【深度學習】ResNet——CNN經(jīng)典網(wǎng)絡(luò)模型詳解(pytorch實現(xiàn))
建議大家可以實踐下,代碼都很詳細,有不清楚的地方評論區(qū)見~
1、前言
ResNet(Residual Neural Network)由微軟研究院的Kaiming He等四名華人提出,通過使用ResNet Unit成功訓練出了152層的神經(jīng)網(wǎng)絡(luò),并在ILSVRC2015比賽中取得冠軍,在top5上的錯誤率為3.57%,同時參數(shù)量比VGGNet低,效果非常突出。ResNet的結(jié)構(gòu)可以極快的加速神經(jīng)網(wǎng)絡(luò)的訓練,模型的準確率也有比較大的提升。同時ResNet的推廣性非常好,甚至可以直接用到InceptionNet網(wǎng)絡(luò)中。
下圖是ResNet34層模型的結(jié)構(gòu)簡圖。
2、ResNet詳解
在ResNet網(wǎng)絡(luò)中有如下幾個亮點:
提出residual結(jié)構(gòu)(殘差結(jié)構(gòu)),并搭建超深的網(wǎng)絡(luò)結(jié)構(gòu)(突破1000層) 使用Batch Normalization加速訓練(丟棄dropout)
在ResNet網(wǎng)絡(luò)提出之前,傳統(tǒng)的卷積神經(jīng)網(wǎng)絡(luò)都是通過將一系列卷積層與下采樣層進行堆疊得到的。但是當堆疊到一定網(wǎng)絡(luò)深度時,就會出現(xiàn)兩個問題。
梯度消失或梯度爆炸。 退化問題(degradation problem)。
在ResNet論文中說通過數(shù)據(jù)的預處理以及在網(wǎng)絡(luò)中使用BN(Batch Normalization)層能夠解決梯度消失或者梯度爆炸問題。如果不了解BN層可參考這個鏈接。但是對于退化問題(隨著網(wǎng)絡(luò)層數(shù)的加深,效果還會變差,如下圖所示)并沒有很好的解決辦法。
所以ResNet論文提出了residual結(jié)構(gòu)(殘差結(jié)構(gòu))來減輕退化問題。下圖是使用residual結(jié)構(gòu)的卷積網(wǎng)絡(luò),可以看到隨著網(wǎng)絡(luò)的不斷加深,效果并沒有變差,反而變的更好了。
殘差結(jié)構(gòu)(residual)
殘差指的是什么?其中ResNet提出了兩種mapping:一種是identity mapping,指的就是下圖中”彎彎的曲線”,另一種residual mapping,指的就是除了”彎彎的曲線“那部分,所以最后的輸出是 y=F(x)+x
identity mapping
顧名思義,就是指本身,也就是公式中的x,而residual mapping指的是“差”,也就是y?x,所以殘差指的就是F(x)部分。
下圖是論文中給出的兩種殘差結(jié)構(gòu)。左邊的殘差結(jié)構(gòu)是針對層數(shù)較少網(wǎng)絡(luò),例如ResNet18層和ResNet34層網(wǎng)絡(luò)。右邊是針對網(wǎng)絡(luò)層數(shù)較多的網(wǎng)絡(luò),例如ResNet101,ResNet152等。為什么深層網(wǎng)絡(luò)要使用右側(cè)的殘差結(jié)構(gòu)呢。因為,右側(cè)的殘差結(jié)構(gòu)能夠減少網(wǎng)絡(luò)參數(shù)與運算量。同樣輸入一個channel為256的特征矩陣,如果使用左側(cè)的殘差結(jié)構(gòu)需要大約1170648個參數(shù),但如果使用右側(cè)的殘差結(jié)構(gòu)只需要69632個參數(shù)。明顯搭建深層網(wǎng)絡(luò)時,使用右側(cè)的殘差結(jié)構(gòu)更合適。
我們先對左側(cè)的殘差結(jié)構(gòu)(針對ResNet18/34)進行一個分析。
如下圖所示,該殘差結(jié)構(gòu)的主分支是由兩層3x3的卷積層組成,而殘差結(jié)構(gòu)右側(cè)的連接線是shortcut分支也稱捷徑分支(注意為了讓主分支上的輸出矩陣能夠與我們捷徑分支上的輸出矩陣進行相加,必須保證這兩個輸出特征矩陣有相同的shape)。如果剛剛仔細觀察了ResNet34網(wǎng)絡(luò)結(jié)構(gòu)圖的同學,應該能夠發(fā)現(xiàn)圖中會有一些虛線的殘差結(jié)構(gòu)。在原論文中作者只是簡單說了這些虛線殘差結(jié)構(gòu)有降維的作用,并在捷徑分支上通過1x1的卷積核進行降維處理。而下圖右側(cè)給出了詳細的虛線殘差結(jié)構(gòu),注意下每個卷積層的步距stride,以及捷徑分支上的卷積核的個數(shù)(與主分支上的卷積核個數(shù)相同)。
接著我們再來分析下針對ResNet50/101/152的殘差結(jié)構(gòu),如下圖所示。在該殘差結(jié)構(gòu)當中,主分支使用了三個卷積層,第一個是1x1的卷積層用來壓縮channel維度,第二個是3x3的卷積層,第三個是1x1的卷積層用來還原channel維度(注意主分支上第一層卷積層和第二次卷積層所使用的卷積核個數(shù)是相同的,第三次是第一層的4倍)。該殘差結(jié)構(gòu)所對應的虛線殘差結(jié)構(gòu)如下圖右側(cè)所示,同樣在捷徑分支上有一層1x1的卷積層,它的卷積核個數(shù)與主分支上的第三層卷積層卷積核個數(shù)相同,注意每個卷積層的步距。
為什么殘差學習相對更容易,從直觀上看殘差學習需要學習的內(nèi)容少,因為殘差一般會比較小,學習難度小點。不過我們可以從數(shù)學的角度來分析這個問題,首先殘差單元可以表示為:
其中 XL和 XL+1分別表示的是第L個殘差單元的輸入和輸出,注意每個殘差單元一般包含多層結(jié)構(gòu)。F是殘差函數(shù),表示學習到的殘差,而 h(XL)=XL表示恒等映射, F是ReLU激活函數(shù)?;谏鲜?,我們求得從淺層 l到深層 L 的學習特征為:
式子的第一個因子表示的損失函數(shù)到達L的梯度,小括號中的1表明短路機制可以無損地傳播梯度,而另外一項殘差梯度則需要經(jīng)過帶有weights的層,梯度不是直接傳遞過來的。殘差梯度不會那么巧全為-1,而且就算其比較小,有1的存在也不會導致梯度消失。所以殘差學習會更容易。要注意上面的推導并不是嚴格的證明。
下面這幅圖是原論文給出的不同深度的ResNet網(wǎng)絡(luò)結(jié)構(gòu)配置,注意表中的殘差結(jié)構(gòu)給出了主分支上卷積核的大小與卷積核個數(shù),表中的xN表示將該殘差結(jié)構(gòu)重復N次。那到底哪些殘差結(jié)構(gòu)是虛線殘差結(jié)構(gòu)呢。
對于我們ResNet18/34/50/101/152,表中conv3_x, conv4_x, conv5_x所對應的一系列殘差結(jié)構(gòu)的第一層殘差結(jié)構(gòu)都是虛線殘差結(jié)構(gòu)。因為這一系列殘差結(jié)構(gòu)的第一層都有調(diào)整輸入特征矩陣shape的使命(將特征矩陣的高和寬縮減為原來的一半,將深度channel調(diào)整成下一層殘差結(jié)構(gòu)所需要的channel)。為了方便理解,下面給出了ResNet34的網(wǎng)絡(luò)結(jié)構(gòu)圖,圖中簡單標注了一些信息。
對于我們ResNet50/101/152,其實在conv2_x所對應的一系列殘差結(jié)構(gòu)的第一層也是虛線殘差結(jié)構(gòu)。因為它需要調(diào)整輸入特征矩陣的channel,根據(jù)表格可知通過3x3的max pool之后輸出的特征矩陣shape應該是[56, 56, 64],但我們conv2_x所對應的一系列殘差結(jié)構(gòu)中的實線殘差結(jié)構(gòu)它們期望的輸入特征矩陣shape是[56, 56, 256](因為這樣才能保證輸入輸出特征矩陣shape相同,才能將捷徑分支的輸出與主分支的輸出進行相加)。所以第一層殘差結(jié)構(gòu)需要將shape從[56, 56, 64] --> [56, 56, 256]。注意,這里只調(diào)整channel維度,高和寬不變(而conv3_x, conv4_x, conv5_x所對應的一系列殘差結(jié)構(gòu)的第一層虛線殘差結(jié)構(gòu)不僅要調(diào)整channel還要將高和寬縮減為原來的一半)。
代碼
注:
本次訓練集下載在AlexNet博客有詳細解說:https://blog.csdn.net/weixin_44023658/article/details/105798326 使用遷移學習方法實現(xiàn)收錄在我的這篇blog中:遷移學習 TransferLearning—通俗易懂地介紹(pytorch實例)
#model.py
import?torch.nn?as?nn
import?torch
#18/34
class?BasicBlock(nn.Module):
????expansion?=?1?#每一個conv的卷積核個數(shù)的倍數(shù)
????def?__init__(self,?in_channel,?out_channel,?stride=1,?downsample=None):#downsample對應虛線殘差結(jié)構(gòu)
????????super(BasicBlock,?self).__init__()
????????self.conv1?=?nn.Conv2d(in_channels=in_channel,?out_channels=out_channel,
???????????????????????????????kernel_size=3,?stride=stride,?padding=1,?bias=False)
????????self.bn1?=?nn.BatchNorm2d(out_channel)#BN處理
????????self.relu?=?nn.ReLU()
????????self.conv2?=?nn.Conv2d(in_channels=out_channel,?out_channels=out_channel,
???????????????????????????????kernel_size=3,?stride=1,?padding=1,?bias=False)
????????self.bn2?=?nn.BatchNorm2d(out_channel)
????????self.downsample?=?downsample
????def?forward(self,?x):
????????identity?=?x?#捷徑上的輸出值
????????if?self.downsample?is?not?None:
????????????identity?=?self.downsample(x)
????????out?=?self.conv1(x)
????????out?=?self.bn1(out)
????????out?=?self.relu(out)
????????out?=?self.conv2(out)
????????out?=?self.bn2(out)
????????out?+=?identity
????????out?=?self.relu(out)
????????return?out
#50,101,152
class?Bottleneck(nn.Module):
????expansion?=?4#4倍
????def?__init__(self,?in_channel,?out_channel,?stride=1,?downsample=None):
????????super(Bottleneck,?self).__init__()
????????self.conv1?=?nn.Conv2d(in_channels=in_channel,?out_channels=out_channel,
???????????????????????????????kernel_size=1,?stride=1,?bias=False)??#?squeeze?channels
????????self.bn1?=?nn.BatchNorm2d(out_channel)
????????self.relu?=?nn.ReLU(inplace=True)
????????#?-----------------------------------------
????????self.conv2?=?nn.Conv2d(in_channels=out_channel,?out_channels=out_channel,
???????????????????????????????kernel_size=3,?stride=stride,?bias=False,?padding=1)
????????self.bn2?=?nn.BatchNorm2d(out_channel)
????????self.relu?=?nn.ReLU(inplace=True)
????????#?-----------------------------------------
????????self.conv3?=?nn.Conv2d(in_channels=out_channel,?out_channels=out_channel*self.expansion,#輸出*4
???????????????????????????????kernel_size=1,?stride=1,?bias=False)??#?unsqueeze?channels
????????self.bn3?=?nn.BatchNorm2d(out_channel*self.expansion)
????????self.relu?=?nn.ReLU(inplace=True)
????????self.downsample?=?downsample
????def?forward(self,?x):
????????identity?=?x
????????if?self.downsample?is?not?None:
????????????identity?=?self.downsample(x)
????????out?=?self.conv1(x)
????????out?=?self.bn1(out)
????????out?=?self.relu(out)
????????out?=?self.conv2(out)
????????out?=?self.bn2(out)
????????out?=?self.relu(out)
????????out?=?self.conv3(out)
????????out?=?self.bn3(out)
????????out?+=?identity
????????out?=?self.relu(out)
????????return?out
class?ResNet(nn.Module):
????def?__init__(self,?block,?blocks_num,?num_classes=1000,?include_top=True):#block殘差結(jié)構(gòu)?include_top為了之后搭建更加復雜的網(wǎng)絡(luò)
????????super(ResNet,?self).__init__()
????????self.include_top?=?include_top
????????self.in_channel?=?64
????????self.conv1?=?nn.Conv2d(3,?self.in_channel,?kernel_size=7,?stride=2,
???????????????????????????????padding=3,?bias=False)
????????self.bn1?=?nn.BatchNorm2d(self.in_channel)
????????self.relu?=?nn.ReLU(inplace=True)
????????self.maxpool?=?nn.MaxPool2d(kernel_size=3,?stride=2,?padding=1)
????????self.layer1?=?self._make_layer(block,?64,?blocks_num[0])
????????self.layer2?=?self._make_layer(block,?128,?blocks_num[1],?stride=2)
????????self.layer3?=?self._make_layer(block,?256,?blocks_num[2],?stride=2)
????????self.layer4?=?self._make_layer(block,?512,?blocks_num[3],?stride=2)
????????if?self.include_top:
????????????self.avgpool?=?nn.AdaptiveAvgPool2d((1,?1))??#?output?size?=?(1,?1)自適應
????????????self.fc?=?nn.Linear(512?*?block.expansion,?num_classes)
????????for?m?in?self.modules():
????????????if?isinstance(m,?nn.Conv2d):
????????????????nn.init.kaiming_normal_(m.weight,?mode='fan_out',?nonlinearity='relu')
????def?_make_layer(self,?block,?channel,?block_num,?stride=1):
????????downsample?=?None
????????if?stride?!=?1?or?self.in_channel?!=?channel?*?block.expansion:
????????????downsample?=?nn.Sequential(
????????????????nn.Conv2d(self.in_channel,?channel?*?block.expansion,?kernel_size=1,?stride=stride,?bias=False),
????????????????nn.BatchNorm2d(channel?*?block.expansion))
????????layers?=?[]
????????layers.append(block(self.in_channel,?channel,?downsample=downsample,?stride=stride))
????????self.in_channel?=?channel?*?block.expansion
????????for?_?in?range(1,?block_num):
????????????layers.append(block(self.in_channel,?channel))
????????return?nn.Sequential(*layers)
????def?forward(self,?x):
????????x?=?self.conv1(x)
????????x?=?self.bn1(x)
????????x?=?self.relu(x)
????????x?=?self.maxpool(x)
????????x?=?self.layer1(x)
????????x?=?self.layer2(x)
????????x?=?self.layer3(x)
????????x?=?self.layer4(x)
????????if?self.include_top:
????????????x?=?self.avgpool(x)
????????????x?=?torch.flatten(x,?1)
????????????x?=?self.fc(x)
????????return?x
def?resnet34(num_classes=1000,?include_top=True):
????return?ResNet(BasicBlock,?[3,?4,?6,?3],?num_classes=num_classes,?include_top=include_top)
def?resnet101(num_classes=1000,?include_top=True):
????return?ResNet(Bottleneck,?[3,?4,?23,?3],?num_classes=num_classes,?include_top=include_top)
#train.py
import?torch
import?torch.nn?as?nn
from?torchvision?import?transforms,?datasets
import?json
import?matplotlib.pyplot?as?plt
import?os
import?torch.optim?as?optim
from?model?import?resnet34,?resnet101
import?torchvision.models.resnet
device?=?torch.device("cuda:0"?if?torch.cuda.is_available()?else?"cpu")
print(device)
data_transform?=?{
????"train":?transforms.Compose([transforms.RandomResizedCrop(224),
?????????????????????????????????transforms.RandomHorizontalFlip(),
?????????????????????????????????transforms.ToTensor(),
?????????????????????????????????transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])]),#來自官網(wǎng)參數(shù)
????"val":?transforms.Compose([transforms.Resize(256),#將最小邊長縮放到256
???????????????????????????????transforms.CenterCrop(224),
???????????????????????????????transforms.ToTensor(),
???????????????????????????????transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])])}
data_root?=?os.getcwd()
image_path?=?data_root?+?"/flower_data/"??#?flower?data?set?path
train_dataset?=?datasets.ImageFolder(root=image_path?+?"train",
?????????????????????????????????????transform=data_transform["train"])
train_num?=?len(train_dataset)
#?{'daisy':0,?'dandelion':1,?'roses':2,?'sunflower':3,?'tulips':4}
flower_list?=?train_dataset.class_to_idx
cla_dict?=?dict((val,?key)?for?key,?val?in?flower_list.items())
#?write?dict?into?json?file
json_str?=?json.dumps(cla_dict,?indent=4)
with?open('class_indices.json',?'w')?as?json_file:
????json_file.write(json_str)
batch_size?=?16
train_loader?=?torch.utils.data.DataLoader(train_dataset,
???????????????????????????????????????????batch_size=batch_size,?shuffle=True,
???????????????????????????????????????????num_workers=0)
validate_dataset?=?datasets.ImageFolder(root=image_path?+?"/val",
????????????????????????????????????????transform=data_transform["val"])
val_num?=?len(validate_dataset)
validate_loader?=?torch.utils.data.DataLoader(validate_dataset,
??????????????????????????????????????????????batch_size=batch_size,?shuffle=False,
??????????????????????????????????????????????num_workers=0)
#net?=?resnet34()
net?=?resnet34(num_classes=5)
#?load?pretrain?weights
#?model_weight_path?=?"./resnet34-pre.pth"
#?missing_keys,?unexpected_keys?=?net.load_state_dict(torch.load(model_weight_path),?strict=False)#載入模型參數(shù)
#?for?param?in?net.parameters():
#?????param.requires_grad?=?False
#?change?fc?layer?structure
#?inchannel?=?net.fc.in_features
#?net.fc?=?nn.Linear(inchannel,?5)
net.to(device)
loss_function?=?nn.CrossEntropyLoss()
optimizer?=?optim.Adam(net.parameters(),?lr=0.0001)
best_acc?=?0.0
save_path?=?'./resNet34.pth'
for?epoch?in?range(3):
????#?train
????net.train()
????running_loss?=?0.0
????for?step,?data?in?enumerate(train_loader,?start=0):
????????images,?labels?=?data
????????optimizer.zero_grad()
????????logits?=?net(images.to(device))
????????loss?=?loss_function(logits,?labels.to(device))
????????loss.backward()
????????optimizer.step()
????????#?print?statistics
????????running_loss?+=?loss.item()
????????#?print?train?process
????????rate?=?(step+1)/len(train_loader)
????????a?=?"*"?*?int(rate?*?50)
????????b?=?"."?*?int((1?-?rate)?*?50)
????????print("\rtrain?loss:?{:^3.0f}%[{}->{}]{:.4f}".format(int(rate*100),?a,?b,?loss),?end="")
????print()
????#?validate
????net.eval()
????acc?=?0.0??#?accumulate?accurate?number?/?epoch
????with?torch.no_grad():
????????for?val_data?in?validate_loader:
????????????val_images,?val_labels?=?val_data
????????????outputs?=?net(val_images.to(device))??#?eval?model?only?have?last?output?layer
????????????#?loss?=?loss_function(outputs,?test_labels)
????????????predict_y?=?torch.max(outputs,?dim=1)[1]
????????????acc?+=?(predict_y?==?val_labels.to(device)).sum().item()
????????val_accurate?=?acc?/?val_num
????????if?val_accurate?>?best_acc:
????????????best_acc?=?val_accurate
????????????torch.save(net.state_dict(),?save_path)
????????print('[epoch?%d]?train_loss:?%.3f??test_accuracy:?%.3f'?%
??????????????(epoch?+?1,?running_loss?/?step,?val_accurate))
print('Finished?Training')

#predict.py
import?torch
from?model?import?resnet34
from?PIL?import?Image
from?torchvision?import?transforms
import?matplotlib.pyplot?as?plt
import?json
data_transform?=?transforms.Compose(
????[transforms.Resize(256),
?????transforms.CenterCrop(224),
?????transforms.ToTensor(),
?????transforms.Normalize([0.485,?0.456,?0.406],?[0.229,?0.224,?0.225])])
#?load?image
img?=?Image.open("./roses.jpg")
plt.imshow(img)
#?[N,?C,?H,?W]
img?=?data_transform(img)
#?expand?batch?dimension
img?=?torch.unsqueeze(img,?dim=0)
#?read?class_indict
try:
????json_file?=?open('./class_indices.json',?'r')
????class_indict?=?json.load(json_file)
except?Exception?as?e:
????print(e)
????exit(-1)
#?create?model
model?=?resnet34(num_classes=5)
#?load?model?weights
model_weight_path?=?"./resNet34.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with?torch.no_grad():
????#?predict?class
????output?=?torch.squeeze(model(img))
????predict?=?torch.softmax(output,?dim=0)
????predict_cla?=?torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)],?predict[predict_cla].numpy())
plt.show()

往期精彩回顧
獲取一折本站知識星球優(yōu)惠券,復制鏈接直接打開:
https://t.zsxq.com/662nyZF
本站qq群1003271085。
加入微信群請掃碼進群(如果是博士或者準備讀博士請說明):
