點(diǎn)擊上方“機(jī)器學(xué)習(xí)與生成對(duì)抗網(wǎng)絡(luò)”,關(guān)注星標(biāo)
獲取有趣、好玩的前沿干貨!
來源:知乎—Mario 編輯 人工智能前沿講習(xí)
地址:https://zhuanlan.zhihu.com/p/117270644
為啥突然要寫一下pytorch的dataloader呢,首先來說說事情的來龍去脈。起初,我最開始單獨(dú)訓(xùn)練一個(gè)網(wǎng)絡(luò)來完成landmark點(diǎn)回歸任務(wù)和分類任務(wù),訓(xùn)練的數(shù)據(jù)是txt格式,在訓(xùn)練之前對(duì)數(shù)據(jù)進(jìn)行分析,發(fā)現(xiàn)分類任務(wù)中存在嚴(yán)重的數(shù)據(jù)樣本不均衡的問題,那么我事先針對(duì)性的進(jìn)行數(shù)據(jù)采樣均衡操作,重新得到訓(xùn)練和測(cè)試的txt數(shù)據(jù)和標(biāo)簽,保證了整個(gè)訓(xùn)練和測(cè)試數(shù)據(jù)的樣本均衡性。由于我的整個(gè)項(xiàng)目是檢測(cè)+點(diǎn)回歸+分類,起初檢測(cè)和點(diǎn)回歸+分類是分兩步實(shí)現(xiàn)的,檢測(cè)是通過讀取XML格式來進(jìn)行訓(xùn)練,現(xiàn)在要統(tǒng)一整個(gè)項(xiàng)目的訓(xùn)練和測(cè)試過程,要將點(diǎn)回歸+分類的訓(xùn)練測(cè)試過程也按照讀取XML格式來進(jìn)行,那么就遇到一個(gè)問題,如何針對(duì)性的去給樣本偏少的樣本進(jìn)行均衡,由于在dataset類中,返回的圖像和標(biāo)簽都是針對(duì)每個(gè)index返回一個(gè)結(jié)果,在dataset類中進(jìn)行操作似乎不太可行,那么就想到在dataloader中進(jìn)行操作,通過dataloader中的參數(shù)sample來完成針對(duì)性采樣。還有一個(gè)問題是關(guān)于num_workers的設(shè)置,因?yàn)槲矣袑?duì)比過,在我的單機(jī)RTX 2080Ti上和八卡服務(wù)器TITAN RTX上(僅使用單卡,其它卡有在跑其它任務(wù)),使用相同的num_workers,在單機(jī)上的訓(xùn)練速度反而更快,于是猜想可能和CPU或者內(nèi)存有關(guān)系,下面會(huì)具體分析。首先來看下下dataloader中的各個(gè)參數(shù)的含義。類的定義為:torch.utils.data.DataLoader ,其中包含的參數(shù)有:torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, \ batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, \ drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
dataset:定義的dataset類返回的結(jié)果。
batchsize:每個(gè)bacth要加載的樣本數(shù),默認(rèn)為1。shuffle:在每個(gè)epoch中對(duì)整個(gè)數(shù)據(jù)集data進(jìn)行shuffle重排,默認(rèn)為False。sample:定義從數(shù)據(jù)集中加載數(shù)據(jù)所采用的策略,如果指定的話,shuffle必須為False;batch_sample類似,表示一次返回一個(gè)batch的index。num_workers:表示開啟多少個(gè)線程數(shù)去加載你的數(shù)據(jù),默認(rèn)為0,代表只使用主進(jìn)程。collate_fn:表示合并樣本列表以形成小批量的Tensor對(duì)象。pin_memory:表示要將load進(jìn)來的數(shù)據(jù)是否要拷貝到pin_memory區(qū)中,其表示生成的Tensor數(shù)據(jù)是屬于內(nèi)存中的鎖頁(yè)內(nèi)存區(qū),這樣將Tensor數(shù)據(jù)轉(zhuǎn)義到GPU中速度就會(huì)快一些,默認(rèn)為False。drop_last:當(dāng)你的整個(gè)數(shù)據(jù)長(zhǎng)度不能夠整除你的batchsize,選擇是否要丟棄最后一個(gè)不完整的batch,默認(rèn)為False。注:這里簡(jiǎn)單科普下pin_memory,通常情況下,數(shù)據(jù)在內(nèi)存中要么以鎖頁(yè)的方式存在,要么保存在虛擬內(nèi)存(磁盤)中,設(shè)置為True后,數(shù)據(jù)直接保存在鎖頁(yè)內(nèi)存中,后續(xù)直接傳入cuda;否則需要先從虛擬內(nèi)存中傳入鎖頁(yè)內(nèi)存中,再傳入cuda,這樣就比較耗時(shí)了,但是對(duì)于內(nèi)存的大小要求比較高。下面針對(duì)num_workers,sample和collate_fn分別進(jìn)行說明:01
pytorch中dataloader一次性創(chuàng)建num_workers個(gè)子線程,然后用batch_sampler將指定batch分配給指定worker,worker將它負(fù)責(zé)的batch加載進(jìn)RAM,dataloader就可以直接從RAM中找本輪迭代要用的batch。如果num_worker設(shè)置得大,好處是尋batch速度快,因?yàn)橄乱惠喌腷atch很可能在上一輪/上上一輪...迭代時(shí)已經(jīng)加載好了。壞處是內(nèi)存開銷大,也加重了CPU負(fù)擔(dān)(worker加載數(shù)據(jù)到RAM的進(jìn)程是進(jìn)行CPU復(fù)制)。如果num_worker設(shè)為0,意味著每一輪迭代時(shí),dataloader不再有自主加載數(shù)據(jù)到RAM這一步驟,只有當(dāng)你需要的時(shí)候再加載相應(yīng)的batch,當(dāng)然速度就更慢。num_workers的經(jīng)驗(yàn)設(shè)置值是自己電腦/服務(wù)器的CPU核心數(shù),如果CPU很強(qiáng)、RAM也很充足,就可以設(shè)置得更大些,對(duì)于單機(jī)來說,單跑一個(gè)任務(wù)的話,直接設(shè)置為CPU的核心數(shù)最好。02
定義sample:(假設(shè)dataset類返回的是:data, label)from torch.utils.data.sampler import WeightedRandomSampler## 如果label為1,那么對(duì)應(yīng)的該類別被取出來的概率是另外一個(gè)類別的2倍weights = [2 if label == 1 else 1 for data, label in dataset]sampler = WeightedRandomSampler(weights,num_samples=10, replacement=True)dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
PyTorch中提供的這個(gè)sampler模塊,用來對(duì)數(shù)據(jù)進(jìn)行采樣。默認(rèn)采用SequentialSampler,它會(huì)按順序一個(gè)一個(gè)進(jìn)行采樣。常用的有隨機(jī)采樣器:RandomSampler,當(dāng)dataloader的shuffle參數(shù)為True時(shí),系統(tǒng)會(huì)自動(dòng)調(diào)用這個(gè)采樣器,實(shí)現(xiàn)打亂數(shù)據(jù)。這里使用另外一個(gè)很有用的采樣方法:WeightedRandomSampler,它會(huì)根據(jù)每個(gè)樣本的權(quán)重選取數(shù)據(jù),在樣本比例不均衡的問題中,可用它來進(jìn)行重采樣。replacement用于指定是否可以重復(fù)選取某一個(gè)樣本,默認(rèn)為True,即允許在一個(gè)epoch中重復(fù)采樣某一個(gè)數(shù)據(jù)。03
def detection_collate(batch): """Custom collate fn for dealing with batches of images that have a different number of associated object annotations (bounding boxes).
Arguments: batch: (tuple) A tuple of tensor images and lists of annotations
Return: A tuple containing: 1) (tensor) batch of images stacked on their 0 dim 2) (list of tensors) annotations for a given image are stacked on 0 dim """ targets = [] imgs = [] for sample in batch: imgs.append(sample[0]) targets.append(torch.FloatTensor(sample[1])) return torch.stack(imgs, 0), targets
使用dataloader時(shí)加入collate_fn參數(shù),即可合并樣本列表以形成小批量的Tensor對(duì)象,如果你的標(biāo)簽不止一個(gè)的話,還可以支持自定義,在上述方法中再額外添加對(duì)應(yīng)的label即可。data_loader = torch.utils.data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, sampler=sampler, shuffle=False, collate_fn=detection_collate, pin_memory=True, drop_last=True)
https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader
猜您喜歡:
超100篇!CVPR 2020最全GAN論文梳理匯總!
拆解組新的GAN:解耦表征MixNMatch
StarGAN第2版:多域多樣性圖像生成
附下載 | 《可解釋的機(jī)器學(xué)習(xí)》中文版
附下載 |《TensorFlow 2.0 深度學(xué)習(xí)算法實(shí)戰(zhàn)》
附下載 |《計(jì)算機(jī)視覺中的數(shù)學(xué)方法》分享
《基于深度學(xué)習(xí)的表面缺陷檢測(cè)方法綜述》
《零樣本圖像分類綜述: 十年進(jìn)展》
《基于深度神經(jīng)網(wǎng)絡(luò)的少樣本學(xué)習(xí)綜述》