当前位置: 首页 > 工具软件 > JAX > 使用案例 >

Ubuntu下jax安装与使用

臧亦
2023-12-01

目录

安装说明

pip安装

conda安装

参考网址


注:该项目目前仍然没有官方的Windows支持,需要自己编译。

安装说明

该库安装时分为两部分:

  1. jaxlib,该库平台相关,目前没有官方的编译
  2. jax,该库依赖jaxlib,平台无关,可以直接安装。

找到目前一个还活跃的jaxlib非官方编译服务:

https://github.com/cloudhan/jax-windows-builder

pip安装

要安装仅 CPU 版本的 JAX,这可能对在笔记本电脑上进行本地开发很有用,您可以运行

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

在 Linux 上,通常需要先更新pip到支持 manylinux2014轮子的版本。 这些pip安装不适用于 Windows,并且可能会静默失败;见 上文

如果要安装同时支持 CPU 和 NVidia GPU 的 JAX,则必须首先安装CUDA和 CuDNN(如果尚未安装)。与其他一些流行的深度学习系统不同,JAX 没有将 CUDA 或 CuDNN 捆绑为pip 软件包的一部分。

JAX仅为 Linux提供预构建的 CUDA 兼容轮子,带有 CUDA 11.1 或更高版本,以及 CuDNN 8.0.5 或更高版本。操作系统、CUDA 和 CuDNN 的其他组合是可能的,但需要从源代码构建

  • 需要CUDA 11.1 或更新版本。
  • 预建轮子支持的 cuDNN 版本是:
    • cuDNN 8.2 或更高版本。如果您的 cuDNN 安装足够新,我们建议使用 cuDNN 8.2 轮,因为它支持附加功能。
    • cuDNN 8.0.5 或更高版本。
  • 必须使用至少与您的 CUDA 工具包的相应驱动程序版本一样新的 NVidia 驱动程序版本。例如,如果您安装了 CUDA 11.4 update 4,则在 Linux 上必须使用 NVidia 驱动程序 470.82.01 或更新版本。这是一个严格的要求,因为 JAX 依赖于 JIT 编译代码;较旧的驱动程序可能会导致故障。
    • 如果您需要将较新的 CUDA 工具包与较旧的驱动程序一起使用,例如在无法轻松更新 NVidia 驱动程序的集群上,您可以使用 NVidia 为此目的提供的CUDA 前向兼容性包。

接下来,运行

pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

这些pip安装不适用于 Windows,并且可能会静默失败;

jaxlib 版本必须与您要使用的现有 CUDA 安装的版本相对应。您可以为 jaxlib 显式指定特定的 CUDA 和 CuDNN 版本:

pip install --upgrade pip

# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

具体版本
pip install --upgrade jax==0.3.15 jaxlib==0.3.15+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

conda安装

有一个社区支持的 Conda 构建jax。要安装 using conda,只需运行

conda install jax -c conda-forge

要在具有 NVidia GPU 的机器上安装,请运行

conda install jax cuda-nvcc -c conda-forge -c nvidia

请注意cudatoolkitDistributed by conda-forgeis missing ptxas,这是 JAX 要求的。因此,您必须cuda-nvcc从频道安装软件包nvidia,或者在您的机器上单独安装 CUDA,以便ptxas 在您的路径中。上面的频道顺序很重要(conda-forge之前 nvidia)。我们正在努力简化这一点。

如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 构建,请按照 网站提示和技巧 部分中的说明进行操作conda-forge

参考网址

https://github.com/google/jax

 类似资料: