diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index d87e3bf5a5..4c3ff96e7f 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -787,7 +787,22 @@ def _generate_random_str_id() -> str: # note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087 # non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx") length_of_tensor_args = sum(args_tensor_mask) - new_fwd_args = (wrap_const(None),) + fwd_args[:length_of_tensor_args] + + # N.B.(crcrpar) When `torch.compile(..., dynamic=True)`, + # GraphModules' forward seem to take `SymInt` and other values + # as its argument with some probability. Though that piece of information unfortunately + # does not seem to be indicated in ``args_tensor_mask`` nor ``non_differentiable_idx``. + # Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values whose index is >= `length_of_tensor_args` to ``fwd_args``. + new_fwd_args = [] + for i, v in enumerate(fwd_args): + if i < length_of_tensor_args: + new_fwd_args.append(v) + else: + # note(crcrpar): we might want to include `FutureTensorProxy` and + # a proxy of tensor subclass in the near future. + if not isinstance(unwrap(v), TensorProxy): + new_fwd_args.append(v) + new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args) aug_fwd_trace, aug_fwd_provenance = _convert_pytorchfunc_to_thundertrace(fwd, False, *new_fwd_args) if aug_fwd_trace is INTERPRETER_SIGNALS.EXCEPTION_RAISED: diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index 587729dde3..84fdd9f8c7 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -134,6 +134,14 @@ def callback(node) -> int: supported_partitions.add(partition_cnt) return partition_cnt + # Removes the unused torch.autograd.function.FunctionCtx + functionctx_nodes_to_del = ( + n for n in gm.graph.find_nodes(op="call_function", target=torch.autograd.function.FunctionCtx) if not n.users + ) + for n in functionctx_nodes_to_del: + gm.graph.erase_node(n) + gm.recompile() + # `split_module` iterates over nodes and determines the partition to place them based on the callback. original_split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 24bcc4c816..3b8915f19a 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -160,10 +160,15 @@ def get_proxy_inputs_from_node(node: torch.fx.Node) -> tuple[tuple, dict]: # We need to be under trace context to generate proxies. with thunder.core.trace.tracectx(TraceCtx()): - def make_tensor_proxy(arg_node): + def make_input_proxy(arg_node): # This is a Node in the graph representing a Tensor or tuple of Tensors or # a PyTorch object like one representing torch.autocast. if isinstance(arg_node, torch.fx.Node): + # Higher-order operator nodes take get_attr nodes as input to get the called module + if arg_node.op == "get_attr": + attr = getattr(arg_node.graph.owning_module, arg_node.target) + if isinstance(attr, torch.nn.Module): + return attr if "example_value" not in arg_node.meta: # This is a non tensor object like `torch.autocast` ctx manager object. return arg_node @@ -185,15 +190,15 @@ def make_tensor_proxy(arg_node): for e_v in example_value ) else: - # NOTE - This will be caught will be caught and be part of the SplitReason. - raise TypeError(f"Received `make_tensor_proxy` received example_value which wasn't Tensor or Tuple") + # NOTE - This will be caught and be part of the SplitReason. + raise TypeError(f"`make_input_proxy` received example_value which wasn't Tensor or Tuple") return proxy(example_value) # This is int, float, etc. return arg_node - proxy_args = torch.fx.map_arg(node.args, make_tensor_proxy) - proxy_kwargs = {k: torch.fx.map_arg(v, make_tensor_proxy) for k, v in node.kwargs.items()} + proxy_args = torch.fx.map_arg(node.args, make_input_proxy) + proxy_kwargs = {k: torch.fx.map_arg(v, make_input_proxy) for k, v in node.kwargs.items()} return proxy_args, proxy_kwargs @@ -354,14 +359,16 @@ def is_node_supported_by_thunder(node: torch.fx.Node) -> tuple[bool, SplitReason ) return False, split_reason - # The checkpointed function must be fully supported by Thunder - if target is torch.ops.higher_order.tag_activation_checkpoint: + # The higher order function must be fully supported by Thunder + if target in (torch.ops.higher_order.tag_activation_checkpoint, torch.ops.higher_order.autograd_function_apply): m = node.graph.owning_module - get_attr_node = node.args[0] - assert get_attr_node.op == "get_attr" - checkpointed_fn = getattr(m, get_attr_node.target) - is_module_supported, split_reason = is_graphmodule_supported_by_thunder(checkpointed_fn) - return is_module_supported, split_reason + for arg_node in node.args: + if arg_node.op == "get_attr": + called_module = getattr(m, arg_node.target) + is_module_supported, split_reason = is_graphmodule_supported_by_thunder(called_module) + if not is_module_supported: + return is_module_supported, split_reason + return True, None # If thunder has a mapping for this operation, try executing the meta function and see. # We have a symbol for `torch.where`, but we don't support one overload of it. diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index bac1c639c7..87a356c817 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -263,7 +263,10 @@ def func(x): ), ) def test_splitter_autograd_function(executor, device: str, dtype: dtypes.dtype, dynamic: bool | None): - x = torch.ones(2, device=device, dtype=dtype, requires_grad=True) + # Workaround for "RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered" + # https://github.com/pytorch/pytorch/issues/124565 + if device != "cpu": + torch.empty(1, device="cuda", requires_grad=True).backward() class Sin(torch.autograd.Function): @staticmethod @@ -274,21 +277,29 @@ def forward(ctx, x): @staticmethod def backward(ctx, g): (x,) = ctx.saved_tensors - return g * torch.cos(x) + return g * torch.cos(x) * 100 def func(x): y = torch.cos(x) + Sin.apply(x) return torch.matmul(x, y) + x = torch.ones(2, device=device, dtype=dtype, requires_grad=True) expected = torch.compile(func, dynamic=dynamic)(x) cfunc = thunderfx(func, dynamic=dynamic) actual = cfunc(x) backend = cfunc._backend - targets = (node.target for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) - assert any(target.startswith("thunder_") for target in targets) - assert any(target.startswith("inductor_") for target in targets) + assert len(backend.subgraph_infos) == 1 # no graph break in dynamo + subgraph_info = backend.subgraph_infos[0] + assert len(subgraph_info.split_reasons) == 0 # no split + assert len(subgraph_info.thunder_compiled_fns) == 1 + jfunc = subgraph_info.thunder_compiled_fns[0] + trc = last_traces(jfunc)[0] + assert any( + isinstance(bsym.sym.id, str) and bsym.sym.id.startswith("higher_order_autograd_function_apply") + for bsym in trc.bound_symbols + ) # Verify forward pass torch.testing.assert_close(actual, expected)