<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          干貨|Pytorch彈性訓(xùn)練極簡(jiǎn)實(shí)現(xiàn)( 附源碼)

          共 11180字,需瀏覽 23分鐘

           ·

          2022-04-25 19:03

          點(diǎn)擊上方視學(xué)算法”,選擇加"星標(biāo)"或“置頂

          重磅干貨,第一時(shí)間送達(dá)

          作者丨顏挺帥@知乎(已授權(quán))
          來(lái)源丨h(huán)ttps://zhuanlan.zhihu.com/p/489892744
          編輯丨極市平臺(tái)

          導(dǎo)讀

          ?

          作者將以往抽象的分布式訓(xùn)練的概念以代碼的形式展現(xiàn)出來(lái),并保證每個(gè)代碼可執(zhí)行、可驗(yàn)證、可復(fù)現(xiàn),并貢獻(xiàn)出來(lái)源碼讓大家相互交流。本例中會(huì)先在Node0上啟動(dòng)4 GPU的worker group ,等其訓(xùn)練一段時(shí)間后,會(huì)在Node1上再啟動(dòng)4 GPU的workers,并與Node1上的workers構(gòu)成一個(gè)新的worker group,最終構(gòu)成一個(gè)2機(jī)8卡的分布式訓(xùn)練。

          由于工作需要,最近在補(bǔ)充分布式訓(xùn)練方面的知識(shí)。經(jīng)過(guò)一番理論學(xué)習(xí)后仍覺(jué)得意猶未盡,很多知識(shí)點(diǎn)無(wú)法準(zhǔn)確get到(例如:分布式原語(yǔ)scatter、all reduce等代碼層面應(yīng)該是什么樣的,ring all reduce 算法在梯度同步時(shí)是怎么使用的,parameter server參數(shù)是如何部分更新的)。

          著名物理學(xué)家,諾貝爾獎(jiǎng)得主Richard Feynman辦公室的黑板上寫(xiě)了:"What I cannot create, I do not understand."。在程序員界也經(jīng)常有"show me the code"的口號(hào)。因此,我打算寫(xiě)一系列的分布式訓(xùn)練的文章,將以往抽象的分布式訓(xùn)練的概念以代碼的形式展現(xiàn)出來(lái),并保證每個(gè)代碼可執(zhí)行、可驗(yàn)證、可復(fù)現(xiàn),并貢獻(xiàn)出來(lái)源碼讓大家相互交流。

          經(jīng)過(guò)調(diào)研發(fā)現(xiàn)pytorch對(duì)于分布式訓(xùn)練做好很好的抽象且接口完善,因此本系列文章將以pytorch為主要框架進(jìn)行,文章中的例子很多都來(lái)自pytorch的文檔,并在此基礎(chǔ)上進(jìn)行了調(diào)試和擴(kuò)充。

          最后,由于分布式訓(xùn)練的理論介紹網(wǎng)絡(luò)上已經(jīng)很多了,理論部分的介紹不會(huì)是本系列文章的重點(diǎn),我會(huì)將重點(diǎn)放在代碼層面的介紹上面。

          Pytorch - 分布式訓(xùn)練極簡(jiǎn)體驗(yàn):https://zhuanlan.zhihu.com/p/477073906

          Pytorch - 分布式通信原語(yǔ)(附源碼):https://zhuanlan.zhihu.com/p/478953028

          Pytorch - 手寫(xiě)allreduce分布式訓(xùn)練(附源碼):https://zhuanlan.zhihu.com/p/482557067

          Pytorch - 算子間并行極簡(jiǎn)實(shí)現(xiàn)(附源碼):https://zhuanlan.zhihu.com/p/483640235

          Pytorch - 多機(jī)多卡極簡(jiǎn)實(shí)現(xiàn)(附源碼):https://zhuanlan.zhihu.com/p/486130584

          1. 介紹

          Pytorch在1.9.0引入了torchrun,用其替代1.9.0以前版本的torch.distributed.launch。torchrun在torch.distributed.launch 功能的基礎(chǔ)上主要新增了兩個(gè)功能:

          • Failover: 當(dāng)worker訓(xùn)練失敗時(shí),會(huì)自動(dòng)重新啟動(dòng)所有worker繼續(xù)進(jìn)行訓(xùn)練;
          • Elastic: 可以動(dòng)態(tài)增加或或刪除node節(jié)點(diǎn),本文將通過(guò)一個(gè)例子說(shuō)明Elastic Training應(yīng)該如何使用;

          本例中會(huì)先在Node0上啟動(dòng)4 GPU的worker group ,等其訓(xùn)練一段時(shí)間后,會(huì)在Node1上再啟動(dòng)4 GPU的workers,并與Node1上的workers構(gòu)成一個(gè)新的worker group,最終構(gòu)成一個(gè)2機(jī)8卡的分布式訓(xùn)練。

          2. 模型構(gòu)建

          一個(gè)簡(jiǎn)單的全連接模型神經(jīng)網(wǎng)絡(luò)模型

          class?ToyModel(nn.Module):
          ????def?__init__(self):
          ????????super(ToyModel,?self).__init__()
          ????????self.net1?=?nn.Linear(10,?10)
          ????????self.relu?=?nn.ReLU()
          ????????self.net2?=?nn.Linear(10,?5)

          ????def?forward(self,?x):
          ????????return?self.net2(self.relu(self.net1(x)))

          3. checkpoint 處理

          由于再每次增加或刪除node時(shí),會(huì)將所有worker kill掉,然后再重新啟動(dòng)所有worker進(jìn)行訓(xùn)練。因此,在訓(xùn)練代碼中要對(duì)訓(xùn)練的狀態(tài)進(jìn)行保存,以保證重啟后能接著上次的狀態(tài)繼續(xù)訓(xùn)練。

          需要保存的信息一般有如下內(nèi)容:

          • model :模型的參數(shù)信息
          • optimizer :優(yōu)化器的參數(shù)信心
          • epoch:當(dāng)前執(zhí)行到第幾個(gè)epoch

          save和load的代碼如下所示

          • torch.save:利用python的pickle將python的object 進(jìn)行序列化,并保存到本地文件;
          • torch.load : 將torch.save后的本地文件進(jìn)行反序列化,并加載到內(nèi)存中;
          • model.state_dict(): 存儲(chǔ)了model 每個(gè)layer和其對(duì)應(yīng)的param信息
          • optimizer.state_dict():存儲(chǔ)了優(yōu)化器的參數(shù)信信息
          def?save_checkpoint(epoch,?model,?optimizer,?path):
          ????torch.save({
          ????"epoch":?epoch,
          ????"model_state_dict":?model.state_dict(),
          ????"optimize_state_dict":?optimizer.state_dict(),
          },?path)

          def?load_checkpoint(path):
          ????checkpoint?=?torch.load(path)
          ????return?checkpoint

          4. 訓(xùn)練代碼

          初始化邏輯如下:

          • 1~3行: 輸出當(dāng)前worker的關(guān)鍵環(huán)境變量,用于后面的結(jié)果展示
          • 5~8行:創(chuàng)建模型、優(yōu)化器和損失函數(shù)
          • 10~12行:初始化參數(shù)信息
          • 14~19行:如果存在checkpoint,則加載checkpoint,并賦值給model、optimizer和firt_epoch
          ????local_rank?=?int(os.environ["LOCAL_RANK"])
          ????rank?=?int(os.environ["RANK"])
          ????print(f"[{os.getpid()}]?(rank?=?{rank},?local_rank?=?{local_rank})?train?worker?starting...")
          ????
          ????model?=?ToyModel().cuda(local_rank)
          ????ddp_model?=?DDP(model,?[local_rank])
          ????loss_fn?=?nn.MSELoss()
          ????optimizer?=?optim.SGD(ddp_model.parameters(),?lr=0.001)
          ????optimizer.zero_grad()
          ????max_epoch?=?100
          ????first_epoch?=?0
          ????ckp_path?=?"checkpoint.pt"
          ????
          ????if?os.path.exists(ckp_path):
          ????????print(f"load?checkpoint?from?{ckp_path}")
          ????????checkpoint?=?load_checkpoint(ckp_path)
          ????????model.load_state_dict(checkpoint["model_state_dict"])
          ????????optimizer.load_state_dict(checkpoint["optimize_state_dict"])
          ????????first_epoch?=?checkpoint["epoch"]

          訓(xùn)練邏輯:

          • 1行:epoch執(zhí)行的次數(shù)為first_epoch到max_epoch,以便能夠在worker被重啟后繼續(xù)原有的epoch繼續(xù)訓(xùn)練;
          • 2行:為了展示動(dòng)態(tài)添加node效果,這里添加sleep函數(shù)來(lái)降低訓(xùn)練的速度;
          • 3~8行:模型訓(xùn)練流程;
          • 9行:為了簡(jiǎn)單,文本每個(gè)epoch進(jìn)行一次checkpoint保存;將當(dāng)前的epoch,model和optimizer保存到checkpoint中;
          ????for?i?in?range(first_epoch,?max_epoch):
          ????????time.sleep(1)?#?為了展示動(dòng)態(tài)添加node效果,這里添加sleep函數(shù)來(lái)降低訓(xùn)練的速度
          ????????outputs?=?ddp_model(torch.randn(20,?10).to(local_rank))
          ????????labels?=?torch.randn(20,?5).to(local_rank)
          ????????loss?=?loss_fn(outputs,?labels)
          ????????loss.backward()
          ????????print(f"[{os.getpid()}]?epoch?{i}?(rank?=?{rank},?local_rank?=?{local_rank})?loss?=?{loss.item()}\n")
          ????????optimizer.step()
          ????????save_checkpoint(i,?model,?optimizer,?ckp_path)

          5. 啟動(dòng)方式

          由于我們使用torchrun來(lái)啟動(dòng)多機(jī)多卡任務(wù),無(wú)需使用spawn接口來(lái)啟動(dòng)多個(gè)進(jìn)程(torchrun會(huì)負(fù)責(zé)將我們的python script啟動(dòng)為一個(gè)process),因此直接調(diào)用上文編寫(xiě)的train函數(shù),并在前后分別添加DistributedDataParallel的初始化和效果函數(shù)即可。

          下面代碼描述了上文train接口的調(diào)用。

          def?run():
          ????env_dict?=?{
          ????????key:?os.environ[key]
          ????????for?key?in?("MASTER_ADDR",?"MASTER_PORT",?"WORLD_SIZE",?"LOCAL_WORLD_SIZE")
          ????}
          ????print(f"[{os.getpid()}]?Initializing?process?group?with:?{env_dict}")
          ????dist.init_process_group(backend="nccl")
          ????train()
          ????dist.destroy_process_group()


          if?__name__?==?"__main__":
          ????run()

          本例中使用torchrun來(lái)執(zhí)行多機(jī)多卡的分布式訓(xùn)練任務(wù)(注:torch.distributed.launch 已經(jīng)被pytorch淘汰了,盡量不要再使用)。啟動(dòng)腳本描述如下(注:node0和node1均通過(guò)該腳本進(jìn)行啟動(dòng))

          • --nnodes=1:3 :表示當(dāng)前訓(xùn)練任務(wù)接受最少1個(gè)node,最多3個(gè)node參與分布式訓(xùn)練;
          • --nproc_per_node=4:表示每個(gè)node上節(jié)點(diǎn)有4個(gè)process
          • --max_restarts=3: worker group最大的重啟次數(shù);這里需要注意的是,node fail、node scale down和node scale up都會(huì)導(dǎo)致restart;
          • --rdzv_id=1:一個(gè)unique的job id,所有node均使用同一個(gè)job id;
          • --rdzv_backend: rendezvous的backend實(shí)現(xiàn),默認(rèn)支持c10d和etcd兩種;rendezvous用于多個(gè)node之間的通信和協(xié)調(diào);
          • --rdzv_endpoint:rendezvous的地址,應(yīng)該為一個(gè)node的host ip和port;
          torchrun \
          --nnodes=1:3\
          --nproc_per_node=4\
          --max_restarts=3\
          --rdzv_id=1\
          --rdzv_backend=c10d\
          --rdzv_endpoint="192.0.0.1:1234"\
          train_elastic.py

          6. 結(jié)果分析

          代碼:BetterDL - train_elastic.py:https://github.com/tingshua-yts/BetterDL/blob/master/test/pytorch/DDP/train_elastic.py

          運(yùn)行環(huán)境: 2臺(tái)4卡 v100機(jī)器

          image: pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime

          gpu: v100

          先在node0上執(zhí)行執(zhí)行啟動(dòng)腳本

          torchrun \
          --nnodes=1:3\
          --nproc_per_node=4\
          --max_restarts=3\
          --rdzv_id=1\
          --rdzv_backend=c10d\
          --rdzv_endpoint="192.0.0.1:1234"\
          train_elastic.py

          得到如下結(jié)果

          • 2~5行:當(dāng)前啟動(dòng)的是單機(jī)4卡的訓(xùn)練任務(wù),因此WORLD_SIZE為4, LOCAL_WORKD_SIZE也為4
          • 6~9行:共有4個(gè)rank參與了分布式訓(xùn)練,rank0~rank3
          • 10~18行: rank0~rank3 均從epoch=0開(kāi)始訓(xùn)練
          r/workspace/DDP# sh run_elastic.sh
          [4031] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
          [4029] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
          [4030] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
          [4032] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '44901', 'WORLD_SIZE': '4', 'LOCAL_WORLD_SIZE': '4'}
          [4029] (rank = 0, local_rank = 0) train worker starting...
          [4030] (rank = 1, local_rank = 1) train worker starting...
          [4032] (rank = 3, local_rank = 3) train worker starting...
          [4031] (rank = 2, local_rank = 2) train worker starting...
          [4101] epoch 0 (rank = 1, local_rank = 1) loss = 0.9288564920425415
          [4103] epoch 0 (rank = 3, local_rank = 3) loss = 0.9711472988128662
          [4102] epoch 0 (rank = 2, local_rank = 2) loss = 1.0727070569992065
          [4100] epoch 0 (rank = 0, local_rank = 0) loss = 0.9402943253517151
          [4100] epoch 1 (rank = 0, local_rank = 0) loss = 1.0327017307281494
          [4101] epoch 1 (rank = 1, local_rank = 1) loss = 1.4485043287277222
          [4103] epoch 1 (rank = 3, local_rank = 3) loss = 1.0959293842315674
          [4102] epoch 1 (rank = 2, local_rank = 2) loss = 1.0669530630111694
          ...

          在node1上執(zhí)行與上面相同的腳本

          torchrun \
          --nnodes=1:3\
          --nproc_per_node=4\
          --max_restarts=3\
          --rdzv_id=1\
          --rdzv_backend=c10d\
          --rdzv_endpoint="192.0.0.1:1234"\
          train_elastic.py

          node1上結(jié)果如下:

          • 2~5行:由于添加node1,當(dāng)前執(zhí)行的是2機(jī)8卡的分布式訓(xùn)練任務(wù),因此WORLD_SIZE=8, LOCAL_WORLD_SIZE=4
          • 6~9行:當(dāng)前node1上workers的rank為rank4 ~rank7
          • 13~20行: 由于node1是在node0上work訓(xùn)練到epoch35的時(shí)候加入的,因此其接著epoch 35開(kāi)始訓(xùn)練
          /workspace/DDP# sh run_elastic.sh
          [696] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
          [697] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
          [695] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
          [694] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
          [697] (rank = 7, local_rank = 3) train worker starting...
          [695] (rank = 5, local_rank = 1) train worker starting...
          [694] (rank = 4, local_rank = 0) train worker starting...
          [696] (rank = 6, local_rank = 2) train worker starting...
          load checkpoint from checkpoint.ptload checkpoint from checkpoint.pt
          load checkpoint from checkpoint.pt
          load checkpoint from checkpoint.pt
          [697] epoch 35 (rank = 7, local_rank = 3) loss = 1.1888569593429565
          [694] epoch 35 (rank = 4, local_rank = 0) loss = 0.8916441202163696
          [695] epoch 35 (rank = 5, local_rank = 1) loss = 1.5685604810714722
          [696] epoch 35 (rank = 6, local_rank = 2) loss = 1.11683189868927
          [696] epoch 36 (rank = 6, local_rank = 2) loss = 1.3724170923233032
          [694] epoch 36 (rank = 4, local_rank = 0) loss = 1.061527967453003
          [695] epoch 36 (rank = 5, local_rank = 1) loss = 0.96876460313797
          [697] epoch 36 (rank = 7, local_rank = 3) loss = 0.8060566782951355
          ...

          node0上結(jié)果如下:

          • 6~9行: node0上的works在執(zhí)行到epoch 35時(shí),node1上執(zhí)行了訓(xùn)練腳本,請(qǐng)求加入到訓(xùn)練任務(wù)中
          • 10~13行:所有workers重新啟動(dòng),由于添加了node1,當(dāng)前執(zhí)行的是2機(jī)8卡的分布式訓(xùn)練任務(wù),因此WORLD_SIZE=8, LOCAL_WORLD_SIZE=4
          • 14~17行:當(dāng)前node1上works的rank為rank0~rank3
          • 18~21行:加載checkpoint
          • 22~30行:接著checkpoint中的model、optimizer和epoch繼續(xù)訓(xùn)練
          ...
          [4100] epoch 35 (rank = 0, local_rank = 0) loss = 1.0746158361434937
          [4101] epoch 35 (rank = 1, local_rank = 1) loss = 1.1712706089019775
          [4103] epoch 35 (rank = 3, local_rank = 3) loss = 1.1774182319641113
          [4102] epoch 35 (rank = 2, local_rank = 2) loss = 1.0898035764694214
          WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4100 closing signal SIGTERM
          WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4101 closing signal SIGTERM
          WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4102 closing signal SIGTERM
          WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 4103 closing signal SIGTERM
          [4164] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
          [4165] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
          [4162] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
          [4163] Initializing process group with: {'MASTER_ADDR': '192.0.0.1', 'MASTER_PORT': '42913', 'WORLD_SIZE': '8', 'LOCAL_WORLD_SIZE': '4'}
          [4162] (rank = 0, local_rank = 0) train worker starting...
          [4163] (rank = 1, local_rank = 1) train worker starting...
          [4164] (rank = 2, local_rank = 2) train worker starting...
          [4165] (rank = 3, local_rank = 3) train worker starting...
          load checkpoint from checkpoint.pt
          load checkpoint from checkpoint.pt
          load checkpoint from checkpoint.pt
          load checkpoint from checkpoint.pt
          [4165] epoch 35 (rank = 3, local_rank = 3) loss = 1.3437936305999756
          [4162] epoch 35 (rank = 0, local_rank = 0) loss = 1.5693414211273193
          [4163] epoch 35 (rank = 1, local_rank = 1) loss = 1.199862003326416
          [4164] epoch 35 (rank = 2, local_rank = 2) loss = 1.0465545654296875
          [4163] epoch 36 (rank = 1, local_rank = 1) loss = 0.9741991758346558
          [4162] epoch 36 (rank = 0, local_rank = 0) loss = 1.3609280586242676
          [4164] epoch 36 (rank = 2, local_rank = 2) loss = 0.9585908055305481
          [4165] epoch 36 (rank = 3, local_rank = 3) loss = 0.9169824123382568
          ...


          點(diǎn)個(gè)在看 paper不斷!

          瀏覽 21
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  免费看二极一级黄色片 | 色情网站免费在线观看 | 欧美九一精品 | 99r在线免费观看 | 成人黄色一级A片 |