From 849394199e39e6d2a3ebce76bcbfbd46824325a9 Mon Sep 17 00:00:00 2001 From: shuibai Date: Sun, 28 Dec 2025 01:34:40 +0000 Subject: [PATCH 1/2] fix dimension bug --- python/dinfer/decoding/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/dinfer/decoding/utils.py b/python/dinfer/decoding/utils.py index fb178fb..4bde8a3 100644 --- a/python/dinfer/decoding/utils.py +++ b/python/dinfer/decoding/utils.py @@ -247,7 +247,9 @@ def __next__(self): current_block_end = min(current_block_start + self.block_length, self.x.total_length) assert current_block_end <= self.x.total_length self.iter += 1 - return BlockLoc(current_block_start, current_block_end), self.x[current_block_start:current_block_end] + # Fix dimension bug: originally was self.x[current_block_start:current_block_end], + # now correctly uses self.x[:, current_block_start:current_block_end] to handle batch dimension + return BlockLoc(current_block_start, current_block_end), self.x[:, current_block_start:current_block_end] class BlockIteratorFactory: From befe61f9fc813559f8209881937bd8059771909d Mon Sep 17 00:00:00 2001 From: shuibai Date: Sun, 28 Dec 2025 01:42:59 +0000 Subject: [PATCH 2/2] add support for threshold based remask decoding --- .gitignore | 5 + evaluations/eval_dinfer_sglang.py | 9 +- evaluations/eval_llada_mini_remask.sh | 41 ++++ python/dinfer/decoding/parallel_strategy.py | 201 ++++++++++++++++++-- 4 files changed, 233 insertions(+), 23 deletions(-) create mode 100644 .gitignore create mode 100644 evaluations/eval_llada_mini_remask.sh diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..653883f --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +evaluations/outputs/ +build/ +python/dinfer.egg-info + +evaluations/tasks/mbpp_sanitized/__pycache__/* \ No newline at end of file diff --git a/evaluations/eval_dinfer_sglang.py b/evaluations/eval_dinfer_sglang.py index c3e182e..8cdb950 100644 --- a/evaluations/eval_dinfer_sglang.py +++ b/evaluations/eval_dinfer_sglang.py @@ -97,9 +97,9 @@ def run_benchmark(world_size, rank, gpu_id, tokenizer, args): if args.parallel_decoding == 'threshold': if args.use_credit: - decoder = CreditThresholdParallelDecoder(temperature=0, threshold=args.threshold, mask_id=mask_id, eos_id=eos_id) + decoder = CreditThresholdParallelDecoder(temperature=0, threshold=args.threshold, mask_id=mask_id, eos_id=eos_id, enable_remask=args.enable_remask) else: - decoder = ThresholdParallelDecoder(temperature=0, threshold=args.threshold, mask_id=mask_id, eos_id=eos_id) + decoder = ThresholdParallelDecoder(temperature=0, threshold=args.threshold, mask_id=mask_id, eos_id=eos_id, enable_remask=args.enable_remask) else: decoder = HierarchyDecoder(temperature=0, threshold=args.threshold, low_threshold=args.low_threshold, mask_id=mask_id, eos_id=eos_id) @@ -272,6 +272,7 @@ class EvalConfig: batch_size: int = 1 save_samples: bool = False speed_path: str = '' + enable_remask: bool = False def set_seed(seed): torch.manual_seed(seed) @@ -316,6 +317,7 @@ def __init__( use_shift = False, model_type = 'llada2', save_samples = False, + enable_remask = False, **kwargs ): @@ -354,6 +356,7 @@ def __init__( self.use_shift = use_shift self.model_type = model_type self.save_samples = save_samples + self.enable_remask = enable_remask if self.model_type == 'llada2': self.mask_id = 156895 @@ -575,7 +578,7 @@ def cal_bucket_len(gen_len, all_input_ids): procs = [] answers = [] gpus = [int(gpu) for gpu in self.gpus.split(';')] - args = {"gpu": gpus, "batch_size": self.batch_size, "model_name": self.model_path, "gen_len": self.gen_length, "block_length": self.block_length, "prefix_look": self.prefix_look, "after_look": self.after_look, "warmup_times": self.warmup_times, "low_threshold": self.low_threshold, "threshold": self.threshold, "cont_weight": self.cont_weight, "use_credit": self.use_credit, "cache": self.cache, "parallel_decoding": self.parallel_decoding, "tp_size": self.tp_size, "save_path": self.save_path, "use_cudagraph": self.use_cudagraph, "use_compile": self.use_compile,"use_bd": self.use_bd, "use_shift": self.use_shift, "model_type": self.model_type, "vocab_size": self.vocab_size, "batch_size": self.batch_size, "speed_path": self.speed_path} + args = {"gpu": gpus, "batch_size": self.batch_size, "model_name": self.model_path, "gen_len": self.gen_length, "block_length": self.block_length, "prefix_look": self.prefix_look, "after_look": self.after_look, "warmup_times": self.warmup_times, "low_threshold": self.low_threshold, "threshold": self.threshold, "cont_weight": self.cont_weight, "use_credit": self.use_credit, "cache": self.cache, "parallel_decoding": self.parallel_decoding, "tp_size": self.tp_size, "save_path": self.save_path, "use_cudagraph": self.use_cudagraph, "use_compile": self.use_compile,"use_bd": self.use_bd, "use_shift": self.use_shift, "model_type": self.model_type, "vocab_size": self.vocab_size, "batch_size": self.batch_size, "speed_path": self.speed_path, "enable_remask": self.enable_remask} args = EvalConfig(**args) args.tp_size = len(gpus) args.master_port = self.master_port diff --git a/evaluations/eval_llada_mini_remask.sh b/evaluations/eval_llada_mini_remask.sh new file mode 100644 index 0000000..53eb8f9 --- /dev/null +++ b/evaluations/eval_llada_mini_remask.sh @@ -0,0 +1,41 @@ +# Set the environment variables first before running the command. +export HF_ALLOW_CODE_EVAL=1 +export HF_DATASETS_TRUST_REMOTE_CODE=1 +export TRANSFORMERS_TRUST_REMOTE_CODE=1 +export CUDA_VISIBLE_DEVICES=4,5,6,7 + +parallel_decoding='threshold' # or hierarchy +length=2048 # generate length +block_length=32 # block length +model_path='/root/inclusion_data/separate_expert_model/LLaDA2.0-mini-preview' # your model path +threshold=0.98 # threshold for parallel decoding +low_threshold=0.62 # low threshold for parallel decoding when using hierarchy mechanism +cache='prefix' # or 'prefix' for prefix cache; or '' if you don't want to use cache +warmup_times=0 # warmup times for cache +prefix_look=0 +after_look=0 +cont_weight=0 # cont weight +use_credit=False # enable credit for threshold mechanism +use_compile=True # use compile +tp_size=4 # tensor parallel size +gpus='0;1;2;3' # gpus for tensor parallel inference +parallel='tp' # 'tp' for tensor parallel or 'dp' for data parallel +output_dir='./outputs' # your customer output path +model_type='llada2' # llada2 (for llada2-mini) +use_bd=True # use block diffusion +master_port="23458" +save_samples=False # save samples +enable_remask=True # enable remasking for threshold decoder +# for llada 1.5 use tasks gsm8k_llada1.5 mbpp_sanitized_llada1.5 +# for llada2_mini use tasks gsm8k_llada_mini mbpp_sanitized_llada_mini +if [ parallel=='tp' ]; then + for task in mbpp_sanitized_llada_mini; do + output_path=${output_dir}/${task} + python eval_dinfer_sglang.py --tasks ${task} \ + --confirm_run_unsafe_code --model dInfer_eval \ + --model_args model_path=${model_path},gen_length=${length},block_length=${block_length},threshold=${threshold},low_threshold=${low_threshold},show_speed=True,save_dir=${output_path},parallel_decoding=${parallel_decoding},cache=${cache},warmup_times=${warmup_times},use_compile=${use_compile},tp_size=${tp_size},parallel=${parallel},cont_weight=${cont_weight},use_credit=${use_credit},prefix_look=${prefix_look},after_look=${after_look},gpus=${gpus},model_type=${model_type},use_bd=${use_bd},master_port=${master_port},save_samples=${save_samples},enable_remask=${enable_remask} \ + --output_path ${output_path} --include_path ./tasks --apply_chat_template + done +else + echo "parallel must be tp" +fi diff --git a/python/dinfer/decoding/parallel_strategy.py b/python/dinfer/decoding/parallel_strategy.py index dc775bb..66ca387 100644 --- a/python/dinfer/decoding/parallel_strategy.py +++ b/python/dinfer/decoding/parallel_strategy.py @@ -385,6 +385,115 @@ def get_transfer_index_threshold( return x0, transfer_index +@torch.no_grad() +@torch.compile(dynamic=True) +def get_transfer_index_threshold_remask( + logits, + temperature, + mask_index, + x, + mask_id, + threshold, + use_float64=False, + fix_mask=None, + **kwargs, +): + """Similar to get_transfer_index_threshold but with remasking support. + + Tokens that are already decoded but have confidence below the given threshold + will be remasked (set back to mask_id). + Uses the given threshold directly (no adaptive threshold calculation). + """ + # Keep the original mask positions for remasking decision. + orig_mask_index = mask_index + + # 1) Sample token ids from gumbel-noised logits (sampling decision). + logits_with_noise = add_gumbel_noise(logits, temperature=temperature) + x0_sample = torch.argmax(logits_with_noise, dim=-1) # b, l + + # 2) Compute confidence from *clean* logits (no gumbel): + # - masked positions: probability of the sampled token (x0_sample from gumbel sampling) + # - unmasked positions: probability of current token in x (used for remasking decision) + if use_float64: + p = F.softmax(logits.to(torch.float64), dim=-1) + else: + p = F.softmax(logits.to(torch.float32), dim=-1) + + # Probability of the *sampled* token at each position: + # - token id comes from gumbel sampling (x0_sample) + # - probability is evaluated under clean logits (p) + # Used for thresholding masked positions (deciding which masks to fill this step). + x0_p = torch.squeeze( + torch.gather( + p, dim=-1, index=torch.unsqueeze(x0_sample.to(torch.long), -1) + ), + -1, + ) # b, l + + # Probability of the *current* token already in the sequence (x) under clean logits (p). + # Used for remasking: unmasked positions with low confidence will be set back to mask_id. + x_p = torch.squeeze( + torch.gather(p, dim=-1, index=torch.unsqueeze(x.to(torch.long), -1)), + -1, + ) # b, l + + # Step 1: Default decode all mask positions + x0 = torch.where(orig_mask_index, x0_sample, x) # [B, L] + + # Step 2: Build unified confidence: mask positions use x0_p, non-mask positions use x_p + unified_confidence = torch.where(orig_mask_index, x0_p, x_p) # [B, L] + + # If fix_mask is provided, set confidence of fixed positions to +inf to prevent remasking + if fix_mask is not None: + # fix_mask should have shape [B, L] matching unified_confidence + # Set fixed positions to +inf so they won't be selected for remasking (highest confidence) + unified_confidence = torch.where(fix_mask, torch.full_like(unified_confidence, float('inf')), unified_confidence) + + # Step 3: Use the given threshold directly (no adaptive threshold) + if isinstance(threshold, torch.Tensor): + threshold_tensor = threshold.to(device=unified_confidence.device, dtype=unified_confidence.dtype) + if threshold_tensor.dim() == 0: + threshold_tensor = threshold_tensor.unsqueeze(-1) + else: + threshold_tensor = torch.tensor(threshold, device=unified_confidence.device, dtype=unified_confidence.dtype).unsqueeze(-1) + + # Step 4: Unified decision: confidence >= threshold means keep, < threshold means remask + keep_mask = unified_confidence >= threshold_tensor # [B, L] + + # Step 5: Remask positions with low confidence, but ensure mask count is strictly decreasing + # Calculate how many masks we decoded this step + decoded_cnt = orig_mask_index.sum(dim=1) # [B] - all original masks were decoded + + # Candidates for remasking: positions with low confidence + remask_candidates = ~keep_mask # [B, L] + cand_cnt = remask_candidates.sum(dim=1) # [B] + + # Limit remask count: remask < decoded to ensure mask count strictly decreasing + # This ensures: final_mask = orig_mask - decoded + remask < orig_mask + max_remask = torch.minimum(decoded_cnt - 1, cand_cnt).clamp(min=0) # [B] + + # Sort candidates by confidence (lowest first) and select top max_remask + cand_conf = torch.where(remask_candidates, unified_confidence, torch.inf) # Non-candidates set to +inf + sorted_idx = torch.argsort(cand_conf, dim=1) # [B, L] + ranks = torch.argsort(sorted_idx, dim=1) # ranks[pos] = position's rank in sorted order + + # Select top max_remask candidates (lowest confidence) + max_remask_exp = max_remask.unsqueeze(1) # [B, 1] + remask_index = (ranks < max_remask_exp) & remask_candidates # [B, L] + + # Ensure fix_mask positions are never remasked (double protection) + if fix_mask is not None: + remask_index = torch.logical_and(remask_index, ~fix_mask) # Never remask fixed positions + + # Step 6: Apply remask - set low confidence positions back to mask_id + x0 = torch.where(remask_index, mask_id, x0) + + # transfer_index: all True, meaning x0 is the final updated block + transfer_index = torch.ones_like(orig_mask_index, dtype=torch.bool) + + return x0, transfer_index + + class ThresholdParallelDecoder(ParallelDecoder): """Parallel decoding driven by a confidence threshold.""" def __init__( @@ -395,11 +504,27 @@ def __init__( mask_id=126336, eos_id=126081, use_float64=False, + enable_remask=False, ): super().__init__(temperature, remasking, mask_id) self.threshold = threshold self.eos_id = eos_id self.use_float64 = use_float64 + # Enable remasking based on the enable_remask parameter + self.enable_remask = enable_remask + + def block_init(self, block_x, block_id): + """Initialize fix_mask to protect prompt tokens from remasking. + + If remasking is enabled, creates fix_mask to mark prompt tokens (non-mask_id tokens) + that should be protected from remasking operations. + """ + if self.enable_remask: + # Create fix_mask: mark positions that are NOT mask_id (i.e., prompt tokens) + # These positions should be protected from remasking + self.fix_mask = (block_x != self.mask_id) # [B, L], True for prompt tokens + else: + self.fix_mask = None def decode(self, logits, block_start, block_end, x, iter_threshold=None): """ Decode the logits in the same block of multiple samples. @@ -410,16 +535,35 @@ def decode(self, logits, block_start, block_end, x, iter_threshold=None): assert mask_index.shape[1] == logits.shape[1] curr_x = x[:, block_start:block_end] - x0, transfer_index = get_transfer_index_threshold( - logits, - self.temperature, - mask_index, - curr_x, - self.mask_id, - threshold=iter_threshold, - use_float64=self.use_float64, - ) - transfer_index = torch.logical_and(transfer_index, mask_index) + + if self.enable_remask: + # Get fix_mask for current block (protect prompt tokens from remasking) + fix_mask = getattr(self, 'fix_mask', None) + x0, transfer_index = get_transfer_index_threshold_remask( + logits, + self.temperature, + mask_index, + curr_x, + self.mask_id, + threshold=iter_threshold, + use_float64=self.use_float64, + fix_mask=fix_mask, + ) + else: + x0, transfer_index = get_transfer_index_threshold( + logits, + self.temperature, + mask_index, + curr_x, + self.mask_id, + threshold=iter_threshold, + use_float64=self.use_float64, + ) + + # For remasking case, transfer_index may include remasked positions + # For non-remasking case, only update masked positions + if not self.enable_remask: + transfer_index = torch.logical_and(transfer_index, mask_index) assert transfer_index.dtype == torch.bool x[:, block_start:block_end] = torch.where(transfer_index, x0, curr_x) broadcast_if_needed(x.data) @@ -438,17 +582,34 @@ def batch_decode(self, logits, block_start, x, block_length, iter_threshold=None mask_index = (x_block == self.mask_id) - x0, transfer_index = get_transfer_index_threshold( - logits, - self.temperature, - mask_index, - x_block, - self.mask_id, - threshold=iter_threshold, - use_float64=self.use_float64, - ) + if self.enable_remask: + # Get fix_mask for current block (protect prompt tokens from remasking) + fix_mask = getattr(self, 'fix_mask', None) + x0, transfer_index = get_transfer_index_threshold_remask( + logits, + self.temperature, + mask_index, + x_block, + self.mask_id, + threshold=iter_threshold, + use_float64=self.use_float64, + fix_mask=fix_mask, + ) + else: + x0, transfer_index = get_transfer_index_threshold( + logits, + self.temperature, + mask_index, + x_block, + self.mask_id, + threshold=iter_threshold, + use_float64=self.use_float64, + ) - transfer_index = transfer_index & mask_index + # For remasking case, transfer_index may include remasked positions + # For non-remasking case, only update masked positions + if not self.enable_remask: + transfer_index = transfer_index & mask_index x_updated = torch.where(transfer_index, x0, x_block)