diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 2a01959f4e..5011ff0bc6 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -14,11 +14,9 @@ # TODO This test currently ignores the "should_autocast" argument enumerated in it -@instantiate( - dtypes=dtypes.float_math_dtypes, -) +@instantiate(dtypes=dtypes.float_math_dtypes) def test_thunder_autocast_transform(executor, device, dtype): - from thunder.transforms.autocast import autocast + from thunder.transforms.autocast import AutocastTransform # TODO: Consider adding support for device specific dtypes in the test # instantiator. @@ -49,13 +47,17 @@ def h(a, b, c): ) else: pytest.fail(f"Invalid combination of parameters: {executor=}, {device=}, {dtype=}") + for (func, should_autocast), autocast_dtype in itertools.product( ((f, True), (g, False), (h, True)), autocast_dtypes ): autocast_torch_dtype = ltorch.to_torch_dtype(autocast_dtype) x, y, z = (torch.randn((2, 2), device=device, dtype=torch_dtype) for _ in range(3)) - initial_trace = thunder.trace()(autocast(func, dtype=autocast_dtype), x, y, z) - compiled = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True) + + # Use AutocastTransform instead of autocast function + compiled = thunder.jit( + func, transforms=[AutocastTransform(dtype=autocast_dtype)], executors=executor.executors_list() + ) out = compiled(x, y, z) devicetype = torch.device(device).type diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index fbe4f622f1..aedf782735 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -8,12 +8,15 @@ from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbolInterface, Symbol from thunder.core.proxies import TensorProxy -from thunder.core.transforms import construct_trace, eval_trace +from thunder.core.trace import TraceCtx, TraceProvenance, from_trace, set_tracectx, get_tracectx, reset_tracectx +from thunder.core.trace_interpreter import TraceSubstitutionProcessor, trace_interpreter_skip_list +from thunder.core.transforms import construct_trace, eval_trace, Transform from thunder.clang import ( maybe_convert_to_dtype, ) import thunder.torch as ltorch - +import warnings +from contextlib import contextmanager autocast_impls: dict[prims.PrimIDs, Callable] = {} @@ -227,6 +230,11 @@ def autocast(func: Callable, dtype: dtypes.dtype): Returns: Callable: The transformed function """ + warnings.warn( + "autocast() is deprecated. Use thunder.jit(func, transforms=[AutocastTransform(dtype)]) instead", + DeprecationWarning, + stacklevel=2, + ) if not isinstance(dtype, dtypes.dtype): raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") @@ -309,3 +317,185 @@ def is_cpu_tensor(p): return wrapper return None + + +# class AutocastTransform(Transform): +# """Transform that applies automatic mixed precision (autocast) to eligible operations.""" + +# def __init__(self, dtype: dtypes.dtype): +# """Initialize the autocast transform. + +# Args: +# dtype: The target dtype to cast eligible operations to (float16 or bfloat16) +# """ +# if not isinstance(dtype, dtypes.dtype): +# raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") + +# if dtype not in _allowed_downcast_types: +# raise ValueError( +# f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}" +# ) +# self.dtype = dtype + +# def transform_traces_pre_prologue( +# self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs +# ) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]: +# """Transform the computation trace to apply autocast rules.""" + +# class AutocastProcessor(TraceSubstitutionProcessor): +# def __init__(self, trace, dtype, *args, **kwargs): +# super().__init__(trace, *args, **kwargs) +# self.dtype = dtype + +# def process_bsym(self, bsym): +# if bsym.sym.id in trace_interpreter_skip_list: +# self.new_trace.bound_symbols.append(bsym.from_bsym()) +# return + +# autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) + +# if autocast_impl is not None: +# args = tree_map(self.read, bsym.args) +# kwargs = tree_map(self.read, bsym.kwargs) + +# with disable_autocast(): +# result = autocast_impl(*args, **kwargs, dtype=self.dtype) + +# self.set_result(result) +# else: +# args = tree_map(self.read, bsym.args) +# kwargs = tree_map(self.read, bsym.kwargs) +# result = bsym.sym(*args, **kwargs) +# self.set_result(result) + +# new_bsym = bsym.from_bsym() +# new_bsym.args = args +# new_bsym.kwargs = kwargs +# self.add_processed_bsyms([new_bsym]) + +# if computation_trace is not None: +# processor = AutocastProcessor(computation_trace, self.dtype) + +# args = kwargs.get("args", ()) +# kw = kwargs.get("kwargs", {}) + +# processor.process_args(*args, **kw) + +# new_trace, outputs = processor() +# computation_trace = new_trace + +# return prologue_trace, computation_trace, epilogue_trace + + +class AutocastTransform(Transform): + def __init__(self, dtype: dtypes.dtype): + if not isinstance(dtype, dtypes.dtype): + raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") + _check_valid_autocast_dtype(dtype) + self.dtype = dtype + self.env = {} + + @contextmanager + def trace_context(self, trace: TraceCtx): + """Context manager for handling trace context""" + token = set_tracectx(trace) + try: + yield + finally: + if token is not None: + reset_tracectx(token) + + +def transform_traces_pre_prologue( + self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs +) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]: + if computation_trace is None: + return prologue_trace, computation_trace, epilogue_trace + + # Create new computation trace + new_computation_trace = from_trace(computation_trace) + + # Initialize environment with input tensors + for arg in computation_trace.args: + if isinstance(arg, TensorProxy): + self.env[arg.name] = arg + + # Process bound symbols within tracing context + with self.trace_context(new_computation_trace): + for bsym in computation_trace.bound_symbols: + if bsym.sym.id in trace_interpreter_skip_list: + new_bsym = bsym.from_bsym() + new_computation_trace.bound_symbols.append(new_bsym) + self._update_env(new_bsym) + continue + + autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) + + if autocast_impl is not None: + # Process args with autocast + processed_args = [] + for arg in bsym.args: + if isinstance(arg, TensorProxy): + if arg.name in self.env: + tensor = self.env[arg.name] + if tensor.dtype in (dtypes.float32, dtypes.float64): + with self.trace_context(new_computation_trace): + processed_arg = maybe_downcast_to(self.dtype, tensor) + else: + processed_arg = tensor + else: + processed_arg = arg + else: + processed_arg = arg + processed_args.append(processed_arg) + + # Process kwargs with autocast + processed_kwargs = {} + for k, v in bsym.kwargs.items(): + if isinstance(v, TensorProxy): + if v.name in self.env: + tensor = self.env[v.name] + if tensor.dtype in (dtypes.float32, dtypes.float64): + with self.trace_context(new_computation_trace): + processed_kwargs[k] = maybe_downcast_to(self.dtype, tensor) + else: + processed_kwargs[k] = tensor + else: + processed_kwargs[k] = v + else: + processed_kwargs[k] = v + + # Apply autocast implementation within tracing context + with self.trace_context(new_computation_trace), disable_autocast(): + result = autocast_impl(*processed_args, **processed_kwargs, dtype=self.dtype) + + # Ensure the result is in the target dtype + if isinstance(result, TensorProxy) and result.dtype in (dtypes.float32, dtypes.float64): + with self.trace_context(new_computation_trace): + result = maybe_downcast_to(self.dtype, result) + + new_bsym = bsym.from_bsym(args=processed_args, kwargs=processed_kwargs, output=result) + else: + new_bsym = bsym.from_bsym() + # If this is the final operation, ensure output is in target dtype + if isinstance(new_bsym.output, TensorProxy) and new_bsym.output.dtype in ( + dtypes.float32, + dtypes.float64, + ): + with self.trace_context(new_computation_trace): + new_bsym.output = maybe_downcast_to(self.dtype, new_bsym.output) + + new_computation_trace.bound_symbols.append(new_bsym) + self._update_env(new_bsym) + + new_computation_trace.set_provenance(TraceProvenance("constructed by autocast")) + return prologue_trace, new_computation_trace, epilogue_trace + + def _update_env(self, bsym): + """Update environment mapping with bound symbol outputs""" + if isinstance(bsym.output, (tuple, list)): + for out in bsym.output: + if isinstance(out, TensorProxy): + self.env[out.name] = out + elif isinstance(bsym.output, TensorProxy): + self.env[bsym.output.name] = bsym.output