这么说吧,MLIR的教程我是真没绷住😅作为新手教程的Tutorial居然要求编译LLVM Project,且不提LLVM Project光压缩包下载就要200MB-300MB的空间(直接Git Clone我都不敢想象要多大),解压后可达700MB,我的辣鸡笔记本编译都需要半个小时(最后貌似还编译失败了🤣现在看到MSVC就想绕着走)。受不了,真玩不起😅

(Note:建议来个大佬捐一个线上编译环境好吧😆Tutorial也没有明确说明要先编译才能用,我翻了两个章节后才反应过来)

但活还是要做的,本着来都来了,就先去学了LLVM,过完一遍Kaleidoscope Tutorial(学习记录:传送门)再折回来看MLIR

翻了一遍目录,发现其实吧,只要能找到对应依赖,就可以不需要编译LLVM Project,而Debian的sid源就有LLVM的依赖

这不就成了!😍赶紧开搞

环境配置

项目文件我已经从LLVM Project中分离出来,已上传Github,可供参考:

mocusez/CMake_MLIR_Toy

踩过的坑

error: no type named ‘LogicalResult’ in namespace ‘llvm’; did you mean ‘mlir::LogicalResult’?

需要使用LLVM18刚发布时候的代码,LLVM19修改了MLIR的Example代码,将LogicResult从MLIR空间移到了LLVM导致报错

报错的起因:mlir/LogicalResult: move into llvm

CMake关于mlir-tblgen的部分报错

由于是从LLVM Project摘取出来,缺少写CMake的Macro导致报错

这部分是关于.td文件生成C++源码与部分,需要mlir-tool的参与

我的解决方案是运行Shell脚本解决

mlir-tblgen-18 -gen-op-decls -I /usr/lib/llvm-18/include Ops.td > Ops.h.inc
mlir-tblgen-18 -gen-op-defs -I /usr/lib/llvm-18/include Ops.td > Ops.cpp.inc
mlir-tblgen-18 -gen-dialect-decls -I /usr/lib/llvm-18/include Ops.td > Dialect.h.inc
mlir-tblgen-18 -gen-dialect-defs -I /usr/lib/llvm-18/include Ops.td > Dialect.cpp.inc

2024.8.8更新:

后面发现MLIR相关部分的CMake已经定义好宏,在CMakelists.txt里Include进来就行了

list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")

include(TableGen)
include(AddLLVM)
include(AddMLIR)
include(HandleLLVMOptions)

学习资料

知乎&B站用户法斯特豪斯对Toy tutotial进行了一个中文的概述,可以对Tutorial的结构有一个不错的认识

还有B站先进编译实验室的视频 人工智能编译器MLIR-官方入门教程讲解

MLIR Toy Tutorial概述

但实话说,具体细节,我觉的我还没理解到位,即使我已经过完了一边Toy tutorial😓

这张从视频当中截取的图,基本就是MLIR端到端的一个完整流程

学习建议

  1. 善用VScode的代码查找功能,因为Toy Tutorial的很多代码并没有给出实现位置😢
  2. 使用GDB/LLDB动态调试代码,从而更好理解运行流程
  3. 学完Toy tutorial后可以看一下mlir/test里的实践样例
  4. 用godbolt.org看MLIR的生成结果

VScode Debug环境配置

安装GDB

apt install gdb

参考这份可以仓库launch.json自行编写

用LLDB可能会更好些😢但我还不知道怎么与VScode衔接

{
    // Use IntelliSense to learn about possible attributes.
    // Hover to view descriptions of existing attributes.
    // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
    "version": "0.2.0",
    "configurations": [
        {
            "name": "Debug Toy1",
            "type": "cppdbg",
            "request": "launch",
            "program": "${workspaceFolder}/build/Ch1/toyc-ch1",
            "args": ["./Examples/Toy/Ch1/ast.toy","-emit=ast"],
            "stopAtEntry": false,
            "cwd": "${workspaceFolder}",
            "environment": [],
            "externalConsole": false,
            "MIMode": "gdb",
            "setupCommands": [
                {
                    "description": "Enable pretty-printing for gdb",
                    "text": "-enable-pretty-printing",
                    "ignoreFailures": true
                },
                {
                    "description": "Set Disassembly Flavor to Intel",
                    "text": "-gdb-set disassembly-flavor intel",
                    "ignoreFailures": true
                }
            ]
        }

    ]
}

学习记录

MLIR是LLVM IR更进一步的抽象,通过Dialect可以支持更加复杂的运算结构(比如Tensor)

无论多么复杂,最终都会变成LLVM IR进行最后端到端的流程

而Toy Tutorial里并没有详细介绍Lexer, Parser,AST Tree,端到端落地(应该是默认大家已经看过Kaleidoscope Tutorial),更多着重在MLIR的生成与优化上。而要实现MLIR的生成与优化,需要Dialect,需要学习ODS(Operation Define Specification)和Declarative Rewrite Rule(DRR)语法,这部分内容在7个章节的Tutorial仅仅只是些皮毛,还有很多细节内容需要花时间去学习。

这部分内容算是学习MLIR两天的一个初步体验,后面还有很多内容有待学习😀

Chapter 1

这部分属于LLVM的部分:AST生成,这部分属于LLVM模块(引用的基本都是LLVM的头文件),参照LLVM-KaledioScope的设计(Lexer提供Location, Token两种结构体,以及Lexer,用于内存缓存的LexerBuffer,ExprAST为基类设计多种AST的生成)

依赖关系应该是Lexer -> AST -> Parser

LLVM的ADT和Support提供了一些库用于生成AST

cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
auto moduleAST = parseInputFile(inputFilename);

这里的InputFilename是从cl::ParseCommandLineOptions中获取的,而InputFilename是以static定义的

手册对于cl::ParseCommandLineOptions的描述跟没有一样🙄

image-20240807153338882

namespace cl = llvm::cl;
static cl::opt<std::string> inputFilename(cl::Positional,
                                          cl::desc("<input toy file>"),
                                          cl::init("-"),
                                          cl::value_desc("filename"));

输出全靠emitAction决定(在Toy2可以看到,可输入的),一个dump就打印出去了,感觉黑箱太多😯

emitAction("emit", cl::desc("Select the kind of output desired"),
           cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")));

  switch (emitAction) {
  case Action::DumpAST:
    dump(*moduleAST);
    return 0;
  default:
    llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
  }

AST树的生成细节可以通过Debug在AST.cpp查看dump()

image-20240807163252595

利用重载特性适配Parser(虽然都是dump,但不同类型dump输出不同)

void ASTDumper::dump(FunctionAST *node) {
  INDENT();
  llvm::errs() << "Function \n";
  dump(node->getProto());
  dump(node->getBody());
}

LLVM提供以下内容用于AST生成:

llvm::errs() 用于无缓存输出

llvm::StringRef : LLVM 项目中的一个实用类,用于表示一个不可变的字符串引用

image-20240807155431760

llvm::ArrayRef提供了一种轻量级的方式来访问数组或容器中的元素,而不需要复制这些元素

llvm::any_of类似std::any_of

llvm::isa类型检查

llvm::Twine用于字符串复制和拼接

llvm::interleaveComma 用于在一系列项之间插入逗号分隔符。

Chapter 2

开始加入MLIR的Dialect部分

In addition to specializing the mlir::Op C++ template, MLIR also supports defining operations in a declarative manner. This is achieved via the Operation Definition Specification framework. Facts regarding an operation are specified concisely into a TableGen record, which will be expanded into an equivalent mlir::Op C++ template

使用ODS生成mlir::Op的C++文件,所以理论上你知道mlir::Op,就可以不看ODS(这种设计应该是为了方便日后解耦)

很多关于ODS编写规范的细节,散布在教程的各处超链接中

而文章后面占用不小篇幅的PrintOp::print根本就不是写出来的🤣(实际生成于Ops.cpp.inc

emitAction决定了输出:执行时给的时-emit=ast就输出为AST,教程要求输出为mlir

static cl::opt<enum Action> emitAction(
    "emit", cl::desc("Select the kind of output desired"),
    cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
    cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));

/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
  llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
      llvm::MemoryBuffer::getFileOrSTDIN(filename);
  if (std::error_code ec = fileOrErr.getError()) {
    llvm::errs() << "Could not open input file: " << ec.message() << "\n";
    return nullptr;
  }
  auto buffer = fileOrErr.get()->getBuffer();
  LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename));
  Parser parser(lexer);
  return parser.parseModule();
}

mlir/Dialect.cpp

MLIR的输出控制在mlir/Dialect.cpp

对用ops.td生成的Dialect.h.inc的Dialect类的inititalize()实现增加Op

教程称之为register this operation in the ToyDialect initializer

void ToyDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
      >();
}

对用ops.td生成的Ops.h.inc的类进行实现(但include进入的却是Ops.cpp.inc

AddOp为例,实现了build()print()parse()

void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
                  mlir::Value lhs, mlir::Value rhs) {
  state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
  state.addOperands({lhs, rhs});
}

mlir::ParseResult AddOp::parse(mlir::OpAsmParser &parser,
                               mlir::OperationState &result) {
  return parseBinaryOp(parser, result);
}

void AddOp::print(mlir::OpAsmPrinter &p) { printBinaryOp(p, *this); }

parse()经过parseBinaryOp(),pirnt()经过printBinary(),在注释中被统称为Toy Operation

static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser,
                                       mlir::OperationState &result) {
  SmallVector<mlir::OpAsmParser::UnresolvedOperand, 2> operands;
  SMLoc operandsLoc = parser.getCurrentLocation();
  Type type;
  if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) ||
      parser.parseOptionalAttrDict(result.attributes) ||
      parser.parseColonType(type))
    return mlir::failure();

  // If the type is a function type, it contains the input and result types of
  // this operation.
  if (FunctionType funcType = llvm::dyn_cast<FunctionType>(type)) {
    if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc,
                               result.operands))
      return mlir::failure();
    result.addTypes(funcType.getResults());
    return mlir::success();
  }

  // Otherwise, the parsed type is the type of both operands and results.
  if (parser.resolveOperands(operands, type, result.operands))
    return mlir::failure();
  result.addTypes(type);
  return mlir::success();
}

这里涉及到mlir::OpAsmParsermlir::OperationState

文档分别写着:

The OpAsmParser has methods for interacting with the asm parser: parsing things from it, emitting errors etc.

It has an intentionally high-level API that is designed to reduce/constrain syntax innovation in individual operations.

For example, consider an op like this:

x = load p[%1, %2] : memref<…>

The “%x = load” tokens are already parsed and therefore invisible to the custom op parser. This can be supported by calling parseOperandList to parse the p, then calling parseOperandList with a SquareDelimiter to parse the indices, then calling parseColonTypeList to parse the result type.

This represents an operation in an abstracted form, suitable for use with the builder APIs.

This object is a large and heavy weight object meant to be used as a temporary object on the stack. It is generally unwise to put this in a collection.

mlir/MLIRGen.cpp

将AST转为MLIR的部分,

提供mlirGen()用于在toyc.cpp中使用,mlirGen()的输入类型为mlir::MLIRContextModuleAST

等于基于AST之上进行MLIR生成

int dumpMLIR() {
  mlir::MLIRContext context;
  // Load our Dialect in this MLIR Context.
  context.getOrLoadDialect<mlir::toy::ToyDialect>();

  // Handle '.toy' input to the compiler.
  if (inputType != InputType::MLIR &&
      !llvm::StringRef(inputFilename).ends_with(".mlir")) {
    auto moduleAST = parseInputFile(inputFilename);
    if (!moduleAST)
      return 6;
    mlir::OwningOpRef<mlir::ModuleOp> module = mlirGen(context, *moduleAST);
    if (!module)
      return 1;

    module->dump();
    return 0;
  }

Chapter 3

ToyCombine.td就是个DRR!其优化了Transpose递归调用

使用DRR的条件时对应Op的hasCanonicalizer = 1

需要给相应的Op添加优化方法(位于ToyCombine.cpp,如果加了Pure就可以去掉Side-effect

void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                              MLIRContext *context) {
  results.add<SimplifyRedundantTranspose>(context);
}

void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
  results.add<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
              FoldConstantReshapeOptPattern>(context);
}

继承mlir::OpRewritePattern<TransposeOp>完成其细节实现

struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
  /// We register this pattern to match every toy.transpose in the IR.
  /// The "benefit" is used by the framework to order the patterns and process
  /// them in order of profitability.
  SimplifyRedundantTranspose(mlir::MLIRContext *context)
      : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}

  /// This method attempts to match a pattern and rewrite it. The rewriter
  /// argument is the orchestrator of the sequence of rewrites. The pattern is
  /// expected to interact with it to perform any changes to the IR from here.
  mlir::LogicalResult
  matchAndRewrite(TransposeOp op,
                  mlir::PatternRewriter &rewriter) const override {
    // Look through the input of the current transpose.
    mlir::Value transposeInput = op.getOperand();
    TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();

    // Input defined by another transpose? If not, no match.
    if (!transposeInputOp)
      return failure();

    // Otherwise, we have a redundant transpose. Use the rewriter.
    rewriter.replaceOp(op, {transposeInputOp.getOperand()});
    return success();
  }
};

如果运行程序时添加-opt,dumpmlir()就会进入下面这个程序片段

  if (enableOpt) {
    mlir::PassManager pm(module.get()->getName());
    // Apply any generic pass manager command line options and run the pipeline.
    if (mlir::failed(mlir::applyPassManagerCLOptions(pm)))
      return 4;

    // Add a run of the canonicalizer to optimize the mlir module.
    pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass());
    if (mlir::failed(pm.run(*module)))
      return 4;
  }

(Chapter 3开始,lodaMLIR()dumpMLIR()中独立出来了)

pm.run(*module)会自动找注册的Pass(这里的效果也可以通过mlir-opt --pass-pipeline="builtin.module()"实现)

下面这个TableGen等价于上面的C++代码,实现模式匹配

// Reshape(Reshape(x)) = Reshape(x)
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
                                   (ReshapeOp $arg)>;

更多实现细节就得看DRR的详细说明了

Chapter 4

Dialect.cpp做了很多修改,使其可以内联(inline)——继承DialectInlinerInterface并添加到Op即可

struct ToyInlinerInterface : public DialectInlinerInterface {
  using DialectInlinerInterface::DialectInlinerInterface;

  //===--------------------------------------------------------------------===//
  // Analysis Hooks
  //===--------------------------------------------------------------------===//

  /// All call operations within toy can be inlined.
  bool isLegalToInline(Operation *call, Operation *callable,
                       bool wouldBeCloned) const final {
    return true;
  }

  /// All operations within toy can be inlined.
  bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
    return true;
  }

  // All functions within toy can be inlined.
  bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
    return true;
  }

Op可以添加traits实现类似继承的特性?(是Rust那种么)

def GenericCallOp : Toy_Op<"generic_call",
    [DeclareOpInterfaceMethods<CallOpInterface>]> {
    //Typora代码渲染有问题,就不放了
}

可以使用mlir::OpPassManager针对FuncOp使用多个Pass

if (enableOpt) {
    mlir::PassManager pm(module.get()->getName());
    // Apply any generic pass manager command line options and run the pipeline.
    if (mlir::failed(mlir::applyPassManagerCLOptions(pm)))
      return 4;

    // Inline all functions into main and then delete them.
    pm.addPass(mlir::createInlinerPass());

    // Now that there is only one function, we can infer the shapes of each of
    // the operations.
    mlir::OpPassManager &optPM = pm.nest<mlir::toy::FuncOp>();
    optPM.addPass(mlir::toy::createShapeInferencePass());
    optPM.addPass(mlir::createCanonicalizerPass());
    optPM.addPass(mlir::createCSEPass());

    if (mlir::failed(pm.run(*module)))
      return 4;
  }

感觉有些细节不太清楚,应该还需要折回来看

Chapter 5

开启使用Opt加了减少循环和AffineScalarReplacement(标量替代?)

if (enableOpt) {
  optPM.addPass(mlir::affine::createLoopFusionPass());
  optPM.addPass(mlir::affine::createAffineScalarReplacementPass());
}

而最美妙的事情是:这两个Pass均已实现,直接使用即可😍

左边加了opt,最直观的感受就是For循环少了一个

image-20240807222316514

相关lowering的操作在LowerToAffineLoops.cpp中,相关说明在Tutorial里写的很清楚,相关lowering的实现均需要继承ConversionPatternOpRewritePattern<Type>

不知道我理解对不对:Lowering的编写属于DRR的范畴

Lowering方案要在ToyToAffineLoweringPass::runOnOperation()得到注册

这里仅实现partial lowering,原因是toy.print函数要在LLVM Dialect才能实现(下一章Chapter 6会见到)

void ToyToAffineLoweringPass::runOnOperation() {
  // The first thing to define is the conversion target. This will define the
  // final target for this lowering.
  ConversionTarget target(getContext());

  // We define the specific operations, or dialects, that are legal targets for
  // this lowering. In our case, we are lowering to a combination of the
  // `Affine`, `Arith`, `Func`, and `MemRef` dialects.
  target.addLegalDialect<affine::AffineDialect, BuiltinDialect,
                         arith::ArithDialect, func::FuncDialect,
                         memref::MemRefDialect>();

  // We also define the Toy dialect as Illegal so that the conversion will fail
  // if any of these operations are *not* converted. Given that we actually want
  // a partial lowering, we explicitly mark the Toy operations that don't want
  // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands
  // to be updated though (as we convert from TensorType to MemRefType), so we
  // only treat it as `legal` if its operands are legal.
  target.addIllegalDialect<toy::ToyDialect>();
  target.addDynamicallyLegalOp<toy::PrintOp>([](toy::PrintOp op) {
    return llvm::none_of(op->getOperandTypes(),
                         [](Type type) { return llvm::isa<TensorType>(type); });
  });

  // Now that the conversion target has been defined, we just need to provide
  // the set of patterns that will lower the Toy operations.
  RewritePatternSet patterns(&getContext());
  patterns.add<AddOpLowering, ConstantOpLowering, FuncOpLowering, MulOpLowering,
               PrintOpLowering, ReturnOpLowering, TransposeOpLowering>(
      &getContext());

  // With the target and rewrite patterns defined, we can now attempt the
  // conversion. The conversion will signal failure if any of our `illegal`
  // operations were not converted successfully.
  if (failed(
          applyPartialConversion(getOperation(), target, std::move(patterns))))
    signalPassFailure();
}

这里的Plus和Mul用了一个Templete BinaryOpLowering,有意思🤔

using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>;
using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>;

具体实现如下:

template <typename BinaryOp, typename LoweredBinaryOp>
struct BinaryOpLowering : public ConversionPattern {
  BinaryOpLowering(MLIRContext *ctx)
      : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {}

  LogicalResult
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const final {
    auto loc = op->getLoc();
    lowerOpToLoops(op, operands, rewriter,
                   [loc](OpBuilder &builder, ValueRange memRefOperands,
                         ValueRange loopIvs) {
                     // Generate an adaptor for the remapped operands of the
                     // BinaryOp. This allows for using the nice named accessors
                     // that are generated by the ODS.
                     typename BinaryOp::Adaptor binaryAdaptor(memRefOperands);

                     // Generate loads for the element of 'lhs' and 'rhs' at the
                     // inner loop.
                     auto loadedLhs = builder.create<affine::AffineLoadOp>(
                         loc, binaryAdaptor.getLhs(), loopIvs);
                     auto loadedRhs = builder.create<affine::AffineLoadOp>(
                         loc, binaryAdaptor.getRhs(), loopIvs);

                     // Create the binary operation performed on the loaded
                     // values.
                     return builder.create<LoweredBinaryOp>(loc, loadedLhs,
                                                            loadedRhs);
                   });
    return success();
  }
};

lowerOpToLoops实现到Affine的转换(或者说Rewrite)

似乎也不涉及Affine相关的Dialect编写,全是调用C++函数

static void lowerOpToLoops(Operation *op, ValueRange operands,
                           PatternRewriter &rewriter,
                           LoopIterationFn processIteration) {
  auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin()));
  auto loc = op->getLoc();

  // Insert an allocation and deallocation for the result of this operation.
  auto memRefType = convertTensorToMemRef(tensorType);
  auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);

  // Create a nest of affine loops, with one loop per dimension of the shape.
  // The buildAffineLoopNest function takes a callback that is used to construct
  // the body of the innermost loop given a builder, a location and a range of
  // loop induction variables.
  SmallVector<int64_t, 4> lowerBounds(tensorType.getRank(), /*Value=*/0);
  SmallVector<int64_t, 4> steps(tensorType.getRank(), /*Value=*/1);
  affine::buildAffineLoopNest(
      rewriter, loc, lowerBounds, tensorType.getShape(), steps,
      [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) {
        // Call the processing function with the rewriter, the memref operands,
        // and the loop induction variables. This function will return the value
        // to store at the current index.
        Value valueToStore = processIteration(nestedBuilder, operands, ivs);
        nestedBuilder.create<affine::AffineStoreOp>(loc, valueToStore, alloc,
                                                    ivs);
      });

  // Replace this operation with the generated alloc.
  rewriter.replaceOp(op, alloc);
}

buildAffineLoopNest()

Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only another loop and a terminator.

The loops iterate from “lbs” to “ubs” with “steps”. The body of the innermost loop is populated by calling “bodyBuilderFn” and providing it with an OpBuilder, a Location and a list of loop induction variables.

Definition at line 2673 of file AffineOps.cpp.

References buildAffineLoopFromConstants(), and buildAffineLoopNestImpl().

Referenced by mlir::linalg::GenerateLoopNest< LoopTy >::doit().

(Tutorial并没有解释Passes.h是干什么的,有待后续补充)

查源码发现Ops.td里有这么一段代码(不知道什么时候进的,Tutorial也没说啊🤐),而在生成的Ops.h.inc里确实生成了代码

 // Provide extra utility definitions on the c++ operation class definition.
 let extraClassDeclaration = [{
  bool hasOperand() { return getNumOperands() != 0; }
 }];

Chapter 6

加上了实现LLVM的Pass

if (isLoweringToLLVM) {
    // Finish lowering the toy IR to the LLVM dialect.
    pm.addPass(mlir::toy::createLowerToLLVMPass());
    // This is necessary to have line tables emitted and basic
    // debugger working. In the future we will add proper debug information
    // emission directly from our frontend.
    pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass());
}

具体细节在LowerToLLVM.cpp里面,将之前没进行loweringPrintOp进行lowering操作——能直接对标生成LLVM的printf()!!!😦

但这里面依然复杂,存在这样一条调用关系:matchAndRewrite -> getOrInsertPrintf/getOrCreateGlobalString(不知道是干嘛的),将toy.print替换为printf()

而像arithAffine这个层级的Dialect已经配置好对应的Conversion,在runOnOperation直接添加即可,并且直接Full lowering

void ToyToLLVMLoweringPass::runOnOperation() {
  LLVMConversionTarget target(getContext());
  target.addLegalOp<ModuleOp>();

  LLVMTypeConverter typeConverter(&getContext());

  RewritePatternSet patterns(&getContext());
  populateAffineToStdConversionPatterns(patterns);
  populateSCFToControlFlowConversionPatterns(patterns);
  mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
  populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns);
  cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns);
  populateFuncToLLVMConversionPatterns(typeConverter, patterns);

  patterns.add<PrintOpLowering>(&getContext());

  auto module = getOperation();
  if (failed(applyFullConversion(module, target, std::move(patterns))))
    signalPassFailure();
}

最后需要做的,就是选择生成LLVM IR或JIT运行

使用JIT的话,还需要指定函数入口

  auto invocationResult = engine->invokePacked("main");
  if (invocationResult) {
    llvm::errs() << "JIT invocation failed\n";
    return -1;
  }

提供了emit=mlir, -emit=mlir-affine, -emit=mlir-llvm(这里的输出还是MLIR,但全部替换为LLVM 的Dialect)和-emit=llvm这些选项进行测试

如果选择 -emit=llvm,可以把输出存成a.ll,再经过Clang编译

clang-18 a.ll -o a

就可以得到二进制执行文件

Chapter 7

这章主要讲如何在MLIR中实现结构体,为此需要实现ODS,修改AST与Parser,变动如下:

  1. Dialect.h添加了有关StructType的实现(这个位置我找了好久😅气到忘了可以用VScode检索)
  2. Dialect.cpp中添加了StructTypeStorage的实现
  3. Ops.td实现写ODS,增加Toy_StructType
  4. (在底层实现Parser[Parser.h]解析和AST[AST.h]构建)——Toy tutorial没提及底层的事
  5. Dialect.cpp实现parseTypeprintType
  6. Ops.td修改ReturnOp使其能处理Toy_StructType
  7. 在增加struct_constantstruct_access两个Op(解决Struct的访问问题)
  8. 增加Fold的相关属性,使Struct里的元素能被FoldConstantReshapeOptPattern识别
  9. Ops.td的Toy_Dialect的设定中设置hasConstantMaterializer = 1,并在Dialect.cpp中实现materializeConstant()内容,生成Toy Dialect

这里面神奇的事情是:我们不再需要实现有关Struct的lowering😮——在materializeConstant转为了mlir::ArrayAttr这个Builtin Dialect

mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder,
                                                 mlir::Attribute value,
                                                 mlir::Type type,
                                                 mlir::Location loc) {
  if (type.isa<StructType>())
    return builder.create<StructConstantOp>(loc, type,
                                            value.cast<mlir::ArrayAttr>());
  return builder.create<ConstantOp>(loc, type,
                                    value.cast<mlir::DenseElementsAttr>());
}

parseType()printType()只有在--emit=MLIR的时候才会被调用(如果是JIT条件下函数内容甚至可以完全删除!😯,但Dialect要求强制实现那就没得说)

后日谈:

不知道什么时候起加了个RecordAST😅Tutorial加东西都不用说明的是吧,代码没看明白,这里就不贴了

StructAST继承了RecordAST

class StructAST : public RecordAST {
  Location location;
  std::string name;
  std::vector<std::unique_ptr<VarDeclExprAST>> variables;

public:
  StructAST(Location location, const std::string &name,
            std::vector<std::unique_ptr<VarDeclExprAST>> variables)
      : RecordAST(Record_Struct), location(std::move(location)), name(name),
        variables(std::move(variables)) {}

  const Location &loc() { return location; }
  llvm::StringRef getName() const { return name; }
  llvm::ArrayRef<std::unique_ptr<VarDeclExprAST>> getVariables() {
    return variables;
  }

  /// LLVM style RTTI
  static bool classof(const RecordAST *r) {
    return r->getKind() == Record_Struct;
  }
};

Parser也需要添加对应解析

  std::unique_ptr<StructAST> parseStruct() {
    auto loc = lexer.getLastLocation();
    lexer.consume(tok_struct);
    if (lexer.getCurToken() != tok_identifier)
      return parseError<StructAST>("name", "in struct definition");
    std::string name(lexer.getId());
    lexer.consume(tok_identifier);

    // Parse: '{'
    if (lexer.getCurToken() != '{')
      return parseError<StructAST>("{", "in struct definition");
    lexer.consume(Token('{'));

    // Parse: decl+
    std::vector<std::unique_ptr<VarDeclExprAST>> decls;
    do {
      auto decl = parseDeclaration(/*requiresInitializer=*/false);
      if (!decl)
        return nullptr;
      decls.push_back(std::move(decl));

      if (lexer.getCurToken() != ';')
        return parseError<StructAST>(";",
                                     "after variable in struct definition");
      lexer.consume(Token(';'));
    } while (lexer.getCurToken() != '}');

    // Parse: '}'
    lexer.consume(Token('}'));
    return std::make_unique<StructAST>(loc, name, std::move(decls));
  }

  /// Get the precedence of the pending binary operator token.
  int getTokPrecedence() {
    if (!isascii(lexer.getCurToken()))
      return -1;

    // 1 is lowest precedence.
    switch (static_cast<char>(lexer.getCurToken())) {
    case '-':
      return 20;
    case '+':
      return 20;
    case '*':
      return 40;
    case '.':
      return 60;
    default:
      return -1;
    }
  }

  /// Helper function to signal errors while parsing, it takes an argument
  /// indicating the expected token and another argument giving more context.
  /// Location is retrieved from the lexer to enrich the error message.
  template <typename R, typename T, typename U = const char *>
  std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
    auto curToken = lexer.getCurToken();
    llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
                 << lexer.getLastLocation().col << "): expected '" << expected
                 << "' " << context << " but has Token " << curToken;
    if (isprint(curToken))
      llvm::errs() << " '" << (char)curToken << "'";
    llvm::errs() << "\n";
    return nullptr;
  }
};