Skip to content

Conversation

@sustcsonglin
Copy link
Collaborator

@sustcsonglin sustcsonglin commented Nov 22, 2025

python benchmarks/ops/benchmark_kda_intra.py on H200

Shape | Token (Original) | Recursive (New) | Recurrent | Speedup (Rec/Tok)

B=8, T=4096, H=16, K=128 | 1.206 ms | 1.246 ms | 1.361 ms | 0.97x
B=1, T=8192, H=16, K=128 | 0.330 ms | 0.346 ms | 0.373 ms | 0.95x
B=8, T=4096, H=32, K=64 | 1.743 ms | 1.502 ms | 1.905 ms | 1.16x
B=1, T=8192, H=32, K=64 | 0.466 ms | 0.406 ms | 0.504 ms | 1.15x
B=32, T=512, H=12, K=64 | 0.363 ms | 0.322 ms | 0.387 ms | 1.13x
B=2, T=4096, H=8, K=256 | 0.273 ms | 0.677 ms | 0.333 ms | 0.40x

Summary by CodeRabbit

  • New Features

    • Added benchmarking capabilities for kernel descent algorithm intra-chunk forward passes across multiple implementation strategies.
    • Enabled multi-head key-value grouping support in linear attention layer with improved parameter validation.
    • Introduced configurable computation strategies for intra-chunk operations with auto-detection.
  • Refactor

    • Updated computation paths for KDA intra-chunk operations to support multiple implementation strategies.

✏️ Tip: You can customize this high-level summary in your review settings.

…election.

- Merged Token Parallel and Recursive Block implementations into chunk_intra.py.
- Added 'recurrent' (naive sequential) implementation for verification.
- Updated chunk_kda_fwd_intra to use 'impl_type' argument and automatically select between 'token' and 'recursive' based on head dimension K (threshold K=128).
- Added comprehensive benchmark script in benchmarks/ops/benchmark_kda_intra.py covering various shapes.
- Performance: Recursive Block achieves ~15% speedup for small K (<=64), while Token Parallel remains superior for large K.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 22, 2025

Walkthrough

This PR adds a new KDA intra-chunk benchmarking script, introduces KV-group factorization support to ReBasedLinearAttention with divisibility validations and adjusted projection dimensions, and extends the intra-chunk KDA kernel with multiple implementation strategies (token-parallel, recursive, recurrent) selectable via an impl_type parameter.

Changes

Cohort / File(s) Summary
Benchmarking Infrastructure
benchmarks/ops/benchmark_kda_intra.py
New script that benchmarks three KDA intra-chunk forward implementations (token, recursive, recurrent) with configurable tensor sizes, warmup loops, and formatted timing output including speedup metrics.
Layer Architecture Updates
fla/layers/rebased.py
Added KV-group support with input validation (divisibility checks), new attributes (num_kv_groups, head_dim), and adjusted projection shapes to use num_key_value_heads instead of num_heads; KV tensors expanded via repeat when multiple groups exist.
Kernel Implementation Variants
fla/ops/kda/chunk_intra.py
Replaced use_token_parallel flag with configurable impl_type parameter ("auto", "token", "recursive", "recurrent"); added new recurrent and token-parallel kernels; dispatches to appropriate implementation based on impl_type and feature dimension; fixed block size parameterization with BC=16.

Sequence Diagram

sequenceDiagram
    participant User
    participant chunk_kda_fwd_intra as chunk_kda_fwd_intra<br/>(impl_type)
    participant TokenParallel as Token-Parallel<br/>Kernel
    participant Recurrent as Recurrent<br/>Kernel
    participant Recursive as Recursive<br/>Kernel (default)

    User->>chunk_kda_fwd_intra: Call with q, k, g, impl_type
    
    alt impl_type == "token"
        chunk_kda_fwd_intra->>TokenParallel: Dispatch
        TokenParallel->>TokenParallel: Compute with token-parallel strategy
    else impl_type == "recurrent"
        chunk_kda_fwd_intra->>Recurrent: Dispatch (3D grid: NT, NC, B×H)
        Recurrent->>Recurrent: Per-token recurrence accumulation<br/>with gating parameters
    else impl_type == "auto" or "recursive"
        chunk_kda_fwd_intra->>chunk_kda_fwd_intra: Auto-detect:<br/>K≥128 → "token",<br/>else → "recursive"
        alt Auto-selected "token"
            chunk_kda_fwd_intra->>TokenParallel: Dispatch
        else Auto-selected "recursive"
            chunk_kda_fwd_intra->>Recursive: Dispatch (original path)
        end
    end
    
    rect rgb(200, 220, 255)
        note over TokenParallel,Recursive: BC=16 fixed,<br/>BK computed accordingly
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • fla/ops/kda/chunk_intra.py: New kernels with recurrent/gating logic and modified dispatch control flow require careful verification of accumulation semantics and kernel grid configurations.
  • fla/layers/rebased.py: Input validation and KV-group expansion via repeat need attention to ensure shape correctness in downstream attention computations.
  • benchmarks/ops/benchmark_kda_intra.py: Low complexity but should verify benchmark configurations align with layer changes and kernel expectations.

Possibly related PRs

Poem

🐰 Chunking atoms, kernels bright,
With KV groups in perfect sight,
Token, recursive, recurrence's call,
Three paths forward, we benchmark them all!
Fast and faster, the benchmarks roll,
Helping us reach our performance goal. ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 10.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[kda] add recursive block intra implementation' directly and specifically describes the main change: adding a recursive block intra implementation to the KDA module, which is evidenced by the new kernels, impl_type parameter, and benchmarking script in the changeset.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/kda-recursive-block

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @sustcsonglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the KDA (Kernelized Deep Attention) implementation by introducing a new "recursive block intra" approach, which demonstrates performance improvements in several benchmark scenarios. It also refactors the KDA intra-chunk forward pass to support multiple implementation types, including the new recursive method, and updates the ReBasedLinearAttention module to incorporate Grouped Query Attention (GQA) capabilities. These changes aim to boost efficiency and flexibility in attention computations.

Highlights

  • New Recursive KDA Intra Implementation: Introduced a novel "recursive block intra" implementation for the KDA attention mechanism, designed for improved performance in specific configurations.
  • Performance Benchmarking: Added a new benchmark script (benchmark_kda_intra.py) to compare the performance of "token", "recursive", and "recurrent" KDA intra implementations across various input shapes and dimensions.
  • Grouped Query Attention (GQA) Support: Enhanced the ReBasedLinearAttention module to support Grouped Query Attention (GQA) by modifying key/value projection and replication logic for more efficient attention computations.
  • Flexible Implementation Selection: The chunk_kda_fwd_intra function now allows explicit selection of the KDA intra implementation type ("token", "recursive", "recurrent") or automatic selection based on the K dimension.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new recursive implementation for the intra-chunk KDA kernel, which demonstrates performance benefits for smaller head dimensions. It also refactors the kernel selection logic, allowing for dynamic switching between token, recursive, and recurrent implementations. Furthermore, the changes correctly implement Grouped-Query Attention (GQA) in the ReBasedLinearAttention layer, addressing some important dimension-related bugs. The new benchmark script is a valuable addition for performance verification. My review includes a few suggestions to enhance code clarity, maintainability, and robustness, particularly within the new Triton kernel and the implementation selection logic.

Comment on lines +174 to +176
# For BC=64, we need to handle span=32 (log2=5).
# Starting from 6 is safe for BC up to 128.
for log_span in range(3, -1, -1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment on lines 174-175 is misleading as it refers to BC=64, but the implementation and the loop on line 176 are for BC=16. To improve clarity and maintainability, the comment should be updated, and the loop range should be derived from the BC constant rather than being hardcoded.

Suggested change
# For BC=64, we need to handle span=32 (log2=5).
# Starting from 6 is safe for BC up to 128.
for log_span in range(3, -1, -1):
# For BC=16, we need to handle span=8 (log2=3).
# The loop range is `log2(BC)-1` down to 0.
for log_span in range(tl.static_log2(BC) - 1, -1, -1):

b_g = tl.load(p_g, boundary_check=(0, 1))

b_k = b_k * tl.load(beta + (i_t * BT + i_i * BC + o_i) * H, mask=m_A, other=0)[:, None]
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) * 1.44269504
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 1.44269504 is used here. This corresponds to log2(e) and is used to convert the base of the exponentiation from e to 2. It would be beneficial for readability to either define this as a named constant at the module level (e.g., LOG2_E = 1.44269504) or add a comment explaining its purpose.

Suggested change
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) * 1.44269504
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) * 1.44269504 # log2(e)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
fla/layers/rebased.py (2)

147-152: forward_reference breaks when num_key_value_heads != num_heads

Here q, k, v are reshaped as:

  • q: [B, T, feature_dim * num_heads] → [B, num_heads, T, feature_dim]
  • k: [B, T, feature_dim * num_key_value_heads] → [B, num_key_value_heads, T, feature_dim]
  • v: [B, T, num_key_value_heads * head_dim] → [B, num_key_value_heads, T, head_dim]

The subsequent attention computation assumes the same head dimension for q, k, and v, so when num_key_value_heads != num_heads the math and broadcasting are wrong (and can be outright invalid in PyTorch).

You should either:

  • Restrict forward_reference to configurations with num_key_value_heads == num_heads (and assert that), or
  • Update it to apply the same KV-group expansion (repeat) that forward uses so that k and v are broadcast to num_heads.

103-118: Unpack return values from both function calls

The verification confirms the issue. Both chunk_linear_attn() and fused_chunk_linear_attn() have return type tuple[torch.Tensor, torch.Tensor] and always return (o, final_state). The current code assigns these tuples directly to o without unpacking, but line ~131 calls rearrange(o, "... h d -> ... (h d)"), which expects a tensor and will fail at runtime.

Fix the assignments at lines 104-107 and 109-112:

o, _ = fused_chunk_linear_attn(
    q=q,
    k=k,
    v=v,
    normalize=True,
    scale=1,
)
# and
o, _ = chunk_linear_attn(
    q=q,
    k=k,
    v=v,
    normalize=True,
    scale=1,
)
🧹 Nitpick comments (9)
fla/layers/rebased.py (1)

98-102: Feature-map + KV repeat ordering is correct, but consider documenting

Applying feature_map to q and k before repeating k/v over KV groups is efficient and avoids redundant normalization work. Because this subtly changes where normalization happens relative to the grouping, a brief inline comment (e.g. “normalize once per KV-head, then share across groups”) would help future readers.

benchmarks/ops/benchmark_kda_intra.py (2)

19-33: Use torch.no_grad() to be explicit about benchmarking forward-only

Even though torch.randn defaults to requires_grad=False, wrapping benchmarked calls in a torch.no_grad() context can make it explicit this is a pure-forward perf script and avoid accidental grad tracking if the code changes.

For example:

with torch.no_grad():
    for _ in range(10):
        ...

1-3: Address trailing whitespace / pre-commit lint

GitHub Actions reports trailing whitespace stripped by pre-commit in this file. After applying code changes, please run the repo’s pre-commit hooks (or the configured formatter) locally to keep CI green.

fla/ops/kda/chunk_intra.py (6)

159-223: Recursive sub-intra kernel logic is internally consistent

Within chunk_kda_fwd_kernel_intra_sub_intra:

  • Converting b_g to base‑2 exponents via * 1.44269504 and then using tl.exp2 is mathematically consistent with the previous exp-based formulation and keeps everything in float32.
  • Diagonal contributions are seeded via b_Aqk_diag before the span loop, and the subsequent log_span loop builds larger spans bottom‑up using masks (is_q, is_k, same_block).
  • Final triangular masks
    • Aqk: o_i[:, None] >= o_i[None, :]
    • Akk: o_i[:, None] > o_i[None, :] with b_beta
      enforce causal structure and zero out unwanted regions before storing.

The only nit is that comments still mention BC=64/log2=5 while the implementation hardcodes BC=16 and loops log_span in range(3, -1, -1). Updating the comment would avoid confusion.


450-520: New recurrent sub-intra kernel appears shape- and index-safe

chunk_kda_fwd_kernel_intra_sub_intra_recurrent:

  • Correctly handles varlen vs fixed‑len via chunk_indices/cu_seqlens and early‑returns when i_t * BT + i_i * BC >= T.
  • Offsets q/k/g/beta/Aqk/Akk in the same way as other kernels, keeping head/time mapping consistent.
  • Loop over j advances p_kt/p_gk by H*K per step, which matches the [T, H, K] physical layout.
  • Final zeroing of:
    • Aqk for o_i[:, None] < o_i[None, :] (upper triangle)
    • Akk for o_i[:, None] <= o_i[None, :] (diag + upper)
      ensures the expected causal triangular structure after the per‑row recurrence.

Given the mathematical subtlety, it would be good to add or extend tests that compare impl_type="recurrent" against "token"/"recursive" on small random inputs to validate numerical equivalence.


724-761: impl_type API and auto-selection heuristic are reasonable

  • Adding impl_type: str = "auto" gives a clean way to choose between "token", "recursive", and "recurrent" without changing call sites.
  • The auto-heuristic impl_type = "token" if K >= 128 else "recursive" is simple and matches the benchmark intent (different performance regimes by head dim).

Consider validating impl_type values and raising a clear error on unknown strings to avoid silently falling into the else branch:

valid_impls = {"auto", "token", "recursive", "recurrent"}
if impl_type not in valid_impls:
    raise ValueError(f"Unsupported impl_type={impl_type!r}, expected one of {sorted(valid_impls)}")

767-769: Hardcoding BC = 16 is fine but deserves a short comment

Forward now fixes BC = 16 while backward still computes BC = min(16, BT). That’s fine for the current configs (BT=64 in benchmarks), but a brief comment like “BC fixed to 16 for fwd kernels; adjust if chunk_size is tuned beyond 64” would document the coupling between forward and backward tiling.


842-844: Minor consistency nit between forward and backward BC/BK choices

Forward uses BC = 16 and BK = max(triton.next_power_of_2(K), 16), while backward uses BC = min(16, BT) and BK = min(32, triton.next_power_of_2(K)). That’s likely intentional for numerical stability/performance, but it’s worth double-checking that all production call sites stay within the tuned regimes (e.g., BT >= 16, K <= 256 as asserted) and documenting the assumptions in a module-level comment.


1-5: Re-run pre-commit to satisfy trailing-whitespace lint

CI reports trailing whitespace removed near the top of this file by pre-commit. After applying kernel changes, please re-run the repo’s formatting/lint hooks locally to keep this file aligned with style checks.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 71260ec and c6bd19e.

📒 Files selected for processing (3)
  • benchmarks/ops/benchmark_kda_intra.py (1 hunks)
  • fla/layers/rebased.py (5 hunks)
  • fla/ops/kda/chunk_intra.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
benchmarks/ops/benchmark_kda_intra.py (1)
fla/ops/kda/chunk_intra.py (1)
  • chunk_kda_fwd_intra (714-852)
fla/layers/rebased.py (4)
fla/modules/feature_map.py (1)
  • RebasedFeatureMap (204-243)
fla/ops/linear_attn/chunk.py (1)
  • chunk_linear_attn (11-74)
fla/ops/linear_attn/fused_chunk.py (1)
  • fused_chunk_linear_attn (11-59)
fla/ops/rebased/parallel.py (1)
  • parallel_rebased (439-463)
fla/ops/kda/chunk_intra.py (2)
fla/ops/kda/chunk_intra_token_parallel.py (3)
  • chunk_kda_fwd_kernel_intra_token_parallel (25-155)
  • chunk_kda_fwd_intra_token_parallel (158-219)
  • grid (200-202)
fla/ops/utils/index.py (1)
  • prepare_chunk_indices (114-119)
🪛 Flake8 (7.3.0)
benchmarks/ops/benchmark_kda_intra.py

[error] 41-41: local variable 'e' is assigned to but never used

(F841)

🪛 GitHub Actions: lint
benchmarks/ops/benchmark_kda_intra.py

[error] 1-1: Trailing whitespace detected and removed by pre-commit trailing-whitespace hook. Re-run pre-commit to confirm.

fla/ops/kda/chunk_intra.py

[error] 1-1: Trailing whitespace detected and removed by pre-commit trailing-whitespace hook. Re-run pre-commit to confirm.

🪛 Ruff (0.14.5)
benchmarks/ops/benchmark_kda_intra.py

41-41: Do not catch blind exception: Exception

(BLE001)


41-41: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)


67-67: Do not catch blind exception: Exception

(BLE001)

fla/layers/rebased.py

54-54: Avoid specifying long messages outside the exception class

(TRY003)


56-56: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
  • GitHub Check: check-pt-python-compatibility
🔇 Additional comments (10)
fla/layers/rebased.py (5)

11-11: Import of repeat is appropriate and scoped correctly

Using einops.repeat for KV-group expansion is idiomatic here and keeps head-axis reshaping readable.


53-57: Divisibility checks for heads/KV-heads look correct

Validating hidden_size % num_heads == 0 and num_heads % num_key_value_heads == 0 early is the right guardrail for KV-grouping; messages are clear enough for debugging.


60-62: New KV-grouping attributes are consistent

num_key_value_heads, num_kv_groups = num_heads // num_key_value_heads, and head_dim = hidden_size // num_heads are coherent and match the projection shapes below.


71-75: Projection shapes align with GQA design

  • q_proj: hidden_size → feature_dim * num_heads
  • k_proj: hidden_size → feature_dim * num_key_value_heads
  • v_proj: hidden_size → num_key_value_heads * head_dim
  • o_proj: num_heads * head_dim → hidden_size

These choices are consistent with grouped KV (GQA) and with the later rearrange/repeat usage.


86-97: KV grouping expansion strategy looks sound

  • k is reshaped with h=self.num_key_value_heads.
  • v is reshaped with h=self.num_key_value_heads, d=self.head_dim.
  • When num_kv_groups > 1, repeating along h to (h * num_kv_groups) for both k and v correctly broadcasts KV heads to all query heads.

This keeps all three tensors at H = num_heads before entering the attention kernels.

benchmarks/ops/benchmark_kda_intra.py (2)

6-18: Benchmark setup is reasonable for perf comparisons

Random q/k/g/beta allocations with bfloat16 on CUDA and a fixed chunk_size give a clean microbenchmark of the three impl_type variants. Warmup and reuse of the same tensors across runs are appropriate.


64-68: Same unused-exception pattern at top-level benchmark loop

Apply the same adjustment here to satisfy F841/BLE001:

-    for B, T, H, K in configs:
-        try:
-            benchmark_intra_chunk(B=B, T=T, H=H, K=K, chunk_size=64)
-        except Exception as e:
-            print(f"Failed for shape B={B}, T={T}, H={H}, K={K}: {e}")
+    for B, T, H, K in configs:
+        try:
+            benchmark_intra_chunk(B=B, T=T, H=H, K=K, chunk_size=64)
+        except Exception as exc:
+            print(f"Failed for shape B={B}, T={T}, H={H}, K={K}: {exc}")

Here it’s actually useful to keep the exception object and log it.

Likely an incorrect or invalid review comment.

fla/ops/kda/chunk_intra.py (3)

523-552: Token-parallel Triton kernel grid and varlen handling look correct

In chunk_kda_fwd_kernel_intra_token_parallel:

  • Grid (total_tokens, H/BH) with autotuned BH matches the intended “each token × head‑group” mapping.
  • Varlen path binary-searches cu_seqlens to recover (sequence, local_t) and guards out-of-range tokens.
  • Fixed‑len path computes (i_b, i_t) via div/mod on T, with an early return when i_t >= T.
  • The subchunk computation uses BT/BC=16 to localize work to the appropriate chunk and subchunk.
  • Masking on m_h and m_k prevents out‑of‑range head/K accesses, and writes to Aqk/Akk follow the documented [B, T, H, BT] layout.

Overall this matches the intent described in the docstring for the Python wrapper.


668-712: Python wrapper for token-parallel impl is thin and correct

chunk_kda_fwd_intra_token_parallel:

  • Derives total_tokens and B_kernel correctly for both fixed‑len and varlen inputs.
  • Uses a grid(meta) closure to expose autotuned BH back to the launch grid, which is the standard Triton pattern.
  • Forwards all required arguments, including BT=chunk_size and USE_EXP2, to the kernel.

No issues from a shape or launch perspective.


793-825: impl_type dispatch wiring matches intended kernels

  • "token"chunk_kda_fwd_intra_token_parallel on top of the common chunk_kda_fwd_kernel_intra_sub_inter.
  • "recurrent" → new chunk_kda_fwd_kernel_intra_sub_intra_recurrent with grid (NT, NC, B*H) and the same BC/BK as default.
  • Anything else (currently "recursive") → chunk_kda_fwd_kernel_intra_sub_intra with identical grid and tiling.

This preserves output shapes and reuses the same Aqk/Akk buffers, so higher-level callers don’t need to change.

Comment on lines +35 to +43
try:
ms_recurrent = triton.testing.do_bench(
lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
quantiles=quantiles
)
t_recurrent = ms_recurrent[0]
except Exception as e:
t_recurrent = float('nan')

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix unused exception variable and consider narrowing the catch

You currently catch a broad Exception and bind it to e without using it, which triggers F841/BLE001 and fails lint.

A minimal fix that keeps the behavior is:

-    try:
-        ms_recurrent = triton.testing.do_bench(
-            lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
-            quantiles=quantiles
-        )
-        t_recurrent = ms_recurrent[0]
-    except Exception as e:
-        t_recurrent = float('nan')
+    try:
+        ms_recurrent = triton.testing.do_bench(
+            lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
+            quantiles=quantiles,
+        )
+        t_recurrent = ms_recurrent[0]
+    except Exception:
+        # Recurrent path not supported/stable for this config; report NaN.
+        t_recurrent = float("nan")

This removes the unused variable and documents why the broad catch is acceptable in a benchmark context.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
ms_recurrent = triton.testing.do_bench(
lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
quantiles=quantiles
)
t_recurrent = ms_recurrent[0]
except Exception as e:
t_recurrent = float('nan')
try:
ms_recurrent = triton.testing.do_bench(
lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
quantiles=quantiles,
)
t_recurrent = ms_recurrent[0]
except Exception:
# Recurrent path not supported/stable for this config; report NaN.
t_recurrent = float("nan")
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 41-41: local variable 'e' is assigned to but never used

(F841)

🪛 Ruff (0.14.5)

41-41: Do not catch blind exception: Exception

(BLE001)


41-41: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)

🤖 Prompt for AI Agents
In benchmarks/ops/benchmark_kda_intra.py around lines 35 to 43, the except block
currently uses "except Exception as e" but never uses "e", causing an
unused-variable lint error and implicitly catching all exceptions; change it to
"except Exception:" (or better, catch specific expected exceptions if known) and
add a short comment explaining why a broad catch is acceptable in this
benchmarking context so the intent is clear to reviewers and linters.

@yzhangcs yzhangcs force-pushed the main branch 2 times, most recently from ddd8f23 to 91d2f46 Compare December 25, 2025 08:19
@zhiyuan1i zhiyuan1i force-pushed the main branch 3 times, most recently from 2b3db51 to 53dda79 Compare January 22, 2026 07:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant