<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          PyTorch 源碼解讀之 torch.utils.data:解析數(shù)據(jù)處理全流程

          共 52227字,需瀏覽 105分鐘

           ·

          2021-06-09 12:30

          作者丨OpenMMLab
          來源丨h(huán)ttps://zhuanlan.zhihu.com/p/337850513
          編輯丨GiantPandaCV


          目錄

          0 前言

          1 Dataset

          1.1 Map-style dataset

          1.2 Iterable-style dataset

          1.3 其他 dataset

          2 Sampler

          3 DataLoader

          3.1 三者關(guān)系 (Dataset, Sampler, Dataloader)

          3.2 批處理

            3.2.1 自動(dòng)批處理(默認(rèn))

          3.2.2 關(guān)閉自動(dòng)批處理

          3.2.3 collate_fn

          3.3 多進(jìn)程處理 (multi-process)

          4 單進(jìn)程

          5 多進(jìn)程

          6 鎖頁內(nèi)存 (Memory Pinning)

          7 預(yù)取 (prefetch)

          8 代碼講解

          0 前言

          本文涉及的源碼以 PyTorch 1.7 為準(zhǔn)

          迭代器

          理解 Python 的迭代器是解讀 PyTorch 中 torch.utils.data 模塊的關(guān)鍵。在 Dataset, SamplerDataLoader 這三個(gè)類中都會(huì)用到 python 抽象類的魔法方法,包括 __len__(self)__getitem__(self)__iter__(self)

          • __len__(self) :  定義當(dāng)被 len() 函數(shù)調(diào)用時(shí)的行為,一般返回迭代器中元素的個(gè)數(shù)
          • __getitem__(self): 定義獲取容器中指定元素時(shí)的行為,相當(dāng)于self[key] ,即允許類對象擁有索引操作
          • __iter__(self) : 定義當(dāng)?shù)萜髦械脑貢r(shí)的行為

          迭代的意思類似于循環(huán),每一次重復(fù)的過程被稱為一次迭代的過程,而每一次迭代得到的結(jié)果會(huì)被用來作為下一次迭代的初始值。提供迭代方法的容器稱為迭代器,通常接觸的迭代器有序列(列表、元組和字符串)還有字典,這些數(shù)據(jù)結(jié)構(gòu)都支持迭代操作。

          實(shí)現(xiàn)迭代器的魔法方法有兩個(gè):__iter__(self)  和 __next__(self)

          一個(gè)容器如果是迭代器,那就必須實(shí)現(xiàn) __iter__(self)  魔法方法,這個(gè)方法實(shí)際上是返回是一個(gè)迭代器(通常是迭代器本身)。接下來重點(diǎn)要實(shí)現(xiàn)的是  __next__(self) 魔法方法,因?yàn)樗鼪Q定了迭代的規(guī)則。

          class Fibs:
              def __init__(self, n=20):
                  self.a = 0
                  self.b = 1
                  self.n = n
              def __iter__(self):
                  return self
              def __next__(self):
                  self.a, self.b = self.b, self.a + self.b
                  if self.a &gt; self.n:
                      raise StopIteration
                  return self.a

          fibs = Fibs()
          for each in fibs:
              print(each)

          # 輸出 
          # 1 1 2 3 5 8 13

          一般來說,迭代器滿足以下幾種特性:

          • 迭代器是?個(gè)對象
          • 迭代器可以被 next() 函數(shù)調(diào)?,并返回?個(gè)值
          • 迭代器可以被 iter() 函數(shù)調(diào)?,并返回一個(gè)迭代器(可以是自身)
          • 連續(xù)被 next() 調(diào)?時(shí)依次返回?系列的值
          • 如果到了迭代的末尾,則拋出 StopIteration 異常
          • 迭代器也可以沒有末尾,只要被 next() 調(diào)?,就?定會(huì)返回?個(gè)值
          • Python 中, next() 內(nèi)置函數(shù)調(diào)?的是對象的 next() ?法
          • Python 中, iter() 內(nèi)置函數(shù)調(diào)?的是對象的 iter() ?法
          • ?個(gè)實(shí)現(xiàn)了迭代器協(xié)議的的對象可以被 for 語句循環(huán)迭代直到終?

          了解了什么是迭代器后,我們就可以開始解讀 torch.utils.data 模塊

          對于 torch.utils.data 而言,重點(diǎn)是其 Dataset, Sampler, DataLoader 模塊,輔以 collate, fetch, pin_memory 等組件對特定功能予以支持。

          1 Dataset

          Dataset 負(fù)責(zé)對 raw data source 封裝,將其封裝成 Python 可識別的數(shù)據(jù)結(jié)構(gòu),其必須提供提取數(shù)據(jù)個(gè)體的接口。

          Dataset 共有 Map-style datasets 和 Iterable-style datasets 兩種:

          1.1 Map-style datasettorch.utils.data.Dataset

          它是一種通過實(shí)現(xiàn)__getitem__()__len()__來獲取數(shù)據(jù)的 Dataset,它表示從(可能是非整數(shù))索引/關(guān)鍵字到數(shù)據(jù)樣本的映射。訪問時(shí),這樣的數(shù)據(jù)集用 dataset[idx] 訪問 idx 對應(yīng)的數(shù)據(jù)。通常我們使用 Map-style 類型的 dataset 居多,其數(shù)據(jù)接口定義如下:

          class Dataset(Generic[T_co]):
              # Generic is an Abstract base class for generic types.

              def __getitem__(self, index) -&gtT_co:
                  raise NotImplementedError
              
              def __add__(self, other: 'Dataset[T_co]') -&gt; 'ConcatDataset[T_co]':
                  return ConcatDataset([self, other])

          PyTorch 中所有定義的 Dataset 都是其子類。

          對于一般計(jì)算機(jī)視覺任務(wù),我們通常會(huì)在其中進(jìn)行一些 resize, crop, flip 等預(yù)處理的操作 值得一提的是,PyTorch 源碼中并沒有提供默認(rèn)的 len() 方法實(shí)現(xiàn),原因是 return NotImplemented 或者 raise NotImplementedError() 之類的默認(rèn)實(shí)現(xiàn)都會(huì)存在各自的問題,這點(diǎn)在其源碼中也有注釋加以體現(xiàn)。

          1.2 Iterable-style datasettorch.utils.data.IterableDataset

          它是一種實(shí)現(xiàn) iter() 來獲取數(shù)據(jù)的 Dataset,這種類型的數(shù)據(jù)集特別適用于以下情況:隨機(jī)讀取代價(jià)很大甚至不大可能,且 batch size 取決于獲取的數(shù)據(jù)。其接口定義如下:

          class IterableDataset(Dataset[T_co]):

              def __iter__(self) -&gt; Iterator[T_co]:
                  raise NotImplementedError
              
              def __add__(self, other: Dataset[T_co]):
                  return ChainDataset([self, other])

          特別地,當(dāng) DataLoader 的 num_workers > 0 時(shí), 每個(gè) worker 都將具有數(shù)據(jù)對象的不同樣本。因此需要獨(dú)立地對每個(gè)副本進(jìn)行配置,以防止每個(gè) worker 產(chǎn)生的數(shù)據(jù)不重復(fù)。同時(shí),數(shù)據(jù)加載順序完全由用戶定義的可迭代樣式控制。這允許更容易地實(shí)現(xiàn)塊讀取和動(dòng)態(tài)批次大小(例如,通過每次產(chǎn)生一個(gè)批次的樣本)

          1.3 其他 Dataset

          除了 Map-style dataset 和 Iterable-style dataset 以外,PyTorch 也在此基礎(chǔ)上提供了其他類型的 Dataset 子類

          • torch.utils.data.ConcatDataset: 用于連接多個(gè) ConcatDataset 數(shù)據(jù)集

          • torch.utils.data.ChainDataset : 用于連接多個(gè) IterableDataset 數(shù)據(jù)集,在 IterableDataset 的 add() 方法中被調(diào)用

          • torch.utils.data.Subset: 用于獲取指定一個(gè)索引序列對應(yīng)的子數(shù)據(jù)集

          class Subset(Dataset[T_co]):

              dataset: Dataset[T_co]
              indices: Sequence[int]
              
              def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -&gt; None:
                  self.dataset = dataset
                  self.indices = indices
              
              def __getitem__(self, idx):
                  return self.dataset[self.indices[idx]]
              
              def __len__(self):
                  return len(self.indices)
          • torch.utils.data.TensorDataset: 用于獲取封裝成 tensor 的數(shù)據(jù)集,每一個(gè)樣本都通過索引張量來獲得。class TensorDataset(Dataset):
              def __init__(self, *tensor):
                  assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
                  self.tensors = tensors
              
              def __getitem__(self, index):
                  return tuple(tensor[index] for tensor in tensors
              
              def __len__(self):
                  return self.tensors[0].size(0)

          2 Sampler

          torch.utils.data.Sampler 負(fù)責(zé)提供一種遍歷數(shù)據(jù)集所有元素索引的方式。可支持用戶自定義,也可以用 PyTorch 提供的,基類接口定義如下:

          lass Sampler(Generic[T_co]):
              r"""Base class for all Samplers.

              Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
              way to iterate over indices of dataset elements, and a :meth:`__len__` method
              that returns the length of the returned iterators.
              
              .. note:: The :meth:`__len__` method isn't strictly required by
                        :class:`~torch.utils.data.DataLoader`, but is expected in any
                        calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
              "
          ""
              
              def __init__(self, data_source: Optional[Sized]) -&gt; None:
                  pass
              
              def __iter__(self) -&gt; Iterator[T_co]:
                  raise NotImplementedError

          特別地,__len__() 方法不是必要的,但是當(dāng) DataLoader 需要計(jì)算 len() 的時(shí)候必須定義,這點(diǎn)在其源碼中也有注釋加以體現(xiàn)。

          同樣,PyTorch 也在此基礎(chǔ)上提供了其他類型的 Sampler 子類

          • torch.utils.data.SequentialSampler : 順序采樣樣本,始終按照同一個(gè)順序
          • torch.utils.data.RandomSampler: 可指定有無放回地,進(jìn)行隨機(jī)采樣樣本元素
          • torch.utils.data.SubsetRandomSampler: 無放回地按照給定的索引列表采樣樣本元素
          • torch.utils.data.WeightedRandomSampler: 按照給定的概率來采樣樣本。樣本元素來自 [0,…,len(weights)-1] , 給定概率(權(quán)重)
          • torch.utils.data.BatchSampler: 在一個(gè)batch中封裝一個(gè)其他的采樣器, 返回一個(gè) batch 大小的 index 索引
          • torch.utils.data.DistributedSample: 將數(shù)據(jù)加載限制為數(shù)據(jù)集子集的采樣器。與 torch.nn.parallel.DistributedDataParallel 結(jié)合使用。在這種情況下,每個(gè)進(jìn)程都可以將 DistributedSampler 實(shí)例作為 DataLoader 采樣器傳遞

          3 DataLoader

          torch.utils.data.DataLoader 是 PyTorch 數(shù)據(jù)加載的核心,負(fù)責(zé)加載數(shù)據(jù),同時(shí)支持 Map-style 和 Iterable-style Dataset,支持單進(jìn)程/多進(jìn)程,還可以設(shè)置 loading order, batch size, pin memory 等加載參數(shù)。其接口定義如下:

          DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
                     batch_sampler=None, num_workers=0, collate_fn=None,
                     pin_memory=False, drop_last=False, timeout=0,
                     worker_init_fn=None, *, prefetch_factor=2,
                     persistent_workers=False)

          對于每個(gè)參數(shù)的含義,以下給出一個(gè)表格進(jìn)行對應(yīng)介紹:

          attributemeaningdefault valuetype
          dataset加載數(shù)據(jù)的數(shù)據(jù)集
          Dataset
          batch_size每個(gè) batch 加載多少個(gè)樣本1int
          shuffle設(shè)置為 True 時(shí),調(diào)用 RandomSampler 進(jìn)行隨機(jī)索引Falsebool
          sampler定義從數(shù)據(jù)集中提取樣本的策略 如果指定了, shuffle 參數(shù)必須為 False,(否則會(huì)和 RandomSampler 互斥)NoneSampler, Iterable
          batch_sampler和 sampler 類似,但是一般傳入 BatchSampler,每次返回一個(gè) batch 大小的索引 其和 batch_size, shuffle 等參數(shù)是互斥的NoneSampler, Iterable
          num_workers要用于數(shù)據(jù)加載的子進(jìn)程數(shù),0 表示將在主進(jìn)程中加載數(shù)據(jù)0int
          collate_fn在將 Map-style datase t 取出的數(shù)據(jù)整合成 batch 時(shí)使用,合并樣本列表以形成一個(gè) batchNonecallable
          pin_memory如果為 True,則 DataLoader 在將張量返回之前將其復(fù)制到 CUDA 固定的內(nèi)存中Falsebool
          drop_last設(shè)置為 True 刪除最后一個(gè)不完整的批次,如果該數(shù)據(jù)集大小不能被該批次大小整除。如果 False 并且數(shù)據(jù)集的大小不能被批次大小整除,那么最后一批將較小Falsebool
          timeout如果為正,則為從 worker 收集 batch 的超時(shí)值,應(yīng)始終為非負(fù)數(shù) 超過這個(gè)時(shí)間還沒讀取到數(shù)據(jù)的話就會(huì)報(bào)錯(cuò)0numeric
          worker_init_fn如果不為 None,它將會(huì)被每個(gè) worker 子進(jìn)程調(diào)用, 以 worker id ([0, num_workers - 1] 內(nèi)的整形) 為輸入Nonecallable
          prefetch_factor每個(gè) worker 提前加載 的 sample 數(shù)量2int
          persistent_workers如果為 True,dataloader 將不會(huì)終止 worker 進(jìn)程,直到 dataset 迭代完成Falsebool

          從參數(shù)定義中,我們可以看到 DataLoader 主要支持以下幾個(gè)功能

          • 支持加載 map-style 和 iterable-style 的 dataset,主要涉及到的參數(shù)是 dataset
          • 自定義數(shù)據(jù)加載順序,主要涉及到的參數(shù)有 shuffle, sampler, batch_sampler, collate_fn
          • 自動(dòng)把數(shù)據(jù)整理成batch序列,主要涉及到的參數(shù)有 batch_size, batch_sampler, collate_fn, drop_last
          • 單進(jìn)程和多進(jìn)程的數(shù)據(jù)加載,主要涉及到的參數(shù)有 num_workers, worker_init_fn
          • 自動(dòng)進(jìn)行鎖頁內(nèi)存讀取 (memory pinning),主要涉及到的參數(shù) pin_memory
          • 支持?jǐn)?shù)據(jù)預(yù)加載,主要涉及的參數(shù) prefetch_factor

          3.1 三者關(guān)系 (Dataset, Sampler, Dataloader)

          通過以上介紹的三者工作內(nèi)容不難推出其內(nèi)在關(guān)系:

          1. 設(shè)置 Dataset,將數(shù)據(jù) data source 包裝成 Dataset 類,暴露提取接口。

          2. 設(shè)置 Sampler,決定采樣方式。我們是能從 Dataset 中提取元素了,還是需要設(shè)置 Sampler 告訴程序提取 Dataset 的策略。

          3. 將設(shè)置好的 Dataset 和 Sampler 傳入 DataLoader,同時(shí)可以設(shè)置 shuffle, batch_size 等參數(shù)。使用 DataLoader 對象可以方便快捷地在數(shù)據(jù)集上遍歷。

          總結(jié)來說,即 Dataloader 負(fù)責(zé)總的調(diào)度,命令 Sampler 定義遍歷索引的方式,然后用索引去 Dataset 中提取元素。于是就實(shí)現(xiàn)了對給定數(shù)據(jù)集的遍歷。

          3.2 批處理

          3.2.1 自動(dòng)批處理(默認(rèn))

          DataLoader 支持通過參數(shù)batch_size, drop_last, batch_sampler,自動(dòng)地把取出的數(shù)據(jù)整理 (collate) 成批次樣本 (batch)

          batch_size 和 drop_last 參數(shù)用于指定 DataLoader 如何獲取 dataset 的 key。特別地,對于 map-style 類型的 dataset,用戶可以選擇指定 batch_sample參數(shù),一次就生成一個(gè) keys list

          在使用 sampler 產(chǎn)生的 indices 獲取采樣到的數(shù)據(jù)時(shí),DataLoader 使用 collate_fn 參數(shù)將樣本列表整理成 batch。抽象這個(gè)過程,其表示方式大致如下

          # For Map-style
          for indices in batch_sampler:
              yield collate_fn([dataset[i] for i in indices])

          # For Iterable-style
          dataset_iter = iter(dataset)
          for indices in batch_sampler:
              yield collate_fn([next(dataset_iter) for _ in indices])

          3.2.2 關(guān)閉自動(dòng)批處理

          當(dāng)用戶想用 dataset 代碼手動(dòng)處理 batch,或僅加載單個(gè) sample data 時(shí),可將 batch_size 和 batch_sampler 設(shè)為 None, 將關(guān)閉自動(dòng)批處理。此時(shí),由 Dataset 產(chǎn)生的 sample 將會(huì)直接被 collate_fn 處理。抽象這個(gè)過程,其表示方式大致如下:

          # For Map-style
          for index in sampler:
              yield collate_fn(dataset[index])

          # For Iterable-style
          for data in iter(dataset):
              yield collate_fn(data)

          3.2.3 collate_fn當(dāng)關(guān)閉自動(dòng)批處理 (automatic batching) 時(shí),collate_fn 作用于單個(gè)數(shù)據(jù)樣本,只是在 PyTorch 張量中轉(zhuǎn)換 NumPy 數(shù)組。

          當(dāng)開啟自動(dòng)批處理 (automatic batching) 時(shí),collate_fn 作用于數(shù)據(jù)樣本列表,將輸入樣本整理為一個(gè) batch,一般做下面 3 件事情

          • 添加新的批次維度(一般是第一維)
          • 它會(huì)自動(dòng)將 NumPy 數(shù)組和 Python 數(shù)值轉(zhuǎn)換為 PyTorch 張量
          • 它保留數(shù)據(jù)結(jié)構(gòu),例如,如果每個(gè)樣本都是 dict,則輸出具有相同鍵集但批處理過的張量作為值的字典(或list,當(dāng)不能轉(zhuǎn)換的時(shí)候)。list, tuples, namedtuples 同樣適用

          自定義 collate_fn 可用于自定義排序規(guī)則,例如,將順序數(shù)據(jù)填充到批處理的最大長度,添加對自定義數(shù)據(jù)類型的支持等。

          3.3 多進(jìn)程處理 (multi-process)

          為了避免在加載數(shù)據(jù)時(shí)阻塞計(jì)算代碼,PyTorch 提供了一個(gè)簡單的開關(guān),只需將參數(shù)設(shè)置 num_workers 為正整數(shù)即可執(zhí)行多進(jìn)程數(shù)據(jù)加載,設(shè)置為 0 時(shí)執(zhí)行單線程數(shù)據(jù)加載。

          4. 單進(jìn)程

          在單進(jìn)程模式下,DataLoader 初始化的進(jìn)程和取數(shù)據(jù)的進(jìn)程是一樣的 。因此,數(shù)據(jù)加載可能會(huì)阻止計(jì)算。

          但是,當(dāng)用于在進(jìn)程之間共享數(shù)據(jù)的資源(例如共享內(nèi)存,文件描述符)有限時(shí),或者當(dāng)整個(gè)數(shù)據(jù)集很小并且可以完全加載到內(nèi)存中時(shí),此模式可能是首選。

          此外,單進(jìn)程加載通常顯示更多可讀的錯(cuò)誤跟蹤,因此對于調(diào)試很有用。

          5. 多進(jìn)程

          在多進(jìn)程模式下,每次 DataLoader 創(chuàng)建 iterator 時(shí)(例如,當(dāng)調(diào)用時(shí)enumerate(dataloader)),都會(huì)創(chuàng)建 num_workers 工作進(jìn)程。dataset, collate_fn, worker_init_fn 都會(huì)被傳到每個(gè)worker中,每個(gè)worker都用獨(dú)立的進(jìn)程。

          對于 map-style 數(shù)據(jù),主線程會(huì)用 Sampler 產(chǎn)生 indice,并將它們送到 worker 里。因此,shuffle是在主線程做的

          對于 iterable-style 數(shù)據(jù),因?yàn)槊總€(gè) worker 都有相同的 data 復(fù)制樣本,并在各個(gè)進(jìn)程里進(jìn)行不同的操作,以防止每個(gè)進(jìn)程輸出的數(shù)據(jù)是重復(fù)的,所以一般用 torch.utils.data.get_worker_info() 來進(jìn)行輔助處理。

          這里,torch.utils.data.get_worker_info() 返回worker進(jìn)程的一些信息(id, dataset, num_workers, seed),如果在主線程跑的話返回None

          注意,通常不建議在多進(jìn)程加載中返回CUDA張量,因?yàn)樵谑褂肅UDA和在多處理中共享CUDA張量時(shí)存在許多微妙之處(文檔中提出:只要接收過程保留張量的副本,就需要發(fā)送過程來保留原始張量)。建議采用 pin_memory=True ,以將數(shù)據(jù)快速傳輸?shù)街С諧UDA的GPU。簡而言之,不建議在使用多線程的情況下返回CUDA的tensor。

          6 鎖頁內(nèi)存 (Memory Pinning)

          這里首先解釋一下鎖頁內(nèi)存的概念。

          主機(jī)中的內(nèi)存,有兩種存在方式,一是鎖頁,二是不鎖頁,鎖頁內(nèi)存存放的內(nèi)容在任何情況下都不會(huì)與主機(jī)的虛擬內(nèi)存進(jìn)行交換(注:虛擬內(nèi)存就是硬盤),而不鎖頁內(nèi)存在主機(jī)內(nèi)存不足時(shí),數(shù)據(jù)會(huì)存放在虛擬內(nèi)存中。主機(jī)到GPU副本源自固定(頁面鎖定)內(nèi)存時(shí),速度要快得多。CPU張量和存儲暴露了一種 pin_memory() 方法,該方法返回對象的副本,并將數(shù)據(jù)放在固定的區(qū)域中。

          而顯卡中的顯存全部是鎖頁內(nèi)存!當(dāng)計(jì)算機(jī)的內(nèi)存充足的時(shí)候,可以設(shè)置 pin_memory=True。設(shè)置 pin_memory=True,則意味著生成的 Tensor 數(shù)據(jù)最開始是屬于內(nèi)存中的鎖頁內(nèi)存,這樣將內(nèi)存的Tensor轉(zhuǎn)義到GPU的顯存就會(huì)更快一些。同時(shí),由于 pin_memory 的作用是將張量返回之前將其復(fù)制到 CUDA 固定的內(nèi)存中,所以只有在 CUDA 環(huán)境支持下才有用。

          PyTorch 原生的 pin_memory 方法如下,其支持大部分 python 數(shù)據(jù)類型的處理:

          def pin_memory(data):
              if isinstance(data, torch.Tensor):
                  return data.pin_memory()
              elif isinstance(data, string_classes):
                  return data
              elif isinstance(data, container_abcs.Mapping):
                  return {k: pin_memory(sample) for k, sample in data.items()}
              elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple
                  return type(data)(*(pin_memory(sample) for sample in data))
              elif isinstance(data, container_abcs.Sequence):
                  return [pin_memory(sample) for sample in data]
              elif hasattr(data, "pin_memory"):
                  return data.pin_memory()
              else:
                  return data

          默認(rèn)情況下,如果固定邏輯看到一個(gè)屬于自定義類型 (custom type) 的batch(如果有一個(gè) collate_fn 返回自定義批處理類型的批處理,則會(huì)發(fā)生),或者如果該批處理的每個(gè)元素都是 custom type,則固定邏輯將無法識別它們,它將返回該批處理(或那些元素)而無需固定內(nèi)存。要為自定義批處理或數(shù)據(jù)類型啟用內(nèi)存固定,需 pin_memory() 在自定義類型上定義一個(gè)方法。如下

          class SimpleCustomBatch:
              # 自定義一個(gè)類,該類不能被PyTorch原生的pin_memory方法所支持

              def __init__(self, data):
                  transposed_data = list(zip(*data))
                  self.inp = torch.stack(transposed_data[0], 0)
                  self.tgt = torch.stack(transposed_data[1], 0)

              # custom memory pinning method on custom type
              def pin_memory(self):
                  self.inp = self.inp.pin_memory()
                  self.tgt = self.tgt.pin_memory()
                  return self

          def collate_wrapper(batch):
              return SimpleCustomBatch(batch)

          inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
          tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
          dataset = TensorDataset(inps, tgts)

          loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                              pin_memory=True)

          for batch_ndx, sample in enumerate(loader):
              print(sample.inp.is_pinned())  # True
              print(sample.tgt.is_pinned())  # True

          7 預(yù)取 (prefetch)

          DataLoader 通過指定 prefetch_factor (默認(rèn)為 2)來進(jìn)行數(shù)據(jù)的預(yù)取。

          class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
              def __init__(self, loader):
                  ...
                  self._reset(loader, first_iter=True)

              def _reset(self, loader, first_iter=False):
                  ...
                  # prime the prefetch loop
                  for _ in range(self._prefetch_factor * self._num_workers):
                      self._try_put_index()

          通過源碼可以看到,prefetch 功能僅適用于 多進(jìn)程 加載中(下面會(huì)由多進(jìn)程 dataloader 的代碼分析)

          8 代碼詳解

          讓我們來看看具體的代碼調(diào)用流程:

          for data, label in train_loader:
              ......

          for 循環(huán)會(huì)調(diào)用 dataloader 的 iter(self) 方法,以此獲得迭代器來遍歷 dataset

          class DataLoader(Generic[T_co]):
              ...
              def __iter__(self) -&gt; '_BaseDataLoaderIter':

                  if self.persistent_workers and self.num_workers &gt; 0:
                      if self._iterator is None:
                          self._iterator = self._get_iterator()
                      else:
                          self._iterator._reset(self)
                      return self._iterator
                  else:
                      return self._get_iterator()

          iter(self) 方法中,dataloader 調(diào)用了 self._get_iterator() 方法,根據(jù) num_worker 獲得迭代器,并指示進(jìn)行單進(jìn)程還是多進(jìn)程

          class DataLoader(Generic[T_co]):
              ...
              def _get_iterator(self) -&gt; '_BaseDataLoaderIter':
                  if self.num_workers == 0:
                      return _SingleProcessDataLoaderIter(self)
                  else:
                      self.check_worker_number_rationality()
                      return _MultiProcessingDataLoaderIter(self)

          為了描述清晰,我們只考慮單進(jìn)程的代碼。下面是 class _SingleProcessDataLoaderIter(_BaseDataLoaderIter) ,以及其父類 class _BaseDataLoaderIter(object): 的重點(diǎn)代碼片段:

          class _BaseDataLoaderIter(object):
              def __init__(self, loader: DataLoader) -&gt; None:
                  # 初始化賦值一些 DataLoader 參數(shù),
                  # 以及用戶輸入合法性進(jìn)行校驗(yàn)
                  self._dataset = loader.dataset
                  self._dataset_kind = loader._dataset_kind
                  self._index_sampler = loader._index_sampler
                  ...

              def __iter__(self) -&gt; '_BaseDataLoaderIter':
                  return self
              
              def _reset(self, loader, first_iter=False):
                  self._sampler_iter = iter(self._index_sampler)
                  self._num_yielded = 0
                  self._IterableDataset_len_called = loader._IterableDataset_len_called
              
              def _next_index(self):
                  return next(self._sampler_iter)  # may raise StopIteration
              
              def _next_data(self):
                  raise NotImplementedError
              
              def __next__(self) -&gt; Any:
                  with torch.autograd.profiler.record_function(self._profile_name):
                      if self._sampler_iter is None:
                          self._reset()
                      data = self._next_data() # 重點(diǎn)代碼行,通過此獲取數(shù)據(jù)
                      self._num_yielded += 1
                      ...
                      return data
              
              next = __next__  # Python 2 compatibility
              
              def __len__(self) -&gt; int:
                  return len(self._index_sampler) # len(_BaseDataLoaderIter) == len(self._index_sampler)
              
              def __getstate__(self):
                  raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)

          _BaseDataLoaderIter 是所有 DataLoaderIter 的父類。dataloader獲得了迭代器之后,for 循環(huán)需要調(diào)用 next() 來獲得下一個(gè)對象,從而實(shí)現(xiàn)遍歷。通過 next 方法調(diào)用 _next_data() 獲取數(shù)據(jù)

          class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
              def __init__(self, loader):
                  super(_SingleProcessDataLoaderIter, self).__init__(loader)
                  assert self._timeout == 0
                  assert self._num_workers == 0

                  self._dataset_fetcher = _DatasetKind.create_fetcher(
                      self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
              
              def _next_data(self):
                  index = self._next_index()  # may raise StopIteration
                  data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
                  if self._pin_memory:
                      data = _utils.pin_memory.pin_memory(data)
                  return data

          從 _SingleProcessDataLoaderIter 的初始化參數(shù)可以看到,其在父類 _BaseDataLoaderIter 的基礎(chǔ)上定義了 _dataset_fetcher, 并傳入 _dataset, _auto_collation, _collate_fn 等參數(shù),用于定義獲取數(shù)據(jù)的方式。其具體實(shí)現(xiàn)會(huì)在稍后解釋。

          在 _next_data() 被調(diào)用后,其需要 next_index() 獲取 index,并通過獲得的 index 傳入 _dataset_fetcher 中獲取對應(yīng)樣本

          class DataLoader(Generic[T_co]):
              ...
              @property
              def _auto_collation(self):
                  return self.batch_sampler is not None

              @property
              def _index_sampler(self):
                  if self._auto_collation:
                      return self.batch_sampler
                  else:
                      return self.sampler

          class _BaseDataLoaderIter(object):
              ...
              def _reset(self, loader, first_iter=False):
                  self._sampler_iter = iter(self._index_sampler)
                  ...

              def _next_index(self):
                  # sampler_iter 來自于 index_sampler
                  return next(self._sampler_iter)  # may raise StopIteration

          從這里看出,dataloader 提供了 sampler (可以是batch_sampler 或者是其他 sampler 子類),然后 _SingleProcessDataLoaderIter 迭代sampler獲得索引

          下面我們來看看 fetcher,fetcher 需要 index 來獲取元素,并同時(shí)支持 Map-style dataset(對應(yīng) _MapDatasetFetcher)和 Iterable-style dataset(對應(yīng) _IterableDatasetFetcher),使其在Dataloader內(nèi)能使用相同的接口 fetch,代碼更加簡潔。

          對于 Map-style:直接輸入索引 index,作為 map 的 key,獲得對應(yīng)的樣本(即 value)

          class _MapDatasetFetcher(_BaseDatasetFetcher):
              def __init__(self, dataset, auto_collation, collate_fn, drop_last):
                  super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

              def fetch(self, possibly_batched_index):
                  if self.auto_collation:
                      # 有batch_sampler,_auto_collation就為True,
                      # 就優(yōu)先使用batch_sampler,對應(yīng)在fetcher中傳入的就是一個(gè)batch的索引
                      data = [self.dataset[idx] for idx in possibly_batched_index]
                  else:
                      data = self.dataset[possibly_batched_index]
                  return self.collate_fn(data)

          對于 Iterable-style: init 方法內(nèi)設(shè)置了 dataset 初始的迭代器,fetch 方法內(nèi)獲取元素,index 其實(shí)已經(jīng)沒有多大作用了

          class _IterableDatasetFetcher(_BaseDatasetFetcher):
              def __init__(self, dataset, auto_collation, collate_fn, drop_last):
                  super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
                  self.dataset_iter = iter(dataset)
              
              def fetch(self, possibly_batched_index):
                  if self.auto_collation:
                      # 對于batch_sampler(即auto_collation==True)
                      # 直接使用往后遍歷并提取len(possibly_batched_index)個(gè)樣本(即1個(gè)batch的樣本)
                      data = []
                      for _ in possibly_batched_index:
                          try:
                              data.append(next(self.dataset_iter))
                          except StopIteration:
                              break
                      if len(data) == 0 or (self.drop_last and len(data) &lt; len(possibly_batched_index)):
                          raise StopIteration
                  else:
                      # 對于sampler,直接往后遍歷并提取1個(gè)樣本
                      data = next(self.dataset_iter)
                  return self.collate_fn(data)        

          最后,我們通過索引傳入 fetcher,fetch 得到想要的樣本 因此,整個(gè)過程調(diào)用關(guān)系總結(jié) 如下:

          loader.__iter__ --> self._get_iterator()--> class _SingleProcessDataLoaderIter --> class _BaseDataLoaderIter--> __next__() --> self._next_data()--> self._next_index()-->next(self._sampler_iter)next(iter(self._index_sampler)) --> 獲得 index --> self._dataset_fetcher.fetch(index)--> 獲得 data

          對于多進(jìn)程而言,借用 PyTorch 內(nèi)源碼的注釋,其運(yùn)行流程解釋如下

          # Our data model looks like this (queues are indicated with curly brackets):
          #
          #                main process                              ||
          #                     |                                    ||
          #               {index_queue}                              ||
          #                     |                                    ||
          #              worker processes                            ||     DATA
          #                     |                                    ||
          #            {worker_result_queue}                         ||     FLOW
          #                     |                                    ||
          #      pin_memory_thread of main process                   ||   DIRECTION
          #                     |                                    ||
          #               {data_queue}                               ||
          #                     |                                    ||
          #                data output                               \/
          #
          # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
          #      `pin_memory=False`.

          首先 dataloader 基于 multiprocessing 產(chǎn)生多進(jìn)程,每個(gè)子進(jìn)程的輸入輸出通過兩個(gè)主要的隊(duì)列 (multiprocessing.Queue() 類) 產(chǎn)生,分別為:

          • index_queue: 每個(gè)子進(jìn)程的隊(duì)列中需要處理的任務(wù)的下標(biāo)
          • _worker_result_queue: 返回時(shí)處理完任務(wù)的下標(biāo)
          • data_queue: 表明經(jīng)過 pin_memory 處理后的數(shù)據(jù)隊(duì)列

          并且有以下這些比較重要的 flag 參數(shù)來協(xié)調(diào)各個(gè) worker 之間的工作:

          • _send_idx: 發(fā)送索引,用來記錄這次要放 index_queue 中 batch 的 idx
          • _rcvd_idx: 接受索引,記錄要從 data_queue 中取出的 batch 的 idx
          • _task_info: 存儲將要產(chǎn)生的 data 信息的 dict,key為 task idx(由 0 開始的整形索引),value 為 (worker_id,) 或 (worker_id, data),分別對應(yīng)數(shù)據(jù) 未取 和 已取 的情況
          • _tasks_outstanding: 整形,代表已經(jīng)準(zhǔn)備好的 task/batch 的數(shù)量(可能有些正在準(zhǔn)備中)

          每個(gè) worker 一次產(chǎn)生一個(gè) batch 的數(shù)據(jù),返回 batch 數(shù)據(jù)前放入下一個(gè)批次要處理的數(shù)據(jù)下標(biāo),對應(yīng)構(gòu)造函數(shù)子進(jìn)程初始化如下

          class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
              def __init__(self, loader):
                  super(_MultiProcessingDataLoaderIter, self).__init__(loader)
                  ...
                  self._worker_result_queue = multiprocessing_context.Queue()  # 把該worker取出的數(shù)放入該隊(duì)列,用于進(jìn)程間通信
                  ...
                  self._workers_done_event = multiprocessing_context.Event()

                  self._index_queues = []
                  self._workers = []
                  for i in range(self._num_workers):
                      index_queue = multiprocessing_context.Queue()  # 索引隊(duì)列,每個(gè)子進(jìn)程一個(gè)隊(duì)列放要處理的下標(biāo)
                      index_queue.cancel_join_thread()
                      # _worker_loop 的作用是:從index_queue中取索引,然后通過collate_fn處理數(shù)據(jù),
                      # 然后再將處理好的 batch 數(shù)據(jù)放到 data_queue 中。(發(fā)送到隊(duì)列中的idx是self.send_idx)
                      w = multiprocessing_context.Process(
                          target=_utils.worker._worker_loop,  # 每個(gè)worker子進(jìn)程循環(huán)執(zhí)行的函數(shù),主要將數(shù)據(jù)以(idx, data)的方式傳入_worker_result_queue中
                          args=(self._dataset_kind, self._dataset, index_queue, 
                                self._worker_result_queue, self._workers_done_event,
                                self._auto_collation, self._collate_fn, self._drop_last,
                                self._base_seed + i, self._worker_init_fn, i, self._num_workers,
                                self._persistent_workers))
                      w.daemon = True
                      w.start()
                      self._index_queues.append(index_queue)
                      self._workers.append(w)
              
                  if self._pin_memory:
                      self._pin_memory_thread_done_event = threading.Event()
              
                      self._data_queue = queue.Queue()  # 用于存取出的數(shù)據(jù)進(jìn)行 pin_memory 操作后的結(jié)果
                      pin_memory_thread = threading.Thread(
                          target=_utils.pin_memory._pin_memory_loop,
                          args=(self._worker_result_queue, self._data_queue,
                                torch.cuda.current_device(),
                                self._pin_memory_thread_done_event))
                      pin_memory_thread.daemon = True
                      pin_memory_thread.start()
                      # Similar to workers (see comment above), we only register
                      # pin_memory_thread once it is started.
                      self._pin_memory_thread = pin_memory_thread
                  else:
                      self._data_queue = self._worker_result_queue
                  ...
                  self._reset(loader, first_iter=True)
              
              def _reset(self, loader, first_iter=False):
                  super()._reset(loader, first_iter)
                  self._send_idx = 0  # idx of the next task to be sent to workers,發(fā)送索引,用來記錄這次要放 index_queue 中 batch 的 idx
                  self._rcvd_idx = 0  # idx of the next task to be returned in __next__,接受索引,記錄要從 data_queue 中取出的 batch 的 idx
              
                  # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
                  # map: task idx =&gt; - (worker_id,)        if data isn't fetched (outstanding)
                  #                  \ (worker_id, data)   if data is already fetched (out-of-order)
                  self._task_info = {}
              
                  # _tasks_outstanding 指示當(dāng)前已經(jīng)準(zhǔn)備好的 task/batch 的數(shù)量(可能有些正在準(zhǔn)備中)
                  # 初始值為 0, 在 self._try_put_index() 中 +1,在 self._next_data 中-1
                  self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
              
                  # this indicates status that a worker still has work to do *for this epoch*.
                  self._workers_status = [True for i in range(self._num_workers)] 
                  # We resume the prefetching in case it was enabled
                  if not first_iter:
                      for idx in range(self._num_workers):
                          self._index_queues[idx].put(_utils.worker._ResumeIteration())
                      resume_iteration_cnt = self._num_workers
                      while resume_iteration_cnt &gt; 0:
                          data = self._get_data()
                          if isinstance(data, _utils.worker._ResumeIteration):
                              resume_iteration_cnt -= 1
                  ...
                  # 初始化的時(shí)候,就將 2*num_workers 個(gè) (batch_idx, sampler_indices) 放到 index_queue 中
                  for _ in range(self._prefetch_factor * self._num_workers):
                      self._try_put_index() # 進(jìn)行預(yù)取

          dataloader 初始化的時(shí)候,每個(gè) worker 的 index_queue 默認(rèn)會(huì)放入兩個(gè) batch 的 index,從 index_queue 中取出要處理的下標(biāo)

          def _try_put_index(self):
                  # self._prefetch_factor 默認(rèn)為 2
                  assert self._tasks_outstanding &lt; self._prefetch_factor * self._num_workers

                  try:
                      index = self._next_index()
                  except StopIteration:
                      return
                  for _ in range(self._num_workers):  # find the next active worker, if any
                      worker_queue_idx = next(self._worker_queue_idx_cycle)
                      if self._workers_status[worker_queue_idx]:
                          break
                  else:
                      # not found (i.e., didn't break)
                      return
              
                  self._index_queues[worker_queue_idx].put((self._send_idx, index)) # 放入 任務(wù)下標(biāo) 和 數(shù)據(jù)下標(biāo)
                  self._task_info[self._send_idx] = (worker_queue_idx,)
                  # _tasks_outstanding + 1,表明預(yù)備好的batch個(gè)數(shù)+1
                  self._tasks_outstanding += 1
                  # send_idx 發(fā)送索引, 記錄從sample_iter中發(fā)送索引到index_queue的次數(shù)
                  self._send_idx += 1

          調(diào)用 _next_data 方法進(jìn)行數(shù)據(jù)讀取,其中 _process_data 用于返回?cái)?shù)據(jù)

          def _next_data(self):
                  while True:

                      while self._rcvd_idx &lt; self._send_idx: # 確保待處理的任務(wù)(待取的batch)下標(biāo) &gt; 處理完畢要返回的任務(wù)(已經(jīng)取完的batch)下標(biāo)
                          info = self._task_info[self._rcvd_idx]
                          worker_id = info[0]
                          if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active
                              break
                          del self._task_info[self._rcvd_idx]
                          self._rcvd_idx += 1
                      else:
                          # no valid `self._rcvd_idx` is found (i.e., didn't break)
                          if not self._persistent_workers:
                              self._shutdown_workers()
                          raise StopIteration
              
                      # Now `self._rcvd_idx` is the batch index we want to fetch
              
                      # Check if the next sample has already been generated
                      if len(self._task_info[self._rcvd_idx]) == 2:
                          data = self._task_info.pop(self._rcvd_idx)[1]
                          return self._process_data(data)
              
                      assert not self._shutdown and self._tasks_outstanding &gt; 0
                      idx, data = self._get_data() # 調(diào)用 self._try_get_data() 從 self._data_queue 中取數(shù)
                      self._tasks_outstanding -= 1  # 表明預(yù)備好的batch個(gè)數(shù)需要減1
                      if self._dataset_kind == _DatasetKind.Iterable:
                          # Check for _IterableDatasetStopIteration
                          if isinstance(data, _utils.worker._IterableDatasetStopIteration):
                              if self._persistent_workers:
                                  self._workers_status[data.worker_id] = False
                              else:
                                  self._mark_worker_as_unavailable(data.worker_id)
                              self._try_put_index()
                              continue
              
                      if idx != self._rcvd_idx:
                          # store out-of-order samples
                          self._task_info[idx] += (data,)
                      else:
                          del self._task_info[idx]
                          return self._process_data(data) # 返回?cái)?shù)據(jù)
              
              def _process_data(self, data):
                  self._rcvd_idx += 1
                  self._try_put_index() # 同上,主要放入隊(duì)列索引 以及 更新flag
                  if isinstance(data, ExceptionWrapper):
                      data.reraise()
                  return data

          這樣,多線程的 dataloader 就能通過多個(gè) worker 的協(xié)作來共同完成數(shù)據(jù)的加載。

          參考

          • https://pytorch.org/docs/stable/data.html
          • https://www.zhihu.com/search?type=content&q=dataloader
          • https://www.dazhuanlan.com/2019/12/05/5de8104ce9491/
          • https://blog.csdn.net/g11d111/article/details/81504637



          - The End -


          GiantPandaCV

          長按二維碼關(guān)注我們

          本公眾號專注:

          1. 技術(shù)分享;

          2. 學(xué)術(shù)交流

          3. 資料共享

          歡迎關(guān)注我們,一起成長!



          瀏覽 147
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評論
          圖片
          表情
          推薦
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  青娱乐大香蕉 | 淫色视频在线观看 | 国产最新在线播放 | 国产精品操逼片 | 成人豆花视频在线观看 |