Skip to content

Conversation

@timmy-feng
Copy link
Contributor

@timmy-feng timmy-feng commented Nov 20, 2025

Motivation

The two existing attention backends both exhibit inefficiencies which inhibit the training experience.

  • sdpa backend materializes the full bsz x num_heads x q_len x kv_len attention score matrix in VRAM, severely inhibiting max sequence length.
  • flex_attention backend is very particular about the linux environment and often requires different compilation flags depending on package versions. We were not able to get this kernel to compile reliably on torch==2.8.0.

Using a log sum exp trick, we can avoid materializing any attention matrix while handling TTT KV cache with very minimal overhead. We support this using the flash attention backend since it readily provides us with an LSE tensor along with the O tensor. Flash attention 4 is also SOTA for training on Blackwell and while porting FA4 is out of scope of this PR, supporting the flash attention interface is a first step.

Modifications

Added a new LlamaFlashAttention module which has the same api as LlamaAttention (using a manual hidden cache).

Within the forward pass, we:

  • Calculate the partial attention output with only the target model's KV cache using flash attention
  • Create singleton partial attention outputs for each of the successive TTT iterations
  • Combine all partials via weighted sum with their LSE's

Added a test file test_flash_attention.py which verifies equivalence with the SDPA backend (up to bf16 numerical stability).

Related Issues

Accuracy Test

Ran python -m tests.test_utils.test_flash_attention:

test_backward_pass_gradient_comparison (__main__.TestFlashAttention.test_backward_pass_gradient_comparison)
Test backward pass comparing gradients between LlamaAttention and LlamaFlashAttention. ... ok
test_forward_pass_comparison (__main__.TestFlashAttention.test_forward_pass_comparison)
Test forward pass comparison between LlamaAttention and LlamaFlashAttention. ... ok

----------------------------------------------------------------------
Ran 2 tests in 16.257s

OK

Benchmark & Profiling

Trained a speculator on custom data for GLM 4.5 on 8xH200 with batch size per GPU of 1 and sequence length of 32K. Here are the performance comparisons to flex attention:

Method VRAM Usage Speed (s/it)
flex-attention 888 GB 9.5
flash-attention 854 GB 7.2

We also trained for one epoch on perfectblend and achieved accept length of 3 on GSM8K with chain spec of 3 steps.

GLM 4.5 support was added in a custom branch built on top of this PR here.

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@sleepcoo
Copy link
Collaborator

How much performance improvement does flex attention offer in comparison?

@timmy-feng
Copy link
Contributor Author

I ran a comparison on 8xH200 and added it to the benchmarks section. I had a slight improvement to flex attention in both VRAM usage and speed (25% faster). We could probably push it up further by supporting fa3 and fa4.

I was not able to get flex-attention to compile on B200, one of the core motivations for this feature.

@FrankLeeeee
Copy link
Collaborator

Thanks! I was not able to use flex-attention on B200, too. Meanwhile, can you pre-commit your code?

@timmy-feng timmy-feng requested a review from zyksir as a code owner November 21, 2025 20:52
@FrankLeeeee
Copy link
Collaborator

There is still conflict with the main branch.

torch.manual_seed(0)


def assert_similar(ref, out):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like there is a more significant numeric differences between the two approaches. How much of a difference if we do point to point comparison between ref and out?

@Abigbigbig
Copy link

I trained qwen2.5-vl-7B-eagle3 using the latest specforge 0.1.1 and sglang 0.5.5, and encountered ”AttributeError: 'Qwen2_5_VLForConditionalGeneration' object has no attribute 'set_aux_hidden_states_layers'“. I didn't have this issue when using the version before the fix. What could be the reason?

@yubofredwang
Copy link
Collaborator

@Abigbigbig This looks like a different issue from this PR. Let's move to a different issue. I can point you the fix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants