Skip to content

Commit

Permalink
add zero_division
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Nov 11, 2024
1 parent 0d43b36 commit fc6f25b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/scportrait/ml/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def on_train_epoch_end(self, trainer, pl_module):
preds_1d = (probs_1d >= 0.5).astype(int)
labels = self.train_actual_labels.numpy()

f1 = f1_score(y_true=labels, y_pred=preds_1d)
f1 = f1_score(y_true=labels, y_pred=preds_1d, zero_division=0)
self.log("f1_score/train_accumulated", f1, sync_dist=True)

def on_validation_epoch_start(self, trainer, pl_module):
Expand All @@ -76,7 +76,7 @@ def on_validation_epoch_end(self, trainer, pl_module):
preds_1d = (probs_1d >= 0.5).astype(int)
labels = self.val_actual_labels.numpy()

f1 = f1_score(y_true=labels, y_pred=preds_1d)
f1 = f1_score(y_true=labels, y_pred=preds_1d, zero_division=0)
self.log("f1_score/val_accumulated", f1, sync_dist=True)

# calculate precision-recall curve
Expand Down Expand Up @@ -151,7 +151,7 @@ def on_test_epoch_end(self, trainer, pl_module):
preds_1d = (probs_1d >= 0.5).astype(int)
labels = self.test_actual_labels.numpy()

f1 = f1_score(y_true=labels, y_pred=preds_1d)
f1 = f1_score(y_true=labels, y_pred=preds_1d, zero_division=0)
self.log("f1_score/test_accumulated", f1, sync_dist=True)

# calculate precision-recall curve
Expand Down

0 comments on commit fc6f25b

Please sign in to comment.