Pytorch GPU多卡并行訓(xùn)練實(shí)戰(zhàn)總結(jié)(附代碼)

來(lái)源 l 記憶的迷谷? ? 出品 l 對(duì)白的算法屋
今天分享給大家一份Pytorch GPU多卡并行訓(xùn)練實(shí)戰(zhàn)細(xì)節(jié)總結(jié)。
為什么要使用多GPU并行訓(xùn)練?
簡(jiǎn)單來(lái)說(shuō),有兩種原因:第一種是模型在一塊GPU上放不下,兩塊或多塊GPU上就能運(yùn)行完整的模型(如早期的AlexNet)。第二種是多塊GPU并行計(jì)算可以達(dá)到加速訓(xùn)練的效果。想要成為“煉丹大師“,多GPU并行訓(xùn)練是不可或缺的技能。
常見(jiàn)的多GPU訓(xùn)練方法:


誤差梯度如何在不同設(shè)備之間通信?
BN如何在不同設(shè)備之間同步?

?兩種GPU訓(xùn)練方法:DataParallel 和 DistributedDataParallel:
DataParallel是單進(jìn)程多線程的,僅僅能工作在單機(jī)中。而DistributedDataParallel是多進(jìn)程的,可以工作在單機(jī)或多機(jī)器中。 DataParallel通常會(huì)慢于DistributedDataParallel。所以目前主流的方法是DistributedDataParallel。
pytorch中常見(jiàn)的GPU啟動(dòng)方式:

def init_distributed_mode(args):# 如果是多機(jī)多卡的機(jī)器,WORLD_SIZE代表使用的機(jī)器數(shù),RANK對(duì)應(yīng)第幾臺(tái)機(jī)器# 如果是單機(jī)多卡的機(jī)器,WORLD_SIZE代表有幾塊GPU,RANK和LOCAL_RANK代表第幾塊GPUif'RANK'in os.environ and'WORLD_SIZE'in os.environ:args.rank = int(os.environ["RANK"])args.world_size = int(os.environ['WORLD_SIZE'])# LOCAL_RANK代表某個(gè)機(jī)器上第幾塊GPUargs.gpu = int(os.environ['LOCAL_RANK'])elif'SLURM_PROCID'in os.environ:args.rank = int(os.environ['SLURM_PROCID'])args.gpu = args.rank % torch.cuda.device_count()else:print('Not using distributed mode')args.distributed = Falsereturnargs.distributed = Truetorch.cuda.set_device(args.gpu) # 對(duì)當(dāng)前進(jìn)程指定使用的GPUargs.dist_backend = 'nccl'# 通信后端,nvidia GPU推薦使用NCCLdist.barrier() # 等待每個(gè)GPU都運(yùn)行完這個(gè)地方以后再繼續(xù)
def main(args):if torch.cuda.is_available() isFalse:raise EnvironmentError("not find GPU device for training.")# 初始化各進(jìn)程環(huán)境=args)rank = args.rankdevice = torch.device(args.device)batch_size = args.batch_sizenum_classes = args.num_classesweights_path = args.weights*= args.world_size # 學(xué)習(xí)率要根據(jù)并行GPU的數(shù)倍增
#給每個(gè)rank對(duì)應(yīng)的進(jìn)程分配訓(xùn)練的樣本索引train_sampler=torch.utils.data.distributed.DistributedSampler(train_data_set)val_sampler=torch.utils.data.distributed.DistributedSampler(val_data_set)#將樣本索引每batch_size個(gè)元素組成一個(gè)listtrain_batch_sampler=torch.utils.data.BatchSampler(train_sampler,batch_size,drop_last=True)


train_loader = torch.utils.data.DataLoader(train_data_set,batch_sampler=train_batch_sampler,pin_memory=True, # 直接加載到顯存中,達(dá)到加速效果num_workers=nw,collate_fn=train_data_set.collate_fn)val_loader = torch.utils.data.DataLoader(val_data_set,batch_size=batch_size,sampler=val_sampler,pin_memory=True,num_workers=nw,collate_fn=val_data_set.collate_fn)
model = resnet34(num_classes=num_classes).to(device)if os.path.exists(weights_path):weights_dict = torch.load(weights_path, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}model.load_state_dict(load_weights_dict, strict=False)else:checkpoint_path = os.path.join(tempfile.gettempdir(), "initial_weights.pt")if rank == 0:torch.save(model.state_dict(), checkpoint_path)dist.barrier()model.load_state_dict(torch.load(checkpoint_path, map_location=device))
# 是否凍結(jié)權(quán)重if args.freeze_layers:for name, para in model.named_parameters():# 除最后的全連接層外,其他權(quán)重全部?jī)鼋Y(jié)if"fc"notin name:para.requires_grad_(False)else:# 只有訓(xùn)練帶有BN結(jié)構(gòu)的網(wǎng)絡(luò)時(shí)使用SyncBatchNorm采用意義if args.syncBN:# 使用SyncBatchNorm后訓(xùn)練會(huì)更耗時(shí)model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)# 轉(zhuǎn)為DDP模型model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])# optimizer使用SGD+余弦淬火策略pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
for epoch in range(args.epochs):mean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()sum_num = evaluate(model=model,data_loader=val_loader,device=device)acc = sum_num / val_sampler.total_size
def train_one_epoch(model, optimizer, data_loader, device, epoch):model.train()loss_function = torch.nn.CrossEntropyLoss()mean_loss = torch.zeros(1).to(device)optimizer.zero_grad()if is_main_process():data_loader = tqdm(data_loader)for step, data in enumerate(data_loader):images, labels = datapred = model(images.to(device))loss = loss_function(pred, labels.to(device))loss.backward()loss = reduce_value(loss, average=True)mean_loss = (mean_loss * step + loss.detach()) / (step + 1)if is_main_process():data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))ifnot torch.isfinite(loss):print('WARNING: non-finite loss, ending training ', loss)sys.exit(1)optimizer.step()optimizer.zero_grad()if device != torch.device("cpu"):torch.cuda.synchronize(device)return mean_loss.item()def reduce_value(value, average=True):world_size = get_world_size()if world_size < 2:return valuewith torch.no_grad():dist.all_reduce(value)if average:value /= world_sizereturn value
@torch.no_grad()def evaluate(model, data_loader, device):model.eval()# 用于存儲(chǔ)預(yù)測(cè)正確的樣本個(gè)數(shù),每塊GPU都會(huì)計(jì)算自己正確樣本的數(shù)量sum_num = torch.zeros(1).to(device)# 在進(jìn)程0中打印驗(yàn)證進(jìn)度if is_main_process():data_loader = tqdm(data_loader)for step, data in enumerate(data_loader):images, labels = datapred = model(images.to(device))pred = torch.max(pred, dim=1)[1]sum_num += torch.eq(pred, labels.to(device)).sum()# 等待所有進(jìn)程計(jì)算完畢if device != torch.device("cpu"):torch.cuda.synchronize(device)sum_num = reduce_value(sum_num, average=False) # 預(yù)測(cè)正確樣本個(gè)數(shù)return sum_num.item()
if rank == 0:print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.module.state_dict(), "./weights/model-{}.pth".format(epoch))
if rank == 0:
分享
收藏
點(diǎn)贊
在看

評(píng)論
圖片
表情
