Skip to content

Commit

Permalink
Add support for FastSentencepieceTokenizer.
Browse files Browse the repository at this point in the history
The implementation of the tokenizer is here: https://www.tensorflow.org/text/api_docs/python/text/FastSentencepieceTokenizer

This change should enable small student models, which use T5 as the backbone, to use the fast tokenizer, which for long sequences can slash several miliseconds of inference time on accelerators, since the tokenization happens on Device and not on Host.

PiperOrigin-RevId: 589615741
  • Loading branch information
SeqIO Team authored and SeqIO committed Dec 10, 2023
1 parent cff6544 commit 6c34de7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
2 changes: 2 additions & 0 deletions seqio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,12 +1265,14 @@ def sentencepiece_vocab(
sentencepiece_model_pb2.NormalizerSpec
] = None,
reverse_extra_ids: bool = True,
use_fast_tokenizer: bool = False,
):
return vocabularies.SentencePieceVocabulary(
os.path.join(TEST_DATA_DIR, "sentencepiece", "sentencepiece.model"),
extra_ids=extra_ids,
normalizer_spec_overrides=normalizer_spec_overrides,
reverse_extra_ids=reverse_extra_ids,
use_fast_tokenizer=use_fast_tokenizer,
)


Expand Down
8 changes: 8 additions & 0 deletions seqio/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ def __init__(
sentencepiece_model_pb2.NormalizerSpec
] = None,
reverse_extra_ids: bool = True,
# TODO(vladdoru): Flip this to True by default after we confirm there
# is no delta in the behavior of the 2 implementations.
use_fast_tokenizer: bool = False,
):
"""Create a SentencePieceVocabulary.
Expand All @@ -300,11 +303,14 @@ def __init__(
reverse_extra_ids: if True, extra_ids are numbered in descending order, so
the first extra_id has the highest number. This is done for
compatibility with span_corruption mask generation in T5.
use_fast_tokenizer: use the tf_text fastsentencepiecetokenizer
implementation which runs much faster.
"""
self._sentencepiece_model_file = sentencepiece_model_file
self._normalizer_spec_overrides = normalizer_spec_overrides
self._reverse_extra_ids = reverse_extra_ids
self._model: Optional[SentencePieceVocabulary._ModelContext] = None
self._use_fast_tokenizer = use_fast_tokenizer

super().__init__(extra_ids=extra_ids)

Expand Down Expand Up @@ -436,6 +442,8 @@ def tokenizer(self) -> sentencepiece_processor.SentencePieceProcessor:
@property
def tf_tokenizer(self):
"""Instantiate and return a TF tokenizer."""
if self._use_fast_tokenizer:
return tf_text.FastSentencepieceTokenizer(model=self.sp_model)
return tf_text.SentencepieceTokenizer(model=self.sp_model)

@property
Expand Down
14 changes: 14 additions & 0 deletions seqio/vocabularies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,20 @@ def test_extra_ids(self):
test_tokens, tuple(vocab.encode_tf(test_string).numpy())
)

def test_fast_tokenizer(self):
vocab = test_utils.sentencepiece_vocab(
extra_ids=10, use_fast_tokenizer=True)
self.assertEqual(36, vocab.vocab_size)
self.assertEqual("v", vocab.decode([25]))
test_string = "<extra_id_0> <extra_id_1> v <extra_id_9>"
test_tokens = (35, 34, 3, 25, 26)
self.assertEqual(test_string, vocab.decode(test_tokens))
self.assertEqual(test_string, _decode_tf(vocab, test_tokens))
self.assertSequenceEqual(test_tokens, vocab.encode(test_string))
self.assertSequenceEqual(
test_tokens, tuple(vocab.encode_tf(test_string).numpy())
)

def test_force_repeated_whitespace_preservation(self):
test_string = "a a a a" # string with repeated whitespaces

Expand Down

0 comments on commit 6c34de7

Please sign in to comment.