-
Notifications
You must be signed in to change notification settings - Fork 40
Feature/threshold remasking #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| evaluations/outputs/ | ||
| build/ | ||
| python/dinfer.egg-info | ||
|
|
||
| evaluations/tasks/mbpp_sanitized/__pycache__/* |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,41 @@ | ||||||
| # Set the environment variables first before running the command. | ||||||
| export HF_ALLOW_CODE_EVAL=1 | ||||||
| export HF_DATASETS_TRUST_REMOTE_CODE=1 | ||||||
| export TRANSFORMERS_TRUST_REMOTE_CODE=1 | ||||||
| export CUDA_VISIBLE_DEVICES=4,5,6,7 | ||||||
|
|
||||||
| parallel_decoding='threshold' # or hierarchy | ||||||
| length=2048 # generate length | ||||||
| block_length=32 # block length | ||||||
| model_path='/root/inclusion_data/separate_expert_model/LLaDA2.0-mini-preview' # your model path | ||||||
| threshold=0.98 # threshold for parallel decoding | ||||||
| low_threshold=0.62 # low threshold for parallel decoding when using hierarchy mechanism | ||||||
| cache='prefix' # or 'prefix' for prefix cache; or '' if you don't want to use cache | ||||||
| warmup_times=0 # warmup times for cache | ||||||
| prefix_look=0 | ||||||
| after_look=0 | ||||||
| cont_weight=0 # cont weight | ||||||
| use_credit=False # enable credit for threshold mechanism | ||||||
| use_compile=True # use compile | ||||||
| tp_size=4 # tensor parallel size | ||||||
| gpus='0;1;2;3' # gpus for tensor parallel inference | ||||||
| parallel='tp' # 'tp' for tensor parallel or 'dp' for data parallel | ||||||
| output_dir='./outputs' # your customer output path | ||||||
| model_type='llada2' # llada2 (for llada2-mini) | ||||||
| use_bd=True # use block diffusion | ||||||
| master_port="23458" | ||||||
| save_samples=False # save samples | ||||||
| enable_remask=True # enable remasking for threshold decoder | ||||||
| # for llada 1.5 use tasks gsm8k_llada1.5 mbpp_sanitized_llada1.5 | ||||||
| # for llada2_mini use tasks gsm8k_llada_mini mbpp_sanitized_llada_mini | ||||||
| if [ parallel=='tp' ]; then | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| for task in mbpp_sanitized_llada_mini; do | ||||||
| output_path=${output_dir}/${task} | ||||||
| python eval_dinfer_sglang.py --tasks ${task} \ | ||||||
| --confirm_run_unsafe_code --model dInfer_eval \ | ||||||
| --model_args model_path=${model_path},gen_length=${length},block_length=${block_length},threshold=${threshold},low_threshold=${low_threshold},show_speed=True,save_dir=${output_path},parallel_decoding=${parallel_decoding},cache=${cache},warmup_times=${warmup_times},use_compile=${use_compile},tp_size=${tp_size},parallel=${parallel},cont_weight=${cont_weight},use_credit=${use_credit},prefix_look=${prefix_look},after_look=${after_look},gpus=${gpus},model_type=${model_type},use_bd=${use_bd},master_port=${master_port},save_samples=${save_samples},enable_remask=${enable_remask} \ | ||||||
| --output_path ${output_path} --include_path ./tasks --apply_chat_template | ||||||
| done | ||||||
| else | ||||||
| echo "parallel must be tp" | ||||||
| fi | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -385,6 +385,115 @@ def get_transfer_index_threshold( | |||||
| return x0, transfer_index | ||||||
|
|
||||||
|
|
||||||
| @torch.no_grad() | ||||||
| @torch.compile(dynamic=True) | ||||||
| def get_transfer_index_threshold_remask( | ||||||
| logits, | ||||||
| temperature, | ||||||
| mask_index, | ||||||
| x, | ||||||
| mask_id, | ||||||
| threshold, | ||||||
| use_float64=False, | ||||||
| fix_mask=None, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| """Similar to get_transfer_index_threshold but with remasking support. | ||||||
|
|
||||||
| Tokens that are already decoded but have confidence below the given threshold | ||||||
| will be remasked (set back to mask_id). | ||||||
| Uses the given threshold directly (no adaptive threshold calculation). | ||||||
| """ | ||||||
| # Keep the original mask positions for remasking decision. | ||||||
| orig_mask_index = mask_index | ||||||
|
|
||||||
| # 1) Sample token ids from gumbel-noised logits (sampling decision). | ||||||
| logits_with_noise = add_gumbel_noise(logits, temperature=temperature) | ||||||
| x0_sample = torch.argmax(logits_with_noise, dim=-1) # b, l | ||||||
|
|
||||||
| # 2) Compute confidence from *clean* logits (no gumbel): | ||||||
| # - masked positions: probability of the sampled token (x0_sample from gumbel sampling) | ||||||
| # - unmasked positions: probability of current token in x (used for remasking decision) | ||||||
| if use_float64: | ||||||
| p = F.softmax(logits.to(torch.float64), dim=-1) | ||||||
| else: | ||||||
| p = F.softmax(logits.to(torch.float32), dim=-1) | ||||||
|
|
||||||
| # Probability of the *sampled* token at each position: | ||||||
| # - token id comes from gumbel sampling (x0_sample) | ||||||
| # - probability is evaluated under clean logits (p) | ||||||
| # Used for thresholding masked positions (deciding which masks to fill this step). | ||||||
| x0_p = torch.squeeze( | ||||||
| torch.gather( | ||||||
| p, dim=-1, index=torch.unsqueeze(x0_sample.to(torch.long), -1) | ||||||
| ), | ||||||
| -1, | ||||||
| ) # b, l | ||||||
|
|
||||||
| # Probability of the *current* token already in the sequence (x) under clean logits (p). | ||||||
| # Used for remasking: unmasked positions with low confidence will be set back to mask_id. | ||||||
| x_p = torch.squeeze( | ||||||
| torch.gather(p, dim=-1, index=torch.unsqueeze(x.to(torch.long), -1)), | ||||||
| -1, | ||||||
| ) # b, l | ||||||
|
|
||||||
| # Step 1: Default decode all mask positions | ||||||
| x0 = torch.where(orig_mask_index, x0_sample, x) # [B, L] | ||||||
|
|
||||||
| # Step 2: Build unified confidence: mask positions use x0_p, non-mask positions use x_p | ||||||
| unified_confidence = torch.where(orig_mask_index, x0_p, x_p) # [B, L] | ||||||
|
|
||||||
| # If fix_mask is provided, set confidence of fixed positions to +inf to prevent remasking | ||||||
| if fix_mask is not None: | ||||||
| # fix_mask should have shape [B, L] matching unified_confidence | ||||||
| # Set fixed positions to +inf so they won't be selected for remasking (highest confidence) | ||||||
| unified_confidence = torch.where(fix_mask, torch.full_like(unified_confidence, float('inf')), unified_confidence) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's an inconsistent use of
Suggested change
|
||||||
|
|
||||||
| # Step 3: Use the given threshold directly (no adaptive threshold) | ||||||
| if isinstance(threshold, torch.Tensor): | ||||||
| threshold_tensor = threshold.to(device=unified_confidence.device, dtype=unified_confidence.dtype) | ||||||
| if threshold_tensor.dim() == 0: | ||||||
| threshold_tensor = threshold_tensor.unsqueeze(-1) | ||||||
| else: | ||||||
| threshold_tensor = torch.tensor(threshold, device=unified_confidence.device, dtype=unified_confidence.dtype).unsqueeze(-1) | ||||||
|
|
||||||
| # Step 4: Unified decision: confidence >= threshold means keep, < threshold means remask | ||||||
| keep_mask = unified_confidence >= threshold_tensor # [B, L] | ||||||
|
|
||||||
| # Step 5: Remask positions with low confidence, but ensure mask count is strictly decreasing | ||||||
| # Calculate how many masks we decoded this step | ||||||
| decoded_cnt = orig_mask_index.sum(dim=1) # [B] - all original masks were decoded | ||||||
|
|
||||||
| # Candidates for remasking: positions with low confidence | ||||||
| remask_candidates = ~keep_mask # [B, L] | ||||||
| cand_cnt = remask_candidates.sum(dim=1) # [B] | ||||||
|
|
||||||
| # Limit remask count: remask < decoded to ensure mask count strictly decreasing | ||||||
| # This ensures: final_mask = orig_mask - decoded + remask < orig_mask | ||||||
| max_remask = torch.minimum(decoded_cnt - 1, cand_cnt).clamp(min=0) # [B] | ||||||
|
|
||||||
| # Sort candidates by confidence (lowest first) and select top max_remask | ||||||
| cand_conf = torch.where(remask_candidates, unified_confidence, torch.inf) # Non-candidates set to +inf | ||||||
| sorted_idx = torch.argsort(cand_conf, dim=1) # [B, L] | ||||||
| ranks = torch.argsort(sorted_idx, dim=1) # ranks[pos] = position's rank in sorted order | ||||||
|
|
||||||
| # Select top max_remask candidates (lowest confidence) | ||||||
| max_remask_exp = max_remask.unsqueeze(1) # [B, 1] | ||||||
| remask_index = (ranks < max_remask_exp) & remask_candidates # [B, L] | ||||||
|
|
||||||
| # Ensure fix_mask positions are never remasked (double protection) | ||||||
| if fix_mask is not None: | ||||||
| remask_index = torch.logical_and(remask_index, ~fix_mask) # Never remask fixed positions | ||||||
|
|
||||||
| # Step 6: Apply remask - set low confidence positions back to mask_id | ||||||
| x0 = torch.where(remask_index, mask_id, x0) | ||||||
|
|
||||||
| # transfer_index: all True, meaning x0 is the final updated block | ||||||
| transfer_index = torch.ones_like(orig_mask_index, dtype=torch.bool) | ||||||
|
|
||||||
| return x0, transfer_index | ||||||
|
|
||||||
|
|
||||||
| class ThresholdParallelDecoder(ParallelDecoder): | ||||||
| """Parallel decoding driven by a confidence threshold.""" | ||||||
| def __init__( | ||||||
|
|
@@ -395,11 +504,27 @@ def __init__( | |||||
| mask_id=126336, | ||||||
| eos_id=126081, | ||||||
| use_float64=False, | ||||||
| enable_remask=False, | ||||||
| ): | ||||||
| super().__init__(temperature, remasking, mask_id) | ||||||
| self.threshold = threshold | ||||||
| self.eos_id = eos_id | ||||||
| self.use_float64 = use_float64 | ||||||
| # Enable remasking based on the enable_remask parameter | ||||||
| self.enable_remask = enable_remask | ||||||
|
|
||||||
| def block_init(self, block_x, block_id): | ||||||
| """Initialize fix_mask to protect prompt tokens from remasking. | ||||||
|
|
||||||
| If remasking is enabled, creates fix_mask to mark prompt tokens (non-mask_id tokens) | ||||||
| that should be protected from remasking operations. | ||||||
| """ | ||||||
| if self.enable_remask: | ||||||
| # Create fix_mask: mark positions that are NOT mask_id (i.e., prompt tokens) | ||||||
| # These positions should be protected from remasking | ||||||
| self.fix_mask = (block_x != self.mask_id) # [B, L], True for prompt tokens | ||||||
| else: | ||||||
| self.fix_mask = None | ||||||
|
|
||||||
| def decode(self, logits, block_start, block_end, x, iter_threshold=None): | ||||||
| """ Decode the logits in the same block of multiple samples. | ||||||
|
|
@@ -410,16 +535,35 @@ def decode(self, logits, block_start, block_end, x, iter_threshold=None): | |||||
| assert mask_index.shape[1] == logits.shape[1] | ||||||
|
|
||||||
| curr_x = x[:, block_start:block_end] | ||||||
| x0, transfer_index = get_transfer_index_threshold( | ||||||
| logits, | ||||||
| self.temperature, | ||||||
| mask_index, | ||||||
| curr_x, | ||||||
| self.mask_id, | ||||||
| threshold=iter_threshold, | ||||||
| use_float64=self.use_float64, | ||||||
| ) | ||||||
| transfer_index = torch.logical_and(transfer_index, mask_index) | ||||||
|
|
||||||
| if self.enable_remask: | ||||||
| # Get fix_mask for current block (protect prompt tokens from remasking) | ||||||
| fix_mask = getattr(self, 'fix_mask', None) | ||||||
| x0, transfer_index = get_transfer_index_threshold_remask( | ||||||
| logits, | ||||||
| self.temperature, | ||||||
| mask_index, | ||||||
| curr_x, | ||||||
| self.mask_id, | ||||||
| threshold=iter_threshold, | ||||||
| use_float64=self.use_float64, | ||||||
| fix_mask=fix_mask, | ||||||
| ) | ||||||
| else: | ||||||
| x0, transfer_index = get_transfer_index_threshold( | ||||||
| logits, | ||||||
| self.temperature, | ||||||
| mask_index, | ||||||
| curr_x, | ||||||
| self.mask_id, | ||||||
| threshold=iter_threshold, | ||||||
| use_float64=self.use_float64, | ||||||
| ) | ||||||
|
|
||||||
| # For remasking case, transfer_index may include remasked positions | ||||||
| # For non-remasking case, only update masked positions | ||||||
| if not self.enable_remask: | ||||||
| transfer_index = torch.logical_and(transfer_index, mask_index) | ||||||
|
Comment on lines
+539
to
+566
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to select the decoding function ( Consider refactoring this shared logic into a private helper method within the For example, you could create a helper like this: def _get_x0_and_transfer_index(self, logits, mask_index, curr_x, iter_threshold):
if self.enable_remask:
fix_mask = getattr(self, 'fix_mask', None)
return get_transfer_index_threshold_remask(
logits,
self.temperature,
mask_index,
curr_x,
self.mask_id,
threshold=iter_threshold,
use_float64=self.use_float64,
fix_mask=fix_mask,
)
else:
return get_transfer_index_threshold(
logits,
self.temperature,
mask_index,
curr_x,
self.mask_id,
threshold=iter_threshold,
use_float64=self.use_float64,
)Then, both |
||||||
| assert transfer_index.dtype == torch.bool | ||||||
| x[:, block_start:block_end] = torch.where(transfer_index, x0, curr_x) | ||||||
| broadcast_if_needed(x.data) | ||||||
|
|
@@ -438,17 +582,34 @@ def batch_decode(self, logits, block_start, x, block_length, iter_threshold=None | |||||
|
|
||||||
| mask_index = (x_block == self.mask_id) | ||||||
|
|
||||||
| x0, transfer_index = get_transfer_index_threshold( | ||||||
| logits, | ||||||
| self.temperature, | ||||||
| mask_index, | ||||||
| x_block, | ||||||
| self.mask_id, | ||||||
| threshold=iter_threshold, | ||||||
| use_float64=self.use_float64, | ||||||
| ) | ||||||
| if self.enable_remask: | ||||||
| # Get fix_mask for current block (protect prompt tokens from remasking) | ||||||
| fix_mask = getattr(self, 'fix_mask', None) | ||||||
| x0, transfer_index = get_transfer_index_threshold_remask( | ||||||
| logits, | ||||||
| self.temperature, | ||||||
| mask_index, | ||||||
| x_block, | ||||||
| self.mask_id, | ||||||
| threshold=iter_threshold, | ||||||
| use_float64=self.use_float64, | ||||||
| fix_mask=fix_mask, | ||||||
| ) | ||||||
| else: | ||||||
| x0, transfer_index = get_transfer_index_threshold( | ||||||
| logits, | ||||||
| self.temperature, | ||||||
| mask_index, | ||||||
| x_block, | ||||||
| self.mask_id, | ||||||
| threshold=iter_threshold, | ||||||
| use_float64=self.use_float64, | ||||||
| ) | ||||||
|
|
||||||
| transfer_index = transfer_index & mask_index | ||||||
| # For remasking case, transfer_index may include remasked positions | ||||||
| # For non-remasking case, only update masked positions | ||||||
| if not self.enable_remask: | ||||||
| transfer_index = transfer_index & mask_index | ||||||
|
|
||||||
| x_updated = torch.where(transfer_index, x0, x_block) | ||||||
|
|
||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
argsdictionary has a duplicate key for"batch_size". While it doesn't cause a functional issue in this case as both instances use the same value (self.batch_size), it makes the code confusing and prone to errors in the future. The redundant entry should be removed.