Skip to content

Commit

Permalink
Rename DataCreator to DataManager
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Mar 7, 2024
1 parent 49e5734 commit bc3266b
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 332 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ jobs:
env:
TF_CPP_MIN_LOG_LEVEL: '2'

- name: Test DataCreator
- name: Test DataManager
if: always()
run: |
python -m unittest tests/test_datacreator.py
python -m unittest tests/test_datamanager.py
env:
TF_CPP_MIN_LOG_LEVEL: '2'

Expand Down
30 changes: 12 additions & 18 deletions src/data/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,42 @@
import tensorflow as tf

# > Local dependencies
from data.creator import DataCreator
from data.manager import DataManager
from setup.config import Config


def initialize_data_loader(config: Config,
charlist: List[str],
model: tf.keras.Model,
augment_model: tf.keras.Sequential) -> DataCreator:
def initialize_data_manager(config: Config,
charlist: List[str],
model: tf.keras.Model,
augment_model: tf.keras.Sequential) -> DataManager:
"""
Initializes a data loader with specified parameters and based on the input
Initializes a data manager with specified parameters and based on the input
shape of a given model.
Parameters
----------
config : Config
A Config containing various arguments to configure the data loader
A Config containing various arguments to configure the data manager
(e.g., batch size, image size, lists for training, validation, etc.).
charlist : List[str]
A list of characters to be used by the data loader.
A list of characters to be used by the data manager.
model : tf.keras.Model
The Keras model, used to derive input dimensions for the data loader.
The Keras model, used to derive input dimensions for the data manager.
augment_model : tf.keras.Sequential
The Keras model used for data augmentation.
Returns
-------
DataLoader
An instance of DataLoader configured as per the provided arguments and
DataManager
An instance of DataManager configured as per the provided arguments and
model.
Notes
-----
The DataLoader is initialized with parameters like image size, batch size,
and various data augmentation options. These parameters are derived from
both the `args` namespace and the input shape of the provided `model`.
"""

model_height = model.layers[0].input_shape[0][2]
model_channels = model.layers[0].input_shape[0][3]
img_size = (model_height, config["width"], model_channels)

return DataCreator(
return DataManager(
img_size=img_size,
charlist=charlist,
augment_model=augment_model,
Expand Down
2 changes: 1 addition & 1 deletion src/data/creator.py → src/data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from utils.text import Tokenizer, normalize_text


class DataCreator:
class DataManager:
"""
Class for creating and managing datasets for training, validating, etc.
Expand Down
33 changes: 18 additions & 15 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

# > Local dependencies
# Data handling
from data.data_handling import load_initial_charlist, initialize_data_loader, \
save_charlist
from data.data_handling import save_charlist, load_initial_charlist, \
initialize_data_manager

# Model-specific
from data.augmentation import make_augment_model, visualize_augments
Expand Down Expand Up @@ -71,12 +71,12 @@ def main():
visualize_augments(augmentation_model, config["output"],
model.input_shape[-1])

# Initialize the Dataloader
loader = initialize_data_loader(config, charlist, model,
augmentation_model)
# Initialize the DataManager
data_manager = initialize_data_manager(config, charlist, model,
augmentation_model)

# Replace the charlist with the one from the data loader
charlist = loader.charlist
# Replace the charlist with the one from the data manager
charlist = data_manager.charlist

# Additional model customization such as freezing layers, replacing
# layers, or adjusting for float32
Expand All @@ -92,7 +92,7 @@ def main():
learning_rate=config["learning_rate"],
decay_rate=config["decay_rate"],
decay_steps=config["decay_steps"],
train_batches=loader.get_train_batches(),
train_batches=data_manager.get_train_batches(),
do_train=config["do_train"],
warmup_ratio=config["warmup_ratio"],
epochs=config["epochs"],
Expand Down Expand Up @@ -125,8 +125,11 @@ def main():
if config["train_list"]:
tick = time.time()

history = train_model(model, config, loader.datasets["train"],
loader.datasets["evaluation"], loader)
history = train_model(model,
config,
data_manager.datasets["train"],
data_manager.datasets["evaluation"],
data_manager)

# Plot the training history
plot_training_history(history, config["output"],
Expand All @@ -139,23 +142,23 @@ def main():
logging.warning("Validation results are without special markdown tags")

tick = time.time()
perform_validation(config, model, charlist, loader)
perform_validation(config, model, charlist, data_manager)
timestamps['Validation'] = time.time() - tick

# Test the model
if config["test_list"]:
logging.warning("Test results are without special markdown tags")

tick = time.time()
perform_test(config, model, loader.datasets["test"],
charlist, loader)
perform_test(config, model, data_manager.datasets["test"],
charlist, data_manager)
timestamps['Test'] = time.time() - tick

# Infer with the model
if config["inference_list"]:
tick = time.time()
perform_inference(config, model, loader.datasets["inference"],
charlist, loader)
perform_inference(config, model, data_manager.datasets["inference"],
charlist, data_manager)
timestamps['Inference'] = time.time() - tick

# Log the timestamps
Expand Down
25 changes: 13 additions & 12 deletions src/modes/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import tensorflow as tf

# > Local dependencies
from data.loader import DataLoader
from data.creator import DataCreator
from data.manager import DataManager
from model.management import get_prediction_model
from setup.config import Config
from utils.decoding import decode_batch_predictions
from utils.text import Tokenizer


def perform_inference(config: Config, model: tf.keras.Model,
inference_dataset: DataLoader, charlist: List[str],
loader: DataCreator) -> None:
def perform_inference(config: Config,
model: tf.keras.Model,
inference_dataset: tf.data.Dataset,
charlist: List[str],
data_manager: DataManager) -> None:
"""
Performs inference on a given dataset using a specified model and writes
the results to a file.
Expand All @@ -31,12 +32,12 @@ def perform_inference(config: Config, model: tf.keras.Model,
masking, the batch size, and parameters for decoding predictions.
model : tf.keras.Model
The Keras model to be used for inference.
inference_dataset : DataGenerator
inference_dataset : tf.data.Dataset
The dataset on which inference is to be performed.
charlist : List[str]
A list of characters used in the model, for decoding predictions.
loader : DataLoader
A data loader object used for retrieving additional information needed
data_manager : DataManager
A data manager object used for retrieving additional information needed
during inference (e.g., filenames).
Notes
Expand Down Expand Up @@ -66,10 +67,10 @@ def perform_inference(config: Config, model: tf.keras.Model,
prediction = prediction.strip().replace('', '')

# Format the filename
filename = loader.get_filename('inference',
(batch_no *
config["batch_size"])
+ index)
filename = data_manager.get_filename('inference',
(batch_no *
config["batch_size"])
+ index)

# Write the results to the results file
result_str = f"{filename}\t{confidence}\t{prediction}"
Expand Down
21 changes: 10 additions & 11 deletions src/modes/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import tensorflow as tf

# > Local dependencies
from data.loader import DataLoader
from data.creator import DataCreator
from data.manager import DataManager
from model.management import get_prediction_model
from setup.config import Config
from utils.calculate import calc_95_confidence_interval, calculate_cers, \
Expand All @@ -24,7 +23,7 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor],
tokenizer: Tokenizer,
config: Config,
wbs: Optional[Any],
loader: DataCreator,
data_manager: DataManager,
chars: List[str]) -> Dict[str, int]:
"""
Processes a batch of data by predicting, calculating Character Error Rate
Expand All @@ -45,8 +44,8 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor],
wbs : Optional[Any]
An optional Word Beam Search object for advanced decoding, if
applicable.
loader : DataLoader
A data loader object for additional operations like normalization.
data_manager : DataManager
A data data_manager object for additional operations like normalization.
chars : List[str]
A list of characters used in the model.
Expand Down Expand Up @@ -113,9 +112,9 @@ def process_batch(batch: Tuple[tf.Tensor, tf.Tensor],

def perform_test(config: Config,
model: tf.keras.Model,
test_dataset: DataLoader,
test_dataset: tf.data.Dataset,
charlist: List[str],
dataloader: DataCreator) -> None:
data_manager: DataManager) -> None:
"""
Performs test run on a dataset using a given model and calculates various
metrics like Character Error Rate (CER).
Expand All @@ -127,12 +126,12 @@ def perform_test(config: Config,
mask usage and file paths.
model : tf.keras.Model
The Keras model to be validated.
test_dataset : DataGenerator
test_dataset : tf.data.Dataset
The dataset to be used for testing.
charlist : List[str]
A list of characters used in the model.
dataloader : DataLoader
A data loader object for additional operations like normalization and
data_manager : DataManager
A data data_manager object for additional operations like normalization and
Word Beam Search setup.
Notes
Expand Down Expand Up @@ -162,7 +161,7 @@ def perform_test(config: Config,

logging.info("Batch %s/%s", batch_no + 1, len(test_dataset))
batch_counter = process_batch((X, y_true), prediction_model, tokenizer,
config, wbs, dataloader, charlist)
config, wbs, data_manager, charlist)

# Update the total counter
for key, value in batch_counter.items():
Expand Down
10 changes: 5 additions & 5 deletions src/modes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tensorflow as tf

# > Local dependencies
from data.creator import DataCreator
from data.manager import DataManager
from setup.config import Config
from model.custom_callback import LoghiCustomCallback

Expand All @@ -18,7 +18,7 @@ def train_model(model: tf.keras.Model,
config: Config,
training_dataset: tf.data.Dataset,
validation_dataset: tf.data.Dataset,
loader: DataCreator,
data_manager: DataManager,
num_workers: int = 20) -> Any:
"""
Trains a Keras model using the provided training and validation datasets,
Expand All @@ -34,8 +34,8 @@ def train_model(model: tf.keras.Model,
The dataset to be used for training.
validation_dataset : tf.data.Dataset
The dataset to be used for validation.
loader : DataLoader
A DataLoader containing additional information like character list.
data_manager : DataManager
A DataManager containing additional information like character list.
num_workers : int, default 20
Number of workers for data loading.
Expand All @@ -62,7 +62,7 @@ def train_model(model: tf.keras.Model,
LoghiCustomCallback(save_best=True,
save_checkpoint=config["output_checkpoints"],
output=config["output"],
charlist=loader.charlist,
charlist=data_manager.charlist,
config=config,
normalization_file=config["normalization_file"])

Expand Down
Loading

0 comments on commit bc3266b

Please sign in to comment.