<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          損失函數(shù)技術(shù)總結(jié)及Pytorch使用示例

          共 1743字,需瀏覽 4分鐘

           ·

          2022-04-12 02:47

          ↑ 點擊藍字?關(guān)注極市平臺

          作者丨仿佛若有光
          來源丨CV技術(shù)指南
          編輯丨極市平臺

          極市導(dǎo)讀

          ?

          本文對損失函數(shù)的類別和應(yīng)用場景,常見的損失函數(shù),常見損失函數(shù)的表達式,特性,應(yīng)用場景和使用示例作了詳細的總結(jié)。?>>加入極市CV技術(shù)交流群,走在計算機視覺的最前沿

          前言

          一直想寫損失函數(shù)的技術(shù)總結(jié),但網(wǎng)上已經(jīng)有諸多關(guān)于損失函數(shù)綜述的文章或博客,考慮到這點就一直拖著沒寫,直到有一天,我將一個二分類項目修改為多分類,簡簡單單地修改了損失函數(shù),結(jié)果一直有問題,后來才發(fā)現(xiàn)是不同函數(shù)的標簽的設(shè)置方式并不相同。

          為了避免讀者也出現(xiàn)這樣的問題,本文中會給出每個損失函數(shù)的pytorch使用示例,這也是本文與其它相關(guān)綜述文章或博客的區(qū)別所在。希望讀者在閱讀本文時,重點關(guān)注一下每個損失函數(shù)的使用示例中的target的設(shè)置問題。

          本文對損失函數(shù)的類別和應(yīng)用場景,常見的損失函數(shù),常見損失函數(shù)的表達式,特性,應(yīng)用場景和使用示例作了詳細的總結(jié)。

          主要涉及到L1 loss、L2 loss、Negative Log-Likelihood loss、Cross-Entropy loss、Hinge Embedding loss、Margin Ranking Loss、Triplet Margin loss、KL Divergence.

          損失函數(shù)分類與應(yīng)用場景

          損失函數(shù)可以分為三類:回歸損失函數(shù)(Regression loss)、分類損失函數(shù)(Classification loss)和排序損失函數(shù)(Ranking loss)。

          應(yīng)用場景:回歸損失:用于預(yù)測連續(xù)的值。如預(yù)測房價、年齡等。分類損失:用于預(yù)測離散的值。如圖像分類,語義分割等。排序損失:用于預(yù)測輸入數(shù)據(jù)之間的相對距離。如行人重識別。

          L1 loss

          也稱Mean Absolute Error,簡稱MAE,計算實際值和預(yù)測值之間的絕對差之和的平均值。

          表達式如下:

          y表示標簽,pred表示預(yù)測值。

          應(yīng)用場合:回歸問題。

          根據(jù)損失函數(shù)的表達式很容易了解它的特性:當目標變量的分布具有異常值時,即與平均值相差很大的值,它被認為對異常值具有很好的魯棒行。

          使用示例:

          input = torch.randn(3, 5, requires_grad=True)
          target = torch.randn(3, 5)

          mae_loss = torch.nn.L1Loss()
          output = mae_loss(input, target)

          L2 loss

          也稱為Mean Squared Error,簡稱MSE,計算實際值和預(yù)測值之間的平方差的平均值。

          表達式如下:

          應(yīng)用場合:對大部分回歸問題,pytorch默認使用L2,即MSE。

          使用平方意味著當預(yù)測值離目標值更遠時在平方后具有更大的懲罰,預(yù)測值離目標值更近時在平方后懲罰更小,因此,當異常值與樣本平均值相差格外大時,模型會因為懲罰更大而開始偏離,相比之下,L1對異常值的魯棒性更好。

          使用示例:

          input = torch.randn(3, 5, requires_grad=True)
          target = torch.randn(3, 5)
          mse_loss = torch.nn.MSELoss()
          output = mse_loss(input, target)

          Negative Log-Likelihood

          簡稱NLL。表達式如下:

          應(yīng)用場景:多分類問題。

          注:NLL要求網(wǎng)絡(luò)最后一層使用softmax作為激活函數(shù)。通過softmax將輸出值映射為每個類別的概率值。

          根據(jù)表達式,它的特性是懲罰預(yù)測準確而預(yù)測概率不高的情況。

          NLL 使用負號,因為概率(或似然)在 0 和 1 之間變化,并且此范圍內(nèi)的值的對數(shù)為負。最后,損失值變?yōu)檎怠?/p>

          在 NLL 中,最小化損失函數(shù)有助于獲得更好的輸出。從近似最大似然估計 (MLE) 中檢索負對數(shù)似然。這意味著嘗試最大化模型的對數(shù)似然,從而最小化 NLL。

          使用示例

          # size of input (N x C) is = 3 x 5
          input = torch.randn(3, 5, requires_grad=True)
          # every element in target should have 0 <= value < C
          target = torch.tensor([1, 0, 4])

          m = nn.LogSoftmax(dim=1)
          nll_loss = torch.nn.NLLLoss()
          output = nll_loss(m(input), target)

          Cross-Entropy

          此損失函數(shù)計算提供的一組出現(xiàn)次數(shù)或隨機變量的兩個概率分布之間的差異。它用于計算預(yù)測值與實際值之間的平均差異的分數(shù)。

          表達式:

          應(yīng)用場景:二分類及多分類。

          特性:負對數(shù)似然損失不對預(yù)測置信度懲罰,與之不同的是,交叉熵懲罰不正確但可信的預(yù)測,以及正確但不太可信的預(yù)測。

          交叉熵函數(shù)有很多種變體,其中最常見的類型是Binary Cross-Entropy (BCE)。BCE Loss 主要用于二分類模型;也就是說,模型只有 2 個類。

          使用示例

          input = torch.randn(3, 5, requires_grad=True)
          target = torch.empty(3, dtype=torch.long).random_(5)

          cross_entropy_loss = torch.nn.CrossEntropyLoss()
          output = cross_entropy_loss(input, target)

          Hinge Embedding

          表達式:

          其中y為1或-1。

          應(yīng)用場景:

          分類問題,特別是在確定兩個輸入是否不同或相似時。

          學(xué)習非線性嵌入或半監(jiān)督學(xué)習任務(wù)。

          使用示例

          input = torch.randn(3, 5, requires_grad=True)
          target = torch.randn(3, 5)

          hinge_loss = torch.nn.HingeEmbeddingLoss()
          output = hinge_loss(input, target)

          Margin Ranking Loss

          Margin Ranking Loss 計算一個標準來預(yù)測輸入之間的相對距離。這與其他損失函數(shù)(如 MSE 或交叉熵)不同,后者學(xué)習直接從給定的輸入集進行預(yù)測。

          表達式:

          標簽張量 y(包含 1 或 -1)。當 y == 1 時,第一個輸入將被假定為更大的值。它將排名高于第二個輸入。如果 y == -1,則第二個輸入將排名更高。

          應(yīng)用場景:排名問題

          使用示例

          input_one = torch.randn(3, requires_grad=True)
          input_two = torch.randn(3, requires_grad=True)
          target = torch.randn(3).sign()

          ranking_loss = torch.nn.MarginRankingLoss()
          output = ranking_loss(input_one, input_two, target)

          Triplet Margin Loss

          計算三元組的損失。

          表達式:

          三元組由a (anchor),p (正樣本) 和 n (負樣本)組成.

          應(yīng)用場景:

          確定樣本之間的相對相似性

          用于基于內(nèi)容的檢索問題

          使用示例

          anchor = torch.randn(100, 128, requires_grad=True)
          positive = torch.randn(100, 128, requires_grad=True)
          negative = torch.randn(100, 128, requires_grad=True)

          triplet_margin_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2)
          output = triplet_margin_loss(anchor, positive, negative)

          KL Divergence Loss

          計算兩個概率分布之間的差異。

          表達式:

          輸出表示兩個概率分布的接近程度。如果預(yù)測的概率分布與真實的概率分布相差很遠,就會導(dǎo)致很大的損失。如果 KL Divergence 的值為零,則表示概率分布相同。

          KL Divergence 與交叉熵損失的關(guān)鍵區(qū)別在于它們?nèi)绾翁幚眍A(yù)測概率和實際概率。交叉熵根據(jù)預(yù)測的置信度懲罰模型,而 KL Divergence 則沒有。KL Divergence 僅評估概率分布預(yù)測與ground truth分布的不同之處。

          應(yīng)用場景:逼近復(fù)雜函數(shù)多類分類任務(wù)確保預(yù)測的分布與訓(xùn)練數(shù)據(jù)的分布相似

          使用示例

          input = torch.randn(2, 3, requires_grad=True)
          target = torch.randn(2, 3)

          kl_loss = torch.nn.KLDivLoss(reduction = 'batchmean')
          output = kl_loss(input, target)

          原文鏈接:https://neptune.ai/blog/pytorch-loss-functions
          本文在此鏈接的基礎(chǔ)上進行一部分而來修改。


          △點擊卡片關(guān)注極市平臺,獲取最新CV干貨


          極市干貨
          YOLO教程:一文讀懂YOLO V5 與 YOLO V4大盤點|YOLO 系目標檢測算法總覽全面解析YOLO V4網(wǎng)絡(luò)結(jié)構(gòu)
          實操教程:PyTorch vs LibTorch:網(wǎng)絡(luò)推理速度誰更快?只用兩行代碼,我讓Transformer推理加速了50倍
          算法技巧(trick):深度學(xué)習訓(xùn)練tricks總結(jié)(有實驗支撐)深度強化學(xué)習調(diào)參Tricks合集


          #?CV技術(shù)社群邀請函?#

          △長按添加極市小助手
          添加極市小助手微信(ID : cvmart2)

          備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標檢測-深圳)


          即可申請加入極市目標檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強/OCR/視頻理解等技術(shù)交流群


          每月大咖直播分享、真實項目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與?10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~



          覺得有用麻煩給個在看啦~??
          瀏覽 43
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  哪里可以看AV片 | 日本东京热高清 | 激情另类视频 | 中文字幕日亚州 | 成人少妇AV |