From 6e3d2960508622973dff1da16fb60f84e65cddb0 Mon Sep 17 00:00:00 2001 From: Nam Khanh Pham Date: Tue, 4 Feb 2025 08:48:03 -0800 Subject: [PATCH] fix ndcg metrics in multitasks cases (#2720) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2720 Runs into some errors when using NDCG metrics with multiple tasks. This PR - adds unit tests for the multitask cases - fixes the ndcg implementations for multitask cases Reviewed By: iamzainhuda, venkatrsrinivas Differential Revision: D69057771 fbshipit-source-id: 7d33432f663d933bab17bce6e660d5350962d953 --- torchrec/metrics/ndcg.py | 12 +-- torchrec/metrics/tests/test_ndcg.py | 140 +++++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 7 deletions(-) diff --git a/torchrec/metrics/ndcg.py b/torchrec/metrics/ndcg.py index 61be6525d..718e208f2 100644 --- a/torchrec/metrics/ndcg.py +++ b/torchrec/metrics/ndcg.py @@ -160,12 +160,12 @@ def _get_ndcg_states( ) # Expand these to be [num_task, num_sessions, batch_size] for masking to handle later. - expanded_sorted_labels_by_labels = sorted_labels_by_labels.expand( - (num_tasks, num_sessions, batch_size) - ) - expanded_sorted_labels_by_predictions = sorted_labels_by_predictions.expand( + expanded_sorted_labels_by_labels = sorted_labels_by_labels.unsqueeze(1).expand( (num_tasks, num_sessions, batch_size) ) + expanded_sorted_labels_by_predictions = sorted_labels_by_predictions.unsqueeze( + 1 + ).expand((num_tasks, num_sessions, batch_size)) # Make sure to correspondingly sort session IDs according to how we sorted labels above. session_ids_by_sorted_labels = torch.gather( @@ -188,10 +188,10 @@ def _get_ndcg_states( # Figure out after sorting which example indices belong to which session. sorted_session_ids_by_labels_mask = ( - task_to_session_to_examples == session_ids_by_sorted_labels + task_to_session_to_examples == session_ids_by_sorted_labels.unsqueeze(1) ).long() sorted_session_ids_by_predictions_mask = ( - task_to_session_to_examples == session_ids_by_sorted_predictions + task_to_session_to_examples == session_ids_by_sorted_predictions.unsqueeze(1) ).long() # Get the ranks (1, N] for each example in each session for every task. diff --git a/torchrec/metrics/tests/test_ndcg.py b/torchrec/metrics/tests/test_ndcg.py index 88c227725..add23b885 100644 --- a/torchrec/metrics/tests/test_ndcg.py +++ b/torchrec/metrics/tests/test_ndcg.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List import torch -from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.metrics_config import DefaultTaskInfo, RecComputeMode from torchrec.metrics.ndcg import NDCGMetric, SESSION_KEY from torchrec.metrics.test_utils import RecTaskInfo @@ -22,6 +22,27 @@ WORLD_SIZE = 4 BATCH_SIZE = 10 +DefaultTaskInfo0 = RecTaskInfo( + name="DefaultTask0", + label_name="label", + prediction_name="prediction", + weight_name="weight", +) + +DefaultTaskInfo1 = RecTaskInfo( + name="DefaultTask1", + label_name="label", + prediction_name="prediction", + weight_name="weight", +) + +DefaultTaskInfo2 = RecTaskInfo( + name="DefaultTask2", + label_name="label", + prediction_name="prediction", + weight_name="weight", +) + def get_test_case_single_session_within_batch() -> Dict[str, torch.Tensor]: return { @@ -117,6 +138,41 @@ def get_test_case_scale_by_weights_tensor() -> Dict[str, torch.Tensor]: } +def get_test_case_multitask() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor( + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3], + [0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3], + [0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3], + ] + ), + "session_ids": torch.tensor( + [ + [1, 1, 1, 1, 1, 2, 2, 2], + [1, 1, 1, 1, 1, 2, 2, 2], + [1, 1, 1, 1, 1, 2, 2, 2], + ] + ), + "labels": torch.tensor( + [ + [0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0], + ] + ), + "weights": torch.tensor( + [ + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0], + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0], + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0], + ] + ), + "expected_ndcg_exp": torch.tensor([0.6748, 0.6748, 0.6748]), + "expected_ndcg_non_exp": torch.tensor([0.6463, 0.6463, 0.6463]), + } + + class NDCGMetricValueTest(unittest.TestCase): def generate_metric( self, @@ -130,6 +186,7 @@ def generate_metric( remove_single_length_sessions: bool = False, scale_by_weights_tensor: bool = False, report_ndcg_as_decreasing_curve: bool = True, + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, **kwargs: Dict[str, Any], ) -> NDCGMetric: return NDCGMetric( @@ -149,6 +206,7 @@ def generate_metric( report_ndcg_as_decreasing_curve=report_ndcg_as_decreasing_curve, # pyre-ignore[6] k=k, + compute_mode=compute_mode, # pyre-ignore[6] **kwargs, ) @@ -565,3 +623,83 @@ def test_case_report_as_increasing_ndcg_and_scale_by_weights_tensor(self) -> Non equal_nan=True, msg=f"Actual: {actual_metric}, Expected: {expected_metric}", ) + + def test_multitask_non_exp(self) -> None: + """ + Test NDCG with multiple tasks. + """ + model_output = get_test_case_multitask() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2], + exponential_gain=False, + session_key=SESSION_KEY, + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + ) + + metric.update( + predictions=model_output["predictions"], + labels=model_output["labels"], + weights=model_output["weights"], + required_inputs={SESSION_KEY: model_output["session_ids"]}, + ) + output = metric.compute() + actual_metric = torch.stack( + [ + output[f"ndcg-{task.name}|lifetime_ndcg"] + for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2] + ] + ) + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_multitask_exp(self) -> None: + """ + Test NDCG with multiple tasks. + """ + model_output = get_test_case_multitask() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2], + exponential_gain=True, + session_key=SESSION_KEY, + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + ) + + metric.update( + predictions=model_output["predictions"], + labels=model_output["labels"], + weights=model_output["weights"], + required_inputs={SESSION_KEY: model_output["session_ids"]}, + ) + output = metric.compute() + actual_metric = torch.stack( + [ + output[f"ndcg-{task.name}|lifetime_ndcg"] + for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2] + ] + ) + expected_metric = model_output["expected_ndcg_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + )