diff --git a/cute-kernels b/cute-kernels index a7d23d004..265584c61 160000 --- a/cute-kernels +++ b/cute-kernels @@ -1 +1 @@ -Subproject commit a7d23d0047ad3b68eed5d6be18bb21bc3eaaab1c +Subproject commit 265584c615a5acae52b68102667009eca87c70d6 diff --git a/examples/diffusion/diffusion-1b-24l.yml b/examples/diffusion/diffusion-1b-24l.yml new file mode 100755 index 000000000..ca7f01c0f --- /dev/null +++ b/examples/diffusion/diffusion-1b-24l.yml @@ -0,0 +1,352 @@ +datasets: + # class_name - data_name & data_sampling_ratio are not used but need to be passed to avoid errors + - class_name: MegatronDataset + data_name: Megatron + data_sampling_ratio: 1 + class_args: + eval_steps: 2 + data_cache_path: /proj/checkpoints/shawntan/diffusion/release/data-cache + data_path: + - 1 # mixture ratio + - /proj/checkpoints/shawntan/diffusion/release/data/dclm-dedup-gpt2-tokenized/dclm_00_text # path prefix + split: 100,0,0 + sequence_length: 4096 # context length + + +tokenizer_args: + tokenizer_name: /proj/checkpoints/shawntan/diffusion/release/data/tokenizer + +kernel_args: + kernels: + - swiglu_packed_cute + - rmsnorm_cute + - scattermoe + - flash_attention_2 + +model_args: + model_class: AutoModelForCausalLM + pretrained_config: + initializer_range: 0.1 + layer_norm_epsilon: 1e-05 + model_type: diffusion + normalization_function: rmsnorm + position_embedding_type: rope + hidden_size: 2048 + m_width: 8 + m_emb: 12 + m_residual: 0.28577380332470415 + num_layers: 24 + init_method: mup + tie_word_embeddings: true + router_aux_loss_coef: 0.01 + bos_token_id: 50256 # ensure these are same in the tokenizer + eos_token_id: 50256 + pad_token_id: 50258 + vocab_size: 50259 + max_position_embeddings: 4096 + sequence_mixer_blocks: + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + - sequence_mixer_type: softmax_attention + causal: false + num_attention_heads: 16 + num_key_value_heads: 16 + add_bias: false + attention_multiplier: 0.0078125 + mlp_blocks: + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + - mlp_type: MLP + activation_function: swiglu + intermediate_size: 4096 + add_bias: false + + + use_padding_free_transformer: true + # efficient_initialization: true + reset_attention_mask: true + reset_position_ids: true + +tuning_args: + tuning_method: pretraining_diffusion + +save_args: + save_path: /proj/checkpoints/shawntan/diffusion/release/data/diffusion-24l-1b + save_interval: 5000 + +# TODO restoring from last checkpoint +# load_args: +# load_path: /proj/checkpoints/shawntan/diffusion/release/data/diffusion-24l-1b + +logging_args: + log_interval: 10 +# experiments_tracker_name: wandb +# wandb_args: +# project: diffusion-release +# name: diffusion-1b-24l + + +training_parameters: + num_training_steps: 75000 + eval_interval: 1000000000 + micro_batch_size: 2 + gradient_accumulation_steps: 4 + eval_during_training: false + +optimizer_args: + params_group_method: mup + class_name: TorchAdamW + class_args: + lr: 0.01 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + eps: 1e-10 + +lr_scheduler_args: + lr_decay_style: power + num_warmup_steps: 5000 + num_constant_steps: 0 + num_decay_steps: 70000 + extra_lr_scheduler_args: + # 4 * global_batch_size + a: 4096 + # constant + b: -0.51 + # global_batch_size in number of tokens + c: 4194304 + +mixed_precision_args: + dtype: bf16 + +distributed_args: + fsdp_algorithm: 2 + torch_compile: true + stage: 0 diff --git a/examples/diffusion/diffusion.sh b/examples/diffusion/diffusion.sh new file mode 100755 index 000000000..7cbd91802 --- /dev/null +++ b/examples/diffusion/diffusion.sh @@ -0,0 +1,25 @@ +#!/bin/bash +set -x +DATASET="Zyphra/dclm-dedup" +BASE_TOKENIZER="openai-community/gpt2" +DATA_PATH="../data/" +mkdir -p $DATA_PATH +TRAIN_PATH="$DATA_PATH/dclm-dedup-gpt2-tokenized" +mkdir -p $TRAIN_PATH +TOKENIZER_PATH="$DATA_PATH/tokenizer" +mkdir -p $TOKENIZER_PATH + +python -u examples/diffusion/modify_tokenizer.py --tokenizer $BASE_TOKENIZER --output-path $TOKENIZER_PATH + +CHUNK=0 +CHUNK_SIZE=20000000 +START_IDX=$(($CHUNK * $CHUNK_SIZE)) +END_IDX=$(($START_IDX + $CHUNK_SIZE)) +SPLIT="train[$START_IDX:$END_IDX]" + +OUTPUT_FILE="$TRAIN_PATH/dclm_`printf '%02d' $CHUNK`" +python -u examples/diffusion/preprocess_data.py \ + --input Zyphra/dclm-dedup --split $SPLIT \ + --tokenizer $TOKENIZER_PATH \ + --output-prefix $OUTPUT_FILE \ + --workers 128 --chunk-size 8192 --append-eod diff --git a/examples/diffusion/diffusion_eval.py b/examples/diffusion/diffusion_eval.py new file mode 100755 index 000000000..fd6148d28 --- /dev/null +++ b/examples/diffusion/diffusion_eval.py @@ -0,0 +1,340 @@ +""" +This file is inspired by the code from https://github.com/ML-GSAI/SMDM +""" + +import math +import random + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +from datasets import Dataset +from lm_eval.__main__ import cli_evaluate +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from generate import generate +from lm_engine import hf_models +from lm_engine.kernels import Kernel, enable_kernels + + +enable_kernels( + [Kernel.mamba2_ssm, Kernel.scattermoe, Kernel.rmsnorm_cute, Kernel.swiglu_packed_cute, Kernel.flash_attention_2] +).__enter__() + + +def set_seed(seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +@register_model("lm_engine_diffusion") +class LMEngineDiffusionEvalHarness(LM): + def __init__( + self, + pretrained="", + max_length=4096, + batch_size=32, + mc_num=128, + is_check_greedy=True, + cfg=0.0, + steps=1024, + gen_length=1024, + block_length=1024, + remasking="low_confidence", + device="cuda", + mask_id=None, + **kwargs, + ): + """ + Args: + model_path: LLaDA-8B-Base model path. + mask_id: The token id of [MASK] is 126336. + max_length: the max sequence length. + batch_size: mini batch size. + mc_num: Monte Carlo estimation iterations + is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer + is generated through greedy sampling conditioned on the prompt (note that this differs from conditional + generation). We implement this verification through the suffix_greedy_prediction() function, which + returns a True/False judgment used for accuracy calculation. + When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function. + However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality, + we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False + by default, significantly accelerating the evaluation process. + cfg_scale: Unsupervised classifier-free guidance scale. + """ + super().__init__() + + accelerator = accelerate.Accelerator() + if accelerator.num_processes > 1: + self.accelerator = accelerator + else: + self.accelerator = None + print("pretrained", pretrained) + self.tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True) + if mask_id is None: + self.mask_id = self.tokenizer.mask_token_id + else: + self.mask_id = mask_id + model_kwargs = {"mask_token_id": self.mask_id} + + if self.accelerator is not None: + model_kwargs.update({"device_map": {"": f"{self.accelerator.device}"}}) + self.model = AutoModelForCausalLM.from_pretrained( + pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16, **model_kwargs + ) + self.model = torch.compile(self.model) + self.model.eval() + + self.device = torch.device(device) + if self.accelerator is not None: + self.model = self.accelerator.prepare(self.model) + self.device = torch.device(f"{self.accelerator.device}") + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + else: + self.model = self.model.to(device) + + # self.mask_id = self.tokenizer.convert_tokens_to_ids(FIM_MIDDLE) + + self.mc_num = mc_num + self.batch_size = int(batch_size) + assert mc_num % self.batch_size == 0 + self.sampling_eps = 0.0 + self.max_length = max_length + self.is_check_greedy = is_check_greedy + + self.cfg = cfg + self.steps = steps + self.gen_length = gen_length + self.block_length = block_length + self.remasking = remasking + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + def _forward_process(self, batch_plus_one, prompt_index_plus_one): + batch = batch_plus_one[:, 1:] + prompt_index = prompt_index_plus_one[1:] + b, l = batch.shape + target_len = (l - prompt_index.sum()).item() + k = torch.randint(1, target_len + 1, (), device=batch.device) + x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long() + x = ((x - 1) % target_len) + 1 + assert x.min() >= 1 and x.max() <= target_len + + indices = torch.arange(target_len, device=batch.device).repeat(b, 1) + is_mask = indices < x.unsqueeze(1) + + for i in range(b): + is_mask[i] = is_mask[i][torch.randperm(target_len)] + + is_mask = torch.cat( + (torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1 + ) + + noisy_batch = torch.cat([batch_plus_one[:, :1], torch.where(is_mask, self.mask_id, batch)], dim=1) + + return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l + 1) + + @torch.no_grad() + def get_logits(self, batch, prompt_index): + if self.cfg > 0.0: + assert len(prompt_index) == batch.shape[1] + prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) + un_batch = batch.clone() + un_batch[prompt_index] = self.mask_id + batch = torch.cat([batch, un_batch]) + + logits = self.model(batch).logits + + if self.cfg > 0.0: + logits, un_logits = torch.chunk(logits, 2, dim=0) + logits = un_logits + (self.cfg + 1) * (logits - un_logits) + return logits[:, : batch.shape[1]] + + @torch.no_grad() + def get_loglikelihood(self, prefix, target): + if prefix is not None: + seq = torch.concatenate([prefix, target])[None, :] + prefix_len = len(prefix) + else: + seq = target[None, :] + prefix_len = 0 + + seq = seq.repeat((self.batch_size, 1)).to(self.device) + + prompt_index = torch.arange(seq.shape[1], device=self.device) < prefix_len + + loss_acc = [] + for _ in range(self.mc_num // self.batch_size): + perturbed_seq, p_mask = self._forward_process(seq, prompt_index) + mask_indices = perturbed_seq == self.mask_id + pred_mask_indices = F.pad(mask_indices[:, 1:], (0, 1), value=0) + seq_ = perturbed_seq.clone() + seq_[mask_indices] = seq[pred_mask_indices] + # print(self.tokenizer.decode(seq_[0])) + logits = self.get_logits(perturbed_seq, prompt_index) + loss = ( + F.cross_entropy(logits[pred_mask_indices], seq[mask_indices], reduction="none") + / p_mask[pred_mask_indices] + ) + loss = loss.sum() / self.batch_size + + loss_acc.append(loss.item()) + return -sum(loss_acc) / len(loss_acc) + + @torch.no_grad() + def suffix_greedy_prediction(self, prefix, target): + if not self.is_check_greedy: + return False + + seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device) + prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix) + prefix, target = prefix.to(self.device), target.to(self.device) + seq[0, : len(prefix)] = prefix + + for i in range(len(target)): + mask_index = seq == self.mask_id + logits = self.get_logits(seq, prompt_index)[mask_index] + x0 = torch.argmax(logits, dim=-1) + + p = torch.softmax(logits.to(torch.float32), dim=-1) + confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1) + _, index = torch.sort(confidence, descending=True) + x0[index[1:]] = self.mask_id + seq[mask_index] = x0.clone() + correct = target == seq[0, len(prefix) :] + correct = torch.all(correct) + return correct + + def _encode_pair(self, context, continuation): + n_spaces = len(context) - len(context.rstrip()) + if n_spaces > 0: + continuation = context[-n_spaces:] + continuation + context = context[:-n_spaces] + + whole_enc = self.tokenizer(context + continuation)["input_ids"] + context_enc = self.tokenizer(context)["input_ids"] + + context_enc_len = len(context_enc) + continuation_enc = whole_enc[context_enc_len:] + + return context_enc, continuation_enc + + def loglikelihood(self, requests): + def _tokenize(e): + prefix, target = self._encode_pair(e["prefix"], e["target"]) + return { + "prefix_text": e["prefix"], + "target_text": e["target"], + "prefix": prefix, + "target": target, + } + + ds = [] + ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds] + + assert max(prompt_len) <= 4096 + + out = [] + with torch.no_grad(): + for elem in tqdm(ds, desc="Computing likelihood..."): + prefix = elem["prefix"] + target = elem["target"] + + ll = self.get_loglikelihood(prefix, target) + + is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target) + + out.append((ll, 1.0 if is_target_greedy_dec else 0.0)) + torch.cuda.empty_cache() + return out + + def loglikelihood_rolling(self, requests): + chunk_size = 4096 + loglikelihoods = [] + for i in tqdm(range(len(requests))): + x = self.tokenizer(requests[i].args[0]) + x_seq = [self.tokenizer.eos_token_id] + x["input_ids"] + # x_seq.append() + x_seq = torch.tensor(x_seq, dtype=torch.long, device=torch.cuda.current_device()) + # chunks = ((len(x_seq) - 1) // chunk_size) + 1 + chunks = int(math.ceil((x_seq.size(0) - 1) / (chunk_size - 1))) + total_ll = 0.0 + start_idx = 0 + for c in range(chunks): + x_seq_chunk = x_seq[start_idx : start_idx + chunk_size] + ll = self.get_loglikelihood(prefix=None, target=x_seq_chunk) + total_ll += ll + start_idx += chunk_size - 1 + loglikelihoods.append(total_ll) + assert start_idx >= x_seq.size(0) + return loglikelihoods + + def generate_until(self, requests: list[Instance]): + def _tokenize(e): + return { + "question": self.tokenizer(e["question"])["input_ids"], + "question_text": e["question"], + "until": e["until"], + } + + ds = [{"question": req.args[0], "until": req.args[1]["until"]} for req in requests] + ds = Dataset.from_list(ds) + ds = ds.map(_tokenize) + ds = ds.with_format("torch") + + out = [] + for elem in tqdm(ds, desc="Generating..."): + prompt = elem["question"].unsqueeze(0).to(self.device) + stop_tokens = elem["until"] + + generated_answer = generate( + self.model, + prompt, + steps=self.steps, + gen_length=self.gen_length, + block_length=self.block_length, + temperature=0, + cfg_scale=self.cfg, + remasking=self.remasking, + mask_id=self.mask_id, + ) + + generated_answer = self.tokenizer.decode(generated_answer[0][prompt.shape[1] :], skip_special_tokens=False) + for stop_seq in stop_tokens: + if stop_seq in generated_answer: + generated_answer = generated_answer.split(stop_seq)[0] + + # remove special tokens + generated_answer_ids = self.tokenizer(generated_answer)["input_ids"] + generated_answer = self.tokenizer.decode(generated_answer_ids, skip_special_tokens=True) + out.append(generated_answer) + + self.accelerator.wait_for_everyone() + + return out + + +if __name__ == "__main__": + # python diffusion_eval.py --tasks wikitext --model llada_dist --batch_size 1 --model_args model_path='/proj/checkpoints/shawntan/statebreaking/diffusion-1b-24l/unsharded-10000',mc_num=128 + set_seed(1234) + cli_evaluate() diff --git a/examples/diffusion/eval.sh b/examples/diffusion/eval.sh new file mode 100755 index 000000000..b85c9b646 --- /dev/null +++ b/examples/diffusion/eval.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +MODEL_PATH="$1" +set -x +RESULT_PATH=$MODEL_PATH/results/ +mkdir -p $RESULT_PATH +export PYTHONPATH=./cute-kernels +# accelerate launch diffusion_eval.py --tasks wikitext \ +accelerate launch diffusion_eval.py --tasks wikitext \ + --model lm_engine_diffusion --batch_size 8 \ + --model_args pretrained=${MODEL_PATH},mc_num=128 | tee $RESULT_PATH/wikitext.log +exit +accelerate launch diffusion_eval.py --tasks hellaswag \ + --num_fewshot 0 --model llada_dist --batch_size 8 \ + --model_args model_path=${MODEL_PATH},cfg=0.5,is_check_greedy=False,mc_num=128 | tee $RESULT_PATH/hellaswag.log +accelerate launch diffusion_eval.py --tasks winogrande \ + --num_fewshot 5 --model llada_dist --batch_size 8 \ + --model_args model_path=${MODEL_PATH},cfg=0.0,is_check_greedy=False,mc_num=128 | tee $RESULT_PATH/winogrande.log +accelerate launch diffusion_eval.py --tasks arc_challenge \ + --num_fewshot 0 --model llada_dist --batch_size 8 \ + --model_args model_path=${MODEL_PATH},cfg=0.5,is_check_greedy=False,mc_num=128 | tee $RESULT_PATH/arc_challenge.log + +accelerate launch diffusion_eval.py --tasks arc_easy \ + --num_fewshot 0 --model llada_dist --batch_size 8 \ + --model_args model_path=${MODEL_PATH},cfg=0.5,is_check_greedy=False,mc_num=128 | tee $RESULT_PATH/arc_easy.log + +accelerate launch diffusion_eval.py --tasks mmlu --num_fewshot 5 --model llada_dist --batch_size 1 \ + --model_args model_path=${MODEL_PATH},cfg=0.0,is_check_greedy=False,mc_num=1 | tee $RESULT_PATH/mmlu.log + diff --git a/examples/diffusion/modify_tokenizer.py b/examples/diffusion/modify_tokenizer.py new file mode 100755 index 000000000..8b7016c73 --- /dev/null +++ b/examples/diffusion/modify_tokenizer.py @@ -0,0 +1,27 @@ +import sys +from argparse import ArgumentParser, Namespace + +from transformers import AutoTokenizer, PreTrainedTokenizer + + +def get_args() -> Namespace: + parser = ArgumentParser() + group = parser.add_argument_group(title="input data") + group.add_argument("--tokenizer", type=str, required=True, help="Path to the tokenizer") + group = parser.add_argument_group(title="output data") + group.add_argument("--output-path", type=str, required=True, help="Path to binary output file without suffix") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(args.tokenizer, model_max_length=4096) + tokenizer.add_special_tokens({"mask_token": ""}) + tokenizer.add_special_tokens({"pad_token": ""}) + tokenizer.model_max_length = sys.maxsize + print("bos_token_id", tokenizer.bos_token_id) + print("eos_token_id", tokenizer.eos_token_id) + print("pad_token_id", tokenizer.pad_token_id) + print("Vocab size:", len(tokenizer)) + tokenizer.save_pretrained(args.output_path) diff --git a/examples/diffusion/preprocess_data.py b/examples/diffusion/preprocess_data.py new file mode 100755 index 000000000..9848354db --- /dev/null +++ b/examples/diffusion/preprocess_data.py @@ -0,0 +1,126 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import json +import multiprocessing +from argparse import ArgumentParser, Namespace +from typing import List + +import datasets +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from lm_engine.data.megatron.indexed_dataset import DType, MMapIndexedDatasetBuilder + + +class Encoder: + def __init__(self, tokenizer: AutoTokenizer, json_keys: List[str], append_eod: bool, tokenizer_str: str) -> None: + self.tokenizer_str = tokenizer_str + self.tokenizer = None + self.json_keys = json_keys + self.append_eod = append_eod + + def _encode_data(self, data): + ids = {} + for key in self.json_keys: + text = data[key] + # text = text.encode('ascii','backslashreplace').decode('ascii') # TODO + document_ids = self.tokenizer.encode(text) + if len(document_ids) > 0: + if self.append_eod: + document_ids.append(self.tokenizer.eos_token_id) + # decoded_text = self.tokenizer.decode(document_ids) + # print(decoded_text) + # exit() + ids[key] = document_ids + return ids + + def encode(self, json_line): + data = json.loads(json_line) + return self._encode_data(data) + + def encode_jsonl_zstd(self, bytes_obj): + json_str = bytes_obj.decode("utf-8") + return self.encode(json_str) + + def load_tokenizer(self): + if self.tokenizer is None: + self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_str) + + def encode_hf(self, sample): + self.load_tokenizer() + return self._encode_data(sample) + + +def get_args() -> Namespace: + parser = ArgumentParser() + + group = parser.add_argument_group(title="input data") + group.add_argument("--input", type=str, required=True, help="Path to input JSON/Arrow") + group.add_argument( + "--subset", type=str, default=None, help="Subset argument when loading input data from a HuggingFace dataset" + ) + group.add_argument( + "--split", type=str, default="train", help="Split argument when loading input data from a HuggingFace dataset" + ) + + group.add_argument( + "--json-keys", nargs="+", default=["text"], help="space separate listed of keys to extract from json" + ) + + group = parser.add_argument_group(title="tokenizer") + group.add_argument("--tokenizer", type=str, required=True, help="Path to the tokenizer") + group.add_argument("--append-eod", action="store_true", help="Append an token to the end of a document.") + + group = parser.add_argument_group(title="output data") + group.add_argument("--output-prefix", type=str, required=True, help="Path to binary output file without suffix") + + group = parser.add_argument_group(title="runtime") + group.add_argument("--workers", type=int, required=True, help="Number of worker processes to launch") + group.add_argument("--chunk-size", type=int, required=True, help="Chunk size assigned to each worker process") + args = parser.parse_args() + + return args + + +def main() -> None: + args = get_args() + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + del tokenizer.model_max_length + encoder = Encoder(tokenizer, args.json_keys, args.append_eod, tokenizer_str=args.tokenizer) + + def init(): + encoder.tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + + print(args.input, args.subset, args.split) + pool = multiprocessing.Pool(args.workers, initializer=init) + # ds = load_dataset(args.input, use_auth_token=True, streaming=True, split=args.split, data_dir=args.subset) + ds = load_dataset( + args.input, + data_dir=args.subset, + split=args.split, + ) + + encoded_docs = pool.imap(encoder.encode_hf, ds, args.chunk_size) + + builders = { + key: MMapIndexedDatasetBuilder( + f"{args.output_prefix}_{key}.bin", dtype=DType.optimal_dtype(tokenizer.vocab_size) + ) + for key in args.json_keys + } + + for item in tqdm(encoded_docs): + for key, document in item.items(): + builders[key].add_item(torch.IntTensor(document)) + builders[key].end_document() + + print("Done! Now finalizing.") + + for key in args.json_keys: + builders[key].finalize(f"{args.output_prefix}_{key}.idx") + + +if __name__ == "__main__": + main() diff --git a/lm_engine/enums.py b/lm_engine/enums.py index 66ee05fa2..6abb24823 100644 --- a/lm_engine/enums.py +++ b/lm_engine/enums.py @@ -33,6 +33,7 @@ class TuningMethod(Enum): """training method""" pretraining = "pretraining" + pretraining_diffusion = "pretraining_diffusion" full_finetuning = "full_finetuning" distillation = "distillation" diff --git a/lm_engine/hf_models/config/__init__.py b/lm_engine/hf_models/config/__init__.py index fe76d8fae..6f7ebe87b 100644 --- a/lm_engine/hf_models/config/__init__.py +++ b/lm_engine/hf_models/config/__init__.py @@ -215,7 +215,6 @@ def _set_sequence_mixer_blocks(self) -> None: sequence_mixer_block["intermediate_size"] = sequence_mixer_block.pop( "intermediate_size", 2 * self.hidden_size ) - sequence_mixer_blocks.append(_SEQUENCE_MIXER_CONFIG_CLASSES[sequence_mixer_type](**sequence_mixer_block)) self.sequence_mixer_blocks = sequence_mixer_blocks diff --git a/lm_engine/hf_models/config/sequence_mixer.py b/lm_engine/hf_models/config/sequence_mixer.py index da103eaed..d8b3396d0 100644 --- a/lm_engine/hf_models/config/sequence_mixer.py +++ b/lm_engine/hf_models/config/sequence_mixer.py @@ -14,6 +14,7 @@ class _SoftmaxAttentionArgs(BaseArgs): softmax_dropout: float = 0 dropout: float = 0 add_bias: bool = True + causal: bool = True attention_multiplier: float | None = None def model_post_init(self, __context: Any) -> None: diff --git a/lm_engine/hf_models/mixins/dense/layer.py b/lm_engine/hf_models/mixins/dense/layer.py index a0c3541cd..90fc34360 100644 --- a/lm_engine/hf_models/mixins/dense/layer.py +++ b/lm_engine/hf_models/mixins/dense/layer.py @@ -25,7 +25,9 @@ def __init__( self.ln_1 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) + self.sequence_mixer = get_sequence_mixer(config, True, use_padding_free_transformer, layer_idx) + self.ln_2 = get_normalization_function( config.normalization_function, hidden_size, eps=config.layer_norm_epsilon ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py index dfdfdfd32..901e457b2 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/__init__.py @@ -126,7 +126,7 @@ def get_sequence_mixer( initializer_range=config.initializer_range, m_width=config.m_width, num_layers=config.num_layers, - causal=causal, + causal=causal if not hasattr(block, "causal") else block.causal, layer_idx=layer_idx, ) diff --git a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py index a812473ac..57b59cced 100644 --- a/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py +++ b/lm_engine/hf_models/modeling_utils/sequence_mixer_blocks/attention.py @@ -128,6 +128,9 @@ def __init__( mark_parameter_as_mup_learning_rate(self.c_attn.weight) mark_parameter_as_mup_learning_rate(self.c_proj.weight) + def extra_repr(self): + return f"causal={self.causal}, num_heads={self.num_heads}, num_key_value_heads={self.num_key_value_heads}," + def forward( self, hidden_states: torch.Tensor, diff --git a/lm_engine/hf_models/models/__init__.py b/lm_engine/hf_models/models/__init__.py index ecf78511d..346d969fe 100644 --- a/lm_engine/hf_models/models/__init__.py +++ b/lm_engine/hf_models/models/__init__.py @@ -13,3 +13,4 @@ from .ladder_residual import LadderResidualConfig, LadderResidualForCausalLM, LadderResidualModel from .ladder_residual_TP import LadderResidualForCausalLM_TP, LadderResidualModel_TP from .palm import PaLMConfig, PaLMForCausalLM, PaLMModel +from .diffusion import DiffusionConfig, DiffusionMaskedLM, DiffusionModel \ No newline at end of file diff --git a/lm_engine/hf_models/models/diffusion/__init__.py b/lm_engine/hf_models/models/diffusion/__init__.py new file mode 100644 index 000000000..5c95e6bad --- /dev/null +++ b/lm_engine/hf_models/models/diffusion/__init__.py @@ -0,0 +1,7 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from .base import DiffusionModel +from .config import DiffusionConfig +from .main import DiffusionMaskedLM diff --git a/lm_engine/hf_models/models/diffusion/base.py b/lm_engine/hf_models/models/diffusion/base.py new file mode 100644 index 000000000..0552d095a --- /dev/null +++ b/lm_engine/hf_models/models/diffusion/base.py @@ -0,0 +1,23 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from ...mixins import BaseModelMixin, PreTrainedModelMixin +from .config import DiffusionConfig + + +class DiffusionPreTrainedModel(PreTrainedModelMixin): + config_class = DiffusionConfig + + +class DiffusionModel(DiffusionPreTrainedModel, BaseModelMixin): + def __init__(self, config, **kwargs): + if "mask_token_id" in kwargs: + self.mask_token_id = kwargs.pop("mask_token_id") + super().__init__(config, **kwargs) + + def _get_initial_hidden_state(self, input_ids, position_ids): + hidden_state = super()._get_initial_hidden_state(input_ids, position_ids) + # mask = (input_ids == self.mask_token_id)[:, None] + # hidden_state = hidden_state.masked_fill_(mask, 0) + return hidden_state diff --git a/lm_engine/hf_models/models/diffusion/config.py b/lm_engine/hf_models/models/diffusion/config.py new file mode 100644 index 000000000..ab1e1c79a --- /dev/null +++ b/lm_engine/hf_models/models/diffusion/config.py @@ -0,0 +1,8 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from ..gpt_base import GPTBaseConfig + +class DiffusionConfig(GPTBaseConfig): + model_type = "diffusion" diff --git a/lm_engine/hf_models/models/diffusion/main.py b/lm_engine/hf_models/models/diffusion/main.py new file mode 100644 index 000000000..cbe798d47 --- /dev/null +++ b/lm_engine/hf_models/models/diffusion/main.py @@ -0,0 +1,170 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +import torch +import torch.nn.functional as F +from transformers import GenerationMixin + +from ....enums import Kernel +from ....kernels import is_kernel_allowed +from ...cache import GenerationCache +from ...config import CommonConfig +from ...loss import clear_aux_loss, get_autoregressive_language_modeling_loss, get_aux_loss, is_aux_loss_zero +from ...mixins import CausalLMModelMixin +from ...mixins.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear +from .base import DiffusionModel, DiffusionPreTrainedModel +from .config import DiffusionConfig + + +# from .base import PreTrainedModelMixin + + +class DiffusionMaskedLM(DiffusionPreTrainedModel): + def __init__(self, config: DiffusionConfig, **kwargs) -> DiffusionPreTrainedModel: + if "mask_token_id" in kwargs: + self.mask_token_id = kwargs.pop("mask_token_id") + super().__init__(config, **kwargs) + self.router_aux_loss_coef = getattr(config, "router_aux_loss_coef", 0) + self._init_model(config, **kwargs) + + def _init_model(self, config: DiffusionConfig, **kwargs) -> None: + if hasattr(self, "mask_token_id"): + kwargs["mask_token_id"] = self.mask_token_id + self.transformer = DiffusionModel(config, **kwargs) + + if not self._tied_word_embeddings: + self.lm_head = ParameterizedLinear( + config.hidden_size, config.vocab_size, bias=False, std=config.initializer_range + ) + + self.m_width = config.m_width + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ParameterizedEmbedding: + return self.transformer.wte + + def set_input_embeddings(self, value: ParameterizedEmbedding) -> None: + self.transformer.wte = value + + def get_output_embeddings(self) -> ParameterizedLinear: + return self.transformer.wte if self._tied_word_embeddings else self.lm_head + + def set_output_embeddings(self, new_embeddings: ParameterizedLinear) -> None: + if not self._tied_word_embeddings: + self.lm_head = new_embeddings + + def forward( + self, + input_ids: torch.Tensor | list[list[int]] | None = None, + past_key_values: GenerationCache | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | list[list[int]] | None = None, + inputs_embeds: torch.Tensor | list[list[float]] | None = None, + labels: torch.Tensor | list[list[int]] | None = None, + use_cache: bool | None = None, + return_dict: bool = True, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: int | None = None, + reduction: str = "mean", + masked_indices: torch.Tensor | None = None, + ) -> CausalLMOutputWithPast: + assert return_dict + assert inputs_embeds is None + input_ids, position_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + ) + # ========================================================================================== + # padding_free: + # input_ids -> (total_q) + # attention_mask -> None + # position_ids -> (total_q) + # else: + # input_ids -> (batch_size, query_length) + # attention_mask -> None or (batch_size, key_length) + # position_ids -> None or (batch_size, key_length) + # ========================================================================================== + clear_aux_loss() + + transformer_outputs: BaseModelOutputWithPast = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + hidden_states = transformer_outputs.last_hidden_state + if masked_indices is not None: + hidden_states = torch.index_select(hidden_states, dim=0, index=masked_indices) + + past_key_values = transformer_outputs.past_key_values + del transformer_outputs + + lm_logits = None + loss = None + + if labels is None: + if is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute): + if self.m_width is not None: + hidden_states = hidden_states / self.m_width + else: + lm_logits = self.get_lm_logits(hidden_states) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width + else: + assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) + + lm_logits = self.get_lm_logits(hidden_states) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width + + loss = get_autoregressive_language_modeling_loss( + lm_logits=lm_logits, + labels=labels, + hidden_states=None, + vocab_weight=None, + cu_seqlens=cu_seqlens, + use_padding_free_transformer=self.use_padding_free_transformer, + reduction=reduction, + shift_logits_and_labels=True, + tensor_parallel_enabled=False, + ) + aux_loss = get_aux_loss() + + if loss is not None and not is_aux_loss_zero(aux_loss): + loss = loss + self.router_aux_loss_coef * aux_loss + + return CausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=lm_logits, + past_key_values=past_key_values, + last_hidden_state=hidden_states, + ) + + def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = ( + F.linear(hidden_states, self.transformer.wte.weight) + if self._tied_word_embeddings + else self.lm_head(hidden_states) + ) + logits.index_fill_( + dim=-1, index=torch.tensor(self.mask_token_id, dtype=torch.int32, device=logits.device), value=-1.0e5 + ) + # logits[..., self.mask_token_id] = -1.0e5 + return logits diff --git a/lm_engine/hf_models/register_hf.py b/lm_engine/hf_models/register_hf.py index c54d59393..1246191de 100644 --- a/lm_engine/hf_models/register_hf.py +++ b/lm_engine/hf_models/register_hf.py @@ -19,6 +19,9 @@ PaLMConfig, PaLMForCausalLM, PaLMModel, + DiffusionConfig, + DiffusionMaskedLM, + DiffusionModel ) @@ -28,6 +31,7 @@ (GPTCrossLayerConfig, GPTCrossLayerModel, GPTCrossLayerForCausalLM), (LadderResidualConfig, LadderResidualModel, LadderResidualForCausalLM), (PaLMConfig, PaLMModel, PaLMForCausalLM), + (DiffusionConfig, DiffusionModel, DiffusionMaskedLM) ] _CUSTOM_MODEL_TYPES = [] _CUSTOM_MODEL_CLASSES = [] diff --git a/lm_engine/model_wrapper/__init__.py b/lm_engine/model_wrapper/__init__.py index 12615bdde..0370807d2 100644 --- a/lm_engine/model_wrapper/__init__.py +++ b/lm_engine/model_wrapper/__init__.py @@ -11,11 +11,13 @@ from .distillation import ModelWrapperForDistillation from .finetuning import ModelWrapperForFinetuning from .pretraining import ModelWrapperForPretraining +from .pretraining_diffusion import ModelWrapperForPretrainingDiffusion from .utils import broadcast_tensor_parallel_input _MODEL_CLASS_MAPPING = { TuningMethod.pretraining: ModelWrapperForPretraining, + TuningMethod.pretraining_diffusion: ModelWrapperForPretrainingDiffusion, TuningMethod.full_finetuning: ModelWrapperForFinetuning, TuningMethod.distillation: ModelWrapperForDistillation, } @@ -49,7 +51,7 @@ def get_model_container( } # pretraining model wrapper needs some extra arguments for initialization - if tuning_method in [TuningMethod.pretraining, TuningMethod.distillation]: + if tuning_method in [TuningMethod.pretraining, TuningMethod.distillation, TuningMethod.pretraining_diffusion]: kwargs["micro_batch_size"] = args.training_parameters.micro_batch_size kwargs["sequence_length"] = args.datasets[0].class_args.get("sequence_length") kwargs["reset_attention_mask"] = args.model_args.reset_attention_mask diff --git a/lm_engine/model_wrapper/pretraining_diffusion.py b/lm_engine/model_wrapper/pretraining_diffusion.py new file mode 100644 index 000000000..9b6ea2ac4 --- /dev/null +++ b/lm_engine/model_wrapper/pretraining_diffusion.py @@ -0,0 +1,283 @@ +# ************************************************** +# Copyright (c) 2025, Mayank Mishra +# ************************************************** + +from __future__ import annotations + +import torch +from torch.distributed._tensor.placement_types import Replicate +from torch.nn import functional as F +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM + +from ..dtensors import tensor_to_dtensor +from ..enums import Kernel +from ..hf_models import ( + CausalLMOutputWithPast, + PipelineParallelInput, + PipelineParallelOutput, + get_autoregressive_language_modeling_loss, + is_aux_loss_zero, +) +from ..kernels import is_kernel_allowed +from ..utils import MetricsTrackingDict, ProcessGroupManager +from .base import ModelWrapper +from .pretraining import _F, ModelWrapperForPretraining +from .utils import broadcast_tensor_parallel_input + + +class ModelWrapperForPretrainingDiffusion(ModelWrapperForPretraining): + def __init__( + self, + model_name: str | None, + pretrained_config: dict | None, + model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM, + dtype: torch.dtype, + efficient_initialization: bool, + use_padding_free_transformer: bool, + sequence_parallel: bool, + micro_batch_size: int, + sequence_length: int, + num_pipeline_stages: int, + pipeline_stage_id: int, + trust_remote_code: bool = False, + tokenizer_name: str | None = None, + additional_special_tokens: list[str] | None = None, + reset_attention_mask: bool = False, + reset_position_ids: bool = False, + keep_in_fp32: bool = True, + ) -> ModelWrapperForPretraining: + super().__init__( + model_name, + pretrained_config, + model_class, + dtype, + efficient_initialization, + use_padding_free_transformer, + sequence_parallel, + micro_batch_size, + sequence_length, + num_pipeline_stages, + pipeline_stage_id, + trust_remote_code, + tokenizer_name, + additional_special_tokens, + reset_attention_mask, + reset_position_ids, + keep_in_fp32, + ) + assert self.use_padding_free_transformer and self.reset_attention_mask + + def _get_model_kwargs(self): + kwargs = super()._get_model_kwargs() + if hasattr(self, "mask_token_id"): + kwargs["mask_token_id"] = self.mask_token_id + return kwargs + + def forward( + self, + batch: dict | torch.Tensor, + aux_loss_from_pipeline_parallel: torch.Tensor | float = 0, + lm_loss_multiplier: float = 1, + ) -> dict: + """forward function for a batch + + Args: + batch (dict): a dict of key, value pairs for a batch + + Returns: + torch.Tensor: loss tensor + """ + + # for pretraining we compute loss externally here instead of relying on transformers. + # this is done because megatron's dataset returns batches of length (sequence_length + 1) + # instead of (sequence_length), so we need to trim the input_ids before forward pass. + # transformers does forward pass before however and then trims the tokens. + + if not self.is_custom_model: + assert not is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) + + if isinstance(batch, torch.Tensor): + batch = {"text": batch} + + if self.is_pipeline_parallel_enabled: + batch["aux_loss_from_pipeline_parallel"] = aux_loss_from_pipeline_parallel + else: + assert aux_loss_from_pipeline_parallel == 0 + + batch = self._prepare_model_inputs(batch) + labels = batch.pop("labels") + p_mask = batch.pop("p_mask") + masked_indices = batch["masked_indices"] + output: CausalLMOutputWithPast | PipelineParallelOutput = self.model(**batch, return_dict=True) + + if self.is_pipeline_parallel_enabled: + # aux_loss is returned as a 0 dimensional tensor + aux_loss = output.aux_loss + use_aux_loss = not is_aux_loss_zero(aux_loss) + + if use_aux_loss and aux_loss.dim() == 0: + aux_loss = aux_loss.unsqueeze(0) + + if self.is_last_stage: + assert isinstance(output, CausalLMOutputWithPast) + output = output.logits + else: + assert isinstance(output, PipelineParallelOutput) + output = output.hidden_states + + if use_aux_loss: + output = (output, aux_loss) + else: + assert (labels[batch["masked_indices"]] != self.ignore_token_id).all() + output = self.get_loss(output, labels, masked_indices, p_mask, lm_loss_multiplier=lm_loss_multiplier) + + return output + + def get_loss( + self, + model_outputs: CausalLMOutputWithPast, + labels: torch.Tensor, + masked_indices: torch.Tensor, + p_mask: torch.Tensor, + lm_loss_multiplier: float = 1, + ) -> torch.Tensor | dict: + tensor_parallel_enabled = ProcessGroupManager.is_tensor_parallel_enabled() + # use_fused_linear_cross_entropy_kernel = is_kernel_allowed(Kernel.fused_linear_cross_entropy_cute) + flat_logits = model_outputs.logits.flatten(0, -2) + flat_labels = labels.flatten()[masked_indices] + flat_p_mask = p_mask.flatten()[masked_indices] + # print(flat_logits.size(), flat_labels.size()) + lm_loss = ( + F.cross_entropy( + input=flat_logits, + target=flat_labels, + ignore_index=self.ignore_token_id, + reduction="none", + ) + / flat_p_mask + ).sum() / 2 + + lm_loss = lm_loss * lm_loss_multiplier + aux_loss = getattr(model_outputs, "aux_loss", 0) + + if is_aux_loss_zero(aux_loss): + loss = lm_loss + output = {"loss": loss, "lm_loss": loss} + else: + if self.is_pipeline_parallel_enabled: + self._extra_metrics = self._extra_metrics + {"aux_loss": aux_loss} + + if tensor_parallel_enabled: + aux_loss = tensor_to_dtensor(aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate()) + + loss = _F.apply(lm_loss, aux_loss, self.router_aux_loss_coef) + output = {"loss": loss, "lm_loss": lm_loss, "aux_loss": aux_loss} + + return output + + def _setup_tokenizer(self) -> None: + super()._setup_tokenizer() + assert hasattr( + self.tokenizer, "mask_token_id" + ), "Tokenizer must have `mask_token_id` for diffusion_pretraining" + self.mask_token_id = self.tokenizer.mask_token_id + assert self.mask_token_id is not None, "Tokenizer must have `mask_token_id` for diffusion_pretraining" + self.pad_token_id = self.tokenizer.pad_token_id + assert self.pad_token_id is not None + self.ignore_token_id = -1 + + def _prepare_model_inputs(self, batch: dict) -> dict: + device = torch.cuda.current_device() + if self.is_pipeline_parallel_enabled: + raise NotImplementedError("No pipeline for diffusion yet.") + else: + if ProcessGroupManager.is_tensor_parallel_enabled(): + tokens = broadcast_tensor_parallel_input( + None if batch is None else batch["text"], (self.micro_batch_size, self.sequence_length + 1) + ) + else: + tokens = batch["text"] + tokens = tokens.to(device) + + # still shifted to facilitate adaptation workflow + input_ids = tokens[:, :-1] + labels = tokens[:, 1:] + orig_batch_size, sequence_length = input_ids.shape + batch_size = orig_batch_size * 2 + + perm_idxs = torch.argsort(torch.rand_like(input_ids[:, :-1], dtype=torch.bfloat16), dim=-1) + input_ids = input_ids.repeat_interleave(2, 0).flatten() + orig_input_ids = input_ids.clone() + unmasked_labels = labels.repeat_interleave(2, 0).flatten() + labels = torch.full_like(input_ids, fill_value=self.ignore_token_id) + p_mask = torch.ones_like(input_ids, dtype=torch.bfloat16) + + # assert batch_size % 2 == 0 + masked_ptr = 0 + masked_indices = ( + torch.zeros((batch_size // 2) * (sequence_length - 1), dtype=input_ids.dtype, device=input_ids.device) - 1 + ) + + document_end_positions = unmasked_labels == self.eos_token_id + document_end_positions[sequence_length - 1 :: sequence_length] = 1 + eps = 1e-4 + moved_boundary = False + + def _apply_mask_and_fill(start_idx, end_idx, masked_idxs, p): + nonlocal moved_boundary + labels[start_idx:end_idx][masked_idxs] = input_ids[start_idx:end_idx][masked_idxs + 1] + input_ids[start_idx:end_idx][masked_idxs + 1] = self.mask_token_id + p_mask[start_idx:end_idx] = p + + for i in range(orig_batch_size): + t = torch.rand(1, device=input_ids.device)[0] + p = (1 - 2 * eps) * t + eps + sample_masked_idxs = perm_idxs[i] + mask_count = torch.round(p * (sequence_length - 1)).to(torch.int32) + masked_idxs_ = sample_masked_idxs[:mask_count] + _apply_mask_and_fill( + start_idx=2 * i * sequence_length, end_idx=(2 * i + 1) * sequence_length, masked_idxs=masked_idxs_, p=p + ) + masked_indices[masked_ptr : masked_ptr + mask_count] = 2 * i * sequence_length + masked_idxs_ + masked_ptr += mask_count + + masked_idxs_ = sample_masked_idxs[mask_count:] + mask_count = (sequence_length - 1) - mask_count + _apply_mask_and_fill( + start_idx=(2 * i + 1) * sequence_length, + end_idx=(2 * i + 2) * sequence_length, + masked_idxs=masked_idxs_, + p=1 - p, + ) + masked_indices[masked_ptr : masked_ptr + mask_count] = (2 * i + 1) * sequence_length + masked_idxs_ + masked_ptr += mask_count + # assert (masked_indices != -1).any() + + masked_indices, _ = torch.sort(masked_indices) + cu_seqlens = document_end_positions.nonzero(as_tuple=True)[0] + 1 + cu_seqlens = torch.cat([torch.tensor([0], device=input_ids.device), cu_seqlens]).to(torch.int32) + seqlen = cu_seqlens[1:] - cu_seqlens[:-1] + # we move to CPU here otherwise FlashAttention will move to CPU on every invocation i.e all layers + max_seqlen = seqlen.max().item() + + if self.reset_position_ids: + position_ids = torch.cat( + [torch.arange(0, i, 1, dtype=torch.int32, device=input_ids.device) for i in seqlen] + ) + else: + position_ids = self.position_ids + assert (labels[masked_indices] != self.ignore_token_id).all() + assert (input_ids[masked_indices + 1] == self.mask_token_id).all() + batch = { + "input_ids": input_ids, + "labels": labels.flatten(), + "p_mask": p_mask.flatten(), + "cu_seqlens": cu_seqlens, + "max_seqlen": max_seqlen, + "position_ids": position_ids, + "masked_indices": masked_indices, + } + if ProcessGroupManager.is_tensor_parallel_enabled(): + batch["output_parallel_lm_logits"] = True + + return batch diff --git a/lm_engine/optimization/optimizer.py b/lm_engine/optimization/optimizer.py index d03fdea39..b84f06688 100644 --- a/lm_engine/optimization/optimizer.py +++ b/lm_engine/optimization/optimizer.py @@ -67,7 +67,21 @@ def get_optimizer_container( raise ImportError("relevant package for the optimizer is not installed") params_groups_list = get_param_groups_list(model_container, optimizer_class_args, params_group_method) - + # TODO hack for length-extension + # for group in params_groups_list: + # for param in group: + # for p in param[1]: + # # p.parameter_name_map = {k: p.parameter_name_map[k] for k in p.parameter_name_map} + # # for k in p.parameter_name_map: + # # print(k) + # p.parameter_name_map = { + # k: p.parameter_name_map[k] + # for k in p.parameter_name_map + # if ( + # 'sequence_mixer' in k or + # 'wte' in k + # ) + # } if use_optimizer_with_backward_hook: for model, params_groups in zip(model_container, params_groups_list): for param_name, param in model.named_parameters(): diff --git a/lm_engine/pretrain.py b/lm_engine/pretrain.py index ca9ed1a4f..31740265b 100644 --- a/lm_engine/pretrain.py +++ b/lm_engine/pretrain.py @@ -576,7 +576,8 @@ def main(args_class: type[DistillationArgs | TrainingArgs] = TrainingArgs) -> No if args_class == TrainingArgs: assert ( - args.tuning_args.tuning_method == TuningMethod.pretraining + args.tuning_args.tuning_method == TuningMethod.pretraining or + args.tuning_args.tuning_method == TuningMethod.pretraining_diffusion ), f"unexpected tuning method ({args.tuning_args.tuning_method})" elif args_class == DistillationArgs: assert args.distributed_args.fsdp_algorithm == 2, "Distillation is only supported with FSDP-2"