<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          詳盡 | PyTorch動(dòng)態(tài)圖解析

          共 16365字,需瀏覽 33分鐘

           ·

          2021-05-16 07:08

          點(diǎn)擊上方小白學(xué)視覺(jué)”,選擇加"星標(biāo)"或“置頂

          重磅干貨,第一時(shí)間送達(dá)

          本文轉(zhuǎn)自:深度學(xué)習(xí)這件小事


          背景
          PyTorch的動(dòng)態(tài)圖框架主要是由torch/csrc/autograd下的代碼實(shí)現(xiàn)的。這個(gè)目錄下定義了3個(gè)主要的基類(lèi):Variable、Function、Engine,這三個(gè)基類(lèi)及其繼承體系共同構(gòu)成了PyTorch動(dòng)態(tài)圖的根基。
          為什么叫作動(dòng)態(tài)圖呢?圖容易理解,F(xiàn)unction是nodes/vertices,(Function, input_nr)是edges。那么動(dòng)態(tài)體現(xiàn)在什么地方呢?每一次前向時(shí)構(gòu)建graph,反向時(shí)銷(xiāo)毀。本文就以torch/csrc/autograd/下的代碼為基礎(chǔ),深入講解PyTorch的動(dòng)態(tài)圖系統(tǒng)——這也可能是互聯(lián)網(wǎng)上關(guān)于PyTorch動(dòng)態(tài)圖最詳盡的文章了。
          在專(zhuān)欄文章《PyTorch的初始化》(https://zhuanlan.zhihu.com/p/57571317)中,gemfield描述了PyTorch的初始化流程,在文末提到了THPAutograd_initFunctions()調(diào)用:“最后的THPAutograd_initFunctions()則是初始化了torch的自動(dòng)微分系統(tǒng),這是PyTorch動(dòng)態(tài)圖框架的基礎(chǔ)”。而本文將以THPAutograd_initFunctions開(kāi)始,帶你走入到PyTorch的動(dòng)態(tài)圖世界中。首先為上篇,主要介紹Function、Variable、Engine的類(lèi)的繼承體系。
          autograd初始化
          THPAutograd_initFunctions這個(gè)函數(shù)實(shí)現(xiàn)如下:
             
          void THPAutograd_initFunctions(){  THPObjectPtr module(PyModule_New("torch._C._functions"));  ......  generated::initialize_autogenerated_functions();  auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C"));}

          用來(lái)初始化cpp_function_types表,這個(gè)表維護(hù)了從cpp類(lèi)型的函數(shù)到python類(lèi)型的映射:
             
          static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types

          這個(gè)表里存放的都是和autograd相關(guān)的函數(shù)的映射關(guān)系,起什么作用呢?比如我在python中print一個(gè)Variable的grad_fn:
             
          >>> gemfield = torch.empty([2,2],requires_grad=True)>>> syszux = gemfield * gemfield>>> syszux.grad_fn<ThMulBackward object at 0x7f111621c350>

          grad_fn是一個(gè)Function的實(shí)例,我們?cè)贑++中定義了那么多反向函數(shù)(參考下文),但是怎么在python中訪(fǎng)問(wèn)呢?就靠上面這個(gè)表的映射。實(shí)際上,cpp_function_types這個(gè)映射表就是為了在python中打印grad_fn服務(wù)的。

          Variable
          參考:https://zhuanlan.zhihu.com/p/64135058
          以下面的代碼片段作為例子:
             
          gemfield = torch.ones(2, 2, requires_grad=True)syszux = gemfield + 2civilnet = syszux * syszux * 3gemfieldout = civilnet.mean()gemfieldout.backward()

          需要指出的是,動(dòng)態(tài)圖是在前向的時(shí)候建立起來(lái)的。gemfieldout作為前向的最終輸出,在反向傳播的時(shí)候,卻是計(jì)算的最初輸入—在動(dòng)態(tài)圖中,我們稱(chēng)之為root。在下文介紹Engine的時(shí)候,你就會(huì)看到,我們會(huì)使用gemfieldout這個(gè)root來(lái)構(gòu)建GraphRoot實(shí)例,以此作為Graph的輸入。
          Function
          在開(kāi)始介紹Function之前,還是以上面的代碼為例,在一次前向的過(guò)程中,我們會(huì)創(chuàng)建出如下的Variable和Function實(shí)例:
             
          #Variable實(shí)例gemfield --> grad_fn_ (Function實(shí)例)= None         --> grad_accumulator_ (Function實(shí)例)= AccumulateGrad實(shí)例0x55ca7f304500         --> output_nr_ = 0
          #Function實(shí)例, 0x55ca7f872e90AddBackward0實(shí)例 --> sequence_nr_ (uint64_t) = 0 --> next_edges_ (edge_list) --> std::vector<Edge> = [(AccumulateGrad實(shí)例, 0),(0, 0)] --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2],cpu])] --> alpha (Scalar) = 1 --> apply() --> 使用 AddBackward0 的apply
          #Variable實(shí)例syszux --> grad_fn_ (Function實(shí)例)= AddBackward0實(shí)例0x55ca7f872e90 --> output_nr_ = 0
          #Function實(shí)例, 0x55ca7ebba2a0MulBackward0 --> sequence_nr_ (uint64_t) = 1 --> next_edges_ (edge_list) = [(AddBackward0實(shí)例0x55ca7f872e90,0),(AddBackward0實(shí)例0x55ca7f872e90,0)] --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2],cpu])] --> alpha (Scalar) = 1 --> apply() --> 使用 MulBackward0 的apply
          # #Variable實(shí)例,syszux * syszux得到的tmptmp --> grad_fn_ (Function實(shí)例)= MulBackward0實(shí)例0x55ca7ebba2a0 --> output_nr_ = 0
          #Function實(shí)例,0x55ca7fada2f0MulBackward0 --> sequence_nr_ (uint64_t) = 2 (每個(gè)線(xiàn)程內(nèi)自增) --> next_edges_ (edge_list) = [(MulBackward0實(shí)例0x55ca7ebba2a0,0),(0,0)] --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType, [2, 2],cpu])] --> self_ (SavedVariable) = tmp的淺拷貝 --> other_ (SavedVariable) = 3的淺拷貝 --> apply() --> 使用 MulBackward0 的apply
          #Variable實(shí)例civilnet --> grad_fn_ (Function實(shí)例)= MulBackward0實(shí)例0x55ca7fada2f0 -
          #Function實(shí)例,0x55ca7eb358b0MeanBackward0 --> sequence_nr_ (uint64_t) = 3 (每個(gè)線(xiàn)程內(nèi)自增) --> next_edges_ (edge_list) = [(MulBackward0實(shí)例0x55ca7fada2f0,0)] --> input_metadata_ --> [(type, shape, device)...] = [(CPUFloatType|[]|cpu])] --> self_sizes (std::vector<int64_t>) = (2, 2) --> self_numel = 4 --> apply() --> 使用 MulBackward0 的apply#Variable實(shí)例gemfieldout --> grad_fn_ (Function實(shí)例)= MeanBackward0實(shí)例0x55ca7eb358b0 --> output_nr_ = 0
          這些用于反向計(jì)算的Function實(shí)例之間通過(guò)next_edges_連接在一起,因?yàn)檫@些Function的實(shí)際運(yùn)行都是在反向期間,因此,輸出輸出關(guān)系正好和前向期間是反過(guò)來(lái)的。它們通過(guò)next_edges_連接在一起。用一個(gè)圖來(lái)概括,就是下面這樣:
          這就引入一個(gè)新的話(huà)題——Function類(lèi)是如何抽象出來(lái)的。
          #Function基類(lèi)定義

          Function的數(shù)據(jù)成員如下所示:
             
          using edge_list = std::vector<Edge>;using variable_list = std::vector<Variable>;
          struct TORCH_API Function {... virtual variable_list apply(variable_list&& inputs) = 0;... const uint64_t sequence_nr_; edge_list next_edges_; PyObject* pyobj_ = nullptr; // weak reference std::unique_ptr<AnomalyMetadata> anomaly_metadata_ = nullptr; std::vector<std::unique_ptr<FunctionPreHook>> pre_hooks_; std::vector<std::unique_ptr<FunctionPostHook>> post_hooks_; at::SmallVector<InputMetadata, 2> input_metadata_;};

          #Function call
          Function類(lèi)是抽象出來(lái)的基類(lèi),代表一個(gè)op(operation),每個(gè)op接收的參數(shù)是0個(gè)、1個(gè)或多個(gè)Variable實(shí)例(使用std::vector封裝),并與此同時(shí)輸出0個(gè)、1個(gè)或多個(gè)Variable實(shí)例。PyTorch中所有用于反向傳播計(jì)算的函數(shù)都繼承自Function類(lèi),并重寫(xiě)了Function類(lèi)中的apply純虛函數(shù)。因?yàn)镕unction類(lèi)中實(shí)現(xiàn)了call函數(shù):
             
          variable_list operator()(variable_list&& inputs) {
          return apply(std::move(inputs));
          }
          所以依靠C++的多態(tài),對(duì)op的call將轉(zhuǎn)化為自身(子類(lèi))的apply調(diào)用。Function類(lèi)中最重要的方法是call函數(shù),call會(huì)調(diào)用apply,call函數(shù)接收vector封裝的多個(gè)Variable實(shí)例,并輸出vector封裝的多個(gè)Variable實(shí)例。輸入?yún)?shù)的vector長(zhǎng)度可以由num_inputs()調(diào)用獲得,對(duì)應(yīng)的,輸出的vector長(zhǎng)度則由num_outputs()獲得。
          #Function的輸入
          Function成員input_metadata_代表input data的meta信息,界定了一個(gè)Function的輸入:
             
          struct InputMetadata {...  const at::Type* type_ = nullptr;  at::DimVector shape_;  at::Device device_ = at::kCPU;};

          #Autograd graph的edge和vertices
          如果將PyTorch的autograd系統(tǒng)看作是一個(gè)圖(graph)的話(huà),那么每個(gè)Function實(shí)例就是graph中的節(jié)點(diǎn)(nodes/vertices),各個(gè)Function實(shí)例之間則是通過(guò)Edge連接的。Edge是個(gè)結(jié)構(gòu)體,通過(guò) (Function, input_nr) 的配對(duì)來(lái)代表graph中的edge:
             
          struct Edge {...  std::shared_ptr<Function> function;  uint32_t input_nr;};

          Function的成員next_edges_正是一組這樣的Edge實(shí)例,代表此function實(shí)例的返回值要輸出到的(另外)function,也即next_edges_是function和function之間的紐帶。
          Function的輸入輸出都是Variable實(shí)例,因此,當(dāng)一個(gè)graph被執(zhí)行的時(shí)候,Variable實(shí)例就在這些edges之間來(lái)傳輸流動(dòng)。當(dāng)兩個(gè)或者多個(gè)Edge指向同一個(gè)Function的時(shí)候(這個(gè)節(jié)點(diǎn)的入度大于1),這些edges的輸出將會(huì)隱含的相加起來(lái)再送給指向的目標(biāo)Function。
          Function和Function之間通過(guò)next_edge接口連接在一起,你可以使用add_next_edge()來(lái)向Function添加一個(gè)edge, 通過(guò)next_edge(index)獲取對(duì)應(yīng)的edge,通過(guò)next_edges()方法獲得迭代edge的迭代器。每一個(gè)Function都有一個(gè)sequence number,隨著Function實(shí)例的不斷構(gòu)建而單調(diào)增長(zhǎng)。你可以通過(guò)sequence_nr()方法來(lái)或者一個(gè)Function的sequence number。
          Function繼承體系

          基類(lèi)Function直接派生出TraceableFunction和以下這些Function:
             
          CopySlices : public Function DelayedError : public Function Error : public Function Gather : public Function GraphRoot : public Function Scatter : public FunctionAccumulateGrad : public Function AliasBackward : public Function AsStridedBackward : public Function CopyBackwards : public Function DiagonalBackward : public Function ExpandBackward : public Function IndicesBackward0 : public Function IndicesBackward1 : public Function PermuteBackward : public Function SelectBackward : public Function SliceBackward : public Function SqueezeBackward0 : public Function SqueezeBackward1 : public Function TBackward : public Function TransposeBackward0 : public Function UnbindBackward : public Function UnfoldBackward : public Function UnsqueezeBackward0 : public Function ValuesBackward0 : public Function ValuesBackward1 : public Function ViewBackward : public Function
          PyFunction : public Function

          這其中,從基類(lèi)Function派生出來(lái)的AccumulateGrad、TraceableFunction、GraphRoot是比較關(guān)鍵的類(lèi)。
          #派生類(lèi)AccumulateGrad
          先說(shuō)說(shuō)AccumulateGrad,AccumulateGrad正是Variable的grad_accumulator_成員的類(lèi)型:
             
          struct AccumulateGrad : public Function {  explicit AccumulateGrad(Variable variable_);  variable_list apply(variable_list&& grads) override;  Variable variable;};

          可見(jiàn)一個(gè)AccumulateGrad實(shí)例必須用一個(gè)Variable構(gòu)建,apply調(diào)用接收一個(gè)list的Variable的實(shí)例——這都是和Variable的grad_accumulator_相關(guān)的。
          #派生類(lèi)GraphRoot
          對(duì)于GraphRoot,前向時(shí)候的最終輸出——在反向的時(shí)候作為最初輸入——是由GraphRoot封裝的:
             
          struct GraphRoot : public Function {  GraphRoot(edge_list functions, variable_list inputs)      : Function(std::move(functions)),        outputs(std::move(inputs)) {}  variable_list apply(variable_list&& inputs) override {    return outputs;  }  variable_list outputs;};

          GraphRoot——正如Function的靈魂在apply一樣——其apply函數(shù)僅僅返回它的輸入!
          #派生類(lèi)TraceableFunction
          再說(shuō)說(shuō)TraceableFunction:
             
          struct TraceableFunction : public Function {
          using Function::Function;
          bool is_traceable() final {
          return true;
          }
          };
          TraceableFunction會(huì)進(jìn)一步派生出372個(gè)子類(lèi)(2019年4月),這些子類(lèi)的名字都含有一個(gè)共同的部分:Backward。這說(shuō)明什么呢?這些函數(shù)將只會(huì)用在反向傳播中:
             
          AbsBackward : public TraceableFunction AcosBackward : public TraceableFunction AdaptiveAvgPool2DBackwardBackward : public TraceableFunction AdaptiveAvgPool2DBackward : public TraceableFunction AdaptiveAvgPool3DBackwardBackward : public TraceableFunction AdaptiveAvgPool3DBackward : public TraceableFunction AdaptiveMaxPool2DBackwardBackward : public TraceableFunction AdaptiveMaxPool2DBackward : public TraceableFunction AdaptiveMaxPool3DBackwardBackward : public TraceableFunction AdaptiveMaxPool3DBackward : public TraceableFunction AddBackward0 : public TraceableFunction AddBackward1 : public TraceableFunction AddbmmBackward : public TraceableFunction AddcdivBackward : public TraceableFunction AddcmulBackward : public TraceableFunction AddmmBackward : public TraceableFunction AddmvBackward : public TraceableFunction AddrBackward : public TraceableFunction ......SoftmaxBackwardDataBackward : public TraceableFunction SoftmaxBackward : public TraceableFunction ......UpsampleBicubic2DBackwardBackward : public TraceableFunction UpsampleBicubic2DBackward : public TraceableFunction UpsampleBilinear2DBackwardBackward : public TraceableFunction UpsampleBilinear2DBackward : public TraceableFunction UpsampleLinear1DBackwardBackward : public TraceableFunction UpsampleLinear1DBackward : public TraceableFunction UpsampleNearest1DBackwardBackward : public TraceableFunction UpsampleNearest1DBackward : public TraceableFunction UpsampleNearest2DBackwardBackward : public TraceableFunction UpsampleNearest2DBackward : public TraceableFunction UpsampleNearest3DBackwardBackward : public TraceableFunction UpsampleNearest3DBackward : public TraceableFunction UpsampleTrilinear3DBackwardBackward : public TraceableFunction UpsampleTrilinear3DBackward : public TraceableFunction ......

          這300多個(gè)Backward function都重寫(xiě)了apply函數(shù),來(lái)實(shí)現(xiàn)自己的反向求導(dǎo)算法,比如加法的反向求導(dǎo)函數(shù)AddBackward0:
             
          struct AddBackward0 : public TraceableFunction {  using TraceableFunction::TraceableFunction;  variable_list apply(variable_list&& grads) override;  Scalar alpha;};

          這些apply函數(shù)是Function的靈魂,是反向傳播計(jì)算時(shí)候的核心執(zhí)行邏輯。
          Engine
          Engine類(lèi)實(shí)現(xiàn)了從輸出的variable(以及它的gradients)到root variables(用戶(hù)創(chuàng)建的并且requires_grad=True)之間的反向傳播。
             
          gemfield = torch.ones(2, 2, requires_grad=True)syszux = gemfield + 2civilnet = syszux * syszux * 3gemfieldout = civilnet.mean()gemfieldout.backward()

          還是以上面這個(gè)代碼片段為例,Engine實(shí)現(xiàn)了從gemfieldout到gemfield的反向傳播:
          1,如何根據(jù)gemfieldout構(gòu)建GraphRoot;
          2,如何根據(jù)這些Function實(shí)例及它們上的metadata構(gòu)建graph;
          3,如何實(shí)現(xiàn)Queue來(lái)多線(xiàn)程完成反向計(jì)算的工作。
          #Engine類(lèi)定義
          Engine類(lèi)的定義如下:
             
          struct Engine {  using ready_queue_type = std::deque<std::pair<std::shared_ptr<Function>, InputBuffer>>;  using dependencies_type = std::unordered_map<Function*, int>;  virtual variable_list execute(const edge_list& roots,const variable_list& inputs,...const edge_list& outputs = {});  void queue_callback(std::function<void()> callback);protected:  void compute_dependencies(Function* root, GraphTask& task);  void evaluate_function(FunctionTask& task);  void start_threads();  virtual void thread_init(int device);  virtual void thread_main(GraphTask *graph_task);  std::vector<std::shared_ptr<ReadyQueue>> ready_queues;};

          核心就是execute函數(shù),它接收一組Edge——(Function, input number) pairs ——來(lái)作為函數(shù)的輸入,然后通過(guò)next_edge不斷的找到指向的下一個(gè)Edge,最終完成整個(gè)Graph的計(jì)算。
          #派生類(lèi)PythonEngine
          然而我們實(shí)際使用的是Engine類(lèi)的派生類(lèi):PythonEngine。PythonEngine子類(lèi)重寫(xiě)了父類(lèi)的execute,只不過(guò)僅僅提供了把C++異常翻譯為Python異常的功能,核心工作還是由Engine基類(lèi)來(lái)完成:
             
          struct PythonEngine : public Engine
          整個(gè)PyTorch程序全局只維護(hù)一個(gè)Engine實(shí)例,也就是PythonEngine實(shí)例。
          BP調(diào)用棧

          既然Engine是用來(lái)計(jì)算網(wǎng)絡(luò)反向傳播的,我們不妨看下這個(gè)調(diào)用棧是怎么到達(dá)Engine類(lèi)的。如果我們對(duì)gemfieldout進(jìn)行backward計(jì)算,則調(diào)用棧如下所示:
             
          #torch/tensor.py,self is gemfieldoutdef backward(self, gradient=None, retain_graph=None, create_graph=False)|V#torch.autograd.backward(self, gradient, retain_graph, create_graph)#torch/autograd/__init__.pydef backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, grad_variables=None)|VVariable._execution_engine.run_backward(tensors, grad_tensors, retain_graph, create_graph,allow_unreachable=True)#轉(zhuǎn)化為Variable._execution_engine.run_backward((gemfieldout,), (tensor(1.),), False, False,True)|V#torch/csrc/autograd/python_engine.cppPyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)|V#torch/csrc/autograd/python_engine.cppvariable_list PythonEngine::execute(const edge_list& roots, const variable_list& inputs, bool keep_graph, bool create_graph, const edge_list& outputs)|V#torch/csrc/autograd/engine.cpp

          總結(jié)

          在下段文章中,Gemfield將主要介紹Engine這個(gè)類(lèi)是如何在gemfieldout.backward()中運(yùn)行PyTorch動(dòng)態(tài)圖的。

          下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
          在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺(jué)、目標(biāo)跟蹤、生物視覺(jué)、超分辨率處理等二十多章內(nèi)容。

          下載2:Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目52講
          小白學(xué)視覺(jué)公眾號(hào)后臺(tái)回復(fù):Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目即可下載包括圖像分割、口罩檢測(cè)、車(chē)道線(xiàn)檢測(cè)、車(chē)輛計(jì)數(shù)、添加眼線(xiàn)、車(chē)牌識(shí)別、字符識(shí)別、情緒檢測(cè)、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺(jué)實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺(jué)。

          下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
          小白學(xué)視覺(jué)公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。

          交流群


          歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺(jué)、傳感器自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱(chēng)+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺(jué)SLAM“。請(qǐng)按照格式備注,否則不予通過(guò)。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~


          瀏覽 49
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  色图av| 人妻精品无码 | 精品久久久久久蜜桃 | 伊人五月婷婷 | 激情国产内射 |