深度學習論文精讀[2]:UNet網(wǎng)絡(luò)
FCN雖然做出了開創(chuàng)性的工作,F(xiàn)CN-8s相較于此前的SOTA分割表現(xiàn),已經(jīng)取得了巨大的優(yōu)勢。但從分割效果上看還很粗糙,對圖像的細節(jié)處理還很不成熟,也沒有考慮到像素與像素之間的上下文(context)關(guān)系,所以FCN更像是一項拋磚引玉式的工作,隨著U形的編解碼結(jié)構(gòu)成為通用的語義分割網(wǎng)絡(luò)設(shè)計范式,各種網(wǎng)絡(luò)如雨后春筍般涌現(xiàn)。UNet是U形網(wǎng)絡(luò)結(jié)構(gòu)最經(jīng)典和最主要的代表網(wǎng)絡(luò),因其網(wǎng)絡(luò)結(jié)構(gòu)是一個U形而得名,這類編解碼的結(jié)構(gòu)也因而被稱之為U形結(jié)構(gòu)。提出UNet的論文為U-Net: Convolutional Networks for Biomedical Image Segmentation,與FCN提出時間相差了兩個月,其結(jié)構(gòu)設(shè)計在FCN基礎(chǔ)上做了進一步的改進,設(shè)計初衷主要是用于醫(yī)學圖像的分割。截至到本書寫稿,UNet在谷歌學術(shù)上的引用次數(shù)已達44772次,堪稱深度學習語義分割領(lǐng)域的里程碑式的工作。

在醫(yī)學圖像領(lǐng)域,具體到更加細分的醫(yī)學圖像識別任務(wù)時,大量的帶有高質(zhì)量標注的圖像數(shù)據(jù)十分難得,在此之前的通常做法是采用滑動窗口卷積(類似于圖像分塊)的方式來進行圖像局部預測,這么做的好處是可以做圖像像素做到一定程度定位,其次就是滑窗分塊能夠使得訓練樣本量增多。但缺點也很明顯,一個是滑窗操作非常耗時,推理的時候效率低下,其次就是不能兼顧定位精度和像素上下文信息的利用率。UNet在FCN的基礎(chǔ)上,完整地給出了U形的編解碼結(jié)構(gòu),如下圖所示。

UNet結(jié)構(gòu)包括編碼器下采樣、解碼器上采樣和同層跳躍連接三個組成部分。編碼器由4組卷積、ReLU激活和最大池化構(gòu)成,每一組均有兩次3*3的卷積,每個卷積層后面都有一次ReLU激活函數(shù),然后再進行一次步長為2的2*2最大池化進行下采樣,如第一組操作輸入圖像大小為572*572,兩輪3*3的卷積之后的特征圖大小為568*568,再經(jīng)過22最大池化后的輸出尺寸為284*284。解碼器由4組2*2轉(zhuǎn)置卷積、3*3卷積構(gòu)成和一個ReLU激活函數(shù)構(gòu)成,在最后的輸出層又補充了一個1*1卷積。最后是同層跳躍連接,這也是UNet的特色操作之一,指的是將下采樣時每一層的輸出裁剪后連接到同層的上采樣層做融合。每一次下采樣都會有一個跳躍連接與對應的上采樣進行融合,這種不同尺度的特征融合對上采樣恢復像素大有幫助,具體來說就是高層(淺層)下采樣倍數(shù)小,特征圖具備更加細致的圖特征,低層(深層)下采樣倍數(shù)大,信息經(jīng)過大量濃縮,空間損失大,但有助于目標區(qū)域(分類)判斷,當高層和低層的特征進行融合時,分割效果往往會非常好。從某種程度上講,這種跳躍連接也可以視為一種深度監(jiān)督。
我們將UNet結(jié)構(gòu)按照編碼器、解碼器和同層跳躍連接進行簡化,如下圖所示。編碼器下采樣用于特征提取和語義信息濃縮,解碼器上采樣用于圖像像素恢復,跳躍連接則用于信息補充。自此,基于U形結(jié)構(gòu)的編解碼設(shè)計成為深度學習語義分割中的奠基性的網(wǎng)絡(luò)結(jié)構(gòu),經(jīng)過近幾年的發(fā)展,語義分割雖然取得了長足的進步,但UNet和編解碼結(jié)構(gòu)一直是新的模型設(shè)計的參照對象。

下述代碼給出了UNet結(jié)構(gòu)的一個簡易實現(xiàn)版本。我們先分別搭建了包含卷積和ReLU的編碼塊和解碼塊,然后在編解碼塊的基礎(chǔ)上搭建完整的UNet結(jié)構(gòu),在前向計算流程中補充同層跳躍連接。
# 導入PyTorch相關(guān)模塊import torchimport torch.nn as nnimport torch.nn.functional as F### 編碼塊class UNetEnc(nn.Module):def __init__(self, in_channels, out_channels, dropout=False):super().__init__()# 每一個編碼塊中的結(jié)構(gòu)layers = [nn.Conv2d(in_channels, out_channels, 3, dilation=2),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, dilation=2),nn.ReLU(inplace=True),]if dropout:layers += [nn.Dropout(.5)]layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)]self.down = nn.Sequential(*layers)# 編碼塊前向計算流程def forward(self, x):return self.down(x)### 解碼塊class UNetDec(nn.Module):def __init__(self, in_channels, features, out_channels):super().__init__()# 每一個解碼塊中的結(jié)構(gòu)self.up = nn.Sequential(nn.Conv2d(in_channels, features, 3),nn.ReLU(inplace=True),nn.Conv2d(features, features, 3),nn.ReLU(inplace=True),nn.ConvTranspose2d(features, out_channels, 2, stride=2),nn.ReLU(inplace=True),)# 解碼塊前向計算流程def forward(self, x):return self.up(x)### 基于編解碼的U-Netclass UNet(nn.Module):def __init__(self, num_classes):super().__init__()# 四個編碼塊self.enc1 = UNetEnc(3, 64)self.enc2 = UNetEnc(64, 128)self.enc3 = UNetEnc(128, 256)self.enc4 = UNetEnc(256, 512, dropout=True)# 中間部分(U形底部)self.center = nn.Sequential(nn.Conv2d(512, 1024, 3),nn.ReLU(inplace=True),nn.Conv2d(1024, 1024, 3),nn.ReLU(inplace=True),nn.Dropout(),nn.ConvTranspose2d(1024, 512, 2, stride=2),nn.ReLU(inplace=True),)# 四個解碼塊self.dec4 = UNetDec(1024, 512, 256)self.dec3 = UNetDec(512, 256, 128)self.dec2 = UNetDec(256, 128, 64)self.dec1 = nn.Sequential(nn.Conv2d(128, 64, 3),nn.ReLU(inplace=True),nn.Conv2d(64, 64, 3),nn.ReLU(inplace=True),)self.final = nn.Conv2d(64, num_classes, 1)# 前向傳播過程def forward(self, x):enc1 = self.enc1(x)enc2 = self.enc2(enc1)enc3 = self.enc3(enc2)enc4 = self.enc4(enc3)center = self.center(enc4)# 包含了同層分辨率級聯(lián)的解碼塊dec4 = self.dec4(torch.cat([center, F.upsample_bilinear(enc4, center.size()[2:])], 1))dec3 = self.dec3(torch.cat([dec4, F.upsample_bilinear(enc3, dec4.size()[2:])], 1))dec2 = self.dec2(torch.cat([dec3, F.upsample_bilinear(enc2, dec3.size()[2:])], 1))dec1 = self.dec1(torch.cat([dec2, F.upsample_bilinear(enc1, dec2.size()[2:])], 1))return F.upsample_bilinear(self.final(dec1), x.size()[2:])
往期精彩:
深度學習論文精讀[1]:FCN全卷積網(wǎng)絡(luò)
講解視頻來了!機器學習 公式推導與代碼實現(xiàn)開錄!
完結(jié)!《機器學習 公式推導與代碼實現(xiàn)》全書1-26章PPT下載
