1010from transformers .cache_utils import Cache
1111from transformers .models .llama .configuration_llama import LlamaConfig
1212from flash_attn import flash_attn_func
13+ from flash_attn .cute import flash_attn_func as flash_attn_func_v4
1314
1415from specforge .modeling .draft .flex_attention import (
1516 compile_friendly_create_block_mask ,
@@ -425,7 +426,6 @@ def yarn_linear_ramp_mask(min_val, max_val, dim):
425426
426427
427428class LlamaYarnRotaryEmbedding (LlamaRotaryEmbedding ):
428-
429429 def __init__ (
430430 self ,
431431 dim ,
@@ -850,6 +850,10 @@ class LlamaFlashAttention(LlamaAttention):
850850 - cache_hidden: manual cache used for storing past key and value states
851851 """
852852
853+ def __init__ (self , config , backend = "fa" ):
854+ super ().__init__ (config )
855+ self .backend = backend
856+
853857 def forward (
854858 self ,
855859 hidden_states : torch .Tensor ,
@@ -866,9 +870,7 @@ def forward(
866870 key_states = self .k_proj (hidden_states )
867871 value_states = self .v_proj (hidden_states )
868872
869- query_states = query_states .view (
870- bsz , q_len , self .num_heads , self .head_dim
871- )
873+ query_states = query_states .view (bsz , q_len , self .num_heads , self .head_dim )
872874 key_states = key_states .view (
873875 bsz , q_len , self .num_key_value_heads , self .head_dim
874876 )
@@ -908,7 +910,12 @@ def forward(
908910 k0 = cache_k [0 ]
909911 v0 = cache_v [0 ]
910912
911- attn_output , lse , _ = flash_attn_func (
913+ if self .backend == "fa4" :
914+ attn_func = flash_attn_func_v4
915+ else :
916+ attn_func = flash_attn_func
917+
918+ attn_output , lse , _ = attn_func (
912919 query_states ,
913920 k0 ,
914921 v0 ,
@@ -921,7 +928,13 @@ def forward(
921928
922929 lck = len (cache_k )
923930 if lck > 1 :
924- q_shape_expanded = (bsz , q_len , self .num_key_value_heads , self .num_key_value_groups , self .head_dim )
931+ q_shape_expanded = (
932+ bsz ,
933+ q_len ,
934+ self .num_key_value_heads ,
935+ self .num_key_value_groups ,
936+ self .head_dim ,
937+ )
925938 attn_outputs = [attn_output .view (q_shape_expanded )]
926939 lses = [lse .view (q_shape_expanded [:- 1 ])]
927940
@@ -1022,7 +1035,9 @@ def __init__(self, config, attention_backend: str = "sdpa"):
10221035 print_with_rank ("Using flex attention on draft model training!" )
10231036 self .self_attn = LlamaFlexAttention (config = config )
10241037 elif attention_backend == "fa" :
1025- self .self_attn = LlamaFlashAttention (config = config )
1038+ self .self_attn = LlamaFlashAttention (config = config , backend = "fa" )
1039+ elif attention_backend == "fa4" :
1040+ self .self_attn = LlamaFlashAttention (config = config , backend = "fa4" )
10261041 else :
10271042 raise ValueError (f"Unknown attention backend { attention_backend } " )
10281043
@@ -1092,7 +1107,6 @@ def forward(
10921107
10931108
10941109class LlamaForCausalLMEagle3 (Eagle3DraftModel ):
1095-
10961110 config_class = LlamaConfig
10971111
10981112 def __init__ (self , config , quant_config = None , attention_backend = "sdpa" ) -> None :
0 commit comments