使用pytorch-lightning漂亮地進(jìn)行深度學(xué)習(xí)研究
你好,我是云哥。最近研究了一下pytorch-lightning,寫(xiě)了兩篇博客記錄,這是其一,其二有更多的驚喜,敬請(qǐng)期待。????
公眾號(hào)后臺(tái)回復(fù)關(guān)鍵字:源碼,獲取本文源代碼!
pytorch-lightning 是建立在pytorch之上的高層次模型接口。
pytorch-lightning 之于 pytorch,就如同keras之于 tensorflow。
通過(guò)使用 pytorch-lightning,用戶(hù)無(wú)需編寫(xiě)自定義訓(xùn)練循環(huán)就可以非常簡(jiǎn)潔地在CPU、單GPU、多GPU、乃至多TPU上訓(xùn)練模型。
無(wú)需考慮模型和數(shù)據(jù)在cpu,cuda之間的移動(dòng),并且可以通過(guò)回調(diào)函數(shù)實(shí)現(xiàn)CheckPoint參數(shù)保存,實(shí)現(xiàn)斷點(diǎn)續(xù)訓(xùn)功能。
一般按照如下方式 安裝和 引入 pytorch-lightning 庫(kù)。
#安裝
pip install pytorch-lightning
#引入
import pytorch_lightning as pl
顧名思義,它可以幫助我們漂亮(pl)地進(jìn)行深度學(xué)習(xí)研究。????
一,pytorch-lightning的設(shè)計(jì)哲學(xué)
pytorch-lightning 的核心設(shè)計(jì)哲學(xué)是將 深度學(xué)習(xí)項(xiàng)目中的 研究代碼(定義模型) 和 工程代碼 (訓(xùn)練模型) 相互分離。
用戶(hù)只需專(zhuān)注于研究代碼(pl.LightningModule)的實(shí)現(xiàn),而工程代碼借助訓(xùn)練工具類(lèi)(pl.Trainer)統(tǒng)一實(shí)現(xiàn)。
更詳細(xì)地說(shuō),深度學(xué)習(xí)項(xiàng)目代碼可以分成如下4部分:
-
研究代碼 (Research code),用戶(hù)繼承LightningModule實(shí)現(xiàn)。 -
工程代碼 (Engineering code),用戶(hù)無(wú)需關(guān)注通過(guò)調(diào)用Trainer實(shí)現(xiàn)。 -
非必要代碼 (Non-essential research code,logging, etc...),用戶(hù)通過(guò)調(diào)用Callbacks實(shí)現(xiàn)。 -
數(shù)據(jù) (Data),用戶(hù)通過(guò)torch.utils.data.DataLoader實(shí)現(xiàn)。
二,pytorch-lightning使用范例
下面我們使用minist圖片分類(lèi)問(wèn)題為例,演示pytorch-lightning的最佳實(shí)踐。
1,準(zhǔn)備數(shù)據(jù)
import torch
from torch import nn
import torchvision
from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor()])
ds_train = torchvision.datasets.MNIST(root="./minist/",train=True,download=True,transform=transform)
ds_valid = torchvision.datasets.MNIST(root="./minist/",train=False,download=True,transform=transform)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=4)
dl_valid = torch.utils.data.DataLoader(ds_valid, batch_size=128, shuffle=False, num_workers=4)
print(len(ds_train))
print(len(ds_valid))
Done!
60000
10000
2,定義模型
import pytorch_lightning as pl
import datetime
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
nn.Conv2d(in_channels=1,out_channels=32,kernel_size = 3),
nn.MaxPool2d(kernel_size = 2,stride = 2),
nn.Conv2d(in_channels=32,out_channels=64,kernel_size = 5),
nn.MaxPool2d(kernel_size = 2,stride = 2),
nn.Dropout2d(p = 0.1),
nn.AdaptiveMaxPool2d((1,1)),
nn.Flatten(),
nn.Linear(64,32),
nn.ReLU(),
nn.Linear(32,10)]
)
def forward(self,x):
for layer in self.layers:
x = layer(x)
return x
#定義loss,以及可選的各種metrics
def training_step(self, batch, batch_idx):
x, y = batch
prediction = self(x)
loss = nn.CrossEntropyLoss()(prediction,y)
return loss
#定義optimizer,以及可選的lr_scheduler
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return {"optimizer":optimizer}
def validation_step(self, batch, batch_idx):
loss = self.training_step(batch,batch_idx)
return {"val_loss":loss}
def test_step(self, batch, batch_idx):
loss = self.training_step(batch,batch_idx)
return {"test_loss":loss}
3,訓(xùn)練模型
pl.seed_everything(1234)
model = Model()
ckpt_callback = pl.callbacks.ModelCheckpoint(
monitor='val_loss',
save_top_k=1,
mode='min'
)
# gpus=0 則使用cpu訓(xùn)練,gpus=1則使用1個(gè)gpu訓(xùn)練,gpus=2則使用2個(gè)gpu訓(xùn)練,gpus=-1則使用所有g(shù)pu訓(xùn)練,
# gpus=[0,1]則指定使用0號(hào)和1號(hào)gpu訓(xùn)練, gpus="0,1,2,3"則使用0,1,2,3號(hào)gpu訓(xùn)練
# tpus=1 則使用1個(gè)tpu訓(xùn)練
trainer = pl.Trainer(max_epochs=5,gpus=0,callbacks = [ckpt_callback])
#斷點(diǎn)續(xù)訓(xùn)
#trainer = pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')
trainer.fit(model,dl_train,dl_valid)
Global seed set to 1234
GPU available: False, used: False
TPU available: None, using: 0 TPU cores
| Name | Type | Params
--------------------------------------
0 | layers | ModuleList | 54.0 K
--------------------------------------
54.0 K Trainable params
0 Non-trainable params
54.0 K Total params
Epoch 4: 100% >>>>>>>>>>>>>>>>>>>>>>>>>>>> 158/158 [00:19<00:00, 8.08it/s, loss=0.138, v_num=34]
4,評(píng)估模型
result = trainer.test(model, test_dataloaders=dl_valid)
print(result)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0047)}
--------------------------------------------------------------------------------
[{'test_loss': 0.004680501762777567}]
5,使用模型
data,label = next(iter(dl_valid))
model.eval()
prediction = model(data)
print(prediction)
tensor([[ -5.1149, -6.1142, 2.0591, ..., 7.0609, -5.4144, 0.5222],
[ -2.2989, -5.6076, 3.7343, ..., -1.8391, -6.4941, -3.4076],
[ 0.9215, 6.9357, -1.9887, ..., -2.2996, -0.8034, -3.2993],
...,
[ -4.5674, -6.0223, -0.9309, ..., -3.5468, 0.3367, 4.5473],
[ 4.3023, -4.1629, -1.2742, ..., -4.2527, -2.3449, -2.5585],
[ -3.8913, -10.3790, -1.7804, ..., -4.6757, -0.7428, 1.0305]],
grad_fn=
)
6,保存模型
最優(yōu)模型默認(rèn)保存在 trainer.checkpoint_callback.best_model_path 的目錄下,可以直接加載。
print(trainer.checkpoint_callback.best_model_path)
print(trainer.checkpoint_callback.best_model_score)
/Users/liangyun/CodeFiles/PythonAiRoad/lightning_logs/version_34/checkpoints/epoch=04-val_loss=0.00.ckpt
tensor(0.0047)
model_clone = Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
trainer_clone = pl.Trainer(max_epochs=3)
result = trainer_clone.test(model_clone,dl_valid)
print(result)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.0047)}
--------------------------------------------------------------------------------
[{'test_loss': 0.004680501762777567}]
如果對(duì)本文內(nèi)容理解上有需要進(jìn)一步和作者交流的地方,歡迎在公眾號(hào)"算法美食屋"下留言。作者時(shí)間和精力有限,會(huì)酌情予以回復(fù)。
也可以在公眾號(hào)后臺(tái)回復(fù)關(guān)鍵字:加群,加入讀者交流群和大家討論。
