diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 976e51a71e82..d3721f83b06b 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -32,6 +32,7 @@ from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src import pjit +from jax._src import source_info_util from jax._src import state from jax._src import util from jax._src.interpreters import mlir @@ -44,7 +45,6 @@ from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils -from jax._src.state import AbstractRef from jax._src.state import discharge from jax._src.state import indexing from jax._src.state import primitives as sp @@ -73,6 +73,7 @@ class ModuleContext: name: str grid_mapping: GridMapping program_ids: Sequence[ir.Value] + traceback_caches: mlir.TracebackCaches = dataclasses.field(repr=False) @dataclasses.dataclass @@ -269,7 +270,9 @@ def lower_jaxpr_to_triton_module( for i, pid in enumerate(program_ids) if i not in grid_mapping.mapped_dims ] - ctx = ModuleContext(name, grid_mapping, local_program_ids) + ctx = ModuleContext( + name, grid_mapping, local_program_ids, mlir.TracebackCaches() + ) if grid_mapping.num_index_operands: raise NotImplementedError( "Scalar prefetch not supported in Triton lowering." @@ -336,9 +339,13 @@ def write_env(var: jax_core.Var, val): avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] eqn_block_infos = map(read_block_info_env, eqn.invars) + loc = mlir._source_info_to_location( + ctx, eqn.primitive, eqn.params, eqn.source_info + ) rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos) try: - outvals = rule(rule_ctx, *invals, **eqn.params) + with source_info_util.user_context(eqn.source_info.traceback), loc: + outvals = rule(rule_ctx, *invals, **eqn.params) except LoweringError: raise # We only add the extra info to the innermost exception. except Exception as e: @@ -2039,7 +2046,9 @@ def _for_lowering_rule( step = _i32_constant(1) init_args = map(_ensure_ir_value, args, ctx.avals_in) # Partially discharge state from jaxpr for non-pointers - should_discharge = [not isinstance(a, AbstractRef) for a in ctx.avals_in] + should_discharge = [ + not isinstance(a, state.AbstractRef) for a in ctx.avals_in + ] discharged_jaxpr, () = discharge.discharge_state( jaxpr, (), should_discharge=[True, *should_discharge] ) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index aa7b90b1bc15..cb8250d95693 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -205,7 +205,9 @@ def _pallas_call_ttir_lowering( lowering_result = lowering.lower_jaxpr_to_triton_module( jaxpr, (*in_shapes, *out_shapes), grid_mapping, name, cuda_options ) + module_op = lowering_result.module.operation if debug: + print(module_op.get_asm(enable_debug_info=True, pretty_debug_info=True)) lowering_result.module.dump() grid_x, grid_y, grid_z = normalize_grid(lowering_result.grid) @@ -214,7 +216,7 @@ def _pallas_call_ttir_lowering( for shape in out_shapes ] buf = io.BytesIO() - lowering_result.module.operation.write_bytecode(buf) + module_op.write_bytecode(buf) backend_config = dict( name=ir.StringAttr.get(name), ir=ir.StringAttr.get(buf.getvalue()),