<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          實(shí)操教程|PyTorch AutoGrad C++層實(shí)現(xiàn)

          共 9036字,需瀏覽 19分鐘

           ·

          2021-04-13 22:12

          ↑ 點(diǎn)擊藍(lán)字 關(guān)注極市平臺(tái)

          作者丨xxy-zhihu@知乎
          來(lái)源丨h(huán)ttps://zhuanlan.zhihu.com/p/339039943
          編輯丨極市平臺(tái)

          極市導(dǎo)讀

           

          本文為一篇實(shí)操教程,作者介紹了PyTorch AutoGrad C++層實(shí)現(xiàn)中各個(gè)概念的解釋。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺(jué)的最前沿

          autograd依賴(lài)的數(shù)據(jù)結(jié)構(gòu)

          at::Tensor:shared ptr 指向 TensorImpl

          TensorImpl:對(duì) at::Tensor 的實(shí)現(xiàn)

          • 包含一個(gè)類(lèi)型為 [AutogradMetaInterface](c10::AutogradMetaInterface) 的autograd_meta_,在tensor是需要求導(dǎo)的variable時(shí),會(huì)被實(shí)例化為 [AutogradMeta](c10::AutogradMetaInterface) ,里面包含了autograd需要的信息

          Variable: 就是Tensor,為了向前兼容保留的

          • using Variable = at::Tensor;
          • 概念上有區(qū)別, Variable 是需要計(jì)算gradient的, Tensor 是不需要計(jì)算gradient的
          • VariableAutogradMeta是對(duì) [AutogradMetaInterface](c10::AutogradMetaInterface)的實(shí)現(xiàn),里面包含了一個(gè) Variable,就是該variable的gradient
          • 帶有version和view
          • 會(huì)實(shí)例化 AutogradMeta , autograd需要的關(guān)鍵信息都在這里

          AutoGradMeta : 記錄 Variable 的autograd歷史信息

          • 包含一個(gè)叫g(shù)rad_的 Variable, 即 AutoGradMeta 對(duì)應(yīng)的var的梯度tensor
          • 包含類(lèi)型為 Node 指針的 grad_fn (var在graph內(nèi)部時(shí))和 grad_accumulator(var時(shí)葉子時(shí)), 記錄生成grad_的方法
          • 包含 output_nr ,標(biāo)識(shí)var對(duì)應(yīng) grad_fn的輸入編號(hào)
          • 構(gòu)造函數(shù)包含一個(gè)類(lèi)型為 Edge的gradient_edge, gradient_edge.function 就是 grad_fn, 另外 gradient_edge.input_nr 記錄著對(duì)應(yīng) grad_fn的輸入編號(hào),會(huì)賦值給 AutoGradMetaoutput_nr

          autograd::Edge: 指向autograd::Node的一個(gè)輸入

          • 包含類(lèi)型為 Node 指針,表示edge指向的Node
          • 包含 input_nr, 表示edge指向的Node的輸入編號(hào)

          autograd::Node: 對(duì)應(yīng)AutoGrad Graph中的Op

          • 是所有autograd op的抽象基類(lèi),子類(lèi)重載apply方法

            • next_edges_記錄出邊
            • input_metadata_記錄輸入的tensor的metadata
          • 實(shí)現(xiàn)的子類(lèi)一般是可求導(dǎo)的函數(shù)和他們的梯度計(jì)算op

          • Node in AutoGrad Graph

            • Variable通過(guò)Edge關(guān)聯(lián)Node的輸入和輸出
            • 多個(gè)Edge指向同一個(gè)Var時(shí),默認(rèn)做累加
          • call operator

            • 最重要的方法,實(shí)現(xiàn)計(jì)算
          • next_edge

            • 縫合Node的操作
            • 獲取Node的出邊,next_edge(index)/next_edges()
            • add_next_edge(),創(chuàng)建

          前向計(jì)算

          PyTorch通過(guò)tracing只生成了后向AutoGrad Graph.

          代碼是生成的,需要編譯才能看到對(duì)應(yīng)的生成結(jié)果

          • gen_variable_type.py生成可導(dǎo)版本的op
          • 生成的代碼在 pytorch/torch/csrc/autograd/generated/
          • 前向計(jì)算時(shí),進(jìn)行了tracing,記錄了后向計(jì)算圖構(gòu)建需要的信息
          • 這里以relu為例,代碼在pytorch/torch/csrc/autograd/generated/VariableType_0.cpp
          Tensor relu(const Tensor & self) {                                                                                                                                                                   auto& self_ = unpack(self, "self", 0);                                                                                                                                                             std::shared_ptr<ReluBackward0> grad_fn;                                                                                                                                                            if (compute_requires_grad( self )) { // 如果輸入var需要grad    // ReluBackward0的類(lèi)型是Node                                                                                                                                                                grad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);                                                                                                                          // collect_next_edges(var)返回輸入var對(duì)應(yīng)的指向的    // grad_fn(前一個(gè)op的backward或者是一個(gè)accumulator的)的輸入的Edge    // set_next_edges(),在grad_fn中記錄這些Edge(這里完成了后向的構(gòu)圖)    grad_fn->set_next_edges(collect_next_edges( self ));     // 記錄當(dāng)前var的一個(gè)版本                                                                                                                                              grad_fn->self_ = SavedVariable(self, false);                                                                                                                                                     }                                                                                                                                                                                                  #ifndef NDEBUG                                                                                                                                                                                     c10::optional<Storage> self__storage_saved =                                                                                                                                                         self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;                                                                                                                    c10::intrusive_ptr<TensorImpl> self__impl_saved;                                                                                                                                                   if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();                                                                                                                                   #endif                                                                                                                                                                                             auto tmp = ([&]() {                                                                                                                                                                                  at::AutoNonVariableTypeMode non_var_type_mode(true);                                                                                                                                               return at::relu(self_); // 前向計(jì)算                                                                                                                                                                          })();                                                                                                                                                                                              auto result = std::move(tmp);                                                                                                                                                                      #ifndef NDEBUG                                                                                                                                                                                     if (self__storage_saved.has_value())                                                                                                                                                                 AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));                                                                                                                             if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());                                                                                                                      #endif                                                                                                                                                                                             if (grad_fn) {                   // grad_fn增加一個(gè)輸入,記錄輸出var的metadata作為grad_fn的輸入      // 輸出var的AutoGradMeta實(shí)例化,輸出var的AutoGradMeta指向起grad_fn的輸入                                                                                                                                                                            set_history(flatten_tensor_args( result ), grad_fn);                                                                                                                                           }                                                                                                                                                                                                  return result;                                                                                                                                                                                   }

          • 可以看到和 grad_fn 相關(guān)的操作trace了一個(gè)op的計(jì)算,構(gòu)建了后向計(jì)算圖.

          后向計(jì)算

          autograd::backward():計(jì)算output var的梯度值,調(diào)用的 run_backward()

          autograd::grad() :計(jì)算有output var和到特定input的梯度值,調(diào)用的 run_backward()

          autograd::run_backward()

          • 對(duì)于要求梯度的output var,獲取其指向的grad_fn作為roots,是后向圖的起點(diǎn)
          • 對(duì)于有input var的,獲取其指向的grad_fn作為output_edges, 是后向圖的終點(diǎn)
          • 調(diào)用 autograd::Engine::get_default_engine().execute(...) 執(zhí)行后向計(jì)算

          autograd::Engine::execute(...)

          • 創(chuàng)建 GraphTask ,記錄了一些配置信息

          • 創(chuàng)建 GraphRoot ,是一個(gè)Node,把所有的roots作為其輸出邊,Node的apply()返回的是roots的grad【這里已經(jīng)得到一個(gè)單起點(diǎn)的圖】

          • 計(jì)算依賴(lài) compute_dependencies(...)

            • 從GraphRoot開(kāi)始,廣度遍歷,記錄所有碰到的grad_fn的指針,并統(tǒng)計(jì)grad_fn被遇到的次數(shù),這些信息記錄到GraphTask中
          • GraphTask 初始化:當(dāng)有input var時(shí),判斷后向圖中哪些節(jié)點(diǎn)是真正需要計(jì)算的

          • GraphTask 執(zhí)行

            • 選擇CPU or GPU線(xiàn)程執(zhí)行
            • 以CPU為例,調(diào)用的 autograd::Engine::thread_main(...)

          autograd::Engine::thread_main(...)

          • evaluate_function(...) ,輸入輸出的處理,調(diào)度

            • call_function(...) , 調(diào)用對(duì)應(yīng)的Node計(jì)算
            • 執(zhí)行后向過(guò)程中的生成的中間grad Tensor,如果不釋放,可以用于計(jì)算高階導(dǎo)數(shù);(同構(gòu)的后向圖,之前的grad tensor是新的輸出,grad_fn變成之前grad_fn的backward,這些新的輸出還可以再backward)
          • 具體的執(zhí)行機(jī)制可以支撐單獨(dú)開(kāi)一個(gè)Topic分析,在這里討論到后向圖完成構(gòu)建為止.


          推薦閱讀


          2021 年了,TensorFlow 和 PyTorch 兩個(gè)深度學(xué)習(xí)框架地位又有什么變化嗎?

          2021-04-11

          PyTorch 源碼解讀之即時(shí)編譯篇

          2021-04-08

          模型部署翻車(chē)記:pytorch轉(zhuǎn)onnx踩坑實(shí)錄

          2021-04-07



          # CV技術(shù)社群邀請(qǐng)函 #

          △長(zhǎng)按添加極市小助手
          添加極市小助手微信(ID : cvmart2)

          備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)


          即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與 10000+來(lái)自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺(jué)開(kāi)發(fā)者互動(dòng)交流~


          △點(diǎn)擊卡片關(guān)注極市平臺(tái),獲取最新CV干貨

          覺(jué)得有用麻煩給個(gè)在看啦~  
          瀏覽 50
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  欧美在线va| 无码一二区 | 亚洲综合免费观看 | 岛国av在线观看网址国产 | 日批视频免费播放 |