<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          PyTorch 深度剖析:如何保存和加載PyTorch模型?

          共 6824字,需瀏覽 14分鐘

           ·

          2021-11-30 16:58

          點(diǎn)擊上方視學(xué)算法”,選擇加"星標(biāo)"或“置頂

          重磅干貨,第一時(shí)間送達(dá)

          作者丨科技猛獸
          編輯丨極市平臺(tái)

          導(dǎo)讀

          ?

          本文詳解了PyTorch 模型的保存與加載方法。

          目錄

          1 需要掌握3個(gè)重要的函數(shù)

          2 state_dict
          2.1 state_dict 介紹
          2.2 保存和加載 state_dict (已經(jīng)訓(xùn)練完,無(wú)需繼續(xù)訓(xùn)練)
          2.3 保存和加載整個(gè)模型 (已經(jīng)訓(xùn)練完,無(wú)需繼續(xù)訓(xùn)練)
          2.4 保存和加載 state_dict (沒(méi)有訓(xùn)練完,還會(huì)繼續(xù)訓(xùn)練)
          2.5 把多個(gè)模型存進(jìn)一個(gè)文件
          2.6 使用其他模型的參數(shù)暖啟動(dòng)自己的模型
          2.7 保存在 GPU, 加載到 CPU
          2.8 保存在 GPU, 加載到 GPU
          2.9 保存在 CPU, 加載到 GPU

          1 需要掌握3個(gè)重要的函數(shù)

          1) torch.save: 將一個(gè)序列化的對(duì)象保存到磁盤(pán)。這個(gè)函數(shù)使用 Python 的 pickle 工具進(jìn)行序列化。模型 (model)、張量 (tensor)各種對(duì)象的字典 (dict) 都可以用這個(gè)函數(shù)保存。

          2) torch.load: 將 pickled 對(duì)象文件反序列化到內(nèi)存,也便于將數(shù)據(jù)加載到設(shè)備中。

          3) torch.nn.Module.load_state_dict(): 加載模型的參數(shù)。

          2 state_dict

          2.1 state_dict 介紹

          PyTorch 中,torch.nn.Module里面的可學(xué)習(xí)的參數(shù) (weights 和 biases) 都放在model.parameters()里面。而 state_dict 是一個(gè) Python dictionary object,將每一層映射到它的 parameter tensor 上。注意:只有含有可學(xué)習(xí)參數(shù)的層 (convolutional layers, linear layers),或者含有 registered buffers 的層 (batchnorm's running_mean) 才有 state_dict。優(yōu)化器的對(duì)象 (torch.optim) 也有 state_dict,存儲(chǔ)了優(yōu)化器的狀態(tài)和它的超參數(shù)。

          因?yàn)?state_dict 是一個(gè) Python dictionary object,所以保存,加載,更新它比較容易。

          下面我們通過(guò)一個(gè)例子直觀感受下 state_dict 的用法:

          # Define model
          class TheModelClass(nn.Module):
          def __init__(self):
          super(TheModelClass, self).__init__()
          self.conv1 = nn.Conv2d(3, 6, 5)
          self.pool = nn.MaxPool2d(2, 2)
          self.conv2 = nn.Conv2d(6, 16, 5)
          self.fc1 = nn.Linear(16 * 5 * 5, 120)
          self.fc2 = nn.Linear(120, 84)
          self.fc3 = nn.Linear(84, 10)

          def forward(self, x):
          x = self.pool(F.relu(self.conv1(x)))
          x = self.pool(F.relu(self.conv2(x)))
          x = x.view(-1, 16 * 5 * 5)
          x = F.relu(self.fc1(x))
          x = F.relu(self.fc2(x))
          x = self.fc3(x)
          return x

          # Initialize model
          model = TheModelClass()

          # Initialize optimizer
          optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

          # Print model's state_dict
          print("Model's state_dict:")
          for param_tensor in model.state_dict():
          print(param_tensor, "\t", model.state_dict()[param_tensor].size())

          # Print optimizer's state_dict
          print("Optimizer's state_dict:")
          for var_name in optimizer.state_dict():
          print(var_name, "\t", optimizer.state_dict()[var_name])

          輸出:

          Model's state_dict:
          conv1.weight torch.Size([6, 3, 5, 5])
          conv1.bias torch.Size([6])
          conv2.weight torch.Size([16, 6, 5, 5])
          conv2.bias torch.Size([16])
          fc1.weight torch.Size([120, 400])
          fc1.bias torch.Size([120])
          fc2.weight torch.Size([84, 120])
          fc2.bias torch.Size([84])
          fc3.weight torch.Size([10, 84])
          fc3.bias torch.Size([10])

          Optimizer's state_dict:
          state {}
          param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

          2.2 保存和加載 state_dict (已經(jīng)訓(xùn)練完,無(wú)需繼續(xù)訓(xùn)練)

          保存:

          torch.save(model.state_dict(), PATH)

          加載:

          model = TheModelClass(*args, **kwargs)
          model.load_state_dict(torch.load(PATH))
          model.eval()

          一般保存為.pt.pth 格式的文件。

          注意:

          1. 可以使用model.eval()將 dropout 和 batch normalization 層設(shè)置成 evaluation 模式。
          2. load_state_dict()函數(shù)需要一個(gè) dict 類(lèi)型的輸入,而不是保存模型的 PATH。所以這樣 model.load_state_dict(PATH)是錯(cuò)誤的,而應(yīng)該model.load_state_dict(torch.load(PATH))。
          3. 如果你想保存驗(yàn)證機(jī)上表現(xiàn)最好的模型,那么這樣best_model_state=model.state_dict()是錯(cuò)誤的。因?yàn)檫@屬于淺復(fù)制,也就是說(shuō)此時(shí)這個(gè) best_model_state 會(huì)隨著后續(xù)的訓(xùn)練過(guò)程而不斷被更新,最后保存的其實(shí)是個(gè) overfit 的模型。所以正確的做法應(yīng)該是best_model_state=deepcopy(model.state_dict())。

          2.3 保存和加載整個(gè)模型 (已經(jīng)訓(xùn)練完,無(wú)需繼續(xù)訓(xùn)練)

          保存:

          torch.save(model, PATH)

          加載:

          # Model class must be defined somewhere
          model = torch.load(PATH)
          model.eval()

          一般保存為.pt.pth格式的文件。

          注意:

          1. 可以使用model.eval()將 dropout 和 batch normalization 層設(shè)置成 evaluation 模式。

          2.4 保存和加載 state_dict (沒(méi)有訓(xùn)練完,還會(huì)繼續(xù)訓(xùn)練)

          保存:

          torch.save({
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss': loss,
          ...
          }, PATH)

          與2.2的不同是除了保存 model_state_dict 之外,還需要保存:optimizer_state_dict,epoch 和 loss,因?yàn)槔^續(xù)訓(xùn)練時(shí)要知道優(yōu)化器的狀態(tài),epoch 等等。

          加載:

          model = TheModelClass(*args, **kwargs)
          optimizer = TheOptimizerClass(*args, **kwargs)

          checkpoint = torch.load(PATH)
          model.load_state_dict(checkpoint['model_state_dict'])
          optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
          epoch = checkpoint['epoch']
          loss = checkpoint['loss']

          model.eval()
          # - or -
          model.train()

          與2.2的不同是除了加載 model_state_dict 之外,還需要加載:optimizer_state_dict,epoch 和 loss。

          2.5 把多個(gè)模型存進(jìn)一個(gè)文件

          保存:

          torch.save({
          'modelA_state_dict': modelA.state_dict(),
          'modelB_state_dict': modelB.state_dict(),
          'optimizerA_state_dict': optimizerA.state_dict(),
          'optimizerB_state_dict': optimizerB.state_dict(),
          ...
          }, PATH)

          把模型 A 和 B 的 state_dict 和 optimizer 都存進(jìn)一個(gè)文件中。

          加載:

          modelA = TheModelAClass(*args, **kwargs)
          modelB = TheModelBClass(*args, **kwargs)
          optimizerA = TheOptimizerAClass(*args, **kwargs)
          optimizerB = TheOptimizerBClass(*args, **kwargs)

          checkpoint = torch.load(PATH)
          modelA.load_state_dict(checkpoint['modelA_state_dict'])
          modelB.load_state_dict(checkpoint['modelB_state_dict'])
          optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
          optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

          modelA.eval()
          modelB.eval()
          # - or -
          modelA.train()
          modelB.train()

          2.6 使用其他模型的參數(shù)暖啟動(dòng)自己的模型

          有時(shí)候訓(xùn)練一個(gè)新的復(fù)雜模型時(shí),需要加載它的一部分預(yù)訓(xùn)練的權(quán)重。即使只有幾個(gè)可用的參數(shù),也會(huì)有助于 warmstart 訓(xùn)練過(guò)程,幫助模型更快達(dá)到收斂。

          如果手里有的這個(gè) state_dict 缺乏一些 keys,或者多了一些 keys,只要設(shè)置strict參數(shù)為 False,就能夠把 state_dict 能夠匹配的 keys 加載進(jìn)去,而忽略掉那些 non-matching keys。

          保存模型 A 的 state_dict :

          torch.save(modelA.state_dict(), PATH)

          加載到模型 B:

          modelB = TheModelBClass(*args, **kwargs)
          modelB.load_state_dict(torch.load(PATH), strict=False)

          2.7 保存在 GPU, 加載到 CPU

          保存:

          torch.save(model.state_dict(), PATH)

          加載:

          device = torch.device('cpu')
          model = TheModelClass(*args, **kwargs)
          model.load_state_dict(torch.load(PATH, map_location=device))

          這種情況 model.state_dict() 保存之后在 GPU,直接 torch.load(PATH) 會(huì)加載進(jìn) GPU 中。所以若想加載到 CPU 中,需要加 map_location=torch.device('cpu')。

          2.8 保存在 GPU, 加載到 GPU

          保存:

          torch.save(model.state_dict(), PATH)

          加載:

          map_location="cuda:0"device = torch.device("cuda")
          model = TheModelClass(*args, **kwargs)
          model.load_state_dict(torch.load(PATH))
          model.to(device)
          # Make sure to call input = input.to(device) on any input tensors that you feed to the model

          這種情況 model.state_dict() 保存之后在 GPU,直接 torch.load(PATH) 會(huì)加載進(jìn) GPU 中。所以若想加載到 GPU 中,不需要加 map_location=device。因?yàn)樽詈笠虞d到 GPU 里面,model 是重新初始化的 (在 CPU 里面),所以要 model.to(device)。

          2.9 保存在 CPU, 加載到 GPU

          保存:

          torch.save(model.state_dict(), PATH)

          加載:

          device = torch.device("cuda")
          model = TheModelClass(*args, **kwargs)
          model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
          model.to(device)
          # Make sure to call input = input.to(device) on any input tensors that you feed to the model

          這種情況 model.state_dict() 保存之后在 CPU,直接 torch.load(PATH) 會(huì)加載進(jìn) CPU 中。所以若想加載到 GPU 中,需要加 map_location="cuda:0" 。因?yàn)樽詈笠虞d到 GPU 里面,model 是重新初始化的 (在 CPU 里面),所以要 model.to(device)。


          如果覺(jué)得有用,就請(qǐng)分享到朋友圈吧!


          點(diǎn)個(gè)在看 paper不斷!

          瀏覽 59
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  成人精品影院 | 安徽妇女BBBWBBBwm | 黄片一级二级三级 | 国产福利无码视频 | 国产精品97 |