<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          U-Net模型PyTorch實現(xiàn)【含代碼+視頻】

          共 10334字,需瀏覽 21分鐘

           ·

          2022-11-22 18:02

          來源:投稿 作者:卷舒 編輯:學(xué)姐

          • 模型總覽

            1. 編碼器結(jié)構(gòu)
            2. 解碼器結(jié)構(gòu)
            3. 輸入與輸出
          • 代碼復(fù)現(xiàn)

            1. Conv Block
            2. DownSample
            3. UpSample
            4. U-Net模型
          • Reference

          前面說了過多的理論知識,可能有些乏味?,F(xiàn)在我們來通過PyTorch來復(fù)現(xiàn)U-Net

          模型總覽

          如上圖(藍色方塊上方顯示的是通道數(shù),左下角顯示的是數(shù)據(jù)的高寬)所示,U-Net的模型結(jié)構(gòu)符合我們前面說的編碼器/解碼器結(jié)構(gòu) (Encoder/Decoder structure)

          左邊的contracting path就是編碼器,從圖片提取出特征;右邊的expansive path就是解碼器。

          編碼器結(jié)構(gòu)

          左邊的編碼器和典型的卷積網(wǎng)絡(luò)結(jié)構(gòu)相似,它由兩個3×3沒有填充的卷積操作和2×2步長為2的max pooling不斷重復(fù)組成。并且每個卷積操作后面都有一個ReLU激活函數(shù)。

          由于3×3卷積操作沒有進行padding,所以每次卷積操作之后數(shù)據(jù)的寬高都會減少(k-1),k是卷積核的大小。如圖,最初是的輸入數(shù)據(jù)的寬高為572×572,經(jīng)過一次3×3沒有填充的卷積之后變成了570×570。

          在每次max pooling的下采樣中,數(shù)據(jù)的通道數(shù)會翻倍,但是寬高變?yōu)?span style="cursor:pointer;"> 表示輸入形狀,k是卷積核大小,s是步長。將k與s帶入,可以知道,每次下采樣數(shù)據(jù)的高寬都會減半。

          解碼器結(jié)構(gòu)

          右邊的解碼器與編碼器相比有兩點差異。

          • 其一,編碼器中max pooling的下采樣改成了步長為2的 2×2 的轉(zhuǎn)置卷積來進行上采樣。這里數(shù)據(jù)的通道數(shù)會減半,同時數(shù)據(jù)的寬高都會變?yōu)?span style="cursor:pointer;">。這里s步長,表示輸入形狀,k是卷積核大小。將k與 s 帶入,可以得知,每次上采樣數(shù)據(jù)的高寬都會翻倍 。

          • 其二,在每次上采樣之后有一個名為skip connection的操作,即圖中的copy and crop。即將左側(cè)對應(yīng)的特征圖與上采樣的輸出進行concatenation。

          注意:

          這里由于paddingstridekernel size的選擇,每次卷積操作,邊界像素都會有損失。所以左側(cè)的特征圖高寬是大于右側(cè)對應(yīng)特征圖的,所以這里論文中對左側(cè)特征圖先進行了crop,然后再與右側(cè)特征圖進行連接。而最后輸出結(jié)果的形狀遠小于輸入數(shù)據(jù)形狀的原因也是因為卷積操作中邊界像素的損失。

          [同時,你也可以考慮對解碼器的特征圖做線性插值或者padding操作后再進行concatenation。或者在每次卷積操作中加入為1的padding,即可使卷積操作不損失邊界且左右編碼器解碼器對應(yīng)的特征圖高寬一致(但是由于四次下采樣每次數(shù)據(jù)高寬都減半,所以使用這種方法需要確保模型輸入數(shù)據(jù)高寬是$2^4$的倍數(shù))]

          輸入與輸出

          U-Net論文中的數(shù)據(jù)是單通道的灰度圖,所以輸入數(shù)據(jù)的通道數(shù)為1(如果是RGB圖像即為3)輸入后經(jīng)過第一個卷積操作直接轉(zhuǎn)換成了64通道的特征圖,與后面的通道數(shù)翻倍增加不同。

          最后得到的輸出會經(jīng)過1×1的卷積操作將64通道的特征圖映射成所需的類別數(shù)。

          代碼復(fù)現(xiàn) 如圖所示,U-Net主要由連續(xù)的兩個conv 3×3 + ReLu,copy and crop,max pool下采樣,up-conv轉(zhuǎn)置卷積上采樣和conv 1×1組成。

          下面我們將分別實現(xiàn)連續(xù)的兩個conv3×3+ReLu,下采樣和上采樣。

          首先,我們導(dǎo)入必要的庫

          import torch
          import torch.nn as nn
          import torch.nn.functional as F
          import torchvision

          Conv Block

          這里實現(xiàn)連續(xù)的兩個conv3×3+ReLu

          class conv_block(nn.Module):
              def __init__(self, in_channels, out_channels, padding=0):
                  super().__init__()
                  self.conv = nn.Sequential(
                      nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=1,padding=padding),
                      nn.BatchNorm2d(out_channels),
                      nn.ReLU(inplace=True),
                      nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1,padding=padding),
                      nn.BatchNorm2d(out_channels),
                      nn.ReLU(inplace=True)
                  )

              def forward(self,x):
                  x = self.conv(x)
                  return x

          DownSample

          這里的下采樣包括max pool下采樣和連續(xù)的兩個conv3×3+ReLu。

          class DownSample(nn.Module):
              def __init__(self, in_channels, out_channels, padding=0):
                  super().__init__()
                  self.maxpool_conv = nn.Sequential(
                      nn.MaxPool2d(kernel_size=2, stride=2),
                      conv_block(in_channels, out_channels, padding=padding)
                  )

              def forward(self, x):
                  return self.maxpool_conv(x)

          UpSample

          這里的上采樣包括轉(zhuǎn)置卷積上采樣,并與左側(cè)對應(yīng)編碼器的特征圖concatenation。之后進行連續(xù)的兩個conv3×3+ReLu。

          class UpSample(nn.Module):
              def __init__(self, in_channels, out_channels, concat=0):
                  super().__init__()
                  """
                  concat=0 -> do center crop
                  concat=1 -> padding decoder feature map
                  concat=2 -> padding=1 in conv_block
                  "
          ""
                  self.concat = concat
                  if self.concat not in [0, 1, 2]:
                      raise Exception('concat not in list of [0, 1, 2]')
                  if self.concat == 2:
                      padding = 1

                  self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
                  self.conv = conv_block(in_channels, out_channels, padding=padding)

              def forward(self, x, x_copy):
                  x = self.up(x)

                  if self.concat == 0:
                      B, C, H, W = x.shape
                      x_copy = torchvision.transforms.CenterCrop([H, W])(x_copy)
                      
                  elif self.concat == 1:
                      diffY = x_copy.size()[2] - x.size()[2]
                      diffX = x_copy.size()[3] - x.size()[3]
                      x = F.pad(x, [
                          diffX // 2, diffX - diffX // 2, 
                          diffY // 2, diffY - diffY // 2
                          ])

                  x = torch.cat([x_copy, x], dim=1)
                  return self.conv(x)

          U-Net模型

          前面通過PyTorch構(gòu)造了U-Net模型編碼器與解碼器的各個模塊,現(xiàn)在只需要將其拼接在一起就可以組成U-Net模型了。


          class UNet(nn.Module):
              def __init__(self, n_channels, n_classes, concat=0):
                  super().__init__()
                  self.n_channels = n_channels
                  self.n_classes = n_classes
                  self.concat = concat

                  if concat == 2:
                      padding = 1
                  else:
                      padding = 0

                  expansion = 2
                  inplanes = 64
                  chns = [inplanes, inplanes*expansion, inplanes*expansion**2, inplanes*expansion**3, inplanes*expansion**4]

                  self.inc = conv_block(n_channels, chns[0], padding)
                  self.down1 = DownSample(chns[0], chns[1], padding)
                  self.down2 = DownSample(chns[1], chns[2], padding)
                  self.down3 = DownSample(chns[2], chns[3], padding)
                  self.down4 = DownSample(chns[3], chns[4], padding)

                  self.up1 = UpSample(chns[-1], chns[-2], concat)
                  self.up2 = UpSample(chns[-2], chns[-3], concat)
                  self.up3 = UpSample(chns[-3], chns[-4], concat)
                  self.up4 = UpSample(chns[-4], chns[-5], concat)
                  self.outc = nn.Conv2d(chns[-5], n_classes, kernel_size=1)

              def forward(self, x):
                  e1 = self.inc(x)
                  e2 = self.down1(e1)
                  e3 = self.down2(e2)
                  e4 = self.down3(e3)
                  e5 = self.down4(e4)
                  
                  x = self.up1(e5, e4)
                  x = self.up2(x, e3)
                  x = self.up3(x, e2)
                  x = self.up4(x, e1)
                  logits = self.outc(x)

                  return logits

          以上就是U-Net模型PyTorch的實現(xiàn)。

          Reference

          1. Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.APA

          2. Milesial. “U-Net: Semantic segmentation with PyTorch” https://github.com/milesial/Pytorch-UNet

          推薦閱讀

          全網(wǎng)最全速查表:Python 機器學(xué)習(xí)
          搭建完美的Python 機器學(xué)習(xí)開發(fā)環(huán)境
          訓(xùn)練集,驗證集,測試集,交叉驗證

          瀏覽 65
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲色婷婷国产无码av | 午夜1级操逼视频 | 天天爽天天撸 | 伊人色图吧 | a级视频在线观看 |