From 0b54cf778e8d85facf1c6cbe01a96e1fa0eea916 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 12:11:45 -0800 Subject: [PATCH 1/4] fix init Signed-off-by: Mayank Mishra --- fla/layers/mamba2.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/fla/layers/mamba2.py b/fla/layers/mamba2.py index fd43ac63e2..85bd7fcbfc 100644 --- a/fla/layers/mamba2.py +++ b/fla/layers/mamba2.py @@ -2,6 +2,7 @@ from __future__ import annotations +import math import warnings from typing import TYPE_CHECKING @@ -170,11 +171,17 @@ def __init__( # time step projection (discretization) # instantiate once and copy inv_dt in init_weights of PretrainedModel - self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + # hard coded for now + dt_init_floor = 1e-4 + dt = torch.exp(torch.rand(self.num_heads) * (math.log(self.time_step_max) - math.log(self.time_step_min)) + math.log(self.time_step_min)) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) # S4D real initialization. These are not discretized! # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - A = torch.arange(1, self.num_heads + 1) + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) self.A_log = nn.Parameter(torch.log(A)) self.A_log._no_weight_decay = True self.norm = RMSNormGated( From 1364b38d6877ce26908b8fa9f5fbc7f2f9e8b485 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 12:17:53 -0800 Subject: [PATCH 2/4] fix init Signed-off-by: Mayank Mishra --- fla/models/mamba2/modeling_mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/models/mamba2/modeling_mamba2.py b/fla/models/mamba2/modeling_mamba2.py index eddc6da8f8..ca7173b3a2 100644 --- a/fla/models/mamba2/modeling_mamba2.py +++ b/fla/models/mamba2/modeling_mamba2.py @@ -200,7 +200,7 @@ def _init_weights( if isinstance(module, Mamba2): # --- A_log --- - A = torch.arange(1, module.num_heads + 1) + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) with torch.no_grad(): if not isinstance(module.A_log, DTensor): module.A_log.copy_(torch.log(A)) From 5eb0fb593b1a937e216e3a5077d6f4e2ed2efb1e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 12:27:35 -0800 Subject: [PATCH 3/4] fix init Signed-off-by: Mayank Mishra --- fla/models/mamba2/modeling_mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fla/models/mamba2/modeling_mamba2.py b/fla/models/mamba2/modeling_mamba2.py index ca7173b3a2..06855b66dc 100644 --- a/fla/models/mamba2/modeling_mamba2.py +++ b/fla/models/mamba2/modeling_mamba2.py @@ -200,7 +200,7 @@ def _init_weights( if isinstance(module, Mamba2): # --- A_log --- - A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + A = torch.empty(module.num_heads, dtype=torch.float32).uniform_(0, 16) with torch.no_grad(): if not isinstance(module.A_log, DTensor): module.A_log.copy_(torch.log(A)) From 4c01f1da22e791f333d726276770ae655e37c144 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:05:44 -0800 Subject: [PATCH 4/4] fix linter Signed-off-by: Mayank Mishra --- fla/layers/mamba2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fla/layers/mamba2.py b/fla/layers/mamba2.py index 85bd7fcbfc..dac6f1426c 100644 --- a/fla/layers/mamba2.py +++ b/fla/layers/mamba2.py @@ -173,7 +173,11 @@ def __init__( # instantiate once and copy inv_dt in init_weights of PretrainedModel # hard coded for now dt_init_floor = 1e-4 - dt = torch.exp(torch.rand(self.num_heads) * (math.log(self.time_step_max) - math.log(self.time_step_min)) + math.log(self.time_step_min)) + dt = torch.exp( + torch.rand(self.num_heads) * ( + math.log(self.time_step_max) - math.log(self.time_step_min) + ) + math.log(self.time_step_min) + ) dt = torch.clamp(dt, min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt))