Pytorch:使用DCGAN實現(xiàn)數(shù)據(jù)復(fù)制
點擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時間送達
Ian J. Goodfellow首次提出了GAN之后,生成對抗只是神經(jīng)網(wǎng)絡(luò)還不是深度卷積神經(jīng)網(wǎng)絡(luò),所以有人提出一種基于深度神經(jīng)網(wǎng)絡(luò)的生成對抗網(wǎng)絡(luò),這個就是DCGAN。相比之前的GAN,DCGAN在生成者與判別者網(wǎng)絡(luò)上的改進如下:
1.使用步長卷積與反卷積替代池化實現(xiàn)上下采樣2.在生成者與判別者網(wǎng)絡(luò)使用BN層3.刪除全鏈接層4.在生成者網(wǎng)絡(luò)使用ReLU作為激活函數(shù),最后一層使用tanh5.在判別者網(wǎng)絡(luò)使用LeakyReLU作為激活函數(shù)
生成者網(wǎng)絡(luò)如下:

使用celebA人臉數(shù)據(jù)集,20W張人臉數(shù)據(jù),完成DCGAN的訓(xùn)練,最終保存生成者模型。下面是DCGAN的代碼實現(xiàn)與,訓(xùn)練與基于生成者實現(xiàn)人臉數(shù)據(jù)的復(fù)制。
生成者卷積神經(jīng)網(wǎng)絡(luò)的代碼實現(xiàn)如下:
class Generator(nn.Module):def __init__(self, ngpu):super(Generator, self).__init__()self.ngpu = ngpuself.main = nn.Sequential(# input is Z, going into a convolutionnn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),nn.ReLU(True),# state size. (ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# state size. (ngf*4) x 8 x 8nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# state size. (ngf*2) x 16 x 16nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# state size. (ngf) x 32 x 32nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),nn.Tanh()# state size. (nc) x 64 x 64)def forward(self, input):return self.main(input)
判別者卷積神經(jīng)網(wǎng)絡(luò)代碼實現(xiàn)如下:
1class?Discriminator(nn.Module):
2????def?__init__(self,?ngpu):
3????????super(Discriminator,?self).__init__()
4????????self.ngpu?=?ngpu
5????????self.main?=?nn.Sequential(
6????????????#?input?is?(nc)?x?64?x?64
7????????????nn.Conv2d(nc,?ndf,?4,?2,?1,?bias=False),
8????????????nn.LeakyReLU(0.2,?inplace=True),
9????????????#?state?size.?(ndf)?x?32?x?32
10????????????nn.Conv2d(ndf,?ndf?*?2,?4,?2,?1,?bias=False),
11????????????nn.BatchNorm2d(ndf?*?2),
12????????????nn.LeakyReLU(0.2,?inplace=True),
13????????????#?state?size.?(ndf*2)?x?16?x?16
14????????????nn.Conv2d(ndf?*?2,?ndf?*?4,?4,?2,?1,?bias=False),
15????????????nn.BatchNorm2d(ndf?*?4),
16????????????nn.LeakyReLU(0.2,?inplace=True),
17????????????#?state?size.?(ndf*4)?x?8?x?8
18????????????nn.Conv2d(ndf?*?4,?ndf?*?8,?4,?2,?1,?bias=False),
19????????????nn.BatchNorm2d(ndf?*?8),
20????????????nn.LeakyReLU(0.2,?inplace=True),
21????????????#?state?size.?(ndf*8)?x?4?x?4
22????????????nn.Conv2d(ndf?*?8,?1,?4,?1,?0,?bias=False),
23????????????nn.Sigmoid()
24????????)
25
26????def?forward(self,?input):
27????????return?self.main(input)初始化與模型訓(xùn)練
1#?Create?the?Discriminator
2netD?=?Discriminator(ngpu).to(device)
3
4#?Handle?multi-gpu?if?desired
5if?(device.type?==?'cuda')?and?(ngpu?>?1):
6????netD?=?nn.DataParallel(netD,?list(range(ngpu)))
7
8#?Apply?the?weights_init?function?to?randomly?initialize?all?weights
9#??to?mean=0,?stdev=0.2.
10netD.apply(weights_init)
11
12#?Print?the?model
13print(netD)
14
15
16#?Initialize?BCELoss?function
17criterion?=?nn.BCELoss()
18
19#?Create?batch?of?latent?vectors?that?we?will?use?to?visualize
20#??the?progression?of?the?generator
21fixed_noise?=?torch.randn(64,?nz,?1,?1,?device=device)
22
23#?Establish?convention?for?real?and?fake?labels?during?training
24real_label?=?1.
25fake_label?=?0.
26
27#?Setup?Adam?optimizers?for?both?G?and?D
28optimizerD?=?optim.Adam(netD.parameters(),?lr=lr,?betas=(beta1,?0.999))
29optimizerG?=?optim.Adam(netG.parameters(),?lr=lr,?betas=(beta1,?0.999))
30
31
32#?Training?Loop
33
34#?Lists?to?keep?track?of?progress
35img_list?=?[]
36G_losses?=?[]
37D_losses?=?[]
38iters?=?0
39
40if?__name__?==?"__main__":
41????print("Starting?Training?Loop...")
42????#?For?each?epoch
43????for?epoch?in?range(num_epochs):
44????????#?For?each?batch?in?the?dataloader
45????????for?i,?data?in?enumerate(dataloader,?0):
46
47????????????############################
48????????????#?(1)?Update?D?network:?maximize?log(D(x))?+?log(1?-?D(G(z)))
49????????????###########################
50????????????##?Train?with?all-real?batch
51????????????netD.zero_grad()
52????????????#?Format?batch
53????????????real_cpu?=?data[0].to(device)
54????????????b_size?=?real_cpu.size(0)
55????????????label?=?torch.full((b_size,),?real_label,?dtype=torch.float,?device=device)
56????????????#?Forward?pass?real?batch?through?D
57????????????output?=?netD(real_cpu).view(-1)
58????????????#?Calculate?loss?on?all-real?batch
59????????????errD_real?=?criterion(output,?label)
60????????????#?Calculate?gradients?for?D?in?backward?pass
61????????????errD_real.backward()
62????????????D_x?=?output.mean().item()
63
64????????????##?Train?with?all-fake?batch
65????????????#?Generate?batch?of?latent?vectors
66????????????noise?=?torch.randn(b_size,?nz,?1,?1,?device=device)
67????????????#?Generate?fake?image?batch?with?G
68????????????fake?=?netG(noise)
69????????????label.fill_(fake_label)
70????????????#?Classify?all?fake?batch?with?D
71????????????output?=?netD(fake.detach()).view(-1)
72????????????#?Calculate?D's?loss?on?the?all-fake?batch
73????????????errD_fake?=?criterion(output,?label)
74????????????#?Calculate?the?gradients?for?this?batch
75????????????errD_fake.backward()
76????????????D_G_z1?=?output.mean().item()
77????????????#?Add?the?gradients?from?the?all-real?and?all-fake?batches
78????????????errD?=?errD_real?+?errD_fake
79????????????#?Update?D
80????????????optimizerD.step()
81
82????????????############################
83????????????#?(2)?Update?G?network:?maximize?log(D(G(z)))
84????????????###########################
85????????????netG.zero_grad()
86????????????label.fill_(real_label)??#?fake?labels?are?real?for?generator?cost
87????????????#?Since?we?just?updated?D,?perform?another?forward?pass?of?all-fake?batch?through?D
88????????????output?=?netD(fake).view(-1)
89????????????#?Calculate?G's?loss?based?on?this?output
90????????????errG?=?criterion(output,?label)
91????????????#?Calculate?gradients?for?G
92????????????errG.backward()
93????????????D_G_z2?=?output.mean().item()
94????????????#?Update?G
95????????????optimizerG.step()
96
97????????????#?Output?training?stats
98????????????if?i?%?50?==?0:
99????????????????print('[%d/%d][%d/%d]\tLoss_D:?%.4f\tLoss_G:?%.4f\tD(x):?%.4f\tD(G(z)):?%.4f?/?%.4f'
100??????????????????????%?(epoch,?num_epochs,?i,?len(dataloader),
101?????????????????????????errD.item(),?errG.item(),?D_x,?D_G_z1,?D_G_z2))
102
103????????????#?Save?Losses?for?plotting?later
104????????????G_losses.append(errG.item())
105????????????D_losses.append(errD.item())
106
107????????????#?Check?how?the?generator?is?doing?by?saving?G's?output?on?fixed_noise
108????????????if?(iters?%?500?==?0)?or?((epoch?==?num_epochs-1)?and?(i?==?len(dataloader)-1)):
109????????????????with?torch.no_grad():
110????????????????????fake?=?netG(fixed_noise).detach().cpu()
111????????????iters?+=?1
112
113????????#?save?model
114????????netG.eval()
115????????torch.save(netG,?'generate_model.pt')請大家原諒我,因為我比較懶,就直接把pytorch官方教程中的代碼,經(jīng)過一番粘貼復(fù)制然后一通猛改就成了這個例子了!所以注釋都是英文的,大家自己慢慢看吧,基于celebA人臉數(shù)據(jù)集訓(xùn)練,我的機器是GTX1050Ti,所以只訓(xùn)練了3個epoch就over了,主要是為了省電跟環(huán)保@_@!然后保存了生成者模型。
基于DCGAN的生成者模型,直接加載,生成100個隨機數(shù)作為輸入樣本,生成的人臉圖像如下:




從上圖可以看出已經(jīng)有點效果了,感興趣可以繼續(xù)訓(xùn)練,演示代碼如下:
1import?torch
2import?cv2?as?cv
3import?numpy?as?np
4from?dcgan_model?import?Generator
5from?torchvision.utils?import?save_image
6
7
8def?dcgan_generate_face_demo():
9????netG?=?torch.load("./generate_model.pt")
10????netG.cuda()
11????for?i?in?range(4):
12????????noise?=?torch.randn(64,?100,?1,?1,?device="cuda")
13????????#?Generate?fake?image?batch?with?G
14????????generated?=?netG(noise)
15????????print(generated.size())
16????????save_image(generated.view(generated.size(0),?3,?64,?64),?'D:/sample_%d'%i?+?'.png')
17
18
19if?__name__?==?"__main__":
20????dcgan_generate_face_demo()交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

