LSTM原理及生成藏頭詩(Python)
一、基礎(chǔ)介紹
1.1 神經(jīng)網(wǎng)絡(luò)模型
常見的神經(jīng)網(wǎng)絡(luò)模型結(jié)構(gòu)有前饋神經(jīng)網(wǎng)絡(luò)(DNN)、RNN(常用于文本 / 時間系列任務(wù))、CNN(常用于圖像任務(wù))等等。具體可以看之前文章:一文概覽神經(jīng)網(wǎng)絡(luò)模型。
前饋神經(jīng)網(wǎng)絡(luò)是神經(jīng)網(wǎng)絡(luò)模型中最為常見的,信息從輸入層開始輸入,每層的神經(jīng)元接收前一級輸入,并輸出到下一級,直至輸出層。整個網(wǎng)絡(luò)信息輸入傳輸中無反饋(循環(huán))。即任何層的輸出都不會影響同級層,可用一個有向無環(huán)圖表示。
1.2 RNN 介紹
循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)是基于序列數(shù)據(jù)(如語言、語音、時間序列)的遞歸性質(zhì)而設(shè)計的,是一種反饋類型的神經(jīng)網(wǎng)絡(luò),它專門用于處理序列數(shù)據(jù),如逐字生成文本或預(yù)測時間序列數(shù)據(jù)(例如股票價格、詩歌生成)。
RNN和全連接神經(jīng)網(wǎng)絡(luò)的本質(zhì)差異在于“輸入是帶有反饋信息的”,RNN除了接受每一步的輸入x(t) ,同時還有輸入上一步的歷史反饋信息——隱藏狀態(tài)h (t-1) ,也就是當前時刻的隱藏狀態(tài)h(t) 或決策輸出O(t) 由當前時刻的輸入 x(t) 和上一時刻的隱藏狀態(tài)h (t-1) 共同決定。從某種程度,RNN和大腦的決策很像,大腦接受當前時刻感官到的信息(外部的x(t) )和之前的想法(內(nèi)部的h (t-1) )的輸入一起決策。

RNN的結(jié)構(gòu)原理可以簡要概述為兩個公式,具體介紹可以看下【一文詳解RNN】:
RNN的隱藏狀態(tài)為:h(t) = f( U * x(t) + W * h(t-1) + b1), ?f為激活函數(shù),常用tanh、relu;?
RNN的輸出為:o(t) = g( V * h(t) + b2),g為激活函數(shù),當用于分類任務(wù),一般用softmax;
1.3 從RNN到LSTM
但是在實際中,RNN在長序列數(shù)據(jù)處理中,容易導(dǎo)致梯度爆炸或者梯度消失,也就是長期依賴(long-term dependencies)問題,其根本原因就是模型“記憶”的序列信息太長了,都會一股腦地記憶和學(xué)習(xí),時間一長,就容易忘掉更早的信息(梯度消失)或者崩潰(梯度爆炸)。
梯度消失:歷史時間步的信息距離當前時間步越長,反饋的梯度信號就會越弱(甚至為0)的現(xiàn)象,梯度被近距離梯度主導(dǎo),導(dǎo)致模型難以學(xué)到遠距離的依賴關(guān)系。
改善措施:可以使用 ReLU 激活函數(shù);門控RNN 如GRU、LSTM 以改善梯度消失。
梯度爆炸:網(wǎng)絡(luò)層之間的梯度(值大于 1)重復(fù)相乘導(dǎo)致的指數(shù)級增長會產(chǎn)生梯度爆炸,導(dǎo)致模型無法有效學(xué)習(xí)。
改善措施:可以使用 梯度截斷;引導(dǎo)信息流的正則化;ReLU 激活函數(shù);門控RNN 如GRU、LSTM(和普通 RNN 相比多經(jīng)過了很多次導(dǎo)數(shù)都小于 1激活函數(shù),因此 LSTM 發(fā)生梯度爆炸的頻率要低得多)以改善梯度爆炸。
所以,如果我們能讓 RNN 在接受上一時刻的狀態(tài)和當前時刻的輸入時,有選擇地記憶和遺忘一部分內(nèi)容(或者說信息),問題就可以解決了。比如上上句話提及”我去考試了“,然后后面提及”我考試通過了“,那么在此之前說的”我去考試了“的內(nèi)容就沒那么重要,選擇性地遺忘就好了。這也就是長短期記憶網(wǎng)絡(luò)(Long Short-Term Memory, LSTM)的基本思想。
二、LSTM原理
LSTM是種特殊RNN網(wǎng)絡(luò),在RNN的基礎(chǔ)上引入了“門控”的選擇性機制,分別是遺忘門、輸入門和輸出門,從而有選擇性地保留或刪除信息,以能夠較好地學(xué)習(xí)長期依賴關(guān)系。如下圖RNN(上) 對比 LSTM(下):

2.1 LSTM的核心
在RNN基礎(chǔ)上引入門控后的LSTM,結(jié)構(gòu)看起來好復(fù)雜!但其實LSTM作為一種反饋神經(jīng)網(wǎng)絡(luò),核心還是歷史的隱藏狀態(tài)信息的反饋,也就是下圖的Ct:
對標RNN的ht隱藏狀態(tài)的更新,LSTM的Ct只是多個些“門控”刪除或添加信息到狀態(tài)信息。由下面依次介紹LSTM的“門控”:遺忘門,輸入門,輸出門的功能,LSTM的原理也就好理解了。
2.2 遺忘門
LSTM 的第一步是通過"遺忘門"從上個時間點的狀態(tài)Ct-1中丟棄哪些信息。
具體來說,輸入Ct-1,會先根據(jù)上一個時間點的輸出ht-1和當前時間點的輸入xt,并通過sigmoid激活函數(shù)的輸出結(jié)果ft來確定要讓Ct-1,來忘記多少,sigmoid后等于1表示要保存多一些Ct-1的比重,等于0表示完全忘記之前的Ct-1。
2.3 輸入門
下一步是通過輸入門,決定我們將在狀態(tài)中存儲哪些新信息。
我們根據(jù)上一個時間點的輸出ht-1和當前時間點的輸入xt 生成兩部分信息i t 及C~t,通過sigmoid輸出i t,用tanh輸出C~t。之后通過把i t 及C~t兩個部分相乘,共同決定在狀態(tài)中存儲哪些新信息。
在輸入門 + 遺忘門控制下,當前時間點狀態(tài)信息Ct為:

2.4 輸出門
最后,我們根據(jù)上一個時間點的輸出ht-1和當前時間點的輸入xt 通過sigmid 輸出Ot,再根據(jù)Ot 與 tanh控制的當前時間點狀態(tài)信息Ct 相乘作為最終的輸出。
綜上,一張圖可以說清LSTM原理:
三、LSTM簡單寫詩
本節(jié)項目利用深層LSTM模型,學(xué)習(xí)大小為10M的詩歌數(shù)據(jù)集,自動可以生成詩歌。
如下代碼構(gòu)建LSTM模型。
##?本項目完整代碼:github.com/aialgorithm/Blog
#?或“算法進階”公眾號文末閱讀原文可見
model?=?tf.keras.Sequential([
????#?不定長度的輸入
????tf.keras.layers.Input((None,)),
????#?詞嵌入層
????tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size,?output_dim=128),
????#?第一個LSTM層,返回序列作為下一層的輸入
????tf.keras.layers.LSTM(128,?dropout=0.5,?return_sequences=True),
????#?第二個LSTM層,返回序列作為下一層的輸入
????tf.keras.layers.LSTM(128,?dropout=0.5,?return_sequences=True),
????#?對每一個時間點的輸出都做softmax,預(yù)測下一個詞的概率
????tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size,?activation='softmax')),
])
#?查看模型結(jié)構(gòu)
model.summary()
#?配置優(yōu)化器和損失函數(shù)
model.compile(optimizer=tf.keras.optimizers.Adam(),?loss=tf.keras.losses.categorical_crossentropy)
模型訓(xùn)練,考慮訓(xùn)練時長,就簡單訓(xùn)練2個epoch。
class?Evaluate(tf.keras.callbacks.Callback):
????"""
????訓(xùn)練過程評估,在每個epoch訓(xùn)練完成后,保留最優(yōu)權(quán)重,并隨機生成SHOW_NUM首古詩展示
????"""
????def?__init__(self):
????????super().__init__()
????????#?給loss賦一個較大的初始值
????????self.lowest?=?1e10
????def?on_epoch_end(self,?epoch,?logs=None):
????????#?在每個epoch訓(xùn)練完成后調(diào)用
????????#?如果當前l(fā)oss更低,就保存當前模型參數(shù)
????????if?logs['loss']?<=?self.lowest:
????????????self.lowest?=?logs['loss']
????????????model.save(BEST_MODEL_PATH)
????????#?隨機生成幾首古體詩測試,查看訓(xùn)練效果
????????print("cun'h")
????????for?i?in?range(SHOW_NUM):
????????????print(generate_acrostic(tokenizer,?model,?head="春花秋月"))
#?創(chuàng)建數(shù)據(jù)集
data_generator?=?PoetryDataGenerator(poetry,?random=True)
#?開始訓(xùn)練
model.fit_generator(data_generator.for_fit(),?steps_per_epoch=data_generator.steps,?epochs=TRAIN_EPOCHS,
????????????????????callbacks=[Evaluate()])
加載簡單訓(xùn)練的LSTM模型,輸入關(guān)鍵字(如:算法進階)后,自動生成藏頭詩??梢钥闯鲈娋浯致钥瓷先ネ?yōu)雅,但實際上經(jīng)不起推敲。后面增加訓(xùn)練的epoch及數(shù)據(jù)集應(yīng)該可以更好些。
#?加載訓(xùn)練好的模型
model?=?tf.keras.models.load_model(BEST_MODEL_PATH)
keywords?=?input('輸入關(guān)鍵字:\n')
#?生成藏頭詩
for?i?in?range(SHOW_NUM):
????print(generate_acrostic(tokenizer,?model,?head=keywords),'\n')

- END -參考資料:https://colah.github.io/posts/2015-08-Understanding-LSTMs/ https://towardsdatascience.com/illustrated-guide-to-lstms-and-gru-s-a-step-by-step-explanation-44e9eb85bf21 https://www.zhihu.com/question/34878706
文章首發(fā)公眾號“算法進階”,文末閱讀原文可訪問文章相關(guān)代碼
