TensorFlow實(shí)現(xiàn)對(duì)花朵數(shù)據(jù)集的圖片分類(lèi)
點(diǎn)擊下方卡片,關(guān)注“新機(jī)器視覺(jué)”公眾號(hào)
重磅干貨,第一時(shí)間送達(dá)
轉(zhuǎn)載自:古月居
轉(zhuǎn)載自:古月居
前言
利用TensorFlow實(shí)現(xiàn)對(duì)花朵數(shù)據(jù)集的圖片分類(lèi)
提示:以下是本篇文章正文內(nèi)容,下面案例可供參考
一、數(shù)據(jù)集
數(shù)據(jù)集是五個(gè)分別存放著對(duì)應(yīng)類(lèi)別花朵圖片的五個(gè)文件夾,包括daisy(雛菊)633張;dandelion(蒲公英)898張,rose(玫瑰)641張,sunflower(向日葵)699張,tulips(郁金香)799張。
二、代碼
1、下載數(shù)據(jù)集
import tensorflow as tfAUTOTUNE = tf.data.experimental.AUTOTUNEimport pathlibdata_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', fname='flower_photos', untar=True)data_root = pathlib.Path(data_root_orig)print(data_root)for item in data_root.iterdir(): ?? ?print(item)
打印下載后的文件路徑和文件成員:
output:
C:\Users\Administrator.keras\datasets\flower_photos
C:\Users\Administrator.keras\datasets\flower_photos\daisy
C:\Users\Administrator.keras\datasets\flower_photos\dandelion
C:\Users\Administrator.keras\datasets\flower_photos\LICENSE.txt
C:\Users\Administrator.keras\datasets\flower_photos\roses
C:\Users\Administrator.keras\datasets\flower_photos\sunflowers
C:\Users\Administrator.keras\datasets\flower_photos\tulips
2、統(tǒng)計(jì)并觀察數(shù)據(jù)
#獲取五個(gè)文件夾的名字label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())label_names
output:[‘daisy’, ‘dandelion’, ‘roses’, ‘sunflowers’, ‘tulips’]
#將文件夾的名字(即花的分類(lèi))標(biāo)上序號(hào)label_to_index = dict((name, index) for index, name in enumerate(label_names))label_to_index
{‘daisy’: 0, ‘dandelion’: 1, ‘roses’: 2, ‘sunflowers’: 3, ‘tulips’: 4}
#獲取所以圖片的標(biāo)簽(0,1,2,3,4)import randomall_image_paths = list(data_root.glob('*/*'))all_image_paths = [str(path) for path in all_image_paths]random.shuffle(all_image_paths)image_count = len(all_image_paths)image_countall_image_labels = [label_to_index[pathlib.Path(path).parent.name]for path in all_image_paths]print("First 10 labels indices: ", all_image_labels[:10])
3670
First 10 labels indices: [0, 2, 3, 4, 2, 1, 4, 1, 4, 0]
下面我們先來(lái)觀察一張圖片
#觀察第一張圖片img_path = all_image_paths[0]img_path
‘C:\Users\Administrator\.keras\datasets\flower_photos\daisy\11124324295_503f3a0804.jpg’
#讀取原圖img_raw = tf.io.read_file(img_path)#轉(zhuǎn)換成TensorFlow可以使用的tensor類(lèi)型img_tensor = tf.image.decode_image(img_raw)print(img_tensor.shape)print(img_tensor.dtype)
(309, 500, 3)
#對(duì)圖片按要求進(jìn)行轉(zhuǎn)換,這里將size規(guī)定到【192,192】;值域映射到【0,1】img_final = tf.image.resize(img_tensor, [192, 192])img_final = img_final/255.0print(img_final.shape)print(img_final.numpy().min())print(img_final.numpy().max())
(192, 192, 3)
0.0
0.99984366
定義預(yù)處理和加載函數(shù)
def preprocess_image(image): ?? ?image = tf.image.decode_jpeg(image, channels=3) ?? ?image = tf.image.resize(image, [192, 192]) ?? ?image /= 255.0 ?? ?return imagedef load_and_preprocess_image(path): ?? ?image = tf.io.read_file(path) ?? ?return preprocess_image(image)
導(dǎo)入matpoltlib進(jìn)行畫(huà)圖(導(dǎo)入失敗的解決方案見(jiàn)我的另一篇博文)
import matplotlib.pyplot as pltimage_path = all_image_paths[0]label = all_image_labels[0]print (load_and_preprocess_image(img_path))plt.imshow(load_and_preprocess_image(img_path))plt.grid(False)
3、構(gòu)建數(shù)據(jù)集

使用tf.data.Dataset來(lái)構(gòu)建規(guī)范的數(shù)據(jù)集
#“from_tensor_slices ”方法使用張量的切片元素構(gòu)建圖片路徑的數(shù)據(jù)集path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)#同理,構(gòu)建標(biāo)簽數(shù)據(jù)集,并用tf.cast轉(zhuǎn)換成int64數(shù)據(jù)類(lèi)型label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))#根據(jù)路徑獲取圖片,并經(jīng)過(guò)加載和預(yù)處理得到圖片數(shù)據(jù)集image_ds = path_ds.map(load_and_preprocess_image )image_ds
將image_ds和label_ds打包成新的數(shù)據(jù)集
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))print(image_label_ds)
畫(huà)圖查看我們構(gòu)造完成的新數(shù)據(jù)
plt.figure(figsize=(8,8))for n,image_label in enumerate(image_label_ds.take(4)):? ?plt.subplot(2,2,n+1)? ?plt.imshow(image_label[0])? ?plt.grid(False)

4、遷移學(xué)習(xí)進(jìn)行分類(lèi)
接下來(lái)用創(chuàng)建的數(shù)據(jù)集訓(xùn)練一個(gè)分類(lèi)模型,簡(jiǎn)單起見(jiàn),直接用tf.keras.applications包中訓(xùn)練好的模型,并將其遷移到我們的圖片分類(lèi)問(wèn)題上來(lái)。這里使用的模型是MobileNetV2模型
#遷移MobileNetV2模型,并且不加載頂層base_model=tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False,weights='imagenet',input_shape=(192,192,3))inputs=tf.keras.layers.Input(shape=(192,192,3))#模型可視化1,使用model.summary()方法base_model.summary()
接下來(lái),我們打亂一下數(shù)據(jù)集,并定義好訓(xùn)練過(guò)程中每個(gè)批次(Batch)數(shù)據(jù)的大小
#使用shuffle方法打亂數(shù)據(jù)集image_count = len(all_image_paths)ds = image_label_ds.shuffle(buffer_size = image_count)#讓數(shù)據(jù)集重復(fù)多次ds = ds.repeat()#設(shè)置每個(gè)批次的大小BATCH_SIZE = 32ds = ds.batch(BATCH_SIZE)#通過(guò)prefetch方法讓模型的訓(xùn)練和每個(gè)批次數(shù)據(jù)的加載并行ds = ds.prefetch(buffer_size = AUTOTUNE)
然后,針對(duì)MobileNetV2改變一樣數(shù)據(jù)集的取值范圍,因?yàn)镸obileNetV2接受輸入的數(shù)據(jù)值域是【-1,1】,而我們之前的預(yù)處理函數(shù)將圖片的像素值映射到【0,1】
def change_range(image,label):? ?return 2*image-1,labelkeras_ds = ds.map(change_range)
接下來(lái)定義模型,由于預(yù)訓(xùn)練好的MobileNetV2返回的數(shù)據(jù)維度是(32,6,6,128),其中32是一個(gè)批次Batch的大小,“6,6”是輸出的特征的大小為6*6,1280代表該層使用的1280個(gè)卷積核。為了使用花朵分類(lèi)問(wèn)題,需要做一下調(diào)整
model = tf.keras.Sequential([? ?base_model,? ?tf.keras.layers.GlobalAveragePooling2D(),? ?tf.keras.layers.Dense(len(label_names),activation="softmax")? ?])
如上代碼,我們用Sequentail建立我們的網(wǎng)絡(luò)結(jié)構(gòu),base_model是遷移過(guò)來(lái)的模型,我們添加了全局評(píng)價(jià)池化層GlobalAveragePooling,經(jīng)過(guò)此操作6*6的特征被降維,變?yōu)椋?2,1280)。
最后,由于該分類(lèi)問(wèn)題有五個(gè)結(jié)果,我們?cè)黾右粋€(gè)全連接層(Dense)將維度變?yōu)椋?2,5)。
最后,編譯一下模型,同時(shí)制定使用的優(yōu)化器,損失函數(shù)和評(píng)價(jià)標(biāo)準(zhǔn)
model.compile(optimizer = tf.keras.optimizers.Adam(),? ? ? ? ? ? loss='sparse_categorical_crossentropy',? ? ? ? ? ? metrics=['accuracy'])model.summary()

使用model.fit訓(xùn)練模型,epochs是訓(xùn)練的回合數(shù),step_per_epoch代表每個(gè)回合要去多少個(gè)批次數(shù)據(jù)。通常等于我們數(shù)據(jù)集大小除以批次大小后取證(3670/32≈10)
model.fit(ds,epochs=10,steps_per_epoch=100)
雖然沒(méi)有跑完整個(gè)代碼,但是已經(jīng)能看出來(lái)準(zhǔn)確度達(dá)到一個(gè)很高的程度,并在逐步上升。
本文僅做學(xué)術(shù)分享,如有侵權(quán),請(qǐng)聯(lián)系刪文。
