CycleGAN 生成對(duì)抗網(wǎng)絡(luò)圖像處理工具

1. GAN簡(jiǎn)介
通過(guò)對(duì)抗的方式,去學(xué)習(xí)數(shù)據(jù)分布的生成式模型。GAN是無(wú)監(jiān)督的過(guò)程,能夠捕捉數(shù)據(jù)集的分布,以便于可以從隨機(jī)噪聲中生成同樣分布的數(shù)據(jù)
D判別式模型:學(xué)習(xí)真假邊界,判斷數(shù)據(jù)是真的還是假的 G生成式模型:學(xué)習(xí)數(shù)據(jù)分布并生成數(shù)據(jù)


2. 實(shí)戰(zhàn)cycleGAN 風(fēng)格轉(zhuǎn)換

2.1 cycleGAN簡(jiǎn)介

兩路GAN:兩個(gè)生成器[ G:X->Y , F:Y->X ] 和兩個(gè)判別器[Dx, Dy], G和Dy目的是生成的對(duì)象,Dy(正類是Y領(lǐng)域)無(wú)法判別。同理F和Dx也是一樣的。 cycle consistency:G是生成Y的生成器, F是生成X的生成器,cycle consistency是為了約束G和F生成的對(duì)象的范圍, 是的G生成的對(duì)象通過(guò)F生成器能夠回到原始的領(lǐng)域如:x->G(x)->F(G(x))=x



2.2 實(shí)現(xiàn)cycleGAN
2.2.1 生成器
Perceptual Losses for Real-Time Style Transfer and Super-Resolution: Supplementary Material。大致可以分為:下采樣 + residual 殘差block + 上采樣,如下圖(摘自論文):
# 殘差block
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
)
def forward(self, x):
return x + self.block(x)
class GeneratorResNet(nn.Module):
def __init__(self, input_shape, num_residual_blocks):
super(GeneratorResNet, self).__init__()
channels = input_shape[0]
# Initial convolution block
out_features = 64
model = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Downsampling
for _ in range(2):
out_features *= 2
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Residual blocks
for _ in range(num_residual_blocks):
model += [ResidualBlock(out_features)]
# Upsampling
for _ in range(2):
out_features //= 2
model += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
2.2.2 判別器
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(self, img):
return self.model(img)

2.2.3 訓(xùn)練
# Losses
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
cuda = torch.cuda.is_available()
input_shape = (opt.channels, opt.img_height, opt.img_width)
# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
# Optimizers
optimizer_G = torch.optim.Adam(
itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
optimizer_G, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_A, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
optimizer_D_B, lr_lambda=LambdaLR(opt.n_epochs, opt.epoch, opt.decay_epoch).step
)
訓(xùn)練數(shù)據(jù)是成對(duì)的數(shù)據(jù),但是是非配對(duì)的數(shù)據(jù),即A和B是沒有直接的聯(lián)系的。A是原圖,B是風(fēng)格圖 生成器訓(xùn)練 GAN loss:判別器判別A和B生成的兩個(gè)圖fake_A、fake_B與GT的loss Cycle loss:反過(guò)來(lái)fake_A和fake_B 生成的圖與A和B像素上差異 判別器訓(xùn)練: loss_real: 判別A/B和GT的MSELoss loss_fake:判別生成的fake_A/fake_B與GT的MSELoss
for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader):
# 數(shù)據(jù)是成對(duì)的數(shù)據(jù),但是是非配對(duì)的數(shù)據(jù),即A和B是沒有直接的聯(lián)系的
real_A = Variable(batch["A"].type(Tensor))
real_B = Variable(batch["B"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
G_AB.train()
G_BA.train()
optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
loss_G.backward()
optimizer_G.step()
# -----------------------
# Train Discriminator A
# -----------------------
optimizer_D_A.zero_grad()
# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
# fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
loss_D_A.backward()
optimizer_D_A.step()
# -----------------------
# Train Discriminator B
# -----------------------
optimizer_D_B.zero_grad()
# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
# fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
loss_D_B.backward()
optimizer_D_B.step()
loss_D = (loss_D_A + loss_D_B) / 2
# --------------
# Log Progress
# --------------
# Determine approximate time left
batches_done = epoch * len(dataloader) + i
batches_left = opt.n_epochs * len(dataloader) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()
# Update learning rates
lr_scheduler_G.step()
lr_scheduler_D_A.step()
lr_scheduler_D_B.step()2.2.4 結(jié)果展示


2.2.5 cycleGAN其他用途


3. 總結(jié)
GAN是學(xué)習(xí)數(shù)據(jù)中分布,并生成同樣分布但全新的數(shù)據(jù) CycleGAN是兩路GAN:兩個(gè)生成器和兩個(gè)判別器;為了保證生成器的生成的圖片與輸入圖存在一定的關(guān)系,不是隨機(jī)生產(chǎn)的圖片, 引入cycle consistency,判定A->fake_B->recove_A和A的差異 生成器:下采樣 + residual 殘差block + 上采樣 判別器: 不是一個(gè)圖生成一個(gè)判定值,而是patchGAN方式,生成很N*N個(gè)值,而后取均值
作者簡(jiǎn)介:wedo實(shí)驗(yàn)君, 數(shù)據(jù)分析師;熱愛生活,熱愛寫作
贊 賞 作 者

更多閱讀
特別推薦

點(diǎn)擊下方閱讀原文加入社區(qū)會(huì)員
評(píng)論
圖片
表情
