Skip to content

Commit 45519f0

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent df6f835 commit 45519f0

File tree

2 files changed

+43
-55
lines changed

2 files changed

+43
-55
lines changed

thunder/core/rematerialization.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ def add_to_swapmap(p):
687687
)
688688
new_fw_trace.bound_symbols.append(replace(fw_trace.bound_symbols[-1], args=fw_trace.bound_symbols[-1].args))
689689

690-
# outputs required for backward may have different names between forward and backward.
690+
# outputs required for backward may have different names between forward and backward.
691691
# Rematerialisation may remove some outs from the forward.
692692
old_saved_for_backward_fw = fw_trace.bound_symbols[-1].args[1][0]
693693
old_saved_for_backward_bw = []
@@ -696,8 +696,12 @@ def add_to_swapmap(p):
696696
old_saved_for_backward_bw = bsym.flat_outs
697697
break
698698
assert len(old_saved_for_backward_fw) == len(old_saved_for_backward_bw)
699-
new_required_for_bakward_fw_to_bw_map = {x.name: y for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None}
700-
new_required_for_backward = tuple([new_required_for_bakward_fw_to_bw_map[a.name] for a in new_required_for_backward])
699+
new_required_for_bakward_fw_to_bw_map = {
700+
x.name: y for x, y in zip(old_saved_for_backward_bw, old_saved_for_backward_fw) if x is not None
701+
}
702+
new_required_for_backward = tuple(
703+
[new_required_for_bakward_fw_to_bw_map[a.name] for a in new_required_for_backward]
704+
)
701705
_update_forward_with_new_saved_for_backward(new_fw_trace, new_required_for_backward)
702706

703707
# prims.python_return was updated and now DCE can remove the unused

thunder/tests/test_transforms.py

Lines changed: 36 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -670,9 +670,7 @@ def transform_module(self, model: thunder.ThunderModule):
670670
}
671671
model._overrides_buffers[n] = qb
672672

673-
def transform_traces_pre_prologue(
674-
self, prologue_trace, computation_trace, epilogue_trace, **kwargs
675-
):
673+
def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
676674
tm = self.thunder_module
677675
from thunder.core.trace import tracectx
678676

@@ -681,19 +679,15 @@ def transform_traces_pre_prologue(
681679
prologue_proxy_map = {
682680
get_param_bsym.output.name: dict(
683681
shape=self.cast_states[model_weight_name]["qb.shape"],
684-
dtype=thunder.dtypes.to_dtype(
685-
self.cast_states[model_weight_name]["qb.dtype"]
686-
),
682+
dtype=thunder.dtypes.to_dtype(self.cast_states[model_weight_name]["qb.dtype"]),
687683
)
688684
for model_weight_name, (check_bsym, get_param_bsym) in checks.items()
689685
if model_weight_name in self.cast_states
690686
}
691687

692688
# here we switch the prologue_trace to a copy with new metadata
693-
prologue_trace = (
694-
thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
695-
prologue_trace, prologue_proxy_map
696-
)
689+
prologue_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
690+
prologue_trace, prologue_proxy_map
697691
)
698692

699693
checks = thunder.transforms.utils.get_checks(prologue_trace)
@@ -714,79 +708,58 @@ def transform_traces_pre_prologue(
714708
shape=psym.shape,
715709
dtype=psym.dtype,
716710
)
717-
for psym, csym in zip(
718-
prologue_trace.bound_symbols[-1].args[0][0], computation_trace.args
719-
)
711+
for psym, csym in zip(prologue_trace.bound_symbols[-1].args[0][0], computation_trace.args)
720712
if psym.shape != csym.shape or psym.dtype != csym.dtype
721713
}
722714

723-
new_computation_trace = (
724-
thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
725-
computation_trace, computation_proxy_map
726-
)
715+
new_computation_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
716+
computation_trace, computation_proxy_map
727717
)
728718

729-
producers, consumers = thunder.core.utils.producers_and_consumers(
730-
new_computation_trace
731-
)
719+
producers, consumers = thunder.core.utils.producers_and_consumers(new_computation_trace)
732720

733721
bound_symbols = new_computation_trace.bound_symbols
734722
new_computation_trace.bound_symbols = []
735723

736-
new_computation_trace._siginfo.args = [
737-
(a.name, None) for a in new_computation_trace.args
738-
]
724+
new_computation_trace._siginfo.args = [(a.name, None) for a in new_computation_trace.args]
739725

740726
computation_proxy_map = {}
741727
new_bound_symbols = []
742728
for bsym in bound_symbols:
743-
if (
744-
bsym.sym == thunder.torch.to
745-
and producers[bsym.args[0]].sym == thunder.core.prims.unpack_trivial
746-
):
729+
if bsym.sym == thunder.torch.to and producers[bsym.args[0]].sym == thunder.core.prims.unpack_trivial:
747730
inp = bsym.args[0]
748731
args = (inp, inp.dtype, *bsym.args[2:])
749-
computation_proxy_map[bsym.output.name] = dict(
750-
shape=inp.shape, dtype=inp.dtype
751-
)
732+
computation_proxy_map[bsym.output.name] = dict(shape=inp.shape, dtype=inp.dtype)
752733
assert (
753-
len(bsym.subsymbols) == 1
754-
and bsym.subsymbols[0].sym
755-
== thunder.core.prims.convert_element_type
734+
len(bsym.subsymbols) == 1 and bsym.subsymbols[0].sym == thunder.core.prims.convert_element_type
756735
)
757736
subsymbols = [bsym.subsymbols[0].from_bsym(args=(inp, inp.dtype))]
758-
new_bound_symbols.append(
759-
bsym.from_bsym(args=args, subsymbols=subsymbols)
760-
)
737+
new_bound_symbols.append(bsym.from_bsym(args=args, subsymbols=subsymbols))
761738
else:
762739
new_bound_symbols.append(bsym.from_bsym())
763740

764741
new_computation_trace.bound_symbols = new_bound_symbols
765742

766-
new_computation_trace = (
767-
thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
768-
new_computation_trace, computation_proxy_map
769-
)
743+
new_computation_trace = thunder.transforms.quantization.trace_with_replaced_proxy_metadata(
744+
new_computation_trace, computation_proxy_map
770745
)
771746

772-
new_computation_trace.set_provenance(
773-
thunder.core.trace.TraceProvenance("Dtype Convert")
774-
)
747+
new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("Dtype Convert"))
775748
return prologue_trace, new_computation_trace, epilogue_trace
776749

777750
class cast(nn.Module):
778751
def __init__(
779752
self,
780-
k_shape: Tuple[int, int, int, int],
781-
v_shape: Tuple[int, int, int, int],
782-
device: Optional[torch.device] = None,
783-
dtype: Optional[torch.dtype] = None,
753+
k_shape: tuple[int, int, int, int],
754+
v_shape: tuple[int, int, int, int],
755+
device: torch.device | None = None,
756+
dtype: torch.dtype | None = None,
784757
) -> None:
785758
super().__init__()
786759
self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False)
787760
self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False)
788761

789-
def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
762+
def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
790763
# move the buffer to the activation dtype for when AMP is used
791764
self.k = self.k.to(k.dtype)
792765
self.v = self.v.to(v.dtype)
@@ -795,23 +768,34 @@ def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch
795768

796769
# BUG: issue: 1637
797770
class ParentModule(nn.Module):
798-
def __init__(self, k_shape: Tuple[int, int, int, int], v_shape: Tuple[int, int, int, int], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
771+
def __init__(
772+
self,
773+
k_shape: tuple[int, int, int, int],
774+
v_shape: tuple[int, int, int, int],
775+
device: torch.device | None = None,
776+
dtype: torch.dtype | None = None,
777+
):
799778
super().__init__()
800779
self.cast_module = cast(k_shape, v_shape, device=device, dtype=dtype)
801780

802-
def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
781+
def forward(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
803782
return self.cast_module(k, v)
804783

805784
with torch.device("cpu"):
806785
k_shape = (2, 3, 4, 5)
807786
v_shape = (2, 3, 4, 5)
808787
device = torch.device("cpu")
809788
dtype = torch.float32
810-
model = (ParentModule(k_shape, v_shape, device=device, dtype=dtype).eval().requires_grad_(False))
789+
model = ParentModule(k_shape, v_shape, device=device, dtype=dtype).eval().requires_grad_(False)
811790

812791
k = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
813792
v = torch.randn(2, 3, 4, 5, device=device, dtype=torch.half)
814-
cast_jit = thunder.jit(model, transforms=[CastBuffers(),])
793+
cast_jit = thunder.jit(
794+
model,
795+
transforms=[
796+
CastBuffers(),
797+
],
798+
)
815799
output_k, output_v = cast_jit(k, v)
816800

817801
def check_dtypes(bsym):

0 commit comments

Comments
 (0)