Skip to content

Commit

Permalink
add lax.exp2_p -> tl.math.exp2
Browse files Browse the repository at this point in the history
part of #204
  • Loading branch information
mattjj committed Jul 28, 2023
1 parent 46991ed commit 16c2fb2
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions jax_triton/pallas/triton_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from jax._src import pjit
from jax._src import state
from jax._src import util
from jax._src.lax import lax as lax_internal
from jax._src.lax.control_flow import for_loop
from jax._src.lib import gpu_triton as triton_kernel_call_lib
from jax._src.lib.mlir import ir
Expand Down Expand Up @@ -384,6 +385,13 @@ def _exp_lowering_rule(ctx: TritonLoweringRuleContext, a):
triton_lowering_rules[lax.exp_p] = _exp_lowering_rule


def _exp2_lowering_rule(ctx: TritonLoweringRuleContext, a):
return tl.math.exp2(a, _builder=ctx.builder)


triton_lowering_rules[lax_internal.exp2_p] = _exp2_lowering_rule


def _log_lowering_rule(ctx: TritonLoweringRuleContext, a):
return tl.log(a, _builder=ctx.builder)

Expand Down

0 comments on commit 16c2fb2

Please sign in to comment.