Skip to content

Commit

Permalink
make training plot a bit nicer
Browse files Browse the repository at this point in the history
  • Loading branch information
rvankoert committed Aug 11, 2024
1 parent 39d1a2e commit e3ef59c
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/modes/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,23 @@ def plot_training_history(history: tf.keras.callbacks.History,
other for Character Error Rate (CER).
"""

def plot_metric(metric, title, filename):
def plot_metric(metric, title, filename, plot_validation_metric):
plt.style.use("ggplot")
plt.figure()
plt.plot(history.history[metric], label=metric)
if plot_validation:
plt.plot(history.history[f"val_{metric}"], label=f"val_{metric}")
plt.plot(history.history[metric], label='Training ' + metric)
if plot_validation_metric:
plt.plot(history.history[f"val_{metric}"], label=f"Validation {metric}")
plt.title(title)
plt.xlabel("Epoch #")
plt.ylabel("Loss/CER")
plt.legend(loc="lower left")
plt.ylabel(metric)
plt.legend(loc="upper right")
plt.savefig(os.path.join(output_path, filename))

plot_metric("loss", "Training Loss", 'loss_plot.png')
plot_metric("CER_metric", "Character Error Rate (CER)", 'cer_plot.png')
plot_metric(metric="loss",
title="Training Loss",
filename='loss_plot.png',
plot_validation_metric=plot_validation)
plot_metric(metric="CER_metric",
title="Character Error Rate (CER)",
filename='cer_plot.png',
plot_validation_metric=plot_validation)

0 comments on commit e3ef59c

Please sign in to comment.