<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          【小白學(xué)習(xí)PyTorch教程】十四、遷移學(xué)習(xí):微調(diào)ResNet實(shí)現(xiàn)男人和女人圖像分類(lèi)

          共 10390字,需瀏覽 21分鐘

           ·

          2021-08-20 08:27

          「@Author:Runsen」

          上次微調(diào)了Alexnet,這次微調(diào)ResNet實(shí)現(xiàn)男人和女人圖像分類(lèi)。

          ResNet是 Residual Networks 的縮寫(xiě),是一種經(jīng)典的神經(jīng)網(wǎng)絡(luò),用作許多計(jì)算機(jī)視覺(jué)任務(wù)。

          • ResNet論文參見(jiàn)此處:

          https://arxiv.org/abs/1512.03385

          該模型是 2015 年 ImageNet 挑戰(zhàn)賽的獲勝者。ResNet 的根本性突破是它使我們能夠成功訓(xùn)練 150 層以上的極深神經(jīng)網(wǎng)絡(luò)。

          下面是resnet18的整個(gè)網(wǎng)絡(luò)結(jié)構(gòu):

          Resnet 18 是在 ImageNet 數(shù)據(jù)集上預(yù)訓(xùn)練的圖像分類(lèi)模型。

          這次使用Resnet 18 實(shí)現(xiàn)分類(lèi)性別數(shù)據(jù)集,

          該性別分類(lèi)數(shù)據(jù)集共有58,658 張圖像。(train:47,009 / val:11,649)

          female
          male
          • Dataset: Kaggle Gender Classification Dataset

          加載數(shù)據(jù)集

          設(shè)置圖像目錄路徑并初始化 PyTorch 數(shù)據(jù)加載器。和之前一樣的模板套路

          import torch
          import torch.nn as nn
          import torch.optim as optim

          import torchvision
          from torchvision import datasets, models, transforms

          import numpy as np
          import matplotlib.pyplot as plt

          import time
          import os


          device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu"# device object


          transforms_train = transforms.Compose([
              transforms.Resize((224224)),
              transforms.RandomHorizontalFlip(), # data augmentation
              transforms.ToTensor(),
              transforms.Normalize([0.4850.4560.406], [0.2290.2240.225]) # normalization
          ])

          transforms_val = transforms.Compose([
              transforms.Resize((224224)),
              transforms.ToTensor(),
              transforms.Normalize([0.4850.4560.406], [0.2290.2240.225])
          ])

          data_dir = './gender_classification_dataset'
          train_datasets = datasets.ImageFolder(os.path.join(data_dir, 'Training'), transforms_train)
          val_datasets = datasets.ImageFolder(os.path.join(data_dir, 'Validation'), transforms_val)

          train_dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=16, shuffle=True, num_workers=4)
          val_dataloader = torch.utils.data.DataLoader(val_datasets, batch_size=16, shuffle=True, num_workers=4)

          print('Train dataset size:', len(train_datasets))
          print('Validation dataset size:', len(val_datasets))

          class_names = train_datasets.classes
          print('Class names:', class_names)
          plt.rcParams['figure.figsize'] = [128]
          plt.rcParams['figure.dpi'] = 60
          plt.rcParams.update({'font.size'20})


          def imshow(input, title):
              # torch.Tensor => numpy
              input = input.numpy().transpose((120))
              # undo image normalization
              mean = np.array([0.4850.4560.406])
              std = np.array([0.2290.2240.225])
              input = std * input + mean
              input = np.clip(input, 01)
              # display images
              plt.imshow(input)
              plt.title(title)
              plt.show()


          # load a batch of train image
          iterator = iter(train_dataloader)

          # visualize a batch of train image
          inputs, classes = next(iterator)
          out = torchvision.utils.make_grid(inputs[:4])
          imshow(out, title=[class_names[x] for x in classes[:4]])

          定義模型

          我們使用遷移學(xué)習(xí)方法,只需要修改最后的輸出即可。

          model = models.resnet18(pretrained=True)
          num_features = model.fc.in_features
          model.fc = nn.Linear(num_features, 2# binary classification (num_of_class == 2)
          model = model.to(device)

          criterion = nn.CrossEntropyLoss()
          optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

          訓(xùn)練階段

          由于ResNet18網(wǎng)絡(luò)非常復(fù)雜,深,這里只訓(xùn)練num_epochs = 3

          num_epochs = 3
          start_time = time.time()

          for epoch in range(num_epochs):
              """ Training  """
              model.train()

              running_loss = 0.
              running_corrects = 0

              # load a batch data of images
              for i, (inputs, labels) in enumerate(train_dataloader):
                  inputs = inputs.to(device)
                  labels = labels.to(device)

                  optimizer.zero_grad()
                  outputs = model(inputs)
                  _, preds = torch.max(outputs, 1)
                  loss = criterion(outputs, labels)

                  # get loss value and update the network weights
                  loss.backward()
                  optimizer.step()

                  running_loss += loss.item() * inputs.size(0)
                  running_corrects += torch.sum(preds == labels.data)

              epoch_loss = running_loss / len(train_datasets)
              epoch_acc = running_corrects / len(train_datasets) * 100.
              print('[Train #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))

              """ Validation"""
              model.eval()

              with torch.no_grad():
                  running_loss = 0.
                  running_corrects = 0

                  for inputs, labels in val_dataloader:
                      inputs = inputs.to(device)
                      labels = labels.to(device)

                      outputs = model(inputs)
                      _, preds = torch.max(outputs, 1)
                      loss = criterion(outputs, labels)

                      running_loss += loss.item() * inputs.size(0)
                      running_corrects += torch.sum(preds == labels.data)

                  epoch_loss = running_loss / len(val_datasets)
                  epoch_acc = running_corrects / len(val_datasets) * 100.
                  print('[Validation #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))

          「保存訓(xùn)練好的模型文件」

          save_path = 'face_gender_classification_transfer_learning_with_ResNet18.pth'
          torch.save(model.state_dict(), save_path)

          「訓(xùn)練好的模型文件加載」

          model = models.resnet18(pretrained=True)
          num_features = model.fc.in_features
          model.fc = nn.Linear(num_features, 2
          model.load_state_dict(torch.load(save_path))
          model.to(device)

          model.eval()
          start_time = time.time()

          with torch.no_grad():
              running_loss = 0.
              running_corrects = 0

              for i, (inputs, labels) in enumerate(val_dataloader):
                  inputs = inputs.to(device)
                  labels = labels.to(device)

                  outputs = model(inputs)
                  _, preds = torch.max(outputs, 1)
                  loss = criterion(outputs, labels)

                  running_loss += loss.item() * inputs.size(0)
                  running_corrects += torch.sum(preds == labels.data)

                  if i == 0:
                      print('[Prediction Result Examples]')
                      images = torchvision.utils.make_grid(inputs[:4])
                      imshow(images.cpu(), title=[class_names[x] for x in labels[:4]])
                      images = torchvision.utils.make_grid(inputs[4:8])
                      imshow(images.cpu(), title=[class_names[x] for x in labels[4:8]])

              epoch_loss = running_loss / len(val_datasets)
              epoch_acc = running_corrects / len(val_datasets) * 100.
              print('[Validation #{}] Loss: {:.4f} Acc: {:.4f}% Time: {:.4f}s'.format(epoch, epoch_loss, epoch_acc, time.time() - start_time))

          在最后的測(cè)試結(jié)果中,ACC達(dá)到了97,但是模型太復(fù)雜,運(yùn)行太慢了,在項(xiàng)目中往往不可取。

          往期精彩回顧




          本站qq群851320808,加入微信群請(qǐng)掃碼:
          瀏覽 72
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  毛片网站在线 | 亚洲精品卡一卡二 | 午夜性爱在线 | 毛片电影免费看 | 囯产精品久久久久久久久久免费 |