Skip to content

Commit

Permalink
Include source info as ir.Locations when lowering Pallas kernels on GPU
Browse files Browse the repository at this point in the history
I decided to leave out the name stacks for now for simplicity, but we might
want to add them in the future.

PiperOrigin-RevId: 614644216
  • Loading branch information
superbobry authored and jax authors committed Mar 11, 2024
1 parent 477a5aa commit 7863508
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
17 changes: 13 additions & 4 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import source_info_util
from jax._src import state
from jax._src import util
from jax._src.interpreters import mlir
Expand All @@ -44,7 +45,6 @@
from jax._src.pallas import core as pallas_core
from jax._src.pallas import primitives
from jax._src.pallas import utils as pallas_utils
from jax._src.state import AbstractRef
from jax._src.state import discharge
from jax._src.state import indexing
from jax._src.state import primitives as sp
Expand Down Expand Up @@ -73,6 +73,7 @@ class ModuleContext:
name: str
grid_mapping: GridMapping
program_ids: Sequence[ir.Value]
traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False)


@dataclasses.dataclass
Expand Down Expand Up @@ -269,7 +270,9 @@ def lower_jaxpr_to_triton_module(
for i, pid in enumerate(program_ids)
if i not in grid_mapping.mapped_dims
]
ctx = ModuleContext(name, grid_mapping, local_program_ids)
ctx = ModuleContext(
name, grid_mapping, local_program_ids, mlir.TracebackCaches()
)
if grid_mapping.num_index_operands:
raise NotImplementedError(
"Scalar prefetch not supported in Triton lowering."
Expand Down Expand Up @@ -336,9 +339,13 @@ def write_env(var: jax_core.Var, val):
avals_in = [v.aval for v in eqn.invars]
avals_out = [v.aval for v in eqn.outvars]
eqn_block_infos = map(read_block_info_env, eqn.invars)
loc = mlir._source_info_to_location(
ctx, eqn.primitive, eqn.params, eqn.source_info
)
rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos)
try:
outvals = rule(rule_ctx, *invals, **eqn.params)
with source_info_util.user_context(eqn.source_info.traceback), loc:
outvals = rule(rule_ctx, *invals, **eqn.params)
except LoweringError:
raise # We only add the extra info to the innermost exception.
except Exception as e:
Expand Down Expand Up @@ -2039,7 +2046,9 @@ def _for_lowering_rule(
step = _i32_constant(1)
init_args = map(_ensure_ir_value, args, ctx.avals_in)
# Partially discharge state from jaxpr for non-pointers
should_discharge = [not isinstance(a, AbstractRef) for a in ctx.avals_in]
should_discharge = [
not isinstance(a, state.AbstractRef) for a in ctx.avals_in
]
discharged_jaxpr, () = discharge.discharge_state(
jaxpr, (), should_discharge=[True, *should_discharge]
)
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/pallas/triton/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ def _pallas_call_ttir_lowering(
lowering_result = lowering.lower_jaxpr_to_triton_module(
jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, cuda_options
)
module_op = lowering_result.module.operation
if debug:
print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True))
lowering_result.module.dump()

grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid)
Expand All @@ -214,7 +216,7 @@ def _pallas_call_ttir_lowering(
for shape in out_shapes
]
buf = io.BytesIO()
lowering_result.module.operation.write_bytecode(buf)
module_op.write_bytecode(buf)
backend_config = dict(
name=ir.StringAttr.get(name),
ir=ir.StringAttr.get(buf.getvalue()),
Expand Down

0 comments on commit 7863508

Please sign in to comment.