Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 18, 2024
1 parent 14418ed commit 9ffb253
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 28 deletions.
8 changes: 5 additions & 3 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def get_computation_and_inputs(*args, **kwargs):
) = cache_entry
try:
from thunder.core.profile import annotate_for_profile

inps, pro_to_epi = annotate_for_profile("prologue")(pro(*args, **kwargs))
except Exception as _:
continue
Expand Down Expand Up @@ -594,11 +595,11 @@ def get_computation_and_inputs(*args, **kwargs):
pro = thunder.core.profile.annotate_for_profile("prologue python_callable")(
prologue_trc.python_callable(include_decorators=False)
)
#pro = thunder.core.profile.annotate_for_profile("prologue")(pro)
# pro = thunder.core.profile.annotate_for_profile("prologue")(pro)

epilogue_trc = transform_to_torch_types(epilogue_trc)
epilogue = thunder.core.profile.annotate_for_profile("epilogue python_callable")(
epilogue_trc.python_callable()
epilogue_trc.python_callable()
)

cs.last_prologue_traces = prologue_traces
Expand Down Expand Up @@ -703,6 +704,7 @@ def host_execution_timer(fn):
def wrapped(*args, **kwargs):
with thunder.core.profile.annotate_for_profile("computation"):
return fn(*args, **kwargs)

return wrapped

def decorate_computation_function(get_computation_and_inputs_fn, *decorators):
Expand All @@ -725,8 +727,8 @@ def update_call_statistics(fn):
def wrapped(*args, **kwargs):
cs.calls += 1
return fn(*args, **kwargs)
return wrapped

return wrapped

@thunder.core.profile.annotate_for_profile("maybe_connect_to_autograd")
def maybe_connect_to_autograd(cache_entry, result):
Expand Down
4 changes: 1 addition & 3 deletions thunder/core/rematerialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,6 @@ def visit_(bsym: BoundSymbolInterface) -> VISIT_TYPE:
bound_symbols.append(nbsym)

new_trace.bound_symbols = bound_symbols
new_trace.set_provenance(
TraceProvenance("Transform for replace uniform")
)
new_trace.set_provenance(TraceProvenance("Transform for replace uniform"))

return new_trace
4 changes: 1 addition & 3 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,7 @@ def thunder_140410131706304(x, y):
new_bsyms = replace_redundant_inputs(redundant_map, cse_trace_bound_symbols)
cse_trace.bound_symbols = new_bsyms

cse_trace.set_provenance(
TraceProvenance("Common Subexpression Elimination")
)
cse_trace.set_provenance(TraceProvenance("Common Subexpression Elimination"))
return cse_trace


Expand Down
8 changes: 2 additions & 6 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3234,11 +3234,7 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr

new_bwd_trace.args = [(new_saved_for_backward, fwd_trace.output[1][1]), *bwd_trace.args[1:]]

new_bwd_trace.set_provenance(
TraceProvenance("Saved for backward remat trace")
)
new_fwd_trace.set_provenance(
TraceProvenance("Saved for backward remat trace")
)
new_bwd_trace.set_provenance(TraceProvenance("Saved for backward remat trace"))
new_fwd_trace.set_provenance(TraceProvenance("Saved for backward remat trace"))

return new_fwd_trace, new_bwd_trace
8 changes: 2 additions & 6 deletions thunder/dev_utils/nvtx_profile_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,9 @@ def transform_trace_post_optimization(self, trace: Trace, **kwargs) -> Trace:
continue

# Add nvtx range for the symbol.
profile_trace.bound_symbols.append(
nvtx_push.bind(f"{''.join(bound_symbol.python(indent=0))}", output=None)
)
profile_trace.bound_symbols.append(nvtx_push.bind(f"{''.join(bound_symbol.python(indent=0))}", output=None))
profile_trace.bound_symbols.append(bound_symbol)
profile_trace.bound_symbols.append(nvtx_pop.bind(output=None))

profile_trace.set_provenance(
TraceProvenance("NVTX Profile Transform")
)
profile_trace.set_provenance(TraceProvenance("NVTX Profile Transform"))
return profile_trace
4 changes: 1 addition & 3 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,9 +825,7 @@ def map_redundant(x: Any) -> Any:
trace_output = tree_map(map_redundant, return_bsym.args)
cse_trace.bound_symbols[-1] = prims.python_return.bind(*trace_output, output=None)

cse_trace.set_provenance(
TraceProvenance("Nvfuser Common Subexpression Elimination")
)
cse_trace.set_provenance(TraceProvenance("Nvfuser Common Subexpression Elimination"))
return cse_trace

# TODO Restore fusion logic here -- this just replaces supported operations in isolation at the moment
Expand Down
4 changes: 1 addition & 3 deletions thunder/executors/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ def process_bsym(self, bsym):

extrace, _ = OpExProcessor(trace)()

extrace.set_provenance(
TraceProvenance("Transform for operator executor execution")
)
extrace.set_provenance(TraceProvenance("Transform for operator executor execution"))

return extrace

Expand Down
1 change: 0 additions & 1 deletion thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,6 @@ def grad_second(a):
assert_close(actual_second, expected)



@instantiate(dtypes=(thunder.float32,))
def test_optimizer_unpack(executor, device, dtype):
class Optimizer(torch.optim.Optimizer):
Expand Down

0 comments on commit 9ffb253

Please sign in to comment.