From 5c097c8f62cc1233cb6ffa85ed4aefe310075706 Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Wed, 8 Jan 2025 09:17:16 -0800 Subject: [PATCH] #sdy Move Shardy mesh lift inlining pass after verification. Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example ``` File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module pipeline.run(ctx.module.operation) MLIRError: Failure while executing pass pipeline: error: ... 'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2 ... see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64> ``` PiperOrigin-RevId: 713314555 --- jax/_src/interpreters/mlir.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 8e3372113202..f552a779a844 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1205,10 +1205,6 @@ def lower_jaxpr_to_module( arg_layouts=in_layouts, result_layouts=out_layouts, propagated_out_mem_kinds=propagated_out_mem_kinds) - if config.use_shardy_partitioner.value: - pipeline = passmanager.PassManager.parse( - 'builtin.module(sdy-lift-inlined-meshes)') - pipeline.run(ctx.module.operation) try: if not ctx.module.operation.verify(): @@ -1227,6 +1223,12 @@ def emit_diagnostic_info(d): raise ValueError("\n".join(msg_lines) + "\n" + dump_module_message(ctx.module, "verification")) from e + if config.use_shardy_partitioner.value: + with ctx.context: + pipeline = passmanager.PassManager.parse( + 'builtin.module(sdy-lift-inlined-meshes)') + pipeline.run(ctx.module.operation) + return LoweringResult(ctx.module, ctx.keepalives, ctx.host_callbacks, ctx.shape_poly_state)