<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          10分鐘掌握Bert源碼(PyTorch版)

          共 15588字,需瀏覽 32分鐘

           ·

          2021-01-19 08:19

          Bert在生產(chǎn)環(huán)境的應(yīng)用需要進(jìn)行壓縮,這就要求對(duì)Bert結(jié)構(gòu)很了解,這個(gè)倉(cāng)庫(kù)會(huì)一步步解讀Bert源代碼(pytorch版本)。倉(cāng)庫(kù)地址在

          https://github.com/DA-southampton/NLP_ability

          代碼和數(shù)據(jù)介紹

          首先 對(duì)代碼來(lái)說(shuō),借鑒的是這個(gè)倉(cāng)庫(kù)

          我直接把代碼clone過(guò)來(lái),放到了本倉(cāng)庫(kù),重新命名為bert_read_step_to_step。

          我會(huì)使用這個(gè)代碼,一步步運(yùn)行bert關(guān)于文本分類的代碼,然后同時(shí)記錄下各種細(xì)節(jié)包括自己實(shí)現(xiàn)的情況。

          運(yùn)行之前,需要做兩個(gè)事情。

          準(zhǔn)備預(yù)訓(xùn)練模型

          一個(gè)是預(yù)訓(xùn)練模型的準(zhǔn)備,我使用的是谷歌的中文預(yù)訓(xùn)練模型:chinese_L-12_H-768_A-12.zip,模型有點(diǎn)大,我就不上傳了,如果本地不存在,就點(diǎn)擊這里直接下載,或者直接命令行運(yùn)行

          wget?https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip

          預(yù)訓(xùn)練模型下載下來(lái)之后,進(jìn)行解壓,然后將tf模型轉(zhuǎn)為對(duì)應(yīng)的pytorch版本即可。對(duì)應(yīng)代碼如下:

          export?BERT_BASE_DIR=/path/to/bert/chinese_L-12_H-768_A-12

          python?convert_tf_checkpoint_to_pytorch.py?\
          ??--tf_checkpoint_path?$BERT_BASE_DIR/bert_model.ckpt?\
          ??--bert_config_file?$BERT_BASE_DIR/bert_config.json?\
          ??--pytorch_dump_path?$BERT_BASE_DIR/pytorch_model.bin

          轉(zhuǎn)化成功之后,將模型放入到倉(cāng)庫(kù)對(duì)應(yīng)位置:

          Read_Bert_Code/bert_read_step_to_step/prev_trained_model/

          并重新命名為:

          bert-base-chinese

          準(zhǔn)備文本分類訓(xùn)練數(shù)據(jù)

          第二個(gè)事情就是準(zhǔn)備訓(xùn)練數(shù)據(jù),這里我準(zhǔn)備做一個(gè)文本分類任務(wù),使用的是Tnews數(shù)據(jù)集,這個(gè)數(shù)據(jù)集來(lái)源是這里,分為訓(xùn)練,測(cè)試和開(kāi)發(fā)集,我已經(jīng)上傳到了倉(cāng)庫(kù)中,具體位置在

          Read_Bert_Code/bert_read_step_to_step/chineseGLUEdatasets/tnews

          需要注意的一點(diǎn)是,因?yàn)槲抑皇菫榱肆私鈨?nèi)部代碼情況,所以準(zhǔn)確度不是在我的考慮范圍之內(nèi),所以我只是取其中的一部分?jǐn)?shù)據(jù),其中訓(xùn)練數(shù)據(jù)使用1k,測(cè)試數(shù)據(jù)使用1k,開(kāi)發(fā)數(shù)據(jù)1k。

          準(zhǔn)備就緒,使用pycharm導(dǎo)入項(xiàng)目,準(zhǔn)備調(diào)試,我的調(diào)試文件是 run_classifier.py文件,對(duì)應(yīng)的參數(shù)為

          --model_type=bert?--model_name_or_path=prev_trained_model/bert-base-chinese?--task_name="tnews"?--do_train?--do_eval?--do_lower_case?--data_dir=./chineseGLUEdatasets/tnews?--max_seq_length=128?--per_gpu_train_batch_size=16?--per_gpu_eval_batch_size=16?--learning_rate=2e-5?--num_train_epochs=4.0?--logging_steps=100?--save_steps=100?--output_dir=./outputs/tnews_output/?--overwrite_output_dir

          然后對(duì)run_classifier.py 進(jìn)行調(diào)試,我會(huì)在下面是調(diào)試的細(xì)節(jié)

          1.main函數(shù)進(jìn)入

          首先是主函數(shù)位置打入斷點(diǎn),位置在這里,然后進(jìn)入看一下主函數(shù)的情況

          ##主函數(shù)打上斷點(diǎn)
          if?__name__?==?"__main__":
          ????main()##主函數(shù)進(jìn)入

          2.解析命令行參數(shù)

          從這里到這里就是在解析命令行參數(shù),是常規(guī)操作,主要是什么模型名稱,模型地址,是否進(jìn)行測(cè)試等等。比較簡(jiǎn)單,直接過(guò)就可以了。

          3.判斷一些情況

          從這里到這里是一些常規(guī)的判斷:

          判斷是否存在輸出文件夾

          判斷是否需要遠(yuǎn)程debug

          判斷單機(jī)cpu訓(xùn)練還是單機(jī)多gpu訓(xùn)練,還是多機(jī)分布式gpu訓(xùn)練,這個(gè)有兩個(gè)參數(shù)進(jìn)行控制

          具體可以看代碼如下:

          if?args.local_rank?==?-1?or?args.no_cuda:
          ????device?=?torch.device("cuda"?if?torch.cuda.is_available()?and?not?args.no_cuda?else?"cpu")
          ????args.n_gpu?=?torch.cuda.device_count()
          else:??#?Initializes?the?distributed?backend?which?will?take?care?of?sychronizing?nodes/GPUs
          ????torch.cuda.set_device(args.local_rank)
          ????device?=?torch.device("cuda",?args.local_rank)
          ????torch.distributed.init_process_group(backend='nccl')
          ????args.n_gpu?=?1

          4.獲取任務(wù)對(duì)應(yīng)Processor

          獲取任務(wù)對(duì)應(yīng)的相應(yīng)processor,這個(gè)對(duì)應(yīng)的函數(shù)就是需要我們自己去定義的處理我們自己輸入文件的函數(shù),位置在這里,代碼如下:

          processor?=?processors[args.task_name]()

          這里我們使用的是,這個(gè)結(jié)果返回的是一個(gè)類,我們使用的是如下的類:

          TnewsProcessor(DataProcessor)

          具體代碼位置在這里,

          4.1 TnewsProcessor

          仔細(xì)分析一下TnewsProcessor,首先繼承自DataProcessor

          點(diǎn)擊此處打開(kāi)折疊代碼

          ## DataProcessor在整個(gè)項(xiàng)目的位置:processors.utils.DataProcessor
          class?DataProcessor(object):
          ????def?get_train_examples(self,?data_dir):
          ????????raise?NotImplementedError()

          ????def?get_dev_examples(self,?data_dir):
          ????????raise?NotImplementedError()

          ????def?get_labels(self):
          ????????raise?NotImplementedError()

          ????@classmethod
          ????def?_read_tsv(cls,?input_file,?quotechar=None):
          ????????with?open(input_file,?"r",?encoding="utf-8-sig")?as?f:
          ????????????reader?=?csv.reader(f,?delimiter="\t",?quotechar=quotechar)
          ????????????lines?=?[]
          ????????????for?line?in?reader:
          ????????????????lines.append(line)
          ????????????return?lines

          ????@classmethod
          ????def?_read_txt(cls,?input_file):
          ????????"""Reads?a?tab?separated?value?file."""
          ????????with?open(input_file,?"r")?as?f:
          ????????????reader?=?f.readlines()
          ????????????lines?=?[]
          ????????????for?line?in?reader:
          ????????????????lines.append(line.strip().split("_!_"))
          ????????????return?lines

          然后它自己包含五個(gè)函數(shù),分別是讀取訓(xùn)練集,開(kāi)發(fā)集數(shù)據(jù),獲取返回label,制作bert需要的格式的數(shù)據(jù)

          接下來(lái)看一下 TnewsProcessor代碼格式:

          點(diǎn)擊此處打開(kāi)折疊代碼

          class?TnewsProcessor(DataProcessor):

          ????def?get_train_examples(self,?data_dir):
          ????????"""See?base?class."""
          ????????return?self._create_examples(
          ????????????self._read_txt(os.path.join(data_dir,?"toutiao_category_train.txt")),?"train")

          ????def?get_dev_examples(self,?data_dir):
          ????????"""See?base?class."""
          ????????return?self._create_examples(
          ????????????self._read_txt(os.path.join(data_dir,?"toutiao_category_dev.txt")),?"dev")

          ????def?get_test_examples(self,?data_dir):
          ????????"""See?base?class."""
          ????????return?self._create_examples(
          ????????????self._read_txt(os.path.join(data_dir,?"toutiao_category_test.txt")),?"test")

          ????def?get_labels(self):
          ????????"""See?base?class."""
          ????????labels?=?[]
          ????????for?i?in?range(17):
          ????????????if?i?==?5?or?i?==?11:
          ????????????????continue
          ????????????labels.append(str(100?+?i))
          ????????return?labels

          ????def?_create_examples(self,?lines,?set_type):
          ????????"""Creates?examples?for?the?training?and?dev?sets."""
          ????????examples?=?[]
          ????????for?(i,?line)?in?enumerate(lines):
          ????????????guid?=?"%s-%s"?%?(set_type,?i)
          ????????????text_a?=?line[3]
          ????????????if?set_type?==?'test':
          ????????????????label?=?'0'
          ????????????else:
          ????????????????label?=?line[1]
          ????????????examples.append(
          ????????????????InputExample(guid=guid,?text_a=text_a,?text_b=None,?label=label))
          ????????return?examples

          這里有一點(diǎn)需要提醒大家,如果說(shuō)我們使用自己的訓(xùn)練數(shù)據(jù),有兩個(gè)方法,第一個(gè)就是把數(shù)據(jù)格式變化成和我們測(cè)試用例一樣的數(shù)據(jù),第二個(gè)就是我們?cè)谶@里更改源代碼,去讀取我們自己的數(shù)據(jù)格式

          5.加載預(yù)訓(xùn)練模型

          代碼比較簡(jiǎn)單,就是調(diào)用預(yù)訓(xùn)練模型,不詳細(xì)介紹了

          點(diǎn)擊此處打開(kāi)折疊代碼

          config_class,?model_class,?tokenizer_class?=?MODEL_CLASSES[args.model_type]
          config?=?config_class.from_pretrained(args.config_name?if?args.config_name?else?args.model_name_or_path,?num_labels=num_labels,?finetuning_task=args.task_name)
          tokenizer?=?tokenizer_class.from_pretrained(args.tokenizer_name?if?args.tokenizer_name?else?args.model_name_or_path,?do_lower_case=args.do_lower_case)
          model?=?model_class.from_pretrained(args.model_name_or_path,?from_tf=bool('.ckpt'?in?args.model_name_or_path),?config=config)

          6.訓(xùn)練模型-也是最重要的部分

          訓(xùn)練模型,從主函數(shù)這里看就是兩個(gè)步驟,一個(gè)是加載需要的數(shù)據(jù)集,一個(gè)是進(jìn)行訓(xùn)練,代碼位置在這里。大概代碼就是這樣:

          train_dataset?=?load_and_cache_examples(args,?args.task_name,?tokenizer,?data_type='train')
          global_step,?tr_loss?=?train(args,?train_dataset,?model,?tokenizer)

          兩個(gè)函數(shù),我們一個(gè)個(gè)看:

          6.1 加載訓(xùn)練集

          我們先看一下第一個(gè)函數(shù),load_and_cache_examples 就是加載訓(xùn)練數(shù)據(jù)集,代碼位置在這里。大概看一下這個(gè)代碼,核心操作有三個(gè)。

          第一個(gè)核心操作,位置在這里,代碼如下:

          examples?=?processor.get_train_examples(args.data_dir)

          這個(gè)代碼是為了利用processor讀取訓(xùn)練集,很簡(jiǎn)單。

          這里得到的example大概是這樣的(這個(gè)返回形式在上面看processor的時(shí)候很清楚的展示了):

          guid='train-0'
          label='104'
          text_a='今天股票形式不怎么樣啊'
          text_b=None

          第二個(gè)核心操作是convert_examples_to_features講數(shù)據(jù)進(jìn)行轉(zhuǎn)化,也很簡(jiǎn)單。

          代碼位置在這里。代碼如下:

          features?=?convert_examples_to_features(examples,tokenizer,label_list=label_list,max_length=args.max_seq_length,output_mode=output_mode,pad_on_left=bool(args.model_type?in?['xlnet']),????????????????????????????????????????????????pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
          pad_token_segment_id=4?if?args.model_type?in?['xlnet']?else?0,

          我們進(jìn)入這個(gè)函數(shù)看一看里面究竟是咋回事,位置在:

          processors.glue.glue_convert_examples_to_features

          做了一個(gè)標(biāo)簽的映射,'100'->0 ?'101'->1...

          接著獲取輸入文本的序列化表達(dá):input_ids, token_type_ids;形式大概如此:

          'input_ids'=[101, 5500, 4873, 704, 4638, 4960, 4788, 2501, 2578, 102]

          'token_type_ids'=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

          獲取attention mask:attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

          結(jié)果形式如下:[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

          計(jì)算出當(dāng)前長(zhǎng)度,獲取pading長(zhǎng)度,比如我們現(xiàn)在長(zhǎng)度是10,那么需要補(bǔ)到128,pad就需要118個(gè)0.

          這個(gè)時(shí)候,我們的input_ids 就變成了上面的列表后面加上128個(gè)0.然后我們的attention_mask就變成了上面的形式加上118個(gè)0,因?yàn)檠a(bǔ)長(zhǎng)的并不是我們的第二個(gè)句子,我們壓根沒(méi)第二個(gè)句子,所以token_type_ids是總共128個(gè)0

          每操作一個(gè)數(shù)據(jù)之后,我們需要做的是

          features.append(InputFeatures(input_ids=input_ids,
          attention_mask=attention_mask,
          token_type_ids=token_type_ids,
          label=label,
          input_len=input_len))##長(zhǎng)度為原始長(zhǎng)度,這里應(yīng)該是10,不是128

          InputFeatures 在這里就是將轉(zhuǎn)化之后的特征存儲(chǔ)到一個(gè)新的變量中

          在將所有原始數(shù)據(jù)進(jìn)行特征轉(zhuǎn)化之后,我們得到了features列表,然后將其中的元素轉(zhuǎn)化為tensor形式,隨后

          第三個(gè)是將轉(zhuǎn)化之后的新數(shù)據(jù)tensor化,然后使用TensorDataset構(gòu)造最終的數(shù)據(jù)集并返回,

          dataset?=?TensorDataset(all_input_ids,?all_attention_mask,?all_token_type_ids,?all_lens,all_labels)

          6.2 訓(xùn)練模型-Train函數(shù)

          我們來(lái)看第二個(gè)函數(shù),就是train的操作。

          6.2.1 常規(guī)操作

          首先都是一些常規(guī)操作。

          對(duì)數(shù)據(jù)隨機(jī)采樣:RandomSampler

          DataLoader讀取數(shù)據(jù)

          計(jì)算總共訓(xùn)練步數(shù)(梯度累計(jì)),warm_up 參數(shù)設(shè)定,優(yōu)化器,是否fp16等等

          然后一個(gè)batch一個(gè)batch進(jìn)行訓(xùn)練就好了。這里最核心的代碼就是下面的把數(shù)據(jù)和參數(shù)送入到模型中去:

          outputs?=?model(**inputs)

          我們是在進(jìn)行一個(gè)文本分類的demo操作,使用的是Bert中對(duì)應(yīng)的 BertForSequenceClassification 這個(gè)類。

          我們直接進(jìn)入這個(gè)類看一下里面函數(shù)究竟是啥情況。

          6.2.2 Bert分類模型:BertForSequenceClassification

          主要代碼代碼如下:

          點(diǎn)擊此處打開(kāi)折疊代碼

          ##reference:?transformers.modeling_bert.BertForSequenceClassification?
          class?BertForSequenceClassification(BertPreTrainedModel):
          ????def?__init__(self,?config):
          ????????????????...
          ????????...
          ????????self.bert?=?BertModel(config)
          ????????self.dropout?=?nn.Dropout(config.hidden_dropout_prob)
          ????????self.classifier?=?nn.Linear(config.hidden_size,?self.config.num_labels)

          ????def?forward(self,?input_ids,?attention_mask=None,?token_type_ids=None,position_ids=None,?head_mask=None,?labels=None):
          ????????outputs?=?self.bert(input_ids,
          ????????????????????????????attention_mask=attention_mask,
          ????????????????????????????token_type_ids=token_type_ids,
          ????????????????????????????position_ids=position_ids,?
          ????????????????????????????head_mask=head_mask)
          ????????##zida注解:注意看init中,定義了self.bert就是BertModel,所以我們需要的就是看一看BertModel中數(shù)據(jù)怎么進(jìn)入的
          ????????pooled_output?=?outputs[1]
          ????????pooled_output?=?self.dropout(pooled_output)
          ????????????????...
          ????????...
          ????????return?outputs??#?(loss),?logits,?(hidden_states),?(attentions)

          這個(gè)類最核心的有兩個(gè)部分,第一個(gè)部分就是使用了 BertModel 獲取Bert的原始輸出,然后使用 cls的輸出繼續(xù)做后續(xù)的分類操作。比較重要的是 BertModel,我們直接進(jìn)入看 BertModel 這個(gè)類的內(nèi)部情況。代碼如下:

          然后我們看一下BertModel這個(gè)模型究竟是怎么樣的

          6.2.1.1 BertModel

          代碼如下:

          點(diǎn)擊此處打開(kāi)折疊代碼

          ##?reference:?transformers.modeling_bert.BertModel??
          class?BertModel(BertPreTrainedModel):
          ????def?__init__(self,?config):

          ????????self.embeddings?=?BertEmbeddings(config)
          ????????self.encoder?=?BertEncoder(config)
          ????????self.pooler?=?BertPooler(config)
          ????????????????...
          ????def?forward(self,?input_ids,?attention_mask=None,?token_type_ids=None,position_ids=None,?head_mask=None):
          ????????????????...
          ????????###?第一部分,對(duì)?attention_mask?進(jìn)行操作,并對(duì)輸入做embedding
          ????????extended_attention_mask?=?attention_mask.unsqueeze(1).unsqueeze(2)
          ????????extended_attention_mask?=?extended_attention_mask.to(dtype=next(self.parameters()).dtype)?#?fp16?compatibility
          ????????extended_attention_mask?=?(1.0?-?extended_attention_mask)?*?-10000.0
          ????????embedding_output?=?self.embeddings(input_ids,?position_ids=position_ids,?token_type_ids=token_type_ids)
          ????????###?第二部分?進(jìn)入?encoder?進(jìn)行編碼
          ????????encoder_outputs?=?self.encoder(embedding_output,
          ???????????????????????????????????????extended_attention_mask,
          ???????????????????????????????????????head_mask=head_mask)
          ????????????????...
          ????????return?outputs

          對(duì)于BertModel ,我們可以把它分成兩個(gè)部分,第一個(gè)部分是對(duì) attention_mask 進(jìn)行操作,并對(duì)輸入做embedding,第二個(gè)部分是進(jìn)入encoder進(jìn)行編碼,這里的encoder使用的是BertEncoder。我們直接進(jìn)去看一下

          6.2.1.1.1 BertEncoder

          代碼如下:

          ##reference:transformers.modeling_bert.BertEncoder
          class?BertEncoder(nn.Module):
          ????def?__init__(self,?config):
          ????????super(BertEncoder,?self).__init__()
          ????????self.output_attentions?=?config.output_attentions
          ????????self.output_hidden_states?=?config.output_hidden_states
          ????????self.layer?=?nn.ModuleList([BertLayer(config)?for?_?in?range(config.num_hidden_layers)])

          ????def?forward(self,?hidden_states,?attention_mask=None,?head_mask=None):
          ????????all_hidden_states?=?()
          ????????all_attentions?=?()
          ????????for?i,?layer_module?in?enumerate(self.layer):
          ????????????if?self.output_hidden_states:
          ????????????????all_hidden_states?=?all_hidden_states?+?(hidden_states,)

          ????????????layer_outputs?=?layer_module(hidden_states,?attention_mask,?head_mask[i])
          ????????????hidden_states?=?layer_outputs[0]

          ????????????if?self.output_attentions:
          ????????????????all_attentions?=?all_attentions?+?(layer_outputs[1],)

          ????????#?Add?last?layer
          ????????if?self.output_hidden_states:
          ????????????all_hidden_states?=?all_hidden_states?+?(hidden_states,)

          ????????outputs?=?(hidden_states,)
          ????????if?self.output_hidden_states:
          ????????????outputs?=?outputs?+?(all_hidden_states,)
          ????????if?self.output_attentions:
          ????????????outputs?=?outputs?+?(all_attentions,)
          ????????return?outputs??#?last-layer?hidden?state,?(all?hidden?states),?(all?attentions)

          有一個(gè)BertEncoder小細(xì)節(jié),就是如果output_hidden_states為T(mén)rue,會(huì)把每一層的結(jié)果都輸出,也包含詞向量,所以如果十二層的話,輸出是是13層,第一層為word-embedding結(jié)果,每層結(jié)果都是[batchsize,seqlength,Hidden_size](除了第一層,[batchsize,seqlength,embedding_size])

          當(dāng)然embedding_size在維度上是和隱層維度一樣的。

          還有一點(diǎn)需要注意的就是,我們需要在這里看到一個(gè)細(xì)節(jié),就是我們是可以做head_mask,這個(gè)head_mask我記得有個(gè)論文是在做哪個(gè)head對(duì)結(jié)果的影響,這個(gè)好像能夠?qū)崿F(xiàn)。

          BertEncoder 中間最重要的是BertLayer

          • BertLayer

          BertLayer分為兩個(gè)操作,BertAttention和BertIntermediate。BertAttention分為BertSelfAttention和BertSelfOutput。我們一個(gè)個(gè)來(lái)看

          • BertAttention
          • BertSelfAttention
          def?forward(self,?hidden_states,?attention_mask=None,?head_mask=None):
          ??##?接受參數(shù)如上
          ??mixed_query_layer?=?self.query(hidden_states)?##?生成query?[16,32,768],16是batch_size,32是這個(gè)batch中每個(gè)句子的長(zhǎng)度,768是維度
          ??mixed_key_layer?=?self.key(hidden_states)
          ??mixed_value_layer?=?self.value(hidden_states)

          ??query_layer?=?self.transpose_for_scores(mixed_query_layer)##?將上面生成的query進(jìn)行維度轉(zhuǎn)化,現(xiàn)在維度:?[16,12,32,64]:[Batch_size,Num_head,Seq_len,每個(gè)頭維度]
          ??key_layer?=?self.transpose_for_scores(mixed_key_layer)
          ??value_layer?=?self.transpose_for_scores(mixed_value_layer)

          ??#?Take?the?dot?product?between?"query"?and?"key"?to?get?the?raw?attention?scores.
          ??attention_scores?=?torch.matmul(query_layer,?key_layer.transpose(-1,?-2))
          ??##?上面操作之后?attention_scores?維度為torch.Size([16,?12,?32,?32])
          ??attention_scores?=?attention_scores?/?math.sqrt(self.attention_head_size)
          ??if?attention_mask?is?not?None:
          ??#?Apply?the?attention?mask?is?(precomputed?for?all?layers?in?BertModel?forward()?function)
          ??attention_scores?=?attention_scores?+?attention_mask
          ??##?這里直接就是相加了,pad的部分直接為非常大的負(fù)值,下面softmax的時(shí)候,直接就為接近0

          ??#?Normalize?the?attention?scores?to?probabilities.
          ??attention_probs?=?nn.Softmax(dim=-1)(attention_scores)

          ??#?This?is?actually?dropping?out?entire?tokens?to?attend?to,?which?might
          ??#?seem?a?bit?unusual,?but?is?taken?from?the?original?Transformer?paper.
          ??attention_probs?=?self.dropout(attention_probs)##維度torch.Size([16,?12,?32,?32])

          ??#?Mask?heads?if?we?want?to
          ??if?head_mask?is?not?None:
          ????attention_probs?=?attention_probs?*?head_mask

          ??context_layer?=?torch.matmul(attention_probs,?value_layer)##維度torch.Size([16,?12,?32,?64])

          ??context_layer?=?context_layer.permute(0,?2,?1,?3).contiguous()##維度torch.Size([16,?32,?12,?64])
          ??new_context_layer_shape?=?context_layer.size()[:-2]?+?(self.all_head_size,)## new_context_layer_shape:torch.Size([16, 32, 768])
          ??context_layer?=?context_layer.view(*new_context_layer_shape)
          ##維度變成torch.Size([16,?32,?768])
          ??outputs?=?(context_layer,?attention_probs)?if?self.output_attentions?else?(context_layer,)
          ??return?outputs

          這個(gè)時(shí)候 BertSelfAttention 返回結(jié)果維度為 torch.Size([16, 32, 768]),這個(gè)結(jié)果作為BertSelfOutput的輸入

          • BertSelfOutput
          class?BertSelfOutput(nn.Module):
          ????def?__init__(self,?config):
          ????????super(BertSelfOutput,?self).__init__()
          ????????self.dense?=?nn.Linear(config.hidden_size,?config.hidden_size)
          ????????##?做了一個(gè)linear?維度沒(méi)變
          ????????self.LayerNorm?=?BertLayerNorm(config.hidden_size,?eps=config.layer_norm_eps)
          ????????self.dropout?=?nn.Dropout(config.hidden_dropout_prob)

          ????def?forward(self,?hidden_states,?input_tensor):
          ????????hidden_states?=?self.dense(hidden_states)
          ????????hidden_states?=?self.dropout(hidden_states)
          ????????hidden_states?=?self.LayerNorm(hidden_states?+?input_tensor)
          ????????return?hidden_states

          上面兩個(gè)函數(shù)BertSelfAttention 和BertSelfOutput之后,返回attention的結(jié)果,接下來(lái)仍然是BertLayer的下一個(gè)操作:BertIntermediate

          • BertIntermediate

          這個(gè)函數(shù)比較簡(jiǎn)單,經(jīng)過(guò)一個(gè)Linear,經(jīng)過(guò)一個(gè)Gelu激活函數(shù)

          輸入結(jié)果維度為 torch.Size([16, 32, 3072])

          這個(gè)結(jié)果 接下來(lái)進(jìn)入 BertOutput 這個(gè)模型

          • BertOutput

          也比較簡(jiǎn)單,Liner+BertLayerNorm+Dropout,輸出結(jié)果維度為 torch.Size([16, 32, 768])

          BertOutput的輸出結(jié)果返回給BertEncoder 類

          BertEncoder 結(jié)果返回 BertModel 類 作為 encoder_outputs,維度大小 torch.Size([16, 32, 768])

          BertModel的返回為 outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]

          sequence_output:torch.Size([16, 32, 768])

          pooled_output:torch.Size([16, 768]) 是cls的輸出經(jīng)過(guò)一個(gè)pool層(其實(shí)就是linear維度不變+tanh)的輸出

          outputs返回給BertForSequenceClassification,也就是對(duì)pooled_output 做分類

          更多閱讀



          2020 年最佳流行 Python 庫(kù) Top 10


          2020 Python中文社區(qū)熱門(mén)文章 Top 10


          Top 10 沙雕又有趣的 GitHub 程序

          特別推薦




          點(diǎn)擊下方閱讀原文加入社區(qū)會(huì)員

          瀏覽 58
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  这里只有免费精品6 | 无码三级乱伦 | 国产一级二级三级在线观看 | 美女av免费 | 色五月婷婷五月 |