<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          輕松學(xué)Pytorch-遷移學(xué)習(xí)實(shí)現(xiàn)表面缺陷檢查

          共 9001字,需瀏覽 19分鐘

           ·

          2022-07-26 21:36

          點(diǎn)擊上方小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂

          重磅干貨,第一時(shí)間送達(dá)

          本文轉(zhuǎn)載自:OpenCV學(xué)堂

           引言 ·


          大家好,今天給大家更新的是如何基于torchvision自帶的模型完成圖像分類任務(wù)的遷移學(xué)習(xí),前面我們已經(jīng)完成了對對象檢測任務(wù)的遷移學(xué)習(xí),這里補(bǔ)上針對圖像分類任務(wù)的遷移學(xué)習(xí),官方的文檔比較啰嗦,看了之后其實(shí)可操作性很低,特別是對于初學(xué)者,估計(jì)看了之后就發(fā)懵的那種。本人重新改寫了一波,代碼簡潔易懂,然后把訓(xùn)練結(jié)果導(dǎo)出ONNX,使用OpenCV DNN調(diào)用部署,非常實(shí)用!

          數(shù)據(jù)集

          東北大學(xué)熱軋帶鋼表面缺陷數(shù)據(jù)集,該數(shù)據(jù)集是東北大學(xué)的宋克臣等幾位老師收集的,一共包含了三類數(shù)據(jù)。這里使用(NEU surface defect database),數(shù)據(jù)集收集了夾雜、劃痕、壓入氧化皮、裂紋、麻點(diǎn)和斑塊總計(jì)6種缺陷,每種缺陷300張,圖像尺寸為200×200。部分示例如下:

          基于該數(shù)據(jù)集,實(shí)現(xiàn)pytorch數(shù)據(jù)類,完成數(shù)據(jù)集的加載與預(yù)處理的代碼如下:

          class SurfaceDefectDataset(Dataset):
              def __init__(self, root_dir):
                  self.transform = transforms.Compose([transforms.ToTensor()])
                  img_files = os.listdir(root_dir)
                  self.defect_types = []
                  self.images = []
                  index = 0
                  for file_name in img_files:
                      defect_attrs = file_name.split("_")
                      d_index = defect_labels.index(defect_attrs[0])
                      self.images.append(os.path.join(root_dir, file_name))
                      self.defect_types.append(d_index)
                      index += 1

              def __len__(self):
                  return len(self.images)

              def num_of_samples(self):
                  return len(self.images)

              def __getitem__(self, idx):
                  if torch.is_tensor(idx):
                      idx = idx.tolist()
                      image_path = self.images[idx]
                  else:
                      image_path = self.images[idx]
                  img = cv.imread(image_path)  # BGR order
                  h, w, c = img.shape
                  # rescale
                  img = cv.resize(img, (200200))
                  img = (np.float32(img) /255.0 - 0.5) / 0.5
                  # H, W C to C, H, W
                  img = img.transpose((201))
                  sample = {'image': torch.from_numpy(img), 'defect'self.defect_types[idx]}
                  return sample

          怎么下載該數(shù)據(jù)集,后臺(tái)回復(fù)"NEU"關(guān)鍵字即可獲取下載地址

          模型使用

          Pytorchvison支持多種圖像分類模型,這里我們選擇殘差網(wǎng)絡(luò)模型作為遷移學(xué)習(xí)的基礎(chǔ)模型,對輸出層(最后一層)改為六個(gè)類別,其它特征層選擇在訓(xùn)練時(shí)候微調(diào)參數(shù)。常見的ResNet網(wǎng)絡(luò)模型如下:

          基于ResNet18完成網(wǎng)絡(luò)模型修改,最終的模型實(shí)現(xiàn)代碼如下:

          class SurfaceDefectResNet(torch.nn.Module):

              def __init__(self):
                  super(SurfaceDefectResNet, self).__init__()
                  self.cnn_layers = torchvision.models.resnet18(pretrained=True)
                  num_ftrs = self.cnn_layers.fc.in_features
                  self.cnn_layers.fc = torch.nn.Linear(num_ftrs, 6)

              def forward(self, x):
                  # stack convolution layers
                  out = self.cnn_layers(x)
                  return out

          模型訓(xùn)練與測試

          模型訓(xùn)練跟前面講的一些圖像分類模型訓(xùn)練方式并無不同,基于交叉熵?fù)p失,完成訓(xùn)練,每個(gè)批次4張圖像或者8張圖,訓(xùn)練15個(gè)epoch之后,保存模型。然后使用模型測試35張測試圖像,發(fā)現(xiàn)有兩張預(yù)測錯(cuò)誤,其余均正確。訓(xùn)練模型的代碼如下:

          # 訓(xùn)練模型的次數(shù)
          num_epochs = 15
          # optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
          optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
          model.train()

          # 損失函數(shù)
          cross_loss = torch.nn.CrossEntropyLoss()
          index = 0
          for epoch in  range(num_epochs):
              train_loss = 0.0
              for i_batch, sample_batched in enumerate(dataloader):
                  images_batch, label_batch = \
                      sample_batched['image'], sample_batched['defect']
                  if train_on_gpu:
                      images_batch, label_batch= images_batch.cuda(), label_batch.cuda()
                  optimizer.zero_grad()

                  # forward pass: compute predicted outputs by passing inputs to the model
                  m_label_out_ = model(images_batch)
                  label_batch = label_batch.long()

                  # calculate the batch loss
                  loss = cross_loss(m_label_out_, label_batch)

                  # backward pass: compute gradient of the loss with respect to model parameters
                  loss.backward()

                  # perform a single optimization step (parameter update)
                  optimizer.step()

                  # update training loss
                  train_loss += loss.item()
                  if index % 100 == 0:
                      print('step: {} \tTraining Loss: {:.6f} '.format(index, loss.item()))
                  index += 1

                  # 計(jì)算平均損失
              train_loss = train_loss / num_train_samples

              # 顯示訓(xùn)練集與驗(yàn)證集的損失函數(shù)
              print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))

          # save model
          model.eval()
          torch.save(model, 'surface_defect_model.pt')


          轉(zhuǎn)為為ONNX模式,OpenCV DNN部署調(diào)用,代碼如下:

          defect_net = cv.dnn.readNetFromONNX("surface_defect_resnet18.onnx")
          root_dir = "D:/pytorch/enu_surface_defect/test"
          fileNames = os.listdir(root_dir)
          for f in fileNames:
              image = cv.imread(os.path.join(root_dir, f))
              blob = cv.dnn.blobFromImage(image, 0.00392, (200200), (127127127)) / 0.5
              defect_net.setInput(blob)
              res = defect_net.forward()
              idx = np.argmax(np.reshape(res, (6)))
              defect_txt = defect_labels[idx]
              cv.putText(image, defect_txt, (1025), cv.FONT_HERSHEY_SIMPLEX, 1, (25500), 2)
              cv.imshow("input", image)
              print(f, defect_txt)
              cv.waitKey(0)
          cv.destroyAllWindows()


          預(yù)測運(yùn)行結(jié)果如下:

          運(yùn)行結(jié)果與pytorch調(diào)用模型運(yùn)行結(jié)果保持一致。由于這個(gè)是一個(gè)專欄,很多代碼在以前的文章中已經(jīng)給出了,這里就沒有重復(fù)貼代碼!


          好消息!

          小白學(xué)視覺知識(shí)星球

          開始面向外開放啦??????




          下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
          在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

          下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講
          小白學(xué)視覺公眾號(hào)后臺(tái)回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。

          下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
          小白學(xué)視覺公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。

          交流群


          歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會(huì)逐漸細(xì)分),請掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會(huì)請出群,謝謝理解~


          瀏覽 33
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲高清无码视频在线免费观看 | 人怕香蕉网| 日韩免费高清一区二区 | 九月色婷婷 | 欧美久久国产精品 |