Skip to content

Commit

Permalink
add early stop condition
Browse files Browse the repository at this point in the history
  • Loading branch information
SCiarella committed Oct 23, 2024
1 parent e1bb010 commit 71827db
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 35 deletions.
105 changes: 72 additions & 33 deletions src/speckcn2/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import os
import sys

import torch
import yaml
Expand Down Expand Up @@ -35,7 +36,9 @@ def load_config(config_file_path: str) -> dict:
return config


def save(model: torch.nn.Module, datadirectory: str) -> None:
def save(model: torch.nn.Module,
datadirectory: str,
early_stop: bool = False) -> None:
"""Save the model state and the model itself to a specified directory.
This function saves the model's state dictionary and other relevant information
Expand All @@ -47,6 +50,8 @@ def save(model: torch.nn.Module, datadirectory: str) -> None:
The model to save
datadirectory : str
The directory where the data is stored
early_stop: bool
If True, the model corresponds to the moment when early stop was triggered
"""
model_state = {
'epoch': model.epoch,
Expand All @@ -56,13 +61,18 @@ def save(model: torch.nn.Module, datadirectory: str) -> None:
'model_state_dict': model.state_dict(),
}

torch.save(
model_state,
f'{datadirectory}/{model.name}_states/{model.name}_{model.epoch[-1]}.pth'
)
if not early_stop:
savename = f'{datadirectory}/{model.name}_states/{model.name}_{model.epoch[-1]}.pth'
else:
savename = f'{datadirectory}/{model.name}_states/{model.name}_{model.epoch[-1]}_earlystop.pth'

torch.save(model_state, savename)


def load(model: torch.nn.Module, datadirectory: str, epoch: int) -> None:
def load(model: torch.nn.Module,
datadirectory: str,
epoch: int,
early_stop: bool = False) -> None:
"""Load the model state and the model itself from a specified directory and
epoch.
Expand All @@ -77,9 +87,17 @@ def load(model: torch.nn.Module, datadirectory: str, epoch: int) -> None:
The directory where the data is stored
epoch : int
The epoch of the model
early_stop: bool
If True, the last state reached the early stop condition
"""
model_state = torch.load(
f'{datadirectory}/{model.name}_states/{model.name}_{epoch}.pth')
if early_stop:
model_state = torch.load(
f'{datadirectory}/{model.name}_states/{model.name}_{epoch}_earlystop.pth'
)
model.early_stop = True
else:
model_state = torch.load(
f'{datadirectory}/{model.name}_states/{model.name}_{epoch}.pth')

model.epoch = model_state['epoch']
model.loss = model_state['loss']
Expand All @@ -98,6 +116,8 @@ def load_model_state(model: torch.nn.Module,
This function checks the specified directory for the latest model state file,
loads it, and updates the model with the loaded state. If no state is found,
it initializes the model state.
If the training was stopped after meeting an early stop condition, this function
signals that the training should not be continued.
Parameters
----------
Expand All @@ -118,30 +138,49 @@ def load_model_state(model: torch.nn.Module,
model.nparams = sum(p.numel() for p in model.parameters())
print(f'\n--> Nparams = {model.nparams}')

ensure_directory(f'{datadirectory}/{model.name}_states')

# Check what is the last model state
try:
last_model_state = sorted([
int(file_name.split('.pth')[0].split('_')[-1])
for file_name in os.listdir(f'{datadirectory}/{model.name}_states')
])[-1]
except Exception as e:
print(f'Warning: {e}')
last_model_state = 0

if last_model_state > 0:
print(
f'Loading model at epoch {last_model_state}, from {datadirectory}')
load(model, datadirectory, last_model_state)
fulldirname = f'{datadirectory}/{model.name}_states'
ensure_directory(fulldirname)

# First check if there was an early stop
earlystop = [
filename for filename in os.listdir(fulldirname)
if 'earlystop' in filename
]
if len(earlystop) == 0:
# If there was no early stop, check what is the last model state
try:
last_model_state = sorted([
int(file_name.split('.pth')[0].split('_')[-1])
for file_name in os.listdir(fulldirname)
])[-1]
except Exception as e:
print(f'Warning: {e}')
last_model_state = 0

if last_model_state > 0:
print(
f'Loading model at epoch {last_model_state}, from {datadirectory}'
)
load(model, datadirectory, last_model_state)
return model, last_model_state
else:
print('No pretrained model to load')

# Initialize some model state measures
model.loss = []
model.val_loss = []
model.time = []
model.epoch = []

return model, 0
elif len(earlystop) == 1:
filename = earlystop[0]
print(f'Loading the early stop state {filename}')
last_model_state = int(filename.split('_')[-2])
load(model, datadirectory, last_model_state, early_stop=True)
return model, last_model_state
else:
print('No pretrained model to load')

# Initialize some model state measures
model.loss = []
model.val_loss = []
model.time = []
model.epoch = []

return model, 0
print(
f'Error: more than one early stop state found. This is not correct. This is the list: {earlystop}'
)
sys.exit(0)
43 changes: 43 additions & 0 deletions src/speckcn2/mlmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,46 @@ def get_scnn(config: dict) -> tuple[nn.Module, int]:
scnn_model.name = model_name

return load_model_state(scnn_model, datadirectory)


class EarlyStopper:
""" Early stopping to stop the training when the validation loss does not decrease anymore.
"""

def __init__(self, patience: int = 1, min_delta: float = 0):
"""Initializes the EarlyStopper.
Parameters
----------
patience: int
Number of epochs of tolerance before stopping.
min_delta: float
Percentage of tolerance in considering the loss acceptable.
"""

self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.min_validation_loss = float('inf')

def early_stop(self, validation_loss: float) -> bool:
""" Computes if the early stop condition is met at the current step.
Parameters
----------
validation_loss: float
Current value of the validation loss
Returns
-------
bool
It returns True if the training has met the stop condition.
"""
if validation_loss < self.min_validation_loss:
self.min_validation_loss = validation_loss
self.counter = 0
elif validation_loss > self.min_validation_loss * (1 + self.min_delta):
self.counter += 1
if self.counter >= self.patience:
return True
return False
23 changes: 21 additions & 2 deletions src/speckcn2/mlops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from speckcn2.io import save
from speckcn2.loss import ComposableLoss
from speckcn2.mlmodels import EnsembleModel
from speckcn2.mlmodels import EarlyStopper, EnsembleModel
from speckcn2.plots import score_plot
from speckcn2.preprocess import Normalizer
from speckcn2.utils import ensure_directory
Expand Down Expand Up @@ -51,10 +51,25 @@ def train(model: nn.Module, last_model_state: int, conf: dict, train_set: list,
datadirectory = conf['speckle']['datadirectory']
batch_size = conf['hyppar']['batch_size']

if getattr(model, 'early_stop', False):
print(
'Warning: Training reached early stop in a previous training instance'
)
return model, 0

print(f'Training the model from epoch {last_model_state} to {final_epoch}')

# Setup the EnsembleModel wrapper
ensemble = EnsembleModel(conf, device)

print(f'Training the model from epoch {last_model_state} to {final_epoch}')
# Early stopper
early_stopping = conf['hyppar'].get('early_stopping', -1)
if early_stopping > 0:
print('Using early stopping')
min_delta = conf['hyppar'].get('early_stop_delta', 0.1)
early_stopper = EarlyStopper(patience=early_stopping,
min_delta=min_delta)

average_loss = 0.0
model.train()
for epoch in range(last_model_state, final_epoch):
Expand Down Expand Up @@ -110,6 +125,10 @@ def train(model: nn.Module, last_model_state: int, conf: dict, train_set: list,
f'Test-Loss: {val_loss:.5f}')
print(message, flush=True)

if early_stopper.early_stop(val_loss):
print('Early stopping triggered')
save(model, datadirectory, early_stop=True)

if (epoch + 1) % save_every == 0 or epoch == final_epoch - 1:
# Save the model state
save(model, datadirectory)
Expand Down

0 comments on commit 71827db

Please sign in to comment.