diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 9c53088c9d..a2d61bb626 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -334,25 +334,7 @@ def _cudnn_sdpa_checker( if d % 8 != 0 or d > 128: return False - is_backward_supported = _cudnn_sdpa_backward_checker( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - - return True and is_backward_supported - - -@langctx("torch") -def _cudnn_sdpa_backward_checker( - query: TensorLike, - key: TensorLike, - value: TensorLike, - attn_mask: TensorLike | None = None, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: float | None = None, -) -> bool: - return cudnn is not None + return True cudnn_sdpa_fwd = cudnn_ex.register_operator(