Skip to content

Conversation

@AwesomeSeq
Copy link
Contributor

@AwesomeSeq AwesomeSeq commented Jan 27, 2026

from hujiaxi@moonshot.cn

Summary by CodeRabbit

  • New Features

    • Two optimized gated OJA implementations: a chunked, chunk-size configurable path and a fused recurrent path for efficient sequential processing.
    • Triton-accelerated kernels with variable-length sequence support, optional per-tensor L2 normalization, and initial/final state propagation.
  • Tests

    • Extensive validation suite with reference implementations covering forward and backward behavior, gradients, varlen cases, and multiple dtypes/configurations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 27, 2026

Walkthrough

Adds a complete gated OJA-rule implementation: chunked and fused-recurrent Python bindings plus many Triton kernels (KKT, hidden-state, output, WY recompute/prepare), variable-length sequence support, public API exports, and comprehensive tests validating forward and backward behavior.

Changes

Cohort / File(s) Summary
Public API Exports
fla/ops/gated_oja_rule/__init__.py
Exports chunk_gated_oja_rule and fused_recurrent_gated_oja_rule via __all__.
Chunked Python Orchestration
fla/ops/gated_oja_rule/chunk.py
Adds chunked forward/backward (chunk_oja_fwd / chunk_oja_bwd), autograd Function, top-level chunk_gated_oja_rule wrapper with L2-norm and cu_seqlens handling.
Triton: hidden-state kernels
fla/ops/gated_oja_rule/chunk_h.py
New Triton kernels/wrappers for computing/propagating hidden states (forward/backward), block-tiling, GV conditioning, varlen support, and kernel autotuning.
Triton: KKT kernels
fla/ops/gated_oja_rule/chunk_kkt.py
Triton kernels and wrappers for chunked scaled-dot KKT forward and backward (including gated gk), chunk/grid orchestration, varlen support.
Triton: output & grads kernels
fla/ops/gated_oja_rule/chunk_o.py
Triton kernels and host wrappers to compute outputs (o) and accumulate gradients (dA, dq/dk, dv, etc.) with inter/intra-block tiling and varlen support.
Fused recurrent path
fla/ops/gated_oja_rule/fused_recurrent.py
Adds fused_recurrent kernel binding, autograd Function (forward only), and fused_recurrent_gated_oja_rule API with defaults/validation; backward not implemented.
WY recompute / bwd prep
fla/ops/gated_oja_rule/wy_fast.py
Adds recompute_w_u_fwd and prepare_wy_repr_bwd Triton kernels/wrappers for WY representation (w/u/vg) forward/backward preparation.
Tests & refs
tests/ops/test_oja.py
Adds reference implementations (recurrent/chunk refs) and extensive parameterized tests comparing fused/chunk implementations for forward and backward (including varlen).

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant ChunkAPI as chunk_gated_oja_rule
    participant ChunkFunc as ChunkOJAFunction
    participant KKT as chunk_scaled_dot_kkt_fwd
    participant WY as recompute_w_u_fwd
    participant Hkern as chunk_oja_fwd_h
    participant Okern as chunk_oja_fwd_o

    User->>ChunkAPI: q,k,v,gv,beta,...
    ChunkAPI->>ChunkFunc: apply forward
    ChunkFunc->>KKT: compute A
    KKT-->>ChunkFunc: A
    ChunkFunc->>WY: recompute w/u (from k,v,A)
    WY-->>ChunkFunc: w,u,vg
    ChunkFunc->>Hkern: compute h (hidden states)
    Hkern-->>ChunkFunc: h,final_state
    ChunkFunc->>Okern: compute o
    Okern-->>ChunkFunc: o
    ChunkFunc-->>User: o, final_state
Loading
sequenceDiagram
    actor User
    participant UserAPI as fused_recurrent_gated_oja_rule
    participant FusedFunc as FusedRecurrentFunction
    participant FusedKernel as fused_recurrent_oja_fwd_kernel

    User->>UserAPI: q,k,v,gv,beta,initial_state,...
    UserAPI->>FusedFunc: forward (prepare, validate)
    FusedFunc->>FusedKernel: per-timestep fused kernel (update h, o)
    FusedKernel-->>FusedFunc: o, final_state
    FusedFunc-->>User: o, final_state
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested reviewers

  • yzhangcs

Poem

🐰 A hoppy ode to kernels new and bright,

Tiles and chunks hopping through day and night,
Forward states stitched, backward traces unfurled,
Gradients nibble and round the training world.
🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 2.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[OJA] Integrate Gated OJA Rule' directly matches the main objective of the PR—integrating gated OJA rule functionality as evidenced by the substantial additions across multiple new modules (chunk.py, chunk_h.py, chunk_kkt.py, chunk_o.py, fused_recurrent.py, wy_fast.py) and comprehensive tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @AwesomeSeq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the library's capabilities by introducing a new gated_oja_operator module. It provides both a chunk-based and a fused recurrent implementation of the Gated Oja Rule, designed for efficient processing of sequence data. The new operators are backed by highly optimized Triton kernels and are thoroughly tested to ensure correctness across various scenarios, including variable-length inputs.

Highlights

  • New Gated Oja Rule Operators: Introduced a new module fla/ops/gated_oja_rule containing implementations for the Gated Oja Rule.
  • Chunked Implementation: Added chunk_gated_oja_rule for efficient chunk-wise processing, including forward and backward passes with specialized Triton kernels for hidden states, KKT matrix computations, and output gradients.
  • Fused Recurrent Implementation: Provided fused_recurrent_gated_oja_rule for a fused recurrent approach, though its backward pass is noted as not yet implemented.
  • Triton Kernels: Developed multiple Triton kernels (chunk_oja_fwd_h, chunk_oja_bwd_dhu, chunk_oja_bwd_dvwg_h, chunk_scaled_dot_kkt_fwd, chunk_scaled_dot_kkt_bwd_gk, chunk_oja_fwd_o, chunk_oja_bwd_dA, chunk_oja_bwd_dqk, chunk_oja_bwd_dv_o, recompute_w_u_fwd, prepare_wy_repr_bwd) to optimize performance.
  • Comprehensive Testing: Included extensive unit tests covering both chunked and fused recurrent Oja rules, validating forward and backward passes, and supporting variable-length sequences.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces gated Oja operator implementations, including chunked and fused recurrent versions, along with corresponding tests. The overall structure is well-organized, separating forward and backward passes into distinct functions and Triton kernels. The addition of comprehensive test cases, including variable-length sequences and backward pass checks, is highly commendable. However, several critical bugs related to indexing and conditional variable usage in Triton kernels have been identified, which need immediate attention to ensure correctness and prevent potential runtime errors.

Comment on lines +152 to +163
if K > 64:
o_v2 = 64 + o_v1
b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.)
b_h2 *= exp(b_gk_last2)[None, :]
if K > 128:
o_v3 = 128 + o_v1
b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.)
b_h3 *= exp(b_gk_last3)[None, :]
if K > 192:
o_v4 = 192 + o_v1
b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.)
b_h4 *= exp(b_gk_last4)[None, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The if conditions for loading b_gk_lastX are checking K (e.g., if K > 64) but the offsets o_vX and the b_hX variables are related to the V dimension. This is a logical error and can lead to incorrect memory access or calculations if K and V dimensions do not align with these hardcoded checks. The conditions should be based on V.

Suggested change
if K > 64:
o_v2 = 64 + o_v1
b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.)
b_h2 *= exp(b_gk_last2)[None, :]
if K > 128:
o_v3 = 128 + o_v1
b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.)
b_h3 *= exp(b_gk_last3)[None, :]
if K > 192:
o_v4 = 192 + o_v1
b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.)
b_h4 *= exp(b_gk_last4)[None, :]
b_h1 *= exp(b_gk_last1)[None, :]
if V > 64:
o_v2 = 64 + o_v1
b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.)
b_h2 *= exp(b_gk_last2)[None, :]
if V > 128:
o_v3 = 128 + o_v1
b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.)
b_h3 *= exp(b_gk_last3)[None, :]
if V > 192:
o_v4 = 192 + o_v1
b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < V), other=0.)
b_h4 *= exp(b_gk_last4)[None, :]

Comment on lines 745 to 756
b_dgv_last += tl.sum((b_h * b_dh) * exp(b_gn), axis=0)

if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)

p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))

b_dgv_last += tl.sum(b_dv * b_v, axis=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variables b_gn and b_dv are used in calculations (b_dgv_last += tl.sum((b_h * b_dh) * exp(b_gn), axis=0) and b_dgv_last += tl.sum(b_dv * b_v, axis=0)) without being initialized or conditionally used when USE_GV is false. This will lead to a runtime error or undefined behavior if gv is None.

Suggested change
b_dgv_last += tl.sum((b_h * b_dh) * exp(b_gn), axis=0)
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_dgv_last += tl.sum(b_dv * b_v, axis=0)
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
b_dgv_last += tl.sum((b_h * b_dh) * exp(b_gn), axis=0)
b_dgv_last += tl.sum(b_dv * b_v, axis=0)
else:
b_dv = b_dvg
b_dgv_last += tl.sum(b_h * b_dh, axis=0)
b_dgv_last += tl.sum(b_dv * b_v, axis=0)

return

p_g = tl.make_block_ptr(gv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0))
p_gn = gv + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The calculation of p_gn uses min(i_t * BT + i_i * BC, T). This likely intends to get the last valid index of the current block. However, it should typically be min(i_t * BT + i_i * BC + BC - 1, T - 1) to correctly represent the last element's index within the block, ensuring it doesn't go out of bounds if T is exactly the end of the sequence or if the block is not full. Using T directly can lead to out-of-bounds access if T is the total length and not an index.

Suggested change
p_gn = gv + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v
p_gn = gv + (bos + min(i_t * BT + i_i * BC + BC - 1, T - 1)) * H*V + i_h * V + o_v

if i_i > i_j:
p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
p_gv = tl.make_block_ptr(gv + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1))
p_gn = gv + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to chunk_oja_fwd_intra, the calculation of p_gn here (gv + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v) is likely incorrect for getting the last valid index of the current block. It should be min(i_t * BT + i_i * BC + BC - 1, T - 1) to ensure correct indexing and prevent potential out-of-bounds access.

Suggested change
p_gn = gv + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v
p_gn = gv + (bos + min(i_t*BT + i_i*BC + BC - 1, T - 1)) * H*V + i_h * V + o_v

Comment on lines 398 to 445
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]

if V > 64:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v2 = 64 + o_v1
b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.)
b_dh2 *= exp(b_gv_last2[None, :])
b_do *= exp(b_gv)
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)

if V > 128:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v3 = 128 + o_v1
b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.)
b_dh3 *= exp(b_gv_last3[None, :])
b_do *= exp(b_gv)
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)

if V > 192:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v4 = 192 + o_v1
b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.)
b_dh4 *= exp(b_gv_last4[None, :])
b_do *= exp(b_gv)
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the chunk_oja_bwd_kernel_dhu_blockdim64 kernel, the b_dhX variables are multiplied by exp(b_gv_lastX) unconditionally if USE_GV is true. However, b_gv_lastX (for X=2,3,4) are loaded only if V is greater than a certain threshold (e.g., V > 64). If V is smaller, these b_gv_lastX variables might contain uninitialized or garbage values, leading to incorrect calculations. Each multiplication should be guarded by the corresponding if V > ... condition.

Suggested change
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]
if V > 64:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v2 = 64 + o_v1
b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.)
b_dh2 *= exp(b_gv_last2[None, :])
b_do *= exp(b_gv)
b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if V > 128:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v3 = 128 + o_v1
b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.)
b_dh3 *= exp(b_gv_last3[None, :])
b_do *= exp(b_gv)
b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if V > 192:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0))
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v4 = 192 + o_v1
b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.)
b_dh4 *= exp(b_gv_last4[None, :])
b_do *= exp(b_gv)
b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w)
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1)[None, :]
b_do *= exp(b_gv)
if V > 64 and USE_GV:
o_v2 = 64 + o_v1
b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.)
b_dh2 *= exp(b_gv_last2)[None, :]
b_do *= exp(b_gv)
if V > 128 and USE_GV:
o_v3 = 128 + o_v1
b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.)
b_dh3 *= exp(b_gv_last3)[None, :]
b_do *= exp(b_gv)
if V > 192 and USE_GV:
o_v4 = 192 + o_v1
b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.)
b_dh4 *= exp(b_gv_last4)[None, :]
b_do *= exp(b_gv)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🤖 Fix all issues with AI agents
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Line 93: Remove the unused BV variable assignment (BV=64) in chunk_h.py:
delete the BV definition since the kernel uses a hardcoded 64 and BV is never
referenced elsewhere; ensure no other code in the module refers to BV and run
lint/tests to confirm no remaining references (look for the symbol BV in
chunk_h.py to locate the line to remove).
- Around line 148-163: The conditional blocks that apply gv scaling incorrectly
compare against K instead of V; change all occurrences of "if K > 64/128/192" to
"if V > 64/128/192" and update the final load mask from "(o_v4 < K)" to "(o_v4 <
V)" so the gv loads and masks (e.g., in the blocks computing
b_gk_last1..b_gk_last4 and multiplying b_h1..b_h4) correctly use the
value-dimension V rather than the key-dimension K.
- Around line 747-767: The code uses b_dv unconditionally but only defines it
inside the if USE_GV branch; to fix, ensure b_dv is always assigned: when USE_GV
is True compute b_dv = b_dvg * exp(b_gn[None, :] - b_gv) as before, otherwise
initialize b_dv to a zero tensor with the same shape and dtype used later (shape
[BT, BV] matching b_v and p_dv.element_ty) so subsequent operations (b_dgv_last
update, tl.store(p_dv, ...), and interaction with b_v) work correctly; update
the block so b_dv, b_dvg, b_gn, b_gv, p_dv, b_v, and b_dgv_last remain the
referenced symbols.
- Around line 396-403: The code unconditionally creates and loads p_gv and b_gv
(using gv, p_gv, b_gv) inside the V>0 handling even when gv may be None if
USE_GV is False; wrap the creation of p_gv and any tl.load(gv + ...) or
tl.load(p_gv, ...) calls with a guard on USE_GV (same pattern used elsewhere
when gv is offset at line 319) so that all accesses to gv happen only when
USE_GV is True, and apply the same guard pattern to the other V>64, V>128, and
V>192 blocks to prevent null pointer/runtime loads when gv is not provided.
- Around line 531-651: The function chunk_gsa_bwd_k_kernel_dqkvg defined in this
file is dead/duplicated and should be removed: delete the entire
chunk_gsa_bwd_k_kernel_dqkvg(...) definition from
fla/ops/gated_oja_rule/chunk_h.py so the codebase uses the single implementation
in fla/ops/gsa/chunk.py; after removal, run tests and search for any local
references to chunk_gsa_bwd_k_kernel_dqkvg to ensure no callers depend on this
definition and update imports/call sites to reference the gsa implementation if
needed.

In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 8-15: Remove the redundant and unused imports: delete the import
of exp from fla.ops.utils.op and the import of chunk_local_cumsum from
fla.ops.utils.cumsum, keeping the intended tl.exp assignment (exp = tl.exp) as
the single definition of exp; ensure no other code depends on
fla.ops.utils.op.exp or chunk_local_cumsum in this file (references to exp
should use the tl-backed exp symbol).

In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 1-2: The file header lines have duplicated comment markers ("#
#"), so remove the extra '#' characters in those header comments: change the
leading "# # -*- coding: utf-8 -*-" to "# -*- coding: utf-8 -*-" and similarly
change "# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang" to "# Copyright (c)
2023-2025, Songlin Yang, Yu Zhang" to restore proper comment syntax.

In `@fla/ops/gated_oja_rule/fused_recurrent.py`:
- Around line 95-97: The load of gv must apply the mask for partial vector
blocks to avoid OOB reads: when USE_GV is true, change the load of p_gv (symbol
b_gv) to use mask_v (the mask for the last V block) instead of an unconditional
tl.load; keep the subsequent scaling of b_h (symbol b_h *= exp(b_gv[None, :]))
the same so that only valid lanes are loaded and used when V % BV != 0.

In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 199-237: Update the return type annotation of recompute_w_u_fwd to
match the actual returned values (w, u, vg): change the declared return from
Tuple[torch.Tensor, torch.Tensor] to Tuple[torch.Tensor, torch.Tensor,
Optional[torch.Tensor]] and import Optional if not already present; ensure the
function signature and any callers/types align with the new signature for
recompute_w_u_fwd.
- Around line 247-261: The gv parameter is declared without Optional typing and
the code unconditionally allocates dgv with torch.empty_like(gv), which will
crash if gv is None; update the function signature to annotate gv as
Optional[torch.Tensor] and change the local dgv to be Optional[torch.Tensor] (or
torch.Tensor | None) and only allocate dgv when gv is not None (e.g., after
checking gv) — leave dgv as None otherwise; ensure any later uses of dgv handle
the None case or assert/raise if those code paths require gv to be present.
🧹 Nitpick comments (12)
fla/ops/gated_oja_rule/chunk_kkt.py (1)

131-143: Inconsistent naming convention for block pointer.

The variable b_kt at line 131 is a block pointer (created via tl.make_block_ptr), but follows the b_ prefix convention used for block tensors throughout this file. Consider renaming to p_kt for consistency with other pointers (p_k, p_g, etc.).

♻️ Suggested naming fix
-        b_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
+        p_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))

And at line 143:

-        b_kt = tl.load(b_kt, boundary_check=(0, 1)) * exp(b_gn[:, None] - b_gk)
+        b_kt = tl.load(p_kt, boundary_check=(0, 1)) * exp(b_gn[:, None] - b_gk)
fla/ops/gated_oja_rule/wy_fast.py (2)

11-11: Remove unused import.

The static analysis correctly identifies that chunk_local_cumsum is imported but not used in this file.

♻️ Proposed fix
-from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices
+from fla.ops.utils import prepare_chunk_indices

193-193: Remove or complete the commented-out code.

Line 193 contains a commented-out conditional # if USE_GV:. This appears to be either dead code or an incomplete TODO. Please remove it or implement the intended logic.

fla/ops/gated_oja_rule/chunk.py (1)

283-287: Add stacklevel to warnings.warn.

Per best practices, specify stacklevel=2 so the warning points to the caller's location rather than this line.

♻️ Proposed fix
     if 'head_first' in kwargs:
         warnings.warn(
             "head_first is deprecated and will be removed in a future version. "
-            "Please use head_first=False for now instead."
+            "Please use head_first=False for now instead.",
+            stacklevel=2
         )
fla/ops/gated_oja_rule/fused_recurrent.py (1)

133-133: Document or relax the V <= 128 constraint.

The assertion assert V <= 128 limits the value dimension without explanation. Consider adding a comment explaining why this limit exists, or raising a more informative error.

♻️ Suggested improvement
-    assert V <= 128
+    if V > 128:
+        raise ValueError(
+            f"fused_recurrent_oja_fwd currently supports V <= 128, got V={V}. "
+            "Use chunk_gated_oja_rule for larger value dimensions."
+        )
tests/ops/test_oja.py (4)

4-4: Remove unused imports.

Optional from typing and repeat from einops are imported but never used.

♻️ Proposed fix
-from typing import List, Optional
+from typing import List
-from einops import rearrange, repeat
+from einops import rearrange

82-82: Rename ambiguous variable l.

The variable l at line 82 is flagged by linters as ambiguous (looks like 1). Consider renaming to seq_len or L for clarity.

♻️ Proposed fix
-    b, h, l, d_k = q.shape
+    b, h, seq_len, d_k = q.shape
     d_v = v.shape[-1]
     q = q * scale # B H T D
-    assert l % chunk_size == 0
+    assert seq_len % chunk_size == 0

And update other usages of l (lines 85, 121) to seq_len.


341-341: Remove debug print statement.

The print statement at line 341 appears to be debug output. Consider removing it or using proper logging.

♻️ Proposed fix
-    print('================== Running forward and backward ==================')

412-412: Consider isolating environment variable modification.

Setting os.environ['TRITON_F32_DEFAULT'] at line 412 persists beyond this test and may affect subsequent tests. Consider using a fixture or context manager to ensure cleanup.

♻️ Suggested approach
`@pytest.fixture`(autouse=True)
def set_triton_f32_default():
    old_value = os.environ.get('TRITON_F32_DEFAULT')
    os.environ['TRITON_F32_DEFAULT'] = 'ieee'
    yield
    if old_value is None:
        del os.environ['TRITON_F32_DEFAULT']
    else:
        os.environ['TRITON_F32_DEFAULT'] = old_value
fla/ops/gated_oja_rule/chunk_h.py (2)

480-484: Unused chunk_indices computation.

chunk_indices is computed but never passed to the kernel. Consider removing this unnecessary computation.

Proposed fix
-    chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
     if cu_seqlens is None:
         N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
     else:
-        N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
+        chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
+        N = len(cu_seqlens) - 1
+        NT = chunk_offsets[-1].item()  # or compute directly

386-386: Use ASCII commas in comments for consistency.

The comment contains fullwidth commas (,) which triggers linter warnings. Consider using standard ASCII commas or translating comments to English.

fla/ops/gated_oja_rule/chunk_o.py (1)

453-461: Redundant computation of attention matrix A.

The attention matrix A is computed identically in each i_k block and then summed (line 540), which is wasteful. Since A = dot(q*scale, k.T) is the same regardless of which K-block is being processed, this results in NK redundant computations.

Consider computing A once in a separate kernel or only in the first K-block.

Comment on lines +148 to +163
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gk_last1 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v1, mask=(o_v1 < V), other=0.)
b_h1 *= exp(b_gk_last1)[None, :]
if K > 64:
o_v2 = 64 + o_v1
b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.)
b_h2 *= exp(b_gk_last2)[None, :]
if K > 128:
o_v3 = 128 + o_v1
b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.)
b_h3 *= exp(b_gk_last3)[None, :]
if K > 192:
o_v4 = 192 + o_v1
b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.)
b_h4 *= exp(b_gk_last4)[None, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical bug: Incorrect dimension check for gv scaling.

The conditionals check K (key dimension) but should check V (value dimension). The gv tensor has shape [..., V], not [..., K]. This causes incorrect scaling when K != V.

Proposed fix
         if USE_GV:
             o_v1 = tl.arange(0, 64)
             b_gk_last1 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v1, mask=(o_v1 < V), other=0.)
             b_h1 *= exp(b_gk_last1)[None, :]
-            if K > 64:
+            if V > 64:
                 o_v2 = 64 + o_v1
                 b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.)
                 b_h2 *= exp(b_gk_last2)[None, :]
-            if K > 128:
+            if V > 128:
                 o_v3 = 128 + o_v1
                 b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.)
                 b_h3 *= exp(b_gk_last3)[None, :]
-            if K > 192:
+            if V > 192:
                 o_v4 = 192 + o_v1
-                b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.)
+                b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < V), other=0.)
                 b_h4 *= exp(b_gk_last4)[None, :]
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 148 - 163, The conditional
blocks that apply gv scaling incorrectly compare against K instead of V; change
all occurrences of "if K > 64/128/192" to "if V > 64/128/192" and update the
final load mask from "(o_v4 < K)" to "(o_v4 < V)" so the gv loads and masks
(e.g., in the blocks computing b_gk_last1..b_gk_last4 and multiplying
b_h1..b_h4) correctly use the value-dimension V rather than the key-dimension K.

Comment on lines 396 to 403
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential null pointer access when USE_GV is False.

The code loads from gv (lines 396-397) unconditionally within the if V > 0 block, but gv is only offset at line 319 when USE_GV is True. When USE_GV is False and gv is None, this will cause a runtime error.

Proposed fix: Guard gv loads with USE_GV check
         if V > 0:
             p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
             b_do = tl.load(p_do, boundary_check=(0, 1))
             p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
             b_w = tl.load(p_w, boundary_check=(0, 1))
-            p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
-            b_gv = tl.load(p_gv, boundary_check=(0, 1))
             if USE_GV:
+                p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
+                b_gv = tl.load(p_gv, boundary_check=(0, 1))
                 o_v1 = tl.arange(0, 64)
                 b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
                 b_dh1 *= exp(b_gv_last1[None, :])
                 b_do *= exp(b_gv)
             b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]

Apply similar changes for the if V > 64, if V > 128, and if V > 192 blocks.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]
if USE_GV:
p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_gv = tl.load(p_gv, boundary_check=(0, 1))
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 396 - 403, The code
unconditionally creates and loads p_gv and b_gv (using gv, p_gv, b_gv) inside
the V>0 handling even when gv may be None if USE_GV is False; wrap the creation
of p_gv and any tl.load(gv + ...) or tl.load(p_gv, ...) calls with a guard on
USE_GV (same pattern used elsewhere when gv is offset at line 319) so that all
accesses to gv happen only when USE_GV is True, and apply the same guard pattern
to the other V>64, V>128, and V>192 blocks to prevent null pointer/runtime loads
when gv is not provided.

Comment on lines 531 to 651
def chunk_gsa_bwd_k_kernel_dqkvg(
q,
k,
v,
h,
g,
A,
do,
dh,
dq,
dk,
dv,
dg,
dgv,
dA,
cu_seqlens,
chunk_indices,
scale,
T,
B: tl.constexpr,
HQ: tl.constexpr,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
NG: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_hq = i_bh // HQ, i_bh % HQ
i_h = i_hq // NG
if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
all = T
T = eos - bos
NT = tl.cdiv(T, BT)
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
all = B * T

o_i = tl.arange(0, BT)
o_t = min(i_t * BT + BT, T)
m_s = o_i[:, None] >= o_i[None, :]

p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BT, BT]
b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k))
b_A = tl.where(m_s, b_A, 0.)
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))

b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
for i_v in range(tl.cdiv(V, BV)):
o_v = i_v * BV + tl.arange(0, BV)
p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v
p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
m_v = o_v < V

# [BV,]
b_gn = tl.load(p_gn, mask=m_v, other=0)
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
b_gv = exp(b_gn[None, :] - b_g)
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
b_do = (b_do * exp(b_g) * scale).to(b_do.dtype)
# [BK, BV]
b_dh = tl.load(p_dh, boundary_check=(0, 1))
# [BV]
b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn)

b_dh = b_dh.to(b_k.dtype)
# [BT, BK]
b_dq += tl.dot(b_do, b_h.to(b_k.dtype))
b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh))
# [BT, BV]
b_dv = tl.dot(b_k, b_dh) * b_gv
# [BV]
b_dg += tl.sum(b_dv * b_v, 0)

if i_k == 0:
b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :]
else:
b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :]

tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
# [BT, BT]
b_dA = tl.load(p_dA, boundary_check=(0, 1))
# [BT, BK]
b_dq += tl.dot(b_dA, b_k)
b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q)

tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

rg -n "chunk_gsa_bwd_k_kernel_dqkvg" --type py

Repository: fla-org/flash-linear-attention

Length of output: 269


🏁 Script executed:

# Check the kernel definition in both files
echo "=== gated_oja_rule/chunk_h.py definition ===" 
sed -n '531,533p' fla/ops/gated_oja_rule/chunk_h.py

echo -e "\n=== gsa/chunk.py definition ===" 
sed -n '297,299p' fla/ops/gsa/chunk.py

# Get some context about imports and structure in gated_oja_rule/chunk_h.py
echo -e "\n=== gated_oja_rule/chunk_h.py file overview ===" 
head -30 fla/ops/gated_oja_rule/chunk_h.py

# Check if the kernel is imported from gsa or defined locally
echo -e "\n=== Checking for imports of chunk_gsa_bwd_k_kernel_dqkvg in gated_oja_rule ===" 
rg "import.*chunk_gsa_bwd_k_kernel_dqkvg|from.*chunk_gsa" fla/ops/gated_oja_rule/

Repository: fla-org/flash-linear-attention

Length of output: 1349


🏁 Script executed:

# Check if chunk_gsa_bwd_k_kernel_dqkvg is called anywhere in gated_oja_rule/chunk_h.py
echo "=== Searching for calls to chunk_gsa_bwd_k_kernel_dqkvg in gated_oja_rule ===" 
rg "chunk_gsa_bwd_k_kernel_dqkvg\[" fla/ops/gated_oja_rule/chunk_h.py

# Check what's defined in gated_oja_rule/chunk_h.py (list all function/kernel definitions)
echo -e "\n=== All kernel definitions in gated_oja_rule/chunk_h.py ===" 
rg "^def chunk.*kernel" fla/ops/gated_oja_rule/chunk_h.py

# See the whole file length
echo -e "\n=== File length ===" 
wc -l fla/ops/gated_oja_rule/chunk_h.py

# Check what's exported or used from gated_oja_rule
echo -e "\n=== Functions called/exported in gated_oja_rule ===" 
rg "^def [a-z_]+\(" fla/ops/gated_oja_rule/chunk_h.py | head -20

Repository: fla-org/flash-linear-attention

Length of output: 608


🏁 Script executed:

# Compare the function definition at the end to see if there's a wrapper function
echo "=== Last 30 lines of gated_oja_rule/chunk_h.py ===" 
tail -30 fla/ops/gated_oja_rule/chunk_h.py

# Check if there's any Python function wrapper that calls this kernel
echo -e "\n=== Checking for wrapper functions that might call the kernel ===" 
rg -A 5 "def chunk_gsa_bwd_k" fla/ops/gated_oja_rule/chunk_h.py | head -40

Repository: fla-org/flash-linear-attention

Length of output: 950


🏁 Script executed:

# Get line count of the kernel definition to understand its size
echo "=== Getting the kernel definition size ===" 
awk '/^def chunk_gsa_bwd_k_kernel_dqkvg/,/^def / {print NR": "$0}' fla/ops/gated_oja_rule/chunk_h.py | tail -5

# Compare the definitions more carefully - get line ranges
echo -e "\n=== Checking what comes after the kernel definition ===" 
sed -n '650,655p' fla/ops/gated_oja_rule/chunk_h.py

# Check if there are any public wrapper functions that use this kernel
echo -e "\n=== Public functions that might export or use the kernel ===" 
rg "^def chunk_gsa_bwd" fla/ops/gated_oja_rule/chunk_h.py

Repository: fla-org/flash-linear-attention

Length of output: 403


🏁 Script executed:

# Let me check the exports or __all__ to see if it's intentionally exposed
echo "=== Check __all__ or public exports in gated_oja_rule/chunk_h.py ===" 
rg "__all__|^from|^import" fla/ops/gated_oja_rule/chunk_h.py | head -20

# Check if the gated_oja_rule module imports from gsa
echo -e "\n=== Check if gated_oja_rule imports the kernel from gsa ===" 
rg "from.*gsa.*import|import.*gsa" fla/ops/gated_oja_rule/

# Verify the exact line where the kernel ends in gated_oja_rule
echo -e "\n=== Lines 645-655 to see what comes after kernel ===" 
sed -n '645,655p' fla/ops/gated_oja_rule/chunk_h.py

Repository: fla-org/flash-linear-attention

Length of output: 815


🏁 Script executed:

# Check the wrapper function pattern in gsa/chunk.py around the kernel call
echo "=== Context around the kernel call in gsa/chunk.py (line 739) ===" 
sed -n '730,750p' fla/ops/gsa/chunk.py

# Check if there's a public wrapper function in gated_oja_rule that might call chunk_gsa_bwd_k_kernel_dqkvg
echo -e "\n=== All public functions (non-kernel) in gated_oja_rule/chunk_h.py ===" 
rg "^def [a-z_]+\(" fla/ops/gated_oja_rule/chunk_h.py

# Check if any of these functions contain the kernel grid call
echo -e "\n=== Check if chunk_gsa_bwd_k_kernel_dqkvg is called via grid ===" 
rg "chunk_gsa_bwd_k_kernel_dqkvg\[" fla/ops/gated_oja_rule/chunk_h.py

Repository: fla-org/flash-linear-attention

Length of output: 848


🏁 Script executed:

# Double check: search for any function that might be using the kernel with array indexing
echo "=== Final verification: any use of chunk_gsa_bwd_k_kernel_dqkvg[ in gated_oja_rule ===" 
grep -n "chunk_gsa_bwd_k_kernel_dqkvg\[" fla/ops/gated_oja_rule/chunk_h.py || echo "Not found"

# Check module structure - is this module even exposed?
echo -e "\n=== Check __init__.py to see what's exported from gated_oja_rule ===" 
find fla/ops/gated_oja_rule -name "__init__.py" -exec cat {} \;

Repository: fla-org/flash-linear-attention

Length of output: 420


Remove chunk_gsa_bwd_k_kernel_dqkvg as it is unused dead code in this module.

This kernel is defined in fla/ops/gsa/chunk.py and actively used there, but the definition in fla/ops/gated_oja_rule/chunk_h.py is never called. It should be removed to avoid maintaining duplicate code and confusion about the module's purpose (OJA operations).

🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 531 - 651, The function
chunk_gsa_bwd_k_kernel_dqkvg defined in this file is dead/duplicated and should
be removed: delete the entire chunk_gsa_bwd_k_kernel_dqkvg(...) definition from
fla/ops/gated_oja_rule/chunk_h.py so the codebase uses the single implementation
in fla/ops/gsa/chunk.py; after removal, run tests and search for any local
references to chunk_gsa_bwd_k_kernel_dqkvg to ensure no callers depend on this
definition and update imports/call sites to reference the gsa implementation if
needed.

Comment on lines 747 to 767
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)

p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))

b_dgv_last += tl.sum(b_dv * b_v, axis=0)

# 留给GSA2的接口
if HAVE_GK:
dgk += (bos * H + i_h) * V
p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
b_dgv_last = b_dgk + b_dgv_last[None, :]
else:
b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :]

tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical bug: b_dv undefined when USE_GV is False.

When USE_GV is False, b_dv is never assigned (line 748 is inside if USE_GV), but it's used unconditionally at lines 756 and 767. This will cause a runtime error.

Proposed fix
     if USE_GV:
         b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
-    
+    else:
+        b_dv = b_dvg
+
     p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_dgv_last += tl.sum(b_dv * b_v, axis=0)
# 留给GSA2的接口
if HAVE_GK:
dgk += (bos * H + i_h) * V
p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
b_dgv_last = b_dgk + b_dgv_last[None, :]
else:
b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :]
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
else:
b_dv = b_dvg
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_dgv_last += tl.sum(b_dv * b_v, axis=0)
# 留给GSA2的接口
if HAVE_GK:
dgk += (bos * H + i_h) * V
p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dgk = tl.load(p_dgk, boundary_check=(0, 1))
b_dgv_last = b_dgk + b_dgv_last[None, :]
else:
b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :]
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 747 - 767, The code uses b_dv
unconditionally but only defines it inside the if USE_GV branch; to fix, ensure
b_dv is always assigned: when USE_GV is True compute b_dv = b_dvg *
exp(b_gn[None, :] - b_gv) as before, otherwise initialize b_dv to a zero tensor
with the same shape and dtype used later (shape [BT, BV] matching b_v and
p_dv.element_ty) so subsequent operations (b_dgv_last update, tl.store(p_dv,
...), and interaction with b_v) work correctly; update the block so b_dv, b_dvg,
b_gn, b_gv, p_dv, b_v, and b_dgv_last remain the referenced symbols.

Comment on lines 8 to 15
from fla.ops.utils.op import exp
from fla.utils import check_shared_mem, is_nvidia_hopper
from fla.ops.utils.cumsum import chunk_local_cumsum

BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]

exp = tl.exp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove redundant import and duplicate exp definition.

exp is imported from fla.ops.utils.op at line 8 but immediately overwritten with tl.exp at line 15. Also, chunk_local_cumsum is imported but never used. Remove the unused imports.

Proposed fix
 from fla.ops.utils import prepare_chunk_indices
-from fla.ops.utils.op import exp
 from fla.utils import check_shared_mem, is_nvidia_hopper
-from fla.ops.utils.cumsum import chunk_local_cumsum
 
 BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
 NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
 
 exp = tl.exp
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 10-10: 'fla.ops.utils.cumsum.chunk_local_cumsum' imported but unused

(F401)


[error] 15-15: redefinition of unused 'exp' from line 8

(F811)

🪛 GitHub Actions: lint

[error] 13-13: Ruff: F811 Redefinition of unused 'exp' from line 7.

🪛 Ruff (0.14.14)

15-15: Redefinition of unused exp from line 8: exp redefined here

(F811)

🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_o.py` around lines 8 - 15, Remove the redundant
and unused imports: delete the import of exp from fla.ops.utils.op and the
import of chunk_local_cumsum from fla.ops.utils.cumsum, keeping the intended
tl.exp assignment (exp = tl.exp) as the single definition of exp; ensure no
other code depends on fla.ops.utils.op.exp or chunk_local_cumsum in this file
(references to exp should use the tl-backed exp symbol).

Comment on lines 1 to 2
# # -*- coding: utf-8 -*-
# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix comment syntax.

Line 1 has doubled comment markers # # which appears to be a typo.

🐛 Proposed fix
-# # -*- coding: utf-8 -*-
+# -*- coding: utf-8 -*-
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# # -*- coding: utf-8 -*-
# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# -*- coding: utf-8 -*-
# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
🧰 Tools
🪛 GitHub Actions: lint

[error] 1-1: Trailing whitespace detected by pre-commit; file was modified.

🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk.py` around lines 1 - 2, The file header lines
have duplicated comment markers ("# #"), so remove the extra '#' characters in
those header comments: change the leading "# # -*- coding: utf-8 -*-" to "# -*-
coding: utf-8 -*-" and similarly change "# # Copyright (c) 2023-2025, Songlin
Yang, Yu Zhang" to "# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang" to
restore proper comment syntax.

Comment on lines +95 to +97
if USE_GV:
b_gv = tl.load(p_gv).to(tl.float32)
b_h *= exp(b_gv[None, :])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add mask to gv load for partial blocks.

When V is not a multiple of BV, the load at line 96 should use mask_v to avoid reading out-of-bounds memory.

🐛 Proposed fix
         if USE_GV:
-            b_gv = tl.load(p_gv).to(tl.float32)
+            b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
             b_h *= exp(b_gv[None, :])
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if USE_GV:
b_gv = tl.load(p_gv).to(tl.float32)
b_h *= exp(b_gv[None, :])
if USE_GV:
b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
b_h *= exp(b_gv[None, :])
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/fused_recurrent.py` around lines 95 - 97, The load of
gv must apply the mask for partial vector blocks to avoid OOB reads: when USE_GV
is true, change the load of p_gv (symbol b_gv) to use mask_v (the mask for the
last V block) instead of an unconditional tl.load; keep the subsequent scaling
of b_h (symbol b_h *= exp(b_gv[None, :])) the same so that only valid lanes are
loaded and used when V % BV != 0.

Comment on lines 199 to 237
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
gv: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
BK = 64
BV = 64

chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

w = torch.empty_like(v)
u = torch.empty_like(k)
vg = torch.empty_like(v) if gv is not None else None
recompute_w_u_fwd_kernel[(NT, B*H)](
k=k,
v=v,
vg=vg,
beta=beta,
w=w,
u=u,
A=A,
gv=gv,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
return w, u, vg
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type annotation is incorrect.

The function returns three values (w, u, vg) at line 237, but the type hint at line 206 declares Tuple[torch.Tensor, torch.Tensor]. Update to reflect the actual return type.

🐛 Proposed fix
 def recompute_w_u_fwd(
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
     A: torch.Tensor,
     gv: Optional[torch.Tensor] = None,
     cu_seqlens: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 199 - 237, Update the return
type annotation of recompute_w_u_fwd to match the actual returned values (w, u,
vg): change the declared return from Tuple[torch.Tensor, torch.Tensor] to
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] and import Optional if
not already present; ensure the function signature and any callers/types align
with the new signature for recompute_w_u_fwd.

Comment on lines 247 to 261
gv: torch.Tensor = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = 64
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
CONST_TILING = 64 if check_shared_mem() else 32
BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING)
BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING)

dk = torch.empty_like(k)
dv = torch.empty_like(v, dtype=torch.float)

dgv = torch.empty_like(gv, dtype=torch.float)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add type hint and guard for gv parameter.

The gv parameter at line 247 lacks a proper type hint (defaults to None but no Optional annotation). Additionally, dgv is unconditionally allocated at line 261 using torch.empty_like(gv), which will fail if gv is None.

🐛 Proposed fix
 def prepare_wy_repr_bwd(
     k: torch.Tensor,
     v: torch.Tensor,
     beta: torch.Tensor,
     A: torch.Tensor,
     dw: torch.Tensor,
     du: torch.Tensor,
-    gv: torch.Tensor = None,
+    gv: Optional[torch.Tensor] = None,
     cu_seqlens: Optional[torch.LongTensor] = None,
 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    if gv is None:
+        raise ValueError("gv is required for prepare_wy_repr_bwd")
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 247 - 261, The gv parameter
is declared without Optional typing and the code unconditionally allocates dgv
with torch.empty_like(gv), which will crash if gv is None; update the function
signature to annotate gv as Optional[torch.Tensor] and change the local dgv to
be Optional[torch.Tensor] (or torch.Tensor | None) and only allocate dgv when gv
is not None (e.g., after checking gv) — leave dgv as None otherwise; ensure any
later uses of dgv handle the None case or assert/raise if those code paths
require gv to be present.

@zhiyuan1i zhiyuan1i changed the title Upload gated_oja_operator and test [OJA] Integrate Gated OJA Rule Feb 4, 2026
@ntumm120
Copy link

ntumm120 commented Feb 9, 2026

Cool work @AwesomeSeq! Have you trained any 340M/1.3B models with this recurrence yet? If so, is there a paper I can refer to?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Fix all issues with AI agents
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Around line 194-238: The return type annotation of chunk_oja_fwd_h is wrong:
the function actually returns three values (h, k_new, final_state) where k_new
and final_state can be None depending on save_new_key and output_final_state;
update the function signature's return type to reflect three elements (e.g.
tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]) and ensure any
callers or tests expecting the old two-tuple are adjusted accordingly; reference
the chunk_oja_fwd_h definition and the variables h, k_new, final_state in your
change.

In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 178-241: The function chunk_oja_fwd_o has a return type mismatch:
its annotation declares four tensors but the implementation returns only A and
o; update the function signature's return annotation to match the actual return
(tuple[torch.Tensor, torch.Tensor]) or modify the body to return the additional
tensors if intended; locate chunk_oja_fwd_o and change the annotated return type
to reflect only A and o (or add the missing tensors to the return) and ensure
callers expect the corrected shape.

In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 287-297: The error messages use adjacent f-strings that get
implicitly concatenated without a separating space; update the messages in the
cu_seqlens check (where variables q, cu_seqlens, and initial_state are
referenced) to ensure proper spacing — either merge the two f-strings into one
or insert an explicit leading/trailing space or punctuation between them so the
resulting strings read correctly (do the same fix in the analogous checks in
fused_recurrent.py around the initial_state/cu_seqlens validation).

In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 61-79: The kernel unconditionally loads from gv (e.g.,
tl.load(p_gv) and tl.load(gv + ...)) which will crash if gv is None; update the
kernels (recompute_w_u_fwd_kernel and prepare_wy_repr_bwd_kernel) to either
(preferred) add a compile-time/use-time guard like a boolean USE_GV and wrap all
gv loads and STORE_VG-dependent logic (the p_gv/tl.load uses and computing
b_vb/b_vg) behind if USE_GV so the code never dereferences gv when absent, or
alternatively make gv a required parameter in the Python wrappers so callers
cannot pass None; ensure referenced symbols include gv, p_gv, b_gv, b_gn,
STORE_VG, and vg when applying the guard so no tl.load or tl.store touches gv/
vg unless USE_GV is true.

In `@tests/ops/test_oja.py`:
- Around line 404-407: Fix two issues: update the skip reason text and avoid
mutating global env. Change the pytest.skip call (condition using
is_intel_alchemist and D) to use the correct message 'chunk_gated_oja_rule'
instead of 'chunk_gated_delta_rule'; and replace the direct
os.environ['TRITON_F32_DEFAULT'] = 'ieee' side-effect with a test-scoped
environment change (use pytest's monkeypatch to setenv or save and restore the
original value around the test) so TRITON_F32_DEFAULT is not left modified for
other tests.
🧹 Nitpick comments (2)
fla/ops/gated_oja_rule/fused_recurrent.py (1)

117-168: Use explicit T | None for optional parameters.

Several parameters use implicit Optional (PEP 484 violation): scale at line 123, initial_state at line 124. This also applies to FusedRecurrentFunction.forward (line 182) and fused_recurrent_gated_oja_rule (lines 221-222).

Proposed fix (for the wrapper)
-    scale: float = None,
-    initial_state: torch.Tensor = None,
+    scale: float | None = None,
+    initial_state: torch.Tensor | None = None,
tests/ops/test_oja.py (1)

337-337: Remove leftover debug print statement.

Line 337 contains a print(...) that shouldn't be in committed test code. Also, there are leftover # breakpoint() comments at lines 37 and 369.

Proposed fix
-    print('================== Running forward and backward ==================')

Comment on lines +194 to +238
def chunk_oja_fwd_h(
v: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
gv: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_key: bool = True,
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, V, K = *v.shape, u.shape[-1]
BT = chunk_size

chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
assert V <= 256, "current kernel does not support head dimension larger than 256."

h = v.new_empty(B, NT, H, K, V)
final_state = v.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None

k_new = torch.empty_like(u) if save_new_key else None
def grid(meta): return (triton.cdiv(K, meta['BK']), N*H)
chunk_oja_fwd_kernel_h_blockdim64[grid](
v=v,
u=u,
w=w,
k_new=k_new,
gv=gv,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
K=K,
V=V,
BT=BT
)
return h, k_new, final_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type annotation is incorrect on chunk_oja_fwd_h.

The function returns three values (h, k_new, final_state) at line 238, but the type hint at line 204 declares tuple[torch.Tensor, torch.Tensor].

🐛 Proposed fix
-) -> tuple[torch.Tensor, torch.Tensor]:
+) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 194 - 238, The return type
annotation of chunk_oja_fwd_h is wrong: the function actually returns three
values (h, k_new, final_state) where k_new and final_state can be None depending
on save_new_key and output_final_state; update the function signature's return
type to reflect three elements (e.g. tuple[torch.Tensor, torch.Tensor | None,
torch.Tensor | None]) and ensure any callers or tests expecting the old
two-tuple are adjusted accordingly; reference the chunk_oja_fwd_h definition and
the variables h, k_new, final_state in your change.

Comment on lines +178 to +241
def chunk_oja_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
gv: torch.Tensor,
h: torch.Tensor,
scale: float = 1.,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
BC = min(16, BT)
BV = min(64, triton.next_power_of_2(V))
HQ = q.shape[2]

chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
NC = triton.cdiv(BT, BC)
NG = HQ // H

o = v.new_empty(B, T, HQ, V)
A = q.new_empty(B, T, HQ, BT)
def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ)
chunk_oja_fwd_inter[grid](
q,
k,
h,
gv,
o,
A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
HQ=HQ,
H=H,
K=K,
V=V,
BT=BT,
NG=NG,
)

def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ)
chunk_oja_fwd_intra[grid](
v,
gv,
o,
A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
HQ=HQ,
H=H,
V=V,
BT=BT,
BC=BC,
BV=BV,
NC=NC,
NG=NG,
num_warps=4,
num_stages=2
)
return A, o
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Return type annotation declares 4 tensors but only 2 are returned.

Line 187 declares -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] but line 241 returns A, o (2 tensors).

🐛 Proposed fix
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+) -> tuple[torch.Tensor, torch.Tensor]:
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_o.py` around lines 178 - 241, The function
chunk_oja_fwd_o has a return type mismatch: its annotation declares four tensors
but the implementation returns only A and o; update the function signature's
return annotation to match the actual return (tuple[torch.Tensor, torch.Tensor])
or modify the body to return the additional tensors if intended; locate
chunk_oja_fwd_o and change the annotated return type to reflect only A and o (or
add the missing tensors to the return) and ensure callers expect the corrected
shape.

Comment on lines +287 to +297
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Missing space between concatenated f-strings in error messages.

The two adjacent f-strings at lines 290–291 and 295–296 are implicitly concatenated without a separator, producing messages like "...cu_seqlens.Please flatten...". The same issue exists in fused_recurrent.py` lines 237–239.

🐛 Proposed fix
         if q.shape[0] != 1:
             raise ValueError(
-                f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
+                f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`. "
                 f"Please flatten variable-length inputs before processing."
             )
         if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
             raise ValueError(
-                f"The number of initial states is expected to be equal to the number of input sequences, "
+                f"The number of initial states is expected to be equal to the number of input sequences, "
                 f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`. "
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
🧰 Tools
🪛 Ruff (0.14.14)

[warning] 289-292: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 294-297: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk.py` around lines 287 - 297, The error messages
use adjacent f-strings that get implicitly concatenated without a separating
space; update the messages in the cu_seqlens check (where variables q,
cu_seqlens, and initial_state are referenced) to ensure proper spacing — either
merge the two f-strings into one or insert an explicit leading/trailing space or
punctuation between them so the resulting strings read correctly (do the same
fix in the analogous checks in fused_recurrent.py around the
initial_state/cu_seqlens validation).

Comment on lines +61 to +79
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_w = tl.make_block_ptr(w + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = b_v * b_b[:, None]

p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_gv = tl.load(p_gv, boundary_check=(0, 1))
b_vb *= exp(b_gv)
if STORE_VG:
last_idx = min(i_t * BT + BT, T) - 1

o_v = i_v * BV + tl.arange(0, BV)
m_v = o_v < V
b_gn = tl.load(gv + ((bos + last_idx) * H + i_h) * V + o_v, mask=m_v, other=0.)
b_vg = b_v * exp(b_gn - b_gv)

p_vg = tl.make_block_ptr(vg + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
tl.store(p_vg, b_vg.to(p_vg.dtype.element_ty), boundary_check=(0, 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

gv is loaded unconditionally in the kernel — will crash if gv is None.

Both recompute_w_u_fwd_kernel (line 67) and prepare_wy_repr_bwd_kernel (line 152) load from gv without guarding on whether gv is actually provided. Although callers in chunk.py always pass a valid gv, the Python wrapper signatures allow gv=None. Either add a USE_GV heuristic guard in the kernels or make gv a required parameter in the wrappers to prevent a latent null-pointer crash.

🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 61 - 79, The kernel
unconditionally loads from gv (e.g., tl.load(p_gv) and tl.load(gv + ...)) which
will crash if gv is None; update the kernels (recompute_w_u_fwd_kernel and
prepare_wy_repr_bwd_kernel) to either (preferred) add a compile-time/use-time
guard like a boolean USE_GV and wrap all gv loads and STORE_VG-dependent logic
(the p_gv/tl.load uses and computing b_vb/b_vg) behind if USE_GV so the code
never dereferences gv when absent, or alternatively make gv a required parameter
in the Python wrappers so callers cannot pass None; ensure referenced symbols
include gv, p_gv, b_gv, b_gn, STORE_VG, and vg when applying the guard so no
tl.load or tl.store touches gv/ vg unless USE_GV is true.

Comment on lines +404 to +407
if is_intel_alchemist and D > 128:
pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128')
torch.manual_seed(42)
os.environ['TRITON_F32_DEFAULT'] = 'ieee'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Two issues: incorrect skip message and global env var side-effect.

  1. Line 405: The skip reason says chunk_gated_delta_rule but should say chunk_gated_oja_rule.
  2. Line 407: Setting os.environ['TRITON_F32_DEFAULT'] = 'ieee' modifies global process state, which can leak into other tests. Consider scoping this with monkeypatch or restoring the original value in a fixture.
Proposed fix for the skip message
-        pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128')
+        pytest.skip(reason='chunk_gated_oja_rule is not supported on alchemist for D>128')
🤖 Prompt for AI Agents
In `@tests/ops/test_oja.py` around lines 404 - 407, Fix two issues: update the
skip reason text and avoid mutating global env. Change the pytest.skip call
(condition using is_intel_alchemist and D) to use the correct message
'chunk_gated_oja_rule' instead of 'chunk_gated_delta_rule'; and replace the
direct os.environ['TRITON_F32_DEFAULT'] = 'ieee' side-effect with a test-scoped
environment change (use pytest's monkeypatch to setenv or save and restore the
original value around the test) so TRITON_F32_DEFAULT is not left modified for
other tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants