Pytorch 數(shù)據(jù)流中常見Trick總結(jié)
重磅干貨,第一時間送達(dá)
前言
在使用Pytorch建模時,常見的流程為先寫Model,再寫Dataset,最后寫Trainer。Dataset 是整個項目開發(fā)中投入時間第二多,也是中間關(guān)鍵的步驟。往往需要事先對于其設(shè)計有明確的思考,不然可能會因為Dataset的一些問題又要去調(diào)整Model,Trainer。本文將目前開發(fā)中的一些思考以及遇到的問題做一個總結(jié),提供給各位讀者一個比較通用的模版,拋磚引玉~
一、Dataset的定義
from?torch.utils.data?import?Dataset,?DataLoader,?RandomSampler
對于不同類型的建模任務(wù),模型的輸入各不相同。自然語言,多模態(tài),點擊率預(yù)估,往往這些場景輸入模型的數(shù)據(jù)并不是來自于單一文件,而且可能無法全部存入內(nèi)存。Dataset需要整合項目的數(shù)據(jù),對于單條樣本涉及到的數(shù)據(jù)做一個提取與歸納。不但如此,項目可能還涉及到多種模型,任務(wù)的訓(xùn)練。Dataset需要為不同的模型以及訓(xùn)練任務(wù)提供不同的單條樣本輸入,作為一個數(shù)據(jù)生成器,把后續(xù)模型訓(xùn)練任務(wù)需要的所有基礎(chǔ)數(shù)據(jù),標(biāo)簽全返回了。所以往往我們可以定義一個BaseDataset類,繼承torch.utils.data.Dataset,這個類可以初始化一些文件路徑,配置等。后面不同的模型訓(xùn)練任務(wù)定義相應(yīng)的Dataset類繼承BaseDataset。
Dataset通用的結(jié)構(gòu)為:
class?BaseDataset(Dataset):
????def?__init__(self,?config):
????????self.config?=?config
????????if?os.path.isfile(config.file_path)?is?False:
????????????raise?ValueError(f"Input?file?path?{config.file_path}?not?found")
????????logger.info(f"Creating?features?from?dataset?file?at?{config.file_path}")
????????#?一次性全讀進(jìn)內(nèi)存
????????self.data?=?joblib.load(config.file_path)
????????self.nums?=?len(self.data)
????def?__len__(self):
????????return?self.nums
????def?__getitem__(self,?i)?->?Dict[str,?tensor]:
????????sample_i?=?self.data[i]
????????return?{"f1":torch.tensor(sample_i["f1"]).long(),"f2":torch.tensor(sample_i["f2"]).long(),torch.LongTensor([sample_i["label"]])}
如果無法全部讀取進(jìn)內(nèi)存需要再__getitem__方法內(nèi)構(gòu)建數(shù)據(jù),做自然語言則可以吧tokenizer初始化到該類中,在__getitem__方法內(nèi)完成tokenizer。改方法的輸出推薦做成字典形式。
對于不同的訓(xùn)練任務(wù)可以通過以下方法返回響應(yīng)的數(shù)據(jù)生成器
def?build_dataset(task_type,?features,?**kwargs):
????assert?task_type?in?['task1',?'task2'],?'task?mismatch'
????if?task_type?==?'task1':
????????dataset?=?task1Dataset(features))
????else:
????????dataset?=?task2Dataset(features)
????return?dataset
有時模型的訓(xùn)練任務(wù)需要做數(shù)據(jù)增強(qiáng),對比學(xué)習(xí),構(gòu)造多種的預(yù)訓(xùn)練任務(wù)輸入。Dataset的職能邊界是提供一套基礎(chǔ)的單樣本數(shù)據(jù)輸入生成器。如果是MLM任務(wù),可以在Dataset內(nèi)生成maskposition以及l(fā)abel。如果是在batch內(nèi)的對比學(xué)習(xí)則應(yīng)該在DataLoader生產(chǎn)batch數(shù)據(jù)后再進(jìn)行。
二、DataLoader的定義
DataLoader的作用是對Dataset進(jìn)行多進(jìn)程高效地構(gòu)建每個訓(xùn)練批次的數(shù)據(jù)。傳入的數(shù)據(jù)可以認(rèn)為是長度為batch大小的多個__getitem__ 方法返回的字典list。DataLoader的職能邊界是根據(jù)Dataset提供的單條樣本數(shù)據(jù)有選擇的構(gòu)建一個batch的模型輸入數(shù)據(jù)。
其通常的結(jié)構(gòu)為對Train,Valid,Test分別建立:
train_sampler?=?RandomSampler(train_dataset)
train_loader?=?DataLoader(dataset=train_dataset,
??????????????????????????????batch_size=args.train_batch_size,
??????????????????????????????sampler=train_sampler,
??????????????????????????????shuffle=(train_sampler?is?None)
??????????????????????????????collate_fn=None,?#?一般不用設(shè)置
??????????????????????????????num_workers=4)
首先對于sampler 還有一種定義方式:
sampler?=?torch.utils.data.distributed.DistributedSampler(dataset)
至于batch內(nèi)數(shù)據(jù)是否需要做shuffle也需要根據(jù)損失函數(shù)確定(對比學(xué)習(xí)慎用)
DataLoader會自動合并__getitem__ 方法返回的字典內(nèi)每個key內(nèi)每個tensor,在tensor的第0維度新增一個batch大小的維度。如果該方法返回的每條樣本長度不同無法拼接,batchsize>1就會報錯。但是又一些任務(wù)在還沒有確定后續(xù)的批樣本對應(yīng)的任務(wù)時,Dataset可能返回的字典里每個key可能就是長度不同的tensor,甚至是list,這時候需要使用collate_fn參數(shù)告訴DataLoader如何取樣。我們可以定義自己的函數(shù)來準(zhǔn)確地實現(xiàn)想要的功能。
如果__getitem__方法返回的是tuple((list, list)) 可以使用:
def?merge_sample(x):
????return?zip(*x)
train_loader?=?DataLoader(dataset=train_dataset,
??????????????????????????????batch_size=args.train_batch_size,
??????????????????????????????sampler=train_sampler,
??????????????????????????????shuffle=(train_sampler?is?None)
??????????????????????????????collate_fn=merge_sample,
??????????????????????????????num_workers=4)
拼接數(shù)據(jù),后續(xù)再做進(jìn)一步處理。(此時list內(nèi)數(shù)據(jù)還是不等長,無法轉(zhuǎn)為tensor)
如果__getitem_方法返回的是Dict[str,tensor],自定義的collate_fn方法內(nèi)需要實現(xiàn):List[Dict[str,tensor(xx)]]->Dict[str,tensor(bs,xx)]的操作,pad_sequence過程也可以在自定義方法內(nèi)實現(xiàn)。(總之collate_fn中不但可以處理不等長數(shù)據(jù),還可以對一個batch的數(shù)據(jù)做精修。當(dāng)然也可以在DataLoader之后再做修改batch內(nèi)的數(shù)據(jù)。)
值得注意的是在cpu環(huán)境下,如果要自定義collate_fn,num_workers必須設(shè)置為0,不然就會有問題..
通過以下方式可以檢查一下輸入后續(xù)模型的數(shù)據(jù)是否已經(jīng)是想要的格式
for?step,?batch_data?in?enumerate(train_loader):
????if?step?1:
????????print(batch_data)
????else:
????????break
之后數(shù)據(jù)將數(shù)據(jù)放入gpu device, 一個batch的數(shù)據(jù)進(jìn)入device端后就與內(nèi)存上的數(shù)據(jù)不再互相干擾。之后數(shù)據(jù)就可以喂給模型了:
for?key?in?batch_data.keys():
????batch_data[key]?=?batch_data[key].to(device)
loss?=?model(**batch_data)交流群
歡迎加入公眾號讀者群一起和同行交流,目前有美顏、三維視覺、計算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群
個人微信(如果沒有備注不拉群!) 請注明:地區(qū)+學(xué)校/企業(yè)+研究方向+昵稱
下載1:何愷明頂會分享
在「AI算法與圖像處理」公眾號后臺回復(fù):何愷明,即可下載。總共有6份PDF,涉及 ResNet、Mask RCNN等經(jīng)典工作的總結(jié)分析
下載2:終身受益的編程指南:Google編程風(fēng)格指南
在「AI算法與圖像處理」公眾號后臺回復(fù):c++,即可下載。歷經(jīng)十年考驗,最權(quán)威的編程規(guī)范!
下載3 CVPR2021 在「AI算法與圖像處理」公眾號后臺回復(fù):CVPR,即可下載1467篇CVPR?2020論文 和 CVPR 2021 最新論文

