diff --git a/thunder/__init__.py b/thunder/__init__.py index d2f596a90b..72eb5c76a7 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -458,15 +458,10 @@ def get_computation_and_inputs(*args, **kwargs): _vanilla_args, ) = cache_entry try: - cs.last_prologue_execution_start = time.perf_counter_ns() inps, pro_to_epi = pro(*args, **kwargs) - cs.last_prologue_execution_stop = time.perf_counter_ns() except Exception as _: continue - cs.last_trace_host_tracing_start = time.perf_counter_ns() - cs.last_trace_host_tracing_stop = time.perf_counter_ns() - # Updates cache statistics cs.cache_hits += 1 cs.last_traces = comp_traces @@ -495,12 +490,7 @@ def get_computation_and_inputs(*args, **kwargs): backward_traces, ) = cache_entry - cs.last_prologue_execution_start = time.perf_counter_ns() inps, pro_to_epi = pro(*args, **kwargs) - cs.last_prologue_execution_stop = time.perf_counter_ns() - - cs.last_trace_host_tracing_start = time.perf_counter_ns() - cs.last_trace_host_tracing_stop = time.perf_counter_ns() # Updates cache statistics cs.cache_hits += 1 @@ -622,6 +612,7 @@ def get_computation_and_inputs(*args, **kwargs): ) prologue_trc = prologue_traces[-1] pro = prologue_trc.python_callable(include_decorators=False) + pro = prologue_execution_timer(pro) if epilogue_trc is not None: epilogue = epilogue_trc.python_callable() @@ -637,9 +628,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_interpreter_log = last_interpreter_log cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction)) - cs.last_prologue_execution_start = time.perf_counter_ns() inps, pro_to_epi = pro(*args, **kwargs) - cs.last_prologue_execution_stop = time.perf_counter_ns() computation_trc = dce(computation_trc) computation_traces.append(computation_trc) @@ -729,23 +718,55 @@ def get_computation_and_inputs(*args, **kwargs): return cache_entry, inps, pro_to_epi - cd.get_computation_and_inputs = get_computation_and_inputs + def host_execution_timer(fn): + def wrapped(*args, **kwargs): + cs.last_trace_host_execution_start = time.perf_counter_ns() + try: + return fn(*args, **kwargs) + finally: + cs.last_trace_host_execution_stop = time.perf_counter_ns() - @wraps(fn) - def fn_(*args, **kwargs) -> Any: - if is_tracing(): - _recursive_jit_call_warning() - return fn(*args, **kwargs) + return wrapped - # Updats call statistics - cs.last_trace_host_start = time.perf_counter_ns() - cs.calls += 1 + def prologue_execution_timer(fn): + def wrapped(*args, **kwargs): + cs.last_prologue_execution_start = time.perf_counter_ns() + try: + return fn(*args, **kwargs) + finally: + cs.last_prologue_execution_stop = time.perf_counter_ns() - cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) - cs.last_trace_host_execution_start = time.perf_counter_ns() + return wrapped + + def decorate_computation_function(get_computation_and_inputs_fn, *decorators): + def wrapped(*args, **kwargs): + cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs) + decorated_computation_fn = cache_entry.computation_fn + for decorator in decorators: + decorated_computation_fn = decorator(decorated_computation_fn) + if decorators: + cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn) + return cache_entry, inps, pro_to_epi + + return wrapped + + get_computation_and_inputs = decorate_computation_function(get_computation_and_inputs, host_execution_timer) + cd.get_computation_and_inputs = get_computation_and_inputs + + def update_call_statistics(fn): + def wrapped(*args, **kwargs): + cs.calls += 1 + cs.last_trace_host_start = time.perf_counter_ns() + try: + return fn(*args, **kwargs) + finally: + cs.last_trace_host_stop = time.perf_counter_ns() + return wrapped + + def check_storage_aliases(cache_entry, args): if cache_entry.vanilla_tensor_args: - if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps): + if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*args): alias_tensor_indices = alias_tensor_indices_str alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")} vanilla_tensor_args = cache_entry.vanilla_tensor_args @@ -755,13 +776,12 @@ def fn_(*args, **kwargs) -> Any: NotImplementedError, ) - result = cache_entry.computation_fn(*inps) - + def maybe_connect_to_autograd(cache_entry, result): if cache_entry.backward_fn: - # Run the compiled forward function + # If the backward function is available, we need to connect the + # resulting tensors to PyTorch's Autograd graph using the + # ThunderFunction (which is a torch.autograd.Function subclass) data_for_autograd, (saved_tensors, saved_other) = result - - # Connect produced tensors with PyTorch's autograd graph ThunderFunction.apply( cache_entry.return_none_instead_of_grads, cache_entry.backward_fn, @@ -772,17 +792,31 @@ def fn_(*args, **kwargs) -> Any: ) result = data_for_autograd["output"] + return result + + def maybe_call_epilogue(cache_entry, result, pro_to_epi): if cache_entry.epilogue_fn: result, comp_to_epi = result cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi) - cs.last_trace_host_execution_stop = time.perf_counter_ns() - cs.last_computation_execution_stop = cs.last_trace_host_execution_stop + return result - cs.last_executed = cache_entry.computation_fn - cs.last_trace_cache_stop = time.perf_counter_ns() - cs.last_trace_host_stop = time.perf_counter_ns() + @wraps(fn) + @update_call_statistics + def fn_(*args, **kwargs) -> Any: + if is_tracing(): + _recursive_jit_call_warning() + return fn(*args, **kwargs) + + cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) + + check_storage_aliases(cache_entry, inps) + + result = cache_entry.computation_fn(*inps) + result = maybe_connect_to_autograd(cache_entry, result) + result = maybe_call_epilogue(cache_entry, result, pro_to_epi) + cs.last_computation = cache_entry.computation_fn return result if isinstance(fn, pytorch.nn.Module): diff --git a/thunder/common.py b/thunder/common.py index 9f0d0bd0a1..5fb8b2f2ce 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -64,7 +64,6 @@ # Holds statistics and caches for a compiled function -# TODO RC1 Update last_executed to last_computation # TODO RC1 Review how autograd traces are presented class CompileStats: """A class holding statistics and caches for a compiled function. @@ -76,7 +75,7 @@ class CompileStats: See :mod:`thunder` for more of such utility functions. Attributes: - last_executed: + last_computation (Callable): last_traces (Sequence[TraceCtx]): last_prologue (TraceCtx): last_prologue_traces (Sequence[TraceCtx]): @@ -107,7 +106,7 @@ class CompileStats: def __init__(self): # Callables and traces - self.last_executed = None + self.last_computation = None self.last_traces = None self.last_prologue = None self.last_prologue_traces = None