Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions lm_engine/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
from .containers import ModelContainer
from .enums import Kernel
from .gradient_checkpointing import apply_gradient_checkpointing
from .hf_models import CausalLMOutputWithPast
from .hf_models.parameter import _ALL_MARKERS
from .hf_models import CausalLMOutputWithPast, is_parameter_initialized
from .kernels import is_kernel_allowed
from .utils import (
Accelerator,
Expand Down Expand Up @@ -120,13 +119,13 @@ def _get_fsdp_mixed_precision(
return mixed_precision


def _get_parameter_marker_maps(model_container: ModelContainer) -> list[dict]:
def _get_parameter_marker_maps(model_container: ModelContainer, extra_markers: list[str] = []) -> list[dict]:
marker_maps = []
for model in model_container:
marker_maps.append({})
for param_name, param in model.named_parameters():
marker_maps[-1][param_name] = {}
for marker in _ALL_MARKERS:
for marker in ["_no_weight_decay", "_has_mup_learning_rate"] + extra_markers:
marker_maps[-1][param_name][marker] = getattr(param, marker, False)

return marker_maps
Expand Down Expand Up @@ -222,7 +221,18 @@ def wrap_model_container_for_distributed_training(
**args.distributed_args.gradient_checkpointing_args,
)

marker_maps = _get_parameter_marker_maps(model_container)
if efficient_initialization:
for model in model_container:
for param_name, parameter in model.named_parameters():
parameter._is_initialized = False

for param_name, parameter in model.named_buffers():
parameter._is_initialized = False

marker_maps = _get_parameter_marker_maps(model_container)
else:
marker_maps = _get_parameter_marker_maps(model_container, extra_markers=["_is_initialized"])

accelerator = Accelerator.get_accelerator()

if accelerator == Accelerator.tpu:
Expand Down Expand Up @@ -382,6 +392,13 @@ def _sharding_function(parameter: nn.Parameter) -> Shard:
pipeline_stages = []
pipeline_schedule = None

for model in model_container:
for param_name, parameter in model.named_parameters():
assert is_parameter_initialized(parameter), f"{param_name} is not initialized"

for param_name, parameter in model.named_buffers():
assert is_parameter_initialized(parameter), f"{param_name} is not initialized"

if num_pipeline_stages > 1:
micro_batch_size = args.training_parameters.micro_batch_size
sequence_length = args.datasets[0].class_args.get("sequence_length")
Expand Down
40 changes: 38 additions & 2 deletions lm_engine/hf_models/config/sequence_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,43 @@ def model_post_init(self, __context: Any) -> None:
assert self.sequence_mixer_type == "softmax_attention"


class _Mamba2Args(BaseArgs):
class _MultiHeadLatentAttentionArgs(BaseArgs):
sequence_mixer_type: str = "multihead_latent_attention"
num_attention_heads: int | None = None
softmax_dropout: float = 0
dropout: float = 0
add_bias: bool = False
attention_multiplier: float | None = None
sliding_window: int | None = None
query_compression_size: int | None = None
key_value_compression_size: int | None = None
num_attention_heads: int | None = None
head_dim: int | None = None
normalization_function: str = "layernorm"

def model_post_init(self, __context: Any) -> None:
assert self.sequence_mixer_type == "multihead_latent_attention"
assert self.num_attention_heads is not None
assert self.query_compression_size is not None
assert self.key_value_compression_size is not None
assert self.num_attention_heads is not None
assert self.head_dim is not None


class _SoftPlusDecayArgs(BaseArgs):
A_init_min: float = 0
A_init_max: float = 16
dt_init_min: float = 0.001
dt_init_max: float = 0.1
dt_init_floor: float = 1e-4

def model_post_init(self, __context: Any) -> None:
assert self.A_init_min >= 0
assert self.A_init_min <= self.A_init_max
assert self.dt_init_min <= self.dt_init_max


class _Mamba2Args(_SoftPlusDecayArgs):
sequence_mixer_type: str = "mamba2"
state_size: int = 128
intermediate_size: int
Expand Down Expand Up @@ -91,7 +127,7 @@ def model_post_init(self, __context: Any) -> None:
assert self.sequence_mixer_type == "causal_convolution"


class _GatedDeltaNetArgs(BaseArgs):
class _GatedDeltaNetArgs(_SoftPlusDecayArgs):
sequence_mixer_type: str = "gated_deltanet"
k_head_dim: int
v_head_dim: int
Expand Down
7 changes: 0 additions & 7 deletions lm_engine/hf_models/mixins/dense/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixi

self._has_mamba2 = any([block.sequence_mixer_type == "mamba2" for block in self.config.sequence_mixer_blocks])

def _init_weights(self, module: nn.Module) -> None:
if hasattr(module, "reset_parameters"):
module.reset_parameters()

# FIXME typing
def prepare_inputs_for_model(
self,
Expand Down Expand Up @@ -118,9 +114,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:
self.position_embedding_type = config.position_embedding_type
self._setup_positional_encoding()

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: torch.Tensor | None = None,
Expand Down
3 changes: 0 additions & 3 deletions lm_engine/hf_models/mixins/dense/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:

self.m_width = config.m_width

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> ParameterizedEmbedding:
return self.transformer.wte

Expand Down
3 changes: 0 additions & 3 deletions lm_engine/hf_models/mixins/dense_TP/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:
self.position_embedding_type = config.position_embedding_type
self._setup_positional_encoding()

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: torch.Tensor | None = None,
Expand Down
3 changes: 0 additions & 3 deletions lm_engine/hf_models/mixins/dense_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:

self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh()

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: torch.Tensor | list[list[int]] | None = None,
Expand Down
12 changes: 6 additions & 6 deletions lm_engine/hf_models/model_conversion/granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def _import_granitemoehybrid_state_dict(
state_dict[f"transformer.h.{layer_idx}.sequence_mixer.in_proj.bias"] = (
safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.in_proj.bias")
)
state_dict[f"transformer.h.{layer_idx}.sequence_mixer.dt_bias"] = safetensors_weights_manager.get_tensor(
f"model.layers.{layer_idx}.mamba.dt_bias"
state_dict[f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.dt_bias"] = (
safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.dt_bias")
)
state_dict[f"transformer.h.{layer_idx}.sequence_mixer.A_log"] = safetensors_weights_manager.get_tensor(
f"model.layers.{layer_idx}.mamba.A_log"
state_dict[f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.A_log"] = (
safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.A_log")
)
state_dict[f"transformer.h.{layer_idx}.sequence_mixer.D"] = safetensors_weights_manager.get_tensor(
f"model.layers.{layer_idx}.mamba.D"
Expand Down Expand Up @@ -404,10 +404,10 @@ def _export_granitemoehybrid_state_dict(
f"transformer.h.{layer_idx}.sequence_mixer.in_proj.bias"
)
state_dict[f"model.layers.{layer_idx}.mamba.dt_bias"] = safetensors_weights_manager.get_tensor(
f"transformer.h.{layer_idx}.sequence_mixer.dt_bias"
f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.dt_bias"
)
state_dict[f"model.layers.{layer_idx}.mamba.A_log"] = safetensors_weights_manager.get_tensor(
f"transformer.h.{layer_idx}.sequence_mixer.A_log"
f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.A_log"
)
state_dict[f"model.layers.{layer_idx}.mamba.D"] = safetensors_weights_manager.get_tensor(
f"transformer.h.{layer_idx}.sequence_mixer.D"
Expand Down
120 changes: 120 additions & 0 deletions lm_engine/hf_models/modeling_utils/decay_gate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

from __future__ import annotations

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor.api import DTensor
from torch.distributed._tensor.placement_types import Replicate

from ...dtensors import tensor_to_dtensor
from ..parameter import (
mark_parameter_as_initialized,
mark_parameter_as_mup_learning_rate,
mark_parameter_as_no_weight_decay,
)
from .linear import ParameterizedLinear


class SoftplusDecayGate(nn.Module):
def __init__(
self,
hidden_size: int | None,
output_size: int,
std: float | None,
has_projection: bool = False,
A_init_min: float = 0,
A_init_max: float = 16,
dt_init_min: float = 1e-3,
dt_init_max: float = 0.1,
dt_init_floor: float = 1e-4,
) -> SoftplusDecayGate:
super().__init__()

self.output_size = output_size
self.has_projection = has_projection

if has_projection:
self.proj = ParameterizedLinear(hidden_size, self.output_size, std=std)
mark_parameter_as_mup_learning_rate(self.proj.weight)
else:
assert hidden_size is None

self.A_log = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32))
mark_parameter_as_no_weight_decay(self.A_log)

self.dt_bias = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32))
mark_parameter_as_no_weight_decay(self.dt_bias)

assert A_init_min >= 0
assert A_init_max >= A_init_min

self.A_init_min = A_init_min
self.A_init_max = A_init_max

assert dt_init_min > 0
assert dt_init_max >= dt_init_min

self.dt_init_min = dt_init_min
self.dt_init_max = dt_init_max
self.dt_init_floor = dt_init_floor

self.reset_parameters()

def forward(
self, x: torch.Tensor, final_exponential: bool, output_dtype: torch.dtype = torch.float32
) -> torch.Tensor:
if self.has_projection:
x = self.proj(x)

x = x.float()
x = x + self.dt_bias
x = F.softplus(x)
x = -self.A_log.float().exp() * x

if final_exponential:
x = torch.exp(x)

x = x.to(output_dtype)

return x

@torch.no_grad()
def reset_parameters(self) -> None:
A = torch.empty(self.output_size, dtype=torch.float32).uniform_(self.A_init_min, self.A_init_max)

if isinstance(self.A_log, DTensor):
A = tensor_to_dtensor(
tensor=A,
device_mesh=self.A_log.device_mesh,
current_placement=[Replicate()] * len(self.A_log.placements),
desired_placement=self.A_log.placements,
)

self.A_log.copy_(torch.log(A))

dt = torch.exp(
torch.rand(self.output_size) * (math.log(self.dt_init_max) - math.log(self.dt_init_min))
+ math.log(self.dt_init_min)
)
dt = torch.clamp(dt, min=self.dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))

if isinstance(self.dt_bias, DTensor):
inv_dt = tensor_to_dtensor(
tensor=inv_dt,
device_mesh=self.dt_bias.device_mesh,
current_placement=[Replicate()] * len(self.dt_bias.placements),
desired_placement=self.dt_bias.placements,
)

self.dt_bias.copy_(inv_dt)

mark_parameter_as_initialized(self.A_log)
mark_parameter_as_initialized(self.dt_bias)
14 changes: 0 additions & 14 deletions lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,6 @@ def __init__(
self.in_features = in_features
self.out_features = out_features

self.register_buffer(
"N_array", torch.empty((num_experts,), device=device, dtype=torch.uint32), persistent=False
)

self.register_buffer(
"K_array", torch.empty((num_experts,), device=device, dtype=torch.uint32), persistent=False
)

self.reset_parameters()

mark_parameter_as_no_weight_decay(self.bias)
Expand Down Expand Up @@ -139,15 +131,9 @@ def reset_parameters(self) -> None:
if hasattr(self, "bias") and self.bias is not None:
self.bias.zero_()

self.N_array.fill_(self.out_features)
self.K_array.fill_(self.in_features)

mark_parameter_as_initialized(self.weight)
mark_parameter_as_initialized(self.bias)

mark_parameter_as_initialized(self.N_array)
mark_parameter_as_initialized(self.K_array)


class MoE(nn.Module):
linear_class = ParameterizedExperts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def get_sequence_mixer(
init_method=config.init_method,
normalization_function=block.normalization_function,
m_width=config.m_width,
A_init_min=block.A_init_min,
A_init_max=block.A_init_max,
dt_init_min=block.dt_init_min,
dt_init_max=block.dt_init_max,
dt_init_floor=block.dt_init_floor,
num_layers=config.num_layers,
layer_idx=layer_idx,
)
Expand All @@ -115,12 +120,19 @@ def get_sequence_mixer(
num_k_heads=block.num_k_heads,
num_v_heads=block.num_v_heads,
use_gate=block.use_gate,
attention_multiplier=block.attention_multiplier,
allow_neg_eigval=block.allow_neg_eigval,
conv_size=block.conv_size,
conv_size=block.kernel_size,
layer_idx=layer_idx,
norm_eps=config.layer_norm_epsilon,
init_method=config.init_method,
initializer_range=config.initializer_range,
m_width=config.m_width,
A_init_min=block.A_init_min,
A_init_max=block.A_init_max,
dt_init_min=block.dt_init_min,
dt_init_max=block.dt_init_max,
dt_init_floor=block.dt_init_floor,
num_layers=config.num_layers,
use_padding_free_transformer=use_padding_free_transformer,
)
Expand Down
Loading