<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          Pytorch實(shí)現(xiàn)圖像修復(fù):GAN+上下文自編碼器

          共 21216字,需瀏覽 43分鐘

           ·

          2021-03-11 00:31

          點(diǎn)擊上方“程序員大白”,選擇“星標(biāo)”公眾號

          重磅干貨,第一時(shí)間送達(dá)

          作者:Hmrishav Bandyopadhyay
          編譯:公眾號 ronghuaiyang  AI公園
          導(dǎo)讀

          一篇比較經(jīng)典的圖像復(fù)原的文章。


          你知道在那個(gè)滿是灰塵的相冊里的童年舊照片是可以復(fù)原的嗎?是啊,就是那種每個(gè)人都手牽著手,盡情享受生活的那種!不相信我嗎?看看這個(gè):

          圖像修復(fù)是人工智能研究的一個(gè)活躍領(lǐng)域,人工智能已經(jīng)能夠得出比大多數(shù)藝術(shù)家更好的修復(fù)結(jié)果。在本文中,我們將討論使用神經(jīng)網(wǎng)絡(luò),特別是上下文編碼器的圖像修復(fù)。本文解釋并實(shí)現(xiàn)了在CVPR 2016中提出的關(guān)于上下文編碼器的研究工作。

          上下文編碼器

          為了開始使用上下文編碼器,我們必須了解什么是“自編碼器”。自編碼器在結(jié)構(gòu)上由編碼器、解碼器以及一個(gè)bottleneck組成。一般的自編碼器的目的是通過忽略圖像中的噪聲來減小圖像的尺寸。然而,自編碼器不是特定于圖像,也可以擴(kuò)展到其他數(shù)據(jù)。自編碼器有特定的變體來完成特定的任務(wù)。

          自編碼器結(jié)構(gòu)

          既然我們已經(jīng)了解了自編碼器,我們就可以將上下文編碼器比作自編碼器。上下文編碼器是一種卷積神經(jīng)網(wǎng)絡(luò),經(jīng)過訓(xùn)練,根據(jù)周圍環(huán)境生成任意圖像區(qū)域的內(nèi)容:即上下文編碼器接收圖像區(qū)域周圍的數(shù)據(jù),并嘗試生成適合該圖像區(qū)域的東西。就像我們小的時(shí)候拼拼圖一樣 —— 只是我們不需要生成拼圖的碎片。

          我們這里的上下文編碼器由一個(gè)編碼器和一個(gè)解碼器組成,前者將圖像的上下文捕獲為一個(gè)緊湊的潛在特征表示,后者使用該表示來生成缺失的圖像內(nèi)容。由于我們需要一個(gè)龐大的數(shù)據(jù)集來訓(xùn)練一個(gè)神經(jīng)網(wǎng)絡(luò),我們不能只處理修復(fù)問題圖像。因此,我們從正常的圖像數(shù)據(jù)集中分割出部分圖像,創(chuàng)建一個(gè)修復(fù)問題,并將圖像提供給神經(jīng)網(wǎng)絡(luò),從而在我們分割的區(qū)域創(chuàng)建缺失的圖像內(nèi)容。

          這里需要注意的是,輸入到神經(jīng)網(wǎng)絡(luò)的圖像有太多的缺失部分,經(jīng)典的修復(fù)方法根本無法工作。

          GAN的使用

          GANs或生成對抗網(wǎng)絡(luò)已被證明對圖像生成極為有用。生成對抗網(wǎng)絡(luò)運(yùn)行的基本原理是:一個(gè)生成器試圖“愚弄”一個(gè)鑒別器,一個(gè)確定的鑒別器試圖區(qū)分出生成器生成的圖像。換句話說,兩個(gè)網(wǎng)絡(luò)試圖分別使損失函數(shù)最小化和最大化。

          區(qū)域掩碼

          區(qū)域掩模是我們所屏蔽的圖像的一部分,這樣我們就可以將生成的修復(fù)問題提供給模型。通過屏蔽,我們將該圖像區(qū)域的像素值設(shè)置為0。有三種方法:

          1. 中心區(qū)域:對圖像數(shù)據(jù)進(jìn)行遮擋,最簡單的方法是將中心的正方形斑塊設(shè)為零。雖然網(wǎng)絡(luò)學(xué)習(xí)修復(fù),但我們面臨著泛化的問題。該網(wǎng)絡(luò)不能很好地泛化,只能學(xué)習(xí)到低層次的特征。
          2. 隨機(jī)塊:為了應(yīng)對網(wǎng)絡(luò)“鎖定”到掩碼區(qū)域邊界的問題,如在中央?yún)^(qū)域掩碼中,掩碼過程是隨機(jī)的。不是選擇一個(gè)單一的正方形貼片作為掩碼,而是設(shè)置多個(gè)重疊的正方形掩碼,最多占圖像的1/4。
          3. 隨機(jī)區(qū)域:然而,隨機(jī)塊掩蔽仍然有清晰的邊界供網(wǎng)絡(luò)捕捉。為了解決這個(gè)問題,任意的形狀必須從圖像中移除??梢詮腜ASCAL VOC 2012數(shù)據(jù)集中獲得任意形狀,并在任意圖像位置進(jìn)行變形和作為掩模放置。

          從左到右,a)中心掩碼,b)隨機(jī)塊掩碼,c)隨機(jī)區(qū)域掩碼

          在這里,我只實(shí)現(xiàn)了中心區(qū)域掩蔽方法,因?yàn)檫@只是一個(gè)指南,讓你開始用AI修復(fù)繪畫。你可以嘗試其他屏蔽方法,并在評論中告訴我結(jié)果!

          結(jié)構(gòu)

          現(xiàn)在,你應(yīng)該對模型有了一些了解。讓我們看看你是否正確。

          該模型由一個(gè)編碼器和一個(gè)解碼器部分組成,構(gòu)建了模型的上下文編碼器部分。這部分還充當(dāng)生成數(shù)據(jù)和試圖愚弄鑒別器的生成器。該鑒別器由卷積網(wǎng)絡(luò)和一個(gè)最終給出一個(gè)標(biāo)量作為輸出的Sigmoid函數(shù)組成。

          損失

          模型的損失函數(shù)分為2部分:

          1、重建損失:重建損失是L2損失函數(shù)。它有助于捕捉缺失區(qū)域的整體結(jié)構(gòu)和與其上下文相關(guān)的連貫性。數(shù)學(xué)上,它被表示為:

          這里需要注意的是,僅使用L2損耗會(huì)使圖像變得模糊。因?yàn)槟:膱D像減少了平均像素的誤差,因此L2損失是最小的,但這不是我們想要的。

          2、對抗損失:這試圖使預(yù)測“看起來”真實(shí)(記住生成器必須可以欺騙鑒別器!),這幫助我們在克服L2損失會(huì)導(dǎo)致我們得到模糊的圖像。數(shù)學(xué)上,我們可以把它表示為:

          這里有一個(gè)有趣的觀察:對抗損失鼓勵(lì)整個(gè)輸出看起來真實(shí),而不僅僅是缺失的部分。換句話說,對抗性網(wǎng)絡(luò)給了整個(gè)圖像一個(gè)真實(shí)的外觀。

          總的損失函數(shù):

          我們來構(gòu)建這個(gè)模型!

          現(xiàn)在,因?yàn)槲覀円呀?jīng)清楚了網(wǎng)絡(luò)的主要的要點(diǎn),讓我們開始構(gòu)建模型。我將首先建立模型結(jié)構(gòu),然后進(jìn)入訓(xùn)練和損失函數(shù)部分。該模型使用PyTorch進(jìn)行構(gòu)建。

          讓我們從生成網(wǎng)絡(luò)開始:

          import torch
          from torch import nn
          class generator(nn.Module):

              #generator model
              def __init__(self):
                  super(generator,self).__init__()
                  

                  self.t1=nn.Sequential(
                      nn.Conv2d(in_channels=3,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
                      nn.LeakyReLU(0.2,in_place=True)
                  )
                  
                  self.t2=nn.Sequential(
                      nn.Conv2d(in_channels=64,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(64),
                      nn.LeakyReLU(0.2,in_place=True)
                  )
                  self.t3=nn.Sequential(
                      nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(128),
                      nn.LeakyReLU(0.2,in_place=True)
                  )
                  self.t4=nn.Sequential(
                      nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(256),
                      nn.LeakyReLU(0.2,in_place=True)
                  )
                  self.t5=nn.Sequential(
                      nn.Conv2d(in_channels=256,out_channels=512,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(512),
                      nn.LeakyReLU(0.2,in_place=True)
                      
                  )
                  self.t6=nn.Sequential(
                      nn.Conv2d(512,4000,kernel_size=(4,4))#bottleneck
                      nn.BatchNorm2d(4000),
                      nn.ReLU()
                  )
                  self.t7=nn.Sequential(
                      nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(256),
                      nn.ReLU()
                      )
                  self.t8=nn.Sequential(
                      nn.ConvTranspose2d(in_channels=256,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(128),
                      nn.ReLU()
                      )
                  self.t9=nn.Sequential(
                      nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(64),
                      nn.ReLU()
                      )
                  self.t10=nn.Sequential(
                      nn.ConvTranspose2d(in_channels=64,out_channels=3,kernel_size=(4,4),stride=2,padding=1),
                      nn.Tanh()
                      )
                          
              def forward(self,x):
               x=self.t1(x)
               x=self.t2(x)
               x=self.t3(x)
               x=self.t4(x)
               x=self.t5(x)
               x=self.t6(x)
               x=self.t7(x)
               x=self.t8(x)
               x=self.t9(x)
               x=self.t10(x)
               return x #output of generator
          網(wǎng)絡(luò)的生成器模型

          現(xiàn)在,是鑒別器網(wǎng)絡(luò):

          import torch
          from torch import nn
          class discriminator(nn.Module):

              #discriminator model
              def __init__(self):
                  super(discriminator,self).__init__()
                  
                  self.t1=nn.Sequential(
                      nn.Conv2d(in_channels=3,out_channels=64,kernel_size=(4,4),stride=2,padding=1),
                      nn.LeakyReLU(0.2,in_place=True)
                  )
                  
                  self.t2=nn.Sequential(
                      nn.Conv2d(in_channels=64,out_channels=128,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(128),
                      nn.LeakyReLU(0.2,in_place=True)
                  )
                  
                  self.t3=nn.Sequential(
                      nn.Conv2d(in_channels=128,out_channels=256,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(256),
                      nn.LeakyReLU(0.2,in_place=True)
                  )
                  self.t4=nn.Sequential(
                      nn.Conv2d(in_channels=256,out_channels=512,kernel_size=(4,4),stride=2,padding=1),
                      nn.BatchNorm2d(512),
                      nn.LeakyReLU(0.2,in_place=True)
                  )
                  self.t5=nn.Sequential(
                      nn.Conv2d(in_channels=512,out_channels=1,kernel_size=(4,4),stride=1,padding=0),
                      nn.Sigmoid()
                  )        
              
              def forward(self,x):
               x=self.t1(x)
               x=self.t2(x)
               x=self.t3(x)
               x=self.t4(x)
               x=self.t5(x)
               return x #output of discriminator
          鑒別器網(wǎng)絡(luò)

          現(xiàn)在讓我們開始訓(xùn)練網(wǎng)絡(luò)。我們將batch size設(shè)置為64,epoch的數(shù)量設(shè)置為100。學(xué)習(xí)速率設(shè)置為0.0002。

          from model import generator, discriminator
          import argparse
          import os
          import random
          import torch
          import torch.nn as nn
          import torch.nn.parallel
          import torch.backends.cudnn as cudnn
          import torch.optim as optim
          import torch.utils.data
          import torchvision.datasets as dset
          import torchvision.transforms as transforms
          import torchvision.utils as vutils
          from torch.autograd import Variable

          from model import _netlocalD,_netG
          import utils
          epochs=100
          Batch_Size=64
          lr=0.0002
          beta1=0.5
          over=4
          parser = argparse.ArgumentParser()
          parser.add_argument('--dataroot',  default='dataset/train', help='path to dataset')
          opt = parser.parse_args()
          try:
              os.makedirs("result/train/cropped")
              os.makedirs("result/train/real")
              os.makedirs("result/train/recon")
              os.makedirs("model")
          except OSError:
              pass

          transform = transforms.Compose([transforms.Scale(128),
                                          transforms.CenterCrop(128),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.50.50.5), (0.50.50.5))])
          dataset = dset.ImageFolder(root=opt.dataroot, transform=transform )
          assert dataset
          dataloader = torch.utils.data.DataLoader(dataset, batch_size=Batch_Size,
                                                   shuffle=True, num_workers=2)

          ngpu = int(opt.ngpu)

          wtl2 = 0.999

          # custom weights initialization called on netG and netD
          def weights_init(m):
              classname = m.__class__.__name__
              if classname.find('Conv') != -1:
                  m.weight.data.normal_(0.00.02)
              elif classname.find('BatchNorm') != -1:
                  m.weight.data.normal_(1.00.02)
                  m.bias.data.fill_(0)


          resume_epoch=0

          netG = generator()
          netG.apply(weights_init)


          netD = discriminator()
          netD.apply(weights_init)

          criterion = nn.BCELoss()
          criterionMSE = nn.MSELoss()

          input_real = torch.FloatTensor(Batch_Size, 3128128)
          input_cropped = torch.FloatTensor(Batch_Size, 3128128)
          label = torch.FloatTensor(Batch_Size)
          real_label = 1
          fake_label = 0

          real_center = torch.FloatTensor(Batch_Size, 364,64)


          netD.cuda()
          netG.cuda()
          criterion.cuda()
          criterionMSE.cuda()
          input_real, input_cropped,label = input_real.cuda(),input_cropped.cuda(), label.cuda()
          real_center = real_center.cuda()


          input_real = Variable(input_real)
          input_cropped = Variable(input_cropped)
          label = Variable(label)


          real_center = Variable(real_center)

          optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
          optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

          for epoch in range(resume_epoch,epochs):
              for i, data in enumerate(dataloader, 0):
                  real_cpu, _ = data
                  real_center_cpu = real_cpu[:,:,int(128/4):int(128/4)+int(128/2),int(128/4):int(128/4)+int(128/2)]
                  batch_size = real_cpu.size(0)
                  with torch.no_grad():
                      input_real.resize_(real_cpu.size()).copy_(real_cpu)
                      input_cropped.resize_(real_cpu.size()).copy_(real_cpu)
                      real_center.resize_(real_center_cpu.size()).copy_(real_center_cpu)
                      input_cropped[:,0,int(128/4+over):int(128/4+128/2-over),int(128/4+over):int(128/4+128/2-over)] = 2*117.0/255.0 - 1.0
                      input_cropped[:,1,int(128/4+over):int(128/4+128/2-over),int(128/4+over):int(128/4+128/2-over)] = 2*104.0/255.0 - 1.0
                      input_cropped[:,2,int(128/4+over):int(128/4+128/2-over),int(128/4+over):int(128/4+128/2-over)] = 2*123.0/255.0 - 1.0

                  #start the discriminator by training with real data---
                  netD.zero_grad()
                  with torch.no_grad():
                      label.resize_(batch_size).fill_(real_label)

                  output = netD(real_center)
                  errD_real = criterion(output, label)
                  errD_real.backward()
                  D_x = output.data.mean()

                  # train the discriminator with fake data---
                  fake = netG(input_cropped)
                  label.data.fill_(fake_label)
                  output = netD(fake.detach())
                  errD_fake = criterion(output, label)
                  errD_fake.backward()
                  D_G_z1 = output.data.mean()
                  errD = errD_real + errD_fake
                  optimizerD.step()


                  #train the generator now---
                  netG.zero_grad()
                  label.data.fill_(real_label)  # fake labels are real for generator cost
                  output = netD(fake)
                  errG_D = criterion(output, label)

                  wtl2Matrix = real_center.clone()
                  wtl2Matrix.data.fill_(wtl2*10)
                  wtl2Matrix.data[:,:,int(over):int(128/2 - over),int(over):int(128/2 - over)] = wtl2

                  errG_l2 = (fake-real_center).pow(2)
                  errG_l2 = errG_l2 * wtl2Matrix
                  errG_l2 = errG_l2.mean()

                  errG = (1-wtl2) * errG_D + wtl2 * errG_l2

                  errG.backward()

                  D_G_z2 = output.data.mean()
                  optimizerG.step()

                  print('[%d / %d][%d / %d] Loss_D: %.4f Loss_G: %.4f / %.4f l_D(x): %.4f l_D(G(z)): %.4f'
                        % (epoch, epochs, i, len(dataloader),
                           errD.data, errG_D.data,errG_l2.data, D_x,D_G_z1, ))

                  if i % 100 == 0:

                      vutils.save_image(real_cpu,
                              'result/train/real/real_samples_epoch_%03d.png' % (epoch))
                      vutils.save_image(input_cropped.data,
                              'result/train/cropped/cropped_samples_epoch_%03d.png' % (epoch))
                      recon_image = input_cropped.clone()
                      recon_image.data[:,:,int(128/4):int(128/4+128/2),int(128/4):int(128/4+128/2)] = fake.data
                      vutils.save_image(recon_image.data,
                              'result/train/recon/recon_center_samples_epoch_%03d.png' % (epoch))
          訓(xùn)練生成器和鑒別器的訓(xùn)練模塊

          結(jié)果

          讓我們看一下我們的模型能夠構(gòu)建出什么來?第0個(gè)epoch時(shí)候的圖像(噪聲):

          第100個(gè)epoch時(shí)候:

          我們看下輸入模型的是什么:


          END

          英文原文:https://towardsdatascience.com/inpainting-with-ai-get-back-your-images-pytorch-a68f689128e5


          國產(chǎn)小眾瀏覽器因屏蔽視頻廣告,被索賠100萬(后續(xù))

          年輕人“不講武德”:因看黃片上癮,把網(wǎng)站和786名女主播起訴了

          中國聯(lián)通官網(wǎng)被發(fā)現(xiàn)含木馬腳本,可向用戶推廣色情APP

          張一鳴:每個(gè)逆襲的年輕人,都具備的底層能力


          關(guān)


          學(xué),西學(xué)學(xué)運(yùn)護(hù)質(zhì),結(jié),關(guān)[]學(xué)習(xí)進(jìn)!


          瀏覽 104
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評論
          圖片
          表情
          推薦
          點(diǎn)贊
          評論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  美国黄色电影AA | 免费高潮视频 | 国产成人免费在线观看 | 大香蕉思思精品在线 | 狼友视频官网免费 |