<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          PyTorch 數(shù)據(jù)集隨機值的完美實踐

          共 4242字,需瀏覽 9分鐘

           ·

          2021-09-17 09:31

          點擊上方“程序員大白”,選擇“星標”公眾號

          重磅干貨,第一時間送達

          作者 | Elvanth@知乎
          來源 | https://zhuanlan.zhihu.com/p/377155682
          編輯 | 極市平臺
          本文僅作學術交流,版權歸原作者所有,如有侵權請聯(lián)系刪除。

          極市導讀

           

          本文所分析的問題與解決方案將在最近發(fā)布的pytorch版本中解決;因此解決所有煩惱的根源是方法,更新pytorch~ 

          一個快捷的解決方案:

          def worker_init_fn(worker_id):
          worker_seed = torch.initial_seed() % 2**32
          np.random.seed(worker_seed)
          random.seed(worker_seed)

          ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

          01 關于pytorch數(shù)據(jù)集隨機種子的基本認識

          在pytorch中random、torch.random等隨機值產(chǎn)生方法一般沒有問題,只有少數(shù)工人運行也可以保障其不同的最終值.

          np.random.seed 會出現(xiàn)問題的原因是,當多處理采用 fork 方式產(chǎn)生子進程時,numpy 不會對不同的子進程產(chǎn)生不同的隨機值.

          換言之,當沒有多處理使用時,numpy 不會出現(xiàn)隨機種子的不同的問題;實驗代碼的可復現(xiàn)性要求一個是工人種子 ,即工人內包括numpy,random,torch.random所有的隨機表現(xiàn);另一個是Base ,即程序運行后的初始隨機值,其可以通過以下兩種方式產(chǎn)生

          1. torch.manual_seed(base_seed)

          2. 由特定的seed generator設置

          generator = torch. Generator()
          g.manual_seed(base_seed)
          DataLoader(dataset, ..., generator=generator)

          使用spawn模式可以斬斷以上所有煩惱.

          02 直接在網(wǎng)上搜這個問題會得到什么答案

          參考很多的解決方案時,往往會提出以下功能:

          def worker_init_fn(worker_id):
          np.random.seed(np.random.get_state()[1][0] + worker_id)

          讓我們看看它的輸出結果:
          (第0,3列是索引,第1,4列是np.random的結果,第2,5列是random.randint的結果)

          epoch 0
          tensor([[ 0, 5125, 13588, 0, 15905, 23182],
          [ 1, 7204, 19825, 0, 13653, 25225]])
          tensor([[ 2, 1709, 11504, 0, 12842, 23238],
          [ 3, 5715, 14058, 0, 15236, 28033]])
          tensor([[ 4, 1062, 11239, 0, 10142, 29869],
          [ 5, 6574, 15672, 0, 19623, 25600]])
          ============================================================
          epoch 1
          tensor([[ 0, 5125, 18134, 0, 15905, 28990],
          [ 1, 7204, 13206, 0, 13653, 25106]])
          tensor([[ 2, 1709, 15512, 0, 12842, 29703],
          [ 3, 5715, 14201, 0, 15236, 27696]])
          tensor([[ 4, 1062, 13994, 0, 10142, 23411],
          [ 5, 6574, 18532, 0, 19623, 21744]])
          ============================================================

          假設上述方案對一個時代內可以防止不同的工人出現(xiàn)隨機值相同的情況,但不同的時代之間,其最終的隨機種子仍然是不變的。

          03 那應該如何解決

          來自pytorch官方的解決方案:

          https://github.com/pytorch/pytorch/pull/56488#issuecomment-825128350

          def worker_init_fn(worker_id):
          worker_seed = torch.initial_seed() % 2**32
          np.random.seed(worker_seed)
          random.seed(worker_seed)

          ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

          來自numpy.random原作者的解決方案:

          https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562

          def worker_init_fn(id):
          process_seed = torch.initial_seed()
          # Back out the base_seed so we can use all the bits.
          base_seed = process_seed - id
          ss = np.random.SeedSequence([id, base_seed])
          # More than 128 bits (4 32-bit words) would be overkill.
          np.random.seed(ss.generate_state(4))

          ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

          一個更簡單但不保證正確性的解決方案:

          def worker_init_fn(worker_id):
          np.random.seed((worker_id + torch.initial_seed()) % np.iinfo(np.int32).max)

          ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

          04 附上可運行的完整文件

          import numpy as np
          import random
          import torch

          # np.random.seed(0)

          class Transform(object):
          def __init__(self):
          pass

          def __call__(self, item = None):
          return [np.random.randint(10000, 20000), random.randint(20000,30000)]

          class RandomDataset(object):
          def __init__(self):
          pass

          def __getitem__(self, ind):
          item = [ind, np.random.randint(1, 10000), random.randint(10000, 20000), 0]
          tsfm =Transform()(item)
          return np.array(item + tsfm)
          def __len__(self):
          return 20

          from torch.utils.data import DataLoader

          def worker_init_fn(worker_id):
          np.random.seed(np.random.get_state()[1][0] + worker_id)

          ds = RandomDataset()
          ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

          for epoch in range(2):
          print("epoch {}".format(epoch))
          np.random.seed()
          for batch in ds:
                  print(batch)

          “拍一拍” 能撤回了 !?。?/a>

          5款Chrome插件,第1款絕對良心!

          為開發(fā)色情游戲,這家公司赴日尋找AV女優(yōu)拍攝,期望暴力賺錢結果...

          拼多多終于釀成慘劇

          華為阿里下班時間曝光:所有的光鮮,都有加班的味道




          ,,西,,[],


          瀏覽 79
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  亚洲成人影片在线免费看 | 国产精品无码卡一卡二卡三 | 91插插插插插插 | 欧美日韩国产成人一区 | 色多多在线网址 |