-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Description
trl/trl/experimental/gold/gold_trainer.py
Lines 342 to 366 in 1eb561c
| def _build_alignment_groups_from_ids(self, student_token_ids, teacher_token_ids): | |
| """ | |
| Build alignment groups using a greedy substring-equality algorithm on decoded token pieces. | |
| Args: | |
| student_token_ids: List[int] | |
| teacher_token_ids: List[int] | |
| Returns: | |
| Tuple[List[List[int]], List[List[int]]]: student and teacher alignment groups | |
| """ | |
| def to_canonical_pieces(tok, ids): | |
| pieces = [] | |
| prev = "" | |
| for k in range(len(ids)): | |
| # IMPORTANT: Do NOT skip special tokens - we need to align them too | |
| cur = tok.decode(ids[: k + 1], skip_special_tokens=False, clean_up_tokenization_spaces=False) | |
| # Extract the incremental addition (may include spaces/ZWJ/etc.) | |
| pieces.append(cur[len(prev) :]) | |
| prev = cur | |
| return pieces | |
| s_pieces = to_canonical_pieces(self.student_tokenizer, student_token_ids) | |
| t_pieces = to_canonical_pieces(self.teacher_tokenizer, teacher_token_ids) |
-
Current implementation produces incorrect decoding results in some cases
BPE-based tokenizers always split some multi-byte characters, e.g. some Chinese characters. In the current implementation, the algorithm adds one token per loop iteration, calls the decode method, and takes the trailing incremental text as the corresponding piece. This causes decoding errors for these split bytes, with no subsequent validation or error correction mechanism. To reproduce this issue quickly, try preparing multilingual mixed text samples and load Qwen3 tokenizer, thenassert ''.join(pieces) == orig_textto verify the splitting results for each sample. -
Calling decode() repeatedly, causes significant computational overhead
In our team's practice, we modified Python bindings in tokenizers (and correspondingly adjusted tokenization_utils_fast.py in transformers). We introduced parameter "offset_type" in both encode() and batch_encode() methods to directly obtain the offset sequences generated during the Rust code's single encoding execution. This approach has minimal computational overhead and ensures correctness.
This practice also provides an additional advantage: when aligning two families of BPE tokenizers, we can achieve more accurate token alignment results without relying on string as an intermediate modality.
def align_using_offsets(offsets_T: List, offsets_S: List):
"""
offsets_T/_S be like:
[(0, 1), (1, 3), (3, 8), ..., (start_i, end_i), ..., (start_I, TOTAL_BYTES_OF_ORIG_STRING)] # len = I == len(input_ids_T)
[(0, 3), (3, 5), (5, 8), ..., (start_j, end_j), ..., (start_J, TOTAL_BYTES_OF_ORIG_STRING)] # len = J == len(input_ids_S)
"""
constraint_ls = [{"T": 0, "S": 0}]
cur_T, cur_S = 0, 0
while True:
if offsets_T[cur_T][1] == offsets_S[cur_S][1]:
constraint_ls.append({"T": cur_T, "S": cur_S})
cur_T, cur_S = cur_T + 1, cur_S + 1
elif offsets_T[cur_T][1] < offsets_S[cur_S][1]:
cur_T += 1
else: # offsets_T[cur_T][1] > offsets_S[cur_S][1]
cur_S += 1
if cur_T == len(offsets_T)-1 or cur_S == len(offsets_S)-1:
constraint_ls.append({"T": len(offsets_T)-1, "S": len(offsets_S)-1})
break
return constraint_ls
offsets_T = teacher_tokenizer.encode(gen_text, offset_type="byte")['offsets']
offsets_S = student_tokenizer.encode(gen_text, offset_type="byte")['offsets']
constraint_ls = align_using_offsets(offsets_T, offsets_S)