<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          【小白學(xué)習(xí)PyTorch教程】十一、基于MNIST數(shù)據(jù)集訓(xùn)練第一個(gè)生成性對(duì)抗網(wǎng)絡(luò)

          共 14778字,需瀏覽 30分鐘

           ·

          2021-08-15 11:58

          「@Author:Runsen」

          GAN 是使用兩個(gè)神經(jīng)網(wǎng)絡(luò)模型訓(xùn)練的生成模型。一種模型稱為生成網(wǎng)絡(luò)模型,它學(xué)習(xí)生成新的似是而非的樣本。另一個(gè)模型被稱為判別網(wǎng)絡(luò),它學(xué)習(xí)區(qū)分生成的例子和真實(shí)的例子。

          生成性對(duì)抗網(wǎng)絡(luò)

          2014,蒙特利爾大學(xué)的Ian Goodfellow和他的朋友發(fā)明了生成性對(duì)抗網(wǎng)絡(luò)(GAN)。自它出版以來(lái),有許多它的變體和客觀功能來(lái)解決它的問(wèn)題

          論文在這里找到.

          論文提出了兩種模型:生成模型和判別模型。兩個(gè)模型競(jìng)爭(zhēng),以產(chǎn)生真實(shí)和假的樣本。2016年,Yann LeCun將GANs描述為“過(guò)去二十年機(jī)器學(xué)習(xí)中最酷的想法”。

          GAN 的大部分研究和應(yīng)用都集中在計(jì)算機(jī)視覺(jué)領(lǐng)域。

          其原因是卷積神經(jīng)網(wǎng)絡(luò) (CNN) 等深度學(xué)習(xí)模型在過(guò)去 5 到 7 年中在計(jì)算機(jī)視覺(jué)領(lǐng)域取得了巨大成功,例如在具有挑戰(zhàn)性的任務(wù)(如對(duì)象檢測(cè)和人臉識(shí)別。

          GAN 的典型例子是生成新的逼真的照片,最令人吃驚的是生成照片般逼真的人臉的例子。

          在本教程中,我們將實(shí)現(xiàn)一個(gè)簡(jiǎn)單的GAN生成假的MNIST樣本。

          import torch
          import torch.nn as nn
          import torch.optim as optim
          from torch.utils.data import DataLoader

          import torchvision
          import torchvision.datasets as datasets
          import torchvision.transforms as transforms
          import torchvision.utils as utils

          import numpy as np
          import matplotlib.pyplot as plt
          # CPU / GPU Setting
          device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
          print(device)  #cuda

          使用MNIST數(shù)據(jù)集,具有最小大小的數(shù)據(jù)集。

          它由60000個(gè)訓(xùn)練圖像和10000個(gè)測(cè)試圖像組成,每個(gè)圖像有28*28的大小和一個(gè)彩色通道。

          # Define a transform 
          transform = transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize(mean = (0.5, ), std = (0.5, ))
          ])

          # batch_size是一個(gè)前向和后向傳播過(guò)程中的圖像數(shù)。
          batch_size = 100

          mnist = datasets.MNIST('./data/MNIST'
                                 download = True
                                 train = True
                                 transform = transform)

          mnist_loader = DataLoader(dataset = mnist, 
                                    batch_size = batch_size, 
                                    shuffle = True)
          # CPU
          def imshow(img, title):
              img = utils.make_grid(img.cpu().detach())
              img = (img+1)/2
              npimg = img.detach().numpy()
              plt.imshow(np.transpose(npimg, (120)))
              plt.title(title)
              plt.show()
          #GPU
          def imshow(img, title):
              npimg = img.detach().numpy()
              fig = plt.figure(figsize = (1010))
              plt.imshow(np.transpose(npimg, (120)))
              plt.title(title)
              plt.show()

          images, labels = iter(mnist_loader).next()
          imshow(images[0:16, :, :], "MNIST Images")

          建立一個(gè)GANs模型。一個(gè)Generator和Discriminator

          GANs由完全連接的層組成。它將從100維高斯分布采樣的噪聲轉(zhuǎn)換為MNIST圖像。鑒別器網(wǎng)絡(luò)也由完全連接的層組成,用于區(qū)分輸入數(shù)據(jù)是真是假。

          class Generator(nn.Module):
              def __init__(self):
                  super(Generator, self).__init__()
                  
                  latent_size = 100
                  output = 28*28
                  
                  self.main = nn.Sequential(
                      nn.Linear(latent_size, 128),
                      nn.ReLU(inplace=True),
                      
                      nn.Linear(128256),
                      nn.ReLU(inplace=True),
                      
                      nn.Linear(256512),
                      nn.ReLU(inplace=True),
                      
                      nn.Linear(512, output),
                      nn.Tanh()
                  )
                  
              def forward(self, x):
                  out = self.main(x)
                  out = out.view(-112828)
                  return out


          class Discriminator(nn.Module):
              def __init__(self):
                  super(Discriminator, self).__init__()
                  
                  n_features = 28 * 28
                  n_out = 1
                  
                  self.main = nn.Sequential(
                      nn.Linear(n_features, 512),
                      nn.ReLU(inplace=True),
                      
                      nn.Linear(512256),
                      nn.ReLU(inplace=True),
                      
                      nn.Linear(256128),
                      nn.ReLU(inplace=True),
                      
                      nn.Linear(12864),
                      nn.ReLU(inplace=True),
                      
                      nn.Linear(64, n_out),
                      nn.Sigmoid()        
                  )
                  
              def forward(self, x):
                  x = x.view(-128*28)
                  out = self.main(x)
                  return out

          G = Generator().to(device)
          D = Discriminator().to(device)

          生成性對(duì)抗網(wǎng)絡(luò)訓(xùn)練過(guò)程的損失函數(shù)是二進(jìn)制交叉熵?fù)p失,由torch.nn.BCELoss實(shí)現(xiàn)。

          這兩種模型都使用torch.optim.Adam作為優(yōu)化工具,學(xué)習(xí)率設(shè)置為0.002。

          # Objective Function
          criterion = nn.BCELoss()

          # Optimizer
          G_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
          D_optimizer = optim.Adam(D.parameters(), lr = 0.0002)

          # Constants
          noise_dim = 100
          num_epochs = 50
          total_batch = len(mnist_loader)

          # Lists
          G_losses = []
          D_losses = []

          # Noise
          sample_size = 16
          fixed_noise = torch.randn(sample_size, noise_dim).to(device)

          # Train
          for epoch in range(num_epochs):
              for i, (images, labels) in enumerate(mnist_loader):
                  
                  # Images #
                  images = images.reshape(batch_size, -1).float().to(device)
                  
                  # Labels #
                  ones = torch.ones(batch_size, 1).to(device)
                  zeros = torch.zeros(batch_size, 1).to(device)
                  
                  # Noise #
                  noise = torch.randn(batch_size, noise_dim).to(device)
                  
                  # Initialize Optimizers
                  D_optimizer.zero_grad()
                  G_optimizer.zero_grad()
                  
                  #######################
                  # Train Discriminator #
                  #######################
                  
                  # Forward Images #
                  prob_real = D(images)
                  D_real_loss = criterion(prob_real, ones)
                  
                  # Generate Samples #
                  fake_images = G(noise)
                  prob_fake = D(fake_images)
                  
                  # Forward Fake Samples and Calculate Discriminator Loss #
                  D_fake_loss = criterion(prob_fake, zeros)
                  D_loss = (D_real_loss + D_fake_loss).mean()
                  
                  # Back Propagation and Update
                  D_loss.backward()
                  D_optimizer.step()
                  
                  ###################
                  # Train Generator #
                  ###################
                  
                  fake_images = G(noise)
                  prob_fake = D(fake_images)
                  
                  # According to the section 3 in paper,
                  # early in learning, when G is very poor, D can reject samples from G.
                  # In this case, log(1-D(G(z))) saturates. 
                  # thus, train G to maximiaze log(D(G(z))) instead of minimizing log(1-D(G(z)))
                  G_loss = criterion(prob_fake, ones)
                  
                  # Back Propagation and Update
                  G_loss.backward()
                  G_optimizer.step()
                  
                  # Save Losses for Plotting Later
                  G_losses.append(G_loss.item())
                  D_losses.append(D_loss.item())
                  
                  # Print Statistics #
                  if (i + 1) % 100 == 0:
                      print("Epoch [%d/%d] Iter [%d/%d], D_Loss: %.4f G_Loss: %.4f"
                            %(epoch+1, num_epochs, i+1, total_batch, D_loss.item(), G_loss.item()))
              
              # Generate Samples #
              if epoch % 1 == 0:
                  fake_samples = G(fixed_noise)
                  imshow(fake_samples, "Generated MNIST Images")
              
          # Save Model Weights for Digit Generation
          torch.save(G.state_dict(), './data/GAN.pkl')
          plt.figure(figsize = (86))
          plt.title("Generator and Discriminator Loss During Training")
          plt.plot(G_losses, label="Generator")
          plt.plot(D_losses, label="Discriminator")
          plt.xlabel("Iterations")
          plt.ylabel("Losses")
          plt.legend()
          plt.show()
          sample_size = 64
          noise_dim = 100

          noise = torch.randn(sample_size, noise_dim).to(device)

          G.load_state_dict(torch.load('GAN.pkl'))
          fake_samples = G(fixed_noise)
          imshow(fake_samples, "Generated MNIST Images")

          GAN生成性對(duì)抗網(wǎng)絡(luò)的運(yùn)用

          • 將語(yǔ)義圖像翻譯成城市景觀和建筑物的照片。
          • 將衛(wèi)星照片翻譯成地圖。
          • 從白天到晚上的照片翻譯。
          • 將黑白照片翻譯成彩色。

          - 論文在這里找到:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

          - 上述代碼的論文:https://arxiv.org/abs/1511.06434

          - 上述代碼:https://github.com/yihui-he/GAN-MNIST

          往期精彩回顧




          本站qq群851320808,加入微信群請(qǐng)掃碼:
          瀏覽 53
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  北条麻妃一区二区三区成人片 | 97国产在线观看 | 樱桃视频一区二区 | 一区二区三区高清无码在线 | 日女人毛片 |