Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #18254: [XLA:GPU] Simplifies All-Reduce if they can be simplified or are degenerated. #18297

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

copybara-service[bot]
Copy link

PR #18254: [XLA:GPU] Simplifies All-Reduce if they can be simplified or are degenerated.

Imported from GitHub PR #18254

This PR adds back the AllReduceSimplifier pass to the GPU compiler as in some cases we see JAX programs will add some degenerated psum to the HLOs,

Code pointer in JAX:

  1. https://cs.opensource.google/jax/jax/+/main:jax/experimental/shard_map.py;drc=875f44c63a6c0f082ef3b228559e86e1d0a398d1;l=1674
  2. pbroadcast to psum conversions.

The compiler should generally remove these unnecessary ARs so that it unblocks many optimizations like collective pipeliner. It also emits unncessary runtime thunks as well.
Copybara import of the project:

--
b89df55 by Yunlong Liu yunlongl@x.ai:

Removes degenerated all reduce.

--
055c409 by Yunlong Liu yunlongl@x.ai:

Creates a new e2e test target for GPU.

--
93934eb by Yunlong Liu yunlongl@x.ai:

E2E test passed.

Merging this change closes #18254

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18254 from yliu120:ar-simplify 93934eb

@copybara-service copybara-service bot force-pushed the test_685835366 branch 11 times, most recently from 84d548b to 8d42868 Compare October 21, 2024 09:15
…or are degenerated.

Imported from GitHub PR #18254

This PR adds back the AllReduceSimplifier pass to the GPU compiler as in some cases we see JAX programs will add some degenerated psum to the HLOs,

Code pointer in JAX:
1) https://cs.opensource.google/jax/jax/+/main:jax/experimental/shard_map.py;drc=875f44c63a6c0f082ef3b228559e86e1d0a398d1;l=1674
2) pbroadcast to psum conversions.

The compiler should generally remove these unnecessary ARs so that it unblocks many optimizations like collective pipeliner. It also emits unncessary runtime thunks as well.
Copybara import of the project:

--
b89df55 by Yunlong Liu <yunlongl@x.ai>:

Removes degenerated all reduce.

--
055c409 by Yunlong Liu <yunlongl@x.ai>:

Creates a new e2e test target for GPU.

--
93934eb by Yunlong Liu <yunlongl@x.ai>:

E2E test passed.

Merging this change closes #18254

FUTURE_COPYBARA_INTEGRATE_REVIEW=#18254 from yliu120:ar-simplify 93934eb
PiperOrigin-RevId: 685835366
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant