知乎 | 寫深度學(xué)習(xí)代碼需要遵守哪些順序?
來源 | 知乎問答
地址?|?https://www.zhihu.com/question/498167513
本文僅作學(xué)術(shù)分享,若侵權(quán)請(qǐng)聯(lián)系后臺(tái)刪文處理
回答一:作者-三思但不猶豫
前段時(shí)間剛重寫了一個(gè) dl 任務(wù),在此說下心得體會(huì):
順序上,先 dataset,檢查基本的 transform,再搭 model,構(gòu)建 head 和 loss,就可以把一個(gè)基礎(chǔ)的、可以跑的網(wǎng)絡(luò)就能跑起來了(這點(diǎn)很重要); 可視化很重要,如果是本地開發(fā)機(jī),善用 cv.imshow 直觀、便捷地可視化處理的結(jié)果; 一個(gè)基礎(chǔ)的 train/inference 流程跑通后,分別構(gòu)建 1 張、10 張的數(shù)據(jù)用于 debug,確保任意改動(dòng)后,可以 overfit; 調(diào)試代碼階段避免隨機(jī)性、避免數(shù)據(jù)增強(qiáng),一定用 tensorboard 之類的工具觀察 loss 下降是否合理; 一般數(shù)據(jù)集最好處理成 coco 的格式,我的任務(wù)跟傳統(tǒng)任務(wù)不太一樣,但也盡量仿照 coco 來設(shè)計(jì),寫 dataset 的時(shí)候可以參考開源實(shí)現(xiàn); 善用開源框架,比如 Open-MMLab,Detectron2 之類的,好處是方便實(shí)驗(yàn),在框架里寫不容易出現(xiàn)難以察覺的 bug,壞處是開源框架為了適配各種網(wǎng)絡(luò),代碼復(fù)雜程度會(huì)高一點(diǎn),建議從第一版入手了解框架,然后基于最新的一邊閱讀一邊開發(fā)。
最后,想要更穩(wěn)健的開發(fā)流程,參考 Karpathy 大神的:
http://karpathy.github.io/2019/04/25/recipe/
回答二:作者-撿到一束光
先給結(jié)論:以我寫了兩三年pytorch代碼的經(jīng)驗(yàn)而言,比較好的順序是先寫model,再寫dataset,最后寫train。
在討論碼組件的具體順序前,我們先分析每一個(gè)組件背后的目的和邏輯。
model構(gòu)成了整個(gè)深度學(xué)習(xí)訓(xùn)練與推斷系統(tǒng)骨架,也確定了整個(gè)AI模型的輸入和輸出格式。對(duì)于視覺任務(wù),模型架構(gòu)多為卷積神經(jīng)網(wǎng)絡(luò)或是最新的ViT模型;對(duì)于NLP任務(wù),模型架構(gòu)多為Transformer以及Bert;對(duì)于時(shí)間序列預(yù)測,模型架構(gòu)多為RNN或LSTM。不同的model對(duì)應(yīng)了不同的數(shù)據(jù)輸入格式,如ResNet一般是輸入多通道二維矩陣,而ViT則需要輸入帶有位置信息的圖像patchs。確定了用什么樣的model后,數(shù)據(jù)的輸入格式也就確定下來。根據(jù)確定的輸入格式,我們才能構(gòu)建對(duì)應(yīng)的dataset。
dataset構(gòu)建了整個(gè)AI模型的輸入與輸出格式。在寫作dataset組件時(shí),我們需要考慮數(shù)據(jù)的存儲(chǔ)位置與存儲(chǔ)方式,如數(shù)據(jù)是否是分布式存儲(chǔ)的,模型是否要在多機(jī)多卡的情況下運(yùn)行,讀寫速度是否存在瓶頸,如果機(jī)械硬盤帶來了讀寫瓶頸則需要將數(shù)據(jù)預(yù)加載進(jìn)內(nèi)存等。在寫dataset組件時(shí),我們也要反向微調(diào)model組件。例如,確定了分布式訓(xùn)練的數(shù)據(jù)讀寫后,需要用nn.DataParallel或者nn.DistributedDataParallel等模塊包裹model,使模型能夠在多機(jī)多卡上運(yùn)行。此外,dataset組件的寫作也會(huì)影響訓(xùn)練策略,這也為構(gòu)建train組件做了鋪墊。比如根據(jù)顯存大小,我們需要確定相應(yīng)的BatchSize,而BatchSize則直接影響學(xué)習(xí)率的大小。再比如根據(jù)數(shù)據(jù)的分布情況,我們需要選擇不同的采樣策略進(jìn)行Feature Balance,而這也會(huì)體現(xiàn)在訓(xùn)練策略中。
train構(gòu)建了模型的訓(xùn)練策略以及評(píng)估方法,它是最重要也是最復(fù)雜的組件。先構(gòu)建model與dataset可以添加限制,減少train組件的復(fù)雜度。在train組件中,我們需要根據(jù)訓(xùn)練環(huán)境(單機(jī)多卡,多機(jī)多卡或是聯(lián)邦學(xué)習(xí))確定模型更新的策略,以及確定訓(xùn)練總時(shí)長epochs,優(yōu)化器的類型,學(xué)習(xí)率的大小與衰減策略,參數(shù)的初始化方法,模型損失函數(shù)。此外,為了對(duì)抗過擬合,提升泛化性,還需要引入合適的正則化方法,如Dropout,BatchNorm,L2-Regularization,Data Augmentation等。有些提升泛化性能的方法可以直接在train組件中實(shí)現(xiàn)(如添加L2-Reg,Mixup),有些則需要添加進(jìn)model中(如Dropout與BatchNorm),還有些需要添加進(jìn)dataset中(如Data Augmentation)。此處安利一下我們的專欄教程:數(shù)據(jù)增廣的方法與代碼實(shí)現(xiàn)(https://zhuanlan.zhihu.com/p/439206910)。train還需要記錄訓(xùn)練過程的一些重要信息,并將這些信息可視化出來,比如在每個(gè)epoch上記錄訓(xùn)練集的平均損失以及測試集精度,并將這些信息寫入tensorboard,然后在網(wǎng)頁端實(shí)時(shí)監(jiān)控。在構(gòu)建train組件中,我們需要隨時(shí)根據(jù)模型表現(xiàn)進(jìn)行參數(shù)微調(diào),并根據(jù)結(jié)果改進(jìn)model和dataset兩個(gè)組件。
model-dataset-train的順序進(jìn)行構(gòu)建,實(shí)現(xiàn)了單機(jī)多卡,聯(lián)邦學(xué)習(xí)等訓(xùn)練環(huán)境:在Cifar10與Cifar100上采用各種ResNet,以Mixup作為數(shù)據(jù)增廣策略,實(shí)現(xiàn)監(jiān)督分類與無監(jiān)督學(xué)習(xí)(https://github.com/FengHZ/mixupfamily)。關(guān)于數(shù)據(jù)增廣策略Mixup的科普也可以移步我們的專欄Mixup的一個(gè)綜述(https://zhuanlan.zhihu.com/p/439205252)。 在5種Bencnmark數(shù)據(jù)集上實(shí)現(xiàn)聯(lián)邦遷移學(xué)習(xí)(https://github.com/FengHZ/KD3A)。
回答三:作者-芙蘭朵露
往期精彩:
?時(shí)隔一年!深度學(xué)習(xí)語義分割理論與代碼實(shí)踐指南.pdf第二版來了!
?基于 docker 和 Flask 的深度學(xué)習(xí)模型部署!
?新書預(yù)告 | 《機(jī)器學(xué)習(xí)公式推導(dǎo)與代碼實(shí)現(xiàn)》出版在即!
