PyTorch1.10發(fā)布:ZeroRedundancyOptimizer和Join
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
近日,PyTorch1.10版本發(fā)布,這個(gè)版本在分布式訓(xùn)練方面正式發(fā)布了ZeroRedundancyOptimizer,對(duì)標(biāo)微軟在DeepSpeed中發(fā)布的ZeRO,它可以wrap其它任意普通優(yōu)化器如SGD和Adam等,主要是實(shí)現(xiàn)optimizer state在DDP訓(xùn)練過程中切分,從而減少每個(gè)節(jié)點(diǎn)(進(jìn)程或者設(shè)備)的顯存使用。此外,這個(gè)版本也發(fā)布了Join,這個(gè)是一個(gè)上下文管理器,用來處理分布式訓(xùn)練中的不均勻樣本,DDP和 ZeroRedundancyOptimizer是支持這項(xiàng)功能的。
ZeroRedundancyOptimizer
ZeRO是微軟提出的一種大規(guī)模模型在分布式訓(xùn)練的一種優(yōu)化策略,主要是通過model states進(jìn)行切分來實(shí)現(xiàn)顯存占用的優(yōu)化,model states主要包括optimizer states,gradients和parameter。而ZeroRedundancyOptimizer用來實(shí)現(xiàn)對(duì)optimizer states的切分,這里的optimizer states指的是優(yōu)化器所需的參數(shù),比如SGD需要和模型參數(shù)一樣大小的momentum,而Adam需要exp_avg和exp_avg_sq,它們是模型參數(shù)的兩倍大小,當(dāng)模型較大時(shí),optimizer states會(huì)是不小的顯存開銷。而在DDP中,每個(gè)rank(node,process,device)都包括一個(gè)optimizer副本,在每個(gè)iteration中它們干相同的事情:用all-reduce后的gradients去更新模型參數(shù),從而保證每個(gè)rank的模型參數(shù)一致。不過這個(gè)過程可以優(yōu)化,那就是將optimizer states切分到每個(gè)rank上,每個(gè)rank的optimizer只保存一部分(1/world_size)模型參數(shù)需要的optimizer states,也只負(fù)責(zé)更新這部分模型參數(shù)。一旦某個(gè)rank完成參數(shù)更新后,它可以broadcast到其它ranks,從而實(shí)現(xiàn)各個(gè)rank模型參數(shù)的一致。ZeroRedundancyOptimizer其實(shí)相當(dāng)于ZeRO-DP-1,是ZeRO的最簡單版本,更多內(nèi)容可以閱讀之前的文章(如何用數(shù)據(jù)并行訓(xùn)練萬億參數(shù)模型?),F(xiàn)acebookfairscale庫也已經(jīng)實(shí)現(xiàn)了更全面ZeRO優(yōu)化版本:FSDP。
ZeroRedundancyOptimizer的使用很簡單,只需要對(duì)常規(guī)的optimizer進(jìn)行warp即可以,一個(gè)簡單的用例如下所示:
import?os
import?torch
import?torch.distributed?as?dist
import?torch.multiprocessing?as?mp
import?torch.nn?as?nn
import?torch.optim?as?optim
from?torch.distributed.optim?import?ZeroRedundancyOptimizer
from?torch.nn.parallel?import?DistributedDataParallel?as?DDP
def?print_peak_memory(prefix,?device):
????if?device?==?0:
????????print(f"{prefix}:?{torch.cuda.max_memory_allocated(device)?//?1e6}MB?")
def?example(rank,?world_size,?use_zero):
????torch.manual_seed(0)
????torch.cuda.manual_seed(0)
????os.environ['MASTER_ADDR']?=?'localhost'
????os.environ['MASTER_PORT']?=?'29500'
????#?create?default?process?group
????dist.init_process_group("gloo",?rank=rank,?world_size=world_size)
????#?create?local?model
????model?=?nn.Sequential(*[nn.Linear(2000,?2000).to(rank)?for?_?in?range(20)])
????print_peak_memory("Max?memory?allocated?after?creating?local?model",?rank)
????#?construct?DDP?model
????ddp_model?=?DDP(model,?device_ids=[rank])
????print_peak_memory("Max?memory?allocated?after?creating?DDP",?rank)
????#?define?loss?function?and?optimizer
????loss_fn?=?nn.MSELoss()
????if?use_zero:
????????#?簡單地warp
????????optimizer?=?ZeroRedundancyOptimizer(
????????????ddp_model.parameters(),
????????????optimizer_class=torch.optim.Adam,
????????????lr=0.01
????????)
????else:
????????optimizer?=?torch.optim.Adam(ddp_model.parameters(),?lr=0.01)
????#?forward?pass
????outputs?=?ddp_model(torch.randn(20,?2000).to(rank))
????labels?=?torch.randn(20,?2000).to(rank)
????#?backward?pass
????loss_fn(outputs,?labels).backward()
????#?update?parameters
????print_peak_memory("Max?memory?allocated?before?optimizer?step()",?rank)
????optimizer.step()
????print_peak_memory("Max?memory?allocated?after?optimizer?step()",?rank)
????print(f"params?sum?is:?{sum(model.parameters()).sum()}")
def?main():
????world_size?=?2
????print("===?Using?ZeroRedundancyOptimizer?===")
????mp.spawn(example,
????????args=(world_size,?True),
????????nprocs=world_size,
????????join=True)
????print("===?Not?Using?ZeroRedundancyOptimizer?===")
????mp.spawn(example,
????????args=(world_size,?False),
????????nprocs=world_size,
????????join=True)
if?__name__=="__main__":
????main()
????
##?output
===?Using?ZeroRedundancyOptimizer?===
Max?memory?allocated?after?creating?local?model:?335.0MB
Max?memory?allocated?after?creating?DDP:?656.0MB
Max?memory?allocated?before?optimizer?step():?992.0MB
Max?memory?allocated?after?optimizer?step():?1361.0MB
params?sum?is:?-3453.6123046875
params?sum?is:?-3453.6123046875
===?Not?Using?ZeroRedundancyOptimizer?===
Max?memory?allocated?after?creating?local?model:?335.0MB
Max?memory?allocated?after?creating?DDP:?656.0MB
Max?memory?allocated?before?optimizer?step():?992.0MB
Max?memory?allocated?after?optimizer?step():?1697.0MB
params?sum?is:?-3453.6123046875
params?sum?is:?-3453.6123046875
可以看到無論使用或不使用ZeroRedundancyOptimizer,模型創(chuàng)建使用的顯存是一樣的,而且最終的輸出也是一致的,但是optimizer執(zhí)行step后,兩者差異就出來了,使用ZeroRedundancyOptimizer可以降低大約一半的顯存消耗,這是因?yàn)閛ptimizer states被均分在2個(gè)rank上了。
Join
在DDP訓(xùn)練中,在backward背后其實(shí)是執(zhí)行all-reduce來實(shí)現(xiàn)各個(gè)rank上的gradients同步,這是一種集群通信(collective communications** **),所有的集群通信都需要所有rank的參與,如果某個(gè)rank的輸入較少,那么其它rank就會(huì)等待甚至出錯(cuò)。而Join這個(gè)上下文管理器就是為了解決分布式訓(xùn)練過程中的不均勻輸入的情況,簡單來說就是允許某些輸入較少的rank(已經(jīng)join)可以跟隨那么未執(zhí)行完的rank(未join)進(jìn)行集群通信。看起來比較抽象,但是從下面的一個(gè)DDP例子就比較容易理解:
import?os
import?torch
import?torch.distributed?as?dist
import?torch.multiprocessing?as?mp
from?torch.distributed.algorithms.join?import?Join
from?torch.nn.parallel?import?DistributedDataParallel?as?DDP
BACKEND?=?"nccl"
WORLD_SIZE?=?2
NUM_INPUTS?=?5
def?worker(rank):
????os.environ['MASTER_ADDR']?=?'localhost'
????os.environ['MASTER_PORT']?=?'29500'
????dist.init_process_group(BACKEND,?rank=rank,?world_size=WORLD_SIZE)
????model?=?DDP(torch.nn.Linear(1,?1).to(rank),?device_ids=[rank])
????#?Rank?1?gets?one?more?input?than?rank?0
????inputs?=?[torch.tensor([1]).float()?for?_?in?range(NUM_INPUTS?+?rank)]
????num_inputs?=?0
????with?Join([model]):
????????for?input?in?inputs:
????????????num_inputs?+=?1
????????????loss?=?model(input).sum()
????????????loss.backward()
????print(f"Rank?{rank}?has?exhausted?all?{num_inputs}?of?its?inputs!")
def?main():
????mp.spawn(worker,?nprocs=WORLD_SIZE,?join=True)
if?__name__?==?"__main__":
????main()
????
##?output
Rank?0?has?exhausted?all?5?of?its?inputs!
Rank?1?has?exhausted?all?6?of?its?inputs!
這里rank0只有5個(gè)inputs,而rank1有6個(gè)inputs,這是不均勻的數(shù)據(jù)。如果不使用Join的話,rank1在處理最后的一個(gè)input時(shí)會(huì)死等,但是使用Join后,就能正常處理這種情況,這背后的原理后面再說。這里的with Join([model])和with model.join():,不過前者更靈活,因?yàn)樗芴幚矶鄠€(gè)類的情況,比如要使用ZeroRedundancyOptimizer:
from?torch.distributed.optim?import?ZeroRedundancyOptimizer?as?ZeRO
from?torch.optim?import?Adam
def?worker(rank):
????os.environ['MASTER_ADDR']?=?'localhost'
????os.environ['MASTER_PORT']?=?'29500'
????dist.init_process_group(BACKEND,?rank=rank,?world_size=WORLD_SIZE)
????model?=?DDP(torch.nn.Linear(1,?1).to(rank),?device_ids=[rank])
????optim?=?ZeRO(model.parameters(),?Adam,?lr=0.01)
????#?Rank?1?gets?one?more?input?than?rank?0
????inputs?=?[torch.tensor([1]).float()?for?_?in?range(NUM_INPUTS?+?rank)]
????num_inputs?=?0
????#?Pass?both?`model`?and?`optim`?into?`Join()`
????with?Join([model,?optim]):
????????for?input?in?inputs:
????????????num_inputs?+=?1
????????????loss?=?model(input).sum()
????????????loss.backward()
????????????optim.step()
????print(f"Rank?{rank}?has?exhausted?all?{num_inputs}?of?its?inputs!")
此外,Join還支持修改加入Join的類的關(guān)鍵字參數(shù),比如DDP的divide_by_initial_world_size,這個(gè)參數(shù)決定梯度是除以最初的world_size還是有效的world_size(未join的ranks總和)。具體使用如下:
with?Join([model,?optim],?divide_by_initial_world_size=False):
????for?input?in?inputs:
????????...
要理解Join背后的原理需要理解兩個(gè)類:Joinable和JoinHook。所以送入Join的類必須是Joinable,即需要繼承這個(gè)類,而且要實(shí)現(xiàn)3個(gè)方法。
**join_hook(self, kwargs) -> JoinHook:返回一個(gè)JoinHook,決定了已經(jīng)join的ranks跟隨其它未join的ranks進(jìn)行集群操作的具體行為; join_device(self) -> torch.device; join_process_group(self) -> ProcessGroup;
后面兩個(gè)方法是Join來處理集群通信所必須的,而join_hook決定了具體行為。DistributedDataParallel和ZeroRedundancyOptimizer之所以能用在Join上是因?yàn)樗鼈円呀?jīng)繼承Joinable,并實(shí)現(xiàn)了這三個(gè)方法。
JoinHooK包含兩個(gè)方法:
main_hook(self) -> None:如果存在未join的rank,那么已經(jīng)join的rank在就在每次集群操作時(shí)重復(fù)執(zhí)行這個(gè)方法,即如何跟隨其它未join的rank進(jìn)行的集群操作; post_hook(self, is_last_joiner: bool) -> None:當(dāng)所有rank都join后,這個(gè)方法會(huì)執(zhí)行一次,這里is_last_joiner參數(shù)告知這個(gè)rank是否是最后join的,注意last joiner可能不止一個(gè);
對(duì)于ZeroRedundancyOptimizer,它的main_hook要做的就是執(zhí)行一次optimizer step,因?yàn)殡m然它們已經(jīng)join但依然需要負(fù)責(zé)更新切分到它們那里的參數(shù)更新和同步。對(duì)于DistributedDataParallel,它的post_hook要做的是將最后join的rank的模型參數(shù)broadcasts到其它ranks,以保證模型參數(shù)的一致性。
這里給一個(gè)簡單的case來展示Joinable和JoinHook是如何具體工作的:
import?os
import?torch
import?torch.distributed?as?dist
import?torch.multiprocessing?as?mp
from?torch.distributed.algorithms.join?import?Join,?Joinable,?JoinHook
BACKEND?=?"nccl"
WORLD_SIZE?=?2
NUM_INPUTS?=?5
class?CounterJoinHook(JoinHook):
????r"""
????Join?hook?for?:class:`Counter`.
????Arguments:
????????counter?(Counter):?the?:class:`Counter`?object?using?this?hook.
????????sync_max_count?(bool):?whether?to?sync?the?max?count?once?all?ranks
????????????join.
????"""
????def?__init__(
????????self,
????????counter,
????????sync_max_count
????):
????????self.counter?=?counter
????????self.sync_max_count?=?sync_max_count
????def?main_hook(self):
????????r"""
????????Shadows?the?counter's?all-reduce?by?all-reducing?a?dim-1?zero?tensor.
????????"""
????????t?=?torch.zeros(1,?device=self.counter.device)
????????dist.all_reduce(t)
????def?post_hook(self,?is_last_joiner:?bool):
????????r"""
????????Synchronizes?the?max?count?across?all?:class:`Counter`?s?if
????????``sync_max_count=True``.
????????"""
????????if?not?self.sync_max_count:
????????????return
????????rank?=?dist.get_rank(self.counter.process_group)
????????common_rank?=?self.counter.find_common_rank(rank,?is_last_joiner)
????????if?rank?==?common_rank:
????????????self.counter.max_count?=?self.counter.count.detach().clone()
????????dist.broadcast(self.counter.max_count,?src=common_rank)
class?Counter(Joinable):
????r"""
????Example?:class:`Joinable`?that?counts?the?number?of?training?iterations
????that?it?participates?in.
????"""
????def?__init__(self,?device,?process_group):
????????super(Counter,?self).__init__()
????????self.device?=?device
????????self.process_group?=?process_group
????????self.count?=?torch.tensor([0],?device=device).float()
????????self.max_count?=?torch.tensor([0],?device=device).float()
????def?__call__(self):
????????r"""
????????Counts?the?number?of?inputs?processed?on?this?iteration?by?all?ranks
????????by?all-reducing?a?dim-1?one?tensor;?increments?its?own?internal?count.
????????"""
????????Join.notify_join_context(self)
????????t?=?torch.ones(1,?device=self.device).float()
????????dist.all_reduce(t)
????????self.count?+=?t
????def?join_hook(self,?**kwargs)?->?JoinHook:
????????r"""
????????Return?a?join?hook?that?shadows?the?all-reduce?in?:meth:`__call__`.
????????This?join?hook?supports?the?following?keyword?arguments:
????????????sync_max_count?(bool,?optional):?whether?to?synchronize?the?maximum
????????????????count?across?all?ranks?once?all?ranks?join;?default?is?``False``.
????????"""
????????sync_max_count?=?kwargs.get("sync_max_count",?False)
????????return?CounterJoinHook(self,?sync_max_count)
????@property
????def?join_device(self)?->?torch.device:
????????return?self.device
????@property
????def?join_process_group(self):
????????return?self.process_group
????#?用來確定最后join的rank,由于不止一個(gè),可以選擇rank最大的rank,以用來同步
????def?find_common_rank(self,?rank,?to_consider):
????????r"""
????????Returns?the?max?rank?of?the?ones?to?consider?over?the?process?group.
????????"""
????????common_rank?=?torch.tensor([rank?if?to_consider?else?-1],?device=self.device)
????????dist.all_reduce(common_rank,?op=dist.ReduceOp.MAX,?group=self.process_group)
????????common_rank?=?common_rank.item()
????????return?common_rank
def?worker(rank):
????assert?torch.cuda.device_count()?>=?WORLD_SIZE
????os.environ['MASTER_ADDR']?=?'localhost'
????os.environ['MASTER_PORT']?=?'29500'
????dist.init_process_group(BACKEND,?rank=rank,?world_size=WORLD_SIZE)
????counter?=?Counter(torch.device(f"cuda:{rank}"),?dist.group.WORLD)
????inputs?=?[torch.tensor([1]).float()?for?_?in?range(NUM_INPUTS?+?rank)]
????with?Join([counter],?sync_max_count=True):
????????for?_?in?inputs:
????????????counter()
????print(f"{int(counter.count.item())}?inputs?processed?before?rank?{rank}?joined!")
????print(f"{int(counter.max_count.item())}?inputs?processed?across?all?ranks!")
def?main():
????mp.spawn(worker,?nprocs=WORLD_SIZE,?join=True)
if?__name__?==?"__main__":
????main()
#?output????
#?Since?rank?0?sees?5?inputs?and?rank?1?sees?6,?this?yields?the?output:
10?inputs?processed?before?rank?0?joined!
11?inputs?processed?across?all?ranks!
11?inputs?processed?before?rank?1?joined!
11?inputs?processed?across?all?ranks!
這里的Counter是一個(gè)Joinable,功能是用來實(shí)現(xiàn)分布式計(jì)數(shù),它對(duì)應(yīng)的CounterJoinHook來處理不均勻輸入,其中main_hook就是all-reduce一個(gè)為0的tensor,而post_hook用來同步最大的count,這里也用了關(guān)鍵字參數(shù)。
后話
這個(gè)發(fā)布的ZeroRedundancyOptimizer其實(shí)在PyTorch1.8版本已經(jīng)支持,不過應(yīng)該是進(jìn)行了優(yōu)化,比如支持Join,但目前的ZeroRedundancyOptimizer其實(shí)只是實(shí)現(xiàn)的ZeRO-DP-1,應(yīng)該后續(xù)還有優(yōu)化空間。而Join目前還處于迭代中,后續(xù)應(yīng)該會(huì)有更多新的更新。
推薦閱讀
谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!
"未來"的經(jīng)典之作ViT:transformer is all you need!
PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
不妨試試MoCo,來替換ImageNet上pretrain模型!
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)

