TransUNet:基于 Transformer 和 CNN 的混合編碼網(wǎng)絡(luò)
Visual Transformer
Author:louwill
Machine Learning Lab
在深度學(xué)習(xí)醫(yī)學(xué)圖像分割領(lǐng)域,UNet結(jié)構(gòu)一直以來都牢牢占據(jù)著主導(dǎo)地位。自從2015年提出U形結(jié)構(gòu)以來,后續(xù)在UNet基礎(chǔ)上做出的魔改網(wǎng)絡(luò)不可計數(shù)。Tranformer結(jié)構(gòu)逐漸開始用于視覺領(lǐng)域之后,基于UNet和Tranformer結(jié)合的相關(guān)結(jié)構(gòu)和研究逐漸興起。
UNet用了這么多年,效果好是毋庸置疑的。但硬要是找一些缺點,也不是找不到。由于CNN的平移不變性和捕捉長期依賴能力的不足,UNet在一定程度上依然有較大的提升空間。而Tranformer正好以捕捉序列之間的長期依賴而見長,將Tranformer結(jié)構(gòu)融入到以CNN為主體的UNet中,能否進(jìn)一步發(fā)揮UNet的威力呢?
答案是肯定的。今天我們要介紹的網(wǎng)絡(luò)叫做TransUNet,正是一種充分結(jié)合UNet和Tranformer這兩種結(jié)構(gòu)的醫(yī)學(xué)圖像分割模型。提出TransUNet的論文為TransUNet:Transformers make strong encoders for medical image segmentation,發(fā)表于2021年2月,由約翰霍普金斯大學(xué)和電子科技大學(xué)等學(xué)校聯(lián)合提出。
TransUNet結(jié)構(gòu)
TransUNet完整結(jié)構(gòu)如圖1所示。

其中圖(a)是一層Transformer結(jié)構(gòu)示意圖,圖(b)是完整的TransUNet架構(gòu)。Transformer結(jié)構(gòu)不多說,對于圖像塊嵌入后,行常規(guī)的Layer Norm+MSA+MLP+殘差連接結(jié)構(gòu)處理。
我們重點看一下圖(b)的TransUNet完整架構(gòu)。完整的結(jié)構(gòu)仍然是U形的編解碼結(jié)構(gòu)。先來看編碼器部分,這也是TransUNet的關(guān)鍵部分。編碼器部分先是對輸入圖像做了三層卷積下采樣,對CNN得到的特征圖進(jìn)行圖像塊嵌入,同樣也是要加位置編碼,然后將塊嵌入后的一維向量輸入到12層Transformer結(jié)構(gòu)中。所以TransUNet編碼器的策略是CNN和Transformer混合構(gòu)建編碼器。這也是論文題目中make strong encoders的含義所在。
為什么要混合編碼呢?這也是為了各自利用Transformer和CNN的優(yōu)點來考慮的。Transformer更在注重全局信息,但容易忽略低分辨率下的圖像細(xì)節(jié),這對于解碼器恢復(fù)像素尺寸傷害比較大,會導(dǎo)致分割結(jié)果很粗糙。而CNN正好可以彌補Transformer的這個缺點。所以混合編碼在作者看來是大有裨益的。
然后是解碼器,解碼器比較簡單,就是常規(guī)的轉(zhuǎn)置卷積上采樣恢復(fù)圖像像素。同時從編碼器的CNN下采樣對應(yīng)過來同層分辨率的級聯(lián)。這些都屬于原始的UNet的固有操作。
TransUNet實驗
作者分別在Synapse多器官分割數(shù)據(jù)集和ACDC (自動化心臟診斷挑戰(zhàn)賽)上實驗了TransUNet的效果。具體地,對于混合編碼器,論文中使用ResNet-50和ViT分別作為CNN和Transformer的backbone,并且都經(jīng)過了ImageNet的預(yù)訓(xùn)練處理。
表1是TransUNet與VNet等模型的效果對比。

除了直接的模型精度比對之外,論文中還做了大量的消融實驗研究。TransUNet的消融實驗主要包括四個方面:1)跳躍連接數(shù),2)輸入圖像分辨率,3)序列長度和圖像分塊大小,4)模型大小。
下面我們僅從第一個和第三個方面來看一下TransUNet的消融實驗。第一個方面是嘗試不同的跳躍連接數(shù)來觀測模型分割的dice精度。對TransUNet網(wǎng)絡(luò)分別不做添加、添加1和3條跳躍連接后的實驗對比效果如圖2所示。

實驗結(jié)果也再一次強化了跳躍連接對于U形結(jié)構(gòu)分割網(wǎng)絡(luò)的強大效果。
消融實驗的第三個方面是關(guān)于圖像分塊大小和序列長度對于模型精度影響的。當(dāng)然這兩個說的是一回事,圖像分塊尺寸越小,圖像分塊數(shù)量就越多,也就是序列越長。一般認(rèn)為,patch size越小,Transformer序列越長,就越能編碼出更為復(fù)雜的依賴關(guān)系。論文中分別實驗了32、16和8三個尺寸的patch size,實驗效果如表2所示。

圖3顯示了TransUNet、R50-ViT-CUP、AttentionUNet和UNet四個模型在多器官分割數(shù)據(jù)上的可視化效果。從視覺效果上的對比來看,TransUNet無疑是跟Ground Truth最為接近的了。

TransUNet代碼實現(xiàn)
TransUNet完整代碼實現(xiàn)可參考論文作者提供的倉庫:
https://github.com/Beckschen/TransUNet
按照圖1的模型架構(gòu),TransUNet最后的搭建代碼如下所示。
class TransUNet(nn.Module):def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):super(VisionTransformer, self).__init__()self.num_classes = num_classesself.zero_head = zero_headself.classifier = config.classifierself.transformer = Transformer(config, img_size, vis)self.decoder = DecoderCup(config)self.segmentation_head = SegmentationHead(in_channels=config['decoder_channels'][-1],out_channels=config['n_classes'],kernel_size=3,)self.config = configdef forward(self, x):if x.size()[1] == 1:x = x.repeat(1,3,1,1)x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)x = self.decoder(x, features)logits = self.segmentation_head(x)return logits
總結(jié)
TransUNet是率先將Transformer結(jié)構(gòu)用于醫(yī)學(xué)圖像分割工作的研究。TransUNet將重視全局信息的Transformer結(jié)構(gòu)和底層圖像特征的CNN一起進(jìn)行混合編碼,能夠更大程度上提升UNet的分割效果。
參考資料:
Chen J, Lu Y, Yu Q, et al. Transunet: Transformers make strong encoders for medical image segmentation[J]. arXiv preprint arXiv:2102.04306, 2021.
往期精彩:
ViT:視覺Transformer backbone網(wǎng)絡(luò)ViT論文與代碼詳解
【原創(chuàng)首發(fā)】機器學(xué)習(xí)公式推導(dǎo)與代碼實現(xiàn)30講.pdf
【原創(chuàng)首發(fā)】深度學(xué)習(xí)語義分割理論與實戰(zhàn)指南.pdf
求個在看
