詳解卷積中的Winograd加速算法
?1. 為什么會引入WinoGrad?「GiantPandaCV導語」:這篇文章為大家介紹一下用來加速卷積運算的WinoGrad算法的原理,工程實現(xiàn)以及相關優(yōu)化思路,如果你對卷積加速算法感興趣可以看看這篇文章。算法的完整實現(xiàn)請到MsnhNet的github倉庫查看,地址為:https://github.com/msnh2012/Msnhnet
?
做過ACM/OI的朋友大家應該對FFT并不陌生,我們知道對于兩個序列的乘法通過FFT可以從原始O(n^2)復雜度變成O(nlogn),所以我們就會想著FFT這個算法是否可以應用到我們計算卷積中來呢?當然是可以的,但是FFT的計算有個問題哦,會引入復數(shù)。而移動端是不好處理復數(shù)的,對于小卷積核可能減少的計算量和復數(shù)運算帶來的降速效果是不好說誰會主導的。所以在這種情況下,針對卷積的WinoGrad算法出現(xiàn)了,它不僅可以類似FFT一樣降低計算量,它還不會引入復數(shù),使得卷積的運算加速成為了可能。因此,本文嘗試從工程實現(xiàn)的角度來看一下WinoGrad,希望對從事算法加速的小伙伴有一些幫助。
2. 為什么會有這篇文章?最近嘗試給MsnhNet做卷積的WinoGrad實現(xiàn),然后開始了解這個算法,并嘗試參考著NCNN來理解和動手寫一下。參考了多篇優(yōu)秀的講解文章和NCNN源碼,感覺算是對這個算法有了較為清楚的認識,這篇文章就記錄一下我在實現(xiàn)并且步長為的WinoGrad卷積時的一些理解。這篇文章的重點是WinoGrad卷積的實現(xiàn),關于WinoGrad卷積里面的變化矩陣如何推導可以看梁德澎作者的文章:詳解Winograd變換矩陣生成原理 (聽說后續(xù)他會做個視頻來仔細講講QAQ),現(xiàn)在就假設我們知道了WinoGrad的幾個變換矩陣。如果你不知道也沒關系,因為有一個Python工具包可以直接幫我們計算,地址為:https://github.com/andravin/wincnn 。然后現(xiàn)在我們就要用拿到的這幾個矩陣來實現(xiàn)WinoGrad算法,聽起來比較簡單,但我們還是得一步步理清楚是不。
WinoGrad算法起源于1980年,是Shmuel Winograd提出用來減少FIR濾波器計算量的一個算法。它指出,對于輸出個數(shù)為,參數(shù)個數(shù)為的FIR濾波器,不需要次乘法計算,而只需要次乘法計算即可。
下面是一個經(jīng)典例子,以1維卷積為例,輸入信號,卷積核,則卷積可以寫成如下矩陣乘法形式:
式子1如果這個計算過程使用普通的矩陣乘法,則一共需要「6次乘法和4次加法」 。
但是,我們仔細觀察一下,卷積運算中輸入信號轉換得到的矩陣不是任意矩陣,其有規(guī)律的分布著大量的重復元素,例如第一行的和,卷積轉換成的矩陣乘法比一般乘法的問題域更小,所以這就讓優(yōu)化存為了可能。
然后WinoGrad的做法就是:
式子2其中,
等式3我們知道,在CNN的推理階段,卷積核上的元素是固定的,所以上式中和相關的式子可以提前算好,在預測階段只用計算一次,可以忽略。所以這里一共需要「4次乘法加4次加法」。
相比于普通的矩陣乘法,使用WinoGrad算法之后乘法次數(shù)減少了,這樣就可以達到加速的目的了。
這個例子實際上是「1D的WinoGrad算法」,我們將上面的計算過程寫成矩陣的形式如下:
式子4其中,表示element-wise multiplication(Hadamard product)對應位置相乘。其中,
相關矩陣解釋- :表示卷積核
- :表示輸入信號
- :卷積核變換矩陣,尺寸為
- :輸入變換矩陣,尺寸
- :輸出變換矩陣,尺寸
所以整個計算過程可以分為4步:
- 輸入變換
- 卷積核變換
- 外積
- 輸出變換
然后我們將1D的WinoGrad擴展到2D,就可以實現(xiàn)卷積的加速了,那么如何從1維擴展到2維呢?公式如下:
式子5其中,為的卷積核,為的圖像塊,我們把上面的擴展到,先寫成矩陣乘法的方式:
F(2x2,3x3) 圖片來自https://www.cnblogs.com/shine-lee/p/10906535.html上圖表示我們將卷積核的元素拉成了一列,將輸入信號每個滑動窗口中的元素拉成了一行。注意圖中紅線分成的矩陣塊,每個矩陣塊中重復元素的位置與一維相同,即:
二維和一維的WinoGrad矩陣關系然后,令,即圖像窗口中的第0行元素,然后表示第行,,然后可以推導:
2D WinoGrad矩陣形式計算推導在上面的推導中,表示長度為4的和長度為的卷積結果,結果為長度為2的列向量,其中和均為長度為4的列向量。
進一步,可以看成3對長度為4的列向量兩兩對應位置相乘再相加,結果為長度為4的列向量,也可以看成是4組長度為3的行向量的點積運算。
同樣,也是3對長度為4的列向量的內積運算。
然后類似1D WinoGrad算法,我們考慮兩者的重疊部分和,剛好對應1D WinoGrad中的每一行在的對應行上進行1維卷積,基于上面推導的1D WinoGrad公式,行向量的卷積只需要將所有左乘的變換矩陣轉置后變成右乘即可。
然后上面的推導就做完了。
下圖表示2D WinoGrad的示意圖:
2D WinoGrad示意圖這個時候,WinoGrad算法的乘法次數(shù)為,而如果直接卷積乘法次數(shù)為,「降低了2.25倍的乘法計算復雜度」。
4. 從工程角度來看WinoGrad下面我們就從一個實際例子來說,如何利用WinoGrad來實現(xiàn)并且步長為1的卷積運算?;谏厦娼榻B的2D WinoGrad的原理,我們現(xiàn)在只需要分4步即可實現(xiàn)WnoGrad算法:
- 第一步就是對輸入卷積核的變換:
- 第二步就是對輸入數(shù)據(jù)的變換:
- 第三步就是對M矩陣的計算:
- 最后一步就是結果的計算:
接下來我們就以WinoGrad實現(xiàn)并且步長為1的卷積計算為例子,來理解一下WinoGrad的工程實現(xiàn)。
4.1 對輸入卷積核進行變換
這一步就是對卷積核進行變化,公式為:,其中表示輸出通道標號,表示輸入通道標號,一個對應卷積核的一個。由于我們要實現(xiàn)的是,因此是一個的矩陣,我們不難寫出這部分代碼(其中,矩陣可以通過https://github.com/andravin/wincnn 這個工具進行計算):
//?矩陣G
????????const?float?ktm[8][3]?=?{
????????????{1.0f,??????0.0f,??????0.0f},
????????????{-2.0f?/?9,?-2.0f?/?9,?-2.0f?/?9},
????????????{-2.0f?/?9,?2.0f?/?9,?-2.0f?/?9},
????????????{1.0f?/?90,?1.0f?/?45,?2.0f?/?45},
????????????{1.0f?/?90,?-1.0f?/?45,?2.0f?/?45},
????????????{1.0f?/?45,?1.0f?/?90,?1.0f?/?180},
????????????{1.0f?/?45,?-1.0f?/?90,?1.0f?/?180},
????????????{0.0f,?0.0f,?1.0f}
????????};
????????const?int?kernelTmSize?=?inChannel?*?8?*?8;
#if?USE_OMP
????#pragma?omp?parallel?for?num_threads(OMP_THREAD)
#endif
????????for(int?outc?=?0;?outc?????????????for(int?inc?=?0;?inc?????????????????const?float*?kernel0?=?(const?float*)kernel?+?outc?*?inChannel?*?9?+?inc?*?9;
????????????????float?*kernel_tm0?=?kernel_tm?+?outc?*?kernelTmSize?+?inc?*?64;
????????????????//需要變換的卷積核
????????????????const?float*?k0?=?kernel0;
????????????????const?float*?k1?=?kernel0?+?3;
????????????????const?float*?k2?=?kernel0?+?6;
????????????????float?tmpG[8][3];????//?tmp?=?G*g
????????????????for(int?i?=?0;?i?8;?i++){
????????????????????tmpG[i][0]?=?k0[0]?*?ktm[i][0]?+?k0[1]?*?ktm[i][1]?+?k0[2]?*?ktm[i][2];
????????????????????tmpG[i][1]?=?k1[0]?*?ktm[i][0]?+?k1[1]?*?ktm[i][1]?+?k1[2]?*?ktm[i][2];
????????????????????tmpG[i][2]?=?k2[0]?*?ktm[i][0]?+?k2[1]?*?ktm[i][1]?+?k2[2]?*?ktm[i][2];
????????????????}
????????????????//U?=?kernel_tm0?=?G*g*G^T
????????????????//[8*3]?x?[3*8]
????????????????for(int?i?=?0;?i?8;?i++){
????????????????????float?*tmpPtr?=?&tmpG[i][0];
????????????????????for(int?j?=?0;?j?8;?j++){
????????????????????????kernel_tm0[i?*?8?+?j]?=?tmpPtr[0]?*?ktm[j][0]?+?tmpPtr[1]?*?ktm[j][1]?+?tmpPtr[2]?*?ktm[j][2];
????????????????????}
????????????????}
????????????}
????????}
通過這段代碼,所有的卷積核都被轉換成了U,存放在了kernel_tm上,一行代表一個,kernel_tm的內存排布如下圖所示:
U_{k,c}的內存排布其中W=64的原因是因為F(6x6,3x3)需要每一個輸入圖像塊(tile)的大小為,權重塊也對應,這樣才可以做卷積運算(eltwise_mult)。
然后上次我們講到數(shù)據(jù)Pack的優(yōu)勢詳解Im2Col+Pack+Sgemm策略更好的優(yōu)化卷積運算,所以這里仍然使用NCNN的Pack策略來獲得更好的訪存,即將上面的kernel_tm進行一次重排,將維度全部壓到維度上,另外再對維度做一個額外的4倍壓縮,來獲得更好的訪存。
將H的維度全部壓到維度上示意圖:
將kernel_tm的H維度全部壓到W維度變成一個扁平的Blob然后在這個基礎上,將C維度進行進一步壓縮,這個時候還需要注意的是對于每一個輸出通道,我們在這個平面上是同時拿出了2行,也就是拿出了128個數(shù)據(jù),然后進行交織排列,最后獲得kernel_tm2。這里以輸出通道的前4個為例,即剛好處理8個U矩陣之后結果矩陣kernel_tm2應該是長什么樣子,如下圖所示:
![Pack策略之后的矩陣kernel_tm2就長這個樣子
這部分的代碼實現(xiàn)如下:
int?nnOutchannel?=?outChannel?>>?2;
????????int?remainOutChannel?=?nnOutchannel?<2;
????????
????????int?packOutChannel?=?nnOutchannel?+?(outChannel?%?4?+?3)?/?4;
????????int?packOutH?=?1;
????????int?packOutW?=?(8?*?8?*?inChannel?*?4);
????????//float?*kernel_tm2?=?new?float[packOutChannel?*?packOutH?*?packOutW];
#if?USE_OMP
????#pragma?omp?parallel?for?num_threads(OMP_THREAD)
#endif???????
????????for(int?cc?=?0;?cc?????????????int?c?=?cc?<2;
????????????float?*ktm2?=?kernel_tm2?+?cc?*?packOutH?*?packOutW;
????????????
????????????const?float?*kernel0_tm?=?kernel_tm?+?c?*?kernelTmSize;
????????????const?float?*kernel1_tm?=?kernel_tm?+?(c?+?1)?*?kernelTmSize;
????????????const?float?*kernel2_tm?=?kernel_tm?+?(c?+?2)?*?kernelTmSize;
????????????const?float?*kernel3_tm?=?kernel_tm?+?(c?+?3)?*?kernelTmSize;
????????????int?q?=?0;
????????????for(;?q?+?1?2){
????????????????const?float?*k00?=?kernel0_tm?+?q?*?64;
????????????????const?float?*k01?=?kernel0_tm?+?(q?+?1)?*?64;
????????????????const?float?*k10?=?kernel1_tm?+?q?*?64;
????????????????const?float?*k11?=?kernel1_tm?+?(q?+?1)?*?64;
????????????????const?float?*k20?=?kernel2_tm?+?q?*?64;
????????????????const?float?*k21?=?kernel2_tm?+?(q?+?1)?*?64;
????????????????const?float?*k30?=?kernel3_tm?+?q?*?64;
????????????????const?float?*k31?=?kernel3_tm?+?(q?+?1)?*?64;
????????????????for(int?i?=?0;?i?16;?i++){
????????????????????for(int?j?=?0;?j?4;?j++){
????????????????????????ktm2[0?+?j]?=?k00[j];
????????????????????????ktm2[4?+?j]?=?k01[j];
????????????????????????ktm2[8?+?j]?=?k10[j];
????????????????????????ktm2[12?+?j]?=?k11[j];
????????????????????????ktm2[16?+?j]?=?k20[j];
????????????????????????ktm2[20?+?j]?=?k21[j];
????????????????????????ktm2[24?+?j]?=?k30[j];
????????????????????????ktm2[28?+?j]?=?k31[j];
????????????????????}
????????????????????k00?+=?4;
????????????????????k01?+=?4;
????????????????????k10?+=?4;
????????????????????k11?+=?4;
????????????????????k20?+=?4;
????????????????????k21?+=?4;
????????????????????k30?+=?4;
????????????????????k31?+=?4;
????????????????????ktm2?+=?32;
????????????????}
????????????}
????????????//inChannel方向的拖尾部分
????????????for(;?q?????????????????const?float?*k00?=?kernel0_tm?+?q?*?64;
????????????????const?float?*k10?=?kernel1_tm?+?q?*?64;
????????????????const?float?*k20?=?kernel2_tm?+?q?*?64;
????????????????const?float?*k30?=?kernel3_tm?+?q?*?64;
????????????????for(int?i?=?0;?i?16;?i++){
????????????????????for(int?j?=?0;?j?4;?j++){
????????????????????????ktm2[0?+?j]?=?k00[j];
????????????????????????ktm2[4?+?j]?=?k10[j];
????????????????????????ktm2[8?+?j]?=?k20[j];
????????????????????????ktm2[12?+?j]?=?k30[j];
????????????????????}
????????????????????k00?+=?4;
????????????????????k10?+=?4;
????????????????????k20?+=?4;
????????????????????k30?+=?4;
????????????????????ktm2?+=?16;
????????????????}
????????????}
????????}
#if?USE_OMP
????#pragma?omp?parallel?for?num_threads(OMP_THREAD)
#endif??????
????????for(int?cc?=?remainOutChannel;?cc?????????????float?*ktm2?=?kernel_tm2??+?nnOutchannel?*?packOutH?*?packOutW?+?8?*?8?*?inChannel?*?(cc?-?remainOutChannel);
????????????const?float*?kernel0_tm?=?kernel_tm?+?cc?*?kernelTmSize;
????????????int?q?=?0;
????????????for(;?q?????????????????const?float*?k00?=?kernel0_tm?+?q?*?64;
????????????????for(int?i?=?0;?i?16;?i++){
????????????????????for(int?j?=?0;?j?4;?j++){
????????????????????????ktm2[j]?=?k00[j];
????????????????????}
????????????????????k00?+=?4;
????????????????????ktm2?+=?4;
????????????????}
????????????}
????????}????????
4.2 對輸入數(shù)據(jù)進行變換
對卷積核進行變換之后,接下來就輪到對輸入矩陣進行變換了,即對V矩陣進行計算,。上面我們已經(jīng)提到過,對于卷積核獲得的每一個,我們都需要一個對應的的圖像塊(tile)和它做卷積運算(eltwise_multiply)。所以這里我們首先需要確定輸入數(shù)據(jù)可以被拆成多少個圖像塊,并且我們需要為變換矩陣V申請空間,從第三節(jié)可知:輸入變換矩陣,尺寸為,即每個小塊的變換矩陣都為,但是輸入特征圖長寬不一定會被8整除,這個時候就需要對輸入特征圖進行擴展(padding),這部分預處理的代碼實現(xiàn)如下:
//?Vc,b?=?B^Td_{c,b}B
????????
????????//?輸出特征圖如果長寬不夠需要Padding
????????int?outW?=?(outWidth?+?5)?/?6?*?6;
????????int?outH?=?(outHeight?+?5)?/?6?*?6;
????????int?W?=?outW?+?2;
????????int?H?=?outH?+?2;
????????int?Top?=?0;
????????int?Left?=?0;
????????int?Bottom?=?H;
????????int?Right?=?W;
????????int?PadHeight?=?Bottom?-?Top;
????????int?PadWidth?=?Right?-?Left;
????????int?PadSize?=?PadHeight?*?PadWidth;
????????float?*srcPadding?=?new?float[PadHeight?*?PadWidth?*?inChannel];
????????PaddingLayerArm?now;
????????now.padding(src,?inWidth,?inHeight,?inChannel,?srcPadding,?0,?H?-?inHeight,?0,?W?-?inWidth,?0);
????????
????????int?w_tm?=?outW?/?6?*?8;
????????int?h_tm?=?outH?/?6?*?8;
????????int?tiles?=?w_tm?/?8?*?h_tm?/?8;
????????int?src_tm_channel?=?inChannel;
????????int?src_tm_h?=?16?*?w_tm?/?8?*?h_tm?/?8;
????????int?src_tm_w?=?4;
????????
????????int?src_tm_size?=?src_tm_h?*?src_tm_w;
????????float?*src_tm??=?new?float[src_tm_channel?*?src_tm_h?*?src_tm_w];
注意上面src_tm的形狀,這是考慮到了卷積核變換矩陣已經(jīng)執(zhí)行了Pack策略,所以這里為了方便后續(xù)的卷積計算和進行指令集加速,同樣將src_tm進行Pack,這個Pack是直接規(guī)定計算完之后4個4個岔開存儲的方式來實現(xiàn)的。另外,輸入Blob的一個Channel對應了輸出Blob的一個Channel。
然后我們再通過WinCNN工具可以獲得B矩陣和B的轉置矩陣,并確定V矩陣更好的計算策略(指的是可以復用一些中間變量)。
//?BT?=?
????????//??1???0????-21/4????0????21/4?????0????-1??0?
????????//?????????????????????????????????????????????
????????//??0???1??????1????-17/4??-17/4????1????1???0?
????????//?????????????????????????????????????????????
????????//??0???-1?????1????17/4???-17/4???-1????1???0?
????????//?????????????????????????????????????????????
????????//??0??1/2????1/4???-5/2???-5/4?????2????1???0?
????????//?????????????????????????????????????????????
????????//??0??-1/2???1/4????5/2???-5/4????-2????1???0?
????????//?????????????????????????????????????????????
????????//??0???2??????4????-5/2????-5?????1/2???1???0?
????????//?????????????????????????????????????????????
????????//??0???-2?????4?????5/2????-5????-1/2???1???0?
????????//?????????????????????????????????????????????
????????//??0???-1?????0????21/4?????0????-21/4??0???1?
????????//B?=?
????????//??1?????0?????0????0????0???0?????0???0?????
?????//??0?????1?????-1????1/2????-1/2???2????-2???-1????
?????//??-21/4?1?????1????1/4????1/4???4?????4???0?????
?????//??0?????-17/4?17/4???-5/2????5/2???-5/2?5/2???21/4??
?????//??21/4?-17/4?-17/4??-5/4???-5/4???-5?-5???0????????
?????//??0?????1?????-1????2????2???1/2?-1/2??-21/4?
?????//??-1?????1?????1????1????1???1?????1???0?????
?????//??0?????0?????0????0????0???0?????0???1?????
????????//?0?=?r00?-?r06?+?(r04?-?r02)?*?5.25
????????//?7?=?r07?-?r01?+?(r03?-?r05)?*?5.25
????????//?1?=?(r02?+?r06?-?r04?*?4.25)?+?(r01?-?r03?*?4.25?+?r05)
????????//?2?=?(r02?+?r06?-?r04?*?4.25)?-?(r01?-?r03?*?4.25?+?r05)
????????//?3?=?(r06?+?r02?*?0.25?-?r04?*?1.25)?+?(r01?*?0.5?-?r03?*?2.5?+?r05?*?2)
????????//?4?=?(r06?+?r02?*?0.25?-?r04?*?1.25)?-?(r01?*?0.5?-?r03?*?2.5?+?r05?*?2)
????????//?reuse?r04?*?1.25
????????//?reuse?r03?*?2.5
????????//?5?=?(r06?+?(r02?-?r04?*?1.25)?*?4)?+?(r01?*?2?-?r03?*?2.5?+?r05?*?0.5)
????????//?6?=?(r06?+?(r02?-?r04?*?1.25)?*?4)?-?(r01?*?2?-?r03?*?2.5?+?r05?*?0.5)
接下來我們就可以開始計算V矩陣了,代碼如下:
#if?USE_OMP
????#pragma?omp?parallel?for?num_threads(OMP_THREAD)
#endif
????????for(int?q?=?0;?q?????????????const?float?*padptr?=?srcPadding?+?q?*?PadSize;
????????????float?*srcptr?=?src_tm?+?q?*?src_tm_size;
????????????float?tmpV[8][8];
????????????//tile
????????????for(int?i?=?0;?i?8;?i++){
????????????????for(int?j?=?0;?j?8;?j++){
????????????????????float?*r0?=?padptr?+?i?*?6?*?PadWidth?+?j?*?6;
????????????????????
????????????????????//?Bd_{c,b}
????????????????????for(int?m?=?0;?m?8;?m++){
????????????????????????tmpV[0][m]?=?r0[0]?-?r0[6]?+?(r0[4]?-?r0[2])?*?5.25f;
????????????????????????tmpV[7][m]?=?r0[7]?-?r0[1]?+?(r0[3]?-?r0[5])?*?5.25f;
????????????????????????float?t1?=?(r0[2]?+?r0[6]?-?r0[4]?*?4.25f);
????????????????????????float?t2?=?(r0[1]?+?r0[5]?-?r0[3]?*?4.25f);
????????????????????????tmpV[1][m]?=?t1?+?t2;
????????????????????????tmpV[2][m]?=?t1?-?t2;
????????????????????????float?t3?=?(r0[6]?+?r0[2]?*?0.25f?-?r0[4]?*?1.25f);
????????????????????????float?t4?=?(r0[1]?*?0.5f?-?r0[3]?*?2.5f?+?r0[5]?*?2.f);
????????????????????????tmpV[3][m]?=?t3?+?t4;
????????????????????????tmpV[4][m]?=?t3?-?t4;
????????????????????????float?t5?=?(r0[6]?+?(r0[2]?-?r0[4]?*?1.25f)?*?4.f);
????????????????????????float?t6?=?(r0[1]?*?2.f?-?r0[3]?*?2.5f?+?r0[5]?*?0.5f);
????????????????????????tmpV[5][m]?=?t5?+?t6;
????????????????????????tmpV[6][m]?=?t5?-?t6;
????????????????????????r0?+=?PadWidth;
????????????????????}
????????????????????//Bd_{c,b}B^T
????????????????????float?*r00?=?srcptr?+?(i?*?w_tm?/?8?+?j)?*?src_tm_w;
????????????????????float?*r04?=?srcptr?+?(i?*?w_tm?/8?+?j?+?tiles)?*?src_tm_w;
????????????????????for(int?m?=?0;?m?8;?m++){
????????????????????????float*?tmpVPtr?=?tmpV[m];
????????????????????????r00[0]?=?tmpVPtr[0]?-?tmpVPtr[6]?+?(tmpVPtr[4]?-?tmpVPtr[2])?*?5.25f;
????????????????????????r04[3]?=?tmpVPtr[7]?-?tmpVPtr[1]?+?(tmpVPtr[3]?-?tmpVPtr[5])?*?5.25f;
????????????????????????
????????????????????????float?t1?=??(tmpVPtr[2]?+?tmpVPtr[6]?-?tmpVPtr[4]?*?4.25f);
????????????????????????float?t2?=??(tmpVPtr[1]?-?tmpVPtr[3]?*?4.25f?+?tmpVPtr[5]);
????????????????????????r00[1]?=?t1?+?t2;
????????????????????????r00[2]?=?t1?-?t2;
????????????????????????float?t3?=?(tmpVPtr[6]?+?tmpVPtr[2]?*?0.25f?-?tmpVPtr[4]?*?1.25);
????????????????????????float?t4?=?(tmpVPtr[1]?*?0.5f?-?tmpVPtr[3]?*?2.5f?+?tmpVPtr[5]?*?2.f);
????????????????????????r00[3]?=?t3?+?t4;
????????????????????????r04[0]?=?t3?-?t4;
????????????????????????float?t5?=?(tmpVPtr[6]?+?(tmpVPtr[2]?-?tmpVPtr[4]?*?1.25f)?*?4.f);
????????????????????????float?t6?=?(tmpVPtr[1]?*?2.f?-?tmpVPtr[3]?*?2.5f?+?tmpVPtr[5]?*?0.5f);
????????????????????????r04[1]?=?t5?+?t6;
????????????????????????r04[2]?=?t5?-?t6;
????????????????????????r00?+=?2?*?tiles?*?src_tm_w;
????????????????????????r04?+=?2?*?tiles?*?src_tm_w;
????????????????????}
????????????????}
????????????}
????????}
????????delete?[]?srcPadding;
可以看到這個地方不僅計算了V矩陣,并在存儲時就對V矩陣進行了重新排列,以適應卷積核變化矩陣的Pack結果,方便后面進行卷積計算的加速同時獲得更好的訪存,這個過程如下圖所示:
對輸入矩陣進行變換的過程4.3 計算M矩陣
M矩陣的計算公式為:
其中,k代表輸出通道數(shù),b表示tile序號。
由于上面輸入圖像塊已經(jīng)執(zhí)行了Pack策略,這里只需要將對應小塊進行乘加操作即完成了M矩陣的計算,這部分的代碼實現(xiàn)如下:
#if?USE_OMP
????#pragma?omp?parallel?for?num_threads(OMP_THREAD)
#endif
????????for(int?cc?=?0;?cc?????????????int?c?=?cc?*?4;
????????????float?*dest0?=?dest_tm?+?c?*?dst_tm_size;
????????????float?*dest1?=?dest_tm?+?(c?+?1)?*?dst_tm_size;
????????????float?*dest2?=?dest_tm?+?(c?+?2)?*?dst_tm_size;
????????????float?*dest3?=?dest_tm?+?(c?+?3)?*?dst_tm_size;
????????????const?float?*ktm?=?kernel?+?cc?*?kernelSize;
????????????int?q?=?0;
????????????
????????????for(;?q?+?1?2){
????????????????const?float*?r0?=?src_tm?+?q?*?src_tm_size;
????????????????const?float*?r1?=?src_tm?+?(q?+?1)?*?src_tm_size;
????????????????
????????????????float*?destptr0?=?dest0;
????????????????float?*destptr1?=?dest1;
????????????????float?*destptr2?=?dest2;
????????????????float?*destptr3?=?dest3;
????????????????for(int?r?=?0;?r?16;?r++){
????????????????????for(int?t?=?0;?t?????????????????????????for(int?m?=?0;?m?4;?m++){
????????????????????????????destptr0[m]?+=?r0[m]?*?ktm[m];
????????????????????????????destptr0[m]?+=?r1[m]?*?ktm[m?+?4];
????????????????????????????destptr1[m]?+=?r0[m]?*?ktm[m?+?8];
????????????????????????????destptr1[m]?+=?r1[m]?*?ktm[m?+?12];
????????????????????????????destptr2[m]?+=?r0[m]?*?ktm[m?+?16];
????????????????????????????destptr2[m]?+=?r1[m]?*?ktm[m?+?20];
????????????????????????????destptr3[m]?+=?r0[m]?*?ktm[m?+?24];
????????????????????????????destptr3[m]?+=?r1[m]?*?ktm[m?+?28];??
????????????????????????}
????????????????????????r0?+=?4;
????????????????????????r1?+=?4;
????????????????????????destptr0?+=?4;
????????????????????????destptr1?+=?4;
????????????????????????destptr2?+=?4;
????????????????????????destptr3?+=?4;
????????????????????}
????????????????????ktm?+=?32;
????????????????}
????????????}
????????????for(;?q?????????????????const?float?*r0?=?src_tm?+?q?*?src_tm_size;
????????????????float*?destptr0?=?dest0;
????????????????float?*destptr1?=?dest1;
????????????????float?*destptr2?=?dest2;
????????????????float?*destptr3?=?dest3;
????????????????for(int?r?=?0;?r?16;?r++){
????????????????????for(int?t?=?0;?t?????????????????????????for(int?m?=?0;?m?4;?m++){
????????????????????????????destptr0[m]?+=?r0[m]?*?ktm[m];
????????????????????????????destptr1[m]?+=?r0[m]?*?ktm[m?+?4];
????????????????????????????destptr2[m]?+=?r0[m]?*?ktm[m?+?8];
????????????????????????????destptr3[m]?+=?r0[m]?*?ktm[m?+?12];
????????????????????????}
????????????????????????r0?+=?4;
????????????????????????destptr0?+=?4;
????????????????????????destptr1?+=?4;
????????????????????????destptr2?+=?4;
????????????????????????destptr3?+=?4;
????????????????????}
????????????????????ktm?+=?16;
????????????????}
????????????}
????????}
????????
#if?USE_OMP
????#pragma?omp?parallel?for?num_threads(OMP_THREAD)
#endif
????????for(int?cc?=?remainOutChannel;?cc?????????????int?c?=?cc;
????????????float?*dest0?=?dest_tm?+?c?*?dst_tm_size;
????????????const?float?*ktm?=?kernel?+?nnOutChannel?*?kernelSize?+?8?*?8?*?inChannel?*?(c?-?remainOutChannel);
????????????int?q?=?0;
????????????for(;?q?????????????????const?float*?r0?=?src_tm?+?q?*?src_tm_size;
????????????????float*?destptr0?=?dest0;
????????????????for(int?r?=?0;?r?16;?r++){
????????????????????for(int?i?=?0;?i?????????????????????????for(int?m?=?0;?m?4;?m++){
????????????????????????????destptr0[m]?+=?r0[m]?*?ktm[m];
????????????????????????}
????????????????????????r0?+=?4;
????????????????????????destptr0?+=?4;
????????????????????}
????????????????????ktm?+=?4;
????????????????}
????????????}
????????}
至此,我們獲得了M矩陣,矩陣大概長下面這樣子,它仍然是交錯排列的:
M矩陣長得和V矩陣有點像,主要是通道維度變了4.4 計算結果Y矩陣
現(xiàn)在就到了最后一步了,我們需要計算結果矩陣Y,公式為:
其中表示輸出通道數(shù),b表示tile標號,這部分和上面卷積核的計算類似,代碼如下:
//?Yk,b=A^TMk,bA
//?AT=
//??1??1??1???1????1????1??????1????0?
//????????????????????????????????????
//??0??1??-1??2???-2???1/2???-1/2???0?
//????????????????????????????????????
//??0??1??1???4????4???1/4????1/4???0?
//????????????????????????????????????
//??0??1??-1??8???-8???1/8???-1/8???0?
//????????????????????????????????????
//??0??1??1???16??16???1/16??1/16???0?
//????????????????????????????????????
//??0??1??-1??32??-32??1/32??-1/32??1?
????????//?0?=?r0?+?(r1?+?r2)?+?(r3?+?r4)?????+?(r5?+?r6)?*?32
????????//?1?=??????(r1?-?r2)?+?(r3?-?r4)?*?2?+?(r5?-?r6)?*?16
????????//?2?=??????(r1?+?r2)?+?(r3?+?r4)?*?4?+?(r5?+?r6)?*?8
????????//?3?=??????(r1?-?r2)?+?(r3?-?r4)?*?8?+?(r5?-?r6)?*?4
????????//?4?=??????(r1?+?r2)?+?(r3?+?r4)?*?16+?(r5?+?r6)?*?2
????????//?5?=?r7?+?(r1?-?r2)?+?(r3?-?r4)?*?32+?(r5?-?r6)
????????float?*dest_tm2?=?new?float[outW?*?outH?*?outChannel];
????????const?int?dst_tm_size2?=?outW?*?outH;
????????
????????const?int?outSize?=?outHeight?*?outWidth;
#if?USE_OMP
????#pragma?omp?parallel?for?num_threads(OMP_THREAD)
#endif
????????for(int?cc?=?0;?cc?????????????float?*destptr?=?dest_tm?+?cc?*?dst_tm_size;
????????????float?*outptr?=?dest_tm2?+?cc?*?dst_tm_size2;
????????????float?tmpA[6][8];
????????????for(int?i?=?0;?i?6;?i++){
????????????????for(int?j?=?0;?j?6;?j++){
????????????????????float?*destptr0?=?destptr?+?(i?*?w_tm?/?8?+?j)?*?dst_tm_w;
????????????????????float?*destptr4?=?destptr?+?(i?*?w_tm?/?8?+?j?+?tiles)?*?dst_tm_w;
????????????????????for(int?m?=?0;?m?8;?m++){
????????????????????????float?t1?=?destptr0[1]?+?destptr0[2];
????????????????????????float?t2?=?destptr0[1]?-?destptr0[2];
????????????????????????float?t3?=?destptr0[3]?+?destptr4[0];
????????????????????????float?t4?=?destptr0[3]?-?destptr4[0];
????????????????????????float?t5?=?destptr4[1]?+?destptr4[2];
????????????????????????float?t6?=?destptr4[1]?-?destptr4[2];
????????????????????????tmpA[0][m]?=?destptr0[0]?+?t1?+?t3?+?t5?*?32;
????????????????????????tmpA[2][m]?=?t1?+?t3?*?4?+?t5?*?8;
????????????????????????tmpA[4][m]?=?t1?+?t3?*?16?+?t5?+?t5;
????????????????????????tmpA[1][m]?=?t2?+?t4?+?t4?+?t6?*?16;
????????????????????????tmpA[3][m]?=?t2?+?t4?*?8?+?t6?*?4;
????????????????????????tmpA[5][m]?=?destptr4[3]?+?t2?+?t4?*?32?+?t6;
????????????????????????destptr0?+=?dst_tm_w?*?2?*?tiles;
????????????????????????destptr4?+=?dst_tm_w?*?2?*?tiles;
????????????????????}
????????????????????float?*outptr0?=?outptr?+?(i?*?6)?*?outW?+?j?*?6;
????????????????????for(int?m?=?0;?m?6;?m++){
????????????????????????const?float*?tmp0?=?tmpA[m];
????????????????????????float?t1?=?tmp0[1]?+?tmp0[2];
????????????????????????float?t2?=?tmp0[1]?-?tmp0[2];
????????????????????????float?t3?=?tmp0[3]?+?tmp0[4];
????????????????????????float?t4?=?tmp0[3]?-?tmp0[4];
????????????????????????float?t5?=?tmp0[5]?+?tmp0[6];
????????????????????????float?t6?=?tmp0[5]?-?tmp0[6];
????????????????????????outptr0[0]?=?tmp0[0]?+?t1?+?t3?+?t5?*?32;
????????????????????????outptr0[2]?=?t1?+?t3?*?4?+?t5?*?8;
????????????????????????outptr0[4]?=?t1?+?t3?*?16?+?t5?+?t5;
????????????????????????outptr0[1]?=?t2?+?t4?+?t4?+?t6?*?16;
????????????????????????outptr0[3]?=?t2?+?t4?*?8?+?t6?*?4;
????????????????????????outptr0[5]?=?tmp0[7]?+?t2?+?t4?*?32?+?t6;
????????????????????????outptr0?+=?outW;
????????????????????}
????????????????}
????????????}
????????}?
這部分代碼就實現(xiàn)了M矩陣匯聚并利用A矩陣獲得了最終的結果Y。這個過程上一節(jié)圖中已經(jīng)畫了,這里主要實現(xiàn)的是圖中的右半部分:
Y矩陣匯聚存放獲得輸出Blob但是需要注意的是這里獲得的Y有可能是多了幾行或者幾列,也就是拖尾為0的部分,所以需要把這一部分Crop掉,才能獲得我們最終的結果特征圖。Crop部分的代碼如下:
//crop
????????for(int?cc?=?0;?cc?????????????float?*outptr?=?dest_tm2?+?cc?*?dst_tm_size2;
????????????float?*outptr2?=?dest?+?cc?*?outHeight?*?outWidth;
????????????for(int?i?=?0;?i?????????????????for(int?j?=?0;?j?????????????????????outptr2[0]?=?outptr[0];
????????????????????outptr2++;
????????????????????outptr++;
????????????????}
????????????????outptr?+=?(outW?-?outWidth);
????????????}
????????}
至此,WinoGrad的算法流程結束,我們獲得了最后的卷積計算結果。
5. WinoGrad算法進一步加速上面無論是針對U,V,M還是Y矩陣的計算我們使用的都是暴力計算,所以接下來可以使用Neon Instrics和Neon Assembly技術進行優(yōu)化。介于篇幅原因,這里就不貼代碼了,有需要學習的可以關注后續(xù)MsnhNet的WinoGrad代碼部分https://github.com/msnh2012/Msnhnet/blob/master/src/layers/arm/MsnhConvolution3x3s1Winograd.cpp。這個代碼實現(xiàn)的思路取自開源框架NCNN,在此表示感謝NCNN這一優(yōu)秀工作(github:https://github.com/Tencent/ncnn)。
和Sgemm用于卷積一樣,我們也需要思考WinoGrad在何種情況下是適用的,或者說是有明顯加速的。這篇文章介紹的WinoGrad卷積是針對NCHW這種內存排布的,然后我們來看一下NCNN在基于NCHW這種內存排布下,是在何種情況下啟用WinoGrad()?
通過查看NCNN的源碼(https://github.com/Tencent/ncnn/blob/master/src/layer/arm/convolution_arm.cpp)可以發(fā)現(xiàn),只有在輸入輸出通道均>=16,并且特征圖長寬均小于等于120的條件下才會啟用WinoGrad卷積。
那么這個條件是如何得出的,除了和手工優(yōu)化的conv3x3s1(https://github.com/msnh2012/Msnhnet/blob/master/src/layers/arm/MsnhConvolution3x3s1.cpp)在不同條件下做速度對比測試之外,我們也可以感性的分析一下。
第一,WinoGrad算法設計到幾個矩陣變換,如果計算量不大,這幾個矩陣變換的成本占計算總成本的比例就越大,所以WinoGrad應當是在計算量比較大時才能有效,如VGG16。
第二,當計算量比較大的時候,又要考慮到Cache命中率的問題,這個時候WinoGrad訪存可能會比直接手動優(yōu)化更差,導致速度上不去。
7. 速度測試由于筆者還未實現(xiàn)完整Neon Instrics和Assembly部分,所以暫時無法給出速度對比。嘗試從NCNN的BenchMark中找到WinoGrad的加速效果大概是什么樣的,但只能找到各個網(wǎng)絡在各種板子上的整體推理速度,沒有WinoGrad F(6,3)單獨的速度對比,等國慶爭取補上來吧。
8. 結語關于WinoGrad的原理介紹還有工程實現(xiàn)(基于NCNN)暫時就講到這里了,有問題歡迎在評論區(qū)討論哦。我剛入門移動端優(yōu)化幾個月還有非常多知識需要學習,nihui,蟲叔,白牛,大老師他們都是高人,這幾個月從他們的文章受益良多,非常感謝!
9. 致謝- https://zhuanlan.zhihu.com/p/72149270
- https://www.cnblogs.com/shine-lee/p/10906535.html
- https://zhuanlan.zhihu.com/p/81201840
MsnhNet是一款基于純c++的輕量級推理框架,本框架受到darknet啟發(fā)。
項目地址:https://github.com/msnh2012/Msnhnet ,歡迎一鍵三連。
本框架目前已經(jīng)支持了X86、Cuda、Arm端的推理(支持的OP有限,正努力開發(fā)中),并且可以直接將Pytorch模型(后面也會嘗試接入更多框架)轉為本框架的模型進行部署,歡迎對前向推理框架感興趣的同學試用或者加入我們一起維護這個輪子。
最后,歡迎加入Msnhnet開發(fā)QQ交流群,有對項目的建議或者說個人的需求都可以在里面或者github issue提出。
交流群圖片歡迎關注GiantPandaCV, 在這里你將看到獨家的深度學習分享,堅持原創(chuàng),每天分享我們學習到的新鮮知識。( ? ?ω?? )?
有對文章相關的問題,或者想要加入交流群,歡迎添加BBuf微信:
二維碼