From 0ffb59f866209d4786f07cf267bd57b6cda528a5 Mon Sep 17 00:00:00 2001 From: golmschenk Date: Thu, 10 Oct 2024 00:39:39 -0400 Subject: [PATCH] Switch the values in the metric group to be device movable tensors --- src/qusi/internal/module.py | 26 +++++++++-------- .../test_toy_train_lightning_session.py | 28 ------------------- .../photometric_database/test_tess_target.py | 2 +- 3 files changed, 16 insertions(+), 40 deletions(-) delete mode 100644 tests/end_to_end_tests/test_toy_train_lightning_session.py diff --git a/src/qusi/internal/module.py b/src/qusi/internal/module.py index daa185dd..d13d50af 100644 --- a/src/qusi/internal/module.py +++ b/src/qusi/internal/module.py @@ -1,10 +1,10 @@ import copy from typing import Any -import numpy as np -import numpy.typing as npt +import torch from lightning import LightningModule from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import Tensor, tensor from torch.nn import Module, BCELoss, ModuleList from torch.optim import Optimizer, AdamW from torchmetrics import Metric @@ -21,10 +21,15 @@ def __init__(self, loss_metric: Module, state_based_logging_metrics: ModuleList, self.loss_metric: Module = loss_metric self.state_based_logging_metrics: ModuleList = state_based_logging_metrics self.functional_logging_metrics: ModuleList = functional_logging_metrics - self.loss_cycle_total: float = 0 - self.steps_run_in_phase: int = 0 - self.functional_logging_metric_cycle_totals: npt.NDArray = np.zeros( - len(self.functional_logging_metrics), dtype=np.float32) + # Lightning requires tensors be registered to be automatically moved between devices. + # Then we assign it to itself to force IDE resolution. + self.register_buffer('loss_cycle_total', tensor(0, dtype=torch.float32)) + self.loss_cycle_total: Tensor = self.loss_cycle_total + self.register_buffer('steps_run_in_phase', tensor(0, dtype=torch.int32)) + self.steps_run_in_phase: Tensor = self.steps_run_in_phase + self.register_buffer('functional_logging_metric_cycle_totals', + torch.zeros(len(self.functional_logging_metrics), dtype=torch.float32)) + self.functional_logging_metric_cycle_totals: Tensor = self.functional_logging_metric_cycle_totals @classmethod def new( @@ -48,7 +53,7 @@ def new( model: Module, optimizer: Optimizer | None, loss_metric: Module | None = None, - logging_metrics: ModuleList | None = None, + logging_metrics: list[Module] | None = None, ) -> Self: if optimizer is None: optimizer = AdamW(model.parameters()) @@ -125,10 +130,9 @@ def log_loss_and_metrics(self, metric_group: MetricGroup, logging_name_prefix: s mean_cycle_loss = metric_group.loss_cycle_total / metric_group.steps_run_in_phase self.log(name=logging_name_prefix + 'loss', value=mean_cycle_loss, sync_dist=True) - metric_group.loss_cycle_total = 0 - metric_group.functional_logging_metric_cycle_totals = np.zeros(len(metric_group.functional_logging_metrics), - dtype=np.float32) - metric_group.steps_run_in_phase = 0 + metric_group.loss_cycle_total.zero_() + metric_group.steps_run_in_phase.zero_() + metric_group.functional_logging_metric_cycle_totals.zero_() def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: return self.compute_loss_and_metrics(batch, self.validation_metric_groups[0]) diff --git a/tests/end_to_end_tests/test_toy_train_lightning_session.py b/tests/end_to_end_tests/test_toy_train_lightning_session.py deleted file mode 100644 index 750d747e..00000000 --- a/tests/end_to_end_tests/test_toy_train_lightning_session.py +++ /dev/null @@ -1,28 +0,0 @@ -import os -from functools import partial - -from qusi.internal.light_curve_dataset import ( - default_light_curve_observation_post_injection_transform, -) -from qusi.internal.single_dense_layer_model import SingleDenseLayerBinaryClassificationModel -from qusi.internal.toy_light_curve_collection import get_toy_dataset -from qusi.internal.train_hyperparameter_configuration import TrainHyperparameterConfiguration -from qusi.internal.lightning_train_session import train_session - - -def test_toy_train_session(): - os.environ["WANDB_MODE"] = "disabled" - model = SingleDenseLayerBinaryClassificationModel.new(input_size=100) - dataset = get_toy_dataset() - dataset.post_injection_transform = partial( - default_light_curve_observation_post_injection_transform, length=100 - ) - train_hyperparameter_configuration = TrainHyperparameterConfiguration.new( - batch_size=3, cycles=2, train_steps_per_cycle=5, validation_steps_per_cycle=5 - ) - train_session( - train_datasets=[dataset], - validation_datasets=[dataset], - model=model, - hyperparameter_configuration=train_hyperparameter_configuration, - ) diff --git a/tests/unit_tests/photometric_database/test_tess_target.py b/tests/unit_tests/photometric_database/test_tess_target.py index 9fe80d20..0dc4e2cc 100644 --- a/tests/unit_tests/photometric_database/test_tess_target.py +++ b/tests/unit_tests/photometric_database/test_tess_target.py @@ -133,5 +133,5 @@ def test_retrieve_nearby_tic_target_data_frame_corrects_exofop_bug(self): nearby_target_data_frame["TIC ID"] == 231663902 ].iloc[0] assert row["PM RA (mas/yr)"] == pytest.approx(25.0485) - assert row["Separation (arcsec)"] == pytest.approx(18.1) + assert row["Separation (arcsec)"] == pytest.approx(17.5) assert row["Distance Err (pc)"] == pytest.approx(54)