Skip to content

Commit

Permalink
[WIP DO NOT MERGE]: + some nvtx, - some timers
Browse files Browse the repository at this point in the history
start ripping out the timers.

this isn't ready for consideration but it's usable for analyzing
performance on real models.
  • Loading branch information
tfogal committed Nov 7, 2024
1 parent dcf3553 commit a82c33d
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 64 deletions.
72 changes: 37 additions & 35 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from looseversion import LooseVersion

from thunder.core.module import ThunderModule
import thunder.core.profile
from thunder.core.interpreter import InterpreterLogItem
from thunder.core.options import (
CACHE_OPTIONS,
SHARP_EDGES_OPTIONS,
)
from thunder.core.profile import annotate_for_profile
from thunder.core.trace import (
TraceResults,
TraceCtx,
Expand Down Expand Up @@ -277,6 +279,7 @@ def compile(fn: Callable, recipe: Recipe | None):
# This function will replace compile() (below) before RC1
# TODO RC1 Consider adding a debug_log parameter to control debug printing
# TODO RC1 Consider renaming compile_options to additional_compile_options
@thunder.core.profile.annotate_for_profile("jit")
def jit(
fn: Callable,
/,
Expand Down Expand Up @@ -377,6 +380,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
tensor_group_index_to_tensor_indices[tgi].append(idx)
return tensor_group_index_to_tensor_indices

@thunder.core.profile.annotate_for_profile("_alias_tensor_of_args_kwargs")
def _alias_tensor_of_args_kwargs(*args, **kwargs) -> str:
"""If no aliases found, empty string, otherwise, aliases are comma separated, groups are hyphen separated."""

Expand All @@ -391,6 +395,7 @@ def _alias_tensor_of_args_kwargs(*args, **kwargs) -> str:

@langctxs.langctx(cd.langctx)
@_with_cache_info_ctx
@thunder.core.profile.annotate_for_profile("get_computation_and_inputs")
def get_computation_and_inputs(*args, **kwargs):
# set up a record of things in the current environment that impact caching / prologues
# this could be replaced by the respective querying in the prologues
Expand Down Expand Up @@ -440,7 +445,8 @@ def get_computation_and_inputs(*args, **kwargs):
# alaises wouldn't matter, thus it'd be better to nullify this entry in such cases.
# It however would require the functionalized computation trace to interact with `cache_info`,
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)
with thunder.core.profile.annotate_for_profile("Alias_tensor_of_args_kwarsgs"):
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)

# TODO RC1 Add module and function checks to prologue (make it a compile option)

Expand All @@ -462,7 +468,8 @@ def get_computation_and_inputs(*args, **kwargs):
_vanilla_args,
) = cache_entry
try:
inps, pro_to_epi = pro(*args, **kwargs)
from thunder.core.profile import annotate_for_profile
inps, pro_to_epi = annotate_for_profile("prologue")(pro(*args, **kwargs))
except Exception as _:
continue

Expand Down Expand Up @@ -494,7 +501,7 @@ def get_computation_and_inputs(*args, **kwargs):
backward_traces,
) = cache_entry

inps, pro_to_epi = pro(*args, **kwargs)
inps, pro_to_epi = annotate_for_profile("Prologue")(pro(*args, **kwargs))

# Updates cache statistics
cs.cache_hits += 1
Expand Down Expand Up @@ -618,11 +625,16 @@ def get_computation_and_inputs(*args, **kwargs):
use_del_last_used=False,
)
prologue_trc = prologue_traces[-1]
pro = prologue_trc.python_callable(include_decorators=False)
pro = prologue_execution_timer(pro)
pro = thunder.core.profile.annotate_for_profile("prologue python_callable")(
prologue_trc.python_callable(include_decorators=False)
)
#pro = thunder.core.profile.annotate_for_profile("prologue")(pro)

if epilogue_trc is not None:
epilogue = epilogue_trc.python_callable()
epilogue = thunder.core.profile.annotate_for_profile("epilogue python_callable")(
epilogue_trc.python_callable()
)
epilogue = thunder.core.profile.annotate_for_profile("epilogue")(epilogue)
else:
epilogue = None

Expand Down Expand Up @@ -697,7 +709,9 @@ def get_computation_and_inputs(*args, **kwargs):
backward_traces.append(backward_trc)

if backward_trc is not None:
backward_fn = backward_trc.python_callable()
backward_fn = thunder.core.profile.annotate_for_profile("backward python_callable")(
backward_trc.python_callable()
)
else:
backward_fn = None
# We do not have to return auxiliary tensors, which will only be useful in backward pass
Expand Down Expand Up @@ -727,33 +741,20 @@ def get_computation_and_inputs(*args, **kwargs):

def host_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_trace_host_execution_start = time.perf_counter_ns()
try:
with thunder.core.profile.annotate_for_profile("computation"):
return fn(*args, **kwargs)
finally:
cs.last_trace_host_execution_stop = time.perf_counter_ns()

return wrapped

def prologue_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_prologue_execution_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_prologue_execution_stop = time.perf_counter_ns()

return wrapped

def decorate_computation_function(get_computation_and_inputs_fn, *decorators):
def wrapped(*args, **kwargs):
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
decorated_computation_fn = cache_entry.computation_fn
for decorator in decorators:
decorated_computation_fn = decorator(decorated_computation_fn)
if decorators:
cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn)
return cache_entry, inps, pro_to_epi
with thunder.core.profile.annotate_for_profile("get_computation_and_inputs"):
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
decorated_computation_fn = cache_entry.computation_fn
for decorator in decorators:
decorated_computation_fn = decorator(decorated_computation_fn)
if decorators:
cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn)
return cache_entry, inps, pro_to_epi

return wrapped

Expand All @@ -763,14 +764,10 @@ def wrapped(*args, **kwargs):
def update_call_statistics(fn):
def wrapped(*args, **kwargs):
cs.calls += 1
cs.last_trace_host_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_trace_host_stop = time.perf_counter_ns()

return fn(*args, **kwargs)
return wrapped

@thunder.core.profile.annotate_for_profile("check_storage_aliases")
def check_storage_aliases(cache_entry, args):
if cache_entry.vanilla_tensor_args:
if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*args):
Expand All @@ -783,6 +780,7 @@ def check_storage_aliases(cache_entry, args):
NotImplementedError,
)

@thunder.core.profile.annotate_for_profile("maybe_connect_to_autograd")
def maybe_connect_to_autograd(cache_entry, result):
if cache_entry.backward_fn:
# If the backward function is available, we need to connect the
Expand Down Expand Up @@ -810,6 +808,7 @@ def maybe_call_epilogue(cache_entry, result, pro_to_epi):

@wraps(fn)
@update_call_statistics
@thunder.core.profile.annotate_for_profile("fn_")
def fn_(*args, **kwargs) -> Any:
if is_tracing():
_recursive_jit_call_warning()
Expand Down Expand Up @@ -1033,6 +1032,7 @@ def get_auto_registered_torch_op_names(fn: Callable, /) -> set[str] | None:


# TODO (mruberry) Update this
@thunder.core.profile.annotate_for_profile("_grad_transform")
def _grad_transform(trace):
grad_fwd_trace = from_trace(trace)
trace_tok = set_tracectx(grad_fwd_trace)
Expand Down Expand Up @@ -1086,10 +1086,12 @@ def _grad_transform(trace):

# TODO Test nesting of grad and grad and grad and grad
# TODO Test nesting of a regular function + calling grad
@thunder.core.profile.annotate_for_profile("grad")
def grad(fn):
cfn = compile(fn)

@wraps(cfn)
@thunder.core.profile.annotate_for_profile("_fn")
def _fn(*args, **kwargs):
original_result, original_trace = cfn(*args, **kwargs)
original_trace = last_traces(cfn)
Expand Down
56 changes: 27 additions & 29 deletions thunder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,20 @@ class CompileStats:
last_interpreted_instructions (Generator[dist.Instruction, None, None] | None):
last_interpreter_log (list[InterpreterLogItem] | None):
last_backward_traces (Sequence[TraceCtx]):
last_trace_host_start (int):
last_trace_host_stop (int):
last_trace_cache_start (int):
last_trace_cache_stop (int):
last_trace_tracing_start (int):
last_trace_tracing_stop (int):
last_trace_host_execution_start (int):
last_trace_host_execution_stop (int):
last_prologue_transformation_start (int):
last_prologue_transformation_stop (int):
last_prologue_execution_start (int):
last_prologue_execution_stop (int):
last_computation_execution_start (int):
last_computation_execution_stop (int):
last_trace_host_start (int): deprecated
last_trace_host_stop (int): deprecated
last_trace_cache_start (int): deprecated
last_trace_cache_stop (int): deprecated
last_trace_tracing_start (int): deprecated
last_trace_tracing_stop (int): deprecated
last_trace_host_execution_start (int): deprecated
last_trace_host_execution_stop (int): deprecated
last_prologue_transformation_start (int): deprecated
last_prologue_transformation_stop (int): deprecated
last_prologue_execution_start (int): deprecated
last_prologue_execution_stop (int): deprecated
last_computation_execution_start (int): deprecated
last_computation_execution_stop (int): deprecated
cache (dict):
interpreter_cache (list):
calls (int):
Expand Down Expand Up @@ -132,6 +132,8 @@ def __init__(self):
self.last_prologue_execution_stop: int = -1
self.last_computation_execution_start: int = -1
self.last_computation_execution_stop: int = -1
self.last_times_msg: str = "Timers have been replaced by nvtx. "
self.last_times_msg += "Please use Nsight systems to see host timers."

# Cache stats
self.cache = {}
Expand All @@ -151,29 +153,24 @@ def _time_template(self, start: int, stop: int, desc: str, /) -> int:
return stop - start

def last_cache_lookup_time(self, /) -> int:
start: int = self.last_trace_cache_start
stop: int = self.last_trace_cache_stop
return self._time_template(start, stop, "cache lookup")
warnings.warn(self.last_times_msg)
return -1

def last_trace_construction_time(self, /) -> int:
start: int = self.last_trace_host_start
stop: int = self.last_trace_host_stop
return self._time_template(start, stop, "trace construction")
warnings.warn(self.last_times_msg)
return -1

def last_prologue_transformation_time(self, /) -> int:
start: int = self.last_prologue_transformation_start
stop: int = self.last_prologue_transformation_stop
return self._time_template(start, stop, "prologue construction")
warnings.warn(self.last_times_msg)
return -1

def last_prologue_execution_time(self, /) -> int:
start: int = self.last_prologue_execution_start
stop: int = self.last_prologue_execution_stop
return self._time_template(start, stop, "prologue execution")
warnings.warn(self.last_times_msg)
return -1

def last_computation_execution_time(self, /) -> int:
start: int = self.last_computation_execution_start
stop: int = self.last_computation_execution_stop
return self._time_template(start, stop, "computation execution")
warnings.warn(self.last_times_msg)
return -1


# A class that holds data about the compiled object, including statistics about how it's been called
Expand Down Expand Up @@ -631,6 +628,7 @@ def wait_for_future(f: FutureTensorProxy) -> TensorProxy:
# TODO Consider making this faster by reusing more data
# TODO Create a general mechanism for running traces that produces reproducible provenance and the
# appropriate error checks
@thunder.core.profile.annotate_for_profile("transform_for_execution")
def transform_for_execution(
trace: TraceCtx,
executors_list: Sequence[Executor],
Expand Down
3 changes: 3 additions & 0 deletions thunder/core/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING

import thunder.core.prims as prims
import thunder.core.profile
from thunder.core.proxies import variableify, TensorProxy, unvariableify, ProxyInterface
from thunder.core.pytree import tree_flatten, tree_unflatten
from thunder.core.symbol import BoundSymbol
Expand Down Expand Up @@ -39,6 +40,7 @@ def bsym_of_to_return_self(bsym: BoundSymbol):
return result_is_self


@thunder.core.profile.annotate_for_profile("check_inplace_to_views")
def check_inplace_to_views(computation_trace: Trace) -> dict[VariableInterface, TensorProxy]:
"""Error out if in-place op that outputs of different number of elements from the input and the input has other consumers."""
from thunder.core import utils
Expand Down Expand Up @@ -614,6 +616,7 @@ def _reshape_bsym_ctor(src: TensorProxy, dst: TensorProxy, trace: Trace) -> tupl
return functionalized_computation_trace


@thunder.core.profile.annotate_for_profile("functionalize_inplace_ops")
def functionalize_inplace_ops(
computation_trace: Trace,
orig_to_view_swap_map: dict[VariableInterface, TensorProxy],
Expand Down
1 change: 1 addition & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,7 @@ def update_tags(proxy_swapmap: dict[Variable, Proxy]) -> None:
new.tags.update(unvariableify(old).tags)


@thunder.core.profile.annotate_for_profile("thunder_general_jit")
def thunder_general_jit(
fn: Callable,
args: tuple[Any, ...],
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import thunder
import thunder.core.codeutils as codeutils
import thunder.core.baseutils as baseutils
import thunder.core.profile
from thunder.core.baseutils import ProxyInterface, BoundSymbolInterface, TagBase
import thunder.core.devices as devices
from thunder.core.pytree import tree_flatten, tree_unflatten
Expand Down Expand Up @@ -479,6 +480,7 @@ def keyfn(class_or_module: type | ModuleType) -> str:
# Returns a Python callable that executes the trace
# TODO issue "Create a mechanism for freezing TraceCtx objects"
# Create a mechanism for freezing traces and cache the compilation
@thunder.core.profile.annotate_for_profile("TraceCtx.python_callable")
def python_callable(self, *, global_dicts: None | dict = None, **kwargs: Any) -> Callable:
python_str: str

Expand Down
3 changes: 3 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def check(inp, log_str):
# that only produce non-proxy objects
# NOTE needed_proxies is an in/out argument, it takes an initial set of Variables you want to keep, and return
# all the needed proxies of the input trace
@thunder.core.profile.annotate_for_profile("dce")
def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
start_time_ns = time.perf_counter_ns()

Expand Down Expand Up @@ -456,6 +457,7 @@ def process_bound_symbols(src_bound_symbols, target_bound_symbols):
return output


@thunder.core.profile.annotate_for_profile("wrap_return-value_together_with_arguments")
def wrap_return_value_together_with_argments(trace: Trace) -> Trace:
last = trace.bound_symbols[-1]
assert last.sym.id == prims.PrimIDs.RETURN
Expand All @@ -480,6 +482,7 @@ def unwrap_return_value(trace: Trace) -> Trace:
return new_trace


@thunder.core.profile.annotate_for_profile("remove_context_manager_prims_from_trace")
def remove_context_manager_prims_from_trace(trace: Trace) -> Trace:
def is_context_manager_prim(bsym):
# context manager prims would/should be explicitly tagged.
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
interpret_trace_to_trace,
trace_interpreter_skip_list,
)
from thunder.core.profile import annotate_for_profile
from thunder.core.proxies import (
CollectionProxy,
NumberProxy,
Expand Down Expand Up @@ -1485,6 +1486,7 @@ def transform_traces_pre_prologue(
computation_trc = dce(computation_trc)

@wraps(computation_trc.python_callable())
@annotate_for_profile("_GradTransform.python_callable")
def python_callable(*args, **kwargs):
return eval_trace(computation_trc, *args, **kwargs)["output"]

Expand Down
Loading

0 comments on commit a82c33d

Please sign in to comment.