點(diǎn)擊上方“視學(xué)算法”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
來(lái)源丨h(huán)ttps://zhuanlan.zhihu.com/p/383115932今天一翻朋友圈,發(fā)現(xiàn)好多人轉(zhuǎn)發(fā)一個(gè)業(yè)內(nèi)大佬寫的開(kāi)源項(xiàng)目。內(nèi)容很簡(jiǎn)單,就是在CPU上實(shí)現(xiàn)單精度矩陣乘法??戳艘幌拢Y(jié)果非常好:CPU的利用率很高。更可貴的是核心代碼只有很短不到200行。之前總覺(jué)得自己很了解高性能計(jì)算,無(wú)外乎就是“局部性+向量”隨便搞一搞。但是嘴上說(shuō)說(shuō)和實(shí)際實(shí)現(xiàn)自然有很大差別??赐炅舜罄械拇a覺(jué)得受益匪淺,在這里總結(jié)了一下,當(dāng)作自己的讀書(shū)筆記了。最前面自然是要放項(xiàng)目鏈接,強(qiáng)烈推薦大家讀一讀源代碼:https://github.com/pigirons/sgemm_hsw
問(wèn)題描述:給定兩個(gè)矩陣,其shape分別為(m,k)和(k, 24),求矩陣相乘的結(jié)果。為了方便理解,這里直接把m和k弄了一個(gè)數(shù)值帶了進(jìn)去。所以我們的問(wèn)題如下:輸入是棕色矩陣A和藍(lán)色矩陣B,求紅色矩陣C我們知道一般矩陣乘法就是一堆循環(huán)的嵌套,這個(gè)也不例外。在代碼里,最外層結(jié)果是輸出矩陣的行遍歷。又因?yàn)闀?huì)有向量化的操作,所以最終結(jié)果是:最外層的循環(huán)每次算4行輸出(PS:這里面的4是固定的,并不是我為了方便隨便設(shè)的)。現(xiàn)在我們拆開(kāi)來(lái)看每輪循環(huán):我們每輪會(huì)算4行,24列的輸出。在這里,我們把輸出用12個(gè)向量寄存器表示。現(xiàn)在可以隱約看出來(lái)為什么大佬要固定24這個(gè)數(shù)字了:因?yàn)閥mm寄存器只有16個(gè),我們又希望行數(shù)可以比較整,那么我們每次處理4行比較合適,處理4行的話,每行可以有16/4=4個(gè)寄存器。但是我們要做向量運(yùn)算的話,那我們一定又要有向量寄存器當(dāng)作運(yùn)算符,所以我們不能把這16個(gè)寄存器都用來(lái)存output。所以權(quán)衡一下,那我們每行用3個(gè)寄存器好了,這樣總共12個(gè)寄存器存結(jié)果,剩下4個(gè)用來(lái)搞搞計(jì)算。因?yàn)閥mm是256bit的,可以存8個(gè)float類型,所以我們每列就應(yīng)該是24確定了計(jì)算的目標(biāo),下面我們繼續(xù)更進(jìn)一步,來(lái)看我們?cè)诿總€(gè)內(nèi)存循環(huán)都要做什么。還記得我們之前剩了4個(gè)ymm寄存器么?現(xiàn)在我們把它們都利用上:先來(lái)思考下我們能不能直接在A矩陣用ymm?如果用的話,那么我們會(huì)把A矩陣一行的連續(xù)數(shù)據(jù)存到一起。這些數(shù)據(jù)會(huì)和誰(shuí)運(yùn)算呢?是B的一列數(shù)據(jù),也就是圖中黑色的部分。一般來(lái)說(shuō)我們假設(shè)矩陣都是列連續(xù)的。那么訪問(wèn)黑色的部分,locality就會(huì)很差:我們要把這些數(shù)字一個(gè)一個(gè)讀出來(lái),塞到一個(gè)ymm里面和A的ymm進(jìn)行運(yùn)算。用排除法,我們別無(wú)選擇,只能把ymm用到B上面。B也是24列,我們用3個(gè)ymm就存下了。還剩一個(gè),我們先把A的第一行第一列的數(shù)字讀出來(lái),把它復(fù)制8份拓展成一個(gè)ymm,然后和這三個(gè)B的ymm作element-wise的乘法,把結(jié)果累加到y(tǒng)mm0~ymm2里。現(xiàn)在發(fā)現(xiàn)這個(gè)算法的精妙了么?對(duì)的!他正好把16個(gè)ymm都用上了,一個(gè)不多一個(gè)不少之后我們?cè)摳陕??其?shí)有很多選擇,比如我們把ymm12~ymm14往下移動(dòng)一行,和第一行第二列的數(shù)字做乘法,如下圖:正確性上來(lái)說(shuō),上面的做法沒(méi)問(wèn)題。但我們來(lái)看看下圖里大佬是怎么做的:相比于之前我們說(shuō)的循環(huán)到A的第一行第二列,大佬循環(huán)到了第二行第一列:在這種情況下我們只需要重新構(gòu)造ymm15,原來(lái)的ymm12~ymm14完全都不需要變,不需要讀新的數(shù)值,只需要改變輸出位置,從原來(lái)寫到y(tǒng)mm0~ymm2變成了ymm3~ymm5。但因?yàn)槭菍懠拇嫫鞫莾?nèi)存,所以都一樣。說(shuō)到這兒,大概也把循環(huán)捋清楚了:最內(nèi)層是按照A的列來(lái)迭代:(1)把A的第一行第一列讀出來(lái)構(gòu)造ymm15做計(jì)算,(2)把A的第二行第一列讀出來(lái)構(gòu)造ymm15做計(jì)算。。。。一直讀到A的第四行第一列(為什么是第四行?因?yàn)槲覀冚敵鍪撬男械募拇嫫鳎?,然后開(kāi)始讀A的第一行第二列構(gòu)造ymm,然后讀A的第二行第二列構(gòu)造ymm。。。(1)寫并行計(jì)算,感覺(jué)就像在下國(guó)際象棋:你有很多種走法,這些走法都合法,但是最優(yōu)的只有一種。(2)實(shí)際上寫高性能的程序就是在湊數(shù):在這個(gè)代碼里,我們根據(jù)體系結(jié)構(gòu)里ymm的寬度和ymm的寄存器個(gè)數(shù),推導(dǎo)出我們輸出矩陣每行得有24列。然后又繼續(xù)湊湊湊,得到了4步的步長(zhǎng)的循環(huán)。雖然都是湊數(shù),但是大佬的代碼湊的很好:每一個(gè)ymm都被利用到了,這就是人家的水平。