Skip to content

Commit 1aa05c7

Browse files
isururanawakafacebook-github-bot
authored andcommitted
Input distribution latency estimations (#3575)
Summary: This introduces input distribution latency estimations. Input distribution is two step communication happens inside SDD pipelines. - split exchange: Exchanges buffer sizes to receive input IDS from KJTs. The cost does not depend on Input and it meta data exchanging phase. Hence, this diff excludes that from the computations. - ID exchange: this exchanges actual IDs to lookup. we estimated the cost by analyzing all-to-all comms Differential Revision: D87389540
1 parent 3a1d5f3 commit 1aa05c7

File tree

4 files changed

+417
-1
lines changed

4 files changed

+417
-1
lines changed

torchrec/distributed/planner/shard_estimators.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,48 @@ def _get_expected_cache_prefetch_time(
477477
prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size
478478
return prefetch_bytes / hbm_to_ddr_mem_bw
479479

480+
@classmethod
481+
def _input_dist_expected_latency(
482+
cls,
483+
batch_sizes: List[int],
484+
world_size: int,
485+
local_world_size: int,
486+
num_poolings: List[float],
487+
input_lengths: List[float],
488+
a2a_comm_data_type_size: float,
489+
comms_bandwidths: GeneralizedCommsBandwidth,
490+
is_weighted: bool = False,
491+
) -> float:
492+
"""
493+
Calculates the expected latency for A2A input dist.
494+
495+
Args:
496+
batch_sizes (int): The batch size for each input feature.
497+
world_size (int): The total number of devices in the distributed setup.
498+
local_world_size (int): The number of devices on a single host.
499+
num_poolings (List[float]): Number of poolings per sample for each input feature.
500+
input_lengths (List[float]): Average number of lookups per input feature.
501+
a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication.
502+
comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths.
503+
504+
Returns:
505+
float: The expected latency (in seconds) for input distribution.
506+
"""
507+
batch_inputs = sum(
508+
[x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)]
509+
)
510+
input_read_size = math.ceil(batch_inputs * world_size * a2a_comm_data_type_size)
511+
512+
if is_weighted:
513+
input_read_size *= 2
514+
515+
comms_bw = comms_bandwidths.get_bw(
516+
world_size=world_size,
517+
local_world_size=local_world_size,
518+
collective_type=CollectiveType.ALL_TO_ALL,
519+
)
520+
return input_read_size / comms_bw
521+
480522
@classmethod
481523
def _get_tw_sharding_perf(
482524
cls,
@@ -567,6 +609,15 @@ def _get_tw_sharding_perf(
567609
hbm_to_ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size
568610
)
569611

612+
input_dist_comms = cls._input_dist_expected_latency(
613+
batch_sizes=batch_sizes,
614+
world_size=world_size,
615+
local_world_size=local_world_size,
616+
num_poolings=num_poolings,
617+
input_lengths=input_lengths,
618+
a2a_comm_data_type_size=input_data_type_size,
619+
comms_bandwidths=comms_bandwidths,
620+
)
570621
# in order of model parallel execution, starting with:
571622
# BWD DP -> BWD MP ... FWD MP -> FWD DP
572623
return Perf(
@@ -575,6 +626,7 @@ def _get_tw_sharding_perf(
575626
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
576627
bwd_comms=bwd_comms,
577628
prefetch_compute=prefetch_compute,
629+
input_dist_comms=input_dist_comms,
578630
)
579631

580632
@classmethod
@@ -674,13 +726,23 @@ def _get_rw_sharding_perf(
674726
emb_dim,
675727
table_data_type_size,
676728
)
729+
input_dist_comms = cls._input_dist_expected_latency(
730+
batch_sizes=batch_sizes,
731+
world_size=world_size,
732+
local_world_size=local_world_size,
733+
num_poolings=num_poolings,
734+
input_lengths=input_lengths,
735+
a2a_comm_data_type_size=input_data_type_size,
736+
comms_bandwidths=comms_bandwidths,
737+
)
677738

678739
return Perf(
679740
fwd_compute=fwd_compute,
680741
fwd_comms=fwd_comms,
681742
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
682743
bwd_comms=bwd_comms + bwd_batched_copy,
683744
prefetch_compute=prefetch_compute,
745+
input_dist_comms=input_dist_comms,
684746
)
685747

686748
@classmethod
@@ -806,13 +868,23 @@ def _get_twrw_sharding_perf(
806868
emb_dim,
807869
table_data_type_size,
808870
)
871+
input_dist_comms = cls._input_dist_expected_latency(
872+
batch_sizes=batch_sizes,
873+
world_size=world_size,
874+
local_world_size=local_world_size,
875+
num_poolings=num_poolings,
876+
input_lengths=input_lengths,
877+
a2a_comm_data_type_size=input_data_type_size,
878+
comms_bandwidths=comms_bandwidths,
879+
)
809880

810881
return Perf(
811882
fwd_compute=fwd_compute,
812883
fwd_comms=fwd_comms,
813884
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
814885
bwd_comms=bwd_comms + bwd_batched_copy,
815886
prefetch_compute=prefetch_compute,
887+
input_dist_comms=input_dist_comms,
816888
)
817889

818890
@classmethod

0 commit comments

Comments
 (0)