Skip to content

Commit 0279e03

Browse files
committed
add fa4
1 parent c9dbc1f commit 0279e03

File tree

4 files changed

+39
-13
lines changed

4 files changed

+39
-13
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ pydantic
1515
sglang[all]==0.5.4
1616
openai-harmony
1717
flash-attn>=2.6.3
18+
flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@54d8aa6751fc9d5f0357854079261913d5df1f9d#subdirectory=flash_attn/cute

specforge/core/eagle3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def forward(
140140
plosses = []
141141
vlosses = []
142142
acces = []
143-
if self.attention_backend in ["sdpa", "fa"]:
143+
if self.attention_backend in ["sdpa", "fa", "fa4"]:
144144
cache_hidden = [[], []]
145145
past_key_values = None
146146
elif self.attention_backend == "flex_attention":
@@ -517,7 +517,7 @@ def forward(
517517
plosses = []
518518
vlosses = []
519519
acces = []
520-
if self.attention_backend in ["sdpa", "fa"]:
520+
if self.attention_backend in ["sdpa", "fa", "fa4"]:
521521
cache_hidden = [[], []]
522522
past_key_values = None
523523
elif self.attention_backend == "flex_attention":

specforge/modeling/draft/llama3_eagle.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from transformers.cache_utils import Cache
1111
from transformers.models.llama.configuration_llama import LlamaConfig
1212
from flash_attn import flash_attn_func
13+
from flash_attn.cute import flash_attn_func as flash_attn_func_v4
1314

1415
from specforge.modeling.draft.flex_attention import (
1516
compile_friendly_create_block_mask,
@@ -425,7 +426,6 @@ def yarn_linear_ramp_mask(min_val, max_val, dim):
425426

426427

427428
class LlamaYarnRotaryEmbedding(LlamaRotaryEmbedding):
428-
429429
def __init__(
430430
self,
431431
dim,
@@ -850,6 +850,10 @@ class LlamaFlashAttention(LlamaAttention):
850850
- cache_hidden: manual cache used for storing past key and value states
851851
"""
852852

853+
def __init__(self, config, backend="fa"):
854+
super().__init__(config)
855+
self.backend = backend
856+
853857
def forward(
854858
self,
855859
hidden_states: torch.Tensor,
@@ -866,9 +870,7 @@ def forward(
866870
key_states = self.k_proj(hidden_states)
867871
value_states = self.v_proj(hidden_states)
868872

869-
query_states = query_states.view(
870-
bsz, q_len, self.num_heads, self.head_dim
871-
)
873+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
872874
key_states = key_states.view(
873875
bsz, q_len, self.num_key_value_heads, self.head_dim
874876
)
@@ -908,7 +910,12 @@ def forward(
908910
k0 = cache_k[0]
909911
v0 = cache_v[0]
910912

911-
attn_output, lse, _ = flash_attn_func(
913+
if self.backend == "fa4":
914+
attn_func = flash_attn_func_v4
915+
else:
916+
attn_func = flash_attn_func
917+
918+
attn_output, lse, _ = attn_func(
912919
query_states,
913920
k0,
914921
v0,
@@ -921,7 +928,13 @@ def forward(
921928

922929
lck = len(cache_k)
923930
if lck > 1:
924-
q_shape_expanded = (bsz, q_len, self.num_key_value_heads, self.num_key_value_groups, self.head_dim)
931+
q_shape_expanded = (
932+
bsz,
933+
q_len,
934+
self.num_key_value_heads,
935+
self.num_key_value_groups,
936+
self.head_dim,
937+
)
925938
attn_outputs = [attn_output.view(q_shape_expanded)]
926939
lses = [lse.view(q_shape_expanded[:-1])]
927940

@@ -1022,7 +1035,9 @@ def __init__(self, config, attention_backend: str = "sdpa"):
10221035
print_with_rank("Using flex attention on draft model training!")
10231036
self.self_attn = LlamaFlexAttention(config=config)
10241037
elif attention_backend == "fa":
1025-
self.self_attn = LlamaFlashAttention(config=config)
1038+
self.self_attn = LlamaFlashAttention(config=config, backend="fa")
1039+
elif attention_backend == "fa4":
1040+
self.self_attn = LlamaFlashAttention(config=config, backend="fa4")
10261041
else:
10271042
raise ValueError(f"Unknown attention backend {attention_backend}")
10281043

@@ -1092,7 +1107,6 @@ def forward(
10921107

10931108

10941109
class LlamaForCausalLMEagle3(Eagle3DraftModel):
1095-
10961110
config_class = LlamaConfig
10971111

10981112
def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:

tests/test_utils/test_flash_attention.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ def assert_similar(ref, out):
2828

2929

3030
class TestFlashAttention(unittest.TestCase):
31-
3231
def setUp(self):
32+
import os
33+
3334
torch.manual_seed(0)
35+
36+
self.backend = os.environ.get("FLASH_ATTN_BACKEND", "fa")
3437
self.config_dict = {
3538
"hidden_size": 128,
3639
"num_attention_heads": 8,
@@ -57,7 +60,11 @@ def test_forward_pass_comparison(self):
5760
def _test_forward_pass_comparison_for_seq_len(self, seq_len):
5861
"""Helper method to test forward pass comparison for a specific sequence length."""
5962
attention = LlamaAttention(self.config).to("cuda").to(self.dtype)
60-
flash_attention = LlamaFlashAttention(self.config).to("cuda").to(self.dtype)
63+
flash_attention = (
64+
LlamaFlashAttention(self.config, backend=self.backend)
65+
.to("cuda")
66+
.to(self.dtype)
67+
)
6168

6269
# Ensure same weights
6370
with torch.no_grad():
@@ -144,7 +151,11 @@ def test_backward_pass_gradient_comparison(self):
144151
def _test_backward_pass_gradient_comparison_for_seq_len(self, seq_len):
145152
"""Helper method to test backward pass gradient comparison for a specific sequence length."""
146153
attention = LlamaAttention(self.config).to("cuda").to(self.dtype)
147-
flash_attention = LlamaFlashAttention(self.config).to("cuda").to(self.dtype)
154+
flash_attention = (
155+
LlamaFlashAttention(self.config, backend=self.backend)
156+
.to("cuda")
157+
.to(self.dtype)
158+
)
148159

149160
# Ensure same weights
150161
with torch.no_grad():

0 commit comments

Comments
 (0)