實(shí)踐教程 | Pytorch中模型的保存與遷移

極市導(dǎo)讀
?在本篇文章中,筆者首先介紹了模型復(fù)用的幾種典型場景;然后介紹了如何查看Pytorch模型中的相關(guān)參數(shù)信息;接著介紹了如何載入模型、如何進(jìn)行追加訓(xùn)練以及進(jìn)行模型的遷移學(xué)習(xí)等。?>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿

1 引言
各位朋友大家好,今天要和大家介紹的內(nèi)容是如何在Pytorch框架中對模型進(jìn)行保存和載入、以及模型的遷移和再訓(xùn)練。
一般來說,最常見的場景就是模型完成訓(xùn)練后的推斷過程。一個(gè)網(wǎng)絡(luò)模型在完成訓(xùn)練后通常都需要對新樣本進(jìn)行預(yù)測,此時(shí)就只需要構(gòu)建模型的前向傳播過程,然后載入已訓(xùn)練好的參數(shù)初始化網(wǎng)絡(luò)即可。
第2個(gè)場景就是模型的再訓(xùn)練過程。一個(gè)模型在一批數(shù)據(jù)上訓(xùn)練完成之后需要將其保存到本地,并且可能過了一段時(shí)間后又收集到了一批新的數(shù)據(jù),因此這個(gè)時(shí)候就需要將之前的模型載入進(jìn)行在新數(shù)據(jù)上進(jìn)行增量訓(xùn)練(或者是在整個(gè)數(shù)據(jù)上進(jìn)行全量訓(xùn)練)。
第3個(gè)應(yīng)用場景就是模型的遷移學(xué)習(xí)。這個(gè)時(shí)候就是將別人已經(jīng)訓(xùn)練好的預(yù)模型拿過來,作為你自己網(wǎng)絡(luò)模型參數(shù)的一部分進(jìn)行初始化。例如:你自己在Bert模型的基礎(chǔ)上加了幾個(gè)全連接層來做分類任務(wù),那么你就需要將原始BERT模型中的參數(shù)載入并以此來初始化你的網(wǎng)絡(luò)中的Bert部分的權(quán)重參數(shù)。
在接下來的這篇文章中,筆者就以上述3個(gè)場景為例來介紹如何利用Pytorch框架來完成上述過程。
2 模型的保存與復(fù)用
在Pytorch中,我們可以通過torch.save()和torch.load()來完成上述場景中的主要步驟。下面,筆者將以之前介紹的LeNet5網(wǎng)絡(luò)模型為例來分別進(jìn)行介紹。不過在這之前,我們先來看看Pytorch中模型參數(shù)的保存形式。
2.1 查看網(wǎng)絡(luò)模型參數(shù)
(1)查看參數(shù)
首先定義好LeNet5的網(wǎng)絡(luò)模型結(jié)構(gòu),如下代碼所示:
class?LeNet5(nn.Module):
????def?__init__(self,?):
????????super(LeNet5,?self).__init__()
????????self.conv?=?nn.Sequential(??#?[n,1,28,28]
????????????nn.Conv2d(1,?6,?5,?padding=2),??#?in_channels,?out_channels,?kernel_size
????????????nn.ReLU(),??#?[n,6,24,24]
????????????nn.MaxPool2d(2,?2),??#?kernel_size,?stride??[n,6,14,14]
????????????nn.Conv2d(6,?16,?5),??#?[n,16,10,10]
????????????nn.ReLU(),
????????????nn.MaxPool2d(2,?2))??#?[n,16,5,5]
????????self.fc?=?nn.Sequential(
????????????nn.Flatten(),
????????????nn.Linear(16?*?5?*?5,?120),
????????????nn.ReLU(),
????????????nn.Linear(120,?84),
????????????nn.ReLU(),
????????????nn.Linear(84,?10))
????def?forward(self,?img):
????????output?=?self.conv(img)
????????output?=?self.fc(output)
????????return?output
在定義好LeNet5這個(gè)網(wǎng)絡(luò)結(jié)構(gòu)的類之后,只要我們完成了這個(gè)類的實(shí)例化操作,那么網(wǎng)絡(luò)中對應(yīng)的權(quán)重參數(shù)也都完成了初始化的工作,即有了一個(gè)初始值。同時(shí),我們可以通過如下方式來訪問:
#?Print?model's?state_dict
print("Model's?state_dict:")
for?param_tensor?in?model.state_dict():
????print(param_tensor,?"\t",?model.state_dict()[param_tensor].size())
其輸出的結(jié)果為:
conv.0.weight???torch.Size([6,?1,?5,?5])
conv.0.bias???torch.Size([6])
conv.3.weight???torch.Size([16,?6,?5,?5])
....
....
可以發(fā)現(xiàn),網(wǎng)絡(luò)模型中的參數(shù)model.state_dict()其實(shí)是以字典的形式(實(shí)質(zhì)上是collections模塊中的OrderedDict)保存下來的:
print(model.state_dict().keys())
#?odict_keys(['conv.0.weight',?'conv.0.bias',?'conv.3.weight',?
'conv.3.bias',?'fc.1.weight',?'fc.1.bias',?'fc.3.weight',?'fc.3.bias',?
'fc.5.weight',?'fc.5.bias'])
(2)自定義參數(shù)前綴
同時(shí),這里值得注意的地方有兩點(diǎn):①參數(shù)名中的fc和conv前綴是根據(jù)你在上面定義nn.Sequential()時(shí)的名字所確定的;②參數(shù)名中的數(shù)字表示每個(gè)Sequential()中網(wǎng)絡(luò)層所在的位置。例如將網(wǎng)絡(luò)結(jié)構(gòu)定義成如下形式:
class?LeNet5(nn.Module):
????def?__init__(self,?):
????????super(LeNet5,?self).__init__()
????????self.moon?=?nn.Sequential(??#?[n,1,28,28]
????????????nn.Conv2d(1,?6,?5,?padding=2),??#?in_channels,?out_channels,?kernel_size
????????????nn.ReLU(),??#?[n,6,24,24]
????????????nn.MaxPool2d(2,?2),??#?kernel_size,?stride??[n,6,14,14]
????????????nn.Conv2d(6,?16,?5),??#?[n,16,10,10]
????????????nn.ReLU(),
????????????nn.MaxPool2d(2,?2),
????????????nn.Flatten(),
????????????nn.Linear(16?*?5?*?5,?120),
????????????nn.ReLU(),
????????????nn.Linear(120,?84),
????????????nn.ReLU(),
????????????nn.Linear(84,?10))
那么其參數(shù)名則為:
print(model.state_dict().keys())
odict_keys(['moon.0.weight',?'moon.0.bias',?'moon.3.weight',
?'moon.3.bias',?'moon.7.weight',?'moon.7.bias',?'moon.9.weight',?
'moon.9.bias',?'moon.11.weight',?'moon.11.bias'])
理解了這一點(diǎn)對于后續(xù)我們?nèi)ソ馕龊洼d入一些預(yù)訓(xùn)練模型很有幫助。
除此之外,對于中的優(yōu)化器等,其同樣有對應(yīng)的state_dict()方法來獲取對于的參數(shù),例如:
optimizer?=?torch.optim.SGD(model.parameters(),?lr=0.001,?momentum=0.9)
print("Optimizer's?state_dict:")
for?var_name?in?optimizer.state_dict():
???print(var_name,?"\t",?optimizer.state_dict()[var_name])
????
#
Optimizer's?state_dict:
state???{}
param_groups???[{'lr':?0.001,?'momentum':?0.9,?'dampening':?0,?
'weight_decay':?0,?'nesterov':?False,?
'params':?[140239245300504,?140239208339784,?140239245311360,?
140239245310856,?140239266942480,?140239266942552,?140239266942624,?
140239266942696,?140239266942912,?140239267041352]}]
在介紹完模型參數(shù)的查看方法后,就可以進(jìn)入到模型復(fù)用階段的內(nèi)容介紹了。
2.2 載入模型進(jìn)行推斷
(1) 模型保存
在Pytorch中,對于模型的保存來說是非常簡單的,通常來說通過如下兩行代碼便可以實(shí)現(xiàn):
model_save_path?=?os.path.join(model_save_dir,?'model.pt')
torch.save(model.state_dict(),?model_save_path)
在指定保存的模型名稱時(shí)Pytorch官方建議的后綴為.pt或者.pth(當(dāng)然也不是強(qiáng)制的)。最后,只需要在合適的地方加入第2行代碼即可完成模型的保存。
同時(shí),如果想要在訓(xùn)練過程中保存某個(gè)條件下的最優(yōu)模型,那么應(yīng)該通過如下方式:
best_model_state?=?deepcopy(model.state_dict())?
torch.save(best_model_state,?model_save_path)
而不是:
best_model_state?=?model.state_dict()?
torch.save(best_model_state,?model_save_path)
因?yàn)楹笳?code style="margin-right: 2px;margin-left: 2px;padding: 2px 4px;font-size: 14px;overflow-wrap: break-word;border-radius: 4px;color: rgb(30, 107, 184);background-color: rgba(27, 31, 35, 0.05);font-family: "Operator Mono", Consolas, Monaco, Menlo, monospace;word-break: break-all;">best_model_state得到只是model.state_dict()的引用,它依舊會隨著訓(xùn)練過程而發(fā)生改變。
(2)復(fù)用模型進(jìn)行推斷
在推斷過程中,首先需要完成網(wǎng)絡(luò)的初始化,然后再載入已有的模型參數(shù)來覆蓋網(wǎng)絡(luò)中的權(quán)重參數(shù)即可,示例代碼如下:
def?inference(data_iter,?device,?model_save_dir='./MODEL'):???
????model?=?LeNet5()??#?初始化現(xiàn)有模型的權(quán)重參數(shù)????
????model.to(device)????
????model_save_path?=?os.path.join(model_save_dir,?'model.pt')????
????if?os.path.exists(model_save_path):????????
????????loaded_paras?=?torch.load(model_save_path)????????
????model.load_state_dict(loaded_paras)??#?用本地已有模型來重新初始化網(wǎng)絡(luò)權(quán)重參數(shù)?????
????model.eval()?#?注意不要忘記????
????with?torch.no_grad():????????
????????acc_sum,?n?=?0.0,?0????????
????????for?x,?y?in?data_iter:????????????
????????x,?y?=?x.to(device),?y.to(device)????????????
????????logits?=?model(x)????????????
????????acc_sum?+=?(logits.argmax(1)?==?y).float().sum().item()????????????
????????n?+=?len(y)????????
????????print("Accuracy?in?test?data?is?:",?acc_sum?/?n)
在上述代碼中,4-7行便是用來載入本地模型參數(shù),并用其覆蓋網(wǎng)絡(luò)模型中原有的參數(shù)。這樣,便可以進(jìn)行后續(xù)的推斷工作:
Accuracy?in?test?data?is?:?0.8851
2.3 載入模型進(jìn)行訓(xùn)練
在介紹完模型的保存與復(fù)用之后,對于網(wǎng)絡(luò)的追加訓(xùn)練就很簡單了。最簡便的一種方式就是在訓(xùn)練過程中只保存網(wǎng)絡(luò)權(quán)重,然后在后續(xù)進(jìn)行追加訓(xùn)練時(shí)只載入網(wǎng)絡(luò)權(quán)重參數(shù)初始化網(wǎng)絡(luò)進(jìn)行訓(xùn)練即可,示例如下(完整代碼參見[2]):
??def?train(self):
????????#......
????????model_save_path?=?os.path.join(self.model_save_dir,?'model.pt')
????????if?os.path.exists(model_save_path):
????????????loaded_paras?=?torch.load(model_save_path)
????????????self.model.load_state_dict(loaded_paras)
????????????print("####?成功載入已有模型,進(jìn)行追加訓(xùn)練...")
????????optimizer?=?torch.optim.Adam(self.model.parameters(),?lr=self.learning_rate)??#?定義優(yōu)化器
???????#......
????????for?epoch?in?range(self.epochs):
????????????for?i,?(x,?y)?in?enumerate(train_iter):
????????????????x,?y?=?x.to(device),?y.to(device)
????????????????logits?=?self.model(x)
????????????????#?......
????????????print("Epochs[{}/{}]--acc?on?test?{:.4}".format(epoch,?self.epochs,
??????????????????????????????????????????????self.evaluate(test_iter,?self.model,?device)))
????????????torch.save(self.model.state_dict(),?model_save_path)
這樣,便完成了模型的追加訓(xùn)練:
####?成功載入已有模型,進(jìn)行追加訓(xùn)練...
Epochs[0/5]---batch[938/0]---acc?0.9062---loss?0.2926
Epochs[0/5]---batch[938/100]---acc?0.9375---loss?0.1598
......
除此之外,你也可以在保存參數(shù)的時(shí)候,將優(yōu)化器參數(shù)、損失值等一同保存下來,然后在恢復(fù)模型的時(shí)候連同其它參數(shù)一起恢復(fù),示例如下:
model_save_path?=?os.path.join(model_save_dir,?'model.pt')
torch.save({
????????????'epoch':?epoch,
????????????'model_state_dict':?model.state_dict(),
????????????'optimizer_state_dict':?optimizer.state_dict(),
????????????'loss':?loss,
????????????...
????????????},?model_save_path)
載入方式如下:
checkpoint?=?torch.load(model_save_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch?=?checkpoint['epoch']
loss?=?checkpoint['loss']
2.4 載入模型進(jìn)行遷移
(1)定義新模型
到目前為止,對于前面兩種應(yīng)用場景的介紹就算完成了,可以發(fā)現(xiàn)總體上并不復(fù)雜。但是對于第3中場景的應(yīng)用來說就會略微復(fù)雜一點(diǎn)。
假設(shè)現(xiàn)在有一個(gè)LeNet6網(wǎng)絡(luò)模型,它是在LeNet5的基礎(chǔ)最后多加了一個(gè)全連接層,其定義如下:
class?LeNet6(nn.Module):
????def?__init__(self,?):
????????super(LeNet6,?self).__init__()
????????self.conv?=?nn.Sequential(??#?[n,1,28,28]
????????????nn.Conv2d(1,?6,?5,?padding=2),??#?in_channels,?out_channels,?kernel_size
????????????nn.ReLU(),??#?[n,6,24,24]
????????????nn.MaxPool2d(2,?2),??#?kernel_size,?stride??[n,6,14,14]
????????????nn.Conv2d(6,?16,?5),??#?[n,16,10,10]
????????????nn.ReLU(),
????????????nn.MaxPool2d(2,?2))??#?[n,16,5,5]
????????self.fc?=?nn.Sequential(
????????????nn.Flatten(),
????????????nn.Linear(16?*?5?*?5,?120),
????????????nn.ReLU(),
????????????nn.Linear(120,?84),
????????????nn.ReLU(),
????????????nn.Linear(84,?64),?
????????????nn.ReLU(),
????????????nn.Linear(64,?10)?)?#?新加入的全連接層
接下來,我們需要將在LeNet5上訓(xùn)練得到的權(quán)重參數(shù)遷移到LeNet6網(wǎng)絡(luò)中去。從上面LeNet6的定義可以發(fā)現(xiàn),此時(shí)盡管只是多加了一個(gè)全連接層,但是倒數(shù)第2層參數(shù)的維度也發(fā)生了變換。因此,對于LeNet6來說只能復(fù)用LeNet5網(wǎng)絡(luò)前面4層的權(quán)重參數(shù)。
(2)查看模型參數(shù)
在拿到一個(gè)模型參數(shù)后,首先我們可以將其載入,然查看相關(guān)參數(shù)的信息:
model_save_path?=?os.path.join('./MODEL',?'model.pt')
loaded_paras?=?torch.load(model_save_path)
for?param_tensor?in?loaded_paras:
????print(param_tensor,?"\t",?loaded_paras[param_tensor].size())
#----?可復(fù)用部分
conv.0.weight???torch.Size([6,?1,?5,?5])
conv.0.bias???torch.Size([6])
conv.3.weight???torch.Size([16,?6,?5,?5])
conv.3.bias???torch.Size([16])
fc.1.weight???torch.Size([120,?400])
fc.1.bias???torch.Size([120])
fc.3.weight???torch.Size([84,?120])
fc.3.bias???torch.Size([84])
#-----?不可復(fù)用部分
fc.5.weight???torch.Size([10,?84])
fc.5.bias???torch.Size([10])
同時(shí),對于LeNet6網(wǎng)絡(luò)的參數(shù)信息為:
model?=?LeNet6()
for?param_tensor?in?model.state_dict():
????print(param_tensor,?"\t",?model.state_dict()[param_tensor].size())
#
conv.0.weight???torch.Size([6,?1,?5,?5])
conv.0.bias???torch.Size([6])
conv.3.weight???torch.Size([16,?6,?5,?5])
conv.3.bias???torch.Size([16])
fc.1.weight???torch.Size([120,?400])
fc.1.bias???torch.Size([120])
fc.3.weight???torch.Size([84,?120])
fc.3.bias???torch.Size([84])
#------?新加入部分
fc.5.weight???torch.Size([64,?84])
fc.5.bias???torch.Size([64])
fc.7.weight???torch.Size([10,?64])
fc.7.bias???torch.Size([10])
在理清楚了新舊模型的參數(shù)后,下面就可以將LeNet5中我們需要的參數(shù)給取出來,然后再換到LeNet6的網(wǎng)絡(luò)中。
(3)模型遷移
雖然本地載入的模型參數(shù)(上面的loaded_paras)和模型初始化后的參數(shù)(上面的model.state_dict())都是一個(gè)字典的形式,但是我們并不能夠直接改變model.state_dict()中的權(quán)重參數(shù)。這里需要先構(gòu)造一個(gè)state_dict然后通過model.load_state_dict()方法來重新初始化網(wǎng)絡(luò)中的參數(shù)。
同時(shí),在這個(gè)過程中我們需要篩選掉本地模型中不可復(fù)用的部分,具體代碼如下:
def?para_state_dict(model,?model_save_dir):
????state_dict?=?deepcopy(model.state_dict())
????model_save_path?=?os.path.join(model_save_dir,?'model.pt')
????if?os.path.exists(model_save_path):
????????loaded_paras?=?torch.load(model_save_path)
????????for?key?in?state_dict:??#?在新的網(wǎng)絡(luò)模型中遍歷對應(yīng)參數(shù)
????????????if?key?in?loaded_paras?and?state_dict[key].size()?==?loaded_paras[key].size():
????????????????print("成功初始化參數(shù):",?key)
????????????????state_dict[key]?=?loaded_paras[key]
????return?state_dict
在上述代碼中,第2行的作用是先拷貝網(wǎng)絡(luò)中(LeNet6)原有的參數(shù);第6-9行則是用本地的模型參數(shù)(LeNet5)中可以復(fù)用的替換掉LeNet6中的對應(yīng)部分,其中第7行就是判斷可用的條件。同時(shí)需要注意的是在不同的情況下篩選的方式可能不一樣,因此具體情況需要具體分析,但是整體邏輯是一樣的。
最后,我們只需要在模型訓(xùn)練之前調(diào)用該函數(shù),然后重新初始化LeNet6中的部分權(quán)重參數(shù)即可[2]:
state_dict?=?para_state_dict(self.model,?self.model_save_dir)
self.model.load_state_dict(state_dict)
訓(xùn)練結(jié)果如下:
成功初始化參數(shù):?conv.0.weight
成功初始化參數(shù):?conv.0.bias
成功初始化參數(shù):?conv.3.weight
成功初始化參數(shù):?conv.3.bias
成功初始化參數(shù):?fc.1.weight
成功初始化參數(shù):?fc.1.bias
成功初始化參數(shù):?fc.3.weight
成功初始化參數(shù):?fc.3.bias
####?成功載入已有模型,進(jìn)行追加訓(xùn)練...
Epochs[0/5]---batch[938/0]---acc?0.1094---loss?2.512
Epochs[0/5]---batch[938/100]---acc?0.9375---loss?0.2141
Epochs[0/5]---batch[938/200]---acc?0.9219---loss?0.2729
Epochs[0/5]---batch[938/300]---acc?0.8906---loss?0.2958
......
Epochs[0/5]---batch[938/900]---acc?0.8906---loss?0.2828
Epochs[0/5]--acc?on?test?0.8808
可以發(fā)現(xiàn),在大約100個(gè)batch之后,模型的準(zhǔn)確率就提升上來了。
3 總結(jié)
在本篇文章中,筆者首先介紹了模型復(fù)用的幾種典型場景;然后介紹了如何查看Pytorch模型中的相關(guān)參數(shù)信息;接著介紹了如何載入模型、如何進(jìn)行追加訓(xùn)練以及進(jìn)行模型的遷移學(xué)習(xí)等。
感謝您的閱讀!
引用
[1]?SAVING AND LOADING MODELS?https://pytorch.org/tutorials/beginner/saving_loading_models.html
[2]?示例代碼?https://github.com/moon-hotel/DeepLearningWithMe
如果覺得有用,就請分享到朋友圈吧!
公眾號后臺回復(fù)“CVPR21檢測”獲取CVPR2021目標(biāo)檢測論文下載~

#?CV技術(shù)社群邀請函?#

備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測-深圳)
即可申請加入極市目標(biāo)檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動(dòng)駕駛/超分辨率/姿態(tài)估計(jì)/ReID/GAN/圖像增強(qiáng)/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實(shí)項(xiàng)目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與?10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動(dòng)交流~

