Best Practice in PyTorch: 如何控制dataloader的隨機(jī)shuffle

極市導(dǎo)讀
?在使用PyTorch進(jìn)行訓(xùn)練或者測試的過程中,一般來說dataloader在每個epoch返回的樣本順序是不一樣的,但在某些特殊情況中,我們可能希望dataloader按照固定的順序進(jìn)行多個epoch。本文作者給出了一個簡單方便的實(shí)現(xiàn)思路,附詳解代碼。?>>加入極市CV技術(shù)交流群,走在計(jì)算機(jī)視覺的最前沿
問題背景:
在使用PyTorch進(jìn)行訓(xùn)練或者測試的過程中,一般來說dataloader在每個epoch返回的樣本順序是不一樣的,但在某些特殊情況中,我們可能希望dataloader按照固定的順序進(jìn)行多個epoch, 或者說,在一個epoch中按照固定的順序進(jìn)行多次的樣本循環(huán)iteration。
現(xiàn)有Sampler:
默認(rèn)的 RandomSampler 在生成iteration的時候會重新做一次random shuffle,所以無法直接實(shí)現(xiàn)這個需求。
????def?__iter__(self)?->?Iterator[int]:
????????n?=?len(self.data_source)
????????if?self.generator?is?None:
????????????seed?=?int(torch.empty((),?dtype=torch.int64).random_().item())
????????????generator?=?torch.Generator()
????????????generator.manual_seed(seed)
????????else:
????????????generator?=?self.generator
????????if?self.replacement:
????????????for?_?in?range(self.num_samples?//?32):
????????????????yield?from?torch.randint(high=n,?size=(32,),?dtype=torch.int64,?generator=generator).tolist()
????????????yield?from?torch.randint(high=n,?size=(self.num_samples?%?32,),?dtype=torch.int64,?generator=generator).tolist()
????????else:
????????????for?_?in?range(self.num_samples?//?n):
????????????????yield?from?torch.randperm(n,?generator=generator).tolist()
????????????yield?from?torch.randperm(n,?generator=generator).tolist()[:self.num_samples?%?n]
上面的代碼是RandomSampler中最重要的__iter__函數(shù),我們可以看到每次調(diào)用這個函數(shù)或者新的iter時會得到一個新的隨機(jī)順序的iteration。
再看看另一個常用的sampler,也就是 SequentialSampler。我們在test的時候經(jīng)常會設(shè)置shuffle=false,這時候就相當(dāng)于使用了SequentialSampler:
class?SequentialSampler(Sampler[int]):
????r"""Samples?elements?sequentially,?always?in?the?same?order.
????Args:
????????data_source?(Dataset):?dataset?to?sample?from
????"""
????data_source:?Sized
????def?__init__(self,?data_source:?Sized)?->?None:
????????self.data_source?=?data_source
????def?__iter__(self)?->?Iterator[int]:
????????return?iter(range(len(self.data_source)))
????def?__len__(self)?->?int:
????????return?len(self.data_source)
在代碼中可以看到,這個sampler就是簡單地創(chuàng)造并返回一個range序列,無法對其進(jìn)行shuffle操作。
解決方案:
結(jié)合上面兩個現(xiàn)有的sampler,我們可以簡單地自定義一個新的sampler來實(shí)現(xiàn)我們的需求。也就是說,我們希望能夠手動控制何時進(jìn)行shuffle操作,在沒有shuffle時我們希望sampler按照前面的順序返回iteration。
下面是我的實(shí)現(xiàn):
class?MySequentialSampler(SequentialSampler):
????def?__init__(self,?data_source,?num_data=None):
????????self.data_source?=?data_source
????????self.my_list?=?list(range(len(self.data_source)))
????????random.shuffle(self.my_list)
????????if?num_data?is?None:
????????????self.num_data?=?len(self.my_list)
????????else:
????????????self.num_data?=?num_data
????????????self.my_list?=?self.my_list[:num_data]
????def?__iter__(self):
????????return?iter(self.my_list)
????def?__len__(self):
????????return?self.num_data
????def?shuffle(self):
????????self.my_list?=?list(range(len(self.data_source)))
????????random.shuffle(self.my_list)
????????self.my_list?=?self.my_list[:self.num_data]
這個實(shí)現(xiàn)非常簡單而且使用方便。在默認(rèn)情況下基本等同于SequentialSampler (去掉init函數(shù)中的shuffle即完全一致)。當(dāng)我們需要重新shuffle序列的時候,只需要調(diào)用shuffle函數(shù)即可,比如:dataloader.sampler.shuffle(). 通過這個自定義sampler,我們就可以實(shí)現(xiàn)在指定的時候進(jìn)行shuffle操作,而不是固定在每個iteration結(jié)束時進(jìn)行shuffle。
ps: 理論上也可以直接通過對dataset進(jìn)行shuffle,但這樣操作的缺點(diǎn)是會改變對應(yīng)的index,另外一般我們在train或者test函數(shù)中不會獲取到dataset,而只能從loader進(jìn)行操作(dataloader.dataset一般只能獲取到length)。因此,修改sampler可以說是對原訓(xùn)練方法流程最少的方式。
公眾號后臺回復(fù)“目標(biāo)檢測競賽”獲取競賽經(jīng)驗(yàn)分享~


