<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          擴(kuò)散模型之DDIM:為擴(kuò)散模型的生成過(guò)程提速!

          共 7031字,需瀏覽 15分鐘

           ·

          2022-09-20 22:06

          fdc66e3ee1831499b3f2b5bd32fd7321.webp點(diǎn)藍(lán)色字關(guān)注 “機(jī)器學(xué)習(xí)算法工程師

          設(shè)為 星標(biāo) ,干貨直達(dá)!


          “What I cannot create, I do not understand.” -- Richard Feynman

          上一篇文章擴(kuò)散模型之DDPM帶你深入理解擴(kuò)散模型DDPM介紹了經(jīng)典擴(kuò)散模型DDPM的原理和實(shí)現(xiàn),對(duì)于擴(kuò)散模型來(lái)說(shuō),一個(gè)最大的缺點(diǎn)是需要設(shè)置較長(zhǎng)的擴(kuò)散步數(shù)才能得到好的效果,這導(dǎo)致了生成樣本的速度較慢,比如擴(kuò)散步數(shù)為1000的話,那么生成一個(gè)樣本就要模型推理1000次。這篇文章我們將介紹另外一種擴(kuò)散模型DDIM(Denoising Diffusion Implicit Models),DDIM和DDPM有相同的訓(xùn)練目標(biāo),但是它不再限制擴(kuò)散過(guò)程必須是一個(gè)馬爾卡夫鏈,這使得DDIM可以采用更小的采樣步數(shù)來(lái)加速生成過(guò)程,DDIM的另外是一個(gè)特點(diǎn)是從一個(gè)隨機(jī)噪音生成樣本的過(guò)程是一個(gè)確定的過(guò)程(中間沒(méi)有加入隨機(jī)噪音)。

          DDIM原理

          在介紹DDIM之前,先來(lái)回顧一下DDPM。在DDPM中,擴(kuò)散過(guò)程(前向過(guò)程)定義為一個(gè)馬爾卡夫鏈:

          注意,在DDIM的論文中,其實(shí)是DDPM論文中的,那么DDPM論文中的前向過(guò)程就為:

          擴(kuò)散過(guò)程的一個(gè)重要特性是可以直接用來(lái)對(duì)任意的進(jìn)行采樣:

          而DDPM的反向過(guò)程也定義為一個(gè)馬爾卡夫鏈:

          這里用神經(jīng)網(wǎng)絡(luò)來(lái)擬合真實(shí)的分布。DDPM的前向過(guò)程和反向過(guò)程如下所示:f4d928802a45806d5de2af79e69dbff0.webp我們近一步發(fā)現(xiàn)后驗(yàn)分布是一個(gè)可獲取的高斯分布:

          其中這個(gè)高斯分布的方差是定值,而均值是一個(gè)依賴和的組合函數(shù):

          然后我們基于變分法得到如下的優(yōu)化目標(biāo):

          根據(jù)兩個(gè)高斯分布的KL公式,我們近一步得到:

          根據(jù)擴(kuò)散過(guò)程的特性,我們通過(guò)重參數(shù)化可以近一步簡(jiǎn)化上述目標(biāo):

          如果去掉系數(shù),那么就能得到更簡(jiǎn)化的優(yōu)化目標(biāo):

          仔細(xì)分析DDPM的優(yōu)化目標(biāo)會(huì)發(fā)現(xiàn),DDPM其實(shí)僅僅依賴邊緣分布,而并不是直接作用在聯(lián)合分布。這帶來(lái)的一個(gè)啟示是:DDPM這個(gè)隱變量模型可以有很多推理分布來(lái)選擇,只要推理分布滿足邊緣分布條件(擴(kuò)散過(guò)程的特性)即可,而且這些推理過(guò)程并不一定要是馬爾卡夫鏈。但值得注意的一個(gè)點(diǎn)是,我們要得到DDPM的優(yōu)化目標(biāo),還需要知道分布,之前我們?cè)诟鶕?jù)貝葉斯公式推導(dǎo)這個(gè)分布時(shí)是知道分布的,而且依賴了前向過(guò)程的馬爾卡夫鏈特性。如果要解除對(duì)前向過(guò)程的依賴,那么我們就需要直接定義這個(gè)分布。 基于上述分析,DDIM論文中將推理分布定義為:

          這里要同時(shí)滿足以及對(duì)于所有的有:

          這里的方差是一個(gè)實(shí)數(shù),不同的設(shè)置就是不一樣的分布,所以其實(shí)是一系列的推理分布??梢钥吹竭@里分布的均值也定義為一個(gè)依賴和的組合函數(shù),之所以定義為這樣的形式,是因?yàn)楦鶕?jù),我們可以通過(guò)數(shù)學(xué)歸納法證明,對(duì)于所有的均滿足:

          這部分的證明見(jiàn)DDIM論文的附錄部分,另外博客生成擴(kuò)散模型漫談(四):DDIM = 高觀點(diǎn)DDPM也從待定系數(shù)法來(lái)證明了分布要構(gòu)造的形式。 可以看到這里定義的推理分布并沒(méi)有直接定義前向過(guò)程,但這里滿足了我們前面要討論的兩個(gè)條件:邊緣分布,同時(shí)已知后驗(yàn)分布。同樣地,我們可以按照和DDPM的一樣的方式去推導(dǎo)優(yōu)化目標(biāo),最終也會(huì)得到同樣的

          (雖然VLB的系數(shù)不同,論文3.2部分也證明了這個(gè)結(jié)論)。 論文也給出了一個(gè)前向過(guò)程是非馬爾可夫鏈的示例,如下圖所示,這里前向過(guò)程是,由于生成不僅依賴,而且依賴,所以是一個(gè)非馬爾可夫鏈:0025597021062efd9f20d5ba499061a8.webp注意,這里只是一個(gè)前向過(guò)程的示例,而實(shí)際上我們上述定義的推理分布并不需要前向過(guò)程就可以得到和DDPM一樣的優(yōu)化目標(biāo)。與DDPM一樣,這里也是用神經(jīng)網(wǎng)絡(luò)來(lái)預(yù)測(cè)噪音,那么根據(jù)的形式,在生成階段,我們可以用如下公式來(lái)從生成:



          這里將生成過(guò)程分成三個(gè)部分:一是由預(yù)測(cè)的來(lái)產(chǎn)生的,二是由指向的部分,三是隨機(jī)噪音(這里是與無(wú)關(guān)的噪音)。論文將近一步定義為:


          這里考慮兩種情況,一是,此時(shí),此時(shí)生成過(guò)程就和DDPM一樣了。另外一種情況是,這個(gè)時(shí)候生成過(guò)程就沒(méi)有隨機(jī)噪音了,是一個(gè)確定性的過(guò)程,論文將這種情況下的模型稱為DDIMdenoising diffusion implicit model),一旦最初的隨機(jī)噪音確定了,那么DDIM的樣本生成就變成了確定的過(guò)程。

          上面我們終于得到了DDIM模型,那么我們現(xiàn)在來(lái)看如何來(lái)加速生成過(guò)程。雖然DDIM和DDPM的訓(xùn)練過(guò)程一樣,但是我們前面已經(jīng)說(shuō)了,DDIM并沒(méi)有明確前向過(guò)程,這意味著我們可以定義一個(gè)更短的步數(shù)的前向過(guò)程。具體地,這里我們從原始的序列采樣一個(gè)長(zhǎng)度為的子序列,我們將的前向過(guò)程定義為一個(gè)馬爾卡夫鏈,并且它們滿足:。下圖展示了一個(gè)具體的示例:

          3dc33b6ae4d50ae9d5b82d319ca4c316.webp那么生成過(guò)程也可以用這個(gè)子序列的反向馬爾卡夫鏈來(lái)替代,由于可以設(shè)置比原來(lái)的步數(shù)要小,那么就可以加速生成過(guò)程。這里的生成過(guò)程變成:


          其實(shí)上述的加速,我們是將前向過(guò)程按如下方式進(jìn)行了分解:


          其中。這包含了兩個(gè)圖:其中一個(gè)就是由組成的馬爾可夫鏈,另外一個(gè)是剩余的變量組成的星狀圖。同時(shí)生成過(guò)程,我們也只用馬爾可夫鏈的那部分來(lái)生成:


          論文共設(shè)計(jì)了兩種方法來(lái)采樣子序列,分別是:

          • Linear:采用線性的序列;
          • Quadratic:采樣二次方的序列;

          這里的是一個(gè)定值,它的設(shè)定使得最接近。論文中只對(duì)CIFAR10數(shù)據(jù)集采用Quadratic序列,其它數(shù)據(jù)集均采用Linear序列。

          實(shí)驗(yàn)結(jié)果

          下表為不同的下以及不同采樣步數(shù)下的對(duì)比結(jié)果,可以看到DDIM()在較短的步數(shù)下就能得到比較好的效果,媲美DDPM()的生成效果。如果設(shè)置為50,那么相比原來(lái)的生成過(guò)程就可以加速20倍。98a3d08e19023508fc6af45cf14304e8.webp

          代碼實(shí)現(xiàn)

          DDIM和DDPM的訓(xùn)練過(guò)程一樣,所以可以直接在DDPM的基礎(chǔ)上加一個(gè)新的生成方法(這里主要參考了DDIM官方代碼以及diffusers庫(kù)),具體代碼如下所示:

              class?GaussianDiffusion:
          ????def?__init__(self,?timesteps=1000,?beta_schedule='linear'):
          ?????pass

          ????#?...
          ????????
          ?#?use?ddim?to?sample
          [email protected]_grad()
          ????def?ddim_sample(
          ????????self,
          ????????model,
          ????????image_size,
          ????????batch_size=8,
          ????????channels=3,
          ????????ddim_timesteps=50,
          ????????ddim_discr_method="uniform",
          ????????ddim_eta=0.0,
          ????????clip_denoised=True)
          :

          ????????#?make?ddim?timestep?sequence
          ????????if?ddim_discr_method?==?'uniform':
          ????????????c?=?self.timesteps?//?ddim_timesteps
          ????????????ddim_timestep_seq?=?np.asarray(list(range(0,?self.timesteps,?c)))
          ????????elif?ddim_discr_method?==?'quad':
          ????????????ddim_timestep_seq?=?(
          ????????????????(np.linspace(0,?np.sqrt(self.timesteps?*?.8),?ddim_timesteps))?**?2
          ????????????).astype(int)
          ????????else:
          ????????????raise?NotImplementedError(f'There?is?no?ddim?discretization?method?called?"{ddim_discr_method}"')
          ????????#?add?one?to?get?the?final?alpha?values?right?(the?ones?from?first?scale?to?data?during?sampling)
          ????????ddim_timestep_seq?=?ddim_timestep_seq?+?1
          ????????#?previous?sequence
          ????????ddim_timestep_prev_seq?=?np.append(np.array([0]),?ddim_timestep_seq[:-1])
          ????????
          ????????device?=?next(model.parameters()).device
          ????????#?start?from?pure?noise?(for?each?example?in?the?batch)
          ????????sample_img?=?torch.randn((batch_size,?channels,?image_size,?image_size),?device=device)
          ????????for?i?in?tqdm(reversed(range(0,?ddim_timesteps)),?desc='sampling?loop?time?step',?total=ddim_timesteps):
          ????????????t?=?torch.full((batch_size,),?ddim_timestep_seq[i],?device=device,?dtype=torch.long)
          ????????????prev_t?=?torch.full((batch_size,),?ddim_timestep_prev_seq[i],?device=device,?dtype=torch.long)
          ????????????
          ????????????#?1.?get?current?and?previous?alpha_cumprod
          ????????????alpha_cumprod_t?=?self._extract(self.alphas_cumprod,?t,?sample_img.shape)
          ????????????alpha_cumprod_t_prev?=?self._extract(self.alphas_cumprod,?prev_t,?sample_img.shape)
          ????
          ????????????#?2.?predict?noise?using?model
          ????????????pred_noise?=?model(sample_img,?t)
          ????????????
          ????????????#?3.?get?the?predicted?x_0
          ????????????pred_x0?=?(sample_img?-?torch.sqrt((1.?-?alpha_cumprod_t))?*?pred_noise)?/?torch.sqrt(alpha_cumprod_t)
          ????????????if?clip_denoised:
          ????????????????pred_x0?=?torch.clamp(pred_x0,?min=-1.,?max=1.)
          ????????????
          ????????????#?4.?compute?variance:?"sigma_t(η)"?->?see?formula?(16)
          ????????????#?σ_t?=?sqrt((1???α_t?1)/(1???α_t))?*?sqrt(1???α_t/α_t?1)
          ????????????sigmas_t?=?ddim_eta?*?torch.sqrt(
          ????????????????(1?-?alpha_cumprod_t_prev)?/?(1?-?alpha_cumprod_t)?*?(1?-?alpha_cumprod_t?/?alpha_cumprod_t_prev))
          ????????????
          ????????????#?5.?compute?"direction?pointing?to?x_t"?of?formula?(12)
          ????????????pred_dir_xt?=?torch.sqrt(1?-?alpha_cumprod_t_prev?-?sigmas_t**2)?*?pred_noise
          ????????????
          ????????????#?6.?compute?x_{t-1}?of?formula?(12)
          ????????????x_prev?=?torch.sqrt(alpha_cumprod_t_prev)?*?pred_x0?+?pred_dir_xt?+?sigmas_t?*?torch.randn_like(sample_img)

          ????????????sample_img?=?x_prev
          ????????????
          ????????return?sample_img.cpu().numpy()

          這里以MNIST數(shù)據(jù)集為例,訓(xùn)練的擴(kuò)散步數(shù)為500,直接采用DDPM(即推理500次)生成的樣本如下所示:2aac388d226aa9d2144d60d99aa164a0.webp同樣的模型,我們采用DDIM來(lái)加速生成過(guò)程,這里DDIM的采樣步數(shù)為50,其生成的樣本質(zhì)量和500步的DDPM相當(dāng):b459dfaa4666c3207e615b2361709c4c.webp完整的代碼示例見(jiàn)https://github.com/xiaohu2015/nngen。

          小結(jié)

          如果從直觀上看,DDIM的加速方式非常簡(jiǎn)單,直接采樣一個(gè)子序列,其實(shí)論文DDPM+也采用了類似的方式來(lái)加速。另外DDIM和其它擴(kuò)散模型的一個(gè)較大的區(qū)別是其生成過(guò)程是確定性的。

          參考

          • Denoising Diffusion Implicit Models
          • https://github.com/ermongroup/ddim
          • https://github.com/openai/improved-diffusion
          • https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py
          • https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddim.py
          • https://kexue.fm/archives/9181


          推薦閱讀

          深入理解生成模型VAE

          DropBlock的原理和實(shí)現(xiàn)

          SOTA模型Swin Transformer是如何煉成的!

          有碼有顏!你要的生成模型VQ-VAE來(lái)了!

          集成YYDS!讓你的模型更快更準(zhǔn)!

          輔助模塊加速收斂,精度大幅提升!移動(dòng)端實(shí)時(shí)的NanoDet-Plus來(lái)了!

          SimMIM:一種更簡(jiǎn)單的MIM方法

          SSD的torchvision版本實(shí)現(xiàn)詳解


          機(jī)器學(xué)習(xí)算法工程師


          ? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)

          1244a8e58c3fe6dce4fa94e68b68cc55.webp


          瀏覽 827
          1點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          評(píng)論
          圖片
          表情
          推薦
          1點(diǎn)贊
          評(píng)論
          收藏
          分享

          手機(jī)掃一掃分享

          分享
          舉報(bào)
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  99精品在线| 尻屄AV| 在线看毛片的网站 | 日韩特级毛片在线视频 | 免费三级片网址 |