From e6b151f836b7d0bdabd05c173066390dd5c5f160 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 17:56:35 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/__init__.py | 8 +++++--- thunder/tests/test_core.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 7aa3266191..9fe1b61554 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -473,6 +473,7 @@ def get_computation_and_inputs(*args, **kwargs): ) = cache_entry try: from thunder.core.profile import annotate_for_profile + inps, pro_to_epi = annotate_for_profile("prologue")(pro(*args, **kwargs)) except Exception as _: continue @@ -594,11 +595,11 @@ def get_computation_and_inputs(*args, **kwargs): pro = thunder.core.profile.annotate_for_profile("prologue python_callable")( prologue_trc.python_callable(include_decorators=False) ) - #pro = thunder.core.profile.annotate_for_profile("prologue")(pro) + # pro = thunder.core.profile.annotate_for_profile("prologue")(pro) epilogue_trc = transform_to_torch_types(epilogue_trc) epilogue = thunder.core.profile.annotate_for_profile("epilogue python_callable")( - epilogue_trc.python_callable() + epilogue_trc.python_callable() ) cs.last_prologue_traces = prologue_traces @@ -703,6 +704,7 @@ def host_execution_timer(fn): def wrapped(*args, **kwargs): with thunder.core.profile.annotate_for_profile("computation"): return fn(*args, **kwargs) + return wrapped def decorate_computation_function(get_computation_and_inputs_fn, *decorators): @@ -725,8 +727,8 @@ def update_call_statistics(fn): def wrapped(*args, **kwargs): cs.calls += 1 return fn(*args, **kwargs) - return wrapped + return wrapped @thunder.core.profile.annotate_for_profile("maybe_connect_to_autograd") def maybe_connect_to_autograd(cache_entry, result): diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 7f9d96f08e..34a2c4ede4 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -270,7 +270,6 @@ def grad_second(a): assert_close(actual_second, expected) - @instantiate(dtypes=(thunder.float32,)) def test_optimizer_unpack(executor, device, dtype): class Optimizer(torch.optim.Optimizer):