Skip to content

Commit

Permalink
Remove prediction_model
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 23, 2024
1 parent 8c159e1 commit 9f7ee18
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 43 deletions.
36 changes: 0 additions & 36 deletions src/model/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,39 +383,3 @@ def verify_charlist_length(charlist: List[str],
f"Charlist length ({len(charlist)}) does not match "
f"model output length ({expected_length}). If the charlist "
"is correct, try setting use_mask to True.")


def get_prediction_model(model: tf.keras.Model) -> tf.keras.Model:
"""
Extracts a prediction model from a given Keras model.
Parameters
----------
model : tf.keras.Model
The complete Keras model from which the prediction model is to be
extracted.
Returns
-------
tf.keras.Model
The prediction model extracted from the given model, typically up to
the last dense layer.
Raises
------
ValueError
If no dense layer is found in the given model.
"""

last_dense_layer = None
for layer in reversed(model.layers):
if layer.name.startswith('dense'):
last_dense_layer = layer
break
if last_dense_layer is None:
raise ValueError("No dense layer found in the model")

prediction_model = tf.keras.models.Model(
model.get_layer(name="image").input, last_dense_layer.output
)
return prediction_model
4 changes: 1 addition & 3 deletions src/modes/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

# > Local dependencies
from data.manager import DataManager
from model.management import get_prediction_model
from setup.config import Config
from utils.decoding import decode_batch_predictions
from utils.text import Tokenizer
Expand Down Expand Up @@ -50,12 +49,11 @@ def perform_inference(config: Config,
"""

tokenizer = Tokenizer(charlist, config["use_mask"])
prediction_model = get_prediction_model(model)

with open(config["results_file"], "w", encoding="utf-8") as results_file:
for batch_no, batch in enumerate(inference_dataset):
# Get the predictions
predictions = prediction_model.predict_on_batch(batch[0])
predictions = model.predict_on_batch(batch[0])
y_pred = decode_batch_predictions(predictions,
tokenizer,
config["greedy"],
Expand Down
4 changes: 1 addition & 3 deletions src/modes/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

# > Local dependencies
from data.manager import DataManager
from model.management import get_prediction_model
from setup.config import Config
from utils.calculate import calc_95_confidence_interval, calculate_cers, \
increment_counters, calculate_edit_distances
Expand Down Expand Up @@ -146,7 +145,6 @@ def perform_test(config: Config,
logging.info("Performing test...")

tokenizer = Tokenizer(charlist, config["use_mask"])
prediction_model = get_prediction_model(model)

# Setup WordBeamSearch if needed
wbs = setup_word_beam_search(config, charlist) \
Expand All @@ -161,7 +159,7 @@ def perform_test(config: Config,
X, y_true, _ = batch

logging.info("Batch %s/%s", batch_no + 1, len(test_dataset))
batch_counter = process_batch((X, y_true), prediction_model, tokenizer,
batch_counter = process_batch((X, y_true), model, tokenizer,
config, wbs, data_manager, charlist)

# Update the total counter
Expand Down
1 change: 0 additions & 1 deletion src/modes/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

# > Local dependencies
from data.manager import DataManager
from model.management import get_prediction_model
from setup.config import Config
from utils.calculate import calc_95_confidence_interval, \
calculate_edit_distances, update_statistics, increment_counters
Expand Down

0 comments on commit 9f7ee18

Please sign in to comment.