【小白學(xué)習(xí)PyTorch教程】十六、在多標(biāo)簽分類(lèi)任務(wù)上 微調(diào)BERT模型
「@Author:Runsen」
BERT模型在NLP各項(xiàng)任務(wù)中大殺四方,那么我們?nèi)绾问褂眠@一利器來(lái)為我們?nèi)粘5腘LP任務(wù)來(lái)服務(wù)呢?首先介紹使用BERT做文本多標(biāo)簽分類(lèi)任務(wù)。
文本多標(biāo)簽分類(lèi)是常見(jiàn)的NLP任務(wù),文本介紹了如何使用Bert模型完成文本多標(biāo)簽分類(lèi),并給出了各自的步驟。
參考官方教程:https://pytorch.org/tutorials/intermediate/dynamic_quantization_bert_tutorial.html
復(fù)旦大學(xué)邱錫鵬老師課題組的研究論文《How to Fine-Tune BERT for Text Classification?》。
論文: https://arxiv.org/pdf/1905.05583.pdf
這篇論文的主要目的在于在文本分類(lèi)任務(wù)上探索不同的BERT微調(diào)方法并提供一種通用的BERT微調(diào)解決方法。這篇論文從三種路線進(jìn)行了探索:
(1) BERT自身的微調(diào)策略,包括長(zhǎng)文本處理、學(xué)習(xí)率、不同層的選擇等方法; (2) 目標(biāo)任務(wù)內(nèi)、領(lǐng)域內(nèi)及跨領(lǐng)域的進(jìn)一步預(yù)訓(xùn)練BERT; (3) 多任務(wù)學(xué)習(xí)。微調(diào)后的BERT在七個(gè)英文數(shù)據(jù)集及搜狗中文數(shù)據(jù)集上取得了當(dāng)前最優(yōu)的結(jié)果。
作者的實(shí)現(xiàn)代碼: https://github.com/xuyige/BERT4doc-Classification
數(shù)據(jù)集來(lái)源:https://www.kaggle.com/shivanandmn/multilabel-classification-dataset?select=train.csv
該數(shù)據(jù)集包含 6 個(gè)不同的標(biāo)簽(計(jì)算機(jī)科學(xué)、物理、數(shù)學(xué)、統(tǒng)計(jì)學(xué)、生物學(xué)、金融),以根據(jù)摘要和標(biāo)題對(duì)研究論文進(jìn)行分類(lèi)。標(biāo)簽列中的值 1 表示標(biāo)簽屬于該標(biāo)簽。每個(gè)論文有多個(gè)標(biāo)簽為 1。
Bert模型加載
Transformer 為我們提供了一個(gè)基于 Transformer 的可以微調(diào)的預(yù)訓(xùn)練網(wǎng)絡(luò)。
由于數(shù)據(jù)集是英文, 因此這里選擇加載bert-base-uncased。
具體下載鏈接:https://huggingface.co/bert-base-uncased/tree/main
from transformers import BertTokenizerFast as BertTokenizer
# 直接下載很很慢,建議下載到文件夾中
# BERT_MODEL_NAME = "bert-base-uncased"
BERT_MODEL_NAME = "model/bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)
微調(diào)BERT模型
bert微調(diào)就是在預(yù)訓(xùn)練模型bert的基礎(chǔ)上只需更新后面幾層的參數(shù),這相對(duì)于從頭開(kāi)始訓(xùn)練可以節(jié)省大量時(shí)間,甚至可以提高性能,通常情況下在模型的訓(xùn)練過(guò)程中,我們也會(huì)更新bert的參數(shù),這樣模型的性能會(huì)更好。
微調(diào)BERT模型主要在D_out進(jìn)行相關(guān)的改變,去除segment層,直接采用了字符輸入,不再需要segment層。
下面是微調(diào)BERT的主要代碼
class BertClassifier(nn.Module):
def __init__(self, num_labels: int, BERT_MODEL_NAME, freeze_bert=False):
super().__init__()
self.num_labels = num_labels
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME)
# hidden size of BERT, hidden size of our classifier, and number of labels to classify
D_in, H, D_out = self.bert.config.hidden_size, 50, num_labels
# Instantiate an one-layer feed-forward classifier
self.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(D_in, H),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(H, D_out),
)
# loss
self.loss_func = nn.BCEWithLogitsLoss()
if freeze_bert:
print("freezing bert parameters")
for param in self.bert.parameters():
param.requires_grad = False
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.bert(input_ids, attention_mask=attention_mask)
last_hidden_state_cls = outputs[0][:, 0, :]
logits = self.classifier(last_hidden_state_cls)
if labels is not None:
predictions = torch.sigmoid(logits)
loss = self.loss_func(
predictions.view(-1, self.num_labels), labels.view(-1, self.num_labels)
)
return loss
else:
return logits
其他
關(guān)于數(shù)據(jù)預(yù)處理,DataLoader等代碼有點(diǎn)多,這里不一一列舉,需要代碼的在公眾號(hào)回復(fù):”「bert」“ 。
最后的訓(xùn)練結(jié)果如下所示:



