玩轉(zhuǎn)張量必備利器之 einops
視覺(jué) Transformer 中免不了要對(duì)小批次圖像張量以及特征張量一頓操作,比如圖像分塊、多頭自注意力機(jī)制等。而這些操作如果借用 einops?以及 einsum 的話往往可以事半功倍。

官方給出的全稱(chēng)是 Einstein-Inspired Notation for operations,也可以看成 Einstein operations。它支持廣泛使用的張量包(如?numpy、pytorch、chainer、gluon、tensorflow)。
靈感來(lái)自愛(ài)因斯坦求和約定(einsum),而這個(gè)強(qiáng)大工具可以參考下面這篇。
兩大利器?einsum?和?einops?在手,可謂無(wú)往不利。
掌握了它們,不僅方便了張量操作,甚至還可以跨包寫(xiě)通用代碼。
.安裝 .
!pip?install?einops
1基本套路
在學(xué)習(xí) NumPy 的多維數(shù)組時(shí),我們知道軸這個(gè)概念。對(duì)于同一個(gè)數(shù)組來(lái)說(shuō),它的幾個(gè)軸的順序是可以不同的。
比如,深度學(xué)習(xí)中經(jīng)常碰到表示一批次圖像數(shù)據(jù)的張量,往往有不同的表示形式。

再比如,一張圖像往往可以表示為一個(gè)三維數(shù)組,如果將 1 軸和 2 軸交換順序也是可以的,只是此時(shí)同一個(gè)元素的索引(下標(biāo))變了。

上述變換在 NumPy 中可以通過(guò)下面代碼來(lái)實(shí)現(xiàn),
y?=?x.transpose(0,?2,?1)
但如果使用 einops,將會(huì)是下面這個(gè)樣子。
y?=?rearrange(x,?'c?h?w?->?c?w?h')
形式上明顯是借鑒了 einsum,但咋一看,沒(méi)比 transpose 方便啊。
這個(gè)例子是沒(méi)體現(xiàn)出它的優(yōu)勢(shì),但實(shí)際上它還有更多功能,比如多個(gè)軸組合、軸分解,分解再組合,約簡(jiǎn)除軸或者增加新軸等。
2操作單張圖片
import?numpy?as?np
from?PIL?import?Image
import?matplotlib.pyplot?as?plt
plt.rcParams['font.sans-serif']?=?[u'SimHei']
plt.rcParams['axes.unicode_minus']?=?False
from?utils?import?display_np_arrays_as_images
display_np_arrays_as_images()
我們加載一幅小扎元宇宙的圖片來(lái)測(cè)試。
img?=?np.array(Image.open('./resources/meta_verse_256.jpg'))/255
img.shape,?img.dtype
((256, 460, 3), dtype('float64'))
img

看這個(gè)圖的大小,(256, 460, 3)。這三個(gè)數(shù)字對(duì)應(yīng)高度、寬度、通道數(shù),即 h w c。
from?einops?import?rearrange,?reduce,?repeat
.元素重排.
einops.rearrange 是一種對(duì)多維數(shù)組(張量)進(jìn)行元素重排的十分強(qiáng)大的操作。
該操作包括轉(zhuǎn)置(軸置換)、reshape、擠壓(squeeze)、解壓(unsqueeze)、堆疊(stack)、拼接(concatenate)等操作。
轉(zhuǎn)置
這個(gè)操作很方便,高和寬兩個(gè)軸交換順序即可,即 h w c -> w h c。
out?=?rearrange(img,?'h?w?c?->?w?h?c')
out.shape,?out.dtype
((460, 256, 3), dtype('float64'))
out

out?=?rearrange(img,?'h?w?(c?cs)?->?(h?cs)?w?c',?cs=3)
out.shape
(768, 460, 1)
out[...,0]

上面不是將三個(gè)通道沿著高度方向拼接,而是將原高度和通道兩個(gè)軸合并了。是三個(gè)通道穿插起來(lái)了,表現(xiàn)出來(lái)是整個(gè)圖高度方向被拉長(zhǎng)了。
那么怎么做到將三個(gè)通道沿著高度方向拼接呢?
先將三個(gè)軸重新排序?yàn)? c h w格式,一般深度學(xué)習(xí)的包中會(huì)以這種格式處理數(shù)據(jù)。
out?=?rearrange(img,?'h?w?c?->?c?h?w')
out.shape
(3, 256, 460)
然后將上面結(jié)果的三個(gè)通道沿 h方向拼接,這個(gè)目標(biāo)只需要合并c和h兩個(gè)軸來(lái)實(shí)現(xiàn)。
rearrange(out,?'c?h?w?->?(c?h)?w')

上述代碼中,c 軸在 h 軸前面,因此 c 是 0 軸,h 是 1 軸,因此數(shù)據(jù)總體是按三個(gè)通道分開(kāi)排列。
也可以將上述操作合并成一步
rearrange(img,?'h?w?c?->?(c?h)?w')

將圖片展平為一維數(shù)組, 353280 = 256 x?460 x?3。
rearrange(img,?'h?w?c?->?(c?h?w)').shape
(353280,)
3操作一批圖片
上面僅僅是對(duì)一張圖片進(jìn)行操作,但像在深度學(xué)習(xí)中往往是對(duì)一批圖片下手。
我們不妨也來(lái)一試,但不想另外再找圖,咱們就地取材,從上面這張圖中取小圖來(lái)操作一番。
將每個(gè)圖像分成 8 個(gè)更小的圖像塊
bhwc?=?rearrange(img,?'(h1?h)?(w1?w)?c?->?(h1?w1)?h?w?c',?h1=2,?w1=4)
bhwc.shape
(8, 128, 115, 3)
def?subfig(bhwc,?hs,?ws):
????fig,?ax?=?plt.subplots(hs,?ws,?figsize=(12,?6))
????for?i,?axi?in?enumerate(ax.flat):
????????axi.imshow(bhwc[i])
????????axi.set(xticks=[],?yticks=[],?xlabel='第?'+str(i+1)+'?個(gè)子圖')
subfig(bhwc,?hs=2,?ws=4)

8 個(gè)小圖塊橫排
rearrange(bhwc,?'b?h?w?c?->?h?(b?w)?c')

8 個(gè)小圖塊縱排
rearrange(bhwc,?'b?h?w?c?->?(b?h)?w?c')


bhwc.shape
(8, 128, 115, 3)
空域到通道轉(zhuǎn)換
res?=?rearrange(img,?'(h1?h)?(w1?w)?c?->?(h1?w1)?h?w?c',?h1=2,?w1=2)
space2deep?=?rearrange(res,?'b?(h1?h)?(w1?w)?c?->?b?h?w?(c?h1?w1)',?h1=2,?w1=2)
space2deep.shape
(4, 64, 115, 12)
上述代碼將 4 個(gè)子圖取出來(lái)沿通道軸拼接得到 12 個(gè)通道。
將每個(gè)圖展平
flt?=?rearrange(bhwc,?'b?h?w?c?->?b?(h?w?c)')
flt.shape
(8, 44160)
4Reduce 和 Repeat
einops 也可以作約簡(jiǎn)(reduce)運(yùn)算,即沿著某個(gè)軸聚合,減少軸;當(dāng)然也能反過(guò)來(lái),沿著新的軸 repeat,增加軸。
對(duì) c軸求均值
out1?=?reduce(bhwc,?'b?h?w?c?->?b?h?w',?reduction='mean')
out1.shape
(8, 128, 115)
上述代碼將每個(gè)子圖的三個(gè)通道作了均值約簡(jiǎn),變成單通道灰度圖。
下面這樣子可以保留被約簡(jiǎn)的軸。
mean1?=?reduce(bhwc,?'b?h?w?c?->?b?h?w?1',?'mean')
mean1.shape
(8, 128, 115, 1)
下面代碼將其重新組裝成完整的三通道灰度圖。
out2?=?repeat(out1,?'(bh?bw)?h?w?->?(bh?h)?(bw?w)?c',?c=3,?bh=2,?bw=4)
out2.shape
(256, 460, 3)
out2

reduce 和 repeat 聯(lián)合使用實(shí)現(xiàn)馬賽克效果。
res?=?reduce(img,?'(h?hs)?(w?ws)?c?->?h?w?c',?reduction='mean',?hs=8,?ws=5)
repeat(res,?'h?w?c->?(h?hs)?(w?ws)?c',?c=3,?hs=8,?ws=5)

.求差 .
min1?=?reduce(bhwc,?'b?h?w?c?->?b?h?w?()',?'min')
min1.shape
(8, 128, 115, 1)
dif?=?bhwc?-?min1
subfig(dif,?hs=2,?ws=4)

5加新軸
out3?=?rearrange(bhwc,?'b?h?w?c?->?1?b?h?w?1?c')
out3.shape
(1, 8, 128, 115, 1, 3)
6圖像 Patch
將一張圖劃分為若干個(gè)圖像子塊。
images?=?rearrange(bhwc,?'b?(h?ph)?(w?pw)?c?->?(b?h?w)?ph?pw?c',?ph?=?32,?pw?=?23)?
images.shape
(160, 32, 23, 3)
subfig(images,?hs=4,?ws=5)

是不是特別適合于拿來(lái)實(shí)現(xiàn)圖像的 Transformer 呢!
7小結(jié)
rearrange不改變?cè)氐臄?shù)量,它的功能涵蓋了不同的 numpy 函數(shù)(如transpose、reshape、stack、concatenate、squeeze和expand_dims)reduce實(shí)現(xiàn)了約簡(jiǎn)操作(如mean、min、max、sum、prod等)repeat實(shí)現(xiàn)重復(fù)(repeating)和平鋪(tiling)。軸的組合和分解是基石,它們可以并且應(yīng)該一起使用來(lái)操作張量。
相關(guān)閱讀
