Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor autocast to AutocastTransform for Improved Composability and Transform Integration #1516

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions thunder/tests/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@


# TODO This test currently ignores the "should_autocast" argument enumerated in it
@instantiate(
dtypes=dtypes.float_math_dtypes,
)
@instantiate(dtypes=dtypes.float_math_dtypes)
def test_thunder_autocast_transform(executor, device, dtype):
from thunder.transforms.autocast import autocast
from thunder.transforms.autocast import AutocastTransform

# TODO: Consider adding support for device specific dtypes in the test
# instantiator.
Expand Down Expand Up @@ -49,13 +47,17 @@ def h(a, b, c):
)
else:
pytest.fail(f"Invalid combination of parameters: {executor=}, {device=}, {dtype=}")

for (func, should_autocast), autocast_dtype in itertools.product(
((f, True), (g, False), (h, True)), autocast_dtypes
):
autocast_torch_dtype = ltorch.to_torch_dtype(autocast_dtype)
x, y, z = (torch.randn((2, 2), device=device, dtype=torch_dtype) for _ in range(3))
initial_trace = thunder.trace()(autocast(func, dtype=autocast_dtype), x, y, z)
compiled = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True)

# Use AutocastTransform instead of autocast function
compiled = thunder.jit(
func, transforms=[AutocastTransform(dtype=autocast_dtype)], executors=executor.executors_list()
)
out = compiled(x, y, z)

devicetype = torch.device(device).type
Expand Down
194 changes: 192 additions & 2 deletions thunder/transforms/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from thunder.core.proxies import TensorProxy
from thunder.core.symbol import BoundSymbolInterface, Symbol
from thunder.core.proxies import TensorProxy
from thunder.core.transforms import construct_trace, eval_trace
from thunder.core.trace import TraceCtx, TraceProvenance, from_trace, set_tracectx, get_tracectx, reset_tracectx
from thunder.core.trace_interpreter import TraceSubstitutionProcessor, trace_interpreter_skip_list
from thunder.core.transforms import construct_trace, eval_trace, Transform
from thunder.clang import (
maybe_convert_to_dtype,
)
import thunder.torch as ltorch

import warnings
from contextlib import contextmanager

autocast_impls: dict[prims.PrimIDs, Callable] = {}

Expand Down Expand Up @@ -227,6 +230,11 @@ def autocast(func: Callable, dtype: dtypes.dtype):
Returns:
Callable: The transformed function
"""
warnings.warn(
"autocast() is deprecated. Use thunder.jit(func, transforms=[AutocastTransform(dtype)]) instead",
DeprecationWarning,
stacklevel=2,
)

if not isinstance(dtype, dtypes.dtype):
raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}")
Expand Down Expand Up @@ -309,3 +317,185 @@ def is_cpu_tensor(p):
return wrapper

return None


# class AutocastTransform(Transform):
# """Transform that applies automatic mixed precision (autocast) to eligible operations."""

# def __init__(self, dtype: dtypes.dtype):
# """Initialize the autocast transform.

# Args:
# dtype: The target dtype to cast eligible operations to (float16 or bfloat16)
# """
# if not isinstance(dtype, dtypes.dtype):
# raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}")

# if dtype not in _allowed_downcast_types:
# raise ValueError(
# f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}"
# )
# self.dtype = dtype

# def transform_traces_pre_prologue(
# self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs
# ) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]:
# """Transform the computation trace to apply autocast rules."""

# class AutocastProcessor(TraceSubstitutionProcessor):
# def __init__(self, trace, dtype, *args, **kwargs):
# super().__init__(trace, *args, **kwargs)
# self.dtype = dtype

# def process_bsym(self, bsym):
# if bsym.sym.id in trace_interpreter_skip_list:
# self.new_trace.bound_symbols.append(bsym.from_bsym())
# return

# autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym)

# if autocast_impl is not None:
# args = tree_map(self.read, bsym.args)
# kwargs = tree_map(self.read, bsym.kwargs)

# with disable_autocast():
# result = autocast_impl(*args, **kwargs, dtype=self.dtype)

# self.set_result(result)
# else:
# args = tree_map(self.read, bsym.args)
# kwargs = tree_map(self.read, bsym.kwargs)
# result = bsym.sym(*args, **kwargs)
# self.set_result(result)

# new_bsym = bsym.from_bsym()
# new_bsym.args = args
# new_bsym.kwargs = kwargs
# self.add_processed_bsyms([new_bsym])

# if computation_trace is not None:
# processor = AutocastProcessor(computation_trace, self.dtype)

# args = kwargs.get("args", ())
# kw = kwargs.get("kwargs", {})

# processor.process_args(*args, **kw)

# new_trace, outputs = processor()
# computation_trace = new_trace

# return prologue_trace, computation_trace, epilogue_trace


class AutocastTransform(Transform):
def __init__(self, dtype: dtypes.dtype):
if not isinstance(dtype, dtypes.dtype):
raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}")
_check_valid_autocast_dtype(dtype)
self.dtype = dtype
self.env = {}

@contextmanager
def trace_context(self, trace: TraceCtx):
"""Context manager for handling trace context"""
token = set_tracectx(trace)
try:
yield
finally:
if token is not None:
reset_tracectx(token)


def transform_traces_pre_prologue(
self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs
) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]:
if computation_trace is None:
return prologue_trace, computation_trace, epilogue_trace

# Create new computation trace
new_computation_trace = from_trace(computation_trace)

# Initialize environment with input tensors
for arg in computation_trace.args:
if isinstance(arg, TensorProxy):
self.env[arg.name] = arg

# Process bound symbols within tracing context
with self.trace_context(new_computation_trace):
for bsym in computation_trace.bound_symbols:
if bsym.sym.id in trace_interpreter_skip_list:
new_bsym = bsym.from_bsym()
new_computation_trace.bound_symbols.append(new_bsym)
self._update_env(new_bsym)
continue

autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym)

if autocast_impl is not None:
# Process args with autocast
processed_args = []
for arg in bsym.args:
if isinstance(arg, TensorProxy):
if arg.name in self.env:
tensor = self.env[arg.name]
if tensor.dtype in (dtypes.float32, dtypes.float64):
with self.trace_context(new_computation_trace):
processed_arg = maybe_downcast_to(self.dtype, tensor)
else:
processed_arg = tensor
else:
processed_arg = arg
else:
processed_arg = arg
processed_args.append(processed_arg)

# Process kwargs with autocast
processed_kwargs = {}
for k, v in bsym.kwargs.items():
if isinstance(v, TensorProxy):
if v.name in self.env:
tensor = self.env[v.name]
if tensor.dtype in (dtypes.float32, dtypes.float64):
with self.trace_context(new_computation_trace):
processed_kwargs[k] = maybe_downcast_to(self.dtype, tensor)
else:
processed_kwargs[k] = tensor
else:
processed_kwargs[k] = v
else:
processed_kwargs[k] = v

# Apply autocast implementation within tracing context
with self.trace_context(new_computation_trace), disable_autocast():
result = autocast_impl(*processed_args, **processed_kwargs, dtype=self.dtype)

# Ensure the result is in the target dtype
if isinstance(result, TensorProxy) and result.dtype in (dtypes.float32, dtypes.float64):
with self.trace_context(new_computation_trace):
result = maybe_downcast_to(self.dtype, result)

new_bsym = bsym.from_bsym(args=processed_args, kwargs=processed_kwargs, output=result)
else:
new_bsym = bsym.from_bsym()
# If this is the final operation, ensure output is in target dtype
if isinstance(new_bsym.output, TensorProxy) and new_bsym.output.dtype in (
dtypes.float32,
dtypes.float64,
):
with self.trace_context(new_computation_trace):
new_bsym.output = maybe_downcast_to(self.dtype, new_bsym.output)

new_computation_trace.bound_symbols.append(new_bsym)
self._update_env(new_bsym)

new_computation_trace.set_provenance(TraceProvenance("constructed by autocast"))
return prologue_trace, new_computation_trace, epilogue_trace

def _update_env(self, bsym):
"""Update environment mapping with bound symbol outputs"""
if isinstance(bsym.output, (tuple, list)):
for out in bsym.output:
if isinstance(out, TensorProxy):
self.env[out.name] = out
elif isinstance(bsym.output, TensorProxy):
self.env[bsym.output.name] = bsym.output
Loading