From 61a849d6dabd916cb476fcefd00e729297630d5d Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 22 Jan 2025 11:34:53 +0000 Subject: [PATCH] 2025-01-22 nightly release (dd5457c43b0d4f9697b7b84a70ea4f97cfbfd6ad) --- .github/workflows/unittest_ci.yml | 3 + .github/workflows/unittest_ci_cpu.yml | 3 + torchrec/distributed/embedding.py | 13 +- torchrec/distributed/embedding_sharding.py | 5 +- torchrec/distributed/embeddingbag.py | 16 ++- .../distributed/test_utils/multi_process.py | 19 +-- .../distributed/test_utils/test_sharding.py | 32 ++++- .../tests/test_sequence_model_parallel.py | 41 +++++++ torchrec/distributed/tests/test_utils.py | 113 +++--------------- .../tests/pipeline_benchmarks.py | 4 +- torchrec/modules/embedding_modules.py | 10 +- torchrec/sparse/tests/test_tensor_dict.py | 22 ++-- 12 files changed, 147 insertions(+), 134 deletions(-) diff --git a/.github/workflows/unittest_ci.yml b/.github/workflows/unittest_ci.yml index 1f0b18dcd..89aa2768f 100644 --- a/.github/workflows/unittest_ci.yml +++ b/.github/workflows/unittest_ci.yml @@ -64,6 +64,9 @@ jobs: python-tag: "py312" cuda-tag: "cu124" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: runner: ${{ matrix.os }} timeout: 30 diff --git a/.github/workflows/unittest_ci_cpu.yml b/.github/workflows/unittest_ci_cpu.yml index 4c92e6027..a0b0ef9eb 100644 --- a/.github/workflows/unittest_ci_cpu.yml +++ b/.github/workflows/unittest_ci_cpu.yml @@ -36,6 +36,9 @@ jobs: python-version: '3.12' python-tag: "py312" uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read with: runner: ${{ matrix.os }} timeout: 15 diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 93773cc1f..feb77a72a 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -26,6 +26,7 @@ ) import torch +from tensordict import TensorDict from torch import distributed as dist, nn from torch.autograd.profiler import record_function from torch.distributed._shard.sharding_spec import EnumerableShardingSpec @@ -90,6 +91,7 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -1198,8 +1200,15 @@ def _compute_sequence_vbe_context( def input_dist( self, ctx: EmbeddingCollectionContext, - features: KeyedJaggedTensor, + features: TypeUnion[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: + need_permute: bool = True + if isinstance(features, TensorDict): + feature_keys = list(features.keys()) # pyre-ignore[6] + if self._features_order: + feature_keys = [feature_keys[i] for i in self._features_order] + need_permute = False + features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] if self._has_uninitialized_input_dist: self._create_input_dist(input_feature_names=features.keys()) self._has_uninitialized_input_dist = False @@ -1209,7 +1218,7 @@ def input_dist( unpadded_features = features features = pad_vbe_kjt_lengths(unpadded_features) - if self._features_order: + if need_permute and self._features_order: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index 38bb0dd4b..0f37e71a1 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -274,7 +274,10 @@ def bucketize_kjt_before_all2all( batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt), max_B=_fx_wrap_max_B(kjt), block_bucketize_pos=( - _fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths()) + [ + _fx_wrap_tensor_to_device_dtype(pos, kjt.values()) + for pos in block_bucketize_row_pos + ] if block_bucketize_row_pos is not None else None ), diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 8cfd16ae9..de3d495f2 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,6 +27,7 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function from torch.distributed._shard.sharded_tensor import TensorProperties @@ -94,6 +95,7 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -656,9 +658,7 @@ def __init__( self._inverse_indices_permute_indices: Optional[torch.Tensor] = None # to support mean pooling callback hook self._has_mean_pooling_callback: bool = ( - True - if PoolingType.MEAN.value in self._pooling_type_to_rs_features - else False + PoolingType.MEAN.value in self._pooling_type_to_rs_features ) self._dim_per_key: Optional[torch.Tensor] = None self._kjt_key_indices: Dict[str, int] = {} @@ -1189,8 +1189,16 @@ def _create_inverse_indices_permute_indices( # pyre-ignore [14] def input_dist( - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor + self, + ctx: EmbeddingBagCollectionContext, + features: Union[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: + if isinstance(features, TensorDict): + feature_keys = list(features.keys()) # pyre-ignore[6] + if len(self._features_order) > 0: + feature_keys = [feature_keys[i] for i in self._features_order] + self._has_features_permute = False # feature_keys are in order + features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] ctx.variable_batch_per_feature = features.variable_stride_per_key() ctx.inverse_indices = features.inverse_indices_or_none() if self._has_uninitialized_input_dist: diff --git a/torchrec/distributed/test_utils/multi_process.py b/torchrec/distributed/test_utils/multi_process.py index f3233e9b0..ac003d02b 100644 --- a/torchrec/distributed/test_utils/multi_process.py +++ b/torchrec/distributed/test_utils/multi_process.py @@ -9,6 +9,7 @@ #!/usr/bin/env python3 +import logging import multiprocessing import os import unittest @@ -24,11 +25,6 @@ ) -# AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail -# Therefore we use spawn for HIP runtime until AMD fixes the issue -_MP_INIT_MODE = "forkserver" if torch.version.hip is None else "spawn" - - class MultiProcessContext: def __init__( self, @@ -98,6 +94,15 @@ def __exit__(self, exc_type, exc_instance, traceback) -> None: class MultiProcessTestBase(unittest.TestCase): + def __init__( + self, methodName: str = "runTest", mp_init_mode: str = "forkserver" + ) -> None: + super().__init__(methodName) + + # AMD's HIP runtime doesn't seem to work with forkserver; hipMalloc will fail + # Therefore we use spawn for HIP runtime until AMD fixes the issue + self._mp_init_mode: str = mp_init_mode if torch.version.hip is None else "spawn" + logging.info(f"Using {self._mp_init_mode} for multiprocessing") @seed_and_log def setUp(self) -> None: @@ -131,7 +136,7 @@ def _run_multi_process_test( # pyre-ignore **kwargs, ) -> None: - ctx = multiprocessing.get_context(_MP_INIT_MODE) + ctx = multiprocessing.get_context(self._mp_init_mode) processes = [] for rank in range(world_size): kwargs["rank"] = rank @@ -157,7 +162,7 @@ def _run_multi_process_test_per_rank( world_size: int, kwargs_per_rank: List[Dict[str, Any]], ) -> None: - ctx = multiprocessing.get_context(_MP_INIT_MODE) + ctx = multiprocessing.get_context(self._mp_init_mode) processes = [] for rank in range(world_size): kwargs = {} diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index f2b65a833..48b9a90ab 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -147,6 +147,7 @@ def gen_model_and_input( long_indices: bool = True, global_constant_batch: bool = False, num_inputs: int = 1, + input_type: str = "kjt", # "kjt" or "td" ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -177,9 +178,9 @@ def gen_model_and_input( feature_processor_modules=feature_processor_modules, ) inputs = [] - for _ in range(num_inputs): - inputs.append( - ( + if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input: + for _ in range(num_inputs): + inputs.append( cast(VariableBatchModelInputCallable, generate)( average_batch_size=batch_size, world_size=world_size, @@ -188,8 +189,26 @@ def gen_model_and_input( weighted_tables=weighted_tables or [], global_constant_batch=global_constant_batch, ) - if generate == ModelInput.generate_variable_batch_input - else cast(ModelInputCallable, generate)( + ) + elif generate == ModelInput.generate: + for _ in range(num_inputs): + inputs.append( + ModelInput.generate( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + long_indices=long_indices, + input_type=input_type, + ) + ) + else: + for _ in range(num_inputs): + inputs.append( + cast(ModelInputCallable, generate)( world_size=world_size, tables=tables, dedup_tables=dedup_tables, @@ -200,7 +219,6 @@ def gen_model_and_input( long_indices=long_indices, ) ) - ) return (model, inputs) @@ -297,6 +315,7 @@ def sharding_single_rank_test( global_constant_batch: bool = False, world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, + input_type: str = "kjt", # "kjt" or "td" ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: # Generate model & inputs. @@ -319,6 +338,7 @@ def sharding_single_rank_test( batch_size=batch_size, feature_processor_modules=feature_processor_modules, global_constant_batch=global_constant_batch, + input_type=input_type, ) global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index aec092354..d13d819c3 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -376,3 +376,44 @@ def _test_sharding( variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=True, ) + + +@skip_if_asan_class +class TDSequenceModelParallelTest(SequenceModelParallelTest): + + def test_sharding_variable_batch(self) -> None: + pass + + def _test_sharding( + self, + sharders: List[TestEmbeddingCollectionSharder], + backend: str = "gloo", + world_size: int = 2, + local_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, + qcomms_config: Optional[QCommsConfig] = None, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ] = None, + variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, + ) -> None: + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=world_size, + local_size=local_size, + model_class=model_class, + tables=self.tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + optim=EmbOptimType.EXACT_SGD, + backend=backend, + constraints=constraints, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, + input_type="td", + ) diff --git a/torchrec/distributed/tests/test_utils.py b/torchrec/distributed/tests/test_utils.py index 3e299192e..bdffcf7a0 100644 --- a/torchrec/distributed/tests/test_utils.py +++ b/torchrec/distributed/tests/test_utils.py @@ -263,98 +263,6 @@ def block_bucketize_ref( class KJTBucketizeTest(unittest.TestCase): - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", - ) - # pyre-ignore[56] - @given( - index_type=st.sampled_from([torch.int, torch.long]), - offset_type=st.sampled_from([torch.int, torch.long]), - world_size=st.integers(1, 129), - num_features=st.integers(1, 15), - batch_size=st.integers(1, 15), - ) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) - def test_kjt_bucketize_before_all2all( - self, - index_type: torch.dtype, - offset_type: torch.dtype, - world_size: int, - num_features: int, - batch_size: int, - ) -> None: - MAX_BATCH_SIZE = 15 - MAX_LENGTH = 10 - # max number of rows needed for a given feature to have unique row index - MAX_ROW_COUNT = MAX_LENGTH * MAX_BATCH_SIZE - - lengths_list = [ - random.randrange(MAX_LENGTH + 1) for _ in range(num_features * batch_size) - ] - keys_list = [f"feature_{i}" for i in range(num_features)] - # for each feature, generate unrepeated row indices - indices_lists = [ - random.sample( - range(MAX_ROW_COUNT), - # number of indices needed is the length sum of all batches for a feature - sum( - lengths_list[ - feature_offset * batch_size : (feature_offset + 1) * batch_size - ] - ), - ) - for feature_offset in range(num_features) - ] - indices_list = list(itertools.chain(*indices_lists)) - - weights_list = [random.randint(1, 100) for _ in range(len(indices_list))] - - # for each feature, calculate the minimum block size needed to - # distribute all rows to the available trainers - block_sizes_list = [ - ( - math.ceil((max(feature_indices_list) + 1) / world_size) - if feature_indices_list - else 1 - ) - for feature_indices_list in indices_lists - ] - - kjt = KeyedJaggedTensor( - keys=keys_list, - lengths=torch.tensor(lengths_list, dtype=offset_type) - .view(num_features * batch_size) - .cuda(), - values=torch.tensor(indices_list, dtype=index_type).cuda(), - weights=torch.tensor(weights_list, dtype=torch.float).cuda(), - ) - """ - each entry in block_sizes identifies how many hashes for each feature goes - to every rank; we have three featues in `self.features` - """ - block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda() - - block_bucketized_kjt, _ = bucketize_kjt_before_all2all( - kjt=kjt, - num_buckets=world_size, - block_sizes=block_sizes, - ) - - expected_block_bucketized_kjt = block_bucketize_ref( - kjt, - world_size, - block_sizes, - ) - - self.assertTrue( - keyed_jagged_tensor_equals( - block_bucketized_kjt, - expected_block_bucketized_kjt, - is_pooled_features=True, - ) - ) - # pyre-ignore[56] @given( index_type=st.sampled_from([torch.int, torch.long]), @@ -363,9 +271,12 @@ def test_kjt_bucketize_before_all2all( num_features=st.integers(1, 15), batch_size=st.integers(1, 15), variable_bucket_pos=st.booleans(), + device=st.sampled_from( + ["cpu"] + (["cuda"] if torch.cuda.device_count() > 0 else []) + ), ) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) - def test_kjt_bucketize_before_all2all_cpu( + @settings(verbosity=Verbosity.verbose, max_examples=50, deadline=None) + def test_kjt_bucketize_before_all2all( self, index_type: torch.dtype, offset_type: torch.dtype, @@ -373,6 +284,7 @@ def test_kjt_bucketize_before_all2all_cpu( num_features: int, batch_size: int, variable_bucket_pos: bool, + device: str, ) -> None: MAX_BATCH_SIZE = 15 MAX_LENGTH = 10 @@ -423,17 +335,17 @@ def test_kjt_bucketize_before_all2all_cpu( kjt = KeyedJaggedTensor( keys=keys_list, - lengths=torch.tensor(lengths_list, dtype=offset_type).view( + lengths=torch.tensor(lengths_list, dtype=offset_type, device=device).view( num_features * batch_size ), - values=torch.tensor(indices_list, dtype=index_type), - weights=torch.tensor(weights_list, dtype=torch.float), + values=torch.tensor(indices_list, dtype=index_type, device=device), + weights=torch.tensor(weights_list, dtype=torch.float, device=device), ) """ each entry in block_sizes identifies how many hashes for each feature goes to every rank; we have three featues in `self.features` """ - block_sizes = torch.tensor(block_sizes_list, dtype=index_type) + block_sizes = torch.tensor(block_sizes_list, dtype=index_type, device=device) block_bucketized_kjt, _ = bucketize_kjt_before_all2all( kjt=kjt, num_buckets=world_size, @@ -442,7 +354,10 @@ def test_kjt_bucketize_before_all2all_cpu( ) expected_block_bucketized_kjt = block_bucketize_ref( - kjt, world_size, block_sizes, "cpu" + kjt, + world_size, + block_sizes, + device, ) self.assertTrue( diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index e8dc5eccb..fdb900fe0 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -160,7 +160,7 @@ def main( tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -169,7 +169,7 @@ def main( ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 307d66639..d110fd57f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -19,6 +19,7 @@ pooling_type_to_str, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt @torch.fx.wrap @@ -218,7 +219,10 @@ def __init__( self._feature_names: List[List[str]] = [table.feature_names for table in tables] self.reset_parameters() - def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + def forward( + self, + features: KeyedJaggedTensor, # can also take TensorDict as input + ) -> KeyedTensor: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature. @@ -229,6 +233,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: KeyedTensor """ flat_feature_names: List[str] = [] + features = maybe_td_to_kjt(features, None) for names in self._feature_names: flat_feature_names.extend(names) inverse_indices = reorder_inverse_indices( @@ -448,7 +453,7 @@ def __init__( # noqa C901 def forward( self, - features: KeyedJaggedTensor, + features: KeyedJaggedTensor, # can also take TensorDict as input ) -> Dict[str, JaggedTensor]: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` @@ -461,6 +466,7 @@ def forward( Dict[str, JaggedTensor] """ + features = maybe_td_to_kjt(features, None) feature_embeddings: Dict[str, JaggedTensor] = {} jt_dict: Dict[str, JaggedTensor] = features.to_dict() for i, emb_module in enumerate(self.embeddings.values()): diff --git a/torchrec/sparse/tests/test_tensor_dict.py b/torchrec/sparse/tests/test_tensor_dict.py index d243fc255..2fbcc0a66 100644 --- a/torchrec/sparse/tests/test_tensor_dict.py +++ b/torchrec/sparse/tests/test_tensor_dict.py @@ -17,14 +17,14 @@ from torchrec.sparse.tensor_dict import maybe_td_to_kjt -class TestTensorDIct(unittest.TestCase): - @given(device_str=st.sampled_from(["cpu", "cuda", "meta"])) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) +class TestTensorDict(unittest.TestCase): # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", + @given( + device_str=st.sampled_from( + ["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else []) + ) ) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) def test_kjt_input(self, device_str: str) -> None: device = torch.device(device_str) values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) @@ -36,13 +36,13 @@ def test_kjt_input(self, device_str: str) -> None: features = maybe_td_to_kjt(kjt) self.assertEqual(features, kjt) - @given(device_str=st.sampled_from(["cpu", "cuda", "meta"])) - @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) # pyre-ignore[56] - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "CUDA is not available", + @given( + device_str=st.sampled_from( + ["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else []) + ) ) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) def test_td_kjt(self, device_str: str) -> None: device = torch.device(device_str) values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)