<p id="m2nkj"><option id="m2nkj"><big id="m2nkj"></big></option></p>
    <strong id="m2nkj"></strong>
    <ruby id="m2nkj"></ruby>

    <var id="m2nkj"></var>
  • TorchMetrics:PyTorch的指標(biāo)度量庫

    共 3715字,需瀏覽 8分鐘

     ·

    2021-04-19 16:20

    編譯 | ronghuaiyang

    轉(zhuǎn)自 | AI公園


    找出你需要評估的指標(biāo)是深度學(xué)習(xí)的關(guān)鍵。有各種各樣的指標(biāo),我們可以評估ML算法的性能。TorchMetrics是一個PyTorch度量的實現(xiàn)的集合,是PyTorch Lightning高性能深度學(xué)習(xí)的框架的一部分。在本文中,我們將介紹如何使用TorchMetrics評估你的深度學(xué)習(xí)模型,甚至使用一個簡單易用的API創(chuàng)建你自己的度量。

    什么是TorchMetrics?

    TorchMetrics是一個開源的PyTorch原生的函數(shù)和度量模塊的集合,用于簡單的性能評估。你可以使用開箱即用的實現(xiàn)來實現(xiàn)常見的指標(biāo),如準(zhǔn)確性,召回率,精度,AUROC, RMSE, R2等,或者創(chuàng)建你自己的指標(biāo)。我們目前支持超過25個指標(biāo),并不斷增加更多的通用任務(wù)和特定領(lǐng)域的標(biāo)準(zhǔn)(目標(biāo)檢測,NLP等)。

    TorchMetrics最初是作為Pytorch Lightning (PL)的一部分創(chuàng)建的,被設(shè)計為分布式硬件兼容,并在默認(rèn)情況下與DistributedDataParalel(DDP)一起工作。所有指標(biāo)都在cpu和gpu上經(jīng)過嚴(yán)格測試。

    使用TorchMetrics

    安裝

    這個包可以通過以下方式從PyPI簡單安裝:

    pip install torchmetrics

    或者直接從GitHub倉庫的源代碼安裝:

    # with git
    pip install git+https://github.com/PytorchLightning/metrics.git@master

    函數(shù)形式的metrics

    類似于torch.nn,大多數(shù)度量指標(biāo)都有基于模塊和函數(shù)的版本。函數(shù)版本實現(xiàn)了計算每個度量所需的基本操作。它們是作為輸入的簡單的python函數(shù)。并返回相應(yīng)的torch.tensor的指標(biāo)。下面的代碼片段展示了一個使用函數(shù)接口計算精度的簡單示例:

    模塊形式的metrics

    幾乎所有函數(shù)metrics都有一個對應(yīng)的基于模塊的metrics,該度量將其稱為底層的函數(shù)等價模塊?;谀K的度量的特點是有一個或多個內(nèi)部度量狀態(tài)(類似于PyTorch模塊的參數(shù)),允許它們提供額外的功能:

    • 多批次積累
    • 多臺設(shè)備間自動同步
    • 度量算法

    下面的代碼展示了如何使用基于模塊的接口:

    每次調(diào)用度量的forward函數(shù)時,我們同時計算當(dāng)前看到的一批數(shù)據(jù)上的度量值,并更新內(nèi)部度量狀態(tài),以跟蹤到目前為止看到的所有數(shù)據(jù)。內(nèi)部狀態(tài)需要在不同時期之間重置,不應(yīng)該在訓(xùn)練、驗證和測試之間混合。因此我們強(qiáng)烈建議按如下方式重新初始化度量:

    Lightning中使用TorchMetrics

    下面的例子展示了如何在你的LightningModule中使用metric :

    雖然TorchMetrics被構(gòu)建為與原生的PyTorch一起使用,但TorchMetrics與Lightning一起使用提供了額外的好處:

    • 當(dāng)在LightningModule中正確定義模塊metrics 時,模塊metrics會自動放置在正確的設(shè)備上。這意味著你的數(shù)據(jù)將始終與你的metrics 放在相同的設(shè)備上。
    • 在Lightning中支持使用原生的self.log,Lightning會根據(jù)on_stepon_epoch標(biāo)志來記錄metric,如果on_epoch=True,logger 會在epoch結(jié)束的時候自動調(diào)用.compute()。
    • metric 的.reset()方法的度量在一個epoch結(jié)束后自動被調(diào)用。

    Lightning的轉(zhuǎn)換

    已經(jīng)熟悉Lightning的metric接口的用戶應(yīng)該能夠輕松地適應(yīng)TorchMetrics。簡單地替換:

    from pytorchlightning import metrics

    with:

    import torchmetrics

    注意,在1.3版本之前,metrics將是PyTorchLightning的一部分,但不再接收任何更新。我們強(qiáng)烈建議用戶切換到TorchMetrics,以得到我們可能實現(xiàn)的所有的bug修復(fù)和增強(qiáng)。

    實現(xiàn)自己的metrics

    如果你想使用一個還不被支持的指標(biāo),你可以使用TorchMetrics的API來實現(xiàn)你自己的自定義指標(biāo),只需子類化torchmetrics.Metric并實現(xiàn)以下方法:

    1. __init__():每個狀態(tài)變量都應(yīng)該使用self.add_state(…)調(diào)用。
    2. update():任何需要更新內(nèi)部度量狀態(tài)的代碼。
    3. compute():從度量值的狀態(tài)計算一個最終值。

    例子:均方根誤差

    均方根誤差是一個很好的例子,說明了為什么許多度量計算需要劃分為兩個函數(shù)。定義為:

    為了正確地計算RMSE,我們需要兩個度量狀態(tài):sum_squared_error來跟蹤目標(biāo)y和預(yù)測y之間的平方誤差,以及n_observations來知道我們有多少觀測結(jié)果。

    因為sqrt(a+b) != sqrt(a) + sqrt(b),我們不能把這個度量實現(xiàn)為每個batch計算的RMSE分?jǐn)?shù)的簡單平均值,而是需要實現(xiàn)更新步驟中需要在平方根之前發(fā)生的所有邏輯,以及在compute步驟中需要實現(xiàn)剩余的邏輯。

    為你的模型選擇正確的度量

    選擇正確的度量對于確定你的模型是否按照應(yīng)該的方式運行,或者是否有什么地方出了問題非常重要。

    預(yù)測冠狀病毒

    假設(shè)你的任務(wù)是建立一個分類網(wǎng)絡(luò),可以通過一套非侵入性測量來確定患者是否是冠狀病毒陽性。你會得到數(shù)千份觀察報告,并使用你最喜歡的網(wǎng)絡(luò)架構(gòu),優(yōu)化以正確識別哪些患者感染了冠狀病毒。這種模式可用于確保檢測呈陽性的患者被隔離,以避免傳播病毒并迅速得到治療。

    為了評估你的模型,你計算了4個指標(biāo):準(zhǔn)確性、混淆矩陣、精確度和召回率。你得到了以下結(jié)果:

    準(zhǔn)確率: 99.9%

    混淆矩陣

    精確率: 1.0

    召回率:0.28

    評估得分

    你怎么看?這個模型足夠好嗎?讓我們更深入地了解這些指標(biāo)的含義。在分類中,準(zhǔn)確率是指我們的模型得到正確預(yù)測的比例。

    我們的模型得到了非常高的準(zhǔn)確率:99.9%??磥砭W(wǎng)絡(luò)正在做你要求它做的事情,你可以準(zhǔn)確地檢測到患者是否感染了冠狀病毒。

    對于二元分類,另一個有用的度量是混淆矩陣,這給了我們下面的真、假陽性和陰性的組合。

    我們可以從混淆矩陣中快速確定兩件事:

    • 陰性患者的數(shù)量遠(yuǎn)遠(yuǎn)少于陽性患者的數(shù)量 —> 這意味著你的數(shù)據(jù)集是高度不平衡的。
    • 有5名患者檢測失敗

    從準(zhǔn)確性來看,這個模型似乎表現(xiàn)得很好,但考慮到混淆矩陣,我們發(fā)現(xiàn)這個模型過于專注于預(yù)測陰性患者,而未能預(yù)測陽性患者。在這種設(shè)置下,它應(yīng)該清楚正確識別新冠患者和正確識別非新冠患者之間的巨大的區(qū)別,正確識別患者將確保患者得到早期治療,最重要的是隔離,不要傳染給別人。

    為什么準(zhǔn)確率指標(biāo)沒有顯示出模型有什么問題?準(zhǔn)確率捕獲了整體性能,以正確地預(yù)測所有類,在這種情況下,我們感興趣的是捕獲我們預(yù)測的ground truth的情況有多好。因此,你可以將注意力轉(zhuǎn)向精確率和召回率。

    精確率定義為實際正確的正樣本的比例。

    其中TP和FP分別表示true p positive個數(shù),false positive個數(shù)。一個有0個誤報的模型的精確率為1.0,而一個模型輸出的結(jié)果都是陽性,而實際上都是假的模型的精度分?jǐn)?shù)為0。

    Recall定義為真實的陽性被正確識別的比例。

    其中TP和FN分別表示true positives數(shù),false negatives數(shù)。類似地,如果沒有錯誤否定,一個模型的召回分?jǐn)?shù)將為1.0。從定義上我們可以得出結(jié)論,精確率聚焦于在不能識別所有假陽性的“成本”上,而召回率聚焦在不能識別所有假陰性的“成本”上。因為我們在這里感興趣的是假陰性,所以我們應(yīng)該在recall metric下重新評估我們的模型,現(xiàn)在我們得到了0.28的分?jǐn)?shù)?,F(xiàn)在,你已經(jīng)量化了模型的性能不佳,并且在訓(xùn)練機(jī)器學(xué)習(xí)算法時可能需要處理數(shù)據(jù)集中存在的巨大類不平衡。

    這個小例子展示了選擇正確度量來評估機(jī)器學(xué)習(xí)算法的重要性。通常,建議使用一組度量標(biāo)準(zhǔn)來評估算法,因為它們都關(guān)注數(shù)據(jù)和模型預(yù)測的不同方面。


    英文原文:https://pytorch-lightning.medium.com/torchmetrics-pytorch-metrics-built-to-scale-7091b1bec919


    往期精彩:

    【原創(chuàng)首發(fā)】機(jī)器學(xué)習(xí)公式推導(dǎo)與代碼實現(xiàn)30講.pdf

    【原創(chuàng)首發(fā)】深度學(xué)習(xí)語義分割理論與實戰(zhàn)指南.pdf

     談中小企業(yè)算法崗面試

     算法工程師研發(fā)技能表

     真正想做算法的,不要害怕內(nèi)卷

     算法工程師的日常,一定不能脫離產(chǎn)業(yè)實踐

     技術(shù)學(xué)習(xí)不能眼高手低

     技術(shù)人要學(xué)會自我營銷

     做人不能過擬合

    求個在看

    瀏覽 64
    點贊
    評論
    收藏
    分享

    手機(jī)掃一掃分享

    分享
    舉報
    評論
    圖片
    表情
    推薦
    點贊
    評論
    收藏
    分享

    手機(jī)掃一掃分享

    分享
    舉報
    <p id="m2nkj"><option id="m2nkj"><big id="m2nkj"></big></option></p>
    <strong id="m2nkj"></strong>
    <ruby id="m2nkj"></ruby>

    <var id="m2nkj"></var>
  • 天天插一插| 欧美精品偷拍 | www.天天射 | 亚洲 欧美 国产 另类 | 色欲综合网 | 又粗又长又大的黄视频 | 国内免费精品视频 | 婷婷毛片 | 国产69精品久久久久久久久久 | 人人妻人人操青青 |