實(shí)操教程|PyTorch AutoGrad C++層實(shí)現(xiàn)
點(diǎn)擊上方“程序員大白”,選擇“星標(biāo)”公眾號(hào)
重磅干貨,第一時(shí)間送達(dá)

極市導(dǎo)讀
本文為一篇實(shí)操教程,作者介紹了PyTorch AutoGrad C++層實(shí)現(xiàn)中各個(gè)概念的解釋。
autograd依賴的數(shù)據(jù)結(jié)構(gòu)
at::Tensor:shared ptr 指向 TensorImpl
TensorImpl:對(duì) at::Tensor 的實(shí)現(xiàn)
包含一個(gè)類型為 [AutogradMetaInterface](c10::AutogradMetaInterface)的autograd_meta_,在tensor是需要求導(dǎo)的variable時(shí),會(huì)被實(shí)例化為[AutogradMeta](c10::AutogradMetaInterface),里面包含了autograd需要的信息
Variable: 就是Tensor,為了向前兼容保留的
using Variable = at::Tensor; 概念上有區(qū)別, Variable是需要計(jì)算gradient的,Tensor是不需要計(jì)算gradient的Variable的AutogradMeta是對(duì)[AutogradMetaInterface](c10::AutogradMetaInterface)的實(shí)現(xiàn),里面包含了一個(gè)Variable,就是該variable的gradient帶有version和view 會(huì)實(shí)例化 AutogradMeta, autograd需要的關(guān)鍵信息都在這里
AutoGradMeta : 記錄 Variable 的autograd歷史信息
包含一個(gè)叫g(shù)rad_的 Variable, 即AutoGradMeta對(duì)應(yīng)的var的梯度tensor包含類型為 Node指針的grad_fn(var在graph內(nèi)部時(shí))和grad_accumulator(var時(shí)葉子時(shí)), 記錄生成grad_的方法包含 output_nr,標(biāo)識(shí)var對(duì)應(yīng)grad_fn的輸入編號(hào)構(gòu)造函數(shù)包含一個(gè)類型為 Edge的gradient_edge,gradient_edge.function就是grad_fn, 另外gradient_edge.input_nr記錄著對(duì)應(yīng)grad_fn的輸入編號(hào),會(huì)賦值給AutoGradMeta的output_nr
autograd::Edge: 指向autograd::Node的一個(gè)輸入
包含類型為 Node指針,表示edge指向的Node包含 input_nr, 表示edge指向的Node的輸入編號(hào)
autograd::Node: 對(duì)應(yīng)AutoGrad Graph中的Op
是所有autograd op的抽象基類,子類重載apply方法
next_edges_記錄出邊input_metadata_記錄輸入的tensor的metadata實(shí)現(xiàn)的子類一般是可求導(dǎo)的函數(shù)和他們的梯度計(jì)算op
Node in AutoGrad Graph
Variable通過(guò)Edge關(guān)聯(lián)Node的輸入和輸出 多個(gè)Edge指向同一個(gè)Var時(shí),默認(rèn)做累加 call operator
最重要的方法,實(shí)現(xiàn)計(jì)算 next_edge
縫合Node的操作 獲取Node的出邊,next_edge(index)/next_edges() add_next_edge(),創(chuàng)建
前向計(jì)算
PyTorch通過(guò)tracing只生成了后向AutoGrad Graph.
代碼是生成的,需要編譯才能看到對(duì)應(yīng)的生成結(jié)果
gen_variable_type.py生成可導(dǎo)版本的op 生成的代碼在 pytorch/torch/csrc/autograd/generated/前向計(jì)算時(shí),進(jìn)行了tracing,記錄了后向計(jì)算圖構(gòu)建需要的信息 這里以relu為例,代碼在 pytorch/torch/csrc/autograd/generated/VariableType_0.cpp
Tensor relu(const Tensor & self) {auto& self_ = unpack(self, "self", 0);std::shared_ptr<ReluBackward0> grad_fn;if (compute_requires_grad( self )) { // 如果輸入var需要grad// ReluBackward0的類型是Nodegrad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);// collect_next_edges(var)返回輸入var對(duì)應(yīng)的指向的// grad_fn(前一個(gè)op的backward或者是一個(gè)accumulator的)的輸入的Edge// set_next_edges(),在grad_fn中記錄這些Edge(這里完成了后向的構(gòu)圖)grad_fn->set_next_edges(collect_next_edges( self ));// 記錄當(dāng)前var的一個(gè)版本grad_fn->self_ = SavedVariable(self, false);}c10::optional<Storage> self__storage_saved =self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;c10::intrusive_ptr<TensorImpl> self__impl_saved;if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();auto tmp = ([&]() {at::AutoNonVariableTypeMode non_var_type_mode(true);return at::relu(self_); // 前向計(jì)算})();auto result = std::move(tmp);if (self__storage_saved.has_value())AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());if (grad_fn) {// grad_fn增加一個(gè)輸入,記錄輸出var的metadata作為grad_fn的輸入// 輸出var的AutoGradMeta實(shí)例化,輸出var的AutoGradMeta指向起grad_fn的輸入set_history(flatten_tensor_args( result ), grad_fn);}return result;}
可以看到和 grad_fn相關(guān)的操作trace了一個(gè)op的計(jì)算,構(gòu)建了后向計(jì)算圖.
后向計(jì)算
autograd::backward():計(jì)算output var的梯度值,調(diào)用的 run_backward()
autograd::grad() :計(jì)算有output var和到特定input的梯度值,調(diào)用的 run_backward()
autograd::run_backward()
對(duì)于要求梯度的output var,獲取其指向的grad_fn作為roots,是后向圖的起點(diǎn) 對(duì)于有input var的,獲取其指向的grad_fn作為output_edges, 是后向圖的終點(diǎn) 調(diào)用 autograd::Engine::get_default_engine().execute(...)執(zhí)行后向計(jì)算
autograd::Engine::execute(...)
創(chuàng)建
GraphTask,記錄了一些配置信息創(chuàng)建
GraphRoot,是一個(gè)Node,把所有的roots作為其輸出邊,Node的apply()返回的是roots的grad【這里已經(jīng)得到一個(gè)單起點(diǎn)的圖】計(jì)算依賴
compute_dependencies(...)從GraphRoot開始,廣度遍歷,記錄所有碰到的grad_fn的指針,并統(tǒng)計(jì)grad_fn被遇到的次數(shù),這些信息記錄到GraphTask中 GraphTask初始化:當(dāng)有input var時(shí),判斷后向圖中哪些節(jié)點(diǎn)是真正需要計(jì)算的GraphTask執(zhí)行選擇CPU or GPU線程執(zhí)行 以CPU為例,調(diào)用的 autograd::Engine::thread_main(...)
autograd::Engine::thread_main(...)
evaluate_function(...),輸入輸出的處理,調(diào)度call_function(...), 調(diào)用對(duì)應(yīng)的Node計(jì)算執(zhí)行后向過(guò)程中的生成的中間grad Tensor,如果不釋放,可以用于計(jì)算高階導(dǎo)數(shù);(同構(gòu)的后向圖,之前的grad tensor是新的輸出,grad_fn變成之前grad_fn的backward,這些新的輸出還可以再backward) 具體的執(zhí)行機(jī)制可以支撐單獨(dú)開一個(gè)Topic分析,在這里討論到后向圖完成構(gòu)建為止.
推薦閱讀
國(guó)產(chǎn)小眾瀏覽器因屏蔽視頻廣告,被索賠100萬(wàn)(后續(xù))
年輕人“不講武德”:因看黃片上癮,把網(wǎng)站和786名女主播起訴了
關(guān)于程序員大白
程序員大白是一群哈工大,東北大學(xué),西湖大學(xué)和上海交通大學(xué)的碩士博士運(yùn)營(yíng)維護(hù)的號(hào),大家樂(lè)于分享高質(zhì)量文章,喜歡總結(jié)知識(shí),歡迎關(guān)注[程序員大白],大家一起學(xué)習(xí)進(jìn)步!

