Skip to content

Commit

Permalink
Minor cleanups.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Mar 22, 2024
1 parent 3e7b7aa commit 74bf361
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,7 @@ def _cudnn_sdpa_fwd_impl(
return O_actual, softmax_stats_actual, seed_tensor, offset_tensor


# NOTE Uses the torch language context to resolve .size calls
@langctx("torch")
def _cudnn_sdpa_forward_checker(
def _cudnn_sdpa_checker(
query: TensorLike,
key: TensorLike,
value: TensorLike,
Expand All @@ -319,6 +317,7 @@ def _cudnn_sdpa_forward_checker(
*,
scale: float | None = None,
) -> bool:
# TODO(#58): make the checker more conservative.
if cudnn is None:
return False

Expand Down Expand Up @@ -653,7 +652,7 @@ def _cudnn_sdpa_grad(
# Registers the implementation for torch.nn.functional.scaled_dot_product_attention
cudnn_ex.register_implementation(
ltorch.scaled_dot_product_attention,
checker=_cudnn_sdpa_forward_checker,
checker=_cudnn_sdpa_checker,
execution_transform=_cudnn_sdpa_transform,
grad_transform=_cudnn_sdpa_grad,
)

0 comments on commit 74bf361

Please sign in to comment.