<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          實操教程|PyTorch實現(xiàn)斷點繼續(xù)訓(xùn)練

          共 9266字,需瀏覽 19分鐘

           ·

          2021-05-15 21:38

          ↑ 點擊藍(lán)字 關(guān)注極市平臺

          作者丨HUST小菜雞@知乎(已授權(quán))
          來源丨h(huán)ttps://zhuanlan.zhihu.com/p/133250753
          編輯丨極市平臺

          極市導(dǎo)讀

           

          本文整理了pytorch實現(xiàn)斷電繼續(xù)訓(xùn)練時需要注意的要點,附有代碼詳解。

          最近在嘗試用CIFAR10訓(xùn)練分類問題的時候,由于數(shù)據(jù)集體量比較大,訓(xùn)練的過程中時間比較長,有時候想給停下來,但是停下來了之后就得重新訓(xùn)練,之前師兄讓我們學(xué)習(xí)斷點繼續(xù)訓(xùn)練及繼續(xù)訓(xùn)練的時候注意epoch的改變等,今天上午給大致整理了一下,不全面僅供參考

          Epoch:  9 | train loss: 0.3517 | test accuracy: 0.7184 | train time: 14215.1018  sEpoch:  9 | train loss: 0.2471 | test accuracy: 0.7252 | train time: 14309.1216  sEpoch:  9 | train loss: 0.4335 | test accuracy: 0.7201 | train time: 14403.2398  sEpoch:  9 | train loss: 0.2186 | test accuracy: 0.7242 | train time: 14497.1921  sEpoch:  9 | train loss: 0.2127 | test accuracy: 0.7196 | train time: 14591.4974  sEpoch:  9 | train loss: 0.1624 | test accuracy: 0.7142 | train time: 14685.7034  sEpoch:  9 | train loss: 0.1795 | test accuracy: 0.7170 | train time: 14780.2831  s絕望!!!!!訓(xùn)練到了一定次數(shù)發(fā)現(xiàn)訓(xùn)練次數(shù)少了,或者中途斷了又得重新開始訓(xùn)練

          一、模型的保存與加載

          PyTorch中的保存(序列化,從內(nèi)存到硬盤)與反序列化(加載,從硬盤到內(nèi)存)

          torch.save主要參數(shù):obj:對象 、f:輸出路徑

          torch.load 主要參數(shù) :f:文件路徑 、map_location:指定存放位置、 cpu or gpu

          模型的保存的兩種方法:

          1、保存整個Module

          torch.save(net, path)

          2、保存模型參數(shù)

          state_dict = net.state_dict()torch.save(state_dict , path)

          二、模型的訓(xùn)練過程中保存

          checkpoint = {        "net": model.state_dict(),        'optimizer':optimizer.state_dict(),        "epoch": epoch    }

          將網(wǎng)絡(luò)訓(xùn)練過程中的網(wǎng)絡(luò)的權(quán)重,優(yōu)化器的權(quán)重保存,以及epoch 保存,便于繼續(xù)訓(xùn)練恢復(fù)

          在訓(xùn)練過程中,可以根據(jù)自己的需要,每多少代,或者多少epoch保存一次網(wǎng)絡(luò)參數(shù),便于恢復(fù),提高程序的魯棒性。

          checkpoint = {        "net": model.state_dict(),        'optimizer':optimizer.state_dict(),        "epoch": epoch    }    if not os.path.isdir("./models/checkpoint"):        os.mkdir("./models/checkpoint")    torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))
          通過上述的過程可以在訓(xùn)練過程自動在指定位置創(chuàng)建文件夾,并保存斷點文件

          三、模型的斷點繼續(xù)訓(xùn)練

          if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 斷點路徑    checkpoint = torch.load(path_checkpoint)  # 加載斷點
          model.load_state_dict(checkpoint['net']) # 加載模型可學(xué)習(xí)參數(shù)
          optimizer.load_state_dict(checkpoint['optimizer']) # 加載優(yōu)化器參數(shù) start_epoch = checkpoint['epoch'] # 設(shè)置開始的epoch

          指出這里的是否繼續(xù)訓(xùn)練,及訓(xùn)練的checkpoint的文件位置等可以通過argparse從命令行直接讀取,也可以通過log文件直接加載,也可以自己在代碼中進(jìn)行修改。關(guān)于argparse參照我的這一篇文章:

          HUST小菜雞:argparse 命令行選項、參數(shù)和子命令解析器

          https://zhuanlan.zhihu.com/p/133285373

          四、重點在于epoch的恢復(fù)

          start_epoch = -1

          if RESUME: path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # 斷點路徑 checkpoint = torch.load(path_checkpoint) # 加載斷點
          model.load_state_dict(checkpoint['net']) # 加載模型可學(xué)習(xí)參數(shù)
          optimizer.load_state_dict(checkpoint['optimizer']) # 加載優(yōu)化器參數(shù) start_epoch = checkpoint['epoch'] # 設(shè)置開始的epoch


          for epoch in range(start_epoch + 1 ,EPOCH): # print('EPOCH:',epoch) for step, (b_img,b_label) in enumerate(train_loader): train_output = model(b_img) loss = loss_func(train_output,b_label) # losses.append(loss) optimizer.zero_grad() loss.backward() optimizer.step()

          通過定義start_epoch變量來保證繼續(xù)訓(xùn)練的時候epoch不會變化

          斷點繼續(xù)訓(xùn)練

          一、初始化隨機數(shù)種子

          import torchimport randomimport numpy as np
          def set_random_seed(seed = 10,deterministic=False,benchmark=False): random.seed(seed) np.random(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) if deterministic: torch.backends.cudnn.deterministic = True if benchmark: torch.backends.cudnn.benchmark = True

          關(guān)于torch.backends.cudnn.deterministic和torch.backends.cudnn.benchmark詳見

          Pytorch學(xué)習(xí)0.01:cudnn.benchmark= True的設(shè)置

          https://www.cnblogs.com/captain-dl/p/11938864.html

          pytorch---之cudnn.benchmark和cudnn.deterministic_人工智能_zxyhhjs2017的博客

          https://blog.csdn.net/zxyhhjs2017/article/details/91348108

          benchmark用在輸入尺寸一致,可以加速訓(xùn)練,deterministic用來固定內(nèi)部隨機性

          二、多步長SGD繼續(xù)訓(xùn)練

          在簡單的任務(wù)中,我們使用固定步長(也就是學(xué)習(xí)率LR)進(jìn)行訓(xùn)練,但是如果學(xué)習(xí)率lr設(shè)置的過小的話,則會導(dǎo)致很難收斂,如果學(xué)習(xí)率很大的時候,就會導(dǎo)致在最小值附近,總會錯過最小值,loss產(chǎn)生震蕩,無法收斂。所以這要求我們要對于不同的訓(xùn)練階段使用不同的學(xué)習(xí)率,一方面可以加快訓(xùn)練的過程,另一方面可以加快網(wǎng)絡(luò)收斂。

          采用多步長 torch.optim.lr_scheduler的多種步長設(shè)置方式來實現(xiàn)步長的控制,lr_scheduler的各種使用推薦參考如下教程:

          【轉(zhuǎn)載】 Pytorch中的學(xué)習(xí)率調(diào)整lr_scheduler,ReduceLROnPlateau

          https://www.cnblogs.com/devilmaycry812839668/p/10630302.html

          所以我們在保存網(wǎng)絡(luò)中的訓(xùn)練的參數(shù)的過程中,還需要保存lr_scheduler的state_dict,然后斷點繼續(xù)訓(xùn)練的時候恢復(fù)

          #這里我設(shè)置了不同的epoch對應(yīng)不同的學(xué)習(xí)率衰減,在10->20->30,學(xué)習(xí)率依次衰減為原來的0.1,即一個數(shù)量級lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
          for epoch in range(start_epoch+1,80): optimizer.zero_grad() optimizer.step() lr_schedule.step()
          if epoch %10 ==0: print('epoch:',epoch)        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
          lr的變化過程如下:
          epoch: 10learning rate: 0.1epoch: 20learning rate: 0.010000000000000002epoch: 30learning rate: 0.0010000000000000002epoch: 40learning rate: 0.00010000000000000003epoch: 50learning rate: 1.0000000000000004e-05epoch: 60learning rate: 1.0000000000000004e-06epoch: 70learning rate: 1.0000000000000004e-06

          我們在保存的時候,也需要對lr_scheduler的state_dict進(jìn)行保存,斷點繼續(xù)訓(xùn)練的時候也需要恢復(fù)lr_scheduler

          #加載恢復(fù)if RESUME:    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 斷點路徑    checkpoint = torch.load(path_checkpoint)  # 加載斷點
          model.load_state_dict(checkpoint['net']) # 加載模型可學(xué)習(xí)參數(shù)
          optimizer.load_state_dict(checkpoint['optimizer']) # 加載優(yōu)化器參數(shù) start_epoch = checkpoint['epoch'] # 設(shè)置開始的epoch lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加載lr_scheduler


          #保存for epoch in range(start_epoch+1,80):
          optimizer.zero_grad()
          optimizer.step() lr_schedule.step()

          if epoch %10 ==0: print('epoch:',epoch) print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr']) checkpoint = { "net": model.state_dict(), 'optimizer': optimizer.state_dict(), "epoch": epoch, 'lr_schedule': lr_schedule.state_dict() } if not os.path.isdir("./model_parameter/test"): os.mkdir("./model_parameter/test")        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))

          三、保存最好的結(jié)果

          每一個epoch中的每個step會有不同的結(jié)果,可以保存每一代最好的結(jié)果,用于后續(xù)的訓(xùn)練

          第一次實驗代碼

          RESUME = True
          EPOCH = 40LR = 0.0005

          model = cifar10_cnn.CIFAR10_CNN()
          print(model)optimizer = torch.optim.Adam(model.parameters(),lr=LR)loss_func = nn.CrossEntropyLoss()
          start_epoch = -1

          if RESUME: path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # 斷點路徑 checkpoint = torch.load(path_checkpoint) # 加載斷點
          model.load_state_dict(checkpoint['net']) # 加載模型可學(xué)習(xí)參數(shù)
          optimizer.load_state_dict(checkpoint['optimizer']) # 加載優(yōu)化器參數(shù) start_epoch = checkpoint['epoch'] # 設(shè)置開始的epoch


          for epoch in range(start_epoch + 1 ,EPOCH): # print('EPOCH:',epoch) for step, (b_img,b_label) in enumerate(train_loader): train_output = model(b_img) loss = loss_func(train_output,b_label) # losses.append(loss) optimizer.zero_grad() loss.backward() optimizer.step()
          if step % 100 == 0: now = time.time() print('EPOCH:',epoch,'| step :',step,'| loss :',loss.data.numpy(),'| train time: %.4f'%(now-start_time))
          checkpoint = { "net": model.state_dict(), 'optimizer':optimizer.state_dict(), "epoch": epoch } if not os.path.isdir("./models/checkpoint"): os.mkdir("./models/checkpoint")    torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))

          更新實驗代碼

          optimizer = torch.optim.SGD(model.parameters(),lr=0.1)lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)start_epoch = 9# print(schedule)

          if RESUME: path_checkpoint = "./model_parameter/test/ckpt_best_50.pth" # 斷點路徑 checkpoint = torch.load(path_checkpoint) # 加載斷點
          model.load_state_dict(checkpoint['net']) # 加載模型可學(xué)習(xí)參數(shù)
          optimizer.load_state_dict(checkpoint['optimizer']) # 加載優(yōu)化器參數(shù) start_epoch = checkpoint['epoch'] # 設(shè)置開始的epoch lr_schedule.load_state_dict(checkpoint['lr_schedule'])
          for epoch in range(start_epoch+1,80):
          optimizer.zero_grad()
          optimizer.step() lr_schedule.step()

          if epoch %10 ==0: print('epoch:',epoch) print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr']) checkpoint = { "net": model.state_dict(), 'optimizer': optimizer.state_dict(), "epoch": epoch, 'lr_schedule': lr_schedule.state_dict() } if not os.path.isdir("./model_parameter/test"): os.mkdir("./model_parameter/test")        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))

          如果覺得有用,就請分享到朋友圈吧!

          △點擊卡片關(guān)注極市平臺,獲取最新CV干貨

          公眾號后臺回復(fù)“pytorch”獲取Pytorch 官方書籍英文版電子版


          極市干貨

          YOLO教程:YOLO算法最全綜述:從YOLOv1到Y(jié)OLOv5YOLO系列(從V1到V5)模型解讀!
          實操教程:PyTorch自定義CUDA算子教程與運行時間分析詳解PyTorch中的ModuleList和Sequential詳細(xì)記錄solov2的ncnn實現(xiàn)和優(yōu)化
          算法技巧(trick):深度神經(jīng)網(wǎng)絡(luò)模型訓(xùn)練中的 tricks(原理與代碼匯總)神經(jīng)網(wǎng)絡(luò)訓(xùn)練trick總結(jié)深度學(xué)習(xí)調(diào)參tricks總結(jié)
          最新CV競賽:2021 高通人工智能應(yīng)用創(chuàng)新大賽CVPR 2021 | Short-video Face Parsing Challenge3D人體目標(biāo)檢測與行為分析競賽開賽,獎池7萬+,數(shù)據(jù)集達(dá)16671張!


          CV技術(shù)社群邀請函 #

          △長按添加極市小助手
          添加極市小助手微信(ID : cvmart2)

          備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測-深圳)


          即可申請加入極市目標(biāo)檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實項目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~



          覺得有用麻煩給個在看啦~  
          瀏覽 49
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  无码视屏| 操碰97人人操 | 亚洲一区二区三区人妻 | 亚洲国产激情视频 | 天天干天 |