diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 99f531d..46d2288 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -392,7 +392,9 @@ def train(self, verbose: bool = True): real_batch_size = seq_lens.shape[0] outputs, _ = self.model(sequences) stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] - retentions = power_forgetting_curve(delta_ts, stabilities, -self.model.w[19]) + retentions = power_forgetting_curve( + delta_ts, stabilities, -self.model.w[19] + ) loss = (self.loss_fn(retentions, labels) * weights).sum() penalty = torch.sum( torch.square(self.model.w - self.init_w_tensor)