<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          如何寫好BERT知識(shí)蒸餾的損失函數(shù)代碼(一)

          共 6936字,需瀏覽 14分鐘

           ·

          2021-03-26 15:46

          大家好,我是DASOU;

          今天從代碼角度深入了解一下知識(shí)蒸餾,主要核心部分就是分析一下在知識(shí)蒸餾中損失函數(shù)是如何實(shí)現(xiàn)的;

          之前寫過一個(gè)關(guān)于BERT知識(shí)蒸餾的理論的文章,感興趣的朋友可以去看一下:Bert知識(shí)蒸餾系列(一):什么是知識(shí)蒸餾

          知識(shí)蒸餾一個(gè)簡(jiǎn)單的脈絡(luò)可以這么去梳理:學(xué)什么,從哪里學(xué),怎么學(xué)?

          學(xué)什么:學(xué)的是老師的知識(shí),體現(xiàn)在網(wǎng)絡(luò)的參數(shù)上;

          從哪里學(xué):輸入層,中間層,輸出層;

          怎么學(xué):損失函數(shù)度量老師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的差異性;

          從架構(gòu)上來說,BERT可以蒸餾到簡(jiǎn)單的TextCNN,LSTM等,也就可以蒸餾到TRM架構(gòu)模型,比如12層BERT到4層BERT;

          之前工作中用到的是BERT蒸餾到TextCNN;

          最近在往TRM蒸餾靠近,使用的是 Textbrewer 這個(gè)庫(這個(gè)庫太強(qiáng)大了);

          接下來,我從代碼的角度來梳理一下知識(shí)蒸餾的核心步驟,其實(shí)最主要的就是分析一下?lián)p失函數(shù)那塊的代碼形式。

          我以一個(gè)文本分類的任務(wù)為例子,在閱讀理解的過程中,最需要注意的一點(diǎn)是數(shù)據(jù)的流入流出的Shape,這個(gè)很重要,在自己寫代碼的時(shí)候,最重要的其實(shí)就是這個(gè);

          首先使用的是MNLI任務(wù),也就是一個(gè)文本分類任務(wù),三個(gè)標(biāo)簽;

          輸入為Batch_data:[32,128]---[Batch_size,seq_len];

          老師網(wǎng)絡(luò):BERT_base:12層,Hidden_size為768;

          學(xué)生網(wǎng)絡(luò):BERT_base:4層,Hidden_size為312;

          首先第一個(gè)步驟是訓(xùn)練一個(gè)老師網(wǎng)絡(luò),這個(gè)沒啥可說。

          其次是初始化學(xué)生網(wǎng)絡(luò),然后將輸入Batch_data流經(jīng)兩個(gè)網(wǎng)絡(luò);

          在初始化學(xué)生網(wǎng)絡(luò)的時(shí)候,之前有的同學(xué)問到是如何初始化的一個(gè)BERT模型的;

          關(guān)于這個(gè),最主要的是修改Config文件那里的層數(shù),由正常的12改為4,然后如果你不是從本地load參數(shù)到學(xué)生網(wǎng)絡(luò),BERT模型的類會(huì)自動(dòng)調(diào)用初始化;

          關(guān)于代碼實(shí)現(xiàn),我之前寫過一個(gè)文章,大家可以看這里的代碼解析,更加的清洗一點(diǎn):Pytorch代碼驗(yàn)證--如何讓Bert在finetune小數(shù)據(jù)集時(shí)更“穩(wěn)”一點(diǎn)

          然后我們來說數(shù)據(jù)首先流經(jīng)學(xué)生網(wǎng)絡(luò),我們得到兩個(gè)東西,一個(gè)是最后一層【CLS】的輸出,此時(shí)未經(jīng)softmax操作,所以是logits,維度為:[32,3]-[batch_size,label_size];

          第二個(gè)東西是中間隱層的輸出,維度為:[5,32,128,312],也就是 [隱層數(shù)量,batch_size,seq_len,Hidden_size];

          需要注意的是這里的隱層數(shù)量是5,因?yàn)檎5碾[層在模型定義的時(shí)候是4,然后這里是加上了embedding層;

          還有一點(diǎn)需要注意的是,在度量學(xué)生網(wǎng)絡(luò)和老師網(wǎng)絡(luò)隱層差異的時(shí)候,這里是度量的seq_len,也就是對(duì)每個(gè)token的輸出都做了操作;

          如果在這里我們想做類似【CLS】的輸出的時(shí)候,只需要提取最開始的一個(gè)[32,312]的向量就可以;不過,一般來說我們不這么做;

          其次流經(jīng)老師網(wǎng)絡(luò),我們同樣得到兩個(gè)東西,一個(gè)是最后一層【CLS】的輸出,此時(shí)未經(jīng)softmax操作,所以是logits,維度為:[32,3]-[batch_size,label_size];

          第二個(gè)東西是中間隱層的輸出,維度為:[5,32,128,768],也就是 [隱層數(shù)量,batch_size,seq_len,Hidden_size];

          這里需要注意的是老師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)隱層數(shù)量不一樣,一個(gè)是768,一個(gè)是312。

          這其實(shí)是一個(gè)很常見的現(xiàn)象;就是我們的學(xué)生網(wǎng)絡(luò)在減少參數(shù)的時(shí)候,不僅會(huì)變矮,有時(shí)候我們也想讓它變窄,也就是隱層的輸出會(huì)發(fā)生變化,從768變?yōu)?12;

          這個(gè)維度的變化需要注意兩點(diǎn),首先就是在學(xué)生模型初始化的時(shí)候,不能套用老師網(wǎng)絡(luò)的對(duì)應(yīng)層的參數(shù),因?yàn)殡[層Hidden_size發(fā)生了變化。所以一般調(diào)用的是BERT自帶的初始化方式;

          其次就是在度量學(xué)生網(wǎng)絡(luò)和老師網(wǎng)絡(luò)差異性的時(shí)候,因?yàn)榫仃嚧笮〔灰恢?,不能直接做MSE。在代碼層面上,需要做一個(gè)線性映射,才能做MSE。

          而且還需要注意的一點(diǎn)是,由于老師網(wǎng)絡(luò)已經(jīng)固定不動(dòng)了,所以在做映射的時(shí)候我們是要對(duì)學(xué)生網(wǎng)路的312加一個(gè)線性層轉(zhuǎn)化到768層,也就是說這個(gè)線性層是加在了學(xué)生網(wǎng)絡(luò);

          整個(gè)架構(gòu)的損失函數(shù)可以分為三種:首先對(duì)于【CLS】的輸出,使用KL散度度量差異;對(duì)于隱層輸出使用MSE和MMD損失函數(shù)進(jìn)行度量;

          對(duì)于損失函數(shù)這塊的選擇,其實(shí)我覺得沒啥經(jīng)驗(yàn)可說,只能試一試;

          看了很多論文加上自己的經(jīng)驗(yàn),一般來說在最后面使用KL,中間層使用MSE會(huì)更好一點(diǎn);當(dāng)然有的實(shí)驗(yàn)也會(huì)在最后一層直接用MSE;玄學(xué)。

          在初看代碼的時(shí)候,MMD這個(gè)之前我沒接觸過,還特意去看了一下,關(guān)于理論我就不多說了,一會(huì)看代碼吧。

          首先對(duì)【CLS】的輸出,代碼如下:

          def kd_ce_loss(logits_S, logits_T, temperature=1):
          if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:
          temperature = temperature.unsqueeze(-1)
          beta_logits_T = logits_T / temperature
          beta_logits_S = logits_S / temperature
          p_T = F.softmax(beta_logits_T, dim=-1)
          loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()
          return loss

          首先對(duì)于 logits_S,就是學(xué)生網(wǎng)絡(luò)的【CLS】的輸出,logits_T就是老師網(wǎng)絡(luò)【CLS】的輸出,temperature 在代碼中默認(rèn)參數(shù)是1,例子中設(shè)置為了8;

          整個(gè)代碼其實(shí)很簡(jiǎn)單,就是先做Temp的一個(gè)轉(zhuǎn)化,注意這里我們對(duì)學(xué)生網(wǎng)絡(luò)的輸出和老師網(wǎng)絡(luò)的輸出都做了轉(zhuǎn)化,然后做loss計(jì)算;

          其次我們來看比較復(fù)雜的中間層的度量;

          首先需要掌握一點(diǎn),就是學(xué)生網(wǎng)絡(luò)和老師網(wǎng)絡(luò)層之間的對(duì)應(yīng)關(guān)系;

          學(xué)生網(wǎng)絡(luò)是4層,老師網(wǎng)絡(luò)12層,那么在對(duì)應(yīng)的時(shí)候,簡(jiǎn)單的對(duì)應(yīng)關(guān)系就是這樣的:

          layer_T : 0, layer_S : 0,
          layer_T : 3, layer_S : 1,
          layer_T : 6, layer_S : 2,
          layer_T : 9, layer_S : 3,
          layer_T : 12, layer_S : 4,

          這個(gè)對(duì)應(yīng)關(guān)系是需要我們認(rèn)為去設(shè)定的,將學(xué)生網(wǎng)絡(luò)的1層對(duì)應(yīng)到老師網(wǎng)絡(luò)的12層可不可以?當(dāng)然可以,但是效果不一定好;

          一般來說等間隔的對(duì)應(yīng)上就好;

          這個(gè)對(duì)應(yīng)關(guān)系其實(shí)還有一個(gè)用處,就是學(xué)生網(wǎng)絡(luò)在初始化的時(shí)候【假如沒有變窄,只是變矮,也就是層數(shù)變低了】,那么可以從依據(jù)這個(gè)對(duì)應(yīng)關(guān)系把權(quán)重copy過來;

          學(xué)生網(wǎng)絡(luò)的隱層輸出為:[5,32,128,312],老師網(wǎng)絡(luò)隱層輸出為[5,32,128,768]

          那么在代碼實(shí)現(xiàn)的時(shí)候,需要做一個(gè)zip函數(shù)把對(duì)應(yīng)層映射過去,然后每一層計(jì)算MSE,然后加起來作為損失函數(shù);

          我們來看代碼:

          inters_T = {feature: results_T.get(feature,[]) for feature in FEATURES}
          inters_S = {feature: results_S.get(feature,[]) for feature in FEATURES}

          for ith,inter_match in enumerate(self.d_config.intermediate_matches):
          if type(layer_S) is list and type(layer_T) is list: ## MMD損失函數(shù)對(duì)應(yīng)的情況
          inter_S = [inters_S[feature][s] for s in layer_S]
          inter_T = [inters_T[feature][t] for t in layer_T]
          name_S = '-'.join(map(str,layer_S))
          name_T = '-'.join(map(str,layer_T))
          if self.projs[ith]: ## 這里失去做學(xué)生網(wǎng)絡(luò)隱層的映射
          #inter_T = [self.projs[ith](t) for t in inter_T]
          inter_S = [self.projs[ith](s) for s in inter_S]
          else:## MSE 損失函數(shù)
          inter_S = inters_S[feature][layer_S]
          inter_T = inters_T[feature][layer_T]
          name_S = str(layer_S)
          name_T = str(layer_T)
          if self.projs[ith]:
          inter_S = self.projs[ith](inter_S) # 需要注意的是隱層輸出是312,但是老師網(wǎng)絡(luò)是768,所以這里要做一個(gè)linear投影到更高維,方便計(jì)算損失函數(shù)

          intermediate_loss = match_loss(inter_S, inter_T, mask=inputs_mask_S) ## loss = F.mse_loss(state_S, state_T)
          total_loss += intermediate_loss * match_weight

          這個(gè)代碼里面比如迷糊的是【self.d_config.intermediate_matches】,打印出來發(fā)現(xiàn)是這個(gè)東西:

          IntermediateMatch: layer_T : 0, layer_S : 0, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}], 
          IntermediateMatch: layer_T : 3, layer_S : 1, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
          IntermediateMatch: layer_T : 6, layer_S : 2, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
          IntermediateMatch: layer_T : 9, layer_S : 3, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
          IntermediateMatch: layer_T : 12, layer_S : 4, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
          IntermediateMatch: layer_T : [0, 0], layer_S : [0, 0], feature : hidden, weight : 1, loss : mmd, proj : None,
          IntermediateMatch: layer_T : [3, 3], layer_S : [1, 1], feature : hidden, weight : 1, loss : mmd, proj : None,
          IntermediateMatch: layer_T : [6, 6], layer_S : [2, 2], feature : hidden, weight : 1, loss : mmd, proj : None,
          IntermediateMatch: layer_T : [9, 9], layer_S : [3, 3], feature : hidden, weight : 1, loss : mmd, proj : None,
          IntermediateMatch: layer_T : [12, 12], layer_S : [4, 4], feature : hidden, weight : 1, loss : mmd, proj : None

          簡(jiǎn)單說,這個(gè)變量存儲(chǔ)的就是上面我們談到的層與層之間的對(duì)應(yīng)關(guān)系。前面5行就是MSE損失函數(shù)度量,后面那個(gè)注意看,層數(shù)對(duì)應(yīng)的時(shí)候是一個(gè)列表,對(duì)應(yīng)的是MMD損失函數(shù);

          我們來看一下MMD損失的代碼形式:

          def mmd_loss(state_S, state_T, mask=None):
          state_S_0 = state_S[0] # (batch_size , length, hidden_dim_S)
          state_S_1 = state_S[1] # (batch_size , length, hidden_dim_S)
          state_T_0 = state_T[0] # (batch_size , length, hidden_dim_T)
          state_T_1 = state_T[1] # (batch_size , length, hidden_dim_T)
          if mask is None:
          gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2) # (batch_size, length, length)
          gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)
          loss = F.mse_loss(gram_S, gram_T)
          else:
          mask = mask.to(state_S[0])
          valid_count = torch.pow(mask.sum(dim=1), 2).sum()
          gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2) # (batch_size, length, length)
          gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)
          loss = (F.mse_loss(gram_S, gram_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(1)).sum() / valid_count
          return loss

          看最重要的代碼就可以:

          state_S_0 = state_S[0]#  32 128 312 (batch_size , length, hidden_dim_S)
          state_T_0 = state_T[0] # 32 128 768 (batch_size , length, hidden_dim_T)
          gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2)
          gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)

          簡(jiǎn)單說就是現(xiàn)在自己內(nèi)部計(jì)算bmm,然后兩個(gè)矩陣之間做mse;這里如果我沒理解錯(cuò)使用的是一個(gè)線性核函數(shù);

          損失函數(shù)代碼大致就是這樣,之后有時(shí)間我寫個(gè)簡(jiǎn)單的repository,梳理一下整個(gè)流程;

          瀏覽 112
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  污的网站国产 | 综合操逼网 | 青娱乐 欧美在线视频 | 国产一区亚洲天堂 | 欧美一级特黄真人做受 |