Skip to content

Commit

Permalink
Correct tests
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Oct 10, 2024
1 parent d464747 commit 168f7f7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
5 changes: 2 additions & 3 deletions src/qusi/internal/lightning_train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from warnings import warn

import lightning
from lightning.pytorch.loggers import WandbLogger
from torch.nn import BCELoss, Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -103,11 +102,11 @@ def train_session(

lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric,
logging_metrics=logging_metrics)
wandb_logger = WandbLogger(project=logging_configuration.wandb_project, entity=logging_configuration.wandb_entity)
trainer = lightning.Trainer(
max_epochs=hyperparameter_configuration.cycles,
limit_train_batches=hyperparameter_configuration.train_steps_per_cycle,
limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle,
logger=[wandb_logger]
log_every_n_steps=0,
accelerator=system_configuration.accelerator
)
trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders)
10 changes: 8 additions & 2 deletions src/qusi/internal/train_system_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,25 @@ class TrainSystemConfiguration:
"""

preprocessing_processes_per_train_process: int
accelerator: str

@classmethod
def new(
cls,
*,
preprocessing_processes_per_train_process: int = 10
preprocessing_processes_per_train_process: int = 10,
accelerator: str = 'auto',
):
"""
Creates a `TrainSystemConfiguration`.
:param preprocessing_processes_per_train_process: The number of processes that are started to preprocess the data
per train process. The train session will create this many processes for each the train data and the validation
data.
:param accelerator: A string identifying the Lightning accelerator to use.
:return: The `TrainSystemConfiguration`.
"""
return cls(preprocessing_processes_per_train_process=preprocessing_processes_per_train_process)
return cls(
preprocessing_processes_per_train_process=preprocessing_processes_per_train_process,
accelerator=accelerator,
)
3 changes: 3 additions & 0 deletions tests/end_to_end_tests/test_toy_lightning_train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from qusi.internal.toy_light_curve_collection import get_toy_dataset
from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration
from qusi.internal.lightning_train_session import train_session
from qusi.internal.train_system_configuration import TrainSystemConfiguration


def test_toy_train_session():
Expand All @@ -20,9 +21,11 @@ def test_toy_train_session():
train_hyperparameter_configuration = TrainHyperparameterConfiguration.new(
batch_size=3, cycles=2, train_steps_per_cycle=5, validation_steps_per_cycle=5
)
train_system_configuration = TrainSystemConfiguration.new(accelerator='cpu')
train_session(
train_datasets=[dataset],
validation_datasets=[dataset],
model=model,
hyperparameter_configuration=train_hyperparameter_configuration,
system_configuration=train_system_configuration,
)

0 comments on commit 168f7f7

Please sign in to comment.