Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
evaluations/outputs/
build/
python/dinfer.egg-info

evaluations/tasks/mbpp_sanitized/__pycache__/*
9 changes: 6 additions & 3 deletions evaluations/eval_dinfer_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def run_benchmark(world_size, rank, gpu_id, tokenizer, args):

if args.parallel_decoding == 'threshold':
if args.use_credit:
decoder = CreditThresholdParallelDecoder(temperature=0, threshold=args.threshold, mask_id=mask_id, eos_id=eos_id)
decoder = CreditThresholdParallelDecoder(temperature=0, threshold=args.threshold, mask_id=mask_id, eos_id=eos_id, enable_remask=args.enable_remask)
else:
decoder = ThresholdParallelDecoder(temperature=0, threshold=args.threshold, mask_id=mask_id, eos_id=eos_id)
decoder = ThresholdParallelDecoder(temperature=0, threshold=args.threshold, mask_id=mask_id, eos_id=eos_id, enable_remask=args.enable_remask)

else:
decoder = HierarchyDecoder(temperature=0, threshold=args.threshold, low_threshold=args.low_threshold, mask_id=mask_id, eos_id=eos_id)
Expand Down Expand Up @@ -272,6 +272,7 @@ class EvalConfig:
batch_size: int = 1
save_samples: bool = False
speed_path: str = ''
enable_remask: bool = False

def set_seed(seed):
torch.manual_seed(seed)
Expand Down Expand Up @@ -316,6 +317,7 @@ def __init__(
use_shift = False,
model_type = 'llada2',
save_samples = False,
enable_remask = False,
**kwargs
):

Expand Down Expand Up @@ -354,6 +356,7 @@ def __init__(
self.use_shift = use_shift
self.model_type = model_type
self.save_samples = save_samples
self.enable_remask = enable_remask

if self.model_type == 'llada2':
self.mask_id = 156895
Expand Down Expand Up @@ -575,7 +578,7 @@ def cal_bucket_len(gen_len, all_input_ids):
procs = []
answers = []
gpus = [int(gpu) for gpu in self.gpus.split(';')]
args = {"gpu": gpus, "batch_size": self.batch_size, "model_name": self.model_path, "gen_len": self.gen_length, "block_length": self.block_length, "prefix_look": self.prefix_look, "after_look": self.after_look, "warmup_times": self.warmup_times, "low_threshold": self.low_threshold, "threshold": self.threshold, "cont_weight": self.cont_weight, "use_credit": self.use_credit, "cache": self.cache, "parallel_decoding": self.parallel_decoding, "tp_size": self.tp_size, "save_path": self.save_path, "use_cudagraph": self.use_cudagraph, "use_compile": self.use_compile,"use_bd": self.use_bd, "use_shift": self.use_shift, "model_type": self.model_type, "vocab_size": self.vocab_size, "batch_size": self.batch_size, "speed_path": self.speed_path}
args = {"gpu": gpus, "batch_size": self.batch_size, "model_name": self.model_path, "gen_len": self.gen_length, "block_length": self.block_length, "prefix_look": self.prefix_look, "after_look": self.after_look, "warmup_times": self.warmup_times, "low_threshold": self.low_threshold, "threshold": self.threshold, "cont_weight": self.cont_weight, "use_credit": self.use_credit, "cache": self.cache, "parallel_decoding": self.parallel_decoding, "tp_size": self.tp_size, "save_path": self.save_path, "use_cudagraph": self.use_cudagraph, "use_compile": self.use_compile,"use_bd": self.use_bd, "use_shift": self.use_shift, "model_type": self.model_type, "vocab_size": self.vocab_size, "batch_size": self.batch_size, "speed_path": self.speed_path, "enable_remask": self.enable_remask}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The args dictionary has a duplicate key for "batch_size". While it doesn't cause a functional issue in this case as both instances use the same value (self.batch_size), it makes the code confusing and prone to errors in the future. The redundant entry should be removed.

Suggested change
args = {"gpu": gpus, "batch_size": self.batch_size, "model_name": self.model_path, "gen_len": self.gen_length, "block_length": self.block_length, "prefix_look": self.prefix_look, "after_look": self.after_look, "warmup_times": self.warmup_times, "low_threshold": self.low_threshold, "threshold": self.threshold, "cont_weight": self.cont_weight, "use_credit": self.use_credit, "cache": self.cache, "parallel_decoding": self.parallel_decoding, "tp_size": self.tp_size, "save_path": self.save_path, "use_cudagraph": self.use_cudagraph, "use_compile": self.use_compile,"use_bd": self.use_bd, "use_shift": self.use_shift, "model_type": self.model_type, "vocab_size": self.vocab_size, "batch_size": self.batch_size, "speed_path": self.speed_path, "enable_remask": self.enable_remask}
args = {"gpu": gpus, "model_name": self.model_path, "gen_len": self.gen_length, "block_length": self.block_length, "prefix_look": self.prefix_look, "after_look": self.after_look, "warmup_times": self.warmup_times, "low_threshold": self.low_threshold, "threshold": self.threshold, "cont_weight": self.cont_weight, "use_credit": self.use_credit, "cache": self.cache, "parallel_decoding": self.parallel_decoding, "tp_size": self.tp_size, "save_path": self.save_path, "use_cudagraph": self.use_cudagraph, "use_compile": self.use_compile,"use_bd": self.use_bd, "use_shift": self.use_shift, "model_type": self.model_type, "vocab_size": self.vocab_size, "batch_size": self.batch_size, "speed_path": self.speed_path, "enable_remask": self.enable_remask}

args = EvalConfig(**args)
args.tp_size = len(gpus)
args.master_port = self.master_port
Expand Down
41 changes: 41 additions & 0 deletions evaluations/eval_llada_mini_remask.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Set the environment variables first before running the command.
export HF_ALLOW_CODE_EVAL=1
export HF_DATASETS_TRUST_REMOTE_CODE=1
export TRANSFORMERS_TRUST_REMOTE_CODE=1
export CUDA_VISIBLE_DEVICES=4,5,6,7

parallel_decoding='threshold' # or hierarchy
length=2048 # generate length
block_length=32 # block length
model_path='/root/inclusion_data/separate_expert_model/LLaDA2.0-mini-preview' # your model path
threshold=0.98 # threshold for parallel decoding
low_threshold=0.62 # low threshold for parallel decoding when using hierarchy mechanism
cache='prefix' # or 'prefix' for prefix cache; or '' if you don't want to use cache
warmup_times=0 # warmup times for cache
prefix_look=0
after_look=0
cont_weight=0 # cont weight
use_credit=False # enable credit for threshold mechanism
use_compile=True # use compile
tp_size=4 # tensor parallel size
gpus='0;1;2;3' # gpus for tensor parallel inference
parallel='tp' # 'tp' for tensor parallel or 'dp' for data parallel
output_dir='./outputs' # your customer output path
model_type='llada2' # llada2 (for llada2-mini)
use_bd=True # use block diffusion
master_port="23458"
save_samples=False # save samples
enable_remask=True # enable remasking for threshold decoder
# for llada 1.5 use tasks gsm8k_llada1.5 mbpp_sanitized_llada1.5
# for llada2_mini use tasks gsm8k_llada_mini mbpp_sanitized_llada_mini
if [ parallel=='tp' ]; then

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The if condition has incorrect syntax for a string comparison in bash. The expression [ parallel=='tp' ] will always evaluate to true because parallel=='tp' is treated as a non-empty string. This will cause the script to run even when parallel is not 'tp'. You should use "$parallel" == "tp" for string comparison.

Suggested change
if [ parallel=='tp' ]; then
if [ "$parallel" == 'tp' ]; then

for task in mbpp_sanitized_llada_mini; do
output_path=${output_dir}/${task}
python eval_dinfer_sglang.py --tasks ${task} \
--confirm_run_unsafe_code --model dInfer_eval \
--model_args model_path=${model_path},gen_length=${length},block_length=${block_length},threshold=${threshold},low_threshold=${low_threshold},show_speed=True,save_dir=${output_path},parallel_decoding=${parallel_decoding},cache=${cache},warmup_times=${warmup_times},use_compile=${use_compile},tp_size=${tp_size},parallel=${parallel},cont_weight=${cont_weight},use_credit=${use_credit},prefix_look=${prefix_look},after_look=${after_look},gpus=${gpus},model_type=${model_type},use_bd=${use_bd},master_port=${master_port},save_samples=${save_samples},enable_remask=${enable_remask} \
--output_path ${output_path} --include_path ./tasks --apply_chat_template
done
else
echo "parallel must be tp"
fi
201 changes: 181 additions & 20 deletions python/dinfer/decoding/parallel_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,115 @@ def get_transfer_index_threshold(
return x0, transfer_index


@torch.no_grad()
@torch.compile(dynamic=True)
def get_transfer_index_threshold_remask(
logits,
temperature,
mask_index,
x,
mask_id,
threshold,
use_float64=False,
fix_mask=None,
**kwargs,
):
"""Similar to get_transfer_index_threshold but with remasking support.

Tokens that are already decoded but have confidence below the given threshold
will be remasked (set back to mask_id).
Uses the given threshold directly (no adaptive threshold calculation).
"""
# Keep the original mask positions for remasking decision.
orig_mask_index = mask_index

# 1) Sample token ids from gumbel-noised logits (sampling decision).
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0_sample = torch.argmax(logits_with_noise, dim=-1) # b, l

# 2) Compute confidence from *clean* logits (no gumbel):
# - masked positions: probability of the sampled token (x0_sample from gumbel sampling)
# - unmasked positions: probability of current token in x (used for remasking decision)
if use_float64:
p = F.softmax(logits.to(torch.float64), dim=-1)
else:
p = F.softmax(logits.to(torch.float32), dim=-1)

# Probability of the *sampled* token at each position:
# - token id comes from gumbel sampling (x0_sample)
# - probability is evaluated under clean logits (p)
# Used for thresholding masked positions (deciding which masks to fill this step).
x0_p = torch.squeeze(
torch.gather(
p, dim=-1, index=torch.unsqueeze(x0_sample.to(torch.long), -1)
),
-1,
) # b, l

# Probability of the *current* token already in the sequence (x) under clean logits (p).
# Used for remasking: unmasked positions with low confidence will be set back to mask_id.
x_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x.to(torch.long), -1)),
-1,
) # b, l

# Step 1: Default decode all mask positions
x0 = torch.where(orig_mask_index, x0_sample, x) # [B, L]

# Step 2: Build unified confidence: mask positions use x0_p, non-mask positions use x_p
unified_confidence = torch.where(orig_mask_index, x0_p, x_p) # [B, L]

# If fix_mask is provided, set confidence of fixed positions to +inf to prevent remasking
if fix_mask is not None:
# fix_mask should have shape [B, L] matching unified_confidence
# Set fixed positions to +inf so they won't be selected for remasking (highest confidence)
unified_confidence = torch.where(fix_mask, torch.full_like(unified_confidence, float('inf')), unified_confidence)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's an inconsistent use of float('inf') here, while torch.inf is used elsewhere in this function (e.g., line 476). For tensor operations, it's best practice to use torch.inf for consistency and to ensure proper handling across different devices and data types.

Suggested change
unified_confidence = torch.where(fix_mask, torch.full_like(unified_confidence, float('inf')), unified_confidence)
unified_confidence = torch.where(fix_mask, torch.full_like(unified_confidence, torch.inf), unified_confidence)


# Step 3: Use the given threshold directly (no adaptive threshold)
if isinstance(threshold, torch.Tensor):
threshold_tensor = threshold.to(device=unified_confidence.device, dtype=unified_confidence.dtype)
if threshold_tensor.dim() == 0:
threshold_tensor = threshold_tensor.unsqueeze(-1)
else:
threshold_tensor = torch.tensor(threshold, device=unified_confidence.device, dtype=unified_confidence.dtype).unsqueeze(-1)

# Step 4: Unified decision: confidence >= threshold means keep, < threshold means remask
keep_mask = unified_confidence >= threshold_tensor # [B, L]

# Step 5: Remask positions with low confidence, but ensure mask count is strictly decreasing
# Calculate how many masks we decoded this step
decoded_cnt = orig_mask_index.sum(dim=1) # [B] - all original masks were decoded

# Candidates for remasking: positions with low confidence
remask_candidates = ~keep_mask # [B, L]
cand_cnt = remask_candidates.sum(dim=1) # [B]

# Limit remask count: remask < decoded to ensure mask count strictly decreasing
# This ensures: final_mask = orig_mask - decoded + remask < orig_mask
max_remask = torch.minimum(decoded_cnt - 1, cand_cnt).clamp(min=0) # [B]

# Sort candidates by confidence (lowest first) and select top max_remask
cand_conf = torch.where(remask_candidates, unified_confidence, torch.inf) # Non-candidates set to +inf
sorted_idx = torch.argsort(cand_conf, dim=1) # [B, L]
ranks = torch.argsort(sorted_idx, dim=1) # ranks[pos] = position's rank in sorted order

# Select top max_remask candidates (lowest confidence)
max_remask_exp = max_remask.unsqueeze(1) # [B, 1]
remask_index = (ranks < max_remask_exp) & remask_candidates # [B, L]

# Ensure fix_mask positions are never remasked (double protection)
if fix_mask is not None:
remask_index = torch.logical_and(remask_index, ~fix_mask) # Never remask fixed positions

# Step 6: Apply remask - set low confidence positions back to mask_id
x0 = torch.where(remask_index, mask_id, x0)

# transfer_index: all True, meaning x0 is the final updated block
transfer_index = torch.ones_like(orig_mask_index, dtype=torch.bool)

return x0, transfer_index


class ThresholdParallelDecoder(ParallelDecoder):
"""Parallel decoding driven by a confidence threshold."""
def __init__(
Expand All @@ -395,11 +504,27 @@ def __init__(
mask_id=126336,
eos_id=126081,
use_float64=False,
enable_remask=False,
):
super().__init__(temperature, remasking, mask_id)
self.threshold = threshold
self.eos_id = eos_id
self.use_float64 = use_float64
# Enable remasking based on the enable_remask parameter
self.enable_remask = enable_remask

def block_init(self, block_x, block_id):
"""Initialize fix_mask to protect prompt tokens from remasking.

If remasking is enabled, creates fix_mask to mark prompt tokens (non-mask_id tokens)
that should be protected from remasking operations.
"""
if self.enable_remask:
# Create fix_mask: mark positions that are NOT mask_id (i.e., prompt tokens)
# These positions should be protected from remasking
self.fix_mask = (block_x != self.mask_id) # [B, L], True for prompt tokens
else:
self.fix_mask = None

def decode(self, logits, block_start, block_end, x, iter_threshold=None):
""" Decode the logits in the same block of multiple samples.
Expand All @@ -410,16 +535,35 @@ def decode(self, logits, block_start, block_end, x, iter_threshold=None):
assert mask_index.shape[1] == logits.shape[1]

curr_x = x[:, block_start:block_end]
x0, transfer_index = get_transfer_index_threshold(
logits,
self.temperature,
mask_index,
curr_x,
self.mask_id,
threshold=iter_threshold,
use_float64=self.use_float64,
)
transfer_index = torch.logical_and(transfer_index, mask_index)

if self.enable_remask:
# Get fix_mask for current block (protect prompt tokens from remasking)
fix_mask = getattr(self, 'fix_mask', None)
x0, transfer_index = get_transfer_index_threshold_remask(
logits,
self.temperature,
mask_index,
curr_x,
self.mask_id,
threshold=iter_threshold,
use_float64=self.use_float64,
fix_mask=fix_mask,
)
else:
x0, transfer_index = get_transfer_index_threshold(
logits,
self.temperature,
mask_index,
curr_x,
self.mask_id,
threshold=iter_threshold,
use_float64=self.use_float64,
)

# For remasking case, transfer_index may include remasked positions
# For non-remasking case, only update masked positions
if not self.enable_remask:
transfer_index = torch.logical_and(transfer_index, mask_index)
Comment on lines +539 to +566

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to select the decoding function (get_transfer_index_threshold_remask or get_transfer_index_threshold) and to conditionally process the transfer_index is duplicated in both the decode and batch_decode methods. This duplication increases maintenance overhead and the risk of introducing inconsistencies.

Consider refactoring this shared logic into a private helper method within the ThresholdParallelDecoder class. This would centralize the decision-making and improve code readability and maintainability.

For example, you could create a helper like this:

def _get_x0_and_transfer_index(self, logits, mask_index, curr_x, iter_threshold):
    if self.enable_remask:
        fix_mask = getattr(self, 'fix_mask', None)
        return get_transfer_index_threshold_remask(
            logits,
            self.temperature,
            mask_index,
            curr_x,
            self.mask_id,
            threshold=iter_threshold,
            use_float64=self.use_float64,
            fix_mask=fix_mask,
        )
    else:
        return get_transfer_index_threshold(
            logits,
            self.temperature,
            mask_index,
            curr_x,
            self.mask_id,
            threshold=iter_threshold,
            use_float64=self.use_float64,
        )

Then, both decode and batch_decode can call this helper to simplify their implementations.

assert transfer_index.dtype == torch.bool
x[:, block_start:block_end] = torch.where(transfer_index, x0, curr_x)
broadcast_if_needed(x.data)
Expand All @@ -438,17 +582,34 @@ def batch_decode(self, logits, block_start, x, block_length, iter_threshold=None

mask_index = (x_block == self.mask_id)

x0, transfer_index = get_transfer_index_threshold(
logits,
self.temperature,
mask_index,
x_block,
self.mask_id,
threshold=iter_threshold,
use_float64=self.use_float64,
)
if self.enable_remask:
# Get fix_mask for current block (protect prompt tokens from remasking)
fix_mask = getattr(self, 'fix_mask', None)
x0, transfer_index = get_transfer_index_threshold_remask(
logits,
self.temperature,
mask_index,
x_block,
self.mask_id,
threshold=iter_threshold,
use_float64=self.use_float64,
fix_mask=fix_mask,
)
else:
x0, transfer_index = get_transfer_index_threshold(
logits,
self.temperature,
mask_index,
x_block,
self.mask_id,
threshold=iter_threshold,
use_float64=self.use_float64,
)

transfer_index = transfer_index & mask_index
# For remasking case, transfer_index may include remasked positions
# For non-remasking case, only update masked positions
if not self.enable_remask:
transfer_index = transfer_index & mask_index

x_updated = torch.where(transfer_index, x0, x_block)

Expand Down
4 changes: 3 additions & 1 deletion python/dinfer/decoding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def __next__(self):
current_block_end = min(current_block_start + self.block_length, self.x.total_length)
assert current_block_end <= self.x.total_length
self.iter += 1
return BlockLoc(current_block_start, current_block_end), self.x[current_block_start:current_block_end]
# Fix dimension bug: originally was self.x[current_block_start:current_block_end],
# now correctly uses self.x[:, current_block_start:current_block_end] to handle batch dimension
return BlockLoc(current_block_start, current_block_end), self.x[:, current_block_start:current_block_end]


class BlockIteratorFactory:
Expand Down