<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          使用pytorch-lightning漂亮地進(jìn)行深度學(xué)習(xí)研究

          共 9222字,需瀏覽 19分鐘

           ·

          2021-01-18 23:39

          你好,我是云哥。最近研究了一下pytorch-lightning,寫(xiě)了兩篇博客記錄,這是其一,其二有更多的驚喜,敬請(qǐng)期待。????

          公眾號(hào)后臺(tái)回復(fù)關(guān)鍵字:源碼,獲取本文源代碼!

          pytorch-lightning 是建立在pytorch之上的高層次模型接口。

          pytorch-lightning 之于 pytorch,就如同keras之于 tensorflow。

          通過(guò)使用 pytorch-lightning,用戶(hù)無(wú)需編寫(xiě)自定義訓(xùn)練循環(huán)就可以非常簡(jiǎn)潔地在CPU、單GPU、多GPU、乃至多TPU上訓(xùn)練模型。

          無(wú)需考慮模型和數(shù)據(jù)在cpu,cuda之間的移動(dòng),并且可以通過(guò)回調(diào)函數(shù)實(shí)現(xiàn)CheckPoint參數(shù)保存,實(shí)現(xiàn)斷點(diǎn)續(xù)訓(xùn)功能。

          一般按照如下方式 安裝和 引入 pytorch-lightning 庫(kù)。

          #安裝
          pip install pytorch-lightning
          #引入
          import pytorch_lightning as pl 

          顧名思義,它可以幫助我們漂亮(pl)地進(jìn)行深度學(xué)習(xí)研究。????

          一,pytorch-lightning的設(shè)計(jì)哲學(xué)

          pytorch-lightning 的核心設(shè)計(jì)哲學(xué)是將 深度學(xué)習(xí)項(xiàng)目中的 研究代碼(定義模型) 和 工程代碼 (訓(xùn)練模型) 相互分離。

          用戶(hù)只需專(zhuān)注于研究代碼(pl.LightningModule)的實(shí)現(xiàn),而工程代碼借助訓(xùn)練工具類(lèi)(pl.Trainer)統(tǒng)一實(shí)現(xiàn)。

          更詳細(xì)地說(shuō),深度學(xué)習(xí)項(xiàng)目代碼可以分成如下4部分:

          • 研究代碼 (Research code),用戶(hù)繼承LightningModule實(shí)現(xiàn)。
          • 工程代碼 (Engineering code),用戶(hù)無(wú)需關(guān)注通過(guò)調(diào)用Trainer實(shí)現(xiàn)。
          • 非必要代碼 (Non-essential research code,logging, etc...),用戶(hù)通過(guò)調(diào)用Callbacks實(shí)現(xiàn)。
          • 數(shù)據(jù) (Data),用戶(hù)通過(guò)torch.utils.data.DataLoader實(shí)現(xiàn)。

          二,pytorch-lightning使用范例

          下面我們使用minist圖片分類(lèi)問(wèn)題為例,演示pytorch-lightning的最佳實(shí)踐。

          1,準(zhǔn)備數(shù)據(jù)

          import torch 
          from torch import nn 

          import torchvision 
          from torchvision import transforms

          transform = transforms.Compose([transforms.ToTensor()])

          ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
          ds_valid = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)

          dl_train =  torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
          dl_valid =  torch.utils.data.DataLoader(ds_valid, batch_size=128, shuffle=False, num_workers=4)

          print(len(ds_train))
          print(len(ds_valid))

          Done!
          60000
          10000

          2,定義模型

          import pytorch_lightning as pl 
          import datetime

          class Model(pl.LightningModule):
              
              def __init__(self):
                  super().__init__()
                  self.layers = nn.ModuleList([
                      nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
                      nn.MaxPool2d(kernel_size = 2,stride = 2),
                      nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
                      nn.MaxPool2d(kernel_size = 2,stride = 2),
                      nn.Dropout2d(p = 0.1),
                      nn.AdaptiveMaxPool2d((1,1)),
                      nn.Flatten(),
                      nn.Linear(64,32),
                      nn.ReLU(),
                      nn.Linear(32,10)]
                  )
                  
              def forward(self,x):
                  for layer in self.layers:
                      x = layer(x)
                  return x
              
              #定義loss,以及可選的各種metrics
              def training_step(self, batch, batch_idx):
                  x, y = batch
                  prediction = self(x)
                  loss = nn.CrossEntropyLoss()(prediction,y)
                  return loss
              
              #定義optimizer,以及可選的lr_scheduler
              def configure_optimizers(self):
                  optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
                  return {"optimizer":optimizer}
              
              def validation_step(self, batch, batch_idx):
                  loss = self.training_step(batch,batch_idx)
                  return {"val_loss":loss}
              
              def test_step(self, batch, batch_idx):
                  loss = self.training_step(batch,batch_idx)
                  return {"test_loss":loss}
              
            

          3,訓(xùn)練模型

          pl.seed_everything(1234)
          model = Model() 


          ckpt_callback = pl.callbacks.ModelCheckpoint(
              monitor='val_loss',
              save_top_k=1,
              mode='min'
          )

          # gpus=0 則使用cpu訓(xùn)練,gpus=1則使用1個(gè)gpu訓(xùn)練,gpus=2則使用2個(gè)gpu訓(xùn)練,gpus=-1則使用所有g(shù)pu訓(xùn)練,
          # gpus=[0,1]則指定使用0號(hào)和1號(hào)gpu訓(xùn)練, gpus="0,1,2,3"則使用0,1,2,3號(hào)gpu訓(xùn)練
          # tpus=1 則使用1個(gè)tpu訓(xùn)練

          trainer = pl.Trainer(max_epochs=5,gpus=0,callbacks = [ckpt_callback]) 

          #斷點(diǎn)續(xù)訓(xùn)
          #trainer = pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')

          trainer.fit(model,dl_train,dl_valid)

          Global seed set to 1234
          GPU available: False, used: False
          TPU available: None, using: 0 TPU cores

            | Name   | Type       | Params
          --------------------------------------
          0 | layers | ModuleList | 54.0 K
          --------------------------------------
          54.0 K    Trainable params
          0         Non-trainable params
          54.0 K    Total params
          Epoch 4: 100% >>>>>>>>>>>>>>>>>>>>>>>>>>>> 158/158 [00:19<00:00, 8.08it/s, loss=0.138, v_num=34]

          4,評(píng)估模型

          result = trainer.test(model, test_dataloaders=dl_valid)
          print(result)
          --------------------------------------------------------------------------------
          DATALOADER:0 TEST RESULTS
          {'test_loss': tensor(0.0047)}
          --------------------------------------------------------------------------------
          [{'test_loss': 0.004680501762777567}]

          5,使用模型

          data,label = next(iter(dl_valid))
          model.eval()
          prediction = model(data)
          print(prediction)

          tensor([[ -5.1149,  -6.1142,   2.0591,  ...,   7.0609,  -5.4144,   0.5222],
                  [ -2.2989,  -5.6076,   3.7343,  ...,  -1.8391,  -6.4941,  -3.4076],
                  [  0.9215,   6.9357,  -1.9887,  ...,  -2.2996,  -0.8034,  -3.2993],
                  ...,
                  [ -4.5674,  -6.0223,  -0.9309,  ...,  -3.5468,   0.3367,   4.5473],
                  [  4.3023,  -4.1629,  -1.2742,  ...,  -4.2527,  -2.3449,  -2.5585],
                  [ -3.8913, -10.3790,  -1.7804,  ...,  -4.6757,  -0.7428,   1.0305]],
                 grad_fn= )

          6,保存模型

          最優(yōu)模型默認(rèn)保存在 trainer.checkpoint_callback.best_model_path 的目錄下,可以直接加載。

          print(trainer.checkpoint_callback.best_model_path)
          print(trainer.checkpoint_callback.best_model_score)
          /Users/liangyun/CodeFiles/PythonAiRoad/lightning_logs/version_34/checkpoints/epoch=04-val_loss=0.00.ckpt
          tensor(0.0047)

          model_clone = Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
          trainer_clone = pl.Trainer(max_epochs=3
          result = trainer_clone.test(model_clone,dl_valid)
          print(result)

          --------------------------------------------------------------------------------
          DATALOADER:0 TEST RESULTS
          {'test_loss': tensor(0.0047)}
          --------------------------------------------------------------------------------
          [{'test_loss': 0.004680501762777567}]

          如果對(duì)本文內(nèi)容理解上有需要進(jìn)一步和作者交流的地方,歡迎在公眾號(hào)"算法美食屋"下留言。作者時(shí)間和精力有限,會(huì)酌情予以回復(fù)。

          也可以在公眾號(hào)后臺(tái)回復(fù)關(guān)鍵字:加群,加入讀者交流群和大家討論。

          瀏覽 29
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  色视频网| 91成人三级 | 色婷婷在线小视频 | 午夜福利剧场 | 中文字幕无码不卡 |