From 47241702fd839f1c4740a5168649e92417ecbe72 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Tue, 3 Feb 2026 11:22:01 -0800 Subject: [PATCH 01/99] add Signed-off-by: Mayank Mishra --- accelerated-model-architectures | 2 +- lm_engine/arguments.py | 1 + lm_engine/distributed.py | 20 +++++++++++++++++--- lm_engine/hf_models/parameter.py | 3 --- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/accelerated-model-architectures b/accelerated-model-architectures index 6432eda29..b25904976 160000 --- a/accelerated-model-architectures +++ b/accelerated-model-architectures @@ -1 +1 @@ -Subproject commit 6432eda2934e68183f8d7965b04e151e72314fdd +Subproject commit b25904976adf0559bb5b5ff7e7f59ee8dea6067e diff --git a/lm_engine/arguments.py b/lm_engine/arguments.py index 694332d6b..cbe885e41 100644 --- a/lm_engine/arguments.py +++ b/lm_engine/arguments.py @@ -57,6 +57,7 @@ def model_post_init(self, __context: Any) -> None: if self.model_name is None: _check_not_None([(self.pretrained_config, "pretrained_config")]) else: + assert not self.efficient_initialization, "efficient_initialization is not supported with HF models" assert self.pretrained_config is None, "pretrained_config shouldn't be specified with model_name" diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 952d89fb7..7879f06c5 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -30,8 +30,7 @@ from .containers import ModelContainer from .enums import Kernel from .gradient_checkpointing import apply_gradient_checkpointing -from .hf_models import CausalLMOutputWithPast -from .hf_models.parameter import _ALL_MARKERS +from .hf_models import CausalLMOutputWithPast, is_parameter_initialized from .kernels import is_kernel_allowed from .utils import ( Accelerator, @@ -126,7 +125,7 @@ def _get_parameter_marker_maps(model_container: ModelContainer) -> list[dict]: marker_maps.append({}) for param_name, param in model.named_parameters(): marker_maps[-1][param_name] = {} - for marker in _ALL_MARKERS: + for marker in ["_no_weight_decay", "_has_mup_learning_rate"]: marker_maps[-1][param_name][marker] = getattr(param, marker, False) return marker_maps @@ -222,6 +221,14 @@ def wrap_model_container_for_distributed_training( **args.distributed_args.gradient_checkpointing_args, ) + if efficient_initialization: + for model in model_container: + for param_name, parameter in model.named_parameters(): + del parameter._is_initialized + + for param_name, parameter in model.named_buffers(): + del parameter._is_initialized + marker_maps = _get_parameter_marker_maps(model_container) accelerator = Accelerator.get_accelerator() @@ -382,6 +389,13 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: pipeline_stages = [] pipeline_schedule = None + for model in model_container: + for param_name, parameter in model.named_parameters(): + assert is_parameter_initialized(parameter), f"{param_name} is not initialized" + + for param_name, parameter in model.named_buffers(): + assert is_parameter_initialized(parameter), f"{param_name} is not initialized" + if num_pipeline_stages > 1: micro_batch_size = args.training_parameters.micro_batch_size sequence_length = args.datasets[0].class_args.get("sequence_length") diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index b86da83b2..4ab6e6834 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -5,9 +5,6 @@ import torch.nn as nn -_ALL_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", "_is_initialized"] - - def mark_parameter_as_no_weight_decay(parameter: nn.Parameter | None) -> nn.Parameter | None: if parameter is not None: parameter._no_weight_decay = True From f49b5b10ae221e4f0b92b326fcee0d5d03831ed2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 19:27:49 -0800 Subject: [PATCH 02/99] drop post_init() Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 7 ------- lm_engine/hf_models/mixins/dense/main.py | 3 --- lm_engine/hf_models/mixins/dense_TP/base.py | 3 --- lm_engine/hf_models/mixins/dense_TP/main.py | 3 --- 4 files changed, 16 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index a16a4673d..9582ee384 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -38,10 +38,6 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixi self._has_mamba2 = any([block.sequence_mixer_type == "mamba2" for block in self.config.sequence_mixer_blocks]) - def _init_weights(self, module: nn.Module) -> None: - if hasattr(module, "reset_parameters"): - module.reset_parameters() - # FIXME typing def prepare_inputs_for_model( self, @@ -118,9 +114,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.position_embedding_type = config.position_embedding_type self._setup_positional_encoding() - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.Tensor | None = None, diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 32599cb0d..4ec016803 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -37,9 +37,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.m_width = config.m_width - # Initialize weights and apply final processing - self.post_init() - def get_input_embeddings(self) -> ParameterizedEmbedding: return self.transformer.wte diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index e3d53e3c9..700c2b153 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -92,9 +92,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.position_embedding_type = config.position_embedding_type self._setup_positional_encoding() - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.Tensor | None = None, diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index 13951bdc5..f614b1cc0 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -52,9 +52,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.Tensor | list[list[int]] | None = None, From d4f1451cbf028ecdb803a390a78b3eef9357cdf7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 19:29:13 -0800 Subject: [PATCH 03/99] add check for init Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 7879f06c5..a86be9ebf 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -224,10 +224,10 @@ def wrap_model_container_for_distributed_training( if efficient_initialization: for model in model_container: for param_name, parameter in model.named_parameters(): - del parameter._is_initialized + parameter._is_initialized = False for param_name, parameter in model.named_buffers(): - del parameter._is_initialized + parameter._is_initialized = False marker_maps = _get_parameter_marker_maps(model_container) accelerator = Accelerator.get_accelerator() From ba585b7178ff3aa6e6e8959a6dcb7070978ce96e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 20:04:30 -0800 Subject: [PATCH 04/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 6 +- .../hf_models/modeling_utils/decay_gate.py | 73 +++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 lm_engine/hf_models/modeling_utils/decay_gate.py diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index a86be9ebf..e0b09b8a1 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -119,13 +119,15 @@ def _get_fsdp_mixed_precision( return mixed_precision -def _get_parameter_marker_maps(model_container: ModelContainer) -> list[dict]: +def _get_parameter_marker_maps( + model_container: ModelContainer, markers: list[str] = ["_no_weight_decay", "_has_mup_learning_rate"] +) -> list[dict]: marker_maps = [] for model in model_container: marker_maps.append({}) for param_name, param in model.named_parameters(): marker_maps[-1][param_name] = {} - for marker in ["_no_weight_decay", "_has_mup_learning_rate"]: + for marker in markers: marker_maps[-1][param_name][marker] = getattr(param, marker, False) return marker_maps diff --git a/lm_engine/hf_models/modeling_utils/decay_gate.py b/lm_engine/hf_models/modeling_utils/decay_gate.py new file mode 100644 index 000000000..a59bdce64 --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/decay_gate.py @@ -0,0 +1,73 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..parameter import ( + mark_parameter_as_initialized, + mark_parameter_as_mup_learning_rate, + mark_parameter_as_no_weight_decay, +) +from .linear import ParameterizedLinear + + +class SoftplusDecayGate(nn.Module): + def __init__( + self, hidden_size: int, output_size: int, std: float | None, has_projection: bool = False + ) -> SoftplusDecayGate: + super().__init__() + + self.output_size = output_size + self.has_projection = has_projection + + if has_projection: + self.proj = ParameterizedLinear(hidden_size, self.output_size, std=std) + mark_parameter_as_mup_learning_rate(self.proj.weight) + + self.A_log = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + self.dt_bias = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + + self.reset_parameters() + mark_parameter_as_no_weight_decay(self.dt_bias) + + def forward(self, x: torch.Tensor, final_exponential: bool) -> torch.Tensor: + dtype = x.dtype + + if self.has_projection: + x = self.proj(x) + + x = x.float() + x = x + self.dt_bias + x = F.softplus(x) + x = -self.A_log.float().exp() * x + + if final_exponential: + x = torch.exp(x) + + x = x.to(dtype) + + return x + + @torch.no_grad() + def reset_parameters(self) -> None: + A = torch.arange(1, self.output_size + 1, dtype=torch.float32) + self.A_log.copy_(torch.log(A)) + + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp(torch.rand(self.output_size) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + dt = torch.clamp(dt, min=dt_init_floor) + + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias.copy_(inv_dt) + + mark_parameter_as_initialized(self.A_log) + mark_parameter_as_initialized(self.dt_bias) From a0f677ceadbd8f20702438c5fbc55b6b7aa9d117 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 20:21:05 -0800 Subject: [PATCH 05/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index e0b09b8a1..8fbd19d9c 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -119,15 +119,13 @@ def _get_fsdp_mixed_precision( return mixed_precision -def _get_parameter_marker_maps( - model_container: ModelContainer, markers: list[str] = ["_no_weight_decay", "_has_mup_learning_rate"] -) -> list[dict]: +def _get_parameter_marker_maps(model_container: ModelContainer, extra_markers: list[str] = []) -> list[dict]: marker_maps = [] for model in model_container: marker_maps.append({}) for param_name, param in model.named_parameters(): marker_maps[-1][param_name] = {} - for marker in markers: + for marker in ["_no_weight_decay", "_has_mup_learning_rate"] + extra_markers: marker_maps[-1][param_name][marker] = getattr(param, marker, False) return marker_maps @@ -231,7 +229,10 @@ def wrap_model_container_for_distributed_training( for param_name, parameter in model.named_buffers(): parameter._is_initialized = False - marker_maps = _get_parameter_marker_maps(model_container) + marker_maps = _get_parameter_marker_maps(model_container) + else: + marker_maps = _get_parameter_marker_maps(model_container, extra_markers=["_is_initialized"]) + accelerator = Accelerator.get_accelerator() if accelerator == Accelerator.tpu: From 58b9ac0b551d5994b5a4b0478f00218d920a06f7 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 20:37:00 -0800 Subject: [PATCH 06/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/dtensors.py | 16 ++++++++++++++++ lm_engine/hf_models/parameter.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/lm_engine/dtensors.py b/lm_engine/dtensors.py index 7944c77d1..ff57a4406 100644 --- a/lm_engine/dtensors.py +++ b/lm_engine/dtensors.py @@ -8,6 +8,8 @@ from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import DeviceMesh +from .hf_models.parameter import _ALL_MARKERS + def tensor_to_dtensor( tensor: torch.Tensor, @@ -15,6 +17,7 @@ def tensor_to_dtensor( current_placement: Placement | list[Placement], desired_placement: Placement | list[Placement] | None = None, run_check: bool = False, + copy_marker: bool = True, ) -> DTensor: if isinstance(tensor, DTensor): return tensor @@ -30,6 +33,12 @@ def tensor_to_dtensor( dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=True) + if copy_marker: + for marker in _ALL_MARKERS: + marker_value = getattr(dtensor, marker, None) + if marker_value is not None: + setattr(dtensor, marker, marker_value) + return dtensor @@ -38,6 +47,7 @@ def dtensor_to_tensor( device_mesh: DeviceMesh | None = None, desired_placement: Placement | list[Placement] | None = None, grad_placement: Placement | list[Placement] | None = None, + copy_marker: bool = True, ) -> torch.Tensor: if not isinstance(dtensor, DTensor): return dtensor @@ -55,6 +65,12 @@ def dtensor_to_tensor( tensor = dtensor.to_local(grad_placements=grad_placement) + if copy_marker: + for marker in _ALL_MARKERS: + marker_value = getattr(tensor, marker, None) + if marker_value is not None: + setattr(tensor, marker, marker_value) + return tensor diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 4ab6e6834..b86da83b2 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -5,6 +5,9 @@ import torch.nn as nn +_ALL_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", "_is_initialized"] + + def mark_parameter_as_no_weight_decay(parameter: nn.Parameter | None) -> nn.Parameter | None: if parameter is not None: parameter._no_weight_decay = True From 6e6e074cf0b47dec0fb618c7ea48966b23c58a19 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:12:08 -0800 Subject: [PATCH 07/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/dtensors.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/lm_engine/dtensors.py b/lm_engine/dtensors.py index ff57a4406..edcf55104 100644 --- a/lm_engine/dtensors.py +++ b/lm_engine/dtensors.py @@ -17,7 +17,6 @@ def tensor_to_dtensor( current_placement: Placement | list[Placement], desired_placement: Placement | list[Placement] | None = None, run_check: bool = False, - copy_marker: bool = True, ) -> DTensor: if isinstance(tensor, DTensor): return tensor @@ -33,12 +32,6 @@ def tensor_to_dtensor( dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=True) - if copy_marker: - for marker in _ALL_MARKERS: - marker_value = getattr(dtensor, marker, None) - if marker_value is not None: - setattr(dtensor, marker, marker_value) - return dtensor @@ -47,7 +40,6 @@ def dtensor_to_tensor( device_mesh: DeviceMesh | None = None, desired_placement: Placement | list[Placement] | None = None, grad_placement: Placement | list[Placement] | None = None, - copy_marker: bool = True, ) -> torch.Tensor: if not isinstance(dtensor, DTensor): return dtensor @@ -65,12 +57,6 @@ def dtensor_to_tensor( tensor = dtensor.to_local(grad_placements=grad_placement) - if copy_marker: - for marker in _ALL_MARKERS: - marker_value = getattr(tensor, marker, None) - if marker_value is not None: - setattr(tensor, marker, marker_value) - return tensor From d775fc47af464647cbffb341b9a038f7af0662b0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:12:44 -0800 Subject: [PATCH 08/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/dtensors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lm_engine/dtensors.py b/lm_engine/dtensors.py index edcf55104..7944c77d1 100644 --- a/lm_engine/dtensors.py +++ b/lm_engine/dtensors.py @@ -8,8 +8,6 @@ from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import DeviceMesh -from .hf_models.parameter import _ALL_MARKERS - def tensor_to_dtensor( tensor: torch.Tensor, From 81a10f1cd2a6142a1189eabdc94eeeb0538f9cfd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:13:28 -0800 Subject: [PATCH 09/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/dtensors.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lm_engine/dtensors.py b/lm_engine/dtensors.py index 7944c77d1..ff57a4406 100644 --- a/lm_engine/dtensors.py +++ b/lm_engine/dtensors.py @@ -8,6 +8,8 @@ from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import DeviceMesh +from .hf_models.parameter import _ALL_MARKERS + def tensor_to_dtensor( tensor: torch.Tensor, @@ -15,6 +17,7 @@ def tensor_to_dtensor( current_placement: Placement | list[Placement], desired_placement: Placement | list[Placement] | None = None, run_check: bool = False, + copy_marker: bool = True, ) -> DTensor: if isinstance(tensor, DTensor): return tensor @@ -30,6 +33,12 @@ def tensor_to_dtensor( dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=True) + if copy_marker: + for marker in _ALL_MARKERS: + marker_value = getattr(dtensor, marker, None) + if marker_value is not None: + setattr(dtensor, marker, marker_value) + return dtensor @@ -38,6 +47,7 @@ def dtensor_to_tensor( device_mesh: DeviceMesh | None = None, desired_placement: Placement | list[Placement] | None = None, grad_placement: Placement | list[Placement] | None = None, + copy_marker: bool = True, ) -> torch.Tensor: if not isinstance(dtensor, DTensor): return dtensor @@ -55,6 +65,12 @@ def dtensor_to_tensor( tensor = dtensor.to_local(grad_placements=grad_placement) + if copy_marker: + for marker in _ALL_MARKERS: + marker_value = getattr(tensor, marker, None) + if marker_value is not None: + setattr(tensor, marker, marker_value) + return tensor From efde5ba815071a7cf3a8b6c18c9f91162280378b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:17:09 -0800 Subject: [PATCH 10/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils_TP/embedding.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lm_engine/hf_models/modeling_utils_TP/embedding.py b/lm_engine/hf_models/modeling_utils_TP/embedding.py index ba1ec9bac..87726c6eb 100644 --- a/lm_engine/hf_models/modeling_utils_TP/embedding.py +++ b/lm_engine/hf_models/modeling_utils_TP/embedding.py @@ -13,6 +13,7 @@ from ...dtensors import dtensor_to_tensor, tensor_to_dtensor from ...utils import ProcessGroupManager, divide_if_divisible from ..modeling_utils import ParameterizedEmbedding +from ..parameter import mark_parameter_as_initialized from .dtensor_module import DTensorModule from .TP import get_module_placements @@ -35,6 +36,7 @@ def __init__( num_embeddings ) + self.std = std super().__init__(num_embeddings_per_tp_rank, embedding_dim, std=std) self.weight = nn.Parameter( @@ -51,6 +53,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.output_placement) return input + @torch.no_grad() + def reset_parameters(self) -> None: + if self.std is None: + super().reset_parameters() + else: + self.weight.normal_(mean=0, std=self.std) + + mark_parameter_as_initialized(self.weight) + def get_tensor_parallel_vocab_info(vocab_size: int, make_vocab_size_divisible_by: int = 64) -> tuple[int, int, int]: tp_rank = ProcessGroupManager.get_tensor_parallel_rank() From 5f94a84a36dcd169129d506f05ed2bd5aed589c5 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:26:24 -0800 Subject: [PATCH 11/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense_TP/base.py | 8 +- .../modeling_utils/dtensor_module.py | 30 +++++++ .../hf_models/modeling_utils/embedding.py | 70 +++++++++++++++- .../hf_models/modeling_utils_TP/__init__.py | 1 - .../hf_models/modeling_utils_TP/embedding.py | 79 ------------------- .../hf_models/modeling_utils_TP/lm_head.py | 4 +- 6 files changed, 103 insertions(+), 89 deletions(-) create mode 100644 lm_engine/hf_models/modeling_utils/dtensor_module.py delete mode 100644 lm_engine/hf_models/modeling_utils_TP/embedding.py diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index 700c2b153..28bedc56b 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -10,8 +10,8 @@ from ....utils import ProcessGroupManager, divide_if_divisible from ...cache import GenerationCache from ...config import CommonConfig -from ...modeling_utils import Dropout, RoPE, YaRNScaledRoPE -from ...modeling_utils_TP import Embedding_TP, get_normalization_function_TP +from ...modeling_utils import Dropout, ParameterizedEmbedding, RoPE, YaRNScaledRoPE +from ...modeling_utils_TP import get_normalization_function_TP from ...utils import is_generation_cache_enabled from ..dense import BaseModelMixin, PreTrainedModelMixin from ..modeling_outputs import BaseModelOutputWithPast @@ -54,7 +54,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.layer_end_id = self.layers_per_stage * (self.pipeline_stage_id + 1) if self.is_first_stage: - self.wte = Embedding_TP( + self.wte = ParameterizedEmbedding( config.vocab_size, self.embed_dim, std=self.initializer_range, @@ -168,7 +168,7 @@ def _setup_positional_encoding(self) -> None: if self.position_embedding_type == "learned_absolute": if self.is_first_stage: - self.wpe = Embedding_TP( + self.wpe = ParameterizedEmbedding( max_position_embeddings, self.embed_dim, std=self.initializer_range, diff --git a/lm_engine/hf_models/modeling_utils/dtensor_module.py b/lm_engine/hf_models/modeling_utils/dtensor_module.py new file mode 100644 index 000000000..f6b993747 --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/dtensor_module.py @@ -0,0 +1,30 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +from typing import Any, Mapping + +import torch.nn as nn + +from ...dtensors import modify_state_dict_to_dtensor_dict +from ...utils import ProcessGroupManager + + +class DTensorModule(nn.Module): + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None: + if ProcessGroupManager.is_tensor_parallel_enabled(): + state_dict = modify_state_dict_to_dtensor_dict(self, state_dict=state_dict, prefix="", strip_keys=False) + + super().load_state_dict(state_dict, strict, assign) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) -> None: + if ProcessGroupManager.is_tensor_parallel_enabled(): + state_dict = modify_state_dict_to_dtensor_dict(self, state_dict=state_dict, prefix=prefix, strip_keys=True) + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index 1dc9525f3..00529f3ae 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -4,16 +4,64 @@ from __future__ import annotations +import math + import torch import torch.nn as nn +from torch.distributed._tensor.placement_types import Replicate, Shard +from ...dtensors import dtensor_to_tensor, tensor_to_dtensor +from ...utils import ProcessGroupManager, divide_if_divisible from ..parameter import mark_parameter_as_initialized +from .dtensor_module import DTensorModule +from .TP import get_module_placements + + +class ParameterizedEmbedding(nn.Embedding, DTensorModule): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + std: float | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> ParameterizedEmbedding: + nn.Module.__init__(self) + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.use_padding_free_transformer = use_padding_free_transformer + self.sequence_parallel = sequence_parallel + self.vocab_start_index, self.vocab_end_index, num_embeddings_per_tp_rank = get_tensor_parallel_vocab_info( + num_embeddings + ) + + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) + ) + ) + else: + self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim)) -class ParameterizedEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, std: float | None = None) -> ParameterizedEmbedding: self.std = std - super().__init__(num_embeddings, embedding_dim) + self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.reset_parameters() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=Replicate()) + + input = super().forward(input) + + if self.is_tp_enabled: + input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.output_placement) + + return input @torch.no_grad() def reset_parameters(self) -> None: @@ -23,3 +71,19 @@ def reset_parameters(self) -> None: self.weight.normal_(mean=0, std=self.std) mark_parameter_as_initialized(self.weight) + + +def get_tensor_parallel_vocab_info(vocab_size: int, make_vocab_size_divisible_by: int = 64) -> tuple[int, int, int]: + tp_rank = ProcessGroupManager.get_tensor_parallel_rank() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + + divide_if_divisible(make_vocab_size_divisible_by, tp_world_size) + + vocab_size_per_tensor_parallel_rank = ( + make_vocab_size_divisible_by * math.ceil(vocab_size / make_vocab_size_divisible_by) + ) // tp_world_size + + vocab_start_index = tp_rank * vocab_size_per_tensor_parallel_rank + vocab_end_index = min((tp_rank + 1) * vocab_size_per_tensor_parallel_rank, vocab_size) + + return vocab_start_index, vocab_end_index, vocab_size_per_tensor_parallel_rank diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py index 4560064d9..799042884 100644 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/__init__.py @@ -3,7 +3,6 @@ # ************************************************** from .dtensor_module import DTensorModule -from .embedding import Embedding_TP, get_tensor_parallel_vocab_info from .linear import ColumnParallelLinear, RowParallelLinear from .lm_head import LMHead_TP from .mlp_blocks import MLP_TP, MoE_TP, get_mlp_block_TP diff --git a/lm_engine/hf_models/modeling_utils_TP/embedding.py b/lm_engine/hf_models/modeling_utils_TP/embedding.py deleted file mode 100644 index 87726c6eb..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/embedding.py +++ /dev/null @@ -1,79 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import math - -import torch -import torch.nn as nn -from torch.distributed._tensor.placement_types import Replicate, Shard - -from ...dtensors import dtensor_to_tensor, tensor_to_dtensor -from ...utils import ProcessGroupManager, divide_if_divisible -from ..modeling_utils import ParameterizedEmbedding -from ..parameter import mark_parameter_as_initialized -from .dtensor_module import DTensorModule -from .TP import get_module_placements - - -class Embedding_TP(ParameterizedEmbedding, DTensorModule): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - std: float | None = None, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> Embedding_TP: - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.use_padding_free_transformer = use_padding_free_transformer - self.sequence_parallel = sequence_parallel - - self.vocab_start_index, self.vocab_end_index, num_embeddings_per_tp_rank = get_tensor_parallel_vocab_info( - num_embeddings - ) - - self.std = std - super().__init__(num_embeddings_per_tp_rank, embedding_dim, std=std) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) - ) - ) - - self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=Replicate()) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.output_placement) - return input - - @torch.no_grad() - def reset_parameters(self) -> None: - if self.std is None: - super().reset_parameters() - else: - self.weight.normal_(mean=0, std=self.std) - - mark_parameter_as_initialized(self.weight) - - -def get_tensor_parallel_vocab_info(vocab_size: int, make_vocab_size_divisible_by: int = 64) -> tuple[int, int, int]: - tp_rank = ProcessGroupManager.get_tensor_parallel_rank() - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - divide_if_divisible(make_vocab_size_divisible_by, tp_world_size, "") - - vocab_size_per_tensor_parallel_rank = ( - make_vocab_size_divisible_by * math.ceil(vocab_size / make_vocab_size_divisible_by) - ) // tp_world_size - - vocab_start_index = tp_rank * vocab_size_per_tensor_parallel_rank - vocab_end_index = min((tp_rank + 1) * vocab_size_per_tensor_parallel_rank, vocab_size) - - return vocab_start_index, vocab_end_index, vocab_size_per_tensor_parallel_rank diff --git a/lm_engine/hf_models/modeling_utils_TP/lm_head.py b/lm_engine/hf_models/modeling_utils_TP/lm_head.py index f3ae3a9c7..ed3578c90 100644 --- a/lm_engine/hf_models/modeling_utils_TP/lm_head.py +++ b/lm_engine/hf_models/modeling_utils_TP/lm_head.py @@ -8,11 +8,11 @@ from torch.distributed.device_mesh import DeviceMesh from ...dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel -from .embedding import Embedding_TP +from ..modeling_utils import ParameterizedEmbedding from .TP import get_module_placements -class LMHead_TP(Embedding_TP): +class LMHead_TP(ParameterizedEmbedding): def forward(self, input: torch.Tensor) -> torch.Tensor: return self.compute_with_weight( input, From 41ba4e9e34b5b1e3e6cdb5b2d3689c21fbb94604 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:27:08 -0800 Subject: [PATCH 12/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/__init__.py | 2 +- lm_engine/hf_models/models/gpt_base_TP/weights/shard.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index 5c662b738..e9de3c937 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -5,7 +5,7 @@ from .activations import get_activation_function, is_glu from .convolution import ParameterizedConv1d from .dropout import Dropout -from .embedding import ParameterizedEmbedding +from .embedding import ParameterizedEmbedding, get_tensor_parallel_vocab_info from .linear import ParameterizedLinear from .mlp_blocks import ( MLP, diff --git a/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py b/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py index 4d1f592ae..7b234c7d2 100644 --- a/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py +++ b/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py @@ -5,8 +5,8 @@ import torch from .....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible -from ....modeling_utils import is_glu -from ....modeling_utils_TP import get_tensor_parallel_vocab_info, tensor_parallel_split_safetensor_slice +from ....modeling_utils import get_tensor_parallel_vocab_info, is_glu +from ....modeling_utils_TP import tensor_parallel_split_safetensor_slice from ...gpt_base import GPTBaseConfig From c15a72f823c91af1121ee733686847465c76f59a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:28:22 -0800 Subject: [PATCH 13/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/embedding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index 00529f3ae..108548720 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -41,7 +41,9 @@ def __init__( self.weight = nn.Parameter( tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) + torch.empty(num_embeddings_per_tp_rank, embedding_dim), + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Shard(0), ) ) else: From b2b541ef811ff4b6c9427ff501694acc6a8fe16d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:29:12 -0800 Subject: [PATCH 14/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/embedding.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index 108548720..62ecd1442 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -54,16 +54,16 @@ def __init__( self.reset_parameters() - def forward(self, input: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_tp_enabled: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=Replicate()) + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Replicate()) - input = super().forward(input) + x = super().forward(x) if self.is_tp_enabled: - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.output_placement) + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.output_placement) - return input + return x @torch.no_grad() def reset_parameters(self) -> None: From ebb614397dfa6614bb01e4d81d7ac96b39e08349 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:31:04 -0800 Subject: [PATCH 15/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index 62ecd1442..f1f2dca90 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -46,11 +46,12 @@ def __init__( current_placement=Shard(0), ) ) + + self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) else: self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim)) self.std = std - self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) self.reset_parameters() From babcaa49fce07899f784a22400d17ec3882867b4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:45:26 -0800 Subject: [PATCH 16/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/embedding.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index f1f2dca90..fdd395465 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from torch.distributed._tensor.placement_types import Replicate, Shard from ...dtensors import dtensor_to_tensor, tensor_to_dtensor @@ -17,7 +18,7 @@ from .TP import get_module_placements -class ParameterizedEmbedding(nn.Embedding, DTensorModule): +class ParameterizedEmbedding(nn.Module, DTensorModule): def __init__( self, num_embeddings: int, @@ -59,7 +60,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_tp_enabled: x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Replicate()) - x = super().forward(x) + x = F.embedding(x, weight=self.weight) if self.is_tp_enabled: x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.output_placement) From dc4ef01845edc319e796ca75c96b85cb0be95d71 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Wed, 4 Feb 2026 22:46:31 -0800 Subject: [PATCH 17/99] add markers Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index fdd395465..1125b144e 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -18,7 +18,7 @@ from .TP import get_module_placements -class ParameterizedEmbedding(nn.Module, DTensorModule): +class ParameterizedEmbedding(DTensorModule): def __init__( self, num_embeddings: int, From 86db0ea0a8712dc9e29704f8682593fb29ea31ad Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 04:17:08 -0800 Subject: [PATCH 18/99] add GDN efficient init Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/gated_deltanet.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 21affa061..19962481f 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -12,7 +12,10 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import Replicate +from ....dtensors import tensor_to_dtensor from ....utils import divide_if_divisible, is_fla_available from ...cache import GenerationCache from ...parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay @@ -223,8 +226,18 @@ def forward( return o + @torch.no_grad() def reset_parameters(self) -> None: A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16) + + if isinstance(self.A_log, DTensor): + A = tensor_to_dtensor( + tensor=A, + device_mesh=self.A_log.device_mesh, + current_placement=[Replicate()] * len(self.A_log.placements), + desired_placement=self.A_log.placements, + ) + self.A_log.copy_(torch.log(A)) # hard coded for now @@ -236,6 +249,14 @@ def reset_parameters(self) -> None: # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) + if isinstance(self.dt_bias, DTensor): + inv_dt = tensor_to_dtensor( + tensor=inv_dt, + device_mesh=self.dt_bias.device_mesh, + current_placement=[Replicate()] * len(self.dt_bias.placements), + desired_placement=self.dt_bias.placements, + ) + self.dt_bias.copy_(inv_dt) mark_parameter_as_initialized(self.A_log) From db411cf329bd97c10334db87eef17f4e3f8c649f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 04:24:59 -0800 Subject: [PATCH 19/99] fix mamba2 init Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/mamba2.py | 57 +++++++++---------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index e4ab4f0cb..b646bf900 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -305,34 +305,31 @@ def _torch_forward( # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] # NOTE: S = 1 actually here - B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] - # B -> (B, G, 1, ssm_state_size / num_groups) - B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() - # B -> (B, G, N / G, ssm_state_size / num_groups) - B = B.reshape(batch_size, -1, B.shape[-1]) - # B -> (B, N, ssm_state_size / num_groups) - - # (B, N, head_dim, 1) * (B, N, 1, ssm_state_size / num_groups) + B, C = [i.reshape(batch_size, self.n_groups, -1)[..., None, :] for i in (B, C)] + # B, C -> (B, G, 1, ssm_state_size) + B, C = [ + i.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, i.shape[-1]).contiguous() + for i in (B, C) + ] + # B, C -> (B, G, N / G, ssm_state_size) + B, C = [i.reshape(batch_size, -1, i.shape[-1]) for i in (B, C)] + # B, C -> (B, N, ssm_state_size) + + # (B, N, head_dim, 1) * (B, N, 1, ssm_state_size) + # B is same as k and is shared across heads and dt is used to expand it dB = dt[..., None] * B[..., None, :] - # dB -> (B, N, head_dim, ssm_state_size / num_groups) + # dB -> (B, N, head_dim, ssm_state_size) # Discretize x into dB hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) # hidden_states -> (B, N, head_dim) dBx = (dB * hidden_states[..., None]).to(device=cache_device) - # dBx -> (B, N, head_dim, ssm_state_size / num_groups) + # dBx -> (B, N, head_dim, ssm_state_size) # State calculation ssm_state = ssm_state * dA + dBx cache_params.update(ssm_state=ssm_state, num_tokens_added=seq_len, layer_idx=self.layer_idx) - # Subsequent output - # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] - C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] - C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() - C = C.reshape(batch_size, -1, C.shape[-1]) - # [bsz, num_heads, head_dim] - ssm_state = ssm_state.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_state.view( @@ -351,7 +348,7 @@ def _torch_forward( y = y.reshape(batch_size, -1)[:, None, ...] else: # begin ssd naive implementation without einsums - dt = nn.functional.softplus(dt + self.dt_bias) + dt = F.softplus(dt + self.dt_bias) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() @@ -502,7 +499,8 @@ def _cuda_forward( dt_softplus=True, ) hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) - hidden_states = self.norm(hidden_states, gate) + hidden_states = hidden_states * F.silu(gate) + hidden_states = self.norm(hidden_states) # 4. Final linear projection out = self.out_proj(hidden_states)[:, None, ...] @@ -602,20 +600,21 @@ def _cuda_forward( @torch.no_grad() def reset_parameters(self) -> None: - A = torch.log(torch.arange(1, self.num_heads + 1)) + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + self.A_log.copy_(torch.log(A)) - if isinstance(self.A_log, DTensor): - A = tensor_to_dtensor( - A, - device_mesh=self.A_log.device_mesh, - current_placement=[Replicate()] * len(self.A_log.placements), - desired_placement=self.A_log.placements, - ) + # hard coded for now + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp(torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) - self.A_log.copy_(A) + self.dt_bias.copy_(inv_dt) nn.init.ones_(self.D) - nn.init.ones_(self.dt_bias) mark_parameter_as_initialized(self.A_log) mark_parameter_as_initialized(self.D) From 06ce95f9758f0563394c0060ae5ac325289f55a2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 04:33:59 -0800 Subject: [PATCH 20/99] pass args Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index cbe4c1b4e..96f1a3ecd 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -139,12 +139,14 @@ def get_sequence_mixer( num_k_heads=block.num_k_heads, num_v_heads=block.num_v_heads, use_gate=block.use_gate, + attention_multiplier=block.attention_multiplier, allow_neg_eigval=block.allow_neg_eigval, - conv_size=block.conv_size, + conv_size=block.kernel_size, layer_idx=layer_idx, norm_eps=config.layer_norm_epsilon, init_method=config.init_method, initializer_range=config.initializer_range, + m_width=config.m_width, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, ) From 3a21404265144c64111bde6ae325ab0fcae05c0b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 16:59:24 -0800 Subject: [PATCH 21/99] hidden_states -> x Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/mamba2.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index b646bf900..85eab46b5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -601,6 +601,15 @@ def _cuda_forward( @torch.no_grad() def reset_parameters(self) -> None: A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + + if isinstance(self.A_log, DTensor): + A = tensor_to_dtensor( + tensor=A, + device_mesh=self.A_log.device_mesh, + current_placement=[Replicate()] * len(self.A_log.placements), + desired_placement=self.A_log.placements, + ) + self.A_log.copy_(torch.log(A)) # hard coded for now @@ -612,6 +621,14 @@ def reset_parameters(self) -> None: # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) + if isinstance(self.dt_bias, DTensor): + inv_dt = tensor_to_dtensor( + tensor=inv_dt, + device_mesh=self.dt_bias.device_mesh, + current_placement=[Replicate()] * len(self.dt_bias.placements), + desired_placement=self.dt_bias.placements, + ) + self.dt_bias.copy_(inv_dt) nn.init.ones_(self.D) From b6dd81b25c9abf0e8d97ee12f0eb7610799ea53b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:05:13 -0800 Subject: [PATCH 22/99] hidden_states -> x Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 10 ++++++++ .../sequence_mixer_blocks/__init__.py | 5 ++++ .../sequence_mixer_blocks/gated_deltanet.py | 23 ++++++++++++++----- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index d737d46c3..0e44ceadc 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -124,6 +124,16 @@ class _GatedDeltaNetArgs(BaseArgs): attention_multiplier: float | None = None allow_neg_eigval: bool kernel_size: int + A_init_min: float = 0 + A_init_max: float = 16 + dt_min: float = 0.001 + dt_max: float = 0.1 + dt_init_floor: float = 1e-4 def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "gated_deltanet" + + assert self.A_init_min >= 0 + assert self.A_init_min <= self.A_init_max + + assert self.dt_min <= self.dt_max diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 96f1a3ecd..4fc7b6904 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -147,6 +147,11 @@ def get_sequence_mixer( init_method=config.init_method, initializer_range=config.initializer_range, m_width=config.m_width, + A_init_min=block.A_init_min, + A_init_max=block.A_init_max, + dt_min=block.dt_min, + dt_max=block.dt_max, + dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 19962481f..8fbf455a5 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -47,6 +47,11 @@ def __init__( init_method: str, initializer_range: float, m_width: float | None, + A_init_min: float, + A_init_max: float, + dt_min: float, + dt_max: float, + dt_init_floor: float, num_layers: int, use_padding_free_transformer: bool, ) -> GatedDeltaNet: @@ -67,6 +72,13 @@ def __init__( self.k_head_dim = k_head_dim self.v_head_dim = v_head_dim + self.A_init_min = A_init_min + self.A_init_max = A_init_max + + self.dt_min = dt_min + self.dt_max = dt_max + self.dt_init_floor = dt_init_floor + self.key_dim = self.num_k_heads * self.k_head_dim self.value_dim = self.num_v_heads * self.v_head_dim self.layer_idx = layer_idx @@ -228,7 +240,7 @@ def forward( @torch.no_grad() def reset_parameters(self) -> None: - A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16) + A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(self.A_init_min, self.A_init_max) if isinstance(self.A_log, DTensor): A = tensor_to_dtensor( @@ -241,11 +253,10 @@ def reset_parameters(self) -> None: self.A_log.copy_(torch.log(A)) # hard coded for now - dt_min = 0.001 - dt_max = 0.1 - dt_init_floor = 1e-4 - dt = torch.exp(torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) - dt = torch.clamp(dt, min=dt_init_floor) + dt = torch.exp( + torch.rand(self.num_v_heads) * (math.log(self.dt_max) - math.log(self.dt_min)) + math.log(self.dt_min) + ) + dt = torch.clamp(dt, min=self.dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) From 8dc2c7952d144d028c61e93fe5e013bbbd087cf8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:07:35 -0800 Subject: [PATCH 23/99] hidden_states -> x Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 57 ++++++++++---------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 0e44ceadc..1c497d649 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -49,24 +49,6 @@ def model_post_init(self, __context: Any) -> None: assert self.head_dim is not None -class _Mamba2Args(BaseArgs): - sequence_mixer_type: str = "mamba2" - state_size: int = 128 - intermediate_size: int - num_heads: int = 128 - conv_kernel_size: int = 4 - time_step_limit: tuple[float, float] = (0, float("inf")) - add_bias: bool = False - use_conv_bias: bool = True - activation_function: str = "silu" - num_groups: int = 8 - chunk_size: int = 256 - normalization_function: str | None = "rmsnorm" - - def model_post_init(self, __context: Any) -> None: - assert self.sequence_mixer_type == "mamba2" - - class _GRUArgs(BaseArgs): sequence_mixer_type: str = "gru" state_head_dim: int @@ -114,7 +96,20 @@ def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "causal_convolution" -class _GatedDeltaNetArgs(BaseArgs): +class _SoftPlusDecayArgs(BaseArgs): + A_init_min: float = 0 + A_init_max: float = 16 + dt_min: float = 0.001 + dt_max: float = 0.1 + dt_init_floor: float = 1e-4 + + def model_post_init(self, __context: Any) -> None: + assert self.A_init_min >= 0 + assert self.A_init_min <= self.A_init_max + assert self.dt_min <= self.dt_max + + +class _GatedDeltaNetArgs(_SoftPlusDecayArgs): sequence_mixer_type: str = "gated_deltanet" k_head_dim: int v_head_dim: int @@ -124,16 +119,24 @@ class _GatedDeltaNetArgs(BaseArgs): attention_multiplier: float | None = None allow_neg_eigval: bool kernel_size: int - A_init_min: float = 0 - A_init_max: float = 16 - dt_min: float = 0.001 - dt_max: float = 0.1 - dt_init_floor: float = 1e-4 def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "gated_deltanet" - assert self.A_init_min >= 0 - assert self.A_init_min <= self.A_init_max - assert self.dt_min <= self.dt_max +class _Mamba2Args(_SoftPlusDecayArgs): + sequence_mixer_type: str = "mamba2" + state_size: int = 128 + intermediate_size: int + num_heads: int = 128 + conv_kernel_size: int = 4 + time_step_limit: tuple[float, float] = (0, float("inf")) + add_bias: bool = False + use_conv_bias: bool = True + activation_function: str = "silu" + num_groups: int = 8 + chunk_size: int = 256 + normalization_function: str | None = "rmsnorm" + + def model_post_init(self, __context: Any) -> None: + assert self.sequence_mixer_type == "mamba2" From 3bec1f004b4fd0d95486de9927f6949bf72e2c1c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:10:51 -0800 Subject: [PATCH 24/99] hidden_states -> x Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/decay_gate.py | 118 ++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 lm_engine/hf_models/modeling_utils/decay_gate.py diff --git a/lm_engine/hf_models/modeling_utils/decay_gate.py b/lm_engine/hf_models/modeling_utils/decay_gate.py new file mode 100644 index 000000000..a76824eaf --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/decay_gate.py @@ -0,0 +1,118 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import Replicate + +from ...dtensors import tensor_to_dtensor +from ..parameter import ( + mark_parameter_as_initialized, + mark_parameter_as_mup_learning_rate, + mark_parameter_as_no_weight_decay, +) +from .linear import ParameterizedLinear + + +class SoftplusDecayGate(nn.Module): + def __init__( + self, + hidden_size: int, + output_size: int, + std: float | None, + has_projection: bool = False, + A_init_min: float = 0, + A_init_max: float = 16, + dt_init_min: float = 1e-3, + dt_init_max: float = 0.1, + dt_init_floor: float = 1e-4, + ) -> SoftplusDecayGate: + super().__init__() + + self.output_size = output_size + self.has_projection = has_projection + + if has_projection: + self.proj = ParameterizedLinear(hidden_size, self.output_size, std=std) + mark_parameter_as_mup_learning_rate(self.proj.weight) + + self.A_log = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + self.dt_bias = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + + assert A_init_min >= 0 + assert A_init_max >= A_init_min + + self.A_init_min = A_init_min + self.A_init_max = A_init_max + + assert dt_init_min > 0 + assert dt_init_max >= dt_init_min + + self.dt_init_min = dt_init_min + self.dt_init_max = dt_init_max + self.dt_init_floor = dt_init_floor + + self.reset_parameters() + + mark_parameter_as_no_weight_decay(self.A_log) + mark_parameter_as_no_weight_decay(self.dt_bias) + + def forward( + self, x: torch.Tensor, final_exponential: bool, output_dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + if self.has_projection: + x = self.proj(x) + + x = x.float() + x = x + self.dt_bias + x = F.softplus(x) + x = -self.A_log.float().exp() * x + + if final_exponential: + x = torch.exp(x) + + x = x.to(output_dtype) + + return x + + @torch.no_grad() + def reset_parameters(self) -> None: + A = torch.empty(self.output_size, dtype=torch.float32).uniform_(self.A_init_min, self.A_init_max) + + if isinstance(self.A_log, DTensor): + A = tensor_to_dtensor( + tensor=A, + device_mesh=self.A_log.device_mesh, + current_placement=[Replicate()] * len(self.A_log.placements), + desired_placement=self.A_log.placements, + ) + + self.A_log.copy_(torch.log(A)) + + dt = torch.exp( + torch.rand(self.output_size) * (math.log(self.dt_init_max) - math.log(self.dt_init_min)) + + math.log(self.dt_init_min) + ) + dt = torch.clamp(dt, min=self.dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + + if isinstance(self.dt_bias, DTensor): + inv_dt = tensor_to_dtensor( + tensor=inv_dt, + device_mesh=self.dt_bias.device_mesh, + current_placement=[Replicate()] * len(self.dt_bias.placements), + desired_placement=self.dt_bias.placements, + ) + + self.dt_bias.copy_(inv_dt) + + mark_parameter_as_initialized(self.A_log) + mark_parameter_as_initialized(self.dt_bias) From 52b14a334b2cebbb12aa81653887ee256ae75bf6 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:20:24 -0800 Subject: [PATCH 25/99] use gate for GDN Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/decay_gate.py | 10 +-- .../sequence_mixer_blocks/gated_deltanet.py | 69 ++++--------------- 2 files changed, 21 insertions(+), 58 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/decay_gate.py b/lm_engine/hf_models/modeling_utils/decay_gate.py index a76824eaf..f9f8de738 100644 --- a/lm_engine/hf_models/modeling_utils/decay_gate.py +++ b/lm_engine/hf_models/modeling_utils/decay_gate.py @@ -24,7 +24,7 @@ class SoftplusDecayGate(nn.Module): def __init__( self, - hidden_size: int, + hidden_size: int | None, output_size: int, std: float | None, has_projection: bool = False, @@ -42,9 +42,14 @@ def __init__( if has_projection: self.proj = ParameterizedLinear(hidden_size, self.output_size, std=std) mark_parameter_as_mup_learning_rate(self.proj.weight) + else: + assert hidden_size is None self.A_log = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + mark_parameter_as_no_weight_decay(self.A_log) + self.dt_bias = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + mark_parameter_as_no_weight_decay(self.dt_bias) assert A_init_min >= 0 assert A_init_max >= A_init_min @@ -61,9 +66,6 @@ def __init__( self.reset_parameters() - mark_parameter_as_no_weight_decay(self.A_log) - mark_parameter_as_no_weight_decay(self.dt_bias) - def forward( self, x: torch.Tensor, final_exponential: bool, output_dtype: torch.dtype = torch.float32 ) -> torch.Tensor: diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 8fbf455a5..437cea0cb 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -12,14 +12,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.placement_types import Replicate -from ....dtensors import tensor_to_dtensor from ....utils import divide_if_divisible, is_fla_available from ...cache import GenerationCache -from ...parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay from ..convolution import ParameterizedConv1d +from ..decay_gate import SoftplusDecayGate from ..linear import ParameterizedLinear from ..normalization import get_normalization_function from .causal_convolution import causal_convolution @@ -49,8 +46,8 @@ def __init__( m_width: float | None, A_init_min: float, A_init_max: float, - dt_min: float, - dt_max: float, + dt_init_min: float, + dt_init_max: float, dt_init_floor: float, num_layers: int, use_padding_free_transformer: bool, @@ -72,13 +69,6 @@ def __init__( self.k_head_dim = k_head_dim self.v_head_dim = v_head_dim - self.A_init_min = A_init_min - self.A_init_max = A_init_max - - self.dt_min = dt_min - self.dt_max = dt_max - self.dt_init_floor = dt_init_floor - self.key_dim = self.num_k_heads * self.k_head_dim self.value_dim = self.num_v_heads * self.v_head_dim self.layer_idx = layer_idx @@ -96,11 +86,17 @@ def __init__( hidden_size, 2 * self.num_v_heads + (self.value_dim if use_gate else 0), bias=False, std=std ) - self.A_log = nn.Parameter(torch.empty(self.num_v_heads, dtype=torch.float32)) - mark_parameter_as_no_weight_decay(self.A_log) - - self.dt_bias = nn.Parameter(torch.empty(self.num_v_heads)) - mark_parameter_as_no_weight_decay(self.dt_bias) + self.decay_gate = SoftplusDecayGate( + hidden_size=None, + output_size=self.num_v_heads, + std=None, + has_projection=False, + A_init_min=A_init_min, + A_init_max=A_init_max, + dt_init_min=dt_init_min, + dt_init_max=dt_init_max, + dt_init_floor=dt_init_floor, + ) self.conv_size = conv_size self.qkv_conv1d = ParameterizedConv1d( @@ -171,7 +167,7 @@ def forward( if self.allow_neg_eigval: beta = beta * 2.0 - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + g = self.decay_gate(x=a, final_exponential=False) if self.use_padding_free_transformer: assert cache_params is None @@ -237,38 +233,3 @@ def forward( o = self.o_proj(o) return o - - @torch.no_grad() - def reset_parameters(self) -> None: - A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(self.A_init_min, self.A_init_max) - - if isinstance(self.A_log, DTensor): - A = tensor_to_dtensor( - tensor=A, - device_mesh=self.A_log.device_mesh, - current_placement=[Replicate()] * len(self.A_log.placements), - desired_placement=self.A_log.placements, - ) - - self.A_log.copy_(torch.log(A)) - - # hard coded for now - dt = torch.exp( - torch.rand(self.num_v_heads) * (math.log(self.dt_max) - math.log(self.dt_min)) + math.log(self.dt_min) - ) - dt = torch.clamp(dt, min=self.dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - if isinstance(self.dt_bias, DTensor): - inv_dt = tensor_to_dtensor( - tensor=inv_dt, - device_mesh=self.dt_bias.device_mesh, - current_placement=[Replicate()] * len(self.dt_bias.placements), - desired_placement=self.dt_bias.placements, - ) - - self.dt_bias.copy_(inv_dt) - - mark_parameter_as_initialized(self.A_log) - mark_parameter_as_initialized(self.dt_bias) From b062d8676b726ca1f075be0af518913937ef4ecd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:23:38 -0800 Subject: [PATCH 26/99] use gate for mamba2 Signed-off-by: Mayank Mishra --- .../sequence_mixer_blocks/mamba2.py | 47 +++---------------- 1 file changed, 7 insertions(+), 40 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 85eab46b5..cf7e5a384 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -168,23 +168,24 @@ def __init__( # instantiate once and copy inv_dt in init_weights of PretrainedModel # Initialize log dt bias self.dt_bias = nn.Parameter(torch.empty(self.num_heads)) + mark_parameter_as_no_weight_decay(self.dt_bias) # S4D real initialization. These are not discretized! # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded self.A_log = nn.Parameter(torch.empty(self.num_heads)) + mark_parameter_as_no_weight_decay(self.A_log) + mark_parameter_as_mup_learning_rate(self.A_log) + self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) + self.D = nn.Parameter(torch.empty(self.num_heads)) + mark_parameter_as_no_weight_decay(self.D) + mark_parameter_as_mup_learning_rate(self.D) self.out_proj = ParameterizedLinear( self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) ) - mark_parameter_as_no_weight_decay(self.dt_bias) - mark_parameter_as_no_weight_decay(self.A_log) - mark_parameter_as_no_weight_decay(self.D) - - mark_parameter_as_mup_learning_rate(self.A_log) - mark_parameter_as_mup_learning_rate(self.D) mark_parameter_as_mup_learning_rate(self.conv1d.weight) mark_parameter_as_mup_learning_rate(self.in_proj.weight) mark_parameter_as_mup_learning_rate(self.out_proj.weight) @@ -600,39 +601,5 @@ def _cuda_forward( @torch.no_grad() def reset_parameters(self) -> None: - A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) - - if isinstance(self.A_log, DTensor): - A = tensor_to_dtensor( - tensor=A, - device_mesh=self.A_log.device_mesh, - current_placement=[Replicate()] * len(self.A_log.placements), - desired_placement=self.A_log.placements, - ) - - self.A_log.copy_(torch.log(A)) - - # hard coded for now - dt_min = 0.001 - dt_max = 0.1 - dt_init_floor = 1e-4 - dt = torch.exp(torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) - dt = torch.clamp(dt, min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - if isinstance(self.dt_bias, DTensor): - inv_dt = tensor_to_dtensor( - tensor=inv_dt, - device_mesh=self.dt_bias.device_mesh, - current_placement=[Replicate()] * len(self.dt_bias.placements), - desired_placement=self.dt_bias.placements, - ) - - self.dt_bias.copy_(inv_dt) - nn.init.ones_(self.D) - - mark_parameter_as_initialized(self.A_log) mark_parameter_as_initialized(self.D) - mark_parameter_as_initialized(self.dt_bias) From 87e7a4432b0adef48eb174b6fde42dcc9682e5e3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:36:17 -0800 Subject: [PATCH 27/99] use gate for mamba2 Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 4 +- .../sequence_mixer_blocks/__init__.py | 9 +++- .../sequence_mixer_blocks/mamba2.py | 46 +++++++++++-------- 3 files changed, 35 insertions(+), 24 deletions(-) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 1c497d649..9d128a162 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -99,8 +99,8 @@ def model_post_init(self, __context: Any) -> None: class _SoftPlusDecayArgs(BaseArgs): A_init_min: float = 0 A_init_max: float = 16 - dt_min: float = 0.001 - dt_max: float = 0.1 + dt_init_min: float = 0.001 + dt_init_max: float = 0.1 dt_init_floor: float = 1e-4 def model_post_init(self, __context: Any) -> None: diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 4fc7b6904..aa8e979e4 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -105,6 +105,11 @@ def get_sequence_mixer( init_method=config.init_method, normalization_function=block.normalization_function, m_width=config.m_width, + A_init_min=block.A_init_min, + A_init_max=block.A_init_max, + dt_init_min=block.dt_init_min, + dt_init_max=block.dt_init_max, + dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, layer_idx=layer_idx, ) @@ -149,8 +154,8 @@ def get_sequence_mixer( m_width=config.m_width, A_init_min=block.A_init_min, A_init_max=block.A_init_max, - dt_min=block.dt_min, - dt_max=block.dt_max, + dt_min=block.dt_init_min, + dt_max=block.dt_init_max, dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index cf7e5a384..531b158ff 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -24,6 +24,7 @@ ) from ..activations import get_activation_function from ..convolution import ParameterizedConv1d +from ..decay_gate import SoftplusDecayGate from ..linear import ParameterizedLinear from ..mlp_blocks.mlp import _get_std_for_linear from ..normalization import get_normalization_function @@ -115,6 +116,11 @@ def __init__( layer_norm_epsilon: float, initializer_range: float, m_width: float, + A_init_min: float, + A_init_max: float, + dt_init_min: float, + dt_init_max: float, + dt_init_floor: float, init_method: str, normalization_function: str | None, num_layers: int, @@ -162,19 +168,19 @@ def __init__( std=std, ) - # selective projection used to make dt, B and C input dependant - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - # Initialize log dt bias - self.dt_bias = nn.Parameter(torch.empty(self.num_heads)) - mark_parameter_as_no_weight_decay(self.dt_bias) + self.decay_gate = SoftplusDecayGate( + hidden_size=None, + output_size=self.num_heads, + std=None, + has_projection=False, + A_init_min=A_init_min, + A_init_max=A_init_max, + dt_init_min=dt_init_min, + dt_init_max=dt_init_max, + dt_init_floor=dt_init_floor, + ) - # S4D real initialization. These are not discretized! - # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - self.A_log = nn.Parameter(torch.empty(self.num_heads)) - mark_parameter_as_no_weight_decay(self.A_log) - mark_parameter_as_mup_learning_rate(self.A_log) + mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) @@ -272,7 +278,7 @@ def _torch_forward( ) # 3. SSM transformation - A = -torch.exp(self.A_log.float()) + A = -torch.exp(self.decay_gate.A_log.float()) # hidden_states -> B, S, N, head_dim # A -> num_heads @@ -291,7 +297,7 @@ def _torch_forward( # dt -> (B, 1, N) dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) # dt -> (B, N, head_dim) - dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + dt_bias = self.decay_gate.dt_bias[..., None].expand(self.decay_gate.dt_bias.shape[0], self.head_dim) dt = F.softplus(dt + dt_bias.to(dt.dtype)) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) @@ -349,7 +355,7 @@ def _torch_forward( y = y.reshape(batch_size, -1)[:, None, ...] else: # begin ssd naive implementation without einsums - dt = F.softplus(dt + self.dt_bias) + dt = F.softplus(dt + self.decay_gate.dt_bias) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() @@ -479,10 +485,10 @@ def _cuda_forward( ) # 3. SSM transformation - A = -torch.exp(self.A_log.float()) # (nheads,) + A = -torch.exp(self.decay_gate.A_log.float()) # (nheads,) A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) dt = dt[:, :, None].expand(-1, -1, self.head_dim) - dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + dt_bias = self.decay_gate.dt_bias[:, None, ...].expand(-1, self.head_dim) D = self.D[:, None, ...].expand(-1, self.head_dim) B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) @@ -507,7 +513,7 @@ def _cuda_forward( out = self.out_proj(hidden_states)[:, None, ...] # Fused calculations or step by step if no initialized cache is found else: - A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + A = -torch.exp(self.decay_gate.A_log.float()) # (num_heads) or (intermediate_size, state_size) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} # 2-4. Fused kernel for conv1d, SSM, and the final projection @@ -516,7 +522,7 @@ def _cuda_forward( projected_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, - self.dt_bias, + self.decay_gate.dt_bias, A, D=self.D, chunk_size=self.chunk_size, @@ -580,7 +586,7 @@ def _cuda_forward( z=None, seq_idx=None, return_final_states=True, - dt_bias=self.dt_bias, + dt_bias=self.decay_gate.dt_bias, dt_softplus=True, **dt_limit_kwargs, ) From d1050e5e814181674cfada372b73b2a92680dc18 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:42:22 -0800 Subject: [PATCH 28/99] fix tests Signed-off-by: Mayank Mishra --- .../hf_models/model_conversion/granitemoehybrid.py | 12 ++++++------ tests/training/params_group/groups/mup.json | 4 ++-- tests/training/params_group/groups/normal.json | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lm_engine/hf_models/model_conversion/granitemoehybrid.py b/lm_engine/hf_models/model_conversion/granitemoehybrid.py index 6bae682e8..2f4f2f518 100644 --- a/lm_engine/hf_models/model_conversion/granitemoehybrid.py +++ b/lm_engine/hf_models/model_conversion/granitemoehybrid.py @@ -176,11 +176,11 @@ def _import_granitemoehybrid_state_dict( state_dict[f"transformer.h.{layer_idx}.sequence_mixer.in_proj.bias"] = ( safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.in_proj.bias") ) - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.dt_bias"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mamba.dt_bias" + state_dict[f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.dt_bias"] = ( + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.dt_bias") ) - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.A_log"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mamba.A_log" + state_dict[f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.A_log"] = ( + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.A_log") ) state_dict[f"transformer.h.{layer_idx}.sequence_mixer.D"] = safetensors_weights_manager.get_tensor( f"model.layers.{layer_idx}.mamba.D" @@ -404,10 +404,10 @@ def _export_granitemoehybrid_state_dict( f"transformer.h.{layer_idx}.sequence_mixer.in_proj.bias" ) state_dict[f"model.layers.{layer_idx}.mamba.dt_bias"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.sequence_mixer.dt_bias" + f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.dt_bias" ) state_dict[f"model.layers.{layer_idx}.mamba.A_log"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.sequence_mixer.A_log" + f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.A_log" ) state_dict[f"model.layers.{layer_idx}.mamba.D"] = safetensors_weights_manager.get_tensor( f"transformer.h.{layer_idx}.sequence_mixer.D" diff --git a/tests/training/params_group/groups/mup.json b/tests/training/params_group/groups/mup.json index f39be3fa9..cb6d202a7 100644 --- a/tests/training/params_group/groups/mup.json +++ b/tests/training/params_group/groups/mup.json @@ -16,7 +16,7 @@ "model.transformer.h.1.mlp_block.c_proj.bias", "model.transformer.h.1.mlp_block.c_proj_shared.bias", "model.transformer.h.1.sequence_mixer.conv1d.bias", - "model.transformer.h.1.sequence_mixer.dt_bias", + "model.transformer.h.1.sequence_mixer.decay_gate.dt_bias", "model.transformer.h.1.sequence_mixer.in_proj.bias", "model.transformer.h.1.sequence_mixer.norm.weight", "model.transformer.h.1.sequence_mixer.out_proj.bias", @@ -42,9 +42,9 @@ "model.transformer.h.1.mlp_block.c_proj.weight", "model.transformer.h.1.mlp_block.c_proj_shared.weight", "model.transformer.h.1.mlp_block.gate.weight", - "model.transformer.h.1.sequence_mixer.A_log", "model.transformer.h.1.sequence_mixer.D", "model.transformer.h.1.sequence_mixer.conv1d.weight", + "model.transformer.h.1.sequence_mixer.decay_gate.A_log", "model.transformer.h.1.sequence_mixer.in_proj.weight", "model.transformer.h.1.sequence_mixer.out_proj.weight", "model.transformer.h.2.mlp_block.c_fc.weight", diff --git a/tests/training/params_group/groups/normal.json b/tests/training/params_group/groups/normal.json index c14d0f0df..ae8f2a459 100644 --- a/tests/training/params_group/groups/normal.json +++ b/tests/training/params_group/groups/normal.json @@ -35,10 +35,10 @@ "model.transformer.h.1.mlp_block.c_fc_shared.bias", "model.transformer.h.1.mlp_block.c_proj.bias", "model.transformer.h.1.mlp_block.c_proj_shared.bias", - "model.transformer.h.1.sequence_mixer.A_log", "model.transformer.h.1.sequence_mixer.D", "model.transformer.h.1.sequence_mixer.conv1d.bias", - "model.transformer.h.1.sequence_mixer.dt_bias", + "model.transformer.h.1.sequence_mixer.decay_gate.A_log", + "model.transformer.h.1.sequence_mixer.decay_gate.dt_bias", "model.transformer.h.1.sequence_mixer.in_proj.bias", "model.transformer.h.1.sequence_mixer.norm.weight", "model.transformer.h.1.sequence_mixer.out_proj.bias", From 750f9c0131201ad359033d1fdef834c785448906 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:44:57 -0800 Subject: [PATCH 29/99] fix tests Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 62 ++++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 9d128a162..5b22610dc 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -49,6 +49,37 @@ def model_post_init(self, __context: Any) -> None: assert self.head_dim is not None +class _SoftPlusDecayArgs(BaseArgs): + A_init_min: float = 0 + A_init_max: float = 16 + dt_init_min: float = 0.001 + dt_init_max: float = 0.1 + dt_init_floor: float = 1e-4 + + def model_post_init(self, __context: Any) -> None: + assert self.A_init_min >= 0 + assert self.A_init_min <= self.A_init_max + assert self.dt_init_min <= self.dt_init_max + + +class _Mamba2Args(_SoftPlusDecayArgs): + sequence_mixer_type: str = "mamba2" + state_size: int = 128 + intermediate_size: int + num_heads: int = 128 + conv_kernel_size: int = 4 + time_step_limit: tuple[float, float] = (0, float("inf")) + add_bias: bool = False + use_conv_bias: bool = True + activation_function: str = "silu" + num_groups: int = 8 + chunk_size: int = 256 + normalization_function: str | None = "rmsnorm" + + def model_post_init(self, __context: Any) -> None: + assert self.sequence_mixer_type == "mamba2" + + class _GRUArgs(BaseArgs): sequence_mixer_type: str = "gru" state_head_dim: int @@ -96,19 +127,6 @@ def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "causal_convolution" -class _SoftPlusDecayArgs(BaseArgs): - A_init_min: float = 0 - A_init_max: float = 16 - dt_init_min: float = 0.001 - dt_init_max: float = 0.1 - dt_init_floor: float = 1e-4 - - def model_post_init(self, __context: Any) -> None: - assert self.A_init_min >= 0 - assert self.A_init_min <= self.A_init_max - assert self.dt_min <= self.dt_max - - class _GatedDeltaNetArgs(_SoftPlusDecayArgs): sequence_mixer_type: str = "gated_deltanet" k_head_dim: int @@ -122,21 +140,3 @@ class _GatedDeltaNetArgs(_SoftPlusDecayArgs): def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "gated_deltanet" - - -class _Mamba2Args(_SoftPlusDecayArgs): - sequence_mixer_type: str = "mamba2" - state_size: int = 128 - intermediate_size: int - num_heads: int = 128 - conv_kernel_size: int = 4 - time_step_limit: tuple[float, float] = (0, float("inf")) - add_bias: bool = False - use_conv_bias: bool = True - activation_function: str = "silu" - num_groups: int = 8 - chunk_size: int = 256 - normalization_function: str | None = "rmsnorm" - - def model_post_init(self, __context: Any) -> None: - assert self.sequence_mixer_type == "mamba2" From 8e5903e74666b35dfa4021a574003f08fe835256 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:52:10 -0800 Subject: [PATCH 30/99] fix tests Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 531b158ff..dbd5f3d35 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -180,18 +180,17 @@ def __init__( dt_init_floor=dt_init_floor, ) - mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) - self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) self.D = nn.Parameter(torch.empty(self.num_heads)) mark_parameter_as_no_weight_decay(self.D) - mark_parameter_as_mup_learning_rate(self.D) self.out_proj = ParameterizedLinear( self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) ) + mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) + mark_parameter_as_mup_learning_rate(self.D) mark_parameter_as_mup_learning_rate(self.conv1d.weight) mark_parameter_as_mup_learning_rate(self.in_proj.weight) mark_parameter_as_mup_learning_rate(self.out_proj.weight) From cc76e2449af88ea091be0dab2cb60493f4138a40 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:55:49 -0800 Subject: [PATCH 31/99] fix tests Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index dbd5f3d35..a0bbdff05 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -9,10 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.placement_types import Replicate -from ....dtensors import tensor_to_dtensor from ....enums import Kernel from ....kernels import is_kernel_allowed from ....utils import divide_if_divisible, is_causal_conv1d_available, is_mamba_2_ssm_available From 60e092a3791202139e7e6434355b76749ead2f28 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 17:58:45 -0800 Subject: [PATCH 32/99] fix tests Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/mamba2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index a0bbdff05..544ec4393 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -177,15 +177,16 @@ def __init__( dt_init_floor=dt_init_floor, ) - self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) - self.D = nn.Parameter(torch.empty(self.num_heads)) - mark_parameter_as_no_weight_decay(self.D) + + self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) self.out_proj = ParameterizedLinear( self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) ) + mark_parameter_as_no_weight_decay(self.D) + mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) mark_parameter_as_mup_learning_rate(self.D) mark_parameter_as_mup_learning_rate(self.conv1d.weight) From 740a636424c201906413d1d0c2c42f1b82d99197 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:04:26 -0800 Subject: [PATCH 33/99] merge Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/gated_deltanet.py | 2 -- .../hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 437cea0cb..c16a42f3e 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -117,8 +117,6 @@ def __init__( std /= math.sqrt(m_width) self.o_proj = ParameterizedLinear(self.value_dim, hidden_size, bias=False, std=std) - self.reset_parameters() - def forward( self, hidden_states: torch.Tensor, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index 544ec4393..8866f93a3 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -177,14 +177,13 @@ def __init__( dt_init_floor=dt_init_floor, ) - self.D = nn.Parameter(torch.empty(self.num_heads)) - self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) self.out_proj = ParameterizedLinear( self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) ) + self.D = nn.Parameter(torch.empty(self.num_heads)) mark_parameter_as_no_weight_decay(self.D) mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) From 43e099c99aa1bc7767176fe84f68fb8e07d618f1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:13:01 -0800 Subject: [PATCH 34/99] merge Signed-off-by: Mayank Mishra --- .../tensor_parallel/tensor_parallel_forward_test.py | 8 +++++--- tests/hf_models/single_gpu/typecheck_test.py | 12 +++++++++++- tests/hf_models/test_common.py | 4 ---- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py index 7fc50d93a..1ee3b74fc 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py @@ -8,7 +8,7 @@ import torch from parameterized import parameterized -from lm_engine.utils import is_flash_attention_2_available, torch_dtype_to_string +from lm_engine.utils import is_flash_attention_2_available, is_flash_attention_3_available, torch_dtype_to_string from ...test_common import TestCommons @@ -17,7 +17,7 @@ class TensorParallelTest(TestCommons): @parameterized.expand( TestCommons.make_args_matrix( TestCommons.get_position_embedding_types(), - TestCommons.get_attention_implementations(), + ["sdpa", "flash_attention_2", "flash_attention_3"], TestCommons.get_dtypes(), [False, True], [False, True], @@ -41,7 +41,9 @@ def test_tensor_parallel_forward( self.skipTest("skipping test since running all takes too long") if attention_implementation == "flash_attention_2" and not is_flash_attention_2_available(): - self.skipTest("skipping test since flash-attn is unavialable") + self.skipTest("skipping test because flash attention 2 is unavailable") + elif attention_implementation == "flash_attention_3" and not is_flash_attention_3_available(): + self.skipTest("skipping test because flash attention 3 is unavailable") if use_padding_free_transformer and attention_implementation != "flash_attention_2": self.skipTest("skipping test since flash attention is needed for padding free transformer") diff --git a/tests/hf_models/single_gpu/typecheck_test.py b/tests/hf_models/single_gpu/typecheck_test.py index ded25bb6a..480c12d32 100644 --- a/tests/hf_models/single_gpu/typecheck_test.py +++ b/tests/hf_models/single_gpu/typecheck_test.py @@ -7,6 +7,7 @@ from lm_engine.enums import Kernel from lm_engine.kernels import enable_kernels +from lm_engine.utils import is_flash_attention_2_available, is_flash_attention_3_available from ..test_common import TestCommons @@ -25,5 +26,14 @@ def test_no_attention_mask_flash_attention(self, device: torch.device) -> None: input_ids, _, labels = self.get_dummy_inputs(device, return_list=True) attention_mask = [[1] * len(i) for i in input_ids] - with enable_kernels([Kernel.flash_attention_2]): + kernel = None + if is_flash_attention_3_available(): + kernel = Kernel.flash_attention_3 + if is_flash_attention_2_available(): + kernel = Kernel.flash_attention_2 + + if kernel is None: + self.skipTest("skipping test because flash attention 2 or 3 is unavailable") + + with enable_kernels([kernel]): self.assertRaises(AssertionError, model, input_ids=input_ids, attention_mask=attention_mask, labels=labels) diff --git a/tests/hf_models/test_common.py b/tests/hf_models/test_common.py index cbc4152b1..914760684 100644 --- a/tests/hf_models/test_common.py +++ b/tests/hf_models/test_common.py @@ -20,10 +20,6 @@ class TestCommons(BaseTestCommons): - @staticmethod - def get_attention_implementations() -> list[str]: - return ["sdpa", "flash_attention_2"] - @staticmethod def get_position_embedding_types() -> list[str]: return ["learned_absolute", "rope"] From d1fd4fb6c06866972d413d42ee0db99c5f7e616f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:13:44 -0800 Subject: [PATCH 35/99] merge Signed-off-by: Mayank Mishra --- .../multi_gpu/tensor_parallel/tensor_parallel_forward_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py index 1ee3b74fc..3d6da4f19 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py @@ -37,6 +37,7 @@ def test_tensor_parallel_forward( if (attention_implementation, dtype) not in [ ("sdpa", torch.float32), ("flash_attention_2", torch.float16), + ("flash_attention_3", torch.float16), ]: self.skipTest("skipping test since running all takes too long") From 913e08d407ad1b06da7e548ba54cdcbeede3a2f0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:19:56 -0800 Subject: [PATCH 36/99] merge Signed-off-by: Mayank Mishra --- .../tensor_parallel_forward.py | 143 +++++++++--------- 1 file changed, 73 insertions(+), 70 deletions(-) diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index 431ebba7d..2698cc8a7 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -62,90 +62,93 @@ ], ) -enable_kernels( - [Kernel.scattermoe] + ([Kernel.flash_attention_2] if args.attention_implementation == "flash_attention_2" else []) -).__enter__() +kernels = [Kernel.scattermoe] +if args.attention_implementation == "flash_attention_2": + kernels.append(Kernel.flash_attention_2) +elif args.attention_implementation == "flash_attention_3": + kernels.append(Kernel.flash_attention_3) -if torch.distributed.get_rank() == 0: - with torch.device("meta"): - model = TestCommons.from_config(None, config) - - model = model.to_empty(device=torch.cuda.current_device()) - for _, param in model.named_parameters(): - param.data.normal_(0, 0.0125) +with enable_kernels(kernels): + if torch.distributed.get_rank() == 0: + with torch.device("meta"): + model = TestCommons.from_config(None, config) - model.eval() + model = model.to_empty(device=torch.cuda.current_device()) + for _, param in model.named_parameters(): + param.data.normal_(0, 0.0125) - model.save_pretrained(args.tmp_path, safe_serialization=True) - model = model.to(dtype) + model.eval() -Communication.barrier() + model.save_pretrained(args.tmp_path, safe_serialization=True) + model = model.to(dtype) -# use dummy tensors to avoid initializing model here -with torch.device("meta"): - # try sharding vocab matrices if really struggling for memory + Communication.barrier() - model_tp = get_model_parallel_class(config.model_type)._from_config( - config, - use_padding_free_transformer=args.use_padding_free_transformer, - sequence_parallel=args.sequence_parallel, - ) + # use dummy tensors to avoid initializing model here + with torch.device("meta"): + # try sharding vocab matrices if really struggling for memory -# copy to device without copying storage -model_tp = model_tp.to_empty(device=torch.cuda.current_device()) + model_tp = get_model_parallel_class(config.model_type)._from_config( + config, + use_padding_free_transformer=args.use_padding_free_transformer, + sequence_parallel=args.sequence_parallel, + ) -# load weights into tensor parallel model using SafeTensorsWeightsManager class -# this avoids loading multiple copies of the parameters in CPU memory -model_tp.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(args.tmp_path)) + # copy to device without copying storage + model_tp = model_tp.to_empty(device=torch.cuda.current_device()) -# set model to eval mode -model_tp = model_tp.to(dtype) -model_tp.eval() + # load weights into tensor parallel model using SafeTensorsWeightsManager class + # this avoids loading multiple copies of the parameters in CPU memory + model_tp.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(args.tmp_path)) -set_seed(42) + # set model to eval mode + model_tp = model_tp.to(dtype) + model_tp.eval() -batch_size = 4 -sequence_length = 512 + set_seed(42) -input_ids = torch.randint( - 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False -) -labels = torch.randint( - 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False -) + batch_size = 4 + sequence_length = 512 -if args.use_padding_free_transformer: - cu_seqlens = torch.arange( - 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() + input_ids = torch.randint( + 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False ) - position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) - - output_tp = model_tp( - input_ids=input_ids.view(-1), - labels=labels.view(-1), - cu_seqlens=cu_seqlens, - max_seqlen=sequence_length, - position_ids=position_ids, + labels = torch.randint( + 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False ) -else: - output_tp = model_tp(input_ids=input_ids, labels=labels) - -loss_tp = output_tp.loss -logits_tp = output_tp.logits[..., : config.vocab_size] - -if torch.distributed.get_rank() == 0: - # loss computation hangs if we don't use dummy tensor parallel world size - with ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): - output = model(input_ids=input_ids, labels=labels) - - loss = output.loss - logits = output.logits if args.use_padding_free_transformer: - logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) - - error = (logits - logits_tp).abs().max() - assert error < 5e-4, f"logits don't match for normal and tensor parallel model, error is ({error})" - - error = (loss - loss_tp).abs().max() - assert error < 1e-3, f"losses don't match for normal and tensor parallel model, error is ({error})" + cu_seqlens = torch.arange( + 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() + ) + position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) + + output_tp = model_tp( + input_ids=input_ids.view(-1), + labels=labels.view(-1), + cu_seqlens=cu_seqlens, + max_seqlen=sequence_length, + position_ids=position_ids, + ) + else: + output_tp = model_tp(input_ids=input_ids, labels=labels) + + loss_tp = output_tp.loss + logits_tp = output_tp.logits[..., : config.vocab_size] + + if torch.distributed.get_rank() == 0: + # loss computation hangs if we don't use dummy tensor parallel world size + with ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): + output = model(input_ids=input_ids, labels=labels) + + loss = output.loss + logits = output.logits + + if args.use_padding_free_transformer: + logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) + + error = (logits - logits_tp).abs().max() + assert error < 5e-4, f"logits don't match for normal and tensor parallel model, error is ({error})" + + error = (loss - loss_tp).abs().max() + assert error < 1e-3, f"losses don't match for normal and tensor parallel model, error is ({error})" From 037185f7842c31208b68c0d1fd4207f2ec50cda1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 18:30:42 -0800 Subject: [PATCH 37/99] merge Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 4dfe572ab..e4d42e081 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -130,8 +130,8 @@ def get_sequence_mixer( m_width=config.m_width, A_init_min=block.A_init_min, A_init_max=block.A_init_max, - dt_min=block.dt_init_min, - dt_max=block.dt_init_max, + dt_init_min=block.dt_init_min, + dt_init_max=block.dt_init_max, dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, From f0312863ba7bb1a7d3236918a27371c0df981cfd Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 19:02:09 -0800 Subject: [PATCH 38/99] merge Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/gated_deltanet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index c16a42f3e..ed44b2808 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -163,7 +163,7 @@ def forward( beta = b.sigmoid() if self.allow_neg_eigval: - beta = beta * 2.0 + beta = beta * 2 g = self.decay_gate(x=a, final_exponential=False) From d7c6eff883d2cbcd900c183227d65d5ed792eb97 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 19:09:24 -0800 Subject: [PATCH 39/99] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 23 -------------------- lm_engine/train_utils.py | 14 ------------ 2 files changed, 37 deletions(-) diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 5b22610dc..4cd8537f8 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -26,29 +26,6 @@ def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "softmax_attention" -class _MultiHeadLatentAttentionArgs(BaseArgs): - sequence_mixer_type: str = "multihead_latent_attention" - num_attention_heads: int | None = None - softmax_dropout: float = 0 - dropout: float = 0 - add_bias: bool = False - attention_multiplier: float | None = None - sliding_window: int | None = None - query_compression_size: int | None = None - key_value_compression_size: int | None = None - num_attention_heads: int | None = None - head_dim: int | None = None - normalization_function: str = "layernorm" - - def model_post_init(self, __context: Any) -> None: - assert self.sequence_mixer_type == "multihead_latent_attention" - assert self.num_attention_heads is not None - assert self.query_compression_size is not None - assert self.key_value_compression_size is not None - assert self.num_attention_heads is not None - assert self.head_dim is not None - - class _SoftPlusDecayArgs(BaseArgs): A_init_min: float = 0 A_init_max: float = 16 diff --git a/lm_engine/train_utils.py b/lm_engine/train_utils.py index 4016732e0..201799d0d 100644 --- a/lm_engine/train_utils.py +++ b/lm_engine/train_utils.py @@ -138,20 +138,6 @@ def get_model_tflops( b * s, h, h, gradient_checkpointing=gradient_checkpointing_enabled ) - sequence_mixer_flops += _get_attention_flops(b, s, h) - elif sequence_mixer_type == "multihead_latent_attention": - # QKV down and up projection FLOPs - sequence_mixer_flops = 2 * _get_linear_flops( - b * s, - h, - block.query_compression_size + 2 * block.key_value_compression_size, - gradient_checkpointing=gradient_checkpointing_enabled, - ) - # output projection FLOPs - sequence_mixer_flops += _get_linear_flops( - b * s, h, h, gradient_checkpointing=gradient_checkpointing_enabled - ) - sequence_mixer_flops += _get_attention_flops(b, s, h) elif sequence_mixer_type == "mamba2": # NOTE taken from NexaAI's fork (might be incorrect) From 95cf3f9ad51da5ca69fefdf715f6f77d2b655ecc Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 22:22:24 -0800 Subject: [PATCH 40/99] merge Signed-off-by: Mayank Mishra --- .../modeling_utils_TP/dtensor_module.py | 22 +------------------ 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils_TP/dtensor_module.py b/lm_engine/hf_models/modeling_utils_TP/dtensor_module.py index 761f5e8c1..e4fc599dd 100644 --- a/lm_engine/hf_models/modeling_utils_TP/dtensor_module.py +++ b/lm_engine/hf_models/modeling_utils_TP/dtensor_module.py @@ -2,24 +2,4 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -from __future__ import annotations - -from typing import Any, Mapping - -import torch.nn as nn - -from ...dtensors import modify_state_dict_to_dtensor_dict - - -class DTensorModule(nn.Module): - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None: - state_dict = modify_state_dict_to_dtensor_dict(self, state_dict=state_dict, prefix="", strip_keys=False) - super().load_state_dict(state_dict, strict, assign) - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) -> None: - state_dict = modify_state_dict_to_dtensor_dict(self, state_dict=state_dict, prefix=prefix, strip_keys=True) - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) +from ..modeling_utils.dtensor_module import DTensorModule From 5aa75c1c9f246c785ec1c5d4aeb9901cf6b75bf1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 22:23:21 -0800 Subject: [PATCH 41/99] merge Signed-off-by: Mayank Mishra --- lm_engine/dtensors.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/lm_engine/dtensors.py b/lm_engine/dtensors.py index ff57a4406..edcf55104 100644 --- a/lm_engine/dtensors.py +++ b/lm_engine/dtensors.py @@ -17,7 +17,6 @@ def tensor_to_dtensor( current_placement: Placement | list[Placement], desired_placement: Placement | list[Placement] | None = None, run_check: bool = False, - copy_marker: bool = True, ) -> DTensor: if isinstance(tensor, DTensor): return tensor @@ -33,12 +32,6 @@ def tensor_to_dtensor( dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=True) - if copy_marker: - for marker in _ALL_MARKERS: - marker_value = getattr(dtensor, marker, None) - if marker_value is not None: - setattr(dtensor, marker, marker_value) - return dtensor @@ -47,7 +40,6 @@ def dtensor_to_tensor( device_mesh: DeviceMesh | None = None, desired_placement: Placement | list[Placement] | None = None, grad_placement: Placement | list[Placement] | None = None, - copy_marker: bool = True, ) -> torch.Tensor: if not isinstance(dtensor, DTensor): return dtensor @@ -65,12 +57,6 @@ def dtensor_to_tensor( tensor = dtensor.to_local(grad_placements=grad_placement) - if copy_marker: - for marker in _ALL_MARKERS: - marker_value = getattr(tensor, marker, None) - if marker_value is not None: - setattr(tensor, marker, marker_value) - return tensor From c2a767534b8aed5c9f744cf3ba315f0b3198bb19 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 22:23:36 -0800 Subject: [PATCH 42/99] merge Signed-off-by: Mayank Mishra --- lm_engine/dtensors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lm_engine/dtensors.py b/lm_engine/dtensors.py index edcf55104..7944c77d1 100644 --- a/lm_engine/dtensors.py +++ b/lm_engine/dtensors.py @@ -8,8 +8,6 @@ from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import DeviceMesh -from .hf_models.parameter import _ALL_MARKERS - def tensor_to_dtensor( tensor: torch.Tensor, From dcb892f195be29e9e63369f3942969881cb0c864 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 22:29:21 -0800 Subject: [PATCH 43/99] merge Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/__init__.py | 1 + lm_engine/hf_models/modeling_utils_TP/__init__.py | 1 - .../hf_models/modeling_utils_TP/dtensor_module.py | 5 ----- lm_engine/hf_models/modeling_utils_TP/linear.py | 3 +-- .../hf_models/modeling_utils_TP/mlp_blocks/moe.py | 11 +++++++++-- .../hf_models/modeling_utils_TP/normalization.py | 2 +- 6 files changed, 12 insertions(+), 11 deletions(-) delete mode 100644 lm_engine/hf_models/modeling_utils_TP/dtensor_module.py diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index e9de3c937..97b265adb 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -5,6 +5,7 @@ from .activations import get_activation_function, is_glu from .convolution import ParameterizedConv1d from .dropout import Dropout +from .dtensor_module import DTensorModule from .embedding import ParameterizedEmbedding, get_tensor_parallel_vocab_info from .linear import ParameterizedLinear from .mlp_blocks import ( diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py index 799042884..30884bbaa 100644 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/__init__.py @@ -2,7 +2,6 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -from .dtensor_module import DTensorModule from .linear import ColumnParallelLinear, RowParallelLinear from .lm_head import LMHead_TP from .mlp_blocks import MLP_TP, MoE_TP, get_mlp_block_TP diff --git a/lm_engine/hf_models/modeling_utils_TP/dtensor_module.py b/lm_engine/hf_models/modeling_utils_TP/dtensor_module.py deleted file mode 100644 index e4fc599dd..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/dtensor_module.py +++ /dev/null @@ -1,5 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from ..modeling_utils.dtensor_module import DTensorModule diff --git a/lm_engine/hf_models/modeling_utils_TP/linear.py b/lm_engine/hf_models/modeling_utils_TP/linear.py index 40be06ee2..0829855a7 100644 --- a/lm_engine/hf_models/modeling_utils_TP/linear.py +++ b/lm_engine/hf_models/modeling_utils_TP/linear.py @@ -10,8 +10,7 @@ from ...dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel from ...utils import ProcessGroupManager, divide_if_divisible -from ..modeling_utils import ParameterizedLinear -from .dtensor_module import DTensorModule +from ..modeling_utils import DTensorModule, ParameterizedLinear from .TP import get_module_placements diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index 2af05f021..24e9b2353 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -16,9 +16,16 @@ from ....kernels import is_kernel_allowed, wait_for_ACT from ....utils import ProcessGroupManager, divide_if_divisible, is_xma_available from ...loss import add_aux_loss -from ...modeling_utils import Dropout, MoE, ParameterizedExperts, ParameterizedLinear, get_activation_function, is_glu +from ...modeling_utils import ( + Dropout, + DTensorModule, + MoE, + ParameterizedExperts, + ParameterizedLinear, + get_activation_function, + is_glu, +) from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear -from ..dtensor_module import DTensorModule from ..linear import ColumnParallelLinear, RowParallelLinear diff --git a/lm_engine/hf_models/modeling_utils_TP/normalization.py b/lm_engine/hf_models/modeling_utils_TP/normalization.py index 737255d85..8efe72b54 100644 --- a/lm_engine/hf_models/modeling_utils_TP/normalization.py +++ b/lm_engine/hf_models/modeling_utils_TP/normalization.py @@ -12,7 +12,7 @@ from ...enums import Kernel from ...kernels import is_kernel_allowed, wait_for_ACT from ...utils import ProcessGroupManager, is_xma_available -from .dtensor_module import DTensorModule +from ..modeling_utils import DTensorModule from .TP import get_module_placements From 98986456c1523892b9cb22953989741a9ae043cf Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 22:30:08 -0800 Subject: [PATCH 44/99] merge Signed-off-by: Mayank Mishra --- tests/hf_models/multi_gpu/dcp/train.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/hf_models/multi_gpu/dcp/train.yml b/tests/hf_models/multi_gpu/dcp/train.yml index 30acd2e83..a6253e911 100644 --- a/tests/hf_models/multi_gpu/dcp/train.yml +++ b/tests/hf_models/multi_gpu/dcp/train.yml @@ -179,6 +179,7 @@ model_args: - mlp_type: MLP activation_function: swiglu add_bias: false + efficient_initialization: true tuning_args: tuning_method: pretraining From 193c11f53c6abbd9e74be946f13530dddc33a2e1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 22:47:58 -0800 Subject: [PATCH 45/99] merge Signed-off-by: Mayank Mishra --- lm_engine/dtensors.py | 16 ++++++++++++++++ lm_engine/hf_models/mixins/dense/main.py | 13 ------------- lm_engine/hf_models/parameter.py | 3 +++ 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/lm_engine/dtensors.py b/lm_engine/dtensors.py index 7944c77d1..ff57a4406 100644 --- a/lm_engine/dtensors.py +++ b/lm_engine/dtensors.py @@ -8,6 +8,8 @@ from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import DeviceMesh +from .hf_models.parameter import _ALL_MARKERS + def tensor_to_dtensor( tensor: torch.Tensor, @@ -15,6 +17,7 @@ def tensor_to_dtensor( current_placement: Placement | list[Placement], desired_placement: Placement | list[Placement] | None = None, run_check: bool = False, + copy_marker: bool = True, ) -> DTensor: if isinstance(tensor, DTensor): return tensor @@ -30,6 +33,12 @@ def tensor_to_dtensor( dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=True) + if copy_marker: + for marker in _ALL_MARKERS: + marker_value = getattr(dtensor, marker, None) + if marker_value is not None: + setattr(dtensor, marker, marker_value) + return dtensor @@ -38,6 +47,7 @@ def dtensor_to_tensor( device_mesh: DeviceMesh | None = None, desired_placement: Placement | list[Placement] | None = None, grad_placement: Placement | list[Placement] | None = None, + copy_marker: bool = True, ) -> torch.Tensor: if not isinstance(dtensor, DTensor): return dtensor @@ -55,6 +65,12 @@ def dtensor_to_tensor( tensor = dtensor.to_local(grad_placements=grad_placement) + if copy_marker: + for marker in _ALL_MARKERS: + marker_value = getattr(tensor, marker, None) + if marker_value is not None: + setattr(tensor, marker, marker_value) + return tensor diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 4ec016803..550a41e32 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -37,19 +37,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.m_width = config.m_width - def get_input_embeddings(self) -> ParameterizedEmbedding: - return self.transformer.wte - - def set_input_embeddings(self, value: ParameterizedEmbedding) -> None: - self.transformer.wte = value - - def get_output_embeddings(self) -> ParameterizedLinear: - return self.transformer.wte if self._tied_word_embeddings else self.lm_head - - def set_output_embeddings(self, new_embeddings: ParameterizedLinear) -> None: - if not self._tied_word_embeddings: - self.lm_head = new_embeddings - def forward( self, input_ids: torch.Tensor | list[list[int]] | None = None, diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 4ab6e6834..b86da83b2 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -5,6 +5,9 @@ import torch.nn as nn +_ALL_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", "_is_initialized"] + + def mark_parameter_as_no_weight_decay(parameter: nn.Parameter | None) -> nn.Parameter | None: if parameter is not None: parameter._no_weight_decay = True From 843f57d4c032daee7e1899c38dbca63307e239d5 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Sun, 8 Feb 2026 22:57:35 -0800 Subject: [PATCH 46/99] count correctly Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 2 +- lm_engine/model_wrapper/base.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 550a41e32..98e0a1303 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -13,7 +13,7 @@ from ...cache import GenerationCache from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero -from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear +from ...modeling_utils import ParameterizedLinear from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .base import PreTrainedModelMixin diff --git a/lm_engine/model_wrapper/base.py b/lm_engine/model_wrapper/base.py index 2fbc5841c..3284b8271 100644 --- a/lm_engine/model_wrapper/base.py +++ b/lm_engine/model_wrapper/base.py @@ -205,7 +205,11 @@ def _setup_model(self) -> None: def calculate_num_parameters(self) -> tuple[int, int]: model_kwargs = self._get_model_kwargs() - with torch.device("meta"): + with ( + torch.device("meta"), + ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), + ): if self.model_name is not None: model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs.pop("pretrained_model_name_or_path")) From 40efdca1c77d522a5493aec8c3b9254b1a3c8793 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 00:39:08 -0800 Subject: [PATCH 47/99] count correctly Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 53 ++++++++++++-------------------- lm_engine/hf_models/__init__.py | 2 ++ lm_engine/hf_models/parameter.py | 30 ++++++++++++++++++ 3 files changed, 51 insertions(+), 34 deletions(-) diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 8fbd19d9c..15c0a456d 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -30,7 +30,12 @@ from .containers import ModelContainer from .enums import Kernel from .gradient_checkpointing import apply_gradient_checkpointing -from .hf_models import CausalLMOutputWithPast, is_parameter_initialized +from .hf_models import ( + CausalLMOutputWithPast, + get_parameter_marker_maps, + is_parameter_initialized, + set_parameter_marker_maps, +) from .kernels import is_kernel_allowed from .utils import ( Accelerator, @@ -119,36 +124,6 @@ def _get_fsdp_mixed_precision( return mixed_precision -def _get_parameter_marker_maps(model_container: ModelContainer, extra_markers: list[str] = []) -> list[dict]: - marker_maps = [] - for model in model_container: - marker_maps.append({}) - for param_name, param in model.named_parameters(): - marker_maps[-1][param_name] = {} - for marker in ["_no_weight_decay", "_has_mup_learning_rate"] + extra_markers: - marker_maps[-1][param_name][marker] = getattr(param, marker, False) - - return marker_maps - - -def _set_parameter_marker_maps(model_container: ModelContainer, marker_maps: list[dict]) -> None: - for model, _marker_map in zip(model_container, marker_maps): - for param_name, parameter in model.named_parameters(): - # handle FSDP for TPU - param_name = param_name.replace(_FSDP_TPU_SHARD_SEPARATOR, ".") - param_name = param_name.replace(f"{_FSDP_TPU_SHARD}.", "") - param_name = param_name.replace(f"{_FSDP_TPU_FPW}.", "") - - # handle FSDP-1 - param_name = param_name.replace(f"{_FSDP_1_STRING}.", "") - - # handle torch compile - param_name = param_name.replace(f"{_TORCH_COMPILE_STRING}.", "") - - for marker, value in _marker_map[param_name].items(): - setattr(parameter, marker, value) - - def wrap_model_container_for_distributed_training( args: TrainingArgs, model_container: ModelContainer ) -> tuple[ModelContainer, _PipelineSchedule]: @@ -229,9 +204,9 @@ def wrap_model_container_for_distributed_training( for param_name, parameter in model.named_buffers(): parameter._is_initialized = False - marker_maps = _get_parameter_marker_maps(model_container) + marker_maps = get_parameter_marker_maps(model_container) else: - marker_maps = _get_parameter_marker_maps(model_container, extra_markers=["_is_initialized"]) + marker_maps = get_parameter_marker_maps(model_container, extra_markers=["_is_initialized"]) accelerator = Accelerator.get_accelerator() @@ -387,7 +362,17 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: for i, model in enumerate(model_container): model_container[i] = torch.compile(model) - _set_parameter_marker_maps(model_container, marker_maps) + set_parameter_marker_maps( + model_container, + marker_maps, + replacement_patterns=[ + (_FSDP_TPU_SHARD_SEPARATOR, "."), + (f"{_FSDP_TPU_SHARD}.", ""), + (f"{_FSDP_TPU_FPW}.", ""), + (f"{_FSDP_1_STRING}.", ""), + (f"{_TORCH_COMPILE_STRING}.", ""), + ], + ) pipeline_stages = [] pipeline_schedule = None diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index 75341b75b..b84448623 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -23,12 +23,14 @@ PaLMModel, ) from .parameter import ( + get_parameter_marker_maps, is_parameter_initialized, is_parameter_with_mup_learning_rate, is_parameter_with_no_weight_decay, mark_parameter_as_initialized, mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay, + set_parameter_marker_maps, ) from .register_hf import get_model_parallel_class, is_custom_model, register_model_classes from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index b86da83b2..471dccd16 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -39,3 +39,33 @@ def is_parameter_with_mup_learning_rate(parameter: nn.Parameter | None) -> bool: def is_parameter_initialized(parameter: nn.Parameter | None) -> bool: return getattr(parameter, "_is_initialized", False) + + +def get_parameter_marker_maps(model_container: list[nn.Module], extra_markers: list[str] = []) -> list[dict]: + if isinstance(model_container, nn.Module): + model_container = [model_container] + + marker_maps = [] + for model in model_container: + marker_maps.append({}) + for param_name, param in model.named_parameters(): + marker_maps[-1][param_name] = {} + for marker in ["_no_weight_decay", "_has_mup_learning_rate"] + extra_markers: + marker_maps[-1][param_name][marker] = getattr(param, marker, False) + + return marker_maps + + +def set_parameter_marker_maps( + model_container: list[nn.Module], marker_maps: list[dict], replacement_patterns: list[tuple[str]] = [] +) -> None: + if isinstance(model_container, nn.Module): + model_container = [model_container] + + for model, _marker_map in zip(model_container, marker_maps): + for param_name, parameter in model.named_parameters(): + for pattern, replacement in replacement_patterns: + param_name = param_name.replace(pattern, replacement) + + for marker, value in _marker_map[param_name].items(): + setattr(parameter, marker, value) From b049fde585628f310e7c273461e7da0ea6bb377a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 00:47:00 -0800 Subject: [PATCH 48/99] count correctly Signed-off-by: Mayank Mishra --- lm_engine/distributed.py | 3 ++- lm_engine/hf_models/__init__.py | 1 + lm_engine/hf_models/mixins/dense_TP/main.py | 4 ++++ lm_engine/hf_models/parameter.py | 8 +++++--- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 15c0a456d..5e162231f 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -31,6 +31,7 @@ from .enums import Kernel from .gradient_checkpointing import apply_gradient_checkpointing from .hf_models import ( + _INIT_MARKER, CausalLMOutputWithPast, get_parameter_marker_maps, is_parameter_initialized, @@ -206,7 +207,7 @@ def wrap_model_container_for_distributed_training( marker_maps = get_parameter_marker_maps(model_container) else: - marker_maps = get_parameter_marker_maps(model_container, extra_markers=["_is_initialized"]) + marker_maps = get_parameter_marker_maps(model_container, extra_markers=[_INIT_MARKER]) accelerator = Accelerator.get_accelerator() diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index b84448623..cd306993e 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -23,6 +23,7 @@ PaLMModel, ) from .parameter import ( + _INIT_MARKER, get_parameter_marker_maps, is_parameter_initialized, is_parameter_with_mup_learning_rate, diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index f614b1cc0..6af566025 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -21,6 +21,7 @@ is_aux_loss_zero, ) from ...modeling_utils_TP import LMHead_TP +from ...parameter import _INIT_MARKER, get_parameter_marker_maps, set_parameter_marker_maps from ..dense import CausalLMModelMixin from ..modeling_outputs import ( BaseModelOutputWithPast, @@ -186,7 +187,10 @@ def from_pretrained( with torch.device("meta"): # try sharding vocab matrices if really struggling for memory model = cls._from_config(config, **kwargs) + marker_maps = get_parameter_marker_maps([model], extra_markers=[_INIT_MARKER]) + model = model.to(dtype=dtype) + set_parameter_marker_maps([model], marker_maps) # copy to device without copying storage model = model.to_empty(device=torch.cuda.current_device()) diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 471dccd16..33427b9b3 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -5,7 +5,9 @@ import torch.nn as nn -_ALL_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", "_is_initialized"] +_INIT_MARKER = "_is_initialized" +_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate"] +_ALL_MARKERS = _METADATA_MARKERS + [_INIT_MARKER] def mark_parameter_as_no_weight_decay(parameter: nn.Parameter | None) -> nn.Parameter | None: @@ -38,7 +40,7 @@ def is_parameter_with_mup_learning_rate(parameter: nn.Parameter | None) -> bool: def is_parameter_initialized(parameter: nn.Parameter | None) -> bool: - return getattr(parameter, "_is_initialized", False) + return getattr(parameter, _INIT_MARKER, False) def get_parameter_marker_maps(model_container: list[nn.Module], extra_markers: list[str] = []) -> list[dict]: @@ -50,7 +52,7 @@ def get_parameter_marker_maps(model_container: list[nn.Module], extra_markers: l marker_maps.append({}) for param_name, param in model.named_parameters(): marker_maps[-1][param_name] = {} - for marker in ["_no_weight_decay", "_has_mup_learning_rate"] + extra_markers: + for marker in _METADATA_MARKERS + extra_markers: marker_maps[-1][param_name][marker] = getattr(param, marker, False) return marker_maps From 137ca9d9551fd2849680a98302fa24f278d54117 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 00:52:45 -0800 Subject: [PATCH 49/99] count correctly Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense_TP/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index 6af566025..0a03389d2 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -190,10 +190,11 @@ def from_pretrained( marker_maps = get_parameter_marker_maps([model], extra_markers=[_INIT_MARKER]) model = model.to(dtype=dtype) - set_parameter_marker_maps([model], marker_maps) # copy to device without copying storage model = model.to_empty(device=torch.cuda.current_device()) + set_parameter_marker_maps([model], marker_maps) + model.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(pretrained_model_name_or_path)) return model From 79a69ddfd201970bde2250181b30886315cdd6c4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:25:44 -0800 Subject: [PATCH 50/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense_TP/base.py | 5 +- lm_engine/hf_models/mixins/dense_TP/layer.py | 7 +- .../hf_models/modeling_utils/normalization.py | 112 ++++++++++++++-- .../hf_models/modeling_utils_TP/__init__.py | 1 - .../modeling_utils_TP/normalization.py | 121 ------------------ 5 files changed, 107 insertions(+), 139 deletions(-) delete mode 100644 lm_engine/hf_models/modeling_utils_TP/normalization.py diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index 28bedc56b..ddd504a75 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -10,8 +10,7 @@ from ....utils import ProcessGroupManager, divide_if_divisible from ...cache import GenerationCache from ...config import CommonConfig -from ...modeling_utils import Dropout, ParameterizedEmbedding, RoPE, YaRNScaledRoPE -from ...modeling_utils_TP import get_normalization_function_TP +from ...modeling_utils import Dropout, ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function from ...utils import is_generation_cache_enabled from ..dense import BaseModelMixin, PreTrainedModelMixin from ..modeling_outputs import BaseModelOutputWithPast @@ -81,7 +80,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: ) if self.is_last_stage: - self.ln_f = get_normalization_function_TP( + self.ln_f = get_normalization_function( config.normalization_function, self.embed_dim, eps=config.layer_norm_epsilon, diff --git a/lm_engine/hf_models/mixins/dense_TP/layer.py b/lm_engine/hf_models/mixins/dense_TP/layer.py index d09109ce3..dae1fb966 100644 --- a/lm_engine/hf_models/mixins/dense_TP/layer.py +++ b/lm_engine/hf_models/mixins/dense_TP/layer.py @@ -7,7 +7,8 @@ import torch.nn as nn from ...config import CommonConfig -from ...modeling_utils_TP import get_mlp_block_TP, get_normalization_function_TP, get_sequence_mixer_TP +from ...modeling_utils import get_normalization_function +from ...modeling_utils_TP import get_mlp_block_TP, get_sequence_mixer_TP from ..dense import Block @@ -25,7 +26,7 @@ def __init__( self.m_residual = config.m_residual self.sequence_mixer_type = config.sequence_mixer_blocks[layer_idx].sequence_mixer_type - self.ln_1 = get_normalization_function_TP( + self.ln_1 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon, @@ -39,7 +40,7 @@ def __init__( layer_idx=layer_idx, sequence_parallel=sequence_parallel, ) - self.ln_2 = get_normalization_function_TP( + self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon, diff --git a/lm_engine/hf_models/modeling_utils/normalization.py b/lm_engine/hf_models/modeling_utils/normalization.py index c518577c2..2e7e93e8b 100644 --- a/lm_engine/hf_models/modeling_utils/normalization.py +++ b/lm_engine/hf_models/modeling_utils/normalization.py @@ -7,35 +7,102 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed._tensor.placement_types import Replicate +from ...dtensors import dtensor_to_tensor, tensor_to_dtensor from ...enums import Kernel -from ...kernels import is_kernel_allowed -from ...utils import is_xma_available +from ...kernels import is_kernel_allowed, wait_for_ACT +from ...utils import ProcessGroupManager, is_xma_available from ..parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay +from .dtensor_module import DTensorModule +from .TP import get_module_placements if is_xma_available(): from xma import rmsnorm -class LayerNorm(nn.LayerNorm): +class LayerNorm(nn.LayerNorm, DTensorModule): + def __init__( + self, + normalized_shape: int, + eps: float = 1e-6, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> LayerNorm: + super().__init__(normalized_shape, eps=eps) + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.weight = nn.Parameter( + tensor_to_dtensor(self.weight, device_mesh=self.tp_mesh, current_placement=Replicate()) + ) + + self.bias = nn.Parameter( + tensor_to_dtensor(self.bias, device_mesh=self.tp_mesh, current_placement=Replicate()) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) + + x = super().forward(x) + + if self.is_tp_enabled: + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.placement) + + return x + def reset_parameters(self) -> None: super().reset_parameters() mark_parameter_as_initialized(self.weight) -class RMSNorm(nn.RMSNorm): +class RMSNorm(nn.RMSNorm, DTensorModule): + def __init__( + self, + normalized_shape: int, + eps: float = 1e-6, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> RMSNorm: + super().__init__(normalized_shape, eps=eps) + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.weight = nn.Parameter( + tensor_to_dtensor(self.weight, device_mesh=self.tp_mesh, current_placement=Replicate()) + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) + if is_kernel_allowed(Kernel.rmsnorm) or is_kernel_allowed(Kernel.rmsnorm_memory_efficient): + x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False) + x = rmsnorm( x=x, weight=self.weight, eps=self.eps, memory_efficient=is_kernel_allowed(Kernel.rmsnorm_memory_efficient), ) + + x = wait_for_ACT(x, wait_in_forward=False, wait_in_backward=True) else: x = super().forward(x) + if self.is_tp_enabled: + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.placement) + return x def reset_parameters(self) -> None: @@ -49,14 +116,22 @@ def __init__( normalized_shape: int, p: int, eps: float | None = None, - elementwise_affine=True, - device: torch.device = None, - dtype: torch.dtype = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, ) -> PNorm: self.p = p - super().__init__(normalized_shape, eps, elementwise_affine, device, dtype) + + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) + dtype = x.dtype x = x.float() @@ -66,6 +141,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.weight is not None: x = self.weight * x + if self.is_tp_enabled: + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.placement) + return x @@ -73,18 +151,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def get_normalization_function( - normalization_function: str, normalized_shape: int, eps: float = 1e-5, p: int | None = None + normalization_function: str, + normalized_shape: int, + eps: float = 1e-5, + p: int | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, ) -> LayerNorm | RMSNorm | PNorm: if normalization_function is None: return nn.Identity() + kwargs = { + "normalized_shape": normalized_shape, + "eps": eps, + "use_padding_free_transformer": use_padding_free_transformer, + "sequence_parallel": sequence_parallel, + } + if normalization_function in _NORMALIZATION_FUNCTIONS: if normalization_function == "p_norm": assert p is not None - normalization = _NORMALIZATION_FUNCTIONS[normalization_function](normalized_shape, eps=eps, p=p) + normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs, p=p) else: assert p is None - normalization = _NORMALIZATION_FUNCTIONS[normalization_function](normalized_shape, eps=eps) + normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs) else: raise ValueError(f"unexpected `normalization_function` {normalization_function}") diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py index 30884bbaa..634ff1bcf 100644 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/__init__.py @@ -5,6 +5,5 @@ from .linear import ColumnParallelLinear, RowParallelLinear from .lm_head import LMHead_TP from .mlp_blocks import MLP_TP, MoE_TP, get_mlp_block_TP -from .normalization import get_normalization_function_TP from .sequence_mixer_blocks import Attention_TP, get_sequence_mixer_TP from .TP import get_module_placements, tensor_parallel_split_safetensor_slice diff --git a/lm_engine/hf_models/modeling_utils_TP/normalization.py b/lm_engine/hf_models/modeling_utils_TP/normalization.py deleted file mode 100644 index 8efe72b54..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/normalization.py +++ /dev/null @@ -1,121 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import torch -import torch.nn as nn -from torch.distributed._tensor.placement_types import Partial, Replicate - -from ...dtensors import dtensor_to_tensor, tensor_to_dtensor -from ...enums import Kernel -from ...kernels import is_kernel_allowed, wait_for_ACT -from ...utils import ProcessGroupManager, is_xma_available -from ..modeling_utils import DTensorModule -from .TP import get_module_placements - - -if is_xma_available(): - from xma import rmsnorm - - -class LayerNorm_TP(nn.LayerNorm, DTensorModule): - def __init__( - self, - normalized_shape: int, - eps: float = 1e-6, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> LayerNorm_TP: - super().__init__(normalized_shape, eps=eps) - - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, - device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), - current_placement=Replicate(), - ) - ) - self.bias = nn.Parameter( - tensor_to_dtensor( - self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() - ) - ) - - self.placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=self.placement) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.placement) - return input - - -class RMSNorm_TP(nn.RMSNorm, DTensorModule): - def __init__( - self, - normalized_shape: int, - eps: float = 1e-6, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> RMSNorm_TP: - super().__init__(normalized_shape, eps=eps) - - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() - ) - ) - - self.sequence_parallel = sequence_parallel - self.placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - rmsnorm_kernel_allowed = is_kernel_allowed(Kernel.rmsnorm) - rmsnorm_memory_efficient_kernel_allowed = is_kernel_allowed(Kernel.rmsnorm_memory_efficient) - - if rmsnorm_kernel_allowed or rmsnorm_memory_efficient_kernel_allowed: - input = wait_for_ACT(input, wait_in_forward=True, wait_in_backward=False) - input = rmsnorm( - x=input, - weight=dtensor_to_tensor( - self.weight, grad_placement=Partial() if self.sequence_parallel else Replicate() - ), - eps=self.eps, - memory_efficient=rmsnorm_memory_efficient_kernel_allowed, - ) - input = wait_for_ACT(input, wait_in_forward=False, wait_in_backward=True) - else: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=self.placement) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.placement) - - return input - - -_NORMALIZATION_FUNCTIONS = {"layernorm": LayerNorm_TP, "rmsnorm": RMSNorm_TP} - - -def get_normalization_function_TP( - normalization_function: str, - normalized_shape: int, - eps: float = 1e-5, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, -) -> LayerNorm_TP | RMSNorm_TP: - if normalization_function in _NORMALIZATION_FUNCTIONS: - normalization = _NORMALIZATION_FUNCTIONS[normalization_function]( - normalized_shape, - eps=eps, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - else: - raise ValueError(f"unexpected `normalization_function` {normalization_function}") - - return normalization From 554ebe3e60b8034fc64f6689ca3a92ebfc578d65 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:27:08 -0800 Subject: [PATCH 51/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/normalization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lm_engine/hf_models/modeling_utils/normalization.py b/lm_engine/hf_models/modeling_utils/normalization.py index 2e7e93e8b..f59424cf5 100644 --- a/lm_engine/hf_models/modeling_utils/normalization.py +++ b/lm_engine/hf_models/modeling_utils/normalization.py @@ -60,6 +60,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def reset_parameters(self) -> None: super().reset_parameters() mark_parameter_as_initialized(self.weight) + mark_parameter_as_initialized(self.bias) class RMSNorm(nn.RMSNorm, DTensorModule): From 60eadaa30e134f7a05549e2ad48f3a695428e363 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:32:48 -0800 Subject: [PATCH 52/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py index fe8c8470e..f207749da 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py @@ -53,7 +53,5 @@ def __init__( ) self.dropout = Dropout( - dropout, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, + dropout, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel ) From 6394358677033aec23eb9440a6b0e68d29ec13e4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:36:17 -0800 Subject: [PATCH 53/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/normalization.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lm_engine/hf_models/modeling_utils/normalization.py b/lm_engine/hf_models/modeling_utils/normalization.py index f59424cf5..5f7ecdd7f 100644 --- a/lm_engine/hf_models/modeling_utils/normalization.py +++ b/lm_engine/hf_models/modeling_utils/normalization.py @@ -46,6 +46,8 @@ def __init__( tensor_to_dtensor(self.bias, device_mesh=self.tp_mesh, current_placement=Replicate()) ) + self.reset_parameters() + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_tp_enabled: x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) @@ -83,6 +85,8 @@ def __init__( tensor_to_dtensor(self.weight, device_mesh=self.tp_mesh, current_placement=Replicate()) ) + self.reset_parameters() + def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_tp_enabled: x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) From 0e29579535d93942f2390fb0f16ca1a58c206d80 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:38:58 -0800 Subject: [PATCH 54/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils_TP/linear.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lm_engine/hf_models/modeling_utils_TP/linear.py b/lm_engine/hf_models/modeling_utils_TP/linear.py index 0829855a7..bec0efed7 100644 --- a/lm_engine/hf_models/modeling_utils_TP/linear.py +++ b/lm_engine/hf_models/modeling_utils_TP/linear.py @@ -58,6 +58,8 @@ def __init__( self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + self.reset_parameters() + if use_async_tensor_parallel(): self.compile() @@ -121,6 +123,8 @@ def __init__( self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + self.reset_parameters() + if use_async_tensor_parallel(): self.compile() From 4420224b75946d56f5e44aa2c67eb7dbd31175e2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:43:37 -0800 Subject: [PATCH 55/99] norm Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/__init__.py | 2 +- .../modeling_utils/linear/__init__.py | 7 ++ .../{linear.py => linear/base.py} | 2 +- .../linear/column.py} | 72 ++--------------- .../hf_models/modeling_utils/linear/row.py | 78 +++++++++++++++++++ .../hf_models/modeling_utils_TP/__init__.py | 1 - .../modeling_utils_TP/mlp_blocks/mlp.py | 3 +- .../modeling_utils_TP/mlp_blocks/moe.py | 3 +- .../sequence_mixer_blocks/attention.py | 10 ++- 9 files changed, 103 insertions(+), 75 deletions(-) create mode 100644 lm_engine/hf_models/modeling_utils/linear/__init__.py rename lm_engine/hf_models/modeling_utils/{linear.py => linear/base.py} (92%) rename lm_engine/hf_models/{modeling_utils_TP/linear.py => modeling_utils/linear/column.py} (50%) create mode 100644 lm_engine/hf_models/modeling_utils/linear/row.py diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index 97b265adb..b615c6101 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -7,7 +7,7 @@ from .dropout import Dropout from .dtensor_module import DTensorModule from .embedding import ParameterizedEmbedding, get_tensor_parallel_vocab_info -from .linear import ParameterizedLinear +from .linear import ColumnParallelLinear, ParameterizedLinear, RowParallelLinear from .mlp_blocks import ( MLP, MoE, diff --git a/lm_engine/hf_models/modeling_utils/linear/__init__.py b/lm_engine/hf_models/modeling_utils/linear/__init__.py new file mode 100644 index 000000000..b4a6ac5e0 --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/linear/__init__.py @@ -0,0 +1,7 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from .base import ParameterizedLinear +from .column import ColumnParallelLinear +from .row import RowParallelLinear diff --git a/lm_engine/hf_models/modeling_utils/linear.py b/lm_engine/hf_models/modeling_utils/linear/base.py similarity index 92% rename from lm_engine/hf_models/modeling_utils/linear.py rename to lm_engine/hf_models/modeling_utils/linear/base.py index 9ac15302a..f23847943 100644 --- a/lm_engine/hf_models/modeling_utils/linear.py +++ b/lm_engine/hf_models/modeling_utils/linear/base.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn -from ..parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay +from ...parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay class ParameterizedLinear(nn.Linear): diff --git a/lm_engine/hf_models/modeling_utils_TP/linear.py b/lm_engine/hf_models/modeling_utils/linear/column.py similarity index 50% rename from lm_engine/hf_models/modeling_utils_TP/linear.py rename to lm_engine/hf_models/modeling_utils/linear/column.py index bec0efed7..a08a4ef96 100644 --- a/lm_engine/hf_models/modeling_utils_TP/linear.py +++ b/lm_engine/hf_models/modeling_utils/linear/column.py @@ -8,10 +8,11 @@ import torch.nn as nn from torch.distributed._tensor.placement_types import Replicate, Shard -from ...dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel -from ...utils import ProcessGroupManager, divide_if_divisible -from ..modeling_utils import DTensorModule, ParameterizedLinear -from .TP import get_module_placements +from ....dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel +from ....utils import ProcessGroupManager, divide_if_divisible +from ..dtensor_module import DTensorModule +from ..TP import get_module_placements +from .base import ParameterizedLinear class ColumnParallelLinear(ParameterizedLinear, DTensorModule): @@ -75,66 +76,3 @@ def extra_repr(self) -> str: return "in_features={}, out_features_per_device={}, bias={}".format( self.in_features, self.out_features_per_device, self.bias is not None ) - - -class RowParallelLinear(ParameterizedLinear, DTensorModule): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> RowParallelLinear: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.in_features_per_device = divide_if_divisible( - in_features, - tp_world_size, - f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", - ) - - super().__init__( - in_features=self.in_features_per_device, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - std=std, - ) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) - ) - ) - if bias: - self.bias = nn.Parameter( - tensor_to_dtensor( - self.bias, - device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), - current_placement=Replicate(), - ) - ) - - self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - - self.reset_parameters() - - if use_async_tensor_parallel(): - self.compile() - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.output_placement) - return input - - def extra_repr(self) -> str: - return "in_features_per_device={}, out_features={}, bias={}".format( - self.in_features_per_device, self.out_features, self.bias is not None - ) diff --git a/lm_engine/hf_models/modeling_utils/linear/row.py b/lm_engine/hf_models/modeling_utils/linear/row.py new file mode 100644 index 000000000..7cadf88a6 --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/linear/row.py @@ -0,0 +1,78 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.distributed._tensor.placement_types import Replicate, Shard + +from ....dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel +from ....utils import ProcessGroupManager, divide_if_divisible +from ..dtensor_module import DTensorModule +from ..TP import get_module_placements +from .base import ParameterizedLinear + + +class RowParallelLinear(ParameterizedLinear, DTensorModule): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + std: float | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> RowParallelLinear: + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + + self.in_features_per_device = divide_if_divisible( + in_features, + tp_world_size, + f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", + ) + + super().__init__( + in_features=self.in_features_per_device, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + std=std, + ) + + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) + ) + ) + if bias: + self.bias = nn.Parameter( + tensor_to_dtensor( + self.bias, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), + ) + ) + + self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.reset_parameters() + + if use_async_tensor_parallel(): + self.compile() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=Shard(-1)) + input = super().forward(input) + input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.output_placement) + return input + + def extra_repr(self) -> str: + return "in_features_per_device={}, out_features={}, bias={}".format( + self.in_features_per_device, self.out_features, self.bias is not None + ) diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py index 634ff1bcf..b75386d3b 100644 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/__init__.py @@ -2,7 +2,6 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -from .linear import ColumnParallelLinear, RowParallelLinear from .lm_head import LMHead_TP from .mlp_blocks import MLP_TP, MoE_TP, get_mlp_block_TP from .sequence_mixer_blocks import Attention_TP, get_sequence_mixer_TP diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py index f207749da..ce8091ab8 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py @@ -8,9 +8,8 @@ import torch.nn as nn -from ...modeling_utils import MLP, Dropout, get_activation_function, is_glu +from ...modeling_utils import MLP, ColumnParallelLinear, Dropout, RowParallelLinear, get_activation_function, is_glu from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear -from ..linear import ColumnParallelLinear, RowParallelLinear class MLP_TP(MLP): diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index 24e9b2353..0dcbffd8b 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -17,16 +17,17 @@ from ....utils import ProcessGroupManager, divide_if_divisible, is_xma_available from ...loss import add_aux_loss from ...modeling_utils import ( + ColumnParallelLinear, Dropout, DTensorModule, MoE, ParameterizedExperts, ParameterizedLinear, + RowParallelLinear, get_activation_function, is_glu, ) from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear -from ..linear import ColumnParallelLinear, RowParallelLinear if is_xma_available(): diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py index 8e7851b70..1b0944f94 100644 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py @@ -14,9 +14,15 @@ from ....kernels import is_kernel_allowed, wait_for_ACT from ....utils import ProcessGroupManager, divide_if_divisible from ...cache import GenerationCache -from ...modeling_utils import Attention, Dropout, apply_rotary_pos_emb, flash_attention +from ...modeling_utils import ( + Attention, + ColumnParallelLinear, + Dropout, + RowParallelLinear, + apply_rotary_pos_emb, + flash_attention, +) from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear -from ..linear import ColumnParallelLinear, RowParallelLinear class Attention_TP(Attention): From 462acb1fdc926cf725a4540fce80f1051dd2dbff Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:45:47 -0800 Subject: [PATCH 56/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/linear/column.py | 12 ++++++------ lm_engine/hf_models/modeling_utils/linear/row.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/linear/column.py b/lm_engine/hf_models/modeling_utils/linear/column.py index a08a4ef96..2003de0c4 100644 --- a/lm_engine/hf_models/modeling_utils/linear/column.py +++ b/lm_engine/hf_models/modeling_utils/linear/column.py @@ -64,13 +64,13 @@ def __init__( if use_async_tensor_parallel(): self.compile() - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = tensor_to_dtensor( - input, device_mesh=self.tp_mesh, current_placement=self.input_placement, desired_placement=Replicate() + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = tensor_to_dtensor( + x, device_mesh=self.tp_mesh, current_placement=self.input_placement, desired_placement=Replicate() ) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=Shard(-1)) - return input + x = super().forward(x) + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=Shard(-1)) + return x def extra_repr(self) -> str: return "in_features={}, out_features_per_device={}, bias={}".format( diff --git a/lm_engine/hf_models/modeling_utils/linear/row.py b/lm_engine/hf_models/modeling_utils/linear/row.py index 7cadf88a6..d91015cdb 100644 --- a/lm_engine/hf_models/modeling_utils/linear/row.py +++ b/lm_engine/hf_models/modeling_utils/linear/row.py @@ -66,11 +66,11 @@ def __init__( if use_async_tensor_parallel(): self.compile() - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.output_placement) - return input + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Shard(-1)) + x = super().forward(x) + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.output_placement) + return x def extra_repr(self) -> str: return "in_features_per_device={}, out_features={}, bias={}".format( From ec473cb45f1d36bdb00ff686ea75fbdc69452539 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:51:08 -0800 Subject: [PATCH 57/99] norm Signed-off-by: Mayank Mishra --- tests/hf_models/multi_gpu/dcp/dcp.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/hf_models/multi_gpu/dcp/dcp.py b/tests/hf_models/multi_gpu/dcp/dcp.py index 81787a9dd..fcfd7d6c1 100644 --- a/tests/hf_models/multi_gpu/dcp/dcp.py +++ b/tests/hf_models/multi_gpu/dcp/dcp.py @@ -94,9 +94,14 @@ Communication.barrier() -_, _, consolidated_state_dict = load_checkpoint_and_unshard(unshard_config) - if global_rank == 0: + with ( + ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), + ProcessGroupManager.set_dummy_tensor_parallel_rank(0), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_rank(0), + ): + _, _, consolidated_state_dict = load_checkpoint_and_unshard(unshard_config) original_state_dict = model_container[0].state_dict() assert consolidated_state_dict.keys() == original_state_dict.keys() From 4a1710c2feb6ad82e3d1c3dd652919e56df7a41a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:53:59 -0800 Subject: [PATCH 58/99] norm Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/linear/column.py | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/linear/column.py b/lm_engine/hf_models/modeling_utils/linear/column.py index 2003de0c4..19ef99dc0 100644 --- a/lm_engine/hf_models/modeling_utils/linear/column.py +++ b/lm_engine/hf_models/modeling_utils/linear/column.py @@ -28,7 +28,6 @@ def __init__( sequence_parallel: bool = False, ) -> ColumnParallelLinear: tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() self.out_features_per_device = divide_if_divisible( out_features, @@ -45,31 +44,43 @@ def __init__( std=std, ) - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) - ) - ) - if bias: - self.bias = nn.Parameter( + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.weight = nn.Parameter( tensor_to_dtensor( - self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) ) ) - self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + if bias: + self.bias = nn.Parameter( + tensor_to_dtensor( + self.bias, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Shard(0), + ) + ) - self.reset_parameters() + if use_async_tensor_parallel(): + self.compile() - if use_async_tensor_parallel(): - self.compile() + self.reset_parameters() def forward(self, x: torch.Tensor) -> torch.Tensor: - x = tensor_to_dtensor( - x, device_mesh=self.tp_mesh, current_placement=self.input_placement, desired_placement=Replicate() - ) + if self.is_tp_enabled: + x = tensor_to_dtensor( + x, device_mesh=self.tp_mesh, current_placement=self.input_placement, desired_placement=Replicate() + ) + x = super().forward(x) - x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=Shard(-1)) + + if self.is_tp_enabled: + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=Shard(-1)) + return x def extra_repr(self) -> str: From 3297ed3b8ba2520d85697cc3a9794e77d28c3f20 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 01:57:02 -0800 Subject: [PATCH 59/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/linear/column.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/linear/column.py b/lm_engine/hf_models/modeling_utils/linear/column.py index 19ef99dc0..94e23aa5b 100644 --- a/lm_engine/hf_models/modeling_utils/linear/column.py +++ b/lm_engine/hf_models/modeling_utils/linear/column.py @@ -29,7 +29,7 @@ def __init__( ) -> ColumnParallelLinear: tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - self.out_features_per_device = divide_if_divisible( + self.out_features_per_tp_rank = divide_if_divisible( out_features, tp_world_size, f"`out_features` ({out_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", @@ -37,7 +37,7 @@ def __init__( super().__init__( in_features=in_features, - out_features=self.out_features_per_device, + out_features=self.out_features_per_tp_rank, bias=bias, device=device, dtype=dtype, @@ -84,6 +84,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x def extra_repr(self) -> str: - return "in_features={}, out_features_per_device={}, bias={}".format( - self.in_features, self.out_features_per_device, self.bias is not None + return "in_features={}, out_features_per_tp_rank={}, bias={}".format( + self.in_features, self.out_features_per_tp_rank, self.bias is not None ) From 93953c6a88862d638a8d272837b9a940a28468a4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 02:04:22 -0800 Subject: [PATCH 60/99] norm Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/linear/row.py | 51 ++++++++++------- .../modeling_utils/mlp_blocks/mlp.py | 21 +++++-- .../modeling_utils_TP/mlp_blocks/__init__.py | 6 +- .../modeling_utils_TP/mlp_blocks/mlp.py | 56 ------------------- 4 files changed, 49 insertions(+), 85 deletions(-) delete mode 100644 lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py diff --git a/lm_engine/hf_models/modeling_utils/linear/row.py b/lm_engine/hf_models/modeling_utils/linear/row.py index d91015cdb..f58b79748 100644 --- a/lm_engine/hf_models/modeling_utils/linear/row.py +++ b/lm_engine/hf_models/modeling_utils/linear/row.py @@ -28,16 +28,15 @@ def __init__( sequence_parallel: bool = False, ) -> RowParallelLinear: tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - self.in_features_per_device = divide_if_divisible( + self.in_features_per_tp_rank = divide_if_divisible( in_features, tp_world_size, f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", ) super().__init__( - in_features=self.in_features_per_device, + in_features=self.in_features_per_tp_rank, out_features=out_features, bias=bias, device=device, @@ -45,34 +44,44 @@ def __init__( std=std, ) - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) - ) - ) - if bias: - self.bias = nn.Parameter( + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.weight = nn.Parameter( tensor_to_dtensor( - self.bias, - device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), - current_placement=Replicate(), + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) ) ) - self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + if bias: + self.bias = nn.Parameter( + tensor_to_dtensor( + self.bias, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), + ) + ) + + if use_async_tensor_parallel(): + self.compile() self.reset_parameters() - if use_async_tensor_parallel(): - self.compile() - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Shard(-1)) + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Shard(-1)) + x = super().forward(x) - x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.output_placement) + + if self.is_tp_enabled: + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.output_placement) + return x def extra_repr(self) -> str: - return "in_features_per_device={}, out_features={}, bias={}".format( - self.in_features_per_device, self.out_features, self.bias is not None + return "in_features_per_tp_rank={}, out_features={}, bias={}".format( + self.in_features_per_tp_rank, self.out_features, self.bias is not None ) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index c95747da9..c6a2c6c12 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -12,7 +12,7 @@ from ...parameter import mark_parameter_as_mup_learning_rate from ..activations import get_activation_function, is_glu from ..dropout import Dropout -from ..linear import ParameterizedLinear +from ..linear import ColumnParallelLinear, RowParallelLinear class MLP(nn.Module): @@ -27,25 +27,36 @@ def __init__( initializer_range: float, m_width: float, num_layers: int, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, ) -> MLP: super().__init__() std = _get_std_for_linear(initializer_range, init_method, m_width) - self.c_fc = ParameterizedLinear( + self.c_fc = ColumnParallelLinear( hidden_size, 2 * intermediate_size if is_glu(activation_function) else intermediate_size, bias=add_bias, std=std, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) self.act = get_activation_function(activation_function) - self.c_proj = ParameterizedLinear( - intermediate_size, hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=add_bias, + std=std / math.sqrt(2 * num_layers), + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) - self.dropout = Dropout(dropout) + self.dropout = Dropout( + dropout, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel + ) mark_parameter_as_mup_learning_rate(self.c_fc.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py index b74ba8585..cc96471c6 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py @@ -3,13 +3,13 @@ # ************************************************** from ...config import CommonConfig -from .mlp import MLP_TP +from ...modeling_utils import MLP from .moe import MoE_TP def get_mlp_block_TP( config: CommonConfig, use_padding_free_transformer: bool, sequence_parallel: bool, layer_idx: int -) -> MLP_TP | MoE_TP: +) -> MLP | MoE_TP: block = config.mlp_blocks[layer_idx] mlp_type = block.mlp_type @@ -28,7 +28,7 @@ def get_mlp_block_TP( ) if mlp_type == "MLP": - mlp = MLP_TP(**kwargs) + mlp = MLP(**kwargs) elif mlp_type == "MoE": mlp = MoE_TP( **kwargs, diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py deleted file mode 100644 index ce8091ab8..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py +++ /dev/null @@ -1,56 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import math - -import torch.nn as nn - -from ...modeling_utils import MLP, ColumnParallelLinear, Dropout, RowParallelLinear, get_activation_function, is_glu -from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear - - -class MLP_TP(MLP): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - activation_function: str, - add_bias: bool, - dropout: float, - init_method: str, - initializer_range: float, - m_width: float, - num_layers: int, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> MLP_TP: - nn.Module.__init__(self) - - std = _get_std_for_linear(initializer_range, init_method, m_width) - - self.c_fc = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size if is_glu(activation_function) else intermediate_size, - bias=add_bias, - std=std, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.act = get_activation_function(activation_function) - - self.c_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=add_bias, - std=std / math.sqrt(2 * num_layers), - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.dropout = Dropout( - dropout, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel - ) From 278c7f65af68a190b04eab4973ff0c0d5844cbf3 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 02:17:19 -0800 Subject: [PATCH 61/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils_TP/TP.py | 5 ----- lm_engine/hf_models/modeling_utils_TP/lm_head.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) delete mode 100644 lm_engine/hf_models/modeling_utils_TP/TP.py diff --git a/lm_engine/hf_models/modeling_utils_TP/TP.py b/lm_engine/hf_models/modeling_utils_TP/TP.py deleted file mode 100644 index bc2df335f..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/TP.py +++ /dev/null @@ -1,5 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from ..modeling_utils.TP import get_module_placements, tensor_parallel_split_safetensor_slice diff --git a/lm_engine/hf_models/modeling_utils_TP/lm_head.py b/lm_engine/hf_models/modeling_utils_TP/lm_head.py index ed3578c90..18584dae1 100644 --- a/lm_engine/hf_models/modeling_utils_TP/lm_head.py +++ b/lm_engine/hf_models/modeling_utils_TP/lm_head.py @@ -9,7 +9,7 @@ from ...dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel from ..modeling_utils import ParameterizedEmbedding -from .TP import get_module_placements +from ..modeling_utils.TP import get_module_placements class LMHead_TP(ParameterizedEmbedding): From ed561e819e168ec0c441890d11a81c9a7882a5dc Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 02:18:00 -0800 Subject: [PATCH 62/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils_TP/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py index b75386d3b..f2e5c8118 100644 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/__init__.py @@ -3,6 +3,5 @@ # ************************************************** from .lm_head import LMHead_TP -from .mlp_blocks import MLP_TP, MoE_TP, get_mlp_block_TP +from .mlp_blocks import MoE_TP, get_mlp_block_TP from .sequence_mixer_blocks import Attention_TP, get_sequence_mixer_TP -from .TP import get_module_placements, tensor_parallel_split_safetensor_slice From 57e7a87ea02cf9ab4b022cf9e3a0d4c97a3af142 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 02:19:05 -0800 Subject: [PATCH 63/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/__init__.py | 1 + lm_engine/hf_models/models/gpt_base_TP/weights/shard.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index b615c6101..b78cf3815 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -25,3 +25,4 @@ interleave_query_key_value_tensor_for_attention, split_query_key_value_tensor_for_attention, ) +from .TP import tensor_parallel_split_safetensor_slice diff --git a/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py b/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py index 7b234c7d2..6024fa95e 100644 --- a/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py +++ b/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py @@ -5,8 +5,7 @@ import torch from .....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible -from ....modeling_utils import get_tensor_parallel_vocab_info, is_glu -from ....modeling_utils_TP import tensor_parallel_split_safetensor_slice +from ....modeling_utils import get_tensor_parallel_vocab_info, is_glu, tensor_parallel_split_safetensor_slice from ...gpt_base import GPTBaseConfig From f4ccfea8a30b8564f877ff6d926e1705dea1e1f1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 02:40:55 -0800 Subject: [PATCH 64/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/dropout.py | 7 ++- .../hf_models/modeling_utils/embedding.py | 7 ++- .../hf_models/modeling_utils/linear/base.py | 10 +--- .../hf_models/modeling_utils/linear/column.py | 18 ++----- .../modeling_utils/linear/replicated.py | 43 ++++++++++++++++ .../hf_models/modeling_utils/linear/row.py | 18 ++----- .../modeling_utils/mlp_blocks/moe.py | 21 ++++++++ .../hf_models/modeling_utils/normalization.py | 7 ++- .../modeling_utils_TP/mlp_blocks/moe.py | 51 ++++++++++--------- 9 files changed, 109 insertions(+), 73 deletions(-) create mode 100644 lm_engine/hf_models/modeling_utils/linear/replicated.py diff --git a/lm_engine/hf_models/modeling_utils/dropout.py b/lm_engine/hf_models/modeling_utils/dropout.py index d4d50806c..d5b887217 100644 --- a/lm_engine/hf_models/modeling_utils/dropout.py +++ b/lm_engine/hf_models/modeling_utils/dropout.py @@ -31,10 +31,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_tp_enabled: x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) - - x = super().forward(x) - - if self.is_tp_enabled: + x = super().forward(x) x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.placement) + else: + x = super().forward(x) return x diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index 1125b144e..aa0b48d38 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -59,11 +59,10 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_tp_enabled: x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Replicate()) - - x = F.embedding(x, weight=self.weight) - - if self.is_tp_enabled: + x = F.embedding(x, weight=self.weight) x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.output_placement) + else: + x = F.embedding(x, weight=self.weight) return x diff --git a/lm_engine/hf_models/modeling_utils/linear/base.py b/lm_engine/hf_models/modeling_utils/linear/base.py index f23847943..8c340ae3d 100644 --- a/lm_engine/hf_models/modeling_utils/linear/base.py +++ b/lm_engine/hf_models/modeling_utils/linear/base.py @@ -12,16 +12,10 @@ class ParameterizedLinear(nn.Linear): def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, + self, in_features: int, out_features: int, bias: bool = True, std: float | None = None ) -> ParameterizedLinear: self.std = std - super().__init__(in_features, out_features, bias, device, dtype) + super().__init__(in_features, out_features, bias) mark_parameter_as_no_weight_decay(self.bias) diff --git a/lm_engine/hf_models/modeling_utils/linear/column.py b/lm_engine/hf_models/modeling_utils/linear/column.py index 94e23aa5b..7af0dc8ff 100644 --- a/lm_engine/hf_models/modeling_utils/linear/column.py +++ b/lm_engine/hf_models/modeling_utils/linear/column.py @@ -21,8 +21,6 @@ def __init__( in_features: int, out_features: int, bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, std: float | None = None, use_padding_free_transformer: bool = False, sequence_parallel: bool = False, @@ -35,14 +33,7 @@ def __init__( f"`out_features` ({out_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", ) - super().__init__( - in_features=in_features, - out_features=self.out_features_per_tp_rank, - bias=bias, - device=device, - dtype=dtype, - std=std, - ) + super().__init__(in_features=in_features, out_features=self.out_features_per_tp_rank, bias=bias, std=std) self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() @@ -75,11 +66,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = tensor_to_dtensor( x, device_mesh=self.tp_mesh, current_placement=self.input_placement, desired_placement=Replicate() ) - - x = super().forward(x) - - if self.is_tp_enabled: + x = super().forward(x) x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=Shard(-1)) + else: + x = super().forward(x) return x diff --git a/lm_engine/hf_models/modeling_utils/linear/replicated.py b/lm_engine/hf_models/modeling_utils/linear/replicated.py new file mode 100644 index 000000000..1f7a65a2c --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/linear/replicated.py @@ -0,0 +1,43 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.distributed._tensor.placement_types import Replicate + +from ....dtensors import tensor_to_dtensor +from ....utils import ProcessGroupManager +from ..dtensor_module import DTensorModule +from .base import ParameterizedLinear + + +class ReplicatedLinear(ParameterizedLinear, DTensorModule): + def __init__( + self, in_features: int, out_features: int, bias: bool = True, std: float | None = None + ) -> ReplicatedLinear: + super().__init__(in_features=in_features, out_features=out_features, bias=bias, std=std) + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), + ) + ) + + if bias: + self.bias = nn.Parameter( + tensor_to_dtensor( + self.bias, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), + ) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: ... diff --git a/lm_engine/hf_models/modeling_utils/linear/row.py b/lm_engine/hf_models/modeling_utils/linear/row.py index f58b79748..d7e80a94e 100644 --- a/lm_engine/hf_models/modeling_utils/linear/row.py +++ b/lm_engine/hf_models/modeling_utils/linear/row.py @@ -21,8 +21,6 @@ def __init__( in_features: int, out_features: int, bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, std: float | None = None, use_padding_free_transformer: bool = False, sequence_parallel: bool = False, @@ -35,14 +33,7 @@ def __init__( f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", ) - super().__init__( - in_features=self.in_features_per_tp_rank, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - std=std, - ) + super().__init__(in_features=self.in_features_per_tp_rank, out_features=out_features, bias=bias, std=std) self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() @@ -73,11 +64,10 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_tp_enabled: x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - - x = super().forward(x) - - if self.is_tp_enabled: + x = super().forward(x) x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.output_placement) + else: + x = super().forward(x) return x diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 32bd6751c..ac3b70e74 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -55,6 +55,27 @@ def compute_bincount(x: torch.Tensor, size: int, use_continuous_count: bool) -> return count +class ReplicatedLinear_TP(ParameterizedLinear, DTensorModule): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + std: float | None = None, + ) -> ReplicatedLinear_TP: + super().__init__( + in_features=in_features, out_features=out_features, bias=bias, device=device, dtype=dtype, std=std + ) + + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() + ) + ) + + class ParameterizedExperts(nn.Module): def __init__( self, diff --git a/lm_engine/hf_models/modeling_utils/normalization.py b/lm_engine/hf_models/modeling_utils/normalization.py index 5f7ecdd7f..089da664e 100644 --- a/lm_engine/hf_models/modeling_utils/normalization.py +++ b/lm_engine/hf_models/modeling_utils/normalization.py @@ -51,11 +51,10 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_tp_enabled: x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) - - x = super().forward(x) - - if self.is_tp_enabled: + x = super().forward(x) x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.placement) + else: + x = super().forward(x) return x diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index 0dcbffd8b..3efeb49a3 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -250,40 +250,38 @@ def __init__( torch.cuda.current_device() ) >= (9, 0) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: assert is_kernel_allowed(Kernel.scattermoe) if not self.use_padding_free_transformer: - batch_size, sequence_length, _ = hidden_states.shape + batch_size, sequence_length, _ = x.shape - hidden_states = hidden_states.view(-1, self.hidden_size) + x = x.view(-1, self.hidden_size) - hidden_states = tensor_to_dtensor(hidden_states, device_mesh=self.tp_mesh, current_placement=self.placement) + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) - router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states) + router_logits, router_weights, selected_experts = self._compute_routing_weights(x) - hidden_states = dtensor_to_tensor( - hidden_states, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial() - ) + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial()) - moe_output, expert_frequency = self._compute_experts(hidden_states, router_weights, selected_experts) + moe_output, expert_frequency = self._compute_experts(x, router_weights, selected_experts) if self.shared_intermediate_size is None: - hidden_states = moe_output + x = moe_output else: - hidden_states = moe_output + self._compute_shared_experts(hidden_states) + x = moe_output + self._compute_shared_experts(x) del moe_output - hidden_states = tensor_to_dtensor(hidden_states, device_mesh=self.tp_mesh, current_placement=Partial()) - hidden_states = dtensor_to_tensor( - hidden_states, device_mesh=self.tp_mesh, desired_placement=self.placement, grad_placement=self.placement + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Partial()) + x = dtensor_to_tensor( + x, device_mesh=self.tp_mesh, desired_placement=self.placement, grad_placement=self.placement ) if not self.use_padding_free_transformer: - hidden_states = hidden_states.reshape(batch_size, sequence_length, self.hidden_size) + x = x.reshape(batch_size, sequence_length, self.hidden_size) - hidden_states = self.dropout(hidden_states) + x = self.dropout(x) aux_loss = ( self._compute_switch_loss( @@ -295,20 +293,23 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: add_aux_loss(aux_loss) - return hidden_states + return x - def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]: - # hidden_states -> (total_q, hidden_size) - router_logits = self.gate(hidden_states) + def _compute_routing_weights(self, x: torch.Tensor) -> tuple[torch.Tensor]: + # x -> (total_q, hidden_size) + router_logits = self.gate(x) router_logits = dtensor_to_tensor( router_logits, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial() ) # router_logits -> (total_q, num_experts) - router_weights, selected_experts = self._get_topk(router_logits) - router_weights = F.softmax(router_weights.float(), dim=-1) - - # we cast back to the input dtype - router_weights = router_weights.type_as(hidden_states) + if self.normalized_topk: + router_weights, selected_experts = self._get_topk(router_logits) + router_weights = F.softmax(router_weights.float(), dim=-1) + router_weights = router_weights.type_as(x) + else: + router_weights = F.softmax(router_logits.float(), dim=-1) + router_weights = router_weights.type_as(x) + router_weights, selected_experts = self._get_topk(router_weights) return router_logits, router_weights, selected_experts From fd0b6b6bfb3557ddf6a1556ad69a058133331c56 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 02:44:03 -0800 Subject: [PATCH 65/99] norm Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/mlp_blocks/moe.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index ac3b70e74..cb36341d6 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -11,6 +11,7 @@ import torch.nn.functional as F from torch.distributed._functional_collectives import all_reduce +from ....dtensors import dtensor_to_tensor from ....enums import Kernel from ....kernels import is_kernel_allowed from ....utils import ProcessGroupManager, is_sonicmoe_available, is_xma_available @@ -22,7 +23,8 @@ ) from ..activations import get_activation_function, is_glu from ..dropout import Dropout -from ..linear import ParameterizedLinear +from ..dtensor_module import DTensorModule +from ..linear import ColumnParallelLinear, ParameterizedLinear, RowParallelLinear from .mlp import _get_std_for_linear @@ -55,6 +57,16 @@ def compute_bincount(x: torch.Tensor, size: int, use_continuous_count: bool) -> return count +class SharedExpertsColumnParallelLinear(ColumnParallelLinear): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, dtensor_to_tensor(self.weight), dtensor_to_tensor(self.bias)) + + +class SharedExpertsRowParallelLinear(RowParallelLinear): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, dtensor_to_tensor(self.weight), dtensor_to_tensor(self.bias)) + + class ReplicatedLinear_TP(ParameterizedLinear, DTensorModule): def __init__( self, From 6772ab0dc015682b805923e676c2998f2f7f83f4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 02:44:39 -0800 Subject: [PATCH 66/99] norm Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils_TP/mlp_blocks/moe.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index 3efeb49a3..524c50ed4 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -28,6 +28,7 @@ is_glu, ) from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear +from ...modeling_utils.mlp_blocks.moe import SharedExpertsColumnParallelLinear, SharedExpertsRowParallelLinear if is_xma_available(): @@ -158,16 +159,6 @@ def __init__( ) -class SharedExpertsColumnParallelLinear(ColumnParallelLinear): - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, dtensor_to_tensor(self.weight), dtensor_to_tensor(self.bias)) - - -class SharedExpertsRowParallelLinear(RowParallelLinear): - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, dtensor_to_tensor(self.weight), dtensor_to_tensor(self.bias)) - - class MoE_TP(MoE, DTensorModule): def __init__( self, From d19d9277a64144f0da94375fa363767ca4675aa5 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 02:51:36 -0800 Subject: [PATCH 67/99] norm Signed-off-by: Mayank Mishra --- .../modeling_utils/mlp_blocks/moe.py | 17 ++++------- .../modeling_utils_TP/mlp_blocks/moe.py | 29 ++----------------- 2 files changed, 8 insertions(+), 38 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index cb36341d6..b668e602c 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -90,22 +90,15 @@ def __init__( class ParameterizedExperts(nn.Module): def __init__( - self, - num_experts: int, - in_features: int, - out_features: int, - add_bias: bool = False, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, + self, num_experts: int, in_features: int, out_features: int, add_bias: bool = False, std: float | None = None ) -> ParameterizedExperts: super().__init__() - self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features, device=device, dtype=dtype)) + self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features)) self.bias = None if add_bias: - self.bias = nn.Parameter(torch.empty(num_experts, out_features, device=device, dtype=dtype)) + self.bias = nn.Parameter(torch.empty(num_experts, out_features)) self.std = std @@ -226,7 +219,7 @@ def __init__( std=std, ) if self.shared_intermediate_size is not None: - self.c_fc_shared = ParameterizedLinear( + self.c_fc_shared = SharedExpertsColumnParallelLinear( in_features=self.hidden_size, out_features=( 2 * self.shared_intermediate_size if is_glu(activation_function) else self.shared_intermediate_size @@ -248,7 +241,7 @@ def __init__( std=std, ) if self.shared_intermediate_size is not None: - self.c_proj_shared = ParameterizedLinear( + self.c_proj_shared = SharedExpertsRowParallelLinear( in_features=self.shared_intermediate_size, out_features=self.hidden_size, bias=add_bias, diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index 524c50ed4..caf4d9221 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -17,13 +17,11 @@ from ....utils import ProcessGroupManager, divide_if_divisible, is_xma_available from ...loss import add_aux_loss from ...modeling_utils import ( - ColumnParallelLinear, Dropout, DTensorModule, MoE, ParameterizedExperts, ParameterizedLinear, - RowParallelLinear, get_activation_function, is_glu, ) @@ -37,17 +35,9 @@ class ReplicatedLinear_TP(ParameterizedLinear, DTensorModule): def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, + self, in_features: int, out_features: int, bias: bool = True, std: float | None = None ) -> ReplicatedLinear_TP: - super().__init__( - in_features=in_features, out_features=out_features, bias=bias, device=device, dtype=dtype, std=std - ) + super().__init__(in_features=in_features, out_features=out_features, bias=bias, std=std) self.weight = nn.Parameter( tensor_to_dtensor( @@ -58,14 +48,7 @@ def __init__( class ColumnParallelExperts(ParameterizedExperts, DTensorModule): def __init__( - self, - num_experts: int, - in_features: int, - out_features: int, - add_bias: bool = False, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, + self, num_experts: int, in_features: int, out_features: int, add_bias: bool = False, std: float | None = None ) -> ColumnParallelExperts: tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() @@ -80,8 +63,6 @@ def __init__( in_features=in_features, out_features=self.out_features_per_device, add_bias=add_bias, - device=device, - dtype=dtype, std=std, ) @@ -129,8 +110,6 @@ def __init__( in_features: int, out_features: int, add_bias: bool = False, - device: torch.device | None = None, - dtype: torch.dtype | None = None, std: float | None = None, ) -> RowParallelExperts: tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() @@ -147,8 +126,6 @@ def __init__( in_features=self.in_features_per_device, out_features=out_features, add_bias=add_bias, - device=device, - dtype=dtype, std=std, ) From eb22d4c2ee962eb59f0e7034d1f2117d4c432fb0 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 02:58:34 -0800 Subject: [PATCH 68/99] norm Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py | 2 +- lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py | 7 +++---- .../hf_models/modeling_utils_TP/mlp_blocks/__init__.py | 3 +++ lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py | 7 ++++++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index 06877bc7e..e7f8d089a 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -21,6 +21,7 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye initializer_range=config.initializer_range, m_width=config.m_width, num_layers=config.num_layers, + use_padding_free_transformer=use_padding_free_transformer, ) if mlp_type == "MLP": @@ -34,7 +35,6 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye normalized_topk=block.normalized_topk, num_experts=block.num_experts, num_experts_per_tok=block.num_experts_per_tok, - use_padding_free_transformer=use_padding_free_transformer, ) else: raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index b668e602c..90ce97c3a 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -162,8 +162,6 @@ def reset_parameters(self) -> None: class MoE(nn.Module): - linear_class = ParameterizedExperts - def __init__( self, hidden_size: int, @@ -182,6 +180,7 @@ def __init__( m_width: float, num_layers: int, use_padding_free_transformer: bool, + sequence_parallel: bool = False, ) -> MoE: super().__init__() @@ -211,7 +210,7 @@ def __init__( in_features=self.hidden_size, out_features=1, bias=False, std=std ) - self.c_fc = self.linear_class( + self.c_fc = ColumnParallelExperts( num_experts=num_experts, in_features=self.hidden_size, out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size, @@ -233,7 +232,7 @@ def __init__( std /= math.sqrt(2 * num_layers) - self.c_proj = self.linear_class( + self.c_proj = RowParallelExperts( num_experts=num_experts, in_features=self.intermediate_size, out_features=self.hidden_size, diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py index cc96471c6..f6f7b8ab6 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py @@ -33,6 +33,9 @@ def get_mlp_block_TP( mlp = MoE_TP( **kwargs, shared_intermediate_size=block.shared_intermediate_size, + use_interleaved_weights=block.use_interleaved_weights, + shared_expert_gating=block.shared_expert_gating, + normalized_topk=block.normalized_topk, num_experts=block.num_experts, num_experts_per_tok=block.num_experts_per_tok, ) diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index caf4d9221..2f1823316 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -142,6 +142,9 @@ def __init__( hidden_size: int, intermediate_size: int, shared_intermediate_size: int, + use_interleaved_weights: bool, + shared_expert_gating: bool, + normalized_topk: bool, num_experts: int, num_experts_per_tok: int, activation_function: str, @@ -159,10 +162,12 @@ def __init__( self.num_experts = num_experts self.top_k = num_experts_per_tok self.use_padding_free_transformer = use_padding_free_transformer - self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.shared_intermediate_size = shared_intermediate_size + self.shared_expert_gating = shared_expert_gating + self.normalized_topk = normalized_topk + self.use_interleaved_weights = use_interleaved_weights std = _get_std_for_linear(initializer_range, init_method, m_width) From 402b23c96aa1a430f783b85ced8d411ba000ceda Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 03:27:57 -0800 Subject: [PATCH 69/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/mlp.py | 5 +- .../modeling_utils/mlp_blocks/__init__.py | 3 +- .../modeling_utils/mlp_blocks/moe.py | 172 +++++++++++++----- .../modeling_utils_TP/mlp_blocks/moe.py | 112 ++---------- 4 files changed, 139 insertions(+), 153 deletions(-) diff --git a/lm_engine/hf_models/config/mlp.py b/lm_engine/hf_models/config/mlp.py index 699493423..05ee4144d 100644 --- a/lm_engine/hf_models/config/mlp.py +++ b/lm_engine/hf_models/config/mlp.py @@ -18,8 +18,11 @@ def model_post_init(self, __context: Any) -> None: assert self.mlp_type == "MLP" -class _MoEArgs(_MLPArgs): +class _MoEArgs(BaseArgs): mlp_type: str = "MoE" + intermediate_size: int + activation_function: str = "gelu_pytorch_tanh" + dropout: float = 0 shared_intermediate_size: int | None = None num_experts: int = 8 use_interleaved_weights: bool = False diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index e7f8d089a..e10df95e8 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -15,7 +15,6 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye hidden_size=config.hidden_size, intermediate_size=block.intermediate_size, activation_function=block.activation_function, - add_bias=block.add_bias, dropout=block.dropout, init_method=config.init_method, initializer_range=config.initializer_range, @@ -25,7 +24,7 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye ) if mlp_type == "MLP": - mlp = MLP(**kwargs) + mlp = MLP(**kwargs, add_bias=block.add_bias) elif mlp_type == "MoE": mlp = MoE( **kwargs, diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 90ce97c3a..7189e1704 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -10,17 +10,14 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed._functional_collectives import all_reduce +from torch.distributed._tensor.placement_types import Replicate, Shard -from ....dtensors import dtensor_to_tensor +from ....dtensors import dtensor_to_tensor, tensor_to_dtensor from ....enums import Kernel -from ....kernels import is_kernel_allowed -from ....utils import ProcessGroupManager, is_sonicmoe_available, is_xma_available +from ....kernels import is_kernel_allowed, wait_for_ACT +from ....utils import ProcessGroupManager, divide_if_divisible, is_sonicmoe_available, is_xma_available from ...loss import add_aux_loss -from ...parameter import ( - mark_parameter_as_initialized, - mark_parameter_as_mup_learning_rate, - mark_parameter_as_no_weight_decay, -) +from ...parameter import mark_parameter_as_initialized, mark_parameter_as_mup_learning_rate from ..activations import get_activation_function, is_glu from ..dropout import Dropout from ..dtensor_module import DTensorModule @@ -69,17 +66,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ReplicatedLinear_TP(ParameterizedLinear, DTensorModule): def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, + self, in_features: int, out_features: int, bias: bool = True, std: float | None = None ) -> ReplicatedLinear_TP: - super().__init__( - in_features=in_features, out_features=out_features, bias=bias, device=device, dtype=dtype, std=std - ) + super().__init__(in_features=in_features, out_features=out_features, bias=bias, std=std) self.weight = nn.Parameter( tensor_to_dtensor( @@ -90,16 +79,11 @@ def __init__( class ParameterizedExperts(nn.Module): def __init__( - self, num_experts: int, in_features: int, out_features: int, add_bias: bool = False, std: float | None = None + self, num_experts: int, in_features: int, out_features: int, std: float | None = None ) -> ParameterizedExperts: super().__init__() self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features)) - - self.bias = None - if add_bias: - self.bias = nn.Parameter(torch.empty(num_experts, out_features)) - self.std = std self.num_experts = num_experts @@ -108,11 +92,9 @@ def __init__( self.reset_parameters() - mark_parameter_as_no_weight_decay(self.bias) - def forward( self, - input: torch.Tensor, + x: torch.Tensor, num_experts_per_token: int | None = None, expert_frequency: torch.Tensor | None = None, sorted_expert_idxs: torch.Tensor | None = None, @@ -125,8 +107,8 @@ def forward( if is_kernel_allowed(Kernel.scattermoe): assert self.bias is None - input = scattered_experts( - inputs=input, + x = scattered_experts( + inputs=x, expert_weights=self.weight.permute(0, 2, 1), k=num_experts_per_token, sorted_expert_idxs=sorted_expert_idxs, @@ -137,14 +119,11 @@ def forward( grouped_out=grouped_out, ) else: - input = input.split(expert_frequency.tolist(), dim=0) - input = [ - F.linear(input[i], self.weight[i], None if self.bias is None else self.bias[i]) - for i in range(self.num_experts) - ] - input = torch.cat(input, dim=0) + x = x.split(expert_frequency.tolist(), dim=0) + x = [F.linear(x[i], self.weight[i]) for i in range(self.num_experts)] + x = torch.cat(x, dim=0) - return input + return x def extra_repr(self) -> str: return "num_experts={}, in_features={}, out_features={}".format( @@ -154,11 +133,110 @@ def extra_repr(self) -> str: @torch.no_grad() def reset_parameters(self) -> None: nn.init.normal_(self.weight, mean=0, std=self.std) - if hasattr(self, "bias") and self.bias is not None: - self.bias.zero_() - mark_parameter_as_initialized(self.weight) - mark_parameter_as_initialized(self.bias) + + +class ColumnParallelExperts(ParameterizedExperts, DTensorModule): + def __init__( + self, num_experts: int, in_features: int, out_features: int, std: float | None = None + ) -> ColumnParallelExperts: + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + + self.out_features_per_tp_rank = divide_if_divisible( + out_features, + tp_world_size, + f"`out_features` ({out_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", + ) + + super().__init__( + num_experts=num_experts, in_features=in_features, out_features=self.out_features_per_tp_rank, std=std + ) + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) + ) + ) + + self.reset_parameters() + + def forward( + self, + x: torch.Tensor, + num_experts_per_token: int | None = None, + expert_frequency: torch.Tensor | None = None, + sorted_expert_idxs: torch.Tensor | None = None, + sorted_scattered_idxs: torch.Tensor | None = None, + expert_offsets: torch.Tensor | None = None, + gates: torch.Tensor | None = None, + grouped_in: bool = False, + grouped_out: bool = False, + ) -> torch.Tensor: + if self.is_tp_enabled: + assert is_kernel_allowed(Kernel.scattermoe) + + x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False) + + if is_kernel_allowed(Kernel.scattermoe): + assert self.bias is None + + x = scattered_experts( + inputs=x, + expert_weights=dtensor_to_tensor(self.weight).permute(0, 2, 1), + k=num_experts_per_token, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, + gates=gates, + grouped_in=grouped_in, + grouped_out=grouped_out, + ) + + x = wait_for_ACT(x, wait_in_forward=False, wait_in_backward=True) + else: + x = x.split(expert_frequency.tolist(), dim=0) + x = [F.linear(x[i], self.weight[i]) for i in range(self.num_experts)] + x = torch.cat(x, dim=0) + + return x + + def extra_repr(self) -> str: + return "num_experts={}, in_features={}, out_features_per_tp_rank={}".format( + self.num_experts, self.in_features, self.out_features_per_tp_rank + ) + + +class RowParallelExperts(ColumnParallelExperts): + def __init__( + self, num_experts: int, in_features: int, out_features: int, std: float | None = None + ) -> RowParallelExperts: + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + + self.in_features_per_device = divide_if_divisible( + in_features, + tp_world_size, + f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", + ) + + ParameterizedExperts.__init__( + self, num_experts=num_experts, in_features=self.in_features_per_device, out_features=out_features, std=std + ) + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Shard(-1), + ) + ) + + self.reset_parameters() class MoE(nn.Module): @@ -173,7 +251,6 @@ def __init__( num_experts: int, num_experts_per_tok: int, activation_function: str, - add_bias: bool, dropout: float, init_method: str, initializer_range: float, @@ -214,16 +291,16 @@ def __init__( num_experts=num_experts, in_features=self.hidden_size, out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size, - add_bias=add_bias, std=std, ) + if self.shared_intermediate_size is not None: self.c_fc_shared = SharedExpertsColumnParallelLinear( in_features=self.hidden_size, out_features=( 2 * self.shared_intermediate_size if is_glu(activation_function) else self.shared_intermediate_size ), - bias=add_bias, + bias=False, std=std, ) @@ -233,17 +310,14 @@ def __init__( std /= math.sqrt(2 * num_layers) self.c_proj = RowParallelExperts( - num_experts=num_experts, - in_features=self.intermediate_size, - out_features=self.hidden_size, - add_bias=add_bias, - std=std, + num_experts=num_experts, in_features=self.intermediate_size, out_features=self.hidden_size, std=std ) + if self.shared_intermediate_size is not None: self.c_proj_shared = SharedExpertsRowParallelLinear( in_features=self.shared_intermediate_size, out_features=self.hidden_size, - bias=add_bias, + bias=False, std=std, ) diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index 2f1823316..00e470ed1 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -26,7 +26,12 @@ is_glu, ) from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear -from ...modeling_utils.mlp_blocks.moe import SharedExpertsColumnParallelLinear, SharedExpertsRowParallelLinear +from ...modeling_utils.mlp_blocks.moe import ( + ColumnParallelExperts, + RowParallelExperts, + SharedExpertsColumnParallelLinear, + SharedExpertsRowParallelLinear, +) if is_xma_available(): @@ -46,96 +51,6 @@ def __init__( ) -class ColumnParallelExperts(ParameterizedExperts, DTensorModule): - def __init__( - self, num_experts: int, in_features: int, out_features: int, add_bias: bool = False, std: float | None = None - ) -> ColumnParallelExperts: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - self.out_features_per_device = divide_if_divisible( - out_features, - tp_world_size, - f"`out_features` ({out_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", - ) - - super().__init__( - num_experts=num_experts, - in_features=in_features, - out_features=self.out_features_per_device, - add_bias=add_bias, - std=std, - ) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) - ) - ) - - def forward( - self, - input: torch.Tensor, - num_experts_per_token: int | None = None, - num_tokens_per_expert: torch.Tensor | None = None, - sorted_expert_idxs: torch.Tensor | None = None, - sorted_scattered_idxs: torch.Tensor | None = None, - expert_offsets: torch.Tensor | None = None, - gates: torch.Tensor | None = None, - grouped_in: bool = False, - grouped_out: bool = False, - ) -> torch.Tensor: - assert is_kernel_allowed(Kernel.scattermoe) - - input = scattered_experts( - inputs=wait_for_ACT(input, wait_in_forward=True, wait_in_backward=False), - expert_weights=dtensor_to_tensor(self.weight).permute(0, 2, 1), - k=num_experts_per_token, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - expert_offsets=expert_offsets, - gates=gates, - grouped_in=grouped_in, - grouped_out=grouped_out, - ) - - input = wait_for_ACT(input, wait_in_forward=False, wait_in_backward=True) - - return input - - -class RowParallelExperts(ColumnParallelExperts): - def __init__( - self, - num_experts: int, - in_features: int, - out_features: int, - add_bias: bool = False, - std: float | None = None, - ) -> RowParallelExperts: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - self.in_features_per_device = divide_if_divisible( - in_features, - tp_world_size, - f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", - ) - - ParameterizedExperts.__init__( - self, - num_experts=num_experts, - in_features=self.in_features_per_device, - out_features=out_features, - add_bias=add_bias, - std=std, - ) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(-1) - ) - ) - - class MoE_TP(MoE, DTensorModule): def __init__( self, @@ -148,7 +63,6 @@ def __init__( num_experts: int, num_experts_per_tok: int, activation_function: str, - add_bias: bool, dropout: float, init_method: str, initializer_range: float, @@ -184,16 +98,16 @@ def __init__( num_experts=num_experts, in_features=self.hidden_size, out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size, - add_bias=add_bias, std=std, ) + if self.shared_intermediate_size is not None: self.c_fc_shared = SharedExpertsColumnParallelLinear( in_features=self.hidden_size, out_features=( 2 * self.shared_intermediate_size if is_glu(activation_function) else self.shared_intermediate_size ), - bias=add_bias, + bias=False, std=std, ) @@ -202,17 +116,13 @@ def __init__( std /= math.sqrt(2 * num_layers) self.c_proj = RowParallelExperts( - num_experts=num_experts, - in_features=self.intermediate_size, - out_features=self.hidden_size, - add_bias=add_bias, - std=std, + num_experts=num_experts, in_features=self.intermediate_size, out_features=self.hidden_size, std=std ) if self.shared_intermediate_size is not None: self.c_proj_shared = SharedExpertsRowParallelLinear( in_features=self.shared_intermediate_size, out_features=self.hidden_size, - bias=add_bias, + bias=False, std=std, ) @@ -268,7 +178,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x - def _compute_routing_weights(self, x: torch.Tensor) -> tuple[torch.Tensor]: + def _compute_routing_weights(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # x -> (total_q, hidden_size) router_logits = self.gate(x) router_logits = dtensor_to_tensor( From e14274bdaded30c4daf8b4832ea1b22a64f8ccab Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 03:29:05 -0800 Subject: [PATCH 70/99] fix linter Signed-off-by: Mayank Mishra --- .../multi_gpu/tensor_parallel/tensor_parallel_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index 2698cc8a7..71a58ac1b 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -58,7 +58,7 @@ ], mlp_blocks=[ {"mlp_type": "MLP", "add_bias": False}, - {"mlp_type": "MoE", "add_bias": False}, + {"mlp_type": "MoE"}, ], ) From e7393d5d9118097999f8a76a8c5619ca8bc0e47b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 03:29:29 -0800 Subject: [PATCH 71/99] fix linter Signed-off-by: Mayank Mishra --- .../multi_gpu/tensor_parallel/tensor_parallel_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index 71a58ac1b..ff1ef7dfd 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -70,7 +70,7 @@ with enable_kernels(kernels): if torch.distributed.get_rank() == 0: - with torch.device("meta"): + with torch.device("meta"), ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): model = TestCommons.from_config(None, config) model = model.to_empty(device=torch.cuda.current_device()) From ceb50c32ce40fba1f8e7f26d0f47acc631b4e09d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 03:30:40 -0800 Subject: [PATCH 72/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py index f6f7b8ab6..17dc539cb 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py @@ -17,7 +17,6 @@ def get_mlp_block_TP( hidden_size=config.hidden_size, intermediate_size=block.intermediate_size, activation_function=block.activation_function, - add_bias=block.add_bias, dropout=block.dropout, init_method=config.init_method, initializer_range=config.initializer_range, @@ -28,7 +27,7 @@ def get_mlp_block_TP( ) if mlp_type == "MLP": - mlp = MLP(**kwargs) + mlp = MLP(**kwargs, add_bias=block.add_bias) elif mlp_type == "MoE": mlp = MoE_TP( **kwargs, From de3d44f94db0c59aa3b543f518da51813074d0b5 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 03:32:07 -0800 Subject: [PATCH 73/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/models/gpt_base_TP/weights/shard.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py b/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py index 6024fa95e..6e4364da3 100644 --- a/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py +++ b/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py @@ -78,7 +78,6 @@ def get_gpt_base_model_parallel_state_dict( state_dict.update( _get_moe( activation_function=block.activation_function, - add_bias=block.add_bias, safetensors_weights_manager=safetensors_weights_manager, prefix=prefix + "mlp_block.", column_parallel_shard_dim=1, @@ -163,7 +162,6 @@ def _get_attention( def _get_moe( activation_function: str, - add_bias: bool, safetensors_weights_manager: SafeTensorsWeightsManager, prefix: str, column_parallel_shard_dim: int, @@ -171,12 +169,10 @@ def _get_moe( ) -> None: state_dict = {prefix + "gate.weight": safetensors_weights_manager.get_tensor(prefix + "gate.weight")} - assert not add_bias - state_dict.update( _get_mlp( activation_function=activation_function, - add_bias=add_bias, + add_bias=False, safetensors_weights_manager=safetensors_weights_manager, prefix=prefix, column_parallel_shard_dim=column_parallel_shard_dim, From 6fe2d8b54e022d98123f2f45ee7156ac162a5787 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 03:33:37 -0800 Subject: [PATCH 74/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 7189e1704..2317fbbd4 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -423,7 +423,7 @@ def _compute_experts( expert_offsets = expert_frequency.cumsum(-1) x = self.c_fc( - input=x, + x=x, num_experts_per_token=self.top_k, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, @@ -434,7 +434,7 @@ def _compute_experts( x = self.act(x) x = self.c_proj( - input=x, + x=x, num_experts_per_token=1, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, @@ -453,9 +453,9 @@ def _compute_experts( x = x[batch_index] - x = self.c_fc(input=x, expert_frequency=expert_frequency) + x = self.c_fc(x=x, expert_frequency=expert_frequency) x = self.act(x) - x = self.c_proj(input=x, expert_frequency=expert_frequency) + x = self.c_proj(x=x, expert_frequency=expert_frequency) x = x * batch_gates.unsqueeze(-1) # [:, None] zeros = torch.zeros((T, self.hidden_size), dtype=x.dtype, device=x.device) From a1bc7191444e0f3b1290824e3a9937ffbb5fc30c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 03:34:55 -0800 Subject: [PATCH 75/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 2317fbbd4..7598999fd 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -105,8 +105,6 @@ def forward( grouped_out: bool = False, ) -> torch.Tensor: if is_kernel_allowed(Kernel.scattermoe): - assert self.bias is None - x = scattered_experts( inputs=x, expert_weights=self.weight.permute(0, 2, 1), @@ -181,8 +179,6 @@ def forward( x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False) if is_kernel_allowed(Kernel.scattermoe): - assert self.bias is None - x = scattered_experts( inputs=x, expert_weights=dtensor_to_tensor(self.weight).permute(0, 2, 1), From 422daf5f2178c013d28fd379e3d9dd0af9d6a7df Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 03:35:33 -0800 Subject: [PATCH 76/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 7598999fd..c50cea1ae 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -176,9 +176,9 @@ def forward( if self.is_tp_enabled: assert is_kernel_allowed(Kernel.scattermoe) - x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False) - if is_kernel_allowed(Kernel.scattermoe): + x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False) + x = scattered_experts( inputs=x, expert_weights=dtensor_to_tensor(self.weight).permute(0, 2, 1), From 7f942341bcc42b435b2835ec6431ea35e71b7b3b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 03:36:56 -0800 Subject: [PATCH 77/99] fix linter Signed-off-by: Mayank Mishra --- .../modeling_utils_TP/mlp_blocks/moe.py | 32 +++---------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index 00e470ed1..bd46defb0 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -13,44 +13,20 @@ from ....dtensors import dtensor_to_tensor, tensor_to_dtensor from ....enums import Kernel -from ....kernels import is_kernel_allowed, wait_for_ACT -from ....utils import ProcessGroupManager, divide_if_divisible, is_xma_available +from ....kernels import is_kernel_allowed +from ....utils import ProcessGroupManager from ...loss import add_aux_loss -from ...modeling_utils import ( - Dropout, - DTensorModule, - MoE, - ParameterizedExperts, - ParameterizedLinear, - get_activation_function, - is_glu, -) +from ...modeling_utils import Dropout, DTensorModule, MoE, get_activation_function, is_glu from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear from ...modeling_utils.mlp_blocks.moe import ( ColumnParallelExperts, + ReplicatedLinear_TP, RowParallelExperts, SharedExpertsColumnParallelLinear, SharedExpertsRowParallelLinear, ) -if is_xma_available(): - from xma.layers.moe import scattered_experts - - -class ReplicatedLinear_TP(ParameterizedLinear, DTensorModule): - def __init__( - self, in_features: int, out_features: int, bias: bool = True, std: float | None = None - ) -> ReplicatedLinear_TP: - super().__init__(in_features=in_features, out_features=out_features, bias=bias, std=std) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() - ) - ) - - class MoE_TP(MoE, DTensorModule): def __init__( self, From 68b9df912f1648e6e51a0d38fd32055f5fca4bf1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 13:27:55 -0800 Subject: [PATCH 78/99] fix linter Signed-off-by: Mayank Mishra --- .../modeling_utils/mlp_blocks/moe.py | 25 ++++++++++++++++ .../modeling_utils_TP/mlp_blocks/__init__.py | 7 ++--- .../modeling_utils_TP/mlp_blocks/moe.py | 30 +++++++------------ 3 files changed, 38 insertions(+), 24 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index c50cea1ae..6fc56f678 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -318,6 +318,7 @@ def __init__( ) self.dropout = Dropout(dropout) + self.placement = Shard(0) if sequence_parallel else Replicate() self.is_hopper_or_newer_gpu = torch.cuda.is_available() and torch.cuda.get_device_capability( torch.cuda.current_device() @@ -333,15 +334,21 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_fc_shared.weight) mark_parameter_as_mup_learning_rate(self.c_proj_shared.weight) + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.use_padding_free_transformer: batch_size, sequence_length, _ = x.shape x = x.view(-1, self.hidden_size) + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) + if is_kernel_allowed(Kernel.sonicmoe): assert self.use_interleaved_weights assert self.activation_function_string == "swiglu" + assert not self.is_tp_enabled moe_output, router_logits, expert_frequency = moe_TC_softmax_topk_layer( x=x, @@ -358,6 +365,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: assert not self.use_interleaved_weights router_logits, router_weights, selected_experts = self._compute_routing_weights(x) + + if self.is_tp_enabled: + x = dtensor_to_tensor( + x, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial() + ) + moe_output, expert_frequency = self._compute_experts(x, router_weights, selected_experts) if self.shared_intermediate_size is None: @@ -367,6 +380,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: del moe_output + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Partial()) + x = dtensor_to_tensor( + x, device_mesh=self.tp_mesh, desired_placement=self.placement, grad_placement=self.placement + ) + if not self.use_padding_free_transformer: x = x.reshape(batch_size, sequence_length, self.hidden_size) @@ -387,6 +406,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _compute_routing_weights(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # x -> (total_q, hidden_size) router_logits = self.gate(x) + + if self.is_tp_enabled: + router_logits = dtensor_to_tensor( + router_logits, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial() + ) + # router_logits -> (total_q, num_experts) if self.normalized_topk: diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py index 17dc539cb..9b104e09f 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py @@ -3,13 +3,12 @@ # ************************************************** from ...config import CommonConfig -from ...modeling_utils import MLP -from .moe import MoE_TP +from ...modeling_utils import MLP, MoE def get_mlp_block_TP( config: CommonConfig, use_padding_free_transformer: bool, sequence_parallel: bool, layer_idx: int -) -> MLP | MoE_TP: +) -> MLP | MoE: block = config.mlp_blocks[layer_idx] mlp_type = block.mlp_type @@ -29,7 +28,7 @@ def get_mlp_block_TP( if mlp_type == "MLP": mlp = MLP(**kwargs, add_bias=block.add_bias) elif mlp_type == "MoE": - mlp = MoE_TP( + mlp = MoE( **kwargs, shared_intermediate_size=block.shared_intermediate_size, use_interleaved_weights=block.use_interleaved_weights, diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py index bd46defb0..6c1f988aa 100644 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py @@ -46,7 +46,7 @@ def __init__( num_layers: int, use_padding_free_transformer: bool, sequence_parallel: bool = False, - ) -> MoE_TP: + ) -> MoE: nn.Module.__init__(self) self.num_experts = num_experts @@ -70,6 +70,13 @@ def __init__( std=std, ) + if self.shared_expert_gating: + assert shared_intermediate_size is not None + + self.shared_expert_gate = ReplicatedLinear_TP( + in_features=self.hidden_size, out_features=1, bias=False, std=std + ) + self.c_fc = ColumnParallelExperts( num_experts=num_experts, in_features=self.hidden_size, @@ -87,6 +94,7 @@ def __init__( std=std, ) + self.activation_function_string = activation_function self.act = get_activation_function(activation_function) std /= math.sqrt(2 * num_layers) @@ -94,6 +102,7 @@ def __init__( self.c_proj = RowParallelExperts( num_experts=num_experts, in_features=self.intermediate_size, out_features=self.hidden_size, std=std ) + if self.shared_intermediate_size is not None: self.c_proj_shared = SharedExpertsRowParallelLinear( in_features=self.shared_intermediate_size, @@ -153,22 +162,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: add_aux_loss(aux_loss) return x - - def _compute_routing_weights(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # x -> (total_q, hidden_size) - router_logits = self.gate(x) - router_logits = dtensor_to_tensor( - router_logits, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial() - ) - # router_logits -> (total_q, num_experts) - - if self.normalized_topk: - router_weights, selected_experts = self._get_topk(router_logits) - router_weights = F.softmax(router_weights.float(), dim=-1) - router_weights = router_weights.type_as(x) - else: - router_weights = F.softmax(router_logits.float(), dim=-1) - router_weights = router_weights.type_as(x) - router_weights, selected_experts = self._get_topk(router_weights) - - return router_logits, router_weights, selected_experts From 323ad5a577093b3bf8812e5dcd562582ef4e6ef1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 13:30:42 -0800 Subject: [PATCH 79/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils_TP/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py index f2e5c8118..dfb08bb75 100644 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/__init__.py @@ -3,5 +3,5 @@ # ************************************************** from .lm_head import LMHead_TP -from .mlp_blocks import MoE_TP, get_mlp_block_TP +from .mlp_blocks import get_mlp_block_TP from .sequence_mixer_blocks import Attention_TP, get_sequence_mixer_TP From a9728afddada1435535ffa3f69ee82ac763227ba Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 13:35:30 -0800 Subject: [PATCH 80/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 5 +- lm_engine/hf_models/mixins/dense_TP/layer.py | 6 +- .../modeling_utils/mlp_blocks/__init__.py | 5 +- .../modeling_utils/mlp_blocks/moe.py | 3 + .../hf_models/modeling_utils_TP/__init__.py | 1 - .../modeling_utils_TP/mlp_blocks/__init__.py | 43 ----- .../modeling_utils_TP/mlp_blocks/moe.py | 164 ------------------ .../hf_models/models/gpt_crosslayer/layer.py | 5 +- lm_engine/hf_models/models/palm/layer.py | 5 +- 9 files changed, 22 insertions(+), 215 deletions(-) delete mode 100644 lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py delete mode 100644 lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 0cceab740..bd85d42d0 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -30,7 +30,10 @@ def __init__( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx + config, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=False, + layer_idx=layer_idx, ) def forward( diff --git a/lm_engine/hf_models/mixins/dense_TP/layer.py b/lm_engine/hf_models/mixins/dense_TP/layer.py index dae1fb966..e6f19076d 100644 --- a/lm_engine/hf_models/mixins/dense_TP/layer.py +++ b/lm_engine/hf_models/mixins/dense_TP/layer.py @@ -7,8 +7,8 @@ import torch.nn as nn from ...config import CommonConfig -from ...modeling_utils import get_normalization_function -from ...modeling_utils_TP import get_mlp_block_TP, get_sequence_mixer_TP +from ...modeling_utils import get_mlp_block, get_normalization_function +from ...modeling_utils_TP import get_sequence_mixer_TP from ..dense import Block @@ -47,7 +47,7 @@ def __init__( use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, ) - self.mlp_block = get_mlp_block_TP( + self.mlp_block = get_mlp_block( config, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index e10df95e8..c3e408f8d 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -7,7 +7,9 @@ from .moe import MoE, ParameterizedExperts -def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int) -> MLP | MoE: +def get_mlp_block( + config: CommonConfig, use_padding_free_transformer: bool, sequence_parallel: bool, layer_idx: int +) -> MLP | MoE: block = config.mlp_blocks[layer_idx] mlp_type = block.mlp_type @@ -21,6 +23,7 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye m_width=config.m_width, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) if mlp_type == "MLP": diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 6fc56f678..377cda7c7 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -336,6 +336,9 @@ def __init__( self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.use_padding_free_transformer: batch_size, sequence_length, _ = x.shape diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py index dfb08bb75..a657eddc7 100644 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/__init__.py @@ -3,5 +3,4 @@ # ************************************************** from .lm_head import LMHead_TP -from .mlp_blocks import get_mlp_block_TP from .sequence_mixer_blocks import Attention_TP, get_sequence_mixer_TP diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py deleted file mode 100644 index 9b104e09f..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from ...config import CommonConfig -from ...modeling_utils import MLP, MoE - - -def get_mlp_block_TP( - config: CommonConfig, use_padding_free_transformer: bool, sequence_parallel: bool, layer_idx: int -) -> MLP | MoE: - block = config.mlp_blocks[layer_idx] - mlp_type = block.mlp_type - - kwargs = dict( - hidden_size=config.hidden_size, - intermediate_size=block.intermediate_size, - activation_function=block.activation_function, - dropout=block.dropout, - init_method=config.init_method, - initializer_range=config.initializer_range, - m_width=config.m_width, - num_layers=config.num_layers, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - if mlp_type == "MLP": - mlp = MLP(**kwargs, add_bias=block.add_bias) - elif mlp_type == "MoE": - mlp = MoE( - **kwargs, - shared_intermediate_size=block.shared_intermediate_size, - use_interleaved_weights=block.use_interleaved_weights, - shared_expert_gating=block.shared_expert_gating, - normalized_topk=block.normalized_topk, - num_experts=block.num_experts, - num_experts_per_tok=block.num_experts_per_tok, - ) - else: - raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") - - return mlp diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py deleted file mode 100644 index 6c1f988aa..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ /dev/null @@ -1,164 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.distributed._tensor.placement_types import Partial, Replicate, Shard - -from ....dtensors import dtensor_to_tensor, tensor_to_dtensor -from ....enums import Kernel -from ....kernels import is_kernel_allowed -from ....utils import ProcessGroupManager -from ...loss import add_aux_loss -from ...modeling_utils import Dropout, DTensorModule, MoE, get_activation_function, is_glu -from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear -from ...modeling_utils.mlp_blocks.moe import ( - ColumnParallelExperts, - ReplicatedLinear_TP, - RowParallelExperts, - SharedExpertsColumnParallelLinear, - SharedExpertsRowParallelLinear, -) - - -class MoE_TP(MoE, DTensorModule): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - shared_intermediate_size: int, - use_interleaved_weights: bool, - shared_expert_gating: bool, - normalized_topk: bool, - num_experts: int, - num_experts_per_tok: int, - activation_function: str, - dropout: float, - init_method: str, - initializer_range: float, - m_width: float, - num_layers: int, - use_padding_free_transformer: bool, - sequence_parallel: bool = False, - ) -> MoE: - nn.Module.__init__(self) - - self.num_experts = num_experts - self.top_k = num_experts_per_tok - self.use_padding_free_transformer = use_padding_free_transformer - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.shared_intermediate_size = shared_intermediate_size - self.shared_expert_gating = shared_expert_gating - self.normalized_topk = normalized_topk - self.use_interleaved_weights = use_interleaved_weights - - std = _get_std_for_linear(initializer_range, init_method, m_width) - - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.gate = ReplicatedLinear_TP( - in_features=self.hidden_size, - out_features=num_experts, - bias=False, - std=std, - ) - - if self.shared_expert_gating: - assert shared_intermediate_size is not None - - self.shared_expert_gate = ReplicatedLinear_TP( - in_features=self.hidden_size, out_features=1, bias=False, std=std - ) - - self.c_fc = ColumnParallelExperts( - num_experts=num_experts, - in_features=self.hidden_size, - out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size, - std=std, - ) - - if self.shared_intermediate_size is not None: - self.c_fc_shared = SharedExpertsColumnParallelLinear( - in_features=self.hidden_size, - out_features=( - 2 * self.shared_intermediate_size if is_glu(activation_function) else self.shared_intermediate_size - ), - bias=False, - std=std, - ) - - self.activation_function_string = activation_function - self.act = get_activation_function(activation_function) - - std /= math.sqrt(2 * num_layers) - - self.c_proj = RowParallelExperts( - num_experts=num_experts, in_features=self.intermediate_size, out_features=self.hidden_size, std=std - ) - - if self.shared_intermediate_size is not None: - self.c_proj_shared = SharedExpertsRowParallelLinear( - in_features=self.shared_intermediate_size, - out_features=self.hidden_size, - bias=False, - std=std, - ) - - self.dropout = Dropout(dropout) - self.placement = Shard(0) if sequence_parallel else Replicate() - - self.is_hopper_or_newer_gpu = torch.cuda.is_available() and torch.cuda.get_device_capability( - torch.cuda.current_device() - ) >= (9, 0) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - assert is_kernel_allowed(Kernel.scattermoe) - - if not self.use_padding_free_transformer: - batch_size, sequence_length, _ = x.shape - - x = x.view(-1, self.hidden_size) - - x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) - - router_logits, router_weights, selected_experts = self._compute_routing_weights(x) - - x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial()) - - moe_output, expert_frequency = self._compute_experts(x, router_weights, selected_experts) - - if self.shared_intermediate_size is None: - x = moe_output - else: - x = moe_output + self._compute_shared_experts(x) - - del moe_output - - x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Partial()) - x = dtensor_to_tensor( - x, device_mesh=self.tp_mesh, desired_placement=self.placement, grad_placement=self.placement - ) - - if not self.use_padding_free_transformer: - x = x.reshape(batch_size, sequence_length, self.hidden_size) - - x = self.dropout(x) - - aux_loss = ( - self._compute_switch_loss( - logits=router_logits, probs=torch.softmax(router_logits, dim=-1), expert_frequency=expert_frequency - ) - if self.training - else 0 - ) - - add_aux_loss(aux_loss) - - return x diff --git a/lm_engine/hf_models/models/gpt_crosslayer/layer.py b/lm_engine/hf_models/models/gpt_crosslayer/layer.py index 3413911ff..cfdfa6acc 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/layer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/layer.py @@ -53,7 +53,10 @@ def __init__( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx + config, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=False, + layer_idx=layer_idx, ) def forward( diff --git a/lm_engine/hf_models/models/palm/layer.py b/lm_engine/hf_models/models/palm/layer.py index 8bafb9098..652e5e6b8 100644 --- a/lm_engine/hf_models/models/palm/layer.py +++ b/lm_engine/hf_models/models/palm/layer.py @@ -25,7 +25,10 @@ def __init__( ) self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx + config, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=False, + layer_idx=layer_idx, ) def forward( From 572d5ba979c5db9f7fc185d47a2b8ffb92923412 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 13:44:48 -0800 Subject: [PATCH 81/99] fix linter Signed-off-by: Mayank Mishra --- .../modeling_utils/linear/__init__.py | 1 + .../modeling_utils/linear/replicated.py | 3 --- .../modeling_utils/mlp_blocks/moe.py | 22 ++----------------- 3 files changed, 3 insertions(+), 23 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/linear/__init__.py b/lm_engine/hf_models/modeling_utils/linear/__init__.py index b4a6ac5e0..173bb73e1 100644 --- a/lm_engine/hf_models/modeling_utils/linear/__init__.py +++ b/lm_engine/hf_models/modeling_utils/linear/__init__.py @@ -4,4 +4,5 @@ from .base import ParameterizedLinear from .column import ColumnParallelLinear +from .replicated import ReplicatedLinear from .row import RowParallelLinear diff --git a/lm_engine/hf_models/modeling_utils/linear/replicated.py b/lm_engine/hf_models/modeling_utils/linear/replicated.py index 1f7a65a2c..5e769e33e 100644 --- a/lm_engine/hf_models/modeling_utils/linear/replicated.py +++ b/lm_engine/hf_models/modeling_utils/linear/replicated.py @@ -4,7 +4,6 @@ from __future__ import annotations -import torch import torch.nn as nn from torch.distributed._tensor.placement_types import Replicate @@ -39,5 +38,3 @@ def __init__( current_placement=Replicate(), ) ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: ... diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 377cda7c7..56ecd8fea 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -21,7 +21,7 @@ from ..activations import get_activation_function, is_glu from ..dropout import Dropout from ..dtensor_module import DTensorModule -from ..linear import ColumnParallelLinear, ParameterizedLinear, RowParallelLinear +from ..linear import ColumnParallelLinear, ParameterizedLinear, ReplicatedLinear, RowParallelLinear from .mlp import _get_std_for_linear @@ -64,19 +64,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.linear(x, dtensor_to_tensor(self.weight), dtensor_to_tensor(self.bias)) -class ReplicatedLinear_TP(ParameterizedLinear, DTensorModule): - def __init__( - self, in_features: int, out_features: int, bias: bool = True, std: float | None = None - ) -> ReplicatedLinear_TP: - super().__init__(in_features=in_features, out_features=out_features, bias=bias, std=std) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() - ) - ) - - class ParameterizedExperts(nn.Module): def __init__( self, num_experts: int, in_features: int, out_features: int, std: float | None = None @@ -269,12 +256,7 @@ def __init__( std = _get_std_for_linear(initializer_range, init_method, m_width) - self.gate = ParameterizedLinear( - in_features=self.hidden_size, - out_features=num_experts, - bias=False, - std=std, - ) + self.gate = ReplicatedLinear(in_features=self.hidden_size, out_features=num_experts, bias=False, std=std) if self.shared_expert_gating: assert shared_intermediate_size is not None From 48f288fff81fe659150b41f0c50b6ab3b8ae7ac1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 13:45:23 -0800 Subject: [PATCH 82/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 56ecd8fea..eb08a7b49 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -10,7 +10,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed._functional_collectives import all_reduce -from torch.distributed._tensor.placement_types import Replicate, Shard +from torch.distributed._tensor.placement_types import Partial, Replicate, Shard from ....dtensors import dtensor_to_tensor, tensor_to_dtensor from ....enums import Kernel From 0af2bdab9950aa84d59cc69f48fa98d5eb3748df Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 13:54:40 -0800 Subject: [PATCH 83/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/layer.py | 8 +- lm_engine/hf_models/mixins/dense_TP/layer.py | 5 +- .../sequence_mixer_blocks/__init__.py | 2 + .../sequence_mixer_blocks/attention.py | 81 +++++-- .../hf_models/modeling_utils_TP/__init__.py | 1 - .../sequence_mixer_blocks/__init__.py | 38 --- .../sequence_mixer_blocks/attention.py | 227 ------------------ lm_engine/hf_models/models/palm/layer.py | 8 +- tests/hf_models/single_gpu/weight_test.py | 2 +- 9 files changed, 78 insertions(+), 294 deletions(-) delete mode 100644 lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/__init__.py delete mode 100644 lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index bd85d42d0..1ca0214d6 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -25,7 +25,13 @@ def __init__( self.ln_1 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.sequence_mixer = get_sequence_mixer( + config, + True, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=False, + layer_idx=layer_idx, + ) self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) diff --git a/lm_engine/hf_models/mixins/dense_TP/layer.py b/lm_engine/hf_models/mixins/dense_TP/layer.py index e6f19076d..446069d64 100644 --- a/lm_engine/hf_models/mixins/dense_TP/layer.py +++ b/lm_engine/hf_models/mixins/dense_TP/layer.py @@ -7,8 +7,7 @@ import torch.nn as nn from ...config import CommonConfig -from ...modeling_utils import get_mlp_block, get_normalization_function -from ...modeling_utils_TP import get_sequence_mixer_TP +from ...modeling_utils import get_mlp_block, get_normalization_function, get_sequence_mixer from ..dense import Block @@ -33,7 +32,7 @@ def __init__( use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, ) - self.sequence_mixer = get_sequence_mixer_TP( + self.sequence_mixer = get_sequence_mixer( config, True, use_padding_free_transformer=use_padding_free_transformer, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index e4d42e081..643364b0a 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -23,6 +23,7 @@ def get_sequence_mixer( config: CommonConfig, causal: bool, use_padding_free_transformer: bool, + sequence_parallel: bool, layer_idx: int, ) -> SEQUENCE_MIXER_TYPE: block = config.sequence_mixer_blocks[layer_idx] @@ -160,6 +161,7 @@ def get_sequence_mixer( qkv_bias=block.qkv_bias, softmax_dropout=block.softmax_dropout, use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) else: raise ValueError(f"unexpected sequence_mixer_type ({sequence_mixer_type})") diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index c18d10e68..0f72f19df 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -17,7 +17,7 @@ from ...parameter import mark_parameter_as_mup_learning_rate from ..chunk import contiguous_split from ..dropout import Dropout -from ..linear import ParameterizedLinear +from ..linear import ColumnParallelLinear, ParameterizedLinear, RowParallelLinear from ..position_embedding import apply_rotary_pos_emb from .utils import flash_attention @@ -85,55 +85,91 @@ def __init__( m_width: float, num_layers: int, causal: bool, - layer_idx: int, - use_padding_free_transformer: bool, + layer_idx: int | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, ) -> Attention: super().__init__() + self.tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + self.causal = causal - self.hidden_size = hidden_size - self.num_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads + self.global_hidden_size = hidden_size + self.global_num_heads = num_attention_heads + self.global_num_key_value_heads = num_key_value_heads self.add_bias = add_bias self.qkv_bias = qkv_bias - self.use_padding_free_transformer = use_padding_free_transformer self.sliding_window = sliding_window - self.head_dim = divide_if_divisible( - self.hidden_size, - self.num_heads, - f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})", + self.use_padding_free_transformer = use_padding_free_transformer + self.sequence_parallel = sequence_parallel + + divide_if_divisible(self.global_hidden_size, self.global_num_heads) + + self.hidden_size = divide_if_divisible( + self.global_hidden_size, self.tp_world_size, "hidden_size should be divisible by TP world size" + ) + + self.num_heads = divide_if_divisible( + self.global_num_heads, self.tp_world_size, "num_heads must be divisible by TP world size" ) + self.head_dim = divide_if_divisible(self.hidden_size, self.num_heads, "") self.position_embedding_type = position_embedding_type self.attention_multiplier = attention_multiplier self.layer_idx = layer_idx divide_if_divisible( - self.num_heads, - self.num_key_value_heads, - f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` ({self.num_key_value_heads})", + self.global_num_heads, + self.global_num_key_value_heads, + f"`num_heads` ({self.global_num_heads}) should be a multiple of `num_key_value_heads` ({self.global_num_key_value_heads})", + ) + + self.num_key_value_heads = divide_if_divisible( + self.global_num_key_value_heads, + tp_world_size, + f"`num_key_value_heads` ({self.global_num_key_value_heads}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", ) std = initializer_range if init_method == "mup": std /= math.sqrt(m_width) - self.c_attn = ParameterizedLinear( - self.hidden_size, - self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, - bias=self.qkv_bias, + + self.c_attn = ColumnParallelLinear( + self.global_hidden_size, + self.global_hidden_size + 2 * self.global_num_key_value_heads * self.head_dim, + bias=self.add_bias, std=std, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) std = initializer_range / math.sqrt(2 * num_layers) if init_method == "mup": std /= math.sqrt(m_width) - self.c_proj = ParameterizedLinear(self.hidden_size, self.hidden_size, bias=self.add_bias, std=std) + + self.c_proj = RowParallelLinear( + self.global_hidden_size, + self.global_hidden_size, + bias=self.add_bias, + std=std / math.sqrt(2 * num_layers), + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) self.softmax_dropout_p = softmax_dropout - self.softmax_dropout = Dropout(softmax_dropout) - self.dropout = Dropout(dropout) + self.softmax_dropout = Dropout( + softmax_dropout, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) + + self.dropout = Dropout( + dropout, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) @@ -155,11 +191,12 @@ def forward( assert use_flash_attention_2 or use_flash_attention_3 assert past_key_values is None - T = hidden_states.size(0) + T = hidden_states.size(0) * (self.tp_world_size if self.sequence_parallel else 1) input_shape = (T, self.num_key_value_heads, -1) output_shape = (T, -1, self.head_dim) else: batch_size, query_length = hidden_states.shape[:-1] + query_length *= self.tp_world_size if self.sequence_parallel else 1 input_shape = (batch_size, query_length, self.num_key_value_heads, -1) output_shape = (batch_size, query_length, -1, self.head_dim) diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py index a657eddc7..b522b434a 100644 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ b/lm_engine/hf_models/modeling_utils_TP/__init__.py @@ -3,4 +3,3 @@ # ************************************************** from .lm_head import LMHead_TP -from .sequence_mixer_blocks import Attention_TP, get_sequence_mixer_TP diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/__init__.py deleted file mode 100644 index f14da7d8d..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from ...config import CommonConfig -from .attention import Attention_TP - - -def get_sequence_mixer_TP( - config: CommonConfig, - causal: bool, - use_padding_free_transformer: bool, - layer_idx: int, - sequence_parallel: bool, -) -> Attention_TP: - block = config.sequence_mixer_blocks[layer_idx] - sequence_mixer_type = block.sequence_mixer_type - - sequence_mixer_kwargs = dict( - hidden_size=config.hidden_size, - num_attention_heads=block.num_attention_heads, - num_key_value_heads=block.num_key_value_heads, - attention_multiplier=block.attention_multiplier, - position_embedding_type=config.position_embedding_type, - add_bias=block.add_bias, - softmax_dropout=block.softmax_dropout, - dropout=block.dropout, - init_method=config.init_method, - initializer_range=config.initializer_range, - m_width=config.m_width, - num_layers=config.num_layers, - causal=causal, - layer_idx=layer_idx, - sequence_parallel=sequence_parallel, - ) - - if sequence_mixer_type == "softmax_attention": - return Attention_TP(**sequence_mixer_kwargs, use_padding_free_transformer=use_padding_free_transformer) diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py deleted file mode 100644 index 1b0944f94..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py +++ /dev/null @@ -1,227 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ....enums import Kernel -from ....kernels import is_kernel_allowed, wait_for_ACT -from ....utils import ProcessGroupManager, divide_if_divisible -from ...cache import GenerationCache -from ...modeling_utils import ( - Attention, - ColumnParallelLinear, - Dropout, - RowParallelLinear, - apply_rotary_pos_emb, - flash_attention, -) -from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear - - -class Attention_TP(Attention): - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - num_key_value_heads: int, - attention_multiplier: float, - position_embedding_type: str, - add_bias: bool, - softmax_dropout: float, - dropout: float, - init_method: str, - initializer_range: float, - m_width: float, - num_layers: int, - causal: bool, - layer_idx: int | None = None, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> Attention_TP: - nn.Module.__init__(self) - - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - self.causal = causal - self.global_hidden_size = hidden_size - self.global_num_heads = num_attention_heads - self.global_num_key_value_heads = num_key_value_heads - self.add_bias = add_bias - self.use_padding_free_transformer = use_padding_free_transformer - self.sequence_parallel = sequence_parallel - - divide_if_divisible( - self.global_hidden_size, - self.global_num_heads, - f"`embed_dim` ({self.global_hidden_size}) must be divisible by `num_heads` ({self.global_num_heads})", - ) - - self.hidden_size = divide_if_divisible( - self.global_hidden_size, tp_world_size, "hidden_size should be divisible by TP world size" - ) - - self.num_heads = divide_if_divisible( - self.global_num_heads, tp_world_size, "num_heads must be divisible by TP world size" - ) - - self.head_dim = divide_if_divisible(self.hidden_size, self.num_heads, "") - self.position_embedding_type = position_embedding_type - self.attention_multiplier = attention_multiplier - self.layer_idx = layer_idx - - std = _get_std_for_linear(initializer_range, init_method, m_width) - - divide_if_divisible( - self.global_num_heads, - self.global_num_key_value_heads, - f"`num_heads` ({self.global_num_heads}) should be a multiple of `num_key_value_heads` ({self.global_num_key_value_heads})", - ) - - self.num_key_value_heads = divide_if_divisible( - self.global_num_key_value_heads, - tp_world_size, - f"`num_key_value_heads` ({self.global_num_key_value_heads}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", - ) - - self.c_attn = ColumnParallelLinear( - self.global_hidden_size, - self.global_hidden_size + 2 * self.global_num_key_value_heads * self.head_dim, - bias=self.add_bias, - std=std, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.c_proj = RowParallelLinear( - self.global_hidden_size, - self.global_hidden_size, - bias=self.add_bias, - std=std / math.sqrt(2 * num_layers), - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.softmax_dropout_p = softmax_dropout - - self.softmax_dropout = Dropout( - softmax_dropout, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.dropout = Dropout( - dropout, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, - rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - ) -> torch.Tensor: - use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) - use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - if self.use_padding_free_transformer: - assert use_flash_attention_2 or use_flash_attention_3 - assert past_key_values is None - - total_q = hidden_states.shape[0] * (tp_world_size if self.sequence_parallel else 1) - input_shape = (total_q, self.num_key_value_heads, -1) - output_shape = (total_q, -1, self.head_dim) - else: - batch_size, query_length = hidden_states.shape[:-1] - query_length *= tp_world_size if self.sequence_parallel else 1 - - input_shape = (batch_size, query_length, self.num_key_value_heads, -1) - output_shape = (batch_size, query_length, -1, self.head_dim) - - hidden_states = self.c_attn(hidden_states) - - hidden_states = hidden_states.view(*input_shape) - - query, key, value = hidden_states.split( - ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 - ) - - query = query.reshape(*output_shape) - - if not self.use_padding_free_transformer: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - if self.position_embedding_type == "rope": - query = apply_rotary_pos_emb(query, rope_cos_sin) - key = apply_rotary_pos_emb(key, rope_cos_sin) - - if past_key_values is not None: - key, value = past_key_values.update(key_states=key, value_states=value, layer_idx=self.layer_idx) - - if use_flash_attention_2 or use_flash_attention_3: - if self.use_padding_free_transformer: - output_shape = (-1, self.hidden_size) - else: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - output_shape = (batch_size, query_length, -1) - - query = wait_for_ACT(query, wait_in_forward=True, wait_in_backward=False) - key = wait_for_ACT(key, wait_in_forward=True, wait_in_backward=False) - value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) - - hidden_states = flash_attention( - q=query, - k=key, - v=value, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - attention_mask=attention_mask, - use_padding_free_transformer=self.use_padding_free_transformer, - causal=self.causal, - dropout=self.softmax_dropout_p if self.training else 0, - softmax_scale=self.attention_multiplier, - ) - - del query, key, value - - hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) - hidden_states = hidden_states.view(*output_shape) - else: - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=self.softmax_dropout_p if self.training else 0, - is_causal=self.causal if attention_mask is None else False, - scale=self.attention_multiplier, - enable_gqa=True, - ) - - del query, key, value - - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.transpose(1, 2) - hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) - - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states) - - return hidden_states diff --git a/lm_engine/hf_models/models/palm/layer.py b/lm_engine/hf_models/models/palm/layer.py index 652e5e6b8..c54b686f3 100644 --- a/lm_engine/hf_models/models/palm/layer.py +++ b/lm_engine/hf_models/models/palm/layer.py @@ -23,7 +23,13 @@ def __init__( self.ln = get_normalization_function( config.normalization_function, config.hidden_size, eps=config.layer_norm_epsilon ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.sequence_mixer = get_sequence_mixer( + config, + True, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=False, + layer_idx=layer_idx, + ) self.mlp_block = get_mlp_block( config, use_padding_free_transformer=use_padding_free_transformer, diff --git a/tests/hf_models/single_gpu/weight_test.py b/tests/hf_models/single_gpu/weight_test.py index cb488b9cf..78bcb5198 100644 --- a/tests/hf_models/single_gpu/weight_test.py +++ b/tests/hf_models/single_gpu/weight_test.py @@ -19,7 +19,7 @@ def test_query_key_value_weight_loading_and_saving(self) -> None: config = self.get_dense_test_config("learned_absolute") layer_idx = 1 - attention = get_sequence_mixer(config, True, False, layer_idx) + attention = get_sequence_mixer(config, True, False, False, layer_idx) num_key_value_heads = config.sequence_mixer_blocks[layer_idx].num_key_value_heads state_dict = attention.state_dict() From 68a2b0c214ea695ac003ed43645f929db7da0f97 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 13:57:49 -0800 Subject: [PATCH 84/99] fix linter Signed-off-by: Mayank Mishra --- .../modeling_utils/sequence_mixer_blocks/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 0f72f19df..50a8f7cf8 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -12,12 +12,12 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed, wait_for_ACT -from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available +from ....utils import Accelerator, ProcessGroupManager, divide_if_divisible, is_torch_xla_available from ...cache import GenerationCache from ...parameter import mark_parameter_as_mup_learning_rate from ..chunk import contiguous_split from ..dropout import Dropout -from ..linear import ColumnParallelLinear, ParameterizedLinear, RowParallelLinear +from ..linear import ColumnParallelLinear, RowParallelLinear from ..position_embedding import apply_rotary_pos_emb from .utils import flash_attention From 62d6526a2e08e022f0a0ff4bc87a0fe2bd79c2e6 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 13:58:46 -0800 Subject: [PATCH 85/99] fix linter Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 50a8f7cf8..8db8b2b4c 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -127,7 +127,7 @@ def __init__( self.num_key_value_heads = divide_if_divisible( self.global_num_key_value_heads, - tp_world_size, + self.tp_world_size, f"`num_key_value_heads` ({self.global_num_key_value_heads}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", ) From 527dfeebc0047b326f9e6b056230b3093c058a8a Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 14:03:19 -0800 Subject: [PATCH 86/99] fix linter Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/sequence_mixer_blocks/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 8db8b2b4c..2035c2722 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -128,7 +128,7 @@ def __init__( self.num_key_value_heads = divide_if_divisible( self.global_num_key_value_heads, self.tp_world_size, - f"`num_key_value_heads` ({self.global_num_key_value_heads}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", + f"`num_key_value_heads` ({self.global_num_key_value_heads}) must be divisible by `tensor_parallel_world_size` ({self.tp_world_size})", ) std = initializer_range From e2ec3b3f1f1bf9aa727c1fff7d0bb4c34983070d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 14:35:50 -0800 Subject: [PATCH 87/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/base.py | 18 +++++++++++++++- lm_engine/hf_models/mixins/dense/layer.py | 22 +++++++++++++++----- lm_engine/hf_models/mixins/dense_TP/base.py | 23 +-------------------- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 9582ee384..4c8e66cf0 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -30,6 +30,14 @@ class PreTrainedModelMixin(PreTrainedModel): def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixin: super().__init__(config, *args, **kwargs) + self.sequence_parallel = kwargs.get("sequence_parallel", False) + self.num_pipeline_stages = kwargs.get("num_pipeline_stages", 1) + self.pipeline_stage_id = kwargs.get("pipeline_stage_id", 0) + + self.is_first_stage = self.pipeline_stage_id == 0 + self.is_last_stage = self.pipeline_stage_id == self.num_pipeline_stages - 1 + self.is_pipeline_parallel_enabled = self.num_pipeline_stages > 1 + assert self.config_class is not None self.generation_config = GenerationConfig.from_model_config(self.config) @@ -38,6 +46,9 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixi self._has_mamba2 = any([block.sequence_mixer_type == "mamba2" for block in self.config.sequence_mixer_blocks]) + if self.is_pipeline_parallel_enabled and self._tied_word_embeddings: + raise NotImplementedError() + # FIXME typing def prepare_inputs_for_model( self, @@ -101,7 +112,12 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.embedding_dropout = Dropout(config.embedding_dropout) self.h = nn.ModuleList( [ - self.layer_class(config, use_padding_free_transformer=self.use_padding_free_transformer, layer_idx=i) + self.layer_class( + config, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + layer_idx=i, + ) for i in range(config.num_layers) ] ) diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 1ca0214d6..e7d7a80de 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -14,7 +14,11 @@ class Block(nn.Module): def __init__( - self, config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int | None = None + self, + config: CommonConfig, + use_padding_free_transformer: bool, + layer_idx: int, + sequence_parallel: bool, ) -> Block: super().__init__() @@ -23,22 +27,30 @@ def __init__( self.sequence_mixer_type = config.sequence_mixer_blocks[layer_idx].sequence_mixer_type self.ln_1 = get_normalization_function( - config.normalization_function, hidden_size, eps=config.layer_norm_epsilon + config.normalization_function, + hidden_size, + eps=config.layer_norm_epsilon, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) self.sequence_mixer = get_sequence_mixer( config, True, use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=False, + sequence_parallel=sequence_parallel, layer_idx=layer_idx, ) self.ln_2 = get_normalization_function( - config.normalization_function, hidden_size, eps=config.layer_norm_epsilon + config.normalization_function, + hidden_size, + eps=config.layer_norm_epsilon, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) self.mlp_block = get_mlp_block( config, use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=False, + sequence_parallel=sequence_parallel, layer_idx=layer_idx, ) diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index ddd504a75..2daa66b88 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -14,30 +14,9 @@ from ...utils import is_generation_cache_enabled from ..dense import BaseModelMixin, PreTrainedModelMixin from ..modeling_outputs import BaseModelOutputWithPast -from .layer import Block_TP -class PreTrainedModelMixin_TP(PreTrainedModelMixin): - layer_class = Block_TP - _no_split_modules = ["Block_TP"] - - def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixin_TP: - self.sequence_parallel = kwargs.get("sequence_parallel", False) - - self.num_pipeline_stages = kwargs.get("num_pipeline_stages", 1) - self.pipeline_stage_id = kwargs.get("pipeline_stage_id", 0) - - self.is_first_stage = self.pipeline_stage_id == 0 - self.is_last_stage = self.pipeline_stage_id == self.num_pipeline_stages - 1 - self.is_pipeline_parallel_enabled = self.num_pipeline_stages > 1 - - super().__init__(config, *args, **kwargs) - - if self.is_pipeline_parallel_enabled and self._tied_word_embeddings: - raise NotImplementedError() - - -class BaseModelMixin_TP(PreTrainedModelMixin_TP, BaseModelMixin): +class BaseModelMixin_TP(PreTrainedModelMixin, BaseModelMixin): def _init_model(self, config: CommonConfig, **kwargs) -> None: self.embed_dim = config.hidden_size self.max_position_embeddings = config.max_position_embeddings From 4abc04b1ad413db3d92191eabb99dc75b0ed7e0c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 14:36:26 -0800 Subject: [PATCH 88/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/__init__.py | 2 +- .../hf_models/mixins/dense_TP/__init__.py | 3 +- lm_engine/hf_models/mixins/dense_TP/layer.py | 54 ------------------- 3 files changed, 2 insertions(+), 57 deletions(-) delete mode 100644 lm_engine/hf_models/mixins/dense_TP/layer.py diff --git a/lm_engine/hf_models/mixins/__init__.py b/lm_engine/hf_models/mixins/__init__.py index 2d9f6d03a..439bcccbb 100644 --- a/lm_engine/hf_models/mixins/__init__.py +++ b/lm_engine/hf_models/mixins/__init__.py @@ -3,7 +3,7 @@ # ************************************************** from .dense import BaseModelMixin, Block, CausalLMModelMixin, PreTrainedModelMixin -from .dense_TP import BaseModelMixin_TP, Block_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP +from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP from .modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, diff --git a/lm_engine/hf_models/mixins/dense_TP/__init__.py b/lm_engine/hf_models/mixins/dense_TP/__init__.py index 817c2815f..6d93899fa 100644 --- a/lm_engine/hf_models/mixins/dense_TP/__init__.py +++ b/lm_engine/hf_models/mixins/dense_TP/__init__.py @@ -2,6 +2,5 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -from .base import BaseModelMixin_TP, PreTrainedModelMixin_TP -from .layer import Block_TP +from .base import BaseModelMixin_TP from .main import CausalLMModelMixin_TP diff --git a/lm_engine/hf_models/mixins/dense_TP/layer.py b/lm_engine/hf_models/mixins/dense_TP/layer.py deleted file mode 100644 index 446069d64..000000000 --- a/lm_engine/hf_models/mixins/dense_TP/layer.py +++ /dev/null @@ -1,54 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import torch.nn as nn - -from ...config import CommonConfig -from ...modeling_utils import get_mlp_block, get_normalization_function, get_sequence_mixer -from ..dense import Block - - -class Block_TP(Block): - def __init__( - self, - config: CommonConfig, - use_padding_free_transformer: bool, - layer_idx: int | None = None, - sequence_parallel: bool = False, - ) -> Block_TP: - nn.Module.__init__(self) - - hidden_size = config.hidden_size - self.m_residual = config.m_residual - self.sequence_mixer_type = config.sequence_mixer_blocks[layer_idx].sequence_mixer_type - - self.ln_1 = get_normalization_function( - config.normalization_function, - hidden_size, - eps=config.layer_norm_epsilon, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - self.sequence_mixer = get_sequence_mixer( - config, - True, - use_padding_free_transformer=use_padding_free_transformer, - layer_idx=layer_idx, - sequence_parallel=sequence_parallel, - ) - self.ln_2 = get_normalization_function( - config.normalization_function, - hidden_size, - eps=config.layer_norm_epsilon, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - self.mlp_block = get_mlp_block( - config, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - layer_idx=layer_idx, - ) From f9315211782512856a747c7149f553b924e30ef1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 14:37:34 -0800 Subject: [PATCH 89/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense_TP/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index 2daa66b88..af58dcef6 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -16,7 +16,7 @@ from ..modeling_outputs import BaseModelOutputWithPast -class BaseModelMixin_TP(PreTrainedModelMixin, BaseModelMixin): +class BaseModelMixin_TP(BaseModelMixin): def _init_model(self, config: CommonConfig, **kwargs) -> None: self.embed_dim = config.hidden_size self.max_position_embeddings = config.max_position_embeddings From a47b79bf8fe7bffa71dfde061abff4dbc693fa52 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 14:42:49 -0800 Subject: [PATCH 90/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense_TP/main.py | 3 +-- lm_engine/hf_models/models/gpt_base_TP/base.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index 0a03389d2..3f1da0773 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -29,10 +29,9 @@ PipelineParallelInput, PipelineParallelOutput, ) -from .base import PreTrainedModelMixin_TP -class CausalLMModelMixin_TP(PreTrainedModelMixin_TP, CausalLMModelMixin): +class CausalLMModelMixin_TP(CausalLMModelMixin): model_parallel_state_dict_function = None def _init_model(self, config: CommonConfig, **kwargs) -> None: diff --git a/lm_engine/hf_models/models/gpt_base_TP/base.py b/lm_engine/hf_models/models/gpt_base_TP/base.py index e5236b0ac..33fb7cd48 100644 --- a/lm_engine/hf_models/models/gpt_base_TP/base.py +++ b/lm_engine/hf_models/models/gpt_base_TP/base.py @@ -2,11 +2,11 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -from ...mixins import BaseModelMixin_TP, PreTrainedModelMixin_TP +from ...mixins import BaseModelMixin_TP, PreTrainedModelMixin from ..gpt_base import GPTBaseConfig -class GPTBasePreTrainedModel_TP(PreTrainedModelMixin_TP): +class GPTBasePreTrainedModel_TP(PreTrainedModelMixin): config_class = GPTBaseConfig From 40eb0ad0b36a1b00f48b5b91f65e84162a1ea079 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 14:47:48 -0800 Subject: [PATCH 91/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense_TP/main.py | 6 ++-- .../hf_models/modeling_utils/__init__.py | 1 + .../lm_head.py | 32 +++++++++---------- .../hf_models/modeling_utils_TP/__init__.py | 5 --- 4 files changed, 19 insertions(+), 25 deletions(-) rename lm_engine/hf_models/{modeling_utils_TP => modeling_utils}/lm_head.py (75%) delete mode 100644 lm_engine/hf_models/modeling_utils_TP/__init__.py diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index 3f1da0773..dd43d804a 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -20,7 +20,7 @@ get_aux_loss, is_aux_loss_zero, ) -from ...modeling_utils_TP import LMHead_TP +from ...modeling_utils import LMHead from ...parameter import _INIT_MARKER, get_parameter_marker_maps, set_parameter_marker_maps from ..dense import CausalLMModelMixin from ..modeling_outputs import ( @@ -40,7 +40,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: if self.is_last_stage: if not self._tied_word_embeddings: - self.lm_head = LMHead_TP( + self.lm_head = LMHead( self.vocab_size, config.hidden_size, std=config.initializer_range, @@ -165,7 +165,7 @@ def forward( def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: return ( - LMHead_TP.compute_with_weight( + LMHead.compute_with_weight( hidden_states, weight=self.transformer.wte.weight, use_padding_free_transformer=self.use_padding_free_transformer, diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index b78cf3815..3404b87d3 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -8,6 +8,7 @@ from .dtensor_module import DTensorModule from .embedding import ParameterizedEmbedding, get_tensor_parallel_vocab_info from .linear import ColumnParallelLinear, ParameterizedLinear, RowParallelLinear +from .lm_head import LMHead from .mlp_blocks import ( MLP, MoE, diff --git a/lm_engine/hf_models/modeling_utils_TP/lm_head.py b/lm_engine/hf_models/modeling_utils/lm_head.py similarity index 75% rename from lm_engine/hf_models/modeling_utils_TP/lm_head.py rename to lm_engine/hf_models/modeling_utils/lm_head.py index 18584dae1..a37966671 100644 --- a/lm_engine/hf_models/modeling_utils_TP/lm_head.py +++ b/lm_engine/hf_models/modeling_utils/lm_head.py @@ -12,10 +12,10 @@ from ..modeling_utils.TP import get_module_placements -class LMHead_TP(ParameterizedEmbedding): - def forward(self, input: torch.Tensor) -> torch.Tensor: +class LMHead(ParameterizedEmbedding): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.compute_with_weight( - input, + x, self.weight, use_padding_free_transformer=self.use_padding_free_transformer, sequence_parallel=self.sequence_parallel, @@ -24,18 +24,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: @staticmethod def compute_with_weight( - input: torch.Tensor, + x: torch.Tensor, weight: torch.Tensor, use_padding_free_transformer: bool, sequence_parallel: bool, tp_mesh: DeviceMesh, ) -> torch.Tensor: - function = ( - LMHead_TP._compute_with_weight_compiled if use_async_tensor_parallel() else LMHead_TP._compute_with_weight - ) + function = LMHead._compute_with_weight_compiled if use_async_tensor_parallel() else LMHead._compute_with_weight return function( - input=input, + input=x, weight=weight, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, @@ -44,33 +42,33 @@ def compute_with_weight( @staticmethod def _compute_with_weight( - input: torch.Tensor, + x: torch.Tensor, weight: torch.Tensor, use_padding_free_transformer: bool, sequence_parallel: bool, tp_mesh: DeviceMesh, ) -> torch.Tensor: - input = tensor_to_dtensor( - input, + x = tensor_to_dtensor( + x, device_mesh=tp_mesh, current_placement=get_module_placements(use_padding_free_transformer, sequence_parallel), desired_placement=Replicate(), ) - input = F.linear(input, weight) - input = dtensor_to_tensor(input, device_mesh=tp_mesh, desired_placement=Shard(-1)) - return input + x = F.linear(x, weight) + x = dtensor_to_tensor(x, device_mesh=tp_mesh, desired_placement=Shard(-1)) + return x @torch.compile @staticmethod def _compute_with_weight_compiled( - input: torch.Tensor, + x: torch.Tensor, weight: torch.Tensor, use_padding_free_transformer: bool, sequence_parallel: bool, tp_mesh: DeviceMesh, ) -> torch.Tensor: - return LMHead_TP._compute_with_weight( - input=input, + return LMHead._compute_with_weight( + input=x, weight=weight, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py deleted file mode 100644 index b522b434a..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from .lm_head import LMHead_TP From 98fc18a65641387705b468226562bd94eff5bb94 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 14:49:01 -0800 Subject: [PATCH 92/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/lm_head.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/lm_head.py b/lm_engine/hf_models/modeling_utils/lm_head.py index a37966671..5b9cdf394 100644 --- a/lm_engine/hf_models/modeling_utils/lm_head.py +++ b/lm_engine/hf_models/modeling_utils/lm_head.py @@ -33,7 +33,7 @@ def compute_with_weight( function = LMHead._compute_with_weight_compiled if use_async_tensor_parallel() else LMHead._compute_with_weight return function( - input=x, + x=x, weight=weight, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, @@ -68,7 +68,7 @@ def _compute_with_weight_compiled( tp_mesh: DeviceMesh, ) -> torch.Tensor: return LMHead._compute_with_weight( - input=x, + x=x, weight=weight, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, From a7786aa553a591dc67275113e3d2659313dccc23 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 15:03:28 -0800 Subject: [PATCH 93/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 86 +++++++++++++++++++++ lm_engine/hf_models/mixins/dense_TP/main.py | 85 -------------------- 2 files changed, 86 insertions(+), 85 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 98e0a1303..3f4660b4d 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -10,6 +10,7 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed +from ....utils import ProcessGroupManager from ...cache import GenerationCache from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero @@ -241,3 +242,88 @@ def generate( ) return generated_tokens + + def get_dummy_input_tensor( + self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype + ) -> tuple[torch.Tensor] | torch.Tensor: + if self.is_first_stage: + # 1 is added to sequence length since megatron's dataloader gives an extra token and for good reason + dummy_input = torch.empty( + micro_batch_size, sequence_length + 1, device=torch.cuda.current_device(), dtype=torch.long + ) + else: + dummy_input = self._get_dummy_intermediate_tensor( + micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype + ) + + dummy_input = ( + dummy_input, + torch.empty(1, device=torch.cuda.current_device(), dtype=intermediate_dtype), + ) + + return dummy_input + + def get_dummy_output_tensor( + self, + micro_batch_size: int, + sequence_length: int, + intermediate_dtype: torch.dtype, + output_parallel_lm_logits_if_possible: bool, + ) -> tuple[torch.Tensor] | torch.Tensor: + if self.is_last_stage: + vocab_size = self.config.vocab_size + if output_parallel_lm_logits_if_possible: + vocab_size = divide_if_divisible(vocab_size, ProcessGroupManager.get_tensor_parallel_world_size(), "") + + if self.use_padding_free_transformer: + tensor = torch.empty( + micro_batch_size * sequence_length, + vocab_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + else: + tensor = torch.empty( + micro_batch_size, + sequence_length, + vocab_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + else: + tensor = self._get_dummy_intermediate_tensor( + micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype + ) + + tensor = (tensor, torch.empty(1, device=torch.cuda.current_device(), dtype=intermediate_dtype)) + + return tensor + + def _get_dummy_intermediate_tensor( + self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype + ) -> tuple[torch.Tensor] | torch.Tensor: + sharded_sequence_length = ( + divide_if_divisible(sequence_length, ProcessGroupManager.get_tensor_parallel_world_size(), "") + if self.sequence_parallel + else sequence_length + ) + + hidden_size = self.config.hidden_size + + if self.use_padding_free_transformer: + tensor = torch.empty( + micro_batch_size * sharded_sequence_length, + hidden_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + else: + tensor = torch.empty( + micro_batch_size, + sharded_sequence_length, + hidden_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + + return tensor diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index dd43d804a..24a64b547 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -213,88 +213,3 @@ def load_from_safetensors_weights_manager(self, safetensors_weights_manager: Saf ) self.load_state_dict(state_dict) - - def get_dummy_input_tensor( - self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype - ) -> tuple[torch.Tensor] | torch.Tensor: - if self.is_first_stage: - # 1 is added to sequence length since megatron's dataloader gives an extra token and for good reason - dummy_input = torch.empty( - micro_batch_size, sequence_length + 1, device=torch.cuda.current_device(), dtype=torch.long - ) - else: - dummy_input = self._get_dummy_intermediate_tensor( - micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype - ) - - dummy_input = ( - dummy_input, - torch.empty(1, device=torch.cuda.current_device(), dtype=intermediate_dtype), - ) - - return dummy_input - - def get_dummy_output_tensor( - self, - micro_batch_size: int, - sequence_length: int, - intermediate_dtype: torch.dtype, - output_parallel_lm_logits_if_possible: bool, - ) -> tuple[torch.Tensor] | torch.Tensor: - if self.is_last_stage: - vocab_size = self.config.vocab_size - if output_parallel_lm_logits_if_possible: - vocab_size = divide_if_divisible(vocab_size, ProcessGroupManager.get_tensor_parallel_world_size(), "") - - if self.use_padding_free_transformer: - tensor = torch.empty( - micro_batch_size * sequence_length, - vocab_size, - device=torch.cuda.current_device(), - dtype=intermediate_dtype, - ) - else: - tensor = torch.empty( - micro_batch_size, - sequence_length, - vocab_size, - device=torch.cuda.current_device(), - dtype=intermediate_dtype, - ) - else: - tensor = self._get_dummy_intermediate_tensor( - micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype - ) - - tensor = (tensor, torch.empty(1, device=torch.cuda.current_device(), dtype=intermediate_dtype)) - - return tensor - - def _get_dummy_intermediate_tensor( - self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype - ) -> tuple[torch.Tensor] | torch.Tensor: - sharded_sequence_length = ( - divide_if_divisible(sequence_length, ProcessGroupManager.get_tensor_parallel_world_size(), "") - if self.sequence_parallel - else sequence_length - ) - - hidden_size = self.config.hidden_size - - if self.use_padding_free_transformer: - tensor = torch.empty( - micro_batch_size * sharded_sequence_length, - hidden_size, - device=torch.cuda.current_device(), - dtype=intermediate_dtype, - ) - else: - tensor = torch.empty( - micro_batch_size, - sharded_sequence_length, - hidden_size, - device=torch.cuda.current_device(), - dtype=intermediate_dtype, - ) - - return tensor From 7c056dc72e123cfebf2dc68198882ae9f50c0a7d Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 15:12:46 -0800 Subject: [PATCH 94/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/mixins/dense/main.py | 18 +++++++++++++++++- lm_engine/hf_models/mixins/dense_TP/main.py | 16 ---------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 3f4660b4d..e68a651e5 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -10,7 +10,7 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed -from ....utils import ProcessGroupManager +from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible from ...cache import GenerationCache from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero @@ -327,3 +327,19 @@ def _get_dummy_intermediate_tensor( ) return tensor + + def load_from_safetensors_weights_manager(self, safetensors_weights_manager: SafeTensorsWeightsManager) -> None: + with torch.device(torch.cuda.current_device()): + position_embedding_type = self.config.position_embedding_type + + if position_embedding_type == "rope": + self.transformer.rope.reset_parameters() + + state_dict = self.__class__.model_parallel_state_dict_function( + config=self.config, + safetensors_weights_manager=safetensors_weights_manager, + num_pipeline_stages=self.num_pipeline_stages, + pipeline_stage_id=self.pipeline_stage_id, + ) + + self.load_state_dict(state_dict) diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index 24a64b547..88c29add9 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -197,19 +197,3 @@ def from_pretrained( model.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(pretrained_model_name_or_path)) return model - - def load_from_safetensors_weights_manager(self, safetensors_weights_manager: SafeTensorsWeightsManager) -> None: - with torch.device(torch.cuda.current_device()): - position_embedding_type = self.config.position_embedding_type - - if position_embedding_type == "rope": - self.transformer.rope.reset_parameters() - - state_dict = self.__class__.model_parallel_state_dict_function( - config=self.config, - safetensors_weights_manager=safetensors_weights_manager, - num_pipeline_stages=self.num_pipeline_stages, - pipeline_stage_id=self.pipeline_stage_id, - ) - - self.load_state_dict(state_dict) From f6b091078687e51e36f9d848f101c2149278ee2e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 15:56:14 -0800 Subject: [PATCH 95/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/mlp.py | 5 +--- lm_engine/hf_models/mixins/dense/base.py | 16 ++++++++++-- lm_engine/hf_models/mixins/dense/main.py | 26 ++++++++++++++----- lm_engine/hf_models/mixins/dense_TP/main.py | 20 -------------- .../hf_models/modeling_utils/embedding.py | 2 +- lm_engine/hf_models/modeling_utils/lm_head.py | 24 +++++++++++------ .../modeling_utils/mlp_blocks/__init__.py | 3 ++- .../modeling_utils/mlp_blocks/moe.py | 5 ++-- lm_engine/utils/parallel.py | 3 +++ tests/hf_models/test_common.py | 1 - .../params_group/params_group_test.py | 11 +++----- 11 files changed, 63 insertions(+), 53 deletions(-) diff --git a/lm_engine/hf_models/config/mlp.py b/lm_engine/hf_models/config/mlp.py index 05ee4144d..699493423 100644 --- a/lm_engine/hf_models/config/mlp.py +++ b/lm_engine/hf_models/config/mlp.py @@ -18,11 +18,8 @@ def model_post_init(self, __context: Any) -> None: assert self.mlp_type == "MLP" -class _MoEArgs(BaseArgs): +class _MoEArgs(_MLPArgs): mlp_type: str = "MoE" - intermediate_size: int - activation_function: str = "gelu_pytorch_tanh" - dropout: float = 0 shared_intermediate_size: int | None = None num_experts: int = 8 use_interleaved_weights: bool = False diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 4c8e66cf0..9aa134a0e 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -107,7 +107,13 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: config.sequence_mixer_blocks[i].sequence_mixer_type for i in range(config.num_layers) ] - self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) + self.wte = ParameterizedEmbedding( + config.vocab_size, + self.embed_dim, + std=self.initializer_range, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) self.embedding_dropout = Dropout(config.embedding_dropout) self.h = nn.ModuleList( @@ -328,7 +334,13 @@ def _setup_positional_encoding(self) -> None: max_position_embeddings = self.config.max_position_embeddings if self.position_embedding_type == "learned_absolute": - self.wpe = ParameterizedEmbedding(max_position_embeddings, self.embed_dim, std=self.initializer_range) + self.wpe = ParameterizedEmbedding( + max_position_embeddings, + self.embed_dim, + std=self.initializer_range, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=False, + ) elif self.position_embedding_type == "rope": if self.config.rope_scaling is None: self.rope = RoPE( diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index e68a651e5..b16ba0f1c 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -14,13 +14,14 @@ from ...cache import GenerationCache from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero -from ...modeling_utils import ParameterizedLinear +from ...modeling_utils import LMHead, ParameterizedLinear from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .base import PreTrainedModelMixin class CausalLMModelMixin(PreTrainedModelMixin): base_model_class = None + model_parallel_state_dict_function = None def __init__(self, config: CommonConfig, **kwargs) -> CausalLMModelMixin: super().__init__(config, **kwargs) @@ -29,14 +30,25 @@ def __init__(self, config: CommonConfig, **kwargs) -> CausalLMModelMixin: self._init_model(config, **kwargs) def _init_model(self, config: CommonConfig, **kwargs) -> None: + self.vocab_size = config.vocab_size self.transformer = self.base_model_class(config, **kwargs) - if not self._tied_word_embeddings: - self.lm_head = ParameterizedLinear( - config.hidden_size, config.vocab_size, bias=False, std=config.initializer_range - ) + if self.is_last_stage: + if not self._tied_word_embeddings: + self.lm_head = LMHead( + self.vocab_size, + config.hidden_size, + std=config.initializer_range, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + + self.m_width = config.m_width + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() - self.m_width = config.m_width + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() def forward( self, @@ -271,7 +283,7 @@ def get_dummy_output_tensor( output_parallel_lm_logits_if_possible: bool, ) -> tuple[torch.Tensor] | torch.Tensor: if self.is_last_stage: - vocab_size = self.config.vocab_size + vocab_size = self.vocab_size if output_parallel_lm_logits_if_possible: vocab_size = divide_if_divisible(vocab_size, ProcessGroupManager.get_tensor_parallel_world_size(), "") diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index 88c29add9..673e33fc0 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -32,26 +32,6 @@ class CausalLMModelMixin_TP(CausalLMModelMixin): - model_parallel_state_dict_function = None - - def _init_model(self, config: CommonConfig, **kwargs) -> None: - self.vocab_size = config.vocab_size - self.transformer = self.base_model_class(config, **kwargs) - - if self.is_last_stage: - if not self._tied_word_embeddings: - self.lm_head = LMHead( - self.vocab_size, - config.hidden_size, - std=config.initializer_range, - use_padding_free_transformer=self.use_padding_free_transformer, - sequence_parallel=self.sequence_parallel, - ) - - self.m_width = config.m_width - - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - def forward( self, input_ids: torch.Tensor | list[list[int]] | None = None, diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index aa0b48d38..f129d46cc 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -30,10 +30,10 @@ def __init__( nn.Module.__init__(self) self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + self.use_padding_free_transformer = use_padding_free_transformer if self.is_tp_enabled: self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - self.use_padding_free_transformer = use_padding_free_transformer self.sequence_parallel = sequence_parallel self.vocab_start_index, self.vocab_end_index, num_embeddings_per_tp_rank = get_tensor_parallel_vocab_info( diff --git a/lm_engine/hf_models/modeling_utils/lm_head.py b/lm_engine/hf_models/modeling_utils/lm_head.py index 5b9cdf394..44a890e36 100644 --- a/lm_engine/hf_models/modeling_utils/lm_head.py +++ b/lm_engine/hf_models/modeling_utils/lm_head.py @@ -14,13 +14,18 @@ class LMHead(ParameterizedEmbedding): def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.compute_with_weight( - x, - self.weight, - use_padding_free_transformer=self.use_padding_free_transformer, - sequence_parallel=self.sequence_parallel, - tp_mesh=self.tp_mesh, - ) + if self.is_tp_enabled: + x = self.compute_with_weight( + x, + self.weight, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + tp_mesh=self.tp_mesh if self.is_tp_enabled else None, + ) + else: + x = F.linear(x, weight=self.weight) + + return x @staticmethod def compute_with_weight( @@ -28,8 +33,11 @@ def compute_with_weight( weight: torch.Tensor, use_padding_free_transformer: bool, sequence_parallel: bool, - tp_mesh: DeviceMesh, + tp_mesh: DeviceMesh | None, ) -> torch.Tensor: + if tp_mesh is None: + return F.linear(x, weight=weight) + function = LMHead._compute_with_weight_compiled if use_async_tensor_parallel() else LMHead._compute_with_weight return function( diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index c3e408f8d..1520672ec 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -17,6 +17,7 @@ def get_mlp_block( hidden_size=config.hidden_size, intermediate_size=block.intermediate_size, activation_function=block.activation_function, + add_bias=block.add_bias, dropout=block.dropout, init_method=config.init_method, initializer_range=config.initializer_range, @@ -27,7 +28,7 @@ def get_mlp_block( ) if mlp_type == "MLP": - mlp = MLP(**kwargs, add_bias=block.add_bias) + mlp = MLP(**kwargs) elif mlp_type == "MoE": mlp = MoE( **kwargs, diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index eb08a7b49..8724bfecc 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -233,6 +233,7 @@ def __init__( normalized_topk: bool, num_experts: int, num_experts_per_tok: int, + add_bias: bool, activation_function: str, dropout: float, init_method: str, @@ -278,7 +279,7 @@ def __init__( out_features=( 2 * self.shared_intermediate_size if is_glu(activation_function) else self.shared_intermediate_size ), - bias=False, + bias=add_bias, std=std, ) @@ -295,7 +296,7 @@ def __init__( self.c_proj_shared = SharedExpertsRowParallelLinear( in_features=self.shared_intermediate_size, out_features=self.hidden_size, - bias=False, + bias=add_bias, std=std, ) diff --git a/lm_engine/utils/parallel.py b/lm_engine/utils/parallel.py index 689710b4b..904b6d216 100644 --- a/lm_engine/utils/parallel.py +++ b/lm_engine/utils/parallel.py @@ -216,6 +216,9 @@ def set_dummy_tensor_parallel_rank(rank: int): @staticmethod def get_tensor_parallel_world_size() -> int: + if not ProcessGroupManager.is_tensor_parallel_enabled(): + return 1 + global _TENSOR_PARALLEL_WORLD_SIZE if _TENSOR_PARALLEL_WORLD_SIZE is None: diff --git a/tests/hf_models/test_common.py b/tests/hf_models/test_common.py index 914760684..719a8df92 100644 --- a/tests/hf_models/test_common.py +++ b/tests/hf_models/test_common.py @@ -131,7 +131,6 @@ def get_moe_test_config( "num_experts_per_tok": num_experts_per_tok, "normalized_topk": normalized_topk, "activation_function": activation_function, - "add_bias": add_bias, "shared_intermediate_size": None if shared_n_inner is None else shared_n_inner, "shared_expert_gating": shared_expert_gating, } diff --git a/tests/training/params_group/params_group_test.py b/tests/training/params_group/params_group_test.py index 6a3dcc59a..20b3e6c4f 100644 --- a/tests/training/params_group/params_group_test.py +++ b/tests/training/params_group/params_group_test.py @@ -8,12 +8,9 @@ import torch from parameterized import parameterized -from lm_engine.distributed import ( - _get_parameter_marker_maps, - _set_parameter_marker_maps, - wrap_model_container_for_distributed_training, -) +from lm_engine.distributed import wrap_model_container_for_distributed_training from lm_engine.enums import ParamsGroupMethod +from lm_engine.hf_models import get_parameter_marker_maps, set_parameter_marker_maps from lm_engine.model_wrapper import get_model_container from lm_engine.optimization.params_group import get_param_groups_list from lm_engine.utils import ProcessGroupManager @@ -55,9 +52,9 @@ def test_mup_group( if use_fsdp: model_container, _ = wrap_model_container_for_distributed_training(args, model_container) elif use_torch_compile: - marker_maps = _get_parameter_marker_maps(model_container) + marker_maps = get_parameter_marker_maps(model_container) model_container = [torch.compile(model) for model in model_container] - _set_parameter_marker_maps(model_container, marker_maps) + set_parameter_marker_maps(model_container, marker_maps) params_groups = get_param_groups_list(model_container, args.optimizer_args.class_args, params_group_method)[0] From 7b7261bf79d34b913a4618efd1e6b7a3c805119b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 17:11:43 -0800 Subject: [PATCH 96/99] fix linter Signed-off-by: Mayank Mishra --- lm_engine/hf_models/config/sequence_mixer.py | 5 - lm_engine/hf_models/mixins/dense/main.py | 21 +- lm_engine/hf_models/mixins/dense_TP/main.py | 18 + .../hf_models/model_conversion/__init__.py | 8 - .../hf_models/model_conversion/qwen2_moe.py | 345 ------------------ .../hf_models/modeling_utils/embedding.py | 14 +- .../modeling_utils/mlp_blocks/moe.py | 56 ++- .../sequence_mixer_blocks/__init__.py | 1 - .../sequence_mixer_blocks/attention.py | 2 - lm_engine/hf_models/parameter.py | 8 +- .../single_gpu/model_conversion_test.py | 30 -- tests/hf_models/test_common.py | 4 - .../params_group/params_group_test.py | 2 +- 13 files changed, 82 insertions(+), 432 deletions(-) delete mode 100644 lm_engine/hf_models/model_conversion/qwen2_moe.py diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 4cd8537f8..a2eec5652 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -16,13 +16,8 @@ class _SoftmaxAttentionArgs(BaseArgs): add_bias: bool = False attention_multiplier: float | None = None sliding_window: int | None = None - # needed for Qwen 2 MoE - qkv_bias: bool = None def model_post_init(self, __context: Any) -> None: - if self.qkv_bias is None: - self.qkv_bias = self.add_bias - assert self.sequence_mixer_type == "softmax_attention" diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index b16ba0f1c..273d2dc4f 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -21,7 +21,6 @@ class CausalLMModelMixin(PreTrainedModelMixin): base_model_class = None - model_parallel_state_dict_function = None def __init__(self, config: CommonConfig, **kwargs) -> CausalLMModelMixin: super().__init__(config, **kwargs) @@ -46,9 +45,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.m_width = config.m_width self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() - - if self.is_tp_enabled: - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() if self.is_tp_enabled else None def forward( self, @@ -339,19 +336,3 @@ def _get_dummy_intermediate_tensor( ) return tensor - - def load_from_safetensors_weights_manager(self, safetensors_weights_manager: SafeTensorsWeightsManager) -> None: - with torch.device(torch.cuda.current_device()): - position_embedding_type = self.config.position_embedding_type - - if position_embedding_type == "rope": - self.transformer.rope.reset_parameters() - - state_dict = self.__class__.model_parallel_state_dict_function( - config=self.config, - safetensors_weights_manager=safetensors_weights_manager, - num_pipeline_stages=self.num_pipeline_stages, - pipeline_stage_id=self.pipeline_stage_id, - ) - - self.load_state_dict(state_dict) diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index 673e33fc0..9b23a9f9d 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -32,6 +32,8 @@ class CausalLMModelMixin_TP(CausalLMModelMixin): + model_parallel_state_dict_function = None + def forward( self, input_ids: torch.Tensor | list[list[int]] | None = None, @@ -177,3 +179,19 @@ def from_pretrained( model.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(pretrained_model_name_or_path)) return model + + def load_from_safetensors_weights_manager(self, safetensors_weights_manager: SafeTensorsWeightsManager) -> None: + with torch.device(torch.cuda.current_device()): + position_embedding_type = self.config.position_embedding_type + + if position_embedding_type == "rope": + self.transformer.rope.reset_parameters() + + state_dict = self.__class__.model_parallel_state_dict_function( + config=self.config, + safetensors_weights_manager=safetensors_weights_manager, + num_pipeline_stages=self.num_pipeline_stages, + pipeline_stage_id=self.pipeline_stage_id, + ) + + self.load_state_dict(state_dict) diff --git a/lm_engine/hf_models/model_conversion/__init__.py b/lm_engine/hf_models/model_conversion/__init__.py index bb1214461..8c4820e94 100644 --- a/lm_engine/hf_models/model_conversion/__init__.py +++ b/lm_engine/hf_models/model_conversion/__init__.py @@ -22,12 +22,6 @@ _import_granitemoeshared_state_dict, ) from .llama import _export_llama_config, _export_llama_state_dict, _import_llama_config, _import_llama_state_dict -from .qwen2_moe import ( - _export_qwen2_moe_config, - _export_qwen2_moe_state_dict, - _import_qwen2_moe_config, - _import_qwen2_moe_state_dict, -) _MODEL_IMPORT_FUNCTIONS = { @@ -36,7 +30,6 @@ "granitemoeshared": (_import_granitemoeshared_config, _import_granitemoeshared_state_dict), "granitemoehybrid": (_import_granitemoehybrid_config, _import_granitemoehybrid_state_dict), "llama": (_import_llama_config, _import_llama_state_dict), - "qwen2_moe": (_import_qwen2_moe_config, _import_qwen2_moe_state_dict), } @@ -77,7 +70,6 @@ def import_from_huggingface( "granitemoeshared": (_export_granitemoeshared_config, _export_granitemoeshared_state_dict), "granitemoehybrid": (_export_granitemoehybrid_config, _export_granitemoehybrid_state_dict), "llama": (_export_llama_config, _export_llama_state_dict), - "qwen2_moe": (_export_qwen2_moe_config, _export_qwen2_moe_state_dict), } diff --git a/lm_engine/hf_models/model_conversion/qwen2_moe.py b/lm_engine/hf_models/model_conversion/qwen2_moe.py deleted file mode 100644 index 9fe75906e..000000000 --- a/lm_engine/hf_models/model_conversion/qwen2_moe.py +++ /dev/null @@ -1,345 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch -from transformers import Qwen2MoeConfig, Qwen2MoeForCausalLM - -from ...utils import SafeTensorsWeightsManager, divide_if_divisible -from ..modeling_utils import ( - interleave_query_key_value_tensor_for_attention, - interleave_up_gate_tensor_for_mlp, - split_query_key_value_tensor_for_attention, - split_up_gate_tensor_for_mlp, -) -from ..models import GPTBaseConfig - - -def _import_qwen2_moe_config(original_config: Qwen2MoeConfig) -> GPTBaseConfig: - assert original_config.hidden_act == "silu" - - mlp_blocks = [] - for layer_idx in range(original_config.num_hidden_layers): - if (layer_idx not in original_config.mlp_only_layers) and ( - original_config.num_experts > 0 and (layer_idx + 1) % original_config.decoder_sparse_step == 0 - ): - mlp_block = { - "mlp_type": "MoE", - "intermediate_size": original_config.moe_intermediate_size, - "shared_intermediate_size": original_config.shared_expert_intermediate_size, - "shared_expert_gating": True, - "num_experts": original_config.num_experts, - "num_experts_per_tok": original_config.num_experts_per_tok, - "activation_function": "swiglu", - "add_bias": False, - "normalized_topk": original_config.norm_topk_prob, - } - else: - mlp_block = { - "mlp_type": "MLP", - "intermediate_size": original_config.intermediate_size, - "activation_function": "swiglu", - "add_bias": False, - } - - mlp_blocks.append(mlp_block) - - sequence_mixer_blocks = [] - for layer_idx in range(original_config.num_hidden_layers): - sliding_window = None - if original_config.use_sliding_window and layer_idx >= original_config.max_window_layers: - sliding_window = original_config.sliding_window - - sequence_mixer_block = { - "sequence_mixer_type": "softmax_attention", - "num_attention_heads": original_config.num_attention_heads, - "num_key_value_heads": original_config.num_key_value_heads, - "add_bias": False, - "sliding_window": sliding_window, - "qkv_bias": original_config.qkv_bias, - "softmax_dropout": original_config.attention_dropout, - } - - sequence_mixer_blocks.append(sequence_mixer_block) - - config = GPTBaseConfig( - vocab_size=original_config.vocab_size, - max_position_embeddings=original_config.max_position_embeddings, - hidden_size=original_config.hidden_size, - num_layers=original_config.num_hidden_layers, - position_embedding_type="rope", - normalization_function="rmsnorm", - layer_norm_epsilon=original_config.rms_norm_eps, - use_cache=original_config.use_cache, - tie_word_embeddings=original_config.tie_word_embeddings, - initializer_range=original_config.initializer_range, - rope_theta=original_config.rope_theta, - rope_scaling=original_config.rope_scaling, - router_aux_loss_coef=original_config.router_aux_loss_coef, - bos_token_id=original_config.bos_token_id, - eos_token_id=original_config.eos_token_id, - pad_token_id=original_config.pad_token_id, - sequence_mixer_blocks=sequence_mixer_blocks, - mlp_blocks=mlp_blocks, - ) - - return config - - -def _import_qwen2_moe_state_dict( - config: GPTBaseConfig, safetensors_weights_manager: SafeTensorsWeightsManager -) -> dict: - num_attention_heads = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_attention_heads") - num_key_value_heads = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_key_value_heads") - qkv_bias = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "qkv_bias") - head_dim = divide_if_divisible(config.hidden_size, num_attention_heads, "") - - state_dict = { - "transformer.wte.weight": safetensors_weights_manager.get_tensor("model.embed_tokens.weight"), - "transformer.ln_f.weight": safetensors_weights_manager.get_tensor("model.norm.weight"), - } - - if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") - - for layer_idx in range(config.num_layers): - state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.input_layernorm.weight" - ) - state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.post_attention_layernorm.weight" - ) - - # MoE - if safetensors_weights_manager.has_tensor(f"model.layers.{layer_idx}.mlp.gate.weight"): - state_dict[f"transformer.h.{layer_idx}.mlp_block.gate.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.gate.weight" - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_fc.weight"] = torch.stack( - [ - interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight" - ), - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" - ), - ) - for expert_idx in range(config.mlp_blocks[layer_idx].num_experts) - ] - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_proj.weight"] = torch.stack( - [ - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight" - ) - for expert_idx in range(config.mlp_blocks[layer_idx].num_experts) - ] - ) - - if safetensors_weights_manager.has_tensor(f"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight"): - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_fc_shared.weight"] = ( - interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight" - ), - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight" - ), - ) - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_proj_shared.weight"] = ( - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight" - ) - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.shared_expert_gate.weight"] = ( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.shared_expert_gate.weight") - ) - # MLP - else: - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_fc.weight"] = interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.weight"), - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.weight"), - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.down_proj.weight" - ) - - keys = ["weight"] + (["bias"] if qkv_bias else []) - for key in keys: - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.c_attn.{key}"] = ( - interleave_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.{key}"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.{key}"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.{key}"), - num_attention_heads, - num_key_value_heads, - head_dim, - ) - ) - - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.weight" - ) - - return state_dict - - -def _export_qwen2_moe_config(config: GPTBaseConfig) -> Qwen2MoeConfig: - assert config.normalization_function == "rmsnorm" - assert config.position_embedding_type == "rope" - - config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "add_bias", False) - config.check_equal_for_all_and_get_value("mlp_blocks", "add_bias", False) - config.check_equal_for_all_and_get_value("mlp_blocks", "activation_function", "swiglu") - - mlp_only_layers = [ - layer_idx for layer_idx, mlp_block in enumerate(config.mlp_blocks) if mlp_block.mlp_type == "MLP" - ] - - max_window_layers = None - use_sliding_window = False - for layer_idx in range(config.num_layers): - block = config.sequence_mixer_blocks[layer_idx] - if config.sequence_mixer_blocks[layer_idx]: - use_sliding_window = use_sliding_window or block.sliding_window is not None - if max_window_layers is None and use_sliding_window: - max_window_layers = layer_idx - - original_config = Qwen2MoeConfig( - vocab_size=config.vocab_size, - max_position_embeddings=config.max_position_embeddings, - hidden_size=config.hidden_size, - num_hidden_layers=config.num_layers, - num_attention_heads=config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_attention_heads"), - num_key_value_heads=config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_key_value_heads"), - intermediate_size=config.check_equal_for_all_and_get_value("mlp_blocks", "intermediate_size", mlp_type="MLP"), - moe_intermediate_size=config.check_equal_for_all_and_get_value( - "mlp_blocks", "intermediate_size", mlp_type="MoE" - ), - shared_expert_intermediate_size=config.check_equal_for_all_and_get_value( - "mlp_blocks", "shared_intermediate_size", mlp_type="MoE" - ), - hidden_act="silu", - rms_norm_eps=config.layer_norm_epsilon, - use_cache=config.use_cache, - use_sliding_window=use_sliding_window, - max_window_layers=max_window_layers, - tie_word_embeddings=config.tie_word_embeddings, - initializer_range=config.initializer_range, - rope_theta=config.rope_theta, - rope_scaling=config.rope_scaling, - attention_dropout=config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "softmax_dropout"), - num_experts=config.check_equal_for_all_and_get_value("mlp_blocks", "num_experts", mlp_type="MoE"), - num_experts_per_tok=config.check_equal_for_all_and_get_value( - "mlp_blocks", "num_experts_per_tok", mlp_type="MoE" - ), - router_aux_loss_coef=config.router_aux_loss_coef, - bos_token_id=config.bos_token_id, - eos_token_id=config.eos_token_id, - pad_token_id=config.pad_token_id, - qkv_bias=config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "qkv_bias"), - mlp_only_layers=mlp_only_layers, - norm_topk_prob=config.check_equal_for_all_and_get_value("mlp_blocks", "normalized_topk", mlp_type="MoE"), - architectures=[Qwen2MoeForCausalLM.__name__], - ) - - return original_config - - -def _export_qwen2_moe_state_dict( - config: GPTBaseConfig, safetensors_weights_manager: SafeTensorsWeightsManager -) -> dict: - num_attention_heads = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_attention_heads") - num_key_value_heads = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_key_value_heads") - qkv_bias = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "qkv_bias") - - state_dict = { - "model.embed_tokens.weight": safetensors_weights_manager.get_tensor("transformer.wte.weight"), - "model.norm.weight": safetensors_weights_manager.get_tensor("transformer.ln_f.weight"), - } - - if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") - - for layer_idx in range(config.num_layers): - state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.ln_1.weight" - ) - state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.ln_2.weight") - ) - - # MoE layer - if safetensors_weights_manager.has_tensor(f"transformer.h.{layer_idx}.mlp_block.gate.weight"): - state_dict[f"model.layers.{layer_idx}.mlp.gate.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp_block.gate.weight" - ) - - for expert_idx in range(config.mlp_blocks[layer_idx].num_experts): - up_weight, gate_weight = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc.weight")[ - expert_idx - ] - ) - - state_dict[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight"] = up_weight - state_dict[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight"] = gate_weight - - state_dict[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_proj.weight")[ - expert_idx - ] - ) - - if safetensors_weights_manager.has_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc_shared.weight"): - up_weight, gate_weight = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc_shared.weight") - ) - - state_dict[f"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight"] = gate_weight - state_dict[f"model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight"] = up_weight - state_dict[f"model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_proj_shared.weight") - ) - - state_dict[f"model.layers.{layer_idx}.mlp.shared_expert_gate.weight"] = ( - safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp_block.shared_expert_gate.weight" - ) - ) - # MLP layer - else: - up_weight, gate_weight = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc.weight") - ) - - state_dict[f"model.layers.{layer_idx}.mlp.up_proj.weight"] = up_weight - state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.weight"] = gate_weight - - state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp_block.c_proj.weight" - ) - - keys = ["weight"] + (["bias"] if qkv_bias else []) - for key in keys: - query_weight, key_weight, value_weight = split_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.sequence_mixer.c_attn.{key}"), - num_attention_heads, - num_key_value_heads, - ) - state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.{key}"] = query_weight - state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.{key}"] = key_weight - state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.{key}"] = value_weight - - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.sequence_mixer.c_proj.weight" - ) - - return state_dict diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index f129d46cc..98119bba0 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -27,7 +27,10 @@ def __init__( use_padding_free_transformer: bool = False, sequence_parallel: bool = False, ) -> ParameterizedEmbedding: - nn.Module.__init__(self) + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() self.use_padding_free_transformer = use_padding_free_transformer @@ -68,13 +71,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.no_grad() def reset_parameters(self) -> None: - if self.std is None: - super().reset_parameters() - else: - self.weight.normal_(mean=0, std=self.std) - + self.weight.normal_(mean=0, std=1 if self.std is None else self.std) mark_parameter_as_initialized(self.weight) + def extra_repr(self) -> str: + return f"{self.num_embeddings}, {self.embedding_dim}" + def get_tensor_parallel_vocab_info(vocab_size: int, make_vocab_size_divisible_by: int = 64) -> tuple[int, int, int]: tp_rank = ProcessGroupManager.get_tensor_parallel_rank() diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 8724bfecc..c5c8f9fb2 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -17,7 +17,11 @@ from ....kernels import is_kernel_allowed, wait_for_ACT from ....utils import ProcessGroupManager, divide_if_divisible, is_sonicmoe_available, is_xma_available from ...loss import add_aux_loss -from ...parameter import mark_parameter_as_initialized, mark_parameter_as_mup_learning_rate +from ...parameter import ( + mark_parameter_as_initialized, + mark_parameter_as_mup_learning_rate, + mark_parameter_as_no_weight_decay, +) from ..activations import get_activation_function, is_glu from ..dropout import Dropout from ..dtensor_module import DTensorModule @@ -66,13 +70,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class ParameterizedExperts(nn.Module): def __init__( - self, num_experts: int, in_features: int, out_features: int, std: float | None = None + self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None ) -> ParameterizedExperts: super().__init__() - self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features)) self.std = std + self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features)) + + self.bias = None + if add_bias: + self.bias = nn.Parameter(torch.empty(num_experts, out_features)) + + mark_parameter_as_no_weight_decay(self.bias) + self.num_experts = num_experts self.in_features = in_features self.out_features = out_features @@ -118,12 +129,16 @@ def extra_repr(self) -> str: @torch.no_grad() def reset_parameters(self) -> None: nn.init.normal_(self.weight, mean=0, std=self.std) + if self.bias is not None: + nn.init.zeros_(self.bias) + mark_parameter_as_initialized(self.weight) + mark_parameter_as_initialized(self.bias) class ColumnParallelExperts(ParameterizedExperts, DTensorModule): def __init__( - self, num_experts: int, in_features: int, out_features: int, std: float | None = None + self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None ) -> ColumnParallelExperts: tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() @@ -134,12 +149,18 @@ def __init__( ) super().__init__( - num_experts=num_experts, in_features=in_features, out_features=self.out_features_per_tp_rank, std=std + num_experts=num_experts, + in_features=in_features, + out_features=self.out_features_per_tp_rank, + add_bias=add_bias, + std=std, ) self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() if self.is_tp_enabled: + assert not add_bias + self.weight = nn.Parameter( tensor_to_dtensor( self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) @@ -194,23 +215,30 @@ def extra_repr(self) -> str: class RowParallelExperts(ColumnParallelExperts): def __init__( - self, num_experts: int, in_features: int, out_features: int, std: float | None = None + self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None ) -> RowParallelExperts: tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - self.in_features_per_device = divide_if_divisible( + self.in_features_per_tp_rank = divide_if_divisible( in_features, tp_world_size, f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", ) ParameterizedExperts.__init__( - self, num_experts=num_experts, in_features=self.in_features_per_device, out_features=out_features, std=std + self, + num_experts=num_experts, + in_features=self.in_features_per_tp_rank, + out_features=out_features, + add_bias=add_bias, + std=std, ) self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() if self.is_tp_enabled: + assert not add_bias + self.weight = nn.Parameter( tensor_to_dtensor( self.weight, @@ -221,6 +249,11 @@ def __init__( self.reset_parameters() + def extra_repr(self) -> str: + return "num_experts={}, in_features_per_tp_rank={}, out_features={}".format( + self.num_experts, self.in_features_per_tp_rank, self.out_features + ) + class MoE(nn.Module): def __init__( @@ -270,6 +303,7 @@ def __init__( num_experts=num_experts, in_features=self.hidden_size, out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size, + add_bias=add_bias, std=std, ) @@ -289,7 +323,11 @@ def __init__( std /= math.sqrt(2 * num_layers) self.c_proj = RowParallelExperts( - num_experts=num_experts, in_features=self.intermediate_size, out_features=self.hidden_size, std=std + num_experts=num_experts, + in_features=self.intermediate_size, + out_features=self.hidden_size, + add_bias=add_bias, + std=std, ) if self.shared_intermediate_size is not None: diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 643364b0a..628a25c97 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -158,7 +158,6 @@ def get_sequence_mixer( if sequence_mixer_type == "softmax_attention": return Attention( **sequence_mixer_kwargs, - qkv_bias=block.qkv_bias, softmax_dropout=block.softmax_dropout, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index 2035c2722..e0d38fafc 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -77,7 +77,6 @@ def __init__( sliding_window: int | None, position_embedding_type: str, add_bias: bool, - qkv_bias: bool, softmax_dropout: float, dropout: float, init_method: str, @@ -98,7 +97,6 @@ def __init__( self.global_num_heads = num_attention_heads self.global_num_key_value_heads = num_key_value_heads self.add_bias = add_bias - self.qkv_bias = qkv_bias self.sliding_window = sliding_window self.use_padding_free_transformer = use_padding_free_transformer diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 33427b9b3..0a816f0a0 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -59,7 +59,10 @@ def get_parameter_marker_maps(model_container: list[nn.Module], extra_markers: l def set_parameter_marker_maps( - model_container: list[nn.Module], marker_maps: list[dict], replacement_patterns: list[tuple[str]] = [] + model_container: list[nn.Module], + marker_maps: list[dict], + replacement_patterns: list[tuple[str]] = [], + _trim_prefix: str | None = None, ) -> None: if isinstance(model_container, nn.Module): model_container = [model_container] @@ -69,5 +72,8 @@ def set_parameter_marker_maps( for pattern, replacement in replacement_patterns: param_name = param_name.replace(pattern, replacement) + if _trim_prefix is not None: + param_name = param_name.removeprefix(_trim_prefix) + for marker, value in _marker_map[param_name].items(): setattr(parameter, marker, value) diff --git a/tests/hf_models/single_gpu/model_conversion_test.py b/tests/hf_models/single_gpu/model_conversion_test.py index 6943cf312..9793cb7e4 100644 --- a/tests/hf_models/single_gpu/model_conversion_test.py +++ b/tests/hf_models/single_gpu/model_conversion_test.py @@ -104,33 +104,3 @@ def test_granitemoehybrid_model_conversion(self, device: torch.device, is_moe: b compare_loss=False, logits_atol_float32=2.5e-5, ) - - @parameterized.expand(TestCommons.make_args_matrix(TestCommons.get_all_devices(), [False, True], [False, True])) - def test_qwen2_moe_model_conversion(self, device: torch.device, qkv_bias: bool, use_sliding_window: bool) -> None: - lm_engine_config = self.get_moe_test_config( - "rope", - qkv_bias=qkv_bias, - shared_n_inner=36, - activation_function="swiglu", - normalization_function="rmsnorm", - shared_expert_gating=True, - normalized_topk=False, - ) - - for layer_idx in [3, 6]: - mlp_block = lm_engine_config.mlp_blocks[layer_idx] - lm_engine_config.mlp_blocks[layer_idx] = _MLPArgs( - intermediate_size=mlp_block.intermediate_size, activation_function=mlp_block.activation_function - ) - - if use_sliding_window: - for layer_idx in range(3, lm_engine_config.num_layers): - lm_engine_config.sequence_mixer_blocks[layer_idx].sliding_window = 4096 - - self.model_conversion_test( - lm_engine_config=lm_engine_config, - model_type="qwen2_moe", - device=device, - exact_match=False, - weight_test_only=use_sliding_window, - ) diff --git a/tests/hf_models/test_common.py b/tests/hf_models/test_common.py index 719a8df92..7c16609b9 100644 --- a/tests/hf_models/test_common.py +++ b/tests/hf_models/test_common.py @@ -94,7 +94,6 @@ def get_moe_test_config( num_attention_heads: int = 4, shared_expert_gating: bool = False, normalized_topk: bool = True, - qkv_bias: bool = None, ) -> GPTBaseConfig: num_key_value_heads = 2 @@ -106,9 +105,6 @@ def get_moe_test_config( "attention_multiplier": attention_multiplier, } - if qkv_bias is not None: - sequence_mixer["qkv_bias"] = qkv_bias - return GPTBaseConfig( vocab_size=2048, max_position_embeddings=1024, diff --git a/tests/training/params_group/params_group_test.py b/tests/training/params_group/params_group_test.py index 20b3e6c4f..414e15e47 100644 --- a/tests/training/params_group/params_group_test.py +++ b/tests/training/params_group/params_group_test.py @@ -54,7 +54,7 @@ def test_mup_group( elif use_torch_compile: marker_maps = get_parameter_marker_maps(model_container) model_container = [torch.compile(model) for model in model_container] - set_parameter_marker_maps(model_container, marker_maps) + set_parameter_marker_maps(model_container, marker_maps, _trim_prefix="_orig_mod.") params_groups = get_param_groups_list(model_container, args.optimizer_args.class_args, params_group_method)[0] From b12b17afa3fc0a12cc057f829b86a56d06c52b27 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 17:21:23 -0800 Subject: [PATCH 97/99] fix tp Signed-off-by: Mayank Mishra --- lm_engine/hf_models/modeling_utils/linear/column.py | 5 ++--- lm_engine/hf_models/modeling_utils/linear/row.py | 5 ++--- lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py | 10 ++++------ .../modeling_utils/sequence_mixer_blocks/attention.py | 9 ++++++++- lm_engine/utils/parallel.py | 3 --- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/linear/column.py b/lm_engine/hf_models/modeling_utils/linear/column.py index 7af0dc8ff..72c54a68e 100644 --- a/lm_engine/hf_models/modeling_utils/linear/column.py +++ b/lm_engine/hf_models/modeling_utils/linear/column.py @@ -25,7 +25,8 @@ def __init__( use_padding_free_transformer: bool = False, sequence_parallel: bool = False, ) -> ColumnParallelLinear: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() if self.is_tp_enabled else 1 self.out_features_per_tp_rank = divide_if_divisible( out_features, @@ -35,8 +36,6 @@ def __init__( super().__init__(in_features=in_features, out_features=self.out_features_per_tp_rank, bias=bias, std=std) - self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() - if self.is_tp_enabled: self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) diff --git a/lm_engine/hf_models/modeling_utils/linear/row.py b/lm_engine/hf_models/modeling_utils/linear/row.py index d7e80a94e..eb360156a 100644 --- a/lm_engine/hf_models/modeling_utils/linear/row.py +++ b/lm_engine/hf_models/modeling_utils/linear/row.py @@ -25,7 +25,8 @@ def __init__( use_padding_free_transformer: bool = False, sequence_parallel: bool = False, ) -> RowParallelLinear: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() if self.is_tp_enabled else 1 self.in_features_per_tp_rank = divide_if_divisible( in_features, @@ -35,8 +36,6 @@ def __init__( super().__init__(in_features=self.in_features_per_tp_rank, out_features=out_features, bias=bias, std=std) - self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() - if self.is_tp_enabled: self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index c5c8f9fb2..00df1fe77 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -140,7 +140,8 @@ class ColumnParallelExperts(ParameterizedExperts, DTensorModule): def __init__( self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None ) -> ColumnParallelExperts: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() if self.is_tp_enabled else 1 self.out_features_per_tp_rank = divide_if_divisible( out_features, @@ -156,8 +157,6 @@ def __init__( std=std, ) - self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() - if self.is_tp_enabled: assert not add_bias @@ -217,7 +216,8 @@ class RowParallelExperts(ColumnParallelExperts): def __init__( self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None ) -> RowParallelExperts: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() if self.is_tp_enabled else 1 self.in_features_per_tp_rank = divide_if_divisible( in_features, @@ -234,8 +234,6 @@ def __init__( std=std, ) - self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() - if self.is_tp_enabled: assert not add_bias diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index e0d38fafc..832387ce6 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -90,7 +90,14 @@ def __init__( ) -> Attention: super().__init__() - self.tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + if ProcessGroupManager.is_initialized(): + self.tp_world_size = ( + ProcessGroupManager.get_tensor_parallel_world_size() + if ProcessGroupManager.is_tensor_parallel_enabled() + else 1 + ) + else: + self.tp_world_size = 1 self.causal = causal self.global_hidden_size = hidden_size diff --git a/lm_engine/utils/parallel.py b/lm_engine/utils/parallel.py index 904b6d216..689710b4b 100644 --- a/lm_engine/utils/parallel.py +++ b/lm_engine/utils/parallel.py @@ -216,9 +216,6 @@ def set_dummy_tensor_parallel_rank(rank: int): @staticmethod def get_tensor_parallel_world_size() -> int: - if not ProcessGroupManager.is_tensor_parallel_enabled(): - return 1 - global _TENSOR_PARALLEL_WORLD_SIZE if _TENSOR_PARALLEL_WORLD_SIZE is None: From 4668bf098b82125ff69a76fd61903962ed7b6427 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 17:35:50 -0800 Subject: [PATCH 98/99] fix tp Signed-off-by: Mayank Mishra --- lm_engine/dtensors.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/lm_engine/dtensors.py b/lm_engine/dtensors.py index ff57a4406..4fe9a3a76 100644 --- a/lm_engine/dtensors.py +++ b/lm_engine/dtensors.py @@ -8,7 +8,11 @@ from torch.distributed._tensor.placement_types import Placement from torch.distributed.device_mesh import DeviceMesh -from .hf_models.parameter import _ALL_MARKERS + +def _get_all_markers(): + from .hf_models.parameter import _ALL_MARKERS + + return _ALL_MARKERS def tensor_to_dtensor( @@ -34,7 +38,7 @@ def tensor_to_dtensor( dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=True) if copy_marker: - for marker in _ALL_MARKERS: + for marker in _get_all_markers(): marker_value = getattr(dtensor, marker, None) if marker_value is not None: setattr(dtensor, marker, marker_value) @@ -66,7 +70,7 @@ def dtensor_to_tensor( tensor = dtensor.to_local(grad_placements=grad_placement) if copy_marker: - for marker in _ALL_MARKERS: + for marker in _get_all_markers(): marker_value = getattr(tensor, marker, None) if marker_value is not None: setattr(tensor, marker, marker_value) From 9f9eddcc9f4d90a356a76d8fa969be6e83aef9b4 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Mon, 9 Feb 2026 18:05:22 -0800 Subject: [PATCH 99/99] fix tp Signed-off-by: Mayank Mishra --- tests/hf_models/multi_gpu/unsharding/unsharding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/hf_models/multi_gpu/unsharding/unsharding.py b/tests/hf_models/multi_gpu/unsharding/unsharding.py index 2e927406f..18244ceed 100644 --- a/tests/hf_models/multi_gpu/unsharding/unsharding.py +++ b/tests/hf_models/multi_gpu/unsharding/unsharding.py @@ -60,7 +60,9 @@ if is_tp_first_rank: - model = TestCommons.from_config(None, config) + with ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): + model = TestCommons.from_config(None, config) + model.save_pretrained(args.tmp_path, safe_serialization=True) Communication.barrier()