PyTorch Lightning工具學習
【GiantPandaCV導語】Pytorch Lightning是在Pytorch基礎上進行封裝的庫(可以理解為keras之于tensorflow),為了讓用戶能夠脫離PyTorch一些繁瑣的細節(jié),專注于核心代碼的構建,提供了許多實用工具,可以讓實驗更加高效。本文將介紹安裝方法、設計邏輯、轉化的例子等內容。
PyTorch Lightning中提供了以下比較方便的功能:
multi-GPU訓練 半精度訓練 TPU 訓練 將訓練細節(jié)進行抽象,從而可以快速迭代

1. 簡單介紹
PyTorch lightning 是為AI相關的專業(yè)的研究人員、研究生、博士等人群開發(fā)的。PyTorch就是William Falcon在他的博士階段創(chuàng)建的,目標是讓AI研究擴展性更強,忽略一些耗費時間的細節(jié)。
目前PyTorch Lightning庫已經有了一定的影響力,star已經1w+,同時有超過1千多的研究人員在一起維護這個框架。

同時PyTorch Lightning也在隨著PyTorch版本的更新也在不停迭代。

官方文檔也有支持,正在不斷更新:

下面介紹一下如何安裝。
2. 安裝方法
Pytorch Lightning安裝非常方便,推薦使用conda環(huán)境進行安裝。
source?activate?you_env
pip?install?pytorch-lightning
或者直接用pip安裝:
pip?install?pytorch-lightning
或者通過conda安裝:
conda?install?pytorch-lightning?-c?conda-forge
3. Lightning的設計思想
Lightning將大部分AI相關代碼分為三個部分:
研究代碼,主要是模型的結構、訓練等部分。被抽象為LightningModule類。
工程代碼,這部分代碼重復性強,比如16位精度,分布式訓練。被抽象為Trainer類。
非必要代碼,這部分代碼和實驗沒有直接關系,不加也可以,加上可以輔助,比如梯度檢查,log輸出等。被抽象為Callbacks類。
Lightning將研究代碼劃分為以下幾個組件:
模型 數(shù)據(jù)處理 損失函數(shù) 優(yōu)化器
以上四個組件都將集成到LightningModule類中,是在Module類之上進行了擴展,進行了功能性補充,比如原來優(yōu)化器使用在main函數(shù)中,是一種面向過程的用法,現(xiàn)在集成到LightningModule中,作為一個類的方法。
4. LightningModule生命周期
這部分參考了https://zhuanlan.zhihu.com/p/120331610 和 官方文檔 https://pytorch-lightning.readthedocs.io/en/latest/trainer.html
在這個模塊中,將PyTorch代碼按照五個部分進行組織:
Computations(init) 初始化相關計算 Train Loop(training_step) 每個step中執(zhí)行的代碼 Validation Loop(validation_step) 在一個epoch訓練完以后執(zhí)行Valid Test Loop(test_step) 在整個訓練完成以后執(zhí)行Test Optimizer(configure_optimizers) 配置優(yōu)化器等
展示一個最簡代碼:
>>>?import?pytorch_lightning?as?pl
>>>?class?LitModel(pl.LightningModule):
...
...?????def?__init__(self):
...?????????super().__init__()
...?????????self.l1?=?torch.nn.Linear(28?*?28,?10)
...
...?????def?forward(self,?x):
...?????????return?torch.relu(self.l1(x.view(x.size(0),?-1)))
...
...?????def?training_step(self,?batch,?batch_idx):
...?????????x,?y?=?batch
...?????????y_hat?=?self(x)
...?????????loss?=?F.cross_entropy(y_hat,?y)
...?????????return?loss
...
...?????def?configure_optimizers(self):
...?????????return?torch.optim.Adam(self.parameters(),?lr=0.02)
那么整個生命周期流程是如何組織的?
4.1 準備工作
這部分包括LightningModule的初始化、準備數(shù)據(jù)、配置優(yōu)化器。每次只執(zhí)行一次,相當于構造函數(shù)的作用。
__init__()(初始化 LightningModule )prepare_data()(準備數(shù)據(jù),包括下載數(shù)據(jù)、預處理等等)configure_optimizers()(配置優(yōu)化器)
4.2 測試 驗證部分
實際運行代碼前,會隨即初始化模型,然后運行一次驗證代碼,這樣可以防止在你訓練了幾個epoch之后要進行Valid的時候發(fā)現(xiàn)驗證部分出錯。主要測試下面幾個函數(shù):
val_dataloader()validation_step()validation_epoch_end()
4.3 加載數(shù)據(jù)
調用以下方法進行加載數(shù)據(jù)。
train_dataloader()val_dataloader()
4.4 訓練
每個batch的訓練被稱為一個step,故先運行train_step函數(shù)。
當經過多個batch, 默認49個step的訓練后,會進行驗證,運行validation_step函數(shù)。
當完成一個epoch的訓練以后,會對整個epoch結果進行驗證,運行validation_epoch_end函數(shù)
(option)如果需要的話,可以調用測試部分代碼:
test_dataloader() test_step() test_epoch_end()
5. 示例
以MNIST為例,將PyTorch版本代碼轉為PyTorch Lightning。
5.1 PyTorch版本訓練MNIST
對于一個PyTorch的代碼來說,一般是這樣構建網絡(源碼來自PyTorch中的example庫)。
class?Net(nn.Module):
????def?__init__(self):
????????super(Net,?self).__init__()
????????self.conv1?=?nn.Conv2d(1,?32,?3,?1)
????????self.conv2?=?nn.Conv2d(32,?64,?3,?1)
????????self.dropout1?=?nn.Dropout(0.25)
????????self.dropout2?=?nn.Dropout(0.5)
????????self.fc1?=?nn.Linear(9216,?128)
????????self.fc2?=?nn.Linear(128,?10)
????def?forward(self,?x):
????????x?=?self.conv1(x)
????????x?=?F.relu(x)
????????x?=?self.conv2(x)
????????x?=?F.relu(x)
????????x?=?F.max_pool2d(x,?2)
????????x?=?self.dropout1(x)
????????x?=?torch.flatten(x,?1)
????????x?=?self.fc1(x)
????????x?=?F.relu(x)
????????x?=?self.dropout2(x)
????????x?=?self.fc2(x)
????????output?=?F.log_softmax(x,?dim=1)
????????return?output
還有兩個主要工作是構建訓練函數(shù)和測試函數(shù)。
在訓練函數(shù)中需要完成:
數(shù)據(jù)獲取 data, target = data.to(device), target.to(device)清空優(yōu)化器梯度 optimizer.zero_grad()前向傳播 output = model(data)計算損失函數(shù) loss = F.nll_loss(output, target)反向傳播 loss.backward()優(yōu)化器進行單次優(yōu)化 optimizer.step()
def?train(args,?model,?device,?train_loader,?optimizer,?epoch):
????model.train()
????for?batch_idx,?(data,?target)?in?enumerate(train_loader):
????????data,?target?=?data.to(device),?target.to(device)
????????optimizer.zero_grad()
????????output?=?model(data)
????????loss?=?F.nll_loss(output,?target)
????????loss.backward()
????????optimizer.step()
????????if?batch_idx?%?args.log_interval?==?0:
????????????print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}'.format(
????????????????epoch,?batch_idx?*?len(data),?len(train_loader.dataset),
????????????????100.?*?batch_idx?/?len(train_loader),?loss.item()))
????????????if?args.dry_run:
????????????????break
def?test(model,?device,?test_loader):
????model.eval()
????test_loss?=?0
????correct?=?0
????with?torch.no_grad():
????????for?data,?target?in?test_loader:
????????????data,?target?=?data.to(device),?target.to(device)
????????????output?=?model(data)
????????????test_loss?+=?F.nll_loss(output,?target,?reduction='sum').item()??#?sum?up?batch?loss
????????????pred?=?output.argmax(dim=1,?keepdim=True)??#?get?the?index?of?the?max?log-probability
????????????correct?+=?pred.eq(target.view_as(pred)).sum().item()
????test_loss?/=?len(test_loader.dataset)
????print('\nTest?set:?Average?loss:?{:.4f},?Accuracy:?{}/{}?({:.0f}%)\n'.format(
????????test_loss,?correct,?len(test_loader.dataset),
????????100.?*?correct?/?len(test_loader.dataset)))
其他部分比如數(shù)據(jù)加載、數(shù)據(jù)增廣、優(yōu)化器、訓練流程都是在main中執(zhí)行的,采用的是一種面向過程的方法。
def?main():
????#?Training?settings
????parser?=?argparse.ArgumentParser(description='PyTorch?MNIST?Example')
????parser.add_argument('--batch-size',?type=int,?default=64,?metavar='N',
????????????????????????help='input?batch?size?for?training?(default:?64)')
????parser.add_argument('--test-batch-size',?type=int,?default=1000,?metavar='N',
????????????????????????help='input?batch?size?for?testing?(default:?1000)')
????parser.add_argument('--epochs',?type=int,?default=14,?metavar='N',
????????????????????????help='number?of?epochs?to?train?(default:?14)')
????parser.add_argument('--lr',?type=float,?default=1.0,?metavar='LR',
????????????????????????help='learning?rate?(default:?1.0)')
????parser.add_argument('--gamma',?type=float,?default=0.7,?metavar='M',
????????????????????????help='Learning?rate?step?gamma?(default:?0.7)')
????parser.add_argument('--no-cuda',?action='store_true',?default=False,
????????????????????????help='disables?CUDA?training')
????parser.add_argument('--dry-run',?action='store_true',?default=False,
????????????????????????help='quickly?check?a?single?pass')
????parser.add_argument('--seed',?type=int,?default=1,?metavar='S',
????????????????????????help='random?seed?(default:?1)')
????parser.add_argument('--log-interval',?type=int,?default=10,?metavar='N',
????????????????????????help='how?many?batches?to?wait?before?logging?training?status')
????parser.add_argument('--save-model',?action='store_true',?default=False,
????????????????????????help='For?Saving?the?current?Model')
????args?=?parser.parse_args()
????use_cuda?=?not?args.no_cuda?and?torch.cuda.is_available()
????torch.manual_seed(args.seed)
????device?=?torch.device("cuda"?if?use_cuda?else?"cpu")
????train_kwargs?=?{'batch_size':?args.batch_size}
????test_kwargs?=?{'batch_size':?args.test_batch_size}
????if?use_cuda:
????????cuda_kwargs?=?{'num_workers':?1,
???????????????????????'pin_memory':?True,
???????????????????????'shuffle':?True}
????????train_kwargs.update(cuda_kwargs)
????????test_kwargs.update(cuda_kwargs)
????transform=transforms.Compose([
????????transforms.ToTensor(),
????????transforms.Normalize((0.1307,),?(0.3081,))
????????])
????dataset1?=?datasets.MNIST('../data',?train=True,?download=True,
???????????????????????transform=transform)
????dataset2?=?datasets.MNIST('../data',?train=False,
???????????????????????transform=transform)
????train_loader?=?torch.utils.data.DataLoader(dataset1,**train_kwargs)
????test_loader?=?torch.utils.data.DataLoader(dataset2,?**test_kwargs)
????model?=?Net().to(device)
????optimizer?=?optim.Adadelta(model.parameters(),?lr=args.lr)
????scheduler?=?StepLR(optimizer,?step_size=1,?gamma=args.gamma)
????for?epoch?in?range(1,?args.epochs?+?1):
????????train(args,?model,?device,?train_loader,?optimizer,?epoch)
????????test(model,?device,?test_loader)
????????scheduler.step()
????if?args.save_model:
????????torch.save(model.state_dict(),?"mnist_cnn.pt")
5.2 Lightning版本訓練MNIST
第一部分,也就是歸為研究代碼,主要是模型的結構、訓練等部分。被抽象為LightningModule類。
class?LitClassifier(pl.LightningModule):
????def?__init__(self,?hidden_dim=128,?learning_rate=1e-3):
????????super().__init__()
????????self.save_hyperparameters()
????????self.l1?=?torch.nn.Linear(28?*?28,?self.hparams.hidden_dim)
????????self.l2?=?torch.nn.Linear(self.hparams.hidden_dim,?10)
????def?forward(self,?x):
????????x?=?x.view(x.size(0),?-1)
????????x?=?torch.relu(self.l1(x))
????????x?=?torch.relu(self.l2(x))
????????return?x
????def?training_step(self,?batch,?batch_idx):
????????x,?y?=?batch
????????y_hat?=?self(x)
????????loss?=?F.cross_entropy(y_hat,?y)
????????return?loss
????def?validation_step(self,?batch,?batch_idx):
????????x,?y?=?batch
????????y_hat?=?self(x)
????????loss?=?F.cross_entropy(y_hat,?y)
????????self.log('valid_loss',?loss)
????def?test_step(self,?batch,?batch_idx):
????????x,?y?=?batch
????????y_hat?=?self(x)
????????loss?=?F.cross_entropy(y_hat,?y)
????????self.log('test_loss',?loss)
????def?configure_optimizers(self):
????????return?torch.optim.Adam(self.parameters(),?lr=self.hparams.learning_rate)
????@staticmethod
????def?add_model_specific_args(parent_parser):
????????parser?=?ArgumentParser(parents=[parent_parser],?add_help=False)
????????parser.add_argument('--hidden_dim',?type=int,?default=128)
????????parser.add_argument('--learning_rate',?type=float,?default=0.0001)
????????return?parser
可以看出,和PyTorch版本最大的不同之處在于多了幾個流程處理函數(shù):
training_step,相當于訓練過程中處理一個batch的內容 validation_step,相當于驗證過程中處理一個batch的內容 test_step, 同上 configure_optimizers, 這部分用于處理optimizer和scheduler add_module_specific_args代表這部分控制的是與模型相關的參數(shù)
除此以外,main函數(shù)主要有以下幾個部分:
args參數(shù)處理 data部分 model部分 訓練部分 測試部分
def?cli_main():
????pl.seed_everything(1234)?#?這個是用于固定seed用
????#?args
????parser?=?ArgumentParser()
????parser?=?pl.Trainer.add_argparse_args(parser)
????parser?=?LitClassifier.add_model_specific_args(parser)
????parser?=?MNISTDataModule.add_argparse_args(parser)
????args?=?parser.parse_args()
????#?data
????dm?=?MNISTDataModule.from_argparse_args(args)
????#?model
????model?=?LitClassifier(args.hidden_dim,?args.learning_rate)
????#?training
????trainer?=?pl.Trainer.from_argparse_args(args)
????trainer.fit(model,?datamodule=dm)
????result?=?trainer.test(model,?datamodule=dm)
????pprint(result)
可以看出Lightning版本的代碼代碼量略低于PyTorch版本,但是同時將一些細節(jié)忽略了,比如訓練的具體流程直接使用fit搞定,這樣不會出現(xiàn)忘記清空optimizer等低級錯誤。
6. 評價
總體來說,PyTorch Lightning是一個發(fā)展迅速的框架,如同fastai、keras、ignite等二次封裝的框架一樣,雖然易用性得到了提升,讓用戶可以通過更短的代碼完成任務,但是遇到錯誤的時候,往往就需要查看API甚至涉及框架源碼才能夠解決。前者降低門檻,后者略微提升了門檻。
筆者使用這個框架大概一周了,從使用者角度來談談優(yōu)缺點:
6.1 優(yōu)點
簡化了部分代碼,之前如果要轉到GPU上,需要用to(device)方法判斷,然后轉過去。有了PyTorch lightning的幫助,可以自動幫你處理,通過設置trainer中的gpus參數(shù)即可。 提供了一些有用的工具,比如混合精度訓練、分布式訓練、Horovod 代碼移植更加容易 API比較完善,大部分都有例子,少部分講的不夠詳細。 社區(qū)還是比較活躍的,如果有問題,可以在issue中提問。 實驗結果整理的比較好,將每次實驗劃分為version 0-n,同時可以用tensorboard比較多個實驗,非常友好。
6.2 缺點
引入了一些新的概念,進一步加大了使用者的學習成本,比如pl_bolts 很多原本習慣于在Pytorch中使用的功能,在PyTorch Lightning中必須查API才能使用,比如我想用scheduler,就需要去查API,然后發(fā)現(xiàn)在configure_optimizers函數(shù)中實現(xiàn),然后模仿demo實現(xiàn),因此也帶來了一定的門檻。 有些報錯比較迷,筆者曾遇到過執(zhí)行的時候發(fā)現(xiàn)多線程出問題,比較難以排查,最后通過更改distributed_backend得到了解決。遇到新的坑要去API里找答案,如果沒有解決繼續(xù)去Issue里找答案。
7. 參考
【1】 https://zhuanlan.zhihu.com/p/120331610
【2】https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html
【3】https://github.com/pytorch/examples/blob/master/mnist/main.py
【4】 https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/simple_image_classifier.py
