<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          TorchMetrics:PyTorch的指標(biāo)度量庫

          共 3715字,需瀏覽 8分鐘

           ·

          2021-04-19 16:20

          編譯 | ronghuaiyang

          轉(zhuǎn)自 | AI公園


          找出你需要評估的指標(biāo)是深度學(xué)習(xí)的關(guān)鍵。有各種各樣的指標(biāo),我們可以評估ML算法的性能。TorchMetrics是一個PyTorch度量的實現(xiàn)的集合,是PyTorch Lightning高性能深度學(xué)習(xí)的框架的一部分。在本文中,我們將介紹如何使用TorchMetrics評估你的深度學(xué)習(xí)模型,甚至使用一個簡單易用的API創(chuàng)建你自己的度量。

          什么是TorchMetrics?

          TorchMetrics是一個開源的PyTorch原生的函數(shù)和度量模塊的集合,用于簡單的性能評估。你可以使用開箱即用的實現(xiàn)來實現(xiàn)常見的指標(biāo),如準(zhǔn)確性,召回率,精度,AUROC, RMSE, R2等,或者創(chuàng)建你自己的指標(biāo)。我們目前支持超過25個指標(biāo),并不斷增加更多的通用任務(wù)和特定領(lǐng)域的標(biāo)準(zhǔn)(目標(biāo)檢測,NLP等)。

          TorchMetrics最初是作為Pytorch Lightning (PL)的一部分創(chuàng)建的,被設(shè)計為分布式硬件兼容,并在默認(rèn)情況下與DistributedDataParalel(DDP)一起工作。所有指標(biāo)都在cpu和gpu上經(jīng)過嚴(yán)格測試。

          使用TorchMetrics

          安裝

          這個包可以通過以下方式從PyPI簡單安裝:

          pip install torchmetrics

          或者直接從GitHub倉庫的源代碼安裝:

          # with git
          pip install git+https://github.com/PytorchLightning/metrics.git@master

          函數(shù)形式的metrics

          類似于torch.nn,大多數(shù)度量指標(biāo)都有基于模塊和函數(shù)的版本。函數(shù)版本實現(xiàn)了計算每個度量所需的基本操作。它們是作為輸入的簡單的python函數(shù)。并返回相應(yīng)的torch.tensor的指標(biāo)。下面的代碼片段展示了一個使用函數(shù)接口計算精度的簡單示例:

          模塊形式的metrics

          幾乎所有函數(shù)metrics都有一個對應(yīng)的基于模塊的metrics,該度量將其稱為底層的函數(shù)等價模塊?;谀K的度量的特點是有一個或多個內(nèi)部度量狀態(tài)(類似于PyTorch模塊的參數(shù)),允許它們提供額外的功能:

          • 多批次積累
          • 多臺設(shè)備間自動同步
          • 度量算法

          下面的代碼展示了如何使用基于模塊的接口:

          每次調(diào)用度量的forward函數(shù)時,我們同時計算當(dāng)前看到的一批數(shù)據(jù)上的度量值,并更新內(nèi)部度量狀態(tài),以跟蹤到目前為止看到的所有數(shù)據(jù)。內(nèi)部狀態(tài)需要在不同時期之間重置,不應(yīng)該在訓(xùn)練、驗證和測試之間混合。因此我們強(qiáng)烈建議按如下方式重新初始化度量:

          Lightning中使用TorchMetrics

          下面的例子展示了如何在你的LightningModule中使用metric :

          雖然TorchMetrics被構(gòu)建為與原生的PyTorch一起使用,但TorchMetrics與Lightning一起使用提供了額外的好處:

          • 當(dāng)在LightningModule中正確定義模塊metrics 時,模塊metrics會自動放置在正確的設(shè)備上。這意味著你的數(shù)據(jù)將始終與你的metrics 放在相同的設(shè)備上。
          • 在Lightning中支持使用原生的self.log,Lightning會根據(jù)on_stepon_epoch標(biāo)志來記錄metric,如果on_epoch=True,logger 會在epoch結(jié)束的時候自動調(diào)用.compute()。
          • metric 的.reset()方法的度量在一個epoch結(jié)束后自動被調(diào)用。

          Lightning的轉(zhuǎn)換

          已經(jīng)熟悉Lightning的metric接口的用戶應(yīng)該能夠輕松地適應(yīng)TorchMetrics。簡單地替換:

          from pytorchlightning import metrics

          with:

          import torchmetrics

          注意,在1.3版本之前,metrics將是PyTorchLightning的一部分,但不再接收任何更新。我們強(qiáng)烈建議用戶切換到TorchMetrics,以得到我們可能實現(xiàn)的所有的bug修復(fù)和增強(qiáng)。

          實現(xiàn)自己的metrics

          如果你想使用一個還不被支持的指標(biāo),你可以使用TorchMetrics的API來實現(xiàn)你自己的自定義指標(biāo),只需子類化torchmetrics.Metric并實現(xiàn)以下方法:

          1. __init__():每個狀態(tài)變量都應(yīng)該使用self.add_state(…)調(diào)用。
          2. update():任何需要更新內(nèi)部度量狀態(tài)的代碼。
          3. compute():從度量值的狀態(tài)計算一個最終值。

          例子:均方根誤差

          均方根誤差是一個很好的例子,說明了為什么許多度量計算需要劃分為兩個函數(shù)。定義為:

          為了正確地計算RMSE,我們需要兩個度量狀態(tài):sum_squared_error來跟蹤目標(biāo)y和預(yù)測y之間的平方誤差,以及n_observations來知道我們有多少觀測結(jié)果。

          因為sqrt(a+b) != sqrt(a) + sqrt(b),我們不能把這個度量實現(xiàn)為每個batch計算的RMSE分?jǐn)?shù)的簡單平均值,而是需要實現(xiàn)更新步驟中需要在平方根之前發(fā)生的所有邏輯,以及在compute步驟中需要實現(xiàn)剩余的邏輯。

          為你的模型選擇正確的度量

          選擇正確的度量對于確定你的模型是否按照應(yīng)該的方式運行,或者是否有什么地方出了問題非常重要。

          預(yù)測冠狀病毒

          假設(shè)你的任務(wù)是建立一個分類網(wǎng)絡(luò),可以通過一套非侵入性測量來確定患者是否是冠狀病毒陽性。你會得到數(shù)千份觀察報告,并使用你最喜歡的網(wǎng)絡(luò)架構(gòu),優(yōu)化以正確識別哪些患者感染了冠狀病毒。這種模式可用于確保檢測呈陽性的患者被隔離,以避免傳播病毒并迅速得到治療。

          為了評估你的模型,你計算了4個指標(biāo):準(zhǔn)確性、混淆矩陣、精確度和召回率。你得到了以下結(jié)果:

          準(zhǔn)確率: 99.9%

          混淆矩陣

          精確率: 1.0

          召回率:0.28

          評估得分

          你怎么看?這個模型足夠好嗎?讓我們更深入地了解這些指標(biāo)的含義。在分類中,準(zhǔn)確率是指我們的模型得到正確預(yù)測的比例。

          我們的模型得到了非常高的準(zhǔn)確率:99.9%??磥砭W(wǎng)絡(luò)正在做你要求它做的事情,你可以準(zhǔn)確地檢測到患者是否感染了冠狀病毒。

          對于二元分類,另一個有用的度量是混淆矩陣,這給了我們下面的真、假陽性和陰性的組合。

          我們可以從混淆矩陣中快速確定兩件事:

          • 陰性患者的數(shù)量遠(yuǎn)遠(yuǎn)少于陽性患者的數(shù)量 —> 這意味著你的數(shù)據(jù)集是高度不平衡的。
          • 有5名患者檢測失敗

          從準(zhǔn)確性來看,這個模型似乎表現(xiàn)得很好,但考慮到混淆矩陣,我們發(fā)現(xiàn)這個模型過于專注于預(yù)測陰性患者,而未能預(yù)測陽性患者。在這種設(shè)置下,它應(yīng)該清楚正確識別新冠患者和正確識別非新冠患者之間的巨大的區(qū)別,正確識別患者將確保患者得到早期治療,最重要的是隔離,不要傳染給別人。

          為什么準(zhǔn)確率指標(biāo)沒有顯示出模型有什么問題?準(zhǔn)確率捕獲了整體性能,以正確地預(yù)測所有類,在這種情況下,我們感興趣的是捕獲我們預(yù)測的ground truth的情況有多好。因此,你可以將注意力轉(zhuǎn)向精確率和召回率。

          精確率定義為實際正確的正樣本的比例。

          其中TP和FP分別表示true p positive個數(shù),false positive個數(shù)。一個有0個誤報的模型的精確率為1.0,而一個模型輸出的結(jié)果都是陽性,而實際上都是假的模型的精度分?jǐn)?shù)為0。

          Recall定義為真實的陽性被正確識別的比例。

          其中TP和FN分別表示true positives數(shù),false negatives數(shù)。類似地,如果沒有錯誤否定,一個模型的召回分?jǐn)?shù)將為1.0。從定義上我們可以得出結(jié)論,精確率聚焦于在不能識別所有假陽性的“成本”上,而召回率聚焦在不能識別所有假陰性的“成本”上。因為我們在這里感興趣的是假陰性,所以我們應(yīng)該在recall metric下重新評估我們的模型,現(xiàn)在我們得到了0.28的分?jǐn)?shù)?,F(xiàn)在,你已經(jīng)量化了模型的性能不佳,并且在訓(xùn)練機(jī)器學(xué)習(xí)算法時可能需要處理數(shù)據(jù)集中存在的巨大類不平衡。

          這個小例子展示了選擇正確度量來評估機(jī)器學(xué)習(xí)算法的重要性。通常,建議使用一組度量標(biāo)準(zhǔn)來評估算法,因為它們都關(guān)注數(shù)據(jù)和模型預(yù)測的不同方面。


          英文原文:https://pytorch-lightning.medium.com/torchmetrics-pytorch-metrics-built-to-scale-7091b1bec919


          往期精彩:

          【原創(chuàng)首發(fā)】機(jī)器學(xué)習(xí)公式推導(dǎo)與代碼實現(xiàn)30講.pdf

          【原創(chuàng)首發(fā)】深度學(xué)習(xí)語義分割理論與實戰(zhàn)指南.pdf

           談中小企業(yè)算法崗面試

           算法工程師研發(fā)技能表

           真正想做算法的,不要害怕內(nèi)卷

           算法工程師的日常,一定不能脫離產(chǎn)業(yè)實踐

           技術(shù)學(xué)習(xí)不能眼高手低

           技術(shù)人要學(xué)會自我營銷

           做人不能過擬合

          求個在看

          瀏覽 64
          點贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲精品成人在线视频久久 | 国产性情网站在线看 | 国产精品永久无码AV毛片18禁 | 亚洲成人网站视频 | 激情网站五月天 |