Skip to content

Commit

Permalink
calibration ne (#2675)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2675

Recalibration is required if negative downsampling is used during training:
[1] https://johaupt.github.io/blog/downsampling_recalibration.html
[2] https://cseweb.ucsd.edu/~elkan/rescale.pdf

We define calibration_with_recalibration and ne_with_recalibration which re-weight  model predictions before computing NE and calibrations: {F1974271096}
where p' - original model predictions, w - recalibration parameter, p - predictions used in metrics computations.

Differential Revision: D67784713
  • Loading branch information
Mark Gluzman authored and facebook-github-bot committed Jan 10, 2025
1 parent 7833314 commit 0e92204
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 0 deletions.
96 changes: 96 additions & 0 deletions torchrec/metrics/calibration_with_recalibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Dict, Optional, Type

import torch
from torchrec.metrics.calibration import (
CalibrationMetricComputation,
get_calibration_states,
)
from torchrec.metrics.metrics_namespace import MetricNamespace
from torchrec.metrics.rec_metric import (
RecMetric,
RecMetricComputation,
RecMetricException,
)


CALIBRATION_NUM = "calibration_num"
CALIBRATION_DENOM = "calibration_denom"


class RecalibratedCalibrationMetricComputation(CalibrationMetricComputation):
r"""
This class implements the RecMetricComputation for Calibration that is required to correctly estimate eval NE if negative downsampling was used during training.
The constructor arguments are defined in RecMetricComputation.
See the docstring of RecMetricComputation for more detail.
"""

def __init__(
self, *args: Any, recalibration_coefficient: float = 1.0, **kwargs: Any
) -> None:
self._recalibration_coefficient: float = recalibration_coefficient
super().__init__(*args, **kwargs)
self._add_state(
CALIBRATION_NUM,
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
CALIBRATION_DENOM,
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)

def _recalibrate(
self,
predictions: torch.Tensor,
calibration_coef: Optional[torch.Tensor],
) -> torch.Tensor:
if calibration_coef is not None:
predictions = predictions / (
predictions + (1.0 - predictions) / calibration_coef
)
return predictions

def update(
self,
*,
predictions: Optional[torch.Tensor],
labels: torch.Tensor,
weights: Optional[torch.Tensor],
**kwargs: Dict[str, Any],
) -> None:
if predictions is None or weights is None:
raise RecMetricException(
"Inputs 'predictions' and 'weights' should not be None for CalibrationMetricComputation update"
)
predictions = self._recalibrate(
predictions, self._recalibration_coefficient * torch.ones_like(predictions)
)
num_samples = predictions.shape[-1]
for state_name, state_value in get_calibration_states(
labels, predictions, weights
).items():
state = getattr(self, state_name)
state += state_value
self._aggregate_window_state(state_name, state_value, num_samples)


class RecalibratedCalibrationMetric(RecMetric):
_namespace: MetricNamespace = MetricNamespace.RECALIBRATED_CALIBRATION
_computation_class: Type[RecMetricComputation] = (
RecalibratedCalibrationMetricComputation
)
6 changes: 6 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from torchrec.metrics.auprc import AUPRCMetric
from torchrec.metrics.cali_free_ne import CaliFreeNEMetric
from torchrec.metrics.calibration import CalibrationMetric
from torchrec.metrics.calibration_with_recalibration import (
RecalibratedCalibrationMetric,
)
from torchrec.metrics.ctr import CTRMetric
from torchrec.metrics.hindsight_target_pr import HindsightTargetPRMetric
from torchrec.metrics.mae import MAEMetric
Expand All @@ -46,6 +49,7 @@
from torchrec.metrics.ndcg import NDCGMetric
from torchrec.metrics.ne import NEMetric
from torchrec.metrics.ne_positive import NEPositiveMetric
from torchrec.metrics.ne_with_recalibration import RecalibratedNEMetric
from torchrec.metrics.output import OutputMetric
from torchrec.metrics.precision import PrecisionMetric
from torchrec.metrics.precision_session import PrecisionSessionMetric
Expand All @@ -71,6 +75,8 @@
RecMetricEnum.NE: NEMetric,
RecMetricEnum.NE_POSITIVE: NEPositiveMetric,
RecMetricEnum.SEGMENTED_NE: SegmentedNEMetric,
RecMetricEnum.RECALIBRATED_NE: RecalibratedNEMetric,
RecMetricEnum.RECALIBRATED_CALIBRATION: RecalibratedCalibrationMetric,
RecMetricEnum.CTR: CTRMetric,
RecMetricEnum.CALIBRATION: CalibrationMetric,
RecMetricEnum.AUC: AUCMetric,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class RecMetricEnumBase(StrValueMixin, Enum):
class RecMetricEnum(RecMetricEnumBase):
NE = "ne"
NE_POSITIVE = "ne_positive"
RECALIBRATED_NE = "recalibrated_ne"
RECALIBRATED_CALIBRATION = "recalibrated_calibration"
SEGMENTED_NE = "segmented_ne"
LOG_LOSS = "log_loss"
CTR = "ctr"
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/metrics_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class MetricNamespace(MetricNamespaceBase):
NE = "ne"
NE_POSITIVE = "ne_positive"
SEGMENTED_NE = "segmented_ne"
RECALIBRATED_NE = "recalibrated_ne"
RECALIBRATED_CALIBRATION = "recalibrated_calibration"
THROUGHPUT = "throughput"
CTR = "ctr"
CALIBRATION = "calibration"
Expand Down
116 changes: 116 additions & 0 deletions torchrec/metrics/ne_with_recalibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, Dict, Optional, Type

import torch

from torchrec.metrics.metrics_namespace import MetricNamespace
from torchrec.metrics.ne import get_ne_states, NEMetricComputation
from torchrec.metrics.rec_metric import (
RecMetric,
RecMetricComputation,
RecMetricException,
)


class RecalibratedNEMetricComputation(NEMetricComputation):
r"""
This class implements the recalibration for NE that is required to correctly estimate eval NE if negative downsampling was used during training.
The constructor arguments are defined in RecMetricComputation.
See the docstring of RecMetricComputation for more detail.
Args:
include_logloss (bool): return vanilla logloss as one of metrics results, on top of NE.
"""

def __init__(
self,
*args: Any,
include_logloss: bool = False,
allow_missing_label_with_zero_weight: bool = False,
recalibration_coefficient: float = 1.0,
**kwargs: Any,
) -> None:
self._recalibration_coefficient: float = recalibration_coefficient
self._include_logloss: bool = include_logloss
self._allow_missing_label_with_zero_weight: bool = (
allow_missing_label_with_zero_weight
)
super().__init__(*args, **kwargs)
self._add_state(
"cross_entropy_sum",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"weighted_num_samples",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"pos_labels",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"neg_labels",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self.eta = 1e-12

def _recalibrate(
self,
predictions: torch.Tensor,
calibration_coef: Optional[torch.Tensor],
) -> torch.Tensor:
if calibration_coef is not None:
predictions = predictions / (
predictions + (1.0 - predictions) / calibration_coef
)
return predictions

def update(
self,
*,
predictions: Optional[torch.Tensor],
labels: torch.Tensor,
weights: Optional[torch.Tensor],
**kwargs: Dict[str, Any],
) -> None:
if predictions is None or weights is None:
raise RecMetricException(
"Inputs 'predictions' and 'weights' should not be None for RecalibratedNEMetricComputation update"
)

predictions = self._recalibrate(
predictions, self._recalibration_coefficient * torch.ones_like(predictions)
)
states = get_ne_states(labels, predictions, weights, self.eta)
num_samples = predictions.shape[-1]

for state_name, state_value in states.items():
state = getattr(self, state_name)
state += state_value
self._aggregate_window_state(state_name, state_value, num_samples)


class RecalibratedNEMetric(RecMetric):
_namespace: MetricNamespace = MetricNamespace.RECALIBRATED_NE
_computation_class: Type[RecMetricComputation] = RecalibratedNEMetricComputation

0 comments on commit 0e92204

Please sign in to comment.