From 84fc424044e2814eadbb59d6ec032ff0e2765cd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 3 Apr 2024 18:36:31 +0200 Subject: [PATCH 1/2] Remove `CompileData.processed_function` --- thunder/common.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/thunder/common.py b/thunder/common.py index 54dbe13c34..116aa27c70 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -233,15 +233,7 @@ def __init__( self.additional_return_names = None self.num_constant_args = 0 - self._processed_function: Callable - - assert disable_preprocessing, "please use thunder.jit if you need preprocessing" - self._processed_function = fn - - # Disallows overwriting processed_function - @property - def processed_function(self): - return self._processed_function + assert disable_preprocessing, "please use thunder.compile if you need preprocessing" # Common UX functions @@ -665,7 +657,7 @@ def _create_callable( post_optimization_transforms: list[Callable] = [], _using_grad_transform: bool = False, ) -> Callable: - @wraps(cd.processed_function) + @wraps(cd.fn) def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: cs.last_trace_host_start = time.time_ns() cs.calls += 1 @@ -729,9 +721,9 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: # Applies the autocast transform if PyTorch's autocast behavior is enabled processed_function = ( - cd.processed_function + cd.fn if not is_autocast_enabled - else autocast(cd.processed_function, dtype=autocast_thunder_dtype) + else autocast(cd.fn, dtype=autocast_thunder_dtype) ) # Resets use of compile flags @@ -839,7 +831,6 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: return result # NOTE is_module is False - _fn._pfn = cd.processed_function _fn._lc_cd = cd _fn._lc_cs = cs _fn._lc_transforms = transforms From 6114d2705fc477428e09adb5ec90a5f3ebfe1c17 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Apr 2024 16:37:54 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/common.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/thunder/common.py b/thunder/common.py index 116aa27c70..ec0ed1d262 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -720,11 +720,7 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: cs.last_trace_cache_stop = time.time_ns() # Applies the autocast transform if PyTorch's autocast behavior is enabled - processed_function = ( - cd.fn - if not is_autocast_enabled - else autocast(cd.fn, dtype=autocast_thunder_dtype) - ) + processed_function = cd.fn if not is_autocast_enabled else autocast(cd.fn, dtype=autocast_thunder_dtype) # Resets use of compile flags cs.last_compile_reasons = defaultdict(list)