diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index dafc74b..ba03607 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import itertools import math import re import shutil diff --git a/src/qusi/internal/lightning_train_session.py b/src/qusi/internal/lightning_train_session.py index 60f72f0..e2f4594 100644 --- a/src/qusi/internal/lightning_train_session.py +++ b/src/qusi/internal/lightning_train_session.py @@ -1,9 +1,12 @@ from __future__ import annotations +import datetime import logging +from pathlib import Path from warnings import warn import lightning +from lightning.pytorch.loggers import CSVLogger, WandbLogger from torch.nn import BCELoss, Module from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -102,11 +105,18 @@ def train_session( lightning_model = QusiLightningModule.new(model=model, optimizer=optimizer, loss_metric=loss_metric, logging_metrics=logging_metrics) + sessions_directory_path = Path(f'sessions') + session_name = f'{datetime.datetime.now():%Y_%m_%d_%H_%M_%S}' + sessions_directory_path.mkdir(exist_ok=True, parents=True) + loggers = [ + CSVLogger(save_dir=sessions_directory_path, name=session_name), + WandbLogger(save_dir=sessions_directory_path, name=session_name)] trainer = lightning.Trainer( max_epochs=hyperparameter_configuration.cycles, limit_train_batches=hyperparameter_configuration.train_steps_per_cycle, limit_val_batches=hyperparameter_configuration.validation_steps_per_cycle, log_every_n_steps=0, - accelerator=system_configuration.accelerator + accelerator=system_configuration.accelerator, + logger=loggers, ) trainer.fit(model=lightning_model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloaders) diff --git a/src/qusi/internal/module.py b/src/qusi/internal/module.py index 5f6621a..7878f7d 100644 --- a/src/qusi/internal/module.py +++ b/src/qusi/internal/module.py @@ -53,7 +53,7 @@ class QusiLightningModule(LightningModule): def new( cls, model: Module, - optimizer: Optimizer | None, + optimizer: Optimizer | None = None, loss_metric: Module | None = None, logging_metrics: list[Module] | None = None, ) -> Self: @@ -101,7 +101,7 @@ def on_train_epoch_start(self) -> None: # Due to Lightning's inconsistent step ordering, performing this during the train epoch start gives the most # consistent results. self.cycle += 1 - self.log(name='cycle', value=self.cycle) + self.log(name='cycle', value=self.cycle, reduce_fx=torch.max, rank_zero_only=True, on_step=False, on_epoch=True) def training_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: return self.compute_loss_and_metrics(batch, self.train_metric_group) @@ -122,13 +122,17 @@ def compute_loss_and_metrics(self, batch: tuple[Any, Any], metric_group: MetricG return loss def on_train_epoch_end(self) -> None: - self.log_loss_and_metrics(self.train_metric_group) + self.log_loss_and_metrics(self.train_metric_group, logging_name_prefix='') def log_loss_and_metrics(self, metric_group: MetricGroup, logging_name_prefix: str = ''): + mean_cycle_loss = metric_group.loss_cycle_total / metric_group.steps_run_in_phase + self.log(name=logging_name_prefix + 'loss', + value=mean_cycle_loss, sync_dist=True, on_step=False, + on_epoch=True) for state_based_logging_metric in metric_group.state_based_logging_metrics: state_based_logging_metric_name = get_metric_name(state_based_logging_metric) self.log(name=logging_name_prefix + state_based_logging_metric_name, - value=state_based_logging_metric.compute(), sync_dist=True) + value=state_based_logging_metric.compute(), sync_dist=True, on_step=False, on_epoch=True) state_based_logging_metric.reset() for functional_logging_metric_index, functional_logging_metric in enumerate( metric_group.functional_logging_metrics): @@ -139,10 +143,7 @@ def log_loss_and_metrics(self, metric_group: MetricGroup, logging_name_prefix: s functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / metric_group.steps_run_in_phase self.log(name=logging_name_prefix + functional_logging_metric_name, value=functional_logging_metric_cycle_mean, - sync_dist=True) - mean_cycle_loss = metric_group.loss_cycle_total / metric_group.steps_run_in_phase - self.log(name=logging_name_prefix + 'loss', - value=mean_cycle_loss, sync_dist=True) + sync_dist=True, on_step=False, on_epoch=True) metric_group.loss_cycle_total.zero_() metric_group.steps_run_in_phase.zero_() metric_group.functional_logging_metric_cycle_totals.zero_() diff --git a/src/qusi/internal/toy_light_curve_collection.py b/src/qusi/internal/toy_light_curve_collection.py index 858792d..dd1221b 100644 --- a/src/qusi/internal/toy_light_curve_collection.py +++ b/src/qusi/internal/toy_light_curve_collection.py @@ -112,7 +112,9 @@ def get_toy_dataset(): standard_light_curve_collections=[ get_toy_sine_wave_light_curve_observation_collection(), get_toy_flat_light_curve_observation_collection(), - ] + ], + post_injection_transform=partial(default_light_curve_observation_post_injection_transform, + length=100, randomize=False) )