-
Notifications
You must be signed in to change notification settings - Fork 565
[JAX] Wrapper for Permutation Triton kernel #2419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te-ci L0 |
Greptile OverviewGreptile SummaryAdds JAX-Triton wrappers for permutation kernels to enable MOE support on single GPU, completing step 2 of the multi-step implementation plan. Key Changes
Implementation Details
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as User Code
participant JAX as JAX Wrapper<br/>(jax/triton/permutation.py)
participant JT as jax_triton
participant Triton as Triton Kernel<br/>(common/triton/permutation.py)
participant GPU as GPU
Note over User,GPU: MOE Token Permutation Flow
User->>JAX: make_row_id_map(routing_map, num_tokens, num_experts)
JAX->>JAX: Compute strides
JAX->>JT: triton_call(_row_id_map_pass_1_kernel)
JT->>Triton: Execute Pass 1 (block cumsum)
Triton->>GPU: Parallel kernel execution
GPU-->>Triton: row_id_map_pass1, workspace
Triton-->>JT: Return results
JT-->>JAX: row_id_map_pass1, workspace
JAX->>JT: triton_call(_row_id_map_pass_2_kernel)
JT->>Triton: Execute Pass 2 (cumsum all)
Triton->>GPU: Parallel kernel execution
GPU-->>Triton: row_id_map_pass2
Triton-->>JT: Return results
JT-->>JAX: row_id_map_pass2
JAX->>JAX: Initialize columns [num_experts:] to -1
JAX->>JT: triton_call(_row_id_map_pass_3_kernel)
JT->>Triton: Execute Pass 3 (sparse to dense)
Triton->>GPU: Parallel kernel execution
GPU-->>Triton: row_id_map (final)
Triton-->>JT: Return results
JT-->>JAX: row_id_map
JAX-->>User: Return row_id_map
User->>JAX: permute_with_mask_map(inp, row_id_map, probs, ...)
JAX->>JAX: Compute strides & create dummy tensors
JAX->>JT: triton_call(_permute_kernel)
JT->>Triton: Execute permutation
Triton->>GPU: Parallel kernel execution
GPU-->>Triton: output, permuted_probs
Triton-->>JT: Return results
JT-->>JAX: output, permuted_probs
JAX-->>User: Return output, permuted_probs
User->>JAX: unpermute_with_mask_map(inp, row_id_map, merging_probs, ...)
JAX->>JAX: Compute strides & create dummy tensors
JAX->>JT: triton_call(_unpermute_kernel)
JT->>Triton: Execute unpermutation
Triton->>GPU: Parallel kernel execution (accumulate)
GPU-->>Triton: output, unpermuted_probs
Triton-->>JT: Return results
JT-->>JAX: output, unpermuted_probs
JAX-->>User: Return output, unpermuted_probs
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
|
Which |
To eliminate pytorch failures
|
/te-ci L0 pytorch |
For anyone else looking and this and would like to try: |
Description
Step 2 in a multi-step process to have Jax execute Triton kernels to support MOE on single GPU
Steps:
Fixes # (issue)
Type of change
Changes
1/ Write jax-triton calls for each permutation kernels in common/triton
2/ Add test_permutation to jax. Covering:
- make row ID mapping
- chunk sorting
- permute
- unpermute
Checklist: