Skip to content

Commit

Permalink
Merge pull request jax-ml#24751 from Stella-S-Yan:feature/default_dev…
Browse files Browse the repository at this point in the history
…ice_str

PiperOrigin-RevId: 696560063
  • Loading branch information
Google-ML-Automation committed Nov 14, 2024
2 parents 05716b5 + afa518a commit cea8176
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
11 changes: 6 additions & 5 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,20 +1570,21 @@ def _update_default_device_thread_local(val):


def _validate_default_device(val):
if val is not None and not isinstance(val, xla_client.Device):
if (val is not None and
not isinstance(val, xla_client.Device) and
val not in ['cpu', 'gpu', 'tpu']):
# TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when
# all JAX backends use a single C++ device interface.
if 'Device' in str(type(val)):
logger.info(
'Allowing non-`xla_client.Device` default device: %s, type: %s',
repr(val), type(val))
return
raise ValueError('jax.default_device must be passed a Device object (e.g. '
f"`jax.devices('cpu')[0]`), got: {val!r}")
raise ValueError('jax.default_device must be passed either a Device object (e.g. '
f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'"
f", got: {val!r}")


# TODO(skye): default_device only accepts devices for now. Make it work with
# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]).
default_device = string_or_object_state(
name='jax_default_device',
default=None,
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,7 +1711,10 @@ class DeviceAssignmentMismatchError(Exception):


def _get_default_device() -> xc.Device:
return config.default_device.value or xb.local_devices()[0]
if isinstance(config.default_device.value, str):
return xb.get_backend(config.default_device.value).local_devices()[0]
else:
return config.default_device.value or xb.local_devices()[0]


def _get_and_check_device_assignment(
Expand Down
12 changes: 7 additions & 5 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,15 @@ def test_jit_default_device(self, module):
self.assertEqual(f(sticky).devices(), system_default_devices)
self.assertEqual(f(1).devices(), system_default_devices)

# TODO(skye): make this work!
def test_jit_default_platform(self):
with self.assertRaisesWithLiteralMatch(
ValueError, "jax.default_device must be passed a Device object "
"(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"):
with jax.default_device("cpu"):
jax.jit(lambda x: x + 1)(1)
result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, "cpu")
self.assertEqual(result.device, jax.local_devices(backend="cpu")[0])

result = jax.jit(lambda x: x + 1)(1)
self.assertEqual(result.device.platform, jax.default_backend())
self.assertEqual(result.device, jax.local_devices()[0])

def test_complex_support(self):
self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)
Expand Down

0 comments on commit cea8176

Please sign in to comment.