Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add an exp2 primitive and lax.exp2 #16883

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/_src/internal_test_util/lax_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def lax_ops():
),
op_record("is_finite", 1, float_dtypes, test_util.rand_small),
op_record("exp", 1, float_dtypes + complex_dtypes, test_util.rand_small),
op_record("exp2", 1, float_dtypes + complex_dtypes, test_util.rand_small),
# TODO(b/142975473): on CPU, expm1 for float64 is only accurate to ~float32
# precision.
op_record(
Expand Down
15 changes: 13 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,10 @@ def exp(x: ArrayLike) -> Array:
r"""Elementwise exponential: :math:`e^x`."""
return exp_p.bind(x)

def exp2(x: ArrayLike) -> Array:
r"""Elementwise base-2 exponential: :math:`2^x`."""
return exp2_p.bind(x)

def expm1(x: ArrayLike) -> Array:
r"""Elementwise :math:`e^{x} - 1`."""
return expm1_p.bind(x)
Expand Down Expand Up @@ -1757,10 +1761,17 @@ def _round_lower(ctx, x, *, rounding_method):

exp_p = standard_unop(_float | _complex, 'exp')
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
# For exp_p it is more efficient to use the reconstructed output for the vjp
epiqueras marked this conversation as resolved.
Show resolved Hide resolved
# rule instead of computing it again from the input.
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp))

exp2_p = standard_unop(_float | _complex, 'exp2')
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
def _exp2_lower(ctx, x):
x_aval, = ctx.avals_in
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
return hlo.ExpOp(hlo.MulOp(log2, x).result).results
mlir.register_lowering(exp2_p, _exp2_lower)

log_p = standard_unop(_float | _complex, 'log')
ad.defjvp(log_p, lambda g, x: div(g, x))
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp))
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def round(x):
is_finite = np.isfinite

exp = np.exp
exp2 = np.exp2
expm1 = np.expm1
log = np.log
log1p = np.log1p
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/numpy/ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,9 @@ def log10(x: ArrayLike, /) -> Array:
@_wraps(np.exp2, module='numpy')
@partial(jit, inline=True)
def exp2(x: ArrayLike, /) -> Array:
assert False
x, = promote_args_inexact("exp2", x)
return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
return lax.exp2(x)


@_wraps(np.signbit, module='numpy')
Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,6 +1575,8 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],

tf_impl_with_avals[lax.integer_pow_p] = _integer_pow
tf_impl[lax.exp_p] = tf.math.exp
tf_impl[lax_internal.exp2_p] = lambda x: \
tf.math.exp(tf.math.multiply(tf.math.log(tf.constant(2, x.dtype)), x))
tf_impl[lax.expm1_p] = tf.math.expm1
tf_impl[lax.log_p] = tf.math.log
tf_impl[lax.log1p_p] = tf.math.log1p
Expand Down
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
eq_p as eq_p,
exp as exp,
exp_p as exp_p,
exp2 as exp2,
expand_dims as expand_dims,
expm1 as expm1,
expm1_p as expm1_p,
Expand Down
6 changes: 5 additions & 1 deletion tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):

grad_test_spec(lax.exp, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.exp2, nargs=1, order=2, rng_factory=jtu.rand_small,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.expm1, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes),
grad_test_spec(lax.log, nargs=1, order=2, rng_factory=jtu.rand_positive,
Expand All @@ -79,7 +81,7 @@ def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
grad_test_spec(lax.cosh, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol=1e-5),
grad_test_spec(lax.tanh, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol=1e-4),
dtypes=grad_inexact_dtypes, tol=2e-4),
grad_test_spec(lax.sin, nargs=1, order=2, rng_factory=jtu.rand_default,
dtypes=grad_inexact_dtypes, tol={np.float32: 5e-1}),
grad_test_spec(lax.cos, nargs=1, order=2, rng_factory=jtu.rand_default,
Expand Down Expand Up @@ -213,6 +215,8 @@ def testOpGrad(self, op, rng_factory, shapes, dtype, order, tol):
raise SkipTest("pow grad imprecise on tpu")
if op is lax.cos:
order = 1 # 2nd-order gradient is imprecise on TPU.
if op is lax.log:
order = 1 # 2nd-order gradient is imprecise on TPU.

tol = jtu.join_tolerance(1.5e-1, tol) if jtu.num_float_bits(dtype) == 32 else tol
args = tuple(rng(shape, dtype) for shape in shapes)
Expand Down
Loading