From e7588e366a1745af26a43344e6a5ce9bbab5064a Mon Sep 17 00:00:00 2001 From: SeqIO Team Date: Sat, 18 Jan 2025 23:47:17 -0800 Subject: [PATCH] Make `UnigramVocabulary._encode` method work for bytes as well string inputs. PiperOrigin-RevId: 717164495 --- seqio/vocabularies.py | 4 +++- seqio/vocabularies_test.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/seqio/vocabularies.py b/seqio/vocabularies.py index 81c27a80..79af25b1 100644 --- a/seqio/vocabularies.py +++ b/seqio/vocabularies.py @@ -243,7 +243,9 @@ def __init__(self, unigrams: Sequence[str], split_on_space: bool = False): self._unigram_by_id_tf = tf.constant(self._unigram_by_id) self._split_on_space = split_on_space - def _encode(self, s: str) -> Sequence[int]: + def _encode(self, s: str | bytes) -> Sequence[int]: + if isinstance(s, bytes): + s = s.decode("utf-8") if self._split_on_space: return [ self._id_by_unigram.get(unigram, self.unk_id) diff --git a/seqio/vocabularies_test.py b/seqio/vocabularies_test.py index ec0c84f7..3c1ebd59 100644 --- a/seqio/vocabularies_test.py +++ b/seqio/vocabularies_test.py @@ -217,6 +217,14 @@ def test_encode_converts_unigrams_to_ints_correctly(self, split_on_space): self.assertEqual(vocabulary.encode("not that"), [4, 2]) else: self.assertEqual(vocabulary.encode("not that"), [vocabulary.unk_id]) + # validate that lookup works for bytes input. + self.assertEqual(vocabulary.encode(b"that"), [2]) + self.assertEqual(vocabulary.encode(b"not"), [4]) + self.assertEqual(vocabulary.encode(b"apple"), [vocabulary.unk_id]) + if split_on_space: + self.assertEqual(vocabulary.encode(b"not that"), [4, 2]) + else: + self.assertEqual(vocabulary.encode(b"not that"), [vocabulary.unk_id]) with self.subTest(name="tensorflow"): # Note that id 0 is reserved for padding. # Note that this test must pass under both TF1 and TF2, but the default