@@ -139,7 +139,18 @@ def _get_event_busbw_factor(evt):
139139
140140    return  correction_factor_func (group_size )
141141
142- def  _calculate_busbw_for_uneven_all_to_all (evt , global_rank ):
142+ def  _is_uneven_all_to_all_evt (evt ):
143+     coll_name  =  _get_dict_value (
144+         evt ["args" ],
145+         "Collective name" ,
146+         f'Missing "Collective name" in event: { evt }  
147+         )
148+     return  (coll_name  in  ["all_to_all" , "all_to_allv" ] 
149+             and  (ast .literal_eval (evt ['args' ]['In split size' ])
150+                  or  ast .literal_eval (evt ['args' ]['Out split size' ]))
151+            )
152+ 
153+ def  _get_uneven_all_to_all_data_size (evt , global_rank ):
143154    group_size  =  evt ["args" ]["Group size" ]
144155    local_rank  =  _parse_ranks (evt ["args" ]["Process Group Ranks" ], group_size ).index (global_rank )
145156    in_elems_count  =  evt ["args" ]["In msg nelems" ]
@@ -158,7 +169,10 @@ def _calculate_busbw_for_uneven_all_to_all(evt, global_rank):
158169    else :
159170        recv_elems  =  out_elems_count  /  group_size  *  (group_size  -  1 )
160171
161-     return  round (max (send_elems , recv_elems ) *  dtype_size  /  evt ["dur" ] *  1e-3 , 2 )
172+     return  max (send_elems , recv_elems ) *  dtype_size 
173+ 
174+ def  _calculate_busbw_for_uneven_all_to_all (evt , global_rank ):
175+     return  round (_get_uneven_all_to_all_data_size (evt , global_rank ) /  evt ["dur" ] *  1e-3 , 2 )
162176
163177def  calculate_bw_ (trace_data , global_rank ):
164178    nccl_events  =  [
@@ -184,10 +198,7 @@ def calculate_bw_(trace_data, global_rank):
184198
185199            algbw  =  _calculate_algbw (evt )
186200            busbw_factor  =  _get_event_busbw_factor (evt )
187-             if  (coll_name  in  ["all_to_all" , "all_to_allv" ] 
188-                 and  (ast .literal_eval (evt ['args' ]['In split size' ]) 
189-                     or  ast .literal_eval (evt ['args' ]['Out split size' ]))
190-                 ):
201+             if  _is_uneven_all_to_all_evt (evt ):
191202                # calculate busbw for uneven all_to_all 
192203                busbw  =  _calculate_busbw_for_uneven_all_to_all (evt , global_rank )
193204            else :
@@ -206,7 +217,7 @@ def calculate_bw_(trace_data, global_rank):
206217            logger .error (f"- Error: { err_msg }  )
207218
208219
209- def  calculate_sbw (trace_data ):
220+ def  calculate_sbw (trace_data ,  global_rank ):
210221    # calculate shared bw per rank 
211222    nccl_events  =  [
212223        i 
@@ -221,6 +232,8 @@ def calculate_sbw(trace_data):
221232
222233    total_data_size  =  sum (
223234        _calculate_event_data_size (evt ) *  _get_event_busbw_factor (evt )
235+         if  not  _is_uneven_all_to_all_evt (evt )
236+         else  _get_uneven_all_to_all_data_size (evt , global_rank )
224237        for  evt  in  nccl_events 
225238    )
226239
@@ -336,7 +349,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
336349        ) as  f :
337350            json .dump (trace , f )
338351
339-         sbw_lst .append (calculate_sbw (trace ))
352+         sbw_lst .append (calculate_sbw (trace ,  global_rank ))
340353
341354        pick_iter_e2e_time_ (trace , iter_e2e_time )
342355        pick_comm_bw_ (trace , comm_bw_data )
@@ -367,7 +380,7 @@ def analyze_profiler_trace(trace_dir: str, report_dir: str):
367380            f"avg. E2ETime of iters among all ranks: { sum (iter_e2e_time ) /  len (iter_e2e_time ) /  1e3  :.3f} \n " 
368381        )
369382        f .write (
370-             f"avg. SharedBW (i.e. sum(data_size * busbw_factor ) / GPU_comm_busy_time  per rank) among all ranks: { sum (sbw_lst ) /  len (sbw_lst ) :.3f} \n " 
383+             f"avg. SharedBW (i.e. sum(busbw_data_size ) / GPU_comm_busy_time  per rank) among all ranks: { sum (sbw_lst ) /  len (sbw_lst ) :.3f} \n " 
371384        )
372385
373386        f .write (
0 commit comments