PyTorch 源碼解讀之 torch.utils.data:解析數(shù)據(jù)處理全流程
目錄
0 前言
1 Dataset
1.1 Map-style dataset
1.2 Iterable-style dataset
1.3 其他 dataset
2 Sampler
3 DataLoader
3.1 三者關(guān)系 (Dataset, Sampler, Dataloader)
3.2 批處理
3.2.1 自動(dòng)批處理(默認(rèn))
3.2.2 關(guān)閉自動(dòng)批處理
3.2.3 collate_fn
3.3 多進(jìn)程處理 (multi-process)
4 單進(jìn)程
5 多進(jìn)程
6 鎖頁內(nèi)存 (Memory Pinning)
7 預(yù)取 (prefetch)
8 代碼講解
0 前言
本文涉及的源碼以 PyTorch 1.7 為準(zhǔn)
迭代器
理解 Python 的迭代器是解讀 PyTorch 中 torch.utils.data 模塊的關(guān)鍵。在 Dataset, Sampler 和 DataLoader 這三個(gè)類中都會(huì)用到 python 抽象類的魔法方法,包括 __len__(self) ,__getitem__(self) 和 __iter__(self)
__len__(self): 定義當(dāng)被 len() 函數(shù)調(diào)用時(shí)的行為,一般返回迭代器中元素的個(gè)數(shù)__getitem__(self): 定義獲取容器中指定元素時(shí)的行為,相當(dāng)于self[key],即允許類對象擁有索引操作__iter__(self): 定義當(dāng)?shù)萜髦械脑貢r(shí)的行為
迭代的意思類似于循環(huán),每一次重復(fù)的過程被稱為一次迭代的過程,而每一次迭代得到的結(jié)果會(huì)被用來作為下一次迭代的初始值。提供迭代方法的容器稱為迭代器,通常接觸的迭代器有序列(列表、元組和字符串)還有字典,這些數(shù)據(jù)結(jié)構(gòu)都支持迭代操作。
實(shí)現(xiàn)迭代器的魔法方法有兩個(gè):__iter__(self) 和 __next__(self)
一個(gè)容器如果是迭代器,那就必須實(shí)現(xiàn) __iter__(self) 魔法方法,這個(gè)方法實(shí)際上是返回是一個(gè)迭代器(通常是迭代器本身)。接下來重點(diǎn)要實(shí)現(xiàn)的是 __next__(self) 魔法方法,因?yàn)樗鼪Q定了迭代的規(guī)則。
class Fibs:
def __init__(self, n=20):
self.a = 0
self.b = 1
self.n = n
def __iter__(self):
return self
def __next__(self):
self.a, self.b = self.b, self.a + self.b
if self.a > self.n:
raise StopIteration
return self.a
fibs = Fibs()
for each in fibs:
print(each)
# 輸出
# 1 1 2 3 5 8 13
一般來說,迭代器滿足以下幾種特性:
迭代器是?個(gè)對象 迭代器可以被 next() 函數(shù)調(diào)?,并返回?個(gè)值 迭代器可以被 iter() 函數(shù)調(diào)?,并返回一個(gè)迭代器(可以是自身) 連續(xù)被 next() 調(diào)?時(shí)依次返回?系列的值 如果到了迭代的末尾,則拋出 StopIteration 異常 迭代器也可以沒有末尾,只要被 next() 調(diào)?,就?定會(huì)返回?個(gè)值 Python 中, next() 內(nèi)置函數(shù)調(diào)?的是對象的 next() ?法 Python 中, iter() 內(nèi)置函數(shù)調(diào)?的是對象的 iter() ?法 ?個(gè)實(shí)現(xiàn)了迭代器協(xié)議的的對象可以被 for 語句循環(huán)迭代直到終?
了解了什么是迭代器后,我們就可以開始解讀 torch.utils.data 模塊
對于 torch.utils.data 而言,重點(diǎn)是其 Dataset, Sampler, DataLoader 模塊,輔以 collate, fetch, pin_memory 等組件對特定功能予以支持。
1 Dataset
Dataset 負(fù)責(zé)對 raw data source 封裝,將其封裝成 Python 可識別的數(shù)據(jù)結(jié)構(gòu),其必須提供提取數(shù)據(jù)個(gè)體的接口。
Dataset 共有 Map-style datasets 和 Iterable-style datasets 兩種:
1.1 Map-style datasettorch.utils.data.Dataset
它是一種通過實(shí)現(xiàn)__getitem__()和__len()__來獲取數(shù)據(jù)的 Dataset,它表示從(可能是非整數(shù))索引/關(guān)鍵字到數(shù)據(jù)樣本的映射。訪問時(shí),這樣的數(shù)據(jù)集用 dataset[idx] 訪問 idx 對應(yīng)的數(shù)據(jù)。通常我們使用 Map-style 類型的 dataset 居多,其數(shù)據(jù)接口定義如下:
class Dataset(Generic[T_co]):
# Generic is an Abstract base class for generic types.
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
PyTorch 中所有定義的 Dataset 都是其子類。
對于一般計(jì)算機(jī)視覺任務(wù),我們通常會(huì)在其中進(jìn)行一些 resize, crop, flip 等預(yù)處理的操作 值得一提的是,PyTorch 源碼中并沒有提供默認(rèn)的 len() 方法實(shí)現(xiàn),原因是 return NotImplemented 或者 raise NotImplementedError() 之類的默認(rèn)實(shí)現(xiàn)都會(huì)存在各自的問題,這點(diǎn)在其源碼中也有注釋加以體現(xiàn)。
1.2 Iterable-style datasettorch.utils.data.IterableDataset
它是一種實(shí)現(xiàn) iter() 來獲取數(shù)據(jù)的 Dataset,這種類型的數(shù)據(jù)集特別適用于以下情況:隨機(jī)讀取代價(jià)很大甚至不大可能,且 batch size 取決于獲取的數(shù)據(jù)。其接口定義如下:
class IterableDataset(Dataset[T_co]):
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
def __add__(self, other: Dataset[T_co]):
return ChainDataset([self, other])
特別地,當(dāng) DataLoader 的 num_workers > 0 時(shí), 每個(gè) worker 都將具有數(shù)據(jù)對象的不同樣本。因此需要獨(dú)立地對每個(gè)副本進(jìn)行配置,以防止每個(gè) worker 產(chǎn)生的數(shù)據(jù)不重復(fù)。同時(shí),數(shù)據(jù)加載順序完全由用戶定義的可迭代樣式控制。這允許更容易地實(shí)現(xiàn)塊讀取和動(dòng)態(tài)批次大小(例如,通過每次產(chǎn)生一個(gè)批次的樣本)
1.3 其他 Dataset
除了 Map-style dataset 和 Iterable-style dataset 以外,PyTorch 也在此基礎(chǔ)上提供了其他類型的 Dataset 子類
torch.utils.data.ConcatDataset: 用于連接多個(gè) ConcatDataset 數(shù)據(jù)集
torch.utils.data.ChainDataset : 用于連接多個(gè) IterableDataset 數(shù)據(jù)集,在 IterableDataset 的 add() 方法中被調(diào)用
torch.utils.data.Subset: 用于獲取指定一個(gè)索引序列對應(yīng)的子數(shù)據(jù)集
class Subset(Dataset[T_co]):
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
torch.utils.data.TensorDataset: 用于獲取封裝成 tensor 的數(shù)據(jù)集,每一個(gè)樣本都通過索引張量來獲得。class TensorDataset(Dataset):
def __init__(self, *tensor):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
def __getitem__(self, index):
return tuple(tensor[index] for tensor in tensors
def __len__(self):
return self.tensors[0].size(0)
2 Sampler
torch.utils.data.Sampler 負(fù)責(zé)提供一種遍歷數(shù)據(jù)集所有元素索引的方式。可支持用戶自定義,也可以用 PyTorch 提供的,基類接口定義如下:
lass Sampler(Generic[T_co]):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
特別地,__len__() 方法不是必要的,但是當(dāng) DataLoader 需要計(jì)算 len() 的時(shí)候必須定義,這點(diǎn)在其源碼中也有注釋加以體現(xiàn)。
同樣,PyTorch 也在此基礎(chǔ)上提供了其他類型的 Sampler 子類
torch.utils.data.SequentialSampler : 順序采樣樣本,始終按照同一個(gè)順序 torch.utils.data.RandomSampler: 可指定有無放回地,進(jìn)行隨機(jī)采樣樣本元素 torch.utils.data.SubsetRandomSampler: 無放回地按照給定的索引列表采樣樣本元素 torch.utils.data.WeightedRandomSampler: 按照給定的概率來采樣樣本。樣本元素來自 [0,…,len(weights)-1] , 給定概率(權(quán)重) torch.utils.data.BatchSampler: 在一個(gè)batch中封裝一個(gè)其他的采樣器, 返回一個(gè) batch 大小的 index 索引 torch.utils.data.DistributedSample: 將數(shù)據(jù)加載限制為數(shù)據(jù)集子集的采樣器。與 torch.nn.parallel.DistributedDataParallel 結(jié)合使用。在這種情況下,每個(gè)進(jìn)程都可以將 DistributedSampler 實(shí)例作為 DataLoader 采樣器傳遞
3 DataLoader
torch.utils.data.DataLoader 是 PyTorch 數(shù)據(jù)加載的核心,負(fù)責(zé)加載數(shù)據(jù),同時(shí)支持 Map-style 和 Iterable-style Dataset,支持單進(jìn)程/多進(jìn)程,還可以設(shè)置 loading order, batch size, pin memory 等加載參數(shù)。其接口定義如下:
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
對于每個(gè)參數(shù)的含義,以下給出一個(gè)表格進(jìn)行對應(yīng)介紹:
| attribute | meaning | default value | type |
|---|---|---|---|
| dataset | 加載數(shù)據(jù)的數(shù)據(jù)集 | Dataset | |
| batch_size | 每個(gè) batch 加載多少個(gè)樣本 | 1 | int |
| shuffle | 設(shè)置為 True 時(shí),調(diào)用 RandomSampler 進(jìn)行隨機(jī)索引 | False | bool |
| sampler | 定義從數(shù)據(jù)集中提取樣本的策略 如果指定了, shuffle 參數(shù)必須為 False,(否則會(huì)和 RandomSampler 互斥) | None | Sampler, Iterable |
| batch_sampler | 和 sampler 類似,但是一般傳入 BatchSampler,每次返回一個(gè) batch 大小的索引 其和 batch_size, shuffle 等參數(shù)是互斥的 | None | Sampler, Iterable |
| num_workers | 要用于數(shù)據(jù)加載的子進(jìn)程數(shù),0 表示將在主進(jìn)程中加載數(shù)據(jù) | 0 | int |
| collate_fn | 在將 Map-style datase t 取出的數(shù)據(jù)整合成 batch 時(shí)使用,合并樣本列表以形成一個(gè) batch | None | callable |
| pin_memory | 如果為 True,則 DataLoader 在將張量返回之前將其復(fù)制到 CUDA 固定的內(nèi)存中 | False | bool |
| drop_last | 設(shè)置為 True 刪除最后一個(gè)不完整的批次,如果該數(shù)據(jù)集大小不能被該批次大小整除。如果 False 并且數(shù)據(jù)集的大小不能被批次大小整除,那么最后一批將較小 | False | bool |
| timeout | 如果為正,則為從 worker 收集 batch 的超時(shí)值,應(yīng)始終為非負(fù)數(shù) 超過這個(gè)時(shí)間還沒讀取到數(shù)據(jù)的話就會(huì)報(bào)錯(cuò) | 0 | numeric |
| worker_init_fn | 如果不為 None,它將會(huì)被每個(gè) worker 子進(jìn)程調(diào)用, 以 worker id ([0, num_workers - 1] 內(nèi)的整形) 為輸入 | None | callable |
| prefetch_factor | 每個(gè) worker 提前加載 的 sample 數(shù)量 | 2 | int |
| persistent_workers | 如果為 True,dataloader 將不會(huì)終止 worker 進(jìn)程,直到 dataset 迭代完成 | False | bool |
從參數(shù)定義中,我們可以看到 DataLoader 主要支持以下幾個(gè)功能
支持加載 map-style 和 iterable-style 的 dataset,主要涉及到的參數(shù)是 dataset 自定義數(shù)據(jù)加載順序,主要涉及到的參數(shù)有 shuffle, sampler, batch_sampler, collate_fn 自動(dòng)把數(shù)據(jù)整理成batch序列,主要涉及到的參數(shù)有 batch_size, batch_sampler, collate_fn, drop_last 單進(jìn)程和多進(jìn)程的數(shù)據(jù)加載,主要涉及到的參數(shù)有 num_workers, worker_init_fn 自動(dòng)進(jìn)行鎖頁內(nèi)存讀取 (memory pinning),主要涉及到的參數(shù) pin_memory 支持?jǐn)?shù)據(jù)預(yù)加載,主要涉及的參數(shù) prefetch_factor
3.1 三者關(guān)系 (Dataset, Sampler, Dataloader)
通過以上介紹的三者工作內(nèi)容不難推出其內(nèi)在關(guān)系:
設(shè)置 Dataset,將數(shù)據(jù) data source 包裝成 Dataset 類,暴露提取接口。
設(shè)置 Sampler,決定采樣方式。我們是能從 Dataset 中提取元素了,還是需要設(shè)置 Sampler 告訴程序提取 Dataset 的策略。
將設(shè)置好的 Dataset 和 Sampler 傳入 DataLoader,同時(shí)可以設(shè)置 shuffle, batch_size 等參數(shù)。使用 DataLoader 對象可以方便快捷地在數(shù)據(jù)集上遍歷。
總結(jié)來說,即 Dataloader 負(fù)責(zé)總的調(diào)度,命令 Sampler 定義遍歷索引的方式,然后用索引去 Dataset 中提取元素。于是就實(shí)現(xiàn)了對給定數(shù)據(jù)集的遍歷。
3.2 批處理
3.2.1 自動(dòng)批處理(默認(rèn))
DataLoader 支持通過參數(shù)batch_size, drop_last, batch_sampler,自動(dòng)地把取出的數(shù)據(jù)整理 (collate) 成批次樣本 (batch)
batch_size 和 drop_last 參數(shù)用于指定 DataLoader 如何獲取 dataset 的 key。特別地,對于 map-style 類型的 dataset,用戶可以選擇指定 batch_sample參數(shù),一次就生成一個(gè) keys list
在使用 sampler 產(chǎn)生的 indices 獲取采樣到的數(shù)據(jù)時(shí),DataLoader 使用 collate_fn 參數(shù)將樣本列表整理成 batch。抽象這個(gè)過程,其表示方式大致如下
# For Map-style
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
# For Iterable-style
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
3.2.2 關(guān)閉自動(dòng)批處理
當(dāng)用戶想用 dataset 代碼手動(dòng)處理 batch,或僅加載單個(gè) sample data 時(shí),可將 batch_size 和 batch_sampler 設(shè)為 None, 將關(guān)閉自動(dòng)批處理。此時(shí),由 Dataset 產(chǎn)生的 sample 將會(huì)直接被 collate_fn 處理。抽象這個(gè)過程,其表示方式大致如下:
# For Map-style
for index in sampler:
yield collate_fn(dataset[index])
# For Iterable-style
for data in iter(dataset):
yield collate_fn(data)
3.2.3 collate_fn當(dāng)關(guān)閉自動(dòng)批處理 (automatic batching) 時(shí),collate_fn 作用于單個(gè)數(shù)據(jù)樣本,只是在 PyTorch 張量中轉(zhuǎn)換 NumPy 數(shù)組。
當(dāng)開啟自動(dòng)批處理 (automatic batching) 時(shí),collate_fn 作用于數(shù)據(jù)樣本列表,將輸入樣本整理為一個(gè) batch,一般做下面 3 件事情
添加新的批次維度(一般是第一維) 它會(huì)自動(dòng)將 NumPy 數(shù)組和 Python 數(shù)值轉(zhuǎn)換為 PyTorch 張量 它保留數(shù)據(jù)結(jié)構(gòu),例如,如果每個(gè)樣本都是 dict,則輸出具有相同鍵集但批處理過的張量作為值的字典(或list,當(dāng)不能轉(zhuǎn)換的時(shí)候)。list, tuples, namedtuples 同樣適用
自定義 collate_fn 可用于自定義排序規(guī)則,例如,將順序數(shù)據(jù)填充到批處理的最大長度,添加對自定義數(shù)據(jù)類型的支持等。
3.3 多進(jìn)程處理 (multi-process)
為了避免在加載數(shù)據(jù)時(shí)阻塞計(jì)算代碼,PyTorch 提供了一個(gè)簡單的開關(guān),只需將參數(shù)設(shè)置 num_workers 為正整數(shù)即可執(zhí)行多進(jìn)程數(shù)據(jù)加載,設(shè)置為 0 時(shí)執(zhí)行單線程數(shù)據(jù)加載。
4. 單進(jìn)程
在單進(jìn)程模式下,DataLoader 初始化的進(jìn)程和取數(shù)據(jù)的進(jìn)程是一樣的 。因此,數(shù)據(jù)加載可能會(huì)阻止計(jì)算。
但是,當(dāng)用于在進(jìn)程之間共享數(shù)據(jù)的資源(例如共享內(nèi)存,文件描述符)有限時(shí),或者當(dāng)整個(gè)數(shù)據(jù)集很小并且可以完全加載到內(nèi)存中時(shí),此模式可能是首選。
此外,單進(jìn)程加載通常顯示更多可讀的錯(cuò)誤跟蹤,因此對于調(diào)試很有用。
5. 多進(jìn)程
在多進(jìn)程模式下,每次 DataLoader 創(chuàng)建 iterator 時(shí)(例如,當(dāng)調(diào)用時(shí)enumerate(dataloader)),都會(huì)創(chuàng)建 num_workers 工作進(jìn)程。dataset, collate_fn, worker_init_fn 都會(huì)被傳到每個(gè)worker中,每個(gè)worker都用獨(dú)立的進(jìn)程。
對于 map-style 數(shù)據(jù),主線程會(huì)用 Sampler 產(chǎn)生 indice,并將它們送到 worker 里。因此,shuffle是在主線程做的
對于 iterable-style 數(shù)據(jù),因?yàn)槊總€(gè) worker 都有相同的 data 復(fù)制樣本,并在各個(gè)進(jìn)程里進(jìn)行不同的操作,以防止每個(gè)進(jìn)程輸出的數(shù)據(jù)是重復(fù)的,所以一般用 torch.utils.data.get_worker_info() 來進(jìn)行輔助處理。
這里,torch.utils.data.get_worker_info() 返回worker進(jìn)程的一些信息(id, dataset, num_workers, seed),如果在主線程跑的話返回None
注意,通常不建議在多進(jìn)程加載中返回CUDA張量,因?yàn)樵谑褂肅UDA和在多處理中共享CUDA張量時(shí)存在許多微妙之處(文檔中提出:只要接收過程保留張量的副本,就需要發(fā)送過程來保留原始張量)。建議采用 pin_memory=True ,以將數(shù)據(jù)快速傳輸?shù)街С諧UDA的GPU。簡而言之,不建議在使用多線程的情況下返回CUDA的tensor。
6 鎖頁內(nèi)存 (Memory Pinning)
這里首先解釋一下鎖頁內(nèi)存的概念。
主機(jī)中的內(nèi)存,有兩種存在方式,一是鎖頁,二是不鎖頁,鎖頁內(nèi)存存放的內(nèi)容在任何情況下都不會(huì)與主機(jī)的虛擬內(nèi)存進(jìn)行交換(注:虛擬內(nèi)存就是硬盤),而不鎖頁內(nèi)存在主機(jī)內(nèi)存不足時(shí),數(shù)據(jù)會(huì)存放在虛擬內(nèi)存中。主機(jī)到GPU副本源自固定(頁面鎖定)內(nèi)存時(shí),速度要快得多。CPU張量和存儲暴露了一種 pin_memory() 方法,該方法返回對象的副本,并將數(shù)據(jù)放在固定的區(qū)域中。
而顯卡中的顯存全部是鎖頁內(nèi)存!當(dāng)計(jì)算機(jī)的內(nèi)存充足的時(shí)候,可以設(shè)置 pin_memory=True。設(shè)置 pin_memory=True,則意味著生成的 Tensor 數(shù)據(jù)最開始是屬于內(nèi)存中的鎖頁內(nèi)存,這樣將內(nèi)存的Tensor轉(zhuǎn)義到GPU的顯存就會(huì)更快一些。同時(shí),由于 pin_memory 的作用是將張量返回之前將其復(fù)制到 CUDA 固定的內(nèi)存中,所以只有在 CUDA 環(huán)境支持下才有用。
PyTorch 原生的 pin_memory 方法如下,其支持大部分 python 數(shù)據(jù)類型的處理:
def pin_memory(data):
if isinstance(data, torch.Tensor):
return data.pin_memory()
elif isinstance(data, string_classes):
return data
elif isinstance(data, container_abcs.Mapping):
return {k: pin_memory(sample) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample) for sample in data))
elif isinstance(data, container_abcs.Sequence):
return [pin_memory(sample) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data
默認(rèn)情況下,如果固定邏輯看到一個(gè)屬于自定義類型 (custom type) 的batch(如果有一個(gè) collate_fn 返回自定義批處理類型的批處理,則會(huì)發(fā)生),或者如果該批處理的每個(gè)元素都是 custom type,則固定邏輯將無法識別它們,它將返回該批處理(或那些元素)而無需固定內(nèi)存。要為自定義批處理或數(shù)據(jù)類型啟用內(nèi)存固定,需 pin_memory() 在自定義類型上定義一個(gè)方法。如下
class SimpleCustomBatch:
# 自定義一個(gè)類,該類不能被PyTorch原生的pin_memory方法所支持
def __init__(self, data):
transposed_data = list(zip(*data))
self.inp = torch.stack(transposed_data[0], 0)
self.tgt = torch.stack(transposed_data[1], 0)
# custom memory pinning method on custom type
def pin_memory(self):
self.inp = self.inp.pin_memory()
self.tgt = self.tgt.pin_memory()
return self
def collate_wrapper(batch):
return SimpleCustomBatch(batch)
inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)
loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
pin_memory=True)
for batch_ndx, sample in enumerate(loader):
print(sample.inp.is_pinned()) # True
print(sample.tgt.is_pinned()) # True
7 預(yù)取 (prefetch)
DataLoader 通過指定 prefetch_factor (默認(rèn)為 2)來進(jìn)行數(shù)據(jù)的預(yù)取。
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
...
self._reset(loader, first_iter=True)
def _reset(self, loader, first_iter=False):
...
# prime the prefetch loop
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()
通過源碼可以看到,prefetch 功能僅適用于 多進(jìn)程 加載中(下面會(huì)由多進(jìn)程 dataloader 的代碼分析)
8 代碼詳解
讓我們來看看具體的代碼調(diào)用流程:
for data, label in train_loader:
......
for 循環(huán)會(huì)調(diào)用 dataloader 的 iter(self) 方法,以此獲得迭代器來遍歷 dataset
class DataLoader(Generic[T_co]):
...
def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
在 iter(self) 方法中,dataloader 調(diào)用了 self._get_iterator() 方法,根據(jù) num_worker 獲得迭代器,并指示進(jìn)行單進(jìn)程還是多進(jìn)程
class DataLoader(Generic[T_co]):
...
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
為了描述清晰,我們只考慮單進(jìn)程的代碼。下面是 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter) ,以及其父類 class _BaseDataLoaderIter(object): 的重點(diǎn)代碼片段:
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
# 初始化賦值一些 DataLoader 參數(shù),
# 以及用戶輸入合法性進(jìn)行校驗(yàn)
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._index_sampler = loader._index_sampler
...
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 重點(diǎn)代碼行,通過此獲取數(shù)據(jù)
self._num_yielded += 1
...
return data
next = __next__ # Python 2 compatibility
def __len__(self) -> int:
return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)
def __getstate__(self):
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
_BaseDataLoaderIter 是所有 DataLoaderIter 的父類。dataloader獲得了迭代器之后,for 循環(huán)需要調(diào)用 next() 來獲得下一個(gè)對象,從而實(shí)現(xiàn)遍歷。通過 next 方法調(diào)用 _next_data() 獲取數(shù)據(jù)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
從 _SingleProcessDataLoaderIter 的初始化參數(shù)可以看到,其在父類 _BaseDataLoaderIter 的基礎(chǔ)上定義了 _dataset_fetcher, 并傳入 _dataset, _auto_collation, _collate_fn 等參數(shù),用于定義獲取數(shù)據(jù)的方式。其具體實(shí)現(xiàn)會(huì)在稍后解釋。
在 _next_data() 被調(diào)用后,其需要 next_index() 獲取 index,并通過獲得的 index 傳入 _dataset_fetcher 中獲取對應(yīng)樣本
class DataLoader(Generic[T_co]):
...
@property
def _auto_collation(self):
return self.batch_sampler is not None
@property
def _index_sampler(self):
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
class _BaseDataLoaderIter(object):
...
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
...
def _next_index(self):
# sampler_iter 來自于 index_sampler
return next(self._sampler_iter) # may raise StopIteration
從這里看出,dataloader 提供了 sampler (可以是batch_sampler 或者是其他 sampler 子類),然后 _SingleProcessDataLoaderIter 迭代sampler獲得索引
下面我們來看看 fetcher,fetcher 需要 index 來獲取元素,并同時(shí)支持 Map-style dataset(對應(yīng) _MapDatasetFetcher)和 Iterable-style dataset(對應(yīng) _IterableDatasetFetcher),使其在Dataloader內(nèi)能使用相同的接口 fetch,代碼更加簡潔。
對于 Map-style:直接輸入索引 index,作為 map 的 key,獲得對應(yīng)的樣本(即 value)
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 有batch_sampler,_auto_collation就為True,
# 就優(yōu)先使用batch_sampler,對應(yīng)在fetcher中傳入的就是一個(gè)batch的索引
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
對于 Iterable-style: init 方法內(nèi)設(shè)置了 dataset 初始的迭代器,fetch 方法內(nèi)獲取元素,index 其實(shí)已經(jīng)沒有多大作用了
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 對于batch_sampler(即auto_collation==True)
# 直接使用往后遍歷并提取len(possibly_batched_index)個(gè)樣本(即1個(gè)batch的樣本)
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
# 對于sampler,直接往后遍歷并提取1個(gè)樣本
data = next(self.dataset_iter)
return self.collate_fn(data)
最后,我們通過索引傳入 fetcher,fetch 得到想要的樣本 因此,整個(gè)過程調(diào)用關(guān)系總結(jié) 如下:
loader.__iter__ --> self._get_iterator()--> class _SingleProcessDataLoaderIter --> class _BaseDataLoaderIter--> __next__() --> self._next_data()--> self._next_index()-->next(self._sampler_iter)即 next(iter(self._index_sampler)) --> 獲得 index --> self._dataset_fetcher.fetch(index)--> 獲得 data
對于多進(jìn)程而言,借用 PyTorch 內(nèi)源碼的注釋,其運(yùn)行流程解釋如下
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.
首先 dataloader 基于 multiprocessing 產(chǎn)生多進(jìn)程,每個(gè)子進(jìn)程的輸入輸出通過兩個(gè)主要的隊(duì)列 (multiprocessing.Queue() 類) 產(chǎn)生,分別為:
index_queue: 每個(gè)子進(jìn)程的隊(duì)列中需要處理的任務(wù)的下標(biāo) _worker_result_queue: 返回時(shí)處理完任務(wù)的下標(biāo) data_queue: 表明經(jīng)過 pin_memory 處理后的數(shù)據(jù)隊(duì)列
并且有以下這些比較重要的 flag 參數(shù)來協(xié)調(diào)各個(gè) worker 之間的工作:
_send_idx: 發(fā)送索引,用來記錄這次要放 index_queue 中 batch 的 idx _rcvd_idx: 接受索引,記錄要從 data_queue 中取出的 batch 的 idx _task_info: 存儲將要產(chǎn)生的 data 信息的 dict,key為 task idx(由 0 開始的整形索引),value 為 (worker_id,) 或 (worker_id, data),分別對應(yīng)數(shù)據(jù) 未取 和 已取 的情況 _tasks_outstanding: 整形,代表已經(jīng)準(zhǔn)備好的 task/batch 的數(shù)量(可能有些正在準(zhǔn)備中)
每個(gè) worker 一次產(chǎn)生一個(gè) batch 的數(shù)據(jù),返回 batch 數(shù)據(jù)前放入下一個(gè)批次要處理的數(shù)據(jù)下標(biāo),對應(yīng)構(gòu)造函數(shù)子進(jìn)程初始化如下
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
...
self._worker_result_queue = multiprocessing_context.Queue() # 把該worker取出的數(shù)放入該隊(duì)列,用于進(jìn)程間通信
...
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = []
self._workers = []
for i in range(self._num_workers):
index_queue = multiprocessing_context.Queue() # 索引隊(duì)列,每個(gè)子進(jìn)程一個(gè)隊(duì)列放要處理的下標(biāo)
index_queue.cancel_join_thread()
# _worker_loop 的作用是:從index_queue中取索引,然后通過collate_fn處理數(shù)據(jù),
# 然后再將處理好的 batch 數(shù)據(jù)放到 data_queue 中。(發(fā)送到隊(duì)列中的idx是self.send_idx)
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop, # 每個(gè)worker子進(jìn)程循環(huán)執(zhí)行的函數(shù),主要將數(shù)據(jù)以(idx, data)的方式傳入_worker_result_queue中
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed + i, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
w.start()
self._index_queues.append(index_queue)
self._workers.append(w)
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
self._data_queue = queue.Queue() # 用于存取出的數(shù)據(jù)進(jìn)行 pin_memory 操作后的結(jié)果
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue
...
self._reset(loader, first_iter=True)
def _reset(self, loader, first_iter=False):
super()._reset(loader, first_iter)
self._send_idx = 0 # idx of the next task to be sent to workers,發(fā)送索引,用來記錄這次要放 index_queue 中 batch 的 idx
self._rcvd_idx = 0 # idx of the next task to be returned in __next__,接受索引,記錄要從 data_queue 中取出的 batch 的 idx
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
# _tasks_outstanding 指示當(dāng)前已經(jīng)準(zhǔn)備好的 task/batch 的數(shù)量(可能有些正在準(zhǔn)備中)
# 初始值為 0, 在 self._try_put_index() 中 +1,在 self._next_data 中-1
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
# this indicates status that a worker still has work to do *for this epoch*.
self._workers_status = [True for i in range(self._num_workers)]
# We resume the prefetching in case it was enabled
if not first_iter:
for idx in range(self._num_workers):
self._index_queues[idx].put(_utils.worker._ResumeIteration())
resume_iteration_cnt = self._num_workers
while resume_iteration_cnt > 0:
data = self._get_data()
if isinstance(data, _utils.worker._ResumeIteration):
resume_iteration_cnt -= 1
...
# 初始化的時(shí)候,就將 2*num_workers 個(gè) (batch_idx, sampler_indices) 放到 index_queue 中
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index() # 進(jìn)行預(yù)取
dataloader 初始化的時(shí)候,每個(gè) worker 的 index_queue 默認(rèn)會(huì)放入兩個(gè) batch 的 index,從 index_queue 中取出要處理的下標(biāo)
def _try_put_index(self):
# self._prefetch_factor 默認(rèn)為 2
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
try:
index = self._next_index()
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]:
break
else:
# not found (i.e., didn't break)
return
self._index_queues[worker_queue_idx].put((self._send_idx, index)) # 放入 任務(wù)下標(biāo) 和 數(shù)據(jù)下標(biāo)
self._task_info[self._send_idx] = (worker_queue_idx,)
# _tasks_outstanding + 1,表明預(yù)備好的batch個(gè)數(shù)+1
self._tasks_outstanding += 1
# send_idx 發(fā)送索引, 記錄從sample_iter中發(fā)送索引到index_queue的次數(shù)
self._send_idx += 1
調(diào)用 _next_data 方法進(jìn)行數(shù)據(jù)讀取,其中 _process_data 用于返回?cái)?shù)據(jù)
def _next_data(self):
while True:
while self._rcvd_idx < self._send_idx: # 確保待處理的任務(wù)(待取的batch)下標(biāo) > 處理完畢要返回的任務(wù)(已經(jīng)取完的batch)下標(biāo)
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
if not self._persistent_workers:
self._shutdown_workers()
raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data)
assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data() # 調(diào)用 self._try_get_data() 從 self._data_queue 中取數(shù)
self._tasks_outstanding -= 1 # 表明預(yù)備好的batch個(gè)數(shù)需要減1
if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
if self._persistent_workers:
self._workers_status[data.worker_id] = False
else:
self._mark_worker_as_unavailable(data.worker_id)
self._try_put_index()
continue
if idx != self._rcvd_idx:
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx]
return self._process_data(data) # 返回?cái)?shù)據(jù)
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index() # 同上,主要放入隊(duì)列索引 以及 更新flag
if isinstance(data, ExceptionWrapper):
data.reraise()
return data
這樣,多線程的 dataloader 就能通過多個(gè) worker 的協(xié)作來共同完成數(shù)據(jù)的加載。
參考
https://pytorch.org/docs/stable/data.html https://www.zhihu.com/search?type=content&q=dataloader https://www.dazhuanlan.com/2019/12/05/5de8104ce9491/ https://blog.csdn.net/g11d111/article/details/81504637
- The End -
長按二維碼關(guān)注我們
本公眾號專注:
1. 技術(shù)分享;
2. 學(xué)術(shù)交流;
3. 資料共享。
歡迎關(guān)注我們,一起成長!
