Skip to content

Commit

Permalink
Rework timer infrastructure.
Browse files Browse the repository at this point in the history
Add profile annotations to appropriate places.

Remove the old timer tracking stuff so that we do not have two
methods. Use nsys for proper host profiling.
  • Loading branch information
tfogal committed Dec 18, 2024
1 parent 2d0199e commit 2a01e25
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 391 deletions.
84 changes: 36 additions & 48 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from looseversion import LooseVersion

from thunder.core.module import ThunderModule
import thunder.core.profile
from thunder.core.interpreter import InterpreterLogItem
from thunder.core.options import (
CACHE_OPTIONS,
SHARP_EDGES_OPTIONS,
DebugOptions,
)
from thunder.core.profile import annotate_for_profile
from thunder.core.trace import (
TraceResults,
TraceCtx,
Expand Down Expand Up @@ -277,6 +279,7 @@ def compile(fn: Callable, recipe: Recipe | None):

# This function will replace compile() (below) before RC1
# TODO RC1 Consider renaming compile_options to additional_compile_options
@thunder.core.profile.annotate_for_profile("jit")
def jit(
fn: Callable,
/,
Expand Down Expand Up @@ -378,6 +381,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
tensor_group_index_to_tensor_indices[tgi].append(idx)
return tensor_group_index_to_tensor_indices

@thunder.core.profile.annotate_for_profile("_alias_tensor_of_args_kwargs")
def _alias_tensor_of_args_kwargs(*args, **kwargs) -> str:
"""If no aliases found, empty string, otherwise, aliases are comma separated, groups are hyphen separated."""

Expand All @@ -392,6 +396,7 @@ def _alias_tensor_of_args_kwargs(*args, **kwargs) -> str:

@langctxs.langctx(cd.langctx)
@_with_cache_info_ctx
@thunder.core.profile.annotate_for_profile("get_computation_and_inputs")
def get_computation_and_inputs(*args, **kwargs):
# set up a record of things in the current environment that impact caching / prologues
# this could be replaced by the respective querying in the prologues
Expand Down Expand Up @@ -441,7 +446,8 @@ def get_computation_and_inputs(*args, **kwargs):
# alaises wouldn't matter, thus it'd be better to nullify this entry in such cases.
# It however would require the functionalized computation trace to interact with `cache_info`,
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)
with thunder.core.profile.annotate_for_profile("Alias_tensor_of_args_kwarsgs"):
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)

# Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform
# to treat certain Symbols as constant.
Expand All @@ -451,7 +457,6 @@ def get_computation_and_inputs(*args, **kwargs):
# TODO RC1 Add module and function checks to prologue (make it a compile option)

# Checks cache
cs.last_trace_cache_start = time.perf_counter_ns()
if (cd.cache_option is CACHE_OPTIONS.CONSTANT_VALUES) or (cd.cache_option is CACHE_OPTIONS.SYMBOLIC_VALUES):
for cache_entry in reversed(cs.interpreter_cache):
with compile_data_and_stats(cd, cs):
Expand All @@ -467,7 +472,8 @@ def get_computation_and_inputs(*args, **kwargs):
_return_none_instead_of_grads,
) = cache_entry
try:
inps, pro_to_epi = pro(*args, **kwargs)
from thunder.core.profile import annotate_for_profile
inps, pro_to_epi = annotate_for_profile("prologue")(pro(*args, **kwargs))
except Exception as _:
continue

Expand All @@ -478,10 +484,6 @@ def get_computation_and_inputs(*args, **kwargs):
cs.last_interpreter_log = None
cs.last_prologue_traces = pro_traces
cs.last_prologue = pro
cs.last_prologue_transformation_start = 0
cs.last_prologue_transformation_stop = 0
cs.last_computation_transformation_start = 0
cs.last_computation_transformation_stop = 0

return cache_entry, inps, pro_to_epi

Expand All @@ -499,7 +501,7 @@ def get_computation_and_inputs(*args, **kwargs):
backward_traces,
) = cache_entry

inps, pro_to_epi = pro(*args, **kwargs)
inps, pro_to_epi = annotate_for_profile("Prologue")(pro(*args, **kwargs))

# Updates cache statistics
cs.cache_hits += 1
Expand All @@ -512,16 +514,13 @@ def get_computation_and_inputs(*args, **kwargs):
return cache_entry, inps, pro_to_epi

cs.cache_misses += 1
cs.last_trace_cache_stop = time.perf_counter_ns()

# Resets use of compile flags
cs.last_compile_reasons = defaultdict(list)

with compile_data_and_stats(cd, cs):
# Acquires the trace OR inlines the trace into an existing trace and
# returns the (proxied) result of the operation
cs.last_trace_tracing_start = time.perf_counter_ns()

prologue_trc: TraceCtx
computation_trc: TraceCtx
jit_results: TraceResults = thunder_general_jit(
Expand Down Expand Up @@ -563,11 +562,7 @@ def get_computation_and_inputs(*args, **kwargs):

epilogue_traces = [epilogue_trc]

cs.last_trace_tracing_stop = time.perf_counter_ns()

# Makes the prologue callable
cs.last_prologue_transformation_start = time.perf_counter_ns()

transform: Transform
for transform in transforms:
thunder.core.utils.check_type(transform, Transform)
Expand Down Expand Up @@ -596,13 +591,16 @@ def get_computation_and_inputs(*args, **kwargs):
use_del_last_used=False,
)
prologue_trc = prologue_traces[-1]
pro = prologue_trc.python_callable(include_decorators=False)
pro = prologue_execution_timer(pro)
pro = thunder.core.profile.annotate_for_profile("prologue python_callable")(
prologue_trc.python_callable(include_decorators=False)
)
#pro = thunder.core.profile.annotate_for_profile("prologue")(pro)

epilogue_trc = transform_to_torch_types(epilogue_trc)
epilogue = epilogue_trc.python_callable()
epilogue = thunder.core.profile.annotate_for_profile("epilogue python_callable")(
epilogue_trc.python_callable()
)

cs.last_prologue_transformation_stop = time.perf_counter_ns()
cs.last_prologue_traces = prologue_traces
cs.last_prologue = pro
cs.last_traces = computation_traces
Expand Down Expand Up @@ -672,7 +670,9 @@ def get_computation_and_inputs(*args, **kwargs):
backward_traces.append(backward_trc)

if backward_trc is not None:
backward_fn = backward_trc.python_callable()
backward_fn = thunder.core.profile.annotate_for_profile("backward python_callable")(
backward_trc.python_callable()
)
else:
backward_fn = None
# We do not have to return auxiliary tensors, which will only be useful in backward pass
Expand Down Expand Up @@ -701,33 +701,20 @@ def get_computation_and_inputs(*args, **kwargs):

def host_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_trace_host_execution_start = time.perf_counter_ns()
try:
with thunder.core.profile.annotate_for_profile("computation"):
return fn(*args, **kwargs)
finally:
cs.last_trace_host_execution_stop = time.perf_counter_ns()

return wrapped

def prologue_execution_timer(fn):
def wrapped(*args, **kwargs):
cs.last_prologue_execution_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_prologue_execution_stop = time.perf_counter_ns()

return wrapped

def decorate_computation_function(get_computation_and_inputs_fn, *decorators):
def wrapped(*args, **kwargs):
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
decorated_computation_fn = cache_entry.computation_fn
for decorator in decorators:
decorated_computation_fn = decorator(decorated_computation_fn)
if decorators:
cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn)
return cache_entry, inps, pro_to_epi
with thunder.core.profile.annotate_for_profile("get_computation_and_inputs"):
cache_entry, inps, pro_to_epi = get_computation_and_inputs_fn(*args, **kwargs)
decorated_computation_fn = cache_entry.computation_fn
for decorator in decorators:
decorated_computation_fn = decorator(decorated_computation_fn)
if decorators:
cache_entry = cache_entry._replace(computation_fn=decorated_computation_fn)
return cache_entry, inps, pro_to_epi

return wrapped

Expand All @@ -737,14 +724,11 @@ def wrapped(*args, **kwargs):
def update_call_statistics(fn):
def wrapped(*args, **kwargs):
cs.calls += 1
cs.last_trace_host_start = time.perf_counter_ns()
try:
return fn(*args, **kwargs)
finally:
cs.last_trace_host_stop = time.perf_counter_ns()

return fn(*args, **kwargs)
return wrapped


@thunder.core.profile.annotate_for_profile("maybe_connect_to_autograd")
def maybe_connect_to_autograd(cache_entry, result):
if cache_entry.backward_fn:
# If the backward function is available, we need to connect the
Expand All @@ -769,6 +753,7 @@ def call_epilogue(cache_entry, comp_result, pro_to_epi):

@wraps(fn)
@update_call_statistics
@thunder.core.profile.annotate_for_profile("fn_")
def fn_(*args, **kwargs) -> Any:
if is_tracing():
_recursive_jit_call_warning()
Expand Down Expand Up @@ -990,6 +975,7 @@ def get_auto_registered_torch_op_names(fn: Callable, /) -> set[str] | None:


# TODO (mruberry) Update this
@thunder.core.profile.annotate_for_profile("_grad_transform")
def _grad_transform(trace):
grad_fwd_trace = from_trace(trace)
trace_tok = set_tracectx(grad_fwd_trace)
Expand Down Expand Up @@ -1043,10 +1029,12 @@ def _grad_transform(trace):

# TODO Test nesting of grad and grad and grad and grad
# TODO Test nesting of a regular function + calling grad
@thunder.core.profile.annotate_for_profile("grad")
def grad(fn):
cfn = compile(fn)

@wraps(cfn)
@thunder.core.profile.annotate_for_profile("_fn")
def _fn(*args, **kwargs):
original_result, original_trace = cfn(*args, **kwargs)
original_trace = last_traces(cfn)
Expand Down
65 changes: 1 addition & 64 deletions thunder/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,6 @@ class BenchmarkRunStatistics:
stop_time: int
host_stop_time: int
called_backward: bool
has_extended_stats: bool = False
last_trace_host_start: int = -1
last_trace_host_stop: int = -1
last_trace_cache_start: int = -1
last_trace_cache_stop: int = -1
last_trace_tracing_start: int = -1
last_trace_tracing_stop: int = -1
last_trace_host_execution_start: int = -1
last_trace_host_execution_stop: int = -1


# A timing helper
Expand Down Expand Up @@ -294,18 +285,6 @@ def _benchmark(
called_backward=called_backward,
)

# TODO Ensure the compile data statistics are always populated
if cs is not None and cs.last_trace_host_start > 0:
stat.has_extended_stats = True
stat.last_trace_host_start = cs.last_trace_host_start
stat.last_trace_host_stop = cs.last_trace_host_stop
stat.last_trace_cache_start = cs.last_trace_cache_start
stat.last_trace_cache_stop = cs.last_trace_cache_stop
stat.last_trace_tracing_start = cs.last_trace_tracing_start
stat.last_trace_tracing_stop = cs.last_trace_tracing_stop
stat.last_trace_host_execution_start = cs.last_trace_host_execution_start
stat.last_trace_host_execution_stop = cs.last_trace_host_execution_stop

stats.append(stat)

return stats
Expand Down Expand Up @@ -417,51 +396,9 @@ def _prettyprint_stats(
for rank, (memory_allocated, memory_reserved) in rank_mem_info.items():
short_printout += f"\n rank-{rank} - peak allocated memory {memory_allocated/1024/1024:.2f}MB, peak reserved: {memory_reserved/1024/1024:.2f}MB"
short_printout += "\n"
if median_benchmark_stat.has_extended_stats:
# NOTE At this point in the program extended statistics are available
trace_time_ns = median_benchmark_stat.last_trace_host_stop - median_benchmark_stat.last_trace_host_start
cache_time_ns = median_benchmark_stat.last_trace_cache_stop - median_benchmark_stat.last_trace_cache_start
tracing_time_ns = median_benchmark_stat.last_trace_tracing_stop - median_benchmark_stat.last_trace_tracing_start
trace_execution_time_ns = (
median_benchmark_stat.last_trace_host_execution_stop - median_benchmark_stat.last_trace_host_execution_start
)

trace_time_us: str = ns_to_us(trace_time_ns)
cache_time_us: str = ns_to_us(cache_time_ns)
tracing_time_us: str = ns_to_us(tracing_time_ns)
trace_execution_time_us: str = ns_to_us(trace_execution_time_ns)

trace_time_percentage: str = f"{round(trace_time_ns / median_benchmark_stat.total_time * 100)}%"
cache_time_percentage: str = f"{round(cache_time_ns / median_benchmark_stat.total_time * 100)}%"
tracing_time_percentage: str = f"{round(tracing_time_ns / median_benchmark_stat.total_time * 100)}%"
trace_execution_time_percentage: str = (
f"{round(trace_execution_time_ns / median_benchmark_stat.total_time * 100)}%"
)

before_trace_time_ns = median_benchmark_stat.last_trace_host_start - median_benchmark_stat.start_time
accelerator_wait_time_ns = median_benchmark_stat.stop_time - median_benchmark_stat.last_trace_host_stop

before_trace_time_us: str = ns_to_us(before_trace_time_ns)
accelerator_wait_time_us: str = ns_to_us(accelerator_wait_time_ns)

before_trace_time_percentage: str = f"{round(before_trace_time_ns / median_benchmark_stat.total_time * 100)}%"
accelerator_wait_time_percentage: str = (
f"{round(accelerator_wait_time_ns / median_benchmark_stat.total_time * 100)}%"
)

extension = f"""\
The median benchmark took {before_trace_time_us} to get into the tracing logic, {before_trace_time_percentage} of the total time.
The median benchmark took {accelerator_wait_time_us} waiting for the accelerator's computation to finish, {accelerator_wait_time_percentage} of the total time.
The median benchmark run's total time in tracing logic is {trace_time_us}, {trace_time_percentage} of the total time.
The median benchmark run's cache lookup time is {cache_time_us}, {cache_time_percentage} of the total time.
The median benchmark run's time spent tracing is {tracing_time_us}, {tracing_time_percentage} of the total time.
The median benchmark run's time to request the traced program be executed is {trace_execution_time_us}, {trace_execution_time_percentage} of the total time.
"""
else:
extension = ""

output = textwrap.dedent(preamble) + textwrap.indent(textwrap.dedent(extension), " " * 4)
print(output)
print(textwrap.dedent(preamble))


def print_rank_0(message):
Expand Down
Loading

0 comments on commit 2a01e25

Please sign in to comment.