<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          輕松學Pytorch – 構(gòu)建生成對抗網(wǎng)絡(luò)

          共 5229字,需瀏覽 11分鐘

           ·

          2022-05-24 10:10

          點擊上方小白學視覺”,選擇加"星標"或“置頂

          重磅干貨,第一時間送達

          又好久沒有繼續(xù)寫了,這個是我寫的第21篇文章,我還在繼續(xù)堅持寫下去,雖然經(jīng)常各種拖延癥,但是我還記得,一直沒有敢忘記!今天給大家分享一下Pytorch生成對抗網(wǎng)絡(luò)代碼實現(xiàn)。

          ?

          01.什么是生成對抗網(wǎng)絡(luò)


          Ian J. Goodfellow在2014年提出生成對抗網(wǎng)絡(luò),從此打開了深度學習中另外一個重要分支,讓生成對抗網(wǎng)絡(luò)(GAN)成為與卷積神經(jīng)網(wǎng)絡(luò)(CNN)、循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN/LSTM)可以并駕齊驅(qū)的分支領(lǐng)域。今天GAN仍然是計算機視覺領(lǐng)域研究熱點之一,每年還有大量相關(guān)的論文產(chǎn)生,GAN已經(jīng)被用在視覺任務(wù)的很多方面,主要包括:

          • 圖像合成與數(shù)據(jù)增廣

          • 圖像翻譯與變換

          • 缺陷檢測

          • 圖像去噪與重建

          • 圖像分割

          但是GAN最基本的核心思想還是2014年Ian J. Goodfellow在論文中提到的兩個基本的模型分別是:生成器與判別器

          生成器(G):

          根據(jù)輸入噪聲Z生成輸出樣本G(z)目標:通過生成樣本與目標樣本分布一致,成功欺騙鑒別器

          判別器(D):

          根據(jù)輸入樣本數(shù)據(jù)來分辨真實樣本概率從數(shù)據(jù)中學習樣本數(shù)據(jù)的差異性

          從a到d,可以看到輸入噪聲的生成分布越來越接近真實分布X,最終達到一種平衡狀態(tài),這種穩(wěn)定的平衡狀態(tài)叫納什均衡,還有一部電影跟這個有關(guān)系叫《美麗心靈》。

          ?

          02.GAN代碼實現(xiàn)


          下面的代碼實現(xiàn)了基于Mnist數(shù)據(jù)集實現(xiàn)判別器與生成器,最終通過生成器可以自動生成手寫數(shù)字識別的圖像,輸入的z=100是隨機噪聲,輸出的是784個數(shù)據(jù)表示28x28大小的手寫數(shù)字樣本,損失主要來自兩個部分,生成器生成損失,判別器分別判別真實與虛構(gòu)樣本概率,基于反向傳播訓練兩個網(wǎng)絡(luò),設(shè)置epoch=100,得到最終的生成器生成結(jié)果如下:


          生成器與判別器代碼實現(xiàn)如下


          判別器與生成器代碼:(后面文字忽略)2004論文中提出,其主要思想可以通過下面一張圖像解釋:

           1transform?=?tv.transforms.Compose([tv.transforms.ToTensor(),
          2???????????????????????????????????tv.transforms.Normalize((0.5,),?(0.5,))])
          3train_ts?=?tv.datasets.MNIST(root='./data',?train=True,?download=True,?transform=transform)
          4test_ts?=?tv.datasets.MNIST(root='./data',?train=False,?download=True,?transform=transform)
          5train_dl?=?DataLoader(train_ts,?batch_size=128,?shuffle=True,?drop_last=False)
          6test_dl?=?DataLoader(test_ts,?batch_size=128,?shuffle=True,?drop_last=False)
          7
          8
          9class?Generator(t.nn.Module):
          10????def?__init__(self,?g_input_dim,?g_output_dim):
          11????????super(Generator,?self).__init__()
          12????????self.fc1?=?t.nn.Linear(g_input_dim,?256)
          13????????self.fc2?=?t.nn.Linear(self.fc1.out_features,?self.fc1.out_features?*?2)
          14????????self.fc3?=?t.nn.Linear(self.fc2.out_features,?self.fc2.out_features?*?2)
          15????????self.fc4?=?t.nn.Linear(self.fc3.out_features,?g_output_dim)
          16
          17????#?forward?method
          18????def?forward(self,?x):
          19????????x?=?F.leaky_relu(self.fc1(x),?0.2)
          20????????x?=?F.leaky_relu(self.fc2(x),?0.2)
          21????????x?=?F.leaky_relu(self.fc3(x),?0.2)
          22????????return?t.tanh(self.fc4(x))
          23
          24
          25class?Discriminator(t.nn.Module):
          26????def?__init__(self,?d_input_dim):
          27????????super(Discriminator,?self).__init__()
          28????????self.fc1?=?t.nn.Linear(d_input_dim,?1024)
          29????????self.fc2?=?t.nn.Linear(self.fc1.out_features,?self.fc1.out_features?//?2)
          30????????self.fc3?=?t.nn.Linear(self.fc2.out_features,?self.fc2.out_features?//?2)
          31????????self.fc4?=?t.nn.Linear(self.fc3.out_features,?1)
          32
          33????#?forward?method
          34????def?forward(self,?x):
          35????????x?=?F.leaky_relu(self.fc1(x),?0.2)
          36????????x?=?F.dropout(x,?0.3)
          37????????x?=?F.leaky_relu(self.fc2(x),?0.2)
          38????????x?=?F.dropout(x,?0.3)
          39????????x?=?F.leaky_relu(self.fc3(x),?0.2)
          40????????x?=?F.dropout(x,?0.3)
          41????????return?t.sigmoid(self.fc4(x))


          損失與訓練代碼如下


          分別定義生成網(wǎng)絡(luò)訓練與鑒別網(wǎng)絡(luò)的訓練方法,然后開始訓練即可,代碼實現(xiàn)如下:

           1#?生成者與判別者
          2bs?=?128
          3z_dim?=?100
          4mnist_dim?=?784
          5#?loss
          6criterion?=?t.nn.BCELoss()
          7
          8#?optimizer
          9device?=?"cuda"
          10gnet?=?Generator(g_input_dim?=?z_dim,?g_output_dim?=?mnist_dim).to(device)
          11dnet?=?Discriminator(mnist_dim).to(device)
          12lr?=?0.0002
          13G_optimizer?=?t.optim.Adam(gnet.parameters(),?lr=lr)
          14D_optimizer?=?t.optim.Adam(dnet.parameters(),?lr=lr)
          15
          16
          17def?D_train(x):
          18????#?=======================Train?the?discriminator=======================#
          19????dnet.zero_grad()
          20
          21????#?train?discriminator?on?real
          22????x_real,?y_real?=?x.view(-1,?mnist_dim),?t.ones(bs,?1)
          23????x_real,?y_real?=?Variable(x_real.to(device)),?Variable(y_real.to(device))
          24
          25????D_output?=?dnet(x_real)
          26????D_real_loss?=?criterion(D_output,?y_real)
          27
          28????#?train?discriminator?on?facke
          29????z?=?Variable(t.randn(bs,?z_dim).to(device))
          30????x_fake,?y_fake?=?gnet(z),?Variable(t.zeros(bs,?1).to(device))
          31
          32????D_output?=?dnet(x_fake)
          33????D_fake_loss?=?criterion(D_output,?y_fake)
          34
          35????#?gradient?backprop?&?optimize?ONLY?D's?parameters
          36????D_loss?=?D_real_loss?+?D_fake_loss
          37????D_loss.backward()
          38????D_optimizer.step()
          39
          40????return?D_loss.data.item()
          41
          42
          43def?G_train(x):
          44????#?=======================Train?the?generator=======================#
          45????gnet.zero_grad()
          46
          47????z?=?Variable(t.randn(bs,?z_dim).to(device))
          48????y?=?Variable(t.ones(bs,?1).to(device))
          49
          50????G_output?=?gnet(z)
          51????D_output?=?dnet(G_output)
          52????G_loss?=?criterion(D_output,?y)
          53
          54????#?gradient?backprop?&?optimize?ONLY?G's?parameters
          55????G_loss.backward()
          56????G_optimizer.step()
          57
          58????return?G_loss.data.item()
          59
          60
          61n_epoch?=?100
          62for?epoch?in?range(1,?n_epoch+1):
          63????D_losses,?G_losses?=?[],?[]
          64????for?batch_idx,?(x,?_)?in?enumerate(train_dl):
          65????????bs_,?_,_,_?=?x.size()
          66????????bs?=?bs_
          67????????D_losses.append(D_train(x))
          68????????G_losses.append(G_train(x))
          69
          70????print('[%d/%d]:?loss_d:?%.3f,?loss_g:?%.3f'?%?(
          71????????????(epoch),?n_epoch,?t.mean(t.FloatTensor(D_losses)),?t.mean(t.FloatTensor(G_losses))))



          下載1:OpenCV-Contrib擴展模塊中文版教程
          在「小白學視覺」公眾號后臺回復:擴展模塊中文教程即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

          下載2:Python視覺實戰(zhàn)項目52講
          小白學視覺公眾號后臺回復:Python視覺實戰(zhàn)項目即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。

          下載3:OpenCV實戰(zhàn)項目20講
          小白學視覺公眾號后臺回復:OpenCV實戰(zhàn)項目20講即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。

          交流群


          歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三?+?上海交大?+?視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~


          瀏覽 51
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  三四级日逼视频关看 | 日韩无码免费视频 | 亚洲午夜无码久久久久蜜桃AV | 黄色α视频| 久热福利在线 |