Skip to content

Commit 6c34de7

Browse files
SeqIO TeamSeqIO
SeqIO Team
authored and
SeqIO
committed
Add support for FastSentencepieceTokenizer.
The implementation of the tokenizer is here: https://www.tensorflow.org/text/api_docs/python/text/FastSentencepieceTokenizer This change should enable small student models, which use T5 as the backbone, to use the fast tokenizer, which for long sequences can slash several miliseconds of inference time on accelerators, since the tokenization happens on Device and not on Host. PiperOrigin-RevId: 589615741
1 parent cff6544 commit 6c34de7

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

seqio/test_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,12 +1265,14 @@ def sentencepiece_vocab(
12651265
sentencepiece_model_pb2.NormalizerSpec
12661266
] = None,
12671267
reverse_extra_ids: bool = True,
1268+
use_fast_tokenizer: bool = False,
12681269
):
12691270
return vocabularies.SentencePieceVocabulary(
12701271
os.path.join(TEST_DATA_DIR, "sentencepiece", "sentencepiece.model"),
12711272
extra_ids=extra_ids,
12721273
normalizer_spec_overrides=normalizer_spec_overrides,
12731274
reverse_extra_ids=reverse_extra_ids,
1275+
use_fast_tokenizer=use_fast_tokenizer,
12741276
)
12751277

12761278

seqio/vocabularies.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,9 @@ def __init__(
285285
sentencepiece_model_pb2.NormalizerSpec
286286
] = None,
287287
reverse_extra_ids: bool = True,
288+
# TODO(vladdoru): Flip this to True by default after we confirm there
289+
# is no delta in the behavior of the 2 implementations.
290+
use_fast_tokenizer: bool = False,
288291
):
289292
"""Create a SentencePieceVocabulary.
290293
@@ -300,11 +303,14 @@ def __init__(
300303
reverse_extra_ids: if True, extra_ids are numbered in descending order, so
301304
the first extra_id has the highest number. This is done for
302305
compatibility with span_corruption mask generation in T5.
306+
use_fast_tokenizer: use the tf_text fastsentencepiecetokenizer
307+
implementation which runs much faster.
303308
"""
304309
self._sentencepiece_model_file = sentencepiece_model_file
305310
self._normalizer_spec_overrides = normalizer_spec_overrides
306311
self._reverse_extra_ids = reverse_extra_ids
307312
self._model: Optional[SentencePieceVocabulary._ModelContext] = None
313+
self._use_fast_tokenizer = use_fast_tokenizer
308314

309315
super().__init__(extra_ids=extra_ids)
310316

@@ -436,6 +442,8 @@ def tokenizer(self) -> sentencepiece_processor.SentencePieceProcessor:
436442
@property
437443
def tf_tokenizer(self):
438444
"""Instantiate and return a TF tokenizer."""
445+
if self._use_fast_tokenizer:
446+
return tf_text.FastSentencepieceTokenizer(model=self.sp_model)
439447
return tf_text.SentencepieceTokenizer(model=self.sp_model)
440448

441449
@property

seqio/vocabularies_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,20 @@ def test_extra_ids(self):
298298
test_tokens, tuple(vocab.encode_tf(test_string).numpy())
299299
)
300300

301+
def test_fast_tokenizer(self):
302+
vocab = test_utils.sentencepiece_vocab(
303+
extra_ids=10, use_fast_tokenizer=True)
304+
self.assertEqual(36, vocab.vocab_size)
305+
self.assertEqual("v", vocab.decode([25]))
306+
test_string = "<extra_id_0> <extra_id_1> v <extra_id_9>"
307+
test_tokens = (35, 34, 3, 25, 26)
308+
self.assertEqual(test_string, vocab.decode(test_tokens))
309+
self.assertEqual(test_string, _decode_tf(vocab, test_tokens))
310+
self.assertSequenceEqual(test_tokens, vocab.encode(test_string))
311+
self.assertSequenceEqual(
312+
test_tokens, tuple(vocab.encode_tf(test_string).numpy())
313+
)
314+
301315
def test_force_repeated_whitespace_preservation(self):
302316
test_string = "a a a a" # string with repeated whitespaces
303317

0 commit comments

Comments
 (0)