<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          Keras 實(shí)戰(zhàn)系列之知識(shí)蒸餾(Knowledge Distilling)

          共 4327字,需瀏覽 9分鐘

           ·

          2021-12-25 21:26

          前言

          深度學(xué)習(xí)在這兩年的發(fā)展可謂是突飛猛進(jìn),為了提升模型性能,模型的參數(shù)量變得越來(lái)越多,模型自身也變得越來(lái)越大。在圖像領(lǐng)域中基于Resnet的卷積神經(jīng)網(wǎng)絡(luò)模型,不斷延伸著網(wǎng)絡(luò)深度。而在自然語(yǔ)言處理領(lǐng)域(NLP)領(lǐng)域,BERT,GPT等超大模型的誕生也緊隨其后。這些巨型模型在準(zhǔn)確性上大部分時(shí)候都吊打其他一眾小參數(shù)量模型,可是它們?cè)诓渴痣A段,往往需要占用巨大內(nèi)存資源,同時(shí)運(yùn)行起來(lái)也極其耗時(shí),這與工業(yè)界對(duì)模型吃資源少,低延時(shí)的要求完全背道而馳。所以很多在學(xué)術(shù)界呼風(fēng)喚雨的強(qiáng)大模型在企業(yè)的運(yùn)用過(guò)程中卻沒(méi)有那么順風(fēng)順?biāo)?/p>

          知識(shí)蒸餾

          為解決上述問(wèn)題,我們需要將參數(shù)量巨大的模型,壓縮成小參數(shù)量模型,這樣就可以在不失精度的情況下,使得模型占用資源少,運(yùn)行快,所以如何將這些大模型壓縮,同時(shí)保持住頂尖的準(zhǔn)確率,成了學(xué)術(shù)界一個(gè)專(zhuān)門(mén)的研究領(lǐng)域。2015年Geoffrey Hinton 發(fā)表的Distilling the Knowledge in a Neural Network的論文中提出了知識(shí)蒸餾技術(shù),就是為了解決模型壓而生的。至于文章的細(xì)節(jié)這里博主不做過(guò)多介紹,想了解的同學(xué)們可以好好研讀原文。不過(guò)這篇文章的主要思想就如下方圖片所示:用一個(gè)老師模型(大參數(shù)模型)去教一個(gè)學(xué)生模型(小參數(shù)模型),在實(shí)做上就是用讓學(xué)生模型去學(xué)習(xí)已經(jīng)在目標(biāo)數(shù)據(jù)集上訓(xùn)練過(guò)的老師模型。盡管學(xué)生模型最終依然達(dá)不到老師模型的準(zhǔn)確性,但是被老師教過(guò)的學(xué)生模型會(huì)比自己?jiǎn)为?dú)訓(xùn)練的學(xué)生模型更加強(qiáng)大。

          這里大家可能會(huì)產(chǎn)生疑惑,為什么讓學(xué)生模型去學(xué)習(xí)目標(biāo)數(shù)據(jù)集會(huì)比被老師模型教出來(lái)的差。產(chǎn)生這種結(jié)果可能原因是因?yàn)?strong>老師模型的輸出提供了比目標(biāo)數(shù)據(jù)集更加豐富的信息,如下圖所示,老師模型的輸出,不僅提供了輸入圖片上的數(shù)字是數(shù)字1的信息,而且還附帶著數(shù)字1和數(shù)字7和9比較像等額外信息。

          知識(shí)蒸餾

          ?

          知識(shí)蒸餾具體流程

          接下來(lái)博主介紹一下知識(shí)蒸餾在實(shí)做上的具體流程。

          • (1)定義一個(gè)參數(shù)量較大(強(qiáng)大的)的老師模型,和一個(gè)參數(shù)量較小(弱小的)的學(xué)生模型,

          • (2)讓老師模型在目標(biāo)數(shù)據(jù)集上訓(xùn)練到最佳,

          • (3)將目標(biāo)數(shù)據(jù)的label替換成老師模型最后一個(gè)全連接層的輸出,讓學(xué)生模型學(xué)習(xí)老師模型的輸出,希望學(xué)生模型的輸出和老師模型輸出之間的交叉熵越小越好。

          了解到知識(shí)蒸餾的具體步驟之后,我們采用keras在mnist數(shù)據(jù)集上進(jìn)行一次簡(jiǎn)單的實(shí)驗(yàn)。

          知識(shí)蒸餾實(shí)戰(zhàn)

          包導(dǎo)入

          導(dǎo)入一下必要的python 包,同時(shí)載入數(shù)據(jù)。

          ?

          1. from keras.datasets import mnist

          2. from keras.layers import *

          3. from keras import Model

          4. from sklearn.metrics import accuracy_score

          5. import numpy as np

          6. (data_train,label_train),(data_test,label_test )= mnist.load_data()

          7. data_train = np.expand_dims(data_train,axis=3)

          8. data_test = np.expand_dims(data_test,axis=3)

          定義老師模型和學(xué)生模型

          在下方代碼中,博主定義了一個(gè)包含3層卷積層的CNN模型作為老師模型(參數(shù)量6萬(wàn)),定義了一個(gè)包含512個(gè)神經(jīng)元的全連接層作為學(xué)生模型(參數(shù)量4萬(wàn),比老師模型少了2萬(wàn))。

          ?

          1. #####定義老師模型——包含三層卷積層的CNN模型

          2. def teacher_model():

          3. input_ = Input(shape=(28,28,1))

          4. x = Conv2D(32,(3,3),padding = "same")(input_)

          5. x = Activation("relu")(x)

          6. print(x)

          7. x = MaxPool2D((2,2))(x)

          8. x = Conv2D(64,(3,3),padding= "same")(x)

          9. x = Activation("relu")(x)

          10. x = MaxPool2D((2,2))(x)

          11. x = Conv2D(64,(3,3),padding= "same")(x)

          12. x = Activation("relu")(x)

          13. x = MaxPool2D((2,2))(x)

          14. x = Flatten()(x)

          15. out = Dense(10,activation = "softmax")(x)

          16. model = Model(inputs=input_,outputs=out)

          17. model.compile(loss="sparse_categorical_crossentropy",

          18. optimizer="adam",

          19. metrics=["accuracy"])

          20. model.summary()

          21. return model


          22. ###定義學(xué)生模型——— 一層含512個(gè)神經(jīng)元的全連接層

          23. def student_model():

          24. input_ = Input(shape=(28,28,1))

          25. x = Flatten()(input_)

          26. x = Dense(512,activation="sigmoid")(x)

          27. out = Dense(10,activation = "softmax")(x)

          28. model = Model(inputs=input_,outputs=out)

          29. model.compile(loss="sparse_categorical_crossentropy",

          30. optimizer="adam",

          31. metrics=["accuracy"])

          32. model.summary()

          33. return model

          訓(xùn)練老師模型

          接下來(lái)開(kāi)始訓(xùn)練老師模型,由于mnist數(shù)據(jù)集較為簡(jiǎn)單,在三層的CNN模型上,我設(shè)定只訓(xùn)練2個(gè)epoch。這里需要注意的是,如下圖所示:三層卷積的CNN的有6萬(wàn)多個(gè)參數(shù)

          ?

          1. t_model = teacher_model()

          2. t_model.fit(data_train,label_train,batch_size=64,epochs=2,validation_data=(data_test,label_test))

          teacher model

          ?

          訓(xùn)練結(jié)果如下圖所示:兩個(gè)epoch,CNN模型就在測(cè)試集上做到了98%的準(zhǔn)確性。

          ?

          teacher result

          訓(xùn)練學(xué)生模型

          在512個(gè)神經(jīng)元的全連接層上訓(xùn)練mnist數(shù)據(jù)集,學(xué)生模型的參數(shù)量如下圖所示:參數(shù)量只有4萬(wàn)個(gè),參數(shù)量比老師模型少了2萬(wàn)個(gè)

          ?

          1. s_model = student_model()

          2. s_model.fit(data_train,label_train,batch_size=64,epochs=10,validation_data=(data_test,label_test))

          student model

          ?

          在學(xué)生模型上訓(xùn)練了10個(gè)epoch之后,測(cè)試機(jī)準(zhǔn)確率最高也才達(dá)到0.9460,遠(yuǎn)低于CNN老師模型的0.98

          ?

          student result

          老師模型教學(xué)生模型

          最后我們用老師模型教學(xué)生模型,進(jìn)行知識(shí)蒸餾。
          首先我們采用下方代碼將目標(biāo)數(shù)據(jù)集的label替換成老師模型的輸出。

          ?

          t_out = t_model.predict(data_train)

          然后用學(xué)生模型去學(xué)習(xí)老師模型的輸出。

          ?

          1. def teach_student(teacher_out, student_model,data_train,data_test,label_test):

          2. t_out = teacher_out


          3. s_model = student_model

          4. for l in s_model.layers:

          5. l.trainable = True


          6. label_test = keras.utils.to_categorical(label_test)


          7. model = Model(s_model.input,s_model.output)

          8. model.compile(loss="categorical_crossentropy",

          9. optimizer="adam")

          10. model.fit(data_train,t_out,batch_size= 64,epochs = 5)


          11. s_predict = np.argmax(model.predict(data_test),axis=1)

          12. s_label = np.argmax(label_test,axis=1)

          13. print(accuracy_score(s_predict,s_label))

          最終得到的實(shí)驗(yàn)結(jié)果如下圖所示:學(xué)生模型的性能提升到了0.9511,相比于學(xué)生模型在目標(biāo)數(shù)據(jù)集上的最好成績(jī)0.9460提升了千分之6個(gè)點(diǎn)。這也證明我們知識(shí)蒸餾確實(shí)起作用了。

          ?

          result of student model after being taught

          結(jié)語(yǔ)

          當(dāng)然我們也發(fā)現(xiàn),我們的實(shí)驗(yàn)提升的幅度并不大,離老師模型的準(zhǔn)確度還有巨大的差距,而要想優(yōu)化知識(shí)蒸餾的性能,我們可以采取升溫技術(shù),升溫技術(shù)的原理圖如下圖所示:將老師模型的輸出在softmax激活函數(shù)之前初上一個(gè)數(shù)值大于1的數(shù)字T,這樣會(huì)使得老師模型輸出的個(gè)類(lèi)別概率值變得較為接近。

          升溫技術(shù)

          ?

          確實(shí)升溫技術(shù)的主要目的就是將老師模型輸出的各類(lèi)型的概率,變得較為接近,這樣老師模型的輸出信息將變得更加豐富,得學(xué)生模型學(xué)會(huì)分辨出個(gè)類(lèi)別之間細(xì)微的區(qū)別。當(dāng)然知識(shí)蒸餾的優(yōu)化方法并不只上述的升溫技術(shù)這一種,這里博主只是拋磚引玉,知識(shí)蒸餾還有更多的奧秘等著大家去探索,去學(xué)習(xí)。希望讀者能夠有所收獲的同時(shí),心中的好奇心也能夠被激發(fā),主動(dòng)的學(xué)習(xí)知識(shí)蒸餾這門(mén)技術(shù)。




          Python“寶藏級(jí)”公眾號(hào)【Python之王】專(zhuān)注于Python領(lǐng)域,會(huì)爬蟲(chóng),數(shù)分,C++,tensorflow和Pytorch等等。

          近 2年共原創(chuàng) 100+ 篇技術(shù)文章。創(chuàng)作的精品文章系列有:

          日常收集整理了一批不錯(cuò)的?Python?學(xué)習(xí)資料,有需要的小伙可以自行免費(fèi)領(lǐng)取。

          獲取方式如下:公眾號(hào)回復(fù)資料領(lǐng)取Python等系列筆記,項(xiàng)目,書(shū)籍,直接套上模板就可以用了。資料包含算法、python、算法小抄、力扣刷題手冊(cè)和 C++ 等學(xué)習(xí)資料!


          ??


          瀏覽 42
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  A片视频在线观看 | 95嫩模主播酒店约 | 超碰免费在线97 | 色天天男人天堂婷婷 | 久久aa|