何愷明一作MAE收錄CVPR2022 oral!(附源碼實(shí)現(xiàn))
點(diǎn)藍(lán)色字關(guān)注“機(jī)器學(xué)習(xí)算法工程師”
設(shè)為星標(biāo),干貨直達(dá)!
近日,F(xiàn)AIR的最新論文Masked Autoencoders Are Scalable Vision Learners(何愷明一作)提出了一種更簡(jiǎn)單有效的用于ViT無(wú)監(jiān)督訓(xùn)練的方法MAE,并在ImageNet-1K數(shù)據(jù)集上的top-1 acc達(dá)到新的SOTA:87.8%(無(wú)額外訓(xùn)練數(shù)據(jù))。自從ViT火了之后,一些研究者就開始嘗試研究ViT的無(wú)監(jiān)督學(xué)習(xí),比如Mocov3用對(duì)比學(xué)習(xí)的方法無(wú)監(jiān)督訓(xùn)練ViT,此外也有一些研究開始借鑒BERT中的MLM(masked language modeling)方法,比如BEiT提出了用于圖像的無(wú)監(jiān)督學(xué)習(xí)方法:MIM(masked image modeling)。無(wú)疑,MAE方法也落在MIM的范疇,但整個(gè)論文會(huì)給人更震撼之感,因?yàn)镸AE方法更簡(jiǎn)單有效。
NLP領(lǐng)域的BERT提出的預(yù)訓(xùn)練方法本質(zhì)上也是一種masked autoencoding:去除數(shù)據(jù)的一部分然后學(xué)習(xí)恢復(fù)。這種masked autoencoding方法也很早就在圖像領(lǐng)域應(yīng)用,比如Stacked Denoising Autoencoders。但是NLP領(lǐng)域已經(jīng)在BERT之后采用這種方法在無(wú)監(jiān)督學(xué)習(xí)上取得非常大的進(jìn)展,比如目前已經(jīng)可以訓(xùn)練超過1000億參數(shù)的大模型,但是圖像領(lǐng)域卻遠(yuǎn)遠(yuǎn)落后,而且目前主流的無(wú)監(jiān)督訓(xùn)練還是對(duì)比學(xué)習(xí)。那么究竟是什么造成了masked autoencoding方法在NLP和CV上的差異呢?MAE論文從三個(gè)方面做了分析,這也是MAE方法的立意:
圖像的主流模型是CNN,而NLP的主流模型是transformer,CNN和transformer的架構(gòu)不同導(dǎo)致NLP的BERT很難直接遷移到CV。但是vision transformer的出現(xiàn)已經(jīng)解決這個(gè)問題; 圖像和文本的信息密度不同,文本是高語(yǔ)義的人工創(chuàng)造的符號(hào),而圖像是一種自然信號(hào),兩者采用masked autoencoding建模任務(wù)難度就不一樣,從句子中預(yù)測(cè)丟失的詞本身就是一種復(fù)雜的語(yǔ)言理解任務(wù),但是圖像存在很大的信息冗余,一個(gè)丟失的圖像塊很容易利用周邊的圖像區(qū)域進(jìn)行恢復(fù); 用于重建的decoder在圖像和文本任務(wù)發(fā)揮的角色有區(qū)別,從句子中預(yù)測(cè)單詞屬于高語(yǔ)義任務(wù),encoder和decoder的gap小,所以BERT的decoder部分微不足道(只需要一個(gè)MLP),而對(duì)圖像重建像素屬于低語(yǔ)義任務(wù)(相比圖像分類),encoder需要發(fā)揮更大作用:將高語(yǔ)義的中間表征恢復(fù)成低語(yǔ)義的像素值。
基于這三個(gè)的分析,論文提出了一種用于圖像領(lǐng)域(ViT模型)的更簡(jiǎn)單有效的無(wú)監(jiān)督訓(xùn)練方法:MAE(masked autoencoder),隨機(jī)mask掉部分patchs然后進(jìn)行重建,其整體架構(gòu)如下所示。MAE采用encoder-decoder結(jié)構(gòu)(分析3,需要單獨(dú)的decoder),但屬于非對(duì)稱結(jié)構(gòu),一方面decoder采用比encoder更輕量級(jí)設(shè)計(jì),另外一方面encoder只處理一部分patchs(visible patchs,除了masked patchs之外的patchs),而decoder處理所有的patchs。一個(gè)很重要的點(diǎn),MAE采用很高的masking ratio(比如75%甚至更高),這契合分析2,這樣構(gòu)建的學(xué)習(xí)任務(wù)大大降低了信息冗余,也使得encoder能學(xué)習(xí)到更高級(jí)的特征。由于encoder只處理visible patchs,所以很高的masking ratio可以大大降低計(jì)算量。

MAE采用的masking策略是簡(jiǎn)單的隨機(jī)mask:基于均勻分布從圖像的patchs隨機(jī)抽樣一部分patchs進(jìn)行mask。每個(gè)被mask的patch采用mask token來替代,mask token是一個(gè)共享且可學(xué)習(xí)的向量。MAE的encoder采用ViT模型,只處理visible patchs,visible patchs通過linear projection得到patch embedding輸入到ViT的transformer blocks進(jìn)行處理;而decoder是一個(gè)輕量級(jí)模塊,主體包含幾個(gè)transformer blocks,而最后一層是一個(gè)linear層(輸出是和一個(gè)patch像素?cái)?shù)一致),用來直接預(yù)測(cè)masked patch的像素值。decoder的輸入是所有的tokens:encoded visible patchs和mask tokens,它們要加上對(duì)應(yīng)的positional embeddings。訓(xùn)練的loss采用簡(jiǎn)單的MSE:計(jì)算預(yù)測(cè)像素值和原始像素值的均方誤差,不過loss只計(jì)算masked patchs。MAE的實(shí)現(xiàn)非常簡(jiǎn)單:首先對(duì)輸入的patch進(jìn)行l(wèi)inear projection得到patch embeddings,并加上positional embeddings(采用sine-cosine版本);然后對(duì)tokens列表進(jìn)行random shuffle,根據(jù)masking ratio去掉列表中后面的一部分tokens,然后送入encoder中,這里注意ViT中需要一個(gè)class token來做圖像分類,所以這里的輸入也要增加一個(gè)dummy token(如果最后分類采用global avg pooling就不需要這個(gè));encoder處理后,在tokens列表后面補(bǔ)足mask tokens,然后通過unshuffle來恢復(fù)tokens列表中tokens的原始位置,然后再加上positional embeddings(mask tokens本身并無(wú)位置信息,所以還要此操作)送入decoder中進(jìn)行處理。具體的代碼實(shí)現(xiàn)如下:
class?MaskedAutoencoderViT(nn.Module):
????"""?Masked?Autoencoder?with?VisionTransformer?backbone
????"""
????def?__init__(self,?img_size=224,?patch_size=16,?in_chans=3,
?????????????????embed_dim=1024,?depth=24,?num_heads=16,
?????????????????decoder_embed_dim=512,?decoder_depth=8,?decoder_num_heads=16,
?????????????????mlp_ratio=4.,?norm_layer=nn.LayerNorm,?norm_pix_loss=False):
????????super().__init__()
????????#?--------------------------------------------------------------------------
????????#?MAE?encoder?specifics
????????self.patch_embed?=?PatchEmbed(img_size,?patch_size,?in_chans,?embed_dim)
????????num_patches?=?self.patch_embed.num_patches
????????self.cls_token?=?nn.Parameter(torch.zeros(1,?1,?embed_dim))
????????self.pos_embed?=?nn.Parameter(torch.zeros(1,?num_patches?+?1,?embed_dim),?requires_grad=False)??#?fixed?sin-cos?embedding
????????self.blocks?=?nn.ModuleList([
????????????Block(embed_dim,?num_heads,?mlp_ratio,?qkv_bias=True,?qk_scale=None,?norm_layer=norm_layer)
????????????for?i?in?range(depth)])
????????self.norm?=?norm_layer(embed_dim)
????????#?--------------------------------------------------------------------------
????????#?--------------------------------------------------------------------------
????????#?MAE?decoder?specifics
????????self.decoder_embed?=?nn.Linear(embed_dim,?decoder_embed_dim,?bias=True)
????????self.mask_token?=?nn.Parameter(torch.zeros(1,?1,?decoder_embed_dim))
????????self.decoder_pos_embed?=?nn.Parameter(torch.zeros(1,?num_patches?+?1,?decoder_embed_dim),?requires_grad=False)??#?fixed?sin-cos?embedding
????????self.decoder_blocks?=?nn.ModuleList([
????????????Block(decoder_embed_dim,?decoder_num_heads,?mlp_ratio,?qkv_bias=True,?qk_scale=None,?norm_layer=norm_layer)
????????????for?i?in?range(decoder_depth)])
????????self.decoder_norm?=?norm_layer(decoder_embed_dim)
????????self.decoder_pred?=?nn.Linear(decoder_embed_dim,?patch_size**2?*?in_chans,?bias=True)?#?encoder?to?decoder
????????#?--------------------------------------------------------------------------
????????self.norm_pix_loss?=?norm_pix_loss
????????self.initialize_weights()
????def?patchify(self,?imgs):
????????"""
????????imgs:?(N,?3,?H,?W)
????????x:?(N,?L,?patch_size**2?*3)
????????"""
????????p?=?self.patch_embed.patch_size[0]
????????assert?imgs.shape[2]?==?imgs.shape[3]?and?imgs.shape[2]?%?p?==?0
????????h?=?w?=?imgs.shape[2]?//?p
????????x?=?imgs.reshape(shape=(imgs.shape[0],?3,?h,?p,?w,?p))
????????x?=?torch.einsum('nchpwq->nhwpqc',?x)
????????x?=?x.reshape(shape=(imgs.shape[0],?h?*?w,?p**2?*?3))
????????return?x
????def?unpatchify(self,?x):
????????"""
????????x:?(N,?L,?patch_size**2?*3)
????????imgs:?(N,?3,?H,?W)
????????"""
????????p?=?self.patch_embed.patch_size[0]
????????h?=?w?=?int(x.shape[1]**.5)
????????assert?h?*?w?==?x.shape[1]
????????
????????x?=?x.reshape(shape=(x.shape[0],?h,?w,?p,?p,?3))
????????x?=?torch.einsum('nhwpqc->nchpwq',?x)
????????imgs?=?x.reshape(shape=(x.shape[0],?3,?h?*?p,?h?*?p))
????????return?imgs
????def?random_masking(self,?x,?mask_ratio):
????????"""
????????Perform?per-sample?random?masking?by?per-sample?shuffling.
????????Per-sample?shuffling?is?done?by?argsort?random?noise.
????????x:?[N,?L,?D],?sequence
????????"""
????????N,?L,?D?=?x.shape??#?batch,?length,?dim
????????len_keep?=?int(L?*?(1?-?mask_ratio))
????????
????????noise?=?torch.rand(N,?L,?device=x.device)??#?noise?in?[0,?1]
????????
????????#?sort?noise?for?each?sample
????????ids_shuffle?=?torch.argsort(noise,?dim=1)??#?ascend:?small?is?keep,?large?is?remove
????????ids_restore?=?torch.argsort(ids_shuffle,?dim=1)
????????#?keep?the?first?subset
????????ids_keep?=?ids_shuffle[:,?:len_keep]
????????x_masked?=?torch.gather(x,?dim=1,?index=ids_keep.unsqueeze(-1).repeat(1,?1,?D))
????????#?generate?the?binary?mask:?0?is?keep,?1?is?remove
????????mask?=?torch.ones([N,?L],?device=x.device)
????????mask[:,?:len_keep]?=?0
????????#?unshuffle?to?get?the?binary?mask
????????mask?=?torch.gather(mask,?dim=1,?index=ids_restore)
????????return?x_masked,?mask,?ids_restore
????def?forward_encoder(self,?x,?mask_ratio):
????????#?embed?patches
????????x?=?self.patch_embed(x)
????????#?add?pos?embed?w/o?cls?token
????????x?=?x?+?self.pos_embed[:,?1:,?:]
????????#?masking:?length?->?length?*?mask_ratio
????????x,?mask,?ids_restore?=?self.random_masking(x,?mask_ratio)
????????#?append?cls?token
????????cls_token?=?self.cls_token?+?self.pos_embed[:,?:1,?:]
????????cls_tokens?=?cls_token.expand(x.shape[0],?-1,?-1)
????????x?=?torch.cat((cls_tokens,?x),?dim=1)
????????#?apply?Transformer?blocks
????????for?blk?in?self.blocks:
????????????x?=?blk(x)
????????x?=?self.norm(x)
????????return?x,?mask,?ids_restore
????def?forward_decoder(self,?x,?ids_restore):
????????#?embed?tokens
????????x?=?self.decoder_embed(x)
????????#?append?mask?tokens?to?sequence
????????mask_tokens?=?self.mask_token.repeat(x.shape[0],?ids_restore.shape[1]?+?1?-?x.shape[1],?1)
????????x_?=?torch.cat([x[:,?1:,?:],?mask_tokens],?dim=1)??#?no?cls?token
????????x_?=?torch.gather(x_,?dim=1,?index=ids_restore.unsqueeze(-1).repeat(1,?1,?x.shape[2]))??#?unshuffle
????????x?=?torch.cat([x[:,?:1,?:],?x_],?dim=1)??#?append?cls?token
????????#?add?pos?embed
????????x?=?x?+?self.decoder_pos_embed
????????#?apply?Transformer?blocks
????????for?blk?in?self.decoder_blocks:
????????????x?=?blk(x)
????????x?=?self.decoder_norm(x)
????????#?predictor?projection
????????x?=?self.decoder_pred(x)
????????#?remove?cls?token
????????x?=?x[:,?1:,?:]
????????return?x
????def?forward_loss(self,?imgs,?pred,?mask):
????????"""
????????imgs:?[N,?3,?H,?W]
????????pred:?[N,?L,?p*p*3]
????????mask:?[N,?L],?0?is?keep,?1?is?remove,?
????????"""
????????target?=?self.patchify(imgs)
????????if?self.norm_pix_loss:
????????????mean?=?target.mean(dim=-1,?keepdim=True)
????????????var?=?target.var(dim=-1,?keepdim=True)
????????????target?=?(target?-?mean)?/?(var?+?1.e-6)**.5
????????loss?=?(pred?-?target)?**?2
????????loss?=?loss.mean(dim=-1)??#?[N,?L],?mean?loss?per?patch
????????loss?=?(loss?*?mask).sum()?/?mask.sum()??#?mean?loss?on?removed?patches
????????return?loss
????def?forward(self,?imgs,?mask_ratio=0.75):
????????latent,?mask,?ids_restore?=?self.forward_encoder(imgs,?mask_ratio)
????????pred?=?self.forward_decoder(latent,?ids_restore)??#?[N,?L,?p*p*3]
????????loss?=?self.forward_loss(imgs,?pred,?mask)
????????return?loss,?pred,?mask
論文選擇ViT-Large(ViT-L/16)作為encoder在ImageNet-1K上實(shí)驗(yàn),首先進(jìn)行無(wú)監(jiān)督預(yù)訓(xùn)練,然后進(jìn)行監(jiān)督訓(xùn)練以評(píng)估encoder的表征能力,包括常用linear probing和finetune兩個(gè)實(shí)驗(yàn)結(jié)果。下表是baseline MAE方法的實(shí)驗(yàn)結(jié)果,可以看到經(jīng)過MAE預(yù)訓(xùn)練后finetune的效果要超過直接從頭訓(xùn)練(84.9 vs 82.5):
更重要的是,論文做了MAE各個(gè)部分的不同設(shè)置對(duì)比實(shí)驗(yàn),這些實(shí)驗(yàn)?zāi)軌蚪沂綧AE更多的特性。首先是masking ratio,從下圖可以看到,最優(yōu)的設(shè)置是75%的masking ratio,此時(shí)linear probing和finetune效果最好,這比之前的研究要高很多,比如BEiT的masking ratio是40%。另外也可以看到linear probing和finetune的表現(xiàn)不一樣,linear probing效果隨著masking ratio的增加逐漸提高直至一個(gè)峰值后出現(xiàn)下降,而finetune效果在不同making ratio下差異小,masking ratio在40%~80%范圍內(nèi)均能表現(xiàn)較好。
這么高的masking ratio,模型到底能學(xué)習(xí)到什么?這里采用預(yù)訓(xùn)練好的模型在驗(yàn)證集進(jìn)行重建,效果如下所示,可以看到decoder重建出來的圖像還是比較讓人驚艷的(95%的masking ratio竟然也能work?。?,這或許說明模型已經(jīng)學(xué)習(xí)到比較好的特征。
第二個(gè)是encoder的設(shè)計(jì),這里主要探討decoder的深度(transformer blocks數(shù)量)和寬度(channels數(shù)量)對(duì)效果的影響,實(shí)驗(yàn)結(jié)果如下表所示。首先,要想得到比較好的linear probing效果,就需要一個(gè)比較深的decoder,這不難理解,前面說過重建圖像和圖像識(shí)別兩個(gè)任務(wù)的gap較大,如果decoder比較深,那么decoder就有足夠的容量學(xué)習(xí)到重建能力,這樣encoder可以更專注于提取特征。但是不同的深度對(duì)finetune效果影響較小,只用一個(gè)transformer block就可以work。相比之下,網(wǎng)絡(luò)寬度對(duì)linear probing影響比網(wǎng)絡(luò)深度要小一點(diǎn)。論文選擇的默認(rèn)設(shè)置是:8個(gè)blocks,width為512,一個(gè)token的FLOPs只有encoder的9%。
第三個(gè)是mask token,這里探討的是encoder是否處理mask tokens帶來的影響,從對(duì)比實(shí)驗(yàn)來看,encoder不處理mask tokens不僅效果更好而且訓(xùn)練更高效,首先linear probing的效果差異非常大,如果encoder也處理mask tokens,此時(shí)linear probing的效果較差,這主要是訓(xùn)練和測(cè)試的不一致帶來的,因?yàn)闇y(cè)試時(shí)都是正常的圖像,但經(jīng)過finetune后也能得到較好的效果。最重要的是,不處理mask tokens模型的FLOPs大大降低(3.3x),而且訓(xùn)練也能加速2.8倍,這里也可以看到采用較小的decoder可以進(jìn)一步加速訓(xùn)練。
第四個(gè)是探討不同的重建目標(biāo)對(duì)效果的影響,從對(duì)比實(shí)驗(yàn)看,如果對(duì)像素值做歸一化處理(用patch所有像素點(diǎn)的MAEn和std),效果有一定提升,采用PCA處理效果無(wú)提升。這里也實(shí)驗(yàn)了BEiT采用的dVAE tokenizer,此時(shí)訓(xùn)練loss是交叉熵,從效果上看比baseline有一定提升(finetune有提升,但是linear probing下降),但不如歸一化處理的結(jié)果。注意的是dVAE tokenizer需要非常大的數(shù)據(jù)來單獨(dú)訓(xùn)練,這是非常不方便的。
第五個(gè)是數(shù)據(jù)增強(qiáng)的影響,這里讓人驚奇的是MAE在無(wú)數(shù)據(jù)增強(qiáng)下(center crop)依然可以表現(xiàn)出好的效果,如果采用random crop(固定size或隨機(jī)size)+random horizontal flipping(其實(shí)也屬于輕量級(jí))效果有微弱的提升,但加上color jit效果反而有所下降。相比之下,對(duì)比學(xué)習(xí)往往需要非常heavy的數(shù)據(jù)增強(qiáng)。這差異的背后主要是因?yàn)镸AE采用的random mask patch已經(jīng)起到了數(shù)據(jù)增強(qiáng)的效果。
第六個(gè)是mask sampling策略的影響,相比BEiT采用的block-wise或grid-wise方式,random sampling效果最好。
另外,論文也發(fā)現(xiàn)MAE和對(duì)比學(xué)習(xí)方法在training schedule上也存在差異,之前的實(shí)驗(yàn)都是基于800 epoch的訓(xùn)練時(shí)長(zhǎng),而實(shí)驗(yàn)發(fā)現(xiàn)訓(xùn)練到更長(zhǎng)的epoch(1600 epoch+),模型的linear probing性能依然還在上升,而MoCoV3在300 epoch后就飽和了。不過,MAE在75%的masking ratio下每個(gè)epoch其實(shí)只相當(dāng)于見了25%的數(shù)據(jù),而對(duì)比學(xué)習(xí)往往學(xué)習(xí)two-crop和multi-crop,每個(gè)epoch見到的數(shù)據(jù)在200%以上,這也意味著MAE可以訓(xùn)練更多的epoch。雖然MAE訓(xùn)練更長(zhǎng),但是由于其特殊的設(shè)置,基于ViT-L的MAE訓(xùn)練1600 epoch的時(shí)長(zhǎng)比MoCoV3訓(xùn)練300 epoch還要短(31h vs 36h)。

MAE與其它無(wú)監(jiān)督方法的對(duì)比如下所示,可以看到在同樣條件下MAE要比BEiT更好,而且也超過有監(jiān)督訓(xùn)練,其中ViT-H在448大小finetune后在ImageNet上達(dá)到了87.8%的top1 acc。不過MAE的效果還是比谷歌采用JFT300M訓(xùn)練的ViT要差一些,這說明訓(xùn)練數(shù)據(jù)量可能是一個(gè)瓶頸。在linear probing方面,MAE要比其它的MIMI方法要好很多,前面已經(jīng)說過,這主要?dú)w功于encoder不處理mask tokens。
在魯棒性方面,論文測(cè)試了幾種ImageNet數(shù)據(jù)集的變種,從下表可以看到,相比直接有監(jiān)督訓(xùn)練模型,基于MAE先預(yù)訓(xùn)練再finetune的模型魯棒性更好。比如在ImageNet-A數(shù)據(jù)集上,基于MAE的ViT-H模型的top1-acc遠(yuǎn)高于有監(jiān)督模型(68.2% vs 33.1%)。
同時(shí),論文也對(duì)比了MAE訓(xùn)練的encoder在下游任務(wù)(檢測(cè)和分割)的遷移能力,同等條件下,MAE均能超過有監(jiān)督訓(xùn)練或者其它無(wú)監(jiān)督訓(xùn)練方法:
這里要注意的一點(diǎn)是檢測(cè)和分割模型需要多尺度的特征(即FPN),而ViT模型只輸出一種尺度的特征(比如1/16大小特征),這里采用XCiT所提出的一種簡(jiǎn)單策略來產(chǎn)生多尺度特征,即對(duì)ViT的中間特征進(jìn)行上采樣和下采樣。這里以Mask R-CNN模型為例,它需要提出backbone的1/4,1/8,1/16和1/32共4個(gè)level的特征,而ViT16只輸出1/16的特征,這里將ViT的transformer blocks均分成4個(gè)部分,假定d為ViT的blocks數(shù)量,那么分別用位置為d/4,2d/4,3d/4和d的block的輸出來提取特征,這里位置為d/4的block的輸出需要上采樣4x才能得到1/4大小的特征,可以通過兩個(gè)stride=2的2x2反卷積操作來實(shí)現(xiàn)(第一個(gè)反卷積后接GN和GeLU),而位置為2d/4的block的輸出只需要一個(gè)stride=2的2x2反卷積就能得到1/8大小的特征,對(duì)于位置為3d/4的block的輸出則不需要任何操作,最后一個(gè)block的輸出可以通過stride=2的2x2 max-pooling來產(chǎn)生1/32特征。(具體見論文Benchmarking Detection Transfer Learning with Vision Transformers)

論文最后還有一個(gè)額外的部分,那就是對(duì)linear probing評(píng)估方式的討論。從前面的實(shí)驗(yàn)我們看到,雖然MAE訓(xùn)練的encoder在finetune下能取得比較SOTA的結(jié)果,但是其linear probing和finetune效果存在不小的差異,單從linear probing效果來看,MAE并不比MoCoV3要好(ViT-L:73.5 vs 77.6)。雖然linear probing一直是無(wú)監(jiān)督訓(xùn)練的最常用的評(píng)估方法,但是它追求的是encoder提取特征的線性可分能力,這不并能成為唯一的一個(gè)評(píng)價(jià)指標(biāo),而且linear probing也不能很好地和下游任務(wù)遷移能力關(guān)聯(lián)起來。所以論文額外做了partial fine-tuning的實(shí)驗(yàn),這里可以看到如果僅對(duì)encoder的最后一個(gè)block進(jìn)行finetune的話,MAE就能達(dá)到和MoCoV3一樣的效果,如果finetune更多的blocks,MAE就會(huì)超過MoCoV3。這說明雖然MAE得到的特征線性可分能力差了點(diǎn),但是它其實(shí)是更強(qiáng)的非線性特征。
最后談一點(diǎn)自己對(duì)MAE的認(rèn)識(shí):首先MAE并不是第一個(gè)基于MIM方法做無(wú)監(jiān)督訓(xùn)練,之前微軟的BEiT基于MIM也取得了很好的效果,還有MST和iBOT等工作。但是MAE讓人看起來更簡(jiǎn)單有效,比如BEiT需要單獨(dú)訓(xùn)練的tokenizer,而其它的一些工作往往引入了對(duì)比學(xué)習(xí)的類似設(shè)計(jì)。對(duì)于MAE的成功,我覺得是一些突破常規(guī)的設(shè)計(jì),比如很高的masking ratio,這是很難想象會(huì)work的,但MAE卻證明了這是成功的關(guān)鍵。
參考
Mocov3: An Empirical Study of Training Self-Supervised Vision Transformers DINO: Emerging Properties in Self-Supervised Vision Transformers MST: Masked Self-Supervised Transformer for Visual Representation BEiT: BERT Pre-Training of Image Transformers EsViT: Efficient Self-supervised Vision Transformers for Representation Learning Image BERT Pre-training with Online Tokenizer Masked Autoencoders Are Scalable Vision Learners https://github.com/facebookresearch/mae
推薦閱讀
RegNet:設(shè)計(jì)網(wǎng)絡(luò)設(shè)計(jì)空間
PyTorch1.10發(fā)布:ZeroRedundancyOptimizer和Join
谷歌AI用30億數(shù)據(jù)訓(xùn)練了一個(gè)20億參數(shù)Vision Transformer模型,在ImageNet上達(dá)到新的SOTA!
Transformer在語(yǔ)義分割上的應(yīng)用
"未來"的經(jīng)典之作ViT:transformer is all you need!
PVT:可用于密集任務(wù)backbone的金字塔視覺transformer!
漲點(diǎn)神器FixRes:兩次超越ImageNet數(shù)據(jù)集上的SOTA
不妨試試MoCo,來替換ImageNet上pretrain模型!
機(jī)器學(xué)習(xí)算法工程師
? ??? ? ? ? ? ? ? ? ? ? ????????? ??一個(gè)用心的公眾號(hào)

