diff --git a/src/qusi/internal/module.py b/src/qusi/internal/module.py index cac86a4..c793aa3 100644 --- a/src/qusi/internal/module.py +++ b/src/qusi/internal/module.py @@ -52,10 +52,10 @@ def __init__( self.model: Module = model self._optimizer: Optimizer = optimizer self.loss_metric: Module = loss_metric - self.state_based_logging_metrics: ModuleList = state_based_logging_metrics - self.functional_logging_metrics: list[Module] = functional_logging_metrics - self._functional_logging_metric_cycle_totals: npt.NDArray = np.zeros(len(self.functional_logging_metrics), - dtype=np.float32) + self.train_state_based_logging_metrics: ModuleList = state_based_logging_metrics + self.train_functional_logging_metrics: list[Module] = functional_logging_metrics + self._train_functional_logging_metric_cycle_totals: npt.NDArray = np.zeros( + len(self.train_functional_logging_metrics), dtype=np.float32) self._loss_cycle_total: int = 0 self._steps_run_in_phase: int = 0 @@ -70,11 +70,11 @@ def compute_loss_and_metrics(self, batch): predicted = self(inputs) loss = self.loss_metric(predicted, target) self._loss_cycle_total += loss - for state_based_logging_metric in self.state_based_logging_metrics: + for state_based_logging_metric in self.train_state_based_logging_metrics: state_based_logging_metric(predicted, target) - for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics): + for functional_logging_metric_index, functional_logging_metric in enumerate(self.train_functional_logging_metrics): functional_logging_metric_value = functional_logging_metric(predicted, target) - self._functional_logging_metric_cycle_totals[ + self._train_functional_logging_metric_cycle_totals[ functional_logging_metric_index] += functional_logging_metric_value self._steps_run_in_phase += 1 return loss @@ -83,14 +83,14 @@ def on_train_epoch_end(self) -> None: self.log_loss_and_metrics() def log_loss_and_metrics(self, logging_name_prefix: str = ''): - for state_based_logging_metric in self.state_based_logging_metrics: + for state_based_logging_metric in self.train_state_based_logging_metrics: state_based_logging_metric_name = get_metric_name(state_based_logging_metric) self.log(name=logging_name_prefix + state_based_logging_metric_name, value=state_based_logging_metric.compute(), sync_dist=True) state_based_logging_metric.reset() - for functional_logging_metric_index, functional_logging_metric in enumerate(self.functional_logging_metrics): + for functional_logging_metric_index, functional_logging_metric in enumerate(self.train_functional_logging_metrics): functional_logging_metric_name = get_metric_name(functional_logging_metric) - functional_logging_metric_cycle_total = float(self._functional_logging_metric_cycle_totals[ + functional_logging_metric_cycle_total = float(self._train_functional_logging_metric_cycle_totals[ functional_logging_metric_index]) functional_logging_metric_cycle_mean = functional_logging_metric_cycle_total / self._steps_run_in_phase @@ -101,7 +101,7 @@ def log_loss_and_metrics(self, logging_name_prefix: str = ''): self.log(name=logging_name_prefix + 'loss', value=mean_cycle_loss, sync_dist=True) self._loss_cycle_total = 0 - self._functional_logging_metric_cycle_totals = np.zeros(len(self.functional_logging_metrics), dtype=np.float32) + self._train_functional_logging_metric_cycle_totals = np.zeros(len(self.train_functional_logging_metrics), dtype=np.float32) self._steps_run_in_phase = 0 def validation_step(self, batch: tuple[Any, Any], batch_index: int) -> STEP_OUTPUT: