Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes cudnn a default executor #427

Merged
merged 19 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
# TODO Extend this
# TODO Add device aliases
# TODO Add executor aliases
"cudnn_executor",
"sdpa_executor",
"nvfuser_executor",
"pytorch_executor",
Expand Down Expand Up @@ -155,17 +156,22 @@ def __version__():
get_default_executors = extend.get_default_executors
get_always_executors = extend.get_always_executors

cudnn_executor: None | extend.Executor = extend.get_executor("cudnn")
sdpa_executor: None | extend.Executor = extend.get_executor("sdpa")
nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser")
pytorch_executor: None | extend.Executor = extend.get_executor("torch")

# Default executor list is [sdpa -> nvfuser -> torch -> python]
# Default executor list is [cudnn -> sdpa -> nvfuser -> torch -> python]
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
# Note that add_default_executor inserts executor at start of list, hence the reverse order below.
if nvfuser_executor:
add_default_executor(nvfuser_executor)

if sdpa_executor:
add_default_executor(sdpa_executor)

if cudnn_executor:
add_default_executor(cudnn_executor)

#
# Promoted debugging functions
#
Expand Down
46 changes: 37 additions & 9 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,33 @@
import random

from lightning_utilities.core.imports import package_available
from looseversion import LooseVersion

#
# Functions for detecting cudnn and its version
#
def cudnn_version() -> LooseVersion | None:
try:
import cudnn

def cudnn_available() -> bool:
return package_available("cudnn")
if hasattr(cudnn, "__version__"):
return LooseVersion(cudnn.__version__)

# NOTE: This import of cudnn may or may not have version info
return LooseVersion("0.0.0")
except ImportError:
pass

# NOTE This occurs when cudnn couldn't be imported
return None

def cudnn_version() -> int:
if cudnn_available():
return cudnn.backend_version()
return 0
def required_cudnn_version() -> LooseVersion:
# Using 1.3.0 majorly because it works better with other libraries (e.g. torch) that also build on top of cudnn backend
return LooseVersion("1.3.0")

def cudnn_available() -> bool:
v = cudnn_version()
return v is not None and v >= required_cudnn_version()

cudnn: None | Any = None
cudnn_backend_version: None | Any = None
Expand Down Expand Up @@ -367,6 +383,21 @@ def _cudnn_sdpa_checker(
if d % 8 != 0 or d > 128:
return False

try:
# Build both forward and backward graphs
query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask)
_make_cudnn_sdpa_forward_graph(query_4d, key_4d, value_4d, attn_mask_4d, dropout_p, is_causal)
_make_cudnn_sdpa_backward_graph(query_4d, key_4d, value_4d, attn_mask_4d, dropout_p, is_causal, query_4d.stride, key_4d.stride, value_4d.stride)

# If cudnn can't support the graph, return false
# Please turn on cudnn API logging for helpful messages that mention why the graph is not supported.
tfogal marked this conversation as resolved.
Show resolved Hide resolved
except cudnn.cudnnGraphNotSupportedError as ex:
return False
# Otherwise just raise the error.
# These errors can be due to internal cudnn bugs, or user error.
except Exception as e:
raise

return True


Expand All @@ -383,9 +414,6 @@ def _make_cudnn_sdpa_backward_graph(
b, h, s_q, _ = query.size
_, _, _, d_v = value.size

# cuDNN < 9.0.0 might produce nan gradients for sequence length < 64
vedaanta marked this conversation as resolved.
Show resolved Hide resolved
assert s_q >= 64, "CUDNN SDPA requires sequence length to be at least 64 for backward pass"

graph = cudnn.pygraph(
io_data_type=torch_to_cudnn_dtype(query.dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

cudnn = pytest.importorskip("cudnn")
from thunder.executors.cudnn_layernormex import cudnn_layernorm_ex
from thunder.executors.cudnnex import cudnn_ex, cudnn_version
from thunder.executors.cudnnex import cudnn_ex


def _maybe_xfail() -> None:
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_):
return result


@pytest.mark.skipif(cudnn_version() < 8905, reason="cuDNN is required to be at least `8.9.5`")
@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < LooseVersion('8.9.5'), reason="cuDNN is required to be at least `8.9.5`")
@pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv"))
@pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes())))
def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv):
Expand Down
Loading