Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
and we encourage all new code to use `jax.shard_map` directly. See the
[migration guide](https://docs.jax.dev/en/latest/deprecate_pmap.html) for
more information.
* JAX no longer allows passing objects that support `__jax_array__` directly
to, e.g. `jit`-ed functions. Call `jax.numpy.asarray` on them first.
* {func}`jax.numpy.cov` is now returns NaN for empty arrays ({jax-issue}`#32305`),
and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}`#32308`).
* JAX no longer accepts `Array` values where a `dtype` value is expected. Call
Expand Down
2 changes: 0 additions & 2 deletions jax/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ pytype_strict_library(
],
deps = [
":config",
":deprecations",
":dtypes",
":effects",
":layout",
Expand Down Expand Up @@ -725,7 +724,6 @@ pytype_strict_library(
],
deps = [
":config",
":deprecations",
":literals",
":traceback_util",
":typing",
Expand Down
28 changes: 12 additions & 16 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

import numpy as np

from jax._src import deprecations
from jax._src import dtypes
from jax._src import config
from jax._src import effects
Expand Down Expand Up @@ -1742,13 +1741,11 @@ def shaped_abstractify(x):
if isinstance(x, AbstractValue):
return x
if hasattr(x, '__jax_array__'):
deprecations.warn(
'jax-abstract-dunder-array',
('Triggering of __jax_array__() during abstractification is deprecated.'
' To avoid this error, either explicitly convert your object using'
' jax.numpy.array(), or register your object as a pytree.'),
stacklevel=6)
return shaped_abstractify(x.__jax_array__())
raise ValueError(
'Triggering __jax_array__() during abstractification is no longer'
' supported. To avoid this error, either explicitly convert your object'
' using jax.numpy.array(), or register your object as a pytree.'
)
if hasattr(x, 'dtype'):
aval = ShapedArray(
np.shape(x),
Expand All @@ -1767,21 +1764,20 @@ def abstractify(x):
return get_aval(x)


def get_aval(x):
# TODO(phawkins): the return type should be AbstractValue.
def get_aval(x: Any) -> Any:
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
for t in typ.__mro__[1:]:
if (aval_fn := pytype_aval_mappings.get(t)):
return aval_fn(x)
if hasattr(x, '__jax_array__'):
deprecations.warn(
'jax-abstract-dunder-array',
('Triggering of __jax_array__() during abstractification is deprecated.'
' To avoid this error, either explicitly convert your object using'
' jax.numpy.array(), or register your object as a pytree.'),
stacklevel=6)
return get_aval(x.__jax_array__())
raise ValueError(
'Triggering __jax_array__() during abstractification is no longer'
' supported. To avoid this error, either explicitly convert your object'
' using jax.numpy.array(), or register your object as a pytree.'
)
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")

typeof = get_aval
Expand Down
1 change: 0 additions & 1 deletion jax/_src/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,5 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
register('jax-numpy-linalg-matrix_rank-tol')
register('jax-numpy-linalg-pinv-rcond')
register('jax-scipy-special-sph-harm')
register('jax-abstract-dunder-array')
register('safer-randint-config')
register('jax-pmap-no-rank-reduction')
15 changes: 4 additions & 11 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import numpy as np

from jax._src import config
from jax._src import deprecations
from jax._src import literals
from jax._src.typing import Array, DType, DTypeLike
from jax._src.util import set_module, StrictABC
Expand Down Expand Up @@ -375,17 +374,11 @@ def canonicalize_value(x):
if handler:
return handler(x)
if hasattr(x, '__jax_array__'):
deprecations.warn(
'jax-abstract-dunder-array',
(
'Triggering of __jax_array__() during abstractification is'
' deprecated. To avoid this error, either explicitly convert your'
' object using jax.numpy.array(), or register your object as a'
' pytree.'
),
stacklevel=6,
raise ValueError(
'Triggering __jax_array__() during abstractification is no longer'
' supported. To avoid this error, either explicitly convert your object'
' using jax.numpy.array(), or register your object as a pytree.'
)
return canonicalize_value(x.__jax_array__())
raise InvalidInputException(
f"Argument '{x}' of type {type(x)} is not a valid JAX type."
)
Expand Down
7 changes: 5 additions & 2 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4199,8 +4199,11 @@ def __jax_array__(self):

f = jax.jit(lambda x: x)
a = AlexArray(jnp.arange(4))
msg = r"Triggering of __jax_array__\(\) during abstractification is deprecated."
with self.assertDeprecationWarnsOrRaises('jax-abstract-dunder-array', msg):
msg = (
r"Triggering __jax_array__\(\) during abstractification is no longer"
r" supported."
)
with self.assertRaisesRegex(ValueError, msg):
f(a)

@jtu.thread_unsafe_test() # count_jit_tracing_cache_miss() isn't thread-safe
Expand Down
Loading