深度學(xué)習(xí)框架量化感知訓(xùn)練的思考及OneFlow的一種解決方案
【GiantPandaCV導(dǎo)語(yǔ)】這篇文章分享的是筆者最近在OneFlow做的一個(gè)項(xiàng)目,將Pytorch FX移植到OneFlow之后實(shí)現(xiàn)了自動(dòng)量化感知訓(xùn)練動(dòng)態(tài)圖模型(在Pytorch和OneFlow中都稱為nn.Module)?,F(xiàn)在用戶可以在自己構(gòu)建的nn.Module基礎(chǔ)上,修改很少的代碼即可完成從nn.Module量化感知訓(xùn)練到用TensorRT將量化感知訓(xùn)練后的模型部署到GPU上運(yùn)行的完整鏈路。在TensorRT上推理是利用了ONNX作為中間表示,即Oneflow動(dòng)態(tài)圖模型(nn.Module)->OneFlow量化感知訓(xùn)練模型(nn.Module)->OneFlow靜態(tài)圖(nn.Graph)->ONNX->TensorRT。量化感知訓(xùn)練是基于支持在Eager下寫(xiě)Pass的FX模塊(FX被Pytorch率先提出,筆者將其基礎(chǔ)設(shè)施移植到了OneFlow)來(lái)完成的。讀者如果想體驗(yàn)這個(gè)功能可以按照本文的方法進(jìn)行操作,有任何使用上的問(wèn)題可以聯(lián)系筆者。
0x0. 總覽
好久不見(jiàn),大家國(guó)慶快樂(lè)!
相信不少小伙伴都了解或者使用了一些深度學(xué)習(xí)框架比如Pytorch,TensorFlow,OneFlow(也是筆者目前正在參與開(kāi)發(fā)的)。但當(dāng)大家使用深度學(xué)習(xí)框架的訓(xùn)練量化方案時(shí)如果第一感覺(jué)就是太復(fù)雜了,那么你可能會(huì)對(duì)這篇文章感興趣!因?yàn)槲以?個(gè)月前開(kāi)始接觸這個(gè)項(xiàng)目前,對(duì)量化感知訓(xùn)練的知識(shí)積累也非常少,并且我也會(huì)認(rèn)為各個(gè)框架的量化感知訓(xùn)練方案很復(fù)雜,甚至不想研究這些API。
這篇文章我會(huì)以Pytorch的兩代量化方案開(kāi)始切入談一談他們的好處和壞處,然后我會(huì)講講我在吸收了Pytorch的部分優(yōu)秀成果(FX模塊)并加上一些自己的想法后把OneFlow的量化感知訓(xùn)練方案做成了什么樣子。這里先羅列一下這篇文章中涉及到的知識(shí)點(diǎn):
Pytorch FX模塊 Eager Pass 量化感知訓(xùn)練 Conv+BN的融合 OneFlow的動(dòng)靜轉(zhuǎn)換(nn.Graph) ONNX TensorRT
如果你對(duì)上面的任意一個(gè)知識(shí)點(diǎn)不熟悉,那也是完全沒(méi)有關(guān)系的。實(shí)際上即使你只會(huì)用Pytorch搭建模型也可以快速把本文的量化感知訓(xùn)練方案用起來(lái)。因?yàn)榱炕兄?xùn)練的工作和模型轉(zhuǎn)化成ONNX以及用TensorRT來(lái)部署運(yùn)行的代碼我們?cè)贠neFlow社區(qū)中均開(kāi)源了。
簡(jiǎn)單總結(jié)一下就是,用戶可以基于OneFlow搭建一個(gè)動(dòng)態(tài)圖模型(即nn.Module,算子的API和Pytorch基本一樣),然后調(diào)用下面的幾行代碼就可以完成這個(gè)動(dòng)態(tài)圖模型(是一個(gè)nn.Module)自動(dòng)在合適的位置插入量化模塊生成一個(gè)量化模型(仍然是nn.Module),然后基于這個(gè)量化模型完成量化感知訓(xùn)練。
gm:?flow.fx.GraphModule?=?flow.fx.symbolic_trace(net)
qconfig?=?{
????'quantization_bit':?8,?
????'quantization_scheme':?"symmetric",?
????'quantization_formula':?"cambricon",?
????'per_layer_quantization':?True,
????'momentum':?0.95,
}
net?=?quantization_aware_training(gm,?flow.randn(1,?3,?32,?32),?qconfig)
net?=?net.to(device)
在訓(xùn)練完成后,調(diào)用下面的代碼完成訓(xùn)練量化模型到ONNX的轉(zhuǎn)換,并使用TensorRT在GPU上推理。
quantization_resnet18?=?quantization_aware_training(gm,?flow.randn(1,?3,?32,?32).to("cuda"),?qconfig)
quantization_resnet18?=?quantization_resnet18.to("cuda")
quantization_resnet18.eval()
checkpoint?=?flow.load('/home/zhangxiaoyu/oneflow-cifar/checkpoint/epoch_11_val_acc_83.280000')
quantization_resnet18.load_state_dict(checkpoint)
origin_gm:?flow.fx.GraphModule?=?flow.fx.symbolic_trace(resnet18)
dequantization_resnet18?=?dequantization_aware_training(origin_gm,?gm,?flow.randn(1,?3,?32,?32).to("cuda"),?qconfig)
dequantization_resnet18?=?dequantization_resnet18.to("cuda")
dequantization_resnet18.eval()
class?ResNet18Graph(flow.nn.Graph):
????def?__init__(self):
????????super().__init__()
????????self.m?=?dequantization_resnet18
????def?build(self,?x):
????????out?=?self.m(x)
????????return?out
def?test_resnet():???
????resnet_graph?=?ResNet18Graph()
????resnet_graph._compile(flow.randn(1,?3,?32,?32).to("cuda"))
????with?tempfile.TemporaryDirectory()?as?tmpdirname:
????????flow.save(dequantization_resnet18.state_dict(),?tmpdirname)
????????convert_to_onnx_and_check(resnet_graph,?flow_weight_dir=tmpdirname,?onnx_model_path="/tmp",?print_outlier=True)
????????ipt_dict,?onnx_res?=?run_onnx("/tmp/model.onnx",?get_onnx_provider("cpu"))
????????trt_res?=?run_tensorrt("/tmp/model.onnx",?ipt_dict[list(ipt_dict.keys())[0]])
????????compare_result(onnx_res,?trt_res,?atol=1e-4,?print_outlier=True)
test_resnet()
用戶只需要使用上面示例中的短短幾十行代碼就可以完成一個(gè)端到端的量化感知訓(xùn)練到GPU部署的全流程。所以我認(rèn)為這項(xiàng)工作是有趣并且相對(duì)簡(jiǎn)潔的,當(dāng)然我更希望聽(tīng)到用戶的想法,然后就寫(xiě)了這篇文章來(lái)分享這個(gè)項(xiàng)目。這個(gè)項(xiàng)目的所有代碼均開(kāi)源在了OneFlow社區(qū),下面是對(duì)應(yīng)的鏈接。如果你使用這個(gè)方案碰到了什么問(wèn)題都可以第一時(shí)間聯(lián)系我。我的個(gè)人微信號(hào)是bbuf23333,來(lái)時(shí)請(qǐng)備注 量化感知訓(xùn)練
OneFlow FX(用來(lái)實(shí)現(xiàn)量化感知訓(xùn)練的基礎(chǔ)設(shè)施):https://github.com/Oneflow-Inc/oneflow/pull/5939 OneFlow Cifar(基于OneFlow FX量化訓(xùn)練Cifar10):https://github.com/BBuf/oneflow-cifar OneFlow->ONNX和TensorRT運(yùn)行:https://github.com/Oneflow-Inc/oneflow_convert/pull/45
0x1. Pytorch量化方案的沉浮
這一節(jié)主要基于Pytorch的官方文檔:https://pytorch.org/docs/1.9.0/quantization.html來(lái)進(jìn)行說(shuō)明。Pytorch第一代量化方案叫作Eager Mode Quantization,然后從1.8開(kāi)始推出FX Graph Mode Quantization。Eager Mode Quantization需要用戶手動(dòng)更改模型,并手動(dòng)指定需要融合的Op。FX Graph Mode Quantization解放了用戶,一鍵自動(dòng)量化,無(wú)需用戶修改模型和關(guān)心內(nèi)部操作。這個(gè)改動(dòng)具體可以體現(xiàn)在下面的圖中。

下面分別解釋一下Pytorch這兩種量化方式的區(qū)別。
Eager Mode Quantization
class?Net(nn.Module):
????def?__init__(self,?num_channels=1):
????????super(Net,?self).__init__()
????????self.conv1?=?nn.Conv2d(num_channels,?40,?3,?1)
????????self.conv2?=?nn.Conv2d(40,?40,?3,?1)
????????self.fc?=?nn.Linear(5*5*40,?10)
????def?forward(self,?x):
????????x?=?F.relu(self.conv1(x))
????????x?=?F.max_pool2d(x,?2,?2)
????????x?=?F.relu(self.conv2(x))
????????x?=?F.max_pool2d(x,?2,?2)
????????x?=?x.reshape(-1,?5*5*40)
????????x?=?self.fc(x)
????????return?x
Pytorch可以在nn.Module的foward里面隨意構(gòu)造網(wǎng)絡(luò),可以調(diào)用其它nn.Module,也可以調(diào)用nn.functional.xxx,甚至可以在里面寫(xiě)If這種控制邏輯。但這也帶來(lái)了一個(gè)問(wèn)題,就是在Eager層面比較難獲取這個(gè)模型的圖結(jié)構(gòu)。所以在Eager Mode Quantization中,要量化這個(gè)網(wǎng)絡(luò)必須做手動(dòng)修改:
class?NetQuant(nn.Module):
????def?__init__(self,?num_channels=1):
????????super(NetQuant,?self).__init__()
????????self.conv1?=?nn.Conv2d(num_channels,?40,?3,?1)
????????self.relu1?=?nn.ReLU()
????????self.pool1?=?nn.MaxPool2d(2,?2)
????????self.conv2?=?nn.Conv2d(40,?40,?3,?1)
????????self.relu2?=?nn.ReLU()
????????self.pool2?=?nn.MaxPool2d(2,?2)
????????self.fc?=?nn.Linear(5*5*40,?10)
????????self.quant?=?torch.quantization.QuantStub()
????????self.dequant?=?torch.quantization.DeQuantStub()
????def?forward(self,?x):
????????x?=?self.quant(x)
????????x?=?self.relu1(self.conv1(x))
????????x?=?self.pool1(x)
????????x?=?self.relu2(self.conv2(x))
????????x?=?self.pool2(x)
????????x?=?x.reshape(-1,?5*5*40)
????????x?=?self.fc(x)
????????x?=?self.dequant(x)
????????return?x
也就是說(shuō),除了Conv,Linear這些含有參數(shù)的Module外,ReLU,MaxPool2d也要在__init__中定義,Eager Mode Quantization才可以正確處理。
除了這一點(diǎn),還有一些情況是需要Fuse之后做量化比如Conv+ReLU,那么還需要手動(dòng)指定這些層進(jìn)行折疊,目前這種量化模式支持ConV + BN、ConV + BN + ReLU、Conv + ReLU、Linear + ReLU、BN + ReLU的折疊。
model?=?NetQuant()model.qconfig?=?torch.quantization.get_default_qconfig('fbgemm')
modules_to_fuse?=?[['conv1',?'relu1'],?['conv2',?'relu2']]??#?指定合并layer的名字
model_fused?=?torch.quantization.fuse_modules(model,?modules_to_fuse)
model_prepared?=?torch.quantization.prepare(model_fused)
post_training_quantize(model_prepared,?train_loader)???#?這一步是做后訓(xùn)練量化
model_int8?=?torch.quantization.convert(model_prepared)
整個(gè)流程比較逆天,不知道有沒(méi)有人用。不過(guò)公眾號(hào)有小伙伴確實(shí)用過(guò),見(jiàn)文章:Pytorch量化感知訓(xùn)練詳解
FX Graph Mode Quantization
關(guān)于Pytorch FX模塊是什么,我們放到下一節(jié)來(lái)講。
由于 Pytorch FX 可以自動(dòng)跟蹤 forward 里面的代碼,因此它是真正記錄了網(wǎng)絡(luò)里面的每個(gè)節(jié)點(diǎn),在 fuse 和動(dòng)態(tài)插入量化節(jié)點(diǎn)方面,比 Eager 模式強(qiáng)太多。對(duì)于前面那個(gè)模型代碼,我們不需要對(duì)網(wǎng)絡(luò)做修改,直接讓 FX 幫我們自動(dòng)修改網(wǎng)絡(luò)即可,一個(gè)使用示例如下:
from?torch.quantization?import?get_default_qconfig,?quantize_jit
from?torch.quantization.quantize_fx?import?prepare_fx,?convert_fx
model?=?Net()??
qconfig?=?get_default_qconfig("fbgemm")
qconfig_dict?=?{"":?qconfig}
model_prepared?=?prepare_fx(model,?qconfig_dict)
post_training_quantize(model_prepared,?train_loader)??????#?這一步是做后訓(xùn)練量化
model_int8?=?convert_fx(model_prepared)
基于這兩套量化方案來(lái)看,基于FX的量化方案顯然更加優(yōu)秀,因?yàn)樗恍枰脩粼诙x模型的時(shí)候做什么額外限制,用戶仍然是隨心所欲的寫(xiě)模型代碼就行了,這才符合人的常識(shí)。我在做OneFlow的量化感知訓(xùn)練方案時(shí)也正是基于FX這個(gè)基礎(chǔ)設(shè)施(我將其核心功能移植到了OneFlow框架下,代碼鏈接第一節(jié)給了)來(lái)完成的。
另外在TensorRT的工程中:https://github.com/NVIDIA/TensorRT/tree/master/tools/pytorch-quantization發(fā)現(xiàn)Pytorch量化模型要轉(zhuǎn)為ONNX來(lái)部署現(xiàn)在似乎還是得基于第一個(gè)版本的方案,Pytorch FX這邊似乎想直接從nn.Module轉(zhuǎn)到TensorRT,不經(jīng)過(guò)ONNX的中間表示,所以我這里的技術(shù)路線還是有點(diǎn)不一樣。
0x2. OneFlow FX (在Eager中寫(xiě)Pass)
FX可以用來(lái)做什么?

FX可以將一個(gè)nn.Module變換后生成另外一個(gè)nn.Module,只需要在這個(gè)架構(gòu)的基礎(chǔ)上實(shí)現(xiàn)一些Transformation(也可以叫Pass),比如在Conv后自動(dòng)插入偽量化節(jié)點(diǎn)實(shí)現(xiàn)訓(xùn)練量化,然后生成GraphModule(這個(gè)也是nn.Module)進(jìn)行訓(xùn)練和轉(zhuǎn)為ONNX進(jìn)行部署。
OneFlow FX模塊在這個(gè)PR(https://github.com/Oneflow-Inc/oneflow/pull/5939)中實(shí)現(xiàn),這里復(fù)用了Pytorch FX基礎(chǔ)設(shè)施的核心邏輯和代碼,這個(gè)PR里的主要工作為:
[x] 精簡(jiǎn)Pytorch FX的特殊設(shè)計(jì)比如對(duì)_C的Trace,和Jit的交互。保留核心功能,即Symbolic Tracing,Intermediate Representation和Transformation以及Python Codegen這4個(gè)組成部分。 [x] ?分步實(shí)現(xiàn)以上四大功能的代碼,完全適配OneFlow的相關(guān)設(shè)計(jì),現(xiàn)在可以一鍵import oneflow.fx來(lái)體驗(yàn)??梢訲race住基本所有OneFlow API搭建的Eager模型的結(jié)構(gòu),并將其變換成一個(gè)等價(jià)的 nn.Module,我們還可以在這個(gè)nn.Module的基礎(chǔ)上自定義自己的Transformation Pass,我這里實(shí)現(xiàn)了Shape Infer和Quantization以及Dequantization的Pass。[x] 增加AlexNet,ResNet50,MobileNetV2等模型的測(cè)試。
然后分享一下OneFlow FX的整體思路。
先看一個(gè)示例:
????import?oneflow
????#?Simple?module?for?demonstration
????class?MyModule(oneflow.nn.Module):
????????def?__init__(self):
????????????super().__init__()
????????????self.param?=?oneflow.nn.Parameter(oneflow.rand(3,?4))
????????????self.linear?=?oneflow.nn.Linear(4,?5)
????????def?forward(self,?x):
????????????return?self.linear(x?+?self.param).clamp(min=0.0,?max=1.0)
????module?=?MyModule()
????from?oneflow.fx?import?symbolic_trace
????#?Symbolic?tracing?frontend?-?captures?the?semantics?of?the?module
????symbolic_traced?:?oneflow.fx.GraphModule?=?symbolic_trace(module)
????#?High-level?intermediate?representation?(IR)?-?Graph?representation
????print(symbolic_traced.graph)
????"""
????graph():
????????%x?:?[#users=1]?=?placeholder[target=x]
????????%param?:?[#users=1]?=?get_attr[target=param]
????????%add?:?[#users=1]?=?call_function[target=operator.add](args?=?(%x,?%param),?kwargs?=?{})
????????%linear?:?[#users=1]?=?call_module[target=linear](args?=?(%add,),?kwargs?=?{})
????????%clamp?:?[#users=1]?=?call_method[target=clamp](args?=?(%linear,),?kwargs?=?{min:?0.0,?max:?1.0})
????????return?clamp
????"""
????#?Code?generation?-?valid?Python?code
????print(symbolic_traced.code)
????"""
????def?forward(self,?x):
????????param?=?self.param
????????add?=?x?+?param;??x?=?param?=?None
????????linear?=?self.linear(add);??add?=?None
????????clamp?=?linear.clamp(min?=?0.0,?max?=?1.0);??linear?=?None
????????return?clamp
????"""
在FX中有一個(gè)Proxy類,它會(huì)把oneflow中所有的call_method和call_function以及math庫(kù)中的函數(shù)和常見(jiàn)的魔法函數(shù)都包裝一遍來(lái)記錄OneFlow中所有的運(yùn)算符,這個(gè)在import oneflow.fx時(shí)就做好了。然后在傳入一個(gè)nn.Module調(diào)用symbolic_trace進(jìn)行跟蹤代碼的時(shí)候會(huì)首先處理__init__中的其它nn.Module,把這些nn.Module也用Proxy包起來(lái),同時(shí)輸入數(shù)據(jù)也要包起來(lái)。
用Proxy包好所有程序中可能存在的運(yùn)算符之后就執(zhí)行一遍forward,這個(gè)forward的輸入數(shù)據(jù)不再是Tensor而是Proxy(Tensor)。由于程序的執(zhí)行過(guò)程類似于一個(gè)運(yùn)算符和數(shù)據(jù)入棧出棧的過(guò)程,所以我們可以直接按照這個(gè)執(zhí)行順序?qū)偛庞肞roxy記錄下來(lái)的數(shù)據(jù)和Op進(jìn)行unpack,unpack之后可以拿到真實(shí)的Tensor, Parameter和運(yùn)算符等等,我們將這些數(shù)據(jù)和運(yùn)算符當(dāng)作點(diǎn)和邊去構(gòu)造一個(gè)新的Graph。那么Graph是怎么轉(zhuǎn)化成nn.Module的呢?FX中通過(guò)引入GraphModule的數(shù)據(jù)結(jié)構(gòu)來(lái)持有這個(gè)Graph,此外GraphModule還持有code和foward成員,這兩者都是基于Graph自動(dòng)生成的,注意GraphModule仍然是nn.Module。
自動(dòng)生成的代碼就是GraphModule中的code,打印出來(lái)其實(shí)就是整個(gè)forward函數(shù)的完整執(zhí)行過(guò)程。
另外FX還提供了一個(gè)Interpreter類用來(lái)讓用戶自定義nn.Module的執(zhí)行過(guò)程,比如這個(gè)PR提供了一個(gè)基于這個(gè)類做所有中間Tensor形狀推導(dǎo)的Pass。另外還提供了一個(gè)基于pydot將GraphModule結(jié)構(gòu)可視化的Pass,如下圖。

相信到這里大家對(duì)FX有一個(gè)了解了,這里最棒的一個(gè)功能就是我們可以對(duì)nn.Module進(jìn)行修改,然后返回變化后的nn.Module。說(shuō)
到這里,我們自然能想到量化感知訓(xùn)練不就是把Conv+BN或者Conv,Linear等組件替換為插入了偽量化節(jié)點(diǎn)的組件嗎?所以我們基于FX來(lái)寫(xiě)一個(gè)Pass就可以完成這件事了。
這就是上面說(shuō)的,FX支持在Eager寫(xiě)Pass。
然而FX也存在缺陷,目前無(wú)法處理控制流,需要注意網(wǎng)絡(luò)中不要帶控制流(不過(guò)這一點(diǎn)暫時(shí)影響不大,因?yàn)橛脩粢话愣疾粫?huì)部署含有控制流的網(wǎng)絡(luò),如果有這個(gè)需求我們可以交流)。
0x3. 實(shí)現(xiàn)量化感知訓(xùn)練Pass
有了OneFlow FX之后我們就可以實(shí)現(xiàn)一個(gè)量化感知訓(xùn)練的Pass來(lái)將用戶自定義的網(wǎng)絡(luò)中自動(dòng)插入量化感知訓(xùn)練組件來(lái)完成量化感知訓(xùn)練了。
以ResNet18為例,它只有Conv+BN這種模式,即對(duì)于任意一個(gè)卷積層后面都跟了一個(gè)BN層,在推理的時(shí)候TensorRT會(huì)做Conv+BN的融合,那么我們?cè)谟?xùn)練的時(shí)候也是必須要做Conv+BN的融合的,不然會(huì)影響部署的精度。所以,我們首先需要把BN層的參數(shù)和卷積層的參數(shù)融合,然后再對(duì)這個(gè)參數(shù)做量化,具體過(guò)程如下圖所示:

下面是Conv和BN融合的公式:
所以:
公式中的,和分別表示卷積層的權(quán)值與偏置,和分別為卷積層的輸入與輸出,則根據(jù)的計(jì)算公式,可以推出融合了batchnorm參數(shù)之后的權(quán)值與偏置,和。
按照這個(gè)公式就可以實(shí)現(xiàn)Conv+BN融合后的量化感知訓(xùn)練組件,在實(shí)現(xiàn)中對(duì)訓(xùn)練和推理的處理有些不一樣的地方,我在代碼中標(biāo)注出來(lái)了。
class?QConvBN(flow.nn.Module):
????def?__init__(
????????self,
????????conv_module,
????????bn_module,
????????quantization_bit=8,
????????quantization_scheme="symmetric",
????????quantization_formula="google",
????????per_layer_quantization=True,
????????momentum=0.95,
????):
????????super().__init__()
????????self.quantization_bit?=?quantization_bit
????????self.quantization_scheme?=?quantization_scheme
????????self.quantization_formula?=?quantization_formula
????????self.per_layer_quantization?=?per_layer_quantization
????????self.conv_module?=?conv_module
????????self.bn_module?=?bn_module
????????self.moving_min_max_observer?=?flow.nn.MovingAverageMinMaxObserver(
????????????training=self.training,
????????????quantization_formula=quantization_formula,
????????????stop_update_after_iters=1,
????????????quantization_bit=quantization_bit,
????????????quantization_scheme=quantization_scheme,
????????????momentum=momentum,
????????)
????????self.min_max_observer?=?flow.nn.MinMaxObserver(
????????????quantization_formula=quantization_formula,
????????????quantization_bit=quantization_bit,
????????????quantization_scheme=quantization_scheme,
????????????per_layer_quantization=per_layer_quantization,
????????)
????????self.fake_quantization?=?flow.nn.FakeQuantization(
????????????quantization_formula=quantization_formula,
????????????quantization_bit=quantization_bit,
????????????quantization_scheme=quantization_scheme,
????????)
????def?fold_bn(self,?mean,?std):
????????if?self.bn_module.affine:
????????????gamma_?=?self.bn_module.weight?/?std
????????????weight?=?self.conv_module.weight?*?gamma_.view(
????????????????self.conv_module.out_channels,?1,?1,?1
????????????)
????????????if?self.conv_module.bias?is?not?None:
????????????????bias?=?(
????????????????????gamma_?*?self.conv_module.bias?-?gamma_?*?mean?+?self.bn_module.bias
????????????????)
????????????else:
????????????????bias?=?self.bn_module.bias?-?gamma_?*?mean
????????else:
????????????gamma_?=?1?/?std
????????????weight?=?self.conv_module.weight?*?gamma_
????????????if?self.conv_module.bias?is?not?None:
????????????????bias?=?gamma_?*?self.conv_module.bias?-?gamma_?*?mean
????????????else:
????????????????bias?=?-gamma_?*?mean
????????return?weight,?bias
????def?forward(self,?x):
????????scale,?zero_point?=?self.moving_min_max_observer(
????????????x,?flow.tensor([0],?dtype=flow.int64).to(x.device.type)
????????)
????????x?=?self.fake_quantization(x,?scale,?zero_point)
????????if?self.training:
????????????y?=?flow.nn.functional.conv2d(
????????????????x,
????????????????self.conv_module.weight,
????????????????self.conv_module.bias,
????????????????stride=self.conv_module.stride,
????????????????padding=self.conv_module.padding,
????????????????dilation=self.conv_module.dilation,
????????????????groups=self.conv_module.groups,
????????????)
????????????y?=?y.permute(1,?0,?2,?3)??#?NCHW?->?CNHW
????????????y?=?y.view(self.conv_module.out_channels,?-1)??#?CNHW?->?C,NHW
????????????mean?=?y.mean(1)
????????????var?=?y.var(1)
????????????with?flow.no_grad():
????????????????self.bn_module.running_mean?=?(
????????????????????self.bn_module.momentum?*?self.bn_module.running_mean
????????????????????+?(1?-?self.bn_module.momentum)?*?mean
????????????????)
????????????????self.bn_module.running_var?=?(
????????????????????self.bn_module.momentum?*?self.bn_module.running_var
????????????????????+?(1?-?self.bn_module.momentum)?*?var
????????????????)
????????else:
????????????mean?=?flow.Tensor(self.bn_module.running_mean)
????????????var?=?flow.Tensor(self.bn_module.running_var)
????????std?=?flow.sqrt(var?+?self.bn_module.eps)
????????weight,?bias?=?self.fold_bn(mean,?std)
????????weight_scale,?weight_zero_point?=?self.min_max_observer(weight)
????????res?=?flow.nn.functional.conv2d(
????????????x,
????????????self.fake_quantization(weight,?weight_scale,?weight_zero_point),
????????????bias,
????????????stride=self.conv_module.stride,
????????????padding=self.conv_module.padding,
????????????dilation=self.conv_module.dilation,
????????????groups=self.conv_module.groups,
????????)
????????return?res
實(shí)現(xiàn)了這個(gè)組件之后我們就可以實(shí)現(xiàn)一個(gè)量化感知訓(xùn)練Pass,即將用戶的nn.Module抽象的計(jì)算圖中的Conv+BN都替換成這個(gè)QConvBN組件,替換部分的代碼實(shí)現(xiàn)如下:
for?x?in?gm.graph.nodes:
????????if?x.target?in?insert_place:
????????????with?gm.graph.inserting_after(x):
????????????????y?=?x.next
????????????????if?(
????????????????????isinstance(insert_op_state[x.target],?flow.nn.Conv2d)
????????????????????and?y.target?in?insert_place
????????????????????and?isinstance(insert_op_state[y.target],?flow.nn.BatchNorm2d)
????????????????):
????????????????????now_target?=?get_current_module_space(x.target)
????????????????????if?now_target?==?"":
????????????????????????now_target?=?f"fake_conv_bn.{cnt}"
????????????????????else:
????????????????????????now_target?=?(
????????????????????????????f"{get_current_module_space(x.target)}.fake_conv_bn.{cnt}"
????????????????????????)
????????????????????gm.add_submodule(
????????????????????????now_target,
????????????????????????QConvBN(
????????????????????????????insert_op_state[x.target],
????????????????????????????insert_op_state[y.target],
????????????????????????????quantization_bit,
????????????????????????????quantization_scheme,
????????????????????????????quantization_formula,
????????????????????????????per_layer_quantization,
????????????????????????????momentum,
????????????????????????),
????????????????????)
????????????????????y.replace_all_uses_with(x)
????????????????????gm.graph.erase_node(y)
????????????????????gm.delete_submodule(y.target)
????????????????????qconvbn?=?gm.graph.call_module(module_name=now_target,?args=x.args,)
????????????????????cnt?=?cnt?+?1
????????????????????x.replace_all_uses_with(qconvbn)
????????????????????gm.graph.erase_node(x)
????????????????????gm.delete_submodule(x.target)
在gm(ResNet18 Trace出來(lái)的GraphModule,仍然是nn.Module)中找到Conv+BN的組件,將其刪除然后替換成QConvBN組件。
0x4. 基于ResNet18量化感知訓(xùn)練Cifar10
基于上面實(shí)現(xiàn)的量化Pass,我們就可以方便的對(duì)自定義的模型進(jìn)行量化感知訓(xùn)練了,以ResNet18為例,我們?cè)谠嫉膭?dòng)態(tài)圖訓(xùn)練代碼基礎(chǔ)上加上下面幾行代碼就可以了。
gm:?flow.fx.GraphModule?=?flow.fx.symbolic_trace(net)
qconfig?=?{
????'quantization_bit':?8,?
????'quantization_scheme':?"symmetric",?
????'quantization_formula':?"cambricon",?
????'per_layer_quantization':?True,
????'momentum':?0.95,
}
net?=?quantization_aware_training(gm,?flow.randn(1,?3,?32,?32),?qconfig)
net?=?net.to(device)
這里qconfig讓用戶可以方便的配置OneFlow支持的各種量化方式。具體可以看之前的文章介紹:基于OneFlow實(shí)現(xiàn)量化感知訓(xùn)練
第一個(gè)net就是用戶定義的動(dòng)態(tài)圖模型,經(jīng)過(guò)這個(gè)Pass之后獲得新的net,新的net就已經(jīng)自動(dòng)插入好了量化感知訓(xùn)練組件。其它的訓(xùn)練和測(cè)試的過(guò)程和普通的FP32訓(xùn)練完全一致,就不贅述了。我基于ResNet18在Cifar10上訓(xùn)練了幾個(gè)OneFlow支持的量化配置,均訓(xùn)練了200個(gè)Epoch,超參一致,結(jié)果如下:
Note:
The?`momentum`?parameter?in?the?`MovingAverageMinMaxObserver`?class?defaults?to?0.95,?which?will?not?be?changed?in?the?following?experiments.?
##?Accuracy
|?Model?????????????|?quantization_bit?|?quantization_scheme?|?quantization_formula?|?per_layer_quantization?|?Acc?|
|?-----------------?|?-----------?|?-----------?|?-----------?|?-----------?|?-----------?|
|?ResNet18??????????|??8?????|??symmetric??????|?google???????|???True?????|??95.19%??????|?
|?ResNet18??????????|??8?????|??symmetric??????|?google???????|???False????|??95.24%??????|?
|?ResNet18??????????|??8?????|??affine?????????|?google???????|???True?????|??95.32%??????|?
|?ResNet18??????????|??8?????|??affine?????????|?google???????|???False????|??95.30%??????|?
|?ResNet18??????????|??8?????|??symmetric??????|?cambricon????|???True?????|??95.19%??????|
工程地址:https://github.com/BBuf/oneflow-cifar。ResNet18在Cifar10上基于FP32訓(xùn)練的精度是:95.62%。這里各種量化參數(shù)下的量化感知訓(xùn)練精度均和原始精度持平。上面的cambricon代表的是寒武紀(jì)量化方案,google代表的是Google的量化方案。
0x5. 基于量化感知訓(xùn)練模型改寫(xiě)原始模型
上面我們已經(jīng)基于量化感知訓(xùn)練模型進(jìn)行了量化感知訓(xùn)練,接下來(lái)我們要考慮怎么部署這個(gè)量化感知訓(xùn)練模型了。顯然現(xiàn)在這個(gè)模型不是我們期望部署的樣子,因?yàn)槲覀冇脕?lái)部署的模型BN應(yīng)該已經(jīng)合并到卷積層里了,而不是被保留下來(lái)。所以我們要基于量化感知訓(xùn)練模型的參數(shù)對(duì)原始模型進(jìn)行改寫(xiě),然后將其用于轉(zhuǎn)化ONNX再到TensorRT。
這里和量化感知訓(xùn)練類似,我們實(shí)現(xiàn)一個(gè)dequantization Pass。這個(gè)Pass用來(lái)將QConvBN組件替換成一個(gè)DConv2d組件。DConv2d組件代碼實(shí)現(xiàn)如下:
class?DConv2d(flow.nn.Conv2d):
????def?__init__(
????????self,
????????in_channels,
????????out_channels,
????????kernel_size,
????????stride,
????????padding,
????????dilation,
????????groups,
????????quantization_bit=8,
????????quantization_scheme="symmetric",
????????quantization_formula="google",
????????per_layer_quantization=True,
????????momentum=0.95,
????)?->?None:
????????super(DConv2d,?self).__init__(
????????????in_channels,?out_channels,?kernel_size,?stride,?padding,?dilation,?groups
????????)
????????self.moving_min_max_observer?=?flow.nn.MovingAverageMinMaxObserver(
????????????training=self.training,
????????????quantization_formula=quantization_formula,
????????????stop_update_after_iters=1,
????????????quantization_bit=quantization_bit,
????????????quantization_scheme=quantization_scheme,
????????????momentum=momentum,
????????)
????????self.min_max_observer?=?flow.nn.MinMaxObserver(
????????????quantization_formula=quantization_formula,
????????????quantization_bit=quantization_bit,
????????????quantization_scheme=quantization_scheme,
????????????per_layer_quantization=per_layer_quantization,
????????)
????????self.fake_quantization?=?flow.nn.FakeQuantization(
????????????quantization_formula=quantization_formula,
????????????quantization_bit=quantization_bit,
????????????quantization_scheme=quantization_scheme,
????????)
????????self.register_buffer("new_zero",?flow.Tensor(1))
????????self.new_zero.fill_(0)
????def?forward(self,?x):
????????scale,?zero_point?=?self.moving_min_max_observer(
????????????x,?self.new_zero.to(flow.int64).to(x.device.type)
????????)
????????x?=?self.fake_quantization(x,?scale,?zero_point)
????????return?flow.nn.functional.conv2d(
????????????x,
????????????self.weight,
????????????self.bias,
????????????stride=self.stride,
????????????padding=self.padding,
????????????dilation=self.dilation,
????????????groups=self.groups,
????????)
然后我們只需要將原始的ResNet18模型里面的Conv+BN換成這個(gè)組件即可,請(qǐng)注意!!!這個(gè)組件的權(quán)重和偏置以及moving_min_max_observer的moving_min/max參數(shù)要賦值為訓(xùn)練好的量化感知模型的QConvBN組件對(duì)應(yīng)的權(quán)重和偏置以及moving_min_max_observer的moving_min/max參數(shù)。dequantization Pass的核心部分如下:
for?x?in?origin_gm.graph.nodes:
????????if?x.target?in?insert_place:
????????????with?origin_gm.graph.inserting_after(x):
????????????????y?=?x.next
????????????????if?(
????????????????????isinstance(insert_op_state[x.target],?flow.nn.Conv2d)
????????????????????and?y.target?in?insert_place
????????????????????and?isinstance(insert_op_state[y.target],?flow.nn.BatchNorm2d)
????????????????):
????????????????????now_target?=?get_current_module_space(x.target)
????????????????????if?now_target?==?"":
????????????????????????now_target?=?f"fake_conv_bn.{cnt}"
????????????????????else:
????????????????????????now_target?=?(
????????????????????????????f"{get_current_module_space(x.target)}.fake_conv_bn.{cnt}"
????????????????????????)
????????????????????dequanzation_conv?=?DConv2d(
????????????????????????quantization_op_state[now_target].conv_module.in_channels,?
????????????????????????quantization_op_state[now_target].conv_module.out_channels,
????????????????????????quantization_op_state[now_target].conv_module.kernel_size,
????????????????????????quantization_op_state[now_target].conv_module.stride,
????????????????????????quantization_op_state[now_target].conv_module.padding,
????????????????????????quantization_op_state[now_target].conv_module.dilation,
????????????????????????quantization_op_state[now_target].conv_module.groups,
????????????????????????quantization_bit,
????????????????????????quantization_scheme,
????????????????????????quantization_formula,
????????????????????????per_layer_quantization,
????????????????????????momentum,
????????????????????????)
????????????????????mean?=?flow.Tensor(quantization_op_state[now_target].bn_module.running_mean)
????????????????????var?=?flow.Tensor(quantization_op_state[now_target].bn_module.running_var)
????????????????????std?=?flow.sqrt(var?+?quantization_op_state[now_target].bn_module.eps)
????????????????????if?quantization_op_state[now_target].bn_module.affine:
????????????????????????gamma_?=?quantization_op_state[now_target].bn_module.weight?/?std
????????????????????????weight?=?quantization_op_state[now_target].conv_module.weight?*?gamma_.view(
????????????????????????????quantization_op_state[now_target].conv_module.out_channels,?1,?1,?1
????????????????????????)
????????????????????????if?quantization_op_state[now_target].conv_module.bias?is?not?None:
????????????????????????????bias?=?(
????????????????????????????????gamma_?*?quantization_op_state[now_target].conv_module.bias?-?gamma_?*?mean?+?quantization_op_state[now_target].bn_module.bias
????????????????????????????)
????????????????????????else:
????????????????????????????bias?=?quantization_op_state[now_target].bn_module.bias?-?gamma_?*?mean
????????????????????else:
????????????????????????gamma_?=?1?/?std
????????????????????????weight?=?quantization_op_state[now_target].conv_module.weight?*?gamma_
????????????????????????if?quantization_op_state[now_target].conv_module.bias?is?not?None:
????????????????????????????bias?=?gamma_?*?quantization_op_state[now_target].conv_module.bias?-?gamma_?*?mean
????????????????????????else:
????????????????????????????bias?=?-gamma_?*?mean
????????????????????dequanzation_conv.weight?=?flow.nn.Parameter(weight)
????????????????????dequanzation_conv.bias?=?flow.nn.Parameter(bias)
????????????????????dequanzation_conv.moving_min_max_observer.moving_max?=?quantization_op_state[now_target].moving_min_max_observer.moving_max
????????????????????dequanzation_conv.moving_min_max_observer.moving_min?=?quantization_op_state[now_target].moving_min_max_observer.moving_min
????????????????????origin_gm.add_submodule(
????????????????????????now_target,
????????????????????????dequanzation_conv,
????????????????????)
????????????????????y.replace_all_uses_with(x)
????????????????????origin_gm.graph.erase_node(y)
????????????????????origin_gm.delete_submodule(y.target)
????????????????????qconvbn?=?origin_gm.graph.call_module(module_name=now_target,?args=x.args,)
????????????????????cnt?=?cnt?+?1
????????????????????x.replace_all_uses_with(qconvbn)
????????????????????origin_gm.graph.erase_node(x)
????????????????????origin_gm.delete_submodule(x.target)
這里手動(dòng)執(zhí)行了Conv和BN融合的工作并把融合后的權(quán)重和偏置賦給DConv2d組件。
0x6. 轉(zhuǎn)換ONNX以及TensorRT推理
基于量化感知訓(xùn)練的模型以及dequantization Pass,我們就可以獲得用于推理時(shí)的nn.Module了。我們將這個(gè)nn.Module轉(zhuǎn)換成ONNX然后再放到TensorRT中進(jìn)行推理就可以了。這部分的示例代碼在:https://github.com/Oneflow-Inc/oneflow_convert/blob/add_fx_train_quantization/examples/oneflow2onnx/quantization/test_resnet18.py。我們截取核心部分進(jìn)行解釋。
#?加載訓(xùn)練好的量化模型權(quán)重
quantization_resnet18?=?quantization_aware_training(gm,?flow.randn(1,?3,?32,?32).to("cuda"),?qconfig)
quantization_resnet18?=?quantization_resnet18.to("cuda")
quantization_resnet18.eval()
checkpoint?=?flow.load('/home/zhangxiaoyu/oneflow-cifar/checkpoint/epoch_11_val_acc_83.280000')
quantization_resnet18.load_state_dict(checkpoint)
#?基于量化感知訓(xùn)練模型改寫(xiě)原始模型
origin_gm:?flow.fx.GraphModule?=?flow.fx.symbolic_trace(resnet18)
dequantization_resnet18?=?dequantization_aware_training(origin_gm,?gm,?flow.randn(1,?3,?32,?32).to("cuda"),?qconfig)
dequantization_resnet18?=?dequantization_resnet18.to("cuda")
dequantization_resnet18.eval()
#?nn.Graph是轉(zhuǎn)ONNX的橋梁,是把OneFlow的動(dòng)態(tài)圖轉(zhuǎn)為靜態(tài)圖
class?ResNet18Graph(flow.nn.Graph):
????def?__init__(self):
????????super().__init__()
????????self.m?=?dequantization_resnet18
????def?build(self,?x):
????????out?=?self.m(x)
????????return?out
#?測(cè)試函數(shù)
def?test_resnet():???
????resnet_graph?=?ResNet18Graph()
????resnet_graph._compile(flow.randn(1,?3,?32,?32).to("cuda"))
????with?tempfile.TemporaryDirectory()?as?tmpdirname:
????????flow.save(dequantization_resnet18.state_dict(),?tmpdirname)
????????convert_to_onnx_and_check(resnet_graph,?flow_weight_dir=tmpdirname,?onnx_model_path="/tmp",?print_outlier=True)
????????ipt_dict,?onnx_res?=?run_onnx("/tmp/model.onnx",?get_onnx_provider("cpu"))
????????trt_res?=?run_tensorrt("/tmp/model.onnx",?ipt_dict[list(ipt_dict.keys())[0]])
????????compare_result(onnx_res,?trt_res,?atol=1e-4,?print_outlier=True)
test_resnet()
首先我們使用dequantization Pass將原始模型改寫(xiě)成了部署時(shí)的模型,并且在這個(gè)Pass中同步處理好了權(quán)重的更改。然后我們將現(xiàn)在需要部署的這個(gè)模型(類型是nn.Module)通過(guò)OneFlow的nn.Graph將其轉(zhuǎn)為靜態(tài)圖,nn.Graph的資料見(jiàn):https://docs.oneflow.org/master/basics/08_nn_graph.html。
為什么要nn.Graph這一步?這是因?yàn)镺neFlow的轉(zhuǎn)化ONNX工具是基于靜態(tài)圖做的,所以額外多了這一步,如果你不想理解也沒(méi)關(guān)系,上面的代碼中已經(jīng)展示了完整的用法了。
要使用OneFlow->ONNX的轉(zhuǎn)化工具需要安裝下面的包:
python>=3.5
onnx>=1.8.0
onnxruntime>=1.6.0
oneflow>=0.5.0
然后pip install oneflow_onnx
然后調(diào)用oneflow_onnx中的convert_to_onnx_and_check API將量化訓(xùn)練模型轉(zhuǎn)化為ONNX。我們看一眼量化感知訓(xùn)練后的ResNet18轉(zhuǎn)化成ONNX之后長(zhǎng)什么樣子吧。

然后我們還需要用TesnsorRT來(lái)運(yùn)行這個(gè)量化感知訓(xùn)練模型,也要配置一些環(huán)境。我們需要安裝:
onnx>=1.8.0
onnxruntime-gpu>=1.8.0
opencv-python
pytest
nvidia-tensorrt==8.0.0.3
pycuda
flake8
這些包就緒之后就可以使用TensorRT來(lái)推理了。即上面的代碼:
ipt_dict,?onnx_res?=?run_onnx("/tmp/model.onnx",?get_onnx_provider("cpu"))
trt_res?=?run_tensorrt("/tmp/model.onnx",?ipt_dict[list(ipt_dict.keys())[0]])
compare_result(onnx_res,?trt_res,?atol=1e-4,?print_outlier=True)
具體的推理代碼和其它細(xì)節(jié)可以去代碼倉(cāng)庫(kù)看,這里展示一下最后的結(jié)果。在相同的輸隨機(jī)入下,ONNX的結(jié)果和TensorRT推理結(jié)果基本一致:
-2.9825006?-2.9825
-5.438802?-5.4388037
3.5198674?3.5198674
2.409646?2.4096458
4.5826764?4.5826764
0.019911028?0.019910894
6.6347113?6.634712
-3.5996702?-3.5996711
-1.3407612?-1.340761
-3.8473191?-3.847319
至此,我們完成了將原始的動(dòng)態(tài)圖模型通過(guò)量化感知訓(xùn)練后部署到了GPU上進(jìn)行推理,整個(gè)過(guò)程雖然我開(kāi)發(fā)的波折比較大,但總算完成了基礎(chǔ)功能的開(kāi)發(fā),感謝我的同事們。
我想你可能會(huì)好奇,這里為什么沒(méi)有給精度和速度對(duì)比,因?yàn)槟壳拔沂稚峡ú粔蜻€不能做更好的實(shí)驗(yàn)(比如用更大的數(shù)據(jù)集訓(xùn)練)所以只能用Cifar10跑一下精度。關(guān)于速度測(cè)試方面,TensorRT那部分需要排除編譯engine的影響只計(jì)算推理那部分的時(shí)間,我還沒(méi)有改那部分代碼,讀者如果感興趣可以先自行計(jì)算一下時(shí)間。后面可能會(huì)專門(mén)寫(xiě)一篇文章來(lái)介紹一下部署前后的精度和速度對(duì)比,另外目前實(shí)現(xiàn)的方案可能還存在漏洞需要更加精細(xì)的Check。
總的來(lái)說(shuō),這篇文章只是一篇學(xué)習(xí)交流筆記,所以目前并不會(huì)正式的給出量化感知訓(xùn)練精度和速度的BenchMark。因?yàn)樵诤喜⒌絆neFlow主分支前還有諸多的工程問(wèn)題需要解決。
0x7. 總結(jié)
這篇文章分享的是筆者最近在OneFlow做的一個(gè)項(xiàng)目,將Pytorch FX移植到OneFlow之后實(shí)現(xiàn)了自動(dòng)量化感知訓(xùn)練動(dòng)態(tài)圖模型(在Pytorch和OneFlow中都稱為nn.Module)。現(xiàn)在用戶可以在自己構(gòu)建的nn.Module基礎(chǔ)上,修改很少的代碼即可完成從nn.Module量化感知訓(xùn)練到用TensorRT將量化感知訓(xùn)練后的模型部署到GPU上運(yùn)行的完整鏈路。在TensorRT上推理是利用了ONNX作為中間表示,即Oneflow動(dòng)態(tài)圖模型(nn.Module)->OneFlow量化感知訓(xùn)練模型(nn.Module)->OneFlow靜態(tài)圖(nn.Graph)->ONNX->TensorRT。量化感知訓(xùn)練是基于支持在Eager下寫(xiě)Pass的FX模塊(FX被Pytorch率先提出,筆者將其基礎(chǔ)設(shè)施移植到了OneFlow)來(lái)完成的。讀者如果想體驗(yàn)這個(gè)功能可以按照本文的方法進(jìn)行操作,有任何使用上的問(wèn)題可以聯(lián)系筆者。
0x8. 相關(guān)鏈接和學(xué)習(xí)資料
https://docs.oneflow.org https://github.com/Oneflow-Inc/oneflow https://github.com/Oneflow-Inc/oneflow_convert https://github.com/BBuf/oneflow-cifar 神經(jīng)網(wǎng)絡(luò)量化入門(mén)--Folding BN ReLU代碼實(shí)現(xiàn) 基于OneFlow實(shí)現(xiàn)量化感知訓(xùn)練
