From 30a0fdd8d34c87fdbc9ae4dec8eff4ba367d31c3 Mon Sep 17 00:00:00 2001 From: danielgibert Date: Wed, 16 Oct 2024 16:58:27 +0200 Subject: [PATCH 1/4] Added PyTorch trainer with early stopping --- .../pytorch/early_stopping_pytorch_trainer.py | 131 ++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py diff --git a/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py b/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py new file mode 100644 index 0000000..2bc79ce --- /dev/null +++ b/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py @@ -0,0 +1,131 @@ +"""PyTorch model trainers with early stopping.""" + +import torch.nn +from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader + + +class EarlyStoppingPyTorchTrainer(BasePyTorchTrainer): + """Trainer for PyTorch models.""" + + def __init__(self, optimizer: torch.optim.Optimizer, epochs: int = 5, + loss: torch.nn.Module = None, scheduler: _LRScheduler = None) -> None: + """ + Create PyTorch trainer. + + Parameters + ---------- + optimizer : torch.optim.Optimizer + Optimizer to use for training the model. + epochs : int, optional + Number of epochs, by default 5. + loss : torch.nn.Module, optional + Loss to minimize, by default None. + scheduler : _LRScheduler, optional + Scheduler for the optimizer, by default None. + """ + super().__init__(optimizer, epochs, loss, scheduler) + self._epochs = epochs + self._optimizer = optimizer + self._loss = loss if loss is not None else torch.nn.CrossEntropyLoss() + self._scheduler = scheduler + + def fit(self, model: torch.nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + patience: int) -> torch.nn.Module: + """ + Train model with given loaders and early stopping. + + Parameters + ---------- + model : torch.nn.Module + Pytorch model to be trained. + train_loader : DataLoader + Train data loader. + val_loader : DataLoader + Validation data loader. + patience : int + Number of epochs to wait before early stopping. + + Returns + ------- + torch.nn.Module + Trained model. + """ + best_loss = float("inf") + best_model = None + patience_counter = 0 + for _ in range(self._epochs): + model = self.train(model, train_loader) + val_loss = self.validate(model, val_loader) + if val_loss < best_loss: + best_loss = val_loss + best_model = model + patience_counter = 0 + else: + patience_counter += 1 + if patience_counter >= patience: + break + return best_model + + def train(self, + model: torch.nn.Module, + dataloader: DataLoader) -> torch.nn.Module: + """ + Train model for one epoch with given loader. + + Parameters + ---------- + model : torch.nn.Module + Pytorch model to be trained. + dataloader : DataLoader + Train data loader. + + Returns + ------- + torch.nn.Module + Trained model. + """ + device = next(model.parameters()).device + model = model.train() + for _, (x, y) in enumerate(dataloader): + x, y = x.to(device), y.to(device) + self._optimizer.zero_grad() + outputs = model(x) + loss = self._loss(outputs, y) + loss.backward() + self._optimizer.step() + if self._scheduler is not None: + self._scheduler.step() + return model + + def validate(self, + model: torch.nn.Module, + dataloader: DataLoader) -> torch.nn.Module: + """ + Validate model with given loader. + + Parameters + ---------- + model : torch.nn.Module + Pytorch model to be balidated. + dataloader : DataLoader + Validation data loader. + + Returns + ------- + float + Validation loss of the model. + """ + running_loss = 0 + device = next(model.parameters()).device + model = model.eval() + with torch.no_grad(): + for _, (x, y) in enumerate(dataloader): + x, y = x.to(device), y.to(device) + outputs = model(x) + loss = self._loss(outputs, y) + running_loss += loss.item() + return loss From 189904ae51491c4a52758f06f7282b606387c023 Mon Sep 17 00:00:00 2001 From: danielgibert Date: Sun, 20 Oct 2024 13:45:26 +0200 Subject: [PATCH 2/4] Add docs for Early Stopping Trainer --- docs/source/secmlt.models.pytorch.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/source/secmlt.models.pytorch.rst b/docs/source/secmlt.models.pytorch.rst index 7f72516..f1a53cc 100644 --- a/docs/source/secmlt.models.pytorch.rst +++ b/docs/source/secmlt.models.pytorch.rst @@ -20,6 +20,14 @@ secmlt.models.pytorch.base\_pytorch\_trainer module :undoc-members: :show-inheritance: +secmlt.models.pytorch.early\_stopping\_pytorch\_trainer module +-------------------------------------------------------------- + +.. automodule:: secmlt.models.pytorch.early_stopping_pytorch_trainer + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- From 93043bf2febb17bd47e89557c7b6d63d6fdd8637 Mon Sep 17 00:00:00 2001 From: danielgibert Date: Sun, 20 Oct 2024 13:48:34 +0200 Subject: [PATCH 3/4] Minor comment --- src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py b/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py index 2bc79ce..caab245 100644 --- a/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py +++ b/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py @@ -7,7 +7,7 @@ class EarlyStoppingPyTorchTrainer(BasePyTorchTrainer): - """Trainer for PyTorch models.""" + """Trainer for PyTorch models with early stopping.""" def __init__(self, optimizer: torch.optim.Optimizer, epochs: int = 5, loss: torch.nn.Module = None, scheduler: _LRScheduler = None) -> None: From ce8268b64db176e262f47062ff687a8510c2ebdc Mon Sep 17 00:00:00 2001 From: danielgibert Date: Sun, 20 Oct 2024 13:58:29 +0200 Subject: [PATCH 4/4] Modified return type for validate function --- src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py b/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py index caab245..5ea62f5 100644 --- a/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py +++ b/src/secmlt/models/pytorch/early_stopping_pytorch_trainer.py @@ -103,7 +103,7 @@ def train(self, def validate(self, model: torch.nn.Module, - dataloader: DataLoader) -> torch.nn.Module: + dataloader: DataLoader) -> float: """ Validate model with given loader.