diff --git a/.github/workflows/build-wheels-linux.yml b/.github/workflows/build-wheels-linux.yml index f709008e6..50dc5ccdc 100644 --- a/.github/workflows/build-wheels-linux.yml +++ b/.github/workflows/build-wheels-linux.yml @@ -33,7 +33,7 @@ jobs: steps: - uses: actions/setup-python@v4 - name: Checkout torchrec repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: repository: pytorch/torchrec - name: Filter Generated Built Matrix diff --git a/.github/workflows/build_dynamic_embedding_wheels.yml b/.github/workflows/build_dynamic_embedding_wheels.yml index 3c37d2aa0..a3d32d278 100644 --- a/.github/workflows/build_dynamic_embedding_wheels.yml +++ b/.github/workflows/build_dynamic_embedding_wheels.yml @@ -37,7 +37,7 @@ jobs: name: Check disk space run: df . -h - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: submodules: recursive @@ -55,6 +55,6 @@ jobs: shell: bash - name: Upload wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: path: wheelhouse/*.whl diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index bd5162a4b..360668d4b 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -23,7 +23,7 @@ jobs: - name: Check ldd --version run: ldd --version - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 # Update references - name: Update pip run: | @@ -83,7 +83,7 @@ jobs: conda run -n build_binary make html cd .. - name: Upload Built-Docs - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: Built-Docs path: docs/build/html/ @@ -103,9 +103,9 @@ jobs: if: ${{ github.event_name == 'pull_request' }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Download artifact - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: Built-Docs path: docs diff --git a/.github/workflows/pre-commit.yaml b/.github/workflows/pre-commit.yaml index d4358a0f0..85716982d 100644 --- a/.github/workflows/pre-commit.yaml +++ b/.github/workflows/pre-commit.yaml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Setup Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: 3.9 architecture: x64 @@ -19,6 +19,6 @@ jobs: black==24.2.0 usort==1.0.8 - name: Checkout Torchrec - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Run pre-commit uses: pre-commit/action@v2.0.3 diff --git a/.github/workflows/pyre.yml b/.github/workflows/pyre.yml index ed25404e7..fd773c787 100644 --- a/.github/workflows/pyre.yml +++ b/.github/workflows/pyre.yml @@ -16,7 +16,7 @@ jobs: with: python-version: 3.9 - name: Checkout Torchrec - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Install dependencies run: > pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu && diff --git a/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb index 59bfd86c0..015d216e4 100644 --- a/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb +++ b/TorchRec_Interactive_Tutorial_Notebook_OSS_version.ipynb @@ -1818,7 +1818,7 @@ "source": [ "### DistributedModelParallel\n", "\n", - "We have now explored sharding a single EmbeddingBagCollection! We were able to take the `EmbeddingBagCollectionSharder` and use the unsharded `EmbeddingBagCollection` to generate a `ShardedEmbeddingBagCollection` module. This workflow is fine, but typically when doing model parallel, [`DistributedModelParallel`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel) (DMP) is used as the standard interface. When wrapping your model (in our case `ebc`), with DMP, the following will occur:\n", + "We have now explored sharding a single EmbeddingBagCollection! We were able to take the `EmbeddingBagCollectionSharder` and use the unsharded `EmbeddingBagCollection` to generate a `ShardedEmbeddingBagCollection` module. This workflow is fine, but typically when doing model parallel, [`DistributedModelParallel`](https://pytorch.org/torchrec/model-parallel-api-reference.html#model-parallel) (DMP) is used as the standard interface. When wrapping your model (in our case `ebc`), with DMP, the following will occur:\n", "\n", "1. Decide how to shard the model. DMP will collect the available ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the embedding table(s) (i.e, the EmbeddingBagCollection)\n", "2. Actually shard the model. This includes allocating memory for each embedding table on the appropriate device(s).\n", diff --git a/torchrec/metrics/ndcg.py b/torchrec/metrics/ndcg.py index 61be6525d..718e208f2 100644 --- a/torchrec/metrics/ndcg.py +++ b/torchrec/metrics/ndcg.py @@ -160,12 +160,12 @@ def _get_ndcg_states( ) # Expand these to be [num_task, num_sessions, batch_size] for masking to handle later. - expanded_sorted_labels_by_labels = sorted_labels_by_labels.expand( - (num_tasks, num_sessions, batch_size) - ) - expanded_sorted_labels_by_predictions = sorted_labels_by_predictions.expand( + expanded_sorted_labels_by_labels = sorted_labels_by_labels.unsqueeze(1).expand( (num_tasks, num_sessions, batch_size) ) + expanded_sorted_labels_by_predictions = sorted_labels_by_predictions.unsqueeze( + 1 + ).expand((num_tasks, num_sessions, batch_size)) # Make sure to correspondingly sort session IDs according to how we sorted labels above. session_ids_by_sorted_labels = torch.gather( @@ -188,10 +188,10 @@ def _get_ndcg_states( # Figure out after sorting which example indices belong to which session. sorted_session_ids_by_labels_mask = ( - task_to_session_to_examples == session_ids_by_sorted_labels + task_to_session_to_examples == session_ids_by_sorted_labels.unsqueeze(1) ).long() sorted_session_ids_by_predictions_mask = ( - task_to_session_to_examples == session_ids_by_sorted_predictions + task_to_session_to_examples == session_ids_by_sorted_predictions.unsqueeze(1) ).long() # Get the ranks (1, N] for each example in each session for every task. diff --git a/torchrec/metrics/tests/test_ndcg.py b/torchrec/metrics/tests/test_ndcg.py index 88c227725..add23b885 100644 --- a/torchrec/metrics/tests/test_ndcg.py +++ b/torchrec/metrics/tests/test_ndcg.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List import torch -from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.metrics_config import DefaultTaskInfo, RecComputeMode from torchrec.metrics.ndcg import NDCGMetric, SESSION_KEY from torchrec.metrics.test_utils import RecTaskInfo @@ -22,6 +22,27 @@ WORLD_SIZE = 4 BATCH_SIZE = 10 +DefaultTaskInfo0 = RecTaskInfo( + name="DefaultTask0", + label_name="label", + prediction_name="prediction", + weight_name="weight", +) + +DefaultTaskInfo1 = RecTaskInfo( + name="DefaultTask1", + label_name="label", + prediction_name="prediction", + weight_name="weight", +) + +DefaultTaskInfo2 = RecTaskInfo( + name="DefaultTask2", + label_name="label", + prediction_name="prediction", + weight_name="weight", +) + def get_test_case_single_session_within_batch() -> Dict[str, torch.Tensor]: return { @@ -117,6 +138,41 @@ def get_test_case_scale_by_weights_tensor() -> Dict[str, torch.Tensor]: } +def get_test_case_multitask() -> Dict[str, torch.Tensor]: + return { + "predictions": torch.tensor( + [ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3], + [0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3], + [0.1, 0.2, 0.3, 0.4, 0.5, 0.1, 0.2, 0.3], + ] + ), + "session_ids": torch.tensor( + [ + [1, 1, 1, 1, 1, 2, 2, 2], + [1, 1, 1, 1, 1, 2, 2, 2], + [1, 1, 1, 1, 1, 2, 2, 2], + ] + ), + "labels": torch.tensor( + [ + [0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 2.0, 2.0, 1.0, 0.0], + ] + ), + "weights": torch.tensor( + [ + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0], + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0], + [1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 3.0], + ] + ), + "expected_ndcg_exp": torch.tensor([0.6748, 0.6748, 0.6748]), + "expected_ndcg_non_exp": torch.tensor([0.6463, 0.6463, 0.6463]), + } + + class NDCGMetricValueTest(unittest.TestCase): def generate_metric( self, @@ -130,6 +186,7 @@ def generate_metric( remove_single_length_sessions: bool = False, scale_by_weights_tensor: bool = False, report_ndcg_as_decreasing_curve: bool = True, + compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION, **kwargs: Dict[str, Any], ) -> NDCGMetric: return NDCGMetric( @@ -149,6 +206,7 @@ def generate_metric( report_ndcg_as_decreasing_curve=report_ndcg_as_decreasing_curve, # pyre-ignore[6] k=k, + compute_mode=compute_mode, # pyre-ignore[6] **kwargs, ) @@ -565,3 +623,83 @@ def test_case_report_as_increasing_ndcg_and_scale_by_weights_tensor(self) -> Non equal_nan=True, msg=f"Actual: {actual_metric}, Expected: {expected_metric}", ) + + def test_multitask_non_exp(self) -> None: + """ + Test NDCG with multiple tasks. + """ + model_output = get_test_case_multitask() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2], + exponential_gain=False, + session_key=SESSION_KEY, + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + ) + + metric.update( + predictions=model_output["predictions"], + labels=model_output["labels"], + weights=model_output["weights"], + required_inputs={SESSION_KEY: model_output["session_ids"]}, + ) + output = metric.compute() + actual_metric = torch.stack( + [ + output[f"ndcg-{task.name}|lifetime_ndcg"] + for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2] + ] + ) + expected_metric = model_output["expected_ndcg_non_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + ) + + def test_multitask_exp(self) -> None: + """ + Test NDCG with multiple tasks. + """ + model_output = get_test_case_multitask() + metric = self.generate_metric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2], + exponential_gain=True, + session_key=SESSION_KEY, + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + ) + + metric.update( + predictions=model_output["predictions"], + labels=model_output["labels"], + weights=model_output["weights"], + required_inputs={SESSION_KEY: model_output["session_ids"]}, + ) + output = metric.compute() + actual_metric = torch.stack( + [ + output[f"ndcg-{task.name}|lifetime_ndcg"] + for task in [DefaultTaskInfo0, DefaultTaskInfo1, DefaultTaskInfo2] + ] + ) + expected_metric = model_output["expected_ndcg_exp"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + )