<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          PyTorch 源碼解讀之即時編譯篇

          共 84968字,需瀏覽 170分鐘

           ·

          2021-06-13 21:57


          作者丨OpenMMLab
          來源丨h(huán)ttps://zhuanlan.zhihu.com/p/361101354
          編輯丨GiantPandaCV

          前言

          torch 從 1.0 開始支持了 jit 模塊,其大概包括以下幾個部分:

          • 一種新的計算圖中間表示 (Intermediate Representation),之后簡稱為 IR.
          • 從 Python 代碼導出IR的兩種方法,即 trace 與 script.
          • IR 優(yōu)化以及 IR 的解釋器(翻譯為具體的運算 op).

          這篇解讀會分為以下幾個部分:

          • jit 的簡單介紹以及兩種導出方式的使用例子
          • jit 中 IR 的形式
          • 導出 IR 的兩種方式,trace 與 script 的源碼解讀
          • IR 優(yōu)化的簡單介紹

          1 jit 的簡單介紹以及使用例子

          JIT 簡介

          如前言,這篇解讀雖然標題是 JIT,但是真正稱得上即時編譯器的部分是在導出 IR 后,即優(yōu)化 IR 計算圖,并且解釋為對應 operation 的過程,即PyTorch jit 相關 code 帶來的優(yōu)化一般是計算圖級別優(yōu)化,比如部分運算的融合,但是對具體算子(如卷積)是沒有特定優(yōu)化的,其依舊調(diào)用 torch的基礎算子庫.

          大家也可以在導出 IR 也就是 torchscript 后,使用其他的編譯優(yōu)化或者解釋器,如現(xiàn)在也有script to a TensorRT engine,TRTtorch轉 tensorRT 的方案。

          trace

          給大家一個簡單例子。

          import torchvision.models as models
              resnet = torch.jit.trace(models.resnet18(),torch.rand(1,3,224,224))
              output=resnet(torch.ones(1,3,224,224))
              print(output)
              output=resnet(torch.ones(1,3,224,224))
              resnet.save('resnet.pt')

          output 便是我們導出的中間表示,其可以 save 下來,在其他框架使用

          我們可以看下 output 中的 IR,即 torchscript 表征的計算圖是什么樣子的。

          graph(%self.1 : __torch__.torchvision.models.resnet.___torch_mangle_194.ResNet,
              %input.1 : Float(1:1505283:50176224:224224:1, requires_grad=0, device=cpu)):
              %1472 : __torch__.torch.nn.modules.linear.___torch_mangle_193.Linear = prim::GetAttr[name="fc"](%self.1)
              %1469 : __torch__.torch.nn.modules.pooling.___torch_mangle_192.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)
              %1468 : __torch__.torch.nn.modulesjieshao.container.___torch_mangle_191.Sequential = prim::GetAttr[name="layer4"](%self.1)
              %1422 : __torch__.torch.nn.modules.container.___torch_mangle_175.Sequential = prim::GetAttr[name="layer3"](%self.1)
              ....
              %1556 : Tensor = prim::CallMethod[name="forward"](%1469, %1555)
              %1202 : int = prim::Constant[value=1]()
              %1203 : int = prim::Constant[value=-1]()
              %input : Float(1:512512:1, requires_grad=1, device=cpu) = aten::flatten(%1556, %1202, %1203
              %1557 : Tensor = prim::CallMethod[name="forward"](%1472, %input)
              return (%1557)

          這便是 trace 方法的使用,其核心實現(xiàn)的入口便是torch.jit.trace,參數(shù)為你需要導出的 model,以及合法輸入input,其大概原理恰如其名,便是跟蹤模型 inference 過程,將模型對輸入進行的操作逐一記錄下來,并對應到 IR 的操作,從而得到原本模型forward 的 IR。

          ote :但是這種實現(xiàn)方式有很明顯的缺陷,PyTorch 作為動態(tài)圖網(wǎng)絡,會有很多的 input dependent的控制流語句,根據(jù)輸入的不同可能會執(zhí)行情況會不同(if 或者 變長的 loop),這樣就無法 trace 到完整的計算圖。如下就是一個 trace

          失敗的 case:

          if x > 2.0:
              r = torch.tensor(1.0)
              else:
               r = torch.tensor(2.0)
              return r
              
          ftrace = torch.jit.trace(test, (torch.ones(1)))
          y = torch.ones(1) * 5
          print(ftrace(y))
          # results: tensor(2.)
          # 因為輸入只走了的分支else

          script

          @torch.jit.script
          def foo(x, y):
              if x.max() > y.max():
                  r = x
              else:
                  r = y
              return r
              
          print(foo.graph)
              
          print(foo(torch.Tensor([0]), torch.Tensor([1])))
          print(foo(torch.Tensor([1]), torch.Tensor([0])))
              
          graph(%x.1 : Tensor,
                %y.1 : Tensor):
            %3 : Tensor = aten::max(%x.1
            %5 : Tensor = aten::max(%y.1
            # 可以看到確實捕捉到了控制語句,
            %6 : Tensor = aten::gt(%3, %5
            %7 : bool = aten::Bool(%6
            %r : Tensor = prim::If(%7
              block0():
                -> (%x.1)
              block1():
                -> (%y.1)
            return (%r)
              
          tensor([1.])
          tensor([1.])

          script 使用是在你需要的地方 (fuction or nn.Module (默認追蹤 forward函數(shù)))掛載裝飾器torch.jit.script,其轉換方式跟 trace 是完全不同的思路,script 直接解析你的 PyTorch代碼,通過語法分析解析你的邏輯為一棵語法樹,然后轉換為中間表示 IR。

          Note: 雖然其可以解決 trace 存在無法追蹤動態(tài)邏輯的問題,但是 Python 作為靈活度極高的語法, 想完整支持解析各種 Python 操作幾乎是不可能的,因此我們需要額外的時間熟悉哪些寫法是可以被解析的,讓我們寫代碼的體驗大打折扣。

          兩者結合

          兩者各有優(yōu)勢,支持靈活集合。

          import torch
          import torch.nn as nn
          import torch.nn.functional as F
              
          class MyModule(nn.Module):
              def __init__(self):
                  super(MyModule, self).__init__()
                  # torch.jit.trace produces a ScriptModule's conv1 and conv2
                  self.conv1 = torch.jit.trace(nn.Conv2d(1205), torch.rand(111616))
                  self.conv2 = torch.jit.trace(nn.Conv2d(20205), torch.rand(1201616))
              
              def forward(self, input):
                  input = F.relu(self.conv1(input))
                  input = F.relu(self.conv2(input))
                  return input
              
          scripted_module = torch.jit.script(MyModule())

          因此實際使用時候,可以有如下準則:

          1 大部分情況 model 只有 tensor operation,就直接無腦 tracing

          2 帶 control-flow (if-else, for-loop) 的,上 scripting

          3 碰上 scripting 不能 handle 的語法,要么重寫,要么把 tracing 和 scripting 合起來用(比如說只在有 control-

          flow 的代碼用 scripting,其他用 tracing)

          如何擴展

          trace 與 script 都不能轉換第三方 Python 庫中的函數(shù),盡量所有代碼都使用 PyTorch 實現(xiàn), 自定義 op 需要注冊成 jit

          操作( torch 的 op 其實也注冊了),最后轉成 torchscript。

              TORCH_LIBRARY(my_ops, m) {
                m.def("warp_perspective", warp_perspective);
              }

          更多可以參考官方教程

          1 EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS

          2 IR (torchscript)的基本表示

          PyTorch 中的各種設計(parameter,計算節(jié)點等)在 torchscript 中是如何對應的呢?

          這便是轉換出的 IR 結果,torchscrip 以下結構組合。

          名稱source code簡介
          Modulesmodule.h對標 nn.Module
          Parametersmodule.h對標 PyTorch 的 parameter
          MethodMethod.h包括 FunctionSchema 方法描述,Graph 實際計算圖,GraphExecutor do the optimization and execution
          FunctionSchemafunction_schema.h描述參數(shù)與返回類型
          Graphir.h定義 function 的具體實現(xiàn),包括 Nodes,Blocks,Values
          Nodesir.h一個指令,如一次卷積運算,一次矩陣運算
          Blockir.h控制語句 if,loop + list of nodes

          還有with,Value,Type

              # %x.1 value
              graph(%x.1 : Tensor,
                    %y.1 : Tensor):
                    # aten::max 就是一個Node
                    # Tensor: Type-TensorType
                %3 : Tensor = aten::max(%x.1
                %5 : Tensor = aten::max(%y.1
                %6 : Tensor = aten::gt(%3, %5
                %7 : bool = aten::Bool(%6
                %r : Tensor = prim::If(%7
                 # Blocks 
                  block0():
                    -> (%x.1)
                  block1():
                    -> (%y.1)
                return (%r)

          3 導出 IR 的兩種方式,trace 與 script

          因為其具體實現(xiàn)頗為復雜,粘貼的源碼也僅僅保留了簡單 case 跑過的分支,并且省去了絕大部分細節(jié),讀者如有需要更多細節(jié)可以自行去源碼查閱。

          trace 實現(xiàn)

              func,
                  example_inputs,
                  optimize=None,
                  check_trace=True,
                  check_inputs=None,
                  check_tolerance=1e-5,
                  strict=True,
                  _force_outplace=False,
                  _module_class=None,
                  _compilation_unit=_python_cu,
              ):


                  # 發(fā)現(xiàn)是nn.Module instacene forward, 追蹤forward
                  if isinstance(func, torch.nn.Module):
                      return trace_module(
                          func,
                          {"forward": example_inputs},
                          None,
                          check_trace,
                          wrap_check_inputs(check_inputs),
                          check_tolerance,
                          strict,
                          _force_outplace,
                          _module_class,
                      )
                  # 傳進來的是某個module instance的forward
                  if (
                      hasattr(func, "__self__")
                      and isinstance(func.__self__, torch.nn.Module)
                      and func.__name__ == "forward"
                  ):
                      return trace_module(
                          func.__self__,
                          {"forward": example_inputs},
                          None,
                          check_trace,
                          wrap_check_inputs(check_inputs),
                          check_tolerance,
                          strict,
                          _force_outplace,
                          _module_class,
                      )
                  # 一個查找變量名的接口
                  var_lookup_fn = _create_interpreter_name_lookup_fn(0)
              
                 # C++ 入口 
                 traced = torch._C._create_function_from_trace(
                     name, func, example_inputs, var_lookup_fn, strict,_force_outplace
                  )
              
                  # 檢查traced 與 原func是否有差異
                  if check_trace:
                      if check_inputs is not None:
                          _check_trace(
                              check_inputs,
                              func,
                              traced,
                              check_tolerance,
                              strict,
                              _force_outplace,
                              False,
                              _module_class,
                          )
                      else:
                          _check_trace(
                              [example_inputs],
                              func,
                              traced,
                              check_tolerance,
                              strict,
                              _force_outplace,
                              False,
                              _module_class,
                          )
              
                  return traced

          我們發(fā)現(xiàn)經(jīng)過簡單的判斷,代碼便進入了 C++ 相關函數(shù)

              traced = torch._C._create_function_from_trace(
                      name, func, example_inputs, var_lookup_fn, strict, _force_outplace
                  )

          我們?nèi)?C++ 中看下發(fā)生了什么

              std::pair<std::shared_ptr<TracingState>, Stack> trace(
                  Stack inputs,
                  const std::function<Stack(Stack)>& traced_fn,
                  std::function<std::string(const Variable&)> var_name_lookup_fn,
                  bool strict,
                  bool force_outplace,
                  Module* self)
           
          {
                try {
              
                  auto state = std::make_shared<TracingState>();
                  # setTracingState 將state 這個實例set下來,在之后計算節(jié)點get出來insert計算過程
                  setTracingState(state);
              
                  #state這個數(shù)據(jù)結構會在forward過程中存儲trace到的計算過程
                  if (self) {
                    Value* self_value = state->graph->insertInput(0"self")->setType(
                        self->_ivalue()->type());
                    gatherParametersAndBuffers(state, self_value, *self, {"__module"});
                  }
              
                  for (IValue& input : inputs) {
                    input = addInput(state, input, input.type(), state->graph->addInput());
                  }
                  auto graph = state->graph;
                  # 將python中的變量名解析函數(shù)綁定下來
                  getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);
                  getTracingState()->strict = strict;
                  getTracingState()->force_outplace = force_outplace;
              
                  # 開始forward,在計算發(fā)生時,會把計算記錄到state中
                  auto out_stack = traced_fn(inputs);
              
                  // Exit a trace, treating 'out_stack' as the outputs of the trace.  These
                  // are the variables whose values will be computed upon subsequent
                  // invocations of the trace.
                  size_t i = 0;
                  for (auto& output : out_stack) {
                    // NB: The stack is in "reverse" order, so when we pass the diagnostic
                    // number we need to flip it based on size.
                    state->graph->registerOutput(
                        state->getOutput(output, out_stack.size() - i));
                    i++;
                  }
                  setTracingState(nullptr);
              
                  if (getInlineEverythingMode()) {
                    Inline(*graph);
                  }
                  FixupTraceScopeBlocks(graph, self);
                  NormalizeOps(graph);
                  return {state, out_stack};
                } catch (...) {
                  tracer::abandon();
                  throw;
                }
              }

          那么具體記錄 operation 的過程發(fā)生在哪里呢?

          pytorch/torch/csrc/jit/runtime/register_c10_ops.cpp

              Operator createOperatorFromC10_withTracingHandledHere(
                  const c10::OperatorHandle& op)
           
          {
                return Operator(op, [op](Stack& stack) {
                  const auto input_size = op.schema().arguments().size();
                  const auto output_size = op.schema().returns().size();
              
                  Node* node = nullptr;
                  std::shared_ptr<jit::tracer::TracingState> tracer_state;
              
                  // trace the input before unwrapping, otherwise we may lose
                  // the input information
                  if (jit::tracer::isTracing()) {
                    # 獲取 tracer_state
                    tracer_state = jit::tracer::getTracingState();
                    auto symbol = Symbol::fromQualString(op.schema().name());
                    const auto& graph = tracer::getTracingState()->graph;
                    node = graph->create(symbol, 0);
                    tracer::recordSourceLocation(node);
                    const auto& args = op.schema().arguments();
                    int i = 0;
                    # 記錄args 
                    for (auto iter = stack.end() - input_size; iter != stack.end();
                         ++iter, ++i) {
                      // TODO we need to refactor graph APIs (e.g., addInputs)
                      // appropriately; after that, we can get rid of the giant if-else
                      // block we will clean this tech debt together in the following PRs
                      auto type = args[i].type();
                      if (type->kind() == TypeKind::OptionalType) {
                        if (iter->isNone()) {
                          Value* none = graph->insertNode(graph->createNone())->output();
                          node->addInput(none);
                          continue;
                        } else {
                          type = type->expect<OptionalType>()->getElementType();
                        }
                      }
                      if (type->isSubtypeOf(TensorType::get())) {
                        AT_ASSERT(iter->isTensor());
                        tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
                      } else if (type->kind() == TypeKind::FloatType) {
                        AT_ASSERT(iter->isDouble());
                        tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
                      } else if (type->kind() == TypeKind::IntType) {
                        AT_ASSERT(iter->isInt());
                        tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
                      } else if (type->kind() == TypeKind::BoolType) {
                        AT_ASSERT(iter->isBool());
                        tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
                      } else if (type->kind() == TypeKind::StringType) {
                        AT_ASSERT(iter->isString());
                        tracer::addInputs(node, args[i].name().c_str(), iter->toStringRef());
                      } else if (type->kind() == TypeKind::NumberType) {
                        tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
                      } else if (type->kind() == TypeKind::ListType) {
                        const auto& elem_type = type->expect<ListType>()->getElementType();
                        if (elem_type->isSubtypeOf(TensorType::get())) {
                          AT_ASSERT(iter->isTensorList());
                          auto list = iter->toTensorVector();
                          tracer::addInputs(node, args[i].name().c_str(), list);
                        } else if (elem_type->kind() == TypeKind::FloatType) {
                          AT_ASSERT(iter->isDoubleList());
                          // NB: now, tracer doesn't support tracing double list. We add
                          // special handling here, since in our case, we assume that all the
                          // doubles in the list are constants
                          auto value = iter->toDoubleVector();
                          std::vector<Value*> info(value.size());
                          for (size_t value_index = 0; value_index < value.size();
                               ++value_index) {
                            info[value_index] = graph->insertConstant(value[value_index]);
                            tracer::recordSourceLocation(info[value_index]->node());
                          }
                          node->addInput(
                              graph
                                  ->insertNode(graph->createList(jit::FloatType::get(), info))
                                  ->output());
                        } else if (elem_type->kind() == TypeKind::IntType) {
                          AT_ASSERT(iter->isIntList());
                          tracer::addInputs(
                              node, args[i].name().c_str(), iter->toIntVector());
                        } else if (elem_type->kind() == TypeKind::BoolType) {
                          AT_ASSERT(iter->isBoolList());
                          tracer::addInputs(
                              node, args[i].name().c_str(), iter->toBoolList().vec());
                        } else {
                          throw std::runtime_error(
                              "unsupported input list type: " + elem_type->str());
                        }
                      } else if (iter->isObject()) {
                        tracer::addInputs(node, args[i].name().c_str(), iter->toObject());
                      } else {
                        throw std::runtime_error("unsupported input type: " + type->str());
                      }
                    }
                    # node嵌入graph
                    graph->insertNode(node);
              
                    jit::tracer::setTracingState(nullptr);
                  }

          可以看到,在具體運算發(fā)生時,會使用 getTracingState() 得到 forward 開始去創(chuàng)建的 state,然后看到根據(jù)op.schema().name() 得到計算類型(比如相加),根據(jù)計算類型通過 createNone 方法創(chuàng)建一個計算節(jié)點,然后創(chuàng)建計算輸入,最后把計算node insert 到 graph 中,完成一次對計算的記錄。

          script

          因為 script 得到 IR 的方式是解析源碼,因此對于不同的代碼形式會略有不同(函數(shù),class,nn.Module的instance):1 Python 函數(shù) 簡化后 code

              def script(obj, optimize=None, _frames_up=0, _rcb=None):
                  # fucntion 分支
                  if hasattr(obj, "__script_if_tracing_wrapper"):
                      obj = obj.__original_fn
                      _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
              
                  # 檢查重載
                  _check_directly_compile_overloaded(obj)
                  # 是否之前被script過了
                  maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
                  if maybe_already_compiled_fn:
                      return maybe_already_compiled_fn
                  # 得到ast語法樹
                  ast = get_jit_def(obj, obj.__name__)
                  if _rcb is None:
                      _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
                  #c++ 入口,根據(jù)ast得到ir
                  fn = torch._C._jit_script_compile(
                      qualified_name, ast, _rcb, get_default_args(obj)
                  )
                  # Forward docstrings
                  fn.__doc__ = obj.__doc__
                  # cache起來
                  _set_jit_function_cache(obj, fn)
                  return fn

          我們看下get_jit_def是如何得到 jit 規(guī)定的 ast 語法樹的

          僅保留邏輯代碼,細節(jié)刪掉

              def get_jit_def(fn, def_name, self_name=None):

                  # 得到源代碼的一些信息
                  sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
                  sourcelines = normalize_source_lines(sourcelines)
                  source =  dedent_src ''.join(sourcelines)
                  # dedent_src 為包含了要script函數(shù)的字符串
                  dedent_src = dedent(source)
                  # 調(diào)用python ast包將字符串解析為Python的ast
                  py_ast = ast.parse(dedent_src)
              
                  # 得到python類型注釋
                  type_line = torch.jit.annotations.get_type_line(source)
                  #ctx中包含了函數(shù)所有原信息
                  ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)
                  fn_def = py_ast.body[0]
              
                  # build_def將python 的ast 轉化為torchjit 使用的ast格式
                  return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)

          用一個簡單的例子給大家解釋下 py_ast.body[0] 是什么

              import ast
              ... func_def= \
              ... """def test(a):
              ...     a = a + 2
              ...     return a + 1"""

              ... results = ast.parse(func_def)

          Python 解析出的 AST

          可見,ast.body 是一個 list,其長度等于解析的 string 中包含的函數(shù)的個數(shù),我們看第一個元素,其中 value 是一個

          Binop具體為一個Add,left 是Name類型,id為``a,right是Num,也就是2,這個Binop即解析的a = a + 2`。

          因為我們 get_source_lines_and_file 返回的一定是一個 single top-level function, 因此我們直接取用第 0個元素,即 py_ast.body[0] 就可以了。

          接下來看build_def是如何將 Python 的 ast 轉化為自己需要的 ast 的。

          進入buid_def

              def build_def(ctx, py_def, type_line, def_name, self_name=None):
                  ....
                  return Def(Ident(r, def_name),
                             decl,
                             build_stmts(ctx, body))

          因為ctx 包含 source code 所有信息, body 是 Python ast 解析結果,那么build_stmts中應該包含我們想要的答案。

          我們用例子中a+2為例看會怎么轉換,這部分可見frontend.py

          關于StmtBuilder

              
              from torch._C._jit_tree_views import (
                  ClassDef, Ident, Stmt, Decl, Def, Var,
                  EmptyTypeAnnotation, Param, ExprStmt, Assign,
                  Delete, Return, Raise, Assert, AugAssign, While,
                  For, If, Pass, Break, Continue, Apply, Dots, Select,
                  TrueLiteral, FalseLiteral, NoneLiteral, Starred,
                  ListLiteral, TupleLiteral, DictLiteral, Const,
                  StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
                  SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
                  DictComp,
              )
              # jit中定義的ast基本結構
              
              def build_stmts(ctx, stmts):
                  #發(fā)現(xiàn)其調(diào)用了`build_stmt`
                  stmts = [build_stmt(ctx, s) for s in stmts]
                  return list(filter(None, stmts))
              
              #`build_stmt` 是一個StmtBuilder()的instance
              build_stmt = StmtBuilder()
              build_expr = ExprBuilder()
              
              class Builder(object):
                  def __call__(self, ctx, node):
                      # 可見會根據(jù)解析出的ast的類型返回相應的build方法,從截圖可以看到`a+2`是一個`Assign`類型
                      # 因此會調(diào)用build_Assign
                      method = getattr(self, 'build_' + node.__class__.__name__, None)
                      if method is None:
                          raise UnsupportedNodeError(ctx, node)
                      return method(ctx, node)
              
              class StmtBuilder(Builder):
                  @staticmethod
                  def build_Assign(ctx, stmt):
                      # 截圖可以看到stmt.value是一個Binop
                      # build_expr是ExprBuilder的INSTANCE,其會調(diào)用`build_BinOp`
                      rhs = build_expr(ctx, stmt.value)
                      lhs = [build_expr(ctx, x) for x in stmt.targets]
                      return Assign(lhs, rhs)
              
                  @staticmethod
                  def build_Expr(ctx, stmt):
                      # Binop
                      value = stmt.value
                      if value.__class__.__name__ == 'Str':
                          # If a statement is a string literal expression,
                          # then it is a docstring. Just ignore it.
                          return None
                      else:
                          return ExprStmt(build_expr(ctx, value))
              
               class ExprBuilder(Builder):
                      binop_map = {
                      ast.Add: '+',
                      ast.Sub: '-',
                      ast.Mult: '*',
                      ast.Div: '/',
                      ast.Pow: '**',
                      ast.Mod: '%',
                      ast.FloorDiv: '//',
                      ast.BitAnd: '&',
                      ast.BitXor: '^',
                      ast.BitOr: '|',
                      ast.LShift: '<<',
                      ast.RShift: '>>',
                  }
                      @staticmethod
                  def build_BinOp(ctx, expr):
                      #expr.left是個`Name`調(diào)用build_Name
                      lhs = build_expr(ctx, expr.left)
                      rhs = build_expr(ctx, expr.right)
                      op = type(expr.op)
                      # 轉化為約定的代表運算類型的string 符號
                      op_token = ExprBuilder.binop_map.get(op)
                      return BinOp(op_token, lhs, rhs)

          最終轉化為的格式,類似于S-expression.

              (def
                (ident test)
                (decl
                  (list
                    (param
                      (ident a)
                      (option)
                      (option)
                      (False))
          )

                  (option))

                (list
                  (assign
                    (list (variable (ident a)))
                    (option
                      (+
                        (variable (ident a))
                        (const 2))
          )

                    (option))

                  (return
                    (+
                      (variable (ident a))
                      (const 1))
          )
          )
          )

          好的,我們已經(jīng)得到得到jit約定的 AST 樹了,接下來我們要進入 torch._C._jit_script_compile查看如何將這樣的 ast 樹轉化為 IR.

          C++ 入口為 script_compile_function

              static StrongFunctionPtr script_compile_function(
                  const c10::QualifiedName& name,
                  const Def& def,
                  const FunctionDefaults& defaults,
                  const ResolutionCallback& rcb)
           
          {
                 #  def 中包含ast,跟著它就能找到答案
                auto cu = get_python_cu();
                #看來是get_python_cu這個類中的define函數(shù)完成的
                auto defined_functions = cu->define(
                    QualifiedName(name.prefix()),
                    /*properties=*/{},
                    /*propResolvers=*/{},
                    {def},
                    {pythonResolver(rcb)},
                    nullptr,
                    true);
                TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
                auto& defined = defined_functions[0];
                defined->setSchema(getSchemaWithNameAndDefaults(
                    def.range(), defined->getSchema(), def.name().name(), defaults));
                StrongFunctionPtr ret(std::move(cu), defined);
                didFinishEmitFunction(ret);
                return ret;
              }
              # 發(fā)現(xiàn)只是wapper了下CompilationUnit
              inline std::shared_ptr<CompilationUnit> get_python_cu() 
          {
                return py::module::import("torch.jit._state")
                    .attr("_python_cu")
                    .cast<std::shared_ptr<CompilationUnit>>();
              }
              
              #關于compilation_unit
              #/torch/csrc/jit/api/compilation_unit.h
               // for historic reasons, these are defined in ir_emitter.cpp
               // Returns the list of Functions just defined.
                std::vector<Function*> define(
                    const c10::optional<c10::QualifiedName>& prefix,
                    const std::vector<Property>& properties,
                    const std::vector<ResolverPtr>& propResolvers,
                    const std::vector<Def>& definitions,
                    const std::vector<ResolverPtr>&
                        defResolvers, /* determines how we handle free
                                   variables in each definition*/

                    // if non-null, the first argument to each def, is bound to this value
                    const Self* self,
                    // see [name mangling]
                    bool shouldMangle = false)
          ;
              #實現(xiàn)在torch/csrc/jit/frontend/ir_emitter.cpp
              std::unique_ptr<Function> CompilationUnit::define(
                  const c10::optional<QualifiedName>& prefix,
                  const Def& def,
                  const ResolverPtr& resolver,
                  const Self* self,
                  const std::unordered_map<std::string, Function*>& function_table,
                  bool shouldMangle)
           const 
          {
              
                auto _resolver = resolver;
                .....
                auto creator = [def, _resolver, self](Function& method) {
                  ....
                  ##核心代碼to_ir
                  to_ir(def, _resolver, self, method);
                };
              
                auto fn = torch::make_unique<GraphFunction>(
                    std::move(name), std::make_shared<Graph>(), creator);
                return fn;
              }

          我們跟隨 def,找到了一個轉化為 IR 的關鍵的structto_ir,其輸入中有 def,也就是 ast,_resolver 是 Python 中傳過來的解析名字的函數(shù),我們可以在內(nèi)部找到關鍵部分

              to_ir(
                    const Def& def,
                    ResolverPtr resolver_,
                    const Self* self,
                    Function& method) // method being constructed
                    : method(method),
                      graph(method.graph()),
                      resolver(std::move(resolver_)),
                      typeParser_(resolver),
                      environment_stack(nullptr) {
                  AT_ASSERT(resolver);
                  pushFrame(graph->block(), /*starts_def=*/true);
              
                  #emitDef 中會調(diào)用emitStatements
                  method.setSchema(emitDef(def, self, graph->block()));
                  ConvertToSSA(graph);
                  CanonicalizeModifiedLoops(graph);
                  NormalizeOps(graph);
                  runCleanupPasses(graph);
                }
              private:
               #在to_ir 的private中我們可以看到Graph Function這些我們之前介紹的IR的組成部分
                Function& method;
                std::shared_ptr<Graph> graph;
                ResolverPtr resolver;
                std::unordered_map<int64_t, Value*> integral_constants;  
              
               #emitDef 中會調(diào)用emitStatements
               FunctionSchema emitDef(const Def& def, const Self* self, Block* block) 
          {
                  ......
                  // body
                  auto stmts_list = def.statements();
                  emitStatements(stmts_list.begin(), stmts_list.end());
                   ........
                }
               void emitStatements(
                    List<Stmt>::const_iterator begin,
                    List<Stmt>::const_iterator end)
           
          {
                  for (; begin != end; ++begin) {
                    auto stmt = *begin;
                    ErrorReport::CallStack::update_pending_range(stmt.range());
                    switch (stmt.kind()) {
                      case TK_IF:
                        emitIf(If(stmt));
                        break;
                      case TK_WHILE:
                        emitWhile(While(stmt));
                        break;
                      case TK_FOR:
                        emitFor(For(stmt));
                        break;
                      case TK_ASSIGN:
                        emitAssignment(Assign(stmt));
                     .................
                        break;
                      default:
                        throw ErrorReport(stmt)
                            << "Unrecognized statement kind " << kindToString(stmt.kind());
                    }
                    // Found an exit statement in this block. The remaining statements aren't
                    // reachable so we don't emit them.
                    if (exit_blocks.count(environment_stack->block()))
                      return;
                  }
                }


          我們可以看到根據(jù)stmt.kind(),會進入而各種emit里面,其中一定可以找到
          graph->insertNode(graph->create(.....));
          類似的操作,對應我們建立IR graph

          以上是我們以一個 function 為例子,接下來我們以 script 一個 module為例,其有一些獨有的挑戰(zhàn),因為有一些變量的指代,是需要初始化后才知道的,同時,我們希望 script 完的 module 對外還能保持一樣的接口,即可以正常訪問原有 module 的屬性,那么應該怎么做呢?

          1. 在 module 原有的 init 結束后隨即開始完整的 script forward 函數(shù),替換涉及到的所有函數(shù)為 script 后的函數(shù)
          2. 如何正常訪問原有的屬性

          如何在一個類的 init 函數(shù)后面綁定行為呢,我們想到 metaclass,torch.jit 實現(xiàn)了 ScriptMeta這個 metaclass。

          class MyModule(torch.jit.ScriptModule):
              @torch.jit.script_method
              def f(self.x):
                  return x * x
              @torch.jit.script_method
              def forward(self, x):
                   return x + self.f(x)

          關于script_method

              def script_method(fn):
              
                  _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)
                  ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")
                  #暫時沒有script,只是返回包含ast的nametuple
                  return ScriptMethodStub(_rcb, ast, fn)
              
                  ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback''def_''original_method'))

          1. 移除所有script_method屬性被(@script_method修飾的方法),確保訪問到的是script function
          2. 修改module的_init_,確保module的self.param或者self.module初始化后立即編譯所有的script_method,從而生成的instance的forward已經(jīng)被替換

              class ScriptMeta(type):
                  def __init__(cls, name, bases, attrs):  # noqa: B902
                      # cls ScriptMeta的instance,是一個類如ScriptModule
                      cls._methods: Dict[str, Any] = {}
                      cls._constants_set = set(getattr(cls, "__constants__", ()))
                      for base in reversed(bases):
                          # 還記得嗎t(yī)race的module也是有一個_methods的屬性
                          for k, v in getattr(base, "_methods", {}).items():
                              cls._methods[k] = v
                          base_constants = getattr(base, "_constants_set", set())
                          cls._constants_set = cls._constants_set.union(base_constants)
              
                      # 找到現(xiàn)在所有被@script_method修飾的方法,放到_method,并刪除原有attr
                      # init后之后統(tǒng)一script
                      for k, v in sorted(attrs.items()):
                          if isinstance(v, ScriptMethodStub):
                              delattr(cls, k)
                              cls._methods[v.original_method.__name__] = v


              
                      original_init = getattr(cls, "__init__"lambda self: None)
              
                      # 此處實現(xiàn)了init結束后,調(diào)用create_script_module進行script
                      @functools.wraps(original_init)
                      def init_then_script(self, *args, **kwargs):
                          # 此處的self為instance
                          num_methods = len(cls._methods)
                          original_init(self, *args, **kwargs)
                          added_methods_in_init = len(cls._methods) > num_methods
              
                          if type(self) == cls:
                              # 選取需要script的method
                              def make_stubs(module):
                                  cls = type(module)
                                  if hasattr(cls, "_methods"):
                                      return [v for k, v in sorted(cls._methods.items())]
                                  else:
                                      # infer_methods_to_compile 是一個選取要script函數(shù)的函數(shù)
                                      return infer_methods_to_compile(module)
                              # 講所有script_method一塊編譯為_actual_script_module屬性
              
                              self.__dict__[
                                  "_actual_script_module"
                              ] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
              
                              # Delete the Python attributes that now shadow the ScriptModule
                              # ones, so that __getattr__ and __setattr__ will properly find
                              # the scripted versions.
                              concrete_type = self._actual_script_module._concrete_type
                              for name in concrete_type.get_attributes():
                                  delattr(self, name)
                              for name, _ in concrete_type.get_modules():
                                  delattr(self, name)
                              for name in ("_parameters""_buffers""_modules"):
                                  delattr(self, name)
              
                      cls.__init__ = init_then_script  # type: ignore
              
                      return super(ScriptMeta, cls).__init__(name, bases, attrs)
              
                class _CachedForward(object):
                      def __get__(self, obj, cls):
                          return self.__getattr__("forward")  # type: ignore
              
                 class ScriptModule(with_metaclass(ScriptMeta, Module)):  # type: ignore
              
                      def __init__(self):
                          super(ScriptModule, self).__init__()
              
                      forward = _CachedForward()
                      # 想訪問module的attr,返回_actual_script_module的attr
                      def __getattr__(self, attr):
                          if "_actual_script_module" not in self.__dict__:
                              return super(ScriptModule, self).__getattr__(attr)
                          return getattr(self._actual_script_module, attr)
              
                      def __setattr__(self, attr, value):
                          if "_actual_script_module" not in self.__dict__:
                              # Unwrap torch.jit.Attribute into a regular setattr + recording
                              # the provided type in __annotations__.
                              #
                              # This ensures that if we use the attr again in `__init__`, it
                              # will look like the actual value, not an instance of Attribute.
                              if isinstance(value, Attribute):
                                  if "__annotations__" not in self.__class__.__dict__:
                                      self.__class__.__annotations__ = {}
                                  self.__annotations__[attr] = value.type
                                  value = value.value
                              return super(ScriptModule, self).__setattr__(attr, value)
              
                          setattr(self._actual_script_module, attr, value)

          關于 create_script_module 函數(shù)會 script method 然后返回一個RecursiveScriptModule,但是其邏輯較為復雜,在此不再展開。

          關于 getattribute vs getattr

          當訪問某個實例屬性時,getattribute 會被無條件調(diào)用,當這個屬性不存在,則會調(diào)用 getattr,如未實現(xiàn)自己的 getattr 方法,會拋出AttributeError 提示找不到這個屬性,如果自定義了自己 getattr 方法的話方法會在這種找不到屬性的情況下被調(diào)用。

          4 IR優(yōu)化的簡單介紹

          jit 一般涉及如下優(yōu)化: loop unrolling peephole optimization constant propagation DCE fusion inlining... 我們看如下例子:

              def test(x):
                  # Dead code Elimination
                  for i in range(1000):
                      y = x + 1
                  for i in range(100):
                      #peephole optimization
                      x = x.t()
                      x = x.t()
                  return x.sum()
              
              opt_test = torch.jit.script(test)
              s = time()
              inputs = torch.ones(4,4).cuda()
              s = time()
              for i in range(10000):
                  test(inputs)
              print(time()-s)
              # 95s
              s = time()
              for i in range(10000):
                  opt_test(inputs)
              print(time()-s)
              # 0.13s
              print(opt_test.graph)
              print(opt_test.graph_for(inputs))
              95.13823795318604
              0.13010907173156738
              graph(%x.1 : Tensor):
                %22 : None = prim::Constant()
                %13 : bool = prim::Constant[value=1]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4
                %10 : int = prim::Constant[value=100]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:19
                %x : Tensor = prim::Loop(%10, %13, %x.1# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4
                  block0(%i : int, %x.10 : Tensor):
                    %x.4 : Tensor = aten::t(%x.10# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:11:12
                    %x.7 : Tensor = aten::t(%x.4# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:12:12
                    -> (%13, %x.7)
                %23 : Tensor = aten::sum(%x, %22# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11
                return (%23)
              
              graph(%x.1 : Tensor):
                %1 : None = prim::Constant()
                %2 : Tensor = aten::sum(%x.1, %1# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11
                return (%2)

          關于 IR 計算圖優(yōu)化

          IR 的 Method 中內(nèi)置 GraphExecutor object,創(chuàng)建于第一次執(zhí)行的時候,負責優(yōu)化。
          文件 pytorch-master/torch/csrc/jit/api/method.h scritp_method 的 C++ 原型里

              GraphExecutor& get_executor() {
                  return function_->get_executor();
                }

          GraphExecutor 的定義在/torch/csrc/jit/runtime/graph_executor.cpp,可見其由 graph 產(chǎn)生,定義了 run 方法執(zhí)行

              GraphExecutor::GraphExecutor(
                  const std::shared_ptr<Graph>& graph,
                  std::string function_name)
                  : pImpl(
                        IsNewExecutorEnabled()
                            ? dynamic_cast<GraphExecutorImplBase*>(
                                  new ProfilingGraphExecutorImpl(
                                      graph,
                                      std::move(function_name)))
                            : dynamic_cast<GraphExecutorImplBase*>(
                                  new GraphExecutorImpl(graph, std::move(function_name)))) {}
              std::shared_ptr<Graph> GraphExecutor::graph() const {
                return pImpl->graph;
              }
              const ExecutionPlan& GraphExecutor::getPlanFor(
                  Stack& inputs,
                  size_t remaining_bailout_depth)
           
          {
                return pImpl->getPlanFor(inputs, remaining_bailout_depth);
              }
              
               std::shared_ptr<GraphExecutorImplBase> pImpl;
              .....

          關于 GraphExecutorImplBase,/torch/csrc/jit/runtime/graph_executor.cpp


              const ExecutionPlan& getOrCompile(const Stack& stack) 
          {
                    .....
                    auto plan = compileSpec(spec);
              
                  }
                }
              # compileSpec 會返回一個plan
              ExecutionPlan compileSpec(const ArgumentSpec& spec) 
          {
                  auto opt_graph = graph->copy();
                  GRAPH_DUMP("Optimizing the following function:", opt_graph);
                  arg_spec_creator_.specializeTypes(*opt_graph, spec);
              
                  // Phase 0. Inline functions, then clean up any artifacts that the inliner
                  //          left in that may inhibit optimization
                   .....
                  runRequiredPasses(opt_graph);
                  GRAPH_DEBUG(
                      "After runRequiredPasses, before ConstantPropagation\n", *opt_graph);
              
                  // Phase 2. Propagate detailed information about the spec through the
                  //          graph (enabled more specializations in later passes).
                  //          Shape propagation sometimes depends on certain arguments being
                  //          constants, and constant propagation doesn't need shape
                  //          information anyway, so it's better to run it first.
                  ConstantPropagation(opt_graph);
                  GRAPH_DEBUG(
                      "After ConstantPropagation, before PropagateInputShapes\n", *opt_graph);
                  PropagateInputShapes(opt_graph);
                  GRAPH_DEBUG(
                      "After PropagateInputShapes, before PropagateRequiresGrad\n",
                      *opt_graph);
                  PropagateRequiresGrad(opt_graph);
                  GRAPH_DEBUG(
                      "After PropagateRequiresGrad, before runOptimization\n", *opt_graph);
              
                  // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
                  //          that we can still execute using autograd).
                  runOptimization(opt_graph);
                  .....各種優(yōu)化
                  return ExecutionPlan(opt_graph, function_name_);
                }

          這些優(yōu)化在 torch/csrc/jit/passes/ 文件夾

          torch/csrc/jit/passes/dead_code_elimination.cpp

          /torch/csrc/jit/passes/fuse_linear.cpp

          torch/csrc/jit/passes/remove_dropout.cpp

          torch/csrc/jit/passes/fold_conv_bn.cpp

          參考

          1. INTRODUCTION TO TORCHSCRIPT

          2. PyTorch 部署_TorchScript

          3.pytorch_wiki

          4. PyTorch-JIT-Source-Code-Read-Note

          5. Abstract_syntax_tree


          - The End -


          GiantPandaCV

          長按二維碼關注我們

          本公眾號專注:

          1. 技術分享;

          2. 學術交流

          3. 資料共享

          歡迎關注我們,一起成長!

          瀏覽 60
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  国产三级片电影成人久久久 | 免费无码在线视频 | 99久久久国产精品无码 | 91成人视频 | 五月丁香六月婷婷久久 |