【小白學(xué)習(xí)PyTorch教程】十三、遷移學(xué)習(xí):微調(diào)Alexnet實(shí)現(xiàn)ant和bee...
「@Author:Runsen」
上次微調(diào)了VGG19,這次微調(diào)Alexnet實(shí)現(xiàn)ant和bee圖像分類。
多年來,CNN許多變體已經(jīng)發(fā)展起來,從而產(chǎn)生了幾種 CNN 架構(gòu)。其中最常見的是:
LeNet-5 (1998)
AlexNet (2012)
ZFNet (2013)
GoogleNet / Inception(2014)
VGGNet (2014)
ResNet (2015)
這篇博客是 關(guān)于AlexNet 教程,AlexNet 也是之前受歡迎的 CNN 架構(gòu)之一。
AlexNet
AlexNet主要由 Alex Krizhevsky 設(shè)計(jì)。它由 Ilya Sutskever 和 Krizhevsky 的博士生導(dǎo)師 Geoffrey Hinton 共同發(fā)表,是卷積神經(jīng)網(wǎng)絡(luò)或 CNN。
在參加 ImageNet 大規(guī)模視覺識(shí)別挑戰(zhàn)賽后,AlexNet 一舉成名。Alexnet在分類任務(wù)中實(shí)現(xiàn)了 84.6% 的前 5 名準(zhǔn)確率,而排名第二的團(tuán)隊(duì)的前 5 名準(zhǔn)確率為 73.8%。由于 2012 年的計(jì)算能力非常有限,Alex 在 2 個(gè) GPU 上對(duì)其進(jìn)行了訓(xùn)練。
2012 Imagenet 挑戰(zhàn)賽的 Alexnet 架構(gòu)。=上圖是2012 Imagenet 挑戰(zhàn)賽的 Alexnet 架構(gòu)
224x224x3
227x227x3AlexNet 架構(gòu)由 5 個(gè)卷積層、3 個(gè)最大池化層、2 個(gè)歸一化層、2 個(gè)全連接層和 1 個(gè) softmax 層組成。
每個(gè)卷積層由卷積濾波器和非線性激活函數(shù)ReLU組成。
池化層用于執(zhí)行最大池化。
由于全連接層的存在,輸入大小是固定的。
輸入大小之前在大多數(shù)被提及為 224x224x3,但由于一些填充,變成了 227x227x3
AlexNet 總共有 6000 萬個(gè)參數(shù)。
下面是Alexnet中的 227x227x3 模型參數(shù)
| Size / Operation | Filter | Depth | Stride | Padding | Number of Parameters | Forward Computation |
|---|---|---|---|---|---|---|
| 3* 227 * 227 | ||||||
| Conv1 + Relu | 11 * 11 | 96 | 4 | (11 * 11 *3 ?+ 1) * 96=34944 | (11113 + 1) * 96 * 55 * 55=105705600 | |
| 96 * 55 * 55 | ||||||
| Max Pooling | 3 * 3 | 2 | ||||
| 96 * 27 * 27 | ||||||
| Norm | ||||||
| Conv2 + Relu | 5 * 5 | 256 | 1 | 2 | (5 * 5 * 96 + 1) * 256=614656 | (5 * 5 * 96 + 1) * 256 * 27 * 27=448084224 |
| 256 * 27 * 27 | ||||||
| Max Pooling | 3 * 3 | 2 | ||||
| 256 * 13 * 13 | ||||||
| Norm | ||||||
| Conv3 + Relu | 3 * 3 | 384 | 1 | 1 | (3 * 3 * 256 + 1) * 384=885120 | (3 * 3 * 256 + 1) * 384 * 13 * 13=149585280 |
| 384 * 13 * 13 | ||||||
| Conv4 + Relu | 3 * 3 | 384 | 1 | 1 | (3 * 3 * 384 + 1) * 384=1327488 | (3 * 3 * 384 + 1) * 384 * 13 * 13=224345472 |
| 384 * 13 * 13 | ||||||
| Conv5 + Relu | 3 * 3 | 256 | 1 | 1 | (3 * ?3 * ?384 + 1) * 256=884992 | (3 * 3 * 384 + 1) * 256 * 13 * 13=149563648 |
| 256 * 13 * 13 | ||||||
| Max Pooling | 3 * 3 | 2 | ||||
| 256 * 6 * 6 | ||||||
| Dropout (rate 0.5) | ||||||
| FC6 + Relu | 256 * 6 * 6 * 4096=37748736 | 256 * 6 * 6 * 4096=37748736 | ||||
| 4096 | ||||||
| Dropout (rate 0.5) | ||||||
| FC7 + Relu | 4096 * 4096=16777216 | 4096 * 4096=16777216 | ||||
| 4096 | ||||||
| FC8 + Relu | 4096 * 1000=4096000 | 4096 * 1000=4096000 | ||||
| 1000 classes | ||||||
| Overall | 62369152=62.3 million | 1135906176=1.1 billion | ||||
| Conv VS FC | Conv:3.7million (6%) , FC: 58.6 million (94% ) | Conv: 1.08 billion (95%) , FC: 58.6 million (5%) |
數(shù)據(jù)集介紹
本數(shù)據(jù)集中存在PyTorch相關(guān)入門的數(shù)據(jù)集ant和bee案例,每一個(gè)ant和bee
數(shù)據(jù)來源:PyTorch深度學(xué)習(xí)快速入門教程(絕對(duì)通俗易懂?。拘⊥炼选?/p>
「關(guān)于數(shù)據(jù)集和代碼見文末」
- 讀取數(shù)據(jù)
這里選擇將數(shù)據(jù)reshape成224*224。
import?torch
import?numpy?as?np
import?matplotlib.pyplot?as?plt
import?torch.nn.functional?as?F
from?torch?import?nn
from?torchvision?import?datasets,?transforms,?models
device?=?torch.device('cuda:0'?if?torch.cuda.is_available()?else?"cpu")
#transforms
transform_train?=?transforms.Compose([transforms.Resize((224,?224)),
??????????????????????????????????????transforms.RandomHorizontalFlip(),
??????????????????????????????????????transforms.RandomAffine(0,?shear=10,?scale=(0.8,?1.2)),
??????????????????????????????????????transforms.ColorJitter(brightness=1,?contrast=1,?saturation=1),
??????????????????????????????????????transforms.ToTensor(),
??????????????????????????????????????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
????????????????????????????????????])
transform?=?transforms.Compose([transforms.Resize((224,?224)),
???????????????????????????????transforms.ToTensor(),
???????????????????????????????transforms.Normalize((0.5,?0.5,?0.5),?(0.5,?0.5,?0.5))
???????????????????????????????])
root_train?=?'ants_and_bees/train'
root_val?=?'ants_and_bees/val'
training_dataset?=?datasets.ImageFolder(root=root_train,?transform=transform)
validation_dataset?=?datasets.ImageFolder(root=root_val,?transform=transform)
training_loader?=?torch.utils.data.DataLoader(training_dataset,?batch_size=20,?shuffle=True)
validation_loader?=?torch.utils.data.DataLoader(validation_dataset,?batch_size?=?20,?shuffle=False)
- 展示數(shù)據(jù)
dataiter?=?iter(training_loader)
images,?labels?=?dataiter.next()
fig?=?plt.figure(figsize=(25,6))
def?im_convert(tensor):
??image?=?tensor.cpu().clone().detach().numpy()
??image?=?image.transpose(1,?2,?0)?#shape?32?x?32?x?1
??#de-normalisation?-?multiply?by?std?and?add?mean
??image?=?image?*?np.array((0.5,?0.5,?0.5))?+?np.array((0.5,?0.5,?0.5))
??image?=?image.clip(0,?1)
??return?image
for?idx?in?np.arange(20):
??ax?=?fig.add_subplot(2,?10,?idx+1,?xticks=[],?yticks=[])
??plt.imshow(im_convert(images[idx]))
??#print(labels[idx].item())
??ax.set_title(classes[labels[idx].item()])
plt.show()

- 微調(diào)Alexnet
model?=?models.alexnet(pretrained=True)
print(model)
AlexNet(
??(features):?Sequential(
????(0):?Conv2d(3,?64,?kernel_size=(11,?11),?stride=(4,?4),?padding=(2,?2))
????(1):?ReLU(inplace=True)
????(2):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
????(3):?Conv2d(64,?192,?kernel_size=(5,?5),?stride=(1,?1),?padding=(2,?2))
????(4):?ReLU(inplace=True)
????(5):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
????(6):?Conv2d(192,?384,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
????(7):?ReLU(inplace=True)
????(8):?Conv2d(384,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
????(9):?ReLU(inplace=True)
????(10):?Conv2d(256,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
????(11):?ReLU(inplace=True)
????(12):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
??)
??(avgpool):?AdaptiveAvgPool2d(output_size=(6,?6))
??(classifier):?Sequential(
????(0):?Dropout(p=0.5,?inplace=False)
????(1):?Linear(in_features=9216,?out_features=4096,?bias=True)
????(2):?ReLU(inplace=True)
????(3):?Dropout(p=0.5,?inplace=False)
????(4):?Linear(in_features=4096,?out_features=4096,?bias=True)
????(5):?ReLU(inplace=True)
????(6):?Linear(in_features=4096,?out_features=1000,?bias=True)
??)
)
通過轉(zhuǎn)移學(xué)習(xí),我們將使用從卷積層中提取的特征
需要把最后一層的out_features=1000,改為out_features=2
因?yàn)槲覀兊哪P椭粚?duì)螞蟻和蜜蜂進(jìn)行分類,所以輸出應(yīng)該是2,而不是AlexNet的輸出層中指定的1000。因此,我們改變了AlexNet中的classifier第6個(gè)元素的輸出。
for?param?in?model.features.parameters():
??`param.requires_grad?=?False?????????????????????
import?torch.nn?as?nn
n_inputs?=?model.classifier[6].in_features??????#4096
last_layer?=?nn.Linear(n_inputs,?len(classes))
model.classifier[6]?=?last_layer
model.to(device)
print(model)
AlexNet(
??(features):?Sequential(
????(0):?Conv2d(3,?64,?kernel_size=(11,?11),?stride=(4,?4),?padding=(2,?2))
????(1):?ReLU(inplace=True)
????(2):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
????(3):?Conv2d(64,?192,?kernel_size=(5,?5),?stride=(1,?1),?padding=(2,?2))
????(4):?ReLU(inplace=True)
????(5):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
????(6):?Conv2d(192,?384,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
????(7):?ReLU(inplace=True)
????(8):?Conv2d(384,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
????(9):?ReLU(inplace=True)
????(10):?Conv2d(256,?256,?kernel_size=(3,?3),?stride=(1,?1),?padding=(1,?1))
????(11):?ReLU(inplace=True)
????(12):?MaxPool2d(kernel_size=3,?stride=2,?padding=0,?dilation=1,?ceil_mode=False)
??)
??(avgpool):?AdaptiveAvgPool2d(output_size=(6,?6))
??(classifier):?Sequential(
????(0):?Dropout(p=0.5,?inplace=False)
????(1):?Linear(in_features=9216,?out_features=4096,?bias=True)
????(2):?ReLU(inplace=True)
????(3):?Dropout(p=0.5,?inplace=False)
????(4):?Linear(in_features=4096,?out_features=4096,?bias=True)
????(5):?ReLU(inplace=True)
????(6):?Linear(in_features=4096,?out_features=2,?bias=True)
??)
)
- 訓(xùn)練和測試模型
criterion?=?nn.CrossEntropyLoss()
optimizer?=?torch.optim.Adam(model.parameters(),?lr=0.0001)
epochs?=?5
losses?=?[]
accuracy?=?[]
val_losses?=?[]
val_accuracies?=?[]
for?e?in?range(epochs):
??running_loss?=?0.0
??running_accuracy?=?0.0
??val_loss?=?0.0
??val_accuracy?=?0.0
??for?images,?labels?in?training_loader:
????images?=?images.to(device)
????labels?=?labels.to(device)
????outputs?=?model(images)???
????loss?=?criterion(outputs,?labels)
????optimizer.zero_grad()
????loss.backward()
????optimizer.step()
????_,?preds?=?torch.max(outputs,?1)
????running_accuracy?+=?torch.sum(preds?==?labels.data)
????running_loss?+=?loss.item()?
????#不必為驗(yàn)證集執(zhí)行梯度
????with?torch.no_grad():???????
??????for?val_images,?val_labels?in?validation_loader:
????????val_images?=?val_images.to(device)
????????val_labels?=?val_labels.to(device)
????????val_outputs?=?model(val_images)
????????val_loss?=?criterion(val_outputs,?val_labels)
????????_,?val_preds?=?torch.max(val_outputs,?1)
????????val_accuracy?+=?torch.sum(val_preds?==?val_labels.data)
????????val_loss?+=?val_loss.item()?
????#?metrics?for?training?data
????epoch_loss?=?running_loss/len(training_loader.dataset)
????epoch_accuracy?=?running_accuracy.float()/len(training_loader.dataset)
????losses.append(epoch_loss)
????accuracy.append(epoch_accuracy)
????#?metrics?for?validation?data
????val_epoch_loss?=?val_loss/len(validation_loader.dataset)
????val_epoch_accuracy?=?val_accuracy.float()/len(validation_loader.dataset)
????val_losses.append(val_epoch_loss)
????val_accuracies.append(val_epoch_accuracy)
????#print?the?training?and?validation?metrics
????print("epoch:",?e+1)
????print('training?loss:?{:.6f},?acc?{:.6f}'.format(epoch_loss,?epoch_accuracy.item()))
????print('validation?loss:?{:.6f},?acc?{:.6f}'.format(val_epoch_loss,?val_epoch_accuracy.item()))

plt.plot(losses,?label='training?loss')
plt.plot(val_losses,?label='validation?loss')
plt.legend()
plt.show()

plt.plot(accuracy,?label='training?accuracy')
plt.plot(val_accuracies,?label?=?'validation?accuracy')
plt.legend()
plt.show()

dataiter?=?iter(validation_loader)
images,?labels?=?dataiter.next()
images?=?images.to(device)
labels?=?labels.to(device)
output?=?model(images)
_,?preds?=?torch.max(output,?1)
fig?=?plt.figure(figsize=(25,?4))
for?idx?in?np.arange(20):
??ax?=?fig.add_subplot(2,?10,?idx+1,?xticks=[],?yticks=[])
??plt.imshow(im_convert(images[idx]))
??ax.set_title("{}?({})".format(str(classes[preds[idx].item()]),?str(classes[labels[idx].item()])),?color=("green"?if?preds[idx]==labels[idx]?else?"red"))
plt.show()

PyTorch中使用alexnet的官方文檔:
- https://pytorch.org/hub/pytorch_vision_alexnet/
代碼和數(shù)據(jù)下載:
鏈接:https://pan.baidu.com/s/1KKcl4I97kIcv83HLZVoHDg 提取碼:tun1
