From 2b3bca72f8e22f927b489108dd0502fe5d66d61f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 1 Dec 2025 09:53:34 -0500 Subject: [PATCH 01/10] init triton binding Signed-off-by: Phuong Nguyen --- tests/jax/test_triton_custom_calls.py | 119 +++++++ .../jax/triton_extensions/__init__.py | 25 ++ .../jax/triton_extensions/utils.py | 332 ++++++++++++++++++ 3 files changed, 476 insertions(+) create mode 100644 tests/jax/test_triton_custom_calls.py create mode 100644 transformer_engine/jax/triton_extensions/__init__.py create mode 100644 transformer_engine/jax/triton_extensions/utils.py diff --git a/tests/jax/test_triton_custom_calls.py b/tests/jax/test_triton_custom_calls.py new file mode 100644 index 0000000000..ba421546c7 --- /dev/null +++ b/tests/jax/test_triton_custom_calls.py @@ -0,0 +1,119 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Tests for Triton-based custom calls in TE JAX.""" + +import jax +import jax.numpy as jnp +import pytest + +from utils import assert_allclose, pytest_parametrize_wrapper + +import triton +import triton.language as tl + +from transformer_engine.jax.cpp_extensions.base import BasePrimitive, register_primitive +from transformer_engine.jax.triton_extensions import triton_call_lowering + + +@pytest.fixture(autouse=True, scope="module") +def init(): + """WAR for CUDA uninitialize error""" + _ = jnp.zeros(0) + yield + + +class TestTritonBinding: + """Test Triton binding primitive.""" + + # Define autotuned Triton kernel + @staticmethod + @triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE": 256} + ), # Uses defaults: num_warps=4, num_stages=3 + triton.Config({"BLOCK_SIZE": 512}, num_warps=8), # Custom num_warps + ], + key=["n_elements"], # Autotune based on input size + ) + @triton.jit + def amax_kernel( + x_ptr, + amax_ptr, + n_elements: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + """Compute amax using Triton with autotuning.""" + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) + abs_x = tl.abs(x) + block_max = tl.max(abs_x) + + tl.atomic_max(amax_ptr, block_max) + + # Define test primitive + class AmaxTritonPrimitive(BasePrimitive): + """Test primitive using Triton kernel.""" + + name = "te_amax_triton_test" + multiple_results = False + impl_static_args = (1,) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(x_aval, *, block_size): + del block_size + return jax.core.ShapedArray((1,), jnp.float32) + + @staticmethod + def impl(x, block_size): + assert TestTritonBinding.AmaxTritonPrimitive.inner_primitive is not None + return TestTritonBinding.AmaxTritonPrimitive.inner_primitive.bind( + x, block_size=block_size + ) + + @staticmethod + def lowering(ctx, x, *, block_size): + """MLIR lowering using Triton kernel.""" + n_elements = 1 + for dim in ctx.avals_in[0].shape: + n_elements *= dim + + grid = (triton.cdiv(n_elements, block_size),) + + return triton_call_lowering( + ctx, + TestTritonBinding.amax_kernel, # Autotuned kernel + x, + grid=grid, + constexprs={"n_elements": n_elements}, + # BLOCK_SIZE comes from autotuner config, not passed here + ) + + register_primitive(AmaxTritonPrimitive) + + @staticmethod + def _triton_amax(x: jnp.ndarray, block_size: int = 1024) -> jnp.ndarray: + """Compute amax using Triton kernel.""" + return TestTritonBinding.AmaxTritonPrimitive.outer_primitive.bind( + x, block_size=block_size + ) + + @pytest_parametrize_wrapper("shape", [(1024, 1024)]) + @pytest_parametrize_wrapper("dtype", [jnp.bfloat16]) + def test_triton_amax(self, shape, dtype): + """Test Triton amax with JIT.""" + key = jax.random.PRNGKey(0) + x = jax.random.uniform(key, shape, dtype) + + expected = jnp.max(jnp.abs(x), keepdims=False).astype(jnp.float32) + jitted_amax = jax.jit(self._triton_amax) + result = jitted_amax(x) + + assert_allclose(result, expected, dtype=jnp.float32) diff --git a/transformer_engine/jax/triton_extensions/__init__.py b/transformer_engine/jax/triton_extensions/__init__.py new file mode 100644 index 0000000000..7ce6c476c2 --- /dev/null +++ b/transformer_engine/jax/triton_extensions/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +Triton extensions for Transformer Engine JAX. + +This module provides Triton kernel integration for TE primitives. + +IMPORTANT: This module requires Triton to be installed. If you don't have Triton, +use transformer_engine.jax.cpp_extensions instead (CUDA/FFI based primitives). + +Install Triton: pip install triton + + +Usage: + # Import utilities + from transformer_engine.jax.triton_extensions import triton_call_lowering + + # Use in your primitive's lowering + @staticmethod + def lowering(ctx, x, **kwargs): + return triton_call_lowering(ctx, my_kernel, x, ...) +""" + +from .utils import * diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py new file mode 100644 index 0000000000..aa14bb8fd1 --- /dev/null +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -0,0 +1,332 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +Triton utilities for JAX primitives. + +This module provides utility functions for integrating Triton kernels into +JAX primitives. Triton is only imported when this module is used. +""" + +import hashlib +from typing import Any, Callable, Mapping +import zlib + +from jax import core +import jax +import jax.numpy as jnp + + +try: + from jax._src.lib import gpu_triton + from triton.compiler import compiler as tc + from triton.backends.nvidia import compiler as cb + from triton.runtime import autotuner +except ImportError as e: + raise ImportError( + "Triton is required for transformer_engine.jax.triton_extensions. " + "Install with: pip install triton\n" + "If you don't need Triton, use transformer_engine.jax.cpp_extensions instead." + ) from e + + +__all__ = ["triton_call_lowering"] + +# Triton kernel cache (module-level, shared across all kernels) +_TRITON_KERNEL_CACHE = {} + + +def get_triton_dtype(aval): + """Convert JAX dtype to Triton type string. + + Args: + aval: JAX ShapedArray + + Returns: + Triton type string (e.g., "*fp32" for pointer, "i32" for scalar) + """ + dtype_map = { + jnp.dtype("bfloat16"): "bf16", + jnp.dtype("float64"): "fp64", + jnp.dtype("float32"): "fp32", + jnp.dtype("float16"): "fp16", + jnp.dtype("float8_e4m3fn"): "fp8e4nv", + jnp.dtype("float8_e5m2"): "fp8e5", + jnp.dtype("int64"): "i64", + jnp.dtype("int32"): "i32", + jnp.dtype("int16"): "i16", + jnp.dtype("int8"): "i8", + jnp.dtype("bool"): "i1", + } + + assert isinstance(aval, core.ShapedArray), "aval must be a JAX ShapedArray" + return f"*{dtype_map[aval.dtype]}" + + +def compile_triton( + kernel_fn: Callable, + signature: Mapping[str, str], + constants: Mapping[str, Any], + num_warps: int, + num_stages: int, + num_ctas: int, + compute_capability: int, + enable_fp_fusion: bool = False, +): + """Compile a Triton kernel to PTX. + + Kernels are cached to avoid recompilation. + + Args: + kernel_fn: Triton kernel function (decorated with @triton.jit) + signature: Dict mapping arg names to types (e.g., {"x_ptr": "*fp32", "n": "i32"}) + constants: Dict of compile-time constants + num_warps: Number of warps per block + num_stages: Number of pipeline stages + num_ctas: Number of CTAs (cooperative thread arrays) + compute_capability: CUDA compute capability + enable_fp_fusion: Enable FP fusion optimizations (default False for accuracy) + + Returns: + TritonKernel object for JAX + """ + # Create cache key + cache_key = hashlib.md5( + str( + ( + kernel_fn.__name__, + tuple(sorted(signature.items())), + tuple(sorted(constants.items())), + num_warps, + num_stages, + num_ctas, + enable_fp_fusion, + compute_capability, + ) + ).encode() + ).hexdigest() + + if cache_key in _TRITON_KERNEL_CACHE: + return _TRITON_KERNEL_CACHE[cache_key] + + # Compile kernel + options = cb.CUDAOptions( + num_warps=num_warps, + num_stages=num_stages, + num_ctas=num_ctas, + cluster_dims=(1, 1, 1), + debug=False, + enable_fp_fusion=enable_fp_fusion, + ) + + # Mark constants as constexpr in signature + signature_with_constexpr = dict(signature) + for const_name in constants.keys(): + if const_name in signature_with_constexpr: + signature_with_constexpr[const_name] = "constexpr" + + src = tc.ASTSource( + fn=kernel_fn, + constexprs=constants, + signature=signature_with_constexpr, + ) + + compiled = tc.compile( + src, + target=tc.GPUTarget("cuda", compute_capability, 32), + options=options.__dict__, + ) + + # Create kernel object for JAX + kernel = gpu_triton.TritonKernel( + compiled.name, + num_warps, + compiled.metadata.shared, + compiled.asm["ptx"], + "", # ttir + compute_capability, + 1, + 1, + 1, # cluster_dims + ) + + _TRITON_KERNEL_CACHE[cache_key] = kernel + return kernel + + +def triton_call_lowering( + ctx, + kernel_fn: Callable, + *array_args, + grid, + input_output_aliases: Mapping[int, int] = None, + constexprs: Mapping[str, Any] = None, +): + """Helper for MLIR lowering that calls a Triton kernel. + + Use this in your primitive's lowering method to call Triton kernels. + + Args: + ctx: MLIR lowering context + kernel_fn: Triton kernel function + *array_args: Input arrays (from ctx) + grid: Grid dimensions (int or tuple) + input_output_aliases: Mapping of input to output aliases + constexprs: Compile-time constants for the kernel + + Returns: + MLIR lowering result + + Example: + @staticmethod + def lowering(ctx, x, *, block_size): + from ..triton_extensions import triton_call_lowering + n = ctx.avals_in[0].size + return triton_call_lowering( + ctx, my_kernel, x, + grid=(triton.cdiv(n, block_size),), + n_elements=n, + BLOCK_SIZE=block_size + ) + """ + # Get compute capability using gpu_triton + compute_capability = gpu_triton.get_compute_capability(0) # device 0 + + # Build signature dict: map arg names to types + # Get arg names from kernel function + if isinstance(kernel_fn, autotuner.Autotuner): + arg_names = kernel_fn.fn.arg_names + else: + arg_names = kernel_fn.arg_names + + # Build signature for inputs + outputs + all_avals = list(ctx.avals_in) + list(ctx.avals_out) + signature = {arg_names[i]: get_triton_dtype(aval) for i, aval in enumerate(all_avals)} + + # Normalize grid to 3D + if isinstance(grid, int): + grid_tuple = (grid, 1, 1) + elif len(grid) == 1: + grid_tuple = (grid[0], 1, 1) + elif len(grid) == 2: + grid_tuple = (grid[0], grid[1], 1) + else: + grid_tuple = grid[:3] + + # Default values for the kernel + actual_kernel_fn = kernel_fn + num_warps = 32 + num_stages = ( + 1 # TODO(Phuong): consider if it is beneficial to expose num_warps, num_stages, num_ctas + ) + num_ctas = 1 + kernel_constexprs = constexprs if constexprs is not None else {} + + # Handle autotuned kernels - compile all configs + if isinstance(kernel_fn, autotuner.Autotuner): + # Compile all configs for runtime selection + kernel_calls = [] + actual_kernel_fn = kernel_fn.fn + + for config in kernel_fn.configs: + # Extract parameters from config + config_num_warps = config.num_warps if config.num_warps is not None else num_warps + config_num_stages = config.num_stages if config.num_stages is not None else num_stages + config_num_ctas = config.num_ctas if config.num_ctas is not None else num_ctas + + # Merge config kwargs with user constexprs + config_constexprs = {**config.kwargs, **(constexprs if constexprs else {})} + + # Compile this config + config_kernel = compile_triton( + actual_kernel_fn, + signature, + config_constexprs, + config_num_warps, + config_num_stages, + config_num_ctas, + compute_capability, + enable_fp_fusion=False, + ) + + # Create kernel call for this config + config_params = [] + for _ in list(ctx.avals_in) + list(ctx.avals_out): + config_params.append(gpu_triton.create_array_parameter(0, 16)) + + config_call = gpu_triton.TritonKernelCall( + config_kernel, + grid_tuple[0], + grid_tuple[1], + grid_tuple[2], + config_params, + ) + + kernel_calls.append((config_call, str(config))) + + # Create autotuned kernel call + # Convert input_output_aliases to format with sizes + if input_output_aliases is None: + input_output_aliases = {} + + input_output_aliases_with_sizes = tuple( + ( + input_idx, + output_idx, + ctx.avals_in[input_idx].size * ctx.avals_in[input_idx].dtype.itemsize, + ) + for input_idx, output_idx in input_output_aliases.items() + ) + + kernel_call = gpu_triton.TritonAutotunedKernelCall( + f"{actual_kernel_fn.__name__}_autotuned", + kernel_calls, + input_output_aliases_with_sizes, + ) + + # Skip the single kernel call creation below + use_autotuned = True + else: + # Regular kernel: compile single config + kernel = compile_triton( + actual_kernel_fn, + signature, + kernel_constexprs, + num_warps, + num_stages, + num_ctas, + compute_capability, + enable_fp_fusion=False, + ) + use_autotuned = False + + # Create kernel call (if not already created by autotuner) + if not use_autotuned: + # Create kernel parameters for single config + kernel_params = [] + for aval in list(ctx.avals_in) + list(ctx.avals_out): + kernel_params.append(gpu_triton.create_array_parameter(0, 16)) + + kernel_call = gpu_triton.TritonKernelCall( + kernel, + grid_tuple[0], + grid_tuple[1], + grid_tuple[2], + kernel_params, + ) + + serialized_metadata = b"" + call_proto = kernel_call.to_proto(actual_kernel_fn.__name__, serialized_metadata) + + if input_output_aliases is None: + input_output_aliases = {} + + # Use JAX FFI lowering with compressed protobuf + rule = jax.ffi.ffi_lowering( + "triton_kernel_call", # Custom call target registered in gpu_triton.py + api_version=2, + backend_config=zlib.compress(call_proto), + operand_output_aliases=input_output_aliases, + ) + + return rule(ctx, *array_args) From 58f2f2ef3a4896679c6ab36708c81cbb2ecb1010 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 20:21:12 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_triton_custom_calls.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/jax/test_triton_custom_calls.py b/tests/jax/test_triton_custom_calls.py index ba421546c7..926074616f 100644 --- a/tests/jax/test_triton_custom_calls.py +++ b/tests/jax/test_triton_custom_calls.py @@ -30,9 +30,7 @@ class TestTritonBinding: @staticmethod @triton.autotune( configs=[ - triton.Config( - {"BLOCK_SIZE": 256} - ), # Uses defaults: num_warps=4, num_stages=3 + triton.Config({"BLOCK_SIZE": 256}), # Uses defaults: num_warps=4, num_stages=3 triton.Config({"BLOCK_SIZE": 512}, num_warps=8), # Custom num_warps ], key=["n_elements"], # Autotune based on input size @@ -101,9 +99,7 @@ def lowering(ctx, x, *, block_size): @staticmethod def _triton_amax(x: jnp.ndarray, block_size: int = 1024) -> jnp.ndarray: """Compute amax using Triton kernel.""" - return TestTritonBinding.AmaxTritonPrimitive.outer_primitive.bind( - x, block_size=block_size - ) + return TestTritonBinding.AmaxTritonPrimitive.outer_primitive.bind(x, block_size=block_size) @pytest_parametrize_wrapper("shape", [(1024, 1024)]) @pytest_parametrize_wrapper("dtype", [jnp.bfloat16]) From 31b104f64ebdcdf7d03bd2af94a089b4bbb5d069 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 1 Dec 2025 12:28:51 -0800 Subject: [PATCH 03/10] fix lint Signed-off-by: Phuong Nguyen --- transformer_engine/jax/triton_extensions/utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index aa14bb8fd1..77cfbca3dd 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -284,8 +284,6 @@ def lowering(ctx, x, *, block_size): input_output_aliases_with_sizes, ) - # Skip the single kernel call creation below - use_autotuned = True else: # Regular kernel: compile single config kernel = compile_triton( @@ -298,11 +296,7 @@ def lowering(ctx, x, *, block_size): compute_capability, enable_fp_fusion=False, ) - use_autotuned = False - # Create kernel call (if not already created by autotuner) - if not use_autotuned: - # Create kernel parameters for single config kernel_params = [] for aval in list(ctx.avals_in) + list(ctx.avals_out): kernel_params.append(gpu_triton.create_array_parameter(0, 16)) From b4fd76f68eb660d174f1c44be7540d1491f5af5f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 1 Dec 2025 15:30:01 -0500 Subject: [PATCH 04/10] More dtypes Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Phuong Nguyen --- transformer_engine/jax/triton_extensions/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 77cfbca3dd..daf5d334eb 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -56,6 +56,10 @@ def get_triton_dtype(aval): jnp.dtype("int32"): "i32", jnp.dtype("int16"): "i16", jnp.dtype("int8"): "i8", + jnp.dtype("uint64"): "u64", + jnp.dtype("uint32"): "u32", + jnp.dtype("uint16"): "u16", + jnp.dtype("uint8"): "u8", jnp.dtype("bool"): "i1", } From 79b21d1a714a91f3ba7e36c23c281f045c95efa3 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 1 Dec 2025 13:30:08 -0800 Subject: [PATCH 05/10] added triton as test dependency Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 1f9552eb69..df78bf3e2f 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -20,7 +20,7 @@ def install_requirements() -> List[str]: def test_requirements() -> List[str]: """Test dependencies for TE/JAX extensions.""" - return ["numpy"] + return ["numpy", "triton"] def xla_path() -> str: From 1d3730f7476fa3636be1a80fd37eb354da90b0c7 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 24 Nov 2025 16:51:47 -0800 Subject: [PATCH 06/10] cherry pick permutation from teddy/jax-triton-initial-commit Signed-off-by: tdophung --- tests/jax/test_permutation.py | 827 ++++++++++++++++++ .../common/triton/permutation.py | 2 +- transformer_engine/jax/triton/__init__.py | 21 + transformer_engine/jax/triton/permutation.py | 567 ++++++++++++ 4 files changed, 1416 insertions(+), 1 deletion(-) create mode 100644 tests/jax/test_permutation.py create mode 100644 transformer_engine/jax/triton/__init__.py create mode 100644 transformer_engine/jax/triton/permutation.py diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py new file mode 100644 index 0000000000..37c4208df7 --- /dev/null +++ b/tests/jax/test_permutation.py @@ -0,0 +1,827 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for permutation Triton kernels""" + +# Patch jax-triton for Triton 3.5.1 compatibility - MUST BE FIRST! + +import jax +import jax.numpy as jnp +import pytest +from jax import jit + +from transformer_engine.jax.triton.permutation import ( + make_row_id_map, + permute_with_mask_map, + unpermute_with_mask_map, + make_chunk_sort_map, + sort_chunks_by_map, +) +from utils import assert_allclose, dtype_tols + + +def reference_make_row_id_map( + routing_map: jnp.ndarray, + num_tokens: int, + num_experts: int, +) -> jnp.ndarray: + """ + Reference implementation of make_row_id_map using JAX primitives. + + Parameters + ---------- + routing_map : jnp.ndarray + Input tensor of shape [num_tokens, num_experts]. Mask indicating which experts + are routed to which tokens (1 = routed, 0 = not routed). + num_tokens : int + Number of tokens in the input tensor. + num_experts : int + Number of experts in the input tensor. + + Returns + ------- + row_id_map : jnp.ndarray + The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1]. + """ + row_id_map = jnp.full((num_tokens, num_experts * 2 + 1), -1, dtype=jnp.int32) + + # For each expert, compute cumulative sum to get destination indices + cumsum_per_expert = jnp.cumsum(routing_map, axis=0) + + # Compute total tokens per expert + tokens_per_expert = jnp.sum(routing_map, axis=0) + expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]]) + + # Build the row_id_map + for token_idx in range(num_tokens): + routed_experts = jnp.where(routing_map[token_idx] == 1)[0] + n_routed = len(routed_experts) + + # Store number of routed experts in the last position + row_id_map = row_id_map.at[token_idx, -1].set(n_routed) + + # For each routed expert, compute destination row and store it + dest_rows = [] + expert_indices = [] + for expert_idx in routed_experts: + # Destination row = expert offset + (cumsum - 1) + dest_row = expert_offsets[expert_idx] + cumsum_per_expert[token_idx, expert_idx] - 1 + dest_rows.append(dest_row) + expert_indices.append(expert_idx) + + # Sort by destination row + if n_routed > 0: + sort_indices = jnp.argsort(-jnp.array(dest_rows)) # Negative for descending sort + sorted_dest_rows = jnp.array(dest_rows)[sort_indices] + sorted_expert_indices = jnp.array(expert_indices)[sort_indices] + + # Store sorted destination rows and expert indices + for i in range(n_routed): + row_id_map = row_id_map.at[token_idx, i].set(sorted_dest_rows[i]) + row_id_map = row_id_map.at[token_idx, num_experts + i].set(sorted_expert_indices[i]) + + return row_id_map + + +def reference_permute_with_mask_map( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + probs: jnp.ndarray, + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, +) -> tuple: + """ + Reference implementation of permute_with_mask_map using JAX primitives. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape [num_tokens, hidden_size]. + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1]. + probs : jnp.ndarray + The probabilities of the input tensor. + num_tokens : int + Number of tokens in the input tensor. + num_experts : int + Number of experts. + num_out_tokens : int + Number of tokens in the permuted tensor. + hidden_size : int + Hidden size of the input tensor. + + Returns + ------- + output : jnp.ndarray + Permuted output tensor of shape [num_out_tokens, hidden_size]. + permuted_probs : jnp.ndarray + Permuted probabilities if probs was provided, None otherwise. + """ + output = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) + permuted_probs = None if probs is None else jnp.zeros((num_out_tokens,), dtype=probs.dtype) + + for token_idx in range(num_tokens): + n_routed = int(row_id_map[token_idx, -1]) + for i in range(n_routed): + dest_row = int(row_id_map[token_idx, i]) + expert_idx = int(row_id_map[token_idx, num_experts + i]) + + # Get probability for this expert + if probs is not None: + if probs.ndim == 1: + prob = probs[token_idx] + else: + prob = probs[token_idx, expert_idx] + + # Match kernel behavior: if prob == 0.0, zero out the output (padding indicator) + if prob == 0.0: + output = output.at[dest_row].set(0.0) + else: + output = output.at[dest_row].set(inp[token_idx]) + + permuted_probs = permuted_probs.at[dest_row].set(prob) + else: + output = output.at[dest_row].set(inp[token_idx]) + + return output, permuted_probs + + +def reference_unpermute_with_mask_map( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + merging_probs: jnp.ndarray, + permuted_probs: jnp.ndarray, + num_tokens: int, + num_experts: int, + hidden_size: int, +) -> tuple: + """ + Reference implementation of unpermute_with_mask_map using JAX primitives. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape [num_out_tokens, hidden_size]. + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape [num_tokens, num_experts * 2 + 1]. + merging_probs : jnp.ndarray + The merging probabilities for weighted reduction. + permuted_probs : jnp.ndarray + The permuted probabilities. + num_tokens : int + Number of tokens. + num_experts : int + Number of experts. + hidden_size : int + Hidden size. + + Returns + ------- + output : jnp.ndarray + Unpermuted output tensor of shape [num_tokens, hidden_size]. + unpermuted_probs : jnp.ndarray + Unpermuted probabilities if permuted_probs was provided, None otherwise. + """ + output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) + unpermuted_probs = None if permuted_probs is None else jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype) + + for token_idx in range(num_tokens): + n_routed = int(row_id_map[token_idx, -1]) + for i in range(n_routed): + src_row = int(row_id_map[token_idx, i]) + expert_idx = int(row_id_map[token_idx, num_experts + i]) + + if merging_probs is not None: + weight = merging_probs[token_idx, expert_idx] + output = output.at[token_idx].add(inp[src_row] * weight) + else: + output = output.at[token_idx].add(inp[src_row]) + + if permuted_probs is not None: + unpermuted_probs = unpermuted_probs.at[token_idx, expert_idx].set(permuted_probs[src_row]) + + return output, unpermuted_probs + + +def reference_make_chunk_sort_map( + split_sizes: jnp.ndarray, + sorted_indices: jnp.ndarray, + num_tokens: int, + num_splits: int, +) -> jnp.ndarray: + """ + Reference implementation of make_chunk_sort_map using JAX primitives. + + Parameters + ---------- + split_sizes : jnp.ndarray + The sizes of the chunks of shape [num_splits,]. + sorted_indices : jnp.ndarray + The indices of the sorted chunks of shape [num_splits,]. + num_tokens : int + Number of tokens. + num_splits : int + Number of splits. + + Returns + ------- + row_id_map : jnp.ndarray + Row ID map for chunk sorting of shape [num_tokens,]. + """ + row_id_map = jnp.zeros((num_tokens,), dtype=jnp.int32) + + # Compute cumulative positions + cumsum_sizes = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)]) + + # For each chunk, compute the destination indices + dest_offset = 0 + for sorted_idx in sorted_indices: + chunk_start = cumsum_sizes[sorted_idx] + chunk_end = cumsum_sizes[sorted_idx + 1] + chunk_size = chunk_end - chunk_start + + # Map source positions to destination positions + for i in range(chunk_size): + row_id_map = row_id_map.at[chunk_start + i].set(dest_offset + i) + + dest_offset += chunk_size + + return row_id_map + + +def reference_sort_chunks_by_map( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + probs: jnp.ndarray, + num_tokens: int, + hidden_size: int, + is_forward: bool, +) -> tuple: + """ + Reference implementation of sort_chunks_by_map using JAX primitives. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape [num_tokens, hidden_size]. + row_id_map : jnp.ndarray + The token to destination mapping of shape [num_tokens,]. + probs : jnp.ndarray + The probabilities. + num_tokens : int + Number of tokens. + hidden_size : int + Hidden size. + is_forward : bool + Whether this is forward or backward. + + Returns + ------- + output : jnp.ndarray + Sorted output tensor of shape [num_tokens, hidden_size]. + permuted_probs : jnp.ndarray + Sorted probabilities if probs was provided, None otherwise. + """ + output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) + permuted_probs = None if probs is None else jnp.zeros((num_tokens,), dtype=probs.dtype) + + if is_forward: + # Forward: src -> dest + for src_idx in range(num_tokens): + dest_idx = int(row_id_map[src_idx]) + output = output.at[dest_idx].set(inp[src_idx]) + if probs is not None: + permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx]) + else: + # Backward: dest -> src + for dest_idx in range(num_tokens): + src_idx = int(row_id_map[dest_idx]) + output = output.at[dest_idx].set(inp[src_idx]) + if probs is not None: + permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx]) + + return output, permuted_probs + + +class TestPermutation: + """Test permutation operations implementation""" + + @staticmethod + def generate_routing_map( + num_tokens: int, + num_experts: int, + tokens_per_expert: int = 2, + key: jax.Array = None, + use_fixed_per_token: bool = True, + ): + """Generate random routing map for testing + + Parameters + ---------- + num_tokens : int + Number of tokens + num_experts : int + Number of experts + tokens_per_expert : int + If use_fixed_per_token=True, each token gets exactly this many experts. + If use_fixed_per_token=False, total routed connections = num_tokens * tokens_per_expert + key : jax.Array + Random key + use_fixed_per_token : bool + If True: each token routes to exactly tokens_per_expert experts (old behavior) + If False: randomly distribute routing like PyTorch (different n_routed per token) + """ + if key is None: + key = jax.random.PRNGKey(0) + + if use_fixed_per_token: + # Old behavior: each token routes to exactly tokens_per_expert experts + routing_map = jnp.zeros((num_tokens, num_experts), dtype=jnp.int32) + + # Randomly assign each token to tokens_per_expert experts + for token_idx in range(num_tokens): + key, subkey = jax.random.split(key) + expert_indices = jax.random.choice( + subkey, num_experts, shape=(tokens_per_expert,), replace=False + ) + routing_map = routing_map.at[token_idx, expert_indices].set(1) + else: + # PyTorch-style: randomly distribute routing (varying n_routed per token) + num_out_tokens = num_tokens * tokens_per_expert + + # Create flat array with num_out_tokens ones + flat_array = jnp.zeros((num_tokens * num_experts,), dtype=jnp.int32) + flat_array = flat_array.at[:num_out_tokens].set(1) + + # Randomly permute + key, subkey = jax.random.split(key) + permuted_indices = jax.random.permutation(subkey, num_tokens * num_experts) + flat_array = flat_array[permuted_indices] + + # Reshape to routing_map + routing_map = flat_array.reshape((num_tokens, num_experts)) + + return routing_map + + # Test make_row_id_map + @pytest.mark.parametrize("num_tokens,num_experts,tokens_per_expert", [ + (32, 8, 2), + (64, 16, 3), + (128, 8, 1), + ]) + @pytest.mark.parametrize("use_fixed_per_token", [True, False]) + def test_make_row_id_map(self, num_tokens, num_experts, tokens_per_expert, use_fixed_per_token): + """Test make_row_id_map against reference implementation""" + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map( + num_tokens, num_experts, tokens_per_expert, key, use_fixed_per_token + ) + + # Test implementation + test_row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) + + # Reference implementation + ref_row_id_map = reference_make_row_id_map(routing_map, num_tokens, num_experts) + + # Pretty print for debugging + # print("\n" + "="*100) + # print(f"TEST: make_row_id_map (num_tokens={num_tokens}, num_experts={num_experts}, tokens_per_expert={tokens_per_expert})") + # print("="*100) + + # print("\n📊 ROUTING MAP (rows=tokens, cols=experts):") + # print("-"*100) + # print(routing_map) + + # print("\n📋 FULL ARRAYS:") + # print("-"*100) + # print("\n🔴 ACTUAL (JAX/Triton implementation) - AFTER PASS 3:") + # print(test_row_id_map) + # print("\n🔴 ACTUAL - Columns breakdown:") + # print(f" Sorted dest rows [0:{num_experts}]:") + # print(test_row_id_map[:, :num_experts]) + # print(f" Expert indices [{num_experts}:{2*num_experts}]:") + # print(test_row_id_map[:, num_experts:2*num_experts]) + # print(f" n_routed (last column):") + # print(test_row_id_map[:, -1]) + + # print("\n🟢 EXPECTED (Reference implementation):") + # print(ref_row_id_map) + # print("\n🟢 EXPECTED - Columns breakdown:") + # print(f" Sorted dest rows [0:{num_experts}]:") + # print(ref_row_id_map[:, :num_experts]) + # print(f" Expert indices [{num_experts}:{2*num_experts}]:") + # print(ref_row_id_map[:, num_experts:2*num_experts]) + # print(f" n_routed (last column):") + # print(ref_row_id_map[:, -1]) + + # print("\n🔍 DIFFERENCE (Actual - Expected):") + # diff = test_row_id_map - ref_row_id_map + # print(diff) + + # mismatch_count = jnp.sum(diff != 0) + # total_elements = test_row_id_map.size + # print(f"\n📊 STATISTICS:") + # print(f" Total elements: {total_elements}") + # print(f" Mismatched elements: {mismatch_count} ({100*mismatch_count/total_elements:.1f}%)") + # print(f" Max absolute difference: {jnp.max(jnp.abs(diff))}") + + # print("\n" + "="*100 + "\n") + + # Compare results - only compare valid positions (first n_routed in each section) + # Invalid positions may contain garbage (PyTorch) or -1 (JAX reference), but they're never accessed + for token_idx in range(num_tokens): + n_routed = int(ref_row_id_map[token_idx, -1]) + + # Compare valid dest rows [0:n_routed] + assert_allclose( + test_row_id_map[token_idx, :n_routed], + ref_row_id_map[token_idx, :n_routed], + rtol=0, atol=0, + err_msg=f"Mismatch in dest rows for token {token_idx}" + ) + + # Compare valid expert indices [num_experts:num_experts+n_routed] + assert_allclose( + test_row_id_map[token_idx, num_experts:num_experts+n_routed], + ref_row_id_map[token_idx, num_experts:num_experts+n_routed], + rtol=0, atol=0, + err_msg=f"Mismatch in expert indices for token {token_idx}" + ) + + # Compare n_routed (last column) + assert_allclose( + test_row_id_map[token_idx, -1], + ref_row_id_map[token_idx, -1], + rtol=0, atol=0, + err_msg=f"Mismatch in n_routed for token {token_idx}" + ) + + # # Optional: Also do a full comparison if both use -1 for invalid positions + # # This will help catch uninitialized memory issues + # if jnp.all((test_row_id_map == -1) | (test_row_id_map >= 0)): + # print("🔬 Both use -1 for invalid positions, doing full comparison...") + # assert_allclose(test_row_id_map, ref_row_id_map, rtol=0, atol=0) + + # Test permute_with_mask_map + @pytest.mark.parametrize("num_tokens,num_experts,hidden_size,tokens_per_expert", [ + (32, 8, 256, 2), + (64, 16, 512, 3), + # Add smaller test cases for easier debugging + (16, 4, 64, 2), # Small case for debugging + (8, 2, 32, 1), # Minimal case + ]) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + @pytest.mark.parametrize("with_probs", [True, False]) + def test_permute_with_mask_map(self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs): + """Test permute_with_mask_map against reference implementation""" + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + + # Generate row_id_map + row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) + + # Calculate number of output tokens + num_out_tokens = int(jnp.sum(routing_map)) + + # Generate input data + key, inp_key, prob_key = jax.random.split(key, 3) + inp = jax.random.uniform(inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) + + if with_probs: + probs = jax.random.uniform(prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0) + else: + probs = None + + # Test implementation + test_output, test_probs = permute_with_mask_map( + inp, row_id_map, probs, num_tokens, num_experts, num_out_tokens, hidden_size + ) + + # Reference implementation + ref_output, ref_probs = reference_permute_with_mask_map( + inp, row_id_map, probs, num_tokens, num_experts, num_out_tokens, hidden_size + ) + + # Debug output for bfloat16 failures + if dtype == jnp.bfloat16 and with_probs: + print(f"\n{'='*100}") + print(f"DEBUG: test_permute_with_mask_map (dtype=bfloat16, with_probs=True)") + print(f" num_tokens={num_tokens}, num_experts={num_experts}, hidden_size={hidden_size}") + print(f" num_out_tokens={num_out_tokens}, tokens_per_expert={tokens_per_expert}") + print(f"{'='*100}") + + # Check output differences (convert to float32 for printing) + output_diff = jnp.abs(test_output.astype(jnp.float32) - ref_output.astype(jnp.float32)) + print(f"\n📊 OUTPUT DIFFERENCES:") + print(f" Max diff: {float(jnp.max(output_diff)):.6f}") + print(f" Mean diff: {float(jnp.mean(output_diff)):.6f}") + print(f" Median diff: {float(jnp.median(output_diff)):.6f}") + print(f" Num elements with diff > 0.1: {int(jnp.sum(output_diff > 0.1))}") + print(f" Num elements with diff > 0.5: {int(jnp.sum(output_diff > 0.5))}") + print(f" Num elements with diff > 0.9: {int(jnp.sum(output_diff > 0.9))}") + + # Find worst mismatches + flat_diff = output_diff.flatten() + worst_indices = jnp.argsort(flat_diff)[-10:] # Top 10 worst + print(f"\n🔍 WORST MISMATCHES (flattened indices):") + for i, idx in enumerate(worst_indices): + row = int(idx // hidden_size) + col = int(idx % hidden_size) + actual_val = float(test_output.flatten()[idx]) + expected_val = float(ref_output.flatten()[idx]) + diff_val = float(flat_diff[idx]) + print(f" [{i}] position=({row},{col}), actual={actual_val:.4f}, " + f"expected={expected_val:.4f}, diff={diff_val:.4f}") + + # Check probs differences if present + if test_probs is not None and ref_probs is not None: + prob_diff = jnp.abs(test_probs.astype(jnp.float32) - ref_probs.astype(jnp.float32)) + print(f"\n📊 PROBS DIFFERENCES:") + print(f" Max diff: {float(jnp.max(prob_diff)):.6f}") + print(f" Mean diff: {float(jnp.mean(prob_diff)):.6f}") + print(f" Num elements with diff > 0.1: {int(jnp.sum(prob_diff > 0.1))}") + + # Check input data quality + print(f"\n📊 INPUT DATA STATS:") + print(f" inp range: [{float(jnp.min(inp)):.4f}, {float(jnp.max(inp)):.4f}]") + print(f" inp has NaN: {bool(jnp.any(jnp.isnan(inp)))}") + if probs is not None: + print(f" probs range: [{float(jnp.min(probs)):.4f}, {float(jnp.max(probs)):.4f}]") + print(f"\n{'='*100}\n") + + # Compare results + tols = dtype_tols(dtype) + assert_allclose(test_output, ref_output, **tols) + + if with_probs: + assert_allclose(test_probs, ref_probs, **tols) + + # Test unpermute_with_mask_map + @pytest.mark.parametrize("num_tokens,num_experts,hidden_size,tokens_per_expert", [ + (32, 8, 256, 2), + (64, 16, 512, 3), + ]) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + @pytest.mark.parametrize("with_merging_probs", [True, False]) + @pytest.mark.parametrize("with_permuted_probs", [True, False]) + def test_unpermute_with_mask_map( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, + with_merging_probs, with_permuted_probs + ): + """Test unpermute_with_mask_map against reference implementation""" + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + + # Generate row_id_map + row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) + + # Calculate number of output tokens + num_out_tokens = int(jnp.sum(routing_map)) + + # Generate input data + key, inp_key, merge_key, prob_key = jax.random.split(key, 4) + inp = jax.random.uniform(inp_key, (num_out_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) + + if with_merging_probs: + merging_probs = jax.random.uniform(merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0) + # Normalize merging probs per token + merging_probs = merging_probs / (jnp.sum(merging_probs, axis=1, keepdims=True) + 1e-8) + else: + merging_probs = None + + if with_permuted_probs: + permuted_probs = jax.random.uniform(prob_key, (num_out_tokens,), dtype=dtype, minval=0.0, maxval=1.0) + else: + permuted_probs = None + + # Test implementation + test_output, test_unprobs = unpermute_with_mask_map( + inp, row_id_map, merging_probs, permuted_probs, num_tokens, num_experts, hidden_size + ) + + # Reference implementation + ref_output, ref_unprobs = reference_unpermute_with_mask_map( + inp, row_id_map, merging_probs, permuted_probs, num_tokens, num_experts, hidden_size + ) + + # Compare results + tols = dtype_tols(dtype) + # Use relaxed tolerances for unpermute due to accumulation + relaxed_tols = dtype_tols(dtype, rtol=tols["rtol"] * 5, atol=tols["atol"] * 5) + + assert_allclose(test_output, ref_output, **relaxed_tols) + + if with_permuted_probs: + assert_allclose(test_unprobs, ref_unprobs, **tols) + + # Test round-trip: permute -> unpermute + @pytest.mark.parametrize("num_tokens,num_experts,hidden_size,tokens_per_expert", [ + (32, 8, 256, 2), + (64, 16, 512, 3), + ]) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + def test_permute_unpermute_roundtrip(self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype): + """Test that permute followed by unpermute recovers original input""" + key = jax.random.PRNGKey(42) + + # Generate routing map + routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) + + # Generate row_id_map + row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) + + # Calculate number of output tokens + num_out_tokens = int(jnp.sum(routing_map)) + + # Generate input data + key, inp_key = jax.random.split(key) + inp = jax.random.uniform(inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) + + # Create uniform merging probs (equal weight for all routed experts) + merging_probs = routing_map.astype(dtype) / jnp.maximum(jnp.sum(routing_map, axis=1, keepdims=True), 1.0) + + # Permute + permuted, _ = permute_with_mask_map( + inp, row_id_map, None, num_tokens, num_experts, num_out_tokens, hidden_size + ) + + # Unpermute with uniform merging + unpermuted, _ = unpermute_with_mask_map( + permuted, row_id_map, merging_probs, None, num_tokens, num_experts, hidden_size + ) + + # Compare with original input + tols = dtype_tols(dtype) + relaxed_tols = dtype_tols(dtype, rtol=tols["rtol"] * 10, atol=tols["atol"] * 10) + assert_allclose(unpermuted, inp, **relaxed_tols) + + # Test make_chunk_sort_map + @pytest.mark.parametrize("num_splits,total_tokens", [ + (4, 128), + (8, 256), + (16, 512), + ]) + def test_make_chunk_sort_map(self, num_splits, total_tokens): + """Test make_chunk_sort_map against reference implementation""" + key = jax.random.PRNGKey(42) + + # Generate random split sizes + key, size_key = jax.random.split(key) + split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits) + # Adjust last split to match total_tokens + split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1])) + + # Generate sorted indices (permutation of 0..num_splits-1) + key, sort_key = jax.random.split(key) + sorted_indices = jax.random.permutation(sort_key, num_splits) + + # Test implementation + test_map = make_chunk_sort_map(split_sizes, sorted_indices, total_tokens, num_splits) + + # Reference implementation + ref_map = reference_make_chunk_sort_map(split_sizes, sorted_indices, total_tokens, num_splits) + + # Compare results + assert_allclose(test_map, ref_map, rtol=0, atol=0) + + # Test sort_chunks_by_map + @pytest.mark.parametrize("num_splits,total_tokens,hidden_size", [ + (4, 128, 256), + (8, 256, 512), + ]) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + @pytest.mark.parametrize("is_forward", [True, False]) + @pytest.mark.parametrize("with_probs", [True, False]) + def test_sort_chunks_by_map(self, num_splits, total_tokens, hidden_size, dtype, is_forward, with_probs): + """Test sort_chunks_by_map against reference implementation""" + key = jax.random.PRNGKey(42) + + # Generate random split sizes + key, size_key = jax.random.split(key) + split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits) + split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1])) + + # Generate sorted indices + key, sort_key = jax.random.split(key) + sorted_indices = jax.random.permutation(sort_key, num_splits) + + # Generate row_id_map + row_id_map = make_chunk_sort_map(split_sizes, sorted_indices, total_tokens, num_splits) + + # Generate input data + key, inp_key, prob_key = jax.random.split(key, 3) + inp = jax.random.uniform(inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) + + if with_probs: + probs = jax.random.uniform(prob_key, (total_tokens,), dtype=dtype, minval=0.0, maxval=1.0) + else: + probs = None + + # Debug prints + # print("\n" + "="*100) + # print(f"TEST: sort_chunks_by_map (num_splits={num_splits}, total_tokens={total_tokens}, hidden_size={hidden_size})") + # print(f" dtype={dtype}, is_forward={is_forward}, with_probs={with_probs}") + # print("="*100) + + # print("\n📊 TEST INPUTS:") + # print("-"*100) + # print(f"split_sizes: {split_sizes}") + # print(f"sorted_indices: {sorted_indices}") + # print(f"row_id_map (first 10): {row_id_map[:10]}") + # print(f"inp.shape: {inp.shape}, inp.dtype: {inp.dtype}") + # print(f"inp[0, :5]: {inp[0, :5]}") # First 5 elements of first row + # if probs is not None: + # print(f"probs.shape: {probs.shape}, probs.dtype: {probs.dtype}") + # print(f"probs[:5]: {probs[:5]}") + + # Test implementation + test_output, test_probs = sort_chunks_by_map( + inp, row_id_map, probs, total_tokens, hidden_size, is_forward + ) + + # Reference implementation + ref_output, ref_probs = reference_sort_chunks_by_map( + inp, row_id_map, probs, total_tokens, hidden_size, is_forward + ) + + # print("\n📋 OUTPUT COMPARISON:") + # print("-"*100) + # print(f"\n🔴 ACTUAL (Triton implementation):") + # print(f"test_output.shape: {test_output.shape}, test_output.dtype: {test_output.dtype}") + # print(f"test_output[0, :5]: {test_output[0, :5]}") # First 5 elements of first row + # print(f"test_output has NaN: {jnp.any(jnp.isnan(test_output))}") + # print(f"test_output has Inf: {jnp.any(jnp.isinf(test_output))}") + # if test_probs is not None: + # print(f"test_probs[:5]: {test_probs[:5]}") + + # print(f"\n🟢 EXPECTED (Reference implementation):") + # print(f"ref_output.shape: {ref_output.shape}, ref_output.dtype: {ref_output.dtype}") + # print(f"ref_output[0, :5]: {ref_output[0, :5]}") # First 5 elements of first row + # if ref_probs is not None: + # print(f"ref_probs[:5]: {ref_probs[:5]}") + + # print("\n🔍 DIFFERENCE (Actual - Expected):") + # if not jnp.any(jnp.isnan(test_output)): + # diff = test_output - ref_output + # print(f"Max absolute difference: {jnp.max(jnp.abs(diff))}") + # print(f"Mean absolute difference: {jnp.mean(jnp.abs(diff))}") + # else: + # print("Cannot compute difference - test_output contains NaN values") + + # print("\n" + "="*100 + "\n") + + # Compare results + tols = dtype_tols(dtype) + assert_allclose(test_output, ref_output, **tols) + + if with_probs: + assert_allclose(test_probs, ref_probs, **tols) + + # Test chunk sort round-trip + @pytest.mark.parametrize("num_splits,total_tokens,hidden_size", [ + (4, 128, 256), + (8, 256, 512), + ]) + @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) + def test_chunk_sort_roundtrip(self, num_splits, total_tokens, hidden_size, dtype): + """Test that forward sort followed by backward sort recovers original input""" + key = jax.random.PRNGKey(42) + + # Generate random split sizes + key, size_key = jax.random.split(key) + split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits) + split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1])) + + # Generate sorted indices + key, sort_key = jax.random.split(key) + sorted_indices = jax.random.permutation(sort_key, num_splits) + + # Generate row_id_map + row_id_map = make_chunk_sort_map(split_sizes, sorted_indices, total_tokens, num_splits) + + # Generate input data + key, inp_key = jax.random.split(key) + inp = jax.random.uniform(inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) + + # Forward sort + sorted_output, _ = sort_chunks_by_map( + inp, row_id_map, None, total_tokens, hidden_size, is_forward=True + ) + + # Backward sort (should recover original) + recovered, _ = sort_chunks_by_map( + sorted_output, row_id_map, None, total_tokens, hidden_size, is_forward=False + ) + + # Compare with original input + tols = dtype_tols(dtype) + assert_allclose(recovered, inp, **tols) \ No newline at end of file diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index e8c43f52d2..4b06b9a7fe 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -603,4 +603,4 @@ def _sort_chunks_by_map_kernel( key=["hidden_size"], )(_sort_chunks_by_map_kernel) except RuntimeError: - pass + pass \ No newline at end of file diff --git a/transformer_engine/jax/triton/__init__.py b/transformer_engine/jax/triton/__init__.py new file mode 100644 index 0000000000..8c1ec6d18d --- /dev/null +++ b/transformer_engine/jax/triton/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025-2028, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""JAX wrappers for Triton kernels.""" + +from .permutation import ( + make_row_id_map, + permute_with_mask_map, + unpermute_with_mask_map, + make_chunk_sort_map, + sort_chunks_by_map, +) + +__all__ = [ + "make_row_id_map", + "permute_with_mask_map", + "unpermute_with_mask_map", + "make_chunk_sort_map", + "sort_chunks_by_map", +] diff --git a/transformer_engine/jax/triton/permutation.py b/transformer_engine/jax/triton/permutation.py new file mode 100644 index 0000000000..5a1c996f4f --- /dev/null +++ b/transformer_engine/jax/triton/permutation.py @@ -0,0 +1,567 @@ +# Copyright (c) 2025-2028, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""JAX wrapper functions for Permutation Triton kernels.""" + +from typing import Optional, Tuple, Union +import jax +import jax.numpy as jnp +from jax import ShapeDtypeStruct +import triton +import jax_triton as jt + +from transformer_engine.common.triton.permutation import ( + _row_id_map_pass_1_kernel, + _row_id_map_pass_2_kernel, + _row_id_map_pass_3_kernel, + _permute_kernel, + _unpermute_kernel, + _unpermute_bwd_with_merging_probs_kernel, + _make_chunk_sort_map_kernel, + _sort_chunks_by_map_kernel, +) + + +def make_row_id_map( + routing_map: jnp.ndarray, + num_tokens: int, + num_experts: int, +) -> jnp.ndarray: + """ + Prepare the row_id_map for the permutation using JAX-Triton. + + Parameters + ---------- + routing_map : jnp.ndarray + Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates + which experts are routed to which tokens. The values in it: 1 means the token is routed to + this expert and 0 means not. + num_tokens : int + Number of tokens in the input tensor. + num_experts : int + Number of experts in the input tensor. + + Returns + ------- + row_id_map : jnp.ndarray + The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`. + For each token, the last item is the number of experts that are routed (n_routed). + The first n_routed items are the destination row indices in the permuted tokens. + The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding + to the first n_routed row indices above. + """ + row_id_map = jnp.full((num_tokens, num_experts * 2 + 1), -1, dtype=jnp.int32) + block_size = 1024 + grid = (num_experts, triton.cdiv(num_tokens, block_size)) + workspace_tensor = jnp.zeros(grid, dtype=jnp.int32) + + # supposing num_tokens == 5, num_experts == 3, block_size == 3 + # and we have a routing_map like this: + # [[1, 1, 0], + # [1, 0, 1], + # [0, 0, 1], + # [1, 1, 0], + # [0, 0, 0]] + + # Pass 1: block cumsum + # for each expert, compute the cumsum of every block_size tokens + # the row_id_map will be like this after pass 1 (r means useless values): + # [[1, 1, 0, r, r, r, r], + # [2, 0, 1, r, r, r, r], + # [0, 0, 2, r, r, r, r], + # [1, 1, 0, r, r, r, r], + # [0, 0, 0, r, r, r, r]] + # Note: "r" = -1 in the triton common kernel implementation + + # Compute strides manually (JAX arrays don't have .strides attribute) + # For routing_map of shape [num_tokens, num_experts], C-contiguous: + routing_stride_token = num_experts # Move to next token + routing_stride_expert = 1 # Move to next expert (contiguous) + # For row_id_map of shape [num_tokens, num_experts * 2 + 1], C-contiguous: + row_id_stride_token = num_experts * 2 + 1 # Move to next token + row_id_stride_expert = 1 # Move to next column (contiguous) + + # Pass 1: Block cumsum + row_id_map_pass1, workspace_tensor = jt.triton_call( + routing_map, # Input 0 (ptr): routing_map_ptr + num_tokens, # Scalar: num_tokens + routing_stride_token, # Scalar: stride_routing_map_token + routing_stride_expert, # Scalar: stride_routing_map_expert + row_id_stride_token, # Scalar: stride_row_id_map_token + row_id_stride_expert, # Scalar: stride_row_id_map_expert + kernel=_row_id_map_pass_1_kernel, + out_shape=[ + ShapeDtypeStruct(row_id_map.shape, row_id_map.dtype), + ShapeDtypeStruct(workspace_tensor.shape, workspace_tensor.dtype), + ], + grid=grid, + BLOCK_SIZE=block_size, # Constexpr - pass as keyword + ) + + # Pass 2: cumsum all and process the mask + # Strides remain the same as Pass 1 + # Note: Pass 2 takes the outputs from Pass 1 as inputs + row_id_map_pass2, workspace_tensor = jt.triton_call( + row_id_map_pass1, # Input 0 (ptr): row_id_map_ptr (from Pass 1) + workspace_tensor, # Input 1 (ptr): workspace_ptr (from Pass 1) + num_tokens, # Scalar: num_tokens + row_id_stride_token, # Scalar: stride_row_id_map_token + row_id_stride_expert, # Scalar: stride_row_id_map_expert + kernel=_row_id_map_pass_2_kernel, + out_shape=[ + ShapeDtypeStruct(row_id_map.shape, row_id_map.dtype), + ShapeDtypeStruct(workspace_tensor.shape, workspace_tensor.dtype), + ], + input_output_aliases={0: 0, 1: 1}, # row_id_map input→output, workspace input→output + grid=grid, + WORKSPACE_LOAD_WIDTH=triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)), # Constexpr + BLOCK_SIZE=block_size, # Constexpr + ) + # Initialize columns [num_experts:] to -1 since Pass 1/2 only wrote to [0:num_experts] + # Reference implementation expects -1 for invalid entries, not garbage + row_id_map = row_id_map_pass2.at[:, num_experts:].set(-1) + + # Pass 3: make the row_id_map from sparse to dense structure + grid = (num_tokens,) + load_size = triton.next_power_of_2(num_experts) + row_id_map = jt.triton_call( + row_id_map, # Input 0 (ptr): row_id_map_ptr (from Pass 2, with -1 initialized) + row_id_stride_token, # Scalar 1: stride_row_id_map_token + row_id_stride_expert, # Scalar 2: stride_row_id_map_expert + kernel=_row_id_map_pass_3_kernel, + out_shape=[ShapeDtypeStruct(row_id_map.shape, row_id_map.dtype)], + input_output_aliases={0: 0}, # row_id_map input→output + num_experts=num_experts, + grid=grid, + LOAD_SIZE=load_size, # Constexpr + )[0] + + return row_id_map + + +def permute_with_mask_map( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + probs: Optional[jnp.ndarray], + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """ + Permute the input tensor based on the row_id_map using JAX-Triton. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + probs : Optional[jnp.ndarray] + The probabilities of the input tensor. If it is not None, it will be permuted. + num_tokens : int + Number of tokens in the input tensor. + num_experts : int + Number of experts in the input tensor. + num_out_tokens : int + Number of tokens in the permuted tensor. + hidden_size : int + Hidden size of the input tensor. + + Returns + ------- + output : jnp.ndarray + Permuted output tensor of shape `[num_out_tokens, hidden_size]`. + permuted_probs : Optional[jnp.ndarray] + Permuted probabilities if probs was provided, None otherwise. + """ + # Compute strides manually (JAX arrays don't have .strides attribute) + # For inp of shape [num_tokens, hidden_size], C-contiguous: + inp_stride_token = hidden_size + inp_stride_hidden = 1 + # For output of shape [num_out_tokens, hidden_size], C-contiguous: + output_stride_token = hidden_size + output_stride_hidden = 1 + # For row_id_map of shape [num_tokens, num_experts * 2 + 1], C-contiguous: + row_id_stride_token = num_experts * 2 + 1 + row_id_stride_expert = 1 + # For probs: depends on dimensionality + if probs is not None: + if probs.ndim > 1: + # Shape [num_tokens, num_experts] + probs_stride_token = num_experts + probs_stride_expert = 1 + else: + # Shape [num_tokens] + probs_stride_token = 1 + probs_stride_expert = 1 + else: + probs_stride_token = 0 + probs_stride_expert = 0 + # For permuted_probs of shape [num_out_tokens], C-contiguous: + permuted_probs_stride_token = 1 + + # Grid: one block per token, multiple blocks for hidden dimension + def grid_fn(meta): + return (num_tokens, triton.cdiv(hidden_size, meta['BLOCK_SIZE'])) + + if probs is not None: + # jax-triton doesn't handle None pointers correctly, create dummy tensors + # Make dummy tensors large enough to not cause out-of-bounds access + dummy_scale = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) + dummy_permuted_scale = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) + + output, permuted_probs = jt.triton_call( + inp, # Input 0 (ptr): input_ptr + row_id_map, # Input 1 (ptr): row_id_map_ptr + probs, # Input 2 (ptr): probs_ptr + dummy_scale, # Input 3 (ptr): scale_ptr (dummy, not used) + dummy_permuted_scale, # Input 4 (ptr): permuted_scale_ptr (dummy, not used) + 0, # Scalar 5: scale_hidden_dim (not used) + row_id_stride_token, # Scalar 6: stride_row_id_map_token + row_id_stride_expert, # Scalar 7: stride_row_id_map_expert + inp_stride_token, # Scalar 8: stride_input_token + inp_stride_hidden, # Scalar 9: stride_input_hidden + output_stride_token, # Scalar 10: stride_output_token + output_stride_hidden, # Scalar 11: stride_output_hidden + probs_stride_token, # Scalar 12: stride_probs_token + probs_stride_expert, # Scalar 13: stride_probs_expert + hidden_size, # Scalar 14: stride_scale_token (use actual stride) + 1, # Scalar 15: stride_scale_hidden + permuted_probs_stride_token, # Scalar 16: stride_permuted_probs_token + hidden_size, # Scalar 17: stride_permuted_scale_token (use actual stride) + 1, # Scalar 18: stride_permuted_scale_hidden + kernel=_permute_kernel, + out_shape=[ + ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), # Positional: output_ptr + ShapeDtypeStruct((num_out_tokens,), probs.dtype), # Positional: permuted_probs_ptr + ], + grid=grid_fn, + num_experts=num_experts, # Keyword constexpr + hidden_size=hidden_size, # Keyword constexpr + PERMUTE_PROBS=True, # Keyword constexpr + PERMUTE_SCALE=False, # Keyword constexpr + # BLOCK_SIZE is keyword constexpr from autotune + ) + else: + # jax-triton doesn't handle None pointers correctly, create dummy tensors + # Make dummy tensors large enough to not cause out-of-bounds access + dummy_probs = jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) + dummy_scale = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) + dummy_permuted_scale = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) + + result = jt.triton_call( + inp, # Input 0 (ptr): input_ptr + row_id_map, # Input 1 (ptr): row_id_map_ptr + dummy_probs, # Input 2 (ptr): probs_ptr (dummy, not used) + dummy_scale, # Input 3 (ptr): scale_ptr (dummy, not used) + dummy_permuted_scale, # Input 4 (ptr): permuted_scale_ptr (dummy, not used) + 0, # Scalar 5: scale_hidden_dim (not used) + row_id_stride_token, # Scalar 6: stride_row_id_map_token + row_id_stride_expert, # Scalar 7: stride_row_id_map_expert + inp_stride_token, # Scalar 8: stride_input_token + inp_stride_hidden, # Scalar 9: stride_input_hidden + output_stride_token, # Scalar 10: stride_output_token + output_stride_hidden, # Scalar 11: stride_output_hidden + probs_stride_token, # Scalar 12: stride_probs_token (use actual) + probs_stride_expert, # Scalar 13: stride_probs_expert (use actual) + hidden_size, # Scalar 14: stride_scale_token (use actual stride) + 1, # Scalar 15: stride_scale_hidden + permuted_probs_stride_token, # Scalar 16: stride_permuted_probs_token (use actual) + hidden_size, # Scalar 17: stride_permuted_scale_token (use actual stride) + 1, # Scalar 18: stride_permuted_scale_hidden + kernel=_permute_kernel, + out_shape=[ + ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), # Positional: output_ptr + ShapeDtypeStruct((num_out_tokens,), inp.dtype), # Positional: permuted_probs_ptr (dummy) + ], + grid=grid_fn, + num_experts=num_experts, # Keyword constexpr + hidden_size=hidden_size, # Keyword constexpr + PERMUTE_PROBS=False, # Keyword constexpr + PERMUTE_SCALE=False, # Keyword constexpr + # BLOCK_SIZE is keyword constexpr from autotune + ) + output = result[0] + permuted_probs = None + + return output, permuted_probs + + +def unpermute_with_mask_map( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + merging_probs: Optional[jnp.ndarray], + permuted_probs: Optional[jnp.ndarray], + num_tokens: int, + num_experts: int, + hidden_size: int, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """ + Unpermute the input tensor based on the row_id_map using JAX-Triton. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape `[num_out_tokens, hidden_size]`. + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`. + merging_probs : Optional[jnp.ndarray] + The merging probabilities of the input tensor. If it is not None, it will be used as weights + to reduce the unpermuted tokens. + permuted_probs : Optional[jnp.ndarray] + The permuted probabilities of the input tensor. If it is not None, it will be unpermuted. + num_tokens : int + Number of tokens in the permuted tensor. + num_experts : int + Number of experts in the permuted tensor. + hidden_size : int + Hidden size of the permuted tensor. + + Returns + ------- + output : jnp.ndarray + Unpermuted output tensor of shape `[num_tokens, hidden_size]`. + unpermuted_probs : Optional[jnp.ndarray] + Unpermuted probabilities if permuted_probs was provided, None otherwise. + """ + # Compute strides manually (JAX arrays don't have .strides attribute) + # For inp of shape [num_out_tokens, hidden_size], C-contiguous: + inp_stride_token = hidden_size + inp_stride_hidden = 1 + # For output of shape [num_tokens, hidden_size], C-contiguous: + output_stride_token = hidden_size + output_stride_hidden = 1 + # For row_id_map of shape [num_tokens, num_experts * 2 + 1], C-contiguous: + row_id_stride_token = num_experts * 2 + 1 + row_id_stride_expert = 1 + # For merging_probs of shape [num_tokens, num_experts] if present: + if merging_probs is not None: + merging_probs_stride_token = num_experts + merging_probs_stride_expert = 1 + else: + merging_probs_stride_token = 0 + merging_probs_stride_expert = 0 + # For permuted_probs of shape [num_out_tokens] if present: + permuted_probs_stride_token = 1 + # For unpermuted_probs of shape [num_tokens, num_experts] (output): + unpermuted_probs_stride_token = num_experts + unpermuted_probs_stride_expert = 1 + + # Grid: one block per token, multiple blocks for hidden dimension + def grid_fn(meta): + return (num_tokens, triton.cdiv(hidden_size, meta['BLOCK_SIZE'])) + + if permuted_probs is not None: + # Ensure merging_probs is not None (use dummy if needed) + merging_probs_arg = merging_probs if merging_probs is not None else jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) + + output, unpermuted_probs = jt.triton_call( + inp, # Input 0 (ptr): input_ptr + row_id_map, # Input 1 (ptr): row_id_map_ptr + merging_probs_arg, # Input 2 (ptr): merging_probs_ptr (real or dummy) + permuted_probs, # Input 3 (ptr): permuted_probs_ptr + row_id_stride_token, # Scalar 4: stride_row_id_map_token + row_id_stride_expert, # Scalar 5: stride_row_id_map_expert + inp_stride_token, # Scalar 6: stride_input_token + inp_stride_hidden, # Scalar 7: stride_input_hidden + output_stride_token, # Scalar 8: stride_output_token + output_stride_hidden, # Scalar 9: stride_output_hidden + merging_probs_stride_token, # Scalar 10: stride_merging_probs_token + merging_probs_stride_expert, # Scalar 11: stride_merging_probs_expert + permuted_probs_stride_token, # Scalar 12: stride_permuted_probs_token + unpermuted_probs_stride_token, # Scalar 13: stride_unpermuted_probs_token + unpermuted_probs_stride_expert, # Scalar 14: stride_unpermuted_probs_expert + kernel=_unpermute_kernel, + out_shape=[ + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), # Positional: output_ptr + ShapeDtypeStruct((num_tokens, num_experts), permuted_probs.dtype), # Positional: unpermuted_probs_ptr + ], + grid=grid_fn, + num_experts=num_experts, # Keyword constexpr + hidden_size=hidden_size, # Keyword constexpr + PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), # Keyword constexpr + WITH_MERGING_PROBS=merging_probs is not None, # Keyword constexpr + PERMUTE_PROBS=True, # Keyword constexpr + # BLOCK_SIZE is keyword constexpr from autotune + ) + else: + # jax-triton doesn't handle None pointers correctly, create dummy tensors if needed + dummy_permuted_probs = jnp.zeros((num_tokens,), dtype=inp.dtype) # Proper size dummy + merging_probs_arg = merging_probs if merging_probs is not None else jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) + + result = jt.triton_call( + inp, # Input 0 (ptr): input_ptr + row_id_map, # Input 1 (ptr): row_id_map_ptr + merging_probs_arg, # Input 2 (ptr): merging_probs_ptr (real or dummy) + dummy_permuted_probs, # Input 3 (ptr): permuted_probs_ptr (dummy, not used) + row_id_stride_token, # Scalar 4: stride_row_id_map_token + row_id_stride_expert, # Scalar 5: stride_row_id_map_expert + inp_stride_token, # Scalar 6: stride_input_token + inp_stride_hidden, # Scalar 7: stride_input_hidden + output_stride_token, # Scalar 8: stride_output_token + output_stride_hidden, # Scalar 9: stride_output_hidden + merging_probs_stride_token, # Scalar 10: stride_merging_probs_token + merging_probs_stride_expert, # Scalar 11: stride_merging_probs_expert + 1, # Scalar 12: stride_permuted_probs_token (dummy stride) + 0, # Scalar 13: stride_unpermuted_probs_token + 0, # Scalar 14: stride_unpermuted_probs_expert + kernel=_unpermute_kernel, + out_shape=[ + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), # Positional: output_ptr + ShapeDtypeStruct((num_tokens, num_experts), inp.dtype), # Positional: unpermuted_probs_ptr (dummy) + ], + grid=grid_fn, + num_experts=num_experts, # Keyword constexpr + hidden_size=hidden_size, # Keyword constexpr + PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), # Keyword constexpr + WITH_MERGING_PROBS=merging_probs is not None, # Keyword constexpr + PERMUTE_PROBS=False, # Keyword constexpr + # BLOCK_SIZE is keyword constexpr from autotune + ) + output = result[0] + unpermuted_probs = None + + return output, unpermuted_probs + + +def make_chunk_sort_map( + split_sizes: jnp.ndarray, + sorted_indices: jnp.ndarray, + num_tokens: int, + num_splits: int, +) -> jnp.ndarray: + """ + Make a row_id_map for chunk sort using JAX-Triton. + + Parameters + ---------- + split_sizes : jnp.ndarray + The sizes of the chunks of shape `[num_splits,]`. + sorted_indices : jnp.ndarray + The indices of the sorted chunks of shape `[num_splits,]`. + num_tokens : int + Number of tokens in the input tensor. + num_splits : int + Number of splits of split_sizes and sorted_indices. + + Returns + ------- + row_id_map : jnp.ndarray + Row ID map for chunk sorting of shape `[num_tokens,]`. + """ + grid = (num_tokens,) + + row_id_map = jt.triton_call( + split_sizes, # Input 0 (ptr): split_sizes_ptr + sorted_indices, # Input 1 (ptr): sorted_indices_ptr + kernel=_make_chunk_sort_map_kernel, + out_shape=[ShapeDtypeStruct((num_tokens,), jnp.int32)], + grid=grid, + num_splits=num_splits, # Constexpr + IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits), # Constexpr + )[0] + + return row_id_map + + +def sort_chunks_by_map( + inp: jnp.ndarray, + row_id_map: jnp.ndarray, + probs: Optional[jnp.ndarray], + num_tokens: int, + hidden_size: int, + is_forward: bool, +) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """ + Sort chunks with row_id_map using JAX-Triton. + + Parameters + ---------- + inp : jnp.ndarray + Input tensor of shape `[num_tokens, hidden_size]`. + row_id_map : jnp.ndarray + The token to expert mapping tensor of shape `[num_tokens,]`. + probs : Optional[jnp.ndarray] + The probabilities of the input tensor. If it is not None, it will be permuted. + num_tokens : int + Number of tokens in the input tensor. + hidden_size : int + Hidden size of the input tensor. + is_forward : bool + Whether the sort is for forward or backward. + + Returns + ------- + output : jnp.ndarray + Sorted output tensor of shape `[num_tokens, hidden_size]`. + permuted_probs : Optional[jnp.ndarray] + Sorted probabilities if probs was provided, None otherwise. + """ + # Compute strides manually (JAX arrays don't have .strides attribute) + # For inp and output of shape [num_tokens, hidden_size], C-contiguous: + inp_stride_token = hidden_size + inp_stride_hidden = 1 + output_stride_token = hidden_size + output_stride_hidden = 1 + # For probs and permuted_probs of shape [num_tokens], C-contiguous: + probs_stride_token = 1 + permuted_probs_stride_token = 1 + + # Grid: one block per token, multiple blocks for hidden dimension + def grid_fn(meta): + return (num_tokens, triton.cdiv(hidden_size, meta['BLOCK_SIZE'])) + + if probs is not None: + output, permuted_probs = jt.triton_call( + inp, # Input 0 (ptr): input_ptr + row_id_map, # Input 1 (ptr): row_id_map_ptr + probs, # Input 2 (ptr): probs_ptr + inp_stride_token, # Scalar 3: stride_input_token + inp_stride_hidden, # Scalar 4: stride_input_hidden + output_stride_token, # Scalar 5: stride_output_token + output_stride_hidden, # Scalar 6: stride_output_hidden + probs_stride_token, # Scalar 7: stride_probs_token + permuted_probs_stride_token, # Scalar 8: stride_permuted_probs_token + kernel=_sort_chunks_by_map_kernel, + out_shape=[ + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), # Added at end: output_ptr + ShapeDtypeStruct((num_tokens,), probs.dtype), # Added at end: permuted_probs_ptr + ], + grid=grid_fn, + hidden_size=hidden_size, # Constexpr 9: hidden_size + PERMUTE_PROBS=True, # Constexpr 10: PERMUTE_PROBS + # BLOCK_SIZE is Constexpr 11, provided by autotune + FORWARD=is_forward, # Constexpr 12: FORWARD + # output_ptr and permuted_probs_ptr (13-14) are added automatically by jax-triton from out_shape + ) + else: + # Note: jax-triton might not handle None correctly, so create a dummy probs tensor + dummy_probs = jnp.zeros((num_tokens,), dtype=inp.dtype) + + result = jt.triton_call( + inp, # Input 0 (ptr): input_ptr + row_id_map, # Input 1 (ptr): row_id_map_ptr + dummy_probs, # Input 2 (ptr): probs_ptr (dummy, not used by kernel) + inp_stride_token, # Scalar 3: stride_input_token + inp_stride_hidden, # Scalar 4: stride_input_hidden + output_stride_token, # Scalar 5: stride_output_token + output_stride_hidden, # Scalar 6: stride_output_hidden + probs_stride_token, # Scalar 7: stride_probs_token (use actual stride) + permuted_probs_stride_token, # Scalar 8: stride_permuted_probs_token + kernel=_sort_chunks_by_map_kernel, + out_shape=[ + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), # Positional after strides: output_ptr + ShapeDtypeStruct((num_tokens,), inp.dtype), # Positional after strides: permuted_probs_ptr (dummy) + ], + grid=grid_fn, + hidden_size=hidden_size, # Keyword constexpr + PERMUTE_PROBS=False, # Keyword constexpr + FORWARD=is_forward, # Keyword constexpr + # BLOCK_SIZE is added by autotune as keyword constexpr + ) + output = result[0] + permuted_probs = None + + return output, permuted_probs \ No newline at end of file From c804e5f28cf7c0d7c343d95354b768a64f2fc3a6 Mon Sep 17 00:00:00 2001 From: tdophung Date: Mon, 24 Nov 2025 18:00:25 -0800 Subject: [PATCH 07/10] Clean up for MR Signed-off-by: tdophung --- tests/jax/test_permutation.py | 560 +++++++----------- .../common/triton/permutation.py | 2 +- transformer_engine/jax/triton/permutation.py | 437 +++++++------- 3 files changed, 452 insertions(+), 547 deletions(-) diff --git a/tests/jax/test_permutation.py b/tests/jax/test_permutation.py index 37c4208df7..2490737ac6 100644 --- a/tests/jax/test_permutation.py +++ b/tests/jax/test_permutation.py @@ -4,8 +4,6 @@ """Tests for permutation Triton kernels""" -# Patch jax-triton for Triton 3.5.1 compatibility - MUST BE FIRST! - import jax import jax.numpy as jnp import pytest @@ -28,7 +26,7 @@ def reference_make_row_id_map( ) -> jnp.ndarray: """ Reference implementation of make_row_id_map using JAX primitives. - + Parameters ---------- routing_map : jnp.ndarray @@ -38,29 +36,29 @@ def reference_make_row_id_map( Number of tokens in the input tensor. num_experts : int Number of experts in the input tensor. - + Returns ------- row_id_map : jnp.ndarray The row_id_map for the permutation of shape [num_tokens, num_experts * 2 + 1]. """ row_id_map = jnp.full((num_tokens, num_experts * 2 + 1), -1, dtype=jnp.int32) - + # For each expert, compute cumulative sum to get destination indices cumsum_per_expert = jnp.cumsum(routing_map, axis=0) - + # Compute total tokens per expert tokens_per_expert = jnp.sum(routing_map, axis=0) expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]]) - + # Build the row_id_map for token_idx in range(num_tokens): routed_experts = jnp.where(routing_map[token_idx] == 1)[0] n_routed = len(routed_experts) - + # Store number of routed experts in the last position row_id_map = row_id_map.at[token_idx, -1].set(n_routed) - + # For each routed expert, compute destination row and store it dest_rows = [] expert_indices = [] @@ -69,18 +67,18 @@ def reference_make_row_id_map( dest_row = expert_offsets[expert_idx] + cumsum_per_expert[token_idx, expert_idx] - 1 dest_rows.append(dest_row) expert_indices.append(expert_idx) - + # Sort by destination row if n_routed > 0: sort_indices = jnp.argsort(-jnp.array(dest_rows)) # Negative for descending sort sorted_dest_rows = jnp.array(dest_rows)[sort_indices] sorted_expert_indices = jnp.array(expert_indices)[sort_indices] - + # Store sorted destination rows and expert indices for i in range(n_routed): row_id_map = row_id_map.at[token_idx, i].set(sorted_dest_rows[i]) row_id_map = row_id_map.at[token_idx, num_experts + i].set(sorted_expert_indices[i]) - + return row_id_map @@ -95,7 +93,7 @@ def reference_permute_with_mask_map( ) -> tuple: """ Reference implementation of permute_with_mask_map using JAX primitives. - + Parameters ---------- inp : jnp.ndarray @@ -112,7 +110,7 @@ def reference_permute_with_mask_map( Number of tokens in the permuted tensor. hidden_size : int Hidden size of the input tensor. - + Returns ------- output : jnp.ndarray @@ -122,30 +120,30 @@ def reference_permute_with_mask_map( """ output = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) permuted_probs = None if probs is None else jnp.zeros((num_out_tokens,), dtype=probs.dtype) - + for token_idx in range(num_tokens): n_routed = int(row_id_map[token_idx, -1]) for i in range(n_routed): dest_row = int(row_id_map[token_idx, i]) expert_idx = int(row_id_map[token_idx, num_experts + i]) - + # Get probability for this expert if probs is not None: if probs.ndim == 1: prob = probs[token_idx] else: prob = probs[token_idx, expert_idx] - + # Match kernel behavior: if prob == 0.0, zero out the output (padding indicator) if prob == 0.0: output = output.at[dest_row].set(0.0) else: output = output.at[dest_row].set(inp[token_idx]) - + permuted_probs = permuted_probs.at[dest_row].set(prob) else: output = output.at[dest_row].set(inp[token_idx]) - + return output, permuted_probs @@ -160,7 +158,7 @@ def reference_unpermute_with_mask_map( ) -> tuple: """ Reference implementation of unpermute_with_mask_map using JAX primitives. - + Parameters ---------- inp : jnp.ndarray @@ -177,7 +175,7 @@ def reference_unpermute_with_mask_map( Number of experts. hidden_size : int Hidden size. - + Returns ------- output : jnp.ndarray @@ -186,23 +184,29 @@ def reference_unpermute_with_mask_map( Unpermuted probabilities if permuted_probs was provided, None otherwise. """ output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) - unpermuted_probs = None if permuted_probs is None else jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype) - + unpermuted_probs = ( + None + if permuted_probs is None + else jnp.zeros((num_tokens, num_experts), dtype=permuted_probs.dtype) + ) + for token_idx in range(num_tokens): n_routed = int(row_id_map[token_idx, -1]) for i in range(n_routed): src_row = int(row_id_map[token_idx, i]) expert_idx = int(row_id_map[token_idx, num_experts + i]) - + if merging_probs is not None: weight = merging_probs[token_idx, expert_idx] output = output.at[token_idx].add(inp[src_row] * weight) else: output = output.at[token_idx].add(inp[src_row]) - + if permuted_probs is not None: - unpermuted_probs = unpermuted_probs.at[token_idx, expert_idx].set(permuted_probs[src_row]) - + unpermuted_probs = unpermuted_probs.at[token_idx, expert_idx].set( + permuted_probs[src_row] + ) + return output, unpermuted_probs @@ -214,7 +218,7 @@ def reference_make_chunk_sort_map( ) -> jnp.ndarray: """ Reference implementation of make_chunk_sort_map using JAX primitives. - + Parameters ---------- split_sizes : jnp.ndarray @@ -225,30 +229,30 @@ def reference_make_chunk_sort_map( Number of tokens. num_splits : int Number of splits. - + Returns ------- row_id_map : jnp.ndarray Row ID map for chunk sorting of shape [num_tokens,]. """ row_id_map = jnp.zeros((num_tokens,), dtype=jnp.int32) - + # Compute cumulative positions cumsum_sizes = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)]) - + # For each chunk, compute the destination indices dest_offset = 0 for sorted_idx in sorted_indices: chunk_start = cumsum_sizes[sorted_idx] chunk_end = cumsum_sizes[sorted_idx + 1] chunk_size = chunk_end - chunk_start - + # Map source positions to destination positions for i in range(chunk_size): row_id_map = row_id_map.at[chunk_start + i].set(dest_offset + i) - + dest_offset += chunk_size - + return row_id_map @@ -262,7 +266,7 @@ def reference_sort_chunks_by_map( ) -> tuple: """ Reference implementation of sort_chunks_by_map using JAX primitives. - + Parameters ---------- inp : jnp.ndarray @@ -277,7 +281,7 @@ def reference_sort_chunks_by_map( Hidden size. is_forward : bool Whether this is forward or backward. - + Returns ------- output : jnp.ndarray @@ -287,7 +291,7 @@ def reference_sort_chunks_by_map( """ output = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) permuted_probs = None if probs is None else jnp.zeros((num_tokens,), dtype=probs.dtype) - + if is_forward: # Forward: src -> dest for src_idx in range(num_tokens): @@ -302,13 +306,13 @@ def reference_sort_chunks_by_map( output = output.at[dest_idx].set(inp[src_idx]) if probs is not None: permuted_probs = permuted_probs.at[dest_idx].set(probs[src_idx]) - + return output, permuted_probs class TestPermutation: """Test permutation operations implementation""" - + @staticmethod def generate_routing_map( num_tokens: int, @@ -318,7 +322,7 @@ def generate_routing_map( use_fixed_per_token: bool = True, ): """Generate random routing map for testing - + Parameters ---------- num_tokens : int @@ -336,11 +340,11 @@ def generate_routing_map( """ if key is None: key = jax.random.PRNGKey(0) - + if use_fixed_per_token: - # Old behavior: each token routes to exactly tokens_per_expert experts + # Each token is routed to the same number of experts. The experts are chosen randomly routing_map = jnp.zeros((num_tokens, num_experts), dtype=jnp.int32) - + # Randomly assign each token to tokens_per_expert experts for token_idx in range(num_tokens): key, subkey = jax.random.split(key) @@ -349,479 +353,373 @@ def generate_routing_map( ) routing_map = routing_map.at[token_idx, expert_indices].set(1) else: - # PyTorch-style: randomly distribute routing (varying n_routed per token) + # Varying n_routed per token num_out_tokens = num_tokens * tokens_per_expert - + # Create flat array with num_out_tokens ones flat_array = jnp.zeros((num_tokens * num_experts,), dtype=jnp.int32) flat_array = flat_array.at[:num_out_tokens].set(1) - + # Randomly permute key, subkey = jax.random.split(key) permuted_indices = jax.random.permutation(subkey, num_tokens * num_experts) flat_array = flat_array[permuted_indices] - + # Reshape to routing_map routing_map = flat_array.reshape((num_tokens, num_experts)) - + return routing_map - - # Test make_row_id_map - @pytest.mark.parametrize("num_tokens,num_experts,tokens_per_expert", [ - (32, 8, 2), - (64, 16, 3), - (128, 8, 1), - ]) + + @pytest.mark.parametrize( + "num_tokens,num_experts,tokens_per_expert", + [ + (32, 8, 2), + (64, 16, 3), + (128, 8, 1), + ], + ) @pytest.mark.parametrize("use_fixed_per_token", [True, False]) def test_make_row_id_map(self, num_tokens, num_experts, tokens_per_expert, use_fixed_per_token): """Test make_row_id_map against reference implementation""" key = jax.random.PRNGKey(42) - + # Generate routing map routing_map = self.generate_routing_map( num_tokens, num_experts, tokens_per_expert, key, use_fixed_per_token ) - - # Test implementation + test_row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) - - # Reference implementation + ref_row_id_map = reference_make_row_id_map(routing_map, num_tokens, num_experts) - - # Pretty print for debugging - # print("\n" + "="*100) - # print(f"TEST: make_row_id_map (num_tokens={num_tokens}, num_experts={num_experts}, tokens_per_expert={tokens_per_expert})") - # print("="*100) - - # print("\n📊 ROUTING MAP (rows=tokens, cols=experts):") - # print("-"*100) - # print(routing_map) - - # print("\n📋 FULL ARRAYS:") - # print("-"*100) - # print("\n🔴 ACTUAL (JAX/Triton implementation) - AFTER PASS 3:") - # print(test_row_id_map) - # print("\n🔴 ACTUAL - Columns breakdown:") - # print(f" Sorted dest rows [0:{num_experts}]:") - # print(test_row_id_map[:, :num_experts]) - # print(f" Expert indices [{num_experts}:{2*num_experts}]:") - # print(test_row_id_map[:, num_experts:2*num_experts]) - # print(f" n_routed (last column):") - # print(test_row_id_map[:, -1]) - - # print("\n🟢 EXPECTED (Reference implementation):") - # print(ref_row_id_map) - # print("\n🟢 EXPECTED - Columns breakdown:") - # print(f" Sorted dest rows [0:{num_experts}]:") - # print(ref_row_id_map[:, :num_experts]) - # print(f" Expert indices [{num_experts}:{2*num_experts}]:") - # print(ref_row_id_map[:, num_experts:2*num_experts]) - # print(f" n_routed (last column):") - # print(ref_row_id_map[:, -1]) - - # print("\n🔍 DIFFERENCE (Actual - Expected):") - # diff = test_row_id_map - ref_row_id_map - # print(diff) - - # mismatch_count = jnp.sum(diff != 0) - # total_elements = test_row_id_map.size - # print(f"\n📊 STATISTICS:") - # print(f" Total elements: {total_elements}") - # print(f" Mismatched elements: {mismatch_count} ({100*mismatch_count/total_elements:.1f}%)") - # print(f" Max absolute difference: {jnp.max(jnp.abs(diff))}") - - # print("\n" + "="*100 + "\n") - - # Compare results - only compare valid positions (first n_routed in each section) - # Invalid positions may contain garbage (PyTorch) or -1 (JAX reference), but they're never accessed + + # Compare results only at valid positions (first n_routed in each section) for token_idx in range(num_tokens): n_routed = int(ref_row_id_map[token_idx, -1]) - + # Compare valid dest rows [0:n_routed] assert_allclose( test_row_id_map[token_idx, :n_routed], ref_row_id_map[token_idx, :n_routed], - rtol=0, atol=0, - err_msg=f"Mismatch in dest rows for token {token_idx}" + rtol=0, + atol=0, + err_msg=f"Mismatch in dest rows for token {token_idx}", ) - + # Compare valid expert indices [num_experts:num_experts+n_routed] assert_allclose( - test_row_id_map[token_idx, num_experts:num_experts+n_routed], - ref_row_id_map[token_idx, num_experts:num_experts+n_routed], - rtol=0, atol=0, - err_msg=f"Mismatch in expert indices for token {token_idx}" + test_row_id_map[token_idx, num_experts : num_experts + n_routed], + ref_row_id_map[token_idx, num_experts : num_experts + n_routed], + rtol=0, + atol=0, + err_msg=f"Mismatch in expert indices for token {token_idx}", ) - + # Compare n_routed (last column) assert_allclose( test_row_id_map[token_idx, -1], ref_row_id_map[token_idx, -1], - rtol=0, atol=0, - err_msg=f"Mismatch in n_routed for token {token_idx}" + rtol=0, + atol=0, + err_msg=f"Mismatch in n_routed for token {token_idx}", ) - - # # Optional: Also do a full comparison if both use -1 for invalid positions - # # This will help catch uninitialized memory issues - # if jnp.all((test_row_id_map == -1) | (test_row_id_map >= 0)): - # print("🔬 Both use -1 for invalid positions, doing full comparison...") - # assert_allclose(test_row_id_map, ref_row_id_map, rtol=0, atol=0) - + # Test permute_with_mask_map - @pytest.mark.parametrize("num_tokens,num_experts,hidden_size,tokens_per_expert", [ - (32, 8, 256, 2), - (64, 16, 512, 3), - # Add smaller test cases for easier debugging - (16, 4, 64, 2), # Small case for debugging - (8, 2, 32, 1), # Minimal case - ]) + @pytest.mark.parametrize( + "num_tokens,num_experts,hidden_size,tokens_per_expert", + [ + (32, 8, 256, 2), + (64, 16, 512, 3), + # Smaller test cases for easier debugging + # (8, 2, 32, 1), + ], + ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) @pytest.mark.parametrize("with_probs", [True, False]) - def test_permute_with_mask_map(self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs): + def test_permute_with_mask_map( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, with_probs + ): """Test permute_with_mask_map against reference implementation""" key = jax.random.PRNGKey(42) - + # Generate routing map routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - - # Generate row_id_map + row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) - - # Calculate number of output tokens num_out_tokens = int(jnp.sum(routing_map)) - + # Generate input data key, inp_key, prob_key = jax.random.split(key, 3) - inp = jax.random.uniform(inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) - + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + if with_probs: - probs = jax.random.uniform(prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0) + probs = jax.random.uniform( + prob_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 + ) else: probs = None - - # Test implementation + test_output, test_probs = permute_with_mask_map( inp, row_id_map, probs, num_tokens, num_experts, num_out_tokens, hidden_size ) - - # Reference implementation + ref_output, ref_probs = reference_permute_with_mask_map( inp, row_id_map, probs, num_tokens, num_experts, num_out_tokens, hidden_size ) - - # Debug output for bfloat16 failures - if dtype == jnp.bfloat16 and with_probs: - print(f"\n{'='*100}") - print(f"DEBUG: test_permute_with_mask_map (dtype=bfloat16, with_probs=True)") - print(f" num_tokens={num_tokens}, num_experts={num_experts}, hidden_size={hidden_size}") - print(f" num_out_tokens={num_out_tokens}, tokens_per_expert={tokens_per_expert}") - print(f"{'='*100}") - - # Check output differences (convert to float32 for printing) - output_diff = jnp.abs(test_output.astype(jnp.float32) - ref_output.astype(jnp.float32)) - print(f"\n📊 OUTPUT DIFFERENCES:") - print(f" Max diff: {float(jnp.max(output_diff)):.6f}") - print(f" Mean diff: {float(jnp.mean(output_diff)):.6f}") - print(f" Median diff: {float(jnp.median(output_diff)):.6f}") - print(f" Num elements with diff > 0.1: {int(jnp.sum(output_diff > 0.1))}") - print(f" Num elements with diff > 0.5: {int(jnp.sum(output_diff > 0.5))}") - print(f" Num elements with diff > 0.9: {int(jnp.sum(output_diff > 0.9))}") - - # Find worst mismatches - flat_diff = output_diff.flatten() - worst_indices = jnp.argsort(flat_diff)[-10:] # Top 10 worst - print(f"\n🔍 WORST MISMATCHES (flattened indices):") - for i, idx in enumerate(worst_indices): - row = int(idx // hidden_size) - col = int(idx % hidden_size) - actual_val = float(test_output.flatten()[idx]) - expected_val = float(ref_output.flatten()[idx]) - diff_val = float(flat_diff[idx]) - print(f" [{i}] position=({row},{col}), actual={actual_val:.4f}, " - f"expected={expected_val:.4f}, diff={diff_val:.4f}") - - # Check probs differences if present - if test_probs is not None and ref_probs is not None: - prob_diff = jnp.abs(test_probs.astype(jnp.float32) - ref_probs.astype(jnp.float32)) - print(f"\n📊 PROBS DIFFERENCES:") - print(f" Max diff: {float(jnp.max(prob_diff)):.6f}") - print(f" Mean diff: {float(jnp.mean(prob_diff)):.6f}") - print(f" Num elements with diff > 0.1: {int(jnp.sum(prob_diff > 0.1))}") - - # Check input data quality - print(f"\n📊 INPUT DATA STATS:") - print(f" inp range: [{float(jnp.min(inp)):.4f}, {float(jnp.max(inp)):.4f}]") - print(f" inp has NaN: {bool(jnp.any(jnp.isnan(inp)))}") - if probs is not None: - print(f" probs range: [{float(jnp.min(probs)):.4f}, {float(jnp.max(probs)):.4f}]") - print(f"\n{'='*100}\n") - + # Compare results tols = dtype_tols(dtype) assert_allclose(test_output, ref_output, **tols) - + if with_probs: assert_allclose(test_probs, ref_probs, **tols) - + # Test unpermute_with_mask_map - @pytest.mark.parametrize("num_tokens,num_experts,hidden_size,tokens_per_expert", [ - (32, 8, 256, 2), - (64, 16, 512, 3), - ]) + @pytest.mark.parametrize( + "num_tokens,num_experts,hidden_size,tokens_per_expert", + [ + (32, 8, 256, 2), + (64, 16, 512, 3), + ], + ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) @pytest.mark.parametrize("with_merging_probs", [True, False]) @pytest.mark.parametrize("with_permuted_probs", [True, False]) def test_unpermute_with_mask_map( - self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype, - with_merging_probs, with_permuted_probs + self, + num_tokens, + num_experts, + hidden_size, + tokens_per_expert, + dtype, + with_merging_probs, + with_permuted_probs, ): """Test unpermute_with_mask_map against reference implementation""" key = jax.random.PRNGKey(42) - + # Generate routing map routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - + # Generate row_id_map row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) - + # Calculate number of output tokens num_out_tokens = int(jnp.sum(routing_map)) - + # Generate input data key, inp_key, merge_key, prob_key = jax.random.split(key, 4) - inp = jax.random.uniform(inp_key, (num_out_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) - + inp = jax.random.uniform( + inp_key, (num_out_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + if with_merging_probs: - merging_probs = jax.random.uniform(merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0) + merging_probs = jax.random.uniform( + merge_key, (num_tokens, num_experts), dtype=dtype, minval=0.0, maxval=1.0 + ) # Normalize merging probs per token merging_probs = merging_probs / (jnp.sum(merging_probs, axis=1, keepdims=True) + 1e-8) else: merging_probs = None - + if with_permuted_probs: - permuted_probs = jax.random.uniform(prob_key, (num_out_tokens,), dtype=dtype, minval=0.0, maxval=1.0) + permuted_probs = jax.random.uniform( + prob_key, (num_out_tokens,), dtype=dtype, minval=0.0, maxval=1.0 + ) else: permuted_probs = None - - # Test implementation + test_output, test_unprobs = unpermute_with_mask_map( inp, row_id_map, merging_probs, permuted_probs, num_tokens, num_experts, hidden_size ) - - # Reference implementation + ref_output, ref_unprobs = reference_unpermute_with_mask_map( inp, row_id_map, merging_probs, permuted_probs, num_tokens, num_experts, hidden_size ) - + # Compare results tols = dtype_tols(dtype) # Use relaxed tolerances for unpermute due to accumulation relaxed_tols = dtype_tols(dtype, rtol=tols["rtol"] * 5, atol=tols["atol"] * 5) - + assert_allclose(test_output, ref_output, **relaxed_tols) - + if with_permuted_probs: assert_allclose(test_unprobs, ref_unprobs, **tols) - + # Test round-trip: permute -> unpermute - @pytest.mark.parametrize("num_tokens,num_experts,hidden_size,tokens_per_expert", [ - (32, 8, 256, 2), - (64, 16, 512, 3), - ]) + @pytest.mark.parametrize( + "num_tokens,num_experts,hidden_size,tokens_per_expert", + [ + (32, 8, 256, 2), + (64, 16, 512, 3), + ], + ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) - def test_permute_unpermute_roundtrip(self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype): + def test_permute_unpermute_roundtrip( + self, num_tokens, num_experts, hidden_size, tokens_per_expert, dtype + ): """Test that permute followed by unpermute recovers original input""" key = jax.random.PRNGKey(42) - + # Generate routing map routing_map = self.generate_routing_map(num_tokens, num_experts, tokens_per_expert, key) - + # Generate row_id_map row_id_map = make_row_id_map(routing_map, num_tokens, num_experts) - + # Calculate number of output tokens num_out_tokens = int(jnp.sum(routing_map)) - + # Generate input data key, inp_key = jax.random.split(key) - inp = jax.random.uniform(inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) - + inp = jax.random.uniform( + inp_key, (num_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + # Create uniform merging probs (equal weight for all routed experts) - merging_probs = routing_map.astype(dtype) / jnp.maximum(jnp.sum(routing_map, axis=1, keepdims=True), 1.0) - + merging_probs = routing_map.astype(dtype) / jnp.maximum( + jnp.sum(routing_map, axis=1, keepdims=True), 1.0 + ) + # Permute permuted, _ = permute_with_mask_map( inp, row_id_map, None, num_tokens, num_experts, num_out_tokens, hidden_size ) - + # Unpermute with uniform merging unpermuted, _ = unpermute_with_mask_map( permuted, row_id_map, merging_probs, None, num_tokens, num_experts, hidden_size ) - + # Compare with original input tols = dtype_tols(dtype) relaxed_tols = dtype_tols(dtype, rtol=tols["rtol"] * 10, atol=tols["atol"] * 10) assert_allclose(unpermuted, inp, **relaxed_tols) - - # Test make_chunk_sort_map - @pytest.mark.parametrize("num_splits,total_tokens", [ - (4, 128), - (8, 256), - (16, 512), - ]) + + @pytest.mark.parametrize( + "num_splits,total_tokens", + [ + (4, 128), + (8, 256), + (16, 512), + ], + ) def test_make_chunk_sort_map(self, num_splits, total_tokens): """Test make_chunk_sort_map against reference implementation""" key = jax.random.PRNGKey(42) - + # Generate random split sizes key, size_key = jax.random.split(key) split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits) # Adjust last split to match total_tokens split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1])) - + # Generate sorted indices (permutation of 0..num_splits-1) key, sort_key = jax.random.split(key) sorted_indices = jax.random.permutation(sort_key, num_splits) - - # Test implementation + test_map = make_chunk_sort_map(split_sizes, sorted_indices, total_tokens, num_splits) - - # Reference implementation - ref_map = reference_make_chunk_sort_map(split_sizes, sorted_indices, total_tokens, num_splits) - - # Compare results + ref_map = reference_make_chunk_sort_map( + split_sizes, sorted_indices, total_tokens, num_splits + ) + assert_allclose(test_map, ref_map, rtol=0, atol=0) - - # Test sort_chunks_by_map - @pytest.mark.parametrize("num_splits,total_tokens,hidden_size", [ - (4, 128, 256), - (8, 256, 512), - ]) + + @pytest.mark.parametrize( + "num_splits,total_tokens,hidden_size", + [ + (4, 128, 256), + (8, 256, 512), + ], + ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) @pytest.mark.parametrize("is_forward", [True, False]) @pytest.mark.parametrize("with_probs", [True, False]) - def test_sort_chunks_by_map(self, num_splits, total_tokens, hidden_size, dtype, is_forward, with_probs): + def test_sort_chunks_by_map( + self, num_splits, total_tokens, hidden_size, dtype, is_forward, with_probs + ): """Test sort_chunks_by_map against reference implementation""" key = jax.random.PRNGKey(42) - + # Generate random split sizes key, size_key = jax.random.split(key) split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits) split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1])) - + # Generate sorted indices key, sort_key = jax.random.split(key) sorted_indices = jax.random.permutation(sort_key, num_splits) - - # Generate row_id_map + row_id_map = make_chunk_sort_map(split_sizes, sorted_indices, total_tokens, num_splits) - - # Generate input data + key, inp_key, prob_key = jax.random.split(key, 3) - inp = jax.random.uniform(inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) - + inp = jax.random.uniform( + inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + if with_probs: - probs = jax.random.uniform(prob_key, (total_tokens,), dtype=dtype, minval=0.0, maxval=1.0) + probs = jax.random.uniform( + prob_key, (total_tokens,), dtype=dtype, minval=0.0, maxval=1.0 + ) else: probs = None - - # Debug prints - # print("\n" + "="*100) - # print(f"TEST: sort_chunks_by_map (num_splits={num_splits}, total_tokens={total_tokens}, hidden_size={hidden_size})") - # print(f" dtype={dtype}, is_forward={is_forward}, with_probs={with_probs}") - # print("="*100) - - # print("\n📊 TEST INPUTS:") - # print("-"*100) - # print(f"split_sizes: {split_sizes}") - # print(f"sorted_indices: {sorted_indices}") - # print(f"row_id_map (first 10): {row_id_map[:10]}") - # print(f"inp.shape: {inp.shape}, inp.dtype: {inp.dtype}") - # print(f"inp[0, :5]: {inp[0, :5]}") # First 5 elements of first row - # if probs is not None: - # print(f"probs.shape: {probs.shape}, probs.dtype: {probs.dtype}") - # print(f"probs[:5]: {probs[:5]}") - - # Test implementation + test_output, test_probs = sort_chunks_by_map( inp, row_id_map, probs, total_tokens, hidden_size, is_forward ) - - # Reference implementation + ref_output, ref_probs = reference_sort_chunks_by_map( inp, row_id_map, probs, total_tokens, hidden_size, is_forward ) - - # print("\n📋 OUTPUT COMPARISON:") - # print("-"*100) - # print(f"\n🔴 ACTUAL (Triton implementation):") - # print(f"test_output.shape: {test_output.shape}, test_output.dtype: {test_output.dtype}") - # print(f"test_output[0, :5]: {test_output[0, :5]}") # First 5 elements of first row - # print(f"test_output has NaN: {jnp.any(jnp.isnan(test_output))}") - # print(f"test_output has Inf: {jnp.any(jnp.isinf(test_output))}") - # if test_probs is not None: - # print(f"test_probs[:5]: {test_probs[:5]}") - - # print(f"\n🟢 EXPECTED (Reference implementation):") - # print(f"ref_output.shape: {ref_output.shape}, ref_output.dtype: {ref_output.dtype}") - # print(f"ref_output[0, :5]: {ref_output[0, :5]}") # First 5 elements of first row - # if ref_probs is not None: - # print(f"ref_probs[:5]: {ref_probs[:5]}") - - # print("\n🔍 DIFFERENCE (Actual - Expected):") - # if not jnp.any(jnp.isnan(test_output)): - # diff = test_output - ref_output - # print(f"Max absolute difference: {jnp.max(jnp.abs(diff))}") - # print(f"Mean absolute difference: {jnp.mean(jnp.abs(diff))}") - # else: - # print("Cannot compute difference - test_output contains NaN values") - - # print("\n" + "="*100 + "\n") - - # Compare results + tols = dtype_tols(dtype) assert_allclose(test_output, ref_output, **tols) - + if with_probs: assert_allclose(test_probs, ref_probs, **tols) - - # Test chunk sort round-trip - @pytest.mark.parametrize("num_splits,total_tokens,hidden_size", [ - (4, 128, 256), - (8, 256, 512), - ]) + + @pytest.mark.parametrize( + "num_splits,total_tokens,hidden_size", + [ + (4, 128, 256), + (8, 256, 512), + ], + ) @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16]) def test_chunk_sort_roundtrip(self, num_splits, total_tokens, hidden_size, dtype): """Test that forward sort followed by backward sort recovers original input""" key = jax.random.PRNGKey(42) - + # Generate random split sizes key, size_key = jax.random.split(key) split_sizes = jax.random.randint(size_key, (num_splits,), 10, total_tokens // num_splits) split_sizes = split_sizes.at[-1].set(total_tokens - jnp.sum(split_sizes[:-1])) - + # Generate sorted indices key, sort_key = jax.random.split(key) sorted_indices = jax.random.permutation(sort_key, num_splits) - + # Generate row_id_map row_id_map = make_chunk_sort_map(split_sizes, sorted_indices, total_tokens, num_splits) - + # Generate input data key, inp_key = jax.random.split(key) - inp = jax.random.uniform(inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0) - + inp = jax.random.uniform( + inp_key, (total_tokens, hidden_size), dtype=dtype, minval=-1.0, maxval=1.0 + ) + # Forward sort sorted_output, _ = sort_chunks_by_map( inp, row_id_map, None, total_tokens, hidden_size, is_forward=True ) - + # Backward sort (should recover original) recovered, _ = sort_chunks_by_map( sorted_output, row_id_map, None, total_tokens, hidden_size, is_forward=False ) - + # Compare with original input tols = dtype_tols(dtype) - assert_allclose(recovered, inp, **tols) \ No newline at end of file + assert_allclose(recovered, inp, **tols) diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 4b06b9a7fe..e8c43f52d2 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -603,4 +603,4 @@ def _sort_chunks_by_map_kernel( key=["hidden_size"], )(_sort_chunks_by_map_kernel) except RuntimeError: - pass \ No newline at end of file + pass diff --git a/transformer_engine/jax/triton/permutation.py b/transformer_engine/jax/triton/permutation.py index 5a1c996f4f..27d17f24a2 100644 --- a/transformer_engine/jax/triton/permutation.py +++ b/transformer_engine/jax/triton/permutation.py @@ -30,7 +30,7 @@ def make_row_id_map( ) -> jnp.ndarray: """ Prepare the row_id_map for the permutation using JAX-Triton. - + Parameters ---------- routing_map : jnp.ndarray @@ -41,7 +41,7 @@ def make_row_id_map( Number of tokens in the input tensor. num_experts : int Number of experts in the input tensor. - + Returns ------- row_id_map : jnp.ndarray @@ -55,7 +55,7 @@ def make_row_id_map( block_size = 1024 grid = (num_experts, triton.cdiv(num_tokens, block_size)) workspace_tensor = jnp.zeros(grid, dtype=jnp.int32) - + # supposing num_tokens == 5, num_experts == 3, block_size == 3 # and we have a routing_map like this: # [[1, 1, 0], @@ -73,51 +73,52 @@ def make_row_id_map( # [1, 1, 0, r, r, r, r], # [0, 0, 0, r, r, r, r]] # Note: "r" = -1 in the triton common kernel implementation - + # Compute strides manually (JAX arrays don't have .strides attribute) - # For routing_map of shape [num_tokens, num_experts], C-contiguous: - routing_stride_token = num_experts # Move to next token - routing_stride_expert = 1 # Move to next expert (contiguous) - # For row_id_map of shape [num_tokens, num_experts * 2 + 1], C-contiguous: - row_id_stride_token = num_experts * 2 + 1 # Move to next token - row_id_stride_expert = 1 # Move to next column (contiguous) - + # For routing_map of shape [num_tokens, num_experts] + routing_stride_token = num_experts + routing_stride_expert = 1 + # For row_id_map of shape [num_tokens, num_experts * 2 + 1] + row_id_stride_token = num_experts * 2 + 1 + row_id_stride_expert = 1 # Move to next column (contiguous) + # Pass 1: Block cumsum row_id_map_pass1, workspace_tensor = jt.triton_call( - routing_map, # Input 0 (ptr): routing_map_ptr - num_tokens, # Scalar: num_tokens - routing_stride_token, # Scalar: stride_routing_map_token - routing_stride_expert, # Scalar: stride_routing_map_expert - row_id_stride_token, # Scalar: stride_row_id_map_token - row_id_stride_expert, # Scalar: stride_row_id_map_expert + routing_map, + num_tokens, + routing_stride_token, + routing_stride_expert, + row_id_stride_token, + row_id_stride_expert, kernel=_row_id_map_pass_1_kernel, out_shape=[ ShapeDtypeStruct(row_id_map.shape, row_id_map.dtype), ShapeDtypeStruct(workspace_tensor.shape, workspace_tensor.dtype), ], grid=grid, - BLOCK_SIZE=block_size, # Constexpr - pass as keyword + BLOCK_SIZE=block_size, ) # Pass 2: cumsum all and process the mask - # Strides remain the same as Pass 1 - # Note: Pass 2 takes the outputs from Pass 1 as inputs row_id_map_pass2, workspace_tensor = jt.triton_call( - row_id_map_pass1, # Input 0 (ptr): row_id_map_ptr (from Pass 1) - workspace_tensor, # Input 1 (ptr): workspace_ptr (from Pass 1) - num_tokens, # Scalar: num_tokens - row_id_stride_token, # Scalar: stride_row_id_map_token - row_id_stride_expert, # Scalar: stride_row_id_map_expert + row_id_map_pass1, + workspace_tensor, + num_tokens, + row_id_stride_token, + row_id_stride_expert, kernel=_row_id_map_pass_2_kernel, out_shape=[ ShapeDtypeStruct(row_id_map.shape, row_id_map.dtype), ShapeDtypeStruct(workspace_tensor.shape, workspace_tensor.dtype), ], - input_output_aliases={0: 0, 1: 1}, # row_id_map input→output, workspace input→output + input_output_aliases={0: 0, 1: 1}, grid=grid, - WORKSPACE_LOAD_WIDTH=triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)), # Constexpr - BLOCK_SIZE=block_size, # Constexpr + WORKSPACE_LOAD_WIDTH=triton.next_power_of_2( + num_experts * triton.cdiv(num_tokens, block_size) + ), + BLOCK_SIZE=block_size, ) + # Initialize columns [num_experts:] to -1 since Pass 1/2 only wrote to [0:num_experts] # Reference implementation expects -1 for invalid entries, not garbage row_id_map = row_id_map_pass2.at[:, num_experts:].set(-1) @@ -126,17 +127,17 @@ def make_row_id_map( grid = (num_tokens,) load_size = triton.next_power_of_2(num_experts) row_id_map = jt.triton_call( - row_id_map, # Input 0 (ptr): row_id_map_ptr (from Pass 2, with -1 initialized) - row_id_stride_token, # Scalar 1: stride_row_id_map_token - row_id_stride_expert, # Scalar 2: stride_row_id_map_expert + row_id_map, + row_id_stride_token, + row_id_stride_expert, kernel=_row_id_map_pass_3_kernel, out_shape=[ShapeDtypeStruct(row_id_map.shape, row_id_map.dtype)], - input_output_aliases={0: 0}, # row_id_map input→output + input_output_aliases={0: 0}, num_experts=num_experts, grid=grid, - LOAD_SIZE=load_size, # Constexpr + LOAD_SIZE=load_size, )[0] - + return row_id_map @@ -151,7 +152,7 @@ def permute_with_mask_map( ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """ Permute the input tensor based on the row_id_map using JAX-Triton. - + Parameters ---------- inp : jnp.ndarray @@ -168,7 +169,7 @@ def permute_with_mask_map( Number of tokens in the permuted tensor. hidden_size : int Hidden size of the input tensor. - + Returns ------- output : jnp.ndarray @@ -176,116 +177,117 @@ def permute_with_mask_map( permuted_probs : Optional[jnp.ndarray] Permuted probabilities if probs was provided, None otherwise. """ + # Compute strides manually (JAX arrays don't have .strides attribute) - # For inp of shape [num_tokens, hidden_size], C-contiguous: + + # [num_tokens, hidden_size] inp_stride_token = hidden_size inp_stride_hidden = 1 - # For output of shape [num_out_tokens, hidden_size], C-contiguous: + # [num_out_tokens, hidden_size] output_stride_token = hidden_size output_stride_hidden = 1 - # For row_id_map of shape [num_tokens, num_experts * 2 + 1], C-contiguous: + # [num_tokens, num_experts * 2 + 1] row_id_stride_token = num_experts * 2 + 1 row_id_stride_expert = 1 - # For probs: depends on dimensionality + if probs is not None: if probs.ndim > 1: - # Shape [num_tokens, num_experts] + # [num_tokens, num_experts] probs_stride_token = num_experts probs_stride_expert = 1 else: - # Shape [num_tokens] + # [num_tokens] probs_stride_token = 1 probs_stride_expert = 1 else: probs_stride_token = 0 probs_stride_expert = 0 - # For permuted_probs of shape [num_out_tokens], C-contiguous: + + # [num_out_tokens] permuted_probs_stride_token = 1 - - # Grid: one block per token, multiple blocks for hidden dimension + + # one block per token, multiple blocks for hidden dimension def grid_fn(meta): - return (num_tokens, triton.cdiv(hidden_size, meta['BLOCK_SIZE'])) - + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + if probs is not None: # jax-triton doesn't handle None pointers correctly, create dummy tensors - # Make dummy tensors large enough to not cause out-of-bounds access dummy_scale = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) dummy_permuted_scale = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) - + output, permuted_probs = jt.triton_call( - inp, # Input 0 (ptr): input_ptr - row_id_map, # Input 1 (ptr): row_id_map_ptr - probs, # Input 2 (ptr): probs_ptr - dummy_scale, # Input 3 (ptr): scale_ptr (dummy, not used) - dummy_permuted_scale, # Input 4 (ptr): permuted_scale_ptr (dummy, not used) - 0, # Scalar 5: scale_hidden_dim (not used) - row_id_stride_token, # Scalar 6: stride_row_id_map_token - row_id_stride_expert, # Scalar 7: stride_row_id_map_expert - inp_stride_token, # Scalar 8: stride_input_token - inp_stride_hidden, # Scalar 9: stride_input_hidden - output_stride_token, # Scalar 10: stride_output_token - output_stride_hidden, # Scalar 11: stride_output_hidden - probs_stride_token, # Scalar 12: stride_probs_token - probs_stride_expert, # Scalar 13: stride_probs_expert - hidden_size, # Scalar 14: stride_scale_token (use actual stride) - 1, # Scalar 15: stride_scale_hidden - permuted_probs_stride_token, # Scalar 16: stride_permuted_probs_token - hidden_size, # Scalar 17: stride_permuted_scale_token (use actual stride) - 1, # Scalar 18: stride_permuted_scale_hidden + inp, + row_id_map, + probs, + dummy_scale, + dummy_permuted_scale, + 0, + row_id_stride_token, + row_id_stride_expert, + inp_stride_token, + inp_stride_hidden, + output_stride_token, + output_stride_hidden, + probs_stride_token, + probs_stride_expert, + hidden_size, + 1, + permuted_probs_stride_token, + hidden_size, + 1, kernel=_permute_kernel, out_shape=[ - ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), # Positional: output_ptr - ShapeDtypeStruct((num_out_tokens,), probs.dtype), # Positional: permuted_probs_ptr + ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((num_out_tokens,), probs.dtype), ], grid=grid_fn, - num_experts=num_experts, # Keyword constexpr - hidden_size=hidden_size, # Keyword constexpr - PERMUTE_PROBS=True, # Keyword constexpr - PERMUTE_SCALE=False, # Keyword constexpr + num_experts=num_experts, + hidden_size=hidden_size, + PERMUTE_PROBS=True, + PERMUTE_SCALE=False, # BLOCK_SIZE is keyword constexpr from autotune ) else: # jax-triton doesn't handle None pointers correctly, create dummy tensors - # Make dummy tensors large enough to not cause out-of-bounds access dummy_probs = jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) dummy_scale = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) dummy_permuted_scale = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) - + result = jt.triton_call( - inp, # Input 0 (ptr): input_ptr - row_id_map, # Input 1 (ptr): row_id_map_ptr - dummy_probs, # Input 2 (ptr): probs_ptr (dummy, not used) - dummy_scale, # Input 3 (ptr): scale_ptr (dummy, not used) - dummy_permuted_scale, # Input 4 (ptr): permuted_scale_ptr (dummy, not used) - 0, # Scalar 5: scale_hidden_dim (not used) - row_id_stride_token, # Scalar 6: stride_row_id_map_token - row_id_stride_expert, # Scalar 7: stride_row_id_map_expert - inp_stride_token, # Scalar 8: stride_input_token - inp_stride_hidden, # Scalar 9: stride_input_hidden - output_stride_token, # Scalar 10: stride_output_token - output_stride_hidden, # Scalar 11: stride_output_hidden - probs_stride_token, # Scalar 12: stride_probs_token (use actual) - probs_stride_expert, # Scalar 13: stride_probs_expert (use actual) - hidden_size, # Scalar 14: stride_scale_token (use actual stride) - 1, # Scalar 15: stride_scale_hidden - permuted_probs_stride_token, # Scalar 16: stride_permuted_probs_token (use actual) - hidden_size, # Scalar 17: stride_permuted_scale_token (use actual stride) - 1, # Scalar 18: stride_permuted_scale_hidden + inp, + row_id_map, + dummy_probs, + dummy_scale, + dummy_permuted_scale, + 0, + row_id_stride_token, + row_id_stride_expert, + inp_stride_token, + inp_stride_hidden, + output_stride_token, + output_stride_hidden, + probs_stride_token, + probs_stride_expert, + hidden_size, + 1, + permuted_probs_stride_token, + hidden_size, + 1, kernel=_permute_kernel, out_shape=[ - ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), # Positional: output_ptr - ShapeDtypeStruct((num_out_tokens,), inp.dtype), # Positional: permuted_probs_ptr (dummy) + ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((num_out_tokens,), inp.dtype), ], grid=grid_fn, - num_experts=num_experts, # Keyword constexpr - hidden_size=hidden_size, # Keyword constexpr - PERMUTE_PROBS=False, # Keyword constexpr - PERMUTE_SCALE=False, # Keyword constexpr + num_experts=num_experts, + hidden_size=hidden_size, + PERMUTE_PROBS=False, + PERMUTE_SCALE=False, # BLOCK_SIZE is keyword constexpr from autotune ) output = result[0] permuted_probs = None - + return output, permuted_probs @@ -300,7 +302,7 @@ def unpermute_with_mask_map( ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """ Unpermute the input tensor based on the row_id_map using JAX-Triton. - + Parameters ---------- inp : jnp.ndarray @@ -318,7 +320,7 @@ def unpermute_with_mask_map( Number of experts in the permuted tensor. hidden_size : int Hidden size of the permuted tensor. - + Returns ------- output : jnp.ndarray @@ -327,102 +329,109 @@ def unpermute_with_mask_map( Unpermuted probabilities if permuted_probs was provided, None otherwise. """ # Compute strides manually (JAX arrays don't have .strides attribute) - # For inp of shape [num_out_tokens, hidden_size], C-contiguous: + # [num_out_tokens, hidden_size], inp_stride_token = hidden_size inp_stride_hidden = 1 - # For output of shape [num_tokens, hidden_size], C-contiguous: + # [num_tokens, hidden_size], output_stride_token = hidden_size output_stride_hidden = 1 - # For row_id_map of shape [num_tokens, num_experts * 2 + 1], C-contiguous: + # [num_tokens, num_experts * 2 + 1], row_id_stride_token = num_experts * 2 + 1 row_id_stride_expert = 1 - # For merging_probs of shape [num_tokens, num_experts] if present: + # [num_tokens, num_experts] if present: if merging_probs is not None: merging_probs_stride_token = num_experts merging_probs_stride_expert = 1 else: merging_probs_stride_token = 0 merging_probs_stride_expert = 0 - # For permuted_probs of shape [num_out_tokens] if present: + # [num_out_tokens] if present: permuted_probs_stride_token = 1 - # For unpermuted_probs of shape [num_tokens, num_experts] (output): + # [num_tokens, num_experts] (output): unpermuted_probs_stride_token = num_experts unpermuted_probs_stride_expert = 1 - - # Grid: one block per token, multiple blocks for hidden dimension + + # One block per token, multiple blocks for hidden dimension def grid_fn(meta): - return (num_tokens, triton.cdiv(hidden_size, meta['BLOCK_SIZE'])) - + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + if permuted_probs is not None: - # Ensure merging_probs is not None (use dummy if needed) - merging_probs_arg = merging_probs if merging_probs is not None else jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) - + merging_probs_arg = ( + merging_probs + if merging_probs is not None + else jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) + ) + output, unpermuted_probs = jt.triton_call( - inp, # Input 0 (ptr): input_ptr - row_id_map, # Input 1 (ptr): row_id_map_ptr - merging_probs_arg, # Input 2 (ptr): merging_probs_ptr (real or dummy) - permuted_probs, # Input 3 (ptr): permuted_probs_ptr - row_id_stride_token, # Scalar 4: stride_row_id_map_token - row_id_stride_expert, # Scalar 5: stride_row_id_map_expert - inp_stride_token, # Scalar 6: stride_input_token - inp_stride_hidden, # Scalar 7: stride_input_hidden - output_stride_token, # Scalar 8: stride_output_token - output_stride_hidden, # Scalar 9: stride_output_hidden - merging_probs_stride_token, # Scalar 10: stride_merging_probs_token - merging_probs_stride_expert, # Scalar 11: stride_merging_probs_expert - permuted_probs_stride_token, # Scalar 12: stride_permuted_probs_token - unpermuted_probs_stride_token, # Scalar 13: stride_unpermuted_probs_token - unpermuted_probs_stride_expert, # Scalar 14: stride_unpermuted_probs_expert + inp, + row_id_map, + merging_probs_arg, + permuted_probs, + row_id_stride_token, + row_id_stride_expert, + inp_stride_token, + inp_stride_hidden, + output_stride_token, + output_stride_hidden, + merging_probs_stride_token, + merging_probs_stride_expert, + permuted_probs_stride_token, + unpermuted_probs_stride_token, + unpermuted_probs_stride_expert, kernel=_unpermute_kernel, out_shape=[ - ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), # Positional: output_ptr - ShapeDtypeStruct((num_tokens, num_experts), permuted_probs.dtype), # Positional: unpermuted_probs_ptr + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((num_tokens, num_experts), permuted_probs.dtype), ], grid=grid_fn, - num_experts=num_experts, # Keyword constexpr - hidden_size=hidden_size, # Keyword constexpr - PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), # Keyword constexpr - WITH_MERGING_PROBS=merging_probs is not None, # Keyword constexpr - PERMUTE_PROBS=True, # Keyword constexpr + num_experts=num_experts, + hidden_size=hidden_size, + PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), + WITH_MERGING_PROBS=merging_probs is not None, + PERMUTE_PROBS=True, # BLOCK_SIZE is keyword constexpr from autotune ) else: # jax-triton doesn't handle None pointers correctly, create dummy tensors if needed - dummy_permuted_probs = jnp.zeros((num_tokens,), dtype=inp.dtype) # Proper size dummy - merging_probs_arg = merging_probs if merging_probs is not None else jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) - + dummy_permuted_probs = jnp.zeros((num_tokens,), dtype=inp.dtype) + merging_probs_arg = ( + merging_probs + if merging_probs is not None + else jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) + ) + result = jt.triton_call( - inp, # Input 0 (ptr): input_ptr - row_id_map, # Input 1 (ptr): row_id_map_ptr - merging_probs_arg, # Input 2 (ptr): merging_probs_ptr (real or dummy) - dummy_permuted_probs, # Input 3 (ptr): permuted_probs_ptr (dummy, not used) - row_id_stride_token, # Scalar 4: stride_row_id_map_token - row_id_stride_expert, # Scalar 5: stride_row_id_map_expert - inp_stride_token, # Scalar 6: stride_input_token - inp_stride_hidden, # Scalar 7: stride_input_hidden - output_stride_token, # Scalar 8: stride_output_token - output_stride_hidden, # Scalar 9: stride_output_hidden - merging_probs_stride_token, # Scalar 10: stride_merging_probs_token - merging_probs_stride_expert, # Scalar 11: stride_merging_probs_expert - 1, # Scalar 12: stride_permuted_probs_token (dummy stride) - 0, # Scalar 13: stride_unpermuted_probs_token - 0, # Scalar 14: stride_unpermuted_probs_expert + inp, + row_id_map, + merging_probs_arg, + dummy_permuted_probs, + row_id_stride_token, + row_id_stride_expert, + inp_stride_token, + inp_stride_hidden, + output_stride_token, + output_stride_hidden, + merging_probs_stride_token, + merging_probs_stride_expert, + 1, + 0, + 0, kernel=_unpermute_kernel, out_shape=[ - ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), # Positional: output_ptr - ShapeDtypeStruct((num_tokens, num_experts), inp.dtype), # Positional: unpermuted_probs_ptr (dummy) + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((num_tokens, num_experts), inp.dtype), ], grid=grid_fn, - num_experts=num_experts, # Keyword constexpr - hidden_size=hidden_size, # Keyword constexpr - PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), # Keyword constexpr - WITH_MERGING_PROBS=merging_probs is not None, # Keyword constexpr - PERMUTE_PROBS=False, # Keyword constexpr + num_experts=num_experts, + hidden_size=hidden_size, + PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), + WITH_MERGING_PROBS=merging_probs is not None, + PERMUTE_PROBS=False, # BLOCK_SIZE is keyword constexpr from autotune ) output = result[0] unpermuted_probs = None - + return output, unpermuted_probs @@ -434,7 +443,7 @@ def make_chunk_sort_map( ) -> jnp.ndarray: """ Make a row_id_map for chunk sort using JAX-Triton. - + Parameters ---------- split_sizes : jnp.ndarray @@ -445,24 +454,24 @@ def make_chunk_sort_map( Number of tokens in the input tensor. num_splits : int Number of splits of split_sizes and sorted_indices. - + Returns ------- row_id_map : jnp.ndarray Row ID map for chunk sorting of shape `[num_tokens,]`. """ grid = (num_tokens,) - + row_id_map = jt.triton_call( - split_sizes, # Input 0 (ptr): split_sizes_ptr - sorted_indices, # Input 1 (ptr): sorted_indices_ptr + split_sizes, + sorted_indices, kernel=_make_chunk_sort_map_kernel, out_shape=[ShapeDtypeStruct((num_tokens,), jnp.int32)], grid=grid, - num_splits=num_splits, # Constexpr - IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits), # Constexpr + num_splits=num_splits, + IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits), )[0] - + return row_id_map @@ -476,7 +485,7 @@ def sort_chunks_by_map( ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: """ Sort chunks with row_id_map using JAX-Triton. - + Parameters ---------- inp : jnp.ndarray @@ -491,7 +500,7 @@ def sort_chunks_by_map( Hidden size of the input tensor. is_forward : bool Whether the sort is for forward or backward. - + Returns ------- output : jnp.ndarray @@ -500,68 +509,66 @@ def sort_chunks_by_map( Sorted probabilities if probs was provided, None otherwise. """ # Compute strides manually (JAX arrays don't have .strides attribute) - # For inp and output of shape [num_tokens, hidden_size], C-contiguous: + # [num_tokens, hidden_size] inp_stride_token = hidden_size inp_stride_hidden = 1 output_stride_token = hidden_size output_stride_hidden = 1 - # For probs and permuted_probs of shape [num_tokens], C-contiguous: + # [num_tokens] probs_stride_token = 1 permuted_probs_stride_token = 1 - - # Grid: one block per token, multiple blocks for hidden dimension + def grid_fn(meta): - return (num_tokens, triton.cdiv(hidden_size, meta['BLOCK_SIZE'])) - + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + if probs is not None: output, permuted_probs = jt.triton_call( - inp, # Input 0 (ptr): input_ptr - row_id_map, # Input 1 (ptr): row_id_map_ptr - probs, # Input 2 (ptr): probs_ptr - inp_stride_token, # Scalar 3: stride_input_token - inp_stride_hidden, # Scalar 4: stride_input_hidden - output_stride_token, # Scalar 5: stride_output_token - output_stride_hidden, # Scalar 6: stride_output_hidden - probs_stride_token, # Scalar 7: stride_probs_token - permuted_probs_stride_token, # Scalar 8: stride_permuted_probs_token + inp, + row_id_map, + probs, + inp_stride_token, + inp_stride_hidden, + output_stride_token, + output_stride_hidden, + probs_stride_token, + permuted_probs_stride_token, kernel=_sort_chunks_by_map_kernel, out_shape=[ - ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), # Added at end: output_ptr - ShapeDtypeStruct((num_tokens,), probs.dtype), # Added at end: permuted_probs_ptr + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((num_tokens,), probs.dtype), ], grid=grid_fn, - hidden_size=hidden_size, # Constexpr 9: hidden_size - PERMUTE_PROBS=True, # Constexpr 10: PERMUTE_PROBS - # BLOCK_SIZE is Constexpr 11, provided by autotune - FORWARD=is_forward, # Constexpr 12: FORWARD - # output_ptr and permuted_probs_ptr (13-14) are added automatically by jax-triton from out_shape + hidden_size=hidden_size, + PERMUTE_PROBS=True, + # BLOCK_SIZE is provided by autotune + FORWARD=is_forward, ) else: - # Note: jax-triton might not handle None correctly, so create a dummy probs tensor + dummy_probs = jnp.zeros((num_tokens,), dtype=inp.dtype) - + result = jt.triton_call( - inp, # Input 0 (ptr): input_ptr - row_id_map, # Input 1 (ptr): row_id_map_ptr - dummy_probs, # Input 2 (ptr): probs_ptr (dummy, not used by kernel) - inp_stride_token, # Scalar 3: stride_input_token - inp_stride_hidden, # Scalar 4: stride_input_hidden - output_stride_token, # Scalar 5: stride_output_token - output_stride_hidden, # Scalar 6: stride_output_hidden - probs_stride_token, # Scalar 7: stride_probs_token (use actual stride) - permuted_probs_stride_token, # Scalar 8: stride_permuted_probs_token + inp, + row_id_map, + dummy_probs, + inp_stride_token, + inp_stride_hidden, + output_stride_token, + output_stride_hidden, + probs_stride_token, + permuted_probs_stride_token, kernel=_sort_chunks_by_map_kernel, out_shape=[ - ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), # Positional after strides: output_ptr - ShapeDtypeStruct((num_tokens,), inp.dtype), # Positional after strides: permuted_probs_ptr (dummy) + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((num_tokens,), inp.dtype), ], grid=grid_fn, - hidden_size=hidden_size, # Keyword constexpr - PERMUTE_PROBS=False, # Keyword constexpr - FORWARD=is_forward, # Keyword constexpr + hidden_size=hidden_size, + PERMUTE_PROBS=False, + FORWARD=is_forward, # BLOCK_SIZE is added by autotune as keyword constexpr ) output = result[0] permuted_probs = None - - return output, permuted_probs \ No newline at end of file + + return output, permuted_probs From 0fb64329f0c3dd8047fa20b2c2fb2c82024ef2e8 Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Wed, 26 Nov 2025 15:49:58 -0800 Subject: [PATCH 08/10] Set 0 as the size of dummy tensors to reduce memory usage. --- transformer_engine/jax/triton/permutation.py | 352 ++++++++----------- 1 file changed, 140 insertions(+), 212 deletions(-) diff --git a/transformer_engine/jax/triton/permutation.py b/transformer_engine/jax/triton/permutation.py index 27d17f24a2..c104feb461 100644 --- a/transformer_engine/jax/triton/permutation.py +++ b/transformer_engine/jax/triton/permutation.py @@ -51,10 +51,10 @@ def make_row_id_map( The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding to the first n_routed row indices above. """ - row_id_map = jnp.full((num_tokens, num_experts * 2 + 1), -1, dtype=jnp.int32) + row_id_map_shape = (num_tokens, num_experts * 2 + 1) block_size = 1024 grid = (num_experts, triton.cdiv(num_tokens, block_size)) - workspace_tensor = jnp.zeros(grid, dtype=jnp.int32) + workspace_tensor_shape = grid # supposing num_tokens == 5, num_experts == 3, block_size == 3 # and we have a routing_map like this: @@ -92,8 +92,8 @@ def make_row_id_map( row_id_stride_expert, kernel=_row_id_map_pass_1_kernel, out_shape=[ - ShapeDtypeStruct(row_id_map.shape, row_id_map.dtype), - ShapeDtypeStruct(workspace_tensor.shape, workspace_tensor.dtype), + ShapeDtypeStruct(row_id_map_shape, jnp.int32), + ShapeDtypeStruct(workspace_tensor_shape, jnp.int32), ], grid=grid, BLOCK_SIZE=block_size, @@ -108,8 +108,8 @@ def make_row_id_map( row_id_stride_expert, kernel=_row_id_map_pass_2_kernel, out_shape=[ - ShapeDtypeStruct(row_id_map.shape, row_id_map.dtype), - ShapeDtypeStruct(workspace_tensor.shape, workspace_tensor.dtype), + ShapeDtypeStruct(row_id_map_shape, jnp.int32), + ShapeDtypeStruct(workspace_tensor_shape, jnp.int32), ], input_output_aliases={0: 0, 1: 1}, grid=grid, @@ -131,12 +131,12 @@ def make_row_id_map( row_id_stride_token, row_id_stride_expert, kernel=_row_id_map_pass_3_kernel, - out_shape=[ShapeDtypeStruct(row_id_map.shape, row_id_map.dtype)], + out_shape=ShapeDtypeStruct(row_id_map_shape, jnp.int32), input_output_aliases={0: 0}, num_experts=num_experts, grid=grid, LOAD_SIZE=load_size, - )[0] + ) return row_id_map @@ -177,9 +177,12 @@ def permute_with_mask_map( permuted_probs : Optional[jnp.ndarray] Permuted probabilities if probs was provided, None otherwise. """ + # one block per token, multiple blocks for hidden dimension + def grid_fn(meta): + return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + with_probs = probs is not None # Compute strides manually (JAX arrays don't have .strides attribute) - # [num_tokens, hidden_size] inp_stride_token = hidden_size inp_stride_hidden = 1 @@ -190,7 +193,10 @@ def permute_with_mask_map( row_id_stride_token = num_experts * 2 + 1 row_id_stride_expert = 1 - if probs is not None: + # [num_out_tokens] + permuted_probs_stride_token = 1 + + if with_probs: if probs.ndim > 1: # [num_tokens, num_experts] probs_stride_token = num_experts @@ -199,93 +205,56 @@ def permute_with_mask_map( # [num_tokens] probs_stride_token = 1 probs_stride_expert = 1 + out_shape = [ + ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((num_out_tokens,), probs.dtype), + ] else: probs_stride_token = 0 probs_stride_expert = 0 + probs = jnp.zeros((0,), dtype=inp.dtype) + out_shape = [ + ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((0,), inp.dtype), + ] - # [num_out_tokens] - permuted_probs_stride_token = 1 + dummy_scale = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) + dummy_permuted_scale = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) - # one block per token, multiple blocks for hidden dimension - def grid_fn(meta): - return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) + outputs = jt.triton_call( + inp, + row_id_map, + probs, + dummy_scale, # scale + dummy_permuted_scale, # permuted_scale + 0, + row_id_stride_token, + row_id_stride_expert, + inp_stride_token, + inp_stride_hidden, + output_stride_token, + output_stride_hidden, + probs_stride_token, + probs_stride_expert, + hidden_size, + 1, + permuted_probs_stride_token, + hidden_size, + 1, + kernel=_permute_kernel, + out_shape=out_shape, + grid=grid_fn, + num_experts=num_experts, + hidden_size=hidden_size, + PERMUTE_PROBS=with_probs, + PERMUTE_SCALE=False, + # BLOCK_SIZE is keyword constexpr from autotune + ) - if probs is not None: - # jax-triton doesn't handle None pointers correctly, create dummy tensors - dummy_scale = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) - dummy_permuted_scale = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) - - output, permuted_probs = jt.triton_call( - inp, - row_id_map, - probs, - dummy_scale, - dummy_permuted_scale, - 0, - row_id_stride_token, - row_id_stride_expert, - inp_stride_token, - inp_stride_hidden, - output_stride_token, - output_stride_hidden, - probs_stride_token, - probs_stride_expert, - hidden_size, - 1, - permuted_probs_stride_token, - hidden_size, - 1, - kernel=_permute_kernel, - out_shape=[ - ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), - ShapeDtypeStruct((num_out_tokens,), probs.dtype), - ], - grid=grid_fn, - num_experts=num_experts, - hidden_size=hidden_size, - PERMUTE_PROBS=True, - PERMUTE_SCALE=False, - # BLOCK_SIZE is keyword constexpr from autotune - ) + output = outputs[0] + if with_probs: + permuted_probs = outputs[1] else: - # jax-triton doesn't handle None pointers correctly, create dummy tensors - dummy_probs = jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) - dummy_scale = jnp.zeros((num_tokens, hidden_size), dtype=inp.dtype) - dummy_permuted_scale = jnp.zeros((num_out_tokens, hidden_size), dtype=inp.dtype) - - result = jt.triton_call( - inp, - row_id_map, - dummy_probs, - dummy_scale, - dummy_permuted_scale, - 0, - row_id_stride_token, - row_id_stride_expert, - inp_stride_token, - inp_stride_hidden, - output_stride_token, - output_stride_hidden, - probs_stride_token, - probs_stride_expert, - hidden_size, - 1, - permuted_probs_stride_token, - hidden_size, - 1, - kernel=_permute_kernel, - out_shape=[ - ShapeDtypeStruct((num_out_tokens, hidden_size), inp.dtype), - ShapeDtypeStruct((num_out_tokens,), inp.dtype), - ], - grid=grid_fn, - num_experts=num_experts, - hidden_size=hidden_size, - PERMUTE_PROBS=False, - PERMUTE_SCALE=False, - # BLOCK_SIZE is keyword constexpr from autotune - ) - output = result[0] permuted_probs = None return output, permuted_probs @@ -328,6 +297,9 @@ def unpermute_with_mask_map( unpermuted_probs : Optional[jnp.ndarray] Unpermuted probabilities if permuted_probs was provided, None otherwise. """ + with_merging_probs = merging_probs is not None + with_probs = permuted_probs is not None + # Compute strides manually (JAX arrays don't have .strides attribute) # [num_out_tokens, hidden_size], inp_stride_token = hidden_size @@ -339,7 +311,7 @@ def unpermute_with_mask_map( row_id_stride_token = num_experts * 2 + 1 row_id_stride_expert = 1 # [num_tokens, num_experts] if present: - if merging_probs is not None: + if with_merging_probs: merging_probs_stride_token = num_experts merging_probs_stride_expert = 1 else: @@ -355,81 +327,50 @@ def unpermute_with_mask_map( def grid_fn(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) - if permuted_probs is not None: - merging_probs_arg = ( - merging_probs - if merging_probs is not None - else jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) - ) - - output, unpermuted_probs = jt.triton_call( - inp, - row_id_map, - merging_probs_arg, - permuted_probs, - row_id_stride_token, - row_id_stride_expert, - inp_stride_token, - inp_stride_hidden, - output_stride_token, - output_stride_hidden, - merging_probs_stride_token, - merging_probs_stride_expert, - permuted_probs_stride_token, - unpermuted_probs_stride_token, - unpermuted_probs_stride_expert, - kernel=_unpermute_kernel, - out_shape=[ - ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), - ShapeDtypeStruct((num_tokens, num_experts), permuted_probs.dtype), - ], - grid=grid_fn, - num_experts=num_experts, - hidden_size=hidden_size, - PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), - WITH_MERGING_PROBS=merging_probs is not None, - PERMUTE_PROBS=True, - # BLOCK_SIZE is keyword constexpr from autotune - ) + merging_probs = merging_probs if with_merging_probs else jnp.zeros((0,), dtype=inp.dtype) + permuted_probs = permuted_probs if with_probs else jnp.zeros((0,), dtype=inp.dtype) + + if with_probs: + out_shape = [ + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((num_tokens, num_experts), permuted_probs.dtype), + ] + else: + out_shape = [ + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((0,), inp.dtype), + ] + + outputs = jt.triton_call( + inp, + row_id_map, + merging_probs, + permuted_probs, + row_id_stride_token, + row_id_stride_expert, + inp_stride_token, + inp_stride_hidden, + output_stride_token, + output_stride_hidden, + merging_probs_stride_token, + merging_probs_stride_expert, + 1, + 0, + 0, + kernel=_unpermute_kernel, + out_shape=out_shape, + grid=grid_fn, + num_experts=num_experts, + hidden_size=hidden_size, + PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), + WITH_MERGING_PROBS=with_merging_probs, + PERMUTE_PROBS=with_probs, + # BLOCK_SIZE is keyword constexpr from autotune + ) + output = outputs[0] + if with_probs: + unpermuted_probs = outputs[1] else: - # jax-triton doesn't handle None pointers correctly, create dummy tensors if needed - dummy_permuted_probs = jnp.zeros((num_tokens,), dtype=inp.dtype) - merging_probs_arg = ( - merging_probs - if merging_probs is not None - else jnp.zeros((num_tokens, num_experts), dtype=inp.dtype) - ) - - result = jt.triton_call( - inp, - row_id_map, - merging_probs_arg, - dummy_permuted_probs, - row_id_stride_token, - row_id_stride_expert, - inp_stride_token, - inp_stride_hidden, - output_stride_token, - output_stride_hidden, - merging_probs_stride_token, - merging_probs_stride_expert, - 1, - 0, - 0, - kernel=_unpermute_kernel, - out_shape=[ - ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), - ShapeDtypeStruct((num_tokens, num_experts), inp.dtype), - ], - grid=grid_fn, - num_experts=num_experts, - hidden_size=hidden_size, - PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts), - WITH_MERGING_PROBS=merging_probs is not None, - PERMUTE_PROBS=False, - # BLOCK_SIZE is keyword constexpr from autotune - ) - output = result[0] unpermuted_probs = None return output, unpermuted_probs @@ -521,54 +462,41 @@ def sort_chunks_by_map( def grid_fn(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) - if probs is not None: - output, permuted_probs = jt.triton_call( - inp, - row_id_map, - probs, - inp_stride_token, - inp_stride_hidden, - output_stride_token, - output_stride_hidden, - probs_stride_token, - permuted_probs_stride_token, - kernel=_sort_chunks_by_map_kernel, - out_shape=[ - ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), - ShapeDtypeStruct((num_tokens,), probs.dtype), - ], - grid=grid_fn, - hidden_size=hidden_size, - PERMUTE_PROBS=True, - # BLOCK_SIZE is provided by autotune - FORWARD=is_forward, - ) + with_probs = probs is not None + if with_probs: + out_shape = [ + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((num_tokens,), probs.dtype), + ] + else: + out_shape = [ + ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), + ShapeDtypeStruct((0,), inp.dtype), + ] + probs = jnp.zeros((0,), dtype=inp.dtype) + + outputs = jt.triton_call( + inp, + row_id_map, + probs, + inp_stride_token, + inp_stride_hidden, + output_stride_token, + output_stride_hidden, + probs_stride_token, + permuted_probs_stride_token, + kernel=_sort_chunks_by_map_kernel, + out_shape=out_shape, + grid=grid_fn, + hidden_size=hidden_size, + PERMUTE_PROBS=with_probs, + # BLOCK_SIZE is provided by autotune + FORWARD=is_forward, + ) + output = outputs[0] + if with_probs: + permuted_probs = outputs[1] else: - - dummy_probs = jnp.zeros((num_tokens,), dtype=inp.dtype) - - result = jt.triton_call( - inp, - row_id_map, - dummy_probs, - inp_stride_token, - inp_stride_hidden, - output_stride_token, - output_stride_hidden, - probs_stride_token, - permuted_probs_stride_token, - kernel=_sort_chunks_by_map_kernel, - out_shape=[ - ShapeDtypeStruct((num_tokens, hidden_size), inp.dtype), - ShapeDtypeStruct((num_tokens,), inp.dtype), - ], - grid=grid_fn, - hidden_size=hidden_size, - PERMUTE_PROBS=False, - FORWARD=is_forward, - # BLOCK_SIZE is added by autotune as keyword constexpr - ) - output = result[0] permuted_probs = None return output, permuted_probs From 010fd8fec7c25fa47472451b3a78551ba6d46478 Mon Sep 17 00:00:00 2001 From: Ming Huang Date: Fri, 28 Nov 2025 16:45:45 -0800 Subject: [PATCH 09/10] Correct setting of permuted_probs_stride_token, unpermuted_probs_stride_token and unpermuted_probs_stride_expert in unpermutation --- transformer_engine/jax/triton/permutation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/triton/permutation.py b/transformer_engine/jax/triton/permutation.py index c104feb461..b10f1a6fb5 100644 --- a/transformer_engine/jax/triton/permutation.py +++ b/transformer_engine/jax/triton/permutation.py @@ -354,9 +354,9 @@ def grid_fn(meta): output_stride_hidden, merging_probs_stride_token, merging_probs_stride_expert, - 1, - 0, - 0, + permuted_probs_stride_token, + unpermuted_probs_stride_token, + unpermuted_probs_stride_expert, kernel=_unpermute_kernel, out_shape=out_shape, grid=grid_fn, From 4346a0e126e239f707391434d2a52b9cb999257a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Dec 2025 21:17:01 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/triton/permutation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/triton/permutation.py b/transformer_engine/jax/triton/permutation.py index b10f1a6fb5..9f112bd847 100644 --- a/transformer_engine/jax/triton/permutation.py +++ b/transformer_engine/jax/triton/permutation.py @@ -177,6 +177,7 @@ def permute_with_mask_map( permuted_probs : Optional[jnp.ndarray] Permuted probabilities if probs was provided, None otherwise. """ + # one block per token, multiple blocks for hidden dimension def grid_fn(meta): return (num_tokens, triton.cdiv(hidden_size, meta["BLOCK_SIZE"])) @@ -225,8 +226,8 @@ def grid_fn(meta): inp, row_id_map, probs, - dummy_scale, # scale - dummy_permuted_scale, # permuted_scale + dummy_scale, # scale + dummy_permuted_scale, # permuted_scale 0, row_id_stride_token, row_id_stride_expert,