Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 18, 2024
1 parent 2a01e25 commit ab7ffe6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 5 additions & 3 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def get_computation_and_inputs(*args, **kwargs):
) = cache_entry
try:
from thunder.core.profile import annotate_for_profile

inps, pro_to_epi = annotate_for_profile("prologue")(pro(*args, **kwargs))
except Exception as _:
continue
Expand Down Expand Up @@ -594,11 +595,11 @@ def get_computation_and_inputs(*args, **kwargs):
pro = thunder.core.profile.annotate_for_profile("prologue python_callable")(
prologue_trc.python_callable(include_decorators=False)
)
#pro = thunder.core.profile.annotate_for_profile("prologue")(pro)
# pro = thunder.core.profile.annotate_for_profile("prologue")(pro)

epilogue_trc = transform_to_torch_types(epilogue_trc)
epilogue = thunder.core.profile.annotate_for_profile("epilogue python_callable")(
epilogue_trc.python_callable()
epilogue_trc.python_callable()
)

cs.last_prologue_traces = prologue_traces
Expand Down Expand Up @@ -703,6 +704,7 @@ def host_execution_timer(fn):
def wrapped(*args, **kwargs):
with thunder.core.profile.annotate_for_profile("computation"):
return fn(*args, **kwargs)

return wrapped

def decorate_computation_function(get_computation_and_inputs_fn, *decorators):
Expand All @@ -725,8 +727,8 @@ def update_call_statistics(fn):
def wrapped(*args, **kwargs):
cs.calls += 1
return fn(*args, **kwargs)
return wrapped

return wrapped

@thunder.core.profile.annotate_for_profile("maybe_connect_to_autograd")
def maybe_connect_to_autograd(cache_entry, result):
Expand Down
1 change: 0 additions & 1 deletion thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def grad_second(a):
assert_close(actual_second, expected)



@instantiate(dtypes=(thunder.float32,))
def test_optimizer_unpack(executor, device, dtype):
class Optimizer(torch.optim.Optimizer):
Expand Down

0 comments on commit ab7ffe6

Please sign in to comment.