Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add padding-free to bamba #35861

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

garrett361
Copy link

@garrett361 garrett361 commented Jan 23, 2025

What does this PR do?

Adds padding-free training to the BambaModel, enabling more efficient training with causal masking between disjoint sequences.

Performance: approximately 2x throughput improvements over naive padding for supervised finetuning on the Tulu v3 dataset with open-instruct. Tokens/sec/gpu plots for batch_size_per_gpu = 4:

8 A100s: 600 --> 1200 Tok/s/gpu

Scherm­afbeelding 2025-01-16 om 3 52 41 PM

32 A100s: 450 --> 750 Tok/s/gpu

Scherm­afbeelding 2025-01-16 om 3 52 33 PM

CC @fabianlim

CC reviewers of #34982: @ArthurZucker @molbap

Notes on Code

  • Code changes only affect the mamba layers (BambaMixer). BambaAttention layers are untouched.
  • The padding-free path is only supported on cuda and requires the mamba kernels.
  • Supports both the position_ids and FlashAttentionKwargs padding-free code paths.

Notes on Tests

On both latest main and this PR branch the following tests/models/bamba/test_modeling_bamba.py tests are failing (with RUN_SLOW=1):

BambaModelTest::test_eager_matches_fa2_generate
BambaModelTest::test_flash_attention_2_padding_matches_padding_free_with_position_ids
BambaModelTest::test_sdpa_can_compile_dynamic
BambaModelTest::test_torchscript_output_attentions
BambaModelTest::test_torchscript_output_hidden_state
BambaModelTest::test_torchscript_simple
BambaModelIntegrationTest::test_simple_generate
  • The test_eager_matches_fa2_generate test seems flaky: sometimes it passes, other times it fails.
  • For test_flash_attention_2_padding_matches_padding_free_with_position_ids:
    • On main, this test fails because padding-free is not implemented.
    • On this PR branch this test fails because this PR only uses position_ids when model.training = True and this test explicitly calls eval() on the model. I have checked that this test passes when model.training = True. Edit: see BambaModelTest::test_attn_mask_position_ids_flash_attn_equality, also.
  • test_simple_generate appears to just need a simple edit for its expected text. It consistently fails with:
AssertionError: '<|be[35 chars]on this lovely evening? I hope you are all doing well. I am' != '<|be[35 chars]on this lovely evening? I hope you are all having a good time.'

where the generated and expected text differs at the very end.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from cdaf1e6 to eab1ae1 Compare January 23, 2025 20:48
@garrett361 garrett361 closed this Jan 23, 2025
@garrett361 garrett361 reopened this Jan 24, 2025
@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from eab1ae1 to c4874af Compare January 24, 2025 14:53
@Rocketknight1
Copy link
Member

cc @ArthurZucker for bamba, but let me know if you want me to take a look since it seems like quite an extensive PR!

@garrett361
Copy link
Author

it seems like quite an extensive PR!

I don't think it's very many changes, ultimately! Basically it just adds two helper functions so that position_ids and FlashAttentionKwargs get properly converted to the seq_idx arg that mamba expects:

  • get_cu_seq_lens_from_position_ids
  • get_seq_idx_from_cu_seq_lens

So, basically the above, making sure **kwargs get passed everywhere they should, and a little code cleanup.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from 49d007c to d35bcc6 Compare January 24, 2025 20:04
@garrett361
Copy link
Author

Added a commit with BambaModelTest::test_attn_mask_position_ids_flash_attn_equality which tests the various code paths against each other.

@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from d35bcc6 to 7a9e343 Compare January 24, 2025 20:47
@garrett361 garrett361 force-pushed the bamba-hf-padding-free-pr branch from 9577fb4 to 0534b8f Compare January 24, 2025 21:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants