diff --git a/README.md b/README.md index e97b97fb48..5557c6f8d1 100644 --- a/README.md +++ b/README.md @@ -73,14 +73,23 @@ The easiest way to get started with Thunder, requiring no extra installations or ## Install Thunder -To use Thunder on your local machine, first install [nvFuser](https://github.com/NVIDIA/Fuser) nightly and PyTorch nightly together as follows: +To use Thunder on your local machine: + +- install [nvFuser](https://github.com/NVIDIA/Fuser) nightly and PyTorch nightly together as follows: ```bash # install nvFuser which installs the matching nightly PyTorch pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com ``` -Then, install Thunder as follows: +- install [cudnn](https://gitlab-master.nvidia.com/cudnn/cudnn_frontend) as follows: + +```bash +# install cudnn +pip install nvidia-cudnn-frontend +``` + +- Finally, install Thunder as follows: ``` # install thunder diff --git a/thunder/__init__.py b/thunder/__init__.py index 51c8c98778..5eb0441b84 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -110,6 +110,7 @@ # TODO Extend this # TODO Add device aliases # TODO Add executor aliases + "cudnn_executor", "sdpa_executor", "nvfuser_executor", "pytorch_executor", @@ -163,17 +164,22 @@ def __version__(): get_default_executors = extend.get_default_executors get_always_executors = extend.get_always_executors +cudnn_executor: None | extend.Executor = extend.get_executor("cudnn") sdpa_executor: None | extend.Executor = extend.get_executor("sdpa") nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser") pytorch_executor: None | extend.Executor = extend.get_executor("torch") -# Default executor list is [sdpa -> nvfuser -> torch -> python] +# Default executor list is [cudnn -> sdpa -> nvfuser -> torch -> python] +# Note that add_default_executor inserts executor at start of list, hence the reverse order below. if nvfuser_executor: add_default_executor(nvfuser_executor) if sdpa_executor: add_default_executor(sdpa_executor) +if cudnn_executor: + add_default_executor(cudnn_executor) + # # Promoted debugging functions # diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 9b92cf0952..37139fdcc4 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -5,16 +5,36 @@ import random from lightning_utilities.core.imports import package_available +from looseversion import LooseVersion -def cudnn_available() -> bool: - return package_available("cudnn") +# +# Functions for detecting cudnn and its version +# +def cudnn_version() -> LooseVersion | None: + try: + import cudnn + + if hasattr(cudnn, "__version__"): + return LooseVersion(cudnn.__version__) + + # NOTE: This import of cudnn may or may not have version info + return LooseVersion("0.0.0") + except ImportError: + pass + + # NOTE This occurs when cudnn couldn't be imported + return None -def cudnn_version() -> int: - if cudnn_available(): - return cudnn.backend_version() - return 0 +def required_cudnn_version() -> LooseVersion: + # Using 1.3.0 majorly because it works better with other libraries (e.g. torch) that also build on top of cudnn backend + return LooseVersion("1.3.0") + + +def cudnn_available() -> bool: + v = cudnn_version() + return v is not None and v >= required_cudnn_version() cudnn: None | Any = None @@ -354,6 +374,9 @@ def _cudnn_sdpa_checker( if cudnn is None: return False + if query.device.type != "cuda" or key.device != query.device or value.device != query.device: + return False + if len(query.size()) != 4: return False _, _, _, d_q = query.size() @@ -367,6 +390,32 @@ def _cudnn_sdpa_checker( if d % 8 != 0 or d > 128: return False + try: + # Build both forward and backward graphs + query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask) + _make_cudnn_sdpa_forward_graph(query_4d, key_4d, value_4d, attn_mask_4d, dropout_p, is_causal) + _make_cudnn_sdpa_backward_graph( + query_4d, + key_4d, + value_4d, + attn_mask_4d, + dropout_p, + is_causal, + query_4d.stride, + key_4d.stride, + value_4d.stride, + ) + # If cudnn can't support the graph, return false + # Please turn on cudnn API logging for helpful messages that mention why the graph is not supported. + # For cudnn backend logging, refer https://docs.nvidia.com/deeplearning/cudnn/latest/reference/troubleshooting.html + # For cudnn frontend logging, refer https://gitlab-master.nvidia.com/cudnn/cudnn_frontend#debugging + except cudnn.cudnnGraphNotSupportedError as ex: + return False + # Otherwise just raise the error. + # These errors can be due to internal cudnn bugs, or user error. + except Exception as e: + raise + return True @@ -383,9 +432,6 @@ def _make_cudnn_sdpa_backward_graph( b, h, s_q, _ = query.size _, _, _, d_v = value.size - # cuDNN < 9.0.0 might produce nan gradients for sequence length < 64 - assert s_q >= 64, "CUDNN SDPA requires sequence length to be at least 64 for backward pass" - graph = cudnn.pygraph( io_data_type=torch_to_cudnn_dtype(query.dtype), intermediate_data_type=cudnn.data_type.FLOAT, diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index 0f8b51d543..55d3ec5769 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -18,7 +18,7 @@ cudnn = pytest.importorskip("cudnn") from thunder.executors.cudnn_layernormex import cudnn_layernorm_ex -from thunder.executors.cudnnex import cudnn_ex, cudnn_version +from thunder.executors.cudnnex import cudnn_ex def _maybe_xfail() -> None: @@ -196,7 +196,10 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): return result -@pytest.mark.skipif(cudnn_version() < 8905, reason="cuDNN is required to be at least `8.9.5`") +@pytest.mark.skipif( + LooseVersion(cudnn.backend_version_string()) < LooseVersion("8.9.5"), + reason="cuDNN is required to be at least `8.9.5`", +) @pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv")) @pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes()))) def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv):