文字識(shí)別:一文讀懂 Transformer OCR
共 4595字,需瀏覽 10分鐘
·
2022-02-09 17:36
深度學(xué)習(xí)時(shí)代的文字識(shí)別:行識(shí)別,主流有兩種算法,一種是CRNN 算法,一種是attention 算法。
CRNN:CNN+RNN+CTC
白裳:一文讀懂CRNN+CTC文字識(shí)別attention :CNN+Seq2Seq+Attention
白裳:完全解析RNN, Seq2Seq, Attention注意力機(jī)制兩種算法都比較成熟,互聯(lián)網(wǎng)上也有很多講解的文章。
Attention Is All You Need (Transformer)這篇文章,設(shè)計(jì)了一種新型self-attention結(jié)構(gòu),取代了 RNN(LSTM\GRU) 的結(jié)構(gòu),在眾多nlp相關(guān)任務(wù)上取得了效果上的突破,而后來(lái)的BERT、GPT等模型亦是來(lái)源于這篇文章。有關(guān)Transformer 的講解也有很多。
大師兄:詳解Transformer (Attention Is All You Need)OCR行識(shí)別本身也是seq2seq的序列識(shí)別問題,這里講下如何利用transformer結(jié)構(gòu)進(jìn)行OCR識(shí)別,本文介紹的Transformer OCR 基于以下代碼
https://github.com/saberSabersaber/transformer_OCR目前僅支持寬、高固定的定長(zhǎng)識(shí)別(高度固定,寬度需要padding到最大長(zhǎng)度),以下假設(shè)輸入圖像高度固定為32, 最大寬度為100。
Transformer OCR 模型結(jié)構(gòu),核心分為兩個(gè)部分:backbone + transfromer
- Backbone:
backbone一般是由卷積層+pooling層堆疊而成, 這里以crnn 中 VGG_FeatureExtractor 為例,一共是經(jīng)過(guò)4次pooling層,其中 2 個(gè)2*2 pooling ,2個(gè) 2 * 1 的pooling,經(jīng)過(guò)backbone之后輸出為512*2*25(其中512是channel數(shù)),為了將高度pooling 到1 ,最后又額外利用卷積操作將圖像變?yōu)?12*1*24。(并不必要,因?yàn)楹竺鏁?huì)接一個(gè)AdaptiveAvgPool2d 層,以保證不同backbone提取到的特征高度均為1)。
2. Transformer:
transformer 主要有兩個(gè)結(jié)構(gòu),encoder 和 decoder.
2.1 Encoder:
encoder 結(jié)構(gòu)由多個(gè)MultiHeadAttentionLayer 和 PositionwiseFeedforwardLayer 堆疊而成的block組成。其輸入為經(jīng)過(guò)backbone之后提取的特征,這里記為 ,為了保證transformer 輸入的時(shí)序性,transformer 額外計(jì)算了位置編碼(pos_embedding),這里記為 ,那么 encoder內(nèi)部輸入給MultiHeadAttentionLayer 的輸入可以表示為: 。encoder 第一個(gè)block(MultiHeadAttentionLayer + PositionwiseFeedforwardLayer)輸入是有圖像特征和位置編碼得到的 ,后面block 的輸入為上一個(gè)block 的輸出。
Self-attention
Self-attention 是 transformer 的核心結(jié)構(gòu),在這里數(shù)據(jù)首先會(huì)經(jīng)過(guò)一個(gè)叫做self-attention的模塊得到一個(gè)加權(quán)之后的特征向量 ,這個(gè) 便是論文公式1中的 :
在encoder 中, 全部來(lái)自于輸入 :
其中 是模型參數(shù)。Self-attention 可以看到是將輸入經(jīng)過(guò)三個(gè)不同的線性映射得到三個(gè)中間變量 ,然后利用矩陣乘法來(lái)模擬RNN 的時(shí)序計(jì)算。RNN,包括LSTM、GRU的一個(gè)問題是隨著時(shí)間步的變長(zhǎng),時(shí)序依賴性會(huì)減弱,即 和 ,在 比較小的時(shí)候互相之間依賴較強(qiáng),當(dāng) 越來(lái)越大時(shí),兩個(gè)時(shí)刻之間的聯(lián)系也會(huì)越來(lái)越弱。但是在矩陣乘法中,每一列也可以認(rèn)為是一個(gè)時(shí)間步,而任意兩列在矩陣乘法中距離是一樣的,都會(huì)得到計(jì)算,故而對(duì)于序列識(shí)別來(lái)說(shuō),transformer打破了時(shí)序長(zhǎng)時(shí)依賴的障礙。但是,由于矩陣乘法是無(wú)序的,而OCR識(shí)別輸入的圖像是有序的,所以需要通過(guò)位置編碼來(lái)彌補(bǔ)。
MultiHeadAttentionLayer
Multi-Head Attention,其實(shí)是多個(gè)self-attention的集成,這里采用多個(gè)self-attention能夠豐富特征,代買設(shè)置的參數(shù)為 。Multi-Head Attention的輸出分成3步:
- 將輸入 分別輸入到8個(gè)self-attention中,得到8個(gè)加權(quán)后的特征矩陣 ,將這8個(gè)輸出矩陣直接拼成一個(gè)大的特征矩陣,最后再利用一層全連接后得到輸出 。
PositionwiseFeedforward
這個(gè)全連接有兩層,第一層的激活函數(shù)是ReLU,第二層是一個(gè)線性激活函數(shù),可以表示為:
圖像特征 經(jīng)過(guò)encoder之后得到的特征記為 。
2.2 Decoder
Decoder 結(jié)構(gòu)和encoder 稍有不同,是由 MASK MultiHeadAttentionLayer 、encoder-decoder MultiHeadAttentionLayer 和 PositionwiseFeedforwardLayer 三部分組成的多個(gè)block 堆疊而成。
Decoder 的輸入由兩部分組成,其中之一是encoder 輸出 ,另外一個(gè)輸入是target, 由label 得到,而MASK MultiHeadAttentionLayer 結(jié)構(gòu)輸入就是target。
MASK MultiHeadAttentionLayer
這里首先要講一下target 的生成方式,以下圖為例,其真實(shí)label為L(zhǎng)ondon,在計(jì)算loss 時(shí),會(huì)在字符串最后補(bǔ)一個(gè)終止符[end], 由于是batch 訓(xùn)練,多個(gè)樣本的label長(zhǎng)度不同,這里也采用pading 的方式對(duì)齊到max_len,即label為L(zhǎng)ondon[end][pad][pad]....,而target 為[begin]London[end][pad]...., 即在label 前插入一個(gè)起始符,[begin] 和 [end] 用途稍后會(huì)說(shuō)。
MASK MultiHeadAttentionLayer 與 MASK MultiHeadAttentionLayer 區(qū)別在于mask,而利用mask 的原因是在序列運(yùn)算的時(shí)候,后面的字符可以看到前面的字符,但是前面的字符看不到后面的字符,比如,當(dāng)預(yù)測(cè)o的時(shí)候,L是已經(jīng)預(yù)測(cè)出來(lái)的,可以用于輸入,但是o、n都不能用于輸入,所以要生成三角矩陣來(lái)進(jìn)行mask。操作起來(lái)也比較簡(jiǎn)單,就是在公示(1)中 得到的矩陣對(duì)應(yīng)位置上填充-inf即可。其余地方和 MultiHeadAttentionLayer 一致。將target 經(jīng)過(guò)MASK MultiHeadAttentionLayer 之后得到的特征記為 。
encoder-decoder MultiHeadAttentionLayer
encoder-decoder MultiHeadAttentionLayer 運(yùn)算過(guò)程和MultiHeadAttentionLayer 也是一樣的,只不過(guò)這里的 來(lái)自于 , 而 來(lái)自于 , 其余部分與MultiHeadAttentionLayer 完全相同。其意思為每個(gè)時(shí)刻輸入一個(gè)query,通過(guò)query,key,和 value的矩陣乘法得到當(dāng)前時(shí)刻的輸出。target在最開始的時(shí)候插入[begin]就是為了輸入的query 領(lǐng)先于需要預(yù)測(cè)的label,即query為[begin],預(yù)測(cè)得到L,query為L(zhǎng), 預(yù)測(cè)得到o。最后的PositionwiseFeedforwardLayer 與encoder中相同。
Tips:
最后針對(duì)一些常見的問題,給出一些解釋:
1)label 為何要在最后添加[end] 符號(hào)?
這是因?yàn)樵跍y(cè)試的時(shí)候我們不知道最終結(jié)果的長(zhǎng)度,比如London這個(gè)單詞,當(dāng)預(yù)測(cè)到字符n的時(shí)候,程序并不會(huì)終止,會(huì)繼續(xù)預(yù)測(cè),添加終止符號(hào)的情況下,模型會(huì)學(xué)到預(yù)測(cè)到[end],這時(shí)候可以終止預(yù)測(cè),或者通過(guò)簡(jiǎn)單的后處理提取[end]之前有效字符。
2)target 為何在最開始添加[begin] 符號(hào)?
這個(gè)和attention OCR 中類似,只不過(guò)在運(yùn)算中由于都是矩陣運(yùn)算體現(xiàn)的不明顯(在代碼中測(cè)試的時(shí)候采用的是循環(huán)的方式,這種方式更容易理解)
從decoder 輸入可以看到,decoder 同時(shí)輸入了圖像和label(target 是在lable 最開始插入了[begin]符號(hào)),但是實(shí)際情況我們是沒有l(wèi)abel的,只有圖像,仍然以London 為例,當(dāng)預(yù)測(cè)第一個(gè)字符的L的時(shí)候,如果不補(bǔ)[begin]會(huì)發(fā)生什么?模型能看到整張圖像的特征以及l(fā)abel第一個(gè)字符L 提取的特征,當(dāng)預(yù)測(cè)o的時(shí)候,模型能看到整張圖像的特征以及l(fā)abel 前兩個(gè)符號(hào)Lo的特征,這樣顯然是有問題的。接下來(lái)在label 最開始的位置插入一個(gè)[begin] 符號(hào),那么當(dāng)預(yù)測(cè)第一個(gè)字符的L的時(shí)候, 模型能看到整張圖像的特征以及[begin]的特征,當(dāng)預(yù)測(cè)o的時(shí)候,模型能看到整張圖像的特征以及l(fā)abel 前兩個(gè)符號(hào)[begin][L] 的特征,在訓(xùn)練的時(shí)候 L 來(lái)自label,當(dāng)預(yù)測(cè)的時(shí)候,L 是前一次預(yù)測(cè)的結(jié)果。
3)encoder-decoder MultiHeadAttentionLayer 需不需要用mask?
在encoder-decoder MultiHeadAttentionLayer 中計(jì)算是不需要計(jì)算mask的,這是因?yàn)檫@個(gè)模塊的輸入是MASK MultiHeadAttentionLayer的輸出,這里的輸出已經(jīng)保證了 時(shí)刻的特征看不到 時(shí)刻的特征。
4)transformer VS RNN(LSTM、GRU)有哪些優(yōu)勢(shì)?為什么要用位置編碼?
這里把兩個(gè)問題放在一起,之前也簡(jiǎn)單提過(guò),不過(guò)這里只是個(gè)人理解,不保證個(gè)人理解一定正確。
RNN 是利用隱含層h記錄之前時(shí)刻的狀態(tài),LSTM、GRU 通過(guò)記錄更多的額外狀態(tài)以期保留時(shí)間跨度更長(zhǎng)的信息,但都不可避免的是, 和 ,在 比較小的時(shí)候互相之間依賴較強(qiáng),當(dāng) 越來(lái)越大時(shí),兩個(gè)時(shí)刻之間的聯(lián)系也會(huì)越來(lái)越弱。Transformer 里面是 的矩陣乘法,而矩陣乘法任意兩列(兩個(gè)時(shí)刻)都會(huì)計(jì)算。以Beijing is the captial of china 為例,在RNN 這種結(jié)構(gòu)中,Beijing 和 china 距離較遠(yuǎn),所以在預(yù)測(cè)china 的時(shí)候RNN中幾乎沒有任何關(guān)于Beijing特征的信息,但是在transformer 中,Beijing is the captial of china 這句話和 is the captial of Beijing china 是一樣的,所以在長(zhǎng)時(shí)間依賴上,transformer 能夠優(yōu)于RNN,也正是因?yàn)檫@個(gè)原因,需要利用位置編碼將位置信息加給模型。
