一個(gè)Tensor在深度學(xué)習(xí)框架中的執(zhí)行過(guò)程簡(jiǎn)單梳理
?0x0. 前言撰文:BBuf。審稿:王迎港。
相信看到這篇文章的人都對(duì)深度學(xué)習(xí)框架是有所了解和熟悉的,也多多少少會(huì)使用Python寫(xiě)一些神經(jīng)網(wǎng)絡(luò)相關(guān)的代碼。例如我們可以在PyTorch寫(xiě)出下面的代碼:
import?torch
x?=?torch.tensor([-1.0,?2.0],?device="cuda")
y?=?torch.relu(x)
print(y)
使用PyTorch運(yùn)行之后我們會(huì)獲得如下結(jié)果:
tensor([0.,?2.],?device='cuda:0')
對(duì)于x這個(gè)輸入Tensor來(lái)說(shuō),它被喂給了relu這個(gè)Op,然后輸出結(jié)果,一切看起來(lái)都很簡(jiǎn)單和正常。但如果有人問(wèn)你是否清楚這背后到底發(fā)生了什么,relu這個(gè)Op對(duì)應(yīng)的Cuda Kernel是在什么時(shí)候被GPU調(diào)用的,相信一部分人是不會(huì)很清楚的。因?yàn)榘ㄎ业拇蠖鄶?shù)人習(xí)慣在舒適區(qū)使用深度學(xué)習(xí)框架,對(duì)背后的原理可能沒(méi)有深入了解,所以回答不了也很正常。
這篇文章我就將嘗試解開(kāi)這個(gè)問(wèn)題,但我并不是以PyTorch為例來(lái)講解,而是以O(shè)neFlow為例子。為什么以O(shè)neFlow為例子呢?首先我在OneFlow工作,對(duì)這背后的執(zhí)行機(jī)制比PyTorch要清楚一些,在調(diào)用鏈跟蹤的時(shí)候會(huì)更流暢。其次,OneFlow背后這套運(yùn)行機(jī)制含有挺多PyTorch不存在的設(shè)計(jì)思想,相信讀者看完之后對(duì)深度學(xué)習(xí)框架系統(tǒng)設(shè)計(jì)方面有更多的思考和啟發(fā)。
所以,接下來(lái)就一起看看一個(gè)Tensor在OneFlow深度學(xué)習(xí)框架中的執(zhí)行過(guò)程吧。為了簡(jiǎn)單起見(jiàn),本文只考慮單機(jī)單卡模式下的Op執(zhí)行過(guò)程,不涉及OneFlow特有的consistent模式(和分布式相關(guān)),如果你對(duì)這部分感興趣可以自行查看。
0x1. Python和C++的橋梁當(dāng)我們敲下如下代碼并將其移交給OneFlow執(zhí)行時(shí):
import?oneflow?as?flow
x?=?flow.tensor([-1.0,?2.0],?device="cuda")
y?=?flow.relu(x)
print(y)
系統(tǒng)首先創(chuàng)建了一個(gè)在GPU上的輸入Tensor,然后調(diào)用了導(dǎo)出到python端的c++ functional接口relu。這里涉及到pybind11綁定相關(guān)的Python wrapper和C++ relu functor。這個(gè)交互的上層,同事在OneFlow學(xué)習(xí)筆記:python到C++調(diào)用過(guò)程分析 這篇文章有解析過(guò)了,感興趣可以看看。我們上面Python代碼中的flow.relu這個(gè)Op最終調(diào)用的是ReLU C++ Functor的實(shí)現(xiàn),我們看一下代碼。
class?ReluFunctor?{
?public:
??ReluFunctor()?{?op_?=?CHECK_JUST(one::OpBuilder("relu").Input("x",?1).Output("y",?1).Build());?}
??Maybe?operator()(const?std::shared_ptr&?x,?bool?inplace) ?const? {
????if?(inplace)?{
??????...
????}?else?{
??????return?OpInterpUtil::Dispatch(*op_,?{x});
????}
??}
?private:
??std::shared_ptr?op_;
};
這段代碼里面的op_是一個(gè)OpExpr的指針,然后在構(gòu)造函數(shù)里面調(diào)用了OpBuilder函數(shù)來(lái)創(chuàng)建了一個(gè)新的OpExpr。從后面的實(shí)際調(diào)用代碼OpInterpUtil::Dispatch可以發(fā)現(xiàn)這里的算子構(gòu)建和執(zhí)行是分開(kāi)的(因?yàn)镈ispatch函數(shù)是同時(shí)將OpExpr和輸入Tensor等分發(fā)出去,沒(méi)有直接分發(fā)執(zhí)行的結(jié)果Tensor出去,所以這里還沒(méi)有真正的執(zhí)行Op),這里的OpInterpUtil::Dispatch是負(fù)責(zé)將OpExpr,輸入Tensor和其它參數(shù)(ReLU這個(gè)算子沒(méi)有除輸入外的參數(shù))分發(fā)出去,還沒(méi)有真正的執(zhí)行。
OpExpr可以簡(jiǎn)單理解為是OneFlow算子的統(tǒng)一抽象。OpExpr大體可以分為BuiltinOpExpr、FunctionOpExpr和其他類(lèi)別的OpExpr,其中BuiltinOpExpr又可以細(xì)分為UserOpExpr和其他非UserOpExpr,用戶(hù)可以通過(guò)OpBuilder構(gòu)建出UserOpExpr。
不需要完全理解OpExpr的定義,我們只需要知道這里是通過(guò)OpBuilder類(lèi)構(gòu)造了一個(gè)新的OpExpr,這個(gè)OpExpr有Op name,UserOpConf proto_這個(gè)序列化Op信息的ProtoBuf對(duì)象,以及輸入輸出Tensor的名字等關(guān)鍵信息。然后順著這個(gè)Dispatch函數(shù)可以發(fā)現(xiàn)最后在oneflow/core/framework/op_interpreter/op_interpreter_util.cpp中調(diào)用到了GetInterpreter函數(shù)的Apply方法:
/*?static?*/?Maybe<void>?OpInterpUtil::Dispatch(const?OpExpr&?op_expr,?const?TensorTuple&?inputs,
????????????????????????????????????????????????TensorTuple*?outputs,
????????????????????????????????????????????????const?OpExprInterpContext&?ctx)?{
??return?JUST(GetInterpreter(inputs,?ctx,?op_expr))->Apply(op_expr,?inputs,?outputs,?ctx);
}
這里的OpExprInterpContext對(duì)象會(huì)存儲(chǔ)Op的動(dòng)態(tài)屬性,設(shè)備信息,分布式信息等,對(duì)于Relu Functor來(lái)說(shuō),這里為空,所以我們這里不關(guān)注這個(gè)對(duì)象。再往下跟就屬于InterPreter的內(nèi)容了,新開(kāi)一節(jié)來(lái)講。
0x2. Interpreter從上面的Op調(diào)用流程可以看出,我們?cè)赑ython層的Op實(shí)際上是調(diào)用的導(dǎo)出到Python的Functor接口,而Functor接口會(huì)將OpExpr,輸入Tensor和動(dòng)態(tài)屬性attr遞交給Interpreter來(lái)處理,因?yàn)樯厦娴?code style="font-size:14px;background-color:rgba(27,31,35,.05);font-family:'Operator Mono', Consolas, Monaco, Menlo, monospace;color:#f48a00;">GetInterpreter函數(shù)獲取的就是一個(gè)Interpreter對(duì)象。Interpreter這個(gè)類(lèi)就是專(zhuān)門(mén)用來(lái)解釋Op執(zhí)行過(guò)程的,上一節(jié)在Relu Functor里面的Dispatch就是把任務(wù)分發(fā)到Interpreter來(lái)執(zhí)行。OneFlow的Interpreter又分為幾種類(lèi)型,如Eager Mirrored Interpreter,Eager Consistent Interpreter和LazyInterpreter,我們這篇文章的例子沒(méi)有考慮分布式信息,所以輸入Tensor都是Eager Mirroed Tensor,所以走的是Eager Mirrored Interpreter這個(gè)調(diào)用鏈。Mirrored Tensor和PyTorch的Tensor類(lèi)似,在各個(gè)Rank上是獨(dú)立的。
再往下跟一下我們發(fā)現(xiàn)上面的Apply實(shí)際上調(diào)用的是oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp文件中的NaiveInterpret函數(shù),這個(gè)函數(shù)接收OpExpr對(duì)象,輸入輸出Tensor和一個(gè)OpExprInterpContext對(duì)象來(lái)對(duì)Op的device,輸出dtype,輸出shape等進(jìn)行推導(dǎo),然后根據(jù)推導(dǎo)的元信息(元信息對(duì)應(yīng)TensorMeta類(lèi)對(duì)象,把 Tensor 的基本信息:shape, dtype, stride 等抽出來(lái)一個(gè)類(lèi)型,放一起方便管理)構(gòu)造分別對(duì)應(yīng)輸入輸出的BlobObject對(duì)象input_eager_blob_objects和output_eager_blob_objects(可理解為輸入輸出Tensor的數(shù)據(jù)指針),另外還會(huì)根據(jù)OpExpr和推導(dǎo)后的device構(gòu)造一個(gè)特定執(zhí)行kernel。最后將執(zhí)行kernel,輸入輸出Tensor的數(shù)據(jù)指針以及OpExprInterpContext對(duì)象以指令的方式發(fā)給OneFlow的虛擬機(jī)(VM,可以理解為OneFlow的Eager運(yùn)行時(shí),后面會(huì)細(xì)講)執(zhí)行并獲得結(jié)果。
這里我們分段看一下NaiveInterpret的實(shí)現(xiàn)。第一段:
Maybe<void>?NaiveInterpret(const?UserOpExpr&?user_op_expr,?const?TensorTuple&?inputs,
???????????????????????????const?Symbol&?default_device,?TensorTuple*?outputs,
???????????????????????????const?OpExprInterpContext&?ctx) ?{
??const?auto&?attrs?=?ctx.attrs;
??std::shared_ptr?input_eager_blob_objects?=
??????std::make_shared(inputs.size());
??for?(int?i?=?0;?i?????const?auto&?input_device?=?JUST(inputs.at(i)->device());
????if?(i?>?0)?{
??????CHECK_OR_RETURN(*default_device?==?*input_device)?<????}
????input_eager_blob_objects->at(i)?=?JUST(inputs.at(i)->eager_blob_object());
??}
上面這段代碼遍歷輸入Tensor的列表,將每一個(gè)輸入Tensor的device和函數(shù)傳入的默認(rèn)device進(jìn)行比較,如果發(fā)現(xiàn)輸入Tensor的device和默認(rèn)device不一致就拋出異常??梢詫?duì)類(lèi)似輸入Tensor在CPU上,但nn.Module在GPU上的例子進(jìn)行錯(cuò)誤檢查,輸出設(shè)備不匹配的錯(cuò)誤信息。如果設(shè)備都匹配上了,這個(gè)時(shí)候會(huì)將輸入Tensor的eager_blob_object添加到input_eager_blob_objects這個(gè)列表中。輸入Tensor的eager_blob_object是一個(gè)EagerBlobObject類(lèi)型的指針,是輸入Tensor的數(shù)據(jù)指針,后續(xù)通過(guò)它和OneFlow的虛擬機(jī)(VM)進(jìn)行交互。
?這里要補(bǔ)充說(shuō)明一下OneFlow中Tensor,TensorImpl,TensorMeta和BlobObject的關(guān)系。Tensor 和 TensorImpl 用了橋接設(shè)計(jì)模式,Tensor 負(fù)責(zé)向上和 python 接口、autograd 的對(duì)接;TensorImpl 是向下負(fù)責(zé)真實(shí)數(shù)據(jù)這部分。TensorMeta 就是把 Tensor 的基本信息:shape, dtype, stride 等抽出來(lái)一個(gè)類(lèi)型,放一起方便管理。BlobObject是真正的數(shù)據(jù)對(duì)象,數(shù)據(jù)指針在這個(gè)對(duì)象中,這個(gè)類(lèi)被虛擬機(jī)使用來(lái)完成指令的計(jì)算任務(wù)。
第二段:
std::shared_ptr?output_eager_blob_objects?=
??????std::make_shared(outputs->size());
??auto*?output_tensor_metas?=?ThreadLocalDefaultOutputMutTensorMetas(outputs->size());
??for?(int?i?=?0;?i?size();?i++)?{
????if?(!outputs->at(i))?{
??????const?auto&?tensor_impl?=?std::make_shared();
??????outputs->at(i)?=?std::make_shared(tensor_impl);
??????output_tensor_metas->at(i)?=?tensor_impl->mut_tensor_meta();
????}?else?{
??????bool?has_eager_blob_object?=?JUST(outputs->at(i)->has_eager_blob_object());
??????CHECK_OR_RETURN(has_eager_blob_object);
??????output_eager_blob_objects->at(i)?=?JUST(outputs->at(i)->eager_blob_object());
????}
??}
這里首先聲明了一個(gè)EagerBlobObjectList類(lèi)型的指針output_eager_blob_objects 以及存儲(chǔ)輸出Tensor元信息的output_tensor_metas,然后遍歷輸出Tensor列表判斷第i個(gè)Tensor是否已經(jīng)有值,如果沒(méi)有就申請(qǐng)一個(gè)MirroredTensor類(lèi)型的指針并初始化為tensor_impl這個(gè)對(duì)象,并將output_tensor_metas在索引i處的值更新為tensor_impl的Tensor元信息,為接下來(lái)的形狀和類(lèi)型推導(dǎo)做準(zhǔn)備(這里如果有值的話,那就是 inplace 調(diào)用了,如果加一些判斷,可以發(fā)現(xiàn)有值的 BlobObject 和某個(gè)輸入的 BlobObject 是同一個(gè)對(duì)象)。如果這個(gè)輸出Tensor已經(jīng)有值了(inplace模式),那么就判斷它是否存在EagerBlobObject類(lèi)型的數(shù)據(jù)指針,如果存在就將這個(gè)數(shù)據(jù)指針取出來(lái)放到剛才申請(qǐng)好的EagerBlobObjectList類(lèi)型的output_eager_blob_objects列表里。后續(xù)的shape推導(dǎo)和dtype推導(dǎo)也將用到這個(gè)output_eager_blob_objects。
第三段:
Symbol?op_device;
??bool?need_check_mem_case?=?true;
??//?Infer?devices
??if?(!user_op_expr.has_device_infer_fn())?{
????op_device?=?default_device;
????for?(int?i?=?0;?i?size();?i++)?{
??????auto*?tensor_impl?=?JUST(TensorImpl4Tensor(outputs->at(i)));
??????*JUST(tensor_impl->mut_device())?=?default_device;
????}
??}?else?{
????need_check_mem_case?=?false;
????op_device?=?JUST(user_op_expr.InferDevices(attrs,?inputs,?outputs));
??}
??//?Infer?shapes?and?dtypes
??const?auto&?device_tag?=?JUST(op_device->of_type());
??JUST(user_op_expr.InferPhysicalShapeAndDType(
??????attrs,?device_tag,
??????[&](int32_t?i)?->?const?TensorMeta*?{
????????return?CHECK_JUST(TensorImpl4Tensor(inputs.at(i)))->mut_tensor_meta();
??????},
??????[&](int32_t?i)?->?TensorMeta*?{
????????//?using?thread_local?TensorMeta?pointer?if?inplace.
????????//?using?tensor_impl?TensorMeta?pointer?if?not?inplace.
????????return?output_tensor_metas->at(i);
??????}));
??for?(int?i?=?0;?i?size();?i++)?{
????auto*?tensor_impl?=?JUST(TensorImpl4Tensor(outputs->at(i)));
????if?(!output_eager_blob_objects->at(i))?{
??????tensor_impl->mut_tensor_meta()->set_stride(std::make_shared(*tensor_impl->shape()));
??????const?auto&?dep_object?=?JUST(GetLocalDepObjectFromDevicePool(op_device));
??????JUST(tensor_impl->InitEagerBlobObject(dep_object));
??????output_eager_blob_objects->at(i)?=?JUST(tensor_impl->eager_blob_object());
????}?else?{
??????//?output?i?is?inplaced.
??????//?check?thread_local?TensorMeta?and?tensor_impl?TensorMeta.
??????CHECK_OR_RETURN(tensor_impl->tensor_meta()->shape()?==?output_tensor_metas->at(i)->shape());
??????CHECK_OR_RETURN(tensor_impl->tensor_meta()->dtype()?==?output_tensor_metas->at(i)->dtype());
????}
??}
這一段代碼是Op的device,shape和dtype推導(dǎo)。user_op_expr.has_device_infer_fn()用來(lái)判斷當(dāng)前的OpExpr是否存在device信息推導(dǎo)函數(shù),如果沒(méi)有就將輸出Tensor的device信息更新為當(dāng)前的default_device。如果有就直接從user_op_expr取出來(lái)即可。這里是否推導(dǎo)過(guò)在注冊(cè)User Op的時(shí)候就已經(jīng)決定了,我們可以在oneflow/core/framework/op_expr.cpp這里的UserOpExpr::Init看到對(duì)注冊(cè)器是否有device推導(dǎo)函數(shù)的判斷,另外我們可以在oneflow/ir/include/OneFlow/OneFlowUserOps.td這個(gè)td文件中看到哪些Op實(shí)現(xiàn)了device推導(dǎo)函數(shù)。
接下來(lái)調(diào)用了OpExpr中的InferPhysicalShapeAndDType完成對(duì)輸出Tensor的shape和dtype推導(dǎo)。跟進(jìn)InferPhysicalShapeAndDType函數(shù)可以發(fā)現(xiàn)它實(shí)際調(diào)用了注冊(cè)User Op時(shí)定義的shape推導(dǎo)和dtype推導(dǎo)函數(shù)。然后會(huì)遍歷output_eager_blob_objects并基于已經(jīng)推導(dǎo)出的TensorMeta對(duì)它做更新或者檢查(這里的TensorMeta檢查就是因?yàn)樯厦嫣岬降目赡艽嬖诘腎nplace的情況,inplace 前后的TensorMeta不能改變)。
最后一段:
const?auto&?kernel?=?JUST(user_op_expr.MutKernel4Device(op_device));
??kernel->set_need_check_mem_case(need_check_mem_case);
??for?(int64_t?index?:?kernel->output_tuple_indexes4mut2_obns())?{
????output_eager_blob_objects->at(index)->set_is_shape_synced(false);
??}
??JUST(PhysicalRun([&](InstructionsBuilder*?builder)?->?Maybe<void>?{
????return?builder->LocalCallOpKernel(kernel,?input_eager_blob_objects,?output_eager_blob_objects,
??????????????????????????????????????ctx,?op_device);
??}));
最后一段代碼就是Interpreter和VM交互時(shí)最關(guān)鍵的一步了,這里用user_op_expr.MutKernel4Device構(gòu)造了在op_device上的StatefulOpKernel ,并將output_eager_blob_objects中每個(gè)EagerBlobObject對(duì)象的is_shape_synced_屬性設(shè)置為False,這個(gè)is_shape_synced_設(shè)置為False代表輸出Tensor的形狀是在運(yùn)行時(shí)確定的,要Kernel執(zhí)行完之后才能獲得輸出Tensor的shape。為什么這里要默認(rèn)都設(shè)置為False呢?因?yàn)閷?duì)于一個(gè) Op 來(lái)說(shuō),它的 shape 是不是需要推導(dǎo)是 Op 自己的屬性,這里默認(rèn)會(huì)給一個(gè) false。然后在 StatefulOpKernel 那里還有個(gè) flag,這里就真正知道哪些 op 是動(dòng)態(tài) shape 了,如果不是動(dòng)態(tài) shape,就給這個(gè) flag 置為 True,表示已經(jīng)同步(不用同步)。這里的builder->LocalCallOpKernel函數(shù)就是在構(gòu)建虛擬機(jī)(VM)的指令,而PhysicalRun負(fù)責(zé)給虛擬機(jī)發(fā)送這個(gè)指令并執(zhí)行獲得最終結(jié)果。
OneFlow Eager的運(yùn)行時(shí)被抽象為虛擬機(jī)(VM)。當(dāng)我們執(zhí)行flow.relu(x)這句代碼時(shí),會(huì)通過(guò)上面的Interpreter發(fā)一個(gè)LocalCallOpKernel指令給VM。VM再執(zhí)行這個(gè)指令的時(shí)候會(huì)為輸出Tensor申請(qǐng)顯存,調(diào)用ReLU的Cuda Kernel進(jìn)行計(jì)算并將計(jì)算結(jié)果寫(xiě)到輸出Tensor。
我們先介紹一下虛擬機(jī)一些概念,然后再追關(guān)鍵代碼進(jìn)一步說(shuō)明。
OneFlow程序在運(yùn)行期間虛擬機(jī)會(huì)在后臺(tái)不斷的輪詢(xún),如果有新的可以執(zhí)行的指令就執(zhí)行,沒(méi)有就繼續(xù)輪詢(xún)。虛擬機(jī)有兩種線程,稱(chēng)作scheduler線程以及worker線程(如果我們運(yùn)行Python腳本,Python腳本是在主線程也叫main線程中運(yùn)行)。虛擬機(jī)的輪詢(xún)是在scheduler線程中,而worker線程則是處理一些阻塞的操作,這種操作比較慢不適合放到scheduler線程里面做。
剛才我們已經(jīng)多次提到指令這個(gè)名詞,虛擬機(jī)執(zhí)行的最小單位就是指令。OneFlow中的指令類(lèi)型有AccessBlobByCallback,LocalCallOpKernel,ReleaseTensor等。AccessBlobByCallback用于讀取和修改Blob的值的指令,而LocalCallOpKernel是運(yùn)行一個(gè)Op的指令,ReleaseTensor就是釋放聲明周期已經(jīng)結(jié)束的Tensor的內(nèi)存。每一種指令都會(huì)攜帶一個(gè)parallel_desc表示指令在哪些設(shè)備上執(zhí)行(例如只在 1 號(hào)卡上執(zhí)行,或在所有的卡上執(zhí)行),還會(huì)綁定一個(gè) StreamType,表示指令在哪種 Stream 上執(zhí)行(在我們文章開(kāi)頭舉的例子中,ReLU對(duì)應(yīng)的LocalCallOpKernel就是在CudaStream上執(zhí)行)。以LocalCallOpKernel為例,根據(jù)StreamType的不同有以下類(lèi)型的指令:
Maybe<const?std::string&>?GetLocalCallInstructionName(const?std::string&?type)?{
??static?const?HashMap<std::string,?std::string>?type2instr_name{
??????{"cpu",?"cpu.LocalCallOpKernel"},
??????{"gpu",?"gpu.LocalCallOpKernel"},
??????{"cuda",?"gpu.LocalCallOpKernel"},
??????{"cuda_h2d",?"cuda_h2d.LocalCallOpKernel"},
??????{"cuda_d2h",?"cuda_d2h.LocalCallOpKernel"},
??????{"comm_net",?"cpu.LocalCallOpKernel"},
??????{"sync_launched_nccl",?"gpu.LocalCallOpKernel"},
??????{"async_launched_nccl",?"async.gpu.LocalCallOpKernel"},
??????//?no?compute?instruction?on?critical_section?device.
??????{"critical_section",?"UNIMPLEMENTED?INSTRUCTION?NAME"},
??};
??return?MapAt(type2instr_name,?type);
}
以cpu.LocalCallOpKernel指令來(lái)看就將它的stram_type綁定為CpuStreamType,在oneflow/core/eager/cpu_opkernel_instruction_type.cpp的定義如下:
class?CpuLocalCallOpKernelInstructionType?final?:?public?LocalCallOpKernelInstructionType?{
?public:
??CpuLocalCallOpKernelInstructionType()?=?default;
??~CpuLocalCallOpKernelInstructionType()?override?=?default;
??using?stream_type?=?vm::CpuStreamType;?//?綁定stream_type
?private:
??const?char*?device_tag()?const?override?{?return?stream_type().device_tag();?}
};
COMMAND(vm::RegisterInstructionType("cpu.LocalCallOpKernel"));
每種StreamType都可以設(shè)置這種類(lèi)型的Stream是否工作在scheduler線程上,初始化和查詢(xún)指令狀態(tài),完成指令計(jì)算等工作。
?這里的Stream是虛擬機(jī)里面的device抽象,每一種Stream對(duì)應(yīng)一種device。另外指令都有Infer和Compute過(guò)程,Infer是推導(dǎo)元信息,而Compute才是真正的啟動(dòng)計(jì)算Kernel進(jìn)行執(zhí)行。
接下來(lái)我們看看指令間的依賴(lài)關(guān)系,虛擬機(jī)的指令是亂序執(zhí)行的,但對(duì)有依賴(lài)關(guān)系的指令的執(zhí)行順序也是有要求的。例如用戶(hù)發(fā)射了a和b兩條指令,然后a指令要修改Blob c的值,但b指令要讀取Blob c的值,那a指令就得先于b指令執(zhí)行。
那么指令間的依賴(lài)關(guān)系是如何構(gòu)建的呢?指令間的依賴(lài)關(guān)系是依靠指令攜帶的操作數(shù)來(lái)實(shí)現(xiàn)的,操作數(shù)的主要類(lèi)型有 const、mut、mut2。const 對(duì)應(yīng)輸入(讀?。?,mut 和 mut2 對(duì)應(yīng)輸出(寫(xiě)入)。上述的 a 指令有一個(gè) mut operand c,b 指令有一個(gè) const operand c。這樣,通過(guò)檢查 a 和 b 指令中 c 的類(lèi)型,就可以在 a 和 b 之間建立依賴(lài)關(guān)系:b 的 infer 一定要在 a infer 完成之后、b 的 compute 一定要在 a compute 之后。mut2 operand 是為了處理一些 output shape 在 compute 階段才能確定的 op(如 unique),例如,如果 a 以 mut2 operand 形式持有 c,那么 b 的 infer 和 compute 都需要發(fā)生在 a 的 compute 之后。從oneflow/core/eager/local_call_opkernel_phy_instr_operand.h定義的LocalCallOpKernelPhyInstrOperand指令來(lái)看,它重載了ForEachConstMirroredObject,ForEachMutMirroredObject,ForEachMut2MirroredObject三種方法,分別對(duì)應(yīng)的是const,mut,mut2操作數(shù)。在重載的每個(gè)方法里去調(diào)用傳入的回調(diào)函數(shù)(const std::function)來(lái)構(gòu)建指令間的依賴(lài)關(guān)系,以const為例:
void?LocalCallOpKernelPhyInstrOperand::ForEachConstMirroredObject(
????const?std::function<void(vm::MirroredObject*?compute)>&?DoEach)?const?{
??const?auto&?input_list?=?inputs();
??for?(int64_t?index?:?opkernel().input_tuple_indexes4const_ibns())?{
????const?auto&?input?=?input_list->at(index);
????DoEach(CHECK_JUST(input->compute_local_dep_object())->mut_mirrored_object());
??}
}
for (int64_t index : opkernel().input_tuple_indexes4const_ibns()) 這行代碼用來(lái)遍歷StatefulOpKernel對(duì)象里面的const操作數(shù),得到它在Input Tuple里面的下標(biāo)獲得index,然后根據(jù)index取出這個(gè)下標(biāo)對(duì)應(yīng)的對(duì)應(yīng)的EagerBlobObject對(duì)象。再對(duì)這個(gè)EagerBlobObject上的compute_local_dep_object調(diào)用DoEach這個(gè)回調(diào),相當(dāng)于以const的方式去消費(fèi)這個(gè)compute_local_dep_object。mut和mut2類(lèi)似。
這里還要說(shuō)明一下虛擬機(jī)的指令間依賴(lài)關(guān)系具體是怎么建立的。在oneflow/core/vm/virtual_machine_engine.cpp里面的HandlePending成員函數(shù)里面,ConsumeMirroredObjects這個(gè)函數(shù)中的for (const auto& operand : operands) 針對(duì)每種operand調(diào)用ForEachMutMirroredObject函數(shù),比如對(duì)于mut來(lái)說(shuō):
for?(const?auto&?operand?:?operands)?{
?if?(operand->has_mut_operand())?{
??ForEachMutMirroredObject(interpret_type,?id2logical_object,
?????????????????????????????????????????????operand->mut_operand(),?global_device_id,
?????????????????????????????????????????????ConsumeMutMirroredObject);
?}?...
}
templatetypename ?DoEachT>
void?VirtualMachineEngine::ForEachMutMirroredObject(
const?InterpretType?interpret_type,?Id2LogicalObject*?id2logical_object,
const?ModifiedOperand&?mut_operand,
int64_t?global_device_id,?const?DoEachT&?DoEach) ?{
????const?Operand&?operand?=?mut_operand.operand();
????if?(interpret_type?==?InterpretType::kCompute)?{
????????ForEachMirroredObject<&IdUtil::GetValueId>(id2logical_object,?operand,?global_device_id,
???????????????????????????????????????????DoEach);
????}?else?if?(interpret_type?==?InterpretType::kInfer)?{
?????ForEachMirroredObject<&IdUtil::GetTypeId>(id2logical_object,?operand,?global_device_id,?DoEach);
????}?else?{
????UNIMPLEMENTED();
????}
}
這里的DoEachT就是ConsumeMutMirroredObject,即消費(fèi)MutMirroredObject。繼續(xù)跟進(jìn)ConsumeMutMirroredObject的實(shí)現(xiàn):
const?auto&?ConsumeMirroredObject?=?[&](OperandAccessType?access_type,
????????????????????????????????MirroredObject*?mirrored_object)?{
????auto*?access?=?AccessMirroredObject(access_type,?mirrored_object,?instruction);
????instruction->mut_mirrored_object_id2access()->Insert(access);
????return?access;
};
這里的AccessMirroredObject將這個(gè)指令添加到了會(huì)訪問(wèn)這個(gè)mirrored_object的指令列表里面。
RwMutexedObjectAccess*?VirtualMachineEngine::AccessMirroredObject(OperandAccessType?access_type,
??????????????????????????????????????????????????????MirroredObject*?mirrored_object,
??????????????????????????????????????????????????????Instruction*?instruction)?{
????auto?access?=?access_pool_.make_shared(instruction,?mirrored_object,?access_type);
????auto*?ptr?=?access.Mutable();
????instruction->mut_access_list()->PushBack(ptr);
????mirrored_object->mut_rw_mutexed_object()->mut_access_list()->EmplaceBack(std::move(access));
????return?ptr;
}
RwMutexedObject這里是對(duì)mirrored_object的讀寫(xiě)進(jìn)行加鎖。有了指令的依賴(lài)關(guān)系之后我們就可以構(gòu)造指令邊了,構(gòu)建完指令邊之后虛擬機(jī)就可以執(zhí)行有指令節(jié)點(diǎn)構(gòu)成的一個(gè)Dag。處理Dag的一個(gè)有效方式是拓?fù)渑判?,但在OneFlow的虛擬機(jī)里面是通過(guò)ready_instruction_list和pending_instaruction_list將其做成一個(gè)迭代的方式,即scheduler輪詢(xún)的時(shí)候只需要不斷處理這兩個(gè)list即可。這里再看一下指令邊的構(gòu)建流程,在ConsumeMirroredObjects的這部分:
void?VirtualMachineEngine::TryConnectInstruction(Instruction*?src_instruction,
?????????????????????????????????????Instruction*?dst_instruction)?{
????if?(unlikely(src_instruction?==?dst_instruction))?{?return;?}
????if?(likely(EdgeDispatchable(src_instruction,?dst_instruction)))?{?return;?}
????auto?edge?=?instruction_edge_pool_.make_shared(src_instruction,?dst_instruction);
????src_instruction->mut_out_edges()->PushBack(edge.Mutable());
????dst_instruction->mut_in_edges()->PushBack(edge.Mutable());
}
void?VirtualMachineEngine::ConnectInstructionsByWrite(RwMutexedObjectAccess*?dst_access)?{
????CHECK(dst_access->is_mut_operand());
????auto*?mirrored_object?=?dst_access->mut_mirrored_object();
????auto*?dst_instruction?=?dst_access->mut_instruction();
????auto*?access_list?=?mirrored_object->mut_rw_mutexed_object()->mut_access_list();
????if?(likely(access_list->Begin()?==?dst_access))?{?return;?}
????INTRUSIVE_FOR_EACH_PTR(src_access,?access_list)?{
????if?(unlikely(src_access?==?dst_access))?{?break;?}
????TryConnectInstruction(src_access->mut_instruction(),?dst_instruction);
????CHECK_EQ(src_access->mut_rw_mutexed_object(),?mirrored_object->mut_rw_mutexed_object());
????access_list->Erase(src_access);
}
}
void?VirtualMachineEngine::ConnectInstructionsByRead(RwMutexedObjectAccess*?dst_access)?{
????CHECK(dst_access->is_const_operand());
????auto*?mirrored_object?=?dst_access->mut_mirrored_object();
????auto*?dst_instruction?=?dst_access->mut_instruction();
????auto*?first?=?mirrored_object->mut_rw_mutexed_object()->mut_access_list()->Begin();
????if?(first->is_mut_operand())?{
????TryConnectInstruction(first->mut_instruction(),?dst_instruction);
????}?else?if?(first->is_const_operand())?{
????//?do?nothing
????}?else?{
????UNIMPLEMENTED();
????}
}
if?(likely(phy_instr_operand))?{
//?Connect?instructions?by?write?before?connecting?by?read.
????for?(auto*?mirrored_object?:?phy_instr_operand->output_dependences())?{
????ConnectInstructionsByWrite(
????AccessMirroredObject(kMutableOperandAccess,?mirrored_object,?instruction));
????}
????for?(auto*?mirrored_object?:?phy_instr_operand->input_dependences())?{
????ConnectInstructionsByRead(
????AccessMirroredObject(kConstOperandAccess,?mirrored_object,?instruction));
????}
}
會(huì)去分析兩個(gè)指令的關(guān)系,例如一個(gè)讀一個(gè)寫(xiě),或者兩個(gè)讀或者寫(xiě),來(lái)分別構(gòu)造指令邊,把兩個(gè)指令連在一起。
因此,虛擬機(jī)的指令依賴(lài)關(guān)系并不是虛擬機(jī)內(nèi)嵌的,而是通過(guò)消費(fèi)指令的操作數(shù)實(shí)現(xiàn)出來(lái)的,并且除了消費(fèi)操作數(shù)構(gòu)造指令依賴(lài)關(guān)系,還可以消費(fèi)device。以LocalCallOpKernelPhyInstrOperand指令的mut操作數(shù)為例,這里會(huì)拿到StatefulOpKernel對(duì)應(yīng)的device,比如cuda,然后每個(gè)device方法上也有一個(gè)local_dep_object成員,每個(gè)指令都以mut形式來(lái)消費(fèi)device上的local_dep_object,這樣就實(shí)現(xiàn)了比如前后兩個(gè)指令都在同一個(gè)device上執(zhí)行,那么這兩個(gè)指令的執(zhí)行順序一定是需要按照發(fā)射時(shí)的順序進(jìn)行執(zhí)行的這種依賴(lài)關(guān)系,因?yàn)樗鼈兌家詍ut的方式消費(fèi)了同一個(gè)local_dep_object。
?0x4. VM和Interpreter的整體調(diào)用鏈這里的
local_dep_object是專(zhuān)門(mén)用來(lái)幫助虛擬機(jī)構(gòu)建指令邊的一個(gè) 對(duì)象。這個(gè)對(duì)象被EagerBlobObject,Device持有,然后按先后順序消費(fèi)它就建立了指令之間的聯(lián)系。
虛擬機(jī)的基礎(chǔ)知識(shí)就點(diǎn)到為止了,因?yàn)槲业睦斫饽壳耙彩钟邢?。這一節(jié)再宏觀的梳理一下Interpter和虛擬機(jī)的調(diào)用鏈。首先,Python層調(diào)用OneFlow的Op會(huì)發(fā)經(jīng)過(guò)Interpreter去構(gòu)建虛擬機(jī)的指令并執(zhí)行。以ReLU為例,在Interpreter的最后一步是:
JUST(PhysicalRun([&](InstructionsBuilder*?builder)?->?Maybe<void>?{
????return?builder->LocalCallOpKernel(kernel,?input_eager_blob_objects,?output_eager_blob_objects,
??????????????????????????????????????ctx,?op_device);
??}));
然后跟進(jìn)LocalCallOpKernel的實(shí)現(xiàn):
Maybe<void>?InstructionsBuilder::LocalCallOpKernel(
????const?std::shared_ptr&?opkernel,
????const?one::EagerBlobObjectListPtr&?input_eager_blob_objects,
????const?one::EagerBlobObjectListPtr&?output_eager_blob_objects,
????const?std::shared_ptr<const?one::ConsistentTensorInferResult>&?consistent_tensor_infer_result,
????const?one::OpExprInterpContext&?ctx,?Symbol?op_device) ?{
??const?auto&?parallel_desc_sym?=?JUST(Placement4Device(op_device)).shared_from_symbol();
??for?(const?auto&?input?:?*input_eager_blob_objects)?{
????const?auto&?blob_last_used_device?=?JUST(input->last_used_device());
????if?(blob_last_used_device?!=?op_device)?{
??????auto*?dep_object?=?JUST(input->compute_local_dep_object());
??????JUST(SoftSyncStream(dep_object,?"mut",?blob_last_used_device));
????}
????input->set_last_used_device(op_device);
??}
??auto?phy_instr_operand?=?JUST(vm::LocalCallOpKernelPhyInstrOperand::New(
??????opkernel,?input_eager_blob_objects,?output_eager_blob_objects,?consistent_tensor_infer_result,
??????ctx,?*one::CurrentDevVmDepObjectConsumeMode()));
??auto?instruction?=?intrusive::make_shared(
??????Global::Get()->mut_vm(),?JUST(op_device->local_call_instruction_name()),
??????parallel_desc_sym,?phy_instr_operand);?
??instruction_list_->EmplaceBack(std::move(instruction));
??for?(const?auto&?output?:?*output_eager_blob_objects)?{
????if?(!output->producer_op_device().has_value())?{
??????JUST(output->init_producer_op_device(op_device));
????}
????output->set_last_used_device(op_device);
??}
??return?Maybe<void>::Ok();
}
auto instruction = intrusive::make_shared這句代碼,構(gòu)建了一條新的指令給它綁定了一個(gè)parallel_desc,表示在哪些設(shè)備上執(zhí)行(例如只在 0 號(hào)卡上執(zhí)行,或在所有的卡上執(zhí)行)和一個(gè) StreamType,表示指令在哪種 stream 上執(zhí)行。而這句代碼上面的 auto phy_instr_operand = JUST(vm::LocalCallOpKernelPhyInstrOperand::New...是用來(lái)將指令和操作數(shù)進(jìn)行綁定的。現(xiàn)在指令有了,接下來(lái)就應(yīng)該和VM進(jìn)行交互基于這些新建的指令構(gòu)建指令邊并執(zhí)行了,這個(gè)交互的接口是PhysicalInterpreter::Run(從PhysicalRun跳進(jìn)去)。
Maybe<void>?PhysicalInterpreter::Run(
????const?std::functionvoid >(InstructionsBuilder*)>&?Build)?{
??InstructionsBuilder?instructions_builder(mut_id_generator(),?mut_instruction_list(),
???????????????????????????????????????????mut_eager_symbol_list());
??JUST(Build(&instructions_builder));
??if?(instructions_builder.instruction_list().empty())?{
????CHECK(instructions_builder.eager_symbol_list().eager_symbol().empty());
????return?Maybe<void>::Ok();
??}
??return?Global::Get()->RunPhysicalInstruction(
??????instructions_builder.mut_instruction_list(),?instructions_builder.eager_symbol_list());
}
跳到RunPhysicalInstruction的定義,在oneflow/core/eager/eager_oneflow.cpp:
Maybe<void>?EagerOneflow::RunPhysicalInstruction(
????vm::InstructionMsgList*?instruction_list,
????const?vm::cfg::EagerSymbolList&?cfg_eager_symbol_list)?{
??vm::EagerSymbolList?eager_symbol_list;
??cfg_eager_symbol_list.ToProto(&eager_symbol_list);
??return?RunPhysicalInstruction(instruction_list,?eager_symbol_list);
}
它的入?yún)⒕褪俏覀儤?gòu)造指令那個(gè)地方定義的全局InstructionsBuilder對(duì)象的mut_instruction_list和eager_symbol_list(是虛擬機(jī)里面的對(duì)象)。再跳轉(zhuǎn)一下RunPhysicalInstruction(instruction_list, eager_symbol_list)可以看到如下定義:
Maybe<void>?EagerOneflow::RunPhysicalInstruction(vm::InstructionMsgList*?instruction_list,
?????????????????????????????????????????????????const?vm::EagerSymbolList&?eager_symbol_list)?{
??for?(const?auto&?eager_symbol?:?eager_symbol_list.eager_symbol())?{
????JUST(StorageAdd(eager_symbol));
??}
??return?vm::Run(instruction_list);
}
Maybe<void>?Run(vm::InstructionMsgList*?instr_msg_list)?{
??auto*?virtual_machine?=?JUST(GlobalMaybe());
??JUST(virtual_machine->Receive(instr_msg_list));
??return?Maybe<void>::Ok();
}
這里的virtual_machine->Receive(instr_msg_list)就可以獲取剛才構(gòu)建的指令了。
Maybe<bool>?VirtualMachineEngine::Receive(
????intrusive::shared_ptr&&?compute_instr_msg) ?{
??InstructionMsgList?instr_msg_list;
??instr_msg_list.EmplaceBack(std::move(compute_instr_msg));
??return?Receive(&instr_msg_list);
}
獲取到指令之后就可以在VM的Scheduler線程進(jìn)行輪詢(xún)的時(shí)候處理這些指令了,即oneflow/core/vm/virtual_machine_engine.cpp這里的VirtualMachineEngine::Schedule函數(shù):
void?VirtualMachineEngine::Schedule()?{
??//?Release?finished?instructions?and?try?to?schedule?out?instructions?in?DAG?onto?ready?list.
??if?(unlikely(mut_active_stream_list()->size()))?{?ReleaseFinishedInstructions();?}
??//?TODO(lixinqi):?remove?this?line?after?disabling?vm?single-client?support.
??if?(unlikely(mut_delete_logical_object_list()->size()))?{?TryDeleteLogicalObjects();?}
??//?Try?run?the?first?barrier?instruction.
??if?(unlikely(mut_barrier_instruction_list()->size()))?{?TryRunBarrierInstruction();?}
??//?Handle?pending?instructions,?and?try?schedule?them?to?ready?list.
??//?Use?thread_unsafe_size?to?avoid?acquiring?mutex?lock.
??//?The?inconsistency?between?pending_msg_list.list_head_.list_head_.container_?and
??//?pending_msg_list.list_head_.list_head_.size_?is?not?a?fatal?error?because
??//?VirtualMachineEngine::Schedule?is?always?in?a?buzy?loop.?All?instructions?will?get?handled
??//?eventually.
??//??VirtualMachineEngine::Receive?may?be?less?effiencient?if?the?thread?safe?version
??//??`pending_msg_list().size()`?used?here,?because?VirtualMachineEngine::Schedule?is?more?likely
??//??to?get?the?mutex?lock.
??if?(unlikely(pending_msg_list().thread_unsafe_size()))?{?HandlePending();?}
??//?dispatch?ready?instructions?and?try?to?schedule?out?instructions?in?DAG?onto?ready?list.
??if?(unlikely(mut_ready_instruction_list()->size()))?{?DispatchAndPrescheduleInstructions();?}
}
Schedule函數(shù)在不斷的輪詢(xún),整體功能大概可以分為接受main線程發(fā)出的指令,輪詢(xún)指令的完成情況,處理阻塞指令以及Dispatch已經(jīng)就緒的指令。實(shí)際上當(dāng)我們點(diǎn)進(jìn)HandlePending可以發(fā)現(xiàn),它正是在消費(fèi)我們的local_dep_opbject進(jìn)行指令的構(gòu)建和指令邊鏈接,和上面分析的過(guò)程也對(duì)應(yīng)上了。
關(guān)于Interpreter和VM我大概就梳理到這里,實(shí)際上里面的細(xì)節(jié)比我想象的復(fù)雜很多,我對(duì)OneFlow的整體知識(shí)欠缺得還很多,所以目前理解也比較初級(jí)請(qǐng)見(jiàn)諒。最后再放一張某個(gè)網(wǎng)絡(luò)訓(xùn)練時(shí)生成的nsys圖:
可以看到虛擬機(jī)正在工作,scheduler線程正在分發(fā)就緒的指令并且launch Adam的cuda kernel執(zhí)行參數(shù)更新0x6. 總結(jié)這篇文章以oneflow.relu這個(gè)op為例,介紹了要執(zhí)行這個(gè)Op需要依賴(lài)的Interpreter和VM機(jī)制,對(duì)想了解OneFlow Eager執(zhí)行機(jī)制的同事以及用戶(hù)希望有一點(diǎn)幫助。
- 設(shè)計(jì)模式之橋接模式:https://segmentfault.com/a/1190000041225650
- https://github.com/Oneflow-Inc/oneflow
