现状

如果你现在想运行一个MLIR程序,你在搜索引擎上目前能找到的最好的中文资料是这个:

使用MLIR完成一个端到端的编译流程 — 一条通路

这份资料并不怎么让人满意:虽然整个流程看起来并没错,但MLIR更新的速度很快,4年前的东西很可能用不了。而需要跑通这个端到端流程,你还需要了解TensorFlow,这未免太笨重了。

私认为是MLIR的Toy Tutorial用于炫技的产物,虽然在Chapter #6提到了如何JIT或AOT运行,但很多细节依然需要弄清。

而我是在看了MLIR — Lowering through LLVM才意识到一个问题:既然MLIR最后转换成LLVM IR,那理论上MLIR程序的调用方案和LLVM IR程序几乎别无二致——区别只在于MLIR程序需要mlir-opt进行lowering和mlir-translate进行转译

解决方案

关于如何写出一个简单好用的端到端案例,我想了一个晚上,原先我计划在Toy Tutorial上面修改,但Toy Tutorial限制太多(Example 7所有函数与Main内联,非main函数设置为Private属性,有些函数没添加LLVM Lowering)

思来想去,还是直接手搓MLIR吧😜做个简单的加减乘除即可

Note: 文章以Debian Linux发行版为例,LLVM相关指令请按情况修改

获取LLVM IR

ChatGPT目前还不能输出符合标准的MLIR程序,需要在回答的基础上人工进行修改。将下面这部分代码的文件命名为basic.mlir

module {
// 加法函数:返回 a + b
func.func @add(%0: i32, %1: i32) -> i32 {
%c = arith.addi %0, %1 : i32
return %c : i32
}
// 减法函数:返回 a - b
func.func @sub(%0: i32, %1: i32) -> i32 {
%c = arith.subi %0, %1 : i32
return %c : i32
}
// 乘法函数:返回 a * b
func.func @mul(%0: i32, %1: i32) -> i32 {
%c = arith.muli %0, %1 : i32
return %c : i32
}
// 除法函数:返回 a / b(假设b不为0)
func.func @div(%0: i32, %1: i32) -> i32 {
%c = arith.divsi %0, %1 : i32
return %c : i32
}
}

走Pipeline获得LLVM IR,生成.obj文件

mlir-opt-18 basic.mlir -convert-arith-to-llvm -convert-func-to-llvm > lowered.mlir
mlir-translate-18 --mlir-to-llvmir lowered.mlir > output.ll
llc-18 -filetype=obj -relocation-model=pic output.ll -o output.o

llc-18 -filetype=obj -relocation-model=pic output.ll -o output.o等价于下面代码

#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/MC/TargetRegistry.h"
using namespace llvm;
int main(int argc, char **argv) {
InitLLVM X(argc, argv);
InitializeNativeTarget();
InitializeNativeTargetAsmParser();
InitializeNativeTargetAsmPrinter();
// 创建LLVM上下文和源管理器
LLVMContext Context;
SMDiagnostic Err;
// 从文件中读取LLVM IR
// std::string InputFilename = argv[1];
std::unique_ptr<Module> TheModule = parseIRFile("input.ll", Err, Context);
if (!TheModule) {
errs() << "Error loading file: " << Err.getMessage() << "\n";
return 1;
}
// 获取目标三元组(Target Triple)
auto TargetTriple = sys::getDefaultTargetTriple();
TheModule->setTargetTriple(TargetTriple);
std::string Error;
auto Target = TargetRegistry::lookupTarget(TargetTriple, Error);
if (!Target) {
errs() << Error;
return 1;
}
// 配置目标机器
auto CPU = "generic";
auto Features = "";
TargetOptions opt;
auto TheTargetMachine = Target->createTargetMachine(
TargetTriple, CPU, Features, opt, Reloc::PIC_);
// 设置模块的数据布局
TheModule->setDataLayout(TheTargetMachine->createDataLayout());
// 打开输出文件
std::string OutputFilename = "output.o";
std::error_code EC;
raw_fd_ostream dest(OutputFilename, EC, sys::fs::OF_None);
if (EC) {
errs() << "Could not open file: " << EC.message();
return 1;
}
// 创建PassManager并生成目标文件
legacy::PassManager pass;
auto FileType = CodeGenFileType::ObjectFile;
if (TheTargetMachine->addPassesToEmitFile(pass, dest, nullptr, FileType)) {
errs() << "TheTargetMachine can't emit a file of this type";
return 1;
}
// 运行PassManager并生成目标文件
pass.run(*TheModule);
dest.flush();
outs() << "Wrote " << OutputFilename << "\n";
return 0;
}

可以给大家看看生成的LLVM IR文件

; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
define i32 @add(i32 %0, i32 %1) {
%3 = add i32 %0, %1
ret i32 %3
}
define i32 @sub(i32 %0, i32 %1) {
%3 = sub i32 %0, %1
ret i32 %3
}
define i32 @mul(i32 %0, i32 %1) {
%3 = mul i32 %0, %1
ret i32 %3
}
define i32 @div(i32 %0, i32 %1) {
%3 = sdiv i32 %0, %1
ret i32 %3
}
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}

可以使用objdump查看output.o

AOT运行

写一个简单的main.c与mlir.h进行连结

main.c:

#include<stdio.h>
#include "mlir.h"
int main(){
int a = 2;
int b = 4;
printf("add: %d\n",add(b,a));
printf("sub: %d\n",sub(b,a));
printf("mul: %d\n",mul(b,a));
printf("div: %d\n",div(b,a));
return 0;
}

mlir.h

extern int add(int a,int b);
extern int sub(int a,int b);
extern int mul(int a,int b);
extern int div(int a,int b);

接下来有三种方案可以调用MLIR的程序:

  1. 直接链接目标文件(.obj/.o)
  2. 使用静态库(以Linux平台为例是.a)
  3. 使用动态库(以Linux平台为例是.so)

直接链接目标文件(.obj)

将main.c转成.o后链接即可

clang-18 -c main.c
clang-18 main.o output.o -o main
./main

使用静态库

用LLVM archiver生成静态库

llvm-ar-18 rcs libmylibrary.a output.o
clang-18 main.c -L. -lmylibrary -o main
./main

使用动态库

需要修改下main.c的内容打开动态库

#include <stdio.h>
#include <dlfcn.h> // 包含动态加载库相关的头文件
int main() {
void *handle = dlopen("./libmylibrary.so", RTLD_LAZY);
if (!handle) {
fprintf(stderr, "Error loading library: %s\n", dlerror());
return -1;
}
dlerror();
int (*add)(int, int) = (int (*)(int, int)) dlsym(handle, "add");
int (*sub)(int, int) = (int (*)(int, int)) dlsym(handle, "sub");
int (*mul)(int, int) = (int (*)(int, int)) dlsym(handle, "mul");
int (*div)(int, int) = (int (*)(int, int)) dlsym(handle, "div");
char *error = dlerror();
if (error != NULL) {
fprintf(stderr, "Error finding symbol: %s\n", error);
dlclose(handle);
return -1;
}
int a = 3;
int b = 6;
printf("add: %d\n",add(b,a));
printf("sub: %d\n",sub(b,a));
printf("mul: %d\n",mul(b,a));
printf("div: %d\n",div(b,a));
dlclose(handle);
return 0;
}

将.o转为动态库,链接,然后运行即可

clang-18 -shared -o libmylibrary.so output.o
clang-18 -o main main.c -ldl
./main

JIT运行

使用LLI运行

直接链接运行当然没问题,在此不进行赘述。这里主要演示动态库如何操作

clang-18 -shared -o libmylibrary.so output.o
# clang-18 -S -emit-llvm main.c -o main.ll 也可以
clang-18 -c -emit-llvm main.c -o main.bc
lli-18 -load=./libmylibrary.so main.bc

使用ORC JIT代码运行

ByteCode & ll导入

使用之前生成output.ll将其导入即可,将其命名为jit.cpp

同理导入Bytecode也是可行的,参照代码注释内容

#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/TargetSelect.h"
// #include "llvm/Bitcode/BitcodeReader.h"
using namespace llvm;
using namespace llvm::orc;
ExitOnError ExitOnErr;
int main(int argc, char *argv[]) {
// 初始化LLVM
InitLLVM X(argc, argv);
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
// 创建LLVM上下文
LLVMContext Context;
SMDiagnostic Err;
// 从.ll文件加载LLVM IR模块
std::unique_ptr<Module> M = parseIRFile("output.ll", Err, Context);
if (!M) {
errs() << "Error loading file: " << Err.getMessage() << "\n";
return 1;
}
//从.bc文件加载LLVM IR模块
// ErrorOr<std::unique_ptr<MemoryBuffer>> MBOrErr = MemoryBuffer::getFile("output.bc");
// if (std::error_code EC = MBOrErr.getError()) {
// errs() << "Error reading file: " << EC.message() << "\n";
// return 1;
// }
// Expected<std::unique_ptr<Module>> MOrErr = parseBitcodeFile(MBOrErr.get()->getMemBufferRef(), Context);
// if (!MOrErr) {
// errs() << "Error parsing bitcode: " << toString(MOrErr.takeError()) << "\n";
// return 1;
// }
// std::unique_ptr<Module> M = std::move(MOrErr.get());
// 创建JIT实例
auto J = ExitOnErr(LLJITBuilder().create());
// 将模块添加到JIT
ExitOnErr(J->addIRModule(ThreadSafeModule(std::move(M), std::make_unique<LLVMContext>())));
// 查找并执行函数
auto AddSymbol = ExitOnErr(J->lookup("add"));
auto *Add = AddSymbol.toPtr<int(int, int)>();
auto SubSymbol = ExitOnErr(J->lookup("sub"));
auto *Sub = SubSymbol.toPtr<int(int, int)>();
auto MulSymbol = ExitOnErr(J->lookup("mul"));
auto *Mul = MulSymbol.toPtr<int(int, int)>();
auto DivSymbol = ExitOnErr(J->lookup("div"));
auto *Div = DivSymbol.toPtr<int(int, int)>();
int a = 2;
int b = 4;
outs() << "add: " << Add(b, a) << "\n";
outs() << "sub: " << Sub(b, a) << "\n";
outs() << "mul: " << Mul(b, a) << "\n";
outs() << "div: " << Div(b, a) << "\n";
return 0;
}

编译生成JIT引擎,运行即可得到输出

clang++-18 jit.cpp `llvm-config-18 --cxxflags --ldflags --system-libs --libs core orcjit native` -o jit_example
./jit_example

导入静态库和动态库会比较麻烦,因为ORC JIT自身实现了一套JIT Linker的实现方式,而不是Linux系统默认的ld

既然lli可以运行动态库,那使用动态库理论上就没问题

动态库导入

更新于2024.10.27

由于LLVM迭代很快,在找了很多资料的情况下,终于完成了测试

#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
#include <string>
#include <vector>
using namespace llvm;
using namespace llvm::orc;
class JITLoader {
public:
JITLoader() {
// 初始化本地目标
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
}
Expected<std::unique_ptr<LLJIT>> createJIT() {
auto Builder = LLJITBuilder();
return Builder.create();
}
Error loadLibrary(LLJIT &JIT, const std::string &LibPath) {
// 加载动态库
std::string ErrMsg;
if (sys::DynamicLibrary::LoadLibraryPermanently(LibPath.c_str(), &ErrMsg)) {
return createStringError(inconvertibleErrorCode(),
"Failed to load library: " + ErrMsg);
}
// 添加动态库到搜索路径
JIT.getMainJITDylib().addGenerator(
cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
JIT.getDataLayout().getGlobalPrefix())));
return Error::success();
}
Expected<JITEvaluatedSymbol> lookupSymbol(LLJIT &JIT, const std::string &Name) {
// 打印正在查找的符号
outs() << "Looking for symbol: " << Name << "\n";
// 查找符号
if (auto Addr = JIT.lookup(Name)) {
return JITEvaluatedSymbol(Addr->getValue(),
JITSymbolFlags::Exported);
}
return createStringError(inconvertibleErrorCode(),
"Symbol not found: " + Name);
}
};
// 函数类型定义
using MathFunc = int(*)(int,int);
// 测试函数
void testMathFunction(LLJIT &JIT, JITLoader &Loader,
const std::string &FuncName,
int a, int b) {
if (auto Symbol = Loader.lookupSymbol(JIT, FuncName)) {
auto Func = (MathFunc)(Symbol->getAddress());
outs() << FuncName << "(" << a << ", " << b << ") = "
<< Func(a, b) << "\n";
} else {
errs() << "Failed to find " << FuncName << ": "
<< toString(Symbol.takeError()) << "\n";
}
}
int main(int argc, char *argv[]) {
// 检查命令行参数
if (argc < 2) {
errs() << "Usage: " << argv[0] << " <path-to-libmath_ops.so>\n";
return 1;
}
JITLoader Loader;
// 创建 JIT 实例
auto JIT = Loader.createJIT();
if (!JIT) {
errs() << "Failed to create JIT: "
<< toString(JIT.takeError()) << "\n";
return 1;
}
// 加载动态库
if (auto Err = Loader.loadLibrary(**JIT, argv[1])) {
errs() << "Failed to load library: "
<< toString(std::move(Err)) << "\n";
return 1;
}
// 打印库信息
outs() << "Successfully loaded library: " << argv[1] << "\n";
// 测试所有数学函数
std::vector<std::string> mathFuncs = {"add", "sub", "mul", "div"};
std::vector<std::pair<int, int>> testCases = {
{10, 5},
{20, 4},
{15, 3}
};
for (const auto &func : mathFuncs) {
outs() << "\nTesting " << func << ":\n";
for (const auto &[a, b] : testCases) {
testMathFunction(**JIT, Loader, func, a, b);
}
}
return 0;
}

启动代码:

clang++-18 dynamic_jit.cpp `llvm-config-18 --cxxflags --ldflags --system-libs --libs core orcjit native` -o jit_example
./jit_example ./libmylibrary.so

Note:写一个能和前面对照的上的代码,可以看出差异还是很大的

#include "llvm/ExecutionEngine/Orc/LLJIT.h"
#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
#include <string>
#include <vector>
using namespace llvm;
using namespace llvm::orc;
using MathFunc = int(*)(int,int);
int main(int argc, char *argv[]) {
llvm::ExitOnError ExitOnErr;
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
auto JIT = ExitOnErr(LLJITBuilder().create());
std::string ErrMsg;
if (sys::DynamicLibrary::LoadLibraryPermanently("./libmylibrary.so", &ErrMsg)) {
outs() << "Failed to load library: " + ErrMsg << "\n";
}
// 添加动态库到搜索路径
JIT->getMainJITDylib().addGenerator(
cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
JIT->getDataLayout().getGlobalPrefix())));
// 查找并执行函数
auto AddSymbol = JITEvaluatedSymbol(JIT->lookup("add")->getValue(), JITSymbolFlags::Exported);
auto Add = (MathFunc)(AddSymbol.getAddress());
auto SubSymbol = JITEvaluatedSymbol(JIT->lookup("sub")->getValue(), JITSymbolFlags::Exported);
auto Sub = (MathFunc)(SubSymbol.getAddress());
auto MulSymbol = JITEvaluatedSymbol(JIT->lookup("mul")->getValue(), JITSymbolFlags::Exported);
auto Mul = (MathFunc)(MulSymbol.getAddress());
auto DivSymbol = JITEvaluatedSymbol(JIT->lookup("div")->getValue(), JITSymbolFlags::Exported);
auto Div = (MathFunc)(DivSymbol.getAddress());
int a = 2;
int b = 4;
outs() << "add: " << Add(b, a) << "\n";
outs() << "sub: " << Sub(b, a) << "\n";
outs() << "mul: " << Mul(b, a) << "\n";
outs() << "div: " << Div(b, a) << "\n";
return 0;
}

与Rust联动

通过FFI调用程序肯定也没问题

使用静态库

修改Cargo.toml,增加下面一行:

[build-dependencies]

并在项目根目录(注意不是/src)下添加build.rs

use std::env;
use std::path::PathBuf;
fn main() {
let src_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR").unwrap()).join("src");
println!("cargo:rustc-link-search=native={}", src_dir.display());
}

将之前的libmylibrary.a放入/src,并修改main.rs

main.rs
#[link(name = "mylibrary", kind = "static")]
extern "C" {
fn add(a: i32, b: i32) -> i32;
fn sub(a: i32, b: i32) -> i32;
fn mul(a: i32, b: i32) -> i32;
fn div(a: i32, b: i32) -> i32;
}
fn main() {
unsafe {
let a = 2;
let b = 4;
println!("add: {}", add(b,a));
println!("sub: {}", sub(b,a));
println!("mul: {}", mul(b,a));
println!("div: {}", div(b,a));
}
}

项目结构目录树如下

├── Cargo.lock
├── Cargo.toml
├── build.rs
├── src
│ ├── libmylibrary.a
│ └── main.rs

直接Cargo run运行即可得到结果

Finished `dev` profile [unoptimized + debuginfo] target(s) in 0.00s
Running `target/debug/test_ffi`
add: 6
sub: 2
mul: 8
div: 2

使用动态库(以Linux为例)

上接使用静态库,在该基础上修改部分内容即可

需要告诉ld动态库在哪里,在Bash里修改环境变量

export LD_LIBRARY_PATH=$(pwd)/src:$LD_LIBRARY_PATH

删除main.ckind = "static"

#[link(name = "mylibrary")]
extern "C" {
fn add(a: i32, b: i32) -> i32;
fn sub(a: i32, b: i32) -> i32;
fn mul(a: i32, b: i32) -> i32;
fn div(a: i32, b: i32) -> i32;
}

将前文的libmylibrary.so放入.src,然后cargo run即可

进阶拓展

MLIR中调用C++ Function

更新于2024.12.29

走完上面步骤其实就不能理解MLIR了:只要MLIR还想要在CPU上运行,就会回到LLVM的逻辑,进而回归类似传统动态库的解决方案

addInteger.cpp

#include <cstdint>
#include <cstdio>
extern "C" {
int32_t addInteger(int32_t a, int32_t b) {
const int32_t result = a + b;
printf("Result:%d\n",result);
return result;
}
}

example.mlir

module {
llvm.func @addInteger(i32, i32) -> i32
func.func @main() -> i32 {
%2 = arith.constant 10 : i32
%3 = arith.constant 20 : i32
%4 = llvm.call @addInteger(%2, %3) : (i32, i32) -> i32
%ret = arith.constant 0 : i32
return %ret : i32
}
}

处理操作的Bash:

clang++-18 -c addInteger.cpp -o addInteger.o
mlir-opt-18 example.mlir -convert-func-to-llvm -convert-scf-to-cf
mlir-translate-18 lower.mlir --mlir-to-llvmir > example.ll
clang++-18 example.ll addInteger.o -o example

结果:

Result:30

对应的MLIRContext构建

int arith_work() {
mlir::MLIRContext context;
// Register dialects
context.loadDialect<mlir::func::FuncDialect>();
context.loadDialect<mlir::arith::ArithDialect>();
context.loadDialect<mlir::LLVM::LLVMDialect>();
mlir::OpBuilder builder(&context);
mlir::OwningOpRef<mlir::ModuleOp> module = mlir::ModuleOp::create(builder.getUnknownLoc());
// Create function returning i32
auto i32Type = builder.getI32Type();
auto addIntegerType = mlir::LLVM::LLVMFunctionType::get(i32Type, {i32Type, i32Type}, false);
auto addInteger = builder.create<mlir::LLVM::LLVMFuncOp>(
builder.getUnknownLoc(),
"addInteger",
addIntegerType
);
auto mainType = builder.getFunctionType({}, {i32Type});
auto mainFunc = builder.create<mlir::func::FuncOp>(
builder.getUnknownLoc(),
"main",
mainType
);
auto entryBlock = mainFunc.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
auto ten = builder.create<mlir::arith::ConstantOp>(
builder.getUnknownLoc(),
builder.getI32IntegerAttr(10)
);
auto twenty = builder.create<mlir::arith::ConstantOp>(
builder.getUnknownLoc(),
builder.getI32IntegerAttr(20)
);
auto callResult = builder.create<mlir::LLVM::CallOp>(
builder.getUnknownLoc(),
i32Type,
"addInteger",
mlir::ValueRange{ten, twenty}
);
auto retVal = builder.create<mlir::arith::ConstantOp>(
builder.getUnknownLoc(),
builder.getI32IntegerAttr(0)
);
builder.create<mlir::func::ReturnOp>(
builder.getUnknownLoc(),
mlir::ValueRange{retVal});
module->push_back(addInteger);
module->push_back(mainFunc);
module->print(llvm::outs());
return 0;
}

结语

大家都习惯于使用MLIR的产物,但是真正理解MLIR全链路端到端流程的人却很少。今天最主要的工作就是把这部分知识缺漏补上😆以方便推进后续的研究进展。

附录

记录下动态库生成可能用上,但实际并没用上的Bash指令

clang++-18 -o jit_example dynamic_jit.cpp `llvm-config-18 --cxxflags --ldflags --system-libs --libs core orcjit native` -fno-rtti
clang-18 -shared -o libexample.so example.o -Wl,--export-dynamic