Skip to content

Commit

Permalink
Make UnigramVocabulary._encode method work for bytes as well string…
Browse files Browse the repository at this point in the history
… inputs.

PiperOrigin-RevId: 717164495
  • Loading branch information
SeqIO Team authored and SeqIO committed Jan 19, 2025
1 parent b26111b commit e7588e3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
4 changes: 3 additions & 1 deletion seqio/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ def __init__(self, unigrams: Sequence[str], split_on_space: bool = False):
self._unigram_by_id_tf = tf.constant(self._unigram_by_id)
self._split_on_space = split_on_space

def _encode(self, s: str) -> Sequence[int]:
def _encode(self, s: str | bytes) -> Sequence[int]:
if isinstance(s, bytes):
s = s.decode("utf-8")
if self._split_on_space:
return [
self._id_by_unigram.get(unigram, self.unk_id)
Expand Down
8 changes: 8 additions & 0 deletions seqio/vocabularies_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ def test_encode_converts_unigrams_to_ints_correctly(self, split_on_space):
self.assertEqual(vocabulary.encode("not that"), [4, 2])
else:
self.assertEqual(vocabulary.encode("not that"), [vocabulary.unk_id])
# validate that lookup works for bytes input.
self.assertEqual(vocabulary.encode(b"that"), [2])
self.assertEqual(vocabulary.encode(b"not"), [4])
self.assertEqual(vocabulary.encode(b"apple"), [vocabulary.unk_id])
if split_on_space:
self.assertEqual(vocabulary.encode(b"not that"), [4, 2])
else:
self.assertEqual(vocabulary.encode(b"not that"), [vocabulary.unk_id])
with self.subTest(name="tensorflow"):
# Note that id 0 is reserved for padding.
# Note that this test must pass under both TF1 and TF2, but the default
Expand Down

0 comments on commit e7588e3

Please sign in to comment.