<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          【小白學(xué)習(xí)PyTorch教程】十三、遷移學(xué)習(xí):微調(diào)Alexnet實(shí)現(xiàn)ant和bee...

          共 9522字,需瀏覽 20分鐘

           ·

          2021-07-29 09:28

          「@Author:Runsen」

          上次微調(diào)了VGG19,這次微調(diào)Alexnet實(shí)現(xiàn)ant和bee圖像分類。

          多年來,CNN許多變體已經(jīng)發(fā)展起來,從而產(chǎn)生了幾種 CNN 架構(gòu)。其中最常見的是:

          1. LeNet-5 (1998)

          2. AlexNet (2012)

          3. ZFNet (2013)

          4. GoogleNet / Inception(2014)

          5. VGGNet (2014)

          6. ResNet (2015)

          這篇博客是 關(guān)于AlexNet 教程,AlexNet 也是之前受歡迎的 CNN 架構(gòu)之一。

          AlexNet

          AlexNet主要由 Alex Krizhevsky 設(shè)計(jì)。它由 Ilya Sutskever 和 Krizhevsky 的博士生導(dǎo)師 Geoffrey Hinton 共同發(fā)表,是卷積神經(jīng)網(wǎng)絡(luò)或 CNN。

          在參加 ImageNet 大規(guī)模視覺識(shí)別挑戰(zhàn)賽后,AlexNet 一舉成名。Alexnet在分類任務(wù)中實(shí)現(xiàn)了 84.6% 的前 5 名準(zhǔn)確率,而排名第二的團(tuán)隊(duì)的前 5 名準(zhǔn)確率為 73.8%。由于 2012 年的計(jì)算能力非常有限,Alex 在 2 個(gè) GPU 上對(duì)其進(jìn)行了訓(xùn)練。

          9a5f2c6b8c896e56359c9acc9d235318.webp2012 Imagenet 挑戰(zhàn)賽的 Alexnet 架構(gòu)。=

          上圖是2012 Imagenet 挑戰(zhàn)賽的 Alexnet 架構(gòu)

          6c756d154da4957c7b4d67c658c2c001.webp224x224x35be48dc7e930b1d8d54311b4e51eb748.webp227x227x3
          1. AlexNet 架構(gòu)由 5 個(gè)卷積層、3 個(gè)最大池化層、2 個(gè)歸一化層、2 個(gè)全連接層和 1 個(gè) softmax 層組成。

          2. 每個(gè)卷積層由卷積濾波器和非線性激活函數(shù)ReLU組成。

          3. 池化層用于執(zhí)行最大池化。

          4. 由于全連接層的存在,輸入大小是固定的。

          5. 輸入大小之前在大多數(shù)被提及為 224x224x3,但由于一些填充,變成了 227x227x3

          6. AlexNet 總共有 6000 萬個(gè)參數(shù)。

          下面是Alexnet中的 227x227x3 模型參數(shù)

          Size / OperationFilterDepthStridePaddingNumber of ParametersForward Computation
          3* 227 * 227





          Conv1 + Relu11 * 11964
          (11 * 11 *3 ?+ 1) * 96=34944(11113 + 1) * 96 * 55 * 55=105705600
          96 * 55 * 55





          Max Pooling3 * 3
          2


          96 * 27 * 27





          Norm





          Conv2 + Relu5 * 525612(5 * 5 * 96 + 1) * 256=614656(5 * 5 * 96 + 1) * 256 * 27 * 27=448084224
          256 * 27 * 27





          Max Pooling3 * 3
          2


          256 * 13 * 13





          Norm





          Conv3 + Relu3 * 338411(3 * 3 * 256 + 1) * 384=885120(3 * 3 * 256 + 1) * 384 * 13 * 13=149585280
          384 * 13 * 13





          Conv4 + Relu3 * 338411(3 * 3 * 384 + 1) * 384=1327488(3 * 3 * 384 + 1) * 384 * 13 * 13=224345472
          384 * 13 * 13





          Conv5 + Relu3 * 325611(3 * ?3 * ?384 + 1) * 256=884992(3 * 3 * 384 + 1) * 256 * 13 * 13=149563648
          256 * 13 * 13





          Max Pooling3 * 3
          2


          256 * 6 * 6





          Dropout (rate 0.5)





          FC6 + Relu



          256 * 6 * 6 * 4096=37748736256 * 6 * 6 * 4096=37748736
          4096





          Dropout (rate 0.5)





          FC7 + Relu



          4096 * 4096=167772164096 * 4096=16777216
          4096





          FC8 + Relu



          4096 * 1000=40960004096 * 1000=4096000
          1000 classes





          Overall



          62369152=62.3 million1135906176=1.1 billion
          Conv VS FC



          Conv:3.7million (6%) , FC: 58.6 million (94% )Conv: 1.08 billion (95%) , FC: 58.6 million (5%)

          數(shù)據(jù)集介紹

          本數(shù)據(jù)集中存在PyTorch相關(guān)入門的數(shù)據(jù)集ant和bee案例,每一個(gè)ant和bee3181c59eee2c009eb56021690a4114a3.webp

          數(shù)據(jù)來源:PyTorch深度學(xué)習(xí)快速入門教程(絕對(duì)通俗易懂?。拘⊥炼选?/p>

          「關(guān)于數(shù)據(jù)集和代碼見文末」

          1. 讀取數(shù)據(jù)

          這里選擇將數(shù)據(jù)reshape成224*224。

          import?torch
          import?numpy?as?np
          import?matplotlib.pyplot?as?plt
          import?torch.nn.functional?as?F
          from?torch?import?nn
          from?torchvision?import?datasets,?transforms,?models

          device?=?torch.device('cuda:0'?if?torch.cuda.is_available()?else?"cpu")

          #transforms
          transform_train?=?transforms.Compose([transforms.Resize((224,?224)),
          ??????????????????????????????????????transforms.RandomHorizontalFlip(),
          ??????????????????????????????????????transforms.RandomAffine(0,?shear=10,?scale=(0.8,?1.2)),
          ??????????????????????????????????????transforms.ColorJitter(brightness=1,?contrast=1,?saturation=1),
          ??????????????????????????????????????transforms.ToTensor(),
          ??????????????????????????????????????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
          ????????????????????????????????????])

          transform?=?transforms.Compose([transforms.Resize((224,?224)),
          ???????????????????????????????transforms.ToTensor(),
          ???????????????????????????????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
          ???????????????????????????????])
          root_train?=?'ants_and_bees/train'
          root_val?=?'ants_and_bees/val'

          training_dataset?=?datasets.ImageFolder(root=root_train,?transform=transform)
          validation_dataset?=?datasets.ImageFolder(root=root_val,?transform=transform)
          training_loader?=?torch.utils.data.DataLoader(training_dataset,?batch_size=20,?shuffle=True)
          validation_loader?=?torch.utils.data.DataLoader(validation_dataset,?batch_size?=?20,?shuffle=False)
          1. 展示數(shù)據(jù)
          dataiter?=?iter(training_loader)
          images,?labels?=?dataiter.next()
          fig?=?plt.figure(figsize=(25,6))

          def?im_convert(tensor):
          ??image?=?tensor.cpu().clone().detach().numpy()
          ??image?=?image.transpose(1,?2,?0)?#shape?32?x?32?x?1
          ??#de-normalisation?-?multiply?by?std?and?add?mean
          ??image?=?image?*?np.array((0.5,?0.5,?0.5))?+?np.array((0.5,?0.5,?0.5))
          ??image?=?image.clip(0,?1)
          ??return?image

          for?idx?in?np.arange(20):
          ??ax?=?fig.add_subplot(2,?10,?idx+1,?xticks=[],?yticks=[])
          ??plt.imshow(im_convert(images[idx]))
          ??#print(labels[idx].item())
          ??ax.set_title(classes[labels[idx].item()])
          plt.show()
          e3d226219efaa89fe098756a750c14b0.webp
          1. 微調(diào)Alexnet
          model?=?models.alexnet(pretrained=True)
          print(model)

          AlexNet(
          ??(features):?Sequential(
          ????(0):?Conv2d(3,?64,?kernel_size=(11,?11),?stride=(4,?4),?padding=(2,?2))
          ????(1):?ReLU(inplace=True)
          ????(2):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
          ????(3):?Conv2d(64,?192,?kernel_size=(5,?5),?stride=(1,?1),?padding=(2,?2))
          ????(4):?ReLU(inplace=True)
          ????(5):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
          ????(6):?Conv2d(192,?384,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
          ????(7):?ReLU(inplace=True)
          ????(8):?Conv2d(384,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
          ????(9):?ReLU(inplace=True)
          ????(10):?Conv2d(256,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
          ????(11):?ReLU(inplace=True)
          ????(12):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
          ??)
          ??(avgpool):?AdaptiveAvgPool2d(output_size=(6,?6))
          ??(classifier):?Sequential(
          ????(0):?Dropout(p=0.5,?inplace=False)
          ????(1):?Linear(in_features=9216,?out_features=4096,?bias=True)
          ????(2):?ReLU(inplace=True)
          ????(3):?Dropout(p=0.5,?inplace=False)
          ????(4):?Linear(in_features=4096,?out_features=4096,?bias=True)
          ????(5):?ReLU(inplace=True)
          ????(6):?Linear(in_features=4096,?out_features=1000,?bias=True)
          ??)
          )

          通過轉(zhuǎn)移學(xué)習(xí),我們將使用從卷積層中提取的特征 需要把最后一層的out_features=1000,改為out_features=2

          因?yàn)槲覀兊哪P椭粚?duì)螞蟻和蜜蜂進(jìn)行分類,所以輸出應(yīng)該是2,而不是AlexNet的輸出層中指定的1000。因此,我們改變了AlexNet中的classifier第6個(gè)元素的輸出。

          for?param?in?model.features.parameters():
          ??`param.requires_grad?=?False?????????????????????

          import?torch.nn?as?nn

          n_inputs?=?model.classifier[6].in_features??????#4096
          last_layer?=?nn.Linear(n_inputs,?len(classes))
          model.classifier[6]?=?last_layer
          model.to(device)

          print(model)

          AlexNet(
          ??(features):?Sequential(
          ????(0):?Conv2d(3,?64,?kernel_size=(11,?11),?stride=(4,?4),?padding=(2,?2))
          ????(1):?ReLU(inplace=True)
          ????(2):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
          ????(3):?Conv2d(64,?192,?kernel_size=(5,?5),?stride=(1,?1),?padding=(2,?2))
          ????(4):?ReLU(inplace=True)
          ????(5):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
          ????(6):?Conv2d(192,?384,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
          ????(7):?ReLU(inplace=True)
          ????(8):?Conv2d(384,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
          ????(9):?ReLU(inplace=True)
          ????(10):?Conv2d(256,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
          ????(11):?ReLU(inplace=True)
          ????(12):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
          ??)
          ??(avgpool):?AdaptiveAvgPool2d(output_size=(6,?6))
          ??(classifier):?Sequential(
          ????(0):?Dropout(p=0.5,?inplace=False)
          ????(1):?Linear(in_features=9216,?out_features=4096,?bias=True)
          ????(2):?ReLU(inplace=True)
          ????(3):?Dropout(p=0.5,?inplace=False)
          ????(4):?Linear(in_features=4096,?out_features=4096,?bias=True)
          ????(5):?ReLU(inplace=True)
          ????(6):?Linear(in_features=4096,?out_features=2,?bias=True)
          ??)
          )
          1. 訓(xùn)練和測試模型
          criterion?=?nn.CrossEntropyLoss()
          optimizer?=?torch.optim.Adam(model.parameters(),?lr=0.0001)

          epochs?=?5
          losses?=?[]
          accuracy?=?[]
          val_losses?=?[]
          val_accuracies?=?[]

          for?e?in?range(epochs):
          ??running_loss?=?0.0
          ??running_accuracy?=?0.0
          ??val_loss?=?0.0
          ??val_accuracy?=?0.0

          ??for?images,?labels?in?training_loader:
          ????images?=?images.to(device)
          ????labels?=?labels.to(device)
          ????outputs?=?model(images)???
          ????loss?=?criterion(outputs,?labels)

          ????optimizer.zero_grad()
          ????loss.backward()
          ????optimizer.step()

          ????_,?preds?=?torch.max(outputs,?1)
          ????running_accuracy?+=?torch.sum(preds?==?labels.data)
          ????running_loss?+=?loss.item()?

          ????#不必為驗(yàn)證集執(zhí)行梯度
          ????with?torch.no_grad():???????
          ??????for?val_images,?val_labels?in?validation_loader:
          ????????val_images?=?val_images.to(device)
          ????????val_labels?=?val_labels.to(device)
          ????????val_outputs?=?model(val_images)
          ????????val_loss?=?criterion(val_outputs,?val_labels)

          ????????_,?val_preds?=?torch.max(val_outputs,?1)
          ????????val_accuracy?+=?torch.sum(val_preds?==?val_labels.data)
          ????????val_loss?+=?val_loss.item()?
          ????#?metrics?for?training?data
          ????epoch_loss?=?running_loss/len(training_loader.dataset)
          ????epoch_accuracy?=?running_accuracy.float()/len(training_loader.dataset)
          ????losses.append(epoch_loss)
          ????accuracy.append(epoch_accuracy)
          ????#?metrics?for?validation?data
          ????val_epoch_loss?=?val_loss/len(validation_loader.dataset)
          ????val_epoch_accuracy?=?val_accuracy.float()/len(validation_loader.dataset)
          ????val_losses.append(val_epoch_loss)
          ????val_accuracies.append(val_epoch_accuracy)
          ????#print?the?training?and?validation?metrics
          ????print("epoch:",?e+1)
          ????print('training?loss:?{:.6f},?acc?{:.6f}'.format(epoch_loss,?epoch_accuracy.item()))
          ????print('validation?loss:?{:.6f},?acc?{:.6f}'.format(val_epoch_loss,?val_epoch_accuracy.item()))
          4cb1fb3acca669e7ad75239e29b8407e.webp
          plt.plot(losses,?label='training?loss')
          plt.plot(val_losses,?label='validation?loss')
          plt.legend()
          plt.show()
          7677a3e03664a2180d307811679d7d99.webp
          plt.plot(accuracy,?label='training?accuracy')
          plt.plot(val_accuracies,?label?=?'validation?accuracy')
          plt.legend()
          plt.show()
          10918aa148edd0e0d17d3b113516f948.webp
          dataiter?=?iter(validation_loader)
          images,?labels?=?dataiter.next()
          images?=?images.to(device)
          labels?=?labels.to(device)
          output?=?model(images)
          _,?preds?=?torch.max(output,?1)

          fig?=?plt.figure(figsize=(25,?4))

          for?idx?in?np.arange(20):
          ??ax?=?fig.add_subplot(2,?10,?idx+1,?xticks=[],?yticks=[])
          ??plt.imshow(im_convert(images[idx]))
          ??ax.set_title("{}?({})".format(str(classes[preds[idx].item()]),?str(classes[labels[idx].item()])),?color=("green"?if?preds[idx]==labels[idx]?else?"red"))

          plt.show()
          d725ab95caa2b2c0133ea25b803cb141.webp

          PyTorch中使用alexnet的官方文檔:

          • https://pytorch.org/hub/pytorch_vision_alexnet/

          代碼和數(shù)據(jù)下載:

          鏈接:https://pan.baidu.com/s/1KKcl4I97kIcv83HLZVoHDg 提取碼:tun1

          瀏覽 64
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  在线免费观看黄色小视频 | 亚洲日韩一区二区三区四区丨高清 | 久久精产国99精产国高潮麻豆 | 久久精品无码一区二区无码性色 | 天天拍夜夜爽 |