diff --git a/torchrec/metrics/calibration_with_recalibration.py b/torchrec/metrics/calibration_with_recalibration.py new file mode 100644 index 000000000..fc7c594b9 --- /dev/null +++ b/torchrec/metrics/calibration_with_recalibration.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, Optional, Type + +import torch +from torchrec.metrics.calibration import ( + CalibrationMetricComputation, + get_calibration_states, +) +from torchrec.metrics.metrics_namespace import MetricNamespace +from torchrec.metrics.rec_metric import ( + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +CALIBRATION_NUM = "calibration_num" +CALIBRATION_DENOM = "calibration_denom" + + +class RecalibratedCalibrationMetricComputation(CalibrationMetricComputation): + r""" + This class implements the RecMetricComputation for Calibration that is required to correctly estimate eval NE if negative downsampling was used during training. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__( + self, *args: Any, recalibration_coefficient: float = 1.0, **kwargs: Any + ) -> None: + self._recalibration_coefficient: float = recalibration_coefficient + super().__init__(*args, **kwargs) + self._add_state( + CALIBRATION_NUM, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + CALIBRATION_DENOM, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def _recalibrate( + self, + predictions: torch.Tensor, + calibration_coef: Optional[torch.Tensor], + ) -> torch.Tensor: + if calibration_coef is not None: + predictions = predictions / ( + predictions + (1.0 - predictions) / calibration_coef + ) + return predictions + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for CalibrationMetricComputation update" + ) + predictions = self._recalibrate( + predictions, self._recalibration_coefficient * torch.ones_like(predictions) + ) + num_samples = predictions.shape[-1] + for state_name, state_value in get_calibration_states( + labels, predictions, weights + ).items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + +class RecalibratedCalibrationMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.RECALIBRATED_CALIBRATION + _computation_class: Type[RecMetricComputation] = ( + RecalibratedCalibrationMetricComputation + ) diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 0be8329f1..56356ff8b 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -23,6 +23,9 @@ from torchrec.metrics.auprc import AUPRCMetric from torchrec.metrics.cali_free_ne import CaliFreeNEMetric from torchrec.metrics.calibration import CalibrationMetric +from torchrec.metrics.calibration_with_recalibration import ( + RecalibratedCalibrationMetric, +) from torchrec.metrics.ctr import CTRMetric from torchrec.metrics.hindsight_target_pr import HindsightTargetPRMetric from torchrec.metrics.mae import MAEMetric @@ -46,6 +49,7 @@ from torchrec.metrics.ndcg import NDCGMetric from torchrec.metrics.ne import NEMetric from torchrec.metrics.ne_positive import NEPositiveMetric +from torchrec.metrics.ne_with_recalibration import RecalibratedNEMetric from torchrec.metrics.output import OutputMetric from torchrec.metrics.precision import PrecisionMetric from torchrec.metrics.precision_session import PrecisionSessionMetric @@ -71,6 +75,8 @@ RecMetricEnum.NE: NEMetric, RecMetricEnum.NE_POSITIVE: NEPositiveMetric, RecMetricEnum.SEGMENTED_NE: SegmentedNEMetric, + RecMetricEnum.RECALIBRATED_NE: RecalibratedNEMetric, + RecMetricEnum.RECALIBRATED_CALIBRATION: RecalibratedCalibrationMetric, RecMetricEnum.CTR: CTRMetric, RecMetricEnum.CALIBRATION: CalibrationMetric, RecMetricEnum.AUC: AUCMetric, diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index e85867862..ce30e3026 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -21,6 +21,8 @@ class RecMetricEnumBase(StrValueMixin, Enum): class RecMetricEnum(RecMetricEnumBase): NE = "ne" NE_POSITIVE = "ne_positive" + RECALIBRATED_NE = "recalibrated_ne" + RECALIBRATED_CALIBRATION = "recalibrated_calibration" SEGMENTED_NE = "segmented_ne" LOG_LOSS = "log_loss" CTR = "ctr" diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index 1afd83e60..177057c2a 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -97,6 +97,8 @@ class MetricNamespace(MetricNamespaceBase): NE = "ne" NE_POSITIVE = "ne_positive" SEGMENTED_NE = "segmented_ne" + RECALIBRATED_NE = "recalibrated_ne" + RECALIBRATED_CALIBRATION = "recalibrated_calibration" THROUGHPUT = "throughput" CTR = "ctr" CALIBRATION = "calibration" diff --git a/torchrec/metrics/ne_with_recalibration.py b/torchrec/metrics/ne_with_recalibration.py new file mode 100644 index 000000000..715ddd356 --- /dev/null +++ b/torchrec/metrics/ne_with_recalibration.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Dict, Optional, Type + +import torch + +from torchrec.metrics.metrics_namespace import MetricNamespace +from torchrec.metrics.ne import get_ne_states, NEMetricComputation +from torchrec.metrics.rec_metric import ( + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +class RecalibratedNEMetricComputation(NEMetricComputation): + r""" + This class implements the recalibration for NE that is required to correctly estimate eval NE if negative downsampling was used during training. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + + Args: + include_logloss (bool): return vanilla logloss as one of metrics results, on top of NE. + """ + + def __init__( + self, + *args: Any, + include_logloss: bool = False, + allow_missing_label_with_zero_weight: bool = False, + recalibration_coefficient: float = 1.0, + **kwargs: Any, + ) -> None: + self._recalibration_coefficient: float = recalibration_coefficient + self._include_logloss: bool = include_logloss + self._allow_missing_label_with_zero_weight: bool = ( + allow_missing_label_with_zero_weight + ) + super().__init__(*args, **kwargs) + self._add_state( + "cross_entropy_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.eta = 1e-12 + + def _recalibrate( + self, + predictions: torch.Tensor, + calibration_coef: Optional[torch.Tensor], + ) -> torch.Tensor: + if calibration_coef is not None: + predictions = predictions / ( + predictions + (1.0 - predictions) / calibration_coef + ) + return predictions + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for RecalibratedNEMetricComputation update" + ) + + predictions = self._recalibrate( + predictions, self._recalibration_coefficient * torch.ones_like(predictions) + ) + states = get_ne_states(labels, predictions, weights, self.eta) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + +class RecalibratedNEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.RECALIBRATED_NE + _computation_class: Type[RecMetricComputation] = RecalibratedNEMetricComputation