Dropout算子的bitmask優(yōu)化
背景
在某個風和日麗,適合寫bug的早晨,老大甩給我一個鏈接,里面是onnxruntime對dropout算子使用bitmask的優(yōu)化,思路還是很巧妙的,下面簡單解析下
代碼地址:[CUDA] Implement BitmaskDropout, BitmaskBiasDropout and BitmaskDropoutGrad
Naive Dropout Kernel
Dropout的操作就是生成一個(0, 1)之間的隨機數(shù),當大于dropout_rate的時候,則設置mask=1,否則則設置為mask=0,這個mask值我們也需要保存下來用于后向,一段簡化版本的樸素代碼:
template<typename?T>
__global__?naive_dropout(const?T*?x,?T*?y,?int8_t*?mask,?float?rate,?const?int64_t?elem_cnt){
??//?curand_init...
??CUDA_1D_KERNEL_LOOP(i,?elem_cnt){
????float?random_val?=?curand_uniform(&state);?
????bool?mask_val?=?random_val?>?rate;?
????y[i]?=?x[i]?*?static_cast(mask_val);
????mask[i]?=?mask_val;?
??}
}
其中隨機數(shù)生成用的是NV的cuRand隨機數(shù)生成庫,而閱讀官網(wǎng)文檔后,在Philox算法下,可以一次性生成4個隨機數(shù),從算子的邏輯來看,這是一個memory-bound的算子,這樣我們就可以應用向量化手段來提高讀寫帶寬,大部分框架內部都做了向量化的優(yōu)化,這里我們用curand_uniform4來一次性生成4個隨機數(shù):
????rand_uniform_pack4.storage?=?curand_uniform4(&state);
????const?LoadType*?x_load?=?reinterpret_cast<const?LoadType*>(x?+?linear_index);
????LoadPack?x_vec;
????x_vec.storage?=?*x_load;
????MaskPack?mask_vec;
????LoadPack?y_vec;
#pragma?unroll
????for?(int?i?=?0;?i???????mask_vec.elem[i]?=?rand_uniform_pack4.elem[i]?>?rate;
??????T?tmp_float_mask?=?static_cast<float>(mask_vec.elem[i]);
??????y_vec.elem[i]?=?x_vec.elem[i]?*?tmp_float_mask?*?t_scale;
????}
Bitmask
在正式介紹OnnxRuntime優(yōu)化的算子前,我們先簡單引入bitmask的概念。顧名思義,bitmask就是用比特位來表示mask,每一個bit可以取值為0和1,那么在dropout里,我們就可以用一個bit的狀態(tài)來表示該元素是否被dropout掉。
相比我們用int8_t類型來保存mask,這無疑能節(jié)省很多顯存。(原來一個int8只能保存1個mask,但如果用bitmask那么一個int8就可以保存8個mask)
使用Bitmask優(yōu)化的Dropout
這里我們選取該PR的dropout_impl.cu文件作為示例:
template?<typename?T,?bool?UseBitmask>
__global__?void?DropoutKernel(const?CUDA_LONG?N,?const?CUDA_LONG?mask_element_count,?const?int?step_size,
??????????????????????????????const?int?steps_per_thread,?const?fast_divmod?fdm_bits_per_element,?const?float?ratio,
??????????????????????????????const?std::pair<uint64_t,?uint64_t>?seeds,?const?T*?X_data,?T*?Y_data,?void*?mask_data)?{
??CUDA_LONG?idx?=?blockDim.x?*?blockIdx.x?+?threadIdx.x;
??const?float?p?=?1.0f?-?ratio;
??const?float?scale?=?1.0f?/?p;
??curandStatePhilox4_32_10_t?state;
??curand_init(seeds.first,?idx,?seeds.second,?&state);
??//???The?Philox_4x32_10?algorithm?is?closely?tied?to?the?thread?and?block?count.
??//???Each?thread?computes?4?random?numbers?in?the?same?time?thus?the?most?efficient
??//???use?of?Philox_4x32_10?is?to?generate?a?multiple?of?4?times?number?of?threads.
??for?(int?i?=?0;?i?????CUDA_LONG?id?=?idx?*?kNumUnroll?+?i?*?step_size;
????rand?=?curand_uniform4(&state);
????BitmaskElementType?thread_bitmask?=?0;
//?actual?computation
#pragma?unroll
????for?(int?i?=?0;?i???????CUDA_LONG?li?=?id?+?i;
??????if?(li?????????bool?mask?=?(&rand.x)[i]?????????Y_data[li]?=?static_cast(static_cast<float>(X_data[li])?*?mask?*?scale);
????????if?(UseBitmask)?{
??????????thread_bitmask?|=?(mask?<????????}?else?{
??????????reinterpret_cast<bool*>(mask_data)[li]?=?mask;
????????}
??????}
????}
????if?(UseBitmask)?{
??????SetBitmask(id,?mask_element_count,?fdm_bits_per_element,?thread_bitmask,
?????????????????????????????reinterpret_cast(mask_data));
????}
????__syncthreads();
??}
}
這個kernel其實也是做了向量化的優(yōu)化,其中kNumUnroll=4,我們著重看向量化循環(huán)展開的這部分邏輯:
??uint32_t?thread_bitmask;?
??for?(int?i?=?0;?i???????CUDA_LONG?li?=?id?+?i;
??????if?(li?????????bool?mask?=?(&rand.x)[i]?????????Y_data[li]?=?static_cast(static_cast<float>(X_data[li])?*?mask?*?scale);
????????if?(UseBitmask)?{
??????????thread_bitmask?|=?(mask?<????????}?...
??????}
????}
當使用bitmask的時候,將mask值進行左移,并通過邏輯或的操作,賦進thread_bitmask里的其中一個bit,這樣循環(huán)結束后,每個線程的thread_bitmask就存儲了其處理的4個元素的mask值。
假設我們的處理的4個元素的mask值分別是1 0 1 1,那么示意圖如下:

每個線程計算好mask后,下一步就是怎么把各個mask存儲進變量中,對應的是bitmask.cuh中的SetBitmask函數(shù)
template?<int?NumUnroll>
__device__?__forceinline__?void?SetBitmask(const?CUDA_LONG?id,?const?CUDA_LONG?mask_element_count,
???????????????????????????????????????????const?fast_divmod?fdm_bits_per_element,?BitmaskElementType?thread_bitmask,
???????????????????????????????????????????BitmaskElementType*?mask_data)?{
??int?bitmask_idx,?bitmask_shift;
??fdm_bits_per_element.divmod(id,?bitmask_idx,?bitmask_shift);
??BitmaskElementType?bitmask?=?(thread_bitmask?<??
#if?defined(USE_CUDA)?&&?__CUDA_ARCH__?>=?800
??BitmaskElementType?thread_mask?=?__match_any_sync(0xFFFFFFFF,?bitmask_idx);
??bitmask?=?__reduce_or_sync(thread_mask,?bitmask);
#else
??#pragma?unroll
??for?(int?stride?=?kNumBitsPerBitmaskElement?/?(NumUnroll?*?2);?stride?>?0;?stride?/=?2)?{
????bitmask?|=?WARP_SHFL_DOWN(bitmask,?stride);
??}
??//?Choose?a?single?from?the?"thread?mask"?group?to?perform?the?output?write.
??if?(bitmask_shift?==?0?&&?bitmask_idx?????mask_data[bitmask_idx]?=?bitmask;
??}
首先fdm_bits_per_element是一個快速除法的操作,除數(shù)設置為32(因為這里用uint32_t存儲32個bit),他的操作等價于:
bitmask_idx?=?id?/?32;?表示該線程的bitmask應該寫到第幾個mask_data中
bitmask_shift?=?id?%?32;?表示該線程的bitmask應該偏移到?1個mask中的哪個bit位
而前面我們每個線程處理4個元素,那么對應的id是:
id:?0?4?8?12?...?28
bitmask_idx:?0?0?0?0?0?
bitmask_shift:?0?4?8?12
由于每個線程的thread_bitmask都只有前4位有效,而我們要想把多個線程的thread_bitmask放到一個uint32_t變量中,就需要對其做偏移。1個uint32_t可以存儲8個線程的thread_bitmask,一個示意圖如下:

最后就是將所有線程給結合起來,筆者對__match_any_snyc不太熟悉,我們看warp_shfl_down版本的操作,它將stride設置為kNumBitsPerBitmaskElement / (NumUnroll * 2),這里kNumBitsPerBitmaskElemen=32,NumUnroll=4,那是對每8個線程放一起做warp級別的reduce和邏輯或操作,一個線程reduce示意圖如下:

我們取第一次reduce中,0號線程和4號線程的操作具體分析:

這樣就將所有線程的bitmask結合到一起,最后選擇第一個線程負責寫入到mask_data中
筆者認為這里可能存在部分線程不活躍的情況,warp_shfl_down不應該所有線程參與操作,而是應該用__activemask()
性能數(shù)據(jù)
OnnxRuntime的PR也有對應的Profile數(shù)據(jù):

選取了Bert模型,對于峰值顯存有10%的減少,而帶寬也有10%的提升(一方面是用了bitmask寫入數(shù)據(jù)變少了,另一方面說一般用了向量化優(yōu)化基本都可以打滿帶寬)

