<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          拿什么拯救我的 4G 顯卡: PyTorch 節(jié)省顯存的策略總結(jié)

          共 4101字,需瀏覽 9分鐘

           ·

          2021-11-10 01:19

          ↑ 點(diǎn)擊藍(lán)字?關(guān)注極市平臺(tái)

          作者丨OpenMMLab
          來(lái)源丨https://zhuanlan.zhihu.com/p/430123077
          編輯丨極市平臺(tái)

          極市導(dǎo)讀

          ?

          隨著深度學(xué)習(xí)快速發(fā)展,同時(shí)伴隨著模型參數(shù)的爆炸式增長(zhǎng),對(duì)顯卡的顯存容量提出了越來(lái)越高的要求,如何在單卡小容量顯卡上面訓(xùn)練模型是一直以來(lái)大家關(guān)心的問(wèn)題。本文結(jié)合 MMCV 開(kāi)源庫(kù)對(duì)一些常用的節(jié)省顯存策略進(jìn)行了簡(jiǎn)要分析。?>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺(jué)的最前沿

          0 前言

          本文涉及到的 PyTorch 節(jié)省顯存的策略包括:

          • 混合精度訓(xùn)練
          • 大 batch 訓(xùn)練或者稱為梯度累加
          • gradient checkpointing 梯度檢查點(diǎn)

          1 混合精度訓(xùn)練

          混合精度訓(xùn)練全稱為 Automatic Mixed Precision,簡(jiǎn)稱為 AMP,也就是我們常說(shuō)的 FP16。在前系列解讀中已經(jīng)詳細(xì)分析了 AMP 原理、源碼實(shí)現(xiàn)以及 MMCV 中如何一行代碼使用 AMP,具體鏈接見(jiàn):

          OpenMMLab:PyTorch 源碼解讀之 torch.cuda.amp: 自動(dòng)混合精度詳解

          https://zhuanlan.zhihu.com/p/348554267

          OpenMMLab:OpenMMLab 中混合精度訓(xùn)練 AMP 的正確打開(kāi)方式

          https://zhuanlan.zhihu.com/p/375224982

          由于前面兩篇文章已經(jīng)分析的非常詳細(xì)了,本文只簡(jiǎn)要描述原理和具體說(shuō)明用法。

          考慮到訓(xùn)練過(guò)程中梯度幅值大部分是非常小的,故訓(xùn)練默認(rèn)是 FP32 格式,如果能直接以 FP16 格式精度進(jìn)行訓(xùn)練,理論上可以減少一半的內(nèi)存,達(dá)到加速訓(xùn)練和采用更大 batch size 的目的,但是直接以 FP16 訓(xùn)練會(huì)出現(xiàn)溢出問(wèn)題,導(dǎo)致 NAN 或者參數(shù)更新失敗問(wèn)題,而 AMP 的出現(xiàn)就是為了解決這個(gè)問(wèn)題,其核心思想是 混合精度訓(xùn)練+動(dòng)態(tài)損失放大

          1. 維護(hù)一個(gè) FP32 數(shù)值精度模型的副本
          2. 在每個(gè) iteration
          • 拷貝并且轉(zhuǎn)換成 FP16 模型
          • 前向傳播(FP16 的模型參數(shù)),此時(shí) weights, activations 都是 FP16
          • loss 乘 scale factor s
          • 反向傳播(FP16 的模型參數(shù)和參數(shù)梯度), 此時(shí) gradients 也是 FP16
          • 參數(shù)梯度乘 1/s
          • 利用 FP16 的梯度更新 FP32 的模型參數(shù)

          在 MMCV 中使用 AMP 分成兩種情況:

          • 在 OpenMMLab 上游庫(kù)例如 MMDetection 中使用 MMCV 的 AMP
          • 用戶只想簡(jiǎn)單調(diào)用 MMCV 中的 AMP,而不依賴上游庫(kù)

          (1) OpenMMLab 上游庫(kù)如何使用 MMCV 的 AMP

          以 MMDectection 為例,用法非常簡(jiǎn)單,只需要在配置中設(shè)置:

          fp16?=?dict(loss_scale=512.)?#?表示靜態(tài)?scale?

          #?表示動(dòng)態(tài)?scale?
          fp16?=?dict(loss_scale='dynamic')??

          #?通過(guò)字典形式靈活開(kāi)啟動(dòng)態(tài)?scale?
          fp16?=?dict(loss_scale=dict(init_scale=512.,mode='dynamic'))??

          三種不同設(shè)置在大部分模型上性能都非常接近,如果不想設(shè)置 loss_scale,則可以簡(jiǎn)單的采用 loss_scale='dynamic'

          (2) 調(diào)用 MMCV 中的 AMP

          直接調(diào)用 MMCV 中的 AMP,這通常意味著用戶可能在其他庫(kù)或者自己寫的代碼庫(kù)中支持 AMP 功能。需要特別強(qiáng)調(diào)的是 PyTorch 官方僅僅在 1.6 版本及其之后版本中開(kāi)始支持 AMP,而 MMCV 中的 AMP 支持 1.3 及其之后版本。如果你想在 1.3 或者 1.5 中使用 AMP,那么使用 MMCV 是個(gè)非常不錯(cuò)的選擇。

          使用 MMCV 的 AMP 功能,只需要遵循以下幾個(gè)步驟即可:

          1. 將 auto_fp16 裝飾器應(yīng)用到 model 的 forward 函數(shù)上
          2. 設(shè)置模型的 fp16_enabled 為 True 表示開(kāi)啟 AMP 訓(xùn)練,否則不生效
          3. 如果開(kāi)啟了 AMP,需要同時(shí)配置對(duì)應(yīng)的 FP16 優(yōu)化器配置 Fp16OptimizerHook
          4. 在訓(xùn)練的不同時(shí)刻,調(diào)用 Fp16OptimizerHook,如果你同時(shí)使用了 MMCV 中的 Runner 模塊,那么直接將第 3 步的參數(shù)輸入到 Runner 中即可
          5. (可選) 如果對(duì)應(yīng)某些 OP 希望強(qiáng)制運(yùn)行在 FP32 上,則可以在對(duì)應(yīng)位置引入 force_fp32 裝飾器
          #?1?作用到?forward?函數(shù)中
          class?ExampleModule(nn.Module):

          ????@auto_fp16()
          ????def?forward(self,?x,?y):
          ????????return?x,?y
          ????????
          #?2?如果開(kāi)啟?AMP,則需要加入開(kāi)啟標(biāo)志
          model.fp16_enabled?=?True?????

          #?3?配置?Fp16OptimizerHook
          optimizer_config?=?Fp16OptimizerHook(
          ????**cfg.optimizer_config,?**fp16_cfg,?distributed=distributed)

          #?4?傳遞給?runner
          runner.register_training_hooks(cfg.lr_config,?optimizer_config,
          ???????????????????????????????cfg.checkpoint_config,?cfg.log_config,
          ???????????????????????????????cfg.get('momentum_config',?None))???
          ?
          #?5?可選
          class?ExampleModule(nn.Module):

          ????@auto_fp16()
          ????def?forward(self,?x,?y):
          ????????features=self._forward(x,?y)
          ????????loss=self._loss(features,labels)
          ????????return?loss
          ????
          ????def?_forward(self,?x,?y):
          ???????pass
          ????
          ????
          ????@force_fp32(apply_to=('features',))
          ????def?_loss(features,labels)?:
          ????????pass??????????????????????????????

          注意 force_fp32 要生效,依然需要 fp16_enabled 為 True 才生效。

          2 大 Batch 訓(xùn)練(梯度累加)

          大 Batch 訓(xùn)練通常也稱為梯度累加策略,通常 PyTorch 一次迭代訓(xùn)練流程為:

          y_pred?=?model(xx)
          loss?=?loss_fn(y_pred,?y)
          loss.backward()
          optimizer.step()?
          optimizer.zero_grad()

          而梯度累加策略下常見(jiàn)的一次迭代訓(xùn)練流程為:

          y_pred?=?model(xx)
          loss?=?loss_fn(y_pred,?y)

          loss?=?loss?/?cumulative_iters
          loss.backward()

          if?current_iter?%?cumulative_iters==0
          ????optimizer.step()?
          ????optimizer.zero_grad()

          其核心思想就是對(duì)前幾次梯度進(jìn)行累加,然后再統(tǒng)一進(jìn)行參數(shù)更新,從而變相實(shí)現(xiàn)大 batch size 功能。需要注意的是如果模型中包括 BN 等考慮 batch 信息的層,那么性能可能會(huì)有輕微的差距。

          細(xì)節(jié)可以參考:

          https://github.com/open-mmlab/mmcv/pull/1221

          在 MMCV 中已經(jīng)實(shí)現(xiàn)了梯度累加功能,其核心代碼位于 mmcv/runner/hooks/optimizer.py

          GradientCumulativeOptimizerHook 中,和 AMP 實(shí)現(xiàn)一樣是采用 Hook 實(shí)現(xiàn)的。使用方法和 AMP 類似,只需要將第一節(jié)中的 Fp16OptimizerHook 替換為 GradientCumulativeOptimizerHook 或者 GradientCumulativeFp16OptimizerHook 即可。其核心實(shí)現(xiàn)如下所示:

          @HOOKS.register_module()
          class?GradientCumulativeOptimizerHook(OptimizerHook):
          ????def?__init__(self,?cumulative_iters=1,?**kwargs):
          ????
          ????????self.cumulative_iters?=?cumulative_iters
          ????????self.divisible_iters?=?0??#?剩余的可以被?cumulative_iters?整除的訓(xùn)練迭代次數(shù)
          ????????self.remainder_iters?=?0??#?剩余累加次數(shù)
          ????????self.initialized?=?False
          ????????
          ????def?after_train_iter(self,?runner):
          ????????#?只需要運(yùn)行一次即可
          ????????if?not?self.initialized:
          ????????????self._init(runner)
          ????????
          ????????if?runner.iter?????????????loss_factor?=?self.cumulative_iters
          ????????else:
          ????????????loss_factor?=?self.remainder_iters
          ????????????
          ????????loss?=?runner.outputs['loss']
          ????????loss?=?loss?/?loss_factor
          ????????loss.backward()
          ????
          ????????if?(self.every_n_iters(runner,?self.cumulative_iters)
          ????????????????or?self.is_last_iter(runner)):
          ????
          ????????????runner.optimizer.step()
          ????????????runner.optimizer.zero_grad()????


          ????def?_init(self,?runner):
          ?
          ????????residual_iters?=?runner.max_iters?-?runner.iter
          ????
          ????????self.divisible_iters?=?(
          ????????????residual_iters?//?self.cumulative_iters?*?self.cumulative_iters)
          ????????self.remainder_iters?=?residual_iters?-?self.divisible_iters
          ????
          ????????self.initialized?=?True????????????

          需要明白 divisible_iters 和 remainder_iters 的含義:

          (1) 從頭訓(xùn)練

          此時(shí)在開(kāi)始訓(xùn)練時(shí) iter=0,一共迭代 max_iters=102 次,梯度累加次數(shù)是 4,由于 102 無(wú)法被 4 整除,也就是最后的 102-(102 // 4)*4=2 個(gè)迭代是額外需要考慮的,在最后 2 個(gè)訓(xùn)練迭代中 loss_factor 不能除以 4,而是 2,這樣才是最合理的做法。其中 remainder_iters=2,divisible_iters=100,residual_iters=102。

          (2) resume 訓(xùn)練

          假設(shè)在梯度累加的中途退出,然后進(jìn)行 resume 訓(xùn)練,此時(shí) iter 不是 0,由于優(yōu)化器對(duì)象需要重新初始化,為了保證剩余的不能被累加次數(shù)的訓(xùn)練迭代次數(shù)能夠正常計(jì)算,需要重新計(jì)算 residual_iters。

          3 梯度檢查點(diǎn)

          梯度檢查點(diǎn)是一種用訓(xùn)練時(shí)間換取顯存的辦法,其核心原理是在反向傳播時(shí)重新計(jì)算神經(jīng)網(wǎng)絡(luò)的中間激活值而不用在前向時(shí)存儲(chǔ),torch.utils.checkpoint 包中已經(jīng)實(shí)現(xiàn)了對(duì)應(yīng)功能。簡(jiǎn)要實(shí)現(xiàn)過(guò)程是:在前向階段傳遞到 checkpoint 中的 forward 函數(shù)會(huì)以 _torch.no_grad_ 模式運(yùn)行,并且僅僅保存輸入?yún)?shù)和 forward 函數(shù),在反向階段重新計(jì)算其 forward 輸出值。

          具體用法非常簡(jiǎn)單,以 ResNet 的 BasicBlock 為例:

          def?forward(self,?x):
          ????def?_inner_forward(x):
          ????????identity?=?x
          ????????out?=?self.conv1(x)
          ????????out?=?self.norm1(out)
          ????????out?=?self.relu(out)
          ????????out?=?self.conv2(out)
          ????????out?=?self.norm2(out)
          ????????if?self.downsample?is?not?None:
          ????????????identity?=?self.downsample(x)
          ????????out?+=?identity
          ????????return?out
          ????????
          ????#?x.requires_grad?這個(gè)判斷很有必要
          ????if?self.with_cp?and?x.requires_grad:
          ????????out?=?cp.checkpoint(_inner_forward,?x)
          ????else:
          ????????out?=?_inner_forward(x)
          ????out?=?self.relu(out)
          ????return?out

          self.with_cp 為 True,表示要開(kāi)啟梯度檢查點(diǎn)功能。

          checkpoint 在用法上面需要注意以下幾點(diǎn):

          1. 模型的第一層不能用 checkpoint 或者說(shuō) forward 輸入中不能所有輸入的 requires_grad 屬性都是 False,因?yàn)槠鋬?nèi)部實(shí)現(xiàn)是依靠輸入的 requires_grad 屬性來(lái)判斷輸出返回是否需要梯度,而通常模型第一層輸入是 image tensor,其 requires_grad 通常是 False。一旦你第一層用了 checkpoint,那么意味著這個(gè) forward 函數(shù)不會(huì)有任何梯度,也就是說(shuō)不會(huì)進(jìn)行任何參數(shù)更新,沒(méi)有任何使用的必要,具體見(jiàn) https://discuss.pytorch.org/t/use-of-torch-utils-checkpoint-checkpoint-causes-simple-model-to-diverge/116271。如果第一層用了 checkpoint, PyTorch 會(huì)打印 None of the inputs have requires_grad=True. Gradients will be Non 警告
          2. 對(duì)于 dropout 這種 forward 存在隨機(jī)性的層,需要保證 preserve_rng_state 為 True (默認(rèn)就是 True,所以不用擔(dān)心),一旦標(biāo)志位設(shè)置為 True,在 forward 會(huì)存儲(chǔ) RNG 狀態(tài),然后在反向傳播的時(shí)候讀取該 RNG,保證兩次 forward 輸出一致。如果你確定不需要保存 RNG,則可以設(shè)置 preserve_rng_state 為 False,省掉一些不必要的運(yùn)行邏輯
          3. 其他注意事項(xiàng),可以參考官方文檔 https://pytorch.org/docs/stable/checkpoint.html#

          其核心實(shí)現(xiàn)如下所示:

          class?CheckpointFunction(torch.autograd.Function):

          ????@staticmethod
          ????def?forward(ctx,?run_function,?preserve_rng_state,?*args):
          ????????#?檢查輸入?yún)?shù)是否需要梯度
          ????????check_backward_validity(args)
          ????????#?保存必要的狀態(tài)
          ????????ctx.run_function?=?run_function
          ????????ctx.save_for_backward(*args)
          ????????with?torch.no_grad():
          ????????????#?以?no_grad?模型運(yùn)行一遍
          ????????????outputs?=?run_function(*args)
          ????????return?outputs

          ????@staticmethod
          ????def?backward(ctx,?*args):
          ????????#?讀取輸入?yún)?shù)
          ????????inputs?=?ctx.saved_tensors
          ????????#?Stash?the?surrounding?rng?state,?and?mimic?the?state?that?was
          ????????#?present?at?this?time?during?forward.??Restore?the?surrounding?state
          ????????#?when?we're?done.
          ????????rng_devices?=?[]
          ????????with?torch.random.fork_rng(devices=rng_devices,?enabled=ctx.preserve_rng_state):
          ????????????#?detach?掉當(dāng)前不需要考慮的節(jié)點(diǎn)
          ????????????detached_inputs?=?detach_variable(inputs)
          ????????????#?重新運(yùn)行一遍
          ????????????with?torch.enable_grad():
          ????????????????outputs?=?ctx.run_function(*detached_inputs)
          ???????
          ????????if?isinstance(outputs,?torch.Tensor):
          ????????????outputs?=?(outputs,)
          ????????#?計(jì)算該子圖梯度
          ????????torch.autograd.backward(outputs,?args)
          ????????grads?=?tuple(inp.grad?if?isinstance(inp,?torch.Tensor)?else?inp
          ??????????????????????for?inp?in?detached_inputs)
          ????????return?(None,?None)?+?grads

          4 實(shí)驗(yàn)驗(yàn)證

          為了驗(yàn)證上述策略是否真的能夠省顯存,采用 mmdetection 庫(kù)進(jìn)行驗(yàn)證,基本環(huán)境如下:

          顯卡:?GeForce?GTX?1660
          PyTorch:?1.7.1
          CUDA?Runtime?10.1
          MMCV:?1.3.16
          MMDetection:?2.17.0

          (1) base

          • 數(shù)據(jù)集:pascal voc
          • 算法是 retinanet,對(duì)應(yīng)配置文件為 retinanet_r50_fpn_1x_voc0712.py
          • 為了防止 lr 過(guò)大導(dǎo)致訓(xùn)練出現(xiàn) nan,需要將 lr 設(shè)置為 0.01/8=0.00125
          • bs 設(shè)置為 2

          (2) 混合精度 AMP

          在 base 配置基礎(chǔ)上新增如下配置即可:

          fp16?=?dict(loss_scale=512.)

          (3) 梯度累加

          在 base 配置基礎(chǔ)上替換 optimizer_config 為如下:

          #?累加2次
          optimizer_config?=?dict(type='GradientCumulativeOptimizerHook',?cumulative_iters=2)

          (4) 梯度檢查點(diǎn)

          在 base 配置基礎(chǔ)上在 backbone 部分開(kāi)啟 with_cp 標(biāo)志即可:

          model?=?dict(backbone=dict(with_cp=True),
          ?????????????bbox_head=dict(num_classes=20))

          每個(gè)實(shí)驗(yàn)總共迭代 1300 次,統(tǒng)計(jì)占用顯存、訓(xùn)練總時(shí)長(zhǎng)。

          配置顯存占用(MB)訓(xùn)練時(shí)長(zhǎng)
          base29007 分 45 秒
          混合精度 AMP224336 分
          梯度累加31777 分 32 秒
          梯度檢查點(diǎn)25908 分 37 秒
          1. 對(duì)比 base 和 AMP 可以發(fā)現(xiàn),由于實(shí)驗(yàn)顯卡是不支持 AMP 的,故只能節(jié)省顯存,速度會(huì)特別慢,如果本身顯卡支持 AMP 則可以實(shí)現(xiàn)在節(jié)省顯存的同時(shí)提升訓(xùn)練速度
          2. 對(duì)比 base 和梯度累加可以發(fā)現(xiàn),在相同 bs 情況下,梯度累加 2 次相當(dāng)于 bs 擴(kuò)大一倍,但是顯存增加不多。如果將 bs 縮小一倍,則可以實(shí)現(xiàn)在相同 bs 情況下節(jié)省大概一倍顯存
          3. 對(duì)比 base 和梯度檢查點(diǎn)可以發(fā)現(xiàn),可以節(jié)省一定的顯存,但是訓(xùn)練時(shí)長(zhǎng)會(huì)增加一些

          從上面簡(jiǎn)單實(shí)驗(yàn)可以發(fā)現(xiàn),AMP、梯度累加和梯度檢查點(diǎn)確實(shí)可以在不同程度減少顯存,而且這三個(gè)策略是正交的,可以同時(shí)使用。

          5 總結(jié)

          本文簡(jiǎn)要描述了三個(gè)在 MMCV 中集成且可以通過(guò)配置一行開(kāi)啟的節(jié)省顯存策略,這三個(gè)策略比較常用也比較成熟。隨著模型規(guī)模的不斷增長(zhǎng),也出現(xiàn)了很多新的策略,例如模型參數(shù)壓縮、動(dòng)態(tài)顯存優(yōu)化、使用 CPU 內(nèi)存暫存策略以及分布式情況下 PyTorch 1.10 最新支持的 ZeroRedundancyOptimizer 等等。

          快速鏈接直達(dá) MMCV 算法庫(kù),歡迎大家 Star:

          https://github.com/open-mmlab/mmcv

          如果覺(jué)得有用,就請(qǐng)分享到朋友圈吧!

          △點(diǎn)擊卡片關(guān)注極市平臺(tái),獲取最新CV干貨

          公眾號(hào)后臺(tái)回復(fù)“CVPR21檢測(cè)”獲取CVPR2021目標(biāo)檢測(cè)論文下載~


          極市干貨
          神經(jīng)網(wǎng)絡(luò):視覺(jué)神經(jīng)網(wǎng)絡(luò)模型優(yōu)秀開(kāi)源工作:timm庫(kù)使用方法和最新代碼解讀
          技術(shù)綜述:綜述:神經(jīng)網(wǎng)絡(luò)中 Normalization 的發(fā)展歷程CNN輕量化模型及其設(shè)計(jì)原則綜述
          算法技巧(trick):8點(diǎn)PyTorch提速技巧匯總圖像分類算法優(yōu)化技巧


          #?CV技術(shù)社群邀請(qǐng)函?#

          △長(zhǎng)按添加極市小助手
          添加極市小助手微信(ID : cvmart4)

          備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)


          即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與?10000+來(lái)自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺(jué)開(kāi)發(fā)者互動(dòng)交流~



          覺(jué)得有用麻煩給個(gè)在看啦~??
          瀏覽 33
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  天天网综合 | 骚逼久久| 一区=区三区四区 视频 | 九热精品视频 | 小早川怜子 无码 在线 |