From b93ff1d400363b12143cdce9267f01bb6ee71a70 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 14 Nov 2024 19:14:11 +0200 Subject: [PATCH 01/14] =?UTF-8?q?Improve=20no=5Fautocast=20overhead=20from?= =?UTF-8?q?=203.1=20=C2=B5s=20to=200.5=20=C2=B5s=20(6x=20improvement)=20(#?= =?UTF-8?q?1271)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- thunder/executors/torchex.py | 37 +++++++++++++++++++++++++++++++--- thunder/tests/test_autocast.py | 33 +++++++----------------------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 7b363606e8..f354e9abf0 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -148,9 +148,40 @@ def _device_put_transform(a: TensorProxy, device: devices.Device) -> TensorProxy def no_autocast(fn): - fn = torch.autocast(device_type="cpu", enabled=False, cache_enabled=False)(fn) - fn = torch.autocast(device_type="cuda", enabled=False, cache_enabled=False)(fn) - return fn + """ + A decorator that disables torch.autocast for the duration of the decorated + function. + + In Thunder this is useful when you want to ensure that the generated + function is not run with PyTorch's autocast enabled to execute exactly as + generated. + + Args: + fn: The function to decorate. + + Returns: + The decorated function. + """ + # This decorator intentionally does not use the torch.autocast decorator + # because it is much slower than the implementation here. This is because + # the torch.autocast decorator has a lot more overhead to support various + # features that are not needed in Thunder. + from torch import set_autocast_enabled + + prev_cpu = torch.is_autocast_cpu_enabled() + prev = torch.is_autocast_enabled() + + @wraps(fn) + def no_autocast_fn(*args, **kwargs): + try: + set_autocast_enabled("cpu", False) + set_autocast_enabled("cuda", False) + return fn(*args, **kwargs) + finally: + set_autocast_enabled("cpu", prev_cpu) + set_autocast_enabled("cuda", prev) + + return no_autocast_fn # diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 2abab28d34..2a01959f4e 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -85,11 +85,10 @@ def func(): trace = thunder.trace()(func) python_callable = trace.python_callable() - # 3 unwraps for: + # 2 unwraps for: # @no_grad() - # @autocast(device_type="cpu", ...) - # @autocast(device_type="cuda", ...) - cfunc = python_callable.__wrapped__.__wrapped__.__wrapped__ + # @no_autocast + cfunc = python_callable.__wrapped__.__wrapped__ b1, b2 = python_callable() assert b1 is False assert b2 is False @@ -107,10 +106,10 @@ def func(): with torch.autocast(device_type=devicetype, dtype=test_dtype): b1, b2 = python_callable() b3, b4 = cfunc() - assert b1 is False - assert b2 is False - assert b3 is (True if torch_device.type == "cuda" else False) - assert b4 is (True if torch_device.type == "cpu" else False) + assert not b1 + assert not b2 + assert not b3 + assert not b4 @instantiate( @@ -141,24 +140,6 @@ def func(a, b): assert output.dtype == (torch.float16 if torch_device.type == "cuda" else torch.bfloat16) -# Disabling on windows temporarily, until our windows runners source the -# appropriate visual studio config. -@pytest.mark.skipif(not is_inductor_supported() or platform.system() == "Windows", reason="inductor unsupported") -def test_torch_compile_autocast(): - """Checks if our autocast decorator plays well with ``torch.compile``""" - - @no_autocast - def fn(x, y): - return x + y - - a = torch.randn(2, 2) - b = torch.randn(2, 2) - cfn = torch.compile(fn, fullgraph=True) - actual = cfn(a, b) - expected = a + b - torch.testing.assert_close(actual, expected) - - def test_autocast_mixed_dtype_inputs(): def foo(x, w): return torch.nn.functional.linear(x, w) From 4914141fa1b077f4fa230670b8699ee052c18c67 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Fri, 15 Nov 2024 05:02:45 +0900 Subject: [PATCH 02/14] remove `langctx` from `__all__` (#1443) --- thunder/core/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 699750b3ea..635d0e591e 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -71,8 +71,6 @@ # Helpful classes "OrderedSet", "FrozenDict", - # Context-related functions and decorators - "langctx", ] T = TypeVar("T") From 6f50e52950ff1c9c14bf0758fbe85427151dca69 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 18 Nov 2024 18:21:15 +0900 Subject: [PATCH 03/14] Import updates of `utils`, `baseutils`, and `codeutils`. (#1444) --- thunder/core/baseutils.py | 59 ++++++++++++++++++++++++++++++++------- thunder/core/codeutils.py | 46 ++++++++++++++++++++---------- thunder/core/utils.py | 33 ++++++++++++---------- 3 files changed, 98 insertions(+), 40 deletions(-) diff --git a/thunder/core/baseutils.py b/thunder/core/baseutils.py index 55284b3b12..42c63aafd6 100644 --- a/thunder/core/baseutils.py +++ b/thunder/core/baseutils.py @@ -2,24 +2,63 @@ # This feature is available in Python 3.7 and later. # This import (like all __future__ imports) must be at the beginning of the file. from __future__ import annotations +from collections.abc import Sequence from enum import Enum +from types import MappingProxyType, ModuleType, CodeType, EllipsisType, FunctionType, MethodType +from typing import TYPE_CHECKING +import collections.abc +import dis import functools +import inspect import os -import dis - -import sys -import collections.abc -from numbers import Number -from typing import Any, Type, Union, Optional, Tuple, List -from collections.abc import Callable -from collections.abc import Sequence -from types import MappingProxyType, ModuleType, CodeType, EllipsisType, FunctionType, MethodType import re -import inspect +import sys import torch import numpy as np +if TYPE_CHECKING: + from collections.abc import Callable + from numbers import Number + from typing import Any + + +__all__ = [ + "BoundSymbolInterface", + "NumberProxyInterface", + "ProxyInterface", + "SymbolInterface", + "TagBase", + "TensorProxyInterface", + "TermColors", + "TorchAutogradFunctionCtxProxyInterface", + "build_callable", + "check", + "check_type", + "check_types", + "check_valid_length", + "check_valid_shape", + "default_dataclass_params", + "extract_callable_name", + "fnprint", + "get_module", + "indent", + "init_colors", + "init_windows_terminal", + "is_base_printable", + "is_base_printable_literal", + "is_base_printable_type", + "is_base_printable_value", + "is_collection", + "print_base_printable", + "print_base_type", + "print_number", + "print_type", + "run_once", + "sequencify", + "warn_term_variable_once", +] + # # Common utilities importable by any other file diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 3834bbc66f..68e9892d8e 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -1,26 +1,40 @@ -from types import CodeType, FunctionType, MethodType, EllipsisType -from typing import List, Dict, Tuple, Set, Deque, Any, NamedTuple, Optional -from numbers import Number -from collections import deque -from collections.abc import Mapping, Sequence, Iterable, Callable -import inspect -from inspect import Parameter -import string -import functools +from __future__ import annotations from functools import partial +from inspect import Parameter +from typing import TYPE_CHECKING, NamedTuple +import dataclasses import dis +import functools +import inspect import linecache -import dataclasses import sys -import torch - import thunder.core.baseutils as baseutils from thunder.core.baseutils import ProxyInterface, check import thunder.core.dtypes as dtypes import thunder.core.devices as devices from thunder.core.pytree import tree_flatten, tree_unflatten +if TYPE_CHECKING: + from typing import Any + from collections.abc import Callable, Sequence + from thunder.core.trace import TraceCtx + + +__all__ = [ + "ContextObject", + "SigInfo", + "get_siginfo", + "get_source_line", + "indent_string", + "is_literal", + "is_printable", + "is_simple_printable_collection", + "module_shortname", + "prettyprint", + "to_printable", +] + # # Functions related to analyzing and printing functions and arguments # @@ -106,7 +120,7 @@ def is_literal(x: Any) -> bool: return True -def _to_printable(tracectx: Optional, x: Any) -> tuple[Any, tuple[str, Any] | None]: +def _to_printable(tracectx: TraceCtx | None, x: Any) -> tuple[Any, tuple[str, Any] | None]: can_print, module_info = is_printable(x) if can_print: return x, module_info @@ -123,7 +137,7 @@ def _to_printable(tracectx: Optional, x: Any) -> tuple[Any, tuple[str, Any] | No # TODO Improve type annotations def to_printable( - trace: Optional, + trace: TraceCtx | None, x: Any, *, import_ctx: dict | None = None, @@ -302,7 +316,9 @@ def __repr__(self): # TODO Print the original signature's type annotations # TODO Maybe be clear about what inputs are const and what aren't? # TODO Improve this signature's type annotations - def prettyprint(self, *, trace: Optional = None, import_ctx: Optional = None, object_ctx=None) -> str: + def prettyprint( + self, *, trace: TraceCtx | None = None, import_ctx: Any | None = None, object_ctx: Any | None = None + ) -> str: def _arg_printer(name: str, has_default: bool, default: Any = None) -> str: # NOTE In this case the argument has a default value, like 'a' in foo(a=5) if has_default: diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 635d0e591e..8eab3c8f3f 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1,16 +1,16 @@ -import sys -import os +from __future__ import annotations +from collections import defaultdict, deque, UserDict +from collections.abc import Callable, Hashable, Iterable, Iterator, Sequence, Mapping from enum import Enum -from functools import reduce, wraps -import itertools -from itertools import chain +from functools import reduce from numbers import Number -from typing import Any, overload, Generic, Optional, TypeVar, TYPE_CHECKING -from collections.abc import Callable -from collections.abc import Hashable, Iterable, Iterator, Sequence -from collections import defaultdict +from types import MappingProxyType +from typing import overload, Generic, TypeVar, TYPE_CHECKING +import itertools +import os from typing_extensions import Self +import torch import thunder.core.dtypes as dtypes from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map @@ -20,6 +20,9 @@ from thunder.core.trace import TraceCtx import thunder.core.prims as prims +if TYPE_CHECKING: + from typing import Any + # This file defines utilities that can be used when defining primitive operations. # This file depends on proxies.py and the dtypes submodule. @@ -729,17 +732,17 @@ def __len__(self) -> int: return len(self.d) # - - def __sub__(self, other: "_OrderedSet") -> Self: + def __sub__(self, other: _OrderedSet) -> Self: return self.__class__(k for k in self if k not in other) - def __and__(self, other: "_OrderedSet") -> Self: + def __and__(self, other: _OrderedSet) -> Self: return self.__class__(k for k in self if k in other) - def __or__(self, other: "_OrderedSet") -> Self: + def __or__(self, other: _OrderedSet) -> Self: return self.__class__(itertools.chain(self, other)) # NOTE: actual set signature is (self, *others) - def difference(self, other: "_OrderedSet") -> Self: + def difference(self, other: _OrderedSet) -> Self: return self - other def add(self, x: T | T1): @@ -753,7 +756,7 @@ def discard(self, x: T | T1): def issubset(self, other): return all((e in other) for e in self) - def union(self, *others: "Sequence[_OrderedSet]") -> Self: + def union(self, *others: Sequence[_OrderedSet]) -> Self: return self.__class__(itertools.chain(self, *others)) def update(self, x: Iterable[T | T1]) -> None: @@ -791,7 +794,7 @@ def __missing__(self, key: T) -> T1: if TYPE_CHECKING: _UserDictT = dict else: - _UserDictT = collections.UserDict + _UserDictT = UserDict class FrozenDict(_UserDictT[T, T1], Mapping[T, T1]): From b21378c2d93a7dd46ad11057796e9f418501eb1e Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 18 Nov 2024 11:23:47 +0200 Subject: [PATCH 04/14] Test HF's implementation of Phi 3 model (#1439) --- requirements/test.txt | 2 +- thunder/tests/test_networks.py | 31 +++++++++++++++++-------------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/requirements/test.txt b/requirements/test.txt index c9ccf66f84..9ea8bc8ab6 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -18,7 +18,7 @@ pandas # thunder/benchmarks/test_benchmark_litgpt.py xlsxwriter # thunder/benchmarks/test_benchmark_litgpt.py jsonargparse # thunder/benchmarks/benchmark_litgpt.py bitsandbytes==0.42.0 # fixed version! -transformers==4.43.3 # for test_networks.py +transformers==4.46.2 # for test_networks.py # Installs JAX on Linux and MacOS jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation diff --git a/thunder/tests/test_networks.py b/thunder/tests/test_networks.py index 36ba5c3cd3..a1a3f0ca3f 100644 --- a/thunder/tests/test_networks.py +++ b/thunder/tests/test_networks.py @@ -401,29 +401,31 @@ def test_thunderfx_mistral_nemo_small(): @thunder.tests.framework.requiresCUDA -def test_hf_qwen2(): +@pytest.mark.parametrize("model_id", ["Qwen/Qwen2.5-7B-Instruct", "microsoft/Phi-3-mini-128k-instruct"]) +def test_hf_for_nemo(model_id): from thunder.dynamo import ThunderCompiler - from transformers import Qwen2Config, Qwen2ForCausalLM - - # https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json - configuration = Qwen2Config( - # Qwen2.5-7B-Instruct uses Grouped-Query Attention, while the default - # config uses Multi-Head Attention - num_attention_heads=28, - num_key_value_heads=4, + from transformers import AutoConfig, AutoModelForCausalLM + + configuration = AutoConfig.from_pretrained( + model_id, # Scaled down for testing - hidden_size=56, vocab_size=16, + pad_token_id=15, max_position_embeddings=32, + num_hidden_layers=1, ) - configuration.num_hidden_layers = 1 + configuration.hidden_size = configuration.num_attention_heads with torch.device("cuda"): - model = Qwen2ForCausalLM(configuration).to(torch.bfloat16) + model = AutoModelForCausalLM.from_config(configuration).to(torch.bfloat16) # thunder.jit doesn't work with Qwen2, so we use torch.compile # https://github.com/Lightning-AI/lightning-thunder/issues/1405 + + # fullgraph=True used to work with transformers 4.45.2, but it doesn't work + # with 4.46.2 because of re.findall usage in the loss function + fullgraph = False backend = ThunderCompiler() - compiled_model = torch.compile(model, backend=backend, fullgraph=True) + compiled_model = torch.compile(model, backend=backend, fullgraph=fullgraph) input_ids = torch.randint(0, configuration.vocab_size, (1, configuration.max_position_embeddings), device="cuda") ref_output = model(input_ids=input_ids, labels=input_ids) @@ -437,7 +439,8 @@ def test_hf_qwen2(): # https://github.com/Lightning-AI/lightning-thunder/issues/1407 torch.testing.assert_close(compiled_loss, ref_loss, rtol=1e-4, atol=1e-4) - assert len(backend.subgraph_infos) == 1, "Should have exactly 1 subgraph because of fullgraph=True" + if fullgraph: + assert len(backend.subgraph_infos) == 1, "Should have exactly 1 subgraph because of fullgraph=True" loss_grad = torch.randn_like(compiled_loss) grads_ref = torch.autograd.grad(ref_loss, model.parameters(), grad_outputs=loss_grad) From d60f85cf009793894e9329eca1ee543037237057 Mon Sep 17 00:00:00 2001 From: beverlylytle <57254617+beverlylytle@users.noreply.github.com> Date: Mon, 18 Nov 2024 10:25:41 +0100 Subject: [PATCH 05/14] Conditionally use prologue in test_vjp_correctness (#1438) --- thunder/tests/test_grad.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 6e683b09df..2f389a480f 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -266,7 +266,7 @@ def _dot(x, y): return sum([_tensor_dot(a, b) for a, b in zip(x, y)]) -def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = False): +def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = False, prologue_required: bool = False): """Check that the vector-Jacobian product of a function is correct. Args: @@ -296,7 +296,13 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals u = tree_map(make, primals) - comp_f = thunder.jit(f, disable_torch_autograd=True) + # dirty little trick for speed: skip the prologue, however, the prologue is required when + # there are non-differentiable kwargs + jf = executor.make_callable(f, disable_torch_autograd=True) + if prologue_required: + comp_f = thunder.jit(f, disable_torch_autograd=True) + else: + comp_f = thunder.compile_data(jf).get_computation_and_inputs(*primals)[0].computation_fn outs_p, J_u = numerical_jvp(comp_f)(primals, u) @@ -304,7 +310,7 @@ def check_vjp(f, *primals, comp, executor="torch", set_compile_data: bool = Fals v = tree_map(make, outs_p) if set_compile_data: - with thunder.core.compile_data.compile_data_and_stats(thunder.compile_data(comp_f), None): + with thunder.core.compile_data.compile_data_and_stats(thunder.compile_data(jf), None): initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v) else: initial_trace_vjp_f = thunder.trace()(vjp(f), primals, v) @@ -364,8 +370,15 @@ def wrapper(*differentiable_args): return wrapper, filtered_args -def snippet_vjp_correctness(func, args, comp, executor, set_compile_data): - check_vjp(func, *args, comp=comp, executor=executor, set_compile_data=set_compile_data) +def snippet_vjp_correctness(func, args, comp, executor, set_compile_data, prologue_required): + check_vjp( + func, + *args, + comp=comp, + executor=executor, + set_compile_data=set_compile_data, + prologue_required=prologue_required, + ) # TODO Use the given comparator @@ -408,6 +421,7 @@ def test_vjp_correctness(op, device, dtype, executor, comp): comp, executor, "adaptive_avg_pool2d" in op.name, + len(sample.kwargs) != 0, ) if result is not None: return result From 11a32a4b9575acc6de670450e1c0d66cab541305 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Mon, 18 Nov 2024 10:43:40 +0100 Subject: [PATCH 06/14] support no_grad in thunder.jit (#1423) --- thunder/__init__.py | 5 +++ thunder/common.py | 4 +++ thunder/core/symbol.py | 17 ++++++++- thunder/core/transforms.py | 8 +++++ thunder/tests/test_core.py | 72 +++++++++++++++++++++++++++++++++----- thunder/torch/__init__.py | 20 +++++++++-- 6 files changed, 113 insertions(+), 13 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 5f9bd9f521..5f4b75d6e3 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -442,6 +442,11 @@ def get_computation_and_inputs(*args, **kwargs): # which seems to break the consistency of cache_info, leading to a failure in cache_info check. cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs) + # Store the `is_grad_enabled` state of PyTorch. This is used by vjp transform + # to treat certain Symbols as constant. + cache_info["is_grad_enabled"] = pytorch.is_grad_enabled() + cd.is_grad_enabled = pytorch.is_grad_enabled() + # TODO RC1 Add module and function checks to prologue (make it a compile option) # Checks cache diff --git a/thunder/common.py b/thunder/common.py index bc5f370156..674cab65d8 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -221,6 +221,10 @@ def __init__( # State for pytorch autocast context managers. self.autocast_stack: AutocastStack = AutocastStack() + # State to query whether grad is enabled or disabled using + # torch.no_grad/torch.enable_grad/torch._C._set_grad_enabled + self.is_grad_enabled: bool = True + # # Gathers additional metadata # diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py index c34071c1cb..da6eca6ddd 100644 --- a/thunder/core/symbol.py +++ b/thunder/core/symbol.py @@ -21,7 +21,8 @@ from thunder.core.pytree import tree_flatten_with_dataclass, tree_unflatten, tree_map import thunder.core.dtypes as dtypes import thunder.core.devices as devices -from thunder.core.proxies import Proxy, NumberProxy, variableify, CollectionProxy +from thunder.core.proxies import Proxy, TensorProxy, NumberProxy, variableify, CollectionProxy, ProxyTag +from thunder.core.compile_data import get_compile_data from thunder.core.trace import ( get_tracectx, @@ -320,6 +321,20 @@ def __call__(self, *args, **kwargs): result = self.meta(*args, **kwargs) trace.pop_scope() + cd = get_compile_data() + if cd is not None and not cd.is_grad_enabled: + # If grad is disabled using `torch.no_grad` or `torch._C._set_grad_enabled(False)`, + # tag the results with `DETACHED_AUTOGRAD_GRAPH` which makes this Symbol a constant for + # vjp transform (applied later). + def tag_tensorproxy_output_as_detached(proxy): + if isinstance(proxy, TensorProxy): + # We need to remove name from trace, otherwise replace will return a proxy with new name. + trace.names.remove(proxy.name) + return proxy.replace(tags=(ProxyTag.DETACHED_AUTOGRAD_GRAPH,)) + return proxy + + result = tree_map(tag_tensorproxy_output_as_detached, result) + bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols) symbols_list = trace.peek_scope() diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 2d9d88cddf..7b09ef26b2 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -33,6 +33,7 @@ variableify, unvariableify, FutureTensorProxy, + ProxyTag, ) from thunder.core.compile_data import get_compile_data, get_compile_option from thunder.core.langctxs import langctx, Languages @@ -2485,10 +2486,17 @@ def is_constant_for_vjp(symbol: prims.Symbol) -> bool: bool: True if the symbol is constant, False otherwise. """ are_all_args_non_differentiable = not any(isinstance(arg, (FloatProxy, TensorProxy)) for arg in symbol.flat_args) + # Symbol's tag their output in `torch.no_grad` regions with `DETACHED_AUTOGRAD_GRAPH`. + # These are treated as constant for VJP. + # NOTE - `any(()) is False` + output_disconnected_from_graph = any( + ProxyTag.DETACHED_AUTOGRAD_GRAPH in o.tags for o in symbol.flat_outs if isinstance(o, TensorProxy) + ) return ( are_all_args_non_differentiable or symbol.are_all_args_constant or symbol.sym.id in nondifferentiable_vjp_symbols + or output_disconnected_from_graph ) diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 4330b5236a..29c491be08 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2099,7 +2099,7 @@ def func(x): compiled = executor.make_callable(func) out = compiled(x) assert out is x - initial_trace_with_dce = thunder.last_traces(compiled)[3] + initial_trace_with_dce = thunder.last_traces(compiled)[4] assert "Constructed by Dead Code Elimination" in str(initial_trace_with_dce) assert len(initial_trace_with_dce.bound_symbols) == 2 assert initial_trace_with_dce.bound_symbols[0].sym.id == prims.PrimIDs.UNPACK_TRIVIAL @@ -2480,27 +2480,81 @@ def foo_error(args): def test_grad_ctx(): + # NOTE - This test would start failing if tags on Proxies are dropped + # as the computation under `no_grad` won't be treated as constant + # and grad won't match with PyTorch eager. + + # Test `enable_grad` on a function works correctly @torch.enable_grad() def foo1(x): return x + 1 x = torch.randn(3, 3, requires_grad=True) - with pytest.warns(UserWarning, match="have no effect under thunder.jit"): - thunder.jit(foo1)(x).sum().backward() - + thunder.jit(foo1)(x).sum().backward() assert x.grad is not None + # Test `no_grad` on a function works correctly @torch.no_grad() def foo2(x): return x + 1 x = torch.randn(3, 3, requires_grad=True) - with pytest.warns(UserWarning, match="have no effect under thunder.jit"): - thunder.jit(foo2)(x).sum().backward() + thunder.jit(foo2)(x).sum().backward() + assert x.grad is None - # `torch.no_grad` has no effect on thunder's autodiff which determines whether to compute grad based on `requires_grad=True`. - # Thus when backward is called it computes grad for the input. - assert x.grad is not None + # Test `no_grad` ctx correctly disable gradient computation + def foo3(x): + with torch.no_grad(): + y = x * 3 + return x * 2 + y + + x = torch.randn(3, 3, requires_grad=True) + with torch.no_grad(): + x_ref = x.clone() + x_ref.requires_grad_(True) + + foo3(x_ref).sum().backward() + thunder.jit(foo3)(x).sum().backward() + # Verify the gradients match + torch.testing.assert_close(x.grad, x_ref.grad) + + # Test nested `no_grad` and `enable_grad` + def foo4(x): + with torch.enable_grad(): + with torch.no_grad(): + y = x * 3 + z = x * 4 + return x * 2 + y + z + + x = torch.randn(3, 3, requires_grad=True) + with torch.no_grad(): + x_ref = x.clone() + x_ref.requires_grad_(True) + + foo4(x_ref).sum().backward() + thunder.jit(foo4)(x).sum().backward() + # Verify the gradients match + torch.testing.assert_close(x.grad, x_ref.grad) + + def foo5(x): + return x * 2 + + x = torch.randn(3, 3, requires_grad=True) + with torch.no_grad(): + x_ref = x.clone() + x_ref.requires_grad_(True) + + jfoo = thunder.jit(foo5) + with torch.no_grad(): + o = jfoo(x) + assert o.grad_fn is None + assert thunder.cache_misses(jfoo) == 1 # First compilation + + # Running it out of `torch.no_grad`, should lead to recompile. + foo5(x_ref).sum().backward() + jfoo(x).sum().backward() + torch.testing.assert_close(x.grad, x_ref.grad) + assert thunder.cache_misses(jfoo) == 2 def test_serialize_trace(): diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index b216d2a684..a94ada1cc8 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -47,6 +47,7 @@ ListProxy, DictProxy, numberproxy, + ProxyTag, ) from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten from thunder.core.symbol import Symbol @@ -5238,11 +5239,24 @@ def torch_device(type: DeviceLike, index: int | None = None) -> devices.Device: register_function(torch.device, torch_device) -def _set_grad_enabled_with_warning(enabled: bool) -> None: - warnings.warn("torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect under thunder.jit") +# Tag to use on Proxies created in `no_grad` regions. +# VJP transform will treat BoundSymbol's whose output has these tags +# as constant. +ProxyTag.register_tag("DETACHED_AUTOGRAD_GRAPH") -register_function(torch._C._set_grad_enabled, _set_grad_enabled_with_warning) +# This is just a marker Symbol. `tag_no_grad_symbols_pass` pass uses these symbols +# to find the `no_grad` regions and mark the BoundSymbols within them as constant +# for VJP using the `DETACHED_AUTOGRAD_GRAPH` tag. +@torchsymbol(torch._C._set_grad_enabled, id="set_grad_enabled", tags=(prims.OpTags.CTX_MANAGER_ENTER_EXIT_OP,)) +def _set_grad_enabled_with_warning(enabled: bool) -> None: + cd = get_compile_data() + if cd is None: + warnings.warn( + "torch.enable_grad/torch.no_grad/torch._C._set_grad_enabled have no effect, use thunder.jit for correct behaviour" + ) + return + get_compile_data().is_grad_enabled = enabled def _unwrap_if_dead(tensor): From a5c523d3f0ba2f1131cac65419cb955e2110175a Mon Sep 17 00:00:00 2001 From: Ali Alshaarawy <45029495+ali-alshaar7@users.noreply.github.com> Date: Mon, 18 Nov 2024 08:33:07 -0500 Subject: [PATCH 07/14] `tree_flatten` supports named tuples (#1446) --- thunder/core/baseutils.py | 11 +++++++++++ thunder/core/interpreter.py | 12 +----------- thunder/core/pytree.py | 3 ++- thunder/tests/test_core.py | 11 +++++++++++ 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/thunder/core/baseutils.py b/thunder/core/baseutils.py index 42c63aafd6..08cd4fd159 100644 --- a/thunder/core/baseutils.py +++ b/thunder/core/baseutils.py @@ -210,6 +210,17 @@ def get_module(name: str) -> Any: return sys.modules[name] +def is_likely_from_collections_namedtuple(tuple_type): + from collections import namedtuple + + # Check if tuple_type code object is coming from namedtuple + return ( + hasattr(tuple_type, "__repr__") + and hasattr(tuple_type.__repr__, "__code__") + and tuple_type.__repr__.__code__ in namedtuple.__code__.co_consts + ) + + # # Functions related to printing and debugging # diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index bc40ce12b7..992d1e27ec 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -42,7 +42,7 @@ TracebackType, ) -from thunder.core.baseutils import Singleton, init_colors, extract_callable_name +from thunder.core.baseutils import Singleton, init_colors, extract_callable_name, is_likely_from_collections_namedtuple from thunder.core.codeutils import Positions @@ -2848,16 +2848,6 @@ def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /): else: item_wrappers.append(wv) - def is_likely_from_collections_namedtuple(tuple_type): - from collections import namedtuple - - # Check if tuple_type code object is coming from namedtuple - return ( - hasattr(tuple_type, "__repr__") - and hasattr(tuple_type.__repr__, "__code__") - and tuple_type.__repr__.__code__ in namedtuple.__code__.co_consts - ) - # Construction of namedtuples may raise try: ures = tuple(w.value for w in item_wrappers) diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 6a3d3d3a76..8c92a38555 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -6,7 +6,7 @@ import torch import thunder.core.dtypes as dtypes import thunder.core.devices as devices -from thunder.core.baseutils import ProxyInterface +from thunder.core.baseutils import ProxyInterface, is_likely_from_collections_namedtuple from types import FunctionType OPTREE_NAMESPACE = "thunder" @@ -61,6 +61,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE): torch.autograd.function.FunctionCtx, } and not isinstance(args, (ProxyInterface)) + and not is_likely_from_collections_namedtuple(args) and not dataclasses.is_dataclass(args) and not type(args).__module__.startswith("torch.return_types") ): diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 29c491be08..9acde1615b 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -662,6 +662,17 @@ def test_to_printable_not_collection(): assert inp is out +def test_to_printable_collection(): + from collections import namedtuple + + MyTuple = namedtuple("MyTuple", ["x", "y"]) + + inps = (MyTuple("abc", "def"),) + for inp in inps: + out = codeutils.to_printable(None, inp) + assert inp == out + + # # Type promotion tests # From 0205c737fbf039114c3ef816d7febd8a3fad03d2 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 18 Nov 2024 14:59:07 +0100 Subject: [PATCH 08/14] update codeowners (#1448) --- .github/CODEOWNERS | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index b75ae0a931..0273df53e2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,11 +2,15 @@ # These owners will be the default owners for everything in the repo. Unless a later match takes precedence, # @global-owner1, @global-owner2, and @global-owner3 will be requested for review when someone opens a pull request. -* @mruberry @lantiga @t-vi @carmocca + +# Thank you, our previous code owners for their service: +# @carmocca + +* @mruberry @lantiga @t-vi # CI/CD and configs -/.azure/ @borda @lantiga @t-vi @carmocca -/.github/ @borda @lantiga @t-vi @carmocca -/dockers/ @borda @lantiga @t-vi @carmocca -Makefile @borda @lantiga @t-vi @carmocca -*.yml @borda @lantiga @t-vi @carmocca +/.azure/ @borda @lantiga @t-vi +/.github/ @borda @lantiga @t-vi +/dockers/ @borda @lantiga @t-vi +Makefile @borda @lantiga @t-vi +*.yml @borda @lantiga @t-vi From 8d1637f0d16908cc08e62c92ae5a92ca4ab59552 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 19 Nov 2024 11:14:52 +0100 Subject: [PATCH 09/14] fix distributed tests with pt main (#1452) --- thunder/tests/distributed/helper.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/thunder/tests/distributed/helper.py b/thunder/tests/distributed/helper.py index 03c711daeb..55a3e3cc51 100644 --- a/thunder/tests/distributed/helper.py +++ b/thunder/tests/distributed/helper.py @@ -2,6 +2,7 @@ from functools import partial from functools import wraps from typing import ClassVar, TYPE_CHECKING +import inspect import math import os import sys @@ -129,10 +130,14 @@ def _run(cls, rank, test_name, file_name, pipe, *, fake_pg=False): local_rank = self.rank % torch.cuda.device_count() torch.cuda.set_device(local_rank) os.environ["LOCAL_RANK"] = str(local_rank) + if "destroy_process_group" in inspect.signature(self.run_test).parameters: + run_test_kwargs = {"destroy_process_group": False} + else: + run_test_kwargs = {} torch.distributed.barrier() try: - self.run_test(test_name, pipe) + self.run_test(test_name, pipe, **run_test_kwargs) except Exception: raise finally: From c9bbc5e0dae375cac2f242b60163c7bfaa648802 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:51:30 +0100 Subject: [PATCH 10/14] [pre-commit.ci] pre-commit suggestions (#1449) --- .pre-commit-config.yaml | 4 ++-- README.md | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b235526420..0e0076115d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,7 +45,7 @@ repos: exclude: "examples" - repo: https://github.com/executablebooks/mdformat - rev: 0.7.18 + rev: 0.7.19 hooks: - id: mdformat additional_dependencies: @@ -71,7 +71,7 @@ repos: # args: ["--fix"] - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.1.0 + rev: v4.0.0-alpha.8 hooks: - id: prettier # https://prettier.io/docs/en/options.html#print-width diff --git a/README.md b/README.md index 33306f505e..fba6ab1c87 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ Thunder aims to be usable, understandable, and extensible.   -> \[!Note\] +> [!Note] > Lightning Thunder is in alpha. Feel free to get involved, but expect a few bumps along the way.   From 052bac347e3d77856cfe2185081fe568310e1021 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 19 Nov 2024 11:57:01 +0100 Subject: [PATCH 11/14] add extensible DebugOptions class (#1447) --- docs/source/reference/thunder.rst | 1 + thunder/__init__.py | 13 ++++--- thunder/common.py | 3 ++ thunder/core/jit_ext.py | 10 +++-- thunder/core/options.py | 65 +++++++++++++++++++++++++++++++ thunder/tests/test_core.py | 27 +++++++++++++ 6 files changed, 110 insertions(+), 9 deletions(-) diff --git a/docs/source/reference/thunder.rst b/docs/source/reference/thunder.rst index a041cdf787..162d09c5b7 100644 --- a/docs/source/reference/thunder.rst +++ b/docs/source/reference/thunder.rst @@ -21,6 +21,7 @@ Querying information on compiled functions and modules .. autosummary:: :toctree: generated/ + DebugOptions compile_data compile_stats last_traces diff --git a/thunder/__init__.py b/thunder/__init__.py index 5f4b75d6e3..54c94855dc 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -16,6 +16,7 @@ from thunder.core.options import ( CACHE_OPTIONS, SHARP_EDGES_OPTIONS, + DebugOptions, ) from thunder.core.trace import ( TraceResults, @@ -124,6 +125,7 @@ "nvfuser_executor", "pytorch_executor", # debugging functions + "DebugOptions", "set_execution_callback_file", "jit", "resolve_executors", @@ -275,7 +277,6 @@ def compile(fn: Callable, recipe: Recipe | None): # This function will replace compile() (below) before RC1 -# TODO RC1 Consider adding a debug_log parameter to control debug printing # TODO RC1 Consider renaming compile_options to additional_compile_options def jit( fn: Callable, @@ -287,7 +288,7 @@ def jit( cache: None | CACHE_OPTIONS | str = None, disable_torch_autograd: bool = False, # TODO Revisit this UX for RC1 transforms: list[Transform] | None = None, - record_history: bool = False, + debug_options: DebugOptions | None = None, **compile_options, # TODO RC1 Make this explicit -- dict of options ) -> Callable: """Just-in-time compile a callable (function or model). @@ -313,7 +314,9 @@ def jit( - ``"constant values"`` - require Tensors to be of the same shape, device, dtype etc., and integers and strings to match exactly, - ``"same input"`` - don't check, but just assume that a cached function works if it exists. - transforms: List of transforms to be applied. It should be an instance :class:`thunder.core.transforms.Transform`. Default: ``None`` + transforms: optional list of transforms to be applied. It should be a list of instances of :class:`thunder.core.transforms.Transform`. Default: ``None`` + + debug_options: optional :class:`thunder.DebugOptions` instance. See the doc string of :class:`DebugOptions` for supported debug options. Default: ``None`` """ if "executors_list" in compile_options: @@ -345,8 +348,6 @@ def jit( # TODO: sharp edge if lookasides are shadowed? executor_lookasides.update(ex._lookasides) - assert type(record_history) is bool - # TODO RC1 Refine the compile data option to remove unused options # TODO: refine options cd = CompileData( @@ -361,6 +362,7 @@ def jit( disable_preprocessing=True, compile_options=compile_options, executor_lookasides=executor_lookasides, + debug_options=debug_options, ) cs = CompileStats() @@ -529,7 +531,6 @@ def get_computation_and_inputs(*args, **kwargs): args, kwargs, ad_hoc_executor=ad_hoc_executor, - record_history=record_history, sharp_edges=cd.sharp_edges, ) prologue_trc = jit_results.prologue_trace diff --git a/thunder/common.py b/thunder/common.py index 674cab65d8..c5101b8c27 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -12,6 +12,7 @@ resolve_cache_option, SHARP_EDGES_OPTIONS, resolve_sharp_edges_option, + DebugOptions, ) from thunder.core.utils import check, is_collection, AutocastStack from thunder.core.pytree import tree_flatten, tree_map @@ -202,6 +203,7 @@ def __init__( compile_options: dict[str, Any] = {}, get_computation_and_inputs: Callable | None = None, executor_lookasides: dict[Callable, Callable] | None = None, + debug_options: DebugOptions | None = None, ): # Records whether we're using the thunder.jit() entrypoint or not # The thunder.jit() entrypoint introduces important architectural updates, @@ -261,6 +263,7 @@ def __init__( self.disable_preprocessing = disable_preprocessing self.disable_torch_autograd_support = disable_torch_autograd_support self.debug_log = debug_log + self.debug_options = DebugOptions() if debug_options is None else debug_options # TODO Consider validating that this dict has exclusively string keys self.compile_options = compile_options diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index c7ec43ad45..cf186b1ead 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -101,7 +101,7 @@ from thunder.core.codeutils import get_siginfo, SigInfo import thunder.core.prims as prims from thunder.common import transform_for_execution -from thunder.core.options import CACHE_OPTIONS, SHARP_EDGES_OPTIONS +from thunder.core.options import CACHE_OPTIONS, SHARP_EDGES_OPTIONS, DebugOptions from thunder.core.symbol import Symbol, BoundSymbol, is_traceable from thunder.extend import Executor @@ -1685,13 +1685,17 @@ def update_tags(proxy_swapmap: dict[Variable, Proxy]) -> None: new.tags.update(unvariableify(old).tags) +DebugOptions.register_option( + "record_interpreter_history", bool, False, "record interpreter history (use thunder.last_interpreter_log to access)" +) + + def thunder_general_jit( fn: Callable, args: tuple[Any, ...], kwargs: dict[str, Any], /, *, - record_history: bool = False, sharp_edges: SHARP_EDGES_OPTIONS, ad_hoc_executor, ) -> TraceResults: @@ -1732,7 +1736,7 @@ def thunder_general_jit( callbacks=general_jit_callbacks, with_provenance_tracking=True, uncacheable_classes=(torch.Tensor, int, float, str, NoneType), - record_history=record_history, + record_history=compile_data.debug_options.record_interpreter_history, ) with jit_ctx(ctx): diff --git a/thunder/core/options.py b/thunder/core/options.py index a9164bd3d6..efed1f9eed 100644 --- a/thunder/core/options.py +++ b/thunder/core/options.py @@ -133,3 +133,68 @@ def resolve_sharp_edges_option(x: Any, /) -> SHARP_EDGES_OPTIONS: ) return seo + + +class DebugOptions: + _defaults = {} + _docs = {} + + def __init__(self, **kwargs): + cls = self.__class__ + for k, default in self._defaults.items(): + v = kwargs.pop(k, default) + typ = cls.__annotations__[k] + if not isinstance(v, typ): + raise TypeError(f"{cls.__name__}.{k} needs to be of type {typ.__name__}") + setattr(self, k, v) + if kwargs: + unknown_args = ", ".join(f"{k}" for k in kwargs) + raise TypeError(f"unknown argument(s) for {cls.__name__}: {unknown_args}") + + @classmethod + def register_option(cls, name, typ, default, doc=""): + if hasattr(cls, name): + raise AttributeError(f"{cls.__name__}.{name} is already registered") + + assert isinstance(default, typ) + cls._defaults[name] = default + cls.__annotations__[name] = typ + cls._docs[name] = doc + setattr(cls, name, default) + cls._set_docstring() + + @classmethod + def _set_docstring(cls): + cls.__doc__ = f"""{cls.__name__}(**options) + options can be dynamically registered, currently registered ones are below + + Keyword Args: + {cls.list_options(docstr=True)} + """ + + @classmethod + def list_options(cls, docstr=False): + lines = [] + cls.__annotations__ # initialize annotations in cls.__dict__ + for name, default in sorted(cls._defaults.items()): + typ = cls.__annotations__[name] + doc = cls._docs[name] + lines.append(f"{name}: {typ.__name__}={default} {doc}") + + sep = "\n" if not docstr else "\n\n " + return sep.join(lines) + + def __repr__(self): + cls = self.__class__ + repr = [f"{cls.__name__}("] + for k, default in cls._defaults.items(): + v = getattr(self, k, default) + if v != default: + repr.append(f" {k}={v},") + repr.append(")") + if len(repr) <= 3: + return "".join(r.lstrip().rstrip(",") for r in repr) + return "\n".join(repr) + + +DebugOptions._set_docstring() diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 9acde1615b..bf6e9bd7d3 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -3086,3 +3086,30 @@ def fn(idx, weight): for ref in prev_iter_refs: assert ref() is None + + +def test_debug_options(): + from thunder import DebugOptions + import dill + + initial_state = dill.dumps(dict(DebugOptions.__dict__)) + print(DebugOptions.__dict__) + DebugOptions.register_option("test_option", bool, False, "Test Option") + + assert "Test Option" in DebugOptions.__doc__ + + do = DebugOptions(test_option=True) + assert do.test_option is True + + with pytest.raises(TypeError, match="test_option"): + do = DebugOptions(test_option=5) + + del DebugOptions._docs["test_option"] + del DebugOptions._defaults["test_option"] + del DebugOptions.__annotations__["test_option"] + del DebugOptions.test_option + + DebugOptions._set_docstring() + + print(DebugOptions.__dict__) + assert dill.dumps(dict(DebugOptions.__dict__)) == initial_state From a6175034226fda30bb125a3f74f2fe6edf9b4a6f Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 19 Nov 2024 12:05:54 +0100 Subject: [PATCH 12/14] thunderFX : pass to remove empty autocast regions (#1400) --- thunder/dynamo/compiler.py | 4 ++- thunder/dynamo/utils.py | 41 +++++++++++++++++++++++++++ thunder/tests/test_dynamo.py | 55 ++++++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+), 1 deletion(-) diff --git a/thunder/dynamo/compiler.py b/thunder/dynamo/compiler.py index f55ee29657..d61aa10812 100644 --- a/thunder/dynamo/compiler.py +++ b/thunder/dynamo/compiler.py @@ -7,7 +7,7 @@ import torch from thunder.core.baseutils import run_once -from thunder.dynamo.utils import recompile_graph +from thunder.dynamo.utils import recompile_graph, remove_empty_autocast from thunder.dynamo.splitter import _splitter if TYPE_CHECKING: @@ -72,6 +72,8 @@ def __init__(self, **thunder_options): self._torch_compile = partial(torch.compile, **torch_inductor_options) def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): + gm = remove_empty_autocast(gm) + # Dynamo uses lazy generation of the underlying Python code, so we need to # force recompilation of the GraphModule before passing it to Thunder. recompile_graph(gm) diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index 8b4c690c0a..d434d02342 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -512,3 +512,44 @@ def checkpoint_converter(gm: torch.fx.GraphModule, sub_gm: torch.fx.GraphModule) else: function_module = getattr(gm, n.args[0].name) _checkpoint_function_converter(function_module) + + +def remove_empty_autocast(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Function to remove empty autocast regions from GraphModule. + + Dynamo can provide empty autocast regions in which case, it is more performant to remove them + from the graph than to compile them and pay the cost of calling a wrapped optimized function + which does nothing. + + Args: + graph_module: Graph module to which this pass is applied. + + """ + + empty_autocast_removed_graph_module = copy.deepcopy(graph_module) + + # Dummy init node. + prev_node = torch.fx.node.Node(graph_module.graph, "start_node", "call_function", lambda: None, None, None) + nodes_to_erase = [] + for node in empty_autocast_removed_graph_module.graph.nodes: + # As _enter_autocast and _exit_autocast functions map the regions created by context manager, + # previous `_enter_autocast` will always correspond with current `_exit_autocast`. + if ( + prev_node.target == torch.amp.autocast_mode._enter_autocast + and node.target == torch.amp.autocast_mode._exit_autocast + ): + # NOTE: Order of node being appended matters. + # The node to be erased has to have zero users. + # So, we remove `_exit_autocast` first (which consumes output from `_enter_autocast`) + # and then we can remove the corresponding `_enter_autocast`. + nodes_to_erase.append(node) + nodes_to_erase.append(prev_node) + + prev_node = node + + # Erase the marked nodes. + for node in nodes_to_erase: + empty_autocast_removed_graph_module.graph.erase_node(node) + + return empty_autocast_removed_graph_module diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 2f9bb0d124..da9129dcbe 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -1,5 +1,6 @@ import pytest import warnings +import itertools import torch import torch.fx import torch.nn as nn @@ -515,6 +516,60 @@ def func(x): torch.testing.assert_close(actual_grad, expected_grad) +def test_empty_autocast(): + autocast_ops = (torch.amp.autocast_mode._enter_autocast, torch.amp.autocast_mode._exit_autocast) + + def _call_thunder_backend(fn, args): + backend = ThunderCompiler() + jf = torch.compile(backend=backend)(f) + jf(*args) + return backend + + # autocast region is removed + def f(): + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + pass + return + + backend = _call_thunder_backend(f, ()) + assert all(node.target not in autocast_ops for node in backend.subgraph_infos[0].split_graph_module.graph.nodes) + + # Both autocast regions are removed + def f(x): + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + pass + y = x @ x + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + pass + return y + + x = torch.randn(3, 3) + backend = _call_thunder_backend(f, (x,)) + + all_nodes = itertools.chain( + backend.subgraph_infos[0].split_graph_module.graph.nodes, + backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes, + ) + assert all(node.target not in autocast_ops for node in all_nodes) + + # First autocast region is removed and second isn't + def f(x): + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + pass + y = x @ x + with torch.autocast(dtype=torch.bfloat16, device_type="cpu"): + y = y @ y + return y + + x = torch.randn(3, 3) + backend = _call_thunder_backend(f, (x,)) + all_nodes = itertools.chain( + backend.subgraph_infos[0].split_graph_module.graph.nodes, + backend.subgraph_infos[0].split_graph_module.thunder_1.graph.nodes, + ) + assert sum(node.target in autocast_ops for node in all_nodes) == 2 + + # Sample command to run the benchmark using ThunderCompilerGraphBenchmarking # pytest thunder/tests/test_dynamo.py -k test_ThunderCompilerGraphBenchmarking_groupby --benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName' # For more details, see :class:`thunder.dynamo.compiler_graph_benchmark.ThunderCompilerGraphBenchmarking` From f206afa9ad76d8c862393325ab8832cf0cadb113 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 19 Nov 2024 12:24:30 +0100 Subject: [PATCH 13/14] fix: using te and fsdp leads to multiple device found error (#1453) --- thunder/benchmarks/benchmark_litgpt.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 2aa78cbe13..8bcaf575eb 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -114,11 +114,13 @@ def _resursively_swap_linear_layers_for_te(module: torch.nn.Module) -> None: if isinstance(m, torch.nn.Linear): has_bias = m.bias is not None - new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=device) + # Pass device as str (as there is a bug in TransformerEngine's handling of torch.device) + new_linear = te.Linear(m.in_features, m.out_features, bias=has_bias, device=str(device)) setattr(module, n, new_linear) if swap_layernorm and isinstance(m, torch.nn.LayerNorm): - new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=device) + # Pass device as str (as there is a bug in TransformerEngine's handling of torch.device) + new_layernorm = te.LayerNorm(m.normalized_shape[0], eps=m.eps, device=str(device)) setattr(module, n, new_layernorm) initial_params_cnt = parameters_cnt(model) @@ -366,11 +368,6 @@ def __init__( self.model = self.init_model() print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.") - if self.use_te_fp8_autocast: - is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm" - swap_linear_layers_for_te(self.model, device, swap_layernorm=not is_wo_layernorm) - self.model.to(torch.bfloat16) - # Setup the distributed algorithm choices if distributed_first := (self.compile in ("eager", "inductor") or "dynamo" in self.compile): self.model = self.setup_distributed(self.model) @@ -407,8 +404,14 @@ def init_model(self): init_device = torch.device("meta") if self.distributed_mode in FSDP_MODES else self.device with init_device: model = GPT(self.config) - model.to(dtype=torch.bfloat16) + + # Handle fp8 related Linear layer swapping (for torchao or TransformerEngine) model = self._torchao_fp8_handler.convert_model_to_fp8(model) + if self.use_te_fp8_autocast: + is_wo_layernorm = self.low_precision_mode == "fp8-delayed-te-wo_layernorm" + swap_linear_layers_for_te(model, init_device, swap_layernorm=not is_wo_layernorm) + + model.to(dtype=torch.bfloat16) return model def setup_distributed(self, model): From 60f3ee1ec536ee8d6fdef503af54525e0a3978a4 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Tue, 19 Nov 2024 13:44:29 +0100 Subject: [PATCH 14/14] support graph-by-graph benchmarking for PyTorch native checkpointing (#1437) --- thunder/dynamo/compiler_graph_benchmark.py | 18 ++++++++++++ thunder/dynamo/splitter.py | 14 +++++++-- thunder/dynamo/utils.py | 20 +++++++------ thunder/tests/test_dynamo.py | 34 ++++++++++++++++++++++ 4 files changed, 74 insertions(+), 12 deletions(-) diff --git a/thunder/dynamo/compiler_graph_benchmark.py b/thunder/dynamo/compiler_graph_benchmark.py index eafd30ce0e..ddb7f80e53 100644 --- a/thunder/dynamo/compiler_graph_benchmark.py +++ b/thunder/dynamo/compiler_graph_benchmark.py @@ -2,6 +2,7 @@ from itertools import chain from pytest_benchmark.fixture import BenchmarkFixture from typing import TYPE_CHECKING +from looseversion import LooseVersion import torch from thunder.dynamo import ThunderCompiler @@ -124,6 +125,23 @@ def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args): def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, torch.Tensor]): split_module = super().__call__(gm, sample_args) + + def has_checkpoint_node(g): + if g.find_nodes(op="call_function", target=torch.ops.higher_order.tag_activation_checkpoint): + return True + for n in g.nodes: + if n.op == "call_module" and has_checkpoint_node(getattr(g.owning_module, n.target).graph): + return True + return False + + if LooseVersion(torch.__version__) < LooseVersion("2.6.0"): + # NOTE: PyTorch 2.6 changes the structure of GraphModule when using activation checkpointing. + # It's hard to retrieve the example input tensor for the GraphModule contains checkpoint operator before PyTorch 2.6 + if has_checkpoint_node(split_module.graph): + raise RuntimeError( + "The benchmarking of the Torch activation checkpointing is only supported with PyTorch version 2.6 or later." + ) + compiled_functions_to_submodule = { v.compiled_fn: k for k, v in self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions.items() } diff --git a/thunder/dynamo/splitter.py b/thunder/dynamo/splitter.py index b123400ec7..b128357b97 100644 --- a/thunder/dynamo/splitter.py +++ b/thunder/dynamo/splitter.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING +import copy import torch from torch.fx.passes.split_module import split_module @@ -131,9 +132,10 @@ def callback(node) -> int: return partition_cnt # `split_module` iterates over nodes and determines the partition to place them based on the callback. - split_gm: torch.fx.GraphModule = split_module( + original_split_gm: torch.fx.GraphModule = split_module( gm, root_m=None, split_callback=callback, keep_original_order=True, keep_original_node_name=True ) + split_gm = copy.deepcopy(original_split_gm) def is_thunder_supported_partition(node: torch.fx.Node) -> bool: return node.name.startswith("submod") and int(node.name.replace("submod_", "")) in supported_partitions @@ -142,6 +144,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: thunder_compiled_fns = [] submodule_to_compiled_fns = {} for node in split_gm.graph.nodes: + node_name = node.name if is_thunder_supported_partition(node): graph_module = getattr(split_gm, node.name) # Replace PyTorch operators within the checkpointed function with the corresponding Thunder operators @@ -150,13 +153,17 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: # Update the node name from "submod_*" to "thunder_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "thunder"), jit_fn) thunder_compiled_fns.append(jit_fn) - submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.THUNDER) + submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction( + jit_fn, CompilerType.THUNDER + ) elif node.name.startswith("submod"): # For inductor graph_module = getattr(split_gm, node.name) jit_fn = torch_inductor(graph_module) # Update the node name from "submod_*" to "inductor_*" for more user-friendly names update_node_and_submodule(split_gm, node, node.name.replace("submod", "inductor"), jit_fn) - submodule_to_compiled_fns[graph_module] = CompiledFunction(jit_fn, CompilerType.TORCH_INDUCTOR) + submodule_to_compiled_fns[getattr(original_split_gm, node_name)] = CompiledFunction( + jit_fn, CompilerType.TORCH_INDUCTOR + ) else: # Everything else is a glue code to call and pass outputs between the other partitions. pass @@ -166,6 +173,7 @@ def is_thunder_supported_partition(node: torch.fx.Node) -> bool: return split_gm, SubgraphInfo( gm, + original_split_gm, split_gm, thunder_compiled_fns, submodule_to_compiled_fns, diff --git a/thunder/dynamo/utils.py b/thunder/dynamo/utils.py index d434d02342..668f2ef0bc 100644 --- a/thunder/dynamo/utils.py +++ b/thunder/dynamo/utils.py @@ -80,17 +80,21 @@ class SubgraphInfo: Attributes: original_graph_module: The original graph module. - split_graph_module: The graph module for the split subgraph. + original_split_graph_module: The original split graph module before any transformations are applied. + Specifically, before the :func:`checkpoint_converter` replaces the Torch operators with Thunder symbols, + and before any submodules are compiled by Thunder. + split_graph_module: The graph module for the split subgraph. It contains the compiled thunder/inductor modules. thunder_compiled_fns: List of thunder optimized callables. This could be :obj:`None` if there the graph module was not supported by thunder. Look at the :attr:`split_reasons` for further information. - submodule_to_compiled_functions: Dict from subgraph to compiled function. + submodule_to_compiled_functions: Dict from subgraph in :attr:`original_split_graph_module` to compiled function. This will be a dict with one pair in case the graph was not split. split_reasons: List of reasons explaining why the subgraph was split. Present only if there are was a split. """ original_graph_module: torch.fx.GraphModule + original_split_graph_module: torch.fx.GraphModule | None split_graph_module: torch.fx.GraphModule | None thunder_compiled_fns: list[Callable] | None submodule_to_compiled_functions: dict[torch.fx.GraphModule, CompiledFunction] @@ -466,8 +470,7 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule): Args: gm (torch.fx.GraphModule): The GraphModule of the checkpointed function, which is modified inplace. """ - new_graph = copy.deepcopy(gm.graph) - for n in new_graph.nodes: + for n in gm.graph.nodes: # replace the torch operator in "call_function" node if n.op == "call_function": assert isinstance(n.target, Callable) @@ -476,19 +479,18 @@ def _checkpoint_function_converter(gm: torch.fx.GraphModule): check( n.target in _torch_to_thunder_function_map, lambda: f"Unexpected {n.target}, not registered in Thunder" ) - with new_graph.inserting_before(n): - thunder_node = new_graph.call_function( + with gm.graph.inserting_before(n): + thunder_node = gm.graph.call_function( _torch_to_thunder_function_map[n.target], args=n.args, kwargs=n.kwargs ) n.replace_all_uses_with(thunder_node) - new_graph.erase_node(n) + gm.graph.erase_node(n) else: if n.op == "call_module": raise RuntimeError( "Unexpected call_module detected inside a checkpoint. This should have been inlined in dynamo graphs" ) - new_graph.lint() - gm.graph = new_graph + gm.graph.lint() recompile_graph(gm) diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index da9129dcbe..42299c149c 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -5,6 +5,7 @@ import torch.fx import torch.nn as nn import torch.nn.functional as F +from looseversion import LooseVersion from thunder import dtypes from thunder.dynamo import ThunderCompiler @@ -445,6 +446,10 @@ def func(x): IS_WINDOWS, reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094", ), + pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("2.6.0"), + reason="Skip until the Torch bug is fixed - https://github.com/pytorch/pytorch/pull/139275", + ), ), ) @requiresCUDA @@ -639,6 +644,35 @@ def f(x): compiled(x) +@pytest.mark.skipif( + LooseVersion(torch.__version__) < LooseVersion("2.6.0"), + reason="The checkpoint function becomes a submodule of the module containing `tag_activation_checkpoint` in PyTorch 2.6.0.", +) +@requiresCUDA +def test_ThunderCompilerGraphBenchmarking_checkpoint(benchmark): + class SimpleModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(10, 20) + + def forward(self, x): + x = torch.utils.checkpoint.checkpoint(self.layer1, x) + x = F.relu(x) + return x + + x = torch.randn(5, 10).cuda().requires_grad_() + model = SimpleModel().cuda().train() + + exe_backend = ThunderCompiler() + backend = ThunderCompilerGraphBenchmarking( + benchmark, executors={"inductor": torch.compile, "thunderfx": torch.compile(backend=exe_backend)} + ) + # Using torch.compile here fails with "TypeError: cannot pickle '_io.TextIOWrapper' object" in + # https://github.com/Lightning-AI/pytorch-lightning/blob/828fd998961f6a60f92c35254bb94d6e049ad069/src/lightning/fabric/wrappers.py#L421 + jf = torch._dynamo.optimize(backend=backend)(model) + out = jf(x) + + @requiresCUDA @pytest.mark.filterwarnings(r"ignore:`torch\.cpu\.amp\.autocast\((.*?)\)` is deprecated.*:FutureWarning") def test_checkpoint_converter():