【深度學(xué)習(xí)】編寫同時在PyTorch和Tensorflow上工作的代碼
作者 | Ram Sagar?
編譯 | VK?
來源 | Analytics In Diamag
?“庫開發(fā)人員不再需要在框架之間進(jìn)行選擇?!?/p>?
來自德國圖賓根人工智能中心的研究人員介紹了一種新的Python框架EagerPy,EagerPy允許開發(fā)人員編寫?yīng)毩⒂赑yTorch和TensorFlow等流行框架的代碼。
在最近發(fā)表的一篇關(guān)于EagerPy的文章中,研究人員寫道,庫開發(fā)人員不再關(guān)注框架依賴性。他們的新Python框架,急切地解決了它們的重新實現(xiàn)和代碼復(fù)制障礙。
例如,F(xiàn)oolbox是一個構(gòu)建在EagerPy之上的Python庫。該庫是用EagerPy而不是NumPy重寫的,以實現(xiàn)在PyTorch和TensorFlow中開發(fā),并且只有一個代碼庫,沒有代碼重復(fù)。Foolbox是一個對機(jī)器學(xué)習(xí)模型進(jìn)行對抗性攻擊的庫。
框架無關(guān)的重要性
為了解決框架之間的差異,作者探索了句法偏差。在PyTorch的情況下,使用In-place的梯度需要使用**_grad_()「,而反向傳播是使用」backward**()調(diào)用的。
然而,TensorFlow提供了一個高級管理器和像「tape.gradient」這樣的函數(shù)來查詢梯度。即使在句法層面,這兩個框架也有很大的不同。例如,對于參數(shù),dim vs axis;對于函數(shù),sum vs reduce_sum。
這就是“EagerPy ”發(fā)揮作用的地方。它通過提供一個統(tǒng)一的API來解決PyTorch和TensorFlow之間的差異,該API透明地映射到各種底層框架,而無需計算開銷。
?“EagerPy允許你編寫自動使用PyTorch、TensorFlow、JAX和NumPy的代碼?!?/p>?
研究人員寫道,EagerPy專注于Eager執(zhí)行,此外,它的方法是透明的,用戶可以將與框架無關(guān)的EagerPy代碼與特定于框架的代碼結(jié)合起來。
TensorFlow引入的eager執(zhí)行模塊和PyTorch的相似特性使eager執(zhí)行成為主流,框架更加相似。然而,盡管PyTorch和TensorFlow2之間有這些相似之處,但編寫框架無關(guān)的代碼并不簡單。在語法層面,這些框架中用于自動微分的api是不同的。
自動微分是指用算法求解微分方程。它的工作原理是鏈?zhǔn)揭?guī)則,也就是說,求解函數(shù)的導(dǎo)數(shù)可以歸結(jié)為基本的數(shù)學(xué)運算(加、減、乘、除)。這些算術(shù)運算可以用圖形格式表示。EagerPy特別使用了一種函數(shù)式的方法來自動區(qū)分。
下面是一段來自文檔的代碼片段:
import?eagerpy?as?ep
x?=?ep.astensor(x)
def?loss_fn(x):
?#這個函數(shù)接受并返回一個eager張量
????return?x.square().sum()
print(loss_fn(x))
#?PyTorchTensor(tensor(14.))
print(ep.value_and_grad(loss_fn,?x))
首先定義第一個函數(shù),然后根據(jù)其輸入進(jìn)行微分。然后傳遞給「ep.value_and_grad」 來得到函數(shù)的值及其梯度。
此外,norm函數(shù)現(xiàn)在可以與PyTorch、TensorFlow、JAX和NumPy中的原生張量和數(shù)組一起使用,與本機(jī)代碼相比幾乎沒有任何開銷。它也適用于GPU張量。
import?torch
norm(torch.tensor([1.,?2.,?3.]))
import?tensorflow?as?tf
norm(tf.constant([1.,?2.,?3.]))
總之,EagerPy 旨在提供以下功能:
為快速執(zhí)行提供統(tǒng)一的API
維護(hù)框架的本機(jī)性能
完全可鏈接的API
全面的類型檢查支持
研究人員聲稱,這些屬性使得使用這些屬性比底層框架特定的api更容易、更安全。盡管有這些變化和改進(jìn),但EagerPy 背后的團(tuán)隊還是確保了eagerpy API遵循了NumPy、PyTorch和JAX設(shè)置的標(biāo)準(zhǔn)。
入門「EagerPy」:
使用pip從PyPI安裝最新版本:
python3?-m?pip?install?eagerpy
import?eagerpy?as?ep
def?norm(x):
????x?=?ep.astensor(x)
????result?=?x.square().sum().sqrt()
????return?result.raw
了解更多關(guān)于“eagerpy”的信息:https://eagerpy.jonasrauber.de/guide/autodiff.html
原文鏈接:https://analyticsindiamag.com/eagerpy-pytorch-tensorflow-coding/
往期精彩回顧
獲取一折本站知識星球優(yōu)惠券,復(fù)制鏈接直接打開:
https://t.zsxq.com/662nyZF
本站qq群704220115。
加入微信群請掃碼進(jìn)群(如果是博士或者準(zhǔn)備讀博士請說明):
