地址:https://zhuanlan.zhihu.com/p/490742250編輯:人工智能前沿講習(xí)
昨晚在看對比學(xué)習(xí)算法 MoCo[1] 的源代碼時,中間有一個涉及Pytorch中CrossEntropyLoss的計算問題困擾了我較長時間,因此記錄下來加深一下印象:MoCo 中 contrastive loss 的組成是由query正樣本對相似度(代碼圖中的 l_pos),以及query與一系列queue中的負(fù)樣本相似度(代碼圖中的 l_neg)共同構(gòu)成的:
在經(jīng)過拼接后,logits 為一個N*(1+K) 的矩陣,矩陣的第一列為正樣本對間的相似度,而其他剩余K列為正負(fù)樣本對之間的相似度,因此我會直觀地認(rèn)為,在對應(yīng)到標(biāo)簽計算CrossEntropyLoss時,第一列的標(biāo)簽應(yīng)該為1,而其余K列的標(biāo)簽都為0。但在算法實(shí)現(xiàn)的時候,可以明顯地看到此處的labels為一個值全為0的張量,這是為什么?這個labels張量不應(yīng)該是第一個元素為1,其他元素都為0嗎?
我在反復(fù)品讀GitHub issue中其他人關(guān)于這個問題的解答(鏈接如下),以及pytorch文檔中CrossEntropyLoss的計算方法后,總算意識到自己之前理解的錯誤所在:labels中的0元素并不是指代正負(fù)樣本對,而是告訴CrossEntropyLoss輸入第一維的標(biāo)簽為1(ground truth),也就是第0維指代的是正樣本對。https://github.com/facebookresearch/moco/issues/24#issuecomment-625508418上面這句話理解起來可能仍然有點(diǎn)抽象,因此舉個簡單例子說明一下:logits = [[0.5, 0.2, 0.2, 0.1]矩陣的行表示不同的數(shù)據(jù)樣本;第一列是正樣本對間的相似度,其他列表示正樣本與負(fù)樣本之間的相似度。注意這里labels的長度,是與logits的第一維也就是樣本數(shù)量是一致的。labels中的元素實(shí)際上意味著在進(jìn)行CrossEntropyLoss計算時,標(biāo)簽為1的ground truth的索引是多少,以logits中第一個樣本為例的話,此時0號元素為ground truth,即數(shù)值0.5對應(yīng)的標(biāo)簽值為1,其他數(shù)值對應(yīng)的標(biāo)簽值為0,在進(jìn)行CrossEntropyLoss計算時,會由 logits [0.5, 0.2, 0.2, 0.1] 與 label [1, 0, 0, 0] 來計算loss的數(shù)值。之前會有錯誤理解的原因在于對Pytorch中CrossEntropyLoss的計算方法理解還不夠深,在弄明白它的計算方法后自然就不會產(chǎn)生這樣的疑問啦。[1] He, K., Fan, H., Wu, Y., Xie, S., & Girshick, R. (2020). Momentum contrast for unsupervised visual representation learning. InProceedings of the IEEE/CVF conference on computer vision and pattern recognition(pp. 9729-9738).
猜您喜歡:
拆解組新的GAN:解耦表征MixNMatch
StarGAN第2版:多域多樣性圖像生成
附下載 |?《可解釋的機(jī)器學(xué)習(xí)》中文版
附下載 |《TensorFlow 2.0 深度學(xué)習(xí)算法實(shí)戰(zhàn)》
附下載 |《計算機(jī)視覺中的數(shù)學(xué)方法》分享
《基于深度學(xué)習(xí)的表面缺陷檢測方法綜述》
《零樣本圖像分類綜述: 十年進(jìn)展》
《基于深度神經(jīng)網(wǎng)絡(luò)的少樣本學(xué)習(xí)綜述》