Skip to content

Commit

Permalink
reenable cudnn sdpa
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Jun 24, 2024
1 parent 4ce822b commit 5f65e50
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 22 deletions.
9 changes: 0 additions & 9 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7880,15 +7880,6 @@ def grad_scaled_dot_product_attention_sample_generator(op, device, dtype, requir
# NOTE: NotImplementedError: Could not run 'aten::_scaled_dot_product_efficient_attention' with arguments from the 'CPU' backend.
# NOTE: NotImplementedError: Could not run 'aten::_scaled_dot_product_efficient_attention_backward' with arguments from the 'CPU' backend
devicetypes=(devices.DeviceType.CUDA,),
test_directives=(
DecorateInfo(
pytest.mark.skip(reason="https://github.com/Lightning-AI/lightning-thunder/issues/567"),
"test_core_vs_torch_consistency",
dtypes=(datatypes.bfloat16, datatypes.float16, datatypes.float32),
devicetypes=(devices.DeviceType.CUDA,),
active_if=version_between(torch.__version__, min_ver="2.4.0a0", max_ver="2.4.0a99"),
),
),
)
nn_ops.append(grad_sdpa_opinfo)

Expand Down
4 changes: 0 additions & 4 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,6 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_):
LooseVersion(cudnn.backend_version_string()) < LooseVersion("8.9.5"),
reason="cuDNN is required to be at least `8.9.5`",
)
@pytest.mark.skipif(
version_between(torch.__version__, min_ver="2.4.0a0", max_ver="2.4.0a99"),
reason="https://github.com/Lightning-AI/lightning-thunder/issues/567",
)
@pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv"))
@pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes())))
def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv):
Expand Down
9 changes: 0 additions & 9 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,21 +541,12 @@ def test_vjp_correctness_index_put_manual(op, device, dtype, executor, comp):

# NOTE Scaled_Dot_Product_Efficient_Attention_Backward does not support fp64 dtypes
# RuntimeError: Only fp32, half & bf16 supported at the moment
@pytest.mark.skipif(
not version_between(torch.__version__, min_ver="2.4.0a0", max_ver="2.4.0a99"),
reason="https://github.com/Lightning-AI/lightning-thunder/issues/567",
)
@ops(
(get_opinfo("grad_forward_scaled_dot_product_attention"),),
supported_dtypes=(dtypes.float16, dtypes.bfloat16),
supported_devicetypes=(devices.DeviceType.CUDA,),
)
def test_vjp_correctness_sdpa_manual(op, device, dtype, executor, comp):
if version_between(torch.__version__, min_ver="2.4.0a0", max_ver="2.4.0a99"):
raise pytest.skip(
"https://github.com/Lightning-AI/lightning-thunder/issues/567",
)

for sample in op.sample_inputs(device, dtype, requires_grad=True):
from thunder.executors.sdpaex import sdpa_ex

Expand Down

0 comments on commit 5f65e50

Please sign in to comment.