Skip to content

Commit

Permalink
Add jax[cuda12] install variation for using cuda plugin.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590342149
  • Loading branch information
Jieying Luo authored and jax authors committed Dec 12, 2023
1 parent 7305b64 commit 4fe9e59
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,29 @@ def generate_proto(source):
"nvidia-nvjitlink-cu12>=12.2",
],

'cuda12': [
f"jaxlib=={_current_jaxlib_version}",
f"jax-cuda12-plugin=={_current_jaxlib_version}",
"nvidia-cublas-cu12>=12.2.5.6",
"nvidia-cuda-cupti-cu12>=12.2.142",
"nvidia-cuda-nvcc-cu12>=12.2.140",
"nvidia-cuda-runtime-cu12>=12.2.140",
"nvidia-cudnn-cu12>=8.9",
"nvidia-cufft-cu12>=11.0.8.103",
"nvidia-cusolver-cu12>=11.5.2",
"nvidia-cusparse-cu12>=12.1.2.141",
"nvidia-nccl-cu12>=2.18.3",

# nvjitlink is not a direct dependency of JAX, but it is a transitive
# dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages
# do not have a version constraint on their dependencies, so the
# package doesn't get upgraded even though not doing that can cause
# problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196)
# Until NVIDIA add version constraints, add an version constraint
# here.
"nvidia-nvjitlink-cu12>=12.2",
],

# Target that does not depend on the CUDA pip wheels, for those who want
# to use a preinstalled CUDA.
'cuda11_local': [
Expand Down

0 comments on commit 4fe9e59

Please sign in to comment.