【關于 Bert 源碼解析II 之 預訓練篇 】 那些的你不知道的事
作者:楊夕
論文鏈接:https://arxiv.org/pdf/1810.04805.pdf
本文鏈接:https://github.com/km1994/nlp_paper_study
個人介紹:大佬們好,我叫楊夕,該項目主要是本人在研讀頂會論文和復現經典論文過程中,所見、所思、所想、所聞,可能存在一些理解錯誤,希望大佬們多多指正。
【注:手機閱讀可能圖片打不開?。?!】
目錄

一、動機
之前給 小伙伴們 寫過 一篇 【【關于Bert】 那些的你不知道的事】后,有一些小伙伴聯系我,說對?【Bert】?里面的很多細節(jié)性問題都沒看懂,不清楚他怎么實現的。針對該問題,小菜雞的我 也 意識到自己的不足,所以就 想 研讀一下?【Bert】?的 源碼,并針對 之前小伙伴 的一些 問題 進行 回答和解釋,能力有限,希望對大家有幫助。
二、本文框架
本文 將?【Bert】?的 源碼分成以下模塊:
【關于 Bert 源碼解析 之 主體篇 】 那些的你不知道的事
【關于 Bert 源碼解析 之 預訓練篇 】 那些的你不知道的事【本章】
【關于 Bert 源碼解析 之 微調篇 】 那些的你不知道的事
【關于 Bert 源碼解析IV 之 句向量生成篇 】 那些的你不知道的事
【關于 Bert 源碼解析V 之 文本相似度篇 】 那些的你不知道的事
分模塊 進行解讀。
三、前言
本文 主要 解讀 Bert 模型的 預訓練 模塊代碼:
tokenization.py:主要用于對原始句子內容進行解析,分為 BasicTokenizer和WordpieceTokenizer 兩種;
create_pretraining_data.py:用于將 原始語料 轉化為 模型 所需要的 訓練格式;
run_pretraining.py:模型預訓練;
四、原始語料 預處理模塊 (tokenization.py)
4.1 動機
由于 原始語料 可能多種多樣,所以需要將 原始語料 轉化為 Bert 所需要的訓練數據格式。
4.2 類別
預處理模塊主要分為:
BasicTokenizer
WordpieceTokenizer
4.3 BasicTokenizer
動機:對原始語料進行處理
處理操作:unicode轉換、標點符號分割、小寫轉換、中文字符分割、去除重音符號等操作,最后返回的是關于詞的數組(中文是字的數組);
代碼解析
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: 是否將 query 字母都轉化為小寫
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
# step 1:將 text 從 Unicode 轉化為 utf-8
text = convert_to_unicode(text)
# step 2:去除無意義字符以及空格
text = self._clean_text(text)
# step 3:增加中文支持
text = self._tokenize_chinese_chars(text)
# step 4:在一段文本上運行基本的空格清除和拆分
orig_tokens = whitespace_tokenize(text)
# step 5:用標點切分
split_tokens = []
for token in orig_tokens:
# 是否轉小寫
if self.do_lower_case:
token = token.lower()
# 對text進行歸一化
token = self._run_strip_accents(token)
# 用標點切分
split_tokens.extend(self._run_split_on_punc(token))
# step 5:在一段文本上運行基本的空格清除和拆分
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""這個函數去除掉text中的非間距字符"""
# step 1: 對text進行歸一化
# 標準化對于任何需要以一致的方式處理Unicode文本的程序都是非常重要的。
# 當處理來自用戶輸入的字符串而你很難去控制編碼的時候尤其如此。
# normalize() 將文本標準化,第一個參數指定字符串標準化的方式,NFD表示字符應該分解為多個組合字符表示
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
# # category() 返回字符在UNICODE里分類的類型
cat = unicodedata.category(char)
# 判斷cat 是否為 Mn,Mark, Nonspacing 指示字符是非間距字符,這指示基字符的修改。
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""用標點對 文本 進行切分,返回list"""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
""" 按字切分中文,實現就是在字兩側添加空格
Adds whitespace around any CJK character. """
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
""" 判斷是否是漢字
Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""
去除無意義字符以及空格
Performs invalid character removal and whitespace cleanup on text.
"""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
4.4 WordpieceTokenizer
動機:對于 詞中 可能是 未登錄詞、時態(tài)問題等;
操作:將BasicTokenizer的結果進一步做更細粒度的切分,將合成詞分解成類似詞根一樣的詞片。例如將"unwanted"分解成["un", "##want", "##ed"]
目的:去除未登錄詞對模型效果的影響。防止因為詞的過于生僻沒有被收錄進詞典最后只能以[UNK]代替的局面,因為英語當中這樣的合成詞非常多,詞典不可能全部收錄。
代碼講解
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""使用貪心的最大正向匹配算法
例如:
For example:
input = \"unaffable\"
output = [\"un\", \"##aff\", \"##able\"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
# step 1:將 text 從 Unicode 轉化為 utf-8
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
舉例說明:
假設輸入是”unaffable”。
我們跳到while循環(huán)部分,這是start=0,end=len(chars)=9,也就是先看看unaffable在不在詞典里,如果在,那么直接作為一個WordPiece,如果不再,那么end-=1,也就是看unaffabl在不在詞典里,最終發(fā)現”un”在詞典里,把un加到結果里。
接著start=2,看affable在不在,不在再看affabl,…,最后發(fā)現 ##aff 在詞典里。注意:##表示這個詞是接著前面的,這樣使得WordPiece切分是可逆的——我們可以恢復出“真正”的詞。
4.5 FullTokenizer
功能:對一個文本段進行以上兩種解析,最后返回詞(字)的數組,同時還提供token到id的索引以及id到token的索引。這里的token可以理解為文本段處理過后的最小單元。
代碼講解
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
# 加載詞表文件為字典形式
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
# 調用BasicTokenizer粗粒度分詞
for token in self.basic_tokenizer.tokenize(text):
# 調用WordpieceTokenizer細粒度分詞
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
五、訓練數據生成(create_pretraining_data.py)
5.1 作用
將原始輸入語料轉換成模型預訓練所需要的數據格式TFRecoed。
5.2 參數設置
flags.DEFINE_string("input_file", None,
"Input raw text file (or comma-separated list of files).")
flags.DEFINE_string(
"output_file", None,
"Output TF example file (or comma-separated list of files).")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence.")
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
flags.DEFINE_integer(
"dupe_factor", 10,
"Number of times to duplicate the input data (with different masks).")
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float(
"short_seq_prob", 0.1,
"Probability of creating sequences which are shorter than the "
"maximum length.")
參數介紹
input_file::代表輸入的源語料文件地址
output_file :代表處理過的預料文件地址
do_lower_case:是否全部轉為小寫字母
vocab_file:詞典文件
max_seq_length:最大序列長度
dupe_factor: 重復參數,默認重復10次,目的是可以生成不同情況的masks;舉例:對于同一個句子,我們可以設置不同位置的【MASK】次數。比如對于句子Hello world, this is bert.,為了充分利用數據,第一次可以mask成Hello [MASK], this is bert.,第二次可以變成Hello world, this is [MASK].
max_predictions_per_seq: 一個句子里最多有多少個[MASK]標記
masked_lm_prob: 多少比例的Token被MASK掉
short_seq_prob: 長度小于“max_seq_length”的樣本比例。因為在fine-tune過程里面輸入的target_seq_length是可變的(小于等于max_seq_length),那么為了防止過擬合也需要在pre-train的過程當中構造一些短的樣本。
5.3 main 入口
思路:
構造tokenizer,構造tokenizer對輸入語料進行分詞處理 ;
構造instances,經過create_training_instances函數構造訓練instance;
保存instances,調用write_instance_to_example_files函數以TFRecord格式保存數據 ;
代碼講解
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# step 1:構造tokenizer,構造tokenizer對輸入語料進行分詞處理 ;
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
tf.logging.info("*** Reading from input files ***")
for input_file in input_files:
tf.logging.info(" %s", input_file)
# step 2:構造instances,經過create_training_instances函數構造訓練instance;
rng = random.Random(FLAGS.random_seed)
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng)
output_files = FLAGS.output_file.split(",")
tf.logging.info("*** Writing to output files ***")
for output_file in output_files:
tf.logging.info(" %s", output_file)
# step 3:保存instances,調用write_instance_to_example_files函數以TFRecord格式保存數據
;
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files)
5.4 定義訓練樣本類 (TrainingInstance)
作用:構建訓練樣本
代碼講解:
class TrainingInstance(object):
"""A single training instance (sentence pair)."""
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
is_random_next):
self.tokens = tokens
self.segment_ids = segment_ids # 指的形式為[0,0,0...1,1,111] 0的個數為i+1個,1的個數為max_seq_length - (i+1) 對應到模型輸入就是token_type
self.is_random_next = is_random_next # 其實就是上圖的Label,0.5的概率為True(和當只有一個segment的時候),如果為True則B和A不屬于同一document。剩下的情況為False,則B為A同一document的后續(xù)句子。
self.masked_lm_positions = masked_lm_positions # 序列里被[MASK]的位置;
self.masked_lm_labels = masked_lm_labels # 序列里被[MASK]的token
def __str__(self):
s = ""
s += "tokens: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.tokens]))
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
s += "is_random_next: %s\n" % self.is_random_next
s += "masked_lm_positions: %s\n" % (" ".join(
[str(x) for x in self.masked_lm_positions]))
s += "masked_lm_labels: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
s += "\n"
return s
def __repr__(self):
return self.__str__()
5.5 構建訓練實例 (create_training_instances)
功能:讀取數據,并構建實例
訓練樣本輸入格式說明(sample_text.txt):
This text is included to make sure Unicode is handled properly: 力加勝北區(qū)??????????
Text should be one-sentence-per-line, with empty lines between documents.
This sample text is public domain and was randomly selected from Project Guttenberg.
The rain had only ceased with the gray streaks of morning at Blazing Star, and the settlement awoke to a moral sense of cleanliness, and the finding of forgotten knives, tin cups, and smaller camp utensils, where the heavy showers had washed away the debris and dust heaps before the cabin doors.
Indeed, it was recorded in Blazing Star that a fortunate early riser had once picked up on the highway a solid chunk of gold quartz which the rain had freed from its incumbering soil, and washed into immediate and glittering popularity.
...
說明:
不同句子 用 換行符 分割,也就是一個句子一行;
不同 文檔 中間 用 兩個換行符 分割;
同一篇文檔的上下句 之間 操作關系;
代碼講解:
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
# step 1:加載數據
# Input file format:
# (1) 每行一句話。理想情況下,這些應該是實際句子,而不是整個段落或文本的任意跨度。 (因為我們將句子邊界用于“下一句預測”任務)。
# (2) 文檔之間的空白行。需要文檔邊界,以便“下一個句子預測”任務不會跨越文檔之間。
for input_file in input_files:
with tf.gfile.GFile(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
# Empty lines are used as document delimiters
if not line:
all_documents.append([])
tokens = tokenizer.tokenize(line)
if tokens:
all_documents[-1].append(tokens)
# step 2:清除 空文檔
all_documents = [x for x in all_documents if x]
rng.shuffle(all_documents)
vocab_words = list(tokenizer.vocab.keys())
instances = []
# step 3:重復dupe_factor次,目的是可以生成不同情況的masks
for _ in range(dupe_factor):
for document_index in range(len(all_documents)):
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
# step 4:數據打亂
rng.shuffle(instances)
return instances
5.6 從 document 中抽取 實例(create_instances_from_document)
5.6.1 作用
作用:實現從一個文檔中抽取多個訓練樣本
5.6.2 代碼講解
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
# step 1:為[CLS], [SEP], [SEP]預留三個空位
max_num_tokens = max_seq_length - 3
# step 2:以short_seq_prob的概率隨機生成(2~max_num_tokens)的長度
# 我們*通常*想要填充整個序列,因為無論如何我們都要填充到“ max_seq_length”,因此短序列通常會浪費計算量。但是,我們有時*(即short_seq_prob == 0.1 == 10%的時間)希望使用較短的序列來最大程度地減少預訓練和微調之間的不匹配。但是,“ target_seq_length”只是一個粗略的目標,而“ max_seq_length”是一個硬限制。
target_seq_length = max_num_tokens
if rng.random() < short_seq_prob:
target_seq_length = rng.randint(2, max_num_tokens)
# step 3:根據用戶輸入提供的實際“句子”將輸入分為“ A”和“ B”兩段
# 我們不只是將文檔中的所有標記連接成一個較長的序列,并選擇一個任意的分割點,因為這會使下一個句子的預測任務變得太容易了。相反,我們根據用戶輸入提供的實際“句子”將輸入分為“ A”和“ B”兩段。
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
# 將句子依次加入current_chunk中,直到加完或者達到限制的最大長度
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end`是第一個句子A結束的下標
a_end = 1
# 隨機選取切分邊界
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
# step 4:構建 NSP 任務
tokens_b = []
# 是否隨機next
is_random_next = False
# 構建隨機的下一句
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# 隨機的挑選另外一篇文檔的隨機開始的句子
# 但是理論上有可能隨機到的文檔就是當前文檔,因此需要一個while循環(huán)
# 這里只while循環(huán)10次,理論上還是有重復的可能性,但是我們忽略
for _ in range(10):
random_document_index = rng.randint(0, len(all_documents) - 1)
if random_document_index != document_index:
break
random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# 對于上述構建的隨機下一句,我們并沒有真正地使用它們
# 所以為了避免數據浪費,我們將其“放回”
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# 構建真實的下一句
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
# 如果太多了,隨機去掉一些
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = []
segment_ids = []
# 處理句子A
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
# 句子A結束,加上【SEP】
tokens.append("[SEP]")
segment_ids.append(0)
# 處理句子B
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
# 句子B結束,加上【SEP】
tokens.append("[SEP]")
segment_ids.append(1)
# step 5:構建 MLN 任務
# 調用 create_masked_lm_predictions 來隨機對某些Token進行mask
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
is_random_next=is_random_next,
masked_lm_positions=masked_lm_positions,
masked_lm_labels=masked_lm_labels)
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
5.6.3 流程
算法首先會維護一個chunk,不斷加入document中的元素,也就是句子(segment),直到加載完或者chunk中token數大于等于最大限制,這樣做的目的是使得padding的盡量少,訓練效率更高。
現在chunk建立完畢之后,假設包括了前三個句子,算法會隨機選擇一個切分點,比如2。接下來構建predict next判斷:
如果是正樣本,前兩個句子當成是句子A,后一個句子當成是句子B;
如果是負樣本,前兩個句子當成是句子A,無關的句子從其他文檔中隨機抽取
得到句子A和句子B之后,對其填充tokens和segment_ids,這里會加入特殊的[CLS]和[SEP]標記
模板
[CLS] A [SEP] B [SEP]
A = [token_0, token_1, ..., token_i]
B = [token_i+1, token_i+2, ... , token_n-1]
其中:
2<= n < max_seq_length - 3 (in short_seq_prob)
n=max_seq_length - 3 (in 1-short_seq_prob)
結果舉例
Input = [CLS] the man went to [MASK] store [SEP] he bought a gallon [MASK] milk [SEP]
Label = IsNext
Input = [CLS] the man [MASK] to the store [SEP] he penguin [MASK] are flight ##less birds [SEP]
Label = NotNext
在create_masked_lm_predictions函數里,一個序列在指定MASK數量之后,有80%被真正MASK,10%還是保留原來token,10%被隨機替換成其他token。
5.7 隨機MASK(create_masked_lm_predictions)
介紹
創(chuàng)新點:對Tokens進行隨機mask
原因:為了防止模型在雙向循環(huán)訓練的過程中“預見自身”;
文章中選取的策略是對輸入序列中15%的詞使用[MASK]標記掩蓋掉,然后通過上下文去預測這些被mask的token。但是為了防止模型過擬合地學習到【MASK】這個標記,對15%mask掉的詞進一步優(yōu)化。
操作
原始句子:the man went to a store, he bought a gallon milk.
對于 句子中 15% 的 Token,采用以下一種方式優(yōu)化:
以80%的概率用[MASK]替換:
the man went to a [MASK], he bought a gallon milk.以10%的概率隨機替換:
the man went to a shop, he bought a gallon milk.以10%的概率保持原狀:
the man went to a store, he bought a gallon milk.
代碼解析
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
# step 1:[CLS]和[SEP]不能用于MASK
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word. When a word has been split into
# WordPieces, the first token does not have any marker and any subsequence
# tokens are prefixed with ##. So whenever we see the ## token, we
# append it to the previous set of word indexes.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if (FLAGS.do_whole_word_mask and len(cand_indexes) >= 1 and
token.startswith("##")):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
# step 2 : mask 操作
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
assert len(masked_lms) <= num_to_predict
# step 3:按照下標重排,保證是原來句子中出現的順序
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels)
5.8 保存instance(write_instance_to_example_files)
代碼講解
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files):
writers = []
for output_file in output_files:
writers.append(tf.python_io.TFRecordWriter(output_file))
writer_index = 0
total_written = 0
for (inst_index, instance) in enumerate(instances):
# 將輸入轉成word-ids
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
# 記錄實際句子長度
input_mask = [1] * len(input_ids)
segment_ids = list(instance.segment_ids)
assert len(input_ids) <= max_seq_length
# padding
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = list(instance.masked_lm_positions)
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
masked_lm_weights = [1.0] * len(masked_lm_ids)
while len(masked_lm_positions) < max_predictions_per_seq:
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance.is_random_next else 0
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids)
features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
features["next_sentence_labels"] = create_int_feature([next_sentence_label])
# 生成訓練樣本
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
# 輸出到文件
writers[writer_index].write(tf_example.SerializeToString())
writer_index = (writer_index + 1) % len(writers)
total_written += 1
# 打印前20個樣本
if inst_index < 20:
tf.logging.info("*** Example ***")
tf.logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in instance.tokens]))
for feature_name in features.keys():
feature = features[feature_name]
values = []
if feature.int64_list.value:
values = feature.int64_list.value
elif feature.float_list.value:
values = feature.float_list.value
tf.logging.info(
"%s: %s" % (feature_name, " ".join([str(x) for x in values])))
for writer in writers:
writer.close()
tf.logging.info("Wrote %d total instances", total_written)
六、預訓練
6.1 Masked LM 訓練 (get_masked_lm_output)
作用:針對的是語言模型對MASK起來的標簽的預測,即上下文語境預測當前詞,并計算 MLM 的 訓練 loss
def get_masked_lm_output(
bert_config, # bert 配置
input_tensor, # BertModel的最后一層sequence_output輸出([batch_size, seq_length, hidden_size])
output_weights, # embedding_table,用來反embedding,這樣就映射到token的學習了
positions,
label_ids,
label_weights):
""" 獲取 MLM 的 loss 和 log probs
Get loss and log probs for the masked LM.
"""
# step 1:在一個小批量的特定位置收集向量。
input_tensor = gather_indexes(input_tensor, positions)
with tf.variable_scope("cls/predictions"):
# step 2:在輸出之前添加一個非線性變換,只在預訓練階段起作用
with tf.variable_scope("transform"):
input_tensor = tf.layers.dense(
input_tensor,
units=bert_config.hidden_size,
activation=modeling.get_activation(bert_config.hidden_act),
kernel_initializer=modeling.create_initializer(
bert_config.initializer_range))
input_tensor = modeling.layer_norm(input_tensor)
# step 3:output_weights是和傳入的word embedding一樣的
# 但是在輸出中有一個對應每個 token 的權重
output_bias = tf.get_variable(
"output_bias",
shape=[bert_config.vocab_size],
initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
# step 4:label_ids表示mask掉的Token的id
label_ids = tf.reshape(label_ids, [-1])
label_weights = tf.reshape(label_weights, [-1])
# step 5:關于 label 的一些格式處理,處理完之后把 label 轉化成 one hot 類型的輸出。
one_hot_labels = tf.one_hot(
label_ids, depth=bert_config.vocab_size, dtype=tf.float32)
# step 6:`positions` tensor可以補零(如果序列太短而無法獲得最大預測數)。 對于每個真實的預測,`label_weights` tensor 的值為1.0,對于填充預測,其值為0.0。
# 但是由于實際MASK的可能不到20,比如只MASK18,那么label_ids有2個0(padding)
# 而label_weights=[1, 1, ...., 0, 0],說明后面兩個label_id是padding的,計算loss要去掉。
per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
numerator = tf.reduce_sum(label_weights * per_example_loss)
denominator = tf.reduce_sum(label_weights) + 1e-5
loss = numerator / denominator
return (loss, per_example_loss, log_probs)
6.2 獲取 next sentence prediction(下一句預測) 部分的 loss 以及 log probs (get_next_sentence_output)
作用:用于計算 NSP 的訓練loss
def get_next_sentence_output(bert_config, input_tensor, labels):
"""獲取 NSP 的 loss 和 log probs"""
# 二分類任務
# 0 is "next sentence" and 1 is "random sentence".
# 這個分類器的參數在實際Fine-tuning階段會丟棄掉
with tf.variable_scope("cls/seq_relationship"):
output_weights = tf.get_variable(
"output_weights",
shape=[2, bert_config.hidden_size],
initializer=modeling.create_initializer(bert_config.initializer_range))
output_bias = tf.get_variable(
"output_bias", shape=[2], initializer=tf.zeros_initializer())
logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
log_probs = tf.nn.log_softmax(logits, axis=-1)
labels = tf.reshape(labels, [-1])
one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
return (loss, per_example_loss, log_probs)
七、測試
python create_pretraining_data.py \
--input_file=./sample_text_zh.txt \
--output_file=/tmp/tf_examples.tfrecord \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--do_lower_case=True \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--masked_lm_prob=0.15 \
--random_seed=12345 \
--dupe_factor=5
八、總結
本章 主要介紹了 利用 Bert fineturn,代碼比較簡單。
【關于 Bert 源碼解析 之 主體篇 】 那些的你不知道的事
【關于 Bert 源碼解析 之 預訓練篇 】 那些的你不知道的事【本章】
【關于 Bert 源碼解析 之 微調篇 】 那些的你不知道的事
【關于 Bert 源碼解析IV 之 句向量生成篇 】 那些的你不知道的事
【關于 Bert 源碼解析V 之 文本相似度篇 】 那些的你不知道的事
分模塊 進行解讀。
參考文檔
Bert系列(三)——源碼解讀之Pre-train
BERT源碼分析PART II


