diff --git a/thunder/common.py b/thunder/common.py index 54dbe13c34..ec0ed1d262 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -233,15 +233,7 @@ def __init__( self.additional_return_names = None self.num_constant_args = 0 - self._processed_function: Callable - - assert disable_preprocessing, "please use thunder.jit if you need preprocessing" - self._processed_function = fn - - # Disallows overwriting processed_function - @property - def processed_function(self): - return self._processed_function + assert disable_preprocessing, "please use thunder.compile if you need preprocessing" # Common UX functions @@ -665,7 +657,7 @@ def _create_callable( post_optimization_transforms: list[Callable] = [], _using_grad_transform: bool = False, ) -> Callable: - @wraps(cd.processed_function) + @wraps(cd.fn) def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: cs.last_trace_host_start = time.time_ns() cs.calls += 1 @@ -728,11 +720,7 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: cs.last_trace_cache_stop = time.time_ns() # Applies the autocast transform if PyTorch's autocast behavior is enabled - processed_function = ( - cd.processed_function - if not is_autocast_enabled - else autocast(cd.processed_function, dtype=autocast_thunder_dtype) - ) + processed_function = cd.fn if not is_autocast_enabled else autocast(cd.fn, dtype=autocast_thunder_dtype) # Resets use of compile flags cs.last_compile_reasons = defaultdict(list) @@ -839,7 +827,6 @@ def _fn(*args, **kwargs) -> tuple[Any, list[TraceCtx]]: return result # NOTE is_module is False - _fn._pfn = cd.processed_function _fn._lc_cd = cd _fn._lc_cs = cs _fn._lc_transforms = transforms