本文是淦!MLIR输出Hello World不应该这么难的续集,同时也是对MLIR官方Toy Tutorial的Chapter #7中一些困惑的自我解答尝试

本文的实现代码已放入Github仓库:

https://github.com/mocusez/mlir-hello-world

根据Github CI的脚本可了解运行流程

对于Toy Tutroial的困惑

文章中在Toy Dialect中定义一个composite type,名称为struct

# A struct is defined by using the `struct` keyword followed by a name.
struct MyStruct {
# Inside of the struct is a list of variable declarations without initializers
# or shapes, which may also be other previously defined structs.
var a;
var b;
}

文章只是告诉我们了定义一个Type的流程:

  1. 定义Type
  2. 定义TypeStorage
  3. 定义ODS,即TableGen
  4. 定义Type的解析和打印
  5. 定义Type的相关操作

但是很多细节语焉不详:

  1. 我们该如何处理Type和TypeStorage的下降(lowering)?
  2. 我明明已经用了MLIR,为什么需要处理Type的解析和打印?(以及这个解析和打印,是针对Toy这门语言,还是MLIR的Dialect)

第二个问题其实极具迷惑性,但实际答案是:用于MLIR文本源代码的解析,在这点上与Toy语言的解析大不相同

基于这两点,我们开启今天文章的内容:实现一个简易的HashTable,并且能打印(Print)看到结果

实现方案概述

在Hello Dialect中定义Dict这个Type,以及Type的操作方法(Put,Get,Delete),转化为LLVM IR进行执行实现

添加自定义类型Type

首先,在Dialect的TableGen里面声明类型

def DictType :
DialectType<Hello_Dialect, CPred<"::llvm::isa<DictType>($_self)">,
"Hello dict type">;

如果不确定这里面的CPred怎么写,那么参照这个进行修改即可

dialect.h中定义Type的C++方法,用于在转化中生成和访问Dialect,同时声明使用的TypeStorage

namespace hello {
struct DictTypeStorage;
}
class DictType : public mlir::Type::TypeBase<DictType, mlir::Type,
hello::DictTypeStorage> {
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
/// Create an instance of a `DictType` with the given key and value types.
static DictType get(mlir::Type keyType, mlir::Type valueType);
/// Returns the key type of this dict type.
mlir::Type getKeyType();
/// Returns the value type of this dict type.
mlir::Type getValueType();
/// The name of this dict type.
static constexpr mlir::StringLiteral name = "hello.dict";
};

dialect.cpp中完成Type的方法补齐,下文中的getImpl()是指从TypeStorage访问相关MLIR Type

DictType DictType::get(mlir::Type keyType, mlir::Type valueType) {
return Base::get(keyType.getContext(), keyType, valueType);
}
mlir::Type DictType::getKeyType() {
return getImpl()->keyType;
}
/// Returns the value type of this dict type.
mlir::Type DictType::getValueType() {
return getImpl()->valueType;
}

以及在Dialect中的initialize()方法中添加对应的Type

void HelloDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "Hello/HelloOps.cpp.inc"
>();
addTypes<DictType>();
}

添加自定义的TypeStorage

MLIR存在Type和TypeStorage结构,关于TypeStorage,可以理解为Type当中数据和元数据的实际存储

由于在Dialect.h中已经声明了DictTypeStorage,那么在Dialect.cpp中直接继承mlir::TypeStorage实现就行

operator==确定了Type之间的比较关系

hashKey确定操作的Type是否为同一个

实现*construct方法,使用allocator告诉Dialect需要申请的空间(这一步操作并不会在lowering中体现)

struct DictTypeStorage : public mlir::TypeStorage {
/// The `KeyTy` defines what uniquely identifies this type.
/// For dict type, we unique on the key type and value type pair.
using KeyTy = std::pair<mlir::Type, mlir::Type>;
/// Constructor for the type storage instance.
DictTypeStorage(mlir::Type keyType, mlir::Type valueType)
: keyType(keyType), valueType(valueType) {}
/// Define the comparison function for the key type.
bool operator==(const KeyTy &key) const {
return key.first == keyType && key.second == valueType;
}
/// Define a hash function for the key type.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine(key.first, key.second);
}
/// Define a construction function for the key type.
static KeyTy getKey(mlir::Type keyType, mlir::Type valueType) {
return KeyTy(keyType, valueType);
}
/// Define a construction method for creating a new instance of this storage.
static DictTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
const KeyTy &key) {
// Allocate the storage instance and construct it.
return new (allocator.allocate<DictTypeStorage>())
DictTypeStorage(key.first, key.second);
}
/// The key and value types of the dict.
mlir::Type keyType;
mlir::Type valueType;
};

Type类型转化

Hello Dialect中的Dict Type需要转化为一个指针,相关的操作实际发生与Operation的下降lowering,通过下降Op的入参和返回值实现

一个方案是创建一个继承LLVMTypeConverter的方法,添加addConversion——这个方案理论可行,但我没试过

class HelloTypeConverter : public mlir::LLVMTypeConverter {
public:
HelloTypeConverter(mlir::MLIRContext *ctx) : mlir::LLVMTypeConverter(ctx) {
addConversion([](hello::DictType type) {
return mlir::LLVM::LLVMPointerType::get(type.getContext());
});
}
};

另外一个方案是直接在typeconverter上加上addConversion

这里还分出两种写法可供大家参考(这两种方法测试均通过):

  1. 直接点名转换的Type的类型
typeConverter.addConversion([](DictType type) -> mlir::Type {
return mlir::LLVM::LLVMPointerType::get(type.getContext());
});
  1. 根据mlir::dyn_cast判断后确定返回
typeConverter.addConversion([](mlir::Type type) -> std::optional<mlir::Type> {
if (auto dictType = mlir::dyn_cast<DictType>(type))
return mlir::LLVM::LLVMPointerType::get(type.getContext());
return std::nullopt;
});

按照构造函数传入typeconverter后,就可以使用getTypeConverter根据类型决定传回什么参数

auto resultType = getTypeConverter()->convertType(op->getResult(0).getType());
if (!resultType) {
return mlir::failure();
}

所以这段代码就等价于

auto resultType = mlir::LLVM::LLVMPointerType::get(context);

这部分代码相等于就是生成不可变代码——resultType不会因为Operation传入的参数的不同而发生变化

Type类型操作方法转化

对于Dict(哈希表)而言,有putgetdelete这些操作,以及用于释放内存的createfree,这些都需要在Dialect中定义对应Op(Operation)才能实现

下面代码当中,定义AssemblyFormat确保能从文本的mlir文件中能正确解析类型

class Hello_Op<string mnemonic, list<Trait> traits = []> :
Op<Hello_Dialect, mnemonic, traits>;
def Dict_CreateOp : Hello_Op<"dict.create", [Pure]> {
let summary = "Create a new dict<string,i32>";
let results = (outs DictType:$dict);
let assemblyFormat = "attr-dict `:` type($dict)";
}
def Dict_FreeOp : Hello_Op<"dict.free", []> {
let summary = "Free the dict<string,i32> memory";
let arguments = (ins DictType:$dict);
let assemblyFormat = "$dict attr-dict `:` type($dict)";
}
def Dict_PutOp : Hello_Op<"dict.put", []> {
let summary = "Insert string->i32";
let arguments = (ins DictType:$dict, StrAttr:$key, I32Attr:$value);
let results = (outs DictType:$out);
let assemblyFormat = "$dict `,` $key `=` $value attr-dict `:` type($dict) `->` type($out)";
}
def Dict_GetOp : Hello_Op<"dict.get", []> {
let summary = "Lookup string->i32, returns i32";
let arguments = (ins DictType:$dict, StrAttr:$key);
let results = (outs I32:$value);
let assemblyFormat = "$dict `,` $key attr-dict `:` type($dict) `->` type($value)";
}
def Dict_DeleteOp : Hello_Op<"dict.delete", []> {
let summary = "Delete key string";
let arguments = (ins DictType:$dict, StrAttr:$key);
let results = (outs DictType:$out);
let assemblyFormat = "$dict `,` $key attr-dict `:` type($dict) `->` type($out)";
}

同时也需要写明各个Op的lowering,这里以Dict_CreateOp的实现为例

class DictCreateOpLowering : public mlir::ConversionPattern {
public:
explicit DictCreateOpLowering(mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context)
: mlir::ConversionPattern(
typeConverter, hello::CreateOp::getOperationName(), 1, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,
mlir::ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto resultType =
getTypeConverter()->convertType(op->getResult(0).getType());
if (!resultType) {
return mlir::failure();
}
mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();
auto createMapRef = getOrInsertCreateMap(rewriter, module);
auto callOp = rewriter.create<mlir::LLVM::CallOp>(
loc, resultType, createMapRef, mlir::ValueRange{});
rewriter.replaceOp(op, callOp.getResult());
return mlir::success();
}
private:
mlir::FlatSymbolRefAttr getOrInsertCreateMap(mlir::PatternRewriter &rewriter,
mlir::ModuleOp module) const {
auto *context = module.getContext();
if (module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("create_map"))
return mlir::SymbolRefAttr::get(context, "create_map");
// Create function type: () -> !llvm.ptr
auto resultType = mlir::LLVM::LLVMPointerType::get(context);
auto fnType =
mlir::LLVM::LLVMFunctionType::get(resultType, std::nullopt, false);
// Insert function declaration
mlir::PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<mlir::LLVM::LLVMFuncOp>(module.getLoc(), "create_map",
fnType);
return mlir::SymbolRefAttr::get(context, "create_map");
}
};

这个Op在做的事情就是:Dict_CreateOp转化为对于create_map这个C/LLVM IR函数的调用,只要能实现,就能用C的printf()完成输出。

定义MLIR文本解析规则

这是非常重要的一部分,需要在声明对于自定义Type的解析范式,即printTypeparseType

为了做这一步,需要将Dialect的useDefaultTypePrinterParser设为1

def Hello_Dialect : Dialect {
let name = "hello";
let summary = "A hello out-of-tree MLIR dialect.";
let description = [{
This dialect is minimal example to implement hello-world kind of sample code
for MLIR.
}];
let cppNamespace = "::hello";
let useDefaultTypePrinterParser = 1;
}

在对应的dialect.cpp中定义printTypeparseType

mlir::Type DictType::getValueType() {
return getImpl()->valueType;
}
mlir::Type HelloDialect::parseType(mlir::DialectAsmParser &parser) const {
llvm::StringRef typeTag;
if (parser.parseKeyword(&typeTag))
return mlir::Type();
if (typeTag == "dict") {
if (parser.parseLess())
return mlir::Type();
mlir::Type keyType;
if (parser.parseType(keyType))
return mlir::Type();
if (parser.parseComma())
return mlir::Type();
mlir::Type valueType;
if (parser.parseType(valueType))
return mlir::Type();
if (parser.parseGreater())
return mlir::Type();
return DictType::get(keyType, valueType);
}
parser.emitError(parser.getNameLoc(), "unknown hello type: ") << typeTag;
return mlir::Type();
}
void HelloDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &printer) const {
if (auto dictType = mlir::dyn_cast<DictType>(type)) {
printer << "dict<";
printer.printType(dictType.getKeyType());
printer << ", ";
printer.printType(dictType.getValueType());
printer << ">";
return;
}
llvm_unreachable("unhandled hello type");
}

运行!

准备一份MLIR代码

module {
func.func @test_dict_operations() -> i32 {
%dict0 = hello.dict.create : !hello.dict<index, i32>
%dict1 = hello.dict.put %dict0, "first1" = 100 : !hello.dict<index, i32> -> !hello.dict<index, i32>
%dict2 = hello.dict.put %dict1, "second" = 200 : !hello.dict<index, i32> -> !hello.dict<index, i32>
%dict3 = hello.dict.put %dict2, "third" = 300 : !hello.dict<index, i32> -> !hello.dict<index, i32>
%val1 = hello.dict.get %dict3, "first1" : !hello.dict<index, i32> -> i32
%val2 = hello.dict.get %dict3, "second" : !hello.dict<index, i32> -> i32
%val3 = hello.dict.get %dict3, "third" : !hello.dict<index, i32> -> i32
%dict4 = hello.dict.delete %dict3, "second" : !hello.dict<index, i32> -> !hello.dict<index, i32>
hello.dict.free %dict4 : !hello.dict<index, i32>
func.return %val2 : i32
}
}

将MLIR下降到LLVM IR,我们看一看看打印出来的LLVM IR的样子

; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-pc-linux-gnu"
@key_third = internal constant [6 x i8] c"third\00"
@key_second = internal constant [7 x i8] c"second\00"
@key_first1 = internal constant [7 x i8] c"first1\00"
declare void @free_map(ptr) local_unnamed_addr
declare void @delete(ptr, ptr) local_unnamed_addr
declare ptr @get(ptr, ptr) local_unnamed_addr
declare void @put(ptr, ptr, i32) local_unnamed_addr
declare ptr @create_map() local_unnamed_addr
define i32 @test_dict_operations() local_unnamed_addr {
%1 = tail call ptr @create_map()
tail call void @put(ptr %1, ptr nonnull @key_first1, i32 100)
tail call void @put(ptr %1, ptr nonnull @key_second, i32 200)
tail call void @put(ptr %1, ptr nonnull @key_third, i32 300)
%2 = tail call ptr @get(ptr %1, ptr nonnull @key_first1)
%3 = tail call ptr @get(ptr %1, ptr nonnull @key_second)
%4 = load i32, ptr %3, align 4
%5 = tail call ptr @get(ptr %1, ptr nonnull @key_third)
tail call void @delete(ptr %1, ptr nonnull @key_second)
tail call void @free_map(ptr %1)
ret i32 %4
}
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}

大部分转化为了C/LLVM IR的函数调用,这样就变成了我们所熟悉的问题,和前文的printf()行为一样

在C/C++的实现相关接口函数,编译并调用运行

../../build/bin/hello-opt dict.mlir -emit=llvm > dict.ll
clang-20 dict.c dict.ll -o dict
extern int test_dict_operations();
int main() {
printf("Starting dictionary test...\n");
int result = test_dict_operations();
printf("Result from test_dict_operations: %d\n", result);
return 0;
}

预期结果为

Starting dictionary test...
Result from test_dict_operations: 200

image-20250724222358886

思考内容

可以看到这种实现方式的局限性:只能适配于<string,i32>结构的哈希表,不能适用于其他类型。这个问题可以通过动态生成LLVM IR解决,具体操作就是另外一个问题了😂

还是建议要自己上手尝试才加深理解