Skip to content

Commit

Permalink
all_reduce & wait_tensor of torch.ops._c10d_functional (#1063)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Aug 29, 2024
1 parent ceba730 commit a89aa48
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 21 deletions.
51 changes: 39 additions & 12 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from collections import namedtuple
from contextlib import nullcontext, contextmanager
from dataclasses import dataclass, replace
Expand All @@ -6,41 +7,33 @@
from functools import lru_cache, partial, wraps
import math
from numbers import Number
from typing import Any, Dict, Union, Optional
from types import NoneType
from typing import Any, TYPE_CHECKING
from collections.abc import Callable
from collections.abc import Hashable
from collections.abc import Sequence
import copy
import inspect
import time
from collections import deque
import os
import dataclasses

import thunder
import thunder.core.utils as utils
from thunder.core import dtypes, prims
import thunder.core.devices as devices
from thunder.core.devices import cpu, Device
from thunder.core.proxies import (
NumberProxy,
Proxy,
TensorProxy,
FloatProxy,
variableify,
unvariableify,
CollectionProxy,
FutureTensorProxy,
)
from thunder.core.baseutils import default_dataclass_params
from thunder.core.compile_data import get_compile_data
from thunder.core.langctxs import langctx, Languages
from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass
from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol
from thunder.core.trace import TraceCtx as Trace, tracectx
from thunder.core.trace import TraceCtx as Trace
from thunder.core.trace import VariableInterface as Variable
from thunder.core.trace import detached_trace, get_tracectx, set_tracectx, reset_tracectx, from_trace, TraceProvenance
from thunder.core.trace import detached_trace, set_tracectx, reset_tracectx, from_trace, TraceProvenance
from thunder.core.utils import (
check,
flatten_func,
Expand Down Expand Up @@ -69,7 +62,8 @@
import thunder.torch as ltorch

import torch
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode

# from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
import numpy as np


Expand Down Expand Up @@ -3072,6 +3066,39 @@ def index_put_aug_fwd(
return VJPDual(primal, residuals)


if torch.distributed.is_available():
from torch.distributed import ReduceOp
from torch.distributed import distributed_c10d as c10d
from torch._C._distributed_c10d import _resolve_process_group

if TYPE_CHECKING:
from torch.distributed import ProcessGroup
from thunder.distributed.prims import DistributedReduceOps

@register_augmented_forward("torch.ops._c10d_functional.all_reduce")
def functional_all_reduce_augmented_forward(
a: TensorProxy,
/,
op: str | ReduceOp | DistributedReduceOps = ReduceOp.SUM,
group: None | ProcessGroup | str = None,
async_op: bool = False,
**kwargs,
) -> VJPDual:
from thunder.torch import all_reduce

if isinstance(group, str):
group = _resolve_process_group(group)
primal = all_reduce(a, op=op, group=group)
residuals = (op, group)
return VJPDual(primal, residuals)

@register_backward("torch.ops._c10d_functional.all_reduce")
def functional_all_backward(op, group, g) -> TensorProxy:
from thunder.torch import all_reduce

return all_reduce(g, op=op, group=group)


def sum_to(a: TensorProxy, shape: Sequence[int]) -> TensorProxy:
if not shape:
return a.sum()
Expand Down
14 changes: 7 additions & 7 deletions thunder/distributed/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ class PrimIDs(Enum):
# the tensor across processes.
class DistributedReduceOps(Enum):
SUM = auto()
# AVG = auto()
# PRODUCT = auto()
# MIN = auto()
# MAX = auto()
# BAND = auto()
# BOR = auto()
# BXOR = auto()
AVG = auto()
PRODUCT = auto()
MIN = auto()
MAX = auto()
BAND = auto()
BOR = auto()
BXOR = auto()
# PREMUL_SUM = auto()


Expand Down
20 changes: 18 additions & 2 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,11 @@ def copysign_(a, b, /):
return prims.copy_(copysign(a, b), a)


@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return prims.copy_(b, a)


# TODO Implement div
@torchsymbol(torch.div, is_method=True)
def div(
Expand Down Expand Up @@ -5129,7 +5134,10 @@ def _unwrap_if_dead(tensor):
DistributedReduceOpLike = str | torch.distributed.ReduceOp | dist_prims.DistributedReduceOps

# string name, PyTorch enum value, thunder.jit enum value
_reduceop_triples = (("sum", torch.distributed.ReduceOp.SUM, dist_prims.DistributedReduceOps.SUM),)
_reduceop_triples = (
("sum", torch.distributed.ReduceOp.SUM, dist_prims.DistributedReduceOps.SUM),
("max", torch.distributed.ReduceOp.MAX, dist_prims.DistributedReduceOps.MAX),
)

def to_thunder_distributed_reduce_op(op: DistributedReduceOpLike | None):
if isinstance(op, str):
Expand Down Expand Up @@ -5206,16 +5214,23 @@ def all_gather_(
# This operation is based on torch.distributed.all_reduce, see:
# https://pytorch.org/docs/master/distributed.html#torch.distributed.all_reduce
@torchsymbol(
torch.ops._c10d_functional.all_reduce,
is_method=False,
id="functional_all_reduce",
)
def all_reduce(
a: TensorLike,
/,
op: DistributedReduceOpLike = torch.distributed.ReduceOp.SUM,
group: None | torch.distributed.ProcessGroup = None,
group: None | torch.distributed.ProcessGroup | str = None,
async_op: bool = False,
**kwargs,
) -> TensorLike | FutureTensorLike:
# note: torch.ops._c10d_functional takes name of group
if isinstance(group, str):
from torch._C._distributed_c10d import _resolve_process_group

group = _resolve_process_group(group_name=group)
op = to_thunder_distributed_reduce_op(op)
group = group if group is not None else torch.distributed.new_group()

Expand Down Expand Up @@ -5300,6 +5315,7 @@ def reduce_scatter_(
return prims.copy_(out.view(output.shape), output)

@torchsymbol(
torch.ops._c10d_functional.wait_tensor,
is_method=True,
id="torch.Tensor.wait",
)
Expand Down

0 comments on commit a89aa48

Please sign in to comment.