<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          輕松學(xué)pytorch – 使用多標(biāo)簽損失函數(shù)訓(xùn)練卷積網(wǎng)絡(luò)

          共 13275字,需瀏覽 27分鐘

           ·

          2022-06-24 10:55

          點(diǎn)擊上方小白學(xué)視覺(jué)”,選擇加"星標(biāo)"或“置頂

          重磅干貨,第一時(shí)間送達(dá)

          大家好,我還在堅(jiān)持繼續(xù)寫(xiě),如果我沒(méi)有記錯(cuò)的話,這個(gè)是系列文章的第十五篇,pytorch中有很多非常方便使用的損失函數(shù),本文就演示了如何通過(guò)多標(biāo)簽損失函數(shù)訓(xùn)練驗(yàn)證碼識(shí)別網(wǎng)絡(luò),實(shí)現(xiàn)驗(yàn)證碼識(shí)別。 


          數(shù)據(jù)集


          這個(gè)數(shù)據(jù)是來(lái)自Kaggle上的一個(gè)驗(yàn)證碼識(shí)別例子,作者采用的是遷移學(xué)習(xí),基于ResNet18做到的訓(xùn)練。

          https://www.kaggle.com/anjalichoudhary12/captcha-with-pytorch

          這個(gè)數(shù)據(jù)集總計(jì)有1070張驗(yàn)證碼圖像,我把其中的1040張用作訓(xùn)練,30張作為測(cè)試,使用pytorch自定義了一個(gè)數(shù)據(jù)集類,代碼如下:

           1import torch
          2import numpy as np
          3from torch.utils.data import Dataset, DataLoader
          4from torchvision import transforms
          5import os
          6import cv2 as cv
          7
          8NUMBER = ['0''1''2''3''4''5''6''7''8''9']
          9ALPHABET = ['a''b''c''d''e''f''g''h''i''j''k''l''m''n''o''p''q''r''s''t''u''v''w''x''y''z']
          10ALL_CHAR_SET = NUMBER + ALPHABET
          11ALL_CHAR_SET_LEN = len(ALL_CHAR_SET)
          12MAX_CAPTCHA = 5
          13
          14
          15def output_nums():
          16    return MAX_CAPTCHA * ALL_CHAR_SET_LEN
          17
          18
          19def encode(a):
          20    onehot = [0]*ALL_CHAR_SET_LEN
          21    idx = ALL_CHAR_SET.index(a)
          22    onehot[idx] += 1
          23    return onehot
          24
          25
          26class CapchaDataset(Dataset):
          27    def __init__(self, root_dir):
          28        self.transform = transforms.Compose([transforms.ToTensor()])
          29        img_files = os.listdir(root_dir)
          30        self.txt_labels = []
          31        self.encodes = []
          32        self.images = []
          33        for file_name in img_files:
          34            label = file_name[:-4]
          35            label_oh = []
          36            for i in label:
          37                label_oh += encode(i)
          38            self.images.append(os.path.join(root_dir, file_name))
          39            self.encodes.append(np.array(label_oh))
          40            self.txt_labels.append(label)
          41
          42    def __len__(self):
          43        return len(self.images)
          44
          45    def num_of_samples(self):
          46        return len(self.images)
          47
          48    def __getitem__(self, idx):
          49        if torch.is_tensor(idx):
          50            idx = idx.tolist()
          51            image_path = self.images[idx]
          52        else:
          53            image_path = self.images[idx]
          54        img = cv.imread(image_path)  # BGR order
          55        h, w, c = img.shape
          56        # rescale
          57        img = cv.resize(img, (12832))
          58        img = (np.float32(img) /255.0 - 0.5) / 0.5
          59        # H, W C to C, H, W
          60        img = img.transpose((201))
          61        sample = {'image': torch.from_numpy(img), 'encode': self.encodes[idx], 'label': self.txt_labels[idx]}
          62        return sample

           

          模型實(shí)現(xiàn)

           

          基于ResNet的block結(jié)構(gòu),我實(shí)現(xiàn)了一個(gè)比較簡(jiǎn)單的殘差網(wǎng)絡(luò),最后加一個(gè)全連接層輸出多個(gè)標(biāo)簽。驗(yàn)證碼是有5個(gè)字符的,每個(gè)字符的是小寫(xiě)26個(gè)字母加上0~9十個(gè)數(shù)字,總計(jì)36個(gè)類別,所以5個(gè)字符就有5x36=180個(gè)輸出,其中每個(gè)字符是獨(dú)熱編碼,這個(gè)可以從數(shù)據(jù)集類的實(shí)現(xiàn)看到。模型的輸入與輸出格式:

          輸入:NCHW=Nx3x32x128
          卷積層最終輸出:NCHW=Nx256x1x4
          全連接層:Nx(256x4)
          最終輸出層:Nx180

          代碼實(shí)現(xiàn)如下:


           1class CapchaResNet(torch.nn.Module):
          2    def __init__(self):
          3        super(CapchaResNet, self).__init__()
          4        self.cnn_layers = torch.nn.Sequential(
          5            # 卷積層 (128x32x3)
          6            ResidualBlock(3321),
          7            ResidualBlock(32642),
          8            ResidualBlock(64642),
          9            ResidualBlock(641282),
          10            ResidualBlock(1282562),
          11            ResidualBlock(2562562),
          12        )
          13
          14        self.fc_layers = torch.nn.Sequential(
          15            torch.nn.Linear(256 * 4, output_nums()),
          16        )
          17
          18    def forward(self, x):
          19        # stack convolution layers
          20        x = self.cnn_layers(x)
          21        out = x.view(-14 * 256)
          22        out = self.fc_layers(out)
          23        return out

           

          模型訓(xùn)練與測(cè)試


          使用多標(biāo)簽損失函數(shù),Adam優(yōu)化器,代碼實(shí)現(xiàn)如下:

           1model = CapchaResNet()
          2print(model)
          3
          4# 使用GPU
          5if train_on_gpu:
          6    model.cuda()
          7
          8ds = CapchaDataset("D:/python/pytorch_tutorial/capcha/samples")
          9num_train_samples = ds.num_of_samples()
          10bs = 16
          11dataloader = DataLoader(ds, batch_size=bs, shuffle=True)
          12
          13# 訓(xùn)練模型的次數(shù)
          14num_epochs = 25
          15# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
          16optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
          17model.train()
          18
          19# 損失函數(shù)
          20mul_loss = torch.nn.MultiLabelSoftMarginLoss()
          21index = 0
          22for epoch in range(num_epochs):
          23    train_loss = 0.0
          24    for i_batch, sample_batched in enumerate(dataloader):
          25        images_batch, oh_labels = \
          26            sample_batched['image'], sample_batched['encode']
          27        if train_on_gpu:
          28            images_batch, oh_labels= images_batch.cuda(), oh_labels.cuda()
          29        optimizer.zero_grad()
          30
          31        # forward pass: compute predicted outputs by passing inputs to the model
          32        m_label_out_ = model(images_batch)
          33        oh_labels = torch.autograd.Variable(oh_labels.float())
          34
          35        # calculate the batch loss
          36        loss = mul_loss(m_label_out_, oh_labels)
          37
          38        # backward pass: compute gradient of the loss with respect to model parameters
          39        loss.backward()
          40
          41        # perform a single optimization step (parameter update)
          42        optimizer.step()
          43
          44        # update training loss
          45        train_loss += loss.item()
          46        if index % 100 == 0:
          47            print('step: {} \tTraining Loss: {:.6f} '.format(index, loss.item()))
          48        index += 1
          49
          50        # 計(jì)算平均損失
          51    train_loss = train_loss / num_train_samples
          52
          53    # 顯示訓(xùn)練集與驗(yàn)證集的損失函數(shù)
          54    print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))
          55
          56# save model
          57model.eval()
          58torch.save(model, 'capcha_recognize_model.pt')

          調(diào)用保存之后的模型,對(duì)圖像測(cè)試代碼如下:

           1cnn_model = torch.load("./capcha_recognize_model.pt")
          2root_dir = "D:/python/pytorch_tutorial/capcha/testdata"
          3files = os.listdir(root_dir)
          4one_hot_len = ALL_CHAR_SET_LEN
          5for file in files:
          6    if os.path.isfile(os.path.join(root_dir, file)):
          7        image = cv.imread(os.path.join(root_dir, file))
          8        h, w, c = image.shape
          9        img = cv.resize(image, (12832))
          10        img = (np.float32(img) /255.0 - 0.5) / 0.5
          11        img = img.transpose((201))
          12        x_input = torch.from_numpy(img).view(1332128)
          13        probs = cnn_model(x_input.cuda())
          14        mul_pred_labels = probs.squeeze().cpu().tolist()
          15        c0 = ALL_CHAR_SET[np.argmax(mul_pred_labels[0:one_hot_len])]
          16        c1 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len:one_hot_len*2])]
          17        c2 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*2:one_hot_len*3])]
          18        c3 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*3:one_hot_len*4])]
          19        c4 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*4:one_hot_len*5])]
          20        pred_txt = '%s%s%s%s%s' % (c0, c1, c2, c3, c4)
          21        cv.putText(image, pred_txt, (1020), cv.FONT_HERSHEY_PLAIN, 1.5, (00255), 2)
          22        print("current code : %s, predict code : %s "%(file[:-4], pred_txt))
          23        cv.imshow("capcha predict", image)
          24        cv.waitKey(0)

          其中對(duì)輸入結(jié)果,要根據(jù)每個(gè)字符的獨(dú)熱編碼,截取成五個(gè)獨(dú)立的字符分類標(biāo)簽,然后使用argmax獲取index根據(jù)index查找類別標(biāo)簽,得到最終的驗(yàn)證碼預(yù)測(cè)字符串,代碼運(yùn)行結(jié)果如下:

          好消息!

          小白學(xué)視覺(jué)知識(shí)星球

          開(kāi)始面向外開(kāi)放啦??????




          下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
          在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺(jué)、目標(biāo)跟蹤、生物視覺(jué)、超分辨率處理等二十多章內(nèi)容。

          下載2:Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目52講
          小白學(xué)視覺(jué)公眾號(hào)后臺(tái)回復(fù):Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目即可下載包括圖像分割、口罩檢測(cè)、車道線檢測(cè)、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測(cè)、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺(jué)實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺(jué)。

          下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
          小白學(xué)視覺(jué)公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。

          交流群


          歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺(jué)、傳感器自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺(jué)SLAM“。請(qǐng)按照格式備注,否則不予通過(guò)。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~


          瀏覽 110
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  蜜桃视频在线无码播放 | 中文娱乐在线视频 | 淫性综合| 欧美国产成人精品一区二区三区 | 美女国产精品 |