-
Notifications
You must be signed in to change notification settings - Fork 381
[kda] add recursive block intra implementation #656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…election. - Merged Token Parallel and Recursive Block implementations into chunk_intra.py. - Added 'recurrent' (naive sequential) implementation for verification. - Updated chunk_kda_fwd_intra to use 'impl_type' argument and automatically select between 'token' and 'recursive' based on head dimension K (threshold K=128). - Added comprehensive benchmark script in benchmarks/ops/benchmark_kda_intra.py covering various shapes. - Performance: Recursive Block achieves ~15% speedup for small K (<=64), while Token Parallel remains superior for large K.
WalkthroughThis PR adds a new KDA intra-chunk benchmarking script, introduces KV-group factorization support to ReBasedLinearAttention with divisibility validations and adjusted projection dimensions, and extends the intra-chunk KDA kernel with multiple implementation strategies (token-parallel, recursive, recurrent) selectable via an Changes
Sequence DiagramsequenceDiagram
participant User
participant chunk_kda_fwd_intra as chunk_kda_fwd_intra<br/>(impl_type)
participant TokenParallel as Token-Parallel<br/>Kernel
participant Recurrent as Recurrent<br/>Kernel
participant Recursive as Recursive<br/>Kernel (default)
User->>chunk_kda_fwd_intra: Call with q, k, g, impl_type
alt impl_type == "token"
chunk_kda_fwd_intra->>TokenParallel: Dispatch
TokenParallel->>TokenParallel: Compute with token-parallel strategy
else impl_type == "recurrent"
chunk_kda_fwd_intra->>Recurrent: Dispatch (3D grid: NT, NC, B×H)
Recurrent->>Recurrent: Per-token recurrence accumulation<br/>with gating parameters
else impl_type == "auto" or "recursive"
chunk_kda_fwd_intra->>chunk_kda_fwd_intra: Auto-detect:<br/>K≥128 → "token",<br/>else → "recursive"
alt Auto-selected "token"
chunk_kda_fwd_intra->>TokenParallel: Dispatch
else Auto-selected "recursive"
chunk_kda_fwd_intra->>Recursive: Dispatch (original path)
end
end
rect rgb(200, 220, 255)
note over TokenParallel,Recursive: BC=16 fixed,<br/>BK computed accordingly
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @sustcsonglin, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the KDA (Kernelized Deep Attention) implementation by introducing a new "recursive block intra" approach, which demonstrates performance improvements in several benchmark scenarios. It also refactors the KDA intra-chunk forward pass to support multiple implementation types, including the new recursive method, and updates the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new recursive implementation for the intra-chunk KDA kernel, which demonstrates performance benefits for smaller head dimensions. It also refactors the kernel selection logic, allowing for dynamic switching between token, recursive, and recurrent implementations. Furthermore, the changes correctly implement Grouped-Query Attention (GQA) in the ReBasedLinearAttention layer, addressing some important dimension-related bugs. The new benchmark script is a valuable addition for performance verification. My review includes a few suggestions to enhance code clarity, maintainability, and robustness, particularly within the new Triton kernel and the implementation selection logic.
| # For BC=64, we need to handle span=32 (log2=5). | ||
| # Starting from 6 is safe for BC up to 128. | ||
| for log_span in range(3, -1, -1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on lines 174-175 is misleading as it refers to BC=64, but the implementation and the loop on line 176 are for BC=16. To improve clarity and maintainability, the comment should be updated, and the loop range should be derived from the BC constant rather than being hardcoded.
| # For BC=64, we need to handle span=32 (log2=5). | |
| # Starting from 6 is safe for BC up to 128. | |
| for log_span in range(3, -1, -1): | |
| # For BC=16, we need to handle span=8 (log2=3). | |
| # The loop range is `log2(BC)-1` down to 0. | |
| for log_span in range(tl.static_log2(BC) - 1, -1, -1): |
| b_g = tl.load(p_g, boundary_check=(0, 1)) | ||
|
|
||
| b_k = b_k * tl.load(beta + (i_t * BT + i_i * BC + o_i) * H, mask=m_A, other=0)[:, None] | ||
| b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) * 1.44269504 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magic number 1.44269504 is used here. This corresponds to log2(e) and is used to convert the base of the exponentiation from e to 2. It would be beneficial for readability to either define this as a named constant at the module level (e.g., LOG2_E = 1.44269504) or add a comment explaining its purpose.
| b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) * 1.44269504 | |
| b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) * 1.44269504 # log2(e) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
fla/layers/rebased.py (2)
147-152:forward_referencebreaks whennum_key_value_heads != num_headsHere
q,k,vare reshaped as:
q: [B, T, feature_dim * num_heads] → [B, num_heads, T, feature_dim]k: [B, T, feature_dim * num_key_value_heads] → [B, num_key_value_heads, T, feature_dim]v: [B, T, num_key_value_heads * head_dim] → [B, num_key_value_heads, T, head_dim]The subsequent attention computation assumes the same head dimension for
q,k, andv, so whennum_key_value_heads != num_headsthe math and broadcasting are wrong (and can be outright invalid in PyTorch).You should either:
- Restrict
forward_referenceto configurations withnum_key_value_heads == num_heads(and assert that), or- Update it to apply the same KV-group expansion (
repeat) thatforwarduses so thatkandvare broadcast tonum_heads.
103-118: Unpack return values from both function callsThe verification confirms the issue. Both
chunk_linear_attn()andfused_chunk_linear_attn()have return typetuple[torch.Tensor, torch.Tensor]and always return(o, final_state). The current code assigns these tuples directly toowithout unpacking, but line ~131 callsrearrange(o, "... h d -> ... (h d)"), which expects a tensor and will fail at runtime.Fix the assignments at lines 104-107 and 109-112:
o, _ = fused_chunk_linear_attn( q=q, k=k, v=v, normalize=True, scale=1, ) # and o, _ = chunk_linear_attn( q=q, k=k, v=v, normalize=True, scale=1, )
🧹 Nitpick comments (9)
fla/layers/rebased.py (1)
98-102: Feature-map + KV repeat ordering is correct, but consider documentingApplying
feature_maptoqandkbefore repeatingk/vover KV groups is efficient and avoids redundant normalization work. Because this subtly changes where normalization happens relative to the grouping, a brief inline comment (e.g. “normalize once per KV-head, then share across groups”) would help future readers.benchmarks/ops/benchmark_kda_intra.py (2)
19-33: Usetorch.no_grad()to be explicit about benchmarking forward-onlyEven though
torch.randndefaults torequires_grad=False, wrapping benchmarked calls in atorch.no_grad()context can make it explicit this is a pure-forward perf script and avoid accidental grad tracking if the code changes.For example:
with torch.no_grad(): for _ in range(10): ...
1-3: Address trailing whitespace / pre-commit lintGitHub Actions reports trailing whitespace stripped by pre-commit in this file. After applying code changes, please run the repo’s pre-commit hooks (or the configured formatter) locally to keep CI green.
fla/ops/kda/chunk_intra.py (6)
159-223: Recursive sub-intra kernel logic is internally consistentWithin
chunk_kda_fwd_kernel_intra_sub_intra:
- Converting
b_gto base‑2 exponents via* 1.44269504and then usingtl.exp2is mathematically consistent with the previousexp-based formulation and keeps everything infloat32.- Diagonal contributions are seeded via
b_Aqk_diagbefore the span loop, and the subsequentlog_spanloop builds larger spans bottom‑up using masks (is_q,is_k,same_block).- Final triangular masks
Aqk:o_i[:, None] >= o_i[None, :]Akk:o_i[:, None] > o_i[None, :]withb_beta
enforce causal structure and zero out unwanted regions before storing.The only nit is that comments still mention
BC=64/log2=5while the implementation hardcodesBC=16and loopslog_span in range(3, -1, -1). Updating the comment would avoid confusion.
450-520: New recurrent sub-intra kernel appears shape- and index-safe
chunk_kda_fwd_kernel_intra_sub_intra_recurrent:
- Correctly handles varlen vs fixed‑len via
chunk_indices/cu_seqlensand early‑returns wheni_t * BT + i_i * BC >= T.- Offsets
q/k/g/beta/Aqk/Akkin the same way as other kernels, keeping head/time mapping consistent.- Loop over
jadvancesp_kt/p_gkbyH*Kper step, which matches the[T, H, K]physical layout.- Final zeroing of:
Aqkforo_i[:, None] < o_i[None, :](upper triangle)Akkforo_i[:, None] <= o_i[None, :](diag + upper)
ensures the expected causal triangular structure after the per‑row recurrence.Given the mathematical subtlety, it would be good to add or extend tests that compare
impl_type="recurrent"against"token"/"recursive"on small random inputs to validate numerical equivalence.
724-761:impl_typeAPI and auto-selection heuristic are reasonable
- Adding
impl_type: str = "auto"gives a clean way to choose between"token","recursive", and"recurrent"without changing call sites.- The auto-heuristic
impl_type = "token" if K >= 128 else "recursive"is simple and matches the benchmark intent (different performance regimes by head dim).Consider validating
impl_typevalues and raising a clear error on unknown strings to avoid silently falling into theelsebranch:valid_impls = {"auto", "token", "recursive", "recurrent"} if impl_type not in valid_impls: raise ValueError(f"Unsupported impl_type={impl_type!r}, expected one of {sorted(valid_impls)}")
767-769: HardcodingBC = 16is fine but deserves a short commentForward now fixes
BC = 16while backward still computesBC = min(16, BT). That’s fine for the current configs (BT=64in benchmarks), but a brief comment like “BC fixed to 16 for fwd kernels; adjust if chunk_size is tuned beyond 64” would document the coupling between forward and backward tiling.
842-844: Minor consistency nit between forward and backward BC/BK choicesForward uses
BC = 16andBK = max(triton.next_power_of_2(K), 16), while backward usesBC = min(16, BT)andBK = min(32, triton.next_power_of_2(K)). That’s likely intentional for numerical stability/performance, but it’s worth double-checking that all production call sites stay within the tuned regimes (e.g.,BT >= 16,K <= 256as asserted) and documenting the assumptions in a module-level comment.
1-5: Re-run pre-commit to satisfy trailing-whitespace lintCI reports trailing whitespace removed near the top of this file by pre-commit. After applying kernel changes, please re-run the repo’s formatting/lint hooks locally to keep this file aligned with style checks.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
benchmarks/ops/benchmark_kda_intra.py(1 hunks)fla/layers/rebased.py(5 hunks)fla/ops/kda/chunk_intra.py(8 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
benchmarks/ops/benchmark_kda_intra.py (1)
fla/ops/kda/chunk_intra.py (1)
chunk_kda_fwd_intra(714-852)
fla/layers/rebased.py (4)
fla/modules/feature_map.py (1)
RebasedFeatureMap(204-243)fla/ops/linear_attn/chunk.py (1)
chunk_linear_attn(11-74)fla/ops/linear_attn/fused_chunk.py (1)
fused_chunk_linear_attn(11-59)fla/ops/rebased/parallel.py (1)
parallel_rebased(439-463)
fla/ops/kda/chunk_intra.py (2)
fla/ops/kda/chunk_intra_token_parallel.py (3)
chunk_kda_fwd_kernel_intra_token_parallel(25-155)chunk_kda_fwd_intra_token_parallel(158-219)grid(200-202)fla/ops/utils/index.py (1)
prepare_chunk_indices(114-119)
🪛 Flake8 (7.3.0)
benchmarks/ops/benchmark_kda_intra.py
[error] 41-41: local variable 'e' is assigned to but never used
(F841)
🪛 GitHub Actions: lint
benchmarks/ops/benchmark_kda_intra.py
[error] 1-1: Trailing whitespace detected and removed by pre-commit trailing-whitespace hook. Re-run pre-commit to confirm.
fla/ops/kda/chunk_intra.py
[error] 1-1: Trailing whitespace detected and removed by pre-commit trailing-whitespace hook. Re-run pre-commit to confirm.
🪛 Ruff (0.14.5)
benchmarks/ops/benchmark_kda_intra.py
41-41: Do not catch blind exception: Exception
(BLE001)
41-41: Local variable e is assigned to but never used
Remove assignment to unused variable e
(F841)
67-67: Do not catch blind exception: Exception
(BLE001)
fla/layers/rebased.py
54-54: Avoid specifying long messages outside the exception class
(TRY003)
56-56: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
- GitHub Check: check-pt-python-compatibility
🔇 Additional comments (10)
fla/layers/rebased.py (5)
11-11: Import ofrepeatis appropriate and scoped correctlyUsing
einops.repeatfor KV-group expansion is idiomatic here and keeps head-axis reshaping readable.
53-57: Divisibility checks for heads/KV-heads look correctValidating
hidden_size % num_heads == 0andnum_heads % num_key_value_heads == 0early is the right guardrail for KV-grouping; messages are clear enough for debugging.
60-62: New KV-grouping attributes are consistent
num_key_value_heads,num_kv_groups = num_heads // num_key_value_heads, andhead_dim = hidden_size // num_headsare coherent and match the projection shapes below.
71-75: Projection shapes align with GQA design
q_proj: hidden_size → feature_dim * num_headsk_proj: hidden_size → feature_dim * num_key_value_headsv_proj: hidden_size → num_key_value_heads * head_dimo_proj: num_heads * head_dim → hidden_sizeThese choices are consistent with grouped KV (GQA) and with the later
rearrange/repeatusage.
86-97: KV grouping expansion strategy looks sound
kis reshaped withh=self.num_key_value_heads.vis reshaped withh=self.num_key_value_heads,d=self.head_dim.- When
num_kv_groups > 1, repeating alonghto(h * num_kv_groups)for bothkandvcorrectly broadcasts KV heads to all query heads.This keeps all three tensors at
H = num_headsbefore entering the attention kernels.benchmarks/ops/benchmark_kda_intra.py (2)
6-18: Benchmark setup is reasonable for perf comparisonsRandom
q/k/g/betaallocations withbfloat16on CUDA and a fixedchunk_sizegive a clean microbenchmark of the threeimpl_typevariants. Warmup and reuse of the same tensors across runs are appropriate.
64-68: Same unused-exception pattern at top-level benchmark loopApply the same adjustment here to satisfy F841/BLE001:
- for B, T, H, K in configs: - try: - benchmark_intra_chunk(B=B, T=T, H=H, K=K, chunk_size=64) - except Exception as e: - print(f"Failed for shape B={B}, T={T}, H={H}, K={K}: {e}") + for B, T, H, K in configs: + try: + benchmark_intra_chunk(B=B, T=T, H=H, K=K, chunk_size=64) + except Exception as exc: + print(f"Failed for shape B={B}, T={T}, H={H}, K={K}: {exc}")Here it’s actually useful to keep the exception object and log it.
Likely an incorrect or invalid review comment.
fla/ops/kda/chunk_intra.py (3)
523-552: Token-parallel Triton kernel grid and varlen handling look correctIn
chunk_kda_fwd_kernel_intra_token_parallel:
- Grid
(total_tokens, H/BH)with autotunedBHmatches the intended “each token × head‑group” mapping.- Varlen path binary-searches
cu_seqlensto recover(sequence, local_t)and guards out-of-range tokens.- Fixed‑len path computes
(i_b, i_t)via div/mod onT, with an early return wheni_t >= T.- The subchunk computation uses
BT/BC=16to localize work to the appropriate chunk and subchunk.- Masking on
m_handm_kprevents out‑of‑range head/K accesses, and writes toAqk/Akkfollow the documented[B, T, H, BT]layout.Overall this matches the intent described in the docstring for the Python wrapper.
668-712: Python wrapper for token-parallel impl is thin and correct
chunk_kda_fwd_intra_token_parallel:
- Derives
total_tokensandB_kernelcorrectly for both fixed‑len and varlen inputs.- Uses a
grid(meta)closure to expose autotunedBHback to the launch grid, which is the standard Triton pattern.- Forwards all required arguments, including
BT=chunk_sizeandUSE_EXP2, to the kernel.No issues from a shape or launch perspective.
793-825:impl_typedispatch wiring matches intended kernels
"token"→chunk_kda_fwd_intra_token_parallelon top of the commonchunk_kda_fwd_kernel_intra_sub_inter."recurrent"→ newchunk_kda_fwd_kernel_intra_sub_intra_recurrentwith grid(NT, NC, B*H)and the same BC/BK as default.- Anything else (currently
"recursive") →chunk_kda_fwd_kernel_intra_sub_intrawith identical grid and tiling.This preserves output shapes and reuses the same
Aqk/Akkbuffers, so higher-level callers don’t need to change.
| try: | ||
| ms_recurrent = triton.testing.do_bench( | ||
| lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"), | ||
| quantiles=quantiles | ||
| ) | ||
| t_recurrent = ms_recurrent[0] | ||
| except Exception as e: | ||
| t_recurrent = float('nan') | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix unused exception variable and consider narrowing the catch
You currently catch a broad Exception and bind it to e without using it, which triggers F841/BLE001 and fails lint.
A minimal fix that keeps the behavior is:
- try:
- ms_recurrent = triton.testing.do_bench(
- lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
- quantiles=quantiles
- )
- t_recurrent = ms_recurrent[0]
- except Exception as e:
- t_recurrent = float('nan')
+ try:
+ ms_recurrent = triton.testing.do_bench(
+ lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
+ quantiles=quantiles,
+ )
+ t_recurrent = ms_recurrent[0]
+ except Exception:
+ # Recurrent path not supported/stable for this config; report NaN.
+ t_recurrent = float("nan")This removes the unused variable and documents why the broad catch is acceptable in a benchmark context.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| ms_recurrent = triton.testing.do_bench( | |
| lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"), | |
| quantiles=quantiles | |
| ) | |
| t_recurrent = ms_recurrent[0] | |
| except Exception as e: | |
| t_recurrent = float('nan') | |
| try: | |
| ms_recurrent = triton.testing.do_bench( | |
| lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"), | |
| quantiles=quantiles, | |
| ) | |
| t_recurrent = ms_recurrent[0] | |
| except Exception: | |
| # Recurrent path not supported/stable for this config; report NaN. | |
| t_recurrent = float("nan") |
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 41-41: local variable 'e' is assigned to but never used
(F841)
🪛 Ruff (0.14.5)
41-41: Do not catch blind exception: Exception
(BLE001)
41-41: Local variable e is assigned to but never used
Remove assignment to unused variable e
(F841)
🤖 Prompt for AI Agents
In benchmarks/ops/benchmark_kda_intra.py around lines 35 to 43, the except block
currently uses "except Exception as e" but never uses "e", causing an
unused-variable lint error and implicitly catching all exceptions; change it to
"except Exception:" (or better, catch specific expected exceptions if known) and
add a short comment explaining why a broad catch is acceptable in this
benchmarking context so the intent is clear to reviewers and linters.
ddd8f23 to
91d2f46
Compare
2b3db51 to
53dda79
Compare
python benchmarks/ops/benchmark_kda_intra.pyon H200Shape | Token (Original) | Recursive (New) | Recurrent | Speedup (Rec/Tok)
B=8, T=4096, H=16, K=128 | 1.206 ms | 1.246 ms | 1.361 ms | 0.97x
B=1, T=8192, H=16, K=128 | 0.330 ms | 0.346 ms | 0.373 ms | 0.95x
B=8, T=4096, H=32, K=64 | 1.743 ms | 1.502 ms | 1.905 ms | 1.16x
B=1, T=8192, H=32, K=64 | 0.466 ms | 0.406 ms | 0.504 ms | 1.15x
B=32, T=512, H=12, K=64 | 0.363 ms | 0.322 ms | 0.387 ms | 1.13x
B=2, T=4096, H=8, K=256 | 0.273 ms | 0.677 ms | 0.333 ms | 0.40x
Summary by CodeRabbit
New Features
Refactor
✏️ Tip: You can customize this high-level summary in your review settings.