Skip to content

Commit

Permalink
Add padding-free to bamba
Browse files Browse the repository at this point in the history
  • Loading branch information
garrett361 committed Jan 24, 2025
1 parent 72d1a4c commit c4874af
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 72 deletions.
110 changes: 76 additions & 34 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,28 +516,17 @@ def cuda_kernels_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
use_precomputed_states: bool = False,
):
# 1. Gated MLP's linear projection
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
projected_states = self.in_proj(hidden_states)

# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size

use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_params.conv_states[self.layer_idx].shape[0]
== cache_params.ssm_states[self.layer_idx].shape[0]
== batch_size
and cache_position is not None
and cache_position[0] > 0
)

# getting projected states from cache if it exists
if use_precomputed_states:
gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
Expand Down Expand Up @@ -600,7 +589,7 @@ def cuda_kernels_forward(
A,
D=self.D,
chunk_size=self.chunk_size,
seq_idx=None, # was seq_idx
seq_idx=seq_idx,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
Expand Down Expand Up @@ -684,29 +673,18 @@ def torch_forward(
self,
input_states,
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_precomputed_states: bool = False
):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype

# 1. Gated MLP's linear projection
input_states = apply_mask_to_padding_states(input_states, attention_mask)
projected_states = self.in_proj(input_states)
gate, hidden_states_B_C, dt = projected_states.split(
[self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)

use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_params.conv_states[self.layer_idx].shape[0]
== cache_params.ssm_states[self.layer_idx].shape[0]
== batch_size
and cache_position is not None
and cache_position[0] > 0
)

# 2. Convolution sequence transformation
if use_precomputed_states:
Expand Down Expand Up @@ -893,15 +871,27 @@ def forward(
cache_params: Optional[HybridMambaAttentionDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
seq_idx: Optional[torch.Tensor] = None,
):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_params.conv_states[self.layer_idx].shape[0]
== cache_params.ssm_states[self.layer_idx].shape[0]
== batch_size
and cache_position is not None
and cache_position[0] > 0
)
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)

return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
return self.cuda_kernels_forward(
hidden_states, cache_params, attention_mask, seq_idx, use_precomputed_states
)
if seq_idx is not None:
raise ValueError("Non-trivial seq_idx only supported on cuda path.")
return self.torch_forward(hidden_states, cache_params, attention_mask, use_precomputed_states)


class BambaMLP(nn.Module):
Expand Down Expand Up @@ -940,10 +930,42 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"


def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.LongTensor:
batch_size = position_ids.shape[0]
if batch_size != 1:
raise ValueError("Only batch size 1 is supported.")
device = position_ids.device
idxs = torch.arange(1, position_ids.shape[1], device=device)
non_increasing_pos_id = position_ids[0, 1:] <= position_ids[0, :-1]
cu_seq_lens = torch.cat(
(
torch.tensor([0], device=device),
idxs[non_increasing_pos_id],
torch.tensor(position_ids[0].shape, device=device),
),
)
return cu_seq_lens[None]


def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
batch_size = cu_seq_lens.shape[0]
if batch_size != 1:
raise ValueError("Only batch size 1 is supported.")
seq_idx = torch.cat(
[
torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device)
for idx, n in enumerate(torch.diff(cu_seq_lens[0], dim=-1))
]
)
return seq_idx[None]


class BambaDecoderLayer(nn.Module):
def __init__(self, config: BambaConfig, layer_idx: int, layer_type: str = "mamba"):
super().__init__()

# The `num_experts` code below is redundant, but it prevents modular_model_converter.py from
# generating an unwanted BambaSparseMoeBlock in modeling_bamba.py
num_experts = 1
ffn_layer_class = BambaMLP if num_experts == 1 else None
self.feed_forward = ffn_layer_class(config)
Expand All @@ -968,7 +990,7 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -998,11 +1020,29 @@ def forward(

# this is a hybrid decoder layer
if self.layer_type == "mamba":
# Padding-free processing for efficient training. position_ids and FlashAttentionKwargs
# are ignored by mamba layers if not training.
if not self.training:
seq_idx = None
elif "cu_seq_lens_k" in kwargs:
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
elif position_ids is not None:
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
if len(cu_seq_lens[0]) == 2:
# If cu_seq_lens only has two elements, then it is semantically equivalent to
# `seq_idx=None`, which is more efficient.
seq_idx = None
else:
seq_idx = get_seq_idx_from_cu_seq_lens(cu_seq_lens)
else:
seq_idx = None
hidden_states = self.mamba(
hidden_states=hidden_states,
cache_params=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
seq_idx=seq_idx,
**kwargs,
)
self_attn_weights = None
elif self.layer_type == "attention":
Expand Down Expand Up @@ -1202,6 +1242,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1275,6 +1316,7 @@ def forward(
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)

hidden_states = layer_outputs[0]
Expand Down
Loading

0 comments on commit c4874af

Please sign in to comment.