Pytorch DDP Training (分布式并行訓(xùn)練)
來源:知乎—就是不吃草的羊
作者:https://zhuanlan.zhihu.com/p/527360059
01
模型被拆分到不同GPU, 模型太大了,基本用不到
模型放在一個(gè),數(shù)據(jù)拆分不同GPU,torch.dataparallel
基本不會(huì)報(bào)bug
sync bc要自己準(zhǔn)備
模型和數(shù)據(jù)在不同gpu上各有一份, torch.distributeddataparallel
bug多,各進(jìn)程間數(shù)據(jù)不共享,訪問文件先后不確定,在日志系統(tǒng),數(shù)據(jù)集預(yù)處理,模型loss放在指定cuda等地方要仔細(xì)設(shè)計(jì)。
sync 是pytorch現(xiàn)有的庫
原理和效果理論上和 2 一致,都是用更大的batchsize,速度確實(shí)比 2 快,好像顯著減少了數(shù)據(jù)to cuda的時(shí)
支持多機(jī)
卡太多,網(wǎng)絡(luò)跑的時(shí)間短的情況,實(shí)際還不如 2
02
增大bs,就會(huì)帶來增大bs的相關(guān)弊端
過擬合,使用warm-up緩解,需要探索一下增大到多少不會(huì)影響泛化性
對(duì)應(yīng)倍增學(xué)習(xí)率,數(shù)據(jù)一個(gè)epoch減少了n倍,和學(xué)習(xí)率的影響抵消
DP匯總梯度,但是bn是根據(jù)單個(gè)gpu數(shù)據(jù)計(jì)算的,會(huì)有不正確的情況,要用sync bn
map-reduce,每個(gè)gpu得到上一個(gè)的,傳給下一個(gè)
一共兩輪,第一輪讓每個(gè)卡上有全部的數(shù)據(jù),第二輪讓數(shù)據(jù)同步給所有卡
每次只需要1/N的數(shù)據(jù),需要2N-2次,所以理論上與GPU個(gè)數(shù)無關(guān)
模型buffer, 不是參數(shù),其優(yōu)化不是反向傳播而是其他途徑,如bn的variance 和 moving mean
03
可以調(diào)用dist來查看當(dāng)前的rank,之后log等不需要重復(fù)的任務(wù)都在rank=0進(jìn)行
默認(rèn)不用時(shí)候rank=0
先用一張卡debug
使用wandb的話,需要顯式調(diào)用wandb.finish()
import torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDPimport torch.multiprocessing as mpdef demo_fn(rank, world_size):dist.init_process_group("nccl", rank=rank, world_size=world_size)# lots of code.if dist.get_rank() == 0:train_sampler = torch.utils.data.distributed.DistributedSampler(my_trainset)trainloader = torch.utils.data.DataLoader(my_trainset,batch_size=16, num_workers=2, sampler=train_sampler)model = ToyModel()#.to(local_rank)# DDP: Load模型要在構(gòu)造DDP模型之前,且只需要在master上加載就行了。# ckpt_path = None# if dist.get_rank() == 0 and ckpt_path is not None:# model.load_state_dict(torch.load(ckpt_path))model = DDP(model, device_ids=[local_rank], output_device=local_rank)# DDP:需要注意的是,這里的batch_size指的是每個(gè)進(jìn)程下的batch_size。# 也就是說,總batch_size是這里的batch_size再乘以并行數(shù)(world_size)。# torch.cuda.set_device(local_rank)# dist.init_process_group(backend='nccl')loss_func = nn.CrossEntropyLoss().to(local_rank)trainloader.sampler.set_epoch(epoch)data, label = data.to(local_rank), label.to(local_rank)if dist.get_rank() == 0:torch.save(model.module.state_dict(), "%d.ckpt" % epoch)def run_demo(demo_fn, world_size):mp.spawn(demo_fn,args=(world_size,),nprocs=world_size,join=True)
猜您喜歡:
戳我,查看GAN的系列專輯~!附下載 | 《可解釋的機(jī)器學(xué)習(xí)》中文版
附下載 |《TensorFlow 2.0 深度學(xué)習(xí)算法實(shí)戰(zhàn)》
附下載 |《計(jì)算機(jī)視覺中的數(shù)學(xué)方法》分享
