Skip to content

Commit

Permalink
Change determination of cloud TPU to check for TPU chips.
Browse files Browse the repository at this point in the history
This is useful in the case of ahead of time compilation, when libtpu is present but there may not be any TPU chips, so we shouldn't attempt to initialize a TPU backend.

PiperOrigin-RevId: 608175011
  • Loading branch information
jax authors committed Mar 16, 2024
1 parent a53e99a commit 0b2a333
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ def cloud_tpu_init() -> None:
"""
global running_in_cloud_tpu_vm

# We assume we are in a correctly-configured Cloud TPU environment
# if the following hold: a) libtpu is installed b) JAX_FORCE_TPU_INIT is set
# Exit early if we're not running on Cloud TPU.
libtpu_module = maybe_import_libtpu()
if libtpu_module is None and not jax_force_tpu_init():
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
if (libtpu_module is None or num_tpu_chips == 0) and not jax_force_tpu_init():
return

running_in_cloud_tpu_vm = True
Expand Down

0 comments on commit 0b2a333

Please sign in to comment.