實(shí)操教程|PyTorch AutoGrad C++層實(shí)現(xiàn)

極市導(dǎo)讀
本文為一篇實(shí)操教程,作者介紹了PyTorch AutoGrad C++層實(shí)現(xiàn)中各個(gè)概念的解釋。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺(jué)的最前沿
autograd依賴(lài)的數(shù)據(jù)結(jié)構(gòu)
at::Tensor:shared ptr 指向 TensorImpl
TensorImpl:對(duì) at::Tensor 的實(shí)現(xiàn)
包含一個(gè)類(lèi)型為 [AutogradMetaInterface](c10::AutogradMetaInterface)的autograd_meta_,在tensor是需要求導(dǎo)的variable時(shí),會(huì)被實(shí)例化為[AutogradMeta](c10::AutogradMetaInterface),里面包含了autograd需要的信息
Variable: 就是Tensor,為了向前兼容保留的
using Variable = at::Tensor; 概念上有區(qū)別, Variable是需要計(jì)算gradient的,Tensor是不需要計(jì)算gradient的Variable的AutogradMeta是對(duì)[AutogradMetaInterface](c10::AutogradMetaInterface)的實(shí)現(xiàn),里面包含了一個(gè)Variable,就是該variable的gradient帶有version和view 會(huì)實(shí)例化 AutogradMeta, autograd需要的關(guān)鍵信息都在這里
AutoGradMeta : 記錄 Variable 的autograd歷史信息
包含一個(gè)叫g(shù)rad_的 Variable, 即AutoGradMeta對(duì)應(yīng)的var的梯度tensor包含類(lèi)型為 Node指針的grad_fn(var在graph內(nèi)部時(shí))和grad_accumulator(var時(shí)葉子時(shí)), 記錄生成grad_的方法包含 output_nr,標(biāo)識(shí)var對(duì)應(yīng)grad_fn的輸入編號(hào)構(gòu)造函數(shù)包含一個(gè)類(lèi)型為 Edge的gradient_edge,gradient_edge.function就是grad_fn, 另外gradient_edge.input_nr記錄著對(duì)應(yīng)grad_fn的輸入編號(hào),會(huì)賦值給AutoGradMeta的output_nr
autograd::Edge: 指向autograd::Node的一個(gè)輸入
包含類(lèi)型為 Node指針,表示edge指向的Node包含 input_nr, 表示edge指向的Node的輸入編號(hào)
autograd::Node: 對(duì)應(yīng)AutoGrad Graph中的Op
是所有autograd op的抽象基類(lèi),子類(lèi)重載apply方法
next_edges_記錄出邊input_metadata_記錄輸入的tensor的metadata實(shí)現(xiàn)的子類(lèi)一般是可求導(dǎo)的函數(shù)和他們的梯度計(jì)算op
Node in AutoGrad Graph
Variable通過(guò)Edge關(guān)聯(lián)Node的輸入和輸出 多個(gè)Edge指向同一個(gè)Var時(shí),默認(rèn)做累加 call operator
最重要的方法,實(shí)現(xiàn)計(jì)算 next_edge
縫合Node的操作 獲取Node的出邊,next_edge(index)/next_edges() add_next_edge(),創(chuàng)建
前向計(jì)算
PyTorch通過(guò)tracing只生成了后向AutoGrad Graph.
代碼是生成的,需要編譯才能看到對(duì)應(yīng)的生成結(jié)果
gen_variable_type.py生成可導(dǎo)版本的op 生成的代碼在 pytorch/torch/csrc/autograd/generated/前向計(jì)算時(shí),進(jìn)行了tracing,記錄了后向計(jì)算圖構(gòu)建需要的信息 這里以relu為例,代碼在 pytorch/torch/csrc/autograd/generated/VariableType_0.cpp
Tensor relu(const Tensor & self) {auto& self_ = unpack(self, "self", 0);std::shared_ptr<ReluBackward0> grad_fn;if (compute_requires_grad( self )) { // 如果輸入var需要grad// ReluBackward0的類(lèi)型是Nodegrad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);// collect_next_edges(var)返回輸入var對(duì)應(yīng)的指向的// grad_fn(前一個(gè)op的backward或者是一個(gè)accumulator的)的輸入的Edge// set_next_edges(),在grad_fn中記錄這些Edge(這里完成了后向的構(gòu)圖)grad_fn->set_next_edges(collect_next_edges( self ));// 記錄當(dāng)前var的一個(gè)版本grad_fn->self_ = SavedVariable(self, false);}c10::optional<Storage> self__storage_saved =self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;c10::intrusive_ptr<TensorImpl> self__impl_saved;if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();auto tmp = ([&]() {at::AutoNonVariableTypeMode non_var_type_mode(true);return at::relu(self_); // 前向計(jì)算})();auto result = std::move(tmp);if (self__storage_saved.has_value())AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());if (grad_fn) {// grad_fn增加一個(gè)輸入,記錄輸出var的metadata作為grad_fn的輸入// 輸出var的AutoGradMeta實(shí)例化,輸出var的AutoGradMeta指向起grad_fn的輸入set_history(flatten_tensor_args( result ), grad_fn);}return result;}
可以看到和 grad_fn相關(guān)的操作trace了一個(gè)op的計(jì)算,構(gòu)建了后向計(jì)算圖.
后向計(jì)算
autograd::backward():計(jì)算output var的梯度值,調(diào)用的 run_backward()
autograd::grad() :計(jì)算有output var和到特定input的梯度值,調(diào)用的 run_backward()
autograd::run_backward()
對(duì)于要求梯度的output var,獲取其指向的grad_fn作為roots,是后向圖的起點(diǎn) 對(duì)于有input var的,獲取其指向的grad_fn作為output_edges, 是后向圖的終點(diǎn) 調(diào)用 autograd::Engine::get_default_engine().execute(...)執(zhí)行后向計(jì)算
autograd::Engine::execute(...)
創(chuàng)建
GraphTask,記錄了一些配置信息創(chuàng)建
GraphRoot,是一個(gè)Node,把所有的roots作為其輸出邊,Node的apply()返回的是roots的grad【這里已經(jīng)得到一個(gè)單起點(diǎn)的圖】計(jì)算依賴(lài)
compute_dependencies(...)從GraphRoot開(kāi)始,廣度遍歷,記錄所有碰到的grad_fn的指針,并統(tǒng)計(jì)grad_fn被遇到的次數(shù),這些信息記錄到GraphTask中 GraphTask初始化:當(dāng)有input var時(shí),判斷后向圖中哪些節(jié)點(diǎn)是真正需要計(jì)算的GraphTask執(zhí)行選擇CPU or GPU線(xiàn)程執(zhí)行 以CPU為例,調(diào)用的 autograd::Engine::thread_main(...)
autograd::Engine::thread_main(...)
evaluate_function(...),輸入輸出的處理,調(diào)度call_function(...), 調(diào)用對(duì)應(yīng)的Node計(jì)算執(zhí)行后向過(guò)程中的生成的中間grad Tensor,如果不釋放,可以用于計(jì)算高階導(dǎo)數(shù);(同構(gòu)的后向圖,之前的grad tensor是新的輸出,grad_fn變成之前grad_fn的backward,這些新的輸出還可以再backward) 具體的執(zhí)行機(jī)制可以支撐單獨(dú)開(kāi)一個(gè)Topic分析,在這里討論到后向圖完成構(gòu)建為止.
推薦閱讀
2021-04-11
2021-04-08
2021-04-07

# CV技術(shù)社群邀請(qǐng)函 #
備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)
即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與 10000+來(lái)自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺(jué)開(kāi)發(fā)者互動(dòng)交流~

