@@ -516,28 +516,17 @@ def cuda_kernels_forward(
516
516
self ,
517
517
hidden_states : torch .Tensor ,
518
518
cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
519
- cache_position : Optional [torch .LongTensor ] = None ,
520
519
attention_mask : Optional [torch .Tensor ] = None ,
520
+ seq_idx : Optional [torch .Tensor ] = None ,
521
+ use_precomputed_states : bool = False ,
521
522
):
522
523
# 1. Gated MLP's linear projection
523
- hidden_states = apply_mask_to_padding_states (hidden_states , attention_mask )
524
524
projected_states = self .in_proj (hidden_states )
525
525
526
526
# Set up dimensions for reshapes later
527
527
batch_size , seq_len , _ = hidden_states .shape
528
528
groups_time_state_size = self .n_groups * self .ssm_state_size
529
529
530
- use_precomputed_states = (
531
- cache_params is not None
532
- and cache_params .has_previous_state
533
- and seq_len == 1
534
- and cache_params .conv_states [self .layer_idx ].shape [0 ]
535
- == cache_params .ssm_states [self .layer_idx ].shape [0 ]
536
- == batch_size
537
- and cache_position is not None
538
- and cache_position [0 ] > 0
539
- )
540
-
541
530
# getting projected states from cache if it exists
542
531
if use_precomputed_states :
543
532
gate , hidden_states_B_C , dt = projected_states .squeeze (1 ).split (
@@ -600,7 +589,7 @@ def cuda_kernels_forward(
600
589
A ,
601
590
D = self .D ,
602
591
chunk_size = self .chunk_size ,
603
- seq_idx = None , # was seq_idx
592
+ seq_idx = seq_idx ,
604
593
activation = self .activation ,
605
594
rmsnorm_weight = self .norm .weight ,
606
595
rmsnorm_eps = self .norm .variance_epsilon ,
@@ -684,29 +673,18 @@ def torch_forward(
684
673
self ,
685
674
input_states ,
686
675
cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
687
- cache_position : Optional [torch .LongTensor ] = None ,
688
676
attention_mask : Optional [torch .Tensor ] = None ,
677
+ use_precomputed_states : bool = False
689
678
):
690
679
batch_size , seq_len , _ = input_states .shape
691
680
dtype = input_states .dtype
692
681
693
682
# 1. Gated MLP's linear projection
694
- input_states = apply_mask_to_padding_states (input_states , attention_mask )
695
683
projected_states = self .in_proj (input_states )
696
684
gate , hidden_states_B_C , dt = projected_states .split (
697
685
[self .intermediate_size , self .conv_dim , self .num_heads ], dim = - 1
698
686
)
699
687
700
- use_precomputed_states = (
701
- cache_params is not None
702
- and cache_params .has_previous_state
703
- and seq_len == 1
704
- and cache_params .conv_states [self .layer_idx ].shape [0 ]
705
- == cache_params .ssm_states [self .layer_idx ].shape [0 ]
706
- == batch_size
707
- and cache_position is not None
708
- and cache_position [0 ] > 0
709
- )
710
688
711
689
# 2. Convolution sequence transformation
712
690
if use_precomputed_states :
@@ -893,15 +871,27 @@ def forward(
893
871
cache_params : Optional [HybridMambaAttentionDynamicCache ] = None ,
894
872
cache_position : Optional [torch .LongTensor ] = None ,
895
873
attention_mask : Optional [torch .Tensor ] = None ,
874
+ seq_idx : Optional [torch .Tensor ] = None ,
896
875
):
876
+ batch_size , seq_len , _ = hidden_states .shape
877
+ use_precomputed_states = (
878
+ cache_params is not None
879
+ and cache_params .has_previous_state
880
+ and seq_len == 1
881
+ and cache_params .conv_states [self .layer_idx ].shape [0 ]
882
+ == cache_params .ssm_states [self .layer_idx ].shape [0 ]
883
+ == batch_size
884
+ and cache_position is not None
885
+ and cache_position [0 ] > 0
886
+ )
887
+ hidden_states = apply_mask_to_padding_states (hidden_states , attention_mask )
897
888
if is_fast_path_available and "cuda" in self .in_proj .weight .device .type :
898
- return self .cuda_kernels_forward (hidden_states , cache_params , cache_position , attention_mask )
899
- dtype = hidden_states .dtype
900
- if attention_mask is not None and attention_mask .shape [1 ] > 1 and attention_mask .shape [0 ] > 1 :
901
- # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
902
- hidden_states = (hidden_states * attention_mask [:, :, None ]).to (dtype )
903
-
904
- return self .torch_forward (hidden_states , cache_params , cache_position , attention_mask )
889
+ return self .cuda_kernels_forward (
890
+ hidden_states , cache_params , attention_mask , seq_idx , use_precomputed_states
891
+ )
892
+ if seq_idx is not None :
893
+ raise ValueError ("Non-trivial seq_idx only supported on cuda path." )
894
+ return self .torch_forward (hidden_states , cache_params , attention_mask , use_precomputed_states )
905
895
906
896
907
897
class BambaMLP (nn .Module ):
@@ -940,10 +930,42 @@ def extra_repr(self):
940
930
return f"{ tuple (self .weight .shape )} , eps={ self .variance_epsilon } "
941
931
942
932
933
+ def get_cu_seq_lens_from_position_ids (position_ids : torch .LongTensor ) -> torch .LongTensor :
934
+ batch_size = position_ids .shape [0 ]
935
+ if batch_size != 1 :
936
+ raise ValueError ("Only batch size 1 is supported." )
937
+ device = position_ids .device
938
+ idxs = torch .arange (1 , position_ids .shape [1 ], device = device )
939
+ non_increasing_pos_id = position_ids [0 , 1 :] <= position_ids [0 , :- 1 ]
940
+ cu_seq_lens = torch .cat (
941
+ (
942
+ torch .tensor ([0 ], device = device ),
943
+ idxs [non_increasing_pos_id ],
944
+ torch .tensor (position_ids [0 ].shape , device = device ),
945
+ ),
946
+ )
947
+ return cu_seq_lens [None ]
948
+
949
+
950
+ def get_seq_idx_from_cu_seq_lens (cu_seq_lens : torch .Tensor ) -> torch .Tensor :
951
+ batch_size = cu_seq_lens .shape [0 ]
952
+ if batch_size != 1 :
953
+ raise ValueError ("Only batch size 1 is supported." )
954
+ seq_idx = torch .cat (
955
+ [
956
+ torch .full ((n ,), idx , dtype = torch .int32 , device = cu_seq_lens .device )
957
+ for idx , n in enumerate (torch .diff (cu_seq_lens [0 ], dim = - 1 ))
958
+ ]
959
+ )
960
+ return seq_idx [None ]
961
+
962
+
943
963
class BambaDecoderLayer (nn .Module ):
944
964
def __init__ (self , config : BambaConfig , layer_idx : int , layer_type : str = "mamba" ):
945
965
super ().__init__ ()
946
966
967
+ # The `num_experts` code below is redundant, but it prevents modular_model_converter.py from
968
+ # generating an unwanted BambaSparseMoeBlock in modeling_bamba.py
947
969
num_experts = 1
948
970
ffn_layer_class = BambaMLP if num_experts == 1 else None
949
971
self .feed_forward = ffn_layer_class (config )
@@ -968,7 +990,7 @@ def forward(
968
990
use_cache : Optional [bool ] = False ,
969
991
cache_position : Optional [torch .LongTensor ] = None ,
970
992
position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None , # necessary, but kept here for BC
971
- ** kwargs ,
993
+ ** kwargs : Unpack [ FlashAttentionKwargs ] ,
972
994
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
973
995
"""
974
996
Args:
@@ -998,11 +1020,29 @@ def forward(
998
1020
999
1021
# this is a hybrid decoder layer
1000
1022
if self .layer_type == "mamba" :
1023
+ # Padding-free processing for efficient training. position_ids and FlashAttentionKwargs
1024
+ # are ignored by mamba layers if not training.
1025
+ if not self .training :
1026
+ seq_idx = None
1027
+ elif "cu_seq_lens_k" in kwargs :
1028
+ seq_idx = get_seq_idx_from_cu_seq_lens (kwargs ["cu_seq_lens_k" ])
1029
+ elif position_ids is not None :
1030
+ cu_seq_lens = get_cu_seq_lens_from_position_ids (position_ids )
1031
+ if len (cu_seq_lens [0 ]) == 2 :
1032
+ # If cu_seq_lens only has two elements, then it is semantically equivalent to
1033
+ # `seq_idx=None`, which is more efficient.
1034
+ seq_idx = None
1035
+ else :
1036
+ seq_idx = get_seq_idx_from_cu_seq_lens (cu_seq_lens )
1037
+ else :
1038
+ seq_idx = None
1001
1039
hidden_states = self .mamba (
1002
1040
hidden_states = hidden_states ,
1003
1041
cache_params = past_key_value ,
1004
1042
cache_position = cache_position ,
1005
1043
attention_mask = attention_mask ,
1044
+ seq_idx = seq_idx ,
1045
+ ** kwargs ,
1006
1046
)
1007
1047
self_attn_weights = None
1008
1048
elif self .layer_type == "attention" :
@@ -1202,6 +1242,7 @@ def forward(
1202
1242
output_hidden_states : Optional [bool ] = None ,
1203
1243
return_dict : Optional [bool ] = None ,
1204
1244
cache_position : Optional [torch .LongTensor ] = None ,
1245
+ ** flash_attn_kwargs : Unpack [FlashAttentionKwargs ],
1205
1246
) -> Union [Tuple , BaseModelOutputWithPast ]:
1206
1247
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
1207
1248
output_hidden_states = (
@@ -1275,6 +1316,7 @@ def forward(
1275
1316
use_cache = use_cache ,
1276
1317
cache_position = cache_position ,
1277
1318
position_embeddings = position_embeddings ,
1319
+ ** flash_attn_kwargs ,
1278
1320
)
1279
1321
1280
1322
hidden_states = layer_outputs [0 ]
0 commit comments