From 7869f83c29f3467a16275d697b3ad348cccc50ff Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 25 Nov 2024 14:41:01 +0100 Subject: [PATCH 1/7] use transform for execution to get torch_compile executable --- thunder/core/symbol.py | 4 +- thunder/core/trace_interpreter.py | 158 +++++++++++++++++++ thunder/executors/passes.py | 115 +++++++------- thunder/executors/torch_compile.py | 64 ++++---- thunder/tests/test_networks.py | 2 +- thunder/tests/test_torch_compile_executor.py | 11 ++ thunder/torch/__init__.py | 1 + 7 files changed, 261 insertions(+), 94 deletions(-) diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index da6eca6ddd..50bf38c001 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -231,7 +231,9 @@ def __reduce__(self): # for pickling raise ValueError("Cannot serialize a symbol without a module and executor.") if self.executor is None: - assert getattr(sys.modules[self.module.__name__], self.name, None) is self + assert ( + getattr(sys.modules[self.module.__name__], self.name, None) is self + ), f"{self.module.__name__}.{self.name} is not {self}" else: assert thunder.get_executor(self.executor.name).opmap.get(self.name) is self diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index ed0e110191..37694ade6b 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -203,3 +203,161 @@ def do_swap(v): return new_trace, tree_map(read, trace.output), env return new_trace, tree_map(read, trace.output) + + +class TraceSubstitutionProcessor: + NULL = object() + + def __init__(self, trace, *args, **kwargs): + self.env = {} + self.trace = trace + self.new_trace = from_trace(self.trace) + self.have_processed_args = False + print(self.trace) + + def read(self, x: VariableInterface | Any) -> Any: + if isinstance(x, VariableInterface): + return self.env[x.name] + else: + return x + + def write(self, v: VariableInterface | Any, val: Any) -> None: + if not isinstance(v, VariableInterface): + return + # Duplicates are allowed and overwritten + if v.name in self.env: + raise ValueError(f"Variable {v.name} is being overwritten this is not allowed") + self.env[v.name] = val + + def add_to_swap_map(self, old, new): + if old is new: + return + if isinstance(old, ProxyInterface): + if isinstance(new, ProxyInterface) and variableify(new) in self.env: + # the new isn't new, but something returned the input + # this means we need to map the old to the new + old, new = new, old + elif isinstance(old, TensorProxyInterface): + # should we have a fix shapes pass? the sharding + # (FSDP, tensor parallel) transforms do "break" shape metadata + self.new_trace.names.remove(old.name) # taken by the .replace proxy + if isinstance(new, VJPDual): + old = old.replace(shape=new.primal._shape) + else: + old = old.replace(shape=new._shape) + + if isinstance(new, VJPDual): + self.swap_map[variableify(new.primal)] = old + new.primal = old + else: + assert isinstance(new, ProxyInterface), (old, new) + self.swap_map[variableify(new)] = old + + def do_swap(self, v): + if isinstance(v, VJPDual): + v.primal = tree_map(self.do_swap, v.primal) + v.residuals = tree_map(self.do_swap, v.residuals) + return v + if not isinstance(v, ProxyInterface): + return v + return self.swap_map.get(variableify(v), v) + + def add_unprocessed_bsyms(self, bsyms): + self.unprocessed_bsyms[:0] = bsyms + + def bsyms_from_function(self, fn, /, *args, **kwargs): + self.new_trace.push_scope([]) + result = fn(*args, **kwargs) + self.new_bsyms += self.new_trace.pop_scope() + self.set_result(result) + return result + + def add_processed_bsyms(self, bsyms): + + ### replacements of inputs! + self.new_bsyms += bsyms + + def set_result(self, result): + self.replacement_result = result + + def process_bsym(self, bsym): + raise NotImplementedError("This needs to be implemented in subclasses") + + def process_args(self, *args, **kwargs): + self.have_processed_args = True + with tracectx(self.new_trace): + self.swap_map = {} + + safe_map_flat(self.add_to_swap_map, list(self.trace.args), list(args)) + safe_map_flat(self.add_to_swap_map, list(self.trace.kwargs.values()), list(kwargs.values())) + args, kwargs = tree_map(self.do_swap, (args, kwargs)) + + safe_map_flat(self.write, list(self.trace.args), list(args)) + safe_map_flat(self.write, list(self.trace.kwargs.values()), list(kwargs.values())) + + def __call__(self): + # if not self.have_processed_args and self.trace.args is not None: + # self.process_args(*self.args, **self.kwargs) + with tracectx(self.new_trace): + self.unprocessed_bsyms = self.trace.bound_symbols[:] + + while self.unprocessed_bsyms: + bsym = self.unprocessed_bsyms.pop(0) + + if self.have_processed_args and bsym.sym.id in trace_interpreter_skip_list: + self.new_trace.bound_symbols.append(bsym.from_bsym()) + continue + + args = tree_map(self.read, bsym.args) + kwargs = tree_map(self.read, bsym.kwargs) + + # this should be prettier + self.replacement_result = self.NULL + self.new_bsyms = [] + + self.process_bsym(bsym) + + if self.new_bsyms: + assert self.replacement_result is not self.NULL, "Need to call set_result if producing new bsyms" + + if self.replacement_result is not self.NULL: + self.swap_map = {} + + # TODO: if inputs are returned, the old outputs should be mapped on the new ones (= the inputs) instead of the other way round + if not self.new_bsyms: + # empty result means we want to swap references to the old + # result to the new result (which will be one of the args) + safe_map_flat( + self.add_to_swap_map, + list(sequencify(self.replacement_result)), + list(sequencify(bsym.output)), + ) + else: + safe_map_flat( + self.add_to_swap_map, + list(sequencify(bsym.output)), + list(sequencify(self.replacement_result)), + ) + + ### replace bsyms + + for new_bsym in self.new_bsyms: + # TODO: what to do with bsym header? Maybe have a combined from_bsym_swap_proxies and from_bsym? + self.new_trace.bound_symbols.append( + new_bsym.from_bsym_swap_proxies(self.swap_map).from_bsym( + source_filename=bsym.source_filename, source_positions=bsym.source_positions + ) + ) + + result = tree_map(self.do_swap, self.replacement_result) + + try: + safe_map_flat(self.write, list(sequencify(bsym.output)), list(sequencify(result))) + except AssertionError as e: + raise RuntimeError( + f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {bsym}." + " This is likely due to a mismatch in the number of outputs." + f" The original symbol has {len(bsym.output)} outputs and the dispatched function has {len(sequencify(result))} outputs." + ) from e + + return self.new_trace, tree_map(self.read, self.trace.output) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 77c6a3d8f4..9f162ac86f 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -18,6 +18,7 @@ import thunder.core.transforms as transforms from thunder.core.transform_common import dce from thunder.core.trace import get_tracectx +from thunder.core.trace_interpreter import interpret_trace, interpret_trace_to_trace, TraceSubstitutionProcessor from thunder.executors.pythonex import clear_mutable_collection from thunder.extend import Executor, get_always_executors, OperatorExecutor, FusionExecutor @@ -61,68 +62,63 @@ def preserve_bsym(bsym: BoundSymbol) -> Any: # If the BoundSymbol already has an executor then None is returned # If the executor has an execution transform, it's called and True is returned # If no executor can execute the BoundSymbol, False is returned - def visit_helper_(bsym: BoundSymbol) -> None | bool: - if bsym.sym.python_impl is not None: - return None - - ex: Executor - for ex in executors_list: - # TODO Consider allowing operator executors to claim portions of operations - # TODO Should FusionExecutors be allowed to claim bsym with bsym.sym.executor? - if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)) or ( - isinstance(ex, FusionExecutor) and ex.can_fuse(bsym) - ): - execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) - out: Any - if execution_transform is not None: - out = execution_transform(*bsym.args, **bsym.kwargs) - elif isinstance(ex, OperatorExecutor): - # NOTE execution_transform is None and the executor is an operator executor - # Calls the operator executor's operation - # TODO Instead of directly acquiring the symbol from the implmap, we probably - # want to hide this behind a function - op = ex.implmap[bsym.sym.id].symbol - out = op(*bsym.args, **bsym.kwargs) - elif isinstance(ex, FusionExecutor): - # NOTE execution_transform is None and the executor is a fusion executor - # Preserves the symbol as is (it will be handled in the fusion pass) - out = preserve_bsym(bsym) - else: - raise AssertionError("Unknown executor") - - safe_map_flat(update_swapmap, bsym.output, out) - return True - - if bsym.sym.executor is not None: - return None - - return False - - def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: - result: None | bool = visit_helper_(bsym) - - if result is None: - return transforms.VISIT_TYPE.NO_OP - - if result is True: - return transforms.VISIT_TYPE.REPLACE - - # NOTE result is False (which means no executor was found for the symbol) - cutils.check(not bsym.sym.is_prim, lambda: f"Failed to find an executor for bound symbol {bsym=}") - for sbsym in bsym.subsymbols: - visit_(sbsym) - - return transforms.VISIT_TYPE.REPLACE - - extrace = transforms.visitor_transform(trace, visit_) + class OpExProcessor(TraceSubstitutionProcessor): + def process_bsym(self, bsym): + if bsym.sym.python_impl is not None: + self.add_processed_bsyms([bsym]) + self.set_result(bsym.output) + return + + ex: Executor + for ex in executors_list: + # TODO Consider allowing operator executors to claim portions of operations + # TODO Should FusionExecutors be allowed to claim bsym with bsym.sym.executor? + if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)) or ( + isinstance(ex, FusionExecutor) and ex.can_fuse(bsym) + ): + execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) + out: Any + if execution_transform is not None: + self.bsyms_from_function(execution_transform, *bsym.args, **bsym.kwargs) + return + elif isinstance(ex, OperatorExecutor): + # NOTE execution_transform is None and the executor is an operator executor + # Calls the operator executor's operation + # TODO Instead of directly acquiring the symbol from the implmap, we probably + # want to hide this behind a function + op = ex.implmap[bsym.sym.id].symbol + self.bsyms_from_function(op, *bsym.args, **bsym.kwargs) + return + elif isinstance(ex, FusionExecutor): + # NOTE execution_transform is None and the executor is a fusion executor + # Preserves the symbol as is (it will be handled in the fusion pass) + self.add_processed_bsyms([bsym]) + self.set_result(bsym.output) + return + else: + raise AssertionError("Unknown executor") + + if bsym.sym.executor is not None: + self.add_processed_bsyms([bsym]) + self.set_result(bsym.output) + return + + # No executor found, need to descend + cutils.check(not bsym.sym.is_prim, lambda: f"Failed to find an executor for bound symbol {bsym=}") + ### OUTPUTS to map + self.add_unprocessed_bsyms(bsym.subsymbols[:]) + + extrace, _ = OpExProcessor(trace)() + + # extrace, _ = interpret_trace_to_trace(trace, *trace.args, symbol_mapper=symbol_mapper, **trace.kwargs) # Restores original variables - bound_symbols: list[BoundSymbol] = [] - for bsym in extrace.bound_symbols: - nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) - bound_symbols.append(nbsym) + # bound_symbols: list[BoundSymbol] = [] + # for bsym in extrace.bound_symbols: + # nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) + # bound_symbols.append(nbsym) - extrace.bound_symbols = bound_symbols + # extrace.bound_symbols = bound_symbols end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns @@ -151,7 +147,6 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) # extrace = _transform_for_operator_executor_execution(trace, executors_list) extrace = dce(extrace) - # # Step 2 Fusion executors can transform the trace # diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 3e7e7ed419..c688e134b6 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -10,13 +10,19 @@ from thunder.core.proxies import Proxy, TensorProxy, unvariableify, Variable from thunder.core.rematerialization import rematerialize from thunder.core.symbol import BoundSymbol, Symbol -from thunder.core.trace import from_trace, TraceCtx, TraceProvenance +from thunder.core.trace import from_trace, tracectx, TraceCtx, TraceProvenance from thunder.core.transform_common import dce from thunder.core.pytree import tree_flatten -from thunder.executors.passes import update_fusion_call_ctx +from thunder.executors.passes import ( + update_fusion_call_ctx, + _transform_for_operator_executor_execution, + transform_for_execution, +) from thunder.executors.utils import Region from thunder.extend import FusionExecutor, register_executor, ImplInfo from thunder.core.compile_data import get_compile_option +from thunder.executors.torchex import ex as pytorch_ex + _TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0", use_base_version=True) @@ -31,10 +37,9 @@ def to_torch_translator(bsym: BoundSymbol) -> Callable: Returns: A callable that can be executed by PyTorch after being traced by Thunder. """ - from thunder.executors.torchex import ex as torchex def _to_torch(*args, **kwargs) -> Any: - impl_info = torchex.implmap.get(bsym.sym.id) + impl_info = pytorch_ex.implmap.get(bsym.sym.id) torch_op = None if impl_info is not None: torch_op = impl_info.symbol @@ -42,12 +47,12 @@ def _to_torch(*args, **kwargs) -> Any: return impl_info.execution_transform(*args, **kwargs) if torch_op is None: - torch_op = torchex.opmap.get(bsym.sym.name) + torch_op = pytorch_ex.opmap.get(bsym.sym.name) # this should be really rare, but type_as has this, # ideally we would be also handling more subsymbols here if torch_op is None and len(bsym.subsymbols) == 1: - torch_op = torchex.opmap.get(bsym.subsymbols[0].sym.name) + torch_op = pytorch_ex.opmap.get(bsym.subsymbols[0].sym.name) if torch_op is None: raise RuntimeError("op not found for {bsym.sym.name}") @@ -63,36 +68,34 @@ def make_compiled( from thunder import trace from thunder.core.transforms import eval_trace from thunder.executors.torchex import no_autocast + from thunder.executors.pythonex import ex as pythonex + from thunder.core.codeutils import SigInfo # Here we construct a trace that will be used to compile the function + # TODO: maybe we should have a utility that does this properly region_trace = TraceCtx(None) - region_trace.bound_symbols = list(bsyms) region_trace.args = sorted_unique_inputs region_trace.kwargs = {} + with tracectx(region_trace): + for a in sorted_unique_inputs: + prims.unpack_trivial(a, name=a.name) + + region_trace.bound_symbols += list(bsyms) region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=None)) + # for a in region_trace.args: + # region_trace.add_name(a.name) + for bsym in region_trace.bound_symbols: + for o in bsym.flat_outs: + if o is not None: # TODO: investigate + region_trace.add_name(o.name) + + # maybe make this the default if no sig info is present? + region_trace._siginfo = SigInfo("to_be_compiled") + region_trace._siginfo.args = [(a.name, None) for a in region_trace.args] + + torchex_trace = transform_for_execution(region_trace, executors_list=(pytorch_ex,)) + trace_callable = torchex_trace.python_callable(include_decorators=False) - def torch_interpreted_func(*args): - return eval_trace(region_trace, *args, symbol_mapper=to_torch_translator) - - # Here instead of using thunder.trace we could use torch_trace = - # passes._transform_for_operator_executor_execution(region_trace, [torchex]) - # but then we would need to handle unpacking of the args explicitly For - # example with: - # try: - # token = set_tracectx(region_trace) - # col = CollectionProxy(region_trace.args, name="args") - # _ = prims.unpack_sequence(col, len(region_trace.args)) - # finally: - # reset_tracectx(token) - # region_trace.bound_symbols.extend(bsyms) - # But there are some issues with the - # _transform_for_operator_executor_execution implementation that need to be - # fixed first. One issue is that it doesn't maintain the ssa form of the - # trace, which is needed for all the passes to work correctly. - # TODO: issue "Try using _transform_for_operator_executor_execution for - # torch.compile executor" - torch_trace = trace(inline_trace=False)(torch_interpreted_func, *sorted_unique_inputs) - trace_callable = torch_trace.python_callable(include_decorators=False) torch_compile_fullgraph: None | bool = get_compile_option( "torch_compile_fullgraph", "Whether to enable `fullgraph` from `torch.compile`. Defaults to `True`." ) @@ -202,9 +205,6 @@ def _can_fuse_node(n: Node): return fusedtrace -from thunder.executors.torchex import ex as pytorch_ex - - def cuda_device_checker(*args, **kwargs): # We only want to compile if all the TensorProxy arguments are on the GPU flat_args, _ = tree_flatten((args, kwargs)) diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 223d898ec4..0cbef7299e 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -62,7 +62,7 @@ def test_nanogpt_complete(executor, device, dtype, recwarn): # TODO: Add float16 and bfloat16 comparison tests here and to all other tests in # this file. # See issue "Add half precision dtype tests to test_networks.py" -@instantiate(dtypes=(thunder.float32,), executors=all_test_executors_and_dynamo) +@instantiate(dtypes=(thunder.float32,)) # ), executors=all_test_executors_and_dynamo) def test_nanogpt_complete_autograd(executor, device, dtype): tdtype = ttorch.to_torch_dtype(dtype) diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index ea02dc8b5d..8b7c07b052 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -9,6 +9,7 @@ from thunder.executors.nvfuserex import nvfuserex from thunder.tests.bf16 import device_supports_bf16 from thunder.tests.framework import requiresCUDA +from torch.testing import assert_close def test_supported_ops_are_in_pytorch_executor(): @@ -79,3 +80,13 @@ def test_torch_compile_cat_rope_single_fusion(): backward_execution_trace = thunder.last_backward_traces(jfn)[-1] assert len(get_fusions(backward_execution_trace)) == 1 assert len(backward_execution_trace.bound_symbols) == 14 + + +@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") +def test_transform_for_execution_for_callable(): + def fn(a): + return a.type("torch.DoubleTensor") + + a = torch.randn(3) + jfn = thunder.jit(fn, executors=(thunder.executors.torch_compile.torch_compile_ex,)) + assert_close(jfn(a), fn(a)) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 6327b4d05c..e741fd3579 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5754,6 +5754,7 @@ def register_default_torch_op(torchfn: Callable, torch_module): name=torchfn_name, meta=_fn, id=f"{torch_module.__name__}.{torchfn_name}", + is_prim=True, tags=(prims.OpTags.AUTO_REGISTERED,), ) From 24d6a8d149d92fcb0d7280e2e61c06b9c60454b0 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 3 Dec 2024 10:46:36 +0100 Subject: [PATCH 2/7] deduplicate saved_for_backward --- thunder/core/trace_interpreter.py | 1 - thunder/core/vjp_utils.py | 21 ++++++++++++++++++++- thunder/executors/torch_autograd.py | 19 ++++++++++++++++++- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index 37694ade6b..b661eca3db 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -213,7 +213,6 @@ def __init__(self, trace, *args, **kwargs): self.trace = trace self.new_trace = from_trace(self.trace) self.have_processed_args = False - print(self.trace) def read(self, x: VariableInterface | Any) -> Any: if isinstance(x, VariableInterface): diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 0921bf7c6e..b6f3712e5e 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -1,5 +1,5 @@ import inspect -from collections.abc import Callable +from collections.abc import Callable, Sequence from functools import wraps from inspect import Parameter, Signature from itertools import chain @@ -229,3 +229,22 @@ def get_saved_for_backward_tensors(trace: TraceCtx) -> tuple[TensorProxy]: lambda: "All saved tensors must be TensorProxy or None", ) return tuple(saved_tensors) + + +def set_saved_for_backward_tensors(trace: TraceCtx, saved_tensors: Sequence[TensorProxy]): + """ + Given a trace, return the tensors that are saved for backward in the trace. + + Args: + trace: The trace to set saved tensors for. + saved_tensors: proxies for the tensors to save. + """ + utils.check( + all(isinstance(t, TensorProxy) or t is None for t in saved_tensors), + lambda: "All saved tensors must be TensorProxy or None", + ) + ret_node = trace.bound_symbols.pop(-1) + assert ret_node.sym == prims.python_return + output = ret_node.args + output = (output[0], (tuple(saved_tensors), *output[1][1:]), *output[2:]) + trace.bound_symbols.append(ret_node.from_bsym(args=output)) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 5374b23afe..ce9497125b 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -10,7 +10,7 @@ from thunder.core.symbol import BoundSymbol from thunder.core.trace import TraceCtx, from_trace, set_tracectx, reset_tracectx from thunder.core.transform_common import replace_redundant_inputs -from thunder.core.vjp_utils import get_saved_for_backward_tensors +from thunder.core.vjp_utils import get_saved_for_backward_tensors, set_saved_for_backward_tensors if TYPE_CHECKING: from thunder.core.trace import VariableInterface @@ -240,6 +240,23 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat skip_output=False, skip_subsymbols=False, ) + + # remove duplicates + # The NVFuser (and possibly others) fusion pass applied on the forward during has a + # CSE pass that may lead to duplicate symbols saved for backward. This causes trouble + # because we see duplicates in the unpacking. But the passes are unaware of the backward, + # so they cannot handle it themselves, so we clean this up here. + seen = set() + new_fw_out = [] + new_bw_inp = [] + for p_fw, p_bw in zip(get_saved_for_backward_tensors(fw_extrace), new_bsyms[4].output, strict=True): + if p_fw.name not in seen: + seen.add(p_fw.name) + new_fw_out.append(p_fw) + new_bw_inp.append(p_bw) + new_bsyms[4] = new_bsyms[4].from_bsym(output=tuple(new_bw_inp)) + set_saved_for_backward_tensors(fw_extrace, new_fw_out) + bw_trace.bound_symbols = new_bsyms if getattr(compile_data.fn, "use_fsdp", False): From e64fab1b84cbcb157ec214d87b3a16b6484f2c5b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 3 Dec 2024 11:59:02 +0100 Subject: [PATCH 3/7] dce duplicate creation of number proxies --- thunder/core/transform_common.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index bfe2cb376f..dc8d053075 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -9,7 +9,7 @@ import thunder import thunder.core.prims as prims -from thunder.core.baseutils import BoundSymbolInterface +from thunder.core.baseutils import BoundSymbolInterface, NumberProxyInterface from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify from thunder.core.pytree import tree_flatten, tree_iter, tree_map, tree_unflatten from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags @@ -111,6 +111,24 @@ def check(inp, log_str): check(copy_to_out, "output") +def remove_duplicate_number_proxies(bsyms): + seen = set() + + def keep_or_swap(p): + if not isinstance(p, NumberProxyInterface): + return p + if p.name in seen: + return p.value # don't make it a duplicate + seen.add(p.name) + return p + + new_bsyms = [] + for bsym in bsyms: + output = tree_map(keep_or_swap, bsym.output) + new_bsyms.append(bsym.from_bsym(output=output)) + return new_bsyms + + # TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly # improve performance # Runs a Dead Code Elimination (DCE) pass @@ -174,7 +192,9 @@ def _helper(x): needed_proxies.add(variableify(x)) dcetrace = from_trace(trace) - dcetrace.bound_symbols = list(reversed(dced)) + dced_bound_symbols = list(reversed(dced)) + dced_bound_symbols = dced_bound_symbols + dcetrace.bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols) end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns From ac6bf589256ed2426b04166ff0b4c69e77e6296c Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 3 Dec 2024 12:45:11 +0100 Subject: [PATCH 4/7] allow duplicates to accomodate pre-dce symbol lists --- thunder/core/trace_interpreter.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index b661eca3db..48a0084a7c 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Any from thunder.core import prims @@ -45,9 +46,9 @@ def read(x: VariableInterface | Any) -> Any: def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None: if not isinstance(v, VariableInterface): return - # Duplicates are allowed and overwritten if v.name in env: if allow_duplicates: + # Duplicates are allowed and not overwritten return raise ValueError(f"Variable {v.name} is being overwritten this is not allowed") env[v.name] = val @@ -104,9 +105,9 @@ def read(x: VariableInterface | Any) -> Any: def write(v: VariableInterface | Any, val: Any, allow_duplicates=False) -> None: if not isinstance(v, VariableInterface): return - # Duplicates are allowed and overwritten if v.name in env: if allow_duplicates: + # Duplicates are allowed and not overwritten return raise ValueError(f"Variable {v.name} is being overwritten this is not allowed") env[v.name] = val @@ -220,11 +221,13 @@ def read(self, x: VariableInterface | Any) -> Any: else: return x - def write(self, v: VariableInterface | Any, val: Any) -> None: + def write(self, v: VariableInterface | Any, val: Any, allow_duplicates=True) -> None: if not isinstance(v, VariableInterface): return - # Duplicates are allowed and overwritten if v.name in self.env: + if allow_duplicates: + # Duplicates are allowed and not overwritten + return raise ValueError(f"Variable {v.name} is being overwritten this is not allowed") self.env[v.name] = val @@ -350,8 +353,13 @@ def __call__(self): result = tree_map(self.do_swap, self.replacement_result) + # we need to allow duplicates here because the re-interpretation is not necessairly DCEed when subsymbols symbols are flattened into the trace after re-execution. try: - safe_map_flat(self.write, list(sequencify(bsym.output)), list(sequencify(result))) + safe_map_flat( + partial(self.write, allow_duplicates=True), + list(sequencify(bsym.output)), + list(sequencify(result)), + ) except AssertionError as e: raise RuntimeError( f"Error while assigning the result of dispatched function {prim_func} to the output of the original symbol {bsym}." From 705bff9d0e8418190db8022f7f961cf25a649f7d Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 3 Dec 2024 12:51:48 +0100 Subject: [PATCH 5/7] disable inductor-using test on windows --- thunder/tests/test_torch_compile_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index 8b7c07b052..6560dddbc8 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -82,7 +82,7 @@ def test_torch_compile_cat_rope_single_fusion(): assert len(backward_execution_trace.bound_symbols) == 14 -@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") +@pytest.mark.skipif(not is_inductor_supported() or platform.system() == "Windows", reason="inductor unsupported") def test_transform_for_execution_for_callable(): def fn(a): return a.type("torch.DoubleTensor") From 26542296a52da87cc3eadc347af289dda1053cb0 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 3 Dec 2024 12:58:32 +0100 Subject: [PATCH 6/7] cleanup --- thunder/executors/passes.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 9f162ac86f..6060e0a50a 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -111,15 +111,6 @@ def process_bsym(self, bsym): extrace, _ = OpExProcessor(trace)() - # extrace, _ = interpret_trace_to_trace(trace, *trace.args, symbol_mapper=symbol_mapper, **trace.kwargs) - # Restores original variables - # bound_symbols: list[BoundSymbol] = [] - # for bsym in extrace.bound_symbols: - # nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) - # bound_symbols.append(nbsym) - - # extrace.bound_symbols = bound_symbols - end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 From 4f798aabf421391abf17d58b2ffa82bfd0ec5658 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 3 Dec 2024 20:19:28 +0100 Subject: [PATCH 7/7] address review comments, thank you Mike! --- thunder/core/trace_interpreter.py | 22 +++++++++++++++----- thunder/core/transform_common.py | 13 +++++++++--- thunder/executors/passes.py | 33 +++++++++++++----------------- thunder/executors/torch_compile.py | 4 +--- 4 files changed, 42 insertions(+), 30 deletions(-) diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index 48a0084a7c..edff5937a9 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -207,6 +207,22 @@ def do_swap(v): class TraceSubstitutionProcessor: + """This processes a trace in an interpretation-style way by looping over the bound symbols. + This processing aims to preserve as much information on the proxies as possible. + + Args: + trace: trace to process + *args: arguments to process the trace with + **kwargs: keyword arguments to process the trace with + + The user is expected to subclass the trace and implement process_bsym with the help of add_unprocessed_bsyms (useful eg for using subsymbols to compute a symbol), add_processed_bsyms, and add_bsyms_from_function. + + Calling the instantiated object initiates the processing and returns + the new trace and a mapping of the outputs. + + See the OpExProcessor in thunder.executors.passes._transform_for_operator_executor_execution for an example of subclassing. + """ + NULL = object() def __init__(self, trace, *args, **kwargs): @@ -267,7 +283,7 @@ def do_swap(self, v): def add_unprocessed_bsyms(self, bsyms): self.unprocessed_bsyms[:0] = bsyms - def bsyms_from_function(self, fn, /, *args, **kwargs): + def add_bsyms_from_function(self, fn, /, *args, **kwargs): self.new_trace.push_scope([]) result = fn(*args, **kwargs) self.new_bsyms += self.new_trace.pop_scope() @@ -275,8 +291,6 @@ def bsyms_from_function(self, fn, /, *args, **kwargs): return result def add_processed_bsyms(self, bsyms): - - ### replacements of inputs! self.new_bsyms += bsyms def set_result(self, result): @@ -298,8 +312,6 @@ def process_args(self, *args, **kwargs): safe_map_flat(self.write, list(self.trace.kwargs.values()), list(kwargs.values())) def __call__(self): - # if not self.have_processed_args and self.trace.args is not None: - # self.process_args(*self.args, **self.kwargs) with tracectx(self.new_trace): self.unprocessed_bsyms = self.trace.bound_symbols[:] diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index dc8d053075..bfe4123dc6 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -111,7 +111,12 @@ def check(inp, log_str): check(copy_to_out, "output") -def remove_duplicate_number_proxies(bsyms): +def remove_duplicate_number_proxies(bsyms: Sequence[BoundSymbol]) -> list[BoundSymbol]: + """This removes duplicate number proxies when they are returned multiple times. + The remaining DCE pass does not see them (because they often are in a tuple?). + In particular, proxies may be extracted multiple times when using the thunder.jit's + symbolic constraints mode. + """ seen = set() def keep_or_swap(p): @@ -193,8 +198,10 @@ def _helper(x): dcetrace = from_trace(trace) dced_bound_symbols = list(reversed(dced)) - dced_bound_symbols = dced_bound_symbols - dcetrace.bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols) + # duplicate number proxies happen with the symbolic shapes and are + # not covered by the above (due to being in tuples?). + dced_bound_symbols = remove_duplicate_number_proxies(dced_bound_symbols) + dcetrace.bound_symbols = dced_bound_symbols end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 6060e0a50a..f7bbe20ca5 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -28,10 +28,6 @@ # Transforms a trace by determining which execution transforms to call given the list of executors in priority order # This pass tries to preserve the original trace and proxies. -# Implementation Steps - -# 1. The trace is updated with `visitor_transform` with `visit_helper_` (where executors try to claim the symbols). Note that this replaces the output proxies in the trace. -# 2. `visit_helper_` also creates a swapmap from the new symbols back to old one. -# 3. After the `visitor_transform`, it iterates over the updated trace and puts back the old proxies. def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: start_time_ns = time.perf_counter_ns() @@ -50,22 +46,21 @@ def update_swapmap(o: Any, no: Any) -> None: return swapmap[vno] = o - def preserve_bsym(bsym: BoundSymbol) -> Any: - trace = get_tracectx() - trace.scopes[-1].append(bsym) - for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): - trace.names.add(p.name) - return bsym.output - - # TODO Consider using an enum for this function's return values - # Tries to find an executor for the BoundSymbol - # If the BoundSymbol already has an executor then None is returned - # If the executor has an execution transform, it's called and True is returned - # If no executor can execute the BoundSymbol, False is returned - + # This processes the bsyms to map symbols to operator executors: + # - if a bsym has a python impl, that will be called, so we can keep it. + # - in the order of the executor list + # - if the executor defines an execution transform, call that to + # create symbols for the trace, + # - for operator executors, if we have an implmap entry for the symbol, + # execute that + # - for fusion executors, check if the symbol can be fused (done later) + # - if none of these apply, and the symbol is not a prim, replace the symbol + # with its subsymbols (which will then be processed using the above), + # - if none of the above apply and we have a prim, raise an error class OpExProcessor(TraceSubstitutionProcessor): def process_bsym(self, bsym): if bsym.sym.python_impl is not None: + # keep the bound symbol and use the python impl self.add_processed_bsyms([bsym]) self.set_result(bsym.output) return @@ -80,7 +75,7 @@ def process_bsym(self, bsym): execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) out: Any if execution_transform is not None: - self.bsyms_from_function(execution_transform, *bsym.args, **bsym.kwargs) + self.add_bsyms_from_function(execution_transform, *bsym.args, **bsym.kwargs) return elif isinstance(ex, OperatorExecutor): # NOTE execution_transform is None and the executor is an operator executor @@ -88,7 +83,7 @@ def process_bsym(self, bsym): # TODO Instead of directly acquiring the symbol from the implmap, we probably # want to hide this behind a function op = ex.implmap[bsym.sym.id].symbol - self.bsyms_from_function(op, *bsym.args, **bsym.kwargs) + self.add_bsyms_from_function(op, *bsym.args, **bsym.kwargs) return elif isinstance(ex, FusionExecutor): # NOTE execution_transform is None and the executor is a fusion executor diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index c688e134b6..345597f669 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -82,11 +82,9 @@ def make_compiled( region_trace.bound_symbols += list(bsyms) region_trace.bound_symbols.append(prims.python_return.bind(sorted_unique_outputs, output=None)) - # for a in region_trace.args: - # region_trace.add_name(a.name) for bsym in region_trace.bound_symbols: for o in bsym.flat_outs: - if o is not None: # TODO: investigate + if o is not None: region_trace.add_name(o.name) # maybe make this the default if no sig info is present?