<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          火爆全球的GAN,究竟是何方神圣?

          共 22168字,需瀏覽 45分鐘

           ·

          2021-03-06 18:04


          故事時(shí)間

          從前有一個(gè)人,他希望通過(guò)制造假幣來(lái)發(fā)家致富。

          于是,他開(kāi)始學(xué)習(xí)制造假幣。

          一開(kāi)始,他的技術(shù)太菜,制作的假幣剛流入市場(chǎng)就被警察發(fā)現(xiàn)了。

          他不甘心,于是繼續(xù)學(xué)習(xí)來(lái)提升造假幣技術(shù),這一次,假幣并沒(méi)有被發(fā)現(xiàn),他很開(kāi)心的數(shù)著錢。

          可是,過(guò)了一段時(shí)間,敏銳的警察使用剛剛學(xué)習(xí)到的新知識(shí),破獲了他的假幣。

          但他還是不甘示弱,繼續(xù)提升造假幣的技術(shù)

          警察也繼續(xù)學(xué)習(xí)新的假幣鑒別技術(shù)

          就這樣,他的造假幣技術(shù)一直在提升,警察鑒別假幣的技術(shù)也在不斷提升

          在互相抗衡很久以后,他的造假幣技術(shù)到了爐火純青的地步,以至于警察都難以鑒別。

          GAN是什么?

          生成對(duì)抗網(wǎng)絡(luò)(Generative adversarial network, GAN)由生成器(一般用表示)和判別器(一般用表示)組成,常用于生成"假"的東西,比如假的文本,假的人臉圖像等等,本文以圖像生成為例進(jìn)行敘述。

          生成器負(fù)責(zé)將從某分布中隨機(jī)采樣的噪聲通過(guò)神經(jīng)網(wǎng)絡(luò)映射為"生成圖像";判別器負(fù)責(zé)鑒定給定的圖像是真實(shí)圖像還是生成器生成的圖像。

          在上面的故事中,警察充當(dāng)著判別器的角色,而造假幣的人充當(dāng)著生成器的角色。

          造假幣的人希望自己的假幣能夠騙過(guò)警察,而警察希望自己能夠精準(zhǔn)區(qū)分真錢幣和假幣,于是他們互相博弈,與彼此相對(duì)抗,最終,造假幣者造出來(lái)的假幣太過(guò)真實(shí),就連警察也不能正確鑒別了,此時(shí),就表明造假幣的人成功了。

          去掉故事的外衣,就是生成對(duì)抗網(wǎng)絡(luò)的思想了:

          生成器希望自己生成的假圖像能夠騙過(guò)判別器,而希望自己能夠精準(zhǔn)區(qū)分真實(shí)的圖像與生成的圖像,于是它們互相博弈,與彼此相對(duì)抗,最終,生成的的假圖像太過(guò)真實(shí),就連也不能正確鑒別了,此時(shí),就表明我們的生成對(duì)抗網(wǎng)絡(luò)訓(xùn)練成功了。

          之后在做“假”圖像生成的時(shí)候,只需將采樣得到的隨機(jī)噪聲序列輸入生成器,等待輸出即可。

          (求生欲:故事僅僅是為了更形象的介紹GAN,并無(wú)其他含義,不要多想)

          GAN原理解析

          生成器負(fù)責(zé)將從某分布中隨機(jī)采樣得到的噪聲序列映射為與真實(shí)圖像相似的生成圖像,自然希望生成圖像與真實(shí)圖像越像越好。

          這里,兩者的相似度用生成圖像所服從的分布與真實(shí)圖像所服從的分布之間的距離來(lái)度量,距離越小,表明兩個(gè)分布越相似。

          那如何度量?jī)蓚€(gè)分布之間的距離呢?干脆直接用萬(wàn)能的神經(jīng)網(wǎng)絡(luò)來(lái)衡量?jī)蓚€(gè)分布之間的距離好了。

          將兩個(gè)分布之間的距離度量記作,由上面所講可知,生成器希望生成圖像所服從的分布與真實(shí)圖像所服從的分布之間的越小越好,這樣生成的圖像才會(huì)更加接近真實(shí)圖像。

          而對(duì)于判別器來(lái)說(shuō),要分兩種情況。第一,如果判別器的輸入是真實(shí)圖像,那么判別器希望此時(shí)輸入圖像所服從的分布與真實(shí)圖像所服從的分布之間的越小越好;第二,如果判別器的輸入是生成圖像,那么判別器希望此時(shí)輸入圖像所服從的分布與真實(shí)圖像所服從的分布之間的越大越好,因?yàn)橹挥羞@樣判別器才能夠正確地將真實(shí)圖像與生成圖像區(qū)分開(kāi)來(lái)。

          以上用文字描述了半天,其實(shí)完全可以由下面的公式來(lái)表示:

          這個(gè)公式就是GAN的優(yōu)化目標(biāo)函數(shù),它將我們上面所講的內(nèi)容信息整合到了一起,其中的就體現(xiàn)了“對(duì)抗”的思想。

          用PyTorch寫(xiě)一個(gè)GAN

          分別用0和1表示生成圖像和真實(shí)圖像的標(biāo)簽,根據(jù)上一部分的原理講解,損失函數(shù)就有了。具體來(lái)說(shuō),生成器希望判別器誤將生成圖像(label:0)判定為真實(shí)圖像(label:1),因此希望生成圖像的判別結(jié)果與1越接近越好;而判別器則希望真實(shí)圖像(label:1)的判別結(jié)果與1越接近越好,生成圖像(label:0)的判別結(jié)果與0越接近越好,這樣就能夠很好的區(qū)分開(kāi)兩者了。用交叉熵度量以上損失即可。

          現(xiàn)在,來(lái)實(shí)現(xiàn)基于卷積神經(jīng)網(wǎng)絡(luò)的GAN(也叫DCGAN),并使用它生成人臉。這是PyTorch官方的給出的例子,我們動(dòng)手過(guò)一遍。

          導(dǎo)入所需庫(kù)

          from __future__ import print_function
          #%matplotlib inline
          import argparse
          import os
          import random
          import torch
          import torch.nn as nn
          import torch.nn.parallel
          import torch.backends.cudnn as cudnn
          import torch.optim as optim
          import torch.utils.data
          import torchvision.datasets as dset
          import torchvision.transforms as transforms
          import torchvision.utils as vutils
          import numpy as np
          import matplotlib.pyplot as plt
          import matplotlib.animation as animation
          from IPython.display import HTML

          必要參數(shù)設(shè)置

          # 設(shè)置隨機(jī)種子
          manualSeed = 999
          random.seed(manualSeed)
          torch.manual_seed(manualSeed)

          # 數(shù)據(jù)下載到指定目錄
          dataroot = "data/celeba"

          # Number of workers for dataloader
          workers = 2

          # 批量大小
          batch_size = 128

          # 將圖像rezize到指定尺寸
          image_size = 64

          # 通道數(shù),彩圖為3
          nc = 3

          # 隨機(jī)噪聲序列的長(zhǎng)度
          nz = 100

          # 生成器中特征圖的個(gè)數(shù)
          ngf = 64

          # 判別器中特征圖的個(gè)數(shù)
          ndf = 64

          # 訓(xùn)練迭代輪數(shù)
          num_epochs = 5

          # 學(xué)習(xí)率
          lr = 0.0002

          # Beta1 hyperparam for Adam optimizers
          beta1 = 0.5

          # Number of GPUs available. Use 0 for CPU mode.
          ngpu = 1

          # Decide which device we want to run on
          device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0else "cpu")

          數(shù)據(jù)準(zhǔn)備

          https://drive.google.com/drive/folders/0B7EVK8r0v71pTUZsaXdaSnZBZzg下載數(shù)據(jù)集到本地,然后解壓,路徑如下

          然后開(kāi)始數(shù)據(jù)預(yù)處理

          # We can use an image folder dataset the way we have it setup.
          # 創(chuàng)建數(shù)據(jù)集
          dataset = dset.ImageFolder(root=dataroot,
                                     transform=transforms.Compose([
                                         transforms.Resize(image_size),
                                         transforms.CenterCrop(image_size),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.50.50.5), (0.50.50.5)),
                                     ]))
          # Create the dataloader
          dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                   shuffle=True, num_workers=workers)

          畫(huà)出第一個(gè)batch的前64張圖片看一下

          # Plot some training images
          real_batch = next(iter(dataloader))
          plt.figure(figsize=(8,8))
          plt.axis("off")
          plt.title("Training Images")
          plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

          設(shè)置權(quán)值初始化方案

          # 不同的層使用不同的權(quán)值初始化方案
          def weights_init(m):
              classname = m.__class__.__name__
              if classname.find('Conv') != -1:
                  nn.init.normal_(m.weight.data, 0.00.02)
              elif classname.find('BatchNorm') != -1:
                  nn.init.normal_(m.weight.data, 1.00.02)
                  nn.init.constant_(m.bias.data, 0)

          搭建生成器和判別器網(wǎng)絡(luò)

          # 生成器網(wǎng)絡(luò)
          class Generator(nn.Module):
              def __init__(self, ngpu):
                  super(Generator, self).__init__()
                  self.ngpu = ngpu#設(shè)置是否使用gpu,1表示使用
                  self.main = nn.Sequential(
                      # 輸入隨機(jī)噪聲z,轉(zhuǎn)置卷積進(jìn)行上采樣
                      nn.ConvTranspose2d( nz, ngf * 8410, bias=False),
                      nn.BatchNorm2d(ngf * 8),
                      nn.ReLU(True),
                      # state size. (ngf*8) x 4 x 4
                      nn.ConvTranspose2d(ngf * 8, ngf * 4421, bias=False),
                      nn.BatchNorm2d(ngf * 4),
                      nn.ReLU(True),
                      # state size. (ngf*4) x 8 x 8
                      nn.ConvTranspose2d( ngf * 4, ngf * 2421, bias=False),
                      nn.BatchNorm2d(ngf * 2),
                      nn.ReLU(True),
                      # state size. (ngf*2) x 16 x 16
                      nn.ConvTranspose2d( ngf * 2, ngf, 421, bias=False),
                      nn.BatchNorm2d(ngf),
                      nn.ReLU(True),
                      # state size. (ngf) x 32 x 32
                      nn.ConvTranspose2d( ngf, nc, 421, bias=False),
                      nn.Tanh()
                      # state size. (nc) x 64 x 64
                  )

              def forward(self, input):
                  return self.main(input)

          # Create the generator
          netG = Generator(ngpu).to(device)

          # Handle multi-gpu if desired
          if (device.type == 'cuda'and (ngpu > 1):
              netG = nn.DataParallel(netG, list(range(ngpu)))

          # Apply the weights_init function to randomly initialize all weights
          #  to mean=0, stdev=0.2.
          netG.apply(weights_init)

          # Print the model
          print(netG)
          #判別器網(wǎng)絡(luò)
          class Discriminator(nn.Module):
              def __init__(self, ngpu):
                  super(Discriminator, self).__init__()
                  self.ngpu = ngpu
                  self.main = nn.Sequential(
                      # input is (nc) x 64 x 64
                      nn.Conv2d(nc, ndf, 421, bias=False),
                      nn.LeakyReLU(0.2, inplace=True),
                      # state size. (ndf) x 32 x 32
                      nn.Conv2d(ndf, ndf * 2421, bias=False),
                      nn.BatchNorm2d(ndf * 2),
                      nn.LeakyReLU(0.2, inplace=True),
                      # state size. (ndf*2) x 16 x 16
                      nn.Conv2d(ndf * 2, ndf * 4421, bias=False),
                      nn.BatchNorm2d(ndf * 4),
                      nn.LeakyReLU(0.2, inplace=True),
                      # state size. (ndf*4) x 8 x 8
                      nn.Conv2d(ndf * 4, ndf * 8421, bias=False),
                      nn.BatchNorm2d(ndf * 8),
                      nn.LeakyReLU(0.2, inplace=True),
                      # state size. (ndf*8) x 4 x 4
                      nn.Conv2d(ndf * 81410, bias=False),
                      nn.Sigmoid()
                  )

              def forward(self, input):
                  return self.main(input)

          # Create the Discriminator
          netD = Discriminator(ngpu).to(device)

          # Handle multi-gpu if desired
          if (device.type == 'cuda'and (ngpu > 1):
              netD = nn.DataParallel(netD, list(range(ngpu)))
              
          # Apply the weights_init function to randomly initialize all weights
          #  to mean=0, stdev=0.2.
          netD.apply(weights_init)

          # Print the model
          print(netD)

          設(shè)置優(yōu)化器等

          # 交叉熵?fù)p失函數(shù)
          criterion = nn.BCELoss()

          # Create batch of latent vectors that we will use to visualize
          #  the progression of the generator
          # 從正態(tài)分布中采樣64個(gè)nz長(zhǎng)度的隨機(jī)噪聲序列
          fixed_noise = torch.randn(64, nz, 11, device=device)

          # Establish convention for real and fake labels during training
          real_label = 1.
          fake_label = 0.

          # 設(shè)置優(yōu)化器
          optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
          optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

          開(kāi)始訓(xùn)練

          # 訓(xùn)練

          # Lists to keep track of progress
          img_list = []
          G_losses = []
          D_losses = []
          iters = 0

          print("Starting Training Loop...")
          # For each epoch
          for epoch in range(num_epochs):
              # For each batch in the dataloader
              for i, data in enumerate(dataloader, 0):
                  
                  ############################
                  # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                  ###########################
                  ## Train with all-real batch
                  netD.zero_grad()
                  # Format batch
                  real_cpu = data[0].to(device)
                  b_size = real_cpu.size(0)
                  label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
                  # Forward pass real batch through D
                  output = netD(real_cpu).view(-1)
                  # Calculate loss on all-real batch
                  errD_real = criterion(output, label)
                  # Calculate gradients for D in backward pass
                  errD_real.backward()
                  D_x = output.mean().item()

                  ## Train with all-fake batch
                  # Generate batch of latent vectors
                  noise = torch.randn(b_size, nz, 11, device=device)
                  # Generate fake image batch with G
                  fake = netG(noise)
                  label.fill_(fake_label)
                  # Classify all fake batch with D
                  output = netD(fake.detach()).view(-1)
                  # Calculate D's loss on the all-fake batch
                  errD_fake = criterion(output, label)
                  # Calculate the gradients for this batch
                  errD_fake.backward()
                  D_G_z1 = output.mean().item()
                  # Add the gradients from the all-real and all-fake batches
                  errD = errD_real + errD_fake
                  # Update D
                  optimizerD.step()

                  ############################
                  # (2) Update G network: maximize log(D(G(z)))
                  ###########################
                  netG.zero_grad()
                  label.fill_(real_label)  # fake labels are real for generator cost
                  # Since we just updated D, perform another forward pass of all-fake batch through D
                  output = netD(fake).view(-1)
                  # Calculate G's loss based on this output
                  errG = criterion(output, label)
                  # Calculate gradients for G
                  errG.backward()
                  D_G_z2 = output.mean().item()
                  # Update G
                  optimizerG.step()
                  
                  # Output training stats
                  if i % 50 == 0:
                      print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                            % (epoch, num_epochs, i, len(dataloader),
                               errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
                  
                  # Save Losses for plotting later
                  G_losses.append(errG.item())
                  D_losses.append(errD.item())
                  
                  # Check how the generator is doing by saving G's output on fixed_noise
                  if (iters % 500 == 0or ((epoch == num_epochs-1and (i == len(dataloader)-1)):
                      with torch.no_grad():
                          fake = netG(fixed_noise).detach().cpu()
                      img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                      
                  iters += 1

          漫長(zhǎng)的等待過(guò)后,終于訓(xùn)練好了。

          loss可視化

          plt.figure(figsize=(10,5))
          plt.title("Generator and Discriminator Loss During Training")
          plt.plot(G_losses,label="G")
          plt.plot(D_losses,label="D")
          plt.xlabel("iterations")
          plt.ylabel("Loss")
          plt.legend()
          plt.show()

          生成圖像的質(zhì)量演變過(guò)程可視化

          #%%capture
          fig = plt.figure(figsize=(8,8))
          plt.axis("off")
          ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
          ani = animation.ArtistAnimation(fig, ims, interval=100, repeat_delay=1000, blit=True)
          ani.save("pendulum.gif", writer='pillow')
          #HTML(ani.to_jshtml())

          上圖呈現(xiàn)了由噪聲圖像到人臉圖像的演變過(guò)程。注意,我們僅僅訓(xùn)練了5個(gè)epoch,因此生成的效果并不算太好,但總體能看出是人臉圖像。你可以嘗試增加epoch,再次訓(xùn)練。

          寫(xiě)在后面

          在GAN被提出以后,各式各樣的對(duì)GAN的改進(jìn)方案層出不窮,生成圖像的質(zhì)量也越來(lái)越好,甚至我們無(wú)法用肉眼分辨真實(shí)圖像和生成圖像。

          GAN也有許多有趣好玩的應(yīng)用,比如照片"去雜物",圖像超分辨率,老照片修復(fù),前段時(shí)間很火的AI還原皇帝,以及大家所熟知的AI換臉deepfake等等。

          相信GAN的前途一片光明!也相信現(xiàn)在的你會(huì)點(diǎn)個(gè)/在看的,對(duì)吧?

          深度學(xué)習(xí)資源下載

          在NLP情報(bào)局公眾號(hào)后臺(tái)回復(fù)“三件套”,即可獲取深度學(xué)習(xí)三件套:

          《PyTorch深度學(xué)習(xí)》,《Hands-on Machine Learning》,《Python深度學(xué)習(xí)》


          推 薦 閱 讀

          參 考 資 料

          • [1]https://sthalles.github.io/intro-to-gans/
          • [2]https://www.researchgate.net/publication/331756737_Recent_Progress_on_Generative_Adversarial_Networks_GANs_A_Survey
          • [3]https://spaces.ac.cn/archives/4439
          • [4]https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

          歡 迎 關(guān) 注 ??

          原創(chuàng)不易,有收獲的話請(qǐng)幫忙點(diǎn)擊分享、點(diǎn)贊在看??

          瀏覽 107
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  日韩人妻无码中文字幕 | 大香蕉视97 | 激情性爱网站 | 欧洲亚洲无码视频 | 4438全国最大无码视频 |