diff --git a/lm_engine/distributed.py b/lm_engine/distributed.py index 952d89fb7..8fbd19d9c 100644 --- a/lm_engine/distributed.py +++ b/lm_engine/distributed.py @@ -30,8 +30,7 @@ from .containers import ModelContainer from .enums import Kernel from .gradient_checkpointing import apply_gradient_checkpointing -from .hf_models import CausalLMOutputWithPast -from .hf_models.parameter import _ALL_MARKERS +from .hf_models import CausalLMOutputWithPast, is_parameter_initialized from .kernels import is_kernel_allowed from .utils import ( Accelerator, @@ -120,13 +119,13 @@ def _get_fsdp_mixed_precision( return mixed_precision -def _get_parameter_marker_maps(model_container: ModelContainer) -> list[dict]: +def _get_parameter_marker_maps(model_container: ModelContainer, extra_markers: list[str] = []) -> list[dict]: marker_maps = [] for model in model_container: marker_maps.append({}) for param_name, param in model.named_parameters(): marker_maps[-1][param_name] = {} - for marker in _ALL_MARKERS: + for marker in ["_no_weight_decay", "_has_mup_learning_rate"] + extra_markers: marker_maps[-1][param_name][marker] = getattr(param, marker, False) return marker_maps @@ -222,7 +221,18 @@ def wrap_model_container_for_distributed_training( **args.distributed_args.gradient_checkpointing_args, ) - marker_maps = _get_parameter_marker_maps(model_container) + if efficient_initialization: + for model in model_container: + for param_name, parameter in model.named_parameters(): + parameter._is_initialized = False + + for param_name, parameter in model.named_buffers(): + parameter._is_initialized = False + + marker_maps = _get_parameter_marker_maps(model_container) + else: + marker_maps = _get_parameter_marker_maps(model_container, extra_markers=["_is_initialized"]) + accelerator = Accelerator.get_accelerator() if accelerator == Accelerator.tpu: @@ -382,6 +392,13 @@ def _sharding_function(parameter: nn.Parameter) -> Shard: pipeline_stages = [] pipeline_schedule = None + for model in model_container: + for param_name, parameter in model.named_parameters(): + assert is_parameter_initialized(parameter), f"{param_name} is not initialized" + + for param_name, parameter in model.named_buffers(): + assert is_parameter_initialized(parameter), f"{param_name} is not initialized" + if num_pipeline_stages > 1: micro_batch_size = args.training_parameters.micro_batch_size sequence_length = args.datasets[0].class_args.get("sequence_length") diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index 71766fc94..5b22610dc 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -26,7 +26,43 @@ def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "softmax_attention" -class _Mamba2Args(BaseArgs): +class _MultiHeadLatentAttentionArgs(BaseArgs): + sequence_mixer_type: str = "multihead_latent_attention" + num_attention_heads: int | None = None + softmax_dropout: float = 0 + dropout: float = 0 + add_bias: bool = False + attention_multiplier: float | None = None + sliding_window: int | None = None + query_compression_size: int | None = None + key_value_compression_size: int | None = None + num_attention_heads: int | None = None + head_dim: int | None = None + normalization_function: str = "layernorm" + + def model_post_init(self, __context: Any) -> None: + assert self.sequence_mixer_type == "multihead_latent_attention" + assert self.num_attention_heads is not None + assert self.query_compression_size is not None + assert self.key_value_compression_size is not None + assert self.num_attention_heads is not None + assert self.head_dim is not None + + +class _SoftPlusDecayArgs(BaseArgs): + A_init_min: float = 0 + A_init_max: float = 16 + dt_init_min: float = 0.001 + dt_init_max: float = 0.1 + dt_init_floor: float = 1e-4 + + def model_post_init(self, __context: Any) -> None: + assert self.A_init_min >= 0 + assert self.A_init_min <= self.A_init_max + assert self.dt_init_min <= self.dt_init_max + + +class _Mamba2Args(_SoftPlusDecayArgs): sequence_mixer_type: str = "mamba2" state_size: int = 128 intermediate_size: int @@ -91,7 +127,7 @@ def model_post_init(self, __context: Any) -> None: assert self.sequence_mixer_type == "causal_convolution" -class _GatedDeltaNetArgs(BaseArgs): +class _GatedDeltaNetArgs(_SoftPlusDecayArgs): sequence_mixer_type: str = "gated_deltanet" k_head_dim: int v_head_dim: int diff --git a/lm_engine/hf_models/mixins/dense/base.py b/lm_engine/hf_models/mixins/dense/base.py index a16a4673d..9582ee384 100644 --- a/lm_engine/hf_models/mixins/dense/base.py +++ b/lm_engine/hf_models/mixins/dense/base.py @@ -38,10 +38,6 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> PreTrainedModelMixi self._has_mamba2 = any([block.sequence_mixer_type == "mamba2" for block in self.config.sequence_mixer_blocks]) - def _init_weights(self, module: nn.Module) -> None: - if hasattr(module, "reset_parameters"): - module.reset_parameters() - # FIXME typing def prepare_inputs_for_model( self, @@ -118,9 +114,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.position_embedding_type = config.position_embedding_type self._setup_positional_encoding() - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.Tensor | None = None, diff --git a/lm_engine/hf_models/mixins/dense/main.py b/lm_engine/hf_models/mixins/dense/main.py index 32599cb0d..4ec016803 100644 --- a/lm_engine/hf_models/mixins/dense/main.py +++ b/lm_engine/hf_models/mixins/dense/main.py @@ -37,9 +37,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.m_width = config.m_width - # Initialize weights and apply final processing - self.post_init() - def get_input_embeddings(self) -> ParameterizedEmbedding: return self.transformer.wte diff --git a/lm_engine/hf_models/mixins/dense_TP/base.py b/lm_engine/hf_models/mixins/dense_TP/base.py index e3d53e3c9..700c2b153 100644 --- a/lm_engine/hf_models/mixins/dense_TP/base.py +++ b/lm_engine/hf_models/mixins/dense_TP/base.py @@ -92,9 +92,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.position_embedding_type = config.position_embedding_type self._setup_positional_encoding() - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.Tensor | None = None, diff --git a/lm_engine/hf_models/mixins/dense_TP/main.py b/lm_engine/hf_models/mixins/dense_TP/main.py index 13951bdc5..f614b1cc0 100644 --- a/lm_engine/hf_models/mixins/dense_TP/main.py +++ b/lm_engine/hf_models/mixins/dense_TP/main.py @@ -52,9 +52,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.tp_mesh = ProcessGroupManager.get_tensor_parallel_mesh() - # Initialize weights and apply final processing - self.post_init() - def forward( self, input_ids: torch.Tensor | list[list[int]] | None = None, diff --git a/lm_engine/hf_models/model_conversion/granitemoehybrid.py b/lm_engine/hf_models/model_conversion/granitemoehybrid.py index 6bae682e8..2f4f2f518 100644 --- a/lm_engine/hf_models/model_conversion/granitemoehybrid.py +++ b/lm_engine/hf_models/model_conversion/granitemoehybrid.py @@ -176,11 +176,11 @@ def _import_granitemoehybrid_state_dict( state_dict[f"transformer.h.{layer_idx}.sequence_mixer.in_proj.bias"] = ( safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.in_proj.bias") ) - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.dt_bias"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mamba.dt_bias" + state_dict[f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.dt_bias"] = ( + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.dt_bias") ) - state_dict[f"transformer.h.{layer_idx}.sequence_mixer.A_log"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mamba.A_log" + state_dict[f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.A_log"] = ( + safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mamba.A_log") ) state_dict[f"transformer.h.{layer_idx}.sequence_mixer.D"] = safetensors_weights_manager.get_tensor( f"model.layers.{layer_idx}.mamba.D" @@ -404,10 +404,10 @@ def _export_granitemoehybrid_state_dict( f"transformer.h.{layer_idx}.sequence_mixer.in_proj.bias" ) state_dict[f"model.layers.{layer_idx}.mamba.dt_bias"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.sequence_mixer.dt_bias" + f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.dt_bias" ) state_dict[f"model.layers.{layer_idx}.mamba.A_log"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.sequence_mixer.A_log" + f"transformer.h.{layer_idx}.sequence_mixer.decay_gate.A_log" ) state_dict[f"model.layers.{layer_idx}.mamba.D"] = safetensors_weights_manager.get_tensor( f"transformer.h.{layer_idx}.sequence_mixer.D" diff --git a/lm_engine/hf_models/modeling_utils/decay_gate.py b/lm_engine/hf_models/modeling_utils/decay_gate.py new file mode 100644 index 000000000..f9f8de738 --- /dev/null +++ b/lm_engine/hf_models/modeling_utils/decay_gate.py @@ -0,0 +1,120 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import Replicate + +from ...dtensors import tensor_to_dtensor +from ..parameter import ( + mark_parameter_as_initialized, + mark_parameter_as_mup_learning_rate, + mark_parameter_as_no_weight_decay, +) +from .linear import ParameterizedLinear + + +class SoftplusDecayGate(nn.Module): + def __init__( + self, + hidden_size: int | None, + output_size: int, + std: float | None, + has_projection: bool = False, + A_init_min: float = 0, + A_init_max: float = 16, + dt_init_min: float = 1e-3, + dt_init_max: float = 0.1, + dt_init_floor: float = 1e-4, + ) -> SoftplusDecayGate: + super().__init__() + + self.output_size = output_size + self.has_projection = has_projection + + if has_projection: + self.proj = ParameterizedLinear(hidden_size, self.output_size, std=std) + mark_parameter_as_mup_learning_rate(self.proj.weight) + else: + assert hidden_size is None + + self.A_log = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + mark_parameter_as_no_weight_decay(self.A_log) + + self.dt_bias = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) + mark_parameter_as_no_weight_decay(self.dt_bias) + + assert A_init_min >= 0 + assert A_init_max >= A_init_min + + self.A_init_min = A_init_min + self.A_init_max = A_init_max + + assert dt_init_min > 0 + assert dt_init_max >= dt_init_min + + self.dt_init_min = dt_init_min + self.dt_init_max = dt_init_max + self.dt_init_floor = dt_init_floor + + self.reset_parameters() + + def forward( + self, x: torch.Tensor, final_exponential: bool, output_dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + if self.has_projection: + x = self.proj(x) + + x = x.float() + x = x + self.dt_bias + x = F.softplus(x) + x = -self.A_log.float().exp() * x + + if final_exponential: + x = torch.exp(x) + + x = x.to(output_dtype) + + return x + + @torch.no_grad() + def reset_parameters(self) -> None: + A = torch.empty(self.output_size, dtype=torch.float32).uniform_(self.A_init_min, self.A_init_max) + + if isinstance(self.A_log, DTensor): + A = tensor_to_dtensor( + tensor=A, + device_mesh=self.A_log.device_mesh, + current_placement=[Replicate()] * len(self.A_log.placements), + desired_placement=self.A_log.placements, + ) + + self.A_log.copy_(torch.log(A)) + + dt = torch.exp( + torch.rand(self.output_size) * (math.log(self.dt_init_max) - math.log(self.dt_init_min)) + + math.log(self.dt_init_min) + ) + dt = torch.clamp(dt, min=self.dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + + if isinstance(self.dt_bias, DTensor): + inv_dt = tensor_to_dtensor( + tensor=inv_dt, + device_mesh=self.dt_bias.device_mesh, + current_placement=[Replicate()] * len(self.dt_bias.placements), + desired_placement=self.dt_bias.placements, + ) + + self.dt_bias.copy_(inv_dt) + + mark_parameter_as_initialized(self.A_log) + mark_parameter_as_initialized(self.dt_bias) diff --git a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py index 349e92b51..32bd6751c 100644 --- a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py +++ b/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py @@ -80,14 +80,6 @@ def __init__( self.in_features = in_features self.out_features = out_features - self.register_buffer( - "N_array", torch.empty((num_experts,), device=device, dtype=torch.uint32), persistent=False - ) - - self.register_buffer( - "K_array", torch.empty((num_experts,), device=device, dtype=torch.uint32), persistent=False - ) - self.reset_parameters() mark_parameter_as_no_weight_decay(self.bias) @@ -139,15 +131,9 @@ def reset_parameters(self) -> None: if hasattr(self, "bias") and self.bias is not None: self.bias.zero_() - self.N_array.fill_(self.out_features) - self.K_array.fill_(self.in_features) - mark_parameter_as_initialized(self.weight) mark_parameter_as_initialized(self.bias) - mark_parameter_as_initialized(self.N_array) - mark_parameter_as_initialized(self.K_array) - class MoE(nn.Module): linear_class = ParameterizedExperts diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index 83d16afec..e4d42e081 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -104,6 +104,11 @@ def get_sequence_mixer( init_method=config.init_method, normalization_function=block.normalization_function, m_width=config.m_width, + A_init_min=block.A_init_min, + A_init_max=block.A_init_max, + dt_init_min=block.dt_init_min, + dt_init_max=block.dt_init_max, + dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, layer_idx=layer_idx, ) @@ -115,12 +120,19 @@ def get_sequence_mixer( num_k_heads=block.num_k_heads, num_v_heads=block.num_v_heads, use_gate=block.use_gate, + attention_multiplier=block.attention_multiplier, allow_neg_eigval=block.allow_neg_eigval, - conv_size=block.conv_size, + conv_size=block.kernel_size, layer_idx=layer_idx, norm_eps=config.layer_norm_epsilon, init_method=config.init_method, initializer_range=config.initializer_range, + m_width=config.m_width, + A_init_min=block.A_init_min, + A_init_max=block.A_init_max, + dt_init_min=block.dt_init_min, + dt_init_max=block.dt_init_max, + dt_init_floor=block.dt_init_floor, num_layers=config.num_layers, use_padding_free_transformer=use_padding_free_transformer, ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py index 21affa061..ed44b2808 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/gated_deltanet.py @@ -15,8 +15,8 @@ from ....utils import divide_if_divisible, is_fla_available from ...cache import GenerationCache -from ...parameter import mark_parameter_as_initialized, mark_parameter_as_no_weight_decay from ..convolution import ParameterizedConv1d +from ..decay_gate import SoftplusDecayGate from ..linear import ParameterizedLinear from ..normalization import get_normalization_function from .causal_convolution import causal_convolution @@ -44,6 +44,11 @@ def __init__( init_method: str, initializer_range: float, m_width: float | None, + A_init_min: float, + A_init_max: float, + dt_init_min: float, + dt_init_max: float, + dt_init_floor: float, num_layers: int, use_padding_free_transformer: bool, ) -> GatedDeltaNet: @@ -81,11 +86,17 @@ def __init__( hidden_size, 2 * self.num_v_heads + (self.value_dim if use_gate else 0), bias=False, std=std ) - self.A_log = nn.Parameter(torch.empty(self.num_v_heads, dtype=torch.float32)) - mark_parameter_as_no_weight_decay(self.A_log) - - self.dt_bias = nn.Parameter(torch.empty(self.num_v_heads)) - mark_parameter_as_no_weight_decay(self.dt_bias) + self.decay_gate = SoftplusDecayGate( + hidden_size=None, + output_size=self.num_v_heads, + std=None, + has_projection=False, + A_init_min=A_init_min, + A_init_max=A_init_max, + dt_init_min=dt_init_min, + dt_init_max=dt_init_max, + dt_init_floor=dt_init_floor, + ) self.conv_size = conv_size self.qkv_conv1d = ParameterizedConv1d( @@ -106,8 +117,6 @@ def __init__( std /= math.sqrt(m_width) self.o_proj = ParameterizedLinear(self.value_dim, hidden_size, bias=False, std=std) - self.reset_parameters() - def forward( self, hidden_states: torch.Tensor, @@ -154,9 +163,9 @@ def forward( beta = b.sigmoid() if self.allow_neg_eigval: - beta = beta * 2.0 + beta = beta * 2 - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + g = self.decay_gate(x=a, final_exponential=False) if self.use_padding_free_transformer: assert cache_params is None @@ -222,21 +231,3 @@ def forward( o = self.o_proj(o) return o - - def reset_parameters(self) -> None: - A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16) - self.A_log.copy_(torch.log(A)) - - # hard coded for now - dt_min = 0.001 - dt_max = 0.1 - dt_init_floor = 1e-4 - dt = torch.exp(torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)) - dt = torch.clamp(dt, min=dt_init_floor) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - self.dt_bias.copy_(inv_dt) - - mark_parameter_as_initialized(self.A_log) - mark_parameter_as_initialized(self.dt_bias) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py index e4ab4f0cb..8866f93a3 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/mamba2.py @@ -9,10 +9,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.distributed._tensor.api import DTensor -from torch.distributed._tensor.placement_types import Replicate -from ....dtensors import tensor_to_dtensor from ....enums import Kernel from ....kernels import is_kernel_allowed from ....utils import divide_if_divisible, is_causal_conv1d_available, is_mamba_2_ssm_available @@ -24,6 +21,7 @@ ) from ..activations import get_activation_function from ..convolution import ParameterizedConv1d +from ..decay_gate import SoftplusDecayGate from ..linear import ParameterizedLinear from ..mlp_blocks.mlp import _get_std_for_linear from ..normalization import get_normalization_function @@ -115,6 +113,11 @@ def __init__( layer_norm_epsilon: float, initializer_range: float, m_width: float, + A_init_min: float, + A_init_max: float, + dt_init_min: float, + dt_init_max: float, + dt_init_floor: float, init_method: str, normalization_function: str | None, num_layers: int, @@ -162,28 +165,28 @@ def __init__( std=std, ) - # selective projection used to make dt, B and C input dependant - - # time step projection (discretization) - # instantiate once and copy inv_dt in init_weights of PretrainedModel - # Initialize log dt bias - self.dt_bias = nn.Parameter(torch.empty(self.num_heads)) + self.decay_gate = SoftplusDecayGate( + hidden_size=None, + output_size=self.num_heads, + std=None, + has_projection=False, + A_init_min=A_init_min, + A_init_max=A_init_max, + dt_init_min=dt_init_min, + dt_init_max=dt_init_max, + dt_init_floor=dt_init_floor, + ) - # S4D real initialization. These are not discretized! - # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - self.A_log = nn.Parameter(torch.empty(self.num_heads)) self.norm = get_normalization_function(normalization_function, self.intermediate_size, eps=layer_norm_epsilon) - self.D = nn.Parameter(torch.empty(self.num_heads)) self.out_proj = ParameterizedLinear( self.intermediate_size, self.hidden_size, bias=add_bias, std=std / math.sqrt(2 * num_layers) ) - mark_parameter_as_no_weight_decay(self.dt_bias) - mark_parameter_as_no_weight_decay(self.A_log) + self.D = nn.Parameter(torch.empty(self.num_heads)) mark_parameter_as_no_weight_decay(self.D) - mark_parameter_as_mup_learning_rate(self.A_log) + mark_parameter_as_mup_learning_rate(self.decay_gate.A_log) mark_parameter_as_mup_learning_rate(self.D) mark_parameter_as_mup_learning_rate(self.conv1d.weight) mark_parameter_as_mup_learning_rate(self.in_proj.weight) @@ -271,7 +274,7 @@ def _torch_forward( ) # 3. SSM transformation - A = -torch.exp(self.A_log.float()) + A = -torch.exp(self.decay_gate.A_log.float()) # hidden_states -> B, S, N, head_dim # A -> num_heads @@ -290,7 +293,7 @@ def _torch_forward( # dt -> (B, 1, N) dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) # dt -> (B, N, head_dim) - dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + dt_bias = self.decay_gate.dt_bias[..., None].expand(self.decay_gate.dt_bias.shape[0], self.head_dim) dt = F.softplus(dt + dt_bias.to(dt.dtype)) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) @@ -305,34 +308,31 @@ def _torch_forward( # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] # NOTE: S = 1 actually here - B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] - # B -> (B, G, 1, ssm_state_size / num_groups) - B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() - # B -> (B, G, N / G, ssm_state_size / num_groups) - B = B.reshape(batch_size, -1, B.shape[-1]) - # B -> (B, N, ssm_state_size / num_groups) - - # (B, N, head_dim, 1) * (B, N, 1, ssm_state_size / num_groups) + B, C = [i.reshape(batch_size, self.n_groups, -1)[..., None, :] for i in (B, C)] + # B, C -> (B, G, 1, ssm_state_size) + B, C = [ + i.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, i.shape[-1]).contiguous() + for i in (B, C) + ] + # B, C -> (B, G, N / G, ssm_state_size) + B, C = [i.reshape(batch_size, -1, i.shape[-1]) for i in (B, C)] + # B, C -> (B, N, ssm_state_size) + + # (B, N, head_dim, 1) * (B, N, 1, ssm_state_size) + # B is same as k and is shared across heads and dt is used to expand it dB = dt[..., None] * B[..., None, :] - # dB -> (B, N, head_dim, ssm_state_size / num_groups) + # dB -> (B, N, head_dim, ssm_state_size) # Discretize x into dB hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) # hidden_states -> (B, N, head_dim) dBx = (dB * hidden_states[..., None]).to(device=cache_device) - # dBx -> (B, N, head_dim, ssm_state_size / num_groups) + # dBx -> (B, N, head_dim, ssm_state_size) # State calculation ssm_state = ssm_state * dA + dBx cache_params.update(ssm_state=ssm_state, num_tokens_added=seq_len, layer_idx=self.layer_idx) - # Subsequent output - # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] - C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] - C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() - C = C.reshape(batch_size, -1, C.shape[-1]) - # [bsz, num_heads, head_dim] - ssm_state = ssm_state.to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] # Reshape ssm_states to merge the first two dimensions ssm_states_reshaped = ssm_state.view( @@ -351,7 +351,7 @@ def _torch_forward( y = y.reshape(batch_size, -1)[:, None, ...] else: # begin ssd naive implementation without einsums - dt = nn.functional.softplus(dt + self.dt_bias) + dt = F.softplus(dt + self.decay_gate.dt_bias) dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() @@ -481,10 +481,10 @@ def _cuda_forward( ) # 3. SSM transformation - A = -torch.exp(self.A_log.float()) # (nheads,) + A = -torch.exp(self.decay_gate.A_log.float()) # (nheads,) A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) dt = dt[:, :, None].expand(-1, -1, self.head_dim) - dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + dt_bias = self.decay_gate.dt_bias[:, None, ...].expand(-1, self.head_dim) D = self.D[:, None, ...].expand(-1, self.head_dim) B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) @@ -502,13 +502,14 @@ def _cuda_forward( dt_softplus=True, ) hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) - hidden_states = self.norm(hidden_states, gate) + hidden_states = hidden_states * F.silu(gate) + hidden_states = self.norm(hidden_states) # 4. Final linear projection out = self.out_proj(hidden_states)[:, None, ...] # Fused calculations or step by step if no initialized cache is found else: - A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + A = -torch.exp(self.decay_gate.A_log.float()) # (num_heads) or (intermediate_size, state_size) dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} # 2-4. Fused kernel for conv1d, SSM, and the final projection @@ -517,7 +518,7 @@ def _cuda_forward( projected_states, self.conv1d.weight.squeeze(1), self.conv1d.bias, - self.dt_bias, + self.decay_gate.dt_bias, A, D=self.D, chunk_size=self.chunk_size, @@ -581,7 +582,7 @@ def _cuda_forward( z=None, seq_idx=None, return_final_states=True, - dt_bias=self.dt_bias, + dt_bias=self.decay_gate.dt_bias, dt_softplus=True, **dt_limit_kwargs, ) @@ -602,21 +603,5 @@ def _cuda_forward( @torch.no_grad() def reset_parameters(self) -> None: - A = torch.log(torch.arange(1, self.num_heads + 1)) - - if isinstance(self.A_log, DTensor): - A = tensor_to_dtensor( - A, - device_mesh=self.A_log.device_mesh, - current_placement=[Replicate()] * len(self.A_log.placements), - desired_placement=self.A_log.placements, - ) - - self.A_log.copy_(A) - nn.init.ones_(self.D) - nn.init.ones_(self.dt_bias) - - mark_parameter_as_initialized(self.A_log) mark_parameter_as_initialized(self.D) - mark_parameter_as_initialized(self.dt_bias) diff --git a/lm_engine/hf_models/parameter.py b/lm_engine/hf_models/parameter.py index b86da83b2..4ab6e6834 100644 --- a/lm_engine/hf_models/parameter.py +++ b/lm_engine/hf_models/parameter.py @@ -5,9 +5,6 @@ import torch.nn as nn -_ALL_MARKERS = ["_no_weight_decay", "_has_mup_learning_rate", "_is_initialized"] - - def mark_parameter_as_no_weight_decay(parameter: nn.Parameter | None) -> nn.Parameter | None: if parameter is not None: parameter._no_weight_decay = True diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py index 431ebba7d..2698cc8a7 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward.py @@ -62,90 +62,93 @@ ], ) -enable_kernels( - [Kernel.scattermoe] + ([Kernel.flash_attention_2] if args.attention_implementation == "flash_attention_2" else []) -).__enter__() +kernels = [Kernel.scattermoe] +if args.attention_implementation == "flash_attention_2": + kernels.append(Kernel.flash_attention_2) +elif args.attention_implementation == "flash_attention_3": + kernels.append(Kernel.flash_attention_3) -if torch.distributed.get_rank() == 0: - with torch.device("meta"): - model = TestCommons.from_config(None, config) - - model = model.to_empty(device=torch.cuda.current_device()) - for _, param in model.named_parameters(): - param.data.normal_(0, 0.0125) +with enable_kernels(kernels): + if torch.distributed.get_rank() == 0: + with torch.device("meta"): + model = TestCommons.from_config(None, config) - model.eval() + model = model.to_empty(device=torch.cuda.current_device()) + for _, param in model.named_parameters(): + param.data.normal_(0, 0.0125) - model.save_pretrained(args.tmp_path, safe_serialization=True) - model = model.to(dtype) + model.eval() -Communication.barrier() + model.save_pretrained(args.tmp_path, safe_serialization=True) + model = model.to(dtype) -# use dummy tensors to avoid initializing model here -with torch.device("meta"): - # try sharding vocab matrices if really struggling for memory + Communication.barrier() - model_tp = get_model_parallel_class(config.model_type)._from_config( - config, - use_padding_free_transformer=args.use_padding_free_transformer, - sequence_parallel=args.sequence_parallel, - ) + # use dummy tensors to avoid initializing model here + with torch.device("meta"): + # try sharding vocab matrices if really struggling for memory -# copy to device without copying storage -model_tp = model_tp.to_empty(device=torch.cuda.current_device()) + model_tp = get_model_parallel_class(config.model_type)._from_config( + config, + use_padding_free_transformer=args.use_padding_free_transformer, + sequence_parallel=args.sequence_parallel, + ) -# load weights into tensor parallel model using SafeTensorsWeightsManager class -# this avoids loading multiple copies of the parameters in CPU memory -model_tp.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(args.tmp_path)) + # copy to device without copying storage + model_tp = model_tp.to_empty(device=torch.cuda.current_device()) -# set model to eval mode -model_tp = model_tp.to(dtype) -model_tp.eval() + # load weights into tensor parallel model using SafeTensorsWeightsManager class + # this avoids loading multiple copies of the parameters in CPU memory + model_tp.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(args.tmp_path)) -set_seed(42) + # set model to eval mode + model_tp = model_tp.to(dtype) + model_tp.eval() -batch_size = 4 -sequence_length = 512 + set_seed(42) -input_ids = torch.randint( - 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False -) -labels = torch.randint( - 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False -) + batch_size = 4 + sequence_length = 512 -if args.use_padding_free_transformer: - cu_seqlens = torch.arange( - 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() + input_ids = torch.randint( + 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False ) - position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) - - output_tp = model_tp( - input_ids=input_ids.view(-1), - labels=labels.view(-1), - cu_seqlens=cu_seqlens, - max_seqlen=sequence_length, - position_ids=position_ids, + labels = torch.randint( + 0, 50255, (batch_size, sequence_length), device=torch.cuda.current_device(), requires_grad=False ) -else: - output_tp = model_tp(input_ids=input_ids, labels=labels) - -loss_tp = output_tp.loss -logits_tp = output_tp.logits[..., : config.vocab_size] - -if torch.distributed.get_rank() == 0: - # loss computation hangs if we don't use dummy tensor parallel world size - with ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): - output = model(input_ids=input_ids, labels=labels) - - loss = output.loss - logits = output.logits if args.use_padding_free_transformer: - logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) - - error = (logits - logits_tp).abs().max() - assert error < 5e-4, f"logits don't match for normal and tensor parallel model, error is ({error})" - - error = (loss - loss_tp).abs().max() - assert error < 1e-3, f"losses don't match for normal and tensor parallel model, error is ({error})" + cu_seqlens = torch.arange( + 0, input_ids.numel() + 1, sequence_length, dtype=torch.int32, device=torch.cuda.current_device() + ) + position_ids = torch.arange(0, sequence_length, 1, device=torch.cuda.current_device()).repeat(batch_size) + + output_tp = model_tp( + input_ids=input_ids.view(-1), + labels=labels.view(-1), + cu_seqlens=cu_seqlens, + max_seqlen=sequence_length, + position_ids=position_ids, + ) + else: + output_tp = model_tp(input_ids=input_ids, labels=labels) + + loss_tp = output_tp.loss + logits_tp = output_tp.logits[..., : config.vocab_size] + + if torch.distributed.get_rank() == 0: + # loss computation hangs if we don't use dummy tensor parallel world size + with ProcessGroupManager.set_dummy_tensor_parallel_world_size(1): + output = model(input_ids=input_ids, labels=labels) + + loss = output.loss + logits = output.logits + + if args.use_padding_free_transformer: + logits_tp = logits_tp.reshape(batch_size, sequence_length, -1) + + error = (logits - logits_tp).abs().max() + assert error < 5e-4, f"logits don't match for normal and tensor parallel model, error is ({error})" + + error = (loss - loss_tp).abs().max() + assert error < 1e-3, f"losses don't match for normal and tensor parallel model, error is ({error})" diff --git a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py index 7fc50d93a..3d6da4f19 100644 --- a/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py +++ b/tests/hf_models/multi_gpu/tensor_parallel/tensor_parallel_forward_test.py @@ -8,7 +8,7 @@ import torch from parameterized import parameterized -from lm_engine.utils import is_flash_attention_2_available, torch_dtype_to_string +from lm_engine.utils import is_flash_attention_2_available, is_flash_attention_3_available, torch_dtype_to_string from ...test_common import TestCommons @@ -17,7 +17,7 @@ class TensorParallelTest(TestCommons): @parameterized.expand( TestCommons.make_args_matrix( TestCommons.get_position_embedding_types(), - TestCommons.get_attention_implementations(), + ["sdpa", "flash_attention_2", "flash_attention_3"], TestCommons.get_dtypes(), [False, True], [False, True], @@ -37,11 +37,14 @@ def test_tensor_parallel_forward( if (attention_implementation, dtype) not in [ ("sdpa", torch.float32), ("flash_attention_2", torch.float16), + ("flash_attention_3", torch.float16), ]: self.skipTest("skipping test since running all takes too long") if attention_implementation == "flash_attention_2" and not is_flash_attention_2_available(): - self.skipTest("skipping test since flash-attn is unavialable") + self.skipTest("skipping test because flash attention 2 is unavailable") + elif attention_implementation == "flash_attention_3" and not is_flash_attention_3_available(): + self.skipTest("skipping test because flash attention 3 is unavailable") if use_padding_free_transformer and attention_implementation != "flash_attention_2": self.skipTest("skipping test since flash attention is needed for padding free transformer") diff --git a/tests/hf_models/single_gpu/typecheck_test.py b/tests/hf_models/single_gpu/typecheck_test.py index ded25bb6a..480c12d32 100644 --- a/tests/hf_models/single_gpu/typecheck_test.py +++ b/tests/hf_models/single_gpu/typecheck_test.py @@ -7,6 +7,7 @@ from lm_engine.enums import Kernel from lm_engine.kernels import enable_kernels +from lm_engine.utils import is_flash_attention_2_available, is_flash_attention_3_available from ..test_common import TestCommons @@ -25,5 +26,14 @@ def test_no_attention_mask_flash_attention(self, device: torch.device) -> None: input_ids, _, labels = self.get_dummy_inputs(device, return_list=True) attention_mask = [[1] * len(i) for i in input_ids] - with enable_kernels([Kernel.flash_attention_2]): + kernel = None + if is_flash_attention_3_available(): + kernel = Kernel.flash_attention_3 + if is_flash_attention_2_available(): + kernel = Kernel.flash_attention_2 + + if kernel is None: + self.skipTest("skipping test because flash attention 2 or 3 is unavailable") + + with enable_kernels([kernel]): self.assertRaises(AssertionError, model, input_ids=input_ids, attention_mask=attention_mask, labels=labels) diff --git a/tests/hf_models/test_common.py b/tests/hf_models/test_common.py index cbc4152b1..914760684 100644 --- a/tests/hf_models/test_common.py +++ b/tests/hf_models/test_common.py @@ -20,10 +20,6 @@ class TestCommons(BaseTestCommons): - @staticmethod - def get_attention_implementations() -> list[str]: - return ["sdpa", "flash_attention_2"] - @staticmethod def get_position_embedding_types() -> list[str]: return ["learned_absolute", "rope"] diff --git a/tests/training/params_group/groups/mup.json b/tests/training/params_group/groups/mup.json index f39be3fa9..cb6d202a7 100644 --- a/tests/training/params_group/groups/mup.json +++ b/tests/training/params_group/groups/mup.json @@ -16,7 +16,7 @@ "model.transformer.h.1.mlp_block.c_proj.bias", "model.transformer.h.1.mlp_block.c_proj_shared.bias", "model.transformer.h.1.sequence_mixer.conv1d.bias", - "model.transformer.h.1.sequence_mixer.dt_bias", + "model.transformer.h.1.sequence_mixer.decay_gate.dt_bias", "model.transformer.h.1.sequence_mixer.in_proj.bias", "model.transformer.h.1.sequence_mixer.norm.weight", "model.transformer.h.1.sequence_mixer.out_proj.bias", @@ -42,9 +42,9 @@ "model.transformer.h.1.mlp_block.c_proj.weight", "model.transformer.h.1.mlp_block.c_proj_shared.weight", "model.transformer.h.1.mlp_block.gate.weight", - "model.transformer.h.1.sequence_mixer.A_log", "model.transformer.h.1.sequence_mixer.D", "model.transformer.h.1.sequence_mixer.conv1d.weight", + "model.transformer.h.1.sequence_mixer.decay_gate.A_log", "model.transformer.h.1.sequence_mixer.in_proj.weight", "model.transformer.h.1.sequence_mixer.out_proj.weight", "model.transformer.h.2.mlp_block.c_fc.weight", diff --git a/tests/training/params_group/groups/normal.json b/tests/training/params_group/groups/normal.json index c14d0f0df..ae8f2a459 100644 --- a/tests/training/params_group/groups/normal.json +++ b/tests/training/params_group/groups/normal.json @@ -35,10 +35,10 @@ "model.transformer.h.1.mlp_block.c_fc_shared.bias", "model.transformer.h.1.mlp_block.c_proj.bias", "model.transformer.h.1.mlp_block.c_proj_shared.bias", - "model.transformer.h.1.sequence_mixer.A_log", "model.transformer.h.1.sequence_mixer.D", "model.transformer.h.1.sequence_mixer.conv1d.bias", - "model.transformer.h.1.sequence_mixer.dt_bias", + "model.transformer.h.1.sequence_mixer.decay_gate.A_log", + "model.transformer.h.1.sequence_mixer.decay_gate.dt_bias", "model.transformer.h.1.sequence_mixer.in_proj.bias", "model.transformer.h.1.sequence_mixer.norm.weight", "model.transformer.h.1.sequence_mixer.out_proj.bias",