<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          獨(dú)家 | 使用TensorFlow 2創(chuàng)建自定義損失函數(shù)

          共 3723字,需瀏覽 8分鐘

           ·

          2021-05-23 14:02


          作者:Arjun Sarkar

          翻譯:陳之炎

          校對(duì):歐陽(yáng)錦


          本文約1900字,建議閱讀8分鐘
          本文帶你學(xué)習(xí)使用Python中的wrapper函數(shù)和OOP來(lái)編寫(xiě)自定義損失函數(shù)。
                  
          標(biāo)簽:TensorFlow 2,損失函數(shù)
           

          圖1:梯度下降算法(來(lái)源:公共域,https://commons.wikimedia.org/w/index.php?curid=521422 )

          神經(jīng)網(wǎng)絡(luò)利用訓(xùn)練數(shù)據(jù),將一組輸入映射成一組輸出,它通過(guò)使用某種形式的優(yōu)化算法,如梯度下降、隨機(jī)梯度下降、AdaGrad、AdaDelta等等來(lái)實(shí)現(xiàn),其中最新的算法包括Adam、Nadam或RMSProp。梯度下降中的“梯度”是指誤差梯度。每次迭代之后,網(wǎng)絡(luò)將其預(yù)測(cè)輸出與實(shí)際輸出進(jìn)行比較,然后計(jì)算出“誤差”。

          通常,對(duì)于神經(jīng)網(wǎng)絡(luò),尋求的是將誤差最小化。將誤差最小化的目標(biāo)函數(shù)通常稱之為成本函數(shù)或損失函數(shù),由“損失函數(shù)”計(jì)算出的值稱為“損失”。在各種問(wèn)題中使用的典型損失函數(shù)有:
           
          • 均方誤差;

          • 均方對(duì)數(shù)誤差;

          • 二元交叉熵;

          • 分類交叉熵;

          • 稀疏分類交叉熵。


          Tensorflow已經(jīng)包含了上述損失函數(shù),直接調(diào)用它們即可,如下所示:

          1. 將損失函數(shù)當(dāng)作字符串進(jìn)行調(diào)用


          model.compile (loss = ‘binary_crossentropy’,optimizer = ‘a(chǎn)dam’, metrics = [‘a(chǎn)ccuracy’])

          2. 將損失函數(shù)當(dāng)作對(duì)象進(jìn)行調(diào)用

          from tensorflow.keras.losses importmean_squared_errormodel.compile(loss = mean_squared_error,optimizer=’sgd’)

          將損失函數(shù)當(dāng)作對(duì)象進(jìn)行調(diào)用的優(yōu)點(diǎn)是可以在損失函數(shù)中傳遞閾值等參數(shù)。

          from tensorflow.keras.losses import mean_squared_errormodel.compile (loss=mean_squared_error(param=value),optimizer = ‘sgd’)

          利用現(xiàn)有函數(shù)創(chuàng)建自定義損失函數(shù):


          利用現(xiàn)有函數(shù)創(chuàng)建損失函數(shù),首先需要定義損失函數(shù),它將接受兩個(gè)參數(shù),y_true(真實(shí)標(biāo)簽/輸出)和y_pred(預(yù)測(cè)標(biāo)簽/輸出)。

          def loss_function(y_true, y_pred):***some calculation***return loss

          創(chuàng)建均方誤差損失函數(shù) (RMSE):


          定義損失函數(shù)名稱-my_rmse。目的是返回目標(biāo)(y_true)與預(yù)測(cè)(y_pred)之間的均方誤差。

          RMSE的公式為:


          • 誤差:真實(shí)標(biāo)簽與預(yù)測(cè)標(biāo)簽之間的差異。

          • sqr_error:誤差的平方。

          • mean_sqr_error:誤差平方的均值。

          • sqrt_mean_sqr_error:誤差平方均值的平方根(均方根誤差)。



          創(chuàng)建Huber損失函數(shù):


          圖2:Huber損失函數(shù)(綠色)和平方誤差損失函數(shù)(藍(lán)色)(來(lái)源:Qwertyus— Own work,CCBY-SA4.0,https://commons.wikimedia.org/w/index.php?curid=34836380)

          Huber損失函數(shù)的計(jì)算公式:


          在此處,δ是閾值,a是誤差(將計(jì)算出a,即實(shí)際標(biāo)簽和預(yù)測(cè)標(biāo)簽之間的差異)。

          當(dāng)|a|≤δ時(shí),loss = 1/2*(a)2
          當(dāng) |a|>δ時(shí),loss = δ(|a|—(1/2)*δ)

          源代碼:


          詳細(xì)說(shuō)明:

          首先,定義一個(gè)函數(shù)—— my huber loss,它需要兩個(gè)參數(shù):y_true和y_pred,
          設(shè)置閾值threshold = 1。

          計(jì)算誤差error a = y_true-y_pred。接下來(lái),檢查誤差的絕對(duì)值是否小于或等于閾值,is_small_error返回一個(gè)布爾值(真或假)。

          當(dāng)|a|≤δ時(shí),loss= 1/2*(a)2,計(jì)算small_error_loss, 誤差的平方除以2。否則,當(dāng)|a| >δ時(shí),則損失等于δ(|a|-(1/2)*δ),用big_error_loss來(lái)計(jì)算這個(gè)值。

          最后,在返回語(yǔ)句中,首先檢查is_small_error是真還是假,如果它為真,函數(shù)返回small_error_loss,否則返回big_error_loss,使用tf.where來(lái)實(shí)現(xiàn)。

          可以使用下述代碼來(lái)編譯模型:


          在上述代碼中,將閾值設(shè)為1。

          如果需要調(diào)整超參數(shù)(閾值),并在編譯過(guò)程中加入一個(gè)新的閾值的話,必須使用wrapper函數(shù)進(jìn)行封裝,也就是說(shuō),將損失函數(shù)封裝成另一個(gè)外部函數(shù)。在這里需要用到封裝函數(shù)(wrapper function),因?yàn)閾p失函數(shù)在默認(rèn)情況下只能接受y_true和y_pred值,而且不能向原始損失函數(shù)添加任何其他參數(shù)。
           

          使用封裝后的Huber損失函數(shù)


          封裝函數(shù)的源代碼:


          此時(shí),閾值不是硬編碼,可以在模型編譯過(guò)程中傳遞該閾值。


          使用類實(shí)現(xiàn)Huber損失函數(shù)(OOP)



          其中,MyHuberLoss是類名稱,隨后從tensorflow.keras.losses繼承父類“Loss”, MyHuberLoss繼承了Loss類,之后可以將MyHuberLoss當(dāng)作損失函數(shù)來(lái)使用。

          __init__   初始化該類中的對(duì)象。執(zhí)行類實(shí)例化對(duì)象時(shí)調(diào)用函數(shù),init函數(shù)返回閾值,調(diào)用函數(shù)得到y(tǒng)_true和y_pred參數(shù),將閾值聲明為一個(gè)類變量,可以給它賦一個(gè)初始值。

          在__init__函數(shù)中,將閾值設(shè)置為self.threshold。在調(diào)用函數(shù)中,self.threshold引用所有的閾值類變量。在model.compile中使用這個(gè)損失函數(shù):


          創(chuàng)建對(duì)比性損失(用于Siamese網(wǎng)絡(luò)):



          Siamese網(wǎng)絡(luò)可以用來(lái)比較兩幅圖像是否相似,Siamese網(wǎng)絡(luò)使用的損失函數(shù)為對(duì)比性損失。

          在上文的公式中,Y_true是關(guān)于圖像相似性細(xì)節(jié)的張量,如果圖像相似,則為1,如果圖像不相似,則為0。

          D是圖像對(duì)之間的歐氏距離的張量。邊際為一個(gè)常量,用它來(lái)設(shè)置將圖像區(qū)別為相似或不同的最小距離。如果為Y_true=1,則方程的第一部分為D2,第二部分為0,所以,當(dāng)Y_true接近1時(shí),D2的權(quán)重則更重。

          如果Y_true=0,則方程的第一部分變?yōu)?,第二部分會(huì)產(chǎn)生一些結(jié)果,這給了最大項(xiàng)更多的權(quán)重,給了D平方項(xiàng)更少的權(quán)重,此時(shí),最大項(xiàng)在損失計(jì)算中占了優(yōu)勢(shì)。

          使用封裝器函數(shù)實(shí)現(xiàn)對(duì)比損失函數(shù):



          結(jié)論


          在Tensorflow中沒(méi)有的損失函數(shù)都可以利用函數(shù)、包裝函數(shù)或類似的類來(lái)創(chuàng)建。
           
          原文標(biāo)題:
          Creating custom Loss functionsusing TensorFlow 2
          原文鏈接:
          https://towardsdatascience.com/creating-custom-loss-functions-using-tensorflow-2-96c123d5ce6c   

          編輯:黃繼彥
          校對(duì):林亦霖




          譯者簡(jiǎn)介




          陳之炎,北京交通大學(xué)通信與控制工程專業(yè)畢業(yè),獲得工學(xué)碩士學(xué)位,歷任長(zhǎng)城計(jì)算機(jī)軟件與系統(tǒng)公司工程師,大唐微電子公司工程師,現(xiàn)任北京吾譯超群科技有限公司技術(shù)支持。目前從事智能化翻譯教學(xué)系統(tǒng)的運(yùn)營(yíng)和維護(hù),在人工智能深度學(xué)習(xí)和自然語(yǔ)言處理(NLP)方面積累有一定的經(jīng)驗(yàn)。業(yè)余時(shí)間喜愛(ài)翻譯創(chuàng)作,翻譯作品主要有:IEC-ISO 7816、伊拉克石油工程項(xiàng)目、新財(cái)稅主義宣言等等,其中中譯英作品“新財(cái)稅主義宣言”在GLOBAL TIMES正式發(fā)表。能夠利用業(yè)余時(shí)間加入到THU 數(shù)據(jù)派平臺(tái)的翻譯志愿者小組,希望能和大家一起交流分享,共同進(jìn)步。

          翻譯組招募信息

          工作內(nèi)容:需要一顆細(xì)致的心,將選取好的外文文章翻譯成流暢的中文。如果你是數(shù)據(jù)科學(xué)/統(tǒng)計(jì)學(xué)/計(jì)算機(jī)類的留學(xué)生,或在海外從事相關(guān)工作,或?qū)ψ约和庹Z(yǔ)水平有信心的朋友歡迎加入翻譯小組。

          你能得到:定期的翻譯培訓(xùn)提高志愿者的翻譯水平,提高對(duì)于數(shù)據(jù)科學(xué)前沿的認(rèn)知,海外的朋友可以和國(guó)內(nèi)技術(shù)應(yīng)用發(fā)展保持聯(lián)系,THU數(shù)據(jù)派產(chǎn)學(xué)研的背景為志愿者帶來(lái)好的發(fā)展機(jī)遇。

          其他福利:來(lái)自于名企的數(shù)據(jù)科學(xué)工作者,北大清華以及海外等名校學(xué)生他們都將成為你在翻譯小組的伙伴。


          點(diǎn)擊文末“閱讀原文”加入數(shù)據(jù)派團(tuán)隊(duì)~



          轉(zhuǎn)載須知

          如需轉(zhuǎn)載,請(qǐng)?jiān)陂_(kāi)篇顯著位置注明作者和出處(轉(zhuǎn)自:數(shù)據(jù)派ID:DatapiTHU),并在文章結(jié)尾放置數(shù)據(jù)派醒目二維碼。有原創(chuàng)標(biāo)識(shí)文章,請(qǐng)發(fā)送【文章名稱-待授權(quán)公眾號(hào)名稱及ID】至聯(lián)系郵箱,申請(qǐng)白名單授權(quán)并按要求編輯。

          發(fā)布后請(qǐng)將鏈接反饋至聯(lián)系郵箱(見(jiàn)下方)。未經(jīng)許可的轉(zhuǎn)載以及改編者,我們將依法追究其法律責(zé)任。



          點(diǎn)擊“閱讀原文”擁抱組織



          瀏覽 56
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  裸体美女A | 俺去也俺来也在线www官网 | 青青草原视频在线 | 女人高潮视频免费观看网站 | 无码欧美人XXXXX日本无码 |