PyTorch 提取中間層特征?
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
來源:機(jī)器學(xué)習(xí)算法與自然語言處理
編輯:憶臻
https://www.zhihu.com/question/68384370
本文僅作為學(xué)術(shù)分享,如果侵權(quán),會刪文處理
PyTorch提取中間層特征?
作者:澀醉
https://www.zhihu.com/question/68384370/answer/751212803
通過pytorch的hook機(jī)制簡單實(shí)現(xiàn)了一下,只輸出conv層的特征圖。
import torchfrom torchvision.models import resnet18import torch.nn as nnfrom torchvision import transformsimport matplotlib.pyplot as pltdef viz(module, input):x = input[0][0]#最多顯示4張圖min_num = np.minimum(4, x.size()[0])for i in range(min_num):plt.subplot(1, 4, i+1)plt.imshow(x[i])plt.show()import cv2import numpy as npdef main():t = transforms.Compose([transforms.ToPILImage(),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = resnet18(pretrained=True).to(device)for name, m in model.named_modules():# if not isinstance(m, torch.nn.ModuleList) and# not isinstance(m, torch.nn.Sequential) and# type(m) in torch.nn.__dict__.values():# 這里只對卷積層的feature map進(jìn)行顯示if isinstance(m, torch.nn.Conv2d):m.register_forward_pre_hook(viz)img = cv2.imread('/Users/edgar/Desktop/cat.jpeg')img = t(img).unsqueeze(0).to(device)with torch.no_grad():model(img)if __name__ == '__main__':main()
打印的特征圖大概是這個(gè)樣子,取了第一層以及第四層的特征圖。


作者:袁坤
https://www.zhihu.com/question/68384370/answer/419741762
建議使用hook,在不改變網(wǎng)絡(luò)forward函數(shù)的基礎(chǔ)上提取所需的特征或者梯度,在調(diào)用階段對module使用即可獲得所需梯度或者特征。
inter_feature = {}inter_gradient = {}def make_hook(name, flag):if flag == 'forward':def hook(m, input, output):inter_feature[name] = inputreturn hookelif flag == 'backward':def hook(m, input, output):inter_gradient[name] = outputreturn hookelse:assert Falsem.register_forward_hook(make_hook(name, 'forward'))m.register_backward_hook(make_hook(name, 'backward'))
在前向計(jì)算和反向計(jì)算的時(shí)候即可達(dá)到類似鉤子的作用,中間變量已經(jīng)被放置于inter_feature 和 inter_gradient。
output = model(input) # achieve intermediate featureloss = criterion(output, target)loss.backward() # achieve backward intermediate gradients
最后可根據(jù)需求是否釋放hook。
hook.remove()作者:羅一成
https://www.zhihu.com/question/68384370/answer/263120790
提取中間特征是指把中間的weights給提出來嗎?這樣不是直接訪問那個(gè)矩陣不就好了嗎? pytorch在存參數(shù)的時(shí)候, 其實(shí)就是給所有的weights bias之類的起個(gè)名字然后存在了一個(gè)字典里面. 不然你看看state_dict.keys(), 找到相對應(yīng)的key拿出來就好了.
然后你說的慎用也是一個(gè)很奇怪的問題啊..
就算用modules下面的class, 你存模型的時(shí)候因?yàn)槟愕腶ctivation function上面本身沒有參數(shù), 所以也不會被存進(jìn)去. 不然你可以試試在Sequential里面把relu換成sigmoid, 你還是可以把之前存的state_dict給load回去.
不能說是慎用functional吧, 我覺得其他的設(shè)置是應(yīng)該分開也存一份的(假設(shè)你把這些當(dāng)做超參的話)
利益相關(guān): 給pytorch提過PR
好消息!
小白學(xué)視覺知識星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程 在「小白學(xué)視覺」公眾號后臺回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講 在「小白學(xué)視覺」公眾號后臺回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計(jì)數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。 下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講 在「小白學(xué)視覺」公眾號后臺回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。 交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計(jì)算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

