Skip to content

Commit 51a3267

Browse files
committed
Use get/set_backward
1 parent 03349ac commit 51a3267

File tree

3 files changed

+145
-108
lines changed

3 files changed

+145
-108
lines changed

thunder/core/rematerialization.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -392,14 +392,16 @@ def add_edges(var):
392392
return tuple(sorted(cut_nodes))
393393

394394

395-
def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[TraceCtx, TraceCtx]:
395+
def rematerialize_all_gather(fw_trace: TraceCtx) -> TraceCtx:
396396
"""Insert new allgather+wait for backward trace and update the return statement for forward trace"""
397397

398398
from thunder.core.proxies import FutureTensorProxy
399399
from thunder.core.trace import reset_tracectx, set_tracectx
400400
from thunder.distributed.prims import PrimIDs as distPrimIDs
401401
from thunder.executors.torchex import all_gather_prim_impl, wait_prim_impl
402+
from thunder.transforms.torch_autograd import get_backward, set_backward
402403

404+
bw_trace = get_backward(fw_trace)
403405
new_bw_trace = from_trace(bw_trace)
404406
consumers = utils.consumers(fw_trace)
405407

@@ -503,7 +505,8 @@ def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[Tr
503505
new_fw_trace = from_trace(fw_trace)
504506
new_fw_trace.bound_symbols = list(fw_trace.bound_symbols)
505507
_update_forward_with_new_saved_for_backward(new_fw_trace, new_required_for_backward)
506-
return new_fw_trace, new_bw_trace
508+
set_backward(new_fw_trace, new_bw_trace)
509+
return new_fw_trace
507510

508511

509512
def rematerialize(trace: TraceCtx) -> TraceCtx:
@@ -570,21 +573,23 @@ def rematerialize(trace: TraceCtx) -> TraceCtx:
570573
return rematerialized_trace
571574

572575

573-
def rematerialize_forward_and_backward(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[TraceCtx, TraceCtx]:
576+
def rematerialize_forward_and_backward(fw_trace: TraceCtx) -> TraceCtx:
574577
"""Apply rematerialization optimization to the forward and backward traces.
575578
576579
Args:
577580
fw_trace (TraceCtx): Forward trace.
578-
bw_trace (TraceCtx): Backward trace.
579581
580582
Returns:
581-
tuple[TraceCtx, TraceCtx]: Rematerialized forward and backward traces.
583+
TraceCtx: Rematerialized forward trace.
582584
"""
583585
# Circular dependency
584586
from thunder.core.transforms import (
585587
_update_backward_with_new_saved_for_backward,
586588
_update_forward_with_new_saved_for_backward,
587589
)
590+
from thunder.transforms.torch_autograd import get_backward, set_backward
591+
592+
bw_trace = get_backward(fw_trace)
588593

589594
def joint_fn(args, kwargs, cotangents):
590595
pass
@@ -654,7 +659,8 @@ def joint_fn(args, kwargs, cotangents):
654659
# Update the call context
655660
new_fw_trace = update_fusion_call_ctx(new_fw_trace)
656661
new_bw_trace = update_fusion_call_ctx(new_bw_trace)
657-
return new_fw_trace, new_bw_trace
662+
new_fw_trace = set_backward(new_fw_trace, new_bw_trace)
663+
return new_fw_trace
658664

659665

660666
def replace_uniform(trace: TraceCtx) -> TraceCtx:

thunder/distributed/transforms/fsdp.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ def _apply_bucketing_to_backward_all_gather(self, fsdp_bwd_trace: TraceCtx) -> T
698698
check_num_comm_and_wait(updated_bwd_trace, _ALL_GATHER_SYM_IDS | _REDUCE_SCATTER_SYM_IDS)
699699
return updated_bwd_trace
700700

701-
def apply_bucketing_to_backward_trace(self, fsdp_bwd_trace: TraceCtx) -> TraceCtx:
701+
def apply_bucketing_to_backward_trace(self, fsdp_fwd_trace: TraceCtx) -> TraceCtx:
702702
"""Apply bucketing to reduce_scatter in fsdp bwd trace.
703703
704704
1. Collect unsharded gradient tensor proxies and create buckets for them based on forward's buckets' name.
@@ -716,6 +716,9 @@ def apply_bucketing_to_backward_trace(self, fsdp_bwd_trace: TraceCtx) -> TraceCt
716716
Returns:
717717
- :class:`TraceCtx`
718718
"""
719+
from thunder.transforms.torch_autograd import get_backward, set_backward
720+
721+
fsdp_bwd_trace = get_backward(fsdp_fwd_trace)
719722

720723
if get_skip_data_parallel_grad_sync():
721724
utils.check(
@@ -724,18 +727,20 @@ def apply_bucketing_to_backward_trace(self, fsdp_bwd_trace: TraceCtx) -> TraceCt
724727
)
725728
if self.requires_bwd_bucketing_allgather:
726729
fsdp_bwd_trace = self._apply_bucketing_to_backward_all_gather(fsdp_bwd_trace)
727-
return stash_unsharded_grads_and_return_none_as_grads(
730+
fsdp_bwd_trace = stash_unsharded_grads_and_return_none_as_grads(
728731
fsdp_bwd_trace,
729732
self.compile_data,
730733
self.index_to_fqn,
731734
)
735+
return set_backward(fsdp_fwd_trace, fsdp_bwd_trace)
732736

733737
if not self.apply_bucketing:
734-
return fsdp_bwd_trace
738+
return fsdp_fwd_trace
735739

736740
# Apply bucketing to parameter unsharding (= AllGather)
737741
if self.requires_bwd_bucketing_allgather:
738742
fsdp_bwd_trace = self._apply_bucketing_to_backward_all_gather(fsdp_bwd_trace)
739743

740744
# Apply bucketing to gradient sharding (= ReduceScatter)
741-
return self._apply_bucketing_to_backward_reduce_scatter(fsdp_bwd_trace)
745+
fsdp_bwd_trace = self._apply_bucketing_to_backward_reduce_scatter(fsdp_bwd_trace)
746+
return set_backward(fsdp_fwd_trace, fsdp_bwd_trace)

0 commit comments

Comments
 (0)