Skip to content

Commit

Permalink
Renaming.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Mar 23, 2024
1 parent 357ee11 commit d9ad71a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
34 changes: 17 additions & 17 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ def _cudnn_sdpa_bwd_meta(
philox_offset: TensorLike,
*,
scale: None | float = None,
preallocate_grad_qkv: bool,
cat_grad_qkv: bool,
) -> tuple[TensorProxy, ...]:
if preallocate_grad_qkv:
if cat_grad_qkv:
grad_qkv = TensorProxy(
like=query, shape=_replace_dim_with(query.size(), 1, query.size(1) + key.size(1) + value.size(1))
)
Expand All @@ -509,7 +509,7 @@ def _same_size_except(*args, except_dim: int) -> bool:
return all(shape == shapes[0] for shape in shapes)


def _preallocate_grad_qkv(
def _cat_grad_qkv(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down Expand Up @@ -540,18 +540,18 @@ def _cudnn_sdpa_bwd_impl(
philox_offset: torch.Tensor,
*,
scale: None | float = None,
preallocate_grad_qkv: bool,
cat_grad_qkv: bool,
) -> tuple[torch.Tensor, ...]:
query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask)
query = _sdpa_enforce_input_tensor_contiguity(query)
key = _sdpa_enforce_input_tensor_contiguity(key)
value = _sdpa_enforce_input_tensor_contiguity(value)

# When preallocate_grad_qkv is on, allocate dQKV and make dQ, dK, and dV
# When cat_grad_qkv is on, allocate dQKV and make dQ, dK, and dV
# slices of that. Otherwise, allocate them individually.
grad_qkv: None | torch.Tensor = None
if preallocate_grad_qkv:
grad_qkv = _preallocate_grad_qkv(query, key, value)
if cat_grad_qkv:
grad_qkv = _cat_grad_qkv(query, key, value)
grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1)
else:
grad_query = torch.empty_like(query)
Expand Down Expand Up @@ -620,7 +620,7 @@ def _cudnn_sdpa_bwd_impl(

graph.execute(cudnn_to_torch_tensor, workspace)

if preallocate_grad_qkv:
if cat_grad_qkv:
grads = (grad_qkv,)
else:
grads = (grad_query, grad_key, grad_value)
Expand Down Expand Up @@ -672,14 +672,14 @@ def _cudnn_sdpa_bwd_wrapper(
This flag is for enabling nvFuser's zipping optimization that seeks to avoid
expensive concatenation.
https://github.com/NVIDIA/Fuser/issues/1502#issuecomment-1870837878 has more
details. When this flag is true, cudnn_sdpa_bwd may preallocate dQ, dK and dV
in **one** tensor and return them as slices of that tensor.
details. When this flag is true, cudnn_sdpa_bwd may cat dQ, dK and dV
as **one** tensor and return them as slices of that tensor.
"""
may_preallocate: None | bool = get_compile_option("cudnn_sdpa_bwd_may_preallocate", description)
if may_preallocate is None:
may_preallocate = False
assert isinstance(may_preallocate, bool)
preallocate = may_preallocate and _same_size_except(query.size(), key.size(), value.size(), except_dim=1)
may_cast_grad_qkv: None | bool = get_compile_option("cudnn_sdpa_bwd_may_cat_grad_qkv", description)
if may_cast_grad_qkv is None:
may_cast_grad_qkv = False
assert isinstance(may_cast_grad_qkv, bool)
cat_grad_qkv = may_cast_grad_qkv and _same_size_except(query.size(), key.size(), value.size(), except_dim=1)

grads = cudnn_sdpa_bwd(
get_grad(primal),
Expand All @@ -694,15 +694,15 @@ def _cudnn_sdpa_bwd_wrapper(
seed,
offset,
scale=scale,
preallocate_grad_qkv=preallocate,
cat_grad_qkv=cat_grad_qkv,
)

if attn_mask is not None:
grad_attn_mask = grads[-1]
grads = grads[:-1]
put_grad(attn_mask, grad_attn_mask)

if preallocate:
if cat_grad_qkv:
(grad_qkv,) = grads
grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1)
else:
Expand Down
6 changes: 3 additions & 3 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_):
return result


@pytest.mark.parametrize("may_preallocate", (True, False), ids=("may-preallocate", "never-preallocate"))
@pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv"))
@pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes())))
def test_vjp_correctness_cudnn_sdpa(dtype, may_preallocate):
def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv):
for sample in grad_sdpa_cudnn_opinfo.reference_inputs("cuda", dtype, requires_grad=True):
# Enforce tensor arguments are contiguous for torch reference
contiguous_args = list(map(lambda a: a.contiguous() if isinstance(a, torch.Tensor) else a, sample.args))
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_vjp_correctness_cudnn_sdpa(dtype, may_preallocate):
disable_torch_autograd_support=True,
disable_preprocessing=True,
executors_list=[cudnn_ex],
cudnn_sdpa_bwd_may_preallocate=may_preallocate,
cudnn_sdpa_bwd_may_cat_grad_qkv=may_cat_grad_qkv,
)

actual_out, actual_grad = cfoo(filtered_args, (v,))
Expand Down

0 comments on commit d9ad71a

Please sign in to comment.