Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] authored and rittik9 committed Dec 9, 2024
1 parent bfb95af commit e3475da
Showing 1 changed file with 20 additions and 23 deletions.
43 changes: 20 additions & 23 deletions thunder/transforms/autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from thunder.core.proxies import TensorProxy
from thunder.core.symbol import BoundSymbolInterface, Symbol
from thunder.core.proxies import TensorProxy
from thunder.core.trace import TraceCtx,tracectx
from thunder.core.trace import TraceCtx, tracectx
from thunder.core.trace_interpreter import TraceSubstitutionProcessor, trace_interpreter_skip_list
from thunder.core.transforms import construct_trace, eval_trace, Transform
from thunder.clang import (
Expand Down Expand Up @@ -317,38 +317,35 @@ def is_cpu_tensor(p):

return None


class AutocastTransform(Transform):
"""Transform that applies automatic mixed precision (autocast) to eligible operations."""

def __init__(self, dtype: dtypes.dtype):
"""Initialize the autocast transform.
Args:
dtype: The target dtype to cast eligible operations to (float16 or bfloat16)
"""
if not isinstance(dtype, dtypes.dtype):
raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}")

if dtype not in _allowed_downcast_types:
raise ValueError(
f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}"
)
self.dtype = dtype

def transform_traces_pre_prologue(
self,
prologue_trace: TraceCtx,
computation_trace: TraceCtx,
epilogue_trace: TraceCtx | None,
**kwargs
self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs
) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]:
"""Transform the computation trace to apply autocast rules."""

class AutocastProcessor(TraceSubstitutionProcessor):
def __init__(self, trace, dtype, *args, **kwargs):
super().__init__(trace, *args, **kwargs)
self.dtype = dtype

def process_bsym(self, bsym):
# Skip special symbols that shouldn't be processed
if bsym.sym.id in trace_interpreter_skip_list:
Expand All @@ -357,24 +354,24 @@ def process_bsym(self, bsym):

# Check if symbol has an autocast implementation
autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym)

if autocast_impl is not None:
# Read the arguments with potential autocast conversion
args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)

# Apply the autocast implementation
with disable_autocast():
result = autocast_impl(*args, **kwargs, dtype=self.dtype)

self.set_result(result)
else:
# No autocast rule, process normally
args = tree_map(self.read, bsym.args)
args = tree_map(self.read, bsym.args)
kwargs = tree_map(self.read, bsym.kwargs)
result = bsym.sym(*args, **kwargs)
self.set_result(result)

# Add the bound symbol to new trace
new_bsym = bsym.from_bsym()
new_bsym.args = args
Expand All @@ -384,21 +381,21 @@ def process_bsym(self, bsym):
# Process the computation trace
if computation_trace is not None:
processor = AutocastProcessor(computation_trace, self.dtype)

# Get the actual args and kwargs from the kwargs dict
args = kwargs.get('args', ())
kw = kwargs.get('kwargs', {})
args = kwargs.get("args", ())
kw = kwargs.get("kwargs", {})

with tracectx(processor.new_trace):
# Initialize the processor's environment with input arguments
for trace_arg, arg in zip(computation_trace.args, args):
processor.env[trace_arg.name] = arg

# Initialize kwargs if any
for trace_kwarg, kwarg in zip(computation_trace.kwargs.values(), kw.values()):
processor.env[trace_kwarg.name] = kwarg

new_trace, _ = processor()
computation_trace = new_trace

return prologue_trace, computation_trace, epilogue_trace
return prologue_trace, computation_trace, epilogue_trace

0 comments on commit e3475da

Please sign in to comment.