Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions dolomite_engine/hf_models/mixins/dense_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ....utils import ProcessGroupManager, SafeTensorsWeightsManager, divide_if_divisible
from ...config import CommonConfig
from ...enums import PositionEmbeddingType
from ...loss import add_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss
from ...loss import add_aux_loss, clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss
from ...modeling_utils_TP import LMHead_TP
from ..dense import CausalLMModelMixin
from ..modeling_outputs import (
Expand Down Expand Up @@ -68,6 +68,8 @@ def forward(
if self.is_pipeline_parallel_enabled:
past_key_values = None

clear_aux_loss()

if self.is_first_stage:
assert pipeline_parallel_input is None, "first stage should not get pipeline_parallel_input"
input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model(
Expand All @@ -84,6 +86,7 @@ def forward(
)
else:
assert input_ids is None
add_aux_loss(pipeline_parallel_input.aux_loss)

transformer_outputs: BaseModelOutputWithPast = self.transformer(
input_ids=input_ids if pipeline_parallel_input is None else pipeline_parallel_input.hidden_states,
Expand All @@ -105,6 +108,7 @@ def forward(

lm_logits = None
loss = None
aux_loss = get_aux_loss()

if self.is_last_stage:
if labels is None:
Expand Down Expand Up @@ -142,8 +146,6 @@ def forward(
lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1))
lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate())

aux_loss = get_aux_loss()

if loss is not None and aux_loss != 0:
loss = loss + self.router_aux_loss_coef * aux_loss

Expand All @@ -154,7 +156,7 @@ def forward(
last_hidden_state=hidden_states,
)
else:
output = PipelineParallelOutput(hidden_states=hidden_states)
output = PipelineParallelOutput(hidden_states=hidden_states, aux_loss=aux_loss)

return output

Expand Down
2 changes: 2 additions & 0 deletions dolomite_engine/hf_models/mixins/modeling_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ class CausalLMOutputWithPast(ModelOutput):
@dataclass
class PipelineParallelInput(ModelOutput):
hidden_states: torch.Tensor | None = None
aux_loss: torch.Tensor | float = 0


@dataclass
class PipelineParallelOutput(ModelOutput):
hidden_states: torch.Tensor | None = None
aux_loss: torch.Tensor | float = 0
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ...config import CommonConfig
from ...enums import InitMethod
from .mlp import MLP, interleave_up_gate_tensor_for_mlp, split_up_gate_tensor_for_mlp
from .moe import AuxFreeMoE, MoE, ParameterizedExperts, ParameterizedScatteredExperts, ScatterMoE
from .moe import MoE, ParameterizedExperts, ParameterizedScatteredExperts, ScatterMoE


def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, layer_idx: int) -> MLP | MoE:
Expand Down Expand Up @@ -33,8 +33,5 @@ def get_mlp_block(config: CommonConfig, use_padding_free_transformer: bool, laye
num_experts_per_tok=block.num_experts_per_tok,
use_padding_free_transformer=use_padding_free_transformer,
)
elif mlp_type == "AuxFreeMoE":
assert is_kernel_allowed(Kernel.scattermoe)
return AuxFreeMoE(config, use_padding_free_transformer)
else:
raise ValueError(f"invalid mlp_type ({mlp_type}) for layer ({layer_idx})")
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .auxfree import AuxFreeMoE
from .base import MoE, ParameterizedExperts
from .scatter import ParameterizedScatteredExperts, ScatterMoE
49 changes: 0 additions & 49 deletions dolomite_engine/hf_models/modeling_utils/mlp_blocks/moe/auxfree.py

This file was deleted.

27 changes: 18 additions & 9 deletions dolomite_engine/model_wrapper/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,27 @@ def forward(self, batch: dict, lm_loss_multiplier: float = 1) -> MetricsTracking
if ProcessGroupManager.is_tensor_parallel_enabled():
batch = self._broadcast_inputs_for_tensor_parallel(batch)

if not self.is_custom_model:
if self.is_custom_model:
assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute)

labels = batch.pop("labels")
model_outputs: CausalLMOutputWithPast = self.model(**batch)
labels = batch.pop("labels")
model_outputs: CausalLMOutputWithPast = self.model(**batch)

return self.get_loss(
model_outputs=model_outputs,
labels=labels,
cu_seqlens=batch.get("cu_seqlens", None),
lm_loss_multiplier=lm_loss_multiplier,
)
output = self.get_loss(
model_outputs=model_outputs,
labels=labels,
cu_seqlens=batch.get("cu_seqlens", None),
lm_loss_multiplier=lm_loss_multiplier,
)
else:
# use HF loss API for HF model since we can't get the aux loss outside
model_outputs: CausalLMOutputWithPast = self.model(**batch)
output = {"loss": model_outputs.loss}

if hasattr(model_outputs, "aux_loss"):
output["aux_loss"] = model_outputs.aux_loss

return output

def get_loss(
self,
Expand Down
14 changes: 1 addition & 13 deletions dolomite_engine/model_wrapper/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def get_loss(
if tensor_parallel_enabled:
aux_loss = tensor_to_dtensor(aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate())

loss = _F.apply(lm_loss, aux_loss, self.router_aux_loss_coef)
loss = lm_loss + self.router_aux_loss_coef * aux_loss
output = {"loss": loss, "lm_loss": lm_loss, "aux_loss": aux_loss}

return output
Expand Down Expand Up @@ -289,15 +289,3 @@ def reset_parameters(self) -> None:
assert (
not self.reset_position_ids
), "currently reset_position_ids is only implemented for padding free transformer"


class _F(torch.autograd.Function):
@staticmethod
def forward(ctx, lm_loss: torch.Tensor, aux_loss: torch.Tensor, router_aux_loss_coef: float) -> torch.Tensor:
ctx.router_aux_loss_coef = router_aux_loss_coef
return lm_loss + router_aux_loss_coef * aux_loss

@staticmethod
@torch._dynamo.disable
def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor | None]:
return grad_output, ctx.router_aux_loss_coef * grad_output, None