Skip to content

Commit

Permalink
Merge pull request #20825 from jakevdp:gammasgn
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626347182
  • Loading branch information
jax authors committed Apr 19, 2024
2 parents 41fa67c + 568db10 commit c7517b8
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/jax.scipy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ jax.scipy.special
gammainc
gammaincc
gammaln
gammasgn
hyp1f1
i0
i0e
Expand Down
21 changes: 18 additions & 3 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,22 @@ def gammaln(x: ArrayLike) -> Array:
return lax.lgamma(x)


def _gamma_sign(x: Array) -> Array:
def gammasgn(x: ArrayLike) -> Array:
"""Sign of the gamma function.
JAX implementation of :func:`scipy.special.gammasgn`.
Args:
x: arraylike, real valued.
Returns:
array containing 1.0 where gamma(x) is positive, and -1.0 where
gamma(x) is negative.
See Also:
:func:`jax.scipy.special.gamma`
"""
x, = promote_args_inexact("gammasgn", x)
floor_x = lax.floor(x)
return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0)

Expand All @@ -53,7 +68,7 @@ def _gamma_sign(x: Array) -> Array:
The JAX version only accepts real-valued inputs.""")
def gamma(x: ArrayLike) -> Array:
x, = promote_args_inexact("gamma", x)
return _gamma_sign(x) * lax.exp(lax.lgamma(x))
return gammasgn(x) * lax.exp(lax.lgamma(x))

betaln = implements(
osp_special.betaln,
Expand All @@ -73,7 +88,7 @@ def factorial(n: ArrayLike, exact: bool = False) -> Array:
@implements(osp_special.beta, module='scipy.special')
def beta(x: ArrayLike, y: ArrayLike) -> Array:
x, y = promote_args_inexact("beta", x, y)
sign = _gamma_sign(x) * _gamma_sign(y) * _gamma_sign(x + y)
sign = gammasgn(x) * gammasgn(y) * gammasgn(x + y)
return sign * lax.exp(betaln(x, y))


Expand Down
1 change: 1 addition & 0 deletions jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
gammainc as gammainc,
gammaincc as gammaincc,
gammaln as gammaln,
gammasgn as gammasgn,
gamma as gamma,
i0 as i0,
i0e as i0e,
Expand Down
3 changes: 3 additions & 0 deletions tests/lax_scipy_special_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t
op_record(
"gammaincc", 2, float_dtypes, jtu.rand_positive, True
),
op_record(
"gammasgn", 1, float_dtypes, jtu.rand_default, True
),
op_record(
"erf", 1, float_dtypes, jtu.rand_small_positive, True
),
Expand Down

0 comments on commit c7517b8

Please sign in to comment.