Skip to content

Commit

Permalink
Add docstring for conversion function
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 23, 2024
1 parent 50910cb commit 8c159e1
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions src/model/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,9 @@ def load_model_from_directory(directory: str,
return tf.keras.models.load_model(model_file,
custom_objects=custom_objects,
compile=compile)
except (TypeError, ValueError) as e:
logging.error("Error loading model: %s", e)
logging.info("Attempting to convert the model to the new Keras "
"format.")
except (TypeError, ValueError):
logging.error("Error loading model. Attempting to convert the "
"model to the new format.")

# Convert the old model to the new format
model = _convert_old_model_to_new(model_file, custom_objects,
Expand All @@ -207,6 +206,24 @@ def load_model_from_directory(directory: str,
def _convert_old_model_to_new(model_file: str,
custom_objects: dict,
compile: bool = True) -> tf.keras.Model:
"""
Converts an old v2 Keras model to the new v3 format.
Parameters
----------
model_file : str
The path to the .keras file containing the old model.
custom_objects : dict
Custom objects required for model loading.
compile : bool, optional
Whether to compile the model after loading, by default True.
Returns
-------
tf.keras.Model
The converted Keras model.
"""

# Temporary directory to extract the .keras file contents
temp_dir = "/tmp/keras_model_extraction"

Expand Down

0 comments on commit 8c159e1

Please sign in to comment.