Skip to content

Commit

Permalink
Merge pull request #19181 from jakevdp:scalar-conversion
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607763395
  • Loading branch information
jax authors committed Feb 16, 2024
2 parents 0c92f55 + 1fe46aa commit ceb1985
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
solves with `b.ndim > 1`. In the future these will be treated as batched 2D
solves.
* Conversion of a non-scalar array to a Python scalar now raises an error, regardless
of the size of the array. Previously a deprecation warning was raised in the case of
non-scalar arrays of size 1. This follows a similar deprecation in NumPy.

## jaxlib 0.4.25

Expand Down
9 changes: 3 additions & 6 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,9 @@ def escaped_tracer_error(tracer, detail=None):


def check_scalar_conversion(arr: Array):
if arr.size != 1:
raise TypeError("Only length-1 arrays can be converted to Python scalars.")
if arr.shape != ():
# Added 2023 September 18.
warnings.warn("Conversion of an array with ndim > 0 to a scalar is deprecated, "
"and will error in future.", DeprecationWarning, stacklevel=3)
if arr.ndim > 0:
raise TypeError("Only scalar arrays can be converted to Python scalars; "
f"got {arr.ndim=}")


def check_integer_conversion(arr: Array):
Expand Down
24 changes: 11 additions & 13 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4495,37 +4495,35 @@ def test_scalar_conversion_errors(self):
array_int = jnp.arange(10, dtype=int)
scalar_float = jnp.float32(0)
scalar_int = jnp.int32(0)
empty_int = jnp.arange(0, dtype='int32')
array1_float = jnp.arange(1, dtype='float32')

assertIntError = partial(self.assertRaisesRegex, TypeError,
"Only integer scalar arrays can be converted to a scalar index.")
for func in [operator.index, hex, oct]:
assertIntError(func, array_int)
assertIntError(func, empty_int)
assertIntError(func, scalar_float)
assertIntError(jax.jit(func), array_int)
assertIntError(jax.jit(func), empty_int)
assertIntError(jax.jit(func), scalar_float)
self.assertRaises(TracerIntegerConversionError, jax.jit(func), scalar_int)
_ = func(scalar_int) # no error

assertScalarError = partial(self.assertRaisesRegex, TypeError,
"Only length-1 arrays can be converted to Python scalars.")
"Only scalar arrays can be converted to Python scalars.")
for func in [int, float, complex]:
assertScalarError(func, array_int)
assertScalarError(jax.jit(func), array_int)
self.assertRaises(ConcretizationTypeError, jax.jit(func), scalar_int)
_ = func(scalar_int) # no error
# TODO(jakevdp): remove this ignore warning when possible
with jtu.ignore_warning(category=DeprecationWarning):
self.assertRaises(ConcretizationTypeError, jax.jit(func), array1_float)
_ = func(array1_float) # no error

# TODO(jakevdp): add these tests once these deprecated operations error.
# empty_int = jnp.arange(0, dtype='int32')
# assertEmptyBoolError = partial(
# self.assertRaisesRegex, ValueError,
# "The truth value of an empty array is ambiguous.")
# assertEmptyBoolError(bool, empty_int)
# assertEmptyBoolError(jax.jit(bool), empty_int)
assertScalarError(func, array1_float)

assertEmptyBoolError = partial(
self.assertRaisesRegex, ValueError,
"The truth value of an empty array is ambiguous.")
assertEmptyBoolError(bool, empty_int)
assertEmptyBoolError(jax.jit(bool), empty_int)

assertBoolError = partial(
self.assertRaisesRegex, ValueError,
Expand Down

0 comments on commit ceb1985

Please sign in to comment.