【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】七,萬(wàn)字長(zhǎng)文入門TVM Pass
0x0. 前言
這篇文章基于TVM 0.8.0.dev版本。在【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】五,TVM Relay以及Pass簡(jiǎn)介 這篇推文中已經(jīng)簡(jiǎn)單介紹了Relay和Pass機(jī)制。但對(duì)Pass的基礎(chǔ)設(shè)施(Pass Infrastructure)和Relay樹(shù)結(jié)構(gòu)都沒(méi)有詳細(xì)介紹,所以這篇文章主要介紹一下Pass Infrastructure和Relay樹(shù)結(jié)構(gòu),再基于這些關(guān)鍵的基礎(chǔ)知識(shí)詳細(xì)了解一下Constant Folding Pass,相信讀者讀完這篇文章會(huì)對(duì)TVM的Pass有更深的理解,并且在閱讀其它Pass和實(shí)現(xiàn)自定義Pass時(shí)可以很Relax。
0x1. Pass Infrastructure
首先來(lái)看Pass Infrastructure,基于官方文檔進(jìn)行介紹。
在講解Pass通用的注冊(cè)和運(yùn)行流程前,先來(lái)介紹一下TVM的Pass Infrastructure。參考官方文檔:https://tvm.apache.org/docs/dev/pass_infra.html 。
Relay 和 TVM IR 都包含一系列優(yōu)化passes,可提高模型的性能指標(biāo),例如平均推理速度、內(nèi)存占用或特定設(shè)備的功耗。TVM有一套標(biāo)準(zhǔn)優(yōu)化方法以及特定于機(jī)器學(xué)習(xí)的優(yōu)化方法,包括常量折疊、死代碼消除、運(yùn)算符布局更改、算符融合、緩沖區(qū)處理和循環(huán)變換等。每一個(gè)Pass都使用在traversal期間和/或之前收集的分析結(jié)果來(lái)構(gòu)造ir-to-ir的pass。
然而,隨著TVM的迅速發(fā)展,需要一種更系統(tǒng)、更有效的方法來(lái)管理這些passes。此外,一個(gè)可以管理跨TVM堆棧不同層(如Relay和tir)的passes的通用框架,為開(kāi)發(fā)人員快速原型化并將實(shí)現(xiàn)的passes插入系統(tǒng)鋪平了道路。
例如,許多現(xiàn)有的生產(chǎn)編譯器,如 GCC 和 LLVM,都采用pass manager來(lái)有效管理passes的執(zhí)行。最初管理 pass 很簡(jiǎn)單,因?yàn)?pass 的數(shù)量很少,但成熟的編譯器將包含數(shù)百個(gè)單獨(dú)的 pass。Often external users will want to have custom passes correctly scheduled without having to modify a single handcrafted pass order.
同樣,現(xiàn)代深度學(xué)習(xí)框架,如 Pytorch 和 MXNet Gluon,也有分別通過(guò) Sequential 和 Block 啟用pass-style層構(gòu)建方案的趨勢(shì)。有了這樣的結(jié)構(gòu),這些現(xiàn)代框架能夠方便地將模塊/層添加到它們的容器中,并輕松地構(gòu)建神經(jīng)網(wǎng)絡(luò)。
Relay pass infra 的設(shè)計(jì)很大程度上受到 LLVM 中使用的分層pass manager和流行的深度學(xué)習(xí)框架中使用的block-style容器的啟發(fā)。pass infra 的主要目標(biāo)包括:
實(shí)現(xiàn)更好的optimizer編程編排。這允許用戶靈活地定制和構(gòu)建自己的優(yōu)化管道。 提供一種用戶友好的方式來(lái)調(diào)試passes。 減輕開(kāi)發(fā)人員手動(dòng)和分別解決passes之間的依賴關(guān)系。 為開(kāi)發(fā)人員簡(jiǎn)化實(shí)現(xiàn)新passes的難度。例如,我們?cè)试S用戶在 Python 中實(shí)現(xiàn)一個(gè) pass 并讓 pass infra 操縱它的執(zhí)行。
The Design
我們專注于為用戶提供易于擴(kuò)展的功能,讓用戶可以快速添加新passes而不會(huì)失去向后兼容性。該設(shè)計(jì)包含后端和前端。前者實(shí)現(xiàn)了 pass infra 的主要邏輯。后者為用戶提供簡(jiǎn)單的 API 進(jìn)行交互,即允許用戶快速創(chuàng)建自己的優(yōu)化管道。
C++ Backend
我們提供了一個(gè) PassInfo 對(duì)象來(lái)包含一個(gè)pass所需的基本信息。name 是 pass 名稱,opt_level 指示將啟用 pass 的優(yōu)化級(jí)別, required 表示執(zhí)行某個(gè) pass 所需的 pass(更多詳細(xì)信息請(qǐng)參見(jiàn)include/tvm/ir/transform.h)。例如,在注冊(cè)pass的時(shí)候(將在后面介紹),pass開(kāi)發(fā)人員可以指定pass的名稱、將執(zhí)行的優(yōu)化級(jí)別和/或所需的pass。opt_level 可用于幫助 pass infra 識(shí)別在用戶提供的優(yōu)化級(jí)別下運(yùn)行時(shí)是否需要執(zhí)行某個(gè) pass。required字段可以由pass infra用來(lái)解決pass依賴關(guān)系。
class PassInfoNode : public Object {
String name;
int opt_level;
Array<String> required;
};
PassContext
PassContext 帶有用于優(yōu)化pass的有用信息。例如,它包含錯(cuò)誤報(bào)告系統(tǒng),因此pass的作者可以提供有關(guān)優(yōu)化失敗原因的注釋。PassContext 還旨在替換舊的BuildConfig,它用于幫助用戶配置編譯選項(xiàng),包括優(yōu)化級(jí)別和必需/禁用的pass等。例如,我們可能有一個(gè)配置,它在 opt_level=3 時(shí)執(zhí)行所有pass,除開(kāi)使用 PassContext 提供的 disabled_pass=xx禁用的一些passes 。現(xiàn)在我們可以在 opt_level=3 處對(duì)所有passes進(jìn)行全局處理,并排除禁用pass列表中的那些pass。
這個(gè)類是為方便用戶編寫(xiě)Python而設(shè)計(jì)的,它的語(yǔ)法可以在特定的配置下執(zhí)行優(yōu)化。此外,用戶可以通過(guò) PassContext::Current()以線程安全的方式獲取某個(gè)程序范圍內(nèi)可用的context,因?yàn)門hreadLocalStore用于保存創(chuàng)建的pass context對(duì)象,關(guān)于ThreadLocalStore建議看這篇文章:https://zhuanlan.zhihu.com/p/61587053,TVM模仿Java中的ThreadLocalStore在C++層自己實(shí)現(xiàn)了用來(lái)管理線程。稍后將提供示例以展示我們?nèi)绾问褂?C++ 和 Python API 來(lái)創(chuàng)建使用pass context的編譯管道。
class PassContextNode : public Object {
public:
ErrorReporter err_reporter;
int opt_level{2};
tvm::Array<tvm::Expr> required_pass;
tvm::Array<tvm::Expr> disabled_pass;
};
class PassContext : public NodeRef {
public:
TVM_DLL static PassContext Create();
TVM_DLL static PassContext Current();
/* Other fields are omitted. */
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class tvm::With<PassContext>;
};
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
PassContextThreadLocalEntry() {
default_context = PassContext(make_node<PassContextNode>());
}
};
/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
PassContextThreadLocalStore;
Pass Constructs
pass infra 是以分層方式設(shè)計(jì)的,它可以在不同粒度的Relay/tir 程序下工作。引入了一個(gè)純虛擬類 PassNode 作為不同優(yōu)化pass的基礎(chǔ)。此類包含幾個(gè)必須由子類在modules, functions, or sequences of passes實(shí)現(xiàn)的虛擬方法。
class PassNode : Object {
virtual PassInfo Info() const = 0;
virtual Module operator()(const IRModule& mod
const PassContext& pass_ctx) const = 0;
};
成員函數(shù)展示了一個(gè)pass應(yīng)該如何實(shí)現(xiàn),例如它始終在特定context下工作在 IRModule中,所有的pass都被設(shè)計(jì)在一個(gè)Module to Module的管理器中。因此,由 pass infra 控制的優(yōu)化將始終更新整個(gè)module。
已經(jīng)創(chuàng)建了幾個(gè)子類來(lái)實(shí)現(xiàn)不同類型的優(yōu)化pass,例如,function-level passes, module-level passes, and sequential passes。每個(gè)子類本身都可以充當(dāng)pass管理器。例如,他們可以收集所需的passes并執(zhí)行它們或基于給定的元數(shù)據(jù)構(gòu)建依賴關(guān)系圖。它們的完整定義可以在src/relay/ir/transform.cc 和 src/ir/transform.cc 中找到。
Module-Level Passes
Module Level Passes主要用于全局和過(guò)程間優(yōu)化 (IPO),類似于 LLVM 中使用的module pass。Relay 中一些典型的 pass 需要一個(gè)模塊的global picture,比如 A-normal form conversion 和 lambda lifting等,都屬于這個(gè)集合。在此級(jí)別,用戶甚至可以在一個(gè)module中添加和/或刪除function。
class ModulePassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
// Other members/methods are omitted
};
pass_info 維護(hù)module-level pass所需的信息。pass_func 實(shí)現(xiàn)了真正的optimization。例如,我們可能需要對(duì)module執(zhí)行死代碼消除。我們可以在 pass_func 中實(shí)現(xiàn)算法并讓它在module上運(yùn)行。然后它將刪除死代碼,包括module中未使用的函數(shù)。請(qǐng)注意,該字段被設(shè)計(jì)為一個(gè)packed function,所以這個(gè)優(yōu)化不僅可以使用C++還可以使用Python來(lái)實(shí)現(xiàn)。
Function-Level Passes
Function-level passes用于為給定的 Relay/tir module實(shí)現(xiàn)各種內(nèi)部函數(shù)級(jí)優(yōu)化。它一次從module的函數(shù)列表中獲取一個(gè)函數(shù)以進(jìn)行優(yōu)化,并生成一個(gè)重寫(xiě)的 Relay Function 或 tir PrimFunc。大多數(shù)pass可以歸入這一類,例如Relay中的常見(jiàn)子表達(dá)式消除和inference simplification 以及tir中的向量化和flattening storage等。
請(qǐng)注意,此級(jí)別的passes范圍是 Relay Function或 tir PrimFunc。因此,我們無(wú)法通過(guò)這些passes添加或刪除函數(shù),因?yàn)樗鼈儾恢廊中畔ⅰ?/p>
class FunctionPassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
bool SkipFunction(const Function& func) const;
// Other members/methods are omitted...
};
pass_info 與我們剛剛在Module pass 中描述的相同。pass_func 需要一個(gè)函數(shù)進(jìn)行優(yōu)化,它還需要一個(gè)Module,因?yàn)槲覀兛赡軙?huì)使用它來(lái)報(bào)告錯(cuò)誤。一個(gè)函數(shù)可以用“SkipOptimization”注釋,以便在優(yōu)化過(guò)程中被忽略。
Sequential Passes
SequentialPass 類似于 Pytorch nn.Sequential,它包含許多用于執(zhí)行的passes。
class SequentialPassNode : PassNode {
PassInfo pass_info;
// Passes need to be executed.
Array<Pass> passes;
bool PassEnabled(const PassInfo& info) const;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};
目前在Relay中只有少數(shù)passes 被放入這組中。例如,FoldScaleAxis 需要在內(nèi)部調(diào)度 ForwardFoldScaleAxis 和 BackwardFoldScaleAxis。此外,建議先完成BackwardFoldScaleAxis。因此,該pass是SequentialPass的理想候選者。
以下代碼顯示了如何調(diào)用sequential pass中的各個(gè)pass。
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
Module mod = module;
for (const Pass& pass : passes) {
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
ICHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
return mod;
}
在調(diào)用pass時(shí),我們首先檢查是否啟用了此pass。這是通過(guò)首先檢查用戶是否明確禁用該pass,然后檢查它是否被用戶指定為必需pass來(lái)完成的。如果仍然不確定是否啟用了此傳遞,則將檢查其 opt_level。只有當(dāng)它的opt_level不低于pass context中配置的優(yōu)化級(jí)別時(shí),才會(huì)啟用并因此執(zhí)行此pass。
要執(zhí)行pass,我們首先需要使用pass name在 TVM packed function注冊(cè)表中已注冊(cè)的pass。這是可能的,因?yàn)槊總€(gè)pass都注冊(cè)了一個(gè) API 接口,我們將在后面展示。
Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry;
std::string fpass_name = "relay._transform." + pass_name;
const auto* f = Registry::Get(fpass_name);
ICHECK(f != nullptr) << "Cannot find " << fpass_name
<< "to create the pass " << pass_name;
return (*f)();
}
提供了一些helper function來(lái)創(chuàng)建上述每種類型的Pass。這些helper function也暴露給 Python 前端,以便用戶可以方便地使用 Python API 來(lái)創(chuàng)建特定的 pass 對(duì)象。
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreateModulePass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
Pass Registration
我們已經(jīng)介紹了不同級(jí)別pass的概念和用于編譯的context。用戶可以多么輕松地注冊(cè)pass是一件有意義的事。,我們以constant folding為例。這個(gè) pass 已經(jīng)被實(shí)現(xiàn)來(lái)折疊 Relay Function中的常量(在 tvm/src/relay/transforms/fold_constant.cc 中找到)。
提供了一個(gè) API 來(lái)執(zhí)行 Expr 到 Expr 的轉(zhuǎn)換。
Expr FoldConstant(const Expr& expr);
為了將這個(gè)pass注冊(cè)到pass infra,我們首先需要決定這個(gè)pass將在哪個(gè)級(jí)別執(zhí)行。由于常量折疊發(fā)生在單個(gè)函數(shù)上,我們應(yīng)該直觀地通過(guò) CreateFunctionPass為其創(chuàng)建一個(gè) FunctionPass。pass_func 作為packed function返回,該函數(shù)在 IRModule 中的每個(gè)function上調(diào)用 Expr to Expr API。{} 表示此pass不需要先決條件。否則,pass開(kāi)發(fā)人員必須識(shí)別并列出它們。
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
} // namespace transform
為了允許其他 C++ 模塊應(yīng)用此pass,我們?cè)?include/tvm/relay/transform.h中聲明了一個(gè)free function,如下所示:
TVM_DLL Pass FoldConstant();
Python Frontend
python前端只需要一些簡(jiǎn)單的 APIs。例如,我們可以為用戶提供以下 APIs 來(lái)創(chuàng)建和執(zhí)行一個(gè) pass(完整的實(shí)現(xiàn)在 python/tvm/relay/transform.py 和 python/tvm/ir/transform.py 中提供)。后端接收信息并決定它應(yīng)該使用哪個(gè)函數(shù)來(lái)創(chuàng)建 Pass 對(duì)象。
PassContext
Python 前端為 PassContext 提供了一個(gè)包裝器,通過(guò)覆蓋 __enter__ 和 __exit__ 來(lái)啟用 with 語(yǔ)法。為用戶提供了一個(gè) current 靜態(tài)方法來(lái)獲取在特定范圍內(nèi)使用的上下文。
@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
def __enter__(self):
_transform.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace, config):
_transform.ExitPassContext(self)
@staticmethod
def current():
"""Return the current pass context."""
return _transform.GetCurrentPassContext()
PassContext 用于配置編譯選項(xiàng),包括優(yōu)化級(jí)別和必需/禁用的pass。它還可以帶一個(gè)配置字典,以便不同的pass可以方便地獲取passed的數(shù)據(jù),例如回退設(shè)備信息和循環(huán)展開(kāi)的步數(shù)/深度等。為了能夠獲取所需的配置,必須通過(guò)TVM_REGISTER_PASS_CONFIG_OPTION注冊(cè)關(guān)鍵字。例如,loop unrolling pass使用以下內(nèi)容:
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);
更多細(xì)節(jié)請(qǐng)參考 src/tir/transforms/unroll_loop.cc。
Pass Objects
Pass 是所有 pass 對(duì)象的基類。這里的所有方法都只是在后端實(shí)現(xiàn)的簡(jiǎn)單包裝器。它們是為了用戶方便地與 Python 中的基類進(jìn)行交互而定義的。在 pass 基類中只定義了一個(gè)__call__來(lái)使子類成為可調(diào)用對(duì)象,以便它們可以很容易地被調(diào)用(例如 pass_xx(arg))來(lái)執(zhí)行。
@register_relay_node
class Pass(RelayNode):
def __call__(self, mod):
return _transform.RunPass(self, mod)
提供了一些輔助 APIs 以支持從 Python 前端輕松創(chuàng)建pass并讓pass infra控制執(zhí)行。比如提供給用戶module_pass、function_pass、sequential,讓他們可以自定義自己的pass或者pass管道。
對(duì)于在C++后端實(shí)現(xiàn)的所有pass,我們分別在python/tvm/ir/transform.py和python/tvm/relay/transform.py中提供了相應(yīng)的Python API。例如,const 折疊有一個(gè) Python API,如下所示:
def FoldConstant():
return _transform.FoldConstant()
用戶可以通過(guò)裝飾器像下面這樣構(gòu)建一個(gè)pass:
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10,), "float32")
x = relay.var("x", tp)
gv = relay.GlobalVar("abs")
func = relay.Function([x], relay.abs(x))
new_mod = relay.Module({gv: func})
new_mod.update(mod)
return new_mod
module_pass = transform
assert isinstance(module_pass, transform.ModulePass)
assert module_pass.info.opt_level == 2
這里的transform函數(shù)向輸入的module添加了一個(gè)abs 函數(shù),但它可以是module level的任何自定義pass。創(chuàng)建此 module_pass 后,用戶可以將其應(yīng)用于任何 Relay 模塊。例如,我們可以構(gòu)建一個(gè)empty module并應(yīng)用此pass來(lái)添加 abs 函數(shù)。
mod = relay.Module()
mod = module_pass(mod)
相應(yīng)地,我們也為 function_pass 提供了這樣的功能。例如,一個(gè)示例function-level pass可以寫(xiě)成如下:
@relay.transform.function_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
# Just for demo purposes
# Transform func to new_func
return self.new_func
x = relay.var("x", shape=(10, 20))
f1 = relay.Function([x], x)
f2 = relay.Function([x], relay.log(x))
# fpass is now a special pass that replaces every
# function to f1
fpass = TestReplaceFunc(f1)
# Now every function in input_mod is replaced by f1
res_mod = fpass(input_mod)
或者,用戶也可以不使用裝飾器直接注冊(cè)pass,然后調(diào)用它。有關(guān)如何自定義您自己的優(yōu)化管道以及調(diào)試 Relay 和 tir pass 的更多示例,請(qǐng)參閱 use pass infra 教程(https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_infra.py)。
0x2. TVM Relay樹(shù)結(jié)構(gòu)
AST
摘自wiki 在計(jì)算機(jī)科學(xué)中,抽象語(yǔ)法樹(shù)(Abstract Syntax Tree,AST),或簡(jiǎn)稱語(yǔ)法樹(shù)(Syntax tree),是源代碼語(yǔ)法結(jié)構(gòu)的一種抽象表示。它以樹(shù)狀的形式表現(xiàn)編程語(yǔ)言的語(yǔ)法結(jié)構(gòu),樹(shù)上的每個(gè)節(jié)點(diǎn)都表示源代碼中的一種結(jié)構(gòu)。之所以說(shuō)語(yǔ)法是“抽象”的,是因?yàn)檫@里的語(yǔ)法并不會(huì)表示出真實(shí)語(yǔ)法中出現(xiàn)的每個(gè)細(xì)節(jié)。比如,嵌套括號(hào)被隱含在樹(shù)的結(jié)構(gòu)中,并沒(méi)有以節(jié)點(diǎn)的形式呈現(xiàn);而類似于 if-condition-then 這樣的條件跳轉(zhuǎn)語(yǔ)句,可以使用帶有三個(gè)分支的節(jié)點(diǎn)來(lái)表示。和抽象語(yǔ)法樹(shù)相對(duì)的是具體語(yǔ)法樹(shù)(通常稱作分析樹(shù))。一般的,在源代碼的翻譯和編譯過(guò)程中,語(yǔ)法分析器創(chuàng)建出分析樹(shù),然后從分析樹(shù)生成AST。一旦AST被創(chuàng)建出來(lái),在后續(xù)的處理過(guò)程中,比如語(yǔ)義分析階段,會(huì)添加一些信息。
之前在解析TVM Relay的ONNX前端的時(shí)候,已經(jīng)提到在完成每個(gè)OP轉(zhuǎn)換之后需要使用IRModule.from_expr將所有轉(zhuǎn)換后的Relay Function包起來(lái)返回,過(guò)程如下,這里關(guān)心最后一行代碼即可:
def from_onnx(self, graph, opset, get_output_expr=False):
"""基于ONNX模型構(gòu)建Relay IR。
參數(shù)
----------
graph : onnx protobuf 對(duì)象
加載進(jìn)來(lái)的ONNX Graph
opset : 操作集版本
get_output_expr: bool
如果設(shè)置為true,則此轉(zhuǎn)換將返回每個(gè)輸出表達(dá)式,而不是打包的模塊。
將子圖轉(zhuǎn)換為Relay時(shí),這可能很有用。
Returns
-------
mod : tvm.IRModule
The returned relay module
params : dict
A dict of name: tvm.nd.array pairs, used as pretrained weights
"""
self.opset = opset
# 解析網(wǎng)絡(luò)的輸入到relay中, 又叫參數(shù),onnx的initializer就是用來(lái)保存模型參數(shù)的
for init_tensor in graph.initializer:
if not init_tensor.name.strip():
raise ValueError("Tensor's name is required.")
# 具體實(shí)現(xiàn)就是先把這個(gè)TensorProto使用get_numpy函數(shù)獲得值,再reshape到特定形狀,再基于這個(gè)numpy構(gòu)造tvm.nd.array。
array = self._parse_array(init_tensor)
# 前面解釋過(guò),如果設(shè)置凍結(jié)參數(shù),則將這個(gè)參數(shù)設(shè)置為Relay中的常量OP
if self._freeze_params:
self._nodes[init_tensor.name] = _expr.const(array)
else:
self._params[init_tensor.name] = array
self._nodes[init_tensor.name] = new_var(
init_tensor.name,
shape=self._params[init_tensor.name].shape,
dtype=self._params[init_tensor.name].dtype,
)
# 解析ONNX模型的輸入
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
# 獲取i這個(gè)輸入的名字,shape,數(shù)據(jù)類型以及shape每個(gè)維度對(duì)應(yīng)的名字
i_name, i_shape, d_type, i_shape_name = get_info(i)
# 判斷i這個(gè)輸入是權(quán)重參數(shù)還是輸入
if i_name in self._params:
# i is a param instead of input
self._num_param += 1
self._params[i_name] = self._params.pop(i_name)
self._nodes[i_name] = new_var(
i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype
)
# 輸入節(jié)點(diǎn)已經(jīng)在Relay IR中了就不用處理了
elif i_name in self._nodes:
continue
else:
# 真正的輸入節(jié)點(diǎn),依賴用戶進(jìn)行指定
self._num_input += 1
self._input_names.append(i_name)
if i_name in self._shape:
i_shape = self._shape[i_name]
else:
if "?" in str(i_shape):
warning_msg = (
"Input %s has unknown dimension shapes: %s. "
"Specifying static values may improve performance"
% (i_name, str(i_shape_name))
)
warnings.warn(warning_msg)
if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else:
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
self._inputs[i_name] = self._nodes[i_name]
# Only check user inputs in the outer-most graph scope.
if self._old_manager is None:
assert all(
[name in self._input_names for name in self._shape.keys()]
), "User specified the shape for inputs that weren't found in the graph: " + str(
self._shape
)
# 獲取不支持的算子列表
convert_map = _get_convert_map(opset)
unsupported_ops = set()
for node in graph.node:
op_name = node.op_type
if (
op_name not in convert_map
and op_name != "Constant"
and op_name not in _identity_list
):
unsupported_ops.add(op_name)
# 輸出不支持的算子集合
if unsupported_ops:
msg = "The following operators are not supported for frontend ONNX: "
msg += ", ".join(unsupported_ops)
raise tvm.error.OpNotImplemented(msg)
# 到這里說(shuō)明這個(gè)ONNX模型的所有算子都被Relay支持,可以正常進(jìn)行轉(zhuǎn)換了
for node in graph.node:
op_name = node.op_type
# 解析attribute參數(shù)
attr = self._parse_attr(node.attribute)
# 創(chuàng)建并填充onnx輸入對(duì)象。
inputs = onnx_input()
for i in node.input:
if i != "":
# self._renames.get(i, i)用來(lái)獲取ONNX Graph每個(gè)節(jié)點(diǎn)的輸入
inputs[i] = self._nodes[self._renames.get(i, i)]
else:
inputs[i] = None
i_name = self._parse_value_proto(node)
node_output = self._fix_outputs(op_name, node.output)
attr["tvm_custom"] = {}
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(node_output)
# 執(zhí)行轉(zhuǎn)換操作
op = self._convert_operator(op_name, inputs, attr, opset)
# op的輸出可能只有一個(gè)也可能有多個(gè)
if not isinstance(op, _expr.TupleWrapper):
outputs_num = 1
else:
outputs_num = len(op)
if outputs_num > 1:
# ONNX的某些節(jié)點(diǎn)支持可選輸出
# 這一塊在ONNX的Graph中搜索缺失的輸出并移除不需要的節(jié)點(diǎn)
valid_outputs = [False] * outputs_num
for i, output in enumerate(node_output):
if output != "":
valid_outputs[i] = True
# If we have outputs ONNX isn't expecting, we need to drop them
# 如果我們有ONNX不期望出現(xiàn)的輸出,我們需要?jiǎng)h除它們
if not all(valid_outputs):
tup = op.astuple()
# TupleWrapper can also wrap ops with TupleType outputs
if isinstance(tup, _expr.Tuple):
# For tuples, we extract the fields instead of using GetTupleItem
outputs = [tup.fields[i] for i, valid in enumerate(valid_outputs) if valid]
else:
# For call nodes, we need to GetTupleItem
outputs = [op[i] for i, valid in enumerate(valid_outputs) if valid]
# Create the new op with valid outputs
if len(outputs) == 1:
op = outputs[0]
else:
op = _expr.TupleWrapper(outputs, len(outputs))
# Drop invalid outputs for the onnx node
outputs_num = len(outputs)
node_output = [output for output in node_output if output != ""]
assert (
len(node_output) == outputs_num
), "Number of output mismatch {} vs {} in {}.".format(
len(node_output), outputs_num, op_name
)
# 輸出只有一個(gè)有可能是常量OP,可以執(zhí)行一次常量折疊功能
if outputs_num == 1:
self._nodes[node_output[0]] = fold_constant(op)
else:
op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]
# 解析ONNX模型的輸出
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
# 如果需要直接返回轉(zhuǎn)換后的表達(dá)式,在這里return
if get_output_expr:
return outputs
# 保持來(lái)自O(shè)NNX Graph的輸入和參數(shù)順序,但僅僅包含這些需要執(zhí)行轉(zhuǎn)換到Relay的節(jié)點(diǎn)
free_vars = analysis.free_vars(outputs)
nodes = {v: k for k, v in self._nodes.items()}
free_vars = [nodes[var] for var in free_vars]
for i_name in self._params:
if i_name in free_vars and i_name not in self._inputs:
self._inputs[i_name] = self._nodes[i_name]
# 根據(jù)我們的輸出表達(dá)式和所有輸入變量創(chuàng)建一個(gè)函數(shù)。
func = _function.Function([v for k, v in self._inputs.items()], outputs)
# 把這個(gè)函數(shù)用IRModule包起來(lái)返回,并同時(shí)返回權(quán)重參數(shù)
return IRModule.from_expr(func), self._params
這里IRModule.from_expr(func)就完成了Relay 抽象語(yǔ)法樹(shù)結(jié)構(gòu)的構(gòu)建,TVM將這個(gè)樹(shù)結(jié)構(gòu)定義為tvm.IRModule這個(gè)類,也即Relay IR。
Relay 樹(shù)結(jié)構(gòu)
現(xiàn)在來(lái)學(xué)習(xí)一下Relay 抽象語(yǔ)法樹(shù),也就是tvm.IRModule相關(guān)的數(shù)據(jù)結(jié)構(gòu)。
節(jié)點(diǎn)定義
樹(shù)的節(jié)點(diǎn)定義為在/include/tvm/relay/expr.h 中,主要有以下幾種類型:ConstantNode、VarNode、TupleNode、CallNode、LetNode、IfNode。
這些Node都繼承了在include/tvm/ir/expr.h定義的RelayExprNode,而RelayExprNode又繼承了BaseExprNode,RelayExprNode可以做什么可以參考這幾行注釋:
/*!
* \brief Base node of all non-primitive expressions.
*
* RelayExpr supports tensor types, functions and ADT as
* first class citizens. The life-cycle of the corresponding
* objects are implicitly managed by the language.
*
* \sa RelayExpr
*/
/*!
* \brief 所有非原始表達(dá)式的基節(jié)點(diǎn)。
*
* RelayExpr 支持張量類型、函數(shù)和 ADT 作為
* 一等公民。 對(duì)應(yīng)的生命周期
* 對(duì)象由語(yǔ)言隱式管理。
*
* \sa RelayExpr
*/
然后這里以IfNode和CallNode為例看一下它們的實(shí)現(xiàn):
class IfNode : public ExprNode {
public:
/*! \brief The condition */
Expr cond;
/*! \brief The expression evaluated when condition is true. */
Expr true_branch;
/*! \brief The expression evaluated when condition is false */
Expr false_branch;
};
class CallNode : public ExprNode {
public:
/*!
* \brief The operator(function) being invoked
*
* - It can be tvm::Op which corresponds to the primitive operators.
* - It can also be user defined functions (Function, GlobalVar, Var).
*/
Expr op;
/*! \brief The arguments(inputs) of the call */
tvm::Array<relay::Expr> args;
/*! \brief The additional attributes */
Attrs attrs;
};
這里展示了這些節(jié)點(diǎn)的成員變量,可以大致了解到這些節(jié)點(diǎn)的內(nèi)部結(jié)構(gòu)。
節(jié)點(diǎn)的數(shù)據(jù)訪問(wèn)
在了解了Relay模型樹(shù)節(jié)點(diǎn)后,我們需要知道TVM是如何去訪問(wèn)這些節(jié)點(diǎn)的數(shù)據(jù)的。在官方文檔中可以找到這樣一句話:ExprVisitor用于不修改程序而是執(zhí)行程序分析和收集信息的passes。而ExprVisitor又繼承自ExprFunctor(定義在tvm/include/tvm/relay/expr_functor.h),ExprFunctor設(shè)置了VisitExpr_的虛函數(shù),在解析時(shí)會(huì)回到ExprVisitor來(lái)解析節(jié)點(diǎn)。ExprFunctor提供了一個(gè)public接口方法VisitExpr,它接受一個(gè)表達(dá)式和零個(gè)或多個(gè)參數(shù)并返回某種類型的實(shí)例。當(dāng)你擴(kuò)展這個(gè)類時(shí),你通過(guò)為每種類型的表達(dá)式覆蓋VisitExpr_的實(shí)現(xiàn)來(lái)定義 AST 遍歷模式。
VisitExpr和VisitExpr_之間的關(guān)系與調(diào)度有關(guān)。每個(gè)VisitExpr_定義針對(duì)特定類型的表達(dá)式,但你并不總是知道你將訪問(wèn)節(jié)點(diǎn)是哪種類型。為了解決這個(gè)問(wèn)題,ExprFunctor提供了一個(gè)VisitExpr函數(shù),它從給定的表達(dá)式路由到處理它的VisitExpr_case。盡管 C++ 已經(jīng)提供了動(dòng)態(tài)調(diào)度,但ExprFunctor定義了自己的 vtable,VisitExpr使用它。通過(guò)定義我們自己的vtable,我們可以更好地控制調(diào)度。例如,如果我們想定義一個(gè)PrintVisitor遍歷器,在每次訪問(wèn)之前打印“Here”,我們可以覆蓋VisitExpr:
void PrintVisitor::VisitExpr(const Expr& expr) {
std::cout << "Here" << std::endl;
ExprFunctor::VisitExpr(expr);
}
ExprFunctor本身是一個(gè)非常通用的類,這就是為什么通常會(huì)擴(kuò)展ExprVisitor或ExprMutator的原因。這些類擴(kuò)展了ExprFunctor 并提供VisitExpr_的默認(rèn)實(shí)現(xiàn),用于捕獲每個(gè)表達(dá)式類型的常見(jiàn)遍歷模式。擁有這些默認(rèn)實(shí)現(xiàn)意味著我們只需要為需要不同行為的表達(dá)式類型提供進(jìn)行重寫(xiě)VisitExpr_方法即可。
比如對(duì)于tvm/src/relay/transforms/fold_constant.cc中的ConstantChecker這個(gè)類,就繼承了ExprVisitor,并通過(guò)VisitExpr(expr),訪問(wèn)數(shù)據(jù)。ExprVisitor的VisitExpr成員函數(shù)實(shí)現(xiàn)如下:
void ExprVisitor::VisitExpr(const Expr& expr) {
auto it = visit_counter_.find(expr.get());
if (it != visit_counter_.end()) {
++it->second;
} else {
using TParent = ExprFunctor<void(const Expr&)>;
TParent::VisitExpr(expr);
visit_counter_.insert({expr.get(), 1});
}
}
可以看到這個(gè)類實(shí)際上調(diào)用的是父類(ExprFunctor)的VisitExpr,而ExprFunctor的VisitExpr的實(shí)現(xiàn)如下:
virtual R VisitExpr(const Expr& n, Args... args) {
ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
"have generated invalid data.";
static FType vtable = InitVTable();
return vtable(n, this, std::forward<Args>(args)...);
}
可以看到ExprFunctor設(shè)置了VisitExpr虛函數(shù),在解析時(shí)會(huì)回到ExprVisitor來(lái)解析節(jié)點(diǎn),而ConstantChecker這個(gè)類繼承了ExprVisitor,這樣我們只需要在ConstantChecker類中重寫(xiě)VisitExpr_就可以了。
在ExprFunctor的VisitExpr實(shí)現(xiàn)中有一個(gè)RELAY_EXPR_FUNCTOR_DISPATCH宏,這個(gè)宏的定義如下:
#define RELAY_EXPR_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \
return self->VisitExpr_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \
});
這里的self即為ExprFunctor的VisitExpr的實(shí)現(xiàn)中的vtable(n, this, std::forward<Args>(args)...),而this指向ExprFunctor。又因?yàn)?code style="font-size: 14px;word-wrap: break-word;padding: 2px 4px;border-radius: 4px;margin: 0 2px;color: #1e6bb8;background-color: rgba(27,31,35,.05);font-family: Operator Mono, Consolas, Monaco, Menlo, monospace;word-break: break-all;">ExprVisitor::VisitExpr方法調(diào)用的是ExprFunctor的函數(shù),所以這里的this指向的是ExprVisitor實(shí)例。
以IfNode為例子,看看ExprVisitor的VisitExpr_實(shí)現(xiàn)。由于this指向的是ExprVisitor實(shí)例,最后會(huì)在ExprVisitor實(shí)例中生成visit_counter_的列表。
void ExprVisitor::VisitExpr_(const IfNode* op) {
this->VisitSpan(op->span);
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
}
visit_counter_是在ExprVisitor中定義的一個(gè)unordered_map,來(lái)標(biāo)記在遍歷Relay AST時(shí)某種Expr是否出現(xiàn),同時(shí)記錄下出現(xiàn)的次數(shù)。
// Internal visiting counter
std::unordered_map<const Object*, size_t> visit_counter_;
節(jié)點(diǎn)修改
pass是對(duì)Relay 樹(shù)結(jié)構(gòu),也可以說(shuō)計(jì)算圖進(jìn)行優(yōu)化,優(yōu)化必然設(shè)計(jì)到對(duì)圖結(jié)構(gòu)的修改。這就是上面提到的ExprMutator子類,它和ExprVisitor一樣繼承自ExprFunctor。類的定義如下:
class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public:
/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); }
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override;
Expr VisitExpr_(const OpNode* op) override;
Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
Expr VisitExpr_(const RefCreate來(lái)表記Node* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;
/*!
* \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
virtual Clause VisitClause(const Clause& c);
virtual Pattern VisitPattern(const Pattern& c);
protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo_;
};
我們需要關(guān)注的是memo_這個(gè)成員變量,然后我們看一下這個(gè)類的VisitExpr實(shí)現(xiàn):
Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr);
if (it != this->memo_.end()) {
return it->second;
} else {
Expr new_expr = ExprFunctor::VisitExpr(expr);
memo_[expr] = new_expr;
return new_expr;
}
}
可以看到memo_存儲(chǔ)了圖中的各個(gè)節(jié)點(diǎn)。參考IfNode的實(shí)現(xiàn):
Expr ExprMutator::VisitExpr_(const IfNode* op) {
auto guard = this->Mutate(op->cond);
auto true_b = this->Mutate(op->true_branch);
auto false_b = this->Mutate(op->false_branch);
if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) &&
op->false_branch.same_as(false_b)) {
return GetRef<Expr>(op);
} else {
return If(guard, true_b, false_b, op->span);
}
}
如果IFNode的子節(jié)點(diǎn)都沒(méi)有被修改,那么就返回這個(gè)節(jié)點(diǎn)本身。否則創(chuàng)建新的節(jié)點(diǎn)If(guard, true_b, false_b, op->span);并返回。這里構(gòu)造新節(jié)點(diǎn)的類If的定義和實(shí)現(xiàn)分別在tvm/src/relay/ir/expr.h和tvm/src/relay/ir/expr.cc中:
class If : public Expr {
public:
/*!
* \brief The constructor
* \param cond The condition of a if node.
* \param true_branch The fall through branch
* \param false_branch The branch for execution when condition is false.
* \param span The source span of the expression.
*/
TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode);
};
If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) {
ObjectPtr<IfNode> n = make_object<IfNode>();
n->cond = std::move(cond);
n->true_branch = std::move(true_branch);
n->false_branch = std::move(false_branch);
n->span = std::move(span);
data_ = std::move(n);
}
總結(jié)
這一節(jié)主要解析了Relay 表達(dá)式樹(shù)的數(shù)據(jù)結(jié)構(gòu),TVM的所有pass都是基于在tvm/include/tvm/relay/expr.h中定義的各種Node組成的表達(dá)式樹(shù)來(lái)完成的,也可以說(shuō)是計(jì)算圖。另外還講解了TVM為了方便對(duì)這些表達(dá)式節(jié)點(diǎn)進(jìn)行訪問(wèn)和操作抽象出了ExprFunctor這個(gè)類,并在ExprFunctor這個(gè)類的基礎(chǔ)上擴(kuò)展ExprVisitor或ExprMutator,這在實(shí)現(xiàn)各個(gè)Pass的C++后端代碼時(shí)非常有用。最后我以IfNode的實(shí)現(xiàn)和常量折疊Pass中的ConstantChecker類實(shí)現(xiàn)為例,展示了這些類的具體用法。
0x3. Function Pass的C++后端通用創(chuàng)建流程
這里先基于Constant Folding Pass講解一下Function Pass的C++后端通用創(chuàng)建流程。這里先看一下FoldConstant的定義:
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FoldConstant(f, m));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant);
} // namespace transform
CreateFunctionPass這個(gè)函數(shù)用來(lái)創(chuàng)建FunctionPass,相關(guān)代碼如下:
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return FunctionPass(pass_func, pass_info);
}
FunctionPass::FunctionPass(
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<FunctionPassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
data_ = std::move(n);
}
可以看到在FunctionPass中創(chuàng)建了一個(gè)FunctionPassNode實(shí)例并將其放到data_中,data_來(lái)自于ObjectRef這個(gè)類的成員變量,這里FunctionPass->Pass->ObjectRef。
如果將上述代碼生成的Pass對(duì)象提供給Pass Infrastructure,它將確保將 AST 遍歷應(yīng)用于給定 Relay Module中的每個(gè)Function,這是我們對(duì)Constant Folding Pass所期望的行為(它應(yīng)該盡可能折疊所有常量)。
函數(shù)CreateFunctionPass允許注冊(cè)傳遞的優(yōu)化級(jí)別(在本例中為2),可用于根據(jù)pass的通用效用、pass的名稱以及pass的任何依賴關(guān)系將pass組合在一起。一個(gè)pass的依賴是一系列可能會(huì)對(duì)這個(gè)pass的結(jié)果產(chǎn)生影響的pass。FoldConstant沒(méi)有任何依賴。但是很多Relay pass確實(shí)依賴于類型信息,所以InferType是一個(gè)常見(jiàn)的依賴;others may depend on the program’s being in A-normal form, via the ToANormalForm pass.
注意,PassContext 對(duì)象包含傳遞用于錯(cuò)誤報(bào)告和配置選項(xiàng)的信息;FoldConstant不需要此信息,但其它Pass可能會(huì)引用它們的PassContext對(duì)象。
現(xiàn)在可以通過(guò)Pass Infrastructure調(diào)用pass,不過(guò)最好也為 pass 添加 Python 綁定,如以下代碼片段所示:
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
一旦以上述方式定義了 Pass 對(duì)象,就可以使用 Pass 基礎(chǔ)結(jié)構(gòu)的 Sequential 構(gòu)造調(diào)用它們,該構(gòu)造采用傳遞列表并將它們按順序應(yīng)用于Relay 模塊,從而獲得轉(zhuǎn)換后的Module。例如,下面的代碼將 FoldConstant 和 ToANormalForm Pass(一個(gè)接一個(gè))應(yīng)用于mod中的每個(gè)函數(shù)并獲得一個(gè)新Module。
seq = transform.Sequential([
relay.transform.FoldConstant(),
relay.transform.ToANormalForm()
])
new_mod = seq(mod)
我們可以看一下Sequential的調(diào)用流程:
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
for (const Pass& pass : passes) {
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!pass_ctx.PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
mod = GetPass(it)(std::move(mod), pass_ctx);
}
mod = pass(std::move(mod), pass_ctx);
}
return mod;
}
這里分成兩個(gè)部分,如果pass有依賴,則先運(yùn)行依賴pass。GetPass會(huì)在relay._transform的列表中根據(jù)命名返回對(duì)應(yīng)的pass。代碼實(shí)現(xiàn)如下:
Pass GetPass(const String& pass_name) {
using tvm::runtime::Registry;
const runtime::PackedFunc* f = nullptr;
if (pass_name.operator std::string().find("transform.") != std::string::npos) {
f = Registry::Get(pass_name);
} else if ((f = Registry::Get("transform." + pass_name))) {
// pass
} else if ((f = Registry::Get("relay._transform." + pass_name))) {
}
ICHECK(f != nullptr) << "Cannot use " << pass_name << "to create the pass";
return (*f)();
}
接著再跟進(jìn)一下mod = pass(std::move(mod), pass_ctx);,代碼實(shí)現(xiàn)如下:
IRModule Pass::operator()(IRModule mod, const PassContext& pass_ctx) const {
const PassNode* node = operator->();
ICHECK(node != nullptr);
PassProfile::EnterPass(node->Info()->name);
auto ret = node->operator()(std::move(mod), pass_ctx);
PassProfile::ExitPass();
return std::move(ret);
}
// 創(chuàng)建PassNode類型的node實(shí)例
const Object* operator->() const { return get(); }
// 虛函數(shù),需要子類重寫(xiě)
virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx) const = 0;
這里關(guān)注一下virtual IRModule operator()(IRModule mod, const PassContext& pass_ctx)的接口實(shí)現(xiàn),具體到FunctionPassNode重寫(xiě)這個(gè)operator()方法,因?yàn)镕unctionPassNode繼承了PassNode,代碼實(shí)現(xiàn)如下:
// Perform Module -> Module optimizations at the Function level.
IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
DiagnosticContext previous = DiagnosticContext::Default(mod);
if (pass_ctx->diag_ctx) {
DiagnosticContext tmp = pass_ctx->diag_ctx.value();
pass_ctx->diag_ctx = previous;
previous = tmp;
} else {
pass_ctx->diag_ctx = previous;
}
ICHECK(pass_ctx->diag_ctx)
<< "The diagnostic context was set at the top of this block this is a bug.";
const PassInfo& pass_info = Info();
ICHECK(mod.defined());
DLOG(INFO) << "Executing function pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level;
pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module.
IRModule updated_mod =
IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map);
std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) {
// only picks up relay::Function
if (auto* n = it.second.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx);
updates.push_back({it.first, updated_func});
}
}
for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true);
}
ICHECK(pass_ctx->diag_ctx)
<< "The diagnostic context was set at the top of this block this is a bug.";
pass_ctx->diag_ctx.value().Render();
pass_ctx->diag_ctx = previous;
pass_ctx.Trace(updated_mod, pass_info, false);
// TODO(@jroesch): move away from eager type checking for performance reasons
// make issue.
return transform::InferType()(updated_mod);
}
這個(gè)實(shí)現(xiàn)比較復(fù)雜,但我們只需要關(guān)心這行auto updated_func = SkipFunction(func) ? func : pass_func(func, updated_mod, pass_ctx);,這是執(zhí)行Pass的核心操作。
以上就是Function Pass的C++后端通用創(chuàng)建流程。
0x4. Constant Folding Pass
下面我們來(lái)看一下Constant Folding Pass的C++后端代碼實(shí)現(xiàn)需要注意哪些東西,首先Constant Folding Pass屬于Funtion-level Pass。入口依舊是Function Pass的注冊(cè)接口:
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FoldConstant(f, m));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant").set_body_typed(FoldConstant);
} // namespace transform
我們看一下FoldConstant的具體實(shí)現(xiàn):
Expr FoldConstant(const Expr& expr, const IRModule& mod) {
return ConstantFolder(mod).Mutate(expr);
}
可以看到常量折疊主要調(diào)用了ConstantFolder這個(gè)類的Mutate函數(shù)。而ConstantFolder繼承了MixedModeMutator這個(gè)類,MixedModeMutator這個(gè)類比較有趣,定義如下:
/*! \brief 用于自定義重寫(xiě)Pass的非遞歸 DFS 圖遍歷
*
* MixedModeMutator 將 Expr 視為數(shù)據(jù)流圖,并且每個(gè) Expr 只重寫(xiě)一次。
* mutated的結(jié)果將被記錄到一個(gè)map中被重用,以便數(shù)據(jù)流上的本地transformation保留了圖形結(jié)構(gòu)
*
* MixedModeMutator 提供與 ExprMutator 相同的遞歸 API,并使用
* 遞歸遍歷大多數(shù)形式的 IR,但在幕后它擴(kuò)展了嵌套的數(shù)據(jù)流區(qū)域
* 并迭代處理它們以防止堆棧溢出
*
* Uses Rewrite_ API of ExprRewriter for a cleaner split between recrusive and non-recursive
* behavior.
*/
class MixedModeMutator : public ::tvm::relay::ExprMutator {
public:
MixedModeMutator(bool pre = false) : pre_{pre} {};
Expr VisitExpr(const Expr& expr) final;
virtual Expr DispatchVisitExpr(const Expr& expr);
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
Expr VisitExpr_(const TupleGetItemNode* op) final { return Rewrite(op); };
/*!
* \brief 用戶應(yīng)該重寫(xiě) Rewrite_ 方法來(lái)實(shí)現(xiàn)他們的Pass。Rewrite_ functions will be able to rewrite
* the op only with data about the original node `pre` and the same node with modified
* inputs `post` and should not recurse.
* \param pre 重寫(xiě)前的表達(dá)式節(jié)點(diǎn)。
* \param post 具有重寫(xiě)輸入的表達(dá)式。
*/
virtual Expr Rewrite_(const TupleNode* pre, const Expr& post) { return post; }
virtual Expr Rewrite_(const CallNode* pre, const Expr& post) { return post; }
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }
protected:
bool pre_;
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
* changed inputs.
*/
template <typename T>
Expr Rewrite(const T* op) {
Expr post = ExprMutator::VisitExpr_(op);
return Rewrite_(op, post);
}
virtual void VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);
};
我們?cè)?x2節(jié)講到pass是對(duì)Relay 樹(shù)結(jié)構(gòu),也可以說(shuō)計(jì)算圖進(jìn)行優(yōu)化,優(yōu)化必然設(shè)計(jì)到對(duì)圖結(jié)構(gòu)的修改。這就是上面提到的ExprMutator子類,它和ExprVisitor一樣繼承自ExprFunctor,我們實(shí)現(xiàn)Pass其實(shí)就是為其重寫(xiě)VisitExpr_成員函數(shù)。但是在這個(gè)MixedModeMutator類中,VisitExpr_成員函數(shù)實(shí)際上又調(diào)用了Rewrite_,所以Constant Folding在修改節(jié)點(diǎn)時(shí)只需要重寫(xiě)這個(gè)Rewrite_成員函數(shù)即可。(繞來(lái)繞去的,要仔細(xì)看看)。
然后我們看一下ConstantChecker的實(shí)現(xiàn),Constant Folding通過(guò)ConstantChecker遞歸實(shí)現(xiàn)了Expr的常量判斷??梢詮拇a看出,Constant 主要是判斷元素是否是ConstantNode,或者TupleNode里的元素都是ConstantNode。
class ConstantChecker : private ExprVisitor {
public:
// Check whether an expression is constant. The results are memoized.
bool Check(const Expr& expr) {
// The `ConstantNode` case is common enough that we check directly for the
// case here, to avoid the time overhead of dispatching through the vtable
// and the space overhead of memoizing always-true results.
if (expr.as<ConstantNode>()) {
return true;
}
const auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;
VisitExpr(expr);
return memo_[expr]; // return memoized result or the default value false
}
private:
std::unordered_map<Expr, bool, ObjectPtrHash, ObjectPtrEqual> memo_;
void VisitExpr_(const TupleNode* n) final {
bool result = true;
for (const auto& field : n->fields) {
if (!Check(field)) {
result = false;
break;
}
}
memo_[GetRef<Tuple>(n)] = result;
}
};
bool ConstantCheck(const Expr& e) { return ConstantChecker().Check(e); }
接下來(lái)我們就解析一下真正的常量融合發(fā)生的函數(shù),即在ConstantFolder中重寫(xiě)的Rewrite_函數(shù),代碼實(shí)現(xiàn)如下:
Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (inside_primitive) {
return GetRef<Expr>(call);
}
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
auto origin_args = call->args;
call = post.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return post;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return post;
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return post;
// Try to evaluate shape_of op
if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
return EvaluateShapeOf(post, origin_args, call->attrs);
}
if (call->op == ndarray_size_op_) {
return EvaluateNdarraySize(post, origin_args, call->attrs);
}
// We should think about potentially constant evaluation over these ops too.
static auto fnoncomputational = Op::GetAttrMap<TNonComputational>("TNonComputational");
if (const auto* call_node = call->op.as<OpNode>()) {
Op op = GetRef<Op>(call_node);
if ((fnoncomputational.count(op) && fnoncomputational[op]) || (call->op == device_copy_op_)) {
return GetRef<Call>(call);
}
}
bool all_const_args = true;
for (Expr arg : call->args) {
if (!checker_.Check(arg)) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(post);
} else {
return post;
}
}
這里根據(jù)callnode的類型又分了幾種情況,我們假設(shè)代碼走到了return ConstEvaluate(post);這一行,這個(gè)函數(shù)為了完成常量折疊做了什么事情呢?
// Constant evaluate an expression.
Expr ConstEvaluate(Expr expr) {
std::vector<transform::Pass> passes = {transform::FuseOps(0), transform::ToANormalForm(),
transform::InferType()};
Function func;
if (expr.as<FunctionNode>()) {
func = Downcast<Function>(expr);
} else {
// TODO(@jroesch): fix this
func = Function(FreeVars(expr), expr, Type(), FreeTypeVars(expr, module_), {});
}
auto mod = IRModule({}, module_->type_definitions, module_->Imports());
auto global = GlobalVar("main");
mod->Add(global, func);
auto seq = transform::Sequential(passes);
mod = seq(mod);
auto entry_func = Downcast<Function>(mod->Lookup("main"));
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
using tvm::transform::PassContext;
Device dev;
dev.device_type = kDLCPU;
dev.device_id = 0;
Target target = Target("llvm");
// use a fresh build context
// in case we are already in a build context.
// needed for both execution and creation(due to JIT)
With<PassContext> fresh_build_ctx(PassContext::Create());
FInterpreter executor = CreateInterpreter(mod, dev, target);
return ObjectToExpr(executor(expr));
}
可以看到這里增加了三個(gè)passes,接下來(lái)執(zhí)行這三個(gè)Pass完成常量折疊的功能,這里有個(gè)ToANormalForm定義可以在wikipedia找到(https://en.wikipedia.org/wiki/A-normal_form):
passes = {transform::FuseOps(0), transform::ToANormalForm(),
transform::InferType()};
獲得了Sequential Pass和target以及全局module信息之后執(zhí)行ObjectToExpr完成常量值到表達(dá)式的轉(zhuǎn)換。
// Convert value to expression.
Expr ObjectToExpr(const ObjectRef& value) {
if (value->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(value);
return Constant(nd_array);
} else if (const auto* val = value.as<runtime::ADTObj>()) {
runtime::ADT adt = GetRef<runtime::ADT>(val);
Array<Expr> fields;
for (size_t i = 0; i < adt.size(); ++i) {
fields.push_back(ObjectToExpr(adt[i]));
}
return Tuple(fields);
} else {
LOG(FATAL) << "Cannot handle " << value->GetTypeKey();
return Expr();
}
}
這個(gè)函數(shù)主要實(shí)現(xiàn)了基于runtime的結(jié)果生成新的Expr來(lái)代替原來(lái)的Expr。在Rewrite_函數(shù)中還有幾種類型的CallNode的常量折疊實(shí)現(xiàn)這里就不介紹了,感興趣的小伙伴可以自己看一下。
0x5. 筆者建議
我的建議是,如果你C++能力不是很出色,建議對(duì)于TVM的C++ backend Pass了解即可,不用深究,只需要按照第一節(jié)介紹的方法就可以將TVM已經(jīng)實(shí)現(xiàn)的Pass玩的風(fēng)聲水起。如果你要自定義Pass,那么直接基于TVM提供的裝飾器在Python層實(shí)現(xiàn)就可以了,具體參考:https://github.com/apache/tvm/blob/main/tutorials/dev/use_pass_infra.py。所以如果嫌棄文章太長(zhǎng),只看第一節(jié)就好了。
0x6. 總結(jié)
好像開(kāi)頭寫(xiě)好了,拷貝一下:這篇文章基于TVM 0.8.0.dev版本。在【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】五,TVM Relay以及Pass簡(jiǎn)介 這篇推文中已經(jīng)簡(jiǎn)單介紹了Relay和Pass機(jī)制。但對(duì)Pass的基礎(chǔ)設(shè)施(Pass Infrastructure)和Relay樹(shù)結(jié)構(gòu)都沒(méi)有詳細(xì)介紹,所以這篇文章主要介紹一下Pass Infrastructure和Relay樹(shù)結(jié)構(gòu),再基于這些關(guān)鍵的基礎(chǔ)知識(shí)詳細(xì)了解一下Constant Folding Pass,相信讀者讀完這篇文章會(huì)對(duì)TVM的Pass有更深的理解,并且在閱讀其它Pass和實(shí)現(xiàn)自定義Pass時(shí)可以很Relax。
0x7. 推薦閱讀
【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】六,TVM的編譯流程詳解 【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】五,TVM Relay以及Pass簡(jiǎn)介 【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】番外一,Data Flow和Control Flow 【從零開(kāi)始學(xué)TVM】三,基于ONNX模型結(jié)構(gòu)了解TVM的前端 【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】二,TVM中的scheduler 【從零開(kāi)始學(xué)深度學(xué)習(xí)編譯器】一,深度學(xué)習(xí)編譯器及TVM 介紹
0x8. 參考
https://tvm.apache.org/docs https://zhuanlan.zhihu.com/p/151815380
歡迎關(guān)注GiantPandaCV, 在這里你將看到獨(dú)家的深度學(xué)習(xí)分享,堅(jiān)持原創(chuàng),每天分享我們學(xué)習(xí)到的新鮮知識(shí)。( ? ?ω?? )?
有對(duì)文章相關(guān)的問(wèn)題,或者想要加入交流群,歡迎添加BBuf微信:
