From b227f54dc9b6262cc2f7c210d7c494d0ff1df21c Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 23 Jan 2025 11:34:45 +0000 Subject: [PATCH] 2025-01-23 nightly release (4d7b7ffbb5ad0f3548d77907d6e41c0a1639f660) --- .github/scripts/install_fbgemm.sh | 4 ++ torchrec/distributed/embedding.py | 13 +----- torchrec/distributed/embeddingbag.py | 16 ++------ torchrec/distributed/model_parallel.py | 35 ++++------------ .../distributed/test_utils/test_sharding.py | 32 +++------------ .../tests/test_sequence_model_parallel.py | 41 ------------------- .../tests/pipeline_benchmarks.py | 4 +- torchrec/modules/embedding_modules.py | 10 +---- torchrec/schema/utils.py | 30 +++++++++++++- 9 files changed, 58 insertions(+), 127 deletions(-) diff --git a/.github/scripts/install_fbgemm.sh b/.github/scripts/install_fbgemm.sh index 0b0b3e347..dc643deaa 100644 --- a/.github/scripts/install_fbgemm.sh +++ b/.github/scripts/install_fbgemm.sh @@ -15,6 +15,10 @@ if [[ $CU_VERSION = cu* ]]; then echo "[NOVA] Setting LD_LIBRARY_PATH ..." conda env config vars set -p ${CONDA_ENV} \ LD_LIBRARY_PATH="/usr/local/lib:${CUDA_HOME}/lib64:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}" +else + echo "[NOVA] Setting LD_LIBRARY_PATH ..." + conda env config vars set -p ${CONDA_ENV} \ + LD_LIBRARY_PATH="/usr/local/lib:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}" fi if [ "$CHANNEL" = "nightly" ]; then diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index feb77a72a..93773cc1f 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -26,7 +26,6 @@ ) import torch -from tensordict import TensorDict from torch import distributed as dist, nn from torch.autograd.profiler import record_function from torch.distributed._shard.sharding_spec import EnumerableShardingSpec @@ -91,7 +90,6 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor -from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -1200,15 +1198,8 @@ def _compute_sequence_vbe_context( def input_dist( self, ctx: EmbeddingCollectionContext, - features: TypeUnion[KeyedJaggedTensor, TensorDict], + features: KeyedJaggedTensor, ) -> Awaitable[Awaitable[KJTList]]: - need_permute: bool = True - if isinstance(features, TensorDict): - feature_keys = list(features.keys()) # pyre-ignore[6] - if self._features_order: - feature_keys = [feature_keys[i] for i in self._features_order] - need_permute = False - features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] if self._has_uninitialized_input_dist: self._create_input_dist(input_feature_names=features.keys()) self._has_uninitialized_input_dist = False @@ -1218,7 +1209,7 @@ def input_dist( unpadded_features = features features = pad_vbe_kjt_lengths(unpadded_features) - if need_permute and self._features_order: + if self._features_order: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index de3d495f2..8cfd16ae9 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,7 +27,6 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings -from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function from torch.distributed._shard.sharded_tensor import TensorProperties @@ -95,7 +94,6 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor -from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -658,7 +656,9 @@ def __init__( self._inverse_indices_permute_indices: Optional[torch.Tensor] = None # to support mean pooling callback hook self._has_mean_pooling_callback: bool = ( - PoolingType.MEAN.value in self._pooling_type_to_rs_features + True + if PoolingType.MEAN.value in self._pooling_type_to_rs_features + else False ) self._dim_per_key: Optional[torch.Tensor] = None self._kjt_key_indices: Dict[str, int] = {} @@ -1189,16 +1189,8 @@ def _create_inverse_indices_permute_indices( # pyre-ignore [14] def input_dist( - self, - ctx: EmbeddingBagCollectionContext, - features: Union[KeyedJaggedTensor, TensorDict], + self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor ) -> Awaitable[Awaitable[KJTList]]: - if isinstance(features, TensorDict): - feature_keys = list(features.keys()) # pyre-ignore[6] - if len(self._features_order) > 0: - feature_keys = [feature_keys[i] for i in self._features_order] - self._has_features_permute = False # feature_keys are in order - features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] ctx.variable_batch_per_feature = features.variable_stride_per_key() ctx.inverse_indices = features.inverse_indices_or_none() if self._has_uninitialized_input_dist: diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 5cbd2429b..f8b32106b 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -770,7 +770,7 @@ def _create_process_groups( ) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: """ Creates process groups for sharding and replication, the process groups - are created in the same exact order on all ranks as per `dist.new_group` API. + are created using the DeviceMesh API. Args: global_rank (int): The global rank of the current process. @@ -781,37 +781,12 @@ def _create_process_groups( Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh, replication process group, and allreduce process group. """ - # TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a peer_matrix = [] - sharding_pg, replica_pg = None, None step = world_size // local_size - my_group_rank = global_rank % step for group_rank in range(world_size // local_size): peers = [step * r + group_rank for r in range(local_size)] - backend = dist.get_backend(self._pg) - curr_pg = dist.new_group(backend=backend, ranks=peers) peer_matrix.append(peers) - if my_group_rank == group_rank: - logger.warning( - f"[Connection] 2D sharding_group: [{global_rank}] -> [{peers}]" - ) - sharding_pg = curr_pg - assert sharding_pg is not None, "sharding_pg is not initialized!" - dist.barrier() - - my_inter_rank = global_rank // step - for inter_rank in range(local_size): - peers = [inter_rank * step + r for r in range(step)] - backend = dist.get_backend(self._pg) - curr_pg = dist.new_group(backend=backend, ranks=peers) - if my_inter_rank == inter_rank: - logger.warning( - f"[Connection] 2D replica_group: [{global_rank}] -> [{peers}]" - ) - replica_pg = curr_pg - assert replica_pg is not None, "replica_pg is not initialized!" - dist.barrier() mesh = DeviceMesh( device_type=self._device.type, @@ -819,6 +794,14 @@ def _create_process_groups( mesh_dim_names=("replicate", "shard"), ) logger.warning(f"[Connection] 2D Device Mesh created: {mesh}") + sharding_pg = mesh.get_group(mesh_dim="shard") + logger.warning( + f"[Connection] 2D sharding_group: [{global_rank}] -> [{mesh['shard']}]" + ) + replica_pg = mesh.get_group(mesh_dim="replicate") + logger.warning( + f"[Connection] 2D replica_group: [{global_rank}] -> [{mesh['replicate']}]" + ) return mesh, sharding_pg, replica_pg diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 48b9a90ab..f2b65a833 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -147,7 +147,6 @@ def gen_model_and_input( long_indices: bool = True, global_constant_batch: bool = False, num_inputs: int = 1, - input_type: str = "kjt", # "kjt" or "td" ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -178,9 +177,9 @@ def gen_model_and_input( feature_processor_modules=feature_processor_modules, ) inputs = [] - if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input: - for _ in range(num_inputs): - inputs.append( + for _ in range(num_inputs): + inputs.append( + ( cast(VariableBatchModelInputCallable, generate)( average_batch_size=batch_size, world_size=world_size, @@ -189,26 +188,8 @@ def gen_model_and_input( weighted_tables=weighted_tables or [], global_constant_batch=global_constant_batch, ) - ) - elif generate == ModelInput.generate: - for _ in range(num_inputs): - inputs.append( - ModelInput.generate( - world_size=world_size, - tables=tables, - dedup_tables=dedup_tables, - weighted_tables=weighted_tables or [], - num_float_features=num_float_features, - variable_batch_size=variable_batch_size, - batch_size=batch_size, - long_indices=long_indices, - input_type=input_type, - ) - ) - else: - for _ in range(num_inputs): - inputs.append( - cast(ModelInputCallable, generate)( + if generate == ModelInput.generate_variable_batch_input + else cast(ModelInputCallable, generate)( world_size=world_size, tables=tables, dedup_tables=dedup_tables, @@ -219,6 +200,7 @@ def gen_model_and_input( long_indices=long_indices, ) ) + ) return (model, inputs) @@ -315,7 +297,6 @@ def sharding_single_rank_test( global_constant_batch: bool = False, world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, - input_type: str = "kjt", # "kjt" or "td" ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: # Generate model & inputs. @@ -338,7 +319,6 @@ def sharding_single_rank_test( batch_size=batch_size, feature_processor_modules=feature_processor_modules, global_constant_batch=global_constant_batch, - input_type=input_type, ) global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index d13d819c3..aec092354 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -376,44 +376,3 @@ def _test_sharding( variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=True, ) - - -@skip_if_asan_class -class TDSequenceModelParallelTest(SequenceModelParallelTest): - - def test_sharding_variable_batch(self) -> None: - pass - - def _test_sharding( - self, - sharders: List[TestEmbeddingCollectionSharder], - backend: str = "gloo", - world_size: int = 2, - local_size: Optional[int] = None, - constraints: Optional[Dict[str, ParameterConstraints]] = None, - model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, - qcomms_config: Optional[QCommsConfig] = None, - apply_optimizer_in_backward_config: Optional[ - Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] - ] = None, - variable_batch_size: bool = False, - variable_batch_per_feature: bool = False, - ) -> None: - self._run_multi_process_test( - callable=sharding_single_rank_test, - world_size=world_size, - local_size=local_size, - model_class=model_class, - tables=self.tables, - embedding_groups=self.embedding_groups, - sharders=sharders, - optim=EmbOptimType.EXACT_SGD, - backend=backend, - constraints=constraints, - qcomms_config=qcomms_config, - apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, - variable_batch_size=variable_batch_size, - variable_batch_per_feature=variable_batch_per_feature, - global_constant_batch=True, - input_type="td", - ) diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index fdb900fe0..e8dc5eccb 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -160,7 +160,7 @@ def main( tables = [ EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, + num_embeddings=(i + 1) * 1000, embedding_dim=dim_emb, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -169,7 +169,7 @@ def main( ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, + num_embeddings=(i + 1) * 1000, embedding_dim=dim_emb, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index d110fd57f..307d66639 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -19,7 +19,6 @@ pooling_type_to_str, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor -from torchrec.sparse.tensor_dict import maybe_td_to_kjt @torch.fx.wrap @@ -219,10 +218,7 @@ def __init__( self._feature_names: List[List[str]] = [table.feature_names for table in tables] self.reset_parameters() - def forward( - self, - features: KeyedJaggedTensor, # can also take TensorDict as input - ) -> KeyedTensor: + def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature. @@ -233,7 +229,6 @@ def forward( KeyedTensor """ flat_feature_names: List[str] = [] - features = maybe_td_to_kjt(features, None) for names in self._feature_names: flat_feature_names.extend(names) inverse_indices = reorder_inverse_indices( @@ -453,7 +448,7 @@ def __init__( # noqa C901 def forward( self, - features: KeyedJaggedTensor, # can also take TensorDict as input + features: KeyedJaggedTensor, ) -> Dict[str, JaggedTensor]: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` @@ -466,7 +461,6 @@ def forward( Dict[str, JaggedTensor] """ - features = maybe_td_to_kjt(features, None) feature_embeddings: Dict[str, JaggedTensor] = {} jt_dict: Dict[str, JaggedTensor] = features.to_dict() for i, emb_module in enumerate(self.embeddings.values()): diff --git a/torchrec/schema/utils.py b/torchrec/schema/utils.py index 0f9b897cb..b4f8a6075 100644 --- a/torchrec/schema/utils.py +++ b/torchrec/schema/utils.py @@ -8,6 +8,32 @@ # pyre-strict import inspect +import typing +from typing import Any + + +def _is_annot_compatible(prev: object, curr: object) -> bool: + if prev == curr: + return True + + if not (prev_origin := typing.get_origin(prev)): + return False + if not (curr_origin := typing.get_origin(curr)): + return False + + if prev_origin != curr_origin: + return False + + prev_args = typing.get_args(prev) + curr_args = typing.get_args(curr) + if len(prev_args) != len(curr_args): + return False + + for prev_arg, curr_arg in zip(prev_args, curr_args): + if not _is_annot_compatible(prev_arg, curr_arg): + return False + + return True def is_signature_compatible( @@ -84,6 +110,8 @@ def is_signature_compatible( return False # TODO: Account for Union Types? - if current_signature.return_annotation != previous_signature.return_annotation: + if not _is_annot_compatible( + previous_signature.return_annotation, current_signature.return_annotation + ): return False return True