-
Notifications
You must be signed in to change notification settings - Fork 380
[OJA] Integrate Gated OJA Rule #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds a complete gated OJA-rule implementation: chunked and fused-recurrent Python bindings plus many Triton kernels (KKT, hidden-state, output, WY recompute/prepare), variable-length sequence support, public API exports, and comprehensive tests validating forward and backward behavior. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant ChunkAPI as chunk_gated_oja_rule
participant ChunkFunc as ChunkOJAFunction
participant KKT as chunk_scaled_dot_kkt_fwd
participant WY as recompute_w_u_fwd
participant Hkern as chunk_oja_fwd_h
participant Okern as chunk_oja_fwd_o
User->>ChunkAPI: q,k,v,gv,beta,...
ChunkAPI->>ChunkFunc: apply forward
ChunkFunc->>KKT: compute A
KKT-->>ChunkFunc: A
ChunkFunc->>WY: recompute w/u (from k,v,A)
WY-->>ChunkFunc: w,u,vg
ChunkFunc->>Hkern: compute h (hidden states)
Hkern-->>ChunkFunc: h,final_state
ChunkFunc->>Okern: compute o
Okern-->>ChunkFunc: o
ChunkFunc-->>User: o, final_state
sequenceDiagram
actor User
participant UserAPI as fused_recurrent_gated_oja_rule
participant FusedFunc as FusedRecurrentFunction
participant FusedKernel as fused_recurrent_oja_fwd_kernel
User->>UserAPI: q,k,v,gv,beta,initial_state,...
UserAPI->>FusedFunc: forward (prepare, validate)
FusedFunc->>FusedKernel: per-timestep fused kernel (update h, o)
FusedKernel-->>FusedFunc: o, final_state
FusedFunc-->>User: o, final_state
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @AwesomeSeq, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the library's capabilities by introducing a new Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The pull request introduces gated Oja operator implementations, including chunked and fused recurrent versions, along with corresponding tests. The overall structure is well-organized, separating forward and backward passes into distinct functions and Triton kernels. The addition of comprehensive test cases, including variable-length sequences and backward pass checks, is highly commendable. However, several critical bugs related to indexing and conditional variable usage in Triton kernels have been identified, which need immediate attention to ensure correctness and prevent potential runtime errors.
| if K > 64: | ||
| o_v2 = 64 + o_v1 | ||
| b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.) | ||
| b_h2 *= exp(b_gk_last2)[None, :] | ||
| if K > 128: | ||
| o_v3 = 128 + o_v1 | ||
| b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.) | ||
| b_h3 *= exp(b_gk_last3)[None, :] | ||
| if K > 192: | ||
| o_v4 = 192 + o_v1 | ||
| b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.) | ||
| b_h4 *= exp(b_gk_last4)[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if conditions for loading b_gk_lastX are checking K (e.g., if K > 64) but the offsets o_vX and the b_hX variables are related to the V dimension. This is a logical error and can lead to incorrect memory access or calculations if K and V dimensions do not align with these hardcoded checks. The conditions should be based on V.
| if K > 64: | |
| o_v2 = 64 + o_v1 | |
| b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.) | |
| b_h2 *= exp(b_gk_last2)[None, :] | |
| if K > 128: | |
| o_v3 = 128 + o_v1 | |
| b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.) | |
| b_h3 *= exp(b_gk_last3)[None, :] | |
| if K > 192: | |
| o_v4 = 192 + o_v1 | |
| b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.) | |
| b_h4 *= exp(b_gk_last4)[None, :] | |
| b_h1 *= exp(b_gk_last1)[None, :] | |
| if V > 64: | |
| o_v2 = 64 + o_v1 | |
| b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.) | |
| b_h2 *= exp(b_gk_last2)[None, :] | |
| if V > 128: | |
| o_v3 = 128 + o_v1 | |
| b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.) | |
| b_h3 *= exp(b_gk_last3)[None, :] | |
| if V > 192: | |
| o_v4 = 192 + o_v1 | |
| b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < V), other=0.) | |
| b_h4 *= exp(b_gk_last4)[None, :] |
| b_dgv_last += tl.sum((b_h * b_dh) * exp(b_gn), axis=0) | ||
|
|
||
| if USE_GV: | ||
| b_dv = b_dvg * exp(b_gn[None, :] - b_gv) | ||
|
|
||
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
|
|
||
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables b_gn and b_dv are used in calculations (b_dgv_last += tl.sum((b_h * b_dh) * exp(b_gn), axis=0) and b_dgv_last += tl.sum(b_dv * b_v, axis=0)) without being initialized or conditionally used when USE_GV is false. This will lead to a runtime error or undefined behavior if gv is None.
| b_dgv_last += tl.sum((b_h * b_dh) * exp(b_gn), axis=0) | |
| if USE_GV: | |
| b_dv = b_dvg * exp(b_gn[None, :] - b_gv) | |
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| b_v = tl.load(p_v, boundary_check=(0, 1)) | |
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) | |
| if USE_GV: | |
| b_dv = b_dvg * exp(b_gn[None, :] - b_gv) | |
| b_dgv_last += tl.sum((b_h * b_dh) * exp(b_gn), axis=0) | |
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) | |
| else: | |
| b_dv = b_dvg | |
| b_dgv_last += tl.sum(b_h * b_dh, axis=0) | |
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) |
| return | ||
|
|
||
| p_g = tl.make_block_ptr(gv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) | ||
| p_gn = gv + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation of p_gn uses min(i_t * BT + i_i * BC, T). This likely intends to get the last valid index of the current block. However, it should typically be min(i_t * BT + i_i * BC + BC - 1, T - 1) to correctly represent the last element's index within the block, ensuring it doesn't go out of bounds if T is exactly the end of the sequence or if the block is not full. Using T directly can lead to out-of-bounds access if T is the total length and not an index.
| p_gn = gv + (bos + min(i_t * BT + i_i * BC, T)) * H*V + i_h * V + o_v | |
| p_gn = gv + (bos + min(i_t * BT + i_i * BC + BC - 1, T - 1)) * H*V + i_h * V + o_v |
| if i_i > i_j: | ||
| p_v = tl.make_block_ptr(v + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) | ||
| p_gv = tl.make_block_ptr(gv + (bos*H+i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t*BT + i_j*BC), (BV, BC), (0, 1)) | ||
| p_gn = gv + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to chunk_oja_fwd_intra, the calculation of p_gn here (gv + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v) is likely incorrect for getting the last valid index of the current block. It should be min(i_t * BT + i_i * BC + BC - 1, T - 1) to ensure correct indexing and prevent potential out-of-bounds access.
| p_gn = gv + (bos + i_t*BT + i_i*BC) * H*V + i_h * V + o_v | |
| p_gn = gv + (bos + min(i_t*BT + i_i*BC + BC - 1, T - 1)) * H*V + i_h * V + o_v |
| if USE_GV: | ||
| o_v1 = tl.arange(0, 64) | ||
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | ||
| b_dh1 *= exp(b_gv_last1[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] | ||
|
|
||
| if V > 64: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v2 = 64 + o_v1 | ||
| b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) | ||
| b_dh2 *= exp(b_gv_last2[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | ||
|
|
||
| if V > 128: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v3 = 128 + o_v1 | ||
| b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) | ||
| b_dh3 *= exp(b_gv_last3[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | ||
|
|
||
| if V > 192: | ||
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_w = tl.load(p_w, boundary_check=(0, 1)) | ||
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v4 = 192 + o_v1 | ||
| b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) | ||
| b_dh4 *= exp(b_gv_last4[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the chunk_oja_bwd_kernel_dhu_blockdim64 kernel, the b_dhX variables are multiplied by exp(b_gv_lastX) unconditionally if USE_GV is true. However, b_gv_lastX (for X=2,3,4) are loaded only if V is greater than a certain threshold (e.g., V > 64). If V is smaller, these b_gv_lastX variables might contain uninitialized or garbage values, leading to incorrect calculations. Each multiplication should be guarded by the corresponding if V > ... condition.
| if USE_GV: | |
| o_v1 = tl.arange(0, 64) | |
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | |
| b_dh1 *= exp(b_gv_last1[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] | |
| if V > 64: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 64), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v2 = 64 + o_v1 | |
| b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) | |
| b_dh2 *= exp(b_gv_last2[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if V > 128: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 128), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v3 = 128 + o_v1 | |
| b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) | |
| b_dh3 *= exp(b_gv_last3[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if V > 192: | |
| p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) | |
| b_do = tl.load(p_do, boundary_check=(0, 1)) | |
| p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | |
| b_w = tl.load(p_w, boundary_check=(0, 1)) | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 192), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v4 = 192 + o_v1 | |
| b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) | |
| b_dh4 *= exp(b_gv_last4[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) | |
| if USE_GV: | |
| o_v1 = tl.arange(0, 64) | |
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | |
| b_dh1 *= exp(b_gv_last1)[None, :] | |
| b_do *= exp(b_gv) | |
| if V > 64 and USE_GV: | |
| o_v2 = 64 + o_v1 | |
| b_gv_last2 = tl.load(gv + last_idx * H*V + o_v2, mask=(o_v2 < V), other=0.) | |
| b_dh2 *= exp(b_gv_last2)[None, :] | |
| b_do *= exp(b_gv) | |
| if V > 128 and USE_GV: | |
| o_v3 = 128 + o_v1 | |
| b_gv_last3 = tl.load(gv + last_idx * H*V + o_v3, mask=(o_v3 < V), other=0.) | |
| b_dh3 *= exp(b_gv_last3)[None, :] | |
| b_do *= exp(b_gv) | |
| if V > 192 and USE_GV: | |
| o_v4 = 192 + o_v1 | |
| b_gv_last4 = tl.load(gv + last_idx * H*V + o_v4, mask=(o_v4 < V), other=0.) | |
| b_dh4 *= exp(b_gv_last4)[None, :] | |
| b_do *= exp(b_gv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
🤖 Fix all issues with AI agents
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Line 93: Remove the unused BV variable assignment (BV=64) in chunk_h.py:
delete the BV definition since the kernel uses a hardcoded 64 and BV is never
referenced elsewhere; ensure no other code in the module refers to BV and run
lint/tests to confirm no remaining references (look for the symbol BV in
chunk_h.py to locate the line to remove).
- Around line 148-163: The conditional blocks that apply gv scaling incorrectly
compare against K instead of V; change all occurrences of "if K > 64/128/192" to
"if V > 64/128/192" and update the final load mask from "(o_v4 < K)" to "(o_v4 <
V)" so the gv loads and masks (e.g., in the blocks computing
b_gk_last1..b_gk_last4 and multiplying b_h1..b_h4) correctly use the
value-dimension V rather than the key-dimension K.
- Around line 747-767: The code uses b_dv unconditionally but only defines it
inside the if USE_GV branch; to fix, ensure b_dv is always assigned: when USE_GV
is True compute b_dv = b_dvg * exp(b_gn[None, :] - b_gv) as before, otherwise
initialize b_dv to a zero tensor with the same shape and dtype used later (shape
[BT, BV] matching b_v and p_dv.element_ty) so subsequent operations (b_dgv_last
update, tl.store(p_dv, ...), and interaction with b_v) work correctly; update
the block so b_dv, b_dvg, b_gn, b_gv, p_dv, b_v, and b_dgv_last remain the
referenced symbols.
- Around line 396-403: The code unconditionally creates and loads p_gv and b_gv
(using gv, p_gv, b_gv) inside the V>0 handling even when gv may be None if
USE_GV is False; wrap the creation of p_gv and any tl.load(gv + ...) or
tl.load(p_gv, ...) calls with a guard on USE_GV (same pattern used elsewhere
when gv is offset at line 319) so that all accesses to gv happen only when
USE_GV is True, and apply the same guard pattern to the other V>64, V>128, and
V>192 blocks to prevent null pointer/runtime loads when gv is not provided.
- Around line 531-651: The function chunk_gsa_bwd_k_kernel_dqkvg defined in this
file is dead/duplicated and should be removed: delete the entire
chunk_gsa_bwd_k_kernel_dqkvg(...) definition from
fla/ops/gated_oja_rule/chunk_h.py so the codebase uses the single implementation
in fla/ops/gsa/chunk.py; after removal, run tests and search for any local
references to chunk_gsa_bwd_k_kernel_dqkvg to ensure no callers depend on this
definition and update imports/call sites to reference the gsa implementation if
needed.
In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 8-15: Remove the redundant and unused imports: delete the import
of exp from fla.ops.utils.op and the import of chunk_local_cumsum from
fla.ops.utils.cumsum, keeping the intended tl.exp assignment (exp = tl.exp) as
the single definition of exp; ensure no other code depends on
fla.ops.utils.op.exp or chunk_local_cumsum in this file (references to exp
should use the tl-backed exp symbol).
In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 1-2: The file header lines have duplicated comment markers ("#
#"), so remove the extra '#' characters in those header comments: change the
leading "# # -*- coding: utf-8 -*-" to "# -*- coding: utf-8 -*-" and similarly
change "# # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang" to "# Copyright (c)
2023-2025, Songlin Yang, Yu Zhang" to restore proper comment syntax.
In `@fla/ops/gated_oja_rule/fused_recurrent.py`:
- Around line 95-97: The load of gv must apply the mask for partial vector
blocks to avoid OOB reads: when USE_GV is true, change the load of p_gv (symbol
b_gv) to use mask_v (the mask for the last V block) instead of an unconditional
tl.load; keep the subsequent scaling of b_h (symbol b_h *= exp(b_gv[None, :]))
the same so that only valid lanes are loaded and used when V % BV != 0.
In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 199-237: Update the return type annotation of recompute_w_u_fwd to
match the actual returned values (w, u, vg): change the declared return from
Tuple[torch.Tensor, torch.Tensor] to Tuple[torch.Tensor, torch.Tensor,
Optional[torch.Tensor]] and import Optional if not already present; ensure the
function signature and any callers/types align with the new signature for
recompute_w_u_fwd.
- Around line 247-261: The gv parameter is declared without Optional typing and
the code unconditionally allocates dgv with torch.empty_like(gv), which will
crash if gv is None; update the function signature to annotate gv as
Optional[torch.Tensor] and change the local dgv to be Optional[torch.Tensor] (or
torch.Tensor | None) and only allocate dgv when gv is not None (e.g., after
checking gv) — leave dgv as None otherwise; ensure any later uses of dgv handle
the None case or assert/raise if those code paths require gv to be present.
🧹 Nitpick comments (12)
fla/ops/gated_oja_rule/chunk_kkt.py (1)
131-143: Inconsistent naming convention for block pointer.The variable
b_ktat line 131 is a block pointer (created viatl.make_block_ptr), but follows theb_prefix convention used for block tensors throughout this file. Consider renaming top_ktfor consistency with other pointers (p_k,p_g, etc.).♻️ Suggested naming fix
- b_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_kt = tl.make_block_ptr(k, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))And at line 143:
- b_kt = tl.load(b_kt, boundary_check=(0, 1)) * exp(b_gn[:, None] - b_gk) + b_kt = tl.load(p_kt, boundary_check=(0, 1)) * exp(b_gn[:, None] - b_gk)fla/ops/gated_oja_rule/wy_fast.py (2)
11-11: Remove unused import.The static analysis correctly identifies that
chunk_local_cumsumis imported but not used in this file.♻️ Proposed fix
-from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices +from fla.ops.utils import prepare_chunk_indices
193-193: Remove or complete the commented-out code.Line 193 contains a commented-out conditional
# if USE_GV:. This appears to be either dead code or an incomplete TODO. Please remove it or implement the intended logic.fla/ops/gated_oja_rule/chunk.py (1)
283-287: Addstackleveltowarnings.warn.Per best practices, specify
stacklevel=2so the warning points to the caller's location rather than this line.♻️ Proposed fix
if 'head_first' in kwargs: warnings.warn( "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead." + "Please use head_first=False for now instead.", + stacklevel=2 )fla/ops/gated_oja_rule/fused_recurrent.py (1)
133-133: Document or relax the V <= 128 constraint.The assertion
assert V <= 128limits the value dimension without explanation. Consider adding a comment explaining why this limit exists, or raising a more informative error.♻️ Suggested improvement
- assert V <= 128 + if V > 128: + raise ValueError( + f"fused_recurrent_oja_fwd currently supports V <= 128, got V={V}. " + "Use chunk_gated_oja_rule for larger value dimensions." + )tests/ops/test_oja.py (4)
4-4: Remove unused imports.
Optionalfromtypingandrepeatfromeinopsare imported but never used.♻️ Proposed fix
-from typing import List, Optional +from typing import List-from einops import rearrange, repeat +from einops import rearrange
82-82: Rename ambiguous variablel.The variable
lat line 82 is flagged by linters as ambiguous (looks like1). Consider renaming toseq_lenorLfor clarity.♻️ Proposed fix
- b, h, l, d_k = q.shape + b, h, seq_len, d_k = q.shape d_v = v.shape[-1] q = q * scale # B H T D - assert l % chunk_size == 0 + assert seq_len % chunk_size == 0And update other usages of
l(lines 85, 121) toseq_len.
341-341: Remove debug print statement.The
♻️ Proposed fix
- print('================== Running forward and backward ==================')
412-412: Consider isolating environment variable modification.Setting
os.environ['TRITON_F32_DEFAULT']at line 412 persists beyond this test and may affect subsequent tests. Consider using a fixture or context manager to ensure cleanup.♻️ Suggested approach
`@pytest.fixture`(autouse=True) def set_triton_f32_default(): old_value = os.environ.get('TRITON_F32_DEFAULT') os.environ['TRITON_F32_DEFAULT'] = 'ieee' yield if old_value is None: del os.environ['TRITON_F32_DEFAULT'] else: os.environ['TRITON_F32_DEFAULT'] = old_valuefla/ops/gated_oja_rule/chunk_h.py (2)
480-484: Unusedchunk_indicescomputation.
chunk_indicesis computed but never passed to the kernel. Consider removing this unnecessary computation.Proposed fix
- chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: - N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT) + N = len(cu_seqlens) - 1 + NT = chunk_offsets[-1].item() # or compute directly
386-386: Use ASCII commas in comments for consistency.The comment contains fullwidth commas (,) which triggers linter warnings. Consider using standard ASCII commas or translating comments to English.
fla/ops/gated_oja_rule/chunk_o.py (1)
453-461: Redundant computation of attention matrixA.The attention matrix
Ais computed identically in eachi_kblock and then summed (line 540), which is wasteful. SinceA = dot(q*scale, k.T)is the same regardless of which K-block is being processed, this results inNKredundant computations.Consider computing
Aonce in a separate kernel or only in the first K-block.
| if USE_GV: | ||
| o_v1 = tl.arange(0, 64) | ||
| b_gk_last1 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v1, mask=(o_v1 < V), other=0.) | ||
| b_h1 *= exp(b_gk_last1)[None, :] | ||
| if K > 64: | ||
| o_v2 = 64 + o_v1 | ||
| b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.) | ||
| b_h2 *= exp(b_gk_last2)[None, :] | ||
| if K > 128: | ||
| o_v3 = 128 + o_v1 | ||
| b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.) | ||
| b_h3 *= exp(b_gk_last3)[None, :] | ||
| if K > 192: | ||
| o_v4 = 192 + o_v1 | ||
| b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.) | ||
| b_h4 *= exp(b_gk_last4)[None, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: Incorrect dimension check for gv scaling.
The conditionals check K (key dimension) but should check V (value dimension). The gv tensor has shape [..., V], not [..., K]. This causes incorrect scaling when K != V.
Proposed fix
if USE_GV:
o_v1 = tl.arange(0, 64)
b_gk_last1 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v1, mask=(o_v1 < V), other=0.)
b_h1 *= exp(b_gk_last1)[None, :]
- if K > 64:
+ if V > 64:
o_v2 = 64 + o_v1
b_gk_last2 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v2, mask=(o_v2 < V), other=0.)
b_h2 *= exp(b_gk_last2)[None, :]
- if K > 128:
+ if V > 128:
o_v3 = 128 + o_v1
b_gk_last3 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v3, mask=(o_v3 < V), other=0.)
b_h3 *= exp(b_gk_last3)[None, :]
- if K > 192:
+ if V > 192:
o_v4 = 192 + o_v1
- b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < K), other=0.)
+ b_gk_last4 = tl.load(gv + (bos + last_idx) * H*V + i_h * V + o_v4, mask=(o_v4 < V), other=0.)
b_h4 *= exp(b_gk_last4)[None, :]🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 148 - 163, The conditional
blocks that apply gv scaling incorrectly compare against K instead of V; change
all occurrences of "if K > 64/128/192" to "if V > 64/128/192" and update the
final load mask from "(o_v4 < K)" to "(o_v4 < V)" so the gv loads and masks
(e.g., in the blocks computing b_gk_last1..b_gk_last4 and multiplying
b_h1..b_h4) correctly use the value-dimension V rather than the key-dimension K.
fla/ops/gated_oja_rule/chunk_h.py
Outdated
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| if USE_GV: | ||
| o_v1 = tl.arange(0, 64) | ||
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | ||
| b_dh1 *= exp(b_gv_last1[None, :]) | ||
| b_do *= exp(b_gv) | ||
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential null pointer access when USE_GV is False.
The code loads from gv (lines 396-397) unconditionally within the if V > 0 block, but gv is only offset at line 319 when USE_GV is True. When USE_GV is False and gv is None, this will cause a runtime error.
Proposed fix: Guard gv loads with USE_GV check
if V > 0:
p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))
p_w = tl.make_block_ptr(w, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
b_w = tl.load(p_w, boundary_check=(0, 1))
- p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
- b_gv = tl.load(p_gv, boundary_check=(0, 1))
if USE_GV:
+ p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV]
+ b_gv = tl.load(p_gv, boundary_check=(0, 1))
o_v1 = tl.arange(0, 64)
b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.)
b_dh1 *= exp(b_gv_last1[None, :])
b_do *= exp(b_gv)
b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV]Apply similar changes for the if V > 64, if V > 128, and if V > 192 blocks.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| if USE_GV: | |
| o_v1 = tl.arange(0, 64) | |
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | |
| b_dh1 *= exp(b_gv_last1[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] | |
| if USE_GV: | |
| p_gv = tl.make_block_ptr(gv, (T, V), (stride_v, 1), (i_t * BT, 0), (BT, 64), (1, 0)) # [BT, BV] | |
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | |
| o_v1 = tl.arange(0, 64) | |
| b_gv_last1 = tl.load(gv + last_idx * H*V + o_v1, mask=(o_v1 < V), other=0.) | |
| b_dh1 *= exp(b_gv_last1[None, :]) | |
| b_do *= exp(b_gv) | |
| b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(tl.trans(b_dk).to(b_w.dtype), b_w) # [BK, BT] @ [BT, BV] - [BK, BT] @ [BT, BV] |
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 396 - 403, The code
unconditionally creates and loads p_gv and b_gv (using gv, p_gv, b_gv) inside
the V>0 handling even when gv may be None if USE_GV is False; wrap the creation
of p_gv and any tl.load(gv + ...) or tl.load(p_gv, ...) calls with a guard on
USE_GV (same pattern used elsewhere when gv is offset at line 319) so that all
accesses to gv happen only when USE_GV is True, and apply the same guard pattern
to the other V>64, V>128, and V>192 blocks to prevent null pointer/runtime loads
when gv is not provided.
fla/ops/gated_oja_rule/chunk_h.py
Outdated
| def chunk_gsa_bwd_k_kernel_dqkvg( | ||
| q, | ||
| k, | ||
| v, | ||
| h, | ||
| g, | ||
| A, | ||
| do, | ||
| dh, | ||
| dq, | ||
| dk, | ||
| dv, | ||
| dg, | ||
| dgv, | ||
| dA, | ||
| cu_seqlens, | ||
| chunk_indices, | ||
| scale, | ||
| T, | ||
| B: tl.constexpr, | ||
| HQ: tl.constexpr, | ||
| H: tl.constexpr, | ||
| K: tl.constexpr, | ||
| V: tl.constexpr, | ||
| BT: tl.constexpr, | ||
| BK: tl.constexpr, | ||
| BV: tl.constexpr, | ||
| NG: tl.constexpr, | ||
| IS_VARLEN: tl.constexpr, | ||
| ): | ||
| i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) | ||
| i_b, i_hq = i_bh // HQ, i_bh % HQ | ||
| i_h = i_hq // NG | ||
| if IS_VARLEN: | ||
| i_tg = i_t | ||
| i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) | ||
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) | ||
| all = T | ||
| T = eos - bos | ||
| NT = tl.cdiv(T, BT) | ||
| else: | ||
| NT = tl.cdiv(T, BT) | ||
| i_tg = i_b * NT + i_t | ||
| bos, eos = i_b * T, i_b * T + T | ||
| all = B * T | ||
|
|
||
| o_i = tl.arange(0, BT) | ||
| o_t = min(i_t * BT + BT, T) | ||
| m_s = o_i[:, None] >= o_i[None, :] | ||
|
|
||
| p_q = tl.make_block_ptr(q + (bos*HQ+i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_k = tl.make_block_ptr(k + (bos*H+i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_A = tl.make_block_ptr(A + ((i_k*all+bos)*HQ+i_hq)*BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) | ||
|
|
||
| # [BT, BK] | ||
| b_q = tl.load(p_q, boundary_check=(0, 1)) | ||
| b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
| # [BT, BT] | ||
| b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k)) | ||
| b_A = tl.where(m_s, b_A, 0.) | ||
| tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) | ||
|
|
||
| b_dq = tl.zeros([BT, BK], dtype=tl.float32) | ||
| b_dk = tl.zeros([BT, BK], dtype=tl.float32) | ||
| for i_v in range(tl.cdiv(V, BV)): | ||
| o_v = i_v * BV + tl.arange(0, BV) | ||
| p_v = tl.make_block_ptr(v + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_g = tl.make_block_ptr(g + (bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_gn = g + (bos + o_t - 1) * H*V + i_h * V + o_v | ||
| p_do = tl.make_block_ptr(do + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dv = tl.make_block_ptr(dv + ((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dg = tl.make_block_ptr(dg + (bos*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dgv = tl.make_block_ptr(dgv+((i_k*all+bos)*HQ+i_hq)*V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) | ||
| p_dh = tl.make_block_ptr(dh + (i_tg * HQ + i_hq) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) | ||
| m_v = o_v < V | ||
|
|
||
| # [BV,] | ||
| b_gn = tl.load(p_gn, mask=m_v, other=0) | ||
| # [BT, BV] | ||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
| b_g = tl.load(p_g, boundary_check=(0, 1)) | ||
| b_gv = exp(b_gn[None, :] - b_g) | ||
| # [BV, BK] | ||
| b_h = tl.load(p_h, boundary_check=(0, 1)) | ||
| # [BT, BV] | ||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||
| b_do = (b_do * exp(b_g) * scale).to(b_do.dtype) | ||
| # [BK, BV] | ||
| b_dh = tl.load(p_dh, boundary_check=(0, 1)) | ||
| # [BV] | ||
| b_dg = tl.sum(tl.trans(b_h) * b_dh, 0) * exp(b_gn) | ||
|
|
||
| b_dh = b_dh.to(b_k.dtype) | ||
| # [BT, BK] | ||
| b_dq += tl.dot(b_do, b_h.to(b_k.dtype)) | ||
| b_dk += tl.dot((b_v * b_gv).to(b_v.dtype), tl.trans(b_dh)) | ||
| # [BT, BV] | ||
| b_dv = tl.dot(b_k, b_dh) * b_gv | ||
| # [BV] | ||
| b_dg += tl.sum(b_dv * b_v, 0) | ||
|
|
||
| if i_k == 0: | ||
| b_dgv = tl.load(p_dg, boundary_check=(0, 1)) + b_dg[None, :] | ||
| else: | ||
| b_dgv = tl.zeros([BT, BV], dtype=tl.float32) + b_dg[None, :] | ||
|
|
||
| tl.store(p_dgv, b_dgv.to(p_dgv.dtype.element_ty), boundary_check=(0, 1)) | ||
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) | ||
| p_dA = tl.make_block_ptr(dA + (bos*HQ + i_hq) * BT, (T, BT), (HQ*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) | ||
| p_dq = tl.make_block_ptr(dq + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| p_dk = tl.make_block_ptr(dk + (bos*HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||
| # [BT, BT] | ||
| b_dA = tl.load(p_dA, boundary_check=(0, 1)) | ||
| # [BT, BK] | ||
| b_dq += tl.dot(b_dA, b_k) | ||
| b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q) | ||
|
|
||
| tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) | ||
| tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
rg -n "chunk_gsa_bwd_k_kernel_dqkvg" --type pyRepository: fla-org/flash-linear-attention
Length of output: 269
🏁 Script executed:
# Check the kernel definition in both files
echo "=== gated_oja_rule/chunk_h.py definition ==="
sed -n '531,533p' fla/ops/gated_oja_rule/chunk_h.py
echo -e "\n=== gsa/chunk.py definition ==="
sed -n '297,299p' fla/ops/gsa/chunk.py
# Get some context about imports and structure in gated_oja_rule/chunk_h.py
echo -e "\n=== gated_oja_rule/chunk_h.py file overview ==="
head -30 fla/ops/gated_oja_rule/chunk_h.py
# Check if the kernel is imported from gsa or defined locally
echo -e "\n=== Checking for imports of chunk_gsa_bwd_k_kernel_dqkvg in gated_oja_rule ==="
rg "import.*chunk_gsa_bwd_k_kernel_dqkvg|from.*chunk_gsa" fla/ops/gated_oja_rule/Repository: fla-org/flash-linear-attention
Length of output: 1349
🏁 Script executed:
# Check if chunk_gsa_bwd_k_kernel_dqkvg is called anywhere in gated_oja_rule/chunk_h.py
echo "=== Searching for calls to chunk_gsa_bwd_k_kernel_dqkvg in gated_oja_rule ==="
rg "chunk_gsa_bwd_k_kernel_dqkvg\[" fla/ops/gated_oja_rule/chunk_h.py
# Check what's defined in gated_oja_rule/chunk_h.py (list all function/kernel definitions)
echo -e "\n=== All kernel definitions in gated_oja_rule/chunk_h.py ==="
rg "^def chunk.*kernel" fla/ops/gated_oja_rule/chunk_h.py
# See the whole file length
echo -e "\n=== File length ==="
wc -l fla/ops/gated_oja_rule/chunk_h.py
# Check what's exported or used from gated_oja_rule
echo -e "\n=== Functions called/exported in gated_oja_rule ==="
rg "^def [a-z_]+\(" fla/ops/gated_oja_rule/chunk_h.py | head -20Repository: fla-org/flash-linear-attention
Length of output: 608
🏁 Script executed:
# Compare the function definition at the end to see if there's a wrapper function
echo "=== Last 30 lines of gated_oja_rule/chunk_h.py ==="
tail -30 fla/ops/gated_oja_rule/chunk_h.py
# Check if there's any Python function wrapper that calls this kernel
echo -e "\n=== Checking for wrapper functions that might call the kernel ==="
rg -A 5 "def chunk_gsa_bwd_k" fla/ops/gated_oja_rule/chunk_h.py | head -40Repository: fla-org/flash-linear-attention
Length of output: 950
🏁 Script executed:
# Get line count of the kernel definition to understand its size
echo "=== Getting the kernel definition size ==="
awk '/^def chunk_gsa_bwd_k_kernel_dqkvg/,/^def / {print NR": "$0}' fla/ops/gated_oja_rule/chunk_h.py | tail -5
# Compare the definitions more carefully - get line ranges
echo -e "\n=== Checking what comes after the kernel definition ==="
sed -n '650,655p' fla/ops/gated_oja_rule/chunk_h.py
# Check if there are any public wrapper functions that use this kernel
echo -e "\n=== Public functions that might export or use the kernel ==="
rg "^def chunk_gsa_bwd" fla/ops/gated_oja_rule/chunk_h.pyRepository: fla-org/flash-linear-attention
Length of output: 403
🏁 Script executed:
# Let me check the exports or __all__ to see if it's intentionally exposed
echo "=== Check __all__ or public exports in gated_oja_rule/chunk_h.py ==="
rg "__all__|^from|^import" fla/ops/gated_oja_rule/chunk_h.py | head -20
# Check if the gated_oja_rule module imports from gsa
echo -e "\n=== Check if gated_oja_rule imports the kernel from gsa ==="
rg "from.*gsa.*import|import.*gsa" fla/ops/gated_oja_rule/
# Verify the exact line where the kernel ends in gated_oja_rule
echo -e "\n=== Lines 645-655 to see what comes after kernel ==="
sed -n '645,655p' fla/ops/gated_oja_rule/chunk_h.pyRepository: fla-org/flash-linear-attention
Length of output: 815
🏁 Script executed:
# Check the wrapper function pattern in gsa/chunk.py around the kernel call
echo "=== Context around the kernel call in gsa/chunk.py (line 739) ==="
sed -n '730,750p' fla/ops/gsa/chunk.py
# Check if there's a public wrapper function in gated_oja_rule that might call chunk_gsa_bwd_k_kernel_dqkvg
echo -e "\n=== All public functions (non-kernel) in gated_oja_rule/chunk_h.py ==="
rg "^def [a-z_]+\(" fla/ops/gated_oja_rule/chunk_h.py
# Check if any of these functions contain the kernel grid call
echo -e "\n=== Check if chunk_gsa_bwd_k_kernel_dqkvg is called via grid ==="
rg "chunk_gsa_bwd_k_kernel_dqkvg\[" fla/ops/gated_oja_rule/chunk_h.pyRepository: fla-org/flash-linear-attention
Length of output: 848
🏁 Script executed:
# Double check: search for any function that might be using the kernel with array indexing
echo "=== Final verification: any use of chunk_gsa_bwd_k_kernel_dqkvg[ in gated_oja_rule ==="
grep -n "chunk_gsa_bwd_k_kernel_dqkvg\[" fla/ops/gated_oja_rule/chunk_h.py || echo "Not found"
# Check module structure - is this module even exposed?
echo -e "\n=== Check __init__.py to see what's exported from gated_oja_rule ==="
find fla/ops/gated_oja_rule -name "__init__.py" -exec cat {} \;Repository: fla-org/flash-linear-attention
Length of output: 420
Remove chunk_gsa_bwd_k_kernel_dqkvg as it is unused dead code in this module.
This kernel is defined in fla/ops/gsa/chunk.py and actively used there, but the definition in fla/ops/gated_oja_rule/chunk_h.py is never called. It should be removed to avoid maintaining duplicate code and confusion about the module's purpose (OJA operations).
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 531 - 651, The function
chunk_gsa_bwd_k_kernel_dqkvg defined in this file is dead/duplicated and should
be removed: delete the entire chunk_gsa_bwd_k_kernel_dqkvg(...) definition from
fla/ops/gated_oja_rule/chunk_h.py so the codebase uses the single implementation
in fla/ops/gsa/chunk.py; after removal, run tests and search for any local
references to chunk_gsa_bwd_k_kernel_dqkvg to ensure no callers depend on this
definition and update imports/call sites to reference the gsa implementation if
needed.
| if USE_GV: | ||
| b_dv = b_dvg * exp(b_gn[None, :] - b_gv) | ||
|
|
||
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
|
|
||
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) | ||
|
|
||
| # 留给GSA2的接口 | ||
| if HAVE_GK: | ||
| dgk += (bos * H + i_h) * V | ||
| p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) | ||
| b_dgv_last = b_dgk + b_dgv_last[None, :] | ||
| else: | ||
| b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :] | ||
|
|
||
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: b_dv undefined when USE_GV is False.
When USE_GV is False, b_dv is never assigned (line 748 is inside if USE_GV), but it's used unconditionally at lines 756 and 767. This will cause a runtime error.
Proposed fix
if USE_GV:
b_dv = b_dvg * exp(b_gn[None, :] - b_gv)
-
+ else:
+ b_dv = b_dvg
+
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if USE_GV: | |
| b_dv = b_dvg * exp(b_gn[None, :] - b_gv) | |
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| b_v = tl.load(p_v, boundary_check=(0, 1)) | |
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) | |
| # 留给GSA2的接口 | |
| if HAVE_GK: | |
| dgk += (bos * H + i_h) * V | |
| p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) | |
| b_dgv_last = b_dgk + b_dgv_last[None, :] | |
| else: | |
| b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :] | |
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) | |
| if USE_GV: | |
| b_dv = b_dvg * exp(b_gn[None, :] - b_gv) | |
| else: | |
| b_dv = b_dvg | |
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dw = tl.make_block_ptr(dw, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| p_dgv_last = tl.make_block_ptr(dgv_last, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| b_v = tl.load(p_v, boundary_check=(0, 1)) | |
| b_dgv_last += tl.sum(b_dv * b_v, axis=0) | |
| # 留给GSA2的接口 | |
| if HAVE_GK: | |
| dgk += (bos * H + i_h) * V | |
| p_dgk = tl.make_block_ptr(dgk, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | |
| b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) | |
| b_dgv_last = b_dgk + b_dgv_last[None, :] | |
| else: | |
| b_dgv_last = tl.zeros([BT, BV], dtype=tl.float32) + b_dgv_last[None, :] | |
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) |
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 747 - 767, The code uses b_dv
unconditionally but only defines it inside the if USE_GV branch; to fix, ensure
b_dv is always assigned: when USE_GV is True compute b_dv = b_dvg *
exp(b_gn[None, :] - b_gv) as before, otherwise initialize b_dv to a zero tensor
with the same shape and dtype used later (shape [BT, BV] matching b_v and
p_dv.element_ty) so subsequent operations (b_dgv_last update, tl.store(p_dv,
...), and interaction with b_v) work correctly; update the block so b_dv, b_dvg,
b_gn, b_gv, p_dv, b_v, and b_dgv_last remain the referenced symbols.
fla/ops/gated_oja_rule/chunk_o.py
Outdated
| from fla.ops.utils.op import exp | ||
| from fla.utils import check_shared_mem, is_nvidia_hopper | ||
| from fla.ops.utils.cumsum import chunk_local_cumsum | ||
|
|
||
| BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] | ||
| NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] | ||
|
|
||
| exp = tl.exp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove redundant import and duplicate exp definition.
exp is imported from fla.ops.utils.op at line 8 but immediately overwritten with tl.exp at line 15. Also, chunk_local_cumsum is imported but never used. Remove the unused imports.
Proposed fix
from fla.ops.utils import prepare_chunk_indices
-from fla.ops.utils.op import exp
from fla.utils import check_shared_mem, is_nvidia_hopper
-from fla.ops.utils.cumsum import chunk_local_cumsum
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
exp = tl.exp🧰 Tools
🪛 Flake8 (7.3.0)
[error] 10-10: 'fla.ops.utils.cumsum.chunk_local_cumsum' imported but unused
(F401)
[error] 15-15: redefinition of unused 'exp' from line 8
(F811)
🪛 GitHub Actions: lint
[error] 13-13: Ruff: F811 Redefinition of unused 'exp' from line 7.
🪛 Ruff (0.14.14)
15-15: Redefinition of unused exp from line 8: exp redefined here
(F811)
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_o.py` around lines 8 - 15, Remove the redundant
and unused imports: delete the import of exp from fla.ops.utils.op and the
import of chunk_local_cumsum from fla.ops.utils.cumsum, keeping the intended
tl.exp assignment (exp = tl.exp) as the single definition of exp; ensure no
other code depends on fla.ops.utils.op.exp or chunk_local_cumsum in this file
(references to exp should use the tl-backed exp symbol).
fla/ops/gated_oja_rule/chunk.py
Outdated
| # # -*- coding: utf-8 -*- | ||
| # # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix comment syntax.
Line 1 has doubled comment markers # # which appears to be a typo.
🐛 Proposed fix
-# # -*- coding: utf-8 -*-
+# -*- coding: utf-8 -*-📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # # -*- coding: utf-8 -*- | |
| # # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | |
| # -*- coding: utf-8 -*- | |
| # # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang |
🧰 Tools
🪛 GitHub Actions: lint
[error] 1-1: Trailing whitespace detected by pre-commit; file was modified.
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk.py` around lines 1 - 2, The file header lines
have duplicated comment markers ("# #"), so remove the extra '#' characters in
those header comments: change the leading "# # -*- coding: utf-8 -*-" to "# -*-
coding: utf-8 -*-" and similarly change "# # Copyright (c) 2023-2025, Songlin
Yang, Yu Zhang" to "# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang" to
restore proper comment syntax.
| if USE_GV: | ||
| b_gv = tl.load(p_gv).to(tl.float32) | ||
| b_h *= exp(b_gv[None, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add mask to gv load for partial blocks.
When V is not a multiple of BV, the load at line 96 should use mask_v to avoid reading out-of-bounds memory.
🐛 Proposed fix
if USE_GV:
- b_gv = tl.load(p_gv).to(tl.float32)
+ b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32)
b_h *= exp(b_gv[None, :])📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if USE_GV: | |
| b_gv = tl.load(p_gv).to(tl.float32) | |
| b_h *= exp(b_gv[None, :]) | |
| if USE_GV: | |
| b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) | |
| b_h *= exp(b_gv[None, :]) |
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/fused_recurrent.py` around lines 95 - 97, The load of
gv must apply the mask for partial vector blocks to avoid OOB reads: when USE_GV
is true, change the load of p_gv (symbol b_gv) to use mask_v (the mask for the
last V block) instead of an unconditional tl.load; keep the subsequent scaling
of b_h (symbol b_h *= exp(b_gv[None, :])) the same so that only valid lanes are
loaded and used when V % BV != 0.
| def recompute_w_u_fwd( | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| beta: torch.Tensor, | ||
| A: torch.Tensor, | ||
| gv: Optional[torch.Tensor] = None, | ||
| cu_seqlens: Optional[torch.LongTensor] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| B, T, H, K, V = *k.shape, v.shape[-1] | ||
| BT = A.shape[-1] | ||
| BK = 64 | ||
| BV = 64 | ||
|
|
||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
|
|
||
| w = torch.empty_like(v) | ||
| u = torch.empty_like(k) | ||
| vg = torch.empty_like(v) if gv is not None else None | ||
| recompute_w_u_fwd_kernel[(NT, B*H)]( | ||
| k=k, | ||
| v=v, | ||
| vg=vg, | ||
| beta=beta, | ||
| w=w, | ||
| u=u, | ||
| A=A, | ||
| gv=gv, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_indices=chunk_indices, | ||
| T=T, | ||
| H=H, | ||
| K=K, | ||
| V=V, | ||
| BT=BT, | ||
| BK=BK, | ||
| BV=BV, | ||
| ) | ||
| return w, u, vg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type annotation is incorrect.
The function returns three values (w, u, vg) at line 237, but the type hint at line 206 declares Tuple[torch.Tensor, torch.Tensor]. Update to reflect the actual return type.
🐛 Proposed fix
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
gv: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 199 - 237, Update the return
type annotation of recompute_w_u_fwd to match the actual returned values (w, u,
vg): change the declared return from Tuple[torch.Tensor, torch.Tensor] to
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] and import Optional if
not already present; ensure the function signature and any callers/types align
with the new signature for recompute_w_u_fwd.
| gv: torch.Tensor = None, | ||
| cu_seqlens: Optional[torch.LongTensor] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| B, T, H, K, V = *k.shape, v.shape[-1] | ||
| BT = 64 | ||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
| CONST_TILING = 64 if check_shared_mem() else 32 | ||
| BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) | ||
| BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) | ||
|
|
||
| dk = torch.empty_like(k) | ||
| dv = torch.empty_like(v, dtype=torch.float) | ||
|
|
||
| dgv = torch.empty_like(gv, dtype=torch.float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add type hint and guard for gv parameter.
The gv parameter at line 247 lacks a proper type hint (defaults to None but no Optional annotation). Additionally, dgv is unconditionally allocated at line 261 using torch.empty_like(gv), which will fail if gv is None.
🐛 Proposed fix
def prepare_wy_repr_bwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
dw: torch.Tensor,
du: torch.Tensor,
- gv: torch.Tensor = None,
+ gv: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ if gv is None:
+ raise ValueError("gv is required for prepare_wy_repr_bwd")🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 247 - 261, The gv parameter
is declared without Optional typing and the code unconditionally allocates dgv
with torch.empty_like(gv), which will crash if gv is None; update the function
signature to annotate gv as Optional[torch.Tensor] and change the local dgv to
be Optional[torch.Tensor] (or torch.Tensor | None) and only allocate dgv when gv
is not None (e.g., after checking gv) — leave dgv as None otherwise; ensure any
later uses of dgv handle the None case or assert/raise if those code paths
require gv to be present.
|
Cool work @AwesomeSeq! Have you trained any 340M/1.3B models with this recurrence yet? If so, is there a paper I can refer to? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🤖 Fix all issues with AI agents
In `@fla/ops/gated_oja_rule/chunk_h.py`:
- Around line 194-238: The return type annotation of chunk_oja_fwd_h is wrong:
the function actually returns three values (h, k_new, final_state) where k_new
and final_state can be None depending on save_new_key and output_final_state;
update the function signature's return type to reflect three elements (e.g.
tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]) and ensure any
callers or tests expecting the old two-tuple are adjusted accordingly; reference
the chunk_oja_fwd_h definition and the variables h, k_new, final_state in your
change.
In `@fla/ops/gated_oja_rule/chunk_o.py`:
- Around line 178-241: The function chunk_oja_fwd_o has a return type mismatch:
its annotation declares four tensors but the implementation returns only A and
o; update the function signature's return annotation to match the actual return
(tuple[torch.Tensor, torch.Tensor]) or modify the body to return the additional
tensors if intended; locate chunk_oja_fwd_o and change the annotated return type
to reflect only A and o (or add the missing tensors to the return) and ensure
callers expect the corrected shape.
In `@fla/ops/gated_oja_rule/chunk.py`:
- Around line 287-297: The error messages use adjacent f-strings that get
implicitly concatenated without a separating space; update the messages in the
cu_seqlens check (where variables q, cu_seqlens, and initial_state are
referenced) to ensure proper spacing — either merge the two f-strings into one
or insert an explicit leading/trailing space or punctuation between them so the
resulting strings read correctly (do the same fix in the analogous checks in
fused_recurrent.py around the initial_state/cu_seqlens validation).
In `@fla/ops/gated_oja_rule/wy_fast.py`:
- Around line 61-79: The kernel unconditionally loads from gv (e.g.,
tl.load(p_gv) and tl.load(gv + ...)) which will crash if gv is None; update the
kernels (recompute_w_u_fwd_kernel and prepare_wy_repr_bwd_kernel) to either
(preferred) add a compile-time/use-time guard like a boolean USE_GV and wrap all
gv loads and STORE_VG-dependent logic (the p_gv/tl.load uses and computing
b_vb/b_vg) behind if USE_GV so the code never dereferences gv when absent, or
alternatively make gv a required parameter in the Python wrappers so callers
cannot pass None; ensure referenced symbols include gv, p_gv, b_gv, b_gn,
STORE_VG, and vg when applying the guard so no tl.load or tl.store touches gv/
vg unless USE_GV is true.
In `@tests/ops/test_oja.py`:
- Around line 404-407: Fix two issues: update the skip reason text and avoid
mutating global env. Change the pytest.skip call (condition using
is_intel_alchemist and D) to use the correct message 'chunk_gated_oja_rule'
instead of 'chunk_gated_delta_rule'; and replace the direct
os.environ['TRITON_F32_DEFAULT'] = 'ieee' side-effect with a test-scoped
environment change (use pytest's monkeypatch to setenv or save and restore the
original value around the test) so TRITON_F32_DEFAULT is not left modified for
other tests.
🧹 Nitpick comments (2)
fla/ops/gated_oja_rule/fused_recurrent.py (1)
117-168: Use explicitT | Nonefor optional parameters.Several parameters use implicit
Optional(PEP 484 violation):scaleat line 123,initial_stateat line 124. This also applies toFusedRecurrentFunction.forward(line 182) andfused_recurrent_gated_oja_rule(lines 221-222).Proposed fix (for the wrapper)
- scale: float = None, - initial_state: torch.Tensor = None, + scale: float | None = None, + initial_state: torch.Tensor | None = None,tests/ops/test_oja.py (1)
337-337: Remove leftover debugLine 337 contains a
print(...)that shouldn't be in committed test code. Also, there are leftover# breakpoint()comments at lines 37 and 369.Proposed fix
- print('================== Running forward and backward ==================')
| def chunk_oja_fwd_h( | ||
| v: torch.Tensor, | ||
| w: torch.Tensor, | ||
| u: torch.Tensor, | ||
| gv: torch.Tensor | None = None, | ||
| initial_state: torch.Tensor | None = None, | ||
| output_final_state: bool = False, | ||
| chunk_size: int = 64, # SY: remove this argument and force chunk size 64? | ||
| save_new_key: bool = True, | ||
| cu_seqlens: torch.LongTensor | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| B, T, H, V, K = *v.shape, u.shape[-1] | ||
| BT = chunk_size | ||
|
|
||
| chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None | ||
| # N: the actual number of sequences in the batch with either equal or variable lengths | ||
| if cu_seqlens is None: | ||
| N, NT, chunk_offsets = B, triton.cdiv(T, BT), None | ||
| else: | ||
| N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) | ||
| assert V <= 256, "current kernel does not support head dimension larger than 256." | ||
|
|
||
| h = v.new_empty(B, NT, H, K, V) | ||
| final_state = v.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None | ||
|
|
||
| k_new = torch.empty_like(u) if save_new_key else None | ||
| def grid(meta): return (triton.cdiv(K, meta['BK']), N*H) | ||
| chunk_oja_fwd_kernel_h_blockdim64[grid]( | ||
| v=v, | ||
| u=u, | ||
| w=w, | ||
| k_new=k_new, | ||
| gv=gv, | ||
| h=h, | ||
| h0=initial_state, | ||
| ht=final_state, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_offsets=chunk_offsets, | ||
| T=T, | ||
| H=H, | ||
| K=K, | ||
| V=V, | ||
| BT=BT | ||
| ) | ||
| return h, k_new, final_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type annotation is incorrect on chunk_oja_fwd_h.
The function returns three values (h, k_new, final_state) at line 238, but the type hint at line 204 declares tuple[torch.Tensor, torch.Tensor].
🐛 Proposed fix
-) -> tuple[torch.Tensor, torch.Tensor]:
+) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_h.py` around lines 194 - 238, The return type
annotation of chunk_oja_fwd_h is wrong: the function actually returns three
values (h, k_new, final_state) where k_new and final_state can be None depending
on save_new_key and output_final_state; update the function signature's return
type to reflect three elements (e.g. tuple[torch.Tensor, torch.Tensor | None,
torch.Tensor | None]) and ensure any callers or tests expecting the old
two-tuple are adjusted accordingly; reference the chunk_oja_fwd_h definition and
the variables h, k_new, final_state in your change.
| def chunk_oja_fwd_o( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| gv: torch.Tensor, | ||
| h: torch.Tensor, | ||
| scale: float = 1., | ||
| cu_seqlens: torch.LongTensor | None = None, | ||
| chunk_size: int = 64 | ||
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
| B, T, H, K, V = *k.shape, v.shape[-1] | ||
| BT = min(chunk_size, max(16, triton.next_power_of_2(T))) | ||
| BC = min(16, BT) | ||
| BV = min(64, triton.next_power_of_2(V)) | ||
| HQ = q.shape[2] | ||
|
|
||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None | ||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||
| NC = triton.cdiv(BT, BC) | ||
| NG = HQ // H | ||
|
|
||
| o = v.new_empty(B, T, HQ, V) | ||
| A = q.new_empty(B, T, HQ, BT) | ||
| def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * HQ) | ||
| chunk_oja_fwd_inter[grid]( | ||
| q, | ||
| k, | ||
| h, | ||
| gv, | ||
| o, | ||
| A, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_indices=chunk_indices, | ||
| scale=scale, | ||
| T=T, | ||
| HQ=HQ, | ||
| H=H, | ||
| K=K, | ||
| V=V, | ||
| BT=BT, | ||
| NG=NG, | ||
| ) | ||
|
|
||
| def grid(meta): return (triton.cdiv(V, meta['BV']), NT * NC, B * HQ) | ||
| chunk_oja_fwd_intra[grid]( | ||
| v, | ||
| gv, | ||
| o, | ||
| A, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_indices=chunk_indices, | ||
| T=T, | ||
| HQ=HQ, | ||
| H=H, | ||
| V=V, | ||
| BT=BT, | ||
| BC=BC, | ||
| BV=BV, | ||
| NC=NC, | ||
| NG=NG, | ||
| num_warps=4, | ||
| num_stages=2 | ||
| ) | ||
| return A, o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return type annotation declares 4 tensors but only 2 are returned.
Line 187 declares -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] but line 241 returns A, o (2 tensors).
🐛 Proposed fix
-) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+) -> tuple[torch.Tensor, torch.Tensor]:🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk_o.py` around lines 178 - 241, The function
chunk_oja_fwd_o has a return type mismatch: its annotation declares four tensors
but the implementation returns only A and o; update the function signature's
return annotation to match the actual return (tuple[torch.Tensor, torch.Tensor])
or modify the body to return the additional tensors if intended; locate
chunk_oja_fwd_o and change the annotated return type to reflect only A and o (or
add the missing tensors to the return) and ensure callers expect the corrected
shape.
| if cu_seqlens is not None: | ||
| if q.shape[0] != 1: | ||
| raise ValueError( | ||
| f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." | ||
| f"Please flatten variable-length inputs before processing." | ||
| ) | ||
| if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: | ||
| raise ValueError( | ||
| f"The number of initial states is expected to be equal to the number of input sequences, " | ||
| f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing space between concatenated f-strings in error messages.
The two adjacent f-strings at lines 290–291 and 295–296 are implicitly concatenated without a separator, producing messages like "...cu_seqlens.Please flatten...". The same issue exists in fused_recurrent.py` lines 237–239.
🐛 Proposed fix
if q.shape[0] != 1:
raise ValueError(
- f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`. "
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
- f"The number of initial states is expected to be equal to the number of input sequences, "
+ f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if cu_seqlens is not None: | |
| if q.shape[0] != 1: | |
| raise ValueError( | |
| f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." | |
| f"Please flatten variable-length inputs before processing." | |
| ) | |
| if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: | |
| raise ValueError( | |
| f"The number of initial states is expected to be equal to the number of input sequences, " | |
| f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." | |
| ) | |
| if cu_seqlens is not None: | |
| if q.shape[0] != 1: | |
| raise ValueError( | |
| f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`. " | |
| f"Please flatten variable-length inputs before processing." | |
| ) | |
| if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: | |
| raise ValueError( | |
| f"The number of initial states is expected to be equal to the number of input sequences, " | |
| f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." | |
| ) |
🧰 Tools
🪛 Ruff (0.14.14)
[warning] 289-292: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 294-297: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/chunk.py` around lines 287 - 297, The error messages
use adjacent f-strings that get implicitly concatenated without a separating
space; update the messages in the cu_seqlens check (where variables q,
cu_seqlens, and initial_state are referenced) to ensure proper spacing — either
merge the two f-strings into one or insert an explicit leading/trailing space or
punctuation between them so the resulting strings read correctly (do the same
fix in the analogous checks in fused_recurrent.py around the
initial_state/cu_seqlens validation).
| for i_v in range(tl.cdiv(V, BV)): | ||
| p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| p_w = tl.make_block_ptr(w + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||
| b_vb = b_v * b_b[:, None] | ||
|
|
||
| p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| b_gv = tl.load(p_gv, boundary_check=(0, 1)) | ||
| b_vb *= exp(b_gv) | ||
| if STORE_VG: | ||
| last_idx = min(i_t * BT + BT, T) - 1 | ||
|
|
||
| o_v = i_v * BV + tl.arange(0, BV) | ||
| m_v = o_v < V | ||
| b_gn = tl.load(gv + ((bos + last_idx) * H + i_h) * V + o_v, mask=m_v, other=0.) | ||
| b_vg = b_v * exp(b_gn - b_gv) | ||
|
|
||
| p_vg = tl.make_block_ptr(vg + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||
| tl.store(p_vg, b_vg.to(p_vg.dtype.element_ty), boundary_check=(0, 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gv is loaded unconditionally in the kernel — will crash if gv is None.
Both recompute_w_u_fwd_kernel (line 67) and prepare_wy_repr_bwd_kernel (line 152) load from gv without guarding on whether gv is actually provided. Although callers in chunk.py always pass a valid gv, the Python wrapper signatures allow gv=None. Either add a USE_GV heuristic guard in the kernels or make gv a required parameter in the wrappers to prevent a latent null-pointer crash.
🤖 Prompt for AI Agents
In `@fla/ops/gated_oja_rule/wy_fast.py` around lines 61 - 79, The kernel
unconditionally loads from gv (e.g., tl.load(p_gv) and tl.load(gv + ...)) which
will crash if gv is None; update the kernels (recompute_w_u_fwd_kernel and
prepare_wy_repr_bwd_kernel) to either (preferred) add a compile-time/use-time
guard like a boolean USE_GV and wrap all gv loads and STORE_VG-dependent logic
(the p_gv/tl.load uses and computing b_vb/b_vg) behind if USE_GV so the code
never dereferences gv when absent, or alternatively make gv a required parameter
in the Python wrappers so callers cannot pass None; ensure referenced symbols
include gv, p_gv, b_gv, b_gn, STORE_VG, and vg when applying the guard so no
tl.load or tl.store touches gv/ vg unless USE_GV is true.
| if is_intel_alchemist and D > 128: | ||
| pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') | ||
| torch.manual_seed(42) | ||
| os.environ['TRITON_F32_DEFAULT'] = 'ieee' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two issues: incorrect skip message and global env var side-effect.
- Line 405: The skip reason says
chunk_gated_delta_rulebut should saychunk_gated_oja_rule. - Line 407: Setting
os.environ['TRITON_F32_DEFAULT'] = 'ieee'modifies global process state, which can leak into other tests. Consider scoping this withmonkeypatchor restoring the original value in a fixture.
Proposed fix for the skip message
- pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128')
+ pytest.skip(reason='chunk_gated_oja_rule is not supported on alchemist for D>128')🤖 Prompt for AI Agents
In `@tests/ops/test_oja.py` around lines 404 - 407, Fix two issues: update the
skip reason text and avoid mutating global env. Change the pytest.skip call
(condition using is_intel_alchemist and D) to use the correct message
'chunk_gated_oja_rule' instead of 'chunk_gated_delta_rule'; and replace the
direct os.environ['TRITON_F32_DEFAULT'] = 'ieee' side-effect with a test-scoped
environment change (use pytest's monkeypatch to setenv or save and restore the
original value around the test) so TRITON_F32_DEFAULT is not left modified for
other tests.
from hujiaxi@moonshot.cn
Summary by CodeRabbit
New Features
Tests