diff --git a/.claude/skills/create-new-kernel/skill.md b/.claude/skills/create-new-kernel/skill.md new file mode 100644 index 00000000..b41fdb4e --- /dev/null +++ b/.claude/skills/create-new-kernel/skill.md @@ -0,0 +1,293 @@ +--- +name: create-new-kernel +description: Create a new TileOps GPU kernel — file layout, tilelang kernel function, wrapper registration, and Kernel subclass. For attention-specific conventions (variants, causal masking, split-K decode, paged KV), also use the create-new-op and create-new-op-attention skills. Auto-invoke when the user asks to create, implement, or add a new kernel in TileOps. +--- + +# Skill: Creating a New TileOps Kernel + +Reference implementation: `tileops/kernels/deepseek_nsa/nsa_topk.py` +Base class: `tileops/kernels/kernel.py` + +## Development Environment + +| Item | Version | +| -------- | ------------ | +| GPU | H200 (sm90a) | +| CUDA | 12.9 | +| TileLang | 0.1.7.post1 | + +Target architecture is `sm90a`. All kernels are developed and validated on this environment. Do not assume compatibility with older architectures unless explicitly tested. + +______________________________________________________________________ + +## File Location and Naming + +``` +tileops/kernels//.py +``` + +- Group related kernels under a feature subdirectory (e.g. `tileops/kernels/deepseek_nsa/`) +- File name: `.py`, all lowercase with underscores (e.g. `nsa_topk.py`) +- Kernel function: `__kernel` (e.g. `_nsa_topk_varlen_kernel`) +- Wrapper function: `__wrapped_kernel` +- Class name: CamelCase + `Kernel` suffix (e.g. `NSATopkVarlenKernel`) +- Export the class from `tileops/kernels//__init__.py` + +______________________________________________________________________ + +## File Structure + +A kernel file contains five parts in order: + +1. Imports +1. Kernel function (tilelang implementation) +1. Wrapper function (`__wrapped_kernel`) +1. `register_fake` for the wrapper +1. `Kernel` subclass + +______________________________________________________________________ + +## Part 1: Imports + +```python +import torch +from typing import Optional, Any, Callable + +import tilelang +from tilelang import language as T + +from tileops.kernels.kernel import Kernel +``` + +Tilelang kernel implementations generally need no additional imports beyond these. + +______________________________________________________________________ + +## Part 2: Kernel Function + +The kernel is implemented as a two-level closure: + +```python +def __kernel( + # Fixed kernel parameters: shapes, dtypes, algorithm constants + param1: int, + param2: int, + dtype: str, + accum_dtype: str, +) -> Callable: + # Precompute constants and define tensor shapes here + # e.g. q_shape = [seq_len, heads, dim] + + @tilelang.jit( + out_idx=[-1], # index of output tensor in @T.prim_func's parameter list + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + def __func(threads: int): # auto-tune parameters go here + + @T.prim_func + def _main( + input1: T.Tensor(shape1, dtype), + ..., + output: T.Tensor(output_shape, accum_dtype), + ): + with T.Kernel(grid_x, grid_y, threads=threads) as (bx, by): + # kernel body + + return _main + + return __func +``` + +### Dtype conventions + +- Default compute dtype: `float16` (`"float16"`); exposed as `dtype: str` in the outermost kernel function so callers can override. +- Accumulator dtype: always `float32`, hardcoded internally as `accum_dtype = "float32"`. Do not expose it as a parameter. + +### Memory hierarchy + +- Always move data from global memory to **shared memory** (`T.alloc_shared`) before computation. +- For reductions, move data from shared memory to **fragment** (`T.alloc_fragment`) first, then call `T.reduce_sum` / `T.reduce_max` on the fragment. +- Never reduce directly on shared memory. + +Typical pattern: + +``` +global → T.copy → shared → T.gemm / T.copy → fragment → T.reduce_* → fragment +``` + +### Code organisation + +- Keep `@T.prim_func` minimal — it should read like a high-level algorithm outline. +- Extract non-trivial logic (sorting, masking, multi-step computations) into `@T.macro` helpers defined inside `__func`, and call them from `_main`. + +### Tilelang coding guidelines + +| Concern | Guidance | +| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| Exponentiation | Use `T.exp2(x * LOG2_E)` instead of `T.exp(x)` for performance (`LOG2_E = 1.44269504`) | +| Data movement | Use `T.copy()` for copies; `T.gemm()` for matrix multiply | +| Reduction | Use `T.reduce_sum()` / `T.reduce_max()` on fragments (not shared memory) | +| Parallelism | Prefer `T.Parallel` over `T.Serial`; only use `T.Serial` when sequential dependency is unavoidable | +| Loop structure | Prefer `T.Pipelined` for outer loops to enable software pipelining | +| Bug workaround | Tilelang may have bugs with nested `T.Parallel` / `T.Serial` / `T.Pipelined`. If compilation or runtime errors occur, try reordering the nesting levels as a first debugging step. | + +______________________________________________________________________ + +## Part 3: Wrapper Function + +The wrapper registers the kernel with `torch.library` so it is visible to `torch.compile` and `torch.export`. The `Kernel` class calls this wrapper in `forward`. + +### 3a. custom_op wrapper + +```python +@torch.library.custom_op("top::_wrapped_kernel", mutates_args=()) +def __wrapped_kernel( + # All scalar params first (int, float, str for dtype names) + param1: int, + dtype: str, + threads: int, + # Tensor inputs last + input1: torch.Tensor, + ... +) -> torch.Tensor: + return __kernel(param1, dtype)(threads)(input1, ...) +``` + +The call chain is: `_kernel(fixed_params)(autotune_params)(tensor_inputs)`. + +Note: `accum_dtype` is hardcoded inside `__kernel` and is not a parameter of the wrapper. + +### 3b. register_fake + +Provides a shape/dtype inference rule for tracing (no actual computation): + +```python +@_ < kernel_name > _wrapped_kernel.register_fake +def _( + param1: int, + dtype: str, + threads: int, + *inputs: tuple[Any], +) -> torch.Tensor: + _ = (param1, dtype, threads) # suppress unused warnings + # Return an empty tensor with the correct output shape and dtype. + # Shape must match what the real kernel produces. + return torch.empty( + [output_dim0, output_dim1, ...], + dtype=inputs[0].dtype, + device=inputs[0].device, + ) +``` + +> The output shape here must exactly match the real kernel's output. Derive it from the scalar parameters (e.g. `c_seq_len`, `heads`, `selected_block_num`), not from input tensor shapes. + +______________________________________________________________________ + +## Part 4: Kernel Class + +- Naming: CamelCase, descriptive, ending in `Kernel` (e.g. `NSATopkVarlenKernel`) +- Inherit from `Kernel` (see `tileops/kernels/kernel.py`) +- Declare `supported_archs` for GPU arch gating; current target is H200 (sm90a), so use `[90]` + +The class **must** include a docstring with three sections: + +1. Input tensor shapes — list each tensor parameter with its layout (e.g. `[batch, seqlen, heads, dim]`) +1. Computation logic — a brief description of what the kernel computes +1. Reference — URL to the official PyTorch / Triton / paper implementation this is based on + +Example: + +```python +class Kernel(Kernel): + """. + + Args: + q: Query tensor, shape [batch, seqlen_q, heads, dim] + k: Key tensor, shape [batch, seqlen_k, heads_kv, dim] + v: Value tensor, shape [batch, seqlen_k, heads_kv, dim_v] + offsets: Sequence boundary offsets, shape [seq_num + 1], dtype int32 + ... + + Computation: + + + Reference: + + """ + supported_archs: list[int] = [90] + + def __init__(self, + param1: int, + ..., + dtype: torch.dtype = torch.float16, + config: Optional[dict] = None, + tune: bool = False) -> None: + super().__init__() + self.param1 = param1 + ... + # Store dtype as string for tilelang; accum_dtype is hardcoded in the kernel + self.dtype_name = str(dtype).split('.')[-1] + self.init_config(config, tune) # must be called last + + @property + def default_config(self) -> dict: + return {"threads": 32} + + @property + def autotune_configs(self) -> list[dict]: + return [{"threads": t} for t in [32, 64, 128]] + + def forward(self, input1: torch.Tensor, ...) -> torch.Tensor: + return __wrapped_kernel( + self.param1, + ..., + self.dtype_name, + self.config["threads"], + input1.to(getattr(torch, self.dtype_name)), + ..., + ) +``` + +Key points: + +- `forward` calls `__wrapped_kernel` (the registered wrapper), not the raw kernel function directly + +- `init_config(config, tune)` must be the last call in `__init__`; it reads `default_config` and `autotune_configs` + +- `accum_dtype` is not stored on the class; it is hardcoded inside the kernel function + +- Only cast index/offset tensors (e.g. `offsets`, `token_indices`) to `torch.int32` in `forward`; do NOT cast floating-point tensors + +- [ ] `out_idx` in `@tilelang.jit` points to the correct output tensor position + +- [ ] `register_fake` output shape matches the real kernel output + +- [ ] `custom_op` name is unique: `"top::_wrapped_kernel"` + +- [ ] File placed under `tileops/kernels//.py` and exported from its `__init__.py` + +- [ ] Class name is CamelCase and ends with `Kernel` + +- [ ] Class has docstring with: input tensor shapes, computation logic, reference URL + +- [ ] `supported_archs` is set appropriately + +- [ ] `init_config` is called at the end of `__init__` + +- [ ] `forward` calls the wrapper (`__wrapped_kernel`), not the raw kernel directly + +- [ ] Index/offset tensors (e.g. `offsets`, `token_indices`) are cast to `torch.int32` in `forward`; do NOT cast floating-point tensors + +- [ ] Default dtype is `float16` (exposed); accumulator is `float32` (hardcoded internally) + +- [ ] Global memory is copied to shared memory before computation + +- [ ] Reductions operate on fragments, not shared memory + +- [ ] `@T.prim_func` is kept minimal; complex logic is in `@T.macro` helpers diff --git a/.claude/skills/create-new-op-attention/skill.md b/.claude/skills/create-new-op-attention/skill.md new file mode 100644 index 00000000..47e3b4a5 --- /dev/null +++ b/.claude/skills/create-new-op-attention/skill.md @@ -0,0 +1,148 @@ +--- +name: create-new-op-attention +description: Attention-specific conventions for creating a new TileOps attention kernel — variants (fwd/decode/bwd), causal masking, split-K decode design, and paged KV cache. Use together with create-new-kernel (kernel structure) and create-new-op (op registration). Auto-invoke when the user asks to create or implement an attention kernel (MHA, GQA, MLA, flash attention, decode attention, etc.) in TileOps. +--- + +# Skill: Attention Kernel Conventions + +Attention kernels have additional conventions beyond the general rules in `create-new-kernel` and `create-new-op`. + +______________________________________________________________________ + +## Kernel variants + +Every attention kernel must be classified into one of three variants, which determines its file name, class name, and internal structure: + +| Variant | File suffix | Class suffix | Description | +| -------- | ------------ | -------------- | --------------------------------------------- | +| Forward | `_fwd.py` | `FwdKernel` | Full-sequence prefill / training forward pass | +| Decode | `_decode.py` | `DecodeKernel` | Single-token decode with KV cache | +| Backward | `_bwd.py` | `BwdKernel` | Training backward pass | + +Examples: `mha_fwd.py` → `MhaFwdKernel`, `mha_decode.py` → `MhaDecodeKernel` + +______________________________________________________________________ + +## `causal` parameter + +All attention kernels **must** expose `is_causal: bool` as a parameter of the outermost kernel function (the two-level closure). It controls the causal masking logic inside the `@T.prim_func`. + +```python +def __kernel( + batch: int, + heads: int, + seqlen_q: int, + seqlen_kv: int, + dim: int, + is_causal: bool, # required for all attention kernels + dtype: str, +) -> Callable: + ... +``` + +______________________________________________________________________ + +## Decode kernel: split-K design + +Decode kernels must support both a **no-split** and a **split-K** execution path, selected at runtime by `num_split`: + +- `num_split = 1` → use the no-split `@T.prim_func` (single pass over KV) +- `num_split > 1` → use the split `@T.prim_func` (parallel over KV chunks, then combine) + +Both paths are implemented as `@T.macro` functions inside `__func`, and the outer `@T.prim_func` simply calls the appropriate macro. The wrapper function (`__wrapped_kernel`) computes `split_length` and dispatches to the correct path. + +`num_split` is a tunable parameter and must appear in `default_config` and `autotune_configs`. + +Structure inside `__func`: + +```python +def __decode_func(block_M, block_N, num_split, num_stages, threads): + + @T.macro + def __no_split(Q, K, V, real_seqlen_kv, Output): + # single-pass attention over full KV + ... + + @T.macro + def __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length): + # attention over one KV chunk; writes partial output + log-sum-exp + ... + + @T.macro + def combine(glse, Output_partial, Output): + # merge partial outputs using LSE rescaling + ... + + @T.prim_func + def _decode_no_split(Q, K, V, real_seqlen_kv, Output): + __no_split(Q, K, V, real_seqlen_kv, Output) + + @T.prim_func + def _decode_split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length, Output): + __split(Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) + combine(glse, Output_partial, Output) + + if num_split > 1: + return _decode_split + else: + return _decode_no_split +``` + +The wrapper allocates `glse` and `Output_partial` buffers and computes `split_length` before dispatching: + +```python +@torch.library.custom_op("top::_decode_wrapped_kernel", mutates_args=()) +def __decode_wrapped_kernel( + ..., num_split: int, + Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, + glse: torch.Tensor, Output_partial: torch.Tensor, +) -> torch.Tensor: + split_length = ... # compute per-split chunk sizes + if split_length[0] == 0: + num_split = 1 + if num_split == 1: + return __decode_kernel(...)(block_M, block_N, 1, num_stages, threads)(Q, K, V, real_seqlen_kv) + return __decode_kernel(...)(block_M, block_N, num_split, num_stages, threads)( + Q, K, V, real_seqlen_kv, glse, Output_partial, split_length) +``` + +The `Kernel.forward` allocates `glse` and `Output_partial` as temporary buffers before calling the wrapper: + +```python +def forward(self, Q, K, V, real_seqlen_kv): + glse = torch.empty((..., self.config["num_split"], ...), dtype=..., device=Q.device) + Output_partial = torch.empty( + (..., self.config["num_split"], ...), dtype=..., device=Q.device + ) + return ( + _ + < attn_name + > _decode_wrapped_kernel( + ..., self.config["num_split"], Q, K, V, glse, Output_partial + ) + ) +``` + +______________________________________________________________________ + +## Decode kernel: paged KV variant + +For paged KV cache support, create a separate file `_decode_paged.py`. The paged variant: + +- Adds `page_size: int` to the outermost kernel function +- Replaces the flat KV shape `[batch, seqlen_kv, heads, dim]` with a paged pool shape `[total_pages * page_size, heads, dim]` +- Adds a `block_table: T.Tensor([batch, max_pages], T.int32)` parameter to index into the page pool +- `real_seqlen_kv` becomes a per-batch tensor `T.Tensor([batch], T.int32)` instead of a scalar + +Reference: `tileops/kernels/flash_decode/mha_decode_paged.py` + +______________________________________________________________________ + +## Checklist + +- [ ] File/class name includes variant suffix: `_fwd` / `_decode` / `_bwd` +- [ ] `is_causal: bool` is a parameter of the outermost kernel function +- [ ] (Decode) Both no-split and split-K `@T.prim_func` are implemented; `num_split` is in `default_config` and `autotune_configs` +- [ ] (Decode) Wrapper computes `split_length` and dispatches to the correct path +- [ ] (Decode) `Kernel.forward` allocates `glse` and `Output_partial` before calling the wrapper +- [ ] (Decode paged) Separate `_decode_paged.py` file; uses `page_size`, `block_table`, and per-batch `real_seqlen_kv` diff --git a/.claude/skills/create-new-op-test/skill.md b/.claude/skills/create-new-op-test/skill.md new file mode 100644 index 00000000..336dfe3f --- /dev/null +++ b/.claude/skills/create-new-op-test/skill.md @@ -0,0 +1,204 @@ +--- +name: create-new-op-test +description: Write correctness unit tests for a new TileOps Op — pytest structure, parametrization, dtype coverage, tolerance guidelines, reference implementation, and debugging protocol. Auto-invoke when the user asks to write, add, or update tests for a TileOps op. +--- + +# Skill: Writing Op Correctness Unit Tests + +Reference tests: `tests/ops/` + +______________________________________________________________________ + +## Overview + +Each op has a pytest file at `tests/ops/test_.py`. The test: + +1. Instantiates the `Op` with given parameters +1. Generates random input tensors +1. Runs the op and a PyTorch reference implementation +1. Asserts numerical closeness with `torch.allclose` + +No benchmark infrastructure is needed — correctness only. + +______________________________________________________________________ + +## File Location and Naming + +``` +tests/ops/test_.py +``` + +File name mirrors the op module name (e.g. `tileops/ops/gqa_decode.py` → `tests/ops/test_gqa_decode.py`). + +______________________________________________________________________ + +## Test File Structure + +```python +import pytest +import torch + +from tileops.ops import Op + + +def ref_(*inputs, **params) -> torch.Tensor: + # Pure PyTorch reference implementation + # Use torch.nn.functional or F.scaled_dot_product_attention + ... + + +@pytest.mark.parametrize("", [ + (...), # case 1 + (...), # case 2 + # ... at least 5 cases +]) +def test_(param1: int, ..., dtype: torch.dtype, tune: bool) -> None: + torch.manual_seed(42) + + op = Op(param1, ..., dtype=dtype, tune=tune) + + # generate inputs + x = torch.randn(..., device='cuda', dtype=dtype) + ... + + # run op + with torch.no_grad(): + out = op(x, ...) + + # run reference + ref = ref_(x, ..., param1=param1, ...) + + assert torch.allclose(out, ref, atol=, rtol=), \ + f"max err: {(out - ref).abs().max()}" + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) +``` + +______________________________________________________________________ + +## Test Case Requirements + +### Minimum coverage + +- At least **5 parametrized test cases** per test function +- Vary structural parameters across cases: batch, seq_len, heads, dim, etc. +- Include at least one `tune=True` case if the op supports autotuning + +### dtype coverage + +| Op type | Required dtypes | +| --------------------------------- | --------------------------------------------------------------- | +| General ops | `torch.float16`, `torch.bfloat16` | +| Quantization ops (fp8, fp4, etc.) | Include the target quantized dtype (e.g. `torch.float8_e4m3fn`) | +| Mixed-precision ops | Cover all relevant input/output dtype combinations | + +### Random seed + +Fix `torch.manual_seed(42)` at the top of every test function, before any tensor creation. + +### Tolerance guidelines + +| Op type | `atol` | `rtol` | +| -------------------------- | ---------------- | ------ | +| Attention fwd (fp16/bf16) | `5e-3` | `1e-5` | +| Attention decode | `1e-2` | `1e-2` | +| GEMM / linear | `1e-3` | `1e-3` | +| Elementwise / quantization | `1e-1` or custom | — | + +### Reference implementation + +- If `torch.nn` or `torch.nn.functional` provides the operation directly, use it. +- If not, implement the reference manually using basic PyTorch ops (`torch.matmul`, `torch.softmax`, elementwise ops, etc.), following the algorithm described in the kernel's docstring and the official reference URL. +- For attention fwd: use `SDPBackend.FLASH_ATTENTION` +- For attention decode: use `SDPBackend.MATH` (flash attention does not support single-token decode) +- Never use another TileOps op as the reference +- If the op has an official PyTorch or Triton reference implementation (e.g. in the kernel's docstring `Reference:` URL), consult it when writing `ref_program` to ensure the reference matches the intended algorithm exactly — including scale factors, layout conventions, and masking logic. + +### Fallback: cosine similarity test + +If numerical error persists after exhausting all debugging steps (see `.claude/create-new-op/skill.md` debugging protocol), the `torch.allclose` assertion may be replaced with a cosine similarity check. The threshold is **0.999 and must not be relaxed**: + +```python +def cosine_sim(a: torch.Tensor, b: torch.Tensor) -> float: + a_f = a.float().flatten() + b_f = b.float().flatten() + return torch.nn.functional.cosine_similarity(a_f, b_f, dim=0).item() + + +# in the test: +sim = cosine_sim(out, ref) +assert sim >= 0.999, f"cosine similarity {sim:.6f} < 0.999" +``` + +Add a comment explaining why `torch.allclose` was replaced and what debugging was attempted. + +______________________________________________________________________ + +## Example: attention fwd op + +```python +import pytest +import torch +import torch.nn.functional as F +from torch.nn.attention import SDPBackend, sdpa_kernel + +from tileops.ops import MultiHeadAttentionFwdOp + + +def ref_mha_fwd(q, k, v, is_causal): + # input layout: [batch, seqlen, heads, dim] → transpose to [batch, heads, seqlen, dim] + q_t, k_t, v_t = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=is_causal) + return out.transpose(1, 2).contiguous() + + +@pytest.mark.parametrize( + "batch, seq_len, heads, dim, is_causal, dtype, tune", + [ + (1, 1024, 8, 64, False, torch.float16, False), + (4, 2048, 16, 128, False, torch.bfloat16, False), + (8, 4096, 16, 128, True, torch.float16, False), + (2, 1024, 32, 64, True, torch.bfloat16, False), + (4, 2048, 16, 128, False, torch.bfloat16, True), + ], +) +def test_mha_fwd(batch, seq_len, heads, dim, is_causal, dtype, tune): + torch.manual_seed(42) + op = MultiHeadAttentionFwdOp( + batch, heads, seq_len, dim, is_causal, dtype, tune=tune + ) + + q = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=dtype) + k = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=dtype) + v = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=dtype) + + with torch.no_grad(): + out = op(q, k, v) + ref = ref_mha_fwd(q, k, v, is_causal) + + assert torch.allclose( + out, ref, atol=5e-3, rtol=1e-5 + ), f"max err: {(out - ref).abs().max()}" + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) +``` + +______________________________________________________________________ + +## Checklist + +- [ ] File placed at `tests/ops/test_.py` +- [ ] `torch.manual_seed(42)` at the top of every test function +- [ ] At least 5 parametrized test cases +- [ ] Both `torch.float16` and `torch.bfloat16` covered (unless op is dtype-specific) +- [ ] Quantization ops include the target quantized dtype +- [ ] Structural parameters varied across cases +- [ ] At least one `tune=True` case if op supports autotuning +- [ ] `atol` / `rtol` appropriate for the op type +- [ ] Reference uses PyTorch built-ins only +- [ ] Test file ends with `if __name__ == "__main__": pytest.main([__file__, "-vvs"])` diff --git a/.claude/skills/create-new-op/skill.md b/.claude/skills/create-new-op/skill.md new file mode 100644 index 00000000..9e06a390 --- /dev/null +++ b/.claude/skills/create-new-op/skill.md @@ -0,0 +1,260 @@ +--- +name: create-new-op +description: Create a new TileOps Op — Op class structure, kernel dispatch via dispatch_kernel, forward method, and __init__.py registration. For attention-specific conventions, also use create-new-op-attention. Auto-invoke when the user asks to create, implement, or add a new Op in TileOps. +--- + +# Skill: Creating a New TileOps Op + +Reference implementations: `tileops/ops/mha.py`, `tileops/ops/gqa_decode.py`, `tileops/ops/deepseek_nsa.py` +Base class: `tileops/ops/op.py` + +______________________________________________________________________ + +## Overview + +An `Op` is a thin orchestration layer that: + +1. Holds kernel instances (one or more `Kernel` subclasses) +1. Dispatches to the correct kernel based on hardware via `dispatch_kernel` +1. Exposes a `forward` method that calls the kernel(s) and handles any pre/post-processing + +An `Op` does **not** implement GPU computation itself — that lives in the `Kernel`. + +______________________________________________________________________ + +## File Location and Naming + +``` +tileops/ops/.py +``` + +- File name: `.py`, all lowercase with underscores (e.g. `gqa_decode.py`, `deepseek_nsa.py`) +- Class name: CamelCase + `Op` suffix (e.g. `GroupQueryAttentionDecodeWithKVCacheOp`) +- Group multiple related ops in one file when they share the same kernel set (e.g. `mha.py` contains both `MultiHeadAttentionFwdOp` and `MultiHeadAttentionBwdOp`) +- After creating the file, register all new classes in `tileops/ops/__init__.py` + +______________________________________________________________________ + +## File Structure + +``` +tileops/ops/.py +``` + +A single op file contains: + +1. Imports +1. `__all__` declaration +1. One or more `Op` subclasses + +After creating the file, register the new class in `tileops/ops/__init__.py`. + +______________________________________________________________________ + +## Part 1: Imports + +```python +from typing import Dict, Optional + +import torch + +from tileops.kernels. import +from tileops.kernels.kernel import Kernel + +from .op import Op +``` + +Only import `is_hopper` from `tileops.utils` if the op needs to dispatch different kernels per architecture. + +______________________________________________________________________ + +## Part 2: Op Class + +### Naming + +- CamelCase, descriptive, ending in `Op` +- Examples: `MultiHeadAttentionFwdOp`, `NSATopkVarlenOp`, `GroupQueryAttentionDecodeWithKVCacheOp` + +### `__init__` — two patterns + +**Pattern A: simple (one kernel, few params)** + +Assign each param explicitly, then call `dispatch_kernel` and instantiate the kernel: + +```python +class Op(Op): + + def __init__(self, + param1: int, + param2: int, + dtype: torch.dtype = torch.float16, + tune: bool = False, + kernel_map: Optional[Dict[str, Kernel]] = None) -> None: + self.param1 = param1 + self.param2 = param2 + self.dtype = dtype + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map[""]( + param1, param2, dtype, tune=tune) +``` + +**Pattern B: many params (use locals() shortcut)** + +When there are many parameters, use `locals()` to avoid repetition: + +```python + def __init__(self, + param1: int, + ..., + dtype: torch.dtype = torch.float16, + tune: bool = False, + kernel_map: Optional[Dict[str, Kernel]] = None, + ) -> None: + params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map')} + for key, value in params.items(): + setattr(self, key, value) + + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map[""](**params) +``` + +> `kernel_map` is always the last parameter, after `tune`. When using `locals()`, exclude both `self` and `kernel_map` from `params`; `tune` is included and passed through to the kernel. `accum_dtype` is never a parameter of the Op — it is hardcoded inside the Kernel. + +### `default_kernel_map` + +Maps string keys to `Kernel` classes. This is what `dispatch_kernel` uses to resolve the actual kernel type: + +```python + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"": } +``` + +For ops that need different kernels per architecture: + +```python + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"": if is_hopper() else } +``` + +For ops with multiple kernels (e.g. fwd + pre/post-process): + +```python + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return { + "preprocess_kernel": PreprocessKernel, + "main_kernel": MainWgmmaPipelinedKernel if is_hopper() else MainKernel, + "postprocess_kernel": PostprocessKernel if not is_hopper() else None, + } +``` + +### `forward` + +Calls `self.kernel(...)` directly (the `Kernel.__call__` delegates to `Kernel.forward`). + +```python + def forward(self, input1: torch.Tensor, ...) -> torch.Tensor: + return self.kernel(input1, ...) +``` + +For ops with pre/post-processing: + +```python + def forward(self, ...) -> tuple[torch.Tensor, ...]: + intermediate = self.prep_kernel(...) + result = self.kernel(..., intermediate) + return self.post_kernel(result) +``` + +For ops that need input validation or padding before calling the kernel (see `gqa_decode.py`): + +```python + def forward(self, q: torch.Tensor, k: torch.Tensor, ...) -> torch.Tensor: + # e.g. pad k/v to match declared seqlen + if k.shape[1] < self.seqlen_kv: + k = F.pad(k, ...) + return self.kernel(q, k, ...) +``` + +______________________________________________________________________ + +## Part 3: Register in `__init__.py` + +Add the import and the class name to `__all__` in `tileops/ops/__init__.py`: + +```python +# import +from . import Op + +# __all__ +__all__ = [ + ... + "Op", +] +``` + +______________________________________________________________________ + +## Key Points + +- `dispatch_kernel(kernel_map)` must be called before instantiating any kernel; it validates arch compatibility via `Kernel.supported_archs` +- `self.kernel` is the primary kernel; additional kernels (pre/post) are stored as `self.prep_kernel`, `self.post_kernel`, etc. +- The `Op` does not store `accum_dtype` — that is internal to the `Kernel` +- `dtype` defaults to `torch.float16` unless the op has a specific reason to require explicit dtype +- `kernel_map` parameter allows callers to inject alternative kernel implementations (for testing or custom dispatch); always default to `None` +- `tune=False` is always before `kernel_map`; `kernel_map` is always the last `__init__` parameter +- `accum_dtype` is never a parameter of the Op — it is hardcoded inside the Kernel + +______________________________________________________________________ + +## Unit Test Requirement + +Every new op **must** have a passing unit test in `tests/ops/test_.py` before it is considered complete. See `.claude/create-new-op-test/skill.md` for how to write the test. + +### Correctness debugging protocol + +When the unit test fails due to numerical mismatch, follow this process in order: + +1. **Do not loosen tolerances.** A large numerical error indicates a real bug, not a precision issue. The PyTorch reference is the ground truth. + +1. **Check the algorithm first.** Compare the tilelang implementation step-by-step against the official reference URL in the kernel's docstring — scale factors, layout (BSHD vs BHSD), masking logic, softmax normalization, etc. + +1. **Check memory layout and indexing.** Verify that tensor shapes, strides, and index expressions in `T.copy` / `T.gemm` match the intended layout. + +1. **Isolate the stage.** If the kernel has multiple stages (e.g. QK gemm → softmax → PV gemm), add intermediate checks to identify which stage produces the wrong result. + +1. **Replace tilelang primitives with scalar equivalents for debugging.** If the above steps do not reveal the bug, temporarily replace tilelang primitives with simpler equivalents to rule out compiler/primitive bugs: + + - Replace `T.Pipelined` with `T.serial` (removes software pipelining) + - Replace `T.copy(src, dst)` with a `T.Parallel` loop and element-wise assignment: + ```python + # DEBUG: replaced T.copy with explicit parallel assignment + for i, j in T.Parallel(M, N): + dst[i, j] = src[i, j] + ``` + - Replace `T.gemm(A, B, C)` with explicit `T.Serial` + `T.Parallel` accumulation: + ```python + # DEBUG: replaced T.gemm with explicit serial-parallel matmul + for k in T.serial(K): + for i, j in T.Parallel(M, N): + C[i, j] += A[i, k] * B[j, k] # adjust indices for transpose_B + ``` + - Mark all such changes with a `# DEBUG:` comment so they are easy to find and revert. + +1. **Revert debug changes** once the bug is found and fixed. Do not leave `T.serial` / explicit loops in production code. + +______________________________________________________________________ + +- [ ] Class name is CamelCase and ends with `Op` +- [ ] File placed at `tileops/ops/.py` following lowercase underscore naming +- [ ] `dispatch_kernel(kernel_map)` is called before kernel instantiation +- [ ] `default_kernel_map` is implemented and returns at least one entry +- [ ] `forward` calls `self.kernel(...)`, not the raw kernel function +- [ ] Unit test at `tests/ops/test_.py` exists and passes +- [ ] Op is imported and added to `__all__` in `tileops/ops/__init__.py` +- [ ] `accum_dtype` is not stored on the Op class +- [ ] `kernel_map: Optional[Dict[str, Kernel]] = None` is the last `__init__` parameter, after `tune` +- [ ] `accum_dtype` is not a parameter of the Op diff --git a/tests/ops/test_gla.py b/tests/ops/test_gla.py new file mode 100644 index 00000000..fffc1d66 --- /dev/null +++ b/tests/ops/test_gla.py @@ -0,0 +1,208 @@ +"""Correctness unit tests for GLAFwdOp. + +Reference: + https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py +""" + +import pytest +import torch +import torch.nn.functional as F + +from tileops.ops import GLAFwdOp + + +def ref_gla_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float, + initial_state: torch.Tensor | None = None, + chunk_size: int = 64, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pure PyTorch reference for GLA chunked forward. + + Implements the same 4-stage algorithm as the TileLang kernel: + 1. Within-chunk cumulative sum of log-space gates + 2. Inter-chunk hidden state recurrence with gated decay + 3. Intra-chunk causal attention matrix + 4. Output = inter-chunk (q*exp(g_cs) @ h) + intra-chunk (A @ v) + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + v: [B, T, H, V] + g: [B, T, H, K] log-space gates + scale: query scale factor + initial_state: [B, H, K, V] float32, optional + chunk_size: BT + + Returns: + (o [B, T, H, V], final_state [B, H, K, V] float32) + """ + B, T, H, K = q.shape + V = v.shape[-1] + NT = (T + chunk_size - 1) // chunk_size + + # work in float32 for numerical stability + q = q.float() + k = k.float() + v = v.float() + g = g.float() + + # Stage 1: within-chunk cumulative sum of gates + # g_cs[b, t, h, k] = sum of g[b, chunk_start..t, h, k] + g_cs = torch.zeros_like(g) + for i_c in range(NT): + cs = i_c * chunk_size + ce = min(cs + chunk_size, T) + g_cs[:, cs:ce] = torch.cumsum(g[:, cs:ce], dim=1) + + # Stage 2: inter-chunk hidden state recurrence + # h[b, i_c, h, K, V] = state entering chunk i_c + h_states = torch.zeros(B, NT, H, K, V, dtype=torch.float32, device=q.device) + b_h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + if initial_state is not None: + b_h = initial_state.float().clone() + + for i_c in range(NT): + cs = i_c * chunk_size + ce = min(cs + chunk_size, T) + h_states[:, i_c] = b_h + + # g_last: cumsum at last position of this chunk [B, H, K] + g_last = g_cs[:, ce - 1] # [B, H, K] + + # decay existing state + # b_h[b, h, k, v] *= exp(g_last[b, h, k]) + b_h = b_h * torch.exp(g_last).unsqueeze(-1) # [B, H, K, V] + + # accumulate: sum over t in chunk of k_adj[t]^T @ v[t] + # k_adj[b, t, h, k] = k[b, t, h, k] * exp(g_last[b, h, k] - g_cs[b, t, h, k]) + k_chunk = k[:, cs:ce] # [B, L, H, K] + v_chunk = v[:, cs:ce] # [B, L, H, V] + g_cs_chunk = g_cs[:, cs:ce] # [B, L, H, K] + k_adj = k_chunk * torch.exp(g_last.unsqueeze(1) - g_cs_chunk) + # b_h += einsum('blhk,blhv->bhkv', k_adj, v_chunk) + b_h = b_h + torch.einsum('blhk,blhv->bhkv', k_adj, v_chunk) + + final_state = b_h # [B, H, K, V] + + # Stage 3 + 4: intra-chunk attention and output + o = torch.zeros(B, T, H, V, dtype=torch.float32, device=q.device) + + for i_c in range(NT): + cs = i_c * chunk_size + ce = min(cs + chunk_size, T) + L = ce - cs + + q_c = q[:, cs:ce] # [B, L, H, K] + k_c = k[:, cs:ce] # [B, L, H, K] + v_c = v[:, cs:ce] # [B, L, H, V] + g_cs_c = g_cs[:, cs:ce] # [B, L, H, K] + h_c = h_states[:, i_c] # [B, H, K, V] + + # intra-chunk attention matrix A[b, i, h, j] = scale * sum_k( + # q[i,k]*exp(g_cs[i,k]) * k[j,k]*exp(-g_cs[j,k]) ), causal + q_gated = q_c * torch.exp(g_cs_c) # [B, L, H, K] + k_gated = k_c * torch.exp(-g_cs_c) # [B, L, H, K] + # A[b, h, i, j] = scale * q_gated[b,i,h,:] @ k_gated[b,j,h,:]^T + # rearrange to [B, H, L, K] for bmm + qg = q_gated.permute(0, 2, 1, 3) # [B, H, L, K] + kg = k_gated.permute(0, 2, 1, 3) # [B, H, L, K] + A = scale * torch.bmm(qg.reshape(B * H, L, K), + kg.reshape(B * H, L, K).transpose(1, 2)).reshape(B, H, L, + L) # [B, H, L, L] + + # causal mask + causal_mask = torch.tril(torch.ones(L, L, device=q.device, dtype=torch.bool)) + A = A * causal_mask.unsqueeze(0).unsqueeze(0) # [B, H, L, L] + + # intra-chunk output: A @ v [B, H, L, V] + vc = v_c.permute(0, 2, 1, 3).reshape(B * H, L, V) + o_intra = torch.bmm(A.reshape(B * H, L, L), + vc).reshape(B, H, L, V).permute(0, 2, 1, 3) # [B, L, H, V] + + # inter-chunk output: scale * (q*exp(g_cs)) @ h [B, L, H, V] + # q_gated [B, L, H, K], h_c [B, H, K, V] + o_inter = scale * torch.einsum('blhk,bhkv->blhv', q_gated, h_c) + + o[:, cs:ce] = o_intra + o_inter + + return o, final_state + + +@pytest.mark.parametrize( + "batch, seq_len, heads, dim_k, dim_v, chunk_size, output_final_state, dtype, tune", + [ + (1, 64, 4, 64, 64, 64, False, torch.float16, False), + (2, 128, 8, 64, 64, 64, False, torch.bfloat16, False), + (1, 256, 4, 128, 128, 64, False, torch.float16, False), + (2, 128, 8, 64, 128, 32, False, torch.bfloat16, False), + (4, 256, 16, 64, 64, 64, False, torch.float16, True), + (1, 64, 4, 64, 64, 64, True, torch.float16, False), # with initial_state + final_state + (2, 128, 8, 64, 64, 64, True, torch.bfloat16, False), + ], +) +def test_gla_fwd( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int, + output_final_state: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + torch.manual_seed(42) + + scale = dim_k**-0.5 + + op = GLAFwdOp( + batch=batch, + seq_len=seq_len, + heads=heads, + dim_k=dim_k, + dim_v=dim_v, + chunk_size=chunk_size, + scale=scale, + output_final_state=output_final_state, + dtype=dtype, + tune=tune, + ) + + q = torch.randn(batch, seq_len, heads, dim_k, device='cuda', dtype=dtype) + k = torch.randn(batch, seq_len, heads, dim_k, device='cuda', dtype=dtype) + v = torch.randn(batch, seq_len, heads, dim_v, device='cuda', dtype=dtype) + g = F.logsigmoid(torch.randn(batch, seq_len, heads, dim_k, device='cuda', dtype=dtype)) + + initial_state = None + if output_final_state: + initial_state = torch.randn( + batch, heads, dim_k, dim_v, device='cuda', dtype=torch.float32) * 0.1 + + with torch.no_grad(): + out, out_final = op(q, k, v, g, initial_state) + + ref_o, ref_final = ref_gla_fwd( + q, k, v, g, scale=scale, initial_state=initial_state, chunk_size=chunk_size) + ref_o = ref_o.to(dtype) + + assert torch.allclose(out, ref_o, atol=1e-2, rtol=1e-2), \ + f"output mismatch: max err = {(out.float() - ref_o.float()).abs().max():.6f}" + + if output_final_state: + assert out_final is not None + assert torch.allclose(out_final.float(), ref_final.float(), atol=1e-2, rtol=1e-2), \ + f"final_state mismatch: max err = {(out_final.float() - ref_final.float()).abs().max():.6f}" + + +def test_gla_fwd_non_divisible_seq_len() -> None: + """GLAFwdOp must reject seq_len not divisible by chunk_size.""" + with pytest.raises(AssertionError, match="must be divisible"): + GLAFwdOp(batch=1, seq_len=100, heads=4, dim_k=64, dim_v=64, chunk_size=64) + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tileops/kernels/gla/__init__.py b/tileops/kernels/gla/__init__.py new file mode 100644 index 00000000..7f1e093f --- /dev/null +++ b/tileops/kernels/gla/__init__.py @@ -0,0 +1,3 @@ +from .gla_fwd import GLAFwdKernel + +__all__ = ["GLAFwdKernel"] diff --git a/tileops/kernels/gla/gla_fwd.py b/tileops/kernels/gla/gla_fwd.py new file mode 100644 index 00000000..971a8f64 --- /dev/null +++ b/tileops/kernels/gla/gla_fwd.py @@ -0,0 +1,374 @@ +import torch +from typing import Optional, Any, Callable + +import tilelang +from tilelang import language as T + +from tileops.kernels.kernel import Kernel + +LOG2_E = 1.44269504 + + +def _gla_fwd_kernel( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int, + scale: float, + dtype: str, +) -> Callable: + """GLA (Gated Linear Attention) forward kernel. + + Implements chunked GLA forward in 4 stages within a single @T.prim_func: + Stage 1: within-chunk cumulative sum of log-space gates -> g_cumsum + Stage 3: intra-chunk causal attention matrix A = q_gated @ k_gated^T + Stage 4: output o = scale * q_gated @ h_state + A @ v + Stage 2: inter-chunk hidden state recurrence -> h_state (carried across chunks) + + Stages run in order 1→3→4→2 so that stage4 reads h_s before stage2 decays it. + + Args: + q [B, T, H, K] fp16/bf16 queries + k [B, T, H, K] fp16/bf16 keys + v [B, T, H, V] fp16/bf16 values + g [B, T, H, K] fp16/bf16 log-space forget gates + initial_state [B, H, K, V] float32 initial hidden state (zeros if unused) + h_out [B, NT, H, K, V] float32 per-chunk hidden states (output) + o [B, T, H, V] fp16/bf16 output + + Reference: + https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py + """ + @tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + def _gla_fwd_func(threads: int): + # Shape lists defined inside _gla_fwd_func so the outer closure only + # captures serializable scalars (int/float/str), satisfying tilelang autotuner. + accum_dtype = "float32" + num_chunks = (seq_len + chunk_size - 1) // chunk_size + + q_shape = [batch, seq_len, heads, dim_k] + k_shape = [batch, seq_len, heads, dim_k] + v_shape = [batch, seq_len, heads, dim_v] + g_shape = [batch, seq_len, heads, dim_k] + init_state_shape = [batch, heads, dim_k, dim_v] + h_out_shape = [batch, num_chunks, heads, dim_k, dim_v] + o_shape = [batch, seq_len, heads, dim_v] + + @T.macro + def stage1_cumsum( + g: T.Tensor(g_shape, dtype), + i_b: int, + i_h: int, + i_c: int, + g_cumsum_s: T.Buffer([chunk_size, dim_k], accum_dtype), + ): + """Within-chunk cumulative sum of log-space gates. + Reads g directly from global memory to avoid an extra shared buffer. + """ + chunk_start = i_c * chunk_size + for i_k in T.Parallel(dim_k): + g_cumsum_s[0, i_k] = T.cast(g[i_b, chunk_start, i_h, i_k], accum_dtype) + for i_t in T.Serial(1, chunk_size): + for i_k in T.Parallel(dim_k): + g_cumsum_s[i_t, i_k] = g_cumsum_s[i_t - 1, i_k] + T.cast( + g[i_b, chunk_start + i_t, i_h, i_k], accum_dtype) + + @T.macro + def stage3_intra( + q: T.Tensor(q_shape, dtype), + k: T.Tensor(k_shape, dtype), + g_cumsum_s: T.Buffer([chunk_size, dim_k], accum_dtype), + A_s: T.Buffer([chunk_size, chunk_size], accum_dtype), + qf32_s: T.Buffer([chunk_size, dim_k], accum_dtype), + kf32_s: T.Buffer([chunk_size, dim_k], accum_dtype), + i_b: int, + i_h: int, + i_c: int, + ): + """Intra-chunk causal attention matrix A = q_gated @ k_gated^T. + Reads q and k directly from global memory into qf32_s / kf32_s. + """ + chunk_start = i_c * chunk_size + for i_t, i_k in T.Parallel(chunk_size, dim_k): + qf32_s[i_t, i_k] = T.cast(q[i_b, chunk_start + i_t, i_h, i_k], + accum_dtype) * T.exp2( + g_cumsum_s[i_t, i_k] * LOG2_E) + kf32_s[i_t, i_k] = T.cast(k[i_b, chunk_start + i_t, i_h, i_k], + accum_dtype) * T.exp2( + -g_cumsum_s[i_t, i_k] * LOG2_E) + + A_frag = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + T.fill(A_frag, 0.0) + T.gemm(qf32_s, kf32_s, A_frag, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + for i_t, i_j in T.Parallel(chunk_size, chunk_size): + A_s[i_t, i_j] = T.if_then_else(i_j <= i_t, A_frag[i_t, i_j] * scale, 0.0) + + @T.macro + def stage4_output( + q: T.Tensor(q_shape, dtype), + v: T.Tensor(v_shape, dtype), + g_cumsum_s: T.Buffer([chunk_size, dim_k], accum_dtype), + A_s: T.Buffer([chunk_size, chunk_size], accum_dtype), + h_s: T.Buffer([dim_k, dim_v], accum_dtype), + o: T.Tensor(o_shape, dtype), + qf32_s: T.Buffer([chunk_size, dim_k], accum_dtype), + vf32_s: T.Buffer([chunk_size, dim_v], accum_dtype), + i_b: int, + i_h: int, + i_c: int, + ): + """Output: o = scale * q_gated @ h_s + A @ v. + Called before stage2 so h_s is the pre-decay state entering this chunk. + Reads q and v directly from global memory. + """ + chunk_start = i_c * chunk_size + for i_t, i_k in T.Parallel(chunk_size, dim_k): + qf32_s[i_t, i_k] = T.cast(q[i_b, chunk_start + i_t, i_h, i_k], + accum_dtype) * T.exp2( + g_cumsum_s[i_t, i_k] * LOG2_E) + for i_t, i_v in T.Parallel(chunk_size, dim_v): + vf32_s[i_t, i_v] = T.cast(v[i_b, chunk_start + i_t, i_h, i_v], accum_dtype) + + acc = T.alloc_fragment([chunk_size, dim_v], accum_dtype) + T.fill(acc, 0.0) + T.gemm(qf32_s, h_s, acc, policy=T.GemmWarpPolicy.FullRow) + for i_t, i_v in T.Parallel(chunk_size, dim_v): + acc[i_t, i_v] = acc[i_t, i_v] * scale + T.gemm(A_s, vf32_s, acc, policy=T.GemmWarpPolicy.FullRow) + + for i_t, i_v in T.Parallel(chunk_size, dim_v): + o[i_b, chunk_start + i_t, i_h, i_v] = T.cast(acc[i_t, i_v], dtype) + + @T.macro + def stage2_recurrence( + k: T.Tensor(k_shape, dtype), + v: T.Tensor(v_shape, dtype), + g_cumsum_s: T.Buffer([chunk_size, dim_k], accum_dtype), + h_s: T.Buffer([dim_k, dim_v], accum_dtype), + h_out: T.Tensor(h_out_shape, accum_dtype), + kf32_s: T.Buffer([chunk_size, dim_k], accum_dtype), + vf32_s: T.Buffer([chunk_size, dim_v], accum_dtype), + i_b: int, + i_h: int, + i_c: int, + ): + """Inter-chunk hidden state recurrence. h_s is carried across chunks. + Saves pre-decay h_s to h_out, then decays and accumulates. + Reads k and v directly from global memory. + """ + chunk_start = i_c * chunk_size + # Save h entering this chunk to h_out + T.copy(h_s, h_out[i_b, i_c, i_h, :, :]) + + g_last = T.alloc_fragment([dim_k], accum_dtype) + for i_k in T.Parallel(dim_k): + g_last[i_k] = g_cumsum_s[chunk_size - 1, i_k] + + # Decay h: h[k, v] *= exp(g_last[k]) + for i_k, i_v in T.Parallel(dim_k, dim_v): + h_s[i_k, i_v] = h_s[i_k, i_v] * T.exp2(g_last[i_k] * LOG2_E) + + # k_adj[t, k] = k[t, k] * exp(g_last[k] - g_cumsum[t, k]) + for i_t, i_k in T.Parallel(chunk_size, dim_k): + kf32_s[i_t, i_k] = T.cast(k[i_b, chunk_start + i_t, i_h, i_k], + accum_dtype) * T.exp2( + (g_last[i_k] - g_cumsum_s[i_t, i_k]) * LOG2_E) + for i_t, i_v in T.Parallel(chunk_size, dim_v): + vf32_s[i_t, i_v] = T.cast(v[i_b, chunk_start + i_t, i_h, i_v], accum_dtype) + + # h += k_adj^T @ v [K, V] + delta_h = T.alloc_fragment([dim_k, dim_v], accum_dtype) + T.fill(delta_h, 0.0) + T.gemm(kf32_s, vf32_s, delta_h, transpose_A=True, policy=T.GemmWarpPolicy.FullRow) + for i_k, i_v in T.Parallel(dim_k, dim_v): + h_s[i_k, i_v] = h_s[i_k, i_v] + delta_h[i_k, i_v] + + @T.prim_func + def _main( + q: T.Tensor(q_shape, dtype), + k: T.Tensor(k_shape, dtype), + v: T.Tensor(v_shape, dtype), + g: T.Tensor(g_shape, dtype), + initial_state: T.Tensor(init_state_shape, accum_dtype), + h_out: T.Tensor(h_out_shape, accum_dtype), + o: T.Tensor(o_shape, dtype), + ): + with T.Kernel(batch * heads, threads=threads) as bx: + i_b = bx // heads + i_h = bx % heads + + # Persistent buffers across the chunk loop. + # Shared memory budget (dim_k=128, dim_v=128, chunk_size=64): + # h_s(65536) + g_cumsum_s(32768) + A_s(16384) + # + qf32_s(32768) + kf32_s(32768) + vf32_s(32768) = 212992 < 232448 + h_s = T.alloc_shared([dim_k, dim_v], accum_dtype) + g_cumsum_s = T.alloc_shared([chunk_size, dim_k], accum_dtype) + A_s = T.alloc_shared([chunk_size, chunk_size], accum_dtype) + # Scratch buffers reused across stages each iteration + qf32_s = T.alloc_shared([chunk_size, dim_k], accum_dtype) + kf32_s = T.alloc_shared([chunk_size, dim_k], accum_dtype) + vf32_s = T.alloc_shared([chunk_size, dim_v], accum_dtype) + + T.copy(initial_state[i_b, i_h, :, :], h_s) + + for i_c in T.Serial(num_chunks): + stage1_cumsum(g, i_b, i_h, i_c, g_cumsum_s) + stage3_intra(q, k, g_cumsum_s, A_s, qf32_s, kf32_s, i_b, i_h, i_c) + # stage4 runs before stage2: h_s is still the pre-decay state + stage4_output(q, v, g_cumsum_s, A_s, h_s, o, qf32_s, vf32_s, i_b, i_h, i_c) + # stage2 decays h_s and accumulates k^T v; saves pre-decay to h_out + stage2_recurrence(k, v, g_cumsum_s, h_s, h_out, kf32_s, vf32_s, i_b, i_h, + i_c) + + # Overwrite the last h_out slot with the fully-updated final state + T.copy(h_s, h_out[i_b, num_chunks - 1, i_h, :, :]) + + return _main + + return _gla_fwd_func + + +@torch.library.custom_op("top::gla_fwd_wrapped_kernel", mutates_args=()) +def _gla_fwd_wrapped_kernel( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int, + scale: float, + dtype: str, + threads: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor, + h_out: torch.Tensor, +) -> torch.Tensor: + return _gla_fwd_kernel(batch, seq_len, heads, dim_k, dim_v, chunk_size, scale, + dtype)(threads)(q, k, v, g, initial_state, h_out) + + +@_gla_fwd_wrapped_kernel.register_fake +def _( + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int, + scale: float, + dtype: str, + threads: int, + *inputs: tuple[Any], +) -> torch.Tensor: + _ = (dim_k, chunk_size, scale, dtype, threads) + return torch.empty([batch, seq_len, heads, dim_v], + dtype=inputs[0].dtype, + device=inputs[0].device) + + +class GLAFwdKernel(Kernel): + """GLA (Gated Linear Attention) forward kernel. + + Args: + q: Query tensor, shape [batch, seq_len, heads, dim_k] + k: Key tensor, shape [batch, seq_len, heads, dim_k] + v: Value tensor, shape [batch, seq_len, heads, dim_v] + g: Log-space forget gates, shape [batch, seq_len, heads, dim_k] + initial_state: Optional initial hidden state, shape [batch, heads, dim_k, dim_v] + + Computation: + Chunked GLA forward in 4 TileLang stages per chunk, fused in a single + T.Serial(NT) loop. Stages run in order 1→3→4→2 so that stage4 reads + the pre-decay hidden state before stage2 updates it: + 1. Within-chunk cumulative sum of gates -> g_cumsum + 3. Intra-chunk causal attention A = scale * q_gated @ k_gated^T (causal masked) + 4. Output o = scale * q_gated @ h + A @ v (h is pre-decay) + 2. Inter-chunk hidden state recurrence h += k_adj^T @ v (h carried in shared memory) + + Reference: + https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gla/chunk.py + """ + + supported_archs: list[int] = [90] + + def __init__( + self, + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int = 64, + scale: float = -1.0, + output_final_state: bool = False, + dtype: torch.dtype = torch.float16, + config: Optional[dict] = None, + tune: bool = False, + ) -> None: + super().__init__() + self.batch = batch + self.seq_len = seq_len + self.heads = heads + self.dim_k = dim_k + self.dim_v = dim_v + self.chunk_size = chunk_size + self.scale = scale if scale > 0 else dim_k**-0.5 + self.output_final_state = output_final_state + self.dtype_name = str(dtype).split('.')[-1] + self.kernel = _gla_fwd_kernel(batch, seq_len, heads, dim_k, dim_v, chunk_size, self.scale, + self.dtype_name) + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return {"threads": 128} + + @property + def autotune_configs(self) -> list[dict]: + return [{"threads": t} for t in [64, 128, 256]] + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + B, T, H, K = self.batch, self.seq_len, self.heads, self.dim_k + V = self.dim_v + BT = self.chunk_size + NT = (T + BT - 1) // BT + dtype_torch = getattr(torch, self.dtype_name) + + if initial_state is None: + init_state = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + else: + init_state = initial_state.to(torch.float32) + + h_out = torch.empty(B, NT, H, K, V, dtype=torch.float32, device=q.device) + + o = _gla_fwd_wrapped_kernel( + B, T, H, K, V, BT, self.scale, self.dtype_name, self.config["threads"], + q.to(dtype_torch), + k.to(dtype_torch), + v.to(dtype_torch), + g.to(dtype_torch), + init_state, + h_out, + ) + + final_state = h_out[:, -1] if self.output_final_state else None + return o, final_state diff --git a/tileops/ops/__init__.py b/tileops/ops/__init__.py index d963a62a..6dc6e6bf 100644 --- a/tileops/ops/__init__.py +++ b/tileops/ops/__init__.py @@ -1,48 +1,50 @@ -from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheOp -from .fp8_lighting_indexer import Fp8LightingIndexerOp -from .topk_selector import TopkSelectorOp -from .fp8_quant import Fp8QuantOp -from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp -from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp -from .gemm import GemmOp -from .gemv import GemvOp -from .gqa import GroupQueryAttentionBwdOp, GroupQueryAttentionFwdOp -from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp -from .gqa_decode_paged import GroupQueryAttentionDecodePagedWithKVCacheOp -from .grouped_gemm import GroupedGemmNNOp, GroupedGemmNTOp, GroupedGemmTNOp, GroupedGemmTTOp -from .mha import MultiHeadAttentionBwdOp, MultiHeadAttentionFwdOp -from .mha_decode import MultiHeadAttentionDecodeWithKVCacheOp -from .mha_decode_paged import MultiHeadAttentionDecodePagedWithKVCacheOp -from .mhc_pre import ManifoldConstrainedHyperConnectionPreOp -from .mhc_post import ManifoldConstrainedHyperConnectionPostOp -from .op import Op # noqa: F401 - -__all__ = [ - "Op", - "MultiHeadAttentionFwdOp", - "MultiHeadAttentionBwdOp", - "GroupQueryAttentionFwdOp", - "GroupQueryAttentionBwdOp", - "GemmOp", - "GemvOp", - "MultiHeadAttentionDecodeWithKVCacheOp", - "MultiHeadAttentionDecodePagedWithKVCacheOp", - "GroupQueryAttentionDecodeWithKVCacheOp", - "GroupQueryAttentionDecodePagedWithKVCacheOp", - "GroupedGemmNTOp", - "GroupedGemmNNOp", - "GroupedGemmTNOp", - "GroupedGemmTTOp", - "MultiHeadLatentAttentionDecodeWithKVCacheOp", - "DeepSeekSparseAttentionDecodeWithKVCacheOp", - "Fp8LightingIndexerOp", - "TopkSelectorOp", - "Fp8QuantOp", - "MeanPoolingForwardOp", - "NSATopkVarlenOp", - "NSAFwdVarlenOp", - "NSACmpFwdVarlenOp", - "GQAWindowSlidingOp", - "ManifoldConstrainedHyperConnectionPreOp", - "ManifoldConstrainedHyperConnectionPostOp", -] +from .deepseek_dsa_decode import DeepSeekSparseAttentionDecodeWithKVCacheOp +from .fp8_lighting_indexer import Fp8LightingIndexerOp +from .topk_selector import TopkSelectorOp +from .fp8_quant import Fp8QuantOp +from .deepseek_mla_decode import MultiHeadLatentAttentionDecodeWithKVCacheOp +from .deepseek_nsa import MeanPoolingForwardOp, NSAFwdVarlenOp, NSATopkVarlenOp, NSACmpFwdVarlenOp, GQAWindowSlidingOp +from .gemm import GemmOp +from .gemv import GemvOp +from .gqa import GroupQueryAttentionBwdOp, GroupQueryAttentionFwdOp +from .gqa_decode import GroupQueryAttentionDecodeWithKVCacheOp +from .gqa_decode_paged import GroupQueryAttentionDecodePagedWithKVCacheOp +from .grouped_gemm import GroupedGemmNNOp, GroupedGemmNTOp, GroupedGemmTNOp, GroupedGemmTTOp +from .mha import MultiHeadAttentionBwdOp, MultiHeadAttentionFwdOp +from .mha_decode import MultiHeadAttentionDecodeWithKVCacheOp +from .mha_decode_paged import MultiHeadAttentionDecodePagedWithKVCacheOp +from .mhc_pre import ManifoldConstrainedHyperConnectionPreOp +from .mhc_post import ManifoldConstrainedHyperConnectionPostOp +from .gla import GLAFwdOp +from .op import Op # noqa: F401 + +__all__ = [ + "Op", + "MultiHeadAttentionFwdOp", + "MultiHeadAttentionBwdOp", + "GroupQueryAttentionFwdOp", + "GroupQueryAttentionBwdOp", + "GemmOp", + "GemvOp", + "MultiHeadAttentionDecodeWithKVCacheOp", + "MultiHeadAttentionDecodePagedWithKVCacheOp", + "GroupQueryAttentionDecodeWithKVCacheOp", + "GroupQueryAttentionDecodePagedWithKVCacheOp", + "GroupedGemmNTOp", + "GroupedGemmNNOp", + "GroupedGemmTNOp", + "GroupedGemmTTOp", + "MultiHeadLatentAttentionDecodeWithKVCacheOp", + "DeepSeekSparseAttentionDecodeWithKVCacheOp", + "Fp8LightingIndexerOp", + "TopkSelectorOp", + "Fp8QuantOp", + "MeanPoolingForwardOp", + "NSATopkVarlenOp", + "NSAFwdVarlenOp", + "NSACmpFwdVarlenOp", + "GQAWindowSlidingOp", + "ManifoldConstrainedHyperConnectionPreOp", + "ManifoldConstrainedHyperConnectionPostOp", + "GLAFwdOp", +] diff --git a/tileops/ops/gla.py b/tileops/ops/gla.py new file mode 100644 index 00000000..f51cd38d --- /dev/null +++ b/tileops/ops/gla.py @@ -0,0 +1,90 @@ +"""GLA (Gated Linear Attention) Forward Op.""" + +from typing import Any, Dict, Optional + +import torch + +from tileops.ops.op import Op + + +class GLAFwdOp(Op): + """Op wrapper for GLA (Gated Linear Attention) forward pass. + + Dispatches to GLAFwdKernel which implements the 4-stage chunked forward: + cumulative gate sum -> inter-chunk recurrence -> intra-chunk attention -> output. + + Args: + batch: Batch size B. + seq_len: Sequence length T. Must be divisible by chunk_size. + heads: Number of query/key/value heads H. + dim_k: Key/query head dimension K. + dim_v: Value head dimension V. + chunk_size: Chunk size BT (default 64). + scale: Query scale factor. Defaults to 1/sqrt(dim_k). + output_final_state: If True, also return the final hidden state. + dtype: Input tensor dtype (default torch.float16). + tune: Whether to run kernel autotuning. + kernel_map: Optional override for the kernel dispatch map. + + Example: + >>> op = GLAFwdOp(batch=2, seq_len=128, heads=8, dim_k=64, dim_v=64) + >>> q = torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16) + >>> k = torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16) + >>> v = torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16) + >>> g = torch.nn.functional.logsigmoid( + ... torch.randn(2, 128, 8, 64, device='cuda', dtype=torch.float16)) + >>> o, final_state = op(q, k, v, g) + """ + + def __init__( + self, + batch: int, + seq_len: int, + heads: int, + dim_k: int, + dim_v: int, + chunk_size: int = 64, + scale: Optional[float] = None, + output_final_state: bool = False, + dtype: torch.dtype = torch.float16, + tune: bool = False, + kernel_map: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + assert seq_len % chunk_size == 0, ( + f"seq_len ({seq_len}) must be divisible by chunk_size ({chunk_size})") + params = {k: v for k, v in locals().items() if k not in ('self', 'kernel_map', '__class__')} + # resolve default scale before storing + if params['scale'] is None: + params['scale'] = dim_k**-0.5 + for k, v in params.items(): + setattr(self, k, v) + self.dispatch_kernel(kernel_map) + self.kernel = self.kernel_map["gla_fwd"](**params) + + @property + def default_kernel_map(self) -> Dict[str, Any]: + from tileops.kernels.gla import GLAFwdKernel + return {"gla_fwd": GLAFwdKernel} + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run GLA forward. + + Args: + q: Queries [B, T, H, K]. + k: Keys [B, T, H, K]. + v: Values [B, T, H, V]. + g: Log-space forget gates [B, T, H, K]. + initial_state: Optional initial hidden state [B, H, K, V] float32. + + Returns: + Tuple of (o [B, T, H, V], final_state [B, H, K, V] or None). + """ + return self.kernel(q, k, v, g, initial_state)