Skip to content

Commit

Permalink
Merge branch 'main' into fix_torchspecial
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Aug 21, 2024
2 parents 126bda6 + 3548ba8 commit 2a816a3
Show file tree
Hide file tree
Showing 8 changed files with 570 additions and 241 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ The easiest way to get started with Thunder, requiring no extra installations or

To use Thunder on your local machine:

- install [nvFuser](https://github.com/NVIDIA/Fuser) nightly and PyTorch nightly together as follows:
- install [nvFuser](https://github.com/NVIDIA/Fuser) and PyTorch stable together as follows:

```bash
# install nvFuser which installs the matching nightly PyTorch
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
# install nvFuser which installs the matching stable PyTorch
pip install --pre nvfuser-cu121-torch24
```

- install [cudnn](https://gitlab-master.nvidia.com/cudnn/cudnn_frontend) as follows:
Expand Down Expand Up @@ -107,8 +107,8 @@ pip install lightning-thunder
Alternatively, you can install the latest version of Thunder directly from this GitHub repository as follows:

```
# 1) Install nvFuser and PyTorch nightly dependencies:
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
# 1) Install nvFuser and PyTorch dependencies:
pip install --pre nvfuser-cu121-torch24
```

```bash
Expand Down
7 changes: 4 additions & 3 deletions docs/source/fundamentals/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ Minimal dependencies

Follow these instructions to install PyTorch, nvFuser, and finally Thunder.

Install PyTorch and nvFuser with pip (command shown is for CUDA 12.1)::
Install PyTorch and nvFuser with pip (command shown is for CUDA 12.1 and PyTorch 2.4.x)::

pip install --pre "nvfuser-cu121[torch]" --extra-index-url https://pypi.nvidia.com
pip install --pre nvfuser-cu121-torch24

cu121 can be replaced with cu118 depending on your CUDA version.
cu121 can be replaced with cu118 depending on your CUDA version. NVFuser builds typically support the latest point release of PyTorch stable versions.
For torch 2.4, cu124 is also supported. For nightly versions and more detailed instructions, please see https://github.com/NVIDIA/Fuser/#installation

You're all set with minimal dependencies, so you can follow `Install Thunder`_.

Expand Down
2 changes: 1 addition & 1 deletion thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def setup_compile(self, model):
dynamo_config.cache_size_limit = 64
model = torch.compile(model)
elif "thunder" in self.compile:
executors = [thunder.nvfuser_executor, thunder.pytorch_executor]
executors = thunder.get_default_executors()
if "inductor_cat" in self.compile:
from thunder.executors.torch_compile import torch_compile_cat_ex as torch_compile_ex

Expand Down
250 changes: 248 additions & 2 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import thunder.core.dtypes as dtypes
import thunder.torch as ltorch
from thunder.torch import TensorLike

from thunder.core import prims, utils
from thunder.core.baseutils import BoundSymbolInterface
from thunder.core.prims import PrimIDs
Expand All @@ -34,16 +36,29 @@
from thunder.core.utils import OrderedSet, check, check_same_dtype
from thunder.core.trace import TraceCtx, from_trace, TraceProvenance
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, Symbol, has_tags
from thunder.core.devices import Device, DeviceType
from thunder.core.devices import Device, DeviceType, cpu
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
from thunder.core.transform_common import dce, cse_single_bsym, replace_redundant_inputs, NON_FUNCTIONAL_OPS
from thunder.core.profile import add_markers
from thunder.core.compile_data import get_compile_option

from thunder.executors.utils import Region
from thunder.core.transforms import (
get_grad,
put_grads,
)

from thunder.executors.utils import (
Region,
_input_dtype_check_fused_scaled_dot_product_attention,
_input_shape_check_fused_scaled_dot_product_attention,
_fused_sdp_choice,
SpdaBackend,
)

from thunder.executors.passes import update_fusion_call_ctx
from thunder.extend import FUEL_LEVEL, FusionExecutor, register_executor, add_default_executor
from thunder.executors.nvfuserex import nvfuser_version

# NOTE This impl file is here because nvFuser may not be available, so it's imported conditionally
# by nvfuserex.py when nvFuser is available.
Expand Down Expand Up @@ -2208,3 +2223,234 @@ def matmul(


register_supported(PrimIDs.MATMUL, matmul, _matmul_check)


# Registering SDPA operators for nvFuser
# SDPA requires an execution and grad transform since the forward and backward passes are called through different implementations.
# For both execution and grad transform, a new operator is registered with nvfuserex (ex.register_operator) and then added to the translation map (register_supported).
# The operators are tagged with OpTag.RANDOM_OP to prevent rematerialization in backward pass.
# Finally, the complete rule is registered through ex.register_supported, with the execution and grad transform wrapping around these operators.


# SDPA Forward
def _scaled_dot_product_flash_attention_forward_meta(
query: TensorLike,
key: TensorLike,
value: TensorLike,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: None | float = None,
) -> tuple[TensorProxy, TensorProxy, int, int]:
# Reference metadata:
# * query (batch_size, num_heads, query_seq_len, E)
# * key (batch_size, num_heads, key_seq_len, E)
# * value (batch_size, num_heads, key_seq_len, Ev)
# * output (batch_size, num_heads, query_seq_len, Ev)

# at::_scaled_dot_product_flash_attention returns {output, log_sumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask}.
# In nvFuser, we only save {output, log_sumexp, philox_seed/offset} for backward since the other variables are not required for non-nested input tensors.
# For non-nested tensor, cum_seq_q/k is undefined, max_q/k can be inferred from input size, and we set `return_debug_mask=False`, so `debug_attn_mask` is a 1D zero tensor.

batch_size, num_heads, query_seq_len, E = query.shape
key_seq_len = key.shape[2]

return (
output := TensorProxy(like=query, shape=(batch_size, num_heads, query_seq_len, E)),
log_sumexp := TensorProxy(
shape=(batch_size, num_heads, query_seq_len), dtype=dtypes.float32, device=query.device, requires_grad=False
),
philox_seed := TensorProxy(shape=(), dtype=dtypes.int64, device=cpu, requires_grad=False),
philox_offset := TensorProxy(shape=(), dtype=dtypes.int64, device=cpu, requires_grad=False),
)


def _scaled_dot_product_flash_attention_forward(
query: TensorProxy,
key: TensorProxy,
value: TensorProxy,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: None | float = None,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:

inputs = [query, key, value, dropout_p, is_causal, scale]
nv_inputs = []
for inp in inputs:
nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None
nv_inputs.append(nv_inp)

return fd.ops.sdpfa_fwd(*nv_inputs)


nv_sdpfa_fwd = ex.register_operator(
"nv_sdpfa_fwd",
meta=_scaled_dot_product_flash_attention_forward_meta,
fn=_scaled_dot_product_flash_attention_forward,
tags=[prims.OpTags.RANDOM_OP],
)

register_supported(nv_sdpfa_fwd.id, _scaled_dot_product_flash_attention_forward, None)


# SDPA Backward
def _scaled_dot_product_flash_attention_backward_meta(
grad_out: TensorLike,
query: TensorLike,
key: TensorLike,
value: TensorLike,
out: TensorLike,
logsumexp: TensorLike,
dropout_p: float,
is_causal: bool,
philox_seed: TensorLike,
philox_offset: TensorLike,
*,
scale: None | float = None,
) -> tuple[TensorProxy, TensorProxy, TensorProxy]:

batch_size, num_heads, query_seq_len, E = query.shape
key_seq_len = key.shape[2]

# Reference metadata:
# https://github.com/pytorch/pytorch/blob/f57b00704e498a676854a02974ca9e0c42188b23/torch/_meta_registrations.py#L5043-L5063
grad_query = TensorProxy(like=query, shape=(batch_size, num_heads, query_seq_len, E))
grad_key = TensorProxy(like=key, shape=(batch_size, num_heads, key_seq_len, E))
grad_value = TensorProxy(like=value, shape=(batch_size, num_heads, key_seq_len, E))
return (grad_query, grad_key, grad_value)


def _scaled_dot_product_flash_attention_backward(
grad_out: TensorProxy,
query: TensorProxy,
key: TensorProxy,
value: TensorProxy,
out: TensorProxy,
logsumexp: TensorProxy,
dropout_p: float,
is_causal: bool,
philox_seed: TensorProxy,
philox_offset: TensorProxy,
*,
scale: None | float = None,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

inputs = [grad_out, query, key, value, out, logsumexp, dropout_p, is_causal, philox_seed, philox_offset, scale]
nv_inputs = []
for inp in inputs:
nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None
nv_inputs.append(nv_inp)

return fd.ops.sdpfa_bwd(*nv_inputs)


nv_sdpfa_bwd = ex.register_operator(
"nv_sdpfa_bwd",
meta=_scaled_dot_product_flash_attention_backward_meta,
fn=_scaled_dot_product_flash_attention_backward,
tags=[prims.OpTags.RANDOM_OP],
)

register_supported(nv_sdpfa_bwd.id, _scaled_dot_product_flash_attention_backward, None)


# Checker for SDPA
def _scaled_dot_product_flash_attention_check(
query: Proxy,
key: Proxy,
value: Proxy,
attn_mask: Proxy | None,
dropout_p: float,
is_causal: bool,
*,
scale: None | float = None,
) -> bool:

# fd.ops.sdpfa_fwd and fd.ops.sdpfa_bwd are adding in versions 0.2.9 and 0.2.10 respectively.
if nvfuser_version() < LooseVersion("0.2.10"):
return False

enable_sdpa: None | bool = get_compile_option("nv_enable_sdpa", "Enable nvFuser flash attention SDPA.")

if not enable_sdpa:
return False

# Flash attn does not support attn_mask currently.
if attn_mask is not None:
return False

if not are_supported_tensors(query, key, value):
return False

# FP64 is not supported by flash attention
supported_dtypes = (dtypes.float16, dtypes.bfloat16)
_input_dtype_check_fused_scaled_dot_product_attention(query, key, value, attn_mask := None, supported_dtypes)
_input_shape_check_fused_scaled_dot_product_attention(query, key, value, attn_mask := None)

# nvFuser only implements flash attention currently.
backend = _fused_sdp_choice(query, key, value, None, dropout_p, is_causal, scale)
return backend == SpdaBackend.FLASH_ATTENTION


# SDPA execution_transform -- calls nv_sdpfa_fwd operator registered above
def scaled_dot_product_flash_attention(
query: TensorProxy,
key: TensorProxy,
value: TensorProxy,
attn_mask: TensorProxy = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: None | float = None,
):
(attn_output, logsumexp, philox_seed, philox_offset) = nv_sdpfa_fwd(
query, key, value, dropout_p, is_causal, scale=scale
)
return attn_output


# SDPA grad_transform -- calls nv_sdpfa_fwd and nv_sdpfa_bwd registered above
def scaled_dot_product_flash_attention_grad(
query: Proxy,
key: Proxy,
value: Proxy,
attn_mask: None | Proxy,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: None | float = None,
):

(attn_output, logsumexp, philox_seed, philox_offset) = nv_sdpfa_fwd(
query, key, value, dropout_p, is_causal, scale=scale
)
grad_out = get_grad(attn_output)
grad_query, grad_key, grad_val = nv_sdpfa_bwd(
grad_out,
query,
key,
value,
attn_output,
logsumexp,
dropout_p,
is_causal,
philox_seed,
philox_offset,
scale=scale,
)
put_grads((query, key, value), (grad_query, grad_key, grad_val))
return attn_output


# Register the complete rule for SDPA in nvfuser executor
ex.register_supported(
ltorch.scaled_dot_product_attention,
checker=_scaled_dot_product_flash_attention_check,
execution_transform=scaled_dot_product_flash_attention,
grad_transform=scaled_dot_product_flash_attention_grad,
)
Loading

0 comments on commit 2a816a3

Please sign in to comment.