(附代碼)大佬是如何優(yōu)雅實(shí)現(xiàn)矩陣乘法的?
點(diǎn)擊左上方藍(lán)字關(guān)注我們

之前總覺得自己很了解高性能計(jì)算,無外乎就是“局部性+向量”隨便搞一搞。但是嘴上說說和實(shí)際實(shí)現(xiàn)自然有很大差別。看完了大佬的代碼覺得受益匪淺,在這里總結(jié)了一下,當(dāng)作自己的讀書筆記了。
最前面自然是要放項(xiàng)目鏈接,強(qiáng)烈推薦大家讀一讀源代碼:
https://github.com/pigirons/sgemm_hsw
=========================正文===============================
問題描述:給定兩個矩陣,其shape分別為(m,k)和(k, 24),求矩陣相乘的結(jié)果。
為了方便理解,這里直接把m和k弄了一個數(shù)值帶了進(jìn)去。所以我們的問題如下:輸入是棕色矩陣A和藍(lán)色矩陣B,求紅色矩陣C

我們知道一般矩陣乘法就是一堆循環(huán)的嵌套,這個也不例外。在代碼里,最外層結(jié)果是輸出矩陣的行遍歷。又因?yàn)闀邢蛄炕牟僮鳎宰罱K結(jié)果是:最外層的循環(huán)每次算4行輸出(PS:這里面的4是固定的,并不是我為了方便隨便設(shè)的)。
就是下面的情況:

現(xiàn)在我們拆開來看每輪循環(huán):我們每輪會算4行,24列的輸出。在這里,我們把輸出用12個向量寄存器表示。
現(xiàn)在可以隱約看出來為什么大佬要固定24這個數(shù)字了:因?yàn)閥mm寄存器只有16個,我們又希望行數(shù)可以比較整,那么我們每次處理4行比較合適,處理4行的話,每行可以有16/4=4個寄存器。但是我們要做向量運(yùn)算的話,那我們一定又要有向量寄存器當(dāng)作運(yùn)算符,所以我們不能把這16個寄存器都用來存output。所以權(quán)衡一下,那我們每行用3個寄存器好了,這樣總共12個寄存器存結(jié)果,剩下4個用來搞搞計(jì)算。因?yàn)閥mm是256bit的,可以存8個float類型,所以我們每列就應(yīng)該是24。

確定了計(jì)算的目標(biāo),下面我們繼續(xù)更進(jìn)一步,來看我們在每個內(nèi)存循環(huán)都要做什么。還記得我們之前剩了4個ymm寄存器么?現(xiàn)在我們把它們都利用上:先來思考下我們能不能直接在A矩陣用ymm?如果用的話,那么我們會把A矩陣一行的連續(xù)數(shù)據(jù)存到一起。這些數(shù)據(jù)會和誰運(yùn)算呢?是B的一列數(shù)據(jù),也就是圖中黑色的部分。一般來說我們假設(shè)矩陣都是列連續(xù)的。那么訪問黑色的部分,locality就會很差:我們要把這些數(shù)字一個一個讀出來,塞到一個ymm里面和A的ymm進(jìn)行運(yùn)算。

用排除法,我們別無選擇,只能把ymm用到B上面。B也是24列,我們用3個ymm就存下了。還剩一個,我們先把A的第一行第一列的數(shù)字讀出來,把它復(fù)制8份拓展成一個ymm,然后和這三個B的ymm作element-wise的乘法,把結(jié)果累加到y(tǒng)mm0~ymm2里。
現(xiàn)在發(fā)現(xiàn)這個算法的精妙了么?對的!他正好把16個ymm都用上了,一個不多一個不少。

之后我們該干嘛?其實(shí)有很多選擇,比如我們把ymm12~ymm14往下移動一行,和第一行第二列的數(shù)字做乘法,如下圖:


相比于之前我們說的循環(huán)到A的第一行第二列,大佬循環(huán)到了第二行第一列:在這種情況下我們只需要重新構(gòu)造ymm15,原來的ymm12~ymm14完全都不需要變,不需要讀新的數(shù)值,只需要改變輸出位置,從原來寫到y(tǒng)mm0~ymm2變成了ymm3~ymm5。但因?yàn)槭菍懠拇嫫鞫莾?nèi)存,所以都一樣。
說到這兒,大概也把循環(huán)捋清楚了:最內(nèi)層是按照A的列來迭代:(1)把A的第一行第一列讀出來構(gòu)造ymm15做計(jì)算,(2)把A的第二行第一列讀出來構(gòu)造ymm15做計(jì)算。。。。一直讀到A的第四行第一列(為什么是第四行?因?yàn)槲覀冚敵鍪撬男械募拇嫫鳎缓箝_始讀A的第一行第二列構(gòu)造ymm,然后讀A的第二行第二列構(gòu)造ymm。。。
總結(jié):
(1)寫并行計(jì)算,感覺就像在下國際象棋:你有很多種走法,這些走法都合法,但是最優(yōu)的只有一種。
(2)實(shí)際上寫高性能的程序就是在湊數(shù):在這個代碼里,我們根據(jù)體系結(jié)構(gòu)里ymm的寬度和ymm的寄存器個數(shù),推導(dǎo)出我們輸出矩陣每行得有24列。然后又繼續(xù)湊湊湊,得到了4步的步長的循環(huán)。雖然都是湊數(shù),但是大佬的代碼湊的很好:每一個ymm都被利用到了,這就是人家的水平
END
整理不易,點(diǎn)贊三連↓
