From 4d15278006b06d6658ca53d0a0baf4876d2c988e Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:12:11 +0300 Subject: [PATCH 01/15] Move last_trace_host and calls recording to a decorator --- thunder/__init__.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index e78839d68e..9c4e6a96ff 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -723,16 +723,26 @@ def get_computation_and_inputs(*args, **kwargs): cd.get_computation_and_inputs = get_computation_and_inputs + + def update_call_statistics(fn): + def wrapped(*args, **kwargs): + cs.calls += 1 + cs.last_trace_host_start = time.perf_counter_ns() + try: + return fn(*args, **kwargs) + finally: + cs.last_trace_host_stop = time.perf_counter_ns() + + return wrapped + + @wraps(fn) + @update_call_statistics def fn_(*args, **kwargs) -> Any: if is_tracing(): _recursive_jit_call_warning() return fn(*args, **kwargs) - # Updats call statistics - cs.last_trace_host_start = time.perf_counter_ns() - cs.calls += 1 - cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) cs.last_trace_host_execution_start = time.perf_counter_ns() @@ -773,7 +783,6 @@ def fn_(*args, **kwargs) -> Any: cs.last_executed = cache_entry.computation_fn cs.last_trace_cache_stop = time.perf_counter_ns() - cs.last_trace_host_stop = time.perf_counter_ns() return result From a36f14b833f11421e6ec77ed43b98ead68029b5b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:14:32 +0300 Subject: [PATCH 02/15] Remove last_trace_cache_stop there's no corresponding start --- thunder/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 9c4e6a96ff..aa964c13c3 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -782,7 +782,6 @@ def fn_(*args, **kwargs) -> Any: cs.last_computation_execution_stop = cs.last_trace_host_execution_stop cs.last_executed = cache_entry.computation_fn - cs.last_trace_cache_stop = time.perf_counter_ns() return result From 09034c1ce011ee49d20e2607fc7b9993c62a66fe Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:16:06 +0300 Subject: [PATCH 03/15] Rename last_executed -> last_computation --- thunder/__init__.py | 2 +- thunder/common.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index aa964c13c3..509923638a 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -781,7 +781,7 @@ def fn_(*args, **kwargs) -> Any: cs.last_trace_host_execution_stop = time.perf_counter_ns() cs.last_computation_execution_stop = cs.last_trace_host_execution_stop - cs.last_executed = cache_entry.computation_fn + cs.last_computation = cache_entry.computation_fn return result diff --git a/thunder/common.py b/thunder/common.py index 9f0d0bd0a1..5fb8b2f2ce 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -64,7 +64,6 @@ # Holds statistics and caches for a compiled function -# TODO RC1 Update last_executed to last_computation # TODO RC1 Review how autograd traces are presented class CompileStats: """A class holding statistics and caches for a compiled function. @@ -76,7 +75,7 @@ class CompileStats: See :mod:`thunder` for more of such utility functions. Attributes: - last_executed: + last_computation (Callable): last_traces (Sequence[TraceCtx]): last_prologue (TraceCtx): last_prologue_traces (Sequence[TraceCtx]): @@ -107,7 +106,7 @@ class CompileStats: def __init__(self): # Callables and traces - self.last_executed = None + self.last_computation = None self.last_traces = None self.last_prologue = None self.last_prologue_traces = None From 288e840b179de950abc3956c21666cf0f5b968ae Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:18:21 +0300 Subject: [PATCH 04/15] Remove last_computation_execution_stop there's no corresponding start --- thunder/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 509923638a..3c822c18ca 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -779,7 +779,6 @@ def fn_(*args, **kwargs) -> Any: cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi) cs.last_trace_host_execution_stop = time.perf_counter_ns() - cs.last_computation_execution_stop = cs.last_trace_host_execution_stop cs.last_computation = cache_entry.computation_fn From 3495bfee862b6b45b2542d4851448a14b6ff0af0 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:24:17 +0300 Subject: [PATCH 05/15] Move storage check for aliases to a separate function --- thunder/__init__.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 3c822c18ca..a7fa68e6c8 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -736,6 +736,19 @@ def wrapped(*args, **kwargs): return wrapped + def check_storage_aliases(cache_entry, args): + if cache_entry.vanilla_tensor_args: + if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*args): + alias_tensor_indices = alias_tensor_indices_str + alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")} + vanilla_tensor_args = cache_entry.vanilla_tensor_args + check( + not vanilla_tensor_args & alias_tensor_indices, + lambda: f"It seems that {vanilla_tensor_args} are {alias_tensor_indices=} share their storage and some of them are modified in-place", + NotImplementedError, + ) + + @wraps(fn) @update_call_statistics def fn_(*args, **kwargs) -> Any: @@ -746,16 +759,7 @@ def fn_(*args, **kwargs) -> Any: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) cs.last_trace_host_execution_start = time.perf_counter_ns() - if cache_entry.vanilla_tensor_args: - if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps): - alias_tensor_indices = alias_tensor_indices_str - alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")} - vanilla_tensor_args = cache_entry.vanilla_tensor_args - check( - not vanilla_tensor_args & alias_tensor_indices, - lambda: f"It seems that {vanilla_tensor_args} are {alias_tensor_indices=} share their storage and some of them are modified in-place", - NotImplementedError, - ) + check_storage_aliases(cache_entry, inps) result = cache_entry.computation_fn(*inps) From 2b089d9210b59baa267fa725f41acad1048f7919 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:34:41 +0300 Subject: [PATCH 06/15] Move ThunderFunction.apply to a separate function --- thunder/__init__.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index a7fa68e6c8..653b6189fd 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -749,6 +749,25 @@ def check_storage_aliases(cache_entry, args): ) + def maybe_connect_to_autograd(cache_entry, result): + if cache_entry.backward_fn: + # If the backward function is available, we need to connect the + # resulting tensors to PyTorch's Autograd graph using the + # ThunderFunction (which is a torch.autograd.Function subclass) + data_for_autograd, (saved_tensors, saved_other) = result + ThunderFunction.apply( + cache_entry.return_none_instead_of_grads, + cache_entry.backward_fn, + saved_tensors, + saved_other, + data_for_autograd["flat_output"], + *data_for_autograd["flat_args"], + ) + result = data_for_autograd["output"] + + return result + + @wraps(fn) @update_call_statistics def fn_(*args, **kwargs) -> Any: @@ -763,20 +782,7 @@ def fn_(*args, **kwargs) -> Any: result = cache_entry.computation_fn(*inps) - if cache_entry.backward_fn: - # Run the compiled forward function - data_for_autograd, (saved_tensors, saved_other) = result - - # Connect produced tensors with PyTorch's autograd graph - ThunderFunction.apply( - cache_entry.return_none_instead_of_grads, - cache_entry.backward_fn, - saved_tensors, - saved_other, - data_for_autograd["flat_output"], - *data_for_autograd["flat_args"], - ) - result = data_for_autograd["output"] + result = maybe_connect_to_autograd(cache_entry, result) if cache_entry.epilogue_fn: result, comp_to_epi = result From 90f9f2a560cf1578f3e3e6ead5423f921da1906c Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:41:26 +0300 Subject: [PATCH 07/15] Move epilogue_fn invocation to a separate function --- thunder/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 653b6189fd..65a60d948f 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -768,6 +768,12 @@ def maybe_connect_to_autograd(cache_entry, result): return result + def maybe_call_epilogue(cache_entry, result, pro_to_epi): + if cache_entry.epilogue_fn: + result, comp_to_epi = result + cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi) + + @wraps(fn) @update_call_statistics def fn_(*args, **kwargs) -> Any: @@ -784,9 +790,7 @@ def fn_(*args, **kwargs) -> Any: result = maybe_connect_to_autograd(cache_entry, result) - if cache_entry.epilogue_fn: - result, comp_to_epi = result - cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi) + maybe_call_epilogue(cache_entry, result, pro_to_epi) cs.last_trace_host_execution_stop = time.perf_counter_ns() From 63a606431d9eff0a31510d75504fcec26e52164a Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:45:08 +0300 Subject: [PATCH 08/15] Remove unused last_trace_host_tracing measurements --- thunder/__init__.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 65a60d948f..9eea8ac4b2 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -456,9 +456,6 @@ def get_computation_and_inputs(*args, **kwargs): except Exception as _: continue - cs.last_trace_host_tracing_start = time.perf_counter_ns() - cs.last_trace_host_tracing_stop = time.perf_counter_ns() - # Updates cache statistics cs.cache_hits += 1 cs.last_traces = comp_traces @@ -491,9 +488,6 @@ def get_computation_and_inputs(*args, **kwargs): inps, pro_to_epi = pro(*args, **kwargs) cs.last_prologue_execution_stop = time.perf_counter_ns() - cs.last_trace_host_tracing_start = time.perf_counter_ns() - cs.last_trace_host_tracing_stop = time.perf_counter_ns() - # Updates cache statistics cs.cache_hits += 1 cs.last_traces = comp_traces From fabfc417912d9dcbbe08d979e4812116fe5904f4 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:58:49 +0300 Subject: [PATCH 09/15] Move host_execution timer to a decorator to be applied on computation_fn --- thunder/__init__.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 9eea8ac4b2..e84149a8a6 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -715,6 +715,28 @@ def get_computation_and_inputs(*args, **kwargs): return cache_entry, inps, pro_to_epi + def host_execution_timer(fn): + def wrapped(*args, **kwargs): + cs.last_trace_host_execution_start = time.perf_counter_ns() + try: + return fn(*args, **kwargs) + finally: + cs.last_trace_host_execution_stop = time.perf_counter_ns() + + return wrapped + + + def decorate_computation_functions(get_computation_and_inputs_fn, *decorators): + def wrapped(*args, **kwargs): + cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs) + for decorator in decorators: + cache_entry.computation_fn = decorator(cache_entry.computation_fn) + return cache_entry, inps, pro_to_epi + + return wrapped + + + get_computation_and_inputs = decorate_computation_functions(get_computation_and_inputs, host_execution_timer) cd.get_computation_and_inputs = get_computation_and_inputs @@ -776,7 +798,6 @@ def fn_(*args, **kwargs) -> Any: return fn(*args, **kwargs) cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) - cs.last_trace_host_execution_start = time.perf_counter_ns() check_storage_aliases(cache_entry, inps) @@ -786,8 +807,6 @@ def fn_(*args, **kwargs) -> Any: maybe_call_epilogue(cache_entry, result, pro_to_epi) - cs.last_trace_host_execution_stop = time.perf_counter_ns() - cs.last_computation = cache_entry.computation_fn return result From 70ac9a774daa98d756d4d39f591eb7362e754b52 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 13:59:29 +0300 Subject: [PATCH 10/15] Remove a couple of empty lines --- thunder/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index e84149a8a6..15aba8822f 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -802,13 +802,11 @@ def fn_(*args, **kwargs) -> Any: check_storage_aliases(cache_entry, inps) result = cache_entry.computation_fn(*inps) - result = maybe_connect_to_autograd(cache_entry, result) maybe_call_epilogue(cache_entry, result, pro_to_epi) cs.last_computation = cache_entry.computation_fn - return result if isinstance(fn, pytorch.nn.Module): From d89f51d3f534b981963ab198b8f3f5c690ed6932 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 14:06:24 +0300 Subject: [PATCH 11/15] Add prologue_execution_timer decorator --- thunder/__init__.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 15aba8822f..0865226e47 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -450,9 +450,7 @@ def get_computation_and_inputs(*args, **kwargs): _vanilla_args, ) = cache_entry try: - cs.last_prologue_execution_start = time.perf_counter_ns() inps, pro_to_epi = pro(*args, **kwargs) - cs.last_prologue_execution_stop = time.perf_counter_ns() except Exception as _: continue @@ -484,9 +482,7 @@ def get_computation_and_inputs(*args, **kwargs): backward_traces, ) = cache_entry - cs.last_prologue_execution_start = time.perf_counter_ns() inps, pro_to_epi = pro(*args, **kwargs) - cs.last_prologue_execution_stop = time.perf_counter_ns() # Updates cache statistics cs.cache_hits += 1 @@ -608,6 +604,7 @@ def get_computation_and_inputs(*args, **kwargs): ) prologue_trc = prologue_traces[-1] pro = prologue_trc.python_callable(include_decorators=False) + pro = prologue_execution_timer(pro) if epilogue_trc is not None: epilogue = epilogue_trc.python_callable() @@ -623,9 +620,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_interpreter_log = last_interpreter_log cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction)) - cs.last_prologue_execution_start = time.perf_counter_ns() inps, pro_to_epi = pro(*args, **kwargs) - cs.last_prologue_execution_stop = time.perf_counter_ns() computation_trc = dce(computation_trc) computation_traces.append(computation_trc) @@ -726,6 +721,17 @@ def wrapped(*args, **kwargs): return wrapped + def prologue_execution_timer(fn): + def wrapped(*args, **kwargs): + cs.last_prologue_execution_start = time.perf_counter_ns() + try: + return fn(*args, **kwargs) + finally: + cs.last_prologue_execution_stop = time.perf_counter_ns() + + return wrapped + + def decorate_computation_functions(get_computation_and_inputs_fn, *decorators): def wrapped(*args, **kwargs): cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs) From 77c7d951284509b8749602d42049e04f3dcd8a77 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 14:21:15 +0300 Subject: [PATCH 12/15] decorate_computation_functions -> decorate_computation_function --- thunder/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 0865226e47..6d959b6dd7 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -732,7 +732,7 @@ def wrapped(*args, **kwargs): return wrapped - def decorate_computation_functions(get_computation_and_inputs_fn, *decorators): + def decorate_computation_function(get_computation_and_inputs_fn, *decorators): def wrapped(*args, **kwargs): cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs) for decorator in decorators: @@ -742,7 +742,7 @@ def wrapped(*args, **kwargs): return wrapped - get_computation_and_inputs = decorate_computation_functions(get_computation_and_inputs, host_execution_timer) + get_computation_and_inputs = decorate_computation_function(get_computation_and_inputs, host_execution_timer) cd.get_computation_and_inputs = get_computation_and_inputs From a8c4d7a718baddf61ac4344226d326608d89491f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 21 Oct 2024 11:36:41 +0000 Subject: [PATCH 13/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 6d959b6dd7..16379c9d29 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -720,7 +720,6 @@ def wrapped(*args, **kwargs): return wrapped - def prologue_execution_timer(fn): def wrapped(*args, **kwargs): cs.last_prologue_execution_start = time.perf_counter_ns() @@ -731,7 +730,6 @@ def wrapped(*args, **kwargs): return wrapped - def decorate_computation_function(get_computation_and_inputs_fn, *decorators): def wrapped(*args, **kwargs): cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs) @@ -741,11 +739,9 @@ def wrapped(*args, **kwargs): return wrapped - get_computation_and_inputs = decorate_computation_function(get_computation_and_inputs, host_execution_timer) cd.get_computation_and_inputs = get_computation_and_inputs - def update_call_statistics(fn): def wrapped(*args, **kwargs): cs.calls += 1 @@ -757,7 +753,6 @@ def wrapped(*args, **kwargs): return wrapped - def check_storage_aliases(cache_entry, args): if cache_entry.vanilla_tensor_args: if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*args): @@ -770,7 +765,6 @@ def check_storage_aliases(cache_entry, args): NotImplementedError, ) - def maybe_connect_to_autograd(cache_entry, result): if cache_entry.backward_fn: # If the backward function is available, we need to connect the @@ -789,13 +783,11 @@ def maybe_connect_to_autograd(cache_entry, result): return result - def maybe_call_epilogue(cache_entry, result, pro_to_epi): if cache_entry.epilogue_fn: result, comp_to_epi = result cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi) - @wraps(fn) @update_call_statistics def fn_(*args, **kwargs) -> Any: From a9fd993742423c2e5a5ec0a994a44018fbb0d674 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 18:29:34 +0300 Subject: [PATCH 14/15] CacheEntry is a named tuple which doesn't support attribute assignment --- thunder/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 16379c9d29..65d6a8deaf 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -733,8 +733,11 @@ def wrapped(*args, **kwargs): def decorate_computation_function(get_computation_and_inputs_fn, *decorators): def wrapped(*args, **kwargs): cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs) + decorated_computation_fn = cache_entry.computation_fn for decorator in decorators: - cache_entry.computation_fn = decorator(cache_entry.computation_fn) + decorated_computation_fn = decorator(decorated_computation_fn) + if decorators: + cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn) return cache_entry, inps, pro_to_epi return wrapped From 9923a429c24a5f641e624fc6c6229b372abd5a92 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 21 Oct 2024 19:58:31 +0300 Subject: [PATCH 15/15] maybe_call_epilogue should return result --- thunder/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 65d6a8deaf..44fe0567b0 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -791,6 +791,8 @@ def maybe_call_epilogue(cache_entry, result, pro_to_epi): result, comp_to_epi = result cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi) + return result + @wraps(fn) @update_call_statistics def fn_(*args, **kwargs) -> Any: @@ -804,8 +806,7 @@ def fn_(*args, **kwargs) -> Any: result = cache_entry.computation_fn(*inps) result = maybe_connect_to_autograd(cache_entry, result) - - maybe_call_epilogue(cache_entry, result, pro_to_epi) + result = maybe_call_epilogue(cache_entry, result, pro_to_epi) cs.last_computation = cache_entry.computation_fn return result