Skip to content

Commit

Permalink
cbpe lingint
Browse files Browse the repository at this point in the history
  • Loading branch information
nikml committed Apr 24, 2024
1 parent b653620 commit ee2b1c8
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions nannyml/performance_estimation/confidence_based/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,10 +1429,10 @@ def _fit(self, reference_data: pd.DataFrame):
[self.y_true, self.y_pred]
)
if empty:
self._true_positive_sampling_error_components = np.NaN, 0, self.normalize_confusion_matrix
self._true_negative_sampling_error_components = np.NaN, 0, self.normalize_confusion_matrix
self._false_positive_sampling_error_components = np.NaN, 0, self.normalize_confusion_matrix
self._false_negative_sampling_error_components = np.NaN, 0, self.normalize_confusion_matrix
self._true_positive_sampling_error_components = np.NaN, 0., self.normalize_confusion_matrix
self._true_negative_sampling_error_components = np.NaN, 0., self.normalize_confusion_matrix
self._false_positive_sampling_error_components = np.NaN, 0., self.normalize_confusion_matrix
self._false_negative_sampling_error_components = np.NaN, 0., self.normalize_confusion_matrix
else:
self._true_positive_sampling_error_components = bse.true_positive_sampling_error_components(
y_true_reference=reference_data[self.y_true],
Expand Down Expand Up @@ -3076,15 +3076,14 @@ def __init__(
self._sampling_error_components: Tuple = ()

def _fit(self, reference_data: pd.DataFrame):
classes = class_labels(self.y_pred_proba)
_list_missing([self.y_true, self.y_pred], list(reference_data.columns))
# filter nans here
reference_data, empty = common_nan_removal(
reference_data[[self.y_true, self.y_pred]],
[self.y_true, self.y_pred]
)
if empty:
self._sampling_error_components = [(np.NaN, 0) for clazz in classes]
self._sampling_error_components = (np.NaN,)
else:
label_binarizer = LabelBinarizer()
binarized_y_true = label_binarizer.fit_transform(reference_data[self.y_true])
Expand Down Expand Up @@ -3176,10 +3175,11 @@ def _realized_performance(self, data: pd.DataFrame) -> float:
@MetricFactory.register('confusion_matrix', ProblemType.CLASSIFICATION_MULTICLASS)
class MulticlassClassificationConfusionMatrix(Metric):
"""CBPE multiclass classification confusion matrix Metric Class."""
y_pred_proba: Dict[str, str]

def __init__(
self,
y_pred_proba: ModelOutputsType,
y_pred_proba: Dict[str, str],
y_pred: str,
y_true: str,
chunker: Chunker,
Expand All @@ -3194,7 +3194,7 @@ def __init__(
"y_pred_proba must be a dictionary with class labels as keys and pred_proba column names as values"
)

self.classes = sorted(list(y_pred_proba.keys()))
self.classes: List[str] = sorted(list(y_pred_proba.keys()))

super().__init__(
name='confusion_matrix',
Expand Down Expand Up @@ -3336,7 +3336,7 @@ def _get_multiclass_confusion_matrix_estimate(self, chunk_data: pd.DataFrame) ->
except InvalidArgumentsException as ex:
if "missing required columns" in str(ex):
self._logger.debug(str(ex))
return np.NaN
return np.full((len(self.classes), len(self.classes)), np.nan)
else:
raise ex

Expand All @@ -3347,7 +3347,7 @@ def _get_multiclass_confusion_matrix_estimate(self, chunk_data: pd.DataFrame) ->
if empty:
self._logger.debug(f"Not enough data to compute estimated {self.display_name}.")
warnings.warn(f"Not enough data to compute estimated {self.display_name}.")
return np.NaN
return np.full((len(self.classes), len(self.classes)), np.nan)

y_pred_proba = {key: chunk_data[value] for key, value in self.y_pred_proba.items()}

Expand Down

0 comments on commit ee2b1c8

Please sign in to comment.