深度學(xué)習(xí)框架如何優(yōu)雅的做算子對(duì)齊任務(wù)?
0x0. 前言
之前回答過(guò)「如何為PyTorch做貢獻(xiàn)的知乎問(wèn)題」,原貼見(jiàn):https://www.zhihu.com/question/502301777/answer/2248950419 。回答提到了去年在OneFlow開(kāi)發(fā)一些算子時(shí),基于算子AutoTest框架找到了一些PyTorch算子的bug,并給PyTorch做出了反饋或修復(fù)。但這個(gè)回答沒(méi)有介紹這個(gè)AutoTest框架長(zhǎng)什么樣子,以及它背后的原理。因此,這篇文章就用來(lái)介紹OneFlow的算子AutoTest框架看一下OneFlow深度學(xué)習(xí)框架在算子開(kāi)發(fā)過(guò)程中是如何優(yōu)雅的做算子對(duì)齊任務(wù)的(由@大缺弦 開(kāi)發(fā),后經(jīng)我和其它同事進(jìn)行擴(kuò)展和豐富功能形成今天的形態(tài))。這個(gè)AutoTest框架也可以很輕易移植到其它深度學(xué)習(xí)訓(xùn)練框架使用,代碼實(shí)現(xiàn)在https://github.com/Oneflow-Inc/oneflow/blob/v0.6.0/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py。
0x1. 傳統(tǒng)的算子對(duì)齊方式
不局限于OneFlow,任何組織或者個(gè)人編寫(xiě)的深度學(xué)習(xí)訓(xùn)練框架都需要驗(yàn)證算子的實(shí)現(xiàn)正確性。那么,深度學(xué)習(xí)框架中驗(yàn)證算子正確性的一般做法是什么呢?以百度的PaddlePaddle為例,在驗(yàn)證算子正確性時(shí)一般是根據(jù)調(diào)用其它標(biāo)準(zhǔn)庫(kù)獲得的結(jié)果(比如卷積算子的驗(yàn)證就調(diào)用cudnn的卷積,erf算子的驗(yàn)證就調(diào)用了scipy的erf)或者直接使用numpy模擬的計(jì)算結(jié)果來(lái)進(jìn)行驗(yàn)證(比如full算子的驗(yàn)證即為numpy模擬)。在PyTorch的測(cè)試中還有硬編碼一些測(cè)試樣例的方式,也即將固定輸入樣例的標(biāo)準(zhǔn)答案和算子計(jì)算的結(jié)果進(jìn)行對(duì)比,以此判斷算子實(shí)現(xiàn)的正確性。
這些方法都沒(méi)有什么問(wèn)題,但在編寫(xiě)測(cè)試時(shí)需要不少的人力并且在算子開(kāi)發(fā)初期可能有一些corner case會(huì)容易想不到。以O(shè)neFlow為例,由于算子的行為是對(duì)齊PyTorch,如果要驗(yàn)證轉(zhuǎn)置卷積Op在各種情況下的正確性,那么什么樣的測(cè)試代碼才可以全面驗(yàn)證呢?一種做法是將每個(gè)參數(shù)都枚舉出來(lái):
import?torch
import?numpy?as?np
import?oneflow?as?flow
for?N?in?range(1,?5):
????for?C_in?in?range(1,?10):
????????for?L_in?in?range(1,?10):
????????????for?H_in?in?range(1,?10):
????????????????for?C_out?in?range(1,?10):
????????????????????for?Ksize?in?range(1,?10):
????????????????????????for?Pad?in?range(1,?10):
????????????????????????????for?Dilation?in?range(1,?10):
????????????????????????????????for?Stride?in?range(1,?min(L_in,?H_in)):
????????????????????????????????????for?OutPad?in?range(1,?min(Dilation,?Stride)):
????????????????????????????????????????try:
????????????????????????????????????????????torch_input?=?torch.randn(N,?C_in,?L_in,?H_in)
????????????????????????????????????????????flow_input?=?flow.tensor(torch_input.numpy())
????????????????????????????????????????????torch_input.requires_grad?=?True
????????????????????????????????????????????flow_input.requires_grad?=?True
????????????????????????????????????????????torch_m?=?torch.nn.ConvTranspose2d(in_channels=C_in,?out_channels=C_out,?kernel_size=Ksize,?padding=Pad,?stride=Stride,
????????????????????????????????????????????????output_padding=(OutPad),?dilation=Dilation,?bias=False)
????????????????????????????????????????????flow_m?=?flow.nn.ConvTranspose2d(in_channels=C_in,?out_channels=C_out,?kernel_size=Ksize,?padding=Pad,?stride=Stride,
????????????????????????????????????????????????output_padding=(OutPad),?dilation=Dilation,?bias=False)
????????????????????????????????????????????flow_m.weight.data?=?flow.tensor(torch_m.weight.data.detach().numpy(),?requires_grad=True)
????????????????????????????????????????????torch_out?=?torch_m(torch_input)
????????????????????????????????????????????flow_out?=?flow_m(flow_input)
????????????????????????????????????????????torch_out?=?torch_out.sum()
????????????????????????????????????????????flow_out?=?flow_out.sum()
????????????????????????????????????????????assert(np.allclose(torch_out.detach().numpy(),?flow_out.detach().numpy(),?1e-06,?1e-06)),?"forward?not?equal"
????????????????????????????????????????????torch_out.backward()
????????????????????????????????????????????flow_out.backward()
????????????????????????????????????????????print(torch_input.grad.detach().numpy())
????????????????????????????????????????????print(flow_input.grad.detach()[:N,?:C_in,?:L_in,?:H_in].numpy())
????????????????????????????????????????????assert(np.allclose(torch_input.grad.detach().numpy(),?flow_input.grad.detach()[:N,?:C_in,?:L_in,?:H_in].numpy(),?1e-03,?1e-03)),?"backward?not?equal"
????????????????????????????????????????except?Exception?as?e:
????????????????????????????????????????????print('Input?Param?Error')
但這種做法雖然驗(yàn)證得比較全面但同樣有缺點(diǎn)。首先枚舉的上界如何確定?如果給了一個(gè)大的上界,那么這個(gè)算子的驗(yàn)證時(shí)間會(huì)非常長(zhǎng),不利于在CI流程中使用。如果上界很小就可能忽略一些corner case,導(dǎo)致測(cè)試仍然不會(huì)全面并增加算子出bug的風(fēng)險(xiǎn)。
基于算子測(cè)試的這些問(wèn)題,同事 @大缺弦 開(kāi)發(fā)了一個(gè)算子AutoTest框架,用于解決OneFlow算子和PyTorch算子對(duì)齊的問(wèn)題。后來(lái)我在此基礎(chǔ)上又為這個(gè)AutoTest框架豐富了其它的一些功能,感覺(jué)目前已經(jīng)比較好使,接下里做一個(gè)全面介紹。
整個(gè)AutoTest框架只有2個(gè)Python文件,即:https://github.com/Oneflow-Inc/oneflow/blob/v0.6.0/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py 和 https://github.com/Oneflow-Inc/oneflow/blob/v0.6.0/python/oneflow/test_utils/automated_test_util/generators.py 。并且這個(gè)AutoTest框架可以輕易移植到其它任何深度學(xué)習(xí)框架去做算子對(duì)齊任務(wù)。
0x2. 算子AutoTest框架用法
在介紹原理之前,我們先看一下AutoTest框架的用法。以上面的反卷積算子為例,使用了AutoTest框架之后就可以用下面的代碼來(lái)完成算子對(duì)齊測(cè)試:
@autotest()
def?test_deconv2d_with_random_data(test_case):
????channels?=?random(1,?6)
????m?=?torch.nn.ConvTranspose2d(
????????in_channels=channels,
????????out_channels=random(1,?20),
????????kernel_size=random(1,?4),
????????stride=random()?|?nothing(),
????????padding=random(1,?3).to(int)?|?nothing(),
????????dilation=random(1,?5)?|?nothing(),
????????groups=random(1,?5)?|?nothing(),
????????padding_mode=constant("zeros")?|?nothing(),
????)
????m.train(random())
????device?=?random_device()
????m.to(device)
????x?=?random_pytorch_tensor(ndim=4,?dim1=channels).to(device)
????y?=?m(x)
????return?y
熟悉PyTorch的小伙伴可以發(fā)現(xiàn)這個(gè)算子測(cè)試代碼和PyTorch的代碼風(fēng)格基本一樣。的確,AutoTest框架相當(dāng)于是一個(gè)high level的PyTorch,它的接口和PyTorch一樣,但對(duì)于給定的輸入會(huì)分別用OneFlow和PyTorch運(yùn)行一遍,記錄運(yùn)行過(guò)程中得到的每個(gè)tensor以及對(duì)應(yīng)梯度tensor的值,再對(duì)這些OneFlow和PyTorch分別產(chǎn)生的tensor檢查一遍數(shù)值形狀是否完全相同,以完成自動(dòng)測(cè)試工作,我們后面會(huì)細(xì)講。
我們可以再看一個(gè)測(cè)試matmul算子的例子:
?@autotest()
?def?test_flow_matmul_with_random_data(test_case):
?????k?=?random(1,?6)
?????x?=?random_pytorch_tensor(ndim=2,?dim1=k)
?????y?=?random_pytorch_tensor(ndim=2,?dim0=k)
?????z?=?torch.matmul(x,?y)
??return?z
我們基于random_pytorch_tensor方法構(gòu)造了兩個(gè)隨機(jī)tensor x和y,它們的維度分別是[m, k]和[k, n],這些維度的值都是隨機(jī)生成的。
執(zhí)行上述兩個(gè)測(cè)試?yán)樱詣?dòng)測(cè)試框架會(huì)自動(dòng)幫我們隨機(jī)出各種合法參數(shù)組合成的Op,并基于數(shù)值和類(lèi)型完全相同的輸入Tensor(PyTorch和OneFlow各有一份)分別運(yùn)行PyTorch和OneFlow的代碼,并完成算子的自動(dòng)測(cè)試。由于自動(dòng)測(cè)試框架的用法對(duì)齊了PyTorch用法,我們?cè)陂_(kāi)發(fā)算子之后編寫(xiě)測(cè)試樣例將非常簡(jiǎn)單。不用再引入其它的標(biāo)準(zhǔn)庫(kù)或者使用Numpy去模擬一遍算子的前向反向計(jì)算過(guò)程等,解放了生產(chǎn)力。
并且測(cè)試的時(shí)候只要次數(shù)足夠多,就可以很大概率的覆蓋到一些OneFlow算子和PyTorch算子無(wú)法對(duì)齊的樣例,這個(gè)時(shí)候如果能拿到對(duì)應(yīng)的復(fù)現(xiàn)樣例就可以幫助我們確定OneFlow算子實(shí)現(xiàn)是否存在問(wèn)題。
0x3. 算子AutoTest框架實(shí)現(xiàn)思路
了解了AutoTest框架的使用方法之后,這里來(lái)講解一下AutoTest框架的實(shí)現(xiàn)思路。從上面的用法可以大概可以猜到AutoTest框架在實(shí)現(xiàn)時(shí)會(huì)分成兩部分,一部分是如何產(chǎn)生隨機(jī)數(shù)據(jù),另外一部分是運(yùn)AutoTest部分的程序并記錄和比較中間tensor以及對(duì)應(yīng)的梯度tensor的形狀和數(shù)值。
0x3.1 如何產(chǎn)生隨機(jī)數(shù)據(jù)?
這里說(shuō)的隨機(jī)數(shù)據(jù)不僅指的是隨機(jī)的輸入tensor,還包含Op的屬性參數(shù)比如上面反卷積Op測(cè)試?yán)又械?code style="font-size: 14px;word-wrap: break-word;padding: 2px 4px;border-radius: 4px;margin: 0 2px;background-color: rgba(27,31,35,.05);font-family: Operator Mono, Consolas, Monaco, Menlo, monospace;word-break: break-all;color: #916dd5;font-weight: bolder;background: none;">kernel_size=random(1, 4)就實(shí)現(xiàn)了指定kernel_size將會(huì)在[1, 4)這個(gè)區(qū)間進(jìn)行取值。
這部分實(shí)現(xiàn)在https://github.com/Oneflow-Inc/oneflow/blob/v0.6.0/python/oneflow/test_utils/automated_test_util/generators.py 這個(gè)文件里。首先我們看一下這個(gè)文件導(dǎo)出了哪些接口:
__all__?=?[
????"random_tensor",
????"random_bool",
????"random_device",
????"random",
????"random_or_nothing",
????"oneof",
????"constant",
????"nothing"
]
這些接口都是繼承了generator基類(lèi)用來(lái)產(chǎn)生隨機(jī)數(shù)據(jù)結(jié)構(gòu)的類(lèi),這里的數(shù)據(jù)結(jié)構(gòu)既可以是內(nèi)置類(lèi)型如int,也可以是自定義數(shù)據(jù)類(lèi)型比如tensor。AutoTest框架所有的參數(shù)的隨機(jī)性都是基于這些方法來(lái)做到的,我們看一下generator基類(lèi)的實(shí)現(xiàn):
class?generator:
????def?__init__(self,?children):
????????self.children?=?children
????????self._value?=?None
????def?_init(self):
????????self._value?=?None
????????for?x?in?self.children:
????????????x._init()
????def?eval(self):
????????self._init()
????????return?self.value()
????def?_calc_value(self):
????????raise?NotImplementedError()
????def?value(self):
????????if?self._value?is?None:
????????????self._value?=?self._calc_value()
????????return?self._value
????def?size(self):
????????return?1
????def?__or__(self,?other):
????????other?=?pack(other)
????????return?oneof(
????????????self,?other,?possibility=self.size()?/?(self.size()?+?other.size())
????????)
????def?__ror__(self,?other):
????????return?self?|?other
????def?__add__(self,?other):
????????return?add(self,?other)
????def?__radd__(self,?other):
????????return?self?+?other
????def?__sub__(self,?other):
????????return?self?+?neg(other)
????def?__rsub__(self,?other):
????????return?neg(self?-?other)
????def?__mul__(self,?other):
????????return?mul(self,?other)
????def?__rmul__(self,?other):
????????return?self?*?other
????def?to(self,?annotation):
????????self._to(annotation)
????????for?x?in?self.children:
????????????x.to(annotation)
????????return?self
????def?_to(self,?annotation):
????????pass
這個(gè)類(lèi)不僅持有了_calc_value,value,eval等和取值有關(guān)的函數(shù),還持有size這個(gè)反應(yīng)生成數(shù)據(jù)個(gè)數(shù)的函數(shù)。另外還持有了一系列的魔法函數(shù),讓不同的generator子類(lèi)可以互相組合,提升了自動(dòng)測(cè)試框架書(shū)寫(xiě)的靈活性。最后還有一個(gè)to成員函數(shù),這個(gè)函數(shù)被繼承generator基類(lèi)的類(lèi)重寫(xiě),用來(lái)確定這個(gè)隨機(jī)數(shù)據(jù)結(jié)構(gòu)的數(shù)值類(lèi)型。
所有的generator派生類(lèi)都繼承了generator基類(lèi),并重寫(xiě)其中的__init__,__calc_value,size,_to等成員函數(shù)。比如nothing這個(gè)generator的派生類(lèi)就是直接重寫(xiě)_calc_value函數(shù),并在其中返回一個(gè)什么都不做的類(lèi)的實(shí)體。
class?Nothing:
????pass
class?nothing(generator):
????def?__init__(self):
????????super().__init__([])
????def?_calc_value(self):
????????return?Nothing()
再例如,random這個(gè)generator的派生類(lèi)的定義如下:
class?random(generator):
????def?__init__(self,?low=1,?high=6):
????????self.low?=?pack(low)
????????self.high?=?pack(high)
????????super().__init__([self.low,?self.high])
????????self.annotation?=?None
????def?_to(self,?annotation):
????????if?self.annotation?is?not?None:
????????????return
????????if?hasattr(annotation,?"__origin__"):
????????????#?PyTorch?_size_2_t?and?similar?types?are?defined?by?type?variables,
????????????#?leading?to?unexpected?__args__?and?__origin__
????????????#
????????????#?>>>?_size_2_t?=?Union[T,?Tuple[T,?T]][int]
????????????#?>>>?_size_2_t.__origin__
????????????#?typing.Union[~T,?typing.Tuple[~T,?~T]]
????????????#
????????????#?So?recreate?a?new?annotation?object?by?repr?and?eval
????????????#
????????????#?>>>?_size_2_t
????????????#?typing.Union[int,?typing.Tuple[int,?int]]
????????????#?>>>?_size_2_t_new?=?eval(repr(annotation))
????????????#?>>>?_size_2_t_new.__origin__
????????????#?typing.Union
????????????annotation?=?eval(repr(annotation))
????????self.annotation?=?annotation
????def?_generate(self,?annotation):
????????if?hasattr(annotation,?"__origin__"):
????????????if?annotation.__origin__?is?Union:
????????????????x?=?random_util.choice(annotation.__args__)
????????????????return?self._generate(x)
????????????if?annotation.__origin__?is?Tuple?or?annotation.__origin__?is?py_tuple:
????????????????return?[self._generate(x)?for?x?in?annotation.__args__]
????????????else:
????????????????raise?NotImplementedError(
????????????????????f"Not?implemented?annotation?{annotation}?in?random,?type(annotation.__origin__)?is?{type(annotation.__origin__)}"
????????????????)
????????low,?high?=?self.low.value(),?self.high.value()
????????if?annotation?==?int:
????????????val?=?int(rng.integers(low,?high))
????????elif?annotation?==?float:
????????????val?=?float(rng.random()?*?(high?-?low)?+?low)
????????elif?annotation?==?bool:
????????????val?=?random_util.choice([True,?False])
????????else:
????????????raise?NotImplementedError(
????????????????f"Not?implemented?annotation?{annotation}?in?random"
????????????)
????????return?val
????def?_calc_value(self):
????????return?self._generate(self.annotation)
def?random_or_nothing(low,?high):
????return?oneof(random(low,?high),?nothing(),?possibility=2?/?3)
這里需要注意的一點(diǎn)是,持有annotation屬性的generator派生類(lèi)的可以通過(guò)to來(lái)更新annotation屬性(如random類(lèi)),也可以忽略這個(gè)annotation直接在_calc_value構(gòu)造相應(yīng)類(lèi)型的隨機(jī)結(jié)果(如random_device類(lèi))。
0x3.2 AutoTest核心實(shí)現(xiàn)
AutoTest框架的核心實(shí)現(xiàn)在https://github.com/Oneflow-Inc/oneflow/blob/v0.6.0/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py 這個(gè)文件。這個(gè)文件最后2行代碼是:
torch?=?GetDualObject("",?torch_original,?flow)
__all__?=?["autotest",?"random_pytorch_tensor"]
這行代碼torch = GetDualObject("", torch_original, flow) ?里面的torch_original表示原始的PyTorch框架,而使用GetDualObject獲得的torch表示是對(duì)原始的PyTorch和OneFlow進(jìn)行了一個(gè)封裝,變成了一個(gè)high level的PyTorch。因此,這里最關(guān)鍵的實(shí)現(xiàn)就是GetDualObject這個(gè)函數(shù),我們先不關(guān)注這個(gè)函數(shù)具體在做什么,而是它返回了什么。查看代碼可以發(fā)現(xiàn)這個(gè)函數(shù)返回了一個(gè)DualObject類(lèi)對(duì)象,我們先研究一下這個(gè)類(lèi):
class?DualObject:
????def?__init__(self,?name,?pytorch,?oneflow):
????????self.name?=?name
????????self.pytorch?=?pytorch
????????self.oneflow?=?oneflow
????????if?isinstance(pytorch,?torch_original.nn.Module):
????????????state_dict?=?pytorch.state_dict()
????????????state_dict?=?{k:?v.detach().cpu().numpy()?for?(k,?v)?in?state_dict.items()}
????????????oneflow.load_state_dict(state_dict,?strict=False)
????????????if?testing:
????????????????dual_modules_to_test.append(self)
????????if?isinstance(pytorch,?torch_original.Tensor):
????????????if?testing:
????????????????dual_objects_to_test.append(self)
????def?__repr__(self):
????????return?f"PyTorch?object:\n{self.pytorch}\n\nOneFlow?object:\n{self.oneflow}"
????def?__getattr__(self,?key):
????????pytorch_attr?=?getattr(self.pytorch,?key)
????????oneflow_attr?=?getattr(self.oneflow,?key)
????????new_name?=?f"{self.name}.{key}"
????????global?call_pytorch
????????call_pytorch?=?self.pytorch
????????return?GetDualObject(new_name,?pytorch_attr,?oneflow_attr)
在__init__中傳入了類(lèi)對(duì)象名和pytorch/oneflow兩個(gè)對(duì)象,在導(dǎo)出high level的PyTorch的時(shí)候傳入的是torch_original和flow,而在導(dǎo)出random_pytorch_tensor 接口時(shí)傳入的是pytorch_tensor和oneflow_tensor。這里不妨先看一下random_pytorch_tensor這個(gè)函數(shù)的實(shí)現(xiàn):
def?random_pytorch_tensor(
????ndim=None,
????dim0=1,
????dim1=None,
????dim2=None,
????dim3=None,
????dim4=None,
????low=0,
????high=1,
????dtype=float,
????requires_grad=True,
):
????if?isinstance(requires_grad,?generator):
????????requires_grad?=?requires_grad.value()
????pytorch_tensor?=?(
????????random_tensor(ndim,?dim0,?dim1,?dim2,?dim3,?dim4,?low,?high,?dtype)
????????.value()
????????.requires_grad_(requires_grad?and?dtype?!=?int)
????)
????flow_tensor?=?flow.tensor(
????????pytorch_tensor.detach().cpu().numpy(),
????????requires_grad=(requires_grad?and?dtype?!=?int),
????)
????return?GetDualObject("unused",?pytorch_tensor,?flow_tensor)
可以看到它和導(dǎo)出high level PyTorch的實(shí)現(xiàn)一樣,也是通過(guò)調(diào)用GetDualObject來(lái)獲得了一個(gè)對(duì)象。再回到DualObject類(lèi)的實(shí)現(xiàn),可以發(fā)現(xiàn)這里分別使用了dual_modules_to_test和dual_objects_to_test這兩個(gè)list來(lái)分別記錄OneFlow和PyTorch的nn.Module和tensor對(duì)象。另外DualObject類(lèi)還重寫(xiě)了__getattr__這個(gè)魔法方法,這里以Flatten為例來(lái)看看這個(gè)魔法方法獲取了AutoTest程序中的那些屬性:
def?__getattr__(self,?key):
????????pytorch_attr?=?getattr(self.pytorch,?key)
????????oneflow_attr?=?getattr(self.oneflow,?key)
????????print(key)
????????#?print(pytorch_attr)
????????#?print(oneflow_attr)
????????new_name?=?f"{self.name}.{key}"
????????return?GetDualObject(new_name,?pytorch_attr,?oneflow_attr)
#?flatten的AutoTest程序
@autotest(auto_backward=False)
def?test_against_pytorch(test_case):
????m?=?torch.nn.Flatten(
????????start_dim=random(1,?6)?|?nothing(),?end_dim=random(1,?6)?|?nothing()
????)
????m.train(random())
????device?=?random_device()
????m.to(device)
????x?=?random_pytorch_tensor().to(device)
????y?=?m(x)
????return?y
然后看一下__getattr__中key的打印結(jié)果:
nn
Flatten
train
to
to
可以看到被autotest()裝飾器修飾的測(cè)試程序中的PyTorch或者OneFlow的nn.Module或者其它函數(shù)都重寫(xiě)了這個(gè)方法,它將這些nn.Module或者其它函數(shù)的參數(shù)和屬性都取出來(lái)并同樣使用GetDualObject返回一個(gè)新的DualObject對(duì)象,我們可以打印一下Flatten這個(gè)nn.Module對(duì)應(yīng)的DualObject對(duì)象是什么:
PyTorch?object:
1,?end_dim=-1)>
OneFlow?object:
1,?end_dim=-1)>
GetDualObject這個(gè)函數(shù)就是根據(jù)傳入的Pytorch以及OneFlow對(duì)象和它們的名字來(lái)生成一個(gè)DualObject對(duì)象。GetDualObject這個(gè)函數(shù)會(huì)為high level的PyTorch重寫(xiě)傳入的原始PyTorch以及OneFlow對(duì)象的__call__魔法函數(shù),最后返回一個(gè)DualObject對(duì)象,這個(gè)過(guò)程還包含了跳過(guò)一些不需要關(guān)注的魔法函數(shù)以及檢查傳入對(duì)象的屬性是否合法和基于nn.Module和其它API默認(rèn)參數(shù)的類(lèi)型對(duì)generator繼承類(lèi)產(chǎn)生的隨機(jī)數(shù)據(jù)綁定特定類(lèi)型的工作(get_args函數(shù)中完成)。這里還有一句對(duì)于Tensor方法的特判,因?yàn)門(mén)ensor方法的調(diào)用方式(通過(guò)getattr)和其它Module和函數(shù)不同(通過(guò)__call__)。
GetDualObject的實(shí)現(xiàn)思路大致就是這樣,代碼比較長(zhǎng)這里就不貼了,感興趣可以在這里查看:https://github.com/Oneflow-Inc/oneflow/blob/v0.6.0/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py#L195-L401 。
最后,我們看一下autotest()裝飾器的實(shí)現(xiàn):
def?autotest(
????n=20,
????auto_backward=True,
????rtol=0.0001,
????atol=1e-05,
????check_graph=True,
????check_allclose=True,
):
????verbose?=?os.getenv("ONEFLOW_TEST_VERBOSE")?is?not?None
????def?deco(f):
[email protected](f)
????????def?new_f(test_case):
????????????nonlocal?n
????????????loop_limit?=?n?*?20
????????????loop?=?0
????????????while?n?>?0:
????????????????clear_note_fake_program()
????????????????if?loop?>?loop_limit:
????????????????????raise?ValueError("autotest?stuck?in?an?endless?loop!")
????????????????dual_modules_to_test.clear()
????????????????dual_objects_to_test.clear()
????????????????try:
????????????????????global?testing
????????????????????testing?=?True
????????????????????global?testing_graph
????????????????????if?check_graph:
????????????????????????testing_graph?=?True
????????????????????res?=?f(test_case)
????????????????????testing?=?False
????????????????????testing_graph?=?False
????????????????except?(PyTorchDoesNotSupportError,?BothDoNotSupportError)?as?e:
????????????????????if?verbose:
????????????????????????print(f"{f.__name__}")
????????????????????????print(e)
????????????????????loop?+=?1
????????????????????continue
????????????????if?res?is?not?None:
????????????????????if?not?isinstance(res,?collections.abc.Sequence):
????????????????????????res?=?[res]
????????????????????func_outputs?=?res
????????????????????for?x?in?res:
????????????????????????if?auto_backward:
????????????????????????????if?isinstance(x.pytorch,?torch_original.Tensor):
????????????????????????????????call_tensor_id.append(id(x.pytorch))
????????????????????????????????x.sum().backward()
????????????????????????dual_objects_to_test.append(x)
????????????????for?x?in?dual_modules_to_test:
????????????????????for?key?in?x.pytorch.state_dict().keys():
????????????????????????if?key?not?in?x.oneflow.state_dict().keys():
????????????????????????????warnings.warn(f"oneflow?module?don't?have?`{key}`")
????????????????????????????continue
????????????????????????vis_parameters[key]?=?x.pytorch.state_dict()[key]
????????????????????????dual_objects_to_test.append(
????????????????????????????GetDualObject(
????????????????????????????????"unused",
????????????????????????????????getattr(x.pytorch,?key),
????????????????????????????????getattr(x.oneflow,?key),
????????????????????????????)
????????????????????????)
????????????????????????call_tensor_id.append(id(getattr(x.pytorch,?key)))
????????????????????????dual_objects_to_test.append(
????????????????????????????GetDualObject(
????????????????????????????????"unused",
????????????????????????????????getattr(x.pytorch,?key).grad,
????????????????????????????????getattr(x.oneflow,?key).grad,
????????????????????????????)
????????????????????????)
????????????????????????call_tensor_id.append(id(getattr(x.pytorch,?key).grad))
????????????????for?x?in?dual_objects_to_test:
????????????????????if?(
????????????????????????isinstance(x.pytorch,?torch_original.Tensor)
????????????????????????and?id(x.pytorch)?not?in?call_tensor_id
????????????????????):
????????????????????????vis_tensor.append(x.pytorch)
????????????????#?check?eager
????????????????for?x?in?dual_objects_to_test:
????????????????????if?check_allclose:
????????????????????????test_case.assertTrue(check_equality(x,?rtol=rtol,?atol=atol),?x)
????????????????????if?verbose:
????????????????????????print(f"{f.__name__}?test?eager?passed.")
????????????????????
????????????????n?-=?1
????????????????loop?+=?1
????????return?new_f
????return?deco
這個(gè)裝飾器的res = f(test_case)這行代碼會(huì)執(zhí)行這個(gè)裝飾器修飾的自動(dòng)測(cè)試程序,會(huì)在給定輸入的情況下去分別運(yùn)行PyTorch和OneFlow的程序獲得所有中間的輸出tensor,包括tensor的梯度,并將它們記錄到dual_modules_to_test這個(gè)列表。再遍歷這個(gè)列表里面的每個(gè)tensor,比較數(shù)值和shape是否完全一樣。比較函數(shù)實(shí)現(xiàn)在:https://github.com/Oneflow-Inc/oneflow/blob/v0.6.0/python/oneflow/test_utils/automated_test_util/torch_flow_dual_object.py#L565-L599 原理就是拿到tensor的numpy數(shù)據(jù)進(jìn)行比較。autotest() 裝飾器還有幾個(gè)參數(shù)可以調(diào)整,可以控制測(cè)試是否執(zhí)行反向,執(zhí)行次數(shù),以及最后結(jié)果對(duì)比的精度閾值。
0x4. 自動(dòng)生成出BUG的程序和數(shù)據(jù)
上面介紹完了AutoTest框架的原理和使用方法,這里再展示一下基于AutoTest框架如何拿到可復(fù)現(xiàn)BUG的程序以及對(duì)應(yīng)的輸入tensor和參數(shù)等。原理很簡(jiǎn)單,就是把GetDualObject過(guò)程中使用的api記錄下來(lái)拼起來(lái)就構(gòu)成一個(gè)完整的程序,這里展示一下在CI中的效果。https://github.com/Oneflow-Inc/oneflow/runs/4760189461?check_suite_focus=true 這個(gè)例子展示了在某次CI過(guò)程中,OneFlow的conv_transpose2d算子和PyTorch的conv_transpose2d算子在某個(gè)case下沒(méi)有對(duì)齊,那么CI在報(bào)告這個(gè)錯(cuò)誤時(shí)也輸出了對(duì)應(yīng)的復(fù)現(xiàn)代碼和數(shù)據(jù),可以方便框架開(kāi)發(fā)者進(jìn)行定位和判斷:

除此之外,這個(gè)AutoTest框架目前不僅負(fù)責(zé)Eager算子的測(cè)試,還被我們擴(kuò)展到支持nn.Graph和Eager Consistent等多種情況,極大的方便了框架開(kāi)發(fā)者。
0x5. 總結(jié)
這篇文章介紹了OneFlow的算子AutoTest框架,提供了一個(gè)深度學(xué)習(xí)優(yōu)雅的做算子對(duì)齊的方法,使得開(kāi)發(fā)者和用戶(hù)可以像寫(xiě)PyTorch那樣方便寫(xiě)測(cè)試程序。AutoTest框架的靈活性和易用性都比較強(qiáng),歡迎大家學(xué)習(xí)或者使用。
0x6. 相關(guān)鏈接
https://github.com/Oneflow-Inc/oneflow https://github.com/pytorch/pytorch
