干貨|Pytorch彈性訓(xùn)練極簡(jiǎn)實(shí)現(xiàn)( 附源碼)
點(diǎn)擊上方“視學(xué)算法”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
導(dǎo)讀
?作者將以往抽象的分布式訓(xùn)練的概念以代碼的形式展現(xiàn)出來(lái),并保證每個(gè)代碼可執(zhí)行、可驗(yàn)證、可復(fù)現(xiàn),并貢獻(xiàn)出來(lái)源碼讓大家相互交流。本例中會(huì)先在Node0上啟動(dòng)4 GPU的worker group ,等其訓(xùn)練一段時(shí)間后,會(huì)在Node1上再啟動(dòng)4 GPU的workers,并與Node1上的workers構(gòu)成一個(gè)新的worker group,最終構(gòu)成一個(gè)2機(jī)8卡的分布式訓(xùn)練。
由于工作需要,最近在補(bǔ)充分布式訓(xùn)練方面的知識(shí)。經(jīng)過(guò)一番理論學(xué)習(xí)后仍覺(jué)得意猶未盡,很多知識(shí)點(diǎn)無(wú)法準(zhǔn)確get到(例如:分布式原語(yǔ)scatter、all reduce等代碼層面應(yīng)該是什么樣的,ring all reduce 算法在梯度同步時(shí)是怎么使用的,parameter server參數(shù)是如何部分更新的)。
著名物理學(xué)家,諾貝爾獎(jiǎng)得主Richard Feynman辦公室的黑板上寫(xiě)了:"What I cannot create, I do not understand."。在程序員界也經(jīng)常有"show me the code"的口號(hào)。因此,我打算寫(xiě)一系列的分布式訓(xùn)練的文章,將以往抽象的分布式訓(xùn)練的概念以代碼的形式展現(xiàn)出來(lái),并保證每個(gè)代碼可執(zhí)行、可驗(yàn)證、可復(fù)現(xiàn),并貢獻(xiàn)出來(lái)源碼讓大家相互交流。
經(jīng)過(guò)調(diào)研發(fā)現(xiàn)pytorch對(duì)于分布式訓(xùn)練做好很好的抽象且接口完善,因此本系列文章將以pytorch為主要框架進(jìn)行,文章中的例子很多都來(lái)自pytorch的文檔,并在此基礎(chǔ)上進(jìn)行了調(diào)試和擴(kuò)充。
最后,由于分布式訓(xùn)練的理論介紹網(wǎng)絡(luò)上已經(jīng)很多了,理論部分的介紹不會(huì)是本系列文章的重點(diǎn),我會(huì)將重點(diǎn)放在代碼層面的介紹上面。
Pytorch - 分布式訓(xùn)練極簡(jiǎn)體驗(yàn):https://zhuanlan.zhihu.com/p/477073906
Pytorch - 分布式通信原語(yǔ)(附源碼):https://zhuanlan.zhihu.com/p/478953028
Pytorch - 手寫(xiě)allreduce分布式訓(xùn)練(附源碼):https://zhuanlan.zhihu.com/p/482557067
Pytorch - 算子間并行極簡(jiǎn)實(shí)現(xiàn)(附源碼):https://zhuanlan.zhihu.com/p/483640235
Pytorch - 多機(jī)多卡極簡(jiǎn)實(shí)現(xiàn)(附源碼):https://zhuanlan.zhihu.com/p/486130584
1. 介紹
Pytorch在1.9.0引入了torchrun,用其替代1.9.0以前版本的torch.distributed.launch。torchrun在torch.distributed.launch 功能的基礎(chǔ)上主要新增了兩個(gè)功能:
Failover: 當(dāng)worker訓(xùn)練失敗時(shí),會(huì)自動(dòng)重新啟動(dòng)所有worker繼續(xù)進(jìn)行訓(xùn)練; Elastic: 可以動(dòng)態(tài)增加或或刪除node節(jié)點(diǎn),本文將通過(guò)一個(gè)例子說(shuō)明Elastic Training應(yīng)該如何使用;
本例中會(huì)先在Node0上啟動(dòng)4 GPU的worker group ,等其訓(xùn)練一段時(shí)間后,會(huì)在Node1上再啟動(dòng)4 GPU的workers,并與Node1上的workers構(gòu)成一個(gè)新的worker group,最終構(gòu)成一個(gè)2機(jī)8卡的分布式訓(xùn)練。

2. 模型構(gòu)建
一個(gè)簡(jiǎn)單的全連接模型神經(jīng)網(wǎng)絡(luò)模型
class?ToyModel(nn.Module):
????def?__init__(self):
????????super(ToyModel,?self).__init__()
????????self.net1?=?nn.Linear(10,?10)
????????self.relu?=?nn.ReLU()
????????self.net2?=?nn.Linear(10,?5)
????def?forward(self,?x):
????????return?self.net2(self.relu(self.net1(x)))
3. checkpoint 處理
由于再每次增加或刪除node時(shí),會(huì)將所有worker kill掉,然后再重新啟動(dòng)所有worker進(jìn)行訓(xùn)練。因此,在訓(xùn)練代碼中要對(duì)訓(xùn)練的狀態(tài)進(jìn)行保存,以保證重啟后能接著上次的狀態(tài)繼續(xù)訓(xùn)練。
需要保存的信息一般有如下內(nèi)容:
model :模型的參數(shù)信息 optimizer :優(yōu)化器的參數(shù)信心 epoch:當(dāng)前執(zhí)行到第幾個(gè)epoch
save和load的代碼如下所示
torch.save:利用python的pickle將python的object 進(jìn)行序列化,并保存到本地文件;torch.load: 將torch.save后的本地文件進(jìn)行反序列化,并加載到內(nèi)存中;model.state_dict():存儲(chǔ)了model 每個(gè)layer和其對(duì)應(yīng)的param信息optimizer.state_dict():存儲(chǔ)了優(yōu)化器的參數(shù)信信息
def?save_checkpoint(epoch,?model,?optimizer,?path):
????torch.save({
????"epoch":?epoch,
????"model_state_dict":?model.state_dict(),
????"optimize_state_dict":?optimizer.state_dict(),
},?path)
def?load_checkpoint(path):
????checkpoint?=?torch.load(path)
????return?checkpoint
4. 訓(xùn)練代碼
初始化邏輯如下:
1~3行: 輸出當(dāng)前worker的關(guān)鍵環(huán)境變量,用于后面的結(jié)果展示 5~8行:創(chuàng)建模型、優(yōu)化器和損失函數(shù) 10~12行:初始化參數(shù)信息 14~19行:如果存在checkpoint,則加載checkpoint,并賦值給model、optimizer和firt_epoch
????local_rank?=?int(os.environ["LOCAL_RANK"])
????rank?=?int(os.environ["RANK"])
????print(f"[{os.getpid()}]?(rank?=?{rank},?local_rank?=?{local_rank})?train?worker?starting...")
????
????model?=?ToyModel().cuda(local_rank)
????ddp_model?=?DDP(model,?[local_rank])
????loss_fn?=?nn.MSELoss()
????optimizer?=?optim.SGD(ddp_model.parameters(),?lr=0.001)
????optimizer.zero_grad()
????max_epoch?=?100
????first_epoch?=?0
????ckp_path?=?"checkpoint.pt"
????
????if?os.path.exists(ckp_path):
????????print(f"load?checkpoint?from?{ckp_path}")
????????checkpoint?=?load_checkpoint(ckp_path)
????????model.load_state_dict(checkpoint["model_state_dict"])
????????optimizer.load_state_dict(checkpoint["optimize_state_dict"])
????????first_epoch?=?checkpoint["epoch"]
訓(xùn)練邏輯:
1行:epoch執(zhí)行的次數(shù)為first_epoch到max_epoch,以便能夠在worker被重啟后繼續(xù)原有的epoch繼續(xù)訓(xùn)練; 2行:為了展示動(dòng)態(tài)添加node效果,這里添加sleep函數(shù)來(lái)降低訓(xùn)練的速度; 3~8行:模型訓(xùn)練流程; 9行:為了簡(jiǎn)單,文本每個(gè)epoch進(jìn)行一次checkpoint保存;將當(dāng)前的epoch,model和optimizer保存到checkpoint中;
????for?i?in?range(first_epoch,?max_epoch):
????????time.sleep(1)?#?為了展示動(dòng)態(tài)添加node效果,這里添加sleep函數(shù)來(lái)降低訓(xùn)練的速度
????????outputs?=?ddp_model(torch.randn(20,?10).to(local_rank))
????????labels?=?torch.randn(20,?5).to(local_rank)
????????loss?=?loss_fn(outputs,?labels)
????????loss.backward()
????????print(f"[{os.getpid()}]?epoch?{i}?(rank?=?{rank},?local_rank?=?{local_rank})?loss?=?{loss.item()}\n")
????????optimizer.step()
????????save_checkpoint(i,?model,?optimizer,?ckp_path)
5. 啟動(dòng)方式
由于我們使用torchrun來(lái)啟動(dòng)多機(jī)多卡任務(wù),無(wú)需使用spawn接口來(lái)啟動(dòng)多個(gè)進(jìn)程(torchrun會(huì)負(fù)責(zé)將我們的python script啟動(dòng)為一個(gè)process),因此直接調(diào)用上文編寫(xiě)的train函數(shù),并在前后分別添加DistributedDataParallel的初始化和效果函數(shù)即可。
下面代碼描述了上文train接口的調(diào)用。
def?run():
????env_dict?=?{
????????key:?os.environ[key]
????????for?key?in?("MASTER_ADDR",?"MASTER_PORT",?"WORLD_SIZE",?"LOCAL_WORLD_SIZE")
????}
????print(f"[{os.getpid()}]?Initializing?process?group?with:?{env_dict}")
????dist.init_process_group(backend="nccl")
????train()
????dist.destroy_process_group()
if?__name__?==?"__main__":
????run()
本例中使用torchrun來(lái)執(zhí)行多機(jī)多卡的分布式訓(xùn)練任務(wù)(注:torch.distributed.launch 已經(jīng)被pytorch淘汰了,盡量不要再使用)。啟動(dòng)腳本描述如下(注:node0和node1均通過(guò)該腳本進(jìn)行啟動(dòng))
--nnodes=1:3:表示當(dāng)前訓(xùn)練任務(wù)接受最少1個(gè)node,最多3個(gè)node參與分布式訓(xùn)練;--nproc_per_node=4:表示每個(gè)node上節(jié)點(diǎn)有4個(gè)process--max_restarts=3: worker group最大的重啟次數(shù);這里需要注意的是,node fail、node scale down和node scale up都會(huì)導(dǎo)致restart;--rdzv_id=1:一個(gè)unique的job id,所有node均使用同一個(gè)job id;--rdzv_backend: rendezvous的backend實(shí)現(xiàn),默認(rèn)支持c10d和etcd兩種;rendezvous用于多個(gè)node之間的通信和協(xié)調(diào);--rdzv_endpoint:rendezvous的地址,應(yīng)該為一個(gè)node的host ip和port;
torchrun \
--nnodes=1:3\
--nproc_per_node=4\
--max_restarts=3\
--rdzv_id=1\
--rdzv_backend=c10d\
--rdzv_endpoint="192.0.0.1:1234"\
train_elastic.py
6. 結(jié)果分析
代碼:BetterDL - train_elastic.py:https://github.com/tingshua-yts/BetterDL/blob/master/test/pytorch/DDP/train_elastic.py
運(yùn)行環(huán)境: 2臺(tái)4卡 v100機(jī)器
image: pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
gpu: v100
先在node0上執(zhí)行執(zhí)行啟動(dòng)腳本
torchrun \
--nnodes=1:3\
--nproc_per_node=4\
--max_restarts=3\
--rdzv_id=1\
--rdzv_backend=c10d\
--rdzv_endpoint="192.0.0.1:1234"\
train_elastic.py
得到如下結(jié)果
2~5行:當(dāng)前啟動(dòng)的是單機(jī)4卡的訓(xùn)練任務(wù),因此WORLD_SIZE為4, LOCAL_WORKD_SIZE也為4 6~9行:共有4個(gè)rank參與了分布式訓(xùn)練,rank0~rank3 10~18行: rank0~rank3 均從epoch=0開(kāi)始訓(xùn)練
r/workspace/DDP# sh run_elastic.sh
[4031] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4029] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4030] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4032] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
[4029] (rank = 0, local_rank = 0) train worker starting...
[4030] (rank = 1, local_rank = 1) train worker starting...
[4032] (rank = 3, local_rank = 3) train worker starting...
[4031] (rank = 2, local_rank = 2) train worker starting...
[4101] epoch 0 (rank = 1, local_rank = 1) loss = 0.9288564920425415
[4103] epoch 0 (rank = 3, local_rank = 3) loss = 0.9711472988128662
[4102] epoch 0 (rank = 2, local_rank = 2) loss = 1.0727070569992065
[4100] epoch 0 (rank = 0, local_rank = 0) loss = 0.9402943253517151
[4100] epoch 1 (rank = 0, local_rank = 0) loss = 1.0327017307281494
[4101] epoch 1 (rank = 1, local_rank = 1) loss = 1.4485043287277222
[4103] epoch 1 (rank = 3, local_rank = 3) loss = 1.0959293842315674
[4102] epoch 1 (rank = 2, local_rank = 2) loss = 1.0669530630111694
...
在node1上執(zhí)行與上面相同的腳本
torchrun \
--nnodes=1:3\
--nproc_per_node=4\
--max_restarts=3\
--rdzv_id=1\
--rdzv_backend=c10d\
--rdzv_endpoint="192.0.0.1:1234"\
train_elastic.py
node1上結(jié)果如下:
2~5行:由于添加node1,當(dāng)前執(zhí)行的是2機(jī)8卡的分布式訓(xùn)練任務(wù),因此WORLD_SIZE=8, LOCAL_WORLD_SIZE=4 6~9行:當(dāng)前node1上workers的rank為rank4 ~rank7 13~20行: 由于node1是在node0上work訓(xùn)練到epoch35的時(shí)候加入的,因此其接著epoch 35開(kāi)始訓(xùn)練
/workspace/DDP# sh run_elastic.sh
[696] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[697] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[695] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[694] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[697] (rank = 7, local_rank = 3) train worker starting...
[695] (rank = 5, local_rank = 1) train worker starting...
[694] (rank = 4, local_rank = 0) train worker starting...
[696] (rank = 6, local_rank = 2) train worker starting...
load checkpoint from checkpoint.ptload checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
[697] epoch 35 (rank = 7, local_rank = 3) loss = 1.1888569593429565
[694] epoch 35 (rank = 4, local_rank = 0) loss = 0.8916441202163696
[695] epoch 35 (rank = 5, local_rank = 1) loss = 1.5685604810714722
[696] epoch 35 (rank = 6, local_rank = 2) loss = 1.11683189868927
[696] epoch 36 (rank = 6, local_rank = 2) loss = 1.3724170923233032
[694] epoch 36 (rank = 4, local_rank = 0) loss = 1.061527967453003
[695] epoch 36 (rank = 5, local_rank = 1) loss = 0.96876460313797
[697] epoch 36 (rank = 7, local_rank = 3) loss = 0.8060566782951355
...
node0上結(jié)果如下:
6~9行: node0上的works在執(zhí)行到epoch 35時(shí),node1上執(zhí)行了訓(xùn)練腳本,請(qǐng)求加入到訓(xùn)練任務(wù)中 10~13行:所有workers重新啟動(dòng),由于添加了node1,當(dāng)前執(zhí)行的是2機(jī)8卡的分布式訓(xùn)練任務(wù),因此WORLD_SIZE=8, LOCAL_WORLD_SIZE=4 14~17行:當(dāng)前node1上works的rank為rank0~rank3 18~21行:加載checkpoint 22~30行:接著checkpoint中的model、optimizer和epoch繼續(xù)訓(xùn)練
...
[4100] epoch 35 (rank = 0, local_rank = 0) loss = 1.0746158361434937
[4101] epoch 35 (rank = 1, local_rank = 1) loss = 1.1712706089019775
[4103] epoch 35 (rank = 3, local_rank = 3) loss = 1.1774182319641113
[4102] epoch 35 (rank = 2, local_rank = 2) loss = 1.0898035764694214
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4100 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4101 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4102 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4103 closing signal SIGTERM
[4164] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4165] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4162] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4163] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
[4162] (rank = 0, local_rank = 0) train worker starting...
[4163] (rank = 1, local_rank = 1) train worker starting...
[4164] (rank = 2, local_rank = 2) train worker starting...
[4165] (rank = 3, local_rank = 3) train worker starting...
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
load checkpoint from checkpoint.pt
[4165] epoch 35 (rank = 3, local_rank = 3) loss = 1.3437936305999756
[4162] epoch 35 (rank = 0, local_rank = 0) loss = 1.5693414211273193
[4163] epoch 35 (rank = 1, local_rank = 1) loss = 1.199862003326416
[4164] epoch 35 (rank = 2, local_rank = 2) loss = 1.0465545654296875
[4163] epoch 36 (rank = 1, local_rank = 1) loss = 0.9741991758346558
[4162] epoch 36 (rank = 0, local_rank = 0) loss = 1.3609280586242676
[4164] epoch 36 (rank = 2, local_rank = 2) loss = 0.9585908055305481
[4165] epoch 36 (rank = 3, local_rank = 3) loss = 0.9169824123382568
...
點(diǎn)個(gè)在看 paper不斷!
