Skip to content

Commit

Permalink
Use get/set_backward
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Oct 18, 2024
1 parent 03349ac commit 51a3267
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 108 deletions.
18 changes: 12 additions & 6 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,14 +392,16 @@ def add_edges(var):
return tuple(sorted(cut_nodes))


def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[TraceCtx, TraceCtx]:
def rematerialize_all_gather(fw_trace: TraceCtx) -> TraceCtx:
"""Insert new allgather+wait for backward trace and update the return statement for forward trace"""

from thunder.core.proxies import FutureTensorProxy
from thunder.core.trace import reset_tracectx, set_tracectx
from thunder.distributed.prims import PrimIDs as distPrimIDs
from thunder.executors.torchex import all_gather_prim_impl, wait_prim_impl
from thunder.transforms.torch_autograd import get_backward, set_backward

bw_trace = get_backward(fw_trace)
new_bw_trace = from_trace(bw_trace)
consumers = utils.consumers(fw_trace)

Expand Down Expand Up @@ -503,7 +505,8 @@ def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[Tr
new_fw_trace = from_trace(fw_trace)
new_fw_trace.bound_symbols = list(fw_trace.bound_symbols)
_update_forward_with_new_saved_for_backward(new_fw_trace, new_required_for_backward)
return new_fw_trace, new_bw_trace
set_backward(new_fw_trace, new_bw_trace)
return new_fw_trace


def rematerialize(trace: TraceCtx) -> TraceCtx:
Expand Down Expand Up @@ -570,21 +573,23 @@ def rematerialize(trace: TraceCtx) -> TraceCtx:
return rematerialized_trace


def rematerialize_forward_and_backward(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[TraceCtx, TraceCtx]:
def rematerialize_forward_and_backward(fw_trace: TraceCtx) -> TraceCtx:
"""Apply rematerialization optimization to the forward and backward traces.
Args:
fw_trace (TraceCtx): Forward trace.
bw_trace (TraceCtx): Backward trace.
Returns:
tuple[TraceCtx, TraceCtx]: Rematerialized forward and backward traces.
TraceCtx: Rematerialized forward trace.
"""
# Circular dependency
from thunder.core.transforms import (
_update_backward_with_new_saved_for_backward,
_update_forward_with_new_saved_for_backward,
)
from thunder.transforms.torch_autograd import get_backward, set_backward

bw_trace = get_backward(fw_trace)

def joint_fn(args, kwargs, cotangents):
pass
Expand Down Expand Up @@ -654,7 +659,8 @@ def joint_fn(args, kwargs, cotangents):
# Update the call context
new_fw_trace = update_fusion_call_ctx(new_fw_trace)
new_bw_trace = update_fusion_call_ctx(new_bw_trace)
return new_fw_trace, new_bw_trace
new_fw_trace = set_backward(new_fw_trace, new_bw_trace)
return new_fw_trace


def replace_uniform(trace: TraceCtx) -> TraceCtx:
Expand Down
13 changes: 9 additions & 4 deletions thunder/distributed/transforms/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def _apply_bucketing_to_backward_all_gather(self, fsdp_bwd_trace: TraceCtx) -> T
check_num_comm_and_wait(updated_bwd_trace, _ALL_GATHER_SYM_IDS | _REDUCE_SCATTER_SYM_IDS)
return updated_bwd_trace

def apply_bucketing_to_backward_trace(self, fsdp_bwd_trace: TraceCtx) -> TraceCtx:
def apply_bucketing_to_backward_trace(self, fsdp_fwd_trace: TraceCtx) -> TraceCtx:
"""Apply bucketing to reduce_scatter in fsdp bwd trace.
1. Collect unsharded gradient tensor proxies and create buckets for them based on forward's buckets' name.
Expand All @@ -716,6 +716,9 @@ def apply_bucketing_to_backward_trace(self, fsdp_bwd_trace: TraceCtx) -> TraceCt
Returns:
- :class:`TraceCtx`
"""
from thunder.transforms.torch_autograd import get_backward, set_backward

fsdp_bwd_trace = get_backward(fsdp_fwd_trace)

if get_skip_data_parallel_grad_sync():
utils.check(
Expand All @@ -724,18 +727,20 @@ def apply_bucketing_to_backward_trace(self, fsdp_bwd_trace: TraceCtx) -> TraceCt
)
if self.requires_bwd_bucketing_allgather:
fsdp_bwd_trace = self._apply_bucketing_to_backward_all_gather(fsdp_bwd_trace)
return stash_unsharded_grads_and_return_none_as_grads(
fsdp_bwd_trace = stash_unsharded_grads_and_return_none_as_grads(
fsdp_bwd_trace,
self.compile_data,
self.index_to_fqn,
)
return set_backward(fsdp_fwd_trace, fsdp_bwd_trace)

if not self.apply_bucketing:
return fsdp_bwd_trace
return fsdp_fwd_trace

# Apply bucketing to parameter unsharding (= AllGather)
if self.requires_bwd_bucketing_allgather:
fsdp_bwd_trace = self._apply_bucketing_to_backward_all_gather(fsdp_bwd_trace)

# Apply bucketing to gradient sharding (= ReduceScatter)
return self._apply_bucketing_to_backward_reduce_scatter(fsdp_bwd_trace)
fsdp_bwd_trace = self._apply_bucketing_to_backward_reduce_scatter(fsdp_bwd_trace)
return set_backward(fsdp_fwd_trace, fsdp_bwd_trace)
Loading

0 comments on commit 51a3267

Please sign in to comment.