如何從頭訓(xùn)練一個(gè)一鍵摳圖模型
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
摳圖是圖像編輯的基礎(chǔ)功能之一,在摳圖的基礎(chǔ)上可以發(fā)展出很多有意思的玩法和特效。比如一鍵更換背景、一鍵任務(wù)卡通化、一鍵人物素描化等。正是因?yàn)檫@些有意思的玩法,CVPy網(wǎng)站上的一鍵摳圖功能上線以來,從贊數(shù)來看,人氣之高已經(jīng)遙遙領(lǐng)先于CV派內(nèi)其他高手,可見此模型的受歡迎程度。
筆者最近也是對(duì)此模型背后的U-2-Net網(wǎng)絡(luò)很感興趣,收集數(shù)據(jù)訓(xùn)練了人臉?biāo)孛杌P?,盡管受限于數(shù)據(jù)集,只能在人臉圖片上轉(zhuǎn)換成功,但自己仍然玩的不亦樂乎。不僅樂于玩模型的有意思的效果,更樂在訓(xùn)練模型過程中,以及遇到問題解決問題過程中,對(duì)模型理解的不斷加深。
筆者最近對(duì)一鍵扣圖模型從頭訓(xùn)練了一遍,并在訓(xùn)練過程中持續(xù)測(cè)試了不同階段模型的表現(xiàn),看著模型一點(diǎn)點(diǎn)的收斂,摳圖效果慢慢變好。

此處記錄下訓(xùn)練過程以及訓(xùn)練的效果。也可以對(duì)后來者有一個(gè)參考。
提前說一聲,模型訓(xùn)練很耗時(shí)!
2.1 代碼
代碼是U-2-Net的開源代碼,可以從Github下載:https://github.com/NathanUA/U-2-Net。這個(gè)模型本來是做顯著性檢測(cè)的,但是當(dāng)成一鍵扣圖模型也很好玩。
需要注意的地方是,如果是安裝的最新的Pytorch,獲取loss值的時(shí)候,需要將loss.data[0] 修改為loss.data.item()。
筆者在訓(xùn)練過程中曾嘗試修改Loss函數(shù)為其他的,比如改成BCE和SSIM的加權(quán)(參考U-2-Net作者去年的文章BASNet),未見明顯提升。也曾修改輸出通道訓(xùn)練其他模型,暫無好玩的結(jié)果,就當(dāng)是積累經(jīng)驗(yàn)了。
2.2 數(shù)據(jù)
數(shù)據(jù)集我們就用論文中提到的DUTS數(shù)據(jù)集,已經(jīng)分好了訓(xùn)練集和測(cè)試集。網(wǎng)上搜一下直接下載即可。
當(dāng)然,也可以用自己的數(shù)據(jù)集,按照DUTS的格式重新組織下數(shù)據(jù)集即可。
然后在訓(xùn)練代碼里面把數(shù)據(jù)讀取部分的路徑更換為自己準(zhǔn)備的數(shù)據(jù)的路徑。
2.3 機(jī)器
群里一土豪贊助了一臺(tái)4卡的機(jī)器,4塊 RTX 5000,每張卡16G內(nèi)存。跑起來確實(shí)比單卡爽多了。
然后基于Anaconda安裝訓(xùn)練所需的Python環(huán)境,創(chuàng)建虛擬環(huán)境,安裝pytorch, torchvision, skimage, opencv等等,直接pip install或者conda install即可。不多說。
另外多卡的話,代碼還需要有一些細(xì)微的改動(dòng),在構(gòu)建模型之后,將代碼:
????if?torch.cuda.is_available():
????????net.cuda()
修改為
????if?torch.cuda.is_available():
????????net.cuda()
????????net?=?nn.DataParallel(net)
3.1 模型訓(xùn)練
以上代碼、數(shù)據(jù)、機(jī)器和運(yùn)行環(huán)境都已經(jīng)準(zhǔn)備好之后,就可以開始訓(xùn)練了。多卡訓(xùn)練的命令大概長(zhǎng)下面這樣:
CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python3 -u u2net_train.py > log_train.log &
然后tail命令查看日志文件log_train.log,如果看到下面這樣的輸出,說明跑起來了:

再用命令watch -n 1 nvidia-smi查看GPU的情況,可以看到四張卡都被充分利用起來了。

模型訓(xùn)練將近一周,達(dá)到了接近論文的效果。
另外,由于中間保存過多,為了節(jié)省空間,筆者刪掉了太多前期模型,以下展示的前期效果是另外一次訓(xùn)練的前期模型的效果。
3.2 各階段模型測(cè)試
筆者微調(diào)測(cè)試代碼結(jié)構(gòu),把測(cè)試轉(zhuǎn)移到了Jupyter里,這樣畫圖看效果更加直觀。
筆者測(cè)試模型的時(shí)候,每張圖都會(huì)畫出三個(gè)圖:黑色背景的摳圖結(jié)果、模型輸出的Mask或稱Alpha,原圖。這樣對(duì)比來看結(jié)果一目了然。這里每張圖都展示了四個(gè)階段模型的測(cè)試效果。顯然,以下圖片都不在訓(xùn)練集里面。
四個(gè)階段對(duì)比著看,能更加直觀地感受到模型的收斂過程。
從以下四個(gè)階段的對(duì)比圖可以看出,隨著訓(xùn)練的進(jìn)行
前景逐漸變亮,背景逐漸變暗,即前景收斂于1,背景收斂于0。前兩幅圖之間的對(duì)比最為明顯。 前景的輪廓從模糊到清晰細(xì)銳,輪廓處的不確定區(qū)域,越來越少。 注意指縫和發(fā)梢部分的Mask的變化,細(xì)節(jié)越來越清晰。




下面這幅圖請(qǐng)注意這個(gè)卡通人物背后背的那個(gè)是蝸牛還是啥的東西的輪廓的細(xì)化過程。以及其嘴角的一撮小胡子。這個(gè)圖美中不足的是兩腳之間的背景沒有被識(shí)別出來。




下面這張圖值得關(guān)注的應(yīng)該就是其發(fā)梢的摳圖細(xì)化過程、腰部的亮度變化過程。還有就是其手中的衣服了,對(duì)于要不要把一副也給摳出來,模型看起來也很糾結(jié)啊。




這個(gè)圖最引人矚目的莫過于這位美女在風(fēng)中凌亂的發(fā)絲,這不是難為模型嗎?說實(shí)話,如果不是看到Mask里胸前多出的東西,我都沒注意到這個(gè)東西,衣服的胸結(jié)還是啥。




這大概就是訓(xùn)練了五天左右的效果,模型仍然在緩慢的收斂中,故事仍然在繼續(xù)......
直到我實(shí)在是受不了越來越慢的收斂速度,等不及訓(xùn)練其他魔改的模型,終止了訓(xùn)練任務(wù)......
本著報(bào)喜不報(bào)憂的原則,下面再放幾張測(cè)試效果還不錯(cuò)的圖片,效果不怎么樣的就不拿出來獻(xiàn)丑了。






上面的摳圖效果還是有待提高,比如頭發(fā)等邊緣處,還是可見部分背景未分離。前幾天剛轉(zhuǎn)發(fā)了動(dòng)物摳圖的新論文,邊緣和毛發(fā)的摳圖效果很贊。其單開一條支路專門做輪廓邊緣處的摳圖的思路值得參考。
不過,作者暫時(shí)開源了測(cè)試代碼,并沒有訓(xùn)練代碼。我昨晚肝到十二點(diǎn)半,終于根據(jù)論文實(shí)現(xiàn)了一版訓(xùn)練代碼,但是貌似收斂的更慢,這個(gè)優(yōu)化還是慢慢來吧。就這訓(xùn)練速度,想快也快不起來啊。反正就是玩,好玩就行。
交流群
歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請(qǐng)按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~
