<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          【NLP】BERT蒸餾完全指南|原理/技巧/代碼

          共 3836字,需瀏覽 8分鐘

           ·

          2020-11-11 05:12

          小朋友,關(guān)于模型蒸餾,你是否有很多問號:

          • 蒸餾是什么?怎么蒸BERT?
          • BERT蒸餾有什么技巧?如何調(diào)參?
          • 蒸餾代碼怎么寫?有現(xiàn)成的嗎?

          今天rumor就結(jié)合Distilled BiLSTM/BERT-PKD/DistillBERT/TinyBERT/MobileBERT/MiniLM六大經(jīng)典模型,帶大家把BERT蒸餾整到明明白白!

          模型蒸餾原理

          Hinton在NIPS2014[1]提出了知識蒸餾(Knowledge Distillation)的概念,旨在把一個大模型或者多個模型ensemble學(xué)到的知識遷移到另一個輕量級單模型上,方便部署。簡單的說就是用小模型去學(xué)習(xí)大模型的預(yù)測結(jié)果,而不是直接學(xué)習(xí)訓(xùn)練集中的label。

          在蒸餾的過程中,我們將原始大模型稱為教師模型(teacher),新的小模型稱為學(xué)生模型(student),訓(xùn)練集中的標(biāo)簽稱為hard label,教師模型預(yù)測的概率輸出為soft label,temperature(T)是用來調(diào)整soft label的超參數(shù)。

          蒸餾這個概念之所以work,核心思想是因?yàn)?strong style="color: black;">好模型的目標(biāo)不是擬合訓(xùn)練數(shù)據(jù),而是學(xué)習(xí)如何泛化到新的數(shù)據(jù)。所以蒸餾的目標(biāo)是讓學(xué)生模型學(xué)習(xí)到教師模型的泛化能力,理論上得到的結(jié)果會比單純擬合訓(xùn)練數(shù)據(jù)的學(xué)生模型要好。

          如何蒸餾

          蒸餾發(fā)展到今天,有各種各樣的花式方法,我們先從最基本的說起。

          之前提到學(xué)生模型需要通過教師模型的輸出學(xué)習(xí)泛化能力,那對于簡單的二分類任務(wù)來說,直接拿教師預(yù)測的0/1結(jié)果會與訓(xùn)練集差不多,沒什么意義,那拿概率值是不是好一些?于是Hinton采用了教師模型的輸出概率q,同時為了更好地控制輸出概率的平滑程度,給教師模型的softmax中加了一個參數(shù)T。

          有了教師模型的輸出后,學(xué)生模型的目標(biāo)就是盡可能擬合教師模型的輸出,新loss就變成了:

          其中CE是交叉熵(Cross-Entropy),y是真實(shí)label,p是學(xué)生模型的預(yù)測結(jié)果,是蒸餾loss的權(quán)重。這里要注意的是,因?yàn)閷W(xué)生模型要擬合教師模型的分布,所以在求p時的也要使用一樣的參數(shù)T。另外,因?yàn)樵谇筇荻葧r新的目標(biāo)函數(shù)會導(dǎo)致梯度是以前的 ,所以要再乘上,不然T變了的話hard label不減?。═=1),但soft label會變。

          有同學(xué)可能會疑惑:如果可以擬合prob,那直接擬合logits可以嗎?

          當(dāng)然可以,Hinton在論文中進(jìn)行了證明,如果T很大,且logits分布的均值為0時,優(yōu)化概率交叉熵和logits的平方差是等價的。

          BERT蒸餾

          在BERT提出后,如何瘦身就成了一個重要分支。主流的方法主要有剪枝、蒸餾和量化。量化的提升有限,因此免不了采用剪枝+蒸餾的融合方法來獲取更好的效果。接下來將介紹BERT蒸餾的主要發(fā)展脈絡(luò),從各個研究看來,蒸餾的提升一方面來源于從精調(diào)階段蒸餾->預(yù)訓(xùn)練階段蒸餾,另一方面則來源于蒸餾最后一層知識->蒸餾隱層知識->蒸餾注意力矩陣

          Distilled BiLSTM

          Distilled BiLSTM[2]于2019年5月提出,作者將BERT-large蒸餾到了單層的BiLSTM中,參數(shù)量減少了100倍,速度提升了15倍,效果雖然比BERT差不少,但可以和ELMo打成平手。

          Distilled BiLSTM的教師模型采用精調(diào)過的BERT-large,學(xué)生模型采用BiLSTM+ReLU,蒸餾的目標(biāo)是hard labe的交叉熵+logits之間的MSE(作者經(jīng)過實(shí)驗(yàn)發(fā)現(xiàn)MSE比上文的更好)。

          同時因?yàn)槿蝿?wù)數(shù)據(jù)有限,作者基于以下規(guī)則進(jìn)行了10+倍的數(shù)據(jù)擴(kuò)充:

          • 用[MASK]隨機(jī)替換單詞
          • 基于POS標(biāo)簽替換單詞
          • 從樣本中隨機(jī)取出n-gram作為新的樣本

          但由于沒有消融實(shí)驗(yàn),無法知道數(shù)據(jù)增強(qiáng)給模型提升了多少最終效果。

          BERT-PKD (EMNLP2019)

          既然BERT有那么多層,是不是可以蒸餾中間層的知識,讓學(xué)生模型更好地?cái)M合呢?

          BERT-PKD[3]不同于之前的研究,提出了Patient Knowledge Distillation,即從教師模型的中間層提取知識,避免在蒸餾最后一層時擬合過快的現(xiàn)象(有過擬合的風(fēng)險(xiǎn))。

          對于中間層的蒸餾,作者采用了歸一化之后MSE,稱為PT loss。

          教師模型采用精調(diào)好的BERT-base,學(xué)生模型一個6層一個3層。為了初始化一個更好的學(xué)生模型,作者提出了兩種策略,一種是PKD-skip,即用BERT-base的第[2,4,6,8,10]層,另一種是PKD-last,采用第[7,8,9,10,11]層。最終實(shí)驗(yàn)顯示PKD-skip要略好一點(diǎn)點(diǎn)(<0.01)。

          DistillBERT (NIPS2019)

          之前的工作都是對精調(diào)后的BERT進(jìn)行蒸餾,學(xué)生模型學(xué)到的都是任務(wù)相關(guān)的知識。HuggingFace則提出了DistillBERT[4],在預(yù)訓(xùn)練階段進(jìn)行蒸餾。將尺寸減小了40%,速度提升60%,效果好于BERT-PKD,為教師模型的97%。

          DistillBERT的教師模型采用了預(yù)訓(xùn)練好的BERT-base,學(xué)生模型則是6層transformer,采用了PKD-skip的方式進(jìn)行初始化。和之前蒸餾目標(biāo)不同的是,為了調(diào)整教師和學(xué)生的隱層向量方向,作者新增了一個cosine embedding loss,蒸餾最后一層hidden的。最終損失函數(shù)由MLM loss、教師-學(xué)生最后一層的交叉熵、隱層之間的cosine loss組成。從消融實(shí)驗(yàn)可以看出,MLM loss對于學(xué)生模型的表現(xiàn)影響較小,同時初始化也是影響效果的重要因素:

          TinyBERT(EMNLP2019)

          既然精調(diào)階段、預(yù)訓(xùn)練階段都分別被蒸餾過了,理論上兩步聯(lián)合起來的效果可能會更好。

          TinyBERT[5]就提出了two-stage learning框架,分別在預(yù)訓(xùn)練和精調(diào)階段蒸餾教師模型,得到了參數(shù)量減少7.5倍,速度提升9.4倍的4層BERT,效果可以達(dá)到教師模型的96.8%,同時這種方法訓(xùn)出的6層模型甚至接近BERT-base,超過了BERT-PKD和DistillBERT。

          TinyBERT的教師模型采用BERT-base。作者參考其他研究的結(jié)論,即注意力矩陣可以捕獲到豐富的知識,提出了注意力矩陣的蒸餾,采用教師-學(xué)生注意力矩陣logits的MSE作為損失函數(shù)(這里不取attention prob是實(shí)驗(yàn)表明前者收斂更快)。另外,作者還對embedding進(jìn)行了蒸餾,同樣是采用MSE作為損失。

          于是整體的loss計(jì)算可以用下式表示:

          其中m表示層數(shù)。表示教師-學(xué)生最后一層logits的交叉熵。

          最后的實(shí)驗(yàn)中,預(yù)訓(xùn)練階段只對中間層進(jìn)行了蒸餾;精調(diào)階段則先對中間層蒸餾20個epochs,再對最后一層蒸餾3個epochs。

          上圖是各個階段的消融實(shí)驗(yàn)。GD(General Distillation)表示預(yù)訓(xùn)練蒸餾,TD(Task Distillation)表示精調(diào)階段蒸餾,DA(Data Augmentation)表示數(shù)據(jù)增強(qiáng),主要用于精調(diào)階段。從消融實(shí)驗(yàn)來看GD帶來的提升不如TD或者DA,TD和DA對最終結(jié)果的影響差不多(有種蒸了這么半天還不如多標(biāo)點(diǎn)數(shù)據(jù)的感覺=.=)。

          MobileBERT(ACL2020)

          前文介紹的模型都是層次剪枝+蒸餾的操作,MobileBERT[6]則致力于減少每層的維度,在保留24層的情況下,減少了4.3倍的參數(shù),速度提升5.5倍,在GLUE上平均只比BERT-base低了0.6個點(diǎn),效果好于TinyBERT和DistillBERT。

          MobileBERT壓縮維度的主要思想在于bottleneck機(jī)制,如下圖所示:

          其中a是標(biāo)準(zhǔn)的BERT,b是加入bottleneck的BERT-large,作為教師模型,c是加入bottleneck的學(xué)生模型。Bottleneck的原理是在transformer的輸入輸出各加入一個線性層,實(shí)現(xiàn)維度的縮放。對于教師模型,embedding的維度是512,進(jìn)入transformer后擴(kuò)大為1024,而學(xué)生模型則是從512縮小至128,使得參數(shù)量驟減。

          另外,作者發(fā)現(xiàn)在標(biāo)準(zhǔn)BERT中,多頭注意力機(jī)制MHA和非線性層FFN的參數(shù)比為1:2,這個參數(shù)比相比其他比例更好。所以為了維持比例,會在學(xué)生模型中多加幾層FFN。

          MobileBERT的蒸餾中,作者先用b的結(jié)構(gòu)預(yù)訓(xùn)練一個BERT-large,再蒸餾到24層學(xué)生模型中。蒸餾的loss有多個:

          • Feature Map Transfer:隱層的MSE
          • Attention Transfer:注意力矩陣的KL散度
          • Pre-training Distillation:

          同時作者還研究了三種不同的蒸餾策略:直接蒸餾所有層、先蒸餾中間層再蒸餾最后一層、逐層蒸餾。如下圖:

          最后的結(jié)論是逐層蒸餾效果最好,但差距最大才0.5個點(diǎn),性價比有些低了。。

          MobileBERT還有一點(diǎn)不同于之前的TinyBERT,就是預(yù)訓(xùn)練階段蒸餾之后,作者直接在MobileBERT上用任務(wù)數(shù)據(jù)精調(diào),而不需要再進(jìn)行精調(diào)階段的蒸餾,方便了很多。

          MiniLM

          之前的各種模型基本上把BERT里面能蒸餾的都蒸了個遍,但MiniLM[7]還是找到了新的藍(lán)海——蒸餾Value-Value矩陣:

          Value-Relation Transfer可以讓學(xué)生模型更深入地模仿教師模型,實(shí)驗(yàn)表明可以帶來1-2個點(diǎn)的提升。同時作者考慮到學(xué)生模型的層數(shù)、維度都可能和教師模型不同,在實(shí)驗(yàn)中只蒸餾最后一層,并且只蒸餾這兩個矩陣的KL散度,簡直是懶癌福音。

          另外,作者還引入了助教機(jī)制。當(dāng)學(xué)生模型的層數(shù)、維度都小很多時,先用一個維度小但層數(shù)和教師模型一致的助教模型蒸餾,之后再把助教的知識傳遞給學(xué)生。

          最終采用BERT-base作為教師,實(shí)驗(yàn)下來6層的學(xué)生模型比起TinyBERT和DistillBERT好了不少,基本是20年性價比數(shù)一數(shù)二的蒸餾了。

          BERT蒸餾技巧

          介紹了BERT蒸餾的幾個經(jīng)典模型之后,真正要上手前還是要把幾個問題都考慮清楚,下面就來討論一些蒸餾中的變量。

          剪層還是減維度?

          這個選擇取決于是預(yù)訓(xùn)練蒸餾還是精調(diào)蒸餾。預(yù)訓(xùn)練蒸餾的數(shù)據(jù)比較充分,可以參考MiniLM、MobileBERT或者TinyBERT那樣進(jìn)行剪層+維度縮減,如果想蒸餾中間層,又不想像MobileBERT一樣增加bottleneck機(jī)制重新訓(xùn)練一個教師模型的話可以參考TinyBERT,在計(jì)算隱層loss時增加一個線性變換,擴(kuò)大學(xué)生模型的維度:

          對于針對某項(xiàng)任務(wù)、只想蒸餾精調(diào)后BERT的情況,則推薦進(jìn)行剪層,同時利用教師模型的層對學(xué)生模型進(jìn)行初始化。從BERT-PKD以及DistillBERT的結(jié)論來看,采用skip(每隔n層選一層)的初始化策略會優(yōu)于只選前k層或后k層。

          用哪個Loss?

          看完原理后相信大家也發(fā)現(xiàn)了,基本上每個模型蒸餾都用的是不同的損失函數(shù),CE、KL、MSE、Cos魔幻組合,自己蒸餾時都不知道選哪個好。。于是rumor我強(qiáng)行梳理了一番,大家可以根據(jù)自己的任務(wù)目標(biāo)挑選:

          對于hard label,使用KL和CE是一樣的,因?yàn)?span style="cursor:pointer;">,訓(xùn)練集不變時label分布是一定的。但對于soft label則不同了,不過表中不少模型還是采用了CE,只有Distilled BiLSTM發(fā)現(xiàn)更好。個人認(rèn)為可以CE/MSE/KL都試一下,但MSE有個好處是可以避免T的調(diào)參。

          中間層輸出的蒸餾,大多數(shù)模型都采用了MSE,只有DistillBERT加入了cosine loss來對齊方向。

          注意力矩陣的蒸餾loss則比較統(tǒng)一,如果要蒸餾softmax之前的attention logits可以采用MSE,之后的attention prob可以用KL散度。

          T和如何設(shè)置?

          超參數(shù)主要控制soft label和hard label的loss比例,Distilled BiLSTM在實(shí)驗(yàn)中發(fā)現(xiàn)只使用soft label會得到最好的效果。個人建議讓soft label占比更多一些,一方面是強(qiáng)迫學(xué)生更多的教師知識,另一方面實(shí)驗(yàn)證實(shí)soft target可以起到正則化的作用,讓學(xué)生模型更穩(wěn)定地收斂。

          超參數(shù)T主要控制預(yù)測分布的平滑程度,TinyBERT實(shí)驗(yàn)發(fā)現(xiàn)T=1更好,BERT-PKD的搜索空間則是{5, 10, 20}。因此建議在1~20之間多嘗試幾次,T越大越能學(xué)到teacher模型的泛化信息。比如MNIST在對2的手寫圖片分類時,可能給2分配0.9的置信度,3是1e-6,7是1e-9,從這個分布可以看出2和3有一定的相似度,這種時候可以調(diào)大T,讓概率分布更平滑,展示teacher更多的泛化能力。

          需要逐層蒸餾嗎?

          如果不是特別追求零點(diǎn)幾個點(diǎn)的提升,建議無腦一次性蒸餾,從MobileBERT來看這個操作性價比太低了。

          蒸餾代碼實(shí)戰(zhàn)

          目前Pytorch版本的模型蒸餾有一個非常贊的開源工具TextBrewer[8],在它的src/textbrewer/losses.py文件下可以看到各種loss的實(shí)現(xiàn)。

          最后輸出層的CE/KL/MSE loss比較簡單,只需要將兩者的logits除temperature之后正常計(jì)算就可以了,以CE為例:

          def?kd_ce_loss(logits_S,?logits_T,?temperature=1):
          ????'''
          ????Calculate?the?cross?entropy?between?logits_S?and?logits_T
          ????:param?logits_S:?Tensor?of?shape?(batch_size,?length,?num_labels)?or?(batch_size,?num_labels)
          ????:param?logits_T:?Tensor?of?shape?(batch_size,?length,?num_labels)?or?(batch_size,?num_labels)
          ????:param?temperature:?A?float?or?a?tensor?of?shape?(batch_size,?length)?or?(batch_size,)
          ????'''

          ????if?isinstance(temperature,?torch.Tensor)?and?temperature.dim()?>?0:
          ????????temperature?=?temperature.unsqueeze(-1)
          ????beta_logits_T?=?logits_T?/?temperature
          ????beta_logits_S?=?logits_S?/?temperature
          ????p_T?=?F.softmax(beta_logits_T,?dim=-1)
          ????loss?=?-(p_T?*?F.log_softmax(beta_logits_S,?dim=-1)).sum(dim=-1).mean()
          ????return?loss

          對于hidden MSE的蒸餾loss,則需要去除被mask的部分,另外如果維度不一致,需要額外加一個線性變換,TextBrewer默認(rèn)輸入維度是一致的:

          def?hid_mse_loss(state_S,?state_T,?mask=None):
          ????'''
          ????*?Calculates?the?mse?loss?between?`state_S`?and?`state_T`,?which?are?the?hidden?state?of?the?models.
          ????*?If?the?`inputs_mask`?is?given,?masks?the?positions?where?``input_mask==0``.
          ????*?If?the?hidden?sizes?of?student?and?teacher?are?different,?'proj'?option?is?required?in?`inetermediate_matches`?to?match?the?dimensions.
          ????:param?torch.Tensor?state_S:?tensor?of?shape??(*batch_size*,?*length*,?*hidden_size*)
          ????:param?torch.Tensor?state_T:?tensor?of?shape??(*batch_size*,?*length*,?*hidden_size*)
          ????:param?torch.Tensor?mask:????tensor?of?shape??(*batch_size*,?*length*)
          ????'''

          ????if?mask?is?None:
          ????????loss?=?F.mse_loss(state_S,?state_T)
          ????else:
          ????????mask?=?mask.to(state_S)
          ????????valid_count?=?mask.sum()?*?state_S.size(-1)
          ????????loss?=?(F.mse_loss(state_S,?state_T,?reduction='none')?*?mask.unsqueeze(-1)).sum()?/?valid_count
          ????return?loss

          蒸餾attention矩陣則也要考慮mask,但注意這里要處理的維度是N*N:

          def?att_mse_loss(attention_S,?attention_T,?mask=None):
          ????'''
          ????*?Calculates?the?mse?loss?between?`attention_S`?and?`attention_T`.
          ????*?If?the?`inputs_mask`?is?given,?masks?the?positions?where?``input_mask==0``.
          ????:param?torch.Tensor?logits_S:?tensor?of?shape??(*batch_size*,?*num_heads*,?*length*,?*length*)
          ????:param?torch.Tensor?logits_T:?tensor?of?shape??(*batch_size*,?*num_heads*,?*length*,?*length*)
          ????:param?torch.Tensor?mask:?tensor?of?shape??(*batch_size*,?*length*)
          ????'''

          ????if?mask?is?None:
          ????????attention_S_select?=?torch.where(attention_S?<=?-1e-3,?torch.zeros_like(attention_S),?attention_S)
          ????????attention_T_select?=?torch.where(attention_T?<=?-1e-3,?torch.zeros_like(attention_T),?attention_T)
          ????????loss?=?F.mse_loss(attention_S_select,?attention_T_select)
          ????else:
          ????????mask?=?mask.to(attention_S).unsqueeze(1).expand(-1,?attention_S.size(1),?-1)?#?(bs,?num_of_heads,?len)
          ????????valid_count?=?torch.pow(mask.sum(dim=2),2).sum()
          ????????loss?=?(F.mse_loss(attention_S,?attention_T,?reduction='none')?*?mask.unsqueeze(-1)?*?mask.unsqueeze(2)).sum()?/?valid_count
          ????return?loss

          最后是只在DistillBERT中出現(xiàn)的cosine loss,可以直接使用pytorch的默認(rèn)接口:

          def?cos_loss(state_S,?state_T,?mask=None):
          ????'''
          ????*?Computes?the?cosine?similarity?loss?between?the?inputs.?This?is?the?loss?used?in?DistilBERT,?see?`DistilBERT?`_
          ????*?If?the?`inputs_mask`?is?given,?masks?the?positions?where?``input_mask==0``.
          ????*?If?the?hidden?sizes?of?student?and?teacher?are?different,?'proj'?option?is?required?in?`inetermediate_matches`?to?match?the?dimensions.
          ????:param?torch.Tensor?state_S:?tensor?of?shape??(*batch_size*,?*length*,?*hidden_size*)
          ????:param?torch.Tensor?state_T:?tensor?of?shape??(*batch_size*,?*length*,?*hidden_size*)
          ????:param?torch.Tensor?mask:????tensor?of?shape??(*batch_size*,?*length*)
          ????'''

          ????if?mask?is??None:
          ????????state_S?=?state_S.view(-1,state_S.size(-1))
          ????????state_T?=?state_T.view(-1,state_T.size(-1))
          ????else:
          ????????mask?=?mask.to(state_S).unsqueeze(-1).expand_as(state_S).to(mask_dtype)?#(bs,len,dim)
          ????????state_S?=?torch.masked_select(state_S,?mask).view(-1,?mask.size(-1))??#(bs?*?select,?dim)
          ????????state_T?=?torch.masked_select(state_T,?mask).view(-1,?mask.size(-1))??#?(bs?*?select,?dim)

          ????target?=?state_S.new(state_S.size(0)).fill_(1)
          ????loss?=?F.cosine_embedding_loss(state_S,?state_T,?target,?reduction='mean')
          ????return?loss

          關(guān)于更多的蒸餾實(shí)戰(zhàn)經(jīng)驗(yàn),可以參考知乎@邱震宇同學(xué)的模型蒸餾技巧小結(jié)[9]。

          總結(jié)

          短暫的學(xué)習(xí)就要結(jié)束了,蒸餾雖然費(fèi)勁,但確實(shí)是目前小模型提升效果的主要方法之一,在很多研究中都有用到。另外,模型蒸餾有一個好處是可以利用大批量的無監(jiān)督數(shù)據(jù),只要能找到任務(wù)相關(guān)的,就可以蒸餾提升模型的泛化能力。標(biāo)注數(shù)據(jù)少的同學(xué)還等什么?快去試試叭!

          往期精彩回顧





          獲取一折本站知識星球優(yōu)惠券,復(fù)制鏈接直接打開:

          https://t.zsxq.com/y7uvZF6

          本站qq群704220115。

          加入微信群請掃碼:

          瀏覽 81
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評論
          圖片
          表情
          推薦
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  美女全裸18禁 | 豆花视频一区二区 | 天天爱天天日天天干 | 日逼网站黄色 | 大香蕉在线精品视频 |