10分鐘掌握Bert源碼(PyTorch版)

Bert在生產(chǎn)環(huán)境的應(yīng)用需要進(jìn)行壓縮,這就要求對(duì)Bert結(jié)構(gòu)很了解,這個(gè)倉(cāng)庫(kù)會(huì)一步步解讀Bert源代碼(pytorch版本)。倉(cāng)庫(kù)地址在
https://github.com/DA-southampton/NLP_ability
代碼和數(shù)據(jù)介紹
首先 對(duì)代碼來(lái)說(shuō),借鑒的是這個(gè)倉(cāng)庫(kù)
我直接把代碼clone過(guò)來(lái),放到了本倉(cāng)庫(kù),重新命名為bert_read_step_to_step。
我會(huì)使用這個(gè)代碼,一步步運(yùn)行bert關(guān)于文本分類的代碼,然后同時(shí)記錄下各種細(xì)節(jié)包括自己實(shí)現(xiàn)的情況。
運(yùn)行之前,需要做兩個(gè)事情。
準(zhǔn)備預(yù)訓(xùn)練模型
一個(gè)是預(yù)訓(xùn)練模型的準(zhǔn)備,我使用的是谷歌的中文預(yù)訓(xùn)練模型:chinese_L-12_H-768_A-12.zip,模型有點(diǎn)大,我就不上傳了,如果本地不存在,就點(diǎn)擊這里直接下載,或者直接命令行運(yùn)行
wget?https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
預(yù)訓(xùn)練模型下載下來(lái)之后,進(jìn)行解壓,然后將tf模型轉(zhuǎn)為對(duì)應(yīng)的pytorch版本即可。對(duì)應(yīng)代碼如下:
export?BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12
python?convert_tf_checkpoint_to_pytorch.py?\
??--tf_checkpoint_path?$BERT_BASE_DIR/bert_model.ckpt?\
??--bert_config_file?$BERT_BASE_DIR/bert_config.json?\
??--pytorch_dump_path?$BERT_BASE_DIR/pytorch_model.bin
轉(zhuǎn)化成功之后,將模型放入到倉(cāng)庫(kù)對(duì)應(yīng)位置:
Read_Bert_Code/bert_read_step_to_step/prev_trained_model/
并重新命名為:
bert-base-chinese
準(zhǔn)備文本分類訓(xùn)練數(shù)據(jù)
第二個(gè)事情就是準(zhǔn)備訓(xùn)練數(shù)據(jù),這里我準(zhǔn)備做一個(gè)文本分類任務(wù),使用的是Tnews數(shù)據(jù)集,這個(gè)數(shù)據(jù)集來(lái)源是這里,分為訓(xùn)練,測(cè)試和開(kāi)發(fā)集,我已經(jīng)上傳到了倉(cāng)庫(kù)中,具體位置在
Read_Bert_Code/bert_read_step_to_step/chineseGLUEdatasets/tnews
需要注意的一點(diǎn)是,因?yàn)槲抑皇菫榱肆私鈨?nèi)部代碼情況,所以準(zhǔn)確度不是在我的考慮范圍之內(nèi),所以我只是取其中的一部分?jǐn)?shù)據(jù),其中訓(xùn)練數(shù)據(jù)使用1k,測(cè)試數(shù)據(jù)使用1k,開(kāi)發(fā)數(shù)據(jù)1k。
準(zhǔn)備就緒,使用pycharm導(dǎo)入項(xiàng)目,準(zhǔn)備調(diào)試,我的調(diào)試文件是 run_classifier.py文件,對(duì)應(yīng)的參數(shù)為
--model_type=bert?--model_name_or_path=prev_trained_model/bert-base-chinese?--task_name="tnews"?--do_train?--do_eval?--do_lower_case?--data_dir=./chineseGLUEdatasets/tnews?--max_seq_length=128?--per_gpu_train_batch_size=16?--per_gpu_eval_batch_size=16?--learning_rate=2e-5?--num_train_epochs=4.0?--logging_steps=100?--save_steps=100?--output_dir=./outputs/tnews_output/?--overwrite_output_dir
然后對(duì)run_classifier.py 進(jìn)行調(diào)試,我會(huì)在下面是調(diào)試的細(xì)節(jié)
1.main函數(shù)進(jìn)入
首先是主函數(shù)位置打入斷點(diǎn),位置在這里,然后進(jìn)入看一下主函數(shù)的情況
##主函數(shù)打上斷點(diǎn)
if?__name__?==?"__main__":
????main()##主函數(shù)進(jìn)入
2.解析命令行參數(shù)
從這里到這里就是在解析命令行參數(shù),是常規(guī)操作,主要是什么模型名稱,模型地址,是否進(jìn)行測(cè)試等等。比較簡(jiǎn)單,直接過(guò)就可以了。
3.判斷一些情況
從這里到這里是一些常規(guī)的判斷:
判斷是否存在輸出文件夾
判斷是否需要遠(yuǎn)程debug
判斷單機(jī)cpu訓(xùn)練還是單機(jī)多gpu訓(xùn)練,還是多機(jī)分布式gpu訓(xùn)練,這個(gè)有兩個(gè)參數(shù)進(jìn)行控制
具體可以看代碼如下:
if?args.local_rank?==?-1?or?args.no_cuda:
????device?=?torch.device("cuda"?if?torch.cuda.is_available()?and?not?args.no_cuda?else?"cpu")
????args.n_gpu?=?torch.cuda.device_count()
else:??#?Initializes?the?distributed?backend?which?will?take?care?of?sychronizing?nodes/GPUs
????torch.cuda.set_device(args.local_rank)
????device?=?torch.device("cuda",?args.local_rank)
????torch.distributed.init_process_group(backend='nccl')
????args.n_gpu?=?1
4.獲取任務(wù)對(duì)應(yīng)Processor
獲取任務(wù)對(duì)應(yīng)的相應(yīng)processor,這個(gè)對(duì)應(yīng)的函數(shù)就是需要我們自己去定義的處理我們自己輸入文件的函數(shù),位置在這里,代碼如下:
processor?=?processors[args.task_name]()
這里我們使用的是,這個(gè)結(jié)果返回的是一個(gè)類,我們使用的是如下的類:
TnewsProcessor(DataProcessor)
具體代碼位置在這里,
4.1 TnewsProcessor
仔細(xì)分析一下TnewsProcessor,首先繼承自DataProcessor
點(diǎn)擊此處打開(kāi)折疊代碼
## DataProcessor在整個(gè)項(xiàng)目的位置:processors.utils.DataProcessor
class?DataProcessor(object):
????def?get_train_examples(self,?data_dir):
????????raise?NotImplementedError()
????def?get_dev_examples(self,?data_dir):
????????raise?NotImplementedError()
????def?get_labels(self):
????????raise?NotImplementedError()
????@classmethod
????def?_read_tsv(cls,?input_file,?quotechar=None):
????????with?open(input_file,?"r",?encoding="utf-8-sig")?as?f:
????????????reader?=?csv.reader(f,?delimiter="\t",?quotechar=quotechar)
????????????lines?=?[]
????????????for?line?in?reader:
????????????????lines.append(line)
????????????return?lines
????@classmethod
????def?_read_txt(cls,?input_file):
????????"""Reads?a?tab?separated?value?file."""
????????with?open(input_file,?"r")?as?f:
????????????reader?=?f.readlines()
????????????lines?=?[]
????????????for?line?in?reader:
????????????????lines.append(line.strip().split("_!_"))
????????????return?lines
然后它自己包含五個(gè)函數(shù),分別是讀取訓(xùn)練集,開(kāi)發(fā)集數(shù)據(jù),獲取返回label,制作bert需要的格式的數(shù)據(jù)
接下來(lái)看一下 TnewsProcessor代碼格式:
點(diǎn)擊此處打開(kāi)折疊代碼
class?TnewsProcessor(DataProcessor):
????def?get_train_examples(self,?data_dir):
????????"""See?base?class."""
????????return?self._create_examples(
????????????self._read_txt(os.path.join(data_dir,?"toutiao_category_train.txt")),?"train")
????def?get_dev_examples(self,?data_dir):
????????"""See?base?class."""
????????return?self._create_examples(
????????????self._read_txt(os.path.join(data_dir,?"toutiao_category_dev.txt")),?"dev")
????def?get_test_examples(self,?data_dir):
????????"""See?base?class."""
????????return?self._create_examples(
????????????self._read_txt(os.path.join(data_dir,?"toutiao_category_test.txt")),?"test")
????def?get_labels(self):
????????"""See?base?class."""
????????labels?=?[]
????????for?i?in?range(17):
????????????if?i?==?5?or?i?==?11:
????????????????continue
????????????labels.append(str(100?+?i))
????????return?labels
????def?_create_examples(self,?lines,?set_type):
????????"""Creates?examples?for?the?training?and?dev?sets."""
????????examples?=?[]
????????for?(i,?line)?in?enumerate(lines):
????????????guid?=?"%s-%s"?%?(set_type,?i)
????????????text_a?=?line[3]
????????????if?set_type?==?'test':
????????????????label?=?'0'
????????????else:
????????????????label?=?line[1]
????????????examples.append(
????????????????InputExample(guid=guid,?text_a=text_a,?text_b=None,?label=label))
????????return?examples
這里有一點(diǎn)需要提醒大家,如果說(shuō)我們使用自己的訓(xùn)練數(shù)據(jù),有兩個(gè)方法,第一個(gè)就是把數(shù)據(jù)格式變化成和我們測(cè)試用例一樣的數(shù)據(jù),第二個(gè)就是我們?cè)谶@里更改源代碼,去讀取我們自己的數(shù)據(jù)格式
5.加載預(yù)訓(xùn)練模型
代碼比較簡(jiǎn)單,就是調(diào)用預(yù)訓(xùn)練模型,不詳細(xì)介紹了
點(diǎn)擊此處打開(kāi)折疊代碼
config_class,?model_class,?tokenizer_class?=?MODEL_CLASSES[args.model_type]
config?=?config_class.from_pretrained(args.config_name?if?args.config_name?else?args.model_name_or_path,?num_labels=num_labels,?finetuning_task=args.task_name)
tokenizer?=?tokenizer_class.from_pretrained(args.tokenizer_name?if?args.tokenizer_name?else?args.model_name_or_path,?do_lower_case=args.do_lower_case)
model?=?model_class.from_pretrained(args.model_name_or_path,?from_tf=bool('.ckpt'?in?args.model_name_or_path),?config=config)
6.訓(xùn)練模型-也是最重要的部分
訓(xùn)練模型,從主函數(shù)這里看就是兩個(gè)步驟,一個(gè)是加載需要的數(shù)據(jù)集,一個(gè)是進(jìn)行訓(xùn)練,代碼位置在這里。大概代碼就是這樣:
train_dataset?=?load_and_cache_examples(args,?args.task_name,?tokenizer,?data_type='train')
global_step,?tr_loss?=?train(args,?train_dataset,?model,?tokenizer)
兩個(gè)函數(shù),我們一個(gè)個(gè)看:
6.1 加載訓(xùn)練集
我們先看一下第一個(gè)函數(shù),load_and_cache_examples 就是加載訓(xùn)練數(shù)據(jù)集,代碼位置在這里。大概看一下這個(gè)代碼,核心操作有三個(gè)。
第一個(gè)核心操作,位置在這里,代碼如下:
examples?=?processor.get_train_examples(args.data_dir)
這個(gè)代碼是為了利用processor讀取訓(xùn)練集,很簡(jiǎn)單。
這里得到的example大概是這樣的(這個(gè)返回形式在上面看processor的時(shí)候很清楚的展示了):
guid='train-0'
label='104'
text_a='今天股票形式不怎么樣啊'
text_b=None
第二個(gè)核心操作是convert_examples_to_features講數(shù)據(jù)進(jìn)行轉(zhuǎn)化,也很簡(jiǎn)單。
代碼位置在這里。代碼如下:
features?=?convert_examples_to_features(examples,tokenizer,label_list=label_list,max_length=args.max_seq_length,output_mode=output_mode,pad_on_left=bool(args.model_type?in?['xlnet']),????????????????????????????????????????????????pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4?if?args.model_type?in?['xlnet']?else?0,
我們進(jìn)入這個(gè)函數(shù)看一看里面究竟是咋回事,位置在:
processors.glue.glue_convert_examples_to_features
做了一個(gè)標(biāo)簽的映射,'100'->0 ?'101'->1...
接著獲取輸入文本的序列化表達(dá):input_ids, token_type_ids;形式大概如此:
'input_ids'=[101, 5500, 4873, 704, 4638, 4960, 4788, 2501, 2578, 102]
'token_type_ids'=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
獲取attention mask:attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
結(jié)果形式如下:[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
計(jì)算出當(dāng)前長(zhǎng)度,獲取pading長(zhǎng)度,比如我們現(xiàn)在長(zhǎng)度是10,那么需要補(bǔ)到128,pad就需要118個(gè)0.
這個(gè)時(shí)候,我們的input_ids 就變成了上面的列表后面加上128個(gè)0.然后我們的attention_mask就變成了上面的形式加上118個(gè)0,因?yàn)檠a(bǔ)長(zhǎng)的并不是我們的第二個(gè)句子,我們壓根沒(méi)第二個(gè)句子,所以token_type_ids是總共128個(gè)0
每操作一個(gè)數(shù)據(jù)之后,我們需要做的是
features.append(InputFeatures(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
label=label,
input_len=input_len))##長(zhǎng)度為原始長(zhǎng)度,這里應(yīng)該是10,不是128
InputFeatures 在這里就是將轉(zhuǎn)化之后的特征存儲(chǔ)到一個(gè)新的變量中
在將所有原始數(shù)據(jù)進(jìn)行特征轉(zhuǎn)化之后,我們得到了features列表,然后將其中的元素轉(zhuǎn)化為tensor形式,隨后
第三個(gè)是將轉(zhuǎn)化之后的新數(shù)據(jù)tensor化,然后使用TensorDataset構(gòu)造最終的數(shù)據(jù)集并返回,
dataset?=?TensorDataset(all_input_ids,?all_attention_mask,?all_token_type_ids,?all_lens,all_labels)
6.2 訓(xùn)練模型-Train函數(shù)
我們來(lái)看第二個(gè)函數(shù),就是train的操作。
6.2.1 常規(guī)操作
首先都是一些常規(guī)操作。
對(duì)數(shù)據(jù)隨機(jī)采樣:RandomSampler
DataLoader讀取數(shù)據(jù)
計(jì)算總共訓(xùn)練步數(shù)(梯度累計(jì)),warm_up 參數(shù)設(shè)定,優(yōu)化器,是否fp16等等
然后一個(gè)batch一個(gè)batch進(jìn)行訓(xùn)練就好了。這里最核心的代碼就是下面的把數(shù)據(jù)和參數(shù)送入到模型中去:
outputs?=?model(**inputs)
我們是在進(jìn)行一個(gè)文本分類的demo操作,使用的是Bert中對(duì)應(yīng)的 BertForSequenceClassification 這個(gè)類。
我們直接進(jìn)入這個(gè)類看一下里面函數(shù)究竟是啥情況。
6.2.2 Bert分類模型:BertForSequenceClassification
主要代碼代碼如下:
點(diǎn)擊此處打開(kāi)折疊代碼
##reference:?transformers.modeling_bert.BertForSequenceClassification?
class?BertForSequenceClassification(BertPreTrainedModel):
????def?__init__(self,?config):
????????????????...
????????...
????????self.bert?=?BertModel(config)
????????self.dropout?=?nn.Dropout(config.hidden_dropout_prob)
????????self.classifier?=?nn.Linear(config.hidden_size,?self.config.num_labels)
????def?forward(self,?input_ids,?attention_mask=None,?token_type_ids=None,position_ids=None,?head_mask=None,?labels=None):
????????outputs?=?self.bert(input_ids,
????????????????????????????attention_mask=attention_mask,
????????????????????????????token_type_ids=token_type_ids,
????????????????????????????position_ids=position_ids,?
????????????????????????????head_mask=head_mask)
????????##zida注解:注意看init中,定義了self.bert就是BertModel,所以我們需要的就是看一看BertModel中數(shù)據(jù)怎么進(jìn)入的
????????pooled_output?=?outputs[1]
????????pooled_output?=?self.dropout(pooled_output)
????????????????...
????????...
????????return?outputs??#?(loss),?logits,?(hidden_states),?(attentions)
這個(gè)類最核心的有兩個(gè)部分,第一個(gè)部分就是使用了 BertModel 獲取Bert的原始輸出,然后使用 cls的輸出繼續(xù)做后續(xù)的分類操作。比較重要的是 BertModel,我們直接進(jìn)入看 BertModel 這個(gè)類的內(nèi)部情況。代碼如下:
然后我們看一下BertModel這個(gè)模型究竟是怎么樣的
6.2.1.1 BertModel
代碼如下:
點(diǎn)擊此處打開(kāi)折疊代碼
##?reference:?transformers.modeling_bert.BertModel??
class?BertModel(BertPreTrainedModel):
????def?__init__(self,?config):
????????self.embeddings?=?BertEmbeddings(config)
????????self.encoder?=?BertEncoder(config)
????????self.pooler?=?BertPooler(config)
????????????????...
????def?forward(self,?input_ids,?attention_mask=None,?token_type_ids=None,position_ids=None,?head_mask=None):
????????????????...
????????###?第一部分,對(duì)?attention_mask?進(jìn)行操作,并對(duì)輸入做embedding
????????extended_attention_mask?=?attention_mask.unsqueeze(1).unsqueeze(2)
????????extended_attention_mask?=?extended_attention_mask.to(dtype=next(self.parameters()).dtype)?#?fp16?compatibility
????????extended_attention_mask?=?(1.0?-?extended_attention_mask)?*?-10000.0
????????embedding_output?=?self.embeddings(input_ids,?position_ids=position_ids,?token_type_ids=token_type_ids)
????????###?第二部分?進(jìn)入?encoder?進(jìn)行編碼
????????encoder_outputs?=?self.encoder(embedding_output,
???????????????????????????????????????extended_attention_mask,
???????????????????????????????????????head_mask=head_mask)
????????????????...
????????return?outputs
對(duì)于BertModel ,我們可以把它分成兩個(gè)部分,第一個(gè)部分是對(duì) attention_mask 進(jìn)行操作,并對(duì)輸入做embedding,第二個(gè)部分是進(jìn)入encoder進(jìn)行編碼,這里的encoder使用的是BertEncoder。我們直接進(jìn)去看一下
6.2.1.1.1 BertEncoder
代碼如下:
##reference:transformers.modeling_bert.BertEncoder
class?BertEncoder(nn.Module):
????def?__init__(self,?config):
????????super(BertEncoder,?self).__init__()
????????self.output_attentions?=?config.output_attentions
????????self.output_hidden_states?=?config.output_hidden_states
????????self.layer?=?nn.ModuleList([BertLayer(config)?for?_?in?range(config.num_hidden_layers)])
????def?forward(self,?hidden_states,?attention_mask=None,?head_mask=None):
????????all_hidden_states?=?()
????????all_attentions?=?()
????????for?i,?layer_module?in?enumerate(self.layer):
????????????if?self.output_hidden_states:
????????????????all_hidden_states?=?all_hidden_states?+?(hidden_states,)
????????????layer_outputs?=?layer_module(hidden_states,?attention_mask,?head_mask[i])
????????????hidden_states?=?layer_outputs[0]
????????????if?self.output_attentions:
????????????????all_attentions?=?all_attentions?+?(layer_outputs[1],)
????????#?Add?last?layer
????????if?self.output_hidden_states:
????????????all_hidden_states?=?all_hidden_states?+?(hidden_states,)
????????outputs?=?(hidden_states,)
????????if?self.output_hidden_states:
????????????outputs?=?outputs?+?(all_hidden_states,)
????????if?self.output_attentions:
????????????outputs?=?outputs?+?(all_attentions,)
????????return?outputs??#?last-layer?hidden?state,?(all?hidden?states),?(all?attentions)
有一個(gè)BertEncoder小細(xì)節(jié),就是如果output_hidden_states為T(mén)rue,會(huì)把每一層的結(jié)果都輸出,也包含詞向量,所以如果十二層的話,輸出是是13層,第一層為word-embedding結(jié)果,每層結(jié)果都是[batchsize,seqlength,Hidden_size](除了第一層,[batchsize,seqlength,embedding_size])
當(dāng)然embedding_size在維度上是和隱層維度一樣的。
還有一點(diǎn)需要注意的就是,我們需要在這里看到一個(gè)細(xì)節(jié),就是我們是可以做head_mask,這個(gè)head_mask我記得有個(gè)論文是在做哪個(gè)head對(duì)結(jié)果的影響,這個(gè)好像能夠?qū)崿F(xiàn)。
BertEncoder 中間最重要的是BertLayer
BertLayer
BertLayer分為兩個(gè)操作,BertAttention和BertIntermediate。BertAttention分為BertSelfAttention和BertSelfOutput。我們一個(gè)個(gè)來(lái)看
BertAttention BertSelfAttention
def?forward(self,?hidden_states,?attention_mask=None,?head_mask=None):
??##?接受參數(shù)如上
??mixed_query_layer?=?self.query(hidden_states)?##?生成query?[16,32,768],16是batch_size,32是這個(gè)batch中每個(gè)句子的長(zhǎng)度,768是維度
??mixed_key_layer?=?self.key(hidden_states)
??mixed_value_layer?=?self.value(hidden_states)
??query_layer?=?self.transpose_for_scores(mixed_query_layer)##?將上面生成的query進(jìn)行維度轉(zhuǎn)化,現(xiàn)在維度:?[16,12,32,64]:[Batch_size,Num_head,Seq_len,每個(gè)頭維度]
??key_layer?=?self.transpose_for_scores(mixed_key_layer)
??value_layer?=?self.transpose_for_scores(mixed_value_layer)
??#?Take?the?dot?product?between?"query"?and?"key"?to?get?the?raw?attention?scores.
??attention_scores?=?torch.matmul(query_layer,?key_layer.transpose(-1,?-2))
??##?上面操作之后?attention_scores?維度為torch.Size([16,?12,?32,?32])
??attention_scores?=?attention_scores?/?math.sqrt(self.attention_head_size)
??if?attention_mask?is?not?None:
??#?Apply?the?attention?mask?is?(precomputed?for?all?layers?in?BertModel?forward()?function)
??attention_scores?=?attention_scores?+?attention_mask
??##?這里直接就是相加了,pad的部分直接為非常大的負(fù)值,下面softmax的時(shí)候,直接就為接近0
??#?Normalize?the?attention?scores?to?probabilities.
??attention_probs?=?nn.Softmax(dim=-1)(attention_scores)
??#?This?is?actually?dropping?out?entire?tokens?to?attend?to,?which?might
??#?seem?a?bit?unusual,?but?is?taken?from?the?original?Transformer?paper.
??attention_probs?=?self.dropout(attention_probs)##維度torch.Size([16,?12,?32,?32])
??#?Mask?heads?if?we?want?to
??if?head_mask?is?not?None:
????attention_probs?=?attention_probs?*?head_mask
??context_layer?=?torch.matmul(attention_probs,?value_layer)##維度torch.Size([16,?12,?32,?64])
??context_layer?=?context_layer.permute(0,?2,?1,?3).contiguous()##維度torch.Size([16,?32,?12,?64])
??new_context_layer_shape?=?context_layer.size()[:-2]?+?(self.all_head_size,)## new_context_layer_shape:torch.Size([16, 32, 768])
??context_layer?=?context_layer.view(*new_context_layer_shape)
##維度變成torch.Size([16,?32,?768])
??outputs?=?(context_layer,?attention_probs)?if?self.output_attentions?else?(context_layer,)
??return?outputs
這個(gè)時(shí)候 BertSelfAttention 返回結(jié)果維度為 torch.Size([16, 32, 768]),這個(gè)結(jié)果作為BertSelfOutput的輸入
BertSelfOutput
class?BertSelfOutput(nn.Module):
????def?__init__(self,?config):
????????super(BertSelfOutput,?self).__init__()
????????self.dense?=?nn.Linear(config.hidden_size,?config.hidden_size)
????????##?做了一個(gè)linear?維度沒(méi)變
????????self.LayerNorm?=?BertLayerNorm(config.hidden_size,?eps=config.layer_norm_eps)
????????self.dropout?=?nn.Dropout(config.hidden_dropout_prob)
????def?forward(self,?hidden_states,?input_tensor):
????????hidden_states?=?self.dense(hidden_states)
????????hidden_states?=?self.dropout(hidden_states)
????????hidden_states?=?self.LayerNorm(hidden_states?+?input_tensor)
????????return?hidden_states
上面兩個(gè)函數(shù)BertSelfAttention 和BertSelfOutput之后,返回attention的結(jié)果,接下來(lái)仍然是BertLayer的下一個(gè)操作:BertIntermediate
BertIntermediate
這個(gè)函數(shù)比較簡(jiǎn)單,經(jīng)過(guò)一個(gè)Linear,經(jīng)過(guò)一個(gè)Gelu激活函數(shù)
輸入結(jié)果維度為 torch.Size([16, 32, 3072])
這個(gè)結(jié)果 接下來(lái)進(jìn)入 BertOutput 這個(gè)模型
BertOutput
也比較簡(jiǎn)單,Liner+BertLayerNorm+Dropout,輸出結(jié)果維度為 torch.Size([16, 32, 768])
BertOutput的輸出結(jié)果返回給BertEncoder 類
BertEncoder 結(jié)果返回 BertModel 類 作為 encoder_outputs,維度大小 torch.Size([16, 32, 768])
BertModel的返回為 outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
sequence_output:torch.Size([16, 32, 768])
pooled_output:torch.Size([16, 768]) 是cls的輸出經(jīng)過(guò)一個(gè)pool層(其實(shí)就是linear維度不變+tanh)的輸出
outputs返回給BertForSequenceClassification,也就是對(duì)pooled_output 做分類
更多閱讀
特別推薦

點(diǎn)擊下方閱讀原文加入社區(qū)會(huì)員
