<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          淺談混合精度訓(xùn)練imagenet

          共 7814字,需瀏覽 16分鐘

           ·

          2021-07-14 09:39

          ↑ 點(diǎn)擊藍(lán)字 關(guān)注極市平臺(tái)

          作者丨jmc
          來源丨GiantPandaCV
          編輯丨極市平臺(tái)

          極市導(dǎo)讀

           

          本文作者通過自己實(shí)驗(yàn)得出了一些關(guān)于使用混合精度訓(xùn)練的結(jié)論,附有相關(guān)代碼。 >>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿

          零、序

          本文沒有任何的原理和解讀,只有一些實(shí)驗(yàn)的結(jié)論,對(duì)于想使用混合精度訓(xùn)練的同學(xué)可以直接參考結(jié)論白嫖,或者直接拿github上的代碼(文末放送)。

          一、引言

          以前做項(xiàng)目的時(shí)候出現(xiàn)過一個(gè)問題,使用FP16訓(xùn)練的時(shí)候,只要BatchSize增加(LR也對(duì)應(yīng)增加)的時(shí)候訓(xùn)練,一段時(shí)間后就會(huì)出現(xiàn)loss異常,同時(shí)val對(duì)應(yīng)的明顯降低,甚至直接NAN的情況出現(xiàn),圖示如下:

          • 這種是比較正常的損失和acc的情況,因?yàn)轫?xiàng)目的數(shù)據(jù)非常長(zhǎng)尾。
          • 這種就是不正常的訓(xùn)練情況, val的損失不下降反而上升,acc不升反而降。
          • 還有一種情況,就是訓(xùn)練十幾個(gè)epoch以后,loss上升到非常大,acc為nan,后續(xù)訓(xùn)練都是nan,tensorboard顯示有點(diǎn)問題,只好看ckpt的結(jié)果了。

          由于以前每周都會(huì)跑很多模型,問題也不是經(jīng)常出現(xiàn),所以以為是偶然出現(xiàn),不過最近恰好最近要做一些transformer的實(shí)驗(yàn),在跑imagenet baseline(R50)的時(shí)候,出現(xiàn)了類似的問題,由于FP16訓(xùn)練的時(shí)候,出現(xiàn)了溢出的情況所導(dǎo)致的。簡(jiǎn)單的做了一些實(shí)驗(yàn),整理如下。

          二、混合精度訓(xùn)練

          混合精度訓(xùn)練,以pytorch 1.6版本為基礎(chǔ)的話,大致是有3種方案,依次介紹如下:

          • 模型和輸入輸出直接half,如果有BN,那么BN計(jì)算需要轉(zhuǎn)為FP32精度,我上面的問題就是基于此來訓(xùn)練的,代碼如下:
           if args.FP16:        model = model.half()        for bn in get_bn_modules(model):            bn.float()    ...
          for data in dataloader: if args.FP16: image, label = data[0].half() output = model(image) losses = criterion(output, label)
          optimizer.zero_grad() losses.backward() optimizer.step()


          • 使用NVIDIA的Apex庫(kù),這里有O1,O2,O3三種訓(xùn)練模式,代碼如下:
          try:    from apex import amp     from apex.parallel import convert_syncbn_model    from apex.parallel import DistributedDataParallel as DDP except Exception as e:    print("amp have not been import !!!")
          if args.apex: model = convert_syncbn_model(model)
          if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level=args.mode) model = DDP(model, delay_allreduce=True)
          ...
          for data in dataloader: image, label = data[0], data[1] batch_output = model(image) losses = criterion(batch_output, label)
          optimizer.zero_grad() if args.apex: with amp.scale_loss(losses, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step()

          • pytorch1.6版本以后把a(bǔ)pex并入到了自身的庫(kù)里面,代碼如下:
          from torch.cuda.amp import autocast as autocastfrom torch.nn.parallel import DistributedDataParallel as DataParallel
          model = DataParallel(model, device_ids=[args.local_rank], find_unused_parameters=True)
          if args.amp: scaler = torch.cuda.amp.GradScaler()
          for data in dataloader: image, label = data[0], data[1] if args.amp: with autocast(): batch_output = model(image) losses = criterion(batch_output, label)
          if args.amp: scaler.scale(losses).backward() scaler.step(optimizer) scaler.update()

          三、pytorch不同的分布式訓(xùn)練速度對(duì)比

          • 環(huán)境配置如下:
            CPU Intel(R) Xeon(R) Platinum 8163 CPU @ 2.50GHz
            GPU 8XV100 32G
            cuda 10.2
            pytorch 1.7.1

          pytorch分布式有兩種不同的啟動(dòng)方法,一種是單機(jī)多卡啟動(dòng),一種是多機(jī)多卡啟動(dòng)。ps: DataParallel不是分布式訓(xùn)練。

          • 多機(jī)啟動(dòng)
          #!/bin/bashcd$FOLDER;CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch --nproc_per_node 8 train_lanuch.py \...

          • 單機(jī)啟動(dòng)
          #!/bin/bashcd$FOLDER;CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore test.py \--dist-url 'tcp://127.0.0.1:9966' \--dist-backend 'nccl' \--multiprocessing-distributed=1 \--world-size=1 \--rank=0 \...

          詳細(xì)代碼看文末的github鏈接。

          實(shí)驗(yàn)一、num workers對(duì)于速度的影響
          我的服務(wù)器是48個(gè)物理核心,96個(gè)邏輯核心,所以48的情況下,效果最好,不過增加和減少對(duì)于模型的影響不大,基本上按照CPU的物理核心個(gè)數(shù)來設(shè)置就可以。

          numworkers

          實(shí)驗(yàn)二、OMP和MKL對(duì)于速度的影響
          OMP和MKL對(duì)于多機(jī)模式下的速度有輕微的影響,如果不想每個(gè)都去試,直接經(jīng)驗(yàn)設(shè)置為1最合理。FP16大幅度提升模型的訓(xùn)練速度,可以節(jié)省2/5的時(shí)間。

          omp&amp;amp;amp;amp;amp;mkl

          實(shí)驗(yàn)三、單機(jī)和多機(jī)啟動(dòng)速度差異
          單機(jī)和多機(jī)啟動(dòng),對(duì)于模型的前向基本是沒有影響的, 主要的差異是在loader開始執(zhí)行的速度,多機(jī)比起單機(jī)啟動(dòng)要快2倍-5倍左右的時(shí)間。

          四、不同混合精度訓(xùn)練方法對(duì)比

          實(shí)驗(yàn)均在ResNet50和imagenet下面進(jìn)行的,LR隨著BS變換和線性增長(zhǎng),公式如下

          • 實(shí)驗(yàn)結(jié)果

            模型FP16+BNFP32實(shí)驗(yàn)記錄

          fp16

          很明顯可以發(fā)現(xiàn),單存使用FP16進(jìn)行訓(xùn)練,但是沒有l(wèi)oss縮放的情況下,當(dāng)BS和LR都增大的時(shí)候,訓(xùn)練是無法進(jìn)行的,直接原因就是因?yàn)長(zhǎng)R過大,導(dǎo)致模型更新的時(shí)候數(shù)值范圍溢出了,同理loss也就直接為NAN了,我嘗試把LR調(diào)小后發(fā)現(xiàn),模型是可以正常訓(xùn)練的,只是精度略有所下降。

          Apex混合精度實(shí)驗(yàn)記錄

          apex

          Apex O3模式下的訓(xùn)練情況和上面FP16的結(jié)論是一致的,存FP16訓(xùn)練,不管是否有l(wèi)oss縮放都會(huì)導(dǎo)致訓(xùn)練NaN,O2和O1是沒有任何問題的,O2的精度略低于O1的精度。

          AMP實(shí)驗(yàn)記錄

          amp

          AMP自動(dòng)把模型需要用FP32計(jì)算的層或者op直接轉(zhuǎn)換,不需要顯著性指定。精度比apex高,同時(shí)訓(xùn)練時(shí)間更少。

          2-bit訓(xùn)練,ACTNN
          簡(jiǎn)單的嘗試了一下2bit訓(xùn)練,1k的bs是可以跑的,不過速度相比FP16跑,慢了太多,基本可以pass掉了。

          附上一個(gè)比較合理的收斂情況

          train

          val

          五、結(jié)論

          • 如果使用分布式訓(xùn)練,使用pytorch 多機(jī)模式啟動(dòng),收益比較高,如果你不希望所有卡都用的話,那么建議使用單機(jī)多卡的模式。
          • 如果使用FP16方式計(jì)算的話,那么無腦pytorch amp就可以了,速度和精度都比較有優(yōu)勢(shì),代碼量也不多。
          • 我的增強(qiáng)只用了隨機(jī)裁剪,水平翻轉(zhuǎn),跑了90個(gè)epoch,原版的resnet50是跑了120個(gè)epoch,還有color jitter,imagenet上one crop的結(jié)果0.76012,和我的結(jié)果相差無幾,所以分類任務(wù)(基本上最后是求概率的問題,圖像,視頻都work,已經(jīng)驗(yàn)證過)上FP16很明顯完全可以替代FP32。我跑了一個(gè)120epoch的版本,結(jié)果是0.767,吊打原版本結(jié)果了QAQ。
          • 如果跑小的bs,第一種FP16的方法完全是ok的,對(duì)于大的bs來說,使用AMP會(huì)使得模型的收斂更加穩(wěn)定。
          • 如果顯存足夠大,用大的BS會(huì)獲得更好的訓(xùn)練收益。
          • 代碼自行取用。

          FlyEgle/imageclassificationgithub.com


          如果覺得有用,就請(qǐng)分享到朋友圈吧!

          △點(diǎn)擊卡片關(guān)注極市平臺(tái),獲取最新CV干貨

          公眾號(hào)后臺(tái)回復(fù)“CVPR21檢測(cè)”獲取CVPR2021目標(biāo)檢測(cè)論文下載~

          極市干貨


          YOLO教程:一文讀懂YOLO V5 與 YOLO V4大盤點(diǎn)|YOLO 系目標(biāo)檢測(cè)算法總覽全面解析YOLO V4網(wǎng)絡(luò)結(jié)構(gòu)
          實(shí)操教程:PyTorch vs LibTorch:網(wǎng)絡(luò)推理速度誰更快?只用兩行代碼,我讓Transformer推理加速了50倍PyTorch AutoGrad C++層實(shí)現(xiàn)
          算法技巧(trick):深度學(xué)習(xí)訓(xùn)練tricks總結(jié)(有實(shí)驗(yàn)支撐)深度強(qiáng)化學(xué)習(xí)調(diào)參Tricks合集長(zhǎng)尾識(shí)別中的Tricks匯總(AAAI2021
          最新CV競(jìng)賽:2021 高通人工智能應(yīng)用創(chuàng)新大賽CVPR 2021 | Short-video Face Parsing Challenge3D人體目標(biāo)檢測(cè)與行為分析競(jìng)賽開賽,獎(jiǎng)池7萬+,數(shù)據(jù)集達(dá)16671張!


          CV技術(shù)社群邀請(qǐng)函 #

          △長(zhǎng)按添加極市小助手
          添加極市小助手微信(ID : cvmart2)

          備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測(cè)-深圳)


          即可申請(qǐng)加入極市目標(biāo)檢測(cè)/圖像分割/工業(yè)檢測(cè)/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實(shí)項(xiàng)目需求對(duì)接、求職內(nèi)推、算法競(jìng)賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~



          覺得有用麻煩給個(gè)在看啦~  
          瀏覽 52
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  丁香五月六月婷婷 | 日本黄色A片 | 韩国不卡无码 | 亚洲人成人网站色 | 日韩综合网站 |