Pytorch Lightning 完全攻略!
極市導(dǎo)讀
?作者實(shí)踐中發(fā)現(xiàn)Pytorch-Lightning庫(kù)并不容易學(xué)習(xí),當(dāng)出現(xiàn)了一些稍微高階的要求時(shí)會(huì)在相似工程代碼上花費(fèi)大量時(shí)間,Debug也是這些代碼花的時(shí)間最多,同時(shí)核心的訓(xùn)練邏輯也漸漸被這些工程代碼蓋過(guò)。那么有沒(méi)有更好的解決方案,甚至能一鍵解決所有這些問(wèn)題呢?請(qǐng)看本文。?
寫(xiě)在前面
Pytorch-Lightning這個(gè)庫(kù)我“發(fā)現(xiàn)”過(guò)兩次。第一次發(fā)現(xiàn)時(shí),感覺(jué)它很重很難學(xué),而且似乎自己也用不上。但是后面隨著做的項(xiàng)目開(kāi)始出現(xiàn)了一些稍微高階的要求,我發(fā)現(xiàn)我總是不斷地在相似工程代碼上花費(fèi)大量時(shí)間,Debug也是這些代碼花的時(shí)間最多,而且漸漸產(chǎn)生了一個(gè)矛盾之處:如果想要更多更好的功能,如TensorBoard支持,Early Stop,LR Scheduler,分布式訓(xùn)練,快速測(cè)試等,代碼就無(wú)可避免地變得越來(lái)越長(zhǎng),看起來(lái)也越來(lái)越亂,同時(shí)核心的訓(xùn)練邏輯也漸漸被這些工程代碼蓋過(guò)。那么有沒(méi)有更好的解決方案,甚至能一鍵解決所有這些問(wèn)題呢?
于是我第二次發(fā)現(xiàn)了Pytorch-Lightning。
真香。
但是問(wèn)題還是來(lái)了。這個(gè)框架并沒(méi)有因?yàn)橄愣兊酶右讓W(xué)。官網(wǎng)的教程很豐富,可以看出來(lái)開(kāi)發(fā)者們?cè)谂ψ隽恕5呛芏嘞噙B的知識(shí)點(diǎn)都被分布在了不同的版塊里,還有一些核心的理解要點(diǎn)并沒(méi)有被強(qiáng)調(diào)出來(lái),而是小字帶過(guò),這讓我想做一個(gè)普惠的教程,包含所有我在學(xué)習(xí)過(guò)程中認(rèn)為重要的概念,好用的參數(shù),一些注意點(diǎn)、坑點(diǎn),大量的示例代碼段和一些核心問(wèn)題的集中講解。
最后,第三部分提供了一個(gè)我總結(jié)出來(lái)的易用于大型項(xiàng)目、容易遷移、易于復(fù)用的模板,有興趣的可以去GitHub—?https://github.com/miracleyoo/pytorch-lightning-template?試用。
核心
Pytorch-Lighting 的一大特點(diǎn)是把模型和系統(tǒng)分開(kāi)來(lái)看。模型是像Resnet18, RNN之類的純模型, 而系統(tǒng)定義了一組模型如何相互交互,如GAN(生成器網(wǎng)絡(luò)與判別器網(wǎng)絡(luò))、Seq2Seq(Encoder與Decoder網(wǎng)絡(luò))和Bert。同時(shí),有時(shí)候問(wèn)題只涉及一個(gè)模型,那么這個(gè)系統(tǒng)則可以是一個(gè)通用的系統(tǒng),用于描述模型如何使用,并可以被復(fù)用到很多其他項(xiàng)目。 Pytorch-Lighting 的核心設(shè)計(jì)思想是“自給自足”。每個(gè)網(wǎng)絡(luò)也同時(shí)包含了如何訓(xùn)練、如何測(cè)試、優(yōu)化器定義等內(nèi)容。

推薦使用方法
這一部分放在最前面,因?yàn)槿膬?nèi)容太長(zhǎng),如果放后面容易忽略掉這部分精華。
Pytorch-Lightning 是一個(gè)很好的庫(kù),或者說(shuō)是pytorch的抽象和包裝。它的好處是可復(fù)用性強(qiáng),易維護(hù),邏輯清晰等。缺點(diǎn)也很明顯,這個(gè)包需要學(xué)習(xí)和理解的內(nèi)容還是挺多的,或者換句話說(shuō),很重。如果直接按照官方的模板寫(xiě)代碼,小型project還好,如果是大型項(xiàng)目,有復(fù)數(shù)個(gè)需要調(diào)試驗(yàn)證的模型和數(shù)據(jù)集,那就不太好辦,甚至更加麻煩了。經(jīng)過(guò)幾天的摸索和調(diào)試,我總結(jié)出了下面這樣一套好用的模板,也可以說(shuō)是對(duì)Pytorch-Lightning的進(jìn)一步抽象。
歡迎大家嘗試這一套代碼風(fēng)格,如果用習(xí)慣的話還是相當(dāng)方便復(fù)用的,也不容易半道退坑。
root-
????|-data
????????|-__init__.py
????????|-data_interface.py
????????|-xxxdataset1.py
????????|-xxxdataset2.py
????????|-...
????|-model
????????|-__init__.py
????????|-model_interface.py
????????|-xxxmodel1.py
????????|-xxxmodel2.py
????????|-...
????|-main.py
如果對(duì)每個(gè)模型直接上plmodule,對(duì)于已有項(xiàng)目、別人的代碼等的轉(zhuǎn)換將相當(dāng)耗時(shí)。另外,這樣的話,你需要給每個(gè)模型都加上一些相似的代碼,如training_step,validation_step。顯然,這并不是我們想要的,如果真的這樣做,不但不易于維護(hù),反而可能會(huì)更加雜亂。同理,如果把每個(gè)數(shù)據(jù)集類都直接轉(zhuǎn)換成pl的DataModule,也會(huì)面臨相似的問(wèn)題。基于這樣的考量,我建議使用上述架構(gòu):
主目錄下只放一個(gè) main.py文件。data和modle兩個(gè)文件夾中放入__init__.py文件,做成包。這樣方便導(dǎo)入。兩個(gè)init文件分別是:from .data_interface import DInterface和from .model_interface import MInterface在 data_interface中建立一個(gè)class DInterface(pl.LightningDataModule):用作所有數(shù)據(jù)集文件的接口。__init__()函數(shù)中import相應(yīng)Dataset類,setup()進(jìn)行實(shí)例化,并老老實(shí)實(shí)加入所需要的的train_dataloader,?val_dataloader,?test_dataloader函數(shù)。這些函數(shù)往往都是相似的,可以用幾個(gè)輸入args控制不同的部分。同理,在 model_interface中建立class MInterface(pl.LightningModule):類,作為模型的中間接口。__init__()函數(shù)中import相應(yīng)模型類,然后老老實(shí)實(shí)加入configure_optimizers,?training_step,?validation_step等函數(shù),用一個(gè)接口類控制所有模型。不同部分使用輸入?yún)?shù)控制。main.py函數(shù)只負(fù)責(zé):定義parser,添加parse項(xiàng);選好需要的callback函數(shù);實(shí)例化MInterface,?DInterface,?Trainer。
完事。
完全版模板可以在GitHub:https://github.com/miracleyoo/pytorch-lightning-template?找到。
Lightning Module
簡(jiǎn)介
主頁(yè):https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
三個(gè)核心組件:
模型 優(yōu)化器 Train/Val/Test步驟
數(shù)據(jù)流偽代碼:
outs?=?[]
for?batch?in?data:
????out?=?training_step(batch)
????outs.append(out)
training_epoch_end(outs)
等價(jià)Lightning代碼:
def?training_step(self,?batch,?batch_idx):
????prediction?=?...
????return?prediction
def?training_epoch_end(self,?training_step_outputs):
????for?prediction?in?predictions:
????????#?do?something?with?these
我們需要做的,就是像填空一樣,填這些函數(shù)。
組件與函數(shù)
API頁(yè)面:https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html%23lightningmodule-api
一個(gè)Pytorch-Lighting 模型必須含有的部件是:
init: 初始化,包括模型和系統(tǒng)的定義。training_step(self, batch, batch_idx): 即每個(gè)batch的處理函數(shù)。
參數(shù):
batch?(
Tensor?| (Tensor, …) | [Tensor, …]) – The output of your?DataLoader. A tensor, tuple or list.batch_idx?(
int) – Integer displaying index of this batch?optimizer_idx?(
int) – When using multiple optimizers, this argument will also be present.hiddens?(
Tensor) – Passed in if truncated_bptt_steps > 0.
返回值:Any of.
Tensor?- The loss tensordict?- A dictionary. Can include any keys, but must include the key?'loss'None?- Training will skip to the next batch
返回值無(wú)論如何也需要有一個(gè)loss量。如果是字典,要有這個(gè)key。沒(méi)loss這個(gè)batch就被跳過(guò)了。例:
def?training_step(self,?batch,?batch_idx):
????x,?y,?z?=?batch
????out?=?self.encoder(x)
????loss?=?self.loss(out,?x)
????return?loss
#?Multiple?optimizers?(e.g.:?GANs)
def?training_step(self,?batch,?batch_idx,?optimizer_idx):
????if?optimizer_idx?==?0:
????????#?do?training_step?with?encoder
????if?optimizer_idx?==?1:
????????#?do?training_step?with?decoder
????????
#?Truncated?back-propagation?through?time
def?training_step(self,?batch,?batch_idx,?hiddens):
????#?hiddens?are?the?hidden?states?from?the?previous?truncated?backprop?step
????...
????out,?hiddens?=?self.lstm(data,?hiddens)
????...
????return?{'loss':?loss,?'hiddens':?hiddens}
configure_optimizers: 優(yōu)化器定義,返回一個(gè)優(yōu)化器,或數(shù)個(gè)優(yōu)化器,或兩個(gè)List(優(yōu)化器,Scheduler)。如:
#?most?cases
def?configure_optimizers(self):
????opt?=?Adam(self.parameters(),?lr=1e-3)
????return?opt
#?multiple?optimizer?case?(e.g.:?GAN)
def?configure_optimizers(self):
????generator_opt?=?Adam(self.model_gen.parameters(),?lr=0.01)
????disriminator_opt?=?Adam(self.model_disc.parameters(),?lr=0.02)
????return?generator_opt,?disriminator_opt
#?example?with?learning?rate?schedulers
def?configure_optimizers(self):
????generator_opt?=?Adam(self.model_gen.parameters(),?lr=0.01)
????disriminator_opt?=?Adam(self.model_disc.parameters(),?lr=0.02)
????discriminator_sched?=?CosineAnnealing(discriminator_opt,?T_max=10)
????return?[generator_opt,?disriminator_opt],?[discriminator_sched]
#?example?with?step-based?learning?rate?schedulers
def?configure_optimizers(self):
????gen_opt?=?Adam(self.model_gen.parameters(),?lr=0.01)
????dis_opt?=?Adam(self.model_disc.parameters(),?lr=0.02)
????gen_sched?=?{'scheduler':?ExponentialLR(gen_opt,?0.99),
?????????????????'interval':?'step'}??#?called?after?each?training?step
????dis_sched?=?CosineAnnealing(discriminator_opt,?T_max=10)?#?called?every?epoch
????return?[gen_opt,?dis_opt],?[gen_sched,?dis_sched]
#?example?with?optimizer?frequencies
#?see?training?procedure?in?`Improved?Training?of?Wasserstein?GANs`,?Algorithm?1
#?https://arxiv.org/abs/1704.00028
def?configure_optimizers(self):
????gen_opt?=?Adam(self.model_gen.parameters(),?lr=0.01)
????dis_opt?=?Adam(self.model_disc.parameters(),?lr=0.02)
????n_critic?=?5
????return?(
????????{'optimizer':?dis_opt,?'frequency':?n_critic},
????????{'optimizer':?gen_opt,?'frequency':?1}
????)
可以指定的部件有:
forward: 和正常的nn.Module一樣,用于inference。內(nèi)部調(diào)用時(shí):y=self(batch)training_step_end: 只在使用多個(gè)node進(jìn)行訓(xùn)練且結(jié)果涉及如softmax之類需要全部輸出聯(lián)合運(yùn)算的步驟時(shí)使用該函數(shù)。同理,validation_step_end/test_step_end。training_epoch_end:在一個(gè)訓(xùn)練epoch結(jié)尾處被調(diào)用;輸入?yún)?shù):一個(gè)List,List的內(nèi)容是前面training_step()所返回的每次的內(nèi)容;返回:Nonevalidation_step(self, batch, batch_idx)/test_step(self, batch, batch_idx):沒(méi)有返回值限制,不一定非要輸出一個(gè)val_loss。validation_epoch_end/test_epoch_end
工具函數(shù)有:
freeze:凍結(jié)所有權(quán)重以供預(yù)測(cè)時(shí)候使用。僅當(dāng)已經(jīng)訓(xùn)練完成且后面只測(cè)試時(shí)使用。print:盡管自帶的print函數(shù)也可以使用,但如果程序運(yùn)行在分布式系統(tǒng)時(shí),會(huì)打印多次。而使用self.print()則只會(huì)打印一次。log:像是TensorBoard等log記錄器,對(duì)于每個(gè)log的標(biāo)量,都會(huì)有一個(gè)相對(duì)應(yīng)的橫坐標(biāo),它可能是batch number或epoch number。而on_step就表示把這個(gè)log出去的量的橫坐標(biāo)表示為當(dāng)前batch,而on_epoch則表示將log的量在整個(gè)epoch上進(jìn)行累積后log,橫坐標(biāo)為當(dāng)前epoch。
| LightningMoule Hook | on_step | on_epoch | prog_bar | logger |
|---|---|---|---|---|
| training_step | T | F | F | T |
| training_step_end | T | F | F | T |
| training_epoch_end | F | T | F | T |
| validation_step | F | T | F | T |
| validation_step_end | F | T | F | T |
| validation_epoch_end* | F | T | F | T |
*?also applies to the test loop
參數(shù):
name?(
str) – key namevalue?(
Any) – value nameprog_bar?(
bool) – if True logs to the progress barlogger?(
bool) – if True logs to the loggeron_step?(
Optional[bool]) – if True logs at this step. None auto-logs at the training_step but not validation/test_stepon_epoch?(
Optional[bool]) – if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_stepreduce_fx?(
Callable) – reduction function over step values for end of epoch. Torch.mean by defaulttbptt_reduce_fx?(
Callable) – function to reduce on truncated back proptbptt_pad_token?(
int) – token to use for paddingenable_graph?(
bool) – if True, will not auto detach the graphsync_dist?(
bool) – if True, reduces the metric across GPUs/TPUssync_dist_op?(
Union[Any,?str]) – the op to sync across GPUs/TPUssync_dist_group?(
Optional[Any]) – the ddp group
log_dict:和log函數(shù)唯一的區(qū)別就是,name和value變量由一個(gè)字典替換。表示同時(shí)log多個(gè)值。如:python values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n} self.log_dict(values)save_hyperparameters:儲(chǔ)存init中輸入的所有超參。后續(xù)訪問(wèn)可以由self.hparams.argX方式進(jìn)行。同時(shí),超參表也會(huì)被存到文件中。
函數(shù)內(nèi)建變量:
device:可以使用self.device來(lái)構(gòu)建設(shè)備無(wú)關(guān)型tensor。如:z = torch.rand(2, 3, device=self.device)。hparams:含有所有前面存下來(lái)的輸入超參。precision:精確度。常見(jiàn)32和16。
要點(diǎn)
如果準(zhǔn)備使用DataParallel,在寫(xiě)training_step的時(shí)候需要調(diào)用forward函數(shù),z=self(x)
模板
class?LitModel(pl.LightningModule):
????def?__init__(...):
????def?forward(...):
????def?training_step(...)
????def?training_step_end(...)
????def?training_epoch_end(...)
????def?validation_step(...)
????def?validation_step_end(...)
????def?validation_epoch_end(...)
????def?test_step(...)
????def?test_step_end(...)
????def?test_epoch_end(...)
????def?configure_optimizers(...)
????def?any_extra_hook(...)
Trainer
基礎(chǔ)使用
model?=?MyLightningModule()
trainer?=?Trainer()
trainer.fit(model,?train_dataloader,?val_dataloader)
如果連validation_step都沒(méi)有,那val_dataloader也就算了。
偽代碼與hooks
Hooks頁(yè)面:https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html%23hooks
def?fit(...):
????on_fit_start()
????if?global_rank?==?0:
????????#?prepare?data?is?called?on?GLOBAL_ZERO?only
????????prepare_data()
????for?gpu/tpu?in?gpu/tpus:
????????train_on_device(model.copy())
????on_fit_end()
def?train_on_device(model):
????#?setup?is?called?PER?DEVICE
????setup()
????configure_optimizers()
????on_pretrain_routine_start()
????for?epoch?in?epochs:
????????train_loop()
????teardown()
def?train_loop():
????on_train_epoch_start()
????train_outs?=?[]
????for?train_batch?in?train_dataloader():
????????on_train_batch_start()
????????#?-----?train_step?methods?-------
????????out?=?training_step(batch)
????????train_outs.append(out)
????????loss?=?out.loss
????????backward()
????????on_after_backward()
????????optimizer_step()
????????on_before_zero_grad()
????????optimizer_zero_grad()
????????on_train_batch_end(out)
????????if?should_check_val:
????????????val_loop()
????#?end?training?epoch
????logs?=?training_epoch_end(outs)
def?val_loop():
????model.eval()
????torch.set_grad_enabled(False)
????on_validation_epoch_start()
????val_outs?=?[]
????for?val_batch?in?val_dataloader():
????????on_validation_batch_start()
????????#?--------?val?step?methods?-------
????????out?=?validation_step(val_batch)
????????val_outs.append(out)
????????on_validation_batch_end(out)
????validation_epoch_end(val_outs)
????on_validation_epoch_end()
????#?set?up?for?train
????model.train()
????torch.set_grad_enabled(True)
推薦參數(shù)
參數(shù)介紹(附視頻)—?https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html%23trainer-flags
類定義與默認(rèn)參數(shù)—?https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html%23trainer-class-api
default_root_dir:默認(rèn)存儲(chǔ)地址。所有的實(shí)驗(yàn)變量和權(quán)重全部會(huì)被存到這個(gè)文件夾里面。推薦是,每個(gè)模型有一個(gè)獨(dú)立的文件夾。每次重新訓(xùn)練會(huì)產(chǎn)生一個(gè)新的version_x子文件夾。
max_epochs:最大訓(xùn)練周期數(shù)。trainer = Trainer(max_epochs=1000)
min_epochs:至少訓(xùn)練周期數(shù)。當(dāng)有Early Stop時(shí)使用。
auto_scale_batch_size:在進(jìn)行任何訓(xùn)練前自動(dòng)選擇合適的batch size。
#?default?used?by?the?Trainer?(no?scaling?of?batch?size)
trainer?=?Trainer(auto_scale_batch_size=None)
#?run?batch?size?scaling,?result?overrides?hparams.batch_size
trainer?=?Trainer(auto_scale_batch_size='binsearch')
#?call?tune?to?find?the?batch?size
trainer.tune(model)
auto_select_gpus:自動(dòng)選擇合適的GPU。尤其是在有GPU處于獨(dú)占模式時(shí)候,非常有用。
auto_lr_find:自動(dòng)找到合適的初始學(xué)習(xí)率。使用了https://arxiv.org/abs/1506.01186 論文的技術(shù)。當(dāng)且僅當(dāng)執(zhí)行trainer.tune(model)代碼時(shí)工作。
#?run?learning?rate?finder,?results?override?hparams.learning_rate
trainer?=?Trainer(auto_lr_find=True)
#?run?learning?rate?finder,?results?override?hparams.my_lr_arg
trainer?=?Trainer(auto_lr_find='my_lr_arg')
#?call?tune?to?find?the?lr
trainer.tune(model)
precision:精確度。正常是32,使用16可以減小內(nèi)存消耗,增大batch。
#?default?used?by?the?Trainer
trainer?=?Trainer(precision=32)
#?16-bit?precision
trainer?=?Trainer(precision=16,?gpus=1)
val_check_interval:進(jìn)行Validation測(cè)試的周期。正常為1,訓(xùn)練1個(gè)epoch測(cè)試4次是0.25,每1000 batch測(cè)試一次是1000。
use (float) to check within a training epoch:此時(shí)這個(gè)值為一個(gè)epoch的百分比。每百分之多少測(cè)試一次。use (int) to check every n steps (batches):每多少個(gè)batch測(cè)試一次。
#?default?used?by?the?Trainer
trainer?=?Trainer(val_check_interval=1.0)
#?check?validation?set?4?times?during?a?training?epoch
trainer?=?Trainer(val_check_interval=0.25)
#?check?validation?set?every?1000?training?batches
#?use?this?when?using?iterableDataset?and?your?dataset?has?no?length
#?(ie:?production?cases?with?streaming?data)
trainer?=?Trainer(val_check_interval=1000)?
gpus:控制使用的GPU數(shù)。當(dāng)設(shè)定為None時(shí),使用cpu。
#?default?used?by?the?Trainer?(ie:?train?on?CPU)
trainer?=?Trainer(gpus=None)
#?equivalent
trainer?=?Trainer(gpus=0)
#?int:?train?on?2?gpus
trainer?=?Trainer(gpus=2)
#?list:?train?on?GPUs?1,?4?(by?bus?ordering)
trainer?=?Trainer(gpus=[1,?4])
trainer?=?Trainer(gpus='1,?4')?#?equivalent
#?-1:?train?on?all?gpus
trainer?=?Trainer(gpus=-1)
trainer?=?Trainer(gpus='-1')?#?equivalent
#?combine?with?num_nodes?to?train?on?multiple?GPUs?across?nodes
#?uses?8?gpus?in?total
trainer?=?Trainer(gpus=2,?num_nodes=4)
#?train?only?on?GPUs?1?and?4?across?nodes
trainer?=?Trainer(gpus=[1,?4],?num_nodes=4)
limit_train_batches:使用訓(xùn)練數(shù)據(jù)的百分比。如果數(shù)據(jù)過(guò)多,或正在調(diào)試,可以使用這個(gè)。值的范圍為0~1。同樣,有limit_test_batches,limit_val_batches。
#?default?used?by?the?Trainer
trainer?=?Trainer(limit_train_batches=1.0)
#?run?through?only?25%?of?the?training?set?each?epoch
trainer?=?Trainer(limit_train_batches=0.25)
#?run?through?only?10?batches?of?the?training?set?each?epoch
trainer?=?Trainer(limit_train_batches=10)
fast_dev_run:bool量。如果設(shè)定為true,會(huì)只執(zhí)行一個(gè)batch的train, val 和 test,然后結(jié)束。僅用于debug。
Setting this argument will disable tuner, checkpoint callbacks, early stopping callbacks, loggers and logger callbacks like?
LearningRateLogger?and runs for only 1 epoch
#?default?used?by?the?Trainer
trainer?=?Trainer(fast_dev_run=False)
#?runs?1?train,?val,?test?batch?and?program?ends
trainer?=?Trainer(fast_dev_run=True)
#?runs?7?train,?val,?test?batches?and?program?ends
trainer?=?Trainer(fast_dev_run=7)
.fit()函數(shù)
Trainer.fit(model, train_dataloader=None, val_dataloaders=None, datamodule=None):輸入第一個(gè)量一定是model,然后可以跟一個(gè)LigntningDataModule或一個(gè)普通的Train DataLoader。如果定義了Val step,也要有Val DataLoader。
參數(shù):
datamodule?([Optional] [LightningDataModule]) – A instance of LightningDataModule.
model?[LightningModule] – Model to fit.
train_dataloader?([Optional] [DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.
val_dataloaders( Union [DataLoader] ,List [DataLoader],None)– Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped
其他要點(diǎn)
.test()若非直接調(diào)用,不會(huì)運(yùn)行。trainer.test().test()會(huì)自動(dòng)load最優(yōu)模型。model.eval()?and?torch.no_grad()?在進(jìn)行測(cè)試時(shí)會(huì)被自動(dòng)調(diào)用。默認(rèn)情況下, Trainer()運(yùn)行于CPU上。
使用樣例
1.手動(dòng)添加命令行參數(shù):
from?argparse?import?ArgumentParser
def?main(hparams):
????model?=?LightningModule()
????trainer?=?Trainer(gpus=hparams.gpus)
????trainer.fit(model)
if?__name__?==?'__main__':
????parser?=?ArgumentParser()
????parser.add_argument('--gpus',?default=None)
????args?=?parser.parse_args()
????main(args)
2.自動(dòng)添加所有Trainer會(huì)用到的命令行參數(shù):
from?argparse?import?ArgumentParser
def?main(args):
????model?=?LightningModule()
????trainer?=?Trainer.from_argparse_args(args)
????trainer.fit(model)
if?__name__?==?'__main__':
????parser?=?ArgumentParser()
????parser?=?Trainer.add_argparse_args(
????????#?group?the?Trainer?arguments?together
????????parser.add_argument_group(title="pl.Trainer?args")
????)
????args?=?parser.parse_args()
????main(args)
3.混合式,既使用Trainer相關(guān)參數(shù),又使用一些自定義參數(shù),如各種模型超參:
from?argparse?import?ArgumentParser
import?pytorch_lightning?as?pl
from?pytorch_lightning?import?LightningModule,?Trainer
def?main(args):
????model?=?LightningModule()
????trainer?=?Trainer.from_argparse_args(args)
????trainer.fit(model)
if?__name__?==?'__main__':
????parser?=?ArgumentParser()
????parser.add_argument('--batch_size',?default=32,?type=int)
????parser.add_argument('--hidden_dim',?type=int,?default=128)
????parser?=?Trainer.add_argparse_args(
????????#?group?the?Trainer?arguments?together
????????parser.add_argument_group(title="pl.Trainer?args")
????)
????args?=?parser.parse_args()
????main(args)
所有參數(shù)
Trainer.?__init__(logger=True,?checkpoint_callback=True,?callbacks=None,?default_root_dir=None,?gradient_clip_val=0,?process_position=0,?num_nodes=1,?num_processes=1,?gpus=None,?auto_select_gpus=False,?tpu_cores=None,?log_gpu_memory=None,?progress_bar_refresh_rate=None,?overfit_batches=0.0,?track_grad_norm=- 1,?check_val_every_n_epoch=1,?fast_dev_run=False,?accumulate_grad_batches=1,?max_epochs=None,?min_epochs=None,?max_steps=None,?min_steps=None,?limit_train_batches=1.0,?limit_val_batches=1.0,?limit_test_batches=1.0,?limit_predict_batches=1.0,?val_check_interval=1.0,?flush_logs_every_n_steps=100,?log_every_n_steps=50,?accelerator=None,?sync_batchnorm=False,?precision=32,?weights_summary='top',?weights_save_path=None,?num_sanity_val_steps=2,?truncated_bptt_steps=None,?resume_from_checkpoint=None,?profiler=None,?benchmark=False,?deterministic=False,?reload_dataloaders_every_epoch=False,?auto_lr_find=False,?replace_sampler_ddp=True,?terminate_on_nan=False,?auto_scale_batch_size=False,?prepare_data_per_node=True,?plugins=None,?amp_backend='native',?amp_level='O2',?distributed_backend=None,?move_metrics_to_cpu=False,?multiple_trainloader_mode='max_size_cycle',?stochastic_weight_avg=False)
Log和return loss到底在做什么
To add a training loop use the training_step method.
class?LitClassifier(pl.LightningModule):
?????def?__init__(self,?model):
?????????super().__init__()
?????????self.model?=?model
?????def?training_step(self,?batch,?batch_idx):
?????????x,?y?=?batch
?????????y_hat?=?self.model(x)
?????????loss?=?F.cross_entropy(y_hat,?y)
?????????return?loss
無(wú)論是training_step,還是validation_step,test_step返回值都是loss。返回的loss會(huì)被用一個(gè)list收集起來(lái)。
Under the hood, Lightning does the following (pseudocode):
#?put?model?in?train?mode
model.train()
torch.set_grad_enabled(True)
losses?=?[]
for?batch?in?train_dataloader:
????#?forward
????loss?=?training_step(batch)
????losses.append(loss.detach())
????#?backward
????loss.backward()
????#?apply?and?clear?grads
????optimizer.step()
????optimizer.zero_grad()
Training epoch-level metrics
If you want to calculate epoch-level metrics and log them, use the?.log?method.
def?training_step(self,?batch,?batch_idx):
????x,?y?=?batch
????y_hat?=?self.model(x)
????loss?=?F.cross_entropy(y_hat,?y)
????#?logs?metrics?for?each?training_step,
????#?and?the?average?across?the?epoch,?to?the?progress?bar?and?logger
????self.log('train_loss',?loss,?on_step=True,?on_epoch=True,?prog_bar=True,?logger=True)
????return?loss
如果在x_step函數(shù)中使用了.log()函數(shù),那么這個(gè)量將會(huì)被逐步記錄下來(lái)。每一個(gè)log出去的變量都會(huì)被記錄下來(lái),每一個(gè)step會(huì)集中生成一個(gè)字典dict,而每個(gè)epoch都會(huì)把這些字典收集起來(lái),形成一個(gè)字典的list。
The .log object automatically reduces the requested metrics across the full epoch. Here’s the pseudocode of what it does under the hood:
outs?=?[]
for?batch?in?train_dataloader:
????#?forward
????out?=?training_step(val_batch)
????#?backward
????loss.backward()
????#?apply?and?clear?grads
????optimizer.step()
????optimizer.zero_grad()
epoch_metric?=?torch.mean(torch.stack([x['train_loss']?for?x?in?outs]))
Train epoch-level operations
If you need to do something with all the outputs of each training_step, override training_epoch_end yourself.
def?training_step(self,?batch,?batch_idx):
????x,?y?=?batch
????y_hat?=?self.model(x)
????loss?=?F.cross_entropy(y_hat,?y)
????preds?=?...
????return?{'loss':?loss,?'other_stuff':?preds}
def?training_epoch_end(self,?training_step_outputs):
???for?pred?in?training_step_outputs:
???????#?do?something
The matching pseudocode is:
outs?=?[]
for?batch?in?train_dataloader:
????#?forward
????out?=?training_step(val_batch)
????#?backward
????loss.backward()
????#?apply?and?clear?grads
????optimizer.step()
????optimizer.zero_grad()
training_epoch_end(outs)
DataModule
主頁(yè):https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
介紹
首先,這個(gè)DataModule和之前寫(xiě)的Dataset完全不沖突。前者是后者的一個(gè)包裝,并且這個(gè)包裝可以被用于多個(gè)torch Dataset 中。在我看來(lái),其最大的作用就是把各種train/val/test劃分、DataLoader初始化之類的重復(fù)代碼通過(guò)包裝類的方式得以被簡(jiǎn)單的復(fù)用。
具體作用項(xiàng)目:
Download instructions:下載 Processing instructions:處理 Split instructions:分割 Train dataloader:訓(xùn)練集Dataloader Val dataloader(s):驗(yàn)證集Dataloader Test dataloader(s):測(cè)試集Dataloader
其次,pl.LightningDataModule相當(dāng)于一個(gè)功能加強(qiáng)版的torch Dataset,加強(qiáng)的功能包括:
prepare_data(self):
最最開(kāi)始的時(shí)候,進(jìn)行一些無(wú)論GPU有多少只要執(zhí)行一次的操作,如寫(xiě)入磁盤(pán)的下載操作、分詞操作(tokenize)等。 這里是一勞永逸式準(zhǔn)備數(shù)據(jù)的函數(shù)。 由于只在單線程中調(diào)用,不要在這個(gè)函數(shù)中進(jìn)行 self.x=y似的賦值操作。但如果是自己用而不是給大眾分發(fā)的話,這個(gè)函數(shù)可能并不需要調(diào)用,因?yàn)閿?shù)據(jù)提前處理好就好了。
setup(self, stage=None):
實(shí)例化數(shù)據(jù)集(Dataset),并進(jìn)行相關(guān)操作,如:清點(diǎn)類數(shù),劃分train/val/test集合等。 參數(shù) stage用于指示是處于訓(xùn)練周期(fit)還是測(cè)試周期(test),其中,fit周期需要構(gòu)建train和val兩者的數(shù)據(jù)集。setup函數(shù)不需要返回值。初始化好的train/val/test set直接賦值給self即可。
train_dataloader/val_dataloader/test_dataloader:
初始化 DataLoader。返回一個(gè)DataLoader量。
示例
class?MNISTDataModule(pl.LightningDataModule):
????def?__init__(self,?data_dir:?str?=?'./',?batch_size:?int?=?64,?num_workers:?int?=?8):
????????super().__init__()
????????self.data_dir?=?data_dir
????????self.batch_size?=?batch_size
????????self.num_workers?=?num_workers
????????self.transform?=?transforms.Compose([
????????????transforms.ToTensor(),
????????????transforms.Normalize((0.1307,),?(0.3081,))
????????])
????????#?self.dims?is?returned?when?you?call?dm.size()
????????#?Setting?default?dims?here?because?we?know?them.
????????#?Could?optionally?be?assigned?dynamically?in?dm.setup()
????????self.dims?=?(1,?28,?28)
????????self.num_classes?=?10
????def?prepare_data(self):
????????#?download
????????MNIST(self.data_dir,?train=True,?download=True)
????????MNIST(self.data_dir,?train=False,?download=True)
????def?setup(self,?stage=None):
????????#?Assign?train/val?datasets?for?use?in?dataloaders
????????if?stage?==?'fit'?or?stage?is?None:
????????????mnist_full?=?MNIST(self.data_dir,?train=True,?transform=self.transform)
????????????self.mnist_train,?self.mnist_val?=?random_split(mnist_full,?[55000,?5000])
????????#?Assign?test?dataset?for?use?in?dataloader(s)
????????if?stage?==?'test'?or?stage?is?None:
????????????self.mnist_test?=?MNIST(self.data_dir,?train=False,?transform=self.transform)
????def?train_dataloader(self):
????????return?DataLoader(self.mnist_train,?batch_size=self.batch_size,?num_workers=self.num_workers)
????def?val_dataloader(self):
????????return?DataLoader(self.mnist_val,?batch_size=self.batch_size,?num_workers=self.num_workers)
????def?test_dataloader(self):
????????return?DataLoader(self.mnist_test,?batch_size=self.batch_size,?num_workers=self.num_workers)
要點(diǎn)
若在DataModule中定義了一個(gè)self.dims?變量,后面可以調(diào)用dm.size()獲取該變量。
Saving and Loading
主頁(yè):https://pytorch-lightning.readthedocs.io/en/latest/common/weights_loading.html
Saving
ModelCheckpoint 地址:?https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.callbacks.ModelCheckpoint.html%23pytorch_lightning.callbacks.ModelCheckpoint
ModelCheckpoint: 自動(dòng)儲(chǔ)存的callback module。默認(rèn)情況下training過(guò)程中只會(huì)自動(dòng)儲(chǔ)存最新的模型與相關(guān)參數(shù),而用戶可以通過(guò)這個(gè)module自定義。如觀測(cè)一個(gè)val_loss的量,并儲(chǔ)存top 3好的模型,且同時(shí)儲(chǔ)存最后一個(gè)epoch的模型,等等。例:
from?pytorch_lightning.callbacks?import?ModelCheckpoint
#?saves?a?file?like:?my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback?=?ModelCheckpoint(
????monitor='val_loss',
????filename='sample-mnist-{epoch:02d}-{val_loss:.2f}',
????save_top_k=3,
????mode='min',
????save_last=True
)
trainer?=?pl.Trainer(gpus=1,?max_epochs=3,?progress_bar_refresh_rate=20,?callbacks=[checkpoint_callback])
另外,也可以手動(dòng)存儲(chǔ)checkpoint:? trainer.save_checkpoint("example.ckpt")ModelCheckpoint?Callback中,如果save_weights_only =True,那么將會(huì)只儲(chǔ)存模型的權(quán)重(相當(dāng)于model.save_weights(filepath)),反之會(huì)儲(chǔ)存整個(gè)模型(相當(dāng)于model.save(filepath))。
Loading
load一個(gè)模型,包括它的weights、biases和超參數(shù):
model?=?MyLightingModule.load_from_checkpoint(PATH)
print(model.learning_rate)
#?prints?the?learning_rate?you?used?in?this?checkpoint
model.eval()
y_hat?=?model(x)
load模型時(shí)替換一些超參數(shù):
class?LitModel(LightningModule):
????def?__init__(self,?in_dim,?out_dim):
??????super().__init__()
??????self.save_hyperparameters()
??????self.l1?=?nn.Linear(self.hparams.in_dim,?self.hparams.out_dim)
#?if?you?train?and?save?the?model?like?this?it?will?use?these?values?when?loading
#?the?weights.?But?you?can?overwrite?this
LitModel(in_dim=32,?out_dim=10)
#?uses?in_dim=32,?out_dim=10
model?=?LitModel.load_from_checkpoint(PATH)
#?uses?in_dim=128,?out_dim=10
model?=?LitModel.load_from_checkpoint(PATH,?in_dim=128,?out_dim=10)
完全load訓(xùn)練狀態(tài):load包括模型的一切,以及和訓(xùn)練相關(guān)的一切參數(shù),如model, epoch, step, LR schedulers, apex等
model?=?LitModel()
trainer?=?Trainer(resume_from_checkpoint='some/path/to/my_checkpoint.ckpt')
#?automatically?restores?model,?epoch,?step,?LR?schedulers,?apex,?etc...
trainer.fit(model)
Callbacks
Callback 是一個(gè)自包含的程序,可以與訓(xùn)練流程交織在一起,而不會(huì)污染主要的研究邏輯。
Callback 并非只會(huì)在epoch結(jié)尾調(diào)用。pytorch-lightning 提供了數(shù)十個(gè)hook(接口,調(diào)用位置)可供選擇,也可以自定義callback,實(shí)現(xiàn)任何想實(shí)現(xiàn)的模塊。
推薦使用方式是,隨問(wèn)題和項(xiàng)目變化的操作,這些函數(shù)寫(xiě)到lightning module里面,而相對(duì)獨(dú)立,相對(duì)輔助性的,需要復(fù)用的內(nèi)容則可以定義單獨(dú)的模塊,供后續(xù)方便地插拔使用。
Callbacks推薦
內(nèi)建 Callbacks:https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html%23built-in-callbacks
EarlyStopping(monitor='early_stop_on', min_delta=0.0, patience=3, verbose=False, mode='min', strict=True):根據(jù)某個(gè)值,在數(shù)個(gè)epoch沒(méi)有提升的情況下提前停止訓(xùn)練。
參數(shù):
monitor?(str) – quantity to be monitored. Default: 'early_stop_on'.
min_delta?(float) – minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. Default: 0.0.
patience?(int) – number of validation epochs with no improvement after which training will be stopped. Default: 3.
verbose?(bool) – verbosity mode. Default: False.
mode?(str) – one of 'min', 'max'. In 'min' mode, training will stop when the quantity monitored has stopped decreasing and in 'max' mode it will stop when the quantity monitored has stopped increasing.
strict?(bool) – whether to crash the training if monitor is not found in the validation metrics. Default: True.
示例:
from?pytorch_lightning?import?Trainer
from?pytorch_lightning.callbacks?import?EarlyStopping
early_stopping?=?EarlyStopping('val_loss')
trainer?=?Trainer(callbacks=[early_stopping])
ModelCheckpoint:見(jiàn)上文Saving and Loading.PrintTableMetricsCallback:在每個(gè)epoch結(jié)束后打印一份結(jié)果整理表格。
from?pl_bolts.callbacks?import?PrintTableMetricsCallback
callback?=?PrintTableMetricsCallback()
trainer?=?pl.Trainer(callbacks=[callback])
trainer.fit(...)
#?------------------------------
#?at?the?end?of?every?epoch?it?will?print
#?------------------------------
#?loss│train_loss│val_loss│epoch
#?──────────────────────────────
#?2.2541470527648926│2.2541470527648926│2.2158432006835938│0
Logging
Logging:Logger默認(rèn)是TensorBoard,但可以指定各種主流Logger框架,如Comet.ml,MLflow,Netpune,或直接CSV文件??梢酝瑫r(shí)使用復(fù)數(shù)個(gè)logger。
from?pytorch_lightning?import?loggers?as?pl_loggers
#?Default
tb_logger?=?pl_loggers.TensorBoardLogger(
????save_dir=os.getcwd(),
????version=None,
????name='lightning_logs'
)
trainer?=?Trainer(logger=tb_logger)
#?Or?use?the?same?format?as?others
tb_logger?=?pl_loggers.TensorBoardLogger('logs/')
#?One?Logger
comet_logger?=?pl_loggers.CometLogger(save_dir='logs/')
trainer?=?Trainer(logger=comet_logger)
#?Save?code?snapshot
logger?=?pl_loggers.TestTubeLogger('logs/',?create_git_tag=True)
#?Multiple?Logger
tb_logger?=?pl_loggers.TensorBoardLogger('logs/')
comet_logger?=?pl_loggers.CometLogger(save_dir='logs/')
trainer?=?Trainer(logger=[tb_logger,?comet_logger])
默認(rèn)情況下,每50個(gè)batch log一次,可以通過(guò)調(diào)整參數(shù)。
如果想要log輸出非scalar(標(biāo)量)的內(nèi)容,如圖片,文本,直方圖等等,可以直接調(diào)用self.logger.experiment.add_xxx()來(lái)實(shí)現(xiàn)所需操作。
def?training_step(...):
????...
????#?the?logger?you?used?(in?this?case?tensorboard)
????tensorboard?=?self.logger.experiment
????tensorboard.add_image()
????tensorboard.add_histogram(...)
????tensorboard.add_figure(...)
使用log:如果是TensorBoard,那么:tensorboard --logdir ./lightning_logs。在Jupyter Notebook中,可以使用:
#?Start?tensorboard.
%load_ext?tensorboard
%tensorboard?--logdir?lightning_logs/
在行內(nèi)打開(kāi)TensorBoard。
小技巧:如果在局域網(wǎng)內(nèi)開(kāi)啟了TensorBoard,加上flag --bind_all即可使用主機(jī)名訪問(wèn):
tensorboard?--logdir?lightning_logs?--bind_all`?->?`http://SERVER-NAME:6006/
Transfer Learning
主頁(yè):https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html%23transfer-learning
import?torchvision.models?as?models
class?ImagenetTransferLearning(LightningModule):
????def?__init__(self):
????????super().__init__()
????????#?init?a?pretrained?resnet
????????backbone?=?models.resnet50(pretrained=True)
????????num_filters?=?backbone.fc.in_features
????????layers?=?list(backbone.children())[:-1]
????????self.feature_extractor?=?nn.Sequential(*layers)
????????#?use?the?pretrained?model?to?classify?cifar-10?(10?image?classes)
????????num_target_classes?=?10
????????self.classifier?=?nn.Linear(num_filters,?num_target_classes)
????def?forward(self,?x):
????????self.feature_extractor.eval()
????????with?torch.no_grad():
????????????representations?=?self.feature_extractor(x).flatten(1)
????????x?=?self.classifier(representations)
????????...
關(guān)于device操作
LightningModules know what device they are on! Construct tensors on the device directly to avoid CPU->Device transfer.
#?bad
t?=?torch.rand(2,?2).cuda()
#?good?(self?is?LightningModule)
t?=?torch.rand(2,?2,?device=self.device)
For tensors that need to be model attributes, it is best practice to register them as buffers in the modules’?__init__?method:
#?bad
self.t?=?torch.rand(2,?2,?device=self.device)
#?good
self.register_buffer("t",?torch.rand(2,?2))
前面兩段是教程中的文本。然而實(shí)際上有一個(gè)暗坑:
如果你使用了一個(gè)中繼的pl.LightningModule,而這個(gè)module里面實(shí)例化了某個(gè)普通的nn.Module,而這個(gè)模型中又需要內(nèi)部生成一些tensor,比如圖片每個(gè)通道的mean,std之類,那么如果你從pl.LightningModule中pass一個(gè)self.device,實(shí)際上在一開(kāi)始這個(gè)self.device永遠(yuǎn)是cpu。所以如果你在調(diào)用的nn.Module的__init__()中初始化,使用to(device)或干脆什么都不用,結(jié)果就是它永遠(yuǎn)都在cpu上。
但是,經(jīng)過(guò)實(shí)驗(yàn),雖然pl.LightningModule在__init__()階段self.device還是cpu,當(dāng)進(jìn)入了training_step()之后,就迅速變?yōu)榱?code style="margin-right: 2px;margin-left: 2px;padding: 2px 4px;font-size: 14px;overflow-wrap: break-word;border-radius: 4px;color: rgb(30, 107, 184);background-color: rgba(27, 31, 35, 0.05);font-family: 'Operator Mono', Consolas, Monaco, Menlo, monospace;word-break: break-all;">cuda。所以,對(duì)于子模塊,最佳方案是,使用一個(gè)forward中傳入的量,如x,作為一個(gè)reference變量,用type_as函數(shù)將在模型中生成的tensor都放到和這個(gè)參考變量相同的device上即可。
class?RDNFuse(nn.Module):
????...
????def?init_norm_func(self,?ref):
????????self.mean?=?torch.tensor(np.array(self.mean_sen),?dtype=torch.float32).type_as(ref)
????def?forward(self,?x):
????????if?not?hasattr(self,?'mean'):
????????????self.init_norm_func(x)
Points
pl.seed_everything(1234):對(duì)所有相關(guān)的隨機(jī)量固定種子。
使用LR Scheduler時(shí)候,不用自己.step()。它也被Trainer自動(dòng)處理了。
相關(guān)界面:https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html%3Fhighlight%3Dscheduler%23
#?Single?optimizer
for?epoch?in?epochs:
????for?batch?in?data:
????????loss?=?model.training_step(batch,?batch_idx,?...)
????????loss.backward()
????????optimizer.step()
????????optimizer.zero_grad()
????for?scheduler?in?schedulers:
????????scheduler.step()
????????
#?Multiple?optimizers
for?epoch?in?epochs:
??for?batch?in?data:
?????for?opt?in?optimizers:
????????disable_grads_for_other_optimizers()
????????train_step(opt)
????????opt.step()
??for?scheduler?in?schedulers:
?????scheduler.step()
關(guān)于劃分train和val集合的方法。與PL無(wú)關(guān),但很常用,兩個(gè)例子:random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
如下:
from?torch.utils.data?import?DataLoader,?random_split
from?torchvision.datasets?import?MNIST
mnist_full?=?MNIST(self.data_dir,?train=True,?transform=self.transform)
self.mnist_train,?self.mnist_val?=?random_split(mnist_full,?[55000,?5000])
Parameters:
dataset?(https://pytorch.org/docs/stable/data.html%23torch.utils.data.Dataset) – Dataset to be split
lengths?– lengths of splits to be produced
generator?(https://pytorch.org/docs/stable/generated/torch.Generator.html%23torch.Generator) – Generator used for the random permutation.
干貨學(xué)習(xí),點(diǎn)贊三連↓
