GitHub項目推薦|基于強化學習的自動化剪枝模型

極市導讀
GitHub上最新開源的一個基于強化學習的自動化剪枝模型,本模型在圖像識別的實驗證明了能夠有效減少計算量,同時還能提高模型的精度。 >>加入極市CV技術交流群,走在計算機視覺的最前沿
今天為大家介紹一個GitHub上最新開源的一個基于強化學習的自動化剪枝模型,本模型在圖像識別的實驗證明了能夠有效減少計算量,同時還能提高模型的精度。
項目地址:
https://github.com/freefuiiismyname/cv-automatic-pruning-transformer
介紹
目前的強化學習工作很多集中在利用外部環(huán)境的反饋訓練agent,忽略了模型本身就是一種能夠獲得反饋的環(huán)境。本項目的核心思想是:將模型視為環(huán)境,構建附生于模型的 agent ,以輔助模型進一步擬合真實樣本。

大多數(shù)領域的模型都可以采用這種方式來優(yōu)化,如cv/多模態(tài)等。它至少能夠以三種方式工作:
1.過濾噪音信息,如刪減語音或圖像特征;
2.進一步豐富表征信息,如高效引用外部信息;
3.實現(xiàn)記憶、聯(lián)想、推理等復雜工作,如構建重要信息的記憶池。
這里推出一款早期完成的裁剪機制transformer版本(后面稱為APT),實現(xiàn)了一種更高效的訓練模式,能夠優(yōu)化模型指標;此外,可以使用動態(tài)圖丟棄大量的不必要單元,在指標基本不變的情況下,大幅降低計算量。
該項目希望為大家拋磚引玉。

為什么要做自動剪枝
在具體任務中,往往存在大量毫無價值的信息和過渡性信息,有時不但對任務無益,還會成為噪聲。比如:表述會存在冗余/無關片段以及過渡性信息;動物圖像識別中,有時候背景無益于辨別動物主體,即使是動物部分圖像,也僅有小部分是關鍵的特征。

以transformer為例,在進行self-attention計算時其復雜度與序列長度平方成正比。長度為10,復雜度為100;長度為9,復雜度為81。
利用強化學習構建agent,能夠精準且自動化地動態(tài)裁剪已喪失意義部分,甚至能將長序列信息壓縮到50-100之內(實驗中有從500+的序列長度壓縮到個位數(shù)的示例),以大幅減少計算量。
實驗中,發(fā)現(xiàn)與裁剪agent聯(lián)合訓練的模型比普通方法訓練的模型效果要更好。
模型介紹及實驗
模型主體
基于transformer的視覺預訓練模型ViT是本項目的模型主體,具體細節(jié)可以查看論文:《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》

自動化裁剪的智能體
對于強化學習agent來說,最關鍵的問題之一是如何衡量動作帶來的反饋。為了評估單次動作所帶來的影響,使用了以下三步驟:
1、使用一個普通模型(無裁剪模塊)進行預測;
2、使用一個帶裁剪器的模型(執(zhí)行一次裁剪動作)進行預測;
3、對比兩次預測的結果,若裁剪后損失相對更小,則說明該裁剪動作幫助了模型進一步擬合真實狀況,應該得到獎勵;反之,應該受到懲罰。
但是在實際預測過程中,模型是同時裁剪多個單元的,這或將因為多個裁剪的連鎖反應而導致模型失效。訓練過程中需要構建一個帶裁剪器的模型(可執(zhí)行多次裁剪動作),以減小該問題所帶來的影響。
綜上,本模型使用的是三通道模式進行訓練。

關于裁剪器的模型結構設計,本模型中認為如何衡量一個信息單元是否對模型有意義,建立于其自身的信息及它與任務的相關性上。
因此以信息單元本身及它與CLS單元的交互作為agent的輸入信息。

實驗
數(shù)據集 | ViT | APT(pruning) | APT(no pruning) |
CIFAR-100 | 92.3 | 92.6 | 93.03 |
CIFAR-10 | 99.08 | 98.93 | 98.92 |
以上加載的均為ViT-B_16,resolution為224*224。
使用說明
環(huán)境

下載經過預先訓練的模型(來自Google官方)
本項目使用的型號:ViT-B_16(您也可以選擇其它型號進行測試)

訓練與推理
下載好預訓練模型就可以跑了。

CIFAR-10和CIFAR-100會自動下載和培訓。如果使用其他數(shù)據集,您需要自定義data_utils.py。
在裁剪模式的推理過程中,預期您將看到如下格式的輸出。

默認的batch size為72、gradient_accumulation_steps為3。當GPU內存不足時,您可以通過它們來進行訓練。
注:相較于原始的ViT,APT(Automatic pruning transformer)的訓練步數(shù)、訓練耗時都會上升。原因是使用pruning agent的模型由于總會丟失部分信息,使得收斂速度變慢,同時為了訓練pruning agent,也需要多次的觀測、行動、反饋。
致謝
感謝基于pytorch的圖像分類項目(https://github.com/jeonsworld/ViT-pytorch),本項目是在此基礎上做的研發(fā)。
最后再附上一次項目地址,歡迎感興趣的讀者Star?
https://github.com/freefuiiismyname/cv-automatic-pruning-transformer
如果覺得有用,就請分享到朋友圈吧!
公眾號后臺回復“目標檢測競賽”獲取目標檢測競賽經驗資源~

# CV技術社群邀請函 #
備注:姓名-學校/公司-研究方向-城市(如:小極-北大-目標檢測-深圳)
即可申請加入極市目標檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強/OCR/視頻理解等技術交流群
每月大咖直播分享、真實項目需求對接、求職內推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~

