如何在MLIR中自定义类型并且输出运行
本文是淦!MLIR输出Hello World不应该这么难的续集,同时也是对MLIR官方Toy Tutorial的Chapter #7中一些困惑的自我解答尝试
本文的实现代码已放入Github仓库:
https://github.com/mocusez/mlir-hello-world
根据Github CI的脚本可了解运行流程
对于Toy Tutroial的困惑
文章中在Toy Dialect中定义一个composite type,名称为struct
# A struct is defined by using the `struct` keyword followed by a name.struct MyStruct { # Inside of the struct is a list of variable declarations without initializers # or shapes, which may also be other previously defined structs. var a; var b;}
文章只是告诉我们了定义一个Type的流程:
- 定义Type
- 定义TypeStorage
- 定义ODS,即TableGen
- 定义Type的解析和打印
- 定义Type的相关操作
但是很多细节语焉不详:
- 我们该如何处理Type和TypeStorage的下降(lowering)?
- 我明明已经用了MLIR,为什么需要处理Type的解析和打印?(以及这个解析和打印,是针对Toy这门语言,还是MLIR的Dialect)
第二个问题其实极具迷惑性,但实际答案是:用于MLIR文本源代码的解析,在这点上与Toy语言的解析大不相同
基于这两点,我们开启今天文章的内容:实现一个简易的HashTable,并且能打印(Print)看到结果
实现方案概述
在Hello Dialect中定义Dict这个Type,以及Type的操作方法(Put,Get,Delete),转化为LLVM IR进行执行实现
添加自定义类型Type
首先,在Dialect的TableGen里面声明类型
def DictType : DialectType<Hello_Dialect, CPred<"::llvm::isa<DictType>($_self)">, "Hello dict type">;
如果不确定这里面的CPred
怎么写,那么参照这个进行修改即可
在dialect.h
中定义Type的C++方法,用于在转化中生成和访问Dialect,同时声明使用的TypeStorage
namespace hello { struct DictTypeStorage;}
class DictType : public mlir::Type::TypeBase<DictType, mlir::Type, hello::DictTypeStorage> {public: /// Inherit some necessary constructors from 'TypeBase'. using Base::Base;
/// Create an instance of a `DictType` with the given key and value types. static DictType get(mlir::Type keyType, mlir::Type valueType);
/// Returns the key type of this dict type. mlir::Type getKeyType();
/// Returns the value type of this dict type. mlir::Type getValueType();
/// The name of this dict type. static constexpr mlir::StringLiteral name = "hello.dict";};
在dialect.cpp
中完成Type的方法补齐,下文中的getImpl()
是指从TypeStorage访问相关MLIR Type
DictType DictType::get(mlir::Type keyType, mlir::Type valueType) { return Base::get(keyType.getContext(), keyType, valueType);}
mlir::Type DictType::getKeyType() { return getImpl()->keyType;}
/// Returns the value type of this dict type.mlir::Type DictType::getValueType() { return getImpl()->valueType;}
以及在Dialect中的initialize()
方法中添加对应的Type
void HelloDialect::initialize() { addOperations<#define GET_OP_LIST#include "Hello/HelloOps.cpp.inc" >(); addTypes<DictType>();}
添加自定义的TypeStorage
MLIR存在Type和TypeStorage结构,关于TypeStorage,可以理解为Type当中数据和元数据的实际存储
由于在Dialect.h
中已经声明了DictTypeStorage
,那么在Dialect.cpp
中直接继承mlir::TypeStorage
实现就行
operator==
确定了Type之间的比较关系
hashKey
确定操作的Type是否为同一个
实现*construct
方法,使用allocator
告诉Dialect需要申请的空间(这一步操作并不会在lowering中体现)
struct DictTypeStorage : public mlir::TypeStorage { /// The `KeyTy` defines what uniquely identifies this type. /// For dict type, we unique on the key type and value type pair. using KeyTy = std::pair<mlir::Type, mlir::Type>;
/// Constructor for the type storage instance. DictTypeStorage(mlir::Type keyType, mlir::Type valueType) : keyType(keyType), valueType(valueType) {}
/// Define the comparison function for the key type. bool operator==(const KeyTy &key) const { return key.first == keyType && key.second == valueType; }
/// Define a hash function for the key type. static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_combine(key.first, key.second); }
/// Define a construction function for the key type. static KeyTy getKey(mlir::Type keyType, mlir::Type valueType) { return KeyTy(keyType, valueType); }
/// Define a construction method for creating a new instance of this storage. static DictTypeStorage *construct(mlir::TypeStorageAllocator &allocator, const KeyTy &key) { // Allocate the storage instance and construct it. return new (allocator.allocate<DictTypeStorage>()) DictTypeStorage(key.first, key.second); }
/// The key and value types of the dict. mlir::Type keyType; mlir::Type valueType;};
Type类型转化
Hello Dialect中的Dict Type需要转化为一个指针,相关的操作实际发生与Operation的下降lowering,通过下降Op的入参和返回值实现
一个方案是创建一个继承LLVMTypeConverter
的方法,添加addConversion
——这个方案理论可行,但我没试过
class HelloTypeConverter : public mlir::LLVMTypeConverter {public: HelloTypeConverter(mlir::MLIRContext *ctx) : mlir::LLVMTypeConverter(ctx) { addConversion([](hello::DictType type) { return mlir::LLVM::LLVMPointerType::get(type.getContext()); }); }};
另外一个方案是直接在typeconverter
上加上addConversion
这里还分出两种写法可供大家参考(这两种方法测试均通过):
- 直接点名转换的Type的类型
typeConverter.addConversion([](DictType type) -> mlir::Type { return mlir::LLVM::LLVMPointerType::get(type.getContext()); });
- 根据
mlir::dyn_cast
判断后确定返回
typeConverter.addConversion([](mlir::Type type) -> std::optional<mlir::Type> { if (auto dictType = mlir::dyn_cast<DictType>(type)) return mlir::LLVM::LLVMPointerType::get(type.getContext()); return std::nullopt; });
按照构造函数传入typeconverter
后,就可以使用getTypeConverter
根据类型决定传回什么参数
auto resultType = getTypeConverter()->convertType(op->getResult(0).getType());if (!resultType) { return mlir::failure();}
所以这段代码就等价于
auto resultType = mlir::LLVM::LLVMPointerType::get(context);
这部分代码相等于就是生成不可变代码——resultType不会因为Operation传入的参数的不同而发生变化
Type类型操作方法转化
对于Dict(哈希表)而言,有put
,get
,delete
这些操作,以及用于释放内存的create
和free
,这些都需要在Dialect中定义对应Op(Operation)才能实现
下面代码当中,定义AssemblyFormat
确保能从文本的mlir文件中能正确解析类型
class Hello_Op<string mnemonic, list<Trait> traits = []> : Op<Hello_Dialect, mnemonic, traits>;
def Dict_CreateOp : Hello_Op<"dict.create", [Pure]> { let summary = "Create a new dict<string,i32>"; let results = (outs DictType:$dict); let assemblyFormat = "attr-dict `:` type($dict)";}
def Dict_FreeOp : Hello_Op<"dict.free", []> { let summary = "Free the dict<string,i32> memory"; let arguments = (ins DictType:$dict); let assemblyFormat = "$dict attr-dict `:` type($dict)";}
def Dict_PutOp : Hello_Op<"dict.put", []> { let summary = "Insert string->i32"; let arguments = (ins DictType:$dict, StrAttr:$key, I32Attr:$value); let results = (outs DictType:$out); let assemblyFormat = "$dict `,` $key `=` $value attr-dict `:` type($dict) `->` type($out)";}
def Dict_GetOp : Hello_Op<"dict.get", []> { let summary = "Lookup string->i32, returns i32"; let arguments = (ins DictType:$dict, StrAttr:$key); let results = (outs I32:$value); let assemblyFormat = "$dict `,` $key attr-dict `:` type($dict) `->` type($value)";}
def Dict_DeleteOp : Hello_Op<"dict.delete", []> { let summary = "Delete key string"; let arguments = (ins DictType:$dict, StrAttr:$key); let results = (outs DictType:$out); let assemblyFormat = "$dict `,` $key attr-dict `:` type($dict) `->` type($out)";}
同时也需要写明各个Op的lowering,这里以Dict_CreateOp
的实现为例
class DictCreateOpLowering : public mlir::ConversionPattern {public: explicit DictCreateOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *context) : mlir::ConversionPattern( typeConverter, hello::CreateOp::getOperationName(), 1, context) {}
mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands, mlir::ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc();
auto resultType = getTypeConverter()->convertType(op->getResult(0).getType()); if (!resultType) { return mlir::failure(); }
mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>(); auto createMapRef = getOrInsertCreateMap(rewriter, module);
auto callOp = rewriter.create<mlir::LLVM::CallOp>( loc, resultType, createMapRef, mlir::ValueRange{});
rewriter.replaceOp(op, callOp.getResult()); return mlir::success(); }
private: mlir::FlatSymbolRefAttr getOrInsertCreateMap(mlir::PatternRewriter &rewriter, mlir::ModuleOp module) const { auto *context = module.getContext(); if (module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("create_map")) return mlir::SymbolRefAttr::get(context, "create_map");
// Create function type: () -> !llvm.ptr auto resultType = mlir::LLVM::LLVMPointerType::get(context); auto fnType = mlir::LLVM::LLVMFunctionType::get(resultType, std::nullopt, false);
// Insert function declaration mlir::PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create<mlir::LLVM::LLVMFuncOp>(module.getLoc(), "create_map", fnType);
return mlir::SymbolRefAttr::get(context, "create_map"); }};
这个Op在做的事情就是:将Dict_CreateOp
转化为对于create_map
这个C/LLVM IR函数的调用,只要能实现,就能用C的printf()
完成输出。
定义MLIR文本解析规则
这是非常重要的一部分,需要在声明对于自定义Type的解析范式,即printType
和parseType
为了做这一步,需要将Dialect的useDefaultTypePrinterParser
设为1
def Hello_Dialect : Dialect { let name = "hello"; let summary = "A hello out-of-tree MLIR dialect."; let description = [{ This dialect is minimal example to implement hello-world kind of sample code for MLIR. }]; let cppNamespace = "::hello"; let useDefaultTypePrinterParser = 1;}
在对应的dialect.cpp
中定义printType
和parseType
mlir::Type DictType::getValueType() { return getImpl()->valueType;}
mlir::Type HelloDialect::parseType(mlir::DialectAsmParser &parser) const { llvm::StringRef typeTag; if (parser.parseKeyword(&typeTag)) return mlir::Type();
if (typeTag == "dict") { if (parser.parseLess()) return mlir::Type();
mlir::Type keyType; if (parser.parseType(keyType)) return mlir::Type();
if (parser.parseComma()) return mlir::Type();
mlir::Type valueType; if (parser.parseType(valueType)) return mlir::Type();
if (parser.parseGreater()) return mlir::Type();
return DictType::get(keyType, valueType); }
parser.emitError(parser.getNameLoc(), "unknown hello type: ") << typeTag; return mlir::Type();}
void HelloDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &printer) const { if (auto dictType = mlir::dyn_cast<DictType>(type)) { printer << "dict<"; printer.printType(dictType.getKeyType()); printer << ", "; printer.printType(dictType.getValueType()); printer << ">"; return; }
llvm_unreachable("unhandled hello type");}
运行!
准备一份MLIR代码
module { func.func @test_dict_operations() -> i32 { %dict0 = hello.dict.create : !hello.dict<index, i32>
%dict1 = hello.dict.put %dict0, "first1" = 100 : !hello.dict<index, i32> -> !hello.dict<index, i32> %dict2 = hello.dict.put %dict1, "second" = 200 : !hello.dict<index, i32> -> !hello.dict<index, i32> %dict3 = hello.dict.put %dict2, "third" = 300 : !hello.dict<index, i32> -> !hello.dict<index, i32>
%val1 = hello.dict.get %dict3, "first1" : !hello.dict<index, i32> -> i32 %val2 = hello.dict.get %dict3, "second" : !hello.dict<index, i32> -> i32 %val3 = hello.dict.get %dict3, "third" : !hello.dict<index, i32> -> i32
%dict4 = hello.dict.delete %dict3, "second" : !hello.dict<index, i32> -> !hello.dict<index, i32>
hello.dict.free %dict4 : !hello.dict<index, i32>
func.return %val2 : i32 }}
将MLIR下降到LLVM IR,我们看一看看打印出来的LLVM IR的样子
; ModuleID = 'LLVMDialectModule'source_filename = "LLVMDialectModule"target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"target triple = "x86_64-pc-linux-gnu"
@key_third = internal constant [6 x i8] c"third\00"@key_second = internal constant [7 x i8] c"second\00"@key_first1 = internal constant [7 x i8] c"first1\00"
declare void @free_map(ptr) local_unnamed_addr
declare void @delete(ptr, ptr) local_unnamed_addr
declare ptr @get(ptr, ptr) local_unnamed_addr
declare void @put(ptr, ptr, i32) local_unnamed_addr
declare ptr @create_map() local_unnamed_addr
define i32 @test_dict_operations() local_unnamed_addr { %1 = tail call ptr @create_map() tail call void @put(ptr %1, ptr nonnull @key_first1, i32 100) tail call void @put(ptr %1, ptr nonnull @key_second, i32 200) tail call void @put(ptr %1, ptr nonnull @key_third, i32 300) %2 = tail call ptr @get(ptr %1, ptr nonnull @key_first1) %3 = tail call ptr @get(ptr %1, ptr nonnull @key_second) %4 = load i32, ptr %3, align 4 %5 = tail call ptr @get(ptr %1, ptr nonnull @key_third) tail call void @delete(ptr %1, ptr nonnull @key_second) tail call void @free_map(ptr %1) ret i32 %4}
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}
大部分转化为了C/LLVM IR的函数调用,这样就变成了我们所熟悉的问题,和前文的printf()
行为一样
在C/C++的实现相关接口函数,编译并调用运行
../../build/bin/hello-opt dict.mlir -emit=llvm > dict.llclang-20 dict.c dict.ll -o dict
extern int test_dict_operations();
int main() { printf("Starting dictionary test...\n"); int result = test_dict_operations(); printf("Result from test_dict_operations: %d\n", result);
return 0;}
预期结果为
Starting dictionary test...Result from test_dict_operations: 200
思考内容
可以看到这种实现方式的局限性:只能适配于<string,i32>
结构的哈希表,不能适用于其他类型。这个问题可以通过动态生成LLVM IR解决,具体操作就是另外一个问题了😂
还是建议要自己上手尝试才加深理解