From 78826059c8098b266d844c8a360e18c895a3730f Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Wed, 16 Jul 2025 19:13:44 +0000 Subject: [PATCH 01/19] Boilerplate. --- lm_engine/enums.py | 1 + lm_engine/hf_models/models/__init__.py | 1 + .../hf_models/models/diffusion/__init__.py | 7 + lm_engine/hf_models/models/diffusion/base.py | 13 ++ .../hf_models/models/diffusion/config.py | 8 + lm_engine/hf_models/models/diffusion/main.py | 167 +++++++++++++++ lm_engine/hf_models/register_hf.py | 4 + lm_engine/model_wrapper/__init__.py | 4 +- .../model_wrapper/pretraining_diffusion.py | 190 ++++++++++++++++++ lm_engine/pretrain.py | 3 +- 10 files changed, 396 insertions(+), 2 deletions(-) create mode 100644 lm_engine/hf_models/models/diffusion/__init__.py create mode 100644 lm_engine/hf_models/models/diffusion/base.py create mode 100644 lm_engine/hf_models/models/diffusion/config.py create mode 100644 lm_engine/hf_models/models/diffusion/main.py create mode 100644 lm_engine/model_wrapper/pretraining_diffusion.py diff --git a/lm_engine/enums.py b/lm_engine/enums.py index 66ee05fa2..6abb24823 100644 --- a/lm_engine/enums.py +++ b/lm_engine/enums.py @@ -33,6 +33,7 @@ class TuningMethod(Enum): """training method""" pretraining = "pretraining" + pretraining_diffusion = "pretraining_diffusion" full_finetuning = "full_finetuning" distillation = "distillation" diff --git a/lm_engine/hf_models/models/__init__.py b/lm_engine/hf_models/models/__init__.py index ecf78511d..346d969fe 100644 --- a/lm_engine/hf_models/models/__init__.py +++ b/lm_engine/hf_models/models/__init__.py @@ -13,3 +13,4 @@ from .ladder_residual import LadderResidualConfig, LadderResidualForCausalLM, LadderResidualModel from .ladder_residual_TP import LadderResidualForCausalLM_TP, LadderResidualModel_TP from .palm import PaLMConfig, PaLMForCausalLM, PaLMModel +from .diffusion import DiffusionConfig, DiffusionMaskedLM, DiffusionModel \ No newline at end of file diff --git a/lm_engine/hf_models/models/diffusion/__init__.py b/lm_engine/hf_models/models/diffusion/__init__.py new file mode 100644 index 000000000..5c95e6bad --- /dev/null +++ b/lm_engine/hf_models/models/diffusion/__init__.py @@ -0,0 +1,7 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from .base import DiffusionModel +from .config import DiffusionConfig +from .main import DiffusionMaskedLM diff --git a/lm_engine/hf_models/models/diffusion/base.py b/lm_engine/hf_models/models/diffusion/base.py new file mode 100644 index 000000000..f4a055c44 --- /dev/null +++ b/lm_engine/hf_models/models/diffusion/base.py @@ -0,0 +1,13 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from ...mixins import BaseModelMixin, PreTrainedModelMixin +from .config import DiffusionConfig + + +class DiffusionPreTrainedModel(PreTrainedModelMixin): + config_class = DiffusionConfig + + +class DiffusionModel(DiffusionPreTrainedModel, BaseModelMixin): ... diff --git a/lm_engine/hf_models/models/diffusion/config.py b/lm_engine/hf_models/models/diffusion/config.py new file mode 100644 index 000000000..ab1e1c79a --- /dev/null +++ b/lm_engine/hf_models/models/diffusion/config.py @@ -0,0 +1,8 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from ..gpt_base import GPTBaseConfig + +class DiffusionConfig(GPTBaseConfig): + model_type = "diffusion" diff --git a/lm_engine/hf_models/models/diffusion/main.py b/lm_engine/hf_models/models/diffusion/main.py new file mode 100644 index 000000000..458ccd688 --- /dev/null +++ b/lm_engine/hf_models/models/diffusion/main.py @@ -0,0 +1,167 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from ...mixins import CausalLMModelMixin +from .base import DiffusionModel, DiffusionPreTrainedModel +from .config import DiffusionConfig + +import torch +import torch.nn.functional as F +from transformers import GenerationMixin + +from ....enums import Kernel +from ....kernels import is_kernel_allowed +from ...cache import GenerationCache +from ...config import CommonConfig +from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero +from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear +from ...mixins.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +# from .base import PreTrainedModelMixin + + + +class DiffusionMaskedLM(DiffusionPreTrainedModel): + def __init__(self, config: DiffusionConfig, **kwargs) -> DiffusionPreTrainedModel: + super().__init__(config, **kwargs) + + self.router_aux_loss_coef = getattr(config, "router_aux_loss_coef", 0) + self._init_model(config, **kwargs) + + def _init_model(self, config: DiffusionConfig, **kwargs) -> None: + self.transformer = DiffusionModel(config, **kwargs) + + if not self._tied_word_embeddings: + self.lm_head = ParameterizedLinear( + config.hidden_size, config.vocab_size, bias=False, std=config.initializer_range + ) + + self.m_width = config.m_width + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ParameterizedEmbedding: + return self.transformer.wte + + def set_input_embeddings(self, value: ParameterizedEmbedding) -> None: + self.transformer.wte = value + + def get_output_embeddings(self) -> ParameterizedLinear: + return self.transformer.wte if self._tied_word_embeddings else self.lm_head + + def set_output_embeddings(self, new_embeddings: ParameterizedLinear) -> None: + if not self._tied_word_embeddings: + self.lm_head = new_embeddings + + def forward( + self, + input_ids: torch.Tensor | list[list[int]] | None = None, + past_key_values: GenerationCache | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | list[list[int]] | None = None, + inputs_embeds: torch.Tensor | list[list[float]] | None = None, + labels: torch.Tensor | list[list[int]] | None = None, + use_cache: bool | None = None, + return_dict: bool = True, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + reduction: str = "mean", + ) -> CausalLMOutputWithPast: + assert return_dict + assert inputs_embeds is None + + input_ids, position_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + ) + + if labels is not None: + print(input_ids[:10]) + print(labels[:10]) + exit() + + # ========================================================================================== + # padding_free: + # input_ids -> (total_q) + # attention_mask -> None + # position_ids -> (total_q) + # else: + # input_ids -> (batch_size, query_length) + # attention_mask -> None or (batch_size, key_length) + # position_ids -> None or (batch_size, key_length) + # ========================================================================================== + + clear_aux_loss() + + transformer_outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = transformer_outputs.last_hidden_state + past_key_values = transformer_outputs.past_key_values + del transformer_outputs + + lm_logits = None + loss = None + + if labels is None: + if is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute): + if self.m_width is not None: + hidden_states = hidden_states / self.m_width + else: + lm_logits = self.get_lm_logits(hidden_states) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width + else: + assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) + + lm_logits = self.get_lm_logits(hidden_states) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width + + loss = get_autoregressive_language_modeling_loss( + lm_logits=lm_logits, + labels=labels, + hidden_states=None, + vocab_weight=None, + cu_seqlens=cu_seqlens, + use_padding_free_transformer=self.use_padding_free_transformer, + reduction=reduction, + shift_logits_and_labels=True, + tensor_parallel_enabled=False, + ) + + aux_loss = get_aux_loss() + + if loss is not None and not is_aux_loss_zero(aux_loss): + loss = loss + self.router_aux_loss_coef * aux_loss + + return CausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=lm_logits, + past_key_values=past_key_values, + last_hidden_state=hidden_states, + ) + + def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + return ( + F.linear(hidden_states, self.transformer.wte.weight) + if self._tied_word_embeddings + else self.lm_head(hidden_states) + ) diff --git a/lm_engine/hf_models/register_hf.py b/lm_engine/hf_models/register_hf.py index c54d59393..1246191de 100644 --- a/lm_engine/hf_models/register_hf.py +++ b/lm_engine/hf_models/register_hf.py @@ -19,6 +19,9 @@ PaLMConfig, PaLMForCausalLM, PaLMModel, + DiffusionConfig, + DiffusionMaskedLM, + DiffusionModel ) @@ -28,6 +31,7 @@ (GPTCrossLayerConfig, GPTCrossLayerModel, GPTCrossLayerForCausalLM), (LadderResidualConfig, LadderResidualModel, LadderResidualForCausalLM), (PaLMConfig, PaLMModel, PaLMForCausalLM), + (DiffusionConfig, DiffusionModel, DiffusionMaskedLM) ] _CUSTOM_MODEL_TYPES = [] _CUSTOM_MODEL_CLASSES = [] diff --git a/lm_engine/model_wrapper/__init__.py b/lm_engine/model_wrapper/__init__.py index 12615bdde..0370807d2 100644 --- a/lm_engine/model_wrapper/__init__.py +++ b/lm_engine/model_wrapper/__init__.py @@ -11,11 +11,13 @@ from .distillation import ModelWrapperForDistillation from .finetuning import ModelWrapperForFinetuning from .pretraining import ModelWrapperForPretraining +from .pretraining_diffusion import ModelWrapperForPretrainingDiffusion from .utils import broadcast_tensor_parallel_input _MODEL_CLASS_MAPPING = { TuningMethod.pretraining: ModelWrapperForPretraining, + TuningMethod.pretraining_diffusion: ModelWrapperForPretrainingDiffusion, TuningMethod.full_finetuning: ModelWrapperForFinetuning, TuningMethod.distillation: ModelWrapperForDistillation, } @@ -49,7 +51,7 @@ def get_model_container( } # pretraining model wrapper needs some extra arguments for initialization - if tuning_method in [TuningMethod.pretraining, TuningMethod.distillation]: + if tuning_method in [TuningMethod.pretraining, TuningMethod.distillation, TuningMethod.pretraining_diffusion]: kwargs["micro_batch_size"] = args.training_parameters.micro_batch_size kwargs["sequence_length"] = args.datasets[0].class_args.get("sequence_length") kwargs["reset_attention_mask"] = args.model_args.reset_attention_mask diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py new file mode 100644 index 000000000..9f6a1fb5a --- /dev/null +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -0,0 +1,190 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch +from torch.distributed._tensor.placement_types import Replicate +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM + +from ..dtensors import tensor_to_dtensor +from ..enums import Kernel, Mode +from ..hf_models import ( + CausalLMOutputWithPast, + PipelineParallelInput, + PipelineParallelOutput, + get_autoregressive_language_modeling_loss, + is_aux_loss_zero, +) +from ..kernels import is_kernel_allowed +from ..utils import MetricsTrackingDict, ProcessGroupManager +from .base import ModelWrapper +from .utils import broadcast_tensor_parallel_input +from .pretraining import ModelWrapperForPretraining + + +class ModelWrapperForPretrainingDiffusion(ModelWrapperForPretraining): + + def forward( + self, + batch: dict | torch.Tensor, + aux_loss_from_pipeline_parallel: torch.Tensor | float = 0, + lm_loss_multiplier: float = 1, + ) -> dict: + """forward function for a batch + + Args: + batch (dict): a dict of key, value pairs for a batch + + Returns: + torch.Tensor: loss tensor + """ + + # for pretraining we compute loss externally here instead of relying on transformers. + # this is done because megatron's dataset returns batches of length (sequence_length + 1) + # instead of (sequence_length), so we need to trim the input_ids before forward pass. + # transformers does forward pass before however and then trims the tokens. + + if not self.is_custom_model: + assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) + + if isinstance(batch, torch.Tensor): + batch = {"text": batch} + + if self.is_pipeline_parallel_enabled: + batch["aux_loss_from_pipeline_parallel"] = aux_loss_from_pipeline_parallel + else: + assert aux_loss_from_pipeline_parallel == 0 + + batch = self._prepare_model_inputs(batch) + labels = batch.pop("labels") + output: CausalLMOutputWithPast | PipelineParallelOutput = self.model(**batch, return_dict=True) + + if self.is_pipeline_parallel_enabled: + # aux_loss is returned as a 0 dimensional tensor + aux_loss = output.aux_loss + use_aux_loss = not is_aux_loss_zero(aux_loss) + + if use_aux_loss and aux_loss.dim() == 0: + aux_loss = aux_loss.unsqueeze(0) + + if self.is_last_stage: + assert isinstance(output, CausalLMOutputWithPast) + output = output.logits + else: + assert isinstance(output, PipelineParallelOutput) + output = output.hidden_states + + if use_aux_loss: + output = (output, aux_loss) + else: + output = self.get_loss(output, labels, lm_loss_multiplier=lm_loss_multiplier) + + return output + + def get_loss( + self, model_outputs: CausalLMOutputWithPast, labels: torch.Tensor, lm_loss_multiplier: float = 1 + ) -> torch.Tensor | dict: + tensor_parallel_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) + + lm_loss = get_autoregressive_language_modeling_loss( + lm_logits=None if use_fused_linear_cross_entropy_kernel else model_outputs.logits, + labels=labels, + hidden_states=model_outputs.last_hidden_state if use_fused_linear_cross_entropy_kernel else None, + vocab_weight=self.model.get_output_embeddings().weight if use_fused_linear_cross_entropy_kernel else None, + cu_seqlens=None, + use_padding_free_transformer=self.use_padding_free_transformer, + reduction="sum", + shift_logits_and_labels=False, + tensor_parallel_enabled=tensor_parallel_enabled, + ) + + lm_loss = lm_loss * lm_loss_multiplier + aux_loss = getattr(model_outputs, "aux_loss", 0) + + if is_aux_loss_zero(aux_loss): + loss = lm_loss + output = {"loss": loss, "lm_loss": loss} + else: + if self.is_pipeline_parallel_enabled: + self._extra_metrics = self._extra_metrics + {"aux_loss": aux_loss} + + if tensor_parallel_enabled: + aux_loss = tensor_to_dtensor(aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate()) + + loss = _F.apply(lm_loss, aux_loss, self.router_aux_loss_coef) + output = {"loss": loss, "lm_loss": lm_loss, "aux_loss": aux_loss} + + return output + + def _prepare_model_inputs(self, batch: dict) -> dict: + if self.is_pipeline_parallel_enabled: + # when using pipeline parallel, we broadcast the input outside the model function + tokens = batch["text"] + aux_loss_from_pipeline_parallel = batch["aux_loss_from_pipeline_parallel"] + + tokens = tokens.to(torch.cuda.current_device()) + + if self.is_first_stage: + input_ids = tokens[:, :-1] + pipeline_parallel_input = None + else: + input_ids = None + pipeline_parallel_input = PipelineParallelInput( + hidden_states=tokens, aux_loss=aux_loss_from_pipeline_parallel + ) + + batch = {"labels": None, "pipeline_parallel_input": pipeline_parallel_input} + else: + if ProcessGroupManager.is_tensor_parallel_enabled(): + tokens = broadcast_tensor_parallel_input( + None if batch is None else batch["text"], (self.micro_batch_size, self.sequence_length + 1) + ) + else: + tokens = batch["text"] + tokens = tokens.to(torch.cuda.current_device()) + + input_ids = tokens[:, :-1] + batch = {"labels": tokens[:, 1:]} + + if self.use_padding_free_transformer: + batch_size, sequence_length = input_ids.shape + input_ids = input_ids.reshape(-1) + + if self.reset_attention_mask: + num_tokens_in_batch = batch_size * sequence_length + + document_end_positions = input_ids == self.eos_token_id + for i in range(sequence_length - 1, num_tokens_in_batch, sequence_length): + document_end_positions[i] = 1 + cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 + cu_seqlens = torch.cat([torch.tensor([0], device=input_ids.device), cu_seqlens]) + cu_seqlens = cu_seqlens.to(torch.int32) + + seqlen = cu_seqlens[1:] - cu_seqlens[:-1] + # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers + max_seqlen = seqlen.max().item() + + if self.reset_position_ids: + position_ids = torch.cat( + [torch.arange(0, i, 1, dtype=torch.int32, device=input_ids.device) for i in seqlen] + ) + else: + position_ids = self.position_ids + else: + cu_seqlens = self.cu_seqlens + max_seqlen = self.sequence_length + position_ids = self.position_ids + + batch["cu_seqlens"] = cu_seqlens + batch["max_seqlen"] = max_seqlen + batch["position_ids"] = position_ids + + batch["input_ids"] = input_ids + + if ProcessGroupManager.is_tensor_parallel_enabled(): + batch["output_parallel_lm_logits"] = True + + return batch diff --git a/lm_engine/pretrain.py b/lm_engine/pretrain.py index ca9ed1a4f..31740265b 100644 --- a/lm_engine/pretrain.py +++ b/lm_engine/pretrain.py @@ -576,7 +576,8 @@ def main(args_class: type[DistillationArgs | TrainingArgs] = TrainingArgs) -> No if args_class == TrainingArgs: assert ( - args.tuning_args.tuning_method == TuningMethod.pretraining + args.tuning_args.tuning_method == TuningMethod.pretraining or + args.tuning_args.tuning_method == TuningMethod.pretraining_diffusion ), f"unexpected tuning method ({args.tuning_args.tuning_method})" elif args_class == DistillationArgs: assert args.distributed_args.fsdp_algorithm == 2, "Distillation is only supported with FSDP-2" From 59d19951e816a2e08ed9a46623bc7e3509d3552f Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Wed, 16 Jul 2025 22:57:40 +0000 Subject: [PATCH 02/19] Working maksed modelling --- lm_engine/hf_models/config/__init__.py | 1 - lm_engine/hf_models/config/sequence_mixer.py | 1 + lm_engine/hf_models/mixins/dense/layer.py | 2 + .../sequence_mixer_blocks/__init__.py | 2 +- lm_engine/model_wrapper/__init__.py | 3 + .../model_wrapper/pretraining_diffusion.py | 95 +++++++++++++------ 6 files changed, 72 insertions(+), 32 deletions(-) diff --git a/lm_engine/hf_models/config/__init__.py b/lm_engine/hf_models/config/__init__.py index fe76d8fae..6f7ebe87b 100644 --- a/lm_engine/hf_models/config/__init__.py +++ b/lm_engine/hf_models/config/__init__.py @@ -215,7 +215,6 @@ def _set_sequence_mixer_blocks(self) -> None: sequence_mixer_block["intermediate_size"] = sequence_mixer_block.pop( "intermediate_size", 2 * self.hidden_size ) - sequence_mixer_blocks.append(_SEQUENCE_MIXER_CONFIG_CLASSES[sequence_mixer_type](**sequence_mixer_block)) self.sequence_mixer_blocks = sequence_mixer_blocks diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index da103eaed..d8b3396d0 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -14,6 +14,7 @@ class _SoftmaxAttentionArgs(BaseArgs): softmax_dropout: float = 0 dropout: float = 0 add_bias: bool = True + causal: bool = True attention_multiplier: float | None = None def model_post_init(self, __context: Any) -> None: diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index a0c3541cd..90fc34360 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -25,7 +25,9 @@ def __init__( self.ln_1 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) + self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index dfdfdfd32..901e457b2 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -126,7 +126,7 @@ def get_sequence_mixer( initializer_range=config.initializer_range, m_width=config.m_width, num_layers=config.num_layers, - causal=causal, + causal=causal if not hasattr(block, "causal") else block.causal, layer_idx=layer_idx, ) diff --git a/lm_engine/model_wrapper/__init__.py b/lm_engine/model_wrapper/__init__.py index 0370807d2..7880cf714 100644 --- a/lm_engine/model_wrapper/__init__.py +++ b/lm_engine/model_wrapper/__init__.py @@ -57,6 +57,9 @@ def get_model_container( kwargs["reset_attention_mask"] = args.model_args.reset_attention_mask kwargs["reset_position_ids"] = args.model_args.reset_position_ids + if tuning_method == TuningMethod.pretraining_diffusion: + print(args.model_args) + if tuning_method == TuningMethod.distillation: kwargs["teacher_model_name"] = args.teacher_args.model_name kwargs["teacher_model_class"] = args.teacher_args.model_class diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index 9f6a1fb5a..c2511dc23 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -6,6 +6,7 @@ import torch from torch.distributed._tensor.placement_types import Replicate +from torch.nn import functional as F from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM from ..dtensors import tensor_to_dtensor @@ -20,8 +21,11 @@ from ..kernels import is_kernel_allowed from ..utils import MetricsTrackingDict, ProcessGroupManager from .base import ModelWrapper +from .pretraining import _F, ModelWrapperForPretraining from .utils import broadcast_tensor_parallel_input -from .pretraining import ModelWrapperForPretraining + + +FIM_MIDDLE = "" class ModelWrapperForPretrainingDiffusion(ModelWrapperForPretraining): @@ -87,19 +91,30 @@ def get_loss( self, model_outputs: CausalLMOutputWithPast, labels: torch.Tensor, lm_loss_multiplier: float = 1 ) -> torch.Tensor | dict: tensor_parallel_enabled = ProcessGroupManager.is_tensor_parallel_enabled() - use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) - - lm_loss = get_autoregressive_language_modeling_loss( - lm_logits=None if use_fused_linear_cross_entropy_kernel else model_outputs.logits, - labels=labels, - hidden_states=model_outputs.last_hidden_state if use_fused_linear_cross_entropy_kernel else None, - vocab_weight=self.model.get_output_embeddings().weight if use_fused_linear_cross_entropy_kernel else None, - cu_seqlens=None, - use_padding_free_transformer=self.use_padding_free_transformer, - reduction="sum", - shift_logits_and_labels=False, - tensor_parallel_enabled=tensor_parallel_enabled, + # use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) + flat_logits = model_outputs.logits.flatten(0, -2) + flat_labels = labels.flatten() + # print(flat_logits.size(), flat_labels.size()) + lm_loss = ( + F.cross_entropy( + input=flat_logits, + target=flat_labels, + ignore_index=self.mask_token_id, + reduction="mean", + ) + * flat_labels.numel() ) + # lm_loss = get_autoregressive_language_modeling_loss( + # lm_logits=None if use_fused_linear_cross_entropy_kernel else model_outputs.logits, + # labels=labels, + # hidden_states=model_outputs.last_hidden_state if use_fused_linear_cross_entropy_kernel else None, + # vocab_weight=self.model.get_output_embeddings().weight if use_fused_linear_cross_entropy_kernel else None, + # cu_seqlens=None, + # use_padding_free_transformer=self.use_padding_free_transformer, + # reduction="sum", + # shift_logits_and_labels=False, + # tensor_parallel_enabled=tensor_parallel_enabled, + # ) lm_loss = lm_loss * lm_loss_multiplier aux_loss = getattr(model_outputs, "aux_loss", 0) @@ -119,24 +134,43 @@ def get_loss( return output - def _prepare_model_inputs(self, batch: dict) -> dict: - if self.is_pipeline_parallel_enabled: - # when using pipeline parallel, we broadcast the input outside the model function - tokens = batch["text"] - aux_loss_from_pipeline_parallel = batch["aux_loss_from_pipeline_parallel"] + def _setup_tokenizer(self) -> None: + super()._setup_tokenizer() + # self.mask_token_id = self.tokenizer.mask_token_id + self.mask_token_id = self.tokenizer.convert_tokens_to_ids(FIM_MIDDLE) + assert self.mask_token_id is not None - tokens = tokens.to(torch.cuda.current_device()) + def _forward_process(self, input_ids, eps=1e-3): + b, l = input_ids.shape + t = torch.rand(b, device=input_ids.device) + p_mask = (1 - eps) * t + eps + p_mask = p_mask[:, None].repeat(1, l) - if self.is_first_stage: - input_ids = tokens[:, :-1] - pipeline_parallel_input = None - else: - input_ids = None - pipeline_parallel_input = PipelineParallelInput( - hidden_states=tokens, aux_loss=aux_loss_from_pipeline_parallel - ) + masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask + print(masked_indices.int().sum() / masked_indices.numel()) + noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids) + labels = torch.where(~masked_indices, self.mask_token_id, input_ids) + return noisy_batch, labels, p_mask - batch = {"labels": None, "pipeline_parallel_input": pipeline_parallel_input} + def _prepare_model_inputs(self, batch: dict) -> dict: + if self.is_pipeline_parallel_enabled: + raise NotImplementedError("No pipeline for diffusion yet.") + # # when using pipeline parallel, we broadcast the input outside the model function + # tokens = batch["text"] + # aux_loss_from_pipeline_parallel = batch["aux_loss_from_pipeline_parallel"] + + # tokens = tokens.to(torch.cuda.current_device()) + + # if self.is_first_stage: + # input_ids = tokens + # pipeline_parallel_input = None + # else: + # input_ids = None + # pipeline_parallel_input = PipelineParallelInput( + # hidden_states=tokens, aux_loss=aux_loss_from_pipeline_parallel + # ) + + # batch = {"labels": None, "pipeline_parallel_input": pipeline_parallel_input} else: if ProcessGroupManager.is_tensor_parallel_enabled(): tokens = broadcast_tensor_parallel_input( @@ -146,8 +180,9 @@ def _prepare_model_inputs(self, batch: dict) -> dict: tokens = batch["text"] tokens = tokens.to(torch.cuda.current_device()) - input_ids = tokens[:, :-1] - batch = {"labels": tokens[:, 1:]} + unnoised_input_ids = tokens[:, :-1] + input_ids, labels, p_mask = self._forward_process(unnoised_input_ids) + batch = {"labels": labels} if self.use_padding_free_transformer: batch_size, sequence_length = input_ids.shape From 9c96fbac93cc6d6048b3f4da5ebfd903844e41db Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Thu, 17 Jul 2025 16:33:03 +0000 Subject: [PATCH 03/19] Extra comment. --- lm_engine/hf_models/models/diffusion/main.py | 18 ++++++------------ .../model_wrapper/pretraining_diffusion.py | 5 ++--- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/lm_engine/hf_models/models/diffusion/main.py b/lm_engine/hf_models/models/diffusion/main.py index 458ccd688..5a688143d 100644 --- a/lm_engine/hf_models/models/diffusion/main.py +++ b/lm_engine/hf_models/models/diffusion/main.py @@ -2,10 +2,6 @@ # Copyright (c) 2025, Mayank Mishra # ************************************************** -from ...mixins import CausalLMModelMixin -from .base import DiffusionModel, DiffusionPreTrainedModel -from .config import DiffusionConfig - import torch import torch.nn.functional as F from transformers import GenerationMixin @@ -15,11 +11,15 @@ from ...cache import GenerationCache from ...config import CommonConfig from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero -from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear +from ...mixins import CausalLMModelMixin from ...mixins.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -# from .base import PreTrainedModelMixin +from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear +from .base import DiffusionModel, DiffusionPreTrainedModel +from .config import DiffusionConfig +# from .base import PreTrainedModelMixin + class DiffusionMaskedLM(DiffusionPreTrainedModel): def __init__(self, config: DiffusionConfig, **kwargs) -> DiffusionPreTrainedModel: @@ -81,12 +81,6 @@ def forward( attention_mask=attention_mask, use_cache=use_cache, ) - - if labels is not None: - print(input_ids[:10]) - print(labels[:10]) - exit() - # ========================================================================================== # padding_free: # input_ids -> (total_q) diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index c2511dc23..43642fd39 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -136,7 +136,7 @@ def get_loss( def _setup_tokenizer(self) -> None: super()._setup_tokenizer() - # self.mask_token_id = self.tokenizer.mask_token_id + # TODO (shawntan) Use FIM token for now. Figure out if there is a way to have actual mask token. self.mask_token_id = self.tokenizer.convert_tokens_to_ids(FIM_MIDDLE) assert self.mask_token_id is not None @@ -147,7 +147,6 @@ def _forward_process(self, input_ids, eps=1e-3): p_mask = p_mask[:, None].repeat(1, l) masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask - print(masked_indices.int().sum() / masked_indices.numel()) noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids) labels = torch.where(~masked_indices, self.mask_token_id, input_ids) return noisy_batch, labels, p_mask @@ -180,7 +179,7 @@ def _prepare_model_inputs(self, batch: dict) -> dict: tokens = batch["text"] tokens = tokens.to(torch.cuda.current_device()) - unnoised_input_ids = tokens[:, :-1] + unnoised_input_ids = tokens[:, 1:] input_ids, labels, p_mask = self._forward_process(unnoised_input_ids) batch = {"labels": labels} From 84bceeb152e3ab6f8cb23e108639b2ba2cd1cc89 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Fri, 18 Jul 2025 02:56:55 +0000 Subject: [PATCH 04/19] Wrong weighting. --- .../model_wrapper/pretraining_diffusion.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index 43642fd39..466082fd6 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -63,6 +63,7 @@ def forward( batch = self._prepare_model_inputs(batch) labels = batch.pop("labels") + p_mask = batch.pop("p_mask") output: CausalLMOutputWithPast | PipelineParallelOutput = self.model(**batch, return_dict=True) if self.is_pipeline_parallel_enabled: @@ -83,12 +84,16 @@ def forward( if use_aux_loss: output = (output, aux_loss) else: - output = self.get_loss(output, labels, lm_loss_multiplier=lm_loss_multiplier) + output = self.get_loss(output, labels, p_mask, lm_loss_multiplier=lm_loss_multiplier) return output def get_loss( - self, model_outputs: CausalLMOutputWithPast, labels: torch.Tensor, lm_loss_multiplier: float = 1 + self, + model_outputs: CausalLMOutputWithPast, + labels: torch.Tensor, + p_mask: torch.Tensor, + lm_loss_multiplier: float = 1, ) -> torch.Tensor | dict: tensor_parallel_enabled = ProcessGroupManager.is_tensor_parallel_enabled() # use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) @@ -100,10 +105,11 @@ def get_loss( input=flat_logits, target=flat_labels, ignore_index=self.mask_token_id, - reduction="mean", + reduction="none", ) - * flat_labels.numel() - ) + / p_mask.flatten() + ).sum() + # lm_loss = get_autoregressive_language_modeling_loss( # lm_logits=None if use_fused_linear_cross_entropy_kernel else model_outputs.logits, # labels=labels, @@ -181,7 +187,7 @@ def _prepare_model_inputs(self, batch: dict) -> dict: unnoised_input_ids = tokens[:, 1:] input_ids, labels, p_mask = self._forward_process(unnoised_input_ids) - batch = {"labels": labels} + batch = {"labels": labels, "p_mask": p_mask} if self.use_padding_free_transformer: batch_size, sequence_length = input_ids.shape From 65a4481236f09f1424963c218c714af5c0c5c431 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Fri, 18 Jul 2025 03:23:19 +0000 Subject: [PATCH 05/19] Updated without mode. --- cute-kernels | 2 +- lm_engine/model_wrapper/pretraining_diffusion.py | 2 +- lm_engine/unshard.py | 4 ++++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cute-kernels b/cute-kernels index a7d23d004..265584c61 160000 --- a/cute-kernels +++ b/cute-kernels @@ -1 +1 @@ -Subproject commit a7d23d0047ad3b68eed5d6be18bb21bc3eaaab1c +Subproject commit 265584c615a5acae52b68102667009eca87c70d6 diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index 466082fd6..5126bedce 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -10,7 +10,7 @@ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM from ..dtensors import tensor_to_dtensor -from ..enums import Kernel, Mode +from ..enums import Kernel from ..hf_models import ( CausalLMOutputWithPast, PipelineParallelInput, diff --git a/lm_engine/unshard.py b/lm_engine/unshard.py index e4f3424d8..2b2246cc0 100644 --- a/lm_engine/unshard.py +++ b/lm_engine/unshard.py @@ -12,7 +12,11 @@ def main() -> None: args = get_args(UnshardingArgs) +<<<<<<< HEAD model, _, state_dict = load_checkpoint_and_unshard(args) +======= + model, _, state_dict = load_checkpoint_and_unshard(args, allowed_meta_device=True) +>>>>>>> d2b9d0c (Updated without mode.) run_rank_n(model.save_pretrained, barrier=ProcessGroupManager.is_initialized())( args.unsharded_path, state_dict=state_dict ) From 895036e535e509f2a776779ef8e374ac4478a876 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Fri, 18 Jul 2025 21:10:36 +0000 Subject: [PATCH 06/19] Padding-free and attention reset training with diffusion. --- lm_engine/hf_models/models/diffusion/main.py | 2 - .../model_wrapper/pretraining_diffusion.py | 58 ++++++++++++------- lm_engine/unshard.py | 4 -- 3 files changed, 37 insertions(+), 27 deletions(-) diff --git a/lm_engine/hf_models/models/diffusion/main.py b/lm_engine/hf_models/models/diffusion/main.py index 5a688143d..86e4780a3 100644 --- a/lm_engine/hf_models/models/diffusion/main.py +++ b/lm_engine/hf_models/models/diffusion/main.py @@ -91,7 +91,6 @@ def forward( # attention_mask -> None or (batch_size, key_length) # position_ids -> None or (batch_size, key_length) # ========================================================================================== - clear_aux_loss() transformer_outputs: BaseModelOutputWithPast = self.transformer( @@ -139,7 +138,6 @@ def forward( shift_logits_and_labels=True, tensor_parallel_enabled=False, ) - aux_loss = get_aux_loss() if loss is not None and not is_aux_loss_zero(aux_loss): diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index 5126bedce..d20998d17 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -160,22 +160,22 @@ def _forward_process(self, input_ids, eps=1e-3): def _prepare_model_inputs(self, batch: dict) -> dict: if self.is_pipeline_parallel_enabled: raise NotImplementedError("No pipeline for diffusion yet.") - # # when using pipeline parallel, we broadcast the input outside the model function - # tokens = batch["text"] - # aux_loss_from_pipeline_parallel = batch["aux_loss_from_pipeline_parallel"] - - # tokens = tokens.to(torch.cuda.current_device()) - - # if self.is_first_stage: - # input_ids = tokens - # pipeline_parallel_input = None - # else: - # input_ids = None - # pipeline_parallel_input = PipelineParallelInput( - # hidden_states=tokens, aux_loss=aux_loss_from_pipeline_parallel - # ) - - # batch = {"labels": None, "pipeline_parallel_input": pipeline_parallel_input} + # when using pipeline parallel, we broadcast the input outside the model function + # tokens = batch["text"] + # aux_loss_from_pipeline_parallel = batch["aux_loss_from_pipeline_parallel"] + + # tokens = tokens.to(torch.cuda.current_device()) + + # if self.is_first_stage: + # input_ids = tokens + # pipeline_parallel_input = None + # else: + # input_ids = None + # pipeline_parallel_input = PipelineParallelInput( + # hidden_states=tokens, aux_loss=aux_loss_from_pipeline_parallel + # ) + + # batch = {"labels": None, "pipeline_parallel_input": pipeline_parallel_input} else: if ProcessGroupManager.is_tensor_parallel_enabled(): tokens = broadcast_tensor_parallel_input( @@ -190,13 +190,12 @@ def _prepare_model_inputs(self, batch: dict) -> dict: batch = {"labels": labels, "p_mask": p_mask} if self.use_padding_free_transformer: - batch_size, sequence_length = input_ids.shape - input_ids = input_ids.reshape(-1) - + batch_size, sequence_length = unnoised_input_ids.shape + input_ids = input_ids.flatten() if self.reset_attention_mask: + unnoised_input_ids = unnoised_input_ids.flatten() num_tokens_in_batch = batch_size * sequence_length - - document_end_positions = input_ids == self.eos_token_id + document_end_positions = unnoised_input_ids == self.eos_token_id for i in range(sequence_length - 1, num_tokens_in_batch, sequence_length): document_end_positions[i] = 1 cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 @@ -221,6 +220,23 @@ def _prepare_model_inputs(self, batch: dict) -> dict: batch["cu_seqlens"] = cu_seqlens batch["max_seqlen"] = max_seqlen batch["position_ids"] = position_ids + # if cu_seqlens.size(0) > 5: + # from transformers import PreTrainedTokenizer + # tokenizer: PreTrainedTokenizer = self.tokenizer + # for i in range(cu_seqlens.size(0) - 1): + # seq_in = input_ids.flatten()[cu_seqlens[i]:cu_seqlens[i+1]] + # print(' '.join([ + # c if idx != self.mask_token_id else "_" + # for idx, c in zip(seq_in, tokenizer.convert_ids_to_tokens(seq_in)) + # ])) + # seq_out = labels.flatten()[cu_seqlens[i]:cu_seqlens[i+1]] + # print(' '.join([ + # c if idx != self.mask_token_id else "_" + # for idx, c in zip(seq_out, tokenizer.convert_ids_to_tokens(seq_out)) + # ])) + # print(cu_seqlens.size()) + # print((unnoised_input_ids == self.eos_token_id).int().sum()) + # exit() batch["input_ids"] = input_ids diff --git a/lm_engine/unshard.py b/lm_engine/unshard.py index 2b2246cc0..e4f3424d8 100644 --- a/lm_engine/unshard.py +++ b/lm_engine/unshard.py @@ -12,11 +12,7 @@ def main() -> None: args = get_args(UnshardingArgs) -<<<<<<< HEAD model, _, state_dict = load_checkpoint_and_unshard(args) -======= - model, _, state_dict = load_checkpoint_and_unshard(args, allowed_meta_device=True) ->>>>>>> d2b9d0c (Updated without mode.) run_rank_n(model.save_pretrained, barrier=ProcessGroupManager.is_initialized())( args.unsharded_path, state_dict=state_dict ) From 1dae7782fea326bdfaeb2261ea18dbecb7c8d89b Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 20 Jul 2025 03:29:59 +0000 Subject: [PATCH 07/19] Randomly replacing sequences with pad. --- .../model_wrapper/pretraining_diffusion.py | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index d20998d17..881eabbe2 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -145,6 +145,8 @@ def _setup_tokenizer(self) -> None: # TODO (shawntan) Use FIM token for now. Figure out if there is a way to have actual mask token. self.mask_token_id = self.tokenizer.convert_tokens_to_ids(FIM_MIDDLE) assert self.mask_token_id is not None + self.pad_token_id = self.tokenizer.pad_token_id + assert self.pad_token_id is not None def _forward_process(self, input_ids, eps=1e-3): b, l = input_ids.shape @@ -184,21 +186,42 @@ def _prepare_model_inputs(self, batch: dict) -> dict: else: tokens = batch["text"] tokens = tokens.to(torch.cuda.current_device()) - unnoised_input_ids = tokens[:, 1:] input_ids, labels, p_mask = self._forward_process(unnoised_input_ids) batch = {"labels": labels, "p_mask": p_mask} if self.use_padding_free_transformer: + batch_size, sequence_length = unnoised_input_ids.shape input_ids = input_ids.flatten() + flat_labels = labels.view(-1) if self.reset_attention_mask: unnoised_input_ids = unnoised_input_ids.flatten() - num_tokens_in_batch = batch_size * sequence_length + batch_size * sequence_length document_end_positions = unnoised_input_ids == self.eos_token_id - for i in range(sequence_length - 1, num_tokens_in_batch, sequence_length): - document_end_positions[i] = 1 + # Add the end token for the end of sample also + document_end_positions[sequence_length - 1 :: sequence_length] = 1 cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 + + # Randomly delete shorter sequences 1% of the time to train + deleted = False + for i in range(cu_seqlens.size(0) - 1): + start_idx = cu_seqlens[i] + end_idx = cu_seqlens[i + 1] + if start_idx % sequence_length != 0: # make sure not start of sample + if torch.rand(1) < 0.01: + deleted = True + document_end_positions[start_idx - 1] = 0 + token_mask = input_ids[start_idx:end_idx] == self.mask_token_id + input_ids[start_idx:end_idx] = torch.where( + token_mask, input_ids[start_idx:end_idx], self.pad_token_id + ) + flat_labels[start_idx:end_idx] = torch.where( + ~token_mask, flat_labels[start_idx:end_idx], self.pad_token_id + ) + if deleted: + cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 + cu_seqlens = torch.cat([torch.tensor([0], device=input_ids.device), cu_seqlens]) cu_seqlens = cu_seqlens.to(torch.int32) @@ -220,7 +243,8 @@ def _prepare_model_inputs(self, batch: dict) -> dict: batch["cu_seqlens"] = cu_seqlens batch["max_seqlen"] = max_seqlen batch["position_ids"] = position_ids - # if cu_seqlens.size(0) > 5: + + # if deleted: # from transformers import PreTrainedTokenizer # tokenizer: PreTrainedTokenizer = self.tokenizer # for i in range(cu_seqlens.size(0) - 1): @@ -235,7 +259,7 @@ def _prepare_model_inputs(self, batch: dict) -> dict: # for idx, c in zip(seq_out, tokenizer.convert_ids_to_tokens(seq_out)) # ])) # print(cu_seqlens.size()) - # print((unnoised_input_ids == self.eos_token_id).int().sum()) + # # print((unnoised_input_ids == self.eos_token_id).int().sum()) # exit() batch["input_ids"] = input_ids From a51acd0c4f4b6833a9a804a09dfe6b7a210528ee Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Sun, 27 Jul 2025 22:40:16 +0000 Subject: [PATCH 08/19] Commented out code that allows variable lengths --- lm_engine/hf_models/models/diffusion/base.py | 12 +- lm_engine/hf_models/models/diffusion/main.py | 5 +- .../model_wrapper/pretraining_diffusion.py | 300 +++++++++++------- 3 files changed, 203 insertions(+), 114 deletions(-) diff --git a/lm_engine/hf_models/models/diffusion/base.py b/lm_engine/hf_models/models/diffusion/base.py index f4a055c44..0552d095a 100644 --- a/lm_engine/hf_models/models/diffusion/base.py +++ b/lm_engine/hf_models/models/diffusion/base.py @@ -10,4 +10,14 @@ class DiffusionPreTrainedModel(PreTrainedModelMixin): config_class = DiffusionConfig -class DiffusionModel(DiffusionPreTrainedModel, BaseModelMixin): ... +class DiffusionModel(DiffusionPreTrainedModel, BaseModelMixin): + def __init__(self, config, **kwargs): + if "mask_token_id" in kwargs: + self.mask_token_id = kwargs.pop("mask_token_id") + super().__init__(config, **kwargs) + + def _get_initial_hidden_state(self, input_ids, position_ids): + hidden_state = super()._get_initial_hidden_state(input_ids, position_ids) + # mask = (input_ids == self.mask_token_id)[:, None] + # hidden_state = hidden_state.masked_fill_(mask, 0) + return hidden_state diff --git a/lm_engine/hf_models/models/diffusion/main.py b/lm_engine/hf_models/models/diffusion/main.py index 86e4780a3..af01e9b6a 100644 --- a/lm_engine/hf_models/models/diffusion/main.py +++ b/lm_engine/hf_models/models/diffusion/main.py @@ -23,12 +23,15 @@ class DiffusionMaskedLM(DiffusionPreTrainedModel): def __init__(self, config: DiffusionConfig, **kwargs) -> DiffusionPreTrainedModel: + if "mask_token_id" in kwargs: + self.mask_token_id = kwargs.pop("mask_token_id") super().__init__(config, **kwargs) - self.router_aux_loss_coef = getattr(config, "router_aux_loss_coef", 0) self._init_model(config, **kwargs) def _init_model(self, config: DiffusionConfig, **kwargs) -> None: + if hasattr(self, "mask_token_id"): + kwargs["mask_token_id"] = self.mask_token_id self.transformer = DiffusionModel(config, **kwargs) if not self._tied_word_embeddings: diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index 881eabbe2..b6984501d 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -29,6 +29,52 @@ class ModelWrapperForPretrainingDiffusion(ModelWrapperForPretraining): + def __init__( + self, + model_name: str | None, + pretrained_config: dict | None, + model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM, + dtype: torch.dtype, + efficient_initialization: bool, + use_padding_free_transformer: bool, + sequence_parallel: bool, + micro_batch_size: int, + sequence_length: int, + num_pipeline_stages: int, + pipeline_stage_id: int, + trust_remote_code: bool = False, + tokenizer_name: str | None = None, + additional_special_tokens: list[str] | None = None, + reset_attention_mask: bool = False, + reset_position_ids: bool = False, + keep_in_fp32: bool = True, + ) -> ModelWrapperForPretraining: + super().__init__( + model_name, + pretrained_config, + model_class, + dtype, + efficient_initialization, + use_padding_free_transformer, + sequence_parallel, + micro_batch_size, + sequence_length, + num_pipeline_stages, + pipeline_stage_id, + trust_remote_code, + tokenizer_name, + additional_special_tokens, + reset_attention_mask, + reset_position_ids, + keep_in_fp32, + ) + assert self.use_padding_free_transformer and self.reset_attention_mask + + def _get_model_kwargs(self): + kwargs = super()._get_model_kwargs() + if hasattr(self, "mask_token_id"): + kwargs["mask_token_id"] = self.mask_token_id + return kwargs def forward( self, @@ -104,24 +150,12 @@ def get_loss( F.cross_entropy( input=flat_logits, target=flat_labels, - ignore_index=self.mask_token_id, + ignore_index=self.ignore_token_id, reduction="none", ) / p_mask.flatten() ).sum() - # lm_loss = get_autoregressive_language_modeling_loss( - # lm_logits=None if use_fused_linear_cross_entropy_kernel else model_outputs.logits, - # labels=labels, - # hidden_states=model_outputs.last_hidden_state if use_fused_linear_cross_entropy_kernel else None, - # vocab_weight=self.model.get_output_embeddings().weight if use_fused_linear_cross_entropy_kernel else None, - # cu_seqlens=None, - # use_padding_free_transformer=self.use_padding_free_transformer, - # reduction="sum", - # shift_logits_and_labels=False, - # tensor_parallel_enabled=tensor_parallel_enabled, - # ) - lm_loss = lm_loss * lm_loss_multiplier aux_loss = getattr(model_outputs, "aux_loss", 0) @@ -147,37 +181,37 @@ def _setup_tokenizer(self) -> None: assert self.mask_token_id is not None self.pad_token_id = self.tokenizer.pad_token_id assert self.pad_token_id is not None + self.ignore_token_id = self.pad_token_id # self.mask_token_id + + # def _forward_process(self, input_ids, eps=1e-3): + # b, l = input_ids.shape + # t = torch.rand(b, device=input_ids.device) + # p_mask = (1 - eps) * t + eps + # p_mask = p_mask[:, None].repeat(1, l) + + # masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask + # noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids) + # labels = torch.where(~masked_indices, self.ignore_token_id, input_ids) + # return noisy_batch, labels, p_mask def _forward_process(self, input_ids, eps=1e-3): - b, l = input_ids.shape - t = torch.rand(b, device=input_ids.device) + t = torch.rand(1, device=input_ids.device) p_mask = (1 - eps) * t + eps - p_mask = p_mask[:, None].repeat(1, l) - - masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask + masked_indices = torch.rand_like(input_ids, dtype=t.dtype) < p_mask noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids) - labels = torch.where(~masked_indices, self.mask_token_id, input_ids) - return noisy_batch, labels, p_mask + labels = torch.where(~masked_indices, self.ignore_token_id, input_ids) + return noisy_batch, labels, ~masked_indices, p_mask + + def _forward_and_place(self, curr_seq, start_idx, end_idx, input_ids, labels, p_mask): + curr_seq_input, curr_seq_labels, curr_unmasked_idxs, curr_seq_p_mask = self._forward_process(curr_seq) + input_ids[start_idx:end_idx] = curr_seq_input + labels[start_idx:end_idx] = curr_seq_labels + p_mask[start_idx:end_idx] = curr_seq_p_mask + return curr_seq_input, curr_seq_labels, curr_seq_p_mask def _prepare_model_inputs(self, batch: dict) -> dict: if self.is_pipeline_parallel_enabled: raise NotImplementedError("No pipeline for diffusion yet.") - # when using pipeline parallel, we broadcast the input outside the model function - # tokens = batch["text"] - # aux_loss_from_pipeline_parallel = batch["aux_loss_from_pipeline_parallel"] - - # tokens = tokens.to(torch.cuda.current_device()) - - # if self.is_first_stage: - # input_ids = tokens - # pipeline_parallel_input = None - # else: - # input_ids = None - # pipeline_parallel_input = PipelineParallelInput( - # hidden_states=tokens, aux_loss=aux_loss_from_pipeline_parallel - # ) - - # batch = {"labels": None, "pipeline_parallel_input": pipeline_parallel_input} else: if ProcessGroupManager.is_tensor_parallel_enabled(): tokens = broadcast_tensor_parallel_input( @@ -186,83 +220,125 @@ def _prepare_model_inputs(self, batch: dict) -> dict: else: tokens = batch["text"] tokens = tokens.to(torch.cuda.current_device()) - unnoised_input_ids = tokens[:, 1:] - input_ids, labels, p_mask = self._forward_process(unnoised_input_ids) - batch = {"labels": labels, "p_mask": p_mask} - - if self.use_padding_free_transformer: - - batch_size, sequence_length = unnoised_input_ids.shape - input_ids = input_ids.flatten() - flat_labels = labels.view(-1) - if self.reset_attention_mask: - unnoised_input_ids = unnoised_input_ids.flatten() - batch_size * sequence_length - document_end_positions = unnoised_input_ids == self.eos_token_id - # Add the end token for the end of sample also - document_end_positions[sequence_length - 1 :: sequence_length] = 1 - cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 - - # Randomly delete shorter sequences 1% of the time to train - deleted = False - for i in range(cu_seqlens.size(0) - 1): - start_idx = cu_seqlens[i] - end_idx = cu_seqlens[i + 1] - if start_idx % sequence_length != 0: # make sure not start of sample - if torch.rand(1) < 0.01: - deleted = True - document_end_positions[start_idx - 1] = 0 - token_mask = input_ids[start_idx:end_idx] == self.mask_token_id - input_ids[start_idx:end_idx] = torch.where( - token_mask, input_ids[start_idx:end_idx], self.pad_token_id - ) - flat_labels[start_idx:end_idx] = torch.where( - ~token_mask, flat_labels[start_idx:end_idx], self.pad_token_id - ) - if deleted: - cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 - - cu_seqlens = torch.cat([torch.tensor([0], device=input_ids.device), cu_seqlens]) - cu_seqlens = cu_seqlens.to(torch.int32) - - seqlen = cu_seqlens[1:] - cu_seqlens[:-1] - # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers - max_seqlen = seqlen.max().item() - - if self.reset_position_ids: - position_ids = torch.cat( - [torch.arange(0, i, 1, dtype=torch.int32, device=input_ids.device) for i in seqlen] - ) - else: - position_ids = self.position_ids + if torch.rand(1, device=tokens.device) < 0.5: + unnoised_input_ids = tokens[:, 1:] else: - cu_seqlens = self.cu_seqlens - max_seqlen = self.sequence_length - position_ids = self.position_ids - - batch["cu_seqlens"] = cu_seqlens - batch["max_seqlen"] = max_seqlen - batch["position_ids"] = position_ids - - # if deleted: - # from transformers import PreTrainedTokenizer - # tokenizer: PreTrainedTokenizer = self.tokenizer - # for i in range(cu_seqlens.size(0) - 1): - # seq_in = input_ids.flatten()[cu_seqlens[i]:cu_seqlens[i+1]] - # print(' '.join([ - # c if idx != self.mask_token_id else "_" - # for idx, c in zip(seq_in, tokenizer.convert_ids_to_tokens(seq_in)) - # ])) - # seq_out = labels.flatten()[cu_seqlens[i]:cu_seqlens[i+1]] - # print(' '.join([ - # c if idx != self.mask_token_id else "_" - # for idx, c in zip(seq_out, tokenizer.convert_ids_to_tokens(seq_out)) - # ])) - # print(cu_seqlens.size()) - # # print((unnoised_input_ids == self.eos_token_id).int().sum()) - # exit() - - batch["input_ids"] = input_ids + unnoised_input_ids = tokens[:, :-1] + # input_ids, labels, p_mask = self._forward_process(unnoised_input_ids) + # batch = {"labels": labels, "p_mask": p_mask} + batch = {} + + batch_size, sequence_length = unnoised_input_ids.shape + unnoised_input_ids = unnoised_input_ids.flatten() # batch_size * sequence_length + + input_ids = torch.empty_like(unnoised_input_ids) + labels = torch.empty_like(unnoised_input_ids) + p_mask = torch.empty_like(unnoised_input_ids, dtype=torch.bfloat16) + + document_end_positions = unnoised_input_ids == self.eos_token_id + document_end_positions[sequence_length - 1 :: sequence_length] = ( + 1 # Add the end token for the end of sample also + ) + + cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 + for i in range(cu_seqlens.size(0)): + start_idx = 0 if i == 0 else cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + + # # Sometimes replace boundaries of short sequences with more to encourage learning to earlier. + # orig_start_idx = start_idx + # if i > 0 and not moved_previous: + # prev_start_idx = cu_seqlens[i - 2] + # curr_seq_len = end_idx - start_idx + # prev_seq_len = start_idx - prev_start_idx + # if (unnoised_input_ids[start_idx - 1] == self.eos_token_id and + # prev_seq_len * torch.rand(1, device=unnoised_input_ids.device) > curr_seq_len): # make sure shorter seq + # # Resize previous entry + # assert document_end_positions[end_idx - 1] + # # curr_seq = unnoised_input_ids[start_idx:end_idx] + + # # Extend previous sequence + # prevandcurr_seq = unnoised_input_ids[prev_start_idx:end_idx].clone() + # # Mask out all current sequence + # prevandcurr_seq[start_idx - prev_start_idx:] = self.eos_token_id + # # Renoise previous sequence + # prevandcurr_input, prevandcurr_labels, _, prevandcurr_mask \ + # = self._forward_process(prevandcurr_seq) + + # # Find first input that is EOS + # eos_mask = (prevandcurr_input == self.eos_token_id) + # if eos_mask.any(): + # first_input_eos = eos_mask.nonzero(as_tuple=True)[0].min() + # start_idx = prev_start_idx + first_input_eos + 1 + # document_end_positions[start_idx - 1] = True + # else: + # first_input_eos = prevandcurr_input.size(0) - 1 + # start_idx = end_idx + # input_ids[prev_start_idx:start_idx] = prevandcurr_input[:first_input_eos + 1] + # labels[prev_start_idx:start_idx] = prevandcurr_labels[:first_input_eos + 1] + # p_mask[prev_start_idx:start_idx] = prevandcurr_mask[:first_input_eos + 1] + # moved_previous = True + + # # Move barrier + # document_end_positions[orig_start_idx - 1] = False + # document_end_positions[start_idx - 1] = True + # else: + # moved_previous = False + + # moved_boundary = (start_idx != orig_start_idx) or moved_boundary + + curr_seq = unnoised_input_ids[start_idx:end_idx] + curr_seq_input, curr_seq_labels, curr_unmasked_idxs, curr_seq_p_mask = self._forward_process(curr_seq) + + input_ids[start_idx:end_idx] = curr_seq_input + labels[start_idx:end_idx] = curr_seq_labels + p_mask[start_idx:end_idx] = curr_seq_p_mask + + cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 + cu_seqlens = torch.cat([torch.tensor([0], device=unnoised_input_ids.device), cu_seqlens]).to(torch.int32) + seqlen = cu_seqlens[1:] - cu_seqlens[:-1] + # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers + max_seqlen = seqlen.max().item() + + if self.reset_position_ids: + position_ids = torch.cat( + [torch.arange(0, i, 1, dtype=torch.int32, device=unnoised_input_ids.device) for i in seqlen] + ) + else: + position_ids = self.position_ids + + batch["input_ids"] = input_ids.flatten() + batch["labels"] = labels.flatten() + batch["p_mask"] = p_mask.flatten() + batch["cu_seqlens"] = cu_seqlens + batch["max_seqlen"] = max_seqlen + batch["position_ids"] = position_ids + + # if moved_boundary: + # from transformers import PreTrainedTokenizer + # tokenizer: PreTrainedTokenizer = self.tokenizer + # def to_token_list(seq): + # output = [] + # for idx in seq: + # if idx == self.ignore_token_id: + # c = '' + # elif idx == self.mask_token_id: + # c = '_' + # else: + # c = tokenizer._convert_id_to_token(idx) + # output.append(c) + # return output + # for i in range(cu_seqlens.size(0) - 1): + # seq_in = input_ids.flatten()[cu_seqlens[i]:cu_seqlens[i+1]] + # seq_out = labels.flatten()[cu_seqlens[i]:cu_seqlens[i+1]] + # seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) + # print() + # print(cu_seqlens[i], cu_seqlens[i+1]) + # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) + # print(cu_seqlens) + # exit() + # else: + # print("No deletions.") if ProcessGroupManager.is_tensor_parallel_enabled(): batch["output_parallel_lm_logits"] = True From 5189e3150194d15a23204a031012e2244d355b37 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Mon, 28 Jul 2025 04:15:07 +0000 Subject: [PATCH 09/19] Working. --- .../model_wrapper/pretraining_diffusion.py | 130 +++++++----------- 1 file changed, 51 insertions(+), 79 deletions(-) diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index b6984501d..10078f443 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -210,6 +210,7 @@ def _forward_and_place(self, curr_seq, start_idx, end_idx, input_ids, labels, p_ return curr_seq_input, curr_seq_labels, curr_seq_p_mask def _prepare_model_inputs(self, batch: dict) -> dict: + device = torch.cuda.current_device() if self.is_pipeline_parallel_enabled: raise NotImplementedError("No pipeline for diffusion yet.") else: @@ -219,7 +220,7 @@ def _prepare_model_inputs(self, batch: dict) -> dict: ) else: tokens = batch["text"] - tokens = tokens.to(torch.cuda.current_device()) + tokens = tokens.to(device) if torch.rand(1, device=tokens.device) < 0.5: unnoised_input_ids = tokens[:, 1:] else: @@ -229,8 +230,9 @@ def _prepare_model_inputs(self, batch: dict) -> dict: batch = {} batch_size, sequence_length = unnoised_input_ids.shape - unnoised_input_ids = unnoised_input_ids.flatten() # batch_size * sequence_length + assert batch_size % 2 == 0 + unnoised_input_ids = unnoised_input_ids.flatten() # batch_size * sequence_length input_ids = torch.empty_like(unnoised_input_ids) labels = torch.empty_like(unnoised_input_ids) p_mask = torch.empty_like(unnoised_input_ids, dtype=torch.bfloat16) @@ -240,59 +242,26 @@ def _prepare_model_inputs(self, batch: dict) -> dict: 1 # Add the end token for the end of sample also ) - cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 - for i in range(cu_seqlens.size(0)): - start_idx = 0 if i == 0 else cu_seqlens[i - 1] - end_idx = cu_seqlens[i] - - # # Sometimes replace boundaries of short sequences with more to encourage learning to earlier. - # orig_start_idx = start_idx - # if i > 0 and not moved_previous: - # prev_start_idx = cu_seqlens[i - 2] - # curr_seq_len = end_idx - start_idx - # prev_seq_len = start_idx - prev_start_idx - # if (unnoised_input_ids[start_idx - 1] == self.eos_token_id and - # prev_seq_len * torch.rand(1, device=unnoised_input_ids.device) > curr_seq_len): # make sure shorter seq - # # Resize previous entry - # assert document_end_positions[end_idx - 1] - # # curr_seq = unnoised_input_ids[start_idx:end_idx] - - # # Extend previous sequence - # prevandcurr_seq = unnoised_input_ids[prev_start_idx:end_idx].clone() - # # Mask out all current sequence - # prevandcurr_seq[start_idx - prev_start_idx:] = self.eos_token_id - # # Renoise previous sequence - # prevandcurr_input, prevandcurr_labels, _, prevandcurr_mask \ - # = self._forward_process(prevandcurr_seq) - - # # Find first input that is EOS - # eos_mask = (prevandcurr_input == self.eos_token_id) - # if eos_mask.any(): - # first_input_eos = eos_mask.nonzero(as_tuple=True)[0].min() - # start_idx = prev_start_idx + first_input_eos + 1 - # document_end_positions[start_idx - 1] = True - # else: - # first_input_eos = prevandcurr_input.size(0) - 1 - # start_idx = end_idx - # input_ids[prev_start_idx:start_idx] = prevandcurr_input[:first_input_eos + 1] - # labels[prev_start_idx:start_idx] = prevandcurr_labels[:first_input_eos + 1] - # p_mask[prev_start_idx:start_idx] = prevandcurr_mask[:first_input_eos + 1] - # moved_previous = True - - # # Move barrier - # document_end_positions[orig_start_idx - 1] = False - # document_end_positions[start_idx - 1] = True - # else: - # moved_previous = False - - # moved_boundary = (start_idx != orig_start_idx) or moved_boundary - - curr_seq = unnoised_input_ids[start_idx:end_idx] - curr_seq_input, curr_seq_labels, curr_unmasked_idxs, curr_seq_p_mask = self._forward_process(curr_seq) - - input_ids[start_idx:end_idx] = curr_seq_input - labels[start_idx:end_idx] = curr_seq_labels - p_mask[start_idx:end_idx] = curr_seq_p_mask + eps = 1e-3 + + def _apply_mask_and_fill(start_idx, end_idx, masked): + x = unnoised_input_ids[start_idx:end_idx] + p = masked.float().mean() + input_ids[start_idx:end_idx] = torch.where(masked, self.mask_token_id, x) + labels[start_idx:end_idx] = torch.where(~masked, self.ignore_token_id, x) + p_mask[start_idx:end_idx] = p + + for i in range(0, batch_size, 2): + t = torch.rand(1, device=input_ids.device)[0] + p = (1 - eps) * t + eps + mask_count = torch.round(p * sequence_length).to(torch.int32) + + masked = torch.randperm(sequence_length, device=device) < mask_count + _apply_mask_and_fill(start_idx=i * sequence_length, end_idx=(i + 1) * sequence_length, masked=masked) + + mask_count = sequence_length - mask_count + masked = torch.randperm(sequence_length, device=device) < mask_count + _apply_mask_and_fill(start_idx=(i + 1) * sequence_length, end_idx=(i + 2) * sequence_length, masked=masked) cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 cu_seqlens = torch.cat([torch.tensor([0], device=unnoised_input_ids.device), cu_seqlens]).to(torch.int32) @@ -313,30 +282,33 @@ def _prepare_model_inputs(self, batch: dict) -> dict: batch["cu_seqlens"] = cu_seqlens batch["max_seqlen"] = max_seqlen batch["position_ids"] = position_ids - - # if moved_boundary: - # from transformers import PreTrainedTokenizer - # tokenizer: PreTrainedTokenizer = self.tokenizer - # def to_token_list(seq): - # output = [] - # for idx in seq: - # if idx == self.ignore_token_id: - # c = '' - # elif idx == self.mask_token_id: - # c = '_' - # else: - # c = tokenizer._convert_id_to_token(idx) - # output.append(c) - # return output - # for i in range(cu_seqlens.size(0) - 1): - # seq_in = input_ids.flatten()[cu_seqlens[i]:cu_seqlens[i+1]] - # seq_out = labels.flatten()[cu_seqlens[i]:cu_seqlens[i+1]] - # seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) - # print() - # print(cu_seqlens[i], cu_seqlens[i+1]) - # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) - # print(cu_seqlens) - # exit() + print((input_ids == self.mask_token_id).int().sum()) + + from transformers import PreTrainedTokenizer + + tokenizer: PreTrainedTokenizer = self.tokenizer + + def to_token_list(seq): + output = [] + for idx in seq: + if idx == self.ignore_token_id: + c = "" + elif idx == self.mask_token_id: + c = "_" + else: + c = tokenizer._convert_id_to_token(idx) + output.append(c) + return output + + for i in range(cu_seqlens.size(0) - 1): + seq_in = input_ids.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] + seq_out = labels.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] + seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) + print() + print(cu_seqlens[i].item(), cu_seqlens[i + 1].item()) + print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) + print(cu_seqlens) + exit() # else: # print("No deletions.") From 51b24e7317220adcd679ee469bc4e435f1efdea3 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Tue, 29 Jul 2025 19:19:11 +0000 Subject: [PATCH 10/19] Bulk randperm for masking. --- lm_engine/hf_models/models/diffusion/main.py | 4 + .../model_wrapper/pretraining_diffusion.py | 173 +++++++++++------- 2 files changed, 111 insertions(+), 66 deletions(-) diff --git a/lm_engine/hf_models/models/diffusion/main.py b/lm_engine/hf_models/models/diffusion/main.py index af01e9b6a..6504792b6 100644 --- a/lm_engine/hf_models/models/diffusion/main.py +++ b/lm_engine/hf_models/models/diffusion/main.py @@ -70,6 +70,7 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: int | None = None, reduction: str = "mean", + masked_indices: torch.Tensor | None = None, ) -> CausalLMOutputWithPast: assert return_dict assert inputs_embeds is None @@ -113,6 +114,9 @@ def forward( lm_logits = None loss = None + if masked_indices is not None: + hidden_states = hidden_states[masked_indices] + if labels is None: if is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute): if self.m_width is not None: diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index 10078f443..8a6fbe7d1 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -110,6 +110,7 @@ def forward( batch = self._prepare_model_inputs(batch) labels = batch.pop("labels") p_mask = batch.pop("p_mask") + masked_indices = batch["masked_indices"] output: CausalLMOutputWithPast | PipelineParallelOutput = self.model(**batch, return_dict=True) if self.is_pipeline_parallel_enabled: @@ -130,7 +131,8 @@ def forward( if use_aux_loss: output = (output, aux_loss) else: - output = self.get_loss(output, labels, p_mask, lm_loss_multiplier=lm_loss_multiplier) + assert (labels[batch["masked_indices"]] != self.ignore_token_id).all() + output = self.get_loss(output, labels, masked_indices, p_mask, lm_loss_multiplier=lm_loss_multiplier) return output @@ -138,13 +140,15 @@ def get_loss( self, model_outputs: CausalLMOutputWithPast, labels: torch.Tensor, + masked_indices: torch.Tensor, p_mask: torch.Tensor, lm_loss_multiplier: float = 1, ) -> torch.Tensor | dict: tensor_parallel_enabled = ProcessGroupManager.is_tensor_parallel_enabled() # use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) flat_logits = model_outputs.logits.flatten(0, -2) - flat_labels = labels.flatten() + flat_labels = labels.flatten()[masked_indices] + flat_p_mask = p_mask.flatten()[masked_indices] # print(flat_logits.size(), flat_labels.size()) lm_loss = ( F.cross_entropy( @@ -153,7 +157,7 @@ def get_loss( ignore_index=self.ignore_token_id, reduction="none", ) - / p_mask.flatten() + / flat_p_mask ).sum() lm_loss = lm_loss * lm_loss_multiplier @@ -181,7 +185,7 @@ def _setup_tokenizer(self) -> None: assert self.mask_token_id is not None self.pad_token_id = self.tokenizer.pad_token_id assert self.pad_token_id is not None - self.ignore_token_id = self.pad_token_id # self.mask_token_id + self.ignore_token_id = -1 # self.pad_token_id # self.mask_token_id # def _forward_process(self, input_ids, eps=1e-3): # b, l = input_ids.shape @@ -193,21 +197,20 @@ def _setup_tokenizer(self) -> None: # noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids) # labels = torch.where(~masked_indices, self.ignore_token_id, input_ids) # return noisy_batch, labels, p_mask + # def _forward_process(self, input_ids, eps=1e-3): + # t = torch.rand(1, device=input_ids.device) + # p_mask = (1 - eps) * t + eps + # masked_indices = torch.rand_like(input_ids, dtype=t.dtype) < p_mask + # noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids) + # labels = torch.where(~masked_indices, self.ignore_token_id, input_ids) + # return noisy_batch, labels, ~masked_indices, p_mask - def _forward_process(self, input_ids, eps=1e-3): - t = torch.rand(1, device=input_ids.device) - p_mask = (1 - eps) * t + eps - masked_indices = torch.rand_like(input_ids, dtype=t.dtype) < p_mask - noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids) - labels = torch.where(~masked_indices, self.ignore_token_id, input_ids) - return noisy_batch, labels, ~masked_indices, p_mask - - def _forward_and_place(self, curr_seq, start_idx, end_idx, input_ids, labels, p_mask): - curr_seq_input, curr_seq_labels, curr_unmasked_idxs, curr_seq_p_mask = self._forward_process(curr_seq) - input_ids[start_idx:end_idx] = curr_seq_input - labels[start_idx:end_idx] = curr_seq_labels - p_mask[start_idx:end_idx] = curr_seq_p_mask - return curr_seq_input, curr_seq_labels, curr_seq_p_mask + # def _forward_and_place(self, curr_seq, start_idx, end_idx, input_ids, labels, p_mask): + # curr_seq_input, curr_seq_labels, curr_unmasked_idxs, curr_seq_p_mask = self._forward_process(curr_seq) + # input_ids[start_idx:end_idx] = curr_seq_input + # labels[start_idx:end_idx] = curr_seq_labels + # p_mask[start_idx:end_idx] = curr_seq_p_mask + # return curr_seq_input, curr_seq_labels, curr_seq_p_mask def _prepare_model_inputs(self, batch: dict) -> dict: device = torch.cuda.current_device() @@ -230,39 +233,74 @@ def _prepare_model_inputs(self, batch: dict) -> dict: batch = {} batch_size, sequence_length = unnoised_input_ids.shape - assert batch_size % 2 == 0 + + perm_idxs = torch.argsort(torch.rand_like(unnoised_input_ids, dtype=torch.bfloat16), dim=-1) unnoised_input_ids = unnoised_input_ids.flatten() # batch_size * sequence_length - input_ids = torch.empty_like(unnoised_input_ids) - labels = torch.empty_like(unnoised_input_ids) + input_ids = unnoised_input_ids.clone() + labels = torch.full_like(unnoised_input_ids, fill_value=self.ignore_token_id) p_mask = torch.empty_like(unnoised_input_ids, dtype=torch.bfloat16) - document_end_positions = unnoised_input_ids == self.eos_token_id - document_end_positions[sequence_length - 1 :: sequence_length] = ( - 1 # Add the end token for the end of sample also + assert batch_size % 2 == 0 + masked_ptr = 0 + masked_indices = ( + torch.zeros((batch_size // 2) * sequence_length, dtype=input_ids.dtype, device=input_ids.device) - 1 ) - eps = 1e-3 - - def _apply_mask_and_fill(start_idx, end_idx, masked): - x = unnoised_input_ids[start_idx:end_idx] - p = masked.float().mean() - input_ids[start_idx:end_idx] = torch.where(masked, self.mask_token_id, x) - labels[start_idx:end_idx] = torch.where(~masked, self.ignore_token_id, x) - p_mask[start_idx:end_idx] = p + document_end_positions = unnoised_input_ids == self.eos_token_id + document_end_positions[sequence_length - 1 :: sequence_length] = 1 + eps = 1e-4 + + def _apply_mask_and_fill(start_idx, end_idx, masked_idxs): + x = unnoised_input_ids[start_idx:end_idx].clone() + row_p = p_mask[start_idx:end_idx] + labels[start_idx:end_idx][masked_idxs] = input_ids[start_idx:end_idx][masked_idxs] + input_ids[start_idx:end_idx][masked_idxs] = self.mask_token_id + # input_ids[start_idx:end_idx] = torch.where(masked, self.mask_token_id, x) + # labels[start_idx:end_idx] = torch.where(~masked, self.ignore_token_id, x) + + end_positions = x == self.eos_token_id + end_positions[-1] = True + cu_seqlens = end_positions.nonzero(as_tuple=True)[0] + 1 + + for i in range(cu_seqlens.size(0)): + doc_start = 0 if i == 0 else cu_seqlens[i - 1] + doc_end = cu_seqlens[i] + doc_mask = input_ids[start_idx:end_idx][doc_start:doc_end] == self.mask_token_id + row_p[doc_start:doc_end] = torch.clamp(doc_mask.float().mean(), min=eps) for i in range(0, batch_size, 2): t = torch.rand(1, device=input_ids.device)[0] - p = (1 - eps) * t + eps - mask_count = torch.round(p * sequence_length).to(torch.int32) + p = (1 - 2 * eps) * t + eps - masked = torch.randperm(sequence_length, device=device) < mask_count - _apply_mask_and_fill(start_idx=i * sequence_length, end_idx=(i + 1) * sequence_length, masked=masked) + mask_count = torch.round(p * sequence_length).to(torch.int32) + # masked: torch.Tensor = torch.zeros(sequence_length, device=input_ids.device, dtype=torch.bool) + masked_idxs_ = perm_idxs[i, :mask_count] + # masked[masked_idxs_] = True + # assert masked.int().sum() == mask_count, masked_idxs_ + _apply_mask_and_fill( + start_idx=i * sequence_length, end_idx=(i + 1) * sequence_length, masked_idxs=masked_idxs_ + ) + masked_indices[masked_ptr : masked_ptr + mask_count] = i * sequence_length + masked_idxs_ + masked_ptr += mask_count + # masked[:] = False mask_count = sequence_length - mask_count - masked = torch.randperm(sequence_length, device=device) < mask_count - _apply_mask_and_fill(start_idx=(i + 1) * sequence_length, end_idx=(i + 2) * sequence_length, masked=masked) - + # masked: torch.Tensor = torch.zeros(sequence_length, device=input_ids.device, dtype=torch.bool) + masked_idxs_ = perm_idxs[i + 1, :mask_count] + # masked[masked_idxs_] = True + # assert masked.int().sum() == mask_count, masked_idxs_ + _apply_mask_and_fill( + start_idx=(i + 1) * sequence_length, end_idx=(i + 2) * sequence_length, masked_idxs=masked_idxs_ + ) + masked_indices[masked_ptr : masked_ptr + mask_count] = (i + 1) * sequence_length + masked_idxs_ + masked_ptr += mask_count + + masked_indices, _ = torch.sort(masked_indices) + # idxs = torch.arange(masked_indices.size(0), device=masked_indices.device) + # consec_equal = masked_indices[idxs[:-1]] == masked_indices[idxs[1:]] + # assert not consec_equal.any(), masked_indices[:-1][consec_equal] + # assert masked_indices.size(0) == torch.unique(masked_indices).size(0) cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 cu_seqlens = torch.cat([torch.tensor([0], device=unnoised_input_ids.device), cu_seqlens]).to(torch.int32) seqlen = cu_seqlens[1:] - cu_seqlens[:-1] @@ -276,39 +314,42 @@ def _apply_mask_and_fill(start_idx, end_idx, masked): else: position_ids = self.position_ids + # masked_idxs = (labels != self.ignore_token_id).nonzero(as_tuple=True)[0] + # masked_idxs, _ = torch.sort(masked_idxs) + assert (labels[masked_indices] != self.ignore_token_id).all() + assert (input_ids[masked_indices] == self.mask_token_id).all() + batch["input_ids"] = input_ids.flatten() batch["labels"] = labels.flatten() batch["p_mask"] = p_mask.flatten() batch["cu_seqlens"] = cu_seqlens batch["max_seqlen"] = max_seqlen batch["position_ids"] = position_ids - print((input_ids == self.mask_token_id).int().sum()) - - from transformers import PreTrainedTokenizer - - tokenizer: PreTrainedTokenizer = self.tokenizer - - def to_token_list(seq): - output = [] - for idx in seq: - if idx == self.ignore_token_id: - c = "" - elif idx == self.mask_token_id: - c = "_" - else: - c = tokenizer._convert_id_to_token(idx) - output.append(c) - return output - - for i in range(cu_seqlens.size(0) - 1): - seq_in = input_ids.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] - seq_out = labels.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] - seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) - print() - print(cu_seqlens[i].item(), cu_seqlens[i + 1].item()) - print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) - print(cu_seqlens) - exit() + batch["masked_indices"] = masked_indices + # from transformers import PreTrainedTokenizer + # tokenizer: PreTrainedTokenizer = self.tokenizer + # def to_token_list(seq): + # output = [] + # for idx in seq: + # if idx == self.ignore_token_id: + # c = "" + # elif idx == self.mask_token_id: + # c = "_" + # else: + # c = tokenizer._convert_id_to_token(idx) + # output.append(c) + # return output + # print((input_ids == self.mask_token_id).int().sum().item()) + # for i in range(cu_seqlens.size(0) - 1): + # seq_in = input_ids.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] + # seq_out = labels.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] + # seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) + # assert p_mask[cu_seqlens[i]] == p_mask[cu_seqlens[i + 1] - 1] + # print() + # print(cu_seqlens[i].item(), cu_seqlens[i + 1].item(), p_mask[cu_seqlens[i + 1] - 1].item()) + # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) + # print(cu_seqlens) + # exit() # else: # print("No deletions.") From c4de13e8d5cf9f339cd0c954edd2081a1f334e8b Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Tue, 29 Jul 2025 22:01:00 +0000 Subject: [PATCH 11/19] Cleanup. --- lm_engine/model_wrapper/pretraining_diffusion.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index 8a6fbe7d1..2285ff472 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -274,22 +274,15 @@ def _apply_mask_and_fill(start_idx, end_idx, masked_idxs): p = (1 - 2 * eps) * t + eps mask_count = torch.round(p * sequence_length).to(torch.int32) - # masked: torch.Tensor = torch.zeros(sequence_length, device=input_ids.device, dtype=torch.bool) masked_idxs_ = perm_idxs[i, :mask_count] - # masked[masked_idxs_] = True - # assert masked.int().sum() == mask_count, masked_idxs_ _apply_mask_and_fill( start_idx=i * sequence_length, end_idx=(i + 1) * sequence_length, masked_idxs=masked_idxs_ ) masked_indices[masked_ptr : masked_ptr + mask_count] = i * sequence_length + masked_idxs_ masked_ptr += mask_count - # masked[:] = False mask_count = sequence_length - mask_count - # masked: torch.Tensor = torch.zeros(sequence_length, device=input_ids.device, dtype=torch.bool) masked_idxs_ = perm_idxs[i + 1, :mask_count] - # masked[masked_idxs_] = True - # assert masked.int().sum() == mask_count, masked_idxs_ _apply_mask_and_fill( start_idx=(i + 1) * sequence_length, end_idx=(i + 2) * sequence_length, masked_idxs=masked_idxs_ ) From 21fe9da091c861d322698aaa23b76eaf78473484 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Wed, 30 Jul 2025 23:22:41 +0000 Subject: [PATCH 12/19] Double --- .../model_wrapper/pretraining_diffusion.py | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index 2285ff472..ec402c3fe 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -158,7 +158,7 @@ def get_loss( reduction="none", ) / flat_p_mask - ).sum() + ).sum() / 2 lm_loss = lm_loss * lm_loss_multiplier aux_loss = getattr(model_outputs, "aux_loss", 0) @@ -232,21 +232,21 @@ def _prepare_model_inputs(self, batch: dict) -> dict: # batch = {"labels": labels, "p_mask": p_mask} batch = {} - batch_size, sequence_length = unnoised_input_ids.shape + orig_batch_size, sequence_length = unnoised_input_ids.shape + batch_size = orig_batch_size * 2 perm_idxs = torch.argsort(torch.rand_like(unnoised_input_ids, dtype=torch.bfloat16), dim=-1) - - unnoised_input_ids = unnoised_input_ids.flatten() # batch_size * sequence_length + unnoised_input_ids = unnoised_input_ids.repeat_interleave(2, 0).flatten() input_ids = unnoised_input_ids.clone() - labels = torch.full_like(unnoised_input_ids, fill_value=self.ignore_token_id) - p_mask = torch.empty_like(unnoised_input_ids, dtype=torch.bfloat16) + # unnoised_input_ids = unnoised_input_ids.flatten() # batch_size * sequence_length + labels = torch.full_like(input_ids, fill_value=self.ignore_token_id) + p_mask = torch.empty_like(input_ids, dtype=torch.bfloat16) - assert batch_size % 2 == 0 + # assert batch_size % 2 == 0 masked_ptr = 0 masked_indices = ( - torch.zeros((batch_size // 2) * sequence_length, dtype=input_ids.dtype, device=input_ids.device) - 1 + torch.zeros(batch_size * (sequence_length // 2), dtype=input_ids.dtype, device=input_ids.device) - 1 ) - document_end_positions = unnoised_input_ids == self.eos_token_id document_end_positions[sequence_length - 1 :: sequence_length] = 1 eps = 1e-4 @@ -269,24 +269,27 @@ def _apply_mask_and_fill(start_idx, end_idx, masked_idxs): doc_mask = input_ids[start_idx:end_idx][doc_start:doc_end] == self.mask_token_id row_p[doc_start:doc_end] = torch.clamp(doc_mask.float().mean(), min=eps) - for i in range(0, batch_size, 2): + # for i in range(0, batch_size, 2): + for i in range(orig_batch_size): t = torch.rand(1, device=input_ids.device)[0] p = (1 - 2 * eps) * t + eps mask_count = torch.round(p * sequence_length).to(torch.int32) masked_idxs_ = perm_idxs[i, :mask_count] _apply_mask_and_fill( - start_idx=i * sequence_length, end_idx=(i + 1) * sequence_length, masked_idxs=masked_idxs_ + start_idx=2 * i * sequence_length, end_idx=(2 * i + 1) * sequence_length, masked_idxs=masked_idxs_ ) - masked_indices[masked_ptr : masked_ptr + mask_count] = i * sequence_length + masked_idxs_ + masked_indices[masked_ptr : masked_ptr + mask_count] = 2 * i * sequence_length + masked_idxs_ masked_ptr += mask_count + masked_idxs_ = perm_idxs[i, mask_count:] mask_count = sequence_length - mask_count - masked_idxs_ = perm_idxs[i + 1, :mask_count] _apply_mask_and_fill( - start_idx=(i + 1) * sequence_length, end_idx=(i + 2) * sequence_length, masked_idxs=masked_idxs_ + start_idx=(2 * i + 1) * sequence_length, + end_idx=(2 * i + 2) * sequence_length, + masked_idxs=masked_idxs_, ) - masked_indices[masked_ptr : masked_ptr + mask_count] = (i + 1) * sequence_length + masked_idxs_ + masked_indices[masked_ptr : masked_ptr + mask_count] = (2 * i + 1) * sequence_length + masked_idxs_ masked_ptr += mask_count masked_indices, _ = torch.sort(masked_indices) @@ -319,6 +322,7 @@ def _apply_mask_and_fill(start_idx, end_idx, masked_idxs): batch["max_seqlen"] = max_seqlen batch["position_ids"] = position_ids batch["masked_indices"] = masked_indices + # from transformers import PreTrainedTokenizer # tokenizer: PreTrainedTokenizer = self.tokenizer # def to_token_list(seq): From 39ffab51489c391f7533afd932ce19bfde20ad96 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Thu, 7 Aug 2025 00:41:56 +0000 Subject: [PATCH 13/19] Working. --- lm_engine/model_wrapper/__init__.py | 3 - .../model_wrapper/pretraining_diffusion.py | 160 ++++++++---------- 2 files changed, 72 insertions(+), 91 deletions(-) diff --git a/lm_engine/model_wrapper/__init__.py b/lm_engine/model_wrapper/__init__.py index 7880cf714..0370807d2 100644 --- a/lm_engine/model_wrapper/__init__.py +++ b/lm_engine/model_wrapper/__init__.py @@ -57,9 +57,6 @@ def get_model_container( kwargs["reset_attention_mask"] = args.model_args.reset_attention_mask kwargs["reset_position_ids"] = args.model_args.reset_position_ids - if tuning_method == TuningMethod.pretraining_diffusion: - print(args.model_args) - if tuning_method == TuningMethod.distillation: kwargs["teacher_model_name"] = args.teacher_args.model_name kwargs["teacher_model_class"] = args.teacher_args.model_class diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index ec402c3fe..c182b7afd 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -187,31 +187,6 @@ def _setup_tokenizer(self) -> None: assert self.pad_token_id is not None self.ignore_token_id = -1 # self.pad_token_id # self.mask_token_id - # def _forward_process(self, input_ids, eps=1e-3): - # b, l = input_ids.shape - # t = torch.rand(b, device=input_ids.device) - # p_mask = (1 - eps) * t + eps - # p_mask = p_mask[:, None].repeat(1, l) - - # masked_indices = torch.rand((b, l), device=input_ids.device) < p_mask - # noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids) - # labels = torch.where(~masked_indices, self.ignore_token_id, input_ids) - # return noisy_batch, labels, p_mask - # def _forward_process(self, input_ids, eps=1e-3): - # t = torch.rand(1, device=input_ids.device) - # p_mask = (1 - eps) * t + eps - # masked_indices = torch.rand_like(input_ids, dtype=t.dtype) < p_mask - # noisy_batch = torch.where(masked_indices, self.mask_token_id, input_ids) - # labels = torch.where(~masked_indices, self.ignore_token_id, input_ids) - # return noisy_batch, labels, ~masked_indices, p_mask - - # def _forward_and_place(self, curr_seq, start_idx, end_idx, input_ids, labels, p_mask): - # curr_seq_input, curr_seq_labels, curr_unmasked_idxs, curr_seq_p_mask = self._forward_process(curr_seq) - # input_ids[start_idx:end_idx] = curr_seq_input - # labels[start_idx:end_idx] = curr_seq_labels - # p_mask[start_idx:end_idx] = curr_seq_p_mask - # return curr_seq_input, curr_seq_labels, curr_seq_p_mask - def _prepare_model_inputs(self, batch: dict) -> dict: device = torch.cuda.current_device() if self.is_pipeline_parallel_enabled: @@ -224,13 +199,14 @@ def _prepare_model_inputs(self, batch: dict) -> dict: else: tokens = batch["text"] tokens = tokens.to(device) - if torch.rand(1, device=tokens.device) < 0.5: - unnoised_input_ids = tokens[:, 1:] - else: - unnoised_input_ids = tokens[:, :-1] + # if torch.rand(1, device=tokens.device) < 0.5: + # unnoised_input_ids = tokens[:, 1:] + # else: + # unnoised_input_ids = tokens[:, :-1] + unnoised_input_ids = tokens # input_ids, labels, p_mask = self._forward_process(unnoised_input_ids) # batch = {"labels": labels, "p_mask": p_mask} - batch = {} + # batch = {} orig_batch_size, sequence_length = unnoised_input_ids.shape batch_size = orig_batch_size * 2 @@ -238,65 +214,71 @@ def _prepare_model_inputs(self, batch: dict) -> dict: perm_idxs = torch.argsort(torch.rand_like(unnoised_input_ids, dtype=torch.bfloat16), dim=-1) unnoised_input_ids = unnoised_input_ids.repeat_interleave(2, 0).flatten() input_ids = unnoised_input_ids.clone() - # unnoised_input_ids = unnoised_input_ids.flatten() # batch_size * sequence_length labels = torch.full_like(input_ids, fill_value=self.ignore_token_id) p_mask = torch.empty_like(input_ids, dtype=torch.bfloat16) - # assert batch_size % 2 == 0 + assert batch_size % 2 == 0 masked_ptr = 0 masked_indices = ( torch.zeros(batch_size * (sequence_length // 2), dtype=input_ids.dtype, device=input_ids.device) - 1 ) + document_end_positions = unnoised_input_ids == self.eos_token_id document_end_positions[sequence_length - 1 :: sequence_length] = 1 eps = 1e-4 + moved_boundary = False - def _apply_mask_and_fill(start_idx, end_idx, masked_idxs): - x = unnoised_input_ids[start_idx:end_idx].clone() - row_p = p_mask[start_idx:end_idx] + def _apply_mask_and_fill(start_idx, end_idx, masked_idxs, p): + nonlocal moved_boundary labels[start_idx:end_idx][masked_idxs] = input_ids[start_idx:end_idx][masked_idxs] input_ids[start_idx:end_idx][masked_idxs] = self.mask_token_id - # input_ids[start_idx:end_idx] = torch.where(masked, self.mask_token_id, x) - # labels[start_idx:end_idx] = torch.where(~masked, self.ignore_token_id, x) - - end_positions = x == self.eos_token_id - end_positions[-1] = True - cu_seqlens = end_positions.nonzero(as_tuple=True)[0] + 1 - - for i in range(cu_seqlens.size(0)): - doc_start = 0 if i == 0 else cu_seqlens[i - 1] - doc_end = cu_seqlens[i] - doc_mask = input_ids[start_idx:end_idx][doc_start:doc_end] == self.mask_token_id - row_p[doc_start:doc_end] = torch.clamp(doc_mask.float().mean(), min=eps) + p_mask[start_idx:end_idx] = p + + prob = torch.rand(1, device=tokens.device) + if prob < 0.5: + end_positions = unnoised_input_ids[start_idx:end_idx] == self.eos_token_id + end_positions_noised = input_ids[start_idx:end_idx] == self.eos_token_id + # find mismatches + end_position_mismatch = (end_positions != end_positions_noised) & ( + input_ids[start_idx:end_idx] == self.mask_token_id + ) + if end_position_mismatch.any(): + movable_locs = torch.nonzero(end_position_mismatch, as_tuple=True)[0] + move_start_idx = movable_locs[torch.randint(movable_locs.size(0), size=(1,))[0]] + rest_unmasked = input_ids[start_idx:end_idx][move_start_idx:] != self.mask_token_id + if rest_unmasked.any(): + first_unmasked_idx = move_start_idx + torch.nonzero(rest_unmasked, as_tuple=True)[0].min() + if first_unmasked_idx - move_start_idx > 1: + moved_boundary = True + document_end_positions[start_idx:end_idx][move_start_idx] = False + document_end_positions[start_idx:end_idx][first_unmasked_idx] = True + input_ids[start_idx:end_idx][first_unmasked_idx] = self.eos_token_id + labels[start_idx:end_idx][move_start_idx:first_unmasked_idx] = self.eos_token_id - # for i in range(0, batch_size, 2): for i in range(orig_batch_size): t = torch.rand(1, device=input_ids.device)[0] p = (1 - 2 * eps) * t + eps - + sample_masked_idxs = perm_idxs[i] mask_count = torch.round(p * sequence_length).to(torch.int32) - masked_idxs_ = perm_idxs[i, :mask_count] + masked_idxs_ = sample_masked_idxs[:mask_count] _apply_mask_and_fill( - start_idx=2 * i * sequence_length, end_idx=(2 * i + 1) * sequence_length, masked_idxs=masked_idxs_ + start_idx=2 * i * sequence_length, end_idx=(2 * i + 1) * sequence_length, masked_idxs=masked_idxs_, p=p ) masked_indices[masked_ptr : masked_ptr + mask_count] = 2 * i * sequence_length + masked_idxs_ masked_ptr += mask_count - masked_idxs_ = perm_idxs[i, mask_count:] + masked_idxs_ = sample_masked_idxs[mask_count:] mask_count = sequence_length - mask_count _apply_mask_and_fill( start_idx=(2 * i + 1) * sequence_length, end_idx=(2 * i + 2) * sequence_length, masked_idxs=masked_idxs_, + p=1 - p, ) masked_indices[masked_ptr : masked_ptr + mask_count] = (2 * i + 1) * sequence_length + masked_idxs_ masked_ptr += mask_count masked_indices, _ = torch.sort(masked_indices) - # idxs = torch.arange(masked_indices.size(0), device=masked_indices.device) - # consec_equal = masked_indices[idxs[:-1]] == masked_indices[idxs[1:]] - # assert not consec_equal.any(), masked_indices[:-1][consec_equal] - # assert masked_indices.size(0) == torch.unique(masked_indices).size(0) cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 cu_seqlens = torch.cat([torch.tensor([0], device=unnoised_input_ids.device), cu_seqlens]).to(torch.int32) seqlen = cu_seqlens[1:] - cu_seqlens[:-1] @@ -315,38 +297,40 @@ def _apply_mask_and_fill(start_idx, end_idx, masked_idxs): assert (labels[masked_indices] != self.ignore_token_id).all() assert (input_ids[masked_indices] == self.mask_token_id).all() - batch["input_ids"] = input_ids.flatten() - batch["labels"] = labels.flatten() - batch["p_mask"] = p_mask.flatten() - batch["cu_seqlens"] = cu_seqlens - batch["max_seqlen"] = max_seqlen - batch["position_ids"] = position_ids - batch["masked_indices"] = masked_indices - - # from transformers import PreTrainedTokenizer - # tokenizer: PreTrainedTokenizer = self.tokenizer - # def to_token_list(seq): - # output = [] - # for idx in seq: - # if idx == self.ignore_token_id: - # c = "" - # elif idx == self.mask_token_id: - # c = "_" - # else: - # c = tokenizer._convert_id_to_token(idx) - # output.append(c) - # return output - # print((input_ids == self.mask_token_id).int().sum().item()) - # for i in range(cu_seqlens.size(0) - 1): - # seq_in = input_ids.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] - # seq_out = labels.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] - # seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) - # assert p_mask[cu_seqlens[i]] == p_mask[cu_seqlens[i + 1] - 1] - # print() - # print(cu_seqlens[i].item(), cu_seqlens[i + 1].item(), p_mask[cu_seqlens[i + 1] - 1].item()) - # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) - # print(cu_seqlens) - # exit() + batch = { + "input_ids": input_ids.flatten(), + "labels": labels.flatten(), + "p_mask": p_mask.flatten(), + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "position_ids": position_ids, + "masked_indices": masked_indices, + } + # if moved_boundary: + # from transformers import PreTrainedTokenizer + # tokenizer: PreTrainedTokenizer = self.tokenizer + # def to_token_list(seq): + # output = [] + # for idx in seq: + # if idx == self.ignore_token_id: + # c = "" + # elif idx == self.mask_token_id: + # c = "_" + # else: + # c = tokenizer._convert_id_to_token(idx) + # output.append(c) + # return output + # print((input_ids == self.mask_token_id).int().sum().item()) + # for i in range(cu_seqlens.size(0) - 1): + # seq_in = input_ids.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] + # seq_out = labels.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] + # seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) + # assert p_mask[cu_seqlens[i]] == p_mask[cu_seqlens[i + 1] - 1] + # print() + # print(cu_seqlens[i].item(), cu_seqlens[i + 1].item(), p_mask[cu_seqlens[i + 1] - 1].item()) + # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) + # print(cu_seqlens) + # exit() # else: # print("No deletions.") From 6386a653246945773f3a614573aa85a72fa095f0 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Mon, 11 Aug 2025 19:27:19 +0000 Subject: [PATCH 14/19] Annealing code --- .../utils/flash_attention_utils.py | 13 ++ lm_engine/hf_models/models/diffusion/main.py | 19 ++- .../model_wrapper/pretraining_diffusion.py | 161 ++++++++++++------ lm_engine/optimization/optimizer.py | 16 +- 4 files changed, 151 insertions(+), 58 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py index 6a7cf4058..bcdb7b0e3 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py @@ -94,6 +94,18 @@ def flash_attention( causal=causal, ) else: + from ....model_wrapper.pretraining_diffusion import ANNEAL, get_annealing_step, get_max_annealing_steps + + if ANNEAL: + anneal_proportion = get_annealing_step() / get_max_annealing_steps() + if anneal_proportion < 0.5: + right_window_size = 0 + else: + right_window_size = int((anneal_proportion - 0.75) * 2 * max_seqlen) + # print("right_window_size =", right_window_size) + window_size = (-1, right_window_size) + else: + window_size = (-1, -1) attn_output = flash_attention_2_varlen( q=query, k=key, @@ -105,6 +117,7 @@ def flash_attention( dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, + window_size=window_size, ) else: if attention_mask is None: diff --git a/lm_engine/hf_models/models/diffusion/main.py b/lm_engine/hf_models/models/diffusion/main.py index 6504792b6..2fe745bc0 100644 --- a/lm_engine/hf_models/models/diffusion/main.py +++ b/lm_engine/hf_models/models/diffusion/main.py @@ -27,6 +27,7 @@ def __init__(self, config: DiffusionConfig, **kwargs) -> DiffusionPreTrainedMode self.mask_token_id = kwargs.pop("mask_token_id") super().__init__(config, **kwargs) self.router_aux_loss_coef = getattr(config, "router_aux_loss_coef", 0) + self.first_forward_pass = True self._init_model(config, **kwargs) def _init_model(self, config: DiffusionConfig, **kwargs) -> None: @@ -75,6 +76,13 @@ def forward( assert return_dict assert inputs_embeds is None + if self.training and self.first_forward_pass and hasattr(self, "mask_token_id"): + with torch.no_grad(): + self.transformer.wte.weight.data[self.mask_token_id] = torch.mean( + self.transformer.wte.weight.data, dim=0 + ) + self.first_forward_pass = False + input_ids, position_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( input_ids=input_ids, position_ids=position_ids, @@ -159,8 +167,17 @@ def forward( ) def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: - return ( + logits = ( F.linear(hidden_states, self.transformer.wte.weight) if self._tied_word_embeddings else self.lm_head(hidden_states) ) + assert hasattr(self, "mask_token_id") + logits[..., self.mask_token_id] + # print("mask_id_logit", + # mask_id_logit.min().item(), + # mask_id_logit.max().item(), + # mask_id_logit.median().item(), + # mask_id_logit.mean().item()) + logits[..., self.mask_token_id] = -1.0e5 + return logits diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index c182b7afd..37c49c665 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -27,6 +27,27 @@ FIM_MIDDLE = "" +ANNEAL = False + +MAX_ANNEALING_STEPS = 4096 +ANNEALING_STEP = 0 if ANNEAL else 4096 + + +def get_annealing_step(): + global ANNEALING_STEP + return ANNEALING_STEP + + +def get_max_annealing_steps(): + global MAX_ANNEALING_STEPS + return MAX_ANNEALING_STEPS + + +def update_annealing(): + global ANNEALING_STEP + ANNEALING_STEP = min(ANNEALING_STEP + 1, MAX_ANNEALING_STEPS) + # print("Updated to ", ANNEALING_STEP) + class ModelWrapperForPretrainingDiffusion(ModelWrapperForPretraining): def __init__( @@ -181,7 +202,7 @@ def get_loss( def _setup_tokenizer(self) -> None: super()._setup_tokenizer() # TODO (shawntan) Use FIM token for now. Figure out if there is a way to have actual mask token. - self.mask_token_id = self.tokenizer.convert_tokens_to_ids(FIM_MIDDLE) + self.mask_token_id = 100351 # self.tokenizer.convert_tokens_to_ids(FIM_MIDDLE) assert self.mask_token_id is not None self.pad_token_id = self.tokenizer.pad_token_id assert self.pad_token_id is not None @@ -203,63 +224,64 @@ def _prepare_model_inputs(self, batch: dict) -> dict: # unnoised_input_ids = tokens[:, 1:] # else: # unnoised_input_ids = tokens[:, :-1] - unnoised_input_ids = tokens - # input_ids, labels, p_mask = self._forward_process(unnoised_input_ids) - # batch = {"labels": labels, "p_mask": p_mask} - # batch = {} - - orig_batch_size, sequence_length = unnoised_input_ids.shape + input_ids = tokens[:, :-1] + labels = tokens[:, 1:] + orig_batch_size, sequence_length = input_ids.shape batch_size = orig_batch_size * 2 - perm_idxs = torch.argsort(torch.rand_like(unnoised_input_ids, dtype=torch.bfloat16), dim=-1) - unnoised_input_ids = unnoised_input_ids.repeat_interleave(2, 0).flatten() - input_ids = unnoised_input_ids.clone() + perm_idxs = torch.argsort(torch.rand_like(input_ids[:, :-1], dtype=torch.bfloat16), dim=-1) + # unnoised_input_ids = unnoised_input_ids.repeat_interleave(2, 0).flatten() + # input_ids = unnoised_input_ids.clone() + input_ids = input_ids.repeat_interleave(2, 0).flatten() + orig_input_ids = input_ids.clone() + unmasked_labels = labels.repeat_interleave(2, 0).flatten() labels = torch.full_like(input_ids, fill_value=self.ignore_token_id) - p_mask = torch.empty_like(input_ids, dtype=torch.bfloat16) + p_mask = torch.ones_like(input_ids, dtype=torch.bfloat16) - assert batch_size % 2 == 0 + # assert batch_size % 2 == 0 masked_ptr = 0 masked_indices = ( - torch.zeros(batch_size * (sequence_length // 2), dtype=input_ids.dtype, device=input_ids.device) - 1 + torch.zeros((batch_size // 2) * (sequence_length - 1), dtype=input_ids.dtype, device=input_ids.device) - 1 ) - document_end_positions = unnoised_input_ids == self.eos_token_id + document_end_positions = unmasked_labels == self.eos_token_id document_end_positions[sequence_length - 1 :: sequence_length] = 1 eps = 1e-4 moved_boundary = False def _apply_mask_and_fill(start_idx, end_idx, masked_idxs, p): nonlocal moved_boundary - labels[start_idx:end_idx][masked_idxs] = input_ids[start_idx:end_idx][masked_idxs] - input_ids[start_idx:end_idx][masked_idxs] = self.mask_token_id + # assert ((masked_idxs - 1) >= 0).all() + labels[start_idx:end_idx][masked_idxs] = input_ids[start_idx:end_idx][masked_idxs + 1] + input_ids[start_idx:end_idx][masked_idxs + 1] = self.mask_token_id p_mask[start_idx:end_idx] = p - prob = torch.rand(1, device=tokens.device) - if prob < 0.5: - end_positions = unnoised_input_ids[start_idx:end_idx] == self.eos_token_id - end_positions_noised = input_ids[start_idx:end_idx] == self.eos_token_id - # find mismatches - end_position_mismatch = (end_positions != end_positions_noised) & ( - input_ids[start_idx:end_idx] == self.mask_token_id - ) - if end_position_mismatch.any(): - movable_locs = torch.nonzero(end_position_mismatch, as_tuple=True)[0] - move_start_idx = movable_locs[torch.randint(movable_locs.size(0), size=(1,))[0]] - rest_unmasked = input_ids[start_idx:end_idx][move_start_idx:] != self.mask_token_id - if rest_unmasked.any(): - first_unmasked_idx = move_start_idx + torch.nonzero(rest_unmasked, as_tuple=True)[0].min() - if first_unmasked_idx - move_start_idx > 1: - moved_boundary = True - document_end_positions[start_idx:end_idx][move_start_idx] = False - document_end_positions[start_idx:end_idx][first_unmasked_idx] = True - input_ids[start_idx:end_idx][first_unmasked_idx] = self.eos_token_id - labels[start_idx:end_idx][move_start_idx:first_unmasked_idx] = self.eos_token_id + # prob = torch.rand(1, device=tokens.device) + # if prob < 0.5: + # end_positions = unnoised_input_ids[start_idx:end_idx] == self.eos_token_id + # end_positions_noised = input_ids[start_idx:end_idx] == self.eos_token_id + # # find mismatches + # end_position_mismatch = (end_positions != end_positions_noised) & ( + # input_ids[start_idx:end_idx] == self.mask_token_id + # ) + # if end_position_mismatch.any(): + # movable_locs = torch.nonzero(end_position_mismatch, as_tuple=True)[0] + # move_start_idx = movable_locs[torch.randint(movable_locs.size(0), size=(1,))[0]] + # rest_unmasked = input_ids[start_idx:end_idx][move_start_idx:] != self.mask_token_id + # if rest_unmasked.any(): + # first_unmasked_idx = move_start_idx + torch.nonzero(rest_unmasked, as_tuple=True)[0].min() + # if first_unmasked_idx - move_start_idx > 1: + # moved_boundary = True + # document_end_positions[start_idx:end_idx][move_start_idx] = False + # document_end_positions[start_idx:end_idx][first_unmasked_idx] = True + # input_ids[start_idx:end_idx][first_unmasked_idx] = self.eos_token_id + # labels[start_idx:end_idx][move_start_idx:first_unmasked_idx] = self.eos_token_id for i in range(orig_batch_size): t = torch.rand(1, device=input_ids.device)[0] p = (1 - 2 * eps) * t + eps sample_masked_idxs = perm_idxs[i] - mask_count = torch.round(p * sequence_length).to(torch.int32) + mask_count = torch.round(p * (sequence_length - 1)).to(torch.int32) masked_idxs_ = sample_masked_idxs[:mask_count] _apply_mask_and_fill( start_idx=2 * i * sequence_length, end_idx=(2 * i + 1) * sequence_length, masked_idxs=masked_idxs_, p=p @@ -268,7 +290,7 @@ def _apply_mask_and_fill(start_idx, end_idx, masked_idxs, p): masked_ptr += mask_count masked_idxs_ = sample_masked_idxs[mask_count:] - mask_count = sequence_length - mask_count + mask_count = (sequence_length - 1) - mask_count _apply_mask_and_fill( start_idx=(2 * i + 1) * sequence_length, end_idx=(2 * i + 2) * sequence_length, @@ -277,36 +299,56 @@ def _apply_mask_and_fill(start_idx, end_idx, masked_idxs, p): ) masked_indices[masked_ptr : masked_ptr + mask_count] = (2 * i + 1) * sequence_length + masked_idxs_ masked_ptr += mask_count + # assert (masked_indices != -1).any() masked_indices, _ = torch.sort(masked_indices) cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 - cu_seqlens = torch.cat([torch.tensor([0], device=unnoised_input_ids.device), cu_seqlens]).to(torch.int32) + cu_seqlens = torch.cat([torch.tensor([0], device=input_ids.device), cu_seqlens]).to(torch.int32) seqlen = cu_seqlens[1:] - cu_seqlens[:-1] # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers max_seqlen = seqlen.max().item() if self.reset_position_ids: position_ids = torch.cat( - [torch.arange(0, i, 1, dtype=torch.int32, device=unnoised_input_ids.device) for i in seqlen] + [torch.arange(0, i, 1, dtype=torch.int32, device=input_ids.device) for i in seqlen] ) else: position_ids = self.position_ids # masked_idxs = (labels != self.ignore_token_id).nonzero(as_tuple=True)[0] # masked_idxs, _ = torch.sort(masked_idxs) + # print(labels[masked_indices], masked_indices) assert (labels[masked_indices] != self.ignore_token_id).all() - assert (input_ids[masked_indices] == self.mask_token_id).all() - - batch = { - "input_ids": input_ids.flatten(), - "labels": labels.flatten(), - "p_mask": p_mask.flatten(), - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "position_ids": position_ids, - "masked_indices": masked_indices, - } - # if moved_boundary: + assert (input_ids[masked_indices + 1] == self.mask_token_id).all() + anneal_ratio = min((get_annealing_step() / get_max_annealing_steps()) / 0.5, 1) + if ANNEAL: + batch = { + # "input_ids": input_ids, + # "input_ids": orig_input_ids, + "input_ids": torch.where( + torch.rand_like(orig_input_ids, dtype=torch.bfloat16) < anneal_ratio, input_ids, orig_input_ids + ), + "labels": unmasked_labels, + # "labels": labels.flatten(), + "p_mask": p_mask.flatten(), + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "position_ids": position_ids, + "masked_indices": masked_indices, + } + update_annealing() + else: + batch = { + "input_ids": input_ids, + "labels": labels.flatten(), + "p_mask": p_mask.flatten(), + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "position_ids": position_ids, + "masked_indices": masked_indices, + } + + # if True: # from transformers import PreTrainedTokenizer # tokenizer: PreTrainedTokenizer = self.tokenizer # def to_token_list(seq): @@ -321,14 +363,21 @@ def _apply_mask_and_fill(start_idx, end_idx, masked_idxs, p): # output.append(c) # return output # print((input_ids == self.mask_token_id).int().sum().item()) + # combine_seq = batch["input_ids"][1:].clone() + # mask = combine_seq == self.mask_token_id + # combine_seq[mask] = batch["labels"][:-1][mask] + # combine_seq = torch.cat([input_ids[:1], combine_seq], dim=0) + # for i in range(cu_seqlens.size(0) - 1): - # seq_in = input_ids.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] - # seq_out = labels.flatten()[cu_seqlens[i] : cu_seqlens[i + 1]] - # seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) + # # seq_in = batch["input_ids"][cu_seqlens[i] : cu_seqlens[i + 1]] + # # seq_out = batch["labels"][cu_seqlens[i] : cu_seqlens[i + 1]] + # # seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) + # seq = combine_seq[cu_seqlens[i] : cu_seqlens[i + 1]] # assert p_mask[cu_seqlens[i]] == p_mask[cu_seqlens[i + 1] - 1] # print() # print(cu_seqlens[i].item(), cu_seqlens[i + 1].item(), p_mask[cu_seqlens[i + 1] - 1].item()) - # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) + # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq)))) + # # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) # print(cu_seqlens) # exit() # else: diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d03fdea39..b84f06688 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -67,7 +67,21 @@ def get_optimizer_container( raise ImportError("relevant package for the optimizer is not installed") params_groups_list = get_param_groups_list(model_container, optimizer_class_args, params_group_method) - + # TODO hack for length-extension + # for group in params_groups_list: + # for param in group: + # for p in param[1]: + # # p.parameter_name_map = {k: p.parameter_name_map[k] for k in p.parameter_name_map} + # # for k in p.parameter_name_map: + # # print(k) + # p.parameter_name_map = { + # k: p.parameter_name_map[k] + # for k in p.parameter_name_map + # if ( + # 'sequence_mixer' in k or + # 'wte' in k + # ) + # } if use_optimizer_with_backward_hook: for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): From 1514ad4570663fa4cc53808698e92ada47541df0 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Mon, 11 Aug 2025 20:02:39 +0000 Subject: [PATCH 15/19] Delete annealing stuff --- .../utils/flash_attention_utils.py | 13 -- lm_engine/hf_models/models/diffusion/main.py | 16 -- .../model_wrapper/pretraining_diffusion.py | 140 +++--------------- 3 files changed, 17 insertions(+), 152 deletions(-) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py index bcdb7b0e3..6a7cf4058 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/utils/flash_attention_utils.py @@ -94,18 +94,6 @@ def flash_attention( causal=causal, ) else: - from ....model_wrapper.pretraining_diffusion import ANNEAL, get_annealing_step, get_max_annealing_steps - - if ANNEAL: - anneal_proportion = get_annealing_step() / get_max_annealing_steps() - if anneal_proportion < 0.5: - right_window_size = 0 - else: - right_window_size = int((anneal_proportion - 0.75) * 2 * max_seqlen) - # print("right_window_size =", right_window_size) - window_size = (-1, right_window_size) - else: - window_size = (-1, -1) attn_output = flash_attention_2_varlen( q=query, k=key, @@ -117,7 +105,6 @@ def flash_attention( dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, - window_size=window_size, ) else: if attention_mask is None: diff --git a/lm_engine/hf_models/models/diffusion/main.py b/lm_engine/hf_models/models/diffusion/main.py index 2fe745bc0..d4b215703 100644 --- a/lm_engine/hf_models/models/diffusion/main.py +++ b/lm_engine/hf_models/models/diffusion/main.py @@ -27,7 +27,6 @@ def __init__(self, config: DiffusionConfig, **kwargs) -> DiffusionPreTrainedMode self.mask_token_id = kwargs.pop("mask_token_id") super().__init__(config, **kwargs) self.router_aux_loss_coef = getattr(config, "router_aux_loss_coef", 0) - self.first_forward_pass = True self._init_model(config, **kwargs) def _init_model(self, config: DiffusionConfig, **kwargs) -> None: @@ -75,14 +74,6 @@ def forward( ) -> CausalLMOutputWithPast: assert return_dict assert inputs_embeds is None - - if self.training and self.first_forward_pass and hasattr(self, "mask_token_id"): - with torch.no_grad(): - self.transformer.wte.weight.data[self.mask_token_id] = torch.mean( - self.transformer.wte.weight.data, dim=0 - ) - self.first_forward_pass = False - input_ids, position_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( input_ids=input_ids, position_ids=position_ids, @@ -172,12 +163,5 @@ def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: if self._tied_word_embeddings else self.lm_head(hidden_states) ) - assert hasattr(self, "mask_token_id") - logits[..., self.mask_token_id] - # print("mask_id_logit", - # mask_id_logit.min().item(), - # mask_id_logit.max().item(), - # mask_id_logit.median().item(), - # mask_id_logit.mean().item()) logits[..., self.mask_token_id] = -1.0e5 return logits diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py index 37c49c665..9b6ea2ac4 100644 --- a/lm_engine/model_wrapper/pretraining_diffusion.py +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -25,30 +25,6 @@ from .utils import broadcast_tensor_parallel_input -FIM_MIDDLE = "" - -ANNEAL = False - -MAX_ANNEALING_STEPS = 4096 -ANNEALING_STEP = 0 if ANNEAL else 4096 - - -def get_annealing_step(): - global ANNEALING_STEP - return ANNEALING_STEP - - -def get_max_annealing_steps(): - global MAX_ANNEALING_STEPS - return MAX_ANNEALING_STEPS - - -def update_annealing(): - global ANNEALING_STEP - ANNEALING_STEP = min(ANNEALING_STEP + 1, MAX_ANNEALING_STEPS) - # print("Updated to ", ANNEALING_STEP) - - class ModelWrapperForPretrainingDiffusion(ModelWrapperForPretraining): def __init__( self, @@ -201,12 +177,14 @@ def get_loss( def _setup_tokenizer(self) -> None: super()._setup_tokenizer() - # TODO (shawntan) Use FIM token for now. Figure out if there is a way to have actual mask token. - self.mask_token_id = 100351 # self.tokenizer.convert_tokens_to_ids(FIM_MIDDLE) - assert self.mask_token_id is not None + assert hasattr( + self.tokenizer, "mask_token_id" + ), "Tokenizer must have `mask_token_id` for diffusion_pretraining" + self.mask_token_id = self.tokenizer.mask_token_id + assert self.mask_token_id is not None, "Tokenizer must have `mask_token_id` for diffusion_pretraining" self.pad_token_id = self.tokenizer.pad_token_id assert self.pad_token_id is not None - self.ignore_token_id = -1 # self.pad_token_id # self.mask_token_id + self.ignore_token_id = -1 def _prepare_model_inputs(self, batch: dict) -> dict: device = torch.cuda.current_device() @@ -220,18 +198,14 @@ def _prepare_model_inputs(self, batch: dict) -> dict: else: tokens = batch["text"] tokens = tokens.to(device) - # if torch.rand(1, device=tokens.device) < 0.5: - # unnoised_input_ids = tokens[:, 1:] - # else: - # unnoised_input_ids = tokens[:, :-1] + + # still shifted to facilitate adaptation workflow input_ids = tokens[:, :-1] labels = tokens[:, 1:] orig_batch_size, sequence_length = input_ids.shape batch_size = orig_batch_size * 2 perm_idxs = torch.argsort(torch.rand_like(input_ids[:, :-1], dtype=torch.bfloat16), dim=-1) - # unnoised_input_ids = unnoised_input_ids.repeat_interleave(2, 0).flatten() - # input_ids = unnoised_input_ids.clone() input_ids = input_ids.repeat_interleave(2, 0).flatten() orig_input_ids = input_ids.clone() unmasked_labels = labels.repeat_interleave(2, 0).flatten() @@ -251,32 +225,10 @@ def _prepare_model_inputs(self, batch: dict) -> dict: def _apply_mask_and_fill(start_idx, end_idx, masked_idxs, p): nonlocal moved_boundary - # assert ((masked_idxs - 1) >= 0).all() labels[start_idx:end_idx][masked_idxs] = input_ids[start_idx:end_idx][masked_idxs + 1] input_ids[start_idx:end_idx][masked_idxs + 1] = self.mask_token_id p_mask[start_idx:end_idx] = p - # prob = torch.rand(1, device=tokens.device) - # if prob < 0.5: - # end_positions = unnoised_input_ids[start_idx:end_idx] == self.eos_token_id - # end_positions_noised = input_ids[start_idx:end_idx] == self.eos_token_id - # # find mismatches - # end_position_mismatch = (end_positions != end_positions_noised) & ( - # input_ids[start_idx:end_idx] == self.mask_token_id - # ) - # if end_position_mismatch.any(): - # movable_locs = torch.nonzero(end_position_mismatch, as_tuple=True)[0] - # move_start_idx = movable_locs[torch.randint(movable_locs.size(0), size=(1,))[0]] - # rest_unmasked = input_ids[start_idx:end_idx][move_start_idx:] != self.mask_token_id - # if rest_unmasked.any(): - # first_unmasked_idx = move_start_idx + torch.nonzero(rest_unmasked, as_tuple=True)[0].min() - # if first_unmasked_idx - move_start_idx > 1: - # moved_boundary = True - # document_end_positions[start_idx:end_idx][move_start_idx] = False - # document_end_positions[start_idx:end_idx][first_unmasked_idx] = True - # input_ids[start_idx:end_idx][first_unmasked_idx] = self.eos_token_id - # labels[start_idx:end_idx][move_start_idx:first_unmasked_idx] = self.eos_token_id - for i in range(orig_batch_size): t = torch.rand(1, device=input_ids.device)[0] p = (1 - 2 * eps) * t + eps @@ -314,75 +266,17 @@ def _apply_mask_and_fill(start_idx, end_idx, masked_idxs, p): ) else: position_ids = self.position_ids - - # masked_idxs = (labels != self.ignore_token_id).nonzero(as_tuple=True)[0] - # masked_idxs, _ = torch.sort(masked_idxs) - # print(labels[masked_indices], masked_indices) assert (labels[masked_indices] != self.ignore_token_id).all() assert (input_ids[masked_indices + 1] == self.mask_token_id).all() - anneal_ratio = min((get_annealing_step() / get_max_annealing_steps()) / 0.5, 1) - if ANNEAL: - batch = { - # "input_ids": input_ids, - # "input_ids": orig_input_ids, - "input_ids": torch.where( - torch.rand_like(orig_input_ids, dtype=torch.bfloat16) < anneal_ratio, input_ids, orig_input_ids - ), - "labels": unmasked_labels, - # "labels": labels.flatten(), - "p_mask": p_mask.flatten(), - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "position_ids": position_ids, - "masked_indices": masked_indices, - } - update_annealing() - else: - batch = { - "input_ids": input_ids, - "labels": labels.flatten(), - "p_mask": p_mask.flatten(), - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "position_ids": position_ids, - "masked_indices": masked_indices, - } - - # if True: - # from transformers import PreTrainedTokenizer - # tokenizer: PreTrainedTokenizer = self.tokenizer - # def to_token_list(seq): - # output = [] - # for idx in seq: - # if idx == self.ignore_token_id: - # c = "" - # elif idx == self.mask_token_id: - # c = "_" - # else: - # c = tokenizer._convert_id_to_token(idx) - # output.append(c) - # return output - # print((input_ids == self.mask_token_id).int().sum().item()) - # combine_seq = batch["input_ids"][1:].clone() - # mask = combine_seq == self.mask_token_id - # combine_seq[mask] = batch["labels"][:-1][mask] - # combine_seq = torch.cat([input_ids[:1], combine_seq], dim=0) - - # for i in range(cu_seqlens.size(0) - 1): - # # seq_in = batch["input_ids"][cu_seqlens[i] : cu_seqlens[i + 1]] - # # seq_out = batch["labels"][cu_seqlens[i] : cu_seqlens[i + 1]] - # # seq = torch.where(seq_out == self.ignore_token_id, seq_in, seq_out) - # seq = combine_seq[cu_seqlens[i] : cu_seqlens[i + 1]] - # assert p_mask[cu_seqlens[i]] == p_mask[cu_seqlens[i + 1] - 1] - # print() - # print(cu_seqlens[i].item(), cu_seqlens[i + 1].item(), p_mask[cu_seqlens[i + 1] - 1].item()) - # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq)))) - # # print(repr(tokenizer.convert_tokens_to_string(to_token_list(seq_out)))) - # print(cu_seqlens) - # exit() - # else: - # print("No deletions.") - + batch = { + "input_ids": input_ids, + "labels": labels.flatten(), + "p_mask": p_mask.flatten(), + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "position_ids": position_ids, + "masked_indices": masked_indices, + } if ProcessGroupManager.is_tensor_parallel_enabled(): batch["output_parallel_lm_logits"] = True From 83d8474045ae0b598a414efa869c528bd516454c Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Tue, 19 Aug 2025 04:14:03 +0000 Subject: [PATCH 16/19] Modified. --- lm_engine/hf_models/models/diffusion/main.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lm_engine/hf_models/models/diffusion/main.py b/lm_engine/hf_models/models/diffusion/main.py index d4b215703..cbe798d47 100644 --- a/lm_engine/hf_models/models/diffusion/main.py +++ b/lm_engine/hf_models/models/diffusion/main.py @@ -107,15 +107,15 @@ def forward( ) hidden_states = transformer_outputs.last_hidden_state + if masked_indices is not None: + hidden_states = torch.index_select(hidden_states, dim=0, index=masked_indices) + past_key_values = transformer_outputs.past_key_values del transformer_outputs lm_logits = None loss = None - if masked_indices is not None: - hidden_states = hidden_states[masked_indices] - if labels is None: if is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute): if self.m_width is not None: @@ -163,5 +163,8 @@ def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: if self._tied_word_embeddings else self.lm_head(hidden_states) ) - logits[..., self.mask_token_id] = -1.0e5 + logits.index_fill_( + dim=-1, index=torch.tensor(self.mask_token_id, dtype=torch.int32, device=logits.device), value=-1.0e5 + ) + # logits[..., self.mask_token_id] = -1.0e5 return logits From a8137d15946d34450232fd6b5edb2df671857f57 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Tue, 19 Aug 2025 04:30:46 +0000 Subject: [PATCH 17/19] Added repr for attention blocks --- .../modeling_utils/sequence_mixer_blocks/attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index a812473ac..57b59cced 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -128,6 +128,9 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + def extra_repr(self): + return f"causal={self.causal}, num_heads={self.num_heads}, num_key_value_heads={self.num_key_value_heads}," + def forward( self, hidden_states: torch.Tensor, From 0abc6ecd6282f5258eec134c284a54bda25ba8a5 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Wed, 20 Aug 2025 13:47:33 +0000 Subject: [PATCH 18/19] Config files. --- examples/diffusion/diffusion-1b-24l.yml | 352 ++++++++++++++++++++++++ examples/diffusion/diffusion.sh | 25 ++ examples/diffusion/modify_tokenizer.py | 27 ++ examples/diffusion/preprocess_data.py | 126 +++++++++ 4 files changed, 530 insertions(+) create mode 100755 examples/diffusion/diffusion-1b-24l.yml create mode 100755 examples/diffusion/diffusion.sh create mode 100755 examples/diffusion/modify_tokenizer.py create mode 100755 examples/diffusion/preprocess_data.py diff --git a/examples/diffusion/diffusion-1b-24l.yml b/examples/diffusion/diffusion-1b-24l.yml new file mode 100755 index 000000000..ca7f01c0f --- /dev/null +++ b/examples/diffusion/diffusion-1b-24l.yml @@ -0,0 +1,352 @@ +datasets: + # class_name - data_name & data_sampling_ratio are not used but need to be passed to avoid errors + - class_name: MegatronDataset + data_name: Megatron + data_sampling_ratio: 1 + class_args: + eval_steps: 2 + data_cache_path: /proj/checkpoints/shawntan/diffusion/release/data-cache + data_path: + - 1 # mixture ratio + - /proj/checkpoints/shawntan/diffusion/release/data/dclm-dedup-gpt2-tokenized/dclm_00_text # path prefix + split: 100,0,0 + sequence_length: 4096 # context length + + +tokenizer_args: + tokenizer_name: /proj/checkpoints/shawntan/diffusion/release/data/tokenizer + +kernel_args: + kernels: + - swiglu_packed_cute + - rmsnorm_cute + - scattermoe + - flash_attention_2 + +model_args: + model_class: AutoModelForCausalLM + pretrained_config: + initializer_range: 0.1 + layer_norm_epsilon: 1e-05 + model_type: diffusion + normalization_function: rmsnorm + position_embedding_type: rope + hidden_size: 2048 + m_width: 8 + m_emb: 12 + m_residual: 0.28577380332470415 + num_layers: 24 + init_method: mup + tie_word_embeddings: true + router_aux_loss_coef: 0.01 + bos_token_id: 50256 # ensure these are same in the tokenizer + eos_token_id: 50256 + pad_token_id: 50258 + vocab_size: 50259 + max_position_embeddings: 4096 + sequence_mixer_blocks: + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + mlp_blocks: + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + + + use_padding_free_transformer: true + # efficient_initialization: true + reset_attention_mask: true + reset_position_ids: true + +tuning_args: + tuning_method: pretraining_diffusion + +save_args: + save_path: /proj/checkpoints/shawntan/diffusion/release/data/diffusion-24l-1b + save_interval: 5000 + +# TODO restoring from last checkpoint +# load_args: +# load_path: /proj/checkpoints/shawntan/diffusion/release/data/diffusion-24l-1b + +logging_args: + log_interval: 10 +# experiments_tracker_name: wandb +# wandb_args: +# project: diffusion-release +# name: diffusion-1b-24l + + +training_parameters: + num_training_steps: 75000 + eval_interval: 1000000000 + micro_batch_size: 2 + gradient_accumulation_steps: 4 + eval_during_training: false + +optimizer_args: + params_group_method: mup + class_name: TorchAdamW + class_args: + lr: 0.01 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + eps: 1e-10 + +lr_scheduler_args: + lr_decay_style: power + num_warmup_steps: 5000 + num_constant_steps: 0 + num_decay_steps: 70000 + extra_lr_scheduler_args: + # 4 * global_batch_size + a: 4096 + # constant + b: -0.51 + # global_batch_size in number of tokens + c: 4194304 + +mixed_precision_args: + dtype: bf16 + +distributed_args: + fsdp_algorithm: 2 + torch_compile: true + stage: 0 diff --git a/examples/diffusion/diffusion.sh b/examples/diffusion/diffusion.sh new file mode 100755 index 000000000..7cbd91802 --- /dev/null +++ b/examples/diffusion/diffusion.sh @@ -0,0 +1,25 @@ +#!/bin/bash +set -x +DATASET="Zyphra/dclm-dedup" +BASE_TOKENIZER="openai-community/gpt2" +DATA_PATH="../data/" +mkdir -p $DATA_PATH +TRAIN_PATH="$DATA_PATH/dclm-dedup-gpt2-tokenized" +mkdir -p $TRAIN_PATH +TOKENIZER_PATH="$DATA_PATH/tokenizer" +mkdir -p $TOKENIZER_PATH + +python -u examples/diffusion/modify_tokenizer.py --tokenizer $BASE_TOKENIZER --output-path $TOKENIZER_PATH + +CHUNK=0 +CHUNK_SIZE=20000000 +START_IDX=$(($CHUNK * $CHUNK_SIZE)) +END_IDX=$(($START_IDX + $CHUNK_SIZE)) +SPLIT="train[$START_IDX:$END_IDX]" + +OUTPUT_FILE="$TRAIN_PATH/dclm_`printf '%02d' $CHUNK`" +python -u examples/diffusion/preprocess_data.py \ + --input Zyphra/dclm-dedup --split $SPLIT \ + --tokenizer $TOKENIZER_PATH \ + --output-prefix $OUTPUT_FILE \ + --workers 128 --chunk-size 8192 --append-eod diff --git a/examples/diffusion/modify_tokenizer.py b/examples/diffusion/modify_tokenizer.py new file mode 100755 index 000000000..8b7016c73 --- /dev/null +++ b/examples/diffusion/modify_tokenizer.py @@ -0,0 +1,27 @@ +import sys +from argparse import ArgumentParser, Namespace + +from transformers import AutoTokenizer, PreTrainedTokenizer + + +def get_args() -> Namespace: + parser = ArgumentParser() + group = parser.add_argument_group(title="input data") + group.add_argument("--tokenizer", type=str, required=True, help="Path to the tokenizer") + group = parser.add_argument_group(title="output data") + group.add_argument("--output-path", type=str, required=True, help="Path to binary output file without suffix") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(args.tokenizer, model_max_length=4096) + tokenizer.add_special_tokens({"mask_token": ""}) + tokenizer.add_special_tokens({"pad_token": ""}) + tokenizer.model_max_length = sys.maxsize + print("bos_token_id", tokenizer.bos_token_id) + print("eos_token_id", tokenizer.eos_token_id) + print("pad_token_id", tokenizer.pad_token_id) + print("Vocab size:", len(tokenizer)) + tokenizer.save_pretrained(args.output_path) diff --git a/examples/diffusion/preprocess_data.py b/examples/diffusion/preprocess_data.py new file mode 100755 index 000000000..9848354db --- /dev/null +++ b/examples/diffusion/preprocess_data.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import json +import multiprocessing +from argparse import ArgumentParser, Namespace +from typing import List + +import datasets +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from lm_engine.data.megatron.indexed_dataset import DType, MMapIndexedDatasetBuilder + + +class Encoder: + def __init__(self, tokenizer: AutoTokenizer, json_keys: List[str], append_eod: bool, tokenizer_str: str) -> None: + self.tokenizer_str = tokenizer_str + self.tokenizer = None + self.json_keys = json_keys + self.append_eod = append_eod + + def _encode_data(self, data): + ids = {} + for key in self.json_keys: + text = data[key] + # text = text.encode('ascii','backslashreplace').decode('ascii') # TODO + document_ids = self.tokenizer.encode(text) + if len(document_ids) > 0: + if self.append_eod: + document_ids.append(self.tokenizer.eos_token_id) + # decoded_text = self.tokenizer.decode(document_ids) + # print(decoded_text) + # exit() + ids[key] = document_ids + return ids + + def encode(self, json_line): + data = json.loads(json_line) + return self._encode_data(data) + + def encode_jsonl_zstd(self, bytes_obj): + json_str = bytes_obj.decode("utf-8") + return self.encode(json_str) + + def load_tokenizer(self): + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_str) + + def encode_hf(self, sample): + self.load_tokenizer() + return self._encode_data(sample) + + +def get_args() -> Namespace: + parser = ArgumentParser() + + group = parser.add_argument_group(title="input data") + group.add_argument("--input", type=str, required=True, help="Path to input JSON/Arrow") + group.add_argument( + "--subset", type=str, default=None, help="Subset argument when loading input data from a HuggingFace dataset" + ) + group.add_argument( + "--split", type=str, default="train", help="Split argument when loading input data from a HuggingFace dataset" + ) + + group.add_argument( + "--json-keys", nargs="+", default=["text"], help="space separate listed of keys to extract from json" + ) + + group = parser.add_argument_group(title="tokenizer") + group.add_argument("--tokenizer", type=str, required=True, help="Path to the tokenizer") + group.add_argument("--append-eod", action="store_true", help="Append an token to the end of a document.") + + group = parser.add_argument_group(title="output data") + group.add_argument("--output-prefix", type=str, required=True, help="Path to binary output file without suffix") + + group = parser.add_argument_group(title="runtime") + group.add_argument("--workers", type=int, required=True, help="Number of worker processes to launch") + group.add_argument("--chunk-size", type=int, required=True, help="Chunk size assigned to each worker process") + args = parser.parse_args() + + return args + + +def main() -> None: + args = get_args() + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + del tokenizer.model_max_length + encoder = Encoder(tokenizer, args.json_keys, args.append_eod, tokenizer_str=args.tokenizer) + + def init(): + encoder.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + + print(args.input, args.subset, args.split) + pool = multiprocessing.Pool(args.workers, initializer=init) + # ds = load_dataset(args.input, use_auth_token=True, streaming=True, split=args.split, data_dir=args.subset) + ds = load_dataset( + args.input, + data_dir=args.subset, + split=args.split, + ) + + encoded_docs = pool.imap(encoder.encode_hf, ds, args.chunk_size) + + builders = { + key: MMapIndexedDatasetBuilder( + f"{args.output_prefix}_{key}.bin", dtype=DType.optimal_dtype(tokenizer.vocab_size) + ) + for key in args.json_keys + } + + for item in tqdm(encoded_docs): + for key, document in item.items(): + builders[key].add_item(torch.IntTensor(document)) + builders[key].end_document() + + print("Done! Now finalizing.") + + for key in args.json_keys: + builders[key].finalize(f"{args.output_prefix}_{key}.idx") + + +if __name__ == "__main__": + main() From 460af93404d4a89488d0b98fde32002ac6ed29e5 Mon Sep 17 00:00:00 2001 From: Shawn Tan Date: Fri, 29 Aug 2025 18:01:39 +0000 Subject: [PATCH 19/19] Eval scripts. --- examples/diffusion/diffusion_eval.py | 340 +++++++++++++++++++++++++++ examples/diffusion/eval.sh | 29 +++ 2 files changed, 369 insertions(+) create mode 100755 examples/diffusion/diffusion_eval.py create mode 100755 examples/diffusion/eval.sh diff --git a/examples/diffusion/diffusion_eval.py b/examples/diffusion/diffusion_eval.py new file mode 100755 index 000000000..fd6148d28 --- /dev/null +++ b/examples/diffusion/diffusion_eval.py @@ -0,0 +1,340 @@ +""" +This file is inspired by the code from https://github.com/ML-GSAI/SMDM +""" + +import math +import random + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +from datasets import Dataset +from lm_eval.__main__ import cli_evaluate +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from generate import generate +from lm_engine import hf_models +from lm_engine.kernels import Kernel, enable_kernels + + +enable_kernels( + [Kernel.mamba2_ssm, Kernel.scattermoe, Kernel.rmsnorm_cute, Kernel.swiglu_packed_cute, Kernel.flash_attention_2] +).__enter__() + + +def set_seed(seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +@register_model("lm_engine_diffusion") +class LMEngineDiffusionEvalHarness(LM): + def __init__( + self, + pretrained="", + max_length=4096, + batch_size=32, + mc_num=128, + is_check_greedy=True, + cfg=0.0, + steps=1024, + gen_length=1024, + block_length=1024, + remasking="low_confidence", + device="cuda", + mask_id=None, + **kwargs, + ): + """ + Args: + model_path: LLaDA-8B-Base model path. + mask_id: The token id of [MASK] is 126336. + max_length: the max sequence length. + batch_size: mini batch size. + mc_num: Monte Carlo estimation iterations + is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer + is generated through greedy sampling conditioned on the prompt (note that this differs from conditional + generation). We implement this verification through the suffix_greedy_prediction() function, which + returns a True/False judgment used for accuracy calculation. + When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function. + However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality, + we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False + by default, significantly accelerating the evaluation process. + cfg_scale: Unsupervised classifier-free guidance scale. + """ + super().__init__() + + accelerator = accelerate.Accelerator() + if accelerator.num_processes > 1: + self.accelerator = accelerator + else: + self.accelerator = None + print("pretrained", pretrained) + self.tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True) + if mask_id is None: + self.mask_id = self.tokenizer.mask_token_id + else: + self.mask_id = mask_id + model_kwargs = {"mask_token_id": self.mask_id} + + if self.accelerator is not None: + model_kwargs.update({"device_map": {"": f"{self.accelerator.device}"}}) + self.model = AutoModelForCausalLM.from_pretrained( + pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs + ) + self.model = torch.compile(self.model) + self.model.eval() + + self.device = torch.device(device) + if self.accelerator is not None: + self.model = self.accelerator.prepare(self.model) + self.device = torch.device(f"{self.accelerator.device}") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.model = self.model.to(device) + + # self.mask_id = self.tokenizer.convert_tokens_to_ids(FIM_MIDDLE) + + self.mc_num = mc_num + self.batch_size = int(batch_size) + assert mc_num % self.batch_size == 0 + self.sampling_eps = 0.0 + self.max_length = max_length + self.is_check_greedy = is_check_greedy + + self.cfg = cfg + self.steps = steps + self.gen_length = gen_length + self.block_length = block_length + self.remasking = remasking + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def _forward_process(self, batch_plus_one, prompt_index_plus_one): + batch = batch_plus_one[:, 1:] + prompt_index = prompt_index_plus_one[1:] + b, l = batch.shape + target_len = (l - prompt_index.sum()).item() + k = torch.randint(1, target_len + 1, (), device=batch.device) + x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long() + x = ((x - 1) % target_len) + 1 + assert x.min() >= 1 and x.max() <= target_len + + indices = torch.arange(target_len, device=batch.device).repeat(b, 1) + is_mask = indices < x.unsqueeze(1) + + for i in range(b): + is_mask[i] = is_mask[i][torch.randperm(target_len)] + + is_mask = torch.cat( + (torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1 + ) + + noisy_batch = torch.cat([batch_plus_one[:, :1], torch.where(is_mask, self.mask_id, batch)], dim=1) + + return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l + 1) + + @torch.no_grad() + def get_logits(self, batch, prompt_index): + if self.cfg > 0.0: + assert len(prompt_index) == batch.shape[1] + prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) + un_batch = batch.clone() + un_batch[prompt_index] = self.mask_id + batch = torch.cat([batch, un_batch]) + + logits = self.model(batch).logits + + if self.cfg > 0.0: + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (self.cfg + 1) * (logits - un_logits) + return logits[:, : batch.shape[1]] + + @torch.no_grad() + def get_loglikelihood(self, prefix, target): + if prefix is not None: + seq = torch.concatenate([prefix, target])[None, :] + prefix_len = len(prefix) + else: + seq = target[None, :] + prefix_len = 0 + + seq = seq.repeat((self.batch_size, 1)).to(self.device) + + prompt_index = torch.arange(seq.shape[1], device=self.device) < prefix_len + + loss_acc = [] + for _ in range(self.mc_num // self.batch_size): + perturbed_seq, p_mask = self._forward_process(seq, prompt_index) + mask_indices = perturbed_seq == self.mask_id + pred_mask_indices = F.pad(mask_indices[:, 1:], (0, 1), value=0) + seq_ = perturbed_seq.clone() + seq_[mask_indices] = seq[pred_mask_indices] + # print(self.tokenizer.decode(seq_[0])) + logits = self.get_logits(perturbed_seq, prompt_index) + loss = ( + F.cross_entropy(logits[pred_mask_indices], seq[mask_indices], reduction="none") + / p_mask[pred_mask_indices] + ) + loss = loss.sum() / self.batch_size + + loss_acc.append(loss.item()) + return -sum(loss_acc) / len(loss_acc) + + @torch.no_grad() + def suffix_greedy_prediction(self, prefix, target): + if not self.is_check_greedy: + return False + + seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device) + prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) + prefix, target = prefix.to(self.device), target.to(self.device) + seq[0, : len(prefix)] = prefix + + for i in range(len(target)): + mask_index = seq == self.mask_id + logits = self.get_logits(seq, prompt_index)[mask_index] + x0 = torch.argmax(logits, dim=-1) + + p = torch.softmax(logits.to(torch.float32), dim=-1) + confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1) + _, index = torch.sort(confidence, descending=True) + x0[index[1:]] = self.mask_id + seq[mask_index] = x0.clone() + correct = target == seq[0, len(prefix) :] + correct = torch.all(correct) + return correct + + def _encode_pair(self, context, continuation): + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + whole_enc = self.tokenizer(context + continuation)["input_ids"] + context_enc = self.tokenizer(context)["input_ids"] + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood(self, requests): + def _tokenize(e): + prefix, target = self._encode_pair(e["prefix"], e["target"]) + return { + "prefix_text": e["prefix"], + "target_text": e["target"], + "prefix": prefix, + "target": target, + } + + ds = [] + ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds] + + assert max(prompt_len) <= 4096 + + out = [] + with torch.no_grad(): + for elem in tqdm(ds, desc="Computing likelihood..."): + prefix = elem["prefix"] + target = elem["target"] + + ll = self.get_loglikelihood(prefix, target) + + is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target) + + out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) + torch.cuda.empty_cache() + return out + + def loglikelihood_rolling(self, requests): + chunk_size = 4096 + loglikelihoods = [] + for i in tqdm(range(len(requests))): + x = self.tokenizer(requests[i].args[0]) + x_seq = [self.tokenizer.eos_token_id] + x["input_ids"] + # x_seq.append() + x_seq = torch.tensor(x_seq, dtype=torch.long, device=torch.cuda.current_device()) + # chunks = ((len(x_seq) - 1) // chunk_size) + 1 + chunks = int(math.ceil((x_seq.size(0) - 1) / (chunk_size - 1))) + total_ll = 0.0 + start_idx = 0 + for c in range(chunks): + x_seq_chunk = x_seq[start_idx : start_idx + chunk_size] + ll = self.get_loglikelihood(prefix=None, target=x_seq_chunk) + total_ll += ll + start_idx += chunk_size - 1 + loglikelihoods.append(total_ll) + assert start_idx >= x_seq.size(0) + return loglikelihoods + + def generate_until(self, requests: list[Instance]): + def _tokenize(e): + return { + "question": self.tokenizer(e["question"])["input_ids"], + "question_text": e["question"], + "until": e["until"], + } + + ds = [{"question": req.args[0], "until": req.args[1]["until"]} for req in requests] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + + out = [] + for elem in tqdm(ds, desc="Generating..."): + prompt = elem["question"].unsqueeze(0).to(self.device) + stop_tokens = elem["until"] + + generated_answer = generate( + self.model, + prompt, + steps=self.steps, + gen_length=self.gen_length, + block_length=self.block_length, + temperature=0, + cfg_scale=self.cfg, + remasking=self.remasking, + mask_id=self.mask_id, + ) + + generated_answer = self.tokenizer.decode(generated_answer[0][prompt.shape[1] :], skip_special_tokens=False) + for stop_seq in stop_tokens: + if stop_seq in generated_answer: + generated_answer = generated_answer.split(stop_seq)[0] + + # remove special tokens + generated_answer_ids = self.tokenizer(generated_answer)["input_ids"] + generated_answer = self.tokenizer.decode(generated_answer_ids, skip_special_tokens=True) + out.append(generated_answer) + + self.accelerator.wait_for_everyone() + + return out + + +if __name__ == "__main__": + # python diffusion_eval.py --tasks wikitext --model llada_dist --batch_size 1 --model_args model_path='/proj/checkpoints/shawntan/statebreaking/diffusion-1b-24l/unsharded-10000',mc_num=128 + set_seed(1234) + cli_evaluate() diff --git a/examples/diffusion/eval.sh b/examples/diffusion/eval.sh new file mode 100755 index 000000000..b85c9b646 --- /dev/null +++ b/examples/diffusion/eval.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +MODEL_PATH="$1" +set -x +RESULT_PATH=$MODEL_PATH/results/ +mkdir -p $RESULT_PATH +export PYTHONPATH=./cute-kernels +# accelerate launch diffusion_eval.py --tasks wikitext \ +accelerate launch diffusion_eval.py --tasks wikitext \ + --model lm_engine_diffusion --batch_size 8 \ + --model_args pretrained=${MODEL_PATH},mc_num=128 | tee $RESULT_PATH/wikitext.log +exit +accelerate launch diffusion_eval.py --tasks hellaswag \ + --num_fewshot 0 --model llada_dist --batch_size 8 \ + --model_args model_path=${MODEL_PATH},cfg=0.5,is_check_greedy=False,mc_num=128 | tee $RESULT_PATH/hellaswag.log +accelerate launch diffusion_eval.py --tasks winogrande \ + --num_fewshot 5 --model llada_dist --batch_size 8 \ + --model_args model_path=${MODEL_PATH},cfg=0.0,is_check_greedy=False,mc_num=128 | tee $RESULT_PATH/winogrande.log +accelerate launch diffusion_eval.py --tasks arc_challenge \ + --num_fewshot 0 --model llada_dist --batch_size 8 \ + --model_args model_path=${MODEL_PATH},cfg=0.5,is_check_greedy=False,mc_num=128 | tee $RESULT_PATH/arc_challenge.log + +accelerate launch diffusion_eval.py --tasks arc_easy \ + --num_fewshot 0 --model llada_dist --batch_size 8 \ + --model_args model_path=${MODEL_PATH},cfg=0.5,is_check_greedy=False,mc_num=128 | tee $RESULT_PATH/arc_easy.log + +accelerate launch diffusion_eval.py --tasks mmlu --num_fewshot 5 --model llada_dist --batch_size 1 \ + --model_args model_path=${MODEL_PATH},cfg=0.0,is_check_greedy=False,mc_num=1 | tee $RESULT_PATH/mmlu.log +