ThiNet:模型通道結(jié)構(gòu)化剪枝
【GiantPandaCV】ThiNet是一種結(jié)構(gòu)化剪枝,核心思路是找到一個(gè)channel的子集可以近似全集,那么就可以丟棄剩下的channel,對應(yīng)的就是剪掉剩下的channel對應(yīng)的filters。剪枝算法還是三步剪枝:train-prune-finetune,而且是layer by layer的剪枝。本文由作者授權(quán)首發(fā)于GiantPandaCV公眾號。
0、 介紹
ThiNet是南京大學(xué)lamda實(shí)驗(yàn)室出品,是ICCV 2017的文章,文章全名《ThiNet: A Filter Level Pruning Method for Deep Neural Network Compression》。
文章的主要思路是:ThiNet是基于filter剪枝,將filter剪枝操作形式化地定義為一個(gè)優(yōu)化問題,通過下一層的統(tǒng)計(jì)信息來指導(dǎo)當(dāng)前層的剪枝。如果移除當(dāng)前層(記為)filter(記為),那么 層channel和同樣被丟棄;但是如果層的filter的數(shù)量不變,則層的輸出(也是層的輸入)維度不變。也就是發(fā)現(xiàn)這樣的剪枝對層的輸出(也是層的輸入)很小影響,作者提出ThiNet剪枝。大白話就是找到一組channel的輸出跟全部channel的輸出之間的誤差最小(采用均方誤差/最小二乘法去衡量),那么就可以用這組channel來代替全部channel。
ThiNet剪枝流程:選擇channel子集、剪枝、finetune,如下如圖

所以算法的實(shí)現(xiàn)的核心在于如何進(jìn)行channel選擇,一個(gè)channel是一個(gè)filter的計(jì)算結(jié)果,所以二者相互對應(yīng)。
ThiNet有三個(gè)要點(diǎn):
1、如何進(jìn)行通道選擇,通道的子集與全部通道的全集之間的最小二次乘法誤差來做通道重要性判斷依據(jù)
2、最小化重構(gòu)誤差,相當(dāng)于給finetune一個(gè)初始化卷積核參數(shù)
3、對殘差網(wǎng)絡(luò)的剪枝做了適配
一、通道選擇(channel selection)
文章采用貪心算法選擇channel子集(也就是留下來的filter)。ThiNet是迭代式layer by layer的剪枝。
思路1(正向思路):根據(jù)通道重要性判斷找到重要的channel,保留下來,然后迭代式剪枝進(jìn)進(jìn)行直到壓縮率達(dá)到預(yù)設(shè)要求,見公式5。
為什么會有思路1?因?yàn)檎撐牡闹饕悸肥?,找到一組channel的子集可以近似該層channel的全集,那么就是要找到可以留下來的channel,對應(yīng)的就是該channel對應(yīng)的filter;這就是論文的正向思路。
思路1的方法會有一個(gè)問題就是,留下來filter的數(shù)量是從大到小的變化的,那么按照思路1計(jì)算量會很大,因?yàn)榱粝聛淼膄ilter(記為S)在剪枝一開始的時(shí)候要比被移除的filter(記為T)多,所以有
思路2:根據(jù)通道重要性判斷找到要剪枝(丟棄)的filter,然后迭代式剪枝進(jìn)行直到壓縮率達(dá)到預(yù)設(shè)要求(丟棄一定數(shù)量的filter),見公式6。
ThiNet通道重要性判斷是:找到一組通道子集近似通道全集的結(jié)果。

下面公式1-5我是根據(jù)論文寫的,會有點(diǎn)繞,但對復(fù)現(xiàn)這篇論文不是那么重要,核心思路就是上面提到選取一部分channel來近似。
公式1:
公式2:
其中,是第層輸入張量,
是從中隨機(jī)采樣得到的,
是 卷積核的集合,
是對應(yīng)的滑動窗口,
是channels, 是行, 是列,是輸出的通道數(shù), 是bias
公式1和公式2,可以簡化為公式3:
,
公式1~3是為了簡化公式表示的等效變換。
基于通道在中是獨(dú)立的,只取決于,不依賴于, ,則有
公式4:,是channel的子集。
公式5是為了最小化留下來的channel的計(jì)算結(jié)果與原來channel全集的計(jì)算結(jié)果,即為思路1:
變?yōu)楣?,即為思路2:
其中,S ∪ T = {1, 2, . . . , C},S ∩ T = ?,r是壓縮率,C是filter數(shù)量。
基于貪心算法選擇filter子集的算法如圖:

def?channel_selection(inputs,?module,?sparsity=0.5,?method='greedy'):
????"""
????選擇當(dāng)前模塊的輸入通道,以及高度重要的通道。
????找到可以使現(xiàn)有輸出最接近的輸入通道。
????
????:param?inputs:?torch.Tensor,?input?features?map
????:param?module:?torch.nn.module,?layer
????:param?sparsity:?float,?0?~?1?how?many?prune?channel?of?output?of?this?layer
????:param?method:?str,?how?to?select?the?channel
????:return:
????????list?of?int,?indices?of?channel?to?be?selected?and?pruned
????"""
????num_channel?=?inputs.size(1)??#?通道數(shù)
????num_pruned?=?int(math.ceil(num_channel?*?sparsity))??#??輸入需要?jiǎng)h除的通道數(shù)
????num_stayed?=?num_channel?-?num_pruned
????print('num_pruned',?num_pruned)
????if?method?==?'greedy':
????????indices_pruned?=?[]
????????while?len(indices_pruned)?????????????min_diff?=?1e10
????????????min_idx?=?0
????????????for?idx?in?range(num_channel):
????????????????if?idx?in?indices_pruned:
????????????????????continue
????????????????indices_try?=?indices_pruned?+?[idx]
????????????????inputs_try?=?torch.zeros_like(inputs)
????????????????inputs_try[:,?indices_try,?...]?=?inputs[:,?indices_try,?...]
????????????????output_try?=?module(inputs_try)
????????????????output_try_norm?=?output_try.norm(2)?#這里就是公式6
????????????????if?output_try_norm?????????????????????min_diff?=?output_try_norm
????????????????????min_idx?=?idx
????????????indices_pruned.append(min_idx)
????????indices_stayed?=?list(set([i?for?i?in?range(num_channel)])?-?set(indices_pruned))
????????
????inputs?=?inputs.cuda()
????module?=?module.cuda()
????return?indices_stayed,?indices_pruned
二、最小化重構(gòu)誤差(Minimize the reconstruction error)
首先先來看看numpy.linalg.lstsq(),是線性矩陣方程的最小二乘法求解。
最小二乘法的公式為:
| 方法 | 描述 |
|---|---|
| linalg.lstsq(a, b[, rcond]) | 返回線性矩陣方程的最小二乘解 |
numpy.linalg.lstsq(a,?b,?rcond='warn')
#?將least-squares解返回線性矩陣方程。
其中, 是通道選擇后的訓(xùn)練樣本,可以通過 求解
該方法是每一個(gè)通道賦予權(quán)重來進(jìn)一步地減小重構(gòu)誤差。文章說這相當(dāng)于給finetune一個(gè)很好的初始化。
def?weight_reconstruction(module,?inputs,?outputs,?use_gpu=False):
????"""
????reconstruct?the?weight?of?the?next?layer?to?the?one?being?pruned
????:param?module:?torch.nn.module,?module?of?the?this?layer
????:param?inputs:?torch.Tensor,?new?input?feature?map?of?the?this?layer
????:param?outputs:?torch.Tensor,?original?output?feature?map?of?the?this?layer
????:param?use_gpu:?bool,?whether?done?in?gpu
????:return:?void
????"""
????if?module.bias?is?not?None:
????????bias_size?=?[1]?*?outputs.dim()
????????bias_size[1]?=?-1
????????outputs?-=?module.bias.view(bias_size)??#?從?output?feature?中減去?bias?(y?-?b)
????if?isinstance(module,?torch.nn.Conv2d):
????????unfold?=?torch.nn.Unfold(kernel_size=module.kernel_size,?dilation=module.dilation,
?????????????????????????????????padding=module.padding,?stride=module.stride)
????????unfold.eval()
????????x?=?unfold(inputs)??#?展開到以一個(gè)面片(reception?field)為列的三維數(shù)組?(N?*?KKC?*?L?(number?of?fields))
????????x?=?x.transpose(1,?2)??#??transpose?(N?*?KKC?*?L)?->?(N?*?L?*?KKC)
????????num_fields?=?x.size(0)?*?x.size(1)
????????x?=?x.reshape(num_fields,?-1)??#?x:?(NL?*?KKC)
????????y?=?outputs.view(outputs.size(0),?outputs.size(1),?-1)??#?將一個(gè)特征映射展開為一行數(shù)組?(N?*?C?*?WH)
????????y?=?y.transpose(1,?2)??#??transpose?(N?*?C?*?HW)?->?(N?*?HW?*?C),?L?==?HW
????????y?=?y.reshape(-1,?y.size(2))??#?y:?(NHW?*?C),??(NHW)?==?(NL)
????????if?x.size(0)?1)?or?use_gpu?is?False:
????????????x,?y?=?x.cpu(),?y.cpu()
????????????
?#上面一系列的reshape的操作是為了調(diào)用np.linalg.lstsq這個(gè)函數(shù),利用最小二乘法求解weight
????param,?residuals,?rank,?s?=?np.linalg.lstsq(x.detach().cpu().numpy(),y.detach().cpu().numpy(),rcond=-1)
????param?=?param[0:x.size(1),?:].clone().t().contiguous().view(y.size(1),?-1)
????if?isinstance(module,?torch.nn.Conv2d):
????????param?=?param.view(module.out_channels,?module.in_channels,?*module.kernel_size)
????del?module.weight
????module.weight?=?torch.nn.Parameter(param)
三、對于VGG-16的ThiNet剪枝策略
1、對前面10層剪枝力度大,因?yàn)榍懊?0層的feature map比較大,F(xiàn)LOPs占據(jù)了超過90%
2、全連接層占據(jù)了 86.41%的模型參數(shù),所以將其改成global average pooling layer
3、剪枝是layer by layer,每剪完一個(gè)layer finetune一個(gè)epoch,學(xué)習(xí)率設(shè)為0.001,到最后一層剪完 finetune 12個(gè)epoch,學(xué)習(xí)率設(shè)為0.0001.
4、在Imagenet上, VGG更具體的剪枝細(xì)節(jié)可以看論文4.2部分。
四、對于ResNet的剪枝策略
對于殘差塊的剪枝,因?yàn)橛袀€(gè)add的操作,相加時(shí)候維度必須保持一致,所以殘差塊最后一層輸出的filter不改變而只剪枝前面兩層,如下所示:

每剪完一個(gè)layer finetune一個(gè)epoch,固定學(xué)習(xí)率為0.0001,到最后一層剪完 finetune 9個(gè)epoch ,學(xué)習(xí)率從0.001到0.00001變換,其余的與VGG-16中一樣。ResNet更具體細(xì)節(jié),請查看論文4.3部分
五、參考鏈接
原作中文解讀:http://www.lamda.nju.edu.cn/luojh/project/ThiNet_ICCV17/ThiNet_ICCV17_CN.html
論文:https://arxiv.org/abs/1707.06342
代碼:https://github.com/Roll920/ThiNet
https://github.com/kkeono2/Channel-Pruning-using-Thinet-LASSO-
歡迎關(guān)注GiantPandaCV, 在這里你將看到獨(dú)家的深度學(xué)習(xí)分享,堅(jiān)持原創(chuàng),每天分享我們學(xué)習(xí)到的新鮮知識。( ? ?ω?? )?
有對文章相關(guān)的問題,或者想要加入交流群,歡迎添加BBuf微信:
為了方便讀者獲取資料以及我們公眾號的作者發(fā)布一些Github工程的更新,我們成立了一個(gè)QQ群,二維碼如下,感興趣可以加入。
