Skip to content

Commit d35bcc6

Browse files
committed
test_attn_mask_position_ids_flash_attn_equality
1 parent 8d191c7 commit d35bcc6

File tree

3 files changed

+125
-42
lines changed

3 files changed

+125
-42
lines changed

src/transformers/models/bamba/modeling_bamba.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,7 @@ def forward(
872872
cache_position: Optional[torch.LongTensor] = None,
873873
attention_mask: Optional[torch.Tensor] = None,
874874
seq_idx: Optional[torch.Tensor] = None,
875+
**kwargs,
875876
):
876877
batch_size, seq_len, _ = hidden_states.shape
877878
use_precomputed_states = (
@@ -944,17 +945,14 @@ def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.L
944945
torch.tensor(position_ids[0].shape, device=device),
945946
),
946947
)
947-
return cu_seq_lens[None]
948+
return cu_seq_lens
948949

949950

950951
def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
951-
batch_size = cu_seq_lens.shape[0]
952-
if batch_size != 1:
953-
raise ValueError("Only batch size 1 is supported.")
954952
seq_idx = torch.cat(
955953
[
956954
torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device)
957-
for idx, n in enumerate(torch.diff(cu_seq_lens[0], dim=-1))
955+
for idx, n in enumerate(torch.diff(cu_seq_lens, dim=-1))
958956
]
959957
)
960958
return seq_idx[None]
@@ -1028,7 +1026,7 @@ def forward(
10281026
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
10291027
elif position_ids is not None:
10301028
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
1031-
if len(cu_seq_lens[0]) == 2:
1029+
if len(cu_seq_lens) == 2:
10321030
# If cu_seq_lens only has two elements, then it is semantically equivalent to
10331031
# `seq_idx=None`, which is more efficient.
10341032
seq_idx = None
@@ -1244,6 +1242,15 @@ def forward(
12441242
cache_position: Optional[torch.LongTensor] = None,
12451243
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
12461244
) -> Union[Tuple, BaseModelOutputWithPast]:
1245+
if (
1246+
self.training
1247+
and (position_ids is not None or "cu_seq_lens_k" in flash_attn_kwargs)
1248+
and (self.config._attn_implementation != "flash_attention_2" or not is_fast_path_available)
1249+
):
1250+
raise ValueError(
1251+
"Padding-free training using position_ids or FlashAttentionKwargs requires ",
1252+
"the flash_attention_2 attention implementation and mamba cuda and triton kernels.",
1253+
)
12471254
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
12481255
output_hidden_states = (
12491256
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

src/transformers/models/bamba/modular_bamba.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,17 +220,14 @@ def get_cu_seq_lens_from_position_ids(position_ids: torch.LongTensor) -> torch.L
220220
torch.tensor(position_ids[0].shape, device=device),
221221
),
222222
)
223-
return cu_seq_lens[None]
223+
return cu_seq_lens
224224

225225

226226
def get_seq_idx_from_cu_seq_lens(cu_seq_lens: torch.Tensor) -> torch.Tensor:
227-
batch_size = cu_seq_lens.shape[0]
228-
if batch_size != 1:
229-
raise ValueError("Only batch size 1 is supported.")
230227
seq_idx = torch.cat(
231228
[
232229
torch.full((n,), idx, dtype=torch.int32, device=cu_seq_lens.device)
233-
for idx, n in enumerate(torch.diff(cu_seq_lens[0], dim=-1))
230+
for idx, n in enumerate(torch.diff(cu_seq_lens, dim=-1))
234231
]
235232
)
236233
return seq_idx[None]
@@ -678,6 +675,7 @@ def forward(
678675
cache_position: Optional[torch.LongTensor] = None,
679676
attention_mask: Optional[torch.Tensor] = None,
680677
seq_idx: Optional[torch.Tensor] = None,
678+
**kwargs,
681679
):
682680
batch_size, seq_len, _ = hidden_states.shape
683681
use_precomputed_states = (
@@ -776,7 +774,7 @@ def forward(
776774
seq_idx = get_seq_idx_from_cu_seq_lens(kwargs["cu_seq_lens_k"])
777775
elif position_ids is not None:
778776
cu_seq_lens = get_cu_seq_lens_from_position_ids(position_ids)
779-
if len(cu_seq_lens[0]) == 2:
777+
if len(cu_seq_lens) == 2:
780778
# If cu_seq_lens only has two elements, then it is semantically equivalent to
781779
# `seq_idx=None`, which is more efficient.
782780
seq_idx = None
@@ -992,6 +990,15 @@ def forward(
992990
cache_position: Optional[torch.LongTensor] = None,
993991
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
994992
) -> Union[Tuple, BaseModelOutputWithPast]:
993+
if (
994+
self.training
995+
and (position_ids is not None or "cu_seq_lens_k" in flash_attn_kwargs)
996+
and (self.config._attn_implementation != "flash_attention_2" or not is_fast_path_available)
997+
):
998+
raise ValueError(
999+
"Padding-free training using position_ids or FlashAttentionKwargs requires ",
1000+
"the flash_attention_2 attention implementation and mamba cuda and triton kernels.",
1001+
)
9951002
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
9961003
output_hidden_states = (
9971004
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

tests/models/bamba/test_modeling_bamba.py

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,15 @@
1818
import unittest
1919

2020
import pytest
21+
from pytest import mark
2122

2223
from transformers import AutoTokenizer, BambaConfig, is_torch_available
24+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
2325
from transformers.models.bamba.modular_bamba import get_cu_seq_lens_from_position_ids, get_seq_idx_from_cu_seq_lens
2426
from transformers.testing_utils import (
27+
require_flash_attn,
2528
require_torch,
29+
require_torch_gpu,
2630
slow,
2731
torch_device,
2832
)
@@ -482,6 +486,90 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
482486
# They should result in very similar logits
483487
torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
484488

489+
@require_flash_attn
490+
@require_torch_gpu
491+
@mark.flash_attn_test
492+
def test_attn_mask_position_ids_flash_attn_equality(self):
493+
r"""
494+
Verify that the logits agree when using an attention mask, position_ids, or
495+
FlashAttentionKwargs.
496+
"""
497+
torch.manual_seed(42)
498+
decoder_only_classes = []
499+
for model_class in self.all_generative_model_classes:
500+
config, _, _, _ = self.model_tester.prepare_config_and_inputs()
501+
if config.is_encoder_decoder:
502+
continue
503+
else:
504+
decoder_only_classes.append(model_class)
505+
if len(decoder_only_classes) == 0:
506+
self.skipTest(reason="No decoder-only architecture available for this model.")
507+
508+
# - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't
509+
# added support for it yet. We skip these models for now.
510+
has_encoder_attributes = any(
511+
attr_name
512+
for attr_name in config.to_dict().keys()
513+
if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size"
514+
)
515+
if has_encoder_attributes:
516+
self.skipTest(
517+
reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding."
518+
)
519+
520+
for model_class in decoder_only_classes:
521+
config, input_ids, input_mask, _ = self.model_tester.prepare_config_and_inputs()
522+
# Padding-free requires training = True and attn_implementation="flash_attention_2"
523+
model = (
524+
model_class._from_config(config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
525+
.to(torch_device)
526+
.train()
527+
)
528+
529+
non_padding_free_inputs = {"input_ids": input_ids, "attention_mask": input_mask}
530+
attn_mask_logits = model(**non_padding_free_inputs).logits
531+
532+
# Build up padding-free tensors
533+
padding_free_input_ids = torch.cat(
534+
[batch[mask.bool()] for batch, mask in zip(input_ids, input_mask)], dim=-1
535+
)[None]
536+
position_ids_list = [
537+
torch.arange(mask.sum(), device=mask.device, dtype=torch.int32) for mask in input_mask
538+
]
539+
position_ids = torch.cat(position_ids_list, dim=-1)[None]
540+
seq_lens = torch.cat(
541+
[torch.tensor([t.numel()], device=input_mask.device, dtype=torch.int32) for t in position_ids_list],
542+
dim=-1,
543+
)
544+
cu_seq_lens = torch.cat(
545+
[
546+
torch.tensor([0], device=input_mask.device, dtype=torch.int32),
547+
seq_lens.cumsum(dim=-1, dtype=torch.int32),
548+
],
549+
dim=-1,
550+
)
551+
552+
position_ids_inputs = {"input_ids": padding_free_input_ids, "position_ids": position_ids}
553+
position_ids_logits = model(**position_ids_inputs).logits
554+
555+
flash_attn_kwargs = FlashAttentionKwargs(
556+
cu_seq_lens_q=cu_seq_lens,
557+
cu_seq_lens_k=cu_seq_lens,
558+
max_length_q=input_ids.shape[-1],
559+
max_length_k=input_ids.shape[-1],
560+
)
561+
flash_attn_kwargs_logits = model(input_ids=padding_free_input_ids, **flash_attn_kwargs).logits
562+
563+
attn_mask_logits_reshaped = torch.cat(
564+
[batch[mask.bool()] for batch, mask in zip(attn_mask_logits, input_mask)], dim=0
565+
)[None]
566+
567+
torch.testing.assert_close(position_ids_logits, attn_mask_logits_reshaped)
568+
# A higher tolerance is needed for the position_ids and FlashAttentionKwargs logits to
569+
# match, for unknown reasons.
570+
torch.testing.assert_close(position_ids_logits, flash_attn_kwargs_logits, atol=1e-3, rtol=1e-1)
571+
assert True
572+
485573

486574
@slow
487575
@require_torch
@@ -598,60 +686,41 @@ def test_simple_batched_generate_with_padding(self):
598686
def test_cu_seq_lens_from_position_ids() -> None:
599687
seq_length = 256
600688
chunks_per_batch = 4
601-
batch_size = 1
602689

603690
# Split each batch into `chunks_per_batch` sequences.
604-
eos_idxs = (
605-
torch.stack([torch.randperm(seq_length) for _ in range(batch_size)], dim=0)[:, : chunks_per_batch - 1]
606-
.sort(dim=-1)
607-
.values
608-
)
609-
seq_lens = torch.cat(
610-
(torch.full((batch_size, 1), -1), eos_idxs, torch.full((batch_size, 1), seq_length - 1)), dim=-1
611-
).diff(dim=-1)
691+
eos_idxs = torch.randperm(seq_length)[: chunks_per_batch - 1].sort(dim=-1).values
692+
seq_lens = torch.cat((torch.full((1,), -1), eos_idxs, torch.full((1,), seq_length - 1)), dim=-1).diff(dim=-1)
612693

613694
# Create the corresponding position_ids and seq_idx
614-
position_ids = torch.stack(
615-
[
616-
torch.cat(
617-
[torch.arange(s, dtype=torch.int32) for s in sl],
618-
dim=0,
619-
)
620-
for sl in seq_lens
621-
],
695+
position_ids = torch.cat(
696+
[torch.arange(s, dtype=torch.int32) for s in seq_lens],
622697
dim=0,
623-
)
698+
)[None]
624699

625700
cu_seq_lens_pred = get_cu_seq_lens_from_position_ids(position_ids)
626701
assert torch.allclose(
627702
cu_seq_lens_pred,
628-
torch.cat(
629-
[torch.tensor([[0]], dtype=seq_lens.dtype, device=seq_lens.device), seq_lens.cumsum(dim=-1)], dim=-1
630-
),
703+
torch.cat([torch.tensor([0], dtype=seq_lens.dtype, device=seq_lens.device), seq_lens.cumsum(dim=-1)], dim=-1),
631704
)
632705

633706

634707
def test_seq_idx_from_cu_seq_lens() -> None:
635708
n_chunks = 5
636709
max_chunk_len = 64
637-
batch_size = 1
638710

639-
seq_lens = torch.randint(1, max_chunk_len, size=(batch_size, n_chunks))
640-
cu_seq_lens = torch.cat([torch.tensor([[0]]), seq_lens.cumsum(dim=-1)], dim=-1)
711+
seq_lens = torch.randint(1, max_chunk_len, size=(n_chunks,))
712+
cu_seq_lens = torch.cat([torch.tensor([0]), seq_lens.cumsum(dim=-1)], dim=-1)
641713
seq_idx = torch.cat(
642714
[
643715
torch.full(
644-
(
645-
batch_size,
646-
n,
647-
),
716+
(n,),
648717
idx,
649718
dtype=torch.int32,
650719
device=cu_seq_lens.device,
651720
)
652-
for idx, n in enumerate(seq_lens[0])
721+
for idx, n in enumerate(seq_lens)
653722
],
654723
dim=-1,
655-
)
724+
)[None]
656725
seq_idx_pred = get_seq_idx_from_cu_seq_lens(cu_seq_lens)
657726
assert torch.allclose(seq_idx_pred, seq_idx)

0 commit comments

Comments
 (0)