@@ -670,9 +670,7 @@ def transform_module(self, model: thunder.ThunderModule):
670
670
}
671
671
model ._overrides_buffers [n ] = qb
672
672
673
- def transform_traces_pre_prologue (
674
- self , prologue_trace , computation_trace , epilogue_trace , ** kwargs
675
- ):
673
+ def transform_traces_pre_prologue (self , prologue_trace , computation_trace , epilogue_trace , ** kwargs ):
676
674
tm = self .thunder_module
677
675
from thunder .core .trace import tracectx
678
676
@@ -681,19 +679,15 @@ def transform_traces_pre_prologue(
681
679
prologue_proxy_map = {
682
680
get_param_bsym .output .name : dict (
683
681
shape = self .cast_states [model_weight_name ]["qb.shape" ],
684
- dtype = thunder .dtypes .to_dtype (
685
- self .cast_states [model_weight_name ]["qb.dtype" ]
686
- ),
682
+ dtype = thunder .dtypes .to_dtype (self .cast_states [model_weight_name ]["qb.dtype" ]),
687
683
)
688
684
for model_weight_name , (check_bsym , get_param_bsym ) in checks .items ()
689
685
if model_weight_name in self .cast_states
690
686
}
691
687
692
688
# here we switch the prologue_trace to a copy with new metadata
693
- prologue_trace = (
694
- thunder .transforms .quantization .trace_with_replaced_proxy_metadata (
695
- prologue_trace , prologue_proxy_map
696
- )
689
+ prologue_trace = thunder .transforms .quantization .trace_with_replaced_proxy_metadata (
690
+ prologue_trace , prologue_proxy_map
697
691
)
698
692
699
693
checks = thunder .transforms .utils .get_checks (prologue_trace )
@@ -714,79 +708,58 @@ def transform_traces_pre_prologue(
714
708
shape = psym .shape ,
715
709
dtype = psym .dtype ,
716
710
)
717
- for psym , csym in zip (
718
- prologue_trace .bound_symbols [- 1 ].args [0 ][0 ], computation_trace .args
719
- )
711
+ for psym , csym in zip (prologue_trace .bound_symbols [- 1 ].args [0 ][0 ], computation_trace .args )
720
712
if psym .shape != csym .shape or psym .dtype != csym .dtype
721
713
}
722
714
723
- new_computation_trace = (
724
- thunder .transforms .quantization .trace_with_replaced_proxy_metadata (
725
- computation_trace , computation_proxy_map
726
- )
715
+ new_computation_trace = thunder .transforms .quantization .trace_with_replaced_proxy_metadata (
716
+ computation_trace , computation_proxy_map
727
717
)
728
718
729
- producers , consumers = thunder .core .utils .producers_and_consumers (
730
- new_computation_trace
731
- )
719
+ producers , consumers = thunder .core .utils .producers_and_consumers (new_computation_trace )
732
720
733
721
bound_symbols = new_computation_trace .bound_symbols
734
722
new_computation_trace .bound_symbols = []
735
723
736
- new_computation_trace ._siginfo .args = [
737
- (a .name , None ) for a in new_computation_trace .args
738
- ]
724
+ new_computation_trace ._siginfo .args = [(a .name , None ) for a in new_computation_trace .args ]
739
725
740
726
computation_proxy_map = {}
741
727
new_bound_symbols = []
742
728
for bsym in bound_symbols :
743
- if (
744
- bsym .sym == thunder .torch .to
745
- and producers [bsym .args [0 ]].sym == thunder .core .prims .unpack_trivial
746
- ):
729
+ if bsym .sym == thunder .torch .to and producers [bsym .args [0 ]].sym == thunder .core .prims .unpack_trivial :
747
730
inp = bsym .args [0 ]
748
731
args = (inp , inp .dtype , * bsym .args [2 :])
749
- computation_proxy_map [bsym .output .name ] = dict (
750
- shape = inp .shape , dtype = inp .dtype
751
- )
732
+ computation_proxy_map [bsym .output .name ] = dict (shape = inp .shape , dtype = inp .dtype )
752
733
assert (
753
- len (bsym .subsymbols ) == 1
754
- and bsym .subsymbols [0 ].sym
755
- == thunder .core .prims .convert_element_type
734
+ len (bsym .subsymbols ) == 1 and bsym .subsymbols [0 ].sym == thunder .core .prims .convert_element_type
756
735
)
757
736
subsymbols = [bsym .subsymbols [0 ].from_bsym (args = (inp , inp .dtype ))]
758
- new_bound_symbols .append (
759
- bsym .from_bsym (args = args , subsymbols = subsymbols )
760
- )
737
+ new_bound_symbols .append (bsym .from_bsym (args = args , subsymbols = subsymbols ))
761
738
else :
762
739
new_bound_symbols .append (bsym .from_bsym ())
763
740
764
741
new_computation_trace .bound_symbols = new_bound_symbols
765
742
766
- new_computation_trace = (
767
- thunder .transforms .quantization .trace_with_replaced_proxy_metadata (
768
- new_computation_trace , computation_proxy_map
769
- )
743
+ new_computation_trace = thunder .transforms .quantization .trace_with_replaced_proxy_metadata (
744
+ new_computation_trace , computation_proxy_map
770
745
)
771
746
772
- new_computation_trace .set_provenance (
773
- thunder .core .trace .TraceProvenance ("Dtype Convert" )
774
- )
747
+ new_computation_trace .set_provenance (thunder .core .trace .TraceProvenance ("Dtype Convert" ))
775
748
return prologue_trace , new_computation_trace , epilogue_trace
776
749
777
750
class cast (nn .Module ):
778
751
def __init__ (
779
752
self ,
780
- k_shape : Tuple [int , int , int , int ],
781
- v_shape : Tuple [int , int , int , int ],
782
- device : Optional [ torch .device ] = None ,
783
- dtype : Optional [ torch .dtype ] = None ,
753
+ k_shape : tuple [int , int , int , int ],
754
+ v_shape : tuple [int , int , int , int ],
755
+ device : torch .device | None = None ,
756
+ dtype : torch .dtype | None = None ,
784
757
) -> None :
785
758
super ().__init__ ()
786
759
self .register_buffer ("k" , torch .zeros (k_shape , device = device , dtype = dtype ), persistent = False )
787
760
self .register_buffer ("v" , torch .zeros (v_shape , device = device , dtype = dtype ), persistent = False )
788
761
789
- def forward (self , k : torch .Tensor , v : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
762
+ def forward (self , k : torch .Tensor , v : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
790
763
# move the buffer to the activation dtype for when AMP is used
791
764
self .k = self .k .to (k .dtype )
792
765
self .v = self .v .to (v .dtype )
@@ -795,23 +768,34 @@ def forward(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch
795
768
796
769
# BUG: issue: 1637
797
770
class ParentModule (nn .Module ):
798
- def __init__ (self , k_shape : Tuple [int , int , int , int ], v_shape : Tuple [int , int , int , int ], device : Optional [torch .device ] = None , dtype : Optional [torch .dtype ] = None ):
771
+ def __init__ (
772
+ self ,
773
+ k_shape : tuple [int , int , int , int ],
774
+ v_shape : tuple [int , int , int , int ],
775
+ device : torch .device | None = None ,
776
+ dtype : torch .dtype | None = None ,
777
+ ):
799
778
super ().__init__ ()
800
779
self .cast_module = cast (k_shape , v_shape , device = device , dtype = dtype )
801
780
802
- def forward (self , k : torch .Tensor , v : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
781
+ def forward (self , k : torch .Tensor , v : torch .Tensor ) -> tuple [torch .Tensor , torch .Tensor ]:
803
782
return self .cast_module (k , v )
804
783
805
784
with torch .device ("cpu" ):
806
785
k_shape = (2 , 3 , 4 , 5 )
807
786
v_shape = (2 , 3 , 4 , 5 )
808
787
device = torch .device ("cpu" )
809
788
dtype = torch .float32
810
- model = ( ParentModule (k_shape , v_shape , device = device , dtype = dtype ).eval ().requires_grad_ (False ) )
789
+ model = ParentModule (k_shape , v_shape , device = device , dtype = dtype ).eval ().requires_grad_ (False )
811
790
812
791
k = torch .randn (2 , 3 , 4 , 5 , device = device , dtype = torch .half )
813
792
v = torch .randn (2 , 3 , 4 , 5 , device = device , dtype = torch .half )
814
- cast_jit = thunder .jit (model , transforms = [CastBuffers (),])
793
+ cast_jit = thunder .jit (
794
+ model ,
795
+ transforms = [
796
+ CastBuffers (),
797
+ ],
798
+ )
815
799
output_k , output_v = cast_jit (k , v )
816
800
817
801
def check_dtypes (bsym ):
0 commit comments