From fb46dad75027e656439a36b332cdf6ee2c21d68e Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 01:38:09 -0500 Subject: [PATCH 01/11] cleanup Signed-off-by: Mayank Mishra --- .../modeling_utils/mlp_blocks/__init__.py | 6 +-- .../modeling_utils/mlp_blocks/moe/__init__.py | 1 - .../modeling_utils/mlp_blocks/moe/auxfree.py | 49 ------------------- .../modeling_utils/mlp_blocks/moe/base.py | 35 +++++++++++-- dolomite_engine/model_wrapper/pretraining.py | 14 +----- 5 files changed, 33 insertions(+), 72 deletions(-) delete mode 100755 dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/auxfree.py diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index 450140e2f..f91654aee 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -3,7 +3,7 @@ from ...config import CommonConfig from ...enums import InitMethod from .mlp import MLP, interleave_up_gate_tensor_for_mlp, split_up_gate_tensor_for_mlp -from .moe import AuxFreeMoE, MoE, ParameterizedExperts, ParameterizedScatteredExperts, ScatterMoE +from .moe import MoE, ParameterizedExperts, ParameterizedScatteredExperts, ScatterMoE def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int) -> MLP | MoE: @@ -20,6 +20,7 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye initializer_range=config.initializer_range, m_width=config.m_width, num_layers=config.num_layers, + router_aux_loss_coef=config.router_aux_loss_coef, ) if mlp_type == "MLP": @@ -33,8 +34,5 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye num_experts_per_tok=block.num_experts_per_tok, use_padding_free_transformer=use_padding_free_transformer, ) - elif mlp_type == "AuxFreeMoE": - assert is_kernel_allowed(Kernel.scattermoe) - return AuxFreeMoE(config, use_padding_free_transformer) else: raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/__init__.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/__init__.py index 86e03654f..b0af3675c 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/__init__.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/__init__.py @@ -1,3 +1,2 @@ -from .auxfree import AuxFreeMoE from .base import MoE, ParameterizedExperts from .scatter import ParameterizedScatteredExperts, ScatterMoE diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/auxfree.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/auxfree.py deleted file mode 100755 index ee37a6f9f..000000000 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/auxfree.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import torch.nn.functional as F -from torch.distributed._functional_collectives import all_reduce - -from .....utils import ProcessGroupManager -from ....config import CommonConfig -from .scatter import ScatterMoE - - -class AuxFreeMoE(ScatterMoE): - def __init__(self, config: CommonConfig, use_padding_free_transformer: bool) -> None: - super().__init__(config, use_padding_free_transformer) - self.register_buffer("bias", torch.zeros(config.num_experts)) - self.step_size = config.router_aux_loss_coef - - def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]: - router_logits = self.gate(hidden_states) - - with torch.no_grad(): - _, selected_experts = self._get_topk(F.softmax(router_logits, dim=-1) + self.bias) - - router_weights = router_logits[ - torch.arange(hidden_states.size(0), device=hidden_states.device, dtype=torch.int32)[:, None], - selected_experts, - ] - router_weights = F.softmax(router_weights.float(), dim=-1) - - # we cast back to the input dtype - router_weights = router_weights.type_as(hidden_states) - - return router_logits, router_weights, selected_experts - - def _compute_switch_loss(self, logits: torch.Tensor, probs: torch.Tensor, topk_idxs: torch.Tensor) -> torch.Tensor: - num_experts = logits.size(-1) - freq = torch.bincount(topk_idxs.flatten(), minlength=num_experts).to(dtype=logits.dtype) - - if ProcessGroupManager.is_initialized() and ProcessGroupManager.get_data_parallel_world_size() > 1: - freq = all_reduce(freq, reduceOp="sum", group=ProcessGroupManager.get_data_parallel_group()) - - avg_counts = torch.mean(freq, dim=0, keepdim=True) - - if self.training and self.step_size > 0: - self.bias += self.step_size * torch.sign(avg_counts - freq) - - with torch.no_grad(): - acc_probs = probs.sum(0) - switch_loss = num_experts * (F.normalize(acc_probs, p=1, dim=0) * F.normalize(freq, p=1, dim=0)).sum() - - return switch_loss.detach() diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py index 4dfa1dd03..7cc001949 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py @@ -18,6 +18,27 @@ from cute_kernels.kernels.scattermoe.triton_implementation import bincount +class _AuxLossBackprop(torch.autograd.Function): + @staticmethod + def forward(ctx, aux_loss: torch.Tensor, router_aux_loss_coef: float) -> torch.Tensor: + ctx.router_aux_loss_coef = router_aux_loss_coef + ctx.dtype = aux_loss.dtype + return aux_loss + + @staticmethod + def backward(ctx, aux_loss_grad: torch.Tensor) -> tuple[torch.Tensor]: + return ( + torch.tensor(ctx.router_aux_loss_coef, dtype=ctx.dtype, device=torch.cuda.current_device()), + None, + ) + + +def aux_loss_backprop(aux_loss: torch.Tensor, router_aux_loss_coef: float) -> torch.Tensor: + aux_loss = _AuxLossBackprop.apply(aux_loss, router_aux_loss_coef) + aux_loss = aux_loss.detach() + return aux_loss + + class ParameterizedExperts(nn.Module): def __init__( self, @@ -84,6 +105,7 @@ def __init__( m_width: float, num_layers: int, use_padding_free_transformer: bool, + router_aux_loss_coef: float, ) -> None: super().__init__() @@ -146,6 +168,8 @@ def __init__( torch.cuda.current_device() ) >= (9, 0) + self.router_aux_loss_coef = router_aux_loss_coef + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if not self.use_padding_free_transformer: batch_size, sequence_length, _ = hidden_states.shape @@ -167,13 +191,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dropout(hidden_states) - aux_loss = ( - self._compute_switch_loss( + aux_loss = 0 + if self.training: + aux_loss = self._compute_switch_loss( logits=router_logits, probs=torch.softmax(router_logits, dim=-1), topk_idxs=selected_experts ) - if self.training - else 0 - ) + + aux_loss = aux_loss_backprop(aux_loss, self.router_aux_loss_coef) add_aux_loss(aux_loss) @@ -266,5 +290,6 @@ def _compute_switch_loss(self, logits: torch.Tensor, probs: torch.Tensor, topk_i z_loss = (torch.logsumexp(logits, dim=-1) ** 2).mean() loss = switch_loss + 0.1 * z_loss + loss = aux_loss_backprop(loss, self.router_aux_loss_coef) return loss diff --git a/dolomite_engine/model_wrapper/pretraining.py b/dolomite_engine/model_wrapper/pretraining.py index 7235ffe01..d6be25819 100644 --- a/dolomite_engine/model_wrapper/pretraining.py +++ b/dolomite_engine/model_wrapper/pretraining.py @@ -159,7 +159,7 @@ def get_loss( if tensor_parallel_enabled: aux_loss = tensor_to_dtensor(aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate()) - loss = _F.apply(lm_loss, aux_loss, self.router_aux_loss_coef) + loss = lm_loss + self.router_aux_loss_coef * aux_loss output = {"loss": loss, "lm_loss": lm_loss, "aux_loss": aux_loss} return output @@ -289,15 +289,3 @@ def reset_parameters(self) -> None: assert ( not self.reset_position_ids ), "currently reset_position_ids is only implemented for padding free transformer" - - -class _F(torch.autograd.Function): - @staticmethod - def forward(ctx, lm_loss: torch.Tensor, aux_loss: torch.Tensor, router_aux_loss_coef: float) -> torch.Tensor: - ctx.router_aux_loss_coef = router_aux_loss_coef - return lm_loss + router_aux_loss_coef * aux_loss - - @staticmethod - @torch._dynamo.disable - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None]: - return grad_output, ctx.router_aux_loss_coef * grad_output, None From 1c70896f04dc9b7620c01387a59caa23432a293b Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 01:39:58 -0500 Subject: [PATCH 02/11] cleanup Signed-off-by: Mayank Mishra --- dolomite_engine/pretrain.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dolomite_engine/pretrain.py b/dolomite_engine/pretrain.py index ad8da5398..c9d1d9ffe 100644 --- a/dolomite_engine/pretrain.py +++ b/dolomite_engine/pretrain.py @@ -15,7 +15,8 @@ from .communication import Communication from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer, log_model_optimizer_container from .data import ResumableDataLoader, get_next_batch, get_pretraining_dataloaders -from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training +from .distributed import wrap_model_container_for_distributed_training +from .dtensors import dtensor_to_tensor from .enums import Mode, TuningMethod from .hf_models import disable_generation_cache from .kernels import enable_kernels From ac10be1e744d43a513424e8f4791d96a4c75b2fe Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 01:40:39 -0500 Subject: [PATCH 03/11] cleanup Signed-off-by: Mayank Mishra --- dolomite_engine/finetune.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dolomite_engine/finetune.py b/dolomite_engine/finetune.py index 5f04e8e1f..3073ef028 100644 --- a/dolomite_engine/finetune.py +++ b/dolomite_engine/finetune.py @@ -10,7 +10,8 @@ from .checkpointing import ensure_last_checkpoint_is_saved, load_checkpoint_for_training, save_checkpoint from .containers import LRSchedulerContainer, ModelContainer, OptimizerContainer, log_model_optimizer_container from .data import ResumableDataLoader, custom_iterator, get_finetuning_dataloader, get_next_batch -from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training +from .distributed import wrap_model_container_for_distributed_training +from .dtensors import dtensor_to_tensor from .enums import DatasetSplit, Mode, TuningMethod from .hf_models import disable_generation_cache from .kernels import enable_kernels From b8730f4ccf61283a735ddd46f13ea30425a5f2c2 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 01:41:42 -0500 Subject: [PATCH 04/11] cleanup Signed-off-by: Mayank Mishra --- dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index f91654aee..cc222142c 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -20,7 +20,6 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye initializer_range=config.initializer_range, m_width=config.m_width, num_layers=config.num_layers, - router_aux_loss_coef=config.router_aux_loss_coef, ) if mlp_type == "MLP": @@ -33,6 +32,7 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye num_experts=block.num_experts, num_experts_per_tok=block.num_experts_per_tok, use_padding_free_transformer=use_padding_free_transformer, + router_aux_loss_coef=config.router_aux_loss_coef, ) else: raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") From 4a734433347c56d9381530d3e82c6364734b2a83 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 01:58:15 -0500 Subject: [PATCH 05/11] cleanup Signed-off-by: Mayank Mishra --- .../modeling_utils/mlp_blocks/moe/base.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py index 7cc001949..5744b10cc 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py @@ -20,23 +20,26 @@ class _AuxLossBackprop(torch.autograd.Function): @staticmethod - def forward(ctx, aux_loss: torch.Tensor, router_aux_loss_coef: float) -> torch.Tensor: + def forward(ctx, hidden_states: torch.Tensor, aux_loss: torch.Tensor, router_aux_loss_coef: float) -> torch.Tensor: ctx.router_aux_loss_coef = router_aux_loss_coef ctx.dtype = aux_loss.dtype - return aux_loss + return hidden_states, aux_loss @staticmethod - def backward(ctx, aux_loss_grad: torch.Tensor) -> tuple[torch.Tensor]: + def backward(ctx, hidden_states_grad: torch.Tensor, aux_loss_grad: torch.Tensor) -> tuple[torch.Tensor]: return ( + hidden_states_grad, torch.tensor(ctx.router_aux_loss_coef, dtype=ctx.dtype, device=torch.cuda.current_device()), None, ) -def aux_loss_backprop(aux_loss: torch.Tensor, router_aux_loss_coef: float) -> torch.Tensor: - aux_loss = _AuxLossBackprop.apply(aux_loss, router_aux_loss_coef) +def aux_loss_backprop( + hidden_states: torch.Tensor, aux_loss: torch.Tensor, router_aux_loss_coef: float +) -> torch.Tensor: + hidden_states, aux_loss = _AuxLossBackprop.apply(hidden_states, aux_loss, router_aux_loss_coef) aux_loss = aux_loss.detach() - return aux_loss + return hidden_states, aux_loss class ParameterizedExperts(nn.Module): @@ -197,7 +200,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: logits=router_logits, probs=torch.softmax(router_logits, dim=-1), topk_idxs=selected_experts ) - aux_loss = aux_loss_backprop(aux_loss, self.router_aux_loss_coef) + hidden_states, aux_loss = aux_loss_backprop(hidden_states, aux_loss, self.router_aux_loss_coef) add_aux_loss(aux_loss) @@ -290,6 +293,5 @@ def _compute_switch_loss(self, logits: torch.Tensor, probs: torch.Tensor, topk_i z_loss = (torch.logsumexp(logits, dim=-1) ** 2).mean() loss = switch_loss + 0.1 * z_loss - loss = aux_loss_backprop(loss, self.router_aux_loss_coef) return loss From 92920db1433e51b56f39e7964e25c40d5743b704 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 02:03:46 -0500 Subject: [PATCH 06/11] cleanup Signed-off-by: Mayank Mishra --- dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py index cc222142c..b4d587f8a 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/__init__.py @@ -32,7 +32,6 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye num_experts=block.num_experts, num_experts_per_tok=block.num_experts_per_tok, use_padding_free_transformer=use_padding_free_transformer, - router_aux_loss_coef=config.router_aux_loss_coef, ) else: raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})") From 98fedcb230ae0ee8b0b1ee8f0c40751b38ea09d1 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 02:05:16 -0500 Subject: [PATCH 07/11] cleanup Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/mlp_blocks/moe/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py index 5744b10cc..7648034dc 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py @@ -108,7 +108,6 @@ def __init__( m_width: float, num_layers: int, use_padding_free_transformer: bool, - router_aux_loss_coef: float, ) -> None: super().__init__() @@ -171,8 +170,6 @@ def __init__( torch.cuda.current_device() ) >= (9, 0) - self.router_aux_loss_coef = router_aux_loss_coef - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if not self.use_padding_free_transformer: batch_size, sequence_length, _ = hidden_states.shape From 5ddec076db39905d16c2cdba7ed4f2d3ef4573af Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 02:06:05 -0500 Subject: [PATCH 08/11] cleanup Signed-off-by: Mayank Mishra --- .../modeling_utils/mlp_blocks/moe/base.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py index 7648034dc..4c48fc6e5 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py @@ -18,30 +18,6 @@ from cute_kernels.kernels.scattermoe.triton_implementation import bincount -class _AuxLossBackprop(torch.autograd.Function): - @staticmethod - def forward(ctx, hidden_states: torch.Tensor, aux_loss: torch.Tensor, router_aux_loss_coef: float) -> torch.Tensor: - ctx.router_aux_loss_coef = router_aux_loss_coef - ctx.dtype = aux_loss.dtype - return hidden_states, aux_loss - - @staticmethod - def backward(ctx, hidden_states_grad: torch.Tensor, aux_loss_grad: torch.Tensor) -> tuple[torch.Tensor]: - return ( - hidden_states_grad, - torch.tensor(ctx.router_aux_loss_coef, dtype=ctx.dtype, device=torch.cuda.current_device()), - None, - ) - - -def aux_loss_backprop( - hidden_states: torch.Tensor, aux_loss: torch.Tensor, router_aux_loss_coef: float -) -> torch.Tensor: - hidden_states, aux_loss = _AuxLossBackprop.apply(hidden_states, aux_loss, router_aux_loss_coef) - aux_loss = aux_loss.detach() - return hidden_states, aux_loss - - class ParameterizedExperts(nn.Module): def __init__( self, @@ -197,8 +173,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: logits=router_logits, probs=torch.softmax(router_logits, dim=-1), topk_idxs=selected_experts ) - hidden_states, aux_loss = aux_loss_backprop(hidden_states, aux_loss, self.router_aux_loss_coef) - add_aux_loss(aux_loss) return hidden_states From dfbacd6e8d87378ff77e129275fd0f2fea373457 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 02:06:45 -0500 Subject: [PATCH 09/11] fix simple-spec crash Signed-off-by: Mayank Mishra --- .../hf_models/modeling_utils/mlp_blocks/moe/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py index 4c48fc6e5..4dfa1dd03 100644 --- a/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py +++ b/dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/base.py @@ -167,11 +167,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dropout(hidden_states) - aux_loss = 0 - if self.training: - aux_loss = self._compute_switch_loss( + aux_loss = ( + self._compute_switch_loss( logits=router_logits, probs=torch.softmax(router_logits, dim=-1), topk_idxs=selected_experts ) + if self.training + else 0 + ) add_aux_loss(aux_loss) From 2cc2ea207065cfba6c07b0169033df536e2b53b6 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 02:19:32 -0500 Subject: [PATCH 10/11] cleanup Signed-off-by: Mayank Mishra --- dolomite_engine/hf_models/mixins/dense_TP/main.py | 10 ++++++---- dolomite_engine/hf_models/mixins/modeling_outputs.py | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/dolomite_engine/hf_models/mixins/dense_TP/main.py b/dolomite_engine/hf_models/mixins/dense_TP/main.py index 26d6daea6..da6a7be7d 100644 --- a/dolomite_engine/hf_models/mixins/dense_TP/main.py +++ b/dolomite_engine/hf_models/mixins/dense_TP/main.py @@ -10,7 +10,7 @@ from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible from ...config import CommonConfig from ...enums import PositionEmbeddingType -from ...loss import add_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss +from ...loss import add_aux_loss, clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss from ...modeling_utils_TP import LMHead_TP from ..dense import CausalLMModelMixin from ..modeling_outputs import ( @@ -68,6 +68,8 @@ def forward( if self.is_pipeline_parallel_enabled: past_key_values = None + clear_aux_loss() + if self.is_first_stage: assert pipeline_parallel_input is None, "first stage should not get pipeline_parallel_input" input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( @@ -84,6 +86,7 @@ def forward( ) else: assert input_ids is None + add_aux_loss(pipeline_parallel_input.aux_loss) transformer_outputs: BaseModelOutputWithPast = self.transformer( input_ids=input_ids if pipeline_parallel_input is None else pipeline_parallel_input.hidden_states, @@ -105,6 +108,7 @@ def forward( lm_logits = None loss = None + aux_loss = get_aux_loss() if self.is_last_stage: if labels is None: @@ -142,8 +146,6 @@ def forward( lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) - aux_loss = get_aux_loss() - if loss is not None and aux_loss != 0: loss = loss + self.router_aux_loss_coef * aux_loss @@ -154,7 +156,7 @@ def forward( last_hidden_state=hidden_states, ) else: - output = PipelineParallelOutput(hidden_states=hidden_states) + output = PipelineParallelOutput(hidden_states=hidden_states, aux_loss=aux_loss) return output diff --git a/dolomite_engine/hf_models/mixins/modeling_outputs.py b/dolomite_engine/hf_models/mixins/modeling_outputs.py index 5ec7b93ad..75f7ca67b 100644 --- a/dolomite_engine/hf_models/mixins/modeling_outputs.py +++ b/dolomite_engine/hf_models/mixins/modeling_outputs.py @@ -21,8 +21,10 @@ class CausalLMOutputWithPast(ModelOutput): @dataclass class PipelineParallelInput(ModelOutput): hidden_states: torch.Tensor | None = None + aux_loss: torch.Tensor | float = 0 @dataclass class PipelineParallelOutput(ModelOutput): hidden_states: torch.Tensor | None = None + aux_loss: torch.Tensor | float = 0 From 3ecd7d3cd2d9d86339651a224597ae4788294e90 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 6 Mar 2025 12:36:40 -0500 Subject: [PATCH 11/11] cleanup Signed-off-by: Mayank Mishra --- dolomite_engine/model_wrapper/finetuning.py | 27 ++++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/dolomite_engine/model_wrapper/finetuning.py b/dolomite_engine/model_wrapper/finetuning.py index 073c10a9c..ce008e013 100644 --- a/dolomite_engine/model_wrapper/finetuning.py +++ b/dolomite_engine/model_wrapper/finetuning.py @@ -25,18 +25,27 @@ def forward(self, batch: dict, lm_loss_multiplier: float = 1) -> MetricsTracking if ProcessGroupManager.is_tensor_parallel_enabled(): batch = self._broadcast_inputs_for_tensor_parallel(batch) - if not self.is_custom_model: + if self.is_custom_model: assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) - labels = batch.pop("labels") - model_outputs: CausalLMOutputWithPast = self.model(**batch) + labels = batch.pop("labels") + model_outputs: CausalLMOutputWithPast = self.model(**batch) - return self.get_loss( - model_outputs=model_outputs, - labels=labels, - cu_seqlens=batch.get("cu_seqlens", None), - lm_loss_multiplier=lm_loss_multiplier, - ) + output = self.get_loss( + model_outputs=model_outputs, + labels=labels, + cu_seqlens=batch.get("cu_seqlens", None), + lm_loss_multiplier=lm_loss_multiplier, + ) + else: + # use HF loss API for HF model since we can't get the aux loss outside + model_outputs: CausalLMOutputWithPast = self.model(**batch) + output = {"loss": model_outputs.loss} + + if hasattr(model_outputs, "aux_loss"): + output["aux_loss"] = model_outputs.aux_loss + + return output def get_loss( self,