From 64ffc97eca5903d46b5635a7dcef648d48630f55 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Thu, 12 Dec 2024 20:57:57 +0100 Subject: [PATCH 1/8] ThunderFX: handles the callable input of fx.Node (#1539) --- thunder/dynamo/splitter.py | 1 + thunder/dynamo/utils.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 587729dde3..1474268657 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -135,6 +135,7 @@ def callback(node) -> int: return partition_cnt # `split_module` iterates over nodes and determines the partition to place them based on the callback. + gm.graph.eliminate_dead_code() original_split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 5f31d790a3..47a326b378 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -159,10 +159,12 @@ def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]: # We need to be under trace context to generate proxies. with thunder.core.trace.tracectx(TraceCtx()): - def make_tensor_proxy(arg_node): + def make_input_proxy(arg_node): # This is a Node in the graph representing a Tensor or tuple of Tensors or # a PyTorch object like one representing torch.autocast. if isinstance(arg_node, torch.fx.Node): + if arg_node.op == "get_attr": + return getattr(arg_node.graph.owning_module, arg_node.target) if "example_value" not in arg_node.meta: # This is a non tensor object like `torch.autocast` ctx manager object. return arg_node @@ -185,14 +187,14 @@ def make_tensor_proxy(arg_node): ) else: # NOTE - This will be caught will be caught and be part of the SplitReason. - raise TypeError(f"Received `make_tensor_proxy` received example_value which wasn't Tensor or Tuple") + raise TypeError(f"Received `make_input_proxy` received example_value which wasn't Tensor or Tuple") return proxy(example_value) # This is int, float, etc. return arg_node - proxy_args = torch.fx.map_arg(node.args, make_tensor_proxy) - proxy_kwargs = {k: torch.fx.map_arg(v, make_tensor_proxy) for k, v in node.kwargs.items()} + proxy_args = torch.fx.map_arg(node.args, make_input_proxy) + proxy_kwargs = {k: torch.fx.map_arg(v, make_input_proxy) for k, v in node.kwargs.items()} return proxy_args, proxy_kwargs From 0d47ae14132376c8db35c395a6466da5b8c8c3b4 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Fri, 13 Dec 2024 14:39:12 +0100 Subject: [PATCH 2/8] is_node_supported_by_thunder: check the submodule of autograd_function_apply --- thunder/dynamo/splitter.py | 3 ++- thunder/dynamo/utils.py | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 1474268657..f16bad1c67 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -134,8 +134,9 @@ def callback(node) -> int: supported_partitions.add(partition_cnt) return partition_cnt - # `split_module` iterates over nodes and determines the partition to place them based on the callback. gm.graph.eliminate_dead_code() + gm.recompile() + # `split_module` iterates over nodes and determines the partition to place them based on the callback. original_split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 47a326b378..496984a2ce 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -163,8 +163,11 @@ def make_input_proxy(arg_node): # This is a Node in the graph representing a Tensor or tuple of Tensors or # a PyTorch object like one representing torch.autocast. if isinstance(arg_node, torch.fx.Node): + # Higher-order operator nodes take get_attr nodes as input to get the called module if arg_node.op == "get_attr": - return getattr(arg_node.graph.owning_module, arg_node.target) + attr = getattr(arg_node.graph.owning_module, arg_node.target) + if isinstance(attr, torch.nn.Module): + return attr if "example_value" not in arg_node.meta: # This is a non tensor object like `torch.autocast` ctx manager object. return arg_node @@ -372,13 +375,15 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason return False, split_reason # The checkpointed function must be fully supported by Thunder - if target is torch.ops.higher_order.tag_activation_checkpoint: + if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply): m = node.graph.owning_module - get_attr_node = node.args[0] - assert get_attr_node.op == "get_attr" - checkpointed_fn = getattr(m, get_attr_node.target) - is_module_supported, split_reason = is_graphmodule_supported_by_thunder(checkpointed_fn) - return is_module_supported, split_reason + for arg_node in node.args: + if arg_node.op == "get_attr": + called_module = getattr(m, arg_node.target) + is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module) + if not is_module_supported: + return is_module_supported, split_reason + return True, None # If thunder has a mapping for this operation, try executing the meta function and see. # We have a symbol for `torch.where`, but we don't support one overload of it. From 788fda262aa174802f27f3d9de172d116e72f036 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Mon, 16 Dec 2024 15:40:45 +0100 Subject: [PATCH 3/8] rm unused torch.autograd.function.FunctionCtx --- thunder/dynamo/splitter.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index f16bad1c67..84fdd9f8c7 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -134,8 +134,14 @@ def callback(node) -> int: supported_partitions.add(partition_cnt) return partition_cnt - gm.graph.eliminate_dead_code() + # Removes the unused torch.autograd.function.FunctionCtx + functionctx_nodes_to_del = ( + n for n in gm.graph.find_nodes(op="call_function", target=torch.autograd.function.FunctionCtx) if not n.users + ) + for n in functionctx_nodes_to_del: + gm.graph.erase_node(n) gm.recompile() + # `split_module` iterates over nodes and determines the partition to place them based on the callback. original_split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True From 0a6a39b66e31394cd89368eff33597eabbc78b9b Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Wed, 18 Dec 2024 21:46:09 +0900 Subject: [PATCH 4/8] Iterate over `fwd_args` for hopefully more precise `new_fwd_args` (#1565) Signed-off-by: Masaki Kozuki --- thunder/core/jit_ext.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 726353983a..e9876fa40f 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -791,7 +791,22 @@ def _generate_random_str_id() -> str: # note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087 # non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx") length_of_tensor_args = sum(args_tensor_mask) - new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args] + + # N.B.(crcrpar) When `torch.compile(..., dynamic=True)`, + # GraphModules' forward seem to take `SymInt` and other values + # as its argument with some probability. Though that piece of information unfortunately + # does not seem to be indicated in ``args_tensor_`` nor ``non_differentiable_idx``. + # Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values to ``fwd_args``. + new_fwd_args = [] + for i, v in enumerate(fwd_args): + if i < length_of_tensor_args: + new_fwd_args.append(v) + else: + # note(crcrpar): we might want to include `FutureTensorProxy` and + # a proxy of tensor subclass in the near future. + if not isinstance(unwrap(v), TensorProxy): + new_fwd_args.append(v) + new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args) aug_fwd_trace, aug_fwd_provenance = _convert_pytorchfunc_to_thundertrace(fwd, False, *new_fwd_args) if aug_fwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED: From 16bbe283c09b8b69b892ff6654648d31c09a58ee Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Thu, 19 Dec 2024 15:08:10 +0100 Subject: [PATCH 5/8] fix: test_splitter_autograd_function --- thunder/tests/test_dynamo.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 4bcd2333f4..1c90f5b0be 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -274,7 +274,7 @@ def forward(ctx, x): @staticmethod def backward(ctx, g): (x,) = ctx.saved_tensors - return g * torch.cos(x) + return g * torch.cos(x) * 100 def func(x): y = torch.cos(x) + Sin.apply(x) @@ -286,9 +286,16 @@ def func(x): actual = cfunc(x) backend = cfunc._backend - targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) - assert any(target.startswith("thunder_") for target in targets) - assert any(target.startswith("inductor_") for target in targets) + assert len(backend.subgraph_infos) == 1 # no graph break in dynamo + subgraph_info = backend.subgraph_infos[0] + assert len(subgraph_info.split_reasons) == 0 # no split + assert len(subgraph_info.thunder_compiled_fns) == 1 + jfunc = subgraph_info.thunder_compiled_fns[0] + trc = last_traces(jfunc)[0] + assert any( + isinstance(bsym.sym.id, str) and bsym.sym.id.startswith("higher_order_autograd_function_apply") + for bsym in trc.bound_symbols + ) # Verify forward pass torch.testing.assert_close(actual, expected) From 5669191835e2eacfcfe7d4343fbcb149fff0df5e Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Mon, 13 Jan 2025 10:57:26 +0100 Subject: [PATCH 6/8] fix --- thunder/tests/test_dynamo.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 80fef752b5..7baacdda7d 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -263,7 +263,10 @@ def func(x): ), ) def test_splitter_autograd_function(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): - x = torch.ones(2, device=device, dtype=dtype, requires_grad=True) + # Workaround for "RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered" + # https://github.com/pytorch/pytorch/issues/124565 + if device != "cpu": + torch.empty(1, device="cuda", requires_grad=True).backward() class Sin(torch.autograd.Function): @staticmethod @@ -280,6 +283,7 @@ def func(x): y = torch.cos(x) + Sin.apply(x) return torch.matmul(x, y) + x = torch.ones(2, device=device, dtype=dtype, requires_grad=True) expected = torch.compile(func, dynamic=dynamic)(x) cfunc = thunderfx(func, dynamic=dynamic) From a196704d8a77991d4975360dba514d8dc48ab6f7 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Thu, 16 Jan 2025 14:09:24 +0100 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Masaki Kozuki Co-authored-by: Kshiteej K --- thunder/core/jit_ext.py | 4 ++-- thunder/dynamo/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 952fd311c6..c65ed2a1fd 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -795,8 +795,8 @@ def _generate_random_str_id() -> str: # N.B.(crcrpar) When `torch.compile(..., dynamic=True)`, # GraphModules' forward seem to take `SymInt` and other values # as its argument with some probability. Though that piece of information unfortunately - # does not seem to be indicated in ``args_tensor_`` nor ``non_differentiable_idx``. - # Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values to ``fwd_args``. + # does not seem to be indicated in ``args_tensor_mask`` nor ``non_differentiable_idx``. + # Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values whose index is >= `length_of_tensor_args` to ``fwd_args``. new_fwd_args = [] for i, v in enumerate(fwd_args): if i < length_of_tensor_args: diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 71a2c53d53..f0a3aacdf2 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -190,8 +190,8 @@ def make_input_proxy(arg_node): for e_v in example_value ) else: - # NOTE - This will be caught will be caught and be part of the SplitReason. - raise TypeError(f"Received `make_input_proxy` received example_value which wasn't Tensor or Tuple") + # NOTE - This will be caught and be part of the SplitReason. + raise TypeError(f"`make_input_proxy` received example_value which wasn't Tensor or Tuple") return proxy(example_value) # This is int, float, etc. From beb392883f286cc60b5902fe5849e635ca9ac8ee Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Thu, 16 Jan 2025 14:15:36 +0100 Subject: [PATCH 8/8] fix comments --- thunder/dynamo/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index f0a3aacdf2..99e0acc1bd 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -359,7 +359,7 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason ) return False, split_reason - # The checkpointed function must be fully supported by Thunder + # The higher order function must be fully supported by Thunder if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply): m = node.graph.owning_module for arg_node in node.args: