diff --git a/thunder/__init__.py b/thunder/__init__.py index 06ba26fcc2..4eb6d9b36a 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -473,6 +473,7 @@ def get_computation_and_inputs(*args, **kwargs): ) = cache_entry try: from thunder.core.profile import annotate_for_profile + inps, pro_to_epi = annotate_for_profile("prologue")(pro(*args, **kwargs)) except Exception as _: continue @@ -594,11 +595,11 @@ def get_computation_and_inputs(*args, **kwargs): pro = thunder.core.profile.annotate_for_profile("prologue python_callable")( prologue_trc.python_callable(include_decorators=False) ) - #pro = thunder.core.profile.annotate_for_profile("prologue")(pro) + # pro = thunder.core.profile.annotate_for_profile("prologue")(pro) epilogue_trc = transform_to_torch_types(epilogue_trc) epilogue = thunder.core.profile.annotate_for_profile("epilogue python_callable")( - epilogue_trc.python_callable() + epilogue_trc.python_callable() ) cs.last_prologue_traces = prologue_traces @@ -703,6 +704,7 @@ def host_execution_timer(fn): def wrapped(*args, **kwargs): with thunder.core.profile.annotate_for_profile("computation"): return fn(*args, **kwargs) + return wrapped def decorate_computation_function(get_computation_and_inputs_fn, *decorators): @@ -725,8 +727,8 @@ def update_call_statistics(fn): def wrapped(*args, **kwargs): cs.calls += 1 return fn(*args, **kwargs) - return wrapped + return wrapped @thunder.core.profile.annotate_for_profile("maybe_connect_to_autograd") def maybe_connect_to_autograd(cache_entry, result): diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 8a184306b6..fff3707574 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -693,8 +693,6 @@ def visit_(bsym: BoundSymbolInterface) -> VISIT_TYPE: bound_symbols.append(nbsym) new_trace.bound_symbols = bound_symbols - new_trace.set_provenance( - TraceProvenance("Transform for replace uniform") - ) + new_trace.set_provenance(TraceProvenance("Transform for replace uniform")) return new_trace diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index d25503b85d..44ad485311 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -370,9 +370,7 @@ def thunder_140410131706304(x, y): new_bsyms = replace_redundant_inputs(redundant_map, cse_trace_bound_symbols) cse_trace.bound_symbols = new_bsyms - cse_trace.set_provenance( - TraceProvenance("Common Subexpression Elimination") - ) + cse_trace.set_provenance(TraceProvenance("Common Subexpression Elimination")) return cse_trace diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 4202f2f264..cf37d7f170 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3234,11 +3234,7 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]] - new_bwd_trace.set_provenance( - TraceProvenance("Saved for backward remat trace") - ) - new_fwd_trace.set_provenance( - TraceProvenance("Saved for backward remat trace") - ) + new_bwd_trace.set_provenance(TraceProvenance("Saved for backward remat trace")) + new_fwd_trace.set_provenance(TraceProvenance("Saved for backward remat trace")) return new_fwd_trace, new_bwd_trace diff --git a/thunder/dev_utils/nvtx_profile_transform.py b/thunder/dev_utils/nvtx_profile_transform.py index a887911ceb..ba6f2ce78b 100644 --- a/thunder/dev_utils/nvtx_profile_transform.py +++ b/thunder/dev_utils/nvtx_profile_transform.py @@ -51,13 +51,9 @@ def transform_trace_post_optimization(self, trace: Trace, **kwargs) -> Trace: continue # Add nvtx range for the symbol. - profile_trace.bound_symbols.append( - nvtx_push.bind(f"{''.join(bound_symbol.python(indent=0))}", output=None) - ) + profile_trace.bound_symbols.append(nvtx_push.bind(f"{''.join(bound_symbol.python(indent=0))}", output=None)) profile_trace.bound_symbols.append(bound_symbol) profile_trace.bound_symbols.append(nvtx_pop.bind(output=None)) - profile_trace.set_provenance( - TraceProvenance("NVTX Profile Transform") - ) + profile_trace.set_provenance(TraceProvenance("NVTX Profile Transform")) return profile_trace diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 31c62d258a..ca9e549fc4 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -825,9 +825,7 @@ def map_redundant(x: Any) -> Any: trace_output = tree_map(map_redundant, return_bsym.args) cse_trace.bound_symbols[-1] = prims.python_return.bind(*trace_output, output=None) - cse_trace.set_provenance( - TraceProvenance("Nvfuser Common Subexpression Elimination") - ) + cse_trace.set_provenance(TraceProvenance("Nvfuser Common Subexpression Elimination")) return cse_trace # TODO Restore fusion logic here -- this just replaces supported operations in isolation at the moment diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 6839cc2580..1516427d8c 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -106,9 +106,7 @@ def process_bsym(self, bsym): extrace, _ = OpExProcessor(trace)() - extrace.set_provenance( - TraceProvenance("Transform for operator executor execution") - ) + extrace.set_provenance(TraceProvenance("Transform for operator executor execution")) return extrace diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 7f9d96f08e..34a2c4ede4 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -270,7 +270,6 @@ def grad_second(a): assert_close(actual_second, expected) - @instantiate(dtypes=(thunder.float32,)) def test_optimizer_unpack(executor, device, dtype): class Optimizer(torch.optim.Optimizer):