From 38992cd17393961bc99a73b696aa51c9730cbedf Mon Sep 17 00:00:00 2001 From: rittik9 Date: Wed, 4 Dec 2024 10:06:45 +0000 Subject: [PATCH 01/19] Refactor: autocast into AutocastTransform for better composability --- thunder/__init__.py | 5 +++++ thunder/tests/test_autocast.py | 20 ++++++++++++++++++++ thunder/transforms/autocast.py | 9 ++++++++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index e3aa5ac2aa..feb61ec924 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -84,6 +84,7 @@ import thunder.executors.pythonex import thunder.executors.torchex import thunder.executors.nvfuserex +from thunder.transforms.autocast import AutocastTransform pythonex = extend.get_executor("python") assert pythonex is not None @@ -331,6 +332,10 @@ def jit( if transforms is None: transforms = [] + + if transforms: + for transform in transforms: + fn = transform(fn) # Resolve names of executors executors = resolve_executors(executors) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 2a01959f4e..97f31bfcde 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -3,6 +3,8 @@ import pytest import torch +from thunder.transforms.autocast import AutocastTransform +from thunder import jit from torch._dynamo.eval_frame import is_inductor_supported import thunder @@ -311,3 +313,21 @@ def foo(a, b, c, d): for eg, jg in zip(eager_grads, jit_grads): torch.testing.assert_close(eg, jg, rtol=5e-3, atol=5e-3) + + +def simple_addition(x, y): + return x + y + +def test_autocast_transform(): + autocast_transform = AutocastTransform(dtype=torch.bfloat16) + jitted_fn = jit(simple_addition, transforms=[autocast_transform]) + + x = torch.randn(2, 2, dtype=torch.float32) + y = torch.randn(2, 2, dtype=torch.float32) + + result = jitted_fn(x, y) + + assert result.dtype == torch.bfloat16, f"Expected dtype: bfloat16, but got: {result.dtype}" + + expected_result = simple_addition(x, y).to(torch.bfloat16) + assert torch.allclose(result, expected_result), "The output values do not match the expected results." \ No newline at end of file diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index fbe4f622f1..b9cf56a1fd 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -8,7 +8,7 @@ from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbolInterface, Symbol from thunder.core.proxies import TensorProxy -from thunder.core.transforms import construct_trace, eval_trace +from thunder.core.transforms import construct_trace, eval_trace,Transform from thunder.clang import ( maybe_convert_to_dtype, ) @@ -309,3 +309,10 @@ def is_cpu_tensor(p): return wrapper return None + +class AutocastTransform(Transform): + def __init__(self, dtype): + self.dtype = dtype + + def __call__(self, fn): + return autocast(fn, dtype=self.dtype) \ No newline at end of file From 34f4c6ec07ab09d6eadeaa6f9956b59c482e306d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Dec 2024 10:11:26 +0000 Subject: [PATCH 02/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/__init__.py | 2 +- thunder/tests/test_autocast.py | 3 ++- thunder/transforms/autocast.py | 5 +++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index feb61ec924..01252201b5 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -332,7 +332,7 @@ def jit( if transforms is None: transforms = [] - + if transforms: for transform in transforms: fn = transform(fn) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 97f31bfcde..f42c59afc4 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -318,6 +318,7 @@ def foo(a, b, c, d): def simple_addition(x, y): return x + y + def test_autocast_transform(): autocast_transform = AutocastTransform(dtype=torch.bfloat16) jitted_fn = jit(simple_addition, transforms=[autocast_transform]) @@ -330,4 +331,4 @@ def test_autocast_transform(): assert result.dtype == torch.bfloat16, f"Expected dtype: bfloat16, but got: {result.dtype}" expected_result = simple_addition(x, y).to(torch.bfloat16) - assert torch.allclose(result, expected_result), "The output values do not match the expected results." \ No newline at end of file + assert torch.allclose(result, expected_result), "The output values do not match the expected results." diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index b9cf56a1fd..bee9bb45c9 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -8,7 +8,7 @@ from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbolInterface, Symbol from thunder.core.proxies import TensorProxy -from thunder.core.transforms import construct_trace, eval_trace,Transform +from thunder.core.transforms import construct_trace, eval_trace, Transform from thunder.clang import ( maybe_convert_to_dtype, ) @@ -310,9 +310,10 @@ def is_cpu_tensor(p): return None + class AutocastTransform(Transform): def __init__(self, dtype): self.dtype = dtype def __call__(self, fn): - return autocast(fn, dtype=self.dtype) \ No newline at end of file + return autocast(fn, dtype=self.dtype) From 0339275717ea0c8065be3b88789714e2ec79dd19 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 4 Dec 2024 23:27:43 +0530 Subject: [PATCH 03/19] [wip]: update __init__.py(for testing purpose) --- thunder/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 01252201b5..84d4e5bac8 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -335,7 +335,8 @@ def jit( if transforms: for transform in transforms: - fn = transform(fn) + if isinstance(transform, AutocastTransform): + fn = transform(fn) # Resolve names of executors executors = resolve_executors(executors) From 3982fda39d05289ff36189e0b84dcbb327352ccc Mon Sep 17 00:00:00 2001 From: rittik9 Date: Fri, 6 Dec 2024 20:41:20 +0000 Subject: [PATCH 04/19] fix: apply suggestions from review --- thunder/__init__.py | 6 --- thunder/tests/test_autocast.py | 71 +++++++++++++--------------------- thunder/transforms/autocast.py | 34 +++++++++++++--- 3 files changed, 56 insertions(+), 55 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 84d4e5bac8..e3aa5ac2aa 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -84,7 +84,6 @@ import thunder.executors.pythonex import thunder.executors.torchex import thunder.executors.nvfuserex -from thunder.transforms.autocast import AutocastTransform pythonex = extend.get_executor("python") assert pythonex is not None @@ -333,11 +332,6 @@ def jit( if transforms is None: transforms = [] - if transforms: - for transform in transforms: - if isinstance(transform, AutocastTransform): - fn = transform(fn) - # Resolve names of executors executors = resolve_executors(executors) ad_hoc_executor = extend.AdHocExecutor() diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index f42c59afc4..9b0c130dd0 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -3,8 +3,6 @@ import pytest import torch -from thunder.transforms.autocast import AutocastTransform -from thunder import jit from torch._dynamo.eval_frame import is_inductor_supported import thunder @@ -16,19 +14,9 @@ # TODO This test currently ignores the "should_autocast" argument enumerated in it -@instantiate( - dtypes=dtypes.float_math_dtypes, -) +@instantiate(dtypes=dtypes.float_math_dtypes) def test_thunder_autocast_transform(executor, device, dtype): - from thunder.transforms.autocast import autocast - - # TODO: Consider adding support for device specific dtypes in the test - # instantiator. - torch_device = torch.device(device) - if torch_device.type == "cpu" and dtype == dtypes.float16: - pytest.skip("float16 matmul is not supported on CPU.") - if torch_device.type == "cuda" and dtype == dtypes.bfloat16 and not thunder.tests.bf16.device_supports_bf16(device): - pytest.skip(f"bfloat16 is not supported on {torch.cuda.get_device_name()}") + from thunder.transforms.autocast import AutocastTransform def f(a, b, c): return a @ (b + c) @@ -40,29 +28,24 @@ def g(a, b, c): def h(a, b, c): return (a @ b) + c - torch_dtype = ltorch.to_torch_dtype(dtype) - if torch_device.type == "cpu": - autocast_dtypes = (thunder.bfloat16,) - elif torch_device.type == "cuda": - autocast_dtypes = ( - (thunder.bfloat16, thunder.float16) - if thunder.tests.bf16.device_supports_bf16(device) - else (thunder.float16,) + for func, should_autocast in ((f, True), (g, False), (h, False)): + dtype = thunder.bfloat16 if device == "cpu" else thunder.float16 + torch_dtype = ltorch.to_torch_dtype(dtype) + x, y, z = (torch.randn((2, 2), device=device, dtype=torch.float32) for _ in range(3)) + + # Use the new transform class + compiled = thunder.jit( + func, + transforms=[AutocastTransform(dtype=dtype)], + executors=executor.executors_list() ) - else: - pytest.fail(f"Invalid combination of parameters: {executor=}, {device=}, {dtype=}") - for (func, should_autocast), autocast_dtype in itertools.product( - ((f, True), (g, False), (h, True)), autocast_dtypes - ): - autocast_torch_dtype = ltorch.to_torch_dtype(autocast_dtype) - x, y, z = (torch.randn((2, 2), device=device, dtype=torch_dtype) for _ in range(3)) - initial_trace = thunder.trace()(autocast(func, dtype=autocast_dtype), x, y, z) - compiled = executor.make_callable(initial_trace.python_callable(), disable_torch_autograd=True) out = compiled(x, y, z) + traces = thunder.last_traces(compiled) + assert out.dtype == (torch_dtype if should_autocast else torch.float32), traces[-1] + # Compare with PyTorch autocast devicetype = torch.device(device).type - # note(crcrpar): This test could be broken in the future as thunder autocast develops. - with torch.autocast(device_type=devicetype, dtype=autocast_torch_dtype): + with torch.autocast(device_type=devicetype, dtype=torch_dtype): torch_output = func(x, y, z) assert out.dtype == torch_output.dtype @@ -315,20 +298,20 @@ def foo(a, b, c, d): torch.testing.assert_close(eg, jg, rtol=5e-3, atol=5e-3) -def simple_addition(x, y): - return x + y +# def simple_addition(x, y): +# return x + y -def test_autocast_transform(): - autocast_transform = AutocastTransform(dtype=torch.bfloat16) - jitted_fn = jit(simple_addition, transforms=[autocast_transform]) +# def test_autocast_transform(): +# autocast_transform = AutocastTransform(dtype=torch.bfloat16) +# jitted_fn = jit(simple_addition, transforms=[autocast_transform]) - x = torch.randn(2, 2, dtype=torch.float32) - y = torch.randn(2, 2, dtype=torch.float32) +# x = torch.randn(2, 2, dtype=torch.float32) +# y = torch.randn(2, 2, dtype=torch.float32) - result = jitted_fn(x, y) +# result = jitted_fn(x, y) - assert result.dtype == torch.bfloat16, f"Expected dtype: bfloat16, but got: {result.dtype}" +# assert result.dtype == torch.bfloat16, f"Expected dtype: bfloat16, but got: {result.dtype}" - expected_result = simple_addition(x, y).to(torch.bfloat16) - assert torch.allclose(result, expected_result), "The output values do not match the expected results." +# expected_result = simple_addition(x, y).to(torch.bfloat16) +# assert torch.allclose(result, expected_result), "The output values do not match the expected results." diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index bee9bb45c9..a35610b8d8 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -4,16 +4,19 @@ from collections.abc import Sequence from thunder.core import dtypes, prims, devices +from thunder.core.dtypes import dtype from thunder.core.pytree import tree_map, tree_flatten from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbolInterface, Symbol from thunder.core.proxies import TensorProxy +from thunder.core.trace import TraceCtx +from thunder.core.trace_interpreter import TraceSubstitutionProcessor from thunder.core.transforms import construct_trace, eval_trace, Transform from thunder.clang import ( maybe_convert_to_dtype, ) import thunder.torch as ltorch - +import warnings autocast_impls: dict[prims.PrimIDs, Callable] = {} @@ -310,10 +313,31 @@ def is_cpu_tensor(p): return None - class AutocastTransform(Transform): - def __init__(self, dtype): + """Transform that enables autocasting operations to a specified dtype. + + Args: + dtype: The data type to which arguments could get cast if they are float32. + """ + def __init__(self, dtype: dtype): + super().__init__() + if not isinstance(dtype, dtype): + raise ValueError(f"`dtype` expected to be `thunder.dtype.dtype` but got {type(dtype)}") + _check_valid_autocast_dtype(dtype) self.dtype = dtype - def __call__(self, fn): - return autocast(fn, dtype=self.dtype) + def transform_traces_pre_prologue( + self, + prologue_trace: TraceCtx, + computation_trace: TraceCtx, + epilogue_trace: TraceCtx, + **kwargs + ) -> tuple[TraceCtx, TraceCtx, TraceCtx]: + processor = TraceSubstitutionProcessor( + computation_trace, + symbol_mapper=partial(autocast_symbol_mapper, dtype=self.dtype) + ) + new_computation_trace = processor.run() + new_computation_trace.set_provenance("Autocast Transform") + + return prologue_trace, new_computation_trace, epilogue_trace \ No newline at end of file From a0b6cfd57d170682237b1f8e1aa8ff594996264b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 20:42:49 +0000 Subject: [PATCH 05/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_autocast.py | 8 ++------ thunder/transforms/autocast.py | 17 +++++++---------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 9b0c130dd0..9a4cc11d87 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -32,13 +32,9 @@ def h(a, b, c): dtype = thunder.bfloat16 if device == "cpu" else thunder.float16 torch_dtype = ltorch.to_torch_dtype(dtype) x, y, z = (torch.randn((2, 2), device=device, dtype=torch.float32) for _ in range(3)) - + # Use the new transform class - compiled = thunder.jit( - func, - transforms=[AutocastTransform(dtype=dtype)], - executors=executor.executors_list() - ) + compiled = thunder.jit(func, transforms=[AutocastTransform(dtype=dtype)], executors=executor.executors_list()) out = compiled(x, y, z) traces = thunder.last_traces(compiled) assert out.dtype == (torch_dtype if should_autocast else torch.float32), traces[-1] diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index a35610b8d8..f5d809ce4c 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -313,12 +313,14 @@ def is_cpu_tensor(p): return None + class AutocastTransform(Transform): """Transform that enables autocasting operations to a specified dtype. - + Args: dtype: The data type to which arguments could get cast if they are float32. """ + def __init__(self, dtype: dtype): super().__init__() if not isinstance(dtype, dtype): @@ -327,17 +329,12 @@ def __init__(self, dtype: dtype): self.dtype = dtype def transform_traces_pre_prologue( - self, - prologue_trace: TraceCtx, - computation_trace: TraceCtx, - epilogue_trace: TraceCtx, - **kwargs + self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx, **kwargs ) -> tuple[TraceCtx, TraceCtx, TraceCtx]: processor = TraceSubstitutionProcessor( - computation_trace, - symbol_mapper=partial(autocast_symbol_mapper, dtype=self.dtype) + computation_trace, symbol_mapper=partial(autocast_symbol_mapper, dtype=self.dtype) ) new_computation_trace = processor.run() new_computation_trace.set_provenance("Autocast Transform") - - return prologue_trace, new_computation_trace, epilogue_trace \ No newline at end of file + + return prologue_trace, new_computation_trace, epilogue_trace From fe97f4abbd38df2b28833cc3865689ba2368696d Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sat, 7 Dec 2024 03:33:23 +0530 Subject: [PATCH 06/19] Update autocast.py --- thunder/transforms/autocast.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index f5d809ce4c..42fe5f02f5 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -4,7 +4,6 @@ from collections.abc import Sequence from thunder.core import dtypes, prims, devices -from thunder.core.dtypes import dtype from thunder.core.pytree import tree_map, tree_flatten from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbolInterface, Symbol @@ -230,6 +229,11 @@ def autocast(func: Callable, dtype: dtypes.dtype): Returns: Callable: The transformed function """ + warnings.warn( + "autocast() is deprecated. Use thunder.jit(func, transforms=[AutocastTransform(dtype)]) instead", + DeprecationWarning, + stacklevel=2 + ) if not isinstance(dtype, dtypes.dtype): raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") @@ -321,9 +325,9 @@ class AutocastTransform(Transform): dtype: The data type to which arguments could get cast if they are float32. """ - def __init__(self, dtype: dtype): + def __init__(self, dtype: dtypes.dtype): super().__init__() - if not isinstance(dtype, dtype): + if not isinstance(dtype, dtypes.dtype): raise ValueError(f"`dtype` expected to be `thunder.dtype.dtype` but got {type(dtype)}") _check_valid_autocast_dtype(dtype) self.dtype = dtype From f1c03fa1c4b053c7a0228fa0ca38c609a4557064 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 22:04:20 +0000 Subject: [PATCH 07/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/transforms/autocast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index 42fe5f02f5..4df86aea39 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -232,7 +232,7 @@ def autocast(func: Callable, dtype: dtypes.dtype): warnings.warn( "autocast() is deprecated. Use thunder.jit(func, transforms=[AutocastTransform(dtype)]) instead", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) if not isinstance(dtype, dtypes.dtype): From 7b867a9b2ba07f96c49374f05400ca17d7909d07 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sat, 7 Dec 2024 15:30:58 +0530 Subject: [PATCH 08/19] fix autocast.py --- thunder/transforms/autocast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index 4df86aea39..a23bc5bd6d 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -338,7 +338,7 @@ def transform_traces_pre_prologue( processor = TraceSubstitutionProcessor( computation_trace, symbol_mapper=partial(autocast_symbol_mapper, dtype=self.dtype) ) - new_computation_trace = processor.run() + new_computation_trace, _ = processor() new_computation_trace.set_provenance("Autocast Transform") return prologue_trace, new_computation_trace, epilogue_trace From 6717472d1e2495dcf8160f409e1a242897a42a51 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sat, 7 Dec 2024 16:20:53 +0530 Subject: [PATCH 09/19] Update autocast.py --- thunder/transforms/autocast.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index a23bc5bd6d..4ec4b10ac3 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -317,6 +317,32 @@ def is_cpu_tensor(p): return None +class AutocastTraceSubstitutionProcessor(TraceSubstitutionProcessor): + def __init__(self, trace, dtype): + super().__init__(trace) + self.dtype = dtype + + def process_bsym(self, bsym): + """Process a bound symbol for autocast transformation. + + This method is called by TraceSubstitutionProcessor.__call__ for each bound symbol. + """ + # Get the autocast implementation for this symbol + autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) + + if autocast_impl is None: + # If no autocast rule exists, use the original symbol + args = tree_map(self.read, bsym.args) + kwargs = tree_map(self.read, bsym.kwargs) + result = bsym.sym(*args, **kwargs) + self.set_result(result) + return + + # Apply the autocast implementation + args = tree_map(self.read, bsym.args) + kwargs = tree_map(self.read, bsym.kwargs) + result = autocast_impl(*args, dtype=self.dtype, **kwargs) + self.set_result(result) class AutocastTransform(Transform): """Transform that enables autocasting operations to a specified dtype. @@ -335,10 +361,7 @@ def __init__(self, dtype: dtypes.dtype): def transform_traces_pre_prologue( self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx, **kwargs ) -> tuple[TraceCtx, TraceCtx, TraceCtx]: - processor = TraceSubstitutionProcessor( - computation_trace, symbol_mapper=partial(autocast_symbol_mapper, dtype=self.dtype) - ) - new_computation_trace, _ = processor() + processor = AutocastTraceSubstitutionProcessor(computation_trace, self.dtype) + new_computation_trace, outputs = processor() new_computation_trace.set_provenance("Autocast Transform") - return prologue_trace, new_computation_trace, epilogue_trace From 39fc0096cefa0a299644220cf66e48dd4db81b7f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Dec 2024 10:51:31 +0000 Subject: [PATCH 10/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/transforms/autocast.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index 4ec4b10ac3..1f34655ec2 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -317,6 +317,7 @@ def is_cpu_tensor(p): return None + class AutocastTraceSubstitutionProcessor(TraceSubstitutionProcessor): def __init__(self, trace, dtype): super().__init__(trace) @@ -324,12 +325,12 @@ def __init__(self, trace, dtype): def process_bsym(self, bsym): """Process a bound symbol for autocast transformation. - + This method is called by TraceSubstitutionProcessor.__call__ for each bound symbol. """ # Get the autocast implementation for this symbol autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) - + if autocast_impl is None: # If no autocast rule exists, use the original symbol args = tree_map(self.read, bsym.args) @@ -344,6 +345,7 @@ def process_bsym(self, bsym): result = autocast_impl(*args, dtype=self.dtype, **kwargs) self.set_result(result) + class AutocastTransform(Transform): """Transform that enables autocasting operations to a specified dtype. From b5b53b5f82ef76e4cfd288d28c478f0e78a37b82 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sat, 7 Dec 2024 16:23:13 +0530 Subject: [PATCH 11/19] Update test_autocast.py --- thunder/tests/test_autocast.py | 44 ++++++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 9a4cc11d87..e6ab951f28 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -17,35 +17,53 @@ @instantiate(dtypes=dtypes.float_math_dtypes) def test_thunder_autocast_transform(executor, device, dtype): from thunder.transforms.autocast import AutocastTransform + + torch_device = torch.device(device) + if torch_device.type == "cpu" and dtype == dtypes.float16: + pytest.skip("float16 matmul is not supported on CPU.") + if torch_device.type == "cuda" and dtype == dtypes.bfloat16 and not thunder.tests.bf16.device_supports_bf16(device): + pytest.skip(f"bfloat16 is not supported on {torch.cuda.get_device_name()}") def f(a, b, c): return a @ (b + c) - # The following functions needs to be updated as autocast_impls grows. def g(a, b, c): return a + b - c def h(a, b, c): return (a @ b) + c - for func, should_autocast in ((f, True), (g, False), (h, False)): - dtype = thunder.bfloat16 if device == "cpu" else thunder.float16 - torch_dtype = ltorch.to_torch_dtype(dtype) - x, y, z = (torch.randn((2, 2), device=device, dtype=torch.float32) for _ in range(3)) - - # Use the new transform class - compiled = thunder.jit(func, transforms=[AutocastTransform(dtype=dtype)], executors=executor.executors_list()) + torch_dtype = ltorch.to_torch_dtype(dtype) + if torch_device.type == "cpu": + autocast_dtypes = (thunder.bfloat16,) + elif torch_device.type == "cuda": + autocast_dtypes = ( + (thunder.bfloat16, thunder.float16) + if thunder.tests.bf16.device_supports_bf16(device) + else (thunder.float16,) + ) + else: + pytest.fail(f"Invalid combination of parameters: {executor=}, {device=}, {dtype=}") + + for (func, should_autocast), autocast_dtype in itertools.product( + ((f, True), (g, False), (h, True)), autocast_dtypes + ): + autocast_torch_dtype = ltorch.to_torch_dtype(autocast_dtype) + x, y, z = (torch.randn((2, 2), device=device, dtype=torch_dtype) for _ in range(3)) + + # Use AutocastTransform instead of autocast function + compiled = thunder.jit( + func, + transforms=[AutocastTransform(dtype=autocast_dtype)], + executors=executor.executors_list() + ) out = compiled(x, y, z) - traces = thunder.last_traces(compiled) - assert out.dtype == (torch_dtype if should_autocast else torch.float32), traces[-1] - # Compare with PyTorch autocast devicetype = torch.device(device).type - with torch.autocast(device_type=devicetype, dtype=torch_dtype): + with torch.autocast(device_type=devicetype, dtype=autocast_torch_dtype): torch_output = func(x, y, z) assert out.dtype == torch_output.dtype - @instantiate( executors=[TorchExecutor], dtypes=dtypes.float_math_dtypes, From 13b79736bd8169f051f3d9edca56df45db49e886 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Dec 2024 10:53:51 +0000 Subject: [PATCH 12/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_autocast.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index e6ab951f28..1d80bf583c 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -17,7 +17,7 @@ @instantiate(dtypes=dtypes.float_math_dtypes) def test_thunder_autocast_transform(executor, device, dtype): from thunder.transforms.autocast import AutocastTransform - + torch_device = torch.device(device) if torch_device.type == "cpu" and dtype == dtypes.float16: pytest.skip("float16 matmul is not supported on CPU.") @@ -50,12 +50,10 @@ def h(a, b, c): ): autocast_torch_dtype = ltorch.to_torch_dtype(autocast_dtype) x, y, z = (torch.randn((2, 2), device=device, dtype=torch_dtype) for _ in range(3)) - + # Use AutocastTransform instead of autocast function compiled = thunder.jit( - func, - transforms=[AutocastTransform(dtype=autocast_dtype)], - executors=executor.executors_list() + func, transforms=[AutocastTransform(dtype=autocast_dtype)], executors=executor.executors_list() ) out = compiled(x, y, z) @@ -64,6 +62,7 @@ def h(a, b, c): torch_output = func(x, y, z) assert out.dtype == torch_output.dtype + @instantiate( executors=[TorchExecutor], dtypes=dtypes.float_math_dtypes, From bce7372ffad663ce51855699359bda50880d9991 Mon Sep 17 00:00:00 2001 From: rittik9 Date: Sat, 7 Dec 2024 19:00:05 +0000 Subject: [PATCH 13/19] update autocast.py --- thunder/transforms/autocast.py | 127 +++++++++++++++++++++------------ 1 file changed, 81 insertions(+), 46 deletions(-) diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index 1f34655ec2..b72429aed7 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -8,8 +8,8 @@ from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbolInterface, Symbol from thunder.core.proxies import TensorProxy -from thunder.core.trace import TraceCtx -from thunder.core.trace_interpreter import TraceSubstitutionProcessor +from thunder.core.trace import TraceCtx,tracectx +from thunder.core.trace_interpreter import TraceSubstitutionProcessor, trace_interpreter_skip_list from thunder.core.transforms import construct_trace, eval_trace, Transform from thunder.clang import ( maybe_convert_to_dtype, @@ -317,53 +317,88 @@ def is_cpu_tensor(p): return None - -class AutocastTraceSubstitutionProcessor(TraceSubstitutionProcessor): - def __init__(self, trace, dtype): - super().__init__(trace) - self.dtype = dtype - - def process_bsym(self, bsym): - """Process a bound symbol for autocast transformation. - - This method is called by TraceSubstitutionProcessor.__call__ for each bound symbol. - """ - # Get the autocast implementation for this symbol - autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) - - if autocast_impl is None: - # If no autocast rule exists, use the original symbol - args = tree_map(self.read, bsym.args) - kwargs = tree_map(self.read, bsym.kwargs) - result = bsym.sym(*args, **kwargs) - self.set_result(result) - return - - # Apply the autocast implementation - args = tree_map(self.read, bsym.args) - kwargs = tree_map(self.read, bsym.kwargs) - result = autocast_impl(*args, dtype=self.dtype, **kwargs) - self.set_result(result) - - class AutocastTransform(Transform): - """Transform that enables autocasting operations to a specified dtype. - - Args: - dtype: The data type to which arguments could get cast if they are float32. - """ - + """Transform that applies automatic mixed precision (autocast) to eligible operations.""" + def __init__(self, dtype: dtypes.dtype): - super().__init__() + """Initialize the autocast transform. + + Args: + dtype: The target dtype to cast eligible operations to (float16 or bfloat16) + """ if not isinstance(dtype, dtypes.dtype): - raise ValueError(f"`dtype` expected to be `thunder.dtype.dtype` but got {type(dtype)}") - _check_valid_autocast_dtype(dtype) + raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") + + if dtype not in _allowed_downcast_types: + raise ValueError( + f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}" + ) self.dtype = dtype def transform_traces_pre_prologue( - self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx, **kwargs - ) -> tuple[TraceCtx, TraceCtx, TraceCtx]: - processor = AutocastTraceSubstitutionProcessor(computation_trace, self.dtype) - new_computation_trace, outputs = processor() - new_computation_trace.set_provenance("Autocast Transform") - return prologue_trace, new_computation_trace, epilogue_trace + self, + prologue_trace: TraceCtx, + computation_trace: TraceCtx, + epilogue_trace: TraceCtx | None, + **kwargs + ) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]: + """Transform the computation trace to apply autocast rules.""" + + class AutocastProcessor(TraceSubstitutionProcessor): + def __init__(self, trace, dtype, *args, **kwargs): + super().__init__(trace, *args, **kwargs) + self.dtype = dtype + + def process_bsym(self, bsym): + # Skip special symbols that shouldn't be processed + if bsym.sym.id in trace_interpreter_skip_list: + self.new_trace.bound_symbols.append(bsym.from_bsym()) + return + + # Check if symbol has an autocast implementation + autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) + + if autocast_impl is not None: + # Read the arguments with potential autocast conversion + args = tree_map(self.read, bsym.args) + kwargs = tree_map(self.read, bsym.kwargs) + + # Apply the autocast implementation + with disable_autocast(): + result = autocast_impl(*args, **kwargs, dtype=self.dtype) + + self.set_result(result) + else: + # No autocast rule, process normally + args = tree_map(self.read, bsym.args) + kwargs = tree_map(self.read, bsym.kwargs) + result = bsym.sym(*args, **kwargs) + self.set_result(result) + + # Add the bound symbol to new trace + new_bsym = bsym.from_bsym() + new_bsym.args = args + new_bsym.kwargs = kwargs + self.add_processed_bsyms([new_bsym]) + + # Process the computation trace + if computation_trace is not None: + processor = AutocastProcessor(computation_trace, self.dtype) + + # Get the actual args and kwargs from the kwargs dict + args = kwargs.get('args', ()) + kw = kwargs.get('kwargs', {}) + + with tracectx(processor.new_trace): + # Initialize the processor's environment with input arguments + for trace_arg, arg in zip(computation_trace.args, args): + processor.env[trace_arg.name] = arg + + # Initialize kwargs if any + for trace_kwarg, kwarg in zip(computation_trace.kwargs.values(), kw.values()): + processor.env[trace_kwarg.name] = kwarg + + new_trace, _ = processor() + computation_trace = new_trace + + return prologue_trace, computation_trace, epilogue_trace \ No newline at end of file From e399bad5f392a9d0e43498db9f07e1e9772459fd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 7 Dec 2024 19:01:11 +0000 Subject: [PATCH 14/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/transforms/autocast.py | 43 ++++++++++++++++------------------ 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index b72429aed7..0d860cabca 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -8,7 +8,7 @@ from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbolInterface, Symbol from thunder.core.proxies import TensorProxy -from thunder.core.trace import TraceCtx,tracectx +from thunder.core.trace import TraceCtx, tracectx from thunder.core.trace_interpreter import TraceSubstitutionProcessor, trace_interpreter_skip_list from thunder.core.transforms import construct_trace, eval_trace, Transform from thunder.clang import ( @@ -317,18 +317,19 @@ def is_cpu_tensor(p): return None + class AutocastTransform(Transform): """Transform that applies automatic mixed precision (autocast) to eligible operations.""" - + def __init__(self, dtype: dtypes.dtype): """Initialize the autocast transform. - + Args: dtype: The target dtype to cast eligible operations to (float16 or bfloat16) """ if not isinstance(dtype, dtypes.dtype): raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") - + if dtype not in _allowed_downcast_types: raise ValueError( f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}" @@ -336,19 +337,15 @@ def __init__(self, dtype: dtypes.dtype): self.dtype = dtype def transform_traces_pre_prologue( - self, - prologue_trace: TraceCtx, - computation_trace: TraceCtx, - epilogue_trace: TraceCtx | None, - **kwargs + self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs ) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]: """Transform the computation trace to apply autocast rules.""" - + class AutocastProcessor(TraceSubstitutionProcessor): def __init__(self, trace, dtype, *args, **kwargs): super().__init__(trace, *args, **kwargs) self.dtype = dtype - + def process_bsym(self, bsym): # Skip special symbols that shouldn't be processed if bsym.sym.id in trace_interpreter_skip_list: @@ -357,24 +354,24 @@ def process_bsym(self, bsym): # Check if symbol has an autocast implementation autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) - + if autocast_impl is not None: # Read the arguments with potential autocast conversion args = tree_map(self.read, bsym.args) kwargs = tree_map(self.read, bsym.kwargs) - + # Apply the autocast implementation with disable_autocast(): result = autocast_impl(*args, **kwargs, dtype=self.dtype) - + self.set_result(result) else: # No autocast rule, process normally - args = tree_map(self.read, bsym.args) + args = tree_map(self.read, bsym.args) kwargs = tree_map(self.read, bsym.kwargs) result = bsym.sym(*args, **kwargs) self.set_result(result) - + # Add the bound symbol to new trace new_bsym = bsym.from_bsym() new_bsym.args = args @@ -384,21 +381,21 @@ def process_bsym(self, bsym): # Process the computation trace if computation_trace is not None: processor = AutocastProcessor(computation_trace, self.dtype) - + # Get the actual args and kwargs from the kwargs dict - args = kwargs.get('args', ()) - kw = kwargs.get('kwargs', {}) - + args = kwargs.get("args", ()) + kw = kwargs.get("kwargs", {}) + with tracectx(processor.new_trace): # Initialize the processor's environment with input arguments for trace_arg, arg in zip(computation_trace.args, args): processor.env[trace_arg.name] = arg - + # Initialize kwargs if any for trace_kwarg, kwarg in zip(computation_trace.kwargs.values(), kw.values()): processor.env[trace_kwarg.name] = kwarg - + new_trace, _ = processor() computation_trace = new_trace - return prologue_trace, computation_trace, epilogue_trace \ No newline at end of file + return prologue_trace, computation_trace, epilogue_trace From 7240b966134c1488f5da650eeb9b595afaa819eb Mon Sep 17 00:00:00 2001 From: rittik9 Date: Mon, 9 Dec 2024 12:38:36 +0000 Subject: [PATCH 15/19] refactor: autocast.py --- thunder/tests/test_autocast.py | 21 ++------------------- thunder/transforms/autocast.py | 21 +++------------------ 2 files changed, 5 insertions(+), 37 deletions(-) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 1d80bf583c..7c1b647a5f 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -27,6 +27,7 @@ def test_thunder_autocast_transform(executor, device, dtype): def f(a, b, c): return a @ (b + c) + # The following functions needs to be updated as autocast_impls grows. def g(a, b, c): return a + b - c @@ -58,6 +59,7 @@ def h(a, b, c): out = compiled(x, y, z) devicetype = torch.device(device).type + # note(crcrpar): This test could be broken in the future as thunder autocast develops. with torch.autocast(device_type=devicetype, dtype=autocast_torch_dtype): torch_output = func(x, y, z) assert out.dtype == torch_output.dtype @@ -309,22 +311,3 @@ def foo(a, b, c, d): for eg, jg in zip(eager_grads, jit_grads): torch.testing.assert_close(eg, jg, rtol=5e-3, atol=5e-3) - - -# def simple_addition(x, y): -# return x + y - - -# def test_autocast_transform(): -# autocast_transform = AutocastTransform(dtype=torch.bfloat16) -# jitted_fn = jit(simple_addition, transforms=[autocast_transform]) - -# x = torch.randn(2, 2, dtype=torch.float32) -# y = torch.randn(2, 2, dtype=torch.float32) - -# result = jitted_fn(x, y) - -# assert result.dtype == torch.bfloat16, f"Expected dtype: bfloat16, but got: {result.dtype}" - -# expected_result = simple_addition(x, y).to(torch.bfloat16) -# assert torch.allclose(result, expected_result), "The output values do not match the expected results." diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index 0d860cabca..545f731e26 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -347,55 +347,40 @@ def __init__(self, trace, dtype, *args, **kwargs): self.dtype = dtype def process_bsym(self, bsym): - # Skip special symbols that shouldn't be processed if bsym.sym.id in trace_interpreter_skip_list: self.new_trace.bound_symbols.append(bsym.from_bsym()) return - # Check if symbol has an autocast implementation autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) if autocast_impl is not None: - # Read the arguments with potential autocast conversion args = tree_map(self.read, bsym.args) kwargs = tree_map(self.read, bsym.kwargs) - # Apply the autocast implementation with disable_autocast(): result = autocast_impl(*args, **kwargs, dtype=self.dtype) self.set_result(result) else: - # No autocast rule, process normally args = tree_map(self.read, bsym.args) kwargs = tree_map(self.read, bsym.kwargs) result = bsym.sym(*args, **kwargs) self.set_result(result) - # Add the bound symbol to new trace new_bsym = bsym.from_bsym() new_bsym.args = args new_bsym.kwargs = kwargs self.add_processed_bsyms([new_bsym]) - # Process the computation trace if computation_trace is not None: processor = AutocastProcessor(computation_trace, self.dtype) - # Get the actual args and kwargs from the kwargs dict args = kwargs.get("args", ()) kw = kwargs.get("kwargs", {}) - with tracectx(processor.new_trace): - # Initialize the processor's environment with input arguments - for trace_arg, arg in zip(computation_trace.args, args): - processor.env[trace_arg.name] = arg + processor.process_args(*args, **kw) - # Initialize kwargs if any - for trace_kwarg, kwarg in zip(computation_trace.kwargs.values(), kw.values()): - processor.env[trace_kwarg.name] = kwarg - - new_trace, _ = processor() - computation_trace = new_trace + new_trace, outputs = processor() + computation_trace = new_trace return prologue_trace, computation_trace, epilogue_trace From 80b15d5cdc1f82f2375efabb2530c76daa137f21 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 11 Dec 2024 16:33:43 +0530 Subject: [PATCH 16/19] Update autocast.py --- thunder/transforms/autocast.py | 216 +++++++++++++++++++++++++-------- 1 file changed, 165 insertions(+), 51 deletions(-) diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index 545f731e26..dc99ae085a 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -8,7 +8,7 @@ from thunder.core.proxies import TensorProxy from thunder.core.symbol import BoundSymbolInterface, Symbol from thunder.core.proxies import TensorProxy -from thunder.core.trace import TraceCtx, tracectx +from thunder.core.trace import TraceCtx, TraceProvenance, from_trace, set_tracectx, get_tracectx, reset_tracectx from thunder.core.trace_interpreter import TraceSubstitutionProcessor, trace_interpreter_skip_list from thunder.core.transforms import construct_trace, eval_trace, Transform from thunder.clang import ( @@ -16,10 +16,10 @@ ) import thunder.torch as ltorch import warnings +from contextlib import contextmanager autocast_impls: dict[prims.PrimIDs, Callable] = {} - # NOTE: Rules which are registered ltorch symbols should match the type signature # of those symbols as we use this rule for translating from `torch` -> `thunder.torch` # if autocast is enabled while jitting. See also `NOTE: torch.autocast support`. @@ -318,69 +318,183 @@ def is_cpu_tensor(p): return None -class AutocastTransform(Transform): - """Transform that applies automatic mixed precision (autocast) to eligible operations.""" +# class AutocastTransform(Transform): +# """Transform that applies automatic mixed precision (autocast) to eligible operations.""" - def __init__(self, dtype: dtypes.dtype): - """Initialize the autocast transform. +# def __init__(self, dtype: dtypes.dtype): +# """Initialize the autocast transform. - Args: - dtype: The target dtype to cast eligible operations to (float16 or bfloat16) - """ - if not isinstance(dtype, dtypes.dtype): - raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") +# Args: +# dtype: The target dtype to cast eligible operations to (float16 or bfloat16) +# """ +# if not isinstance(dtype, dtypes.dtype): +# raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") - if dtype not in _allowed_downcast_types: - raise ValueError( - f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}" - ) - self.dtype = dtype +# if dtype not in _allowed_downcast_types: +# raise ValueError( +# f"autocast: `dtype` is expected to be either `thunder.float16` or `thunder.bfloat16`, but {dtype}" +# ) +# self.dtype = dtype + +# def transform_traces_pre_prologue( +# self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs +# ) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]: +# """Transform the computation trace to apply autocast rules.""" - def transform_traces_pre_prologue( - self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs - ) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]: - """Transform the computation trace to apply autocast rules.""" +# class AutocastProcessor(TraceSubstitutionProcessor): +# def __init__(self, trace, dtype, *args, **kwargs): +# super().__init__(trace, *args, **kwargs) +# self.dtype = dtype - class AutocastProcessor(TraceSubstitutionProcessor): - def __init__(self, trace, dtype, *args, **kwargs): - super().__init__(trace, *args, **kwargs) - self.dtype = dtype +# def process_bsym(self, bsym): +# if bsym.sym.id in trace_interpreter_skip_list: +# self.new_trace.bound_symbols.append(bsym.from_bsym()) +# return - def process_bsym(self, bsym): - if bsym.sym.id in trace_interpreter_skip_list: - self.new_trace.bound_symbols.append(bsym.from_bsym()) - return +# autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) - autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) +# if autocast_impl is not None: +# args = tree_map(self.read, bsym.args) +# kwargs = tree_map(self.read, bsym.kwargs) - if autocast_impl is not None: - args = tree_map(self.read, bsym.args) - kwargs = tree_map(self.read, bsym.kwargs) +# with disable_autocast(): +# result = autocast_impl(*args, **kwargs, dtype=self.dtype) - with disable_autocast(): - result = autocast_impl(*args, **kwargs, dtype=self.dtype) +# self.set_result(result) +# else: +# args = tree_map(self.read, bsym.args) +# kwargs = tree_map(self.read, bsym.kwargs) +# result = bsym.sym(*args, **kwargs) +# self.set_result(result) - self.set_result(result) - else: - args = tree_map(self.read, bsym.args) - kwargs = tree_map(self.read, bsym.kwargs) - result = bsym.sym(*args, **kwargs) - self.set_result(result) +# new_bsym = bsym.from_bsym() +# new_bsym.args = args +# new_bsym.kwargs = kwargs +# self.add_processed_bsyms([new_bsym]) - new_bsym = bsym.from_bsym() - new_bsym.args = args - new_bsym.kwargs = kwargs - self.add_processed_bsyms([new_bsym]) +# if computation_trace is not None: +# processor = AutocastProcessor(computation_trace, self.dtype) - if computation_trace is not None: - processor = AutocastProcessor(computation_trace, self.dtype) +# args = kwargs.get("args", ()) +# kw = kwargs.get("kwargs", {}) - args = kwargs.get("args", ()) - kw = kwargs.get("kwargs", {}) +# processor.process_args(*args, **kw) - processor.process_args(*args, **kw) +# new_trace, outputs = processor() +# computation_trace = new_trace - new_trace, outputs = processor() - computation_trace = new_trace +# return prologue_trace, computation_trace, epilogue_trace + +class AutocastTransform(Transform): + def __init__(self, dtype: dtypes.dtype): + if not isinstance(dtype, dtypes.dtype): + raise ValueError(f"`dtype` is expected to be `thunder.dtype.dtype` but {type(dtype)}") + _check_valid_autocast_dtype(dtype) + self.dtype = dtype + self.env = {} + + @contextmanager + def trace_context(self, trace: TraceCtx): + """Context manager for handling trace context""" + token = set_tracectx(trace) + try: + yield + finally: + if token is not None: + reset_tracectx(token) + +def transform_traces_pre_prologue( + self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs +) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]: + if computation_trace is None: return prologue_trace, computation_trace, epilogue_trace + + # Create new computation trace + new_computation_trace = from_trace(computation_trace) + + # Initialize environment with input tensors + for arg in computation_trace.args: + if isinstance(arg, TensorProxy): + self.env[arg.name] = arg + + # Process bound symbols within tracing context + with self.trace_context(new_computation_trace): + for bsym in computation_trace.bound_symbols: + if bsym.sym.id in trace_interpreter_skip_list: + new_bsym = bsym.from_bsym() + new_computation_trace.bound_symbols.append(new_bsym) + self._update_env(new_bsym) + continue + + autocast_impl = _maybe_get_autocast_rule_for_symbol(bsym.sym) + + if autocast_impl is not None: + # Process args with autocast + processed_args = [] + for arg in bsym.args: + if isinstance(arg, TensorProxy): + if arg.name in self.env: + tensor = self.env[arg.name] + if tensor.dtype in (dtypes.float32, dtypes.float64): + with self.trace_context(new_computation_trace): + processed_arg = maybe_downcast_to(self.dtype, tensor) + else: + processed_arg = tensor + else: + processed_arg = arg + else: + processed_arg = arg + processed_args.append(processed_arg) + + # Process kwargs with autocast + processed_kwargs = {} + for k, v in bsym.kwargs.items(): + if isinstance(v, TensorProxy): + if v.name in self.env: + tensor = self.env[v.name] + if tensor.dtype in (dtypes.float32, dtypes.float64): + with self.trace_context(new_computation_trace): + processed_kwargs[k] = maybe_downcast_to(self.dtype, tensor) + else: + processed_kwargs[k] = tensor + else: + processed_kwargs[k] = v + else: + processed_kwargs[k] = v + + # Apply autocast implementation within tracing context + with self.trace_context(new_computation_trace), disable_autocast(): + result = autocast_impl(*processed_args, **processed_kwargs, dtype=self.dtype) + + # Ensure the result is in the target dtype + if isinstance(result, TensorProxy) and result.dtype in (dtypes.float32, dtypes.float64): + with self.trace_context(new_computation_trace): + result = maybe_downcast_to(self.dtype, result) + + new_bsym = bsym.from_bsym( + args=processed_args, + kwargs=processed_kwargs, + output=result + ) + else: + new_bsym = bsym.from_bsym() + # If this is the final operation, ensure output is in target dtype + if isinstance(new_bsym.output, TensorProxy) and new_bsym.output.dtype in (dtypes.float32, dtypes.float64): + with self.trace_context(new_computation_trace): + new_bsym.output = maybe_downcast_to(self.dtype, new_bsym.output) + + new_computation_trace.bound_symbols.append(new_bsym) + self._update_env(new_bsym) + + new_computation_trace.set_provenance(TraceProvenance("constructed by autocast")) + return prologue_trace, new_computation_trace, epilogue_trace + + def _update_env(self, bsym): + """Update environment mapping with bound symbol outputs""" + if isinstance(bsym.output, (tuple, list)): + for out in bsym.output: + if isinstance(out, TensorProxy): + self.env[out.name] = out + elif isinstance(bsym.output, TensorProxy): + self.env[bsym.output.name] = bsym.output From 5590c36767634e8d9518cd3b8d037c79c0c0b2c9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:05:15 +0000 Subject: [PATCH 17/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/transforms/autocast.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/thunder/transforms/autocast.py b/thunder/transforms/autocast.py index dc99ae085a..aedf782735 100644 --- a/thunder/transforms/autocast.py +++ b/thunder/transforms/autocast.py @@ -20,6 +20,7 @@ autocast_impls: dict[prims.PrimIDs, Callable] = {} + # NOTE: Rules which are registered ltorch symbols should match the type signature # of those symbols as we use this rule for translating from `torch` -> `thunder.torch` # if autocast is enabled while jitting. See also `NOTE: torch.autocast support`. @@ -404,6 +405,7 @@ def trace_context(self, trace: TraceCtx): if token is not None: reset_tracectx(token) + def transform_traces_pre_prologue( self, prologue_trace: TraceCtx, computation_trace: TraceCtx, epilogue_trace: TraceCtx | None, **kwargs ) -> tuple[TraceCtx, TraceCtx, TraceCtx | None]: @@ -412,7 +414,7 @@ def transform_traces_pre_prologue( # Create new computation trace new_computation_trace = from_trace(computation_trace) - + # Initialize environment with input tensors for arg in computation_trace.args: if isinstance(arg, TensorProxy): @@ -472,15 +474,14 @@ def transform_traces_pre_prologue( with self.trace_context(new_computation_trace): result = maybe_downcast_to(self.dtype, result) - new_bsym = bsym.from_bsym( - args=processed_args, - kwargs=processed_kwargs, - output=result - ) + new_bsym = bsym.from_bsym(args=processed_args, kwargs=processed_kwargs, output=result) else: new_bsym = bsym.from_bsym() # If this is the final operation, ensure output is in target dtype - if isinstance(new_bsym.output, TensorProxy) and new_bsym.output.dtype in (dtypes.float32, dtypes.float64): + if isinstance(new_bsym.output, TensorProxy) and new_bsym.output.dtype in ( + dtypes.float32, + dtypes.float64, + ): with self.trace_context(new_computation_trace): new_bsym.output = maybe_downcast_to(self.dtype, new_bsym.output) From 4ecb15b3c8744c3ae67f65ff79abe37324fb2e08 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 11 Dec 2024 16:35:46 +0530 Subject: [PATCH 18/19] Update test_autocast.py --- thunder/tests/test_autocast.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 7c1b647a5f..eb50d7caa8 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -17,7 +17,9 @@ @instantiate(dtypes=dtypes.float_math_dtypes) def test_thunder_autocast_transform(executor, device, dtype): from thunder.transforms.autocast import AutocastTransform - + + # TODO: Consider adding support for device specific dtypes in the test + # instantiator. torch_device = torch.device(device) if torch_device.type == "cpu" and dtype == dtypes.float16: pytest.skip("float16 matmul is not supported on CPU.") From 7c8635e8a49571c4e4aa93f678a34c1b0e0a37e5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:06:45 +0000 Subject: [PATCH 19/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_autocast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index eb50d7caa8..5011ff0bc6 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -17,7 +17,7 @@ @instantiate(dtypes=dtypes.float_math_dtypes) def test_thunder_autocast_transform(executor, device, dtype): from thunder.transforms.autocast import AutocastTransform - + # TODO: Consider adding support for device specific dtypes in the test # instantiator. torch_device = torch.device(device)