<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          100行代碼使用torch.fx極簡量化教程

          共 8889字,需瀏覽 18分鐘

           ·

          2022-04-17 17:09


          作者丨金天@知乎(已授權(quán))
          來源丨h(huán)ttps://zhuanlan.zhihu.com/p/498286238
          編輯丨極市平臺(tái)

          導(dǎo)讀

          ?

          本文使用100行代碼,極簡的教大家入門比較標(biāo)準(zhǔn)的量化步驟,從怎么用、用在哪里、哪里不能用等問題都將涵蓋。?

          網(wǎng)上很多關(guān)于量化的文章,要么就是跑一跑官方殘缺的例子,要么就是過舊的API,早已經(jīng)不潮流。現(xiàn)在比較fashion的方式,是使用 torch.fx來做量化。本文將使用100行代碼,極簡的教你入門比較標(biāo)準(zhǔn)的量化步驟。這些步驟不是簡單的告訴你torch.fx有什么卵用,大家都知道它有什么卵用,只是怎么用,用在哪里,哪里不能用,這些問題需要解答。本文100行代碼,麻雀雖小五臟俱全,不管你量化什么模型,一頓套用就是了,出了問題我背鍋。

          很多古老的文章,還在用手動(dòng)插入stub來做量化節(jié)點(diǎn),這就好比在21世紀(jì)還在飛鴿傳書。我們必然會(huì)包含一下幾個(gè)完整的內(nèi)容:

          • fx怎么插入量化節(jié)點(diǎn),不要嚇倒,這就一行代碼;
          • 量化的模型怎么保存權(quán)重到本地;
          • 怎么把量化后的權(quán)重再load回來;
          • 怎么做calibration,做跟不做區(qū)別多大;
          • fx到底有沒有局限性;

          以上問題,本文都將囊括。

          量化前期知識(shí)

          此處省略三萬字,具體大家清百度。沒啥好講的。

          量化現(xiàn)狀

          如果你要問我現(xiàn)在最好的量化工具是什么,我的回答是沒有。真的,不管是 nni,還是 nvidia的 pytorch_quantization ,還是nncf so on,不是說這些東西不好,而是在做的各位都是垃圾。

          這些東西本質(zhì)上是在做一件事情,至少從量化角度上看是這樣的,但是到最后不具備通用性,當(dāng)你看到 pytorch_quanzation 這個(gè)工具保存的模型體積根float32一樣的時(shí)候,就會(huì)開始懷疑人生了,這tm是人干的事兒?這就好比普通人想要中杯,他便要說這是大杯。

          輪子不好用,那就只能自己造輪子了。只能說,torch.fxyyds. 用了都說好,誰用誰知道。

          100行代碼

          talk is cheap,我們直接上代碼。需要注意的是,torch.fx最好使用最新的stable版本,老版本API或有不同之處,我測(cè)試的是 `1.11`。

          由于pytorch的自帶的 imagnet系列模型,我們沒有辦法做calibration,我們用小一些的Cifra10, 不需要下載,pytorch自己可以處理,但是這就需要我們自己finetune一下。

          先把finetune的代碼備好:

          這只是用來fintune一個(gè)我們準(zhǔn)備去量化,并且校準(zhǔn)的模型:

          import?torch
          import?torch.nn?as?nn
          import?torch.nn.functional?as?F
          import?copy
          import?torchvision
          from?torchvision?import?transforms
          from?torchvision.models.resnet?import?resnet50,?resnet18
          from?torch.quantization.quantize_fx?import?prepare_fx,?convert_fx
          from?torch.ao.quantization.fx.graph_module?import?ObservedGraphModule
          from?torch.quantization?import?(
          ????get_default_qconfig,
          )
          from?torch?import?optim
          import?os
          import?time


          def?train_model(model,?train_loader,?test_loader,?device):
          ????#?The?training?configurations?were?not?carefully?selected.
          ????learning_rate?=?1e-2
          ????num_epochs?=?20
          ????criterion?=?nn.CrossEntropyLoss()
          ????model.to(device)
          ????#?It?seems?that?SGD?optimizer?is?better?than?Adam?optimizer?for?ResNet18?training?on?CIFAR10.
          ????optimizer?=?optim.SGD(
          ????????model.parameters(),?lr=learning_rate,?momentum=0.9,?weight_decay=1e-5
          ????)
          ????#?optimizer?=?optim.Adam(model.parameters(),?lr=learning_rate,?betas=(0.9,?0.999),?eps=1e-08,?weight_decay=0,?amsgrad=False)
          ????for?epoch?in?range(num_epochs):
          ????????#?Training
          ????????model.train()

          ????????running_loss?=?0
          ????????running_corrects?=?0

          ????????for?inputs,?labels?in?train_loader:
          ????????????inputs?=?inputs.to(device)
          ????????????labels?=?labels.to(device)

          ????????????#?zero?the?parameter?gradients
          ????????????optimizer.zero_grad()

          ????????????#?forward?+?backward?+?optimize
          ????????????outputs?=?model(inputs)
          ????????????_,?preds?=?torch.max(outputs,?1)
          ????????????loss?=?criterion(outputs,?labels)
          ????????????loss.backward()
          ????????????optimizer.step()

          ????????????#?statistics
          ????????????running_loss?+=?loss.item()?*?inputs.size(0)
          ????????????running_corrects?+=?torch.sum(preds?==?labels.data)

          ????????train_loss?=?running_loss?/?len(train_loader.dataset)
          ????????train_accuracy?=?running_corrects?/?len(train_loader.dataset)

          ????????#?Evaluation
          ????????model.eval()
          ????????eval_loss,?eval_accuracy?=?evaluate_model(
          ????????????model=model,?test_loader=test_loader,?device=device,?criterion=criterion
          ????????)
          ????????print(
          ????????????"Epoch:?{:02d}?Train?Loss:?{:.3f}?Train?Acc:?{:.3f}?Eval?Loss:?{:.3f}?Eval?Acc:?{:.3f}".format(
          ????????????????epoch,?train_loss,?train_accuracy,?eval_loss,?eval_accuracy
          ????????????)
          ????????)
          ????return?model

          def?prepare_dataloader(num_workers=8,?train_batch_size=128,?eval_batch_size=256):
          ????train_transform?=?transforms.Compose(
          ????????[
          ????????????transforms.RandomCrop(32,?padding=4),
          ????????????transforms.RandomHorizontalFlip(),
          ????????????transforms.ToTensor(),
          ????????????transforms.Normalize((0.4914,?0.4822,?0.4465),?(0.2023,?0.1994,?0.2010)),
          ????????]
          ????)
          ????test_transform?=?transforms.Compose(
          ????????[
          ????????????transforms.ToTensor(),
          ????????????transforms.Normalize((0.4914,?0.4822,?0.4465),?(0.2023,?0.1994,?0.2010)),
          ????????]
          ????)
          ????train_set?=?torchvision.datasets.CIFAR10(
          ????????root="data",?train=True,?download=True,?transform=train_transform
          ????)
          ????#?We?will?use?test?set?for?validation?and?test?in?this?project.
          ????#?Do?not?use?test?set?for?validation?in?practice!
          ????test_set?=?torchvision.datasets.CIFAR10(
          ????????root="data",?train=False,?download=True,?transform=test_transform
          ????)
          ????train_sampler?=?torch.utils.data.RandomSampler(train_set)
          ????test_sampler?=?torch.utils.data.SequentialSampler(test_set)

          ????train_loader?=?torch.utils.data.DataLoader(
          ????????dataset=train_set,
          ????????batch_size=train_batch_size,
          ????????sampler=train_sampler,
          ????????num_workers=num_workers,
          ????)
          ????test_loader?=?torch.utils.data.DataLoader(
          ????????dataset=test_set,
          ????????batch_size=eval_batch_size,
          ????????sampler=test_sampler,
          ????????num_workers=num_workers,
          ????)
          ????return?train_loader,?test_loader

          然后訓(xùn)練一波模型:

          if?__name__?==?"__main__":
          ????train_loader,?test_loader?=?prepare_dataloader()

          ????#?first?finetune?model?on?cifar,?we?don't?have?imagnet?so?using?cifar?as?test
          ????model?=?resnet18(pretrained=True)
          ????model.fc?=?nn.Linear(512,?10)
          ????if?os.path.exists("r18_row.pth"):
          ????????model.load_state_dict(torch.load("r18_row.pth",?map_location="cpu"))
          ????else:
          ????????train_model(model,?train_loader,?test_loader,?torch.device("cuda"))
          ????????print("train?finished.")
          ????????torch.save(model.state_dict(),?"r18_row.pth")

          接下來就是核心代碼:

          def?quant_fx(model):
          ????model.eval()
          ????qconfig?=?get_default_qconfig("fbgemm")
          ????qconfig_dict?=?{
          ????????"":?qconfig,
          ????????#?'object_type':?[]
          ????}
          ????model_to_quantize?=?copy.deepcopy(model)
          ????prepared_model?=?prepare_fx(model_to_quantize,?qconfig_dict)
          ????print("prepared?model:?",?prepared_model)

          ????quantized_model?=?convert_fx(prepared_model)
          ????print("quantized?model:?",?quantized_model)
          ????torch.save(model.state_dict(),?"r18.pth")
          ????torch.save(quantized_model.state_dict(),?"r18_quant.pth")

          懂了嗎?很快阿,啪一下,一個(gè)int8的量化模型就生成了。

          沒錯(cuò),其實(shí)都不用100行,15行就夠了。torch.fx 就是這么的牛逼!

          我們做一個(gè)evaluation,來驗(yàn)證一下,在不校準(zhǔn)的情況下,精度如何:

          def?evaluate_model(model,?test_loader,?device=torch.device("cpu"),?criterion=None):
          ????t0?=?time.time()
          ????model.eval()
          ????model.to(device)
          ????running_loss?=?0
          ????running_corrects?=?0
          ????for?inputs,?labels?in?test_loader:

          ????????inputs?=?inputs.to(device)
          ????????labels?=?labels.to(device)
          ????????outputs?=?model(inputs)
          ????????_,?preds?=?torch.max(outputs,?1)

          ????????if?criterion?is?not?None:
          ????????????loss?=?criterion(outputs,?labels).item()
          ????????else:
          ????????????loss?=?0

          ????????#?statistics
          ????????running_loss?+=?loss?*?inputs.size(0)
          ????????running_corrects?+=?torch.sum(preds?==?labels.data)

          ????eval_loss?=?running_loss?/?len(test_loader.dataset)
          ????eval_accuracy?=?running_corrects?/?len(test_loader.dataset)
          ????t1?=?time.time()
          ????print(f"eval?loss:?{eval_loss},?eval?acc:?{eval_accuracy},?cost:?{t1?-?t0}")
          ????return?eval_loss,?eval_accuracy

          這是evaluation的結(jié)果:

          eval?loss:?0.0,?eval?acc:?0.8476999998092651,?cost:?2.8914074897766113
          eval?loss:?0.0,?eval?acc:?0.15240000188350677,?cost:?1.240293264389038

          可以看到,精度下降嚴(yán)重。此時(shí)需要進(jìn)行一下校準(zhǔn),我直接放校準(zhǔn)函數(shù):

          def?calib_quant_model(model,?calib_dataloader):
          ????assert?isinstance(
          ????????model,?ObservedGraphModule
          ????),?"model?must?be?a?perpared?fx?ObservedGraphModule."
          ????model.eval()
          ????with?torch.inference_mode():
          ????????for?inputs,?labels?in?calib_dataloader:
          ????????????model(inputs)
          ????print("calib?done.")

          that's all. 就這么簡單。

          如果你有其他非分類模型,也可以直接把dataloader丟進(jìn)來。請(qǐng)注意,這里的標(biāo)簽并沒有用到。只需要統(tǒng)計(jì)數(shù)據(jù)的分布即可。

          非常簡單。

          最后我們?cè)俅蝒val一下:

          def?quant_calib_and_eval(model):
          ????#?test?only?on?CPU
          ????model.to(torch.device("cpu"))
          ????model.eval()

          ????qconfig?=?get_default_qconfig("fbgemm")
          ????qconfig_dict?=?{
          ????????"":?qconfig,
          ????????#?'object_type':?[]
          ????}

          ????model2?=?copy.deepcopy(model)
          ????model_prepared?=?prepare_fx(model2,?qconfig_dict)
          ????model_int8?=?convert_fx(model_prepared)
          ????model_int8.load_state_dict(torch.load("r18_quant.pth"))
          ????model_int8.eval()

          ????a?=?torch.randn([1,?3,?224,?224])
          ????o1?=?model(a)
          ????o2?=?model_int8(a)

          ????diff?=?torch.allclose(o1,?o2,?1e-4)
          ????print(diff)
          ????print(o1.shape,?o2.shape)
          ????print(o1,?o2)
          ????get_output_from_logits(o1)
          ????get_output_from_logits(o2)

          ????train_loader,?test_loader?=?prepare_dataloader()
          ????evaluate_model(model,?test_loader)
          ????evaluate_model(model_int8,?test_loader)

          ????#?calib?quant?model
          ????model2?=?copy.deepcopy(model)
          ????model_prepared?=?prepare_fx(model2,?qconfig_dict)
          ????model_int8?=?convert_fx(model_prepared)
          ????torch.save(model_int8.state_dict(),?"r18.pth")
          ????model_int8.eval()

          ????model_prepared?=?prepare_fx(model2,?qconfig_dict)
          ????calib_quant_model(model_prepared,?test_loader)
          ????model_int8?=?convert_fx(model_prepared)
          ????torch.save(model_int8.state_dict(),?"r18_quant_calib.pth")
          ????evaluate_model(model_int8,?test_loader)

          得到結(jié)果:

          eval?loss:?0.0,?eval?acc:?0.8476999998092651,?cost:?2.8914074897766113
          eval?loss:?0.0,?eval?acc:?0.15240000188350677,?cost:?1.240293264389038
          calib?done.
          eval?loss:?0.0,?eval?acc:?0.8442999720573425,?cost:?1.2966759204864502

          精度瞬間恢復(fù)了。速度快了超過一半。

          總結(jié)

          ok,我們用幾十行代碼就完成這個(gè)量化過程。并且使用校準(zhǔn),恢復(fù)了精度。由此可見fx的強(qiáng)大之處。

          拋出一個(gè)問題,歡迎留言區(qū)解答:

          • torch.fx量化的模型,如果export 到onnx并使用其他前推引擎推理。


          點(diǎn)個(gè)在看 paper不斷!

          瀏覽 85
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  免费视频亚洲 | 538在线精品 | 亚洲 欧美 乱伦 | 国产成人无码Av片小说在线观看 | 日韩欧美纯爱电影片在线观看 |