Skip to content

Commit

Permalink
Fix training + loading legacy models
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 23, 2024
1 parent 015f460 commit e5f5f88
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 9 deletions.
118 changes: 115 additions & 3 deletions src/model/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import logging
import os
from typing import Any, List, Dict, Optional
import json
import shutil
import warnings
import zipfile

# > Third-party dependencies
import tensorflow as tf
Expand Down Expand Up @@ -170,13 +174,121 @@ def load_model_from_directory(directory: str,
directory) if file.endswith(".keras")), None)

if model_file:
return tf.keras.models.load_model(model_file,
custom_objects=custom_objects,
compile=compile)
try:
return tf.keras.models.load_model(model_file,
custom_objects=custom_objects,
compile=compile)
except (TypeError, ValueError) as e:
logging.error("Error loading model: %s", e)
logging.info("Attempting to convert the model to the new Keras "
"format.")

# Convert the old model to the new format
model = _convert_old_model_to_new(model_file, custom_objects,
compile=compile)

# Save the converted model
# Rename the old model file
if not os.path.exists(model_file + ".old"):
logging.info("Renaming old model file to %s",
model_file + ".old")
old_model_file = model_file + ".old"
os.rename(model_file, old_model_file)

# Save the new model
logging.info("Saving new model to %s", model_file)
model.save(model_file)

return model

raise FileNotFoundError("No suitable model file found in the directory.")


def _convert_old_model_to_new(model_file: str,
custom_objects: dict,
compile: bool = True) -> tf.keras.Model:
# Temporary directory to extract the .keras file contents
temp_dir = "/tmp/keras_model_extraction"

# Ensure the temp directory is clean
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir)

# Extract .keras file
with zipfile.ZipFile(model_file, 'r') as zip_ref:
zip_ref.extractall(temp_dir)

# Load model architecture from config.json
with open(os.path.join(temp_dir, 'config.json'), 'r') as json_file:
model_config = json.load(json_file)

model_config["module"] = "keras.src.models.functional"

# Function to recursively correct "axis" in BatchNormalization layers
def correct_axis(layer_config):
if isinstance(layer_config, dict):
if layer_config.get('class_name') == 'BatchNormalization' \
and isinstance(layer_config.get('config', {})
.get('axis'), list):
layer_config['config']['axis'] \
= layer_config['config']['axis'][0]
elif layer_config.get('class_name') == 'Bidirectional':
lstm_layer_config = layer_config['config']['layer']['config']
lstm_layer_config['recurrent_initializer']['class_name'] \
= 'OrthogonalInitializer'
lstm_layer_config.pop('time_major', None)
elif 'layers' in layer_config:
for sub_layer in layer_config['layers']:
correct_axis(sub_layer)

correct_axis(model_config["config"])

# Replace 'Policy' with 'DTypePolicy' in all layers' dtype configurations
def replace_policy_with_dtypepolicy(obj):
if isinstance(obj, dict):
if obj.get('class_name') == 'Policy':
obj['class_name'] = 'DTypePolicy'
for key, value in obj.items():
replace_policy_with_dtypepolicy(value)
elif isinstance(obj, list):
for item in obj:
replace_policy_with_dtypepolicy(item)

replace_policy_with_dtypepolicy(model_config["config"]['layers'])

if model_config.get("compile_config"):
compile_optimizer = model_config["compile_config"]["optimizer"]
compile_optimizer["module"] = "keras.optimizers"
compile_optimizer["config"].pop("jit_compile", None)
compile_optimizer["config"].pop("is_legacy_optimizer", None)

if not compile:
model_config.pop("compile_config", None)

# Save the corrected config back to a temporary json file
corrected_config_path = os.path.join(temp_dir, 'corrected_config.json')
with open(corrected_config_path, 'w') as json_file:
json.dump(model_config, json_file)

# Load model from corrected config
with open(corrected_config_path, 'r') as json_file:
corrected_model_json = json_file.read()
model = tf.keras.models.model_from_json(
corrected_model_json, custom_objects=custom_objects)

# Load weights into the model
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model.load_weights(os.path.join(temp_dir, 'model.weights.h5'),
skip_mismatch=True)

# Clean up the temporary directory
shutil.rmtree(temp_dir)

return model


def load_or_create_model(config: Config,
custom_objects: Dict[str, Any]) -> tf.keras.Model:
"""
Expand Down
5 changes: 2 additions & 3 deletions src/model/replacing.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ def replace_final_layer(model: tf.keras.models.Model,
last_layer = layer.name

# Create a prediction model up to the last layer
prediction_model = keras.models.Model(
model.get_layer(name="image").input,
model.get_layer(name=last_layer).output
prediction_model = tf.keras.models.Model(
inputs=model.inputs, outputs=model.get_layer(last_layer).output
)

# Add a new dense layer with adjusted number of units based on use_mask
Expand Down
5 changes: 2 additions & 3 deletions src/modes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def train_model(model: tf.keras.Model,
epochs=config["epochs"],
callbacks=callbacks,
shuffle=True,
workers=num_workers,
max_queue_size=config["max_queue_size"],
steps_per_epoch=steps_per_epoch,
verbose=config["training_verbosity_mode"]
)
Expand Down Expand Up @@ -129,7 +127,8 @@ def plot_metric(metric, title, filename, plot_validation_metric):
plt.figure()
plt.plot(history.history[metric], label='Training ' + metric)
if plot_validation_metric:
plt.plot(history.history[f"val_{metric}"], label=f"Validation {metric}")
plt.plot(history.history[f"val_{metric}"],
label=f"Validation {metric}")
plt.title(title)
plt.xlabel("Epoch #")
plt.ylabel(metric)
Expand Down

0 comments on commit e5f5f88

Please sign in to comment.