<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          PyTorch1.10發(fā)布:ZeroRedundancyOptimizer和Join

          共 12313字,需瀏覽 25分鐘

           ·

          2021-10-26 05:00

          點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師

          設(shè)為星標(biāo),干貨直達(dá)!


          近日,PyTorch1.10版本發(fā)布,這個(gè)版本在分布式訓(xùn)練方面正式發(fā)布了ZeroRedundancyOptimizer,對(duì)標(biāo)微軟在DeepSpeed中發(fā)布的ZeRO,它可以wrap其它任意普通優(yōu)化器如SGD和Adam等,主要是實(shí)現(xiàn)optimizer state在DDP訓(xùn)練過程中切分,從而減少每個(gè)節(jié)點(diǎn)(進(jìn)程或者設(shè)備)的顯存使用。此外,這個(gè)版本也發(fā)布了Join,這個(gè)是一個(gè)上下文管理器,用來處理分布式訓(xùn)練中的不均勻樣本,DDP和 ZeroRedundancyOptimizer是支持這項(xiàng)功能的。

          ZeroRedundancyOptimizer

          ZeRO是微軟提出的一種大規(guī)模模型在分布式訓(xùn)練的一種優(yōu)化策略,主要是通過model states進(jìn)行切分來實(shí)現(xiàn)顯存占用的優(yōu)化,model states主要包括optimizer states,gradients和parameter。而ZeroRedundancyOptimizer用來實(shí)現(xiàn)對(duì)optimizer states的切分,這里的optimizer states指的是優(yōu)化器所需的參數(shù),比如SGD需要和模型參數(shù)一樣大小的momentum,而Adam需要exp_avg和exp_avg_sq,它們是模型參數(shù)的兩倍大小,當(dāng)模型較大時(shí),optimizer states會(huì)是不小的顯存開銷。而在DDP中,每個(gè)rank(node,process,device)都包括一個(gè)optimizer副本,在每個(gè)iteration中它們干相同的事情:用all-reduce后的gradients去更新模型參數(shù),從而保證每個(gè)rank的模型參數(shù)一致。不過這個(gè)過程可以優(yōu)化,那就是將optimizer states切分到每個(gè)rank上,每個(gè)rank的optimizer只保存一部分(1/world_size)模型參數(shù)需要的optimizer states,也只負(fù)責(zé)更新這部分模型參數(shù)。一旦某個(gè)rank完成參數(shù)更新后,它可以broadcast到其它ranks,從而實(shí)現(xiàn)各個(gè)rank模型參數(shù)的一致。ZeroRedundancyOptimizer其實(shí)相當(dāng)于ZeRO-DP-1,是ZeRO的最簡單版本,更多內(nèi)容可以閱讀之前的文章(如何用數(shù)據(jù)并行訓(xùn)練萬億參數(shù)模型?),F(xiàn)acebookfairscale庫也已經(jīng)實(shí)現(xiàn)了更全面ZeRO優(yōu)化版本:FSDP。

          ZeroRedundancyOptimizer的使用很簡單,只需要對(duì)常規(guī)的optimizer進(jìn)行warp即可以,一個(gè)簡單的用例如下所示:

          import?os
          import?torch
          import?torch.distributed?as?dist
          import?torch.multiprocessing?as?mp
          import?torch.nn?as?nn
          import?torch.optim?as?optim
          from?torch.distributed.optim?import?ZeroRedundancyOptimizer
          from?torch.nn.parallel?import?DistributedDataParallel?as?DDP

          def?print_peak_memory(prefix,?device):
          ????if?device?==?0:
          ????????print(f"{prefix}:?{torch.cuda.max_memory_allocated(device)?//?1e6}MB?")

          def?example(rank,?world_size,?use_zero):
          ????torch.manual_seed(0)
          ????torch.cuda.manual_seed(0)
          ????os.environ['MASTER_ADDR']?=?'localhost'
          ????os.environ['MASTER_PORT']?=?'29500'
          ????#?create?default?process?group
          ????dist.init_process_group("gloo",?rank=rank,?world_size=world_size)

          ????#?create?local?model
          ????model?=?nn.Sequential(*[nn.Linear(2000,?2000).to(rank)?for?_?in?range(20)])
          ????print_peak_memory("Max?memory?allocated?after?creating?local?model",?rank)

          ????#?construct?DDP?model
          ????ddp_model?=?DDP(model,?device_ids=[rank])
          ????print_peak_memory("Max?memory?allocated?after?creating?DDP",?rank)

          ????#?define?loss?function?and?optimizer
          ????loss_fn?=?nn.MSELoss()
          ????if?use_zero:
          ????????#?簡單地warp
          ????????optimizer?=?ZeroRedundancyOptimizer(
          ????????????ddp_model.parameters(),
          ????????????optimizer_class=torch.optim.Adam,
          ????????????lr=0.01
          ????????)
          ????else:
          ????????optimizer?=?torch.optim.Adam(ddp_model.parameters(),?lr=0.01)

          ????#?forward?pass
          ????outputs?=?ddp_model(torch.randn(20,?2000).to(rank))
          ????labels?=?torch.randn(20,?2000).to(rank)
          ????#?backward?pass
          ????loss_fn(outputs,?labels).backward()

          ????#?update?parameters
          ????print_peak_memory("Max?memory?allocated?before?optimizer?step()",?rank)
          ????optimizer.step()
          ????print_peak_memory("Max?memory?allocated?after?optimizer?step()",?rank)

          ????print(f"params?sum?is:?{sum(model.parameters()).sum()}")



          def?main():
          ????world_size?=?2
          ????print("===?Using?ZeroRedundancyOptimizer?===")
          ????mp.spawn(example,
          ????????args=(world_size,?True),
          ????????nprocs=world_size,
          ????????join=True)

          ????print("===?Not?Using?ZeroRedundancyOptimizer?===")
          ????mp.spawn(example,
          ????????args=(world_size,?False),
          ????????nprocs=world_size,
          ????????join=True)

          if?__name__=="__main__":
          ????main()
          ????
          ##?output
          ===?Using?ZeroRedundancyOptimizer?===
          Max?memory?allocated?after?creating?local?model:?335.0MB
          Max?memory?allocated?after?creating?DDP:?656.0MB
          Max?memory?allocated?before?optimizer?step():?992.0MB
          Max?memory?allocated?after?optimizer?step():?1361.0MB
          params?sum?is:?-3453.6123046875
          params?sum?is:?-3453.6123046875
          ===?Not?Using?ZeroRedundancyOptimizer?===
          Max?memory?allocated?after?creating?local?model:?335.0MB
          Max?memory?allocated?after?creating?DDP:?656.0MB
          Max?memory?allocated?before?optimizer?step():?992.0MB
          Max?memory?allocated?after?optimizer?step():?1697.0MB
          params?sum?is:?-3453.6123046875
          params?sum?is:?-3453.6123046875

          可以看到無論使用或不使用ZeroRedundancyOptimizer,模型創(chuàng)建使用的顯存是一樣的,而且最終的輸出也是一致的,但是optimizer執(zhí)行step后,兩者差異就出來了,使用ZeroRedundancyOptimizer可以降低大約一半的顯存消耗,這是因?yàn)閛ptimizer states被均分在2個(gè)rank上了。

          Join

          在DDP訓(xùn)練中,在backward背后其實(shí)是執(zhí)行all-reduce來實(shí)現(xiàn)各個(gè)rank上的gradients同步,這是一種集群通信(collective communications** **),所有的集群通信都需要所有rank的參與,如果某個(gè)rank的輸入較少,那么其它rank就會(huì)等待甚至出錯(cuò)。而Join這個(gè)上下文管理器就是為了解決分布式訓(xùn)練過程中的不均勻輸入的情況,簡單來說就是允許某些輸入較少的rank(已經(jīng)join)可以跟隨那么未執(zhí)行完的rank(未join)進(jìn)行集群通信。看起來比較抽象,但是從下面的一個(gè)DDP例子就比較容易理解:

          import?os
          import?torch
          import?torch.distributed?as?dist
          import?torch.multiprocessing?as?mp
          from?torch.distributed.algorithms.join?import?Join
          from?torch.nn.parallel?import?DistributedDataParallel?as?DDP

          BACKEND?=?"nccl"
          WORLD_SIZE?=?2
          NUM_INPUTS?=?5

          def?worker(rank):
          ????os.environ['MASTER_ADDR']?=?'localhost'
          ????os.environ['MASTER_PORT']?=?'29500'
          ????dist.init_process_group(BACKEND,?rank=rank,?world_size=WORLD_SIZE)

          ????model?=?DDP(torch.nn.Linear(1,?1).to(rank),?device_ids=[rank])
          ????#?Rank?1?gets?one?more?input?than?rank?0
          ????inputs?=?[torch.tensor([1]).float()?for?_?in?range(NUM_INPUTS?+?rank)]

          ????num_inputs?=?0
          ????with?Join([model]):
          ????????for?input?in?inputs:
          ????????????num_inputs?+=?1
          ????????????loss?=?model(input).sum()
          ????????????loss.backward()

          ????print(f"Rank?{rank}?has?exhausted?all?{num_inputs}?of?its?inputs!")

          def?main():
          ????mp.spawn(worker,?nprocs=WORLD_SIZE,?join=True)

          if?__name__?==?"__main__":
          ????main()
          ????
          ##?output
          Rank?0?has?exhausted?all?5?of?its?inputs!
          Rank?1?has?exhausted?all?6?of?its?inputs!

          這里rank0只有5個(gè)inputs,而rank1有6個(gè)inputs,這是不均勻的數(shù)據(jù)。如果不使用Join的話,rank1在處理最后的一個(gè)input時(shí)會(huì)死等,但是使用Join后,就能正常處理這種情況,這背后的原理后面再說。這里的with Join([model])with model.join():,不過前者更靈活,因?yàn)樗芴幚矶鄠€(gè)類的情況,比如要使用ZeroRedundancyOptimizer:

          from?torch.distributed.optim?import?ZeroRedundancyOptimizer?as?ZeRO
          from?torch.optim?import?Adam

          def?worker(rank):
          ????os.environ['MASTER_ADDR']?=?'localhost'
          ????os.environ['MASTER_PORT']?=?'29500'
          ????dist.init_process_group(BACKEND,?rank=rank,?world_size=WORLD_SIZE)

          ????model?=?DDP(torch.nn.Linear(1,?1).to(rank),?device_ids=[rank])
          ????optim?=?ZeRO(model.parameters(),?Adam,?lr=0.01)
          ????#?Rank?1?gets?one?more?input?than?rank?0
          ????inputs?=?[torch.tensor([1]).float()?for?_?in?range(NUM_INPUTS?+?rank)]

          ????num_inputs?=?0
          ????#?Pass?both?`model`?and?`optim`?into?`Join()`
          ????with?Join([model,?optim]):
          ????????for?input?in?inputs:
          ????????????num_inputs?+=?1
          ????????????loss?=?model(input).sum()
          ????????????loss.backward()
          ????????????optim.step()

          ????print(f"Rank?{rank}?has?exhausted?all?{num_inputs}?of?its?inputs!")

          此外,Join還支持修改加入Join的類的關(guān)鍵字參數(shù),比如DDP的divide_by_initial_world_size,這個(gè)參數(shù)決定梯度是除以最初的world_size還是有效的world_size(未join的ranks總和)。具體使用如下:

          with?Join([model,?optim],?divide_by_initial_world_size=False):
          ????for?input?in?inputs:
          ????????...

          要理解Join背后的原理需要理解兩個(gè)類:Joinable和JoinHook。所以送入Join的類必須是Joinable,即需要繼承這個(gè)類,而且要實(shí)現(xiàn)3個(gè)方法。

          • **join_hook(self, kwargs) -> JoinHook:返回一個(gè)JoinHook,決定了已經(jīng)join的ranks跟隨其它未join的ranks進(jìn)行集群操作的具體行為;
          • join_device(self) -> torch.device;
          • join_process_group(self) -> ProcessGroup

          后面兩個(gè)方法是Join來處理集群通信所必須的,而join_hook決定了具體行為。DistributedDataParallelZeroRedundancyOptimizer之所以能用在Join上是因?yàn)樗鼈円呀?jīng)繼承Joinable,并實(shí)現(xiàn)了這三個(gè)方法。

          JoinHooK包含兩個(gè)方法:

          • main_hook(self) -> None:如果存在未join的rank,那么已經(jīng)join的rank在就在每次集群操作時(shí)重復(fù)執(zhí)行這個(gè)方法,即如何跟隨其它未join的rank進(jìn)行的集群操作;
          • post_hook(self, is_last_joiner: bool) -> None:當(dāng)所有rank都join后,這個(gè)方法會(huì)執(zhí)行一次,這里is_last_joiner參數(shù)告知這個(gè)rank是否是最后join的,注意last joiner可能不止一個(gè);


          對(duì)于ZeroRedundancyOptimizer,它的main_hook要做的就是執(zhí)行一次optimizer step,因?yàn)殡m然它們已經(jīng)join但依然需要負(fù)責(zé)更新切分到它們那里的參數(shù)更新和同步。對(duì)于DistributedDataParallel,它的post_hook要做的是將最后join的rank的模型參數(shù)broadcasts到其它ranks,以保證模型參數(shù)的一致性。

          這里給一個(gè)簡單的case來展示Joinable和JoinHook是如何具體工作的:

          import?os
          import?torch
          import?torch.distributed?as?dist
          import?torch.multiprocessing?as?mp
          from?torch.distributed.algorithms.join?import?Join,?Joinable,?JoinHook

          BACKEND?=?"nccl"
          WORLD_SIZE?=?2
          NUM_INPUTS?=?5

          class?CounterJoinHook(JoinHook):
          ????r"""
          ????Join?hook?for?:class:`Counter`.

          ????Arguments:
          ????????counter?(Counter):?the?:class:`Counter`?object?using?this?hook.
          ????????sync_max_count?(bool):?whether?to?sync?the?max?count?once?all?ranks
          ????????????join.
          ????"""

          ????def?__init__(
          ????????self,
          ????????counter,
          ????????sync_max_count
          ????)
          :

          ????????self.counter?=?counter
          ????????self.sync_max_count?=?sync_max_count

          ????def?main_hook(self):
          ????????r"""
          ????????Shadows?the?counter's?all-reduce?by?all-reducing?a?dim-1?zero?tensor.
          ????????"""

          ????????t?=?torch.zeros(1,?device=self.counter.device)
          ????????dist.all_reduce(t)

          ????def?post_hook(self,?is_last_joiner:?bool):
          ????????r"""
          ????????Synchronizes?the?max?count?across?all?:class:`Counter`?s?if
          ????????``sync_max_count=True``.
          ????????"""

          ????????if?not?self.sync_max_count:
          ????????????return
          ????????rank?=?dist.get_rank(self.counter.process_group)
          ????????common_rank?=?self.counter.find_common_rank(rank,?is_last_joiner)
          ????????if?rank?==?common_rank:
          ????????????self.counter.max_count?=?self.counter.count.detach().clone()
          ????????dist.broadcast(self.counter.max_count,?src=common_rank)

          class?Counter(Joinable):
          ????r"""
          ????Example?:class:`Joinable`?that?counts?the?number?of?training?iterations
          ????that?it?participates?in.
          ????"""

          ????def?__init__(self,?device,?process_group):
          ????????super(Counter,?self).__init__()
          ????????self.device?=?device
          ????????self.process_group?=?process_group
          ????????self.count?=?torch.tensor([0],?device=device).float()
          ????????self.max_count?=?torch.tensor([0],?device=device).float()

          ????def?__call__(self):
          ????????r"""
          ????????Counts?the?number?of?inputs?processed?on?this?iteration?by?all?ranks
          ????????by?all-reducing?a?dim-1?one?tensor;?increments?its?own?internal?count.
          ????????"""

          ????????Join.notify_join_context(self)
          ????????t?=?torch.ones(1,?device=self.device).float()
          ????????dist.all_reduce(t)
          ????????self.count?+=?t

          ????def?join_hook(self,?**kwargs)?->?JoinHook:
          ????????r"""
          ????????Return?a?join?hook?that?shadows?the?all-reduce?in?:meth:`__call__`.

          ????????This?join?hook?supports?the?following?keyword?arguments:
          ????????????sync_max_count?(bool,?optional):?whether?to?synchronize?the?maximum
          ????????????????count?across?all?ranks?once?all?ranks?join;?default?is?``False``.
          ????????"""

          ????????sync_max_count?=?kwargs.get("sync_max_count",?False)
          ????????return?CounterJoinHook(self,?sync_max_count)

          ????@property
          ????def?join_device(self)?->?torch.device:
          ????????return?self.device

          ????@property
          ????def?join_process_group(self):
          ????????return?self.process_group

          ????#?用來確定最后join的rank,由于不止一個(gè),可以選擇rank最大的rank,以用來同步
          ????def?find_common_rank(self,?rank,?to_consider):
          ????????r"""
          ????????Returns?the?max?rank?of?the?ones?to?consider?over?the?process?group.
          ????????"""

          ????????common_rank?=?torch.tensor([rank?if?to_consider?else?-1],?device=self.device)
          ????????dist.all_reduce(common_rank,?op=dist.ReduceOp.MAX,?group=self.process_group)
          ????????common_rank?=?common_rank.item()
          ????????return?common_rank

          def?worker(rank):
          ????assert?torch.cuda.device_count()?>=?WORLD_SIZE
          ????os.environ['MASTER_ADDR']?=?'localhost'
          ????os.environ['MASTER_PORT']?=?'29500'
          ????dist.init_process_group(BACKEND,?rank=rank,?world_size=WORLD_SIZE)

          ????counter?=?Counter(torch.device(f"cuda:{rank}"),?dist.group.WORLD)
          ????inputs?=?[torch.tensor([1]).float()?for?_?in?range(NUM_INPUTS?+?rank)]

          ????with?Join([counter],?sync_max_count=True):
          ????????for?_?in?inputs:
          ????????????counter()

          ????print(f"{int(counter.count.item())}?inputs?processed?before?rank?{rank}?joined!")
          ????print(f"{int(counter.max_count.item())}?inputs?processed?across?all?ranks!")

          def?main():
          ????mp.spawn(worker,?nprocs=WORLD_SIZE,?join=True)

          if?__name__?==?"__main__":
          ????main()

          #?output????
          #?Since?rank?0?sees?5?inputs?and?rank?1?sees?6,?this?yields?the?output:

          10?inputs?processed?before?rank?0?joined!
          11?inputs?processed?across?all?ranks!
          11?inputs?processed?before?rank?1?joined!
          11?inputs?processed?across?all?ranks!

          這里的Counter是一個(gè)Joinable,功能是用來實(shí)現(xiàn)分布式計(jì)數(shù),它對(duì)應(yīng)的CounterJoinHook來處理不均勻輸入,其中main_hook就是all-reduce一個(gè)為0的tensor,而post_hook用來同步最大的count,這里也用了關(guān)鍵字參數(shù)。

          后話

          這個(gè)發(fā)布的ZeroRedundancyOptimizer其實(shí)在PyTorch1.8版本已經(jīng)支持,不過應(yīng)該是進(jìn)行了優(yōu)化,比如支持Join,但目前的ZeroRedundancyOptimizer其實(shí)只是實(shí)現(xiàn)的ZeRO-DP-1,應(yīng)該后續(xù)還有優(yōu)化空間。而Join目前還處于迭代中,后續(xù)應(yīng)該會(huì)有更多新的更新。



          推薦閱讀

          CPVT:一個(gè)卷積就可以隱式編碼位置信息

          SOTA模型Swin Transformer是如何煉成的!

          谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!

          BatchNorm的避坑指南(上)

          BatchNorm的避坑指南(下)

          目標(biāo)跟蹤入門篇-相關(guān)濾波

          SOTA模型Swin Transformer是如何煉成的!

          MoCo V3:我并不是你想的那樣!

          Transformer在語義分割上的應(yīng)用

          "未來"的經(jīng)典之作ViT:transformer is all you need!

          PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!

          漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA

          Transformer為何能闖入CV界秒殺CNN?

          不妨試試MoCo,來替換ImageNet上pretrain模型!


          機(jī)器學(xué)習(xí)算法工程師


          ? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)


          瀏覽 89
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  東方在线成人AⅤ | 麻豆91久久久 | 中国毛片播放 | AV观看在线观看 | 天天操天天舔天天日 |