實(shí)操教程|Pytorch-lightning的使用
點(diǎn)擊上方“程序員大白”,選擇“星標(biāo)”公眾號(hào)
重磅干貨,第一時(shí)間送達(dá)

Pytorch-lightning(以下簡(jiǎn)稱(chēng)pl)可以非常簡(jiǎn)潔得構(gòu)建深度學(xué)習(xí)代碼。但是其實(shí)大部分人用不到很多復(fù)雜得功能。而pl有時(shí)候包裝得過(guò)于深了,用的時(shí)候稍微有一些不靈活。通常來(lái)說(shuō),在你的模型搭建好之后,大部分的功能都會(huì)被封裝在一個(gè)叫trainer的類(lèi)里面。一些比較麻煩但是需要的功能通常如下:
保存checkpoints 輸出log信息 resume training 即重載訓(xùn)練,我們希望可以接著上一次的epoch繼續(xù)訓(xùn)練 記錄模型訓(xùn)練的過(guò)程(通常使用tensorboard) 設(shè)置seed,即保證訓(xùn)練過(guò)程可以復(fù)制
好在這些功能在pl中都已經(jīng)實(shí)現(xiàn)。
由于doc上的很多解釋并不是很清楚,而且網(wǎng)上例子也不是特別多。下面分享一點(diǎn)我自己的使用心得。
首先關(guān)于設(shè)置全局的種子:
from pytorch_lightning import seed_everything# Set seedseed = 42seed_everything(seed)
只需要import如上的seed_everything函數(shù)即可。它應(yīng)該和如下的函數(shù)是等價(jià)的:
def seed_all(seed_value):random.seed(seed_value) # Pythonnp.random.seed(seed_value) # cpu varstorch.manual_seed(seed_value) # cpu varsif torch.cuda.is_available():print ('CUDA is available')torch.cuda.manual_seed(seed_value)torch.cuda.manual_seed_all(seed_value) # gpu varstorch.backends.cudnn.deterministic = True #neededtorch.backends.cudnn.benchmark = Falseseed=42seed_all(seed)
但經(jīng)過(guò)我的測(cè)試,好像pl的seed_everything函數(shù)應(yīng)該更全一點(diǎn)。
下面通過(guò)一個(gè)具體的例子來(lái)說(shuō)明一些使用方法:
先下載、導(dǎo)入必要的包和下載數(shù)據(jù)集:
!pip install pytorch-lightning!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip!unzip -q hymenoptera_data.zip!rm hymenoptera_data.zipimport pytorch_lightning as plimport osimport numpy as npimport randomimport matplotlib.pyplot as pltimport torchimport torch.nn.functional as Fimport torchvisionimport torchvision.transforms as transforms
以下代碼種加入!的代碼是在terminal中運(yùn)行的。在google colab中運(yùn)行l(wèi)inux命令需要在之前加!
如果是使用google colab,由于它創(chuàng)建的是一個(gè)虛擬機(jī),不能及時(shí)保存,所以如果需要保存,掛載自己google云盤(pán)也是有必要的。使用如下的代碼:
from google.colab import drivedrive.mount('./content/drive')import osos.chdir("/content/drive/My Drive/")
先如下定義如下的LightningModule和main函數(shù)。
class CoolSystem(pl.LightningModule):def __init__(self, hparams):super(CoolSystem, self).__init__()self.params = hparamsself.data_dir = self.params.data_dirself.num_classes = self.params.num_classes########## define the model ##########arch = torchvision.models.resnet18(pretrained=True)num_ftrs = arch.fc.in_featuresmodules = list(arch.children())[:-1] # ResNet18 has 10 childrenself.backbone = torch.nn.Sequential(*modules) # [bs, 512, 1, 1]self.final = torch.nn.Sequential(torch.nn.Linear(num_ftrs, 128),torch.nn.ReLU(inplace=True),torch.nn.Linear(128, self.num_classes),torch.nn.Softmax(dim=1))def forward(self, x):x = self.backbone(x)x = x.reshape(x.size(0), -1)x = self.final(x)return xdef configure_optimizers(self):# REQUIREDoptimizer = torch.optim.SGD([{'params': self.backbone.parameters()},{'params': self.final.parameters(), 'lr': 1e-2}], lr=1e-3, momentum=0.9)exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)return [optimizer], [exp_lr_scheduler]def training_step(self, batch, batch_idx):# REQUIREDx, y = batchy_hat = self.forward(x)loss = F.cross_entropy(y_hat, y)_, preds = torch.max(y_hat, dim=1)acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)self.log('train_loss', loss)self.log('train_acc', acc)return {'loss': loss, 'train_acc': acc}def validation_step(self, batch, batch_idx):# OPTIONALx, y = batchy_hat = self.forward(x)loss = F.cross_entropy(y_hat, y)_, preds = torch.max(y_hat, 1)acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)self.log('val_loss', loss)self.log('val_acc', acc)return {'val_loss': loss, 'val_acc': acc}def test_step(self, batch, batch_idx):# OPTIONALx, y = batchy_hat = self.forward(x)loss = F.cross_entropy(y_hat, y)_, preds = torch.max(y_hat, 1)acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)return {'test_loss': loss, 'test_acc': acc}def train_dataloader(self):# REQUIREDtransform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'train'), transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)return train_loaderdef val_dataloader(self):transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=4)return val_loaderdef test_dataloader(self):transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=8, shuffle=True, num_workers=4)return val_loaderdef main(hparams):model = CoolSystem(hparams)trainer = pl.Trainer(max_epochs=hparams.epochs,gpus=1,accelerator='dp')trainer.fit(model)
下面是run的部分:
from argparse import Namespaceargs = {'num_classes': 2,'epochs': 5,'data_dir': "/content/hymenoptera_data",}hyperparams = Namespace(**args)if __name__ == '__main__':main(hyperparams)
如果希望重載訓(xùn)練的話,可以按如下方式:
# resume trainingRESUME = Trueif RESUME:resume_checkpoint_dir = './lightning_logs/version_0/checkpoints/'checkpoint_path = os.listdir(resume_checkpoint_dir)[0]resume_checkpoint_path = resume_checkpoint_dir + checkpoint_pathargs = {: 2,: "/content/hymenoptera_data"}hparams = Namespace(**args)model = CoolSystem(hparams)trainer = pl.Trainer(gpus=1,max_epochs=10,accelerator='dp',resume_from_checkpoint = resume_checkpoint_path)trainer.fit(model)
如果我們想要從checkpoint加載模型,并進(jìn)行使用可以按如下操作來(lái):
import matplotlib.pyplot as pltimport numpy as np# functions to show an imagedef imshow(inp):inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)plt.show()classes = ['ants', 'bees']checkpoint_dir = 'lightning_logs/version_1/checkpoints/'checkpoint_path = checkpoint_dir + os.listdir(checkpoint_dir)[0]checkpoint = torch.load(checkpoint_path)model_infer = CoolSystem(hparams)model_infer.load_state_dict(checkpoint['state_dict'])try_dataloader = model_infer.test_dataloader()inputs, labels = next(iter(try_dataloader))# print images and ground truthimshow(torchvision.utils.make_grid(inputs))print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(8)))# inferenceoutputs = model_infer(inputs)_, preds = torch.max(outputs, dim=1)# print (preds)print (torch.sum(preds == labels.data) / (labels.shape[0] * 1.0))print('Predicted: ', ' '.join('%5s' % classes[preds[j]] for j in range(8)))

預(yù)測(cè)結(jié)果如上。
如果希望檢測(cè)訓(xùn)練過(guò)程(第一部分+重載訓(xùn)練的部分),如下:
tensorboard%load_ext tensorboardtensorboard --logdir = ./lightning_logs

訓(xùn)練過(guò)程在tensorboard里面記錄,version0是第一次的訓(xùn)練,version1是重載后的結(jié)果。
完整的code在這里.
https://colab.research.google.com/gist/calibertytz/a9de31175ce15f384dead94c2a9fad4d/pl_tutorials_1.ipynb
推薦閱讀
國(guó)產(chǎn)小眾瀏覽器因屏蔽視頻廣告,被索賠100萬(wàn)(后續(xù))
年輕人“不講武德”:因看黃片上癮,把網(wǎng)站和786名女主播起訴了
關(guān)于程序員大白
程序員大白是一群哈工大,東北大學(xué),西湖大學(xué)和上海交通大學(xué)的碩士博士運(yùn)營(yíng)維護(hù)的號(hào),大家樂(lè)于分享高質(zhì)量文章,喜歡總結(jié)知識(shí),歡迎關(guān)注[程序員大白],大家一起學(xué)習(xí)進(jìn)步!

