<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          聊聊Pytorch中的dataloader

          共 4188字,需瀏覽 9分鐘

           ·

          2021-09-30 00:18

          點擊左上方藍字關(guān)注我們



          一個專注于目標檢測與深度學習知識分享的公眾號

          編者薦語
          首先簡單介紹一下DataLoader,它是PyTorch中數(shù)據(jù)讀取的一個重要接口,該接口定義在dataloader.py中,只要是用PyTorch來訓練模型基本都會用到該接口(除非用戶重寫…),該接口的目的:將自定義的Dataset根據(jù)batch size大小、是否shuffle等封裝成一個Batch Size大小的Tensor,用于后面的訓練。
          作者 | Mario@知乎
          鏈接 | https://zhuanlan.zhihu.com/p/117270644

          為啥突然要寫一下pytorch的dataloader呢,首先來說說事情的來龍去脈。

          起初,我最開始單獨訓練一個網(wǎng)絡(luò)來完成landmark點回歸任務(wù)和分類任務(wù),訓練的數(shù)據(jù)是txt格式,在訓練之前對數(shù)據(jù)進行分析,發(fā)現(xiàn)分類任務(wù)中存在嚴重的數(shù)據(jù)樣本不均衡的問題,那么我事先針對性的進行數(shù)據(jù)采樣均衡操作,重新得到訓練和測試的txt數(shù)據(jù)和標簽,保證了整個訓練和測試數(shù)據(jù)的樣本均衡性。

          由于我的整個項目是檢測+點回歸+分類,起初檢測和點回歸+分類是分兩步實現(xiàn)的,檢測是通過讀取XML格式來進行訓練,現(xiàn)在要統(tǒng)一整個項目的訓練和測試過程,要將點回歸+分類的訓練測試過程也按照讀取XML格式來進行,那么就遇到一個問題,如何針對性的去給樣本偏少的樣本進行均衡。

          由于在dataset類中,返回的圖像和標簽都是針對每個index返回一個結(jié)果,在dataset類中進行操作似乎不太可行,那么就想到在dataloader中進行操作,通過dataloader中的參數(shù)sample來完成針對性采樣。

          還有一個問題是關(guān)于num_workers的設(shè)置,因為我有對比過,在我的單機RTX 2080Ti上和八卡服務(wù)器TITAN RTX上(僅使用單卡,其它卡有在跑其它任務(wù)),使用相同的num_workers,在單機上的訓練速度反而更快,于是猜想可能和CPU或者內(nèi)存有關(guān)系,下面會具體分析。


          dataloader中的各個參數(shù)的含義


          類的定義為:torch.utils.data.DataLoader ,其中包含的參數(shù)有:


          torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, \
          batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, \
          drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)



          dataset:定義的dataset類返回的結(jié)果。

          batchsize:每個bacth要加載的樣本數(shù),默認為1。

          shuffle:在每個epoch中對整個數(shù)據(jù)集data進行shuffle重排,默認為False。

          sample:定義從數(shù)據(jù)集中加載數(shù)據(jù)所采用的策略,如果指定的話,shuffle必須為False;batch_sample類似,表示一次返回一個batch的index。

          num_workers:表示開啟多少個線程數(shù)去加載你的數(shù)據(jù),默認為0,代表只使用主進程。

          collate_fn:表示合并樣本列表以形成小批量的Tensor對象。

          pin_memory:表示要將load進來的數(shù)據(jù)是否要拷貝到pin_memory區(qū)中,其表示生成的Tensor數(shù)據(jù)是屬于內(nèi)存中的鎖頁內(nèi)存區(qū),這樣將Tensor數(shù)據(jù)轉(zhuǎn)義到GPU中速度就會快一些,默認為False。

          drop_last:當你的整個數(shù)據(jù)長度不能夠整除你的batchsize,選擇是否要丟棄最后一個不完整的batch,默認為False。

          注:這里簡單科普下pin_memory,通常情況下,數(shù)據(jù)在內(nèi)存中要么以鎖頁的方式存在,要么保存在虛擬內(nèi)存(磁盤)中,設(shè)置為True后,數(shù)據(jù)直接保存在鎖頁內(nèi)存中,后續(xù)直接傳入cuda;否則需要先從虛擬內(nèi)存中傳入鎖頁內(nèi)存中,再傳入cuda,這樣就比較耗時了,但是對于內(nèi)存的大小要求比較高。


          下面針對num_workers,sample和collate_fn分別進行說明:

          1. 設(shè)置num_workers


          pytorch中dataloader一次性創(chuàng)建num_workers個子線程,然后用batch_sampler將指定batch分配給指定worker,worker將它負責的batch加載進RAM,dataloader就可以直接從RAM中找本輪迭代要用的batch。

          如果num_worker設(shè)置得大,好處是尋batch速度快,因為下一輪迭代的batch很可能在上一輪/上上一輪...迭代時已經(jīng)加載好了。壞處是內(nèi)存開銷大,也加重了CPU負擔(worker加載數(shù)據(jù)到RAM的進程是進行CPU復制)。

          如果num_worker設(shè)為0,意味著每一輪迭代時,dataloader不再有自主加載數(shù)據(jù)到RAM這一步驟,只有當你需要的時候再加載相應(yīng)的batch,當然速度就更慢。

          num_workers的經(jīng)驗設(shè)置值是自己電腦/服務(wù)器的CPU核心數(shù),如果CPU很強、RAM也很充足,就可以設(shè)置得更大些,對于單機來說,單跑一個任務(wù)的話,直接設(shè)置為CPU的核心數(shù)最好。


          2. 定義sample


          (假設(shè)dataset類返回的是:data, label)


          from torch.utils.data.sampler import WeightedRandomSampler
          ## 如果label為1,那么對應(yīng)的該類別被取出來的概率是另外一個類別的2倍
          weights = [2 if label == 1 else 1 for data, label in dataset]
          sampler = WeightedRandomSampler(weights,num_samples=10, replacement=True)
          dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)



          PyTorch中提供的這個sampler模塊,用來對數(shù)據(jù)進行采樣。默認采用SequentialSampler,它會按順序一個一個進行采樣。常用的有隨機采樣器:RandomSampler,當dataloader的shuffle參數(shù)為True時,系統(tǒng)會自動調(diào)用這個采樣器,實現(xiàn)打亂數(shù)據(jù)。

          這里使用另外一個很有用的采樣方法:WeightedRandomSampler,它會根據(jù)每個樣本的權(quán)重選取數(shù)據(jù),在樣本比例不均衡的問題中,可用它來進行重采樣。replacement用于指定是否可以重復選取某一個樣本,默認為True,即允許在一個epoch中重復采樣某一個數(shù)據(jù)。


          3. 定義collate_fn



          def detection_collate(batch):
          """Custom collate fn for dealing with batches of images that have a different
          number of associated object annotations (bounding boxes).

          Arguments:
          batch: (tuple) A tuple of tensor images and lists of annotations

          Return:
          A tuple containing:
          1) (tensor) batch of images stacked on their 0 dim
          2) (list of tensors) annotations for a given image are stacked on
          0 dim
          """
          targets = []
          imgs = []
          for sample in batch:
          imgs.append(sample[0])
          targets.append(torch.FloatTensor(sample[1]))
          return torch.stack(imgs, 0), targets



          使用dataloader時加入collate_fn參數(shù),即可合并樣本列表以形成小批量的Tensor對象,如果你的標簽不止一個的話,還可以支持自定義,在上述方法中再額外添加對應(yīng)的label即可。


          data_loader = torch.utils.data.DataLoader(dataset, args.batch_size,
          num_workers=args.num_workers, sampler=sampler, shuffle=False,
          collate_fn=detection_collate, pin_memory=True, drop_last=True)



          參考鏈接

          https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/data.html%3Fhighlight%3Ddataloader%23torch.utils.data.DataLoader

          https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813


          END



          雙一流大學研究生團隊創(chuàng)建,專注于目標檢測與深度學習,希望可以將分享變成一種習慣!

          瀏覽 67
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  国产一级无码免费视频 | 竹菊影视一区二区三区四区 | 一道本最新无码视频 | 爱爱动态欧美 | 日韩熟妇视频 |