Skip to content

Commit effa80c

Browse files
committed
Refactor plotting a tiny bit
1 parent b1c6c2e commit effa80c

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

src/base_trainer.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from tqdm import tqdm
2626

2727
from conf import project as project_conf
28-
from utils import blink_pbar, to_cuda, update_pbar_str
28+
from utils import blink_pbar, colorize, to_cuda, update_pbar_str
2929
from utils.helpers import BestNModelSaver
3030
from utils.training import visualize_model_predictions
3131

@@ -260,9 +260,12 @@ def train(
260260
"""
261261
if model_ckpt_path is not None:
262262
self._load_checkpoint(model_ckpt_path)
263-
if project_conf.PLOT_ENABLED:
264-
self._setup_plot()
265-
print(f"[*] Training for {epochs} epochs")
263+
print(
264+
colorize(
265+
f"[*] Training {self._run_name} for {epochs} epochs",
266+
project_conf.ANSI_COLORS["green"],
267+
)
268+
)
266269
self._viz_n_samples = visualize_n_samples
267270
train_losses: List[float] = []
268271
val_losses: List[float] = []
@@ -311,12 +314,16 @@ def train(
311314
)
312315

313316
@staticmethod
314-
def _setup_plot():
317+
def _setup_plot(run_name: str, log_scale: bool = False):
315318
"""Setup the plot for training and validation losses."""
316-
plt.title("Training and validation losses")
319+
plt.title(f"Training curves for {run_name}")
317320
plt.theme("dark")
318321
plt.xlabel("Epoch")
319-
plt.ylabel("Loss")
322+
if log_scale:
323+
plt.ylabel("Loss (log scale)")
324+
plt.yscale("log")
325+
else:
326+
plt.ylabel("Loss")
320327
plt.grid(True, True)
321328

322329
def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]):
@@ -329,18 +336,12 @@ def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]):
329336
None
330337
"""
331338
plt.clf()
332-
plt.theme("dark")
333-
plt.xlabel("Epoch")
334339
if project_conf.LOG_SCALE_PLOT:
335340
if any(loss_val <= 0 for loss_val in train_losses + val_losses):
336341
raise ValueError(
337342
"Cannot plot on a log scale if there are non-positive losses."
338343
)
339-
plt.ylabel("Loss (log scale)")
340-
plt.yscale("log")
341-
else:
342-
plt.ylabel("Loss")
343-
plt.grid(True, True)
344+
self._setup_plot(self._run_name, log_scale=project_conf.LOG_SCALE_PLOT)
344345
plt.plot(
345346
list(range(self._starting_epoch, epoch + 1)),
346347
train_losses,

0 commit comments

Comments
 (0)