Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify function implementation returned by thunder.jit for easier instrumentation of different stages #1333

Merged
merged 17 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 68 additions & 34 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,15 +450,10 @@ def get_computation_and_inputs(*args, **kwargs):
_vanilla_args,
) = cache_entry
try:
cs.last_prologue_execution_start = time.perf_counter_ns()
inps, pro_to_epi = pro(*args, **kwargs)
cs.last_prologue_execution_stop = time.perf_counter_ns()
except Exception as _:
continue

cs.last_trace_host_tracing_start = time.perf_counter_ns()
cs.last_trace_host_tracing_stop = time.perf_counter_ns()

# Updates cache statistics
cs.cache_hits += 1
cs.last_traces = comp_traces
Expand Down Expand Up @@ -487,12 +482,7 @@ def get_computation_and_inputs(*args, **kwargs):
backward_traces,
) = cache_entry

cs.last_prologue_execution_start = time.perf_counter_ns()
inps, pro_to_epi = pro(*args, **kwargs)
cs.last_prologue_execution_stop = time.perf_counter_ns()

cs.last_trace_host_tracing_start = time.perf_counter_ns()
cs.last_trace_host_tracing_stop = time.perf_counter_ns()

# Updates cache statistics
cs.cache_hits += 1
Expand Down Expand Up @@ -614,6 +604,7 @@ def get_computation_and_inputs(*args, **kwargs):
)
prologue_trc = prologue_traces[-1]
pro = prologue_trc.python_callable(include_decorators=False)
pro = prologue_execution_timer(pro)

if epilogue_trc is not None:
epilogue = epilogue_trc.python_callable()
Expand All @@ -629,9 +620,7 @@ def get_computation_and_inputs(*args, **kwargs):
cs.last_interpreter_log = last_interpreter_log
cs.last_interpreted_instructions = (i for i in last_interpreter_log if isinstance(i, dis.Instruction))

cs.last_prologue_execution_start = time.perf_counter_ns()
inps, pro_to_epi = pro(*args, **kwargs)
cs.last_prologue_execution_stop = time.perf_counter_ns()

computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)
Expand Down Expand Up @@ -721,23 +710,55 @@ def get_computation_and_inputs(*args, **kwargs):

return cache_entry, inps, pro_to_epi

cd.get_computation_and_inputs = get_computation_and_inputs
def host_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_trace_host_execution_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_trace_host_execution_stop = time.perf_counter_ns()

@wraps(fn)
def fn_(*args, **kwargs) -> Any:
if is_tracing():
_recursive_jit_call_warning()
return fn(*args, **kwargs)
return wrapped

# Updats call statistics
cs.last_trace_host_start = time.perf_counter_ns()
cs.calls += 1
def prologue_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_prologue_execution_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_prologue_execution_stop = time.perf_counter_ns()

cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
cs.last_trace_host_execution_start = time.perf_counter_ns()
return wrapped

def decorate_computation_function(get_computation_and_inputs_fn, *decorators):
def wrapped(*args, **kwargs):
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
decorated_computation_fn = cache_entry.computation_fn
for decorator in decorators:
decorated_computation_fn = decorator(decorated_computation_fn)
if decorators:
cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn)
return cache_entry, inps, pro_to_epi
tfogal marked this conversation as resolved.
Show resolved Hide resolved

return wrapped

get_computation_and_inputs = decorate_computation_function(get_computation_and_inputs, host_execution_timer)
cd.get_computation_and_inputs = get_computation_and_inputs

def update_call_statistics(fn):
def wrapped(*args, **kwargs):
cs.calls += 1
cs.last_trace_host_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_trace_host_stop = time.perf_counter_ns()

return wrapped

def check_storage_aliases(cache_entry, args):
if cache_entry.vanilla_tensor_args:
if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps):
if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*args):
alias_tensor_indices = alias_tensor_indices_str
alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")}
vanilla_tensor_args = cache_entry.vanilla_tensor_args
Expand All @@ -747,13 +768,12 @@ def fn_(*args, **kwargs) -> Any:
NotImplementedError,
)

result = cache_entry.computation_fn(*inps)

def maybe_connect_to_autograd(cache_entry, result):
if cache_entry.backward_fn:
# Run the compiled forward function
# If the backward function is available, we need to connect the
# resulting tensors to PyTorch's Autograd graph using the
# ThunderFunction (which is a torch.autograd.Function subclass)
data_for_autograd, (saved_tensors, saved_other) = result

# Connect produced tensors with PyTorch's autograd graph
ThunderFunction.apply(
cache_entry.return_none_instead_of_grads,
cache_entry.backward_fn,
Expand All @@ -764,17 +784,31 @@ def fn_(*args, **kwargs) -> Any:
)
result = data_for_autograd["output"]

return result

def maybe_call_epilogue(cache_entry, result, pro_to_epi):
if cache_entry.epilogue_fn:
result, comp_to_epi = result
cache_entry.epilogue_fn(*pro_to_epi, *comp_to_epi)

cs.last_trace_host_execution_stop = time.perf_counter_ns()
cs.last_computation_execution_stop = cs.last_trace_host_execution_stop
return result

cs.last_executed = cache_entry.computation_fn
cs.last_trace_cache_stop = time.perf_counter_ns()
cs.last_trace_host_stop = time.perf_counter_ns()
@wraps(fn)
@update_call_statistics
def fn_(*args, **kwargs) -> Any:
if is_tracing():
_recursive_jit_call_warning()
return fn(*args, **kwargs)

cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)

check_storage_aliases(cache_entry, inps)

result = cache_entry.computation_fn(*inps)
result = maybe_connect_to_autograd(cache_entry, result)
result = maybe_call_epilogue(cache_entry, result, pro_to_epi)

cs.last_computation = cache_entry.computation_fn
return result

if isinstance(fn, pytorch.nn.Module):
Expand Down
5 changes: 2 additions & 3 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@


# Holds statistics and caches for a compiled function
# TODO RC1 Update last_executed to last_computation
# TODO RC1 Review how autograd traces are presented
class CompileStats:
"""A class holding statistics and caches for a compiled function.
Expand All @@ -76,7 +75,7 @@ class CompileStats:
See :mod:`thunder` for more of such utility functions.

Attributes:
last_executed:
last_computation (Callable):
last_traces (Sequence[TraceCtx]):
last_prologue (TraceCtx):
last_prologue_traces (Sequence[TraceCtx]):
Expand Down Expand Up @@ -107,7 +106,7 @@ class CompileStats:

def __init__(self):
# Callables and traces
self.last_executed = None
self.last_computation = None
self.last_traces = None
self.last_prologue = None
self.last_prologue_traces = None
Expand Down
Loading