diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index a25c979b13..0470adf19f 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -84,7 +84,6 @@ jobs: pytest thunder/tests/ \ -m "not standalone" \ -v --datefmt="%Y%m%d-%H:%M:%S.%f" \ - --timeout=240 \ --random-order-seed=42 \ --durations=250 \ --timeout=240 \ @@ -97,7 +96,7 @@ jobs: ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ --flags=gpu,pytest,regular --name="GPU-coverage" --env=linux,azure condition: ne(variables['testing'], 'distributed') - timeoutInMinutes: "30" + timeoutInMinutes: "40" displayName: "Testing: regular" - bash: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index be498f6cb9..73bdeaade1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -23,7 +23,7 @@ repos: - id: detect-private-key - repo: https://github.com/asottile/pyupgrade - rev: v3.17.0 + rev: v3.18.0 hooks: - id: pyupgrade args: ["--py310-plus"] @@ -38,14 +38,14 @@ repos: #args: ["--write-changes"] # uncomment if you want to get automatic fixing - repo: https://github.com/psf/black - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black name: Black code exclude: "examples" - repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 + rev: 0.7.18 hooks: - id: mdformat additional_dependencies: @@ -55,7 +55,7 @@ repos: exclude: "examples" - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.9.1 + rev: v1.0.0 hooks: - id: sphinx-lint diff --git a/README.md b/README.md index bcac054cba..a18296adac 100644 --- a/README.md +++ b/README.md @@ -73,12 +73,18 @@ The easiest way to get started with Thunder, requiring no extra installations or ## Install Thunder -To use Thunder on your local machine: +Thunder is in alpha and the latest development is happening on the `main` branch. You can install the latest version of Thunder from the `main` branch as follows: -- install [nvFuser](https://github.com/NVIDIA/Fuser) and PyTorch stable together as follows: +```bash +pip install git+https://github.com/Lightning-AI/lightning-thunder.git@main +``` + +To achieve the best performance, you can install Thunder with the following additional dependencies: + +- install nightly [nvFuser](https://github.com/NVIDIA/Fuser) built for PyTorch 2.4 as follows: ```bash -# install nvFuser which installs the matching stable PyTorch +# install nvFuser built for the matching stable PyTorch pip install --pre nvfuser-cu121-torch24 ``` @@ -89,35 +95,12 @@ pip install --pre nvfuser-cu121-torch24 pip install nvidia-cudnn-frontend ``` -- Finally, install Thunder as follows: - -``` -# install thunder -pip install lightning-thunder -``` -
Advanced install options   -### Install from main - -Alternatively, you can install the latest version of Thunder directly from this GitHub repository as follows: - -``` -# 1) Install nvFuser and PyTorch dependencies: -pip install --pre nvfuser-cu121-torch24 -``` - -```bash -# 2) Install Thunder itself -pip install git+https://github.com/Lightning-AI/lightning-thunder.git -``` - -  - ### Install to tinker and contribute If you are interested in tinkering with and contributing to Thunder, we recommend cloning the Thunder repository and installing it in pip's editable mode: diff --git a/requirements/test.txt b/requirements/test.txt index f734c91ab6..d3fa09f645 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -23,3 +23,6 @@ transformers==4.43.3 # for test_networks.py # Installs JAX on Linux and MacOS jaxlib; sys_platform == 'linux' or sys_platform == 'darwin' # required for jax, see https://github.com/google/jax#installation jax; sys_platform == 'linux' or sys_platform == 'darwin' # for test_ops.py + +asvdb @ git+https://github.com/rapidsai/asvdb.git +asv >=0.6.4 diff --git a/thunder/benchmarks/conftest.py b/thunder/benchmarks/conftest.py new file mode 100644 index 0000000000..6251b2ae26 --- /dev/null +++ b/thunder/benchmarks/conftest.py @@ -0,0 +1,82 @@ +import os +import platform +import psutil +from typing import Any +import warnings +import importlib.util + + +def pytest_addoption(parser): + # CLI option to specify where to store the benchmark results in asv format. + # If not set or None, results won't be saved in asv. + parser.addoption("--asv_bench_dir", action="store", default=os.getenv("THUNDER_BENCH_DIR")) + + +def pytest_sessionfinish(session, exitstatus): + # Save result only if the pytest session was a benchmark. + if hasattr(session.config, "_benchmarksession"): + save_benchmark_results_asv(session.config) + + +def sanitize_params(benchmark_params: list[tuple[str, Any]]) -> list[tuple[str, Any]]: + """Util function that takes a list of params and removes serialization information. E.g. given '' returns 'torch_executor'.""" + sane_params = [] + for k, v in benchmark_params: + if k == "executor": + sane_params += [(k, str(v).split()[1])] + else: + sane_params += [(k, v)] + return sane_params + + +def save_benchmark_results_asv(config): + """Save the benchmark results after a pytest session in the asv format. + User must specify the --asv_bench_dir flag to store the results. + """ + + bench_dir = config.option.asv_bench_dir + + if not importlib.util.find_spec("asv"): + warnings.warn("asvdb is not available. Results won't be saved in asv format.") + return + + if not bench_dir: + warnings.warn("asv_bench_dir' is not set. Results won't be saved in asv format.") + return + + from asvdb import utils, ASVDb, BenchmarkResult, BenchmarkInfo + + benchmarks = config._benchmarksession.benchmarks + + # Get system information to store alongside the results. + uname = platform.uname() + commit_hash, commit_time = utils.getCommitInfo() + repo_name, current_branch = utils.getRepoInfo() + python_version = platform.python_version() + memory_size = str(psutil.virtual_memory().total) + + bench_info = BenchmarkInfo( + machineName=uname.machine, + osType=f"{uname.system} {uname.release}", + pythonVer=python_version, + commitHash=commit_hash, + commitTime=commit_time, + cpuType=uname.processor, + arch=uname.machine, + ram=memory_size, + ) + + # Create the asv result database. + db = ASVDb(dbDir=bench_dir, repo=repo_name, branches=[current_branch]) + + # Add all the benchmarks to the database. + for bench in benchmarks: + name = bench.name.split("[")[0] + params_pairs = sanitize_params(bench.params.items()) + result = BenchmarkResult( + funcName=name, + argNameValuePairs=params_pairs, + result=bench.stats.median * 1e6, + unit="µseconds", + ) + db.addResult(bench_info, result) diff --git a/thunder/common.py b/thunder/common.py index 00292fea41..1c58e33445 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -303,7 +303,7 @@ def translate(x: Any, *, name: str | None = None) -> Any: if isinstance(x, Proxy): # register proxy name used by NumberProxies in TensorProxy.shape if isinstance(x, TensorProxy): - for s_p in filter(lambda s: isinstance(s, Proxy), x.shape): + for s_p in filter(lambda s: isinstance(s, Proxy), x._shape): # TODO need to avoid name conflict here, since s_p.name # could have conflicted with something defined earlier in # the trace. diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 5a6efa434e..4d0c239356 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -52,6 +52,7 @@ ) import torch +import torch.utils.checkpoint from thunder.core.proxies import ( DistParallelType, proxy, @@ -607,7 +608,6 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar So far, non-tensor ``ctx`` attributes seem to be folded into a trace. """ from thunder.core.baseutils import check, sequencify - from thunder.core.transforms import augmented_forward_impls, backward_impls, VJPDual jit_ctx: JitCtx = get_jit_ctx() @@ -629,14 +629,14 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar unwrapped_custom_forward_result = unwrap(custom_forward_result) # autograd.Function produces views of the tensor to attache the autograd node to unwrapped_custom_forward_result = tree_map( - lambda x: prims.shallow_copy(x) if isinstance(x, TensorProxy) else x, unwrapped_custom_forward_result + lambda x: prims.shallow_copy(x) if isinstance(x, TensorProxy) else x, + unwrapped_custom_forward_result, ) custom_fwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope() # not augmented for when we don't need grad trace_of_fwd = TraceCtx() - for bsym in custom_fwd_bsyms: - trace_of_fwd.add_bound_symbol(bsym) + trace_of_fwd.bound_symbols.extend(custom_fwd_bsyms) with tracectx(trace_of_fwd): prims.python_return(unwrapped_custom_forward_result) @@ -647,20 +647,9 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar def core_of_forward(*args, **kwargs): return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs) - def custom_forward_meta(*args, **kwargs): - trace = thunder.core.trace.get_tracectx() - trace.push_scope([]) # don't record symbol calls - res = core_of_forward(*args, **kwargs) - trace.pop_scope() - return res - - def bind_postprocess(bsym): - bsym._call_ctx = {} - custom_fwd_sym = jit_ctx.ad_hoc_executor.register_operator( symbol_name, like=core_of_forward, - bind_postprocess=bind_postprocess, ) unwrapped_forward_result = custom_fwd_sym(*unwrapped_custom_forward_args) @@ -669,49 +658,28 @@ def bind_postprocess(bsym): provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, custom_forward_result.provenance]), ) - jit_ctx.ad_hoc_executor.register_implementation(custom_fwd_sym, execution_transform=core_of_forward) - augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = ( tuple(sequencify(unwrapped_custom_forward_result)), ctx_proxy.saved_tensors, ) trace_of_augmented_fwd = TraceCtx() - for bsym in custom_fwd_bsyms: - trace_of_augmented_fwd.add_bound_symbol(bsym) + trace_of_augmented_fwd.bound_symbols.extend(custom_fwd_bsyms) with tracectx(trace_of_augmented_fwd): prims.python_return(augmented_bsym_output) trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(custom_fwd_sym.name, unwrapped_custom_forward_args) trace_of_augmented_fwd.args = unwrapped_custom_forward_args - @wraps(trace_of_augmented_fwd.python_callable()) - def core_of_augmented_forward(*args, **kwargs): - return thunder.core.trace_interpreter.interpret_trace(trace_of_augmented_fwd, *args, **kwargs) - - @wraps(core_of_augmented_forward) - def augmented_custom_forward_rule(*args, **kwargs): - primal, residulas = core_of_augmented_forward(*args, **kwargs) - check(len(primal) == 1, lambda f: "{primal=} has {len(primal)} proxies but expected 1") - return VJPDual(primal=primal[0], residuals=residulas) - - # TODO: build and register gradient_transform instead - augmented_forward_impls[custom_fwd_sym.name] = augmented_custom_forward_rule - # Backward definition custom_backward = custom_autograd_function_cls.backward - grads = tree_map( lambda a: a.replace_name(f"grad_{a.name}"), sequencify(unwrapped_custom_forward_result), ) trace_of_backward = TraceCtx() - bwd_si = SigInfo(f"{custom_fwd_sym.name}_backward") - for a in ctx_proxy.saved_tensors + grads: - if isinstance(a, Proxy): - bwd_si.args.append((a.name, None)) - else: - pa = proxy(a) - bwd_si.args.append((pa.name, None)) - trace_of_backward._siginfo = bwd_si + trace_of_backward._siginfo = SigInfo.from_name_and_args( + f"{custom_fwd_sym.name}_backward", + ctx_proxy.saved_tensors + grads, + ) trace_of_backward.args = tuple(ctx_proxy.saved_tensors + grads) jit_ctx.computation_trace.push_scope([]) @@ -720,21 +688,13 @@ def augmented_custom_forward_rule(*args, **kwargs): if custom_backward_result is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return custom_backward_result - custom_bwd_bsyms: list[BoundSymbol] = jit_ctx.computation_trace.pop_scope() - - for bsym in custom_bwd_bsyms: + for bsym in jit_ctx.computation_trace.pop_scope(): trace_of_backward.add_bound_symbol(bsym) with tracectx(trace_of_backward): - prims.python_return.bind(*unwrap(custom_backward_result), output=None) - - @wraps(trace_of_backward.python_callable()) - def bwd_trace_callable_interface(*args, **kwargs): - return thunder.core.trace_interpreter.interpret_trace(trace_of_backward, *args, **kwargs) + prims.python_return(unwrap(custom_backward_result)) bwd_trace_impl = TraceCtx() - for bsym in custom_bwd_bsyms: - bwd_trace_impl.add_bound_symbol(bsym) - bwd_trace_impl.add_bound_symbol(prims.python_return.bind(*sequencify(unwrap(custom_backward_result)), output=None)) + bwd_trace_impl.bound_symbols.extend(trace_of_backward.bound_symbols) bwd_trace_impl._siginfo = SigInfo.from_name_and_args( "backward_impl", ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads, @@ -745,13 +705,26 @@ def bwd_trace_callable_interface(*args, **kwargs): def bwd_impl_callable(*args, **kwargs): return thunder.core.trace_interpreter.interpret_trace(bwd_trace_impl, *args, **kwargs) - @wraps(bwd_trace_callable_interface) - def backward_impl(*args, **kwargs): - check(not kwargs, lambda: f"{kwargs} expected to be empty") - new_args = ctx_proxy.saved_consts + args - return bwd_impl_callable(*new_args) - - backward_impls[custom_fwd_sym.name] = backward_impl + @wraps(core_of_forward) + def grad_transform(*args, **kwargs): + from thunder.core.transforms import get_grad + from thunder.core.transforms import put_grads + from thunder.core.trace_interpreter import interpret_trace + + check(not kwargs, lambda: f"{kwargs=} should be empty") + primal, residuals = interpret_trace(trace_of_augmented_fwd, *args, **kwargs) + check(len(primal) == 1, lambda: f"{primal=} has {len(primal)} proxies but expected 1") + grads = (get_grad(primal[0]),) + bwd_args = ctx_proxy.saved_consts + residuals + grads + result = bwd_impl_callable(*bwd_args) + put_grads(args, result) + return primal + + jit_ctx.ad_hoc_executor.register_implementation( + custom_fwd_sym, + execution_transform=core_of_forward, + grad_transform=grad_transform, + ) return forward_result @@ -763,6 +736,41 @@ def _general_jit_torch_finfo_lookaside(dtype: thunder.dtypes.dtype): return res +@register_general_jit_lookaside(torch.utils.checkpoint.checkpoint) +def _general_jit_torch_checkpoint_lookaside( + function: Callable, + *args, + **kwargs: Any, +): + """ + This function does preprocessing of the `function` argument before + dispatching the call to `thunder.torch.checkpoint`. This is necessary + because the `function` is potentially calling into PyTorch functions that + are not yet translated to Thunder. `thunder.torch.checkpoint` is a Thunder + function that can handle only Thunder functions as input. + + Args: + function: The function to be checkpointed. + args: Arguments to the function. + kwargs: Keyword arguments to the function. + + Returns: + The result of calling `thunder.torch.checkpoint` with the preprocessed + `function` and its arguments. + """ + from thunder.torch import checkpoint + + # It should be possible to call the general_thunder_jit here to handle the + # conversion from torch to thunder but it doesn't work now + # See https://github.com/Lightning-AI/lightning-thunder/issues/1126 + # TODO: Convert the function to a Thunder function + def thunder_function(*args, **kwargs): + return unwrap(function)(*args, **kwargs) + + wrapped_thunder_function = wrap_const(thunder_function) + return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs) + + # Adds proxy methods # NOTE These methods map to themselves, which prevents the interpreter from looking into them # This is OK because these methods are written in a tracing-safe manner, and trying to diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index cd0f552453..2f2eb1c665 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -16,7 +16,12 @@ from thunder.core.compile_data import using_symbolic_values, using_jit from thunder.core.interpreter import is_jitting, ProvenanceRecord, PseudoInst -from thunder.core.trace import VariableInterface, get_tracectx, TraceCtx +from thunder.core.trace import ( + VariableInterface, + get_tracectx, + is_tracing, + TraceCtx, +) from thunder.core.baseutils import ( ProxyInterface, NumberProxyInterface, @@ -1242,8 +1247,7 @@ def _infer_tensor_properties( thunder_fsdp_padding_size if thunder_fsdp_padding_size is not None else _thunder_fsdp_padding_size ) - # dynamic shape not yet enabled, otherwise, the bake in should be guarded with if not using_symbolic_values(): - # dynamic shape support is currently block by #471 https://github.com/Lightning-AI/lightning-thunder/issues/471 + baseutils.check(_shape is not None, lambda: f"_shape cannot be None when creating TensorProxy") if not using_symbolic_values(): _shape = tuple(pyval(x) for x in _shape) # Computes derived properties @@ -1251,7 +1255,7 @@ def _infer_tensor_properties( else: # deferred computation of numel # TODO: similar to how `shape` is handled, this should be CSE or lifted for efficiency - _numel = lambda tp: reduce(operator.mul, tp.shape, 1) + _numel = lambda *args: reduce(operator.mul, _shape, 1) # TODO Alias rank to ndim? _ndim = len(_shape) @@ -1459,7 +1463,12 @@ def __init__( # outside of a trace or language context @property def shape(self): - return self._shape + if not using_symbolic_values() or not is_tracing(): + return self._shape + else: + from thunder.core.prims import shape + + return shape(self) @property def ndim(self): @@ -1548,10 +1557,10 @@ def replace(self, **changes): ) def __repr__(self): - return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape})>' + return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self._shape})>' def type_string(self): - return f"{self.device.device_str()} {self.dtype.shortname()}{list(self.shape)}" + return f"{self.device.device_str()} {self.dtype.shortname()}{list(self._shape)}" # NOTE __getattr__ is overridden to support language-specific methods def __getattr__(self, attr: str, /): diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 97248c75bf..6a3d3d3a76 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -7,6 +7,7 @@ import thunder.core.dtypes as dtypes import thunder.core.devices as devices from thunder.core.baseutils import ProxyInterface +from types import FunctionType OPTREE_NAMESPACE = "thunder" diff --git a/thunder/core/rematerialization.py b/thunder/core/rematerialization.py index 66640ca6a6..f24ee1fad2 100644 --- a/thunder/core/rematerialization.py +++ b/thunder/core/rematerialization.py @@ -374,7 +374,13 @@ def add_edges(var): g = nx.DiGraph() g.add_edges_from(edges) - _, (reachable, non_reachable) = nx.minimum_cut(g, "source", "sink") + try: + _, (reachable, non_reachable) = nx.minimum_cut(g, "source", "sink") + except Exception: + raise RuntimeError( + "Failed to compute the min-cut on the graph due to a path with infinite capacity." + "Please report this error along with your function and relevant details at: https://github.com/Lightning-AI/lightning-thunder/issues/new" + ) cut_edges = set() for u, nbrs in ((n, g[n]) for n in reachable): diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index a7d7e598e2..ed0e110191 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -122,9 +122,9 @@ def add_to_swap_map(old, new): # (FSDP, tensor parallel) transforms do "break" shape metadata new_trace.names.remove(old.name) # taken by the .replace proxy if isinstance(new, VJPDual): - old = old.replace(shape=new.primal.shape) + old = old.replace(shape=new.primal._shape) else: - old = old.replace(shape=new.shape) + old = old.replace(shape=new._shape) if isinstance(new, VJPDual): swap_map[variableify(new.primal)] = old diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 4e81bf3261..13437d488e 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2811,7 +2811,7 @@ def _vjp(primals, cotangents, **kwargs): # If the argument is a CPU scalar tensor, its gradient needs to be summed into a scalar tensor. vjp_result = tuple( ( - sum_to(grad, arg.shape) + sum_to(grad, arg._shape) if (grad is not None and isinstance(arg, TensorProxy) and arg.device.type == "cpu") else grad ) diff --git a/thunder/core/utils.py b/thunder/core/utils.py index a5056071d5..c521983316 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1001,8 +1001,9 @@ def __repr__(self) -> str: return str(self._dict) -# NOTE That this pass does not assume that the bound symbols are in a reasonable order, -# but it does assume that each proxy is uniquely constructed once +# NOTE That this pass does not assume that the bound symbols are in a reasonable order. +# For bound symbols with multiple producers, this pass returns the first producer of +# in order of the presented bound symbols # Returns a proxy -> producer mapping # If _map_to_numbers is True then producers are represented by their position in the trace (their "line number") def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_numbers: bool = False) -> ProxyDict: @@ -1022,6 +1023,10 @@ def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_ continue for out in bsym.flat_proxy_outs: + # if a producer has already been traversed, skip + if producers.get(out, None) != None: + continue + vout = variableify(out) # Checks if the proxy was also an input (in which case this is not its producers) diff --git a/thunder/dynamo/compiler_graph_benchmark.py b/thunder/dynamo/compiler_graph_benchmark.py index 7f7fbccf6f..eafd30ce0e 100644 --- a/thunder/dynamo/compiler_graph_benchmark.py +++ b/thunder/dynamo/compiler_graph_benchmark.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS = ("GraphID", "SplitModuleName", "executor") @@ -25,8 +25,8 @@ class ThunderCompilerGraphBenchmarking(ThunderCompiler): def __init__( self, bench: BenchmarkFixture, - executors: Sequence[str], - **thunder_options, + executors: dict[str, Callable], + **debug_options, ): """ This class acts as a backend for the :func:`torch.compile` function, facilitating the benchmarking of each :class:`torch.fx.GraphModule` produced by Thunder dynamo splitter. @@ -34,16 +34,17 @@ def __init__( Args: bench: the BenchmarkFixture created by ``pytest_benchmark`` - executors: list of executors to compare. Supported executors include: 'eager', 'inductor', and 'thunder'. If None, defaults to all available executors. - **thunder_options: a dictionary of options to pass to :func:`thunder.jit`. Besides all the arguments to :func:`thunder.jit`, - it accepts `torch_inductor_options` which are passed to :func:`torch.compile` if part of the graph - is not supported by thunder. + executors: A dictionary of functors to compare. + - Key: The name of the executor to be displayed in the test name. + - Value: A callable representing the compile function to be applied to the GraphModule. + If the value is None, no compilation is performed, and the GraphModule runs in Torch eager mode. Example: .. code-block:: python # script.py import torch + import thunder from thunder.dynamo.compiler_graph_benchmark import ThunderCompilerGraphBenchmarking def func(x): @@ -54,7 +55,7 @@ def func(x): return x - 1 def test_func(benchmark): - backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["eager", "thunder"]) + backend = ThunderCompilerGraphBenchmarking(benchmark, executors={"eager": None, "thunder": thunder.jit}) compiled = torch.compile(backend=backend)(func) x = torch.ones(2, requires_grad=True).cuda() compiled(x) @@ -72,41 +73,41 @@ def test_func(benchmark): With `--benchmark-group-by='graph-by-graph:param:GraphID,param:SplitModuleName'`, the test cases are grouped based on GraphID and SplitModuleName, allowing for performance comparison between different executors (e.g., 'eager' vs. 'thunder'). """ - super().__init__(**thunder_options) + super().__init__() self.bench = bench - if not executors: - self.executors = ThunderCompilerGraphBenchmarking._executors - else: - check( - all(ex in ThunderCompilerGraphBenchmarking._executors for ex in executors), - lambda: f"ThunderCompilerGraphBenchmarking only supports the following executor names: {ThunderCompilerGraphBenchmarking._executors} ", - ) - self.executors = executors + check(isinstance(executors, dict) and executors, lambda: f"'executors' must be a non-empty dictionary.") + check( + not any("-" in k for k in executors.keys()), + lambda: f"Executor names cannot contain '-' as it conflicts with the 'benchmark-group-by' function. Please rename it using a different character.", + ) + self.executors = executors + self._get_debug_options(**debug_options) + self.graph_idx = 0 + def _get_debug_options(self, **debug_options): + self.post_graph = debug_options.get("post_graph", False) + def run_bench(self, gm: torch.fx.GraphModule, name: str, *sample_args): from thunder.benchmarks.targets import record_peak_allocated_memory, MAX_ALLOCATED_MEMORY_KEYWORD - for ex in self.executors: - # Uses the already compiled module if it is compiled with the expected executor - if name.startswith(ex): - fn = self.subgraph_infos[self.graph_idx].submodule_to_compiled_functions[gm].compiled_fn + for ex_name, ex in self.executors.items(): + if ex is None: + compiled_fn = gm else: - if ex == "thunder": - # The subgraph whose name starts with "inductor" is not supported by the Thunder backend. - if name.startswith("inductor"): - continue - fn = self._thunder_jit(gm) - elif ex == "inductor": - fn = self._torch_compile(gm) - else: - fn = gm + try: + compiled_fn = ex(gm) + except Exception as e: + raise RuntimeError(f"The input executor {ex_name} failed to compile {gm}") from e + if self.post_graph: + compiled_fn = self.post_graph(compiled_fn, sample_args) + with record_peak_allocated_memory(self.bench): - self.bench(fn, *sample_args) + self.bench(compiled_fn, *sample_args) # BenchmarkFixture.stats is created each time bench is called (ref: https://github.com/pybenchmark/pytest-benchmark/blob/8c9a5faa1dd178b53ab7b2a66f5364a77e903d74/src/pytest_benchmark/fixture.py#L150) # Adds the graph number, split module name and executor suffix to the name string gid_key, module_name_key, ex_key = GRAPH_BY_GRAPH_BENCHMARK_PARAMS_KEYS - self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex}]" + self.bench.stats.name += f"-{gid_key}[{self.graph_idx+1}]-{module_name_key}[{name}]-{ex_key}[{ex_name}]" assert MAX_ALLOCATED_MEMORY_KEYWORD in self.bench.extra_info assert f"{self.bench.stats.name}_{MAX_ALLOCATED_MEMORY_KEYWORD}" not in self.bench.extra_info # NOTE: A benchmark can include multiple stats, but only one extra_info field is allowed per benchmark. @@ -128,8 +129,8 @@ def __call__(self, gm: torch.fx.GraphModule, sample_args: list[torch.SymInt, tor } for node in split_module.graph.nodes: target = node.target - # Benchmarks the modules produced by the splitter. - if isinstance(target, str) and target.startswith(("thunder_", "inductor_")): + # Benchmarks the modules produced by the splitter and are supported by Thunder. + if isinstance(target, str) and target.startswith("thunder_"): check( hasattr(split_module, target), lambda: f"the submodule {target} does not exist in {split_module}", diff --git a/thunder/executors/pythonex.py b/thunder/executors/pythonex.py index 942e9c9b0b..de2bae4d0b 100644 --- a/thunder/executors/pythonex.py +++ b/thunder/executors/pythonex.py @@ -344,6 +344,7 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten pythonex_pow = ex.register_operator("pow", like=prims.pow, module=operator) sub = ex.register_operator("sub", like=prims.sub, module=operator) div = ex.register_operator("div", like=prims.div, fn=_div_prim_impl) +shape = ex.register_operator("shape", like=prims.shape, fn=lambda x: x.shape) # TODO: Restore truediv once we find it... # truediv = ex.register_operator("truediv", like=prims.truediv, module=operator) @@ -367,6 +368,7 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten ex.register_implementation(prims.pow, pythonex_pow, checker=_elementwise_binary_checker) ex.register_implementation(prims.sub, sub, checker=_elementwise_binary_checker) ex.register_implementation(prims.div, div, checker=_elementwise_binary_checker) +ex.register_implementation(prims.shape, shape, checker=_always_executable) def _sink(*args, **kwargs): diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index c292073482..038997fd57 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2246,7 +2246,11 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe ), # NOTE: PyTorch fails with RuntimeError: "reciprocal_cuda" not implemented for 'Long' occasionally when the exponent is CPU scalar tensor # e.g.: x=torch.tensor([[ 6, 5, 1, -8],], device='cuda:0');y=torch.tensor(-1);torch.pow(x,y) - DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.int32, datatypes.int64)), + DecorateInfo( + pytest.mark.xfail, + "test_core_vs_torch_consistency", + dtypes=(datatypes.int8, datatypes.int16, datatypes.int32, datatypes.int64), + ), ), ) elementwise_binary_ops.append(pow_opinfo) diff --git a/thunder/tests/test_autocast.py b/thunder/tests/test_autocast.py index 5e0c345ee7..2d02ca92d4 100644 --- a/thunder/tests/test_autocast.py +++ b/thunder/tests/test_autocast.py @@ -1,4 +1,5 @@ import itertools +import platform import pytest import torch @@ -140,7 +141,9 @@ def func(a, b): assert output.dtype == (torch.float16 if torch_device.type == "cuda" else torch.bfloat16) -@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") +# Disabling on windows temporarily, until our windows runners source the +# appropriate visual studio config. +@pytest.mark.skipif(not is_inductor_supported() or platform.system() == "Windows", reason="inductor unsupported") def test_torch_compile_autocast(): """Checks if our autocast decorator plays well with ``torch.compile``""" diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index cae17adf65..869422359a 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -475,7 +475,11 @@ def func(x): # It must be located in the same folder as the test file to ensure the configuration. @requiresCUDA def test_ThunderCompilerGraphBenchmarking_LlamaMLPBenchmark(benchmark): - backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"]) + import thunder + + backend = ThunderCompilerGraphBenchmarking( + benchmark, executors={"thunder": thunder.jit, "inductor": torch.compile, "eager": None} + ) from thunder.benchmarks import LlamaMLPBenchmark, Benchmark bench: Benchmark = LlamaMLPBenchmark( @@ -505,8 +509,29 @@ def f(x, y): x = torch.sinc(y) + torch.cos(x) return x - 1 - backend = ThunderCompilerGraphBenchmarking(benchmark, executors=["thunder", "inductor", "eager"]) + import thunder + + backend = ThunderCompilerGraphBenchmarking( + benchmark, executors={"thunder": thunder.jit, "inductor": torch.compile, "eager": None} + ) compiled = torch.compile(backend=backend)(f) x = torch.ones(2, requires_grad=True).cuda() y = torch.ones(2, requires_grad=True).cuda() compiled(x, y) + + +@requiresCUDA +def test_ThunderCompilerGraphBenchmarking_post_graph(benchmark): + def f(x): + return torch.sin(x) + + import thunder + from functools import partial + + x = torch.randn((2, 2), device="cuda").requires_grad_() + post_gp = partial(torch.cuda.make_graphed_callables, num_warmup_iters=1, allow_unused_input=True) + backend = ThunderCompilerGraphBenchmarking( + benchmark, executors={"inductor": torch.compile, "thunder": thunder.jit}, post_graph=post_gp + ) + compiled = torch.compile(backend=backend)(f) + compiled(x) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 3f8fe0c50b..33b4e4df2c 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1700,6 +1700,42 @@ def func(a, b): get_saved_for_backward_tensors(execution_trace) +def test_torch_checkpoint(): + import torch.utils.checkpoint + import torch._higher_order_ops.wrap + + def fn_to_checkpoint(x): + return x.sin().cos().exp() + + checkpoint_fns = ( + thunder.torch.checkpoint, + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False), + torch.ops.higher_order.tag_activation_checkpoint, + ) + + for checkpoint_fn in checkpoint_fns: + + def f(x): + return checkpoint_fn(fn_to_checkpoint, x) + + x = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True) + jf = thunder.jit(f) + out = jf(x) + + # With activation checkpointing, we are saving only the original input. + # The intermediate values are recomputed during backward pass. + assert len(out.grad_fn.saved_tensors) == 1 + assert out.grad_fn.saved_tensors[0] is x + + g = torch.ones_like(out) + out.backward(g) + + x_ref = x.detach().requires_grad_() + out_ref = fn_to_checkpoint(x_ref) + out_ref.backward(g) + torch.testing.assert_close(x.grad, x_ref.grad) + + def test_inconsistent_output_length_grad_transform(): from thunder.extend import OperatorExecutor from thunder.core.proxies import AnyProxy, TensorProxy diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index dd2fd878f7..aa0fe8728c 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1123,7 +1123,7 @@ def forward(self, x): ("cpu", "cuda"), ) def test_cache_symbolic_values_reshape(device): - if not torch.cuda.is_available(): + if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") a = torch.randn((4, 8, 6), device=device) @@ -1451,3 +1451,33 @@ def foo(a): assert_close(actual, expected) assert thunder.cache_misses(jfoo) == 1 assert thunder.cache_hits(jfoo) == 1 + + +def test_cache_symbolic_values_reshape_numel(): + def foo(a): + a = torch.reshape(a, [a.numel()]) + return a.relu() + + jfoo = thunder.jit(foo, cache="symbolic values") + + a = torch.randn(2, 3, 8, requires_grad=True, device="cpu") + + actual = jfoo(a) + expected = foo(a) + + assert_close(actual, expected) + + +def test_cache_symbolic_values_slice(): + def foo(a): + a = a[..., : a.shape[-1]] + return a.relu() + + jfoo = thunder.jit(foo, cache="symbolic values") + + a = torch.randn(2, 3, 8, requires_grad=True, device="cpu") + + actual = jfoo(a) + expected = foo(a) + + assert_close(actual, expected) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 52507be74b..c42c183b7e 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -997,6 +997,10 @@ def make_integer_tensor(): nvfuser_version() is None or nvfuser_version() < LooseVersion("0.2.10"), reason="Requires nvFuser version 0.2.10 or later", ), + pytest.mark.skipif( + torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] < 9, + reason="Requires CUDA compute capability >= 9.0", + ), pytest.mark.parametrize("dropout_p", [0.0, 0.2]), pytest.mark.parametrize("is_causal", [False, True]), pytest.mark.parametrize("scale", [None, 1e-3]), diff --git a/thunder/tests/test_torch_compile_executor.py b/thunder/tests/test_torch_compile_executor.py index eac99957e3..c0bd1b351d 100644 --- a/thunder/tests/test_torch_compile_executor.py +++ b/thunder/tests/test_torch_compile_executor.py @@ -1,3 +1,4 @@ +import platform import pytest import torch from torch._dynamo import is_inductor_supported @@ -15,7 +16,9 @@ def test_supported_ops_are_in_pytorch_executor(): assert supported_ops - pytorch_ex.implmap.keys() == set() -@pytest.mark.skipif(not is_inductor_supported(), reason="inductor unsupported") +# Disabling on windows temporarily, until our windows runners source the +# appropriate visual studio config. +@pytest.mark.skipif(not is_inductor_supported() or platform.system() == "Windows", reason="inductor unsupported") def test_torch_compile_litgpt(): from litgpt.model import GPT diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index 09681e5261..750894ab31 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -596,3 +596,18 @@ def forward(x): break else: raise RuntimeError("Failed to find `add` symbol in trace") + + +@requiresCUDA +def test_cudagraph_empty_inputs(): + def fn(): + a = torch.ones(5, 5, device="cuda") + b = a * 2 + return b + + from thunder.transforms.cudagraph import CUDAGraphTransform + + jfn = thunder.jit(fn, transforms=(CUDAGraphTransform(),), executors=()) + assert_close(jfn(), fn()) + + assert any(("CUDAGraph" in bsym.sym.name) for bsym in thunder.last_traces(jfn)[-1].bound_symbols) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 2e92c6bbce..ed69b4f096 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -61,6 +61,8 @@ # NOTE torch is a requirement import torch +import torch.utils.checkpoint +import torch._higher_order_ops.wrap import warnings @@ -5199,6 +5201,71 @@ def _unwrap_if_dead(tensor): register_function(torch._C._functorch.unwrap_if_dead, _unwrap_if_dead) +@torchsymbol( + torch.utils.checkpoint.checkpoint, + torch.ops.higher_order.tag_activation_checkpoint, + id="activation_checkpoint", +) +def checkpoint( + function: Callable[..., TensorLike], + *args: TensorLike, + context_fn: None | Callable[..., Any] = None, + debug: None | bool = None, + determinism_check: None | str = None, + preserve_rng_state: None | bool = None, + use_reentrant: bool = False, + **kwargs: Any, +) -> TensorLike: + utils.check( + not use_reentrant, + lambda: "torch.checkpoint: use_reentrant=True is not supported in Thunder", + ) + # NOTE: Thunder currently ignores the context_fn, debug, determinism_check, preserve_rng_state arguments + # Let's raise a warning if any of these arguments are passed + if context_fn is not None: + warnings.warn("torch.checkpoint: context_fn is not supported in Thunder and will be ignored") + if debug is not None: + warnings.warn("torch.checkpoint: debug is not supported in Thunder and will be ignored") + if determinism_check is not None: + warnings.warn("torch.checkpoint: determinism_check is not supported in Thunder and will be ignored") + if preserve_rng_state is not None: + warnings.warn("torch.checkpoint: preserve_rng_state is not supported in Thunder and will be ignored") + return function(*args, **kwargs) + + +@register_augmented_forward( + "activation_checkpoint", +) +def _augmented_forward_checkpoint( + function: Callable[..., TensorLike], + *args: TensorLike, + context_fn: None | Callable[..., Any] = None, + debug: None | bool = None, + determinism_check: None | str = None, + preserve_rng_state: None | bool = None, + use_reentrant: bool = False, + **kwargs: Any, +) -> TensorLike: + result = function(*args, **kwargs) + saved_for_backward = (function, args, kwargs) + return result, saved_for_backward + + +@register_backward( + "activation_checkpoint", +) +def _backward_checkpoint( + function, + args, + kwargs, + *grad_outputs, +) -> tuple[None | TensorLike, ...]: + from thunder.core.transforms import vjp + + result = vjp(function)(args, grad_outputs, **kwargs) + return result + + # # Distributed operations # diff --git a/thunder/transforms/constant_folding.py b/thunder/transforms/constant_folding.py index 6f5273c4a1..24085793bc 100644 --- a/thunder/transforms/constant_folding.py +++ b/thunder/transforms/constant_folding.py @@ -5,8 +5,8 @@ import torch import thunder -from thunder.core.trace import from_trace, tracectx -from thunder.core.proxies import variableify, Variable, TensorProxy, NumberProxy, proxy +from thunder.core.trace import from_trace +from thunder.core.proxies import variableify, Variable, TensorProxy, NumberProxy from thunder.core.symbol import BoundSymbol from thunder.core.dtypes import to_dtype from thunder.core.devices import to_device @@ -71,7 +71,6 @@ def materialize_args(a): class ConstantFolding(thunder.Transform): def transform_traces_pre_prologue(self, prologue_trc, computation_trc, epilogue_trc, **kwargs): - # print(computation_trc) # Create a new trace const_folded_trace = from_trace(computation_trc) const_folded_trace.bound_symbols = computation_trc.bound_symbols diff --git a/thunder/transforms/cudagraph.py b/thunder/transforms/cudagraph.py index e66074d117..25d907459b 100644 --- a/thunder/transforms/cudagraph.py +++ b/thunder/transforms/cudagraph.py @@ -33,7 +33,10 @@ def extract_descriptor(arg): else: return type(arg), None, None, arg - dtypes, sizes, strides, non_tensor_args = zip(*map(extract_descriptor, args)) + if args: + dtypes, sizes, strides, non_tensor_args = zip(*map(extract_descriptor, args)) + else: + dtypes = sizes = strides = non_tensor_args = None return ArgsDescriptor(dtypes, sizes, strides, non_tensor_args)