一種更加高效的卷積計算策略:Im2Col+GEMM的改進方法MEC

極市導讀
?本文介紹了一種Im2Col+GEMM的改進版本——MEC,并記錄了對它進行復現嘗試后的測試結果。>>加入極市CV技術交流群,走在計算機視覺的最前沿
1. 前言
2. 背景介紹

所以,MEC改進了Im2Col+GEMM的策略,目的是減少它的內存消耗同時提升一點速度。
3. MEC算法原理


下面的Algorithm1展示了這個算法的流程:

3.2 MEC算法高級版

然后下面的Figure3是它的示例圖:

從偽代碼里可以看到這里有2種計算方法:
Solution 1:Algorithm2中的第9-19行和Algorithm1中的方法完全一致,然后14-19行是對臨時結果對做排列變化,即Figure3中的上半部分。 Solution 2:Algorithm2中的第21-25行。每次循環(huán)處理一個樣本,不需要做額外的排列變化,即Figure3中的下半部分。
4. 實驗對比



5. 復現嘗試(暫時只針對X86 CPU)
// 原始的Im2Col
void im2col_cpu(float** src, const int &inHeight, const int &intWidth, const int &kHeight,
const int &kWidth, float* srcIm2col){
const int outHeight = inHeight - kHeight + 1;
const int outWidth = intWidth - kWidth + 1;
int cnt = 0;
for(int i = 0; i < kHeight; i++){
for(int j = 0; j < kWidth; j++){
int id = i * kWidth + j;
int ii = i;
for(int x = 0; x < outHeight; x++){
int jj = j;
for(int y = 0; y < outWidth; y++){
srcIm2col[cnt] = src[ii][jj];
jj += 1;
cnt++;
}
ii += 1;
}
}
}
}
cblas_sgemm接口,關于OpenBlas的介紹以及計算方式,函數接口可以查看參考中的資料2,這里就不過多介紹了。// 構造輸入矩陣
float **src = new float*[inHeight];
for(int i = 0; i < inHeight; i++){
src[i] = new float[inWidth];
for(int j = 0; j < inWidth; j++){
src[i][j] = 0.1;
}
}
// 構造kernel矩陣
float **kernel[kernel_num];
for(int i = 0; i < kernel_num; i++){
kernel[i] = new float*[kernel_h];
for(int j = 0; j < kernel_h; j++){
kernel[i][j] = new float[kernel_w];
for(int k = 0; k < kernel_w; k++){
kernel[i][j][k] = 0.2;
}
}
}
// 開始計時
struct timeval tstart, tend;
gettimeofday(&tstart, NULL);
// 對kernel進行Im2col
float* kernel2col = new float[kernel_num*kernel_h*kernel_w];
int cnt = 0;
for(int i = 0; i < kernel_num; i++){
for(int j = 0; j < kernel_h; j++){
for(int k = 0; k < kernel_w; k++){
kernel2col[cnt++] = kernel[i][j][k];
}
}
}
// 對輸入矩陣Im2col
int outHeight = inHeight - kernel_h + 1;
int outWidth = inWidth - kernel_w + 1;
float *srcIm2col = new float[kernel_w * kernel_h * outWidth * outHeight];
im2col_cpu(src, inHeight, inWidth, kernel_h, kernel_w, srcIm2col);
cblas_sgemm函數接口即可完成卷積層的計算,這個地方加入了計時函數,統(tǒng)計Im2Col+gemm的運行時間:// 使用Blas庫實現矩陣乘法
float *output = new float[kernel_num * outHeight * outWidth];
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,kernel_num,
outHeight*outWidth, kernel_w*kernel_h, 1,
kernel2col, kernel_h*kernel_w,
srcIm2col,outHeight * outWidth, 0, output, outHeight * outWidth);
// 結束計時
gettimeofday(&tend, NULL);
cout<<"im2colOrigin Total time cost: "<<(tend.tv_sec-tstart.tv_sec)*1000 + (tend.tv_usec-tstart.tv_usec)/1000<<" ms"<
// MEC
void im2col_mec(float** src, const int &inHeight, const int &intWidth, const int &kHeight,
const int &kWidth, float* srcIm2col){
const int outHeight = inHeight - kHeight + 1;
const int outWidth = intWidth - kWidth + 1;
#pragma omp parallel for num_threads(THREAD_NUM)
for(int i = 0; i < outWidth; i++){
int outrow = 0;
for(int j = 0; j < inHeight; j++){
for(int k = i; k < i + kWidth; k++){
srcIm2col[outrow * outWidth + i] = src[j][k];
outrow++;
}
}
}
}
// 構造輸入矩陣
float **src = new float*[inHeight];
for(int i = 0; i < inHeight; i++){
src[i] = new float[inWidth];
for(int j = 0; j < inWidth; j++){
src[i][j] = 0.1;
}
}
// 構造kernel矩陣
float **kernel[kernel_num];
for(int i = 0; i < kernel_num; i++){
kernel[i] = new float*[kernel_h];
for(int j = 0; j < kernel_h; j++){
kernel[i][j] = new float[kernel_w];
for(int k = 0; k < kernel_w; k++){
kernel[i][j][k] = 0.2;
}
}
}
// 開始計時
struct timeval tstart, tend;
gettimeofday(&tstart, NULL);
// 對kernel進行Im2col
float* kernel2col = new float[kernel_num*kernel_h*kernel_w];
int cnt = 0;
for(int i = 0; i < kernel_num; i++){
for(int j = 0; j < kernel_h; j++){
for(int k = 0; k < kernel_w; k++){
kernel2col[cnt++] = kernel[i][j][k];
}
}
}
// 對輸入矩陣Im2col
int outHeight = inHeight - kernel_h + 1;
int outWidth = inWidth - kernel_w + 1;
float *srcIm2col = new float[outWidth * inHeight * kernel_w];
im2col_mec(src, inHeight, inWidth, kernel_h, kernel_w, srcIm2col);
// 使用Blas庫實現矩陣乘法
float **output = new float*[outHeight];
#pragma omp parallel for num_threads(THREAD_NUM)
for(int i = 0; i < outHeight; i++){
output[i] = new float [kernel_num * outWidth];
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,kernel_num,
outWidth, kernel_w * kernel_h,1,
kernel2col, kernel_h * kernel_w,
srcIm2col + i * outWidth, outWidth, 0, output[i], outWidth);
}
// 結束計時
gettimeofday(&tend, NULL);
cout<<"MEC Total time cost: "<<(tend.tv_sec-tstart.tv_sec)*1000 + (tend.tv_usec-tstart.tv_usec)/1000<<" ms"<
https://github.com/BBuf/Memory-efficient-Convolution-for-Deep-Neural-Network6. 效果

參考資料
推薦閱讀
?ACCV 2020國際細粒度網絡圖像識別競賽正式開賽!

評論
圖片
表情
