|
18 | 18 | import unittest
|
19 | 19 |
|
20 | 20 | import pytest
|
| 21 | +from pytest import mark |
21 | 22 |
|
22 | 23 | from transformers import AutoTokenizer, BambaConfig, is_torch_available
|
| 24 | +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
23 | 25 | from transformers.models.bamba.modular_bamba import get_cu_seq_lens_from_position_ids, get_seq_idx_from_cu_seq_lens
|
24 | 26 | from transformers.testing_utils import (
|
| 27 | + require_flash_attn, |
25 | 28 | require_torch,
|
| 29 | + require_torch_gpu, |
26 | 30 | slow,
|
27 | 31 | torch_device,
|
28 | 32 | )
|
@@ -482,6 +486,90 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature):
|
482 | 486 | # They should result in very similar logits
|
483 | 487 | torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5)
|
484 | 488 |
|
| 489 | + @require_flash_attn |
| 490 | + @require_torch_gpu |
| 491 | + @mark.flash_attn_test |
| 492 | + def test_attn_mask_position_ids_flash_attn_equality(self): |
| 493 | + r""" |
| 494 | + Verify that the logits agree when using an attention mask, position_ids, or |
| 495 | + FlashAttentionKwargs. |
| 496 | + """ |
| 497 | + torch.manual_seed(42) |
| 498 | + decoder_only_classes = [] |
| 499 | + for model_class in self.all_generative_model_classes: |
| 500 | + config, _, _, _ = self.model_tester.prepare_config_and_inputs() |
| 501 | + if config.is_encoder_decoder: |
| 502 | + continue |
| 503 | + else: |
| 504 | + decoder_only_classes.append(model_class) |
| 505 | + if len(decoder_only_classes) == 0: |
| 506 | + self.skipTest(reason="No decoder-only architecture available for this model.") |
| 507 | + |
| 508 | + # - Decoder-only architectures derived from encoder-decoder models could support it in theory, but we haven't |
| 509 | + # added support for it yet. We skip these models for now. |
| 510 | + has_encoder_attributes = any( |
| 511 | + attr_name |
| 512 | + for attr_name in config.to_dict().keys() |
| 513 | + if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" |
| 514 | + ) |
| 515 | + if has_encoder_attributes: |
| 516 | + self.skipTest( |
| 517 | + reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." |
| 518 | + ) |
| 519 | + |
| 520 | + for model_class in decoder_only_classes: |
| 521 | + config, input_ids, input_mask, _ = self.model_tester.prepare_config_and_inputs() |
| 522 | + # Padding-free requires training = True and attn_implementation="flash_attention_2" |
| 523 | + model = ( |
| 524 | + model_class._from_config(config, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16) |
| 525 | + .to(torch_device) |
| 526 | + .train() |
| 527 | + ) |
| 528 | + |
| 529 | + non_padding_free_inputs = {"input_ids": input_ids, "attention_mask": input_mask} |
| 530 | + attn_mask_logits = model(**non_padding_free_inputs).logits |
| 531 | + |
| 532 | + # Build up padding-free tensors |
| 533 | + padding_free_input_ids = torch.cat( |
| 534 | + [batch[mask.bool()] for batch, mask in zip(input_ids, input_mask)], dim=-1 |
| 535 | + )[None] |
| 536 | + position_ids_list = [ |
| 537 | + torch.arange(mask.sum(), device=mask.device, dtype=torch.int32) for mask in input_mask |
| 538 | + ] |
| 539 | + position_ids = torch.cat(position_ids_list, dim=-1)[None] |
| 540 | + seq_lens = torch.cat( |
| 541 | + [torch.tensor([t.numel()], device=input_mask.device, dtype=torch.int32) for t in position_ids_list], |
| 542 | + dim=-1, |
| 543 | + ) |
| 544 | + cu_seq_lens = torch.cat( |
| 545 | + [ |
| 546 | + torch.tensor([0], device=input_mask.device, dtype=torch.int32), |
| 547 | + seq_lens.cumsum(dim=-1, dtype=torch.int32), |
| 548 | + ], |
| 549 | + dim=-1, |
| 550 | + ) |
| 551 | + |
| 552 | + position_ids_inputs = {"input_ids": padding_free_input_ids, "position_ids": position_ids} |
| 553 | + position_ids_logits = model(**position_ids_inputs).logits |
| 554 | + |
| 555 | + flash_attn_kwargs = FlashAttentionKwargs( |
| 556 | + cu_seq_lens_q=cu_seq_lens, |
| 557 | + cu_seq_lens_k=cu_seq_lens, |
| 558 | + max_length_q=input_ids.shape[-1], |
| 559 | + max_length_k=input_ids.shape[-1], |
| 560 | + ) |
| 561 | + flash_attn_kwargs_logits = model(input_ids=padding_free_input_ids, **flash_attn_kwargs).logits |
| 562 | + |
| 563 | + attn_mask_logits_reshaped = torch.cat( |
| 564 | + [batch[mask.bool()] for batch, mask in zip(attn_mask_logits, input_mask)], dim=0 |
| 565 | + )[None] |
| 566 | + |
| 567 | + torch.testing.assert_close(position_ids_logits, attn_mask_logits_reshaped) |
| 568 | + # A higher tolerance is needed for the position_ids and FlashAttentionKwargs logits to |
| 569 | + # match, for unknown reasons. |
| 570 | + torch.testing.assert_close(position_ids_logits, flash_attn_kwargs_logits, atol=1e-3, rtol=1e-1) |
| 571 | + assert True |
| 572 | + |
485 | 573 |
|
486 | 574 | @slow
|
487 | 575 | @require_torch
|
@@ -598,60 +686,41 @@ def test_simple_batched_generate_with_padding(self):
|
598 | 686 | def test_cu_seq_lens_from_position_ids() -> None:
|
599 | 687 | seq_length = 256
|
600 | 688 | chunks_per_batch = 4
|
601 |
| - batch_size = 1 |
602 | 689 |
|
603 | 690 | # Split each batch into `chunks_per_batch` sequences.
|
604 |
| - eos_idxs = ( |
605 |
| - torch.stack([torch.randperm(seq_length) for _ in range(batch_size)], dim=0)[:, : chunks_per_batch - 1] |
606 |
| - .sort(dim=-1) |
607 |
| - .values |
608 |
| - ) |
609 |
| - seq_lens = torch.cat( |
610 |
| - (torch.full((batch_size, 1), -1), eos_idxs, torch.full((batch_size, 1), seq_length - 1)), dim=-1 |
611 |
| - ).diff(dim=-1) |
| 691 | + eos_idxs = torch.randperm(seq_length)[: chunks_per_batch - 1].sort(dim=-1).values |
| 692 | + seq_lens = torch.cat((torch.full((1,), -1), eos_idxs, torch.full((1,), seq_length - 1)), dim=-1).diff(dim=-1) |
612 | 693 |
|
613 | 694 | # Create the corresponding position_ids and seq_idx
|
614 |
| - position_ids = torch.stack( |
615 |
| - [ |
616 |
| - torch.cat( |
617 |
| - [torch.arange(s, dtype=torch.int32) for s in sl], |
618 |
| - dim=0, |
619 |
| - ) |
620 |
| - for sl in seq_lens |
621 |
| - ], |
| 695 | + position_ids = torch.cat( |
| 696 | + [torch.arange(s, dtype=torch.int32) for s in seq_lens], |
622 | 697 | dim=0,
|
623 |
| - ) |
| 698 | + )[None] |
624 | 699 |
|
625 | 700 | cu_seq_lens_pred = get_cu_seq_lens_from_position_ids(position_ids)
|
626 | 701 | assert torch.allclose(
|
627 | 702 | cu_seq_lens_pred,
|
628 |
| - torch.cat( |
629 |
| - [torch.tensor([[0]], dtype=seq_lens.dtype, device=seq_lens.device), seq_lens.cumsum(dim=-1)], dim=-1 |
630 |
| - ), |
| 703 | + torch.cat([torch.tensor([0], dtype=seq_lens.dtype, device=seq_lens.device), seq_lens.cumsum(dim=-1)], dim=-1), |
631 | 704 | )
|
632 | 705 |
|
633 | 706 |
|
634 | 707 | def test_seq_idx_from_cu_seq_lens() -> None:
|
635 | 708 | n_chunks = 5
|
636 | 709 | max_chunk_len = 64
|
637 |
| - batch_size = 1 |
638 | 710 |
|
639 |
| - seq_lens = torch.randint(1, max_chunk_len, size=(batch_size, n_chunks)) |
640 |
| - cu_seq_lens = torch.cat([torch.tensor([[0]]), seq_lens.cumsum(dim=-1)], dim=-1) |
| 711 | + seq_lens = torch.randint(1, max_chunk_len, size=(n_chunks,)) |
| 712 | + cu_seq_lens = torch.cat([torch.tensor([0]), seq_lens.cumsum(dim=-1)], dim=-1) |
641 | 713 | seq_idx = torch.cat(
|
642 | 714 | [
|
643 | 715 | torch.full(
|
644 |
| - ( |
645 |
| - batch_size, |
646 |
| - n, |
647 |
| - ), |
| 716 | + (n,), |
648 | 717 | idx,
|
649 | 718 | dtype=torch.int32,
|
650 | 719 | device=cu_seq_lens.device,
|
651 | 720 | )
|
652 |
| - for idx, n in enumerate(seq_lens[0]) |
| 721 | + for idx, n in enumerate(seq_lens) |
653 | 722 | ],
|
654 | 723 | dim=-1,
|
655 |
| - ) |
| 724 | + )[None] |
656 | 725 | seq_idx_pred = get_seq_idx_from_cu_seq_lens(cu_seq_lens)
|
657 | 726 | assert torch.allclose(seq_idx_pred, seq_idx)
|
0 commit comments