視覺進階 | 用于圖像降噪的卷積自編碼器
點擊上方“小白學視覺”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達
本文轉(zhuǎn)自:磐創(chuàng)AI

作者|Dataman
編譯|Arno
來源|Analytics Vidhya
這篇文章的目的是介紹關(guān)于利用自動編碼器實現(xiàn)圖像降噪的內(nèi)容。
在神經(jīng)網(wǎng)絡(luò)世界中,對圖像數(shù)據(jù)進行建模需要特殊的方法。其中最著名的是卷積神經(jīng)網(wǎng)絡(luò)(CNN或ConvNet)或稱為卷積自編碼器。并非所有的讀者都了解圖像數(shù)據(jù),那么我先簡要介紹圖像數(shù)據(jù)(如果你對這方面已經(jīng)很清楚了,可以跳過)。然后,我會介紹標準神經(jīng)網(wǎng)絡(luò)。這個標準神經(jīng)網(wǎng)絡(luò)用于圖像數(shù)據(jù),比較簡單。這解釋了處理圖像數(shù)據(jù)時為什么首選的是卷積自編碼器。最重要的是,我將演示卷積自編碼器如何減少圖像噪聲。這篇文章將用上Keras模塊和MNIST數(shù)據(jù)。Keras用Python編寫,并且能夠在TensorFlow上運行,是高級的神經(jīng)網(wǎng)絡(luò)API。

from keras.layers import Input, Dense
from keras.models import Model
from keras.datasets import mnist
import numpy as np
(x_train, _), (x_test, _) = mnist.load_data()
它們看起來怎么樣?我們用繪圖庫及其圖像功能imshow()展示前十條記錄。
import matplotlib.pyplot as plt
n = 10 # 顯示的記錄數(shù)
plt.figure(figsize=(20, 4))
for i in range(n):
# 顯示原始圖片
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()

圖像數(shù)據(jù)的堆疊,用于訓練




1. 卷積層


1.1填充
1.2步長
2.線性整流步驟
3.最大池化層


from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
# 編碼過程
input_img = Input(shape=(28, 28, 1))
############
# 編碼 #
############
# Conv1 #
x = Conv2D(filters = 16, kernel_size = (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D(pool_size = (2, 2), padding='same')(x)
# Conv2 #
x = Conv2D(filters = 8, kernel_size = (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D(pool_size = (2, 2), padding='same')(x)
# Conv 3 #
x = Conv2D(filters = 8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D(pool_size = (2, 2), padding='same')(x)
# 注意:
# padding 是一個超參數(shù),值'valid' or 'same'.
# "valid" 意味不需要填充
# "same" 填充輸入,使輸出具有與原始輸入相同的長度。
然后,解碼過程繼續(xù)。因此,下面解碼部分已全部完成編碼和解碼過程。
############
# 解碼 #
############
# DeConv1
x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
# DeConv2
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
# Deconv3
x = Conv2D(16, (3, 3), activation='relu')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
該Keras API需要模型和優(yōu)化方法的聲明:
# 聲明模型
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
# 訓練模型
autoencoder.fit(x_train, x_train,
epochs=100,
batch_size=128,
shuffle=True,
validation_data=(x_test, x_test)
)
decoded_imgs = autoencoder.predict(x_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
# 顯示原始圖像
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# 顯示重構(gòu)后的圖像
ax = plt.subplot(2, n, i+1+n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()

noise_factor = 0.4
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape)
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
前十張噪聲圖像如下所示:
n = 10
plt.figure(figsize=(20, 2))
for i in range(n):
ax = plt.subplot(1, n, i+1)
plt.imshow(x_test_noisy[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()

然后,我們訓練模型時將輸入噪聲數(shù)據(jù),輸出干凈的數(shù)據(jù)。
autoencoder.fit(x_train_noisy, x_train,
epochs=100,
batch_size=128,
shuffle=True,
validation_data=(x_test_noisy, x_test)
)
最后,我們打印出前十個噪點圖像以及相應(yīng)的降噪圖像。
decoded_imgs = autoencoder.predict(x_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
# 顯示原始圖像
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test_noisy[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# 顯示重構(gòu)后的圖像
ax = plt.subplot(2, n, i+1+n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()

交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

