@@ -392,14 +392,16 @@ def add_edges(var):
392
392
return tuple (sorted (cut_nodes ))
393
393
394
394
395
- def rematerialize_all_gather (fw_trace : TraceCtx , bw_trace : TraceCtx ) -> tuple [ TraceCtx , TraceCtx ] :
395
+ def rematerialize_all_gather (fw_trace : TraceCtx ) -> TraceCtx :
396
396
"""Insert new allgather+wait for backward trace and update the return statement for forward trace"""
397
397
398
398
from thunder .core .proxies import FutureTensorProxy
399
399
from thunder .core .trace import reset_tracectx , set_tracectx
400
400
from thunder .distributed .prims import PrimIDs as distPrimIDs
401
401
from thunder .executors .torchex import all_gather_prim_impl , wait_prim_impl
402
+ from thunder .transforms .torch_autograd import get_backward , set_backward
402
403
404
+ bw_trace = get_backward (fw_trace )
403
405
new_bw_trace = from_trace (bw_trace )
404
406
consumers = utils .consumers (fw_trace )
405
407
@@ -503,7 +505,8 @@ def rematerialize_all_gather(fw_trace: TraceCtx, bw_trace: TraceCtx) -> tuple[Tr
503
505
new_fw_trace = from_trace (fw_trace )
504
506
new_fw_trace .bound_symbols = list (fw_trace .bound_symbols )
505
507
_update_forward_with_new_saved_for_backward (new_fw_trace , new_required_for_backward )
506
- return new_fw_trace , new_bw_trace
508
+ set_backward (new_fw_trace , new_bw_trace )
509
+ return new_fw_trace
507
510
508
511
509
512
def rematerialize (trace : TraceCtx ) -> TraceCtx :
@@ -570,21 +573,23 @@ def rematerialize(trace: TraceCtx) -> TraceCtx:
570
573
return rematerialized_trace
571
574
572
575
573
- def rematerialize_forward_and_backward (fw_trace : TraceCtx , bw_trace : TraceCtx ) -> tuple [ TraceCtx , TraceCtx ] :
576
+ def rematerialize_forward_and_backward (fw_trace : TraceCtx ) -> TraceCtx :
574
577
"""Apply rematerialization optimization to the forward and backward traces.
575
578
576
579
Args:
577
580
fw_trace (TraceCtx): Forward trace.
578
- bw_trace (TraceCtx): Backward trace.
579
581
580
582
Returns:
581
- tuple[ TraceCtx, TraceCtx] : Rematerialized forward and backward traces .
583
+ TraceCtx: Rematerialized forward trace .
582
584
"""
583
585
# Circular dependency
584
586
from thunder .core .transforms import (
585
587
_update_backward_with_new_saved_for_backward ,
586
588
_update_forward_with_new_saved_for_backward ,
587
589
)
590
+ from thunder .transforms .torch_autograd import get_backward , set_backward
591
+
592
+ bw_trace = get_backward (fw_trace )
588
593
589
594
def joint_fn (args , kwargs , cotangents ):
590
595
pass
@@ -654,7 +659,8 @@ def joint_fn(args, kwargs, cotangents):
654
659
# Update the call context
655
660
new_fw_trace = update_fusion_call_ctx (new_fw_trace )
656
661
new_bw_trace = update_fusion_call_ctx (new_bw_trace )
657
- return new_fw_trace , new_bw_trace
662
+ new_fw_trace = set_backward (new_fw_trace , new_bw_trace )
663
+ return new_fw_trace
658
664
659
665
660
666
def replace_uniform (trace : TraceCtx ) -> TraceCtx :
0 commit comments