GPU多卡并行訓(xùn)練總結(jié)(以pytorch為例)
點(diǎn)擊左上方藍(lán)字關(guān)注我們

為什么要使用多GPU并行訓(xùn)練
常見的多GPU訓(xùn)練方法


誤差梯度如何在不同設(shè)備之間通信?
在每個(gè)GPU訓(xùn)練step結(jié)束后,將每塊GPU的損失梯度求平均,而不是每塊GPU各計(jì)算各的。
BN如何在不同設(shè)備之間同步?


兩種GPU訓(xùn)練方法
DataParallel 和 DistributedDataParallel
DataParallel是單進(jìn)程多線程的,僅僅能工作在單機(jī)中。而DistributedDataParallel是多進(jìn)程的,可以工作在單機(jī)或多機(jī)器中。
DataParallel通常會(huì)慢于DistributedDataParallel。所以目前主流的方法是DistributedDataParallel。
pytorch中常見的GPU啟動(dòng)方式

def init_distributed_mode(args):
# 如果是多機(jī)多卡的機(jī)器,WORLD_SIZE代表使用的機(jī)器數(shù),RANK對(duì)應(yīng)第幾臺(tái)機(jī)器
# 如果是單機(jī)多卡的機(jī)器,WORLD_SIZE代表有幾塊GPU,RANK和LOCAL_RANK代表第幾塊GPU
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
# LOCAL_RANK代表某個(gè)機(jī)器上第幾塊GPU
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu) # 對(duì)當(dāng)前進(jìn)程指定使用的GPU
args.dist_backend = 'nccl' # 通信后端,nvidia GPU推薦使用NCCL
dist.barrier() # 等待每個(gè)GPU都運(yùn)行完這個(gè)地方以后再繼續(xù)
def main(args):
if torch.cuda.is_available() is False:
raise EnvironmentError("not find GPU device for training.")
# 初始化各進(jìn)程環(huán)境
init_distributed_mode(args=args)
rank = args.rank
device = torch.device(args.device)
batch_size = args.batch_size
num_classes = args.num_classes
weights_path = args.weights
args.lr *= args.world_size # 學(xué)習(xí)率要根據(jù)并行GPU的數(shù)倍增
#給每個(gè)rank對(duì)應(yīng)的進(jìn)程分配訓(xùn)練的樣本索引
train_sampler=torch.utils.data.distributed.DistributedSampler(train_data_set)
val_sampler=torch.utils.data.distributed.DistributedSampler(val_data_set)
#將樣本索引每batch_size個(gè)元素組成一個(gè)list
train_batch_sampler=torch.utils.data.BatchSampler(
train_sampler,batch_size,drop_last=True)


train_loader = torch.utils.data.DataLoader(train_data_set,
batch_sampler=train_batch_sampler,
pin_memory=True, # 直接加載到顯存中,達(dá)到加速效果
num_workers=nw,
collate_fn=train_data_set.collate_fn)
val_loader = torch.utils.data.DataLoader(val_data_set,
batch_size=batch_size,
sampler=val_sampler,
pin_memory=True,
num_workers=nw,
collate_fn=val_data_set.collate_fn)
# 實(shí)例化模型
model = resnet34(num_classes=num_classes).to(device)
# 如果存在預(yù)訓(xùn)練權(quán)重則載入
if os.path.exists(weights_path):
weights_dict = torch.load(weights_path, map_location=device)
# 簡單對(duì)比每層的權(quán)重參數(shù)個(gè)數(shù)是否一致
load_weights_dict = {k: v for k, v in weights_dict.items()
if model.state_dict()[k].numel() == v.numel()}
model.load_state_dict(load_weights_dict, strict=False)
else:
checkpoint_path = os.path.join(tempfile.gettempdir(), "initial_weights.pt")
# 如果不存在預(yù)訓(xùn)練權(quán)重,需要將第一個(gè)進(jìn)程中的權(quán)重保存,然后其他進(jìn)程載入,保持初始化權(quán)重一致
if rank == 0:
torch.save(model.state_dict(), checkpoint_path)
dist.barrier()
# 這里注意,一定要指定map_location參數(shù),否則會(huì)導(dǎo)致第一塊GPU占用更多資源
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
# 是否凍結(jié)權(quán)重
if args.freeze_layers:
for name, para in model.named_parameters():
# 除最后的全連接層外,其他權(quán)重全部凍結(jié)
if "fc" not in name:
para.requires_grad_(False)
else:
# 只有訓(xùn)練帶有BN結(jié)構(gòu)的網(wǎng)絡(luò)時(shí)使用SyncBatchNorm采用意義
if args.syncBN:
# 使用SyncBatchNorm后訓(xùn)練會(huì)更耗時(shí)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
# 轉(zhuǎn)為DDP模型
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
# optimizer使用SGD+余弦淬火策略
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
for epoch in range(args.epochs):
train_sampler.set_epoch(epoch)
mean_loss = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
scheduler.step()
sum_num = evaluate(model=model,
data_loader=val_loader,
device=device)
acc = sum_num / val_sampler.total_size
def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss()
mean_loss = torch.zeros(1).to(device)
optimizer.zero_grad()
# 在進(jìn)程0中打印訓(xùn)練進(jìn)度
if is_main_process():
data_loader = tqdm(data_loader)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
loss = loss_function(pred, labels.to(device))
loss.backward()
loss = reduce_value(loss, average=True) # 在單GPU中不起作用,多GPU時(shí),獲得所有GPU的loss的均值。
mean_loss = (mean_loss * step + loss.detach()) / (step + 1) # update mean losses
# 在進(jìn)程0中打印平均loss
if is_main_process():
data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))
if not torch.isfinite(loss):
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
# 等待所有進(jìn)程計(jì)算完畢
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
return mean_loss.item()
def reduce_value(value, average=True):
world_size = get_world_size()
if world_size < 2: # 單GPU的情況
return value
with torch.no_grad():
dist.all_reduce(value) # 對(duì)不同設(shè)備之間的value求和
if average: # 如果需要求平均,獲得多塊GPU計(jì)算loss的均值
value /= world_size
return value
@torch.no_grad()
def evaluate(model, data_loader, device):
model.eval()
# 用于存儲(chǔ)預(yù)測正確的樣本個(gè)數(shù),每塊GPU都會(huì)計(jì)算自己正確樣本的數(shù)量
sum_num = torch.zeros(1).to(device)
# 在進(jìn)程0中打印驗(yàn)證進(jìn)度
if is_main_process():
data_loader = tqdm(data_loader)
for step, data in enumerate(data_loader):
images, labels = data
pred = model(images.to(device))
pred = torch.max(pred, dim=1)[1]
sum_num += torch.eq(pred, labels.to(device)).sum()
# 等待所有進(jìn)程計(jì)算完畢
if device != torch.device("cpu"):
torch.cuda.synchronize(device)
sum_num = reduce_value(sum_num, average=False) # 預(yù)測正確樣本個(gè)數(shù)
return sum_num.item()
if rank == 0:
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
tags = ["loss", "accuracy", "learning_rate"]
tb_writer.add_scalar(tags[0], mean_loss, epoch)
tb_writer.add_scalar(tags[1], acc, epoch)
tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)
torch.save(model.module.state_dict(), "./weights/model-{}.pth".format(epoch))
if rank == 0:# 刪除臨時(shí)緩存文件
if os.path.exists(checkpoint_path) is True:
os.remove(checkpoint_path)
dist.destroy_process_group() # 撤銷進(jìn)程組,釋放資源
END
整理不易,點(diǎn)贊支持一下吧↓
評(píng)論
圖片
表情
