輕松學Pytorch – 構(gòu)建生成對抗網(wǎng)絡(luò)
點擊上方“小白學視覺”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達
又好久沒有繼續(xù)寫了,這個是我寫的第21篇文章,我還在繼續(xù)堅持寫下去,雖然經(jīng)常各種拖延癥,但是我還記得,一直沒有敢忘記!今天給大家分享一下Pytorch生成對抗網(wǎng)絡(luò)代碼實現(xiàn)。
?
Ian J. Goodfellow在2014年提出生成對抗網(wǎng)絡(luò),從此打開了深度學習中另外一個重要分支,讓生成對抗網(wǎng)絡(luò)(GAN)成為與卷積神經(jīng)網(wǎng)絡(luò)(CNN)、循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN/LSTM)可以并駕齊驅(qū)的分支領(lǐng)域。今天GAN仍然是計算機視覺領(lǐng)域研究熱點之一,每年還有大量相關(guān)的論文產(chǎn)生,GAN已經(jīng)被用在視覺任務(wù)的很多方面,主要包括:
圖像合成與數(shù)據(jù)增廣
圖像翻譯與變換
缺陷檢測
圖像去噪與重建
圖像分割
但是GAN最基本的核心思想還是2014年Ian J. Goodfellow在論文中提到的兩個基本的模型分別是:生成器與判別器

生成器(G):
根據(jù)輸入噪聲Z生成輸出樣本G(z)目標:通過生成樣本與目標樣本分布一致,成功欺騙鑒別器
判別器(D):
根據(jù)輸入樣本數(shù)據(jù)來分辨真實樣本概率從數(shù)據(jù)中學習樣本數(shù)據(jù)的差異性
從a到d,可以看到輸入噪聲的生成分布越來越接近真實分布X,最終達到一種平衡狀態(tài),這種穩(wěn)定的平衡狀態(tài)叫納什均衡,還有一部電影跟這個有關(guān)系叫《美麗心靈》。
?
下面的代碼實現(xiàn)了基于Mnist數(shù)據(jù)集實現(xiàn)判別器與生成器,最終通過生成器可以自動生成手寫數(shù)字識別的圖像,輸入的z=100是隨機噪聲,輸出的是784個數(shù)據(jù)表示28x28大小的手寫數(shù)字樣本,損失主要來自兩個部分,生成器生成損失,判別器分別判別真實與虛構(gòu)樣本概率,基于反向傳播訓練兩個網(wǎng)絡(luò),設(shè)置epoch=100,得到最終的生成器生成結(jié)果如下:

生成器與判別器代碼實現(xiàn)如下
判別器與生成器代碼:(后面文字忽略)2004論文中提出,其主要思想可以通過下面一張圖像解釋:
1transform?=?tv.transforms.Compose([tv.transforms.ToTensor(),
2???????????????????????????????????tv.transforms.Normalize((0.5,),?(0.5,))])
3train_ts?=?tv.datasets.MNIST(root='./data',?train=True,?download=True,?transform=transform)
4test_ts?=?tv.datasets.MNIST(root='./data',?train=False,?download=True,?transform=transform)
5train_dl?=?DataLoader(train_ts,?batch_size=128,?shuffle=True,?drop_last=False)
6test_dl?=?DataLoader(test_ts,?batch_size=128,?shuffle=True,?drop_last=False)
7
8
9class?Generator(t.nn.Module):
10????def?__init__(self,?g_input_dim,?g_output_dim):
11????????super(Generator,?self).__init__()
12????????self.fc1?=?t.nn.Linear(g_input_dim,?256)
13????????self.fc2?=?t.nn.Linear(self.fc1.out_features,?self.fc1.out_features?*?2)
14????????self.fc3?=?t.nn.Linear(self.fc2.out_features,?self.fc2.out_features?*?2)
15????????self.fc4?=?t.nn.Linear(self.fc3.out_features,?g_output_dim)
16
17????#?forward?method
18????def?forward(self,?x):
19????????x?=?F.leaky_relu(self.fc1(x),?0.2)
20????????x?=?F.leaky_relu(self.fc2(x),?0.2)
21????????x?=?F.leaky_relu(self.fc3(x),?0.2)
22????????return?t.tanh(self.fc4(x))
23
24
25class?Discriminator(t.nn.Module):
26????def?__init__(self,?d_input_dim):
27????????super(Discriminator,?self).__init__()
28????????self.fc1?=?t.nn.Linear(d_input_dim,?1024)
29????????self.fc2?=?t.nn.Linear(self.fc1.out_features,?self.fc1.out_features?//?2)
30????????self.fc3?=?t.nn.Linear(self.fc2.out_features,?self.fc2.out_features?//?2)
31????????self.fc4?=?t.nn.Linear(self.fc3.out_features,?1)
32
33????#?forward?method
34????def?forward(self,?x):
35????????x?=?F.leaky_relu(self.fc1(x),?0.2)
36????????x?=?F.dropout(x,?0.3)
37????????x?=?F.leaky_relu(self.fc2(x),?0.2)
38????????x?=?F.dropout(x,?0.3)
39????????x?=?F.leaky_relu(self.fc3(x),?0.2)
40????????x?=?F.dropout(x,?0.3)
41????????return?t.sigmoid(self.fc4(x))損失與訓練代碼如下
分別定義生成網(wǎng)絡(luò)訓練與鑒別網(wǎng)絡(luò)的訓練方法,然后開始訓練即可,代碼實現(xiàn)如下:
1#?生成者與判別者
2bs?=?128
3z_dim?=?100
4mnist_dim?=?784
5#?loss
6criterion?=?t.nn.BCELoss()
7
8#?optimizer
9device?=?"cuda"
10gnet?=?Generator(g_input_dim?=?z_dim,?g_output_dim?=?mnist_dim).to(device)
11dnet?=?Discriminator(mnist_dim).to(device)
12lr?=?0.0002
13G_optimizer?=?t.optim.Adam(gnet.parameters(),?lr=lr)
14D_optimizer?=?t.optim.Adam(dnet.parameters(),?lr=lr)
15
16
17def?D_train(x):
18????#?=======================Train?the?discriminator=======================#
19????dnet.zero_grad()
20
21????#?train?discriminator?on?real
22????x_real,?y_real?=?x.view(-1,?mnist_dim),?t.ones(bs,?1)
23????x_real,?y_real?=?Variable(x_real.to(device)),?Variable(y_real.to(device))
24
25????D_output?=?dnet(x_real)
26????D_real_loss?=?criterion(D_output,?y_real)
27
28????#?train?discriminator?on?facke
29????z?=?Variable(t.randn(bs,?z_dim).to(device))
30????x_fake,?y_fake?=?gnet(z),?Variable(t.zeros(bs,?1).to(device))
31
32????D_output?=?dnet(x_fake)
33????D_fake_loss?=?criterion(D_output,?y_fake)
34
35????#?gradient?backprop?&?optimize?ONLY?D's?parameters
36????D_loss?=?D_real_loss?+?D_fake_loss
37????D_loss.backward()
38????D_optimizer.step()
39
40????return?D_loss.data.item()
41
42
43def?G_train(x):
44????#?=======================Train?the?generator=======================#
45????gnet.zero_grad()
46
47????z?=?Variable(t.randn(bs,?z_dim).to(device))
48????y?=?Variable(t.ones(bs,?1).to(device))
49
50????G_output?=?gnet(z)
51????D_output?=?dnet(G_output)
52????G_loss?=?criterion(D_output,?y)
53
54????#?gradient?backprop?&?optimize?ONLY?G's?parameters
55????G_loss.backward()
56????G_optimizer.step()
57
58????return?G_loss.data.item()
59
60
61n_epoch?=?100
62for?epoch?in?range(1,?n_epoch+1):
63????D_losses,?G_losses?=?[],?[]
64????for?batch_idx,?(x,?_)?in?enumerate(train_dl):
65????????bs_,?_,_,_?=?x.size()
66????????bs?=?bs_
67????????D_losses.append(D_train(x))
68????????G_losses.append(G_train(x))
69
70????print('[%d/%d]:?loss_d:?%.3f,?loss_g:?%.3f'?%?(
71????????????(epoch),?n_epoch,?t.mean(t.FloatTensor(D_losses)),?t.mean(t.FloatTensor(G_losses))))下載1:OpenCV-Contrib擴展模塊中文版教程 在「小白學視覺」公眾號后臺回復:擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺實戰(zhàn)項目52講 在「小白學視覺」公眾號后臺回復:Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。 下載3:OpenCV實戰(zhàn)項目20講 在「小白學視覺」公眾號后臺回復:OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。 交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

