diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index a2d61bb626..9a9e72f3d3 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import torch import numpy as np @@ -28,6 +28,7 @@ def cudnn_available() -> bool: from thunder.core.langctxs import langctx import thunder.core.dtypes as dtypes from thunder.torch import TensorLike +from thunder.core.compile_data import get_compile_option from thunder.core.proxies import Proxy, TensorProxy @@ -207,6 +208,11 @@ def compute_NHWC_strides(shape): # And when registering for sdpa, cudnn assumes NHWC layout. (See _transform_sdpa_inputs()) def _sdpa_enforce_input_tensor_contiguity(a: torch.Tensor) -> torch.Tensor: if a.stride(-1) == 1: + # TODO(vedaanta-nvidia): there's an inconsistency between + # _transform_sdpa_inputs and this function, leading to a potential bug. + # _transform_sdpa_inputs always creates contiguous strides, but code + # here creates a partially contiguous stride when the last dimension is + # contiguous but other dimensions are not. return a else: return a.contiguous() @@ -344,7 +350,9 @@ def _cudnn_sdpa_checker( ) -def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_causal): +def _make_cudnn_sdpa_backward_graph( + query, key, value, attn_mask, dropout_p, is_causal, grad_query_stride, grad_key_stride, grad_value_stride +): b, h, s_q, _ = query.size _, _, _, d_v = value.size @@ -414,9 +422,13 @@ def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_ dropout=dropout_tuple, ) - dQ.set_output(True).set_dim(query.size).set_stride(query.stride).set_data_type(torch_to_cudnn_dtype(query.dtype)) - dK.set_output(True).set_dim(key.size).set_stride(key.stride).set_data_type(torch_to_cudnn_dtype(key.dtype)) - dV.set_output(True).set_dim(value.size).set_stride(value.stride).set_data_type(torch_to_cudnn_dtype(value.dtype)) + dQ.set_output(True).set_dim(query.size).set_stride(grad_query_stride).set_data_type( + torch_to_cudnn_dtype(query.dtype) + ) + dK.set_output(True).set_dim(key.size).set_stride(grad_key_stride).set_data_type(torch_to_cudnn_dtype(key.dtype)) + dV.set_output(True).set_dim(value.size).set_stride(grad_value_stride).set_data_type( + torch_to_cudnn_dtype(value.dtype) + ) # Validate the graph before querying the cache key # Validation makes sure all missing properties are inferred and filled, as they affect cache key. @@ -450,7 +462,11 @@ def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_ return _cudnnex_cache[cache_key] -def cudnn_sdpa_backward_meta( +def _replace_dim_with(size: torch.Size, dim: int, dim_size: int) -> torch.Size: + return torch.Size(size[:dim] + (dim_size,) + size[dim + 1 :]) + + +def _cudnn_sdpa_bwd_meta( grad_out: TensorLike, query: TensorLike, key: TensorLike, @@ -464,19 +480,53 @@ def cudnn_sdpa_backward_meta( philox_offset: TensorLike, *, scale: None | float = None, -) -> (TensorProxy, TensorProxy, TensorProxy): - grad_query = TensorProxy(like=query) - grad_key = TensorProxy(like=key) - grad_value = TensorProxy(like=value) + cat_grad_qkv: bool, +) -> tuple[TensorProxy, ...]: + if cat_grad_qkv: + grad_qkv = TensorProxy( + like=query, shape=_replace_dim_with(query.size(), 1, query.size(1) + key.size(1) + value.size(1)) + ) + grads = (grad_qkv,) + else: + grad_query = TensorProxy(like=query) + grad_key = TensorProxy(like=key) + grad_value = TensorProxy(like=value) + grads = (grad_query, grad_key, grad_value) if attn_mask is not None: - grad_attn_mask = TensorProxy(like=attn_mask, shape=attn_mask.shape) - return (grad_query, grad_key, grad_value, grad_attn_mask) - else: - return (grad_query, grad_key, grad_value) + grad_attn_mask = TensorProxy(like=attn_mask) + grads = grads + (grad_attn_mask,) + + return grads + + +def _same_size_except(*args, except_dim: int) -> bool: + shapes = [_replace_dim_with(shape, except_dim, 0) for shape in args] + return all(shape == shapes[0] for shape in shapes) + + +# Allocates an empty tensor that will hold dQ, dK, and dV, concatenated. +# `query`, `key` and `value` merely provide necessary metadata such as sizes +# and dtypes. They don't have to be passed in as `torch.Tensor`s. +def _allocate_catted_grad_qkv( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, +) -> torch.Tensor: + assert _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + assert query.dtype == key.dtype == value.dtype + assert query.device == key.device == value.device + + b, s, d = query.size(0), query.size(2), query.size(3) + h_q, h_k, h_v = query.size(1), key.size(1), value.size(1) + h_qkv = h_q + h_k + h_v + # Create grad_qkv as a tensor of size [b,h_qkv,s,d] and allocation order + # [0,2,1,3] from major to minor. + return torch.empty(b, s, h_qkv, d, dtype=query.dtype, device=query.device).permute(0, 2, 1, 3) -def cudnn_sdpa_bwd_impl( + +def _cudnn_sdpa_bwd_impl( grad_out: torch.Tensor, query: torch.Tensor, key: torch.Tensor, @@ -490,8 +540,23 @@ def cudnn_sdpa_bwd_impl( philox_offset: torch.Tensor, *, scale: None | float = None, -) -> (torch.Tensor, torch.Tensor, torch.Tensor): + cat_grad_qkv: bool, +) -> tuple[torch.Tensor, ...]: query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask) + query = _sdpa_enforce_input_tensor_contiguity(query) + key = _sdpa_enforce_input_tensor_contiguity(key) + value = _sdpa_enforce_input_tensor_contiguity(value) + + # When cat_grad_qkv is on, allocate dQKV and make dQ, dK, and dV + # slices of that. Otherwise, allocate them individually. + grad_qkv: None | torch.Tensor = None + if cat_grad_qkv: + grad_qkv = _allocate_catted_grad_qkv(query, key, value) + grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) + else: + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) ( Q, @@ -516,16 +581,11 @@ def cudnn_sdpa_bwd_impl( attn_mask_4d, dropout_p, is_causal, + grad_query.stride(), + grad_key.stride(), + grad_value.stride(), ) - query = _sdpa_enforce_input_tensor_contiguity(query) - key = _sdpa_enforce_input_tensor_contiguity(key) - value = _sdpa_enforce_input_tensor_contiguity(value) - - grad_query = torch.empty_like(query) - grad_key = torch.empty_like(key) - grad_value = torch.empty_like(value) - # Default value of scale, if not provided, in all torch versions if scale is None: scale = query.shape[-1] ** -0.5 @@ -560,21 +620,25 @@ def cudnn_sdpa_bwd_impl( graph.execute(cudnn_to_torch_tensor, workspace) - if attn_mask is None: - return grad_query, grad_key, grad_value + if cat_grad_qkv: + grads = (grad_qkv,) else: - return grad_query, grad_key, grad_value, grad_attn_mask + grads = (grad_query, grad_key, grad_value) + + if attn_mask is not None: + grads = grads + (grad_attn_mask,) + return grads cudnn_sdpa_bwd = cudnn_ex.register_operator( "cudnn_sdpa_bwd", - meta=cudnn_sdpa_backward_meta, - fn=cudnn_sdpa_bwd_impl, + meta=_cudnn_sdpa_bwd_meta, + fn=_cudnn_sdpa_bwd_impl, ) @langctx("torch") -def _cudnn_sdpa_transform( +def _cudnn_sdpa_fwd_wrapper( query: TensorProxy, key: TensorProxy, value: TensorProxy, @@ -590,7 +654,7 @@ def _cudnn_sdpa_transform( @langctx("torch") -def _cudnn_sdpa_grad( +def _cudnn_sdpa_bwd_wrapper( query: TensorProxy, key: TensorProxy, value: TensorProxy, @@ -604,9 +668,20 @@ def _cudnn_sdpa_grad( query, key, value, attn_mask, dropout_p, is_causal, scale=scale ) - g = get_grad(primal) + description = """\ +This flag is for enabling nvFuser's zipping optimization that seeks to avoid +expensive concatenation. https://github.com/NVIDIA/Fuser/issues/1768 has more +details. When this flag is true, cudnn_sdpa_bwd may cat dQ, dK and dV as one +tensor and return them as slices of that tensor. +""" + may_cat_grad_qkv: None | bool = get_compile_option("cudnn_sdpa_bwd_may_cat_grad_qkv", description) + if may_cat_grad_qkv is None: + may_cat_grad_qkv = False + assert isinstance(may_cat_grad_qkv, bool) + cat_grad_qkv = may_cat_grad_qkv and _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + grads = cudnn_sdpa_bwd( - g, + get_grad(primal), query, key, value, @@ -618,16 +693,23 @@ def _cudnn_sdpa_grad( seed, offset, scale=scale, + cat_grad_qkv=cat_grad_qkv, ) - if attn_mask is None: - grad_query, grad_key, grad_val = grads - else: - grad_query, grad_key, grad_val, grad_attn_mask = grads - put_grads((query, key, value), (grad_query, grad_key, grad_val)) if attn_mask is not None: + grad_attn_mask = grads[-1] + grads = grads[:-1] put_grad(attn_mask, grad_attn_mask) + if cat_grad_qkv: + # The `split` is done outside `cudnn_sdpa_bwd` so it can be picked up + # by nvfuserex. + (grad_qkv,) = grads + grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) + else: + grad_query, grad_key, grad_value = grads + put_grads((query, key, value), (grad_query, grad_key, grad_value)) + return primal @@ -635,6 +717,6 @@ def _cudnn_sdpa_grad( cudnn_ex.register_implementation( ltorch.scaled_dot_product_attention, checker=_cudnn_sdpa_checker, - execution_transform=_cudnn_sdpa_transform, - grad_transform=_cudnn_sdpa_grad, + execution_transform=_cudnn_sdpa_fwd_wrapper, + grad_transform=_cudnn_sdpa_bwd_wrapper, ) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 7a5ebfe5ae..38de3c122a 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -787,7 +787,7 @@ def _can_fuse_node(n: Node): region = Region(producers, consumers, bsyms) # Acquires the nv_enable_bookend compile option, which defaults to True - bookend_help = """ + bookend_help = """\ nvFuser's 'bookending' heuristic tries to gather metadata operations---such as transpose, reshape, or view---into the beginning and ends of blocks that utilize nvFuser. By pushing these ops to the edges, they will get dropped by the nvFuser diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index ab985aecbf..37e2f6606e 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -43,6 +43,9 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True) + # Same sequence length and embedding size for Q, K and V, a common use case. + yield SampleInput(make(N, n_head, L, E), make(N, n_head, L, E), make(N, n_head, L, E), None, is_causal=True) + # Non-contiguous input tensor case nq = make(N, n_head, E, L).permute(0, 1, 3, 2) nk = make(N, n_head, E, S).permute(0, 1, 3, 2) @@ -72,9 +75,8 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req sample_input_generator=None, reference_input_generator=grad_scaled_dot_product_attention_reference_generator, torch_reference=torch.nn.functional.scaled_dot_product_attention, - # RuntimeError: Only fp32, half & bf16 supported at the moment + # RuntimeError: Only half & bf16 supported at the moment dtypes=( - thunder.dtypes.float32, thunder.dtypes.float16, thunder.dtypes.bfloat16, ), @@ -186,15 +188,10 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): return result -# NOTE Scaled_Dot_Product_Efficient_Attention_Backward does not support fp64 dtypes -# RuntimeError: Only fp32, half & bf16 supported at the moment -@ops( - (grad_sdpa_cudnn_opinfo,), - supported_dtypes=(dtypes.float16, dtypes.bfloat16), - supported_devicetypes=(devices.DeviceType.CUDA,), -) -def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp): - for sample in op.reference_inputs(device, dtype, requires_grad=True): +@pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv")) +@pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes()))) +def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv): + for sample in grad_sdpa_cudnn_opinfo.reference_inputs("cuda", dtype, requires_grad=True): # Enforce tensor arguments are contiguous for torch reference contiguous_args = list(map(lambda a: a.contiguous() if isinstance(a, torch.Tensor) else a, sample.args)) @@ -209,25 +206,25 @@ def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp): continue # Compute vjp result using PyTorch - expect_out = op.torch_reference(*contiguous_args, **sample.kwargs) + expect_out = grad_sdpa_cudnn_opinfo.torch_reference(*contiguous_args, **sample.kwargs) v = make_tensor_like(expect_out) expected_grad = torch.autograd.grad(expect_out, grad_inputs, v) # Compute vjp result using Thunder - flat_op, flat_args, spec = flatten_func(op.op, sample.args, sample.kwargs) + flat_op, flat_args, spec = flatten_func(grad_sdpa_cudnn_opinfo.op, sample.args, sample.kwargs) filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args) cfoo = thunder.compile( vjp(filtered_op), disable_torch_autograd_support=True, disable_preprocessing=True, - executors_list=executor.executors_list() + [cudnn_ex], + executors_list=[cudnn_ex], + cudnn_sdpa_bwd_may_cat_grad_qkv=may_cat_grad_qkv, ) actual_out, actual_grad = cfoo(filtered_args, (v,)) - comp(actual_out, expect_out, atol=1e-2, rtol=1e-2) - + torch.testing.assert_close(actual_out, expect_out, atol=1e-2, rtol=1e-2) # compare gradients of query, key, value, and attn_mask for eg, ag in zip(expected_grad, actual_grad): - comp(eg, ag, atol=2e-1, rtol=2e-2) + torch.testing.assert_close(eg, ag, atol=2e-1, rtol=2e-2)