-
Notifications
You must be signed in to change notification settings - Fork 101
apply sort_waits
if dist_prims.wait
is in a trace
#776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
18aa07c
to
163e7b3
Compare
if bsym.sym.id in {all_gather_prim_impl.id, reduce_scatter_prim_impl.id}: | ||
comm_idx = idx | ||
if bsym.sym.id == wait_prim_impl.id: | ||
self.assertGreater(idx, comm_idx + 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal of this check is to make sure that wait gets distant from the comm by sort_waits
. In-place comms are expressed as a pair of thunder's functional comm and copy. So initially they are next to each other.
66b1dc1
to
fbe9526
Compare
from thunder.distributed.utils import maybe_sort_waits | ||
|
||
with langctxs.langctx(cd.langctx): | ||
tmp_comp_trc = _transform_for_operator_executor_execution(computation_trc, cd.executors_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have this _transform_for_operator_executor_execution
also in the transform_for_execution
, does it affect anything? I noticed the bucketing was applied before _transform_for_operator_executor_execution
, now it's after the _transform_for_operator_executor_execution
, don't know if it matters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To effectively apply sort_waits
to a trace, wait
s need to be a BoundSymbol. If a trace has a BoundSymbol representing an in-place distributed op, its subsymbols are a pair of out-of-place dist op, wait, and copy and there does not exist a boundsymbol whose sym is wait.
The call of _transform_for_operator_executor_execution
here flattens such bsyms of in-place dist ops and tmp_comp_trc
would have bsyms of waits, if the computation_trc has bsyms of in-place dist ops.
Before _transform_for_operator_executor_execution
:
# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, b, output):
# a: "cuda:1 f32[4, 2]"
# b: "cuda:1 f32[4, 2]"
# output: "cuda:1 f32[8, 2]"
# /opt/pytorch/lightning-thunder/inplace_dist.py:12: c = a + b
result = ltorch.add(a, b, alpha=None) # result: "cuda:1 f32[4, 2]"
# result = prims.add(a, b) # result: "cuda:1 f32[4, 2]"
t2 = ltorch.all_gather(output, result, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # t2: "cuda:1 f32[8, 2]"
# p1 = thunder.distributed.prims.all_gather(result, _torch_distributed_distributed_c10d_ProcessGroup_0, True, None) # p1: "FUTURE thunder.devices.Device(type='cuda:1') f32[8, 2]"
# t2 = thunder.distributed.prims.wait(p1) # t2: "cuda:1 f32[8, 2]"
# t2 = ltorch.view(t2, (8, 2)) # t2: "cuda:1 f32[8, 2]"
# /opt/pytorch/lightning-thunder/inplace_dist.py:14: e = c + 1
e = ltorch.add(result, 1, alpha=None) # e: "cuda:1 f32[4, 2]"
# _ = prims.convert_element_type(1, float)
# e = prims.add(result, 1.0) # e: "cuda:1 f32[4, 2]"
# /opt/pytorch/lightning-thunder/inplace_dist.py:16: f = e * b
f = ltorch.mul(e, b) # f: "cuda:1 f32[4, 2]"
# f = prims.mul(e, b) # f: "cuda:1 f32[4, 2]"
t6 = ltorch.mul(t2, 2) # t6: "cuda:1 f32[8, 2]"
# t6 = ltorch.mul(t2, 2) # t6: "cuda:1 f32[8, 2]"
# _ = prims.convert_element_type(2, float)
# t6 = prims.mul(t2, 2.0) # t6: "cuda:1 f32[8, 2]"
prims.copy_(t6, output)
# /opt/pytorch/lightning-thunder/inplace_dist.py:17: output *= 2
return f
After _transform_for_operator_executor_execution
:
# Constructed by Transform for operator executor execution (took 0 milliseconds)
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, b, output):
# a: "cuda:1 f32[4, 2]"
# b: "cuda:1 f32[4, 2]"
# output: "cuda:1 f32[8, 2]"
# /opt/pytorch/lightning-thunder/inplace_dist.py:12: c = a + b
result = ltorch.add(a, b, alpha=None) # result: "cuda:1 f32[4, 2]"
# result = prims.add(a, b) # result: "cuda:1 f32[4, 2]"
p1 = torch_all_gather_prim_impl(result, _torch_distributed_distributed_c10d_ProcessGroup_1, True, None) # p1: "FUTURE thunder.devices.Device(type='cuda:1') f32[8, 2]"
t2 = torch_wait_prim_impl(p1) # t2: "cuda:1 f32[8, 2]"
t2 = Tensor.view(t2, (8, 2)) # t2: "cuda:1 f32[8, 2]"
# t2 = ltorch.view(t2, (8, 2)) # t2: "cuda:1 f32[8, 2]"
# t2 = ltorch.reshape(t2, (8, 2)) # t2: "cuda:1 f32[8, 2]"
# /opt/pytorch/lightning-thunder/inplace_dist.py:14: e = c + 1
e = ltorch.add(result, 1, alpha=None) # e: "cuda:1 f32[4, 2]"
# _ = prims.convert_element_type(1, float)
# e = prims.add(result, 1.0) # e: "cuda:1 f32[4, 2]"
# /opt/pytorch/lightning-thunder/inplace_dist.py:16: f = e * b
f = ltorch.mul(e, b) # f: "cuda:1 f32[4, 2]"
# f = prims.mul(e, b) # f: "cuda:1 f32[4, 2]"
t6 = ltorch.mul(t2, 2) # t6: "cuda:1 f32[8, 2]"
# t6 = ltorch.mul(t2, 2) # t6: "cuda:1 f32[8, 2]"
# _ = prims.convert_element_type(2, float)
# t6 = prims.mul(t2, 2.0) # t6: "cuda:1 f32[8, 2]"
prims.copy_(t6, output)
# /opt/pytorch/lightning-thunder/inplace_dist.py:17: output *= 2
return f
bw_extrace = sort_waits(bw_extrace) | ||
if (not use_ddp) and (not use_fsdp): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this sorting also needed for inplace comms when using ddp/fsdp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you elaborate on it? I'm not quite following the question.
inplace comms are not used in thunder's ddp & fsdp, currently.
This check is rather to avoid redundant application for ddp or unwanted application for fsdp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inplace comms are not used in thunder's ddp & fsdp, currently. This check is rather to avoid redundant application for ddp or unwanted application for fsdp.
I get it now, thanks for the above explanation and trace example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems for fsdp, the maybe_sort_wait in thunder/__init__
will be applied after the limit_in_flight_allgathers on fwd trace, will it prefer the allgather again and break the limit_in_flight_allgathers ?
even when `use_ddp=False` and `use_fsdp=False` Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
fbe9526
to
a893e96
Compare
from thunder.executors.passes import _transform_for_operator_executor_execution | ||
from thunder.distributed.utils import maybe_sort_waits | ||
|
||
with langctxs.langctx(cd.langctx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think langctxs.langctx(cd.langctx)
should be applied as a decorator on the get_computation_and_inputs
function as a whole. I will create a PR for this.
What does this PR do?
Fixes #765