From 6900f579a4b16262999435c5ed1a488ca9ecaca6 Mon Sep 17 00:00:00 2001 From: SuperAngGao Date: Thu, 26 Feb 2026 18:42:26 +0800 Subject: [PATCH 1/8] [Feat][GLA] Add GLA (Gated Linear Attention) Forward Operator (L2) Implements chunked GLA forward pass with: - Stage 1+2 (PyTorch): within-chunk gate cumsum + inter-chunk hidden state recurrence - Stage 3 (TileLang): intra-chunk causal attention matrix A [B, T, H, BT] - Stage 4 (TileLang): output combining inter-chunk and intra-chunk contributions Files added: - tileops/kernels/gla/gla_fwd.py -- GLAFwdKernel (sm90a) - tileops/kernels/gla/__init__.py - tileops/ops/gla.py -- GLAFwdOp - tests/ops/test_gla.py -- 7 test cases (fp16 + bf16, with/without initial_state) Closes tile-ai/TileOPs#213 Reference: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py Co-Authored-By: Claude Sonnet 4.6 --- tests/ops/test_gla.py | 203 +++++++++++++++ tileops/kernels/gla/__init__.py | 3 + tileops/kernels/gla/gla_fwd.py | 434 ++++++++++++++++++++++++++++++++ tileops/ops/__init__.py | 98 ++++---- tileops/ops/gla.py | 88 +++++++ 5 files changed, 778 insertions(+), 48 deletions(-) create mode 100644 tests/ops/test_gla.py create mode 100644 tileops/kernels/gla/__init__.py create mode 100644 tileops/kernels/gla/gla_fwd.py create mode 100644 tileops/ops/gla.py diff --git a/tests/ops/test_gla.py b/tests/ops/test_gla.py new file mode 100644 index 00000000..03093196 --- /dev/null +++ b/tests/ops/test_gla.py @@ -0,0 +1,203 @@ +"""Correctness unit tests for GLAFwdOp. + +Reference: + https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py +""" + +import pytest +import torch +import torch.nn.functional as F + +from tileops.ops import GLAFwdOp + + +def ref_gla_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float, + initial_state: torch.Tensor | None = None, + chunk_size: int = 64, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pure PyTorch reference for GLA chunked forward. + + Implements the same 4-stage algorithm as the TileLang kernel: + 1. Within-chunk cumulative sum of log-space gates + 2. Inter-chunk hidden state recurrence with gated decay + 3. Intra-chunk causal attention matrix + 4. Output = inter-chunk (q*exp(g_cs) @ h) + intra-chunk (A @ v) + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + v: [B, T, H, V] + g: [B, T, H, K] log-space gates + scale: query scale factor + initial_state: [B, H, K, V] float32, optional + chunk_size: BT + + Returns: + (o [B, T, H, V], final_state [B, H, K, V] float32) + """ + B, T, H, K = q.shape + V = v.shape[-1] + NT = (T + chunk_size - 1) // chunk_size + + # work in float32 for numerical stability + q = q.float() + k = k.float() + v = v.float() + g = g.float() + + # Stage 1: within-chunk cumulative sum of gates + # g_cs[b, t, h, k] = sum of g[b, chunk_start..t, h, k] + g_cs = torch.zeros_like(g) + for i_c in range(NT): + cs = i_c * chunk_size + ce = min(cs + chunk_size, T) + g_cs[:, cs:ce] = torch.cumsum(g[:, cs:ce], dim=1) + + # Stage 2: inter-chunk hidden state recurrence + # h[b, i_c, h, K, V] = state entering chunk i_c + h_states = torch.zeros(B, NT, H, K, V, dtype=torch.float32, device=q.device) + b_h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + if initial_state is not None: + b_h = initial_state.float().clone() + + for i_c in range(NT): + cs = i_c * chunk_size + ce = min(cs + chunk_size, T) + h_states[:, i_c] = b_h + + # g_last: cumsum at last position of this chunk [B, H, K] + g_last = g_cs[:, ce - 1] # [B, H, K] + + # decay existing state + # b_h[b, h, k, v] *= exp(g_last[b, h, k]) + b_h = b_h * torch.exp(g_last).unsqueeze(-1) # [B, H, K, V] + + # accumulate: sum over t in chunk of k_adj[t]^T @ v[t] + # k_adj[b, t, h, k] = k[b, t, h, k] * exp(g_last[b, h, k] - g_cs[b, t, h, k]) + k_chunk = k[:, cs:ce] # [B, L, H, K] + v_chunk = v[:, cs:ce] # [B, L, H, V] + g_cs_chunk = g_cs[:, cs:ce] # [B, L, H, K] + g_last_exp = torch.exp(g_last).unsqueeze(1) # [B, 1, H, K] + k_adj = k_chunk * (g_last_exp / torch.exp(g_cs_chunk).clamp(min=1e-30)) + # b_h += einsum('blhk,blhv->bhkv', k_adj, v_chunk) + b_h = b_h + torch.einsum('blhk,blhv->bhkv', k_adj, v_chunk) + + final_state = b_h # [B, H, K, V] + + # Stage 3 + 4: intra-chunk attention and output + o = torch.zeros(B, T, H, V, dtype=torch.float32, device=q.device) + + for i_c in range(NT): + cs = i_c * chunk_size + ce = min(cs + chunk_size, T) + L = ce - cs + + q_c = q[:, cs:ce] # [B, L, H, K] + k_c = k[:, cs:ce] # [B, L, H, K] + v_c = v[:, cs:ce] # [B, L, H, V] + g_cs_c = g_cs[:, cs:ce] # [B, L, H, K] + h_c = h_states[:, i_c] # [B, H, K, V] + + # intra-chunk attention matrix A[b, i, h, j] = scale * sum_k( + # q[i,k]*exp(g_cs[i,k]) * k[j,k]*exp(-g_cs[j,k]) ), causal + q_gated = q_c * torch.exp(g_cs_c) # [B, L, H, K] + k_gated = k_c * torch.exp(-g_cs_c) # [B, L, H, K] + # A[b, h, i, j] = scale * q_gated[b,i,h,:] @ k_gated[b,j,h,:]^T + # rearrange to [B, H, L, K] for bmm + qg = q_gated.permute(0, 2, 1, 3) # [B, H, L, K] + kg = k_gated.permute(0, 2, 1, 3) # [B, H, L, K] + A = scale * torch.bmm(qg.reshape(B * H, L, K), + kg.reshape(B * H, L, K).transpose(1, 2)).reshape(B, H, L, + L) # [B, H, L, L] + + # causal mask + causal_mask = torch.tril(torch.ones(L, L, device=q.device, dtype=torch.bool)) + A = A * causal_mask.unsqueeze(0).unsqueeze(0) # [B, H, L, L] + + # intra-chunk output: A @ v [B, H, L, V] + vc = v_c.permute(0, 2, 1, 3).reshape(B * H, L, V) + o_intra = torch.bmm(A.reshape(B * H, L, L), + vc).reshape(B, H, L, V).permute(0, 2, 1, 3) # [B, L, H, V] + + # inter-chunk output: scale * (q*exp(g_cs)) @ h [B, L, H, V] + # q_gated [B, L, H, K], h_c [B, H, K, V] + o_inter = scale * torch.einsum('blhk,bhkv->blhv', q_gated, h_c) + + o[:, cs:ce] = o_intra + o_inter + + return o, final_state + + +@pytest.mark.parametrize( + "batch, seq_len, heads, dim_k, dim_v, chunk_size, output_final_state, dtype, tune", + [ + (1, 64, 4, 64, 64, 64, False, torch.float16, False), + (2, 128, 8, 64, 64, 64, False, torch.bfloat16, False), + (1, 256, 4, 128, 128, 64, False, torch.float16, False), + (2, 128, 8, 64, 128, 32, False, torch.bfloat16, False), + (4, 256, 16, 64, 64, 64, False, torch.float16, True), + (1, 64, 4, 64, 64, 64, True, torch.float16, False), # with initial_state + final_state + (2, 128, 8, 64, 64, 64, True, torch.bfloat16, False), + ], +) +def test_gla_fwd( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int, + output_final_state: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + torch.manual_seed(42) + + scale = dim_k**-0.5 + + op = GLAFwdOp( + batch=batch, + seq_len=seq_len, + heads=heads, + dim_k=dim_k, + dim_v=dim_v, + chunk_size=chunk_size, + scale=scale, + output_final_state=output_final_state, + dtype=dtype, + tune=tune, + ) + + q = torch.randn(batch, seq_len, heads, dim_k, device='cuda', dtype=dtype) + k = torch.randn(batch, seq_len, heads, dim_k, device='cuda', dtype=dtype) + v = torch.randn(batch, seq_len, heads, dim_v, device='cuda', dtype=dtype) + g = F.logsigmoid(torch.randn(batch, seq_len, heads, dim_k, device='cuda', dtype=dtype)) + + initial_state = None + if output_final_state: + initial_state = torch.randn( + batch, heads, dim_k, dim_v, device='cuda', dtype=torch.float32) * 0.1 + + with torch.no_grad(): + out, out_final = op(q, k, v, g, initial_state) + + ref_o, ref_final = ref_gla_fwd( + q, k, v, g, scale=scale, initial_state=initial_state, chunk_size=chunk_size) + ref_o = ref_o.to(dtype) + + assert torch.allclose(out, ref_o, atol=1e-2, rtol=1e-2), \ + f"output mismatch: max err = {(out.float() - ref_o.float()).abs().max():.6f}" + + if output_final_state: + assert out_final is not None + assert torch.allclose(out_final.float(), ref_final.float(), atol=1e-2, rtol=1e-2), \ + f"final_state mismatch: max err = {(out_final.float() - ref_final.float()).abs().max():.6f}" + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tileops/kernels/gla/__init__.py b/tileops/kernels/gla/__init__.py new file mode 100644 index 00000000..7f1e093f --- /dev/null +++ b/tileops/kernels/gla/__init__.py @@ -0,0 +1,3 @@ +from .gla_fwd import GLAFwdKernel + +__all__ = ["GLAFwdKernel"] diff --git a/tileops/kernels/gla/gla_fwd.py b/tileops/kernels/gla/gla_fwd.py new file mode 100644 index 00000000..59f4c045 --- /dev/null +++ b/tileops/kernels/gla/gla_fwd.py @@ -0,0 +1,434 @@ +"""GLA (Gated Linear Attention) Forward Kernel. + +Reference: + https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py + +Algorithm: Chunked GLA forward in 4 stages: + 1. Within-chunk cumulative sum of log-space gates g -> g_cumsum [B, T, H, K] + (computed in PyTorch inside forward() — sequential scan, not GPU-bound) + 2. Inter-chunk hidden state recurrence -> h [B, NT, H, K, V], ht [B, H, K, V] + (computed in PyTorch inside forward() — sequential over chunks) + 3. Intra-chunk causal attention matrix -> A [B, T, H, BT] (TileLang) + 4. Output: o = inter-chunk (q*exp(g_cumsum) @ h) + intra-chunk (A @ v) (TileLang) + +Inputs: + q [B, T, H, K] fp16/bf16 queries + k [B, T, H, K] fp16/bf16 keys + v [B, T, H, V] fp16/bf16 values + g [B, T, H, K] fp16/bf16 log-space forget gates (e.g. F.logsigmoid(...)) + initial_state [B, H, K, V] float32 optional initial hidden state + +Outputs: + o [B, T, H, V] fp16/bf16 + final_state [B, H, K, V] float32 (only when output_final_state=True) +""" + +import torch +from typing import Optional, Any, Callable + +import tilelang +from tilelang import language as T + +from tileops.kernels.kernel import Kernel + +LOG2_E = 1.44269504 + +# --------------------------------------------------------------------------- +# Stage 3: Intra-chunk causal attention matrix A [B, T, H, BT] +# --------------------------------------------------------------------------- + + +def _gla_fwd_intra_kernel( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + chunk_size: int, + scale: float, + dtype: str, +) -> Callable: + num_chunks = (seq_len + chunk_size - 1) // chunk_size + q_shape = [batch, seq_len, heads, dim_k] + k_shape = [batch, seq_len, heads, dim_k] + g_cumsum_shape = [batch, seq_len, heads, dim_k] + A_shape = [batch, seq_len, heads, chunk_size] + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + def _func(threads: int): + + @T.prim_func + def _kernel( + q: T.Tensor(q_shape, dtype), + k: T.Tensor(k_shape, dtype), + g_cumsum: T.Tensor(g_cumsum_shape, "float32"), + A: T.Tensor(A_shape, "float32"), + ): + with T.Kernel(batch * heads, num_chunks, threads=threads) as (bx, by): + i_b = bx // heads + i_h = bx % heads + i_c = by + chunk_start = i_c * chunk_size + + # Shared buffers for inputs + q_shared = T.alloc_shared([chunk_size, dim_k], dtype) + k_shared = T.alloc_shared([chunk_size, dim_k], dtype) + g_shared = T.alloc_shared([chunk_size, dim_k], "float32") + + # Shared buffers for gated q/k (float32 for gemm) + q_gated = T.alloc_shared([chunk_size, dim_k], "float32") + k_gated = T.alloc_shared([chunk_size, dim_k], "float32") + + # Fragment accumulator for A [BT, BT] + acc = T.alloc_fragment([chunk_size, chunk_size], "float32") + + # Load inputs + T.copy(q[i_b, chunk_start:chunk_start + chunk_size, i_h, :], q_shared) + T.copy(k[i_b, chunk_start:chunk_start + chunk_size, i_h, :], k_shared) + T.copy(g_cumsum[i_b, chunk_start:chunk_start + chunk_size, i_h, :], g_shared) + + # q_gated[t, k] = q[t, k] * exp(g_cumsum[t, k]) + # k_gated[t, k] = k[t, k] * exp(-g_cumsum[t, k]) + for i_t, i_k in T.Parallel(chunk_size, dim_k): + q_gated[i_t, i_k] = ( + T.cast(q_shared[i_t, i_k], "float32") * T.exp2(g_shared[i_t, i_k] * LOG2_E)) + k_gated[i_t, i_k] = ( + T.cast(k_shared[i_t, i_k], "float32") * + T.exp2(-g_shared[i_t, i_k] * LOG2_E)) + + # A = q_gated @ k_gated^T [BT, BT] + T.fill(acc, 0.0) + T.gemm(q_gated, k_gated, acc, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + # Apply causal mask and scale, write to A + for i_t, i_j in T.Parallel(chunk_size, chunk_size): + A[i_b, chunk_start + i_t, i_h, + i_j] = T.if_then_else(i_j <= i_t, acc[i_t, i_j] * scale, 0.0) + + return _kernel + + return _func + + +@torch.library.custom_op("gla::gla_fwd_intra", mutates_args=()) +def _gla_fwd_intra_wrapped( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + chunk_size: int, + scale: float, + dtype: str, + threads: int, + q: torch.Tensor, + k: torch.Tensor, + g_cumsum: torch.Tensor, +) -> torch.Tensor: + return _gla_fwd_intra_kernel(batch, seq_len, heads, dim_k, chunk_size, scale, + dtype)(threads)(q, k, g_cumsum) + + +@_gla_fwd_intra_wrapped.register_fake +def _( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + chunk_size: int, + scale: float, + dtype: str, + threads: int, + *inputs: tuple[Any], +) -> torch.Tensor: + _ = (dim_k, scale, dtype, threads) + return torch.empty([batch, seq_len, heads, chunk_size], + dtype=torch.float32, + device=inputs[0].device) + + +# --------------------------------------------------------------------------- +# Stage 4: Output computation o [B, T, H, V] +# --------------------------------------------------------------------------- + + +def _gla_fwd_o_kernel( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int, + scale: float, + dtype: str, +) -> Callable: + num_chunks = (seq_len + chunk_size - 1) // chunk_size + q_shape = [batch, seq_len, heads, dim_k] + v_shape = [batch, seq_len, heads, dim_v] + g_cumsum_shape = [batch, seq_len, heads, dim_k] + A_shape = [batch, seq_len, heads, chunk_size] + h_shape = [batch, num_chunks, heads, dim_k, dim_v] + o_shape = [batch, seq_len, heads, dim_v] + + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + def _func(threads: int): + + @T.prim_func + def _kernel( + q: T.Tensor(q_shape, dtype), + v: T.Tensor(v_shape, dtype), + g_cumsum: T.Tensor(g_cumsum_shape, "float32"), + A: T.Tensor(A_shape, "float32"), + h: T.Tensor(h_shape, "float32"), + o: T.Tensor(o_shape, dtype), + ): + with T.Kernel(batch * heads, num_chunks, threads=threads) as (bx, by): + i_b = bx // heads + i_h = bx % heads + i_c = by + chunk_start = i_c * chunk_size + + # Shared buffers for inputs + q_shared = T.alloc_shared([chunk_size, dim_k], dtype) + v_shared = T.alloc_shared([chunk_size, dim_v], "float32") + g_cs_shared = T.alloc_shared([chunk_size, dim_k], "float32") + A_shared = T.alloc_shared([chunk_size, chunk_size], "float32") + h_shared = T.alloc_shared([dim_k, dim_v], "float32") + + # Shared buffer for gated q (float32 for gemm) + q_gated = T.alloc_shared([chunk_size, dim_k], "float32") + + # Fragment accumulator [BT, BV] + acc = T.alloc_fragment([chunk_size, dim_v], "float32") + + # Load inputs + T.copy(q[i_b, chunk_start:chunk_start + chunk_size, i_h, :], q_shared) + T.copy(g_cumsum[i_b, chunk_start:chunk_start + chunk_size, i_h, :], g_cs_shared) + T.copy(A[i_b, chunk_start:chunk_start + chunk_size, i_h, :], A_shared) + T.copy(h[i_b, i_c, i_h, :, :], h_shared) + + # Load v as float32 + v_raw = T.alloc_shared([chunk_size, dim_v], dtype) + T.copy(v[i_b, chunk_start:chunk_start + chunk_size, i_h, :], v_raw) + for i_t, i_v in T.Parallel(chunk_size, dim_v): + v_shared[i_t, i_v] = T.cast(v_raw[i_t, i_v], "float32") + + # q_gated[t, k] = q[t, k] * exp(g_cumsum[t, k]) + for i_t, i_k in T.Parallel(chunk_size, dim_k): + q_gated[i_t, i_k] = ( + T.cast(q_shared[i_t, i_k], "float32") * + T.exp2(g_cs_shared[i_t, i_k] * LOG2_E)) + + # inter-chunk: acc = scale * q_gated @ h [BT, BV] + T.fill(acc, 0.0) + T.gemm(q_gated, h_shared, acc, policy=T.GemmWarpPolicy.FullRow) + for i_t, i_v in T.Parallel(chunk_size, dim_v): + acc[i_t, i_v] = acc[i_t, i_v] * scale + + # intra-chunk: acc += A @ v [BT, BV] + T.gemm(A_shared, v_shared, acc, policy=T.GemmWarpPolicy.FullRow) + + # Write output (cast back to dtype) + for i_t, i_v in T.Parallel(chunk_size, dim_v): + o[i_b, chunk_start + i_t, i_h, i_v] = T.cast(acc[i_t, i_v], dtype) + + return _kernel + + return _func + + +@torch.library.custom_op("gla::gla_fwd_o", mutates_args=()) +def _gla_fwd_o_wrapped( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int, + scale: float, + dtype: str, + threads: int, + q: torch.Tensor, + v: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + h: torch.Tensor, +) -> torch.Tensor: + return _gla_fwd_o_kernel(batch, seq_len, heads, dim_k, dim_v, chunk_size, scale, + dtype)(threads)(q, v, g_cumsum, A, h) + + +@_gla_fwd_o_wrapped.register_fake +def _( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int, + scale: float, + dtype: str, + threads: int, + *inputs: tuple[Any], +) -> torch.Tensor: + _ = (dim_k, chunk_size, scale, dtype, threads) + return torch.empty([batch, seq_len, heads, dim_v], + dtype=inputs[0].dtype, + device=inputs[0].device) + + +# --------------------------------------------------------------------------- +# Kernel class +# --------------------------------------------------------------------------- + + +class GLAFwdKernel(Kernel): + """GLA (Gated Linear Attention) forward kernel. + + Implements chunked GLA forward: + Stage 1 (PyTorch): within-chunk cumulative sum of log-space gates + Stage 2 (PyTorch): inter-chunk hidden state recurrence + Stage 3 (TileLang): intra-chunk causal attention matrix A [B, T, H, BT] + Stage 4 (TileLang): output o = inter-chunk + intra-chunk contributions + + Args: + batch: Batch size B. + seq_len: Sequence length T. Must be divisible by chunk_size. + heads: Number of query heads H. + dim_k: Key/query head dimension K. + dim_v: Value head dimension V. + chunk_size: Chunk size BT (default 64). + scale: Query scale factor (default 1/sqrt(K)). + output_final_state: Whether to return the final hidden state. + dtype: Input tensor dtype (torch.float16 or torch.bfloat16). + config: Optional kernel config dict (e.g. {"threads": 128}). + tune: Whether to run autotuning. + + Inputs to forward(): + q [B, T, H, K] fp16/bf16 + k [B, T, H, K] fp16/bf16 + v [B, T, H, V] fp16/bf16 + g [B, T, H, K] fp16/bf16 log-space gates + initial_state [B, H, K, V] float32 optional + + Returns: + (o [B, T, H, V], final_state [B, H, K, V] or None) + + Reference: + https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py + """ + + supported_archs: list[int] = [90] + + def __init__( + self, + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int = 64, + scale: float = -1.0, + output_final_state: bool = False, + dtype: torch.dtype = torch.float16, + config: Optional[dict] = None, + tune: bool = False, + ) -> None: + super().__init__() + self.batch = batch + self.seq_len = seq_len + self.heads = heads + self.dim_k = dim_k + self.dim_v = dim_v + self.chunk_size = chunk_size + self.scale = scale if scale > 0 else dim_k**-0.5 + self.output_final_state = output_final_state + self.dtype_name = str(dtype).split('.')[-1] + self.init_config(config, tune) + # GLAFwdKernel has no single self.kernel to autotune; fall back to default_config + if not self.config: + self.config = self.default_config + + @property + def default_config(self) -> dict: + return {"threads": 128} + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + dtype_str = self.dtype_name + threads = self.config["threads"] + B, T, H, K = self.batch, self.seq_len, self.heads, self.dim_k + V = self.dim_v + BT = self.chunk_size + NT = (T + BT - 1) // BT + dtype_torch = getattr(torch, dtype_str) + + q = q.to(dtype_torch) + k = k.to(dtype_torch) + v = v.to(dtype_torch) + g = g.to(dtype_torch) + + use_initial_state = initial_state is not None + if not use_initial_state: + b_h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + else: + b_h = initial_state.to(torch.float32).clone() + + # Stage 1: within-chunk cumulative sum of gates (PyTorch) + g_f32 = g.float() + g_cumsum = torch.empty_like(g_f32) + for i_c in range(NT): + cs = i_c * BT + ce = min(cs + BT, T) + g_cumsum[:, cs:ce] = torch.cumsum(g_f32[:, cs:ce], dim=1) + + # Stage 2: inter-chunk hidden state recurrence (PyTorch) + # h_states[b, i_c, h, K, V] = state entering chunk i_c + h_states = torch.empty(B, NT, H, K, V, dtype=torch.float32, device=q.device) + k_f32 = k.float() + v_f32 = v.float() + for i_c in range(NT): + cs = i_c * BT + ce = min(cs + BT, T) + h_states[:, i_c] = b_h + + # g_last: g_cumsum at last position of this chunk [B, H, K] + g_last = g_cumsum[:, ce - 1] # [B, H, K] + + # Decay: b_h[b, h, k, v] *= exp(g_last[b, h, k]) + b_h = b_h * torch.exp(g_last).unsqueeze(-1) + + # Accumulate: b_h += k_adj^T @ v + # k_adj[b, t, h, k] = k[b, t, h, k] * exp(g_last[b, h, k] - g_cumsum[b, t, h, k]) + k_chunk = k_f32[:, cs:ce] # [B, L, H, K] + v_chunk = v_f32[:, cs:ce] # [B, L, H, V] + g_cs_chunk = g_cumsum[:, cs:ce] # [B, L, H, K] + k_adj = k_chunk * torch.exp(g_last.unsqueeze(1) - g_cs_chunk) + b_h = b_h + torch.einsum('blhk,blhv->bhkv', k_adj, v_chunk) + + final_state = b_h if self.output_final_state else None + + # Stage 3: intra-chunk attention matrix (TileLang) + A = _gla_fwd_intra_wrapped(B, T, H, K, BT, self.scale, dtype_str, threads, q, k, g_cumsum) + + # Stage 4: output (TileLang) + o = _gla_fwd_o_wrapped(B, T, H, K, V, BT, self.scale, dtype_str, threads, q, v, g_cumsum, A, + h_states) + + return o, final_state diff --git a/tileops/ops/__init__.py b/tileops/ops/__init__.py index d963a62a..6dc6e6bf 100644 --- a/tileops/ops/__init__.py +++ b/tileops/ops/__init__.py @@ -1,48 +1,50 @@ -from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheOp -from .fp8_lighting_indexer import Fp8LightingIndexerOp -from .topk_selector import TopkSelectorOp -from .fp8_quant import Fp8QuantOp -from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp -from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp -from .gemm import GemmOp -from .gemv import GemvOp -from .gqa import GroupQueryAttentionBwdOp, GroupQueryAttentionFwdOp -from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp -from .gqa_decode_paged import GroupQueryAttentionDecodePagedWithKVCacheOp -from .grouped_gemm import GroupedGemmNNOp, GroupedGemmNTOp, GroupedGemmTNOp, GroupedGemmTTOp -from .mha import MultiHeadAttentionBwdOp, MultiHeadAttentionFwdOp -from .mha_decode import MultiHeadAttentionDecodeWithKVCacheOp -from .mha_decode_paged import MultiHeadAttentionDecodePagedWithKVCacheOp -from .mhc_pre import ManifoldConstrainedHyperConnectionPreOp -from .mhc_post import ManifoldConstrainedHyperConnectionPostOp -from .op import Op # noqa: F401 - -__all__ = [ - "Op", - "MultiHeadAttentionFwdOp", - "MultiHeadAttentionBwdOp", - "GroupQueryAttentionFwdOp", - "GroupQueryAttentionBwdOp", - "GemmOp", - "GemvOp", - "MultiHeadAttentionDecodeWithKVCacheOp", - "MultiHeadAttentionDecodePagedWithKVCacheOp", - "GroupQueryAttentionDecodeWithKVCacheOp", - "GroupQueryAttentionDecodePagedWithKVCacheOp", - "GroupedGemmNTOp", - "GroupedGemmNNOp", - "GroupedGemmTNOp", - "GroupedGemmTTOp", - "MultiHeadLatentAttentionDecodeWithKVCacheOp", - "DeepSeekSparseAttentionDecodeWithKVCacheOp", - "Fp8LightingIndexerOp", - "TopkSelectorOp", - "Fp8QuantOp", - "MeanPoolingForwardOp", - "NSATopkVarlenOp", - "NSAFwdVarlenOp", - "NSACmpFwdVarlenOp", - "GQAWindowSlidingOp", - "ManifoldConstrainedHyperConnectionPreOp", - "ManifoldConstrainedHyperConnectionPostOp", -] +from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheOp +from .fp8_lighting_indexer import Fp8LightingIndexerOp +from .topk_selector import TopkSelectorOp +from .fp8_quant import Fp8QuantOp +from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp +from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp +from .gemm import GemmOp +from .gemv import GemvOp +from .gqa import GroupQueryAttentionBwdOp, GroupQueryAttentionFwdOp +from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp +from .gqa_decode_paged import GroupQueryAttentionDecodePagedWithKVCacheOp +from .grouped_gemm import GroupedGemmNNOp, GroupedGemmNTOp, GroupedGemmTNOp, GroupedGemmTTOp +from .mha import MultiHeadAttentionBwdOp, MultiHeadAttentionFwdOp +from .mha_decode import MultiHeadAttentionDecodeWithKVCacheOp +from .mha_decode_paged import MultiHeadAttentionDecodePagedWithKVCacheOp +from .mhc_pre import ManifoldConstrainedHyperConnectionPreOp +from .mhc_post import ManifoldConstrainedHyperConnectionPostOp +from .gla import GLAFwdOp +from .op import Op # noqa: F401 + +__all__ = [ + "Op", + "MultiHeadAttentionFwdOp", + "MultiHeadAttentionBwdOp", + "GroupQueryAttentionFwdOp", + "GroupQueryAttentionBwdOp", + "GemmOp", + "GemvOp", + "MultiHeadAttentionDecodeWithKVCacheOp", + "MultiHeadAttentionDecodePagedWithKVCacheOp", + "GroupQueryAttentionDecodeWithKVCacheOp", + "GroupQueryAttentionDecodePagedWithKVCacheOp", + "GroupedGemmNTOp", + "GroupedGemmNNOp", + "GroupedGemmTNOp", + "GroupedGemmTTOp", + "MultiHeadLatentAttentionDecodeWithKVCacheOp", + "DeepSeekSparseAttentionDecodeWithKVCacheOp", + "Fp8LightingIndexerOp", + "TopkSelectorOp", + "Fp8QuantOp", + "MeanPoolingForwardOp", + "NSATopkVarlenOp", + "NSAFwdVarlenOp", + "NSACmpFwdVarlenOp", + "GQAWindowSlidingOp", + "ManifoldConstrainedHyperConnectionPreOp", + "ManifoldConstrainedHyperConnectionPostOp", + "GLAFwdOp", +] diff --git a/tileops/ops/gla.py b/tileops/ops/gla.py new file mode 100644 index 00000000..76f02270 --- /dev/null +++ b/tileops/ops/gla.py @@ -0,0 +1,88 @@ +"""GLA (Gated Linear Attention) Forward Op.""" + +from typing import Any, Dict, Optional + +import torch + +from tileops.ops.op import Op + + +class GLAFwdOp(Op): + """Op wrapper for GLA (Gated Linear Attention) forward pass. + + Dispatches to GLAFwdKernel which implements the 4-stage chunked forward: + cumulative gate sum -> inter-chunk recurrence -> intra-chunk attention -> output. + + Args: + batch: Batch size B. + seq_len: Sequence length T. Must be divisible by chunk_size. + heads: Number of query/key/value heads H. + dim_k: Key/query head dimension K. + dim_v: Value head dimension V. + chunk_size: Chunk size BT (default 64). + scale: Query scale factor. Defaults to 1/sqrt(dim_k). + output_final_state: If True, also return the final hidden state. + dtype: Input tensor dtype (default torch.float16). + tune: Whether to run kernel autotuning. + kernel_map: Optional override for the kernel dispatch map. + + Example: + >>> op = GLAFwdOp(batch=2, seq_len=128, heads=8, dim_k=64, dim_v=64) + >>> q = torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16) + >>> k = torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16) + >>> v = torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16) + >>> g = torch.nn.functional.logsigmoid( + ... torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16)) + >>> o, final_state = op(q, k, v, g) + """ + + def __init__( + self, + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int = 64, + scale: Optional[float] = None, + output_final_state: bool = False, + dtype: torch.dtype = torch.float16, + tune: bool = False, + kernel_map: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map', '__class__')} + # resolve default scale before storing + if params['scale'] is None: + params['scale'] = dim_k**-0.5 + for k, v in params.items(): + setattr(self, k, v) + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["gla_fwd"](**params) + + @property + def default_kernel_map(self) -> Dict[str, Any]: + from tileops.kernels.gla import GLAFwdKernel + return {"gla_fwd": GLAFwdKernel} + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run GLA forward. + + Args: + q: Queries [B, T, H, K]. + k: Keys [B, T, H, K]. + v: Values [B, T, H, V]. + g: Log-space forget gates [B, T, H, K]. + initial_state: Optional initial hidden state [B, H, K, V] float32. + + Returns: + Tuple of (o [B, T, H, V], final_state [B, H, K, V] or None). + """ + return self.kernel(q, k, v, g, initial_state) From 322e5f1e79f6a1b700c02c94373ac03ee3b7b27b Mon Sep 17 00:00:00 2001 From: SuperAngGao Date: Thu, 26 Feb 2026 18:55:03 +0800 Subject: [PATCH 2/8] [Feat][GLA] Add GLA forward operator with skill.md files - Add seq_len % chunk_size == 0 assertion in GLAFwdOp to prevent OOB writes in TileLang kernels on non-divisible sequence lengths - Cast k/v to float32 per-chunk in GLAFwdKernel.forward to reduce peak memory usage - Fix k_adj formula in ref_gla_fwd to use log-space subtraction (matching GLAFwdKernel) instead of division with clamp - Add test_gla_fwd_non_divisible_seq_len to verify the assertion fires - Add skill.md files for create-new-kernel, create-new-op, create-new-op-test, creating-pull-request, migrating-new-op Co-Authored-By: Claude Sonnet 4.6 --- .claude/create-new-kernel/skill.md | 392 +++++++++++++++++++++++ .claude/create-new-op-test/skill.md | 192 +++++++++++ .claude/create-new-op/skill.md | 248 ++++++++++++++ .claude/skills/migrating-new-op/SKILL.md | 326 +++++++++---------- tests/ops/test_gla.py | 11 +- tileops/kernels/gla/gla_fwd.py | 6 +- tileops/ops/gla.py | 2 + 7 files changed, 1008 insertions(+), 169 deletions(-) create mode 100644 .claude/create-new-kernel/skill.md create mode 100644 .claude/create-new-op-test/skill.md create mode 100644 .claude/create-new-op/skill.md diff --git a/.claude/create-new-kernel/skill.md b/.claude/create-new-kernel/skill.md new file mode 100644 index 00000000..78447e17 --- /dev/null +++ b/.claude/create-new-kernel/skill.md @@ -0,0 +1,392 @@ +# Skill: Creating a New TileOps Kernel + +Reference implementation: `tileops/kernels/deepseek_nsa/nsa_topk.py` +Base class: `tileops/kernels/kernel.py` + +## Development Environment + +| Item | Version | +|---|---| +| GPU | H200 (sm90a) | +| CUDA | 12.9 | +| TileLang | 0.1.7.post1 | + +Target architecture is `sm90a`. All kernels are developed and validated on this environment. Do not assume compatibility with older architectures unless explicitly tested. + +--- + +## File Location and Naming + +``` +tileops/kernels//.py +``` + +- Group related kernels under a feature subdirectory (e.g. `tileops/kernels/deepseek_nsa/`) +- File name: `.py`, all lowercase with underscores (e.g. `nsa_topk.py`) +- Kernel function: `__kernel` (e.g. `_nsa_topk_varlen_kernel`) +- Wrapper function: `__wrapped_kernel` +- Class name: CamelCase + `Kernel` suffix (e.g. `NSATopkVarlenKernel`) +- Export the class from `tileops/kernels//__init__.py` + +--- + +## File Structure + +A kernel file contains five parts in order: + +1. Imports +2. Kernel function (tilelang implementation) +3. Wrapper function (`__wrapped_kernel`) +4. `register_fake` for the wrapper +5. `Kernel` subclass + +--- + +## Part 1: Imports + +```python +import torch +from typing import Optional, Any, Callable + +import tilelang +from tilelang import language as T + +from tileops.kernels.kernel import Kernel +``` + +Tilelang kernel implementations generally need no additional imports beyond these. + +--- + +## Part 2: Kernel Function + +The kernel is implemented as a two-level closure: + +```python +def __kernel( + # Fixed kernel parameters: shapes, dtypes, algorithm constants + param1: int, + param2: int, + dtype: str, + accum_dtype: str, +) -> Callable: + # Precompute constants and define tensor shapes here + # e.g. q_shape = [seq_len, heads, dim] + + @tilelang.jit( + out_idx=[-1], # index of output tensor in @T.prim_func's parameter list + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + def __func(threads: int): # auto-tune parameters go here + + @T.prim_func + def _main( + input1: T.Tensor(shape1, dtype), + ..., + output: T.Tensor(output_shape, accum_dtype), + ): + with T.Kernel(grid_x, grid_y, threads=threads) as (bx, by): + # kernel body + + return _main + + return __func +``` + +### Dtype conventions + +- Default compute dtype: `float16` (`"float16"`); exposed as `dtype: str` in the outermost kernel function so callers can override. +- Accumulator dtype: always `float32`, hardcoded internally as `accum_dtype = "float32"`. Do not expose it as a parameter. + +### Memory hierarchy + +- Always move data from global memory to **shared memory** (`T.alloc_shared`) before computation. +- For reductions, move data from shared memory to **fragment** (`T.alloc_fragment`) first, then call `T.reduce_sum` / `T.reduce_max` on the fragment. +- Never reduce directly on shared memory. + +Typical pattern: +``` +global → T.copy → shared → T.gemm / T.copy → fragment → T.reduce_* → fragment +``` + +### Code organisation + +- Keep `@T.prim_func` minimal — it should read like a high-level algorithm outline. +- Extract non-trivial logic (sorting, masking, multi-step computations) into `@T.macro` helpers defined inside `__func`, and call them from `_main`. + +### Tilelang coding guidelines + +| Concern | Guidance | +|---|---| +| Exponentiation | Use `T.exp2(x * LOG2_E)` instead of `T.exp(x)` for performance (`LOG2_E = 1.44269504`) | +| Data movement | Use `T.copy()` for copies; `T.gemm()` for matrix multiply | +| Reduction | Use `T.reduce_sum()` / `T.reduce_max()` on fragments (not shared memory) | +| Parallelism | Prefer `T.Parallel` over `T.Serial`; only use `T.Serial` when sequential dependency is unavoidable | +| Loop structure | Prefer `T.Pipelined` for outer loops to enable software pipelining | +| Bug workaround | Tilelang may have bugs with nested `T.Parallel` / `T.Serial` / `T.Pipelined`. If compilation or runtime errors occur, try reordering the nesting levels as a first debugging step. | + +--- + +## Part 3: Wrapper Function + +The wrapper registers the kernel with `torch.library` so it is visible to `torch.compile` and `torch.export`. The `Kernel` class calls this wrapper in `forward`. + +### 3a. custom_op wrapper + +```python +@torch.library.custom_op("top::_wrapped_kernel", mutates_args=()) +def __wrapped_kernel( + # All scalar params first (int, float, str for dtype names) + param1: int, + dtype: str, + threads: int, + # Tensor inputs last + input1: torch.Tensor, + ... +) -> torch.Tensor: + return __kernel(param1, dtype)(threads)(input1, ...) +``` + +The call chain is: `_kernel(fixed_params)(autotune_params)(tensor_inputs)`. + +Note: `accum_dtype` is hardcoded inside `__kernel` and is not a parameter of the wrapper. + +### 3b. register_fake + +Provides a shape/dtype inference rule for tracing (no actual computation): + +```python +@__wrapped_kernel.register_fake +def _( + param1: int, + dtype: str, + threads: int, + *inputs: tuple[Any], +) -> torch.Tensor: + _ = (param1, dtype, threads) # suppress unused warnings + # Return an empty tensor with the correct output shape and dtype. + # Shape must match what the real kernel produces. + return torch.empty( + [output_dim0, output_dim1, ...], + dtype=inputs[0].dtype, + device=inputs[0].device, + ) +``` + +> The output shape here must exactly match the real kernel's output. Derive it from the scalar parameters (e.g. `c_seq_len`, `heads`, `selected_block_num`), not from input tensor shapes. + +--- + +## Part 4: Kernel Class + +- Naming: CamelCase, descriptive, ending in `Kernel` (e.g. `NSATopkVarlenKernel`) +- Inherit from `Kernel` (see `tileops/kernels/kernel.py`) +- Declare `supported_archs` for GPU arch gating; current target is H200 (sm90a), so use `[90]` + +The class **must** include a docstring with three sections: + +1. Input tensor shapes — list each tensor parameter with its layout (e.g. `[batch, seqlen, heads, dim]`) +2. Computation logic — a brief description of what the kernel computes +3. Reference — URL to the official PyTorch / Triton / paper implementation this is based on + +Example: + +```python +class Kernel(Kernel): + """. + + Args: + q: Query tensor, shape [batch, seqlen_q, heads, dim] + k: Key tensor, shape [batch, seqlen_k, heads_kv, dim] + v: Value tensor, shape [batch, seqlen_k, heads_kv, dim_v] + offsets: Sequence boundary offsets, shape [seq_num + 1], dtype int32 + ... + + Computation: + + + Reference: + + """ + supported_archs: list[int] = [90] + + def __init__(self, + param1: int, + ..., + dtype: torch.dtype = torch.float16, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.param1 = param1 + ... + # Store dtype as string for tilelang; accum_dtype is hardcoded in the kernel + self.dtype_name = str(dtype).split('.')[-1] + self.init_config(config, tune) # must be called last + + @property + def default_config(self) -> dict: + return {"threads": 32} + + @property + def autotune_configs(self) -> list[dict]: + return [{"threads": t} for t in [32, 64, 128]] + + def forward(self, input1: torch.Tensor, ...) -> torch.Tensor: + return __wrapped_kernel( + self.param1, + ..., + self.dtype_name, + self.config["threads"], + input1.to(getattr(torch, self.dtype_name)), + ..., + ) +``` + +Key points: +- `forward` calls `__wrapped_kernel` (the registered wrapper), not the raw kernel function directly +- `init_config(config, tune)` must be the last call in `__init__`; it reads `default_config` and `autotune_configs` +- `accum_dtype` is not stored on the class; it is hardcoded inside the kernel function +- Only cast index/offset tensors (e.g. `offsets`, `token_indices`) to `torch.int32` in `forward`; do NOT cast floating-point tensors + +--- + +## Attention Kernel Conventions + +Attention kernels have additional conventions beyond the general rules above. + +### Kernel variants + +Every attention kernel must be classified into one of three variants, which determines its file name, class name, and internal structure: + +| Variant | File suffix | Class suffix | Description | +|---|---|---|---| +| Forward | `_fwd.py` | `FwdKernel` | Full-sequence prefill / training forward pass | +| Decode | `_decode.py` | `DecodeKernel` | Single-token decode with KV cache | +| Backward | `_bwd.py` | `BwdKernel` | Training backward pass | + +Examples: `mha_fwd.py` → `MhaFwdKernel`, `mha_decode.py` → `MhaDecodeKernel` + +### `causal` parameter + +All attention kernels **must** expose `is_causal: bool` as a parameter of the outermost kernel function (the two-level closure). It controls the causal masking logic inside the `@T.prim_func`. + +```python +def __kernel( + batch: int, + heads: int, + seqlen_q: int, + seqlen_kv: int, + dim: int, + is_causal: bool, # required for all attention kernels + dtype: str, +) -> Callable: + ... +``` + +### Decode kernel: split-K design + +Decode kernels must support both a **no-split** and a **split-K** execution path, selected at runtime by `num_split`: + +- `num_split = 1` → use the no-split `@T.prim_func` (single pass over KV) +- `num_split > 1` → use the split `@T.prim_func` (parallel over KV chunks, then combine) + +Both paths are implemented as `@T.macro` functions inside `__func`, and the outer `@T.prim_func` simply calls the appropriate macro. The wrapper function (`__wrapped_kernel`) computes `split_length` and dispatches to the correct path. + +`num_split` is a tunable parameter and must appear in `default_config` and `autotune_configs`. + +Structure inside `__func`: + +```python +def __decode_func(block_M, block_N, num_split, num_stages, threads): + + @T.macro + def __no_split(Q, K, V, real_seqlen_kv, Output): + # single-pass attention over full KV + ... + + @T.macro + def __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length): + # attention over one KV chunk; writes partial output + log-sum-exp + ... + + @T.macro + def combine(glse, Output_partial, Output): + # merge partial outputs using LSE rescaling + ... + + @T.prim_func + def _decode_no_split(Q, K, V, real_seqlen_kv, Output): + __no_split(Q, K, V, real_seqlen_kv, Output) + + @T.prim_func + def _decode_split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length, Output): + __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) + combine(glse, Output_partial, Output) + + if num_split > 1: + return _decode_split + else: + return _decode_no_split +``` + +The wrapper allocates `glse` and `Output_partial` buffers and computes `split_length` before dispatching: + +```python +@torch.library.custom_op("top::_decode_wrapped_kernel", mutates_args=()) +def __decode_wrapped_kernel( + ..., num_split: int, + Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + glse: torch.Tensor, Output_partial: torch.Tensor, +) -> torch.Tensor: + split_length = ... # compute per-split chunk sizes + if split_length[0] == 0: + num_split = 1 + if num_split == 1: + return __decode_kernel(...)(block_M, block_N, 1, num_stages, threads)(Q, K, V, real_seqlen_kv) + return __decode_kernel(...)(block_M, block_N, num_split, num_stages, threads)( + Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) +``` + +The `Kernel.forward` allocates `glse` and `Output_partial` as temporary buffers before calling the wrapper: + +```python +def forward(self, Q, K, V, real_seqlen_kv): + glse = torch.empty((..., self.config["num_split"], ...), dtype=..., device=Q.device) + Output_partial = torch.empty((..., self.config["num_split"], ...), dtype=..., device=Q.device) + return __decode_wrapped_kernel(..., self.config["num_split"], Q, K, V, glse, Output_partial) +``` + +### Decode kernel: paged KV variant + +For paged KV cache support, create a separate file `_decode_paged.py`. The paged variant: +- Adds `page_size: int` to the outermost kernel function +- Replaces the flat KV shape `[batch, seqlen_kv, heads, dim]` with a paged pool shape `[total_pages * page_size, heads, dim]` +- Adds a `block_table: T.Tensor([batch, max_pages], T.int32)` parameter to index into the page pool +- `real_seqlen_kv` becomes a per-batch tensor `T.Tensor([batch], T.int32)` instead of a scalar + +Reference: `tileops/kernels/flash_decode/mha_decode_paged.py` + +- [ ] `out_idx` in `@tilelang.jit` points to the correct output tensor position +- [ ] `register_fake` output shape matches the real kernel output +- [ ] `custom_op` name is unique: `"top::_wrapped_kernel"` +- [ ] File placed under `tileops/kernels//.py` and exported from its `__init__.py` +- [ ] Class name is CamelCase and ends with `Kernel` +- [ ] Class has docstring with: input tensor shapes, computation logic, reference URL +- [ ] `supported_archs` is set appropriately +- [ ] `init_config` is called at the end of `__init__` +- [ ] `forward` calls the wrapper (`__wrapped_kernel`), not the raw kernel directly +- [ ] Index/offset tensors (e.g. `offsets`, `token_indices`) are cast to `torch.int32` in `forward`; do NOT cast floating-point tensors +- [ ] Default dtype is `float16` (exposed); accumulator is `float32` (hardcoded internally) +- [ ] Global memory is copied to shared memory before computation +- [ ] Reductions operate on fragments, not shared memory +- [ ] `@T.prim_func` is kept minimal; complex logic is in `@T.macro` helpers +- [ ] (Attention only) `is_causal: bool` is a parameter of the outermost kernel function +- [ ] (Attention only) File/class name includes variant suffix: `_fwd` / `_decode` / `_bwd` +- [ ] (Decode only) Both no-split and split-K `@T.prim_func` are implemented; `num_split` is in `default_config` +- [ ] (Decode paged) Separate `_decode_paged.py` file; uses `page_size`, `block_table`, and per-batch `real_seqlen_kv` diff --git a/.claude/create-new-op-test/skill.md b/.claude/create-new-op-test/skill.md new file mode 100644 index 00000000..34c5d05f --- /dev/null +++ b/.claude/create-new-op-test/skill.md @@ -0,0 +1,192 @@ +# Skill: Writing Op Correctness Unit Tests + +Reference tests: `tests/ops/` + +--- + +## Overview + +Each op has a pytest file at `tests/ops/test_.py`. The test: + +1. Instantiates the `Op` with given parameters +2. Generates random input tensors +3. Runs the op and a PyTorch reference implementation +4. Asserts numerical closeness with `torch.allclose` + +No benchmark infrastructure is needed — correctness only. + +--- + +## File Location and Naming + +``` +tests/ops/test_.py +``` + +File name mirrors the op module name (e.g. `tileops/ops/gqa_decode.py` → `tests/ops/test_gqa_decode.py`). + +--- + +## Test File Structure + +```python +import pytest +import torch + +from tileops.ops import Op + + +def ref_(*inputs, **params) -> torch.Tensor: + # Pure PyTorch reference implementation + # Use torch.nn.functional or F.scaled_dot_product_attention + ... + + +@pytest.mark.parametrize("", [ + (...), # case 1 + (...), # case 2 + # ... at least 5 cases +]) +def test_(param1: int, ..., dtype: torch.dtype, tune: bool) -> None: + torch.manual_seed(42) + + op = Op(param1, ..., dtype=dtype, tune=tune) + + # generate inputs + x = torch.randn(..., device='cuda', dtype=dtype) + ... + + # run op + with torch.no_grad(): + out = op(x, ...) + + # run reference + ref = ref_(x, ..., param1=param1, ...) + + assert torch.allclose(out, ref, atol=, rtol=), \ + f"max err: {(out - ref).abs().max()}" + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) +``` + +--- + +## Test Case Requirements + +### Minimum coverage + +- At least **5 parametrized test cases** per test function +- Vary structural parameters across cases: batch, seq_len, heads, dim, etc. +- Include at least one `tune=True` case if the op supports autotuning + +### dtype coverage + +| Op type | Required dtypes | +|---|---| +| General ops | `torch.float16`, `torch.bfloat16` | +| Quantization ops (fp8, fp4, etc.) | Include the target quantized dtype (e.g. `torch.float8_e4m3fn`) | +| Mixed-precision ops | Cover all relevant input/output dtype combinations | + +### Random seed + +Fix `torch.manual_seed(42)` at the top of every test function, before any tensor creation. + +### Tolerance guidelines + +| Op type | `atol` | `rtol` | +|---|---|---| +| Attention fwd (fp16/bf16) | `5e-3` | `1e-5` | +| Attention decode | `1e-2` | `1e-2` | +| GEMM / linear | `1e-3` | `1e-3` | +| Elementwise / quantization | `1e-1` or custom | — | + +### Reference implementation + +- If `torch.nn` or `torch.nn.functional` provides the operation directly, use it. +- If not, implement the reference manually using basic PyTorch ops (`torch.matmul`, `torch.softmax`, elementwise ops, etc.), following the algorithm described in the kernel's docstring and the official reference URL. +- For attention fwd: use `SDPBackend.FLASH_ATTENTION` +- For attention decode: use `SDPBackend.MATH` (flash attention does not support single-token decode) +- Never use another TileOps op as the reference +- If the op has an official PyTorch or Triton reference implementation (e.g. in the kernel's docstring `Reference:` URL), consult it when writing `ref_program` to ensure the reference matches the intended algorithm exactly — including scale factors, layout conventions, and masking logic. + +### Fallback: cosine similarity test + +If numerical error persists after exhausting all debugging steps (see `.claude/create-new-op/skill.md` debugging protocol), the `torch.allclose` assertion may be replaced with a cosine similarity check. The threshold is **0.999 and must not be relaxed**: + +```python +def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> float: + a_f = a.float().flatten() + b_f = b.float().flatten() + return torch.nn.functional.cosine_similarity(a_f, b_f, dim=0).item() + +# in the test: +sim = cosine_sim(out, ref) +assert sim >= 0.999, f"cosine similarity {sim:.6f} < 0.999" +``` + +Add a comment explaining why `torch.allclose` was replaced and what debugging was attempted. + +--- + +## Example: attention fwd op + +```python +import pytest +import torch +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel + +from tileops.ops import MultiHeadAttentionFwdOp + + +def ref_mha_fwd(q, k, v, is_causal): + # input layout: [batch, seqlen, heads, dim] → transpose to [batch, heads, seqlen, dim] + q_t, k_t, v_t = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=is_causal) + return out.transpose(1, 2).contiguous() + + +@pytest.mark.parametrize("batch, seq_len, heads, dim, is_causal, dtype, tune", [ + (1, 1024, 8, 64, False, torch.float16, False), + (4, 2048, 16, 128, False, torch.bfloat16, False), + (8, 4096, 16, 128, True, torch.float16, False), + (2, 1024, 32, 64, True, torch.bfloat16, False), + (4, 2048, 16, 128, False, torch.bfloat16, True), +]) +def test_mha_fwd(batch, seq_len, heads, dim, is_causal, dtype, tune): + torch.manual_seed(42) + op = MultiHeadAttentionFwdOp(batch, heads, seq_len, dim, is_causal, dtype, tune=tune) + + q = torch.randn(batch, seq_len, heads, dim, device='cuda', dtype=dtype) + k = torch.randn(batch, seq_len, heads, dim, device='cuda', dtype=dtype) + v = torch.randn(batch, seq_len, heads, dim, device='cuda', dtype=dtype) + + with torch.no_grad(): + out = op(q, k, v) + ref = ref_mha_fwd(q, k, v, is_causal) + + assert torch.allclose(out, ref, atol=5e-3, rtol=1e-5), \ + f"max err: {(out - ref).abs().max()}" + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) +``` + +--- + +## Checklist + +- [ ] File placed at `tests/ops/test_.py` +- [ ] `torch.manual_seed(42)` at the top of every test function +- [ ] At least 5 parametrized test cases +- [ ] Both `torch.float16` and `torch.bfloat16` covered (unless op is dtype-specific) +- [ ] Quantization ops include the target quantized dtype +- [ ] Structural parameters varied across cases +- [ ] At least one `tune=True` case if op supports autotuning +- [ ] `atol` / `rtol` appropriate for the op type +- [ ] Reference uses PyTorch built-ins only +- [ ] Test file ends with `if __name__ == "__main__": pytest.main([__file__, "-vvs"])` diff --git a/.claude/create-new-op/skill.md b/.claude/create-new-op/skill.md new file mode 100644 index 00000000..56a7fb5f --- /dev/null +++ b/.claude/create-new-op/skill.md @@ -0,0 +1,248 @@ +# Skill: Creating a New TileOps Op + +Reference implementations: `tileops/ops/mha.py`, `tileops/ops/gqa_decode.py`, `tileops/ops/deepseek_nsa.py` +Base class: `tileops/ops/op.py` + +--- + +## Overview + +An `Op` is a thin orchestration layer that: +1. Holds kernel instances (one or more `Kernel` subclasses) +2. Dispatches to the correct kernel based on hardware via `dispatch_kernel` +3. Exposes a `forward` method that calls the kernel(s) and handles any pre/post-processing + +An `Op` does **not** implement GPU computation itself — that lives in the `Kernel`. + +--- + +## File Location and Naming + +``` +tileops/ops/.py +``` + +- File name: `.py`, all lowercase with underscores (e.g. `gqa_decode.py`, `deepseek_nsa.py`) +- Class name: CamelCase + `Op` suffix (e.g. `GroupQueryAttentionDecodeWithKVCacheOp`) +- Group multiple related ops in one file when they share the same kernel set (e.g. `mha.py` contains both `MultiHeadAttentionFwdOp` and `MultiHeadAttentionBwdOp`) +- After creating the file, register all new classes in `tileops/ops/__init__.py` + +--- + +## File Structure + +``` +tileops/ops/.py +``` + +A single op file contains: +1. Imports +2. `__all__` declaration +3. One or more `Op` subclasses + +After creating the file, register the new class in `tileops/ops/__init__.py`. + +--- + +## Part 1: Imports + +```python +from typing import Dict, Optional + +import torch + +from tileops.kernels. import +from tileops.kernels.kernel import Kernel + +from .op import Op +``` + +Only import `is_hopper` from `tileops.utils` if the op needs to dispatch different kernels per architecture. + +--- + +## Part 2: Op Class + +### Naming + +- CamelCase, descriptive, ending in `Op` +- Examples: `MultiHeadAttentionFwdOp`, `NSATopkVarlenOp`, `GroupQueryAttentionDecodeWithKVCacheOp` + +### `__init__` — two patterns + +**Pattern A: simple (one kernel, few params)** + +Assign each param explicitly, then call `dispatch_kernel` and instantiate the kernel: + +```python +class Op(Op): + + def __init__(self, + param1: int, + param2: int, + dtype: torch.dtype = torch.float16, + tune: bool = False, + kernel_map: Optional[Dict[str, Kernel]] = None) -> None: + self.param1 = param1 + self.param2 = param2 + self.dtype = dtype + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map[""]( + param1, param2, dtype, tune=tune) +``` + +**Pattern B: many params (use locals() shortcut)** + +When there are many parameters, use `locals()` to avoid repetition: + +```python + def __init__(self, + param1: int, + ..., + dtype: torch.dtype = torch.float16, + tune: bool = False, + kernel_map: Optional[Dict[str, Kernel]] = None, + ) -> None: + params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map')} + for key, value in params.items(): + setattr(self, key, value) + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map[""](**params) +``` + +> `kernel_map` is always the last parameter, after `tune`. When using `locals()`, exclude both `self` and `kernel_map` from `params`; `tune` is included and passed through to the kernel. `accum_dtype` is never a parameter of the Op — it is hardcoded inside the Kernel. + +### `default_kernel_map` + +Maps string keys to `Kernel` classes. This is what `dispatch_kernel` uses to resolve the actual kernel type: + +```python + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"": } +``` + +For ops that need different kernels per architecture: + +```python + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"": if is_hopper() else } +``` + +For ops with multiple kernels (e.g. fwd + pre/post-process): + +```python + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return { + "preprocess_kernel": PreprocessKernel, + "main_kernel": MainWgmmaPipelinedKernel if is_hopper() else MainKernel, + "postprocess_kernel": PostprocessKernel if not is_hopper() else None, + } +``` + +### `forward` + +Calls `self.kernel(...)` directly (the `Kernel.__call__` delegates to `Kernel.forward`). + +```python + def forward(self, input1: torch.Tensor, ...) -> torch.Tensor: + return self.kernel(input1, ...) +``` + +For ops with pre/post-processing: + +```python + def forward(self, ...) -> tuple[torch.Tensor, ...]: + intermediate = self.prep_kernel(...) + result = self.kernel(..., intermediate) + return self.post_kernel(result) +``` + +For ops that need input validation or padding before calling the kernel (see `gqa_decode.py`): + +```python + def forward(self, q: torch.Tensor, k: torch.Tensor, ...) -> torch.Tensor: + # e.g. pad k/v to match declared seqlen + if k.shape[1] < self.seqlen_kv: + k = F.pad(k, ...) + return self.kernel(q, k, ...) +``` + +--- + +## Part 3: Register in `__init__.py` + +Add the import and the class name to `__all__` in `tileops/ops/__init__.py`: + +```python +# import +from . import Op + +# __all__ +__all__ = [ + ... + "Op", +] +``` + +--- + +## Key Points + +- `dispatch_kernel(kernel_map)` must be called before instantiating any kernel; it validates arch compatibility via `Kernel.supported_archs` +- `self.kernel` is the primary kernel; additional kernels (pre/post) are stored as `self.prep_kernel`, `self.post_kernel`, etc. +- The `Op` does not store `accum_dtype` — that is internal to the `Kernel` +- `dtype` defaults to `torch.float16` unless the op has a specific reason to require explicit dtype +- `kernel_map` parameter allows callers to inject alternative kernel implementations (for testing or custom dispatch); always default to `None` +- `tune=False` is always before `kernel_map`; `kernel_map` is always the last `__init__` parameter +- `accum_dtype` is never a parameter of the Op — it is hardcoded inside the Kernel + +--- + +## Unit Test Requirement + +Every new op **must** have a passing unit test in `tests/ops/test_.py` before it is considered complete. See `.claude/create-new-op-test/skill.md` for how to write the test. + +### Correctness debugging protocol + +When the unit test fails due to numerical mismatch, follow this process in order: + +1. **Do not loosen tolerances.** A large numerical error indicates a real bug, not a precision issue. The PyTorch reference is the ground truth. +2. **Check the algorithm first.** Compare the tilelang implementation step-by-step against the official reference URL in the kernel's docstring — scale factors, layout (BSHD vs BHSD), masking logic, softmax normalization, etc. +3. **Check memory layout and indexing.** Verify that tensor shapes, strides, and index expressions in `T.copy` / `T.gemm` match the intended layout. +4. **Isolate the stage.** If the kernel has multiple stages (e.g. QK gemm → softmax → PV gemm), add intermediate checks to identify which stage produces the wrong result. +5. **Replace tilelang primitives with scalar equivalents for debugging.** If the above steps do not reveal the bug, temporarily replace tilelang primitives with simpler equivalents to rule out compiler/primitive bugs: + + - Replace `T.Pipelined` with `T.serial` (removes software pipelining) + - Replace `T.copy(src, dst)` with a `T.Parallel` loop and element-wise assignment: + ```python + # DEBUG: replaced T.copy with explicit parallel assignment + for i, j in T.Parallel(M, N): + dst[i, j] = src[i, j] + ``` + - Replace `T.gemm(A, B, C)` with explicit `T.Serial` + `T.Parallel` accumulation: + ```python + # DEBUG: replaced T.gemm with explicit serial-parallel matmul + for k in T.serial(K): + for i, j in T.Parallel(M, N): + C[i, j] += A[i, k] * B[j, k] # adjust indices for transpose_B + ``` + - Mark all such changes with a `# DEBUG:` comment so they are easy to find and revert. +6. **Revert debug changes** once the bug is found and fixed. Do not leave `T.serial` / explicit loops in production code. + +--- + +- [ ] Class name is CamelCase and ends with `Op` +- [ ] File placed at `tileops/ops/.py` following lowercase underscore naming +- [ ] `dispatch_kernel(kernel_map)` is called before kernel instantiation +- [ ] `default_kernel_map` is implemented and returns at least one entry +- [ ] `forward` calls `self.kernel(...)`, not the raw kernel function +- [ ] Unit test at `tests/ops/test_.py` exists and passes +- [ ] Op is imported and added to `__all__` in `tileops/ops/__init__.py` +- [ ] `accum_dtype` is not stored on the Op class +- [ ] `kernel_map: Optional[Dict[str, Kernel]] = None` is the last `__init__` parameter, after `tune` +- [ ] `accum_dtype` is not a parameter of the Op diff --git a/.claude/skills/migrating-new-op/SKILL.md b/.claude/skills/migrating-new-op/SKILL.md index 7998544a..8240c3ab 100644 --- a/.claude/skills/migrating-new-op/SKILL.md +++ b/.claude/skills/migrating-new-op/SKILL.md @@ -1,162 +1,164 @@ ---- -name: migrating-new-op -description: Use when adding or migrating an operator to TileOPs with the required kernel->op->function->layer delivery path and validation checklist ---- - -## When to use - -- Add a new op into TileOPs -- Migrate an op from `cuda`/`triton` into TileOPs with TileLang kernels -- Standardize implementation and tests to the project architecture order - -______________________________________________________________________ - -## Level Confirmation (Ask First) - -When this skill is invoked, ask the user first: - -"What implementation level do you want for this operator?" - -- **L1 (Kernel only)** - - Deliverables: kernel implementation + kernel/functional correctness checks - - Typical paths: `tileops/kernels//`, minimal verification script/tests -- **L2 (Kernel + Op)** - - Deliverables: L1 + op interface + op unit tests + benchmark script - - Typical paths: `tileops/ops/.py`, `tests/ops/test_.py`, `benchmarks/benchmark_.py` -- **L3 (Kernel + Op + Function)** - - Deliverables: L2 + functional API (`torch.autograd.Function` when needed) + function tests/grad checks - - Typical paths: `tileops/functions/.py`, `tests/functions/test__func.py` -- **L4 (Full stack: Kernel + Op + Function + Layer)** - - Deliverables: L3 + `nn.Module` layer wrapper + layer tests + export synchronization - - Typical paths: `tileops/layers/.py`, `tests/layers/test__layer.py` - -If user does not specify a level, default to **L2** and state this assumption explicitly. - -After level confirmation, execute only the required layers for that level (do not over-deliver). - -______________________________________________________________________ - -## Workflow - -### Phase A: Requirement + Reference Alignment - -1. Confirm target behavior, API, and expected test scope. -1. Read the external reference implementation and extract: - - kernel stages - - input/output semantics - - accumulation and edge-case behavior - -### Phase B: Delivery Design by Layer - -1. Confirm target level (L1/L2/L3/L4) with the user before coding. -1. Plan implementation in strict order: `kernel -> op -> function -> layer`. -1. Define per-layer responsibilities and avoid cross-layer shortcuts. -1. List only level-required files/tests before coding (see checklist below). - -### Phase C: Kernel Migration (TileLang) - -1. Implement TileLang kernels by stage (same logical decomposition as reference). -1. Keep output semantics compatible with existing interfaces. -1. Handle core edge cases early (for example: empty paths / `num_topk == 0`). - -### Phase D: Upper Layer Wiring - -1. Add `op` wrapper for kernel invocation + runtime contract. -1. Add `function` API for reusable composition and shape/dtype validation. -1. Add `layer` abstraction only when module-style integration is needed. -1. Keep dependency direction one-way: `layer -> function -> op -> kernel`. - -### Phase E: Tests and Validation - -1. Ensure layer-matched tests exist: - -- `tests/ops` for op behavior -- `tests/functions` for functional integration -- `tests/layers` when a layer is introduced - -2. Add reference comparison for correctness checks. -1. Run incremental tests first, then the relevant chain. - -### Phase F: Cleanup and Finalization - -1. Remove obsolete wrappers/helpers/classes. -1. Update exports/imports (`__init__.py`) to avoid stale symbols. -1. Re-run the main regression chain and summarize results. - -______________________________________________________________________ - -## Software Organization Guidelines - -- **Kernel layer (`tileops/kernels`)**: compute logic and kernel configs only. -- **Op layer (`tileops/ops`)**: wraps kernel call contract and runtime dispatch. -- **Function layer (`tileops/functions`)**: reusable API composition and validation. -- **Layer (`tileops/layers`)**: module-level integration (`nn.Module` style). -- Keep dependency direction strict: `layer -> function -> op -> kernel`. -- After refactoring, always sync exports to prevent import-time failures. - -______________________________________________________________________ - -## Minimum Deliverables Checklist - -Before opening a PR, verify all required items are present: - -1. **L1: Kernel (`tileops/kernels`)** - -- [ ] New/updated kernel implementation exists -- [ ] Kernel handles documented edge cases - -2. **L2: Op (`tileops/ops`)** - -- [ ] Op API wraps kernel with stable argument contract -- [ ] Op-level tests added/updated in `tests/ops` -- [ ] Benchmark script added/updated in `benchmarks` - -3. **L3: Function (`tileops/functions`)** - -- [ ] Function API added/updated for composable usage -- [ ] Function tests added/updated in `tests/functions` -- [ ] `gradcheck` path is added when autograd is expected - -4. **L4: Layer (`tileops/layers`, if needed)** - -- [ ] Layer class added only when module abstraction is required -- [ ] Layer tests added/updated in `tests/layers` - -5. **Project Hygiene** - -- [ ] `__init__.py` exports are synchronized -- [ ] Relevant tests pass locally -- [ ] Migration notes / behavior deltas are documented - -______________________________________________________________________ - -## Test Flow - -### Recommended command pattern - -```bash -PYTHONPATH="$PWD" python -m pytest -v tests/ops/test_xxx.py -PYTHONPATH="$PWD" python -m pytest -v tests/functions/test_xxx_func.py -PYTHONPATH="$PWD" python -m pytest -v tests/layers/test_xxx_layer.py -``` - -### Main-chain regression - -```bash -PYTHONPATH="$PWD" python -m pytest -v \ - tests/ops/test_xxx.py \ - tests/functions/test_xxx_func.py \ - tests/layers/test_xxx_layer.py -``` - -______________________________________________________________________ - -## Done Criteria - -Migration is considered complete when: - -- kernel/op/function/layer (if needed) are delivered in correct order -- API behavior stays compatible -- stale legacy wrappers are removed -- relevant ops/functions/layers tests pass -- migration notes are documented for reuse +--- +name: migrating-new-op +description: Use when adding or migrating an operator to TileOPs with the required kernel->op->function->layer delivery path and validation checklist +--- + +## When to use + +- Add a new op into TileOPs +- Migrate an op from `cuda`/`triton` into TileOPs with TileLang kernels +- Standardize implementation and tests to the project architecture order + +______________________________________________________________________ + +## Level Confirmation (Ask First) + +When this skill is invoked, ask the user first: + +"What implementation level do you want for this operator?" + +- **L1 (Kernel only)** + - Deliverables: kernel implementation + kernel/functional correctness checks + - Typical paths: `tileops/kernels//`, minimal verification script/tests +- **L2 (Kernel + Op)** + - Deliverables: L1 + op interface + op unit tests + benchmark script + - Typical paths: `tileops/ops/.py`, `tests/ops/test_.py`, `benchmarks/benchmark_.py` +- **L3 (Kernel + Op + Function)** + - Deliverables: L2 + functional API (`torch.autograd.Function` when needed) + function tests/grad checks + - Typical paths: `tileops/functions/.py`, `tests/functions/test__func.py` +- **L4 (Full stack: Kernel + Op + Function + Layer)** + - Deliverables: L3 + `nn.Module` layer wrapper + layer tests + export synchronization + - Typical paths: `tileops/layers/.py`, `tests/layers/test__layer.py` + +If user does not specify a level, default to **L2** and state this assumption explicitly. + +After level confirmation, execute only the required layers for that level (do not over-deliver). + +______________________________________________________________________ + +## Workflow + +### Phase A: Requirement + Reference Alignment + +1. Confirm target behavior, API, and expected test scope. +1. Read the external reference implementation and extract: + - kernel stages + - input/output semantics + - accumulation and edge-case behavior + +### Phase B: Delivery Design by Layer + +1. Confirm target level (L1/L2/L3/L4) with the user before coding. +1. Plan implementation in strict order: `kernel -> op -> function -> layer`. +1. Define per-layer responsibilities and avoid cross-layer shortcuts. +1. List only level-required files/tests before coding (see checklist below). + +### Phase C: Kernel Migration (TileLang) + +1. Implement TileLang kernels by stage (same logical decomposition as reference). +1. Keep output semantics compatible with existing interfaces. +1. Handle core edge cases early (for example: empty paths / `num_topk == 0`). +1. Follow the conventions in [`.claude/create-new-kernel/skill.md`](../create-new-kernel/skill.md): file layout, naming, dtype rules, memory hierarchy, attention variants, wrapper/register_fake pattern, and docstring requirements. + +### Phase D: Upper Layer Wiring + +1. Add `op` wrapper for kernel invocation + runtime contract. Follow [`.claude/create-new-op/skill.md`](../create-new-op/skill.md) for Op class structure, kernel dispatch, and `__init__.py` registration. +1. Add `function` API for reusable composition and shape/dtype validation. +1. Add `layer` abstraction only when module-style integration is needed. +1. Keep dependency direction one-way: `layer -> function -> op -> kernel`. + +### Phase E: Tests and Validation + +1. Ensure layer-matched tests exist: + +- `tests/ops` for op behavior — follow [`.claude/create-new-op-test/skill.md`](../create-new-op-test/skill.md) for test structure, parametrization, dtype coverage, and debugging protocol +- `tests/functions` for functional integration +- `tests/layers` when a layer is introduced + +2. Add reference comparison for correctness checks. +1. Run incremental tests first, then the relevant chain. + +### Phase F: Cleanup and Finalization + +1. Remove obsolete wrappers/helpers/classes. +1. Update exports/imports (`__init__.py`) to avoid stale symbols. +1. Re-run the main regression chain and summarize results. + +______________________________________________________________________ + +## Software Organization Guidelines + +- **Kernel layer (`tileops/kernels`)**: compute logic and kernel configs only. +- **Op layer (`tileops/ops`)**: wraps kernel call contract and runtime dispatch. +- **Function layer (`tileops/functions`)**: reusable API composition and validation. +- **Layer (`tileops/layers`)**: module-level integration (`nn.Module` style). +- Keep dependency direction strict: `layer -> function -> op -> kernel`. +- After refactoring, always sync exports to prevent import-time failures. + +______________________________________________________________________ + +## Minimum Deliverables Checklist + +Before opening a PR, verify all required items are present: + +1. **L1: Kernel (`tileops/kernels`)** + +- [ ] New/updated kernel implementation exists +- [ ] Kernel handles documented edge cases +- [ ] Follows conventions in [`.claude/create-new-kernel/skill.md`](../create-new-kernel/skill.md) + +2. **L2: Op (`tileops/ops`)** + +- [ ] Op API wraps kernel with stable argument contract +- [ ] Op-level tests added/updated in `tests/ops` — follows [`.claude/create-new-op-test/skill.md`](../create-new-op-test/skill.md) +- [ ] Benchmark script added/updated in `benchmarks` + +3. **L3: Function (`tileops/functions`)** + +- [ ] Function API added/updated for composable usage +- [ ] Function tests added/updated in `tests/functions` +- [ ] `gradcheck` path is added when autograd is expected + +4. **L4: Layer (`tileops/layers`, if needed)** + +- [ ] Layer class added only when module abstraction is required +- [ ] Layer tests added/updated in `tests/layers` + +5. **Project Hygiene** + +- [ ] `__init__.py` exports are synchronized +- [ ] Relevant tests pass locally +- [ ] Migration notes / behavior deltas are documented + +______________________________________________________________________ + +## Test Flow + +### Recommended command pattern + +```bash +PYTHONPATH="$PWD" python -m pytest -v tests/ops/test_xxx.py +PYTHONPATH="$PWD" python -m pytest -v tests/functions/test_xxx_func.py +PYTHONPATH="$PWD" python -m pytest -v tests/layers/test_xxx_layer.py +``` + +### Main-chain regression + +```bash +PYTHONPATH="$PWD" python -m pytest -v \ + tests/ops/test_xxx.py \ + tests/functions/test_xxx_func.py \ + tests/layers/test_xxx_layer.py +``` + +______________________________________________________________________ + +## Done Criteria + +Migration is considered complete when: + +- kernel/op/function/layer (if needed) are delivered in correct order +- API behavior stays compatible +- stale legacy wrappers are removed +- relevant ops/functions/layers tests pass +- migration notes are documented for reuse diff --git a/tests/ops/test_gla.py b/tests/ops/test_gla.py index 03093196..069ab281 100644 --- a/tests/ops/test_gla.py +++ b/tests/ops/test_gla.py @@ -82,8 +82,7 @@ def ref_gla_fwd( k_chunk = k[:, cs:ce] # [B, L, H, K] v_chunk = v[:, cs:ce] # [B, L, H, V] g_cs_chunk = g_cs[:, cs:ce] # [B, L, H, K] - g_last_exp = torch.exp(g_last).unsqueeze(1) # [B, 1, H, K] - k_adj = k_chunk * (g_last_exp / torch.exp(g_cs_chunk).clamp(min=1e-30)) + k_adj = k_chunk * torch.exp(g_last.unsqueeze(1) - g_cs_chunk) # b_h += einsum('blhk,blhv->bhkv', k_adj, v_chunk) b_h = b_h + torch.einsum('blhk,blhv->bhkv', k_adj, v_chunk) @@ -199,5 +198,11 @@ def test_gla_fwd( f"final_state mismatch: max err = {(out_final.float() - ref_final.float()).abs().max():.6f}" +def test_gla_fwd_non_divisible_seq_len() -> None: + """GLAFwdOp must reject seq_len not divisible by chunk_size.""" + with pytest.raises(AssertionError, match="must be divisible"): + GLAFwdOp(batch=1, seq_len=100, heads=4, dim_k=64, dim_v=64, chunk_size=64) + + if __name__ == "__main__": - pytest.main([__file__, "-vvs"]) + pytest.main([__file__, "-vvs"]) \ No newline at end of file diff --git a/tileops/kernels/gla/gla_fwd.py b/tileops/kernels/gla/gla_fwd.py index 59f4c045..11c74a45 100644 --- a/tileops/kernels/gla/gla_fwd.py +++ b/tileops/kernels/gla/gla_fwd.py @@ -401,8 +401,6 @@ def forward( # Stage 2: inter-chunk hidden state recurrence (PyTorch) # h_states[b, i_c, h, K, V] = state entering chunk i_c h_states = torch.empty(B, NT, H, K, V, dtype=torch.float32, device=q.device) - k_f32 = k.float() - v_f32 = v.float() for i_c in range(NT): cs = i_c * BT ce = min(cs + BT, T) @@ -416,8 +414,8 @@ def forward( # Accumulate: b_h += k_adj^T @ v # k_adj[b, t, h, k] = k[b, t, h, k] * exp(g_last[b, h, k] - g_cumsum[b, t, h, k]) - k_chunk = k_f32[:, cs:ce] # [B, L, H, K] - v_chunk = v_f32[:, cs:ce] # [B, L, H, V] + k_chunk = k[:, cs:ce].float() # [B, L, H, K] + v_chunk = v[:, cs:ce].float() # [B, L, H, V] g_cs_chunk = g_cumsum[:, cs:ce] # [B, L, H, K] k_adj = k_chunk * torch.exp(g_last.unsqueeze(1) - g_cs_chunk) b_h = b_h + torch.einsum('blhk,blhv->bhkv', k_adj, v_chunk) diff --git a/tileops/ops/gla.py b/tileops/ops/gla.py index 76f02270..f51cd38d 100644 --- a/tileops/ops/gla.py +++ b/tileops/ops/gla.py @@ -51,6 +51,8 @@ def __init__( kernel_map: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() + assert seq_len % chunk_size == 0, ( + f"seq_len ({seq_len}) must be divisible by chunk_size ({chunk_size})") params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map', '__class__')} # resolve default scale before storing if params['scale'] is None: From 18d06bada6d8794617c118567e8a605afa40b33f Mon Sep 17 00:00:00 2001 From: SuperAngGao Date: Thu, 26 Feb 2026 19:35:35 +0800 Subject: [PATCH 3/8] [Feat][GLA] Fix pre-commit: mdformat skill.md files, add EOF newline Co-Authored-By: Claude Sonnet 4.6 --- .claude/create-new-kernel/skill.md | 77 +++--- .claude/create-new-op-test/skill.md | 71 ++--- .claude/create-new-op/skill.md | 43 +-- .claude/skills/migrating-new-op/SKILL.md | 328 +++++++++++------------ tests/ops/test_gla.py | 2 +- 5 files changed, 273 insertions(+), 248 deletions(-) diff --git a/.claude/create-new-kernel/skill.md b/.claude/create-new-kernel/skill.md index 78447e17..6832be21 100644 --- a/.claude/create-new-kernel/skill.md +++ b/.claude/create-new-kernel/skill.md @@ -5,15 +5,15 @@ Base class: `tileops/kernels/kernel.py` ## Development Environment -| Item | Version | -|---|---| -| GPU | H200 (sm90a) | -| CUDA | 12.9 | -| TileLang | 0.1.7.post1 | +| Item | Version | +| -------- | ------------ | +| GPU | H200 (sm90a) | +| CUDA | 12.9 | +| TileLang | 0.1.7.post1 | Target architecture is `sm90a`. All kernels are developed and validated on this environment. Do not assume compatibility with older architectures unless explicitly tested. ---- +______________________________________________________________________ ## File Location and Naming @@ -28,19 +28,19 @@ tileops/kernels//.py - Class name: CamelCase + `Kernel` suffix (e.g. `NSATopkVarlenKernel`) - Export the class from `tileops/kernels//__init__.py` ---- +______________________________________________________________________ ## File Structure A kernel file contains five parts in order: 1. Imports -2. Kernel function (tilelang implementation) -3. Wrapper function (`__wrapped_kernel`) -4. `register_fake` for the wrapper -5. `Kernel` subclass +1. Kernel function (tilelang implementation) +1. Wrapper function (`__wrapped_kernel`) +1. `register_fake` for the wrapper +1. `Kernel` subclass ---- +______________________________________________________________________ ## Part 1: Imports @@ -56,7 +56,7 @@ from tileops.kernels.kernel import Kernel Tilelang kernel implementations generally need no additional imports beyond these. ---- +______________________________________________________________________ ## Part 2: Kernel Function @@ -108,6 +108,7 @@ def __kernel( - Never reduce directly on shared memory. Typical pattern: + ``` global → T.copy → shared → T.gemm / T.copy → fragment → T.reduce_* → fragment ``` @@ -119,16 +120,16 @@ global → T.copy → shared → T.gemm / T.copy → fragment → T.reduce_* → ### Tilelang coding guidelines -| Concern | Guidance | -|---|---| -| Exponentiation | Use `T.exp2(x * LOG2_E)` instead of `T.exp(x)` for performance (`LOG2_E = 1.44269504`) | -| Data movement | Use `T.copy()` for copies; `T.gemm()` for matrix multiply | -| Reduction | Use `T.reduce_sum()` / `T.reduce_max()` on fragments (not shared memory) | -| Parallelism | Prefer `T.Parallel` over `T.Serial`; only use `T.Serial` when sequential dependency is unavoidable | -| Loop structure | Prefer `T.Pipelined` for outer loops to enable software pipelining | +| Concern | Guidance | +| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Exponentiation | Use `T.exp2(x * LOG2_E)` instead of `T.exp(x)` for performance (`LOG2_E = 1.44269504`) | +| Data movement | Use `T.copy()` for copies; `T.gemm()` for matrix multiply | +| Reduction | Use `T.reduce_sum()` / `T.reduce_max()` on fragments (not shared memory) | +| Parallelism | Prefer `T.Parallel` over `T.Serial`; only use `T.Serial` when sequential dependency is unavoidable | +| Loop structure | Prefer `T.Pipelined` for outer loops to enable software pipelining | | Bug workaround | Tilelang may have bugs with nested `T.Parallel` / `T.Serial` / `T.Pipelined`. If compilation or runtime errors occur, try reordering the nesting levels as a first debugging step. | ---- +______________________________________________________________________ ## Part 3: Wrapper Function @@ -159,7 +160,7 @@ Note: `accum_dtype` is hardcoded inside `__kernel` and is not a par Provides a shape/dtype inference rule for tracing (no actual computation): ```python -@__wrapped_kernel.register_fake +@_ < kernel_name > _wrapped_kernel.register_fake def _( param1: int, dtype: str, @@ -178,7 +179,7 @@ def _( > The output shape here must exactly match the real kernel's output. Derive it from the scalar parameters (e.g. `c_seq_len`, `heads`, `selected_block_num`), not from input tensor shapes. ---- +______________________________________________________________________ ## Part 4: Kernel Class @@ -189,8 +190,8 @@ def _( The class **must** include a docstring with three sections: 1. Input tensor shapes — list each tensor parameter with its layout (e.g. `[batch, seqlen, heads, dim]`) -2. Computation logic — a brief description of what the kernel computes -3. Reference — URL to the official PyTorch / Triton / paper implementation this is based on +1. Computation logic — a brief description of what the kernel computes +1. Reference — URL to the official PyTorch / Triton / paper implementation this is based on Example: @@ -249,12 +250,13 @@ class Kernel(Kernel): ``` Key points: + - `forward` calls `__wrapped_kernel` (the registered wrapper), not the raw kernel function directly - `init_config(config, tune)` must be the last call in `__init__`; it reads `default_config` and `autotune_configs` - `accum_dtype` is not stored on the class; it is hardcoded inside the kernel function - Only cast index/offset tensors (e.g. `offsets`, `token_indices`) to `torch.int32` in `forward`; do NOT cast floating-point tensors ---- +______________________________________________________________________ ## Attention Kernel Conventions @@ -264,11 +266,11 @@ Attention kernels have additional conventions beyond the general rules above. Every attention kernel must be classified into one of three variants, which determines its file name, class name, and internal structure: -| Variant | File suffix | Class suffix | Description | -|---|---|---|---| -| Forward | `_fwd.py` | `FwdKernel` | Full-sequence prefill / training forward pass | -| Decode | `_decode.py` | `DecodeKernel` | Single-token decode with KV cache | -| Backward | `_bwd.py` | `BwdKernel` | Training backward pass | +| Variant | File suffix | Class suffix | Description | +| -------- | ------------ | -------------- | --------------------------------------------- | +| Forward | `_fwd.py` | `FwdKernel` | Full-sequence prefill / training forward pass | +| Decode | `_decode.py` | `DecodeKernel` | Single-token decode with KV cache | +| Backward | `_bwd.py` | `BwdKernel` | Training backward pass | Examples: `mha_fwd.py` → `MhaFwdKernel`, `mha_decode.py` → `MhaDecodeKernel` @@ -358,13 +360,22 @@ The `Kernel.forward` allocates `glse` and `Output_partial` as temporary buffers ```python def forward(self, Q, K, V, real_seqlen_kv): glse = torch.empty((..., self.config["num_split"], ...), dtype=..., device=Q.device) - Output_partial = torch.empty((..., self.config["num_split"], ...), dtype=..., device=Q.device) - return __decode_wrapped_kernel(..., self.config["num_split"], Q, K, V, glse, Output_partial) + Output_partial = torch.empty( + (..., self.config["num_split"], ...), dtype=..., device=Q.device + ) + return ( + _ + < attn_name + > _decode_wrapped_kernel( + ..., self.config["num_split"], Q, K, V, glse, Output_partial + ) + ) ``` ### Decode kernel: paged KV variant For paged KV cache support, create a separate file `_decode_paged.py`. The paged variant: + - Adds `page_size: int` to the outermost kernel function - Replaces the flat KV shape `[batch, seqlen_kv, heads, dim]` with a paged pool shape `[total_pages * page_size, heads, dim]` - Adds a `block_table: T.Tensor([batch, max_pages], T.int32)` parameter to index into the page pool diff --git a/.claude/create-new-op-test/skill.md b/.claude/create-new-op-test/skill.md index 34c5d05f..3fe5d8e5 100644 --- a/.claude/create-new-op-test/skill.md +++ b/.claude/create-new-op-test/skill.md @@ -2,20 +2,20 @@ Reference tests: `tests/ops/` ---- +______________________________________________________________________ ## Overview Each op has a pytest file at `tests/ops/test_.py`. The test: 1. Instantiates the `Op` with given parameters -2. Generates random input tensors -3. Runs the op and a PyTorch reference implementation -4. Asserts numerical closeness with `torch.allclose` +1. Generates random input tensors +1. Runs the op and a PyTorch reference implementation +1. Asserts numerical closeness with `torch.allclose` No benchmark infrastructure is needed — correctness only. ---- +______________________________________________________________________ ## File Location and Naming @@ -25,7 +25,7 @@ tests/ops/test_.py File name mirrors the op module name (e.g. `tileops/ops/gqa_decode.py` → `tests/ops/test_gqa_decode.py`). ---- +______________________________________________________________________ ## Test File Structure @@ -71,7 +71,7 @@ if __name__ == "__main__": pytest.main([__file__, "-vvs"]) ``` ---- +______________________________________________________________________ ## Test Case Requirements @@ -83,11 +83,11 @@ if __name__ == "__main__": ### dtype coverage -| Op type | Required dtypes | -|---|---| -| General ops | `torch.float16`, `torch.bfloat16` | +| Op type | Required dtypes | +| --------------------------------- | --------------------------------------------------------------- | +| General ops | `torch.float16`, `torch.bfloat16` | | Quantization ops (fp8, fp4, etc.) | Include the target quantized dtype (e.g. `torch.float8_e4m3fn`) | -| Mixed-precision ops | Cover all relevant input/output dtype combinations | +| Mixed-precision ops | Cover all relevant input/output dtype combinations | ### Random seed @@ -95,12 +95,12 @@ Fix `torch.manual_seed(42)` at the top of every test function, before any tensor ### Tolerance guidelines -| Op type | `atol` | `rtol` | -|---|---|---| -| Attention fwd (fp16/bf16) | `5e-3` | `1e-5` | -| Attention decode | `1e-2` | `1e-2` | -| GEMM / linear | `1e-3` | `1e-3` | -| Elementwise / quantization | `1e-1` or custom | — | +| Op type | `atol` | `rtol` | +| -------------------------- | ---------------- | ------ | +| Attention fwd (fp16/bf16) | `5e-3` | `1e-5` | +| Attention decode | `1e-2` | `1e-2` | +| GEMM / linear | `1e-3` | `1e-3` | +| Elementwise / quantization | `1e-1` or custom | — | ### Reference implementation @@ -121,6 +121,7 @@ def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> float: b_f = b.float().flatten() return torch.nn.functional.cosine_similarity(a_f, b_f, dim=0).item() + # in the test: sim = cosine_sim(out, ref) assert sim >= 0.999, f"cosine similarity {sim:.6f} < 0.999" @@ -128,7 +129,7 @@ assert sim >= 0.999, f"cosine similarity {sim:.6f} < 0.999" Add a comment explaining why `torch.allclose` was replaced and what debugging was attempted. ---- +______________________________________________________________________ ## Example: attention fwd op @@ -149,34 +150,40 @@ def ref_mha_fwd(q, k, v, is_causal): return out.transpose(1, 2).contiguous() -@pytest.mark.parametrize("batch, seq_len, heads, dim, is_causal, dtype, tune", [ - (1, 1024, 8, 64, False, torch.float16, False), - (4, 2048, 16, 128, False, torch.bfloat16, False), - (8, 4096, 16, 128, True, torch.float16, False), - (2, 1024, 32, 64, True, torch.bfloat16, False), - (4, 2048, 16, 128, False, torch.bfloat16, True), -]) +@pytest.mark.parametrize( + "batch, seq_len, heads, dim, is_causal, dtype, tune", + [ + (1, 1024, 8, 64, False, torch.float16, False), + (4, 2048, 16, 128, False, torch.bfloat16, False), + (8, 4096, 16, 128, True, torch.float16, False), + (2, 1024, 32, 64, True, torch.bfloat16, False), + (4, 2048, 16, 128, False, torch.bfloat16, True), + ], +) def test_mha_fwd(batch, seq_len, heads, dim, is_causal, dtype, tune): torch.manual_seed(42) - op = MultiHeadAttentionFwdOp(batch, heads, seq_len, dim, is_causal, dtype, tune=tune) + op = MultiHeadAttentionFwdOp( + batch, heads, seq_len, dim, is_causal, dtype, tune=tune + ) - q = torch.randn(batch, seq_len, heads, dim, device='cuda', dtype=dtype) - k = torch.randn(batch, seq_len, heads, dim, device='cuda', dtype=dtype) - v = torch.randn(batch, seq_len, heads, dim, device='cuda', dtype=dtype) + q = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=dtype) + k = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=dtype) + v = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=dtype) with torch.no_grad(): out = op(q, k, v) ref = ref_mha_fwd(q, k, v, is_causal) - assert torch.allclose(out, ref, atol=5e-3, rtol=1e-5), \ - f"max err: {(out - ref).abs().max()}" + assert torch.allclose( + out, ref, atol=5e-3, rtol=1e-5 + ), f"max err: {(out - ref).abs().max()}" if __name__ == "__main__": pytest.main([__file__, "-vvs"]) ``` ---- +______________________________________________________________________ ## Checklist diff --git a/.claude/create-new-op/skill.md b/.claude/create-new-op/skill.md index 56a7fb5f..c4f5fb2d 100644 --- a/.claude/create-new-op/skill.md +++ b/.claude/create-new-op/skill.md @@ -3,18 +3,19 @@ Reference implementations: `tileops/ops/mha.py`, `tileops/ops/gqa_decode.py`, `tileops/ops/deepseek_nsa.py` Base class: `tileops/ops/op.py` ---- +______________________________________________________________________ ## Overview An `Op` is a thin orchestration layer that: + 1. Holds kernel instances (one or more `Kernel` subclasses) -2. Dispatches to the correct kernel based on hardware via `dispatch_kernel` -3. Exposes a `forward` method that calls the kernel(s) and handles any pre/post-processing +1. Dispatches to the correct kernel based on hardware via `dispatch_kernel` +1. Exposes a `forward` method that calls the kernel(s) and handles any pre/post-processing An `Op` does **not** implement GPU computation itself — that lives in the `Kernel`. ---- +______________________________________________________________________ ## File Location and Naming @@ -27,7 +28,7 @@ tileops/ops/.py - Group multiple related ops in one file when they share the same kernel set (e.g. `mha.py` contains both `MultiHeadAttentionFwdOp` and `MultiHeadAttentionBwdOp`) - After creating the file, register all new classes in `tileops/ops/__init__.py` ---- +______________________________________________________________________ ## File Structure @@ -36,13 +37,14 @@ tileops/ops/.py ``` A single op file contains: + 1. Imports -2. `__all__` declaration -3. One or more `Op` subclasses +1. `__all__` declaration +1. One or more `Op` subclasses After creating the file, register the new class in `tileops/ops/__init__.py`. ---- +______________________________________________________________________ ## Part 1: Imports @@ -59,7 +61,7 @@ from .op import Op Only import `is_hopper` from `tileops.utils` if the op needs to dispatch different kernels per architecture. ---- +______________________________________________________________________ ## Part 2: Op Class @@ -172,7 +174,7 @@ For ops that need input validation or padding before calling the kernel (see `gq return self.kernel(q, k, ...) ``` ---- +______________________________________________________________________ ## Part 3: Register in `__init__.py` @@ -189,7 +191,7 @@ __all__ = [ ] ``` ---- +______________________________________________________________________ ## Key Points @@ -201,7 +203,7 @@ __all__ = [ - `tune=False` is always before `kernel_map`; `kernel_map` is always the last `__init__` parameter - `accum_dtype` is never a parameter of the Op — it is hardcoded inside the Kernel ---- +______________________________________________________________________ ## Unit Test Requirement @@ -212,10 +214,14 @@ Every new op **must** have a passing unit test in `tests/ops/test_.py` When the unit test fails due to numerical mismatch, follow this process in order: 1. **Do not loosen tolerances.** A large numerical error indicates a real bug, not a precision issue. The PyTorch reference is the ground truth. -2. **Check the algorithm first.** Compare the tilelang implementation step-by-step against the official reference URL in the kernel's docstring — scale factors, layout (BSHD vs BHSD), masking logic, softmax normalization, etc. -3. **Check memory layout and indexing.** Verify that tensor shapes, strides, and index expressions in `T.copy` / `T.gemm` match the intended layout. -4. **Isolate the stage.** If the kernel has multiple stages (e.g. QK gemm → softmax → PV gemm), add intermediate checks to identify which stage produces the wrong result. -5. **Replace tilelang primitives with scalar equivalents for debugging.** If the above steps do not reveal the bug, temporarily replace tilelang primitives with simpler equivalents to rule out compiler/primitive bugs: + +1. **Check the algorithm first.** Compare the tilelang implementation step-by-step against the official reference URL in the kernel's docstring — scale factors, layout (BSHD vs BHSD), masking logic, softmax normalization, etc. + +1. **Check memory layout and indexing.** Verify that tensor shapes, strides, and index expressions in `T.copy` / `T.gemm` match the intended layout. + +1. **Isolate the stage.** If the kernel has multiple stages (e.g. QK gemm → softmax → PV gemm), add intermediate checks to identify which stage produces the wrong result. + +1. **Replace tilelang primitives with scalar equivalents for debugging.** If the above steps do not reveal the bug, temporarily replace tilelang primitives with simpler equivalents to rule out compiler/primitive bugs: - Replace `T.Pipelined` with `T.serial` (removes software pipelining) - Replace `T.copy(src, dst)` with a `T.Parallel` loop and element-wise assignment: @@ -232,9 +238,10 @@ When the unit test fails due to numerical mismatch, follow this process in order C[i, j] += A[i, k] * B[j, k] # adjust indices for transpose_B ``` - Mark all such changes with a `# DEBUG:` comment so they are easy to find and revert. -6. **Revert debug changes** once the bug is found and fixed. Do not leave `T.serial` / explicit loops in production code. ---- +1. **Revert debug changes** once the bug is found and fixed. Do not leave `T.serial` / explicit loops in production code. + +______________________________________________________________________ - [ ] Class name is CamelCase and ends with `Op` - [ ] File placed at `tileops/ops/.py` following lowercase underscore naming diff --git a/.claude/skills/migrating-new-op/SKILL.md b/.claude/skills/migrating-new-op/SKILL.md index 8240c3ab..650c5238 100644 --- a/.claude/skills/migrating-new-op/SKILL.md +++ b/.claude/skills/migrating-new-op/SKILL.md @@ -1,164 +1,164 @@ ---- -name: migrating-new-op -description: Use when adding or migrating an operator to TileOPs with the required kernel->op->function->layer delivery path and validation checklist ---- - -## When to use - -- Add a new op into TileOPs -- Migrate an op from `cuda`/`triton` into TileOPs with TileLang kernels -- Standardize implementation and tests to the project architecture order - -______________________________________________________________________ - -## Level Confirmation (Ask First) - -When this skill is invoked, ask the user first: - -"What implementation level do you want for this operator?" - -- **L1 (Kernel only)** - - Deliverables: kernel implementation + kernel/functional correctness checks - - Typical paths: `tileops/kernels//`, minimal verification script/tests -- **L2 (Kernel + Op)** - - Deliverables: L1 + op interface + op unit tests + benchmark script - - Typical paths: `tileops/ops/.py`, `tests/ops/test_.py`, `benchmarks/benchmark_.py` -- **L3 (Kernel + Op + Function)** - - Deliverables: L2 + functional API (`torch.autograd.Function` when needed) + function tests/grad checks - - Typical paths: `tileops/functions/.py`, `tests/functions/test__func.py` -- **L4 (Full stack: Kernel + Op + Function + Layer)** - - Deliverables: L3 + `nn.Module` layer wrapper + layer tests + export synchronization - - Typical paths: `tileops/layers/.py`, `tests/layers/test__layer.py` - -If user does not specify a level, default to **L2** and state this assumption explicitly. - -After level confirmation, execute only the required layers for that level (do not over-deliver). - -______________________________________________________________________ - -## Workflow - -### Phase A: Requirement + Reference Alignment - -1. Confirm target behavior, API, and expected test scope. -1. Read the external reference implementation and extract: - - kernel stages - - input/output semantics - - accumulation and edge-case behavior - -### Phase B: Delivery Design by Layer - -1. Confirm target level (L1/L2/L3/L4) with the user before coding. -1. Plan implementation in strict order: `kernel -> op -> function -> layer`. -1. Define per-layer responsibilities and avoid cross-layer shortcuts. -1. List only level-required files/tests before coding (see checklist below). - -### Phase C: Kernel Migration (TileLang) - -1. Implement TileLang kernels by stage (same logical decomposition as reference). -1. Keep output semantics compatible with existing interfaces. -1. Handle core edge cases early (for example: empty paths / `num_topk == 0`). -1. Follow the conventions in [`.claude/create-new-kernel/skill.md`](../create-new-kernel/skill.md): file layout, naming, dtype rules, memory hierarchy, attention variants, wrapper/register_fake pattern, and docstring requirements. - -### Phase D: Upper Layer Wiring - -1. Add `op` wrapper for kernel invocation + runtime contract. Follow [`.claude/create-new-op/skill.md`](../create-new-op/skill.md) for Op class structure, kernel dispatch, and `__init__.py` registration. -1. Add `function` API for reusable composition and shape/dtype validation. -1. Add `layer` abstraction only when module-style integration is needed. -1. Keep dependency direction one-way: `layer -> function -> op -> kernel`. - -### Phase E: Tests and Validation - -1. Ensure layer-matched tests exist: - -- `tests/ops` for op behavior — follow [`.claude/create-new-op-test/skill.md`](../create-new-op-test/skill.md) for test structure, parametrization, dtype coverage, and debugging protocol -- `tests/functions` for functional integration -- `tests/layers` when a layer is introduced - -2. Add reference comparison for correctness checks. -1. Run incremental tests first, then the relevant chain. - -### Phase F: Cleanup and Finalization - -1. Remove obsolete wrappers/helpers/classes. -1. Update exports/imports (`__init__.py`) to avoid stale symbols. -1. Re-run the main regression chain and summarize results. - -______________________________________________________________________ - -## Software Organization Guidelines - -- **Kernel layer (`tileops/kernels`)**: compute logic and kernel configs only. -- **Op layer (`tileops/ops`)**: wraps kernel call contract and runtime dispatch. -- **Function layer (`tileops/functions`)**: reusable API composition and validation. -- **Layer (`tileops/layers`)**: module-level integration (`nn.Module` style). -- Keep dependency direction strict: `layer -> function -> op -> kernel`. -- After refactoring, always sync exports to prevent import-time failures. - -______________________________________________________________________ - -## Minimum Deliverables Checklist - -Before opening a PR, verify all required items are present: - -1. **L1: Kernel (`tileops/kernels`)** - -- [ ] New/updated kernel implementation exists -- [ ] Kernel handles documented edge cases -- [ ] Follows conventions in [`.claude/create-new-kernel/skill.md`](../create-new-kernel/skill.md) - -2. **L2: Op (`tileops/ops`)** - -- [ ] Op API wraps kernel with stable argument contract -- [ ] Op-level tests added/updated in `tests/ops` — follows [`.claude/create-new-op-test/skill.md`](../create-new-op-test/skill.md) -- [ ] Benchmark script added/updated in `benchmarks` - -3. **L3: Function (`tileops/functions`)** - -- [ ] Function API added/updated for composable usage -- [ ] Function tests added/updated in `tests/functions` -- [ ] `gradcheck` path is added when autograd is expected - -4. **L4: Layer (`tileops/layers`, if needed)** - -- [ ] Layer class added only when module abstraction is required -- [ ] Layer tests added/updated in `tests/layers` - -5. **Project Hygiene** - -- [ ] `__init__.py` exports are synchronized -- [ ] Relevant tests pass locally -- [ ] Migration notes / behavior deltas are documented - -______________________________________________________________________ - -## Test Flow - -### Recommended command pattern - -```bash -PYTHONPATH="$PWD" python -m pytest -v tests/ops/test_xxx.py -PYTHONPATH="$PWD" python -m pytest -v tests/functions/test_xxx_func.py -PYTHONPATH="$PWD" python -m pytest -v tests/layers/test_xxx_layer.py -``` - -### Main-chain regression - -```bash -PYTHONPATH="$PWD" python -m pytest -v \ - tests/ops/test_xxx.py \ - tests/functions/test_xxx_func.py \ - tests/layers/test_xxx_layer.py -``` - -______________________________________________________________________ - -## Done Criteria - -Migration is considered complete when: - -- kernel/op/function/layer (if needed) are delivered in correct order -- API behavior stays compatible -- stale legacy wrappers are removed -- relevant ops/functions/layers tests pass -- migration notes are documented for reuse +--- +name: migrating-new-op +description: Use when adding or migrating an operator to TileOPs with the required kernel->op->function->layer delivery path and validation checklist +--- + +## When to use + +- Add a new op into TileOPs +- Migrate an op from `cuda`/`triton` into TileOPs with TileLang kernels +- Standardize implementation and tests to the project architecture order + +______________________________________________________________________ + +## Level Confirmation (Ask First) + +When this skill is invoked, ask the user first: + +"What implementation level do you want for this operator?" + +- **L1 (Kernel only)** + - Deliverables: kernel implementation + kernel/functional correctness checks + - Typical paths: `tileops/kernels//`, minimal verification script/tests +- **L2 (Kernel + Op)** + - Deliverables: L1 + op interface + op unit tests + benchmark script + - Typical paths: `tileops/ops/.py`, `tests/ops/test_.py`, `benchmarks/benchmark_.py` +- **L3 (Kernel + Op + Function)** + - Deliverables: L2 + functional API (`torch.autograd.Function` when needed) + function tests/grad checks + - Typical paths: `tileops/functions/.py`, `tests/functions/test__func.py` +- **L4 (Full stack: Kernel + Op + Function + Layer)** + - Deliverables: L3 + `nn.Module` layer wrapper + layer tests + export synchronization + - Typical paths: `tileops/layers/.py`, `tests/layers/test__layer.py` + +If user does not specify a level, default to **L2** and state this assumption explicitly. + +After level confirmation, execute only the required layers for that level (do not over-deliver). + +______________________________________________________________________ + +## Workflow + +### Phase A: Requirement + Reference Alignment + +1. Confirm target behavior, API, and expected test scope. +1. Read the external reference implementation and extract: + - kernel stages + - input/output semantics + - accumulation and edge-case behavior + +### Phase B: Delivery Design by Layer + +1. Confirm target level (L1/L2/L3/L4) with the user before coding. +1. Plan implementation in strict order: `kernel -> op -> function -> layer`. +1. Define per-layer responsibilities and avoid cross-layer shortcuts. +1. List only level-required files/tests before coding (see checklist below). + +### Phase C: Kernel Migration (TileLang) + +1. Implement TileLang kernels by stage (same logical decomposition as reference). +1. Keep output semantics compatible with existing interfaces. +1. Handle core edge cases early (for example: empty paths / `num_topk == 0`). +1. Follow the conventions in [`.claude/create-new-kernel/skill.md`](../create-new-kernel/skill.md): file layout, naming, dtype rules, memory hierarchy, attention variants, wrapper/register_fake pattern, and docstring requirements. + +### Phase D: Upper Layer Wiring + +1. Add `op` wrapper for kernel invocation + runtime contract. Follow [`.claude/create-new-op/skill.md`](../create-new-op/skill.md) for Op class structure, kernel dispatch, and `__init__.py` registration. +1. Add `function` API for reusable composition and shape/dtype validation. +1. Add `layer` abstraction only when module-style integration is needed. +1. Keep dependency direction one-way: `layer -> function -> op -> kernel`. + +### Phase E: Tests and Validation + +1. Ensure layer-matched tests exist: + +- `tests/ops` for op behavior — follow [`.claude/create-new-op-test/skill.md`](../create-new-op-test/skill.md) for test structure, parametrization, dtype coverage, and debugging protocol +- `tests/functions` for functional integration +- `tests/layers` when a layer is introduced + +2. Add reference comparison for correctness checks. +1. Run incremental tests first, then the relevant chain. + +### Phase F: Cleanup and Finalization + +1. Remove obsolete wrappers/helpers/classes. +1. Update exports/imports (`__init__.py`) to avoid stale symbols. +1. Re-run the main regression chain and summarize results. + +______________________________________________________________________ + +## Software Organization Guidelines + +- **Kernel layer (`tileops/kernels`)**: compute logic and kernel configs only. +- **Op layer (`tileops/ops`)**: wraps kernel call contract and runtime dispatch. +- **Function layer (`tileops/functions`)**: reusable API composition and validation. +- **Layer (`tileops/layers`)**: module-level integration (`nn.Module` style). +- Keep dependency direction strict: `layer -> function -> op -> kernel`. +- After refactoring, always sync exports to prevent import-time failures. + +______________________________________________________________________ + +## Minimum Deliverables Checklist + +Before opening a PR, verify all required items are present: + +1. **L1: Kernel (`tileops/kernels`)** + +- [ ] New/updated kernel implementation exists +- [ ] Kernel handles documented edge cases +- [ ] Follows conventions in [`.claude/create-new-kernel/skill.md`](../create-new-kernel/skill.md) + +2. **L2: Op (`tileops/ops`)** + +- [ ] Op API wraps kernel with stable argument contract +- [ ] Op-level tests added/updated in `tests/ops` — follows [`.claude/create-new-op-test/skill.md`](../create-new-op-test/skill.md) +- [ ] Benchmark script added/updated in `benchmarks` + +3. **L3: Function (`tileops/functions`)** + +- [ ] Function API added/updated for composable usage +- [ ] Function tests added/updated in `tests/functions` +- [ ] `gradcheck` path is added when autograd is expected + +4. **L4: Layer (`tileops/layers`, if needed)** + +- [ ] Layer class added only when module abstraction is required +- [ ] Layer tests added/updated in `tests/layers` + +5. **Project Hygiene** + +- [ ] `__init__.py` exports are synchronized +- [ ] Relevant tests pass locally +- [ ] Migration notes / behavior deltas are documented + +______________________________________________________________________ + +## Test Flow + +### Recommended command pattern + +```bash +PYTHONPATH="$PWD" python -m pytest -v tests/ops/test_xxx.py +PYTHONPATH="$PWD" python -m pytest -v tests/functions/test_xxx_func.py +PYTHONPATH="$PWD" python -m pytest -v tests/layers/test_xxx_layer.py +``` + +### Main-chain regression + +```bash +PYTHONPATH="$PWD" python -m pytest -v \ + tests/ops/test_xxx.py \ + tests/functions/test_xxx_func.py \ + tests/layers/test_xxx_layer.py +``` + +______________________________________________________________________ + +## Done Criteria + +Migration is considered complete when: + +- kernel/op/function/layer (if needed) are delivered in correct order +- API behavior stays compatible +- stale legacy wrappers are removed +- relevant ops/functions/layers tests pass +- migration notes are documented for reuse diff --git a/tests/ops/test_gla.py b/tests/ops/test_gla.py index 069ab281..fffc1d66 100644 --- a/tests/ops/test_gla.py +++ b/tests/ops/test_gla.py @@ -205,4 +205,4 @@ def test_gla_fwd_non_divisible_seq_len() -> None: if __name__ == "__main__": - pytest.main([__file__, "-vvs"]) \ No newline at end of file + pytest.main([__file__, "-vvs"]) From cf9b43a702d54f751e74da2db0030d883f9f639a Mon Sep 17 00:00:00 2001 From: SuperAngGao Date: Fri, 27 Feb 2026 11:28:33 +0800 Subject: [PATCH 4/8] [Chore][Skill] Split attention conventions into create-new-op-attention skill, add YAML frontmatter and auto-invoke to all skills Co-Authored-By: Claude Sonnet 4.6 --- .claude/create-new-kernel/skill.md | 136 +-------------------- .claude/create-new-op-attention/skill.md | 144 +++++++++++++++++++++++ .claude/create-new-op-test/skill.md | 5 + .claude/create-new-op/skill.md | 5 + 4 files changed, 159 insertions(+), 131 deletions(-) create mode 100644 .claude/create-new-op-attention/skill.md diff --git a/.claude/create-new-kernel/skill.md b/.claude/create-new-kernel/skill.md index 6832be21..eaf21561 100644 --- a/.claude/create-new-kernel/skill.md +++ b/.claude/create-new-kernel/skill.md @@ -1,3 +1,8 @@ +--- +name: create-new-kernel +description: Create a new TileOps GPU kernel — file layout, tilelang kernel function, wrapper registration, and Kernel subclass. For attention-specific conventions (variants, causal masking, split-K decode, paged KV), also use the create-new-op and create-new-op-attention skills. Auto-invoke when the user asks to create, implement, or add a new kernel in TileOps. +--- + # Skill: Creating a New TileOps Kernel Reference implementation: `tileops/kernels/deepseek_nsa/nsa_topk.py` @@ -256,133 +261,6 @@ Key points: - `accum_dtype` is not stored on the class; it is hardcoded inside the kernel function - Only cast index/offset tensors (e.g. `offsets`, `token_indices`) to `torch.int32` in `forward`; do NOT cast floating-point tensors -______________________________________________________________________ - -## Attention Kernel Conventions - -Attention kernels have additional conventions beyond the general rules above. - -### Kernel variants - -Every attention kernel must be classified into one of three variants, which determines its file name, class name, and internal structure: - -| Variant | File suffix | Class suffix | Description | -| -------- | ------------ | -------------- | --------------------------------------------- | -| Forward | `_fwd.py` | `FwdKernel` | Full-sequence prefill / training forward pass | -| Decode | `_decode.py` | `DecodeKernel` | Single-token decode with KV cache | -| Backward | `_bwd.py` | `BwdKernel` | Training backward pass | - -Examples: `mha_fwd.py` → `MhaFwdKernel`, `mha_decode.py` → `MhaDecodeKernel` - -### `causal` parameter - -All attention kernels **must** expose `is_causal: bool` as a parameter of the outermost kernel function (the two-level closure). It controls the causal masking logic inside the `@T.prim_func`. - -```python -def __kernel( - batch: int, - heads: int, - seqlen_q: int, - seqlen_kv: int, - dim: int, - is_causal: bool, # required for all attention kernels - dtype: str, -) -> Callable: - ... -``` - -### Decode kernel: split-K design - -Decode kernels must support both a **no-split** and a **split-K** execution path, selected at runtime by `num_split`: - -- `num_split = 1` → use the no-split `@T.prim_func` (single pass over KV) -- `num_split > 1` → use the split `@T.prim_func` (parallel over KV chunks, then combine) - -Both paths are implemented as `@T.macro` functions inside `__func`, and the outer `@T.prim_func` simply calls the appropriate macro. The wrapper function (`__wrapped_kernel`) computes `split_length` and dispatches to the correct path. - -`num_split` is a tunable parameter and must appear in `default_config` and `autotune_configs`. - -Structure inside `__func`: - -```python -def __decode_func(block_M, block_N, num_split, num_stages, threads): - - @T.macro - def __no_split(Q, K, V, real_seqlen_kv, Output): - # single-pass attention over full KV - ... - - @T.macro - def __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length): - # attention over one KV chunk; writes partial output + log-sum-exp - ... - - @T.macro - def combine(glse, Output_partial, Output): - # merge partial outputs using LSE rescaling - ... - - @T.prim_func - def _decode_no_split(Q, K, V, real_seqlen_kv, Output): - __no_split(Q, K, V, real_seqlen_kv, Output) - - @T.prim_func - def _decode_split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length, Output): - __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) - combine(glse, Output_partial, Output) - - if num_split > 1: - return _decode_split - else: - return _decode_no_split -``` - -The wrapper allocates `glse` and `Output_partial` buffers and computes `split_length` before dispatching: - -```python -@torch.library.custom_op("top::_decode_wrapped_kernel", mutates_args=()) -def __decode_wrapped_kernel( - ..., num_split: int, - Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, - glse: torch.Tensor, Output_partial: torch.Tensor, -) -> torch.Tensor: - split_length = ... # compute per-split chunk sizes - if split_length[0] == 0: - num_split = 1 - if num_split == 1: - return __decode_kernel(...)(block_M, block_N, 1, num_stages, threads)(Q, K, V, real_seqlen_kv) - return __decode_kernel(...)(block_M, block_N, num_split, num_stages, threads)( - Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) -``` - -The `Kernel.forward` allocates `glse` and `Output_partial` as temporary buffers before calling the wrapper: - -```python -def forward(self, Q, K, V, real_seqlen_kv): - glse = torch.empty((..., self.config["num_split"], ...), dtype=..., device=Q.device) - Output_partial = torch.empty( - (..., self.config["num_split"], ...), dtype=..., device=Q.device - ) - return ( - _ - < attn_name - > _decode_wrapped_kernel( - ..., self.config["num_split"], Q, K, V, glse, Output_partial - ) - ) -``` - -### Decode kernel: paged KV variant - -For paged KV cache support, create a separate file `_decode_paged.py`. The paged variant: - -- Adds `page_size: int` to the outermost kernel function -- Replaces the flat KV shape `[batch, seqlen_kv, heads, dim]` with a paged pool shape `[total_pages * page_size, heads, dim]` -- Adds a `block_table: T.Tensor([batch, max_pages], T.int32)` parameter to index into the page pool -- `real_seqlen_kv` becomes a per-batch tensor `T.Tensor([batch], T.int32)` instead of a scalar - -Reference: `tileops/kernels/flash_decode/mha_decode_paged.py` - - [ ] `out_idx` in `@tilelang.jit` points to the correct output tensor position - [ ] `register_fake` output shape matches the real kernel output - [ ] `custom_op` name is unique: `"top::_wrapped_kernel"` @@ -397,7 +275,3 @@ Reference: `tileops/kernels/flash_decode/mha_decode_paged.py` - [ ] Global memory is copied to shared memory before computation - [ ] Reductions operate on fragments, not shared memory - [ ] `@T.prim_func` is kept minimal; complex logic is in `@T.macro` helpers -- [ ] (Attention only) `is_causal: bool` is a parameter of the outermost kernel function -- [ ] (Attention only) File/class name includes variant suffix: `_fwd` / `_decode` / `_bwd` -- [ ] (Decode only) Both no-split and split-K `@T.prim_func` are implemented; `num_split` is in `default_config` -- [ ] (Decode paged) Separate `_decode_paged.py` file; uses `page_size`, `block_table`, and per-batch `real_seqlen_kv` diff --git a/.claude/create-new-op-attention/skill.md b/.claude/create-new-op-attention/skill.md new file mode 100644 index 00000000..de1f7022 --- /dev/null +++ b/.claude/create-new-op-attention/skill.md @@ -0,0 +1,144 @@ +--- +name: create-new-op-attention +description: Attention-specific conventions for creating a new TileOps attention kernel — variants (fwd/decode/bwd), causal masking, split-K decode design, and paged KV cache. Use together with create-new-kernel (kernel structure) and create-new-op (op registration). Auto-invoke when the user asks to create or implement an attention kernel (MHA, GQA, MLA, flash attention, decode attention, etc.) in TileOps. +--- + +# Skill: Attention Kernel Conventions + +Attention kernels have additional conventions beyond the general rules in `create-new-kernel` and `create-new-op`. + +______________________________________________________________________ + +## Kernel variants + +Every attention kernel must be classified into one of three variants, which determines its file name, class name, and internal structure: + +| Variant | File suffix | Class suffix | Description | +| -------- | ------------ | -------------- | --------------------------------------------- | +| Forward | `_fwd.py` | `FwdKernel` | Full-sequence prefill / training forward pass | +| Decode | `_decode.py` | `DecodeKernel` | Single-token decode with KV cache | +| Backward | `_bwd.py` | `BwdKernel` | Training backward pass | + +Examples: `mha_fwd.py` → `MhaFwdKernel`, `mha_decode.py` → `MhaDecodeKernel` + +______________________________________________________________________ + +## `causal` parameter + +All attention kernels **must** expose `is_causal: bool` as a parameter of the outermost kernel function (the two-level closure). It controls the causal masking logic inside the `@T.prim_func`. + +```python +def __kernel( + batch: int, + heads: int, + seqlen_q: int, + seqlen_kv: int, + dim: int, + is_causal: bool, # required for all attention kernels + dtype: str, +) -> Callable: + ... +``` + +______________________________________________________________________ + +## Decode kernel: split-K design + +Decode kernels must support both a **no-split** and a **split-K** execution path, selected at runtime by `num_split`: + +- `num_split = 1` → use the no-split `@T.prim_func` (single pass over KV) +- `num_split > 1` → use the split `@T.prim_func` (parallel over KV chunks, then combine) + +Both paths are implemented as `@T.macro` functions inside `__func`, and the outer `@T.prim_func` simply calls the appropriate macro. The wrapper function (`__wrapped_kernel`) computes `split_length` and dispatches to the correct path. + +`num_split` is a tunable parameter and must appear in `default_config` and `autotune_configs`. + +Structure inside `__func`: + +```python +def __decode_func(block_M, block_N, num_split, num_stages, threads): + + @T.macro + def __no_split(Q, K, V, real_seqlen_kv, Output): + # single-pass attention over full KV + ... + + @T.macro + def __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length): + # attention over one KV chunk; writes partial output + log-sum-exp + ... + + @T.macro + def combine(glse, Output_partial, Output): + # merge partial outputs using LSE rescaling + ... + + @T.prim_func + def _decode_no_split(Q, K, V, real_seqlen_kv, Output): + __no_split(Q, K, V, real_seqlen_kv, Output) + + @T.prim_func + def _decode_split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length, Output): + __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) + combine(glse, Output_partial, Output) + + if num_split > 1: + return _decode_split + else: + return _decode_no_split +``` + +The wrapper allocates `glse` and `Output_partial` buffers and computes `split_length` before dispatching: + +```python +@torch.library.custom_op("top::_decode_wrapped_kernel", mutates_args=()) +def __decode_wrapped_kernel( + ..., num_split: int, + Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + glse: torch.Tensor, Output_partial: torch.Tensor, +) -> torch.Tensor: + split_length = ... # compute per-split chunk sizes + if split_length[0] == 0: + num_split = 1 + if num_split == 1: + return __decode_kernel(...)(block_M, block_N, 1, num_stages, threads)(Q, K, V, real_seqlen_kv) + return __decode_kernel(...)(block_M, block_N, num_split, num_stages, threads)( + Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) +``` + +The `Kernel.forward` allocates `glse` and `Output_partial` as temporary buffers before calling the wrapper: + +```python +def forward(self, Q, K, V, real_seqlen_kv): + glse = torch.empty((..., self.config["num_split"], ...), dtype=..., device=Q.device) + Output_partial = torch.empty( + (..., self.config["num_split"], ...), dtype=..., device=Q.device + ) + return __decode_wrapped_kernel( + ..., self.config["num_split"], Q, K, V, glse, Output_partial + ) +``` + +______________________________________________________________________ + +## Decode kernel: paged KV variant + +For paged KV cache support, create a separate file `_decode_paged.py`. The paged variant: + +- Adds `page_size: int` to the outermost kernel function +- Replaces the flat KV shape `[batch, seqlen_kv, heads, dim]` with a paged pool shape `[total_pages * page_size, heads, dim]` +- Adds a `block_table: T.Tensor([batch, max_pages], T.int32)` parameter to index into the page pool +- `real_seqlen_kv` becomes a per-batch tensor `T.Tensor([batch], T.int32)` instead of a scalar + +Reference: `tileops/kernels/flash_decode/mha_decode_paged.py` + +______________________________________________________________________ + +## Checklist + +- [ ] File/class name includes variant suffix: `_fwd` / `_decode` / `_bwd` +- [ ] `is_causal: bool` is a parameter of the outermost kernel function +- [ ] (Decode) Both no-split and split-K `@T.prim_func` are implemented; `num_split` is in `default_config` and `autotune_configs` +- [ ] (Decode) Wrapper computes `split_length` and dispatches to the correct path +- [ ] (Decode) `Kernel.forward` allocates `glse` and `Output_partial` before calling the wrapper +- [ ] (Decode paged) Separate `_decode_paged.py` file; uses `page_size`, `block_table`, and per-batch `real_seqlen_kv` diff --git a/.claude/create-new-op-test/skill.md b/.claude/create-new-op-test/skill.md index 3fe5d8e5..336dfe3f 100644 --- a/.claude/create-new-op-test/skill.md +++ b/.claude/create-new-op-test/skill.md @@ -1,3 +1,8 @@ +--- +name: create-new-op-test +description: Write correctness unit tests for a new TileOps Op — pytest structure, parametrization, dtype coverage, tolerance guidelines, reference implementation, and debugging protocol. Auto-invoke when the user asks to write, add, or update tests for a TileOps op. +--- + # Skill: Writing Op Correctness Unit Tests Reference tests: `tests/ops/` diff --git a/.claude/create-new-op/skill.md b/.claude/create-new-op/skill.md index c4f5fb2d..9e06a390 100644 --- a/.claude/create-new-op/skill.md +++ b/.claude/create-new-op/skill.md @@ -1,3 +1,8 @@ +--- +name: create-new-op +description: Create a new TileOps Op — Op class structure, kernel dispatch via dispatch_kernel, forward method, and __init__.py registration. For attention-specific conventions, also use create-new-op-attention. Auto-invoke when the user asks to create, implement, or add a new Op in TileOps. +--- + # Skill: Creating a New TileOps Op Reference implementations: `tileops/ops/mha.py`, `tileops/ops/gqa_decode.py`, `tileops/ops/deepseek_nsa.py` From af6441d8d521e5b5eaf6b183406abb5d4e0aa656 Mon Sep 17 00:00:00 2001 From: SuperAngGao Date: Fri, 27 Feb 2026 11:38:47 +0800 Subject: [PATCH 5/8] [Chore][Skill] Fix mdformat formatting in skill.md files Co-Authored-By: Claude Sonnet 4.6 --- .claude/create-new-kernel/skill.md | 16 ++ .claude/create-new-op-attention/skill.md | 292 ++++++++++++----------- 2 files changed, 164 insertions(+), 144 deletions(-) diff --git a/.claude/create-new-kernel/skill.md b/.claude/create-new-kernel/skill.md index eaf21561..b41fdb4e 100644 --- a/.claude/create-new-kernel/skill.md +++ b/.claude/create-new-kernel/skill.md @@ -257,21 +257,37 @@ class Kernel(Kernel): Key points: - `forward` calls `__wrapped_kernel` (the registered wrapper), not the raw kernel function directly + - `init_config(config, tune)` must be the last call in `__init__`; it reads `default_config` and `autotune_configs` + - `accum_dtype` is not stored on the class; it is hardcoded inside the kernel function + - Only cast index/offset tensors (e.g. `offsets`, `token_indices`) to `torch.int32` in `forward`; do NOT cast floating-point tensors - [ ] `out_idx` in `@tilelang.jit` points to the correct output tensor position + - [ ] `register_fake` output shape matches the real kernel output + - [ ] `custom_op` name is unique: `"top::_wrapped_kernel"` + - [ ] File placed under `tileops/kernels//.py` and exported from its `__init__.py` + - [ ] Class name is CamelCase and ends with `Kernel` + - [ ] Class has docstring with: input tensor shapes, computation logic, reference URL + - [ ] `supported_archs` is set appropriately + - [ ] `init_config` is called at the end of `__init__` + - [ ] `forward` calls the wrapper (`__wrapped_kernel`), not the raw kernel directly + - [ ] Index/offset tensors (e.g. `offsets`, `token_indices`) are cast to `torch.int32` in `forward`; do NOT cast floating-point tensors + - [ ] Default dtype is `float16` (exposed); accumulator is `float32` (hardcoded internally) + - [ ] Global memory is copied to shared memory before computation + - [ ] Reductions operate on fragments, not shared memory + - [ ] `@T.prim_func` is kept minimal; complex logic is in `@T.macro` helpers diff --git a/.claude/create-new-op-attention/skill.md b/.claude/create-new-op-attention/skill.md index de1f7022..47e3b4a5 100644 --- a/.claude/create-new-op-attention/skill.md +++ b/.claude/create-new-op-attention/skill.md @@ -1,144 +1,148 @@ ---- -name: create-new-op-attention -description: Attention-specific conventions for creating a new TileOps attention kernel — variants (fwd/decode/bwd), causal masking, split-K decode design, and paged KV cache. Use together with create-new-kernel (kernel structure) and create-new-op (op registration). Auto-invoke when the user asks to create or implement an attention kernel (MHA, GQA, MLA, flash attention, decode attention, etc.) in TileOps. ---- - -# Skill: Attention Kernel Conventions - -Attention kernels have additional conventions beyond the general rules in `create-new-kernel` and `create-new-op`. - -______________________________________________________________________ - -## Kernel variants - -Every attention kernel must be classified into one of three variants, which determines its file name, class name, and internal structure: - -| Variant | File suffix | Class suffix | Description | -| -------- | ------------ | -------------- | --------------------------------------------- | -| Forward | `_fwd.py` | `FwdKernel` | Full-sequence prefill / training forward pass | -| Decode | `_decode.py` | `DecodeKernel` | Single-token decode with KV cache | -| Backward | `_bwd.py` | `BwdKernel` | Training backward pass | - -Examples: `mha_fwd.py` → `MhaFwdKernel`, `mha_decode.py` → `MhaDecodeKernel` - -______________________________________________________________________ - -## `causal` parameter - -All attention kernels **must** expose `is_causal: bool` as a parameter of the outermost kernel function (the two-level closure). It controls the causal masking logic inside the `@T.prim_func`. - -```python -def __kernel( - batch: int, - heads: int, - seqlen_q: int, - seqlen_kv: int, - dim: int, - is_causal: bool, # required for all attention kernels - dtype: str, -) -> Callable: - ... -``` - -______________________________________________________________________ - -## Decode kernel: split-K design - -Decode kernels must support both a **no-split** and a **split-K** execution path, selected at runtime by `num_split`: - -- `num_split = 1` → use the no-split `@T.prim_func` (single pass over KV) -- `num_split > 1` → use the split `@T.prim_func` (parallel over KV chunks, then combine) - -Both paths are implemented as `@T.macro` functions inside `__func`, and the outer `@T.prim_func` simply calls the appropriate macro. The wrapper function (`__wrapped_kernel`) computes `split_length` and dispatches to the correct path. - -`num_split` is a tunable parameter and must appear in `default_config` and `autotune_configs`. - -Structure inside `__func`: - -```python -def __decode_func(block_M, block_N, num_split, num_stages, threads): - - @T.macro - def __no_split(Q, K, V, real_seqlen_kv, Output): - # single-pass attention over full KV - ... - - @T.macro - def __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length): - # attention over one KV chunk; writes partial output + log-sum-exp - ... - - @T.macro - def combine(glse, Output_partial, Output): - # merge partial outputs using LSE rescaling - ... - - @T.prim_func - def _decode_no_split(Q, K, V, real_seqlen_kv, Output): - __no_split(Q, K, V, real_seqlen_kv, Output) - - @T.prim_func - def _decode_split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length, Output): - __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) - combine(glse, Output_partial, Output) - - if num_split > 1: - return _decode_split - else: - return _decode_no_split -``` - -The wrapper allocates `glse` and `Output_partial` buffers and computes `split_length` before dispatching: - -```python -@torch.library.custom_op("top::_decode_wrapped_kernel", mutates_args=()) -def __decode_wrapped_kernel( - ..., num_split: int, - Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, - glse: torch.Tensor, Output_partial: torch.Tensor, -) -> torch.Tensor: - split_length = ... # compute per-split chunk sizes - if split_length[0] == 0: - num_split = 1 - if num_split == 1: - return __decode_kernel(...)(block_M, block_N, 1, num_stages, threads)(Q, K, V, real_seqlen_kv) - return __decode_kernel(...)(block_M, block_N, num_split, num_stages, threads)( - Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) -``` - -The `Kernel.forward` allocates `glse` and `Output_partial` as temporary buffers before calling the wrapper: - -```python -def forward(self, Q, K, V, real_seqlen_kv): - glse = torch.empty((..., self.config["num_split"], ...), dtype=..., device=Q.device) - Output_partial = torch.empty( - (..., self.config["num_split"], ...), dtype=..., device=Q.device - ) - return __decode_wrapped_kernel( - ..., self.config["num_split"], Q, K, V, glse, Output_partial - ) -``` - -______________________________________________________________________ - -## Decode kernel: paged KV variant - -For paged KV cache support, create a separate file `_decode_paged.py`. The paged variant: - -- Adds `page_size: int` to the outermost kernel function -- Replaces the flat KV shape `[batch, seqlen_kv, heads, dim]` with a paged pool shape `[total_pages * page_size, heads, dim]` -- Adds a `block_table: T.Tensor([batch, max_pages], T.int32)` parameter to index into the page pool -- `real_seqlen_kv` becomes a per-batch tensor `T.Tensor([batch], T.int32)` instead of a scalar - -Reference: `tileops/kernels/flash_decode/mha_decode_paged.py` - -______________________________________________________________________ - -## Checklist - -- [ ] File/class name includes variant suffix: `_fwd` / `_decode` / `_bwd` -- [ ] `is_causal: bool` is a parameter of the outermost kernel function -- [ ] (Decode) Both no-split and split-K `@T.prim_func` are implemented; `num_split` is in `default_config` and `autotune_configs` -- [ ] (Decode) Wrapper computes `split_length` and dispatches to the correct path -- [ ] (Decode) `Kernel.forward` allocates `glse` and `Output_partial` before calling the wrapper -- [ ] (Decode paged) Separate `_decode_paged.py` file; uses `page_size`, `block_table`, and per-batch `real_seqlen_kv` +--- +name: create-new-op-attention +description: Attention-specific conventions for creating a new TileOps attention kernel — variants (fwd/decode/bwd), causal masking, split-K decode design, and paged KV cache. Use together with create-new-kernel (kernel structure) and create-new-op (op registration). Auto-invoke when the user asks to create or implement an attention kernel (MHA, GQA, MLA, flash attention, decode attention, etc.) in TileOps. +--- + +# Skill: Attention Kernel Conventions + +Attention kernels have additional conventions beyond the general rules in `create-new-kernel` and `create-new-op`. + +______________________________________________________________________ + +## Kernel variants + +Every attention kernel must be classified into one of three variants, which determines its file name, class name, and internal structure: + +| Variant | File suffix | Class suffix | Description | +| -------- | ------------ | -------------- | --------------------------------------------- | +| Forward | `_fwd.py` | `FwdKernel` | Full-sequence prefill / training forward pass | +| Decode | `_decode.py` | `DecodeKernel` | Single-token decode with KV cache | +| Backward | `_bwd.py` | `BwdKernel` | Training backward pass | + +Examples: `mha_fwd.py` → `MhaFwdKernel`, `mha_decode.py` → `MhaDecodeKernel` + +______________________________________________________________________ + +## `causal` parameter + +All attention kernels **must** expose `is_causal: bool` as a parameter of the outermost kernel function (the two-level closure). It controls the causal masking logic inside the `@T.prim_func`. + +```python +def __kernel( + batch: int, + heads: int, + seqlen_q: int, + seqlen_kv: int, + dim: int, + is_causal: bool, # required for all attention kernels + dtype: str, +) -> Callable: + ... +``` + +______________________________________________________________________ + +## Decode kernel: split-K design + +Decode kernels must support both a **no-split** and a **split-K** execution path, selected at runtime by `num_split`: + +- `num_split = 1` → use the no-split `@T.prim_func` (single pass over KV) +- `num_split > 1` → use the split `@T.prim_func` (parallel over KV chunks, then combine) + +Both paths are implemented as `@T.macro` functions inside `__func`, and the outer `@T.prim_func` simply calls the appropriate macro. The wrapper function (`__wrapped_kernel`) computes `split_length` and dispatches to the correct path. + +`num_split` is a tunable parameter and must appear in `default_config` and `autotune_configs`. + +Structure inside `__func`: + +```python +def __decode_func(block_M, block_N, num_split, num_stages, threads): + + @T.macro + def __no_split(Q, K, V, real_seqlen_kv, Output): + # single-pass attention over full KV + ... + + @T.macro + def __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length): + # attention over one KV chunk; writes partial output + log-sum-exp + ... + + @T.macro + def combine(glse, Output_partial, Output): + # merge partial outputs using LSE rescaling + ... + + @T.prim_func + def _decode_no_split(Q, K, V, real_seqlen_kv, Output): + __no_split(Q, K, V, real_seqlen_kv, Output) + + @T.prim_func + def _decode_split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length, Output): + __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) + combine(glse, Output_partial, Output) + + if num_split > 1: + return _decode_split + else: + return _decode_no_split +``` + +The wrapper allocates `glse` and `Output_partial` buffers and computes `split_length` before dispatching: + +```python +@torch.library.custom_op("top::_decode_wrapped_kernel", mutates_args=()) +def __decode_wrapped_kernel( + ..., num_split: int, + Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + glse: torch.Tensor, Output_partial: torch.Tensor, +) -> torch.Tensor: + split_length = ... # compute per-split chunk sizes + if split_length[0] == 0: + num_split = 1 + if num_split == 1: + return __decode_kernel(...)(block_M, block_N, 1, num_stages, threads)(Q, K, V, real_seqlen_kv) + return __decode_kernel(...)(block_M, block_N, num_split, num_stages, threads)( + Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) +``` + +The `Kernel.forward` allocates `glse` and `Output_partial` as temporary buffers before calling the wrapper: + +```python +def forward(self, Q, K, V, real_seqlen_kv): + glse = torch.empty((..., self.config["num_split"], ...), dtype=..., device=Q.device) + Output_partial = torch.empty( + (..., self.config["num_split"], ...), dtype=..., device=Q.device + ) + return ( + _ + < attn_name + > _decode_wrapped_kernel( + ..., self.config["num_split"], Q, K, V, glse, Output_partial + ) + ) +``` + +______________________________________________________________________ + +## Decode kernel: paged KV variant + +For paged KV cache support, create a separate file `_decode_paged.py`. The paged variant: + +- Adds `page_size: int` to the outermost kernel function +- Replaces the flat KV shape `[batch, seqlen_kv, heads, dim]` with a paged pool shape `[total_pages * page_size, heads, dim]` +- Adds a `block_table: T.Tensor([batch, max_pages], T.int32)` parameter to index into the page pool +- `real_seqlen_kv` becomes a per-batch tensor `T.Tensor([batch], T.int32)` instead of a scalar + +Reference: `tileops/kernels/flash_decode/mha_decode_paged.py` + +______________________________________________________________________ + +## Checklist + +- [ ] File/class name includes variant suffix: `_fwd` / `_decode` / `_bwd` +- [ ] `is_causal: bool` is a parameter of the outermost kernel function +- [ ] (Decode) Both no-split and split-K `@T.prim_func` are implemented; `num_split` is in `default_config` and `autotune_configs` +- [ ] (Decode) Wrapper computes `split_length` and dispatches to the correct path +- [ ] (Decode) `Kernel.forward` allocates `glse` and `Output_partial` before calling the wrapper +- [ ] (Decode paged) Separate `_decode_paged.py` file; uses `page_size`, `block_table`, and per-batch `real_seqlen_kv` From 76cd1e1024c163e8c6b6043b2cc9d02a78734f47 Mon Sep 17 00:00:00 2001 From: SuperAngGao Date: Fri, 27 Feb 2026 13:16:28 +0800 Subject: [PATCH 6/8] [Refactor][GLA] Rewrite gla_fwd.py to comply with TileOps kernel conventions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Single @T.prim_func with 4 @T.macro stages in one T.Serial(num_chunks) loop - Stages run in order 1→3→4→2 so stage4 reads pre-decay h_s before stage2 updates it - Hoist all shared buffers into _main and pass as parameters to eliminate duplicate allocations (stays within 232448 byte optin limit) - Move shape lists inside _gla_fwd_func so outer closure only captures serializable scalars (fixes autotuner assertion) - Add self.kernel assignment in __init__ to support autotune - Fix custom_op namespace to top:: and add autotune_configs - forward() only allocates buffers and calls wrapper; no PyTorch compute Co-Authored-By: Claude Sonnet 4.6 --- tileops/kernels/gla/gla_fwd.py | 542 +++++++++++++++------------------ 1 file changed, 242 insertions(+), 300 deletions(-) diff --git a/tileops/kernels/gla/gla_fwd.py b/tileops/kernels/gla/gla_fwd.py index 11c74a45..971a8f64 100644 --- a/tileops/kernels/gla/gla_fwd.py +++ b/tileops/kernels/gla/gla_fwd.py @@ -1,28 +1,3 @@ -"""GLA (Gated Linear Attention) Forward Kernel. - -Reference: - https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py - -Algorithm: Chunked GLA forward in 4 stages: - 1. Within-chunk cumulative sum of log-space gates g -> g_cumsum [B, T, H, K] - (computed in PyTorch inside forward() — sequential scan, not GPU-bound) - 2. Inter-chunk hidden state recurrence -> h [B, NT, H, K, V], ht [B, H, K, V] - (computed in PyTorch inside forward() — sequential over chunks) - 3. Intra-chunk causal attention matrix -> A [B, T, H, BT] (TileLang) - 4. Output: o = inter-chunk (q*exp(g_cumsum) @ h) + intra-chunk (A @ v) (TileLang) - -Inputs: - q [B, T, H, K] fp16/bf16 queries - k [B, T, H, K] fp16/bf16 keys - v [B, T, H, V] fp16/bf16 values - g [B, T, H, K] fp16/bf16 log-space forget gates (e.g. F.logsigmoid(...)) - initial_state [B, H, K, V] float32 optional initial hidden state - -Outputs: - o [B, T, H, V] fp16/bf16 - final_state [B, H, K, V] float32 (only when output_final_state=True) -""" - import torch from typing import Optional, Any, Callable @@ -33,147 +8,39 @@ LOG2_E = 1.44269504 -# --------------------------------------------------------------------------- -# Stage 3: Intra-chunk causal attention matrix A [B, T, H, BT] -# --------------------------------------------------------------------------- - -def _gla_fwd_intra_kernel( +def _gla_fwd_kernel( batch: int, seq_len: int, heads: int, dim_k: int, + dim_v: int, chunk_size: int, scale: float, dtype: str, ) -> Callable: - num_chunks = (seq_len + chunk_size - 1) // chunk_size - q_shape = [batch, seq_len, heads, dim_k] - k_shape = [batch, seq_len, heads, dim_k] - g_cumsum_shape = [batch, seq_len, heads, dim_k] - A_shape = [batch, seq_len, heads, chunk_size] - - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }) - def _func(threads: int): - - @T.prim_func - def _kernel( - q: T.Tensor(q_shape, dtype), - k: T.Tensor(k_shape, dtype), - g_cumsum: T.Tensor(g_cumsum_shape, "float32"), - A: T.Tensor(A_shape, "float32"), - ): - with T.Kernel(batch * heads, num_chunks, threads=threads) as (bx, by): - i_b = bx // heads - i_h = bx % heads - i_c = by - chunk_start = i_c * chunk_size - - # Shared buffers for inputs - q_shared = T.alloc_shared([chunk_size, dim_k], dtype) - k_shared = T.alloc_shared([chunk_size, dim_k], dtype) - g_shared = T.alloc_shared([chunk_size, dim_k], "float32") - - # Shared buffers for gated q/k (float32 for gemm) - q_gated = T.alloc_shared([chunk_size, dim_k], "float32") - k_gated = T.alloc_shared([chunk_size, dim_k], "float32") - - # Fragment accumulator for A [BT, BT] - acc = T.alloc_fragment([chunk_size, chunk_size], "float32") - - # Load inputs - T.copy(q[i_b, chunk_start:chunk_start + chunk_size, i_h, :], q_shared) - T.copy(k[i_b, chunk_start:chunk_start + chunk_size, i_h, :], k_shared) - T.copy(g_cumsum[i_b, chunk_start:chunk_start + chunk_size, i_h, :], g_shared) - - # q_gated[t, k] = q[t, k] * exp(g_cumsum[t, k]) - # k_gated[t, k] = k[t, k] * exp(-g_cumsum[t, k]) - for i_t, i_k in T.Parallel(chunk_size, dim_k): - q_gated[i_t, i_k] = ( - T.cast(q_shared[i_t, i_k], "float32") * T.exp2(g_shared[i_t, i_k] * LOG2_E)) - k_gated[i_t, i_k] = ( - T.cast(k_shared[i_t, i_k], "float32") * - T.exp2(-g_shared[i_t, i_k] * LOG2_E)) - - # A = q_gated @ k_gated^T [BT, BT] - T.fill(acc, 0.0) - T.gemm(q_gated, k_gated, acc, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - # Apply causal mask and scale, write to A - for i_t, i_j in T.Parallel(chunk_size, chunk_size): - A[i_b, chunk_start + i_t, i_h, - i_j] = T.if_then_else(i_j <= i_t, acc[i_t, i_j] * scale, 0.0) - - return _kernel - - return _func - - -@torch.library.custom_op("gla::gla_fwd_intra", mutates_args=()) -def _gla_fwd_intra_wrapped( - batch: int, - seq_len: int, - heads: int, - dim_k: int, - chunk_size: int, - scale: float, - dtype: str, - threads: int, - q: torch.Tensor, - k: torch.Tensor, - g_cumsum: torch.Tensor, -) -> torch.Tensor: - return _gla_fwd_intra_kernel(batch, seq_len, heads, dim_k, chunk_size, scale, - dtype)(threads)(q, k, g_cumsum) - - -@_gla_fwd_intra_wrapped.register_fake -def _( - batch: int, - seq_len: int, - heads: int, - dim_k: int, - chunk_size: int, - scale: float, - dtype: str, - threads: int, - *inputs: tuple[Any], -) -> torch.Tensor: - _ = (dim_k, scale, dtype, threads) - return torch.empty([batch, seq_len, heads, chunk_size], - dtype=torch.float32, - device=inputs[0].device) - + """GLA (Gated Linear Attention) forward kernel. -# --------------------------------------------------------------------------- -# Stage 4: Output computation o [B, T, H, V] -# --------------------------------------------------------------------------- + Implements chunked GLA forward in 4 stages within a single @T.prim_func: + Stage 1: within-chunk cumulative sum of log-space gates -> g_cumsum + Stage 3: intra-chunk causal attention matrix A = q_gated @ k_gated^T + Stage 4: output o = scale * q_gated @ h_state + A @ v + Stage 2: inter-chunk hidden state recurrence -> h_state (carried across chunks) + Stages run in order 1→3→4→2 so that stage4 reads h_s before stage2 decays it. -def _gla_fwd_o_kernel( - batch: int, - seq_len: int, - heads: int, - dim_k: int, - dim_v: int, - chunk_size: int, - scale: float, - dtype: str, -) -> Callable: - num_chunks = (seq_len + chunk_size - 1) // chunk_size - q_shape = [batch, seq_len, heads, dim_k] - v_shape = [batch, seq_len, heads, dim_v] - g_cumsum_shape = [batch, seq_len, heads, dim_k] - A_shape = [batch, seq_len, heads, chunk_size] - h_shape = [batch, num_chunks, heads, dim_k, dim_v] - o_shape = [batch, seq_len, heads, dim_v] + Args: + q [B, T, H, K] fp16/bf16 queries + k [B, T, H, K] fp16/bf16 keys + v [B, T, H, V] fp16/bf16 values + g [B, T, H, K] fp16/bf16 log-space forget gates + initial_state [B, H, K, V] float32 initial hidden state (zeros if unused) + h_out [B, NT, H, K, V] float32 per-chunk hidden states (output) + o [B, T, H, V] fp16/bf16 output + Reference: + https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py + """ @tilelang.jit( out_idx=[-1], pass_configs={ @@ -181,74 +48,197 @@ def _gla_fwd_o_kernel( tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, }) - def _func(threads: int): + def _gla_fwd_func(threads: int): + # Shape lists defined inside _gla_fwd_func so the outer closure only + # captures serializable scalars (int/float/str), satisfying tilelang autotuner. + accum_dtype = "float32" + num_chunks = (seq_len + chunk_size - 1) // chunk_size + + q_shape = [batch, seq_len, heads, dim_k] + k_shape = [batch, seq_len, heads, dim_k] + v_shape = [batch, seq_len, heads, dim_v] + g_shape = [batch, seq_len, heads, dim_k] + init_state_shape = [batch, heads, dim_k, dim_v] + h_out_shape = [batch, num_chunks, heads, dim_k, dim_v] + o_shape = [batch, seq_len, heads, dim_v] + + @T.macro + def stage1_cumsum( + g: T.Tensor(g_shape, dtype), + i_b: int, + i_h: int, + i_c: int, + g_cumsum_s: T.Buffer([chunk_size, dim_k], accum_dtype), + ): + """Within-chunk cumulative sum of log-space gates. + Reads g directly from global memory to avoid an extra shared buffer. + """ + chunk_start = i_c * chunk_size + for i_k in T.Parallel(dim_k): + g_cumsum_s[0, i_k] = T.cast(g[i_b, chunk_start, i_h, i_k], accum_dtype) + for i_t in T.Serial(1, chunk_size): + for i_k in T.Parallel(dim_k): + g_cumsum_s[i_t, i_k] = g_cumsum_s[i_t - 1, i_k] + T.cast( + g[i_b, chunk_start + i_t, i_h, i_k], accum_dtype) + + @T.macro + def stage3_intra( + q: T.Tensor(q_shape, dtype), + k: T.Tensor(k_shape, dtype), + g_cumsum_s: T.Buffer([chunk_size, dim_k], accum_dtype), + A_s: T.Buffer([chunk_size, chunk_size], accum_dtype), + qf32_s: T.Buffer([chunk_size, dim_k], accum_dtype), + kf32_s: T.Buffer([chunk_size, dim_k], accum_dtype), + i_b: int, + i_h: int, + i_c: int, + ): + """Intra-chunk causal attention matrix A = q_gated @ k_gated^T. + Reads q and k directly from global memory into qf32_s / kf32_s. + """ + chunk_start = i_c * chunk_size + for i_t, i_k in T.Parallel(chunk_size, dim_k): + qf32_s[i_t, i_k] = T.cast(q[i_b, chunk_start + i_t, i_h, i_k], + accum_dtype) * T.exp2( + g_cumsum_s[i_t, i_k] * LOG2_E) + kf32_s[i_t, i_k] = T.cast(k[i_b, chunk_start + i_t, i_h, i_k], + accum_dtype) * T.exp2( + -g_cumsum_s[i_t, i_k] * LOG2_E) + + A_frag = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + T.fill(A_frag, 0.0) + T.gemm(qf32_s, kf32_s, A_frag, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + for i_t, i_j in T.Parallel(chunk_size, chunk_size): + A_s[i_t, i_j] = T.if_then_else(i_j <= i_t, A_frag[i_t, i_j] * scale, 0.0) + + @T.macro + def stage4_output( + q: T.Tensor(q_shape, dtype), + v: T.Tensor(v_shape, dtype), + g_cumsum_s: T.Buffer([chunk_size, dim_k], accum_dtype), + A_s: T.Buffer([chunk_size, chunk_size], accum_dtype), + h_s: T.Buffer([dim_k, dim_v], accum_dtype), + o: T.Tensor(o_shape, dtype), + qf32_s: T.Buffer([chunk_size, dim_k], accum_dtype), + vf32_s: T.Buffer([chunk_size, dim_v], accum_dtype), + i_b: int, + i_h: int, + i_c: int, + ): + """Output: o = scale * q_gated @ h_s + A @ v. + Called before stage2 so h_s is the pre-decay state entering this chunk. + Reads q and v directly from global memory. + """ + chunk_start = i_c * chunk_size + for i_t, i_k in T.Parallel(chunk_size, dim_k): + qf32_s[i_t, i_k] = T.cast(q[i_b, chunk_start + i_t, i_h, i_k], + accum_dtype) * T.exp2( + g_cumsum_s[i_t, i_k] * LOG2_E) + for i_t, i_v in T.Parallel(chunk_size, dim_v): + vf32_s[i_t, i_v] = T.cast(v[i_b, chunk_start + i_t, i_h, i_v], accum_dtype) + + acc = T.alloc_fragment([chunk_size, dim_v], accum_dtype) + T.fill(acc, 0.0) + T.gemm(qf32_s, h_s, acc, policy=T.GemmWarpPolicy.FullRow) + for i_t, i_v in T.Parallel(chunk_size, dim_v): + acc[i_t, i_v] = acc[i_t, i_v] * scale + T.gemm(A_s, vf32_s, acc, policy=T.GemmWarpPolicy.FullRow) + + for i_t, i_v in T.Parallel(chunk_size, dim_v): + o[i_b, chunk_start + i_t, i_h, i_v] = T.cast(acc[i_t, i_v], dtype) + + @T.macro + def stage2_recurrence( + k: T.Tensor(k_shape, dtype), + v: T.Tensor(v_shape, dtype), + g_cumsum_s: T.Buffer([chunk_size, dim_k], accum_dtype), + h_s: T.Buffer([dim_k, dim_v], accum_dtype), + h_out: T.Tensor(h_out_shape, accum_dtype), + kf32_s: T.Buffer([chunk_size, dim_k], accum_dtype), + vf32_s: T.Buffer([chunk_size, dim_v], accum_dtype), + i_b: int, + i_h: int, + i_c: int, + ): + """Inter-chunk hidden state recurrence. h_s is carried across chunks. + Saves pre-decay h_s to h_out, then decays and accumulates. + Reads k and v directly from global memory. + """ + chunk_start = i_c * chunk_size + # Save h entering this chunk to h_out + T.copy(h_s, h_out[i_b, i_c, i_h, :, :]) + + g_last = T.alloc_fragment([dim_k], accum_dtype) + for i_k in T.Parallel(dim_k): + g_last[i_k] = g_cumsum_s[chunk_size - 1, i_k] + + # Decay h: h[k, v] *= exp(g_last[k]) + for i_k, i_v in T.Parallel(dim_k, dim_v): + h_s[i_k, i_v] = h_s[i_k, i_v] * T.exp2(g_last[i_k] * LOG2_E) + + # k_adj[t, k] = k[t, k] * exp(g_last[k] - g_cumsum[t, k]) + for i_t, i_k in T.Parallel(chunk_size, dim_k): + kf32_s[i_t, i_k] = T.cast(k[i_b, chunk_start + i_t, i_h, i_k], + accum_dtype) * T.exp2( + (g_last[i_k] - g_cumsum_s[i_t, i_k]) * LOG2_E) + for i_t, i_v in T.Parallel(chunk_size, dim_v): + vf32_s[i_t, i_v] = T.cast(v[i_b, chunk_start + i_t, i_h, i_v], accum_dtype) + + # h += k_adj^T @ v [K, V] + delta_h = T.alloc_fragment([dim_k, dim_v], accum_dtype) + T.fill(delta_h, 0.0) + T.gemm(kf32_s, vf32_s, delta_h, transpose_A=True, policy=T.GemmWarpPolicy.FullRow) + for i_k, i_v in T.Parallel(dim_k, dim_v): + h_s[i_k, i_v] = h_s[i_k, i_v] + delta_h[i_k, i_v] @T.prim_func - def _kernel( - q: T.Tensor(q_shape, dtype), - v: T.Tensor(v_shape, dtype), - g_cumsum: T.Tensor(g_cumsum_shape, "float32"), - A: T.Tensor(A_shape, "float32"), - h: T.Tensor(h_shape, "float32"), - o: T.Tensor(o_shape, dtype), + def _main( + q: T.Tensor(q_shape, dtype), + k: T.Tensor(k_shape, dtype), + v: T.Tensor(v_shape, dtype), + g: T.Tensor(g_shape, dtype), + initial_state: T.Tensor(init_state_shape, accum_dtype), + h_out: T.Tensor(h_out_shape, accum_dtype), + o: T.Tensor(o_shape, dtype), ): - with T.Kernel(batch * heads, num_chunks, threads=threads) as (bx, by): + with T.Kernel(batch * heads, threads=threads) as bx: i_b = bx // heads i_h = bx % heads - i_c = by - chunk_start = i_c * chunk_size - - # Shared buffers for inputs - q_shared = T.alloc_shared([chunk_size, dim_k], dtype) - v_shared = T.alloc_shared([chunk_size, dim_v], "float32") - g_cs_shared = T.alloc_shared([chunk_size, dim_k], "float32") - A_shared = T.alloc_shared([chunk_size, chunk_size], "float32") - h_shared = T.alloc_shared([dim_k, dim_v], "float32") - - # Shared buffer for gated q (float32 for gemm) - q_gated = T.alloc_shared([chunk_size, dim_k], "float32") - - # Fragment accumulator [BT, BV] - acc = T.alloc_fragment([chunk_size, dim_v], "float32") - - # Load inputs - T.copy(q[i_b, chunk_start:chunk_start + chunk_size, i_h, :], q_shared) - T.copy(g_cumsum[i_b, chunk_start:chunk_start + chunk_size, i_h, :], g_cs_shared) - T.copy(A[i_b, chunk_start:chunk_start + chunk_size, i_h, :], A_shared) - T.copy(h[i_b, i_c, i_h, :, :], h_shared) - # Load v as float32 - v_raw = T.alloc_shared([chunk_size, dim_v], dtype) - T.copy(v[i_b, chunk_start:chunk_start + chunk_size, i_h, :], v_raw) - for i_t, i_v in T.Parallel(chunk_size, dim_v): - v_shared[i_t, i_v] = T.cast(v_raw[i_t, i_v], "float32") + # Persistent buffers across the chunk loop. + # Shared memory budget (dim_k=128, dim_v=128, chunk_size=64): + # h_s(65536) + g_cumsum_s(32768) + A_s(16384) + # + qf32_s(32768) + kf32_s(32768) + vf32_s(32768) = 212992 < 232448 + h_s = T.alloc_shared([dim_k, dim_v], accum_dtype) + g_cumsum_s = T.alloc_shared([chunk_size, dim_k], accum_dtype) + A_s = T.alloc_shared([chunk_size, chunk_size], accum_dtype) + # Scratch buffers reused across stages each iteration + qf32_s = T.alloc_shared([chunk_size, dim_k], accum_dtype) + kf32_s = T.alloc_shared([chunk_size, dim_k], accum_dtype) + vf32_s = T.alloc_shared([chunk_size, dim_v], accum_dtype) - # q_gated[t, k] = q[t, k] * exp(g_cumsum[t, k]) - for i_t, i_k in T.Parallel(chunk_size, dim_k): - q_gated[i_t, i_k] = ( - T.cast(q_shared[i_t, i_k], "float32") * - T.exp2(g_cs_shared[i_t, i_k] * LOG2_E)) + T.copy(initial_state[i_b, i_h, :, :], h_s) - # inter-chunk: acc = scale * q_gated @ h [BT, BV] - T.fill(acc, 0.0) - T.gemm(q_gated, h_shared, acc, policy=T.GemmWarpPolicy.FullRow) - for i_t, i_v in T.Parallel(chunk_size, dim_v): - acc[i_t, i_v] = acc[i_t, i_v] * scale + for i_c in T.Serial(num_chunks): + stage1_cumsum(g, i_b, i_h, i_c, g_cumsum_s) + stage3_intra(q, k, g_cumsum_s, A_s, qf32_s, kf32_s, i_b, i_h, i_c) + # stage4 runs before stage2: h_s is still the pre-decay state + stage4_output(q, v, g_cumsum_s, A_s, h_s, o, qf32_s, vf32_s, i_b, i_h, i_c) + # stage2 decays h_s and accumulates k^T v; saves pre-decay to h_out + stage2_recurrence(k, v, g_cumsum_s, h_s, h_out, kf32_s, vf32_s, i_b, i_h, + i_c) - # intra-chunk: acc += A @ v [BT, BV] - T.gemm(A_shared, v_shared, acc, policy=T.GemmWarpPolicy.FullRow) + # Overwrite the last h_out slot with the fully-updated final state + T.copy(h_s, h_out[i_b, num_chunks - 1, i_h, :, :]) - # Write output (cast back to dtype) - for i_t, i_v in T.Parallel(chunk_size, dim_v): - o[i_b, chunk_start + i_t, i_h, i_v] = T.cast(acc[i_t, i_v], dtype) + return _main - return _kernel + return _gla_fwd_func - return _func - -@torch.library.custom_op("gla::gla_fwd_o", mutates_args=()) -def _gla_fwd_o_wrapped( +@torch.library.custom_op("top::gla_fwd_wrapped_kernel", mutates_args=()) +def _gla_fwd_wrapped_kernel( batch: int, seq_len: int, heads: int, @@ -259,16 +249,17 @@ def _gla_fwd_o_wrapped( dtype: str, threads: int, q: torch.Tensor, + k: torch.Tensor, v: torch.Tensor, - g_cumsum: torch.Tensor, - A: torch.Tensor, - h: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor, + h_out: torch.Tensor, ) -> torch.Tensor: - return _gla_fwd_o_kernel(batch, seq_len, heads, dim_k, dim_v, chunk_size, scale, - dtype)(threads)(q, v, g_cumsum, A, h) + return _gla_fwd_kernel(batch, seq_len, heads, dim_k, dim_v, chunk_size, scale, + dtype)(threads)(q, k, v, g, initial_state, h_out) -@_gla_fwd_o_wrapped.register_fake +@_gla_fwd_wrapped_kernel.register_fake def _( batch: int, seq_len: int, @@ -287,42 +278,24 @@ def _( device=inputs[0].device) -# --------------------------------------------------------------------------- -# Kernel class -# --------------------------------------------------------------------------- - - class GLAFwdKernel(Kernel): """GLA (Gated Linear Attention) forward kernel. - Implements chunked GLA forward: - Stage 1 (PyTorch): within-chunk cumulative sum of log-space gates - Stage 2 (PyTorch): inter-chunk hidden state recurrence - Stage 3 (TileLang): intra-chunk causal attention matrix A [B, T, H, BT] - Stage 4 (TileLang): output o = inter-chunk + intra-chunk contributions - Args: - batch: Batch size B. - seq_len: Sequence length T. Must be divisible by chunk_size. - heads: Number of query heads H. - dim_k: Key/query head dimension K. - dim_v: Value head dimension V. - chunk_size: Chunk size BT (default 64). - scale: Query scale factor (default 1/sqrt(K)). - output_final_state: Whether to return the final hidden state. - dtype: Input tensor dtype (torch.float16 or torch.bfloat16). - config: Optional kernel config dict (e.g. {"threads": 128}). - tune: Whether to run autotuning. - - Inputs to forward(): - q [B, T, H, K] fp16/bf16 - k [B, T, H, K] fp16/bf16 - v [B, T, H, V] fp16/bf16 - g [B, T, H, K] fp16/bf16 log-space gates - initial_state [B, H, K, V] float32 optional - - Returns: - (o [B, T, H, V], final_state [B, H, K, V] or None) + q: Query tensor, shape [batch, seq_len, heads, dim_k] + k: Key tensor, shape [batch, seq_len, heads, dim_k] + v: Value tensor, shape [batch, seq_len, heads, dim_v] + g: Log-space forget gates, shape [batch, seq_len, heads, dim_k] + initial_state: Optional initial hidden state, shape [batch, heads, dim_k, dim_v] + + Computation: + Chunked GLA forward in 4 TileLang stages per chunk, fused in a single + T.Serial(NT) loop. Stages run in order 1→3→4→2 so that stage4 reads + the pre-decay hidden state before stage2 updates it: + 1. Within-chunk cumulative sum of gates -> g_cumsum + 3. Intra-chunk causal attention A = scale * q_gated @ k_gated^T (causal masked) + 4. Output o = scale * q_gated @ h + A @ v (h is pre-decay) + 2. Inter-chunk hidden state recurrence h += k_adj^T @ v (h carried in shared memory) Reference: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py @@ -354,15 +327,18 @@ def __init__( self.scale = scale if scale > 0 else dim_k**-0.5 self.output_final_state = output_final_state self.dtype_name = str(dtype).split('.')[-1] + self.kernel = _gla_fwd_kernel(batch, seq_len, heads, dim_k, dim_v, chunk_size, self.scale, + self.dtype_name) self.init_config(config, tune) - # GLAFwdKernel has no single self.kernel to autotune; fall back to default_config - if not self.config: - self.config = self.default_config @property def default_config(self) -> dict: return {"threads": 128} + @property + def autotune_configs(self) -> list[dict]: + return [{"threads": t} for t in [64, 128, 256]] + def forward( self, q: torch.Tensor, @@ -371,62 +347,28 @@ def forward( g: torch.Tensor, initial_state: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - dtype_str = self.dtype_name - threads = self.config["threads"] B, T, H, K = self.batch, self.seq_len, self.heads, self.dim_k V = self.dim_v BT = self.chunk_size NT = (T + BT - 1) // BT - dtype_torch = getattr(torch, dtype_str) - - q = q.to(dtype_torch) - k = k.to(dtype_torch) - v = v.to(dtype_torch) - g = g.to(dtype_torch) + dtype_torch = getattr(torch, self.dtype_name) - use_initial_state = initial_state is not None - if not use_initial_state: - b_h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + if initial_state is None: + init_state = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) else: - b_h = initial_state.to(torch.float32).clone() - - # Stage 1: within-chunk cumulative sum of gates (PyTorch) - g_f32 = g.float() - g_cumsum = torch.empty_like(g_f32) - for i_c in range(NT): - cs = i_c * BT - ce = min(cs + BT, T) - g_cumsum[:, cs:ce] = torch.cumsum(g_f32[:, cs:ce], dim=1) - - # Stage 2: inter-chunk hidden state recurrence (PyTorch) - # h_states[b, i_c, h, K, V] = state entering chunk i_c - h_states = torch.empty(B, NT, H, K, V, dtype=torch.float32, device=q.device) - for i_c in range(NT): - cs = i_c * BT - ce = min(cs + BT, T) - h_states[:, i_c] = b_h - - # g_last: g_cumsum at last position of this chunk [B, H, K] - g_last = g_cumsum[:, ce - 1] # [B, H, K] - - # Decay: b_h[b, h, k, v] *= exp(g_last[b, h, k]) - b_h = b_h * torch.exp(g_last).unsqueeze(-1) - - # Accumulate: b_h += k_adj^T @ v - # k_adj[b, t, h, k] = k[b, t, h, k] * exp(g_last[b, h, k] - g_cumsum[b, t, h, k]) - k_chunk = k[:, cs:ce].float() # [B, L, H, K] - v_chunk = v[:, cs:ce].float() # [B, L, H, V] - g_cs_chunk = g_cumsum[:, cs:ce] # [B, L, H, K] - k_adj = k_chunk * torch.exp(g_last.unsqueeze(1) - g_cs_chunk) - b_h = b_h + torch.einsum('blhk,blhv->bhkv', k_adj, v_chunk) - - final_state = b_h if self.output_final_state else None - - # Stage 3: intra-chunk attention matrix (TileLang) - A = _gla_fwd_intra_wrapped(B, T, H, K, BT, self.scale, dtype_str, threads, q, k, g_cumsum) - - # Stage 4: output (TileLang) - o = _gla_fwd_o_wrapped(B, T, H, K, V, BT, self.scale, dtype_str, threads, q, v, g_cumsum, A, - h_states) + init_state = initial_state.to(torch.float32) + + h_out = torch.empty(B, NT, H, K, V, dtype=torch.float32, device=q.device) + + o = _gla_fwd_wrapped_kernel( + B, T, H, K, V, BT, self.scale, self.dtype_name, self.config["threads"], + q.to(dtype_torch), + k.to(dtype_torch), + v.to(dtype_torch), + g.to(dtype_torch), + init_state, + h_out, + ) + final_state = h_out[:, -1] if self.output_final_state else None return o, final_state From d6b30e23023029a19269a4b9a1808843a75c959f Mon Sep 17 00:00:00 2001 From: SuperAngGao Date: Fri, 27 Feb 2026 14:30:48 +0800 Subject: [PATCH 7/8] [Chore][Skill] Sync creating-pull-request and migrating-new-op with main Co-Authored-By: Claude Sonnet 4.6 --- .claude/skills/migrating-new-op/SKILL.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.claude/skills/migrating-new-op/SKILL.md b/.claude/skills/migrating-new-op/SKILL.md index 650c5238..7998544a 100644 --- a/.claude/skills/migrating-new-op/SKILL.md +++ b/.claude/skills/migrating-new-op/SKILL.md @@ -58,11 +58,10 @@ ______________________________________________________________________ 1. Implement TileLang kernels by stage (same logical decomposition as reference). 1. Keep output semantics compatible with existing interfaces. 1. Handle core edge cases early (for example: empty paths / `num_topk == 0`). -1. Follow the conventions in [`.claude/create-new-kernel/skill.md`](../create-new-kernel/skill.md): file layout, naming, dtype rules, memory hierarchy, attention variants, wrapper/register_fake pattern, and docstring requirements. ### Phase D: Upper Layer Wiring -1. Add `op` wrapper for kernel invocation + runtime contract. Follow [`.claude/create-new-op/skill.md`](../create-new-op/skill.md) for Op class structure, kernel dispatch, and `__init__.py` registration. +1. Add `op` wrapper for kernel invocation + runtime contract. 1. Add `function` API for reusable composition and shape/dtype validation. 1. Add `layer` abstraction only when module-style integration is needed. 1. Keep dependency direction one-way: `layer -> function -> op -> kernel`. @@ -71,7 +70,7 @@ ______________________________________________________________________ 1. Ensure layer-matched tests exist: -- `tests/ops` for op behavior — follow [`.claude/create-new-op-test/skill.md`](../create-new-op-test/skill.md) for test structure, parametrization, dtype coverage, and debugging protocol +- `tests/ops` for op behavior - `tests/functions` for functional integration - `tests/layers` when a layer is introduced @@ -105,12 +104,11 @@ Before opening a PR, verify all required items are present: - [ ] New/updated kernel implementation exists - [ ] Kernel handles documented edge cases -- [ ] Follows conventions in [`.claude/create-new-kernel/skill.md`](../create-new-kernel/skill.md) 2. **L2: Op (`tileops/ops`)** - [ ] Op API wraps kernel with stable argument contract -- [ ] Op-level tests added/updated in `tests/ops` — follows [`.claude/create-new-op-test/skill.md`](../create-new-op-test/skill.md) +- [ ] Op-level tests added/updated in `tests/ops` - [ ] Benchmark script added/updated in `benchmarks` 3. **L3: Function (`tileops/functions`)** From b2783e15b45726ba0a35c713172a4e3a13a8156f Mon Sep 17 00:00:00 2001 From: SuperAngGao Date: Fri, 27 Feb 2026 17:02:34 +0800 Subject: [PATCH 8/8] [Chore][Skill] Move skill directories into .claude/skills/ Co-Authored-By: Claude Sonnet 4.6 --- .claude/{ => skills}/create-new-kernel/skill.md | 0 .claude/{ => skills}/create-new-op-attention/skill.md | 0 .claude/{ => skills}/create-new-op-test/skill.md | 0 .claude/{ => skills}/create-new-op/skill.md | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename .claude/{ => skills}/create-new-kernel/skill.md (100%) rename .claude/{ => skills}/create-new-op-attention/skill.md (100%) rename .claude/{ => skills}/create-new-op-test/skill.md (100%) rename .claude/{ => skills}/create-new-op/skill.md (100%) diff --git a/.claude/create-new-kernel/skill.md b/.claude/skills/create-new-kernel/skill.md similarity index 100% rename from .claude/create-new-kernel/skill.md rename to .claude/skills/create-new-kernel/skill.md diff --git a/.claude/create-new-op-attention/skill.md b/.claude/skills/create-new-op-attention/skill.md similarity index 100% rename from .claude/create-new-op-attention/skill.md rename to .claude/skills/create-new-op-attention/skill.md diff --git a/.claude/create-new-op-test/skill.md b/.claude/skills/create-new-op-test/skill.md similarity index 100% rename from .claude/create-new-op-test/skill.md rename to .claude/skills/create-new-op-test/skill.md diff --git a/.claude/create-new-op/skill.md b/.claude/skills/create-new-op/skill.md similarity index 100% rename from .claude/create-new-op/skill.md rename to .claude/skills/create-new-op/skill.md