<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          超越GAN?OpenAI提出可逆生成模型Glow!圖像生成太逼真

          共 7066字,需瀏覽 15分鐘

           ·

          2021-08-24 18:40

          點(diǎn)擊下方卡片,關(guān)注“CVer”公眾號(hào)

          AI/CV重磅干貨,第一時(shí)間送達(dá)

          作者:Aryansh Omray,微軟數(shù)據(jù)科學(xué)工程師,Medium技術(shù)博主

          機(jī)器學(xué)習(xí)領(lǐng)域的一個(gè)基本問(wèn)題就是如何學(xué)習(xí)復(fù)雜數(shù)據(jù)的表征是機(jī)器學(xué)習(xí)。這項(xiàng)任務(wù)的重要性在于,現(xiàn)存的大量非結(jié)構(gòu)化和無(wú)標(biāo)簽的數(shù)據(jù),只有通過(guò)無(wú)監(jiān)督式學(xué)習(xí)才能理解。密度估計(jì)、異常檢測(cè)、文本總結(jié)、數(shù)據(jù)聚類(lèi)、生物信息學(xué)、DNA建模等各方面的應(yīng)用均需要完成這項(xiàng)任務(wù)。多年來(lái),研究人員發(fā)明了許多方法來(lái)學(xué)習(xí)大型數(shù)據(jù)集的概率分布,包括生成對(duì)抗網(wǎng)絡(luò)(GAN)、變分自編碼器(VAE)和Normalizing Flow等。本文即向大家介紹Normalizing Flow這一為了克服GAN和VAE的不足而提出的方法。

          Glow模型的輸出樣例

          https://papers.nips.cc/paper/2018/file/d139db6a236200b21cc7f752979132d0-Paper.pdf


          GAN和VAE的能力本已十分驚人,它們都能通過(guò)簡(jiǎn)單的推理方法學(xué)習(xí)十分復(fù)雜的數(shù)據(jù)分布。然而,GAN和VAE都缺乏對(duì)概率分布的精確評(píng)估和推理,這往往導(dǎo)致VAE中的模糊結(jié)果質(zhì)量不高,GAN訓(xùn)練也面臨著如模式崩潰和后置崩潰等挑戰(zhàn)。因此,Normalizing Flow應(yīng)運(yùn)而生,試圖通過(guò)使用可逆函數(shù)來(lái)解決目前GAN和VAE存在的許多問(wèn)題。


          Normalizing Flow

          簡(jiǎn)單地說(shuō),Normalizing Flow就是一系列的可逆函數(shù),或者說(shuō)這些函數(shù)的解析逆是可以計(jì)算的。例如,f(x)=x+2是一個(gè)可逆函數(shù),因?yàn)槊總€(gè)輸入都有且僅有一個(gè)唯一的輸出,并且反之亦然,而f(x)=x2則不是一個(gè)可逆函數(shù)。這樣的函數(shù)也被稱(chēng)為雙射函數(shù)。

          圖源作者

          從上圖可以看出,Normalizing Flow可以將復(fù)雜的數(shù)據(jù)點(diǎn)(如MNIST中的圖像)轉(zhuǎn)化為簡(jiǎn)單的高斯分布,反之亦然。和GAN非常不一樣的地方是,GAN輸入的是一個(gè)隨機(jī)向量,而輸出的是一個(gè)圖像,基于流(Flow)的模型則是將數(shù)據(jù)點(diǎn)轉(zhuǎn)化為簡(jiǎn)單分布。在上圖的MNIST一例中,我們從高斯分布中抽取隨機(jī)樣本,均可重新獲得其對(duì)應(yīng)的MNIST圖像。

          基于流的模型使用負(fù)對(duì)數(shù)可能性損失函數(shù)進(jìn)行訓(xùn)練,其中p(z)是概率函數(shù)。下面的損失函數(shù)就是使用統(tǒng)計(jì)學(xué)中的變量變化公式得到的。

          https://papers.nips.cc/paper/2018/file/d139db6a236200b21cc7f752979132d0-Paper.pdf


          Normalizing Flow的優(yōu)勢(shì)

          與GAN和VAE相比,Normalizing Flow具有各種優(yōu)勢(shì),包括:

          • Normalizing Flow模型不需要在輸出中放入噪聲,因此可以有更強(qiáng)大的局部方差模型(local variance model);

          • 與GAN相比,基于流的模型訓(xùn)練過(guò)程非常穩(wěn)定,GAN則需要仔細(xì)調(diào)整生成器和判別器的超參數(shù);

          • 與GAN和VAE相比,Normalizing Flow更容易收斂。

          Normalizing Flow的不足

          雖然基于流的模型有其優(yōu)勢(shì),但它們也有一些缺點(diǎn):

          • 基于流的模型在密度估計(jì)等任務(wù)上的表現(xiàn)不盡如人意;

          • 基于流的模型要求保留變換的體積(volume preservation over transformations),這往往會(huì)產(chǎn)生非常高維的潛在空間,通常會(huì)導(dǎo)致解釋性變差;

          • 基于流的模型產(chǎn)生的樣本通常沒(méi)有GAN和VAE的好。

          為了更好地理解Normalizing Flow,我們以Glow架構(gòu)為例進(jìn)行解釋。Glow是OpenAI在2018年提出的一個(gè)基于流的模型。下圖展示了Glow的架構(gòu)。

          Glow的架構(gòu)

          https://papers.nips.cc/paper/2018/file/d139db6a236200b21cc7f752979132d0-Paper.pdf

          Glow架構(gòu)由多個(gè)表層(superficial layers)組合而成。首先我們來(lái)看看Glow模型的多尺度框架。Glow模型由一系列的重復(fù)層(命名為尺度)組成。每個(gè)尺度包括一個(gè)擠壓函數(shù)和一個(gè)流步驟,每個(gè)流步驟包含ActNorm、1x1 Convolution和Coupling Layer,流步驟后是分割函數(shù)。分割函數(shù)在通道維度上將輸入分成兩個(gè)相等的部分。其中一半進(jìn)入之后的層,另一半則進(jìn)入損失函數(shù)。分割是為了減少梯度消失的影響,梯度消失會(huì)在模型以端到端方式(end-to-end)訓(xùn)練時(shí)出現(xiàn)。

          如下圖所示,擠壓函數(shù)(squeeze function)通過(guò)橫向重塑張量,將大小為[c, h, w]的輸入張量轉(zhuǎn)換為大小為[4c, h/2, w/2]的張量。此外,在測(cè)試階段可以采用重塑函數(shù),將輸入的[4c, h/2, w/2]重塑為大小為[c, h, w]的張量。

          https://arxiv.org/pdf/1605.08803.pdf

          其他層,如ActNorm、1x1 Convolution和Affine Coupling層,可以從下表理解。該表展示了每層的功能(包括正向和反向)。

          https://arxiv.org/pdf/1605.08803.pdf


          實(shí)現(xiàn)

          在了解了Normalizing Flow和Glow模型的基礎(chǔ)知識(shí)后,我們將介紹如何使用PyTorch實(shí)現(xiàn)該模型,并在MNIST數(shù)據(jù)集上進(jìn)行訓(xùn)練。

          Glow模型

          首先,我們將使用PyTorch和nflows實(shí)現(xiàn)Glow架構(gòu)。為了節(jié)省時(shí)間,我們使用nflows包含所有層的實(shí)現(xiàn)。

          import torchimport torch.nn as nnimport torch.nn.functional as Ffrom nflows import transformsimport numpy as npfrom torchvision.transforms.functional import resizefrom nflows.transforms.base import Transform
          class Net(nn.Module):
          def __init__(self, in_channel, out_channels): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_channel, 64, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(64, 64, 1), nn.ReLU(inplace=True), ZeroConv2d(64, out_channels), )
          def forward(self, inp, context=None): return self.net(inp)

          def getGlowStep(num_channels, crop_size, i): mask = [1] * num_channels
          if i % 2 == 0: mask[::2] = [-1] * (len(mask[::2])) else: mask[1::2] = [-1] * (len(mask[1::2]))
          def getNet(in_channel, out_channels): return Net(in_channel, out_channels)
          return transforms.CompositeTransform([ transforms.ActNorm(num_channels), transforms.OneByOneConvolution(num_channels), transforms.coupling.AffineCouplingTransform(mask, getNet) ])


          def getGlowScale(num_channels, num_flow, crop_size): z = [getGlowStep(num_channels, crop_size, i) for i in range(num_flow)] return transforms.CompositeTransform([ transforms.SqueezeTransform(), *z ])

          def getGLOW(): num_channels = 1 * 4 num_flow = 32 num_scale = 3 crop_size = 28 // 2 transform = transforms.MultiscaleCompositeTransform(num_scale) for i in range(num_scale): next_input = transform.add_transform(getGlowScale(num_channels, num_flow, crop_size), [num_channels, crop_size, crop_size]) num_channels *= 2 crop_size //= 2
          return transform
          Glow_model = getGLOW()


          我們可以用各種數(shù)據(jù)集來(lái)訓(xùn)練Glow模型,如MNIST、CIFAR-10、ImageNet等。本文為了演示方便,使用的是MNIST數(shù)據(jù)集。

          像MNIST(https://gas.graviti.cn/dataset/data-decorators/MNIST)這樣的數(shù)據(jù)集可以很容易地從格物鈦開(kāi)放數(shù)據(jù)集平臺(tái)https://gas.graviti.cn/open-datasets)獲取,該平臺(tái)包含了機(jī)器學(xué)習(xí)中所有常用的開(kāi)放數(shù)據(jù)集,如分類(lèi)、密度估計(jì)、物體檢測(cè)和基于文本的分類(lèi)數(shù)據(jù)集等。


          要訪問(wèn)數(shù)據(jù)集,我們只需要在格物鈦的平臺(tái)上創(chuàng)建賬戶(hù),就可以直接fork想要的數(shù)據(jù)集,可以直接下載或者使用格物鈦提供的pipeline導(dǎo)入數(shù)據(jù)集?;镜拇a和相關(guān)文檔可在TensorBay的支持網(wǎng)頁(yè)上獲得(graviti.cn/tensorBay)。

          結(jié)合格物鈦TensorBay的Python SDK,我們可以很方便地導(dǎo)入MNIST數(shù)據(jù)集到PyTorch中:

          from PIL import Imagefrom torch.utils.data import DataLoader, Datasetfrom torchvision import transforms
          from tensorbay import GASfrom tensorbay.dataset import Dataset as TensorBayDataset
          class MNISTSegment(Dataset):
          def __init__(self, gas, segment_name, transform): super().__init__() self.dataset = TensorBayDataset("MNIST", gas) self.segment = self.dataset[segment_name] self.category_to_index = self.dataset.catalog.classification.get_category_to_index() self.transform = transform
          def __len__(self): return len(self.segment)
          def __getitem__(self, idx): data = self.segment[idx] with data.open() as fp: image_tensor = self.transform(Image.open(fp))
          return image_tensor, self.category_to_index[data.label.classification.category]


          模型訓(xùn)練

          模型訓(xùn)練可以通過(guò)下面的代碼簡(jiǎn)單開(kāi)始。該代碼使用格物鈦TensorBay提供的Pipeline創(chuàng)建數(shù)據(jù)加載器,其中的ACCESS_KEY可以在TensorBay的賬戶(hù)設(shè)置中獲得。

          from nflows.distributions import normal
          ACCESS_KEY = "Accesskey-*****"EPOCH = 100
          to_tensor = transforms.ToTensor()normalization = transforms.Normalize(mean=[0.485], std=[0.229])my_transforms = transforms.Compose([to_tensor, normalization])
          train_segment = MNISTSegment(GAS(ACCESS_KEY), segment_name="train", transform=my_transforms)train_dataloader = DataLoader(train_segment, batch_size=4, shuffle=True, num_workers=4)
          optimizer = torch.optim.Adam(Glow_model.parameters(), 1e-3)
          for epoch in range(EPOCH): for index, (image, label) in enumerate(train_dataloader): if index == 0: image_size = image.shaape[2] channels = image.shape[1] image = image.cuda() output, logabsdet = Glow_model._transform(image) shape = output.shape[1:] log_z = normal.StandardNormal(shape=shape).log_prob(output) loss = log_z + logabsdet loss = -loss.mean()/(image_size * image_size * channels) optimizer.zero_grad() loss.backward() optimizer.step() print(f"Epoch:{epoch+1}/{EPOCH} Loss:{loss}")

          上面代碼用的是MNIST數(shù)據(jù)集,要想使用其他數(shù)據(jù)集我們可以直接替換該數(shù)據(jù)集的數(shù)據(jù)加載器。

          樣例生成

          模型訓(xùn)練完成之后,我們可以通過(guò)下面的代碼來(lái)生成樣例:

          samples = Glow_model.sample(25)display(samples)

          使用nflows庫(kù)之后,我們只需要用一行代碼就可以生成樣例,而display函數(shù)則能在一個(gè)網(wǎng)格中顯示生成的樣本。

          用MNIST訓(xùn)練模型之后生成的樣例

          結(jié)語(yǔ)

          本文向大家介紹了Normalizing Flow的基本知識(shí),并與GAN和VAE進(jìn)行了比較,同時(shí)向大家展示了Glow模型的基本工作方式。我們還講解了如何簡(jiǎn)單實(shí)現(xiàn)Glow模型,并使用MNIST數(shù)據(jù)集進(jìn)行訓(xùn)練。在格物鈦公開(kāi)數(shù)據(jù)集平臺(tái)的幫助下,數(shù)據(jù)集訪問(wèn)變得十分便捷。


          關(guān)于「格物鈦」

          格物鈦定位為面向機(jī)器學(xué)習(xí)的數(shù)據(jù)平臺(tái),幫助AI開(kāi)發(fā)者解決日益增長(zhǎng)的非結(jié)構(gòu)化數(shù)據(jù)難題。借助非結(jié)構(gòu)化數(shù)據(jù)管理平臺(tái)TensorBay和開(kāi)源數(shù)據(jù)集社區(qū)Open Datasets,機(jī)器學(xué)習(xí)團(tuán)隊(duì)和個(gè)人可進(jìn)行數(shù)據(jù)管理、查詢(xún)、協(xié)同、可視化和版本控制等高效操作,降低高質(zhì)量數(shù)據(jù)獲取、存儲(chǔ)和處理成本,加速AI開(kāi)發(fā)和產(chǎn)品創(chuàng)新。


          Open Datasets  ??

          格物鈦|公開(kāi)數(shù)據(jù)集

          graviti.cn/open-datasets


          訂閱號(hào):格物鈦  ??

          微信號(hào)|Graviti_2019

          微博|格物鈦

          https://www.graviti.cn/




          點(diǎn)擊閱讀原文 / 訪問(wèn)格物鈦官網(wǎng)

          瀏覽 127
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  婷婷精品伊人婷婷精品一区的 | 三级毛骗免费看电影 | 亚洲婷婷综合网 | 亚洲成人毛片 | 91精品久久久成人无码 |