實(shí)踐:PyTorch 數(shù)據(jù)集隨機(jī)值
點(diǎn)擊上方“機(jī)器學(xué)習(xí)與生成對抗網(wǎng)絡(luò)”,關(guān)注星標(biāo)
獲取有趣、好玩的前沿干貨!
作者 Elvanth@知乎
https://zhuanlan.zhihu.com/p/377155682
文僅交流,侵刪
一個(gè)快捷的解決方案:
def worker_init_fn(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)
01 關(guān)于pytorch數(shù)據(jù)集隨機(jī)種子的基本認(rèn)識
在pytorch中random、torch.random等隨機(jī)值產(chǎn)生方法一般沒有問題,只有少數(shù)工人運(yùn)行也可以保障其不同的最終值.
np.random.seed 會(huì)出現(xiàn)問題的原因是,當(dāng)多處理采用 fork 方式產(chǎn)生子進(jìn)程時(shí),numpy 不會(huì)對不同的子進(jìn)程產(chǎn)生不同的隨機(jī)值.
換言之,當(dāng)沒有多處理使用時(shí),numpy 不會(huì)出現(xiàn)隨機(jī)種子的不同的問題;實(shí)驗(yàn)代碼的可復(fù)現(xiàn)性要求一個(gè)是工人種子 ,即工人內(nèi)包括numpy,random,torch.random所有的隨機(jī)表現(xiàn);另一個(gè)是Base ,即程序運(yùn)行后的初始隨機(jī)值,其可以通過以下兩種方式產(chǎn)生
torch.manual_seed(base_seed)
由特定的seed generator設(shè)置
generator = torch. Generator()
g.manual_seed(base_seed)
DataLoader(dataset, ..., generator=generator)
使用spawn模式可以斬?cái)嘁陨纤袩?
02 直接在網(wǎng)上搜這個(gè)問題會(huì)得到什么答案
參考很多的解決方案時(shí),往往會(huì)提出以下功能:
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
讓我們看看它的輸出結(jié)果:
(第0,3列是索引,第1,4列是np.random的結(jié)果,第2,5列是random.randint的結(jié)果)
epoch 0
tensor([[ 0, 5125, 13588, 0, 15905, 23182],
[ 1, 7204, 19825, 0, 13653, 25225]])
tensor([[ 2, 1709, 11504, 0, 12842, 23238],
[ 3, 5715, 14058, 0, 15236, 28033]])
tensor([[ 4, 1062, 11239, 0, 10142, 29869],
[ 5, 6574, 15672, 0, 19623, 25600]])
============================================================
epoch 1
tensor([[ 0, 5125, 18134, 0, 15905, 28990],
[ 1, 7204, 13206, 0, 13653, 25106]])
tensor([[ 2, 1709, 15512, 0, 12842, 29703],
[ 3, 5715, 14201, 0, 15236, 27696]])
tensor([[ 4, 1062, 13994, 0, 10142, 23411],
[ 5, 6574, 18532, 0, 19623, 21744]])
============================================================
假設(shè)上述方案對一個(gè)時(shí)代內(nèi)可以防止不同的工人出現(xiàn)隨機(jī)值相同的情況,但不同的時(shí)代之間,其最終的隨機(jī)種子仍然是不變的。
03 那應(yīng)該如何解決
來自pytorch官方的解決方案:
https://github.com/pytorch/pytorch/pull/56488#issuecomment-825128350
def worker_init_fn(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)
來自numpy.random原作者的解決方案:
https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562
def worker_init_fn(id):
process_seed = torch.initial_seed()
# Back out the base_seed so we can use all the bits.
base_seed = process_seed - id
ss = np.random.SeedSequence([id, base_seed])
# More than 128 bits (4 32-bit words) would be overkill.
np.random.seed(ss.generate_state(4))
ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)
一個(gè)更簡單但不保證正確性的解決方案:
def worker_init_fn(worker_id):
np.random.seed((worker_id + torch.initial_seed()) % np.iinfo(np.int32).max)
ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)
04 附上可運(yùn)行的完整文件
import numpy as np
import random
import torch
# np.random.seed(0)
class Transform(object):
def __init__(self):
pass
def __call__(self, item = None):
return [np.random.randint(10000, 20000), random.randint(20000,30000)]
class RandomDataset(object):
def __init__(self):
pass
def __getitem__(self, ind):
item = [ind, np.random.randint(1, 10000), random.randint(10000, 20000), 0]
tsfm =Transform()(item)
return np.array(item + tsfm)
def __len__(self):
return 20
from torch.utils.data import DataLoader
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
ds = RandomDataset()
ds = DataLoader(ds, 10, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)
for epoch in range(2):
print("epoch {}".format(epoch))
np.random.seed()
for batch in ds:
print(batch)
猜您喜歡:
等你著陸!【GAN生成對抗網(wǎng)絡(luò)】知識星球!
CVPR 2021 | GAN的說話人驅(qū)動(dòng)、3D人臉論文匯總
CVPR 2021 | 圖像轉(zhuǎn)換 今如何?幾篇GAN論文
CVPR 2021生成對抗網(wǎng)絡(luò)GAN部分論文匯總
最新最全20篇!基于 StyleGAN 改進(jìn)或應(yīng)用相關(guān)論文
附下載 | 經(jīng)典《Think Python》中文版
附下載 | 《Pytorch模型訓(xùn)練實(shí)用教程》
附下載 | 最新2020李沐《動(dòng)手學(xué)深度學(xué)習(xí)》
附下載 | 《可解釋的機(jī)器學(xué)習(xí)》中文版
附下載 |《TensorFlow 2.0 深度學(xué)習(xí)算法實(shí)戰(zhàn)》
