Skip to content

Commit

Permalink
[TEST][CUDNN] xfail SDPA on pre-Ampere (#217)
Browse files Browse the repository at this point in the history
Co-authored-by: Vedaanta Agarwalla <142048820+vedaanta-nvidia@users.noreply.github.com>
  • Loading branch information
Aidyn-A and vedaanta authored Apr 25, 2024
1 parent 279380e commit d0ba323
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
from thunder.executors.cudnnex import cudnn_ex, cudnn_version


def _maybe_xfail() -> None:
dev: torch.device = thunder.core.devices.to_torch_device("cuda:0")
cuda_major: int
cuda_major, _ = torch.cuda.get_device_capability(dev)
if cuda_major < 8:
pytest.xfail("cuDNN SDPA uses flash attention, which requires Ampere+")


# These reference inputs are currently used by cudnnex
def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, requires_grad, **kwargs):
"""https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html"""
Expand Down Expand Up @@ -86,17 +94,12 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req

@requiresCUDA
def test_cudnn_sdpa():
_maybe_xfail()

# expect sdpa to fail for 8.9.2 and below
if cudnn.backend_version() <= 8902:
pytest.xfail("Only interleaved layout is supported pre 8.9.2.")

dev: torch.device = thunder.core.devices.to_torch_device("cuda:0")
cuda_major: int
cuda_minor: int
cuda_major, cuda_minor = torch.cuda.get_device_capability(dev)
if cuda_major < 8:
pytest.xfail("cuDNN SDPA uses flash attention, which requires Ampere+")

for dtype in (thunder.float16, thunder.bfloat16):
b, h, s_q, s_kv, d_q, d_v = 8, 8, 256, 256, 64, 64
shape_Q = (b, h, s_q, d_q)
Expand Down Expand Up @@ -162,6 +165,8 @@ def snippet_torch_consistency(op, torch_op, sample):
supported_executors=(TorchExecutor,),
)
def test_cudnn_vs_torch_consistency(op, device, dtype, *_):
_maybe_xfail()

if cudnn.backend_version() < 8905: # todo: could be more specific, just for some cases?
pytest.xfail("s_kv not a multiple of 64 required cudnn version atleast 8.9.5")

Expand Down Expand Up @@ -195,6 +200,8 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_):
@pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv"))
@pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes())))
def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv):
_maybe_xfail()

for sample in grad_sdpa_cudnn_opinfo.reference_inputs("cuda", dtype, requires_grad=True):
# Enforce tensor arguments are contiguous for torch reference
contiguous_args = list(map(lambda a: a.contiguous() if isinstance(a, torch.Tensor) else a, sample.args))
Expand Down

0 comments on commit d0ba323

Please sign in to comment.