@@ -477,6 +477,48 @@ def _get_expected_cache_prefetch_time(
477477 prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size
478478 return prefetch_bytes / hbm_to_ddr_mem_bw
479479
480+ @classmethod
481+ def _input_dist_expected_latency (
482+ cls ,
483+ batch_sizes : List [int ],
484+ world_size : int ,
485+ local_world_size : int ,
486+ num_poolings : List [float ],
487+ input_lengths : List [float ],
488+ a2a_comm_data_type_size : float ,
489+ comms_bandwidths : GeneralizedCommsBandwidth ,
490+ is_weighted : bool = False ,
491+ ) -> float :
492+ """
493+ Calculates the expected latency for A2A input dist.
494+
495+ Args:
496+ batch_sizes (int): The batch size for each input feature.
497+ world_size (int): The total number of devices in the distributed setup.
498+ local_world_size (int): The number of devices on a single host.
499+ num_poolings (List[float]): Number of poolings per sample for each input feature.
500+ input_lengths (List[float]): Average number of lookups per input feature.
501+ a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication.
502+ comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths.
503+
504+ Returns:
505+ float: The expected latency (in seconds) for input distribution.
506+ """
507+ batch_inputs = sum (
508+ [x * y * z for x , y , z in zip (input_lengths , num_poolings , batch_sizes )]
509+ )
510+ input_read_size = math .ceil (batch_inputs * world_size * a2a_comm_data_type_size )
511+
512+ if is_weighted :
513+ input_read_size *= 2
514+
515+ comms_bw = comms_bandwidths .get_bw (
516+ world_size = world_size ,
517+ local_world_size = local_world_size ,
518+ collective_type = CollectiveType .ALL_TO_ALL ,
519+ )
520+ return input_read_size / comms_bw
521+
480522 @classmethod
481523 def _get_tw_sharding_perf (
482524 cls ,
@@ -567,6 +609,15 @@ def _get_tw_sharding_perf(
567609 hbm_to_ddr_mem_bw , expected_cache_fetches , emb_dim , table_data_type_size
568610 )
569611
612+ input_dist_comms = cls ._input_dist_expected_latency (
613+ batch_sizes = batch_sizes ,
614+ world_size = world_size ,
615+ local_world_size = local_world_size ,
616+ num_poolings = num_poolings ,
617+ input_lengths = input_lengths ,
618+ a2a_comm_data_type_size = input_data_type_size ,
619+ comms_bandwidths = comms_bandwidths ,
620+ )
570621 # in order of model parallel execution, starting with:
571622 # BWD DP -> BWD MP ... FWD MP -> FWD DP
572623 return Perf (
@@ -575,6 +626,7 @@ def _get_tw_sharding_perf(
575626 bwd_compute = bwd_compute + bwd_grad_indice_weights_kernel ,
576627 bwd_comms = bwd_comms ,
577628 prefetch_compute = prefetch_compute ,
629+ input_dist_comms = input_dist_comms ,
578630 )
579631
580632 @classmethod
@@ -674,13 +726,23 @@ def _get_rw_sharding_perf(
674726 emb_dim ,
675727 table_data_type_size ,
676728 )
729+ input_dist_comms = cls ._input_dist_expected_latency (
730+ batch_sizes = batch_sizes ,
731+ world_size = world_size ,
732+ local_world_size = local_world_size ,
733+ num_poolings = num_poolings ,
734+ input_lengths = input_lengths ,
735+ a2a_comm_data_type_size = input_data_type_size ,
736+ comms_bandwidths = comms_bandwidths ,
737+ )
677738
678739 return Perf (
679740 fwd_compute = fwd_compute ,
680741 fwd_comms = fwd_comms ,
681742 bwd_compute = bwd_compute + bwd_grad_indice_weights_kernel ,
682743 bwd_comms = bwd_comms + bwd_batched_copy ,
683744 prefetch_compute = prefetch_compute ,
745+ input_dist_comms = input_dist_comms ,
684746 )
685747
686748 @classmethod
@@ -806,13 +868,23 @@ def _get_twrw_sharding_perf(
806868 emb_dim ,
807869 table_data_type_size ,
808870 )
871+ input_dist_comms = cls ._input_dist_expected_latency (
872+ batch_sizes = batch_sizes ,
873+ world_size = world_size ,
874+ local_world_size = local_world_size ,
875+ num_poolings = num_poolings ,
876+ input_lengths = input_lengths ,
877+ a2a_comm_data_type_size = input_data_type_size ,
878+ comms_bandwidths = comms_bandwidths ,
879+ )
809880
810881 return Perf (
811882 fwd_compute = fwd_compute ,
812883 fwd_comms = fwd_comms ,
813884 bwd_compute = bwd_compute + bwd_grad_indice_weights_kernel ,
814885 bwd_comms = bwd_comms + bwd_batched_copy ,
815886 prefetch_compute = prefetch_compute ,
887+ input_dist_comms = input_dist_comms ,
816888 )
817889
818890 @classmethod
0 commit comments