<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          使用Pytorch實(shí)現(xiàn)頻譜歸一化生成對(duì)抗網(wǎng)絡(luò)(SN-GAN)

          共 12138字,需瀏覽 25分鐘

           ·

          2023-11-02 09:12

             
             
          來源:DeepHub IMBA

          本文約3800字,建議閱讀5分鐘

          自從擴(kuò)散模型發(fā)布以來,GAN的關(guān)注度和論文是越來越少了,但是它們里面的一些思路還是值得我們了解和學(xué)習(xí)。所以本文我們來使用Pytorch 來實(shí)現(xiàn)SN-GAN。

          譜歸一化生成對(duì)抗網(wǎng)絡(luò)是一種生成對(duì)抗網(wǎng)絡(luò),它使用譜歸一化技術(shù)來穩(wěn)定鑒別器的訓(xùn)練。譜歸一化是一種權(quán)值歸一化技術(shù),它約束了鑒別器中每一層的譜范數(shù)。這有助于防止鑒別器變得過于強(qiáng)大,從而導(dǎo)致不穩(wěn)定和糟糕的結(jié)果。
          SN-GAN由Miyato等人(2018)在論文“生成對(duì)抗網(wǎng)絡(luò)的譜歸一化”中提出,作者證明了sn - gan在各種圖像生成任務(wù)上比其他gan具有更好的性能。
          SN-GAN的訓(xùn)練方式與其他gan相同。生成器網(wǎng)絡(luò)學(xué)習(xí)生成與真實(shí)圖像無法區(qū)分的圖像,而鑒別器網(wǎng)絡(luò)學(xué)習(xí)區(qū)分真實(shí)圖像和生成圖像。這兩個(gè)網(wǎng)絡(luò)以競爭的方式進(jìn)行訓(xùn)練,它們最終達(dá)到一個(gè)點(diǎn),即生成器能夠產(chǎn)生逼真的圖像,從而欺騙鑒別器。
          以下是SN-GAN相對(duì)于其他gan的優(yōu)勢(shì)總結(jié):
          • 更穩(wěn)定,更容易訓(xùn)練
          • 可以生成更高質(zhì)量的圖像
          • 更通用,可以用來生成更廣泛的內(nèi)容。


          模式崩潰


          模式崩潰是生成對(duì)抗網(wǎng)絡(luò)(GANs)訓(xùn)練中常見的問題。當(dāng)GAN的生成器網(wǎng)絡(luò)無法產(chǎn)生多樣化的輸出,而是陷入特定的模式時(shí),就會(huì)發(fā)生模式崩潰。這會(huì)導(dǎo)致生成的輸出出現(xiàn)重復(fù),缺乏多樣性和細(xì)節(jié),有時(shí)甚至與訓(xùn)練數(shù)據(jù)完全無關(guān)。
          GAN中發(fā)生模式崩潰有幾個(gè)原因。一個(gè)原因是生成器網(wǎng)絡(luò)可能對(duì)訓(xùn)練數(shù)據(jù)過擬合。如果訓(xùn)練數(shù)據(jù)不夠多樣化,或者生成器網(wǎng)絡(luò)太復(fù)雜,就會(huì)發(fā)生這種情況。另一個(gè)原因是生成器網(wǎng)絡(luò)可能陷入損失函數(shù)的局部最小值。如果學(xué)習(xí)率太高,或者損失函數(shù)定義不明確,就會(huì)發(fā)生這種情況。
          以前有許多技術(shù)可以用來防止模式崩潰。比如使用更多樣化的訓(xùn)練數(shù)據(jù)集。或者使用正則化技術(shù),例如dropout或批處理歸一化,使用合適的學(xué)習(xí)率和損失函數(shù)也很重要。

          Wassersteian損失


          Wasserstein損失,也稱為Earth Mover’s Distance(EMD)或Wasserstein GAN (WGAN)損失,是一種用于生成對(duì)抗網(wǎng)絡(luò)(GAN)的損失函數(shù)。引入它是為了解決與傳統(tǒng)GAN損失函數(shù)相關(guān)的一些問題,例如Jensen-Shannon散度和Kullback-Leibler散度。
          Wasserstein損失測量真實(shí)數(shù)據(jù)和生成數(shù)據(jù)的概率分布之間的差異,同時(shí)確保它具有一定的數(shù)學(xué)性質(zhì)。他的思想是最小化這兩個(gè)分布之間的Wassersteian距離(也稱為地球移動(dòng)者距離)。Wasserstein距離可以被認(rèn)為是將一個(gè)分布轉(zhuǎn)換為另一個(gè)分布所需的最小“成本”,其中“成本”被定義為將概率質(zhì)量從一個(gè)位置移動(dòng)到另一個(gè)位置所需的“工作量”。
          Wasserstein損失的數(shù)學(xué)定義如下:
          對(duì)于生成器G和鑒別器D, Wasserstein損失(Wasserstein距離)可以表示為:
          Jensen-Shannon散度(JSD): Jensen-Shannon散度是一種對(duì)稱度量,用于量化兩個(gè)概率分布之間的差異
          對(duì)于概率分布P和Q, JSD定義如下:
             
             
           JSD(P∥Q)=1/2(KL(P∥M)+KL(Q∥M))
          M為平均分布,KL為Kullback-Leibler散度,P∥Q為分布P與分布Q之間的JSD。
          JSD總是非負(fù)的,在0和1之間有界,并且對(duì)稱(JSD(P|Q) = JSD(Q|P))。它可以被解釋為KL散度的“平滑”版本。
          Kullback-Leibler散度(KL散度):Kullback-Leibler散度,通常被稱為KL散度或相對(duì)熵,通過量化“額外信息”來測量兩個(gè)概率分布之間的差異,這些“額外信息”需要使用另一個(gè)分布作為參考來編碼一個(gè)分布。
          對(duì)于兩個(gè)概率分布P和Q,從Q到P的KL散度定義為:KL(P∥Q)=∑x P(x)log(Q(x)/P(x))。KL散度是非負(fù)非對(duì)稱的,即KL(P∥Q)≠KL(Q∥P)。當(dāng)且僅當(dāng)P和Q相等時(shí)它為零。KL散度是無界的,可以用來衡量分布之間的不相似性。

          1-Lipschitz Contiunity

          1- lipschitz函數(shù)是斜率的絕對(duì)值以1為界的函數(shù)。這意味著對(duì)于任意兩個(gè)輸入x和y,函數(shù)輸出之間的差不超過輸入之間的差。
          數(shù)學(xué)上函數(shù)f是1-Lipschitz,如果對(duì)于f定義域內(nèi)的所有x和y,以下不等式成立:
             
             
           |f(x) — f(y)| <= |x — y|
          在生成對(duì)抗網(wǎng)絡(luò)(GANs)中強(qiáng)制Lipschitz連續(xù)性是一種用于穩(wěn)定訓(xùn)練和防止與傳統(tǒng)GANs相關(guān)的一些問題的技術(shù),例如模式崩潰和訓(xùn)練不穩(wěn)定。在GAN中實(shí)現(xiàn)Lipschitz連續(xù)性的主要方法是通過使用Lipschitz約束或正則化,一種常用的方法是Wasserstein GAN (WGAN)。
          在標(biāo)準(zhǔn)gan中,鑒別器(也稱為WGAN中的批評(píng)家)被訓(xùn)練來區(qū)分真實(shí)和虛假數(shù)據(jù)。為了加強(qiáng)Lipschitz連續(xù)性,WGAN增加了一個(gè)約束,即鑒別器函數(shù)應(yīng)該是Lipschitz連續(xù)的,這意味著函數(shù)的梯度不應(yīng)該增長得太大。在數(shù)學(xué)上,它被限制為:
             
             
           ∥∣D(x)?D(y)∣≤K?∥x?y∥
          其中D(x)是評(píng)論家對(duì)數(shù)據(jù)點(diǎn)x的輸出,D(y)是y的輸出,K是Lipschitz 常數(shù)。
          WGAN的權(quán)重裁剪:在原始的WGAN中,通過在每個(gè)訓(xùn)練步驟后將鑒別器網(wǎng)絡(luò)的權(quán)重裁剪到一個(gè)小范圍(例如,[-0.01,0.01])來強(qiáng)制執(zhí)行該約束。權(quán)重裁剪確保了鑒別器的梯度保持在一定范圍內(nèi),并加強(qiáng)了利普希茨連續(xù)性。
          WGAN的梯度懲罰: WGAN的一種變體,稱為WGAN-GP,它使用梯度懲罰而不是權(quán)值裁剪來強(qiáng)制Lipschitz約束。WGAN-GP基于鑒別器的輸出相對(duì)于真實(shí)和虛假數(shù)據(jù)之間的隨機(jī)點(diǎn)的梯度,在損失函數(shù)中添加了一個(gè)懲罰項(xiàng)。這種懲罰鼓勵(lì)了Lipschitz約束,而不需要權(quán)重裁剪。

          譜范數(shù)


          從符號(hào)上看矩陣??的譜范數(shù)通常表示為:對(duì)于神經(jīng)網(wǎng)絡(luò)??矩陣表示網(wǎng)絡(luò)層中的一個(gè)權(quán)重矩陣。矩陣的譜范數(shù)是矩陣的最大奇異值,可以通過奇異值分解(SVD)得到。
          奇異值分解是特征分解的推廣,用于將矩陣分解為
          其中??,q為正交矩陣,Σ為其對(duì)角線上的奇異值矩陣。注意Σ不一定是正方形的。
          其中??1和??分別為最大奇異值和最小奇異值。更大的值對(duì)應(yīng)于一個(gè)矩陣可以應(yīng)用于另一個(gè)向量的更大的拉伸量。依此表示,??(??)=??1.
          SVD在譜歸一化中的應(yīng)用
          為了對(duì)權(quán)矩陣進(jìn)行頻譜歸一化,將矩陣中的每個(gè)值除以它的頻譜范數(shù)。譜歸一化矩陣可以表示為
          計(jì)算??is的SVD非常昂貴,所以SN-GAN論文的作者做了一些簡化。它們通過冪次迭代來近似左、右奇異向量??和??,分別為:??)≈??

          代碼實(shí)現(xiàn)


          現(xiàn)在我們開始使用Pytorch實(shí)現(xiàn)
             
             
           import torch from torch import nn from tqdm.auto import tqdm from torchvision import transforms from torchvision.datasets import MNIST from torchvision.utils import make_grid from torch.utils.data import DataLoader import matplotlib.pyplot as plt torch.manual_seed(0)  def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):    image_tensor = (image_tensor + 1) / 2    image_unflat = image_tensor.detach().cpu()    image_grid = make_grid(image_unflat[:num_images], nrow=5)    plt.imshow(image_grid.permute(1, 2, 0).squeeze())    plt.show()

          生成器:


          class Generator(nn.Module): def __init__(self,z_dim=10,im_chan = 1,hidden_dim = 64): super(Generatoe,self).__init__() self.gen = nn.Sequential( self.make_gen_block(z_dim,hidden_dim * 4), self.make_gen_block(hidden_dim*4,hidden_dim * 2,kernel_size = 4,stride =1), self.make_gen_block(hidden_dim * 2,hidden_dim), self.make_gen_block(hidden_dim,im_chan,kernel_size=4,final_layer = True), ) def make_gen_block(self,input_channels,output_channels,kernel_size=3,stride=2,final_layer = False): if not final_layer : return nn.Sequential(nn.ConvTranspose2D(input_layer,output_layer,kernel_size,stride), nn.BatchNorm2d(output_channels), nn.ReLU(inplace = True), ) else: return nn.Sequential(nn.ConvTranspose2D(input_layer,output_layer,kernel_size,stride), nn.Tanh(),) def unsqueeze_noise(): return noise.view(len(noise), self.z_dim, 1, 1) def forward(self,noise): x = self.unsqueeze_noise(noise) return self.gen(x) def get_noise(n_samples, z_dim, device='cpu'): return torch.randn(n_samples, z_dim, device=device)
          鑒頻器
          對(duì)于鑒別器,我們可以使用spectral_norm對(duì)每個(gè)Conv2D 進(jìn)行處理。除了??之外,還引入了??、??、和其他的參數(shù),這樣在運(yùn)行時(shí)就可以計(jì)算出????的二進(jìn)制二進(jìn)制運(yùn)算符:??、y、y、y、y
          因?yàn)镻ytorch還提供 nn.utils. spectral_norm,nn.utils. remove_spectral_norm函數(shù),所以我們操作起來很方便。
          我們只在推理期間將nn.utils. remove_spectral_norm應(yīng)用于卷積層,以提高運(yùn)行速度。
          值得注意的是,譜范數(shù)并不能消除對(duì)批范數(shù)的需要。譜范數(shù)影響每一層的權(quán)重,批范數(shù)影響每一層的激活度。
             
             
           class Discriminator(nn.Module):      def __init__(self, im_chan=1, hidden_dim=16):        super(Discriminator, self).__init__()        self.disc = nn.Sequential(            self.make_disc_block(im_chan, hidden_dim),            self.make_disc_block(hidden_dim, hidden_dim * 2),            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),        )      def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):        if not final_layer:            return nn.Sequential(                nn.utils.spectral_norm(nn.Conv2d(input_channels, output_channels, kernel_size, stride)),                nn.BatchNorm2d(output_channels),                nn.LeakyReLU(0.2, inplace=True),            )        else:            return nn.Sequential(                nn.utils.spectral_norm(nn.Conv2d(input_channels, output_channels, kernel_size, stride)),            )    def forward(self, image):        disc_pred = self.disc(image)        return disc_pred.view(len(disc_pred), -1)


          訓(xùn)練


          我們這里使用MNIST數(shù)據(jù)集,bcewithlogitsloss()函數(shù)計(jì)算logit和目標(biāo)標(biāo)簽之間的二進(jìn)制交叉熵?fù)p失。二值交叉熵?fù)p失是對(duì)兩個(gè)分布差異程度的度量。在二元分類中,這兩種分布分別是邏輯的分布和目標(biāo)標(biāo)簽的分布。
             
             
           criterion = nn.BCEWithLogitsLoss() n_epochs = 50 z_dim = 64 display_step = 500 batch_size = 128 # A learning rate of 0.0002 works well on DCGAN lr = 0.0002 beta_1 = 0.5 beta_2 = 0.999 device = 'cuda' transform = transforms.Compose([    transforms.ToTensor(),    transforms.Normalize((0.5,), (0.5,)), ])  dataloader = DataLoader(    MNIST(".", download=True, transform=transform),    batch_size=batch_size,    shuffle=True)

          創(chuàng)建生成器和鑒別器
             
             
           gen = Generator(z_dim).to(device) gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2)) disc = Discriminator().to(device) disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))  # initialize the weights to the normal distribution # with mean 0 and standard deviation 0.02 def weights_init(m):    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):        torch.nn.init.normal_(m.weight, 0.0, 0.02)    if isinstance(m, nn.BatchNorm2d):        torch.nn.init.normal_(m.weight, 0.0, 0.02)        torch.nn.init.constant_(m.bias, 0) gen = gen.apply(weights_init) disc = disc.apply(weights_init)

          下面是訓(xùn)練步驟
             
             
           cur_step = 0 mean_generator_loss = 0 mean_discriminator_loss = 0 for epoch in range(n_epochs):    # Dataloader returns the batches    for real, _ in tqdm(dataloader):        cur_batch_size = len(real)        real = real.to(device)         ## Update Discriminator ##        disc_opt.zero_grad()        fake_noise = get_noise(cur_batch_size, z_dim, device=device)        fake = gen(fake_noise)        disc_fake_pred = disc(fake.detach())        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))        disc_real_pred = disc(real)        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))        disc_loss = (disc_fake_loss + disc_real_loss) / 2         # Keep track of the average discriminator loss        mean_discriminator_loss += disc_loss.item() / display_step        # Update gradients        disc_loss.backward(retain_graph=True)        # Update optimizer        disc_opt.step()         ## Update Generator ##        gen_opt.zero_grad()        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)        fake_2 = gen(fake_noise_2)        disc_fake_pred = disc(fake_2)        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))        gen_loss.backward()        gen_opt.step()         # Keep track of the average generator loss        mean_generator_loss += gen_loss.item() / display_step         ## Visualization code ##        if cur_step % display_step == 0 and cur_step > 0:            print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")            show_tensor_images(fake)            show_tensor_images(real)            mean_generator_loss = 0            mean_discriminator_loss = 0        cur_step += 1

          訓(xùn)練結(jié)果如下:


          總結(jié)


          本文我們介紹了SN-GAN的原理和簡單的代碼實(shí)現(xiàn),SN-GAN已經(jīng)被廣泛應(yīng)用于圖像生成任務(wù),包括圖像合成、風(fēng)格遷移和超分辨率等領(lǐng)域。它在改善生成模型的性能和穩(wěn)定性方面取得了顯著的成果,所以學(xué)習(xí)他的代碼對(duì)我們理解會(huì)更有幫助。

          編輯:文婧

          瀏覽 340
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  五月天激情综合网 | 国产永久免费视频 | 黄色视频在线看网站 | 蜜桃久久网 | 免费A片在线免费观看 |