diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py
index 90cdf70201..745143df65 100644
--- a/thunder/core/transforms.py
+++ b/thunder/core/transforms.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
 from collections import namedtuple
 from contextlib import nullcontext, contextmanager
 from dataclasses import dataclass, replace
@@ -6,22 +7,17 @@
 from functools import lru_cache, partial, wraps
 import math
 from numbers import Number
-from typing import Any, Dict, Union, Optional
-from types import NoneType
+from typing import Any, TYPE_CHECKING
 from collections.abc import Callable
-from collections.abc import Hashable
 from collections.abc import Sequence
 import copy
 import inspect
 import time
-from collections import deque
-import os
 import dataclasses
 
 import thunder
 import thunder.core.utils as utils
 from thunder.core import dtypes, prims
-import thunder.core.devices as devices
 from thunder.core.devices import cpu, Device
 from thunder.core.proxies import (
     NumberProxy,
@@ -29,18 +25,15 @@
     TensorProxy,
     FloatProxy,
     variableify,
-    unvariableify,
-    CollectionProxy,
     FutureTensorProxy,
 )
-from thunder.core.baseutils import default_dataclass_params
 from thunder.core.compile_data import get_compile_data
 from thunder.core.langctxs import langctx, Languages
 from thunder.core.pytree import tree_flatten, tree_map, tree_unflatten, tree_flatten_with_dataclass
 from thunder.core.symbol import BoundSymbol, BoundSymbolInterface, Symbol
-from thunder.core.trace import TraceCtx as Trace, tracectx
+from thunder.core.trace import TraceCtx as Trace
 from thunder.core.trace import VariableInterface as Variable
-from thunder.core.trace import detached_trace, get_tracectx, set_tracectx, reset_tracectx, from_trace, TraceProvenance
+from thunder.core.trace import detached_trace, set_tracectx, reset_tracectx, from_trace, TraceProvenance
 from thunder.core.utils import (
     check,
     flatten_func,
@@ -69,7 +62,8 @@
 import thunder.torch as ltorch
 
 import torch
-from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
+
+# from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
 import numpy as np
 
 
@@ -3072,6 +3066,39 @@ def index_put_aug_fwd(
     return VJPDual(primal, residuals)
 
 
+if torch.distributed.is_available():
+    from torch.distributed import ReduceOp
+    from torch.distributed import distributed_c10d as c10d
+    from torch._C._distributed_c10d import _resolve_process_group
+
+    if TYPE_CHECKING:
+        from torch.distributed import ProcessGroup
+        from thunder.distributed.prims import DistributedReduceOps
+
+    @register_augmented_forward("torch.ops._c10d_functional.all_reduce")
+    def functional_all_reduce_augmented_forward(
+        a: TensorProxy,
+        /,
+        op: str | ReduceOp | DistributedReduceOps = ReduceOp.SUM,
+        group: None | ProcessGroup | str = None,
+        async_op: bool = False,
+        **kwargs,
+    ) -> VJPDual:
+        from thunder.torch import all_reduce
+
+        if isinstance(group, str):
+            group = _resolve_process_group(group)
+        primal = all_reduce(a, op=op, group=group)
+        residuals = (op, group)
+        return VJPDual(primal, residuals)
+
+    @register_backward("torch.ops._c10d_functional.all_reduce")
+    def functional_all_backward(op, group, g) -> TensorProxy:
+        from thunder.torch import all_reduce
+
+        return all_reduce(g, op=op, group=group)
+
+
 def sum_to(a: TensorProxy, shape: Sequence[int]) -> TensorProxy:
     if not shape:
         return a.sum()
diff --git a/thunder/distributed/prims.py b/thunder/distributed/prims.py
index ede7581a48..30f786e7dc 100644
--- a/thunder/distributed/prims.py
+++ b/thunder/distributed/prims.py
@@ -42,13 +42,13 @@ class PrimIDs(Enum):
 #   the tensor across processes.
 class DistributedReduceOps(Enum):
     SUM = auto()
-    # AVG = auto()
-    # PRODUCT = auto()
-    # MIN = auto()
-    # MAX = auto()
-    # BAND = auto()
-    # BOR = auto()
-    # BXOR = auto()
+    AVG = auto()
+    PRODUCT = auto()
+    MIN = auto()
+    MAX = auto()
+    BAND = auto()
+    BOR = auto()
+    BXOR = auto()
     # PREMUL_SUM = auto()
 
 
diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py
index 562780e927..fb49bac892 100644
--- a/thunder/torch/__init__.py
+++ b/thunder/torch/__init__.py
@@ -1875,6 +1875,11 @@ def copysign_(a, b, /):
     return prims.copy_(copysign(a, b), a)
 
 
+@torchsymbol(torch.Tensor.copy_, is_method=True)  # , tags=(prims.OpTags.IN_PLACE,))
+def copy_(a, b, /):
+    return prims.copy_(b, a)
+
+
 # TODO Implement div
 @torchsymbol(torch.div, is_method=True)
 def div(
@@ -5129,7 +5134,10 @@ def _unwrap_if_dead(tensor):
     DistributedReduceOpLike = str | torch.distributed.ReduceOp | dist_prims.DistributedReduceOps
 
     # string name, PyTorch enum value, thunder.jit enum value
-    _reduceop_triples = (("sum", torch.distributed.ReduceOp.SUM, dist_prims.DistributedReduceOps.SUM),)
+    _reduceop_triples = (
+        ("sum", torch.distributed.ReduceOp.SUM, dist_prims.DistributedReduceOps.SUM),
+        ("max", torch.distributed.ReduceOp.MAX, dist_prims.DistributedReduceOps.MAX),
+    )
 
     def to_thunder_distributed_reduce_op(op: DistributedReduceOpLike | None):
         if isinstance(op, str):
@@ -5206,6 +5214,7 @@ def all_gather_(
     # This operation is based on torch.distributed.all_reduce, see:
     #   https://pytorch.org/docs/master/distributed.html#torch.distributed.all_reduce
     @torchsymbol(
+        torch.ops._c10d_functional.all_reduce,
         is_method=False,
         id="functional_all_reduce",
     )
@@ -5213,9 +5222,15 @@ def all_reduce(
         a: TensorLike,
         /,
         op: DistributedReduceOpLike = torch.distributed.ReduceOp.SUM,
-        group: None | torch.distributed.ProcessGroup = None,
+        group: None | torch.distributed.ProcessGroup | str = None,
         async_op: bool = False,
+        **kwargs,
     ) -> TensorLike | FutureTensorLike:
+        # note: torch.ops._c10d_functional takes name of group
+        if isinstance(group, str):
+            from torch._C._distributed_c10d import _resolve_process_group
+
+            group = _resolve_process_group(group_name=group)
         op = to_thunder_distributed_reduce_op(op)
         group = group if group is not None else torch.distributed.new_group()
 
@@ -5300,6 +5315,7 @@ def reduce_scatter_(
         return prims.copy_(out.view(output.shape), output)
 
     @torchsymbol(
+        torch.ops._c10d_functional.wait_tensor,
         is_method=True,
         id="torch.Tensor.wait",
     )