Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions fla/layers/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import math
import warnings
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -170,11 +171,21 @@ def __init__(

# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
# hard coded for now
dt_init_floor = 1e-4
dt = torch.exp(
torch.rand(self.num_heads) * (
math.log(self.time_step_max) - math.log(self.time_step_min)
) + math.log(self.time_step_min)
)
dt = torch.clamp(dt, min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
self.dt_bias = nn.Parameter(inv_dt)

# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.num_heads + 1)
A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))
Comment on lines +188 to 189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The initialization of A is changed here, but the corresponding _init_weights method in Mamba2PreTrainedModel (in fla/models/mamba2/modeling_mamba2.py) still uses the old logic (torch.arange(1, module.num_heads + 1)). Since _init_weights is called after the layer's __init__ when creating a model, it will overwrite this new initialization, making the fix ineffective. The logic in _init_weights for A_log needs to be updated to match this change.

self.A_log._no_weight_decay = True
self.norm = RMSNormGated(
Expand Down
2 changes: 1 addition & 1 deletion fla/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _init_weights(
if isinstance(module, Mamba2):

# --- A_log ---
A = torch.arange(1, module.num_heads + 1)
A = torch.empty(module.num_heads, dtype=torch.float32).uniform_(0, 16)
with torch.no_grad():
if not isinstance(module.A_log, DTensor):
module.A_log.copy_(torch.log(A))
Expand Down