From e6d5560f86c6a5e469d09a57f6b26da441261def Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 24 Jan 2025 18:47:55 +0000 Subject: [PATCH] Revert "Revert "Revert "add NJT/TD support in test data generator (#2528)""" This reverts commit 3441ac333f69c7acfb175a0f0d260c335d91831e. --- ...enchmark_split_table_batched_embeddings.py | 9 +- .../distributed/benchmark/benchmark_utils.py | 5 +- .../distributed/test_utils/infer_utils.py | 4 +- torchrec/distributed/test_utils/test_model.py | 123 +++++------------- .../distributed/tests/test_infer_shardings.py | 3 - .../tests/pipeline_benchmarks.py | 12 +- .../tests/test_train_pipelines.py | 6 +- .../keyed_jagged_tensor_benchmark_lib.py | 1 - 8 files changed, 44 insertions(+), 119 deletions(-) diff --git a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py index 8af1f9a46..b03e7b417 100644 --- a/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py +++ b/torchrec/distributed/benchmark/benchmark_split_table_batched_embeddings.py @@ -9,8 +9,6 @@ #!/usr/bin/env python3 -from typing import Dict, List - import click import torch @@ -84,10 +82,9 @@ def op_bench( ) def _func_to_benchmark( - kjts: List[Dict[str, KeyedJaggedTensor]], + kjt: KeyedJaggedTensor, model: torch.nn.Module, ) -> torch.Tensor: - kjt = kjts[0]["feature"] return model.forward(kjt.values(), kjt.offsets()) # breakpoint() # import fbvscode; fbvscode.set_trace() @@ -111,8 +108,8 @@ def _func_to_benchmark( result = benchmark_func( name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}", - bench_inputs=[{"feature": inputs}], - prof_inputs=[{"feature": inputs}], + bench_inputs=inputs, # pyre-ignore + prof_inputs=inputs, # pyre-ignore num_benchmarks=10, num_profiles=10, profile_dir=".", diff --git a/torchrec/distributed/benchmark/benchmark_utils.py b/torchrec/distributed/benchmark/benchmark_utils.py index 22af274d6..1878fdd1f 100644 --- a/torchrec/distributed/benchmark/benchmark_utils.py +++ b/torchrec/distributed/benchmark/benchmark_utils.py @@ -374,14 +374,11 @@ def get_inputs( if train: sparse_features_by_rank = [ - model_input.idlist_features - for model_input in model_input_by_rank - if isinstance(model_input.idlist_features, KeyedJaggedTensor) + model_input.idlist_features for model_input in model_input_by_rank ] inputs_batch.append(sparse_features_by_rank) else: sparse_features = model_input_by_rank[0].idlist_features - assert isinstance(sparse_features, KeyedJaggedTensor) inputs_batch.append([sparse_features]) # Transpose if train, as inputs_by_rank is currently in [B X R] format diff --git a/torchrec/distributed/test_utils/infer_utils.py b/torchrec/distributed/test_utils/infer_utils.py index 478e01bb2..0604f1c29 100644 --- a/torchrec/distributed/test_utils/infer_utils.py +++ b/torchrec/distributed/test_utils/infer_utils.py @@ -264,7 +264,6 @@ def model_input_to_forward_args_kjt( Optional[torch.Tensor], ]: kjt = mi.idlist_features - assert isinstance(kjt, KeyedJaggedTensor) return ( kjt._keys, kjt._values, @@ -292,8 +291,7 @@ def model_input_to_forward_args( ]: idlist_kjt = mi.idlist_features idscore_kjt = mi.idscore_features - assert isinstance(idlist_kjt, KeyedJaggedTensor) - assert isinstance(idscore_kjt, KeyedJaggedTensor) + assert idscore_kjt is not None return ( mi.float_features, idlist_kjt._keys, diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 010abb459..3442b5dd3 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -14,7 +14,6 @@ import torch import torch.nn as nn -from tensordict import TensorDict from torchrec.distributed.embedding_tower_sharding import ( EmbeddingTowerCollectionSharder, EmbeddingTowerSharder, @@ -47,8 +46,8 @@ @dataclass class ModelInput(Pipelineable): float_features: torch.Tensor - idlist_features: Union[KeyedJaggedTensor, TensorDict] - idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]] + idlist_features: KeyedJaggedTensor + idscore_features: Optional[KeyedJaggedTensor] label: torch.Tensor @staticmethod @@ -77,13 +76,11 @@ def generate( randomize_indices: bool = True, device: Optional[torch.device] = None, max_feature_lengths: Optional[List[int]] = None, - input_type: str = "kjt", ) -> Tuple["ModelInput", List["ModelInput"]]: """ Returns a global (single-rank training) batch and a list of local (multi-rank training) batches of world_size. """ - batch_size_by_rank = [batch_size] * world_size if variable_batch_size: batch_size_by_rank = [ @@ -202,26 +199,11 @@ def _validate_pooling_factor( ) global_idlist_lengths.append(lengths) global_idlist_indices.append(indices) - - if input_type == "kjt": - global_idlist_input = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(global_idlist_indices), - lengths=torch.cat(global_idlist_lengths), - ) - elif input_type == "td": - dict_of_nt = { - k: torch.nested.nested_tensor_from_jagged( - values=values, - lengths=lengths, - ) - for k, values, lengths in zip( - idlist_features, global_idlist_indices, global_idlist_lengths - ) - } - global_idlist_input = TensorDict(source=dict_of_nt) - else: - raise ValueError(f"For IdList features, unknown input type {input_type}") + global_idlist_kjt = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(global_idlist_indices), + lengths=torch.cat(global_idlist_lengths), + ) for idx in range(len(idscore_ind_ranges)): ind_range = idscore_ind_ranges[idx] @@ -263,25 +245,16 @@ def _validate_pooling_factor( global_idscore_lengths.append(lengths) global_idscore_indices.append(indices) global_idscore_weights.append(weights) - - if input_type == "kjt": - global_idscore_input = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(global_idscore_indices), - lengths=torch.cat(global_idscore_lengths), - weights=torch.cat(global_idscore_weights), - ) - if global_idscore_indices - else None + global_idscore_kjt = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(global_idscore_indices), + lengths=torch.cat(global_idscore_lengths), + weights=torch.cat(global_idscore_weights), ) - elif input_type == "td": - assert ( - len(idscore_features) == 0 - ), "TensorDict does not support weighted features" - global_idscore_input = None - else: - raise ValueError(f"For weighted features, unknown input type {input_type}") + if global_idscore_indices + else None + ) if randomize_indices: global_float = torch.rand( @@ -330,48 +303,27 @@ def _validate_pooling_factor( weights[lengths_cumsum[r] : lengths_cumsum[r + 1]] ) - if input_type == "kjt": - local_idlist_input = KeyedJaggedTensor( - keys=idlist_features, - values=torch.cat(local_idlist_indices), - lengths=torch.cat(local_idlist_lengths), - ) - - local_idscore_input = ( - KeyedJaggedTensor( - keys=idscore_features, - values=torch.cat(local_idscore_indices), - lengths=torch.cat(local_idscore_lengths), - weights=torch.cat(local_idscore_weights), - ) - if local_idscore_indices - else None - ) - elif input_type == "td": - dict_of_nt = { - k: torch.nested.nested_tensor_from_jagged( - values=values, - lengths=lengths, - ) - for k, values, lengths in zip( - idlist_features, local_idlist_indices, local_idlist_lengths - ) - } - local_idlist_input = TensorDict(source=dict_of_nt) - assert ( - len(idscore_features) == 0 - ), "TensorDict does not support weighted features" - local_idscore_input = None + local_idlist_kjt = KeyedJaggedTensor( + keys=idlist_features, + values=torch.cat(local_idlist_indices), + lengths=torch.cat(local_idlist_lengths), + ) - else: - raise ValueError( - f"For weighted features, unknown input type {input_type}" + local_idscore_kjt = ( + KeyedJaggedTensor( + keys=idscore_features, + values=torch.cat(local_idscore_indices), + lengths=torch.cat(local_idscore_lengths), + weights=torch.cat(local_idscore_weights), ) + if local_idscore_indices + else None + ) local_input = ModelInput( float_features=global_float[r * batch_size : (r + 1) * batch_size], - idlist_features=local_idlist_input, - idscore_features=local_idscore_input, + idlist_features=local_idlist_kjt, + idscore_features=local_idscore_kjt, label=global_label[r * batch_size : (r + 1) * batch_size], ) local_inputs.append(local_input) @@ -379,8 +331,8 @@ def _validate_pooling_factor( return ( ModelInput( float_features=global_float, - idlist_features=global_idlist_input, - idscore_features=global_idscore_input, + idlist_features=global_idlist_kjt, + idscore_features=global_idscore_kjt, label=global_label, ), local_inputs, @@ -671,9 +623,8 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": def record_stream(self, stream: torch.Stream) -> None: self.float_features.record_stream(stream) - if isinstance(self.idlist_features, KeyedJaggedTensor): - self.idlist_features.record_stream(stream) - if isinstance(self.idscore_features, KeyedJaggedTensor): + self.idlist_features.record_stream(stream) + if self.idscore_features is not None: self.idscore_features.record_stream(stream) self.label.record_stream(stream) @@ -1880,8 +1831,6 @@ def forward(self, input: ModelInput) -> ModelInput: ) # stride will be same but features will be joined - assert isinstance(modified_input.idlist_features, KeyedJaggedTensor) - assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor) modified_input.idlist_features = KeyedJaggedTensor.concat( [modified_input.idlist_features, self._extra_input.idlist_features] ) diff --git a/torchrec/distributed/tests/test_infer_shardings.py b/torchrec/distributed/tests/test_infer_shardings.py index 83b4649ee..c7c6ef180 100755 --- a/torchrec/distributed/tests/test_infer_shardings.py +++ b/torchrec/distributed/tests/test_infer_shardings.py @@ -1987,7 +1987,6 @@ def test_sharded_quant_fp_ebc_tw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features - assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device @@ -2167,7 +2166,6 @@ def test_sharded_quant_mc_ec_rw( inputs = [] for model_input in model_inputs: kjt = model_input.idlist_features - assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = None inputs.append( @@ -2303,7 +2301,6 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None: ) inputs = [] kjt = model_inputs[0].idlist_features - assert isinstance(kjt, KeyedJaggedTensor) kjt = kjt.to(local_device) weights = torch.rand( kjt._values.size(0), dtype=torch.float, device=local_device diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index e8dc5eccb..538264c04 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -75,11 +75,6 @@ def _gen_pipelines( default=100, help="Total number of sparse embeddings to be used.", ) -@click.option( - "--ratio_features_weighted", - default=0.4, - help="percentage of features weighted vs unweighted", -) @click.option( "--dim_emb", type=int, @@ -137,7 +132,6 @@ def _gen_pipelines( def main( world_size: int, n_features: int, - ratio_features_weighted: float, dim_emb: int, n_batches: int, batch_size: int, @@ -155,9 +149,8 @@ def main( os.environ["MASTER_ADDR"] = str("localhost") os.environ["MASTER_PORT"] = str(get_free_port()) - num_weighted_features = int(n_features * ratio_features_weighted) - num_features = n_features - num_weighted_features - + num_features = n_features // 2 + num_weighted_features = n_features // 2 tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 1000, @@ -264,7 +257,6 @@ def _generate_data( world_size=world_size, num_float_features=num_float_features, pooling_avg=pooling_factor, - input_type=input_type, )[1] for i in range(num_batches) ] diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index bf708b1f5..9c39b5384 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -306,11 +306,7 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None: # `parameters`. optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01) - data = [ - i.idlist_features - for i in local_model_inputs - if isinstance(i.idlist_features, KeyedJaggedTensor) - ] + data = [i.idlist_features for i in local_model_inputs] dataloader = iter(data) pipeline = TrainPipelinePT2( model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py index 1c409fcf2..235495494 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py @@ -169,7 +169,6 @@ def generate_kjt( randomize_indices=True, device=device, )[0] - assert isinstance(global_input.idlist_features, KeyedJaggedTensor) return global_input.idlist_features