We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d7a0568 commit 0136c67Copy full SHA for 0136c67
thunder/distributed/utils.py
@@ -84,8 +84,9 @@ def key(node: Node) -> int:
84
case wait_prim_impl.id | unpack_for_fsdp_prim_impl.id:
85
return len(order_in_trace)
86
case reduce_scatter_prim_impl.id | all_reduce_prim_impl.id:
87
- # Prefer larger communication ops over smaller ones
88
- return -node.bsym.args[0].numel
+ # We want to keep the `reduce` close to it's producer
+ # (which is close to the original place in the trace).
89
+ return order_in_trace[node.bsym]
90
case all_gather_prim_impl.id:
91
return len(order_in_trace) + order_in_trace[node.bsym]
92
case _:
0 commit comments