Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jnp.clip to Array API 2023 standard and introduces jax.experimental.array_api.clip #20550

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Remember to align the itemized text with the first line of an item within a list
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
lowering pass via Triton Python APIs has been removed and the
`JAX_TRITON_COMPILE_VIA_XLA` environment variable no longer has any effect.
* {func}`jax.numpy.clip` has a new argument signature: `a`, `a_min`, and
`a_max` are deprecated in favor of `x` (positonal only), `min`, and
`max` ({jax-issue}`20550`).


## jaxlib 0.4.27
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,11 @@ def _itemsize(arr: ArrayLike) -> int:


def _clip(number: ArrayLike,
min: ArrayLike | None = None, max: ArrayLike | None = None,
out: None = None) -> Array:
min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
"""Return an array whose values are limited to a specified range.

Refer to :func:`jax.numpy.clip` for full documentation."""
return lax_numpy.clip(number, a_min=min, a_max=max, out=out)
return lax_numpy.clip(number, min=min, max=max)


def _transpose(a: Array, *args: Any) -> Array:
Expand Down
74 changes: 60 additions & 14 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@
from jax._src.numpy import ufuncs
from jax._src.numpy import util
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DType, DTypeLike, Shape
from jax._src.typing import (
Array, ArrayLike, DimSize, DuckTypedArray,
DType, DTypeLike, Shape, DeprecatedArg
)
from jax._src.util import (unzip2, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis,
Expand Down Expand Up @@ -1293,20 +1296,63 @@ def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | Array
axis: int = 0) -> list[Array]:
return _split("array_split", ary, indices_or_sections, axis=axis)

@util.implements(np.clip, skip_params=['out'])

_DEPRECATED_CLIP_ARG = DeprecatedArg()
@util.implements(
np.clip,
skip_params=['a', 'a_min'],
extra_params=_dedent("""
x : array_like
Array containing elements to clip.
min : array_like, optional
Minimum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``min`` is broadcast against x.
max : array_like, optional
Maximum value. If ``None``, clipping is not performed on the
corresponding edge. The value of ``max`` is broadcast against x.
""")
)
@jit
def clip(a: ArrayLike, a_min: ArrayLike | None = None,
a_max: ArrayLike | None = None, out: None = None) -> Array:
util.check_arraylike("clip", a)
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
if a_min is None and a_max is None:
raise ValueError("At most one of a_min and a_max may be None")
if a_min is not None:
a = ufuncs.maximum(a_min, a)
if a_max is not None:
a = ufuncs.minimum(a_max, a)
return asarray(a)
def clip(
x: ArrayLike | None = None, # Default to preserve backwards compatability
/,
min: ArrayLike | None = None,
max: ArrayLike | None = None,
*,
a: ArrayLike | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_min: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG,
a_max: ArrayLike | None | DeprecatedArg = _DEPRECATED_CLIP_ARG
Micky774 marked this conversation as resolved.
Show resolved Hide resolved
) -> Array:
# TODO(micky774): deprecated 2024-4-2, remove after deprecation expires.
x = a if not isinstance(a, DeprecatedArg) else x
if x is None:
raise ValueError("No input was provided to the clip function.")
min = a_min if not isinstance(a_min, DeprecatedArg) else min
max = a_max if not isinstance(a_max, DeprecatedArg) else max
if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)):
warnings.warn(
"Passing arguments 'a', 'a_min', or 'a_max' to jax.numpy.clip is "
"deprecated. Please use 'x', 'min', and 'max' respectively instead.",
DeprecationWarning,
stacklevel=2,
)

util.check_arraylike("clip", x)
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
# TODO(micky774): Deprecated 2024-4-2, remove after deprecation expires.
warnings.warn(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Attempting to clip using complex numbers is deprecated and will soon "
"raise a ValueError. Please convert to a real value or array by taking "
"the real or imaginary components via jax.numpy.real/imag respectively.",
DeprecationWarning, stacklevel=2,
)
if min is not None:
x = ufuncs.maximum(min, x)
if max is not None:
x = ufuncs.minimum(max, x)
return asarray(x)

@util.implements(np.around, skip_params=['out'])
@partial(jit, static_argnames=('decimals',))
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _zeta_series_expansion(x: ArrayLike, q: ArrayLike | None = None) -> Array:
m = jnp.expand_dims(np.arange(2 * M, dtype=M.dtype), tuple(range(s.ndim)))
s_over_a = (s_ + m) / (a_ + N)
T1 = jnp.cumprod(s_over_a, -1)[..., ::2]
T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
T1 = jnp.clip(T1, max=jnp.finfo(dtype).max)
coefs = np.expand_dims(np.array(_BERNOULLI_COEFS[:T1.shape[-1]], dtype=dtype),
tuple(range(a.ndim)))
T1 = T1 / coefs
Expand Down
6 changes: 6 additions & 0 deletions jax/_src/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ def shape(self) -> Shape: ...
# JAX array (i.e. not including future non-standard array types like KeyArray and BInt).
# It's different than np.typing.ArrayLike in that it doesn't accept arbitrary sequences,
# nor does it accept string data.

# We use a class for deprecated args to avoid using Any/object types which can
# introduce complications and mistakes in static analysis
class DeprecatedArg:
def __repr__(self):
return "Deprecated"
1 change: 1 addition & 0 deletions jax/experimental/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
bitwise_right_shift as bitwise_right_shift,
bitwise_xor as bitwise_xor,
ceil as ceil,
clip as clip,
conj as conj,
cos as cos,
cosh as cosh,
Expand Down
16 changes: 16 additions & 0 deletions jax/experimental/array_api/_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,22 @@ def ceil(x, /):
return jax.numpy.ceil(x)


def clip(x, /, min=None, max=None):
"""Returns the complex conjugate for each element x_i of the input array x."""
x, = _promote_dtypes("clip", x)

# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
if any(jax.numpy.iscomplexobj(t) for t in (x, min, max)):
raise ValueError(
"Clip received a complex value either through the input or the min/max "
"keywords. Complex values have no ordering and cannot be clipped. "
"Please convert to a real value or array by taking the real or "
"imaginary components via jax.numpy.real/imag respectively."
)
return jax.numpy.clip(x, min=min, max=max)


def conj(x, /):
"""Returns the complex conjugate for each element x_i of the input array x."""
x, = _promote_dtypes("conj", x)
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/jax2tf_limitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ def dot_column_wise(a, b):
# values like 1.0000001 on float32, which are clipped to 1.0. It is
# possible that anything other than `cos_angular_diff` can be outside
# the interval [0, 1] due to roundoff.
cos_angular_diff = jnp.clip(cos_angular_diff, a_min=0.0, a_max=1.0)
cos_angular_diff = jnp.clip(cos_angular_diff, min=0.0, max=1.0)

angular_diff = jnp.arccos(cos_angular_diff)

Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def body_fun(state):
next_t = t + dt
error_ratio = mean_error_ratio(next_y_error, rtol, atol, y, next_y)
new_interp_coeff = interp_fit_dopri(y, next_y, k, dt)
dt = jnp.clip(optimal_step_size(dt, error_ratio), a_min=0., a_max=hmax)
dt = jnp.clip(optimal_step_size(dt, error_ratio), min=0., max=hmax)

new = [i + 1, next_y, next_f, next_t, dt, t, new_interp_coeff]
old = [i + 1, y, f, t, dt, last_t, interp_coeff]
Expand All @@ -214,7 +214,7 @@ def body_fun(state):
return carry, y_target

f0 = func_(y0, ts[0])
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), a_min=0., a_max=hmax)
dt = jnp.clip(initial_step_size(func_, ts[0], y0, 4, rtol, atol, f0), min=0., max=hmax)
interp_coeff = jnp.array([y0] * 5)
init_carry = [y0, f0, ts[0], dt, ts[0], interp_coeff]
_, ys = lax.scan(scan_fun, init_carry, ts[1:])
Expand Down
16 changes: 13 additions & 3 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ from jax._src import dtypes as _dtypes
from jax._src.lax.lax import PrecisionLike
from jax._src.lax.slicing import GatherScatterMode
from jax._src.numpy.index_tricks import _Mgrid, _Ogrid, CClass as _CClass, RClass as _RClass
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape
from jax._src.typing import (
Array, ArrayLike, DType, DTypeLike,
DimSize, DuckTypedArray, Shape, DeprecatedArg
)
from jax.numpy import fft as fft, linalg as linalg
from jax.sharding import Sharding as _Sharding
import numpy as _np
Expand Down Expand Up @@ -181,8 +184,15 @@ def ceil(x: ArrayLike, /) -> Array: ...
character = _np.character
def choose(a: ArrayLike, choices: Sequence[ArrayLike],
out: None = ..., mode: str = ...) -> Array: ...
def clip(a: ArrayLike, a_min: Optional[ArrayLike] = ...,
a_max: Optional[ArrayLike] = ..., out: None = ...) -> Array: ...
def clip(
x: ArrayLike | None = ...,
/,
min: Optional[ArrayLike] = ...,
max: Optional[ArrayLike] = ...,
a: ArrayLike | DeprecatedArg | None = ...,
a_min: ArrayLike | DeprecatedArg | None = ...,
a_max: ArrayLike | DeprecatedArg | None = ...
) -> Array: ...
def column_stack(
tup: Union[_np.ndarray, Array, Sequence[ArrayLike]]
) -> Array: ...
Expand Down
23 changes: 23 additions & 0 deletions tests/array_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
'broadcast_to',
'can_cast',
'ceil',
'clip',
'complex128',
'complex64',
'concat',
Expand Down Expand Up @@ -233,5 +234,27 @@ def test_array_namespace_method(self):
self.assertIs(x.__array_namespace__(), array_api)


class ArrayAPIErrors(absltest.TestCase):
"""Test that our array API implementations raise errors where required"""

# TODO(micky774): Remove when jnp.clip deprecation is completed
# (began 2024-4-2) and default behavior is Array API 2023 compliant
def test_clip_complex(self):
x = array_api.arange(5, dtype=array_api.complex64)
complex_msg = "Complex values have no ordering and cannot be clipped"
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x)

with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=x)

x = array_api.arange(5, dtype=array_api.int32)
with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, min=-1+5j)

with self.assertRaisesRegex(ValueError, complex_msg):
array_api.clip(x, max=-1+5j)


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion tests/lax_metal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def testClipStaticBounds(self, shape, dtype, a_min, a_max):
a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
Expand Down
44 changes: 39 additions & 5 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,14 +872,45 @@ def testClipStaticBounds(self, shape, dtype, a_min, a_max):
a_max = None if a_max is None else abs(a_max)
rng = jtu.rand_default(self.rng())
np_fun = lambda x: np.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, a_min=a_min, a_max=a_max)
jnp_fun = lambda x: jnp.clip(x, min=a_min, max=a_max)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

def testClipError(self):
with self.assertRaisesRegex(ValueError, "At most one of a_min and a_max.*"):
jnp.clip(jnp.zeros((3,)))

@jtu.sample_product(
shape=all_shapes,
dtype=default_dtypes + unsigned_dtypes,
)
def testClipNone(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
self.assertArraysEqual(jnp.clip(x), x)


# TODO(micky774): Check for ValueError instead of DeprecationWarning when
# jnp.clip deprecation is completed (began 2024-4-2) and default behavior is
# Array API 2023 compliant
@jtu.sample_product(shape=all_shapes)
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
@jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion
def testClipComplexInputDeprecation(self, shape):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype=jnp.complex64)
msg = "Complex values have no ordering and cannot be clipped"
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x)

with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, max=x)

x = rng(shape, dtype=jnp.int32)
with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, min=-1+5j)

with self.assertWarns(DeprecationWarning, msg=msg):
jnp.clip(x, max=jnp.array([-1+5j]))


@jtu.sample_product(
[dict(shape=shape, dtype=dtype)
Expand Down Expand Up @@ -5772,7 +5803,7 @@ def testWrappedSignaturesMatch(self):
'argpartition': ['kind', 'order'],
'asarray': ['like'],
'broadcast_to': ['subok'],
'clip': ['kwargs'],
'clip': ['kwargs', 'out'],
'copy': ['subok'],
'corrcoef': ['ddof', 'bias', 'dtype'],
'cov': ['dtype'],
Expand Down Expand Up @@ -5809,6 +5840,9 @@ def testWrappedSignaturesMatch(self):
}

extra_params = {
# TODO(micky774): Remove when np.clip has adopted the Array API 2023
# standard
'clip': ['x', 'max', 'min'],
'einsum': ['subscripts', 'precision'],
'einsum_path': ['subscripts'],
'take_along_axis': ['mode'],
Expand Down