From 74bf361b1497e2af73ef647801a0f9a96ec4d2d3 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 22 Mar 2024 23:57:02 +0000 Subject: [PATCH] Minor cleanups. --- thunder/executors/cudnnex.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 39593d59f8..9c53088c9d 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -307,9 +307,7 @@ def _cudnn_sdpa_fwd_impl( return O_actual, softmax_stats_actual, seed_tensor, offset_tensor -# NOTE Uses the torch language context to resolve .size calls -@langctx("torch") -def _cudnn_sdpa_forward_checker( +def _cudnn_sdpa_checker( query: TensorLike, key: TensorLike, value: TensorLike, @@ -319,6 +317,7 @@ def _cudnn_sdpa_forward_checker( *, scale: float | None = None, ) -> bool: + # TODO(#58): make the checker more conservative. if cudnn is None: return False @@ -653,7 +652,7 @@ def _cudnn_sdpa_grad( # Registers the implementation for torch.nn.functional.scaled_dot_product_attention cudnn_ex.register_implementation( ltorch.scaled_dot_product_attention, - checker=_cudnn_sdpa_forward_checker, + checker=_cudnn_sdpa_checker, execution_transform=_cudnn_sdpa_transform, grad_transform=_cudnn_sdpa_grad, )