Skip to content

Commit eab1ae1

Browse files
committed
Add padding-free to bamba
1 parent b5aaf87 commit eab1ae1

File tree

3 files changed

+218
-72
lines changed

3 files changed

+218
-72
lines changed

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 76 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -516,28 +516,17 @@ def cuda_kernels_forward(
516516
self,
517517
hidden_states: torch.Tensor,
518518
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
519-
cache_position: Optional[torch.LongTensor] = None,
520519
attention_mask: Optional[torch.Tensor] = None,
520+
seq_idx: Optional[torch.Tensor] = None,
521+
use_precomputed_states: bool = False,
521522
):
522523
# 1. Gated MLP's linear projection
523-
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
524524
projected_states = self.in_proj(hidden_states)
525525

526526
# Set up dimensions for reshapes later
527527
batch_size, seq_len, _ = hidden_states.shape
528528
groups_time_state_size = self.n_groups * self.ssm_state_size
529529

530-
use_precomputed_states = (
531-
cache_params is not None
532-
and cache_params.has_previous_state
533-
and seq_len == 1
534-
and cache_params.conv_states[self.layer_idx].shape[0]
535-
== cache_params.ssm_states[self.layer_idx].shape[0]
536-
== batch_size
537-
and cache_position is not None
538-
and cache_position[0] > 0
539-
)
540-
541530
# getting projected states from cache if it exists
542531
if use_precomputed_states:
543532
gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
@@ -600,7 +589,7 @@ def cuda_kernels_forward(
600589
A,
601590
D=self.D,
602591
chunk_size=self.chunk_size,
603-
seq_idx=None, # was seq_idx
592+
seq_idx=seq_idx,
604593
activation=self.activation,
605594
rmsnorm_weight=self.norm.weight,
606595
rmsnorm_eps=self.norm.variance_epsilon,
@@ -684,29 +673,18 @@ def torch_forward(
684673
self,
685674
input_states,
686675
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
687-
cache_position: Optional[torch.LongTensor] = None,
688676
attention_mask: Optional[torch.Tensor] = None,
677+
use_precomputed_states: bool = False
689678
):
690679
batch_size, seq_len, _ = input_states.shape
691680
dtype = input_states.dtype
692681

693682
# 1. Gated MLP's linear projection
694-
input_states = apply_mask_to_padding_states(input_states, attention_mask)
695683
projected_states = self.in_proj(input_states)
696684
gate, hidden_states_B_C, dt = projected_states.split(
697685
[self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
698686
)
699687

700-
use_precomputed_states = (
701-
cache_params is not None
702-
and cache_params.has_previous_state
703-
and seq_len == 1
704-
and cache_params.conv_states[self.layer_idx].shape[0]
705-
== cache_params.ssm_states[self.layer_idx].shape[0]
706-
== batch_size
707-
and cache_position is not None
708-
and cache_position[0] > 0
709-
)
710688

711689
# 2. Convolution sequence transformation
712690
if use_precomputed_states:
@@ -893,15 +871,27 @@ def forward(
893871
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
894872
cache_position: Optional[torch.LongTensor] = None,
895873
attention_mask: Optional[torch.Tensor] = None,
874+
seq_idx: Optional[torch.Tensor] = None,
896875
):
876+
batch_size, seq_len, _ = hidden_states.shape
877+
use_precomputed_states = (
878+
cache_params is not None
879+
and cache_params.has_previous_state
880+
and seq_len == 1
881+
and cache_params.conv_states[self.layer_idx].shape[0]
882+
== cache_params.ssm_states[self.layer_idx].shape[0]
883+
== batch_size
884+
and cache_position is not None
885+
and cache_position[0] > 0
886+
)
887+
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
897888
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
898-
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
899-
dtype = hidden_states.dtype
900-
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
901-
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
902-
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
903-
904-
return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
889+
return self.cuda_kernels_forward(
890+
hidden_states, cache_params, attention_mask, seq_idx, use_precomputed_states
891+
)
892+
if seq_idx is not None:
893+
raise ValueError("Non-trivial seq_idx only supported on cuda path.")
894+
return self.torch_forward(hidden_states, cache_params, attention_mask, use_precomputed_states)
905895

906896

907897
class BambaMLP(nn.Module):
@@ -940,10 +930,42 @@ def extra_repr(self):
940930
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
941931

942932

933+
def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
934+
batch_size = position_ids.shape[0]
935+
if batch_size != 1:
936+
raise ValueError("Only batch size 1 is supported.")
937+
device = position_ids.device
938+
idxs = torch.arange(1, position_ids.shape[1], device=device)
939+
non_increasing_pos_id = position_ids[0, 1:] <= position_ids[0, :-1]
940+
cu_seq_lens = torch.cat(
941+
(
942+
torch.tensor([0], device=device),
943+
idxs[non_increasing_pos_id],
944+
torch.tensor(position_ids[0].shape, device=device),
945+
),
946+
)
947+
return cu_seq_lens[None]
948+
949+
950+
def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
951+
batch_size = cu_seq_lens.shape[0]
952+
if batch_size != 1:
953+
raise ValueError("Only batch size 1 is supported.")
954+
seq_idx = torch.cat(
955+
[
956+
torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device)
957+
for idx, n in enumerate(torch.diff(cu_seq_lens[0], dim=-1))
958+
]
959+
)
960+
return seq_idx[None]
961+
962+
943963
class BambaDecoderLayer(nn.Module):
944964
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
945965
super().__init__()
946966

967+
# The `num_experts` code below is redundant, but it prevents modular_model_converter.py from
968+
# generating an unwanted BambaSparseMoeBlock in modeling_bamba.py
947969
num_experts = 1
948970
ffn_layer_class = BambaMLP if num_experts == 1 else None
949971
self.feed_forward = ffn_layer_class(config)
@@ -968,7 +990,7 @@ def forward(
968990
use_cache: Optional[bool] = False,
969991
cache_position: Optional[torch.LongTensor] = None,
970992
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
971-
**kwargs,
993+
**kwargs: Unpack[FlashAttentionKwargs],
972994
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
973995
"""
974996
Args:
@@ -998,11 +1020,29 @@ def forward(
9981020

9991021
# this is a hybrid decoder layer
10001022
if self.layer_type == "mamba":
1023+
# Padding-free processing for efficient training. position_ids and FlashAttentionKwargs
1024+
# are ignored by mamba layers if not training.
1025+
if not self.training:
1026+
seq_idx = None
1027+
elif "cu_seq_lens_k" in kwargs:
1028+
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
1029+
elif position_ids is not None:
1030+
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
1031+
if len(cu_seq_lens[0]) == 2:
1032+
# If cu_seq_lens only has two elements, then it is semantically equivalent to
1033+
# `seq_idx=None`, which is more efficient.
1034+
seq_idx = None
1035+
else:
1036+
seq_idx = get_seq_idx_from_cu_seq_lens(cu_seq_lens)
1037+
else:
1038+
seq_idx = None
10011039
hidden_states = self.mamba(
10021040
hidden_states=hidden_states,
10031041
cache_params=past_key_value,
10041042
cache_position=cache_position,
10051043
attention_mask=attention_mask,
1044+
seq_idx=seq_idx,
1045+
**kwargs,
10061046
)
10071047
self_attn_weights = None
10081048
elif self.layer_type == "attention":
@@ -1202,6 +1242,7 @@ def forward(
12021242
output_hidden_states: Optional[bool] = None,
12031243
return_dict: Optional[bool] = None,
12041244
cache_position: Optional[torch.LongTensor] = None,
1245+
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
12051246
) -> Union[Tuple, BaseModelOutputWithPast]:
12061247
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
12071248
output_hidden_states = (
@@ -1275,6 +1316,7 @@ def forward(
12751316
use_cache=use_cache,
12761317
cache_position=cache_position,
12771318
position_embeddings=position_embeddings,
1319+
**flash_attn_kwargs,
12781320
)
12791321

12801322
hidden_states = layer_outputs[0]

0 commit comments

Comments
 (0)