<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          實操教程 | 深度學習pytorch訓練代碼模板(個人習慣)

          共 6111字,需瀏覽 13分鐘

           ·

          2021-08-20 14:18

          ↑ 點擊藍字 關注極市平臺

          作者丨wfnian@知乎(已授權)
          來源丨https://zhuanlan.zhihu.com/p/396666255
          編輯丨極市平臺

          極市導讀

           

          本文從參數(shù)定義,到網絡模型定義,再到訓練步驟,驗證步驟,測試步驟,總結了一套較為直觀的模板。 >>加入極市CV技術交流群,走在計算機視覺的最前沿

          目錄如下:

          1. 導入包以及設置隨機種子
          2. 以類的方式定義超參數(shù)
          3. 定義自己的模型
          4. 定義早停類(此步驟可以省略)
          5. 定義自己的數(shù)據(jù)集Dataset,DataLoader
          6. 實例化模型,設置loss,優(yōu)化器等
          7. 開始訓練以及調整lr
          8. 繪圖
          9. 預測

          一、導入包以及設置隨機種子

          import numpy as np
          import torch
          import torch.nn as nn
          import numpy as np
          import pandas as pd
          from torch.utils.data import DataLoader, Dataset
          from sklearn.model_selection import train_test_split
          import matplotlib.pyplot as plt

          import random
          seed = 42
          torch.manual_seed(seed)
          np.random.seed(seed)
          random.seed(seed)

          二、以類的方式定義超參數(shù)

          class argparse():
          pass

          args = argparse()
          args.epochs, args.learning_rate, args.patience = [30, 0.001, 4]
          args.hidden_size, args.input_size= [40, 30]
          args.device, = [torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),]

          三、定義自己的模型

          class Your_model(nn.Module):
          def __init__(self):
          super(Your_model, self).__init__()
          pass

          def forward(self,x):
          pass
          return x

          四、定義早停類(此步驟可以省略)

          class EarlyStopping():
          def __init__(self,patience=7,verbose=False,delta=0):
          self.patience = patience
          self.verbose = verbose
          self.counter = 0
          self.best_score = None
          self.early_stop = False
          self.val_loss_min = np.Inf
          self.delta = delta
          def __call__(self,val_loss,model,path):
          print("val_loss={}".format(val_loss))
          score = -val_loss
          if self.best_score is None:
          self.best_score = score
          self.save_checkpoint(val_loss,model,path)
          elif score < self.best_score+self.delta:
          self.counter+=1
          print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
          if self.counter>=self.patience:
          self.early_stop = True
          else:
          self.best_score = score
          self.save_checkpoint(val_loss,model,path)
          self.counter = 0
          def save_checkpoint(self,val_loss,model,path):
          if self.verbose:
          print(
          f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
          torch.save(model.state_dict(), path+'/'+'model_checkpoint.pth')
          self.val_loss_min = val_loss

          五、定義自己的數(shù)據(jù)集Dataset,DataLoader

          class Dataset_name(Dataset):
          def __init__(self, flag='train'):
          assert flag in ['train', 'test', 'valid']
          self.flag = flag
          self.__load_data__()

          def __getitem__(self, index):
          pass
          def __len__(self):
          pass

          def __load_data__(self, csv_paths: list):
          pass
          print(
          "train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n"
          .format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))

          train_dataset = Dataset_name(flag='train')
          train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
          valid_dataset = Dataset_name(flag='valid')
          valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)

          六、實例化模型,設置loss,優(yōu)化器等

          model = Your_model().to(args.device)
          criterion = torch.nn.MSELoss()
          optimizer = torch.optim.Adam(Your_model.parameters(),lr=args.learning_rate)

          train_loss = []
          valid_loss = []
          train_epochs_loss = []
          valid_epochs_loss = []

          early_stopping = EarlyStopping(patience=args.patience,verbose=True)

          七、開始訓練以及調整lr

          for epoch in range(args.epochs):
          Your_model.train()
          train_epoch_loss = []
          for idx,(data_x,data_y) in enumerate(train_dataloader,0):
          data_x = data_x.to(torch.float32).to(args.device)
          data_y = data_y.to(torch.float32).to(args.device)
          outputs = Your_model(data_x)
          optimizer.zero_grad()
          loss = criterion(data_y,outputs)
          loss.backward()
          optimizer.step()
          train_epoch_loss.append(loss.item())
          train_loss.append(loss.item())
          if idx%(len(train_dataloader)//2)==0:
          print("epoch={}/{},{}/{}of train, loss={}".format(
          epoch, args.epochs, idx, len(train_dataloader),loss.item()))
          train_epochs_loss.append(np.average(train_epoch_loss))

          #=====================valid============================
          Your_model.eval()
          valid_epoch_loss = []
          for idx,(data_x,data_y) in enumerate(valid_dataloader,0):
          data_x = data_x.to(torch.float32).to(args.device)
          data_y = data_y.to(torch.float32).to(args.device)
          outputs = Your_model(data_x)
          loss = criterion(outputs,data_y)
          valid_epoch_loss.append(loss.item())
          valid_loss.append(loss.item())
          valid_epochs_loss.append(np.average(valid_epoch_loss))
          #==================early stopping======================
          early_stopping(valid_epochs_loss[-1],model=Your_model,path=r'c:\\your_model_to_save')
          if early_stopping.early_stop:
          print("Early stopping")
          break
          #====================adjust lr========================
          lr_adjust = {
          2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
          10: 5e-7, 15: 1e-7, 20: 5e-8
          }
          if epoch in lr_adjust.keys():
          lr = lr_adjust[epoch]
          for param_group in optimizer.param_groups:
          param_group['lr'] = lr
          print('Updating learning rate to {}'.format(lr))

          八、繪圖

          plt.figure(figsize=(12,4))
          plt.subplot(121)
          plt.plot(train_loss[:])
          plt.title("train_loss")
          plt.subplot(122)
          plt.plot(train_epochs_loss[1:],'-o',label="train_loss")
          plt.plot(valid_epochs_loss[1:],'-o',label="valid_loss")
          plt.title("epochs_loss")
          plt.legend()
          plt.show()

          九、預測

          # 此處可定義一個預測集的Dataloader。也可以直接將你的預測數(shù)據(jù)reshape,添加batch_size=1
          Your_model.eval()
          predict = Your_model(data)

          如果覺得有用,就請分享到朋友圈吧!

          △點擊卡片關注極市平臺,獲取最新CV干貨

          公眾號后臺回復“CVPR21檢測”獲取CVPR2021目標檢測論文下載~


          極市干貨
          深度學習環(huán)境搭建:如何配置一臺深度學習工作站?
          實操教程:OpenVINO2021.4+YOLOX目標檢測模型測試部署為什么你的顯卡利用率總是0%?
          算法技巧(trick):圖像分類算法優(yōu)化技巧21個深度學習調參的實用技巧


          CV技術社群邀請函 #

          △長按添加極市小助手
          添加極市小助手微信(ID : cvmart4)

          備注:姓名-學校/公司-研究方向-城市(如:小極-北大-目標檢測-深圳)


          即可申請加入極市目標檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強/OCR/視頻理解等技術交流群


          每月大咖直播分享、真實項目需求對接、求職內推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~



          覺得有用麻煩給個在看啦~  
          瀏覽 28
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  手机在线免费看av | 婷婷色吧综合AV | 青春草在线观看 | 日本香蕉色 | 免费黄色视频网站亚洲 |