Skip to content

Commit

Permalink
2024-04-24 nightly release (76e854c)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Apr 24, 2024
1 parent b839d5e commit 2183c44
Show file tree
Hide file tree
Showing 18 changed files with 1,045 additions and 75 deletions.
73 changes: 73 additions & 0 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ def _fx_wrap_gen_list_n_times(ls: List[str], n: int) -> List[str]:
return ret


@torch.fx.wrap
def _fx_wrap_gen_keys(ls: List[str], n: int) -> List[str]:
# Syntax for dynamo (instead of generator kjt.keys() * num_buckets)
return ls * n


def bucketize_kjt_before_all2all(
kjt: KeyedJaggedTensor,
num_buckets: int,
Expand Down Expand Up @@ -172,6 +178,73 @@ def bucketize_kjt_before_all2all(
)


def bucketize_kjt_inference(
kjt: KeyedJaggedTensor,
num_buckets: int,
block_sizes: torch.Tensor,
output_permute: bool = False,
bucketize_pos: bool = False,
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
return_bucket_mapping: bool = False,
) -> Tuple[KeyedJaggedTensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Bucketizes the `values` in KeyedJaggedTensor into `num_buckets` buckets,
`lengths` are readjusted based on the bucketization results.
Note: This function should be used only for row-wise sharding before calling
`KJTAllToAll`.
Args:
num_buckets (int): number of buckets to bucketize the values into.
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
output_permute (bool): output the memory location mapping from the unbucketized
values to bucketized values or not.
bucketize_pos (bool): output the changed position of the bucketized values or
not.
block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature.
Returns:
Tuple[KeyedJaggedTensor, Optional[torch.Tensor]]: the bucketized `KeyedJaggedTensor` and the optional permute mapping from the unbucketized values to bucketized value.
"""

num_features = len(kjt.keys())
assert_fx_safe(
block_sizes.numel() == num_features,
f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.",
)
block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values())
(
bucketized_lengths,
bucketized_indices,
bucketized_weights,
pos,
unbucketize_permute,
bucket_mapping,
) = torch.ops.fbgemm.block_bucketize_sparse_features_inference(
kjt.lengths().view(-1),
kjt.values(),
bucketize_pos=bucketize_pos,
sequence=output_permute,
block_sizes=block_sizes_new_type,
my_size=num_buckets,
weights=kjt.weights_or_none(),
max_B=_fx_wrap_max_B(kjt),
block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths()
return_bucket_mapping=return_bucket_mapping,
)

return (
KeyedJaggedTensor(
keys=_fx_wrap_gen_keys(kjt.keys(), num_buckets),
values=bucketized_indices,
weights=pos if bucketize_pos else bucketized_weights,
lengths=bucketized_lengths.view(-1),
),
unbucketize_permute,
bucket_mapping,
)


def _get_weighted_avg_cache_load_factor(
embedding_tables: List[ShardedEmbeddingTable],
) -> Optional[float]:
Expand Down
102 changes: 90 additions & 12 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

# pyre-strict

import copy
import logging
import math
import statistics
from collections import defaultdict
from typing import Any, cast, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union

from torch import nn

Expand Down Expand Up @@ -218,8 +221,8 @@ def log(
if sharding_type not in stats[rank]["type"]:
stats[rank]["type"][sharding_type] = 0

rank_hbm = f"{round(used_hbm_gb, 1)} ({used_hbm_ratio:.0%})"
rank_ddr = f"{round(used_ddr_gb, 1)} ({used_ddr_ratio:.0%})"
rank_hbm = f"{round(used_hbm_gb, 3)} ({used_hbm_ratio:.0%})"
rank_ddr = f"{round(used_ddr_gb, 3)} ({used_ddr_ratio:.0%})"
rank_perf = _format_perf_breakdown(perf[rank])
rank_input = f"{round(stats[rank]['input_sizes'], 2)}"
rank_output = f"{round(stats[rank]['output_sizes'], 2)}"
Expand Down Expand Up @@ -519,14 +522,60 @@ def _log_max_perf_and_max_hbm(self, perfs: List[Perf], used_hbm: List[int]) -> N
)
self._stats_table.append(f"# {sum_of_maxima_text : <{self._width-3}}#")

max_hbm = max(used_hbm)
max_hbm_indices = [i for i in range(len(used_hbm)) if used_hbm[i] == max_hbm]
rank_text = "ranks" if len(max_hbm_indices) > 1 else "rank"
max_hbm_indices = _collapse_consecutive_ranks(max_hbm_indices)
max_hbm_ranks = f"{rank_text} {','.join(max_hbm_indices)}"
peak_memory_pressure = f"Peak Memory Pressure: {round(bytes_to_gb(max_hbm), 3)} GB on {max_hbm_ranks}"
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {peak_memory_pressure : <{self._width-3}}#")
self._stats_table.append(
f"# {'Estimated Sharding Distribution' : <{self._width-2}}#"
)
self._stats_table.append(
f"# {'Max HBM: '+_generate_rank_hbm_stats(used_hbm, max) : <{self._width-3}}#"
)
self._stats_table.append(
f"# {'Min HBM: '+_generate_rank_hbm_stats(used_hbm, min) : <{self._width-3}}#"
)
self._stats_table.append(
f"# {'Mean HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.mean) : <{self._width-3}}#"
)
self._stats_table.append(
f"# {'Low Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_low) : <{self._width-3}}#"
)
self._stats_table.append(
f"# {'High Median HBM: '+_generate_rank_hbm_stats(used_hbm, statistics.median_high) : <{self._width-3}}#"
)

self._stats_table.append(f"#{'' : ^{self._width-2}}#")
per_rank_hbm = copy.copy(used_hbm)
NUM_PEAK_RANK = 5
peak_memory_pressure = []

top_hbm_usage_estimation = f"Top HBM Memory Usage Estimation: {round(bytes_to_gb(max(used_hbm)), 3)} GB"
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {top_hbm_usage_estimation : <{self._width-3}}#")

for top in range(NUM_PEAK_RANK):
if not per_rank_hbm:
break
max_hbm = max(per_rank_hbm)
max_hbm_indices = [
i
for i in range(len(per_rank_hbm))
if math.isclose(
bytes_to_mb(per_rank_hbm[i]), bytes_to_mb(max_hbm), abs_tol=1.0
)
]
rank_text = "ranks" if len(max_hbm_indices) > 1 else "rank"
max_hbm_indices = _collapse_consecutive_ranks(max_hbm_indices)
max_hbm_ranks = f"{rank_text} {','.join(max_hbm_indices)}"
peak_memory_pressure.append(
f"Top Tier #{top+1} Estimated Peak HBM Pressure: {round(bytes_to_gb(max_hbm), 3)} GB on {max_hbm_ranks}"
)
per_rank_hbm = [
hbm
for hbm in per_rank_hbm
if not math.isclose(bytes_to_mb(hbm), bytes_to_mb(max_hbm), abs_tol=1.0)
]

for peak_rank in reversed(peak_memory_pressure):
self._stats_table.append(f"# {peak_rank : <{self._width-3}}#")

def _log_storage_reservation_stats(
self,
Expand All @@ -540,13 +589,21 @@ def _log_storage_reservation_stats(
usable_hbm = round(
bytes_to_gb(int((1 - reserved_hbm_percent) * device_storage.hbm)), 3
)
reserved_hbm = round(
bytes_to_gb(int(reserved_hbm_percent * device_storage.hbm)), 3
)
reserved_memory = f"HBM: {reserved_hbm} GB"
reserved_hbm_percentage = f"Percent of Total HBM: {reserved_hbm_percent:.0%}"
usable_ddr = round(bytes_to_gb(int(device_storage.ddr)), 3)
usable_memory = f"HBM: {usable_hbm} GB, DDR: {usable_ddr} GB"
usable_hbm_percentage = (
f"Percent of Total HBM: {(1 - reserved_hbm_percent):.0%}"
)
self._stats_table.append(f"#{'' : ^{self._width-2}}#")
self._stats_table.append(f"# {'Usable Memory:' : <{self._width-3}}#")
self._stats_table.append(f"# {'Reserved Memory:' : <{self._width-3}}#")
self._stats_table.append(f"# {reserved_memory : <{self._width-6}}#")
self._stats_table.append(f"# {reserved_hbm_percentage : <{self._width-6}}#")
self._stats_table.append(f"# {'Planning Memory:' : <{self._width-3}}#")
self._stats_table.append(f"# {usable_memory : <{self._width-6}}#")
self._stats_table.append(f"# {usable_hbm_percentage : <{self._width-6}}#")

Expand Down Expand Up @@ -582,7 +639,15 @@ def _log_imbalance_tables(self, best_plan: List[ShardingOption]) -> None:
f"# {'Top 5 Tables Causing Max HBM:' : <{self._width-3}}#"
)
for sharding_option in hbm_imbalance_tables[0:5]:
self._stats_table.append(f"# {sharding_option.name : <{self._width-6}}#")
storage = sharding_option.shards[0].storage
assert storage is not None # linter friendly optional check

rank_text = "ranks" if len(sharding_option.shards) > 1 else "rank"
top_table = (
f"{sharding_option.name}: {round(bytes_to_gb(storage.hbm),3)} GB on {rank_text} "
f"{[shard.rank for shard in sharding_option.shards]}"
)
self._stats_table.append(f"# {top_table : <{self._width-6}}#")

def _log_compute_kernel_stats(
self, compute_kernels_to_count: Dict[str, int]
Expand All @@ -597,6 +662,19 @@ def _log_compute_kernel_stats(
self._stats_table.append(f"# {compute_kernel_count : <{self._width-6}}#")


def _generate_rank_hbm_stats(
per_rank_hbm: List[int], func: Callable[[Iterable[float]], float]
) -> str:
stats = round(func(per_rank_hbm))
stats_indicies = [
i
for i in range(len(per_rank_hbm))
if math.isclose(bytes_to_mb(per_rank_hbm[i]), bytes_to_mb(stats), abs_tol=1.0)
]
rank_text = "ranks" if len(stats_indicies) > 1 else "rank"
return f"{round(bytes_to_gb(stats), 3)} GB on {rank_text} {stats_indicies}"


def _generate_max_text(perfs: List[float]) -> str:
max_perf = max(perfs)

Expand Down
22 changes: 20 additions & 2 deletions torchrec/distributed/planner/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
from torchrec.distributed.planner.stats import NoopEmbeddingStats
from torchrec.distributed.planner.stats import EmbeddingStats, NoopEmbeddingStats
from torchrec.distributed.planner.types import Topology
from torchrec.distributed.test_utils.test_model import TestSparseNN
from torchrec.distributed.types import ModuleSharder, ShardingType
Expand Down Expand Up @@ -53,7 +53,9 @@ def test_embedding_stats_runs(self) -> None:
planner = EmbeddingShardingPlanner(topology=self.topology)
_ = planner.plan(module=self.model, sharders=[TWvsRWSharder()])
self.assertEqual(len(planner._stats), 1)
stats: List[str] = planner._stats[0]._stats_table # pyre-ignore[16]
stats_data = planner._stats[0]
assert isinstance(stats_data, EmbeddingStats)
stats: List[str] = stats_data._stats_table
self.assertTrue(isinstance(stats, list))
self.assertTrue(stats[0].startswith("####"))

Expand All @@ -68,3 +70,19 @@ def test_noop_embedding_stats_runs(self) -> None:
)
_ = planner.plan(module=self.model, sharders=[TWvsRWSharder()])
self.assertEqual(len(planner._stats), 1)

def test_embedding_stats_output_with_top_hbm_usage(self) -> None:
planner = EmbeddingShardingPlanner(topology=self.topology)
_ = planner.plan(module=self.model, sharders=[TWvsRWSharder()])
self.assertEqual(len(planner._stats), 1)
stats_data = planner._stats[0]
assert isinstance(stats_data, EmbeddingStats)
stats: List[str] = stats_data._stats_table
self.assertTrue(isinstance(stats, list))
top_hbm_usage_keyword = "Top HBM Memory Usage Estimation:"
self.assertTrue(any(top_hbm_usage_keyword in row for row in stats))
top_hbm_mem_usage = None
for row in stats:
if top_hbm_usage_keyword in row:
top_hbm_mem_usage = float(row.split(" ")[6])
self.assertIsNotNone(top_hbm_mem_usage)
12 changes: 11 additions & 1 deletion torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BaseEmbeddingLookup,
BaseSparseFeaturesDist,
bucketize_kjt_before_all2all,
bucketize_kjt_inference,
EmbeddingSharding,
EmbeddingShardingContext,
EmbeddingShardingInfo,
Expand Down Expand Up @@ -683,6 +684,8 @@ def __init__(
self._has_feature_processor = has_feature_processor
self._need_pos = need_pos
self.unbucketize_permute_tensor: Optional[torch.Tensor] = None
self.bucket_mapping_tensor: Optional[torch.Tensor] = None
self.bucketized_length_tensor: Optional[torch.Tensor] = None
self._embedding_shard_metadata = emb_sharding

def forward(
Expand All @@ -700,7 +703,8 @@ def forward(
(
bucketized_features,
self.unbucketize_permute_tensor,
) = bucketize_kjt_before_all2all(
bucket_mapping_tensor_opt,
) = bucketize_kjt_inference(
sparse_features,
num_buckets=self._world_size,
block_sizes=block_sizes,
Expand All @@ -711,6 +715,12 @@ def forward(
else self._need_pos
),
block_bucketize_row_pos=block_bucketize_row_pos,
return_bucket_mapping=self._is_sequence,
)
self.bucket_mapping_tensor = bucket_mapping_tensor_opt
# 2d requried
self.bucketized_length_tensor = bucketized_features.lengths().view(
self._world_size * self._num_features, -1
)
# KJTOneToAll
kjt = bucketized_features
Expand Down
6 changes: 6 additions & 0 deletions torchrec/distributed/sharding/sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class InferSequenceShardingContext(Multistreamable):
features: KJTList
features_before_input_dist: Optional[KeyedJaggedTensor] = None
unbucketize_permute_tensor: Optional[torch.Tensor] = None
bucket_mapping_tensor: Optional[torch.Tensor] = None
bucketized_length: Optional[torch.Tensor] = None

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for feature in self.features:
Expand All @@ -104,3 +106,7 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
self.features_before_input_dist.record_stream(stream)
if self.unbucketize_permute_tensor is not None:
self.unbucketize_permute_tensor.record_stream(stream)
if self.bucket_mapping_tensor is not None:
self.bucket_mapping_tensor.record_stream(stream)
if self.bucketized_length is not None:
self.bucketized_length.record_stream(stream)
7 changes: 4 additions & 3 deletions torchrec/distributed/tests/test_embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def test_should_group_together_with_prefetch(
]
),
)
@settings(max_examples=10, deadline=10000)
@settings(max_examples=100, deadline=10000)
def test_should_not_group_together(
self,
data_types: List[DataType],
Expand Down Expand Up @@ -411,8 +411,9 @@ def test_should_not_group_together(
)
return

# emb dim bucketizier only in use when computer kernel is caching
if distinct_key == "local_dim" and _prefetch_and_cached(tables[0]):
# emb dim bucketizier only in use when computer kernel is caching. Otherwise
# they shall be grouped into the same bucket even with different dimensions
if distinct_key == "local_dim" and not _prefetch_and_cached(tables[0]):
self.assertEqual(
_get_table_names_by_groups(tables),
[["table_0", "table_1"]],
Expand Down
5 changes: 5 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,11 @@ def test_sharded_quant_fp_ebc_tw(self, weight_dtype: torch.dtype) -> None:
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
def test_sharded_quant_fp_ebc_tw_meta(self) -> None:
# Simulate inference, take unsharded cpu model and shard on meta
# Use PositionWeightedModuleCollection, FP used in production
Expand Down
5 changes: 5 additions & 0 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,11 @@ def test_sharded_quant_ebc_non_strict_export(self) -> None:
# TODO: Fix Unflatten
# torch.export.unflatten(ep)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
def test_sharded_quant_fpebc_non_strict_export(self) -> None:
sharded_model, input_kjts = _sharded_quant_ebc_model(
local_device="cpu", compute_device="cpu", feature_processor=True
Expand Down
1 change: 0 additions & 1 deletion torchrec/models/tests/test_dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,6 @@ def test_export_serialization(self) -> None:
# Run forward on ExportedProgram
ep_output = ep.module()(features, sparse_features)
self.assertEqual(ep_output.size(), (B, 1))
self.assertTrue(torch.allclose(logits, ep_output))

deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)
deserialized_logits = deserialized_model(features, sparse_features)
Expand Down
Loading

0 comments on commit 2183c44

Please sign in to comment.