Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali Alshaarawy committed Dec 12, 2024
1 parent 246e8a8 commit fe66455
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,21 @@ def make_compiled(
region_trace = TraceCtx(None)
region_trace.args = sorted_unique_inputs
region_trace.kwargs = {}
for a in sorted_unique_inputs:
region_trace.bound_symbols.append(prims.unpack_trivial.bind(a, name=a.name, output=a))
region_trace.names = set([a.name for a in region_trace.args])
with tracectx(region_trace):
for a in sorted_unique_inputs:
prims.unpack_trivial(a, name=a.name)

region_trace.bound_symbols += list(bsyms)
region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=None))
for bsym in region_trace.bound_symbols:
if bsym.sym == prims.unpack_trivial:
continue
for o in bsym.flat_outs:
if o is not None:
region_trace.add_name(o.name)
for sbsym in bsym.subsymbols:
for o in sbsym.flat_outs:
if o is not None and o.name not in region_trace.names:
region_trace.add_name(o.name)
for arg in region_trace.args:
if arg.name not in region_trace.names:
region_trace.add_name(arg.name)
list(map(lambda o: region_trace.add_name(o.name), filter(lambda o: o is not None and o.name not in region_trace.names, sbsym.flat_outs)))

# maybe make this the default if no sig info is present?
region_trace._siginfo = SigInfo("to_be_compiled")
Expand Down

0 comments on commit fe66455

Please sign in to comment.