用 Sentence Transformers v3 訓(xùn)練和微調(diào)嵌入模型
共 33614字,需瀏覽 68分鐘
·
2024-06-07 15:33
-
Sentence Transformers https://sbert.net/
現(xiàn)在,微調(diào) Sentence Transformers 涉及幾個(gè)組成部分,包括數(shù)據(jù)集、損失函數(shù)、訓(xùn)練參數(shù)、評(píng)估器以及新的訓(xùn)練器本身。我將詳細(xì)講解每個(gè)組成部分,并提供如何使用它們來訓(xùn)練有效模型的示例。
為什么進(jìn)行微調(diào)?
微調(diào) Sentence Transformer 模型可以顯著提高它們?cè)谔囟ㄈ蝿?wù)上的性能。這是因?yàn)槊總€(gè)任務(wù)都需要獨(dú)特的相似性概念。讓我們以幾個(gè)新聞文章標(biāo)題為例:
-
“Apple 發(fā)布新款 iPad” -
“NVIDIA 正在為下一代 GPU 做準(zhǔn)備 “
根據(jù)用例的不同,我們可能希望這些文本具有相似或不相似的嵌入。例如,一個(gè)針對(duì)新聞文章的分類模型可能會(huì)將這些文本視為相似,因?yàn)樗鼈兌紝儆诩夹g(shù)類別。另一方面,一個(gè)語義文本相似度或檢索模型應(yīng)該將它們視為不相似,因?yàn)樗鼈兙哂胁煌暮x。
訓(xùn)練組件
訓(xùn)練 Sentence Transformer 模型涉及以下組件:
-
數(shù)據(jù)集 : 用于訓(xùn)練和評(píng)估的數(shù)據(jù)。 -
損失函數(shù) : 一個(gè)量化模型性能并指導(dǎo)優(yōu)化過程的函數(shù)。 -
訓(xùn)練參數(shù) (可選): 影響訓(xùn)練性能和跟蹤/調(diào)試的參數(shù)。 -
評(píng)估器 (可選): 一個(gè)在訓(xùn)練前、中或后評(píng)估模型的工具。 -
訓(xùn)練器 : 將模型、數(shù)據(jù)集、損失函數(shù)和其他組件整合在一起進(jìn)行訓(xùn)練。
現(xiàn)在,讓我們更詳細(xì)地了解這些組件。
數(shù)據(jù)集
-
SentenceTransformerTrainer https://sbert.net/docs/packagereference/sentencetransformer/SentenceTransformer.html#sentencetransformers.SentenceTransformer -
datasets.Dataset https://hf.co/docs/datasets/main/en/packagereference/mainclasses#datasets.Dataset -
datasets.DatasetDict https://hf.co/docs/datasets/main/en/packagereference/mainclasses#datasets.DatasetDict
注意: 許多開箱即用的 Sentence Transformers 的 Hugging Face 數(shù)據(jù)集已經(jīng)標(biāo)記為 sentence-transformers ,你可以通過瀏覽
-
https://hf.co/datasets?other=sentence-transformers https://hf.co/datasets?other=sentence-transformers
Hugging Face Hub 上的數(shù)據(jù)
要從 Hugging Face Hub 中的數(shù)據(jù)集加載數(shù)據(jù),請(qǐng)使用
-
loaddataset https://hf.co/docs/datasets/main/en/packagereference/loadingmethods#datasets.loaddataset
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev")
print(train_dataset)
"""
Dataset({
features: ['premise', 'hypothesis', 'label'],
num_rows: 942069
})
"""
一些數(shù)據(jù)集,如
-
sentence-transformers/all-nli https://hf.co/datasets/sentence-transformers/all-nli
本地?cái)?shù)據(jù) (CSV, JSON, Parquet, Arrow, SQL)
如果你有常見文件格式的本地?cái)?shù)據(jù),你也可以使用
-
loaddataset https://hf.co/docs/datasets/main/en/packagereference/loadingmethods#datasets.loaddataset
from datasets import load_dataset
dataset = load_dataset("csv", data_files="my_file.csv")
# or
dataset = load_dataset("json", data_files="my_file.json")
需要預(yù)處理的本地?cái)?shù)據(jù)
如果你的本地?cái)?shù)據(jù)需要預(yù)處理,你可以使用
-
datasets.Dataset.fromdict https://hf.co/docs/datasets/main/en/packagereference/mainclasses#datasets.Dataset.fromdict
from datasets import Dataset
anchors = []
positives = []
# Open a file, perform preprocessing, filtering, cleaning, etc.
# and append to the lists
dataset = Dataset.from_dict({
"anchor": anchors,
"positive": positives,
})
字典中的每個(gè)鍵都成為結(jié)果數(shù)據(jù)集中的列。
數(shù)據(jù)集格式
確保你的數(shù)據(jù)集格式與你選擇的 損失函數(shù) 相匹配至關(guān)重要。這包括檢查兩件事:
-
如果你的損失函數(shù)需要 標(biāo)簽 (如 損失概覽 表中所指示),你的數(shù)據(jù)集必須有一個(gè)名為“l(fā)abel” 或“score”的列。https://sbert.net/docs/sentencetransformer/lossoverview.html -
除 “l(fā)abel” 或 “score” 之外的所有列都被視為 輸入 (如 損失概覽 表中所指示)。這些列的數(shù)量必須與你選擇的損失函數(shù)的有效輸入數(shù)量相匹配。列的名稱無關(guān)緊要, 只有它們的順序重要。https://sbert.net/docs/sentencetransformer/lossoverview.html
例如,如果你的損失函數(shù)接受 (anchor, positive, negative) 三元組,那么你的數(shù)據(jù)集的第一、第二和第三列分別對(duì)應(yīng)于 anchor 、 positive 和 negative 。這意味著你的第一和第二列必須包含應(yīng)該緊密嵌入的文本,而你的第一和第三列必須包含應(yīng)該遠(yuǎn)距離嵌入的文本。這就是為什么根據(jù)你的損失函數(shù),你的數(shù)據(jù)集列順序很重要的原因。 考慮一個(gè)帶有 ["text1", "text2", "label"] 列的數(shù)據(jù)集,其中 "label" 列包含浮點(diǎn)數(shù)相似性得分。這個(gè)數(shù)據(jù)集可以用 CoSENTLoss 、 AnglELoss 和 CosineSimilarityLoss ,因?yàn)?
-
數(shù)據(jù)集有一個(gè)“l(fā)abel”列,這是這些損失函數(shù)所必需的。 -
數(shù)據(jù)集有 2 個(gè)非標(biāo)簽列,與這些損失函數(shù)所需的輸入數(shù)量相匹配。
如果你的數(shù)據(jù)集中的列沒有正確排序,請(qǐng)使用
-
Dataset.selectcolumns https://hf.co/docs/datasets/main/en/packagereference/mainclasses#datasets.Dataset.selectcolumns -
Dataset.removecolumns https://hf.co/docs/datasets/main/en/packagereference/mainclasses#datasets.Dataset.removecolumns
損失函數(shù)
損失函數(shù)衡量模型在給定數(shù)據(jù)批次上的表現(xiàn),并指導(dǎo)優(yōu)化過程。損失函數(shù)的選擇取決于你可用的數(shù)據(jù)和目標(biāo)任務(wù)。請(qǐng)參閱
-
損失概覽 https://sbert.net/docs/sentencetransformer/lossoverview.html
大多數(shù)損失函數(shù)可以使用你正在訓(xùn)練的 SentenceTransformer model 來初始化:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import CoSENTLoss
# Load a model to train/finetune
model = SentenceTransformer("FacebookAI/xlm-roberta-base")
# Initialize the CoSENTLoss
# This loss requires pairs of text and a floating point similarity score as a label
loss = CoSENTLoss(model)
# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train")
"""
Dataset({
features: ['sentence1', 'sentence2', 'label'],
num_rows: 942069
})
"""
訓(xùn)練參數(shù)
-
SentenceTransformersTrainingArguments https://sbert.net/docs/packagereference/sentencetransformer/trainingargs.html#sentencetransformertrainingarguments
在 Sentence Transformers 的文檔中,我概述了一些最有用的訓(xùn)練參數(shù)。我建議你閱讀
-
訓(xùn)練概覽 > 訓(xùn)練參數(shù) https://sbert.net/docs/sentencetransformer/trainingoverview.html#training-arguments
以下是如何初始化
-
SentenceTransformersTrainingArguments https://sbert.net/docs/packagereference/sentencetransformer/trainingargs.html#sentencetransformertrainingarguments
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir="models/mpnet-base-all-nli-triplet",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
warmup_ratio=0.1,
fp16=True, # Set to False if your GPU can't handle FP16
bf16=False, # Set to True if your GPU supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # Losses using "in-batch negatives" benefit from no duplicates
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="mpnet-base-all-nli-triplet", # Used in W&B if `wandb` is installed
)
注意 evalstrategy 是在 transformers 版本 4.41.0 中引入的。之前的版本應(yīng)該使用 evaluationstrategy 代替。
評(píng)估器
你可以為
-
SentenceTransformerTrainer https://sbert.net/docs/packagereference/sentencetransformer/SentenceTransformer.html#sentencetransformers.SentenceTransformer
以下是 Sentence Tranformers 隨附的已實(shí)現(xiàn)的評(píng)估器:
| 評(píng)估器 | 所需數(shù)據(jù) |
|---|---|
|
|
帶有類別標(biāo)簽的句子對(duì) |
|
|
帶有相似性得分的句子對(duì) |
|
|
查詢 (qid => 問題) ,語料庫(kù) (cid => 文檔),以及相關(guān)文檔 (qid => 集合[cid]) |
|
|
需要由教師模型嵌入的源句子和需要由學(xué)生模型嵌入的目標(biāo)句子??梢允窍嗤奈谋?。 |
|
|
ID 到句子的映射以及帶有重復(fù)句子 ID 的句子對(duì)。 |
|
|
{'query': '..', 'positive': [...], 'negative': [...]} 字典的列表。 |
|
|
兩種不同語言的句子對(duì)。 |
|
|
(錨點(diǎn),正面,負(fù)面) 三元組。 |
-
BinaryClassificationEvaluator https://sbert.net/docs/packagereference/sentencetransformer/evaluation.html#binaryclassificationevaluator -
EmbeddingSimilarityEvaluator https://sbert.net/docs/packagereference/sentencetransformer/evaluation.html#embeddingsimilarityevaluator -
InformationRetrievalEvaluator https://sbert.net/docs/packagereference/sentencetransformer/evaluation.html#informationretrievalevaluator -
MSEEvaluator https://sbert.net/docs/packagereference/sentencetransformer/evaluation.html#mseevaluator -
ParaphraseMiningEvaluator https://sbert.net/docs/packagereference/sentencetransformer/evaluation.html#paraphraseminingevaluator -
RerankingEvaluator https://sbert.net/docs/packagereference/sentencetransformer/evaluation.html#rerankingevaluator -
TranslationEvaluator https://sbert.net/docs/packagereference/sentencetransformer/evaluation.html#translationevaluator -
TripletEvaluator https://sbert.net/docs/packagereference/sentencetransformer/evaluation.html#tripletevaluator
此外,你可以使用
-
SequentialEvaluator https://sbert.net/docs/packagereference/sentencetransformer/evaluation.html#sequentialevaluator -
SentenceTransformerTrainer https://sbert.net/docs/packagereference/sentencetransformer/SentenceTransformer.html#sentencetransformers.SentenceTransformer
如果你沒有必要的評(píng)估數(shù)據(jù)但仍然想跟蹤模型在常見基準(zhǔn)上的性能,你可以使用 Hugging Face 上的數(shù)據(jù)與這些評(píng)估器一起使用。
使用 STSb 的 Embedding Similarity Evaluator
STS 基準(zhǔn)測(cè)試 (也稱為 STSb) 是一種常用的基準(zhǔn)數(shù)據(jù)集,用于衡量模型對(duì)短文本 (如 “A man is feeding a mouse to a snake.”) 的語義文本相似性的理解。
你可以自由瀏覽 Hugging Face 上的
-
sentence-transformers/stsb https://hf.co/datasets/sentence-transformers/stsb
from datasets import load_dataset
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction
# Load the STSB dataset
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
# Initialize the evaluator
dev_evaluator = EmbeddingSimilarityEvaluator(
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
scores=eval_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-dev",
)
# Run evaluation manually:
# print(dev_evaluator(model))
# Later, you can provide this evaluator to the trainer to get results during training
使用 AllNLI 的 Triplet Evaluator
AllNLI 是
-
SNLI https://hf.co/datasets/stanfordnlp/snli -
MultiNLI https://hf.co/datasets/nyu-mll/multinli
在這個(gè)片段中,它被用來評(píng)估模型認(rèn)為錨文本和蘊(yùn)含文本比錨文本和矛盾文本更相似的頻率。一個(gè)示例文本是 “An older man is drinking orange juice at a restaurant.”。
你可以自由瀏覽 Hugging Face 上的
-
sentence-transformers/all-nli https://hf.co/datasets/sentence-transformers/all-nli
from datasets import load_dataset
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction
# Load triplets from the AllNLI dataset
max_samples = 1000
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split=f"dev[:{max_samples}]")
# Initialize the evaluator
dev_evaluator = TripletEvaluator(
anchors=eval_dataset["anchor"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
main_distance_function=SimilarityFunction.COSINE,
name=f"all-nli-{max_samples}-dev",
)
# Run evaluation manually:
# print(dev_evaluator(model))
# Later, you can provide this evaluator to the trainer to get results during training
訓(xùn)練器
-
SentenceTransformerTrainer https://sbert.net/docs/packagereference/sentencetransformer/SentenceTransformer.html#sentencetransformers.SentenceTransformer
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator
# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer(
"microsoft/mpnet-base",
model_card_data=SentenceTransformerModelCardData(
language="en",
license="apache-2.0",
model_name="MPNet base trained on AllNLI triplets",
)
)
# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(100_000))
eval_dataset = dataset["dev"]
test_dataset = dataset["test"]
# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)
# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir="models/mpnet-base-all-nli-triplet",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
warmup_ratio=0.1,
fp16=True, # Set to False if GPU can't handle FP16
bf16=False, # Set to True if GPU supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicates
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="mpnet-base-all-nli-triplet", # Used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = TripletEvaluator(
anchors=eval_dataset["anchor"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
name="all-nli-dev",
)
dev_evaluator(model)
# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# (Optional) Evaluate the trained model on the test set, after training completes
test_evaluator = TripletEvaluator(
anchors=test_dataset["anchor"],
positives=test_dataset["positive"],
negatives=test_dataset["negative"],
name="all-nli-test",
)
test_evaluator(model)
# 8. Save the trained model
model.save_pretrained("models/mpnet-base-all-nli-triplet/final")
# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub("mpnet-base-all-nli-triplet")
在這個(gè)示例中,我從一個(gè)尚未成為 Sentence Transformer 模型的基礎(chǔ)模型
-
microsoft/mpnet-base https://hf.co/microsoft/mpnet-base -
all-mpnet-base-v2 https://hf.co/sentence-transformers/all-mpnet-base-v2
運(yùn)行此腳本后,
-
tomaarsen/mpnet-base-all-nli-triplet https://hf.co/tomaarsen/mpnet-base-all-nli-triplet -
microsoft/mpnet-base https://hf.co/microsoft/mpnet-base
所有這些信息都被自動(dòng)生成的模型卡存儲(chǔ),包括基礎(chǔ)模型、語言、許可證、評(píng)估結(jié)果、訓(xùn)練和評(píng)估數(shù)據(jù)集信息、超參數(shù)、訓(xùn)練日志等。無需任何努力,你上傳的模型應(yīng)該包含潛在用戶判斷你的模型是否適合他們的所有信息。
回調(diào)函數(shù)
Sentence Transformers 訓(xùn)練器支持各種
-
transformers.TrainerCallback https://hf.co/docs/transformers/mainclasses/callback#transformers.TrainerCallback -
WandbCallback : 如果已安裝 wandb ,則將訓(xùn)練指標(biāo)記錄到 W&Bhttps://hf.co/docs/transformers/en/mainclasses/callback#transformers.integrations.WandbCallback -
TensorBoardCallback : 如果可訪問 tensorboard ,則將訓(xùn)練指標(biāo)記錄到 TensorBoardhttps://hf.co/docs/transformers/en/mainclasses/callback#transformers.integrations.TensorBoardCallback -
CodeCarbonCallback : 如果已安裝 codecarbon ,則跟蹤訓(xùn)練期間的碳排放https://hf.co/docs/transformers/en/mainclasses/callback#transformers.integrations.CodeCarbonCallback
這些回調(diào)函數(shù)會(huì)自動(dòng)使用,無需你進(jìn)行任何指定,只要安裝了所需的依賴項(xiàng)即可。
有關(guān)這些回調(diào)函數(shù)的更多信息以及如何創(chuàng)建你自己的回調(diào)函數(shù),請(qǐng)參閱
-
Transformers 回調(diào)文檔 https://hf.co/docs/transformers/en/mainclasses/callback
多數(shù)據(jù)集訓(xùn)練
通常情況下,表現(xiàn)最好的模型是通過同時(shí)使用多個(gè)數(shù)據(jù)集進(jìn)行訓(xùn)練的。
-
SentenceTransformerTrainer https://sbert.net/docs/packagereference/sentencetransformer/SentenceTransformer.html#sentencetransformers.SentenceTransformer
-
使用一個(gè) datasets.Dataset https://hf.co/docs/datasets/main/en/packagereference/mainclasses#datasets.Dataset 實(shí)例的字典 (或datasets.DatasetDict https://hf.co/docs/datasets/main/en/packagereference/mainclasses#datasets.DatasetDict ) 作為 traindataset 和 evaldataset 。 -
(可選) 如果你希望為不同的數(shù)據(jù)集使用不同的損失函數(shù),請(qǐng)使用一個(gè)損失函數(shù)的字典,其中數(shù)據(jù)集名稱映射到損失。
每個(gè)訓(xùn)練/評(píng)估批次將僅包含來自一個(gè)數(shù)據(jù)集的樣本。從多個(gè)數(shù)據(jù)集中采樣批次的順序由
-
MultiDatasetBatchSamplers https://sbert.net/docs/packagereference/sentencetransformer/trainingargs.html#sentencetransformers.trainingargs.MultiDatasetBatchSamplers -
SentenceTransformersTrainingArguments https://sbert.net/docs/packagereference/sentencetransformer/trainingargs.html#sentencetransformertrainingarguments -
MultiDatasetBatchSamplers.ROUNDROBIN : 以輪詢方式從每個(gè)數(shù)據(jù)集采樣,直到一個(gè)數(shù)據(jù)集用盡。這種策略可能不會(huì)使用每個(gè)數(shù)據(jù)集中的所有樣本,但它確保了每個(gè)數(shù)據(jù)集的平等采樣。
-
MultiDatasetBatchSamplers.PROPORTIONAL (默認(rèn)): 按比例從每個(gè)數(shù)據(jù)集采樣。這種策略確保了每個(gè)數(shù)據(jù)集中的所有樣本都被使用,并且較大的數(shù)據(jù)集被更頻繁地采樣。
多任務(wù)訓(xùn)練已被證明是高度有效的。例如,
-
Huang et al. 2024 https://arxiv.org/pdf/2405.06932 -
MultipleNegativesRankingLoss https://sbert.net/docs/packagereference/sentencetransformer/losses.html#multiplenegativesrankingloss -
CoSENTLoss https://sbert.net/docs/packagereference/sentencetransformer/losses.html#cosentloss -
MultipleNegativesRankingLoss https://sbert.net/docs/packagereference/sentencetransformer/losses.html#multiplenegativesrankingloss -
MatryoshkaLoss https://sbert.net/docs/packagereference/sentencetransformer/losses.html#matryoshkaloss -
Matryoshka Embeddings https://hf.co/blog/matryoshka
以下是多數(shù)據(jù)集訓(xùn)練的一個(gè)示例:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
from sentence_transformers.losses import CoSENTLoss, MultipleNegativesRankingLoss, SoftmaxLoss
# 1. Load a model to finetune
model = SentenceTransformer("bert-base-uncased")
# 2. Loadseveral Datasets to train with
# (anchor, positive)
all_nli_pair_train = load_dataset("sentence-transformers/all-nli", "pair", split="train[:10000]")
# (premise, hypothesis) + label
all_nli_pair_class_train = load_dataset("sentence-transformers/all-nli", "pair-class", split="train[:10000]")
# (sentence1, sentence2) + score
all_nli_pair_score_train = load_dataset("sentence-transformers/all-nli", "pair-score", split="train[:10000]")
# (anchor, positive, negative)
all_nli_triplet_train = load_dataset("sentence-transformers/all-nli", "triplet", split="train[:10000]")
# (sentence1, sentence2) + score
stsb_pair_score_train = load_dataset("sentence-transformers/stsb", split="train[:10000]")
# (anchor, positive)
quora_pair_train = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[:10000]")
# (query, answer)
natural_questions_train = load_dataset("sentence-transformers/natural-questions", split="train[:10000]")
# Combine all datasets into a dictionary with dataset names to datasets
train_dataset = {
"all-nli-pair": all_nli_pair_train,
"all-nli-pair-class": all_nli_pair_class_train,
"all-nli-pair-score": all_nli_pair_score_train,
"all-nli-triplet": all_nli_triplet_train,
"stsb": stsb_pair_score_train,
"quora": quora_pair_train,
"natural-questions": natural_questions_train,
}
# 3. Load several Datasets to evaluate with
# (anchor, positive, negative)
all_nli_triplet_dev = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
# (sentence1, sentence2, score)
stsb_pair_score_dev = load_dataset("sentence-transformers/stsb", split="validation")
# (anchor, positive)
quora_pair_dev = load_dataset("sentence-transformers/quora-duplicates", "pair", split="train[10000:11000]")
# (query, answer)
natural_questions_dev = load_dataset("sentence-transformers/natural-questions", split="train[10000:11000]")
# Use a dictionary for the evaluation dataset too, or just use one dataset or none at all
eval_dataset = {
"all-nli-triplet": all_nli_triplet_dev,
"stsb": stsb_pair_score_dev,
"quora": quora_pair_dev,
"natural-questions": natural_questions_dev,
}
# 4. Load several loss functions to train with
# (anchor, positive), (anchor, positive, negative)
mnrl_loss = MultipleNegativesRankingLoss(model)
# (sentence_A, sentence_B) + class
softmax_loss = SoftmaxLoss(model)
# (sentence_A, sentence_B) + score
cosent_loss = CoSENTLoss(model)
# Create a mapping with dataset names to loss functions, so the trainer knows which loss to apply where
# Note: You can also just use one loss if all your training/evaluation datasets use the same loss
losses = {
"all-nli-pair": mnrl_loss,
"all-nli-pair-class": softmax_loss,
"all-nli-pair-score": cosent_loss,
"all-nli-triplet": mnrl_loss,
"stsb": cosent_loss,
"quora": mnrl_loss,
"natural-questions": mnrl_loss,
}
# 5. Define a simple trainer, although it's recommended to use one with args & evaluators
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=losses,
)
trainer.train()
# 6. Save the trained model and optionally push it to the Hugging Face Hub
model.save_pretrained("bert-base-all-nli-stsb-quora-nq")
model.push_to_hub("bert-base-all-nli-stsb-quora-nq")
棄用
在 Sentence Transformer v3 發(fā)布之前,所有模型都會(huì)使用
-
SentenceTransformer.fit https://sbert.net/docs/packagereference/sentencetransformer/SentenceTransformer.html#sentencetransformers.SentenceTransformer.fit -
SentenceTransformerTrainer https://sbert.net/docs/packagereference/sentencetransformer/trainer.html#sentencetransformers.trainer.SentenceTransformerTrainer
附加資源
訓(xùn)練示例
以下頁(yè)面包含帶有解釋的訓(xùn)練示例以及代碼鏈接。我們建議你瀏覽這些頁(yè)面以熟悉訓(xùn)練循環(huán):
-
語義文本相似度 https://sbert.net/examples/training/sts/README.html -
自然語言推理 https://sbert.net/examples/training/nli/README.html -
釋義 https://sbert.net/examples/training/paraphrases/README.html -
Quora 重復(fù)問題 https://sbert.net/examples/training/quoraduplicatequestions/README.html -
Matryoshka Embeddings https://sbert.net/examples/training/matryoshka/README.html -
自適應(yīng)層模型 https://sbert.net/examples/training/adaptivelayer/README.html -
多語言模型 https://sbert.net/examples/training/multilingual/README.html -
模型蒸餾 https://sbert.net/examples/training/distillation/README.html -
增強(qiáng)的句子轉(zhuǎn)換器 https://sbert.net/examples/training/dataaugmentation/README.html
文檔
此外,以下頁(yè)面可能有助于你了解 Sentence Transformers 的更多信息:
-
安裝 https://sbert.net/docs/installation.html -
快速入門 https://sbert.net/docs/quickstart.html -
使用 https://sbert.net/docs/sentencetransformer/usage/usage.html -
預(yù)訓(xùn)練模型 https://sbert.net/docs/sentencetransformer/pretrainedmodels.html -
訓(xùn)練概覽 (本博客是訓(xùn)練概覽文檔的提煉)https://sbert.net/docs/sentencetransformer/trainingoverview.html -
數(shù)據(jù)集概覽 https://sbert.net/docs/sentencetransformer/datasetoverview.html -
損失概覽 https://sbert.net/docs/sentencetransformer/lossoverview.html -
API 參考 https://sbert.net/docs/packagereference/sentencetransformer/index.html
最后,以下是一些高級(jí)頁(yè)面,你可能會(huì)感興趣:
-
超參數(shù)優(yōu)化 https://sbert.net/examples/training/hpo/README.html -
分布式訓(xùn)練 https://sbert.net/docs/sentencetransformer/training/distributed.html
英文原文: https://hf.co/blog/train-sentence-transformers
原文作者: Tom Aarsen
譯者: innovation64
