Skip to content

Commit ee5d206

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e30f6f9 commit ee5d206

File tree

3 files changed

+52
-65
lines changed

3 files changed

+52
-65
lines changed

thunder/core/rematerialization.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,15 +687,19 @@ def add_to_swapmap(p):
687687
)
688688
new_fw_trace.bound_symbols.append(replace(fw_trace.bound_symbols[-1], args=fw_trace.bound_symbols[-1].args))
689689

690-
# outputs required for backward may have different names between forward and backward.
690+
# outputs required for backward may have different names between forward and backward.
691691
# Rematerialisation may remove some outs from the forward.
692692
old_saved_for_backward_fw = (*fw_trace.bound_symbols[-1].args[1][0], *fw_trace.bound_symbols[-1].args[1][1])
693693
old_saved_for_backward_bw = []
694694
for bsym in bw_trace.bound_symbols:
695695
if bsym.sym.id == PrimIDs.UNPACK_SEQUENCE:
696696
flattened_args = tree_flatten(bw_trace.args[1])[0]
697697
proxy_names = {y.name for y in flattened_args if isinstance(y, ProxyInterface)}
698-
if all(not isinstance(out, CollectionProxy) and out.name not in proxy_names for out in bsym.flat_outs if out is not None):
698+
if all(
699+
not isinstance(out, CollectionProxy) and out.name not in proxy_names
700+
for out in bsym.flat_outs
701+
if out is not None
702+
):
699703
old_saved_for_backward_bw += bsym.flat_outs
700704
assert len(old_saved_for_backward_fw) == len(old_saved_for_backward_bw)
701705
new_required_for_bakward_fw_to_bw_map = {

thunder/core/transforms.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3254,33 +3254,32 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace, do_it=False
32543254

32553255
# outputs required for backward may have different names between forward and backward.
32563256
# Rematerialisation may remove some outs from the forward.
3257-
old_saved_for_backward_fw = (
3258-
*fwd_trace.bound_symbols[-1].args[1][0],
3259-
*fwd_trace.bound_symbols[-1].args[1][1]
3260-
)
3257+
old_saved_for_backward_fw = (*fwd_trace.bound_symbols[-1].args[1][0], *fwd_trace.bound_symbols[-1].args[1][1])
32613258
old_saved_for_backward_bw = []
32623259
for bsym in bwd_trace.bound_symbols:
32633260
if bsym.sym.id == prims.PrimIDs.UNPACK_SEQUENCE:
32643261
flattened_args = tree_flatten(bwd_trace.args[1])[0]
32653262
proxy_names = {y.name for y in flattened_args if isinstance(y, Proxy)}
32663263
if all(
32673264
not isinstance(out, CollectionProxy) and out.name not in proxy_names
3268-
for out in bsym.flat_outs if out is not None
3265+
for out in bsym.flat_outs
3266+
if out is not None
32693267
):
32703268
old_saved_for_backward_bw += bsym.flat_outs
32713269
assert len(old_saved_for_backward_fw) == len(old_saved_for_backward_bw)
32723270
new_required_for_bakward_fw_to_bw_map = {
3273-
x.name: y for x, y in zip(old_saved_for_backward_fw, old_saved_for_backward_bw)
3274-
if x is not None
3271+
x.name: y for x, y in zip(old_saved_for_backward_fw, old_saved_for_backward_bw) if x is not None
32753272
}
32763273
new_required_for_bakward_fw_to_bw_map_mirror = {
3277-
y.name: x for x, y in zip(old_saved_for_backward_fw, old_saved_for_backward_bw)
3278-
if x is not None
3274+
y.name: x for x, y in zip(old_saved_for_backward_fw, old_saved_for_backward_bw) if x is not None
32793275
}
32803276
all_recomputable_proxies = all_recomputable_proxies.union(
32813277
OrderedSet(
3282-
variableify(new_required_for_bakward_fw_to_bw_map[unvariableify(a).name])
3283-
if unvariableify(a).name in new_required_for_bakward_fw_to_bw_map else a
3278+
(
3279+
variableify(new_required_for_bakward_fw_to_bw_map[unvariableify(a).name])
3280+
if unvariableify(a).name in new_required_for_bakward_fw_to_bw_map
3281+
else a
3282+
)
32843283
for a in all_recomputable_proxies
32853284
)
32863285
)

thunder/tests/test_transforms.py

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -670,9 +670,7 @@ def transform_module(self, model: thunder.ThunderModule):
670670
}
671671
model._overrides_buffers[n] = qb
672672

673-
def transform_traces_pre_prologue(
674-
self, prologue_trace, computation_trace, epilogue_trace, **kwargs
675-
):
673+
def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
676674
tm = self.thunder_module
677675
from thunder.core.trace import tracectx
678676

@@ -681,19 +679,15 @@ def transform_traces_pre_prologue(
681679
prologue_proxy_map = {
682680
get_param_bsym.output.name: dict(
683681
shape=self.cast_states[model_weight_name]["qb.shape"],
684-
dtype=thunder.dtypes.to_dtype(
685-
self.cast_states[model_weight_name]["qb.dtype"]
686-
),
682+
dtype=thunder.dtypes.to_dtype(self.cast_states[model_weight_name]["qb.dtype"]),
687683
)
688684
for model_weight_name, (check_bsym, get_param_bsym) in checks.items()
689685
if model_weight_name in self.cast_states
690686
}
691687

692688
# here we switch the prologue_trace to a copy with new metadata
693-
prologue_trace = (
694-
thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
695-
prologue_trace, prologue_proxy_map
696-
)
689+
prologue_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
690+
prologue_trace, prologue_proxy_map
697691
)
698692

699693
checks = thunder.transforms.utils.get_checks(prologue_trace)
@@ -714,79 +708,58 @@ def transform_traces_pre_prologue(
714708
shape=psym.shape,
715709
dtype=psym.dtype,
716710
)
717-
for psym, csym in zip(
718-
prologue_trace.bound_symbols[-1].args[0][0], computation_trace.args
719-
)
711+
for psym, csym in zip(prologue_trace.bound_symbols[-1].args[0][0], computation_trace.args)
720712
if psym.shape != csym.shape or psym.dtype != csym.dtype
721713
}
722714

723-
new_computation_trace = (
724-
thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
725-
computation_trace, computation_proxy_map
726-
)
715+
new_computation_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
716+
computation_trace, computation_proxy_map
727717
)
728718

729-
producers, consumers = thunder.core.utils.producers_and_consumers(
730-
new_computation_trace
731-
)
719+
producers, consumers = thunder.core.utils.producers_and_consumers(new_computation_trace)
732720

733721
bound_symbols = new_computation_trace.bound_symbols
734722
new_computation_trace.bound_symbols = []
735723

736-
new_computation_trace._siginfo.args = [
737-
(a.name, None) for a in new_computation_trace.args
738-
]
724+
new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args]
739725

740726
computation_proxy_map = {}
741727
new_bound_symbols = []
742728
for bsym in bound_symbols:
743-
if (
744-
bsym.sym == thunder.torch.to
745-
and producers[bsym.args[0]].sym == thunder.core.prims.unpack_trivial
746-
):
729+
if bsym.sym == thunder.torch.to and producers[bsym.args[0]].sym == thunder.core.prims.unpack_trivial:
747730
inp = bsym.args[0]
748731
args = (inp, inp.dtype, *bsym.args[2:])
749-
computation_proxy_map[bsym.output.name] = dict(
750-
shape=inp.shape, dtype=inp.dtype
751-
)
732+
computation_proxy_map[bsym.output.name] = dict(shape=inp.shape, dtype=inp.dtype)
752733
assert (
753-
len(bsym.subsymbols) == 1
754-
and bsym.subsymbols[0].sym
755-
== thunder.core.prims.convert_element_type
734+
len(bsym.subsymbols) == 1 and bsym.subsymbols[0].sym == thunder.core.prims.convert_element_type
756735
)
757736
subsymbols = [bsym.subsymbols[0].from_bsym(args=(inp, inp.dtype))]
758-
new_bound_symbols.append(
759-
bsym.from_bsym(args=args, subsymbols=subsymbols)
760-
)
737+
new_bound_symbols.append(bsym.from_bsym(args=args, subsymbols=subsymbols))
761738
else:
762739
new_bound_symbols.append(bsym.from_bsym())
763740

764741
new_computation_trace.bound_symbols = new_bound_symbols
765742

766-
new_computation_trace = (
767-
thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
768-
new_computation_trace, computation_proxy_map
769-
)
743+
new_computation_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
744+
new_computation_trace, computation_proxy_map
770745
)
771746

772-
new_computation_trace.set_provenance(
773-
thunder.core.trace.TraceProvenance("Dtype Convert")
774-
)
747+
new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("Dtype Convert"))
775748
return prologue_trace, new_computation_trace, epilogue_trace
776749

777750
class cast(nn.Module):
778751
def __init__(
779752
self,
780-
k_shape: Tuple[int, int, int, int],
781-
v_shape: Tuple[int, int, int, int],
782-
device: Optional[torch.device] = None,
783-
dtype: Optional[torch.dtype] = None,
753+
k_shape: tuple[int, int, int, int],
754+
v_shape: tuple[int, int, int, int],
755+
device: torch.device | None = None,
756+
dtype: torch.dtype | None = None,
784757
) -> None:
785758
super().__init__()
786759
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
787760
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
788761

789-
def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
762+
def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
790763
# move the buffer to the activation dtype for when AMP is used
791764
self.k = self.k.to(k.dtype)
792765
self.v = self.v.to(v.dtype)
@@ -795,23 +768,34 @@ def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch
795768

796769
# BUG: issue: 1637
797770
class ParentModule(nn.Module):
798-
def __init__(self, k_shape: Tuple[int, int, int, int], v_shape: Tuple[int, int, int, int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
771+
def __init__(
772+
self,
773+
k_shape: tuple[int, int, int, int],
774+
v_shape: tuple[int, int, int, int],
775+
device: torch.device | None = None,
776+
dtype: torch.dtype | None = None,
777+
):
799778
super().__init__()
800779
self.cast_module = cast(k_shape, v_shape, device=device, dtype=dtype)
801780

802-
def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
781+
def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
803782
return self.cast_module(k, v)
804783

805784
with torch.device("cpu"):
806785
k_shape = (2, 3, 4, 5)
807786
v_shape = (2, 3, 4, 5)
808787
device = torch.device("cpu")
809788
dtype = torch.float32
810-
model = (ParentModule(k_shape, v_shape, device=device, dtype=dtype).eval().requires_grad_(False))
789+
model = ParentModule(k_shape, v_shape, device=device, dtype=dtype).eval().requires_grad_(False)
811790

812791
k = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
813792
v = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
814-
cast_jit = thunder.jit(model, transforms=[CastBuffers(),])
793+
cast_jit = thunder.jit(
794+
model,
795+
transforms=[
796+
CastBuffers(),
797+
],
798+
)
815799
output_k, output_v = cast_jit(k, v)
816800

817801
def check_dtypes(bsym):

0 commit comments

Comments
 (0)