Skip to content

Understanding the source code of jax.lax | tanh #24240

Answered by dfm
Rohanjames1997 asked this question in Q&A
Discussion options

You must be logged in to vote

It's not so straightforward to follow the full stack, but the basic flow is that a JAX function get's lowered to StableHLO, which you can inspect as follows:

import jax
import jax.numpy as jnp

print(jax.jit(jnp.tanh).lower(jnp.linspace(-1, 1, 5)).as_text())

Since StableHLO has a tanh intrinsic, the intermediate representation is (relatively) simple:

module @jit_tanh attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<5xf32> {mhlo.layout_mode = "default"}) -> (tensor<5xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.tanh %arg0 : tensor<5xf32>
    return %0 : tensor<5xf32>
  }
}

Then, depending on…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Rohanjames1997
Comment options

Answer selected by Rohanjames1997
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants