超越GAN?OpenAI提出可逆生成模型Glow!圖像生成太逼真
點(diǎn)擊下方卡片,關(guān)注“CVer”公眾號(hào)
AI/CV重磅干貨,第一時(shí)間送達(dá)
作者:Aryansh Omray,微軟數(shù)據(jù)科學(xué)工程師,Medium技術(shù)博主
機(jī)器學(xué)習(xí)領(lǐng)域的一個(gè)基本問(wèn)題就是如何學(xué)習(xí)復(fù)雜數(shù)據(jù)的表征是機(jī)器學(xué)習(xí)。這項(xiàng)任務(wù)的重要性在于,現(xiàn)存的大量非結(jié)構(gòu)化和無(wú)標(biāo)簽的數(shù)據(jù),只有通過(guò)無(wú)監(jiān)督式學(xué)習(xí)才能理解。密度估計(jì)、異常檢測(cè)、文本總結(jié)、數(shù)據(jù)聚類(lèi)、生物信息學(xué)、DNA建模等各方面的應(yīng)用均需要完成這項(xiàng)任務(wù)。多年來(lái),研究人員發(fā)明了許多方法來(lái)學(xué)習(xí)大型數(shù)據(jù)集的概率分布,包括生成對(duì)抗網(wǎng)絡(luò)(GAN)、變分自編碼器(VAE)和Normalizing Flow等。本文即向大家介紹Normalizing Flow這一為了克服GAN和VAE的不足而提出的方法。

Glow模型的輸出樣例
https://papers.nips.cc/paper/2018/file/d139db6a236200b21cc7f752979132d0-Paper.pdf
GAN和VAE的能力本已十分驚人,它們都能通過(guò)簡(jiǎn)單的推理方法學(xué)習(xí)十分復(fù)雜的數(shù)據(jù)分布。然而,GAN和VAE都缺乏對(duì)概率分布的精確評(píng)估和推理,這往往導(dǎo)致VAE中的模糊結(jié)果質(zhì)量不高,GAN訓(xùn)練也面臨著如模式崩潰和后置崩潰等挑戰(zhàn)。因此,Normalizing Flow應(yīng)運(yùn)而生,試圖通過(guò)使用可逆函數(shù)來(lái)解決目前GAN和VAE存在的許多問(wèn)題。
Normalizing Flow
簡(jiǎn)單地說(shuō),Normalizing Flow就是一系列的可逆函數(shù),或者說(shuō)這些函數(shù)的解析逆是可以計(jì)算的。例如,f(x)=x+2是一個(gè)可逆函數(shù),因?yàn)槊總€(gè)輸入都有且僅有一個(gè)唯一的輸出,并且反之亦然,而f(x)=x2則不是一個(gè)可逆函數(shù)。這樣的函數(shù)也被稱(chēng)為雙射函數(shù)。

圖源作者
從上圖可以看出,Normalizing Flow可以將復(fù)雜的數(shù)據(jù)點(diǎn)(如MNIST中的圖像)轉(zhuǎn)化為簡(jiǎn)單的高斯分布,反之亦然。和GAN非常不一樣的地方是,GAN輸入的是一個(gè)隨機(jī)向量,而輸出的是一個(gè)圖像,基于流(Flow)的模型則是將數(shù)據(jù)點(diǎn)轉(zhuǎn)化為簡(jiǎn)單分布。在上圖的MNIST一例中,我們從高斯分布中抽取隨機(jī)樣本,均可重新獲得其對(duì)應(yīng)的MNIST圖像。
基于流的模型使用負(fù)對(duì)數(shù)可能性損失函數(shù)進(jìn)行訓(xùn)練,其中p(z)是概率函數(shù)。下面的損失函數(shù)就是使用統(tǒng)計(jì)學(xué)中的變量變化公式得到的。

https://papers.nips.cc/paper/2018/file/d139db6a236200b21cc7f752979132d0-Paper.pdf
Normalizing Flow的優(yōu)勢(shì)
與GAN和VAE相比,Normalizing Flow具有各種優(yōu)勢(shì),包括:
Normalizing Flow模型不需要在輸出中放入噪聲,因此可以有更強(qiáng)大的局部方差模型(local variance model);
與GAN相比,基于流的模型訓(xùn)練過(guò)程非常穩(wěn)定,GAN則需要仔細(xì)調(diào)整生成器和判別器的超參數(shù);
與GAN和VAE相比,Normalizing Flow更容易收斂。
Normalizing Flow的不足
雖然基于流的模型有其優(yōu)勢(shì),但它們也有一些缺點(diǎn):
基于流的模型在密度估計(jì)等任務(wù)上的表現(xiàn)不盡如人意;
基于流的模型要求保留變換的體積(volume preservation over transformations),這往往會(huì)產(chǎn)生非常高維的潛在空間,通常會(huì)導(dǎo)致解釋性變差;
基于流的模型產(chǎn)生的樣本通常沒(méi)有GAN和VAE的好。
為了更好地理解Normalizing Flow,我們以Glow架構(gòu)為例進(jìn)行解釋。Glow是OpenAI在2018年提出的一個(gè)基于流的模型。下圖展示了Glow的架構(gòu)。

Glow的架構(gòu)
https://papers.nips.cc/paper/2018/file/d139db6a236200b21cc7f752979132d0-Paper.pdf
Glow架構(gòu)由多個(gè)表層(superficial layers)組合而成。首先我們來(lái)看看Glow模型的多尺度框架。Glow模型由一系列的重復(fù)層(命名為尺度)組成。每個(gè)尺度包括一個(gè)擠壓函數(shù)和一個(gè)流步驟,每個(gè)流步驟包含ActNorm、1x1 Convolution和Coupling Layer,流步驟后是分割函數(shù)。分割函數(shù)在通道維度上將輸入分成兩個(gè)相等的部分。其中一半進(jìn)入之后的層,另一半則進(jìn)入損失函數(shù)。分割是為了減少梯度消失的影響,梯度消失會(huì)在模型以端到端方式(end-to-end)訓(xùn)練時(shí)出現(xiàn)。
如下圖所示,擠壓函數(shù)(squeeze function)通過(guò)橫向重塑張量,將大小為[c, h, w]的輸入張量轉(zhuǎn)換為大小為[4c, h/2, w/2]的張量。此外,在測(cè)試階段可以采用重塑函數(shù),將輸入的[4c, h/2, w/2]重塑為大小為[c, h, w]的張量。

https://arxiv.org/pdf/1605.08803.pdf
其他層,如ActNorm、1x1 Convolution和Affine Coupling層,可以從下表理解。該表展示了每層的功能(包括正向和反向)。

https://arxiv.org/pdf/1605.08803.pdf
實(shí)現(xiàn)
在了解了Normalizing Flow和Glow模型的基礎(chǔ)知識(shí)后,我們將介紹如何使用PyTorch實(shí)現(xiàn)該模型,并在MNIST數(shù)據(jù)集上進(jìn)行訓(xùn)練。
Glow模型
首先,我們將使用PyTorch和nflows實(shí)現(xiàn)Glow架構(gòu)。為了節(jié)省時(shí)間,我們使用nflows包含所有層的實(shí)現(xiàn)。
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom nflows import transformsimport numpy as npfrom torchvision.transforms.functional import resizefrom nflows.transforms.base import Transformclass Net(nn.Module):def __init__(self, in_channel, out_channels):super().__init__()self.net = nn.Sequential(nn.Conv2d(in_channel, 64, 3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, 1),nn.ReLU(inplace=True),ZeroConv2d(64, out_channels),)def forward(self, inp, context=None):return self.net(inp)def getGlowStep(num_channels, crop_size, i):mask = [1] * num_channelsif i % 2 == 0:mask[::2] = [-1] * (len(mask[::2]))else:mask[1::2] = [-1] * (len(mask[1::2]))def getNet(in_channel, out_channels):return Net(in_channel, out_channels)return transforms.CompositeTransform([transforms.ActNorm(num_channels),transforms.OneByOneConvolution(num_channels),transforms.coupling.AffineCouplingTransform(mask, getNet)])def getGlowScale(num_channels, num_flow, crop_size):z = [getGlowStep(num_channels, crop_size, i) for i in range(num_flow)]return transforms.CompositeTransform([transforms.SqueezeTransform(),*z])def getGLOW():num_channels = 1 * 4num_flow = 32num_scale = 3crop_size = 28 // 2transform = transforms.MultiscaleCompositeTransform(num_scale)for i in range(num_scale):next_input = transform.add_transform(getGlowScale(num_channels, num_flow, crop_size),[num_channels, crop_size, crop_size])num_channels *= 2crop_size //= 2return transformGlow_model = getGLOW()
我們可以用各種數(shù)據(jù)集來(lái)訓(xùn)練Glow模型,如MNIST、CIFAR-10、ImageNet等。本文為了演示方便,使用的是MNIST數(shù)據(jù)集。
像MNIST(https://gas.graviti.cn/dataset/data-decorators/MNIST)這樣的數(shù)據(jù)集可以很容易地從格物鈦開(kāi)放數(shù)據(jù)集平臺(tái)(https://gas.graviti.cn/open-datasets)獲取,該平臺(tái)包含了機(jī)器學(xué)習(xí)中所有常用的開(kāi)放數(shù)據(jù)集,如分類(lèi)、密度估計(jì)、物體檢測(cè)和基于文本的分類(lèi)數(shù)據(jù)集等。

要訪問(wèn)數(shù)據(jù)集,我們只需要在格物鈦的平臺(tái)上創(chuàng)建賬戶(hù),就可以直接fork想要的數(shù)據(jù)集,可以直接下載或者使用格物鈦提供的pipeline導(dǎo)入數(shù)據(jù)集?;镜拇a和相關(guān)文檔可在TensorBay的支持網(wǎng)頁(yè)上獲得(graviti.cn/tensorBay)。

結(jié)合格物鈦TensorBay的Python SDK,我們可以很方便地導(dǎo)入MNIST數(shù)據(jù)集到PyTorch中:
from PIL import Imagefrom torch.utils.data import DataLoader, Datasetfrom torchvision import transformsfrom tensorbay import GASfrom tensorbay.dataset import Dataset as TensorBayDatasetclass MNISTSegment(Dataset):def __init__(self, gas, segment_name, transform):super().__init__()self.dataset = TensorBayDataset("MNIST", gas)self.segment = self.dataset[segment_name]self.category_to_index = self.dataset.catalog.classification.get_category_to_index()self.transform = transformdef __len__(self):return len(self.segment)def __getitem__(self, idx):data = self.segment[idx]with data.open() as fp:image_tensor = self.transform(Image.open(fp))return image_tensor, self.category_to_index[data.label.classification.category]
模型訓(xùn)練
模型訓(xùn)練可以通過(guò)下面的代碼簡(jiǎn)單開(kāi)始。該代碼使用格物鈦TensorBay提供的Pipeline創(chuàng)建數(shù)據(jù)加載器,其中的ACCESS_KEY可以在TensorBay的賬戶(hù)設(shè)置中獲得。
from nflows.distributions import normalACCESS_KEY = "Accesskey-*****"EPOCH = 100to_tensor = transforms.ToTensor()normalization = transforms.Normalize(mean=[0.485], std=[0.229])my_transforms = transforms.Compose([to_tensor, normalization])train_segment = MNISTSegment(GAS(ACCESS_KEY), segment_name="train", transform=my_transforms)train_dataloader = DataLoader(train_segment, batch_size=4, shuffle=True, num_workers=4)optimizer = torch.optim.Adam(Glow_model.parameters(), 1e-3)for epoch in range(EPOCH):for index, (image, label) in enumerate(train_dataloader):if index == 0:image_size = image.shaape[2]channels = image.shape[1]image = image.cuda()output, logabsdet = Glow_model._transform(image)shape = output.shape[1:]log_z = normal.StandardNormal(shape=shape).log_prob(output)loss = log_z + logabsdetloss = -loss.mean()/(image_size * image_size * channels)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch:{epoch+1}/{EPOCH} Loss:{loss}")
上面代碼用的是MNIST數(shù)據(jù)集,要想使用其他數(shù)據(jù)集我們可以直接替換該數(shù)據(jù)集的數(shù)據(jù)加載器。
樣例生成
模型訓(xùn)練完成之后,我們可以通過(guò)下面的代碼來(lái)生成樣例:
samples = Glow_model.sample(25)display(samples)
使用nflows庫(kù)之后,我們只需要用一行代碼就可以生成樣例,而display函數(shù)則能在一個(gè)網(wǎng)格中顯示生成的樣本。

用MNIST訓(xùn)練模型之后生成的樣例
結(jié)語(yǔ)
本文向大家介紹了Normalizing Flow的基本知識(shí),并與GAN和VAE進(jìn)行了比較,同時(shí)向大家展示了Glow模型的基本工作方式。我們還講解了如何簡(jiǎn)單實(shí)現(xiàn)Glow模型,并使用MNIST數(shù)據(jù)集進(jìn)行訓(xùn)練。在格物鈦公開(kāi)數(shù)據(jù)集平臺(tái)的幫助下,數(shù)據(jù)集訪問(wèn)變得十分便捷。

關(guān)于「格物鈦」
格物鈦定位為面向機(jī)器學(xué)習(xí)的數(shù)據(jù)平臺(tái),幫助AI開(kāi)發(fā)者解決日益增長(zhǎng)的非結(jié)構(gòu)化數(shù)據(jù)難題。借助非結(jié)構(gòu)化數(shù)據(jù)管理平臺(tái)TensorBay和開(kāi)源數(shù)據(jù)集社區(qū)Open Datasets,機(jī)器學(xué)習(xí)團(tuán)隊(duì)和個(gè)人可進(jìn)行數(shù)據(jù)管理、查詢(xún)、協(xié)同、可視化和版本控制等高效操作,降低高質(zhì)量數(shù)據(jù)獲取、存儲(chǔ)和處理成本,加速AI開(kāi)發(fā)和產(chǎn)品創(chuàng)新。
Open Datasets ??
格物鈦|公開(kāi)數(shù)據(jù)集
graviti.cn/open-datasets
訂閱號(hào):格物鈦 ??
微信號(hào)|Graviti_2019
微博|格物鈦
https://www.graviti.cn/
點(diǎn)擊閱讀原文 / 訪問(wèn)格物鈦官網(wǎng)
