diff --git a/CHANGELOG.md b/CHANGELOG.md index db2cc986abe1..8e9656db069a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ Remember to align the itemized text with the first line of an item within a list * {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D solves with `b.ndim > 1`. In the future these will be treated as batched 2D solves. + * Conversion of a non-scalar array to a Python scalar now raises an error, regardless + of the size of the array. Previously a deprecation warning was raised in the case of + non-scalar arrays of size 1. This follows a similar deprecation in NumPy. ## jaxlib 0.4.25 diff --git a/jax/_src/core.py b/jax/_src/core.py index e0ad9dc29555..cb1537d197c9 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -639,12 +639,9 @@ def escaped_tracer_error(tracer, detail=None): def check_scalar_conversion(arr: Array): - if arr.size != 1: - raise TypeError("Only length-1 arrays can be converted to Python scalars.") - if arr.shape != (): - # Added 2023 September 18. - warnings.warn("Conversion of an array with ndim > 0 to a scalar is deprecated, " - "and will error in future.", DeprecationWarning, stacklevel=3) + if arr.ndim > 0: + raise TypeError("Only scalar arrays can be converted to Python scalars; " + f"got {arr.ndim=}") def check_integer_conversion(arr: Array): diff --git a/tests/api_test.py b/tests/api_test.py index 35ffc610006a..81b04d5849a7 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4495,37 +4495,35 @@ def test_scalar_conversion_errors(self): array_int = jnp.arange(10, dtype=int) scalar_float = jnp.float32(0) scalar_int = jnp.int32(0) + empty_int = jnp.arange(0, dtype='int32') array1_float = jnp.arange(1, dtype='float32') assertIntError = partial(self.assertRaisesRegex, TypeError, "Only integer scalar arrays can be converted to a scalar index.") for func in [operator.index, hex, oct]: assertIntError(func, array_int) + assertIntError(func, empty_int) assertIntError(func, scalar_float) assertIntError(jax.jit(func), array_int) + assertIntError(jax.jit(func), empty_int) assertIntError(jax.jit(func), scalar_float) self.assertRaises(TracerIntegerConversionError, jax.jit(func), scalar_int) _ = func(scalar_int) # no error assertScalarError = partial(self.assertRaisesRegex, TypeError, - "Only length-1 arrays can be converted to Python scalars.") + "Only scalar arrays can be converted to Python scalars.") for func in [int, float, complex]: assertScalarError(func, array_int) assertScalarError(jax.jit(func), array_int) self.assertRaises(ConcretizationTypeError, jax.jit(func), scalar_int) _ = func(scalar_int) # no error - # TODO(jakevdp): remove this ignore warning when possible - with jtu.ignore_warning(category=DeprecationWarning): - self.assertRaises(ConcretizationTypeError, jax.jit(func), array1_float) - _ = func(array1_float) # no error - - # TODO(jakevdp): add these tests once these deprecated operations error. - # empty_int = jnp.arange(0, dtype='int32') - # assertEmptyBoolError = partial( - # self.assertRaisesRegex, ValueError, - # "The truth value of an empty array is ambiguous.") - # assertEmptyBoolError(bool, empty_int) - # assertEmptyBoolError(jax.jit(bool), empty_int) + assertScalarError(func, array1_float) + + assertEmptyBoolError = partial( + self.assertRaisesRegex, ValueError, + "The truth value of an empty array is ambiguous.") + assertEmptyBoolError(bool, empty_int) + assertEmptyBoolError(jax.jit(bool), empty_int) assertBoolError = partial( self.assertRaisesRegex, ValueError,