Skip to content

Commit

Permalink
Revert "Revert "Revert "add NJT/TD support in test data generator (#2528
Browse files Browse the repository at this point in the history
)"""

This reverts commit 3441ac3.
  • Loading branch information
PaulZhang12 committed Jan 24, 2025
1 parent 0ce7cc6 commit e6d5560
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

#!/usr/bin/env python3

from typing import Dict, List

import click

import torch
Expand Down Expand Up @@ -84,10 +82,9 @@ def op_bench(
)

def _func_to_benchmark(
kjts: List[Dict[str, KeyedJaggedTensor]],
kjt: KeyedJaggedTensor,
model: torch.nn.Module,
) -> torch.Tensor:
kjt = kjts[0]["feature"]
return model.forward(kjt.values(), kjt.offsets())

# breakpoint() # import fbvscode; fbvscode.set_trace()
Expand All @@ -111,8 +108,8 @@ def _func_to_benchmark(

result = benchmark_func(
name=f"SplitTableBatchedEmbeddingBagsCodegen-{num_embeddings}-{embedding_dim}-{num_tables}-{batch_size}-{bag_size}",
bench_inputs=[{"feature": inputs}],
prof_inputs=[{"feature": inputs}],
bench_inputs=inputs, # pyre-ignore
prof_inputs=inputs, # pyre-ignore
num_benchmarks=10,
num_profiles=10,
profile_dir=".",
Expand Down
5 changes: 1 addition & 4 deletions torchrec/distributed/benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,14 +374,11 @@ def get_inputs(

if train:
sparse_features_by_rank = [
model_input.idlist_features
for model_input in model_input_by_rank
if isinstance(model_input.idlist_features, KeyedJaggedTensor)
model_input.idlist_features for model_input in model_input_by_rank
]
inputs_batch.append(sparse_features_by_rank)
else:
sparse_features = model_input_by_rank[0].idlist_features
assert isinstance(sparse_features, KeyedJaggedTensor)
inputs_batch.append([sparse_features])

# Transpose if train, as inputs_by_rank is currently in [B X R] format
Expand Down
4 changes: 1 addition & 3 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def model_input_to_forward_args_kjt(
Optional[torch.Tensor],
]:
kjt = mi.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
return (
kjt._keys,
kjt._values,
Expand Down Expand Up @@ -292,8 +291,7 @@ def model_input_to_forward_args(
]:
idlist_kjt = mi.idlist_features
idscore_kjt = mi.idscore_features
assert isinstance(idlist_kjt, KeyedJaggedTensor)
assert isinstance(idscore_kjt, KeyedJaggedTensor)
assert idscore_kjt is not None
return (
mi.float_features,
idlist_kjt._keys,
Expand Down
123 changes: 36 additions & 87 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import torch
import torch.nn as nn
from tensordict import TensorDict
from torchrec.distributed.embedding_tower_sharding import (
EmbeddingTowerCollectionSharder,
EmbeddingTowerSharder,
Expand Down Expand Up @@ -47,8 +46,8 @@
@dataclass
class ModelInput(Pipelineable):
float_features: torch.Tensor
idlist_features: Union[KeyedJaggedTensor, TensorDict]
idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]]
idlist_features: KeyedJaggedTensor
idscore_features: Optional[KeyedJaggedTensor]
label: torch.Tensor

@staticmethod
Expand Down Expand Up @@ -77,13 +76,11 @@ def generate(
randomize_indices: bool = True,
device: Optional[torch.device] = None,
max_feature_lengths: Optional[List[int]] = None,
input_type: str = "kjt",
) -> Tuple["ModelInput", List["ModelInput"]]:
"""
Returns a global (single-rank training) batch
and a list of local (multi-rank training) batches of world_size.
"""

batch_size_by_rank = [batch_size] * world_size
if variable_batch_size:
batch_size_by_rank = [
Expand Down Expand Up @@ -202,26 +199,11 @@ def _validate_pooling_factor(
)
global_idlist_lengths.append(lengths)
global_idlist_indices.append(indices)

if input_type == "kjt":
global_idlist_input = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(global_idlist_indices),
lengths=torch.cat(global_idlist_lengths),
)
elif input_type == "td":
dict_of_nt = {
k: torch.nested.nested_tensor_from_jagged(
values=values,
lengths=lengths,
)
for k, values, lengths in zip(
idlist_features, global_idlist_indices, global_idlist_lengths
)
}
global_idlist_input = TensorDict(source=dict_of_nt)
else:
raise ValueError(f"For IdList features, unknown input type {input_type}")
global_idlist_kjt = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(global_idlist_indices),
lengths=torch.cat(global_idlist_lengths),
)

for idx in range(len(idscore_ind_ranges)):
ind_range = idscore_ind_ranges[idx]
Expand Down Expand Up @@ -263,25 +245,16 @@ def _validate_pooling_factor(
global_idscore_lengths.append(lengths)
global_idscore_indices.append(indices)
global_idscore_weights.append(weights)

if input_type == "kjt":
global_idscore_input = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(global_idscore_indices),
lengths=torch.cat(global_idscore_lengths),
weights=torch.cat(global_idscore_weights),
)
if global_idscore_indices
else None
global_idscore_kjt = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(global_idscore_indices),
lengths=torch.cat(global_idscore_lengths),
weights=torch.cat(global_idscore_weights),
)
elif input_type == "td":
assert (
len(idscore_features) == 0
), "TensorDict does not support weighted features"
global_idscore_input = None
else:
raise ValueError(f"For weighted features, unknown input type {input_type}")
if global_idscore_indices
else None
)

if randomize_indices:
global_float = torch.rand(
Expand Down Expand Up @@ -330,57 +303,36 @@ def _validate_pooling_factor(
weights[lengths_cumsum[r] : lengths_cumsum[r + 1]]
)

if input_type == "kjt":
local_idlist_input = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(local_idlist_indices),
lengths=torch.cat(local_idlist_lengths),
)

local_idscore_input = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(local_idscore_indices),
lengths=torch.cat(local_idscore_lengths),
weights=torch.cat(local_idscore_weights),
)
if local_idscore_indices
else None
)
elif input_type == "td":
dict_of_nt = {
k: torch.nested.nested_tensor_from_jagged(
values=values,
lengths=lengths,
)
for k, values, lengths in zip(
idlist_features, local_idlist_indices, local_idlist_lengths
)
}
local_idlist_input = TensorDict(source=dict_of_nt)
assert (
len(idscore_features) == 0
), "TensorDict does not support weighted features"
local_idscore_input = None
local_idlist_kjt = KeyedJaggedTensor(
keys=idlist_features,
values=torch.cat(local_idlist_indices),
lengths=torch.cat(local_idlist_lengths),
)

else:
raise ValueError(
f"For weighted features, unknown input type {input_type}"
local_idscore_kjt = (
KeyedJaggedTensor(
keys=idscore_features,
values=torch.cat(local_idscore_indices),
lengths=torch.cat(local_idscore_lengths),
weights=torch.cat(local_idscore_weights),
)
if local_idscore_indices
else None
)

local_input = ModelInput(
float_features=global_float[r * batch_size : (r + 1) * batch_size],
idlist_features=local_idlist_input,
idscore_features=local_idscore_input,
idlist_features=local_idlist_kjt,
idscore_features=local_idscore_kjt,
label=global_label[r * batch_size : (r + 1) * batch_size],
)
local_inputs.append(local_input)

return (
ModelInput(
float_features=global_float,
idlist_features=global_idlist_input,
idscore_features=global_idscore_input,
idlist_features=global_idlist_kjt,
idscore_features=global_idscore_kjt,
label=global_label,
),
local_inputs,
Expand Down Expand Up @@ -671,9 +623,8 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":

def record_stream(self, stream: torch.Stream) -> None:
self.float_features.record_stream(stream)
if isinstance(self.idlist_features, KeyedJaggedTensor):
self.idlist_features.record_stream(stream)
if isinstance(self.idscore_features, KeyedJaggedTensor):
self.idlist_features.record_stream(stream)
if self.idscore_features is not None:
self.idscore_features.record_stream(stream)
self.label.record_stream(stream)

Expand Down Expand Up @@ -1880,8 +1831,6 @@ def forward(self, input: ModelInput) -> ModelInput:
)

# stride will be same but features will be joined
assert isinstance(modified_input.idlist_features, KeyedJaggedTensor)
assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor)
modified_input.idlist_features = KeyedJaggedTensor.concat(
[modified_input.idlist_features, self._extra_input.idlist_features]
)
Expand Down
3 changes: 0 additions & 3 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,7 +1987,6 @@ def test_sharded_quant_fp_ebc_tw(
inputs = []
for model_input in model_inputs:
kjt = model_input.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = torch.rand(
kjt._values.size(0), dtype=torch.float, device=local_device
Expand Down Expand Up @@ -2167,7 +2166,6 @@ def test_sharded_quant_mc_ec_rw(
inputs = []
for model_input in model_inputs:
kjt = model_input.idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = None
inputs.append(
Expand Down Expand Up @@ -2303,7 +2301,6 @@ def test_sharded_quant_fp_ebc_tw_meta(self, compute_device: str) -> None:
)
inputs = []
kjt = model_inputs[0].idlist_features
assert isinstance(kjt, KeyedJaggedTensor)
kjt = kjt.to(local_device)
weights = torch.rand(
kjt._values.size(0), dtype=torch.float, device=local_device
Expand Down
12 changes: 2 additions & 10 deletions torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ def _gen_pipelines(
default=100,
help="Total number of sparse embeddings to be used.",
)
@click.option(
"--ratio_features_weighted",
default=0.4,
help="percentage of features weighted vs unweighted",
)
@click.option(
"--dim_emb",
type=int,
Expand Down Expand Up @@ -137,7 +132,6 @@ def _gen_pipelines(
def main(
world_size: int,
n_features: int,
ratio_features_weighted: float,
dim_emb: int,
n_batches: int,
batch_size: int,
Expand All @@ -155,9 +149,8 @@ def main(
os.environ["MASTER_ADDR"] = str("localhost")
os.environ["MASTER_PORT"] = str(get_free_port())

num_weighted_features = int(n_features * ratio_features_weighted)
num_features = n_features - num_weighted_features

num_features = n_features // 2
num_weighted_features = n_features // 2
tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 1000,
Expand Down Expand Up @@ -264,7 +257,6 @@ def _generate_data(
world_size=world_size,
num_float_features=num_float_features,
pooling_avg=pooling_factor,
input_type=input_type,
)[1]
for i in range(num_batches)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,7 @@ def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
# `parameters`.
optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01)

data = [
i.idlist_features
for i in local_model_inputs
if isinstance(i.idlist_features, KeyedJaggedTensor)
]
data = [i.idlist_features for i in local_model_inputs]
dataloader = iter(data)
pipeline = TrainPipelinePT2(
model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing
Expand Down
1 change: 0 additions & 1 deletion torchrec/sparse/tests/keyed_jagged_tensor_benchmark_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def generate_kjt(
randomize_indices=True,
device=device,
)[0]
assert isinstance(global_input.idlist_features, KeyedJaggedTensor)
return global_input.idlist_features


Expand Down

0 comments on commit e6d5560

Please sign in to comment.