Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
4724170
add
mayank31398 Feb 3, 2026
f49b5b1
drop post_init()
mayank31398 Feb 5, 2026
d4f1451
add check for init
mayank31398 Feb 5, 2026
ba585b7
add markers
mayank31398 Feb 5, 2026
a0f677c
add markers
mayank31398 Feb 5, 2026
58b9ac0
add markers
mayank31398 Feb 5, 2026
6e6e074
add markers
mayank31398 Feb 5, 2026
d775fc4
add markers
mayank31398 Feb 5, 2026
81a10f1
add markers
mayank31398 Feb 5, 2026
efde5ba
add markers
mayank31398 Feb 5, 2026
5f94a84
add markers
mayank31398 Feb 5, 2026
41ba4e9
add markers
mayank31398 Feb 5, 2026
c15a72f
add markers
mayank31398 Feb 5, 2026
b2b541e
add markers
mayank31398 Feb 5, 2026
ebb6143
add markers
mayank31398 Feb 5, 2026
babcaa4
add markers
mayank31398 Feb 5, 2026
dc4ef01
add markers
mayank31398 Feb 5, 2026
c2860ad
Merge branch 'main' into b
mayank31398 Feb 5, 2026
86db0ea
add GDN efficient init
mayank31398 Feb 8, 2026
db411cf
fix mamba2 init
mayank31398 Feb 8, 2026
06ce95f
pass args
mayank31398 Feb 8, 2026
3a21404
hidden_states -> x
mayank31398 Feb 9, 2026
b6dd81b
hidden_states -> x
mayank31398 Feb 9, 2026
8dc2c79
hidden_states -> x
mayank31398 Feb 9, 2026
3bec1f0
hidden_states -> x
mayank31398 Feb 9, 2026
52b14a3
use gate for GDN
mayank31398 Feb 9, 2026
b062d86
use gate for mamba2
mayank31398 Feb 9, 2026
87e7a44
use gate for mamba2
mayank31398 Feb 9, 2026
d1050e5
fix tests
mayank31398 Feb 9, 2026
750f9c0
fix tests
mayank31398 Feb 9, 2026
8e5903e
fix tests
mayank31398 Feb 9, 2026
cc76e24
fix tests
mayank31398 Feb 9, 2026
60e092a
fix tests
mayank31398 Feb 9, 2026
0685373
merge
mayank31398 Feb 9, 2026
740a636
merge
mayank31398 Feb 9, 2026
43e099c
merge
mayank31398 Feb 9, 2026
d1fd4fb
merge
mayank31398 Feb 9, 2026
913e08d
merge
mayank31398 Feb 9, 2026
037185f
merge
mayank31398 Feb 9, 2026
f031286
merge
mayank31398 Feb 9, 2026
a73e3dd
merge
mayank31398 Feb 9, 2026
d7c6eff
merge
mayank31398 Feb 9, 2026
636ce91
merge
mayank31398 Feb 9, 2026
95cf3f9
merge
mayank31398 Feb 9, 2026
5aa75c1
merge
mayank31398 Feb 9, 2026
c2a7675
merge
mayank31398 Feb 9, 2026
dcb892f
merge
mayank31398 Feb 9, 2026
9898645
merge
mayank31398 Feb 9, 2026
193c11f
merge
mayank31398 Feb 9, 2026
843f57d
count correctly
mayank31398 Feb 9, 2026
40efdca
count correctly
mayank31398 Feb 9, 2026
b049fde
count correctly
mayank31398 Feb 9, 2026
137ca9d
count correctly
mayank31398 Feb 9, 2026
79a69dd
norm
mayank31398 Feb 9, 2026
554ebe3
norm
mayank31398 Feb 9, 2026
60eadaa
norm
mayank31398 Feb 9, 2026
6394358
norm
mayank31398 Feb 9, 2026
0e29579
norm
mayank31398 Feb 9, 2026
4420224
norm
mayank31398 Feb 9, 2026
462acb1
norm
mayank31398 Feb 9, 2026
ec473cb
norm
mayank31398 Feb 9, 2026
4a1710c
norm
mayank31398 Feb 9, 2026
3297ed3
norm
mayank31398 Feb 9, 2026
93953c6
norm
mayank31398 Feb 9, 2026
278c7f6
norm
mayank31398 Feb 9, 2026
ed561e8
norm
mayank31398 Feb 9, 2026
57e7a87
norm
mayank31398 Feb 9, 2026
f4ccfea
norm
mayank31398 Feb 9, 2026
fd0b6b6
norm
mayank31398 Feb 9, 2026
6772ab0
norm
mayank31398 Feb 9, 2026
d19d927
norm
mayank31398 Feb 9, 2026
eb22d4c
norm
mayank31398 Feb 9, 2026
402b23c
fix linter
mayank31398 Feb 9, 2026
e14274b
fix linter
mayank31398 Feb 9, 2026
e7393d5
fix linter
mayank31398 Feb 9, 2026
ceb50c3
fix linter
mayank31398 Feb 9, 2026
de3d44f
fix linter
mayank31398 Feb 9, 2026
6fe2d8b
fix linter
mayank31398 Feb 9, 2026
a1bc719
fix linter
mayank31398 Feb 9, 2026
422daf5
fix linter
mayank31398 Feb 9, 2026
7f94234
fix linter
mayank31398 Feb 9, 2026
68b9df9
fix linter
mayank31398 Feb 9, 2026
323ad5a
fix linter
mayank31398 Feb 9, 2026
a9728af
fix linter
mayank31398 Feb 9, 2026
572d5ba
fix linter
mayank31398 Feb 9, 2026
48f288f
fix linter
mayank31398 Feb 9, 2026
0af2bda
fix linter
mayank31398 Feb 9, 2026
68a2b0c
fix linter
mayank31398 Feb 9, 2026
62d6526
fix linter
mayank31398 Feb 9, 2026
527dfee
fix linter
mayank31398 Feb 9, 2026
e2ec3b3
fix linter
mayank31398 Feb 9, 2026
4abc04b
fix linter
mayank31398 Feb 9, 2026
f931521
fix linter
mayank31398 Feb 9, 2026
a47b79b
fix linter
mayank31398 Feb 9, 2026
40eb0ad
fix linter
mayank31398 Feb 9, 2026
98fc18a
fix linter
mayank31398 Feb 9, 2026
a7786aa
fix linter
mayank31398 Feb 9, 2026
7c056dc
fix linter
mayank31398 Feb 9, 2026
f6b0910
fix linter
mayank31398 Feb 9, 2026
7b7261b
fix linter
mayank31398 Feb 10, 2026
b12b17a
fix tp
mayank31398 Feb 10, 2026
4668bf0
fix tp
mayank31398 Feb 10, 2026
9f9eddc
fix tp
mayank31398 Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lm_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def model_post_init(self, __context: Any) -> None:
if self.model_name is None:
_check_not_None([(self.pretrained_config, "pretrained_config")])
else:
assert not self.efficient_initialization, "efficient_initialization is not supported with HF models"
assert self.pretrained_config is None, "pretrained_config shouldn't be specified with model_name"


Expand Down
54 changes: 20 additions & 34 deletions lm_engine/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
from .containers import ModelContainer
from .enums import Kernel
from .gradient_checkpointing import apply_gradient_checkpointing
from .hf_models import CausalLMOutputWithPast, is_parameter_initialized
from .hf_models import (
_INIT_MARKER,
CausalLMOutputWithPast,
get_parameter_marker_maps,
is_parameter_initialized,
set_parameter_marker_maps,
)
from .kernels import is_kernel_allowed
from .utils import (
Accelerator,
Expand Down Expand Up @@ -119,36 +125,6 @@ def _get_fsdp_mixed_precision(
return mixed_precision


def _get_parameter_marker_maps(model_container: ModelContainer, extra_markers: list[str] = []) -> list[dict]:
marker_maps = []
for model in model_container:
marker_maps.append({})
for param_name, param in model.named_parameters():
marker_maps[-1][param_name] = {}
for marker in ["_no_weight_decay", "_has_mup_learning_rate"] + extra_markers:
marker_maps[-1][param_name][marker] = getattr(param, marker, False)

return marker_maps


def _set_parameter_marker_maps(model_container: ModelContainer, marker_maps: list[dict]) -> None:
for model, _marker_map in zip(model_container, marker_maps):
for param_name, parameter in model.named_parameters():
# handle FSDP for TPU
param_name = param_name.replace(_FSDP_TPU_SHARD_SEPARATOR, ".")
param_name = param_name.replace(f"{_FSDP_TPU_SHARD}.", "")
param_name = param_name.replace(f"{_FSDP_TPU_FPW}.", "")

# handle FSDP-1
param_name = param_name.replace(f"{_FSDP_1_STRING}.", "")

# handle torch compile
param_name = param_name.replace(f"{_TORCH_COMPILE_STRING}.", "")

for marker, value in _marker_map[param_name].items():
setattr(parameter, marker, value)


def wrap_model_container_for_distributed_training(
args: TrainingArgs, model_container: ModelContainer
) -> tuple[ModelContainer, _PipelineSchedule]:
Expand Down Expand Up @@ -229,9 +205,9 @@ def wrap_model_container_for_distributed_training(
for param_name, parameter in model.named_buffers():
parameter._is_initialized = False

marker_maps = _get_parameter_marker_maps(model_container)
marker_maps = get_parameter_marker_maps(model_container)
else:
marker_maps = _get_parameter_marker_maps(model_container, extra_markers=["_is_initialized"])
marker_maps = get_parameter_marker_maps(model_container, extra_markers=[_INIT_MARKER])

accelerator = Accelerator.get_accelerator()

Expand Down Expand Up @@ -387,7 +363,17 @@ def _sharding_function(parameter: nn.Parameter) -> Shard:
for i, model in enumerate(model_container):
model_container[i] = torch.compile(model)

_set_parameter_marker_maps(model_container, marker_maps)
set_parameter_marker_maps(
model_container,
marker_maps,
replacement_patterns=[
(_FSDP_TPU_SHARD_SEPARATOR, "."),
(f"{_FSDP_TPU_SHARD}.", ""),
(f"{_FSDP_TPU_FPW}.", ""),
(f"{_FSDP_1_STRING}.", ""),
(f"{_TORCH_COMPILE_STRING}.", ""),
],
)

pipeline_stages = []
pipeline_schedule = None
Expand Down
20 changes: 20 additions & 0 deletions lm_engine/dtensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@
from torch.distributed.device_mesh import DeviceMesh


def _get_all_markers():
from .hf_models.parameter import _ALL_MARKERS

return _ALL_MARKERS


def tensor_to_dtensor(
tensor: torch.Tensor,
device_mesh: DeviceMesh,
current_placement: Placement | list[Placement],
desired_placement: Placement | list[Placement] | None = None,
run_check: bool = False,
copy_marker: bool = True,
) -> DTensor:
if isinstance(tensor, DTensor):
return tensor
Expand All @@ -30,6 +37,12 @@ def tensor_to_dtensor(

dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=True)

if copy_marker:
for marker in _get_all_markers():
marker_value = getattr(dtensor, marker, None)
if marker_value is not None:
setattr(dtensor, marker, marker_value)

return dtensor


Expand All @@ -38,6 +51,7 @@ def dtensor_to_tensor(
device_mesh: DeviceMesh | None = None,
desired_placement: Placement | list[Placement] | None = None,
grad_placement: Placement | list[Placement] | None = None,
copy_marker: bool = True,
) -> torch.Tensor:
if not isinstance(dtensor, DTensor):
return dtensor
Expand All @@ -55,6 +69,12 @@ def dtensor_to_tensor(

tensor = dtensor.to_local(grad_placements=grad_placement)

if copy_marker:
for marker in _get_all_markers():
marker_value = getattr(tensor, marker, None)
if marker_value is not None:
setattr(tensor, marker, marker_value)

return tensor


Expand Down
3 changes: 3 additions & 0 deletions lm_engine/hf_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
PaLMModel,
)
from .parameter import (
_INIT_MARKER,
get_parameter_marker_maps,
is_parameter_initialized,
is_parameter_with_mup_learning_rate,
is_parameter_with_no_weight_decay,
mark_parameter_as_initialized,
mark_parameter_as_mup_learning_rate,
mark_parameter_as_no_weight_decay,
set_parameter_marker_maps,
)
from .register_hf import get_model_parallel_class, is_custom_model, register_model_classes
from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts
Expand Down
28 changes: 0 additions & 28 deletions lm_engine/hf_models/config/sequence_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,11 @@ class _SoftmaxAttentionArgs(BaseArgs):
add_bias: bool = False
attention_multiplier: float | None = None
sliding_window: int | None = None
# needed for Qwen 2 MoE
qkv_bias: bool = None

def model_post_init(self, __context: Any) -> None:
if self.qkv_bias is None:
self.qkv_bias = self.add_bias

assert self.sequence_mixer_type == "softmax_attention"


class _MultiHeadLatentAttentionArgs(BaseArgs):
sequence_mixer_type: str = "multihead_latent_attention"
num_attention_heads: int | None = None
softmax_dropout: float = 0
dropout: float = 0
add_bias: bool = False
attention_multiplier: float | None = None
sliding_window: int | None = None
query_compression_size: int | None = None
key_value_compression_size: int | None = None
num_attention_heads: int | None = None
head_dim: int | None = None
normalization_function: str = "layernorm"

def model_post_init(self, __context: Any) -> None:
assert self.sequence_mixer_type == "multihead_latent_attention"
assert self.num_attention_heads is not None
assert self.query_compression_size is not None
assert self.key_value_compression_size is not None
assert self.num_attention_heads is not None
assert self.head_dim is not None


class _SoftPlusDecayArgs(BaseArgs):
A_init_min: float = 0
A_init_max: float = 16
Expand Down
2 changes: 1 addition & 1 deletion lm_engine/hf_models/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# **************************************************

from .dense import BaseModelMixin, Block, CausalLMModelMixin, PreTrainedModelMixin
from .dense_TP import BaseModelMixin_TP, Block_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP
from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP
from .modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down
34 changes: 31 additions & 3 deletions lm_engine/hf_models/mixins/dense/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ class PreTrainedModelMixin(PreTrainedModel):
def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixin:
super().__init__(config, *args, **kwargs)

self.sequence_parallel = kwargs.get("sequence_parallel", False)
self.num_pipeline_stages = kwargs.get("num_pipeline_stages", 1)
self.pipeline_stage_id = kwargs.get("pipeline_stage_id", 0)

self.is_first_stage = self.pipeline_stage_id == 0
self.is_last_stage = self.pipeline_stage_id == self.num_pipeline_stages - 1
self.is_pipeline_parallel_enabled = self.num_pipeline_stages > 1

assert self.config_class is not None
self.generation_config = GenerationConfig.from_model_config(self.config)

Expand All @@ -38,6 +46,9 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixi

self._has_mamba2 = any([block.sequence_mixer_type == "mamba2" for block in self.config.sequence_mixer_blocks])

if self.is_pipeline_parallel_enabled and self._tied_word_embeddings:
raise NotImplementedError()

# FIXME typing
def prepare_inputs_for_model(
self,
Expand Down Expand Up @@ -96,12 +107,23 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:
config.sequence_mixer_blocks[i].sequence_mixer_type for i in range(config.num_layers)
]

self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range)
self.wte = ParameterizedEmbedding(
config.vocab_size,
self.embed_dim,
std=self.initializer_range,
use_padding_free_transformer=self.use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
)

self.embedding_dropout = Dropout(config.embedding_dropout)
self.h = nn.ModuleList(
[
self.layer_class(config, use_padding_free_transformer=self.use_padding_free_transformer, layer_idx=i)
self.layer_class(
config,
use_padding_free_transformer=self.use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
layer_idx=i,
)
for i in range(config.num_layers)
]
)
Expand Down Expand Up @@ -312,7 +334,13 @@ def _setup_positional_encoding(self) -> None:
max_position_embeddings = self.config.max_position_embeddings

if self.position_embedding_type == "learned_absolute":
self.wpe = ParameterizedEmbedding(max_position_embeddings, self.embed_dim, std=self.initializer_range)
self.wpe = ParameterizedEmbedding(
max_position_embeddings,
self.embed_dim,
std=self.initializer_range,
use_padding_free_transformer=self.use_padding_free_transformer,
sequence_parallel=False,
)
elif self.position_embedding_type == "rope":
if self.config.rope_scaling is None:
self.rope = RoPE(
Expand Down
31 changes: 26 additions & 5 deletions lm_engine/hf_models/mixins/dense/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

class Block(nn.Module):
def __init__(
self, config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int | None = None
self,
config: CommonConfig,
use_padding_free_transformer: bool,
layer_idx: int,
sequence_parallel: bool,
) -> Block:
super().__init__()

Expand All @@ -23,14 +27,31 @@ def __init__(
self.sequence_mixer_type = config.sequence_mixer_blocks[layer_idx].sequence_mixer_type

self.ln_1 = get_normalization_function(
config.normalization_function, hidden_size, eps=config.layer_norm_epsilon
config.normalization_function,
hidden_size,
eps=config.layer_norm_epsilon,
use_padding_free_transformer=use_padding_free_transformer,
sequence_parallel=sequence_parallel,
)
self.sequence_mixer = get_sequence_mixer(
config,
True,
use_padding_free_transformer=use_padding_free_transformer,
sequence_parallel=sequence_parallel,
layer_idx=layer_idx,
)
self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx)
self.ln_2 = get_normalization_function(
config.normalization_function, hidden_size, eps=config.layer_norm_epsilon
config.normalization_function,
hidden_size,
eps=config.layer_norm_epsilon,
use_padding_free_transformer=use_padding_free_transformer,
sequence_parallel=sequence_parallel,
)
self.mlp_block = get_mlp_block(
config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx
config,
use_padding_free_transformer=use_padding_free_transformer,
sequence_parallel=sequence_parallel,
layer_idx=layer_idx,
)

def forward(
Expand Down
Loading