【詳細(xì)圖解】再次理解im2col
一句話(huà):im2col是將一個(gè)[C,H,W]矩陣變成一個(gè)[H,W]矩陣的一個(gè)方法,其原理是利用了行列式進(jìn)行等價(jià)轉(zhuǎn)換。
為什么要做im2col? 減少調(diào)用gemm的次數(shù)。
重要:本次的代碼只是為了方便理解im2col,不是用來(lái)做加速,所以代碼寫(xiě)的很簡(jiǎn)單且沒(méi)有做任何優(yōu)化。
一、卷積的可視化
例子是一個(gè)[1, 6, 6]的輸入,卷積核是[1, 3, 3],stride等于1,padding等于0。那么卷積的過(guò)程可視化如下圖,一共需要做16次卷積計(jì)算,每次卷積計(jì)算有9次乘法和8次加法。

輸出的公式如下,即Output_height = (6 - 3 + 2*0)/1 ?+ 1 = 4 = Output_width

二、行列式

乘號(hào)左邊的橫條,跟乘號(hào)右邊的豎條進(jìn)行點(diǎn)乘(即每個(gè)元素對(duì)應(yīng)相乘后再全部加起來(lái))。
關(guān)于行列式,大家都清楚的一點(diǎn),一根橫條的元素個(gè)數(shù)要等于一根豎條的元素個(gè)數(shù)(這樣才可以讓做點(diǎn)乘的時(shí)候能一一對(duì)應(yīng)起來(lái),不會(huì)讓小方塊落單)。豎條有多少條,出來(lái)的結(jié)果就有多少個(gè)小方塊(在橫條的個(gè)數(shù)為1的情況下)。
出來(lái)的結(jié)果(等號(hào)的右邊)的行數(shù)等于乘號(hào)左邊的橫條的行數(shù),出來(lái)的結(jié)果(等號(hào)的右邊)的列數(shù)等于乘號(hào)右邊的橫條的列數(shù),公式表示就是[row, ?x] * [x, col] = [row, col]。舉個(gè)例子[3, 8] * [8, 4] = [3, 4]

三、[1, H, W]的im2col


展開(kāi)后,就可以直接做兩個(gè)數(shù)組的矩陣乘積了

import??numpy?as?np
scr?=?np.array(np.arange(0,7**2).reshape(7,?7))
intH,?intW=?scr.shape
kernel?=?np.array([-0.2589,??0.2106,?-0.1583,?-0.0107,??0.1177,??0.1693,?-0.1582,?-0.3048,?-0.1946]).reshape(3,3)
KHeight,?KWeight?=?kernel.shape
row_num?=?intH?-?KHeight?+?1
col_num?=?intW?-?KWeight?+?1
OutScrIm2Col?=?np.zeros([row_num*col_num,KHeight*KWeight])?
ii,?jj?=?0,?0
col_cnt,?row_cnt?=?0,?0
for?i?in?range(0,?row_num):
????for?j?in?range(0,?col_num):?#?這倆個(gè)for是為了遍歷列,即乘了多少次,這里完全可以merge成一個(gè)for循環(huán),只需要提前計(jì)算好就行
????????ii?=?i
????????jj?=?j
????????for?iii?in?range(0,?KHeight):?#?這倆個(gè)for是為了取出一次?一橫?*?一豎?的?行列式,這里完全可以mege成一個(gè)for循環(huán),只需要提前計(jì)算好就行
????????????for?jjj?in?range(0,?KHeight):
????????????????OutScrIm2Col[row_cnt][col_cnt]?=?scr[ii][jj]
????????????????jj?+=1
????????????????col_cnt?+=?1
????????????ii?+=?1
????????????jj?=?j
????????col_cnt?=?0
????????row_cnt?+=?1
im2col_kernel?=?im2col_kernel.reshape(-1,9)
OutScrIm2Col?=?OutScrIm2Col.T
out?=?np.matmul(im2col_kernel,OutScrIm2Col)?#?這步就是做兩個(gè)數(shù)組的矩陣乘積
中間倆個(gè)for循環(huán)是來(lái)填滿(mǎn)展開(kāi)的數(shù)組/矩陣的每一列,即卷積核對(duì)應(yīng)的元素,其個(gè)數(shù)等于卷積核的元素個(gè)數(shù),舉個(gè)例子,[1, 3, 3]的卷積核,那么該卷積核的元素個(gè)數(shù)等于9;最外層的兩個(gè)for循環(huán)是用來(lái)填滿(mǎn)展開(kāi)的數(shù)組/矩陣的每一行,即列數(shù),也就是卷積核在輸入滑動(dòng)了多少次

pytorch來(lái)做驗(yàn)證
import?torch
from?torch?import?nn
import?numpy?as?np
torch.manual_seed(100)
net?=?nn.Conv2d(1,?1,?3,?padding=0,?bias=False)
scr?=?np.array(np.arange(0,?7**2).reshape(1,?1,?7,?7)).astype(np.float32)
scr?=?torch.from_numpy(scr)
print(net.weight.data)?#?把這里的weight的值復(fù)制到上面numpy的代碼來(lái)做驗(yàn)證
print(net(scr))
#?print的信息
tensor([[[[-0.2589,??0.2106,?-0.1583],
??????????[-0.0107,??0.1177,??0.1693],
??????????[-0.1582,?-0.3048,?-0.1946]]]])
tensor([[[[?-7.6173,??-8.2053,??-8.7934,??-9.3815,??-9.9695],
??????????[-11.7337,?-12.3217,?-12.9098,?-13.4978,?-14.0859],
??????????[-15.8500,?-16.4381,?-17.0261,?-17.6142,?-18.2022],
??????????[-19.9664,?-20.5545,?-21.1425,?-21.7306,?-22.3186],
??????????[-24.0828,?-24.6708,?-25.2589,?-25.8469,?-26.4350]]]],
???????grad_fn=)
四、[C, H, W]的im2col







前面一堆圖,是我故意不寫(xiě)文字,希望大家能夠通過(guò)圖能夠看明白。前面卷積核只有一行的情況,跟[1, H, W]的情況基本一摸一樣,只是這一行的元素個(gè)數(shù)等于卷積核的元素個(gè)數(shù)即可5x3x3=45,展開(kāi)的特征圖的每一個(gè)豎條也是45。
當(dāng)卷積核函數(shù)等于3的時(shí)候,就是對(duì)應(yīng)的只要增加卷積核的橫條數(shù)即可,展開(kāi)的特征圖沒(méi)有改變。這里希望大家用行列式的計(jì)算和普通卷積的過(guò)程聯(lián)想起來(lái),你會(huì)發(fā)現(xiàn)是一摸一樣的計(jì)算過(guò)程。
代碼其實(shí)跟[1,H, W]只有一初不同,就是從特征圖里面取數(shù)據(jù)的時(shí)候多了個(gè)維度,需要取對(duì)應(yīng)的通道。這里為什么要取對(duì)應(yīng)的通道數(shù)呢?原因是行列式的計(jì)算中,橫條和豎條是元素一一對(duì)應(yīng)做乘法。
import??numpy?as?np
np.set_printoptions(threshold=np.inf)
src?=?np.array(np.arange(0,?9**3))[0:5*9*9]
src?=?np.tile(src,?5)
src?=?src.reshape(-1,?5,?9,?9)
kernel?=?np.array([[[[-0.1158,??0.0942,?-0.0708],
??????????[-0.0048,??0.0526,??0.0757],
??????????[-0.0708,?-0.1363,?-0.0870]],
?????????[[-0.1139,?-0.1128,??0.0702],
??????????[?0.0631,??0.0857,?-0.0244],
??????????[?0.1197,??0.1481,??0.0765]],
?????????[[-0.0823,?-0.0589,?-0.0959],
??????????[?0.0966,??0.0166,??0.1422],
??????????[-0.0167,??0.1335,??0.0729]],
?????????[[-0.0032,?-0.0768,??0.0597],
??????????[?0.0083,?-0.0754,??0.0867],
??????????[-0.0228,?-0.1440,?-0.0832]],
?????????[[?0.1352,??0.0615,?-0.1005],
??????????[?0.1163,??0.0049,?-0.1384],
??????????[?0.0440,?-0.0468,?-0.0542]]]])
scrN,?srcChannel,?intH,?intW=?src.shape
KoutChannel,?KinChannel,?kernel_H,?kernel_W?=?kernel.shape
im2col_kernel?=?kernel.reshape(KoutChannel,?-1)
outChannel,?outH,?outW?=??KoutChannel,?(intH?-?kernel_H?+?1)?,?(intW?-?kernel_W?+?1)
OutScrIm2Col?=?np.zeros(?[?kernel_H*kernel_W*KinChannel,?outH*outW?]?)
row_num,?col_num?=?OutScrIm2Col.shape
ii,?jj,?cnt_row,?cnt_col?=?0,?0,?0,?0
#?卷積核的reshape準(zhǔn)備?:outchannel, k*k*inchannel
im2col_kernel?=?kernel.reshape(KoutChannel,?-1)
#?輸入的reshape準(zhǔn)備?:outH =?(intH - k + 2*pading)/stride + 1
outChannel,?outH,?outW?=??KoutChannel,?(intH?-?kernel_H?+?1)?,?(intW?-?kernel_W?+?1)
i_id?=?-1
cnt_col?=?-1
cnr?=?0
for?Outim2colCol_H?in?range(0,?outH):
????i_id?+=?1
????j_id?=?-1
????cnt_row??=?-1
????for?Outim2colCol_W?in?range(0,?outW):
????????j_id,?cnt_col?+=?1,??+=?1
????????cnt_row?=?0
????????for?c?in?range(0,?srcChannel):?#?取一次卷積的數(shù)據(jù),放到一列
????????????for?iii?in?range(0,?kernel_H):
????????????????i_number?=?iii?+?i_id
????????????????for?jjj?in?range(0,?kernel_W):
????????????????????j_number?=?jjj?+?j_id
????????????????????OutScrIm2Col[cnt_row][cnt_col]?=?src[bs][c][i_number][j_number]
????????????????????cnr?+=1
????????????????????cnt_row?+=?1
????????????????????
Out?=??np.matmul(im2col_kernel,?OutScrIm2Col)
Out.reshape(outChannel,?outH,?outW)
print(Out.shape)
print(outChannel,?outH,?outW)
pytorch代碼的驗(yàn)證
import?torch
from?torch?import?nn
import?numpy?as?np
torch.manual_seed(100)
net?=?nn.Conv2d(in_channels=5,?out_channels=1,?kernel_size=3,?padding=0,?bias=False)
print(net.weight.data.shape)
print(net.weight.data)
scr?=?np.array(np.arange(0,?9**3))[:9*9*5].reshape(1,?-1,?9,?9).astype(np.float32)
scr?=?torch.from_numpy(src)
print("data:",?scr.shape)
scr?=?torch.from_numpy(scr)
print("data:",?scr.shape)
Out?=?net(scr)
print("Our:",?Out.shape)
print(Out)
五、[B, C, H, W]的im2col

問(wèn)題:如何bs=9的情況呢,要怎么做im2col+gemm呢?方法 1:把filter攤平的shape變成[3,5339],把input攤平的shape變成[5339,16]
– output的shape就為[3,16]了 - ?
方法 2:把filter攤平的shape變成[39,533],把input攤平的shape變成[533,16],output的shape就為[39,16]了
– 隱患:如何filter數(shù)量是51233這種數(shù)量,那么非常占用顯存/內(nèi)存
方法 3:im2col+gemm外面加一層關(guān)于bs的for循環(huán)
– 隱患:加一層for循環(huán)嵌套非常耗時(shí)
經(jīng)過(guò)簡(jiǎn)單分析,發(fā)現(xiàn)采取for循環(huán)的方式來(lái)進(jìn)行im2col是相對(duì)合適的情況。我向msnh2012的作者穆士凝魂請(qǐng)教,得到的答案是,是用加一層for循環(huán)的方式居多,而且由于可以并發(fā),多一層循環(huán)的開(kāi)銷(xiāo)比想象中小一些。如果是推理框架的話(huà),有部分情況bs是等于1的,所以可以規(guī)避這個(gè)問(wèn)題。
歡迎關(guān)注GiantPandaCV
歡迎聯(lián)系&投稿
