diff --git a/blazingai/metrics.py b/blazingai/metrics.py index 057b995..6a55617 100644 --- a/blazingai/metrics.py +++ b/blazingai/metrics.py @@ -27,16 +27,10 @@ def __init__( self.preds = torchmetrics.CatMetric() def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore - """Update state with predictions and targets. - Args: - preds: Predictions from model - target: Ground truth values - """ self.target.update(target) self.preds.update(preds) def compute(self) -> torch.Tensor: - """Computes mean squared error over state.""" preds = self.preds.compute() target = self.target.compute()