輕松學(xué)Pytorch-遷移學(xué)習(xí)實(shí)現(xiàn)表面缺陷檢查
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
本文轉(zhuǎn)載自:OpenCV學(xué)堂
引言 ·
大家好,今天給大家更新的是如何基于torchvision自帶的模型完成圖像分類任務(wù)的遷移學(xué)習(xí),前面我們已經(jīng)完成了對對象檢測任務(wù)的遷移學(xué)習(xí),這里補(bǔ)上針對圖像分類任務(wù)的遷移學(xué)習(xí),官方的文檔比較啰嗦,看了之后其實(shí)可操作性很低,特別是對于初學(xué)者,估計(jì)看了之后就發(fā)懵的那種。本人重新改寫了一波,代碼簡潔易懂,然后把訓(xùn)練結(jié)果導(dǎo)出ONNX,使用OpenCV DNN調(diào)用部署,非常實(shí)用!
數(shù)據(jù)集
東北大學(xué)熱軋帶鋼表面缺陷數(shù)據(jù)集,該數(shù)據(jù)集是東北大學(xué)的宋克臣等幾位老師收集的,一共包含了三類數(shù)據(jù)。這里使用(NEU surface defect database),數(shù)據(jù)集收集了夾雜、劃痕、壓入氧化皮、裂紋、麻點(diǎn)和斑塊總計(jì)6種缺陷,每種缺陷300張,圖像尺寸為200×200。部分示例如下:

基于該數(shù)據(jù)集,實(shí)現(xiàn)pytorch數(shù)據(jù)類,完成數(shù)據(jù)集的加載與預(yù)處理的代碼如下:
class SurfaceDefectDataset(Dataset):
def __init__(self, root_dir):
self.transform = transforms.Compose([transforms.ToTensor()])
img_files = os.listdir(root_dir)
self.defect_types = []
self.images = []
index = 0
for file_name in img_files:
defect_attrs = file_name.split("_")
d_index = defect_labels.index(defect_attrs[0])
self.images.append(os.path.join(root_dir, file_name))
self.defect_types.append(d_index)
index += 1
def __len__(self):
return len(self.images)
def num_of_samples(self):
return len(self.images)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image_path = self.images[idx]
else:
image_path = self.images[idx]
img = cv.imread(image_path) # BGR order
h, w, c = img.shape
# rescale
img = cv.resize(img, (200, 200))
img = (np.float32(img) /255.0 - 0.5) / 0.5
# H, W C to C, H, W
img = img.transpose((2, 0, 1))
sample = {'image': torch.from_numpy(img), 'defect': self.defect_types[idx]}
return sample怎么下載該數(shù)據(jù)集,后臺(tái)回復(fù)"NEU"關(guān)鍵字即可獲取下載地址
模型使用
Pytorchvison支持多種圖像分類模型,這里我們選擇殘差網(wǎng)絡(luò)模型作為遷移學(xué)習(xí)的基礎(chǔ)模型,對輸出層(最后一層)改為六個(gè)類別,其它特征層選擇在訓(xùn)練時(shí)候微調(diào)參數(shù)。常見的ResNet網(wǎng)絡(luò)模型如下:

基于ResNet18完成網(wǎng)絡(luò)模型修改,最終的模型實(shí)現(xiàn)代碼如下:
class SurfaceDefectResNet(torch.nn.Module):
def __init__(self):
super(SurfaceDefectResNet, self).__init__()
self.cnn_layers = torchvision.models.resnet18(pretrained=True)
num_ftrs = self.cnn_layers.fc.in_features
self.cnn_layers.fc = torch.nn.Linear(num_ftrs, 6)
def forward(self, x):
# stack convolution layers
out = self.cnn_layers(x)
return out模型訓(xùn)練與測試
模型訓(xùn)練跟前面講的一些圖像分類模型訓(xùn)練方式并無不同,基于交叉熵?fù)p失,完成訓(xùn)練,每個(gè)批次4張圖像或者8張圖,訓(xùn)練15個(gè)epoch之后,保存模型。然后使用模型測試35張測試圖像,發(fā)現(xiàn)有兩張預(yù)測錯(cuò)誤,其余均正確。訓(xùn)練模型的代碼如下:
# 訓(xùn)練模型的次數(shù)
num_epochs = 15
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
model.train()
# 損失函數(shù)
cross_loss = torch.nn.CrossEntropyLoss()
index = 0
for epoch in range(num_epochs):
train_loss = 0.0
for i_batch, sample_batched in enumerate(dataloader):
images_batch, label_batch = \
sample_batched['image'], sample_batched['defect']
if train_on_gpu:
images_batch, label_batch= images_batch.cuda(), label_batch.cuda()
optimizer.zero_grad()
# forward pass: compute predicted outputs by passing inputs to the model
m_label_out_ = model(images_batch)
label_batch = label_batch.long()
# calculate the batch loss
loss = cross_loss(m_label_out_, label_batch)
# backward pass: compute gradient of the loss with respect to model parameters
loss.backward()
# perform a single optimization step (parameter update)
optimizer.step()
# update training loss
train_loss += loss.item()
if index % 100 == 0:
print('step: {} \tTraining Loss: {:.6f} '.format(index, loss.item()))
index += 1
# 計(jì)算平均損失
train_loss = train_loss / num_train_samples
# 顯示訓(xùn)練集與驗(yàn)證集的損失函數(shù)
print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))
# save model
model.eval()
torch.save(model, 'surface_defect_model.pt')轉(zhuǎn)為為ONNX模式,OpenCV DNN部署調(diào)用,代碼如下:
defect_net = cv.dnn.readNetFromONNX("surface_defect_resnet18.onnx")
root_dir = "D:/pytorch/enu_surface_defect/test"
fileNames = os.listdir(root_dir)
for f in fileNames:
image = cv.imread(os.path.join(root_dir, f))
blob = cv.dnn.blobFromImage(image, 0.00392, (200, 200), (127, 127, 127)) / 0.5
defect_net.setInput(blob)
res = defect_net.forward()
idx = np.argmax(np.reshape(res, (6)))
defect_txt = defect_labels[idx]
cv.putText(image, defect_txt, (10, 25), cv.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
cv.imshow("input", image)
print(f, defect_txt)
cv.waitKey(0)
cv.destroyAllWindows()預(yù)測運(yùn)行結(jié)果如下:

運(yùn)行結(jié)果與pytorch調(diào)用模型運(yùn)行結(jié)果保持一致。由于這個(gè)是一個(gè)專欄,很多代碼在以前的文章中已經(jīng)給出了,這里就沒有重復(fù)貼代碼!
好消息!
小白學(xué)視覺知識(shí)星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程 在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講 在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。 下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講 在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。 交流群
歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會(huì)逐漸細(xì)分),請掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會(huì)請出群,謝謝理解~

