From 4780620ffa87603c5cf81de065993203f93fbf16 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Tue, 12 Nov 2024 17:12:28 +0100 Subject: [PATCH] fix test --- thunder/dynamo/compiler.py | 3 ++- thunder/dynamo/utils.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index 9ca264a4e4..0e26b40640 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -85,7 +85,8 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor return split_module def save_reproducer_to_folder(self, reproducer_folder: str | PathLike): - """Save the reproducer script for the GraphModule executed by Thunder to the specified `reproducer_folder`. + """ + Save the reproducer script for the GraphModule executed by Thunder to the specified `reproducer_folder`. Each saved script is named as "g[graph_id]_thunder_[module_id]", where: - `graph_id` indexes the graph generated by Dynamo, which is then passed to Thunder. - `module_id` indexes the submodule split by the :func:`thunder.dynamo.utils._splitter`. diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index f27472a5cb..16530bbf38 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -495,7 +495,7 @@ def _get_example_inputs_from_placeholder( return _get_example_input_tensor_metadata(ev) return ev.detach().clone().requires_grad_(ev.requires_grad) - check("example_value" in node.meta, lambda: "example_value does not exist in the meta of {node}", ValueError) + check("example_value" in node.meta, lambda: f"example_value does not exist in the meta of {node}", ValueError) example_value = node.meta["example_value"] if isinstance(example_value, torch.Tensor): @@ -503,6 +503,11 @@ def _get_example_inputs_from_placeholder( if only_metadata: return ev_metadata return _create_random_tensor_from_tensor_metadata(ev_metadata) + elif isinstance(example_value, tuple): + ev_metadatas = tuple(_get_example_input_tensor_metadata(e_v) for e_v in example_value) + if only_metadata: + return ev_metadatas + return tuple(_create_random_tensor_from_tensor_metadata(ev_metadata) for ev_metadata in ev_metadatas) elif isinstance(example_value, torch.types.py_sym_types): return example_value.node.hint else: