Pytorch基礎 | eval()的用法比較
點擊上方“機器學習與生成對抗網絡”,關注星標
獲取有趣、好玩的前沿干貨!
01
1.1 model.train()
1.2 model.eval()
1.3 分析原因
# 定義一個網絡class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 實例化這個網絡Model = Net()# 訓練模式使用.train()Model.train(mode=True)# 測試模型使用.eval()Model.eval()

def train(model, optimizer, epoch, train_loader, validation_loader):model.train() # ???????????? 錯誤的位置for batch_idx, (data, target) in experiment.batch_loop(iterable=train_loader):# model.train() # 正確的位置,保證每一個batch都能進入model.train()的模式data, target = Variable(data), Variable(target)# Inferenceoutput = model(data)loss_t = F.nll_loss(output, target)# The iconic grad-back-step triooptimizer.zero_grad()loss_t.backward()optimizer.step()if batch_idx % args.log_interval == 0:train_loss = loss_t.item()train_accuracy = get_correct_count(output, target) * 100.0 / len(target)experiment.add_metric(LOSS_METRIC, train_loss)experiment.add_metric(ACC_METRIC, train_accuracy)print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx, len(train_loader),100. * batch_idx / len(train_loader), train_loss))with experiment.validation():val_loss, val_accuracy = test(model, validation_loader) # ????????????experiment.add_metric(LOSS_METRIC, val_loss)experiment.add_metric(ACC_METRIC, val_accuracy)
def test(model, test_loader):model.eval()# ...
02
在train模式下,dropout網絡層會按照設定的參數p設置保留激活單元的概率(保留概率=p); BN層會繼續(xù)計算數據的mean和var等參數并更新。
在eval模式下,dropout層會讓所有的激活單元都通過,而BN層會停止計算和更新mean和var,直接使用在訓練階段已經學出的mean和var值。
猜您喜歡:
附下載 |《TensorFlow 2.0 深度學習算法實戰(zhàn)》
評論
圖片
表情
