【關于 嵌套實體識別 之 Biaffine 】 那些你不知道的事
作者:楊夕
項目地址:https://github.com/km1994/nlp_paper_study
論文:https://www.aclweb.org/anthology/2020.acl-main.577/
代碼:https://github.com/juntaoy/biaffine-ner
代碼【中文】:https://github.com/suolyer/PyTorch_BERT_Biaffine_NER
個人介紹:大佬們好,我叫楊夕,該項目主要是本人在研讀頂會論文和復現(xiàn)經(jīng)典論文過程中,所見、所思、所想、所聞,可能存在一些理解錯誤,希望大佬們多多指正。
【關于 嵌套實體識別 之 Biaffine 】 那些你不知道的事
摘要
一、數(shù)據(jù)處理模塊
1.1 原始數(shù)據(jù)格式
1.2 數(shù)據(jù)預處理模塊 data_pre()
1.2.1 數(shù)據(jù)預處理 主 函數(shù)
1.2.2 訓練數(shù)據(jù)加載 load_data(file_path)
1.2.3 數(shù)據(jù)編碼 encoder(sentence, argument)
1.3 數(shù)據(jù)轉化為 MyDataset 對象
1.4 構建 數(shù)據(jù) 迭代器
1.5 最后數(shù)據(jù)構建格式
二、模型構建 模塊
2.1 主題框架介紹
2.2 embedding layer
2.2 BiLSTM
2.3 FFNN
2.4 biaffine model
2.5 沖突解決
2.6 損失函數(shù)
三、學習率衰減 模塊
四、loss 損失函數(shù)定義
四、模型訓練
4.1 span_loss 損失函數(shù)定義
4.2 focal_loss 損失函數(shù)定義
參考
摘要
動機:NER 研究 關注于 扁平化NER,而忽略了 實體嵌套問題;
方法:在本文中,我們使用基于圖的依存關系解析中的思想,以通過 biaffine model 為模型提供全局的輸入視圖。biaffine model 對句子中的開始標記和結束標記對進行評分,我們使用該標記來探索所有跨度,以便該模型能夠準確地預測命名實體。
工作介紹:在這項工作中,我們將NER重新確定為開始和結束索引的任務,并為這些對定義的范圍分配類別。我們的系統(tǒng)在多層BiLSTM之上使用biaffine模型,將分數(shù)分配給句子中所有可能的跨度。此后,我們不用構建依賴關系樹,而是根據(jù)候選樹的分數(shù)對它們進行排序,然后返回符合 Flat 或 Nested NER約束的排名最高的樹 span;
實驗結果:我們根據(jù)三個嵌套的NER基準(ACE 2004,ACE 2005,GENIA)和五個扁平的NER語料庫(CONLL 2002(荷蘭語,西班牙語),CONLL 2003(英語,德語)和ONTONOTES)對系統(tǒng)進行了評估。結果表明,我們的系統(tǒng)在所有三個嵌套的NER語料庫和所有五個平坦的NER語料庫上均取得了SoTA結果,與以前的SoTA相比,實際收益高達2.2%的絕對百分比。
一、數(shù)據(jù)處理模塊
1.1 原始數(shù)據(jù)格式
原始數(shù)據(jù)格式如下所示:
{
"text": "當希望工程救助的百萬兒童成長起來,科教興國蔚然成風時,今天有收藏價值的書你沒買,明日就叫你悔不當初!",
"entity_list": []
}
{
"text": "藏書本來就是所有傳統(tǒng)收藏門類中的第一大戶,只是我們結束溫飽的時間太短而已。",
"entity_list": []
}
{
"text": "因有關日寇在京掠奪文物詳情,藏界較為重視,也是我們收藏北京史料中的要件之一。",
"entity_list":
[
{"type": "ns", "argument": "北京"}
]
}
...
1.2 數(shù)據(jù)預處理模塊 data_pre()
1.2.1 數(shù)據(jù)預處理 主 函數(shù)
步驟:
加載數(shù)據(jù);
對數(shù)據(jù)進行編碼,轉化為 訓練數(shù)據(jù) 格式
代碼介紹:
def data_pre(file_path):
sentences, arguments = load_data(file_path)
data = []
for i in tqdm(range(len(sentences))):
encode_sent, token_type_ids, attention_mask, span_label, span_mask = encoder(
sentences[i], arguments[i])
tmp = {}
tmp['input_ids'] = encode_sent
tmp['input_seg'] = token_type_ids
tmp['input_mask'] = attention_mask
tmp['span_label'] = span_label
tmp['span_mask'] = span_mask
data.append(tmp)
return data
輸出結果:
data[0:2]:
[
{
'input_ids': [
101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'input_seg': [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'input_mask': [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'span_label': array(
[
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]
]
),
'span_mask': [
[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], ...
]
}, ...
]
1.2.2 訓練數(shù)據(jù)加載 load_data(file_path)
代碼介紹:
def load_data(file_path):
with open(file_path, 'r', encoding='utf8') as f:
lines = f.readlines()
sentences = []
arguments = []
for line in lines:
data = json.loads(line)
text = data['text']
entity_list = data['entity_list']
args_dict={}
if entity_list != []:
for entity in entity_list:
entity_type = entity['type']
entity_argument=entity['argument']
args_dict[entity_type] = entity_argument
sentences.append(text)
arguments.append(args_dict)
return sentences, arguments
輸出結果:
print(f"sentences[0:2]:{sentences[0:2]}")
print(f"arguments[0:2]:{arguments[0:2]}")
>>>
sentences[0:2]:['因有關日寇在京掠奪文物詳情,藏界較為重視,也是我們收藏北京史料中的要件之一。', '我們藏有一冊1945年 6月油印的《北京文物保存保管狀態(tài)之調查報告》,調查范圍涉及故宮、歷博、古研所、北大清華圖書館、北圖、日偽資料庫等二十幾家,言及文物二十萬件以上,洋洋三萬余言,是珍貴的北京史料。']
arguments[0:2]:[{'ns': '北京'}, {'ns': '北京', 'nt': '古研所'}]
1.2.3 數(shù)據(jù)編碼 encoder(sentence, argument)
代碼介紹:
# step 1:獲取 Bert tokenizer
tokenizer=tools.get_tokenizer()
# step 2: 獲取 label 到 id 間 的 映射表;
label2id,id2label,num_labels = tools.load_schema()
def encoder(sentence, argument):
# step 3:利用 tokenizer 對 sentence 進行 編碼
encode_dict = tokenizer.encode_plus(
sentence,
max_length=args.max_length,
pad_to_max_length=True,
truncation=True
)
encode_sent = encode_dict['input_ids']
token_type_ids = encode_dict['token_type_ids']
attention_mask = encode_dict['attention_mask']
# step 4:span_mask 生成
zero = [0 for i in range(args.max_length)]
span_mask=[ attention_mask for i in range(sum(attention_mask))]
span_mask.extend([ zero for i in range(sum(attention_mask),args.max_length)])
# step 5:span_label 生成
span_label = [0 for i in range(args.max_length)]
span_label = [span_label for i in range(args.max_length)]
span_label = np.array(span_label)
for entity_type,arg in argument.items():
encode_arg = tokenizer.encode(arg)
start_idx = tools.search(encode_arg[1:-1], encode_sent)
end_idx = start_idx + len(encode_arg[1:-1]) - 1
span_label[start_idx, end_idx] = label2id[entity_type]+1
return encode_sent, token_type_ids, attention_mask, span_label, span_mask
步驟:
獲取 Bert tokenizer;
獲取 label 到 id 間 的 映射表;
encode_plus返回所有編碼信息
encode_dict:
{
'input_ids': [101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
}
注:
‘input_ids’:顧名思義,是單詞在詞典中的編碼
‘token_type_ids’, 區(qū)分兩個句子的編碼
‘a(chǎn)ttention_mask’, 指定對哪些詞進行self-Attention操作
span_mask 生成
span_label 生成
介紹:該方法 生成 一個 大小 為 args.max_length*args.max_length 的矩陣,用于 定位 span 在 句子中的位置【開始位置、結束位置】,span 在矩陣中行號 為 開始位置,列號為 結束位置,對應的值 為 該 span所對應的類型;
實例代碼介紹:
>>>
import numpy as np
span_label = [0 for i in range(10)]
span_label = [span_label for i in range(10)]
span_label = np.array(span_label)
start = [1, 3, 7]
end = [ 2,9, 9]
label2id = [1,2,4]
for i in range(len(label2id)):
span_label[start[i], end[i]] = label2id[i]
>>>
array( [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
> 注:行號 為 start,列號 為 end,值 為 label2id
1.3 數(shù)據(jù)轉化為 MyDataset 對象
將數(shù)據(jù)轉化為 torch.tensor 類型
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
one_data = {
"input_ids": torch.tensor(item['input_ids']).long(),
"input_seg": torch.tensor(item['input_seg']).long(),
"input_mask": torch.tensor(item['input_mask']).float(),
"span_label": torch.tensor(item['span_label']).long(),
"span_mask": torch.tensor(item['span_mask']).long()
}
return one_data
1.4 構建 數(shù)據(jù) 迭代器
def yield_data(file_path):
tmp = MyDataset(data_pre(file_path))
return DataLoader(tmp, batch_size=args.batch_size, shuffle=True)
1.5 最后數(shù)據(jù)構建格式
data[0:2]:
[
{
'input_ids': [
101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'input_seg': [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'input_mask': [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'span_label': array(
[
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]
]
),
'span_mask': [
[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], ...
]
}, ...
]
二、模型構建 模塊
2.1 主題框架介紹

模型主要由 embedding layer、BiLSTM、FFNN、biaffine model 四部分組成。
2.2 embedding layer
BERT:遵循 (Kantor and Globerson, 2019) 的方法來獲取目標令牌的上下文相關嵌入,每側有64個周圍令牌;
character-based word embeddings:使用 CNN 編碼 characters of the tokens.
class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
super().__init__()
self.roberta_encoder = BertModel.from_pretrained(pre_train_dir)
self.roberta_encoder.resize_token_embeddings(len(tokenizer))
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
bert_output = self.roberta_encoder(input_ids=input_ids,
attention_mask=input_mask,
token_type_ids=input_seg)
encoder_rep = bert_output[0]
...
2.2 BiLSTM
拼接 char emb 和 word emb,并輸入到 BiLSTM,以獲得 word 表示;
class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
super().__init__()
...
self.lstm=torch.nn.LSTM(input_size=768,hidden_size=768, \
num_layers=1,batch_first=True, \
dropout=0.5,bidirectional=True)
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
...
encoder_rep,_ = self.lstm(encoder_rep)
...
2.3 FFNN
從BiLSTM獲得單詞表示形式后,我們應用兩個單獨的FFNN為 span 的開始/結束創(chuàng)建不同的表示形式(hs / he)。對 span 的開始/結束使用不同的表示,可使系統(tǒng)學會單獨識別 span 的開始/結束。與直接使用LSTM輸出的模型相比,這提高了準確性,因為實體開始和結束的上下文不同。

class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
...
self.start_layer = torch.nn.Sequential(
torch.nn.Linear(in_features=2*768, out_features=128),
torch.nn.ReLU()
)
self.end_layer = torch.nn.Sequential(
torch.nn.Linear(in_features=2*768, out_features=128),
torch.nn.ReLU()
)
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
...
start_logits = self.start_layer(encoder_rep)
end_logits = self.end_layer(encoder_rep)
...
2.4 biaffine model
在句子上使用biaffine模型來創(chuàng)建 l×l×c 評分張量(rm),其中l(wèi)是句子的長度,c 是 NER 類別的數(shù)量 +1(對于非實體)。

其中si和ei是 span i 的開始和結束索引,Um 是 d×c×d 張量,Wm是2d×c矩陣,bm是偏差
定義
class biaffine(nn.Module):
def __init__(self, in_size, out_size, bias_x=True, bias_y=True):
super().__init__()
self.bias_x = bias_x
self.bias_y = bias_y
self.out_size = out_size
self.U = torch.nn.Parameter(torch.Tensor(in_size + int(bias_x),out_size,in_size + int(bias_y)))
def forward(self, x, y):
if self.bias_x:
x = torch.cat((x, torch.ones_like(x[..., :1])), dim=-1)
if self.bias_y:
y = torch.cat((y, torch.ones_like(y[..., :1])), dim=-1)
bilinar_mapping = torch.einsum('bxi,ioj,byj->bxyo', x, self.U, y)
return bilinar_mapping
調用
class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
...
self.biaffne_layer = biaffine(128,num_label)
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
...
span_logits = self.biaffne_layer(start_logits,end_logits)
span_logits = span_logits.contiguous()
...
2.5 沖突解決
張量 vr_m 提供在 s_i≤e_i 的約束下(實體的起點在其終點之前)可以構成命名實體的所有可能 span 的分數(shù)。我們?yōu)槊總€跨度分配一個NER類別 y0

然后,我們按照其類別得分 (r_m * (i_{y'})) 降序對所有其他“非實體”類別的 span 進行排序,并應用以下后處理約束:對于嵌套的NER,只要選擇了一個實體不會與排名較高的實體發(fā)生沖突。對于 實體 i與其他實體 j ,如果 s_i<s_j≤e_i<e_j 或 s_j<s_i≤e_j<e_i ,那么這兩個實體沖突。此時只會選擇類別得分較高的 span。
eg:
在 句子 :In the Bank of China 中, 實體 the Bank 的 邊界與 實體 Bank of China 沖突,
注:對于 flat NER,我們應用了一個更多的約束,其中包含或在排名在它之前的實體之內的任何實體都將不會被選擇。我們命名實體識別器的學習目標是為每個有效范圍分配正確的類別(包括非實體)。
2.6 損失函數(shù)
因為該任務屬于 多類別分類問題:

class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
...
span_prob = torch.nn.functional.softmax(span_logits, dim=-1)
if is_training:
return span_logits
else:
return span_prob
三、學習率衰減 模塊
class WarmUp_LinearDecay:
def __init__(self, optimizer: optim.AdamW, init_rate, warm_up_epoch, decay_epoch, min_lr_rate=1e-8):
self.optimizer = optimizer
self.init_rate = init_rate
self.epoch_step = train_data_length / args.batch_size
self.warm_up_steps = self.epoch_step * warm_up_epoch
self.decay_steps = self.epoch_step * decay_epoch
self.min_lr_rate = min_lr_rate
self.optimizer_step = 0
self.all_steps = args.epoch*(train_data_length/args.batch_size)
def step(self):
self.optimizer_step += 1
if self.optimizer_step <= self.warm_up_steps:
rate = (self.optimizer_step / self.warm_up_steps) * self.init_rate
elif self.warm_up_steps < self.optimizer_step <= self.decay_steps:
rate = self.init_rate
else:
rate = (1.0 - ((self.optimizer_step - self.decay_steps) / (self.all_steps-self.decay_steps))) * self.init_rate
if rate < self.min_lr_rate:
rate = self.min_lr_rate
for p in self.optimizer.param_groups:
p["lr"] = rate
self.optimizer.step()
四、loss 損失函數(shù)定義
4.1 span_loss 損失函數(shù)定義
核心思想:對于模型學習到的所有實體的 start 和 end 位置,構造首尾實體匹配任務,即判斷某個 start 位置是否與某個end位置匹配為一個實體,是則預測為1,否則預測為0,相當于轉化為一個二分類問題,正樣本就是真實實體的匹配,負樣本是非實體的位置匹配。
import torch
from torch import nn
from utils.arguments_parse import args
from data_preprocessing import tools
label2id,id2label,num_labels=tools.load_schema()
num_label = num_labels+1
class Span_loss(nn.Module):
def __init__(self):
super().__init__()
self.loss_func = torch.nn.CrossEntropyLoss(reduction="none")
def forward(self,span_logits,span_label,seq_mask):
# batch_size,seq_len,hidden=span_label.shape
span_label = span_label.view(size=(-1,))
span_logits = span_logits.view(size=(-1, num_label))
span_loss = self.loss_func(input=span_logits, target=span_label)
# start_extend = seq_mask.unsqueeze(2).expand(-1, -1, seq_len)
# end_extend = seq_mask.unsqueeze(1).expand(-1, seq_len, -1)
span_mask = seq_mask.view(size=(-1,))
span_loss *=span_mask
avg_se_loss = torch.sum(span_loss) / seq_mask.size()[0]
# avg_se_loss = torch.sum(sum_loss) / bsz
return avg_se_loss
注:view函數(shù)的作用為重構張量的維度,相當于numpy中resize()的功能
參考論文:《A Unified MRC Framwork for Name Entity Recognition》
4.2 focal_loss 損失函數(shù)定義
目標:解決分類問題中類別不平衡、分類難度差異的一個 loss;
思路:降低了大量簡單負樣本在訓練中所占的權重,也可理解為一種困難樣本挖掘。
損失函數(shù)形式:
Focal loss是在交叉熵損失函數(shù)基礎上進行的修改,首先回顧二分類交叉上損失:

y'是經(jīng)過激活函數(shù)的輸出,所以在0-1之間??梢娖胀ǖ慕徊骒貙τ谡龢颖径裕敵龈怕试酱髶p失越小。對于負樣本而言,輸出概率越小則損失越小。此時的損失函數(shù)在大量簡單樣本的迭代過程中比較緩慢且可能無法優(yōu)化至最優(yōu)。那么Focal loss是怎么改進的呢?


首先在原有的基礎上加了一個因子,其中gamma>0使得減少易分類樣本的損失。使得更關注于困難的、錯分的樣本。
例如gamma為2,對于正類樣本而言,預測結果為0.95肯定是簡單樣本,所以(1-0.95)的gamma次方就會很小,這時損失函數(shù)值就變得更小。而預測概率為0.3的樣本其損失相對很大。對于負類樣本而言同樣,預測0.1的結果應當遠比預測0.7的樣本損失值要小得多。對于預測概率為0.5時,損失只減少了0.25倍,所以更加關注于這種難以區(qū)分的樣本。這樣減少了簡單樣本的影響,大量預測概率很小的樣本疊加起來后的效應才可能比較有效。
此外,加入平衡因子alpha,用來平衡正負樣本本身的比例不均:

只添加alpha雖然可以平衡正負樣本的重要性,但是無法解決簡單與困難樣本的問題。
lambda調節(jié)簡單樣本權重降低的速率,當lambda為0時即為交叉熵損失函數(shù),當lambda增加時,調整因子的影響也在增加。實驗發(fā)現(xiàn)lambda為2是最優(yōu)。
代碼實現(xiàn)
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
'''Multi-class Focal loss implementation'''
def __init__(self, gamma=2, weight=None, ignore_index=-100):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index = ignore_index
def forward(self, input, target):
"""
input: [N, C]
target: [N, ]
"""
logpt = F.log_softmax(input, dim=1)
pt = torch.exp(logpt)
logpt = (1 - pt) ** self.gamma * logpt
loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index)
return loss
參考論文:《 Focal Loss for Dense Object Detection 》
四、模型訓練
def train():
# step 1:數(shù)據(jù)預處理
train_data = data_prepro.yield_data(args.train_path)
test_data = data_prepro.yield_data(args.test_path)
# step 2:模型定義
model = myModel(pre_train_dir=args.pretrained_model_path, dropout_rate=0.5).to(device)
# step 3:優(yōu)化函數(shù) 定義
# model.load_state_dict(torch.load(args.checkpoints))
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.0}
]
optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate)
schedule = WarmUp_LinearDecay(
optimizer = optimizer,
init_rate = args.learning_rate,
warm_up_epoch = args.warm_up_epoch,
decay_epoch = args.decay_epoch
)
# step 4:span_loss 函數(shù) 定義
span_loss_func = span_loss.Span_loss().to(device)
span_acc = metrics.metrics_span().to(device)
# step 5:訓練
step=0
best=0
for epoch in range(args.epoch):
for item in train_data:
step+=1
input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
span_label,span_mask = item['span_label'],item["span_mask"]
optimizer.zero_grad()
span_logits = model(
input_ids=input_ids.to(device),
input_mask=input_mask.to(device),
input_seg=input_seg.to(device),
is_training=True
)
span_loss_v = span_loss_func(span_logits,span_label.to(device),span_mask.to(device))
loss = span_loss_v
loss = loss.float().mean().type_as(loss)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_norm)
schedule.step()
# optimizer.step()
if step%100 == 0:
span_logits = torch.nn.functional.softmax(span_logits, dim=-1)
recall,precise,span_f1=span_acc(span_logits,span_label.to(device))
logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, span_f1 %.4f'% (epoch,step,loss,recall,precise,span_f1))
with torch.no_grad():
count=0
span_f1=0
recall=0
precise=0
for item in test_data:
count+=1
input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
span_label,span_mask = item['span_label'],item["span_mask"]
optimizer.zero_grad()
span_logits = model(
input_ids=input_ids.to(device),
input_mask=input_mask.to(device),
input_seg=input_seg.to(device),
is_training=False
)
tmp_recall,tmp_precise,tmp_span_f1=span_acc(span_logits,span_label.to(device))
span_f1+=tmp_span_f1
recall+=tmp_recall
precise+=tmp_precise
span_f1 = span_f1/count
recall=recall/count
precise=precise/count
logger.info('-----eval----')
logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, span_f1 %.4f'% (epoch,step,loss,recall,precise,span_f1))
logger.info('-----eval----')
if best < span_f1:
best=span_f1
torch.save(model.state_dict(), f=args.checkpoints)
logger.info('-----save the best model----')
參考
Named Entity R
