U-Net模型PyTorch實現(xiàn)【含代碼+視頻】
來源:投稿 作者:卷舒 編輯:學(xué)姐
模型總覽
編碼器結(jié)構(gòu) 解碼器結(jié)構(gòu) 輸入與輸出 代碼復(fù)現(xiàn)
Conv Block DownSample UpSample U-Net模型 Reference
前面說了過多的理論知識,可能有些乏味?,F(xiàn)在我們來通過PyTorch來復(fù)現(xiàn)U-Net
模型總覽

如上圖(藍色方塊上方顯示的是通道數(shù),左下角顯示的是數(shù)據(jù)的高寬)所示,U-Net的模型結(jié)構(gòu)符合我們前面說的編碼器/解碼器結(jié)構(gòu) (Encoder/Decoder structure)
左邊的contracting path就是編碼器,從圖片提取出特征;右邊的expansive path就是解碼器。
編碼器結(jié)構(gòu)
左邊的編碼器和典型的卷積網(wǎng)絡(luò)結(jié)構(gòu)相似,它由兩個3×3沒有填充的卷積操作和2×2步長為2的max pooling不斷重復(fù)組成。并且每個卷積操作后面都有一個ReLU激活函數(shù)。
由于3×3卷積操作沒有進行padding,所以每次卷積操作之后數(shù)據(jù)的寬高都會減少(k-1),k是卷積核的大小。如圖,最初是的輸入數(shù)據(jù)的寬高為572×572,經(jīng)過一次3×3沒有填充的卷積之后變成了570×570。
在每次max pooling的下采樣中,數(shù)據(jù)的通道數(shù)會翻倍,但是寬高變?yōu)?span style="cursor:pointer;"> 表示輸入形狀,k是卷積核大小,s是步長。將k與s帶入,可以知道,每次下采樣數(shù)據(jù)的高寬都會減半。
解碼器結(jié)構(gòu)
右邊的解碼器與編碼器相比有兩點差異。
其一,編碼器中max pooling的下采樣改成了步長為2的 2×2 的轉(zhuǎn)置卷積來進行上采樣。這里數(shù)據(jù)的通道數(shù)會減半,同時數(shù)據(jù)的寬高都會變?yōu)?span style="cursor:pointer;">。這里s步長,表示輸入形狀,k是卷積核大小。將k與 s 帶入,可以得知,每次上采樣數(shù)據(jù)的高寬都會翻倍 。
其二,在每次上采樣之后有一個名為skip connection的操作,即圖中的copy and crop。即將左側(cè)對應(yīng)的特征圖與上采樣的輸出進行concatenation。
注意:
這里由于padding、stride與kernel size的選擇,每次卷積操作,邊界像素都會有損失。所以左側(cè)的特征圖高寬是大于右側(cè)對應(yīng)特征圖的,所以這里論文中對左側(cè)特征圖先進行了crop,然后再與右側(cè)特征圖進行連接。而最后輸出結(jié)果的形狀遠小于輸入數(shù)據(jù)形狀的原因也是因為卷積操作中邊界像素的損失。
[同時,你也可以考慮對解碼器的特征圖做線性插值或者padding操作后再進行concatenation。或者在每次卷積操作中加入為1的padding,即可使卷積操作不損失邊界且左右編碼器解碼器對應(yīng)的特征圖高寬一致(但是由于四次下采樣每次數(shù)據(jù)高寬都減半,所以使用這種方法需要確保模型輸入數(shù)據(jù)高寬是$2^4$的倍數(shù))]
輸入與輸出
U-Net論文中的數(shù)據(jù)是單通道的灰度圖,所以輸入數(shù)據(jù)的通道數(shù)為1(如果是RGB圖像即為3)輸入后經(jīng)過第一個卷積操作直接轉(zhuǎn)換成了64通道的特征圖,與后面的通道數(shù)翻倍增加不同。
最后得到的輸出會經(jīng)過1×1的卷積操作將64通道的特征圖映射成所需的類別數(shù)。
代碼復(fù)現(xiàn)
如圖所示,U-Net主要由連續(xù)的兩個conv 3×3 + ReLu,copy and crop,max pool下采樣,up-conv轉(zhuǎn)置卷積上采樣和conv 1×1組成。
下面我們將分別實現(xiàn)連續(xù)的兩個conv3×3+ReLu,下采樣和上采樣。
首先,我們導(dǎo)入必要的庫
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
Conv Block
這里實現(xiàn)連續(xù)的兩個conv3×3+ReLu
class conv_block(nn.Module):
def __init__(self, in_channels, out_channels, padding=0):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=1,padding=padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1,padding=padding),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
DownSample
這里的下采樣包括max pool下采樣和連續(xù)的兩個conv3×3+ReLu。
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels, padding=0):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
conv_block(in_channels, out_channels, padding=padding)
)
def forward(self, x):
return self.maxpool_conv(x)
UpSample
這里的上采樣包括轉(zhuǎn)置卷積上采樣,并與左側(cè)對應(yīng)編碼器的特征圖concatenation。之后進行連續(xù)的兩個conv3×3+ReLu。
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels, concat=0):
super().__init__()
"""
concat=0 -> do center crop
concat=1 -> padding decoder feature map
concat=2 -> padding=1 in conv_block
"""
self.concat = concat
if self.concat not in [0, 1, 2]:
raise Exception('concat not in list of [0, 1, 2]')
if self.concat == 2:
padding = 1
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = conv_block(in_channels, out_channels, padding=padding)
def forward(self, x, x_copy):
x = self.up(x)
if self.concat == 0:
B, C, H, W = x.shape
x_copy = torchvision.transforms.CenterCrop([H, W])(x_copy)
elif self.concat == 1:
diffY = x_copy.size()[2] - x.size()[2]
diffX = x_copy.size()[3] - x.size()[3]
x = F.pad(x, [
diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2
])
x = torch.cat([x_copy, x], dim=1)
return self.conv(x)
U-Net模型
前面通過PyTorch構(gòu)造了U-Net模型編碼器與解碼器的各個模塊,現(xiàn)在只需要將其拼接在一起就可以組成U-Net模型了。
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, concat=0):
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.concat = concat
if concat == 2:
padding = 1
else:
padding = 0
expansion = 2
inplanes = 64
chns = [inplanes, inplanes*expansion, inplanes*expansion**2, inplanes*expansion**3, inplanes*expansion**4]
self.inc = conv_block(n_channels, chns[0], padding)
self.down1 = DownSample(chns[0], chns[1], padding)
self.down2 = DownSample(chns[1], chns[2], padding)
self.down3 = DownSample(chns[2], chns[3], padding)
self.down4 = DownSample(chns[3], chns[4], padding)
self.up1 = UpSample(chns[-1], chns[-2], concat)
self.up2 = UpSample(chns[-2], chns[-3], concat)
self.up3 = UpSample(chns[-3], chns[-4], concat)
self.up4 = UpSample(chns[-4], chns[-5], concat)
self.outc = nn.Conv2d(chns[-5], n_classes, kernel_size=1)
def forward(self, x):
e1 = self.inc(x)
e2 = self.down1(e1)
e3 = self.down2(e2)
e4 = self.down3(e3)
e5 = self.down4(e4)
x = self.up1(e5, e4)
x = self.up2(x, e3)
x = self.up3(x, e2)
x = self.up4(x, e1)
logits = self.outc(x)
return logits
以上就是U-Net模型PyTorch的實現(xiàn)。
Reference
Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.APA
Milesial. “U-Net: Semantic segmentation with PyTorch” https://github.com/milesial/Pytorch-UNet
推薦閱讀
全網(wǎng)最全速查表:Python 機器學(xué)習(xí) 搭建完美的Python 機器學(xué)習(xí)開發(fā)環(huán)境 訓(xùn)練集,驗證集,測試集,交叉驗證
