【NLP】文本分類與LoRA
共 7331字,需瀏覽 15分鐘
·
2024-06-22 11:00
在這篇博客中,我們逐步進(jìn)行參數(shù)高效微調(diào)(Parameter Efficient Fine Tuning,簡稱PEFT),使用大語言模型(LLM)的低秩適配(Low Rank Adaptation,LoRA)。我們將了解如何使用參數(shù)高效微調(diào)來針對(duì)特定應(yīng)用微調(diào)選定的可訓(xùn)練參數(shù),以最低的成本和最少的基礎(chǔ)設(shè)施實(shí)現(xiàn)這一目標(biāo)。
為什么需要參數(shù)高效微調(diào)?
大語言模型(LLM)已經(jīng)針對(duì)某些任務(wù)進(jìn)行了預(yù)訓(xùn)練;我們可以在應(yīng)用程序中使用LLM來執(zhí)行它們已經(jīng)訓(xùn)練過的任何任務(wù)。然而,這些LLM在我們的環(huán)境中運(yùn)行需要非常昂貴的資源,因此需要參數(shù)高效微調(diào)。
假設(shè)我們能夠以經(jīng)濟(jì)有效的方式在我們的系統(tǒng)上使用大語言模型。這可以通過使用PEFT庫來實(shí)現(xiàn),因?yàn)樗试S我們單獨(dú)使用LLM的一些參數(shù)。
PEFT(參數(shù)高效微調(diào))
參數(shù)高效微調(diào)(Parameter Efficient Fine Tuning,簡稱PEFT)是一個(gè)庫,它允許我們在不對(duì)完整模型進(jìn)行微調(diào)的情況下使用大語言模型(LLM)來執(zhí)行任務(wù),而是對(duì)一些(額外的)參數(shù)進(jìn)行微調(diào)。完整模型的微調(diào)通常需要昂貴的計(jì)算成本,而PEFT通過微調(diào)額外參數(shù)顯著減少了計(jì)算和存儲(chǔ)成本。
參數(shù)高效微調(diào)的優(yōu)勢
-
計(jì)算和存儲(chǔ)成本降低:微調(diào)額外參數(shù)顯著減少了計(jì)算和存儲(chǔ)成本。 -
性能保持一致:與完全微調(diào)的LLM模型相比,性能沒有下降。 -
適用于CPU支持的硬件:額外參數(shù)的微調(diào)使得在CPU支持的硬件上訓(xùn)練和存儲(chǔ)LLM變得更加容易。 -
易于集成:PEFT與諸如transformers和diffusers等庫的集成,使得加載、訓(xùn)練和使用LLM進(jìn)行推理變得更加容易。
LoRA(低秩適配)
LoRA(Low Rank Adaptation)是一種低秩分解方法,旨在減少可訓(xùn)練參數(shù)的數(shù)量,從而在微調(diào)大語言模型(LLM)時(shí)降低內(nèi)存消耗。通過使用LoRA,可以更加輕松地進(jìn)行LLM的微調(diào),同時(shí)顯著減少所需的計(jì)算和存儲(chǔ)資源。
在PEFT(參數(shù)高效微調(diào))中,LoRA配置通過get_peft_model()函數(shù)封裝,以創(chuàng)建一個(gè)可訓(xùn)練的PeftModel。通過調(diào)整LoraConfig中的init_lora_weights參數(shù),可以增加或減少模型權(quán)重,從而優(yōu)化模型的性能和資源消耗。
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# 加載預(yù)訓(xùn)練的模型和分詞器
model_name = "t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# 配置LoRA參數(shù)
lora_config = LoraConfig(
r=8, # 低秩矩陣的秩
lora_alpha=32, # LoRA的alpha參數(shù)
lora_dropout=0.1, # Dropout率
init_lora_weights=0.02 # 初始化LoRA權(quán)重
)
# 獲取PEFT模型
peft_model = get_peft_model(model, lora_config)
# 對(duì)輸入進(jìn)行編碼
input_text = "Translate English to French: The weather is nice today."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# 使用PEFT模型進(jìn)行推理
outputs = peft_model.generate(input_ids)
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Translated text:", output_text)
-
r:低秩矩陣的秩。較低的秩會(huì)減少參數(shù)數(shù)量,從而降低內(nèi)存消耗。 -
lora_alpha:LoRA的alpha參數(shù),用于調(diào)整模型的學(xué)習(xí)率。 -
lora_dropout:Dropout率,有助于防止模型過擬合。 -
init_lora_weights:初始化LoRA權(quán)重的值,可以根據(jù)需要增加或減少模型權(quán)重。
LoRA文本分類案例
-
安裝依賴環(huán)境
!pip install transformers datasets evaluate accelerate peft
-
加載數(shù)據(jù)集
import torch
from transformers import RobertaModel, RobertaTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
peft_model_name = 'roberta-base-peft'
modified_base = 'roberta-base-modified'
base_model = 'roberta-base'
dataset = load_dataset('ag_news')
tokenizer = RobertaTokenizer.from_pretrained(base_model)
def preprocess(examples):
tokenized = tokenizer(examples['text'], truncation=True, padding=True)
return tokenized
tokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=["text"])
train_dataset=tokenized_dataset['train']
eval_dataset=tokenized_dataset['test'].shard(num_shards=2, index=0)
test_dataset=tokenized_dataset['test'].shard(num_shards=2, index=1)
# Extract the number of classess and their names
num_labels = dataset['train'].features['label'].num_classes
class_names = dataset["train"].features["label"].names
print(f"number of labels: {num_labels}")
print(f"the labels: {class_names}")
# Create an id2label mapping
# We will need this for our classifier.
id2label = {i: label for i, label in enumerate(class_names)}
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
-
常規(guī)微調(diào)
# use the same Training args for all models
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy='steps',
learning_rate=5e-5,
num_train_epochs=1,
per_device_train_batch_size=16,
)
def get_trainer(model):
return Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
full_finetuning_trainer = get_trainer(
AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label),
)
full_finetuning_trainer.train()
-
LoRA微調(diào)
model = AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label)
peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
peft_model = get_peft_model(model, peft_config)
peft_model.print_trainable_parameters()
peft_lora_finetuning_trainer = get_trainer(peft_model)
peft_lora_finetuning_trainer.train()
往期精彩回顧
交流群
歡迎加入機(jī)器學(xué)習(xí)愛好者微信群一起和同行交流,目前有機(jī)器學(xué)習(xí)交流群、博士群、博士申報(bào)交流、CV、NLP等微信群,請掃描下面的微信號(hào)加群,備注:”昵稱-學(xué)校/公司-研究方向“,例如:”張小明-浙大-CV“。請按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會(huì)請出群,謝謝理解~(也可以加入機(jī)器學(xué)習(xí)交流qq群772479961)
