【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】十三,如何在MLIR里面寫(xiě)Pass?
【GiantPandaCV導(dǎo)語(yǔ)】這篇文章是學(xué)習(xí)了比較久然后按照自己的理解步驟重新總結(jié)了下來(lái),主要是MLIR Toy Tutorials第3,4篇文章的內(nèi)容。這里主要講解了如何在MLIR中自定義Pass,這里主要以消除連續(xù)的Transpose操作和Reshape操作,內(nèi)聯(lián)優(yōu)化Pass,形狀推導(dǎo)Pass 4個(gè)例子來(lái)介紹了在MLIR中定義Pass的各種技巧,實(shí)際上也并不難理解。但要入門(mén)MLIR掌握這些Pass實(shí)現(xiàn)的技巧是有必要的。「我在從零開(kāi)始學(xué)習(xí)深度學(xué)習(xí)編譯器的過(guò)程中維護(hù)了一個(gè)project:https://github.com/BBuf/tvm_mlir_learn ,主要是記錄學(xué)習(xí)筆記以及一些實(shí)驗(yàn)性代碼,目前已經(jīng)獲得了150+ star,對(duì)深度學(xué)習(xí)編譯器感興趣的小伙伴可以看一下,能點(diǎn)個(gè)star就更受寵若驚了?!?/strong>
前言
在【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】十一,初識(shí)MLIR 和 【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】十二,MLIR Toy Tutorials學(xué)習(xí)筆記一 這兩篇文章中,我們已經(jīng)初步了解了MLIR為何物,并且講到了Toy語(yǔ)言從源文件生成MLIR的具體過(guò)程,以及在這個(gè)過(guò)程中MLIR中的MLIRGen,Dialect,Operation以及TableGen這幾個(gè)MLIR的核心組成部分以及它們是如何相互作用的。
這篇筆記將基于Toy Tutorials總結(jié)MLIR中的表達(dá)式變形是如何實(shí)現(xiàn)的。
Chapter3: MLIR中的表達(dá)式變形(如何寫(xiě)Pass)
在Chapter2中我們已經(jīng)生成了初級(jí)的合法MLIR表達(dá)式,但MLIR表達(dá)式一般還可以被進(jìn)一步處理和簡(jiǎn)化,可以類(lèi)比于TVM的Pass對(duì)Relay IR的優(yōu)化。這里我們來(lái)看看要對(duì)初級(jí)的MLIR表達(dá)式進(jìn)行變形是如何做的?在MLIR中是基于表達(dá)式匹配和重寫(xiě)來(lái)完成MLIR表達(dá)式變形的。這個(gè)教程中分別介紹使用C++模板匹配和重寫(xiě)以及基于DRR框架(https://mlir.llvm.org/docs/DeclarativeRewrites/)來(lái)定義表達(dá)式重寫(xiě)規(guī)則,然后使用ODS框架來(lái)自動(dòng)生成代碼。
使用C++模式匹配和重寫(xiě)的方法優(yōu)化轉(zhuǎn)置(Transpose)操作
這里的目標(biāo)是要消除兩個(gè)具有相互抵消效果的轉(zhuǎn)置序列:transpose(transpose(X)) -> X,即對(duì)同一個(gè)輸入進(jìn)行連續(xù)的Transpose操作肯定存在冗余的操作。該操作對(duì)應(yīng)的源碼如下(在mlir/test/Examples/Toy/Ch3/transpose_transpose.toy中):
def?transpose_transpose(x)?{
??return?transpose(transpose(x));
}
如果不使用任何優(yōu)化Pass,我們看下這個(gè)Toy源程序生成的MLIR表達(dá)式是什么樣子的,使用下面的命令產(chǎn)生MLIR:./toyc-ch3 ../../mlir/test/Examples/Toy/Ch3/transpose_transpose.toy -emit=mlir。
func?@transpose_transpose(%arg0:?tensor<*xf64>)?->?tensor<*xf64>?{
????%0?=?toy.transpose(%arg0?:?tensor<*xf64>)?to?tensor<*xf64>
????%1?=?toy.transpose(%0?:?tensor<*xf64>)?to?tensor<*xf64>
????toy.return?%1?:?tensor<*xf64>
??}
可以看到生成的MLIR表達(dá)式中對(duì)x進(jìn)行了兩次真正的transpose操作,并且返回了兩次transpose之后的Tensor。但實(shí)際上這兩次transpose是不必要的,因?yàn)檩敵龅慕Y(jié)果其實(shí)就是傳入的x。所以為了優(yōu)化這種情況,我們先使用C++方式來(lái)寫(xiě)出表達(dá)式匹配和重寫(xiě)的代碼(在mlir/examples/toy/Ch3/mlir/ToyCombine.cpp中):
///?This?is?an?example?of?a?c++?rewrite?pattern?for?the?TransposeOp.?It
///?optimizes?the?following?scenario:?transpose(transpose(x))?->?x
struct?SimplifyRedundantTranspose?:?public?mlir::OpRewritePattern?{
??///?We?register?this?pattern?to?match?every?toy.transpose?in?the?IR.
??///?The?"benefit"?is?used?by?the?framework?to?order?the?patterns?and?process
??///?them?in?order?of?profitability.
??SimplifyRedundantTranspose(mlir::MLIRContext?*context)
??????:?OpRewritePattern(context,?/*benefit=*/1)?{}
??///?This?method?attempts?to?match?a?pattern?and?rewrite?it.?The?rewriter
??///?argument?is?the?orchestrator?of?the?sequence?of?rewrites.?The?pattern?is
??///?expected?to?interact?with?it?to?perform?any?changes?to?the?IR?from?here.
??mlir::LogicalResult
??matchAndRewrite(TransposeOp?op,
??????????????????mlir::PatternRewriter?&rewriter)?const?override?{
????//?Look?through?the?input?of?the?current?transpose.
????mlir::Value?transposeInput?=?op.getOperand();
????TransposeOp?transposeInputOp?=?transposeInput.getDefiningOp();
????//?Input?defined?by?another?transpose??If?not,?no?match.
????if?(!transposeInputOp)
??????return?failure();
????//?Otherwise,?we?have?a?redundant?transpose.?Use?the?rewriter.
????rewriter.replaceOp(op,?{transposeInputOp.getOperand()});
????return?success();
??}
};
可以看到在matchAndRewrite函數(shù)中,首先獲取當(dāng)前操作的操作數(shù),然后判斷當(dāng)前位置的操作數(shù)對(duì)應(yīng)的操作是否為轉(zhuǎn)置,如果是就將表達(dá)式重寫(xiě)為內(nèi)層轉(zhuǎn)置操作的操作數(shù),不然就不需要進(jìn)行優(yōu)化,保持現(xiàn)狀。
接下來(lái),需要在歸范化框架(Canonicalization Framework)中注冊(cè)剛剛創(chuàng)建的匹配重寫(xiě)模式,使得框架可以調(diào)用它。對(duì)于Canonicalization 更多的介紹請(qǐng)看https://mlir.llvm.org/docs/Canonicalization/,注冊(cè)的代碼如下(代碼仍在:mlir/examples/toy/Ch3/mlir/ToyCombine.cpp):
///?Register?our?patterns?as?"canonicalization"?patterns?on?the?TransposeOp?so
///?that?they?can?be?picked?up?by?the?Canonicalization?framework.
void?TransposeOp::getCanonicalizationPatterns(RewritePatternSet?&results,
??????????????????????????????????????????????MLIRContext?*context)?{
??results.add(context);
}
在我們將表達(dá)式重寫(xiě)規(guī)則添加到了規(guī)范化框架后,我們還需要修改一下定義Operator的td文件,啟用規(guī)范化框架,同時(shí)在定義Operator添加一個(gè)“無(wú)副作用的”(NoSideEffect)新特征,現(xiàn)在Transpose操作的定義如下:
def?TransposeOp?:?Toy_Op<"transpose",?[NoSideEffect]>?{
??let?summary?=?"transpose?operation";
??let?arguments?=?(ins?F64Tensor:$input);
??let?results?=?(outs?F64Tensor);
??let?assemblyFormat?=?[{
????`(`?$input?`:`?type($input)?`)`?attr-dict?`to`?type(results)
??}];
??//?Enable?registering?canonicalization?patterns?with?this?operation.
??let?hasCanonicalizer?=?1;
??//?Allow?building?a?TransposeOp?with?from?the?input?operand.
??let?builders?=?[
????OpBuilder<(ins?"Value":$input)>
??];
??//?Invoke?a?static?verify?method?to?verify?this?transpose?operation.
??let?verifier?=?[{?return?::verify(*this);?}];
}
最后,我們需要在主程序中將基于規(guī)范化框架的優(yōu)化添加到運(yùn)行流程里,這部分代碼在mlir/examples/toy/Ch3/toyc.cpp中的dumpMLIR函數(shù)里面。如下圖的紅框部分:

至此,我們就完成了基于C++的MLIR表達(dá)式匹配和重寫(xiě),我們可以通過(guò)下面的命令來(lái)看下經(jīng)過(guò)上面transpose表達(dá)式的重寫(xiě)后產(chǎn)生的MLIR表達(dá)式是否已經(jīng)去掉了transpose。命令為:./toyc-ch3 ../../mlir/test/Examples/Toy/Ch3/transpose_transpose.toy -emit=mlir -opt。結(jié)果為:
func?@transpose_transpose(%arg0:?tensor<*xf64>)?->?tensor<*xf64>?{
????toy.return?%arg0?:?tensor<*xf64>
??}
可以看到優(yōu)化后的MLIR表達(dá)式已經(jīng)去掉了transpose操作了,達(dá)到了優(yōu)化效果。
使用 DRR 優(yōu)化張量變形(Reshape)操作
MLIR還提供了一種表達(dá)式重寫(xiě)的方法,是基于DDR規(guī)則的方式來(lái)自動(dòng)生成表達(dá)式匹配和重寫(xiě)函數(shù),代碼生成的部分仍然基于ODS框架實(shí)現(xiàn)。DRR(Declarative, Rule-based Pattern-match and Rewrite):聲明性、基于規(guī)則的模式匹配和重寫(xiě)方法。它是一種基于 DAG 的聲明性重寫(xiě)器,提供基于表格的模式匹配和重寫(xiě)規(guī)則的句法。
這里以消除MLIR表達(dá)式中冗余的張量reshape操作為例,對(duì)應(yīng)的Toy源文件如下(在mlir/test/Examples/Toy/Ch3/trivial_reshape.toy中):
def?main()?{
??var?a<2,1>?=?[1,?2];
??var?b<2,1>?=?a;
??var?c<2,1>?=?b;
??print(c);
}
使用下面的命令先產(chǎn)生對(duì)應(yīng)的MLIR表達(dá)式看看:./toyc-ch3 ../../mlir/test/Examples/Toy/Ch3/trivial_reshape.toy -emit=mlir
module??{
??func?@main()?{
????%0?=?toy.constant?dense<[1.000000e+00,?2.000000e+00]>?:?tensor<2xf64>
????%1?=?toy.reshape(%0?:?tensor<2xf64>)?to?tensor<2x1xf64>
????%2?=?toy.reshape(%1?:?tensor<2x1xf64>)?to?tensor<2x1xf64>
????%3?=?toy.reshape(%2?:?tensor<2x1xf64>)?to?tensor<2x1xf64>
????toy.print?%3?:?tensor<2x1xf64>
????toy.return
??}
}
很明顯a,b,c的shape和值都是一樣的,這些reshape操作是多余的。下面我們要基于DDR框架來(lái)定義表達(dá)式匹配和重寫(xiě)規(guī)則。這里要分幾種情況考慮(這里的代碼實(shí)現(xiàn)都在mlir/examples/toy/Ch3/mlir/ToyCombine.td)。
解決 Reshape(Reshape(x)) = Reshape(x)產(chǎn)生的冗余代碼。
//?Reshape(Reshape(x))?=?Reshape(x)
def?ReshapeReshapeOptPattern?:?Pat<(ReshapeOp(ReshapeOp?$arg)),
???????????????????????????????????(ReshapeOp?$arg)>;
即將ReshapeOp(ReshapeOp $arg)替換為 ReshapeOp $arg。對(duì)于多次相同的張量變形操作,執(zhí)行一次即可。
當(dāng)reshape的參數(shù)和結(jié)果的類(lèi)型是一樣的,就說(shuō)明這個(gè)整型操作是沒(méi)用的,因此直接返回輸入?yún)?shù)即可,即 Reshape(x) = x。
//?Reshape(x)?=?x,?where?input?and?output?shapes?are?identical
def?TypesAreIdentical?:?Constraint"$0.getType()?==?$1.getType()">>;
def?RedundantReshapeOptPattern?:?Pat<
??(ReshapeOp:$res?$arg),?(replaceWithValue?$arg),
??[(TypesAreIdentical?$res,?$arg)]>;
即當(dāng)0.getType()與1.getType()相同時(shí)即為冗余,使用操作數(shù)$arg代替。
接下來(lái)我們就可以使用 ODS 框架和定義好的 ToyCombine.td 文件,自動(dòng)化生成代碼文件 ToyCombine.inc。使用下面的命令:
$???cd?llvm-project/build
$???./bin/mlir-tblgen?--gen-rewriters?${mlir_src_root}/examples/toy/Ch3/mlir/ToyCombine.td?-I?${mlir_src_root}/include/
當(dāng)然構(gòu)建工程的時(shí)候也可以將這個(gè)生成過(guò)程配置在cmakelists.txt中:mlir/examples/toy/Ch3/CMakeLists.txt。如下:
set(LLVM_TARGET_DEFINITIONS?mlir/ToyCombine.td)
mlir_tablegen(ToyCombine.inc?-gen-rewriters)
add_public_tablegen_target(ToyCh3CombineIncGen)
最后,我們可以執(zhí)行./toyc-ch3 ../../mlir/test/Examples/Toy/Ch3/trivial_reshape.toy -emit=mlir -opt生成經(jīng)過(guò)這些Pass優(yōu)化的MLIR表達(dá)式:
module??{
??func?@main()?{
????%0?=?toy.constant?dense<[[1.000000e+00],?[2.000000e+00]]>?:?tensor<2x1xf64>
????toy.print?%0?:?tensor<2x1xf64>
????toy.return
??}
}
Chapter4: 實(shí)現(xiàn)泛化的表達(dá)式轉(zhuǎn)化
在Chapter3里面我們學(xué)到了如何在MLIR里面實(shí)現(xiàn)表達(dá)式重寫(xiě),但上面也有一個(gè)非常明顯的問(wèn)題:我們?yōu)門(mén)oy語(yǔ)言實(shí)現(xiàn)的Pass在其它的Dialect抽象中沒(méi)辦法重用,因?yàn)檫@里只是針對(duì)Toy語(yǔ)言的一些Operation的特化操作,如果為每種Dialect實(shí)現(xiàn)每種轉(zhuǎn)化會(huì)導(dǎo)致大量重復(fù)代碼。所以,這一節(jié)以兩個(gè)例子為例講解如何在MLIR中實(shí)現(xiàn)泛化的表達(dá)式。
本文使用下面的例子進(jìn)行介紹(在mlir/test/Examples/Toy/Ch5/codegen.toy):
def?multiply_transpose(a,?b)?{
??return?transpose(a)?*?transpose(b);
}
def?main()?{
??var?a<2,?3>?=?[[1,?2,?3],?[4,?5,?6]];
??var?b<2,?3>?=?[1,?2,?3,?4,?5,?6];
??var?c?=?multiply_transpose(a,?b);
??var?d?=?multiply_transpose(b,?a);
??print(d);
}
我們先看一下它對(duì)應(yīng)的MLIR表達(dá)式./toyc-ch4 ../../mlir/test/Examples/Toy/Ch4/codegen.toy -emit=mlir:
module??{
??func?private?@multiply_transpose(%arg0:?tensor<*xf64>,?%arg1:?tensor<*xf64>)?->?tensor<*xf64>?{
????%0?=?toy.transpose(%arg0?:?tensor<*xf64>)?to?tensor<*xf64>
????%1?=?toy.transpose(%arg1?:?tensor<*xf64>)?to?tensor<*xf64>
????%2?=?toy.mul?%0,?%1?:?tensor<*xf64>
????toy.return?%2?:?tensor<*xf64>
??}
??func?@main()?{
????%0?=?toy.constant?dense<[[1.000000e+00,?2.000000e+00,?3.000000e+00],?[4.000000e+00,?5.000000e+00,?6.000000e+00]]>?:?tensor<2x3xf64>
????%1?=?toy.reshape(%0?:?tensor<2x3xf64>)?to?tensor<2x3xf64>
????%2?=?toy.constant?dense<[1.000000e+00,?2.000000e+00,?3.000000e+00,?4.000000e+00,?5.000000e+00,?6.000000e+00]>?:?tensor<6xf64>
????%3?=?toy.reshape(%2?:?tensor<6xf64>)?to?tensor<2x3xf64>
????%4?=?toy.generic_call?@multiply_transpose(%1,?%3)?:?(tensor<2x3xf64>,?tensor<2x3xf64>)?->?tensor<*xf64>
????%5?=?toy.generic_call?@multiply_transpose(%3,?%1)?:?(tensor<2x3xf64>,?tensor<2x3xf64>)?->?tensor<*xf64>
????toy.print?%5?:?tensor<*xf64>
????toy.return
??}
}
這個(gè)是沒(méi)有優(yōu)化前的MLIR表達(dá)式,我們可以看到在實(shí)例化Tensor之前Tensor的形狀是未知的,即表達(dá)式中的tensor<*xf64>。這樣會(huì)對(duì)后續(xù)的Pass,比如我們?cè)贑hapter3中定義的形狀相關(guān)的Pass造成影響,導(dǎo)致優(yōu)化不到位。所以我們希望在執(zhí)行Reshape相關(guān)的Pass之前可以知道每個(gè)Tensor的形狀,所以這里會(huì)介紹一個(gè)Shape推斷Pass的實(shí)現(xiàn)。另外,還介紹了一個(gè)內(nèi)聯(lián)Pass,來(lái)降低函數(shù)調(diào)用的開(kāi)銷(xiāo)。
內(nèi)聯(lián)Pass
觀察上面的代碼我們可以發(fā)現(xiàn)multiply_transpose這種小函數(shù)被頻繁調(diào)用,這個(gè)時(shí)候函數(shù)調(diào)用本身的開(kāi)銷(xiāo)就不容忽視。所以這里定義一個(gè)內(nèi)聯(lián)Pass希望把multiply_transpose這個(gè)函數(shù)變成內(nèi)聯(lián)函數(shù)以提高運(yùn)行效率。
第一步
MLIR提供了一個(gè)處理內(nèi)聯(lián)的通用接口DialectInlinerInterface ,它包含一組Dialect可以重寫(xiě)的虛擬鉤子,我們要基于這個(gè)類(lèi)為T(mén)oy Operation定義內(nèi)聯(lián)的接口和表達(dá)式重寫(xiě)規(guī)則。代碼實(shí)現(xiàn)在:mlir/examples/toy/Ch5/mlir/Dialect.cpp:
///?This?class?defines?the?interface?for?handling?inlining?with?Toy?operations.
///?We?simplify?inherit?from?the?base?interface?class?and?override
///?the?necessary?methods.
struct?ToyInlinerInterface?:?public?DialectInlinerInterface?{
??using?DialectInlinerInterface::DialectInlinerInterface;
??///?This?hook?checks?to?see?if?the?given?callable?operation?is?legal?to?inline
??///?into?the?given?call.?For?Toy?this?hook?can?simply?return?true,?as?the?Toy
??///?Call?operation?is?always?inlinable.
??bool?isLegalToInline(Operation?*call,?Operation?*callable,
???????????????????????bool?wouldBeCloned)?const?final?{
????return?true;
??}
??///?This?hook?checks?to?see?if?the?given?operation?is?legal?to?inline?into?the
??///?given?region.?For?Toy?this?hook?can?simply?return?true,?as?all?Toy
??///?operations?are?inlinable.
??bool?isLegalToInline(Operation?*,?Region?*,?bool,
???????????????????????BlockAndValueMapping?&)?const?final?{
????return?true;
??}
??///?This?hook?is?called?when?a?terminator?operation?has?been?inlined.?The?only
??///?terminator?that?we?have?in?the?Toy?dialect?is?the?return
??///?operation(toy.return).?We?handle?the?return?by?replacing?the?values
??///?previously?returned?by?the?call?operation?with?the?operands?of?the
??///?return.
??void?handleTerminator(Operation?*op,
????????????????????????ArrayRef?valuesToRepl) ?const?final?{
????//?Only?"toy.return"?needs?to?be?handled?here.
????auto?returnOp?=?cast(op);
????//?Replace?the?values?directly?with?the?return?operands.
????assert(returnOp.getNumOperands()?==?valuesToRepl.size());
????for?(const?auto?&it?:?llvm::enumerate(returnOp.getOperands()))
??????valuesToRepl[it.index()].replaceAllUsesWith(it.value());
??}
};
這部分代碼為T(mén)oy Operation定義了內(nèi)聯(lián)的接口和表達(dá)式變形的規(guī)則,兩個(gè)isLegalToInline重載函數(shù)是兩個(gè)鉤子。第一個(gè)鉤子用來(lái)檢查給定的可調(diào)用操作callable內(nèi)聯(lián)到給定調(diào)用call中是否合法,檢查是否可以內(nèi)聯(lián)。第二個(gè)鉤子用來(lái)檢查給定的操作是否合法地內(nèi)聯(lián)到給定的區(qū)域。handleTerminator函數(shù)只是處理toy.return,將返回操作的操作數(shù)it.index()直接用返回值it.value()代替(這里沒(méi)太懂QAQ)。
第二步
接著,需要在Toy Dialect的定義中添加上面的表達(dá)式變形規(guī)則,位置在mlir/examples/toy/Ch5/mlir/Dialect.cpp。
///?Dialect?initialization,?the?instance?will?be?owned?by?the?context.?This?is
///?the?point?of?registration?of?types?and?operations?for?the?dialect.
void?ToyDialect::initialize()?{
??addOperations<
#define?GET_OP_LIST
#include?"toy/Ops.cpp.inc"
??????>();
??addInterfaces();
}
這里的addInterfaces就是注冊(cè)內(nèi)聯(lián)Pass的過(guò)程,其中ToyInlinerInterface就是我們定義的表達(dá)式變形規(guī)則。
第三步
再接著,我們需要讓內(nèi)聯(lián)器inliner知道IR中toy.generic_call表示的是調(diào)用一個(gè)函數(shù)。MLIR提供了一個(gè)Operation接口CallOpInterface可以將某個(gè)Operation標(biāo)記為調(diào)用。添加上述操作需要在Toy Dialect的定義(mlir/examples/toy/Ch5/include/toy/Ops.td)文件中加入include "mlir/Interfaces/CallInterfaces.td"這行代碼。
然后在Dialect定義部分添加一個(gè)新的Operation,代碼如下所示:
def?GenericCallOp?:?Toy_Op<"generic_call",
????[DeclareOpInterfaceMethods]>?{
??let?summary?=?"generic?call?operation";
??let?description?=?[{
????Generic?calls?represent?calls?to?a?user?defined?function?that?needs?to
????be?specialized?for?the?shape?of?its?arguments.?The?callee?name?is?attached
????as?a?symbol?reference?via?an?attribute.?The?arguments?list?must?match?the
????arguments?expected?by?the?callee.?For?example:
????```mlir
?????%4?=?toy.generic_call?@my_func(%1,?%3)
???????????:?(tensor<2x3xf64>,?tensor<2x3xf64>)?->?tensor<*xf64>
????```
????This?is?only?valid?if?a?function?named?"my_func"?exists?and?takes?two
????arguments.
??}];
??//?The?generic?call?operation?takes?a?symbol?reference?attribute?as?the
??//?callee,?and?inputs?for?the?call.
??let?arguments?=?(ins?FlatSymbolRefAttr:$callee,?Variadic:$inputs);
??//?The?generic?call?operation?returns?a?single?value?of?TensorType.
??let?results?=?(outs?F64Tensor);
??//?Specialize?assembly?printing?and?parsing?using?a?declarative?format.
??let?assemblyFormat?=?[{
????$callee?`(`?$inputs?`)`?attr-dict?`:`?functional-type($inputs,?results)
??}];
??//?Add?custom?build?methods?for?the?generic?call?operation.
??let?builders?=?[
????OpBuilder<(ins?"StringRef":$callee,?"ArrayRef" :$arguments)>
??];
}
解釋?zhuān)何覀兪褂昧?code style="font-size: 14px;word-wrap: break-word;padding: 2px 4px;border-radius: 4px;margin: 0 2px;background-color: rgba(27,31,35,.05);font-family: Operator Mono, Consolas, Monaco, Menlo, monospace;word-break: break-all;color: #595959;">DeclareOpInterfaceMethods在CallOpInterface的聲明中聲明所用的接口方法。DeclareOpInterfaceMethods這個(gè)特征說(shuō)明程序會(huì)識(shí)別generic_call操作(在原始的MLIR表達(dá)式中對(duì)應(yīng)toy.generic_call),并在該位置調(diào)用接口函數(shù)。
然后在mlir/examples/toy/Ch5/mlir/Dialect.cpp中實(shí)現(xiàn)了GenericCallOp的功能,代碼如下:
///?Return?the?callee?of?the?generic?call?operation,?this?is?required?by?the
///?call?interface.
CallInterfaceCallable?GenericCallOp::getCallableForCallee()?{
??return?getAttrOfType("callee");
}
///?Get?the?argument?operands?to?the?called?function,?this?is?required?by?the
///?call?interface.
Operation::operand_range?GenericCallOp::getArgOperands()?{?return?inputs();?}
上面的GenericCallOp::getCallableForCallee() {...} 返回泛化調(diào)用Operation的被調(diào)用方。而GenericCallOp::getArgOperands(){...}用來(lái)獲取被調(diào)用函數(shù)的參數(shù)操作數(shù)。
第四步
下面需要在Dialect定義中添加cast操作并設(shè)置調(diào)用的接口。為什么需要添加cast操作呢?這是因?yàn)樵诤瘮?shù)調(diào)用時(shí),輸入張量的類(lèi)型是確定的。但在函數(shù)定義的時(shí)候,輸入張量的類(lèi)型是不確定的(泛化類(lèi)型,這一點(diǎn)可以從上面的原始版本MLIR表達(dá)式中看出來(lái))。因此在調(diào)用的時(shí)候就需要一個(gè)隱藏的數(shù)據(jù)類(lèi)型轉(zhuǎn)換,否則無(wú)法進(jìn)行內(nèi)聯(lián)操作,因此這里引入了一個(gè)cast。cast操作可以將確定的數(shù)據(jù)類(lèi)型轉(zhuǎn)換為函數(shù)期望的數(shù)據(jù)類(lèi)型。下面在mlir/examples/toy/Ch5/include/toy/Ops.td中添加cast操作:
def?CastOp?:?Toy_Op<"cast",?[
?????DeclareOpInterfaceMethods,
?????DeclareOpInterfaceMethods,
?????NoSideEffect,
?????SameOperandsAndResultShape
??]>?{
??let?summary?=?"shape?cast?operation";
??let?description?=?[{
????The?"cast"?operation?converts?a?tensor?from?one?type?to?an?equivalent?type
????without?changing?any?data?elements.?The?source?and?destination?types?must
????both?be?tensor?types?with?the?same?element?type.?If?both?are?ranked,?then
????shape?is?required?to?match.?The?operation?is?invalid?if?converting?to?a
????mismatching?constant?dimension.
??}];
??let?arguments?=?(ins?F64Tensor:$input);
??let?results?=?(outs?F64Tensor:$output);
??let?assemblyFormat?=?"$input?attr-dict?`:`?type($input)?`to`?type($output)";
}
我們使用了DeclareOpInterfaceMethods在CallOpInterface的聲明中聲明所用的接口方法。DeclareOpInterfaceMethods這個(gè)特征說(shuō)明程序會(huì)識(shí)別cast操作。
接下來(lái)還需要重寫(xiě)cast op的areCastCompatible方法(在mlir/examples/toy/Ch5/mlir/Dialect.cpp中):
///?Returns?true?if?the?given?set?of?input?and?result?types?are?compatible?with
///?this?cast?operation.?This?is?required?by?the?`CastOpInterface`?to?verify
///?this?operation?and?provide?other?additional?utilities.
bool?CastOp::areCastCompatible(TypeRange?inputs,?TypeRange?outputs)?{
??if?(inputs.size()?!=?1?||?outputs.size()?!=?1)
????return?false;
??//?The?inputs?must?be?Tensors?with?the?same?element?type.
??TensorType?input?=?inputs.front().dyn_cast();
??TensorType?output?=?outputs.front().dyn_cast();
??if?(!input?||?!output?||?input.getElementType()?!=?output.getElementType())
????return?false;
??//?The?shape?is?required?to?match?if?both?types?are?ranked.
??return?!input.hasRank()?||?!output.hasRank()?||?input?==?output;
}
這個(gè)方法用來(lái)判斷是否需要進(jìn)行類(lèi)型轉(zhuǎn)換,如果inputs和outputs的類(lèi)型是兼容的澤返回真,否則需要進(jìn)行類(lèi)型轉(zhuǎn)換(cast)返回假。
另外我們還需要重寫(xiě)ToyInlinerInterface 上的鉤子,即materializeCallConversion函數(shù):
struct?ToyInlinerInterface?:?public?DialectInlinerInterface?{
??....
??///?Attempts?to?materialize?a?conversion?for?a?type?mismatch?between?a?call
??///?from?this?dialect,?and?a?callable?region.?This?method?should?generate?an
??///?operation?that?takes?'input'?as?the?only?operand,?and?produces?a?single
??///?result?of?'resultType'.?If?a?conversion?can?not?be?generated,?nullptr
??///?should?be?returned.
??Operation?*materializeCallConversion(OpBuilder?&builder,?Value?input,
???????????????????????????????????????Type?resultType,
???????????????????????????????????????Location?conversionLoc)?const?final?{
????return?builder.create(conversionLoc,?resultType,?input);
??}
};
這個(gè)函數(shù)是內(nèi)聯(lián)Pass的入口。
第五步
將內(nèi)聯(lián)Pass添加到優(yōu)化pipline中,在mlir/examples/toy/Ch5/toyc.cpp中:
if?(enableOpt)?{
????mlir::PassManager?pm(&context);
????//?Apply?any?generic?pass?manager?command?line?options?and?run?the?pipeline.
????applyPassManagerCLOptions(pm);
????//?Inline?all?functions?into?main?and?then?delete?them.
????pm.addPass(mlir::createInlinerPass());
...
}
經(jīng)過(guò)pm.addPass(mlir::createInlinerPass());這一行,優(yōu)化pipline里面就有了內(nèi)聯(lián)Pass了。
我們看一下經(jīng)過(guò)內(nèi)聯(lián)優(yōu)化Pass過(guò)后原始的MLIR表達(dá)式變成什么樣子了:
func?@main()?{
??%0?=?"toy.constant"()?{value?=?dense<[[1.000000e+00,?2.000000e+00,?3.000000e+00],?[4.000000e+00,?5.000000e+00,?6.000000e+00]]>?:?tensor<2x3xf64>}?:?()?->?tensor<2x3xf64>
??%1?=?"toy.constant"()?{value?=?dense<[[1.000000e+00,?2.000000e+00,?3.000000e+00],?[4.000000e+00,?5.000000e+00,?6.000000e+00]]>?:?tensor<2x3xf64>}?:?()?->?tensor<2x3xf64>
??%2?=?"toy.cast"(%1)?:?(tensor<2x3xf64>)?->?tensor<*xf64>
??%3?=?"toy.cast"(%0)?:?(tensor<2x3xf64>)?->?tensor<*xf64>
??%4?=?"toy.transpose"(%2)?:?(tensor<*xf64>)?->?tensor<*xf64>
??%5?=?"toy.transpose"(%3)?:?(tensor<*xf64>)?->?tensor<*xf64>
??%6?=?"toy.mul"(%4,?%5)?:?(tensor<*xf64>,?tensor<*xf64>)?->?tensor<*xf64>
??toy.print?%6?:?tensor<*xf64>
??toy.return
}
現(xiàn)在MLIR表達(dá)式只有一個(gè)主函數(shù),之前的transpose函數(shù)被內(nèi)聯(lián)了,并且可以看到toy.cast實(shí)現(xiàn)的功能。
Shape推斷 Pass
上面內(nèi)聯(lián)Pass實(shí)現(xiàn)了將確定類(lèi)型的Tensor轉(zhuǎn)換成了泛化類(lèi)型的Tensor,進(jìn)而使得內(nèi)聯(lián)操作得以完成。然后接下來(lái),我們需要根據(jù)形狀確定的Tensor來(lái)推導(dǎo)那些泛化Tensor的形狀。這里需要利用ODS框架來(lái)生成自定義的Operation接口來(lái)推導(dǎo)泛化Tensor的形狀。整個(gè)Shape推斷的過(guò)程也會(huì)和inline一樣抽象成一個(gè)Pass作用在MLIR表達(dá)式上。
第一步:使用ODS框架定義Shape推斷Operation接口
代碼實(shí)現(xiàn)在mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td
def?ShapeInferenceOpInterface?:?OpInterface<"ShapeInference">?{
??let?description?=?[{
????Interface?to?access?a?registered?method?to?infer?the?return?types?for?an
????operation?that?can?be?used?during?type?inference.
??}];
??let?methods?=?[
????InterfaceMethod<"Infer?and?set?the?output?shape?for?the?current?operation.",
????????????????????"void",?"inferShapes">
??];
}
ShapeInferenceOpInterface接口繼承了OpInterface,該繼承接收要賦予生成的 C++ 接口類(lèi)的名稱"ShapeInference"作為模板參數(shù)。description字段提供了Operation的簡(jiǎn)要說(shuō)明,而methods字段定義Operation將需要提供的接口方法。
第二步:將特征添加到必要的 Toy Operation定義中
以Toy語(yǔ)言的Mul Operation為例,實(shí)現(xiàn)在mlir/examples/toy/Ch5/include/toy/Ops.td:
def?MulOp?:?Toy_Op<"mul",
????[NoSideEffect,?DeclareOpInterfaceMethods]>?{
??let?summary?=?"element-wise?multiplication?operation";
??let?description?=?[{
????The?"mul"?operation?performs?element-wise?multiplication?between?two
????tensors.?The?shapes?of?the?tensor?operands?are?expected?to?match.
??}];
??let?arguments?=?(ins?F64Tensor:$lhs,?F64Tensor:$rhs);
??let?results?=?(outs?F64Tensor);
??//?Specify?a?parser?and?printer?method.
??let?parser?=?[{?return?::parseBinaryOp(parser,?result);?}];
??let?printer?=?[{?return?::printBinaryOp(p,?*this);?}];
??//?Allow?building?a?MulOp?with?from?the?two?input?operands.
??let?builders?=?[
????OpBuilder<(ins?"Value":$lhs,?"Value":$rhs)>
??];
}
上面的代碼中,DeclareOpInterfaceMethods為Mul Operation添加了形狀推導(dǎo)的特征,和內(nèi)聯(lián)Pass里面的將CallOpInterface特征添加到cast Operation類(lèi)似。
第三步:定義對(duì)應(yīng)Operation的形狀推導(dǎo)函數(shù)
需要進(jìn)行形狀推導(dǎo)的每個(gè)Operation,都需要定義對(duì)應(yīng)的inferShapes()函數(shù),比如Mul Operation,結(jié)果的形狀就是輸入的形狀(因?yàn)槭莈lementwise操作)。代碼實(shí)現(xiàn)在mlir/examples/toy/Ch5/mlir/Dialect.cpp:
///?Infer?the?output?shape?of?the?MulOp,?this?is?required?by?the?shape?inference
///?interface.
void?MulOp::inferShapes()?{?getResult().setType(getOperand(0).getType());?}
第四步:實(shí)現(xiàn)形狀推導(dǎo)Pass
這一步是介紹形狀推導(dǎo)Pass的具體實(shí)現(xiàn),前面幾步是這一步的前置條件。這一步定義一個(gè)形狀推導(dǎo)Pass類(lèi)來(lái)實(shí)現(xiàn)Shape推斷算法,并會(huì)基于這個(gè)Pass類(lèi)來(lái)創(chuàng)建一個(gè)Shape推斷的Pass。代碼實(shí)現(xiàn)在mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp。
class?ShapeInferencePass
????:?public?mlir::PassWrapper?{
public:
??void?runOnFunction()?override?{
????auto?f?=?getFunction();
????//?Populate?the?worklist?with?the?operations?that?need?shape?inference:
????//?these?are?operations?that?return?a?dynamic?shape.
????llvm::SmallPtrSet16>?opWorklist;
????f.walk([&](mlir::Operation?*op)?{
??????if?(returnsDynamicShape(op))
????????opWorklist.insert(op);
????});
????//?Iterate?on?the?operations?in?the?worklist?until?all?operations?have?been
????//?inferred?or?no?change?happened?(fix?point).
????while?(!opWorklist.empty())?{
??????//?Find?the?next?operation?ready?for?inference,?that?is?an?operation
??????//?with?all?operands?already?resolved?(non-generic).
??????auto?nextop?=?llvm::find_if(opWorklist,?allOperandsInferred);
??????if?(nextop?==?opWorklist.end())
????????break;
??????Operation?*op?=?*nextop;
??????opWorklist.erase(op);
??????//?Ask?the?operation?to?infer?its?output?shapes.
??????LLVM_DEBUG(llvm::dbgs()?<"Inferring?shape?for:?"?<*op?<"\n");
??????if?(auto?shapeOp?=?dyn_cast(op))?{
????????shapeOp.inferShapes();
??????}?else?{
????????op->emitError("unable?to?infer?shape?of?operation?without?shape?"
??????????????????????"inference?interface");
????????return?signalPassFailure();
??????}
????}
????//?If?the?operation?worklist?isn't?empty,?this?indicates?a?failure.
????if?(!opWorklist.empty())?{
??????f.emitError("Shape?inference?failed,?")
??????????<"?operations?couldn't?be?inferred\n";
??????signalPassFailure();
????}
??}
??///?A?utility?method?that?returns?if?the?given?operation?has?all?of?its
??///?operands?inferred.
??static?bool?allOperandsInferred(Operation?*op)?{
????return?llvm::all_of(op->getOperandTypes(),?[](Type?operandType)?{
??????return?operandType.isa();
????});
??}
??///?A?utility?method?that?returns?if?the?given?operation?has?a?dynamically
??///?shaped?result.
??static?bool?returnsDynamicShape(Operation?*op)?{
????return?llvm::any_of(op->getResultTypes(),?[](Type?resultType)?{
??????return?!resultType.isa();
????});
??}
};
}?//?end?anonymous?namespace
///?Create?a?Shape?Inference?pass.
std::unique_ptr?mlir::toy::createShapeInferencePass()?{
??return?std::make_unique();
}
ShapeInferencePass繼承了FunctionPass,重寫(xiě)其runOnFunction()接口,實(shí)現(xiàn)Shape推斷算法。首先會(huì)創(chuàng)建一個(gè)輸出返回值為泛化Tensor的Operation列表,然后遍歷列表尋找輸入的操作數(shù)時(shí)類(lèi)型確定的Tensor的Operarion,如果沒(méi)有找到退出循環(huán),否則把該Operation從循環(huán)中刪除并調(diào)用相應(yīng)的inferShape()函數(shù)推斷該Operation的輸出返回Tensor的shape。如果Operation列表為空,則算法結(jié)束。
第五步:把形狀推導(dǎo)Pass加到優(yōu)化pipline
和內(nèi)聯(lián)Pass類(lèi)似,需要把形狀推導(dǎo)Pass加到優(yōu)化pipline里面去。上面內(nèi)聯(lián)Pass那里已經(jīng)展示過(guò)了,不再重復(fù)貼代碼。
至此,我們就完成了內(nèi)聯(lián)Pass和形狀推導(dǎo)Pass的實(shí)現(xiàn),讓我們看看經(jīng)過(guò)這兩個(gè)Pass優(yōu)化之后的MLIR表達(dá)式長(zhǎng)什么樣子吧。執(zhí)行./toyc-ch4 ../../mlir/test/Examples/Toy/Ch4/codegen.toy -emit=mlir -opt 獲得了優(yōu)化后的MLIR表達(dá)式:
module??{
??func?@main()?{
????%0?=?toy.constant?dense<[[1.000000e+00,?2.000000e+00,?3.000000e+00],?[4.000000e+00,?5.000000e+00,?6.000000e+00]]>?:?tensor<2x3xf64>
????%1?=?toy.transpose(%0?:?tensor<2x3xf64>)?to?tensor<3x2xf64>
????%2?=?toy.mul?%1,?%1?:?tensor<3x2xf64>
????toy.print?%2?:?tensor<3x2xf64>
????toy.return
??}
}
參考文章
https://zhuanlan.zhihu.com/p/106472878 https://www.zhihu.com/people/CHUNerr/posts https://mlir.llvm.org/docs/Tutorials/Toy/Ch-4/
