From fb56224ae08ab391c705f2596deb69a866cd01f4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 17 Jan 2024 08:59:40 -0800 Subject: [PATCH] jnp.sign: use x/abs(x) for complex arguments --- CHANGELOG.md | 4 +++- jax/_src/numpy/ufuncs.py | 13 +------------ tests/lax_numpy_operators_test.py | 21 ++++++++++++++++++++- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12d33b922b56..6c69c3138ec3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,9 +48,11 @@ Remember to align the itemized text with the first line of an item within a list * {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices reshaped to the dimension of the input, following a similar change to {func}`numpy.unique` in NumPy 2.0. + * {func}`jax.numpy.sign` now returns `x / abs(x)` for nonzero complex inputs. This is + consistent with the behavior of {func}`numpy.sign` in NumPy version 2.0. * {func}`jax.scipy.special.logsumexp` with `return_sign=True` now uses the NumPy 2.0 convention for the complex sign, `x / abs(x)`. This is consistent with the behavior - of the function in SciPy v1.13. + of {func}`scipy.special.logsumexp` in SciPy v1.13. * Deprecations & Removals * A number of previously deprecated functions have been removed, following a diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index ceb45f1898ca..40032e285f54 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -173,6 +173,7 @@ def _arccosh(x: ArrayLike, /) -> Array: arccosh = _one_to_one_unop(np.arccosh, _arccosh, True) tanh = _one_to_one_unop(np.tanh, lax.tanh, True) arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True) +sign = _one_to_one_unop(np.sign, lax.sign) sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True) cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True) @@ -257,18 +258,6 @@ def rint(x: ArrayLike, /) -> Array: return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) -@_wraps(np.sign, module='numpy') -@jit -def sign(x: ArrayLike, /) -> Array: - check_arraylike('sign', x) - dtype = dtypes.dtype(x) - if dtypes.issubdtype(dtype, np.complexfloating): - re = lax.real(x) - return lax.complex( - lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0)) - return lax.sign(x) - - @_wraps(np.copysign, module='numpy') @jit def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index c16dbb467c47..c1c04935f5fe 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -56,6 +56,7 @@ default_dtypes = float_dtypes + int_dtypes inexact_dtypes = float_dtypes + complex_dtypes number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes +real_dtypes = float_dtypes + int_dtypes + unsigned_dtypes all_dtypes = number_dtypes + bool_dtypes @@ -272,7 +273,9 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes, []), op_record("rint", 1, int_dtypes + unsigned_dtypes, all_shapes, jtu.rand_default, [], check_dtypes=False), - op_record("sign", 1, number_dtypes, all_shapes, jtu.rand_some_inf_and_nan, []), + # numpy < 2.0.0 has a different convention for complex sign. + op_record("sign", 1, real_dtypes if jtu.numpy_version() < (2, 0, 0) else number_dtypes, + all_shapes, jtu.rand_some_inf_and_nan, []), # numpy 1.16 has trouble mixing uint and bfloat16, so we test these separately. op_record("copysign", 2, default_dtypes + unsigned_dtypes, all_shapes, jtu.rand_some_inf_and_nan, [], check_dtypes=False), @@ -646,6 +649,22 @@ def testShiftOpAgainstNumpy(self, op, dtypes, shapes): self._CompileAndCheck(op, args_maker) self._CheckAgainstNumpy(np_op, op, args_maker) + # This test can be deleted once we test against NumPy 2.0. + @jtu.sample_product( + shape=all_shapes, + dtype=complex_dtypes + ) + def testSignComplex(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + if jtu.numpy_version() >= (2, 0, 0): + np_fun = np.sign + else: + np_fun = lambda x: (x / np.where(x == 0, 1, abs(x))).astype(np.result_type(x)) + jnp_fun = jnp.sign + args_maker = lambda: [rng(shape, dtype)] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + def testDeferToNamedTuple(self): class MyArray(NamedTuple): arr: jax.Array