PyTorch分布式訓練進階:這些細節(jié)你都注意到了嗎?

導(dǎo)語?|?pytorch作為目前主流的深度學習訓練框架之一,可以說是每個算法同學工作中的必備技能。此外,pytorch提供了極其方便的API用來進行分布式訓練,由于最近做的工作涉及到一些分布式訓練的細節(jié),在使用中發(fā)現(xiàn)一些之前完全不會care的點,現(xiàn)記錄于此,希望對有需求的同學有所幫助。
本文包含:
pytorch分布式訓練的工作原理介紹。
一些大家平時使用時可能不太注意的點,這些點并不會導(dǎo)致直觀的bug或者訓練中斷,但可能會導(dǎo)致訓練結(jié)果的偏差以及效率的降低。
同時結(jié)合某些場景,介紹更為細粒度(group)的分布式交互方式。
名詞解釋 :
DP: DataParallel
DDP:DistributedDataParaller
基于DDP的多機單卡模型
world_size:并行的節(jié)點數(shù)
rank:節(jié)點的index,從0開始
group_size:并行g(shù)roup的節(jié)點數(shù)
group_ws:group數(shù)量
group_rank:group的index,從0開始
local_group_rank:一個group內(nèi)部的節(jié)點index,從0開始
group_rank_base:一個group內(nèi)local_group_rank為0的節(jié)點的rank
舉例:
使用6節(jié)點,group_size=3,則group_ws=2則各個參數(shù)的對應(yīng)關(guān)系如下:

group 0的group_rank_base為0,group 1的group_rank_base為3。
一、DataParallel和DistributedDataParallel
pytorch提供了兩種分布式訓練的接口,DataParallel(單機多卡)和DistributedDataParallel(多機單卡,多機多卡)。
(一)DataParallel(DP)
先看下DataParallel的工作原理:

module:即要進行的并行的模型,為nn.Module子類實例
device_ids:需要進行并行的卡
output_device:模型最終輸出進行匯總的卡,默認是local_rank=0的卡(以下簡稱“卡0”)
以單機4卡為例,當接到一個batch size=128的數(shù)據(jù)時,卡0會將128的個數(shù)分成32*4,然后將模型拷貝到1~3卡,分別推理32個數(shù)據(jù)后,然后在output_device(默認為卡0)上進行輸出匯總,因為每次推理都會需要進行模型的拷貝,整體效率較低。
注意:
當使用DP的時候,會發(fā)現(xiàn)卡0的顯存占用會比其他的卡更多,原因便在于默認情況下,卡0需要進行輸出的匯總,如果模型的輸出是一個很大tensor,可能會導(dǎo)致卡0負載極其不均衡爆顯存,從而不得不降低整體的bs導(dǎo)致其他卡的顯存利用率低。
解決方案:
由于卡0進行輸出的匯總,因此我們可以把loss的求解放到模型內(nèi)部,這樣模型的輸出就是一個scalar,能夠極大的降低卡0匯總帶來的顯存負載。
(二)DistributedDataParallel(DDP)

其他的參數(shù)含義類似DP,這里重點說下:
broadcast_buffers:在每次調(diào)用forward之前是否進行buffer的同步,比如bn中的mean和var,如果你實現(xiàn)了自己的SyncBn可以設(shè)置為False。
find_unused_parameters:是否查找模型中未參與loss“生成”的參數(shù),簡單說就是模型中定義了一些參數(shù)但是沒用上時需要設(shè)置,后面會詳細介紹。
process_group:并行的group,默認的global group,后面細粒度分布式交互時會詳細介紹。
DistributedDataParallel的則很好的解決了DP推理效率低的問題,這里以多機單卡為例:DDP會在初始化時記錄模型的參數(shù)和buffer等相關(guān)信息,然后進行一次參數(shù)和buffer的同步,這樣在每次迭代時,只需要進行梯度的平均就能保證參數(shù)和buffer在不同的機器上完全一致。
多機多卡情況下,在一個機器內(nèi)部的工作原理和DP一致,這也是為什么torch官方會說多機單卡是效率最高的方式。
目前主要使用DDP的多機單卡模式進行分布式訓練,后文都將基于該設(shè)置進行介紹。
DDP訓練中需要注意的點:
由于DDP在初始化會遍歷模型獲取所有需要進行同步操作的參數(shù)和buffer并記錄,因此,一旦初始化了DDP就不要再對內(nèi)部模型的參數(shù)或者buffer進行增刪,否則會導(dǎo)致新增的參數(shù)或buffer無法被優(yōu)化,但是訓練不會報錯。
如果你是做類似NAS這種需要進行子圖推理的任務(wù)或者模型定義了未使用參數(shù),則必須設(shè)置find_unused_parameters為True,否則設(shè)置為False。如果是后者,請檢查模型刪除無用的參數(shù),find_unused_parameterss設(shè)置為True時會有額外的開銷。
buffer是在forward前進行同步的,所以其實訓練最后一個iter結(jié)束時,不同卡上的buffer是不一樣的(雖然這個差異很小),如果需要完全一致,可以手動調(diào)用DDP._sync_params_and_buffers()
類似NAS這種動態(tài)子圖,且你的優(yōu)化器設(shè)置了momentum等除了grad以外其他需要參與梯度更新的參數(shù)時需要特別注意:在pytorch中,required_grad=False的參數(shù)在進行參數(shù)更新的時候,grad為None,所以torch中優(yōu)化器的step中有一個p.grad is not None的判斷用來跳過這些參數(shù):
for group in self.param_groups:....for p in group['params']:if p.grad is not None:params_with_grad.append(p)d_p_list.append(p.grad)state = self.state[p]if 'momentum_buffer' not in state:momentum_buffer_list.append(None)else:momentum_buffer_list.append(state['momentum_buffer'])....
正常訓練沒有任何問題,但是使用動態(tài)子圖時,即使對當前iter沒有優(yōu)化的子圖的參數(shù)設(shè)置required_grad=False,如果該子圖之前曾經(jīng)被優(yōu)化過,則它的grad會變成全0而不是None。例如有兩個子圖A和B,優(yōu)化順序為A->B->A:1.第一次優(yōu)化A時,B的grad為None,一切正常;2.第一個優(yōu)化B時,由于A已經(jīng)被優(yōu)化過,此時A的grad為0,優(yōu)化器的判斷無法過濾到該參數(shù),因此會沿著第一次優(yōu)化A時的buffer(如momentum)進行錯誤的優(yōu)化。如果子圖數(shù)量很多的話,某一個子圖可能會被錯誤的優(yōu)化成千上萬次。解決方案有兩個:一個是把優(yōu)化器中的
if p.grad is not None:改成
if p.grad is not None (p.grad == 0).all():或者在每次調(diào)用optim.step()之前,加一句:
for p in model.parameters():if p.grad is not None and (p.grad == 0).all():p.grad = None
DDP的梯度匯總使用的是avg,因此如果loss的計算使用的reduce_mean的話,我們不需要再對loss或者grad進行/ world_size的操作。
二、使用DDP時的數(shù)據(jù)讀取
DDP不同于DP需要用卡0進行數(shù)據(jù)分發(fā),它在每個node會有一個獨立的dataloader進行數(shù)據(jù)讀取,一般通過DistributedSampler(DS)來實現(xiàn):

DS會將range(len(dataset))的indices拆分成num_replicas(一般為word_size),不同rank的節(jié)點讀取不同的數(shù)據(jù)進行訓練,一個簡單的分布式訓練示例:
from torch import distributed as distfrom torch.utils.data.distributed import DistributedSamplerimport torch.utils.data as Dataassert torch.cuda.is_available()if not dist.is_initialized():dist.init_process_group(backend='nccl')rank = dist.get_rank()world_size = dist.get_world_size()model = MyModel().cuda()ddp_model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]).cuda()dataset = MyDataset()sampler = DistributedSampler(dataset, rank, world_size, shuffle=True)dataloader = Data.DataLoader(dataset, batch_size, drop_last=False, sampler=sampler, shuffle=False, num_workers=8, pin_memory=True)# training
注意:
如果你的模型使用了分布式評估:
評估需要用到所有測試數(shù)據(jù)的結(jié)果進行整體統(tǒng)計。
精度的計算和數(shù)據(jù)順序相關(guān),則你需要注意DS中:
初始化時會對數(shù)據(jù)進行padding,padding后的數(shù)量為:
real_data_num = int(math.ceil(len(dataset) * 1.0 / world_size)) * world_size因此直接評估可能會使得某些樣本被重復(fù)評估導(dǎo)致精度結(jié)果誤差,尤其是測試數(shù)據(jù)量不大,測試數(shù)據(jù)樣本之間難易程度差距較大時
slice的方式為等間距slice,step為world size,因此直接將不同rank的輸出拼接的話,順序和原始的datast并不一致。
注意:?
可以看到,上述代碼示例中DataLoader的pin_memory設(shè)置為True,torch會在返回數(shù)據(jù)前將數(shù)據(jù)直接放到CUDA的pinned memory里面,從而在訓練時避免從一次從cpu拷貝到gpu的開銷。但是只設(shè)置該參數(shù)不太會導(dǎo)致數(shù)據(jù)讀取速度變快,原因是該參數(shù)需要搭配使用,要將代碼中的數(shù)據(jù)拷貝由.cuda()變更為.cuda(non_blocking=True)
三、分布式訓練進階:Group
根據(jù)上述介紹,基本可以滿足常規(guī)的分布式訓練了。但是像諸如nas這種可能需要同時訓練多個網(wǎng)絡(luò)時,考慮到用戶的不同需求(子網(wǎng)絡(luò)可能需要并行,也可能并不需要并行),我們需要對分布式過程進行更加細粒度的控制,這種控制也可以讓我們能在數(shù)據(jù)讀取和通信開銷做trade off。
在torch的分布式api中基本都包含group(或process_group)這個參數(shù),只不過一般情況下不太需要關(guān)注。它的作用簡言之就是對分布式的節(jié)點數(shù)進行劃分成組,可以在組內(nèi)進行分布式通信的相關(guān)操作。初始化api如下:
ranks = [0,1,2,3]gp = dist.new_group(ranks, backend='nccl')
上述代碼會將節(jié)點[0,1,2,3]作為一個group,在后續(xù)的分布式操作(如:broadcast/reduce/gather/barrier)中,我們只需傳入group=gp參數(shù),就能控制該操作只會在[0,1,2,3]中進行而不會影響其他的節(jié)點。
注意:
在所有的節(jié)點上都需要進行所有g(shù)roup的初始化,而不是只初始化當前rank所屬的group,如使用12卡,group size設(shè)置為4,則12/4=3個group對應(yīng)的rank分別為[0,1,2,3][4,5,6,7][8,9,10,11],這12個節(jié)點都需要初始化三個group,而不是rank0,1,2,3只用初始化group0:
rank = dist.get_rank()group_ranks = [[0,1,2,3], [4,5,6,7],[8,9,10,11]]cur_gp = Nonefor g_ranks in group_ranks:gp = dist.new_groups(g_ranks)if rank in g_ranks:cur_gp = gp# 后續(xù)使用cur_gp即可
注意:
如果進行兼容性考慮的話,比如group_size=1或者group_size=world_size,此時不需要創(chuàng)建group,但是為了代碼的一致性,所有的分布式操作都需要傳入group參數(shù),需要注意的是新版本的torch,分布式op的group參數(shù)缺省值為None,當檢測到None會自動獲取group.WORLD,但是舊版本的缺省參數(shù)為group.WORLD,傳入None會報錯,可以嘗試做以下兼容(具體從哪個版本開始變更沒有嘗試過,以下僅為sample):
import torchfrom torch.distributed.distributed_c10d import _get_default_groupdef get_group(group_size, *args, **kwargs):rank = dist.get_rank()world_size = dist.get_world_size()if group_size == 1:# 后續(xù)不會涉及到分布式的操作return Noneelif group_size == world_size:v = float(torch.__version__.rsplit('.', 1)[0])if v >= 1.8:return Noneelse:return _get_default_group()else:# 返回當前rank對應(yīng)的group
(一)模型在group內(nèi)的并行
只需在DDP初始化的時候把gp賦值給process_group即可
(二)數(shù)據(jù)在group內(nèi)的讀取
使用帶group的DDP訓練時,數(shù)據(jù)讀取依舊使用DS,不同的是num_replicas和rank參數(shù)不再等于world_size和節(jié)點的真實rank,而要變更為group_size和local_group_rank(見名詞解釋部分)。這個也很好理解,舉個例子:
6卡,group_size為3,每個group內(nèi)有3個節(jié)點,模型在這3個節(jié)點上并行。
訓練該模型相應(yīng)的數(shù)據(jù)也應(yīng)只在這3個節(jié)點上進行,所以DS的num_replicas變更為group_size。
另外,DS中的rank參數(shù)決定了當前節(jié)點讀取哪些數(shù)據(jù)(用來進行indices劃分),因此,對于一個group內(nèi)部而言,該參數(shù)需要變更為當前節(jié)點在當前group的序號,即local_group_rank。

四、某些分布式訓練場景下IO瓶頸
這里只介紹多機單卡場景(即一個scheduler和多個worker,且scheduler和每個worker只有一張GPU),且針對某些對于小文件io密集型不太友好的文件系統(tǒng):
對應(yīng)數(shù)據(jù)集不大的,可以考慮做成lmdb或者運行時將數(shù)據(jù)拷貝到docker的路徑下。
數(shù)據(jù)集大,無法采用上述方案時,如果進行大規(guī)模分布式,io問題會更加嚴重:調(diào)度系統(tǒng)可能將worker映射到物理機上,可能導(dǎo)致多個worker都映射到同一臺物理機器,雖然設(shè)置的cpu核心和內(nèi)存,不同的node還是會進行資源搶占,導(dǎo)致速度變慢,為此需要進行數(shù)據(jù)分發(fā):
方式一:group0中的對應(yīng)節(jié)點進行數(shù)據(jù)讀取,然后分發(fā)到其他group的對應(yīng)節(jié)點上,即rank0,1,2各自讀取1/3的數(shù)據(jù),然后通過broadcast將數(shù)據(jù)廣播,rank0的數(shù)據(jù)廣播至rank3,rank1至rank4以此類推。
方式二:rank0的節(jié)點讀取所有數(shù)據(jù),然后在group0內(nèi)進行scatter,然后使用方式一broadcast到其他group。
采用方式一還是二取決于你的數(shù)據(jù)讀取開銷,如果group size很大,那么group0的資源搶占可能就很嚴重,導(dǎo)致速度降低,如果只有rank0進行數(shù)據(jù)讀取的話,雖然不會存在資源搶占(gemini的scheduler不會和worker映射到同一臺機器),但是bs會增大可能會導(dǎo)致讀取變慢。
在gpu正常的情況下,數(shù)據(jù)broadcast的開銷相對較小。
注意:?
使用數(shù)據(jù)broadcast自然需要dataset返回的所有數(shù)據(jù)均是tensor,meta信息諸如字符串類型的數(shù)據(jù)無法broadcast。
進行數(shù)據(jù)broadcast時需要新建一系列的data group,因為它的維度和模型并行的維度不一樣,模型是在[0,1,2]和[3,4,5]上并行,數(shù)據(jù)是在0->3,1->4,2->5上broadcast,因此需要新建三個group[0,3][1,4][2,5]
broadcast自然需要知道數(shù)據(jù)維度,結(jié)合前面講到的DS補齊操作,注意每個epoch最后一個batch數(shù)據(jù)的bs可能不到設(shè)置的bs(drop_last=False時),因此broadcast需要進行額外的處理。
當不同的group之間代碼的邏輯可能不一樣時,使用broadcast需要額外注意,比如group0訓練1個網(wǎng)絡(luò),group1訓練2個網(wǎng)絡(luò),數(shù)據(jù)由group0進行broadcast,group0訓完第一個網(wǎng)絡(luò)就break,導(dǎo)致group1訓練第二個網(wǎng)絡(luò)時接受不到broadcast的數(shù)據(jù)而卡死。
?推薦閱讀
從0到1詳解ZooKeeper的應(yīng)用場景及架構(gòu)原理!

