diff --git a/lm_engine/arguments.py b/lm_engine/arguments.py index 694332d6b..cbe885e41 100644 --- a/lm_engine/arguments.py +++ b/lm_engine/arguments.py @@ -57,6 +57,7 @@ def model_post_init(self, __context: Any) -> None: if self.model_name is None: _check_not_None([(self.pretrained_config, "pretrained_config")]) else: + assert not self.efficient_initialization, "efficient_initialization is not supported with HF models" assert self.pretrained_config is None, "pretrained_config shouldn't be specified with model_name" diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 8fbd19d9c..5e162231f 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -30,7 +30,13 @@ from .containers import ModelContainer from .enums import Kernel from .gradient_checkpointing import apply_gradient_checkpointing -from .hf_models import CausalLMOutputWithPast, is_parameter_initialized +from .hf_models import ( + _INIT_MARKER, + CausalLMOutputWithPast, + get_parameter_marker_maps, + is_parameter_initialized, + set_parameter_marker_maps, +) from .kernels import is_kernel_allowed from .utils import ( Accelerator, @@ -119,36 +125,6 @@ def _get_fsdp_mixed_precision( return mixed_precision -def _get_parameter_marker_maps(model_container: ModelContainer, extra_markers: list[str] = []) -> list[dict]: - marker_maps = [] - for model in model_container: - marker_maps.append({}) - for param_name, param in model.named_parameters(): - marker_maps[-1][param_name] = {} - for marker in ["_no_weight_decay", "_has_mup_learning_rate"] + extra_markers: - marker_maps[-1][param_name][marker] = getattr(param, marker, False) - - return marker_maps - - -def _set_parameter_marker_maps(model_container: ModelContainer, marker_maps: list[dict]) -> None: - for model, _marker_map in zip(model_container, marker_maps): - for param_name, parameter in model.named_parameters(): - # handle FSDP for TPU - param_name = param_name.replace(_FSDP_TPU_SHARD_SEPARATOR, ".") - param_name = param_name.replace(f"{_FSDP_TPU_SHARD}.", "") - param_name = param_name.replace(f"{_FSDP_TPU_FPW}.", "") - - # handle FSDP-1 - param_name = param_name.replace(f"{_FSDP_1_STRING}.", "") - - # handle torch compile - param_name = param_name.replace(f"{_TORCH_COMPILE_STRING}.", "") - - for marker, value in _marker_map[param_name].items(): - setattr(parameter, marker, value) - - def wrap_model_container_for_distributed_training( args: TrainingArgs, model_container: ModelContainer ) -> tuple[ModelContainer, _PipelineSchedule]: @@ -229,9 +205,9 @@ def wrap_model_container_for_distributed_training( for param_name, parameter in model.named_buffers(): parameter._is_initialized = False - marker_maps = _get_parameter_marker_maps(model_container) + marker_maps = get_parameter_marker_maps(model_container) else: - marker_maps = _get_parameter_marker_maps(model_container, extra_markers=["_is_initialized"]) + marker_maps = get_parameter_marker_maps(model_container, extra_markers=[_INIT_MARKER]) accelerator = Accelerator.get_accelerator() @@ -387,7 +363,17 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: for i, model in enumerate(model_container): model_container[i] = torch.compile(model) - _set_parameter_marker_maps(model_container, marker_maps) + set_parameter_marker_maps( + model_container, + marker_maps, + replacement_patterns=[ + (_FSDP_TPU_SHARD_SEPARATOR, "."), + (f"{_FSDP_TPU_SHARD}.", ""), + (f"{_FSDP_TPU_FPW}.", ""), + (f"{_FSDP_1_STRING}.", ""), + (f"{_TORCH_COMPILE_STRING}.", ""), + ], + ) pipeline_stages = [] pipeline_schedule = None diff --git a/lm_engine/dtensors.py b/lm_engine/dtensors.py index 7944c77d1..4fe9a3a76 100644 --- a/lm_engine/dtensors.py +++ b/lm_engine/dtensors.py @@ -9,12 +9,19 @@ from torch.distributed.device_mesh import DeviceMesh +def _get_all_markers(): + from .hf_models.parameter import _ALL_MARKERS + + return _ALL_MARKERS + + def tensor_to_dtensor( tensor: torch.Tensor, device_mesh: DeviceMesh, current_placement: Placement | list[Placement], desired_placement: Placement | list[Placement] | None = None, run_check: bool = False, + copy_marker: bool = True, ) -> DTensor: if isinstance(tensor, DTensor): return tensor @@ -30,6 +37,12 @@ def tensor_to_dtensor( dtensor = dtensor.redistribute(device_mesh=device_mesh, placements=desired_placement, async_op=True) + if copy_marker: + for marker in _get_all_markers(): + marker_value = getattr(dtensor, marker, None) + if marker_value is not None: + setattr(dtensor, marker, marker_value) + return dtensor @@ -38,6 +51,7 @@ def dtensor_to_tensor( device_mesh: DeviceMesh | None = None, desired_placement: Placement | list[Placement] | None = None, grad_placement: Placement | list[Placement] | None = None, + copy_marker: bool = True, ) -> torch.Tensor: if not isinstance(dtensor, DTensor): return dtensor @@ -55,6 +69,12 @@ def dtensor_to_tensor( tensor = dtensor.to_local(grad_placements=grad_placement) + if copy_marker: + for marker in _get_all_markers(): + marker_value = getattr(tensor, marker, None) + if marker_value is not None: + setattr(tensor, marker, marker_value) + return tensor diff --git a/lm_engine/hf_models/__init__.py b/lm_engine/hf_models/__init__.py index 75341b75b..cd306993e 100644 --- a/lm_engine/hf_models/__init__.py +++ b/lm_engine/hf_models/__init__.py @@ -23,12 +23,15 @@ PaLMModel, ) from .parameter import ( + _INIT_MARKER, + get_parameter_marker_maps, is_parameter_initialized, is_parameter_with_mup_learning_rate, is_parameter_with_no_weight_decay, mark_parameter_as_initialized, mark_parameter_as_mup_learning_rate, mark_parameter_as_no_weight_decay, + set_parameter_marker_maps, ) from .register_hf import get_model_parallel_class, is_custom_model, register_model_classes from .unshard import fix_unsharded_state_dict, unshard_tensor_parallel_state_dicts diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 5b22610dc..a2eec5652 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -16,39 +16,11 @@ class _SoftmaxAttentionArgs(BaseArgs): add_bias: bool = False attention_multiplier: float | None = None sliding_window: int | None = None - # needed for Qwen 2 MoE - qkv_bias: bool = None def model_post_init(self, __context: Any) -> None: - if self.qkv_bias is None: - self.qkv_bias = self.add_bias - assert self.sequence_mixer_type == "softmax_attention" -class _MultiHeadLatentAttentionArgs(BaseArgs): - sequence_mixer_type: str = "multihead_latent_attention" - num_attention_heads: int | None = None - softmax_dropout: float = 0 - dropout: float = 0 - add_bias: bool = False - attention_multiplier: float | None = None - sliding_window: int | None = None - query_compression_size: int | None = None - key_value_compression_size: int | None = None - num_attention_heads: int | None = None - head_dim: int | None = None - normalization_function: str = "layernorm" - - def model_post_init(self, __context: Any) -> None: - assert self.sequence_mixer_type == "multihead_latent_attention" - assert self.num_attention_heads is not None - assert self.query_compression_size is not None - assert self.key_value_compression_size is not None - assert self.num_attention_heads is not None - assert self.head_dim is not None - - class _SoftPlusDecayArgs(BaseArgs): A_init_min: float = 0 A_init_max: float = 16 diff --git a/lm_engine/hf_models/mixins/__init__.py b/lm_engine/hf_models/mixins/__init__.py index 2d9f6d03a..439bcccbb 100644 --- a/lm_engine/hf_models/mixins/__init__.py +++ b/lm_engine/hf_models/mixins/__init__.py @@ -3,7 +3,7 @@ # ************************************************** from .dense import BaseModelMixin, Block, CausalLMModelMixin, PreTrainedModelMixin -from .dense_TP import BaseModelMixin_TP, Block_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP +from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP from .modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index 9582ee384..9aa134a0e 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -30,6 +30,14 @@ class PreTrainedModelMixin(PreTrainedModel): def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixin: super().__init__(config, *args, **kwargs) + self.sequence_parallel = kwargs.get("sequence_parallel", False) + self.num_pipeline_stages = kwargs.get("num_pipeline_stages", 1) + self.pipeline_stage_id = kwargs.get("pipeline_stage_id", 0) + + self.is_first_stage = self.pipeline_stage_id == 0 + self.is_last_stage = self.pipeline_stage_id == self.num_pipeline_stages - 1 + self.is_pipeline_parallel_enabled = self.num_pipeline_stages > 1 + assert self.config_class is not None self.generation_config = GenerationConfig.from_model_config(self.config) @@ -38,6 +46,9 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixi self._has_mamba2 = any([block.sequence_mixer_type == "mamba2" for block in self.config.sequence_mixer_blocks]) + if self.is_pipeline_parallel_enabled and self._tied_word_embeddings: + raise NotImplementedError() + # FIXME typing def prepare_inputs_for_model( self, @@ -96,12 +107,23 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: config.sequence_mixer_blocks[i].sequence_mixer_type for i in range(config.num_layers) ] - self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) + self.wte = ParameterizedEmbedding( + config.vocab_size, + self.embed_dim, + std=self.initializer_range, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) self.embedding_dropout = Dropout(config.embedding_dropout) self.h = nn.ModuleList( [ - self.layer_class(config, use_padding_free_transformer=self.use_padding_free_transformer, layer_idx=i) + self.layer_class( + config, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + layer_idx=i, + ) for i in range(config.num_layers) ] ) @@ -312,7 +334,13 @@ def _setup_positional_encoding(self) -> None: max_position_embeddings = self.config.max_position_embeddings if self.position_embedding_type == "learned_absolute": - self.wpe = ParameterizedEmbedding(max_position_embeddings, self.embed_dim, std=self.initializer_range) + self.wpe = ParameterizedEmbedding( + max_position_embeddings, + self.embed_dim, + std=self.initializer_range, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=False, + ) elif self.position_embedding_type == "rope": if self.config.rope_scaling is None: self.rope = RoPE( diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index 0cceab740..e7d7a80de 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -14,7 +14,11 @@ class Block(nn.Module): def __init__( - self, config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int | None = None + self, + config: CommonConfig, + use_padding_free_transformer: bool, + layer_idx: int, + sequence_parallel: bool, ) -> Block: super().__init__() @@ -23,14 +27,31 @@ def __init__( self.sequence_mixer_type = config.sequence_mixer_blocks[layer_idx].sequence_mixer_type self.ln_1 = get_normalization_function( - config.normalization_function, hidden_size, eps=config.layer_norm_epsilon + config.normalization_function, + hidden_size, + eps=config.layer_norm_epsilon, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) + self.sequence_mixer = get_sequence_mixer( + config, + True, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + layer_idx=layer_idx, ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) self.ln_2 = get_normalization_function( - config.normalization_function, hidden_size, eps=config.layer_norm_epsilon + config.normalization_function, + hidden_size, + eps=config.layer_norm_epsilon, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx + config, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + layer_idx=layer_idx, ) def forward( diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 4ec016803..273d2dc4f 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -10,10 +10,11 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed +from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible from ...cache import GenerationCache from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero -from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear +from ...modeling_utils import LMHead, ParameterizedLinear from ..modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from .base import PreTrainedModelMixin @@ -28,27 +29,23 @@ def __init__(self, config: CommonConfig, **kwargs) -> CausalLMModelMixin: self._init_model(config, **kwargs) def _init_model(self, config: CommonConfig, **kwargs) -> None: + self.vocab_size = config.vocab_size self.transformer = self.base_model_class(config, **kwargs) - if not self._tied_word_embeddings: - self.lm_head = ParameterizedLinear( - config.hidden_size, config.vocab_size, bias=False, std=config.initializer_range - ) - - self.m_width = config.m_width - - def get_input_embeddings(self) -> ParameterizedEmbedding: - return self.transformer.wte - - def set_input_embeddings(self, value: ParameterizedEmbedding) -> None: - self.transformer.wte = value + if self.is_last_stage: + if not self._tied_word_embeddings: + self.lm_head = LMHead( + self.vocab_size, + config.hidden_size, + std=config.initializer_range, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) - def get_output_embeddings(self) -> ParameterizedLinear: - return self.transformer.wte if self._tied_word_embeddings else self.lm_head + self.m_width = config.m_width - def set_output_embeddings(self, new_embeddings: ParameterizedLinear) -> None: - if not self._tied_word_embeddings: - self.lm_head = new_embeddings + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() if self.is_tp_enabled else None def forward( self, @@ -254,3 +251,88 @@ def generate( ) return generated_tokens + + def get_dummy_input_tensor( + self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype + ) -> tuple[torch.Tensor] | torch.Tensor: + if self.is_first_stage: + # 1 is added to sequence length since megatron's dataloader gives an extra token and for good reason + dummy_input = torch.empty( + micro_batch_size, sequence_length + 1, device=torch.cuda.current_device(), dtype=torch.long + ) + else: + dummy_input = self._get_dummy_intermediate_tensor( + micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype + ) + + dummy_input = ( + dummy_input, + torch.empty(1, device=torch.cuda.current_device(), dtype=intermediate_dtype), + ) + + return dummy_input + + def get_dummy_output_tensor( + self, + micro_batch_size: int, + sequence_length: int, + intermediate_dtype: torch.dtype, + output_parallel_lm_logits_if_possible: bool, + ) -> tuple[torch.Tensor] | torch.Tensor: + if self.is_last_stage: + vocab_size = self.vocab_size + if output_parallel_lm_logits_if_possible: + vocab_size = divide_if_divisible(vocab_size, ProcessGroupManager.get_tensor_parallel_world_size(), "") + + if self.use_padding_free_transformer: + tensor = torch.empty( + micro_batch_size * sequence_length, + vocab_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + else: + tensor = torch.empty( + micro_batch_size, + sequence_length, + vocab_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + else: + tensor = self._get_dummy_intermediate_tensor( + micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype + ) + + tensor = (tensor, torch.empty(1, device=torch.cuda.current_device(), dtype=intermediate_dtype)) + + return tensor + + def _get_dummy_intermediate_tensor( + self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype + ) -> tuple[torch.Tensor] | torch.Tensor: + sharded_sequence_length = ( + divide_if_divisible(sequence_length, ProcessGroupManager.get_tensor_parallel_world_size(), "") + if self.sequence_parallel + else sequence_length + ) + + hidden_size = self.config.hidden_size + + if self.use_padding_free_transformer: + tensor = torch.empty( + micro_batch_size * sharded_sequence_length, + hidden_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + else: + tensor = torch.empty( + micro_batch_size, + sharded_sequence_length, + hidden_size, + device=torch.cuda.current_device(), + dtype=intermediate_dtype, + ) + + return tensor diff --git a/lm_engine/hf_models/mixins/dense_TP/__init__.py b/lm_engine/hf_models/mixins/dense_TP/__init__.py index 817c2815f..6d93899fa 100644 --- a/lm_engine/hf_models/mixins/dense_TP/__init__.py +++ b/lm_engine/hf_models/mixins/dense_TP/__init__.py @@ -2,6 +2,5 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -from .base import BaseModelMixin_TP, PreTrainedModelMixin_TP -from .layer import Block_TP +from .base import BaseModelMixin_TP from .main import CausalLMModelMixin_TP diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index 700c2b153..af58dcef6 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -10,35 +10,13 @@ from ....utils import ProcessGroupManager, divide_if_divisible from ...cache import GenerationCache from ...config import CommonConfig -from ...modeling_utils import Dropout, RoPE, YaRNScaledRoPE -from ...modeling_utils_TP import Embedding_TP, get_normalization_function_TP +from ...modeling_utils import Dropout, ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function from ...utils import is_generation_cache_enabled from ..dense import BaseModelMixin, PreTrainedModelMixin from ..modeling_outputs import BaseModelOutputWithPast -from .layer import Block_TP -class PreTrainedModelMixin_TP(PreTrainedModelMixin): - layer_class = Block_TP - _no_split_modules = ["Block_TP"] - - def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixin_TP: - self.sequence_parallel = kwargs.get("sequence_parallel", False) - - self.num_pipeline_stages = kwargs.get("num_pipeline_stages", 1) - self.pipeline_stage_id = kwargs.get("pipeline_stage_id", 0) - - self.is_first_stage = self.pipeline_stage_id == 0 - self.is_last_stage = self.pipeline_stage_id == self.num_pipeline_stages - 1 - self.is_pipeline_parallel_enabled = self.num_pipeline_stages > 1 - - super().__init__(config, *args, **kwargs) - - if self.is_pipeline_parallel_enabled and self._tied_word_embeddings: - raise NotImplementedError() - - -class BaseModelMixin_TP(PreTrainedModelMixin_TP, BaseModelMixin): +class BaseModelMixin_TP(BaseModelMixin): def _init_model(self, config: CommonConfig, **kwargs) -> None: self.embed_dim = config.hidden_size self.max_position_embeddings = config.max_position_embeddings @@ -54,7 +32,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.layer_end_id = self.layers_per_stage * (self.pipeline_stage_id + 1) if self.is_first_stage: - self.wte = Embedding_TP( + self.wte = ParameterizedEmbedding( config.vocab_size, self.embed_dim, std=self.initializer_range, @@ -81,7 +59,7 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: ) if self.is_last_stage: - self.ln_f = get_normalization_function_TP( + self.ln_f = get_normalization_function( config.normalization_function, self.embed_dim, eps=config.layer_norm_epsilon, @@ -168,7 +146,7 @@ def _setup_positional_encoding(self) -> None: if self.position_embedding_type == "learned_absolute": if self.is_first_stage: - self.wpe = Embedding_TP( + self.wpe = ParameterizedEmbedding( max_position_embeddings, self.embed_dim, std=self.initializer_range, diff --git a/lm_engine/hf_models/mixins/dense_TP/layer.py b/lm_engine/hf_models/mixins/dense_TP/layer.py deleted file mode 100644 index d09109ce3..000000000 --- a/lm_engine/hf_models/mixins/dense_TP/layer.py +++ /dev/null @@ -1,54 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import torch.nn as nn - -from ...config import CommonConfig -from ...modeling_utils_TP import get_mlp_block_TP, get_normalization_function_TP, get_sequence_mixer_TP -from ..dense import Block - - -class Block_TP(Block): - def __init__( - self, - config: CommonConfig, - use_padding_free_transformer: bool, - layer_idx: int | None = None, - sequence_parallel: bool = False, - ) -> Block_TP: - nn.Module.__init__(self) - - hidden_size = config.hidden_size - self.m_residual = config.m_residual - self.sequence_mixer_type = config.sequence_mixer_blocks[layer_idx].sequence_mixer_type - - self.ln_1 = get_normalization_function_TP( - config.normalization_function, - hidden_size, - eps=config.layer_norm_epsilon, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - self.sequence_mixer = get_sequence_mixer_TP( - config, - True, - use_padding_free_transformer=use_padding_free_transformer, - layer_idx=layer_idx, - sequence_parallel=sequence_parallel, - ) - self.ln_2 = get_normalization_function_TP( - config.normalization_function, - hidden_size, - eps=config.layer_norm_epsilon, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - self.mlp_block = get_mlp_block_TP( - config, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - layer_idx=layer_idx, - ) diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index f614b1cc0..9b23a9f9d 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -20,7 +20,8 @@ get_aux_loss, is_aux_loss_zero, ) -from ...modeling_utils_TP import LMHead_TP +from ...modeling_utils import LMHead +from ...parameter import _INIT_MARKER, get_parameter_marker_maps, set_parameter_marker_maps from ..dense import CausalLMModelMixin from ..modeling_outputs import ( BaseModelOutputWithPast, @@ -28,30 +29,11 @@ PipelineParallelInput, PipelineParallelOutput, ) -from .base import PreTrainedModelMixin_TP -class CausalLMModelMixin_TP(PreTrainedModelMixin_TP, CausalLMModelMixin): +class CausalLMModelMixin_TP(CausalLMModelMixin): model_parallel_state_dict_function = None - def _init_model(self, config: CommonConfig, **kwargs) -> None: - self.vocab_size = config.vocab_size - self.transformer = self.base_model_class(config, **kwargs) - - if self.is_last_stage: - if not self._tied_word_embeddings: - self.lm_head = LMHead_TP( - self.vocab_size, - config.hidden_size, - std=config.initializer_range, - use_padding_free_transformer=self.use_padding_free_transformer, - sequence_parallel=self.sequence_parallel, - ) - - self.m_width = config.m_width - - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - def forward( self, input_ids: torch.Tensor | list[list[int]] | None = None, @@ -165,7 +147,7 @@ def forward( def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: return ( - LMHead_TP.compute_with_weight( + LMHead.compute_with_weight( hidden_states, weight=self.transformer.wte.weight, use_padding_free_transformer=self.use_padding_free_transformer, @@ -186,10 +168,14 @@ def from_pretrained( with torch.device("meta"): # try sharding vocab matrices if really struggling for memory model = cls._from_config(config, **kwargs) + marker_maps = get_parameter_marker_maps([model], extra_markers=[_INIT_MARKER]) + model = model.to(dtype=dtype) # copy to device without copying storage model = model.to_empty(device=torch.cuda.current_device()) + set_parameter_marker_maps([model], marker_maps) + model.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(pretrained_model_name_or_path)) return model @@ -209,88 +195,3 @@ def load_from_safetensors_weights_manager(self, safetensors_weights_manager: Saf ) self.load_state_dict(state_dict) - - def get_dummy_input_tensor( - self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype - ) -> tuple[torch.Tensor] | torch.Tensor: - if self.is_first_stage: - # 1 is added to sequence length since megatron's dataloader gives an extra token and for good reason - dummy_input = torch.empty( - micro_batch_size, sequence_length + 1, device=torch.cuda.current_device(), dtype=torch.long - ) - else: - dummy_input = self._get_dummy_intermediate_tensor( - micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype - ) - - dummy_input = ( - dummy_input, - torch.empty(1, device=torch.cuda.current_device(), dtype=intermediate_dtype), - ) - - return dummy_input - - def get_dummy_output_tensor( - self, - micro_batch_size: int, - sequence_length: int, - intermediate_dtype: torch.dtype, - output_parallel_lm_logits_if_possible: bool, - ) -> tuple[torch.Tensor] | torch.Tensor: - if self.is_last_stage: - vocab_size = self.config.vocab_size - if output_parallel_lm_logits_if_possible: - vocab_size = divide_if_divisible(vocab_size, ProcessGroupManager.get_tensor_parallel_world_size(), "") - - if self.use_padding_free_transformer: - tensor = torch.empty( - micro_batch_size * sequence_length, - vocab_size, - device=torch.cuda.current_device(), - dtype=intermediate_dtype, - ) - else: - tensor = torch.empty( - micro_batch_size, - sequence_length, - vocab_size, - device=torch.cuda.current_device(), - dtype=intermediate_dtype, - ) - else: - tensor = self._get_dummy_intermediate_tensor( - micro_batch_size, sequence_length, intermediate_dtype=intermediate_dtype - ) - - tensor = (tensor, torch.empty(1, device=torch.cuda.current_device(), dtype=intermediate_dtype)) - - return tensor - - def _get_dummy_intermediate_tensor( - self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype - ) -> tuple[torch.Tensor] | torch.Tensor: - sharded_sequence_length = ( - divide_if_divisible(sequence_length, ProcessGroupManager.get_tensor_parallel_world_size(), "") - if self.sequence_parallel - else sequence_length - ) - - hidden_size = self.config.hidden_size - - if self.use_padding_free_transformer: - tensor = torch.empty( - micro_batch_size * sharded_sequence_length, - hidden_size, - device=torch.cuda.current_device(), - dtype=intermediate_dtype, - ) - else: - tensor = torch.empty( - micro_batch_size, - sharded_sequence_length, - hidden_size, - device=torch.cuda.current_device(), - dtype=intermediate_dtype, - ) - - return tensor diff --git a/lm_engine/hf_models/model_conversion/__init__.py b/lm_engine/hf_models/model_conversion/__init__.py index bb1214461..8c4820e94 100644 --- a/lm_engine/hf_models/model_conversion/__init__.py +++ b/lm_engine/hf_models/model_conversion/__init__.py @@ -22,12 +22,6 @@ _import_granitemoeshared_state_dict, ) from .llama import _export_llama_config, _export_llama_state_dict, _import_llama_config, _import_llama_state_dict -from .qwen2_moe import ( - _export_qwen2_moe_config, - _export_qwen2_moe_state_dict, - _import_qwen2_moe_config, - _import_qwen2_moe_state_dict, -) _MODEL_IMPORT_FUNCTIONS = { @@ -36,7 +30,6 @@ "granitemoeshared": (_import_granitemoeshared_config, _import_granitemoeshared_state_dict), "granitemoehybrid": (_import_granitemoehybrid_config, _import_granitemoehybrid_state_dict), "llama": (_import_llama_config, _import_llama_state_dict), - "qwen2_moe": (_import_qwen2_moe_config, _import_qwen2_moe_state_dict), } @@ -77,7 +70,6 @@ def import_from_huggingface( "granitemoeshared": (_export_granitemoeshared_config, _export_granitemoeshared_state_dict), "granitemoehybrid": (_export_granitemoehybrid_config, _export_granitemoehybrid_state_dict), "llama": (_export_llama_config, _export_llama_state_dict), - "qwen2_moe": (_export_qwen2_moe_config, _export_qwen2_moe_state_dict), } diff --git a/lm_engine/hf_models/model_conversion/qwen2_moe.py b/lm_engine/hf_models/model_conversion/qwen2_moe.py deleted file mode 100644 index 9fe75906e..000000000 --- a/lm_engine/hf_models/model_conversion/qwen2_moe.py +++ /dev/null @@ -1,345 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -import torch -from transformers import Qwen2MoeConfig, Qwen2MoeForCausalLM - -from ...utils import SafeTensorsWeightsManager, divide_if_divisible -from ..modeling_utils import ( - interleave_query_key_value_tensor_for_attention, - interleave_up_gate_tensor_for_mlp, - split_query_key_value_tensor_for_attention, - split_up_gate_tensor_for_mlp, -) -from ..models import GPTBaseConfig - - -def _import_qwen2_moe_config(original_config: Qwen2MoeConfig) -> GPTBaseConfig: - assert original_config.hidden_act == "silu" - - mlp_blocks = [] - for layer_idx in range(original_config.num_hidden_layers): - if (layer_idx not in original_config.mlp_only_layers) and ( - original_config.num_experts > 0 and (layer_idx + 1) % original_config.decoder_sparse_step == 0 - ): - mlp_block = { - "mlp_type": "MoE", - "intermediate_size": original_config.moe_intermediate_size, - "shared_intermediate_size": original_config.shared_expert_intermediate_size, - "shared_expert_gating": True, - "num_experts": original_config.num_experts, - "num_experts_per_tok": original_config.num_experts_per_tok, - "activation_function": "swiglu", - "add_bias": False, - "normalized_topk": original_config.norm_topk_prob, - } - else: - mlp_block = { - "mlp_type": "MLP", - "intermediate_size": original_config.intermediate_size, - "activation_function": "swiglu", - "add_bias": False, - } - - mlp_blocks.append(mlp_block) - - sequence_mixer_blocks = [] - for layer_idx in range(original_config.num_hidden_layers): - sliding_window = None - if original_config.use_sliding_window and layer_idx >= original_config.max_window_layers: - sliding_window = original_config.sliding_window - - sequence_mixer_block = { - "sequence_mixer_type": "softmax_attention", - "num_attention_heads": original_config.num_attention_heads, - "num_key_value_heads": original_config.num_key_value_heads, - "add_bias": False, - "sliding_window": sliding_window, - "qkv_bias": original_config.qkv_bias, - "softmax_dropout": original_config.attention_dropout, - } - - sequence_mixer_blocks.append(sequence_mixer_block) - - config = GPTBaseConfig( - vocab_size=original_config.vocab_size, - max_position_embeddings=original_config.max_position_embeddings, - hidden_size=original_config.hidden_size, - num_layers=original_config.num_hidden_layers, - position_embedding_type="rope", - normalization_function="rmsnorm", - layer_norm_epsilon=original_config.rms_norm_eps, - use_cache=original_config.use_cache, - tie_word_embeddings=original_config.tie_word_embeddings, - initializer_range=original_config.initializer_range, - rope_theta=original_config.rope_theta, - rope_scaling=original_config.rope_scaling, - router_aux_loss_coef=original_config.router_aux_loss_coef, - bos_token_id=original_config.bos_token_id, - eos_token_id=original_config.eos_token_id, - pad_token_id=original_config.pad_token_id, - sequence_mixer_blocks=sequence_mixer_blocks, - mlp_blocks=mlp_blocks, - ) - - return config - - -def _import_qwen2_moe_state_dict( - config: GPTBaseConfig, safetensors_weights_manager: SafeTensorsWeightsManager -) -> dict: - num_attention_heads = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_attention_heads") - num_key_value_heads = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_key_value_heads") - qkv_bias = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "qkv_bias") - head_dim = divide_if_divisible(config.hidden_size, num_attention_heads, "") - - state_dict = { - "transformer.wte.weight": safetensors_weights_manager.get_tensor("model.embed_tokens.weight"), - "transformer.ln_f.weight": safetensors_weights_manager.get_tensor("model.norm.weight"), - } - - if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") - - for layer_idx in range(config.num_layers): - state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.input_layernorm.weight" - ) - state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.post_attention_layernorm.weight" - ) - - # MoE - if safetensors_weights_manager.has_tensor(f"model.layers.{layer_idx}.mlp.gate.weight"): - state_dict[f"transformer.h.{layer_idx}.mlp_block.gate.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.gate.weight" - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_fc.weight"] = torch.stack( - [ - interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight" - ), - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight" - ), - ) - for expert_idx in range(config.mlp_blocks[layer_idx].num_experts) - ] - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_proj.weight"] = torch.stack( - [ - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight" - ) - for expert_idx in range(config.mlp_blocks[layer_idx].num_experts) - ] - ) - - if safetensors_weights_manager.has_tensor(f"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight"): - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_fc_shared.weight"] = ( - interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight" - ), - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight" - ), - ) - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_proj_shared.weight"] = ( - safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight" - ) - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.shared_expert_gate.weight"] = ( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.shared_expert_gate.weight") - ) - # MLP - else: - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_fc.weight"] = interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.weight"), - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.weight"), - ) - - state_dict[f"transformer.h.{layer_idx}.mlp_block.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.down_proj.weight" - ) - - keys = ["weight"] + (["bias"] if qkv_bias else []) - for key in keys: - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.c_attn.{key}"] = ( - interleave_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.{key}"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.{key}"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.{key}"), - num_attention_heads, - num_key_value_heads, - head_dim, - ) - ) - - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.weight" - ) - - return state_dict - - -def _export_qwen2_moe_config(config: GPTBaseConfig) -> Qwen2MoeConfig: - assert config.normalization_function == "rmsnorm" - assert config.position_embedding_type == "rope" - - config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "add_bias", False) - config.check_equal_for_all_and_get_value("mlp_blocks", "add_bias", False) - config.check_equal_for_all_and_get_value("mlp_blocks", "activation_function", "swiglu") - - mlp_only_layers = [ - layer_idx for layer_idx, mlp_block in enumerate(config.mlp_blocks) if mlp_block.mlp_type == "MLP" - ] - - max_window_layers = None - use_sliding_window = False - for layer_idx in range(config.num_layers): - block = config.sequence_mixer_blocks[layer_idx] - if config.sequence_mixer_blocks[layer_idx]: - use_sliding_window = use_sliding_window or block.sliding_window is not None - if max_window_layers is None and use_sliding_window: - max_window_layers = layer_idx - - original_config = Qwen2MoeConfig( - vocab_size=config.vocab_size, - max_position_embeddings=config.max_position_embeddings, - hidden_size=config.hidden_size, - num_hidden_layers=config.num_layers, - num_attention_heads=config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_attention_heads"), - num_key_value_heads=config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_key_value_heads"), - intermediate_size=config.check_equal_for_all_and_get_value("mlp_blocks", "intermediate_size", mlp_type="MLP"), - moe_intermediate_size=config.check_equal_for_all_and_get_value( - "mlp_blocks", "intermediate_size", mlp_type="MoE" - ), - shared_expert_intermediate_size=config.check_equal_for_all_and_get_value( - "mlp_blocks", "shared_intermediate_size", mlp_type="MoE" - ), - hidden_act="silu", - rms_norm_eps=config.layer_norm_epsilon, - use_cache=config.use_cache, - use_sliding_window=use_sliding_window, - max_window_layers=max_window_layers, - tie_word_embeddings=config.tie_word_embeddings, - initializer_range=config.initializer_range, - rope_theta=config.rope_theta, - rope_scaling=config.rope_scaling, - attention_dropout=config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "softmax_dropout"), - num_experts=config.check_equal_for_all_and_get_value("mlp_blocks", "num_experts", mlp_type="MoE"), - num_experts_per_tok=config.check_equal_for_all_and_get_value( - "mlp_blocks", "num_experts_per_tok", mlp_type="MoE" - ), - router_aux_loss_coef=config.router_aux_loss_coef, - bos_token_id=config.bos_token_id, - eos_token_id=config.eos_token_id, - pad_token_id=config.pad_token_id, - qkv_bias=config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "qkv_bias"), - mlp_only_layers=mlp_only_layers, - norm_topk_prob=config.check_equal_for_all_and_get_value("mlp_blocks", "normalized_topk", mlp_type="MoE"), - architectures=[Qwen2MoeForCausalLM.__name__], - ) - - return original_config - - -def _export_qwen2_moe_state_dict( - config: GPTBaseConfig, safetensors_weights_manager: SafeTensorsWeightsManager -) -> dict: - num_attention_heads = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_attention_heads") - num_key_value_heads = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "num_key_value_heads") - qkv_bias = config.check_equal_for_all_and_get_value("sequence_mixer_blocks", "qkv_bias") - - state_dict = { - "model.embed_tokens.weight": safetensors_weights_manager.get_tensor("transformer.wte.weight"), - "model.norm.weight": safetensors_weights_manager.get_tensor("transformer.ln_f.weight"), - } - - if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") - - for layer_idx in range(config.num_layers): - state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.ln_1.weight" - ) - state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.ln_2.weight") - ) - - # MoE layer - if safetensors_weights_manager.has_tensor(f"transformer.h.{layer_idx}.mlp_block.gate.weight"): - state_dict[f"model.layers.{layer_idx}.mlp.gate.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp_block.gate.weight" - ) - - for expert_idx in range(config.mlp_blocks[layer_idx].num_experts): - up_weight, gate_weight = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc.weight")[ - expert_idx - ] - ) - - state_dict[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight"] = up_weight - state_dict[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight"] = gate_weight - - state_dict[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_proj.weight")[ - expert_idx - ] - ) - - if safetensors_weights_manager.has_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc_shared.weight"): - up_weight, gate_weight = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc_shared.weight") - ) - - state_dict[f"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight"] = gate_weight - state_dict[f"model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight"] = up_weight - state_dict[f"model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_proj_shared.weight") - ) - - state_dict[f"model.layers.{layer_idx}.mlp.shared_expert_gate.weight"] = ( - safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp_block.shared_expert_gate.weight" - ) - ) - # MLP layer - else: - up_weight, gate_weight = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp_block.c_fc.weight") - ) - - state_dict[f"model.layers.{layer_idx}.mlp.up_proj.weight"] = up_weight - state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.weight"] = gate_weight - - state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp_block.c_proj.weight" - ) - - keys = ["weight"] + (["bias"] if qkv_bias else []) - for key in keys: - query_weight, key_weight, value_weight = split_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.sequence_mixer.c_attn.{key}"), - num_attention_heads, - num_key_value_heads, - ) - state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.{key}"] = query_weight - state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.{key}"] = key_weight - state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.{key}"] = value_weight - - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.sequence_mixer.c_proj.weight" - ) - - return state_dict diff --git a/lm_engine/hf_models/modeling_utils/__init__.py b/lm_engine/hf_models/modeling_utils/__init__.py index 5c662b738..3404b87d3 100644 --- a/lm_engine/hf_models/modeling_utils/__init__.py +++ b/lm_engine/hf_models/modeling_utils/__init__.py @@ -5,8 +5,10 @@ from .activations import get_activation_function, is_glu from .convolution import ParameterizedConv1d from .dropout import Dropout -from .embedding import ParameterizedEmbedding -from .linear import ParameterizedLinear +from .dtensor_module import DTensorModule +from .embedding import ParameterizedEmbedding, get_tensor_parallel_vocab_info +from .linear import ColumnParallelLinear, ParameterizedLinear, RowParallelLinear +from .lm_head import LMHead from .mlp_blocks import ( MLP, MoE, @@ -24,3 +26,4 @@ interleave_query_key_value_tensor_for_attention, split_query_key_value_tensor_for_attention, ) +from .TP import tensor_parallel_split_safetensor_slice diff --git a/lm_engine/hf_models/modeling_utils/dropout.py b/lm_engine/hf_models/modeling_utils/dropout.py index d4d50806c..d5b887217 100644 --- a/lm_engine/hf_models/modeling_utils/dropout.py +++ b/lm_engine/hf_models/modeling_utils/dropout.py @@ -31,10 +31,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.is_tp_enabled: x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) - - x = super().forward(x) - - if self.is_tp_enabled: + x = super().forward(x) x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.placement) + else: + x = super().forward(x) return x diff --git a/lm_engine/hf_models/modeling_utils_TP/dtensor_module.py b/lm_engine/hf_models/modeling_utils/dtensor_module.py similarity index 66% rename from lm_engine/hf_models/modeling_utils_TP/dtensor_module.py rename to lm_engine/hf_models/modeling_utils/dtensor_module.py index 761f5e8c1..f6b993747 100644 --- a/lm_engine/hf_models/modeling_utils_TP/dtensor_module.py +++ b/lm_engine/hf_models/modeling_utils/dtensor_module.py @@ -9,17 +9,22 @@ import torch.nn as nn from ...dtensors import modify_state_dict_to_dtensor_dict +from ...utils import ProcessGroupManager class DTensorModule(nn.Module): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> None: - state_dict = modify_state_dict_to_dtensor_dict(self, state_dict=state_dict, prefix="", strip_keys=False) + if ProcessGroupManager.is_tensor_parallel_enabled(): + state_dict = modify_state_dict_to_dtensor_dict(self, state_dict=state_dict, prefix="", strip_keys=False) + super().load_state_dict(state_dict, strict, assign) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) -> None: - state_dict = modify_state_dict_to_dtensor_dict(self, state_dict=state_dict, prefix=prefix, strip_keys=True) + if ProcessGroupManager.is_tensor_parallel_enabled(): + state_dict = modify_state_dict_to_dtensor_dict(self, state_dict=state_dict, prefix=prefix, strip_keys=True) + super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) diff --git a/lm_engine/hf_models/modeling_utils/embedding.py b/lm_engine/hf_models/modeling_utils/embedding.py index 1dc9525f3..98119bba0 100644 --- a/lm_engine/hf_models/modeling_utils/embedding.py +++ b/lm_engine/hf_models/modeling_utils/embedding.py @@ -4,22 +4,91 @@ from __future__ import annotations +import math + import torch import torch.nn as nn +import torch.nn.functional as F +from torch.distributed._tensor.placement_types import Replicate, Shard +from ...dtensors import dtensor_to_tensor, tensor_to_dtensor +from ...utils import ProcessGroupManager, divide_if_divisible from ..parameter import mark_parameter_as_initialized +from .dtensor_module import DTensorModule +from .TP import get_module_placements + + +class ParameterizedEmbedding(DTensorModule): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + std: float | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> ParameterizedEmbedding: + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + self.use_padding_free_transformer = use_padding_free_transformer + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.sequence_parallel = sequence_parallel + self.vocab_start_index, self.vocab_end_index, num_embeddings_per_tp_rank = get_tensor_parallel_vocab_info( + num_embeddings + ) + + self.weight = nn.Parameter( + tensor_to_dtensor( + torch.empty(num_embeddings_per_tp_rank, embedding_dim), + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Shard(0), + ) + ) + + self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + else: + self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim)) -class ParameterizedEmbedding(nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, std: float | None = None) -> ParameterizedEmbedding: self.std = std - super().__init__(num_embeddings, embedding_dim) - @torch.no_grad() - def reset_parameters(self) -> None: - if self.std is None: - super().reset_parameters() + self.reset_parameters() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Replicate()) + x = F.embedding(x, weight=self.weight) + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.output_placement) else: - self.weight.normal_(mean=0, std=self.std) + x = F.embedding(x, weight=self.weight) + + return x + @torch.no_grad() + def reset_parameters(self) -> None: + self.weight.normal_(mean=0, std=1 if self.std is None else self.std) mark_parameter_as_initialized(self.weight) + + def extra_repr(self) -> str: + return f"{self.num_embeddings}, {self.embedding_dim}" + + +def get_tensor_parallel_vocab_info(vocab_size: int, make_vocab_size_divisible_by: int = 64) -> tuple[int, int, int]: + tp_rank = ProcessGroupManager.get_tensor_parallel_rank() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() + + divide_if_divisible(make_vocab_size_divisible_by, tp_world_size) + + vocab_size_per_tensor_parallel_rank = ( + make_vocab_size_divisible_by * math.ceil(vocab_size / make_vocab_size_divisible_by) + ) // tp_world_size + + vocab_start_index = tp_rank * vocab_size_per_tensor_parallel_rank + vocab_end_index = min((tp_rank + 1) * vocab_size_per_tensor_parallel_rank, vocab_size) + + return vocab_start_index, vocab_end_index, vocab_size_per_tensor_parallel_rank diff --git a/lm_engine/hf_models/modeling_utils/linear/__init__.py b/lm_engine/hf_models/modeling_utils/linear/__init__.py new file mode 100644 index 000000000..173bb73e1 --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/linear/__init__.py @@ -0,0 +1,8 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from .base import ParameterizedLinear +from .column import ColumnParallelLinear +from .replicated import ReplicatedLinear +from .row import RowParallelLinear diff --git a/lm_engine/hf_models/modeling_utils/linear.py b/lm_engine/hf_models/modeling_utils/linear/base.py similarity index 67% rename from lm_engine/hf_models/modeling_utils/linear.py rename to lm_engine/hf_models/modeling_utils/linear/base.py index 9ac15302a..8c340ae3d 100644 --- a/lm_engine/hf_models/modeling_utils/linear.py +++ b/lm_engine/hf_models/modeling_utils/linear/base.py @@ -7,21 +7,15 @@ import torch import torch.nn as nn -from ..parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay +from ...parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay class ParameterizedLinear(nn.Linear): def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, + self, in_features: int, out_features: int, bias: bool = True, std: float | None = None ) -> ParameterizedLinear: self.std = std - super().__init__(in_features, out_features, bias, device, dtype) + super().__init__(in_features, out_features, bias) mark_parameter_as_no_weight_decay(self.bias) diff --git a/lm_engine/hf_models/modeling_utils/linear/column.py b/lm_engine/hf_models/modeling_utils/linear/column.py new file mode 100644 index 000000000..72c54a68e --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/linear/column.py @@ -0,0 +1,78 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.distributed._tensor.placement_types import Replicate, Shard + +from ....dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel +from ....utils import ProcessGroupManager, divide_if_divisible +from ..dtensor_module import DTensorModule +from ..TP import get_module_placements +from .base import ParameterizedLinear + + +class ColumnParallelLinear(ParameterizedLinear, DTensorModule): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + std: float | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> ColumnParallelLinear: + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() if self.is_tp_enabled else 1 + + self.out_features_per_tp_rank = divide_if_divisible( + out_features, + tp_world_size, + f"`out_features` ({out_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", + ) + + super().__init__(in_features=in_features, out_features=self.out_features_per_tp_rank, bias=bias, std=std) + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) + ) + ) + + if bias: + self.bias = nn.Parameter( + tensor_to_dtensor( + self.bias, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Shard(0), + ) + ) + + if use_async_tensor_parallel(): + self.compile() + + self.reset_parameters() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = tensor_to_dtensor( + x, device_mesh=self.tp_mesh, current_placement=self.input_placement, desired_placement=Replicate() + ) + x = super().forward(x) + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=Shard(-1)) + else: + x = super().forward(x) + + return x + + def extra_repr(self) -> str: + return "in_features={}, out_features_per_tp_rank={}, bias={}".format( + self.in_features, self.out_features_per_tp_rank, self.bias is not None + ) diff --git a/lm_engine/hf_models/modeling_utils/linear/replicated.py b/lm_engine/hf_models/modeling_utils/linear/replicated.py new file mode 100644 index 000000000..5e769e33e --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/linear/replicated.py @@ -0,0 +1,40 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch.nn as nn +from torch.distributed._tensor.placement_types import Replicate + +from ....dtensors import tensor_to_dtensor +from ....utils import ProcessGroupManager +from ..dtensor_module import DTensorModule +from .base import ParameterizedLinear + + +class ReplicatedLinear(ParameterizedLinear, DTensorModule): + def __init__( + self, in_features: int, out_features: int, bias: bool = True, std: float | None = None + ) -> ReplicatedLinear: + super().__init__(in_features=in_features, out_features=out_features, bias=bias, std=std) + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), + ) + ) + + if bias: + self.bias = nn.Parameter( + tensor_to_dtensor( + self.bias, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), + ) + ) diff --git a/lm_engine/hf_models/modeling_utils/linear/row.py b/lm_engine/hf_models/modeling_utils/linear/row.py new file mode 100644 index 000000000..eb360156a --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/linear/row.py @@ -0,0 +1,76 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.distributed._tensor.placement_types import Replicate, Shard + +from ....dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel +from ....utils import ProcessGroupManager, divide_if_divisible +from ..dtensor_module import DTensorModule +from ..TP import get_module_placements +from .base import ParameterizedLinear + + +class RowParallelLinear(ParameterizedLinear, DTensorModule): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + std: float | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> RowParallelLinear: + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() if self.is_tp_enabled else 1 + + self.in_features_per_tp_rank = divide_if_divisible( + in_features, + tp_world_size, + f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", + ) + + super().__init__(in_features=self.in_features_per_tp_rank, out_features=out_features, bias=bias, std=std) + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) + ) + ) + + if bias: + self.bias = nn.Parameter( + tensor_to_dtensor( + self.bias, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Replicate(), + ) + ) + + if use_async_tensor_parallel(): + self.compile() + + self.reset_parameters() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Shard(-1)) + x = super().forward(x) + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.output_placement) + else: + x = super().forward(x) + + return x + + def extra_repr(self) -> str: + return "in_features_per_tp_rank={}, out_features={}, bias={}".format( + self.in_features_per_tp_rank, self.out_features, self.bias is not None + ) diff --git a/lm_engine/hf_models/modeling_utils_TP/lm_head.py b/lm_engine/hf_models/modeling_utils/lm_head.py similarity index 57% rename from lm_engine/hf_models/modeling_utils_TP/lm_head.py rename to lm_engine/hf_models/modeling_utils/lm_head.py index f3ae3a9c7..44a890e36 100644 --- a/lm_engine/hf_models/modeling_utils_TP/lm_head.py +++ b/lm_engine/hf_models/modeling_utils/lm_head.py @@ -8,34 +8,40 @@ from torch.distributed.device_mesh import DeviceMesh from ...dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel -from .embedding import Embedding_TP -from .TP import get_module_placements +from ..modeling_utils import ParameterizedEmbedding +from ..modeling_utils.TP import get_module_placements -class LMHead_TP(Embedding_TP): - def forward(self, input: torch.Tensor) -> torch.Tensor: - return self.compute_with_weight( - input, - self.weight, - use_padding_free_transformer=self.use_padding_free_transformer, - sequence_parallel=self.sequence_parallel, - tp_mesh=self.tp_mesh, - ) +class LMHead(ParameterizedEmbedding): + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = self.compute_with_weight( + x, + self.weight, + use_padding_free_transformer=self.use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + tp_mesh=self.tp_mesh if self.is_tp_enabled else None, + ) + else: + x = F.linear(x, weight=self.weight) + + return x @staticmethod def compute_with_weight( - input: torch.Tensor, + x: torch.Tensor, weight: torch.Tensor, use_padding_free_transformer: bool, sequence_parallel: bool, - tp_mesh: DeviceMesh, + tp_mesh: DeviceMesh | None, ) -> torch.Tensor: - function = ( - LMHead_TP._compute_with_weight_compiled if use_async_tensor_parallel() else LMHead_TP._compute_with_weight - ) + if tp_mesh is None: + return F.linear(x, weight=weight) + + function = LMHead._compute_with_weight_compiled if use_async_tensor_parallel() else LMHead._compute_with_weight return function( - input=input, + x=x, weight=weight, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, @@ -44,33 +50,33 @@ def compute_with_weight( @staticmethod def _compute_with_weight( - input: torch.Tensor, + x: torch.Tensor, weight: torch.Tensor, use_padding_free_transformer: bool, sequence_parallel: bool, tp_mesh: DeviceMesh, ) -> torch.Tensor: - input = tensor_to_dtensor( - input, + x = tensor_to_dtensor( + x, device_mesh=tp_mesh, current_placement=get_module_placements(use_padding_free_transformer, sequence_parallel), desired_placement=Replicate(), ) - input = F.linear(input, weight) - input = dtensor_to_tensor(input, device_mesh=tp_mesh, desired_placement=Shard(-1)) - return input + x = F.linear(x, weight) + x = dtensor_to_tensor(x, device_mesh=tp_mesh, desired_placement=Shard(-1)) + return x @torch.compile @staticmethod def _compute_with_weight_compiled( - input: torch.Tensor, + x: torch.Tensor, weight: torch.Tensor, use_padding_free_transformer: bool, sequence_parallel: bool, tp_mesh: DeviceMesh, ) -> torch.Tensor: - return LMHead_TP._compute_with_weight( - input=input, + return LMHead._compute_with_weight( + x=x, weight=weight, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel, diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index 06877bc7e..1520672ec 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -7,7 +7,9 @@ from .moe import MoE, ParameterizedExperts -def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int) -> MLP | MoE: +def get_mlp_block( + config: CommonConfig, use_padding_free_transformer: bool, sequence_parallel: bool, layer_idx: int +) -> MLP | MoE: block = config.mlp_blocks[layer_idx] mlp_type = block.mlp_type @@ -21,6 +23,8 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye initializer_range=config.initializer_range, m_width=config.m_width, num_layers=config.num_layers, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) if mlp_type == "MLP": @@ -34,7 +38,6 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye normalized_topk=block.normalized_topk, num_experts=block.num_experts, num_experts_per_tok=block.num_experts_per_tok, - use_padding_free_transformer=use_padding_free_transformer, ) else: raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py index c95747da9..c6a2c6c12 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/mlp.py @@ -12,7 +12,7 @@ from ...parameter import mark_parameter_as_mup_learning_rate from ..activations import get_activation_function, is_glu from ..dropout import Dropout -from ..linear import ParameterizedLinear +from ..linear import ColumnParallelLinear, RowParallelLinear class MLP(nn.Module): @@ -27,25 +27,36 @@ def __init__( initializer_range: float, m_width: float, num_layers: int, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, ) -> MLP: super().__init__() std = _get_std_for_linear(initializer_range, init_method, m_width) - self.c_fc = ParameterizedLinear( + self.c_fc = ColumnParallelLinear( hidden_size, 2 * intermediate_size if is_glu(activation_function) else intermediate_size, bias=add_bias, std=std, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) self.act = get_activation_function(activation_function) - self.c_proj = ParameterizedLinear( - intermediate_size, hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=add_bias, + std=std / math.sqrt(2 * num_layers), + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) - self.dropout = Dropout(dropout) + self.dropout = Dropout( + dropout, use_padding_free_transformer=use_padding_free_transformer, sequence_parallel=sequence_parallel + ) mark_parameter_as_mup_learning_rate(self.c_fc.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 32bd6751c..00df1fe77 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -10,10 +10,12 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed._functional_collectives import all_reduce +from torch.distributed._tensor.placement_types import Partial, Replicate, Shard +from ....dtensors import dtensor_to_tensor, tensor_to_dtensor from ....enums import Kernel -from ....kernels import is_kernel_allowed -from ....utils import ProcessGroupManager, is_sonicmoe_available, is_xma_available +from ....kernels import is_kernel_allowed, wait_for_ACT +from ....utils import ProcessGroupManager, divide_if_divisible, is_sonicmoe_available, is_xma_available from ...loss import add_aux_loss from ...parameter import ( mark_parameter_as_initialized, @@ -22,7 +24,8 @@ ) from ..activations import get_activation_function, is_glu from ..dropout import Dropout -from ..linear import ParameterizedLinear +from ..dtensor_module import DTensorModule +from ..linear import ColumnParallelLinear, ParameterizedLinear, ReplicatedLinear, RowParallelLinear from .mlp import _get_std_for_linear @@ -55,26 +58,31 @@ def compute_bincount(x: torch.Tensor, size: int, use_continuous_count: bool) -> return count +class SharedExpertsColumnParallelLinear(ColumnParallelLinear): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, dtensor_to_tensor(self.weight), dtensor_to_tensor(self.bias)) + + +class SharedExpertsRowParallelLinear(RowParallelLinear): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, dtensor_to_tensor(self.weight), dtensor_to_tensor(self.bias)) + + class ParameterizedExperts(nn.Module): def __init__( - self, - num_experts: int, - in_features: int, - out_features: int, - add_bias: bool = False, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, + self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None ) -> ParameterizedExperts: super().__init__() - self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features, device=device, dtype=dtype)) + self.std = std + + self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features)) self.bias = None if add_bias: - self.bias = nn.Parameter(torch.empty(num_experts, out_features, device=device, dtype=dtype)) + self.bias = nn.Parameter(torch.empty(num_experts, out_features)) - self.std = std + mark_parameter_as_no_weight_decay(self.bias) self.num_experts = num_experts self.in_features = in_features @@ -82,11 +90,9 @@ def __init__( self.reset_parameters() - mark_parameter_as_no_weight_decay(self.bias) - def forward( self, - input: torch.Tensor, + x: torch.Tensor, num_experts_per_token: int | None = None, expert_frequency: torch.Tensor | None = None, sorted_expert_idxs: torch.Tensor | None = None, @@ -97,10 +103,8 @@ def forward( grouped_out: bool = False, ) -> torch.Tensor: if is_kernel_allowed(Kernel.scattermoe): - assert self.bias is None - - input = scattered_experts( - inputs=input, + x = scattered_experts( + inputs=x, expert_weights=self.weight.permute(0, 2, 1), k=num_experts_per_token, sorted_expert_idxs=sorted_expert_idxs, @@ -111,14 +115,11 @@ def forward( grouped_out=grouped_out, ) else: - input = input.split(expert_frequency.tolist(), dim=0) - input = [ - F.linear(input[i], self.weight[i], None if self.bias is None else self.bias[i]) - for i in range(self.num_experts) - ] - input = torch.cat(input, dim=0) + x = x.split(expert_frequency.tolist(), dim=0) + x = [F.linear(x[i], self.weight[i]) for i in range(self.num_experts)] + x = torch.cat(x, dim=0) - return input + return x def extra_repr(self) -> str: return "num_experts={}, in_features={}, out_features={}".format( @@ -128,16 +129,131 @@ def extra_repr(self) -> str: @torch.no_grad() def reset_parameters(self) -> None: nn.init.normal_(self.weight, mean=0, std=self.std) - if hasattr(self, "bias") and self.bias is not None: - self.bias.zero_() + if self.bias is not None: + nn.init.zeros_(self.bias) mark_parameter_as_initialized(self.weight) mark_parameter_as_initialized(self.bias) -class MoE(nn.Module): - linear_class = ParameterizedExperts +class ColumnParallelExperts(ParameterizedExperts, DTensorModule): + def __init__( + self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None + ) -> ColumnParallelExperts: + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() if self.is_tp_enabled else 1 + + self.out_features_per_tp_rank = divide_if_divisible( + out_features, + tp_world_size, + f"`out_features` ({out_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", + ) + + super().__init__( + num_experts=num_experts, + in_features=in_features, + out_features=self.out_features_per_tp_rank, + add_bias=add_bias, + std=std, + ) + if self.is_tp_enabled: + assert not add_bias + + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) + ) + ) + + self.reset_parameters() + + def forward( + self, + x: torch.Tensor, + num_experts_per_token: int | None = None, + expert_frequency: torch.Tensor | None = None, + sorted_expert_idxs: torch.Tensor | None = None, + sorted_scattered_idxs: torch.Tensor | None = None, + expert_offsets: torch.Tensor | None = None, + gates: torch.Tensor | None = None, + grouped_in: bool = False, + grouped_out: bool = False, + ) -> torch.Tensor: + if self.is_tp_enabled: + assert is_kernel_allowed(Kernel.scattermoe) + + if is_kernel_allowed(Kernel.scattermoe): + x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False) + + x = scattered_experts( + inputs=x, + expert_weights=dtensor_to_tensor(self.weight).permute(0, 2, 1), + k=num_experts_per_token, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, + gates=gates, + grouped_in=grouped_in, + grouped_out=grouped_out, + ) + + x = wait_for_ACT(x, wait_in_forward=False, wait_in_backward=True) + else: + x = x.split(expert_frequency.tolist(), dim=0) + x = [F.linear(x[i], self.weight[i]) for i in range(self.num_experts)] + x = torch.cat(x, dim=0) + + return x + + def extra_repr(self) -> str: + return "num_experts={}, in_features={}, out_features_per_tp_rank={}".format( + self.num_experts, self.in_features, self.out_features_per_tp_rank + ) + + +class RowParallelExperts(ColumnParallelExperts): + def __init__( + self, num_experts: int, in_features: int, out_features: int, add_bias: bool, std: float | None = None + ) -> RowParallelExperts: + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() if self.is_tp_enabled else 1 + + self.in_features_per_tp_rank = divide_if_divisible( + in_features, + tp_world_size, + f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", + ) + + ParameterizedExperts.__init__( + self, + num_experts=num_experts, + in_features=self.in_features_per_tp_rank, + out_features=out_features, + add_bias=add_bias, + std=std, + ) + + if self.is_tp_enabled: + assert not add_bias + + self.weight = nn.Parameter( + tensor_to_dtensor( + self.weight, + device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), + current_placement=Shard(-1), + ) + ) + + self.reset_parameters() + + def extra_repr(self) -> str: + return "num_experts={}, in_features_per_tp_rank={}, out_features={}".format( + self.num_experts, self.in_features_per_tp_rank, self.out_features + ) + + +class MoE(nn.Module): def __init__( self, hidden_size: int, @@ -148,14 +264,15 @@ def __init__( normalized_topk: bool, num_experts: int, num_experts_per_tok: int, - activation_function: str, add_bias: bool, + activation_function: str, dropout: float, init_method: str, initializer_range: float, m_width: float, num_layers: int, use_padding_free_transformer: bool, + sequence_parallel: bool = False, ) -> MoE: super().__init__() @@ -171,12 +288,7 @@ def __init__( std = _get_std_for_linear(initializer_range, init_method, m_width) - self.gate = ParameterizedLinear( - in_features=self.hidden_size, - out_features=num_experts, - bias=False, - std=std, - ) + self.gate = ReplicatedLinear(in_features=self.hidden_size, out_features=num_experts, bias=False, std=std) if self.shared_expert_gating: assert shared_intermediate_size is not None @@ -185,15 +297,16 @@ def __init__( in_features=self.hidden_size, out_features=1, bias=False, std=std ) - self.c_fc = self.linear_class( + self.c_fc = ColumnParallelExperts( num_experts=num_experts, in_features=self.hidden_size, out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size, add_bias=add_bias, std=std, ) + if self.shared_intermediate_size is not None: - self.c_fc_shared = ParameterizedLinear( + self.c_fc_shared = SharedExpertsColumnParallelLinear( in_features=self.hidden_size, out_features=( 2 * self.shared_intermediate_size if is_glu(activation_function) else self.shared_intermediate_size @@ -207,15 +320,16 @@ def __init__( std /= math.sqrt(2 * num_layers) - self.c_proj = self.linear_class( + self.c_proj = RowParallelExperts( num_experts=num_experts, in_features=self.intermediate_size, out_features=self.hidden_size, add_bias=add_bias, std=std, ) + if self.shared_intermediate_size is not None: - self.c_proj_shared = ParameterizedLinear( + self.c_proj_shared = SharedExpertsRowParallelLinear( in_features=self.shared_intermediate_size, out_features=self.hidden_size, bias=add_bias, @@ -223,6 +337,7 @@ def __init__( ) self.dropout = Dropout(dropout) + self.placement = Shard(0) if sequence_parallel else Replicate() self.is_hopper_or_newer_gpu = torch.cuda.is_available() and torch.cuda.get_device_capability( torch.cuda.current_device() @@ -238,15 +353,24 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_fc_shared.weight) mark_parameter_as_mup_learning_rate(self.c_proj_shared.weight) + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.use_padding_free_transformer: batch_size, sequence_length, _ = x.shape x = x.view(-1, self.hidden_size) + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) + if is_kernel_allowed(Kernel.sonicmoe): assert self.use_interleaved_weights assert self.activation_function_string == "swiglu" + assert not self.is_tp_enabled moe_output, router_logits, expert_frequency = moe_TC_softmax_topk_layer( x=x, @@ -263,6 +387,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: assert not self.use_interleaved_weights router_logits, router_weights, selected_experts = self._compute_routing_weights(x) + + if self.is_tp_enabled: + x = dtensor_to_tensor( + x, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial() + ) + moe_output, expert_frequency = self._compute_experts(x, router_weights, selected_experts) if self.shared_intermediate_size is None: @@ -272,6 +402,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: del moe_output + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=Partial()) + x = dtensor_to_tensor( + x, device_mesh=self.tp_mesh, desired_placement=self.placement, grad_placement=self.placement + ) + if not self.use_padding_free_transformer: x = x.reshape(batch_size, sequence_length, self.hidden_size) @@ -292,6 +428,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _compute_routing_weights(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # x -> (total_q, hidden_size) router_logits = self.gate(x) + + if self.is_tp_enabled: + router_logits = dtensor_to_tensor( + router_logits, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial() + ) + # router_logits -> (total_q, num_experts) if self.normalized_topk: @@ -324,7 +466,7 @@ def _compute_experts( expert_offsets = expert_frequency.cumsum(-1) x = self.c_fc( - input=x, + x=x, num_experts_per_token=self.top_k, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, @@ -335,7 +477,7 @@ def _compute_experts( x = self.act(x) x = self.c_proj( - input=x, + x=x, num_experts_per_token=1, sorted_expert_idxs=sorted_expert_idxs, sorted_scattered_idxs=sorted_scattered_idxs, @@ -354,9 +496,9 @@ def _compute_experts( x = x[batch_index] - x = self.c_fc(input=x, expert_frequency=expert_frequency) + x = self.c_fc(x=x, expert_frequency=expert_frequency) x = self.act(x) - x = self.c_proj(input=x, expert_frequency=expert_frequency) + x = self.c_proj(x=x, expert_frequency=expert_frequency) x = x * batch_gates.unsqueeze(-1) # [:, None] zeros = torch.zeros((T, self.hidden_size), dtype=x.dtype, device=x.device) diff --git a/lm_engine/hf_models/modeling_utils/normalization.py b/lm_engine/hf_models/modeling_utils/normalization.py index c518577c2..089da664e 100644 --- a/lm_engine/hf_models/modeling_utils/normalization.py +++ b/lm_engine/hf_models/modeling_utils/normalization.py @@ -7,35 +7,106 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.distributed._tensor.placement_types import Replicate +from ...dtensors import dtensor_to_tensor, tensor_to_dtensor from ...enums import Kernel -from ...kernels import is_kernel_allowed -from ...utils import is_xma_available +from ...kernels import is_kernel_allowed, wait_for_ACT +from ...utils import ProcessGroupManager, is_xma_available from ..parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay +from .dtensor_module import DTensorModule +from .TP import get_module_placements if is_xma_available(): from xma import rmsnorm -class LayerNorm(nn.LayerNorm): +class LayerNorm(nn.LayerNorm, DTensorModule): + def __init__( + self, + normalized_shape: int, + eps: float = 1e-6, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> LayerNorm: + super().__init__(normalized_shape, eps=eps) + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.weight = nn.Parameter( + tensor_to_dtensor(self.weight, device_mesh=self.tp_mesh, current_placement=Replicate()) + ) + + self.bias = nn.Parameter( + tensor_to_dtensor(self.bias, device_mesh=self.tp_mesh, current_placement=Replicate()) + ) + + self.reset_parameters() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) + x = super().forward(x) + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.placement) + else: + x = super().forward(x) + + return x + def reset_parameters(self) -> None: super().reset_parameters() mark_parameter_as_initialized(self.weight) + mark_parameter_as_initialized(self.bias) -class RMSNorm(nn.RMSNorm): +class RMSNorm(nn.RMSNorm, DTensorModule): + def __init__( + self, + normalized_shape: int, + eps: float = 1e-6, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, + ) -> RMSNorm: + super().__init__(normalized_shape, eps=eps) + + self.is_tp_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + + if self.is_tp_enabled: + self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() + self.placement = get_module_placements(use_padding_free_transformer, sequence_parallel) + + self.weight = nn.Parameter( + tensor_to_dtensor(self.weight, device_mesh=self.tp_mesh, current_placement=Replicate()) + ) + + self.reset_parameters() + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) + if is_kernel_allowed(Kernel.rmsnorm) or is_kernel_allowed(Kernel.rmsnorm_memory_efficient): + x = wait_for_ACT(x, wait_in_forward=True, wait_in_backward=False) + x = rmsnorm( x=x, weight=self.weight, eps=self.eps, memory_efficient=is_kernel_allowed(Kernel.rmsnorm_memory_efficient), ) + + x = wait_for_ACT(x, wait_in_forward=False, wait_in_backward=True) else: x = super().forward(x) + if self.is_tp_enabled: + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.placement) + return x def reset_parameters(self) -> None: @@ -49,14 +120,22 @@ def __init__( normalized_shape: int, p: int, eps: float | None = None, - elementwise_affine=True, - device: torch.device = None, - dtype: torch.dtype = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, ) -> PNorm: self.p = p - super().__init__(normalized_shape, eps, elementwise_affine, device, dtype) + + super().__init__( + normalized_shape=normalized_shape, + eps=eps, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.is_tp_enabled: + x = tensor_to_dtensor(x, device_mesh=self.tp_mesh, current_placement=self.placement) + dtype = x.dtype x = x.float() @@ -66,6 +145,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.weight is not None: x = self.weight * x + if self.is_tp_enabled: + x = dtensor_to_tensor(x, device_mesh=self.tp_mesh, desired_placement=self.placement) + return x @@ -73,18 +155,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def get_normalization_function( - normalization_function: str, normalized_shape: int, eps: float = 1e-5, p: int | None = None + normalization_function: str, + normalized_shape: int, + eps: float = 1e-5, + p: int | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, ) -> LayerNorm | RMSNorm | PNorm: if normalization_function is None: return nn.Identity() + kwargs = { + "normalized_shape": normalized_shape, + "eps": eps, + "use_padding_free_transformer": use_padding_free_transformer, + "sequence_parallel": sequence_parallel, + } + if normalization_function in _NORMALIZATION_FUNCTIONS: if normalization_function == "p_norm": assert p is not None - normalization = _NORMALIZATION_FUNCTIONS[normalization_function](normalized_shape, eps=eps, p=p) + normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs, p=p) else: assert p is None - normalization = _NORMALIZATION_FUNCTIONS[normalization_function](normalized_shape, eps=eps) + normalization = _NORMALIZATION_FUNCTIONS[normalization_function](**kwargs) else: raise ValueError(f"unexpected `normalization_function` {normalization_function}") diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index e4d42e081..628a25c97 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -23,6 +23,7 @@ def get_sequence_mixer( config: CommonConfig, causal: bool, use_padding_free_transformer: bool, + sequence_parallel: bool, layer_idx: int, ) -> SEQUENCE_MIXER_TYPE: block = config.sequence_mixer_blocks[layer_idx] @@ -157,9 +158,9 @@ def get_sequence_mixer( if sequence_mixer_type == "softmax_attention": return Attention( **sequence_mixer_kwargs, - qkv_bias=block.qkv_bias, softmax_dropout=block.softmax_dropout, use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) else: raise ValueError(f"unexpected sequence_mixer_type ({sequence_mixer_type})") diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index c18d10e68..832387ce6 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -12,12 +12,12 @@ from ....enums import Kernel from ....kernels import is_kernel_allowed, wait_for_ACT -from ....utils import Accelerator, divide_if_divisible, is_torch_xla_available +from ....utils import Accelerator, ProcessGroupManager, divide_if_divisible, is_torch_xla_available from ...cache import GenerationCache from ...parameter import mark_parameter_as_mup_learning_rate from ..chunk import contiguous_split from ..dropout import Dropout -from ..linear import ParameterizedLinear +from ..linear import ColumnParallelLinear, RowParallelLinear from ..position_embedding import apply_rotary_pos_emb from .utils import flash_attention @@ -77,7 +77,6 @@ def __init__( sliding_window: int | None, position_embedding_type: str, add_bias: bool, - qkv_bias: bool, softmax_dropout: float, dropout: float, init_method: str, @@ -85,55 +84,97 @@ def __init__( m_width: float, num_layers: int, causal: bool, - layer_idx: int, - use_padding_free_transformer: bool, + layer_idx: int | None = None, + use_padding_free_transformer: bool = False, + sequence_parallel: bool = False, ) -> Attention: super().__init__() + if ProcessGroupManager.is_initialized(): + self.tp_world_size = ( + ProcessGroupManager.get_tensor_parallel_world_size() + if ProcessGroupManager.is_tensor_parallel_enabled() + else 1 + ) + else: + self.tp_world_size = 1 + self.causal = causal - self.hidden_size = hidden_size - self.num_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads + self.global_hidden_size = hidden_size + self.global_num_heads = num_attention_heads + self.global_num_key_value_heads = num_key_value_heads self.add_bias = add_bias - self.qkv_bias = qkv_bias - self.use_padding_free_transformer = use_padding_free_transformer self.sliding_window = sliding_window - self.head_dim = divide_if_divisible( - self.hidden_size, - self.num_heads, - f"`hidden_size` ({self.hidden_size}) must be divisible by `num_heads` ({self.num_heads})", + self.use_padding_free_transformer = use_padding_free_transformer + self.sequence_parallel = sequence_parallel + + divide_if_divisible(self.global_hidden_size, self.global_num_heads) + + self.hidden_size = divide_if_divisible( + self.global_hidden_size, self.tp_world_size, "hidden_size should be divisible by TP world size" + ) + + self.num_heads = divide_if_divisible( + self.global_num_heads, self.tp_world_size, "num_heads must be divisible by TP world size" ) + self.head_dim = divide_if_divisible(self.hidden_size, self.num_heads, "") self.position_embedding_type = position_embedding_type self.attention_multiplier = attention_multiplier self.layer_idx = layer_idx divide_if_divisible( - self.num_heads, - self.num_key_value_heads, - f"`num_heads` ({self.num_heads}) should be a multiple of `num_key_value_heads` ({self.num_key_value_heads})", + self.global_num_heads, + self.global_num_key_value_heads, + f"`num_heads` ({self.global_num_heads}) should be a multiple of `num_key_value_heads` ({self.global_num_key_value_heads})", + ) + + self.num_key_value_heads = divide_if_divisible( + self.global_num_key_value_heads, + self.tp_world_size, + f"`num_key_value_heads` ({self.global_num_key_value_heads}) must be divisible by `tensor_parallel_world_size` ({self.tp_world_size})", ) std = initializer_range if init_method == "mup": std /= math.sqrt(m_width) - self.c_attn = ParameterizedLinear( - self.hidden_size, - self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, - bias=self.qkv_bias, + + self.c_attn = ColumnParallelLinear( + self.global_hidden_size, + self.global_hidden_size + 2 * self.global_num_key_value_heads * self.head_dim, + bias=self.add_bias, std=std, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, ) std = initializer_range / math.sqrt(2 * num_layers) if init_method == "mup": std /= math.sqrt(m_width) - self.c_proj = ParameterizedLinear(self.hidden_size, self.hidden_size, bias=self.add_bias, std=std) + + self.c_proj = RowParallelLinear( + self.global_hidden_size, + self.global_hidden_size, + bias=self.add_bias, + std=std / math.sqrt(2 * num_layers), + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) self.softmax_dropout_p = softmax_dropout - self.softmax_dropout = Dropout(softmax_dropout) - self.dropout = Dropout(dropout) + self.softmax_dropout = Dropout( + softmax_dropout, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) + + self.dropout = Dropout( + dropout, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=sequence_parallel, + ) mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) @@ -155,11 +196,12 @@ def forward( assert use_flash_attention_2 or use_flash_attention_3 assert past_key_values is None - T = hidden_states.size(0) + T = hidden_states.size(0) * (self.tp_world_size if self.sequence_parallel else 1) input_shape = (T, self.num_key_value_heads, -1) output_shape = (T, -1, self.head_dim) else: batch_size, query_length = hidden_states.shape[:-1] + query_length *= self.tp_world_size if self.sequence_parallel else 1 input_shape = (batch_size, query_length, self.num_key_value_heads, -1) output_shape = (batch_size, query_length, -1, self.head_dim) diff --git a/lm_engine/hf_models/modeling_utils_TP/TP.py b/lm_engine/hf_models/modeling_utils_TP/TP.py deleted file mode 100644 index bc2df335f..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/TP.py +++ /dev/null @@ -1,5 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from ..modeling_utils.TP import get_module_placements, tensor_parallel_split_safetensor_slice diff --git a/lm_engine/hf_models/modeling_utils_TP/__init__.py b/lm_engine/hf_models/modeling_utils_TP/__init__.py deleted file mode 100644 index 4560064d9..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from .dtensor_module import DTensorModule -from .embedding import Embedding_TP, get_tensor_parallel_vocab_info -from .linear import ColumnParallelLinear, RowParallelLinear -from .lm_head import LMHead_TP -from .mlp_blocks import MLP_TP, MoE_TP, get_mlp_block_TP -from .normalization import get_normalization_function_TP -from .sequence_mixer_blocks import Attention_TP, get_sequence_mixer_TP -from .TP import get_module_placements, tensor_parallel_split_safetensor_slice diff --git a/lm_engine/hf_models/modeling_utils_TP/embedding.py b/lm_engine/hf_models/modeling_utils_TP/embedding.py deleted file mode 100644 index ba1ec9bac..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/embedding.py +++ /dev/null @@ -1,68 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import math - -import torch -import torch.nn as nn -from torch.distributed._tensor.placement_types import Replicate, Shard - -from ...dtensors import dtensor_to_tensor, tensor_to_dtensor -from ...utils import ProcessGroupManager, divide_if_divisible -from ..modeling_utils import ParameterizedEmbedding -from .dtensor_module import DTensorModule -from .TP import get_module_placements - - -class Embedding_TP(ParameterizedEmbedding, DTensorModule): - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - std: float | None = None, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> Embedding_TP: - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.use_padding_free_transformer = use_padding_free_transformer - self.sequence_parallel = sequence_parallel - - self.vocab_start_index, self.vocab_end_index, num_embeddings_per_tp_rank = get_tensor_parallel_vocab_info( - num_embeddings - ) - - super().__init__(num_embeddings_per_tp_rank, embedding_dim, std=std) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) - ) - ) - - self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=Replicate()) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.output_placement) - return input - - -def get_tensor_parallel_vocab_info(vocab_size: int, make_vocab_size_divisible_by: int = 64) -> tuple[int, int, int]: - tp_rank = ProcessGroupManager.get_tensor_parallel_rank() - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - divide_if_divisible(make_vocab_size_divisible_by, tp_world_size, "") - - vocab_size_per_tensor_parallel_rank = ( - make_vocab_size_divisible_by * math.ceil(vocab_size / make_vocab_size_divisible_by) - ) // tp_world_size - - vocab_start_index = tp_rank * vocab_size_per_tensor_parallel_rank - vocab_end_index = min((tp_rank + 1) * vocab_size_per_tensor_parallel_rank, vocab_size) - - return vocab_start_index, vocab_end_index, vocab_size_per_tensor_parallel_rank diff --git a/lm_engine/hf_models/modeling_utils_TP/linear.py b/lm_engine/hf_models/modeling_utils_TP/linear.py deleted file mode 100644 index 40be06ee2..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/linear.py +++ /dev/null @@ -1,137 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import torch -import torch.nn as nn -from torch.distributed._tensor.placement_types import Replicate, Shard - -from ...dtensors import dtensor_to_tensor, tensor_to_dtensor, use_async_tensor_parallel -from ...utils import ProcessGroupManager, divide_if_divisible -from ..modeling_utils import ParameterizedLinear -from .dtensor_module import DTensorModule -from .TP import get_module_placements - - -class ColumnParallelLinear(ParameterizedLinear, DTensorModule): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> ColumnParallelLinear: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.out_features_per_device = divide_if_divisible( - out_features, - tp_world_size, - f"`out_features` ({out_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", - ) - - super().__init__( - in_features=in_features, - out_features=self.out_features_per_device, - bias=bias, - device=device, - dtype=dtype, - std=std, - ) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) - ) - ) - if bias: - self.bias = nn.Parameter( - tensor_to_dtensor( - self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(0) - ) - ) - - self.input_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - - if use_async_tensor_parallel(): - self.compile() - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = tensor_to_dtensor( - input, device_mesh=self.tp_mesh, current_placement=self.input_placement, desired_placement=Replicate() - ) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=Shard(-1)) - return input - - def extra_repr(self) -> str: - return "in_features={}, out_features_per_device={}, bias={}".format( - self.in_features, self.out_features_per_device, self.bias is not None - ) - - -class RowParallelLinear(ParameterizedLinear, DTensorModule): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> RowParallelLinear: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.in_features_per_device = divide_if_divisible( - in_features, - tp_world_size, - f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", - ) - - super().__init__( - in_features=self.in_features_per_device, - out_features=out_features, - bias=bias, - device=device, - dtype=dtype, - std=std, - ) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) - ) - ) - if bias: - self.bias = nn.Parameter( - tensor_to_dtensor( - self.bias, - device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), - current_placement=Replicate(), - ) - ) - - self.output_placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - - if use_async_tensor_parallel(): - self.compile() - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.output_placement) - return input - - def extra_repr(self) -> str: - return "in_features_per_device={}, out_features={}, bias={}".format( - self.in_features_per_device, self.out_features, self.bias is not None - ) diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py deleted file mode 100644 index b74ba8585..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from ...config import CommonConfig -from .mlp import MLP_TP -from .moe import MoE_TP - - -def get_mlp_block_TP( - config: CommonConfig, use_padding_free_transformer: bool, sequence_parallel: bool, layer_idx: int -) -> MLP_TP | MoE_TP: - block = config.mlp_blocks[layer_idx] - mlp_type = block.mlp_type - - kwargs = dict( - hidden_size=config.hidden_size, - intermediate_size=block.intermediate_size, - activation_function=block.activation_function, - add_bias=block.add_bias, - dropout=block.dropout, - init_method=config.init_method, - initializer_range=config.initializer_range, - m_width=config.m_width, - num_layers=config.num_layers, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - if mlp_type == "MLP": - mlp = MLP_TP(**kwargs) - elif mlp_type == "MoE": - mlp = MoE_TP( - **kwargs, - shared_intermediate_size=block.shared_intermediate_size, - num_experts=block.num_experts, - num_experts_per_tok=block.num_experts_per_tok, - ) - else: - raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") - - return mlp diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py deleted file mode 100644 index fe8c8470e..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/mlp.py +++ /dev/null @@ -1,59 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import math - -import torch.nn as nn - -from ...modeling_utils import MLP, Dropout, get_activation_function, is_glu -from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear -from ..linear import ColumnParallelLinear, RowParallelLinear - - -class MLP_TP(MLP): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - activation_function: str, - add_bias: bool, - dropout: float, - init_method: str, - initializer_range: float, - m_width: float, - num_layers: int, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> MLP_TP: - nn.Module.__init__(self) - - std = _get_std_for_linear(initializer_range, init_method, m_width) - - self.c_fc = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size if is_glu(activation_function) else intermediate_size, - bias=add_bias, - std=std, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.act = get_activation_function(activation_function) - - self.c_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=add_bias, - std=std / math.sqrt(2 * num_layers), - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.dropout = Dropout( - dropout, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) diff --git a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py deleted file mode 100644 index 2af05f021..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/mlp_blocks/moe.py +++ /dev/null @@ -1,306 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.distributed._tensor.placement_types import Partial, Replicate, Shard - -from ....dtensors import dtensor_to_tensor, tensor_to_dtensor -from ....enums import Kernel -from ....kernels import is_kernel_allowed, wait_for_ACT -from ....utils import ProcessGroupManager, divide_if_divisible, is_xma_available -from ...loss import add_aux_loss -from ...modeling_utils import Dropout, MoE, ParameterizedExperts, ParameterizedLinear, get_activation_function, is_glu -from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear -from ..dtensor_module import DTensorModule -from ..linear import ColumnParallelLinear, RowParallelLinear - - -if is_xma_available(): - from xma.layers.moe import scattered_experts - - -class ReplicatedLinear_TP(ParameterizedLinear, DTensorModule): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, - ) -> ReplicatedLinear_TP: - super().__init__( - in_features=in_features, out_features=out_features, bias=bias, device=device, dtype=dtype, std=std - ) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() - ) - ) - - -class ColumnParallelExperts(ParameterizedExperts, DTensorModule): - def __init__( - self, - num_experts: int, - in_features: int, - out_features: int, - add_bias: bool = False, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, - ) -> ColumnParallelExperts: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - self.out_features_per_device = divide_if_divisible( - out_features, - tp_world_size, - f"`out_features` ({out_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", - ) - - super().__init__( - num_experts=num_experts, - in_features=in_features, - out_features=self.out_features_per_device, - add_bias=add_bias, - device=device, - dtype=dtype, - std=std, - ) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(1) - ) - ) - - def forward( - self, - input: torch.Tensor, - num_experts_per_token: int | None = None, - num_tokens_per_expert: torch.Tensor | None = None, - sorted_expert_idxs: torch.Tensor | None = None, - sorted_scattered_idxs: torch.Tensor | None = None, - expert_offsets: torch.Tensor | None = None, - gates: torch.Tensor | None = None, - grouped_in: bool = False, - grouped_out: bool = False, - ) -> torch.Tensor: - assert is_kernel_allowed(Kernel.scattermoe) - - input = scattered_experts( - inputs=wait_for_ACT(input, wait_in_forward=True, wait_in_backward=False), - expert_weights=dtensor_to_tensor(self.weight).permute(0, 2, 1), - k=num_experts_per_token, - sorted_expert_idxs=sorted_expert_idxs, - sorted_scattered_idxs=sorted_scattered_idxs, - expert_offsets=expert_offsets, - gates=gates, - grouped_in=grouped_in, - grouped_out=grouped_out, - ) - - input = wait_for_ACT(input, wait_in_forward=False, wait_in_backward=True) - - return input - - -class RowParallelExperts(ColumnParallelExperts): - def __init__( - self, - num_experts: int, - in_features: int, - out_features: int, - add_bias: bool = False, - device: torch.device | None = None, - dtype: torch.dtype | None = None, - std: float | None = None, - ) -> RowParallelExperts: - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - self.in_features_per_device = divide_if_divisible( - in_features, - tp_world_size, - f"`in_features` ({in_features}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", - ) - - ParameterizedExperts.__init__( - self, - num_experts=num_experts, - in_features=self.in_features_per_device, - out_features=out_features, - add_bias=add_bias, - device=device, - dtype=dtype, - std=std, - ) - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Shard(-1) - ) - ) - - -class SharedExpertsColumnParallelLinear(ColumnParallelLinear): - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, dtensor_to_tensor(self.weight), dtensor_to_tensor(self.bias)) - - -class SharedExpertsRowParallelLinear(RowParallelLinear): - def forward(self, input: torch.Tensor) -> torch.Tensor: - return F.linear(input, dtensor_to_tensor(self.weight), dtensor_to_tensor(self.bias)) - - -class MoE_TP(MoE, DTensorModule): - def __init__( - self, - hidden_size: int, - intermediate_size: int, - shared_intermediate_size: int, - num_experts: int, - num_experts_per_tok: int, - activation_function: str, - add_bias: bool, - dropout: float, - init_method: str, - initializer_range: float, - m_width: float, - num_layers: int, - use_padding_free_transformer: bool, - sequence_parallel: bool = False, - ) -> MoE_TP: - nn.Module.__init__(self) - - self.num_experts = num_experts - self.top_k = num_experts_per_tok - self.use_padding_free_transformer = use_padding_free_transformer - - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.shared_intermediate_size = shared_intermediate_size - - std = _get_std_for_linear(initializer_range, init_method, m_width) - - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.gate = ReplicatedLinear_TP( - in_features=self.hidden_size, - out_features=num_experts, - bias=False, - std=std, - ) - - self.c_fc = ColumnParallelExperts( - num_experts=num_experts, - in_features=self.hidden_size, - out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size, - add_bias=add_bias, - std=std, - ) - if self.shared_intermediate_size is not None: - self.c_fc_shared = SharedExpertsColumnParallelLinear( - in_features=self.hidden_size, - out_features=( - 2 * self.shared_intermediate_size if is_glu(activation_function) else self.shared_intermediate_size - ), - bias=add_bias, - std=std, - ) - - self.act = get_activation_function(activation_function) - - std /= math.sqrt(2 * num_layers) - - self.c_proj = RowParallelExperts( - num_experts=num_experts, - in_features=self.intermediate_size, - out_features=self.hidden_size, - add_bias=add_bias, - std=std, - ) - if self.shared_intermediate_size is not None: - self.c_proj_shared = SharedExpertsRowParallelLinear( - in_features=self.shared_intermediate_size, - out_features=self.hidden_size, - bias=add_bias, - std=std, - ) - - self.dropout = Dropout(dropout) - self.placement = Shard(0) if sequence_parallel else Replicate() - - self.is_hopper_or_newer_gpu = torch.cuda.is_available() and torch.cuda.get_device_capability( - torch.cuda.current_device() - ) >= (9, 0) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - assert is_kernel_allowed(Kernel.scattermoe) - - if not self.use_padding_free_transformer: - batch_size, sequence_length, _ = hidden_states.shape - - hidden_states = hidden_states.view(-1, self.hidden_size) - - hidden_states = tensor_to_dtensor(hidden_states, device_mesh=self.tp_mesh, current_placement=self.placement) - - router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states) - - hidden_states = dtensor_to_tensor( - hidden_states, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial() - ) - - moe_output, expert_frequency = self._compute_experts(hidden_states, router_weights, selected_experts) - - if self.shared_intermediate_size is None: - hidden_states = moe_output - else: - hidden_states = moe_output + self._compute_shared_experts(hidden_states) - - del moe_output - - hidden_states = tensor_to_dtensor(hidden_states, device_mesh=self.tp_mesh, current_placement=Partial()) - hidden_states = dtensor_to_tensor( - hidden_states, device_mesh=self.tp_mesh, desired_placement=self.placement, grad_placement=self.placement - ) - - if not self.use_padding_free_transformer: - hidden_states = hidden_states.reshape(batch_size, sequence_length, self.hidden_size) - - hidden_states = self.dropout(hidden_states) - - aux_loss = ( - self._compute_switch_loss( - logits=router_logits, probs=torch.softmax(router_logits, dim=-1), expert_frequency=expert_frequency - ) - if self.training - else 0 - ) - - add_aux_loss(aux_loss) - - return hidden_states - - def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]: - # hidden_states -> (total_q, hidden_size) - router_logits = self.gate(hidden_states) - router_logits = dtensor_to_tensor( - router_logits, device_mesh=self.tp_mesh, desired_placement=Replicate(), grad_placement=Partial() - ) - # router_logits -> (total_q, num_experts) - - router_weights, selected_experts = self._get_topk(router_logits) - router_weights = F.softmax(router_weights.float(), dim=-1) - - # we cast back to the input dtype - router_weights = router_weights.type_as(hidden_states) - - return router_logits, router_weights, selected_experts diff --git a/lm_engine/hf_models/modeling_utils_TP/normalization.py b/lm_engine/hf_models/modeling_utils_TP/normalization.py deleted file mode 100644 index 737255d85..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/normalization.py +++ /dev/null @@ -1,121 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import torch -import torch.nn as nn -from torch.distributed._tensor.placement_types import Partial, Replicate - -from ...dtensors import dtensor_to_tensor, tensor_to_dtensor -from ...enums import Kernel -from ...kernels import is_kernel_allowed, wait_for_ACT -from ...utils import ProcessGroupManager, is_xma_available -from .dtensor_module import DTensorModule -from .TP import get_module_placements - - -if is_xma_available(): - from xma import rmsnorm - - -class LayerNorm_TP(nn.LayerNorm, DTensorModule): - def __init__( - self, - normalized_shape: int, - eps: float = 1e-6, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> LayerNorm_TP: - super().__init__(normalized_shape, eps=eps) - - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, - device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), - current_placement=Replicate(), - ) - ) - self.bias = nn.Parameter( - tensor_to_dtensor( - self.bias, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() - ) - ) - - self.placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=self.placement) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.placement) - return input - - -class RMSNorm_TP(nn.RMSNorm, DTensorModule): - def __init__( - self, - normalized_shape: int, - eps: float = 1e-6, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> RMSNorm_TP: - super().__init__(normalized_shape, eps=eps) - - self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - - self.weight = nn.Parameter( - tensor_to_dtensor( - self.weight, device_mesh=ProcessGroupManager.get_tensor_parallel_mesh(), current_placement=Replicate() - ) - ) - - self.sequence_parallel = sequence_parallel - self.placement = get_module_placements(use_padding_free_transformer, sequence_parallel) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - rmsnorm_kernel_allowed = is_kernel_allowed(Kernel.rmsnorm) - rmsnorm_memory_efficient_kernel_allowed = is_kernel_allowed(Kernel.rmsnorm_memory_efficient) - - if rmsnorm_kernel_allowed or rmsnorm_memory_efficient_kernel_allowed: - input = wait_for_ACT(input, wait_in_forward=True, wait_in_backward=False) - input = rmsnorm( - x=input, - weight=dtensor_to_tensor( - self.weight, grad_placement=Partial() if self.sequence_parallel else Replicate() - ), - eps=self.eps, - memory_efficient=rmsnorm_memory_efficient_kernel_allowed, - ) - input = wait_for_ACT(input, wait_in_forward=False, wait_in_backward=True) - else: - input = tensor_to_dtensor(input, device_mesh=self.tp_mesh, current_placement=self.placement) - input = super().forward(input) - input = dtensor_to_tensor(input, device_mesh=self.tp_mesh, desired_placement=self.placement) - - return input - - -_NORMALIZATION_FUNCTIONS = {"layernorm": LayerNorm_TP, "rmsnorm": RMSNorm_TP} - - -def get_normalization_function_TP( - normalization_function: str, - normalized_shape: int, - eps: float = 1e-5, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, -) -> LayerNorm_TP | RMSNorm_TP: - if normalization_function in _NORMALIZATION_FUNCTIONS: - normalization = _NORMALIZATION_FUNCTIONS[normalization_function]( - normalized_shape, - eps=eps, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - else: - raise ValueError(f"unexpected `normalization_function` {normalization_function}") - - return normalization diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/__init__.py deleted file mode 100644 index f14da7d8d..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from ...config import CommonConfig -from .attention import Attention_TP - - -def get_sequence_mixer_TP( - config: CommonConfig, - causal: bool, - use_padding_free_transformer: bool, - layer_idx: int, - sequence_parallel: bool, -) -> Attention_TP: - block = config.sequence_mixer_blocks[layer_idx] - sequence_mixer_type = block.sequence_mixer_type - - sequence_mixer_kwargs = dict( - hidden_size=config.hidden_size, - num_attention_heads=block.num_attention_heads, - num_key_value_heads=block.num_key_value_heads, - attention_multiplier=block.attention_multiplier, - position_embedding_type=config.position_embedding_type, - add_bias=block.add_bias, - softmax_dropout=block.softmax_dropout, - dropout=block.dropout, - init_method=config.init_method, - initializer_range=config.initializer_range, - m_width=config.m_width, - num_layers=config.num_layers, - causal=causal, - layer_idx=layer_idx, - sequence_parallel=sequence_parallel, - ) - - if sequence_mixer_type == "softmax_attention": - return Attention_TP(**sequence_mixer_kwargs, use_padding_free_transformer=use_padding_free_transformer) diff --git a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py deleted file mode 100644 index 8e7851b70..000000000 --- a/lm_engine/hf_models/modeling_utils_TP/sequence_mixer_blocks/attention.py +++ /dev/null @@ -1,221 +0,0 @@ -# ************************************************** -# Copyright (c) 2025, Mayank Mishra -# ************************************************** - -from __future__ import annotations - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ....enums import Kernel -from ....kernels import is_kernel_allowed, wait_for_ACT -from ....utils import ProcessGroupManager, divide_if_divisible -from ...cache import GenerationCache -from ...modeling_utils import Attention, Dropout, apply_rotary_pos_emb, flash_attention -from ...modeling_utils.mlp_blocks.mlp import _get_std_for_linear -from ..linear import ColumnParallelLinear, RowParallelLinear - - -class Attention_TP(Attention): - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - num_key_value_heads: int, - attention_multiplier: float, - position_embedding_type: str, - add_bias: bool, - softmax_dropout: float, - dropout: float, - init_method: str, - initializer_range: float, - m_width: float, - num_layers: int, - causal: bool, - layer_idx: int | None = None, - use_padding_free_transformer: bool = False, - sequence_parallel: bool = False, - ) -> Attention_TP: - nn.Module.__init__(self) - - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - self.causal = causal - self.global_hidden_size = hidden_size - self.global_num_heads = num_attention_heads - self.global_num_key_value_heads = num_key_value_heads - self.add_bias = add_bias - self.use_padding_free_transformer = use_padding_free_transformer - self.sequence_parallel = sequence_parallel - - divide_if_divisible( - self.global_hidden_size, - self.global_num_heads, - f"`embed_dim` ({self.global_hidden_size}) must be divisible by `num_heads` ({self.global_num_heads})", - ) - - self.hidden_size = divide_if_divisible( - self.global_hidden_size, tp_world_size, "hidden_size should be divisible by TP world size" - ) - - self.num_heads = divide_if_divisible( - self.global_num_heads, tp_world_size, "num_heads must be divisible by TP world size" - ) - - self.head_dim = divide_if_divisible(self.hidden_size, self.num_heads, "") - self.position_embedding_type = position_embedding_type - self.attention_multiplier = attention_multiplier - self.layer_idx = layer_idx - - std = _get_std_for_linear(initializer_range, init_method, m_width) - - divide_if_divisible( - self.global_num_heads, - self.global_num_key_value_heads, - f"`num_heads` ({self.global_num_heads}) should be a multiple of `num_key_value_heads` ({self.global_num_key_value_heads})", - ) - - self.num_key_value_heads = divide_if_divisible( - self.global_num_key_value_heads, - tp_world_size, - f"`num_key_value_heads` ({self.global_num_key_value_heads}) must be divisible by `tensor_parallel_world_size` ({tp_world_size})", - ) - - self.c_attn = ColumnParallelLinear( - self.global_hidden_size, - self.global_hidden_size + 2 * self.global_num_key_value_heads * self.head_dim, - bias=self.add_bias, - std=std, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.c_proj = RowParallelLinear( - self.global_hidden_size, - self.global_hidden_size, - bias=self.add_bias, - std=std / math.sqrt(2 * num_layers), - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.softmax_dropout_p = softmax_dropout - - self.softmax_dropout = Dropout( - softmax_dropout, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - self.dropout = Dropout( - dropout, - use_padding_free_transformer=use_padding_free_transformer, - sequence_parallel=sequence_parallel, - ) - - def forward( - self, - hidden_states: torch.Tensor, - past_key_values: GenerationCache | None = None, - attention_mask: torch.Tensor | None = None, - rope_cos_sin: torch.Tensor | None = None, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - ) -> torch.Tensor: - use_flash_attention_2 = is_kernel_allowed(Kernel.flash_attention_2) - use_flash_attention_3 = is_kernel_allowed(Kernel.flash_attention_3) - - tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size() - - if self.use_padding_free_transformer: - assert use_flash_attention_2 or use_flash_attention_3 - assert past_key_values is None - - total_q = hidden_states.shape[0] * (tp_world_size if self.sequence_parallel else 1) - input_shape = (total_q, self.num_key_value_heads, -1) - output_shape = (total_q, -1, self.head_dim) - else: - batch_size, query_length = hidden_states.shape[:-1] - query_length *= tp_world_size if self.sequence_parallel else 1 - - input_shape = (batch_size, query_length, self.num_key_value_heads, -1) - output_shape = (batch_size, query_length, -1, self.head_dim) - - hidden_states = self.c_attn(hidden_states) - - hidden_states = hidden_states.view(*input_shape) - - query, key, value = hidden_states.split( - ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 - ) - - query = query.reshape(*output_shape) - - if not self.use_padding_free_transformer: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - if self.position_embedding_type == "rope": - query = apply_rotary_pos_emb(query, rope_cos_sin) - key = apply_rotary_pos_emb(key, rope_cos_sin) - - if past_key_values is not None: - key, value = past_key_values.update(key_states=key, value_states=value, layer_idx=self.layer_idx) - - if use_flash_attention_2 or use_flash_attention_3: - if self.use_padding_free_transformer: - output_shape = (-1, self.hidden_size) - else: - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - output_shape = (batch_size, query_length, -1) - - query = wait_for_ACT(query, wait_in_forward=True, wait_in_backward=False) - key = wait_for_ACT(key, wait_in_forward=True, wait_in_backward=False) - value = wait_for_ACT(value, wait_in_forward=True, wait_in_backward=False) - - hidden_states = flash_attention( - q=query, - k=key, - v=value, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - attention_mask=attention_mask, - use_padding_free_transformer=self.use_padding_free_transformer, - causal=self.causal, - dropout=self.softmax_dropout_p if self.training else 0, - softmax_scale=self.attention_multiplier, - ) - - del query, key, value - - hidden_states = wait_for_ACT(hidden_states, wait_in_forward=False, wait_in_backward=True) - hidden_states = hidden_states.view(*output_shape) - else: - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask, - dropout_p=self.softmax_dropout_p if self.training else 0, - is_causal=self.causal if attention_mask is None else False, - scale=self.attention_multiplier, - enable_gqa=True, - ) - - del query, key, value - - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.transpose(1, 2) - hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) - - hidden_states = self.c_proj(hidden_states) - hidden_states = self.dropout(hidden_states) - - return hidden_states diff --git a/lm_engine/hf_models/models/gpt_base_TP/base.py b/lm_engine/hf_models/models/gpt_base_TP/base.py index e5236b0ac..33fb7cd48 100644 --- a/lm_engine/hf_models/models/gpt_base_TP/base.py +++ b/lm_engine/hf_models/models/gpt_base_TP/base.py @@ -2,11 +2,11 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -from ...mixins import BaseModelMixin_TP, PreTrainedModelMixin_TP +from ...mixins import BaseModelMixin_TP, PreTrainedModelMixin from ..gpt_base import GPTBaseConfig -class GPTBasePreTrainedModel_TP(PreTrainedModelMixin_TP): +class GPTBasePreTrainedModel_TP(PreTrainedModelMixin): config_class = GPTBaseConfig diff --git a/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py b/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py index 4d1f592ae..6e4364da3 100644 --- a/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py +++ b/lm_engine/hf_models/models/gpt_base_TP/weights/shard.py @@ -5,8 +5,7 @@ import torch from .....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible -from ....modeling_utils import is_glu -from ....modeling_utils_TP import get_tensor_parallel_vocab_info, tensor_parallel_split_safetensor_slice +from ....modeling_utils import get_tensor_parallel_vocab_info, is_glu, tensor_parallel_split_safetensor_slice from ...gpt_base import GPTBaseConfig @@ -79,7 +78,6 @@ def get_gpt_base_model_parallel_state_dict( state_dict.update( _get_moe( activation_function=block.activation_function, - add_bias=block.add_bias, safetensors_weights_manager=safetensors_weights_manager, prefix=prefix + "mlp_block.", column_parallel_shard_dim=1, @@ -164,7 +162,6 @@ def _get_attention( def _get_moe( activation_function: str, - add_bias: bool, safetensors_weights_manager: SafeTensorsWeightsManager, prefix: str, column_parallel_shard_dim: int, @@ -172,12 +169,10 @@ def _get_moe( ) -> None: state_dict = {prefix + "gate.weight": safetensors_weights_manager.get_tensor(prefix + "gate.weight")} - assert not add_bias - state_dict.update( _get_mlp( activation_function=activation_function, - add_bias=add_bias, + add_bias=False, safetensors_weights_manager=safetensors_weights_manager, prefix=prefix, column_parallel_shard_dim=column_parallel_shard_dim, diff --git a/lm_engine/hf_models/models/gpt_crosslayer/layer.py b/lm_engine/hf_models/models/gpt_crosslayer/layer.py index 3413911ff..cfdfa6acc 100644 --- a/lm_engine/hf_models/models/gpt_crosslayer/layer.py +++ b/lm_engine/hf_models/models/gpt_crosslayer/layer.py @@ -53,7 +53,10 @@ def __init__( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx + config, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=False, + layer_idx=layer_idx, ) def forward( diff --git a/lm_engine/hf_models/models/palm/layer.py b/lm_engine/hf_models/models/palm/layer.py index 8bafb9098..c54b686f3 100644 --- a/lm_engine/hf_models/models/palm/layer.py +++ b/lm_engine/hf_models/models/palm/layer.py @@ -23,9 +23,18 @@ def __init__( self.ln = get_normalization_function( config.normalization_function, config.hidden_size, eps=config.layer_norm_epsilon ) - self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.sequence_mixer = get_sequence_mixer( + config, + True, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=False, + layer_idx=layer_idx, + ) self.mlp_block = get_mlp_block( - config, use_padding_free_transformer=use_padding_free_transformer, layer_idx=layer_idx + config, + use_padding_free_transformer=use_padding_free_transformer, + sequence_parallel=False, + layer_idx=layer_idx, ) def forward( diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index 4ab6e6834..0a816f0a0 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -5,6 +5,11 @@ import torch.nn as nn +_INIT_MARKER = "_is_initialized" +_METADATA_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate"] +_ALL_MARKERS = _METADATA_MARKERS + [_INIT_MARKER] + + def mark_parameter_as_no_weight_decay(parameter: nn.Parameter | None) -> nn.Parameter | None: if parameter is not None: parameter._no_weight_decay = True @@ -35,4 +40,40 @@ def is_parameter_with_mup_learning_rate(parameter: nn.Parameter | None) -> bool: def is_parameter_initialized(parameter: nn.Parameter | None) -> bool: - return getattr(parameter, "_is_initialized", False) + return getattr(parameter, _INIT_MARKER, False) + + +def get_parameter_marker_maps(model_container: list[nn.Module], extra_markers: list[str] = []) -> list[dict]: + if isinstance(model_container, nn.Module): + model_container = [model_container] + + marker_maps = [] + for model in model_container: + marker_maps.append({}) + for param_name, param in model.named_parameters(): + marker_maps[-1][param_name] = {} + for marker in _METADATA_MARKERS + extra_markers: + marker_maps[-1][param_name][marker] = getattr(param, marker, False) + + return marker_maps + + +def set_parameter_marker_maps( + model_container: list[nn.Module], + marker_maps: list[dict], + replacement_patterns: list[tuple[str]] = [], + _trim_prefix: str | None = None, +) -> None: + if isinstance(model_container, nn.Module): + model_container = [model_container] + + for model, _marker_map in zip(model_container, marker_maps): + for param_name, parameter in model.named_parameters(): + for pattern, replacement in replacement_patterns: + param_name = param_name.replace(pattern, replacement) + + if _trim_prefix is not None: + param_name = param_name.removeprefix(_trim_prefix) + + for marker, value in _marker_map[param_name].items(): + setattr(parameter, marker, value) diff --git a/lm_engine/model_wrapper/base.py b/lm_engine/model_wrapper/base.py index 2fbc5841c..3284b8271 100644 --- a/lm_engine/model_wrapper/base.py +++ b/lm_engine/model_wrapper/base.py @@ -205,7 +205,11 @@ def _setup_model(self) -> None: def calculate_num_parameters(self) -> tuple[int, int]: model_kwargs = self._get_model_kwargs() - with torch.device("meta"): + with ( + torch.device("meta"), + ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), + ): if self.model_name is not None: model_kwargs["config"] = AutoConfig.from_pretrained(model_kwargs.pop("pretrained_model_name_or_path")) diff --git a/lm_engine/train_utils.py b/lm_engine/train_utils.py index 4016732e0..201799d0d 100644 --- a/lm_engine/train_utils.py +++ b/lm_engine/train_utils.py @@ -138,20 +138,6 @@ def get_model_tflops( b * s, h, h, gradient_checkpointing=gradient_checkpointing_enabled ) - sequence_mixer_flops += _get_attention_flops(b, s, h) - elif sequence_mixer_type == "multihead_latent_attention": - # QKV down and up projection FLOPs - sequence_mixer_flops = 2 * _get_linear_flops( - b * s, - h, - block.query_compression_size + 2 * block.key_value_compression_size, - gradient_checkpointing=gradient_checkpointing_enabled, - ) - # output projection FLOPs - sequence_mixer_flops += _get_linear_flops( - b * s, h, h, gradient_checkpointing=gradient_checkpointing_enabled - ) - sequence_mixer_flops += _get_attention_flops(b, s, h) elif sequence_mixer_type == "mamba2": # NOTE taken from NexaAI's fork (might be incorrect) diff --git a/tests/hf_models/multi_gpu/dcp/dcp.py b/tests/hf_models/multi_gpu/dcp/dcp.py index 81787a9dd..fcfd7d6c1 100644 --- a/tests/hf_models/multi_gpu/dcp/dcp.py +++ b/tests/hf_models/multi_gpu/dcp/dcp.py @@ -94,9 +94,14 @@ Communication.barrier() -_, _, consolidated_state_dict = load_checkpoint_and_unshard(unshard_config) - if global_rank == 0: + with ( + ProcessGroupManager.set_dummy_tensor_parallel_world_size(1), + ProcessGroupManager.set_dummy_tensor_parallel_rank(0), + ProcessGroupManager.set_dummy_pipeline_parallel_world_size(1), + ProcessGroupManager.set_dummy_pipeline_parallel_rank(0), + ): + _, _, consolidated_state_dict = load_checkpoint_and_unshard(unshard_config) original_state_dict = model_container[0].state_dict() assert consolidated_state_dict.keys() == original_state_dict.keys() diff --git a/tests/hf_models/multi_gpu/dcp/train.yml b/tests/hf_models/multi_gpu/dcp/train.yml index 30acd2e83..a6253e911 100644 --- a/tests/hf_models/multi_gpu/dcp/train.yml +++ b/tests/hf_models/multi_gpu/dcp/train.yml @@ -179,6 +179,7 @@ model_args: - mlp_type: MLP activation_function: swiglu add_bias: false + efficient_initialization: true tuning_args: tuning_method: pretraining diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index 2698cc8a7..ff1ef7dfd 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -58,7 +58,7 @@ ], mlp_blocks=[ {"mlp_type": "MLP", "add_bias": False}, - {"mlp_type": "MoE", "add_bias": False}, + {"mlp_type": "MoE"}, ], ) @@ -70,7 +70,7 @@ with enable_kernels(kernels): if torch.distributed.get_rank() == 0: - with torch.device("meta"): + with torch.device("meta"), ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): model = TestCommons.from_config(None, config) model = model.to_empty(device=torch.cuda.current_device()) diff --git a/tests/hf_models/multi_gpu/unsharding/unsharding.py b/tests/hf_models/multi_gpu/unsharding/unsharding.py index 2e927406f..18244ceed 100644 --- a/tests/hf_models/multi_gpu/unsharding/unsharding.py +++ b/tests/hf_models/multi_gpu/unsharding/unsharding.py @@ -60,7 +60,9 @@ if is_tp_first_rank: - model = TestCommons.from_config(None, config) + with ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): + model = TestCommons.from_config(None, config) + model.save_pretrained(args.tmp_path, safe_serialization=True) Communication.barrier() diff --git a/tests/hf_models/single_gpu/model_conversion_test.py b/tests/hf_models/single_gpu/model_conversion_test.py index 6943cf312..9793cb7e4 100644 --- a/tests/hf_models/single_gpu/model_conversion_test.py +++ b/tests/hf_models/single_gpu/model_conversion_test.py @@ -104,33 +104,3 @@ def test_granitemoehybrid_model_conversion(self, device: torch.device, is_moe: b compare_loss=False, logits_atol_float32=2.5e-5, ) - - @parameterized.expand(TestCommons.make_args_matrix(TestCommons.get_all_devices(), [False, True], [False, True])) - def test_qwen2_moe_model_conversion(self, device: torch.device, qkv_bias: bool, use_sliding_window: bool) -> None: - lm_engine_config = self.get_moe_test_config( - "rope", - qkv_bias=qkv_bias, - shared_n_inner=36, - activation_function="swiglu", - normalization_function="rmsnorm", - shared_expert_gating=True, - normalized_topk=False, - ) - - for layer_idx in [3, 6]: - mlp_block = lm_engine_config.mlp_blocks[layer_idx] - lm_engine_config.mlp_blocks[layer_idx] = _MLPArgs( - intermediate_size=mlp_block.intermediate_size, activation_function=mlp_block.activation_function - ) - - if use_sliding_window: - for layer_idx in range(3, lm_engine_config.num_layers): - lm_engine_config.sequence_mixer_blocks[layer_idx].sliding_window = 4096 - - self.model_conversion_test( - lm_engine_config=lm_engine_config, - model_type="qwen2_moe", - device=device, - exact_match=False, - weight_test_only=use_sliding_window, - ) diff --git a/tests/hf_models/single_gpu/weight_test.py b/tests/hf_models/single_gpu/weight_test.py index cb488b9cf..78bcb5198 100644 --- a/tests/hf_models/single_gpu/weight_test.py +++ b/tests/hf_models/single_gpu/weight_test.py @@ -19,7 +19,7 @@ def test_query_key_value_weight_loading_and_saving(self) -> None: config = self.get_dense_test_config("learned_absolute") layer_idx = 1 - attention = get_sequence_mixer(config, True, False, layer_idx) + attention = get_sequence_mixer(config, True, False, False, layer_idx) num_key_value_heads = config.sequence_mixer_blocks[layer_idx].num_key_value_heads state_dict = attention.state_dict() diff --git a/tests/hf_models/test_common.py b/tests/hf_models/test_common.py index 914760684..7c16609b9 100644 --- a/tests/hf_models/test_common.py +++ b/tests/hf_models/test_common.py @@ -94,7 +94,6 @@ def get_moe_test_config( num_attention_heads: int = 4, shared_expert_gating: bool = False, normalized_topk: bool = True, - qkv_bias: bool = None, ) -> GPTBaseConfig: num_key_value_heads = 2 @@ -106,9 +105,6 @@ def get_moe_test_config( "attention_multiplier": attention_multiplier, } - if qkv_bias is not None: - sequence_mixer["qkv_bias"] = qkv_bias - return GPTBaseConfig( vocab_size=2048, max_position_embeddings=1024, @@ -131,7 +127,6 @@ def get_moe_test_config( "num_experts_per_tok": num_experts_per_tok, "normalized_topk": normalized_topk, "activation_function": activation_function, - "add_bias": add_bias, "shared_intermediate_size": None if shared_n_inner is None else shared_n_inner, "shared_expert_gating": shared_expert_gating, } diff --git a/tests/training/params_group/params_group_test.py b/tests/training/params_group/params_group_test.py index 6a3dcc59a..414e15e47 100644 --- a/tests/training/params_group/params_group_test.py +++ b/tests/training/params_group/params_group_test.py @@ -8,12 +8,9 @@ import torch from parameterized import parameterized -from lm_engine.distributed import ( - _get_parameter_marker_maps, - _set_parameter_marker_maps, - wrap_model_container_for_distributed_training, -) +from lm_engine.distributed import wrap_model_container_for_distributed_training from lm_engine.enums import ParamsGroupMethod +from lm_engine.hf_models import get_parameter_marker_maps, set_parameter_marker_maps from lm_engine.model_wrapper import get_model_container from lm_engine.optimization.params_group import get_param_groups_list from lm_engine.utils import ProcessGroupManager @@ -55,9 +52,9 @@ def test_mup_group( if use_fsdp: model_container, _ = wrap_model_container_for_distributed_training(args, model_container) elif use_torch_compile: - marker_maps = _get_parameter_marker_maps(model_container) + marker_maps = get_parameter_marker_maps(model_container) model_container = [torch.compile(model) for model in model_container] - _set_parameter_marker_maps(model_container, marker_maps) + set_parameter_marker_maps(model_container, marker_maps, _trim_prefix="_orig_mod.") params_groups = get_param_groups_list(model_container, args.optimizer_args.class_args, params_group_method)[0]