Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#sdy Move Shardy mesh lift inlining pass after verification.
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example ``` File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module pipeline.run(ctx.module.operation) MLIRError: Failure while executing pass pipeline: error: ... 'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2 ... see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64> ``` PiperOrigin-RevId: 713314555
- Loading branch information