圖像去模糊算法代碼實踐!

來源 | Datawhale
1.起源:GAN
結(jié)構(gòu)與原理
在介紹DeblurGANv2之前,我們需要大概了解一下GAN,GAN最初的應(yīng)用是圖片生成,即根據(jù)訓(xùn)練集生成圖片,如生成手寫數(shù)字圖像、人臉圖像、動物圖像等等,其主要結(jié)構(gòu)如下:

我們先由上圖的左下方開始,假設(shè)現(xiàn)在只有一個樣本,即batch size為1,則Random noise是一個由服從標(biāo)準(zhǔn)正態(tài)分布的隨機(jī)數(shù)組成的向量。首先,我們將Random noise輸入Generator,最原始GAN的Generator是一個多層感知機(jī),其輸入是一個向量,輸出也是一個向量,然后我們將輸出的向量reshape成一個矩陣,這個矩陣就是一張圖片(一個矩陣是因為MNIST手寫數(shù)據(jù)集中的圖片是單通道的灰度圖,如果想生成彩色圖像就reshape成三個矩陣),即與上圖的“8”對應(yīng)。我們稱Generator生成的圖像為fake image,訓(xùn)練集中的圖片為real image。
上圖中的Distriminator為判別器,它是一個二分類的多層感知機(jī),輸出只有一個數(shù),由于多層感知機(jī)只接受向量為其輸入,我們將一張圖片由矩陣展開為向量后再輸入Discriminator,經(jīng)過一系列運算后輸出一個0~1之間的數(shù),這個數(shù)越接近于0,代表著判別器認(rèn)為這張圖片是fake image;反之,假如輸出的數(shù)越接近于1,則判別器認(rèn)為這張圖片是real image。為了方便,我們將Generator簡稱為G,Distriminator簡稱為D。
總而言之,G的目的是讓自己生成的fake image盡可能欺騙D,而D的任務(wù)是盡可能辨別出fake image和real image,二者不停博弈。最終理想情況下,G生成的數(shù)據(jù)與真實數(shù)據(jù)非常接近,而D無論輸入fake image還是real image都輸出0.5。
損失函數(shù)
GAN的損失函數(shù)是Binary cross entropy loss,簡稱為BCELoss,其主要利用了極大似然的思想,實際上就是二分類對應(yīng)的交叉熵?fù)p失函數(shù)。公式如下:
其中是樣本數(shù),是第個樣本的真實值,是第個樣本的預(yù)測值。對于第個樣本來說,由于取值只能是0或1,此時只看第個樣本,所以。當(dāng)時,,而的取值范圍為0~1,故當(dāng)時,=0,當(dāng)時,,我們的目標(biāo)是使的值越小越好,即當(dāng)越接近0時,的值越小。反之,當(dāng)時,,越接近1時,的值越小??傊?dāng)越接近于時,的值越小。
那么BCELoss和GAN有什么關(guān)系呢?
我們將GAN的Loss分為和,即生成器的損失和判別器的損失。
對于生成器來說,它希望自己生成的圖片能騙過判別器,即希望D(fake)越接近1越好,D(fake)就是G生成的圖片輸入D后的輸出值,D(fake)接近于1意味著G生成的圖片可以以假亂真來欺騙判別器,所以GLoss的公式如下所示:
當(dāng)越接近1,越小,意味著生成器騙過了判別器;
對于判別器來說,它的損失分為兩部分,首先,它不希望自己被fake image欺騙,即與相反,這里用表示:
當(dāng)越接近0,越小,意味著判別器分辨出了fake image;
其次,判別器做出判斷必須有依據(jù),所以它需要知道真實圖片是什么樣的才能正確地辨別假圖片,這里用表示:
當(dāng)越接近1,越小,意味著判別器辨別出了real image。
其實就是這兩個損失值的平均值:
優(yōu)化器
介紹完GAN的損失函數(shù)后,我們還剩下最后一個問題:怎么使損失函數(shù)的值越來越???
這里就需要說一下優(yōu)化器(Optimizer),優(yōu)化器就是使損失函數(shù)值越來越小的工具,常用的優(yōu)化器有SGD、NAG、RMSProp、Adagrad、Adam和Adam的一些變種,其中最常用的是Adam。
最終結(jié)果

由上圖我們可以清楚地看出來,隨著訓(xùn)練輪數(shù)增加,G生成的fake image越來越接近手寫數(shù)字。
目前GAN有很多應(yīng)用,每個應(yīng)用對應(yīng)的論文和Pytorch代碼可以參考下面的鏈接,其中也有GAN的代碼,大家可以根據(jù)代碼進(jìn)一步理解GAN:https://github.com/eriklindernoren/PyTorch-GAN
2.圖像去模糊算法:DeblurGANv2
數(shù)據(jù)集
圖像去模糊的數(shù)據(jù)集通常由許多組圖像組成,每組圖像就是一張清晰圖像和與之對應(yīng)的模糊圖像。然而,其數(shù)據(jù)集的制作并不容易,目前常用的方法有兩種,第一種是用高幀數(shù)的攝像機(jī)拍攝視頻,從視頻中找到連續(xù)幀中的模糊圖片和清晰圖片作為一組數(shù)據(jù);第二種方法是用已知或隨機(jī)生成的運動模糊核對清晰圖片進(jìn)行模糊操作,生成對應(yīng)的一組數(shù)據(jù)。albumentations是Python中常用的數(shù)據(jù)擴(kuò)增庫,可以對圖片進(jìn)行旋轉(zhuǎn)、縮放、裁剪等操作,我們也可以使用albumentations給圖像增加運動模糊,具體操作如下:
首先安裝albumentations庫,在cmd或虛擬環(huán)境中輸入:
python -m pip install albumentations
為了給圖像添加運動模糊,我們需要用matplotlib庫來讀取、顯示和保存圖片。
import albumentations as A
from matplotlib import pyplot as plt
# 讀取和顯示原圖
img = plt.imread('./images/ywxd.jpg')
plt.imshow(img)
plt.axis('off')
plt.show()

albumentations添加運動模糊操作如下,其中blur_limit是卷積核大小的范圍,這里卷積核大小在150到180之間,卷積核越大,模糊效果越明顯;p是進(jìn)行運動模糊操作概率。
aug = A.MotionBlur(blur_limit=(50, 80), p=1.0)
aug_img = aug(image=img)['image']
plt.imshow(aug_img)
plt.axis('off')
plt.show()

如果想查看對應(yīng)的模糊核,我們可以對aug這個實例調(diào)用get_params方法,這里為了大家觀看方便,我使用的是3*3的卷積核。
aug = A.MotionBlur(blur_limit=(3, 3), p=1.0)
aug.get_params(){'kernel': array([[0. , 0. , 0.33333334],
[0.33333334, 0.33333334, 0. ],
[0. , 0. , 0. ]], dtype=float32)}
我使用的數(shù)據(jù)集是DeblurGANv1的數(shù)據(jù)集,鏈接:https://gas.graviti.cn/dataset/datawhale/BlurredSharp
模糊圖片:

清晰圖片:

網(wǎng)絡(luò)結(jié)構(gòu)
DeblurGANv2的思路與GAN大致相同,區(qū)別之處在于其對GAN做了大量優(yōu)化,我們先來看Generator的結(jié)構(gòu):

觀察上圖可以發(fā)現(xiàn),G主要有兩個改變:
輸入用模糊的圖片替代了GAN中的隨機(jī)向量
網(wǎng)絡(luò)結(jié)構(gòu)引入了目標(biāo)檢測中的FPN結(jié)構(gòu),融合了多尺度的特征
另外,在特征提取部分作者提供了三種網(wǎng)絡(luò)主干:MobileNetv2、inceptionresnetv2和densenet121,經(jīng)過作者實驗得出,inceptionresnetv2的效果最好,但模型較大,而MobilNetv2在不降低太大效果的基礎(chǔ)上大大減少了網(wǎng)絡(luò)參數(shù),網(wǎng)絡(luò)主干在上圖中對應(yīng)部分如下所示:

最后,將fpn的輸出與原圖進(jìn)行按元素相加操作得到最終輸出。
DeblurGANv2的判別器由全局和局部兩部分組成,全局判別器輸入的是整張圖片,局部判別器輸入的是隨機(jī)裁剪后的圖片,將輸入圖片經(jīng)過一系列卷積操作后輸出一個數(shù),這個數(shù)代表判別器認(rèn)為其為real image的概率,判別器的結(jié)構(gòu)如下所示:

損失函數(shù)
DeblurGANv2與GAN差別最大的部分就是它的損失函數(shù),我們首先看看D的loss:
D的目的是為了辨別圖片的真假,所以D(fake)越小,D(real)越大時,代表D能很好地判斷圖片的真假,故對于D來說,越小越好
為了防止過擬合,后面還會加上一個L2懲罰項:
G的loss較D復(fù)雜很多,它由和組成,其實就是一個perceptual loss,它其實就是將real image和fake image分別輸入vgg19,將輸出的特征圖做MSELoss(均方誤差),而作者在perceptual loss的基礎(chǔ)上又做了一些改變,公式可以總結(jié)為下式:
由公式可以很容易推斷,的作用就是讓G生成的圖片和原圖盡可能相似來達(dá)到去模糊的目的。
對于來說,其可以總結(jié)為下面公式:
由于G的目的是盡可能以假亂真騙過D,所以和越接近于1越好,即越小越好。
最后,G的loss如下所示:
作者給出的lambda為0.001,可以看出作者更注重生成圖像與原圖的相似性。
3.代碼實踐
訓(xùn)練自己的數(shù)據(jù)集
(目前僅支持gpu訓(xùn)練!)
github項目地址:https://github.com/VITA-Group/DeblurGANv2
數(shù)據(jù)地址:https://gas.graviti.cn/dataset/datawhale/BlurredSharp
首先將數(shù)據(jù)文件夾和項目文件夾按照下面結(jié)構(gòu)放置:

安裝python環(huán)境,在cmd中輸入:
conda create -n deblur python=3.9
conda activate deblur
python -m pip install -r requirements.txt
修改config文件夾中的配置文件config.yaml:
project: deblur_gan
experiment_desc: fpn
train:
files_a: &FILES_A ./dataset/train/blurred/*.png
files_b: &FILES_B ./dataset/train/sharp/*.png
size: &SIZE 256
crop: random
preload: &PRELOAD false
preload_size: &PRELOAD_SIZE 0
bounds: [0, .9]
scope: geometric
corrupt: &CORRUPT
- name: cutout
prob: 0.5
num_holes: 3
max_h_size: 25
max_w_size: 25
- name: jpeg
quality_lower: 70
quality_upper: 90
- name: motion_blur
- name: median_blur
- name: gamma
- name: rgb_shift
- name: hsv_shift
- name: sharpen
val:
files_a: &FILE_A ./dataset/val/blurred/*.png
files_b: &FILE_B ./dataset/val/sharp/*.png
size: *SIZE
scope: geometric
crop: center
preload: *PRELOAD
preload_size: *PRELOAD_SIZE
bounds: [.9, 1]
corrupt: *CORRUPT
phase: train
warmup_num: 3
model:
g_name: resnet
blocks: 9
d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale
d_layers: 3
content_loss: perceptual
adv_lambda: 0.001
disc_loss: wgan-gp
learn_residual: True
norm_layer: instance
dropout: True
num_epochs: 200
train_batches_per_epoch: 1000
val_batches_per_epoch: 100
batch_size: 1
image_size: [256, 256]
optimizer:
name: adam
lr: 0.0001
scheduler:
name: linear
start_epoch: 50
min_lr: 0.0000001
如果是windows系統(tǒng)需要刪除train.py第180行
然后在cmd中cd到項目路徑并輸入:
python train.py
訓(xùn)練結(jié)果可以在tensorboard中可視化出來:
驗證集ssim(結(jié)構(gòu)相似性):

驗證集GLoss:

驗證集PSNR(峰值信噪比):

測試(CPU、GPU均可)
GPU
將測試圖片以test.png保存到DeblurGANv2-master文件夾下,在CMD中輸入:
python predict.py test.png
運行成功后結(jié)果submit文件夾中,predict.py中的模型文件默認(rèn)為best_fpn.h5,大家也可以在DeblurGANv2的github中下載作者訓(xùn)練好的模型文件,保存在項目文件夾后將predict.py文件中的第93行改為想要用的模型文件即可,如將'best_fpn.h5'改為'fpn_inception.h5',但是需要將config.yaml中model對應(yīng)的g_name改為相應(yīng)模型,如想使用'fpn_mobilenet.h5',就將'fpn_inception'改為'fpn_mobilenet'
CPU
將predict.py文件中第21行、22和65行改為下面代碼即可
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))['model'])
self.model = model
inputs = [img]
運行后就可以得到下面效果:


DeblurGAN的應(yīng)用:優(yōu)化YOLOv5性能


由上圖可以看出,圖片去模糊不僅可以提高YOLOv5的檢測置信度,還可以使檢測更準(zhǔn)確。以Mobilenetv2為backbone的DeblurGANv2能達(dá)到圖片實時去模糊的要求,進(jìn)而可以使用到視頻質(zhì)量增強(qiáng)等方向。
線上訓(xùn)練
如果我們不想把數(shù)據(jù)集下載到本地的話可以考慮格物鈦(Graviti)的線上訓(xùn)練功能,在原項目的基礎(chǔ)上改幾行代碼即可。
首先我們打開項目文件夾中的dataset.py文件,在第一行導(dǎo)入tensorbay和PIL(如果沒有安裝tensorbay需要先pip install):
from tensorbay import GAS
from tensorbay.dataset import Dataset as TensorBayDataset
from PIL import Image
我們主要修改的是PairedDatasetOnline類還有_read_img函數(shù),為了保留原來的類,我們新建一個類,將下面代碼復(fù)制粘貼到dataset.py文件中即可(記得將ACCESS_KEY改為自己空間的 Graviti AccessKey):
class PairedDatasetOnline(Dataset):
def __init__(self,
files_a: Tuple[str],
files_b: Tuple[str],
transform_fn: Callable,
normalize_fn: Callable,
corrupt_fn: Optional[Callable] = None,
preload: bool = True,
preload_size: Optional[int] = 0,
verbose=True):
assert len(files_a) == len(files_b)
self.preload = preload
self.data_a = files_a
self.data_b = files_b
self.verbose = verbose
self.corrupt_fn = corrupt_fn
self.transform_fn = transform_fn
self.normalize_fn = normalize_fn
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
if preload:
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
if files_a == files_b:
self.data_a = self.data_b = preload_fn(self.data_a)
else:
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
self.preload = True
def _bulk_preload(self, data: Iterable[str], preload_size: int):
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
@staticmethod
def _preload(x: str, preload_size: int):
img = _read_img(x)
if preload_size:
h, w, *_ = img.shape
h_scale = preload_size / h
w_scale = preload_size / w
scale = max(h_scale, w_scale)
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
return img
def _preprocess(self, img, res):
def transpose(x):
return np.transpose(x, (2, 0, 1))
return map(transpose, self.normalize_fn(img, res))
def __len__(self):
return len(self.data_a)
def __getitem__(self, idx):
a, b = self.data_a[idx], self.data_b[idx]
if not self.preload:
a, b = map(_read_img, (a, b))
a, b = self.transform_fn(a, b)
if self.corrupt_fn is not None:
a = self.corrupt_fn(a)
a, b = self._preprocess(a, b)
return {'a': a, 'b': b}
@staticmethod
def from_config(config):
config = deepcopy(config)
# files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
segment_name = 'train' if 'train' in config['files_a'] else 'val'
ACCESS_KEY = "yours"
gas = GAS(ACCESS_KEY)
dataset = TensorBayDataset("BlurredSharp", gas)
segment = dataset[segment_name]
files_a = [i for i in segment if 'blurred' == i.path.split('/')[2]]
files_b = [i for i in segment if 'sharp' == i.path.split('/')[2]]
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
normalize_fn = aug.get_normalize()
corrupt_fn = aug.get_corrupt_function(config['corrupt'])
# ToDo: add more hash functions
verbose = config.get('verbose', True)
return PairedDatasetOnline(files_a=files_a,
files_b=files_b,
preload=config['preload'],
preload_size=config['preload_size'],
corrupt_fn=corrupt_fn,
normalize_fn=normalize_fn,
transform_fn=transform_fn,
verbose=verbose)
再將_read_img改為:
def _read_img(x):
with x.open() as fp:
img = cv2.cvtColor(np.asarray(Image.open(fp)), cv2.COLOR_RGB2BGR)
if img is None:
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
img = imread(x)[:, :, ::-1]
return img
最后一步將train.py第184行的datasets = map(PairedDataset.from_config, datasets)改為datasets = map(PairedDatasetOnline.from_config, datasets)即可。
