Skip to content

Conversation

@mayank31398
Copy link

@mayank31398 mayank31398 commented Feb 8, 2026

The init in FLA repo for mamba2 is incorrect.
This PR fixes the issue.
After the fix, Mamba2 outperforms GDN at 7B MoE scale (1B active params).
The difference between wrong and fixed init is significant.

Training curves: https://wandb.ai/mayank31398/mamba-test?nw=nwusermayank31398&panelDisplayName=train/lm_loss&panelSectionName=train

Summary by CodeRabbit

  • Refactor
    • Model initialization now uses stochastic per-head discretization parameters instead of a fixed deterministic sequence, increasing initialization variability across heads.
    • This is an internal change only: public interfaces, control flow, and external behavior are unchanged. May slightly affect training/run-to-run variability but requires no action from users.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @mayank31398, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical initialization error within the Mamba2 implementation in the FLA repository. By correcting the dt_bias and A parameter initialization, the Mamba2 model's performance is significantly improved, demonstrating superior results compared to GDN at a 7B MoE scale. This change is crucial for the model's stability and effectiveness.

Highlights

  • Mamba2 Initialization Fix: Corrected an incorrect initialization within the Mamba2 implementation in the FLA repository, specifically for the dt_bias and A parameters.
  • Improved dt_bias Initialization: The dt_bias parameter is now initialized using a more sophisticated calculation involving time_step_max, time_step_min, and an inverse softplus function, replacing a simple torch.ones initialization.
  • Revised A Parameter Initialization: The A parameter is now initialized with values uniformly sampled between 0 and 16, instead of an arithmetic progression.
  • Performance Improvement: This initialization fix leads to Mamba2 outperforming GDN at a 7B MoE scale (1B active parameters), highlighting the significant impact of the change.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • fla/layers/mamba2.py
    • Imported the math module to support new initialization calculations.
    • Modified the initialization logic for self.dt_bias to use a more sophisticated calculation involving time_step_max, time_step_min, and an inverse softplus function, replacing a simple torch.ones initialization.
    • Changed the initialization of the A parameter from torch.arange(1, self.num_heads + 1) to torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16).
Activity
  • No human activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 8, 2026

Walkthrough

Modified per-head state-space discretization initialization in Mamba2: dt_bias now computed from a random per-head dt (clamped and inverse-softplused) and A_log now initialized from a random uniform per-head A in [0,16), then logged. Changes occur in fla/layers/mamba2.py and fla/models/mamba2/modeling_mamba2.py.

Changes

Cohort / File(s) Summary
Mamba2 Layer Init
fla/layers/mamba2.py
Replaced fixed dt_bias initialization with stochastic per-head computation: sample dt per head, clamp with floor, transform via inverse softplus, and store as a parameter. Replaced deterministic A_log init with random uniform A ∈ [0,16) then log-transform.
Mamba2 Model Weight Init
fla/models/mamba2/modeling_mamba2.py
Changed weight-init source for A_log from deterministic arange(1, num_heads+1) to a random uniform A in [0,16), then log-transformed and conditionally copied into module.A_log as before.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Poem

🐰 I nudged each head with jitter and cheer,
A tiny dt, a log, a hop so near.
Softplus tucked the bias tight,
Random A danced into the night.
🥕—a rabbit applauds the new delight.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title "[MAMBA2] fix initialization for mamba2" clearly describes the main changes in the PR, which involve fixing the initialization of dt_bias and A_log parameters in the Mamba2 class. It is specific and directly related to the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

No actionable comments were generated in the recent review. 🎉


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to correct the initialization of the Mamba2 layer. The changes to the initialization of dt_bias and A_log are a good step forward. However, I've identified a couple of issues. A magic number has been introduced for dt_init_floor, which should ideally be configurable. More importantly, there's a critical inconsistency in the A_log initialization between the Mamba2 layer and the Mamba2PreTrainedModel, which will cause the fix to be ineffective when initializing a new model. Please see my detailed comments below.

Comment on lines +184 to 185
A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The initialization of A is changed here, but the corresponding _init_weights method in Mamba2PreTrainedModel (in fla/models/mamba2/modeling_mamba2.py) still uses the old logic (torch.arange(1, module.num_heads + 1)). Since _init_weights is called after the layer's __init__ when creating a model, it will overwrite this new initialization, making the fix ineffective. The logic in _init_weights for A_log needs to be updated to match this change.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@fla/layers/mamba2.py`:
- Around line 174-180: The long single-line expression that computes dt (the
line assigning dt = torch.exp(torch.rand(self.num_heads) *
(math.log(self.time_step_max) - math.log(self.time_step_min)) +
math.log(self.time_step_min))) should be split into smaller statements for PEP8
readability: compute the random tensor (use rand = torch.rand(self.num_heads)),
compute the log range (log_range = math.log(self.time_step_max) -
math.log(self.time_step_min)), compute the log offset (log_min =
math.log(self.time_step_min)), then compute dt = torch.exp(rand * log_range +
log_min), followed by the existing clamp, inv_dt computation and nn.Parameter
assignment to self.dt_bias to preserve behavior using the same symbols dt, rand,
log_range, log_min, dt_init_floor, inv_dt, and self.dt_bias.

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
@mayank31398
Copy link
Author

One big concern: if a user is training any model with FSDP-2, since all tensors are DTensors, it will skip initialization according to the FLA logic.
I dont know how to fix this.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@fla/models/mamba2/modeling_mamba2.py`:
- Line 203: The CI lint failed due to formatting in the assignment to A (the
line "A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)" in
modeling_mamba2.py); run the repository pre-commit hooks or autopep8 (e.g.,
pre-commit run --files fla/models/mamba2/modeling_mamba2.py or autopep8
--in-place) to reformat the file so it passes the autopep8/style checks, then
amend the commit.
- Line 203: The code creates A with torch.empty(self.num_heads, ...) but
self.num_heads doesn't exist; change the array size to use the config property
(self.config.num_heads) so A = torch.empty(self.config.num_heads,
dtype=torch.float32).uniform_(0, 16) in the same method where line 217 already
uses self.config.num_heads (update the creation of A in modeling_mamba2.py,
within the Mamba2PreTrainedModel-derived method).

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
@mayank31398
Copy link
Author

AFAIK this PR still doesnt fix if the meta init is done with DTensors. I am not sure how to fix that

@sustcsonglin
Copy link
Collaborator

One big concern: if a user is training any model with FSDP-2, since all tensors are DTensors, it will skip initialization according to the FLA logic. I dont know how to fix this.

@yzhangcs i remembered you fixed this before. any thoughts?

@sustcsonglin
Copy link
Collaborator

@mayank31398 thanks for ur pr! could you fix lint errors?

Signed-off-by: Mayank Mishra <mayank31398@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants