【從零開始學深度學習編譯器】五,TVM Relay以及Pass簡介
【GiantPandaCV導語】這篇文章主要介紹了一下TVM的Relay并介紹了如何基于Relay構建一個Conv+BN+ReLU的小網(wǎng)絡,然后介紹了一下TVM中的Pass的工作機制,并較為詳細的介紹了RemoveUnusedFunctions,ToBasicBlockNormalForm,EliminateCommonSubexpr三種Pass。其中Relay部分的詳細介紹大部分引用自官方文檔:https://tvm.apache.org/docs/tutorials/get_started/introduction.html。
0x0. 介紹
在前面幾節(jié)的介紹中我們了解到了TVM是如何將ONNX前端模型轉換為IR Module的,并且還剖析了TVM中的Relay算子和TOPI算子的扭轉過程,知道了Relay算子的最終計算也是基于TOPI算子集合完成的。然后我們在基于ONNX模型結構了解TVM的前端那篇文章貼出的示例程序中還有一個很重要的細節(jié)即TVM的編譯流程沒有詳細介紹,即下面這幾行代碼:
######################################################################
# Relay Build
# -----------
# Compile the graph to llvm target with given input specification.
target = "llvm"
target_host = "llvm"
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, target_host=target_host, params=params)
這幾行代碼展示了TVM的編譯流程,在這個編譯流程里面不僅包含了基于Relay IR進行的優(yōu)化策略來去除冗余的算子(也叫Pass)還包含了將Relay程序編譯成特定后端(這里是llvm)可以執(zhí)行的代碼(codegen)。
在這篇文章中我們將簡單介紹一下Relay,然后再認識一下TVM中的Pass,也就是解釋with tvm.transform.PassContext(opt_level=3)這個類具體完成了什么工作。至于code gen和詳細的編譯流程,由于TVM的水太深,我還沒把握住,下次再探索吧。
0x2. Relay介紹
這一節(jié)主要結合TVM的文檔(https://tvm.apache.org/docs/dev/relay_intro.html)來介紹一下NNVM的第二代Relay。Relay的設計目標有以下幾點:
支持傳統(tǒng)的數(shù)據(jù)流(DataFlow)風格編程。 支持functional-style scoping,并融合了編程語言領域的一些知識,帶了一些新的特性(支持Let表達式,支持遞歸等等) 支持數(shù)據(jù)流風格和函數(shù)式風格混合編程。
0x2.1 使用Relay建立一個計算圖
傳統(tǒng)的深度學習框架使用計算圖作為它們的中間表示。計算圖(或數(shù)據(jù)流圖)是代表計算過程的有向無環(huán)圖(DAG)。盡管由于缺少控制流,數(shù)據(jù)流圖在計算能力方面受到限制,但它們的簡單性使其易于實現(xiàn)自動微分并針對異構執(zhí)行環(huán)境進行編譯(例如,在專用硬件上執(zhí)行計算圖的某些部分,即子圖)。

我們可以使用Relay來構建一個計算(DataFlow)圖。具體來說,上面的代碼顯示了如何構造一個簡單的兩個節(jié)點的計算圖,我們可以發(fā)現(xiàn)這個示例的代碼和現(xiàn)有的Garph IR如NNVMv1沒有太大區(qū)別,唯一的區(qū)別是在術語方面:
現(xiàn)有框架通常使用圖和子圖 Relay使用函數(shù),例如 – fn(%x),表示圖
每個數(shù)據(jù)流節(jié)點都是Relay中的一個CallNode。通過Relay的Python DSL,我們可以快速構建計算圖。在上面的代碼需要注意的是這里顯示構造了一個Add節(jié)點,兩個輸入都指向%1。當一個深度學習框架對上面的計算圖進行推理時,它將會按照拓撲序進行計算,并且%1只會被計算一次。雖然這個事實對于深度學習框架的開發(fā)者來說是一件很自然的事情,但這或許會使得只關心算法的研究員困惑。如果我們實現(xiàn)一個簡單的vistor來打印結果并將結果視為嵌套的Call表達式,它將是log(%x) + log(%x)。
當DAG中存在共享節(jié)點時,這種歧義是由程序語義的解釋不同而引起的。在正常的函數(shù)式編程IR中,嵌套表達式被視為表達式樹,并沒有考慮%1實際上在%2中被重用了2次的事實。
Relay IR注意到了這個區(qū)別。其實深度學習框架用戶經常使用這種方式構建計算圖,其中經常發(fā)生DAG節(jié)點重用。然后當我們以文本格式打印Relay程序時,我們每行打印一個CallNode,并為每個CallNode分配一個臨時ID(%1, %2),以便可以在程序的后續(xù)部分中引用每個公共節(jié)點。
0x2.2 Module:支持多個函數(shù)(Graphs)
上面介紹了如何構建一個數(shù)據(jù)流圖為一個函數(shù)。然后一個很自然的問題是可以做到構建多個函數(shù)并相互調用嗎?Relay允許將多個函數(shù)組合在一個Module中,下面的代碼展示了一個函數(shù)調用另外一個函數(shù)的例子。
def @muladd(%x, %y, %z) {
%1 = mul(%x, %y)
%2 = add(%1, %z)
%2
}
def @myfunc(%x) {
%1 = @muladd(%x, 1, 2)
%2 = @muladd(%1, 2, 3)
%2
}
Module可以被看作Map<GlobalVar, Function>,其中GlobalVar僅僅是一個表示函數(shù)名的ID,上面的程序中GlobalVar是@muladd和@myfunc。當一個CallNode被用來調用另外一個函數(shù)時,相應的GlobalVar被存在CallNode的OP中。它包含了一個間接的等級關系---我們需要使用相應的GlobalVar從Module中查找被調用函數(shù)的主體。在這種情況下,我們也可以直接將引用的函數(shù)存儲為CallNode中的OP。那么為什么需要引入GlobalVar呢?主要原因是為了解耦定義和聲明,并支持了函數(shù)的遞歸和延遲聲明。
def @myfunc(%x) {
%1 = equal(%x, 1)
if (%1) {
%x
} else {
%2 = sub(%x, 1)
%3 = @myfunc(%2)
%4 = add(%3, %3)
%4
}
}
在上面的例子中,@myfunc遞歸調用它自己。使用GlobalVar @myfunc來表示函數(shù)避免了數(shù)據(jù)結構中的循環(huán)依賴性。至此,已經介紹完了Relay中的基本概念。值得注意的是,相比NNVM,Relay在如下方面進行了改進:
有文本形式中間表示,便于開發(fā)和 debug 支持子圖函數(shù)、聯(lián)合模塊,便于聯(lián)合優(yōu)化 前端用戶友好,便于調優(yōu)
0x2.3 Let Binding and Scopes
至此,已經介紹了如何用深度學習框架中的舊方法來構建計算圖。這一節(jié)將討論一個Relay的一個新的構造-let bindings。
Let binding被每一種高級的編程語言應用。在Relay中,他是一個擁有三個字段Let(var, value, body)的數(shù)據(jù)結構。當我們計算一個Let表達式時,我們首先計算value部分,然后將其綁定到var,最后在body表達式中返回計算結果。
我們可以使用一系列的Let綁定來構造一個邏輯上等效于數(shù)據(jù)流程序的程序,下面的代碼示例顯示了這個用法:

嵌套的Let Binding被稱作A-normal形式,作為函數(shù)式編程語言中的常用IR。通過上面的圖我們可以發(fā)現(xiàn)雖然這兩個程序的語義完全等價,它們的文本表示也一樣(除了A-norm形式有l(wèi)et的前綴),但AST抽象語法樹卻不一樣。
由于程序的優(yōu)化使用了這些AST數(shù)據(jù)結構并對其進行了變換,這兩種不同的結構會影響到最終編譯器生成的代碼。比如,我們想要檢測add(log(x), y)這個模式。在數(shù)據(jù)流程序中,我們可以首先進入add節(jié)點,然后直接檢查它的第一個參數(shù)是不是log。而在A-form的程序中,我們不能直接檢查任何東西,因為add節(jié)點的輸入是%v1-我們需要維護一個映射表將變量和它綁定的值進行映射,然后查表才知道%v1代表的是log。
0x2.4 為什么我們可能需要Let Binding
Let Binding的一種關鍵用法是它可以指定計算的scope。我們看一下下面這個沒有使用Let Binding的例子:

當我們嘗試在該在哪里計算%1節(jié)點時,問題就來了。特別的是,雖然文本格式似乎建議我們應該在if的scope之外計算節(jié)點%1,但AST卻不建議這樣做。實際上數(shù)據(jù)流圖永遠不會定義它的計算scope,這在語義上產生了一些歧義。
當我們有閉包時,這種歧義更加有趣,考慮下面的程序,該程序返回一個閉包。我們不知道在哪里計算%1,它可以在閉包的內部和外部。
fn (%x) {
%1 = log(%x)
%2 = fn(%y) {
add(%y, %1)
}
%2
}
Let Binding解決了這些問題,因為值的計算發(fā)生在let節(jié)點上。在這兩個程序中,如果將%1 = log(%x)改成let %v1 = log(%x),則我們將計算位置明確指定為if scope和閉包之外。可以看到Let Binding為計算端提供了更精確的范圍,并且在生成后端代碼時會很有用(因為這種范圍在IR中)。
另一方面,沒有指定計算scope的數(shù)據(jù)流形式也有其自身的優(yōu)勢,我們不需要擔心在生成代碼時將let放到哪里。數(shù)據(jù)流格式還為后面決定將計算節(jié)點放到哪里的Passes提供了更大的自由度。因此,在優(yōu)化的初始階段如果發(fā)現(xiàn)數(shù)據(jù)流形式還是挺方便的,那么使用數(shù)據(jù)流圖的編碼方法可能不是一個壞主意。目前在Relay中也實現(xiàn)了很多針對數(shù)據(jù)流圖的優(yōu)化方式。
但是,當我們將IR lower到實際的運行時程序時,我們需要精確的計算scope。特別是當我們使用子函數(shù)和閉包時,我們要明確指定計算scope應在哪里發(fā)生。在后期執(zhí)行特定的優(yōu)化中,可以使用Let Binding來解決此問題。
0x2.5 對IR轉換的影響
希望到目前為止,你們已經熟悉兩種表示形式。大多數(shù)函數(shù)式編程語言都以A-normal形式進行分析,分析人員無需注意表達式是DAG。
Relay選擇同時支持數(shù)據(jù)流形式和Let Binding。TVM相信讓框架開發(fā)者選擇熟悉的表達形式很重要。但是這確實對我們寫通用的Passes產生了一些影響。由于這里還沒介紹Passes,以及對Passes理解不深并且我沒有使用過Let表達式來構建網(wǎng)絡,就不繼續(xù)介紹具體有哪些影響了。
詳細內容可以參考:https://tvm.apache.org/docs/dev/relay_intro.html#let-binding-and-scopes
0x3. 基于Relay構建一個自定義的神經網(wǎng)絡示例
我們基于Relay的接口定義一個Conv+BN+ReLU的小網(wǎng)絡,展示一下Relay接口應該如何使用,這里TVM版本是0.8.0.dev,代碼如下:
#coding=utf-8
import tvm
from tvm import relay
import numpy as np
from tvm.contrib import graph_executor
# 構造BN
def batch_norm(data,
gamma=None,
beta=None,
moving_mean=None,
moving_var=None,
**kwargs):
name = kwargs.get("name")
kwargs.pop("name")
if not gamma:
gamma = relay.var(name + "_gamma")
if not beta:
beta = relay.var(name + "_beta")
if not moving_mean:
moving_mean = relay.var(name + "_moving_mean")
if not moving_var:
moving_var = relay.var(name + "_moving_var")
return relay.nn.batch_norm(data,
gamma=gamma,
beta=beta,
moving_mean=moving_mean,
moving_var=moving_var,
**kwargs)[0]
# 構造卷積
def conv2d(data, weight=None, **kwargs):
name = kwargs.get("name")
kwargs.pop("name")
if not weight:
weight = relay.var(name + "_weight")
return relay.nn.conv2d(data, weight, **kwargs)
# 構造卷積+BN+ReLU的simpleNet
def simplenet(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
padding=(1, 1), epsilon=1e-5):
conv = conv2d(
data=data,
channels=channels,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_layout='NCHW',
name=name+'_conv')
bn = batch_norm(data=conv, epsilon=epsilon, name=name + '_bn')
act = relay.nn.relu(data=bn)
return act
data_shape = (1, 3, 224, 224)
kernel_shape = (32, 3, 3, 3)
dtype = "float32"
data = relay.var("data", shape=data_shape, dtype=dtype)
act = simplenet(data, "graph", 32, strides=(2, 2))
func = relay.Function(relay.analysis.free_vars(act), act)
print(func)
np_data = np.random.uniform(-1, 1, (1, 3, 224, 224))
params = {
"graph_conv_weight": tvm.nd.array(np.random.uniform(-1, 1, (32, 3, 3, 3)).astype(dtype)),
"graph_bn_gamma": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),
"graph_bn_beta": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),
"graph_bn_moving_mean": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),
"graph_bn_moving_var": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),
}
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(func, "llvm", params=params)
dev = tvm.cpu(0)
dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# set inputs
m.set_input("data", tvm.nd.array(np_data.astype(dtype)))
# execute
m.run()
# get outputs
tvm_output = m.get_output(0)
就是一個很常規(guī)的過程,創(chuàng)建Relay Function,然后將所有的OP的權重信息用params這個字典存起來,注意這里的權重信息是隨機初始化的。在編譯Relay IR之前可以先看一下優(yōu)化前的IR長什么樣:
fn (%data: Tensor[(1, 3, 224, 224), float32], %graph_conv_weight, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var) {
%0 = nn.conv2d(%data, %graph_conv_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]);
%1 = nn.batch_norm(%0, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var);
%2 = %1.0;
nn.relu(%2)
}
符合我們第二節(jié)介紹的規(guī)則,Relay IR時一個函數(shù)。
0x4. 初識Pass
在上面構造simplenet的代碼中,relay.build外部包了一層tvm.transform.PassContext,如下:
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(func, "llvm", params=params)
實際上tvm.transform.PassContext這個接口就定義了Pass,如文檔所示:

Pass是TVM中基于Relay IR進行的一系列優(yōu)化,類似于onnx-simplifier里面用到的onnxoptimizer,它可以簡化計算圖,去除一些冗余的算子,提高模型的推理效率。TVM將所有的pass都抽象到了tvm/include/tvm/ir/transform.h這個文件中,主要包含PassContext,PassInfo,Pass,以及Sequential。
這里的PassContext即是上面Python接口對應的C++實現(xiàn),它包含了Pass執(zhí)行依賴的一些參數(shù)如優(yōu)化level,依賴的其它特定Pass以及設置不使用某種指定Pass等。PassInfo是用來記錄Pass信息的類,包含Pass的opy_level,name,以及當前Pass需要哪些前置Pass。而Pass這個類就執(zhí)行pass的主體,這是一個基類,每種Pass具體的C++代碼實現(xiàn)在tvm/src/relay/transforms中,它們都會繼承Pass這個基類。最后,Sequential是一個container,裝載所有Pass。
需要說明一下,不是所有的Pass都定義在tvm/src/relay/transforms這里,比如下面的第一個例子就在tvm/src/relay/backend/vm這個文件夾里。接下來我們將幾個Pass的例子,看看它到底對Relay IR做了什么?
RemoveUnusedFunctions
首先來看一下定義在tvm/src/relay/backend/vm/removed_unused_funcs.cc這里的RemoveUnusedFunctions 這個pass,核心的代碼實現(xiàn)如下:
void VisitExpr_(const FunctionNode* func_node) final {
auto func = GetRef<Function>(func_node);
if (visiting_.find(func) == visiting_.end()) {
visiting_.insert(func);
for (auto param : func_node->params) {
ExprVisitor::VisitExpr(param);
}
ExprVisitor::VisitExpr(func_node->body);
}
}
IRModule RemoveUnusedFunctions(const IRModule& module, Array<runtime::String> entry_funcs) {
std::unordered_set<std::string> called_funcs{};
for (auto entry : entry_funcs) {
auto funcs = CallTracer(module).Trace(entry);
called_funcs.insert(funcs.cbegin(), funcs.cend());
}
auto existing_functions = module->functions;
for (auto f : existing_functions) {
auto it = called_funcs.find(f.first->name_hint);
if (it == called_funcs.end()) {
module->Remove(f.first);
}
}
return module;
}
比較容易看出這個pass就是去除Relay IR中的冗余節(jié)點,VisitExpr_這個函數(shù)就是完成了一個圖的遍歷,然后把沒有遍歷到的節(jié)點刪掉。刪除發(fā)生在RemoveUnusedFunctions這個函數(shù)中。
ToBasicBlockNormalForm
這個Pass實現(xiàn)在tvm/src/relay/transforms/to_basic_block_normal_form.cc,代碼實現(xiàn)如下:
Expr ToBasicBlockNormalFormAux(const Expr& e) {
// calculate all the dependency between nodes.
support::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
/* The scope of the whole expr is global.
* The scope of any subexpr, is the lowest common ancestor of all incoming edge.
* We also record the set of expressions whose scope is lifted.
*/
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
}
IRModule ToBasicBlockNormalForm(const IRModule& mod) {
DLOG(INFO) << "ToBBlock:" << std::endl << mod;
tvm::Map<GlobalVar, Function> updates;
auto funcs = mod->functions;
for (const auto& it : funcs) {
ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables";
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
}
Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second);
updates.Set(it.first, Downcast<Function>(ret));
}
for (auto pair : updates) {
mod->Add(pair.first, pair.second, true);
}
DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod;
return mod;
}
bool BasicBlockNormalFormCheck(const Expr& e) {
// calculate all the dependency between nodes.
support::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
for (auto expr : scopes.second) {
LOG(FATAL) << "The expression below violates the basic block normal form in that "
<< "its scope should be lifted:\n"
<< expr;
}
return scopes.second.size() == 0;
}
ToBasicBlockNormalForm這個函數(shù)通過遍歷Relay IR中的function,將每個function轉換為基本塊形式(即ToBasicBlockNormalFormAux這個函數(shù)),ToBasicBlockNormalFormAux這個函數(shù)分成以下幾個部分:
調用 DependencyGraph dg = DependencyGraph::Create(&arena, e)創(chuàng)建一個DependencyGraph,這個數(shù)據(jù)結構是一個表達式相互依賴的圖結構。通過 std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg)計算每個節(jié)點的scope,這個scope可以簡單理解為由跳轉指令如Ifnode,F(xiàn)unctionNode,LetNode等隔開的那些子圖,因為一旦碰到這些節(jié)點在上面通過Relay Function創(chuàng)建DependencyGraph就會為這種節(jié)點分配一個new_scope標志。然后CalcScope這個函數(shù)具體做了哪些事情,我們需要跟進去看一下:
std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg) {
NodeScopeMap expr_scope;
ExprSet lifted_exprs;
std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr;
// 首先讓每個節(jié)點都屬于一個單獨的scope
for (auto expr_node : dg.expr_node) {
node_to_expr[expr_node.second] = expr_node.first;
}
bool global_scope_used = false;
Scope global_scope = std::make_shared<ScopeNode>();
// 使用LCA算法來更新每個節(jié)點的真正scope
for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) {
DependencyGraph::Node* n = *it;
auto iit = n->parents.head;
Scope s;
if (iit == nullptr) {
ICHECK(!global_scope_used);
s = global_scope;
global_scope_used = true;
} else {
s = expr_scope.at(iit->value);
const auto original_s = s;
iit = iit->next;
for (; iit != nullptr; iit = iit->next) {
s = LCA(s, expr_scope.at(iit->value));
}
if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {
// filter out exprs whose scope do not matter
Expr expr = node_to_expr[n];
if (!expr.as<OpNode>()) {
lifted_exprs.insert(expr);
}
}
}
if (n->new_scope) {
auto child_scope = std::make_shared<ScopeNode>(s);
expr_scope.insert({n, child_scope});
} else {
expr_scope.insert({n, s});
}
}
ICHECK(global_scope_used);
return std::make_pair(expr_scope, lifted_exprs);
}
這個函數(shù)首先讓每個節(jié)點都屬于一個單獨的scope,然后使用LCA算法來更新每個節(jié)點的真正scope。這里簡單介紹一下LCA算法以及這里具體是如何求取每個節(jié)點的scope的。
最近公共祖先簡稱 LCA(Lowest Common Ancestor)。兩個節(jié)點的最近公共祖先,就是這兩個點的公共祖先里面,離根最遠的那個。為了方便,我們記某點集 的最近公共祖先為 或 。LCA有以下性質,引自OI-wiki:
; 是 的祖先,當且僅當 ; 如果 不為 的祖先并且 不為 的祖先,那么 分別處于 的兩棵不同子樹中; 前序遍歷中, 出現(xiàn)在所有 中元素之前,后序遍歷中 則出現(xiàn)在所有 中元素之后; 兩點集并的最近公共祖先為兩點集分別的最近公共祖先的最近公共祖先,即 ; 兩點的最近公共祖先必定處在樹上兩點間的最短路上; ,其中 是樹上兩點間的距離, 代表某點到樹根的距離。
其實不看這個性質也沒關系,了解LCA可以求圖中兩個節(jié)點的最近公共祖先即可。然后CalcScope這個函數(shù)的具體思路就是先將每個節(jié)點初始化為一個單獨的scope,然后按照后DFS序遍歷這些節(jié)點,對于每一個遍歷到的節(jié)點(這里記作n),看一下它的父親節(jié)點iit是否存在,如果不存在則說明當前節(jié)點是根節(jié)點,它的scope應該為global_scope。如果iit存在,那么遍歷iit的子節(jié)點,看一下這些節(jié)點的scope的LCA表達式,如果這個通過LCA求出來的表達式和iit節(jié)點的表達式完全相同,說明這個子圖和當前節(jié)點是屬于同一個scope的,否則就將當前節(jié)點插入到lifted_exprs,lifted_exprs是一個集合用來保存這個DependencyGraph里面的那些跳轉指令節(jié)點,這也是為什么上面再插入節(jié)點到lifted_exprs之前需要判斷一下這個節(jié)點的類型是否為OpNode。另外如果當前枚舉的節(jié)點有new_scope標志,說明當前節(jié)點屬于一個新的scope,需要為當前節(jié)點分配新的類型為ScopeNode的一個智能指針。
通過上面的算法,DependencyGraph中的節(jié)點和scope節(jié)點的關系就被映射到了一個map中,并且scope節(jié)點也被建立起了一個樹結構。最后調用這個Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);來創(chuàng)建一個Fill類,這個類包含了DependencyGraph以及scope相關的信息,通過ToBasicBlockNormalForm成員函數(shù)實現(xiàn)基本塊轉換。它的實現(xiàn)在tvm/src/relay/transforms/to_a_normal_form.cc這個文件中,沒有看得太懂,感興趣的讀者可以自己跟進來看一下,知乎的moon博主對這個Pass也做了解釋,這里引用一下:
它(
ToBasicBlockNormalForm)的基本邏輯通過VisitExpr函數(shù)遍歷dependency節(jié)點,將具有相同scope的節(jié)點壓入到同一個let_list中。Let_list文檔中是這樣解釋的:
/*!
* \file let_list.h
* \brief LetList record let binding and insert let expression implicitly.
* using it, one can treat AST as value instead of expression,
* and pass them around freely without fear of AST explosion (or effect duplication).
* for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'.
* if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);',
* the AST will contain 2 'a', as b and c are now variables.
Let_list使得抽象語法樹簡潔化,不會因為變量的復制導致樹的爆炸。具有相同的scope的expr被約束到相同的let_list中,用一個var來表達,這樣就將表達式轉化為var的形式。一個var也就對應了一個基本塊。
EliminateCommonSubexpr
最后再看一個消除公共子表達式的Pass,所謂公共子表達式指的就是具有相同的OP類型以及相同的參數(shù),并且參數(shù)的順序都是完全相同的,那么這些表達式就可以合成一個公共子表達式。舉個例子:
a = b + cd = b + c
可以看到這兩個表達式時完全一致的,那么經過這個Pass之后計算圖就會消除其中一個表達式。代碼實現(xiàn)在:tvm/src/relay/transforms/eliminate_common_subexpr.cc。這里定義了一個CommonSubexprEliminator類,這個類重載了兩個Rewrite_函數(shù)來對expr進行遍歷和重寫。代碼實現(xiàn)如下:
Expr Rewrite_(const CallNode* call, const Expr& post) final {
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
Expr new_expr = post;
const CallNode* new_call = new_expr.as<CallNode>();
ICHECK(new_call);
const OpNode* op = new_call->op.as<OpNode>();
StructuralEqual attrs_equal;
if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
return new_expr;
}
if (fskip_ != nullptr && fskip_(new_expr)) {
return new_expr;
}
auto it = expr_map_.find(new_call->op);
if (it != expr_map_.end()) {
for (const Expr& candidate_expr : it->second) {
if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
bool is_equivalent = true;
// attrs匹配
if (!attrs_equal(new_call->attrs, candidate->attrs)) {
continue;
}
// args匹配
for (size_t i = 0; i < new_call->args.size(); i++) {
if (!new_call->args[i].same_as(candidate->args[i]) &&
!IsEqualScalar(new_call->args[i], candidate->args[i])) {
is_equivalent = false;
break;
}
}
if (!is_equivalent) continue;
return GetRef<Call>(candidate);
}
}
}
expr_map_[new_call->op].push_back(new_expr);
return new_expr;
}
可以看到大概的思路就是利用expr_map_這個std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;來映射遍歷過的具有相同op的expr,然后每次碰到相同op的表達式都會對已經記錄的expr進行匹配,匹配不僅包含OP的attrs屬性還包含參數(shù)列表,如果它們完全一樣說明這兩個表達式就是公共表達式,就不返回新的表達式。這樣就可以去掉Relay Function中的公共表達式了。
到這里可能還不是特別清楚我們最開始加載的那個simplenet的Relay Function經過一些Pass之后具體變成什么樣,我其實目前也還沒搞清楚這個問題,這個問題應該就需要留到后面再解答了。
0x5. 小結
這篇文章主要介紹了一下TVM的Relay并介紹了如何基于Relay構建一個Conv+BN+ReLU的小網(wǎng)絡,然后介紹了一下TVM中的Pass的工作機制,并較為詳細的介紹了RemoveUnusedFunctions,ToBasicBlockNormalForm,EliminateCommonSubexpr三種Pass。其中Relay部分的詳細介紹大部分引用自官方文檔:https://tvm.apache.org/docs/tutorials/get_started/introduction.html。
0x6. 參考資料
https://zhuanlan.zhihu.com/p/358437531 https://zhuanlan.zhihu.com/p/91283238 https://tvm.apache.org/docs/tutorials/get_started/introduction.html
歡迎關注GiantPandaCV, 在這里你將看到獨家的深度學習分享,堅持原創(chuàng),每天分享我們學習到的新鮮知識。( ? ?ω?? )?
有對文章相關的問題,或者想要加入交流群,歡迎添加BBuf微信:
