Skip to content

Commit 0136c67

Browse files
authored
fix TE - Compute and Comms overlap (#690)
1 parent d7a0568 commit 0136c67

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

thunder/distributed/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,9 @@ def key(node: Node) -> int:
8484
case wait_prim_impl.id | unpack_for_fsdp_prim_impl.id:
8585
return len(order_in_trace)
8686
case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id:
87-
# Prefer larger communication ops over smaller ones
88-
return -node.bsym.args[0].numel
87+
# We want to keep the `reduce` close to it's producer
88+
# (which is close to the original place in the trace).
89+
return order_in_trace[node.bsym]
8990
case all_gather_prim_impl.id:
9091
return len(order_in_trace) + order_in_trace[node.bsym]
9192
case _:

0 commit comments

Comments
 (0)