三個優(yōu)秀的語義分割框架 PyTorch實現(xiàn)

【導(dǎo)語】
本文基于動手深度學(xué)習(xí)項目講解了FCN進行自然圖像語義分割的流程,并對U-Net和Deeplab網(wǎng)絡(luò)進行了實驗,在Github和谷歌網(wǎng)盤上開源了代碼和預(yù)訓(xùn)練模型,訓(xùn)練和預(yù)測的腳本已經(jīng)做好封裝,讀者可以自行下載使用。

1 前言
Colab pro,大家下載模型做預(yù)測即可。下載VOC數(shù)據(jù)集,將 JPEGImagesSegmentationClass兩個文件夾放入到data文件夾下。終端切換到目標(biāo)目錄,運行 python train.py -h查看訓(xùn)練
(torch) qust116-jq@qustx-X299-WU8:~/語義分割$ python train.py -h
usage: train.py [-h] [-m {Unet,FCN,Deeplab}] [-g GPU]
choose the model
optional arguments:
-h, --help show this help message and exit
-m {Unet,FCN,Deeplab}, --model {Unet,FCN,Deeplab}
輸入模型名字
-g GPU, --gpu GPU 輸入所需GPU
python train.py -m Unet -g 0預(yù)測需要手動修改 predict.py中的模型
d2l(動手學(xué)深度學(xué)習(xí))的講解到最后一部分。2 數(shù)據(jù)集



全卷積網(wǎng)絡(luò)將中間層特征圖的高和寬變換回輸入圖像的尺寸:這是通過中引入的轉(zhuǎn)置卷積(transposed convolution)層實現(xiàn)的。因此,輸出的類別預(yù)測與輸入圖像在像素級別上具有一一對應(yīng)關(guān)系:給定空間維上的位置,通道維的輸出即該位置對應(yīng)像素的類別預(yù)測。%matplotlib inline
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
3.1 網(wǎng)絡(luò)結(jié)構(gòu)

pretrained_net。該模型的最后幾層包括全局平均匯聚層和全連接層,然而全卷積網(wǎng)絡(luò)中不需要它們。pretrained_net = torchvision.models.resnet18(pretrained=True)
list(pretrained_net.children())[-3:]

net。它復(fù)制了Resnet-18中大部分的預(yù)訓(xùn)練層,但除去最終的全局平均匯聚層和最接近輸出的全連接層。net = nn.Sequential(*list(pretrained_net.children())[:-2])
net的前向計算將輸入的高和寬減小至原來的,即10和15。X = torch.rand(size=(1, 3, 320, 480))
net(X).shape

num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,
kernel_size=64, padding=16, stride=32))
3.2 初始化轉(zhuǎn)置卷積層
def bilinear_kernel(in_channels, out_channels, kernel_size):
factor = (kernel_size + 1) // 2
if kernel_size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = (torch.arange(kernel_size).reshape(-1, 1),
torch.arange(kernel_size).reshape(1, -1))
filt = (1 - torch.abs(og[0] - center) / factor) * \
(1 - torch.abs(og[1] - center) / factor)
weight = torch.zeros((in_channels, out_channels,
kernel_size, kernel_size))
weight[range(in_channels), range(out_channels), :, :] = filt
return weight
conv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2,
bias=False)
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4));
W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W);
3.3 訓(xùn)練
def loss(inputs, targets):
return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)
num_epochs, lr, wd, devices = 5, 0.001, 1e-3, d2l.try_all_gpus()
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
4 開源代碼和Dataset




!python3 train.py -m Unet -g 0





5 總結(jié)
6 參考
——The End——
推薦閱讀
機器學(xué)習(xí)、深度學(xué)習(xí)面試知識點匯總
深度學(xué)習(xí)attention機制中的Q,K,V分別是從哪來的?
覺得有用,麻煩給個贊和在看~ 

評論
圖片
表情

