Skip to content

Commit

Permalink
jax.numpy: add trig aliases acos(h), asin(h), atan(h), atan2
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 19, 2023
1 parent 2b54527 commit 3df811c
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 16 deletions.
28 changes: 19 additions & 9 deletions jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ def op(*args):
return bitwise_op(*promote_args(np_op.__name__, *args))
return op

@jit
def _arccosh(x: ArrayLike, /) -> Array:
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
# convention than np.arccosh.
out = lax.acosh(*promote_args_inexact("arccosh", x))
if dtypes.issubdtype(out.dtype, np.complexfloating):
out = _where(real(out) < 0, lax.neg(out), out)
return out

fabs = _one_to_one_unop(np.fabs, lax.abs, True)
bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
Expand All @@ -159,6 +167,7 @@ def op(*args):
sinh = _one_to_one_unop(np.sinh, lax.sinh, True)
cosh = _one_to_one_unop(np.cosh, lax.cosh, True)
arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
arccosh = _one_to_one_unop(np.arccosh, _arccosh, True)
tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
Expand Down Expand Up @@ -189,15 +198,16 @@ def op(*args):
logical_or: BinOp = _logical_op(np.logical_or, lax.bitwise_or)
logical_xor: BinOp = _logical_op(np.logical_xor, lax.bitwise_xor)

@_wraps(np.arccosh, module='numpy')
@jit
def arccosh(x: ArrayLike, /) -> Array:
# Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
# convention than np.arccosh.
out = lax.acosh(*promote_args_inexact("arccosh", x))
if dtypes.issubdtype(out.dtype, np.complexfloating):
out = _where(real(out) < 0, lax.neg(out), out)
return out
# Array API aliases
# TODO(jakevdp): directly reference np_fun when minimum numpy version is 2.0
acos = _one_to_one_unop(getattr(np, "acos", np.arccos), lax.acos, True)
acosh = _one_to_one_unop(getattr(np, "acosh", np.arccosh), _arccosh, True)
asin = _one_to_one_unop(getattr(np, "asin", np.arcsin), lax.asin, True)
asinh = _one_to_one_unop(getattr(np, "asinh", np.arcsinh), lax.asinh, True)
atan = _one_to_one_unop(getattr(np, "atan", np.arctan), lax.atan, True)
atanh = _one_to_one_unop(getattr(np, "atanh", np.arctanh), lax.atanh, True)
atan2 = _one_to_one_binop(getattr(np, "atan2", np.arctan2), lax.atan2, True)


@_wraps(getattr(np, 'bitwise_count', None), module='numpy')
@jit
Expand Down
7 changes: 7 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@
from jax._src.numpy.ufuncs import (
abs as abs,
absolute as absolute,
acos as acos,
acosh as acosh,
add as add,
arccos as arccos,
arccosh as arccosh,
Expand All @@ -331,6 +333,11 @@
arctan as arctan,
arctan2 as arctan2,
arctanh as arctanh,
asin as asin,
asinh as asinh,
atan as atan,
atanh as atanh,
atan2 as atan2,
bitwise_and as bitwise_and,
bitwise_count as bitwise_count,
bitwise_not as bitwise_not,
Expand Down
8 changes: 7 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ ComplexWarning: type
_deprecations: dict[str, tuple[str, Any]]
def abs(x: ArrayLike, /) -> Array: ...
def absolute(x: ArrayLike, /) -> Array: ...
def acos(x: ArrayLike, /) -> Array: ...
def acosh(x: ArrayLike, /) -> Array: ...
def add(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def amax(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: bool = ..., initial: Optional[ArrayLike] = ...,
Expand Down Expand Up @@ -99,8 +101,12 @@ array_str = _np.array_str
def asarray(
a: Any, dtype: Optional[DTypeLike] = ..., order: Optional[str] = ...
) -> Array: ...
def asin(x: ArrayLike, /) -> Array: ...
def asinh(x: ArrayLike, /) -> Array: ...
def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: bool = ...) -> Array: ...

def atan(x: ArrayLike, /) -> Array: ...
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def atanh(x: ArrayLike, /) -> Array: ...
@overload
def atleast_1d() -> list[Array]: ...
@overload
Expand Down
27 changes: 21 additions & 6 deletions tests/lax_numpy_operators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,15 @@ def _valid_dtypes_for_shape(shape, dtypes):
OpRecord = collections.namedtuple(
"OpRecord",
["name", "nargs", "dtypes", "shapes", "rng_factory", "diff_modes",
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs"])
"test_name", "check_dtypes", "tolerance", "inexact", "kwargs", "alias"])

def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name=None, check_dtypes=True,
tolerance=None, inexact=False, kwargs=None):
tolerance=None, inexact=False, kwargs=None,
alias=None):
test_name = test_name or name
return OpRecord(name, nargs, dtypes, shapes, rng_factory, diff_modes,
test_name, check_dtypes, tolerance, inexact, kwargs)
test_name, check_dtypes, tolerance, inexact, kwargs, alias)

JAX_ONE_TO_ONE_OP_RECORDS = [
op_record("abs", 1, all_dtypes,
Expand Down Expand Up @@ -168,6 +169,20 @@ def op_record(name, nargs, dtypes, shapes, rng_factory, diff_modes,
inexact=True, tolerance={np.complex64: 2E-2, np.complex128: 2E-12}),
op_record("arctanh", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True, tolerance={np.float64: 1e-9}),
op_record("asin", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True, tolerance={np.complex128: 2e-15}, alias="arcsin"),
op_record("acos", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True, alias="arccos"),
op_record("atan", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True, alias="arctan"),
op_record("atan2", 2, float_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True, check_dtypes=False, alias="arctan2"),
op_record("asinh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True, tolerance={np.complex64: 2E-4, np.complex128: 2E-14}, alias="arcsinh"),
op_record("acosh", 1, number_dtypes, all_shapes, jtu.rand_default, ["rev"],
inexact=True, tolerance={np.complex64: 2E-2, np.complex128: 2E-12}, alias="arccosh"),
op_record("atanh", 1, number_dtypes, all_shapes, jtu.rand_small, ["rev"],
inexact=True, tolerance={np.float64: 1e-9}, alias="arctanh"),
]

JAX_COMPOUND_OP_RECORDS = [
Expand Down Expand Up @@ -418,7 +433,7 @@ def f():
jtu.sample_product_testcases(
[dict(op_name=rec.name, rng_factory=rec.rng_factory,
check_dtypes=rec.check_dtypes, tolerance=rec.tolerance,
inexact=rec.inexact, kwargs=rec.kwargs or {})],
inexact=rec.inexact, kwargs=rec.kwargs or {}, alias=rec.alias)],
[dict(shapes=shapes, dtypes=dtypes)
for shapes in filter(
_shapes_are_broadcast_compatible,
Expand All @@ -430,8 +445,8 @@ def f():
JAX_COMPOUND_OP_RECORDS)))
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testOp(self, op_name, rng_factory, shapes, dtypes, check_dtypes,
tolerance, inexact, kwargs):
np_op = partial(getattr(np, op_name), **kwargs)
tolerance, inexact, kwargs, alias):
np_op = partial(getattr(np, op_name) if hasattr(np, op_name) else getattr(np, alias), **kwargs)
jnp_op = partial(getattr(jnp, op_name), **kwargs)
np_op = jtu.ignore_warning(category=RuntimeWarning,
message="invalid value.*")(np_op)
Expand Down

0 comments on commit 3df811c

Please sign in to comment.