目录
注:该项目目前仍然没有官方的Windows支持,需要自己编译。
该库安装时分为两部分:
找到目前一个还活跃的jaxlib非官方编译服务:
https://github.com/cloudhan/jax-windows-builder
要安装仅 CPU 版本的 JAX,这可能对在笔记本电脑上进行本地开发很有用,您可以运行
pip install --upgrade pip pip install --upgrade "jax[cpu]"
在 Linux 上,通常需要先更新pip
到支持 manylinux2014
轮子的版本。 这些pip
安装不适用于 Windows,并且可能会静默失败;见 上文。
如果要安装同时支持 CPU 和 NVidia GPU 的 JAX,则必须首先安装CUDA和 CuDNN(如果尚未安装)。与其他一些流行的深度学习系统不同,JAX 没有将 CUDA 或 CuDNN 捆绑为pip
软件包的一部分。
JAX仅为 Linux提供预构建的 CUDA 兼容轮子,带有 CUDA 11.1 或更高版本,以及 CuDNN 8.0.5 或更高版本。操作系统、CUDA 和 CuDNN 的其他组合是可能的,但需要从源代码构建。
接下来,运行
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. # Note: wheels only available on linux. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
这些pip
安装不适用于 Windows,并且可能会静默失败;
jaxlib 版本必须与您要使用的现有 CUDA 安装的版本相对应。您可以为 jaxlib 显式指定特定的 CUDA 和 CuDNN 版本:
pip install --upgrade pip # Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2 pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5 pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 具体版本 pip install --upgrade jax==0.3.15 jaxlib==0.3.15+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
有一个社区支持的 Conda 构建jax
。要安装 using conda
,只需运行
conda install jax -c conda-forge
要在具有 NVidia GPU 的机器上安装,请运行
conda install jax cuda-nvcc -c conda-forge -c nvidia
请注意cudatoolkit
Distributed by conda-forge
is missing ptxas
,这是 JAX 要求的。因此,您必须cuda-nvcc
从频道安装软件包nvidia
,或者在您的机器上单独安装 CUDA,以便ptxas
在您的路径中。上面的频道顺序很重要(conda-forge
之前 nvidia
)。我们正在努力简化这一点。
如果您想覆盖 JAX 使用的 CUDA 版本,或者在没有 GPU 的机器上安装 CUDA 构建,请按照 网站提示和技巧 部分中的说明进行操作conda-forge
。