Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #18254: [XLA:GPU] Simplifies All-Reduce if they can be simplified …
…or are degenerated. Imported from GitHub PR #18254 This PR adds back the AllReduceSimplifier pass to the GPU compiler as in some cases we see JAX programs will add some degenerated psum to the HLOs, Code pointer in JAX: 1) https://cs.opensource.google/jax/jax/+/main:jax/experimental/shard_map.py;drc=875f44c63a6c0f082ef3b228559e86e1d0a398d1;l=1674 2) pbroadcast to psum conversions. The compiler should generally remove these unnecessary ARs so that it unblocks many optimizations like collective pipeliner. It also emits unncessary runtime thunks as well. Copybara import of the project: -- b89df55 by Yunlong Liu <yunlongl@x.ai>: Removes degenerated all reduce. -- 055c409 by Yunlong Liu <yunlongl@x.ai>: Creates a new e2e test target for GPU. -- 93934eb by Yunlong Liu <yunlongl@x.ai>: E2E test passed. Merging this change closes #18254 FUTURE_COPYBARA_INTEGRATE_REVIEW=#18254 from yliu120:ar-simplify 93934eb PiperOrigin-RevId: 685835366
- Loading branch information