25
25
from tqdm import tqdm
26
26
27
27
from conf import project as project_conf
28
- from utils import blink_pbar , to_cuda , update_pbar_str
28
+ from utils import blink_pbar , colorize , to_cuda , update_pbar_str
29
29
from utils .helpers import BestNModelSaver
30
30
from utils .training import visualize_model_predictions
31
31
@@ -260,9 +260,12 @@ def train(
260
260
"""
261
261
if model_ckpt_path is not None :
262
262
self ._load_checkpoint (model_ckpt_path )
263
- if project_conf .PLOT_ENABLED :
264
- self ._setup_plot ()
265
- print (f"[*] Training for { epochs } epochs" )
263
+ print (
264
+ colorize (
265
+ f"[*] Training { self ._run_name } for { epochs } epochs" ,
266
+ project_conf .ANSI_COLORS ["green" ],
267
+ )
268
+ )
266
269
self ._viz_n_samples = visualize_n_samples
267
270
train_losses : List [float ] = []
268
271
val_losses : List [float ] = []
@@ -311,12 +314,16 @@ def train(
311
314
)
312
315
313
316
@staticmethod
314
- def _setup_plot ():
317
+ def _setup_plot (run_name : str , log_scale : bool = False ):
315
318
"""Setup the plot for training and validation losses."""
316
- plt .title ("Training and validation losses " )
319
+ plt .title (f "Training curves for { run_name } " )
317
320
plt .theme ("dark" )
318
321
plt .xlabel ("Epoch" )
319
- plt .ylabel ("Loss" )
322
+ if log_scale :
323
+ plt .ylabel ("Loss (log scale)" )
324
+ plt .yscale ("log" )
325
+ else :
326
+ plt .ylabel ("Loss" )
320
327
plt .grid (True , True )
321
328
322
329
def _plot (self , epoch : int , train_losses : List [float ], val_losses : List [float ]):
@@ -329,18 +336,12 @@ def _plot(self, epoch: int, train_losses: List[float], val_losses: List[float]):
329
336
None
330
337
"""
331
338
plt .clf ()
332
- plt .theme ("dark" )
333
- plt .xlabel ("Epoch" )
334
339
if project_conf .LOG_SCALE_PLOT :
335
340
if any (loss_val <= 0 for loss_val in train_losses + val_losses ):
336
341
raise ValueError (
337
342
"Cannot plot on a log scale if there are non-positive losses."
338
343
)
339
- plt .ylabel ("Loss (log scale)" )
340
- plt .yscale ("log" )
341
- else :
342
- plt .ylabel ("Loss" )
343
- plt .grid (True , True )
344
+ self ._setup_plot (self ._run_name , log_scale = project_conf .LOG_SCALE_PLOT )
344
345
plt .plot (
345
346
list (range (self ._starting_epoch , epoch + 1 )),
346
347
train_losses ,
0 commit comments