pytorch:預訓練權重、凍結訓練和斷點恢復
知乎—吵雞兇鴨OvO? 侵刪
01
If I have seen further, it is by standing on the shoulders of giants.
02
# 第一步:讀取當前模型參數model_dict = model.state_dict()# 第二步:讀取預訓練模型pretrained_dict = torch.load(model_path, map_location = device)pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}# 第三步:使用預訓練的模型更新當前模型參數model_dict.update(pretrained_dict)# 第四步:加載模型參數model.load_state_dict(model_dict)
model_dict = model.state_dict()pretrained_dict = torch.load(model_path, map_location=device)temp = {}for k, v in pretrained_dict.items():try:if np.shape(model_dict[k]) == np.shape(v):temp[k]=vexcept:passmodel_dict.update(temp)
03
# 凍結階段訓練參數,learning_rate和batch_size可以設置大一點Init_Epoch = 0Freeze_Epoch = 50Freeze_batch_size = 8Freeze_lr = 1e-3# 解凍階段訓練參數,learning_rate和batch_size設置小一點UnFreeze_Epoch = 100Unfreeze_batch_size = 4Unfreeze_lr = 1e-4# 可以加一個變量控制是否進行凍結訓練Freeze_Train = True# 凍結一部分進行訓練batch_size = Freeze_batch_sizelr = Freeze_lrstart_epoch = Init_Epochend_epoch = Freeze_Epochif Freeze_Train:for param in model.backbone.parameters():param.requires_grad = False# 解凍后訓練batch_size = Unfreeze_batch_sizelr = Unfreeze_lrstart_epoch = Freeze_Epochend_epoch = UnFreeze_Epochif Freeze_Train:for param in model.backbone.parameters():param.requires_grad = True
04
torch.save(model.state_dict(),?"你要保存到的路徑")05
猜您喜歡:
附下載 |《TensorFlow 2.0 深度學習算法實戰(zhàn)》
評論
圖片
表情
