@@ -458,15 +458,10 @@ def get_computation_and_inputs(*args, **kwargs):
458
458
_vanilla_args ,
459
459
) = cache_entry
460
460
try :
461
- cs .last_prologue_execution_start = time .perf_counter_ns ()
462
461
inps , pro_to_epi = pro (* args , ** kwargs )
463
- cs .last_prologue_execution_stop = time .perf_counter_ns ()
464
462
except Exception as _ :
465
463
continue
466
464
467
- cs .last_trace_host_tracing_start = time .perf_counter_ns ()
468
- cs .last_trace_host_tracing_stop = time .perf_counter_ns ()
469
-
470
465
# Updates cache statistics
471
466
cs .cache_hits += 1
472
467
cs .last_traces = comp_traces
@@ -495,12 +490,7 @@ def get_computation_and_inputs(*args, **kwargs):
495
490
backward_traces ,
496
491
) = cache_entry
497
492
498
- cs .last_prologue_execution_start = time .perf_counter_ns ()
499
493
inps , pro_to_epi = pro (* args , ** kwargs )
500
- cs .last_prologue_execution_stop = time .perf_counter_ns ()
501
-
502
- cs .last_trace_host_tracing_start = time .perf_counter_ns ()
503
- cs .last_trace_host_tracing_stop = time .perf_counter_ns ()
504
494
505
495
# Updates cache statistics
506
496
cs .cache_hits += 1
@@ -622,6 +612,7 @@ def get_computation_and_inputs(*args, **kwargs):
622
612
)
623
613
prologue_trc = prologue_traces [- 1 ]
624
614
pro = prologue_trc .python_callable (include_decorators = False )
615
+ pro = prologue_execution_timer (pro )
625
616
626
617
if epilogue_trc is not None :
627
618
epilogue = epilogue_trc .python_callable ()
@@ -637,9 +628,7 @@ def get_computation_and_inputs(*args, **kwargs):
637
628
cs .last_interpreter_log = last_interpreter_log
638
629
cs .last_interpreted_instructions = (i for i in last_interpreter_log if isinstance (i , dis .Instruction ))
639
630
640
- cs .last_prologue_execution_start = time .perf_counter_ns ()
641
631
inps , pro_to_epi = pro (* args , ** kwargs )
642
- cs .last_prologue_execution_stop = time .perf_counter_ns ()
643
632
644
633
computation_trc = dce (computation_trc )
645
634
computation_traces .append (computation_trc )
@@ -729,23 +718,55 @@ def get_computation_and_inputs(*args, **kwargs):
729
718
730
719
return cache_entry , inps , pro_to_epi
731
720
732
- cd .get_computation_and_inputs = get_computation_and_inputs
721
+ def host_execution_timer (fn ):
722
+ def wrapped (* args , ** kwargs ):
723
+ cs .last_trace_host_execution_start = time .perf_counter_ns ()
724
+ try :
725
+ return fn (* args , ** kwargs )
726
+ finally :
727
+ cs .last_trace_host_execution_stop = time .perf_counter_ns ()
733
728
734
- @wraps (fn )
735
- def fn_ (* args , ** kwargs ) -> Any :
736
- if is_tracing ():
737
- _recursive_jit_call_warning ()
738
- return fn (* args , ** kwargs )
729
+ return wrapped
739
730
740
- # Updats call statistics
741
- cs .last_trace_host_start = time .perf_counter_ns ()
742
- cs .calls += 1
731
+ def prologue_execution_timer (fn ):
732
+ def wrapped (* args , ** kwargs ):
733
+ cs .last_prologue_execution_start = time .perf_counter_ns ()
734
+ try :
735
+ return fn (* args , ** kwargs )
736
+ finally :
737
+ cs .last_prologue_execution_stop = time .perf_counter_ns ()
743
738
744
- cache_entry , inps , pro_to_epi = get_computation_and_inputs (* args , ** kwargs )
745
- cs .last_trace_host_execution_start = time .perf_counter_ns ()
739
+ return wrapped
740
+
741
+ def decorate_computation_function (get_computation_and_inputs_fn , * decorators ):
742
+ def wrapped (* args , ** kwargs ):
743
+ cache_entry , inps , pro_to_epi = get_computation_and_inputs_fn (* args , ** kwargs )
744
+ decorated_computation_fn = cache_entry .computation_fn
745
+ for decorator in decorators :
746
+ decorated_computation_fn = decorator (decorated_computation_fn )
747
+ if decorators :
748
+ cache_entry = cache_entry ._replace (computation_fn = decorated_computation_fn )
749
+ return cache_entry , inps , pro_to_epi
750
+
751
+ return wrapped
752
+
753
+ get_computation_and_inputs = decorate_computation_function (get_computation_and_inputs , host_execution_timer )
754
+ cd .get_computation_and_inputs = get_computation_and_inputs
755
+
756
+ def update_call_statistics (fn ):
757
+ def wrapped (* args , ** kwargs ):
758
+ cs .calls += 1
759
+ cs .last_trace_host_start = time .perf_counter_ns ()
760
+ try :
761
+ return fn (* args , ** kwargs )
762
+ finally :
763
+ cs .last_trace_host_stop = time .perf_counter_ns ()
746
764
765
+ return wrapped
766
+
767
+ def check_storage_aliases (cache_entry , args ):
747
768
if cache_entry .vanilla_tensor_args :
748
- if alias_tensor_indices_str := _alias_tensor_of_args_kwargs (* inps ):
769
+ if alias_tensor_indices_str := _alias_tensor_of_args_kwargs (* args ):
749
770
alias_tensor_indices = alias_tensor_indices_str
750
771
alias_tensor_indices = {int (i ) for i in alias_tensor_indices_str .split ("," )}
751
772
vanilla_tensor_args = cache_entry .vanilla_tensor_args
@@ -755,13 +776,12 @@ def fn_(*args, **kwargs) -> Any:
755
776
NotImplementedError ,
756
777
)
757
778
758
- result = cache_entry .computation_fn (* inps )
759
-
779
+ def maybe_connect_to_autograd (cache_entry , result ):
760
780
if cache_entry .backward_fn :
761
- # Run the compiled forward function
781
+ # If the backward function is available, we need to connect the
782
+ # resulting tensors to PyTorch's Autograd graph using the
783
+ # ThunderFunction (which is a torch.autograd.Function subclass)
762
784
data_for_autograd , (saved_tensors , saved_other ) = result
763
-
764
- # Connect produced tensors with PyTorch's autograd graph
765
785
ThunderFunction .apply (
766
786
cache_entry .return_none_instead_of_grads ,
767
787
cache_entry .backward_fn ,
@@ -772,17 +792,31 @@ def fn_(*args, **kwargs) -> Any:
772
792
)
773
793
result = data_for_autograd ["output" ]
774
794
795
+ return result
796
+
797
+ def maybe_call_epilogue (cache_entry , result , pro_to_epi ):
775
798
if cache_entry .epilogue_fn :
776
799
result , comp_to_epi = result
777
800
cache_entry .epilogue_fn (* pro_to_epi , * comp_to_epi )
778
801
779
- cs .last_trace_host_execution_stop = time .perf_counter_ns ()
780
- cs .last_computation_execution_stop = cs .last_trace_host_execution_stop
802
+ return result
781
803
782
- cs .last_executed = cache_entry .computation_fn
783
- cs .last_trace_cache_stop = time .perf_counter_ns ()
784
- cs .last_trace_host_stop = time .perf_counter_ns ()
804
+ @wraps (fn )
805
+ @update_call_statistics
806
+ def fn_ (* args , ** kwargs ) -> Any :
807
+ if is_tracing ():
808
+ _recursive_jit_call_warning ()
809
+ return fn (* args , ** kwargs )
810
+
811
+ cache_entry , inps , pro_to_epi = get_computation_and_inputs (* args , ** kwargs )
812
+
813
+ check_storage_aliases (cache_entry , inps )
814
+
815
+ result = cache_entry .computation_fn (* inps )
816
+ result = maybe_connect_to_autograd (cache_entry , result )
817
+ result = maybe_call_epilogue (cache_entry , result , pro_to_epi )
785
818
819
+ cs .last_computation = cache_entry .computation_fn
786
820
return result
787
821
788
822
if isinstance (fn , pytorch .nn .Module ):
0 commit comments