Skip to content

Commit

Permalink
[array api] Finalize array API in jax.numpy & deprecate jax.experimen…
Browse files Browse the repository at this point in the history
…tal.array_api
  • Loading branch information
jakevdp committed Aug 1, 2024
1 parent b3924da commit 14fa062
Show file tree
Hide file tree
Showing 18 changed files with 128 additions and 325 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ jobs:
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/maps.py
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/jax-array-api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ jobs:
python -m pip install -r array-api-tests/requirements.txt
- name: Run the test suite
env:
ARRAY_API_TESTS_MODULE: jax.experimental.array_api
ARRAY_API_TESTS_MODULE: jax.numpy
JAX_ENABLE_X64: 'true'
run: |
cd ${GITHUB_WORKSPACE}/array-api-tests
pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/jax/experimental/array_api/skips.txt
pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## jax 0.4.32

* Changes
* {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard.
See {ref}`python-array-api` for more information.

* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
`stablehlo` dialect instead.
Expand All @@ -23,6 +27,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* `jax.lib.xla_bridge.xla_client`: use {mod}`jax.lib.xla_client` directly.
* `jax.lib.xla_bridge.get_backend`: use {func}`jax.extend.backend.get_backend`.
* `jax.lib.xla_bridge.default_backend`: use {func}`jax.extend.backend.default_backend`.
* The `jax.experimental.array_api` module is deprecated, and importing it is no
longer required to use the Array API. `jax.numpy` supports the array API
directly; see {ref}`python-array-api` for more information.

## jaxlib 0.4.32

Expand Down
26 changes: 25 additions & 1 deletion docs/jax.experimental.array_api.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,28 @@
``jax.experimental.array_api`` module
=====================================

.. automodule:: jax.experimental.array_api
.. note::
The ``jax.experimental.array_api`` module is deprecated as of JAX v0.4.32, and
importing ``jax.experimental.array_api`` is no longer necessary. {mod}`jax.numpy`
implements the array API standard directly by default. See :ref:`python-array-api`
for details.

This module includes experimental JAX support for the `Python array API standard`_.
Support for this is currently experimental and not fully complete.

Example Usage::

>>> from jax.experimental import array_api as xp

>>> xp.__array_api_version__
'2023.12'

>>> arr = xp.arange(1000)

>>> arr.sum()
Array(499500, dtype=int32)

The ``xp`` namespace is the array API compliant analog of :mod:`jax.numpy`,
and implements most of the API listed in the standard.

.. _Python array API standard: https://data-apis.org/array-api/
33 changes: 33 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,36 @@ This is because in general, pickling and unpickling may take place in different
environments, and there is no general way to map the device IDs of one runtime
to the device IDs of another. If :mod:`pickle` is used in traced/JIT-compiled code,
it will result in a :class:`~jax.errors.ConcretizationTypeError`.

.. _python-array-api:

Python Array API standard
-------------------------

.. note::

Prior to JAX v0.4.32, you must ``import jax.experimental.array_api`` in order
to enable the array API for JAX arrays. After JAX v0.4.32, importing this
module is no longer required, and will raise a deprecation warning.

Starting with JAX v0.4.32, :class:`jax.Array` and :mod:`jax.numpy` are compatible
with the `Python Array API Standard`_. You can access the Array API namespace via
:meth:`jax.Array.__array_namespace__`::

>>> def f(x):
... nx = x.__array_namespace__()
... return nx.sin(x) ** 2 + nx.cos(x) ** 2

>>> import jax.numpy as jnp
>>> x = jnp.arange(5)
>>> f(x).round()
Array([1., 1., 1., 1., 1.], dtype=float32)

JAX departs from the standard in a few places, namely because JAX arrays are
immutable, in-place updates are not supported. Some of these incompatibilities
are being addressed via the `array-api-compat`_ module.

For more information, refer to the `Python Array API Standard`_ documentation.

.. _Python Array API Standard: https://data-apis.org/array-api
.. _array-api-compat: https://github.com/data-apis/array-api-compat
10 changes: 4 additions & 6 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -987,13 +987,11 @@ pytype_library(

pytype_library(
name = "experimental_array_api",
srcs = glob(
[
"experimental/array_api/*.py",
],
),
visibility = [":internal"] + jax_visibility("array_api"),
deps = [":jax"],
deps = [
":experimental",
":jax",
],
)

pytype_library(
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import abc
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import Any, Union
import numpy as np

Expand Down Expand Up @@ -48,6 +49,8 @@ class Array(abc.ABC):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")

def __array_namespace__(self, *, api_version: None | str = ...) -> ModuleType: ...

def __getitem__(self, key) -> Array: ...
def __setitem__(self, key, value) -> None: ...
def __len__(self) -> int: ...
Expand Down
28 changes: 7 additions & 21 deletions jax/_src/numpy/array_api_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,30 @@
"""
from __future__ import annotations

import importlib
from types import ModuleType

import jax
from jax._src.sharding import Sharding
from jax._src.lib import xla_client as xc
from jax._src import dtypes as _dtypes, config


# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete.
__array_api_version__ = '2023.12'


# TODO(jakevdp, vfdev-5): export this in jax.numpy once migration is complete.
def __array_namespace_info__() -> ArrayNamespaceInfo:
return ArrayNamespaceInfo()


def _array_namespace_property(self):
# TODO(jakevdp): clean this up once numpy fully supports the array API.
# In some environments, jax.experimental.array_api is not available.
# We return an AttributeError in this case, because some callers use
# hasattr checks to check for array API compatibility.
if not importlib.util.find_spec('jax.experimental.array_api'):
raise AttributeError("__array_namespace__ requires jax.experimental.array_api")
return __array_namespace__


def __array_namespace__(*, api_version: None | str = None):
def __array_namespace__(self, *, api_version: None | str = None) -> ModuleType:
"""Return the `Python array API`_ namespace for JAX.
.. _Python array API: https://data-apis.org/array-api/
"""
if api_version is not None and api_version != __array_api_version__:
raise ValueError(f"{api_version=!r} is not available; "
f"available versions are: {[__array_api_version__]}")
# TODO(jakevdp, vfdev-5): change this to jax.numpy once migration is complete.
import jax.experimental.array_api
return jax.experimental.array_api # pytype: disable=module-attr
return jax.numpy


def __array_namespace_info__() -> ArrayNamespaceInfo:
return ArrayNamespaceInfo()


class ArrayNamespaceInfo:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
}

_array_methods = {
"__array_namespace__": array_api_metadata.__array_namespace__,
"all": reductions.all,
"any": reductions.any,
"argmax": lax_numpy.argmax,
Expand Down Expand Up @@ -719,7 +720,6 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False,
}

_array_properties = {
"__array_namespace__": array_api_metadata._array_namespace_property,
"flat": _notimplemented_flat,
"T": lax_numpy.transpose,
"mT": lax_numpy.matrix_transpose,
Expand Down
17 changes: 17 additions & 0 deletions jax/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,20 @@
from jax._src.earray import (
EArray as EArray
)

from jax import numpy as _array_api


_deprecations = {
# Deprecated 01 Aug 2024
"array_api": (
"jax.experimental.array_api import is no longer required as of JAX v0.4.32; "
"jax.numpy supports the array API by default.",
_array_api
),
}

from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _array_api
Loading

0 comments on commit 14fa062

Please sign in to comment.