diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 90cdf70201..745143df65 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1,3 +1,4 @@ +from __future__ import annotations from collections import namedtuple from contextlib import nullcontext, contextmanager from dataclasses import dataclass, replace @@ -6,22 +7,17 @@ from functools import lru_cache, partial, wraps import math from numbers import Number -from typing import Any, Dict, Union, Optional -from types import NoneType +from typing import Any, TYPE_CHECKING from collections.abc import Callable -from collections.abc import Hashable from collections.abc import Sequence import copy import inspect import time -from collections import deque -import os import dataclasses import thunder import thunder.core.utils as utils from thunder.core import dtypes, prims -import thunder.core.devices as devices from thunder.core.devices import cpu, Device from thunder.core.proxies import ( NumberProxy, @@ -29,18 +25,15 @@ TensorProxy, FloatProxy, variableify, - unvariableify, - CollectionProxy, FutureTensorProxy, ) -from thunder.core.baseutils import default_dataclass_params from thunder.core.compile_data import get_compile_data from thunder.core.langctxs import langctx, Languages from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol -from thunder.core.trace import TraceCtx as Trace, tracectx +from thunder.core.trace import TraceCtx as Trace from thunder.core.trace import VariableInterface as Variable -from thunder.core.trace import detached_trace, get_tracectx, set_tracectx, reset_tracectx, from_trace, TraceProvenance +from thunder.core.trace import detached_trace, set_tracectx, reset_tracectx, from_trace, TraceProvenance from thunder.core.utils import ( check, flatten_func, @@ -69,7 +62,8 @@ import thunder.torch as ltorch import torch -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + +# from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode import numpy as np @@ -3072,6 +3066,39 @@ def index_put_aug_fwd( return VJPDual(primal, residuals) +if torch.distributed.is_available(): + from torch.distributed import ReduceOp + from torch.distributed import distributed_c10d as c10d + from torch._C._distributed_c10d import _resolve_process_group + + if TYPE_CHECKING: + from torch.distributed import ProcessGroup + from thunder.distributed.prims import DistributedReduceOps + + @register_augmented_forward("torch.ops._c10d_functional.all_reduce") + def functional_all_reduce_augmented_forward( + a: TensorProxy, + /, + op: str | ReduceOp | DistributedReduceOps = ReduceOp.SUM, + group: None | ProcessGroup | str = None, + async_op: bool = False, + **kwargs, + ) -> VJPDual: + from thunder.torch import all_reduce + + if isinstance(group, str): + group = _resolve_process_group(group) + primal = all_reduce(a, op=op, group=group) + residuals = (op, group) + return VJPDual(primal, residuals) + + @register_backward("torch.ops._c10d_functional.all_reduce") + def functional_all_backward(op, group, g) -> TensorProxy: + from thunder.torch import all_reduce + + return all_reduce(g, op=op, group=group) + + def sum_to(a: TensorProxy, shape: Sequence[int]) -> TensorProxy: if not shape: return a.sum() diff --git a/thunder/distributed/prims.py b/thunder/distributed/prims.py index ede7581a48..30f786e7dc 100644 --- a/thunder/distributed/prims.py +++ b/thunder/distributed/prims.py @@ -42,13 +42,13 @@ class PrimIDs(Enum): # the tensor across processes. class DistributedReduceOps(Enum): SUM = auto() - # AVG = auto() - # PRODUCT = auto() - # MIN = auto() - # MAX = auto() - # BAND = auto() - # BOR = auto() - # BXOR = auto() + AVG = auto() + PRODUCT = auto() + MIN = auto() + MAX = auto() + BAND = auto() + BOR = auto() + BXOR = auto() # PREMUL_SUM = auto() diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 562780e927..fb49bac892 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1875,6 +1875,11 @@ def copysign_(a, b, /): return prims.copy_(copysign(a, b), a) +@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,)) +def copy_(a, b, /): + return prims.copy_(b, a) + + # TODO Implement div @torchsymbol(torch.div, is_method=True) def div( @@ -5129,7 +5134,10 @@ def _unwrap_if_dead(tensor): DistributedReduceOpLike = str | torch.distributed.ReduceOp | dist_prims.DistributedReduceOps # string name, PyTorch enum value, thunder.jit enum value - _reduceop_triples = (("sum", torch.distributed.ReduceOp.SUM, dist_prims.DistributedReduceOps.SUM),) + _reduceop_triples = ( + ("sum", torch.distributed.ReduceOp.SUM, dist_prims.DistributedReduceOps.SUM), + ("max", torch.distributed.ReduceOp.MAX, dist_prims.DistributedReduceOps.MAX), + ) def to_thunder_distributed_reduce_op(op: DistributedReduceOpLike | None): if isinstance(op, str): @@ -5206,6 +5214,7 @@ def all_gather_( # This operation is based on torch.distributed.all_reduce, see: # https://pytorch.org/docs/master/distributed.html#torch.distributed.all_reduce @torchsymbol( + torch.ops._c10d_functional.all_reduce, is_method=False, id="functional_all_reduce", ) @@ -5213,9 +5222,15 @@ def all_reduce( a: TensorLike, /, op: DistributedReduceOpLike = torch.distributed.ReduceOp.SUM, - group: None | torch.distributed.ProcessGroup = None, + group: None | torch.distributed.ProcessGroup | str = None, async_op: bool = False, + **kwargs, ) -> TensorLike | FutureTensorLike: + # note: torch.ops._c10d_functional takes name of group + if isinstance(group, str): + from torch._C._distributed_c10d import _resolve_process_group + + group = _resolve_process_group(group_name=group) op = to_thunder_distributed_reduce_op(op) group = group if group is not None else torch.distributed.new_group() @@ -5300,6 +5315,7 @@ def reduce_scatter_( return prims.copy_(out.view(output.shape), output) @torchsymbol( + torch.ops._c10d_functional.wait_tensor, is_method=True, id="torch.Tensor.wait", )