From 168f7f7fa0e9810d453236739a25b47dc9e19ddc Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 10 Oct 2024 13:54:09 -0400 Subject: [PATCH] Correct tests --- src/qusi/internal/lightning_train_session.py | 5 ++--- src/qusi/internal/train_system_configuration.py | 10 ++++++++-- .../test_toy_lightning_train_session.py | 3 +++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/qusi/internal/lightning_train_session.py b/src/qusi/internal/lightning_train_session.py index 5980d95..60f72f0 100644 --- a/src/qusi/internal/lightning_train_session.py +++ b/src/qusi/internal/lightning_train_session.py @@ -4,7 +4,6 @@ from warnings import warn import lightning -from lightning.pytorch.loggers import WandbLogger from torch.nn import BCELoss, Module from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -103,11 +102,11 @@ def train_session( lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric, logging_metrics=logging_metrics) - wandb_logger = WandbLogger(project=logging_configuration.wandb_project, entity=logging_configuration.wandb_entity) trainer = lightning.Trainer( max_epochs=hyperparameter_configuration.cycles, limit_train_batches=hyperparameter_configuration.train_steps_per_cycle, limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle, - logger=[wandb_logger] + log_every_n_steps=0, + accelerator=system_configuration.accelerator ) trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders) diff --git a/src/qusi/internal/train_system_configuration.py b/src/qusi/internal/train_system_configuration.py index ac3153b..9e692c5 100644 --- a/src/qusi/internal/train_system_configuration.py +++ b/src/qusi/internal/train_system_configuration.py @@ -12,12 +12,14 @@ class TrainSystemConfiguration: """ preprocessing_processes_per_train_process: int + accelerator: str @classmethod def new( cls, *, - preprocessing_processes_per_train_process: int = 10 + preprocessing_processes_per_train_process: int = 10, + accelerator: str = 'auto', ): """ Creates a `TrainSystemConfiguration`. @@ -25,6 +27,10 @@ def new( :param preprocessing_processes_per_train_process: The number of processes that are started to preprocess the data per train process. The train session will create this many processes for each the train data and the validation data. + :param accelerator: A string identifying the Lightning accelerator to use. :return: The `TrainSystemConfiguration`. """ - return cls(preprocessing_processes_per_train_process=preprocessing_processes_per_train_process) + return cls( + preprocessing_processes_per_train_process=preprocessing_processes_per_train_process, + accelerator=accelerator, + ) diff --git a/tests/end_to_end_tests/test_toy_lightning_train_session.py b/tests/end_to_end_tests/test_toy_lightning_train_session.py index 750d747..ab2b409 100644 --- a/tests/end_to_end_tests/test_toy_lightning_train_session.py +++ b/tests/end_to_end_tests/test_toy_lightning_train_session.py @@ -8,6 +8,7 @@ from qusi.internal.toy_light_curve_collection import get_toy_dataset from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration from qusi.internal.lightning_train_session import train_session +from qusi.internal.train_system_configuration import TrainSystemConfiguration def test_toy_train_session(): @@ -20,9 +21,11 @@ def test_toy_train_session(): train_hyperparameter_configuration = TrainHyperparameterConfiguration.new( batch_size=3, cycles=2, train_steps_per_cycle=5, validation_steps_per_cycle=5 ) + train_system_configuration = TrainSystemConfiguration.new(accelerator='cpu') train_session( train_datasets=[dataset], validation_datasets=[dataset], model=model, hyperparameter_configuration=train_hyperparameter_configuration, + system_configuration=train_system_configuration, )