diff --git a/jax/_src/config.py b/jax/_src/config.py index a44b0125a210..72f394dba76f 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1570,7 +1570,9 @@ def _update_default_device_thread_local(val): def _validate_default_device(val): - if val is not None and not isinstance(val, xla_client.Device): + if (val is not None and + not isinstance(val, xla_client.Device) and + val not in ['cpu', 'gpu', 'tpu']): # TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when # all JAX backends use a single C++ device interface. if 'Device' in str(type(val)): @@ -1578,12 +1580,11 @@ def _validate_default_device(val): 'Allowing non-`xla_client.Device` default device: %s, type: %s', repr(val), type(val)) return - raise ValueError('jax.default_device must be passed a Device object (e.g. ' - f"`jax.devices('cpu')[0]`), got: {val!r}") + raise ValueError('jax.default_device must be passed either a Device object (e.g. ' + f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'" + f", got: {val!r}") -# TODO(skye): default_device only accepts devices for now. Make it work with -# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). default_device = string_or_object_state( name='jax_default_device', default=None, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index caa414741666..6c9e54441f8e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1711,7 +1711,10 @@ class DeviceAssignmentMismatchError(Exception): def _get_default_device() -> xc.Device: - return config.default_device.value or xb.local_devices()[0] + if isinstance(config.default_device.value, str): + return xb.get_backend(config.default_device.value).local_devices()[0] + else: + return config.default_device.value or xb.local_devices()[0] def _get_and_check_device_assignment( diff --git a/tests/api_test.py b/tests/api_test.py index 8ab5d90f6e07..49cd33ee464c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -287,13 +287,15 @@ def test_jit_default_device(self, module): self.assertEqual(f(sticky).devices(), system_default_devices) self.assertEqual(f(1).devices(), system_default_devices) - # TODO(skye): make this work! def test_jit_default_platform(self): - with self.assertRaisesWithLiteralMatch( - ValueError, "jax.default_device must be passed a Device object " - "(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"): with jax.default_device("cpu"): - jax.jit(lambda x: x + 1)(1) + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, "cpu") + self.assertEqual(result.device, jax.local_devices(backend="cpu")[0]) + + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, jax.default_backend()) + self.assertEqual(result.device, jax.local_devices()[0]) def test_complex_support(self): self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j)