深入理解Pytroch中的hook機制
【GiantPandaCV導語】Pytorch 中的 hook 機制可以很方便的讓用戶往計算圖中注入控制代碼,這樣就可以通過自定義各種操作來修改計算圖中的張量。
點擊小程序觀看視頻(時長22分)
視頻太長不看版:
Pytorch 中的 hook 機制可以很方便的讓用戶往計算圖中注入控制代碼(注入的代碼也可以刪除),這樣用戶就可以通過自定義各種操作來修改計算圖中的張量。
Pytroch 中主要有兩種hook,分別是注冊在Tensor上的hook和注冊在Module上的 hook。
注冊在 Tensor 上的 hook,可以在反向回傳過程中對梯度作修改,分為兩種:
葉子節(jié)點上的hook
中間張量上的hook
在輸出梯度傳入 backward 函數(shù)計算輸入梯度之前,調用注冊的hook的函數(shù)對梯度做一些操作
會在 AccumulateGrad 之前對梯度做一些操作
注意:
最好不要在hook函數(shù)中對梯度做 inplace 修改,因為會直接修改該梯度張量,
如果該op有多個輸入,比如 add op,那么在反向階段,如果其中一個張量上注冊的hook函數(shù)對梯度做了inplace修改,那么就會有可能影響到另一個輸入張量的梯度。
注冊在 Module 上的 hook,則可以在前后過程中對張量作修改,主要有三種:
在module的前向被調用之前調用的hook函數(shù)
在module的前向被調用之后調用的hook函數(shù)
后向過程會調用的hook
對Module的輸入張量做一些操作
對Module的輸入和輸出張量做一些操作
可以打印module輸入張量的梯度,但是目前還有bug,建議不要用。
github上相關的討論:https://github.com/pytorch/pytorch/issues/598
為了感謝讀者的長期支持,今天我們將送出三本由 機械工業(yè)出版社 提供的:《分布式人工智能:基于TensorFlow、RTOS與群體智能體系》 。點擊下方抽獎助手參與抽獎。沒抽到并且對本書有興趣的也可以使用下方鏈接進行購買。
《分布式人工智能:基于TensorFlow、RTOS與群體智能體系》抽獎鏈接

