diff --git a/nannyml/sampling_error/binary_classification.py b/nannyml/sampling_error/binary_classification.py index e8232ce9..bcd4a0f6 100644 --- a/nannyml/sampling_error/binary_classification.py +++ b/nannyml/sampling_error/binary_classification.py @@ -178,6 +178,10 @@ def f1_sampling_error_components(y_true_reference: pd.Series, y_pred_reference: tp_fp_fn = np.concatenate([TP, FN, FP]) + # If there's no true positives, false negatives or false positives, sampling error is NaN + if tp_fp_fn.size == 0: + return np.nan, 0 + correcting_factor = len(tp_fp_fn) / ((len(FN) + len(FP)) * 0.5 + len(TP)) obs_level_f1 = tp_fp_fn * correcting_factor fraction_of_relevant = len(tp_fp_fn) / len(y_pred_reference)