100行代碼使用torch.fx極簡量化教程

極市導(dǎo)讀
?本文使用100行代碼,極簡的教大家入門比較標(biāo)準(zhǔn)的量化步驟,從怎么用、用在哪里、哪里不能用等問題都將涵蓋。?>>加入極市CV技術(shù)交流群,走在計算機(jī)視覺的最前沿
網(wǎng)上很多關(guān)于量化的文章,要么就是跑一跑官方殘缺的例子,要么就是過舊的API,早已經(jīng)不潮流。現(xiàn)在比較fashion的方式,是使用 torch.fx來做量化。本文將使用100行代碼,極簡的教你入門比較標(biāo)準(zhǔn)的量化步驟。這些步驟不是簡單的告訴你torch.fx有什么卵用,大家都知道它有什么卵用,只是怎么用,用在哪里,哪里不能用,這些問題需要解答。本文100行代碼,麻雀雖小五臟俱全,不管你量化什么模型,一頓套用就是了,出了問題我背鍋。
很多古老的文章,還在用手動插入stub來做量化節(jié)點(diǎn),這就好比在21世紀(jì)還在飛鴿傳書。我們必然會包含一下幾個完整的內(nèi)容:
fx怎么插入量化節(jié)點(diǎn),不要嚇倒,這就一行代碼; 量化的模型怎么保存權(quán)重到本地; 怎么把量化后的權(quán)重再load回來; 怎么做calibration,做跟不做區(qū)別多大; fx到底有沒有局限性;
以上問題,本文都將囊括。
量化前期知識
此處省略三萬字,具體大家清百度。沒啥好講的。
量化現(xiàn)狀
如果你要問我現(xiàn)在最好的量化工具是什么,我的回答是沒有。真的,不管是 nni,還是 nvidia的 pytorch_quantization ,還是nncf so on,不是說這些東西不好,而是在做的各位都是垃圾。
這些東西本質(zhì)上是在做一件事情,至少從量化角度上看是這樣的,但是到最后不具備通用性,當(dāng)你看到 pytorch_quanzation 這個工具保存的模型體積根float32一樣的時候,就會開始懷疑人生了,這tm是人干的事兒?這就好比普通人想要中杯,他便要說這是大杯。
輪子不好用,那就只能自己造輪子了。只能說,torch.fxyyds. 用了都說好,誰用誰知道。
100行代碼
talk is cheap,我們直接上代碼。需要注意的是,torch.fx最好使用最新的stable版本,老版本API或有不同之處,我測試的是 `1.11`。
由于pytorch的自帶的 imagnet系列模型,我們沒有辦法做calibration,我們用小一些的Cifra10, 不需要下載,pytorch自己可以處理,但是這就需要我們自己finetune一下。
先把finetune的代碼備好:
這只是用來fintune一個我們準(zhǔn)備去量化,并且校準(zhǔn)的模型:
import?torch
import?torch.nn?as?nn
import?torch.nn.functional?as?F
import?copy
import?torchvision
from?torchvision?import?transforms
from?torchvision.models.resnet?import?resnet50,?resnet18
from?torch.quantization.quantize_fx?import?prepare_fx,?convert_fx
from?torch.ao.quantization.fx.graph_module?import?ObservedGraphModule
from?torch.quantization?import?(
????get_default_qconfig,
)
from?torch?import?optim
import?os
import?time
def?train_model(model,?train_loader,?test_loader,?device):
????#?The?training?configurations?were?not?carefully?selected.
????learning_rate?=?1e-2
????num_epochs?=?20
????criterion?=?nn.CrossEntropyLoss()
????model.to(device)
????#?It?seems?that?SGD?optimizer?is?better?than?Adam?optimizer?for?ResNet18?training?on?CIFAR10.
????optimizer?=?optim.SGD(
????????model.parameters(),?lr=learning_rate,?momentum=0.9,?weight_decay=1e-5
????)
????#?optimizer?=?optim.Adam(model.parameters(),?lr=learning_rate,?betas=(0.9,?0.999),?eps=1e-08,?weight_decay=0,?amsgrad=False)
????for?epoch?in?range(num_epochs):
????????#?Training
????????model.train()
????????running_loss?=?0
????????running_corrects?=?0
????????for?inputs,?labels?in?train_loader:
????????????inputs?=?inputs.to(device)
????????????labels?=?labels.to(device)
????????????#?zero?the?parameter?gradients
????????????optimizer.zero_grad()
????????????#?forward?+?backward?+?optimize
????????????outputs?=?model(inputs)
????????????_,?preds?=?torch.max(outputs,?1)
????????????loss?=?criterion(outputs,?labels)
????????????loss.backward()
????????????optimizer.step()
????????????#?statistics
????????????running_loss?+=?loss.item()?*?inputs.size(0)
????????????running_corrects?+=?torch.sum(preds?==?labels.data)
????????train_loss?=?running_loss?/?len(train_loader.dataset)
????????train_accuracy?=?running_corrects?/?len(train_loader.dataset)
????????#?Evaluation
????????model.eval()
????????eval_loss,?eval_accuracy?=?evaluate_model(
????????????model=model,?test_loader=test_loader,?device=device,?criterion=criterion
????????)
????????print(
????????????"Epoch:?{:02d}?Train?Loss:?{:.3f}?Train?Acc:?{:.3f}?Eval?Loss:?{:.3f}?Eval?Acc:?{:.3f}".format(
????????????????epoch,?train_loss,?train_accuracy,?eval_loss,?eval_accuracy
????????????)
????????)
????return?model
def?prepare_dataloader(num_workers=8,?train_batch_size=128,?eval_batch_size=256):
????train_transform?=?transforms.Compose(
????????[
????????????transforms.RandomCrop(32,?padding=4),
????????????transforms.RandomHorizontalFlip(),
????????????transforms.ToTensor(),
????????????transforms.Normalize((0.4914,?0.4822,?0.4465),?(0.2023,?0.1994,?0.2010)),
????????]
????)
????test_transform?=?transforms.Compose(
????????[
????????????transforms.ToTensor(),
????????????transforms.Normalize((0.4914,?0.4822,?0.4465),?(0.2023,?0.1994,?0.2010)),
????????]
????)
????train_set?=?torchvision.datasets.CIFAR10(
????????root="data",?train=True,?download=True,?transform=train_transform
????)
????#?We?will?use?test?set?for?validation?and?test?in?this?project.
????#?Do?not?use?test?set?for?validation?in?practice!
????test_set?=?torchvision.datasets.CIFAR10(
????????root="data",?train=False,?download=True,?transform=test_transform
????)
????train_sampler?=?torch.utils.data.RandomSampler(train_set)
????test_sampler?=?torch.utils.data.SequentialSampler(test_set)
????train_loader?=?torch.utils.data.DataLoader(
????????dataset=train_set,
????????batch_size=train_batch_size,
????????sampler=train_sampler,
????????num_workers=num_workers,
????)
????test_loader?=?torch.utils.data.DataLoader(
????????dataset=test_set,
????????batch_size=eval_batch_size,
????????sampler=test_sampler,
????????num_workers=num_workers,
????)
????return?train_loader,?test_loader
然后訓(xùn)練一波模型:
if?__name__?==?"__main__":
????train_loader,?test_loader?=?prepare_dataloader()
????#?first?finetune?model?on?cifar,?we?don't?have?imagnet?so?using?cifar?as?test
????model?=?resnet18(pretrained=True)
????model.fc?=?nn.Linear(512,?10)
????if?os.path.exists("r18_row.pth"):
????????model.load_state_dict(torch.load("r18_row.pth",?map_location="cpu"))
????else:
????????train_model(model,?train_loader,?test_loader,?torch.device("cuda"))
????????print("train?finished.")
????????torch.save(model.state_dict(),?"r18_row.pth")
接下來就是核心代碼:
def?quant_fx(model):
????model.eval()
????qconfig?=?get_default_qconfig("fbgemm")
????qconfig_dict?=?{
????????"":?qconfig,
????????#?'object_type':?[]
????}
????model_to_quantize?=?copy.deepcopy(model)
????prepared_model?=?prepare_fx(model_to_quantize,?qconfig_dict)
????print("prepared?model:?",?prepared_model)
????quantized_model?=?convert_fx(prepared_model)
????print("quantized?model:?",?quantized_model)
????torch.save(model.state_dict(),?"r18.pth")
????torch.save(quantized_model.state_dict(),?"r18_quant.pth")
懂了嗎?很快阿,啪一下,一個int8的量化模型就生成了。
沒錯,其實(shí)都不用100行,15行就夠了。torch.fx 就是這么的牛逼!
我們做一個evaluation,來驗(yàn)證一下,在不校準(zhǔn)的情況下,精度如何:
def?evaluate_model(model,?test_loader,?device=torch.device("cpu"),?criterion=None):
????t0?=?time.time()
????model.eval()
????model.to(device)
????running_loss?=?0
????running_corrects?=?0
????for?inputs,?labels?in?test_loader:
????????inputs?=?inputs.to(device)
????????labels?=?labels.to(device)
????????outputs?=?model(inputs)
????????_,?preds?=?torch.max(outputs,?1)
????????if?criterion?is?not?None:
????????????loss?=?criterion(outputs,?labels).item()
????????else:
????????????loss?=?0
????????#?statistics
????????running_loss?+=?loss?*?inputs.size(0)
????????running_corrects?+=?torch.sum(preds?==?labels.data)
????eval_loss?=?running_loss?/?len(test_loader.dataset)
????eval_accuracy?=?running_corrects?/?len(test_loader.dataset)
????t1?=?time.time()
????print(f"eval?loss:?{eval_loss},?eval?acc:?{eval_accuracy},?cost:?{t1?-?t0}")
????return?eval_loss,?eval_accuracy
這是evaluation的結(jié)果:
eval?loss:?0.0,?eval?acc:?0.8476999998092651,?cost:?2.8914074897766113
eval?loss:?0.0,?eval?acc:?0.15240000188350677,?cost:?1.240293264389038
可以看到,精度下降嚴(yán)重。此時需要進(jìn)行一下校準(zhǔn),我直接放校準(zhǔn)函數(shù):
def?calib_quant_model(model,?calib_dataloader):
????assert?isinstance(
????????model,?ObservedGraphModule
????),?"model?must?be?a?perpared?fx?ObservedGraphModule."
????model.eval()
????with?torch.inference_mode():
????????for?inputs,?labels?in?calib_dataloader:
????????????model(inputs)
????print("calib?done.")
that's all. 就這么簡單。
如果你有其他非分類模型,也可以直接把dataloader丟進(jìn)來。請注意,這里的標(biāo)簽并沒有用到。只需要統(tǒng)計數(shù)據(jù)的分布即可。
非常簡單。
最后我們再次eval一下:
def?quant_calib_and_eval(model):
????#?test?only?on?CPU
????model.to(torch.device("cpu"))
????model.eval()
????qconfig?=?get_default_qconfig("fbgemm")
????qconfig_dict?=?{
????????"":?qconfig,
????????#?'object_type':?[]
????}
????model2?=?copy.deepcopy(model)
????model_prepared?=?prepare_fx(model2,?qconfig_dict)
????model_int8?=?convert_fx(model_prepared)
????model_int8.load_state_dict(torch.load("r18_quant.pth"))
????model_int8.eval()
????a?=?torch.randn([1,?3,?224,?224])
????o1?=?model(a)
????o2?=?model_int8(a)
????diff?=?torch.allclose(o1,?o2,?1e-4)
????print(diff)
????print(o1.shape,?o2.shape)
????print(o1,?o2)
????get_output_from_logits(o1)
????get_output_from_logits(o2)
????train_loader,?test_loader?=?prepare_dataloader()
????evaluate_model(model,?test_loader)
????evaluate_model(model_int8,?test_loader)
????#?calib?quant?model
????model2?=?copy.deepcopy(model)
????model_prepared?=?prepare_fx(model2,?qconfig_dict)
????model_int8?=?convert_fx(model_prepared)
????torch.save(model_int8.state_dict(),?"r18.pth")
????model_int8.eval()
????model_prepared?=?prepare_fx(model2,?qconfig_dict)
????calib_quant_model(model_prepared,?test_loader)
????model_int8?=?convert_fx(model_prepared)
????torch.save(model_int8.state_dict(),?"r18_quant_calib.pth")
????evaluate_model(model_int8,?test_loader)
得到結(jié)果:
eval?loss:?0.0,?eval?acc:?0.8476999998092651,?cost:?2.8914074897766113
eval?loss:?0.0,?eval?acc:?0.15240000188350677,?cost:?1.240293264389038
calib?done.
eval?loss:?0.0,?eval?acc:?0.8442999720573425,?cost:?1.2966759204864502
精度瞬間恢復(fù)了。速度快了超過一半。
總結(jié)
ok,我們用幾十行代碼就完成這個量化過程。并且使用校準(zhǔn),恢復(fù)了精度。由此可見fx的強(qiáng)大之處。
拋出一個問題,歡迎留言區(qū)解答:
torch.fx量化的模型,如果export 到onnx并使用其他前推引擎推理。
公眾號后臺回復(fù)“CVPR 2022”獲取論文打包合集下載~

#?CV技術(shù)社群邀請函?#

備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測-深圳)
即可申請加入極市目標(biāo)檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項(xiàng)目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與?10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~

