diff --git a/fla/layers/mamba2.py b/fla/layers/mamba2.py index fd43ac63e2..dac6f1426c 100644 --- a/fla/layers/mamba2.py +++ b/fla/layers/mamba2.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math import warnings from typing import TYPE_CHECKING @@ -170,11 +171,21 @@ def __init__( # time step projection (discretization) # instantiate once and copy inv_dt in init_weights of PretrainedModel - self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + # hard coded for now + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(self.num_heads) * ( + math.log(self.time_step_max) - math.log(self.time_step_min) + ) + math.log(self.time_step_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) # S4D real initialization. These are not discretized! # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - A = torch.arange(1, self.num_heads + 1) + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True self.norm = RMSNormGated( diff --git a/fla/models/mamba2/modeling_mamba2.py b/fla/models/mamba2/modeling_mamba2.py index eddc6da8f8..06855b66dc 100644 --- a/fla/models/mamba2/modeling_mamba2.py +++ b/fla/models/mamba2/modeling_mamba2.py @@ -200,7 +200,7 @@ def _init_weights( if isinstance(module, Mamba2): # --- A_log --- - A = torch.arange(1, module.num_heads + 1) + A = torch.empty(module.num_heads, dtype=torch.float32).uniform_(0, 16) with torch.no_grad(): if not isinstance(module.A_log, DTensor): module.A_log.copy_(torch.log(A))