Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allocate dQ, dK, and dV as a catted tensor to save a downstream cat in nvFuser. #59

Merged
merged 15 commits into from
Mar 27, 2024
162 changes: 122 additions & 40 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any

import torch
import numpy as np
Expand Down Expand Up @@ -28,6 +28,7 @@ def cudnn_available() -> bool:
from thunder.core.langctxs import langctx
import thunder.core.dtypes as dtypes
from thunder.torch import TensorLike
from thunder.core.compile_data import get_compile_option
from thunder.core.proxies import Proxy, TensorProxy


Expand Down Expand Up @@ -207,6 +208,11 @@ def compute_NHWC_strides(shape):
# And when registering for sdpa, cudnn assumes NHWC layout. (See _transform_sdpa_inputs())
def _sdpa_enforce_input_tensor_contiguity(a: torch.Tensor) -> torch.Tensor:
if a.stride(-1) == 1:
# TODO(vedaanta-nvidia): there's an inconsistency between
# _transform_sdpa_inputs and this function, leading to a potential bug.
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
# _transform_sdpa_inputs always creates contiguous strides, but code
# here creates a partially contiguous stride when the last dimension is
# contiguous but other dimensions are not.
return a
else:
return a.contiguous()
Expand Down Expand Up @@ -344,7 +350,9 @@ def _cudnn_sdpa_checker(
)


def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_causal):
def _make_cudnn_sdpa_backward_graph(
query, key, value, attn_mask, dropout_p, is_causal, grad_query_stride, grad_key_stride, grad_value_stride
):
b, h, s_q, _ = query.size
_, _, _, d_v = value.size

Expand Down Expand Up @@ -414,9 +422,13 @@ def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_
dropout=dropout_tuple,
)

dQ.set_output(True).set_dim(query.size).set_stride(query.stride).set_data_type(torch_to_cudnn_dtype(query.dtype))
dK.set_output(True).set_dim(key.size).set_stride(key.stride).set_data_type(torch_to_cudnn_dtype(key.dtype))
dV.set_output(True).set_dim(value.size).set_stride(value.stride).set_data_type(torch_to_cudnn_dtype(value.dtype))
dQ.set_output(True).set_dim(query.size).set_stride(grad_query_stride).set_data_type(
torch_to_cudnn_dtype(query.dtype)
)
dK.set_output(True).set_dim(key.size).set_stride(grad_key_stride).set_data_type(torch_to_cudnn_dtype(key.dtype))
dV.set_output(True).set_dim(value.size).set_stride(grad_value_stride).set_data_type(
torch_to_cudnn_dtype(value.dtype)
)

# Validate the graph before querying the cache key
# Validation makes sure all missing properties are inferred and filled, as they affect cache key.
Expand Down Expand Up @@ -450,7 +462,11 @@ def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_
return _cudnnex_cache[cache_key]


def cudnn_sdpa_backward_meta(
def _replace_dim_with(size: torch.Size, dim: int, dim_size: int) -> torch.Size:
return torch.Size(size[:dim] + (dim_size,) + size[dim + 1 :])


def _cudnn_sdpa_bwd_meta(
grad_out: TensorLike,
query: TensorLike,
key: TensorLike,
Expand All @@ -464,19 +480,53 @@ def cudnn_sdpa_backward_meta(
philox_offset: TensorLike,
*,
scale: None | float = None,
) -> (TensorProxy, TensorProxy, TensorProxy):
grad_query = TensorProxy(like=query)
grad_key = TensorProxy(like=key)
grad_value = TensorProxy(like=value)
cat_grad_qkv: bool,
) -> tuple[TensorProxy, ...]:
if cat_grad_qkv:
grad_qkv = TensorProxy(
like=query, shape=_replace_dim_with(query.size(), 1, query.size(1) + key.size(1) + value.size(1))
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
)
grads = (grad_qkv,)
else:
grad_query = TensorProxy(like=query)
grad_key = TensorProxy(like=key)
grad_value = TensorProxy(like=value)
grads = (grad_query, grad_key, grad_value)

if attn_mask is not None:
grad_attn_mask = TensorProxy(like=attn_mask, shape=attn_mask.shape)
return (grad_query, grad_key, grad_value, grad_attn_mask)
else:
return (grad_query, grad_key, grad_value)
grad_attn_mask = TensorProxy(like=attn_mask)
grads = grads + (grad_attn_mask,)

return grads


def _same_size_except(*args, except_dim: int) -> bool:
shapes = [_replace_dim_with(shape, except_dim, 0) for shape in args]
return all(shape == shapes[0] for shape in shapes)


# Allocates an empty tensor that will hold dQ, dK, and dV, concatenated.
# `query`, `key` and `value` merely provide necessary metadata such as sizes
# and dtypes. They don't have to be passed in as `torch.Tensor`s.
def _allocate_catted_grad_qkv(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
assert _same_size_except(query.size(), key.size(), value.size(), except_dim=1)
assert query.dtype == key.dtype == value.dtype
assert query.device == key.device == value.device

b, s, d = query.size(0), query.size(2), query.size(3)
h_q, h_k, h_v = query.size(1), key.size(1), value.size(1)
h_qkv = h_q + h_k + h_v

# Create grad_qkv as a tensor of size [b,h_qkv,s,d] and allocation order
# [0,2,1,3] from major to minor.
return torch.empty(b, s, h_qkv, d, dtype=query.dtype, device=query.device).permute(0, 2, 1, 3)

def cudnn_sdpa_bwd_impl(

def _cudnn_sdpa_bwd_impl(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
Expand All @@ -490,8 +540,23 @@ def cudnn_sdpa_bwd_impl(
philox_offset: torch.Tensor,
*,
scale: None | float = None,
) -> (torch.Tensor, torch.Tensor, torch.Tensor):
cat_grad_qkv: bool,
) -> tuple[torch.Tensor, ...]:
query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask)
query = _sdpa_enforce_input_tensor_contiguity(query)
key = _sdpa_enforce_input_tensor_contiguity(key)
value = _sdpa_enforce_input_tensor_contiguity(value)

# When cat_grad_qkv is on, allocate dQKV and make dQ, dK, and dV
# slices of that. Otherwise, allocate them individually.
grad_qkv: None | torch.Tensor = None
if cat_grad_qkv:
grad_qkv = _allocate_catted_grad_qkv(query, key, value)
grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1)
else:
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)

(
Q,
Expand All @@ -516,16 +581,11 @@ def cudnn_sdpa_bwd_impl(
attn_mask_4d,
dropout_p,
is_causal,
grad_query.stride(),
grad_key.stride(),
grad_value.stride(),
)

query = _sdpa_enforce_input_tensor_contiguity(query)
key = _sdpa_enforce_input_tensor_contiguity(key)
value = _sdpa_enforce_input_tensor_contiguity(value)

grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)

# Default value of scale, if not provided, in all torch versions
if scale is None:
scale = query.shape[-1] ** -0.5
Expand Down Expand Up @@ -560,21 +620,25 @@ def cudnn_sdpa_bwd_impl(

graph.execute(cudnn_to_torch_tensor, workspace)

if attn_mask is None:
return grad_query, grad_key, grad_value
if cat_grad_qkv:
grads = (grad_qkv,)
else:
return grad_query, grad_key, grad_value, grad_attn_mask
grads = (grad_query, grad_key, grad_value)

if attn_mask is not None:
grads = grads + (grad_attn_mask,)
return grads


cudnn_sdpa_bwd = cudnn_ex.register_operator(
"cudnn_sdpa_bwd",
meta=cudnn_sdpa_backward_meta,
fn=cudnn_sdpa_bwd_impl,
meta=_cudnn_sdpa_bwd_meta,
fn=_cudnn_sdpa_bwd_impl,
)


@langctx("torch")
def _cudnn_sdpa_transform(
def _cudnn_sdpa_fwd_wrapper(
query: TensorProxy,
key: TensorProxy,
value: TensorProxy,
Expand All @@ -590,7 +654,7 @@ def _cudnn_sdpa_transform(


@langctx("torch")
def _cudnn_sdpa_grad(
def _cudnn_sdpa_bwd_wrapper(
query: TensorProxy,
key: TensorProxy,
value: TensorProxy,
Expand All @@ -604,9 +668,20 @@ def _cudnn_sdpa_grad(
query, key, value, attn_mask, dropout_p, is_causal, scale=scale
)

g = get_grad(primal)
description = """\
This flag is for enabling nvFuser's zipping optimization that seeks to avoid
expensive concatenation. https://github.com/NVIDIA/Fuser/issues/1768 has more
details. When this flag is true, cudnn_sdpa_bwd may cat dQ, dK and dV as one
tensor and return them as slices of that tensor.
"""
may_cat_grad_qkv: None | bool = get_compile_option("cudnn_sdpa_bwd_may_cat_grad_qkv", description)
if may_cat_grad_qkv is None:
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
may_cat_grad_qkv = False
assert isinstance(may_cat_grad_qkv, bool)
cat_grad_qkv = may_cat_grad_qkv and _same_size_except(query.size(), key.size(), value.size(), except_dim=1)

grads = cudnn_sdpa_bwd(
g,
get_grad(primal),
query,
key,
value,
Expand All @@ -618,23 +693,30 @@ def _cudnn_sdpa_grad(
seed,
offset,
scale=scale,
cat_grad_qkv=cat_grad_qkv,
)
if attn_mask is None:
grad_query, grad_key, grad_val = grads
else:
grad_query, grad_key, grad_val, grad_attn_mask = grads

put_grads((query, key, value), (grad_query, grad_key, grad_val))
if attn_mask is not None:
grad_attn_mask = grads[-1]
grads = grads[:-1]
put_grad(attn_mask, grad_attn_mask)

if cat_grad_qkv:
# The `split` is done outside `cudnn_sdpa_bwd` so it can be picked up
# by nvfuserex.
(grad_qkv,) = grads
grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1)
else:
grad_query, grad_key, grad_value = grads
put_grads((query, key, value), (grad_query, grad_key, grad_value))

return primal


# Registers the implementation for torch.nn.functional.scaled_dot_product_attention
cudnn_ex.register_implementation(
ltorch.scaled_dot_product_attention,
checker=_cudnn_sdpa_checker,
execution_transform=_cudnn_sdpa_transform,
grad_transform=_cudnn_sdpa_grad,
execution_transform=_cudnn_sdpa_fwd_wrapper,
grad_transform=_cudnn_sdpa_bwd_wrapper,
)
2 changes: 1 addition & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ def _can_fuse_node(n: Node):
region = Region(producers, consumers, bsyms)

# Acquires the nv_enable_bookend compile option, which defaults to True
bookend_help = """
bookend_help = """\
nvFuser's 'bookending' heuristic tries to gather metadata operations---such as
transpose, reshape, or view---into the beginning and ends of blocks that utilize
nvFuser. By pushing these ops to the edges, they will get dropped by the nvFuser
Expand Down
31 changes: 14 additions & 17 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True)

# Same sequence length and embedding size for Q, K and V, a common use case.
yield SampleInput(make(N, n_head, L, E), make(N, n_head, L, E), make(N, n_head, L, E), None, is_causal=True)

# Non-contiguous input tensor case
nq = make(N, n_head, E, L).permute(0, 1, 3, 2)
nk = make(N, n_head, E, S).permute(0, 1, 3, 2)
Expand Down Expand Up @@ -72,9 +75,8 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req
sample_input_generator=None,
reference_input_generator=grad_scaled_dot_product_attention_reference_generator,
torch_reference=torch.nn.functional.scaled_dot_product_attention,
# RuntimeError: Only fp32, half & bf16 supported at the moment
# RuntimeError: Only half & bf16 supported at the moment
dtypes=(
thunder.dtypes.float32,
thunder.dtypes.float16,
thunder.dtypes.bfloat16,
),
Expand Down Expand Up @@ -186,15 +188,10 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_):
return result


# NOTE Scaled_Dot_Product_Efficient_Attention_Backward does not support fp64 dtypes
# RuntimeError: Only fp32, half & bf16 supported at the moment
@ops(
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
(grad_sdpa_cudnn_opinfo,),
supported_dtypes=(dtypes.float16, dtypes.bfloat16),
supported_devicetypes=(devices.DeviceType.CUDA,),
)
def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp):
for sample in op.reference_inputs(device, dtype, requires_grad=True):
@pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv"))
@pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes())))
def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv):
for sample in grad_sdpa_cudnn_opinfo.reference_inputs("cuda", dtype, requires_grad=True):
# Enforce tensor arguments are contiguous for torch reference
contiguous_args = list(map(lambda a: a.contiguous() if isinstance(a, torch.Tensor) else a, sample.args))

Expand All @@ -209,25 +206,25 @@ def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp):
continue

# Compute vjp result using PyTorch
expect_out = op.torch_reference(*contiguous_args, **sample.kwargs)
expect_out = grad_sdpa_cudnn_opinfo.torch_reference(*contiguous_args, **sample.kwargs)
v = make_tensor_like(expect_out)
expected_grad = torch.autograd.grad(expect_out, grad_inputs, v)

# Compute vjp result using Thunder
flat_op, flat_args, spec = flatten_func(op.op, sample.args, sample.kwargs)
flat_op, flat_args, spec = flatten_func(grad_sdpa_cudnn_opinfo.op, sample.args, sample.kwargs)
filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)

cfoo = thunder.compile(
vjp(filtered_op),
disable_torch_autograd_support=True,
disable_preprocessing=True,
executors_list=executor.executors_list() + [cudnn_ex],
executors_list=[cudnn_ex],
cudnn_sdpa_bwd_may_cat_grad_qkv=may_cat_grad_qkv,
)

actual_out, actual_grad = cfoo(filtered_args, (v,))

comp(actual_out, expect_out, atol=1e-2, rtol=1e-2)

torch.testing.assert_close(actual_out, expect_out, atol=1e-2, rtol=1e-2)
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
# compare gradients of query, key, value, and attn_mask
for eg, ag in zip(expected_grad, actual_grad):
comp(eg, ag, atol=2e-1, rtol=2e-2)
torch.testing.assert_close(eg, ag, atol=2e-1, rtol=2e-2)
Loading