Skip to content

Commit

Permalink
Add basic PyTorch Lightning module version
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Sep 2, 2024
1 parent 6da87f0 commit a14b31c
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 0 deletions.
108 changes: 108 additions & 0 deletions src/qusi/internal/lightning_train_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

import logging
from warnings import warn

import lightning
from torch.nn import BCELoss, Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC

from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset
from qusi.internal.logging import set_up_default_logger
from qusi.internal.module import QusiLightningModule
from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration
from qusi.internal.train_logging_configuration import TrainLoggingConfiguration
from qusi.internal.train_system_configuration import TrainSystemConfiguration

logger = logging.getLogger(__name__)


def train_session(
train_datasets: list[LightCurveDataset],
validation_datasets: list[LightCurveDataset],
model: Module,
optimizer: Optimizer | None = None,
loss_metric: Module | None = None,
logging_metrics: list[Module] | None = None,
*,
hyperparameter_configuration: TrainHyperparameterConfiguration | None = None,
system_configuration: TrainSystemConfiguration | None = None,
logging_configuration: TrainLoggingConfiguration | None = None,
# Deprecated keyword parameters.
loss_function: Module | None = None,
metric_functions: list[Module] | None = None,
) -> None:
"""
Runs a training session.
:param train_datasets: The datasets to train on.
:param validation_datasets: The datasets to validate on.
:param model: The model to train.
:param optimizer: The optimizer to be used during training.
:param loss_metric: The loss function to train the model on.
:param logging_metrics: A list of metric functions to record during the training process.
:param hyperparameter_configuration: The configuration of the hyperparameters.
:param system_configuration: The configuration of the system.
:param logging_configuration: The configuration of the logging.
"""
if loss_metric is not None and loss_function is not None:
raise ValueError('Both `loss_metric` and `loss_function` cannot be set at the same time.')
if logging_metrics is not None and metric_functions is not None:
raise ValueError('Both `logging_metrics` and `metric_functions` cannot be set at the same time.')
if loss_function is not None:
warn('`loss_function` is deprecated and will be removed in the future. '
'Please use `loss_metric` instead.', UserWarning)
loss_metric = loss_function
if metric_functions is not None:
warn('`metric_functions` is deprecated and will be removed in the future. '
'Please use `logging_metrics` instead.', UserWarning)
logging_metrics = metric_functions

if hyperparameter_configuration is None:
hyperparameter_configuration = TrainHyperparameterConfiguration.new()
if system_configuration is None:
system_configuration = TrainSystemConfiguration.new()
if loss_metric is None:
loss_metric = BCELoss()
if logging_metrics is None:
logging_metrics = [BinaryAccuracy(), BinaryAUROC()]

set_up_default_logger()
train_dataset = InterleavedDataset.new(*train_datasets)
workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process
if workers_per_dataloader == 0:
prefetch_factor = None
persistent_workers = False
else:
prefetch_factor = 10
persistent_workers = True
train_dataloader = DataLoader(
train_dataset,
batch_size=hyperparameter_configuration.batch_size,
pin_memory=True,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
num_workers=workers_per_dataloader,
)
validation_dataloaders: list[DataLoader] = []
for validation_dataset in validation_datasets:
validation_dataloader = DataLoader(
validation_dataset,
batch_size=hyperparameter_configuration.batch_size,
pin_memory=True,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
num_workers=workers_per_dataloader,
)
validation_dataloaders.append(validation_dataloader)

lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric,
logging_metrics=logging_metrics)
trainer = lightning.Trainer(
max_epochs=hyperparameter_configuration.cycles,
limit_train_batches=hyperparameter_configuration.train_steps_per_cycle,
limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle,
)
trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders)
111 changes: 111 additions & 0 deletions src/qusi/internal/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Any

import numpy as np
import numpy.typing as npt
from lightning import LightningModule
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch.nn import Module, BCELoss, ModuleList
from torch.optim import Optimizer, AdamW
from torchmetrics import Metric
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC
from typing_extensions import Self

from qusi.internal.logging import get_metric_name


class QusiLightningModule(LightningModule):
@classmethod
def new(
cls,
model: Module,
optimizer: Optimizer | None,
loss_metric: Module | None = None,
logging_metrics: list[Module] | None = None,
) -> Self:
if optimizer is None:
optimizer = AdamW(model.parameters())
if loss_metric is None:
loss_metric = BCELoss()
if logging_metrics is None:
logging_metrics = [BinaryAccuracy(), BinaryAUROC()]
state_based_logging_metrics: ModuleList = ModuleList()
functional_logging_metrics: list[Module] = []
for logging_metric in logging_metrics:
if isinstance(logging_metric, Metric):
state_based_logging_metrics.append(logging_metric)
else:
functional_logging_metrics.append(logging_metric)
instance = cls(model=model, optimizer=optimizer, loss_metric=loss_metric,
state_based_logging_metrics=state_based_logging_metrics,
functional_logging_metrics=functional_logging_metrics)
return instance

def __init__(
self,
model: Module,
optimizer: Optimizer,
loss_metric: Module,
state_based_logging_metrics: ModuleList,
functional_logging_metrics: list[Module],
):
super().__init__()
self.model: Module = model
self._optimizer: Optimizer = optimizer
self.loss_metric: Module = loss_metric
self.state_based_logging_metrics: ModuleList = state_based_logging_metrics
self.functional_logging_metrics: list[Module] = functional_logging_metrics
self._functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(len(self.functional_logging_metrics),
dtype=np.float32)
self._loss_cycle_total: int = 0
self._steps_run_in_cycle: int = 0

def forward(self, inputs: Any) -> Any:
return self.model(inputs)

def training_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT:
return self.compute_loss_and_metrics(batch)

def compute_loss_and_metrics(self, batch):
inputs, target = batch
predicted = self(inputs)
loss = self.loss_metric(predicted, target)
self._loss_cycle_total += loss
for state_based_logging_metric in self.state_based_logging_metrics:
state_based_logging_metric(predicted, target)
for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics):
functional_logging_metric_value = functional_logging_metric(predicted, target)
self._functional_logging_metric_cycle_totals[
functional_logging_metric_index] += functional_logging_metric_value
self._steps_run_in_cycle += 1
return loss

def on_train_epoch_end(self) -> None:
self.log_loss_and_metrics()

def log_loss_and_metrics(self, logging_name_prefix: str = ''):
for state_based_logging_metric in self.state_based_logging_metrics:
state_based_logging_metric_name = get_metric_name(state_based_logging_metric)
self.log(name=logging_name_prefix + state_based_logging_metric_name,
value=state_based_logging_metric.compute(), sync_dist=True)
state_based_logging_metric.reset()
for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics):
functional_logging_metric_name = get_metric_name(functional_logging_metric)
functional_logging_metric_cycle_total = float(self._functional_logging_metric_cycle_totals[
functional_logging_metric_index])

functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_cycle
self.log(name=logging_name_prefix + functional_logging_metric_name,
value=functional_logging_metric_cycle_mean,
sync_dist=True)
mean_cycle_loss = self._loss_cycle_total / self._steps_run_in_cycle
self.log(name=logging_name_prefix + 'loss',
value=mean_cycle_loss, sync_dist=True)
self._loss_cycle_total = 0
self._functional_logging_metric_cycle_totals = np.zeros(len(self.functional_logging_metrics), dtype=np.float32)
self._steps_run_in_cycle = 0

def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT:
return self.compute_loss_and_metrics(batch)

def configure_optimizers(self):
return self._optimizer
28 changes: 28 additions & 0 deletions tests/end_to_end_tests/test_toy_train_lightning_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
from functools import partial

from qusi.internal.light_curve_dataset import (
default_light_curve_observation_post_injection_transform,
)
from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel
from qusi.internal.toy_light_curve_collection import get_toy_dataset
from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration
from qusi.internal.lightning_train_session import train_session


def test_toy_train_session():
os.environ["WANDB_MODE"] = "disabled"
model = SingleDenseLayerBinaryClassificationModel.new(input_size=100)
dataset = get_toy_dataset()
dataset.post_injection_transform = partial(
default_light_curve_observation_post_injection_transform, length=100
)
train_hyperparameter_configuration = TrainHyperparameterConfiguration.new(
batch_size=3, cycles=2, train_steps_per_cycle=5, validation_steps_per_cycle=5
)
train_session(
train_datasets=[dataset],
validation_datasets=[dataset],
model=model,
hyperparameter_configuration=train_hyperparameter_configuration,
)

0 comments on commit a14b31c

Please sign in to comment.