Skip to content

apply sort_waits if dist_prims.wait is in a trace #776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 19, 2024

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Jul 16, 2024

What does this PR do?

Fixes #765

@crcrpar crcrpar force-pushed the crpa/torch-native-dist-ops-with-waits-sorted branch 3 times, most recently from 18aa07c to 163e7b3 Compare July 17, 2024 11:22
@crcrpar crcrpar requested a review from kiya00 July 17, 2024 13:00
@crcrpar crcrpar marked this pull request as ready for review July 17, 2024 13:00
if bsym.sym.id in {all_gather_prim_impl.id, reduce_scatter_prim_impl.id}:
comm_idx = idx
if bsym.sym.id == wait_prim_impl.id:
self.assertGreater(idx, comm_idx + 2)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of this check is to make sure that wait gets distant from the comm by sort_waits. In-place comms are expressed as a pair of thunder's functional comm and copy. So initially they are next to each other.

@crcrpar crcrpar force-pushed the crpa/torch-native-dist-ops-with-waits-sorted branch 2 times, most recently from 66b1dc1 to fbe9526 Compare July 18, 2024 00:54
from thunder.distributed.utils import maybe_sort_waits

with langctxs.langctx(cd.langctx):
tmp_comp_trc = _transform_for_operator_executor_execution(computation_trc, cd.executors_list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have this _transform_for_operator_executor_execution also in the transform_for_execution, does it affect anything? I noticed the bucketing was applied before _transform_for_operator_executor_execution, now it's after the _transform_for_operator_executor_execution, don't know if it matters

Copy link
Collaborator Author

@crcrpar crcrpar Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To effectively apply sort_waits to a trace, waits need to be a BoundSymbol. If a trace has a BoundSymbol representing an in-place distributed op, its subsymbols are a pair of out-of-place dist op, wait, and copy and there does not exist a boundsymbol whose sym is wait.
The call of _transform_for_operator_executor_execution here flattens such bsyms of in-place dist ops and tmp_comp_trc would have bsyms of waits, if the computation_trc has bsyms of in-place dist ops.

Before _transform_for_operator_executor_execution:

# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b, output):
  # a: "cuda:1 f32[4, 2]"
  # b: "cuda:1 f32[4, 2]"
  # output: "cuda:1 f32[8, 2]"

  # /opt/pytorch/lightning-thunder/inplace_dist.py:12:      c = a + b
  result = ltorch.add(a, b, alpha=None)  # result: "cuda:1 f32[4, 2]"
    # result = prims.add(a, b)  # result: "cuda:1 f32[4, 2]"
  t2 = ltorch.all_gather(output, result, _torch_distributed_distributed_c10d_ProcessGroup_0, True)  # t2: "cuda:1 f32[8, 2]"
    # p1 = thunder.distributed.prims.all_gather(result, _torch_distributed_distributed_c10d_ProcessGroup_0, True, None)  # p1: "FUTURE thunder.devices.Device(type='cuda:1') f32[8, 2]"
    # t2 = thunder.distributed.prims.wait(p1)  # t2: "cuda:1 f32[8, 2]"
    # t2 = ltorch.view(t2, (8, 2))  # t2: "cuda:1 f32[8, 2]"

  # /opt/pytorch/lightning-thunder/inplace_dist.py:14:      e = c + 1
  e = ltorch.add(result, 1, alpha=None)  # e: "cuda:1 f32[4, 2]"
    # _ = prims.convert_element_type(1, float)
    # e = prims.add(result, 1.0)  # e: "cuda:1 f32[4, 2]"

  # /opt/pytorch/lightning-thunder/inplace_dist.py:16:      f = e * b
  f = ltorch.mul(e, b)  # f: "cuda:1 f32[4, 2]"
    # f = prims.mul(e, b)  # f: "cuda:1 f32[4, 2]"
  t6 = ltorch.mul(t2, 2)  # t6: "cuda:1 f32[8, 2]"
    # t6 = ltorch.mul(t2, 2)  # t6: "cuda:1 f32[8, 2]"
      # _ = prims.convert_element_type(2, float)
      # t6 = prims.mul(t2, 2.0)  # t6: "cuda:1 f32[8, 2]"
  prims.copy_(t6, output)

  # /opt/pytorch/lightning-thunder/inplace_dist.py:17:      output *= 2
  return f

After _transform_for_operator_executor_execution:

# Constructed by Transform for operator executor execution (took 0 milliseconds)
import thunder
import thunder.core.prims as prims
import thunder.torch as ltorch
from torch import Tensor
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a, b, output):
  # a: "cuda:1 f32[4, 2]"
  # b: "cuda:1 f32[4, 2]"
  # output: "cuda:1 f32[8, 2]"

  # /opt/pytorch/lightning-thunder/inplace_dist.py:12:      c = a + b
  result = ltorch.add(a, b, alpha=None)  # result: "cuda:1 f32[4, 2]"
    # result = prims.add(a, b)  # result: "cuda:1 f32[4, 2]"
  p1 = torch_all_gather_prim_impl(result, _torch_distributed_distributed_c10d_ProcessGroup_1, True, None)  # p1: "FUTURE thunder.devices.Device(type='cuda:1') f32[8, 2]"
  t2 = torch_wait_prim_impl(p1)  # t2: "cuda:1 f32[8, 2]"
  t2 = Tensor.view(t2, (8, 2))  # t2: "cuda:1 f32[8, 2]"
    # t2 = ltorch.view(t2, (8, 2))  # t2: "cuda:1 f32[8, 2]"
      # t2 = ltorch.reshape(t2, (8, 2))  # t2: "cuda:1 f32[8, 2]"

  # /opt/pytorch/lightning-thunder/inplace_dist.py:14:      e = c + 1
  e = ltorch.add(result, 1, alpha=None)  # e: "cuda:1 f32[4, 2]"
    # _ = prims.convert_element_type(1, float)
    # e = prims.add(result, 1.0)  # e: "cuda:1 f32[4, 2]"

  # /opt/pytorch/lightning-thunder/inplace_dist.py:16:      f = e * b
  f = ltorch.mul(e, b)  # f: "cuda:1 f32[4, 2]"
    # f = prims.mul(e, b)  # f: "cuda:1 f32[4, 2]"
  t6 = ltorch.mul(t2, 2)  # t6: "cuda:1 f32[8, 2]"
    # t6 = ltorch.mul(t2, 2)  # t6: "cuda:1 f32[8, 2]"
      # _ = prims.convert_element_type(2, float)
      # t6 = prims.mul(t2, 2.0)  # t6: "cuda:1 f32[8, 2]"
  prims.copy_(t6, output)

  # /opt/pytorch/lightning-thunder/inplace_dist.py:17:      output *= 2
  return f

bw_extrace = sort_waits(bw_extrace)
if (not use_ddp) and (not use_fsdp):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this sorting also needed for inplace comms when using ddp/fsdp?

Copy link
Collaborator Author

@crcrpar crcrpar Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you elaborate on it? I'm not quite following the question.

inplace comms are not used in thunder's ddp & fsdp, currently.
This check is rather to avoid redundant application for ddp or unwanted application for fsdp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inplace comms are not used in thunder's ddp & fsdp, currently. This check is rather to avoid redundant application for ddp or unwanted application for fsdp.

I get it now, thanks for the above explanation and trace example

Copy link
Collaborator

@kiya00 kiya00 Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems for fsdp, the maybe_sort_wait in thunder/__init__ will be applied after the limit_in_flight_allgathers on fwd trace, will it prefer the allgather again and break the limit_in_flight_allgathers ?

crcrpar added 6 commits July 19, 2024 18:05
even when `use_ddp=False` and `use_fsdp=False`

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the crpa/torch-native-dist-ops-with-waits-sorted branch from fbe9526 to a893e96 Compare July 19, 2024 09:05
@t-vi
Copy link
Collaborator

t-vi commented Jul 19, 2024

Thank you @crcrpar @kiya00

@t-vi t-vi merged commit 721e28e into main Jul 19, 2024
36 checks passed
@t-vi t-vi deleted the crpa/torch-native-dist-ops-with-waits-sorted branch July 19, 2024 14:29
from thunder.executors.passes import _transform_for_operator_executor_execution
from thunder.distributed.utils import maybe_sort_waits

with langctxs.langctx(cd.langctx):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think langctxs.langctx(cd.langctx) should be applied as a decorator on the get_computation_and_inputs function as a whole. I will create a PR for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Apply thunder.distributed.utils.sort_waits appropriately when possible
4 participants