Transformer系列 | 更深、更強(qiáng)、更輕巧的Transformer,DeLighT(文末...
點(diǎn)擊上方【AI人工智能初學(xué)者】,選擇【星標(biāo)】公眾號
期待您我的相遇與進(jìn)步

1 簡介本文提出了一個更深更輕的Transformer,DeLighT,它的性能與Transformer相似,甚至更好,平均少了2到3倍的參數(shù)。
本文提出了一個更深更輕量的Transformer,DeLighT,DeLighT更有效地在每個Transformer Block中分配參數(shù):
- 1)、使用DeLighT轉(zhuǎn)換進(jìn)行深度和輕量級的轉(zhuǎn)換;
- 2)、使用Block-wise Scaling進(jìn)行跨Block,允許在輸入附近有較淺和較窄的DeLighT Block,以及在輸出附近有較寬和較深的DeLighT Block。
總的來說,DeLighT網(wǎng)絡(luò)的深度是標(biāo)準(zhǔn)Transformer的2.5到4倍,但參數(shù)和操作更少。在機(jī)器翻譯和語言建模任務(wù)上的實(shí)驗(yàn)表明,DeLighT在提高了基準(zhǔn)Transformer性能的基礎(chǔ)上,平均減少了2到3倍的參數(shù)量。
2 相關(guān)工作2.1 Improving transformers
第1種研究研究解決了在長輸入序列上計(jì)算Self-Attention的問題。這些方法可以與本文的架構(gòu)相結(jié)合。
第2種研究側(cè)重于解釋多頭注意力。研究表明增加Transformer Header的數(shù)量會導(dǎo)致冗余表示,使用帶有預(yù)定義模式或綜合注意矩陣的固定注意Header可以提高性能。
第3種研究重點(diǎn)是通過學(xué)習(xí)更好的表示來改進(jìn)Transformer。這些工作旨在使用不同的變換來提高Transformer的表達(dá)性,例如,使用卷積、門控線性單元或多分支特征提取器。本文的工作屬于這一類。與以前的工作不同,本文證明了使用DeLighT變換在塊級和使用塊尺度縮放操作在塊級進(jìn)行有效地分配參數(shù)是可能的。
2.2 Model scaling
Model scaling是提高序列模型性能的一種標(biāo)準(zhǔn)方法。模型的尺寸在寬度尺度上增加,同時在深度尺度上堆疊更多的Block。在這2種情況下(以及它們的組合),網(wǎng)絡(luò)的每個Block內(nèi)的參數(shù)都是相同的,這可能會導(dǎo)致次優(yōu)解。為了進(jìn)一步提高序列模型的性能,本文引入了塊尺度縮放,允許設(shè)計(jì)可變大小的塊和對網(wǎng)絡(luò)中的參數(shù)進(jìn)行有效的分配。
本文的研究結(jié)果表明:
- 1)、靠近輸入的較淺且較窄的DeLighT Block,以及靠近輸出的較深且較寬的DeLighT Block能夠提供最好的性能;
- 2)、與單獨(dú)使用模型縮放相比,基于塊尺度縮放的模型能夠獲得更好的性能。
本文也注意到,卷積神經(jīng)網(wǎng)絡(luò)(CNNs)還可以學(xué)習(xí)靠近輸入的較淺和較窄的表示,以及靠近輸出的較深和較寬的表示。與CNN在每個卷積層執(zhí)行固定數(shù)量的操作不同,建議的塊縮放在每個層和塊中使用可變數(shù)量的操作。
2.3 Improving sequence models
最近在改進(jìn)序列模型的其他相關(guān)方法上也有重要的工作,包括(1)使用更好的標(biāo)記級表示(例如使用BPE)、自適應(yīng)輸入和輸出以及定義來提高準(zhǔn)確性,以及(2)使用壓縮、修剪和蒸餾來提高效率。
本文工作最接近的是定義轉(zhuǎn)換,它也使用expand-reduce策略學(xué)習(xí)表示。DeFINE轉(zhuǎn)換(圖1c)和DeLighT轉(zhuǎn)換(圖1d)之間的關(guān)鍵區(qū)別是,DeLighT轉(zhuǎn)換更有效地在擴(kuò)展層和簡化層中分配參數(shù)。

DeFINE在組線性變換中使用更少的組來學(xué)習(xí)更魯棒的表征,與之不同的是,DeLighT transformation使用更多的組來學(xué)習(xí)更廣泛的表示,且參數(shù)更少。DeLighT轉(zhuǎn)換獲得了與DeFINE轉(zhuǎn)換相當(dāng)?shù)男阅?,但參?shù)卻少得多。
3 DeLight Transformer一個標(biāo)準(zhǔn)的Transformer Block如圖1a所示:

包括使用Query、Key、Value來建模序列Token之間的關(guān)系,以及使用一個前饋網(wǎng)絡(luò)(FFN)來學(xué)習(xí)更廣泛的表征。
多頭注意通過對輸入應(yīng)用3個投影得到Query、Key、Value,每個投影由h個線性層(或頭)組成,將維的輸入映射到一個維的空間,其中是head維。
FFN由一下2個線性層操作完成:
- 第1步:擴(kuò)展維度從到;
- 第2步:減少維度從到。
Transformer Block的深度是4,一般情況下,基于Transformer的網(wǎng)絡(luò)設(shè)計(jì)均是按順序堆疊Transformer Block,以增加網(wǎng)絡(luò)容量和深度。
3.1 DeLight
DeLighT變換先將維度輸入向量映射到高維空間(展開),然后利用N層群變換將其降為維度的輸出向量(降階),如圖1d所示。

在expansion-reduction階段,DeLighT變換使用組線性變換(GLTs),因?yàn)樗鼈兺ㄟ^從輸入的特定部分導(dǎo)出輸出來學(xué)習(xí)局部表示,比線性變換更有效。為了學(xué)習(xí)全局表征,DeLighT變換使用特征變換在組線性變換的不同組之間共享信息,類似于卷積網(wǎng)絡(luò)中的通道變換。
增加Transformer的表達(dá)能力和容量的一種標(biāo)準(zhǔn)方法是增加輸入維數(shù)。然而,線性增加也會增加標(biāo)準(zhǔn)Transformer塊中多線程注意力的復(fù)雜度(,其中是序列長度)。與此相反,為了增加DeLighT塊的表現(xiàn)力和容量,本文使用擴(kuò)展和縮小階段來增加中間DeLighT轉(zhuǎn)換的深度和寬度。這使DeLighT能夠使用更小的維度和更少的操作來計(jì)算注意力。
DeLighT變換由5個配置參數(shù)控制:
- (1)GLT層數(shù)N,
- (2)寬度乘法器,
- (3)輸入維數(shù),
- (4)輸出維數(shù),
- (5)GLT中的最大組。
在expansion階段,DeLighT transformation將維輸入投影到高維空間,,線性層為N/2層;
在reduction階段,DeLighT變換使用剩余的N?N/2 GLT層將維向量投影到維空間。
數(shù)學(xué)上定義GLT層l的輸出Y為:

其中,和分別為liner的第l層組的變換函數(shù)F的權(quán)重和偏置項(xiàng),簡單地說,F(xiàn)函數(shù)輸入X并分成個非重疊組,這樣。函數(shù)F通過使用權(quán)重和偏差對每個進(jìn)行線性變換,產(chǎn)生輸出。
然后,將每組的輸出cat起來,產(chǎn)生輸出。函數(shù)H首先將每組的輸出變換為Yl?1,然后通過Mehta等人的輸入混頻器連接將其與輸入結(jié)合,以避免梯度消失問題。

圖2用組線性變換、特征變換和輸入混頻器連接來可視化了DeLighT變換的擴(kuò)展階段。在DeLighT變換中第l-th GLT處的組數(shù)計(jì)算如下:

在實(shí)驗(yàn)中,作者使用,這樣每組至少有32個輸入元素。
3.2 DeLighT Block
圖1b顯示了如何將DeLighT transformation集成到transformer塊中以提高其效率。首先將維度的輸入輸入到DeLighT變換中,生成維度輸出。然后將這些維度輸出輸入到一個單一的頭部注意力中,然后是通過一個輕量級的FFN來建模它們的關(guān)系。
DeLighT layer和Single Head Attention
假設(shè)有一個由n個輸入token組成的序列,每個token的維數(shù)都是。這些n個維的輸入首先被輸入到DeLighT變換中產(chǎn)生n個維的輸出,其中。
然后使用3個線性層同時投影這n個維輸出,以產(chǎn)生do維查詢Q、鍵K和值v。然后,使用縮放點(diǎn)積注意對這n個token之間的上下文關(guān)系建模。為了使用剩余連接,這個注意操作的維輸出被線性投影到維空間。

假設(shè),DeLighT能夠?qū)W習(xí)更廣泛的表征,這使得可以用單頭注意力取代多頭注意力。標(biāo)準(zhǔn)transformer和DeLighT塊中計(jì)算注意力的計(jì)算代價分別為和,其中;
因此,DeLighT塊將計(jì)算注意力的成本降低了一個因子。在實(shí)驗(yàn)中,使用,因此需要的乘法-加法操作比transformer架構(gòu)少2倍。
Light-weight FFN
與transformer中的ffn類似,這個塊也由兩個線性層組成。由于DeLighT塊已經(jīng)使用DeLighT轉(zhuǎn)換合并了更廣泛的表示,它允許在transformer中反轉(zhuǎn)FFN層的功能。第1層將輸入從降維到,第2層將輸入從擴(kuò)展到,其中為降維因子(見圖1b)。輕量級FFN通過減少了參數(shù)和操作的數(shù)量。在標(biāo)準(zhǔn)transformer中,F(xiàn)FN的尺寸擴(kuò)大了4.1倍。在實(shí)驗(yàn)中使用r=4。因此,輕量化的FFN將FFN中的參數(shù)數(shù)量減少了16倍。
Block depth
DeLighT塊棧包括:
- 1)、1個有N個GLTs的DeLighT轉(zhuǎn)換,
- 2)、3個平行的用于鍵、查詢和值的線性層,
- 3)、一個投影層,
- 4)、輕量級FFN的2個線性層。
因此,DeLighT塊的深度是N+4。與標(biāo)準(zhǔn)transformer(深度為4)相比,DeLighT塊更深。
3.3 Block-Wise Scaling
改進(jìn)序列模型性能的標(biāo)準(zhǔn)方法包括增加模型尺寸(寬度縮放),堆疊更多的塊(深度縮放),或兩者兼用。然而,這種尺度變換在小數(shù)據(jù)集上并不十分有效。
例如,在WMT'16 En-Ro語料上,當(dāng)一個基于transformer(=512)的網(wǎng)絡(luò)被替換為大型transformer(=1024)時,參數(shù)的數(shù)量增加了大約4倍,而性能沒有明顯變化(BLEU:34.28 vs.34.35)。假設(shè)這是因?yàn)榭s放模型寬度和深度在塊之間均勻分配參數(shù),這可能導(dǎo)致學(xué)習(xí)冗余參數(shù)。為了創(chuàng)建更深和更廣的網(wǎng)絡(luò),作者將模型擴(kuò)展到塊級別(參見下圖3)。

Scaling the DeLighT block
DeLighT塊使用DeLighT變換學(xué)習(xí)深度和寬度表示,其深度和寬度分別由兩個配置參數(shù)控制:GLT層數(shù)N和寬度乘法器(圖3a)。這些配置參數(shù)允許增加DeLighT塊內(nèi)可學(xué)習(xí)參數(shù)的數(shù)量,獨(dú)立于輸入和輸出維度。標(biāo)準(zhǔn)transformer組不可能進(jìn)行這種校準(zhǔn),因?yàn)樗鼈兊谋磉_(dá)能力和容量是輸入的函數(shù)(輸入維數(shù)=頭維數(shù))。
在這里,作者引入了按塊縮放,它創(chuàng)建了一個具有可變大小的DeLighT塊的網(wǎng)絡(luò),在輸入附近分配較淺和較窄的DeLighT塊,在輸出附近分配較深和較寬的DeLighT塊。
為此引入了2個網(wǎng)絡(luò)范圍的配置參數(shù):DeLighT變換中的最小和最大數(shù)量。對于b-th DeLighT塊使用線性縮放方法計(jì)算DeLighT變換中GLTs Nb的數(shù)目和寬度乘子。通過這種縮放,每個DeLighT塊都有不同的深度和寬度(圖3a)。

其中,B為網(wǎng)絡(luò)中DeLighT塊的數(shù)量。
Network depth
transformer組的深度固定,即depth=4。因此,先前的研究將基于transformer的網(wǎng)絡(luò)的深度與transformer塊的數(shù)量聯(lián)系起來。而本文提供了一個不同的視角來學(xué)習(xí)更深層次的表示,其中每個塊是不同大小的。為了計(jì)算網(wǎng)絡(luò)深度使用了跨不同領(lǐng)域的標(biāo)準(zhǔn)定義,包括計(jì)算機(jī)視覺和理論機(jī)器學(xué)習(xí)。這些工作測量網(wǎng)絡(luò)深度作為順序可學(xué)習(xí)層的數(shù)量(例如,卷積,線性,或組線性)。同理,有B塊的DeLighT和transformer網(wǎng)絡(luò)的深度分別為和4B。
class?DeLighTTransformerEncoderLayer(nn.Module):
????"""DeLight?Encoder?layer
????"""
????def?__init__(self,?args,?embed_dim,?width_multiplier=DEFAULT_WIDTH_MULTIPLIER,?dextra_depth=DEFAULT_MIN_DEXTRA_LAYERS,
?????????????????dextra_proj=2):
????????super().__init__()
????????self.embed_dim?=?embed_dim
????????assert?embed_dim?%?dextra_proj?==?0
????????self.proj_dim?=?embed_dim?//?dextra_proj
????????self.dextra_layer?=?DExTraUnit(in_features=self.embed_dim,
???????????????????????????????????????in_proj_features=self.proj_dim,
???????????????????????????????????????out_features=self.proj_dim,
???????????????????????????????????????width_multiplier=width_multiplier,
???????????????????????????????????????dextra_depth=dextra_depth,
???????????????????????????????????????dextra_dropout=args.delight_dropout,
???????????????????????????????????????max_glt_groups=args.delight_enc_max_groups,
???????????????????????????????????????act_type=args.act_type,
???????????????????????????????????????use_bias=True,
???????????????????????????????????????norm_type=args.norm_type,
???????????????????????????????????????glt_shuffle=args.glt_shuffle,
???????????????????????????????????????is_iclr_version=args.define_iclr
???????????????????????????????????????)
????????self.self_attn?=?SingleHeadAttention(q_in_dim=self.proj_dim,
?????????????????????????????????????????????kv_in_dim=self.proj_dim,
?????????????????????????????????????????????proj_dim=self.proj_dim,
?????????????????????????????????????????????out_dim=self.embed_dim,
?????????????????????????????????????????????dropout=args.attention_dropout,
?????????????????????????????????????????????bias=True,
?????????????????????????????????????????????self_attention=True,
?????????????????????????????????????????????encoder_decoder_attention=False)
????????self.self_attn_layer_norm?=?get_norm_layer(name=args.norm_type,?out_features=self.embed_dim)
????????self.dropout?=?args.dropout
????????self.norm_fn?=?args.norm_type
????????self.act_type?=?args.act_type
????????self.activation_fn?=?get_activation_layer(name=args.act_type)
????????self.activation_dropout?=?getattr(args,?"activation_dropout",?0)
????????if?self.activation_dropout?==?0:
????????????#?for?backwards?compatibility?with?models?that?use?args.relu_dropout
????????????self.activation_dropout?=?getattr(args,?"relu_dropout",?0)
????????self.normalize_before?=?args.encoder_normalize_before
????????#?Light-weight?FFN
????????self.ffn_dropout?=?args.ffn_dropout
????????ffn_red_factor?=?args.delight_enc_ffn_red
????????assert?self.embed_dim?%?ffn_red_factor?==?0,?'{}/{}?should?be?a?perfect?divisor'.format(self.embed_dim,
????????????????????????????????????????????????????????????????????????????????????????????????ffn_red_factor)
????????light_ffn_dim?=?self.embed_dim?//?ffn_red_factor
????????self.fc1?=?get_weight_layer(name='linear',
????????????????????????????????????in_features=self.embed_dim,
????????????????????????????????????out_features=light_ffn_dim,
????????????????????????????????????use_bias=True)
????????self.fc2?=?get_weight_layer(name='linear',
????????????????????????????????????in_features=light_ffn_dim,
????????????????????????????????????out_features=self.embed_dim,
????????????????????????????????????use_bias=True)
????????self.final_layer_norm?=?get_norm_layer(name=args.norm_type,?out_features=self.embed_dim)
????def?__repr__(self):
????????s?=?'{name}(in_features={embed_dim},?out_features={embed_dim},?dropout={dropout},'?\
????????????'activation_dropout={activation_dropout},?ffn_dropout={ffn_dropout},?'?\
????????????'activation_fn={act_type},?norm_fn={norm_fn})'
????????s?+=?'\n?\t?Dextra?Layer:?\n?\t?\t?{}'.format(self.dextra_layer)
????????s?+=?'\n?\t?Self?Attention:?\n?\t?\t?{}'.format(self.self_attn)
????????s?+=?'\n?\t?????Light-weight?FFN:?\n?\t?????|----?{}?\n?\t?????|----?{}'.format(self.fc1,?self.fc2)
????????return?s.format(name=self.__class__.__name__,?**self.__dict__)
????def?upgrade_state_dict_named(self,?state_dict,?name):
????????"""
????????Rename?layer?norm?states?from?`...layer_norms.0.weight`?to
????????`...self_attn_layer_norm.weight`?and?`...layer_norms.1.weight`?to
????????`...final_layer_norm.weight`
????????"""
????????layer_norm_map?=?{"0":?"self_attn_layer_norm",?"1":?"final_layer_norm"}
????????for?old,?new?in?layer_norm_map.items():
????????????for?m?in?("weight",?"bias"):
????????????????k?=?"{}.layer_norms.{}.{}".format(name,?old,?m)
????????????????if?k?in?state_dict:
????????????????????state_dict["{}.{}.{}".format(name,?new,?m)]?=?state_dict[k]
????????????????????del?state_dict[k]
????def?forward(self,?x,?encoder_padding_mask,?attn_mask:?Optional[Tensor]?=?None):
????????"""
????????Args:
????????????x?(Tensor):?input?to?the?layer?of?shape?`(seq_len,?batch,?embed_dim)`
????????????encoder_padding_mask?(ByteTensor):?binary?ByteTensor?of?shape
????????????????`(batch,?src_len)`?where?padding?elements?are?indicated?by?``1``.
????????????attn_mask?(ByteTensor):?binary?tensor?of?shape?(T_tgt,?T_src),?where
????????????T_tgt?is?the?length?of?query,?while?T_src?is?the?length?of?key,
????????????though?here?both?query?and?key?is?x?here,
????????????attn_mask[t_tgt,?t_src]?=?1?means?when?calculating?embedding
????????????for?t_tgt,?t_src?is?excluded?(or?masked?out),?=0?means?it?is
????????????included?in?attention
????????Returns:
????????????encoded?output?of?shape?`(seq_len,?batch,?embed_dim)`
????????"""
????????residual?=?x
????????if?self.normalize_before:
????????????x?=?self.self_attn_layer_norm(x)
????????if?attn_mask?is?not?None:
????????????attn_mask?=?attn_mask.masked_fill(attn_mask.to(torch.bool),?-1e8)
????????x?=?self.dextra_layer(x)
????????x,?_?=?self.self_attn(
????????????query=x,
????????????key_value=None,
????????????key_padding_mask=encoder_padding_mask,
????????????attn_mask=attn_mask
????????)
????????x?=?F.dropout(x,?p=self.dropout,?training=self.training)
????????x?=?residual?+?x
????????if?not?self.normalize_before:
????????????x?=?self.self_attn_layer_norm(x)
????????#?Light-weight?FFN
????????residual?=?x
????????if?self.normalize_before:
????????????x?=?self.final_layer_norm(x)
????????x?=?self.activation_fn(self.fc1(x))
????????x?=?F.dropout(x,?p=float(self.activation_dropout),?training=self.training)
????????x?=?self.fc2(x)
????????x?=?F.dropout(x,?p=self.ffn_dropout,?training=self.training)
????????x?=?residual?+?x
????????if?not?self.normalize_before:
????????????x?=?self.final_layer_norm(x)
????????return?x
????def?compute_macs_params(self,?S=1):
????????macs?=?0
????????n_params?=?0
????????macs_attn?=?0
????????#?Layer?Norms
????????#?MACS?are?zero?for?LayerNorm?because?they?can?be?fused
????????n_params?+=?sum([p.numel()?for?p?in?self.self_attn_layer_norm.parameters()])
????????#?Dextra?layer
????????dextra_layer?=?self.dextra_layer.compute_macs_params()
????????n_params?+=?dextra_layer['params']
????????macs?+=?(dextra_layer['macs']?*?S)
????????#?Attn
????????self_attn_layer?=?self.self_attn.compute_macs_params(T=S,?S=S)
????????macs?+=?self_attn_layer['macs']
????????n_params?+=?self_attn_layer['params']
????????macs_attn?+=?self_attn_layer['macs_attn']
????????#?FFN
????????fc1_layer?=?self.fc1.compute_macs_params()
????????#?scale?MACS?by?S?because?S?tokens?can?be?processed?in?parallel
????????macs?+=?(fc1_layer['macs']?*?S)
????????n_params?+=?fc1_layer['params']
????????fc2_layer?=?self.fc2.compute_macs_params()
????????#?scale?MACS?by?S?because?S?tokens?can?be?processed?in?parallel
????????macs?+=?(fc2_layer['macs']?*?S)
????????n_params?+=?fc2_layer['params']
????????n_params?+=?sum([p.numel()?for?p?in?self.final_layer_norm.parameters()])
????????return?{
????????????'name':?self.__class__.__name__,
????????????'macs':?macs,
????????????'params':?n_params,
????????????'macs_attn':?macs_attn
????????}
class?DeLighTTransformerDecoderLayer(nn.Module):
????"""Delight?Decoder?layer
????"""
????def?__init__(self,?args,?embed_dim,?width_multiplier=DEFAULT_WIDTH_MULTIPLIER,?dextra_depth=DEFAULT_MIN_DEXTRA_LAYERS,
?????????????????no_encoder_attn=False,?dextra_proj=2,?*unused_args,?**unused_kwargs):
????????super().__init__()
????????self.embed_dim?=?embed_dim
????????assert?embed_dim?%?dextra_proj?==?0
????????self.proj_dim?=?embed_dim?//?dextra_proj
????????self.norm_fn?=?args.norm_type
????????self.act_type?=?args.act_type
????????self.dextra_layer_sa?=?DExTraUnit(in_features=self.embed_dim,
??????????????????????????????????????????in_proj_features=self.proj_dim,
??????????????????????????????????????????out_features=self.proj_dim,
??????????????????????????????????????????width_multiplier=width_multiplier,
??????????????????????????????????????????dextra_depth=dextra_depth,
??????????????????????????????????????????dextra_dropout=args.delight_dropout,
??????????????????????????????????????????max_glt_groups=args.delight_dec_max_groups,
??????????????????????????????????????????act_type=args.act_type,
??????????????????????????????????????????use_bias=True,
??????????????????????????????????????????norm_type=args.norm_type,
??????????????????????????????????????????glt_shuffle=args.glt_shuffle,
??????????????????????????????????????????is_iclr_version=args.define_iclr
??????????????????????????????????????????)
????????self.self_attn?=?SingleHeadAttention(q_in_dim=self.proj_dim,
?????????????????????????????????????????????kv_in_dim=self.proj_dim,
?????????????????????????????????????????????proj_dim=self.proj_dim,
?????????????????????????????????????????????out_dim=self.embed_dim,
?????????????????????????????????????????????dropout=args.attention_dropout,
?????????????????????????????????????????????bias=True,
?????????????????????????????????????????????self_attention=True,
?????????????????????????????????????????????encoder_decoder_attention=False)
????????self.dropout?=?args.dropout
????????self.activation_fn?=?get_activation_layer(name=args.act_type)
????????self.activation_dropout?=?getattr(args,?"activation_dropout",?0)
????????if?self.activation_dropout?==?0:
????????????#?for?backwards?compatibility?with?models?that?use?args.relu_dropout
????????????self.activation_dropout?=?getattr(args,?"relu_dropout",?0)
????????self.normalize_before?=?args.decoder_normalize_before
????????self.self_attn_layer_norm?=?get_norm_layer(name=args.norm_type,?out_features=self.embed_dim)
????????if?no_encoder_attn:
????????????self.encoder_attn?=?None
????????????self.encoder_attn_layer_norm?=?None
????????else:
????????????q_embed_dim?=?self.embed_dim
????????????self.encoder_attn?=?SingleHeadAttention(q_in_dim=q_embed_dim,
????????????????????????????????????????????????????kv_in_dim=self.embed_dim,
????????????????????????????????????????????????????proj_dim=self.proj_dim,
????????????????????????????????????????????????????out_dim=self.embed_dim,
????????????????????????????????????????????????????dropout=args.attention_dropout,
????????????????????????????????????????????????????bias=True,
????????????????????????????????????????????????????encoder_decoder_attention=True,
????????????????????????????????????????????????????self_attention=False)
????????????self.encoder_attn_layer_norm?=?get_norm_layer(name=args.norm_type,?out_features=self.embed_dim)
????????self.ffn_dropout?=?args.ffn_dropout
????????ffn_red_factor?=?args.delight_dec_ffn_red
????????assert?self.embed_dim?%?ffn_red_factor?==?0,?'{}/{}?should?be?a?perfect?divisor'.format(self.embed_dim,
????????????????????????????????????????????????????????????????????????????????????????????????ffn_red_factor)
????????#?Feed?forward?network
????????light_ffn_dim?=?self.embed_dim?//?ffn_red_factor
????????self.fc1?=?get_weight_layer(name='linear',
????????????????????????????????????in_features=self.embed_dim,
????????????????????????????????????out_features=light_ffn_dim,
????????????????????????????????????use_bias=True)
????????self.fc2?=?get_weight_layer(name='linear',
????????????????????????????????????in_features=light_ffn_dim,
????????????????????????????????????out_features=self.embed_dim,
????????????????????????????????????use_bias=True)
????????self.final_layer_norm?=?get_norm_layer(name=args.norm_type,?out_features=self.embed_dim)
????????self.need_attn?=?True
????????self.onnx_trace?=?False
????def?__repr__(self):
????????s?=?'{name}(in_features={embed_dim},?out_features={embed_dim},?dropout={dropout},?'?\
????????????'activation_dropout={activation_dropout},?ffn_dropout={ffn_dropout},?'?\
????????????'activation_fn={act_type},?norm_fn={norm_fn})'
????????s?+=?'\n?\t?????Dextra?Layer?(Query):?\n?\t?\t?{}'.format(self.dextra_layer_sa)
????????s?+=?'\n?\t?????Self?Attention?(Decoder):?\n?\t?\t?{}'.format(self.self_attn)
????????if?self.encoder_attn?is?not?None:
????????????s?+=?'\n?\t?????Encoder-Decoder?Attention:?\n?\t?\t?{}'.format(self.encoder_attn)
????????s?+=?'\n?\t?????Light-weight?FFN:?\n?\t?????|----?{}?\n?\t?????|----?{}'.format(self.fc1,?self.fc2)
????????return?s.format(name=self.__class__.__name__,?**self.__dict__)
????def?prepare_for_onnx_export_(self):
????????self.onnx_trace?=?True
????def?forward(
????????????self,
????????????x,
????????????encoder_out:?Optional[torch.Tensor]?=?None,
????????????encoder_padding_mask:?Optional[torch.Tensor]?=?None,
????????????incremental_state:?Optional[Dict[str,?Dict[str,?Optional[Tensor]]]]?=?None,
????????????prev_self_attn_state:?Optional[List[torch.Tensor]]?=?None,
????????????prev_attn_state:?Optional[List[torch.Tensor]]?=?None,
????????????self_attn_mask:?Optional[torch.Tensor]?=?None,
????????????self_attn_padding_mask:?Optional[torch.Tensor]?=?None,
????????????need_attn:?bool?=?False,
????????????need_head_weights:?bool?=?False,
????):
????????"""
????????Args:
????????????x?(Tensor):?input?to?the?layer?of?shape?`(seq_len,?batch,?embed_dim)`
????????????encoder_padding_mask?(ByteTensor,?optional):?binary
????????????????ByteTensor?of?shape?`(batch,?src_len)`?where?padding
????????????????elements?are?indicated?by?``1``.
????????????need_attn?(bool,?optional):?return?attention?weights
????????????need_head_weights?(bool,?optional):?return?attention?weights
????????????????for?each?head?(default:?return?average?over?heads).
????????Returns:
????????????encoded?output?of?shape?`(seq_len,?batch,?embed_dim)`
????????"""
????????if?need_head_weights:
????????????need_attn?=?True
????????residual?=?x
????????if?self.normalize_before:
????????????x?=?self.self_attn_layer_norm(x)
????????#?apply?dextra?layer
????????x?=?self.dextra_layer_sa(x)
????????if?prev_self_attn_state?is?not?None:
????????????prev_key,?prev_value?=?prev_self_attn_state[:2]
????????????saved_state:?Dict[str,?Optional[Tensor]]?=?{
????????????????"prev_key":?prev_key,
????????????????"prev_value":?prev_value,
????????????}
????????????if?len(prev_self_attn_state)?>=?3:
????????????????saved_state["prev_key_padding_mask"]?=?prev_self_attn_state[2]
????????????assert?incremental_state?is?not?None
????????????self.self_attn._set_input_buffer(incremental_state,?saved_state)
????????x,?attn?=?self.self_attn(
????????????query=x,
????????????key_value=None,
????????????key_padding_mask=self_attn_padding_mask,
????????????incremental_state=incremental_state,
????????????need_weights=False,
????????????attn_mask=self_attn_mask,
????????)
????????x?=?F.dropout(x,?p=self.dropout,?training=self.training)
????????x?=?residual?+?x
????????if?not?self.normalize_before:
????????????x?=?self.self_attn_layer_norm(x)
????????if?self.encoder_attn?is?not?None:
????????????residual?=?x
????????????if?self.normalize_before:
????????????????x?=?self.encoder_attn_layer_norm(x)
????????????if?prev_attn_state?is?not?None:
????????????????prev_key,?prev_value?=?prev_attn_state[:2]
????????????????saved_state:?Dict[str,?Optional[Tensor]]?=?{
????????????????????"prev_key":?prev_key,
????????????????????"prev_value":?prev_value,
????????????????}
????????????????if?len(prev_attn_state)?>=?3:
????????????????????saved_state["prev_key_padding_mask"]?=?prev_attn_state[2]
????????????????assert?incremental_state?is?not?None
????????????????self.encoder_attn._set_input_buffer(incremental_state,?saved_state)
????????????x,?attn?=?self.encoder_attn(
????????????????query=x,
????????????????key_value=encoder_out,
????????????????key_padding_mask=encoder_padding_mask,
????????????????incremental_state=incremental_state,
????????????????static_kv=True,
????????????????need_weights=need_attn?or?(not?self.training?and?self.need_attn),
????????????????need_head_weights=need_head_weights,
????????????)
????????????x?=?F.dropout(x,?p=self.dropout,?training=self.training)
????????????x?=?residual?+?x
????????????if?not?self.normalize_before:
????????????????x?=?self.encoder_attn_layer_norm(x)
????????#Light-weight?FFN
????????residual?=?x
????????if?self.normalize_before:
????????????x?=?self.final_layer_norm(x)
????????x?=?self.activation_fn(self.fc1(x))
????????x?=?F.dropout(x,?p=float(self.activation_dropout),?training=self.training)
????????x?=?self.fc2(x)
????????x?=?F.dropout(x,?p=self.ffn_dropout,?training=self.training)
????????x?=?residual?+?x
????????if?not?self.normalize_before:
????????????x?=?self.final_layer_norm(x)
????????if?self.onnx_trace?and?incremental_state?is?not?None:
????????????saved_state?=?self.self_attn._get_input_buffer(incremental_state)
????????????assert?saved_state?is?not?None
????????????if?self_attn_padding_mask?is?not?None:
????????????????self_attn_state?=?[
????????????????????saved_state["prev_key"],
????????????????????saved_state["prev_value"],
????????????????????saved_state["prev_key_padding_mask"],
????????????????]
????????????else:
????????????????self_attn_state?=?[saved_state["prev_key"],?saved_state["prev_value"]]
????????????return?x,?attn,?self_attn_state
????????return?x,?attn,?None
????def?make_generation_fast_(self,?need_attn:?bool?=?False,?**kwargs):
????????self.need_attn?=?need_attn
????def?compute_macs_params(self,?T=1,?S=1):
????????macs?=?0
????????n_params?=?0
????????macs_attn?=?0
????????#?LayerNorm
????????n_params?+=?sum([p.numel()?for?p?in?self.self_attn_layer_norm.parameters()])
????????#?self?attention
????????self_attn_layer?=?self.self_attn.compute_macs_params(T=T,?S=T)
????????dextra_layer?=?self.dextra_layer_sa.compute_macs_params()
????????macs?+=?self_attn_layer['macs']?+?(dextra_layer['macs']?*?T)
????????n_params?+=?self_attn_layer['params']?+?dextra_layer['params']
????????macs_attn?+=?self_attn_layer['macs_attn']
????????#?Encoder-decoder?attn
????????if?self.encoder_attn?is?not?None:
????????????#?self?attention?scaled-dot-product?Attn
????????????n_params?+=?sum([p.numel()?for?p?in?self.encoder_attn_layer_norm.parameters()])
????????????enc_attn?=?self.encoder_attn.compute_macs_params(T=T,?S=S)
????????????macs?+=?enc_attn['macs']
????????????n_params?+=?enc_attn['params']
????????????macs_attn?+=?enc_attn['macs_attn']
????????#?FFN
????????fc1_layer?=?self.fc1.compute_macs_params()
????????macs?+=?(fc1_layer['macs']?*?T)
????????n_params?+=?fc1_layer['params']
????????fc2_layer?=?self.fc2.compute_macs_params()
????????macs?+=?(fc2_layer['macs']?*?T)
????????n_params?+=?fc2_layer['params']
????????n_params?+=?sum([p.numel()?for?p?in?self.final_layer_norm.parameters()])
????????return?{
????????????'name':?self.__class__.__name__,
????????????'macs':?macs,
????????????'params':?n_params,
????????????'macs_attn':?macs_attn
????????}
if?__name__?==?'__main__':
????pass
4. 實(shí)驗(yàn)4.1 機(jī)器翻譯實(shí)驗(yàn)


4.2 語言模型

毫無疑問,更快更強(qiáng)!?。?/p>5 參考
[1].DELIGHT: DEEP AND LIGHT-WEIGHT TRANSFORMER
[2].https://github.com/sacmehta/delight
原文獲取方式,掃描下方二維碼
回復(fù)【DeLighT】即可獲取論文與源碼
聲明:轉(zhuǎn)載請說明出處
掃描下方二維碼關(guān)注【AI人工智能初學(xué)者】公眾號,獲取更多實(shí)踐項(xiàng)目源碼和論文解讀,非常期待你我的相遇,讓我們以夢為馬,砥礪前行!??!
點(diǎn)“在看”給我一朵小黃花唄![]()
