使用DistilBERT 蒸餾類(lèi) BERT 模型的代碼實(shí)現(xiàn)

來(lái)源:DeepHub IMBA 本文約2700字,建議閱讀9分鐘
本文帶你進(jìn)入Distil細(xì)節(jié),并給出完整的代碼實(shí)現(xiàn)。本文為你詳細(xì)介紹DistilBERT,并給出完整的代碼實(shí)現(xiàn)。
機(jī)器學(xué)習(xí)模型已經(jīng)變得越來(lái)越大,即使使用經(jīng)過(guò)訓(xùn)練的模型當(dāng)硬件不符合模型對(duì)它應(yīng)該運(yùn)行的期望時(shí),推理的時(shí)間和內(nèi)存成本也會(huì)飆升。為了緩解這個(gè)問(wèn)題是使用蒸餾可以將網(wǎng)絡(luò)縮小到合理的大小,同時(shí)最大限度地減少性能損失。
我們?cè)谝郧暗奈恼轮薪榻B過(guò) DistilBERT [1] 如何引入一種簡(jiǎn)單而有效的蒸餾技術(shù),該技術(shù)可以輕松應(yīng)用于任何類(lèi)似 BERT 的模型,但沒(méi)有給出任何的代碼實(shí)現(xiàn),在本篇文章中我們將進(jìn)入細(xì)節(jié),并給出完整的代碼實(shí)現(xiàn)。
學(xué)生模型的初始化
由于我們想從現(xiàn)有模型初始化一個(gè)新模型,所以需要訪問(wèn)舊模型的權(quán)重。本文將使用Hugging Face 提供的 RoBERTa [2] large 作為我們的教師模型,要獲得模型權(quán)重,必須知道如何訪問(wèn)它們。
Hugging Face的模型結(jié)構(gòu)
可以嘗試的第一件事是打印模型,這應(yīng)該讓我們深入了解它是如何工作的。當(dāng)然,我們也可以深入研究 Hugging Face 文檔 [3],但這太繁瑣了。
from transformers import AutoModelForMaskedLMroberta = AutoModelForMaskedLM.from_pretrained("roberta-large")print(roberta)
運(yùn)行此代碼后得到:

在 Hugging Face 模型中,可以使用 .children() 生成器訪問(wèn)模塊的子組件。因此,如果我們想使用整個(gè)模型,我們需要在它上面調(diào)用 .children() ,并在每個(gè)子節(jié)點(diǎn)上調(diào)用,這是一個(gè)遞歸函數(shù),代碼如下:
from typing import Anyfrom transformers import AutoModelForMaskedLMroberta = AutoModelForMaskedLM.from_pretrained("roberta-large")def visualize_children(object : Any,level : int = 0,) -> None:"""Prints the children of (object) and their children too, if there are any.Uses the current depth (level) to print things in a ordonnate manner."""print(f"{' ' * level}{level}- {type(object).__name__}")try:for child in object.children():visualize_children(child, level + 1)except:passvisualize_children(roberta)
這樣獲得了如下輸出:

看起來(lái) RoBERTa 模型的結(jié)構(gòu)與其他類(lèi)似 BERT 的模型一樣,如下所示:

復(fù)制教師模型的權(quán)重
要以 DistilBERT [1] 的方式初始化一個(gè)類(lèi)似 BERT 的模型,我們只需要復(fù)制除最深層的 Roberta 層之外的所有內(nèi)容,并且刪除其中的一半。所以這里的步驟如下:首先,我們需要?jiǎng)?chuàng)建學(xué)生模型,其架構(gòu)與教師模型相同,但隱藏層數(shù)減半。只需要使用教師模型的配置,這是一個(gè)類(lèi)似字典的對(duì)象,描述了Hugging Face模型的架構(gòu)。查看 roberta.config 屬性時(shí),我們可以看到以下內(nèi)容:

我們感興趣的是numhidden -layers屬性。讓我們寫(xiě)一個(gè)函數(shù)來(lái)復(fù)制這個(gè)配置,通過(guò)將其除以2來(lái)改變屬性,然后用新的配置創(chuàng)建一個(gè)新的模型:
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaConfigdef distill_roberta(teacher_model : RobertaPreTrainedModel,) -> RobertaPreTrainedModel:"""Distilates a RoBERTa (teacher_model) like would DistilBERT for a BERT model.The student model has the same configuration, except for the number of hidden layers, which is // by 2.The student layers are initilized by copying one out of two layers of the teacher, starting with layer 0.The head of the teacher is also copied."""# Get teacher configuration as a dictionnaryconfiguration = teacher_model.config.to_dict()# Half the number of hidden layerconfiguration['num_hidden_layers'] //= 2# Convert the dictionnary to the student configurationconfiguration = RobertaConfig.from_dict(configuration)# Create uninitialized student modelstudent_model = type(teacher_model)(configuration)# Initialize the student's weightsdistill_roberta_weights(teacher=teacher_model, student=student_model)# Return the student modelreturn student_model
這個(gè)函數(shù)distill_roberta_weights函數(shù)將把教師的一半權(quán)重放在學(xué)生層中,所以仍然需要對(duì)它進(jìn)行編碼。由于遞歸在探索教師模型方面工作得很好,可以使用相同的思想來(lái)探索和復(fù)制某些部分。這里將同時(shí)在老師和學(xué)生的模型中迭代,并將其從一個(gè)到另一個(gè)進(jìn)行復(fù)制。唯一需要注意的是隱藏層的部分,只復(fù)制一半。
函數(shù)如下:
from transformers.models.roberta.modeling_roberta import RobertaEncoder, RobertaModelfrom torch.nn import Moduledef distill_roberta_weights(teacher : Module,student : Module,) -> None:"""Recursively copies the weights of the (teacher) to the (student).This function is meant to be first called on a RobertaFor... model, but is then called on every children of that model recursively.The only part that's not fully copied is the encoder, of which only half is copied."""# If the part is an entire RoBERTa model or a RobertaFor..., unpack and iterateif isinstance(teacher, RobertaModel) or type(teacher).__name__.startswith('RobertaFor'):for teacher_part, student_part in zip(teacher.children(), student.children()):distill_roberta_weights(teacher_part, student_part)# Else if the part is an encoder, copy one out of every layerelif isinstance(teacher, RobertaEncoder):teacher_encoding_layers = [layer for layer in next(teacher.children())]student_encoding_layers = [layer for layer in next(student.children())]for i in range(len(student_encoding_layers)):student_encoding_layers[i].load_state_dict(teacher_encoding_layers[2*i].state_dict())# Else the part is a head or something else, copy the state_dictelse:student.load_state_dict(teacher.state_dict())
這個(gè)函數(shù)通過(guò)遞歸和類(lèi)型檢查,確保學(xué)生模型與 Roberta 層的教師安全模型相同。如果想在初始化的時(shí)候改變復(fù)制哪些層,只需要更改encoder部分的for循環(huán)就可以了。
現(xiàn)在我們有了學(xué)生模型,我們需要對(duì)其進(jìn)行訓(xùn)練。這部分相對(duì)簡(jiǎn)單,主要的問(wèn)題就是使用的損失函數(shù)。
自定義損失函數(shù)
作為對(duì) DistilBERT 訓(xùn)練過(guò)程的回顧,先看一下下圖:

請(qǐng)把注意力轉(zhuǎn)向上面寫(xiě)著“損失”的紅色大盒子。但是在詳細(xì)介紹里面是什么之前,需要知道如何收集我們要喂給它的東西。在這張圖中可以看到需要 3 個(gè)東西:標(biāo)簽、學(xué)生和教師的嵌入。標(biāo)簽已經(jīng)有了,因?yàn)槭怯斜O(jiān)督的學(xué)習(xí)。現(xiàn)在看啊可能如何得到另外兩個(gè)。
教師和學(xué)生的輸入
在這里需要一個(gè)函數(shù),給定一個(gè)類(lèi) BERT 模型的輸入,包括兩個(gè)張量 input_ids 和 attention_mask 以及模型本身,然后函數(shù)將返回該模型的 logits。由于我們使用的是 Hugging Face,這非常簡(jiǎn)單,我們需要的唯一知識(shí)就是能看懂下面的代碼:
from torch import Tensordef get_logits(model : RobertaPreTrainedModel,input_ids : Tensor,attention_mask : Tensor,) -> Tensor:"""Given a RoBERTa (model) for classification and the couple of (input_ids) and (attention_mask),returns the logits corresponding to the prediction."""return model.classifier(model.roberta(input_ids, attention_mask)[0])
學(xué)生和老師都可以使用這個(gè)函數(shù),但是第一個(gè)有梯度,第二個(gè)沒(méi)有。
損失函數(shù)的代碼實(shí)現(xiàn)
損失函數(shù)具體的介紹請(qǐng)見(jiàn)我們上次發(fā)布的文章,這里使用下面的圖片進(jìn)行解釋?zhuān)?/span>

我們所說(shuō)的“‘converging cosine-loss(收斂余弦損失)”是用于對(duì)齊兩個(gè)輸入向量的常規(guī)余弦損失。這是代碼:
import torchfrom torch.nn import CrossEntropyLoss, CosineEmbeddingLossdef distillation_loss(teacher_logits : Tensor,student_logits : Tensor,labels : Tensor,temperature : float = 1.0,) -> Tensor:"""The distillation loss for distilating a BERT-like model.The loss takes the (teacher_logits), (student_logits) and (labels) for various losses.The (temperature) can be given, otherwise it's set to 1 by default."""# Temperature and sotfmaxstudent_logits, teacher_logits = (student_logits / temperature).softmax(1), (teacher_logits / temperature).softmax(1)# Classification loss (problem-specific loss)loss = CrossEntropyLoss()(student_logits, labels)# CrossEntropy teacher-student lossloss = loss + CrossEntropyLoss()(student_logits, teacher_logits)# Cosine lossloss = loss + CosineEmbeddingLoss()(teacher_logits, student_logits, torch.ones(teacher_logits.size()[0]))# Average the loss and return itloss = loss / 3return loss
以上就是 DistilBERT 的所有關(guān)鍵思想的實(shí)現(xiàn),但是還缺少一些東西,比如 GPU 支持、整個(gè)訓(xùn)練例程等,所以最后完整的代碼會(huì)在文章的最后提供,如果需要實(shí)際使用,建議使用最后的 Distillator 類(lèi)。
結(jié)果
以這種方式提煉出來(lái)的模型最終表現(xiàn)如何呢?對(duì)于 DistilBERT,可以閱讀原始論文 [1]。對(duì)于 RoBERTa,Hugging Face 上已經(jīng)存在類(lèi)似 DistilBERT 的蒸餾版本。在 GLUE 基準(zhǔn) [4] 上,我們可以比較兩個(gè)模型:

至于時(shí)間和內(nèi)存成本,這個(gè)模型大約是 roberta-base 大小的三分之二,速度是兩倍。
總結(jié)
通過(guò)以上的代碼我們可以蒸餾任何類(lèi)似 BERT 的模型。?除此以外還有很多其他更好的方法,例如 TinyBERT [5] 或 MobileBERT [6]。如果你認(rèn)為其中一篇更適合您的需求,你應(yīng)該閱讀這些文章。甚至是完全嘗試一種新的蒸餾方法,因?yàn)檫@是一個(gè)日益發(fā)展的領(lǐng)域。
