漲點明顯 | 港中文等提出SplitNet結合Co-Training提升Backbone性能

??新智元報道??
??新智元報道??
來源:AI人工智能初學者
作者:王浩帆
【新智元導讀】本文提出網(wǎng)絡的"數(shù)目"應該是有效模型縮放的新維度,也因此提出了SplitNet,這是首次提出此類討論的工作,可引入現(xiàn)有CNN網(wǎng)絡中,漲點明顯!

背景介紹
神經(jīng)網(wǎng)絡的寬度很重要,因為增加寬度必然會增加模型的容量。但是,網(wǎng)絡的性能不會隨著寬度的增加而線性提高,并且很快就會飽和。

為了解決這個問題,作者提出增加網(wǎng)絡的數(shù)量,而不是單純地擴大寬度。為了證明這一點,將一個大型網(wǎng)絡劃分為幾個小型網(wǎng)絡,每個小型網(wǎng)絡都具有原始參數(shù)的一小部分。
然后將這些小型網(wǎng)絡一起訓練,并使它們看到相同數(shù)據(jù)的各種視圖,以學習不同的補充知識。在此共同訓練過程中,網(wǎng)絡也可以互相學習。

實驗結果表明,與沒有或沒有額外參數(shù)或FLOP的大型網(wǎng)絡相比,小型網(wǎng)絡可以獲得更好的整體性能。
這表明,除了深度/寬度/分辨率之外,網(wǎng)絡的數(shù)量是有效模型縮放的新維度。通過在不同設備上同時運行,小型網(wǎng)絡也可以比大型網(wǎng)絡實現(xiàn)更快的推理速度。
相關工作
2.1 神經(jīng)網(wǎng)絡結構設計
自從AlexNet的得到非常好的效果以后,深度學習方法便成為計算機視覺領域的主導,神經(jīng)網(wǎng)絡設計也成為一個核心話題。AlexNet之后出現(xiàn)了許多優(yōu)秀的架構,如NIN、VGGNet、Inception、ResNet、Xception等。很多研究者設計了高效的模型,如1*1卷積核、用小核堆疊卷積層、不同卷積與池化操作的組合、殘差連接、深度可分離卷積等。
近年來,神經(jīng)網(wǎng)絡結構搜索(NAS)越來越受歡迎。人們希望通過機器學習方法自動學習或搜索某些任務的最佳神經(jīng)結構。在這里只列舉幾個,基于強化學習的NAS方法、漸進式神經(jīng)架構搜索(PNASNet)、可微分架構搜索(DARTS),等等。
2.2 協(xié)同學習
協(xié)同學習最初是教育中的一個總稱,指的是學生或教師共同努力學習的教育方法。它被正式引入深度學習中,用來描述同一網(wǎng)絡中多個分類器同時訓練的情況。
然而,按照其最初的定義,涉及兩種或兩種以上模型共同學習的作品,也可以稱為協(xié)作學習,如深度相互學習(deep mutual learning, DML)、聯(lián)合訓練、互意教學、合作學習、知識蒸餾等。雖然有不同的名字,但核心思想是相似的,即通過一些同伴或老師的訓練來提高一個或所有的模型的性能。
本文方法
這里先簡單回顧了常用的深度圖像分類框架。給定一個神經(jīng)網(wǎng)絡模型M和N個訓練樣本有C類,訓練M模型最常用的損失函數(shù)是交叉熵損失函數(shù):

其中y為ground truth label;p為估計概率,通常由M的最后softmax層給出。
3.1 分割模型

上圖中根據(jù)網(wǎng)絡的寬度將一個大網(wǎng)絡M劃分為S個小網(wǎng)絡。當按寬度劃分時,實際上指的是按參數(shù)量或FLOPs劃分。
3.1.1、如何計算參數(shù)量和FLOPs?
例如,如果要把M分成兩個網(wǎng)絡,或的參數(shù)數(shù)量應該大約是的一半。這里給出計算神經(jīng)網(wǎng)絡的參數(shù)和FLOPs的方法。
通常是卷積層的累加,因此在這里只討論如何計算一個卷積層的參數(shù)量和FLOPs。在PyTorch的conv-layer的定義中,其卷積核大小是,輸入和輸出通道的特征圖為和, d為累加卷積的數(shù)量,這意味著每個輸入通道將與的卷積核進行卷積。在這種情況下,這個conv-layer的參數(shù)量(Params)和FLOPs:

其中是輸出特征圖的大小,-1是因為次加法只需要次操作。為了簡潔起見,省略了偏差。對于Depthwise卷積來說。
3.1.2、如何切分網(wǎng)絡?
通常,其中為常數(shù)。因此,如果想通過除以因子S來分割卷積層,只需要用除以:

1)切分ResNet
舉例如下:如果我們想分割一個ResNet的Bottleneck Block:

通過除以2可以得到如下的4個小的Blocks:

這里每個Block只有原來Block的四分之一的參數(shù)量和FLOPs。
在實際應用中,網(wǎng)絡中特征映射的輸出通道數(shù)具有最大公約數(shù)(GCD)。大多數(shù)ResNet變體的GCD是第一卷積層的。
對于其他網(wǎng)絡,比如EfficienctNet,他們的GCD是8或者其倍數(shù)。一般來說,當想一個網(wǎng)絡切分為S塊時,只需要找到它的GCD,然后用就可以了。
#?The?below?is?the?same?as?max(widen_factor?/?(split_factor?**?0.5)?+?0.4,?1.0)
?????if?arch?==?'wide_resnet50_2'?and?split_factor?==?2:
??????self.inplanes?=?64
??????width_per_group?=?64
??????print('INFO:PyTorch:?Dividing?wide_resnet50_2,?change?base_width?from?{}?'
????????'to?{}.'.format(64?*?2,?64))
?????if?arch?==?'wide_resnet50_3'?and?split_factor?==?2:
??????self.inplanes?=?64
??????width_per_group?=?64?*?2
??????print('INFO:PyTorch:?Dividing?wide_resnet50_3,?change?base_width?from?{}?'
????????'to?{}.'.format(64?*?3,?64?*?2))
2)切分ResNeXt
對于ResNeXt網(wǎng)絡,當固定時,即,其中為常數(shù),則式上一個等式需要變化為:

這意味著只需要通過channel數(shù)量除以便可以得到d組小的Block。
self.dropout?=?None
??if?'cifar'?in?dataset:
???if?arch?in?['resnext29_16x64d',?'resnext29_8x64d',?'wide_resnet16_8',?'wide_resnet40_10']:
????if?dropout_p?is?not?None:
?????dropout_p?=?dropout_p?/?split_factor
?????#?You?can?also?use?the?below?code.
?????#?dropout_p?=?dropout_p?/?(split_factor?**?0.5)
?????print('INFO:PyTorch:?Using?dropout?with?ratio?{}'.format(dropout_p))
?????self.dropout?=?nn.Dropout(dropout_p)
??elif?'imagenet'?in?dataset:
???if?dropout_p?is?not?None:
????dropout_p?=?dropout_p?/?split_factor
????#?You?can?also?use?the?below?code.
????#?dropout_p?=?dropout_p?/?(split_factor?**?0.5)
????print('INFO:PyTorch:?Using?dropout?with?ratio?{}'.format(dropout_p))
????self.dropout?=?nn.Dropout(dropout_p)??
3.1.3、權重衰減
對于權重衰減問題,由于其內在機理尚不清楚,所以有一點復雜。本工作中使用了2種劃分策略:無劃分和指數(shù)劃分:

其中為原權重衰減值,為除法后的新權重值。不除意味著權重衰減值保持不變。
如上所述,權重衰減的潛在機制尚不清楚,因此很難找到最佳的、普遍的解決方案。以上兩種劃分策略只是經(jīng)驗標準。在實踐中,現(xiàn)在最好的方法是嘗試。
3.1.4、小型網(wǎng)絡的并發(fā)運行
盡管小網(wǎng)絡具有更好的集成性能,但在大多數(shù)情況下,小網(wǎng)絡也可以通過在不同的設備上部署不同的小型模型,并且通過并發(fā)運行來實現(xiàn)比大網(wǎng)絡更快的推理速度。如圖所示。

典型的設備是NVIDIA的GPU。理論上,如果一個GPU有足夠的處理單元,例如流處理器、CUDA核等,小型網(wǎng)絡也可以在一個GPU內并發(fā)運行。
然而,一個小的網(wǎng)絡已經(jīng)能夠占用大部分的計算資源,不同的網(wǎng)絡只能按順序運行。因此,本文只討論多設備的方式。
小網(wǎng)絡并發(fā)推理的成功也表明了訓練并發(fā)性的可能性。目前,小網(wǎng)絡在訓練過程中是按順序運行的,導致訓練時間比大網(wǎng)絡長。
然而,設計一個靈活的、可伸縮的框架是相當困難的,它能夠支持在多個設備上對多個模型進行異步訓練,并且在前向推理和反向傳播的過程中也需要進行通信。
3.2 聯(lián)合訓練
一個大網(wǎng)絡M分割后變成S個小網(wǎng)絡?,F(xiàn)在的問題是如何讓這些小網(wǎng)絡從數(shù)據(jù)中學習不同的互補知識。在引入聯(lián)合訓練部分之前,強調劃分和聯(lián)合訓練的設計只是為了說明增加網(wǎng)絡數(shù)量是有效模型縮放的一個新維度的核心思想。
而聯(lián)合訓練部分是由deep mutual learning(DML)、co-training和mutual mean-teaching(MMT)所啟發(fā)的。
不同的初始化方式和數(shù)據(jù)views
一個基本的理解是學習一些相同的網(wǎng)絡是沒有意義的。相比之下,小型網(wǎng)絡則需要學習有關數(shù)據(jù)的互補的知識,以獲得一個全面的數(shù)據(jù)理解。

為此,首先,對小網(wǎng)絡進行不同權值的初始化。然后,在輸入訓練數(shù)據(jù)時,對不同網(wǎng)絡的相同數(shù)據(jù)使用不同的數(shù)據(jù)轉換器,如上圖所示。這樣,小模型便可以在不同的變換域下進行學習和訓練。
在實際應用中,不同的數(shù)據(jù)域是由數(shù)據(jù)增廣隨機性產(chǎn)生的。除了常用的隨機調整/剪切/翻轉策略外,作者還進一步介紹了隨機擦除和AutoAugment 策略。AutoAugment 有14個圖像變換操作,如剪切,平移,旋轉,自動對比度等。
該算法針對不同的數(shù)據(jù)集搜索了幾十種由兩種轉換操作組成的策略,并在數(shù)據(jù)擴充過程中隨機選擇一種策略。
源碼如下:
??train_transform?=?transforms.Compose([transforms.RandomCrop(32,?padding=4),
?????????????transforms.RandomHorizontalFlip(),
?????????????CIFAR10Policy(),
?????????????transforms.ToTensor(),
?????????????transforms.Normalize((0.4914,?0.4822,?0.4465),
??????????????????(0.2023,?0.1994,?0.2010)),
?????????????transforms.RandomErasing(p=erase_p,
???????????????????scale=(0.125,?0.2),
???????????????????ratio=(0.99,?1.0),
???????????????????value=0,?inplace=False),
????????????])
3.3 聯(lián)合訓練損失函數(shù)
遵循半監(jiān)督學習中共同訓練的假設,小網(wǎng)絡雖然對的view不同,但對x的預測是一致的:

因此,在目標函數(shù)中加入預測概率分布之間的Jensen-Shannon(JS)散度,即聯(lián)合訓練損失函數(shù):

其中是一個小網(wǎng)絡的估計概率,是分布p的Shannon熵。聯(lián)合訓練損失也用于DML中,但具體形式不同,即DML使用的是兩種預測之間的KullbacLeibler(KL)散度。
通過這種聯(lián)合訓練的方式,一個網(wǎng)絡還可以從它的同伴那里學到一些有價值的東西,因為預測的概率包含有關于物體的有意義的信息。
例如,一個將一個物體分類為Chihuahua的model可能也會對Japanese spaniel有很高的信心,因為它們都是狗。這是有價值的信息,定義了對象上豐富的相似結構。
總體目標函數(shù)為:

式中,是通過交叉驗證選擇權重因子。
損失函數(shù)源碼如下:
def?_co_training_loss(self,?outputs,?loss_choose,?epoch=0):
??"""calculate?the?co-training?loss?between?outputs?of?different?small?networks
??"""
??weight_now?=?self.cot_weight
??if?self.is_cot_weight_warm_up?and?epoch????weight_now?=?max(self.cot_weight?*?epoch?/?self.cot_weight_warm_up_epochs,?0.005)
??if?loss_choose?==?'js_divergence':
???#?the?Jensen-Shannon?divergence?between?p(x1),?p(x2),?p(x3)...
???#?https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence
???outputs_all?=?torch.stack(outputs,?dim=0)
???p_all?=?F.softmax(outputs_all,?dim=-1)
???p_mean?=?torch.mean(p_all,?dim=0)
???H_mean?=?(-?p_mean?*?torch.log(p_mean)).sum(-1).mean()
???H_sep?=?(-?p_all?*?F.log_softmax(outputs_all,?dim=-1)).sum(-1).mean()
???cot_loss?=?weight_now?*?(H_mean?-?H_sep)
??else:
???raise?NotImplementedError
??return?cot_loss??
實驗
4.1、各個Backbone的增益

通過上圖可以看出:增加網(wǎng)絡的數(shù)量比單純增加網(wǎng)絡的寬度/深度更有效。
4.2、性能對比

通過上圖可以看出:整體性能與個體性能密切相關。
4.3、CIFAR-100實驗結果

通過上表可以看出:必要的網(wǎng)絡寬度/深度很重要。
4.4、序列和并發(fā)之間的推斷延遲實驗

參考:
[1].SplitNet: Divide and Co-training
[2].https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training


