Skip to content

Commit

Permalink
Switch the values in the metric group to be device movable tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Oct 10, 2024
1 parent cd77001 commit 0ffb59f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 40 deletions.
26 changes: 15 additions & 11 deletions src/qusi/internal/module.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import copy
from typing import Any

import numpy as np
import numpy.typing as npt
import torch
from lightning import LightningModule
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import Tensor, tensor
from torch.nn import Module, BCELoss, ModuleList
from torch.optim import Optimizer, AdamW
from torchmetrics import Metric
Expand All @@ -21,10 +21,15 @@ def __init__(self, loss_metric: Module, state_based_logging_metrics: ModuleList,
self.loss_metric: Module = loss_metric
self.state_based_logging_metrics: ModuleList = state_based_logging_metrics
self.functional_logging_metrics: ModuleList = functional_logging_metrics
self.loss_cycle_total: float = 0
self.steps_run_in_phase: int = 0
self.functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(
len(self.functional_logging_metrics), dtype=np.float32)
# Lightning requires tensors be registered to be automatically moved between devices.
# Then we assign it to itself to force IDE resolution.
self.register_buffer('loss_cycle_total', tensor(0, dtype=torch.float32))
self.loss_cycle_total: Tensor = self.loss_cycle_total
self.register_buffer('steps_run_in_phase', tensor(0, dtype=torch.int32))
self.steps_run_in_phase: Tensor = self.steps_run_in_phase
self.register_buffer('functional_logging_metric_cycle_totals',
torch.zeros(len(self.functional_logging_metrics), dtype=torch.float32))
self.functional_logging_metric_cycle_totals: Tensor = self.functional_logging_metric_cycle_totals

@classmethod
def new(
Expand All @@ -48,7 +53,7 @@ def new(
model: Module,
optimizer: Optimizer | None,
loss_metric: Module | None = None,
logging_metrics: ModuleList | None = None,
logging_metrics: list[Module] | None = None,
) -> Self:
if optimizer is None:
optimizer = AdamW(model.parameters())
Expand Down Expand Up @@ -125,10 +130,9 @@ def log_loss_and_metrics(self, metric_group: MetricGroup, logging_name_prefix: s
mean_cycle_loss = metric_group.loss_cycle_total / metric_group.steps_run_in_phase
self.log(name=logging_name_prefix + 'loss',
value=mean_cycle_loss, sync_dist=True)
metric_group.loss_cycle_total = 0
metric_group.functional_logging_metric_cycle_totals = np.zeros(len(metric_group.functional_logging_metrics),
dtype=np.float32)
metric_group.steps_run_in_phase = 0
metric_group.loss_cycle_total.zero_()
metric_group.steps_run_in_phase.zero_()
metric_group.functional_logging_metric_cycle_totals.zero_()

def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT:
return self.compute_loss_and_metrics(batch, self.validation_metric_groups[0])
Expand Down
28 changes: 0 additions & 28 deletions tests/end_to_end_tests/test_toy_train_lightning_session.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/unit_tests/photometric_database/test_tess_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,5 @@ def test_retrieve_nearby_tic_target_data_frame_corrects_exofop_bug(self):
nearby_target_data_frame["TIC ID"] == 231663902
].iloc[0]
assert row["PM RA (mas/yr)"] == pytest.approx(25.0485)
assert row["Separation (arcsec)"] == pytest.approx(18.1)
assert row["Separation (arcsec)"] == pytest.approx(17.5)
assert row["Distance Err (pc)"] == pytest.approx(54)

0 comments on commit 0ffb59f

Please sign in to comment.