如何寫好BERT知識(shí)蒸餾的損失函數(shù)代碼(一)
大家好,我是DASOU;
今天從代碼角度深入了解一下知識(shí)蒸餾,主要核心部分就是分析一下在知識(shí)蒸餾中損失函數(shù)是如何實(shí)現(xiàn)的;
之前寫過一個(gè)關(guān)于BERT知識(shí)蒸餾的理論的文章,感興趣的朋友可以去看一下:Bert知識(shí)蒸餾系列(一):什么是知識(shí)蒸餾。
知識(shí)蒸餾一個(gè)簡(jiǎn)單的脈絡(luò)可以這么去梳理:學(xué)什么,從哪里學(xué),怎么學(xué)?
學(xué)什么:學(xué)的是老師的知識(shí),體現(xiàn)在網(wǎng)絡(luò)的參數(shù)上;
從哪里學(xué):輸入層,中間層,輸出層;
怎么學(xué):損失函數(shù)度量老師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的差異性;
從架構(gòu)上來說,BERT可以蒸餾到簡(jiǎn)單的TextCNN,LSTM等,也就可以蒸餾到TRM架構(gòu)模型,比如12層BERT到4層BERT;
之前工作中用到的是BERT蒸餾到TextCNN;
最近在往TRM蒸餾靠近,使用的是 Textbrewer 這個(gè)庫(這個(gè)庫太強(qiáng)大了);
接下來,我從代碼的角度來梳理一下知識(shí)蒸餾的核心步驟,其實(shí)最主要的就是分析一下?lián)p失函數(shù)那塊的代碼形式。
我以一個(gè)文本分類的任務(wù)為例子,在閱讀理解的過程中,最需要注意的一點(diǎn)是數(shù)據(jù)的流入流出的Shape,這個(gè)很重要,在自己寫代碼的時(shí)候,最重要的其實(shí)就是這個(gè);
首先使用的是MNLI任務(wù),也就是一個(gè)文本分類任務(wù),三個(gè)標(biāo)簽;
輸入為Batch_data:[32,128]---[Batch_size,seq_len];
老師網(wǎng)絡(luò):BERT_base:12層,Hidden_size為768;
學(xué)生網(wǎng)絡(luò):BERT_base:4層,Hidden_size為312;
首先第一個(gè)步驟是訓(xùn)練一個(gè)老師網(wǎng)絡(luò),這個(gè)沒啥可說。
其次是初始化學(xué)生網(wǎng)絡(luò),然后將輸入Batch_data流經(jīng)兩個(gè)網(wǎng)絡(luò);
在初始化學(xué)生網(wǎng)絡(luò)的時(shí)候,之前有的同學(xué)問到是如何初始化的一個(gè)BERT模型的;
關(guān)于這個(gè),最主要的是修改Config文件那里的層數(shù),由正常的12改為4,然后如果你不是從本地load參數(shù)到學(xué)生網(wǎng)絡(luò),BERT模型的類會(huì)自動(dòng)調(diào)用初始化;
關(guān)于代碼實(shí)現(xiàn),我之前寫過一個(gè)文章,大家可以看這里的代碼解析,更加的清洗一點(diǎn):Pytorch代碼驗(yàn)證--如何讓Bert在finetune小數(shù)據(jù)集時(shí)更“穩(wěn)”一點(diǎn);
然后我們來說數(shù)據(jù)首先流經(jīng)學(xué)生網(wǎng)絡(luò),我們得到兩個(gè)東西,一個(gè)是最后一層【CLS】的輸出,此時(shí)未經(jīng)softmax操作,所以是logits,維度為:[32,3]-[batch_size,label_size];
第二個(gè)東西是中間隱層的輸出,維度為:[5,32,128,312],也就是 [隱層數(shù)量,batch_size,seq_len,Hidden_size];
需要注意的是這里的隱層數(shù)量是5,因?yàn)檎5碾[層在模型定義的時(shí)候是4,然后這里是加上了embedding層;
還有一點(diǎn)需要注意的是,在度量學(xué)生網(wǎng)絡(luò)和老師網(wǎng)絡(luò)隱層差異的時(shí)候,這里是度量的seq_len,也就是對(duì)每個(gè)token的輸出都做了操作;
如果在這里我們想做類似【CLS】的輸出的時(shí)候,只需要提取最開始的一個(gè)[32,312]的向量就可以;不過,一般來說我們不這么做;
其次流經(jīng)老師網(wǎng)絡(luò),我們同樣得到兩個(gè)東西,一個(gè)是最后一層【CLS】的輸出,此時(shí)未經(jīng)softmax操作,所以是logits,維度為:[32,3]-[batch_size,label_size];
第二個(gè)東西是中間隱層的輸出,維度為:[5,32,128,768],也就是 [隱層數(shù)量,batch_size,seq_len,Hidden_size];
這里需要注意的是老師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)隱層數(shù)量不一樣,一個(gè)是768,一個(gè)是312。
這其實(shí)是一個(gè)很常見的現(xiàn)象;就是我們的學(xué)生網(wǎng)絡(luò)在減少參數(shù)的時(shí)候,不僅會(huì)變矮,有時(shí)候我們也想讓它變窄,也就是隱層的輸出會(huì)發(fā)生變化,從768變?yōu)?12;
這個(gè)維度的變化需要注意兩點(diǎn),首先就是在學(xué)生模型初始化的時(shí)候,不能套用老師網(wǎng)絡(luò)的對(duì)應(yīng)層的參數(shù),因?yàn)殡[層Hidden_size發(fā)生了變化。所以一般調(diào)用的是BERT自帶的初始化方式;
其次就是在度量學(xué)生網(wǎng)絡(luò)和老師網(wǎng)絡(luò)差異性的時(shí)候,因?yàn)榫仃嚧笮〔灰恢?,不能直接做MSE。在代碼層面上,需要做一個(gè)線性映射,才能做MSE。
而且還需要注意的一點(diǎn)是,由于老師網(wǎng)絡(luò)已經(jīng)固定不動(dòng)了,所以在做映射的時(shí)候我們是要對(duì)學(xué)生網(wǎng)路的312加一個(gè)線性層轉(zhuǎn)化到768層,也就是說這個(gè)線性層是加在了學(xué)生網(wǎng)絡(luò);
整個(gè)架構(gòu)的損失函數(shù)可以分為三種:首先對(duì)于【CLS】的輸出,使用KL散度度量差異;對(duì)于隱層輸出使用MSE和MMD損失函數(shù)進(jìn)行度量;
對(duì)于損失函數(shù)這塊的選擇,其實(shí)我覺得沒啥經(jīng)驗(yàn)可說,只能試一試;
看了很多論文加上自己的經(jīng)驗(yàn),一般來說在最后面使用KL,中間層使用MSE會(huì)更好一點(diǎn);當(dāng)然有的實(shí)驗(yàn)也會(huì)在最后一層直接用MSE;玄學(xué)。
在初看代碼的時(shí)候,MMD這個(gè)之前我沒接觸過,還特意去看了一下,關(guān)于理論我就不多說了,一會(huì)看代碼吧。
首先對(duì)【CLS】的輸出,代碼如下:
def kd_ce_loss(logits_S, logits_T, temperature=1):
if isinstance(temperature, torch.Tensor) and temperature.dim() > 0:
temperature = temperature.unsqueeze(-1)
beta_logits_T = logits_T / temperature
beta_logits_S = logits_S / temperature
p_T = F.softmax(beta_logits_T, dim=-1)
loss = -(p_T * F.log_softmax(beta_logits_S, dim=-1)).sum(dim=-1).mean()
return loss
首先對(duì)于 logits_S,就是學(xué)生網(wǎng)絡(luò)的【CLS】的輸出,logits_T就是老師網(wǎng)絡(luò)【CLS】的輸出,temperature 在代碼中默認(rèn)參數(shù)是1,例子中設(shè)置為了8;
整個(gè)代碼其實(shí)很簡(jiǎn)單,就是先做Temp的一個(gè)轉(zhuǎn)化,注意這里我們對(duì)學(xué)生網(wǎng)絡(luò)的輸出和老師網(wǎng)絡(luò)的輸出都做了轉(zhuǎn)化,然后做loss計(jì)算;
其次我們來看比較復(fù)雜的中間層的度量;
首先需要掌握一點(diǎn),就是學(xué)生網(wǎng)絡(luò)和老師網(wǎng)絡(luò)層之間的對(duì)應(yīng)關(guān)系;
學(xué)生網(wǎng)絡(luò)是4層,老師網(wǎng)絡(luò)12層,那么在對(duì)應(yīng)的時(shí)候,簡(jiǎn)單的對(duì)應(yīng)關(guān)系就是這樣的:
layer_T : 0, layer_S : 0,
layer_T : 3, layer_S : 1,
layer_T : 6, layer_S : 2,
layer_T : 9, layer_S : 3,
layer_T : 12, layer_S : 4,
這個(gè)對(duì)應(yīng)關(guān)系是需要我們認(rèn)為去設(shè)定的,將學(xué)生網(wǎng)絡(luò)的1層對(duì)應(yīng)到老師網(wǎng)絡(luò)的12層可不可以?當(dāng)然可以,但是效果不一定好;
一般來說等間隔的對(duì)應(yīng)上就好;
這個(gè)對(duì)應(yīng)關(guān)系其實(shí)還有一個(gè)用處,就是學(xué)生網(wǎng)絡(luò)在初始化的時(shí)候【假如沒有變窄,只是變矮,也就是層數(shù)變低了】,那么可以從依據(jù)這個(gè)對(duì)應(yīng)關(guān)系把權(quán)重copy過來;
學(xué)生網(wǎng)絡(luò)的隱層輸出為:[5,32,128,312],老師網(wǎng)絡(luò)隱層輸出為[5,32,128,768]
那么在代碼實(shí)現(xiàn)的時(shí)候,需要做一個(gè)zip函數(shù)把對(duì)應(yīng)層映射過去,然后每一層計(jì)算MSE,然后加起來作為損失函數(shù);
我們來看代碼:
inters_T = {feature: results_T.get(feature,[]) for feature in FEATURES}
inters_S = {feature: results_S.get(feature,[]) for feature in FEATURES}
for ith,inter_match in enumerate(self.d_config.intermediate_matches):
if type(layer_S) is list and type(layer_T) is list: ## MMD損失函數(shù)對(duì)應(yīng)的情況
inter_S = [inters_S[feature][s] for s in layer_S]
inter_T = [inters_T[feature][t] for t in layer_T]
name_S = '-'.join(map(str,layer_S))
name_T = '-'.join(map(str,layer_T))
if self.projs[ith]: ## 這里失去做學(xué)生網(wǎng)絡(luò)隱層的映射
#inter_T = [self.projs[ith](t) for t in inter_T]
inter_S = [self.projs[ith](s) for s in inter_S]
else:## MSE 損失函數(shù)
inter_S = inters_S[feature][layer_S]
inter_T = inters_T[feature][layer_T]
name_S = str(layer_S)
name_T = str(layer_T)
if self.projs[ith]:
inter_S = self.projs[ith](inter_S) # 需要注意的是隱層輸出是312,但是老師網(wǎng)絡(luò)是768,所以這里要做一個(gè)linear投影到更高維,方便計(jì)算損失函數(shù)
intermediate_loss = match_loss(inter_S, inter_T, mask=inputs_mask_S) ## loss = F.mse_loss(state_S, state_T)
total_loss += intermediate_loss * match_weight
這個(gè)代碼里面比如迷糊的是【self.d_config.intermediate_matches】,打印出來發(fā)現(xiàn)是這個(gè)東西:
IntermediateMatch: layer_T : 0, layer_S : 0, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : 3, layer_S : 1, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : 6, layer_S : 2, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : 9, layer_S : 3, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : 12, layer_S : 4, feature : hidden, weight : 1, loss : hidden_mse, proj : ['linear', 312, 768, {}],
IntermediateMatch: layer_T : [0, 0], layer_S : [0, 0], feature : hidden, weight : 1, loss : mmd, proj : None,
IntermediateMatch: layer_T : [3, 3], layer_S : [1, 1], feature : hidden, weight : 1, loss : mmd, proj : None,
IntermediateMatch: layer_T : [6, 6], layer_S : [2, 2], feature : hidden, weight : 1, loss : mmd, proj : None,
IntermediateMatch: layer_T : [9, 9], layer_S : [3, 3], feature : hidden, weight : 1, loss : mmd, proj : None,
IntermediateMatch: layer_T : [12, 12], layer_S : [4, 4], feature : hidden, weight : 1, loss : mmd, proj : None
簡(jiǎn)單說,這個(gè)變量存儲(chǔ)的就是上面我們談到的層與層之間的對(duì)應(yīng)關(guān)系。前面5行就是MSE損失函數(shù)度量,后面那個(gè)注意看,層數(shù)對(duì)應(yīng)的時(shí)候是一個(gè)列表,對(duì)應(yīng)的是MMD損失函數(shù);
我們來看一下MMD損失的代碼形式:
def mmd_loss(state_S, state_T, mask=None):
state_S_0 = state_S[0] # (batch_size , length, hidden_dim_S)
state_S_1 = state_S[1] # (batch_size , length, hidden_dim_S)
state_T_0 = state_T[0] # (batch_size , length, hidden_dim_T)
state_T_1 = state_T[1] # (batch_size , length, hidden_dim_T)
if mask is None:
gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2) # (batch_size, length, length)
gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)
loss = F.mse_loss(gram_S, gram_T)
else:
mask = mask.to(state_S[0])
valid_count = torch.pow(mask.sum(dim=1), 2).sum()
gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2) # (batch_size, length, length)
gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)
loss = (F.mse_loss(gram_S, gram_T, reduction='none') * mask.unsqueeze(-1) * mask.unsqueeze(1)).sum() / valid_count
return loss
看最重要的代碼就可以:
state_S_0 = state_S[0]# 32 128 312 (batch_size , length, hidden_dim_S)
state_T_0 = state_T[0] # 32 128 768 (batch_size , length, hidden_dim_T)
gram_S = torch.bmm(state_S_0, state_S_1.transpose(1, 2)) / state_S_1.size(2)
gram_T = torch.bmm(state_T_0, state_T_1.transpose(1, 2)) / state_T_1.size(2)
簡(jiǎn)單說就是現(xiàn)在自己內(nèi)部計(jì)算bmm,然后兩個(gè)矩陣之間做mse;這里如果我沒理解錯(cuò)使用的是一個(gè)線性核函數(shù);
損失函數(shù)代碼大致就是這樣,之后有時(shí)間我寫個(gè)簡(jiǎn)單的repository,梳理一下整個(gè)流程;
