diff --git a/src/base_trainer.py b/src/base_trainer.py index f672d52..c2ed9d7 100644 --- a/src/base_trainer.py +++ b/src/base_trainer.py @@ -98,12 +98,12 @@ def _train_val_iteration( torch.Tensor: The loss for the batch. Dict[str, torch.Tensor]: The loss components for the batch. """ - # x, y = batch - # y_hat = self._model(x) - # losses = self._training_loss(x, y, y_hat) - # loss = sum([v for v in losses.values()]) - # return loss, losses - raise NotImplementedError + # TODO: You'll most likely want to override this method. + x, y = batch + y_hat = self._model(x) + losses: Dict[str, torch.Tensor] = self._training_loss(y, y_hat) + loss = sum([v for v in losses.values()]) + return loss, losses def _train_epoch( self, description: str, visualize: bool, epoch: int, last_val_loss: float @@ -134,7 +134,7 @@ def _train_epoch( break self._opt.zero_grad() loss, loss_components = self._train_val_iteration( - batch, epoch + batch, epoch, validation=False ) # User implementation goes here (train.py) loss.backward() self._opt.step() @@ -193,7 +193,8 @@ def _val_epoch(self, description: str, visualize: bool, epoch: int) -> float: # Blink the progress bar to indicate that the validation loop is running blink_pbar(i, self._pbar, 4) loss, loss_components = self._train_val_iteration( - batch + batch, + epoch, ) # User implementation goes here (train.py) val_loss.update(loss.item()) for k, v in loss_components.items(): diff --git a/src/losses/mse.py b/src/losses/mse.py index 17246e5..56c6018 100644 --- a/src/losses/mse.py +++ b/src/losses/mse.py @@ -18,4 +18,8 @@ def __init__(self, reduction: str): self._reduction = reduction def __call__(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - return torch.nn.functional.mse_loss(y_pred, y_true, reduction=self._reduction) + return { + "mse": torch.nn.functional.mse_loss( + y_pred, y_true, reduction=self._reduction + ) + }