Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,7 +1252,7 @@ def add_rollout_buffer_arguments(parser):
"--loss-mask-type",
type=str,
default="qwen",
choices=["qwen", "qwen3", "qwen3_5", "distill_qwen"],
choices=["qwen", "qwen3", "qwen3_5", "distill_qwen", "glm5"],
help="Loss mask type",
)
parser.add_argument(
Expand Down
94 changes: 94 additions & 0 deletions slime/utils/mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,98 @@ def gen_multi_turn_loss_mask_distill_qwen(
loss_mask = [0] * len(token_ids)
return token_ids, loss_mask

def gen_multi_turn_loss_mask_glm5(
self, messages: list[dict], tools: list[dict] = None
) -> tuple[list[int], list[int]]:
"""Generate loss masks for GLM-5 chat template.

GLM-5 uses role-token delimiters with no closing tags:
[gMASK]<sop><|system|>...<|user|>...<|assistant|></think>content...

Assistant messages start with ``<|assistant|></think>`` (or ``<|assistant|><think>...
</think>`` when thinking is enabled). We mask only the assistant content tokens
(everything after ``</think>`` until the next role token or end of sequence).
"""
rendered_text = self.tokenizer.apply_chat_template(
messages, tokenize=False, tools=tools, add_generation_prompt=False
)
tokenized = self.tokenizer(rendered_text, add_special_tokens=False, return_offsets_mapping=True)
token_ids = tokenized["input_ids"]
offset_mapping = tokenized.get("offset_mapping")

if offset_mapping is None:
raise ValueError(
"GLM-5 loss mask generation requires a fast tokenizer "
"with `return_offsets_mapping` support."
)

expected_token_ids = self.tokenizer(
self.tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=False),
add_special_tokens=False,
)["input_ids"]
if token_ids != expected_token_ids:
raise ValueError(
"GLM-5 rendered text tokenization does not match "
"re-tokenized output."
)

assistant_header = "<|assistant|>"
think_close = "</think>"
role_markers = ("<|user|>", "<|assistant|>", "<|system|>", "<|observation|>")

char_mask = [0] * len(rendered_text)
cursor = 0

for message in messages:
if message["role"] != "assistant":
continue

header_pos = rendered_text.find(assistant_header, cursor)
if header_pos < 0:
raise ValueError("Failed to locate <|assistant|> in rendered GLM-5 chat template output.")

content_start = header_pos + len(assistant_header)

# Find the end of this assistant message: next role token or end of string
end_pos = len(rendered_text)
for marker in role_markers:
marker_pos = rendered_text.find(marker, content_start)
if 0 <= marker_pos < end_pos:
end_pos = marker_pos

cursor = end_pos

if message.get("step_loss_mask", 1) != 1:
continue

# Skip past </think> or <think>...</think> block at the start of assistant content
mask_start = content_start
if rendered_text[mask_start : mask_start + len(think_close)] == think_close:
# Simple case: </think> immediately after <|assistant|>
mask_start += len(think_close)
elif rendered_text[mask_start : mask_start + len("<think>")] == "<think>":
# Thinking enabled: <think>...</think>
think_end = rendered_text.find(think_close, mask_start)
if think_end >= 0 and think_end < end_pos:
mask_start = think_end + len(think_close)

for pos in range(mask_start, end_pos):
char_mask[pos] = 1

# Convert char-level mask to token-level using offset mapping
char_mask_prefix_sum = [0]
for value in char_mask:
char_mask_prefix_sum.append(char_mask_prefix_sum[-1] + value)

loss_mask = []
for start, end in offset_mapping:
if end <= start:
loss_mask.append(0)
else:
loss_mask.append(1 if char_mask_prefix_sum[end] - char_mask_prefix_sum[start] > 0 else 0)

return token_ids, loss_mask

def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple[list[int], list[int]]:
if self.tokenizer_type == "qwen":
if "<|Assistant|>" in self.tokenizer.get_added_vocab():
Expand All @@ -225,6 +317,8 @@ def get_loss_mask(self, messages: list[dict], tools: list[dict] = None) -> tuple
return self.gen_multi_turn_loss_mask_qwen3_5(messages, tools)
elif self.tokenizer_type == "distill_qwen":
return self.gen_multi_turn_loss_mask_distill_qwen(messages, tools)
elif self.tokenizer_type == "glm5":
return self.gen_multi_turn_loss_mask_glm5(messages, tools)
else:
raise ValueError(f"Unsupported tokenizer type: {self.tokenizer_type}")

Expand Down
Loading