<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          PyTorch Lightning工具學習

          共 11367字,需瀏覽 23分鐘

           ·

          2020-12-09 00:50

          【GiantPandaCV導語】Pytorch Lightning是在Pytorch基礎上進行封裝的庫(可以理解為keras之于tensorflow),為了讓用戶能夠脫離PyTorch一些繁瑣的細節(jié),專注于核心代碼的構建,提供了許多實用工具,可以讓實驗更加高效。本文將介紹安裝方法、設計邏輯、轉化的例子等內容。

          PyTorch Lightning中提供了以下比較方便的功能:

          • multi-GPU訓練
          • 半精度訓練
          • TPU 訓練
          • 將訓練細節(jié)進行抽象,從而可以快速迭代
          Pytorch Lightning

          1. 簡單介紹

          PyTorch lightning 是為AI相關的專業(yè)的研究人員、研究生、博士等人群開發(fā)的。PyTorch就是William Falcon在他的博士階段創(chuàng)建的,目標是讓AI研究擴展性更強,忽略一些耗費時間的細節(jié)。

          目前PyTorch Lightning庫已經有了一定的影響力,star已經1w+,同時有超過1千多的研究人員在一起維護這個框架。

          PyTorch Lightning庫

          同時PyTorch Lightning也在隨著PyTorch版本的更新也在不停迭代。

          版本支持情況

          官方文檔也有支持,正在不斷更新:

          官方文檔

          下面介紹一下如何安裝。

          2. 安裝方法

          Pytorch Lightning安裝非常方便,推薦使用conda環(huán)境進行安裝。

          source?activate?you_env
          pip?install?pytorch-lightning

          或者直接用pip安裝:

          pip?install?pytorch-lightning

          或者通過conda安裝:

          conda?install?pytorch-lightning?-c?conda-forge

          3. Lightning的設計思想

          Lightning將大部分AI相關代碼分為三個部分:

          • 研究代碼,主要是模型的結構、訓練等部分。被抽象為LightningModule類。

          • 工程代碼,這部分代碼重復性強,比如16位精度,分布式訓練。被抽象為Trainer類。

          • 非必要代碼,這部分代碼和實驗沒有直接關系,不加也可以,加上可以輔助,比如梯度檢查,log輸出等。被抽象為Callbacks類。

          Lightning將研究代碼劃分為以下幾個組件:

          • 模型
          • 數(shù)據(jù)處理
          • 損失函數(shù)
          • 優(yōu)化器

          以上四個組件都將集成到LightningModule類中,是在Module類之上進行了擴展,進行了功能性補充,比如原來優(yōu)化器使用在main函數(shù)中,是一種面向過程的用法,現(xiàn)在集成到LightningModule中,作為一個類的方法。

          4. LightningModule生命周期

          這部分參考了https://zhuanlan.zhihu.com/p/120331610 和 官方文檔 https://pytorch-lightning.readthedocs.io/en/latest/trainer.html

          在這個模塊中,將PyTorch代碼按照五個部分進行組織:

          • Computations(init) 初始化相關計算
          • Train Loop(training_step) 每個step中執(zhí)行的代碼
          • Validation Loop(validation_step) 在一個epoch訓練完以后執(zhí)行Valid
          • Test Loop(test_step) 在整個訓練完成以后執(zhí)行Test
          • Optimizer(configure_optimizers) 配置優(yōu)化器等

          展示一個最簡代碼:

          >>>?import?pytorch_lightning?as?pl
          >>>?class?LitModel(pl.LightningModule):
          ...
          ...?????def?__init__(self):
          ...?????????super().__init__()
          ...?????????self.l1?=?torch.nn.Linear(28?*?28,?10)
          ...
          ...?????def?forward(self,?x):
          ...?????????return?torch.relu(self.l1(x.view(x.size(0),?-1)))
          ...
          ...?????def?training_step(self,?batch,?batch_idx):
          ...?????????x,?y?=?batch
          ...?????????y_hat?=?self(x)
          ...?????????loss?=?F.cross_entropy(y_hat,?y)
          ...?????????return?loss
          ...
          ...?????def?configure_optimizers(self):
          ...?????????return?torch.optim.Adam(self.parameters(),?lr=0.02)

          那么整個生命周期流程是如何組織的?

          4.1 準備工作

          這部分包括LightningModule的初始化、準備數(shù)據(jù)、配置優(yōu)化器。每次只執(zhí)行一次,相當于構造函數(shù)的作用。

          • __init__()(初始化 LightningModule )
          • prepare_data() (準備數(shù)據(jù),包括下載數(shù)據(jù)、預處理等等)
          • configure_optimizers() (配置優(yōu)化器)

          4.2 測試 驗證部分

          實際運行代碼前,會隨即初始化模型,然后運行一次驗證代碼,這樣可以防止在你訓練了幾個epoch之后要進行Valid的時候發(fā)現(xiàn)驗證部分出錯。主要測試下面幾個函數(shù):

          • val_dataloader()
          • validation_step()
          • validation_epoch_end()

          4.3 加載數(shù)據(jù)

          調用以下方法進行加載數(shù)據(jù)。

          • train_dataloader()
          • val_dataloader()

          4.4 訓練

          • 每個batch的訓練被稱為一個step,故先運行train_step函數(shù)。

          • 當經過多個batch, 默認49個step的訓練后,會進行驗證,運行validation_step函數(shù)。

          • 當完成一個epoch的訓練以后,會對整個epoch結果進行驗證,運行validation_epoch_end函數(shù)

          • (option)如果需要的話,可以調用測試部分代碼:

            • test_dataloader()
            • test_step()
            • test_epoch_end()

          5. 示例

          以MNIST為例,將PyTorch版本代碼轉為PyTorch Lightning。

          5.1 PyTorch版本訓練MNIST

          對于一個PyTorch的代碼來說,一般是這樣構建網絡(源碼來自PyTorch中的example庫)。

          class?Net(nn.Module):
          ????def?__init__(self):
          ????????super(Net,?self).__init__()
          ????????self.conv1?=?nn.Conv2d(1,?32,?3,?1)
          ????????self.conv2?=?nn.Conv2d(32,?64,?3,?1)
          ????????self.dropout1?=?nn.Dropout(0.25)
          ????????self.dropout2?=?nn.Dropout(0.5)
          ????????self.fc1?=?nn.Linear(9216,?128)
          ????????self.fc2?=?nn.Linear(128,?10)

          ????def?forward(self,?x):
          ????????x?=?self.conv1(x)
          ????????x?=?F.relu(x)
          ????????x?=?self.conv2(x)
          ????????x?=?F.relu(x)
          ????????x?=?F.max_pool2d(x,?2)
          ????????x?=?self.dropout1(x)
          ????????x?=?torch.flatten(x,?1)
          ????????x?=?self.fc1(x)
          ????????x?=?F.relu(x)
          ????????x?=?self.dropout2(x)
          ????????x?=?self.fc2(x)
          ????????output?=?F.log_softmax(x,?dim=1)
          ????????return?output

          還有兩個主要工作是構建訓練函數(shù)和測試函數(shù)。

          在訓練函數(shù)中需要完成:

          • 數(shù)據(jù)獲取 data, target = data.to(device), target.to(device)
          • 清空優(yōu)化器梯度 optimizer.zero_grad()
          • 前向傳播 output = model(data)
          • 計算損失函數(shù) loss = F.nll_loss(output, target)
          • 反向傳播 loss.backward()
          • 優(yōu)化器進行單次優(yōu)化 optimizer.step()
          def?train(args,?model,?device,?train_loader,?optimizer,?epoch):
          ????model.train()
          ????for?batch_idx,?(data,?target)?in?enumerate(train_loader):
          ????????data,?target?=?data.to(device),?target.to(device)
          ????????optimizer.zero_grad()
          ????????output?=?model(data)
          ????????loss?=?F.nll_loss(output,?target)
          ????????loss.backward()
          ????????optimizer.step()
          ????????if?batch_idx?%?args.log_interval?==?0:
          ????????????print('Train?Epoch:?{}?[{}/{}?({:.0f}%)]\tLoss:?{:.6f}'.format(
          ????????????????epoch,?batch_idx?*?len(data),?len(train_loader.dataset),
          ????????????????100.?*?batch_idx?/?len(train_loader),?loss.item()))
          ????????????if?args.dry_run:
          ????????????????break

          def?test(model,?device,?test_loader):
          ????model.eval()
          ????test_loss?=?0
          ????correct?=?0
          ????with?torch.no_grad():
          ????????for?data,?target?in?test_loader:
          ????????????data,?target?=?data.to(device),?target.to(device)
          ????????????output?=?model(data)
          ????????????test_loss?+=?F.nll_loss(output,?target,?reduction='sum').item()??#?sum?up?batch?loss
          ????????????pred?=?output.argmax(dim=1,?keepdim=True)??#?get?the?index?of?the?max?log-probability
          ????????????correct?+=?pred.eq(target.view_as(pred)).sum().item()

          ????test_loss?/=?len(test_loader.dataset)

          ????print('\nTest?set:?Average?loss:?{:.4f},?Accuracy:?{}/{}?({:.0f}%)\n'.format(
          ????????test_loss,?correct,?len(test_loader.dataset),
          ????????100.?*?correct?/?len(test_loader.dataset)))

          其他部分比如數(shù)據(jù)加載、數(shù)據(jù)增廣、優(yōu)化器、訓練流程都是在main中執(zhí)行的,采用的是一種面向過程的方法。

          def?main():
          ????#?Training?settings
          ????parser?=?argparse.ArgumentParser(description='PyTorch?MNIST?Example')
          ????parser.add_argument('--batch-size',?type=int,?default=64,?metavar='N',
          ????????????????????????help='input?batch?size?for?training?(default:?64)')
          ????parser.add_argument('--test-batch-size',?type=int,?default=1000,?metavar='N',
          ????????????????????????help='input?batch?size?for?testing?(default:?1000)')
          ????parser.add_argument('--epochs',?type=int,?default=14,?metavar='N',
          ????????????????????????help='number?of?epochs?to?train?(default:?14)')
          ????parser.add_argument('--lr',?type=float,?default=1.0,?metavar='LR',
          ????????????????????????help='learning?rate?(default:?1.0)')
          ????parser.add_argument('--gamma',?type=float,?default=0.7,?metavar='M',
          ????????????????????????help='Learning?rate?step?gamma?(default:?0.7)')
          ????parser.add_argument('--no-cuda',?action='store_true',?default=False,
          ????????????????????????help='disables?CUDA?training')
          ????parser.add_argument('--dry-run',?action='store_true',?default=False,
          ????????????????????????help='quickly?check?a?single?pass')
          ????parser.add_argument('--seed',?type=int,?default=1,?metavar='S',
          ????????????????????????help='random?seed?(default:?1)')
          ????parser.add_argument('--log-interval',?type=int,?default=10,?metavar='N',
          ????????????????????????help='how?many?batches?to?wait?before?logging?training?status')
          ????parser.add_argument('--save-model',?action='store_true',?default=False,
          ????????????????????????help='For?Saving?the?current?Model')
          ????args?=?parser.parse_args()
          ????use_cuda?=?not?args.no_cuda?and?torch.cuda.is_available()

          ????torch.manual_seed(args.seed)

          ????device?=?torch.device("cuda"?if?use_cuda?else?"cpu")

          ????train_kwargs?=?{'batch_size':?args.batch_size}
          ????test_kwargs?=?{'batch_size':?args.test_batch_size}
          ????if?use_cuda:
          ????????cuda_kwargs?=?{'num_workers':?1,
          ???????????????????????'pin_memory':?True,
          ???????????????????????'shuffle':?True}
          ????????train_kwargs.update(cuda_kwargs)
          ????????test_kwargs.update(cuda_kwargs)

          ????transform=transforms.Compose([
          ????????transforms.ToTensor(),
          ????????transforms.Normalize((0.1307,),?(0.3081,))
          ????????])
          ????dataset1?=?datasets.MNIST('../data',?train=True,?download=True,
          ???????????????????????transform=transform)
          ????dataset2?=?datasets.MNIST('../data',?train=False,
          ???????????????????????transform=transform)
          ????train_loader?=?torch.utils.data.DataLoader(dataset1,**train_kwargs)
          ????test_loader?=?torch.utils.data.DataLoader(dataset2,?**test_kwargs)

          ????model?=?Net().to(device)
          ????optimizer?=?optim.Adadelta(model.parameters(),?lr=args.lr)

          ????scheduler?=?StepLR(optimizer,?step_size=1,?gamma=args.gamma)
          ????for?epoch?in?range(1,?args.epochs?+?1):
          ????????train(args,?model,?device,?train_loader,?optimizer,?epoch)
          ????????test(model,?device,?test_loader)
          ????????scheduler.step()

          ????if?args.save_model:
          ????????torch.save(model.state_dict(),?"mnist_cnn.pt")

          5.2 Lightning版本訓練MNIST

          第一部分,也就是歸為研究代碼,主要是模型的結構、訓練等部分。被抽象為LightningModule類。

          class?LitClassifier(pl.LightningModule):
          ????def?__init__(self,?hidden_dim=128,?learning_rate=1e-3):
          ????????super().__init__()
          ????????self.save_hyperparameters()

          ????????self.l1?=?torch.nn.Linear(28?*?28,?self.hparams.hidden_dim)
          ????????self.l2?=?torch.nn.Linear(self.hparams.hidden_dim,?10)

          ????def?forward(self,?x):
          ????????x?=?x.view(x.size(0),?-1)
          ????????x?=?torch.relu(self.l1(x))
          ????????x?=?torch.relu(self.l2(x))
          ????????return?x

          ????def?training_step(self,?batch,?batch_idx):
          ????????x,?y?=?batch
          ????????y_hat?=?self(x)
          ????????loss?=?F.cross_entropy(y_hat,?y)
          ????????return?loss

          ????def?validation_step(self,?batch,?batch_idx):
          ????????x,?y?=?batch
          ????????y_hat?=?self(x)
          ????????loss?=?F.cross_entropy(y_hat,?y)
          ????????self.log('valid_loss',?loss)

          ????def?test_step(self,?batch,?batch_idx):
          ????????x,?y?=?batch
          ????????y_hat?=?self(x)
          ????????loss?=?F.cross_entropy(y_hat,?y)
          ????????self.log('test_loss',?loss)

          ????def?configure_optimizers(self):
          ????????return?torch.optim.Adam(self.parameters(),?lr=self.hparams.learning_rate)

          ????@staticmethod
          ????def?add_model_specific_args(parent_parser):
          ????????parser?=?ArgumentParser(parents=[parent_parser],?add_help=False)
          ????????parser.add_argument('--hidden_dim',?type=int,?default=128)
          ????????parser.add_argument('--learning_rate',?type=float,?default=0.0001)
          ????????return?parser

          可以看出,和PyTorch版本最大的不同之處在于多了幾個流程處理函數(shù):

          • training_step,相當于訓練過程中處理一個batch的內容
          • validation_step,相當于驗證過程中處理一個batch的內容
          • test_step, 同上
          • configure_optimizers, 這部分用于處理optimizer和scheduler
          • add_module_specific_args代表這部分控制的是與模型相關的參數(shù)

          除此以外,main函數(shù)主要有以下幾個部分:

          • args參數(shù)處理
          • data部分
          • model部分
          • 訓練部分
          • 測試部分
          def?cli_main():
          ????pl.seed_everything(1234)?#?這個是用于固定seed用

          ????#?args
          ????parser?=?ArgumentParser()
          ????parser?=?pl.Trainer.add_argparse_args(parser)
          ????parser?=?LitClassifier.add_model_specific_args(parser)
          ????parser?=?MNISTDataModule.add_argparse_args(parser)
          ????args?=?parser.parse_args()

          ????#?data
          ????dm?=?MNISTDataModule.from_argparse_args(args)

          ????#?model
          ????model?=?LitClassifier(args.hidden_dim,?args.learning_rate)

          ????#?training
          ????trainer?=?pl.Trainer.from_argparse_args(args)
          ????trainer.fit(model,?datamodule=dm)

          ????result?=?trainer.test(model,?datamodule=dm)
          ????pprint(result)

          可以看出Lightning版本的代碼代碼量略低于PyTorch版本,但是同時將一些細節(jié)忽略了,比如訓練的具體流程直接使用fit搞定,這樣不會出現(xiàn)忘記清空optimizer等低級錯誤。

          6. 評價

          總體來說,PyTorch Lightning是一個發(fā)展迅速的框架,如同fastai、keras、ignite等二次封裝的框架一樣,雖然易用性得到了提升,讓用戶可以通過更短的代碼完成任務,但是遇到錯誤的時候,往往就需要查看API甚至涉及框架源碼才能夠解決。前者降低門檻,后者略微提升了門檻。

          筆者使用這個框架大概一周了,從使用者角度來談談優(yōu)缺點:

          6.1 優(yōu)點

          • 簡化了部分代碼,之前如果要轉到GPU上,需要用to(device)方法判斷,然后轉過去。有了PyTorch lightning的幫助,可以自動幫你處理,通過設置trainer中的gpus參數(shù)即可。
          • 提供了一些有用的工具,比如混合精度訓練、分布式訓練、Horovod
          • 代碼移植更加容易
          • API比較完善,大部分都有例子,少部分講的不夠詳細。
          • 社區(qū)還是比較活躍的,如果有問題,可以在issue中提問。
          • 實驗結果整理的比較好,將每次實驗劃分為version 0-n,同時可以用tensorboard比較多個實驗,非常友好。

          6.2 缺點

          • 引入了一些新的概念,進一步加大了使用者的學習成本,比如pl_bolts
          • 很多原本習慣于在Pytorch中使用的功能,在PyTorch Lightning中必須查API才能使用,比如我想用scheduler,就需要去查API,然后發(fā)現(xiàn)在configure_optimizers函數(shù)中實現(xiàn),然后模仿demo實現(xiàn),因此也帶來了一定的門檻。
          • 有些報錯比較迷,筆者曾遇到過執(zhí)行的時候發(fā)現(xiàn)多線程出問題,比較難以排查,最后通過更改distributed_backend得到了解決。遇到新的坑要去API里找答案,如果沒有解決繼續(xù)去Issue里找答案。

          7. 參考

          • 【1】 https://zhuanlan.zhihu.com/p/120331610

          • 【2】https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html

          • 【3】https://github.com/pytorch/examples/blob/master/mnist/main.py

          • 【4】 https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/simple_image_classifier.py

          瀏覽 56
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  大香蕉尹人在看 | 亚洲最大免费在线观看视频 | 麻豆成人在线 | 黄片网站进入口 | 欧美干在线观看 |