<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          有手就會(huì)-用MNIST訓(xùn)練一個(gè)CNN模型并識(shí)別自己手寫數(shù)字

          共 6613字,需瀏覽 14分鐘

           ·

          2022-04-13 01:01


          大家好,我是小伍哥,今天我們學(xué)點(diǎn)視覺的東西。很多人學(xué)圖片算法的時(shí)候,MNIST手寫數(shù)字識(shí)別都是第一個(gè)練手的項(xiàng)目,其實(shí)干跑也沒啥意思,我們今天訓(xùn)練一個(gè)模型用來識(shí)別自己的手寫數(shù)字,看看能不能實(shí)現(xiàn),這樣學(xué)起來更有參與感點(diǎn),也更實(shí)用,過程介紹也比較詳細(xì),適合初學(xué)者。


          一、CNN模型構(gòu)建

          from keras import layersfrom keras import modelsmodel = models.Sequential()model.add(layers.Conv2D(32, (3, 3),           activation='relu',           input_shape=(28, 28, 1))          )model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Conv2D(64, (3, 3), activation='relu'))model.add(layers.MaxPooling2D((2, 2)))model.add(layers.Conv2D(64, (3, 3), activation='relu'))model.add(layers.Flatten())model.add(layers.Dropout(0.25))model.add(layers.Dense(64, activation='relu'))model.add(layers.Dense(10, activation='softmax'))# 十個(gè)分類

          卷積神經(jīng)網(wǎng)絡(luò)接收形狀為(image_height, image_width, image_channels)的輸入張量(不包括批量維度)。本例中設(shè)置卷積神經(jīng)網(wǎng)絡(luò)處理大小為(28, 28, 1) 的輸入張量,這正是MNIST 圖像的格式。我們向第一層傳入?yún)?shù)input_shape=(28, 28, 1) 來完成此設(shè)置。我們來看一下目前卷積神經(jīng)網(wǎng)絡(luò)的架構(gòu)。

          model.summary()Model: "sequential"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================conv2d (Conv2D)              (None, 26, 26, 32)        320       _________________________________________________________________max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         _________________________________________________________________conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     _________________________________________________________________max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         _________________________________________________________________conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     _________________________________________________________________flatten (Flatten)            (None, 576)               0         _________________________________________________________________dropout (Dropout)            (None, 576)               0         _________________________________________________________________dense (Dense)                (None, 64)                36928     _________________________________________________________________dense_1 (Dense)              (None, 10)                650       =================================================================Total params: 93,322Trainable params: 93,322Non-trainable params: 0_________________________________________________________________

          這里需要理解其中的具體結(jié)構(gòu),比如參數(shù)個(gè)數(shù)18496,這個(gè)的算法是(3*3*32+1)*64 得來的,這里需要充分的理解什么事參數(shù),什么事偏置。每個(gè)卷積核單元就是一個(gè)訓(xùn)練參數(shù),3*3的就有9個(gè),上一層有32個(gè)深度,需要32個(gè)3*3的卷積核,卷積乘完了還需要加一個(gè)偏置。所以有了上面的參數(shù)個(gè)數(shù)。


          二、圖片下載與查看

          第一步我們需要獲取訓(xùn)練數(shù)據(jù),mnist這個(gè)數(shù)據(jù)集,已經(jīng)內(nèi)置到Keras包里了,直接下載就可以,具體的代碼如下。

          from keras.datasets import mnistfrom keras.utils    import to_categoricalfrom keras          import datasets
          # 加載數(shù)據(jù)集(train_images,train_labels), (test_images,test_labels) = mnist.load_data()‘’‘我們可以看到下載的進(jìn)度Using TensorFlow backend.Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz 540672/11490434 [>.............................] - ETA: 9:00 ‘’‘ # 訓(xùn)練集有60000個(gè)樣本train_images.shape(60000, 28, 28)
          # 測(cè)試集有10000個(gè)樣本test_images.shape(10000, 28, 28)
          train_images[1].shape (28, 28)
          # 看看一個(gè)數(shù)字的像素點(diǎn)長(zhǎng)啥樣,選取了部分train_images[1] array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 159, 253, 159, 50, 0, 0, 0, 0, 0, 0, 0, 0],····]]
          #可以看看內(nèi)置的所有數(shù)據(jù)集print(dir(datasets))['absolute_import', 'boston_housing', 'cifar', 'cifar10', 'cifar100','fashion_mnist', 'imdb', 'mnist', 'reuters']

          把矩陣打印出來看看,第200個(gè),是數(shù)字1

          import matplotlib.pyplot as pltplt.imshow(train_images[200] , cmap=plt.cm.binary)plt.show()

          第1220個(gè),是數(shù)字4

          plt.imshow(train_images[1220] , cmap=plt.cm.binary)plt.show()

          三、模型訓(xùn)練&準(zhǔn)確率評(píng)估

          我們開始訓(xùn)練模型,第一步是要調(diào)整圖片的格式,通道1,并除以255歸一化,將像素值轉(zhuǎn)換到0-1之間,方便反向傳播數(shù)據(jù)的更新。

          train_images = train_images.reshape((60000, 28, 28, 1))train_images = train_images.astype('float32') / 255train_labels = to_categorical(train_labels)
          test_images = test_images.reshape((10000, 28, 28, 1))test_images = test_images.astype('float32') / 255test_labels = to_categorical(test_labels)

          model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])model.fit(train_images, train_labels, epochs=5, ???????????batch_size=64)

          訓(xùn)練完成了,我們?cè)跍y(cè)試集上測(cè)試下模型的準(zhǔn)確率,可以看到,這么一個(gè)簡(jiǎn)單的模型,我們的準(zhǔn)確率就達(dá)到了99.14%,深度學(xué)習(xí)還是非常強(qiáng)大的

          test_loss, test_acc = model.evaluate(test_images, test_labels)test_acc 0.9914000034332275

          看看預(yù)測(cè)的到底準(zhǔn)不準(zhǔn)呢,我們看看預(yù)測(cè)的細(xì)節(jié)

          import numpy as npimport matplotlib.pyplot as plt
          y_pred = model.predict(test_images)pred = np.argmax(y_pred, axis=1)
          #看看第2990個(gè)數(shù)字是啥,我們預(yù)測(cè)的是8,看看圖片也是8,挺準(zhǔn)的steps = 2990print('pred: ',pred[steps])pred: 8
          plt.imshow(test_images[steps] , cmap=plt.cm.binary)plt.show()


          四、預(yù)測(cè)自己的手寫數(shù)字

          在測(cè)試集效果好,那在實(shí)際應(yīng)用中到底好不好呢,我們自己手寫幾個(gè)測(cè)試下,模型訓(xùn)練好了就可以保存著以后用了,預(yù)測(cè)的時(shí)候直接加載就行,如果預(yù)測(cè)樣本沒有發(fā)生比較大的變化,那訓(xùn)練好的模型理論上可以一直使用,大概的預(yù)測(cè)過程如下:

          根據(jù)上面的訓(xùn)練,模型的準(zhǔn)確率還挺高得,但是實(shí)際有沒有用呢,還需要用自己的數(shù)據(jù)進(jìn)行測(cè)試,打開自己在畫圖板或者在筆記本上隨便寫幾個(gè)數(shù)字,然后單個(gè)截圖保存后進(jìn)行預(yù)測(cè)。

          分別截圖后保存成img3、img4、img5......,下面進(jìn)行預(yù)處理,處理成和模型訓(xùn)練一樣的數(shù)據(jù)才能預(yù)測(cè)。

          from keras.preprocessing.image import load_img,img_to_arrayimport matplotlib.pyplot as pltimport numpy as np
          #讀取圖片、調(diào)整圖片大小,轉(zhuǎn)換成灰度 help(load_img)path = 'C:/Users/伍正祥/Desktop/img5.jpg'img = load_img(path, target_size=(28, 28),color_mode="grayscale")
          #255-為了調(diào)成白底,系統(tǒng)灰度轉(zhuǎn)換自動(dòng)給處理成黑底了,所以做個(gè)反轉(zhuǎn)img = 255-img_to_array(img)
          #查看自己加載的圖片plt.imshow(img , cmap=plt.cm.binary)plt.show()
          #圖片形狀調(diào)整,需要調(diào)整到和訓(xùn)練集一樣的格式img = img.astype('float32')/255img = img.reshape((1, 28, 28, 1))
          #進(jìn)行圖片進(jìn)行預(yù)測(cè)y_pred = model.predict(img)
          print('預(yù)測(cè)數(shù)字:',np.argmax(y_pred, axis=1)[0]print('預(yù)測(cè)概率:',y_pred)

          讀取5的的手寫圖片并進(jìn)行預(yù)測(cè),可以看到預(yù)測(cè)的結(jié)果為5.

          看看預(yù)測(cè)的概率分布,是5的概率幾乎接近于1


          讀取3的的手寫圖片并預(yù)測(cè)

          讀取4的的手寫圖片并預(yù)測(cè)

          從測(cè)試的結(jié)果來看,預(yù)測(cè)效果還是非常不錯(cuò)的,多試幾次,也有預(yù)測(cè)錯(cuò)的,但是錯(cuò)的概率比較小。網(wǎng)絡(luò)結(jié)構(gòu)比較簡(jiǎn)單,如果對(duì)于接觸的不多的同學(xué),理解卷積還是比較困難的。特別是每一層的具體細(xì)節(jié)以及參數(shù)個(gè)數(shù)等,大家可以多看看一些可視化CNN的文章,充分理解。

          ···? END? ···

          往期精彩:

          風(fēng)控難題之無監(jiān)督風(fēng)險(xiǎn)感知:腦力、想象力、第六感、黑洞、星座、面相···

          情侶、基友、渣男和狗-基于時(shí)空關(guān)聯(lián)規(guī)則的影子賬戶挖掘

          情侶、基友、渣男和狗-基于SynchroTrap+LPA算法的團(tuán)伙賬戶挖掘

          孤立森林,一個(gè)通過XJB亂分進(jìn)行異常檢測(cè)的算法

          風(fēng)控策略的自動(dòng)化生成-利用決策樹分分鐘生成上千條策略

          基于密集行為的欺詐檢測(cè)算法-LockInfer

          關(guān)聯(lián)規(guī)則-策略挖掘中必不可少的算法

          異常檢測(cè)算法之(HBOS)-Histogram-based Outlier Score

          團(tuán)伙挖掘中的14大關(guān)系類型

          風(fēng)控難題之無監(jiān)督風(fēng)險(xiǎn)感知:腦力、想象力、第六感、黑洞、星座、面相···

          SynchroTrap-基于松散行為相似度的欺詐賬戶檢測(cè)算法

          瀏覽 132
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  毛片視| 神马午夜福利视频 | 骚逼随便操 | 国产 激情 视频 在线 | 色婷婷在线无码精品 |