diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fab6fc1cce3..9073caea2f45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ Remember to align the itemized text with the first line of an item within a list * In {func}`jax.jit`, passing invalid `static_argnums` or `static_argnames` now leads to an error rather than a warning. * The minimum jaxlib version is now 0.4.23. + * The {func}`jax.numpy.hypot` function now issues a deprecation warning when + passing complex-valued inputs to it. This will raise an error when the + deprecation is completed. ## jaxlib 0.4.27 diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index a9e956a7db11..ca2ff69257b4 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -22,6 +22,7 @@ import operator from textwrap import dedent from typing import Any, Callable, overload +import warnings import numpy as np @@ -730,12 +731,22 @@ def heaviside(x1: ArrayLike, x2: ArrayLike, /) -> Array: @implements(np.hypot, module='numpy') @jit def hypot(x1: ArrayLike, x2: ArrayLike, /) -> Array: - check_arraylike("hypot", x1, x2) - x1, x2 = promote_dtypes_inexact(x1, x2) - x1 = lax.abs(x1) - x2 = lax.abs(x2) + x1, x2 = promote_args_inexact("hypot", x1, x2) + + # TODO(micky774): Promote to ValueError when deprecation is complete + # (began 2024-4-14). + if dtypes.issubdtype(x1.dtype, np.complexfloating): + warnings.warn( + "Passing complex-valued inputs to hypot is deprecated and will raise a " + "ValueError in the future. Please convert to real values first, such as " + "by using jnp.real or jnp.imag to take the real or imaginary components " + "respectively.", + DeprecationWarning, stacklevel=2) + x1, x2 = lax.abs(x1), lax.abs(x2) + idx_inf = lax.bitwise_or(isposinf(x1), isposinf(x2)) x1, x2 = maximum(x1, x2), minimum(x1, x2) - return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax._ones(x1), x1))))) + x = _where(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, _where(x1 == 0, lax._ones(x1), x1))))) + return _where(idx_inf, _lax_const(x, np.inf), x) @implements(np.reciprocal, module='numpy') diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 876b17dbe8b4..14405e67f3b9 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -125,6 +125,7 @@ floor_divide as floor_divide, greater as greater, greater_equal as greater_equal, + hypot as hypot, imag as imag, isfinite as isfinite, isinf as isinf, diff --git a/jax/experimental/array_api/_elementwise_functions.py b/jax/experimental/array_api/_elementwise_functions.py index c34e9d93cfb0..f6f184dcf726 100644 --- a/jax/experimental/array_api/_elementwise_functions.py +++ b/jax/experimental/array_api/_elementwise_functions.py @@ -13,6 +13,7 @@ # limitations under the License. import jax +from jax._src.dtypes import issubdtype from jax.experimental.array_api._data_type_functions import ( result_type as _result_type, isdtype as _isdtype, @@ -214,6 +215,20 @@ def greater_equal(x1, x2, /): return jax.numpy.greater_equal(x1, x2) +def hypot(x1, x2, /): + """Computes the square root of the sum of squares for each element x1_i of the input array x1 with the respective element x2_i of the input array x2.""" + x1, x2 = _promote_dtypes("hypot", x1, x2) + + # TODO(micky774): Remove when jnp.hypot deprecation is completed + # (began 2024-4-14) and default behavior is Array API 2023 compliant + if issubdtype(x1.dtype, jax.numpy.complexfloating): + raise ValueError( + "hypot does not support complex-valued inputs. Please convert to real " + "values first, such as by using jnp.real or jnp.imag to take the real " + "or imaginary components respectively.") + return jax.numpy.hypot(x1, x2) + + def imag(x, /): """Returns the imaginary component of a complex number for each element x_i of the input array x.""" x, = _promote_dtypes("imag", x) diff --git a/tests/array_api_test.py b/tests/array_api_test.py index 9871c100b3ec..f4dcfb74cc78 100644 --- a/tests/array_api_test.py +++ b/tests/array_api_test.py @@ -89,6 +89,7 @@ 'full_like', 'greater', 'greater_equal', + 'hypot', 'iinfo', 'imag', 'inf', diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index c1c04935f5fe..79866d8ee22f 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -221,7 +221,7 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, op_record("fmod", 2, default_dtypes, all_shapes, jtu.rand_some_nan, []), op_record("heaviside", 2, default_dtypes, all_shapes, jtu.rand_default, [], inexact=True), - op_record("hypot", 2, default_dtypes, all_shapes, jtu.rand_default, [], + op_record("hypot", 2, real_dtypes, all_shapes, jtu.rand_default, [], inexact=True), op_record("kron", 2, number_dtypes, nonempty_shapes, jtu.rand_default, []), op_record("outer", 2, number_dtypes, all_shapes, jtu.rand_default, []), diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c957ed669b3a..24bd01247c97 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -914,6 +914,26 @@ def testClipComplexInputDeprecation(self, shape): jnp.clip(x, max=jnp.array([-1+5j])) + # TODO(micky774): Check for ValueError instead of DeprecationWarning when + # jnp.hypot deprecation is completed (began 2024-4-2) and default behavior is + # Array API 2023 compliant + @jtu.sample_product(shape=all_shapes) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion + def testHypotComplexInputDeprecation(self, shape): + rng = jtu.rand_default(self.rng()) + x = rng(shape, dtype=jnp.complex64) + msg = "Passing complex-valued inputs to hypot" + # jit is disabled so we don't miss warnings due to caching. + with jax.disable_jit(): + with self.assertWarns(DeprecationWarning, msg=msg): + jnp.hypot(x, x) + + with self.assertWarns(DeprecationWarning, msg=msg): + y = jnp.ones_like(x) + jnp.hypot(x, y) + + @jtu.sample_product( [dict(shape=shape, dtype=dtype) for shape, dtype in _shape_and_dtypes(all_shapes, number_dtypes)],