From 2a1ac60c06b18ea404cd08104acfedb2c51358db Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Wed, 21 Aug 2024 17:00:15 +0900 Subject: [PATCH] feat: add fullgraph argument for torch compile --- thunder/executors/torch_compile.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 204bc2046d..d1b7680af6 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -16,6 +16,7 @@ from thunder.executors.passes import update_fusion_call_ctx from thunder.executors.utils import Region from thunder.extend import FusionExecutor, register_executor, ImplInfo +from thunder.core.compile_data import get_compile_option _TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0", use_base_version=True) @@ -84,7 +85,12 @@ def torch_interpreted_func(*args): # torch.compile executor" torch_trace = trace(inline_trace=False)(torch_interpreted_func, *sorted_unique_inputs) trace_callable = torch_trace.python_callable(include_decorators=False) - compiled_func = torch.compile(trace_callable, fullgraph=True) + torch_compile_fullgraph: None | bool = get_compile_option( + "torch_compile_fullgraph", "Whether to enable `fullgraph` from `torch.compile`. Defaults to `True`." + ) + if torch_compile_fullgraph is None: + torch_compile_fullgraph = True + compiled_func = torch.compile(trace_callable, fullgraph=torch_compile_fullgraph) # For each of `@torch.no_grad(), and `torch.autocast(device_type="cpu"|"cuda")` torch.compile # create caches with a guard for the wrapped function. Since the torch.compile caches are per code object, not # frame, all the dynamic copies of these context managers share the same code cache.