diff --git a/README.md b/README.md
index 0fbfa636c3..8b97b1cd62 100644
--- a/README.md
+++ b/README.md
@@ -1,57 +1,94 @@
-![](docs/source/_static/images/lightning_thunder_lightmode_nobyline.png)
+
+
+
+
+
+**Make PyTorch models Lightning fast.**
+
+______________________________________________________________________
+
+
+ Lightning.ai •
+ Performance •
+ Get started •
+ Install •
+ Examples •
+ Features •
+ Documentation •
+
+
+[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/Lightning-AI/lightning-thunder/blob/main/LICENSE)
+[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml)
+[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml)
+[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest)
+[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main)
+
+
# Welcome to ⚡ Lightning Thunder
-Lightning Thunder is a source-to-source compiler for PyTorch.
+**Thunder makes PyTorch models Lightning fast.**
-It makes PyTorch programs faster both on single accelerators or in distributed settings.
+Thunder is a source-to-source compiler for PyTorch. It makes PyTorch programs faster by combining and using different hardware executors at once (ie: nvFuser, torch.compile, cuDNN, and TransformerEngine FP8).
+Works on single accelerators and in multi-GPU settings.
Thunder aims to be usable, understandable, and extensible.
## Performance
Thunder can achieve significant speedups over standard PyTorch eager code, through the compounding effects of optimizations and the use of best-in-class executors. Here is an example of the pretraining throughput for Llama 2 7B as implemented in [LitGPT](https://github.com/Lightning-AI/litgpt).
-![](docs/source/_static/images/training_throughput_single.png)
+
+
+
-We achieve a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8.
+Thunder achieves a 40% speedup in training throughput compared to eager code on H100 using a combination of executors including nvFuser, torch.compile, cuDNN, and TransformerEngine FP8.
Thunder supports distributed strategies like DDP and FSDP (ZeRO2 and ZeRO3). Here is the normalized throughput measured for Llama 2 7B (this time without FP8 mixed precision, support for FSDP is underway).
-![](docs/source/_static/images/normalized_training_throughput_zero2.png)
+
+
+
**NOTE: Lightning Thunder is alpha.** Feel free to get involved, expect a few bumps along the way.
-## Start with Thunder
+## Get started
Try Thunder without installing by using our [Zero to Thunder Tutorial Studio](https://lightning.ai/lightning-ai/studios/zero-to-thunder-tutorial).
## Install Thunder
-Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, which will also install the matching PyTorch nightly:
+Install [nvFuser](https://github.com/NVIDIA/Fuser) nightly, and Thunder together
```bash
+# install nvFuser which installs the matching nightly PyTorch
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
-```
-Install Thunder:
-
-```bash
+# install thunder
pip install lightning-thunder
```
-It's actually not a bad idea to install directly from `main`:
+
+ Advanced install options
+
+
+### Install from main
```bash
pip install git+https://github.com/Lightning-AI/lightning-thunder.git
```
-or from the local repo if you want to tinker with the internals:
+### Install to tinker and contribute
+
+Install this way to tinker with the internals and contribute:
```bash
pip install -e .
```
+
+
+
## Hello World
Here is a simple example of how Thunder lets you compile and run PyTorch code:
@@ -102,7 +139,7 @@ python examples/lit-gpt/train_fsdp.py
See [README.md](examples/lit-gpt/README.md) for details on running LitGPT with Thunder.
-## What's in the box
+## Features
Given a Python callable or PyTorch module, Thunder can generate an optimized program that:
@@ -132,7 +169,7 @@ Thunder doesn't generate code for accelerators directly. It acquires and transfo
Modules and functions compiled with Thunder fully interoperate with vanilla PyTorch and support PyTorch's autograd. Also, Thunder works alongside torch.compile to leverage its state-of-the-art optimizations.
-## Build the documentation
+## Documentation
Docs are currently not hosted publicly. However you can build them locally really quickly:
@@ -168,8 +205,3 @@ Thunder is very thoroughly tested, so expect this to take a while.
Lightning Thunder is released under the [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0) license.
See the [LICENSE](LICENSE) file for details.
-
-[![CI testing](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-testing.yml)
-[![General checks](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml/badge.svg?event=push)](https://github.com/Lightning-AI/lightning-thunder/actions/workflows/ci-checks.yml)
-[![Documentation Status](https://readthedocs.org/projects/lightning-thunder/badge/?version=latest)](https://lightning-thunder.readthedocs.io/en/latest/?badge=latest)
-[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Lightning-AI/lightning-thunder/main.svg)](https://results.pre-commit.ci/latest/github/Lightning-AI/lightning-thunder/main)
diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py
index c24e44a3ce..772e65a84d 100644
--- a/thunder/core/transforms.py
+++ b/thunder/core/transforms.py
@@ -1237,11 +1237,6 @@ def _embedding_prim_grad(
def _get_gradfn(bsym: BoundSymbol, *, executors_list: Sequence[Any] = tuple()) -> None | Callable:
- # If executor specific `aug_fwd_rule` exists then we will use that,
- # so we return `None` here.
- if get_executor_specific_aug_fwd_rule(bsym):
- return None
-
cd = get_compile_data()
executors_list = cd.executors_list if cd is not None else executors_list
# Checks if the executor which has priority for this operation has a specific grad transform for it
@@ -2484,15 +2479,6 @@ def zeros_like(x):
}
-@dataclass(**default_dataclass_params)
-class RuleInfo:
- checker: Callable
- rule: Callable
- fw_fallback: Callable
- bw_fallback: Callable
- executor: Executor
-
-
def register_augmented_forward(op):
"""Decorator to register an augmented forward implementation for a symbol.
@@ -2510,40 +2496,6 @@ def decorator(func):
return decorator
-def register_augmented_forward_with_checker(executor, op, checker, rule):
- """Decorator to register an augmented forward implementation for a symbol.
-
- Args:
- executor (Executor): Executor to which the rule applies.
- op (Ops): Symbol for which to register the augmented forward implementation.
- checker (Callable): Function that checks if the rule should be applied.
- rule (Callable): Function that applies the rule.
- """
- fw_fallback = augmented_forward_impls.get(op, None)
- bw_fallback = backward_impls.get(op, None)
- augmented_forward_impls[executor, op] = RuleInfo(checker, rule, fw_fallback, bw_fallback, executor)
-
-
-def deregister_augmented_forward_and_backward(op):
- """Deregisters an augmented forward implementation and a backward
- implementation for a symbol.
-
- Args:
- op (Ops): Symbol for which to deregister the augmented forward
- implementation and the backward implementation.
-
- Returns:
- None
- """
- # Restore the fallback implementation if it exists
- if isinstance(augmented_forward_impls[op], RuleInfo):
- backward_impls[op] = augmented_forward_impls[op].bw_fallback
- augmented_forward_impls[op] = augmented_forward_impls[op].fw_fallback
- else:
- del augmented_forward_impls[op]
- del backward_impls[op]
-
-
def register_backward(op):
"""Decorator to register a backward implementation for a symbol.
@@ -3320,31 +3272,6 @@ def uniform_backward(primal, minval, maxval, g):
nondifferentiable_vjp_symbols = (prims.PrimIDs.BITWISE_AND, prims.PrimIDs.SIGNBIT, prims.PrimIDs.FULL)
-def get_executor_specific_aug_fwd_rule(symbol: BoundSymbol) -> RuleInfo | None:
- """Get executor specific augmented forward rule.
-
- Args:
- symbol (BoundSymbol): BoundSymbol to get the rule for.
-
- Returns:
- RuleInfo: Rule info for the symbol.
- """
- cd = get_compile_data()
- if cd is None:
- return None
-
- # Search for the executor specific rules. When there are multiple rules
- # for the same symbol, we use the left-most executor in the list (i.e.
- # the one with the highest priority) and we fallback to the next one if
- # the checker returns False.
- for executor in cd.executors_list:
- candidate = augmented_forward_impls.get((executor, symbol.sym.id))
- if isinstance(candidate, RuleInfo) and candidate.checker(*symbol.args, **symbol.kwargs):
- return candidate
-
- return None
-
-
def is_constant_for_vjp(symbol: prims.Symbol) -> bool:
"""Check if a symbol is constant for the VJP transform.
@@ -3387,19 +3314,10 @@ def vjp_impl_const(symbol, *args, **kwargs):
# Normal case, we have a proxy tangent
vjp_impl = augmented_forward_impls.get(symbol.sym.id)
- vjp_impl = get_executor_specific_aug_fwd_rule(symbol) or vjp_impl
if _get_gradfn(symbol) is not None:
vjp_impl, backward_fn = make_aug_forward_and_backward(symbol)
- if isinstance(vjp_impl, RuleInfo):
- # We should use this rule only if checker returns True for the current
- # symbol's arguments
- if vjp_impl.checker(*symbol.args, **symbol.kwargs):
- vjp_impl = vjp_impl.rule
- else:
- vjp_impl = vjp_impl.fw_fallback
-
if vjp_impl is None:
# We could not find a VJP for this symbol, so we try to decompose it
if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs):
@@ -3567,14 +3485,10 @@ def put_grad(v: Variable, val: Any) -> None:
backward = backward_impls.get(symbol.sym.id)
aug_forward = augmented_forward_impls.get(symbol.sym.id)
- aug_forward = get_executor_specific_aug_fwd_rule(symbol) or aug_forward
if _get_gradfn(symbol) is not None:
aug_forward, backward = make_aug_forward_and_backward(symbol)
- if isinstance(aug_forward, RuleInfo):
- backward = backward_impls[aug_forward.executor, symbol.sym.id]
-
if backward is None:
if len(symbol.subsymbols) > 0 and not isinstance(symbol.sym.id, prims.PrimIDs):
# We could not find a backward for this symbol, so we try to decompose it
diff --git a/thunder/executors/apex_entropyex.py b/thunder/executors/apex_entropyex.py
index 8a82e04e20..818199ad5b 100644
--- a/thunder/executors/apex_entropyex.py
+++ b/thunder/executors/apex_entropyex.py
@@ -11,10 +11,6 @@
from thunder.core.symbol import Symbol
from thunder.core.utils import check, same_shape
from thunder.core.transforms import get_grad, put_grad, put_grads, mean_backward, restore_reduced_dims
-from thunder.core.transforms import (
- register_augmented_forward_with_checker,
- register_backward,
-)
from thunder.extend import OperatorExecutor, register_executor
@@ -197,76 +193,6 @@ def _cross_entropy_checker(
return True
-# Check out the 'add vjp rule' dev tutorial on how to add a VJP rule for any
-# Symbol. We use our new primitives to register a VJP rule for
-# torch.nn.functional.cross_entropy. This function is registered as the
-# augmented forward rule for torch.nn.functional.cross_entropy below
-def apex_cross_entropy_forward_rule(
- a,
- target,
- weight=None,
- size_average=None,
- ignore_index=-100,
- reduce=None,
- reduction="mean",
- label_smoothing=0.0,
-):
- loss, max_log_sum_exp = apex_xentropy(
- a,
- target=target,
- reduction=reduction,
- label_smoothing=label_smoothing,
- )
- primal = loss
- saved_for_backward = (a, target, max_log_sum_exp, reduction, label_smoothing)
- return primal, saved_for_backward
-
-
-register_augmented_forward_with_checker(
- apex_ex,
- ltorch.cross_entropy.id,
- _cross_entropy_checker,
- apex_cross_entropy_forward_rule,
-)
-
-
-# This function is the backward rule for torch.nn.functional.cross_entropy. It
-# accepts the primal output and saved_for_backward from the forward pass and
-# returns the backward output. The backward output is a tuple of the backward
-# output for each differentiable Tensor input to the forward pass. In this case,
-# the forward pass has 1 such input, so the backward output is a single Tensor.
-# This function is registered as the backward rule for
-# torch.nn.functional.cross_entropy
-@register_backward((apex_ex, ltorch.cross_entropy.id))
-def apex_cross_entropy_backward_rule(
- logits,
- labels,
- max_log_sum_exp,
- reduction,
- smoothing,
- grad,
-):
- from thunder.core.transforms import mean_backward, sum_backward
-
- if reduction == "mean":
- grad = mean_backward(max_log_sum_exp.ndim, max_log_sum_exp.shape, (0,), grad)
- elif reduction == "sum":
- grad = sum_backward(max_log_sum_exp.shape, (0,), grad)
- elif reduction == "none":
- pass
- else:
- raise ValueError(f"Invalid reduction: {reduction}")
-
- grad_logits = apex_xentropy_bwd(
- grad,
- logits,
- target=labels,
- max_log_sum_exp=max_log_sum_exp,
- label_smoothing=smoothing,
- )
- return grad_logits
-
-
# Translate calls from torch.nn.functional.cross_entropy to apex_xentropy (when the checker above returns True)
def _cross_entropy_transform(
a: TensorProxy,
diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py
index 9fb6c50a48..75494cff5a 100644
--- a/thunder/executors/cudnnex.py
+++ b/thunder/executors/cudnnex.py
@@ -35,8 +35,6 @@ def cudnn_available() -> bool:
get_grad,
put_grad,
put_grads,
- register_augmented_forward_with_checker,
- register_backward,
)
from thunder.extend import OperatorExecutor, register_executor
import thunder.torch as ltorch
@@ -338,7 +336,11 @@ def _cudnn_sdpa_forward_checker(
if d % 8 != 0 or d > 128:
return False
- return True
+ is_backward_supported = _cudnn_sdpa_backward_checker(
+ query, key, value, attn_mask, dropout_p, is_causal, scale=scale
+ )
+
+ return True and is_backward_supported
@langctx("torch")
@@ -601,99 +603,6 @@ def cudnn_sdpa_bwd_impl(
)
-@langctx("torch")
-def cudnn_sdpa_aug_fw_rule_checker(
- query: TensorProxy,
- key: TensorProxy,
- value: TensorProxy,
- attn_mask: None | TensorProxy,
- dropout_p: float,
- is_causal: bool,
- *,
- scale: None | float,
-) -> bool:
- from thunder.core.compile_data import get_compile_data
-
- cd = get_compile_data()
- if cudnn_ex in cd.executors_list:
- is_forward_supported = _cudnn_sdpa_forward_checker(
- query, key, value, attn_mask, dropout_p, is_causal, scale=scale
- )
- is_backward_supported = _cudnn_sdpa_backward_checker(
- query, key, value, attn_mask, dropout_p, is_causal, scale=scale
- )
- return is_forward_supported and is_backward_supported
- return False
-
-
-def cudnn_sdpa_aug_fw_rule(
- query,
- key,
- value,
- attn_mask=None,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- *,
- scale: float | None = None,
-):
- output, softmax_stats, seed, offset = cudnn_sdpa_fwd(
- query, key, value, attn_mask, dropout_p, is_causal, scale=scale
- )
- saved_for_backward = (
- query,
- key,
- value,
- attn_mask,
- dropout_p,
- is_causal,
- scale,
- output,
- softmax_stats,
- seed,
- offset,
- )
- return output, saved_for_backward
-
-
-register_augmented_forward_with_checker(
- cudnn_ex,
- "torch.nn.functional.scaled_dot_product_attention",
- cudnn_sdpa_aug_fw_rule_checker,
- cudnn_sdpa_aug_fw_rule,
-)
-
-
-@register_backward((cudnn_ex, "torch.nn.functional.scaled_dot_product_attention"))
-def cudnn_sdpa_backward_rule(
- query: Proxy,
- key: Proxy,
- value: Proxy,
- attn_mask: None | Proxy,
- dropout_p: float,
- is_causal: bool,
- scale: None | float,
- out: Proxy,
- softmax_stats: Proxy,
- seed: Proxy,
- offset: Proxy,
- grad_out: Proxy,
-):
- return cudnn_sdpa_bwd(
- grad_out,
- query,
- key,
- value,
- attn_mask,
- dropout_p,
- is_causal,
- out,
- softmax_stats,
- seed,
- offset,
- scale=scale,
- )
-
-
@langctx("torch")
def _cudnn_sdpa_transform(
query: TensorProxy,
@@ -726,7 +635,7 @@ def _cudnn_sdpa_grad(
)
g = get_grad(primal)
- grad_query, grad_key, grad_val, grad_attn_mask = cudnn_sdpa_bwd(
+ grads = cudnn_sdpa_bwd(
g,
query,
key,
@@ -740,6 +649,11 @@ def _cudnn_sdpa_grad(
offset,
scale=scale,
)
+ if attn_mask is None:
+ grad_query, grad_key, grad_val = grads
+ else:
+ grad_query, grad_key, grad_val, grad_attn_mask = grads
+
put_grads((query, key, value), (grad_query, grad_key, grad_val))
if attn_mask is not None:
put_grad(attn_mask, grad_attn_mask)
diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py
index 91d32e7a88..ced5d8fdb1 100644
--- a/thunder/executors/transformer_engineex.py
+++ b/thunder/executors/transformer_engineex.py
@@ -19,10 +19,6 @@
import thunder.core.prims as prims
from thunder.core.proxies import TensorProxy, CollectionProxy
from thunder.core.symbol import Symbol
-from thunder.core.transforms import (
- register_augmented_forward_with_checker,
- register_backward,
-)
from thunder.extend import OperatorExecutor, register_executor
__all__ = [
@@ -411,15 +407,6 @@ def linear_forward_rule_checker(a: TensorProxy, w: TensorProxy, bias: None | Ten
return False
-register_augmented_forward_with_checker(
- transformer_engine_ex,
- prims.linear.id,
- linear_forward_rule_checker,
- linear_forwad_rule,
-)
-
-
-@register_backward((transformer_engine_ex, prims.linear.id))
def linear_backward_rule(a_shape, w_shape, b_shape, ctx_idx, grad):
return te_functional_linear_backward(grad, a_shape, w_shape, b_shape, ctx_idx)
@@ -429,9 +416,21 @@ def _linear_transform(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> torch.T
return _create_fp8_linear_bound_symbol(a, w, b, is_grad_enabled=False)
+def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy:
+ out, saved_for_backward = linear_forwad_rule(a, w, b)
+ g = prims.get_grad(out)
+ ga, gw, gb = linear_backward_rule(*saved_for_backward, g)
+ prims.put_grad(a, ga)
+ prims.put_grad(w, gw)
+ if b is not None:
+ prims.put_grad(b, gb)
+ return out
+
+
# Registers the implementation for torch.nn.functional.linear
transformer_engine_ex.register_implementation(
prims.linear,
checker=_linear_checker,
execution_transform=_linear_transform,
+ grad_transform=_linear_grad,
)
diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py
index c9bd6277ec..4128e02914 100644
--- a/thunder/tests/test_cudnn_executor.py
+++ b/thunder/tests/test_cudnn_executor.py
@@ -110,10 +110,8 @@ def test_cudnn_sdpa():
query = 1 * (torch.randn(shape_Q, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5)
key = 2 * (torch.randn(shape_K, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5)
value = 3 * (torch.randn(shape_V, dtype=thunder.torch.to_torch_dtype(dtype), device="cuda") - 0.5)
- is_causal = False
- attn_mask = torch.randn(
- s_q, s_kv, requires_grad=False, device="cuda", dtype=thunder.torch.to_torch_dtype(dtype)
- )
+ is_causal = True
+ attn_mask = None
expected = torch.nn.functional.scaled_dot_product_attention(
query, key, value, is_causal=is_causal, attn_mask=attn_mask