使用 JAX 启动并运行

主要观点:

  • JAX 是谷歌研发的高性能数值计算库,结合了 Python 和 NumPy 的易用性与 XLA 的速度效率,适用于机器学习研究和数值计算。
  • 核心是通过自动微分扩展 NumPy 功能,采用函数式编程,强调不可变性和纯函数,代码更易维护、可重现且性能高效。
  • 介绍了 JAX 的三个重要特性:即时编译、向量化变换和自动微分,分别进行了详细阐述和代码示例。

关键信息:

  • Functorch 最初将类似 JAX 的可组合函数转换带到 PyTorch,后集成到 PyTorch 2.0 核心。
  • JAX 核心功能包括自动微分、即时编译(通过jax.jit@jit装饰器)和向量化变换(vmap)。
  • 即时编译在首次调用时有轻微开销,但后续调用快,输入形状和类型固定时效果最佳,可与gradvmappmap结合使用。
  • vmap可自动向量化操作,无需显式写循环进行批量处理,通过jax.random处理随机数生成以确保功能纯粹性。
  • 自动微分通过grad计算标量值函数的梯度,jacfwdjacrev计算雅可比矩阵,hessian计算二阶导数,适用于不同的应用场景。

重要细节:

  • 计算大圆距离的 Haversine 公式示例,展示了即时编译的使用和速度提升。
  • 生成随机坐标数组并应用vmap进行向量化计算,对比了向量化与原生循环的速度。
  • 指数分布的 CDF 和 PDF 示例,使用grad计算导数并与 Scipy 结果对比。
  • 还可对grad的结果再次求导,展示二阶导数的计算。
阅读 7
0 条评论