diff --git a/src/speckcn2/io.py b/src/speckcn2/io.py index 573d7bc..f71ce42 100644 --- a/src/speckcn2/io.py +++ b/src/speckcn2/io.py @@ -8,6 +8,7 @@ from __future__ import annotations import os +import sys import torch import yaml @@ -35,7 +36,9 @@ def load_config(config_file_path: str) -> dict: return config -def save(model: torch.nn.Module, datadirectory: str) -> None: +def save(model: torch.nn.Module, + datadirectory: str, + early_stop: bool = False) -> None: """Save the model state and the model itself to a specified directory. This function saves the model's state dictionary and other relevant information @@ -47,6 +50,8 @@ def save(model: torch.nn.Module, datadirectory: str) -> None: The model to save datadirectory : str The directory where the data is stored + early_stop: bool + If True, the model corresponds to the moment when early stop was triggered """ model_state = { 'epoch': model.epoch, @@ -56,13 +61,18 @@ def save(model: torch.nn.Module, datadirectory: str) -> None: 'model_state_dict': model.state_dict(), } - torch.save( - model_state, - f'{datadirectory}/{model.name}_states/{model.name}_{model.epoch[-1]}.pth' - ) + if not early_stop: + savename = f'{datadirectory}/{model.name}_states/{model.name}_{model.epoch[-1]}.pth' + else: + savename = f'{datadirectory}/{model.name}_states/{model.name}_{model.epoch[-1]}_earlystop.pth' + + torch.save(model_state, savename) -def load(model: torch.nn.Module, datadirectory: str, epoch: int) -> None: +def load(model: torch.nn.Module, + datadirectory: str, + epoch: int, + early_stop: bool = False) -> None: """Load the model state and the model itself from a specified directory and epoch. @@ -77,9 +87,17 @@ def load(model: torch.nn.Module, datadirectory: str, epoch: int) -> None: The directory where the data is stored epoch : int The epoch of the model + early_stop: bool + If True, the last state reached the early stop condition """ - model_state = torch.load( - f'{datadirectory}/{model.name}_states/{model.name}_{epoch}.pth') + if early_stop: + model_state = torch.load( + f'{datadirectory}/{model.name}_states/{model.name}_{epoch}_earlystop.pth' + ) + model.early_stop = True + else: + model_state = torch.load( + f'{datadirectory}/{model.name}_states/{model.name}_{epoch}.pth') model.epoch = model_state['epoch'] model.loss = model_state['loss'] @@ -98,6 +116,8 @@ def load_model_state(model: torch.nn.Module, This function checks the specified directory for the latest model state file, loads it, and updates the model with the loaded state. If no state is found, it initializes the model state. + If the training was stopped after meeting an early stop condition, this function + signals that the training should not be continued. Parameters ---------- @@ -118,30 +138,49 @@ def load_model_state(model: torch.nn.Module, model.nparams = sum(p.numel() for p in model.parameters()) print(f'\n--> Nparams = {model.nparams}') - ensure_directory(f'{datadirectory}/{model.name}_states') - - # Check what is the last model state - try: - last_model_state = sorted([ - int(file_name.split('.pth')[0].split('_')[-1]) - for file_name in os.listdir(f'{datadirectory}/{model.name}_states') - ])[-1] - except Exception as e: - print(f'Warning: {e}') - last_model_state = 0 - - if last_model_state > 0: - print( - f'Loading model at epoch {last_model_state}, from {datadirectory}') - load(model, datadirectory, last_model_state) + fulldirname = f'{datadirectory}/{model.name}_states' + ensure_directory(fulldirname) + + # First check if there was an early stop + earlystop = [ + filename for filename in os.listdir(fulldirname) + if 'earlystop' in filename + ] + if len(earlystop) == 0: + # If there was no early stop, check what is the last model state + try: + last_model_state = sorted([ + int(file_name.split('.pth')[0].split('_')[-1]) + for file_name in os.listdir(fulldirname) + ])[-1] + except Exception as e: + print(f'Warning: {e}') + last_model_state = 0 + + if last_model_state > 0: + print( + f'Loading model at epoch {last_model_state}, from {datadirectory}' + ) + load(model, datadirectory, last_model_state) + return model, last_model_state + else: + print('No pretrained model to load') + + # Initialize some model state measures + model.loss = [] + model.val_loss = [] + model.time = [] + model.epoch = [] + + return model, 0 + elif len(earlystop) == 1: + filename = earlystop[0] + print(f'Loading the early stop state {filename}') + last_model_state = int(filename.split('_')[-2]) + load(model, datadirectory, last_model_state, early_stop=True) return model, last_model_state else: - print('No pretrained model to load') - - # Initialize some model state measures - model.loss = [] - model.val_loss = [] - model.time = [] - model.epoch = [] - - return model, 0 + print( + f'Error: more than one early stop state found. This is not correct. This is the list: {earlystop}' + ) + sys.exit(0) diff --git a/src/speckcn2/mlmodels.py b/src/speckcn2/mlmodels.py index 5c72645..8991e78 100644 --- a/src/speckcn2/mlmodels.py +++ b/src/speckcn2/mlmodels.py @@ -317,3 +317,46 @@ def get_scnn(config: dict) -> tuple[nn.Module, int]: scnn_model.name = model_name return load_model_state(scnn_model, datadirectory) + + +class EarlyStopper: + """ Early stopping to stop the training when the validation loss does not decrease anymore. + """ + + def __init__(self, patience: int = 1, min_delta: float = 0): + """Initializes the EarlyStopper. + + Parameters + ---------- + patience: int + Number of epochs of tolerance before stopping. + min_delta: float + Percentage of tolerance in considering the loss acceptable. + """ + + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.min_validation_loss = float('inf') + + def early_stop(self, validation_loss: float) -> bool: + """ Computes if the early stop condition is met at the current step. + + Parameters + ---------- + validation_loss: float + Current value of the validation loss + + Returns + ------- + bool + It returns True if the training has met the stop condition. + """ + if validation_loss < self.min_validation_loss: + self.min_validation_loss = validation_loss + self.counter = 0 + elif validation_loss > self.min_validation_loss * (1 + self.min_delta): + self.counter += 1 + if self.counter >= self.patience: + return True + return False diff --git a/src/speckcn2/mlops.py b/src/speckcn2/mlops.py index a59bba4..e5cd10c 100644 --- a/src/speckcn2/mlops.py +++ b/src/speckcn2/mlops.py @@ -8,7 +8,7 @@ from speckcn2.io import save from speckcn2.loss import ComposableLoss -from speckcn2.mlmodels import EnsembleModel +from speckcn2.mlmodels import EarlyStopper, EnsembleModel from speckcn2.plots import score_plot from speckcn2.preprocess import Normalizer from speckcn2.utils import ensure_directory @@ -51,10 +51,25 @@ def train(model: nn.Module, last_model_state: int, conf: dict, train_set: list, datadirectory = conf['speckle']['datadirectory'] batch_size = conf['hyppar']['batch_size'] + if getattr(model, 'early_stop', False): + print( + 'Warning: Training reached early stop in a previous training instance' + ) + return model, 0 + + print(f'Training the model from epoch {last_model_state} to {final_epoch}') + # Setup the EnsembleModel wrapper ensemble = EnsembleModel(conf, device) - print(f'Training the model from epoch {last_model_state} to {final_epoch}') + # Early stopper + early_stopping = conf['hyppar'].get('early_stopping', -1) + if early_stopping > 0: + print('Using early stopping') + min_delta = conf['hyppar'].get('early_stop_delta', 0.1) + early_stopper = EarlyStopper(patience=early_stopping, + min_delta=min_delta) + average_loss = 0.0 model.train() for epoch in range(last_model_state, final_epoch): @@ -110,6 +125,10 @@ def train(model: nn.Module, last_model_state: int, conf: dict, train_set: list, f'Test-Loss: {val_loss:.5f}') print(message, flush=True) + if early_stopper.early_stop(val_loss): + print('Early stopping triggered') + save(model, datadirectory, early_stop=True) + if (epoch + 1) % save_every == 0 or epoch == final_epoch - 1: # Save the model state save(model, datadirectory)