擴(kuò)散模型之DDIM:為擴(kuò)散模型的生成過(guò)程提速!
點(diǎn)藍(lán)色字關(guān)注
“機(jī)器學(xué)習(xí)算法工程師
”
設(shè)為 星標(biāo) ,干貨直達(dá)!
“What I cannot create, I do not understand.” -- Richard Feynman
上一篇文章擴(kuò)散模型之DDPM帶你深入理解擴(kuò)散模型DDPM介紹了經(jīng)典擴(kuò)散模型DDPM的原理和實(shí)現(xiàn),對(duì)于擴(kuò)散模型來(lái)說(shuō),一個(gè)最大的缺點(diǎn)是需要設(shè)置較長(zhǎng)的擴(kuò)散步數(shù)才能得到好的效果,這導(dǎo)致了生成樣本的速度較慢,比如擴(kuò)散步數(shù)為1000的話,那么生成一個(gè)樣本就要模型推理1000次。這篇文章我們將介紹另外一種擴(kuò)散模型DDIM(Denoising Diffusion Implicit Models),DDIM和DDPM有相同的訓(xùn)練目標(biāo),但是它不再限制擴(kuò)散過(guò)程必須是一個(gè)馬爾卡夫鏈,這使得DDIM可以采用更小的采樣步數(shù)來(lái)加速生成過(guò)程,DDIM的另外是一個(gè)特點(diǎn)是從一個(gè)隨機(jī)噪音生成樣本的過(guò)程是一個(gè)確定的過(guò)程(中間沒(méi)有加入隨機(jī)噪音)。
DDIM原理
在介紹DDIM之前,先來(lái)回顧一下DDPM。在DDPM中,擴(kuò)散過(guò)程(前向過(guò)程)定義為一個(gè)馬爾卡夫鏈:
注意,在DDIM的論文中,其實(shí)是DDPM論文中的,那么DDPM論文中的前向過(guò)程就為:
擴(kuò)散過(guò)程的一個(gè)重要特性是可以直接用來(lái)對(duì)任意的進(jìn)行采樣:
而DDPM的反向過(guò)程也定義為一個(gè)馬爾卡夫鏈:
這里用神經(jīng)網(wǎng)絡(luò)來(lái)擬合真實(shí)的分布。DDPM的前向過(guò)程和反向過(guò)程如下所示:
我們近一步發(fā)現(xiàn)后驗(yàn)分布是一個(gè)可獲取的高斯分布:
其中這個(gè)高斯分布的方差是定值,而均值是一個(gè)依賴和的組合函數(shù):
然后我們基于變分法得到如下的優(yōu)化目標(biāo):
根據(jù)兩個(gè)高斯分布的KL公式,我們近一步得到:
根據(jù)擴(kuò)散過(guò)程的特性,我們通過(guò)重參數(shù)化可以近一步簡(jiǎn)化上述目標(biāo):
如果去掉系數(shù),那么就能得到更簡(jiǎn)化的優(yōu)化目標(biāo):
仔細(xì)分析DDPM的優(yōu)化目標(biāo)會(huì)發(fā)現(xiàn),DDPM其實(shí)僅僅依賴邊緣分布,而并不是直接作用在聯(lián)合分布。這帶來(lái)的一個(gè)啟示是:DDPM這個(gè)隱變量模型可以有很多推理分布來(lái)選擇,只要推理分布滿足邊緣分布條件(擴(kuò)散過(guò)程的特性)即可,而且這些推理過(guò)程并不一定要是馬爾卡夫鏈。但值得注意的一個(gè)點(diǎn)是,我們要得到DDPM的優(yōu)化目標(biāo),還需要知道分布,之前我們?cè)诟鶕?jù)貝葉斯公式推導(dǎo)這個(gè)分布時(shí)是知道分布的,而且依賴了前向過(guò)程的馬爾卡夫鏈特性。如果要解除對(duì)前向過(guò)程的依賴,那么我們就需要直接定義這個(gè)分布。 基于上述分析,DDIM論文中將推理分布定義為:
這里要同時(shí)滿足以及對(duì)于所有的有:
這里的方差是一個(gè)實(shí)數(shù),不同的設(shè)置就是不一樣的分布,所以其實(shí)是一系列的推理分布??梢钥吹竭@里分布的均值也定義為一個(gè)依賴和的組合函數(shù),之所以定義為這樣的形式,是因?yàn)楦鶕?jù),我們可以通過(guò)數(shù)學(xué)歸納法證明,對(duì)于所有的均滿足:
這部分的證明見(jiàn)DDIM論文的附錄部分,另外博客生成擴(kuò)散模型漫談(四):DDIM = 高觀點(diǎn)DDPM也從待定系數(shù)法來(lái)證明了分布要構(gòu)造的形式。 可以看到這里定義的推理分布并沒(méi)有直接定義前向過(guò)程,但這里滿足了我們前面要討論的兩個(gè)條件:邊緣分布,同時(shí)已知后驗(yàn)分布。同樣地,我們可以按照和DDPM的一樣的方式去推導(dǎo)優(yōu)化目標(biāo),最終也會(huì)得到同樣的
(雖然VLB的系數(shù)不同,論文3.2部分也證明了這個(gè)結(jié)論)。
論文也給出了一個(gè)前向過(guò)程是非馬爾可夫鏈的示例,如下圖所示,這里前向過(guò)程是,由于生成不僅依賴,而且依賴,所以是一個(gè)非馬爾可夫鏈:
注意,這里只是一個(gè)前向過(guò)程的示例,而實(shí)際上我們上述定義的推理分布并不需要前向過(guò)程就可以得到和DDPM一樣的優(yōu)化目標(biāo)。與DDPM一樣,這里也是用神經(jīng)網(wǎng)絡(luò)來(lái)預(yù)測(cè)噪音,那么根據(jù)的形式,在生成階段,我們可以用如下公式來(lái)從生成:
這里將生成過(guò)程分成三個(gè)部分:一是由預(yù)測(cè)的來(lái)產(chǎn)生的,二是由指向的部分,三是隨機(jī)噪音(這里是與無(wú)關(guān)的噪音)。論文將近一步定義為:
這里考慮兩種情況,一是,此時(shí),此時(shí)生成過(guò)程就和DDPM一樣了。另外一種情況是,這個(gè)時(shí)候生成過(guò)程就沒(méi)有隨機(jī)噪音了,是一個(gè)確定性的過(guò)程,論文將這種情況下的模型稱為DDIM(denoising diffusion implicit model),一旦最初的隨機(jī)噪音確定了,那么DDIM的樣本生成就變成了確定的過(guò)程。
上面我們終于得到了DDIM模型,那么我們現(xiàn)在來(lái)看如何來(lái)加速生成過(guò)程。雖然DDIM和DDPM的訓(xùn)練過(guò)程一樣,但是我們前面已經(jīng)說(shuō)了,DDIM并沒(méi)有明確前向過(guò)程,這意味著我們可以定義一個(gè)更短的步數(shù)的前向過(guò)程。具體地,這里我們從原始的序列采樣一個(gè)長(zhǎng)度為的子序列,我們將的前向過(guò)程定義為一個(gè)馬爾卡夫鏈,并且它們滿足:。下圖展示了一個(gè)具體的示例:
那么生成過(guò)程也可以用這個(gè)子序列的反向馬爾卡夫鏈來(lái)替代,由于可以設(shè)置比原來(lái)的步數(shù)要小,那么就可以加速生成過(guò)程。這里的生成過(guò)程變成:
其實(shí)上述的加速,我們是將前向過(guò)程按如下方式進(jìn)行了分解:
其中。這包含了兩個(gè)圖:其中一個(gè)就是由組成的馬爾可夫鏈,另外一個(gè)是剩余的變量組成的星狀圖。同時(shí)生成過(guò)程,我們也只用馬爾可夫鏈的那部分來(lái)生成:
論文共設(shè)計(jì)了兩種方法來(lái)采樣子序列,分別是:
- Linear:采用線性的序列;
- Quadratic:采樣二次方的序列;
這里的是一個(gè)定值,它的設(shè)定使得最接近。論文中只對(duì)CIFAR10數(shù)據(jù)集采用Quadratic序列,其它數(shù)據(jù)集均采用Linear序列。
實(shí)驗(yàn)結(jié)果
下表為不同的下以及不同采樣步數(shù)下的對(duì)比結(jié)果,可以看到DDIM()在較短的步數(shù)下就能得到比較好的效果,媲美DDPM()的生成效果。如果設(shè)置為50,那么相比原來(lái)的生成過(guò)程就可以加速20倍。
代碼實(shí)現(xiàn)
DDIM和DDPM的訓(xùn)練過(guò)程一樣,所以可以直接在DDPM的基礎(chǔ)上加一個(gè)新的生成方法(這里主要參考了DDIM官方代碼以及diffusers庫(kù)),具體代碼如下所示:
class?GaussianDiffusion:
????def?__init__(self,?timesteps=1000,?beta_schedule='linear'):
?????pass
????#?...
????????
?#?use?ddim?to?sample
[email protected]_grad()
????def?ddim_sample(
????????self,
????????model,
????????image_size,
????????batch_size=8,
????????channels=3,
????????ddim_timesteps=50,
????????ddim_discr_method="uniform",
????????ddim_eta=0.0,
????????clip_denoised=True):
????????#?make?ddim?timestep?sequence
????????if?ddim_discr_method?==?'uniform':
????????????c?=?self.timesteps?//?ddim_timesteps
????????????ddim_timestep_seq?=?np.asarray(list(range(0,?self.timesteps,?c)))
????????elif?ddim_discr_method?==?'quad':
????????????ddim_timestep_seq?=?(
????????????????(np.linspace(0,?np.sqrt(self.timesteps?*?.8),?ddim_timesteps))?**?2
????????????).astype(int)
????????else:
????????????raise?NotImplementedError(f'There?is?no?ddim?discretization?method?called?"{ddim_discr_method}"')
????????#?add?one?to?get?the?final?alpha?values?right?(the?ones?from?first?scale?to?data?during?sampling)
????????ddim_timestep_seq?=?ddim_timestep_seq?+?1
????????#?previous?sequence
????????ddim_timestep_prev_seq?=?np.append(np.array([0]),?ddim_timestep_seq[:-1])
????????
????????device?=?next(model.parameters()).device
????????#?start?from?pure?noise?(for?each?example?in?the?batch)
????????sample_img?=?torch.randn((batch_size,?channels,?image_size,?image_size),?device=device)
????????for?i?in?tqdm(reversed(range(0,?ddim_timesteps)),?desc='sampling?loop?time?step',?total=ddim_timesteps):
????????????t?=?torch.full((batch_size,),?ddim_timestep_seq[i],?device=device,?dtype=torch.long)
????????????prev_t?=?torch.full((batch_size,),?ddim_timestep_prev_seq[i],?device=device,?dtype=torch.long)
????????????
????????????#?1.?get?current?and?previous?alpha_cumprod
????????????alpha_cumprod_t?=?self._extract(self.alphas_cumprod,?t,?sample_img.shape)
????????????alpha_cumprod_t_prev?=?self._extract(self.alphas_cumprod,?prev_t,?sample_img.shape)
????
????????????#?2.?predict?noise?using?model
????????????pred_noise?=?model(sample_img,?t)
????????????
????????????#?3.?get?the?predicted?x_0
????????????pred_x0?=?(sample_img?-?torch.sqrt((1.?-?alpha_cumprod_t))?*?pred_noise)?/?torch.sqrt(alpha_cumprod_t)
????????????if?clip_denoised:
????????????????pred_x0?=?torch.clamp(pred_x0,?min=-1.,?max=1.)
????????????
????????????#?4.?compute?variance:?"sigma_t(η)"?->?see?formula?(16)
????????????#?σ_t?=?sqrt((1???α_t?1)/(1???α_t))?*?sqrt(1???α_t/α_t?1)
????????????sigmas_t?=?ddim_eta?*?torch.sqrt(
????????????????(1?-?alpha_cumprod_t_prev)?/?(1?-?alpha_cumprod_t)?*?(1?-?alpha_cumprod_t?/?alpha_cumprod_t_prev))
????????????
????????????#?5.?compute?"direction?pointing?to?x_t"?of?formula?(12)
????????????pred_dir_xt?=?torch.sqrt(1?-?alpha_cumprod_t_prev?-?sigmas_t**2)?*?pred_noise
????????????
????????????#?6.?compute?x_{t-1}?of?formula?(12)
????????????x_prev?=?torch.sqrt(alpha_cumprod_t_prev)?*?pred_x0?+?pred_dir_xt?+?sigmas_t?*?torch.randn_like(sample_img)
????????????sample_img?=?x_prev
????????????
????????return?sample_img.cpu().numpy()
這里以MNIST數(shù)據(jù)集為例,訓(xùn)練的擴(kuò)散步數(shù)為500,直接采用DDPM(即推理500次)生成的樣本如下所示:
同樣的模型,我們采用DDIM來(lái)加速生成過(guò)程,這里DDIM的采樣步數(shù)為50,其生成的樣本質(zhì)量和500步的DDPM相當(dāng):
完整的代碼示例見(jiàn)https://github.com/xiaohu2015/nngen。
小結(jié)
如果從直觀上看,DDIM的加速方式非常簡(jiǎn)單,直接采樣一個(gè)子序列,其實(shí)論文DDPM+也采用了類似的方式來(lái)加速。另外DDIM和其它擴(kuò)散模型的一個(gè)較大的區(qū)別是其生成過(guò)程是確定性的。
參考
- Denoising Diffusion Implicit Models
- https://github.com/ermongroup/ddim
- https://github.com/openai/improved-diffusion
- https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py
- https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddim.py
- https://kexue.fm/archives/9181
推薦閱讀
輔助模塊加速收斂,精度大幅提升!移動(dòng)端實(shí)時(shí)的NanoDet-Plus來(lái)了!
SSD的torchvision版本實(shí)現(xiàn)詳解
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)

