JAX介紹和快速入門示例

來源:DeepHub IMBA 本文約3300字,建議閱讀10+分鐘
本文中,我們了解了 JAX 是什么,并了解了它的一些基本概念。
它可以被視為 GPU 和 TPU 上運行的NumPy , jax.numpy提供了與numpy非常相似API接口。 它與 NumPy API 非常相似,幾乎任何可以用 numpy 完成的事情都可以用 jax.numpy 完成。 由于使用XLA(一種加速線性代數(shù)計算的編譯器)將Python和JAX代碼JIT編譯成優(yōu)化的內(nèi)核,可以在不同設(shè)備(例如gpu和tpu)上運行。而優(yōu)化的內(nèi)核是為高吞吐量設(shè)備(例如gpu和tpu)進行編譯,它與主程序分離但可以被主程序調(diào)用。JIT編譯可以用jax.jit()觸發(fā)。 它對自動微分有很好的支持,對機器學習研究很有用。可以使用 jax.grad() 觸發(fā)自動區(qū)分。 JAX 鼓勵函數(shù)式編程,因為它是面向函數(shù)的。與 NumPy 數(shù)組不同,JAX 數(shù)組始終是不可變的。 JAX提供了一些在編寫數(shù)字處理時非常有用的程序轉(zhuǎn)換,例如JIT . JAX()用于JIT編譯和加速代碼,JIT .grad()用于求導,以及JIT .vmap()用于自動向量化或批處理。 JAX 可以進行異步調(diào)度。所以需要調(diào)用 .block_until_ready() 以確保計算已經(jīng)實際發(fā)生。
自動:在執(zhí)行 JAX 函數(shù)的庫調(diào)用時,默認情況下 JIT 編譯會在后臺進行。 手動:您可以使用 jax.jit() 手動請求對自己的 Python 函數(shù)進行 JIT 編譯。

JAX 使用示例
pip install jax
import jaximport jax.numpy as jnpfrom jax import randomfrom jax import grad, jitimport numpy as npkey = random.PRNGKey(0)
# runs on CPU - numpysize = 5000x = np.random.normal(size=(size, size)).astype(np.float32)%timeit np.dot(x, x.T)# 1 loop, best of 5: 1.61 s per loop
# runs on CPU - JAXsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%timeit jnp.dot(x, x.T).block_until_ready()# 1 loop, best of 5: 3.49 s per loop
# runs on GPUsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time# 1. CPU times: user 102 μs, sys: 42 μs, total: 144 μs# Wall time: 155 μs# 2. CPU times: user 1.3 s, sys: 195 ms, total: 1.5 s# Wall time: 2.16 s# 3. 10 loops, best of 5: 68.9 ms per loop
設(shè)備傳輸時間:將矩陣傳輸?shù)?GPU 所經(jīng)過的時間。耗時 0.155 毫秒。編譯時間:JIT 編譯經(jīng)過的時間。耗時 2.16 秒。運行時間:有效的代碼運行時間。耗時 68.9 毫秒。
# runs on TPUsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time# 1. CPU times: user 131 μs, sys: 72 μs, total: 203 μs# Wall time: 164 μs# 2. CPU times: user 190 ms, sys: 302 ms, total: 492 ms# Wall time: 837 ms# 3. 100 loops, best of 5: 16.5 ms per loop
import jax.tools.colab_tpujax.tools.colab_tpu.setup_tpu()
XLA
def selu_np(x, alpha=1.67, lmbda=1.05):return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)def selu_jax(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
# runs on the CPU - numpyx = np.random.normal(size=(1000000,)).astype(np.float32)%timeit selu_np(x)# 100 loops, best of 5: 7.6 ms per loop
# runs on the CPU - JAXx = random.normal(key, (1000000,))%time selu_jax(x).block_until_ready() # 1. measure JAX compilation time%timeit selu_jax(x).block_until_ready() # 2. measure JAX runtime# 1. CPU times: user 124 ms, sys: 5.01 ms, total: 129 ms# Wall time: 124 ms# 2. 100 loops, best of 5: 4.8 ms per loop
# runs on the GPUx = random.normal(key, (1000000,))%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time selu_jax(x_jax).block_until_ready() # 2. measure JAX compilation time%timeit selu_jax(x_jax).block_until_ready() # 3. measure JAX runtime# 1. CPU times: user 103 μs, sys: 0 ns, total: 103 μs# Wall time: 109 μs# 2. CPU times: user 148 ms, sys: 9.09 ms, total: 157 ms# Wall time: 447 ms# 3. 1000 loops, best of 5: 1.21 ms per loop
# runs on the GPUx = random.normal(key, (1000000,))selu_jax_jit = jit(selu_jax)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time selu_jax_jit(x_jax).block_until_ready() # 2. measure JAX compilation time%timeit selu_jax_jit(x_jax).block_until_ready() # 3. measure JAX runtime# 1. CPU times: user 70 μs, sys: 28 μs, total: 98 μs# Wall time: 104 μs# 2. CPU times: user 66.6 ms, sys: 1.18 ms, total: 67.8 ms# Wall time: 122 ms# 3. 10000 loops, best of 5: 130 μs per loop
CPU 上的 NumPy:7.6 毫秒。 CPU 上的 JAX:4.8 毫秒(x1.58 加速)。 沒有 JIT 的 GPU 上的 JAX:1.21 毫秒(x6.28 加速)。 帶有 JIT 的 GPU 上的 JAX:0.13 毫秒(x58.46 加速)。
@jitdef selu_jax_jit(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
使用 jax.grad 自動微分
def sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))x_small = jnp.arange(3.)derivative_fn = grad(sum_logistic)print(derivative_fn(x_small))# [0.25, 0.19661197, 0.10499357]
總結(jié)
編輯:黃繼彥
評論
圖片
表情
