diff --git a/dolomite_engine/hf_models/mixins/dense_TP/main.py b/dolomite_engine/hf_models/mixins/dense_TP/main.py index 26d6daea6..da6a7be7d 100644 --- a/dolomite_engine/hf_models/mixins/dense_TP/main.py +++ b/dolomite_engine/hf_models/mixins/dense_TP/main.py @@ -10,7 +10,7 @@ from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible from ...config import CommonConfig from ...enums import PositionEmbeddingType -from ...loss import add_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss +from ...loss import add_aux_loss, clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss from ...modeling_utils_TP import LMHead_TP from ..dense import CausalLMModelMixin from ..modeling_outputs import ( @@ -68,6 +68,8 @@ def forward( if self.is_pipeline_parallel_enabled: past_key_values = None + clear_aux_loss() + if self.is_first_stage: assert pipeline_parallel_input is None, "first stage should not get pipeline_parallel_input" input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( @@ -84,6 +86,7 @@ def forward( ) else: assert input_ids is None + add_aux_loss(pipeline_parallel_input.aux_loss) transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids=input_ids if pipeline_parallel_input is None else pipeline_parallel_input.hidden_states, @@ -105,6 +108,7 @@ def forward( lm_logits = None loss = None + aux_loss = get_aux_loss() if self.is_last_stage: if labels is None: @@ -142,8 +146,6 @@ def forward( lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) - aux_loss = get_aux_loss() - if loss is not None and aux_loss != 0: loss = loss + self.router_aux_loss_coef * aux_loss @@ -154,7 +156,7 @@ def forward( last_hidden_state=hidden_states, ) else: - output = PipelineParallelOutput(hidden_states=hidden_states) + output = PipelineParallelOutput(hidden_states=hidden_states, aux_loss=aux_loss) return output diff --git a/dolomite_engine/hf_models/mixins/modeling_outputs.py b/dolomite_engine/hf_models/mixins/modeling_outputs.py index 5ec7b93ad..75f7ca67b 100644 --- a/dolomite_engine/hf_models/mixins/modeling_outputs.py +++ b/dolomite_engine/hf_models/mixins/modeling_outputs.py @@ -21,8 +21,10 @@ class CausalLMOutputWithPast(ModelOutput): @dataclass class PipelineParallelInput(ModelOutput): hidden_states: torch.Tensor | None = None + aux_loss: torch.Tensor | float = 0 @dataclass class PipelineParallelOutput(ModelOutput): hidden_states: torch.Tensor | None = None + aux_loss: torch.Tensor | float = 0 diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index 450140e2f..b4d587f8a 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -3,7 +3,7 @@ from ...config import CommonConfig from ...enums import InitMethod from .mlp import MLP, interleave_up_gate_tensor_for_mlp, split_up_gate_tensor_for_mlp -from .moe import AuxFreeMoE, MoE, ParameterizedExperts, ParameterizedScatteredExperts, ScatterMoE +from .moe import MoE, ParameterizedExperts, ParameterizedScatteredExperts, ScatterMoE def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int) -> MLP | MoE: @@ -33,8 +33,5 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye num_experts_per_tok=block.num_experts_per_tok, use_padding_free_transformer=use_padding_free_transformer, ) - elif mlp_type == "AuxFreeMoE": - assert is_kernel_allowed(Kernel.scattermoe) - return AuxFreeMoE(config, use_padding_free_transformer) else: raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/__init__.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/__init__.py index 86e03654f..b0af3675c 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/__init__.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/__init__.py @@ -1,3 +1,2 @@ -from .auxfree import AuxFreeMoE from .base import MoE, ParameterizedExperts from .scatter import ParameterizedScatteredExperts, ScatterMoE diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/auxfree.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/auxfree.py deleted file mode 100755 index ee37a6f9f..000000000 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/auxfree.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import torch.nn.functional as F -from torch.distributed._functional_collectives import all_reduce - -from .....utils import ProcessGroupManager -from ....config import CommonConfig -from .scatter import ScatterMoE - - -class AuxFreeMoE(ScatterMoE): - def __init__(self, config: CommonConfig, use_padding_free_transformer: bool) -> None: - super().__init__(config, use_padding_free_transformer) - self.register_buffer("bias", torch.zeros(config.num_experts)) - self.step_size = config.router_aux_loss_coef - - def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]: - router_logits = self.gate(hidden_states) - - with torch.no_grad(): - _, selected_experts = self._get_topk(F.softmax(router_logits, dim=-1) + self.bias) - - router_weights = router_logits[ - torch.arange(hidden_states.size(0), device=hidden_states.device, dtype=torch.int32)[:, None], - selected_experts, - ] - router_weights = F.softmax(router_weights.float(), dim=-1) - - # we cast back to the input dtype - router_weights = router_weights.type_as(hidden_states) - - return router_logits, router_weights, selected_experts - - def _compute_switch_loss(self, logits: torch.Tensor, probs: torch.Tensor, topk_idxs: torch.Tensor) -> torch.Tensor: - num_experts = logits.size(-1) - freq = torch.bincount(topk_idxs.flatten(), minlength=num_experts).to(dtype=logits.dtype) - - if ProcessGroupManager.is_initialized() and ProcessGroupManager.get_data_parallel_world_size() > 1: - freq = all_reduce(freq, reduceOp="sum", group=ProcessGroupManager.get_data_parallel_group()) - - avg_counts = torch.mean(freq, dim=0, keepdim=True) - - if self.training and self.step_size > 0: - self.bias += self.step_size * torch.sign(avg_counts - freq) - - with torch.no_grad(): - acc_probs = probs.sum(0) - switch_loss = num_experts * (F.normalize(acc_probs, p=1, dim=0) * F.normalize(freq, p=1, dim=0)).sum() - - return switch_loss.detach() diff --git a/dolomite_engine/model_wrapper/finetuning.py b/dolomite_engine/model_wrapper/finetuning.py index 073c10a9c..ce008e013 100644 --- a/dolomite_engine/model_wrapper/finetuning.py +++ b/dolomite_engine/model_wrapper/finetuning.py @@ -25,18 +25,27 @@ def forward(self, batch: dict, lm_loss_multiplier: float = 1) -> MetricsTracking if ProcessGroupManager.is_tensor_parallel_enabled(): batch = self._broadcast_inputs_for_tensor_parallel(batch) - if not self.is_custom_model: + if self.is_custom_model: assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) - labels = batch.pop("labels") - model_outputs: CausalLMOutputWithPast = self.model(**batch) + labels = batch.pop("labels") + model_outputs: CausalLMOutputWithPast = self.model(**batch) - return self.get_loss( - model_outputs=model_outputs, - labels=labels, - cu_seqlens=batch.get("cu_seqlens", None), - lm_loss_multiplier=lm_loss_multiplier, - ) + output = self.get_loss( + model_outputs=model_outputs, + labels=labels, + cu_seqlens=batch.get("cu_seqlens", None), + lm_loss_multiplier=lm_loss_multiplier, + ) + else: + # use HF loss API for HF model since we can't get the aux loss outside + model_outputs: CausalLMOutputWithPast = self.model(**batch) + output = {"loss": model_outputs.loss} + + if hasattr(model_outputs, "aux_loss"): + output["aux_loss"] = model_outputs.aux_loss + + return output def get_loss( self, diff --git a/dolomite_engine/model_wrapper/pretraining.py b/dolomite_engine/model_wrapper/pretraining.py index 7235ffe01..d6be25819 100644 --- a/dolomite_engine/model_wrapper/pretraining.py +++ b/dolomite_engine/model_wrapper/pretraining.py @@ -159,7 +159,7 @@ def get_loss( if tensor_parallel_enabled: aux_loss = tensor_to_dtensor(aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate()) - loss = _F.apply(lm_loss, aux_loss, self.router_aux_loss_coef) + loss = lm_loss + self.router_aux_loss_coef * aux_loss output = {"loss": loss, "lm_loss": lm_loss, "aux_loss": aux_loss} return output @@ -289,15 +289,3 @@ def reset_parameters(self) -> None: assert ( not self.reset_position_ids ), "currently reset_position_ids is only implemented for padding free transformer" - - -class _F(torch.autograd.Function): - @staticmethod - def forward(ctx, lm_loss: torch.Tensor, aux_loss: torch.Tensor, router_aux_loss_coef: float) -> torch.Tensor: - ctx.router_aux_loss_coef = router_aux_loss_coef - return lm_loss + router_aux_loss_coef * aux_loss - - @staticmethod - @torch._dynamo.disable - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None]: - return grad_output, ctx.router_aux_loss_coef * grad_output, None