From 86db0ea0a8712dc9e29704f8682593fb29ea31ad Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 04:17:08 -0800 Subject: [PATCH 01/24] add GDN efficient init Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/gated_deltanet.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 21affa061..19962481f 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -12,7 +12,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import Replicate +from ....dtensors import tensor_to_dtensor from ....utils import divide_if_divisible, is_fla_available from ...cache import GenerationCache from ...parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay @@ -223,8 +226,18 @@ def forward( return o + @torch.no_grad() def reset_parameters(self) -> None: A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16) + + if isinstance(self.A_log, DTensor): + A = tensor_to_dtensor( + tensor=A, + device_mesh=self.A_log.device_mesh, + current_placement=[Replicate()] * len(self.A_log.placements), + desired_placement=self.A_log.placements, + ) + self.A_log.copy_(torch.log(A)) # hard coded for now @@ -236,6 +249,14 @@ def reset_parameters(self) -> None: # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) + if isinstance(self.dt_bias, DTensor): + inv_dt = tensor_to_dtensor( + tensor=inv_dt, + device_mesh=self.dt_bias.device_mesh, + current_placement=[Replicate()] * len(self.dt_bias.placements), + desired_placement=self.dt_bias.placements, + ) + self.dt_bias.copy_(inv_dt) mark_parameter_as_initialized(self.A_log) From db411cf329bd97c10334db87eef17f4e3f8c649f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 04:24:59 -0800 Subject: [PATCH 02/24] fix mamba2 init Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/mamba2.py | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index e4ab4f0cb..b646bf900 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -305,34 +305,31 @@ def _torch_forward( # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] # NOTE: S = 1 actually here - B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] - # B -> (B, G, 1, ssm_state_size / num_groups) - B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() - # B -> (B, G, N / G, ssm_state_size / num_groups) - B = B.reshape(batch_size, -1, B.shape[-1]) - # B -> (B, N, ssm_state_size / num_groups) - - # (B, N, head_dim, 1) * (B, N, 1, ssm_state_size / num_groups) + B, C = [i.reshape(batch_size, self.n_groups, -1)[..., None, :] for i in (B, C)] + # B, C -> (B, G, 1, ssm_state_size) + B, C = [ + i.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, i.shape[-1]).contiguous() + for i in (B, C) + ] + # B, C -> (B, G, N / G, ssm_state_size) + B, C = [i.reshape(batch_size, -1, i.shape[-1]) for i in (B, C)] + # B, C -> (B, N, ssm_state_size) + + # (B, N, head_dim, 1) * (B, N, 1, ssm_state_size) + # B is same as k and is shared across heads and dt is used to expand it dB = dt[..., None] * B[..., None, :] - # dB -> (B, N, head_dim, ssm_state_size / num_groups) + # dB -> (B, N, head_dim, ssm_state_size) # Discretize x into dB hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) # hidden_states -> (B, N, head_dim) dBx = (dB * hidden_states[..., None]).to(device=cache_device) - # dBx -> (B, N, head_dim, ssm_state_size / num_groups) + # dBx -> (B, N, head_dim, ssm_state_size) # State calculation ssm_state = ssm_state * dA + dBx cache_params.update(ssm_state=ssm_state, num_tokens_added=seq_len, layer_idx=self.layer_idx) - # Subsequent output - # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] - C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] - C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() - C = C.reshape(batch_size, -1, C.shape[-1]) - # [bsz, num_heads, head_dim] - ssm_state = ssm_state.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_state.view( @@ -351,7 +348,7 @@ def _torch_forward( y = y.reshape(batch_size, -1)[:, None, ...] else: # begin ssd naive implementation without einsums - dt = nn.functional.softplus(dt + self.dt_bias) + dt = F.softplus(dt + self.dt_bias) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() @@ -502,7 +499,8 @@ def _cuda_forward( dt_softplus=True, ) hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) - hidden_states = self.norm(hidden_states, gate) + hidden_states = hidden_states * F.silu(gate) + hidden_states = self.norm(hidden_states) # 4. Final linear projection out = self.out_proj(hidden_states)[:, None, ...] @@ -602,20 +600,21 @@ def _cuda_forward( @torch.no_grad() def reset_parameters(self) -> None: - A = torch.log(torch.arange(1, self.num_heads + 1)) + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + self.A_log.copy_(torch.log(A)) - if isinstance(self.A_log, DTensor): - A = tensor_to_dtensor( - A, - device_mesh=self.A_log.device_mesh, - current_placement=[Replicate()] * len(self.A_log.placements), - desired_placement=self.A_log.placements, - ) + # hard coded for now + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp(torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) - self.A_log.copy_(A) + self.dt_bias.copy_(inv_dt) nn.init.ones_(self.D) - nn.init.ones_(self.dt_bias) mark_parameter_as_initialized(self.A_log) mark_parameter_as_initialized(self.D) From 06ce95f9758f0563394c0060ae5ac325289f55a2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 04:33:59 -0800 Subject: [PATCH 03/24] pass args Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index cbe4c1b4e..96f1a3ecd 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -139,12 +139,14 @@ def get_sequence_mixer( num_k_heads=block.num_k_heads, num_v_heads=block.num_v_heads, use_gate=block.use_gate, + attention_multiplier=block.attention_multiplier, allow_neg_eigval=block.allow_neg_eigval, - conv_size=block.conv_size, + conv_size=block.kernel_size, layer_idx=layer_idx, norm_eps=config.layer_norm_epsilon, init_method=config.init_method, initializer_range=config.initializer_range, + m_width=config.m_width, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, ) From 3a21404265144c64111bde6ae325ab0fcae05c0b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 16:59:24 -0800 Subject: [PATCH 04/24] hidden_states -> x Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/mamba2.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index b646bf900..85eab46b5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -601,6 +601,15 @@ def _cuda_forward( @torch.no_grad() def reset_parameters(self) -> None: A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + + if isinstance(self.A_log, DTensor): + A = tensor_to_dtensor( + tensor=A, + device_mesh=self.A_log.device_mesh, + current_placement=[Replicate()] * len(self.A_log.placements), + desired_placement=self.A_log.placements, + ) + self.A_log.copy_(torch.log(A)) # hard coded for now @@ -612,6 +621,14 @@ def reset_parameters(self) -> None: # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) + if isinstance(self.dt_bias, DTensor): + inv_dt = tensor_to_dtensor( + tensor=inv_dt, + device_mesh=self.dt_bias.device_mesh, + current_placement=[Replicate()] * len(self.dt_bias.placements), + desired_placement=self.dt_bias.placements, + ) + self.dt_bias.copy_(inv_dt) nn.init.ones_(self.D) From b6dd81b25c9abf0e8d97ee12f0eb7610799ea53b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:05:13 -0800 Subject: [PATCH 05/24] hidden_states -> x Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 10 ++++++++ .../sequence_mixer_blocks/__init__.py | 5 ++++ .../sequence_mixer_blocks/gated_deltanet.py | 23 ++++++++++++++----- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index d737d46c3..0e44ceadc 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -124,6 +124,16 @@ class _GatedDeltaNetArgs(BaseArgs): attention_multiplier: float | None = None allow_neg_eigval: bool kernel_size: int + A_init_min: float = 0 + A_init_max: float = 16 + dt_min: float = 0.001 + dt_max: float = 0.1 + dt_init_floor: float = 1e-4 def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "gated_deltanet" + + assert self.A_init_min >= 0 + assert self.A_init_min <= self.A_init_max + + assert self.dt_min <= self.dt_max diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 96f1a3ecd..4fc7b6904 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -147,6 +147,11 @@ def get_sequence_mixer( init_method=config.init_method, initializer_range=config.initializer_range, m_width=config.m_width, + A_init_min=block.A_init_min, + A_init_max=block.A_init_max, + dt_min=block.dt_min, + dt_max=block.dt_max, + dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 19962481f..8fbf455a5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -47,6 +47,11 @@ def __init__( init_method: str, initializer_range: float, m_width: float | None, + A_init_min: float, + A_init_max: float, + dt_min: float, + dt_max: float, + dt_init_floor: float, num_layers: int, use_padding_free_transformer: bool, ) -> GatedDeltaNet: @@ -67,6 +72,13 @@ def __init__( self.k_head_dim = k_head_dim self.v_head_dim = v_head_dim + self.A_init_min = A_init_min + self.A_init_max = A_init_max + + self.dt_min = dt_min + self.dt_max = dt_max + self.dt_init_floor = dt_init_floor + self.key_dim = self.num_k_heads * self.k_head_dim self.value_dim = self.num_v_heads * self.v_head_dim self.layer_idx = layer_idx @@ -228,7 +240,7 @@ def forward( @torch.no_grad() def reset_parameters(self) -> None: - A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16) + A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(self.A_init_min, self.A_init_max) if isinstance(self.A_log, DTensor): A = tensor_to_dtensor( @@ -241,11 +253,10 @@ def reset_parameters(self) -> None: self.A_log.copy_(torch.log(A)) # hard coded for now - dt_min = 0.001 - dt_max = 0.1 - dt_init_floor = 1e-4 - dt = torch.exp(torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) - dt = torch.clamp(dt, min=dt_init_floor) + dt = torch.exp( + torch.rand(self.num_v_heads) * (math.log(self.dt_max) - math.log(self.dt_min)) + math.log(self.dt_min) + ) + dt = torch.clamp(dt, min=self.dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) From 8dc2c7952d144d028c61e93fe5e013bbbd087cf8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:07:35 -0800 Subject: [PATCH 06/24] hidden_states -> x Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 57 ++++++++++---------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 0e44ceadc..1c497d649 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -49,24 +49,6 @@ def model_post_init(self, __context: Any) -> None: assert self.head_dim is not None -class _Mamba2Args(BaseArgs): - sequence_mixer_type: str = "mamba2" - state_size: int = 128 - intermediate_size: int - num_heads: int = 128 - conv_kernel_size: int = 4 - time_step_limit: tuple[float, float] = (0, float("inf")) - add_bias: bool = False - use_conv_bias: bool = True - activation_function: str = "silu" - num_groups: int = 8 - chunk_size: int = 256 - normalization_function: str | None = "rmsnorm" - - def model_post_init(self, __context: Any) -> None: - assert self.sequence_mixer_type == "mamba2" - - class _GRUArgs(BaseArgs): sequence_mixer_type: str = "gru" state_head_dim: int @@ -114,7 +96,20 @@ def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "causal_convolution" -class _GatedDeltaNetArgs(BaseArgs): +class _SoftPlusDecayArgs(BaseArgs): + A_init_min: float = 0 + A_init_max: float = 16 + dt_min: float = 0.001 + dt_max: float = 0.1 + dt_init_floor: float = 1e-4 + + def model_post_init(self, __context: Any) -> None: + assert self.A_init_min >= 0 + assert self.A_init_min <= self.A_init_max + assert self.dt_min <= self.dt_max + + +class _GatedDeltaNetArgs(_SoftPlusDecayArgs): sequence_mixer_type: str = "gated_deltanet" k_head_dim: int v_head_dim: int @@ -124,16 +119,24 @@ class _GatedDeltaNetArgs(BaseArgs): attention_multiplier: float | None = None allow_neg_eigval: bool kernel_size: int - A_init_min: float = 0 - A_init_max: float = 16 - dt_min: float = 0.001 - dt_max: float = 0.1 - dt_init_floor: float = 1e-4 def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "gated_deltanet" - assert self.A_init_min >= 0 - assert self.A_init_min <= self.A_init_max - assert self.dt_min <= self.dt_max +class _Mamba2Args(_SoftPlusDecayArgs): + sequence_mixer_type: str = "mamba2" + state_size: int = 128 + intermediate_size: int + num_heads: int = 128 + conv_kernel_size: int = 4 + time_step_limit: tuple[float, float] = (0, float("inf")) + add_bias: bool = False + use_conv_bias: bool = True + activation_function: str = "silu" + num_groups: int = 8 + chunk_size: int = 256 + normalization_function: str | None = "rmsnorm" + + def model_post_init(self, __context: Any) -> None: + assert self.sequence_mixer_type == "mamba2" From 3bec1f004b4fd0d95486de9927f6949bf72e2c1c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:10:51 -0800 Subject: [PATCH 07/24] hidden_states -> x Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/decay_gate.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 lm_engine/hf_models/modeling_utils/decay_gate.py diff --git a/lm_engine/hf_models/modeling_utils/decay_gate.py b/lm_engine/hf_models/modeling_utils/decay_gate.py new file mode 100644 index 000000000..a76824eaf --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/decay_gate.py @@ -0,0 +1,118 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import Replicate + +from ...dtensors import tensor_to_dtensor +from ..parameter import ( + mark_parameter_as_initialized, + mark_parameter_as_mup_learning_rate, + mark_parameter_as_no_weight_decay, +) +from .linear import ParameterizedLinear + + +class SoftplusDecayGate(nn.Module): + def __init__( + self, + hidden_size: int, + output_size: int, + std: float | None, + has_projection: bool = False, + A_init_min: float = 0, + A_init_max: float = 16, + dt_init_min: float = 1e-3, + dt_init_max: float = 0.1, + dt_init_floor: float = 1e-4, + ) -> SoftplusDecayGate: + super().__init__() + + self.output_size = output_size + self.has_projection = has_projection + + if has_projection: + self.proj = ParameterizedLinear(hidden_size, self.output_size, std=std) + mark_parameter_as_mup_learning_rate(self.proj.weight) + + self.A_log = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + self.dt_bias = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + + assert A_init_min >= 0 + assert A_init_max >= A_init_min + + self.A_init_min = A_init_min + self.A_init_max = A_init_max + + assert dt_init_min > 0 + assert dt_init_max >= dt_init_min + + self.dt_init_min = dt_init_min + self.dt_init_max = dt_init_max + self.dt_init_floor = dt_init_floor + + self.reset_parameters() + + mark_parameter_as_no_weight_decay(self.A_log) + mark_parameter_as_no_weight_decay(self.dt_bias) + + def forward( + self, x: torch.Tensor, final_exponential: bool, output_dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + if self.has_projection: + x = self.proj(x) + + x = x.float() + x = x + self.dt_bias + x = F.softplus(x) + x = -self.A_log.float().exp() * x + + if final_exponential: + x = torch.exp(x) + + x = x.to(output_dtype) + + return x + + @torch.no_grad() + def reset_parameters(self) -> None: + A = torch.empty(self.output_size, dtype=torch.float32).uniform_(self.A_init_min, self.A_init_max) + + if isinstance(self.A_log, DTensor): + A = tensor_to_dtensor( + tensor=A, + device_mesh=self.A_log.device_mesh, + current_placement=[Replicate()] * len(self.A_log.placements), + desired_placement=self.A_log.placements, + ) + + self.A_log.copy_(torch.log(A)) + + dt = torch.exp( + torch.rand(self.output_size) * (math.log(self.dt_init_max) - math.log(self.dt_init_min)) + + math.log(self.dt_init_min) + ) + dt = torch.clamp(dt, min=self.dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + + if isinstance(self.dt_bias, DTensor): + inv_dt = tensor_to_dtensor( + tensor=inv_dt, + device_mesh=self.dt_bias.device_mesh, + current_placement=[Replicate()] * len(self.dt_bias.placements), + desired_placement=self.dt_bias.placements, + ) + + self.dt_bias.copy_(inv_dt) + + mark_parameter_as_initialized(self.A_log) + mark_parameter_as_initialized(self.dt_bias) From 52b14a334b2cebbb12aa81653887ee256ae75bf6 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:20:24 -0800 Subject: [PATCH 08/24] use gate for GDN Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/decay_gate.py | 10 +-- .../sequence_mixer_blocks/gated_deltanet.py | 69 ++++--------------- 2 files changed, 21 insertions(+), 58 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/decay_gate.py b/lm_engine/hf_models/modeling_utils/decay_gate.py index a76824eaf..f9f8de738 100644 --- a/lm_engine/hf_models/modeling_utils/decay_gate.py +++ b/lm_engine/hf_models/modeling_utils/decay_gate.py @@ -24,7 +24,7 @@ class SoftplusDecayGate(nn.Module): def __init__( self, - hidden_size: int, + hidden_size: int | None, output_size: int, std: float | None, has_projection: bool = False, @@ -42,9 +42,14 @@ def __init__( if has_projection: self.proj = ParameterizedLinear(hidden_size, self.output_size, std=std) mark_parameter_as_mup_learning_rate(self.proj.weight) + else: + assert hidden_size is None self.A_log = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + mark_parameter_as_no_weight_decay(self.A_log) + self.dt_bias = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + mark_parameter_as_no_weight_decay(self.dt_bias) assert A_init_min >= 0 assert A_init_max >= A_init_min @@ -61,9 +66,6 @@ def __init__( self.reset_parameters() - mark_parameter_as_no_weight_decay(self.A_log) - mark_parameter_as_no_weight_decay(self.dt_bias) - def forward( self, x: torch.Tensor, final_exponential: bool, output_dtype: torch.dtype = torch.float32 ) -> torch.Tensor: diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 8fbf455a5..437cea0cb 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -12,14 +12,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.placement_types import Replicate -from ....dtensors import tensor_to_dtensor from ....utils import divide_if_divisible, is_fla_available from ...cache import GenerationCache -from ...parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay from ..convolution import ParameterizedConv1d +from ..decay_gate import SoftplusDecayGate from ..linear import ParameterizedLinear from ..normalization import get_normalization_function from .causal_convolution import causal_convolution @@ -49,8 +46,8 @@ def __init__( m_width: float | None, A_init_min: float, A_init_max: float, - dt_min: float, - dt_max: float, + dt_init_min: float, + dt_init_max: float, dt_init_floor: float, num_layers: int, use_padding_free_transformer: bool, @@ -72,13 +69,6 @@ def __init__( self.k_head_dim = k_head_dim self.v_head_dim = v_head_dim - self.A_init_min = A_init_min - self.A_init_max = A_init_max - - self.dt_min = dt_min - self.dt_max = dt_max - self.dt_init_floor = dt_init_floor - self.key_dim = self.num_k_heads * self.k_head_dim self.value_dim = self.num_v_heads * self.v_head_dim self.layer_idx = layer_idx @@ -96,11 +86,17 @@ def __init__( hidden_size, 2 * self.num_v_heads + (self.value_dim if use_gate else 0), bias=False, std=std ) - self.A_log = nn.Parameter(torch.empty(self.num_v_heads, dtype=torch.float32)) - mark_parameter_as_no_weight_decay(self.A_log) - - self.dt_bias = nn.Parameter(torch.empty(self.num_v_heads)) - mark_parameter_as_no_weight_decay(self.dt_bias) + self.decay_gate = SoftplusDecayGate( + hidden_size=None, + output_size=self.num_v_heads, + std=None, + has_projection=False, + A_init_min=A_init_min, + A_init_max=A_init_max, + dt_init_min=dt_init_min, + dt_init_max=dt_init_max, + dt_init_floor=dt_init_floor, + ) self.conv_size = conv_size self.qkv_conv1d = ParameterizedConv1d( @@ -171,7 +167,7 @@ def forward( if self.allow_neg_eigval: beta = beta * 2.0 - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + g = self.decay_gate(x=a, final_exponential=False) if self.use_padding_free_transformer: assert cache_params is None @@ -237,38 +233,3 @@ def forward( o = self.o_proj(o) return o - - @torch.no_grad() - def reset_parameters(self) -> None: - A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(self.A_init_min, self.A_init_max) - - if isinstance(self.A_log, DTensor): - A = tensor_to_dtensor( - tensor=A, - device_mesh=self.A_log.device_mesh, - current_placement=[Replicate()] * len(self.A_log.placements), - desired_placement=self.A_log.placements, - ) - - self.A_log.copy_(torch.log(A)) - - # hard coded for now - dt = torch.exp( - torch.rand(self.num_v_heads) * (math.log(self.dt_max) - math.log(self.dt_min)) + math.log(self.dt_min) - ) - dt = torch.clamp(dt, min=self.dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - if isinstance(self.dt_bias, DTensor): - inv_dt = tensor_to_dtensor( - tensor=inv_dt, - device_mesh=self.dt_bias.device_mesh, - current_placement=[Replicate()] * len(self.dt_bias.placements), - desired_placement=self.dt_bias.placements, - ) - - self.dt_bias.copy_(inv_dt) - - mark_parameter_as_initialized(self.A_log) - mark_parameter_as_initialized(self.dt_bias) From b062d8676b726ca1f075be0af518913937ef4ecd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:23:38 -0800 Subject: [PATCH 09/24] use gate for mamba2 Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/mamba2.py | 47 +++---------------- 1 file changed, 7 insertions(+), 40 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 85eab46b5..cf7e5a384 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -168,23 +168,24 @@ def __init__( # instantiate once and copy inv_dt in init_weights of PretrainedModel # Initialize log dt bias self.dt_bias = nn.Parameter(torch.empty(self.num_heads)) + mark_parameter_as_no_weight_decay(self.dt_bias) # S4D real initialization. These are not discretized! # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded self.A_log = nn.Parameter(torch.empty(self.num_heads)) + mark_parameter_as_no_weight_decay(self.A_log) + mark_parameter_as_mup_learning_rate(self.A_log) + self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) + self.D = nn.Parameter(torch.empty(self.num_heads)) + mark_parameter_as_no_weight_decay(self.D) + mark_parameter_as_mup_learning_rate(self.D) self.out_proj = ParameterizedLinear( self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) ) - mark_parameter_as_no_weight_decay(self.dt_bias) - mark_parameter_as_no_weight_decay(self.A_log) - mark_parameter_as_no_weight_decay(self.D) - - mark_parameter_as_mup_learning_rate(self.A_log) - mark_parameter_as_mup_learning_rate(self.D) mark_parameter_as_mup_learning_rate(self.conv1d.weight) mark_parameter_as_mup_learning_rate(self.in_proj.weight) mark_parameter_as_mup_learning_rate(self.out_proj.weight) @@ -600,39 +601,5 @@ def _cuda_forward( @torch.no_grad() def reset_parameters(self) -> None: - A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) - - if isinstance(self.A_log, DTensor): - A = tensor_to_dtensor( - tensor=A, - device_mesh=self.A_log.device_mesh, - current_placement=[Replicate()] * len(self.A_log.placements), - desired_placement=self.A_log.placements, - ) - - self.A_log.copy_(torch.log(A)) - - # hard coded for now - dt_min = 0.001 - dt_max = 0.1 - dt_init_floor = 1e-4 - dt = torch.exp(torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) - dt = torch.clamp(dt, min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - if isinstance(self.dt_bias, DTensor): - inv_dt = tensor_to_dtensor( - tensor=inv_dt, - device_mesh=self.dt_bias.device_mesh, - current_placement=[Replicate()] * len(self.dt_bias.placements), - desired_placement=self.dt_bias.placements, - ) - - self.dt_bias.copy_(inv_dt) - nn.init.ones_(self.D) - - mark_parameter_as_initialized(self.A_log) mark_parameter_as_initialized(self.D) - mark_parameter_as_initialized(self.dt_bias) From 87e7a4432b0adef48eb174b6fde42dcc9682e5e3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:36:17 -0800 Subject: [PATCH 10/24] use gate for mamba2 Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 4 +- .../sequence_mixer_blocks/__init__.py | 9 +++- .../sequence_mixer_blocks/mamba2.py | 46 +++++++++++-------- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 1c497d649..9d128a162 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -99,8 +99,8 @@ def model_post_init(self, __context: Any) -> None: class _SoftPlusDecayArgs(BaseArgs): A_init_min: float = 0 A_init_max: float = 16 - dt_min: float = 0.001 - dt_max: float = 0.1 + dt_init_min: float = 0.001 + dt_init_max: float = 0.1 dt_init_floor: float = 1e-4 def model_post_init(self, __context: Any) -> None: diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 4fc7b6904..aa8e979e4 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -105,6 +105,11 @@ def get_sequence_mixer( init_method=config.init_method, normalization_function=block.normalization_function, m_width=config.m_width, + A_init_min=block.A_init_min, + A_init_max=block.A_init_max, + dt_init_min=block.dt_init_min, + dt_init_max=block.dt_init_max, + dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, layer_idx=layer_idx, ) @@ -149,8 +154,8 @@ def get_sequence_mixer( m_width=config.m_width, A_init_min=block.A_init_min, A_init_max=block.A_init_max, - dt_min=block.dt_min, - dt_max=block.dt_max, + dt_min=block.dt_init_min, + dt_max=block.dt_init_max, dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index cf7e5a384..531b158ff 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -24,6 +24,7 @@ ) from ..activations import get_activation_function from ..convolution import ParameterizedConv1d +from ..decay_gate import SoftplusDecayGate from ..linear import ParameterizedLinear from ..mlp_blocks.mlp import _get_std_for_linear from ..normalization import get_normalization_function @@ -115,6 +116,11 @@ def __init__( layer_norm_epsilon: float, initializer_range: float, m_width: float, + A_init_min: float, + A_init_max: float, + dt_init_min: float, + dt_init_max: float, + dt_init_floor: float, init_method: str, normalization_function: str | None, num_layers: int, @@ -162,19 +168,19 @@ def __init__( std=std, ) - # selective projection used to make dt, B and C input dependant - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - # Initialize log dt bias - self.dt_bias = nn.Parameter(torch.empty(self.num_heads)) - mark_parameter_as_no_weight_decay(self.dt_bias) + self.decay_gate = SoftplusDecayGate( + hidden_size=None, + output_size=self.num_heads, + std=None, + has_projection=False, + A_init_min=A_init_min, + A_init_max=A_init_max, + dt_init_min=dt_init_min, + dt_init_max=dt_init_max, + dt_init_floor=dt_init_floor, + ) - # S4D real initialization. These are not discretized! - # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - self.A_log = nn.Parameter(torch.empty(self.num_heads)) - mark_parameter_as_no_weight_decay(self.A_log) - mark_parameter_as_mup_learning_rate(self.A_log) + mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) @@ -272,7 +278,7 @@ def _torch_forward( ) # 3. SSM transformation - A = -torch.exp(self.A_log.float()) + A = -torch.exp(self.decay_gate.A_log.float()) # hidden_states -> B, S, N, head_dim # A -> num_heads @@ -291,7 +297,7 @@ def _torch_forward( # dt -> (B, 1, N) dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) # dt -> (B, N, head_dim) - dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + dt_bias = self.decay_gate.dt_bias[..., None].expand(self.decay_gate.dt_bias.shape[0], self.head_dim) dt = F.softplus(dt + dt_bias.to(dt.dtype)) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) @@ -349,7 +355,7 @@ def _torch_forward( y = y.reshape(batch_size, -1)[:, None, ...] else: # begin ssd naive implementation without einsums - dt = F.softplus(dt + self.dt_bias) + dt = F.softplus(dt + self.decay_gate.dt_bias) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() @@ -479,10 +485,10 @@ def _cuda_forward( ) # 3. SSM transformation - A = -torch.exp(self.A_log.float()) # (nheads,) + A = -torch.exp(self.decay_gate.A_log.float()) # (nheads,) A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) dt = dt[:, :, None].expand(-1, -1, self.head_dim) - dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + dt_bias = self.decay_gate.dt_bias[:, None, ...].expand(-1, self.head_dim) D = self.D[:, None, ...].expand(-1, self.head_dim) B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) @@ -507,7 +513,7 @@ def _cuda_forward( out = self.out_proj(hidden_states)[:, None, ...] # Fused calculations or step by step if no initialized cache is found else: - A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + A = -torch.exp(self.decay_gate.A_log.float()) # (num_heads) or (intermediate_size, state_size) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} # 2-4. Fused kernel for conv1d, SSM, and the final projection @@ -516,7 +522,7 @@ def _cuda_forward( projected_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, - self.dt_bias, + self.decay_gate.dt_bias, A, D=self.D, chunk_size=self.chunk_size, @@ -580,7 +586,7 @@ def _cuda_forward( z=None, seq_idx=None, return_final_states=True, - dt_bias=self.dt_bias, + dt_bias=self.decay_gate.dt_bias, dt_softplus=True, **dt_limit_kwargs, ) From d1050e5e814181674cfada372b73b2a92680dc18 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:42:22 -0800 Subject: [PATCH 11/24] fix tests Signed-off-by: Mayank Mishra --- .../hf_models/model_conversion/granitemoehybrid.py | 12 ++++++------ tests/training/params_group/groups/mup.json | 4 ++-- tests/training/params_group/groups/normal.json | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lm_engine/hf_models/model_conversion/granitemoehybrid.py b/lm_engine/hf_models/model_conversion/granitemoehybrid.py index 6bae682e8..2f4f2f518 100644 --- a/lm_engine/hf_models/model_conversion/granitemoehybrid.py +++ b/lm_engine/hf_models/model_conversion/granitemoehybrid.py @@ -176,11 +176,11 @@ def _import_granitemoehybrid_state_dict( state_dict[f"transformer.h.{layer_idx}.sequence_mixer.in_proj.bias"] = ( safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.in_proj.bias") ) - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.dt_bias"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mamba.dt_bias" + state_dict[f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.dt_bias"] = ( + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.dt_bias") ) - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.A_log"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mamba.A_log" + state_dict[f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.A_log"] = ( + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.A_log") ) state_dict[f"transformer.h.{layer_idx}.sequence_mixer.D"] = safetensors_weights_manager.get_tensor( f"model.layers.{layer_idx}.mamba.D" @@ -404,10 +404,10 @@ def _export_granitemoehybrid_state_dict( f"transformer.h.{layer_idx}.sequence_mixer.in_proj.bias" ) state_dict[f"model.layers.{layer_idx}.mamba.dt_bias"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.sequence_mixer.dt_bias" + f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.dt_bias" ) state_dict[f"model.layers.{layer_idx}.mamba.A_log"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.sequence_mixer.A_log" + f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.A_log" ) state_dict[f"model.layers.{layer_idx}.mamba.D"] = safetensors_weights_manager.get_tensor( f"transformer.h.{layer_idx}.sequence_mixer.D" diff --git a/tests/training/params_group/groups/mup.json b/tests/training/params_group/groups/mup.json index f39be3fa9..cb6d202a7 100644 --- a/tests/training/params_group/groups/mup.json +++ b/tests/training/params_group/groups/mup.json @@ -16,7 +16,7 @@ "model.transformer.h.1.mlp_block.c_proj.bias", "model.transformer.h.1.mlp_block.c_proj_shared.bias", "model.transformer.h.1.sequence_mixer.conv1d.bias", - "model.transformer.h.1.sequence_mixer.dt_bias", + "model.transformer.h.1.sequence_mixer.decay_gate.dt_bias", "model.transformer.h.1.sequence_mixer.in_proj.bias", "model.transformer.h.1.sequence_mixer.norm.weight", "model.transformer.h.1.sequence_mixer.out_proj.bias", @@ -42,9 +42,9 @@ "model.transformer.h.1.mlp_block.c_proj.weight", "model.transformer.h.1.mlp_block.c_proj_shared.weight", "model.transformer.h.1.mlp_block.gate.weight", - "model.transformer.h.1.sequence_mixer.A_log", "model.transformer.h.1.sequence_mixer.D", "model.transformer.h.1.sequence_mixer.conv1d.weight", + "model.transformer.h.1.sequence_mixer.decay_gate.A_log", "model.transformer.h.1.sequence_mixer.in_proj.weight", "model.transformer.h.1.sequence_mixer.out_proj.weight", "model.transformer.h.2.mlp_block.c_fc.weight", diff --git a/tests/training/params_group/groups/normal.json b/tests/training/params_group/groups/normal.json index c14d0f0df..ae8f2a459 100644 --- a/tests/training/params_group/groups/normal.json +++ b/tests/training/params_group/groups/normal.json @@ -35,10 +35,10 @@ "model.transformer.h.1.mlp_block.c_fc_shared.bias", "model.transformer.h.1.mlp_block.c_proj.bias", "model.transformer.h.1.mlp_block.c_proj_shared.bias", - "model.transformer.h.1.sequence_mixer.A_log", "model.transformer.h.1.sequence_mixer.D", "model.transformer.h.1.sequence_mixer.conv1d.bias", - "model.transformer.h.1.sequence_mixer.dt_bias", + "model.transformer.h.1.sequence_mixer.decay_gate.A_log", + "model.transformer.h.1.sequence_mixer.decay_gate.dt_bias", "model.transformer.h.1.sequence_mixer.in_proj.bias", "model.transformer.h.1.sequence_mixer.norm.weight", "model.transformer.h.1.sequence_mixer.out_proj.bias", From 750f9c0131201ad359033d1fdef834c785448906 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:44:57 -0800 Subject: [PATCH 12/24] fix tests Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 62 ++++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 9d128a162..5b22610dc 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -49,6 +49,37 @@ def model_post_init(self, __context: Any) -> None: assert self.head_dim is not None +class _SoftPlusDecayArgs(BaseArgs): + A_init_min: float = 0 + A_init_max: float = 16 + dt_init_min: float = 0.001 + dt_init_max: float = 0.1 + dt_init_floor: float = 1e-4 + + def model_post_init(self, __context: Any) -> None: + assert self.A_init_min >= 0 + assert self.A_init_min <= self.A_init_max + assert self.dt_init_min <= self.dt_init_max + + +class _Mamba2Args(_SoftPlusDecayArgs): + sequence_mixer_type: str = "mamba2" + state_size: int = 128 + intermediate_size: int + num_heads: int = 128 + conv_kernel_size: int = 4 + time_step_limit: tuple[float, float] = (0, float("inf")) + add_bias: bool = False + use_conv_bias: bool = True + activation_function: str = "silu" + num_groups: int = 8 + chunk_size: int = 256 + normalization_function: str | None = "rmsnorm" + + def model_post_init(self, __context: Any) -> None: + assert self.sequence_mixer_type == "mamba2" + + class _GRUArgs(BaseArgs): sequence_mixer_type: str = "gru" state_head_dim: int @@ -96,19 +127,6 @@ def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "causal_convolution" -class _SoftPlusDecayArgs(BaseArgs): - A_init_min: float = 0 - A_init_max: float = 16 - dt_init_min: float = 0.001 - dt_init_max: float = 0.1 - dt_init_floor: float = 1e-4 - - def model_post_init(self, __context: Any) -> None: - assert self.A_init_min >= 0 - assert self.A_init_min <= self.A_init_max - assert self.dt_min <= self.dt_max - - class _GatedDeltaNetArgs(_SoftPlusDecayArgs): sequence_mixer_type: str = "gated_deltanet" k_head_dim: int @@ -122,21 +140,3 @@ class _GatedDeltaNetArgs(_SoftPlusDecayArgs): def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "gated_deltanet" - - -class _Mamba2Args(_SoftPlusDecayArgs): - sequence_mixer_type: str = "mamba2" - state_size: int = 128 - intermediate_size: int - num_heads: int = 128 - conv_kernel_size: int = 4 - time_step_limit: tuple[float, float] = (0, float("inf")) - add_bias: bool = False - use_conv_bias: bool = True - activation_function: str = "silu" - num_groups: int = 8 - chunk_size: int = 256 - normalization_function: str | None = "rmsnorm" - - def model_post_init(self, __context: Any) -> None: - assert self.sequence_mixer_type == "mamba2" From 8e5903e74666b35dfa4021a574003f08fe835256 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:52:10 -0800 Subject: [PATCH 13/24] fix tests Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 531b158ff..dbd5f3d35 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -180,18 +180,17 @@ def __init__( dt_init_floor=dt_init_floor, ) - mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) - self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) self.D = nn.Parameter(torch.empty(self.num_heads)) mark_parameter_as_no_weight_decay(self.D) - mark_parameter_as_mup_learning_rate(self.D) self.out_proj = ParameterizedLinear( self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) ) + mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) + mark_parameter_as_mup_learning_rate(self.D) mark_parameter_as_mup_learning_rate(self.conv1d.weight) mark_parameter_as_mup_learning_rate(self.in_proj.weight) mark_parameter_as_mup_learning_rate(self.out_proj.weight) From cc76e2449af88ea091be0dab2cb60493f4138a40 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:55:49 -0800 Subject: [PATCH 14/24] fix tests Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index dbd5f3d35..a0bbdff05 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -9,10 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.placement_types import Replicate -from ....dtensors import tensor_to_dtensor from ....enums import Kernel from ....kernels import is_kernel_allowed from ....utils import divide_if_divisible, is_causal_conv1d_available, is_mamba_2_ssm_available From 60e092a3791202139e7e6434355b76749ead2f28 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:58:45 -0800 Subject: [PATCH 15/24] fix tests Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/mamba2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index a0bbdff05..544ec4393 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -177,15 +177,16 @@ def __init__( dt_init_floor=dt_init_floor, ) - self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) - self.D = nn.Parameter(torch.empty(self.num_heads)) - mark_parameter_as_no_weight_decay(self.D) + + self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) self.out_proj = ParameterizedLinear( self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) ) + mark_parameter_as_no_weight_decay(self.D) + mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) mark_parameter_as_mup_learning_rate(self.D) mark_parameter_as_mup_learning_rate(self.conv1d.weight) From 740a636424c201906413d1d0c2c42f1b82d99197 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:04:26 -0800 Subject: [PATCH 16/24] merge Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/gated_deltanet.py | 2 -- .../hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 437cea0cb..c16a42f3e 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -117,8 +117,6 @@ def __init__( std /= math.sqrt(m_width) self.o_proj = ParameterizedLinear(self.value_dim, hidden_size, bias=False, std=std) - self.reset_parameters() - def forward( self, hidden_states: torch.Tensor, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 544ec4393..8866f93a3 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -177,14 +177,13 @@ def __init__( dt_init_floor=dt_init_floor, ) - self.D = nn.Parameter(torch.empty(self.num_heads)) - self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) self.out_proj = ParameterizedLinear( self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) ) + self.D = nn.Parameter(torch.empty(self.num_heads)) mark_parameter_as_no_weight_decay(self.D) mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) From 43e099c99aa1bc7767176fe84f68fb8e07d618f1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:13:01 -0800 Subject: [PATCH 17/24] merge Signed-off-by: Mayank Mishra --- .../tensor_parallel/tensor_parallel_forward_test.py | 8 +++++--- tests/hf_models/single_gpu/typecheck_test.py | 12 +++++++++++- tests/hf_models/test_common.py | 4 ---- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py index 7fc50d93a..1ee3b74fc 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py @@ -8,7 +8,7 @@ import torch from parameterized import parameterized -from lm_engine.utils import is_flash_attention_2_available, torch_dtype_to_string +from lm_engine.utils import is_flash_attention_2_available, is_flash_attention_3_available, torch_dtype_to_string from ...test_common import TestCommons @@ -17,7 +17,7 @@ class TensorParallelTest(TestCommons): @parameterized.expand( TestCommons.make_args_matrix( TestCommons.get_position_embedding_types(), - TestCommons.get_attention_implementations(), + ["sdpa", "flash_attention_2", "flash_attention_3"], TestCommons.get_dtypes(), [False, True], [False, True], @@ -41,7 +41,9 @@ def test_tensor_parallel_forward( self.skipTest("skipping test since running all takes too long") if attention_implementation == "flash_attention_2" and not is_flash_attention_2_available(): - self.skipTest("skipping test since flash-attn is unavialable") + self.skipTest("skipping test because flash attention 2 is unavailable") + elif attention_implementation == "flash_attention_3" and not is_flash_attention_3_available(): + self.skipTest("skipping test because flash attention 3 is unavailable") if use_padding_free_transformer and attention_implementation != "flash_attention_2": self.skipTest("skipping test since flash attention is needed for padding free transformer") diff --git a/tests/hf_models/single_gpu/typecheck_test.py b/tests/hf_models/single_gpu/typecheck_test.py index ded25bb6a..480c12d32 100644 --- a/tests/hf_models/single_gpu/typecheck_test.py +++ b/tests/hf_models/single_gpu/typecheck_test.py @@ -7,6 +7,7 @@ from lm_engine.enums import Kernel from lm_engine.kernels import enable_kernels +from lm_engine.utils import is_flash_attention_2_available, is_flash_attention_3_available from ..test_common import TestCommons @@ -25,5 +26,14 @@ def test_no_attention_mask_flash_attention(self, device: torch.device) -> None: input_ids, _, labels = self.get_dummy_inputs(device, return_list=True) attention_mask = [[1] * len(i) for i in input_ids] - with enable_kernels([Kernel.flash_attention_2]): + kernel = None + if is_flash_attention_3_available(): + kernel = Kernel.flash_attention_3 + if is_flash_attention_2_available(): + kernel = Kernel.flash_attention_2 + + if kernel is None: + self.skipTest("skipping test because flash attention 2 or 3 is unavailable") + + with enable_kernels([kernel]): self.assertRaises(AssertionError, model, input_ids=input_ids, attention_mask=attention_mask, labels=labels) diff --git a/tests/hf_models/test_common.py b/tests/hf_models/test_common.py index cbc4152b1..914760684 100644 --- a/tests/hf_models/test_common.py +++ b/tests/hf_models/test_common.py @@ -20,10 +20,6 @@ class TestCommons(BaseTestCommons): - @staticmethod - def get_attention_implementations() -> list[str]: - return ["sdpa", "flash_attention_2"] - @staticmethod def get_position_embedding_types() -> list[str]: return ["learned_absolute", "rope"] From d1fd4fb6c06866972d413d42ee0db99c5f7e616f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:13:44 -0800 Subject: [PATCH 18/24] merge Signed-off-by: Mayank Mishra --- .../multi_gpu/tensor_parallel/tensor_parallel_forward_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py index 1ee3b74fc..3d6da4f19 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py @@ -37,6 +37,7 @@ def test_tensor_parallel_forward( if (attention_implementation, dtype) not in [ ("sdpa", torch.float32), ("flash_attention_2", torch.float16), + ("flash_attention_3", torch.float16), ]: self.skipTest("skipping test since running all takes too long") From 913e08d407ad1b06da7e548ba54cdcbeede3a2f0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:19:56 -0800 Subject: [PATCH 19/24] merge Signed-off-by: Mayank Mishra --- .../tensor_parallel_forward.py | 143 +++++++++--------- 1 file changed, 73 insertions(+), 70 deletions(-) diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index 431ebba7d..2698cc8a7 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -62,90 +62,93 @@ ], ) -enable_kernels( - [Kernel.scattermoe] + ([Kernel.flash_attention_2] if args.attention_implementation == "flash_attention_2" else []) -).__enter__() +kernels = [Kernel.scattermoe] +if args.attention_implementation == "flash_attention_2": + kernels.append(Kernel.flash_attention_2) +elif args.attention_implementation == "flash_attention_3": + kernels.append(Kernel.flash_attention_3) -if torch.distributed.get_rank() == 0: - with torch.device("meta"): - model = TestCommons.from_config(None, config) - - model = model.to_empty(device=torch.cuda.current_device()) - for _, param in model.named_parameters(): - param.data.normal_(0, 0.0125) +with enable_kernels(kernels): + if torch.distributed.get_rank() == 0: + with torch.device("meta"): + model = TestCommons.from_config(None, config) - model.eval() + model = model.to_empty(device=torch.cuda.current_device()) + for _, param in model.named_parameters(): + param.data.normal_(0, 0.0125) - model.save_pretrained(args.tmp_path, safe_serialization=True) - model = model.to(dtype) + model.eval() -Communication.barrier() + model.save_pretrained(args.tmp_path, safe_serialization=True) + model = model.to(dtype) -# use dummy tensors to avoid initializing model here -with torch.device("meta"): - # try sharding vocab matrices if really struggling for memory + Communication.barrier() - model_tp = get_model_parallel_class(config.model_type)._from_config( - config, - use_padding_free_transformer=args.use_padding_free_transformer, - sequence_parallel=args.sequence_parallel, - ) + # use dummy tensors to avoid initializing model here + with torch.device("meta"): + # try sharding vocab matrices if really struggling for memory -# copy to device without copying storage -model_tp = model_tp.to_empty(device=torch.cuda.current_device()) + model_tp = get_model_parallel_class(config.model_type)._from_config( + config, + use_padding_free_transformer=args.use_padding_free_transformer, + sequence_parallel=args.sequence_parallel, + ) -# load weights into tensor parallel model using SafeTensorsWeightsManager class -# this avoids loading multiple copies of the parameters in CPU memory -model_tp.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(args.tmp_path)) + # copy to device without copying storage + model_tp = model_tp.to_empty(device=torch.cuda.current_device()) -# set model to eval mode -model_tp = model_tp.to(dtype) -model_tp.eval() + # load weights into tensor parallel model using SafeTensorsWeightsManager class + # this avoids loading multiple copies of the parameters in CPU memory + model_tp.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(args.tmp_path)) -set_seed(42) + # set model to eval mode + model_tp = model_tp.to(dtype) + model_tp.eval() -batch_size = 4 -sequence_length = 512 + set_seed(42) -input_ids = torch.randint( - 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False -) -labels = torch.randint( - 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False -) + batch_size = 4 + sequence_length = 512 -if args.use_padding_free_transformer: - cu_seqlens = torch.arange( - 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() + input_ids = torch.randint( + 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False ) - position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) - - output_tp = model_tp( - input_ids=input_ids.view(-1), - labels=labels.view(-1), - cu_seqlens=cu_seqlens, - max_seqlen=sequence_length, - position_ids=position_ids, + labels = torch.randint( + 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False ) -else: - output_tp = model_tp(input_ids=input_ids, labels=labels) - -loss_tp = output_tp.loss -logits_tp = output_tp.logits[..., : config.vocab_size] - -if torch.distributed.get_rank() == 0: - # loss computation hangs if we don't use dummy tensor parallel world size - with ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): - output = model(input_ids=input_ids, labels=labels) - - loss = output.loss - logits = output.logits if args.use_padding_free_transformer: - logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) - - error = (logits - logits_tp).abs().max() - assert error < 5e-4, f"logits don't match for normal and tensor parallel model, error is ({error})" - - error = (loss - loss_tp).abs().max() - assert error < 1e-3, f"losses don't match for normal and tensor parallel model, error is ({error})" + cu_seqlens = torch.arange( + 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() + ) + position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) + + output_tp = model_tp( + input_ids=input_ids.view(-1), + labels=labels.view(-1), + cu_seqlens=cu_seqlens, + max_seqlen=sequence_length, + position_ids=position_ids, + ) + else: + output_tp = model_tp(input_ids=input_ids, labels=labels) + + loss_tp = output_tp.loss + logits_tp = output_tp.logits[..., : config.vocab_size] + + if torch.distributed.get_rank() == 0: + # loss computation hangs if we don't use dummy tensor parallel world size + with ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): + output = model(input_ids=input_ids, labels=labels) + + loss = output.loss + logits = output.logits + + if args.use_padding_free_transformer: + logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) + + error = (logits - logits_tp).abs().max() + assert error < 5e-4, f"logits don't match for normal and tensor parallel model, error is ({error})" + + error = (loss - loss_tp).abs().max() + assert error < 1e-3, f"losses don't match for normal and tensor parallel model, error is ({error})" From 037185f7842c31208b68c0d1fd4207f2ec50cda1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:30:42 -0800 Subject: [PATCH 20/24] merge Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 4dfe572ab..e4d42e081 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -130,8 +130,8 @@ def get_sequence_mixer( m_width=config.m_width, A_init_min=block.A_init_min, A_init_max=block.A_init_max, - dt_min=block.dt_init_min, - dt_max=block.dt_init_max, + dt_init_min=block.dt_init_min, + dt_init_max=block.dt_init_max, dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, From f0312863ba7bb1a7d3236918a27371c0df981cfd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 19:02:09 -0800 Subject: [PATCH 21/24] merge Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/gated_deltanet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index c16a42f3e..ed44b2808 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -163,7 +163,7 @@ def forward( beta = b.sigmoid() if self.allow_neg_eigval: - beta = beta * 2.0 + beta = beta * 2 g = self.decay_gate(x=a, final_exponential=False) From 8eb840a10e0e9456d973915767283b59e79fc964 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 19:44:45 -0800 Subject: [PATCH 22/24] drop post_init() Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 7 ------- lm_engine/hf_models/mixins/dense/main.py | 3 --- lm_engine/hf_models/mixins/dense_TP/base.py | 3 --- lm_engine/hf_models/mixins/dense_TP/main.py | 3 --- 4 files changed, 16 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index a16a4673d..9582ee384 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -38,10 +38,6 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixi self._has_mamba2 = any([block.sequence_mixer_type == "mamba2" for block in self.config.sequence_mixer_blocks]) - def _init_weights(self, module: nn.Module) -> None: - if hasattr(module, "reset_parameters"): - module.reset_parameters() - # FIXME typing def prepare_inputs_for_model( self, @@ -118,9 +114,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.position_embedding_type = config.position_embedding_type self._setup_positional_encoding() - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.Tensor | None = None, diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 32599cb0d..4ec016803 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -37,9 +37,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.m_width = config.m_width - # Initialize weights and apply final processing - self.post_init() - def get_input_embeddings(self) -> ParameterizedEmbedding: return self.transformer.wte diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index e3d53e3c9..700c2b153 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -92,9 +92,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.position_embedding_type = config.position_embedding_type self._setup_positional_encoding() - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.Tensor | None = None, diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index 13951bdc5..f614b1cc0 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -52,9 +52,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.Tensor | list[list[int]] | None = None, From 6271d83f5d5586f72491f0aebb91a3fd0da84371 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 22:03:08 -0800 Subject: [PATCH 23/24] drop buffer Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/mlp_blocks/moe.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 349e92b51..32bd6751c 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -80,14 +80,6 @@ def __init__( self.in_features = in_features self.out_features = out_features - self.register_buffer( - "N_array", torch.empty((num_experts,), device=device, dtype=torch.uint32), persistent=False - ) - - self.register_buffer( - "K_array", torch.empty((num_experts,), device=device, dtype=torch.uint32), persistent=False - ) - self.reset_parameters() mark_parameter_as_no_weight_decay(self.bias) @@ -139,15 +131,9 @@ def reset_parameters(self) -> None: if hasattr(self, "bias") and self.bias is not None: self.bias.zero_() - self.N_array.fill_(self.out_features) - self.K_array.fill_(self.in_features) - mark_parameter_as_initialized(self.weight) mark_parameter_as_initialized(self.bias) - mark_parameter_as_initialized(self.N_array) - mark_parameter_as_initialized(self.K_array) - class MoE(nn.Module): linear_class = ParameterizedExperts From 84286a07b3cc1283dc549f531df0d37ac9810b5b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 22:06:35 -0800 Subject: [PATCH 24/24] fix init check Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 27 ++++++++++++++++++++++----- lm_engine/hf_models/parameter.py | 3 --- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 952d89fb7..8fbd19d9c 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -30,8 +30,7 @@ from .containers import ModelContainer from .enums import Kernel from .gradient_checkpointing import apply_gradient_checkpointing -from .hf_models import CausalLMOutputWithPast -from .hf_models.parameter import _ALL_MARKERS +from .hf_models import CausalLMOutputWithPast, is_parameter_initialized from .kernels import is_kernel_allowed from .utils import ( Accelerator, @@ -120,13 +119,13 @@ def _get_fsdp_mixed_precision( return mixed_precision -def _get_parameter_marker_maps(model_container: ModelContainer) -> list[dict]: +def _get_parameter_marker_maps(model_container: ModelContainer, extra_markers: list[str] = []) -> list[dict]: marker_maps = [] for model in model_container: marker_maps.append({}) for param_name, param in model.named_parameters(): marker_maps[-1][param_name] = {} - for marker in _ALL_MARKERS: + for marker in ["_no_weight_decay", "_has_mup_learning_rate"] + extra_markers: marker_maps[-1][param_name][marker] = getattr(param, marker, False) return marker_maps @@ -222,7 +221,18 @@ def wrap_model_container_for_distributed_training( **args.distributed_args.gradient_checkpointing_args, ) - marker_maps = _get_parameter_marker_maps(model_container) + if efficient_initialization: + for model in model_container: + for param_name, parameter in model.named_parameters(): + parameter._is_initialized = False + + for param_name, parameter in model.named_buffers(): + parameter._is_initialized = False + + marker_maps = _get_parameter_marker_maps(model_container) + else: + marker_maps = _get_parameter_marker_maps(model_container, extra_markers=["_is_initialized"]) + accelerator = Accelerator.get_accelerator() if accelerator == Accelerator.tpu: @@ -382,6 +392,13 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: pipeline_stages = [] pipeline_schedule = None + for model in model_container: + for param_name, parameter in model.named_parameters(): + assert is_parameter_initialized(parameter), f"{param_name} is not initialized" + + for param_name, parameter in model.named_buffers(): + assert is_parameter_initialized(parameter), f"{param_name} is not initialized" + if num_pipeline_stages > 1: micro_batch_size = args.training_parameters.micro_batch_size sequence_length = args.datasets[0].class_args.get("sequence_length") diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index b86da83b2..4ab6e6834 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -5,9 +5,6 @@ import torch.nn as nn -_ALL_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", "_is_initialized"] - - def mark_parameter_as_no_weight_decay(parameter: nn.Parameter | None) -> nn.Parameter | None: if parameter is not None: parameter._no_weight_decay = True