Pytorch實(shí)現(xiàn)圖像修復(fù):GAN+上下文自編碼器
點(diǎn)擊上方“程序員大白”,選擇“星標(biāo)”公眾號
重磅干貨,第一時(shí)間送達(dá)
作者:Hmrishav Bandyopadhyay 編譯:公眾號 ronghuaiyang AI公園
一篇比較經(jīng)典的圖像復(fù)原的文章。
你知道在那個(gè)滿是灰塵的相冊里的童年舊照片是可以復(fù)原的嗎?是啊,就是那種每個(gè)人都手牽著手,盡情享受生活的那種!不相信我嗎?看看這個(gè):
圖像修復(fù)是人工智能研究的一個(gè)活躍領(lǐng)域,人工智能已經(jīng)能夠得出比大多數(shù)藝術(shù)家更好的修復(fù)結(jié)果。在本文中,我們將討論使用神經(jīng)網(wǎng)絡(luò),特別是上下文編碼器的圖像修復(fù)。本文解釋并實(shí)現(xiàn)了在CVPR 2016中提出的關(guān)于上下文編碼器的研究工作。
上下文編碼器
為了開始使用上下文編碼器,我們必須了解什么是“自編碼器”。自編碼器在結(jié)構(gòu)上由編碼器、解碼器以及一個(gè)bottleneck組成。一般的自編碼器的目的是通過忽略圖像中的噪聲來減小圖像的尺寸。然而,自編碼器不是特定于圖像,也可以擴(kuò)展到其他數(shù)據(jù)。自編碼器有特定的變體來完成特定的任務(wù)。

既然我們已經(jīng)了解了自編碼器,我們就可以將上下文編碼器比作自編碼器。上下文編碼器是一種卷積神經(jīng)網(wǎng)絡(luò),經(jīng)過訓(xùn)練,根據(jù)周圍環(huán)境生成任意圖像區(qū)域的內(nèi)容:即上下文編碼器接收圖像區(qū)域周圍的數(shù)據(jù),并嘗試生成適合該圖像區(qū)域的東西。就像我們小的時(shí)候拼拼圖一樣 —— 只是我們不需要生成拼圖的碎片。
我們這里的上下文編碼器由一個(gè)編碼器和一個(gè)解碼器組成,前者將圖像的上下文捕獲為一個(gè)緊湊的潛在特征表示,后者使用該表示來生成缺失的圖像內(nèi)容。由于我們需要一個(gè)龐大的數(shù)據(jù)集來訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò),我們不能只處理修復(fù)問題圖像。因此,我們從正常的圖像數(shù)據(jù)集中分割出部分圖像,創(chuàng)建一個(gè)修復(fù)問題,并將圖像提供給神經(jīng)網(wǎng)絡(luò),從而在我們分割的區(qū)域創(chuàng)建缺失的圖像內(nèi)容。
這里需要注意的是,輸入到神經(jīng)網(wǎng)絡(luò)的圖像有太多的缺失部分,經(jīng)典的修復(fù)方法根本無法工作。
GAN的使用
GANs或生成對抗網(wǎng)絡(luò)已被證明對圖像生成極為有用。生成對抗網(wǎng)絡(luò)運(yùn)行的基本原理是:一個(gè)生成器試圖“愚弄”一個(gè)鑒別器,一個(gè)確定的鑒別器試圖區(qū)分出生成器生成的圖像。換句話說,兩個(gè)網(wǎng)絡(luò)試圖分別使損失函數(shù)最小化和最大化。
區(qū)域掩碼
區(qū)域掩模是我們所屏蔽的圖像的一部分,這樣我們就可以將生成的修復(fù)問題提供給模型。通過屏蔽,我們將該圖像區(qū)域的像素值設(shè)置為0。有三種方法:
中心區(qū)域:對圖像數(shù)據(jù)進(jìn)行遮擋,最簡單的方法是將中心的正方形斑塊設(shè)為零。雖然網(wǎng)絡(luò)學(xué)習(xí)修復(fù),但我們面臨著泛化的問題。該網(wǎng)絡(luò)不能很好地泛化,只能學(xué)習(xí)到低層次的特征。 隨機(jī)塊:為了應(yīng)對網(wǎng)絡(luò)“鎖定”到掩碼區(qū)域邊界的問題,如在中央?yún)^(qū)域掩碼中,掩碼過程是隨機(jī)的。不是選擇一個(gè)單一的正方形貼片作為掩碼,而是設(shè)置多個(gè)重疊的正方形掩碼,最多占圖像的1/4。 隨機(jī)區(qū)域:然而,隨機(jī)塊掩蔽仍然有清晰的邊界供網(wǎng)絡(luò)捕捉。為了解決這個(gè)問題,任意的形狀必須從圖像中移除??梢詮腜ASCAL VOC 2012數(shù)據(jù)集中獲得任意形狀,并在任意圖像位置進(jìn)行變形和作為掩模放置。

在這里,我只實(shí)現(xiàn)了中心區(qū)域掩蔽方法,因?yàn)檫@只是一個(gè)指南,讓你開始用AI修復(fù)繪畫。你可以嘗試其他屏蔽方法,并在評論中告訴我結(jié)果!
結(jié)構(gòu)
現(xiàn)在,你應(yīng)該對模型有了一些了解。讓我們看看你是否正確。
該模型由一個(gè)編碼器和一個(gè)解碼器部分組成,構(gòu)建了模型的上下文編碼器部分。這部分還充當(dāng)生成數(shù)據(jù)和試圖愚弄鑒別器的生成器。該鑒別器由卷積網(wǎng)絡(luò)和一個(gè)最終給出一個(gè)標(biāo)量作為輸出的Sigmoid函數(shù)組成。
損失
模型的損失函數(shù)分為2部分:
1、重建損失:重建損失是L2損失函數(shù)。它有助于捕捉缺失區(qū)域的整體結(jié)構(gòu)和與其上下文相關(guān)的連貫性。數(shù)學(xué)上,它被表示為:

這里需要注意的是,僅使用L2損耗會(huì)使圖像變得模糊。因?yàn)槟:膱D像減少了平均像素的誤差,因此L2損失是最小的,但這不是我們想要的。
2、對抗損失:這試圖使預(yù)測“看起來”真實(shí)(記住生成器必須可以欺騙鑒別器!),這幫助我們在克服L2損失會(huì)導(dǎo)致我們得到模糊的圖像。數(shù)學(xué)上,我們可以把它表示為:

這里有一個(gè)有趣的觀察:對抗損失鼓勵(lì)整個(gè)輸出看起來真實(shí),而不僅僅是缺失的部分。換句話說,對抗性網(wǎng)絡(luò)給了整個(gè)圖像一個(gè)真實(shí)的外觀。
總的損失函數(shù):

我們來構(gòu)建這個(gè)模型!
現(xiàn)在,因?yàn)槲覀円呀?jīng)清楚了網(wǎng)絡(luò)的主要的要點(diǎn),讓我們開始構(gòu)建模型。我將首先建立模型結(jié)構(gòu),然后進(jìn)入訓(xùn)練和損失函數(shù)部分。該模型使用PyTorch進(jìn)行構(gòu)建。
讓我們從生成網(wǎng)絡(luò)開始:
import torch
from torch import nn
class generator(nn.Module):
#generator model
def __init__(self):
super(generator,self).__init__()
self.t1=nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
nn.LeakyReLU(0.2,in_place=True)
)
self.t2=nn.Sequential(
nn.Conv2d(in_channels=64,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2,in_place=True)
)
self.t3=nn.Sequential(
nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2,in_place=True)
)
self.t4=nn.Sequential(
nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2,in_place=True)
)
self.t5=nn.Sequential(
nn.Conv2d(in_channels=256,out_channels=512,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2,in_place=True)
)
self.t6=nn.Sequential(
nn.Conv2d(512,4000,kernel_size=(4,4))#bottleneck
nn.BatchNorm2d(4000),
nn.ReLU()
)
self.t7=nn.Sequential(
nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.t8=nn.Sequential(
nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.t9=nn.Sequential(
nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.t10=nn.Sequential(
nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=(4,4),stride=2,padding=1),
nn.Tanh()
)
def forward(self,x):
x=self.t1(x)
x=self.t2(x)
x=self.t3(x)
x=self.t4(x)
x=self.t5(x)
x=self.t6(x)
x=self.t7(x)
x=self.t8(x)
x=self.t9(x)
x=self.t10(x)
return x #output of generator
現(xiàn)在,是鑒別器網(wǎng)絡(luò):
import torch
from torch import nn
class discriminator(nn.Module):
#discriminator model
def __init__(self):
super(discriminator,self).__init__()
self.t1=nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
nn.LeakyReLU(0.2,in_place=True)
)
self.t2=nn.Sequential(
nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2,in_place=True)
)
self.t3=nn.Sequential(
nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2,in_place=True)
)
self.t4=nn.Sequential(
nn.Conv2d(in_channels=256,out_channels=512,kernel_size=(4,4),stride=2,padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2,in_place=True)
)
self.t5=nn.Sequential(
nn.Conv2d(in_channels=512,out_channels=1,kernel_size=(4,4),stride=1,padding=0),
nn.Sigmoid()
)
def forward(self,x):
x=self.t1(x)
x=self.t2(x)
x=self.t3(x)
x=self.t4(x)
x=self.t5(x)
return x #output of discriminator
現(xiàn)在讓我們開始訓(xùn)練網(wǎng)絡(luò)。我們將batch size設(shè)置為64,epoch的數(shù)量設(shè)置為100。學(xué)習(xí)速率設(shè)置為0.0002。
from model import generator, discriminator
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from model import _netlocalD,_netG
import utils
epochs=100
Batch_Size=64
lr=0.0002
beta1=0.5
over=4
parser = argparse.ArgumentParser()
parser.add_argument('--dataroot', default='dataset/train', help='path to dataset')
opt = parser.parse_args()
try:
os.makedirs("result/train/cropped")
os.makedirs("result/train/real")
os.makedirs("result/train/recon")
os.makedirs("model")
except OSError:
pass
transform = transforms.Compose([transforms.Scale(128),
transforms.CenterCrop(128),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = dset.ImageFolder(root=opt.dataroot, transform=transform )
assert dataset
dataloader = torch.utils.data.DataLoader(dataset, batch_size=Batch_Size,
shuffle=True, num_workers=2)
ngpu = int(opt.ngpu)
wtl2 = 0.999
# custom weights initialization called on netG and netD
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
resume_epoch=0
netG = generator()
netG.apply(weights_init)
netD = discriminator()
netD.apply(weights_init)
criterion = nn.BCELoss()
criterionMSE = nn.MSELoss()
input_real = torch.FloatTensor(Batch_Size, 3, 128, 128)
input_cropped = torch.FloatTensor(Batch_Size, 3, 128, 128)
label = torch.FloatTensor(Batch_Size)
real_label = 1
fake_label = 0
real_center = torch.FloatTensor(Batch_Size, 3, 64,64)
netD.cuda()
netG.cuda()
criterion.cuda()
criterionMSE.cuda()
input_real, input_cropped,label = input_real.cuda(),input_cropped.cuda(), label.cuda()
real_center = real_center.cuda()
input_real = Variable(input_real)
input_cropped = Variable(input_cropped)
label = Variable(label)
real_center = Variable(real_center)
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
for epoch in range(resume_epoch,epochs):
for i, data in enumerate(dataloader, 0):
real_cpu, _ = data
real_center_cpu = real_cpu[:,:,int(128/4):int(128/4)+int(128/2),int(128/4):int(128/4)+int(128/2)]
batch_size = real_cpu.size(0)
with torch.no_grad():
input_real.resize_(real_cpu.size()).copy_(real_cpu)
input_cropped.resize_(real_cpu.size()).copy_(real_cpu)
real_center.resize_(real_center_cpu.size()).copy_(real_center_cpu)
input_cropped[:,0,int(128/4+over):int(128/4+128/2-over),int(128/4+over):int(128/4+128/2-over)] = 2*117.0/255.0 - 1.0
input_cropped[:,1,int(128/4+over):int(128/4+128/2-over),int(128/4+over):int(128/4+128/2-over)] = 2*104.0/255.0 - 1.0
input_cropped[:,2,int(128/4+over):int(128/4+128/2-over),int(128/4+over):int(128/4+128/2-over)] = 2*123.0/255.0 - 1.0
#start the discriminator by training with real data---
netD.zero_grad()
with torch.no_grad():
label.resize_(batch_size).fill_(real_label)
output = netD(real_center)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.data.mean()
# train the discriminator with fake data---
fake = netG(input_cropped)
label.data.fill_(fake_label)
output = netD(fake.detach())
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.data.mean()
errD = errD_real + errD_fake
optimizerD.step()
#train the generator now---
netG.zero_grad()
label.data.fill_(real_label) # fake labels are real for generator cost
output = netD(fake)
errG_D = criterion(output, label)
wtl2Matrix = real_center.clone()
wtl2Matrix.data.fill_(wtl2*10)
wtl2Matrix.data[:,:,int(over):int(128/2 - over),int(over):int(128/2 - over)] = wtl2
errG_l2 = (fake-real_center).pow(2)
errG_l2 = errG_l2 * wtl2Matrix
errG_l2 = errG_l2.mean()
errG = (1-wtl2) * errG_D + wtl2 * errG_l2
errG.backward()
D_G_z2 = output.data.mean()
optimizerG.step()
print('[%d / %d][%d / %d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
% (epoch, epochs, i, len(dataloader),
errD.data, errG_D.data,errG_l2.data, D_x,D_G_z1, ))
if i % 100 == 0:
vutils.save_image(real_cpu,
'result/train/real/real_samples_epoch_%03d.png' % (epoch))
vutils.save_image(input_cropped.data,
'result/train/cropped/cropped_samples_epoch_%03d.png' % (epoch))
recon_image = input_cropped.clone()
recon_image.data[:,:,int(128/4):int(128/4+128/2),int(128/4):int(128/4+128/2)] = fake.data
vutils.save_image(recon_image.data,
'result/train/recon/recon_center_samples_epoch_%03d.png' % (epoch))
結(jié)果
讓我們看一下我們的模型能夠構(gòu)建出什么來?第0個(gè)epoch時(shí)候的圖像(噪聲):

第100個(gè)epoch時(shí)候:

我們看下輸入模型的是什么:

英文原文:https://towardsdatascience.com/inpainting-with-ai-get-back-your-images-pytorch-a68f689128e5
推薦閱讀
國產(chǎn)小眾瀏覽器因屏蔽視頻廣告,被索賠100萬(后續(xù))
年輕人“不講武德”:因看黃片上癮,把網(wǎng)站和786名女主播起訴了
關(guān)于程序員大白
程序員大白是一群哈工大,東北大學(xué),西湖大學(xué)和上海交通大學(xué)的碩士博士運(yùn)營維護(hù)的號,大家樂于分享高質(zhì)量文章,喜歡總結(jié)知識,歡迎關(guān)注[程序員大白],大家一起學(xué)習(xí)進(jìn)步!


