JAX初体验

  |   0 评论   |   0 浏览

背景

JAX是一个包含可组合函数变换的数值计算库,可以用于深度学习。

JAX处于函数变换(function transformations)和科学计算的交界处,所以也有能力训练神经网络模型,但不止于训练。

JAX最初由谷歌大脑团队的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人发起,借助 Autograd 的更新版本,并且结合了 XLA,可对 Python 程序与 NumPy 运算执行自动微分,支持循环、分支、递归、闭包函数求导,也可以求三阶导数;依赖于 XLA,JAX 可以在 GPU 和 TPU 上编译和运行 NumPy 程序;通过 grad,可以支持自动模式反向传播和正向传播,且二者可以任意组合成任何顺序。[1]

JAX优点

  • 加速NumPy
  • XLA: Accelerated Linear Algebra
  • JIT: just in time
  • 自动求导
  • 深度学习
  • 通用可微分编程范式

初体验

环境配置

pip install --upgrade "jax[cpu]" -i https://mirrors.aliyun.com/pypi/simple

JAX是加速版的Numpy

import timeit

import jax.numpy as jnp
import numpy as np
from jax import jit


def fn(x):
    return x + x * x + x * x * x


if __name__ == '__main__':
    jax_fn = jit(fn)

    x = np.random.randn(10000, 10000)
    x2 = jnp.array(x)

    m = timeit.timeit(lambda: fn(x), number=1)
    print(m)

    m = timeit.timeit(lambda: jax_fn(x2), number=1)
    print(m)

结果

2.4264526699999998
0.9130655179999998

可见 jax 比 numpy快几倍。

自动微分

import jax.numpy as jnp
from jax import grad


def tanh(x):  # Define a function
    y = jnp.exp(-2.0 * x)
    return (1.0 - y) / (1.0 + y)


grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))  # Evaluate it at x = 1.0
# prints 0.4199743

print(grad(grad(grad(tanh)))(1.0))
# prints 0.6216266

参考

  1. 2022年再不学JAX就晚了!GitHub超1.6万星
  2. JAX@github:丰富的文档 + 示例