diff --git a/src/qusi/internal/module.py b/src/qusi/internal/module.py index 1a08189..5f6621a 100644 --- a/src/qusi/internal/module.py +++ b/src/qusi/internal/module.py @@ -27,7 +27,7 @@ def __init__(self, loss_metric: Module, state_based_logging_metrics: ModuleList, # Then we assign it to itself to force IDE resolution. self.register_buffer('loss_cycle_total', tensor(0, dtype=torch.float32)) self.loss_cycle_total: Tensor = self.loss_cycle_total - self.register_buffer('steps_run_in_phase', tensor(0, dtype=torch.int32)) + self.register_buffer('steps_run_in_phase', tensor(0, dtype=torch.int64)) self.steps_run_in_phase: Tensor = self.steps_run_in_phase self.register_buffer('functional_logging_metric_cycle_totals', torch.zeros(len(self.functional_logging_metrics), dtype=torch.float32)) @@ -88,10 +88,21 @@ def __init__( self._optimizer: Optimizer = optimizer self.train_metric_group: MetricGroup = train_metric_group self.validation_metric_groups: ModuleList | list[MetricGroup] = validation_metric_groups + # Lightning requires tensors be registered to be automatically moved between devices. + # Then we assign it to itself to force IDE resolution. + # `cycle` is incremented and logged during the train epoch start, so it needs to start at -1. + self.register_buffer('cycle', tensor(-1, dtype=torch.int64)) + self.cycle: Tensor = self.cycle def forward(self, inputs: Any) -> Any: return self.model(inputs) + def on_train_epoch_start(self) -> None: + # Due to Lightning's inconsistent step ordering, performing this during the train epoch start gives the most + # consistent results. + self.cycle += 1 + self.log(name='cycle', value=self.cycle) + def training_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: return self.compute_loss_and_metrics(batch, self.train_metric_group)