2026PPoPP MLIR Tutorial学习
这个月初参加了CGO的LLVM Workshop,在悉尼ICC的2楼做完报告后就到3楼溜达(3楼在开CC,以及PPoPP/HPCA的Workshop和Tutroial),顺带就看了一下MLIR Tutorial。我觉得这个Tutrorial做的不错,讲的是用MLIR实现个简单的Tile,但当天身体不太舒服提前走了,直到这几天才基本完整过一遍Tutroial

项目地址:Groverkss/mlir-tutor
按照项目的Readme.md配置即可顺利运行
tutorial-opt注意事项
项目的Opt位于build/tutorial/tutorial-opt
opt输出可以选择--split-input-file将同一个文件的不同Op进行输出分割
如果需要运行运行Python程序,则需要注意TUTORIAL_OPT环境变量是否则正确
通过grep可以查看对应Pass(总感觉这块显示并不怎么好,看看有没有方法改进)
tutorial-opt --help | grep "tiny"MLIR与Python联动
我认为这是这个Tutorial吸引人的地方,其提供了一套非常基础的Python与MLIR的Type与Op绑定
比如下面这个是Ptr绑定:
class Ptr: """Wrapper for !tiny.ptr SSA values.""" _value: Value
@staticmethod def _wrap(value: Value) -> "Ptr": p = Ptr() p._value = value return p
@staticmethod def get_type() -> Type: """Get the !tiny.ptr type.""" return Type.parse("!tiny.ptr")
def load(self, offset: Index, num_elements: int) -> F16Vector: """Load vector<Nxf16> from pointer at offset.""" vec_type = VectorType.get([num_elements], F16Type.get()) op = Operation.create( "tiny.load", results=[vec_type], operands=[self._value, offset._value], ) return F16Vector._wrap(op.result)
def store(self, offset: Index, vec: F16Vector) -> None: """Store vector<Nxf16> to pointer at offset.""" Operation.create( "tiny.store", results=[], operands=[vec._value, self._value, offset._value], )Compile_and_print将Python转化为MLIR并输出
def compile_and_print(fn): """Compile and print all lowering stages.""" opt = TutorialOpt()
with MLIRModule() as m: tiny_ir = m.build_func_verified(fn, _get_type_map(), opt)
print("=== Tiny Dialect ===") print(tiny_ir)
arith_ir = opt.run(tiny_ir, ["tiny-to-arith", "canonicalize", "cse"]) print("=== After tiny-to-arith ===") print(arith_ir)
llvm_ir = opt.run(tiny_ir, ["tiny-to-arith", "canonicalize", "cse", "tiny-to-llvm", "convert-to-llvm"]) print("=== LLVM Dialect ===") print(llvm_ir)
return llvm_ir同样,MLIR的Vector Type也与Numpy的Vector进行绑定(Numpy的array转为mlir的vector),从代码来看方法也一并绑定
class F16Vector: """Wrapper for vector<Nxf16> SSA values.""" _value: Value
@staticmethod def _wrap(value: Value) -> "F16Vector": vec = F16Vector() vec._value = value return vec
@staticmethod def constant(vals: list[float], size: int = None) -> "F16Vector": """Create a constant f16 vector via tiny.constant.""" n = size or len(vals) vec_type = VectorType.get([n], F16Type.get()) data = np.array(vals, dtype=np.float16) attr = DenseElementsAttr.get(data, type=vec_type) op = Operation.create( "tiny.constant", results=[vec_type], attributes={"value": attr}, ) return F16Vector._wrap(op.result)
def _binop(self, other: "F16Vector", op_name: str) -> "F16Vector": op = Operation.create( f"tiny.{op_name}", results=[self._value.type], operands=[self._value, other._value], ) return F16Vector._wrap(op.result)
def __add__(self, other): return self._binop(other, "addf") def __sub__(self, other): return self._binop(other, "subf") def __mul__(self, other): return self._binop(other, "mulf") def __truediv__(self, other): return self._binop(other, "divf")
def sum(self) -> "F16Vector": """Reduce to vector<1xf16> via tiny.sum.""" result_type = VectorType.get([1], F16Type.get()) op = Operation.create( "tiny.sum", results=[result_type], operands=[self._value], ) return F16Vector._wrap(op.result)此外还需要对于MLIRModule进行单独Python包装,以适应不同的需求场景
class MLIRModule: """Context manager for building MLIR modules with unregistered dialects."""
def __init__(self): self.ctx = None self.loc = None self.module = None
def __enter__(self): self.ctx = Context() self.ctx.allow_unregistered_dialects = True self.ctx.__enter__() self.loc = Location.unknown() self.loc.__enter__() return self
def __exit__(self, *args): self.loc.__exit__(*args) self.ctx.__exit__(*args)
def build_func(self, fn: Callable, type_map: dict) -> Module: """Build MLIR module from a Python function.
Args: fn: Function to compile (uses type annotations for args) type_map: Maps annotation types to (mlir_type, wrapper_class) tuples """ sig = inspect.signature(fn) self.module = Module.create()
with InsertionPoint(self.module.body): # Build input types from annotations input_types = [] for param in sig.parameters.values(): if param.annotation not in type_map: raise ValueError(f"Unsupported type: {param.annotation}") mlir_type, _ = type_map[param.annotation] input_types.append(mlir_type() if callable(mlir_type) else mlir_type)
# Create func.func func_op = func_d.FuncOp(fn.__name__, (input_types, []))
with InsertionPoint(func_op.add_entry_block()): # Wrap block arguments in DSL types args = [] for i, param in enumerate(sig.parameters.values()): _, wrapper_cls = type_map[param.annotation] args.append(wrapper_cls._wrap(func_op.arguments[i]))
# Execute user function body fn(*args)
# Add return func_d.return_([])
return self.module
def build_func_verified(self, fn: Callable, type_map: dict, opt: "TutorialOpt") -> str: """Build and verify module, return pretty-printed IR.
Runs the generated IR through tutorial-opt to verify it's valid and get pretty-printed output using the dialect's assembly format. """ self.build_func(fn, type_map) raw_ir = str(self.module) # Round-trip through tutorial-opt to verify and pretty-print return opt.run(raw_ir, [])此外还有一些实现细节需要注意
| 概念 | 说明 |
|---|---|
tutorial-opt | C++编译的MLIR工具,实现了所有pass(tiny-to-arith、tiny-to-llvm等) |
TutorialOpt 类 | Python包装器,通过 subprocess 调用 tutorial-opt |
run() 方法 | 构建命令行参数,执行 tutorial-opt 进程,返回结果 |
| pass 列表 | 指定要执行的MLIR转换通道 |
| stdin/stdout | 通过标准输入传入IR,从标准输出读取转换结果 |
当然,使用Python包导入复杂的MLIR依赖是第一次见到,这个方案确实解决了MLIR依赖复杂的问题,但代价是版本无法锁定,比如我现在测试时mlir-wheel就已经找不到当时的版本了(见Issue)
Chapter 1
对于Dialect,Pass需要,且也可以在单独的TableGen中定义
//===- TinyPasses.td - Tiny dialect passes -----------------*- tablegen -*-===////// Defines passes for the Tiny dialect.//// Reference: https://mlir.llvm.org/docs/PassManagement/#tablegen-specification////===----------------------------------------------------------------------===//
#ifndef TINY_PASSES#define TINY_PASSES
include "mlir/Pass/PassBase.td"
def TinyToArith : Pass<"tiny-to-arith"> { let summary = "Lower Tiny arithmetic operations to arith dialect."; let description = [{ This pass lowers Tiny dialect arithmetic operations to equivalent operations in the arith dialect. Memory operations (load/store) and the ptr type are NOT converted by this pass - use --tiny-to-llvm for that.
The lowering includes: - `tiny.constant` -> `arith.constant` - `tiny.addf/subf/mulf/divf` -> `arith.addf/subf/mulf/divf` - `tiny.addi/subi/muli/divi` -> `arith.addi/subi/muli/divsi`
Example: ```mlir %0 = tiny.addf %a, %b : vector<4xf16> // Becomes: %0 = arith.addf %a, %b : vector<4xf16> ``` }];
let dependentDialects = [ "mlir::arith::ArithDialect", "mlir::vector::VectorDialect" ];}
def TinyToLLVM : Pass<"tiny-to-llvm"> { let summary = "Lower Tiny memory operations and ptr type to LLVM dialect."; let description = [{ This pass lowers Tiny dialect memory operations and the ptr type to equivalent operations in the LLVM dialect.
The lowering includes: - `!tiny.ptr` type -> `!llvm.ptr` - `tiny.load %ptr, %offset` -> GEP to compute byte address, then `llvm.load` - `tiny.store %val, %ptr, %offset` -> GEP to compute byte address, then `llvm.store`
The offset in tiny.load/store is in f16 elements. The lowering converts this to a byte offset by using GEP with f16 element type: ```mlir %0 = tiny.load %ptr, %offset : vector<4xf16> // Becomes: %gep = llvm.getelementptr %ptr[%offset] : (!llvm.ptr, i64) -> !llvm.ptr, f16 %0 = llvm.load %gep : !llvm.ptr -> vector<4xf16> ```
Note: Run --tiny-to-arith first to convert arithmetic operations. }];
let dependentDialects = [ "mlir::LLVM::LLVMDialect" ];}
#endif // TINY_PASSES非常规范的Type定义与Operate定义,CPred从的isPowerOf2_64中起着校验参数的作用
// Constraint for vector<Nxf16> where N is a power of 2.// Uses VectorOfRankAndType from CommonTypeConstraints.td (included via OpBase.td)// with an additional power-of-2 size check.def Tiny_VectorF16 : Type< And<[ VectorOfRankAndType<[1], [F16]>.predicate, CPred<"::llvm::isPowerOf2_64(" "::llvm::cast<::mlir::VectorType>($_self).getDimSize(0))"> ]>, "vector of f16 with power-of-2 size", "::mlir::VectorType">;类型的定义还可以继承
// Constraint for vector<1xf16> (result type for sum operation).// Reuses Tiny_VectorF16 predicate and adds a size=1 constraint.def Tiny_Vector1F16 : Type< And<[ Tiny_VectorF16.predicate, CPred<"::llvm::cast<::mlir::VectorType>($_self).getDimSize(0) == 1"> ]>, "vector<1xf16>", "::mlir::VectorType">;ConstantOp既可以传入vector也可以是index
def Tiny_ConstantOp : Tiny_Op<"constant", [Pure, AllTypesMatch<["value", "result"]>]> { let summary = "Creates a constant vector or index value."; let description = [{ The `tiny.constant` operation creates a constant value which can be either a vector<Nxf16> or an index.
Examples: ```mlir %0 = tiny.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf16> %1 = tiny.constant 42 : index ``` }];
// TypedAttrInterface allows the type to be inferred from the attribute. let arguments = (ins TypedAttrInterface:$value); let results = (outs AnyTypeOf<[Tiny_VectorF16, Index]>:$result);
// The attribute itself contains the type, so no need to print it separately. let assemblyFormat = "attr-dict $value";}Op的Interface里有SameOperandAndResultType选项
class Tiny_IndexBinaryOp<string mnemonic, list<Trait> traits = []> : Tiny_Op<mnemonic, !listconcat([Pure, SameOperandsAndResultType], traits)> {
let arguments = (ins Index:$lhs, Index:$rhs); let results = (outs Index:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict";}tiny dialect的运算操作转向arith dialect和vector dialect
struct SumOpLowering : public OpRewritePattern<SumOp> { using OpRewritePattern<SumOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SumOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getInput();
// vector.reduction<add> returns a scalar f16. Value scalarSum = vector::ReductionOp::create( rewriter, loc, vector::CombiningKind::ADD, input);
// Broadcast the scalar to vector<1xf16>. VectorType resultType = op.getResult().getType(); rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, resultType, scalarSum);
return success(); }};tiny dialect的指针操作与运算lower到LLVM Dialect
class TinyToLLVMTypeConverter : public TypeConverter {public: TinyToLLVMTypeConverter() { // Identity conversion for all types (fallback). addConversion([](Type type) { return type; });
// Convert tiny.ptr to llvm.ptr. addConversion([](PtrType type) -> Type { return LLVM::LLVMPointerType::get(type.getContext()); }); }};
struct StoreOpToLLVMLowering : public OpConversionPattern<StoreOp> { using OpConversionPattern<StoreOp>::OpConversionPattern;
LogicalResult matchAndRewrite(StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc();
// The adaptor provides the converted operands (ptr is now !llvm.ptr). Value value = adaptor.getValue(); Value ptr = adaptor.getPtr(); Value offset = adaptor.getOffset();
// Convert index offset to i64 for GEP. Type i64Type = rewriter.getI64Type(); Value offsetI64 = arith::IndexCastOp::create(rewriter, loc, i64Type, offset);
// Create GEP with f16 element type to compute the address. Type f16Type = rewriter.getF16Type(); Type llvmPtrType = LLVM::LLVMPointerType::get(getContext()); Value gep = LLVM::GEPOp::create(rewriter, loc, llvmPtrType, f16Type, ptr, ValueRange{offsetI64});
// Store the vector to the computed address. rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, value, gep); return success(); }};将func dialect设为dynamic legal
class TinyToLLVMPass : public impl::TinyToLLVMBase<TinyToLLVMPass> {public: void runOnOperation() override { // Set up the type converter. TinyToLLVMTypeConverter typeConverter;
// Set up the conversion target. ConversionTarget target(getContext());
// Mark memory operations as illegal. target.addIllegalOp<LoadOp, StoreOp>();
// Mark LLVM and arith dialects as legal. target.addLegalDialect<LLVM::LLVMDialect>(); target.addLegalDialect<arith::ArithDialect>();
// Mark func dialect operations as dynamically legal if their types are // converted. target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { return typeConverter.isSignatureLegal(op.getFunctionType()) && typeConverter.isLegal(&op.getBody()); }); target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) { return typeConverter.isLegal(op.getOperandTypes()); });
// Set up rewrite patterns. RewritePatternSet patterns(&getContext());
// Add conversion patterns that use the type converter. patterns.add<LoadOpToLLVMLowering, StoreOpToLLVMLowering>(typeConverter, &getContext());
// Add function signature conversion patterns. populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( patterns, typeConverter); populateReturnOpTypeConversionPattern(patterns, typeConverter);
// Apply the conversion. if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); }};Chapter 2
这个章节主要是将Tiny转为SCF,对于Python则需要将For循环转为SCF
以Accumulate举例,这会Python的decorate方法转化为tiny_loop.accumulate和tiny_loop.yield
def decorator(body_fn): # Get init value types and MLIR values init_values = [v._value for v in inits] init_types = [v._value.type for v in inits] result_types = init_types # Results match init types
# Create the accumulate op with one region op = Operation.create( "tiny_loop.accumulate", results=result_types, operands=[bound._value, step._value] + init_values, regions=1, # One region for the body )
# Set up the block with arguments: (index, *iter_args) region = op.regions[0] block_arg_types = [IndexType.get()] + init_types block = Block.create_at_start(region, block_arg_types)
# Execute body with wrapped arguments with InsertionPoint(block): iv = Index._wrap(block.arguments[0]) iter_args = [_wrap_value(block.arguments[i+1], inits[i]) for i in range(len(inits))]
# Call user's body function if inits: results = body_fn(iv, *iter_args) if not isinstance(results, (list, tuple)): results = [results] yield_values = [r._value for r in results] else: body_fn(iv) yield_values = []
# Create tiny_loop.yield Operation.create("tiny_loop.yield", operands=yield_values)
# Wrap and return results if result_types: return [_wrap_value(op.results[i], inits[i]) for i in range(len(result_types))] return None
return decoratormatemul.py则演示了矩阵乘法
这是一个 向量化矩阵乘法的实现,演示如何使用 Ch2 的循环构造(
accumulate)来编写高性能的矩阵乘法。算法:C[M,N] = A[M,K] * B[K,N]^T
关键点:
- B被转置,使得A和B都在K维上是连续的(便于向量化加载)
- 向量大小为16(一次加载16个f16元素)
- 三层嵌套循环:M维、N维、K维
ch2有一个单独的tiny_loopDialect,实现了SCF的Accumulate(对应scf.for)和Yield(对应scf.yield)
Chapter 3
Chapter3实现了一个类似Trition和TileIR的Tile实现
对应的则是tiny_tile这个Dialect,下降时会用的GPU Dialect(只是输出,并不运行)
对于定义Tile这个Type,使用assemblyformat无法实现对应解析要求,需要实现相对应对的parse和print
Type TileType::parse(AsmParser &parser) { if (parser.parseLess()) return Type();
// Parse "HxW" as a dimension list. MLIR's lexer treats "64x128" as a single // dimension list token, so we must use parseDimensionList. SmallVector<int64_t, 2> dims; if (parser.parseDimensionList(dims, /*allowDynamic=*/false, /*withTrailingX=*/false)) return Type();
// We expect exactly 2 dimensions for a 2D tile. if (dims.size() != 2) { parser.emitError(parser.getCurrentLocation()) << "expected 2 dimensions for tile, got " << dims.size(); return Type(); }
// Parse required comma and layout attribute. if (parser.parseComma()) return Type();
LayoutAttr layout; if (parser.parseAttribute(layout)) return Type();
if (parser.parseGreater()) return Type();
return TileType::get(parser.getContext(), dims[0], dims[1], layout);}
/// Print a tile type: `<` HxW `,` layout `>`void TileType::print(AsmPrinter &printer) const { printer << "<" << getHeight() << "x" << getWidth() << ", " << getLayout() << ">";}对于tiny_tile.splat在规划好矩阵的同时,也一并规划好线程
#tiny_tile.layout<thread = [1, 32], vector_size = 8>Layout的Parameter定义,可以看到可以参数使用的是int64_t(在MLIR中这是一个Attribute)
let parameters = (ins ArrayRefParameter<"int64_t">:$thread, "int64_t":$vectorSize);Tile的尺寸由thread和vector_size共同计算得到
例子1:thread = [4, 16], vector_size = 4
线程网格:4行 × 16列 = 64个线程 Tile尺寸:4 × (16*4) = 4×64 = 256个元素
(2D的Thread可以利用好GPU Warp的特性)
tiny_tile可以lowering到tiny和tiny_loop,线程部分lowering到GPU Dialect
The lowerings should produce:
LoadOp -> gpu.thread_id + tiny.load (compute per-thread offset from layout)
StoreOp -> gpu.thread_id + tiny.store (compute per-thread offset from layout)
SumOp -> tiny.sum + vector.extract + gpu.subgroup_reduce + vector.broadcast
而对于线程用二维进行表示的原因如下
| 概念 | 说明 |
|---|---|
| thread[0] | Y方向(垂直)的线程数 = thread_y的范围 |
| thread[1] | X方向(水平)的线程数 = thread_x的范围 |
| vector_size | 每个线程处理的向量宽度 |
| Tile实际尺寸 | thread[0] × (thread[1] × vector_size) |
| 映射方向 | 行优先(Row-major):先填满X(列),再增加Y(行) |
| 好处 | 利用GPU硬件特性,提高缓存局部性和执行效率 |
Chapter3的gpu_dot_product.py在用tiny_tile做乘法
这是一个 GPU并行点积(dot product)计算 的实现,使用 Tile-based DSL 在多个GPU块上进行SPMD(Single Program Multiple Data)计算
头一次见到Op的定义可以带有空格(实际上是当作EnumAttrbute传入)
- tiny_tile.elementwise add -> tiny.addf
- tiny_tile.elementwise sub -> tiny.subf
- tiny_tile.elementwise mul -> tiny.mulf
- tiny_tile.elementwise div -> tiny.divf
def TinyTile_EWKind_Add : I32EnumAttrCase<"add", 0>;def TinyTile_EWKind_Sub : I32EnumAttrCase<"sub", 1>;def TinyTile_EWKind_Mul : I32EnumAttrCase<"mul", 2>;def TinyTile_EWKind_Div : I32EnumAttrCase<"div", 3>;
def TinyTile_ElementwiseKind : I32EnumAttr< "ElementwiseKind", "Elementwise operation kind", [TinyTile_EWKind_Add, TinyTile_EWKind_Sub, TinyTile_EWKind_Mul, TinyTile_EWKind_Div]> { let cppNamespace = "::mlir::tiny_tile"; let genSpecializedAttr = 0;}
def TinyTile_ElementwiseKindAttr : EnumAttr<TinyTile_Dialect, TinyTile_ElementwiseKind, "ew_kind">;通过getkind()决定往哪个方向下降
LogicalResult ElementwiseOp::convertToSIMT(RewriterBase &rewriter, ValueRange simtOperands) { Value lhs = simtOperands[0]; Value rhs = simtOperands[1];
switch (getKind()) { case ElementwiseKind::add: rewriter.replaceOpWithNewOp<tiny::AddFOp>(*this, lhs, rhs); break; case ElementwiseKind::sub: rewriter.replaceOpWithNewOp<tiny::SubFOp>(*this, lhs, rhs); break; case ElementwiseKind::mul: rewriter.replaceOpWithNewOp<tiny::MulFOp>(*this, lhs, rhs); break; case ElementwiseKind::div: rewriter.replaceOpWithNewOp<tiny::DivFOp>(*this, lhs, rhs); break; }
return success();}tutorial/ch3-gpu-tile-dsl/TinyTileDialect.cpp实现了很多类型的convertToSIMT方法,文件中的注释值得细看
| 操作 | convertToSIMT 做什么 | 输入(Tile) | 输出(Vector) |
|---|---|---|---|
| SplatOp | 创建 per-thread 向量常数 | 标量值 | tiny.constant dense<…> : vector |
| LoadOp | 计算 per-thread 地址偏移 | ptr, row, col, stride | %thread_id + 地址计算 → tiny.load |
| StoreOp | 计算 per-thread 地址偏移 | value, ptr, row, col, stride | %thread_id + 地址计算 → tiny.store |
| SumOp | 两级归约(本地+跨线程) | vector | tiny.sum + extract + subgroup_reduce + broadcast |
Challenge Exercise是做一个tiny_tile的matmul
由于对GPU Dialect不太熟悉,这一章的转化感觉没看懂
Chapter 4
主要关于Linage Dialect和Transform Dialect,从而实现Tile的算子融合
文档中提到Transform Dialect参考的是Halide IR(“Halide like”),这个说法我是第一次听到(如果你再仔细问AI的话,会发现Halide是TVM和MLIR祖先)
Halide非常有意思,因为其对算法和调度进行了区分
范畴 属于“算法”吗? 属于什么? 加减乘除、数学函数(sin, exp) 是 计算逻辑 像素间的依赖关系(如均值模糊) 是 计算逻辑 循环的顺序(先行后列,还是分块) 否 调度 (Schedule) 是否使用多线程 (Parallel) 否 调度 (Schedule) 是否使用向量化 (Vectorize) 否 调度 (Schedule) 临时缓冲区的大小和位置 否 调度 (Schedule) “算法与调度分离”有三个核心原因
- 数学上的正确性:算法部分是容易验证的。只要数学公式对了,无论调度怎么变(分块、并行),计算结果理论上应该是一致的。
- 模块化:同一个算法可以有多种调度方案。例如,针对手机 CPU 有一套调度,针对高性能 GPU 有另一套调度,但算法代码一行都不用改。
- 算子融合(Fusion):因为算法是纯函数描述,编译器可以清晰地看到函数间的依赖关系,从而自动决定是否把两个步骤合并在一起计算,而不需要程序员手动去拆解循环。
这部分基本是MLIR官方Tutorial的简化版本,如果实在不明白,可以去看MLIR官网的Transform Dialect
