Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Nov 25, 2025

Description

Step 2 in a multi-step process to have Jax execute Triton kernels to support MOE on single GPU
Steps:

  • Move Triton kernels to common
  • Use jax-triton to call the triton kernels
  • Write a JAX Primitive for this op.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

1/ Write jax-triton calls for each permutation kernels in common/triton
2/ Add test_permutation to jax. Covering:
- make row ID mapping
- chunk sorting
- permute
- unpermute

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Collaborator Author

/te-ci L0

@tdophung tdophung marked this pull request as draft November 25, 2025 02:16
@tdophung tdophung changed the title Jax wrapper for Permutation Triton kernel [JAX] Wrapper for Permutation Triton kernel Nov 25, 2025
@tdophung tdophung self-assigned this Nov 25, 2025
@tdophung tdophung added the MoE label Nov 25, 2025
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 25, 2025

Greptile Overview

Greptile Summary

Adds JAX-Triton wrappers for permutation kernels to enable MOE support on single GPU, completing step 2 of the multi-step implementation plan.

Key Changes

  • Created transformer_engine/jax/triton/permutation.py with JAX wrappers for 5 permutation operations
  • Refactored Triton kernel parameter ordering in common/triton/permutation.py to group inputs, outputs, strides, and metadata
  • Added comprehensive test suite with reference implementations and roundtrip validation tests
  • Created new module transformer_engine/jax/triton/ with proper exports

Implementation Details

  • JAX wrappers compute strides manually since JAX arrays lack .strides attribute
  • Uses dummy tensors for None pointers as jax-triton doesn't handle None correctly
  • Three-pass approach for make_row_id_map: block cumsum, global cumsum, sparse-to-dense conversion
  • Tests cover multiple dtypes (float32, bfloat16), various token/expert/hidden size combinations, and optional probability handling

Confidence Score: 4/5

  • Safe to merge with minor verification recommended for parameter ordering in kernel calls
  • Well-structured implementation with comprehensive tests and clear documentation; score reflects need to verify the refactored parameter ordering works correctly across all JAX-Triton kernel invocations at runtime
  • Verify transformer_engine/jax/triton/permutation.py parameter ordering matches refactored Triton kernels - manual testing recommended

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/triton/permutation.py 5/5 Reordered kernel parameters to group input/output pointers, sizes, strides, and metas for better organization and consistency with JAX-Triton calling conventions
transformer_engine/jax/triton/permutation.py 4/5 New JAX wrapper for Triton permutation kernels with proper stride computation and dummy tensor handling for None pointers; includes all 5 operations for MOE support
transformer_engine/jax/triton/init.py 5/5 New module initialization file that exports the 5 permutation functions for public API access
tests/jax/test_permutation.py 5/5 Comprehensive test suite with reference implementations for all 5 permutation operations, including roundtrip tests and various parameter combinations

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant JAX as JAX Wrapper<br/>(jax/triton/permutation.py)
    participant JT as jax_triton
    participant Triton as Triton Kernel<br/>(common/triton/permutation.py)
    participant GPU as GPU

    Note over User,GPU: MOE Token Permutation Flow

    User->>JAX: make_row_id_map(routing_map, num_tokens, num_experts)
    JAX->>JAX: Compute strides
    JAX->>JT: triton_call(_row_id_map_pass_1_kernel)
    JT->>Triton: Execute Pass 1 (block cumsum)
    Triton->>GPU: Parallel kernel execution
    GPU-->>Triton: row_id_map_pass1, workspace
    Triton-->>JT: Return results
    JT-->>JAX: row_id_map_pass1, workspace
    
    JAX->>JT: triton_call(_row_id_map_pass_2_kernel)
    JT->>Triton: Execute Pass 2 (cumsum all)
    Triton->>GPU: Parallel kernel execution
    GPU-->>Triton: row_id_map_pass2
    Triton-->>JT: Return results
    JT-->>JAX: row_id_map_pass2
    
    JAX->>JAX: Initialize columns [num_experts:] to -1
    
    JAX->>JT: triton_call(_row_id_map_pass_3_kernel)
    JT->>Triton: Execute Pass 3 (sparse to dense)
    Triton->>GPU: Parallel kernel execution
    GPU-->>Triton: row_id_map (final)
    Triton-->>JT: Return results
    JT-->>JAX: row_id_map
    JAX-->>User: Return row_id_map

    User->>JAX: permute_with_mask_map(inp, row_id_map, probs, ...)
    JAX->>JAX: Compute strides & create dummy tensors
    JAX->>JT: triton_call(_permute_kernel)
    JT->>Triton: Execute permutation
    Triton->>GPU: Parallel kernel execution
    GPU-->>Triton: output, permuted_probs
    Triton-->>JT: Return results
    JT-->>JAX: output, permuted_probs
    JAX-->>User: Return output, permuted_probs

    User->>JAX: unpermute_with_mask_map(inp, row_id_map, merging_probs, ...)
    JAX->>JAX: Compute strides & create dummy tensors
    JAX->>JT: triton_call(_unpermute_kernel)
    JT->>Triton: Execute unpermutation
    Triton->>GPU: Parallel kernel execution (accumulate)
    GPU-->>Triton: output, unpermuted_probs
    Triton-->>JT: Return results
    JT-->>JAX: output, unpermuted_probs
    JAX-->>User: Return output, unpermuted_probs
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@mingxu1067
Copy link
Collaborator

Which jax-triton could work with PR?
I tried with pip install jax-triton but got TypeError: CUDABackend.make_ttir() missing 1 required positional argument: 'capability when run test_permutation.py.

@tdophung
Copy link
Collaborator Author

/te-ci L0 pytorch

@tdophung
Copy link
Collaborator Author

Which jax-triton could work with PR? I tried with pip install jax-triton but got TypeError: CUDABackend.make_ttir() missing 1 required positional argument: 'capability when run test_permutation.py.

For anyone else looking and this and would like to try:
You need to build jax-triton from source, from this commit: #2419 in order to run test_permutation.py successfully

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants