一個函數(shù)打天下,einsum

極市導讀
?在實現(xiàn)算法需要轉為代碼實現(xiàn)的時候遇到復雜的函數(shù),不僅困難還容易出錯,本文介紹了einsum函數(shù)幫助大家解決這個問題。文章共舉例五個函數(shù)通過einsum標記法表示,讓大家對該函數(shù)的運用更加恰當。>>加入極市CV技術交流群,走在計算機視覺的最前沿
einsum全稱Einstein summation convention(愛因斯坦求和約定),又稱為愛因斯坦標記法,是愛因斯坦1916年提出的一種標記約定,簡單的說就是省去求和式中的求和符號,例如下面的公式:

以einsum的寫法就是:

后者將?
?符號給省去了,顯得更加簡潔;再比如:
?(1)
?(2)
上面兩個栗子換成einsum的寫法就變成:
?(1)
?(2)
在實現(xiàn)一些算法時,數(shù)學表達式已經求出來了,需要將之轉換為代碼實現(xiàn),簡單的一些還好,有時碰到例如矩陣轉置、矩陣乘法、求跡、張量乘法、數(shù)組求和等等,若是以分別以transopse、sum、trace、tensordot等函數(shù)實現(xiàn)的話,不但復雜,還容易出錯
現(xiàn)在,這些問題你統(tǒng)統(tǒng)可以一個函數(shù)搞定,沒錯,就是einsum,einsum函數(shù)就是根據(jù)上面的標記法實現(xiàn)的一種函數(shù),可以根據(jù)給定的表達式進行運算,可以替代但不限于以下函數(shù):
矩陣求跡:trace
求矩陣對角線:diag
張量(沿軸)求和:sum
張量轉置:transopose
矩陣乘法:dot
張量乘法:tensordot
向量內積:inner
外積:outer
該函數(shù)在numpy、tensorflow、pytorch上都有實現(xiàn),用法基本一樣,定義如下:
einsum(equation, *operands)
equation是字符串的表達式,operands是操作數(shù),是一個元組參數(shù),并不是只能有兩個,所以只要是能夠通過einsum標記法表示的乘法求和公式,都可以用一個einsum解決,下面以numpy舉幾個栗子:
# 沿軸計算張量元素之和:
c = a.sum(axis=0)
上面的以sum函數(shù)的實現(xiàn)代碼,設?
?為三維張量,上面代碼用公式來表達的話就是:

換成einsum標記法:

然后根據(jù)此式使用einsum函數(shù)實現(xiàn)等價功能:
c = np.einsum('ijk->jk', a)
# 作用與 c = a.sum(axis=0) 一樣
更進一步的,如果?
?不止是三維,可以將下標?
?換成省略號,以表示剩下的所有維度:
c = np.einsum('i...->...', a)
這種寫法pytorch與tensorflow同樣支持,如果不是很理解的話,可以查看其對應的公式:

# 矩陣乘法
c = np.dot(a, b)
矩陣乘法的公式為:

然后是einsum對應的實現(xiàn):
c = np.einsum('ij,jk->ik', a, b)
最后再舉一個張量乘法栗子:
# 張量乘法
c = np.tensordot(a, b, ([0, 1], [0, 1]))
如果?
?是三維的,對應的公式為:

對應的einsum實現(xiàn):
c = np.einsum('ijk,ijl->kl', a, b)
下面以numpy做一下測試,對比einsum與各種函數(shù)的速度,這里使用python內建的timeit模塊進行時間測試,先測試(四維)兩張量相乘然后求所有元素之和,對應的公式為:

然后是測試代碼:
from timeit import Timer
import numpy as np
# 定義兩個全局變量
a = np.random.rand(64, 128, 128, 64)
b = np.random.rand(64, 128, 128, 64)
# 定義使用einsum與sum的函數(shù)
def einsum():
temp = np.einsum('ijkl,ijkl->', a, b)
def npsum():
temp = (a * b).sum()
# 打印運行時間
print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))
print("npsum cost:", Timer("npsum()", "from __main__ import npsum").timeit(20))
上面Timer是timeit模塊內的一個類
Timer(stmt, setup).timeit(number)
# stmt: 要測試的語句
# setup: 傳入stmt的運行環(huán)境,比如stmt中要導入的模塊等。
# 可以寫一行語句,也可以寫多行語句,寫多行語句時要用分號;隔開語句
# number: 執(zhí)行次數(shù)
將兩個函數(shù)各執(zhí)行20遍,最后的結果為,單位為秒:
einsum cost: 1.5560735
npsum cost: 8.0874927
可以看到,einsum比sum快了幾乎一個量級,接下來測試單個張量求和:

將上面的代碼改一下:
def einsum():
temp = np.einsum('ijkl->', a)
def npsum():
temp = a.sum()
相應的運行時間為:
einsum cost: 3.2716003
npsum cost: 6.7865246
還是einsum更快,所以哪怕是單個張量求和,numpy上也可以用einsum替代,同樣,求均值(mean)、方差(var)、標準差(std)也是一樣
接下來測試einsum與dot函數(shù),首先列一下矩陣乘法的公式以以及einsum表達式:


然后是測試代碼:
a = np.random.rand(2024, 2024)
b = np.random.rand(2024, 2024)
# einsum與dot比較
def einsum():
res = np.einsum('ik,kj->ij', a, b)
def dot():
res = np.dot(a, b)
print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))
print("dot cost:", Timer("dot()", "from __main__ import dot").timeit(20))
# einsum cost: 80.2403851
# dot cost: 2.0842243
這就很尷尬了,比dot慢了40倍(并且差距隨著矩陣規(guī)模的平方增加),這還怎么打天下?不過在numpy的實現(xiàn)里,einsum是可以進行優(yōu)化的,去掉不必要的中間結果,減少不必要的轉置、變形等等,可以提升很大的性能,將einsum的實現(xiàn)改一下:
def einsum():
res = np.einsum('ik,kj->ij', a, b, optimize=True)
加了一個參數(shù)optimize=True,官方文檔上該參數(shù)是可選參數(shù),接受4個值:
optimize : {False, True, ‘greedy’, ‘optimal’}, optional
optimize默認為False,如果設為True,這默認選擇‘greedy(貪心)’方式,再看看速度:
einsum cost: 2.0330937
dot cost: 1.9866218
可以看到,通過優(yōu)化,雖然還是稍慢一些,但是einsum的速度與dot達到了一個量級;不過numpy官方手冊上有個einsum_path,說是可以進一步提升速度,但是我在自己電腦上(i7-9750H)測試效果并不穩(wěn)定,這里簡單的介紹一下該函數(shù)的用法為:
path = np.einsum_path('ik,kj->ij', a, b)[0]
np.einsum('ik,kj->ij', a, b, optimize=path)
einsum_path返回一個einsum可使用的優(yōu)化路徑列表,一般使用第一個優(yōu)化路徑;另外,optimize及einsum_path函數(shù)只有numpy實現(xiàn)了,tensorflow和pytorch上至少現(xiàn)在沒有
最后,再測試einsum與另一個常用的函數(shù)tensordot,首先定義兩個四維張量的及tensordot函數(shù):
a = np.random.rand(128, 128, 64, 64)
b = np.random.rand(128, 128, 64, 64)
def tensordot():
res = np.tensordot(a, b, ([0, 1], [0, 1]))
該實現(xiàn)對應的公式為:

所以einsum函數(shù)的實現(xiàn)為:
def einsum():
res = np.einsum('ijkl,ijmn->klmn', a, b, optimize=True)
tensordot也是鏈接到BLAS實現(xiàn)的函數(shù),所以不加optimize肯定比不了,最后結果為:
print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(1))
print("tensordot cost:", Timer("tensordot()", "from __main__ import tensordot").timeit(1))
# einsum cost: 4.2361331
# tensordot cost: 4.2580409
def einsum():
temp = einsum('...->', a, optimize=True)
def test():
temp = a.sum()
np.einsum('...->', a, optimize=True) # 正常運行
np.einsum('...->', a) # 報錯


np.einsum('i...->...', a) # 正常
np.einsum('...,...->...', a, b) # 正常

np.einsum('i...->i', a, optimize=True) # 必須加optimize,不然報錯
c = (a * b).sum()
# 如果不知道a, b的維數(shù),使用einsum實現(xiàn)上面的功能也必須要加optimize
c = einsum('...,...->', a, b, optimize=True)
推薦閱讀

