圖解大模型訓(xùn)練之:Megatron源碼解讀2,模型并行
源碼解讀系列將和大家一起來(lái)讀Megatron的pretrain部分代碼。
在源碼解讀第一篇中,我們講解了如何做「分布式環(huán)境初始化」,即按照DP/TP/PP對(duì)進(jìn)程進(jìn)行分組,并為每個(gè)進(jìn)程指定GPU。在這一章中,我們將一起讀「模型并行部分」:如何切分模型,并搬入分布式環(huán)境定義好的DP/TP/PP組中。

「本文將提供:」
詳細(xì)的圖解。畫圖說(shuō)明代碼的設(shè)計(jì)架構(gòu),講清代碼想做一件什么事。 詳細(xì)的代碼注釋。在圖解的基礎(chǔ)上,提取核心代碼部分,并附上注釋。
「如何利用本文提高源碼閱讀效率:」
先看一~三部分。了解模型并行的設(shè)計(jì)思想、整體框架及入口函數(shù)。 打開Megatron源碼,找到入口函數(shù),開始閱讀。 閱讀中的每一塊細(xì)節(jié),可參考四~八部分。
「閱讀本文前置知識(shí):」
「本文目錄:」
一、模型概述
二、模型切割在做一件什么事
2.1 模型切割設(shè)計(jì)思想 2.2 隨機(jī)種子
三、模型并行框架
3.1 模型并行入口函數(shù) 3.2 定義并搬運(yùn)模型 3.3 分布式模型:CodeGeeX
四、MegatronModule
五、Emebdding
六、VocabParallelEmebdding
七、ParallelSelfAttention:分布式block的一般套路
7.1 列切割:ColumnParallelLinear 7.2 行切割:RowParallelLinear 7.3 ParallelSelfAttention
八、CrossEntropy
8.1 計(jì)算logit 8.2 計(jì)算交叉熵
九、筋疲力盡的總結(jié)
十、參考(本文相關(guān)源碼與論文)
一、模型概述
前文說(shuō)過(guò),用Megatron做分布式訓(xùn)練的開源大模型有很多,我們選用的是THUDM開源的CodeGeeX(代碼生成式大模型,類比于openAI Codex)。選用它的原因是“完全開源”與“清晰的模型架構(gòu)和預(yù)訓(xùn)練配置圖”,能幫助我們高效閱讀源碼。我們?cè)賮?lái)回顧下這兩張圖。
「模型架構(gòu)」

「預(yù)訓(xùn)練配置」

由圖可知,CodeGeeX在預(yù)訓(xùn)練中采用的是8頭TP(同一個(gè)node內(nèi)的8張卡做TP,8張卡組成一個(gè)完整的模型),192頭DP(192個(gè)node間做DP),一共1536塊GPU進(jìn)行訓(xùn)練。
「【閱讀提醒】:如果你對(duì)GPT模型比較熟悉,則不需要花時(shí)間細(xì)看CodeGeeX架構(gòu)圖也能無(wú)障礙閱讀本文。架構(gòu)圖只是在涉及模型細(xì)節(jié)時(shí),可以對(duì)照著看。」
二、模型切割在做一件什么事
2.1 模型切割設(shè)計(jì)思想
回顧一下,在初始化分布式環(huán)境中,我們根據(jù)DP/TP/PP組設(shè)置并劃分了進(jìn)程,確定了模型的切割方法,如下圖:
(注意:這并不是CodeGeeX的劃分框架,而是一個(gè)更廣義的例子,細(xì)節(jié)可閱讀上篇講解)
接下來(lái),我們就可以根據(jù)這個(gè)框架來(lái)切割模型了。pytorch默認(rèn)將模型(nn.Module)定義在CPU上,因此,我們?cè)贑PU上定義并初始化模型,然后將其搬運(yùn)到當(dāng)前進(jìn)程所對(duì)應(yīng)的GPU上,整個(gè)過(guò)程如下圖:
首先,我們是面向進(jìn)程編程的,也就是整份腳本處理的是發(fā)生在1個(gè)進(jìn)程上的事情。這樣做的好處是,我們只需要維護(hù)1份腳本,然后將其發(fā)去不同機(jī)器的各張卡上執(zhí)行,就能實(shí)現(xiàn)全局的并行。
但是,1個(gè)進(jìn)程處理的是模型的不同部分,比如GPT模型,它的pre層涉及到Embedding計(jì)算,post層涉及到softmax和loss的計(jì)算,這樣每個(gè)進(jìn)程上處理的模型是不一樣的,這時(shí)怎么辦呢?別忘了,我們能夠取到進(jìn)程id(全局或DP/TP/PP組內(nèi)的),這樣我們就能通過(guò)進(jìn)程id,寫if...else...來(lái)解決模型差異化問題了。
明確了這個(gè)思想,現(xiàn)在我們可以開始寫代碼了,我們有兩種方式對(duì)模型進(jìn)行切割:
「方案一:」先定義出完整的模型,并對(duì)模型參數(shù)做初始化,然后根據(jù)進(jìn)程id取出相應(yīng)子模型,搬運(yùn)到GPU上 「方案二:」直接根據(jù)進(jìn)程id,設(shè)計(jì)好當(dāng)前子模型,做參數(shù)初始化,搬運(yùn)到GPU上
這兩者的核心差別,在于“隨機(jī)種子”的設(shè)定。
2.2 隨機(jī)種子
在分布式訓(xùn)練中,「隨機(jī)種子是非常重要的,它關(guān)系到模型是否能夠復(fù)現(xiàn)」。例如我們采取activation checkpoint的技術(shù)來(lái)節(jié)省顯存時(shí),在backward過(guò)程中我們需要重算forward得到activation,這時(shí)候就需要我們完整復(fù)現(xiàn)之前forward的過(guò)程,各類參數(shù)的初始化結(jié)果也要和之前完全一致。
我們來(lái)看幾個(gè)例子:
例1: Word Embedding
WE1和WE2間需要采用不同的隨機(jī)種子。因?yàn)槿舨捎孟嗤碾S機(jī)種子,則WE1和WE2的結(jié)果完全一樣,這不等價(jià)于先隨機(jī)初始化WE,再將它進(jìn)行切割。
例2: dropout
左側(cè)方框中的2個(gè)dropout,在初始化時(shí)需要用不同的隨機(jī)種子。因?yàn)檫@樣才等價(jià)于對(duì)完整的dropout做初始化,然后再切割。右側(cè)方框中的dropout,需要用相同的隨機(jī)種子(雖然右邊只畫了1個(gè)dropout,但其實(shí)是2個(gè)dropout,每塊GPU上各一個(gè),因?yàn)榇藭r(shí)兩塊GPU上的輸出已經(jīng)AllReduce,是完全一致的。做完AllReduce后,兩塊GPU繼續(xù)獨(dú)立計(jì)算,因此實(shí)際上有兩個(gè)dropout)。
關(guān)于隨機(jī)種子設(shè)定的一般結(jié)論
從例子中,我們可以得出一個(gè)結(jié)論:「一般在TP/PP組內(nèi),設(shè)定不同的隨機(jī)種子。而在DP組內(nèi),設(shè)定相同的隨機(jī)種子。」 這只是一個(gè)一般結(jié)論,我們可以根據(jù)實(shí)際情況去調(diào)整。
最后,回到模型切割上,方案1(先做整體初始化再切割)在代碼里被稱為“CPU上的初始化”(_initialize_affine_weight_cpu),方案2(直接對(duì)局部初始化)被稱為“在GPU上的初始化”(_initialize_affine_weight_gpu)。我們會(huì)在切割部分的代碼里經(jīng)??匆娝鼈儭?/p>
三、模型并行框架
現(xiàn)在,我們可以來(lái)看具體的代碼了
3.1 模型并行入口函數(shù)
模型并行部分的代碼入口依然在megatron/training.py的pretrain函數(shù)下,代碼如下:
def?pretrain(
????train_valid_test_dataset_provider,
????model_provider,
????forward_step_func,
????valid_forward_step_func=None,
????extra_args_provider=None,
????args_defaults={},
):??
????#?1.初始化分布式環(huán)境(源碼解讀1內(nèi)容)
????initialize_megatron(
????????extra_args_provider=extra_args_provider,?args_defaults=args_defaults
????)
????...
????# 2、模型并行:定義模型架構(gòu),并切割模型(本文重點(diǎn))
????model,?optimizer,?lr_scheduler?=?setup_model_and_optimizer(model_provider)
????...
????
????#?3、構(gòu)造train/val/test數(shù)據(jù)集(下一篇將講述)
????...?(
????????????train_data_iterator,
????????????valid_data_iterator,
????????????test_data_iterator,
????????)?=?build_train_valid_test_data_iterators(train_valid_test_dataset_provider)?
????
????...
????#?4、訓(xùn)練(下下一篇將講述)
????iteration?=?train(
????????????forward_step_func,
????????????valid_forward_step_func,
????????????model,
????????????optimizer,
????????????lr_scheduler,
????????????train_data_iterator,
????????????valid_data_iterator,
????????)
????
????...
由代碼可知,setup_model_and_optimizer是整個(gè)模型并行的入口函數(shù),如下圖,它主要由”「定義模型架構(gòu)并切割模型」“,“「設(shè)置optimizer」”和“「設(shè)置學(xué)習(xí)率」”三部分組成。我們關(guān)注的重點(diǎn)在第一部分上(get_model)。

3.2 定義并搬運(yùn)模型
get_model的內(nèi)容可簡(jiǎn)化成下圖:

get_model函數(shù)主要做了兩件事:
在CPU上定義模型。pytorch默認(rèn)在CPU上定義模型(nn.Module)。
model_provider是一個(gè)函數(shù),調(diào)用它即可返回CPU版的模型,也就是一個(gè)CodeGeeX類,這個(gè)將是下文要介紹的重點(diǎn)。把模型從CPU搬運(yùn)至GPU上。這里有兩種方法可供選擇:
「顯式搬運(yùn)。」即手動(dòng)將模型搬運(yùn)到當(dāng)前進(jìn)程所對(duì)應(yīng)的GPU上 「權(quán)重精度設(shè)定。」由ZeRO的思想可知,在模型訓(xùn)練中,把權(quán)重精度從fp32降至fp16,是一種節(jié)省顯存的好辦法。如果確定使用這種優(yōu)化辦法,將模型搬運(yùn)到GPU上后,我們需要修改精度。 「初始化DP組」。這里指的是 定義DP組間forward、backward和梯度計(jì)算與通訊等方法。在Megatron中,TP和PP組的這些方法是人為定義的(在定義CPU模型時(shí)已設(shè)置好,我們將在下文講CodeGeeX細(xì)節(jié)時(shí)看到),而DP組則是可以用現(xiàn)成的(torch的DistributedDataParallel)。在具體使用時(shí),我們可以:(1)直接調(diào)用DistributedDataParallel?;颍?)在DistributedDataParallel這個(gè)類的基礎(chǔ)上做一些改進(jìn),例如增加對(duì)碎片化內(nèi)存的管理,對(duì)計(jì)算梯度時(shí)的精度控制等。「方案一:借助deepspeed進(jìn)行管理」。在源碼解讀1中我們提過(guò),秉持著萬(wàn)物皆可wrap的原則,按照deepspeed官網(wǎng)教程,只需要在Megatron的某些文件中插入相應(yīng)代碼,就可以讓deepspeed來(lái)管理模型的分布式、DP組間的顯存優(yōu)化等,這里同理。 「方案二:手動(dòng)搬運(yùn)管理?!?/strong>這里需要我們以下事情:
get_model函數(shù)的核心代碼如下(一切盡在注釋中):
def?get_model(model_provider_func):
????"""Build?the?model."""
????args?=?get_args()
????#?1、定義并構(gòu)建CPU版模型
????if?(?#?1.1、當(dāng)分布式進(jìn)行框架采用virtual?pipeline?(是NVDIA后續(xù)提出的對(duì)Megatron的優(yōu)化方法,可先忽略不看)
????????mpu.get_pipeline_model_parallel_world_size()?>?1?
????????and?args.virtual_pipeline_model_parallel_size?is?not?None?
????):
????????model?=?[]
????????for?i?in?range(args.virtual_pipeline_model_parallel_size):
????????????mpu.set_virtual_pipeline_model_parallel_rank(i)?
????????????#?Set?pre_process?and?post_process?only?after?virtual?rank?is?set.
????????????pre_process?=?mpu.is_pipeline_first_stage()
????????????post_process?=?mpu.is_pipeline_last_stage()?
????????????this_model?=?model_provider_func(?
????????????????pre_process=pre_process,?post_process=post_process
????????????)?
????????????model.append(this_model)?
????else:?#?1.2?其余情況
????????#?判斷當(dāng)前進(jìn)程是否是PP組的第一個(gè)進(jìn)程(例如第一部分圖例中PP組的g0)
????????pre_process?=?mpu.is_pipeline_first_stage()
????????#?判斷當(dāng)前進(jìn)程是否是PP組的最后一個(gè)進(jìn)程(例如第一部分圖例中PP組的g12)
????????post_process?=?mpu.is_pipeline_last_stage()
????????#?構(gòu)建CPU版CodeGeeX模型
????????model?=?model_provider_func(pre_process=pre_process,?post_process=post_process)
???
????...
????
????#?2、將模型從CPU搬運(yùn)到GPU上
????#?2.1?如果采用Megatron-DeepSpeed的方式,則直接返回模型,后面的搬運(yùn),數(shù)據(jù)并行等工作將由deepspeed來(lái)完成
????# ref:https://www.deepspeed.ai/tutorials/megatron/
????if?args.deepspeed:?
????????return?model
????#?將當(dāng)前進(jìn)程所維護(hù)的模型,從CPU搬運(yùn)到GPU上(GPU即為在初始化時(shí)為當(dāng)前進(jìn)程分配的那塊GPU)
????print(f"?>?moving?model?to?GPU?...",?flush=True)
????for?model_module?in?model:
????????model_module.cuda(torch.cuda.current_device())
????print(f"?>?moving?to?GPU?done",?flush=True)
????#?fp16轉(zhuǎn)換(pytorch默認(rèn)模型參數(shù)精度為fp32,依需決定計(jì)算過(guò)程中是否要轉(zhuǎn)成fp16,節(jié)省顯存)
????if?args.fp16?or?args.bf16:
????????print(f"?>?converting?model?to?fp16?...",?flush=True)
????????model?=?[Float16Module(model_module,?args)?for?model_module?in?model]
????????print(f"?>?converting?to?fp16?done",?flush=True)
????
????#?采用pytorch定義的DistributedDataParallel管理數(shù)據(jù)并行
????if?args.DDP_impl?==?"torch":?
????????i?=?torch.cuda.current_device()?
????????model?=?[
????????????torchDDP(
????????????????model_module,
????????????????device_ids=[i],
????????????????output_device=i,
????????????????process_group=mpu.get_data_parallel_group(),?#?數(shù)據(jù)并行的組
????????????)
????????????for?model_module?in?model
????????]
????????return?model
????
????#?采用自定義的DistributedDataParallel管理數(shù)據(jù)并行
????#?即在pytorch的DistributedDataParallel的基礎(chǔ)上,自己再定義內(nèi)存管理、梯度精度等計(jì)算方式,更有效利用顯存
????if?args.DDP_impl?==?"local":?#?自定義的數(shù)據(jù)并行類在megatron/model/distributed.py下
????????print(f"?>?creating?DDP?model?...",?flush=True)
????????model?=?[
????????????LocalDDP(
????????????????model_module,
????????????????args.accumulate_allreduce_grads_in_fp32,
????????????????args.use_contiguous_buffers_in_ddp,
????????????)
????????????for?model_module?in?model
????????]
????????print(f"?>?creating?DDP?model?done",?flush=True)
????????return?model
????raise?NotImplementedError(
????????"Unknown?DDP?implementation?specified:?{}.?"?"Exiting.".format(args.DDP_impl)
????)
特別說(shuō)明的是,前文提過(guò)模型的首尾兩層和中間層的架構(gòu)可能不一樣,因此我們通過(guò)pre_process 和post_process來(lái)做區(qū)分。(當(dāng)然你也能選擇用進(jìn)程序id,只是首尾兩層經(jīng)常被Q到,所以這里單獨(dú)明確了下)。對(duì)CodeGeeX來(lái)說(shuō),由它預(yù)訓(xùn)練配置可知,它的PP并行度為1,也就是1塊GPU上涵蓋了模型的第一層至最后一層,所以pre_process和post_process實(shí)際上沒有用到。感興趣的朋友可以閱讀NVIDIA Megatron源碼下關(guān)于bert、gpt2的預(yù)訓(xùn)練代碼,具體了解pre_process和post_process在定義模型時(shí)起的作用。
3.3 分布式模型:CodeGeeX
現(xiàn)在,我們來(lái)看最核心的分布式模型:CodeGeeX類。
前文說(shuō)過(guò),1個(gè)腳本處理的是1個(gè)進(jìn)程上發(fā)生的事情,而1個(gè)進(jìn)程對(duì)應(yīng)的是模型的一部分。單進(jìn)程的架構(gòu)如下:

圖中每個(gè)方框都表示源碼里定義的一個(gè)nn.Module 類(除了最上的方框外)具體定義為:
CodeGeeX: 定義一塊GPU上的模型。它由TransformerLanguageModel 和_VocabParallelCrossEntropy這兩個(gè)核心類組成。TransformerLanguageModel:定義每塊GPU上輸入層embedding和中間block層的結(jié)構(gòu)Embedding: 定義每塊GPU上輸入層embedding結(jié)構(gòu)及相關(guān)計(jì)算,輸出結(jié)果已AllReduce(TP組間)ParallelTransformer:定義每塊GPU上所有中間blocks的結(jié)構(gòu)及相關(guān)計(jì)算,輸出結(jié)果已AllReduce(TP組間)ParallelTransformerLayer: 定義每塊GPU上單個(gè)block的結(jié)構(gòu)及相關(guān)計(jì)算,輸出結(jié)果已AllReduce(TP組間)ParallelSelfAttention: 定義每塊GPU上單個(gè)block中,attention的結(jié)構(gòu)及相關(guān)計(jì)算,輸出結(jié)果已AllReduce(TP組間)ParallelMLP: 定義每塊GPU上單個(gè)block中,mlp層的結(jié)構(gòu)及相關(guān)計(jì)算,輸出結(jié)果已AllReduce(TP組間)。_VocabParallelCrossEntropy: torch.autograd.Function,定義每塊GPU上,輸出層embedding、softmax和loss等結(jié)構(gòu)及相關(guān)計(jì)算。
「為什么需要對(duì)輸出做AllReduce?」回顧Megtron理論部分的講解,在縱向切割模型時(shí),Megatron是在輸入X完整的情況下,設(shè)計(jì)模型切割的方式的。因此,對(duì)于模型的每一層輸出,我們都要在TP組間做AllReduce,來(lái)保證下一層拿到的輸入也是完整的。類名字中的"Parallel",也是指在TP組中做并行,如下圖所示:

到這一步,我們終于把模型切割部分的整體流程講完了。「雖然我們是以CodeGeeX為例,但這個(gè)流程圖可以看作是通用的。」不同模型間只有模型具體結(jié)構(gòu)、DP/TP/PP組設(shè)置這些方面的差別,整個(gè)并行框架是通用的。下面,我們來(lái)探究圖中所繪的各個(gè)類的細(xì)節(jié)。
四、MegatronModule
上面所繪制的幾類,并不是直接繼承自nn.Module ,而是皆繼承于自定義的class MegatronModule(torch.nn.Module)。我們說(shuō)過(guò),gpt類模型,輸入和輸出層共用一個(gè)word embedding。因此,這個(gè)類的主要作用,就是令PP組的第一個(gè)進(jìn)程和最后一個(gè)進(jìn)程滿足這個(gè)條件(不過(guò)我不懂為什么要把這個(gè)限制放在一個(gè)大母類中去做,設(shè)計(jì)上感覺有點(diǎn)奇怪)。MegatronModule類的整體架構(gòu)如下:

特別說(shuō)明,「initialize_word_embedding 并不是某一具體的初始化WE方法,它只是起到如圖所說(shuō)的強(qiáng)制作用?!?/strong>
MegatronModule的代碼如下(一切盡在注釋中):
class?MegatronModule(torch.nn.Module):
????"""Megatron?specific?extensions?of?torch?Module?with?support
????for?pipelining."""
????def?__init__(self,?share_word_embeddings=True):
????????super(MegatronModule,?self).__init__()
????????#?input和output是否要共享一套WE
????????self.share_word_embeddings?=?share_word_embeddings
????def?state_dict_for_save_checkpoint(
????????self,?destination=None,?prefix="",?keep_vars=False
????):
????????"""Use?this?function?to?override?the?state?dict?for
????????saving?checkpoints."""
????????#?模型訓(xùn)練中,及時(shí)將參數(shù)保存到指定位置(設(shè)置checkpoint),
????????#?這樣在訓(xùn)練出問題時(shí),可以從checkpoint點(diǎn)重新load參數(shù),繼續(xù)訓(xùn)練
????????return?self.state_dict(destination,?prefix,?keep_vars)
????def?word_embeddings_weight(self):
????????"""獲取word_embedding"""
????????if?mpu.is_pipeline_first_stage(ignore_virtual=True):
????????????return?self.language_model.embedding.word_embeddings.weight?
????????if?mpu.is_pipeline_last_stage(ignore_virtual=True):?
????????????if?not?self.share_word_embeddings:
????????????????raise?Exception(?#?強(qiáng)制要求共享一套embedding
????????????????????"word_embeddings_weight()?called?for?last?"
????????????????????"stage,?but?share_word_embeddings?is?false"
????????????????)
????????????return?self.word_embeddings.weight?#?參見initialize_word_embeddings中WE的定義
????????raise?Exception(?#?如果當(dāng)前進(jìn)程是PP組的中間進(jìn)程,則其上未維護(hù)WE,因此當(dāng)然獲取不到
????????????"word_embeddings_weight()?should?be?"?"called?for?first?and?last?stage?only"
????????)
????def?initialize_word_embeddings(self,?init_method_normal):
????????"""強(qiáng)制PP組最后一個(gè)進(jìn)程初始化WE時(shí),直接使用PP組第一個(gè)進(jìn)程的WE"""
????????args?=?get_args()
????????if?not?self.share_word_embeddings:?#?強(qiáng)制share?embeddingg
????????????raise?Exception(
????????????????"initialize_word_embeddings()?was?called?but?"
????????????????"share_word_embeddings?is?false"
????????????)
????????#?PP組并行度為1時(shí),第一層和最后一層都在一塊GPU上,天然共享WE,無(wú)需做強(qiáng)制
????????if?args.pipeline_model_parallel_size?==?1:
????????????return
????????#?---------------------------------------------------
????????#?如果流水線并行的度不為1時(shí),依次做三件事:
????????#?【初始化時(shí)】:
????????#?1、在PP組最后一個(gè)進(jìn)程上初始化一個(gè)WE,令其取值全為0
????????#?2、在PP組第一個(gè)進(jìn)程與最后一個(gè)進(jìn)程間做一次AllReduce,保證兩者的WE完全一致
????????#?【訓(xùn)練時(shí)】:
????????#?3、每次想在PP組第一個(gè)/最后一個(gè)進(jìn)程上使用WE時(shí),要做一次通信,保證兩者用的WE完全一致
???????
????????if?mpu.is_pipeline_last_stage():?#?若當(dāng)前進(jìn)程是PP組最后一個(gè)進(jìn)程
????????????assert?not?mpu.is_pipeline_first_stage()
????????????self._word_embeddings_for_head_key?=?"word_embeddings_for_head"?
????????????#?初始化一個(gè)WE(已按vocab_size維度切割,可參見Megatron原理篇對(duì)WE的講解)
????????????#?VocabParallelEmbedding將在下文詳細(xì)講解
????????????self.word_embeddings?=?mpu.VocabParallelEmbedding(
????????????????args.padded_vocab_size,?#?vocab_size
????????????????args.hidden_size,?#?embed_dim
????????????????init_method=init_method_normal(args.init_method_std),?#?初始化方法(在model/utils.py下)
????????????)
????????????#?用0填充WE(等待下面做AllReduce后取得第一個(gè)進(jìn)程上的WE)
????????????self.word_embeddings.weight.data.fill_(0)?
????????????self.word_embeddings.weight.shared?=?True
????????
????????if?torch.distributed.is_initialized():
????????????if?mpu.is_pipeline_first_stage()?or?mpu.is_pipeline_last_stage():?#?若當(dāng)前進(jìn)程是PP組第一個(gè)或最后一個(gè)進(jìn)程
????????????????#?在兩進(jìn)程間做AllReduce,保證它們使用的WE完全一致
????????????????# mpu.get_embedding_group:在源碼解讀1中講過(guò),是除DP/TP/PP之外設(shè)置的又一進(jìn)程組,
????????????????#?主要就是用來(lái)做關(guān)于WE的通訊
????????????????torch.distributed.all_reduce(
????????????????????self.word_embeddings_weight().data,?group=mpu.get_embedding_group()
????????????????)
????????else:
????????????print(
????????????????"WARNING!?Distributed?processes?aren't?initialized,?so?"
????????????????"word?embeddings?in?the?last?layer?are?not?initialized.?"
????????????????"If?you?are?just?manipulating?a?model?this?is?fine,?but?"
????????????????"this?needs?to?be?handled?manually.?If?you?are?training?"
????????????????"something?is?definitely?wrong."
????????????)
五、Embedding
Emebdding類定義了word/position/segment embedding,并定義輸入X過(guò)embedding層的計(jì)算方法。關(guān)鍵屬性和方法如下圖:

self.word_embeddings:來(lái)自自定義的VocabParallelEmbedding (下面會(huì)詳述) 。「含“Parallel”則意味著參數(shù)在TP組間做了切割」。因此self.word_embeddings 是切割好的WE。每個(gè)進(jìn)程上維護(hù)根據(jù)自己進(jìn)程序號(hào)所取下的那塊WE(例如下圖中的WE1,WE2,圖片來(lái)自Megatron原理篇):

self.position_embeddings和self.tokentype_embeddings這兩者都和輸入X相關(guān),而輸入X是不做切割的,因此這兩者也無(wú)需切割。state_dict_for_save_checkpoint和load_state_dict。在源碼注解里,這兩個(gè)函數(shù)分別給出了"easy load" 和"customize load"的注釋,這個(gè)注釋不是很貼切。實(shí)際上,前者用于在模型訓(xùn)練過(guò)程中及時(shí)讀取當(dāng)前參數(shù),及時(shí)保存(做checkpoint);后者則一般用于模型的重載,例如訓(xùn)到一半掛掉了,我們就重新初始化一個(gè)新模型,重載上個(gè)checkpoint保存下的權(quán)重。
Embedding層代碼如下(一切盡在注釋中):
class?Embedding(MegatronModule):
????"""Language?model?embeddings.
????Arguments:
????????hidden_size:?hidden?size?
????????vocab_size:?vocabulary?size?
????????max_sequence_length:?maximum?size?of?sequence.?This
?????????????????????????????is?used?for?positional?embedding
????????embedding_dropout_prob:?dropout?probability?for?embeddings
????????init_method:?weight?initialization?method
????????num_tokentypes:?size?of?the?token-type?embeddings.?0?value
????????????????????????will?ignore?this?embedding
????"""
????def?__init__(
????????self,
????????hidden_size,?#?每個(gè)token的向量維度
????????vocab_size,?#?詞表大小
????????max_sequence_length,?#?最長(zhǎng)序列長(zhǎng)度
????????embedding_dropout_prob,?#?dropout?probability?for?embeddings
????????init_method,?#?初始化權(quán)重的方法
????????num_tokentypes=0,?#?類似于Bert中的segment?type
????):
????????super(Embedding,?self).__init__()
????????
????????args?=?get_args()
????????
????????self.hidden_size?=?hidden_size
????????self.init_method?=?init_method
????????self.num_tokentypes?=?num_tokentypes
????????self.max_sequence_length?=?max_sequence_length
????????
????????#?WE?size:?(vocab_size//TP_N,?hidden_size)
????????#?TP_N表示TP組模型并行度
????????self.word_embeddings?=?mpu.VocabParallelEmbedding(
????????????vocab_size,?self.hidden_size,?init_method=self.init_method)?
????????self._word_embeddings_key?=?'word_embeddings'
????????????
????????self.vocab_size?=?vocab_size
????????#?PE?size:?(max_seq_len,?hidden_size)
????????self.position_embeddings?=?torch.nn.Embedding(
????????????max_sequence_length,?self.hidden_size)
????????self.position_embeddings?=?self.position_embeddings.half()
????????self._position_embeddings_key?=?'position_embeddings'
????????#?Initialize?the?position?embeddings.
????????self.init_method(self.position_embeddings.weight)
????????#?TE_size:(num_tokentypes,?hidden_size)
????????#?TE類似于Bert中的segment?embedding
????????self._tokentype_embeddings_key?=?'tokentype_embeddings'
????????if?self.num_tokentypes?>?0:
????????????self.tokentype_embeddings?=?torch.nn.Embedding(self.num_tokentypes,
???????????????????????????????????????????????????????????self.hidden_size)
????????????#?Initialize?the?token-type?embeddings.
????????????self.init_method(self.tokentype_embeddings.weight)
????????else:
????????????self.tokentype_embeddings?=?None
????????#?Embeddings?dropout
????????self.embedding_dropout?=?torch.nn.Dropout(embedding_dropout_prob)
????def?add_tokentype_embeddings(self,?num_tokentypes):
????????"""如果在pretrain階段未定義TE,而在fine-tune階段TE,則可通過(guò)此函數(shù)添加
????????"""
????????if?self.tokentype_embeddings?is?not?None:
????????????raise?Exception('tokentype?embeddings?is?already?initialized')
????????if?torch.distributed.get_rank()?==?0:
????????????print('adding?embedding?for?{}?tokentypes'.format(num_tokentypes),
??????????????????flush=True)
????????self.num_tokentypes?=?num_tokentypes
????????self.tokentype_embeddings?=?torch.nn.Embedding(num_tokentypes,
???????????????????????????????????????????????????????self.hidden_size)
????????#?Initialize?the?token-type?embeddings.
????????self.init_method(self.tokentype_embeddings.weight)
????def?forward(self,?input_ids,?position_ids,?tokentype_ids=None):
????????"""定義輸入X過(guò)embedding層的計(jì)算方法
????????"""
????????#?words_embeddings?size?=?(b,?seq_len,?hidden_size)
????????#?再次注意:self.word_embeddings做forward時(shí),最終的輸出結(jié)果時(shí)AllReduce的(見上圖)
????????words_embeddings?=?self.word_embeddings(input_ids)?
????????#?position_embeddings?size?=?(b,?seq_len,?hidden_size)
????????position_embeddings?=?self.position_embeddings(position_ids)
????????#?embedding?=?WE?+?PE
????????#?embedding?size?=?(b,?seq_len,?hidden_size)
????????embeddings?=?words_embeddings?+?position_embeddings
????????#?依需要決定是否增加TE
????????if?tokentype_ids?is?not?None:?
????????????assert?self.tokentype_embeddings?is?not?None
????????????embeddings?=?embeddings?+?self.tokentype_embeddings(tokentype_ids)
????????else:
????????????assert?self.tokentype_embeddings?is?None
????????#?Dropout.
????????embeddings?=?self.embedding_dropout(embeddings)
????????return?embeddings
????def?state_dict_for_save_checkpoint(
????????self,?destination=None,?prefix='',?keep_vars=False,
????):
????????"""For?easy?load.
????????在模型訓(xùn)練過(guò)程中及時(shí)讀取當(dāng)前參數(shù),方便及時(shí)保存(做checkpoint)
????????篇幅限制,這里不展示細(xì)節(jié)
????????"""
????????...
????def?load_state_dict(self,?state_dict,?strict=True):
????????"""Customized?load.
????????用于模型的重載。例如訓(xùn)到一半掛掉了,我們就重新初始化一個(gè)新模型,
????????重載上個(gè)checkpoint保存下的權(quán)重。
????????篇幅限制,這里不展示細(xì)節(jié)
????????"""
????????...
六、VocabParallelEmbedding
該類用于定義分布式的word embedding,整體架構(gòu)如下,同樣只列舉了核心屬性和方法:

具體代碼如下,可以特別關(guān)注「初始化和forward」部分,同時(shí)建議大家閱讀理論篇中關(guān)于這一過(guò)程的詳細(xì)講解(一切盡在注釋中)
class?VocabParallelEmbedding(torch.nn.Module):
????"""Embedding?parallelized?in?the?vocabulary?dimension.
????This?is?mainly?adapted?from?torch.nn.Embedding?and?all?the?default
????values?are?kept.
????Arguments:
????????num_embeddings:?vocabulary?size.
????????embedding_dim:?size?of?hidden?state.
????????init_method:?method?to?initialize?weights.
????"""
????def?__init__(self,?num_embeddings,?embedding_dim,?init_method=init.xavier_normal_):
????????super(VocabParallelEmbedding,?self).__init__()
????????#?Keep?the?input?dimensions.
????????self.num_embeddings?=?num_embeddings?#?vocab_size
????????self.embedding_dim?=?embedding_dim?#?hidden_state.
????????#?Set?the?detauls?for?compatibility.
????????self.padding_idx?=?None
????????self.max_norm?=?None
????????self.norm_type?=?2.0
????????self.scale_grad_by_freq?=?False
????????self.sparse?=?False
????????self._weight?=?None
????????#?當(dāng)前進(jìn)程所在TP組進(jìn)程總數(shù)
????????self.tensor_model_parallel_size?=?get_tensor_model_parallel_world_size()
????????#?根據(jù)當(dāng)前進(jìn)程在TP組中的序號(hào),確定其所需維護(hù)的WE部分,沿著vocab維度對(duì)WE進(jìn)行切割
????????#?例如,進(jìn)程id=0, 維護(hù)詞表序號(hào)[0,5)范圍內(nèi)的數(shù)據(jù);進(jìn)程id=1,維護(hù)[5,10)
????????(
????????????self.vocab_start_index,
????????????self.vocab_end_index,
????????)?=?VocabUtility.vocab_range_from_global_vocab_size(
????????????self.num_embeddings,
????????????get_tensor_model_parallel_rank(),
????????????self.tensor_model_parallel_size,
????????)
????????#?計(jì)算當(dāng)前進(jìn)程維護(hù)的詞表大小
????????self.num_embeddings_per_partition?=?(
????????????self.vocab_end_index?-?self.vocab_start_index
????????)
????????#?對(duì)WE做初始化
????????args?=?get_args()?#?讀取預(yù)訓(xùn)練參數(shù)配置
????????if?args.use_cpu_initialization:?#?CPU上做初始化
????????????self.weight?=?Parameter(?#?在CPU上先生成一個(gè)完整的WE
????????????????torch.empty(
????????????????????self.num_embeddings_per_partition,
????????????????????self.embedding_dim,
????????????????????dtype=args.params_dtype,
????????????????????#?dtype=torch.float32,
????????????????)
????????????)
????????????#?對(duì)CPU上的WE做切割(隨機(jī)種子在初始化分布式中已設(shè)定好,不用變)
????????????_initialize_affine_weight_cpu(
????????????????self.weight,?
????????????????self.num_embeddings,?
????????????????self.embedding_dim,
????????????????self.num_embeddings_per_partition,
????????????????0,
????????????????init_method,?#?初始化權(quán)重的方法,例如xavier之類
????????????)
????????else:?#?在GPU上做初始化
????????????self.weight?=?Parameter(?#?生成一個(gè)切割好的WE
????????????????torch.empty(
????????????????????self.num_embeddings_per_partition,
????????????????????self.embedding_dim,
????????????????????device=torch.cuda.current_device(),
????????????????????dtype=args.params_dtype,
????????????????????#?dtype=torch.float32,
????????????????)
????????????)
????????????#?在GPU上做初始化,注意TP組內(nèi)不同進(jìn)程采用不同的隨機(jī)種子
????????????_initialize_affine_weight_gpu(
????????????????self.weight,?init_method,?partition_dim=0,?stride=1
????????????)
????def?forward(self,?input_):
????????"""定義輸入X過(guò)WE的計(jì)算方法,輸出結(jié)果已經(jīng)過(guò)AllReduce"""
????????if?self.tensor_model_parallel_size?>?1:?#?如果使用TP
????????????#?如果在當(dāng)前進(jìn)程維護(hù)的WE上,找不到對(duì)應(yīng)的單詞,那么對(duì)應(yīng)位置就賦0
????????????#?例如當(dāng)前的數(shù)據(jù)的tokenid是:[2,7,1,5],當(dāng)前維護(hù)的詞表是[0,1,2](start_index=0, end_index = 3),
????????????#?則mask之后的數(shù)據(jù)為[2,0,1,0]
????????????#?Build?the?mask.
????????????input_mask?=?(input_?????????????????input_?>=?self.vocab_end_index
????????????)
????????????#?Mask?the?input.
????????????masked_input?=?input_.clone()?-?self.vocab_start_index
????????????masked_input[input_mask]?=?0
????????else:
????????????masked_input?=?input_
????????????
????????#?輸入X,過(guò)當(dāng)前進(jìn)程維護(hù)的部分WE的結(jié)果
????????output_parallel?=?F.embedding(
????????????masked_input,?#?tensor?containing?indices?into?the?embedding?matrix
????????????self.weight,?#?切割好的word?embedding的權(quán)重
????????????self.padding_idx,
????????????self.max_norm,
????????????self.norm_type,
????????????self.scale_grad_by_freq,
????????????self.sparse,
????????)
????????#?當(dāng)前詞表不維護(hù)的部分,都設(shè)為0
????????if?self.tensor_model_parallel_size?>?1:
????????????output_parallel[input_mask,?:]?=?0.0?#
????????
????????#?將TP組各GPU上的結(jié)果做AllReduce
????????output?=?reduce_from_tensor_model_parallel_region(output_parallel)
????????return?output
def?_initialize_affine_weight_cpu(...):
????"""CPU版權(quán)重初始化。這個(gè)不難,大家可以自己閱讀"""
????...
????
def?_initialize_affine_weight_gpu(...):
????"""GPU版權(quán)重初始化。特別關(guān)注設(shè)置隨機(jī)種子部分"""
????...
????#?借助deepspeed或自定義的get_cuda_rng_tracker方法,對(duì)隨機(jī)種子進(jìn)行操作
????#?get_cuda_rng_tracker細(xì)節(jié),大家可自行閱讀源碼
????if?ds_checkpointing.is_configured():
????????global?get_cuda_rng_tracker
????????get_cuda_rng_tracker?=?ds_checkpointing.get_cuda_rng_tracker
????with?get_cuda_rng_tracker().fork():?
????????init_method(weight)
七、ParallelSelfAttention:分布式block的一般套路
【閱讀提示】:閱讀本節(jié)時(shí)可:
對(duì)照第一部分CodeGeeX框架圖 對(duì)照Megatron理論篇對(duì)矩陣切分的講解
首先來(lái)看切割A(yù)ttention的示意圖,由圖可知,「對(duì)QKV矩陣,采用“列切割”,對(duì)線性矩陣B,采用“行切割”」。這樣設(shè)計(jì)的好處是,在經(jīng)過(guò)QKV的計(jì)算后,各進(jìn)程在不用通訊的前提下,繼續(xù)做線性計(jì)算,直到最后一步才AllReduce,起到降低通訊成本的作用:

我們先單獨(dú)來(lái)看“列切割”與“行切割”的實(shí)現(xiàn)代碼。Megatron將它們定義成了兩個(gè)nn.Module類。
7.1 列切割:ColumnParallelLinear
列切割示意圖如下:

f和g是兩個(gè)共軛算子,可理解為兩個(gè)torch.autograd.Function類。在這個(gè)類下,我們可以「根據(jù)需要重寫forward和backward方法」。f: 「forward中,直接copy輸入;backward中,對(duì)梯度做AllReduce」。在代碼里定義為class _CopyToModelParallelRegion(torch.autograd.Function)g: 「forward中,all-gather輸出;backward中,對(duì)梯度做split」(每張卡經(jīng)過(guò)all-gather已有完整的Y了,因此以Y為起點(diǎn)計(jì)算梯度后,沿著列做split就可得到Y(jié)1和Y2的梯度)。在代碼里定義為class _GatherFromModelParallelRegion(torch.autograd.Function)
class?ColumnParallelLinear(torch.nn.Module):
????"""Linear?layer?with?column?parallelism.
????The?linear?layer?is?defined?as?Y?=?XA?+?b.?A?is?parallelized?along
????its?second?dimension?as?A?=?[A_1,?...,?A_p].
????Arguments:
????????input_size:?first?dimension?of?matrix?A.?
????????output_size:?second?dimension?of?matrix?A.?
????????bias:?If?true,?add?bias
????????gather_output:?If?true,?call?all-gether?on?output?and?make?Y?avaiable
???????????????????????to?all?GPUs,?otherwise,?every?GPU?will?have?its?output
???????????????????????which?is?Y_i?=?XA_i?
????????init_method:?method?to?initialize?weights.?Note?that?bias?is?always?set
?????????????????????to?zero.
????????stride:?For?the?strided?linear?layers.
????????keep_master_weight_for_test:?This?was?added?for?testing?and?should?be
?????????????????????????????????????set?to?False.?It?returns?the?master?weights
?????????????????????????????????????used?for?initialization.
????????skip_bias_add:?This?was?added?to?enable?performance?optimations?where?bias
???????????????????????can?be?fused?with?other?elementwise?operations.?we?skip
???????????????????????adding?bias?but?instead?return?it.?
????"""
????#?該類定義了切割后的權(quán)重W,例如對(duì)上圖來(lái)說(shuō),W1和W2都可分別視為該類的一個(gè)實(shí)例
????def?__init__(
????????self,
????????input_size,?#?W的第一個(gè)維度
????????output_size,?#?W的第二個(gè)維度
????????bias=True,?#?是否需要引入bias
????????gather_output=True,?#?決定是否要將Y1和Y2做all-gather
????????init_method=init.xavier_normal_,
????????stride=1,
????????keep_master_weight_for_test=False,
????????skip_bias_add=False,
????????params_dtype=None,
????????skip_init=False,
????????device=None,
????):
????????super(ColumnParallelLinear,?self).__init__()
????????#?Keep?input?parameters
????????self.input_size?=?input_size?
????????self.output_size?=?output_size?
????????self.gather_output?=?gather_output?
????????#?Divide?the?weight?matrix?along?the?last?dimension.
????????#?當(dāng)前進(jìn)程所在TP組的總進(jìn)程數(shù)
????????world_size?=?get_tensor_model_parallel_world_size()
????????#?每塊GPU上維護(hù)的hidden_size的大小,等于?原h(huán)idden_zize?//?TP組總進(jìn)程數(shù)
????????self.output_size_per_partition?=?divide(output_size,?world_size)?
????????self.skip_bias_add?=?skip_bias_add?
????????self.params_dtype?=?params_dtype?
????????self.device?=?device?
????????#?Parameters.
????????#?Note:?torch.nn.functional.linear?performs?XA^T?+?b?and?as?a?result?
????????#?Initialize?weight.
????????args?=?get_args()?#?取得命令行所有的參數(shù)
????????if?not?skip_init:?
????????????if?args.use_cpu_initialization:?#?CPU上初始化
????????????????self.weight?=?Parameter(??
????????????????????torch.empty(
????????????????????????self.output_size_per_partition,
????????????????????????self.input_size,
????????????????????????dtype=self.params_dtype?if?self.params_dtype?is?not?None?else?args.params_dtype,
????????????????????)
????????????????)
????????????????self.master_weight?=?_initialize_affine_weight_cpu(?#?
????????????????????self.weight,
????????????????????self.output_size,
????????????????????self.input_size,
????????????????????self.output_size_per_partition,
????????????????????0,
????????????????????init_method,
????????????????????stride=stride,
????????????????????return_master_weight=keep_master_weight_for_test,
????????????????)
????????????else:?#?GPU上初始化
????????????????self.weight?=?Parameter(?
????????????????????torch.empty(
????????????????????????self.output_size_per_partition,
????????????????????????self.input_size,
????????????????????????device=self.device?if?self.device?is?not?None?else?torch.cuda.current_device(),
????????????????????????dtype=self.params_dtype?if?self.params_dtype?is?not?None?else?args.params_dtype,
????????????????????)
????????????????)
????????????????_initialize_affine_weight_gpu(?
????????????????????self.weight,?init_method,?partition_dim=0,?stride=stride
????????????????)
????????else:
????????????self.register_parameter("weight",?None)
????????#?對(duì)bias做處理,道理同weight
????????if?bias?and?not?skip_init:?
????????????if?args.use_cpu_initialization:?#?CPU上初始化
????????????????self.bias?=?Parameter(
????????????????????torch.empty(self.output_size_per_partition,?
????????????????????????????????dtype=self.params_dtype?if?self.params_dtype?is?not?None?else?args.params_dtype)
????????????????)?
????????????else:?
????????????????self.bias?=?Parameter(?#?GPU上初始化
????????????????????torch.empty(
????????????????????????self.output_size_per_partition,
????????????????????????device=self.device?if?self.device?is?not?None?else?torch.cuda.current_device(),
????????????????????????dtype=self.params_dtype?if?self.params_dtype?is?not?None?else?args.params_dtype,
????????????????????)
????????????????)
????????????
????????????set_tensor_model_parallel_attributes(self.bias,?True,?0,?stride)?
????????????#?Always?initialize?bias?to?zero.?
????????????with?torch.no_grad():
????????????????self.bias.zero_()
????????else:
????????????self.register_parameter("bias",?None)
????def?forward(self,?input_):
????????#?定義列切割中的f算子
????????#?調(diào)用copy_to_tensor_model_parallel_region則新建一個(gè)_CopyToModelParallelRegion實(shí)例(見下)
????????input_parallel?=?copy_to_tensor_model_parallel_region(input_)
????????bias?=?self.bias?if?not?self.skip_bias_add?else?None?#?定義bias
????????output_parallel?=?F.linear(input_parallel,?self.weight,?bias)?#?X?*?切割好的權(quán)重
????????#?決定是否要對(duì)每個(gè)進(jìn)程上的輸出結(jié)果做All-Reduce
????????if?self.gather_output:
????????????#?定義列切割中的g算子
????????????#?調(diào)用gather_from_tensor_model_parallel_region則新建一個(gè)_GatherFromModelParallelRegion實(shí)例(見下)
????????????output?=?gather_from_tensor_model_parallel_region(output_parallel)?#?把各GPU上的輸出按照列g(shù)ather起來(lái)后,作為最終輸出
????????else:
????????????output?=?output_parallel?#?否則最終輸出還是自己算的那塊GPU
????????output_bias?=?self.bias?if?self.skip_bias_add?else?None
????????return?output,?output_bias
#?列切割中的f與g
class?_CopyToModelParallelRegion(torch.autograd.Function):
????"""Pass?the?input?to?the?model?parallel?region."""
????#?列切割下的f算子
????# forward:copy輸入
????# backward:對(duì)梯度做AllReduce
????@staticmethod
????def?symbolic(graph,?input_):
????????return?input_
????@staticmethod
????def?forward(ctx,?input_):
????????return?input_
????@staticmethod
????def?backward(ctx,?grad_output):
????????return?_reduce(grad_output)
class?_GatherFromModelParallelRegion(torch.autograd.Function):
????"""Gather?the?input?from?model?parallel?region?and?concatinate."""
????#?列切割中的g算子
????# forward:All-Gather輸出
????# backward:對(duì)梯度,沿著列方向做split
????@staticmethod
????def?symbolic(graph,?input_):
????????return?_gather(input_)
????@staticmethod
????def?forward(ctx,?input_):
????????return?_gather(input_)
????@staticmethod
????def?backward(ctx,?grad_output):
????????return?_split(grad_output)
7.2 行切割:RowParallelLinear

f: forward中,按列split輸入;backward中,all-gather梯度g: forward中,AllReduce輸出;backward中,直接輸出梯度,無(wú)需做任何通訊(因?yàn)榻?jīng)過(guò)g的foward,每塊GPU上已擁有了Yi和Y,則根據(jù)圖中g(shù)的backward公式可知,每塊GPU可獨(dú)立計(jì)算梯度)
代碼如下:
class?RowParallelLinear(torch.nn.Module):
????"""Linear?layer?with?row?parallelism.
????The?linear?layer?is?defined?as?Y?=?XA?+?b.?A?is?parallelized?along
????its?first?dimension?and?X?along?its?second?dimension?as:
???????????????-???-
??????????????|?A_1?|
??????????????|?.???|
??????????A?=?|?.???|????????X?=?[X_1,?...,?X_p]
??????????????|?.???|
??????????????|?A_p?|
???????????????-???-
????Arguments:
????????input_size:?first?dimension?of?matrix?A.
????????output_size:?second?dimension?of?matrix?A.
????????bias:?If?true,?add?bias.?Note?that?bias?is?not?parallelized.
????????input_is_parallel:?If?true,?we?assume?that?the?input?is?already
???????????????????????????split?across?the?GPUs?and?we?do?not?split
???????????????????????????again.
????????init_method:?method?to?initialize?weights.?Note?that?bias?is?always?set
?????????????????????to?zero.
????????stride:?For?the?strided?linear?layers.
????????keep_master_weight_for_test:?This?was?added?for?testing?and?should?be
?????????????????????????????????????set?to?False.?It?returns?the?master?weights
?????????????????????????????????????used?for?initialization.
????????skip_bias_add:?This?was?added?to?enable?performance?optimations?where?bias
???????????????????????can?be?fused?with?other?elementwise?operations.?we?skip
???????????????????????adding?bias?but?instead?return?it.
????"""
????def?__init__(
????????self,
????????input_size,
????????output_size,
????????bias=True,
????????input_is_parallel=False,
????????init_method=init.xavier_normal_,
????????stride=1,
????????keep_master_weight_for_test=False,
????????skip_bias_add=False,
????????params_dtype=None,
????????skip_init=False,
????????device=None,
????):
????????super(RowParallelLinear,?self).__init__()
????????#?Keep?input?parameters
????????self.input_size?=?input_size
????????self.output_size?=?output_size
????????self.input_is_parallel?=?input_is_parallel
????????#?Divide?the?weight?matrix?along?the?last?dimension.
????????world_size?=?get_tensor_model_parallel_world_size()
????????self.input_size_per_partition?=?divide(input_size,?world_size)
????????self.skip_bias_add?=?skip_bias_add
????????self.params_dtype?=?params_dtype
????????self.device?=?device
????????
????????#?Parameters.
????????#?Note:?torch.nn.functional.linear?performs?XA^T?+?b?and?as?a?result
????????#?we?allocate?the?transpose.
????????#?Initialize?weight.
????????args?=?get_args()
????????if?not?skip_init:
????????????if?args.use_cpu_initialization:
????????????????self.weight?=?Parameter(
????????????????????torch.empty(
????????????????????????self.output_size,
????????????????????????self.input_size_per_partition,
????????????????????????dtype=self.params_dtype?if?self.params_dtype?is?not?None?else?args.params_dtype,
????????????????????)
????????????????)
????????????????self.master_weight?=?_initialize_affine_weight_cpu(
????????????????????self.weight,
????????????????????self.output_size,
????????????????????self.input_size,
????????????????????self.input_size_per_partition,
????????????????????1,
????????????????????init_method,
????????????????????stride=stride,
????????????????????return_master_weight=keep_master_weight_for_test,
????????????????)
????????????else:
????????????????self.weight?=?Parameter(
????????????????????torch.empty(
????????????????????????self.output_size,
????????????????????????self.input_size_per_partition,
????????????????????????device=self.device?if?self.device?is?not?None?else?torch.cuda.current_device(),
????????????????????????dtype=self.params_dtype?if?self.params_dtype?is?not?None?else?args.params_dtype,
????????????????????)
????????????????)
????????????????_initialize_affine_weight_gpu(
????????????????????self.weight,?init_method,?partition_dim=1,?stride=stride
????????????????)
????????else:
????????????self.register_parameter("weight",?None)
????????????
????????if?bias?and?not?skip_init:
????????????if?args.use_cpu_initialization:
????????????????self.bias?=?Parameter(
????????????????????torch.empty(self.output_size,?
????????????????????????????????dtype=self.params_dtype?if?self.params_dtype?is?not?None?else?args.params_dtype)
????????????????)
????????????else:
????????????????self.bias?=?Parameter(
????????????????????torch.empty(
????????????????????????self.output_size,
????????????????????????device=self.device?if?self.device?is?not?None?else?torch.cuda.current_device(),
????????????????????????dtype=self.params_dtype?if?self.params_dtype?is?not?None?else?args.params_dtype,
????????????????????)
????????????????)
????????????#?Always?initialize?bias?to?zero.
????????????with?torch.no_grad():
????????????????self.bias.zero_()
????????else:
????????????self.register_parameter("bias",?None)
????def?forward(self,?input_):
????????#?Set?up?backprop?all-reduce.
????????if?self.input_is_parallel:
????????????input_parallel?=?input_
????????else:
????????????input_parallel?=?scatter_to_tensor_model_parallel_region(input_)
????????#?Matrix?multiply.
????????output_parallel?=?F.linear(input_parallel,?self.weight)
????????#?All-reduce?across?all?the?partitions.
????????output_?=?reduce_from_tensor_model_parallel_region(output_parallel)
????????if?not?self.skip_bias_add:
????????????output?=?output_?+?self.bias?if?self.bias?is?not?None?else?output_
????????????output_bias?=?None
????????else:
????????????output?=?output_
????????????output_bias?=?self.bias
????????return?output,?output_bias
#?行切割中的f和g算子
class?_ScatterToModelParallelRegion(torch.autograd.Function):
????"""Split?the?input?and?keep?only?the?corresponding?chuck?to?the?rank."""
????#?行切割中的f算子
????# forward:沿列split輸入
????# backward:all-gather梯度
????@staticmethod
????def?symbolic(graph,?input_):
????????return?_split(input_)
????@staticmethod
????def?forward(ctx,?input_):
????????return?_split(input_)
????@staticmethod
????def?backward(ctx,?grad_output):
????????return?_gather(grad_output)
?
class?_ReduceFromModelParallelRegion(torch.autograd.Function):
????"""All-reduce?the?input?from?the?model?parallel?region."""
????#?行切割中的g算子
????# forward:AllReduce輸出
????# backward:正常計(jì)算梯度,GPU間無(wú)需做任何通訊
????@staticmethod
????def?symbolic(graph,?input_):
????????return?_reduce(input_)
????@staticmethod
????def?forward(ctx,?input_):
????????return?_reduce(input_)
????@staticmethod
????def?backward(ctx,?grad_output):
????????return?grad_output
7.3 ParallelSelfAttention
該類的構(gòu)造如下圖:

這張圖中透露的核心含義是,「每個(gè)進(jìn)程上維護(hù)的都是按列切割完的QKV矩陣」,進(jìn)程間獨(dú)立計(jì)算,QKV矩陣的輸出結(jié)果一般不做AllReduce。同時(shí),「每個(gè)進(jìn)程上維護(hù)的是按行切割完的dense(線型層)矩陣」,Attention輸出過(guò)線性層后的結(jié)果,做AllReduce。另外,在設(shè)置attention_dropout時(shí),同樣調(diào)用了get_cuda_rng_tracker 方法,令TP組內(nèi)的進(jìn)程擁有不同的隨機(jī)種子。「最后,你可能想問,dense后的dropout去哪里了」?代碼里把它定義到了ParallelTransformerLayer 下(等于attention + mlp)。
相信有了上面的說(shuō)明,看這塊代碼就不難了。篇幅限制,這里不展示代碼了。大家可以對(duì)照著CodeGeeX架構(gòu)圖,來(lái)看這里multi-head attention的計(jì)算方式。
ParallelMLP,ParallelTransformerLayer和ParallelTransformer都采用的是一樣的套路,也略過(guò)不言。
八、CrossEntropy
現(xiàn)在,終于可以來(lái)看模型的最后一層:交叉熵的平行計(jì)算。核心類為_VocabParallelCrossEntropy
我們?cè)谠砥兄v過(guò)交叉熵的并行計(jì)算,其優(yōu)化核心是將通訊量從b*s*v降至b*s。但是Megatron代碼中定義的交叉熵計(jì)算方式,稍微復(fù)雜一些,也和我們一般理解的交叉熵有些許差異。所以我們先用圖解,來(lái)看下代碼定義的交叉熵計(jì)算流程:
【注】:
對(duì)X和Y_i來(lái)說(shuō),(b, s, h)維度下應(yīng)該畫成一個(gè)立方體,為了表達(dá)簡(jiǎn)練,這里將b拍平了。 對(duì)其余維度中含b的矩陣,b正常表示,即row=b

8.1 計(jì)算logit
首先,在使用_VocabParallelCrossEntropy 計(jì)算交叉熵前,我們需要計(jì)算logit。這時(shí)我們調(diào)用parallel_lm_logits 函數(shù),將模型最后一層的輸出X(復(fù)習(xí)一下,這個(gè)X已經(jīng)在TP組內(nèi)AllReduce了),乘上當(dāng)前進(jìn)程上維護(hù)的輸入層WE的轉(zhuǎn)置(復(fù)習(xí)一下,輸入層和輸出層共用一套embedding),得到當(dāng)前進(jìn)程的logit Y_i,「同時(shí)我們選擇不對(duì)輸出logit做AllReduce」。
你可能會(huì)有一個(gè)疑惑:「在Transformer中,輸出層會(huì)額外訓(xùn)練一個(gè)線性矩陣,來(lái)計(jì)算logit;為什么在gpt中,可以用輸入層WE的轉(zhuǎn)置來(lái)代替這個(gè)線性矩陣?」
這個(gè)問題的答案,對(duì)理解Megatron交叉熵計(jì)算也至關(guān)重要。我們可「將X*WE^T結(jié)果理解成“X與WE間的相似度”」,例如對(duì)Y1來(lái)說(shuō),它的第一行中的每個(gè)logit,表示第一個(gè)token與詞表里每個(gè)詞的相似度。
注意到每個(gè)進(jìn)程上只維護(hù)部分WE。例如,假設(shè)詞表共有5個(gè)單詞,WE1維護(hù)前5個(gè)單詞,WE2維護(hù)后5個(gè)單詞。因此再嚴(yán)格來(lái)說(shuō):「對(duì)Y1,它的第一行中的每個(gè)logit,表示第一個(gè)token與詞表中前5個(gè)詞的相似度;對(duì)Y2,它的第一行中的每個(gè)logit,表示第一個(gè)token與詞表中后5個(gè)詞的相似度。我們要記住這個(gè)含義?!?/strong>
8.2 計(jì)算交叉熵
知道了logit的含義,我們來(lái)看交叉熵計(jì)算。
首先做了一系列求max的計(jì)算,得到基于全局的max(logit),再將orig_logit - max(logit),得到處理后的結(jié)果。這步理解起來(lái)不難,主要目的是為了防止計(jì)算溢出。
「接下來(lái),就是基于logit算loss了?!?/strong>
每個(gè)進(jìn)程上都有一份(b, s)維度的真值,它表示每個(gè)token的真值是哪個(gè)詞(詞用id表示)。我們基于這份真值,在Y_i上找出真值位置的logit。例如:seq_length = 3,即我們需要對(duì)3個(gè)token去做預(yù)測(cè),假設(shè)前兩個(gè)token的真值在第1個(gè)進(jìn)程所維護(hù)的WE1中,最后一個(gè)token的真值在第2個(gè)進(jìn)程所維護(hù)的WE2中。那么我們?nèi)1的前兩行里,取出真值位置的logit,這個(gè)logit表示“token與真值的相似度”,去Y2的最后一行里做同樣操作。
這樣,我們就能得到L1和L2,和真值位置不對(duì)應(yīng)的地方,統(tǒng)一填充0。隨后對(duì)L1和L2做AllReduce,得到L。「L中的每行表示“token與真值間的相似度"」
現(xiàn)在,我們回來(lái)對(duì)Y1和Y2的每一行求sum(e^logit),得到e1和e2。將e1和e2做AllReduce,得到e。「e中的每行表示“token和詞表中所有詞相似度的總和”」
我們希望「(token和詞表中所有詞相似度的總和-token與真值間的相似度) /token和詞表中所有詞相似度的總和」這個(gè)值最小,這個(gè)差值就是最終的loss。
8.3 代碼
理清了這點(diǎn),現(xiàn)在可以來(lái)看代碼了(一切盡在注釋中),建議對(duì)這塊還有疑問的朋友,可以寫個(gè)test腳本把中間結(jié)果打印出來(lái),方便理解:
class?_VocabParallelCrossEntropy(torch.autograd.Function):
????"""
????分布式計(jì)算Loss????
????"""
????@staticmethod
????def?forward(ctx,?vocab_parallel_logits,?target):
????????#?1.?logit?-?global?max(logit)操作,主要目的是防溢出
????????logits_max?=?torch.max(vocab_parallel_logits,?dim=-1)[0]?#?(b,?s,?1)
????????torch.distributed.all_reduce(?#?(b,?s,?1)
????????????logits_max,
????????????op=torch.distributed.ReduceOp.MAX,?#?找全局最大值
????????????group=get_tensor_model_parallel_group(),
????????)
????????#?Subtract?the?maximum?value.?
????????vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))?#?原始GPU上維護(hù)的logits減去每行最大值(防止溢出)
????????#?2、根據(jù)當(dāng)前進(jìn)程id,取出當(dāng)前進(jìn)程所維護(hù)詞表序號(hào)等信息
????????#?函數(shù),能夠獲取當(dāng)前進(jìn)程所維護(hù)詞表的start_index和end_index
????????get_vocab_range?=?VocabUtility.vocab_range_from_per_partition_vocab_size?
????????#?這塊GPU上logits最后一維的大小,等于所維護(hù)的詞表的大?。╲/N)
????????partition_vocab_size?=?vocab_parallel_logits.size()[-1]
????????#?取得當(dāng)前進(jìn)程所在TP組中的序號(hào)
????????rank?=?get_tensor_model_parallel_rank()
????????#?取得當(dāng)前進(jìn)程所在TP組的總進(jìn)程數(shù)
????????world_size?=?get_tensor_model_parallel_world_size()
????????#?取得當(dāng)前進(jìn)程所維護(hù)的詞表的start_index和end_index?
????????vocab_start_index,?vocab_end_index?=?get_vocab_range(?
????????????partition_vocab_size,?rank,?world_size
????????)
????????#?3.?基于真值,取出每個(gè)token在真值位置上的logit(即和真值的相似度)
????????#?Create?a?mask?of?valid?vocab?ids?(1?means?it?needs?to?be?masked)
????????target_mask?=?(target?=?vocab_end_index)?#?target?=?(b,?s)
????????masked_target?=?target.clone()?-?vocab_start_index
????????masked_target[target_mask]?=?0
????????#?Get?predicted-logits?=?logits[target].
????????#?For?Simplicity,?we?convert?logits?to?a?2-D?tensor?with?size
????????#?[*,?partition-vocab-size]?and?target?to?a?1-D?tensor?of?size?[*].
????????logits_2d?=?vocab_parallel_logits.view(-1,?partition_vocab_size)?#?(b*s,?v/N)
????????masked_target_1d?=?masked_target.view(-1)?#?(b*s)
????????arange_1d?=?torch.arange(?#?[b*s]
????????????start=0,?end=logits_2d.size()[0],?device=logits_2d.device
????????)
????????#?logits_2d[arange_1d,?masked_target_1d]:?
????????# tensor的切片操作。arange_1d:取出所有的行。masked_target_1d:取出logit
????????predicted_logits_1d?=?logits_2d[arange_1d,?masked_target_1d]?#?(b*s)
????????predicted_logits_1d?=?predicted_logits_1d.clone().contiguous()
????????predicted_logits?=?predicted_logits_1d.view_as(target)?#?(b,?s)
????????predicted_logits[target_mask]?=?0.0
????????#?All?reduce?is?needed?to?get?the?chunks?from?other?GPUs.
????????torch.distributed.all_reduce(?#?allreduce之后得到的logit矩陣為(b,?s),每一個(gè)位置表示對(duì)應(yīng)真值位置的預(yù)測(cè)logit
????????????predicted_logits,
????????????op=torch.distributed.ReduceOp.SUM,
????????????group=get_tensor_model_parallel_group(),
????????)
????????#?Sum?of?exponential?of?logits?along?vocab?dimension?across?all?GPUs.
????????exp_logits?=?vocab_parallel_logits?#?(b,?s,?v/N)
????????torch.exp(vocab_parallel_logits,?out=exp_logits)
????????sum_exp_logits?=?exp_logits.sum(dim=-1)?#?(b,?s)
????????torch.distributed.all_reduce(
????????????sum_exp_logits,
????????????op=torch.distributed.ReduceOp.SUM,
????????????group=get_tensor_model_parallel_group(),
????????)
????????#?4.?計(jì)算Loss?=?log(sum(exp(logits)))?-?predicted-logit.
????????loss?=?torch.log(sum_exp_logits)?-?predicted_logits?#?(b,?s)
????????#?Store?softmax,?target-mask?and?masked-target?for?backward?pass.
????????exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
????????ctx.save_for_backward(exp_logits,?target_mask,?masked_target_1d)
????????return?loss
????@staticmethod
????def?backward(ctx,?grad_output):
????????#?Retreive?tensors?from?the?forward?path.
????????softmax,?target_mask,?masked_target_1d?=?ctx.saved_tensors
????????#?All?the?inputs?have?softmax?as?their?gradient.
????????grad_input?=?softmax
????????#?For?simplicity,?work?with?the?2D?gradient.
????????partition_vocab_size?=?softmax.size()[-1]
????????grad_2d?=?grad_input.view(-1,?partition_vocab_size)
????????#?Add?the?gradient?from?matching?classes.
????????arange_1d?=?torch.arange(start=0,?end=grad_2d.size()[0],?device=grad_2d.device)
????????grad_2d[arange_1d,?masked_target_1d]?-=?1.0?-?target_mask.view(-1).float()
????????#?Finally?elementwise?multiplication?with?the?output?gradients.
????????grad_input.mul_(grad_output.unsqueeze(dim=-1))
????????return?grad_input,?None
九、總結(jié)
啊這總結(jié)怎么寫呢,嘔心瀝血終于寫完了。希望能給到大家?guī)椭?/p>
十、參考
1、「codegeex github」: https://github.com/THUDM/CodeGeeX/tree/7365d9df242d87a5583d3f203e4b6c547dc6240e
2、「NVIDIA Megatron github」: https://github.com/NVIDIA/Megatron-LM/tree/2c493fb3fd37e5ecac068607b408ed5724d80fcc
3、「torch distributed tutorial」: https://pytorch.org/docs/stable/distributed.html
4、「init_process_group」: https://www.cnblogs.com/rossixyz/p/15553670.html
5、「DeepSpeed Megatron tutorial」: https://www.deepspeed.ai/tutorials/megatron/
6、「codegeex paper」: ?https://arxiv.org/abs/2303.17568
