Skip to content

add an exp2 primitive and lax.exp2 #16883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/_src/internal_test_util/lax_test_util.py
Original file line number Diff line number Diff line change
@@ -167,6 +167,7 @@ def lax_ops():
),
op_record("is_finite", 1, float_dtypes, test_util.rand_small),
op_record("exp", 1, float_dtypes + complex_dtypes, test_util.rand_small),
op_record("exp2", 1, float_dtypes + complex_dtypes, test_util.rand_small),
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
# precision.
op_record(
15 changes: 13 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
@@ -302,6 +302,10 @@ def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`."""
return exp_p.bind(x)

def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`."""
return exp2_p.bind(x)

def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`."""
return expm1_p.bind(x)
@@ -1757,10 +1761,17 @@ def _round_lower(ctx, x, *, rounding_method):

exp_p = standard_unop(_float | _complex, 'exp')
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
# For exp_p it is more efficient to use the reconstructed output for the vjp
# rule instead of computing it again from the input.
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp))

exp2_p = standard_unop(_float | _complex, 'exp2')
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
def _exp2_lower(ctx, x):
x_aval, = ctx.avals_in
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
return hlo.ExpOp(hlo.MulOp(log2, x).result).results
mlir.register_lowering(exp2_p, _exp2_lower)

log_p = standard_unop(_float | _complex, 'log')
ad.defjvp(log_p, lambda g, x: div(g, x))
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp))
1 change: 1 addition & 0 deletions jax/_src/lax_reference.py
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ def round(x):
is_finite = np.isfinite

exp = np.exp
exp2 = np.exp2
expm1 = np.expm1
log = np.log
log1p = np.log1p
3 changes: 2 additions & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
@@ -429,8 +429,9 @@ def log10(x: ArrayLike, /) -> Array:
@_wraps(np.exp2, module='numpy')
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
assert False
x, = promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
return lax.exp2(x)


@_wraps(np.signbit, module='numpy')
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
@@ -1575,6 +1575,8 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],

tf_impl_with_avals[lax.integer_pow_p] = _integer_pow
tf_impl[lax.exp_p] = tf.math.exp
tf_impl[lax_internal.exp2_p] = lambda x: \
tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x))
tf_impl[lax.expm1_p] = tf.math.expm1
tf_impl[lax.log_p] = tf.math.log
tf_impl[lax.log1p_p] = tf.math.log1p
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -93,6 +93,7 @@
eq_p as eq_p,
exp as exp,
exp_p as exp_p,
exp2 as exp2,
expand_dims as expand_dims,
expm1 as expm1,
expm1_p as expm1_p,
6 changes: 5 additions & 1 deletion tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
@@ -68,6 +68,8 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):

grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.exp2, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.expm1, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.log, nargs=1, order=2, rng_factory=jtu.rand_positive,
@@ -79,7 +81,7 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
grad_test_spec(lax.cosh, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol=1e-5),
grad_test_spec(lax.tanh, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol=1e-4),
dtypes=grad_inexact_dtypes, tol=2e-4),
grad_test_spec(lax.sin, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol={np.float32: 5e-1}),
grad_test_spec(lax.cos, nargs=1, order=2, rng_factory=jtu.rand_default,
@@ -213,6 +215,8 @@ def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
raise SkipTest("pow grad imprecise on tpu")
if op is lax.cos:
order = 1 # 2nd-order gradient is imprecise on TPU.
if op is lax.log:
order = 1 # 2nd-order gradient is imprecise on TPU.

tol = jtu.join_tolerance(1.5e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol
args = tuple(rng(shape, dtype) for shape in shapes)