diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index c0a03b7866..bfae5aa085 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -72,22 +72,21 @@ def make_compiled( region_trace = TraceCtx(None) region_trace.args = sorted_unique_inputs region_trace.kwargs = {} - for a in sorted_unique_inputs: - region_trace.bound_symbols.append(prims.unpack_trivial.bind(a, name=a.name, output=a)) + region_trace.names = set([a.name for a in region_trace.args]) + with tracectx(region_trace): + for a in sorted_unique_inputs: + prims.unpack_trivial(a, name=a.name) region_trace.bound_symbols += list(bsyms) region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=None)) for bsym in region_trace.bound_symbols: + if bsym.sym == prims.unpack_trivial: + continue for o in bsym.flat_outs: if o is not None: region_trace.add_name(o.name) for sbsym in bsym.subsymbols: - for o in sbsym.flat_outs: - if o is not None and o.name not in region_trace.names: - region_trace.add_name(o.name) - for arg in region_trace.args: - if arg.name not in region_trace.names: - region_trace.add_name(arg.name) + list(map(lambda o: region_trace.add_name(o.name), filter(lambda o: o is not None and o.name not in region_trace.names, sbsym.flat_outs))) # maybe make this the default if no sig info is present? region_trace._siginfo = SigInfo("to_be_compiled")