<kbd id="afajh"><form id="afajh"></form></kbd>
<strong id="afajh"><dl id="afajh"></dl></strong>
    <del id="afajh"><form id="afajh"></form></del>
        1. <th id="afajh"><progress id="afajh"></progress></th>
          <b id="afajh"><abbr id="afajh"></abbr></b>
          <th id="afajh"><progress id="afajh"></progress></th>

          JAX介紹和快速入門示例

          共 6647字,需瀏覽 14分鐘

           ·

          2022-06-20 02:56


          來源:DeepHub IMBA

          本文約3300字,建議閱讀10+分鐘

          本文中,我們了解了 JAX 是什么,并了解了它的一些基本概念。


          JAX 是一個由 Google 開發(fā)的用于優(yōu)化科學計算Python 庫:

          • 它可以被視為 GPU 和 TPU 上運行的NumPy , jax.numpy提供了與numpy非常相似API接口。
          • 它與 NumPy API 非常相似,幾乎任何可以用 numpy 完成的事情都可以用 jax.numpy 完成。
          • 由于使用XLA(一種加速線性代數(shù)計算的編譯器)將Python和JAX代碼JIT編譯成優(yōu)化的內(nèi)核,可以在不同設(shè)備(例如gpu和tpu)上運行。而優(yōu)化的內(nèi)核是為高吞吐量設(shè)備(例如gpu和tpu)進行編譯,它與主程序分離但可以被主程序調(diào)用。JIT編譯可以用jax.jit()觸發(fā)。
          • 它對自動微分有很好的支持,對機器學習研究很有用。可以使用 jax.grad() 觸發(fā)自動區(qū)分。
          • JAX 鼓勵函數(shù)式編程,因為它是面向函數(shù)的。與 NumPy 數(shù)組不同,JAX 數(shù)組始終是不可變的。
          • JAX提供了一些在編寫數(shù)字處理時非常有用的程序轉(zhuǎn)換,例如JIT . JAX()用于JIT編譯和加速代碼,JIT .grad()用于求導,以及JIT .vmap()用于自動向量化或批處理。
          • JAX 可以進行異步調(diào)度。所以需要調(diào)用 .block_until_ready() 以確保計算已經(jīng)實際發(fā)生。

          JAX 使用 JIT 編譯有兩種方式:

          • 自動:在執(zhí)行 JAX 函數(shù)的庫調(diào)用時,默認情況下 JIT 編譯會在后臺進行。
          • 手動:您可以使用 jax.jit() 手動請求對自己的 Python 函數(shù)進行 JIT 編譯。


          JAX 使用示例


          我們可以使用 pip 安裝庫。

          pip install jax


          導入需要的包,這里我們也繼續(xù)使用 NumPy ,這樣可以執(zhí)行一些基準測試。

          import jaximport jax.numpy as jnpfrom jax import randomfrom jax import grad, jitimport numpy as np
          key = random.PRNGKey(0)

          與 import numpy as np 類似,我們可以 import jax.numpy as jnp 并將代碼中的所有 np 替換為 jnp 。如果 NumPy 代碼是用函數(shù)式編程風格編寫的,那么新的 JAX 代碼就可以直接使用。但是,如果有可用的GPU,JAX則可以直接使用。

          JAX 中隨機數(shù)的生成方式與 NumPy 不同。JAX需要創(chuàng)建一個 jax.random.PRNGKey 。我們稍后會看到如何使用它。

          我們在 Google Colab 上做一個簡單的基準測試,這樣我們就可以輕松訪問 GPU 和 TPU。我們首先初始化一個包含 25M 元素的隨機矩陣,然后將其乘以它的轉(zhuǎn)置。使用針對 CPU 優(yōu)化的 NumPy,矩陣乘法平均需要 1.61 秒。

          # runs on CPU - numpysize = 5000x = np.random.normal(size=(size, size)).astype(np.float32)%timeit np.dot(x, x.T)# 1 loop, best of 5: 1.61 s per loop

          在 CPU 上使用 JAX 執(zhí)行相同的操作平均需要大約 3.49 秒。

          # runs on CPU - JAXsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%timeit jnp.dot(x, x.T).block_until_ready()# 1 loop, best of 5: 3.49 s per loop

          在 CPU 上運行時,JAX 通常比 NumPy 慢,因為 NumPy 已針對CPU進行了非常多的優(yōu)化。但是,當使用加速器時這種情況會發(fā)生變化,所以讓我們嘗試使用 GPU 進行矩陣乘法。

          # runs on GPUsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time# 1. CPU times: user 102 μs, sys: 42 μs, total: 144 μs#   Wall time: 155 μs# 2. CPU times: user 1.3 s, sys: 195 ms, total: 1.5 s#   Wall time: 2.16 s# 3. 10 loops, best of 5: 68.9 ms per loop

          從示例中可以看出,要進行公平的基準比較,我們需要使用 JAX 測量不同的步驟:

          設(shè)備傳輸時間:將矩陣傳輸?shù)?GPU 所經(jīng)過的時間。耗時 0.155 毫秒。編譯時間:JIT 編譯經(jīng)過的時間。耗時 2.16 秒。運行時間:有效的代碼運行時間。耗時 68.9 毫秒。

          在 GPU 上使用 JAX 進行單個矩陣乘法的總耗時約為 2.23 秒,高于 NumPy 的總時間 1.61 秒。但是對于每個額外的矩陣乘法,JAX 只需要 68.9 毫秒,而 NumPy 需要 1.61 秒,快了 22 倍多!因此,如果多次執(zhí)行線性代數(shù)運算,那么使用 JAX 是有意義的。

          讓我們測試使用 TPU 進行矩陣乘法。

          # runs on TPUsize = 5000x = random.normal(key, (size, size), dtype=jnp.float32)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time jnp.dot(x_jax, x_jax.T).block_until_ready() # 2. measure JAX compilation time%timeit jnp.dot(x_jax, x_jax.T).block_until_ready() # 3. measure JAX running time# 1. CPU times: user 131 μs, sys: 72 μs, total: 203 μs#   Wall time: 164 μs# 2. CPU times: user 190 ms, sys: 302 ms, total: 492 ms#   Wall time: 837 ms# 3. 100 loops, best of 5: 16.5 ms per loop

          忽略設(shè)備傳輸時間和編譯時間,每個矩陣乘法平均需要 16.5 毫秒:GPU 相比快了4倍,與 CPU 的 NumPy相比快了88倍。需要說明的是,當乘以不同大小的矩陣時,獲得相同的加速效果也不同:相乘的矩陣越大,GPU可以優(yōu)化操作的越多,加速也越大。

          為了在 Google Colab 上復制上述基準,需要運行以下代碼讓 JAX 知道有可用的 TPU。

          import jax.tools.colab_tpujax.tools.colab_tpu.setup_tpu()

          讓我們看看 XLA 編譯器。

          XLA


          XLA 是 JAX(和其他庫,例如 TensorFlow,TPU的Pytorch)使用的線性代數(shù)的編譯器,它通過創(chuàng)建自定義優(yōu)化內(nèi)核來保證最快的在程序中運行線性代數(shù)運算。XLA 最大的好處是可以讓我們在應(yīng)用中自定義內(nèi)核,該部分使用線性代數(shù)運算,以便它可以進行最多的優(yōu)化。

          XLA 最重要的優(yōu)化是融合,即可以在同一個內(nèi)核中進行多個線性代數(shù)運算,將中間輸出保存到 GPU 寄存器中,而不將它們具體化到內(nèi)存中。這可以顯著增加我們的“計算強度”,即所做的工作量與負載和存儲數(shù)量的比例。融合還可以讓我們完全省略僅在內(nèi)存中shuffle 的操作(例如reshape)。

          下面我們看看如何使用 XLA 和 jax.jit 手動觸發(fā) JIT 編譯。

          使用 jax.jit 進行即時編譯

          這里有一些新的基準來測試 jax.jit 的性能。我們定義了兩個實現(xiàn) SELU(Scaled Exponential Linear Unit)的函數(shù):一個使用 NumPy,一個使用 JAX。暫時先不考慮 jax.jitat

          def selu_np(x, alpha=1.67, lmbda=1.05):return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)
          def selu_jax(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

          然后,我們使用 NumPy 在 1M 個元素的向量上運行它。

          # runs on the CPU - numpyx = np.random.normal(size=(1000000,)).astype(np.float32)%timeit selu_np(x)# 100 loops, best of 5: 7.6 ms per loop

          平均需要 7.6 毫秒。現(xiàn)在讓我們在 CPU 上使用 JAX。

          # runs on the CPU - JAXx = random.normal(key, (1000000,))%time selu_jax(x).block_until_ready() # 1. measure JAX compilation time%timeit selu_jax(x).block_until_ready() # 2. measure JAX runtime# 1. CPU times: user 124 ms, sys: 5.01 ms, total: 129 ms#   Wall time: 124 ms# 2. 100 loops, best of 5: 4.8 ms per loop

          現(xiàn)在平均需要 4.8 毫秒,在這種情況下比 NumPy 快。下一個測試是在 GPU 上使用 JAX。

          # runs on the GPUx = random.normal(key, (1000000,))%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time selu_jax(x_jax).block_until_ready() # 2. measure JAX compilation time%timeit selu_jax(x_jax).block_until_ready() # 3. measure JAX runtime# 1. CPU times: user 103 μs, sys: 0 ns, total: 103 μs#   Wall time: 109 μs# 2. CPU times: user 148 ms, sys: 9.09 ms, total: 157 ms#   Wall time: 447 ms# 3. 1000 loops, best of 5: 1.21 ms per loop

          函數(shù)運行時間為1.21毫秒。下面我們用 jax.jit 測試它,觸發(fā) JIT 編譯器使用 XLA 將 SELU 函數(shù)編譯到優(yōu)化的 GPU 內(nèi)核中,同時優(yōu)化函數(shù)內(nèi)部的所有操作。

          # runs on the GPUx = random.normal(key, (1000000,))selu_jax_jit = jit(selu_jax)%time x_jax = jax.device_put(x) # 1. measure JAX device transfer time%time selu_jax_jit(x_jax).block_until_ready() # 2. measure JAX compilation time%timeit selu_jax_jit(x_jax).block_until_ready() # 3. measure JAX runtime# 1. CPU times: user 70 μs, sys: 28 μs, total: 98 μs#   Wall time: 104 μs# 2. CPU times: user 66.6 ms, sys: 1.18 ms, total: 67.8 ms#   Wall time: 122 ms# 3. 10000 loops, best of 5: 130 μs per loop

          使用編譯內(nèi)核,函數(shù)運行時間為0.13毫秒!

          讓我們回顧一下不同的運行時間:

          • CPU 上的 NumPy:7.6 毫秒。
          • CPU 上的 JAX:4.8 毫秒(x1.58 加速)。
          • 沒有 JIT 的 GPU 上的 JAX:1.21 毫秒(x6.28 加速)。
          • 帶有 JIT 的 GPU 上的 JAX:0.13 毫秒(x58.46 加速)。

          使用 JIT 編譯避免從 GPU 寄存器中移動數(shù)據(jù)這樣給我們帶來了非常大的加速。一般來說在不同類型的內(nèi)存之間移動數(shù)據(jù)與代碼執(zhí)行相比非常慢,因此在實際使用時應(yīng)該盡量避免!

          將 SELU 函數(shù)應(yīng)用于不同大小的向量時,您可能會獲得不同的結(jié)果。矢量越大,加速器越能優(yōu)化操作,加速也越大。

          除了執(zhí)行 selu_jax_jit = jit(selu_jax) 之外,還可以使用 @jit 裝飾器對函數(shù)進行 JIT 編譯,如下所示。

          @jitdef selu_jax_jit(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

          JIT 編譯可以加速,為什么我們不能全部都這樣做呢?因為并非所有代碼都可以 JIT 編譯,JIT要求數(shù)組形狀是靜態(tài)的并且在編譯時已知。另外就是引入jax.jit 也會帶來一些開銷。因此通常只有編譯的函數(shù)比較復雜并且需要多次運行才能節(jié)省時間。但是這在機器學習中很常見,例如我們傾編譯一個大而復雜的模型,然后運行它進行數(shù)百萬次訓練、損失函數(shù)和指標的計算。

          使用 jax.grad 自動微分


          另一個 JAX 轉(zhuǎn)換是使用 jit.grad() 函數(shù)的自動微分。

          借助 Autograd ,JAX 可以自動對原生 Python 和 NumPy 代碼進行微分。并且支持 Python 的大部分特性,包括循環(huán)、if、遞歸和閉包。

          下面看看一個帶有 jit.grad() 的代碼示例,我們計算一個自定義的包含 JAX 函數(shù)的Python 函數(shù)的導數(shù)。

          def sum_logistic(x):return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
          x_small = jnp.arange(3.)derivative_fn = grad(sum_logistic)print(derivative_fn(x_small))# [0.25, 0.19661197, 0.10499357]


          總結(jié)


          在本文中,我們了解了 JAX 是什么,并了解了它的一些基本概念:NumPy 接口、JIT 編譯、XLA、優(yōu)化內(nèi)核、程序轉(zhuǎn)換、自動微分和函數(shù)式編程。在 JAX 之上,開源社區(qū)為機器學習構(gòu)建了更多高級庫,例如 Flax 和 Haiku。有興趣的可以搜索查看。

          編輯:黃繼彥





          瀏覽 22
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          評論
          圖片
          表情
          推薦
          點贊
          評論
          收藏
          分享

          手機掃一掃分享

          分享
          舉報
          <kbd id="afajh"><form id="afajh"></form></kbd>
          <strong id="afajh"><dl id="afajh"></dl></strong>
            <del id="afajh"><form id="afajh"></form></del>
                1. <th id="afajh"><progress id="afajh"></progress></th>
                  <b id="afajh"><abbr id="afajh"></abbr></b>
                  <th id="afajh"><progress id="afajh"></progress></th>
                  PANS私拍在线一区二区 | 成人日皮精品视频 | 亚州人妻偷拍成人理伦 | 久久久午夜福利 | 日韩欧美人妻无码精品 |