From fe180ae1776b98de572437294ebac9609f33d1b8 Mon Sep 17 00:00:00 2001 From: Saibo Geng Date: Thu, 29 Feb 2024 21:21:30 +0100 Subject: [PATCH] fix unicode mapping issue --- transformers_cfg/token_grammar_recognizer.py | 6 +++--- transformers_cfg/trie.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index dc42589..c3246da 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -8,7 +8,7 @@ from transformers_cfg.recognizer import StringRecognizer, AcceptState from transformers_cfg.parser import parse_ebnf -from transformers_cfg.trie import Trie +from transformers_cfg.trie import ByteTrie from transformers_cfg.utf8_utils import PartialUTF8 from .vocab_struct import LEAF, TokenTrie from transformers_cfg.mapping import get_mapping @@ -27,7 +27,7 @@ def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False self.token_trie = TokenTrie(tokenizer) self.tokenizer = tokenizer self.string_recognizer = StringRecognizer(grammar_encoding, self.start_rule_id) - self.trie = Trie.from_tokenizer(tokenizer) + self.unicode_trie = ByteTrie.from_tokenizer(tokenizer, unicode=unicode) self.mapping = get_mapping(tokenizer, unicode=unicode) assert len(self.mapping) == len( self.token_trie @@ -131,7 +131,7 @@ def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device): accept_f = lambda x: self.string_recognizer._probe_bytes( x, [stack], partial_utf8=partial_utf8 ) - token_acceptance = self.trie.get_token_acceptance( + token_acceptance = self.unicode_trie.get_token_acceptance( accept=accept_f, accept_eos=False, eos_token_id=self.eos_token_id ) else: diff --git a/transformers_cfg/trie.py b/transformers_cfg/trie.py index 1339220..4682220 100644 --- a/transformers_cfg/trie.py +++ b/transformers_cfg/trie.py @@ -24,7 +24,7 @@ def __init__(self): self.token_id = None -class Trie: +class ByteTrie: def __init__(self): self.root = TrieNode() @@ -54,10 +54,10 @@ def start_with_prefix(self, prefix): return True @classmethod - def from_tokenizer(cls, tokenizer): + def from_tokenizer(cls, tokenizer, unicode=True): vocab: Dict[str, int] = tokenizer.get_vocab() trie = cls() - mapping = get_mapping(tokenizer) + mapping = get_mapping(tokenizer, unicode=unicode) for token_id in vocab.values(): byte_repr = mapping.map(token_id) trie.insert(byte_repr, token_id) @@ -165,7 +165,7 @@ def starts_with_prefix(prefix, target): tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=True) - trie = Trie.from_tokenizer(tokenizer) + trie = ByteTrie.from_tokenizer(tokenizer, unicode=True) print(f"length of trie: {len(trie)}=={len(tokenizer.vocab.items())}") #