diff --git a/jax_triton/pallas/triton_lowering.py b/jax_triton/pallas/triton_lowering.py index 866d8659..51146d2c 100644 --- a/jax_triton/pallas/triton_lowering.py +++ b/jax_triton/pallas/triton_lowering.py @@ -31,6 +31,7 @@ from jax._src import pjit from jax._src import state from jax._src import util +from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop from jax._src.lib import gpu_triton as triton_kernel_call_lib from jax._src.lib.mlir import ir @@ -384,6 +385,13 @@ def _exp_lowering_rule(ctx: TritonLoweringRuleContext, a): triton_lowering_rules[lax.exp_p] = _exp_lowering_rule +def _exp2_lowering_rule(ctx: TritonLoweringRuleContext, a): + return tl.math.exp2(a, _builder=ctx.builder) + + +triton_lowering_rules[lax_internal.exp2_p] = _exp2_lowering_rule + + def _log_lowering_rule(ctx: TritonLoweringRuleContext, a): return tl.log(a, _builder=ctx.builder)