Skip to content

Commit

Permalink
[FIX] Added jaxlib version guard for CUDA compute capability check
Browse files Browse the repository at this point in the history
  • Loading branch information
Micky774 committed Mar 13, 2024
1 parent 1ed27ec commit ac2c522
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jaxlib import version

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -333,8 +334,11 @@ def make_gpu_client(
)
if platform_name == "cuda":
_check_cuda_versions()
devices_to_check = allowed_devices if allowed_devices else range(cuda_versions.cuda_device_count())
_check_cuda_compute_capability(devices_to_check)
# TODO(micky774): remove this check when minimum jaxlib is v0.4.26
if version.__version_info__ >= (0, 4, 26):
devices_to_check = allowed_devices if allowed_devices else range(
cuda_versions.cuda_device_count())
_check_cuda_compute_capability(devices_to_check)

return xla_client.make_gpu_client(
distributed_client=distributed.global_state.client,
Expand Down

0 comments on commit ac2c522

Please sign in to comment.