-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic PyTorch Lightning module version
- Loading branch information
1 parent
6da87f0
commit a14b31c
Showing
3 changed files
with
247 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from warnings import warn | ||
|
||
import lightning | ||
from torch.nn import BCELoss, Module | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader | ||
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC | ||
|
||
from qusi.internal.light_curve_dataset import InterleavedDataset, LightCurveDataset | ||
from qusi.internal.logging import set_up_default_logger | ||
from qusi.internal.module import QusiLightningModule | ||
from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration | ||
from qusi.internal.train_logging_configuration import TrainLoggingConfiguration | ||
from qusi.internal.train_system_configuration import TrainSystemConfiguration | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def train_session( | ||
train_datasets: list[LightCurveDataset], | ||
validation_datasets: list[LightCurveDataset], | ||
model: Module, | ||
optimizer: Optimizer | None = None, | ||
loss_metric: Module | None = None, | ||
logging_metrics: list[Module] | None = None, | ||
*, | ||
hyperparameter_configuration: TrainHyperparameterConfiguration | None = None, | ||
system_configuration: TrainSystemConfiguration | None = None, | ||
logging_configuration: TrainLoggingConfiguration | None = None, | ||
# Deprecated keyword parameters. | ||
loss_function: Module | None = None, | ||
metric_functions: list[Module] | None = None, | ||
) -> None: | ||
""" | ||
Runs a training session. | ||
:param train_datasets: The datasets to train on. | ||
:param validation_datasets: The datasets to validate on. | ||
:param model: The model to train. | ||
:param optimizer: The optimizer to be used during training. | ||
:param loss_metric: The loss function to train the model on. | ||
:param logging_metrics: A list of metric functions to record during the training process. | ||
:param hyperparameter_configuration: The configuration of the hyperparameters. | ||
:param system_configuration: The configuration of the system. | ||
:param logging_configuration: The configuration of the logging. | ||
""" | ||
if loss_metric is not None and loss_function is not None: | ||
raise ValueError('Both `loss_metric` and `loss_function` cannot be set at the same time.') | ||
if logging_metrics is not None and metric_functions is not None: | ||
raise ValueError('Both `logging_metrics` and `metric_functions` cannot be set at the same time.') | ||
if loss_function is not None: | ||
warn('`loss_function` is deprecated and will be removed in the future. ' | ||
'Please use `loss_metric` instead.', UserWarning) | ||
loss_metric = loss_function | ||
if metric_functions is not None: | ||
warn('`metric_functions` is deprecated and will be removed in the future. ' | ||
'Please use `logging_metrics` instead.', UserWarning) | ||
logging_metrics = metric_functions | ||
|
||
if hyperparameter_configuration is None: | ||
hyperparameter_configuration = TrainHyperparameterConfiguration.new() | ||
if system_configuration is None: | ||
system_configuration = TrainSystemConfiguration.new() | ||
if loss_metric is None: | ||
loss_metric = BCELoss() | ||
if logging_metrics is None: | ||
logging_metrics = [BinaryAccuracy(), BinaryAUROC()] | ||
|
||
set_up_default_logger() | ||
train_dataset = InterleavedDataset.new(*train_datasets) | ||
workers_per_dataloader = system_configuration.preprocessing_processes_per_train_process | ||
if workers_per_dataloader == 0: | ||
prefetch_factor = None | ||
persistent_workers = False | ||
else: | ||
prefetch_factor = 10 | ||
persistent_workers = True | ||
train_dataloader = DataLoader( | ||
train_dataset, | ||
batch_size=hyperparameter_configuration.batch_size, | ||
pin_memory=True, | ||
persistent_workers=persistent_workers, | ||
prefetch_factor=prefetch_factor, | ||
num_workers=workers_per_dataloader, | ||
) | ||
validation_dataloaders: list[DataLoader] = [] | ||
for validation_dataset in validation_datasets: | ||
validation_dataloader = DataLoader( | ||
validation_dataset, | ||
batch_size=hyperparameter_configuration.batch_size, | ||
pin_memory=True, | ||
persistent_workers=persistent_workers, | ||
prefetch_factor=prefetch_factor, | ||
num_workers=workers_per_dataloader, | ||
) | ||
validation_dataloaders.append(validation_dataloader) | ||
|
||
lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric, | ||
logging_metrics=logging_metrics) | ||
trainer = lightning.Trainer( | ||
max_epochs=hyperparameter_configuration.cycles, | ||
limit_train_batches=hyperparameter_configuration.train_steps_per_cycle, | ||
limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle, | ||
) | ||
trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
from lightning import LightningModule | ||
from lightning.pytorch.utilities.types import STEP_OUTPUT | ||
from torch.nn import Module, BCELoss, ModuleList | ||
from torch.optim import Optimizer, AdamW | ||
from torchmetrics import Metric | ||
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC | ||
from typing_extensions import Self | ||
|
||
from qusi.internal.logging import get_metric_name | ||
|
||
|
||
class QusiLightningModule(LightningModule): | ||
@classmethod | ||
def new( | ||
cls, | ||
model: Module, | ||
optimizer: Optimizer | None, | ||
loss_metric: Module | None = None, | ||
logging_metrics: list[Module] | None = None, | ||
) -> Self: | ||
if optimizer is None: | ||
optimizer = AdamW(model.parameters()) | ||
if loss_metric is None: | ||
loss_metric = BCELoss() | ||
if logging_metrics is None: | ||
logging_metrics = [BinaryAccuracy(), BinaryAUROC()] | ||
state_based_logging_metrics: ModuleList = ModuleList() | ||
functional_logging_metrics: list[Module] = [] | ||
for logging_metric in logging_metrics: | ||
if isinstance(logging_metric, Metric): | ||
state_based_logging_metrics.append(logging_metric) | ||
else: | ||
functional_logging_metrics.append(logging_metric) | ||
instance = cls(model=model, optimizer=optimizer, loss_metric=loss_metric, | ||
state_based_logging_metrics=state_based_logging_metrics, | ||
functional_logging_metrics=functional_logging_metrics) | ||
return instance | ||
|
||
def __init__( | ||
self, | ||
model: Module, | ||
optimizer: Optimizer, | ||
loss_metric: Module, | ||
state_based_logging_metrics: ModuleList, | ||
functional_logging_metrics: list[Module], | ||
): | ||
super().__init__() | ||
self.model: Module = model | ||
self._optimizer: Optimizer = optimizer | ||
self.loss_metric: Module = loss_metric | ||
self.state_based_logging_metrics: ModuleList = state_based_logging_metrics | ||
self.functional_logging_metrics: list[Module] = functional_logging_metrics | ||
self._functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(len(self.functional_logging_metrics), | ||
dtype=np.float32) | ||
self._loss_cycle_total: int = 0 | ||
self._steps_run_in_cycle: int = 0 | ||
|
||
def forward(self, inputs: Any) -> Any: | ||
return self.model(inputs) | ||
|
||
def training_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: | ||
return self.compute_loss_and_metrics(batch) | ||
|
||
def compute_loss_and_metrics(self, batch): | ||
inputs, target = batch | ||
predicted = self(inputs) | ||
loss = self.loss_metric(predicted, target) | ||
self._loss_cycle_total += loss | ||
for state_based_logging_metric in self.state_based_logging_metrics: | ||
state_based_logging_metric(predicted, target) | ||
for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics): | ||
functional_logging_metric_value = functional_logging_metric(predicted, target) | ||
self._functional_logging_metric_cycle_totals[ | ||
functional_logging_metric_index] += functional_logging_metric_value | ||
self._steps_run_in_cycle += 1 | ||
return loss | ||
|
||
def on_train_epoch_end(self) -> None: | ||
self.log_loss_and_metrics() | ||
|
||
def log_loss_and_metrics(self, logging_name_prefix: str = ''): | ||
for state_based_logging_metric in self.state_based_logging_metrics: | ||
state_based_logging_metric_name = get_metric_name(state_based_logging_metric) | ||
self.log(name=logging_name_prefix + state_based_logging_metric_name, | ||
value=state_based_logging_metric.compute(), sync_dist=True) | ||
state_based_logging_metric.reset() | ||
for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics): | ||
functional_logging_metric_name = get_metric_name(functional_logging_metric) | ||
functional_logging_metric_cycle_total = float(self._functional_logging_metric_cycle_totals[ | ||
functional_logging_metric_index]) | ||
|
||
functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_cycle | ||
self.log(name=logging_name_prefix + functional_logging_metric_name, | ||
value=functional_logging_metric_cycle_mean, | ||
sync_dist=True) | ||
mean_cycle_loss = self._loss_cycle_total / self._steps_run_in_cycle | ||
self.log(name=logging_name_prefix + 'loss', | ||
value=mean_cycle_loss, sync_dist=True) | ||
self._loss_cycle_total = 0 | ||
self._functional_logging_metric_cycle_totals = np.zeros(len(self.functional_logging_metrics), dtype=np.float32) | ||
self._steps_run_in_cycle = 0 | ||
|
||
def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: | ||
return self.compute_loss_and_metrics(batch) | ||
|
||
def configure_optimizers(self): | ||
return self._optimizer |
28 changes: 28 additions & 0 deletions
28
tests/end_to_end_tests/test_toy_train_lightning_session.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
from functools import partial | ||
|
||
from qusi.internal.light_curve_dataset import ( | ||
default_light_curve_observation_post_injection_transform, | ||
) | ||
from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel | ||
from qusi.internal.toy_light_curve_collection import get_toy_dataset | ||
from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration | ||
from qusi.internal.lightning_train_session import train_session | ||
|
||
|
||
def test_toy_train_session(): | ||
os.environ["WANDB_MODE"] = "disabled" | ||
model = SingleDenseLayerBinaryClassificationModel.new(input_size=100) | ||
dataset = get_toy_dataset() | ||
dataset.post_injection_transform = partial( | ||
default_light_curve_observation_post_injection_transform, length=100 | ||
) | ||
train_hyperparameter_configuration = TrainHyperparameterConfiguration.new( | ||
batch_size=3, cycles=2, train_steps_per_cycle=5, validation_steps_per_cycle=5 | ||
) | ||
train_session( | ||
train_datasets=[dataset], | ||
validation_datasets=[dataset], | ||
model=model, | ||
hyperparameter_configuration=train_hyperparameter_configuration, | ||
) |