TransGAN:兩個Transformer可以構造一個強大的GAN
01
GAN
不寫GAN的優(yōu)化公式,看起來更迷糊,直接把GAN的訓練過程闡述一下就清清楚楚了。

GAN由一個生成器G和一個判別器D構成。隨機輸入通過G得到生成圖片,然后將真實圖片和生成圖片都送入判別器D中進行判斷。
GAN的訓練過程:

第一步先將生成器G參數(shù)固定住,然后將隨機輸入通過生成器G得到生成圖片,最后更新判別器D的參數(shù),將真實圖片盡可能的判別為1,將生成圖片盡可能的判別為0。(即需要得到更強判別力的判別器)

第二步先將判別器D參數(shù)固定住,然后將隨機輸入通過需要更新參數(shù)的生成器G得到生成圖片,最后通過固定參數(shù)的判別器D盡可能的將生成圖片判別為1。(即需要得到當前判別器認為的更接近真實圖片的生成圖片)
然后循環(huán)上述兩個步驟,生成器G能夠產(chǎn)生越來越接近真實圖片的生成圖片。
02
TransGAN

TransGAN整體的原理與GAN相同,主要的不同是Generator和Discriminator都是用Transformer構造的。
Memory-Friendly Generator
Generator的設計是顯存友好的,由多個stage組成,每個stage由幾個Transformer Encoder堆疊形成。隨著stage的增加,不斷的增加feature map的分辨率,直到和目標分辨率相同。以左圖為例,將隨機noise作為輸入,通過多個MLP轉化成長度為8x8xC的向量,然后將向量reshape成8x8分辨率的feature map,每個點的embedding維度為C。隨后將64個維度為C的tokens和可學習的positional encoding相加,送入Transoformer Encoder。為了逐漸得到更大分辨率的feature map,在每個stage之后插入一個UpScaling操作(除了最后一個stage)。最后一個stage后面接一個Linear Unflatten將每個token的維度轉化為3,然后reshape成二維RGB圖片(左圖最終的目標維度為32x32x3)。
其中UpScaling操作由兩個reshape和一個pixelshuffle組成。先reshape將一維tokens序列重新組成二維feature map(假設維度維HxWxC),然后通過一個pixelshuffle操作將二維的feature map維度變成2Hx2WxC/4,最后再reshape成4HW個tokens序列,每個token的維度為C/4。
Tokenized-Input For Discriminator
Discriminator部分如右圖所示。先將輸入圖片拆分成8x8個patches,然后通過Linear Flatten轉化成64個維度為C的token embeddings,然后加上可學習的positional encoding,并且在tokens序列增加一個[cls] token,最后在[cls] token對應的輸出位置增加head來判斷是否為真實圖片。
Evaluation Of Transformer-Based GAN
簡單說明一下GAN的評估指標IS和FID。IS指標可以用來衡量生成樣本的多樣性和準確性。FID指標可以用來衡量真實樣本和生成樣本特征空間的距離。所以IS指標越大越好,F(xiàn)ID指標越小越好。

作者通過AutoGAN和Transformer的排列組合,發(fā)現(xiàn)Transformer+AutoGAN的IS最好,AutoGAN+AutoGAN的FID最好,說明Generator使用Transformer是有效的,但是Discriminator使用Transformer會損害GAN的性能。
Three Tricks
那么能不能Generator和Discriminator都使用Transformer,并且提升GAN的性能呢?作者通過三個tricks來進一步的進行探索。
Data Augmentation is Crucial for TransGAN

作者對三個CNN-based GAN和TransGAN進行數(shù)據(jù)增強的實驗,發(fā)現(xiàn)數(shù)據(jù)增強對于TransGAN的收益是最大的。
Co-Training with Self-Supervised Auxiliary Task

作者在TransGAN中增加了超分任務的輔助訓練,具體的是在stage2同時輸入低分辨率的圖片,然后stage3同時輸出高分辨率的圖片。
Locality-Aware Initialization for Self-Attention

作者還提出了一個隨著訓練epoch的增加逐漸增加Self-Attention感受野的trick,具體的在訓練早期Self-Attention操作mask掉大部分區(qū)域(即紅點query做相關性計算的key范圍是非mask區(qū)域),這些mask的部分不進行計算,然后訓練中期mask區(qū)域逐漸減小,直到訓練末期不進行mask。

作者進行消融實驗發(fā)現(xiàn)輔助訓練和Self-Attention局部初始化都能穩(wěn)定提升TransGAN的性能。
03
實驗結果

最終TransGAN在STL-10數(shù)據(jù)集上達到SOTA性能,在cifar10和CelebA數(shù)據(jù)集上達到和SOTA接近的性能。

最終可視化TransGAN的結果,生成圖像的質量還是非常細膩的。
總體上TransGAN把基于Transformer的GAN做work了,同時提出了3個有效的tricks,效果跟之前最好的方法相當。
Reference
[1] Generative Adversarial Networks | Generative Models (analyticsvidhya.com)
[2] TransGAN: Two Transformers Can Make One Strong GAN
?------------------------------------------------
雙一流大學研究生團隊創(chuàng)建,一個專注于目標檢測與深度學習的組織,希望可以將分享變成一種習慣。
整理不易,點贊三連!
