深度學(xué)習(xí)中的“不確定性基線”


發(fā)布人:Google Research Brain 團(tuán)隊(duì)研究工程師 Zachary Nado 和研究員 Dustin Tran
機(jī)器學(xué)習(xí) (ML) 越來越多地被用于實(shí)際應(yīng)用,因此了解模型的不確定性和穩(wěn)健性對(duì)于確保其在實(shí)踐中的性能很有必要。例如,將模型部署到與訓(xùn)練數(shù)據(jù)不同的數(shù)據(jù)上時(shí),其表現(xiàn)如何?模型在可能出錯(cuò)時(shí)如何發(fā)出信號(hào)?
不確定性和穩(wěn)健性
https://slideslive.com/38935801/practical-uncertainty-estimation-outofdistribution-robustness-in-deep-learning
為掌握 ML 模型的行為,我們通常會(huì)根據(jù)目標(biāo)任務(wù)的基線來衡量其性能。對(duì)于每個(gè)基線,研究人員必須嘗試僅使用相應(yīng)論文中的描述來重現(xiàn)結(jié)果,這為復(fù)現(xiàn)帶來了嚴(yán)峻挑戰(zhàn)。在實(shí)驗(yàn)代碼得到完好記錄和維護(hù)的前提下,查看這些代碼可能更有用。但是,基線必須經(jīng)過嚴(yán)格驗(yàn)證,因此僅僅查看代碼還不夠。
帶來了嚴(yán)峻挑戰(zhàn)
https://paperswithcode.com/rc2020
例如,在對(duì)一系列研究 [1、2、3] 進(jìn)行回顧性分析時(shí),作者會(huì)發(fā)現(xiàn)簡單且經(jīng)過優(yōu)化的基線往往優(yōu)于更復(fù)雜的方法。為真正了解模型之間的相對(duì)表現(xiàn),且讓研究人員能夠衡量新理念是否切實(shí)取得有意義的進(jìn)展,必須將目標(biāo)模型與共同基線進(jìn)行比較。
1
https://arxiv.org/abs/1707.05589
2
https://arxiv.org/abs/1807.04720
3
https://arxiv.org/abs/2102.06356
在“不確定性基線:深度學(xué)習(xí)中不確定性和穩(wěn)健性的基準(zhǔn)?(Uncertainty Baselines: Benchmarks for Uncertainty & Robustness in Deep Learning) ”一文中,我們介紹了“不確定性基線”,這是針對(duì)各種任務(wù)的標(biāo)準(zhǔn)化和先進(jìn)深度學(xué)習(xí)方法的高質(zhì)量實(shí)現(xiàn)合集,旨在促使不確定性和穩(wěn)健性的相關(guān)研究更具可重復(fù)性。
不確定性基線:深度學(xué)習(xí)中不確定性和穩(wěn)健性的基準(zhǔn)
https://arxiv.org/abs/2106.04015
不確定性基線
https://github.com/google/uncertainty-baselines
該合集涵蓋 9 大任務(wù)的 19 種方法,每種方法有至少五項(xiàng)指標(biāo)。每個(gè)基線都是獨(dú)立的實(shí)驗(yàn)流水線,具有易于重復(fù)使用且可擴(kuò)展的組件,并且在其編寫框架之外具有最小的依賴性。內(nèi)含的流水線可在 TensorFlow、PyTorch 和 Jax 中得到實(shí)現(xiàn)。此外,每個(gè)基線的超參數(shù)都已在多次迭代中經(jīng)過廣泛調(diào)整,可提供更有力的結(jié)果。
TensorFlow
https://tensorflow.google.cn/
PyTorch
https://pytorch.org/
Jax
https://jax.readthedocs.io/en/latest/notebooks/quickstart.html


至撰寫本文時(shí),不確定性基線共提供了 83 個(gè)基線,包括 19 種方法,涵蓋九個(gè)數(shù)據(jù)集的標(biāo)準(zhǔn)和最新策略。示例方法包括 BatchEnsemble(批集成)、Deep Ensembles(深度集成)、Rank-1 Bayesian Neural Nets(1 階貝葉斯神經(jīng)網(wǎng)絡(luò))、Monte Carlo Dropout 和 Spectral-normalized Neural Gaussian Processes(光譜歸一化神經(jīng)高斯過程)。
BatchEnsemble
https://arxiv.org/abs/2002.06715
Deep Ensembles
https://arxiv.org/abs/1612.01474
Rank-1 Bayesian Neural Nets
https://arxiv.org/abs/2005.07186
Monte Carlo Dropout
https://arxiv.org/abs/1506.02142
Spectral-normalized Neural Gaussian Processes
https://arxiv.org/abs/2006.10108
不確定性基線可以作為繼任者,合并社區(qū)中如下流行基準(zhǔn):您可以相信模型的不確定性嗎?、BDL 基準(zhǔn)和 Edward2 基線。
您可以相信模型的不確定性嗎?
https://github.com/google-research/google-research/tree/master/uq_benchmark_2019
BDL 基準(zhǔn)
https://github.com/OATML/bdl-benchmarks
Edward2 基線
https://github.com/google/edward2/tree/master/baselines
數(shù)據(jù)集 | 輸入 | 輸出 | 訓(xùn)練示例 | 測試 數(shù)據(jù)集 |
CIFAR | RGB 圖像 | 10 類分布 | 50,000 | 3 |
ImageNet | RGB 圖像 | 1000 類分布 | 1,281,167 | 6 |
CLINC 意圖檢測 | 對(duì)話框系統(tǒng) 查詢文本 | 150 類分布 (10 個(gè)網(wǎng)域) | 15,000 | 2 |
Kaggle 糖尿病性視網(wǎng)膜病變檢測 | RGB 圖像 | 糖尿病性視網(wǎng)膜病變的概率 | 35,126 | 1 |
維基百科 毒性 | 維基百科評(píng)論文本 | 毒性概率 | 159,571 | 3 |
CIFAR
https://www.cs.toronto.edu/~kriz/cifar.html
ImageNet
https://image-net.org/
CLINC 意圖檢測
https://github.com/clinc/oos-eval
Kaggle 糖尿病性視網(wǎng)膜病變檢測
https://www.kaggle.com/c/diabetic-retinopathy-detection
維基百科毒性
https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
我們共為 9 個(gè)可用數(shù)據(jù)集提供基線,上表顯示其中 5 個(gè)數(shù)據(jù)集的子集。數(shù)據(jù)集涵蓋了表格、文本和圖像模態(tài)。
不確定性基線根據(jù)選擇的基礎(chǔ)模型、訓(xùn)練數(shù)據(jù)集和一套評(píng)估指標(biāo)設(shè)置各個(gè)基線。然后通過超參數(shù)對(duì)各個(gè)基線進(jìn)行調(diào)整,以最大限度地提高這些指標(biāo)的性能。上述三個(gè)軸線的可用基線各不相同:
基礎(chǔ)模型(架構(gòu))包括 Wide ResNet 28-10、ResNet-50、BERT 和簡單的全連接網(wǎng)絡(luò)。
Wide ResNet 28-10
https://arxiv.org/abs/1605.07146
ResNet-50
https://arxiv.org/abs/1512.03385
BERT
https://arxiv.org/abs/1810.04805
訓(xùn)練數(shù)據(jù)集包括標(biāo)準(zhǔn)機(jī)器學(xué)習(xí)數(shù)據(jù)集(CIFAR、ImageNet 和 UCI)以及更多現(xiàn)實(shí)問題(Clinc 意圖檢測、Kaggle 糖尿病性視網(wǎng)膜病變檢測和維基百科毒性)。
UCI
https://archive.ics.uci.edu/ml/datasets.php
Clinc 意圖檢測
https://tensorflow.google.cn/datasets/catalog/clinc_oos
維基百科毒性
https://tensorflow.google.cn/datasets/catalog/wikipedia_toxicity_subtypes
評(píng)估包括預(yù)測性指標(biāo)(如準(zhǔn)確率)、不確定性指標(biāo)(如選擇性預(yù)測和校準(zhǔn)誤差)、計(jì)算指標(biāo)(推斷延遲)以及分布內(nèi)外數(shù)據(jù)集的性能。


為便于研究人員使用基線并在其基礎(chǔ)上進(jìn)行構(gòu)建,我們特意對(duì)其進(jìn)行優(yōu)化,盡可能采用模塊化設(shè)計(jì),并實(shí)現(xiàn)最小化。如下方工作流圖所示,不確定性基線沒有引入新的類抽象,而是重復(fù)使用生態(tài)系統(tǒng)中預(yù)先存在的類(例如 TensorFlow 的?tf.data.Dataset)。各個(gè)基線的訓(xùn)練/評(píng)估流水線均包含在實(shí)驗(yàn)的獨(dú)立 Python 文件(可以在 CPU、GPU 或 Google Cloud TPU 上運(yùn)行)中。由于基線之間的這種獨(dú)立性,我們得以在 TensorFlow、PyTorch 或 JAX 任意一者中開發(fā)基線。
tf.data.Dataset
https://tensorflow.google.cn/api_docs/python/tf/data/Dataset
PyTorch
https://github.com/google/uncertainty-baselines/blob/master/baselines/diabetic_retinopathy_detection/torch_dropout.py
JAX
https://github.com/google/uncertainty-baselines/blob/master/baselines/jft/deterministic.py

工作流示意圖:不確定性基線不同組成部分的構(gòu)造方式。所有數(shù)據(jù)集都是 BaseDataset 類的子類,BaseDataset 類提供的簡單 API 可用于使用任何受支持框架編寫的基線。然后,任何基線的輸出均可使用穩(wěn)健性指標(biāo)庫進(jìn)行分析
穩(wěn)健性指標(biāo)
https://github.com/google-research/robustness_metrics/
研究工程師對(duì)如何管理超參數(shù)和其他實(shí)驗(yàn)配置值(很輕松就能達(dá)到幾十個(gè))存在爭議。我們沒有使用針對(duì)該問題構(gòu)建的任意一個(gè)框架,不想冒用戶必須學(xué)習(xí)另一個(gè)庫的風(fēng)險(xiǎn),因此選擇僅使用 Python 標(biāo)志,這些標(biāo)志通過遵循 Python 約定的 Abseil 定義。大多數(shù)研究人員應(yīng)該非常熟悉該技術(shù),其很容易擴(kuò)展并插入其他流水線。
Abseil
https://abseil.io/docs/python/guides/flags


除了能夠使用記錄的命令運(yùn)行我們的所有基線并獲得相同的報(bào)告結(jié)果之外,我們還力求發(fā)布超參數(shù)調(diào)整結(jié)果和最終模型檢查點(diǎn),以進(jìn)一步實(shí)現(xiàn)可重復(fù)性。目前,我們只針對(duì)糖尿病性視網(wǎng)膜病變基線完全開源上述內(nèi)容,但我們會(huì)在運(yùn)行基線的過程中繼續(xù)上傳更多結(jié)果。此外,我們提供的基線示例在硬件確定性方面完全可重復(fù)。
糖尿病性視網(wǎng)膜病變基線
https://github.com/google/uncertainty-baselines/tree/master/baselines/diabetic_retinopathy_detection
基線示例
https://github.com/google/uncertainty-baselines/blob/df320d4987deddf2e23a8a7cb45eda87d3c5f210/baselines/cifar/deterministic.py#L132


代碼庫中包含的所有基線都經(jīng)過了廣泛的超參數(shù)調(diào)整,我們希望研究人員可以輕松地重復(fù)使用這些基線,而無需進(jìn)行昂貴的重新訓(xùn)練或重新調(diào)整。此外,我們希望避免流水線實(shí)現(xiàn)中影響基線比較的細(xì)微差異。
不確定性基線已被用于眾多研究項(xiàng)目。如果您是一名研究人員,想要貢獻(xiàn)其他方法或數(shù)據(jù)集,您可以在 GitHub 上創(chuàng)建一個(gè)議題,開啟討論!
眾多研究項(xiàng)目
https://github.com/google/uncertainty-baselines#papers-using-uncertainty-baselines


衷心感謝各位共同開發(fā)的人員,以及提供指導(dǎo)和/或幫助審核本文的人員:Neil Band、Mark Collier、Josip Djolonga、Michael W. Dusenberry、Sebastian Farquhar、Angelos Filos、Marton Havasi、Rodolphe Jenatton、Ghassen Jerfel、Jeremiah Liu、Zelda Mariet、Jeremy Nixon、Shreyas Padhy、Jie Ren、Tim G. J. Rudner、Yeming Wen、Florian Wenzel、Kevin Murphy、D. Sculley、Balaji Lakshminarayanan、Jasper Snoek、Yarin Gal。
推薦閱讀
輔助模塊加速收斂,精度大幅提升!移動(dòng)端實(shí)時(shí)的NanoDet-Plus來了!
SSD的torchvision版本實(shí)現(xiàn)詳解
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)


不要忘記“一鍵三連”哦~

分享

點(diǎn)贊

在看
