PyTorch模型訓(xùn)練特征圖可視化(TensorboardX)

極市導(dǎo)讀
?本文介紹了Loss可視化、輸入圖片和標(biāo)簽的可視化、單通道特征圖的可視化、多通道特征圖的可視化,并分析了make_grid()通道數(shù)的問(wèn)題。>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺(jué)的最前沿
0、前言
本文所有代碼解讀均基于PyTorch 1.0,Python3; 本文為原創(chuàng)文章,初次完成于2019.03,最后更新于2019.09;
最近復(fù)現(xiàn)的一篇論文一直都難以work,上了特征圖可視化后一下子就找到了問(wèn)題所在,所以今天想梳理一下PyTorch里面的特征圖可視化。
大家都知道Tensorflow有一款非常優(yōu)秀的可視化工具Tensorboard,而PyTorch自身沒(méi)有可視化功能,但是我們可以尋找替代品,即TensorBoardX。安裝過(guò)程不多介紹,詳見(jiàn)下面的參考鏈接,里面相應(yīng)有比較豐富的介紹。
tensor-yu/PyTorch_Tutorial
https://github.com/tensor-yu/PyTorch_Tutorial
1、Loss可視化
最常見(jiàn)的可視化就是loss曲線作圖,這個(gè)實(shí)現(xiàn)相對(duì)比較簡(jiǎn)單,不多做介紹了
tb_logger.add_scalar('loss_train', loss, curr_step)

2、輸入圖片和標(biāo)簽的可視化
模型不work,第一個(gè)應(yīng)該檢查的就是輸入輸出有沒(méi)有沒(méi)給對(duì),因此我們需要將傳遞給model的 input 和 label 可視化一下。
傳遞給網(wǎng)絡(luò)的圖片格式往往是 [B,C,H,W] ,范圍[0, 1],數(shù)據(jù)類(lèi)型tensor.FloatTensor,但是add_iamge() 能夠接受的格式是[C,H,W], 范圍[0,1],數(shù)據(jù)類(lèi)型tensor.FloatTensor。
一個(gè)是三維的,一個(gè)是四維的,這很好解決,我們把每個(gè)batch的第一張圖拿出來(lái)就行了:input[0]的形狀就是[C,H,W],符合輸入要求。
tb_logger.add_image('image', input[0], curr_step)
tensoroard里面如果出現(xiàn)了貓咪本尊的正確可視化結(jié)果,就說(shuō)明輸入圖片沒(méi)問(wèn)題

如果你在Dataloader里對(duì)輸入圖片做了Normalize,顯示會(huì)出現(xiàn)問(wèn)題,出現(xiàn)如下所示的亂碼,此時(shí)需要通過(guò)make_grid()函數(shù)做一些處理,函數(shù)用法具體可見(jiàn)后面的描述。

3、單通道特征圖的可視化
有時(shí)候我們需要把網(wǎng)絡(luò)內(nèi)部分節(jié)點(diǎn)的特征圖可視化出來(lái),這時(shí)候上面的方法就不能用了,因?yàn)樘卣鲌D的每個(gè)像素點(diǎn)上的數(shù)值范圍不是[0,1],而是可正可負(fù),可大可小,因此需要做一些特殊處理。這里就要用到 torchvision.utils.make_grid( )函數(shù),把輸入的特征圖做一個(gè)歸一化,把參數(shù)normalize設(shè)置為T(mén)rue即可,它能幫我們把數(shù)據(jù)的輸入范圍調(diào)整至[0, 1]之間
def make_grid(tensor, nrow=8, padding=2,normalize=False, range=None, scale_each=False, pad_value=0):
更多其他參數(shù)的用法參見(jiàn)源碼:
https://github.com/pytorch/vision/blob/master/torchvision/utils.py
這里我把三個(gè)中間特征圖拼在了一塊顯示:
from?torchvision.utils?import?make_gridtb_logger.add_image('feature_map', make_grid([feature_map1, feature_map2, fetare_map3], padding=20, normalize=True, scale_each=True, pad_value=1), curr_step)
需要注意的是:
make_grid() 輸入的是Tensor,而不是numpy.ndarray torchvision.utils.make_grid() 將一組圖片繪制到一個(gè)窗口,其本質(zhì)是將一組圖片拼接成一張圖片
4、多通道特征圖的可視化
多通道的特征圖的顯示和上面的單通道存在一些區(qū)別,假設(shè)我們從batsh_size=16,channel=20的一個(gè)tensor想取出一個(gè)多通道特征圖可視化,只需要如下操作
feature_map[0].deatch().cpu().unsqueeze(dim=1)
這樣就能把一個(gè)形狀為 [16,20, H, W] 的tesnor取出并轉(zhuǎn)換為 [20, 1, H ,W] 的形狀,這與為什么要這么轉(zhuǎn)換,詳解第五章節(jié)。
完整代碼和如下:
tb_logger.add_image('channels',?make_grid(feature_map[0].detach().cpu().unsqueeze(dim=1),?nrow=5,?padding=20,?normalize=False,?pad_value=1),?curr_step)5、make_grid()通道數(shù)的問(wèn)題
測(cè)試發(fā)現(xiàn),輸入 [1, H, W] 的數(shù)據(jù)沒(méi)問(wèn)題,但是[20, H, W] 就不行,[20, 1, H, W] 就可以
這是因?yàn)閱瓮ǖ?[1, H, W] 不存在歧義,但是多通道就不行,比如說(shuō)[3, H, W] 到底是一張三通道的圖還是三張單通道的圖,存在歧義
因此想要顯示一張多通道的特征圖可以這么轉(zhuǎn)換:[1, C, H, W] --> [C, 1, H, W],顯性地指明tensor形狀。
6、總結(jié)
特征圖可視化在模型復(fù)現(xiàn)過(guò)程中十分有用,可用于定位模型錯(cuò)誤所在,但是在tensor的數(shù)據(jù)格式、尺寸、維度上存在許多講究,使用時(shí)需要額外小心。
附錄一:相關(guān)函數(shù)源碼
其實(shí)想要熟練使用,還是多看看make_grid的源碼和樣例吧:
https://github.com/pytorch/vision/blob/master/torchvision/utils.py
https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91
附錄二:網(wǎng)絡(luò)結(jié)構(gòu)可視化工具
Caffe網(wǎng)絡(luò)可視化工具
Netscope
https://ethereon.github.io/netscope/%23/editor
PyTorch等網(wǎng)絡(luò)的可視化工具
https://github.com/waleedka/hiddenlayer
大概是這么個(gè)效果,相對(duì)清晰一些

推薦閱讀
你真的理解Faster RCNN嗎?捋一捋Pytorch官方Faster RCNN代碼
PyTorch Lightning 1.0 正式發(fā)布!從0到1,有這9大特點(diǎn)
在PyTorch中進(jìn)行雙線性采樣:原理和代碼詳解

