-
Notifications
You must be signed in to change notification settings - Fork 380
[kda] add recursive block intra implementation #656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
sustcsonglin
wants to merge
2
commits into
main
Choose a base branch
from
feat/kda-recursive-block
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
|
|
||
| import torch | ||
| import triton | ||
| from fla.ops.kda.chunk_intra import chunk_kda_fwd_intra | ||
|
|
||
| def benchmark_intra_chunk(B=8, T=4096, H=16, K=128, chunk_size=64): | ||
| dtype = torch.bfloat16 | ||
| device = 'cuda' | ||
|
|
||
| q = torch.randn(B, T, H, K, device=device, dtype=dtype) | ||
| k = torch.randn(B, T, H, K, device=device, dtype=dtype) | ||
| g = torch.randn(B, T, H, K, device=device, dtype=torch.float32) | ||
| beta = torch.randn(B, T, H, device=device, dtype=dtype) | ||
|
|
||
| scale = 1.0 | ||
|
|
||
| quantiles = [0.5, 0.2, 0.8] | ||
|
|
||
| # Warmup | ||
| for _ in range(10): | ||
| chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="token") | ||
| chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recursive") | ||
| chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent") | ||
|
|
||
| ms_token = triton.testing.do_bench( | ||
| lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="token"), | ||
| quantiles=quantiles | ||
| ) | ||
|
|
||
| ms_recursive = triton.testing.do_bench( | ||
| lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recursive"), | ||
| quantiles=quantiles | ||
| ) | ||
|
|
||
| try: | ||
| ms_recurrent = triton.testing.do_bench( | ||
| lambda: chunk_kda_fwd_intra(q, k, g, beta, scale=scale, chunk_size=chunk_size, impl_type="recurrent"), | ||
| quantiles=quantiles | ||
| ) | ||
| t_recurrent = ms_recurrent[0] | ||
| except Exception as e: | ||
| t_recurrent = float('nan') | ||
|
|
||
| # Format for table row | ||
| # Shape | Token | Recursive | Recurrent | Rec vs Token | ||
| row_str = f"B={B}, T={T}, H={H}, K={K}" | ||
| print(f"{row_str:<30} | {ms_token[0]:.3f} ms | {ms_recursive[0]:.3f} ms | {t_recurrent:.3f} ms | {ms_token[0]/ms_recursive[0]:.2f}x ") | ||
|
|
||
| if __name__ == "__main__": | ||
| configs = [ | ||
| (8, 4096, 16, 128), | ||
| (1, 8192, 16, 128), | ||
| (8, 4096, 32, 64), | ||
| (1, 8192, 32, 64), | ||
| # Large Batch | ||
| (32, 512, 12, 64), | ||
| # High Head Dim | ||
| (2, 4096, 8, 256), | ||
| ] | ||
|
|
||
| print(f"{'Shape':<30} | {'Token (Original)':<20} | {'Recursive (New)':<20} | {'Recurrent':<15} | {'Speedup (Rec/Tok)':<15}") | ||
| print("-" * 110) | ||
|
|
||
| for B, T, H, K in configs: | ||
| try: | ||
| benchmark_intra_chunk(B=B, T=T, H=H, K=K, chunk_size=64) | ||
| except Exception as e: | ||
| print(f"Failed for shape B={B}, T={T}, H={H}, K={K}: {e}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix unused exception variable and consider narrowing the catch
You currently catch a broad
Exceptionand bind it toewithout using it, which triggers F841/BLE001 and fails lint.A minimal fix that keeps the behavior is:
This removes the unused variable and documents why the broad catch is acceptable in a benchmark context.
📝 Committable suggestion
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 41-41: local variable 'e' is assigned to but never used
(F841)
🪛 Ruff (0.14.5)
41-41: Do not catch blind exception:
Exception(BLE001)
41-41: Local variable
eis assigned to but never usedRemove assignment to unused variable
e(F841)
🤖 Prompt for AI Agents