CUDA優(yōu)化之LayerNorm性能優(yōu)化實踐

撰文 | 郭冉、姚遲、鄭澤康、柳俊丞



以 PyTorch 為例,LayerNorm 的接口為:
torch.nn.LayerNorm(normalized_shape,?eps=1e-05,?elementwise_affine=True,?device=None,?dtype=None)
其中 input 形狀為:[?, normalized_shape[0], normalized_shape[1], …,normalized_shape[?1]]



LayerNorm 中求方差的方法
1.two-pass方法
使用的公式是:


這種方法是一種 single pass 方法,在計算方差時只需要遍歷一遍數(shù)據(jù)累加 x 的平方及累加 x,最后按上述公式計算得到方差。這種方法只需要遍歷一遍數(shù)據(jù),相比 two-pass 的算法,更容易達(dá)到好的性能,但是上面的 Wiki 參考鏈接中介紹由于 SumSquare 和 (Sum×Sum)/n 可能非常接近,可能會導(dǎo)致計算結(jié)果損失精度較大,因此這種方法不建議在實踐中使用。
3.Welford 算法
使用的公式是:

Welford 算法也是一種 single pass 方法,且數(shù)值穩(wěn)定性很好,因此現(xiàn)在很多框架都采用這種方法。本文的代碼中采用的也是 Welford 方法。
OneFlow 深度優(yōu)化 LayerNorm CUDA Kernel 的技巧
和 Softmax 一樣,LayerNorm 也采用分段函數(shù)優(yōu)化,對于不同的 num_cols 范圍,采用不同的實現(xiàn),以在各種情況下都能達(dá)到較高的有效帶寬。
在每種實現(xiàn)中都采用了一個公共的優(yōu)化:向量化訪存,NVIDIA 性能優(yōu)化的博客 Increase Performance with Vectorized Memory Access 中提到可以通過向量化內(nèi)存操作來提高 CUDA Kernel 性能,很多 CUDA Kernel 都是帶寬受限的,使用向量化內(nèi)存操作可以減少總的指令數(shù),減少延遲,提高帶寬利用率。
理論上來說,在計算 LayerNorm 的過程中,輸入 x 需要被讀兩次,第一次用于計算均值和方差。第二次用于得到均值和方差后的計算過程。而對 Global Memory 的訪問操作是昂貴的,如果能將輸入 x 先存起來,不重復(fù)讀,就可以提升性能。在 GPU 中將輸入 x 存起來可以使用寄存器或 Shared memory,但是寄存器資源和 Shared memory 資源都是有限的,如果 num_cols 過大,就會超出資源的使用限制,因此我們針對不同 num_cols 采用不同的實現(xiàn),下面分別進(jìn)行介紹:
1.num_cols <= 1024 的情況
針對 num_cols <= 1024 的情況,以 Warp 為單位處理一行或兩行,將輸入 x 存儲到寄存器中。


WelfordWarpAllReduce 由 WelfordWarpReduce 和 Broadcast 操作完成,WelfordWarpReduce 借助 Warp 級別同步原語 __shfl_down_sync 實現(xiàn),Broadcast操作借助 __shfl_sync 實現(xiàn),代碼如下:
template
T,?int?thread_group_width?=?kWarpSize>
__inline__?__device__?void?WelfordWarpReduce(T?thread_mean,?T?thread_m2,?T?thread_count,?T*?mean,
?????????????????????????????????????????????T*?m2,?T*?count)?{
??*mean?=?thread_mean;
??*m2?=?thread_m2;
??*count?=?thread_count;
??for?(int?mask?=?thread_group_width?/?2;?mask?>?0;?mask?/=?2)?{
????T?b_mean?=?__shfl_down_sync(0xffffffff,?*mean,?mask);
????T?b_m2?=?__shfl_down_sync(0xffffffff,?*m2,?mask);
????T?b_count?=?__shfl_down_sync(0xffffffff,?*count,?mask);
????WelfordCombine(b_mean,?b_m2,?b_count,?mean,?m2,?count);
??}
}
templateT,?int?thread_group_width?=?kWarpSize>
__inline__?__device__?void?WelfordWarpAllReduce(T?thread_mean,?T?thread_m2,?T?thread_count,?T*?mean,
????????????????????????????????????????????????T*?m2,?T*?count)?{
??WelfordWarpReduce<T,?thread_group_width>(thread_mean,?thread_m2,?thread_count,?mean,?m2,?count);
??*mean?=?__shfl_sync(0xffffffff,?*mean,?0,?thread_group_width);
??*m2?=?__shfl_sync(0xffffffff,?*m2,?0,?thread_group_width);
??*count?=?__shfl_sync(0xffffffff,?*count,?0,?thread_group_width);
}
在這里有個模板參數(shù) thread_group_width,當(dāng) num_cols > pack_size * WarpSize 時,thread_group_width 為 WarpSize。當(dāng) num_cols 太小,即 num_cols

將 pack_size 個元素 pack 成更大的數(shù)據(jù)類型讀入,但是 x 還要參與計算。因此我們定義一個union 結(jié)構(gòu)的 Pack 類型,storage 用于從 Global Memory中讀寫,做計算時用 elem[i] 取每個元素參與計算,Pack 類型定義如下:
template<typename?T,?int?N>
union?Pack?{
??PackType?storage;
??T?elem[N];
};
LayerNormWarpImpl Kernel 代碼如下:
template<typename?LOAD,?typename?STORE,?typename?ComputeType,?int?pack_size,?int?cols_per_thread,
?????????int?thread_group_width,?int?rows_per_access,?bool?padding>
__global__?void?LayerNormWarpImpl(LOAD?load,?STORE?store,?const?int64_t?rows,?const?int64_t?cols,
??????????????????????????????????const?double?epsilon,?ComputeType*?mean,
??????????????????????????????????ComputeType*?inv_variance)?{
??static_assert(cols_per_thread?%?pack_size?==?0,?"");
??static_assert(thread_group_width?<=?kWarpSize,?"");
??static_assert(kWarpSize?%?thread_group_width?==?0,?"");
??constexpr?int?num_packs?=?cols_per_thread?/?pack_size;
??assert(cols?<=?cols_per_thread?*?thread_group_width);
??ComputeType?buf[rows_per_access][cols_per_thread];
??const?int64_t?global_thread_group_id?=?blockIdx.x?*?blockDim.y?+?threadIdx.y;
??const?int64_t?num_global_thread_group?=?gridDim.x?*?blockDim.y;
??const?int64_t?lane_id?=?threadIdx.x;
??for?(int64_t?row?=?global_thread_group_id?*?rows_per_access;?row????????row?+=?num_global_thread_group?*?rows_per_access)?{
????ComputeType?thread_mean[rows_per_access];
????ComputeType?thread_m2[rows_per_access];
????ComputeType?thread_count[rows_per_access];
#pragma?unroll
????for?(int?row_id?=?0;?row_id???????thread_mean[row_id]?=?0;
??????thread_m2[row_id]?=?0;
??????thread_count[row_id]?=?0;
??????ComputeType*?row_buf?=?buf[row_id];
#pragma?unroll
??????for?(int?pack_id?=?0;?pack_id?????????const?int?col?=?(pack_id?*?thread_group_width?+?lane_id)?*?pack_size;
????????const?int?pack_offset?=?pack_id?*?pack_size;
????????if?(!padding?||?col???????????load.template?load(row_buf?+?pack_offset,?row?+?row_id,?col);
#pragma?unroll
??????????for?(int?i?=?0;?i?????????????WelfordCombine(row_buf[pack_offset?+?i],?thread_mean?+?row_id,?thread_m2?+?row_id,
???????????????????????????thread_count?+?row_id);
??????????}
????????}?else?{
#pragma?unroll
??????????for?(int?i?=?0;?i?0;?}
????????}
??????}
????}
????ComputeType?warp_mean[rows_per_access];
????ComputeType?warp_m2[rows_per_access];
????ComputeType?warp_count[rows_per_access];
#pragma?unroll
????for?(int?row_id?=?0;?row_id???????int?global_row_id?=?row?+?row_id;
??????ComputeType*?row_buf?=?buf[row_id];
??????WelfordWarpAllReduce(
??????????thread_mean[row_id],?thread_m2[row_id],?thread_count[row_id],?warp_mean?+?row_id,
??????????warp_m2?+?row_id,?warp_count?+?row_id);
??????ComputeType?row_mean?=?warp_mean[row_id];
??????ComputeType?row_variance?=
??????????max(Div(warp_m2[row_id],?warp_count[row_id]),?static_cast(0.0));
??????ComputeType?row_inv_var?=?Rsqrt(row_variance?+?static_cast(epsilon));
??????if?(lane_id?==?0)?{
????????mean[global_row_id]?=?row_mean;
????????inv_variance[global_row_id]?=?row_inv_var;
??????}
#pragma?unroll
??????for?(int?i?=?0;?i?????????row_buf[i]?=?(row_buf[i]?-?row_mean)?*?row_inv_var;
??????}
#pragma?unroll
??????for?(int?i?=?0;?i?????????const?int?col?=?(i?*?thread_group_width?+?lane_id)?*?pack_size;
????????if?(!padding?||?col???????????store.template?store(row_buf?+?i?*?pack_size,?global_row_id,?col);
????????}
??????}
????}
??}
}
LOAD、STORE 分別代表輸入輸出,使用load.template load (ptr, row_id, col_id); 和 store.template store(ptr, row_id, col_id); 進(jìn)行讀取和寫入。使用 LOAD 和 STORE 有兩個好處:a) 可以在 CUDA Kernel中只關(guān)心計算類型 ComputeType,而不用關(guān)心具體的數(shù)據(jù)類型 T。b) 只需要加幾行代碼就可以快速支持 LayerNorm 和其他 Kernel Fuse,減少帶寬需求,提升整體性能。ComputeType 代表計算類型。pack_size 代表向量化訪存操作的 pack 元素的個數(shù),我們將幾個元素 pack 起來讀寫,提升帶寬利用率。 cols_per_thread 代表每個線程處理的元素個數(shù)。 thread_group_width 代表處理元素的線程組的寬度,當(dāng) cols > pack_size * warp_size 時,thread_group_width 就是warp_size,即32。當(dāng) cols < pack_size * warp_size 時,就根據(jù) cols 大小用 1/2個warp 或 1/4個warp 來處理每行的元素。采用更小的 thread_group_width 后,WarpAllReduce需要執(zhí)行的輪次也相應(yīng)減少。 rows_per_access 代表每個 thread_group 一次處理的行數(shù),當(dāng) cols 較小且 thread_group_width 小于warp_size時,若 rows 能被2整除,我們就讓每個線程處理2行來增加指令并行度,從而提升性能。 padding 代表當(dāng)前是否做了 padding,若 cols 不是 warp_size 的整數(shù)倍,我們會把它padding 到最近的整數(shù)倍處理。
2.num_cols > 1024 的情況


LayerNormBlockSMemImpl Kernel的代碼如下:
template<typename?LOAD,?typename?STORE,?typename?ComputeType,?int?pack_size,?int?block_size>
__global__?void?LayerNormBlockSMemImpl(LOAD?load,?STORE?store,?const?int64_t?rows,
???????????????????????????????????????const?int64_t?cols,?const?double?epsilon,?ComputeType*?mean,
???????????????????????????????????????ComputeType*?inv_variance)?{
??extern?__shared__?__align__(sizeof(double))?unsigned?char?shared_buf[];
??auto*?buf?=?reinterpret_cast(shared_buf);
??const?int?tid?=?threadIdx.x;
??assert(cols?%?pack_size?==?0);
??const?int?num_packs?=?cols?/?pack_size;
??for?(int64_t?row?=?blockIdx.x;?row?????ComputeType?thread_mean?=?0;
????ComputeType?thread_m2?=?0;
????ComputeType?thread_count?=?0;
????for?(int?pack_id?=?tid;?pack_id???????ComputeType?pack[pack_size];
??????load.template?load(pack,?row,?pack_id?*?pack_size);
#pragma?unroll
??????for?(int?i?=?0;?i?????????buf[i?*?num_packs?+?pack_id]?=?pack[i];
????????WelfordCombine(pack[i],?&thread_mean,?&thread_m2,?&thread_count);
??????}
????}
????ComputeType?row_mean?=?0;
????ComputeType?row_m2?=?0;
????ComputeType?row_count?=?0;
????WelfordBlockAllReduce(thread_mean,?thread_m2,?thread_count,?&row_mean,?&row_m2,
???????????????????????????????????????&row_count);
????ComputeType?row_variance?=?max(Div(row_m2,?row_count),?static_cast(0.0));
????ComputeType?row_inv_var?=?Rsqrt(row_variance?+?static_cast(epsilon));
????if?(threadIdx.x?==?0)?{
??????mean[row]?=?row_mean;
??????inv_variance[row]?=?row_inv_var;
????}
????for?(int?pack_id?=?tid;?pack_id???????ComputeType?pack[pack_size];
#pragma?unroll
??????for?(int?i?=?0;?i?????????pack[i]?=?(buf[i?*?num_packs?+?pack_id]?-?row_mean)?*?row_inv_var;
??????}
??????store.template?store(pack,?row,?pack_id?*?pack_size);
????}
??}
}
3.num_cols 較大時,不使用 Shared Memory 的情況
當(dāng) num_cols 較大,當(dāng)前硬件資源條件下使用Shared Memory的方法無法成功Launch Kernel時,使用這種實現(xiàn):一個 Block 處理一行的元素,不使用 Shared Memory,重復(fù)讀輸入 x。
這種方法和前面第二種情況線程和元素對應(yīng)關(guān)系一致,唯一的區(qū)別在于,第二種方法將輸入 x 存儲到Shared Memory 中,本方法不存儲 x,在每次計算時需要再從 Global Memory 中讀入 x。這種方法雖然需要多讀一份 x,但是在實際執(zhí)行時,部分輸入可以被 Cache 緩存起來,不會實際增加很多時間。值得注意的是,在這種實現(xiàn)中,block_size 越大,SM 中能同時并行執(zhí)行的 block 數(shù)就越少,對 Cache 的需求就越少,就有更多機(jī)會命中 Cache,因此我們使用較大的 block_size。
LayerNormBlockUncachedImpl 代碼如下:
template<typename?LOAD,?typename?STORE,?typename?ComputeType,?int?pack_size,?int?block_size>
__global__?void?LayerNormBlockUncachedImpl(LOAD?load,?STORE?store,?const?int64_t?rows,
???????????????????????????????????????????const?int64_t?cols,?const?double?epsilon,
???????????????????????????????????????????ComputeType*?mean,?ComputeType*?inv_variance)?{
??const?int?tid?=?threadIdx.x;
??assert(cols?%?pack_size?==?0);
??const?int?num_packs?=?cols?/?pack_size;
??for?(int64_t?row?=?blockIdx.x;?row?????ComputeType?thread_mean?=?0;
????ComputeType?thread_m2?=?0;
????ComputeType?thread_count?=?0;
????for?(int?pack_id?=?tid;?pack_id???????ComputeType?pack[pack_size];
??????load.template?load(pack,?row,?pack_id?*?pack_size);
#pragma?unroll
??????for?(int?i?=?0;?i?????????WelfordCombine(pack[i],?&thread_mean,?&thread_m2,?&thread_count);
??????}
????}
????ComputeType?row_mean?=?0;
????ComputeType?row_m2?=?0;
????ComputeType?row_count?=?0;
????WelfordBlockAllReduce(thread_mean,?thread_m2,?thread_count,?&row_mean,?&row_m2,
???????????????????????????????????????&row_count);
????ComputeType?row_variance?=?max(Div(row_m2,?row_count),?static_cast(0.0));
????ComputeType?row_inv_var?=?Rsqrt(row_variance?+?static_cast(epsilon));
????if?(threadIdx.x?==?0)?{
??????mean[row]?=?row_mean;
??????inv_variance[row]?=?row_inv_var;
????}
????for?(int?pack_id?=?tid;?pack_id???????ComputeType?pack[pack_size];
??????const?int?pack_offset?=?pack_id?*?pack_size;
??????load.template?load(pack,?row,?pack_offset);
#pragma?unroll
??????for?(int?i?=?0;?i???????store.template?store(pack,?row,?pack_offset);
????}
??}
}
3 OneFlow Softmax 庫
????oneflow::cuda::softmax::DirectLoad float>?load(in,?cols);
????oneflow::cuda::softmax::DirectStore<float,?half>?store(out,?cols);
????oneflow::cuda::softmax::DispatchSoftmax<decltype(load),?decltype(store),?float>(
????????cuda_stream,?load,?store,?rows,?cols);
性能優(yōu)勢,可見之前的文章分享。此外,最近一年進(jìn)一步優(yōu)化了小的 num_cols 下的性能。
同時支持了 Softmax 和 LogSoftmax,適用場景更廣。
輸入輸出通過 Load/Store 結(jié)構(gòu)傳遞,解耦數(shù)據(jù)IO和計算,只需要加幾行代碼就可以快速支持 Softmax 和其他 Kernel Fuse,減少帶寬需求,帶來很高的性能收益。
