<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          實操教程|Pytorch-lightning的使用

          共 12034字,需瀏覽 25分鐘

           ·

          2021-05-20 11:03

          ↑ 點擊藍字 關(guān)注極市平臺

          作者丨Caliber@知乎(已授權(quán))
          來源丨h(huán)ttps://zhuanlan.zhihu.com/p/370185203
          編輯丨極市平臺

          極市導讀

           

          Pytorch-lightning可以非常簡潔得構(gòu)建深度學習代碼。但是其實大部分人用不到很多復雜得功能,并且用的時候稍微有一些不靈活。本文作者分享了自己在使用時的一些心得,附有代碼鏈接。 >>加入極市CV技術(shù)交流群,走在計算機視覺的最前沿

          Pytorch-lightning(以下簡稱pl)可以非常簡潔得構(gòu)建深度學習代碼。但是其實大部分人用不到很多復雜得功能。而pl有時候包裝得過于深了,用的時候稍微有一些不靈活。通常來說,在你的模型搭建好之后,大部分的功能都會被封裝在一個叫trainer的類里面。一些比較麻煩但是需要的功能通常如下:

          1. 保存checkpoints
          2. 輸出log信息
          3. resume training 即重載訓練,我們希望可以接著上一次的epoch繼續(xù)訓練
          4. 記錄模型訓練的過程(通常使用tensorboard)
          5. 設(shè)置seed,即保證訓練過程可以復制

          好在這些功能在pl中都已經(jīng)實現(xiàn)。

          由于doc上的很多解釋并不是很清楚,而且網(wǎng)上例子也不是特別多。下面分享一點我自己的使用心得。

          首先關(guān)于設(shè)置全局的種子:

              
          from pytorch_lightning import seed_everything
          # Set seedseed = 42seed_everything(seed)

          只需要import如上的seed_everything函數(shù)即可。它應該和如下的函數(shù)是等價的:

              
          def seed_all(seed_value):    random.seed(seed_value) # Python    np.random.seed(seed_value) # cpu vars    torch.manual_seed(seed_value) # cpu vars        if torch.cuda.is_available():         print ('CUDA is available')        torch.cuda.manual_seed(seed_value)        torch.cuda.manual_seed_all(seed_value) # gpu vars        torch.backends.cudnn.deterministic = True  #needed        torch.backends.cudnn.benchmark = False
          seed=42seed_all(seed)

          但經(jīng)過我的測試,好像pl的seed_everything函數(shù)應該更全一點。

          下面通過一個具體的例子來說明一些使用方法:

          先下載、導入必要的包和下載數(shù)據(jù)集:

              
          !pip install pytorch-lightning!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip!unzip -q hymenoptera_data.zip!rm hymenoptera_data.zip
          import pytorch_lightning as plimport osimport numpy as np import randomimport matplotlib.pyplot as plt
          import torchimport torch.nn.functional as Fimport torchvisionimport torchvision.transforms as transforms

          以下代碼種加入!的代碼是在terminal中運行的。在google colab中運行l(wèi)inux命令需要在之前加!

          如果是使用google colab,由于它創(chuàng)建的是一個虛擬機,不能及時保存,所以如果需要保存,掛載自己google云盤也是有必要的。使用如下的代碼:

              
          from google.colab import drivedrive.mount('./content/drive')
          import osos.chdir("/content/drive/My Drive/")

          先如下定義如下的LightningModule和main函數(shù)。

          class CoolSystem(pl.LightningModule):
          def __init__(self, hparams): super(CoolSystem, self).__init__()
          self.params = hparams self.data_dir = self.params.data_dir self.num_classes = self.params.num_classes
          ########## define the model ########## arch = torchvision.models.resnet18(pretrained=True) num_ftrs = arch.fc.in_features
          modules = list(arch.children())[:-1] # ResNet18 has 10 children self.backbone = torch.nn.Sequential(*modules) # [bs, 512, 1, 1] self.final = torch.nn.Sequential( torch.nn.Linear(num_ftrs, 128), torch.nn.ReLU(inplace=True), torch.nn.Linear(128, self.num_classes), torch.nn.Softmax(dim=1))
          def forward(self, x): x = self.backbone(x) x = x.reshape(x.size(0), -1) x = self.final(x) return x def configure_optimizers(self): # REQUIRED optimizer = torch.optim.SGD([ {'params': self.backbone.parameters()}, {'params': self.final.parameters(), 'lr': 1e-2} ], lr=1e-3, momentum=0.9)
          exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1) return [optimizer], [exp_lr_scheduler]
          def training_step(self, batch, batch_idx): # REQUIRED x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) _, preds = torch.max(y_hat, dim=1) acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)
          self.log('train_loss', loss) self.log('train_acc', acc)
          return {'loss': loss, 'train_acc': acc}
          def validation_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) _, preds = torch.max(y_hat, 1) acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)
          self.log('val_loss', loss) self.log('val_acc', acc)
          return {'val_loss': loss, 'val_acc': acc}

          def test_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) _, preds = torch.max(y_hat, 1) acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)
          return {'test_loss': loss, 'test_acc': acc}

          def train_dataloader(self): # REQUIRED
          transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
          train_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'train'), transform) train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)
          return train_loader def val_dataloader(self): transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform) val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=4)
          return val_loader
          def test_dataloader(self): transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform) val_loader = torch.utils.data.DataLoader(val_set, batch_size=8, shuffle=True, num_workers=4)
          return val_loader




          def main(hparams): model = CoolSystem(hparams)

          trainer = pl.Trainer( max_epochs=hparams.epochs, gpus=1, accelerator='dp' )
          trainer.fit(model)
          下面是run的部分:
              
          from argparse import Namespace
          args = { 'num_classes': 2, 'epochs': 5, 'data_dir': "/content/hymenoptera_data",}
          hyperparams = Namespace(**args)

          if __name__ == '__main__': main(hyperparams)

          如果希望重載訓練的話,可以按如下方式:

              
          # resume training
          RESUME = True
          if RESUME: resume_checkpoint_dir = './lightning_logs/version_0/checkpoints/' checkpoint_path = os.listdir(resume_checkpoint_dir)[0] resume_checkpoint_path = resume_checkpoint_dir + checkpoint_path

          args = { 'num_classes': 2, 'data_dir': "/content/hymenoptera_data"}
          hparams = Namespace(**args)
          model = CoolSystem(hparams)
          trainer = pl.Trainer(gpus=1, max_epochs=10, accelerator='dp', resume_from_checkpoint = resume_checkpoint_path)
          trainer.fit(model)

          如果我們想要從checkpoint加載模型,并進行使用可以按如下操作來:

              
          import matplotlib.pyplot as pltimport numpy as np
          # functions to show an imagedef imshow(inp): inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) plt.imshow(inp) plt.show()
          classes = ['ants', 'bees']
          checkpoint_dir = 'lightning_logs/version_1/checkpoints/'checkpoint_path = checkpoint_dir + os.listdir(checkpoint_dir)[0]
          checkpoint = torch.load(checkpoint_path)model_infer = CoolSystem(hparams)model_infer.load_state_dict(checkpoint['state_dict'])
          try_dataloader = model_infer.test_dataloader()
          inputs, labels = next(iter(try_dataloader))
          # print images and ground truthimshow(torchvision.utils.make_grid(inputs))print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(8)))
          # inferenceoutputs = model_infer(inputs)
          _, preds = torch.max(outputs, dim=1)# print (preds)print (torch.sum(preds == labels.data) / (labels.shape[0] * 1.0))
          print('Predicted: ', ' '.join('%5s' % classes[preds[j]] for j in range(8)))

          預測結(jié)果如上。

          如果希望檢測訓練過程(第一部分+重載訓練的部分),如下:

              
          # tensorboard
          %load_ext tensorboard%tensorboard --logdir = ./lightning_logs

          訓練過程在tensorboard里面記錄,version0是第一次的訓練,version1是重載后的結(jié)果。

          完整的code在這里.

          https://colab.research.google.com/gist/calibertytz/a9de31175ce15f384dead94c2a9fad4d/pl_tutorials_1.ipynb

          如果覺得有用,就請分享到朋友圈吧!

          △點擊卡片關(guān)注極市平臺,獲取最新CV干貨

          公眾號后臺回復“目標檢測”獲取目標檢測算法綜述盤點~


          極市干貨
          YOLO教程:一文讀懂YOLO V5 與 YOLO V4大盤點|YOLO 系目標檢測算法總覽全面解析YOLO V4網(wǎng)絡(luò)結(jié)構(gòu)
          實操教程:PyTorch vs LibTorch:網(wǎng)絡(luò)推理速度誰更快?只用兩行代碼,我讓Transformer推理加速了50倍PyTorch AutoGrad C++層實現(xiàn)
          算法技巧(trick):深度學習訓練tricks總結(jié)(有實驗支撐)深度強化學習調(diào)參Tricks合集長尾識別中的Tricks匯總(AAAI2021
          最新CV競賽:2021 高通人工智能應用創(chuàng)新大賽CVPR 2021 | Short-video Face Parsing Challenge3D人體目標檢測與行為分析競賽開賽,獎池7萬+,數(shù)據(jù)集達16671張!



          CV技術(shù)社群邀請函 #

          △長按添加極市小助手
          添加極市小助手微信(ID : cvmart2)

          備注:姓名-學校/公司-研究方向-城市(如:小極-北大-目標檢測-深圳)


          即可申請加入極市目標檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實項目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~



          覺得有用麻煩給個在看啦~  
          瀏覽 53
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  国产日韩欧美亚洲 | 超碰在线香蕉 | 操老免费网 | 亚洲小说一区二区 | 欧美日韩操逼视屏 |