JAX初体验
背景
JAX是一个包含可组合函数变换的数值计算库,可以用于深度学习。
JAX处于函数变换(function transformations)和科学计算的交界处,所以也有能力训练神经网络模型,但不止于训练。
JAX最初由谷歌大脑团队的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人发起,借助 Autograd 的更新版本,并且结合了 XLA,可对 Python 程序与 NumPy 运算执行自动微分,支持循环、分支、递归、闭包函数求导,也可以求三阶导数;依赖于 XLA,JAX 可以在 GPU 和 TPU 上编译和运行 NumPy 程序;通过 grad,可以支持自动模式反向传播和正向传播,且二者可以任意组合成任何顺序。[1]
JAX优点
- 加速NumPy
- XLA: Accelerated Linear Algebra
- JIT: just in time
- 自动求导
- 深度学习
- 通用可微分编程范式
初体验
环境配置
pip install --upgrade "jax[cpu]" -i https://mirrors.aliyun.com/pypi/simple
JAX是加速版的Numpy
import timeit
import jax.numpy as jnp
import numpy as np
from jax import jit
def fn(x):
return x + x * x + x * x * x
if __name__ == '__main__':
jax_fn = jit(fn)
x = np.random.randn(10000, 10000)
x2 = jnp.array(x)
m = timeit.timeit(lambda: fn(x), number=1)
print(m)
m = timeit.timeit(lambda: jax_fn(x2), number=1)
print(m)
结果
2.4264526699999998
0.9130655179999998
可见 jax 比 numpy快几倍。
自动微分
import jax.numpy as jnp
from jax import grad
def tanh(x): # Define a function
y = jnp.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
# prints 0.4199743
print(grad(grad(grad(tanh)))(1.0))
# prints 0.6216266
参考
- 2022年再不学JAX就晚了!GitHub超1.6万星
- JAX@github:丰富的文档 + 示例