擴(kuò)散模型全新課程:擴(kuò)散模型從0到1實現(xiàn)!
前言
于 11 月底正式開課的擴(kuò)散模型課程正在火熱進(jìn)行中,在中國社區(qū)成員們的幫助下,我們組織了「抱抱臉中文本地化志愿者小組」并完成了擴(kuò)散模型課程的中文翻譯,感謝 @darcula1993、@XhrLeokk、@hoi2022、@SuSung-boy 對課程的翻譯!
如果你還沒有開始課程的學(xué)習(xí),我們建議你從 第一單元:擴(kuò)散模型簡介 開始。
擴(kuò)散模型從零到一
這個 Notebook 我們將展示相同的步驟(向數(shù)據(jù)添加噪聲、創(chuàng)建模型、訓(xùn)練和采樣),并盡可能簡單地在 PyTorch 中從頭開始實現(xiàn)。然后,我們將這個「玩具示例」與 diffusers 版本進(jìn)行比較,并關(guān)注兩者的區(qū)別以及改進(jìn)之處。這里的目標(biāo)是熟悉不同的組件和其中的設(shè)計決策,以便在查看新的實現(xiàn)時能夠快速確定關(guān)鍵思想。
讓我們開始吧!
有時,只考慮一些事務(wù)最簡單的情況會有助于更好地理解其工作原理。我們將在本筆記本中嘗試這一點(diǎn),從“玩具”擴(kuò)散模型開始,看看不同的部分是如何工作的,然后再檢查它們與更復(fù)雜的實現(xiàn)有何不同。
你將跟隨本文的 Notebook 學(xué)習(xí)到
- 損壞過程(向數(shù)據(jù)添加噪聲)
- 什么是 UNet,以及如何從零開始實現(xiàn)一個極小的 UNet
- 擴(kuò)散模型訓(xùn)練
- 抽樣理論
然后,我們將比較我們的版本與 diffusers 庫中的 DDPM 實現(xiàn)的區(qū)別
- 對小型 UNet 的改進(jìn)
- DDPM 噪聲計劃
- 訓(xùn)練目標(biāo)的差異
- timestep 調(diào)節(jié)
- 抽樣方法
這個筆記本相當(dāng)深入,如果你對從零開始的深入研究不感興趣,可以放心地跳過!
還值得注意的是,這里的大多數(shù)代碼都是出于說明的目的,我不建議直接將其用于您自己的工作(除非您只是為了學(xué)習(xí)目的而嘗試改進(jìn)這里展示的示例)。
準(zhǔn)備環(huán)境與導(dǎo)入:
!pip?install?-q?diffusers
import?torch
import?torchvision
from?torch?import?nn
from?torch.nn?import?functional?as?F
from?torch.utils.data?import?DataLoader
from?diffusers?import?DDPMScheduler,?UNet2DModel
from?matplotlib?import?pyplot?as?plt
device?=?torch.device("cuda"?if?torch.cuda.is_available()?else?"cpu")
print(f'Using?device:?{device}')
數(shù)據(jù)
在這里,我們將使用一個非常小的經(jīng)典數(shù)據(jù)集 mnist 來進(jìn)行測試。如果您想在不改變?nèi)魏纹渌麅?nèi)容的情況下給模型一個稍微困難一點(diǎn)的挑戰(zhàn),請使用 torchvision.dataset,F(xiàn)ashionMNIST 應(yīng)作為替代品。
dataset?=?torchvision.datasets.MNIST(root="mnist/",?train=True,?download=True,?transform=torchvision.transforms.ToTensor())
train_dataloader?=?DataLoader(dataset,?batch_size=8,?shuffle=True)
x,?y?=?next(iter(train_dataloader))
print('Input?shape:',?x.shape)
print('Labels:',?y)
plt.imshow(torchvision.utils.make_grid(x)[0],?cmap='Greys');
該數(shù)據(jù)集中的每張圖都是一個數(shù)字的 28x28 像素的灰度圖,像素值的范圍是從 0 到 1。
損壞過程
假設(shè)你沒有讀過任何擴(kuò)散模型的論文,但你知道這個過程會增加噪聲。你會怎么做?
我們可能想要一個簡單的方法來控制損壞的程度。那么,如果我們要引入一個參數(shù)來控制輸入的“噪聲量”,那么我們會這么做:
noise = torch.rand_like(x)
noisy_x = (1-amount)*x + amount*noise
如果 amount = 0,則返回輸入而不做任何更改。如果 amount = 1,我們將得到一個純粹的噪聲。通過這種方式將輸入與噪聲混合,我們將輸出保持在相同的范圍(0 to 1)。
我們可以很容易地實現(xiàn)這一點(diǎn)(但是要注意 tensor 的 shape,以防被廣播 (broadcasting) 機(jī)制不正確的影響到):
def?corrupt(x,?amount):
??"""Corrupt?the?input?`x`?by?mixing?it?with?noise?according?to?`amount`"""
??noise?=?torch.rand_like(x)
??amount?=?amount.view(-1,?1,?1,?1)?#?Sort?shape?so?broadcasting?works
??return?x*(1-amount)?+?noise*amount?
讓我們來可視化一下輸出的結(jié)果,以了解是否符合我們的預(yù)期:
#?Plotting?the?input?data
fig,?axs?=?plt.subplots(2,?1,?figsize=(12,?5))
axs[0].set_title('Input?data')
axs[0].imshow(torchvision.utils.make_grid(x)[0],?cmap='Greys')
#?Adding?noise
amount?=?torch.linspace(0,?1,?x.shape[0])?#?Left?to?right?->?more?corruption
noised_x?=?corrupt(x,?amount)
#?Plottinf?the?noised?version
axs[1].set_title('Corrupted?data?(--?amount?increases?-->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0],?cmap='Greys');
當(dāng)噪聲量接近 1 時,我們的數(shù)據(jù)開始看起來像純隨機(jī)噪聲。但對于大多數(shù)的噪聲情況下,您還是可以很好地識別出數(shù)字。你認(rèn)為這是最佳的嗎?
模型
我們想要一個模型,它可以接收 28px 的噪聲圖像,并輸出相同形狀的預(yù)測。一個比較流行的選擇是一個叫做 UNet 的架構(gòu)。最初被發(fā)明用于醫(yī)學(xué)圖像中的分割任務(wù),UNet 由一個“壓縮路徑”和一個“擴(kuò)展路徑”組成?!皦嚎s路徑”會使通過該路徑的數(shù)據(jù)被壓縮,而通過“擴(kuò)展路徑”會將數(shù)據(jù)擴(kuò)展回原始維度(類似于自動編碼器)。模型中的殘差連接也允許信息和梯度在不同層級之間流動。
一些 UNet 的設(shè)計在每個階段都有復(fù)雜的 blocks,但對于這個玩具 demo,我們只會構(gòu)建一個最簡單的示例,它接收一個單通道圖像,并通過下行路徑上的三個卷積層(圖和代碼中的 down_layers)和上行路徑上的 3 個卷積層,在下行和上行層之間具有殘差連接。我們將使用 max pooling 進(jìn)行下采樣和 nn.Upsample 用于上采樣。某些比較復(fù)雜的 UNets 的設(shè)計會使用帶有可學(xué)習(xí)參數(shù)的上采樣和下采樣 layer。下面的結(jié)構(gòu)圖大致展示了每個 layer 的輸出通道數(shù):

代碼實現(xiàn)如下:
class?BasicUNet(nn.Module):
??
"""A?minimal?UNet?implementation."""
??
def?__init__(self,?in_channels=1,?out_channels=1):
????super().__init__()
????self.down_layers?=?torch.nn.ModuleList([?
??????nn.Conv2d(in_channels,?
32
,?kernel_size=
5
,?padding=
2
),
??????nn.Conv2d(
32
,?
64
,?kernel_size=
5
,?padding=
2
),
??????nn.Conv2d(
64
,?
64
,?kernel_size=
5
,?padding=
2
),
????])
????self.up_layers?=?torch.nn.ModuleList([
??????nn.Conv2d(
64
,?
64
,?kernel_size=
5
,?padding=
2
),
??????nn.Conv2d(
64
,?
32
,?kernel_size=
5
,?padding=
2
),
??????nn.Conv2d(
32
,?out_channels,?kernel_size=
5
,?padding=
2
),?
????])
????self.act?=?nn.SiLU()?
#?The?activation?function
????self.downscale?=?nn.MaxPool2d(
2
)
????self.upscale?=?nn.Upsample(scale_factor=
2
)
??
def?forward(self,?x):
????h?=?[]
????
for
?i,?l?
in
?enumerate(self.down_layers):
??????x?=?self.act(l(x))?
#?Through?the?layer?n?the?activation?function
??????
if
?i?<?
2
:?
#?For?all?but?the?third?(final)?down?layer:
????????h.append(x)?
#?Storing?output?for?skip?connection
????????x?=?self.downscale(x)?
#?Downscale?ready?for?the?next?layer
??????????????
????
for
?i,?l?
in
?enumerate(self.up_layers):
??????
if
?i?>?
0
:?
#?For?all?except?the?first?up?layer
????????x?=?self.upscale(x)?
#?Upscale
????????x?+=?h.pop()?
#?Fetching?stored?output?(skip?connection)
????????x?=?self.act(l(x))?
#?Through?the?layer?n?the?activation?function
????????????
????
return
?x
我們可以驗證輸出 shape 是否如我們期望的那樣與輸入相同:
net?=?BasicUNet()
x?=?torch.rand(8,?1,?28,?28)
net(x).shape
torch.Size([8, 1, 28, 28])
該網(wǎng)絡(luò)有 30 多萬個參數(shù):
sum([p.numel()?for?p?in?net.parameters()])
309057
您可以嘗試更改每個 layer 中的通道數(shù)或嘗試不同的結(jié)構(gòu)設(shè)計。
訓(xùn)練模型
那么,模型到底應(yīng)該做什么呢?同樣,對這個問題有各種不同的看法,但對于這個演示,讓我們選擇一個簡單的框架:給定一個損壞的輸入 noisy_x,模型應(yīng)該輸出它對原本 x 的最佳猜測。我們將通過均方誤差將預(yù)測與真實值進(jìn)行比較。
我們現(xiàn)在可以嘗試訓(xùn)練網(wǎng)絡(luò)了。
- 獲取一批數(shù)據(jù)
- 添加隨機(jī)噪聲
- 將數(shù)據(jù)輸入模型
- 將模型預(yù)測與干凈圖像進(jìn)行比較,以計算 loss
- 更新模型的參數(shù)
你可以自由進(jìn)行修改來嘗試獲得更好的結(jié)果!
#?Dataloader?(you?can?mess?with?batch?size)
batch_size?=?128
train_dataloader?=?DataLoader(dataset,?batch_size=batch_size,?shuffle=True)
#?How?many?runs?through?the?data?should?we?do?
n_epochs?=?3
#?Create?the?network
net?=?BasicUNet()
net.to(device)
#?Our?loss?finction
loss_fn?=?nn.MSELoss()
#?The?optimizer
opt?=?torch.optim.Adam(net.parameters(),?lr=1e-3)?
#?Keeping?a?record?of?the?losses?for?later?viewing
losses?=?[]
#?The?training?loop
for?epoch?in?range(n_epochs):
for?x,?y?in?train_dataloader:
? ?#?Get?some?data?and?prepare?the?corrupted?version
? ?x?=?x.to(device)?#?Data?on?the?GPU
? ?noise_amount?=?torch.rand(x.shape[0]).to(device)?#?Pick?random?noise?amounts
? ?noisy_x?=?corrupt(x,?noise_amount)?#?Create?our?noisy?x
? ?#?Get?the?model?prediction
? ?pred?=?net(noisy_x)
? ?#?Calculate?the?loss
? ?loss?=?loss_fn(pred,?x)?#?How?close?is?the?output?to?the?true?'clean'?x?
? ?#?Backprop?and?update?the?params:
? ?opt.zero_grad()
? ?loss.backward()
? ?opt.step()
? ?#?Store?the?loss?for?later
? ?losses.append(loss.item())
? ?#?Print?our?the?average?of?the?loss?values?for?this?epoch:
????avg_loss?=?sum(losses[-len(train_dataloader):])/len(train_dataloader)
????print(f'Finished?epoch?{epoch}.?Average?loss?for?this?epoch:?{avg_loss:05f}')
#?View?the?loss?curve
plt.plot(losses)
plt.ylim(0,?0.1);
Finished epoch 0. Average loss for this epoch: 0.026736
Finished epoch 1. Average loss for this epoch: 0.020692
Finished epoch 2. Average loss for this epoch: 0.018887

我們可以嘗試通過抓取一批數(shù)據(jù),以不同的數(shù)量損壞數(shù)據(jù),然后喂進(jìn)模型獲得預(yù)測來觀察結(jié)果:
#@markdown?Visualizing?model?predictions?on?noisy?inputs:
#?Fetch?some?data
x,?y?=?next(iter(train_dataloader))
x?=?x[:8]?#?Only?using?the?first?8?for?easy?plotting
#?Corrupt?with?a?range?of?amounts
amount?=?torch.linspace(0,?1,?x.shape[0])?#?Left?to?right?->?more?corruption
noised_x?=?corrupt(x,?amount)
#?Get?the?model?predictions
with?torch.no_grad():
??preds?=?net(noised_x.to(device)).detach().cpu()
#?Plot
fig,?axs?=?plt.subplots(3,?1,?figsize=(12,?7))
axs[0].set_title('Input?data')
axs[0].imshow(torchvision.utils.make_grid(x)[0].clip(0,?1),?cmap='Greys')
axs[1].set_title('Corrupted?data')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0].clip(0,?1),?cmap='Greys')
axs[2].set_title('Network?Predictions')
axs[2].imshow(torchvision.utils.make_grid(preds)[0].clip(0,?1),?cmap='Greys');

你可以看到,對于較低的噪聲水平數(shù)量,預(yù)測的結(jié)果相當(dāng)不錯!但是,當(dāng)噪聲水平非常高時,模型能夠獲得的信息就開始逐漸減少。而當(dāng)我們達(dá)到 amount = 1 時,模型會輸出一個模糊的預(yù)測,該預(yù)測會很接近數(shù)據(jù)集的平均值。模型通過這樣的方式來猜測原始輸入。
取樣(采樣)
如果我們在高噪聲水平下的預(yù)測不是很好,我們?nèi)绾尾拍苌蓤D像呢?
如果我們從完全隨機(jī)的噪聲開始,檢查一下模型預(yù)測的結(jié)果,然后只朝著預(yù)測方向移動一小部分,比如說 20%?,F(xiàn)在我們有一個噪聲很多的圖像,其中可能隱藏了一些關(guān)于輸入數(shù)據(jù)的結(jié)構(gòu)的提示,我們可以將其輸入到模型中以獲得新的預(yù)測。希望這個新的預(yù)測比第一個稍微好一點(diǎn)(因為我們這一次的輸入稍微減少了一點(diǎn)噪聲),所以我們可以用這個新的更好的預(yù)測再往前邁出一小步。
如果一切順利的話,以上過程重復(fù)幾次以后我們就會得到一個新的圖像!以下圖例是迭代了五次以后的結(jié)果,左側(cè)是每個階段的模型輸入的可視化,右側(cè)則是預(yù)測的去噪圖像。請注意,即使模型在第 1 步就預(yù)測了去噪圖像,我們也只是將輸入向去噪圖像變換了一小部分。重復(fù)幾次以后,圖像的結(jié)構(gòu)開始逐漸出現(xiàn)并得到改善 , 直到獲得我們的最終結(jié)果為止。
#@markdown?Sampling?strategy:?Break?the?process?into?5?steps?and?move?1/5'th?of?the?way?there?each?time:
n_steps?=?5
x?=?torch.rand(8,?1,?28,?28).to(device)?#?Start?from?random
step_history?=?[x.detach().cpu()]
pred_output_history?=?[]
for?i?in?range(n_steps):
?with?torch.no_grad():?#?No?need?to?track?gradients?during?inference
? ?pred?=?net(x)?#?Predict?the?denoised?x0
?pred_output_history.append(pred.detach().cpu())?#?Store?model?output?for?plotting
?mix_factor?=?1/(n_steps?-?i)?#?How?much?we?move?towards?the?prediction
?x?=?x*(1-mix_factor)?+?pred*mix_factor?#?Move?part?of?the?way?there
?step_history.append(x.detach().cpu())?#?Store?step?for?plotting
fig,?axs?=?plt.subplots(n_steps,?2,?figsize=(9,?4),?sharex=True)
axs[0,0].set_title('x?(model?input)')
axs[0,1].set_title('model?prediction')
for?i?in?range(n_steps):
?axs[i,?0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0,?1),?cmap='Greys')
?axs[i,?1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0,?1),?cmap='Greys')

我們可以將流程分成更多步驟,并希望通過這種方式獲得更好的圖像:
#@markdown?Showing?more?results,?using?40?sampling?steps
n_steps?=?40
x?=?torch.rand(64,?1,?28,?28).to(device)
for?i?in?range(n_steps):
??noise_amount?=?torch.ones((x.shape[0],?)).to(device)?*?(1-(i/n_steps))?#?Starting?high?going?low
??with?torch.no_grad():
? ?pred?=?net(x)
??mix_factor?=?1/(n_steps?-?i)
??x?=?x*(1-mix_factor)?+?pred*mix_factor
fig,?ax?=?plt.subplots(1,?1,?figsize=(12,?12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu(),?nrow=8)[0].clip(0,?1),?cmap='Greys')
<matplotlib.image.AxesImage at 0x7f27567d8210>

結(jié)果并不是非常好,但是已經(jīng)出現(xiàn)了一些可以被認(rèn)出來的數(shù)字!您可以嘗試訓(xùn)練更長時間(例如,10 或 20 個 epoch),并調(diào)整模型配置、學(xué)習(xí)率、優(yōu)化器等。此外,如果您想嘗試稍微困難一點(diǎn)的數(shù)據(jù)集,您可以嘗試一下 fashionMNIST,只需要一行代碼的替換就可以了。
與 DDPM 做比較
在本節(jié)中,我們將看看我們的“玩具”實現(xiàn)與其他筆記本中使用的基于 DDPM 論文的方法有何不同: 擴(kuò)散器簡介 Notebook。
-
擴(kuò)散器簡介 Notebook:
https://github.com/huggingface/diffusion-models-class/blob/main/unit1/01_introduction_to_diffusers.ipynb
我們將會看到的
- 模型的表現(xiàn)受限于隨迭代周期 (timesteps) 變化的控制條件,在前向傳導(dǎo)中時間步 (t) 是作為一個參數(shù)被傳入的
- 有很多不同的取樣策略可選擇,可能會比我們上面所使用的最簡單的版本更好
-
diffusers
UNet2DModel比我們的 BasicUNet 更先進(jìn) - 損壞過程的處理方式不同
- 訓(xùn)練目標(biāo)不同,包括預(yù)測噪聲而不是去噪圖像
- 該模型通過調(diào)節(jié) timestep 來調(diào)節(jié)噪聲水平 , 其中 t 作為一個附加參數(shù)傳入前向過程中。
- 有許多不同的采樣策略可供選擇,它們應(yīng)該比我們上面簡單的版本更有效。
自 DDPM 論文發(fā)表以來,已經(jīng)有人提出了許多改進(jìn)建議,但這個例子對于不同的可用設(shè)計決策具有指導(dǎo)意義。讀完這篇文章后,你可能會想要深入了解這篇論文《Elucidating the Design Space of Diffusion-Based Generative Models》,它對所有這些組件進(jìn)行了詳細(xì)的探討,并就如何獲得最佳性能提出了新的建議。
Elucidating the Design Space of Diffusion-Based Generative Models 論文鏈接:
https://arxiv.org/abs/2206.00364
如果你覺得這些內(nèi)容對你來說太過深奧了,請不要擔(dān)心!你可以隨意跳過本筆記本的其余部分或?qū)⑵浔4嬉詡洳粫r之需。
UNet
diffusers 中的 UNet2DModel 模型比上述基本 UNet 模型有許多改進(jìn):
- GroupNorm 層對每個 blocks 的輸入進(jìn)行了組標(biāo)準(zhǔn)化(group normalization)
- Dropout 層能使訓(xùn)練更平滑
- 每個塊有多個 resnet 層(如果 layers_per_block 未設(shè)置為 1)
- 注意機(jī)制(通常僅用于輸入分辨率較低的 blocks)
- timestep 的調(diào)節(jié)。
- 具有可學(xué)習(xí)參數(shù)的下采樣和上采樣塊
讓我們來創(chuàng)建并仔細(xì)研究一下 UNet2DModel:
model?=?UNet2DModel(
??sample_size=28,???????????#?the?target?image?resolution
??in_channels=1,????????????#?the?number?of?input?channels,?3?for?RGB?images
??out_channels=1,???????????#?the?number?of?output?channels
??layers_per_block=2,???????#?how?many?ResNet?layers?to?use?per?UNet?block
??block_out_channels=(32,?64,?64),?#?Roughly?matching?our?basic?unet?example
??down_block_types=(?
????"DownBlock2D",????????#?a?regular?ResNet?downsampling?block
??? "AttnDownBlock2D",????#?a?ResNet?downsampling?block?w/?spatial?self-attention
????"AttnDownBlock2D",
??),?
??up_block_types=(
????"AttnUpBlock2D",?
????"AttnUpBlock2D",??????#?a?ResNet?upsampling?block?with?spatial?self-attention
????"UpBlock2D",??????????#?a?regular?ResNet?upsampling?block
??),
)
print(model)
UNet2DModel(
(conv_in): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_proj): Timesteps()
(time_embedding): TimestepEmbedding(
(linear_1): Linear(in_features=32, out_features=128, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=128, out_features=128, bias=True)
)
(down_blocks): ModuleList(
(0): DownBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
(1): AttnDownBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
(2): AttnDownBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
)
(up_blocks): ModuleList(
(0): AttnUpBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(2): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(1): AttnUpBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(1): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
(2): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 128, eps=1e-05, affine=True)
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
(conv1): Conv2d(96, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(upsamplers): ModuleList(
(0): Upsample2D(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
)
(2): UpBlock2D(
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 96, eps=1e-05, affine=True)
(conv1): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
)
(2): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
(norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
(conv_shortcut): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
)
)
(mid_block): UNetMidBlock2D(
(attentions): ModuleList(
(0): AttentionBlock(
(group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
(query): Linear(in_features=64, out_features=64, bias=True)
(key): Linear(in_features=64, out_features=64, bias=True)
(value): Linear(in_features=64, out_features=64, bias=True)
(proj_attn): Linear(in_features=64, out_features=64, bias=True)
)
)
(resnets): ModuleList(
(0): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
(1): ResnetBlock2D(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): Linear(in_features=128, out_features=64, bias=True)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
)
(conv_norm_out): GroupNorm(32, 32, eps=1e-05, affine=True)
(conv_act): SiLU()
(conv_out): Conv2d(32, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
正如你所看到的,還有更多!它比我們的 BasicUNet 有多得多的參數(shù)量:
sum([p.numel()?for?p?in?model.parameters()])?#?1.7M?vs?the?~309k?parameters?of?the?BasicUNet
1707009
我們可以用這個模型代替原來的模型來重復(fù)一遍上面展示的訓(xùn)練過程。我們需要將 x 和 timestep 傳遞給模型(這里我會傳遞 t = 0,以表明它在沒有 timestep 條件的情況下工作,并保持采樣代碼簡單,但您也可以嘗試輸入 (amount*1000),使 timestep 與噪聲水平相當(dāng))。如果要檢查代碼,更改的行將顯示為“#<<<。
#@markdown?Trying?UNet2DModel?instead?of?BasicUNet:
#?Dataloader?(you?can?mess?with?batch?size)
batch_size?=?128
train_dataloader?=?DataLoader(dataset,?batch_size=batch_size,?shuffle=True)
#?How?many?runs?through?the?data?should?we?do?
n_epochs?=?3
#?Create?the?network
net?=?UNet2DModel(
??sample_size=28,??#?the?target?image?resolution
??in_channels=1,??#?the?number?of?input?channels,?3?for?RGB?images
??out_channels=1,??#?the?number?of?output?channels
??layers_per_block=2,??#?how?many?ResNet?layers?to?use?per?UNet?block
??block_out_channels=(32,?64,?64),??#?Roughly?matching?our?basic?unet?example
??down_block_types=(?
????"DownBlock2D",??#?a?regular?ResNet?downsampling?block
????"AttnDownBlock2D",??#?a?ResNet?downsampling?block?with?spatial?self-attention
????"AttnDownBlock2D",
??),?
??up_block_types=(
????"AttnUpBlock2D",?
????"AttnUpBlock2D",??#?a?ResNet?upsampling?block?with?spatial?self-attention
????"UpBlock2D",???#?a?regular?ResNet?upsampling?block
??),
)?#<<<
net.to(device)
#?Our?loss?finction
loss_fn?=?nn.MSELoss()
#?The?optimizer
opt?=?torch.optim.Adam(net.parameters(),?lr=1e-3)?
#?Keeping?a?record?of?the?losses?for?later?viewing
losses?=?[]
#?The?training?loop
for?epoch?in?range(n_epochs):
??for?x,?y?in?train_dataloader:
????#?Get?some?data?and?prepare?the?corrupted?version
????x?=?x.to(device)?#?Data?on?the?GPU
????noise_amount?=?torch.rand(x.shape[0]).to(device)?#?Pick?random?noise?amounts
????noisy_x?=?corrupt(x,?noise_amount)?#?Create?our?noisy?x
????#?Get?the?model?prediction
????pred?=?net(noisy_x,?0).sample?#<<<?Using?timestep?0?always,?adding?.sample
????#?Calculate?the?loss
????loss?=?loss_fn(pred,?x)?#?How?close?is?the?output?to?the?true?'clean'?x?
????#?Backprop?and?update?the?params:
????opt.zero_grad()
????loss.backward()
????opt.step()
????#?Store?the?loss?for?later
????losses.append(loss.item())
????#?Print?our?the?average?of?the?loss?values?for?this?epoch:
????avg_loss?=?sum(losses[-len(train_dataloader):])/len(train_dataloader)
????print(f'Finished?epoch?{epoch}.?Average?loss?for?this?epoch:?{avg_loss:05f}')
#?Plot?losses?and?some?samples
fig,?axs?=?plt.subplots(1,?2,?figsize=(12,?5))
#?Losses
axs[0].plot(losses)
axs[0].set_ylim(0,?0.1)
axs[0].set_title('Loss?over?time')
#?Samples
n_steps?=?40
x?=?torch.rand(64,?1,?28,?28).to(device)
for?i?in?range(n_steps):
??noise_amount?=?torch.ones((x.shape[0],?)).to(device)?*?(1-(i/n_steps))?#?Starting?high?going?low
??with?torch.no_grad():
????pred?=?net(x,?0).sample
??mix_factor?=?1/(n_steps?-?i)
??x?=?x*(1-mix_factor)?+?pred*mix_factor
axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(),?nrow=8)[0].clip(0,?1),?cmap='Greys')
axs[1].set_title('Generated?Samples');
Finished epoch 0. Average loss for this epoch: 0.018925
Finished epoch 1. Average loss for this epoch: 0.012785
Finished epoch 2. Average loss for this epoch: 0.011694

這看起來比我們的第一組結(jié)果好多了!您可以嘗試調(diào)整 UNet 配置或更長時間的訓(xùn)練,以獲得更好的性能。
損壞過程
DDPM 論文描述了一個為每個“timestep”添加少量噪聲的損壞過程。為某些 timestep 給定 , 我們可以得到一個噪聲稍稍增加的 :

這就是說,我們?nèi)?, 給他一個 的系數(shù),然后加上帶有 系數(shù)的噪聲。這里 是根據(jù)一些管理器來為每一個 t 設(shè)定的,來決定每一個迭代周期中添加多少噪聲。現(xiàn)在,我們不想把這個推演進(jìn)行 500 次來得到 ,所以我們用另一個公式來根據(jù)給出的 計算得到任意 t 時刻的 :

數(shù)學(xué)符號看起來總是很嚇人!幸運(yùn)的是,調(diào)度器為我們處理了所有這些(取消下一個單元格的注釋以檢查代碼)。我們可以畫出 (標(biāo)記為 sqrt_alpha_prod) 和 (標(biāo)記為 sqrt_one_minus_alpha_prod) 來看一下輸入 (x) 與噪聲是如何在不同迭代周期中量化和疊加的 :
#??noise_scheduler.add_noise
noise_scheduler?=?DDPMScheduler(num_train_timesteps=1000)
plt.plot(noise_scheduler.alphas_cumprod.cpu()?**?0.5,?label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1?-?noise_scheduler.alphas_cumprod.cpu())?**?0.5,?label=r"$\sqrt{(1?-?\bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");

一開始 , 噪聲 x 里絕大部分都是 x 自身的值 ?(sqrt_alpha_prod ~= 1),但是隨著時間的推移,x 的成分逐漸降低而噪聲的成分逐漸增加。與我們根據(jù) amount 對 x 和噪聲進(jìn)行線性混合不同,這個噪聲的增加相對較快。我們可以在一些數(shù)據(jù)上看到這一點(diǎn):
#@markdown?visualize?the?DDPM?noising?process?for?different?timesteps:
#?Noise?a?batch?of?images?to?view?the?effect
fig,?axs?=?plt.subplots(3,?1,?figsize=(16,?10))
xb,?yb?=?next(iter(train_dataloader))
xb?=?xb.to(device)[:8]
xb?=?xb?*?2.?-?1.?#?Map?to?(-1,?1)
print('X?shape',?xb.shape)
#?Show?clean?inputs
axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(),?cmap='Greys')
axs[0].set_title('Clean?X')
#?Add?noise?with?scheduler
timesteps?=?torch.linspace(0,?999,?8).long().to(device)
noise?=?torch.randn_like(xb)?#?<<?NB:?randn?not?rand
noisy_xb?=?noise_scheduler.add_noise(xb,?noise,?timesteps)
print('Noisy?X?shape',?noisy_xb.shape)
#?Show?noisy?version?(with?and?without?clipping)
axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1,?1),??cmap='Greys')
axs[1].set_title('Noisy?X?(clipped?to?(-1,?1)')
axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(),??cmap='Greys')
axs[2].set_title('Noisy?X');
X shape torch.Size([8, 1, 28, 28])
Noisy X shape torch.Size([8, 1, 28, 28])

在運(yùn)行中的另一個變化:在 DDPM 版本中,加入的噪聲是取自一個高斯分布(來自均值 0 方差 1 的 torch.randn),而不是在我們原始 corrupt 函數(shù)中使用的 0-1 之間的均勻分布(torch.rand),當(dāng)然對訓(xùn)練數(shù)據(jù)做正則化也可以理解。在另一篇筆記中,你會看到 Normalize(0.5, 0.5) 函數(shù)在變化列表中,它把圖片數(shù)據(jù)從 (0, 1) 區(qū)間映射到 (-1, 1),對我們的目標(biāo)來說也‘足夠用了’。我們在此篇筆記中沒使用這個方法,但在上面的可視化中為了更好的展示添加了這種做法。
訓(xùn)練目標(biāo)
在我們的玩具示例中,我們讓模型嘗試預(yù)測去噪圖像。在 DDPM 和許多其他擴(kuò)散模型實現(xiàn)中,模型則會預(yù)測損壞過程中使用的噪聲(在縮放之前,因此是單位方差噪聲)。在代碼中,它看起來像是這樣:
noise?=?torch.randn_like(xb)?#?<<?NB:?randn?not?rand
noisy_x?=?noise_scheduler.add_noise(x,?noise,?timesteps)
model_prediction?=?model(noisy_x,?timesteps).sample
loss?=?mse_loss(model_prediction,?noise)?#?noise?as?the?target
你可能認(rèn)為預(yù)測噪聲(我們可以從中得出去噪圖像的樣子)等同于直接預(yù)測去噪圖像。那么,為什么要這么做呢?這僅僅是為了數(shù)學(xué)上的方便嗎?
這里其實還有另一些精妙之處。我們在訓(xùn)練過程中,會計算不同(隨機(jī)選擇)timestep 的 loss。這些不同的目標(biāo)將導(dǎo)致這些 loss 的不同的“隱含權(quán)重”,其中預(yù)測噪聲會將更多的權(quán)重放在較低的噪聲水平上。你可以選擇更復(fù)雜的目標(biāo)來改變這種“隱性損失權(quán)重”?;蛘撸x擇的噪聲管理器將在較高的噪聲水平下產(chǎn)生更多的示例。也許你讓模型設(shè)計成預(yù)測 “velocity” v,我們將其定義為由噪聲水平影響的圖像和噪聲組合(請參閱“擴(kuò)散模型快速采樣的漸進(jìn)蒸餾”- 'PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS')。也許你將模型設(shè)計成預(yù)測噪聲,然后基于某些因子來對 loss 進(jìn)行縮放:比如有些理論指出可以參考噪聲水平(參見“擴(kuò)散模型的感知優(yōu)先訓(xùn)練”-'Perception Prioritized Training of Diffusion Models'),或者基于一些探索模型最佳噪聲水平的實驗(參見“基于擴(kuò)散的生成模型的設(shè)計空間說明”-'Elucidating the Design Space of Diffusion-Based Generative Models')。
一句話解釋:選擇目標(biāo)對模型性能有影響,現(xiàn)在有許多研究者正在探索“最佳”選項是什么。目前,預(yù)測噪聲(epsilon 或 eps)是最流行的方法,但隨著時間的推移,我們很可能會看到庫中支持的其他目標(biāo),并在不同的情況下使用。
迭代周期(Timestep)調(diào)節(jié)
UNet2DModel 以 x 和 timestep 為輸入。后者被轉(zhuǎn)化為一個嵌入(embedding),并在多個地方被輸入到模型中。
這背后的理論支持是這樣的:通過向模型提供有關(guān)噪聲水平的信息,它可以更好地執(zhí)行任務(wù)。雖然在沒有這種 timestep 條件的情況下也可以訓(xùn)練模型,但在某些情況下,它似乎確實有助于性能,目前來說絕大多數(shù)的模型實現(xiàn)都包括了這一輸入。
取樣(采樣)
有一個模型可以用來預(yù)測在帶噪樣本中的噪聲(或者說能預(yù)測其去噪版本),我們怎么用它來生成圖像呢?
我們可以給入純噪聲,然后就希望模型能一步就輸出一個不帶噪聲的好圖像。但是,就我們上面所見到的來看,這通常行不通。所以,我們在模型預(yù)測的基礎(chǔ)上使用足夠多的小步,迭代著來每次去除一點(diǎn)點(diǎn)噪聲。
具體我們怎么走這些小步,取決于使用上面取樣方法。我們不會去深入討論太多的理論細(xì)節(jié),但是一些頂層想法是這樣:
- 每一步你想走多大?也就是說,你遵循什么樣的“噪聲計劃(噪聲管理)”?
- 你只使用模型當(dāng)前步的預(yù)測結(jié)果來指導(dǎo)下一步的更新方向嗎(像 DDPM,DDIM 或是其他的什么那樣)?你是否要使用模型來多預(yù)測幾次來估計一個更高階的梯度來更新一步更大更準(zhǔn)確的結(jié)果(更高階的方法和一些離散 ODE 處理器)?或者保留歷史預(yù)測值來嘗試更好的指導(dǎo)當(dāng)前步的更新(線性多步或遺傳取樣器)?
- 你是否會在取樣過程中額外再加一些隨機(jī)噪聲,或你完全已知的(deterministic)來添加噪聲?許多取樣器通過參數(shù)(如 DDIM 中的 'eta')來供用戶選擇。
對于擴(kuò)散模型取樣器的研究演進(jìn)的很快,隨之開發(fā)出了越來越多可以使用更少步就找到好結(jié)果的方法。勇敢和有好奇心的人可能會在瀏覽 diffusers library 中不同部署方法時感到非常有意思,可以查看 Schedulers 代碼 或看看 Schedulers 文檔,這里經(jīng)常有一些相關(guān)的論文。
-
Schedulers 代碼:
https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers -
Schedulers 文檔:
https://huggingface.co/docs/diffusers/main/en/api/schedulers
結(jié)語
希望這可以從一些不同的角度來審視擴(kuò)散模型提供一些幫助。這篇筆記是 Jonathan Whitaker 為 Hugging Face 課程所寫的,如果你對從噪聲和約束分類來生成樣本的例子感興趣。問題與 bug 可以通過 GitHub issues 或 Discord 來交流。
致謝第一單元第二部分社區(qū)貢獻(xiàn)者
感謝社區(qū)成員們對本課程的貢獻(xiàn):
@darcula1993、@XhrLeokk:魔都強(qiáng)人工智能孵化者,二里街調(diào)參記錄保持人,一切興趣使然的 AIGC 色圖創(chuàng)作家的庇護(hù)者,圖靈神在五角場的唯一指定路上行走。
感謝茶葉蛋蛋對本文貢獻(xiàn)設(shè)計素材!
歡迎通過鏈接加入我們的本地化小組與大家共同交流:
https://bit.ly/3G40j6U
推薦閱讀
輔助模塊加速收斂,精度大幅提升!移動端實時的NanoDet-Plus來了!
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個用心的公眾號

