Skip to content

Commit

Permalink
Makes cudnn a default executor (#427)
Browse files Browse the repository at this point in the history
Co-authored-by: Vedaanta Agarwalla <vagarwalla@ipp2-1949.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Thomas Viehmann <tv@beamnet.de>
Co-authored-by: Luca Antiga <luca@lightning.ai>
  • Loading branch information
6 people authored May 28, 2024
1 parent 2fa5cab commit 7bd637a
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 14 deletions.
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,23 @@ The easiest way to get started with Thunder, requiring no extra installations or

## Install Thunder

To use Thunder on your local machine, first install [nvFuser](https://github.com/NVIDIA/Fuser) nightly and PyTorch nightly together as follows:
To use Thunder on your local machine:

- install [nvFuser](https://github.com/NVIDIA/Fuser) nightly and PyTorch nightly together as follows:

```bash
# install nvFuser which installs the matching nightly PyTorch
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
```

Then, install Thunder as follows:
- install [cudnn](https://gitlab-master.nvidia.com/cudnn/cudnn_frontend) as follows:

```bash
# install cudnn
pip install nvidia-cudnn-frontend
```

- Finally, install Thunder as follows:

```
# install thunder
Expand Down
8 changes: 7 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
# TODO Extend this
# TODO Add device aliases
# TODO Add executor aliases
"cudnn_executor",
"sdpa_executor",
"nvfuser_executor",
"pytorch_executor",
Expand Down Expand Up @@ -163,17 +164,22 @@ def __version__():
get_default_executors = extend.get_default_executors
get_always_executors = extend.get_always_executors

cudnn_executor: None | extend.Executor = extend.get_executor("cudnn")
sdpa_executor: None | extend.Executor = extend.get_executor("sdpa")
nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser")
pytorch_executor: None | extend.Executor = extend.get_executor("torch")

# Default executor list is [sdpa -> nvfuser -> torch -> python]
# Default executor list is [cudnn -> sdpa -> nvfuser -> torch -> python]
# Note that add_default_executor inserts executor at start of list, hence the reverse order below.
if nvfuser_executor:
add_default_executor(nvfuser_executor)

if sdpa_executor:
add_default_executor(sdpa_executor)

if cudnn_executor:
add_default_executor(cudnn_executor)

#
# Promoted debugging functions
#
Expand Down
64 changes: 55 additions & 9 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,36 @@
import random

from lightning_utilities.core.imports import package_available
from looseversion import LooseVersion


def cudnn_available() -> bool:
return package_available("cudnn")
#
# Functions for detecting cudnn and its version
#
def cudnn_version() -> LooseVersion | None:
try:
import cudnn

if hasattr(cudnn, "__version__"):
return LooseVersion(cudnn.__version__)

# NOTE: This import of cudnn may or may not have version info
return LooseVersion("0.0.0")
except ImportError:
pass

# NOTE This occurs when cudnn couldn't be imported
return None


def cudnn_version() -> int:
if cudnn_available():
return cudnn.backend_version()
return 0
def required_cudnn_version() -> LooseVersion:
# Using 1.3.0 majorly because it works better with other libraries (e.g. torch) that also build on top of cudnn backend
return LooseVersion("1.3.0")


def cudnn_available() -> bool:
v = cudnn_version()
return v is not None and v >= required_cudnn_version()


cudnn: None | Any = None
Expand Down Expand Up @@ -354,6 +374,9 @@ def _cudnn_sdpa_checker(
if cudnn is None:
return False

if query.device.type != "cuda" or key.device != query.device or value.device != query.device:
return False

if len(query.size()) != 4:
return False
_, _, _, d_q = query.size()
Expand All @@ -367,6 +390,32 @@ def _cudnn_sdpa_checker(
if d % 8 != 0 or d > 128:
return False

try:
# Build both forward and backward graphs
query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask)
_make_cudnn_sdpa_forward_graph(query_4d, key_4d, value_4d, attn_mask_4d, dropout_p, is_causal)
_make_cudnn_sdpa_backward_graph(
query_4d,
key_4d,
value_4d,
attn_mask_4d,
dropout_p,
is_causal,
query_4d.stride,
key_4d.stride,
value_4d.stride,
)
# If cudnn can't support the graph, return false
# Please turn on cudnn API logging for helpful messages that mention why the graph is not supported.
# For cudnn backend logging, refer https://docs.nvidia.com/deeplearning/cudnn/latest/reference/troubleshooting.html
# For cudnn frontend logging, refer https://gitlab-master.nvidia.com/cudnn/cudnn_frontend#debugging
except cudnn.cudnnGraphNotSupportedError as ex:
return False
# Otherwise just raise the error.
# These errors can be due to internal cudnn bugs, or user error.
except Exception as e:
raise

return True


Expand All @@ -383,9 +432,6 @@ def _make_cudnn_sdpa_backward_graph(
b, h, s_q, _ = query.size
_, _, _, d_v = value.size

# cuDNN < 9.0.0 might produce nan gradients for sequence length < 64
assert s_q >= 64, "CUDNN SDPA requires sequence length to be at least 64 for backward pass"

graph = cudnn.pygraph(
io_data_type=torch_to_cudnn_dtype(query.dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
Expand Down
7 changes: 5 additions & 2 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

cudnn = pytest.importorskip("cudnn")
from thunder.executors.cudnn_layernormex import cudnn_layernorm_ex
from thunder.executors.cudnnex import cudnn_ex, cudnn_version
from thunder.executors.cudnnex import cudnn_ex


def _maybe_xfail() -> None:
Expand Down Expand Up @@ -196,7 +196,10 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_):
return result


@pytest.mark.skipif(cudnn_version() < 8905, reason="cuDNN is required to be at least `8.9.5`")
@pytest.mark.skipif(
LooseVersion(cudnn.backend_version_string()) < LooseVersion("8.9.5"),
reason="cuDNN is required to be at least `8.9.5`",
)
@pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv"))
@pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes())))
def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv):
Expand Down

0 comments on commit 7bd637a

Please sign in to comment.