Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cuequivariance/cuequivariance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from cuequivariance import group_theory as group_theory
from cuequivariance.group_theory import descriptors as descriptors


__all__ = [
"__version__",
"Operation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
from cuequivariance_jax.triangle._naive_batching import naive_batching_rule

try:
import jax_triton as jt
import triton

HAS_JAX_TRITON = True
from .triton_utils import triton_call

HAS_TRITON = True
except ImportError:
HAS_JAX_TRITON = False
HAS_TRITON = False


# copy from cuequivariance_ops to avoid requiring cuequivariance_ops to be installed
Expand Down Expand Up @@ -228,8 +229,8 @@ def layer_norm_transpose_reference_forward(x, w, b, eps, elementwise_affine, lay

def _layer_norm_forward_impl(x, w, b, eps, elementwise_affine, layout):
"""Triton implementation of forward pass."""
if not HAS_JAX_TRITON:
raise ImportError("jax_triton is required for GPU implementation")
if not HAS_TRITON:
raise ImportError("triton is required for GPU implementation")

from cuequivariance_ops.triton import layer_norm_transpose_forward_kernel

Expand All @@ -241,7 +242,7 @@ def _layer_norm_forward_impl(x, w, b, eps, elementwise_affine, layout):

NEEDS_INT64 = B * N * D >= 2**31 - 1

out, mean, rstd = jt.triton_call(
out, mean, rstd = triton_call(
x,
w,
b,
Expand Down Expand Up @@ -271,8 +272,8 @@ def _layer_norm_backward_impl(
grad_out, x, w, b, mean, rstd, eps, elementwise_affine, layout
):
"""Triton implementation of backward pass."""
if not HAS_JAX_TRITON:
raise ImportError("jax_triton is required for GPU implementation")
if not HAS_TRITON:
raise ImportError("triton is required for GPU implementation")

from cuequivariance_ops.triton import layer_norm_transpose_backward_kernel

Expand All @@ -286,7 +287,7 @@ def _layer_norm_backward_impl(

NEEDS_INT64 = B * N * D >= 2**31 - 1

grad_x, grad_w_tiles, grad_b_tiles = jt.triton_call(
grad_x, grad_w_tiles, grad_b_tiles = triton_call(
grad_out,
x,
w,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
from ._utils import Precision

try:
import jax_triton as jt
import triton

HAS_JAX_TRITON = True
from .triton_utils import triton_call

HAS_TRITON = True
except ImportError:
HAS_JAX_TRITON = False
HAS_TRITON = False


# Unified JAX primitives
Expand Down Expand Up @@ -239,8 +240,8 @@ def fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper(
num_warps: int = 4,
):
"""Triton implementation of forward pass."""
if not HAS_JAX_TRITON:
raise ImportError("jax_triton is required for GPU implementation")
if not HAS_TRITON:
raise ImportError("triton is required for GPU implementation")

from cuequivariance_ops.triton import fused_sigmoid_gated_dual_gemm_forward_kernel

Expand All @@ -263,20 +264,20 @@ def fused_sigmoid_gated_dual_gemm_forward_kernel_wrapper(
out_shape = (N, M) if transpose_out else (M, N)
dummy = jnp.zeros((), dtype=dtype)

return jt.triton_call(
return triton_call(
x1,
x2 if two_inputs else dummy,
w1,
w2,
b1 if has_b1 else dummy,
b2 if has_b2 else dummy,
mask if has_mask else dummy,
M,
N,
K,
kernel=fused_sigmoid_gated_dual_gemm_forward_kernel,
out_shape=[jax.ShapeDtypeStruct(shape=out_shape, dtype=x1.dtype)],
grid=(triton.cdiv(M, TILE_M), triton.cdiv(N, TILE_N), 1),
M=M,
N=N,
K=K,
TILE_M=TILE_M,
TILE_N=TILE_N,
TILE_K=TILE_K,
Expand Down Expand Up @@ -314,8 +315,8 @@ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper(
num_warps: int = 4,
):
"""Triton implementation of backward pass."""
if not HAS_JAX_TRITON:
raise ImportError("jax_triton is required for GPU implementation")
if not HAS_TRITON:
raise ImportError("triton is required for GPU implementation")

from cuequivariance_ops.triton import (
fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel,
Expand Down Expand Up @@ -346,7 +347,7 @@ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper(
]

dummy = jnp.zeros((), dtype=dtype)
grad_xw1, grad_xw2, grad_mask = jt.triton_call(
grad_xw1, grad_xw2, grad_mask = triton_call(
grad_out,
x1,
x2 if two_inputs else dummy,
Expand All @@ -355,12 +356,12 @@ def fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel_wrapper(
b1 if has_b1 else dummy,
b2 if has_b2 else dummy,
mask if has_mask else dummy,
M,
N,
K,
kernel=fused_sigmoid_gated_dual_gemm_backward_pregemm_kernel,
out_shape=out_shapes,
grid=(triton.cdiv(M, TILE_M), triton.cdiv(N, TILE_N), 1),
M=M,
N=N,
K=K,
TILE_M=TILE_M,
TILE_N=TILE_N,
TILE_K=TILE_K,
Expand Down
Loading