Skip to content

Problems with Cross-Tokenizer Alignment in Correctness and Efficiency #4393

@JqzChandler

Description

@JqzChandler

def _build_alignment_groups_from_ids(self, student_token_ids, teacher_token_ids):
"""
Build alignment groups using a greedy substring-equality algorithm on decoded token pieces.
Args:
student_token_ids: List[int]
teacher_token_ids: List[int]
Returns:
Tuple[List[List[int]], List[List[int]]]: student and teacher alignment groups
"""
def to_canonical_pieces(tok, ids):
pieces = []
prev = ""
for k in range(len(ids)):
# IMPORTANT: Do NOT skip special tokens - we need to align them too
cur = tok.decode(ids[: k + 1], skip_special_tokens=False, clean_up_tokenization_spaces=False)
# Extract the incremental addition (may include spaces/ZWJ/etc.)
pieces.append(cur[len(prev) :])
prev = cur
return pieces
s_pieces = to_canonical_pieces(self.student_tokenizer, student_token_ids)
t_pieces = to_canonical_pieces(self.teacher_tokenizer, teacher_token_ids)

  1. Current implementation produces incorrect decoding results in some cases
    BPE-based tokenizers always split some multi-byte characters, e.g. some Chinese characters. In the current implementation, the algorithm adds one token per loop iteration, calls the decode method, and takes the trailing incremental text as the corresponding piece. This causes decoding errors for these split bytes, with no subsequent validation or error correction mechanism. To reproduce this issue quickly, try preparing multilingual mixed text samples and load Qwen3 tokenizer, then assert ''.join(pieces) == orig_text to verify the splitting results for each sample.

  2. Calling decode() repeatedly, causes significant computational overhead
    In our team's practice, we modified Python bindings in tokenizers (and correspondingly adjusted tokenization_utils_fast.py in transformers). We introduced parameter "offset_type" in both encode() and batch_encode() methods to directly obtain the offset sequences generated during the Rust code's single encoding execution. This approach has minimal computational overhead and ensures correctness.

This practice also provides an additional advantage: when aligning two families of BPE tokenizers, we can achieve more accurate token alignment results without relying on string as an intermediate modality.

def align_using_offsets(offsets_T: List, offsets_S: List):
    """
    offsets_T/_S be like:
    [(0, 1), (1, 3), (3, 8), ..., (start_i, end_i), ..., (start_I, TOTAL_BYTES_OF_ORIG_STRING)] # len = I == len(input_ids_T)
    [(0, 3), (3, 5), (5, 8), ..., (start_j, end_j), ..., (start_J, TOTAL_BYTES_OF_ORIG_STRING)] # len = J == len(input_ids_S)
    """
    
    constraint_ls = [{"T": 0, "S": 0}]

    cur_T, cur_S = 0, 0
    while True:
        if offsets_T[cur_T][1] == offsets_S[cur_S][1]:
            constraint_ls.append({"T": cur_T, "S": cur_S})
            cur_T, cur_S = cur_T + 1, cur_S + 1
        elif offsets_T[cur_T][1] < offsets_S[cur_S][1]:
            cur_T += 1
        else: # offsets_T[cur_T][1] > offsets_S[cur_S][1]
            cur_S += 1
        
        if cur_T == len(offsets_T)-1 or cur_S == len(offsets_S)-1:
            constraint_ls.append({"T": len(offsets_T)-1, "S": len(offsets_S)-1})
            break

    return constraint_ls

offsets_T = teacher_tokenizer.encode(gen_text, offset_type="byte")['offsets']
offsets_S = student_tokenizer.encode(gen_text, offset_type="byte")['offsets']

constraint_ls = align_using_offsets(offsets_T, offsets_S)

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions