From 4c6aa8dc126da46152527a74e6cc6d745c0da2b0 Mon Sep 17 00:00:00 2001 From: luguoshan Date: Tue, 11 Nov 2025 00:24:54 +0800 Subject: [PATCH 01/10] add_tpd --- README.md | 4 +- .../algo/trainable_parallel_decoding.rst | 106 ++++++++++++++++++ docs/source/index.rst | 1 + tasks/train_llada2_bd.py | 77 +++++++++++-- 4 files changed, 178 insertions(+), 10 deletions(-) create mode 100644 docs/source/algo/trainable_parallel_decoding.rst diff --git a/README.md b/README.md index 6d5e34e..484e947 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,8 @@ We are actively working on enhancing the project with new features and improvements. Our roadmap for the near future includes: -- [x] **Comprehensive Documentation**: A full documentation site is underway, which will feature in-depth tutorials, API references, and best practices. -- [ ] **Trainable Parallel Decoding**: Integration of support for trainable parallel decoding to enable more advanced use cases. +- [☑️] **Comprehensive Documentation**: A full documentation site is underway, which will feature in-depth tutorials, API references, and best practices. +- [☑️] **Trainable Parallel Decoding**: Integration of support for trainable parallel decoding to enable more advanced use cases. Stay tuned for these updates! diff --git a/docs/source/algo/trainable_parallel_decoding.rst b/docs/source/algo/trainable_parallel_decoding.rst new file mode 100644 index 0000000..c3a8fb7 --- /dev/null +++ b/docs/source/algo/trainable_parallel_decoding.rst @@ -0,0 +1,106 @@ +Trainable Parallel Decoding +============================ + +Last updated: 2025-11-10 + +Trainable Parallel Decoding is a novel approach to accelerate Discrete Diffusion Language Models (DDLMs) by learning to decode multiple tokens simultaneously during training, thereby reducing inference latency while maintaining generation quality. + +Overview +-------- + +Traditional DDLMs suffer from high inference latency due to their iterative, multi-step sampling process. Trainable Parallel Decoding addresses this limitation by introducing a second-stage fine-tuning paradigm that teaches the model to predict multiple future tokens in a single forward pass. This approach transforms the sequential generation process into a more parallelizable one, significantly reducing the number of required sampling steps. + +The framework currently supports two complementary techniques: + +1. **Path Distillation (Trajectory Compression)**: Learning to jump between non-consecutive states in optimal generation trajectories +2. **DPARALLEL**: Entropy-based loss regularization to accelerate parallel decoding learning + +Path Distillation (Trajectory Compression) +------------------------------------------ + +Path Distillation is inspired by the observation in Seed Diffusion [arXiv:2508.02193] that training on high-quality generation paths can significantly improve model efficiency. This method consists of two main stages: + +High-Quality Trajectory Distillation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The first stage involves creating a dataset of "golden" trajectories through the following process: + +1. **Trajectory Generation**: Use a pre-trained DDLM to sample generation paths on a domain-specific dataset (e.g., 200,000 math problems) +2. **Quality Filtering**: Apply an external verifier to filter trajectories that produce correct outputs +3. **Dataset Construction**: Retain only high-quality trajectories that pass verification + +Mathematically, given a trajectory :math:`\tau = (s_N, s_{N-1}, \dots, s_0)` representing states from fully masked to final output, we filter: + +.. math:: + \mathcal{T}_{\text{gold}} = \{ \tau \in \mathcal{T} \,|\, V(s_0^{\tau}) = \text{True} \} + +where :math:`V(\cdot)` is the external verifier function. + +Compressed Transition Learning +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The second stage fine-tunes the model to predict multi-step transitions instead of single-step ones: + +1. **Training Instance Construction**: For each trajectory, randomly sample timestamps :math:`i` and :math:`j` where :math:`N \ge i > j \ge 0` +2. **Target Identification**: The model learns to predict tokens that are [MASK] in :math:`s_i` but revealed in :math:`s_j` +3. **Loss Optimization**: Minimize the negative log-likelihood of compressed transitions + +The fine-tuning objective is: + +.. math:: + \mathcal{L}_{\text{compress}}(\theta) = - \mathbb{E}_{\tau \in \mathcal{T}_{\text{gold}}, \, i,j \sim U(\tau)} \left[ \sum_{k \in \Delta_{i \to j}} \log p_\theta(x_k = s_j[k] \,|\, s_i) \right] + +where :math:`\Delta_{i \to j} = M_i \setminus M_j` represents the indices of tokens to be predicted. + +Implementation Details +^^^^^^^^^^^^^^^^^^^^^^ + +The data preparation process involves: + +1. **Offline Dataset Creation**: Generate and filter trajectories offline +2. **Data Format**: Prepare input_ids, noisy_input_ids, and labels for training +3. **Training Configuration**: Use standard SFT training with the compressed transition objective + +The training data format should include: + +- ``input_ids``: The starting state :math:`s_i` with appropriate masking +- ``noisy_input_ids``: The noised version of :math:`s_i` +- ``labels``: The target tokens to predict (tokens in :math:`s_j` that differ from :math:`s_i`) + +DPARALLEL: Learnable Parallel Decoding +-------------------------------------- + +DPARALLEL is inspired by the DPARALLEL: Learnable Parallel Decoding for DDLMs approach, which introduces an entropy-based regularization term to the training loss to encourage the model to learn parallel decoding capabilities. + +Methodology +~~~~~~~~~~~ + +The key insight is that by adding a confidence-based loss term during supervised fine-tuning, we can guide the model toward making confident, parallel predictions. This is achieved through: + +1. **Entropy Regularization**: Add a loss term based on the entropy of the model's predictions +2. **Confidence Scoring**: Use prediction confidence as a signal for parallel decoding quality +3. **Loss Balancing**: Combine the standard cross-entropy loss with the confidence-based term + +Configuration +~~~~~~~~~~~~~ + +To enable DPARALLEL, use the following training configuration: + +.. code:: bash + + --train.confidence_beta {confidence_beta} + +Where: +- ``confidence_beta`` controls the strength of the entropy regularization (recommended value: 2.0) +- Higher values encourage more aggressive parallel decoding +- The parameter balances between generation quality and speed-up + +Training Process +^^^^^^^^^^^^^^^^ + +The DPARALLEL training process: + +1. **Standard SFT Setup**: Begin with standard supervised fine-tuning +2. **Loss Modification**: Add the confidence-based regularization term +3. **Hyperparameter Tuning**: Adjust ``confidence_beta`` based on desired speed-quality trade-off +4. **Evaluation**: Monitor both generation quality and inference speed metrics \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index c51fad0..b290076 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -37,6 +37,7 @@ dFactory documentation algo/random_mask algo/block_diffusion + algo/trainable_parallel_decoding .. toctree:: diff --git a/tasks/train_llada2_bd.py b/tasks/train_llada2_bd.py index 7a46f86..ccf12ce 100644 --- a/tasks/train_llada2_bd.py +++ b/tasks/train_llada2_bd.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Literal, Tuple, Optional import torch +import torch.nn.functional as F import torch.distributed as dist import wandb from tqdm import trange @@ -98,6 +99,10 @@ class LLaDA2TrainingArguments(TrainingArguments): default=0.999, metadata={"help": "AdamW optimizer beta2"}, ) + confidence_beta: float = field( + default=0.0, + metadata={"help": "Weight for the confidence loss entropy of correct predictions. Set to 0 to disable."}, + ) block_diffusion_mode: bool = field( default=False, metadata={"help": "If train MDM in block_diffusion mode. True: use block_diffusion, False: full_attention"} @@ -165,7 +170,48 @@ def block_diffusion_mask(b, h, q_idx, kv_idx, block_size=None, n=None): # **4. Combine Masks ** return block_diagonal | offset_block_causal | block_causal - +def compute_confidence_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + 计算在模型预测正确位置的输出分布的平均熵。 + Args: + logits (torch.Tensor): 模型的原始输出 logits,形状为 (batch_size, seq_len, vocab_size)。 + labels (torch.Tensor): 真实标签,形状为 (batch_size, seq_len)。-100表示忽略。 + Returns: + torch.Tensor: 一个标量张量,表示 confidence loss。如果没有任何正确预测,则为 0。 + """ + # 确保 labels 和 logits 在同一个设备上 + labels = labels.to(logits.device) + + # 步骤 1: 找到所有有效的标签位置 + valid_mask = (labels != -100) + if not valid_mask.any(): + return torch.tensor(0.0, device=logits.device) + + # 步骤 2: 找到模型在每个位置上概率最高的 token + predicted_tokens = torch.argmax(logits, dim=-1) + + # 步骤 3: 找到模型预测正确的那些位置 (M_c) + # 我们只关心在有效标签位置上是否预测正确 + correct_mask = (predicted_tokens == labels) & valid_mask + + # 如果没有任何一个位置预测正确,那么这个损失为 0 + if correct_mask.sum() == 0: + return torch.tensor(0.0, device=logits.device) + + # 步骤 4: 计算所有位置的概率分布和熵 + # 使用 log_softmax 以保证数值稳定性 + log_probs = F.log_softmax(logits, dim=-1) + probs = torch.exp(log_probs) + # 熵的计算公式: H(p) = - sum(p * log(p)) + entropy_per_token = -torch.sum(probs * log_probs, dim=-1) + + # 步骤 5: 筛选出 M_c 位置的熵,并计算平均值 + entropy_at_correct_positions = entropy_per_token[correct_mask] + + confidence_loss = entropy_at_correct_positions.mean() + + return confidence_loss + def main(): dist.init_process_group(backend=get_nccl_backend()) args = parse_args(Arguments) @@ -415,6 +461,10 @@ def main(): total_loss = 0 synchronize() start_time = time.time() + num_accumulation_steps = len(micro_batches) + total_consistency_loss = 0 + total_confidence_loss = 0 + for micro_batch in micro_batches: environ_meter.add(micro_batch) if args.data.enable_multisource: @@ -448,13 +498,20 @@ def main(): else: noisy_logits = logits + confidence_loss = torch.tensor(0.0, device=noisy_logits.device) + if args.train.confidence_beta > 0: + confidence_loss = compute_confidence_loss( + logits=noisy_logits, + labels=labels, + ) + if args.train.same_token_labels: unscaled_loss = torch.nn.functional.cross_entropy( noisy_logits.view(-1, noisy_logits.shape[-1]), labels.view(-1), reduction="none", - ) - loss = unscaled_loss.sum() / (labels != -100).sum() / len(micro_batches) + ).view(noisy_logits.shape[0], -1) + consistency_loss = unscaled_loss.sum() / (labels != -100).sum() else: shifted_noisy_logits = noisy_logits[:, :-1, :].contiguous() shifted_labels = labels[:, 1:].contiguous() @@ -463,12 +520,16 @@ def main(): shifted_labels.view(-1), reduction="none", ).view(shifted_noisy_logits.shape[0], -1) - loss = unscaled_loss.sum() / (shifted_labels != -100).sum() / len(micro_batches) - + consistency_loss = unscaled_loss.sum() / (shifted_labels != -100).sum() + + combined_loss = consistency_loss + confidence_loss * args.train.confidence_beta + loss = combined_loss / num_accumulation_steps with model_bwd_context: loss.backward() total_loss += loss.item() + total_consistency_loss += consistency_loss.item() / num_accumulation_steps + total_confidence_loss += confidence_loss.item() / num_accumulation_steps del micro_batch # Prefer model-provided clip_grad_norm_ (now both FSDP1 and FSDP2 registers custom grad norm clipping) @@ -494,13 +555,13 @@ def main(): lr = max(lr_scheduler.get_last_lr()) train_metrics = environ_meter.step(delta_time, global_step=global_step) - data_loader_tqdm.set_postfix_str(f"loss: {total_loss:.2f}, grad_norm: {grad_norm:.2f}, lr: {lr:.2e}") + data_loader_tqdm.set_postfix_str(f"loss: {total_loss:.2f}, cons: {total_consistency_loss:.2f}, conf: {total_confidence_loss:.2f}, grad_norm: {grad_norm:.2f}, lr: {lr:.2e}") data_loader_tqdm.update() if args.train.global_rank == 0: if args.train.use_wandb: train_metrics.update( - {"training/loss": total_loss, "training/grad_norm": grad_norm, "training/lr": lr} + {"training/loss": total_loss, "training/cons_loss": total_consistency_loss, "training/conf_loss": total_confidence_loss, "training/grad_norm": grad_norm, "training/lr": lr} ) wandb.log(train_metrics, step=global_step) @@ -569,4 +630,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file From 9e38ba01f187f5c06339fee55953419a5b25697e Mon Sep 17 00:00:00 2001 From: Grason-Lu Date: Tue, 11 Nov 2025 00:33:18 +0800 Subject: [PATCH 02/10] fix --- tasks/train_llada2_bd.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tasks/train_llada2_bd.py b/tasks/train_llada2_bd.py index ccf12ce..2084cc5 100644 --- a/tasks/train_llada2_bd.py +++ b/tasks/train_llada2_bd.py @@ -172,40 +172,30 @@ def block_diffusion_mask(b, h, q_idx, kv_idx, block_size=None, n=None): def compute_confidence_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ - 计算在模型预测正确位置的输出分布的平均熵。 + Calculate the average entropy of the output distribution at positions where the model predicts correctly. Args: - logits (torch.Tensor): 模型的原始输出 logits,形状为 (batch_size, seq_len, vocab_size)。 - labels (torch.Tensor): 真实标签,形状为 (batch_size, seq_len)。-100表示忽略。 + logits (torch.Tensor): The raw output logits from the model, with shape (batch_size, seq_len, vocab_size). + labels (torch.Tensor): The ground truth labels, with shape (batch_size, seq_len). -100 indicates positions to be ignored. Returns: - torch.Tensor: 一个标量张量,表示 confidence loss。如果没有任何正确预测,则为 0。 + torch.Tensor: A scalar tensor representing the confidence loss. Returns 0 if there are no correct predictions. """ - # 确保 labels 和 logits 在同一个设备上 labels = labels.to(logits.device) - # 步骤 1: 找到所有有效的标签位置 valid_mask = (labels != -100) if not valid_mask.any(): return torch.tensor(0.0, device=logits.device) - # 步骤 2: 找到模型在每个位置上概率最高的 token predicted_tokens = torch.argmax(logits, dim=-1) - # 步骤 3: 找到模型预测正确的那些位置 (M_c) - # 我们只关心在有效标签位置上是否预测正确 correct_mask = (predicted_tokens == labels) & valid_mask - # 如果没有任何一个位置预测正确,那么这个损失为 0 if correct_mask.sum() == 0: return torch.tensor(0.0, device=logits.device) - # 步骤 4: 计算所有位置的概率分布和熵 - # 使用 log_softmax 以保证数值稳定性 log_probs = F.log_softmax(logits, dim=-1) probs = torch.exp(log_probs) - # 熵的计算公式: H(p) = - sum(p * log(p)) entropy_per_token = -torch.sum(probs * log_probs, dim=-1) - # 步骤 5: 筛选出 M_c 位置的熵,并计算平均值 entropy_at_correct_positions = entropy_per_token[correct_mask] confidence_loss = entropy_at_correct_positions.mean() From a32ab88b12750cc1fbf08ecbb1898fbe4ba341f3 Mon Sep 17 00:00:00 2001 From: Grason-Lu Date: Tue, 11 Nov 2025 12:11:54 +0800 Subject: [PATCH 03/10] fix --- .../algo/trainable_parallel_decoding.rst | 3 +- tasks/train_llada2_bd.py | 63 +- tasks/train_llada2_bd_with_dparallel.py | 623 ++++++++++++++++++ 3 files changed, 630 insertions(+), 59 deletions(-) create mode 100644 tasks/train_llada2_bd_with_dparallel.py diff --git a/docs/source/algo/trainable_parallel_decoding.rst b/docs/source/algo/trainable_parallel_decoding.rst index c3a8fb7..04d136f 100644 --- a/docs/source/algo/trainable_parallel_decoding.rst +++ b/docs/source/algo/trainable_parallel_decoding.rst @@ -87,8 +87,7 @@ Configuration To enable DPARALLEL, use the following training configuration: .. code:: bash - - --train.confidence_beta {confidence_beta} + sh train.sh tasks/train_llada2_bd_with_dparallel.py configs/sft/llada2_mini_bd_sft.yaml --train.confidence_beta {confidence_beta} Where: - ``confidence_beta`` controls the strength of the entropy regularization (recommended value: 2.0) diff --git a/tasks/train_llada2_bd.py b/tasks/train_llada2_bd.py index 2084cc5..b21f0bc 100644 --- a/tasks/train_llada2_bd.py +++ b/tasks/train_llada2_bd.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Tuple, Optional import torch -import torch.nn.functional as F import torch.distributed as dist import wandb from tqdm import trange @@ -99,10 +98,6 @@ class LLaDA2TrainingArguments(TrainingArguments): default=0.999, metadata={"help": "AdamW optimizer beta2"}, ) - confidence_beta: float = field( - default=0.0, - metadata={"help": "Weight for the confidence loss entropy of correct predictions. Set to 0 to disable."}, - ) block_diffusion_mode: bool = field( default=False, metadata={"help": "If train MDM in block_diffusion mode. True: use block_diffusion, False: full_attention"} @@ -170,38 +165,7 @@ def block_diffusion_mask(b, h, q_idx, kv_idx, block_size=None, n=None): # **4. Combine Masks ** return block_diagonal | offset_block_causal | block_causal -def compute_confidence_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - """ - Calculate the average entropy of the output distribution at positions where the model predicts correctly. - Args: - logits (torch.Tensor): The raw output logits from the model, with shape (batch_size, seq_len, vocab_size). - labels (torch.Tensor): The ground truth labels, with shape (batch_size, seq_len). -100 indicates positions to be ignored. - Returns: - torch.Tensor: A scalar tensor representing the confidence loss. Returns 0 if there are no correct predictions. - """ - labels = labels.to(logits.device) - - valid_mask = (labels != -100) - if not valid_mask.any(): - return torch.tensor(0.0, device=logits.device) - - predicted_tokens = torch.argmax(logits, dim=-1) - - correct_mask = (predicted_tokens == labels) & valid_mask - if correct_mask.sum() == 0: - return torch.tensor(0.0, device=logits.device) - - log_probs = F.log_softmax(logits, dim=-1) - probs = torch.exp(log_probs) - entropy_per_token = -torch.sum(probs * log_probs, dim=-1) - - entropy_at_correct_positions = entropy_per_token[correct_mask] - - confidence_loss = entropy_at_correct_positions.mean() - - return confidence_loss - def main(): dist.init_process_group(backend=get_nccl_backend()) args = parse_args(Arguments) @@ -451,10 +415,6 @@ def main(): total_loss = 0 synchronize() start_time = time.time() - num_accumulation_steps = len(micro_batches) - total_consistency_loss = 0 - total_confidence_loss = 0 - for micro_batch in micro_batches: environ_meter.add(micro_batch) if args.data.enable_multisource: @@ -488,20 +448,13 @@ def main(): else: noisy_logits = logits - confidence_loss = torch.tensor(0.0, device=noisy_logits.device) - if args.train.confidence_beta > 0: - confidence_loss = compute_confidence_loss( - logits=noisy_logits, - labels=labels, - ) - if args.train.same_token_labels: unscaled_loss = torch.nn.functional.cross_entropy( noisy_logits.view(-1, noisy_logits.shape[-1]), labels.view(-1), reduction="none", - ).view(noisy_logits.shape[0], -1) - consistency_loss = unscaled_loss.sum() / (labels != -100).sum() + ) + loss = unscaled_loss.sum() / (labels != -100).sum() / len(micro_batches) else: shifted_noisy_logits = noisy_logits[:, :-1, :].contiguous() shifted_labels = labels[:, 1:].contiguous() @@ -510,16 +463,12 @@ def main(): shifted_labels.view(-1), reduction="none", ).view(shifted_noisy_logits.shape[0], -1) - consistency_loss = unscaled_loss.sum() / (shifted_labels != -100).sum() - - combined_loss = consistency_loss + confidence_loss * args.train.confidence_beta - loss = combined_loss / num_accumulation_steps + loss = unscaled_loss.sum() / (shifted_labels != -100).sum() / len(micro_batches) + with model_bwd_context: loss.backward() total_loss += loss.item() - total_consistency_loss += consistency_loss.item() / num_accumulation_steps - total_confidence_loss += confidence_loss.item() / num_accumulation_steps del micro_batch # Prefer model-provided clip_grad_norm_ (now both FSDP1 and FSDP2 registers custom grad norm clipping) @@ -545,13 +494,13 @@ def main(): lr = max(lr_scheduler.get_last_lr()) train_metrics = environ_meter.step(delta_time, global_step=global_step) - data_loader_tqdm.set_postfix_str(f"loss: {total_loss:.2f}, cons: {total_consistency_loss:.2f}, conf: {total_confidence_loss:.2f}, grad_norm: {grad_norm:.2f}, lr: {lr:.2e}") + data_loader_tqdm.set_postfix_str(f"loss: {total_loss:.2f}, grad_norm: {grad_norm:.2f}, lr: {lr:.2e}") data_loader_tqdm.update() if args.train.global_rank == 0: if args.train.use_wandb: train_metrics.update( - {"training/loss": total_loss, "training/cons_loss": total_consistency_loss, "training/conf_loss": total_confidence_loss, "training/grad_norm": grad_norm, "training/lr": lr} + {"training/loss": total_loss, "training/grad_norm": grad_norm, "training/lr": lr} ) wandb.log(train_metrics, step=global_step) diff --git a/tasks/train_llada2_bd_with_dparallel.py b/tasks/train_llada2_bd_with_dparallel.py new file mode 100644 index 0000000..2084cc5 --- /dev/null +++ b/tasks/train_llada2_bd_with_dparallel.py @@ -0,0 +1,623 @@ +import json +import os +import time +from dataclasses import asdict, dataclass, field +from functools import partial +from typing import Any, Dict, List, Literal, Tuple, Optional + +import torch +import torch.nn.functional as F +import torch.distributed as dist +import wandb +from tqdm import trange + +from veomni.checkpoint import build_checkpointer, ckpt_to_state_dict +from veomni.data import ( + build_dataloader, + build_iterative_dataset, + build_mapping_dataset, +) +from veomni.distributed.offloading import build_activation_offloading_context +from veomni.distributed.parallel_state import get_parallel_state, init_parallel_state +from veomni.distributed.torch_parallelize import build_parallelize_model +from veomni.models import build_foundation_model, build_tokenizer, save_model_assets, save_model_weights +from veomni.optim import build_lr_scheduler, build_optimizer +from veomni.utils import helper +from veomni.utils.arguments import DataArguments, ModelArguments, TrainingArguments, parse_args, save_args +from veomni.utils.device import ( + get_device_type, + get_nccl_backend, + get_torch_device, + synchronize, +) +from veomni.utils.dist_utils import all_reduce +from veomni.models.registry import ModelRegistry +ModelRegistry.register_modeling_path("models.llada2_moe") +from dataset.data_transform import process_mdm_tokenized_example, process_mdm_sft_example +from dataset import build_local_dataset + + +logger = helper.create_logger(__name__) + +@dataclass +class LLaDA2ModelArguments(ModelArguments): + attn_implementation: Optional[Literal["eager", "sdpa", "flex_attention"]] = field( + default="sdpa", + metadata={"help": "Attention implementation to use."}, + ) + + +@dataclass +class LLaDA2DataArguments(DataArguments): + data_type: Literal["conversation", "tokenid"] = field( + default="conversation", + metadata={"help": "Type of the training data."}, + ) + datasets_type: Literal["mapping", "local"] = field( + default="mapping", + metadata={"help": "Type of the datasets."}, + ) + text_keys: str = field( + default="messages", + metadata={"help": "Key to get text from the training data."}, + ) + noise_range_low: float = field( + default=0.3, + metadata={"help": "Noise level for random flip input_ids to mask_ids"} + ) + noise_range_high: float = field( + default=0.8, + metadata={"help": "Noise level for random flip input_ids to mask_ids"} + ) + + def __post_init__(self): + super().__post_init__() + if self.noise_range_low > self.noise_range_high: + raise ValueError( + f"noise_range_low ({self.noise_range_low}) " + f"cannot be greater than noise_range_high ({self.noise_range_high})." + ) + + if not (0.0 <= self.noise_range_low <= 1.0): + raise ValueError( + f"noise_range_low must be between 0.0 and 1.0, but got {self.noise_range_low}." + ) + + if not (0.0 <= self.noise_range_high <= 1.0): + raise ValueError( + f"noise_range_high must be between 0.0 and 1.0, but got {self.noise_range_high}." + ) + + +@dataclass +class LLaDA2TrainingArguments(TrainingArguments): + beta1: float = field( + default=0.9, + metadata={"help": "AdamW optimizer beta1."}, + ) + beta2: float = field( + default=0.999, + metadata={"help": "AdamW optimizer beta2"}, + ) + confidence_beta: float = field( + default=0.0, + metadata={"help": "Weight for the confidence loss entropy of correct predictions. Set to 0 to disable."}, + ) + block_diffusion_mode: bool = field( + default=False, + metadata={"help": "If train MDM in block_diffusion mode. True: use block_diffusion, False: full_attention"} + ) + block_size: int = field( + default=32, + metadata={"help": "The block size for block diffusion block size"} + ) + same_token_labels: bool = field( + default=False, + metadata={"help": "If use same token location labels. True: no shift, False: use next-token prediction shift."} + ) + + +@dataclass +class Arguments: + model: "LLaDA2ModelArguments" = field(default_factory=LLaDA2ModelArguments) + data: "LLaDA2DataArguments" = field(default_factory=LLaDA2DataArguments) + train: "LLaDA2TrainingArguments" = field(default_factory=LLaDA2TrainingArguments) + + +def block_diffusion_mask(b, h, q_idx, kv_idx, block_size=None, n=None): + """ + Constructs the specialized block diffusion attention mask for training + composed of three masks: + - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks + - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context + - **Block Causal Mask (M_BC)**: Attention to update x0 + + Args: + b, h: Batch and head indices (ignored for mask logic). + q_idx, kv_idx: Query and Key indices. + seq_len: Total sequence length. + block_size: Defines the block structure. + + Returns: + A boolean attention mask. + """ + + # Indicate whether token belongs to xt or x0 + x0_flag_q = (q_idx >= n) + x0_flag_kv = (kv_idx >= n) + + # Compute block indices + block_q = torch.where(x0_flag_q == 1, + (q_idx - n) // block_size, + q_idx // block_size) + block_kv = torch.where(x0_flag_kv == 1, + (kv_idx - n) // block_size, + kv_idx // block_size) + + # **1. Block Diagonal Mask (M_BD) ** + block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv) + + # **2. Offset Block-Causal Mask (M_OBC) ** + offset_block_causal = ( + (block_q > block_kv) + & (x0_flag_kv == 1) + & (x0_flag_q == 0) + ) + + # **3. Block-Causal Mask (M_BC) ** + block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1) + + # **4. Combine Masks ** + return block_diagonal | offset_block_causal | block_causal + +def compute_confidence_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """ + Calculate the average entropy of the output distribution at positions where the model predicts correctly. + Args: + logits (torch.Tensor): The raw output logits from the model, with shape (batch_size, seq_len, vocab_size). + labels (torch.Tensor): The ground truth labels, with shape (batch_size, seq_len). -100 indicates positions to be ignored. + Returns: + torch.Tensor: A scalar tensor representing the confidence loss. Returns 0 if there are no correct predictions. + """ + labels = labels.to(logits.device) + + valid_mask = (labels != -100) + if not valid_mask.any(): + return torch.tensor(0.0, device=logits.device) + + predicted_tokens = torch.argmax(logits, dim=-1) + + correct_mask = (predicted_tokens == labels) & valid_mask + + if correct_mask.sum() == 0: + return torch.tensor(0.0, device=logits.device) + + log_probs = F.log_softmax(logits, dim=-1) + probs = torch.exp(log_probs) + entropy_per_token = -torch.sum(probs * log_probs, dim=-1) + + entropy_at_correct_positions = entropy_per_token[correct_mask] + + confidence_loss = entropy_at_correct_positions.mean() + + return confidence_loss + +def main(): + dist.init_process_group(backend=get_nccl_backend()) + args = parse_args(Arguments) + logger.info(f"Process rank: {args.train.global_rank}, world size: {args.train.world_size}") + logger.info_rank0(json.dumps(asdict(args), indent=2)) + get_torch_device().set_device(f"{get_device_type()}:{args.train.local_rank}") + helper.set_seed(args.train.seed, args.train.enable_full_determinism) + if args.train.local_rank == 0: + helper.enable_third_party_logging() + + if args.train.global_rank == 0: + save_args(args, args.train.output_dir) + + Checkpointer = build_checkpointer(dist_backend=args.train.data_parallel_mode, ckpt_manager=args.train.ckpt_manager) + + init_parallel_state( + dp_size=args.train.data_parallel_size, + dp_replicate_size=args.train.data_parallel_replicate_size, + dp_shard_size=args.train.data_parallel_shard_size, + tp_size=args.train.tensor_parallel_size, + ep_size=args.train.expert_parallel_size, + pp_size=args.train.pipeline_parallel_size, + cp_size=args.train.context_parallel_size, + ulysses_size=args.train.ulysses_parallel_size, + dp_mode=args.train.data_parallel_mode, + ) + + logger.info_rank0("Prepare data") + tokenizer = build_tokenizer(args.model.tokenizer_path) + if args.data.data_type == "conversation": + if not tokenizer.chat_template: + raise ValueError(f"No chat template found in the tokenizer.") + + transform = partial( + process_mdm_sft_example, + tokenizer=tokenizer, + max_seq_len=args.data.max_seq_len, + text_keys=args.data.text_keys, + noise_range=(args.data.noise_range_low, args.data.noise_range_high), + mask_token_id=156895, + ) + elif args.data.data_type == "tokenid": + transform = partial( + process_mdm_tokenized_example, + max_seq_len=args.data.max_seq_len, + text_keys=args.data.text_keys, + noise_range=(args.data.noise_range_low, args.data.noise_range_high), + mask_token_id=156895, + ) + else: + raise NotImplementedError(f"Unsupported data type: {args.data.data_type}.") + + if args.data.dataloader_type == "native": + if args.data.datasets_type == "iterable": + logger.info_rank0("Start building iterative dataset") + train_dataset = build_iterative_dataset(args.data.train_path, transform=transform, seed=args.train.seed) + elif args.data.datasets_type == "mapping": + logger.info_rank0("Start building mapping dataset") + train_dataset = build_mapping_dataset(args.data.train_path, transform=transform) + elif args.data.datasets_type == "local": + logger.info_rank0("Start building local dataset") + train_dataset = build_local_dataset(args.data.train_path, transform=transform) + + dataset_length = None if not hasattr(train_dataset, "__len__") else len(train_dataset) + if args.data.datasets_type == "mapping" or args.data.datasets_type == "local": + dataset_length = dataset_length / args.train.data_parallel_size + args.train.compute_train_steps(args.data.max_seq_len, args.data.train_size, dataset_length) + + train_dataloader = build_dataloader( + dataset=train_dataset, + micro_batch_size=args.train.micro_batch_size, + global_batch_size=args.train.global_batch_size, + dataloader_batch_size=args.train.dataloader_batch_size, + seed=args.train.seed, + max_seq_len=args.data.max_seq_len, + train_steps=args.train.train_steps, + rmpad=args.train.rmpad, + rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, + bsz_warmup_ratio=args.train.bsz_warmup_ratio, + bsz_warmup_init_mbtoken=args.train.bsz_warmup_init_mbtoken, + dyn_bsz_margin=args.train.dyn_bsz_margin, + dyn_bsz_buffer_size=args.train.dyn_bsz_buffer_size, + num_workers=args.data.num_workers, + drop_last=args.data.drop_last, + pin_memory=args.data.pin_memory, + prefetch_factor=args.data.prefetch_factor, + ) + else: + raise NotImplementedError(f"Unsupported dataloader type: {args.data.dataloader_type}.") + + logger.info_rank0("Prepare model") + model = build_foundation_model( + config_path=args.model.config_path, + weights_path=args.model.model_path, + torch_dtype="float32" if args.train.enable_mixed_precision else "bfloat16", + attn_implementation=args.model.attn_implementation, + moe_implementation=args.model.moe_implementation, + init_device=args.train.init_device, + force_use_huggingface=args.model.force_use_huggingface, + ) + model_config = model.config + helper.print_device_mem_info("VRAM usage after building model") + + get_optimizer_pre_hook = getattr(model, "get_optimizer_pre_hook", None) + model = build_parallelize_model( + model, + init_device=args.train.init_device, + weights_path=args.model.model_path, + enable_full_shard=args.train.enable_full_shard, + enable_mixed_precision=args.train.enable_mixed_precision, + enable_gradient_checkpointing=args.train.enable_gradient_checkpointing, + enable_fsdp_offload=args.train.enable_fsdp_offload, + basic_modules=model._no_split_modules + args.model.basic_modules, + enable_reentrant=args.train.enable_reentrant, + enable_forward_prefetch=args.train.enable_forward_prefetch, + broadcast_model_weights_from_rank0=args.train.broadcast_model_weights_from_rank0 + ) + + optimizer = build_optimizer( + model, + lr=args.train.lr, + betas=(args.train.beta1, args.train.beta2), + weight_decay=args.train.weight_decay, + fused=True, + optimizer_type=args.train.optimizer, + ) + + if get_optimizer_pre_hook is not None: + optimizer_pre_hook = get_optimizer_pre_hook(model, model_config, args.train.data_parallel_mode) + optimizer.register_step_pre_hook(optimizer_pre_hook) + + lr_scheduler = build_lr_scheduler( + optimizer, + train_steps=args.train.train_steps * args.train.num_train_epochs, + lr=args.train.lr, + lr_min=args.train.lr_min, + lr_decay_style=args.train.lr_decay_style, + lr_decay_ratio=args.train.lr_decay_ratio, + lr_warmup_ratio=args.train.lr_warmup_ratio, + lr_start=args.train.lr_start, + ) + + if args.train.global_rank == 0: + if args.train.use_wandb: + wandb.init( + project=args.train.wandb_project, + name=args.train.wandb_name, + config={**vars(args.model), **vars(args.data), **vars(args.train)}, # flatten dict + ) + + # save model_assets before training + model_assets = [model_config, tokenizer] + save_model_assets(args.train.model_assets_dir, model_assets) + + if args.train.profile_this_rank: + profiler = helper.create_profiler( + start_step=args.train.profile_start_step, + end_step=args.train.profile_end_step, + trace_dir=args.train.profile_trace_dir, + record_shapes=args.train.profile_record_shapes, + profile_memory=args.train.profile_profile_memory, + with_stack=args.train.profile_with_stack, + global_rank=args.train.global_rank, + ) + profiler.start() + + start_epoch, start_step, global_step = 0, 0, 0 + save_checkpoint_path = None + environ_meter = helper.EnvironMeter( + config=model_config, + global_batch_size=args.train.global_batch_size, + rmpad=args.train.rmpad, + rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, + empty_cache_steps=args.train.empty_cache_steps, + enable_multisource=args.data.enable_multisource, + dataloader=train_dataloader, + data_path=args.data.train_path, + ) + + if args.train.load_checkpoint_path: + state = {"model": model, "optimizer": optimizer, "extra_state": {}} # cannot be None + Checkpointer.load(args.train.load_checkpoint_path, state) + global_step = state["extra_state"]["global_step"] + start_epoch = global_step // args.train.train_steps + start_step = global_step % args.train.train_steps + lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"]) + train_dataloader.load_state_dict(state["extra_state"]["train_dataloader"]) + environ_meter.load_state_dict(state["extra_state"]["environ_meter"]) + torch.set_rng_state(state["extra_state"]["torch_rng_state"]) + if start_step == 0: # resume at the end of epoch + iter(train_dataloader) # clear resume state and prefetch data + + dist.barrier() + logger.info_rank0(f"Load distributed checkpoint from {args.train.load_checkpoint_path} successfully!") + + # Build block diffusion attention mask + if args.train.block_diffusion_mode: + bd_attn_full_len = args.data.max_seq_len * 2 + block_size = args.train.block_size + # NOTE: Boolean dtype block diffusion attention mask + block_diffusion_attn_mask_flag = block_diffusion_mask( + b=None, h=None, + q_idx=torch.arange(bd_attn_full_len)[:, None], + kv_idx=torch.arange(bd_attn_full_len)[None, :], + block_size=block_size, + n=args.data.max_seq_len + ).unsqueeze(0).unsqueeze(0) + + block_diffusion_attn_mask_prototype = torch.zeros_like( + block_diffusion_attn_mask_flag, + dtype=torch.float32 if args.train.enable_mixed_precision else torch.bfloat16 + ) + block_diffusion_attn_mask_prototype.masked_fill_(block_diffusion_attn_mask_flag.logical_not(), float("-inf")) + + helper.empty_cache() + model_fwd_context, model_bwd_context = build_activation_offloading_context( + args.train.enable_activation_offload, args.train.enable_gradient_checkpointing, args.train.activation_gpu_limit + ) + model.train() + logger.info( + f"rank{args.train.local_rank} Start training, train_steps: {args.train.train_steps}, epochs: {args.train.num_train_epochs}" + ) + for epoch in range(start_epoch, args.train.num_train_epochs): + if hasattr(train_dataloader, "set_epoch"): + train_dataloader.set_epoch(epoch) + + data_loader_tqdm = trange( + args.train.train_steps, + desc=f"Epoch {epoch + 1}/{args.train.num_train_epochs}", + total=args.train.train_steps, + initial=start_step, + disable=args.train.local_rank != 0, + ) + data_iterator = iter(train_dataloader) + for _ in range(start_step, args.train.train_steps): + global_step += 1 + + try: + micro_batches: List[Dict[str, Any]] = next(data_iterator) + except StopIteration: + logger.info(f"epoch:{epoch} Dataloader finished with drop_last {args.data.drop_last}") + break + + if global_step == 1: + helper.print_example(example=micro_batches[0], rank=args.train.local_rank) + + total_loss = 0 + synchronize() + start_time = time.time() + num_accumulation_steps = len(micro_batches) + total_consistency_loss = 0 + total_confidence_loss = 0 + + for micro_batch in micro_batches: + environ_meter.add(micro_batch) + if args.data.enable_multisource: + micro_batch.pop("ds_idx", None) + micro_batch.pop("source_name", None) + + micro_batch = { + k: v.to(get_device_type(), non_blocking=True) if isinstance(v, torch.Tensor) else v + for k, v in micro_batch.items() + } + if args.train.block_diffusion_mode: + noisy_input_ids = micro_batch["noisy_input_ids"] + clean_input_ids = micro_batch["input_ids"] + batch_size = noisy_input_ids.shape[0] + full_input_ids = torch.cat([noisy_input_ids, clean_input_ids], dim=1) + noisy_position_ids = torch.arange(noisy_input_ids.shape[1], device=get_device_type(), dtype=torch.long) + clean_position_ids = torch.arange(clean_input_ids.shape[1], device=get_device_type(), dtype=torch.long) + position_ids = torch.cat([noisy_position_ids, clean_position_ids], dim=0).unsqueeze(0).expand(batch_size, -1).clone() + micro_batch["input_ids"] = full_input_ids + micro_batch["position_ids"] = position_ids + micro_batch["attention_mask"] = block_diffusion_attn_mask_prototype.expand(batch_size, -1, -1, -1) + else: + micro_batch["attention_mask"] = None + + labels = micro_batch.pop("labels", None) + + with model_fwd_context: + logits: "torch.Tensor" = model(**micro_batch, use_cache=False, output_router_logits=False).logits + if args.train.block_diffusion_mode: + noisy_logits = logits[:, :noisy_input_ids.shape[1]].contiguous() + else: + noisy_logits = logits + + confidence_loss = torch.tensor(0.0, device=noisy_logits.device) + if args.train.confidence_beta > 0: + confidence_loss = compute_confidence_loss( + logits=noisy_logits, + labels=labels, + ) + + if args.train.same_token_labels: + unscaled_loss = torch.nn.functional.cross_entropy( + noisy_logits.view(-1, noisy_logits.shape[-1]), + labels.view(-1), + reduction="none", + ).view(noisy_logits.shape[0], -1) + consistency_loss = unscaled_loss.sum() / (labels != -100).sum() + else: + shifted_noisy_logits = noisy_logits[:, :-1, :].contiguous() + shifted_labels = labels[:, 1:].contiguous() + unscaled_loss = torch.nn.functional.cross_entropy( + shifted_noisy_logits.view(-1, shifted_noisy_logits.shape[-1]), + shifted_labels.view(-1), + reduction="none", + ).view(shifted_noisy_logits.shape[0], -1) + consistency_loss = unscaled_loss.sum() / (shifted_labels != -100).sum() + + combined_loss = consistency_loss + confidence_loss * args.train.confidence_beta + loss = combined_loss / num_accumulation_steps + with model_bwd_context: + loss.backward() + + total_loss += loss.item() + total_consistency_loss += consistency_loss.item() / num_accumulation_steps + total_confidence_loss += confidence_loss.item() / num_accumulation_steps + del micro_batch + + # Prefer model-provided clip_grad_norm_ (now both FSDP1 and FSDP2 registers custom grad norm clipping) + if hasattr(model, "clip_grad_norm_"): + _gn = model.clip_grad_norm_(args.train.max_grad_norm) + grad_norm = _gn.item() if hasattr(_gn, "item") else float(_gn) + else: + logger.info_rank0( + "Can NOT find regitsered clip_grad_norm_ method in the model, using PyTorch default implementation.." + ) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.train.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + if hasattr(grad_norm, "full_tensor"): + grad_norm = grad_norm.full_tensor().item() + + # collect mean loss across data parallel group + total_loss, grad_norm = all_reduce((total_loss, grad_norm), group=get_parallel_state().fsdp_group) + synchronize() + delta_time = time.time() - start_time + lr = max(lr_scheduler.get_last_lr()) + train_metrics = environ_meter.step(delta_time, global_step=global_step) + + data_loader_tqdm.set_postfix_str(f"loss: {total_loss:.2f}, cons: {total_consistency_loss:.2f}, conf: {total_confidence_loss:.2f}, grad_norm: {grad_norm:.2f}, lr: {lr:.2e}") + data_loader_tqdm.update() + + if args.train.global_rank == 0: + if args.train.use_wandb: + train_metrics.update( + {"training/loss": total_loss, "training/cons_loss": total_consistency_loss, "training/conf_loss": total_confidence_loss, "training/grad_norm": grad_norm, "training/lr": lr} + ) + wandb.log(train_metrics, step=global_step) + + if args.train.profile_this_rank and global_step <= args.train.profile_end_step: + profiler.step() + if global_step == args.train.profile_end_step: + profiler.stop() + + if args.train.save_steps and global_step % args.train.save_steps == 0: + helper.empty_cache() + save_checkpoint_path = os.path.join(args.train.save_checkpoint_path, f"global_step_{global_step}") + state = { + "model": model, + "optimizer": optimizer, + "extra_state": { + "global_step": global_step, + "lr_scheduler": lr_scheduler.state_dict(), + "train_dataloader": train_dataloader.state_dict(), + "environ_meter": environ_meter.state_dict(), + "torch_rng_state": torch.get_rng_state(), + }, + } + Checkpointer.save(args.train.save_checkpoint_path, state, global_steps=global_step) + + dist.barrier() + logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!") + + data_loader_tqdm.close() + start_step = 0 + helper.print_device_mem_info(f"VRAM usage after epoch {epoch + 1}") + if args.train.save_epochs and (epoch + 1) % args.train.save_epochs == 0: + helper.empty_cache() + save_checkpoint_path = os.path.join(args.train.save_checkpoint_path, f"global_step_{global_step}") + state = { + "model": model, + "optimizer": optimizer, + "extra_state": { + "global_step": global_step, + "lr_scheduler": lr_scheduler.state_dict(), + "train_dataloader": train_dataloader.state_dict(), + "environ_meter": environ_meter.state_dict(), + "torch_rng_state": torch.get_rng_state(), + }, + } + Checkpointer.save(args.train.save_checkpoint_path, state, global_steps=global_step) + dist.barrier() + logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!") + + synchronize() + # release memory + del optimizer, lr_scheduler + helper.empty_cache() + # save model in huggingface's format + if args.train.global_rank == 0 and args.train.save_hf_weights and save_checkpoint_path is not None: + hf_weights_path = os.path.join(save_checkpoint_path, "hf_ckpt") + model_state_dict = ckpt_to_state_dict( + save_checkpoint_path=save_checkpoint_path, + output_dir=args.train.output_dir, + ckpt_manager=args.train.ckpt_manager, + ) + save_model_weights(hf_weights_path, model_state_dict, model_assets=model_assets) + logger.info_rank0(f"Huggingface checkpoint saved at {hf_weights_path} successfully!") + + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() \ No newline at end of file From f301933c714f069bfeee6584b2a06a875ee5d246 Mon Sep 17 00:00:00 2001 From: Grason-Lu Date: Tue, 11 Nov 2025 12:16:10 +0800 Subject: [PATCH 04/10] fix --- docs/source/algo/trainable_parallel_decoding.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/algo/trainable_parallel_decoding.rst b/docs/source/algo/trainable_parallel_decoding.rst index 04d136f..5c3db73 100644 --- a/docs/source/algo/trainable_parallel_decoding.rst +++ b/docs/source/algo/trainable_parallel_decoding.rst @@ -87,6 +87,7 @@ Configuration To enable DPARALLEL, use the following training configuration: .. code:: bash + sh train.sh tasks/train_llada2_bd_with_dparallel.py configs/sft/llada2_mini_bd_sft.yaml --train.confidence_beta {confidence_beta} Where: From 77c26ff86b0cee55f1a6c13fe79a84652709f61a Mon Sep 17 00:00:00 2001 From: Grason-Lu Date: Tue, 11 Nov 2025 12:20:26 +0800 Subject: [PATCH 05/10] fix --- docs/source/algo/trainable_parallel_decoding.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/algo/trainable_parallel_decoding.rst b/docs/source/algo/trainable_parallel_decoding.rst index 5c3db73..a98e57e 100644 --- a/docs/source/algo/trainable_parallel_decoding.rst +++ b/docs/source/algo/trainable_parallel_decoding.rst @@ -3,12 +3,12 @@ Trainable Parallel Decoding Last updated: 2025-11-10 -Trainable Parallel Decoding is a novel approach to accelerate Discrete Diffusion Language Models (DDLMs) by learning to decode multiple tokens simultaneously during training, thereby reducing inference latency while maintaining generation quality. +Trainable Parallel Decoding is a novel approach to accelerate Diffusion Large Language Models (DLLMs) by learning to decode multiple tokens simultaneously during training, thereby reducing inference latency while maintaining generation quality. Overview -------- -Traditional DDLMs suffer from high inference latency due to their iterative, multi-step sampling process. Trainable Parallel Decoding addresses this limitation by introducing a second-stage fine-tuning paradigm that teaches the model to predict multiple future tokens in a single forward pass. This approach transforms the sequential generation process into a more parallelizable one, significantly reducing the number of required sampling steps. +Traditional DLLMs suffer from high inference latency due to their iterative, multi-step sampling process. Trainable Parallel Decoding addresses this limitation by introducing a second-stage fine-tuning paradigm that teaches the model to predict multiple future tokens in a single forward pass. This approach transforms the sequential generation process into a more parallelizable one, significantly reducing the number of required sampling steps. The framework currently supports two complementary techniques: @@ -70,7 +70,7 @@ The training data format should include: DPARALLEL: Learnable Parallel Decoding -------------------------------------- -DPARALLEL is inspired by the DPARALLEL: Learnable Parallel Decoding for DDLMs approach, which introduces an entropy-based regularization term to the training loss to encourage the model to learn parallel decoding capabilities. +DPARALLEL is inspired by the DPARALLEL: Learnable Parallel Decoding for DLLMs approach, which introduces an entropy-based regularization term to the training loss to encourage the model to learn parallel decoding capabilities. Methodology ~~~~~~~~~~~ From b7f5e9aee02f82d6a5d39a022d78b60fd6e8b82b Mon Sep 17 00:00:00 2001 From: Grason-Lu Date: Tue, 11 Nov 2025 12:20:56 +0800 Subject: [PATCH 06/10] fix --- docs/source/algo/trainable_parallel_decoding.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/algo/trainable_parallel_decoding.rst b/docs/source/algo/trainable_parallel_decoding.rst index a98e57e..2dedced 100644 --- a/docs/source/algo/trainable_parallel_decoding.rst +++ b/docs/source/algo/trainable_parallel_decoding.rst @@ -25,7 +25,7 @@ High-Quality Trajectory Distillation The first stage involves creating a dataset of "golden" trajectories through the following process: -1. **Trajectory Generation**: Use a pre-trained DDLM to sample generation paths on a domain-specific dataset (e.g., 200,000 math problems) +1. **Trajectory Generation**: Use a pre-trained DLLM to sample generation paths on a domain-specific dataset (e.g., 200,000 math problems) 2. **Quality Filtering**: Apply an external verifier to filter trajectories that produce correct outputs 3. **Dataset Construction**: Retain only high-quality trajectories that pass verification From b77508f1f3746b6c969f1e12aefe2f3702d3a729 Mon Sep 17 00:00:00 2001 From: Grason-Lu Date: Tue, 11 Nov 2025 12:28:46 +0800 Subject: [PATCH 07/10] fix --- docs/source/algo/trainable_parallel_decoding.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/algo/trainable_parallel_decoding.rst b/docs/source/algo/trainable_parallel_decoding.rst index 2dedced..079c43e 100644 --- a/docs/source/algo/trainable_parallel_decoding.rst +++ b/docs/source/algo/trainable_parallel_decoding.rst @@ -70,7 +70,7 @@ The training data format should include: DPARALLEL: Learnable Parallel Decoding -------------------------------------- -DPARALLEL is inspired by the DPARALLEL: Learnable Parallel Decoding for DLLMs approach, which introduces an entropy-based regularization term to the training loss to encourage the model to learn parallel decoding capabilities. +DPARALLEL: Learnable Parallel Decoding for DLLMs [arXiv:2509.26488] is a novel approach that introduces an entropy-based regularization term to the training loss to encourage the model to learn parallel decoding capabilities. Methodology ~~~~~~~~~~~ @@ -91,6 +91,7 @@ To enable DPARALLEL, use the following training configuration: sh train.sh tasks/train_llada2_bd_with_dparallel.py configs/sft/llada2_mini_bd_sft.yaml --train.confidence_beta {confidence_beta} Where: + - ``confidence_beta`` controls the strength of the entropy regularization (recommended value: 2.0) - Higher values encourage more aggressive parallel decoding - The parameter balances between generation quality and speed-up From c31ba614fdb7b8f92f7f7b834e1ec053542d848f Mon Sep 17 00:00:00 2001 From: Grason-Lu Date: Tue, 11 Nov 2025 13:08:16 +0800 Subject: [PATCH 08/10] fix --- docs/source/algo/trainable_parallel_decoding.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/algo/trainable_parallel_decoding.rst b/docs/source/algo/trainable_parallel_decoding.rst index 079c43e..17232e8 100644 --- a/docs/source/algo/trainable_parallel_decoding.rst +++ b/docs/source/algo/trainable_parallel_decoding.rst @@ -18,7 +18,7 @@ The framework currently supports two complementary techniques: Path Distillation (Trajectory Compression) ------------------------------------------ -Path Distillation is inspired by the observation in Seed Diffusion [arXiv:2508.02193] that training on high-quality generation paths can significantly improve model efficiency. This method consists of two main stages: +Path Distillation is inspired by the observation in `Song et al., 2025 `_ that training on high-quality generation paths can significantly improve model efficiency. This method consists of two main stages: High-Quality Trajectory Distillation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -70,7 +70,7 @@ The training data format should include: DPARALLEL: Learnable Parallel Decoding -------------------------------------- -DPARALLEL: Learnable Parallel Decoding for DLLMs [arXiv:2509.26488] is a novel approach that introduces an entropy-based regularization term to the training loss to encourage the model to learn parallel decoding capabilities. +`Chen et al., 2025 `_ is a novel approach that introduces an entropy-based regularization term to the training loss to encourage the model to learn parallel decoding capabilities. Methodology ~~~~~~~~~~~ From 270533174a4d0fe7f86a84def9e07cce16ef9012 Mon Sep 17 00:00:00 2001 From: Grason-Lu Date: Tue, 11 Nov 2025 13:13:50 +0800 Subject: [PATCH 09/10] fix --- docs/source/algo/trainable_parallel_decoding.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/algo/trainable_parallel_decoding.rst b/docs/source/algo/trainable_parallel_decoding.rst index 17232e8..bcba9d7 100644 --- a/docs/source/algo/trainable_parallel_decoding.rst +++ b/docs/source/algo/trainable_parallel_decoding.rst @@ -18,7 +18,7 @@ The framework currently supports two complementary techniques: Path Distillation (Trajectory Compression) ------------------------------------------ -Path Distillation is inspired by the observation in `Song et al., 2025 `_ that training on high-quality generation paths can significantly improve model efficiency. This method consists of two main stages: +Path Distillation is motivated by the key observation from `Song et al., 2025 `_ that training on high-quality generation paths can significantly improve model efficiency. The method consists of two main stages: High-Quality Trajectory Distillation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -70,7 +70,7 @@ The training data format should include: DPARALLEL: Learnable Parallel Decoding -------------------------------------- -`Chen et al., 2025 `_ is a novel approach that introduces an entropy-based regularization term to the training loss to encourage the model to learn parallel decoding capabilities. +`Chen et al., 2025 `_ introduce dParallel, a novel approach that incorporates an entropy-based regularization term into the training loss to encourage parallel decoding capabilities. Methodology ~~~~~~~~~~~ From 7f5e5143c79d979ce4f1c86ff8ca0d78cd5dce1e Mon Sep 17 00:00:00 2001 From: Grason-Lu Date: Tue, 11 Nov 2025 13:17:09 +0800 Subject: [PATCH 10/10] fix --- tasks/train_llada2_bd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tasks/train_llada2_bd.py b/tasks/train_llada2_bd.py index b21f0bc..7a46f86 100644 --- a/tasks/train_llada2_bd.py +++ b/tasks/train_llada2_bd.py @@ -569,4 +569,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()