淦!MLIR输出Hello World不应该这么难!
众所周知,大家学习一门编程语言的时候,应该都是从学习从控制台输出“Hello World”开始
这让培养了我们“输出一个数 比 让数进行计算 要更加容易的”认知
这种认知会在你学习完LLVM/MLIR后改变。实际情况是:IO是一种复杂的计算,单纯的运算要比控制输入输出要更容易上手😉
从直觉上来看, 让仅仅只会”计算”的TRM来支撑一个功能齐全的操作系统的运行还是不太现实的. 这给我们的感觉就是, 计算机也有一定的”功能强弱”之分, 计算机越”强大”, 就能跑越复杂的程序. 换句话说, 程序的运行其实是对计算机的功能有需求的. 在你运行Hello World程序时, 你敲入一条命令(或者点击一下鼠标), 程序就成功运行了, 但这背后其实隐藏着操作系统开发者和库函数开发者的无数汗水.
——Chapter 程序, 运行时环境与AM 《南京大学 计算机科学与技术系 计算机系统基础 课程实验》
但对于MLIR而言,输出“Hello world”确实是一个有些复杂的事情,让我们看看网上的大家是怎么做的
网上样例
MLIR Toy Tutorials
在MLIR官方的样例Toy Tutorials中,你会在Chapter 6中看到 toy.print 的从Toy Dialect下降到LLVM Dialect
/// Return a symbol reference to the printf function, inserting it into the/// module if necessary.static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, ModuleOp module, LLVM::LLVMDialect *llvmDialect) { auto *context = module.getContext(); if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf")) return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` auto llvmI32Ty = IntegerType::get(context, 32); auto llvmI8PtrTy = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, /*isVarArg=*/true);
// Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType); return SymbolRefAttr::get("printf", context);}
当时我看到的时候是懵逼的,然后就直接略过了😑(见我之前的笔记)
getOrInsertPrintf
这个函数做的工作是:
先查看Module中有没有先前声明过printf()
,
如果有就直接用(SymbolRefAttr::get("printf", context);
)
如果没有的话,就声明printf()
函数(返回一个32位的值,传入一个指针——即字符串)
然后进行rewrite
,替换toy.print
为llvm.func printf()
有用信息就这些,后面就开始讲Conversion Patterns, Full Lowering, JIT生成
说那么多,你就会发现:你马上要把Toy Tutorial做完了,但这东西连Hello World如何输出都没讲清楚🤣
vector.print
一次偶然的机会,你看到Vector Dialect里面有一个PrintOp,最美妙的事情莫过于:这个Op不仅能输出向量,甚至连字符串也能输出😄
顺着信息,你去Github搜索vector.print str "Hello, World!"
,并很幸运的看到了下面这段代码:
// RUN: mlir-opt %s -test-lower-to-llvm | \// RUN: mlir-cpu-runner -e entry -entry-point-result=void \// RUN: -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils | \// RUN: FileCheck %s
/// This tests printing (multiple) string literals works.
func.func @entry() { // CHECK: Hello, World! vector.print str "Hello, World!\n" // CHECK-NEXT: Bye! vector.print str "Bye!\n" return}
把这段复制下来,并试着按照注释执行了一下
mlir-opt-18 print-str.mlir -test-lower-to-llvm
得到以下内容,看清来不错👍,基本都lower到了 LLVM Dialect:
module { llvm.mlir.global private constant @vector_print_str_0(dense<[66, 121, 101, 33, 10, 10, 0]> : tensor<7xi8>) {addr_space = 0 : i32} : !llvm.array<7 x i8> llvm.func @printString(!llvm.ptr) llvm.mlir.global private constant @vector_print_str(dense<[72, 101, 108, 108, 111, 44, 32, 87, 111, 114, 108, 100, 33, 10, 10, 0]> : tensor<16xi8>) {addr_space = 0 : i32} : !llvm.array<16 x i8> llvm.func @entry() { %0 = llvm.mlir.addressof @vector_print_str : !llvm.ptr llvm.call @printString(%0) : (!llvm.ptr) -> () %1 = llvm.mlir.addressof @vector_print_str_0 : !llvm.ptr llvm.call @printString(%1) : (!llvm.ptr) -> () llvm.return }}
对于%mlir_c_runner_utils,%mlir_runner_utils
是什么并没有什么头绪,但GPT给出了答案
echo "$(llvm-config --prefix)/lib/libmlir_c_runner_utils.so"echo "$(llvm-config --prefix)/lib/libmlir_runner_utils.so"
参照GPT的提示对内容补齐
mlir-opt-18 print-str.mlir -test-lower-to-llvm | \mlir-cpu-runner-18 -e entry -entry-point-result=void \-shared-libs=/usr/lib/llvm-18/lib/libmlir_c_runner_utils.so,/usr/lib/llvm-18/lib/libmlir_runner_utils.so
控制台成功输出了Hello world!😍
但需要指出:这个样例隐藏了太多实现细节,在你自己的项目中,你既不能到处添加vector.print
——没人因为仅仅想要一个输出添加一整个Vector Dialect
,你也不能随意使用mlir-cpu-runner
(虽然他实际上就是LLVM JIT引擎)
样例总结
虽然以上两个案例都有很多地方让人困惑,但仔细想想,其中的关键点也已经显现:
MLIR而本身并不提供输出相关操作——这部分内容是LLVM负责实现的
只要处理好 自定义Dialect 下降到 LLVM Dialect的逻辑,并通过llvm.call
调用外部的printf()
,就可以完成输出!😲
从头实现输出Hello World
精简下 mlir-hello,改一个只有一个Op的Dialect,且这个Op将会输出Hello world
完整代码已上传Github:mlir-hello-world
代码讲解
首先编写有关Diealect的TableGen文件(HelloDialect.td
)
def Hello_Dialect : Dialect { let name = "hello"; let summary = "A hello out-of-tree MLIR dialect."; let description = [{ This dialect is minimal example to implement hello-world kind of sample code for MLIR. }]; let cppNamespace = "::hello";}
定义Op的TableGen文件(HelloOps.td
)
#ifndef HELLO_OPS#define HELLO_OPS
include "HelloDialect.td"include "mlir/Interfaces/SideEffectInterfaces.td"
def WorldOp : Hello_Op<"world", [Pure]> { let summary = "print Hello, World"; let description = [{ The "world" operation prints "Hello, World", and produces no results. }];}
#endif // HELLO_OPS
输出Hello world的关键点在于如何lower到LLVM,这点可看LowerToLLVM.cpp
和Toy tutorials功能一致的getOrInsertPrintf()
,但将函数类型的声明单独独立为getPrintfType()
static mlir::LLVM::LLVMFunctionTypegetPrintfType(mlir::MLIRContext *context) { auto llvmI32Ty = mlir::IntegerType::get(context, 32); auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(context); auto llvmFnType = mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmPtrTy, /*isVarArg=*/true); return llvmFnType;}
static mlir::FlatSymbolRefAttrgetOrInsertPrintf(mlir::PatternRewriter &rewriter, mlir::ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("printf")) { return mlir::SymbolRefAttr::get(context, "printf"); }
mlir::PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create<mlir::LLVM::LLVMFuncOp>(module.getLoc(), "printf", getPrintfType(context)); return mlir::SymbolRefAttr::get(context, "printf");}
在 getOrCreateGlobalString()
里,我们将Hello world文本设为全局的LLVMArray类型
static mlir::Value getOrCreateGlobalString(mlir::Location loc, mlir::OpBuilder &builder, mlir::StringRef name, mlir::StringRef value, mlir::ModuleOp module) { // Create the global at the entry of the module. mlir::LLVM::GlobalOp global; if (!(global = module.lookupSymbol<mlir::LLVM::GlobalOp>(name))) { mlir::OpBuilder::InsertionGuard insertGuard(builder); builder.setInsertionPointToStart(module.getBody()); auto type = mlir::LLVM::LLVMArrayType::get( mlir::IntegerType::get(builder.getContext(), 8), value.size()); global = builder.create<mlir::LLVM::GlobalOp>( loc, type, /*isConstant=*/true, mlir::LLVM::Linkage::Internal, name, builder.getStringAttr(value)); }
// Get the pointer to the first character in the global string. mlir::Value globalPtr = builder.create<mlir::LLVM::AddressOfOp>(loc, global); mlir::Value cst0 = builder.create<mlir::LLVM::ConstantOp>( loc, mlir::IntegerType::get(builder.getContext(), 64), builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<mlir::LLVM::GEPOp>( loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()), global.getType(), globalPtr, mlir::ArrayRef<mlir::Value>({cst0, cst0})); }
编写Pass的runOnOperation()
,下面的内容应该理论上应该可以精简(不精简也能run就是了😀)
void HelloToLLVMLoweringPass::runOnOperation() { mlir::LLVMConversionTarget target(getContext()); target.addLegalOp<mlir::ModuleOp>();
mlir::LLVMTypeConverter typeConverter(&getContext()); mlir::RewritePatternSet patterns(&getContext());
populateAffineToStdConversionPatterns(patterns); populateSCFToControlFlowConversionPatterns(patterns); mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
mlir::populateFinalizeMemRefToLLVMConversionPatterns(typeConverter, patterns); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); populateFuncToLLVMConversionPatterns(typeConverter, patterns);
patterns.add<hello::WorldOpLowering>(&getContext());
auto module = getOperation(); if (failed(applyFullConversion(module, target, std::move(patterns)))) { signalPassFailure(); }}
必须要点明是对hello.world
进行改写和下降
patterns.add<hello::WorldOpLowering>(&getContext());
记得要把Pass注册到HelloPasses.h
里头
namespace hello { std::unique_ptr<mlir::Pass> createLowerToLLVMPass();} // namespace hello
运行效果
准备好这样一段MLIR
func.func @main() { "hello.world"() : () -> () return}
如果是使用mlir::OpBuilder构建代码块的话,这段MLIR等价于下面代码:
mlir::OpBuilder builder(&context);mlir::OwningOpRef<mlir::ModuleOp> module = mlir::ModuleOp::create(builder.getUnknownLoc());// Generate Code Blockauto funcType = builder.getFunctionType({}, {});auto func = builder.create<mlir::func::FuncOp>(builder.getUnknownLoc(), "main", funcType);module->push_back(func);auto entryBlock = func.addEntryBlock();builder.setInsertionPointToStart(entryBlock);builder.create<hello::WorldOp>(builder.getUnknownLoc());builder.create<mlir::func::ReturnOp>(builder.getUnknownLoc());你可以在仓库里的
hello_world.cpp
得到验证
./build/bin/helloworld -emit=mlir为什么提及这个部分呢?如果要从解析器开始的话,语法AST树最后都会转成
mlir::OpBuilder
构建的形式,才能进行各类Pass
对于这样一段MLIR,它会首先下降到LLVM Dialect(这条命令在做mlir-opt
的工作)
./build/bin/hello-opt test/hello_world.mlir -emit=mlir-llvm
效果:
module { llvm.mlir.global internal constant @hello_word_string("Hello, World! \0A\00") {addr_space = 0 : i32} llvm.func @printf(!llvm.ptr, ...) -> i32 llvm.func @main() { %0 = llvm.mlir.addressof @hello_word_string : !llvm.ptr %1 = llvm.mlir.constant(0 : index) : i64 %2 = llvm.getelementptr %0[%1, %1] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.array<16 x i8> %3 = llvm.call @printf(%2) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr) -> i32 llvm.return }}
可以看到,确实生成了lvm.mlir.global internal constant
这一字符串常量,并且声明了printf()
的外部LLVM函数存在
再从LLVM Dialect下降到LLVM IR(这条命令在做mlir-translate
的工作)
./build/bin/hello-opt test/hello_world.mlir -emit=llvm
效果:
; ModuleID = 'LLVMDialectModule'source_filename = "LLVMDialectModule"target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"target triple = "x86_64-pc-linux-gnu"
@str = private unnamed_addr constant [15 x i8] c"Hello, World! \00", align 1
; Function Attrs: nofree nounwinddefine void @main() local_unnamed_addr #0 { %puts = tail call i32 @puts(ptr nonnull dereferenceable(1) @str) ret void}
; Function Attrs: nofree nounwinddeclare noundef i32 @puts(ptr nocapture noundef readonly) local_unnamed_addr #0
attributes #0 = { nofree nounwind }
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}
到这一步,你已经可以用lli
运行,但我在代码里也添加了相关JIT运行实现
./build/bin/hello-opt test/hello_world.mlir
通过这个方案,你就能得到Hello, World!的文本输出😊
FAQ
Q:为什么不用mlir::ExecutionEngine::create
生成JIT运行环境?(就像 Toy Tutorials Chapter 6 那样)
A:这是个好问题😂原因是我在Toy Tutorials以外的地方从来没有运行成功过这个函数,而且mlir::ExecutionEngine::create
用的也还是LLVM ORC JIT提供的运行环境,那还不如直接使用LLVM ORC JIT,也能提供更加灵活的操作
前人の工作
https://github.com/Lewuathe/mlir-hello
我写的代码之从中精简出来的,这个仓库经常更新,用Github Action保持与LLVM主线匹配,有很多值得学习的地方😋
https://github.com/Yancey1989/mlir-hello-world
我写玩这篇文章后从Google搜到的(和我的仓库名称撞了),使用的流程是从AST开始,看完我这篇后,再去试试也不错😂(3年没更新,可能LLVM18跑不起来)
结语
如果你之前已经阅读完MLIR Toy Tutorial,那么这篇文章将有助于理清MLIR于LLVM的层次关系,以及两者在不同阶段所发挥的作用
对我而言,输出Hello World可能只是个执念罢了😂这内容确实要比我预想的要复杂,但整理的过程当中也有新的收获——这也许就是写Blog的乐趣所在。