diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index db54015..dd92a58 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -57,6 +57,7 @@ from cuequivariance import group_theory as group_theory from cuequivariance.group_theory import descriptors as descriptors + __all__ = [ "__version__", "Operation", diff --git a/cuequivariance_jax/cuequivariance_jax/triangle/_layer_norm_transpose.py b/cuequivariance_jax/cuequivariance_jax/triangle/_layer_norm_transpose.py index bc66474..73c56a1 100644 --- a/cuequivariance_jax/cuequivariance_jax/triangle/_layer_norm_transpose.py +++ b/cuequivariance_jax/cuequivariance_jax/triangle/_layer_norm_transpose.py @@ -23,12 +23,13 @@ from cuequivariance_jax.triangle._naive_batching import naive_batching_rule try: - import jax_triton as jt import triton - HAS_JAX_TRITON = True + from .triton_utils import triton_call + + HAS_TRITON = True except ImportError: - HAS_JAX_TRITON = False + HAS_TRITON = False # copy from cuequivariance_ops to avoid requiring cuequivariance_ops to be installed @@ -228,8 +229,8 @@ def layer_norm_transpose_reference_forward(x, w, b, eps, elementwise_affine, lay def _layer_norm_forward_impl(x, w, b, eps, elementwise_affine, layout): """Triton implementation of forward pass.""" - if not HAS_JAX_TRITON: - raise ImportError("jax_triton is required for GPU implementation") + if not HAS_TRITON: + raise ImportError("triton is required for GPU implementation") from cuequivariance_ops.triton import layer_norm_transpose_forward_kernel @@ -241,7 +242,7 @@ def _layer_norm_forward_impl(x, w, b, eps, elementwise_affine, layout): NEEDS_INT64 = B * N * D >= 2**31 - 1 - out, mean, rstd = jt.triton_call( + out, mean, rstd = triton_call( x, w, b, @@ -271,8 +272,8 @@ def _layer_norm_backward_impl( grad_out, x, w, b, mean, rstd, eps, elementwise_affine, layout ): """Triton implementation of backward pass.""" - if not HAS_JAX_TRITON: - raise ImportError("jax_triton is required for GPU implementation") + if not HAS_TRITON: + raise ImportError("triton is required for GPU implementation") from cuequivariance_ops.triton import layer_norm_transpose_backward_kernel @@ -286,7 +287,7 @@ def _layer_norm_backward_impl( NEEDS_INT64 = B * N * D >= 2**31 - 1 - grad_x, grad_w_tiles, grad_b_tiles = jt.triton_call( + grad_x, grad_w_tiles, grad_b_tiles = triton_call( grad_out, x, w, diff --git a/cuequivariance_jax/cuequivariance_jax/triangle/_sigmoid_gated_dual_gemm.py b/cuequivariance_jax/cuequivariance_jax/triangle/_sigmoid_gated_dual_gemm.py index 6c8a4d3..a0e5a1f 100644 --- a/cuequivariance_jax/cuequivariance_jax/triangle/_sigmoid_gated_dual_gemm.py +++ b/cuequivariance_jax/cuequivariance_jax/triangle/_sigmoid_gated_dual_gemm.py @@ -27,12 +27,13 @@ from ._utils import Precision try: - import jax_triton as jt import triton - HAS_JAX_TRITON = True + from .triton_utils import triton_call + + HAS_TRITON = True except ImportError: - HAS_JAX_TRITON = False + HAS_TRITON = False # Unified JAX primitives @@ -239,8 +240,8 @@ def fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper( num_warps: int = 4, ): """Triton implementation of forward pass.""" - if not HAS_JAX_TRITON: - raise ImportError("jax_triton is required for GPU implementation") + if not HAS_TRITON: + raise ImportError("triton is required for GPU implementation") from cuequivariance_ops.triton import fused_sigmoid_gated_dual_gemm_forward_kernel @@ -263,7 +264,7 @@ def fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper( out_shape = (N, M) if transpose_out else (M, N) dummy = jnp.zeros((), dtype=dtype) - return jt.triton_call( + return triton_call( x1, x2 if two_inputs else dummy, w1, @@ -271,12 +272,12 @@ def fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper( b1 if has_b1 else dummy, b2 if has_b2 else dummy, mask if has_mask else dummy, - M, - N, - K, kernel=fused_sigmoid_gated_dual_gemm_forward_kernel, out_shape=[jax.ShapeDtypeStruct(shape=out_shape, dtype=x1.dtype)], grid=(triton.cdiv(M, TILE_M), triton.cdiv(N, TILE_N), 1), + M=M, + N=N, + K=K, TILE_M=TILE_M, TILE_N=TILE_N, TILE_K=TILE_K, @@ -314,8 +315,8 @@ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper( num_warps: int = 4, ): """Triton implementation of backward pass.""" - if not HAS_JAX_TRITON: - raise ImportError("jax_triton is required for GPU implementation") + if not HAS_TRITON: + raise ImportError("triton is required for GPU implementation") from cuequivariance_ops.triton import ( fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel, @@ -346,7 +347,7 @@ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper( ] dummy = jnp.zeros((), dtype=dtype) - grad_xw1, grad_xw2, grad_mask = jt.triton_call( + grad_xw1, grad_xw2, grad_mask = triton_call( grad_out, x1, x2 if two_inputs else dummy, @@ -355,12 +356,12 @@ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper( b1 if has_b1 else dummy, b2 if has_b2 else dummy, mask if has_mask else dummy, - M, - N, - K, kernel=fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel, out_shape=out_shapes, grid=(triton.cdiv(M, TILE_M), triton.cdiv(N, TILE_N), 1), + M=M, + N=N, + K=K, TILE_M=TILE_M, TILE_N=TILE_N, TILE_K=TILE_K, diff --git a/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py b/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py new file mode 100644 index 0000000..8deb9bf --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/triangle/triton_utils.py @@ -0,0 +1,359 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Minimal Triton utilities for JAX primitives. + +This module provides a lightweight helper for integrating Triton kernels into JAX. + +This implementation was inspired by NVIDIA TransformerEngine +(https://github.com/NVIDIA/TransformerEngine). +""" + +from __future__ import annotations + +import hashlib +import os +import re +import subprocess +import zlib +from typing import Any, Callable, Mapping, Sequence + +import jax +import jax.extend.core as jex_core +import jax.numpy as jnp +from jax import core +from jax._src.lib import gpu_triton +from jax.interpreters import mlir +from packaging import version +from triton.backends.nvidia import compiler as cb +from triton.compiler import compiler as tc +from triton.compiler.errors import CompilationError +from triton.runtime import cache as triton_cache + +__all__ = ["triton_call_lowering", "triton_call"] + +try: + import triton + + TRITON_VERSION = version.parse(triton.__version__) +except (ImportError, AttributeError): + TRITON_VERSION = None + +# Configure Triton cache directory +default_cache_dir = os.path.join(os.path.expanduser("~"), ".triton", "cache") +try: + cache_dir = ( + triton_cache.knobs.cache.dir + if hasattr(triton_cache, "knobs") + else triton_cache.default_cache_dir() + ) + if not cache_dir: + cache_dir = default_cache_dir + if hasattr(triton_cache, "knobs"): + triton_cache.knobs.cache.dir = cache_dir +except Exception: + cache_dir = default_cache_dir +os.makedirs(cache_dir, exist_ok=True) + +_TRITON_KERNEL_CACHE = {} + +_DTYPE_MAP = { + jnp.dtype("bfloat16"): "bf16", + jnp.dtype("float64"): "fp64", + jnp.dtype("float32"): "fp32", + jnp.dtype("float16"): "fp16", + jnp.dtype("int64"): "i64", + jnp.dtype("int32"): "i32", + jnp.dtype("int16"): "i16", + jnp.dtype("int8"): "i8", + jnp.dtype("uint64"): "u64", + jnp.dtype("uint32"): "u32", + jnp.dtype("uint16"): "u16", + jnp.dtype("uint8"): "u8", + jnp.dtype("bool"): "i1", +} + +_PTXAS_VERSION_CACHE = None + + +def _get_max_ptx_version(): + """Detects the maximum PTX version supported by the available ptxas.""" + global _PTXAS_VERSION_CACHE + if _PTXAS_VERSION_CACHE is not None: + return _PTXAS_VERSION_CACHE + + try: + # Check ptxas version + result = subprocess.run( + ["ptxas", "--version"], capture_output=True, text=True, check=False + ) + if result.returncode == 0: + # Example output: "Cuda compilation tools, release 12.6, V12.6.68" + match = re.search(r"release (\d+)\.(\d+)", result.stdout) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + # Map CUDA version to PTX version + if major == 12: + if minor >= 8: + version = 87 + elif minor >= 5: + version = 85 + else: + version = 80 + minor + elif major == 11: + if minor >= 8: + version = 78 + else: + version = 70 + minor + else: + version = None + _PTXAS_VERSION_CACHE = version + except Exception: + pass + + return _PTXAS_VERSION_CACHE + + +def _get_triton_dtype(aval: core.ShapedArray) -> str: + """Convert JAX dtype to Triton type string.""" + return f"*{_DTYPE_MAP[aval.dtype]}" + + +def _compile_triton( + kernel_fn: Callable, + signature: Mapping[str, str], + constants: Mapping[str, Any], + num_warps: int, + num_stages: int, + compute_capability: int, + enable_fp_fusion: bool = False, +): + """Compile a Triton kernel to PTX with caching.""" + # Include source code in cache key to handle edits + cache_key = hashlib.md5( + str( + ( + kernel_fn.__name__, + getattr(kernel_fn, "src", ""), + tuple(sorted(signature.items())), + tuple(sorted(constants.items())), + num_warps, + num_stages, + compute_capability, + enable_fp_fusion, + ) + ).encode() + ).hexdigest() + + if cache_key in _TRITON_KERNEL_CACHE: + return _TRITON_KERNEL_CACHE[cache_key] + + # Detect maximum supported PTX version + max_ptx_version = _get_max_ptx_version() + + cuda_options_kwargs = { + "num_warps": num_warps, + "num_stages": num_stages, + "num_ctas": 1, + "cluster_dims": (1, 1, 1), + "debug": False, + "enable_fp_fusion": enable_fp_fusion, + } + + # Try adding ptx_version if detected and supported by Triton + if max_ptx_version is not None: + try: + # Check if CUDAOptions accepts ptx_version + cb.CUDAOptions(**cuda_options_kwargs, ptx_version=max_ptx_version) + except TypeError: + pass + else: + cuda_options_kwargs["ptx_version"] = max_ptx_version + + options = cb.CUDAOptions(**cuda_options_kwargs) + + # Triton 3.3.x is known to be incompatible due to constexpr handling bugs + if TRITON_VERSION is not None and ( + TRITON_VERSION.major == 3 and TRITON_VERSION.minor == 3 + ): + raise ImportError( + f"Triton version {TRITON_VERSION} is not supported due to known issues. " + "Please upgrade to Triton 3.4+ or downgrade to Triton 3.2.x." + ) + + # Helper to try compilation with specific arguments + def try_compile(signature_dict, **kwargs): + try: + src = tc.ASTSource(fn=kernel_fn, signature=signature_dict, **kwargs) + return tc.compile( + src, + target=tc.GPUTarget("cuda", compute_capability, 32), + options=options.__dict__, + ) + except (TypeError, AttributeError, CompilationError): + return None + + # 1. Try Triton 3.4.0+ API: constexprs should not be in signature + compiled = try_compile(signature, constexprs=constants) + + signature_with_constexpr = {**signature, **{k: "constexpr" for k in constants}} + + # 2. Try Triton 3.1.0-3.2.0: constexprs should be in signature as "constexpr" + if compiled is None: + compiled = try_compile(signature_with_constexpr, constexprs=constants) + + # 3. Try Triton 3.0.0: uses 'constants' instead of 'constexprs' + if compiled is None: + compiled = try_compile(signature_with_constexpr, constants=constants) + + if compiled is None: + raise RuntimeError("Failed to compile Triton kernel with any API version") + + # fmt: off + args = ( + (compiled.name, num_warps, 1, compiled.metadata.shared, compiled.asm["ptx"], "", compute_capability) + if version.parse(jax.__version__) >= version.parse("0.8.2") else + (compiled.name, num_warps, compiled.metadata.shared, compiled.asm["ptx"], "", compute_capability, 1, 1, 1) + ) + # fmt: on + kernel = gpu_triton.TritonKernel(*args) + + _TRITON_KERNEL_CACHE[cache_key] = kernel + return kernel + + +def triton_call_lowering( + ctx, + kernel_fn: Callable, + *array_args, + grid, + num_warps: int = 4, + num_stages: int = 3, + input_output_aliases: Mapping[int, int] | None = None, + constexprs: Mapping[str, Any] | None = None, + enable_fp_fusion: bool = False, +): + """Helper for MLIR lowering that calls a Triton kernel.""" + compute_capability = gpu_triton.get_compute_capability(0) + + all_avals = list(ctx.avals_in) + list(ctx.avals_out) + constexprs = constexprs or {} + tensor_arg_names = [n for n in kernel_fn.arg_names if n not in constexprs] + + signature = {n: _get_triton_dtype(a) for n, a in zip(tensor_arg_names, all_avals)} + + # Normalize grid to tuple of 3 ints + g = tuple(grid) if isinstance(grid, (tuple, list)) else (grid,) + grid_tuple = g[:3] + (1,) * (3 - len(g[:3])) + + kernel = _compile_triton( + kernel_fn, + signature, + constexprs, + num_warps, + num_stages, + compute_capability, + enable_fp_fusion, + ) + + kernel_params = [gpu_triton.create_array_parameter(0, 16) for _ in all_avals] + + # WARNING: Must use explicit indexing (grid_tuple[0], grid_tuple[1], grid_tuple[2]) + # instead of unpacking (*grid_tuple). Unpacking causes UnexpectedTracerError. + call_proto = gpu_triton.TritonKernelCall( + kernel, grid_tuple[0], grid_tuple[1], grid_tuple[2], kernel_params + ).to_proto(kernel_fn.__name__, b"") + + return jax.ffi.ffi_lowering( + "triton_kernel_call", + api_version=2, + backend_config=zlib.compress(call_proto), + operand_output_aliases=input_output_aliases or {}, + )(ctx, *array_args) + + +# Define global Triton kernel call primitive +_triton_kernel_call_p = jex_core.Primitive("triton_kernel_call") +_triton_kernel_call_p.multiple_results = True + + +def _triton_abstract_eval(*avals, out_shape, **unused_kwargs): + """Abstract evaluation for Triton kernel call.""" + return tuple(core.ShapedArray(s.shape, s.dtype) for s in out_shape) + + +def _triton_lowering_rule( + ctx, + *mlir_args, + kernel, + grid, + num_warps, + num_stages, + constexprs, + out_shape, + enable_fp_fusion, +): + """Lowering rule for Triton kernel call.""" + return triton_call_lowering( + ctx, + kernel, + *mlir_args, + grid=grid, + num_warps=num_warps, + num_stages=num_stages, + constexprs=dict(constexprs), + enable_fp_fusion=enable_fp_fusion, + ) + + +# Register primitive +_triton_kernel_call_p.def_abstract_eval(_triton_abstract_eval) +mlir.register_lowering(_triton_kernel_call_p, _triton_lowering_rule, platform="gpu") + + +def triton_call( + *args, + kernel: Callable, + out_shape: Sequence[jax.ShapeDtypeStruct], + grid, + num_warps: int = 4, + num_stages: int = 3, + enable_fp_fusion: bool = False, + **kwargs, +): + """High-level API to call a Triton kernel from JAX.""" + # WARNING: Do NOT inline these variables into the bind() call below. + # Inlining causes UnexpectedTracerError because JAX tracers leak into + # primitive parameters. These intermediate variables must be created + # separately to avoid tracer leaks during JAX transformations. + constexprs_tuple = tuple(sorted(kwargs.items())) + out_shape_tuple = tuple(out_shape) + grid_tuple = tuple(grid) if isinstance(grid, (list, tuple)) else (grid,) + + result = _triton_kernel_call_p.bind( + *args, + kernel=kernel, + grid=grid_tuple, + num_warps=num_warps, + num_stages=num_stages, + constexprs=constexprs_tuple, + out_shape=out_shape_tuple, + enable_fp_fusion=enable_fp_fusion, + ) + + num_outputs = len(out_shape) + if num_outputs == 1: + return result[0] if isinstance(result, tuple) else result + return result diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 0af5e98..4899394 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -70,7 +70,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Args: x (torch.Tensor): The input tensor. use_fallback (bool, optional): If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `False`, a CUDA kernel will be attempted, falling back to PyTorch if not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns: @@ -93,21 +93,15 @@ def __init__( self.f = None if info is not None: - import_error = None if use_fallback is False or use_fallback is None: try: import cuequivariance_ops_torch # noqa: F401 - except ImportError as e: - import_error = e + except ImportError: + pass else: if torch.cuda.is_available(): self.f = _transpose(info).to(device=device) - if use_fallback is False and self.f is None: - raise RuntimeError( - f"CUDA kernel not available for TransposeSegments: {import_error}" - ) - if self.f is None: self.f = _transpose_segments_fx(segments).to(device=device) else: @@ -126,18 +120,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: The input tensor to be transposed. use_fallback : Optional[bool], optional If `None` (default), a CUDA kernel will be used if available. - If `False`, a CUDA kernel will be used, and an exception is raised if it's not available. + If `False`, a CUDA kernel will be attempted, falling back to PyTorch if not available. If `True`, a PyTorch fallback method is used regardless of CUDA kernel availability. Returns ------- torch.Tensor The transposed tensor. - - Raises - ------ - RuntimeError - If `use_fallback` is `False` and a CUDA kernel is not available or the input is not on CUDA. """ return self.f(x) diff --git a/docs/noxfile.py b/docs/noxfile.py index 8349d28..1204ee3 100644 --- a/docs/noxfile.py +++ b/docs/noxfile.py @@ -16,9 +16,6 @@ def docs(session: nox.Session): session.run("python", "-m", "uv", "pip", "install", "../cuequivariance_jax") session.run("python", "-m", "uv", "pip", "install", "../cuequivariance_torch") - # Explicitly install jax-triton (workaround for dependency resolution issue) - session.run("python", "-m", "pip", "install", "jax-triton") - # Install docs requirements - now in same directory session.run("python", "-m", "uv", "pip", "install", "-r", "requirements.txt")