这个月初参加了CGO的LLVM Workshop,在悉尼ICC的2楼做完报告后就到3楼溜达(3楼在开CC,以及PPoPP/HPCA的Workshop和Tutroial),顺带就看了一下MLIR Tutorial。我觉得这个Tutrorial做的不错,讲的是用MLIR实现个简单的Tile,但当天身体不太舒服提前走了,直到这几天才基本完整过一遍Tutroial

IMG_20260131_120150

项目地址:Groverkss/mlir-tutor

按照项目的Readme.md配置即可顺利运行

tutorial-opt注意事项

项目的Opt位于build/tutorial/tutorial-opt

opt输出可以选择--split-input-file将同一个文件的不同Op进行输出分割

如果需要运行运行Python程序,则需要注意TUTORIAL_OPT环境变量是否则正确

通过grep可以查看对应Pass(总感觉这块显示并不怎么好,看看有没有方法改进)

tutorial-opt --help | grep "tiny"

MLIR与Python联动

我认为这是这个Tutorial吸引人的地方,其提供了一套非常基础的Python与MLIR的Type与Op绑定

比如下面这个是Ptr绑定:

class Ptr:
"""Wrapper for !tiny.ptr SSA values."""
_value: Value
@staticmethod
def _wrap(value: Value) -> "Ptr":
p = Ptr()
p._value = value
return p
@staticmethod
def get_type() -> Type:
"""Get the !tiny.ptr type."""
return Type.parse("!tiny.ptr")
def load(self, offset: Index, num_elements: int) -> F16Vector:
"""Load vector<Nxf16> from pointer at offset."""
vec_type = VectorType.get([num_elements], F16Type.get())
op = Operation.create(
"tiny.load",
results=[vec_type],
operands=[self._value, offset._value],
)
return F16Vector._wrap(op.result)
def store(self, offset: Index, vec: F16Vector) -> None:
"""Store vector<Nxf16> to pointer at offset."""
Operation.create(
"tiny.store",
results=[],
operands=[vec._value, self._value, offset._value],
)

Compile_and_print将Python转化为MLIR并输出

def compile_and_print(fn):
"""Compile and print all lowering stages."""
opt = TutorialOpt()
with MLIRModule() as m:
tiny_ir = m.build_func_verified(fn, _get_type_map(), opt)
print("=== Tiny Dialect ===")
print(tiny_ir)
arith_ir = opt.run(tiny_ir, ["tiny-to-arith", "canonicalize", "cse"])
print("=== After tiny-to-arith ===")
print(arith_ir)
llvm_ir = opt.run(tiny_ir, ["tiny-to-arith", "canonicalize", "cse", "tiny-to-llvm", "convert-to-llvm"])
print("=== LLVM Dialect ===")
print(llvm_ir)
return llvm_ir

同样,MLIR的Vector Type也与Numpy的Vector进行绑定(Numpy的array转为mlir的vector),从代码来看方法也一并绑定

class F16Vector:
"""Wrapper for vector<Nxf16> SSA values."""
_value: Value
@staticmethod
def _wrap(value: Value) -> "F16Vector":
vec = F16Vector()
vec._value = value
return vec
@staticmethod
def constant(vals: list[float], size: int = None) -> "F16Vector":
"""Create a constant f16 vector via tiny.constant."""
n = size or len(vals)
vec_type = VectorType.get([n], F16Type.get())
data = np.array(vals, dtype=np.float16)
attr = DenseElementsAttr.get(data, type=vec_type)
op = Operation.create(
"tiny.constant",
results=[vec_type],
attributes={"value": attr},
)
return F16Vector._wrap(op.result)
def _binop(self, other: "F16Vector", op_name: str) -> "F16Vector":
op = Operation.create(
f"tiny.{op_name}",
results=[self._value.type],
operands=[self._value, other._value],
)
return F16Vector._wrap(op.result)
def __add__(self, other): return self._binop(other, "addf")
def __sub__(self, other): return self._binop(other, "subf")
def __mul__(self, other): return self._binop(other, "mulf")
def __truediv__(self, other): return self._binop(other, "divf")
def sum(self) -> "F16Vector":
"""Reduce to vector<1xf16> via tiny.sum."""
result_type = VectorType.get([1], F16Type.get())
op = Operation.create(
"tiny.sum",
results=[result_type],
operands=[self._value],
)
return F16Vector._wrap(op.result)

此外还需要对于MLIRModule进行单独Python包装,以适应不同的需求场景

class MLIRModule:
"""Context manager for building MLIR modules with unregistered dialects."""
def __init__(self):
self.ctx = None
self.loc = None
self.module = None
def __enter__(self):
self.ctx = Context()
self.ctx.allow_unregistered_dialects = True
self.ctx.__enter__()
self.loc = Location.unknown()
self.loc.__enter__()
return self
def __exit__(self, *args):
self.loc.__exit__(*args)
self.ctx.__exit__(*args)
def build_func(self, fn: Callable, type_map: dict) -> Module:
"""Build MLIR module from a Python function.
Args:
fn: Function to compile (uses type annotations for args)
type_map: Maps annotation types to (mlir_type, wrapper_class) tuples
"""
sig = inspect.signature(fn)
self.module = Module.create()
with InsertionPoint(self.module.body):
# Build input types from annotations
input_types = []
for param in sig.parameters.values():
if param.annotation not in type_map:
raise ValueError(f"Unsupported type: {param.annotation}")
mlir_type, _ = type_map[param.annotation]
input_types.append(mlir_type() if callable(mlir_type) else mlir_type)
# Create func.func
func_op = func_d.FuncOp(fn.__name__, (input_types, []))
with InsertionPoint(func_op.add_entry_block()):
# Wrap block arguments in DSL types
args = []
for i, param in enumerate(sig.parameters.values()):
_, wrapper_cls = type_map[param.annotation]
args.append(wrapper_cls._wrap(func_op.arguments[i]))
# Execute user function body
fn(*args)
# Add return
func_d.return_([])
return self.module
def build_func_verified(self, fn: Callable, type_map: dict, opt: "TutorialOpt") -> str:
"""Build and verify module, return pretty-printed IR.
Runs the generated IR through tutorial-opt to verify it's valid
and get pretty-printed output using the dialect's assembly format.
"""
self.build_func(fn, type_map)
raw_ir = str(self.module)
# Round-trip through tutorial-opt to verify and pretty-print
return opt.run(raw_ir, [])

此外还有一些实现细节需要注意

概念说明
tutorial-optC++编译的MLIR工具,实现了所有pass(tiny-to-arith、tiny-to-llvm等)
TutorialOptPython包装器,通过 subprocess 调用 tutorial-opt
run() 方法构建命令行参数,执行 tutorial-opt 进程,返回结果
pass 列表指定要执行的MLIR转换通道
stdin/stdout通过标准输入传入IR,从标准输出读取转换结果

当然,使用Python包导入复杂的MLIR依赖是第一次见到,这个方案确实解决了MLIR依赖复杂的问题,但代价是版本无法锁定,比如我现在测试时mlir-wheel就已经找不到当时的版本了(见Issue

Chapter 1

对于Dialect,Pass需要,且也可以在单独的TableGen中定义

//===- TinyPasses.td - Tiny dialect passes -----------------*- tablegen -*-===//
//
// Defines passes for the Tiny dialect.
//
// Reference: https://mlir.llvm.org/docs/PassManagement/#tablegen-specification
//
//===----------------------------------------------------------------------===//
#ifndef TINY_PASSES
#define TINY_PASSES
include "mlir/Pass/PassBase.td"
def TinyToArith : Pass<"tiny-to-arith"> {
let summary = "Lower Tiny arithmetic operations to arith dialect.";
let description = [{
This pass lowers Tiny dialect arithmetic operations to equivalent operations
in the arith dialect. Memory operations (load/store) and the ptr type are
NOT converted by this pass - use --tiny-to-llvm for that.
The lowering includes:
- `tiny.constant` -> `arith.constant`
- `tiny.addf/subf/mulf/divf` -> `arith.addf/subf/mulf/divf`
- `tiny.addi/subi/muli/divi` -> `arith.addi/subi/muli/divsi`
Example:
```mlir
%0 = tiny.addf %a, %b : vector<4xf16>
// Becomes:
%0 = arith.addf %a, %b : vector<4xf16>
```
}];
let dependentDialects = [
"mlir::arith::ArithDialect",
"mlir::vector::VectorDialect"
];
}
def TinyToLLVM : Pass<"tiny-to-llvm"> {
let summary = "Lower Tiny memory operations and ptr type to LLVM dialect.";
let description = [{
This pass lowers Tiny dialect memory operations and the ptr type to
equivalent operations in the LLVM dialect.
The lowering includes:
- `!tiny.ptr` type -> `!llvm.ptr`
- `tiny.load %ptr, %offset` -> GEP to compute byte address, then `llvm.load`
- `tiny.store %val, %ptr, %offset` -> GEP to compute byte address, then `llvm.store`
The offset in tiny.load/store is in f16 elements. The lowering converts this
to a byte offset by using GEP with f16 element type:
```mlir
%0 = tiny.load %ptr, %offset : vector<4xf16>
// Becomes:
%gep = llvm.getelementptr %ptr[%offset] : (!llvm.ptr, i64) -> !llvm.ptr, f16
%0 = llvm.load %gep : !llvm.ptr -> vector<4xf16>
```
Note: Run --tiny-to-arith first to convert arithmetic operations.
}];
let dependentDialects = [
"mlir::LLVM::LLVMDialect"
];
}
#endif // TINY_PASSES

非常规范的Type定义与Operate定义,CPred从的isPowerOf2_64中起着校验参数的作用

// Constraint for vector<Nxf16> where N is a power of 2.
// Uses VectorOfRankAndType from CommonTypeConstraints.td (included via OpBase.td)
// with an additional power-of-2 size check.
def Tiny_VectorF16 : Type<
And<[
VectorOfRankAndType<[1], [F16]>.predicate,
CPred<"::llvm::isPowerOf2_64("
"::llvm::cast<::mlir::VectorType>($_self).getDimSize(0))">
]>,
"vector of f16 with power-of-2 size",
"::mlir::VectorType"
>;

类型的定义还可以继承

// Constraint for vector<1xf16> (result type for sum operation).
// Reuses Tiny_VectorF16 predicate and adds a size=1 constraint.
def Tiny_Vector1F16 : Type<
And<[
Tiny_VectorF16.predicate,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getDimSize(0) == 1">
]>,
"vector<1xf16>",
"::mlir::VectorType"
>;

ConstantOp既可以传入vector也可以是index

def Tiny_ConstantOp : Tiny_Op<"constant", [Pure,
AllTypesMatch<["value", "result"]>]> {
let summary = "Creates a constant vector or index value.";
let description = [{
The `tiny.constant` operation creates a constant value which can be either
a vector<Nxf16> or an index.
Examples:
```mlir
%0 = tiny.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf16>
%1 = tiny.constant 42 : index
```
}];
// TypedAttrInterface allows the type to be inferred from the attribute.
let arguments = (ins TypedAttrInterface:$value);
let results = (outs AnyTypeOf<[Tiny_VectorF16, Index]>:$result);
// The attribute itself contains the type, so no need to print it separately.
let assemblyFormat = "attr-dict $value";
}

Op的Interface里有SameOperandAndResultType选项

class Tiny_IndexBinaryOp<string mnemonic, list<Trait> traits = []> :
Tiny_Op<mnemonic, !listconcat([Pure, SameOperandsAndResultType], traits)> {
let arguments = (ins Index:$lhs, Index:$rhs);
let results = (outs Index:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict";
}

tiny dialect的运算操作转向arith dialect和vector dialect

struct SumOpLowering : public OpRewritePattern<SumOp> {
using OpRewritePattern<SumOp>::OpRewritePattern;
LogicalResult matchAndRewrite(SumOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getInput();
// vector.reduction<add> returns a scalar f16.
Value scalarSum = vector::ReductionOp::create(
rewriter, loc, vector::CombiningKind::ADD, input);
// Broadcast the scalar to vector<1xf16>.
VectorType resultType = op.getResult().getType();
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, resultType, scalarSum);
return success();
}
};

tiny dialect的指针操作与运算lower到LLVM Dialect

class TinyToLLVMTypeConverter : public TypeConverter {
public:
TinyToLLVMTypeConverter() {
// Identity conversion for all types (fallback).
addConversion([](Type type) { return type; });
// Convert tiny.ptr to llvm.ptr.
addConversion([](PtrType type) -> Type {
return LLVM::LLVMPointerType::get(type.getContext());
});
}
};
struct StoreOpToLLVMLowering : public OpConversionPattern<StoreOp> {
using OpConversionPattern<StoreOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
// The adaptor provides the converted operands (ptr is now !llvm.ptr).
Value value = adaptor.getValue();
Value ptr = adaptor.getPtr();
Value offset = adaptor.getOffset();
// Convert index offset to i64 for GEP.
Type i64Type = rewriter.getI64Type();
Value offsetI64 =
arith::IndexCastOp::create(rewriter, loc, i64Type, offset);
// Create GEP with f16 element type to compute the address.
Type f16Type = rewriter.getF16Type();
Type llvmPtrType = LLVM::LLVMPointerType::get(getContext());
Value gep = LLVM::GEPOp::create(rewriter, loc, llvmPtrType, f16Type, ptr,
ValueRange{offsetI64});
// Store the vector to the computed address.
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, value, gep);
return success();
}
};

将func dialect设为dynamic legal

class TinyToLLVMPass : public impl::TinyToLLVMBase<TinyToLLVMPass> {
public:
void runOnOperation() override {
// Set up the type converter.
TinyToLLVMTypeConverter typeConverter;
// Set up the conversion target.
ConversionTarget target(getContext());
// Mark memory operations as illegal.
target.addIllegalOp<LoadOp, StoreOp>();
// Mark LLVM and arith dialects as legal.
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<arith::ArithDialect>();
// Mark func dialect operations as dynamically legal if their types are
// converted.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return typeConverter.isLegal(op.getOperandTypes());
});
// Set up rewrite patterns.
RewritePatternSet patterns(&getContext());
// Add conversion patterns that use the type converter.
patterns.add<LoadOpToLLVMLowering, StoreOpToLLVMLowering>(typeConverter,
&getContext());
// Add function signature conversion patterns.
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
// Apply the conversion.
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};

Chapter 2

这个章节主要是将Tiny转为SCF,对于Python则需要将For循环转为SCF

以Accumulate举例,这会Python的decorate方法转化为tiny_loop.accumulatetiny_loop.yield

def decorator(body_fn):
# Get init value types and MLIR values
init_values = [v._value for v in inits]
init_types = [v._value.type for v in inits]
result_types = init_types # Results match init types
# Create the accumulate op with one region
op = Operation.create(
"tiny_loop.accumulate",
results=result_types,
operands=[bound._value, step._value] + init_values,
regions=1, # One region for the body
)
# Set up the block with arguments: (index, *iter_args)
region = op.regions[0]
block_arg_types = [IndexType.get()] + init_types
block = Block.create_at_start(region, block_arg_types)
# Execute body with wrapped arguments
with InsertionPoint(block):
iv = Index._wrap(block.arguments[0])
iter_args = [_wrap_value(block.arguments[i+1], inits[i])
for i in range(len(inits))]
# Call user's body function
if inits:
results = body_fn(iv, *iter_args)
if not isinstance(results, (list, tuple)):
results = [results]
yield_values = [r._value for r in results]
else:
body_fn(iv)
yield_values = []
# Create tiny_loop.yield
Operation.create("tiny_loop.yield", operands=yield_values)
# Wrap and return results
if result_types:
return [_wrap_value(op.results[i], inits[i])
for i in range(len(result_types))]
return None
return decorator

matemul.py则演示了矩阵乘法

这是一个 向量化矩阵乘法的实现,演示如何使用 Ch2 的循环构造(accumulate)来编写高性能的矩阵乘法。

算法:C[M,N] = A[M,K] * B[K,N]^T

关键点:

  • B被转置,使得A和B都在K维上是连续的(便于向量化加载)
  • 向量大小为16(一次加载16个f16元素)
  • 三层嵌套循环:M维、N维、K维

ch2有一个单独的tiny_loopDialect,实现了SCF的Accumulate(对应scf.for)和Yield(对应scf.yield)

Chapter 3

Chapter3实现了一个类似Trition和TileIR的Tile实现

对应的则是tiny_tile这个Dialect,下降时会用的GPU Dialect(只是输出,并不运行)

对于定义Tile这个Type,使用assemblyformat无法实现对应解析要求,需要实现相对应对的parseprint

Type TileType::parse(AsmParser &parser) {
if (parser.parseLess())
return Type();
// Parse "HxW" as a dimension list. MLIR's lexer treats "64x128" as a single
// dimension list token, so we must use parseDimensionList.
SmallVector<int64_t, 2> dims;
if (parser.parseDimensionList(dims, /*allowDynamic=*/false,
/*withTrailingX=*/false))
return Type();
// We expect exactly 2 dimensions for a 2D tile.
if (dims.size() != 2) {
parser.emitError(parser.getCurrentLocation())
<< "expected 2 dimensions for tile, got " << dims.size();
return Type();
}
// Parse required comma and layout attribute.
if (parser.parseComma())
return Type();
LayoutAttr layout;
if (parser.parseAttribute(layout))
return Type();
if (parser.parseGreater())
return Type();
return TileType::get(parser.getContext(), dims[0], dims[1], layout);
}
/// Print a tile type: `<` HxW `,` layout `>`
void TileType::print(AsmPrinter &printer) const {
printer << "<" << getHeight() << "x" << getWidth() << ", " << getLayout() << ">";
}

对于tiny_tile.splat在规划好矩阵的同时,也一并规划好线程

#tiny_tile.layout<thread = [1, 32], vector_size = 8>

Layout的Parameter定义,可以看到可以参数使用的是int64_t(在MLIR中这是一个Attribute)

let parameters = (ins
ArrayRefParameter<"int64_t">:$thread,
"int64_t":$vectorSize
);

Tile的尺寸由thread和vector_size共同计算得到

例子1:thread = [4, 16], vector_size = 4

线程网格:4行 × 16列 = 64个线程 Tile尺寸:4 × (16*4) = 4×64 = 256个元素

(2D的Thread可以利用好GPU Warp的特性)

tiny_tile可以lowering到tiny和tiny_loop,线程部分lowering到GPU Dialect

The lowerings should produce:

  • LoadOp -> gpu.thread_id + tiny.load (compute per-thread offset from layout)

  • StoreOp -> gpu.thread_id + tiny.store (compute per-thread offset from layout)

  • SumOp -> tiny.sum + vector.extract + gpu.subgroup_reduce + vector.broadcast

而对于线程用二维进行表示的原因如下

概念说明
thread[0]Y方向(垂直)的线程数 = thread_y的范围
thread[1]X方向(水平)的线程数 = thread_x的范围
vector_size每个线程处理的向量宽度
Tile实际尺寸thread[0] × (thread[1] × vector_size)
映射方向行优先(Row-major):先填满X(列),再增加Y(行)
好处利用GPU硬件特性,提高缓存局部性和执行效率

Chapter3的gpu_dot_product.py在用tiny_tile做乘法

这是一个 GPU并行点积(dot product)计算 的实现,使用 Tile-based DSL 在多个GPU块上进行SPMD(Single Program Multiple Data)计算

头一次见到Op的定义可以带有空格(实际上是当作EnumAttrbute传入)

- tiny_tile.elementwise add -> tiny.addf

- tiny_tile.elementwise sub -> tiny.subf

- tiny_tile.elementwise mul -> tiny.mulf

- tiny_tile.elementwise div -> tiny.divf

def TinyTile_EWKind_Add : I32EnumAttrCase<"add", 0>;
def TinyTile_EWKind_Sub : I32EnumAttrCase<"sub", 1>;
def TinyTile_EWKind_Mul : I32EnumAttrCase<"mul", 2>;
def TinyTile_EWKind_Div : I32EnumAttrCase<"div", 3>;
def TinyTile_ElementwiseKind : I32EnumAttr<
"ElementwiseKind",
"Elementwise operation kind",
[TinyTile_EWKind_Add, TinyTile_EWKind_Sub,
TinyTile_EWKind_Mul, TinyTile_EWKind_Div]> {
let cppNamespace = "::mlir::tiny_tile";
let genSpecializedAttr = 0;
}
def TinyTile_ElementwiseKindAttr :
EnumAttr<TinyTile_Dialect, TinyTile_ElementwiseKind, "ew_kind">;

通过getkind()决定往哪个方向下降

LogicalResult ElementwiseOp::convertToSIMT(RewriterBase &rewriter,
ValueRange simtOperands) {
Value lhs = simtOperands[0];
Value rhs = simtOperands[1];
switch (getKind()) {
case ElementwiseKind::add:
rewriter.replaceOpWithNewOp<tiny::AddFOp>(*this, lhs, rhs);
break;
case ElementwiseKind::sub:
rewriter.replaceOpWithNewOp<tiny::SubFOp>(*this, lhs, rhs);
break;
case ElementwiseKind::mul:
rewriter.replaceOpWithNewOp<tiny::MulFOp>(*this, lhs, rhs);
break;
case ElementwiseKind::div:
rewriter.replaceOpWithNewOp<tiny::DivFOp>(*this, lhs, rhs);
break;
}
return success();
}

tutorial/ch3-gpu-tile-dsl/TinyTileDialect.cpp实现了很多类型的convertToSIMT方法,文件中的注释值得细看

操作convertToSIMT 做什么输入(Tile)输出(Vector)
SplatOp创建 per-thread 向量常数标量值tiny.constant dense<…> : vector
LoadOp计算 per-thread 地址偏移ptr, row, col, stride%thread_id + 地址计算 → tiny.load
StoreOp计算 per-thread 地址偏移value, ptr, row, col, stride%thread_id + 地址计算 → tiny.store
SumOp两级归约(本地+跨线程)vectortiny.sum + extract + subgroup_reduce + broadcast

Challenge Exercise是做一个tiny_tile的matmul

由于对GPU Dialect不太熟悉,这一章的转化感觉没看懂

Chapter 4

主要关于Linage Dialect和Transform Dialect,从而实现Tile的算子融合

文档中提到Transform Dialect参考的是Halide IR(“Halide like”),这个说法我是第一次听到(如果你再仔细问AI的话,会发现Halide是TVM和MLIR祖先)

Halide非常有意思,因为其对算法和调度进行了区分

范畴属于“算法”吗?属于什么?
加减乘除、数学函数(sin, exp)计算逻辑
像素间的依赖关系(如均值模糊)计算逻辑
循环的顺序(先行后列,还是分块)调度 (Schedule)
是否使用多线程 (Parallel)调度 (Schedule)
是否使用向量化 (Vectorize)调度 (Schedule)
临时缓冲区的大小和位置调度 (Schedule)

“算法与调度分离”有三个核心原因

  1. 数学上的正确性:算法部分是容易验证的。只要数学公式对了,无论调度怎么变(分块、并行),计算结果理论上应该是一致的。
  2. 模块化:同一个算法可以有多种调度方案。例如,针对手机 CPU 有一套调度,针对高性能 GPU 有另一套调度,但算法代码一行都不用改
  3. 算子融合(Fusion):因为算法是纯函数描述,编译器可以清晰地看到函数间的依赖关系,从而自动决定是否把两个步骤合并在一起计算,而不需要程序员手动去拆解循环。

这部分基本是MLIR官方Tutorial的简化版本,如果实在不明白,可以去看MLIR官网的Transform Dialect