Skip to content

Commit

Permalink
Add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Oct 31, 2024
1 parent 9e6cfca commit e3e0451
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
12 changes: 8 additions & 4 deletions thunder/dynamo/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,13 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor
self.subgraph_infos.append(subgraph_info)
return split_module

def save_reproducer_to_folder(self, reproducer_folder_name: str | PathLike):
def save_reproducer_to_folder(self, reproducer_folder: str | PathLike):
"""Save the reproducer script for the GraphModule executed by Thunder to the specified `reproducer_folder`.
Each saved script is named as "g[graph_id]_thunder_[module_id]", where:
- `graph_id` indexes the graph generated by Dynamo, which is then passed to Thunder.
- `module_id` indexes the submodule split by the :func:`thunder.dynamo.utils._splitter`.
Both `graph_id` and `module_id` start from 1.
"""
if not self.subgraph_infos:
raise TypeError(f"{self} doesn't seem to have been called yet.")

Expand All @@ -97,6 +103,4 @@ def save_reproducer_to_folder(self, reproducer_folder_name: str | PathLike):
thunder_modules = subgraph_info.thunder_compiled_fns
example_inputs = subgraph_info.thunder_compiled_fns_example_inputs
for cur_module, example_input, cur_name in safe_zip(thunder_modules, example_inputs, thunder_module_names):
reproducer(
getattr(cur_module, "_model"), example_input, reproducer_folder_name, f"{graph_idx+1}_{cur_name}"
)
reproducer(getattr(cur_module, "_model"), example_input, reproducer_folder, f"{graph_idx+1}_{cur_name}")
6 changes: 5 additions & 1 deletion thunder/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,11 @@ def _create_random_tensor_from_tensor_metadata(t: ExampleInputMetaData) -> torch

def _get_example_inputs_from_placeholder(
node: torch.fx.Node, only_metadata=False
) -> tuple[torch.Tensor | ExampleInputMetaData]:
) -> tuple[torch.Tensor | ExampleInputMetaData] | torch.Tensor | ExampleInputMetaData:
"""Retrieves example input data for a given placeholder `torch.fx.Node`.
- When `only_metadata` is `False`: Generates and returns a random example tensor based on the node's expected shape and data type, etc.
- When `only_metadata` is `True`: Returns only the tensor's metadata (e.g., shape, data type) without generating an actual tensor.
"""
check(node.op == "placeholder", lambda: f"The node must be placeholder type", ValueError)
# Prefers to use actual example value in GraphArg if available
if "grapharg" in node.meta:
Expand Down

0 comments on commit e3e0451

Please sign in to comment.