【小白學(xué)習(xí)PyTorch教程】十一、基于MNIST數(shù)據(jù)集訓(xùn)練第一個(gè)生成性對抗網(wǎng)絡(luò)
「@Author:Runsen」
GAN 是使用兩個(gè)神經(jīng)網(wǎng)絡(luò)模型訓(xùn)練的生成模型。一種模型稱為生成網(wǎng)絡(luò)模型,它學(xué)習(xí)生成新的似是而非的樣本。另一個(gè)模型被稱為判別網(wǎng)絡(luò),它學(xué)習(xí)區(qū)分生成的例子和真實(shí)的例子。
生成性對抗網(wǎng)絡(luò)
2014,蒙特利爾大學(xué)的Ian Goodfellow和他的朋友發(fā)明了生成性對抗網(wǎng)絡(luò)(GAN)。自它出版以來,有許多它的變體和客觀功能來解決它的問題
論文在這里找到.
論文提出了兩種模型:生成模型和判別模型。兩個(gè)模型競爭,以產(chǎn)生真實(shí)和假的樣本。2016年,Yann LeCun將GANs描述為“過去二十年機(jī)器學(xué)習(xí)中最酷的想法”。

GAN 的大部分研究和應(yīng)用都集中在計(jì)算機(jī)視覺領(lǐng)域。
其原因是卷積神經(jīng)網(wǎng)絡(luò) (CNN) 等深度學(xué)習(xí)模型在過去 5 到 7 年中在計(jì)算機(jī)視覺領(lǐng)域取得了巨大成功,例如在具有挑戰(zhàn)性的任務(wù)(如對象檢測和人臉識別。
GAN 的典型例子是生成新的逼真的照片,最令人吃驚的是生成照片般逼真的人臉的例子。

在本教程中,我們將實(shí)現(xiàn)一個(gè)簡單的GAN生成假的MNIST樣本。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as utils
import numpy as np
import matplotlib.pyplot as plt
# CPU / GPU Setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device) #cuda
使用MNIST數(shù)據(jù)集,具有最小大小的數(shù)據(jù)集。
它由60000個(gè)訓(xùn)練圖像和10000個(gè)測試圖像組成,每個(gè)圖像有28*28的大小和一個(gè)彩色通道。
# Define a transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5, ), std = (0.5, ))
])
# batch_size是一個(gè)前向和后向傳播過程中的圖像數(shù)。
batch_size = 100
mnist = datasets.MNIST('./data/MNIST',
download = True,
train = True,
transform = transform)
mnist_loader = DataLoader(dataset = mnist,
batch_size = batch_size,
shuffle = True)
# CPU
def imshow(img, title):
img = utils.make_grid(img.cpu().detach())
img = (img+1)/2
npimg = img.detach().numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.title(title)
plt.show()
#GPU
def imshow(img, title):
npimg = img.detach().numpy()
fig = plt.figure(figsize = (10, 10))
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.title(title)
plt.show()
images, labels = iter(mnist_loader).next()
imshow(images[0:16, :, :], "MNIST Images")

建立一個(gè)GANs模型。一個(gè)Generator和Discriminator
GANs由完全連接的層組成。它將從100維高斯分布采樣的噪聲轉(zhuǎn)換為MNIST圖像。鑒別器網(wǎng)絡(luò)也由完全連接的層組成,用于區(qū)分輸入數(shù)據(jù)是真是假。
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
latent_size = 100
output = 28*28
self.main = nn.Sequential(
nn.Linear(latent_size, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 512),
nn.ReLU(inplace=True),
nn.Linear(512, output),
nn.Tanh()
)
def forward(self, x):
out = self.main(x)
out = out.view(-1, 1, 28, 28)
return out
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
n_features = 28 * 28
n_out = 1
self.main = nn.Sequential(
nn.Linear(n_features, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Linear(64, n_out),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 28*28)
out = self.main(x)
return out
G = Generator().to(device)
D = Discriminator().to(device)
生成性對抗網(wǎng)絡(luò)訓(xùn)練過程的損失函數(shù)是二進(jìn)制交叉熵?fù)p失,由torch.nn.BCELoss實(shí)現(xiàn)。
這兩種模型都使用torch.optim.Adam作為優(yōu)化工具,學(xué)習(xí)率設(shè)置為0.002。
# Objective Function
criterion = nn.BCELoss()
# Optimizer
G_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
D_optimizer = optim.Adam(D.parameters(), lr = 0.0002)
# Constants
noise_dim = 100
num_epochs = 50
total_batch = len(mnist_loader)
# Lists
G_losses = []
D_losses = []
# Noise
sample_size = 16
fixed_noise = torch.randn(sample_size, noise_dim).to(device)
# Train
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(mnist_loader):
# Images #
images = images.reshape(batch_size, -1).float().to(device)
# Labels #
ones = torch.ones(batch_size, 1).to(device)
zeros = torch.zeros(batch_size, 1).to(device)
# Noise #
noise = torch.randn(batch_size, noise_dim).to(device)
# Initialize Optimizers
D_optimizer.zero_grad()
G_optimizer.zero_grad()
#######################
# Train Discriminator #
#######################
# Forward Images #
prob_real = D(images)
D_real_loss = criterion(prob_real, ones)
# Generate Samples #
fake_images = G(noise)
prob_fake = D(fake_images)
# Forward Fake Samples and Calculate Discriminator Loss #
D_fake_loss = criterion(prob_fake, zeros)
D_loss = (D_real_loss + D_fake_loss).mean()
# Back Propagation and Update
D_loss.backward()
D_optimizer.step()
###################
# Train Generator #
###################
fake_images = G(noise)
prob_fake = D(fake_images)
# According to the section 3 in paper,
# early in learning, when G is very poor, D can reject samples from G.
# In this case, log(1-D(G(z))) saturates.
# thus, train G to maximiaze log(D(G(z))) instead of minimizing log(1-D(G(z)))
G_loss = criterion(prob_fake, ones)
# Back Propagation and Update
G_loss.backward()
G_optimizer.step()
# Save Losses for Plotting Later
G_losses.append(G_loss.item())
D_losses.append(D_loss.item())
# Print Statistics #
if (i + 1) % 100 == 0:
print("Epoch [%d/%d] Iter [%d/%d], D_Loss: %.4f G_Loss: %.4f"
%(epoch+1, num_epochs, i+1, total_batch, D_loss.item(), G_loss.item()))
# Generate Samples #
if epoch % 1 == 0:
fake_samples = G(fixed_noise)
imshow(fake_samples, "Generated MNIST Images")
# Save Model Weights for Digit Generation
torch.save(G.state_dict(), './data/GAN.pkl')

plt.figure(figsize = (8, 6))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Losses")
plt.legend()
plt.show()

sample_size = 64
noise_dim = 100
noise = torch.randn(sample_size, noise_dim).to(device)
G.load_state_dict(torch.load('GAN.pkl'))
fake_samples = G(fixed_noise)
imshow(fake_samples, "Generated MNIST Images")

GAN生成性對抗網(wǎng)絡(luò)的運(yùn)用
將語義圖像翻譯成城市景觀和建筑物的照片。 將衛(wèi)星照片翻譯成地圖。 從白天到晚上的照片翻譯。 將黑白照片翻譯成彩色。

- 論文在這里找到:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
- 上述代碼的論文:https://arxiv.org/abs/1511.06434
- 上述代碼:https://github.com/yihui-he/GAN-MNIST
