diff --git a/src/utils/text.py b/src/utils/text.py index 27a9772..f9de144 100644 --- a/src/utils/text.py +++ b/src/utils/text.py @@ -127,9 +127,6 @@ def __call__(self, texts: Union[str, List[str]]) -> tf.Tensor: tf.Tensor A tensor of tokenized integer sequences. """ - if isinstance(texts, str): - texts = [texts] - split_texts = tf.strings.unicode_split(texts, 'UTF-8') return self.token_to_num(split_texts) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 85a3810..51e74f5 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -48,7 +48,7 @@ def setUpClass(cls): cls.ResizeWithPadLayer = ResizeWithPadLayer def test_initialization(self): - tokenizer = self.Tokenizer(chars=list("ABC"), use_mask=False) + tokenizer = self.Tokenizer(tokens=list("ABC")) dg = self.DataLoader(tokenizer=tokenizer, height=64, augment_model=None) @@ -75,7 +75,7 @@ def test_load_images(self): image_info_tuples = list(zip(images, labels, sample_weights)) dummy_augment_model = tf.keras.Sequential([]) - tokenizer = self.Tokenizer(chars=vocab, use_mask=False) + tokenizer = self.Tokenizer(tokens=vocab) dg = self.DataLoader(tokenizer=tokenizer, height=64, channels=1, augment_model=dummy_augment_model) @@ -115,7 +115,7 @@ def test_load_images_with_augmentation(self): dummy_augment_model = tf.keras.Sequential( [self.ResizeWithPadLayer(70, additional_width=50)]) - tokenizer = self.Tokenizer(chars=vocab, use_mask=False) + tokenizer = self.Tokenizer(tokens=vocab) dg = self.DataLoader(tokenizer=tokenizer, height=64, channels=4, augment_model=dummy_augment_model, is_training=True) diff --git a/tests/test_datamanager.py b/tests/test_datamanager.py index e93b1aa..8f82782 100644 --- a/tests/test_datamanager.py +++ b/tests/test_datamanager.py @@ -117,10 +117,11 @@ def test_initialization(self): "img_size": (256, 256, 3), }) + tokenizer = self.Tokenizer(tokens=list("abc")) data_manager = self.DataManager(img_size=test_config["img_size"], config=test_config, augment_model=None, - charlist=list("abc")) + tokenizer=tokenizer) self.assertIsInstance(data_manager, self.DataManager, "DataManager not instantiated correctly") @@ -152,7 +153,7 @@ def test_create_data_simple(self): # Check the tokenizer self.assertIsInstance(data_manager.tokenizer, self.Tokenizer, "Tokenizer not created correctly") - self.assertEqual(len(data_manager.tokenizer.charlist), 27, + self.assertEqual(len(data_manager.tokenizer), 29, "Charlist length not as expected") def test_missing_files(self): @@ -211,11 +212,6 @@ def test_unsupported_chars_in_eval(self): self.sample_labels[0]+"!", "Label not as expected") - # RK: This should not raise an error imho so why is it tested like this? - # with self.assertRaises(IndexError): - # data_manager.get_filename("validation", 0) - # data_manager.get_filename("evaluation", 3) - # Remove the temporary file self._remove_temp_file(temp_sample_list_file) @@ -232,11 +228,12 @@ def test_injected_charlist(self): }) charlist = list( "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789, ") + tokenizer = self.Tokenizer(tokens=charlist) data_manager = self.DataManager(img_size=test_config["img_size"], config=test_config, augment_model=tf.keras.Sequential(), - charlist=charlist) + tokenizer=tokenizer) # Check if the data is created correctly self.assertEqual(data_manager.get_filename("train", 2), diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index 517c119..0fba169 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -1,14 +1,16 @@ # Imports -# > Third party dependencies -import tensorflow as tf -import numpy as np - # > Standard library -import logging import unittest -from pathlib import Path +import os +import json +from tempfile import TemporaryDirectory +import logging import sys +from pathlib import Path + +# > Third-party dependencies +import tensorflow as tf class TestTokenizer(unittest.TestCase): @@ -24,91 +26,113 @@ def setUpClass(cls): # Add the src directory to the path sys.path.append(str(Path(__file__).resolve().parents[1] / 'src')) + # Import Tokenizer class from utils.text import Tokenizer cls.Tokenizer = Tokenizer - def test_tokenizer_class(self): - # Test without mask and no oov indices - tokenizer = self.Tokenizer(chars=['a', 'b', 'c'], use_mask=False) - self.assertEqual(tokenizer.charlist, ['a', 'b', 'c']) - - # Test with mask - tokenizer = self.Tokenizer(chars=['a', 'b', 'c'], use_mask=True) - self.assertTrue(isinstance(tokenizer.char_to_num, - tf.keras.layers.StringLookup)) - self.assertTrue(tokenizer.char_to_num.mask_token, '') - - # Test set_charlist function with no oov indices. - # Setting OOV indices to a value > 1 is broken. - tokenizer = self.Tokenizer(chars=['a', 'b', 'c', 'd'], - use_mask=False, num_oov_indices=0) - self.assertEqual(tokenizer.charlist, ['a', 'b', 'c', 'd']) - self.assertTrue(isinstance(tokenizer.char_to_num, - tf.keras.layers.StringLookup)) - - def test_ctc_decode_greedy(self): - # Mock data - y_pred = np.random.random((32, 10, 5)) - input_length = np.random.randint(1, 10, size=(32,)) - - # Call the function with greedy=True - from utils.decoding import ctc_decode - decoded_dense, log_prob = ctc_decode(y_pred, input_length, - greedy=True) - - # Verify that the output is as expected - self.assertTrue(isinstance(decoded_dense[0], tf.Tensor)) - self.assertTrue(isinstance(log_prob, tf.Tensor)) - - def test_ctc_decode_beam(self): - # Mock data - y_pred = np.random.random((32, 10, 5)) - input_length = np.random.randint(1, 10, size=(32,)) - beam_width = 100 - - # Call the function with greedy=False - from utils.decoding import ctc_decode - decoded_dense, log_prob = ctc_decode(y_pred, input_length, - greedy=False, - beam_width=beam_width) - - # Verify that the output is as expected - # Ensure that the output is a list of tensors - self.assertTrue(isinstance(decoded_dense, list)) - self.assertTrue(isinstance(decoded_dense[0], tf.Tensor)) - self.assertTrue(isinstance(log_prob, tf.Tensor)) - - def test_decode_batch(self): - chars = ['a', 'b', 'c'] - tokenizer = self.Tokenizer(chars=chars, use_mask=False) + def test_initialize_string_lookup_layers(self): + # Test initialization with a basic token list + tokens = ['a', 'b', 'c'] + tokenizer = self.Tokenizer(tokens=tokens) - # Mock data - y_pred = np.random.random((32, 10, 5)) + self.assertEqual(tokenizer.token_list, [ + '[PAD]', '[UNK]', 'a', 'b', 'c']) + self.assertIsInstance(tokenizer.token_to_num, + tf.keras.layers.StringLookup) + self.assertIsInstance(tokenizer.num_to_token, + tf.keras.layers.StringLookup) - # Call the function - from utils.decoding import decode_batch_predictions - result = decode_batch_predictions(y_pred, tokenizer) + def test_tokenizer_call(self): + # Test tokenizing a simple text string + tokens = ['a', 'b', 'c'] + tokenizer = self.Tokenizer(tokens=tokens) - # Verify that the output is as expected - self.assertTrue(isinstance(result, list)) - self.assertTrue(isinstance(result[0][0], np.float32)) - self.assertTrue(isinstance(result[0][1], str)) + text = 'abc' + tokenized_output = tokenizer(text) + expected_output = [2, 3, 4] # Corresponding indices of 'a', 'b', 'c' - def test_decode_batch_with_beam(self): - chars = ['a', 'b', 'c'] - tokenizer = self.Tokenizer(chars=chars, use_mask=False) + self.assertTrue(tf.reduce_all( + tf.equal(tokenized_output, expected_output))) + + def test_tokenizer_decode(self): + # Test decoding a sequence of token indices back into text + tokens = ['a', 'b', 'c'] + tokenizer = self.Tokenizer(tokens=tokens) + + tokenized_input = tf.constant([2, 3, 4]) # Indices of 'a', 'b', 'c' + decoded_text = tokenizer.decode(tokenized_input) - # Mock data - y_pred = np.random.random((32, 10, 5)) + self.assertEqual(decoded_text, 'abc') - # Call the function - from utils.decoding import decode_batch_predictions - result = decode_batch_predictions(y_pred, tokenizer, beam_width=100) + def test_load_from_file(self): + # Test loading from a JSON file + tokens = ['a', 'b', 'c'] - # Verify that the output is as expected - self.assertTrue(isinstance(result, list)) - self.assertTrue(isinstance(result[0][0], np.float32)) - self.assertTrue(isinstance(result[0][1], str)) + with TemporaryDirectory() as temp_dir: + json_path = os.path.join(temp_dir, 'tokenizer.json') + tokenizer = self.Tokenizer(tokens=tokens) + tokenizer.save_to_json(json_path) + + loaded_tokenizer = self.Tokenizer.load_from_file(json_path) + self.assertEqual(loaded_tokenizer.token_list, tokenizer.token_list) + + def test_load_from_legacy_file(self): + # Test loading from a legacy charlist.txt file and converting to JSON + chars = ['a', 'b', 'c'] + with TemporaryDirectory() as temp_dir: + txt_path = os.path.join(temp_dir, 'charlist.txt') + with open(txt_path, 'w', encoding='utf-8') as f: + f.write(''.join(chars)) + + loaded_tokenizer = self.Tokenizer.load_from_file(txt_path) + # Skipping [PAD], [UNK] + self.assertEqual(loaded_tokenizer.token_list[2:], chars) + self.assertTrue(os.path.exists( + os.path.join(temp_dir, 'tokenizer.json'))) + + def test_save_to_json(self): + # Test saving tokenizer to a JSON file + tokens = ['a', 'b', 'c'] + tokenizer = self.Tokenizer(tokens=tokens) + + with TemporaryDirectory() as temp_dir: + json_path = os.path.join(temp_dir, 'tokenizer.json') + tokenizer.save_to_json(json_path) + + with open(json_path, 'r', encoding='utf-8') as f: + data = json.load(f) + self.assertEqual([data[str(i)] + for i in range(len(data))], tokenizer.token_list) + + def test_add_tokens(self): + # Test adding new tokens + tokens = ['a', 'b', 'c'] + tokenizer = self.Tokenizer(tokens=tokens) + + tokenizer.add_tokens(['d', 'e']) + self.assertIn('d', tokenizer.token_list) + self.assertIn('e', tokenizer.token_list) + + def test_empty_token_list(self): + # Test initializing the tokenizer with an empty token list + with self.assertRaises(ValueError): + self.Tokenizer(tokens=[]) + + def test_tokenizer_str(self): + # Test string representation of tokenizer + tokens = ['a', 'b', 'c'] + tokenizer = self.Tokenizer(tokens=tokens) + tokenizer_str = str(tokenizer) + + expected_str = json.dumps( + dict(enumerate(tokenizer.token_list)), ensure_ascii=False, indent=4) + self.assertEqual(tokenizer_str, expected_str) + + def test_tokenizer_len(self): + # Test length of tokenizer + tokens = ['a', 'b', 'c'] + tokenizer = self.Tokenizer(tokens=tokens) + self.assertEqual(len(tokenizer), len(tokenizer.token_list)) if __name__ == '__main__':