In addition to specializing the mlir::Op C++ template, MLIR also supports defining operations in a declarative manner. This is achieved via the Operation Definition Specification framework. Facts regarding an operation are specified concisely into a TableGen record, which will be expanded into an equivalent mlir::Op C++ template
static cl::opt<enum Action> emitAction( "emit", cl::desc("Select the kind of output desired"), cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error. std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename){ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code ec = fileOrErr.getError()) { llvm::errs() << "Could not open input file: " << ec.message() << "\n"; returnnullptr; } auto buffer = fileOrErr.get()->getBuffer(); LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename)); Parser parser(lexer); return parser.parseModule(); }
// If the type is a function type, it contains the input and result types of // this operation. if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) { if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, result.operands)) return mlir::failure(); result.addTypes(funcType.getResults()); return mlir::success(); }
// Otherwise, the parsed type is the type of both operands and results. if (parser.resolveOperands(operands, type, result.operands)) return mlir::failure(); result.addTypes(type); return mlir::success(); }
这里涉及到mlir::OpAsmParser和mlir::OperationState
文档分别写着:
The OpAsmParser has methods for interacting with the asm parser: parsing things from it, emitting errors etc.
It has an intentionally high-level API that is designed to reduce/constrain syntax innovation in individual operations.
For example, consider an op like this:
x = load p[%1, %2] : memref<…>
The “%x = load” tokens are already parsed and therefore invisible to the custom op parser. This can be supported by calling parseOperandList to parse the p, then calling parseOperandList with a SquareDelimiter to parse the indices, then calling parseColonTypeList to parse the result type.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This object is a large and heavy weight object meant to be used as a temporary object on the stack. It is generally unwise to put this in a collection.
intdumpMLIR(){ mlir::MLIRContext context; // Load our Dialect in this MLIR Context. context.getOrLoadDialect<mlir::toy::ToyDialect>();
// Handle '.toy' input to the compiler. if (inputType != InputType::MLIR && !llvm::StringRef(inputFilename).ends_with(".mlir")) { auto moduleAST = parseInputFile(inputFilename); if (!moduleAST) return6; mlir::OwningOpRef<mlir::ModuleOp> module = mlirGen(context, *moduleAST); if (!module) return1;
structSimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> { /// We register this pattern to match every toy.transpose in the IR. /// The "benefit" is used by the framework to order the patterns and process /// them in order of profitability. SimplifyRedundantTranspose(mlir::MLIRContext *context) : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. mlir::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter)constoverride{ // Look through the input of the current transpose. mlir::Value transposeInput = op.getOperand(); TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
// Input defined by another transpose? If not, no match. if (!transposeInputOp) returnfailure();
// Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); returnsuccess(); } };
如果运行程序时添加-opt,dumpmlir()就会进入下面这个程序片段
1 2 3 4 5 6 7 8 9 10 11
if (enableOpt) { mlir::PassManager pm(module.get()->getName()); // Apply any generic pass manager command line options and run the pipeline. if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) return4;
// Add a run of the canonicalizer to optimize the mlir module. pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass()); if (mlir::failed(pm.run(*module))) return4; }
/// All call operations within toy can be inlined. boolisLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned)constfinal{ returntrue; }
/// All operations within toy can be inlined. boolisLegalToInline(Operation *, Region *, bool, IRMapping &)constfinal{ returntrue; }
// All functions within toy can be inlined. boolisLegalToInline(Region *, Region *, bool, IRMapping &)constfinal{ returntrue; }
if (enableOpt) { mlir::PassManager pm(module.get()->getName()); // Apply any generic pass manager command line options and run the pipeline. if (mlir::failed(mlir::applyPassManagerCLOptions(pm))) return4;
// Inline all functions into main and then delete them. pm.addPass(mlir::createInlinerPass());
// Now that there is only one function, we can infer the shapes of each of // the operations. mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>(); optPM.addPass(mlir::toy::createShapeInferencePass()); optPM.addPass(mlir::createCanonicalizerPass()); optPM.addPass(mlir::createCSEPass());
if (mlir::failed(pm.run(*module))) return4; }
感觉有些细节不太清楚,应该还需要折回来看
Chapter 5
开启使用Opt加了减少循环和AffineScalarReplacement(标量替代?)
1 2 3 4
if (enableOpt) { optPM.addPass(mlir::affine::createLoopFusionPass()); optPM.addPass(mlir::affine::createAffineScalarReplacementPass()); }
voidToyToAffineLoweringPass::runOnOperation(){ // The first thing to define is the conversion target. This will define the // final target for this lowering. ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for // this lowering. In our case, we are lowering to a combination of the // `Affine`, `Arith`, `Func`, and `MemRef` dialects. target.addLegalDialect<affine::AffineDialect, BuiltinDialect, arith::ArithDialect, func::FuncDialect, memref::MemRefDialect>();
// We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want // a partial lowering, we explicitly mark the Toy operations that don't want // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands // to be updated though (as we convert from TensorType to MemRefType), so we // only treat it as `legal` if its operands are legal. target.addIllegalDialect<toy::ToyDialect>(); target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) { return llvm::none_of(op->getOperandTypes(), [](Type type) { return llvm::isa<TensorType>(type); }); });
// Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the Toy operations. RewritePatternSet patterns(&getContext()); patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering, PrintOpLowering, ReturnOpLowering, TransposeOpLowering>( &getContext());
// With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); }
这里的Plus和Mul用了一个Templete BinaryOpLowering,有意思🤔
1 2
using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>; using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter)constfinal{ auto loc = op->getLoc(); lowerOpToLoops(op, operands, rewriter, [loc](OpBuilder &builder, ValueRange memRefOperands, ValueRange loopIvs) { // Generate an adaptor for the remapped operands of the // BinaryOp. This allows for using the nice named accessors // that are generated by the ODS. typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);
// Generate loads for the element of 'lhs' and 'rhs' at the // inner loop. auto loadedLhs = builder.create<affine::AffineLoadOp>( loc, binaryAdaptor.getLhs(), loopIvs); auto loadedRhs = builder.create<affine::AffineLoadOp>( loc, binaryAdaptor.getRhs(), loopIvs);
// Create the binary operation performed on the loaded // values. return builder.create<LoweredBinaryOp>(loc, loadedLhs, loadedRhs); }); returnsuccess(); } };
staticvoidlowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration){ auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin())); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// Create a nest of affine loops, with one loop per dimension of the shape. // The buildAffineLoopNest function takes a callback that is used to construct // the body of the innermost loop given a builder, a location and a range of // loop induction variables. SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0); SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1); affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { // Call the processing function with the rewriter, the memref operands, // and the loop induction variables. This function will return the value // to store at the current index. Value valueToStore = processIteration(nestedBuilder, operands, ivs); nestedBuilder.create<affine::AffineStoreOp>(loc, valueToStore, alloc, ivs); });
// Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); }
buildAffineLoopNest()
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only another loop and a terminator.
The loops iterate from “lbs” to “ubs” with “steps”. The body of the innermost loop is populated by calling “bodyBuilderFn” and providing it with an OpBuilder, a Location and a list of loop induction variables.
// Provide extra utility definitions on the c++ operation class definition. let extraClassDeclaration = [{ boolhasOperand() { returngetNumOperands() != 0; } }];
Chapter 6
加上了实现LLVM的Pass
1 2 3 4 5 6 7 8
if (isLoweringToLLVM) { // Finish lowering the toy IR to the LLVM dialect. pm.addPass(mlir::toy::createLowerToLLVMPass()); // This is necessary to have line tables emitted and basic // debugger working. In the future we will add proper debug information // emission directly from our frontend. pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); }
std::unique_ptr<StructAST> parseStruct(){ auto loc = lexer.getLastLocation(); lexer.consume(tok_struct); if (lexer.getCurToken() != tok_identifier) returnparseError<StructAST>("name", "in struct definition"); std::string name(lexer.getId()); lexer.consume(tok_identifier);
// Parse: '{' if (lexer.getCurToken() != '{') returnparseError<StructAST>("{", "in struct definition"); lexer.consume(Token('{'));
// Parse: decl+ std::vector<std::unique_ptr<VarDeclExprAST>> decls; do { auto decl = parseDeclaration(/*requiresInitializer=*/false); if (!decl) returnnullptr; decls.push_back(std::move(decl));
if (lexer.getCurToken() != ';') returnparseError<StructAST>(";", "after variable in struct definition"); lexer.consume(Token(';')); } while (lexer.getCurToken() != '}');
/// Helper function to signal errors while parsing, it takes an argument /// indicating the expected token and another argument giving more context. /// Location is retrieved from the lexer to enrich the error message. template <typename R, typename T, typename U = constchar *> std::unique_ptr<R> parseError(T &&expected, U &&context = "") { auto curToken = lexer.getCurToken(); llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", " << lexer.getLastLocation().col << "): expected '" << expected << "' " << context << " but has Token " << curToken; if (isprint(curToken)) llvm::errs() << " '" << (char)curToken << "'"; llvm::errs() << "\n"; returnnullptr; } };