使用 PyTorch 進(jìn)行分布式訓(xùn)練
點(diǎn)擊下方“AI算法與圖像處理”,一起進(jìn)步!
重磅干貨,第一時(shí)間送達(dá)
size:進(jìn)行訓(xùn)練的 GPU 設(shè)備的數(shù)量
rank:對GPU設(shè)備有一個(gè)序列的id號
# Download and initialize MNIST train datasettrain_dataset = datasets.MNIST('./mnist_data',download=True,train=True)# Wrap train dataset into DataLoadertrain_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=4,pin_memory=True)
# Download and initialize MNIST train datasettrain_dataset = datasets.MNIST('./mnist_data',download=True,train=True,transform=transform)# Create distributed sampler pinned to ranksampler = DistributedSampler(train_dataset,num_replicas=world_size,rank=rank,shuffle=True, # May be Trueseed=42)# Wrap train dataset into DataLoadertrain_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=False, # Must be False!num_workers=4,sampler=sampler,pin_memory=True)
def create_model():model = nn.Sequential(nn.Linear(28*28, 128), # MNIST images are 28x28 pixelsnn.ReLU(),nn.Dropout(0.2),nn.Linear(128, 128),nn.ReLU(),nn.Linear(128, 10, bias=False) # 10 classes to predict)return model# Initialize the modelmodel = create_model()
# Initialize the modelmodel = create_model()# Create CUDA devicedevice = torch.device(f'cuda:{rank}')# Send model parameters to the devicemodel = model.to(device)# Wrap the model in DDP wrappermodel = DistributedDataParallel(model, device_ids=[rank], output_device=rank)
for i in range(epochs):for x, y in train_loader:# do the training...
for i in range(epochs):train_loader.sampler.set_epoch(i)for x, y in train_loader:# do the training...
parser = argparse.ArgumentParser()parser.add_argument("--local_rank", type=int)args = parser.parse_args()rank = args.local_rank
if rank == 0:torch.save(model.module.state_dict(), 'model.pt')
python -m torch.distributed.launch --nproc_per_node=4ddp_tutorial_multi_gpu.py
努力分享優(yōu)質(zhì)的計(jì)算機(jī)視覺相關(guān)內(nèi)容,歡迎關(guān)注:
個(gè)人微信(如果沒有備注不拉群!) 請注明:地區(qū)+學(xué)校/企業(yè)+研究方向+昵稱
下載1:何愷明頂會(huì)分享
在「AI算法與圖像處理」公眾號后臺(tái)回復(fù):何愷明,即可下載。總共有6份PDF,涉及 ResNet、Mask RCNN等經(jīng)典工作的總結(jié)分析
下載2:終身受益的編程指南:Google編程風(fēng)格指南
在「AI算法與圖像處理」公眾號后臺(tái)回復(fù):c++,即可下載。歷經(jīng)十年考驗(yàn),最權(quán)威的編程規(guī)范!
下載3 CVPR2021
在「AI算法與圖像處理」公眾號后臺(tái)回復(fù):CVPR,即可下載1467篇CVPR 2020論文 和 CVPR 2021 最新論文
點(diǎn)亮
,告訴大家你也在看
評論
圖片
表情
