Skip to content

Commit

Permalink
Handles non-contiguous input strides (#622)
Browse files Browse the repository at this point in the history
  • Loading branch information
vedaanta authored Jun 26, 2024
1 parent 575c0bd commit 754af86
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 90 deletions.
10 changes: 9 additions & 1 deletion thunder/executors/cudnn_layernormex.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
import thunder.core.dtypes as dtypes
from thunder.core.proxies import TensorProxy

from thunder.executors.cudnnex import CudnnTensorAttributes, torch_to_cudnn_dtype
from thunder.executors.cudnnex import torch_to_cudnn_dtype


@dataclass(frozen=True)
class CudnnTensorAttributes:
size: tuple
stride: tuple
dtype: torch.dtype
device_index: int


def make_cacheable_cudnn_graph_inputs(func):
Expand Down
193 changes: 104 additions & 89 deletions thunder/executors/cudnnex.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,6 @@ def _get_cudnn_handle(query_device):
register_executor(cudnn_ex)


@dataclass(frozen=True)
class CudnnTensorAttributes:
size: tuple
stride: tuple
dtype: torch.dtype
device_index: int


from collections import OrderedDict


Expand All @@ -128,23 +120,26 @@ def __setitem__(self, key, value):
_cudnnex_cache = CudnnexLRUCache(maxlen=1024)


def _make_cudnn_sdpa_forward_graph(query, key, value, attn_mask, dropout_p, is_causal):
def _make_cudnn_sdpa_forward_graph(
query, key, value, attn_mask, dropout_p, is_causal, query_stride, key_stride, value_stride
):
graph = cudnn.pygraph(
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
handle=_get_cudnn_handle(query.device_index),
handle=_get_cudnn_handle(query.device.index),
)

Q = graph.tensor(name="Q", dim=query.size, stride=query.stride, data_type=torch_to_cudnn_dtype(query.dtype))
K = graph.tensor(name="K", dim=key.size, stride=key.stride, data_type=torch_to_cudnn_dtype(key.dtype))
V = graph.tensor(name="V", dim=value.size, stride=value.stride, data_type=torch_to_cudnn_dtype(value.dtype))
Q = graph.tensor(name="Q", dim=query.shape, stride=query_stride, data_type=torch_to_cudnn_dtype(query.dtype))
K = graph.tensor(name="K", dim=key.shape, stride=key_stride, data_type=torch_to_cudnn_dtype(key.dtype))
V = graph.tensor(name="V", dim=value.shape, stride=value_stride, data_type=torch_to_cudnn_dtype(value.dtype))
Bias = None
if attn_mask is not None:
attn_mask_stride = _compute_row_major_strides(attn_mask.shape)
Bias = graph.tensor(
name="bias", dim=attn_mask.size, stride=attn_mask.stride, data_type=torch_to_cudnn_dtype(attn_mask.dtype)
name="bias", dim=attn_mask.shape, stride=attn_mask_stride, data_type=torch_to_cudnn_dtype(attn_mask.dtype)
)

scalar_dim_stride = tuple([1] * len(query.size))
scalar_dim_stride = tuple([1] * len(query.shape))
dropout_tuple = None
Seed = None
Offset = None
Expand Down Expand Up @@ -178,8 +173,8 @@ def _make_cudnn_sdpa_forward_graph(query, key, value, attn_mask, dropout_p, is_c
)

# TODO: update to do tensor.stride_order when available from FE
b, h, s_q, _ = query.size
_, _, _, d_v = value.size
b, h, s_q, _ = query.shape
_, _, _, d_v = value.shape

dim_o = (b, h, s_q, d_v)
stride_o = (h * s_q * d_v, s_q * d_v, d_v, 1)
Expand Down Expand Up @@ -223,33 +218,11 @@ def torch_to_cudnn_dtype(lc_dtype: dtypes.dtype):
return _torch_to_cudnn_dtype_map[lc_dtype]


def _transform_sdpa_inputs(query, key, value, attn_mask):
def compute_NHWC_strides(shape):
strides = [1] * len(shape)
stride = 1
for i in reversed(range(len(shape))):
strides[i] = stride
stride *= shape[i]
return tuple(strides)

query_4d = CudnnTensorAttributes(query.shape, compute_NHWC_strides(query.shape), query.dtype, query.device.index)

key_4d = CudnnTensorAttributes(key.shape, compute_NHWC_strides(key.shape), key.dtype, key.device.index)

value_4d = CudnnTensorAttributes(value.shape, compute_NHWC_strides(value.shape), value.dtype, value.device.index)

attn_mask_4d = None
if attn_mask is not None:
# Make attn_mask to be of the same dimensionality as other input tensors
attn_mask_shape = (1,) * (query.ndim - attn_mask.ndim) + attn_mask.shape

# cudnn does not support boolean attn_mask, so make one with -inf
attn_mask_dtype = query.dtype if attn_mask.dtype in [torch.bool, dtypes.bool8] else attn_mask.dtype
attn_mask_4d = CudnnTensorAttributes(
attn_mask_shape, compute_NHWC_strides(attn_mask_shape), attn_mask_dtype, attn_mask.device.index
)

return query_4d, key_4d, value_4d, attn_mask_4d
def _compute_row_major_strides(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(strides[-1] * dim)
return tuple(reversed(strides))


# sdpa requires that the embedding dim stride be one.
Expand Down Expand Up @@ -303,7 +276,18 @@ def _cudnn_sdpa_fwd_impl(
*,
scale: float | None = None,
) -> tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask)

query = _sdpa_enforce_input_tensor_contiguity(query)
key = _sdpa_enforce_input_tensor_contiguity(key)
value = _sdpa_enforce_input_tensor_contiguity(value)

if attn_mask is not None:
attn_mask = attn_mask.view((1,) * (query.ndim - attn_mask.ndim), *attn_mask.shape)
# As cudnn does not support boolean attn_mask, convert these to additive mask with -inf
if attn_mask.dtype == torch.bool:
attn_bias = torch.zeros_like(attn_mask, dtype=query.dtype)
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
attn_mask = attn_bias

(
Q,
Expand All @@ -316,12 +300,14 @@ def _cudnn_sdpa_fwd_impl(
O,
softmax_stats,
graph,
) = _make_cudnn_sdpa_forward_graph(query_4d, key_4d, value_4d, attn_mask_4d, dropout_p, is_causal)
) = _make_cudnn_sdpa_forward_graph(
query, key, value, attn_mask, dropout_p, is_causal, query.stride(), key.stride(), value.stride()
)

b, h, s_q, d_q = query.size()
b, h_q, s_q, d_q = query.size()
_, _, _, d_v = value.size()
O_actual = torch.empty(b, h, s_q, d_v, dtype=value.dtype, device=query.device)
softmax_stats_actual = torch.empty(b, h, s_q, 1, dtype=torch.float32, device=query.device)
O_actual = torch.empty(b, h_q, s_q, d_v, dtype=value.dtype, device=query.device)
softmax_stats_actual = torch.empty(b, h_q, s_q, 1, dtype=torch.float32, device=query.device)
workspace = torch.empty(graph.get_workspace_size(), device=query.device, dtype=torch.uint8)

seed_tensor = (
Expand All @@ -338,15 +324,10 @@ def _cudnn_sdpa_fwd_impl(
scale = 1 / d_q**0.5
Attn_scale_cpu = torch.full((1, 1, 1, 1), scale, dtype=torch.float32, device="cpu")

if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_bias = torch.zeros_like(attn_mask, dtype=query.dtype)
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
attn_mask = attn_bias

cudnn_to_torch_tensor = {
Q: _sdpa_enforce_input_tensor_contiguity(query).detach(),
K: _sdpa_enforce_input_tensor_contiguity(key).detach(),
V: _sdpa_enforce_input_tensor_contiguity(value).detach(),
Q: query.detach(),
K: key.detach(),
V: value.detach(),
Attn_scale: Attn_scale_cpu,
Seed: seed_tensor,
Offset: offset_tensor,
Expand Down Expand Up @@ -395,19 +376,37 @@ def _cudnn_sdpa_checker(
return False

try:
# TensorProxy do not contain stride information, but cudnn graph requires them.
# Assume row major layout for now. If the strides during execution are different, a new graph will be built.
query_stride = _compute_row_major_strides(query.size())
key_stride = _compute_row_major_strides(key.size())
value_stride = _compute_row_major_strides(value.size())

if attn_mask is not None:
# Make attn_mask to be of the same dimensionality as other input tensors
attn_mask_shape = (1,) * (query.ndim - attn_mask.ndim) + attn_mask.shape
# cudnn does not support boolean attn_mask, so make it additive mask instead.
# During execution, similar change to attn_mask buffer will be made, where all values of False will be replaced with -inf
attn_mask_dtype = query.dtype if attn_mask.dtype in [torch.bool, dtypes.bool8] else attn_mask.dtype
attn_mask = TensorProxy(like=attn_mask, shape=attn_mask_shape, dtype=attn_mask_dtype)

# Build both forward and backward graphs
query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask)
_make_cudnn_sdpa_forward_graph(query_4d, key_4d, value_4d, attn_mask_4d, dropout_p, is_causal)
_make_cudnn_sdpa_forward_graph(
query, key, value, attn_mask, dropout_p, is_causal, query_stride, key_stride, value_stride
)
_make_cudnn_sdpa_backward_graph(
query_4d,
key_4d,
value_4d,
attn_mask_4d,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
query_4d.stride,
key_4d.stride,
value_4d.stride,
query_stride,
key_stride,
value_stride,
query_stride,
key_stride,
value_stride, # Use the same strides as inputs for their respective grads
)
# If cudnn can't support the graph, return false
# Please turn on cudnn API logging for helpful messages that mention why the graph is not supported.
Expand All @@ -431,21 +430,32 @@ def _cudnn_sdpa_checker(


def _make_cudnn_sdpa_backward_graph(
query, key, value, attn_mask, dropout_p, is_causal, grad_query_stride, grad_key_stride, grad_value_stride
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
query_stride,
key_stride,
value_stride,
grad_query_stride,
grad_key_stride,
grad_value_stride,
):
b, h, s_q, _ = query.size
_, _, _, d_v = value.size
b, h, s_q, _ = query.shape
_, _, _, d_v = value.shape

graph = cudnn.pygraph(
io_data_type=torch_to_cudnn_dtype(query.dtype),
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
handle=_get_cudnn_handle(query.device_index),
handle=_get_cudnn_handle(query.device.index),
)

Q = graph.tensor(name="Q", dim=query.size, stride=query.stride, data_type=torch_to_cudnn_dtype(query.dtype))
K = graph.tensor(name="K", dim=key.size, stride=key.stride, data_type=torch_to_cudnn_dtype(key.dtype))
V = graph.tensor(name="V", dim=value.size, stride=value.stride, data_type=torch_to_cudnn_dtype(value.dtype))
Q = graph.tensor(name="Q", dim=query.shape, stride=query_stride, data_type=torch_to_cudnn_dtype(query.dtype))
K = graph.tensor(name="K", dim=key.shape, stride=key_stride, data_type=torch_to_cudnn_dtype(key.dtype))
V = graph.tensor(name="V", dim=value.shape, stride=value_stride, data_type=torch_to_cudnn_dtype(value.dtype))

dim_o = (b, h, s_q, d_v)
stride_o = (h * s_q * d_v, s_q * d_v, d_v, 1)
Expand All @@ -459,12 +469,13 @@ def _make_cudnn_sdpa_backward_graph(
Bias = None
dBias = None
if attn_mask is not None:
attn_mask_stride = _compute_row_major_strides(attn_mask.shape)
Bias = graph.tensor(
name="bias", dim=attn_mask.size, stride=attn_mask.stride, data_type=torch_to_cudnn_dtype(attn_mask.dtype)
name="bias", dim=attn_mask.shape, stride=attn_mask_stride, data_type=torch_to_cudnn_dtype(attn_mask.dtype)
)
dBias = graph.tensor_like(Bias)

scalar_dim_stride = tuple([1] * len(query.size))
scalar_dim_stride = tuple([1] * len(query.shape))
dropout_tuple = None
Seed = None
Offset = None
Expand Down Expand Up @@ -499,11 +510,11 @@ def _make_cudnn_sdpa_backward_graph(
dropout=dropout_tuple,
)

dQ.set_output(True).set_dim(query.size).set_stride(grad_query_stride).set_data_type(
dQ.set_output(True).set_dim(query.shape).set_stride(grad_query_stride).set_data_type(
torch_to_cudnn_dtype(query.dtype)
)
dK.set_output(True).set_dim(key.size).set_stride(grad_key_stride).set_data_type(torch_to_cudnn_dtype(key.dtype))
dV.set_output(True).set_dim(value.size).set_stride(grad_value_stride).set_data_type(
dK.set_output(True).set_dim(key.shape).set_stride(grad_key_stride).set_data_type(torch_to_cudnn_dtype(key.dtype))
dV.set_output(True).set_dim(value.shape).set_stride(grad_value_stride).set_data_type(
torch_to_cudnn_dtype(value.dtype)
)

Expand Down Expand Up @@ -612,7 +623,6 @@ def _cudnn_sdpa_bwd_impl(
scale: None | float = None,
cat_grad_qkv: bool,
) -> tuple[torch.Tensor, ...]:
query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask)
query = _sdpa_enforce_input_tensor_contiguity(query)
key = _sdpa_enforce_input_tensor_contiguity(key)
value = _sdpa_enforce_input_tensor_contiguity(value)
Expand All @@ -628,6 +638,13 @@ def _cudnn_sdpa_bwd_impl(
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)

if attn_mask is not None:
attn_mask = attn_mask.view((1,) * (query.ndim - attn_mask.ndim), *attn_mask.shape)
if attn_mask.dtype == torch.bool:
attn_bias = torch.zeros_like(attn_mask, dtype=query.dtype)
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
attn_mask = attn_bias

(
Q,
K,
Expand All @@ -645,12 +662,15 @@ def _cudnn_sdpa_bwd_impl(
dBias,
graph,
) = _make_cudnn_sdpa_backward_graph(
query_4d,
key_4d,
value_4d,
attn_mask_4d,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
query.stride(),
key.stride(),
value.stride(),
grad_query.stride(),
grad_key.stride(),
grad_value.stride(),
Expand All @@ -676,12 +696,7 @@ def _cudnn_sdpa_bwd_impl(
dV: grad_value,
}
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias = torch.zeros_like(attn_mask, dtype=query.dtype)
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
attn_mask = attn_bias

grad_attn_mask = torch.empty_like(attn_mask) if attn_mask is not None else None
grad_attn_mask = torch.empty_like(attn_mask)

cudnn_to_torch_tensor[Bias] = attn_mask.detach()
cudnn_to_torch_tensor[dBias] = grad_attn_mask
Expand Down
7 changes: 7 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -7672,6 +7672,13 @@ def scaled_dot_product_attention_reference_generator(op, device, dtype, requires
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
yield SampleInput(q, k, v, None, 0.0, True)

# non-contiguous with stride 0 cases
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
q_broadcast = torch.as_strided(q, size=q.shape, stride=(0, 0, E, 1))
k_broadcast = torch.as_strided(k, size=k.shape, stride=(0, 0, E, 1))
v_broadcast = torch.as_strided(v, size=v.shape, stride=(0, 0, Ev, 1))
yield SampleInput(q_broadcast, k_broadcast, v_broadcast, None, 0.0, True)


def scaled_dot_product_attention_sample_generator(op, device, dtype, requires_grad, **kwargs):
"""https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html"""
Expand Down
7 changes: 7 additions & 0 deletions thunder/tests/test_cudnn_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req
bool_attn_mask = make((1, n_head, L, S), dtype=torch.bool, low=1, high=1, requires_grad=False).tril()
yield SampleInput(q, k, v, bool_attn_mask, is_causal=False)

# non-contiguous with stride 0 cases
q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev)
q_broadcast = torch.as_strided(q, size=q.shape, stride=(0, 0, E, 1))
k_broadcast = torch.as_strided(k, size=k.shape, stride=(0, 0, E, 1))
v_broadcast = torch.as_strided(v, size=v.shape, stride=(0, 0, Ev, 1))
yield SampleInput(q_broadcast, k_broadcast, v_broadcast, None, dropout_p=0.0, is_causal=True)


grad_sdpa_cudnn_opinfo = OpInfo(
thunder.torch.scaled_dot_product_attention,
Expand Down

0 comments on commit 754af86

Please sign in to comment.