Skip to content

Commit

Permalink
Initial tokenizer.json version
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 27, 2024
1 parent febcb76 commit eec06d4
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 122 deletions.
5 changes: 3 additions & 2 deletions src/data/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# > Local dependencies
from data.manager import DataManager
from setup.config import Config
from utils.text import Tokenizer


def initialize_data_manager(config: Config,
charlist: List[str],
tokenizer: Tokenizer,
model: tf.keras.Model,
augment_model: tf.keras.Sequential) -> DataManager:
"""
Expand Down Expand Up @@ -46,7 +47,7 @@ def initialize_data_manager(config: Config,

return DataManager(
img_size=img_size,
charlist=charlist,
tokenizer=tokenizer,
augment_model=augment_model,
config=config
)
Expand Down
75 changes: 34 additions & 41 deletions src/data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,22 @@ class DataManager:
The model used for data augmentation.
config : Config
The configuration dictionary containing various settings.
charlist : Optional[List[str]], optional
The list of characters to use for tokenization, by default None.
tokenizer : Optional[Tokenizer], optional
The tokenizer to use for encoding text data, by default None.
"""

def __init__(self,
img_size: Tuple[int, int, int],
augment_model: tf.keras.Sequential,
config: Config,
charlist: Optional[List[str]] = None,
tokenizer: Optional[Tokenizer] = None,
):

self.augment_model = augment_model
self.height = img_size[0]
self.channels = img_size[2]
self.config = config

# Determine the character list
if charlist and not config['replace_final_layer']:
logging.info('Using injected charlist')
self.charlist = sorted(list(charlist))
else:
self.charlist = []
self.tokenizer = tokenizer

# Determine the evaluation list
self.evaluation_list = None
Expand All @@ -60,8 +54,7 @@ def __init__(self,
# Process the raw data and create file names, labels, sample weights,
# and tokenizer
logging.info("Processing raw data...")
file_names, labels, sample_weights, self.tokenizer \
= self._process_raw_data()
file_names, labels, sample_weights, self.tokenizer = self._process_raw_data()

self.raw_data = {split: (file_names[split], labels[split],
sample_weights[split])
Expand Down Expand Up @@ -91,11 +84,11 @@ def _process_raw_data(self) -> Tuple[Dict[str, List[str]],
weights, and the tokenizer.
"""

# Initialize character set and data partitions with corresponding
# labels
# Initialize data partitions with corresponding labels
file_names_dict = defaultdict(list)
labels_dict = defaultdict(list)
sample_weights_dict = defaultdict(list)
characters = set()

for partition in ('train', 'evaluation', 'validation',
'test', 'inference'):
Expand All @@ -107,26 +100,26 @@ def _process_raw_data(self) -> Tuple[Dict[str, List[str]],
file_names, labels, sample_weights = self._create_data(
partition_name=partition,
text_file=partition_text_file,
characters=characters
)
if len(file_names) == 0:
raise ValueError("No data found for the specified "
f"{partition} list(s). Have you verified "
"the data paths?")

if partition == "train" and not self.charlist:
raise ValueError("Character list is empty after "
"creating training data. Did you "
"forget to provide a character list?")

# Fill the dictionary with the data
file_names_dict[partition] = file_names
labels_dict[partition] = labels
sample_weights_dict[partition] = sample_weights

# Initialize the tokenizer
tokenizer = Tokenizer(self.charlist, self.config['use_mask'])
# Initialize the tokenizer if not provided
if self.tokenizer is None:
if not characters:
raise ValueError(
"Character list is empty after creating training data.")
self.tokenizer = Tokenizer(sorted(characters))

return file_names_dict, labels_dict, sample_weights_dict, tokenizer
return file_names_dict, labels_dict, sample_weights_dict, self.tokenizer

def _fill_datasets_dict(self,
partitions: Dict[str, List[str]],
Expand All @@ -135,7 +128,7 @@ def _fill_datasets_dict(self,
-> Dict[str, tf.data.Dataset]:
"""
Initialize data generators for different dataset partitions and
update character set and tokenizer based on the dataset.
update tokenizer based on the dataset.
Parameters
----------
Expand Down Expand Up @@ -177,7 +170,8 @@ def _fill_datasets_dict(self,

def _create_data(self,
partition_name: str,
text_file: str) -> Tuple[List[str], List[str], List[str]]:
text_file: str,
characters: Set[str]) -> Tuple[List[str], List[str], List[str]]:
"""
Create data for a specific partition from a text file.
Expand All @@ -187,6 +181,8 @@ def _create_data(self,
The name of the partition.
text_file : str
The path to the text file containing the data.
characters : Set[str]
The set of characters to be used for the tokenizer.
Returns
-------
Expand All @@ -201,9 +197,6 @@ def _create_data(self,
faulty_lines = {}
flaw_counts = {}

# Define the character set
characters = set(self.charlist)

# Process each file in the data files list
for file_path in text_file.split():
if not os.path.exists(file_path):
Expand All @@ -229,16 +222,14 @@ def _create_data(self,
logging.warning("Faulty lines for %s:", partition_name)
for line, flaw in faulty_lines.items():
logging.warning("%s: %s", line.strip(), flaw)
logging.warning("Flaw counts for %s:", partition_name)
for flaw, count in flaw_counts.items():
logging.warning("%s: %d", flaw, count)
logging.warning("Flaw counts for %s:", partition_name)
for flaw, count in flaw_counts.items():
logging.warning("%s: %d", flaw, count)

# Update the charlist if it has changed
if not self.charlist:
self.charlist = sorted(list(characters))
logging.debug("Updated charlist: %s", self.charlist)
# Update the character set with any new characters found
characters.update(set("".join(labels)))

logging.info("Created data for %s with %s samples",
logging.info("Created data for %s with %d samples",
partition_name, len(partitions))

return partitions, labels, sample_weights
Expand All @@ -247,8 +238,7 @@ def _process_line(self,
line: str,
partition_name: str,
characters: Set[str]) \
-> Tuple[Optional[Tuple[str, str, float]],
Optional[str]]:
-> Tuple[Optional[Tuple[str, str, float]], Optional[str]]:
"""
Process a single line from the data file.
Expand Down Expand Up @@ -296,10 +286,13 @@ def _process_line(self,
# Normalize the ground truth if a normalization file is provided and
# the partition is either 'train' or 'evaluation'
if self.config['normalization_file'] and \
partition_name in ('train' or 'evaluation'):
partition_name in ('train', 'evaluation'):
ground_truth = normalize_text(ground_truth,
self.config['normalization_file'])

# Add characters to the set for tokenizer
characters.update(ground_truth)

# Check for unsupported characters in the ground truth
if not self._is_valid_ground_truth(ground_truth, partition_name,
characters):
Expand Down Expand Up @@ -372,7 +365,7 @@ def _is_valid_ground_truth(self,
if partition_name in ('validation', 'inference', 'test'):
return True

if partition_name == 'train' and not self.charlist:
if partition_name == 'train':
characters.update(unsupported_characters)
return True
return False
Expand Down Expand Up @@ -463,7 +456,7 @@ def _create_dataset(self,

dataset = (dataset
.map(data_loader.load_images,
num_parallel_calls=AUTOTUNE,
num_parallel_calls=tf.data.experimental.AUTOTUNE,
deterministic=not is_training)
.padded_batch(self.config["batch_size"],
padded_shapes=(
Expand All @@ -472,7 +465,7 @@ def _create_dataset(self,
tf.constant(-10, dtype=tf.float32),
tf.constant(0, dtype=tf.int64),
tf.constant(1.0, dtype=tf.float32)))
.prefetch(AUTOTUNE))\
.prefetch(tf.data.experimental.AUTOTUNE))\
.apply(tf.data.experimental.assert_cardinality(num_batches))

# Distribute the dataset if needed
Expand Down
31 changes: 15 additions & 16 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,11 @@ def main():
if config["output"]:
os.makedirs(config["output"], exist_ok=True)

# Get the initial character list
if os.path.isdir(config["model"]) or config["charlist"]:
charlist, removed_padding = load_initial_charlist(
config["charlist"], config["model"],
config["output"], config["replace_final_layer"])
# Get the initial tokenizer
if os.path.isdir(config["model"]) or config["tokenizer"]:
tokenizer = Tokenizer(json_path=config["tokenizer"])
else:
charlist = []
tokenizer = None
removed_padding = False

# Set the custom objects
Expand All @@ -72,22 +70,22 @@ def main():
model.input_shape[-1])

# Initialize the DataManager
data_manager = initialize_data_manager(config, charlist, model,
data_manager = initialize_data_manager(config, tokenizer, model,
augmentation_model)

# Replace the charlist with the one from the data manager
charlist = data_manager.charlist
logging.info("Using charlist: %s", charlist)
logging.info("Charlist length: %s", len(charlist))
# Replace the tokenizer with the one from the data manager
tokenizer = data_manager.tokenizer
logging.info("Using tokenizer:\n%s", tokenizer)
logging.info("Tokenizer size: %s tokens", len(tokenizer))

# TODO: Continue from here
# Additional model customization such as freezing layers, replacing
# layers, or adjusting for float32
model = customize_model(model, config, charlist)
model = customize_model(model, config, tokenizer)

# Save the charlist
verify_charlist_length(charlist, model, config["use_mask"],
removed_padding)
save_charlist(charlist, config["output"])
tokenizer.save_to_json(os.path.join(config["output"],
"tokenizer.json"))

# Create the learning rate schedule
lr_schedule = create_learning_rate_schedule(
Expand Down Expand Up @@ -134,7 +132,8 @@ def main():
data_manager.datasets["evaluation"],
data_manager)
# Plot the training history
plot_training_history(history=history, output_path=config["output"],
plot_training_history(history=history,
output_path=config["output"],
plot_validation=bool(config["validation_list"]))

timestamps['Training'] = time.time() - tick
Expand Down
17 changes: 8 additions & 9 deletions src/model/custom_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import tensorflow as tf

# > Local dependencies
from utils.text import Tokenizer
from setup.config import Config


Expand All @@ -26,8 +27,8 @@ class LoghiCustomCallback(tf.keras.callbacks.Callback):
If True, saves the model at the end of each epoch.
output : str
Directory path to save the model and additional files.
charlist : list of str, optional
List of characters used in the model, saved alongside the model.
tokenizer : Tokenizer
Tokenizer object to be saved with the model.
config : object, optional
Configuration object to be saved with the model.
normalization_file : str, optional
Expand All @@ -45,7 +46,7 @@ class LoghiCustomCallback(tf.keras.callbacks.Callback):
"""

def __init__(self, save_best: bool = True, save_checkpoint: bool = True,
output: str = "output", charlist: str = None,
output: str = "output", tokenizer: Tokenizer = None,
config: Config = None, normalization_file: str = None,
logging_level: str = "info"):
"""
Expand All @@ -56,7 +57,7 @@ def __init__(self, save_best: bool = True, save_checkpoint: bool = True,
self.save_best = save_best
self.save_checkpoint = save_checkpoint
self.output = output
self.charlist = charlist
self.tokenizer = tokenizer
self.config = config
self.normalization_file = normalization_file
self.logging_level = logging_level
Expand Down Expand Up @@ -116,11 +117,9 @@ def _save_model(self, subdir: str):
unfrozen_model.save(model_path)

# Save additional files
if self.charlist:
with open(os.path.join(outputdir, "charlist.txt"),
"w", encoding="utf-8") \
as chars_file:
chars_file.write("".join(self.charlist))
if self.tokenizer:
self.tokenizer.save_to_json(os.path.join(outputdir,
"tokenizer.json"))
if self.config:
self.config.save(os.path.join(outputdir, "config.json"))
if self.normalization_file:
Expand Down
19 changes: 10 additions & 9 deletions src/model/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import tensorflow as tf

# > Local dependencies
from model.custom_model import build_custom_model
from model.conversion import convert_model
from model.replacing import replace_final_layer, replace_recurrent_layer
from model.vgsl_model_generator import VGSLModelGenerator
from setup.config import Config

from model.custom_model import build_custom_model
from utils.text import Tokenizer


def adjust_model_for_float32(model: tf.keras.Model) -> tf.keras.Model:
Expand Down Expand Up @@ -63,7 +63,7 @@ def adjust_model_for_float32(model: tf.keras.Model) -> tf.keras.Model:

def customize_model(model: tf.keras.Model,
config: Config,
charlist: List[str]) -> tf.keras.Model:
tokenizer: Tokenizer) -> tf.keras.Model:
"""
Customizes a Keras model based on various arguments including layer
replacement and freezing options.
Expand All @@ -74,8 +74,8 @@ def customize_model(model: tf.keras.Model,
The model to be customized.
config : Config
A set of arguments controlling how the model should be customized.
charlist : List[str]
A list of characters used for model customization.
tokenizer : Tokenizer
The tokenizer object used for tokenization.
Returns
-------
Expand All @@ -88,16 +88,17 @@ def customize_model(model: tf.keras.Model,
logging.info("Replacing recurrent layer with %s",
config["replace_recurrent_layer"])
model = replace_recurrent_layer(model,
len(charlist),
len(tokenizer),
config["replace_recurrent_layer"],
use_mask=config["use_mask"])

# Replace the final layer if specified
if config["replace_final_layer"] or not os.path.isdir(config["model"]):
new_classes = len(charlist) + \
2 if config["use_mask"] else len(charlist) + 1
new_classes = len(tokenizer) + \
2 if config["use_mask"] else len(
tokenizer) + 1 # TODO : replace use_mask
logging.info("Replacing final layer with %s classes", new_classes)
model = replace_final_layer(model, len(charlist), model.name,
model = replace_final_layer(model, len(tokenizer), model.name,
use_mask=config["use_mask"])

# Freeze layers if specified
Expand Down
Loading

0 comments on commit eec06d4

Please sign in to comment.