Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions benchmarks/ops/benchmark_kda_intra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@

import torch
import triton
from fla.ops.kda.chunk_intra import chunk_kda_fwd_intra

def benchmark_intra_chunk(B=8, T=4096, H=16, K=128, chunk_size=64):
dtype = torch.bfloat16
device = 'cuda'

q = torch.randn(B, T, H, K, device=device, dtype=dtype)
k = torch.randn(B, T, H, K, device=device, dtype=dtype)
g = torch.randn(B, T, H, K, device=device, dtype=torch.float32)
beta = torch.randn(B, T, H, device=device, dtype=dtype)

scale = 1.0

quantiles = [0.5, 0.2, 0.8]

# Warmup
for _ in range(10):
chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="token")
chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recursive")
chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent")

ms_token = triton.testing.do_bench(
lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="token"),
quantiles=quantiles
)

ms_recursive = triton.testing.do_bench(
lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recursive"),
quantiles=quantiles
)

try:
ms_recurrent = triton.testing.do_bench(
lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
quantiles=quantiles
)
t_recurrent = ms_recurrent[0]
except Exception as e:
t_recurrent = float('nan')

Comment on lines +35 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix unused exception variable and consider narrowing the catch

You currently catch a broad Exception and bind it to e without using it, which triggers F841/BLE001 and fails lint.

A minimal fix that keeps the behavior is:

-    try:
-        ms_recurrent = triton.testing.do_bench(
-            lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
-            quantiles=quantiles
-        )
-        t_recurrent = ms_recurrent[0]
-    except Exception as e:
-        t_recurrent = float('nan')
+    try:
+        ms_recurrent = triton.testing.do_bench(
+            lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
+            quantiles=quantiles,
+        )
+        t_recurrent = ms_recurrent[0]
+    except Exception:
+        # Recurrent path not supported/stable for this config; report NaN.
+        t_recurrent = float("nan")

This removes the unused variable and documents why the broad catch is acceptable in a benchmark context.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
ms_recurrent = triton.testing.do_bench(
lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
quantiles=quantiles
)
t_recurrent = ms_recurrent[0]
except Exception as e:
t_recurrent = float('nan')
try:
ms_recurrent = triton.testing.do_bench(
lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"),
quantiles=quantiles,
)
t_recurrent = ms_recurrent[0]
except Exception:
# Recurrent path not supported/stable for this config; report NaN.
t_recurrent = float("nan")
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 41-41: local variable 'e' is assigned to but never used

(F841)

🪛 Ruff (0.14.5)

41-41: Do not catch blind exception: Exception

(BLE001)


41-41: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)

🤖 Prompt for AI Agents
In benchmarks/ops/benchmark_kda_intra.py around lines 35 to 43, the except block
currently uses "except Exception as e" but never uses "e", causing an
unused-variable lint error and implicitly catching all exceptions; change it to
"except Exception:" (or better, catch specific expected exceptions if known) and
add a short comment explaining why a broad catch is acceptable in this
benchmarking context so the intent is clear to reviewers and linters.

# Format for table row
# Shape | Token | Recursive | Recurrent | Rec vs Token
row_str = f"B={B}, T={T}, H={H}, K={K}"
print(f"{row_str:<30} | {ms_token[0]:.3f} ms | {ms_recursive[0]:.3f} ms | {t_recurrent:.3f} ms | {ms_token[0]/ms_recursive[0]:.2f}x ")

if __name__ == "__main__":
configs = [
(8, 4096, 16, 128),
(1, 8192, 16, 128),
(8, 4096, 32, 64),
(1, 8192, 32, 64),
# Large Batch
(32, 512, 12, 64),
# High Head Dim
(2, 4096, 8, 256),
]

print(f"{'Shape':<30} | {'Token (Original)':<20} | {'Recursive (New)':<20} | {'Recurrent':<15} | {'Speedup (Rec/Tok)':<15}")
print("-" * 110)

for B, T, H, K in configs:
try:
benchmark_intra_chunk(B=B, T=T, H=H, K=K, chunk_size=64)
except Exception as e:
print(f"Failed for shape B={B}, T={T}, H={H}, K={K}: {e}")
33 changes: 27 additions & 6 deletions fla/layers/rebased.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,25 @@

import torch
import torch.nn as nn
from einops import rearrange
from einops import rearrange, repeat

from fla.modules.feature_map import RebasedFeatureMap
from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn
from fla.ops.rebased import parallel_rebased


class ReBasedLinearAttention(nn.Module):
r"""
Implementation of ReBased linear attention with optional grouped keys/values.

Args:
hidden_size (int): Model hidden size.
feature_dim (int): Dimensionality of the learnable quadratic feature map per head.
num_heads (int): Number of query heads.
num_key_value_heads (int): Number of unique key/value heads (GQA). Must divide `num_heads`.
When smaller than `num_heads`, keys and values are projected once per KV head and then
shared across ``num_heads // num_key_value_heads`` query heads.
"""

def __init__(
self,
Expand All @@ -39,10 +50,16 @@ def __init__(
self.mode = mode
assert self.mode in ["fused_chunk", "parallel", 'chunk']

if hidden_size % num_heads != 0:
raise ValueError("`hidden_size` must be divisible by `num_heads`.")
if num_heads % num_key_value_heads != 0:
raise ValueError("`num_heads` must be divisible by `num_key_value_heads`.")

self.feature_dim = feature_dim
self.num_key_value_heads = num_key_value_heads
self.num_heads = num_heads
self.head_dim = self.hidden_size // self.num_key_value_heads
self.num_key_value_heads = num_key_value_heads
self.num_kv_groups = self.num_heads // self.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.use_gamma = use_gamma
self.use_beta = use_beta
self.normalize = normalize
Expand All @@ -53,7 +70,7 @@ def __init__(

self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize)
self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_key_value_heads, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.dropout = nn.Identity()
Expand All @@ -69,7 +86,7 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
k = rearrange(
self.k_proj(hidden_states),
"... (h d) -> ... h d",
h=self.num_heads,
h=self.num_key_value_heads,
d=self.feature_dim,
)
v = rearrange(
Expand All @@ -78,7 +95,11 @@ def forward(self, hidden_states: torch.Tensor, **kwargs):
h=self.num_key_value_heads,
d=self.head_dim,
)
q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel'))
q = self.feature_map(q, flatten=(mode != 'parallel'))
k = self.feature_map(k, flatten=(mode != 'parallel'))
if self.num_kv_groups > 1:
k = repeat(k, "... h d -> ... (h g) d", g=self.num_kv_groups)
v = repeat(v, "... h d -> ... (h g) d", g=self.num_kv_groups)
if mode == "fused_chunk":
o = fused_chunk_linear_attn(
q=q,
Expand Down
Loading
Loading