Skip to content

Commit

Permalink
fix unicode mapping issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Saibo-creator committed Feb 29, 2024
1 parent 9a5d840 commit fe180ae
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions transformers_cfg/token_grammar_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from transformers_cfg.recognizer import StringRecognizer, AcceptState
from transformers_cfg.parser import parse_ebnf
from transformers_cfg.trie import Trie
from transformers_cfg.trie import ByteTrie
from transformers_cfg.utf8_utils import PartialUTF8
from .vocab_struct import LEAF, TokenTrie
from transformers_cfg.mapping import get_mapping
Expand All @@ -27,7 +27,7 @@ def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False
self.token_trie = TokenTrie(tokenizer)
self.tokenizer = tokenizer
self.string_recognizer = StringRecognizer(grammar_encoding, self.start_rule_id)
self.trie = Trie.from_tokenizer(tokenizer)
self.unicode_trie = ByteTrie.from_tokenizer(tokenizer, unicode=unicode)
self.mapping = get_mapping(tokenizer, unicode=unicode)
assert len(self.mapping) == len(
self.token_trie
Expand Down Expand Up @@ -131,7 +131,7 @@ def get_token_acceptance_array_for_stack(self, stack, partial_utf8, device):
accept_f = lambda x: self.string_recognizer._probe_bytes(
x, [stack], partial_utf8=partial_utf8
)
token_acceptance = self.trie.get_token_acceptance(
token_acceptance = self.unicode_trie.get_token_acceptance(
accept=accept_f, accept_eos=False, eos_token_id=self.eos_token_id
)
else:
Expand Down
8 changes: 4 additions & 4 deletions transformers_cfg/trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self):
self.token_id = None


class Trie:
class ByteTrie:
def __init__(self):
self.root = TrieNode()

Expand Down Expand Up @@ -54,10 +54,10 @@ def start_with_prefix(self, prefix):
return True

@classmethod
def from_tokenizer(cls, tokenizer):
def from_tokenizer(cls, tokenizer, unicode=True):
vocab: Dict[str, int] = tokenizer.get_vocab()
trie = cls()
mapping = get_mapping(tokenizer)
mapping = get_mapping(tokenizer, unicode=unicode)
for token_id in vocab.values():
byte_repr = mapping.map(token_id)
trie.insert(byte_repr, token_id)
Expand Down Expand Up @@ -165,7 +165,7 @@ def starts_with_prefix(prefix, target):

tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=True)

trie = Trie.from_tokenizer(tokenizer)
trie = ByteTrie.from_tokenizer(tokenizer, unicode=True)
print(f"length of trie: {len(trie)}=={len(tokenizer.vocab.items())}")

#
Expand Down

0 comments on commit fe180ae

Please sign in to comment.