Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework timer infrastructure #1410

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 35 additions & 46 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from looseversion import LooseVersion

from thunder.core.module import ThunderModule
import thunder.core.profile
from thunder.core.interpreter import InterpreterLogItem
from thunder.core.options import (
CACHE_OPTIONS,
SHARP_EDGES_OPTIONS,
DebugOptions,
)
from thunder.core.profile import annotate_for_profile
from thunder.core.trace import (
TraceResults,
TraceCtx,
Expand Down Expand Up @@ -276,6 +278,7 @@ def compile(fn: Callable, recipe: Recipe | None):

# This function will replace compile() (below) before RC1
# TODO RC1 Consider renaming compile_options to additional_compile_options
@thunder.core.profile.annotate_for_profile("jit")
def jit(
fn: Callable,
/,
Expand Down Expand Up @@ -377,6 +380,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
tensor_group_index_to_tensor_indices[tgi].append(idx)
return tensor_group_index_to_tensor_indices

@thunder.core.profile.annotate_for_profile("_alias_tensor_of_args_kwargs")
def _alias_tensor_of_args_kwargs(*args, **kwargs) -> str:
"""If no aliases found, empty string, otherwise, aliases are comma separated, groups are hyphen separated."""

Expand All @@ -391,6 +395,7 @@ def _alias_tensor_of_args_kwargs(*args, **kwargs) -> str:

@langctxs.langctx(cd.langctx)
@_with_cache_info_ctx
@thunder.core.profile.annotate_for_profile("get_computation_and_inputs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a mode for annotate for profile that grabs fn.__qualname__, please?
Having to type the name of the function/method every time seems tedious.

Copy link
Collaborator Author

@tfogal tfogal Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to do this, but not actually sure how to implement it. Maybe you could educate me?

The definition is actually just a partial: annotate_for_profile = functools.partial(nvtx.annotate, domain="thunder"). Did you mean maybe something like inspect.stack()[0].function? thus:

  annotate_for_profile = functools.partial(nvtx.annotate, message=inspect.stack()[0].function, domain="thunder")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it is not at executing the function at definition time.

The decorator is a function that takes the function as argument and returns the "new" function. So

@decorator
def fn(...):
    ...

is equivalent to

def fn(...):
    ...
fn = decorator(fn)

The trick is allow both using @decorator and @decorator("foo") is that the first gets a function and needs to return one and the second gets the argument and needs to return a decorator (because the last line in spelling out would become fn = decorator("foo")(fn).
So something like

def annotate_for_profile(fn_or_name):
    if isinstance(fn_or_name, str):
        return functools.partial(nvtx.annotate, message=fn_or_name, domain="thunder")
    return nvtx.annotate(fn_or_name, message=fn_or_name.__qualname__, domain="thunder")

or so should do the trick (didn't try).
But looking at the nvtx code, I'm wondering how they expect to do exception handling and maybe defining the wrapper ourselves and using the context manager version of annotate is safer...

def get_computation_and_inputs(*args, **kwargs):
# set up a record of things in the current environment that impact caching / prologues
# this could be replaced by the respective querying in the prologues
Expand Down Expand Up @@ -440,7 +445,8 @@ def get_computation_and_inputs(*args, **kwargs):
# alaises wouldn't matter, thus it'd be better to nullify this entry in such cases.
# It however would require the functionalized computation trace to interact with `cache_info`,
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)
with thunder.core.profile.annotate_for_profile("Alias_tensor_of_args_kwarsgs"):
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)

# Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform
# to treat certain Symbols as constant.
Expand All @@ -450,7 +456,6 @@ def get_computation_and_inputs(*args, **kwargs):
# TODO RC1 Add module and function checks to prologue (make it a compile option)

# Checks cache
cs.last_trace_cache_start = time.perf_counter_ns()
if (cd.cache_option is CACHE_OPTIONS.CONSTANT_VALUES) or (cd.cache_option is CACHE_OPTIONS.SYMBOLIC_VALUES):
for cache_entry in reversed(cs.interpreter_cache):
with compile_data_and_stats(cd, cs):
Expand All @@ -466,7 +471,9 @@ def get_computation_and_inputs(*args, **kwargs):
_return_none_instead_of_grads,
) = cache_entry
try:
inps, pro_to_epi = pro(*args, **kwargs)
from thunder.core.profile import annotate_for_profile

inps, pro_to_epi = annotate_for_profile("prologue")(pro(*args, **kwargs))
except Exception:
continue

Expand All @@ -477,10 +484,6 @@ def get_computation_and_inputs(*args, **kwargs):
cs.last_interpreter_log = None
cs.last_prologue_traces = pro_traces
cs.last_prologue = pro
cs.last_prologue_transformation_start = 0
cs.last_prologue_transformation_stop = 0
cs.last_computation_transformation_start = 0
cs.last_computation_transformation_stop = 0

return cache_entry, inps, pro_to_epi

Expand All @@ -498,7 +501,7 @@ def get_computation_and_inputs(*args, **kwargs):
backward_traces,
) = cache_entry

inps, pro_to_epi = pro(*args, **kwargs)
inps, pro_to_epi = annotate_for_profile("Prologue")(pro(*args, **kwargs))

# Updates cache statistics
cs.cache_hits += 1
Expand All @@ -511,16 +514,13 @@ def get_computation_and_inputs(*args, **kwargs):
return cache_entry, inps, pro_to_epi

cs.cache_misses += 1
cs.last_trace_cache_stop = time.perf_counter_ns()

# Resets use of compile flags
cs.last_compile_reasons = defaultdict(list)

with compile_data_and_stats(cd, cs):
# Acquires the trace OR inlines the trace into an existing trace and
# returns the (proxied) result of the operation
cs.last_trace_tracing_start = time.perf_counter_ns()

prologue_trc: TraceCtx
computation_trc: TraceCtx
jit_results: TraceResults = thunder_general_jit(
Expand Down Expand Up @@ -562,11 +562,7 @@ def get_computation_and_inputs(*args, **kwargs):

epilogue_traces = [epilogue_trc]

cs.last_trace_tracing_stop = time.perf_counter_ns()

# Makes the prologue callable
cs.last_prologue_transformation_start = time.perf_counter_ns()

transform: Transform
for transform in transforms:
thunder.core.utils.check_type(transform, Transform)
Expand Down Expand Up @@ -595,13 +591,15 @@ def get_computation_and_inputs(*args, **kwargs):
use_del_last_used=False,
)
prologue_trc = prologue_traces[-1]
pro = prologue_trc.python_callable(include_decorators=False)
pro = prologue_execution_timer(pro)
pro = thunder.core.profile.annotate_for_profile("prologue python_callable")(
prologue_trc.python_callable(include_decorators=False)
)

epilogue_trc = transform_to_torch_types(epilogue_trc)
epilogue = epilogue_trc.python_callable()
epilogue = thunder.core.profile.annotate_for_profile("epilogue python_callable")(
epilogue_trc.python_callable()
)

cs.last_prologue_transformation_stop = time.perf_counter_ns()
cs.last_prologue_traces = prologue_traces
cs.last_prologue = pro
cs.last_traces = computation_traces
Expand Down Expand Up @@ -671,7 +669,9 @@ def get_computation_and_inputs(*args, **kwargs):
backward_traces.append(backward_trc)

if backward_trc is not None:
backward_fn = backward_trc.python_callable()
backward_fn = thunder.core.profile.annotate_for_profile("backward python_callable")(
backward_trc.python_callable()
)
else:
backward_fn = None
# We do not have to return auxiliary tensors, which will only be useful in backward pass
Expand Down Expand Up @@ -700,33 +700,21 @@ def get_computation_and_inputs(*args, **kwargs):

def host_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_trace_host_execution_start = time.perf_counter_ns()
try:
with thunder.core.profile.annotate_for_profile("computation"):
return fn(*args, **kwargs)
finally:
cs.last_trace_host_execution_stop = time.perf_counter_ns()

return wrapped

def prologue_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_prologue_execution_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_prologue_execution_stop = time.perf_counter_ns()

return wrapped

def decorate_computation_function(get_computation_and_inputs_fn, *decorators):
def wrapped(*args, **kwargs):
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
decorated_computation_fn = cache_entry.computation_fn
for decorator in decorators:
decorated_computation_fn = decorator(decorated_computation_fn)
if decorators:
cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn)
return cache_entry, inps, pro_to_epi
with thunder.core.profile.annotate_for_profile("get_computation_and_inputs"):
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
decorated_computation_fn = cache_entry.computation_fn
for decorator in decorators:
decorated_computation_fn = decorator(decorated_computation_fn)
if decorators:
cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn)
return cache_entry, inps, pro_to_epi

return wrapped

Expand All @@ -736,14 +724,11 @@ def wrapped(*args, **kwargs):
def update_call_statistics(fn):
def wrapped(*args, **kwargs):
cs.calls += 1
cs.last_trace_host_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_trace_host_stop = time.perf_counter_ns()
return fn(*args, **kwargs)

return wrapped

@thunder.core.profile.annotate_for_profile("maybe_connect_to_autograd")
def maybe_connect_to_autograd(cache_entry, result):
if cache_entry.backward_fn:
# If the backward function is available, we need to connect the
Expand All @@ -769,6 +754,7 @@ def call_epilogue(cache_entry, comp_result, pro_to_epi):

@wraps(fn)
@update_call_statistics
@thunder.core.profile.annotate_for_profile("fn_")
def fn_(*args, **kwargs) -> Any:
if is_tracing():
_recursive_jit_call_warning()
Expand Down Expand Up @@ -990,6 +976,7 @@ def get_auto_registered_torch_op_names(fn: Callable, /) -> set[str] | None:


# TODO (mruberry) Update this
@thunder.core.profile.annotate_for_profile("_grad_transform")
def _grad_transform(trace):
grad_fwd_trace = from_trace(trace)
trace_tok = set_tracectx(grad_fwd_trace)
Expand Down Expand Up @@ -1043,10 +1030,12 @@ def _grad_transform(trace):

# TODO Test nesting of grad and grad and grad and grad
# TODO Test nesting of a regular function + calling grad
@thunder.core.profile.annotate_for_profile("grad")
def grad(fn):
cfn = compile(fn)

@wraps(cfn)
@thunder.core.profile.annotate_for_profile("_fn")
def _fn(*args, **kwargs):
original_result, original_trace = cfn(*args, **kwargs)
original_trace = last_traces(cfn)
Expand Down
65 changes: 1 addition & 64 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,6 @@ class BenchmarkRunStatistics:
stop_time: int
host_stop_time: int
called_backward: bool
has_extended_stats: bool = False
last_trace_host_start: int = -1
last_trace_host_stop: int = -1
last_trace_cache_start: int = -1
last_trace_cache_stop: int = -1
last_trace_tracing_start: int = -1
last_trace_tracing_stop: int = -1
last_trace_host_execution_start: int = -1
last_trace_host_execution_stop: int = -1


# A timing helper
Expand Down Expand Up @@ -294,18 +285,6 @@ def _benchmark(
called_backward=called_backward,
)

# TODO Ensure the compile data statistics are always populated
if cs is not None and cs.last_trace_host_start > 0:
stat.has_extended_stats = True
stat.last_trace_host_start = cs.last_trace_host_start
stat.last_trace_host_stop = cs.last_trace_host_stop
stat.last_trace_cache_start = cs.last_trace_cache_start
stat.last_trace_cache_stop = cs.last_trace_cache_stop
stat.last_trace_tracing_start = cs.last_trace_tracing_start
stat.last_trace_tracing_stop = cs.last_trace_tracing_stop
stat.last_trace_host_execution_start = cs.last_trace_host_execution_start
stat.last_trace_host_execution_stop = cs.last_trace_host_execution_stop

stats.append(stat)

return stats
Expand Down Expand Up @@ -417,51 +396,9 @@ def _prettyprint_stats(
for rank, (memory_allocated, memory_reserved) in rank_mem_info.items():
short_printout += f"\n rank-{rank} - peak allocated memory {memory_allocated/1024/1024:.2f}MB, peak reserved: {memory_reserved/1024/1024:.2f}MB"
short_printout += "\n"
if median_benchmark_stat.has_extended_stats:
# NOTE At this point in the program extended statistics are available
trace_time_ns = median_benchmark_stat.last_trace_host_stop - median_benchmark_stat.last_trace_host_start
cache_time_ns = median_benchmark_stat.last_trace_cache_stop - median_benchmark_stat.last_trace_cache_start
tracing_time_ns = median_benchmark_stat.last_trace_tracing_stop - median_benchmark_stat.last_trace_tracing_start
trace_execution_time_ns = (
median_benchmark_stat.last_trace_host_execution_stop - median_benchmark_stat.last_trace_host_execution_start
)

trace_time_us: str = ns_to_us(trace_time_ns)
cache_time_us: str = ns_to_us(cache_time_ns)
tracing_time_us: str = ns_to_us(tracing_time_ns)
trace_execution_time_us: str = ns_to_us(trace_execution_time_ns)

trace_time_percentage: str = f"{round(trace_time_ns / median_benchmark_stat.total_time * 100)}%"
cache_time_percentage: str = f"{round(cache_time_ns / median_benchmark_stat.total_time * 100)}%"
tracing_time_percentage: str = f"{round(tracing_time_ns / median_benchmark_stat.total_time * 100)}%"
trace_execution_time_percentage: str = (
f"{round(trace_execution_time_ns / median_benchmark_stat.total_time * 100)}%"
)

before_trace_time_ns = median_benchmark_stat.last_trace_host_start - median_benchmark_stat.start_time
accelerator_wait_time_ns = median_benchmark_stat.stop_time - median_benchmark_stat.last_trace_host_stop

before_trace_time_us: str = ns_to_us(before_trace_time_ns)
accelerator_wait_time_us: str = ns_to_us(accelerator_wait_time_ns)

before_trace_time_percentage: str = f"{round(before_trace_time_ns / median_benchmark_stat.total_time * 100)}%"
accelerator_wait_time_percentage: str = (
f"{round(accelerator_wait_time_ns / median_benchmark_stat.total_time * 100)}%"
)

extension = f"""\
The median benchmark took {before_trace_time_us} to get into the tracing logic, {before_trace_time_percentage} of the total time.
The median benchmark took {accelerator_wait_time_us} waiting for the accelerator's computation to finish, {accelerator_wait_time_percentage} of the total time.
The median benchmark run's total time in tracing logic is {trace_time_us}, {trace_time_percentage} of the total time.
The median benchmark run's cache lookup time is {cache_time_us}, {cache_time_percentage} of the total time.
The median benchmark run's time spent tracing is {tracing_time_us}, {tracing_time_percentage} of the total time.
The median benchmark run's time to request the traced program be executed is {trace_execution_time_us}, {trace_execution_time_percentage} of the total time.
"""
else:
extension = ""

output = textwrap.dedent(preamble) + textwrap.indent(textwrap.dedent(extension), " " * 4)
print(output)
print(textwrap.dedent(preamble))


def print_rank_0(message):
Expand Down
Loading
Loading