PyTorch 源碼解讀之即時編譯篇
前言
torch 從 1.0 開始支持了 jit 模塊,其大概包括以下幾個部分:
一種新的計算圖中間表示 (Intermediate Representation),之后簡稱為 IR. 從 Python 代碼導出IR的兩種方法,即 trace 與 script. IR 優(yōu)化以及 IR 的解釋器(翻譯為具體的運算 op).
這篇解讀會分為以下幾個部分:
jit 的簡單介紹以及兩種導出方式的使用例子 jit 中 IR 的形式 導出 IR 的兩種方式,trace 與 script 的源碼解讀 IR 優(yōu)化的簡單介紹
1 jit 的簡單介紹以及使用例子
JIT 簡介
如前言,這篇解讀雖然標題是 JIT,但是真正稱得上即時編譯器的部分是在導出 IR 后,即優(yōu)化 IR 計算圖,并且解釋為對應 operation 的過程,即PyTorch jit 相關 code 帶來的優(yōu)化一般是計算圖級別優(yōu)化,比如部分運算的融合,但是對具體算子(如卷積)是沒有特定優(yōu)化的,其依舊調(diào)用 torch的基礎算子庫.
大家也可以在導出 IR 也就是 torchscript 后,使用其他的編譯優(yōu)化或者解釋器,如現(xiàn)在也有script to a TensorRT engine,TRTtorch轉 tensorRT 的方案。
trace
給大家一個簡單例子。
import torchvision.models as models
resnet = torch.jit.trace(models.resnet18(),torch.rand(1,3,224,224))
output=resnet(torch.ones(1,3,224,224))
print(output)
output=resnet(torch.ones(1,3,224,224))
resnet.save('resnet.pt')
output 便是我們導出的中間表示,其可以 save 下來,在其他框架使用
我們可以看下 output 中的 IR,即 torchscript 表征的計算圖是什么樣子的。
graph(%self.1 : __torch__.torchvision.models.resnet.___torch_mangle_194.ResNet,
%input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)):
%1472 : __torch__.torch.nn.modules.linear.___torch_mangle_193.Linear = prim::GetAttr[name="fc"](%self.1)
%1469 : __torch__.torch.nn.modules.pooling.___torch_mangle_192.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)
%1468 : __torch__.torch.nn.modulesjieshao.container.___torch_mangle_191.Sequential = prim::GetAttr[name="layer4"](%self.1)
%1422 : __torch__.torch.nn.modules.container.___torch_mangle_175.Sequential = prim::GetAttr[name="layer3"](%self.1)
....
%1556 : Tensor = prim::CallMethod[name="forward"](%1469, %1555)
%1202 : int = prim::Constant[value=1]()
%1203 : int = prim::Constant[value=-1]()
%input : Float(1:512, 512:1, requires_grad=1, device=cpu) = aten::flatten(%1556, %1202, %1203)
%1557 : Tensor = prim::CallMethod[name="forward"](%1472, %input)
return (%1557)
這便是 trace 方法的使用,其核心實現(xiàn)的入口便是torch.jit.trace,參數(shù)為你需要導出的 model,以及合法輸入input,其大概原理恰如其名,便是跟蹤模型 inference 過程,將模型對輸入進行的操作逐一記錄下來,并對應到 IR 的操作,從而得到原本模型forward 的 IR。
ote :但是這種實現(xiàn)方式有很明顯的缺陷,PyTorch 作為動態(tài)圖網(wǎng)絡,會有很多的 input dependent的控制流語句,根據(jù)輸入的不同可能會執(zhí)行情況會不同(if 或者 變長的 loop),這樣就無法 trace 到完整的計算圖。如下就是一個 trace
失敗的 case:
if x > 2.0:
r = torch.tensor(1.0)
else:
r = torch.tensor(2.0)
return r
ftrace = torch.jit.trace(test, (torch.ones(1)))
y = torch.ones(1) * 5
print(ftrace(y))
# results: tensor(2.)
# 因為輸入只走了的分支else
script
@torch.jit.script
def foo(x, y):
if x.max() > y.max():
r = x
else:
r = y
return r
print(foo.graph)
print(foo(torch.Tensor([0]), torch.Tensor([1])))
print(foo(torch.Tensor([1]), torch.Tensor([0])))
graph(%x.1 : Tensor,
%y.1 : Tensor):
%3 : Tensor = aten::max(%x.1)
%5 : Tensor = aten::max(%y.1)
# 可以看到確實捕捉到了控制語句,
%6 : Tensor = aten::gt(%3, %5)
%7 : bool = aten::Bool(%6)
%r : Tensor = prim::If(%7)
block0():
-> (%x.1)
block1():
-> (%y.1)
return (%r)
tensor([1.])
tensor([1.])
script 使用是在你需要的地方 (fuction or nn.Module (默認追蹤 forward函數(shù)))掛載裝飾器torch.jit.script,其轉換方式跟 trace 是完全不同的思路,script 直接解析你的 PyTorch代碼,通過語法分析解析你的邏輯為一棵語法樹,然后轉換為中間表示 IR。
Note: 雖然其可以解決 trace 存在無法追蹤動態(tài)邏輯的問題,但是 Python 作為靈活度極高的語法, 想完整支持解析各種 Python 操作幾乎是不可能的,因此我們需要額外的時間熟悉哪些寫法是可以被解析的,讓我們寫代碼的體驗大打折扣。
兩者結合
兩者各有優(yōu)勢,支持靈活集合。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
# torch.jit.trace produces a ScriptModule's conv1 and conv2
self.conv1 = torch.jit.trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = torch.jit.trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
return input
scripted_module = torch.jit.script(MyModule())
因此實際使用時候,可以有如下準則:
1 大部分情況 model 只有 tensor operation,就直接無腦 tracing
2 帶 control-flow (if-else, for-loop) 的,上 scripting
3 碰上 scripting 不能 handle 的語法,要么重寫,要么把 tracing 和 scripting 合起來用(比如說只在有 control-
flow 的代碼用 scripting,其他用 tracing)
如何擴展
trace 與 script 都不能轉換第三方 Python 庫中的函數(shù),盡量所有代碼都使用 PyTorch 實現(xiàn), 自定義 op 需要注冊成 jit
操作( torch 的 op 其實也注冊了),最后轉成 torchscript。
TORCH_LIBRARY(my_ops, m) {
m.def("warp_perspective", warp_perspective);
}
更多可以參考官方教程
1 EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS
2 IR (torchscript)的基本表示
PyTorch 中的各種設計(parameter,計算節(jié)點等)在 torchscript 中是如何對應的呢?
這便是轉換出的 IR 結果,torchscrip 以下結構組合。
| 名稱 | source code | 簡介 |
|---|---|---|
| Modules | module.h | 對標 nn.Module |
| Parameters | module.h | 對標 PyTorch 的 parameter |
| Method | Method.h | 包括 FunctionSchema 方法描述,Graph 實際計算圖,GraphExecutor do the optimization and execution |
| FunctionSchema | function_schema.h | 描述參數(shù)與返回類型 |
| Graph | ir.h | 定義 function 的具體實現(xiàn),包括 Nodes,Blocks,Values |
| Nodes | ir.h | 一個指令,如一次卷積運算,一次矩陣運算 |
| Block | ir.h | 控制語句 if,loop + list of nodes |
還有with,Value,Type等
# %x.1 value
graph(%x.1 : Tensor,
%y.1 : Tensor):
# aten::max 就是一個Node
# Tensor: Type-TensorType
%3 : Tensor = aten::max(%x.1)
%5 : Tensor = aten::max(%y.1)
%6 : Tensor = aten::gt(%3, %5)
%7 : bool = aten::Bool(%6)
%r : Tensor = prim::If(%7)
# Blocks
block0():
-> (%x.1)
block1():
-> (%y.1)
return (%r)
3 導出 IR 的兩種方式,trace 與 script
因為其具體實現(xiàn)頗為復雜,粘貼的源碼也僅僅保留了簡單 case 跑過的分支,并且省去了絕大部分細節(jié),讀者如有需要更多細節(jié)可以自行去源碼查閱。
trace 實現(xiàn)
func,
example_inputs,
optimize=None,
check_trace=True,
check_inputs=None,
check_tolerance=1e-5,
strict=True,
_force_outplace=False,
_module_class=None,
_compilation_unit=_python_cu,
):
# 發(fā)現(xiàn)是nn.Module instacene forward, 追蹤forward
if isinstance(func, torch.nn.Module):
return trace_module(
func,
{"forward": example_inputs},
None,
check_trace,
wrap_check_inputs(check_inputs),
check_tolerance,
strict,
_force_outplace,
_module_class,
)
# 傳進來的是某個module instance的forward
if (
hasattr(func, "__self__")
and isinstance(func.__self__, torch.nn.Module)
and func.__name__ == "forward"
):
return trace_module(
func.__self__,
{"forward": example_inputs},
None,
check_trace,
wrap_check_inputs(check_inputs),
check_tolerance,
strict,
_force_outplace,
_module_class,
)
# 一個查找變量名的接口
var_lookup_fn = _create_interpreter_name_lookup_fn(0)
# C++ 入口
traced = torch._C._create_function_from_trace(
name, func, example_inputs, var_lookup_fn, strict,_force_outplace
)
# 檢查traced 與 原func是否有差異
if check_trace:
if check_inputs is not None:
_check_trace(
check_inputs,
func,
traced,
check_tolerance,
strict,
_force_outplace,
False,
_module_class,
)
else:
_check_trace(
[example_inputs],
func,
traced,
check_tolerance,
strict,
_force_outplace,
False,
_module_class,
)
return traced
我們發(fā)現(xiàn)經(jīng)過簡單的判斷,代碼便進入了 C++ 相關函數(shù)
traced = torch._C._create_function_from_trace(
name, func, example_inputs, var_lookup_fn, strict, _force_outplace
)
我們?nèi)?C++ 中看下發(fā)生了什么
std::pair<std::shared_ptr<TracingState>, Stack> trace(
Stack inputs,
const std::function<Stack(Stack)>& traced_fn,
std::function<std::string(const Variable&)> var_name_lookup_fn,
bool strict,
bool force_outplace,
Module* self) {
try {
auto state = std::make_shared<TracingState>();
# setTracingState 將state 這個實例set下來,在之后計算節(jié)點get出來insert計算過程
setTracingState(state);
#state這個數(shù)據(jù)結構會在forward過程中存儲trace到的計算過程
if (self) {
Value* self_value = state->graph->insertInput(0, "self")->setType(
self->_ivalue()->type());
gatherParametersAndBuffers(state, self_value, *self, {"__module"});
}
for (IValue& input : inputs) {
input = addInput(state, input, input.type(), state->graph->addInput());
}
auto graph = state->graph;
# 將python中的變量名解析函數(shù)綁定下來
getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);
getTracingState()->strict = strict;
getTracingState()->force_outplace = force_outplace;
# 開始forward,在計算發(fā)生時,會把計算記錄到state中
auto out_stack = traced_fn(inputs);
// Exit a trace, treating 'out_stack' as the outputs of the trace. These
// are the variables whose values will be computed upon subsequent
// invocations of the trace.
size_t i = 0;
for (auto& output : out_stack) {
// NB: The stack is in "reverse" order, so when we pass the diagnostic
// number we need to flip it based on size.
state->graph->registerOutput(
state->getOutput(output, out_stack.size() - i));
i++;
}
setTracingState(nullptr);
if (getInlineEverythingMode()) {
Inline(*graph);
}
FixupTraceScopeBlocks(graph, self);
NormalizeOps(graph);
return {state, out_stack};
} catch (...) {
tracer::abandon();
throw;
}
}
那么具體記錄 operation 的過程發(fā)生在哪里呢?
pytorch/torch/csrc/jit/runtime/register_c10_ops.cpp
Operator createOperatorFromC10_withTracingHandledHere(
const c10::OperatorHandle& op) {
return Operator(op, [op](Stack& stack) {
const auto input_size = op.schema().arguments().size();
const auto output_size = op.schema().returns().size();
Node* node = nullptr;
std::shared_ptr<jit::tracer::TracingState> tracer_state;
// trace the input before unwrapping, otherwise we may lose
// the input information
if (jit::tracer::isTracing()) {
# 獲取 tracer_state
tracer_state = jit::tracer::getTracingState();
auto symbol = Symbol::fromQualString(op.schema().name());
const auto& graph = tracer::getTracingState()->graph;
node = graph->create(symbol, 0);
tracer::recordSourceLocation(node);
const auto& args = op.schema().arguments();
int i = 0;
# 記錄args
for (auto iter = stack.end() - input_size; iter != stack.end();
++iter, ++i) {
// TODO we need to refactor graph APIs (e.g., addInputs)
// appropriately; after that, we can get rid of the giant if-else
// block we will clean this tech debt together in the following PRs
auto type = args[i].type();
if (type->kind() == TypeKind::OptionalType) {
if (iter->isNone()) {
Value* none = graph->insertNode(graph->createNone())->output();
node->addInput(none);
continue;
} else {
type = type->expect<OptionalType>()->getElementType();
}
}
if (type->isSubtypeOf(TensorType::get())) {
AT_ASSERT(iter->isTensor());
tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
} else if (type->kind() == TypeKind::FloatType) {
AT_ASSERT(iter->isDouble());
tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
} else if (type->kind() == TypeKind::IntType) {
AT_ASSERT(iter->isInt());
tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
} else if (type->kind() == TypeKind::BoolType) {
AT_ASSERT(iter->isBool());
tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
} else if (type->kind() == TypeKind::StringType) {
AT_ASSERT(iter->isString());
tracer::addInputs(node, args[i].name().c_str(), iter->toStringRef());
} else if (type->kind() == TypeKind::NumberType) {
tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
} else if (type->kind() == TypeKind::ListType) {
const auto& elem_type = type->expect<ListType>()->getElementType();
if (elem_type->isSubtypeOf(TensorType::get())) {
AT_ASSERT(iter->isTensorList());
auto list = iter->toTensorVector();
tracer::addInputs(node, args[i].name().c_str(), list);
} else if (elem_type->kind() == TypeKind::FloatType) {
AT_ASSERT(iter->isDoubleList());
// NB: now, tracer doesn't support tracing double list. We add
// special handling here, since in our case, we assume that all the
// doubles in the list are constants
auto value = iter->toDoubleVector();
std::vector<Value*> info(value.size());
for (size_t value_index = 0; value_index < value.size();
++value_index) {
info[value_index] = graph->insertConstant(value[value_index]);
tracer::recordSourceLocation(info[value_index]->node());
}
node->addInput(
graph
->insertNode(graph->createList(jit::FloatType::get(), info))
->output());
} else if (elem_type->kind() == TypeKind::IntType) {
AT_ASSERT(iter->isIntList());
tracer::addInputs(
node, args[i].name().c_str(), iter->toIntVector());
} else if (elem_type->kind() == TypeKind::BoolType) {
AT_ASSERT(iter->isBoolList());
tracer::addInputs(
node, args[i].name().c_str(), iter->toBoolList().vec());
} else {
throw std::runtime_error(
"unsupported input list type: " + elem_type->str());
}
} else if (iter->isObject()) {
tracer::addInputs(node, args[i].name().c_str(), iter->toObject());
} else {
throw std::runtime_error("unsupported input type: " + type->str());
}
}
# node嵌入graph
graph->insertNode(node);
jit::tracer::setTracingState(nullptr);
}
可以看到,在具體運算發(fā)生時,會使用 getTracingState() 得到 forward 開始去創(chuàng)建的 state,然后看到根據(jù)op.schema().name() 得到計算類型(比如相加),根據(jù)計算類型通過 createNone 方法創(chuàng)建一個計算節(jié)點,然后創(chuàng)建計算輸入,最后把計算node insert 到 graph 中,完成一次對計算的記錄。
script
因為 script 得到 IR 的方式是解析源碼,因此對于不同的代碼形式會略有不同(函數(shù),class,nn.Module的instance):1 Python 函數(shù) 簡化后 code
def script(obj, optimize=None, _frames_up=0, _rcb=None):
# fucntion 分支
if hasattr(obj, "__script_if_tracing_wrapper"):
obj = obj.__original_fn
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
# 檢查重載
_check_directly_compile_overloaded(obj)
# 是否之前被script過了
maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
if maybe_already_compiled_fn:
return maybe_already_compiled_fn
# 得到ast語法樹
ast = get_jit_def(obj, obj.__name__)
if _rcb is None:
_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
#c++ 入口,根據(jù)ast得到ir
fn = torch._C._jit_script_compile(
qualified_name, ast, _rcb, get_default_args(obj)
)
# Forward docstrings
fn.__doc__ = obj.__doc__
# cache起來
_set_jit_function_cache(obj, fn)
return fn
我們看下get_jit_def是如何得到 jit 規(guī)定的 ast 語法樹的
僅保留邏輯代碼,細節(jié)刪掉
def get_jit_def(fn, def_name, self_name=None):
# 得到源代碼的一些信息
sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
sourcelines = normalize_source_lines(sourcelines)
source = dedent_src ''.join(sourcelines)
# dedent_src 為包含了要script函數(shù)的字符串
dedent_src = dedent(source)
# 調(diào)用python ast包將字符串解析為Python的ast
py_ast = ast.parse(dedent_src)
# 得到python類型注釋
type_line = torch.jit.annotations.get_type_line(source)
#ctx中包含了函數(shù)所有原信息
ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)
fn_def = py_ast.body[0]
# build_def將python 的ast 轉化為torchjit 使用的ast格式
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)
用一個簡單的例子給大家解釋下 py_ast.body[0] 是什么
import ast
... func_def= \
... """def test(a):
... a = a + 2
... return a + 1"""
... results = ast.parse(func_def)
Python 解析出的 AST

可見,ast.body 是一個 list,其長度等于解析的 string 中包含的函數(shù)的個數(shù),我們看第一個元素,其中 value 是一個
Binop具體為一個Add,left 是Name類型,id為``a,right是Num,也就是2,這個Binop即解析的a = a + 2`。
因為我們 get_source_lines_and_file 返回的一定是一個 single top-level function, 因此我們直接取用第 0個元素,即 py_ast.body[0] 就可以了。
接下來看build_def是如何將 Python 的 ast 轉化為自己需要的 ast 的。
進入buid_def
def build_def(ctx, py_def, type_line, def_name, self_name=None):
....
return Def(Ident(r, def_name),
decl,
build_stmts(ctx, body))
因為ctx 包含 source code 所有信息, body 是 Python ast 解析結果,那么build_stmts中應該包含我們想要的答案。
我們用例子中a+2為例看會怎么轉換,這部分可見frontend.py
關于StmtBuilder
from torch._C._jit_tree_views import (
ClassDef, Ident, Stmt, Decl, Def, Var,
EmptyTypeAnnotation, Param, ExprStmt, Assign,
Delete, Return, Raise, Assert, AugAssign, While,
For, If, Pass, Break, Continue, Apply, Dots, Select,
TrueLiteral, FalseLiteral, NoneLiteral, Starred,
ListLiteral, TupleLiteral, DictLiteral, Const,
StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
DictComp,
)
# jit中定義的ast基本結構
def build_stmts(ctx, stmts):
#發(fā)現(xiàn)其調(diào)用了`build_stmt`
stmts = [build_stmt(ctx, s) for s in stmts]
return list(filter(None, stmts))
#`build_stmt` 是一個StmtBuilder()的instance
build_stmt = StmtBuilder()
build_expr = ExprBuilder()
class Builder(object):
def __call__(self, ctx, node):
# 可見會根據(jù)解析出的ast的類型返回相應的build方法,從截圖可以看到`a+2`是一個`Assign`類型
# 因此會調(diào)用build_Assign
method = getattr(self, 'build_' + node.__class__.__name__, None)
if method is None:
raise UnsupportedNodeError(ctx, node)
return method(ctx, node)
class StmtBuilder(Builder):
@staticmethod
def build_Assign(ctx, stmt):
# 截圖可以看到stmt.value是一個Binop
# build_expr是ExprBuilder的INSTANCE,其會調(diào)用`build_BinOp`
rhs = build_expr(ctx, stmt.value)
lhs = [build_expr(ctx, x) for x in stmt.targets]
return Assign(lhs, rhs)
@staticmethod
def build_Expr(ctx, stmt):
# Binop
value = stmt.value
if value.__class__.__name__ == 'Str':
# If a statement is a string literal expression,
# then it is a docstring. Just ignore it.
return None
else:
return ExprStmt(build_expr(ctx, value))
class ExprBuilder(Builder):
binop_map = {
ast.Add: '+',
ast.Sub: '-',
ast.Mult: '*',
ast.Div: '/',
ast.Pow: '**',
ast.Mod: '%',
ast.FloorDiv: '//',
ast.BitAnd: '&',
ast.BitXor: '^',
ast.BitOr: '|',
ast.LShift: '<<',
ast.RShift: '>>',
}
@staticmethod
def build_BinOp(ctx, expr):
#expr.left是個`Name`調(diào)用build_Name
lhs = build_expr(ctx, expr.left)
rhs = build_expr(ctx, expr.right)
op = type(expr.op)
# 轉化為約定的代表運算類型的string 符號
op_token = ExprBuilder.binop_map.get(op)
return BinOp(op_token, lhs, rhs)
最終轉化為的格式,類似于S-expression.
(def
(ident test)
(decl
(list
(param
(ident a)
(option)
(option)
(False)))
(option))
(list
(assign
(list (variable (ident a)))
(option
(+
(variable (ident a))
(const 2)))
(option))
(return
(+
(variable (ident a))
(const 1)))))
好的,我們已經(jīng)得到得到jit約定的 AST 樹了,接下來我們要進入 torch._C._jit_script_compile查看如何將這樣的 ast 樹轉化為 IR.
C++ 入口為 script_compile_function
static StrongFunctionPtr script_compile_function(
const c10::QualifiedName& name,
const Def& def,
const FunctionDefaults& defaults,
const ResolutionCallback& rcb) {
# def 中包含ast,跟著它就能找到答案
auto cu = get_python_cu();
#看來是get_python_cu這個類中的define函數(shù)完成的
auto defined_functions = cu->define(
QualifiedName(name.prefix()),
/*properties=*/{},
/*propResolvers=*/{},
{def},
{pythonResolver(rcb)},
nullptr,
true);
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
auto& defined = defined_functions[0];
defined->setSchema(getSchemaWithNameAndDefaults(
def.range(), defined->getSchema(), def.name().name(), defaults));
StrongFunctionPtr ret(std::move(cu), defined);
didFinishEmitFunction(ret);
return ret;
}
# 發(fā)現(xiàn)只是wapper了下CompilationUnit
inline std::shared_ptr<CompilationUnit> get_python_cu() {
return py::module::import("torch.jit._state")
.attr("_python_cu")
.cast<std::shared_ptr<CompilationUnit>>();
}
#關于compilation_unit
#/torch/csrc/jit/api/compilation_unit.h
// for historic reasons, these are defined in ir_emitter.cpp
// Returns the list of Functions just defined.
std::vector<Function*> define(
const c10::optional<c10::QualifiedName>& prefix,
const std::vector<Property>& properties,
const std::vector<ResolverPtr>& propResolvers,
const std::vector<Def>& definitions,
const std::vector<ResolverPtr>&
defResolvers, /* determines how we handle free
variables in each definition*/
// if non-null, the first argument to each def, is bound to this value
const Self* self,
// see [name mangling]
bool shouldMangle = false);
#實現(xiàn)在torch/csrc/jit/frontend/ir_emitter.cpp
std::unique_ptr<Function> CompilationUnit::define(
const c10::optional<QualifiedName>& prefix,
const Def& def,
const ResolverPtr& resolver,
const Self* self,
const std::unordered_map<std::string, Function*>& function_table,
bool shouldMangle) const {
auto _resolver = resolver;
.....
auto creator = [def, _resolver, self](Function& method) {
....
##核心代碼to_ir
to_ir(def, _resolver, self, method);
};
auto fn = torch::make_unique<GraphFunction>(
std::move(name), std::make_shared<Graph>(), creator);
return fn;
}
我們跟隨 def,找到了一個轉化為 IR 的關鍵的structto_ir,其輸入中有 def,也就是 ast,_resolver 是 Python 中傳過來的解析名字的函數(shù),我們可以在內(nèi)部找到關鍵部分
to_ir(
const Def& def,
ResolverPtr resolver_,
const Self* self,
Function& method) // method being constructed
: method(method),
graph(method.graph()),
resolver(std::move(resolver_)),
typeParser_(resolver),
environment_stack(nullptr) {
AT_ASSERT(resolver);
pushFrame(graph->block(), /*starts_def=*/true);
#emitDef 中會調(diào)用emitStatements
method.setSchema(emitDef(def, self, graph->block()));
ConvertToSSA(graph);
CanonicalizeModifiedLoops(graph);
NormalizeOps(graph);
runCleanupPasses(graph);
}
private:
#在to_ir 的private中我們可以看到Graph Function這些我們之前介紹的IR的組成部分
Function& method;
std::shared_ptr<Graph> graph;
ResolverPtr resolver;
std::unordered_map<int64_t, Value*> integral_constants;
#emitDef 中會調(diào)用emitStatements
FunctionSchema emitDef(const Def& def, const Self* self, Block* block) {
......
// body
auto stmts_list = def.statements();
emitStatements(stmts_list.begin(), stmts_list.end());
........
}
void emitStatements(
List<Stmt>::const_iterator begin,
List<Stmt>::const_iterator end) {
for (; begin != end; ++begin) {
auto stmt = *begin;
ErrorReport::CallStack::update_pending_range(stmt.range());
switch (stmt.kind()) {
case TK_IF:
emitIf(If(stmt));
break;
case TK_WHILE:
emitWhile(While(stmt));
break;
case TK_FOR:
emitFor(For(stmt));
break;
case TK_ASSIGN:
emitAssignment(Assign(stmt));
.................
break;
default:
throw ErrorReport(stmt)
<< "Unrecognized statement kind " << kindToString(stmt.kind());
}
// Found an exit statement in this block. The remaining statements aren't
// reachable so we don't emit them.
if (exit_blocks.count(environment_stack->block()))
return;
}
}
我們可以看到根據(jù)stmt.kind(),會進入而各種emit里面,其中一定可以找到
graph->insertNode(graph->create(.....));
類似的操作,對應我們建立IR graph
以上是我們以一個 function 為例子,接下來我們以 script 一個 module為例,其有一些獨有的挑戰(zhàn),因為有一些變量的指代,是需要初始化后才知道的,同時,我們希望 script 完的 module 對外還能保持一樣的接口,即可以正常訪問原有 module 的屬性,那么應該怎么做呢?
在 module 原有的 init 結束后隨即開始完整的 script forward 函數(shù),替換涉及到的所有函數(shù)為 script 后的函數(shù) 如何正常訪問原有的屬性
如何在一個類的 init 函數(shù)后面綁定行為呢,我們想到 metaclass,torch.jit 實現(xiàn)了 ScriptMeta這個 metaclass。
class MyModule(torch.jit.ScriptModule):
@torch.jit.script_method
def f(self.x):
return x * x
@torch.jit.script_method
def forward(self, x):
return x + self.f(x)
關于script_method
def script_method(fn):
_rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)
ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")
#暫時沒有script,只是返回包含ast的nametuple
return ScriptMethodStub(_rcb, ast, fn)
ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
1. 移除所有script_method屬性被(@script_method修飾的方法),確保訪問到的是script function
2. 修改module的_init_,確保module的self.param或者self.module初始化后立即編譯所有的script_method,從而生成的instance的forward已經(jīng)被替換
class ScriptMeta(type):
def __init__(cls, name, bases, attrs): # noqa: B902
# cls ScriptMeta的instance,是一個類如ScriptModule
cls._methods: Dict[str, Any] = {}
cls._constants_set = set(getattr(cls, "__constants__", ()))
for base in reversed(bases):
# 還記得嗎t(yī)race的module也是有一個_methods的屬性
for k, v in getattr(base, "_methods", {}).items():
cls._methods[k] = v
base_constants = getattr(base, "_constants_set", set())
cls._constants_set = cls._constants_set.union(base_constants)
# 找到現(xiàn)在所有被@script_method修飾的方法,放到_method,并刪除原有attr
# init后之后統(tǒng)一script
for k, v in sorted(attrs.items()):
if isinstance(v, ScriptMethodStub):
delattr(cls, k)
cls._methods[v.original_method.__name__] = v
original_init = getattr(cls, "__init__", lambda self: None)
# 此處實現(xiàn)了init結束后,調(diào)用create_script_module進行script
@functools.wraps(original_init)
def init_then_script(self, *args, **kwargs):
# 此處的self為instance
num_methods = len(cls._methods)
original_init(self, *args, **kwargs)
added_methods_in_init = len(cls._methods) > num_methods
if type(self) == cls:
# 選取需要script的method
def make_stubs(module):
cls = type(module)
if hasattr(cls, "_methods"):
return [v for k, v in sorted(cls._methods.items())]
else:
# infer_methods_to_compile 是一個選取要script函數(shù)的函數(shù)
return infer_methods_to_compile(module)
# 講所有script_method一塊編譯為_actual_script_module屬性
self.__dict__[
"_actual_script_module"
] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
# Delete the Python attributes that now shadow the ScriptModule
# ones, so that __getattr__ and __setattr__ will properly find
# the scripted versions.
concrete_type = self._actual_script_module._concrete_type
for name in concrete_type.get_attributes():
delattr(self, name)
for name, _ in concrete_type.get_modules():
delattr(self, name)
for name in ("_parameters", "_buffers", "_modules"):
delattr(self, name)
cls.__init__ = init_then_script # type: ignore
return super(ScriptMeta, cls).__init__(name, bases, attrs)
class _CachedForward(object):
def __get__(self, obj, cls):
return self.__getattr__("forward") # type: ignore
class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore
def __init__(self):
super(ScriptModule, self).__init__()
forward = _CachedForward()
# 想訪問module的attr,返回_actual_script_module的attr
def __getattr__(self, attr):
if "_actual_script_module" not in self.__dict__:
return super(ScriptModule, self).__getattr__(attr)
return getattr(self._actual_script_module, attr)
def __setattr__(self, attr, value):
if "_actual_script_module" not in self.__dict__:
# Unwrap torch.jit.Attribute into a regular setattr + recording
# the provided type in __annotations__.
#
# This ensures that if we use the attr again in `__init__`, it
# will look like the actual value, not an instance of Attribute.
if isinstance(value, Attribute):
if "__annotations__" not in self.__class__.__dict__:
self.__class__.__annotations__ = {}
self.__annotations__[attr] = value.type
value = value.value
return super(ScriptModule, self).__setattr__(attr, value)
setattr(self._actual_script_module, attr, value)
關于 create_script_module 函數(shù)會 script method 然后返回一個RecursiveScriptModule,但是其邏輯較為復雜,在此不再展開。
關于 getattribute vs getattr
當訪問某個實例屬性時,getattribute 會被無條件調(diào)用,當這個屬性不存在,則會調(diào)用 getattr,如未實現(xiàn)自己的 getattr 方法,會拋出AttributeError 提示找不到這個屬性,如果自定義了自己 getattr 方法的話方法會在這種找不到屬性的情況下被調(diào)用。
4 IR優(yōu)化的簡單介紹
jit 一般涉及如下優(yōu)化: loop unrolling peephole optimization constant propagation DCE fusion inlining... 我們看如下例子:
def test(x):
# Dead code Elimination
for i in range(1000):
y = x + 1
for i in range(100):
#peephole optimization
x = x.t()
x = x.t()
return x.sum()
opt_test = torch.jit.script(test)
s = time()
inputs = torch.ones(4,4).cuda()
s = time()
for i in range(10000):
test(inputs)
print(time()-s)
# 95s
s = time()
for i in range(10000):
opt_test(inputs)
print(time()-s)
# 0.13s
print(opt_test.graph)
print(opt_test.graph_for(inputs))
95.13823795318604
0.13010907173156738
graph(%x.1 : Tensor):
%22 : None = prim::Constant()
%13 : bool = prim::Constant[value=1]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4
%10 : int = prim::Constant[value=100]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:19
%x : Tensor = prim::Loop(%10, %13, %x.1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4
block0(%i : int, %x.10 : Tensor):
%x.4 : Tensor = aten::t(%x.10) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:11:12
%x.7 : Tensor = aten::t(%x.4) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:12:12
-> (%13, %x.7)
%23 : Tensor = aten::sum(%x, %22) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11
return (%23)
graph(%x.1 : Tensor):
%1 : None = prim::Constant()
%2 : Tensor = aten::sum(%x.1, %1) # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11
return (%2)
關于 IR 計算圖優(yōu)化
IR 的 Method 中內(nèi)置 GraphExecutor object,創(chuàng)建于第一次執(zhí)行的時候,負責優(yōu)化。
文件 pytorch-master/torch/csrc/jit/api/method.h scritp_method 的 C++ 原型里
GraphExecutor& get_executor() {
return function_->get_executor();
}
GraphExecutor 的定義在/torch/csrc/jit/runtime/graph_executor.cpp,可見其由 graph 產(chǎn)生,定義了 run 方法執(zhí)行
GraphExecutor::GraphExecutor(
const std::shared_ptr<Graph>& graph,
std::string function_name)
: pImpl(
IsNewExecutorEnabled()
? dynamic_cast<GraphExecutorImplBase*>(
new ProfilingGraphExecutorImpl(
graph,
std::move(function_name)))
: dynamic_cast<GraphExecutorImplBase*>(
new GraphExecutorImpl(graph, std::move(function_name)))) {}
std::shared_ptr<Graph> GraphExecutor::graph() const {
return pImpl->graph;
}
const ExecutionPlan& GraphExecutor::getPlanFor(
Stack& inputs,
size_t remaining_bailout_depth) {
return pImpl->getPlanFor(inputs, remaining_bailout_depth);
}
std::shared_ptr<GraphExecutorImplBase> pImpl;
.....
關于 GraphExecutorImplBase,/torch/csrc/jit/runtime/graph_executor.cpp
const ExecutionPlan& getOrCompile(const Stack& stack) {
.....
auto plan = compileSpec(spec);
}
}
# compileSpec 會返回一個plan
ExecutionPlan compileSpec(const ArgumentSpec& spec) {
auto opt_graph = graph->copy();
GRAPH_DUMP("Optimizing the following function:", opt_graph);
arg_spec_creator_.specializeTypes(*opt_graph, spec);
// Phase 0. Inline functions, then clean up any artifacts that the inliner
// left in that may inhibit optimization
.....
runRequiredPasses(opt_graph);
GRAPH_DEBUG(
"After runRequiredPasses, before ConstantPropagation\n", *opt_graph);
// Phase 2. Propagate detailed information about the spec through the
// graph (enabled more specializations in later passes).
// Shape propagation sometimes depends on certain arguments being
// constants, and constant propagation doesn't need shape
// information anyway, so it's better to run it first.
ConstantPropagation(opt_graph);
GRAPH_DEBUG(
"After ConstantPropagation, before PropagateInputShapes\n", *opt_graph);
PropagateInputShapes(opt_graph);
GRAPH_DEBUG(
"After PropagateInputShapes, before PropagateRequiresGrad\n",
*opt_graph);
PropagateRequiresGrad(opt_graph);
GRAPH_DEBUG(
"After PropagateRequiresGrad, before runOptimization\n", *opt_graph);
// Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
// that we can still execute using autograd).
runOptimization(opt_graph);
.....各種優(yōu)化
return ExecutionPlan(opt_graph, function_name_);
}
這些優(yōu)化在 torch/csrc/jit/passes/ 文件夾
torch/csrc/jit/passes/dead_code_elimination.cpp
/torch/csrc/jit/passes/fuse_linear.cpp
torch/csrc/jit/passes/remove_dropout.cpp
torch/csrc/jit/passes/fold_conv_bn.cpp
參考
1. INTRODUCTION TO TORCHSCRIPT
2. PyTorch 部署_TorchScript
3.pytorch_wiki
4. PyTorch-JIT-Source-Code-Read-Note
5. Abstract_syntax_tree
- The End -
長按二維碼關注我們
本公眾號專注:
1. 技術分享;
2. 學術交流;
3. 資料共享。
歡迎關注我們,一起成長!
