Skip to content

Commit

Permalink
#sdy Move Shardy mesh lift inlining pass after verification.
Browse files Browse the repository at this point in the history
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example
```
File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module
    pipeline.run(ctx.module.operation)
MLIRError: Failure while executing pass pipeline:
error:
...
'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2
...
see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64>
```
PiperOrigin-RevId: 713314555
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Jan 8, 2025
1 parent 0389d61 commit 5c097c8
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,10 +1205,6 @@ def lower_jaxpr_to_module(
arg_layouts=in_layouts,
result_layouts=out_layouts,
propagated_out_mem_kinds=propagated_out_mem_kinds)
if config.use_shardy_partitioner.value:
pipeline = passmanager.PassManager.parse(
'builtin.module(sdy-lift-inlined-meshes)')
pipeline.run(ctx.module.operation)

try:
if not ctx.module.operation.verify():
Expand All @@ -1227,6 +1223,12 @@ def emit_diagnostic_info(d):
raise ValueError("\n".join(msg_lines) + "\n" +
dump_module_message(ctx.module, "verification")) from e

if config.use_shardy_partitioner.value:
with ctx.context:
pipeline = passmanager.PassManager.parse(
'builtin.module(sdy-lift-inlined-meshes)')
pipeline.run(ctx.module.operation)

return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks,
ctx.shape_poly_state)

Expand Down

0 comments on commit 5c097c8

Please sign in to comment.