From bfccc976cabe5979cc6bf6f1403d70bb27ad5cc9 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sat, 23 Mar 2024 23:35:14 +0000 Subject: [PATCH 01/13] A follow-up cleanup on #57. --- thunder/executors/cudnnex.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 9c53088c9d..a2d61bb626 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -334,25 +334,7 @@ def _cudnn_sdpa_checker( if d % 8 != 0 or d > 128: return False - is_backward_supported = _cudnn_sdpa_backward_checker( - query, key, value, attn_mask, dropout_p, is_causal, scale=scale - ) - - return True and is_backward_supported - - -@langctx("torch") -def _cudnn_sdpa_backward_checker( - query: TensorLike, - key: TensorLike, - value: TensorLike, - attn_mask: TensorLike | None = None, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: float | None = None, -) -> bool: - return cudnn is not None + return True cudnn_sdpa_fwd = cudnn_ex.register_operator( From 04c24fdc512c3fdacef29cdae2bec1593a429aa5 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 20 Mar 2024 21:40:42 +0000 Subject: [PATCH 02/13] Preallocate dQ, dK, and dV as one tensor for efficient concatenation. --- thunder/executors/cudnnex.py | 147 ++++++++++++++++++++++++++++------- 1 file changed, 118 insertions(+), 29 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index a2d61bb626..3bd3c125a8 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -1,5 +1,6 @@ -from typing import Any, Optional +from typing import Any +import functools import torch import numpy as np import random @@ -344,7 +345,9 @@ def _cudnn_sdpa_checker( ) -def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_causal): +def _make_cudnn_sdpa_backward_graph( + query, key, value, attn_mask, dropout_p, is_causal, grad_qkv_stride: None | tuple[int, ...] +): b, h, s_q, _ = query.size _, _, _, d_v = value.size @@ -414,9 +417,15 @@ def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_ dropout=dropout_tuple, ) - dQ.set_output(True).set_dim(query.size).set_stride(query.stride).set_data_type(torch_to_cudnn_dtype(query.dtype)) - dK.set_output(True).set_dim(key.size).set_stride(key.stride).set_data_type(torch_to_cudnn_dtype(key.dtype)) - dV.set_output(True).set_dim(value.size).set_stride(value.stride).set_data_type(torch_to_cudnn_dtype(value.dtype)) + dQ.set_output(True).set_dim(query.size).set_stride(grad_qkv_stride or query.stride).set_data_type( + torch_to_cudnn_dtype(query.dtype) + ) + dK.set_output(True).set_dim(key.size).set_stride(grad_qkv_stride or key.stride).set_data_type( + torch_to_cudnn_dtype(key.dtype) + ) + dV.set_output(True).set_dim(value.size).set_stride(grad_qkv_stride or value.stride).set_data_type( + torch_to_cudnn_dtype(value.dtype) + ) # Validate the graph before querying the cache key # Validation makes sure all missing properties are inferred and filled, as they affect cache key. @@ -450,7 +459,7 @@ def _make_cudnn_sdpa_backward_graph(query, key, value, attn_mask, dropout_p, is_ return _cudnnex_cache[cache_key] -def cudnn_sdpa_backward_meta( +def cudnn_sdpa_bwd_meta( grad_out: TensorLike, query: TensorLike, key: TensorLike, @@ -470,12 +479,38 @@ def cudnn_sdpa_backward_meta( grad_value = TensorProxy(like=value) if attn_mask is not None: - grad_attn_mask = TensorProxy(like=attn_mask, shape=attn_mask.shape) + grad_attn_mask = TensorProxy(like=attn_mask) return (grad_query, grad_key, grad_value, grad_attn_mask) else: return (grad_query, grad_key, grad_value) +def _replace_dim_with(size: torch.Size, dim: int, dim_size: int) -> torch.Size: + return torch.Size(size[:dim] + (dim_size,) + size[dim + 1 :]) + + +def _same_size_except(*args, except_dim: int) -> bool: + shapes = [_replace_dim_with(shape, except_dim, 0) for shape in args] + return all(shape == shapes[0] for shape in shapes) + + +def _preallocate_grad_qkv( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, +) -> torch.Tensor: + assert _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + assert query.dtype == key.dtype == value.dtype + assert query.device == key.device == value.device + + b, s, d = query.size(0), query.size(2), query.size(3) + h_q, h_k, h_v = query.size(1), key.size(1), value.size(1) + h_qkv = h_q + h_k + h_v + + # Create grad_qkv as a tensor of size [b,h_qkv,s,d] and allocation order [0,2,1,3] from major to minor. + return torch.empty(b, s, h_qkv, d, dtype=query.dtype, device=query.device).permute(0, 2, 1, 3) + + def cudnn_sdpa_bwd_impl( grad_out: torch.Tensor, query: torch.Tensor, @@ -490,9 +525,13 @@ def cudnn_sdpa_bwd_impl( philox_offset: torch.Tensor, *, scale: None | float = None, -) -> (torch.Tensor, torch.Tensor, torch.Tensor): - query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask) + preformat_grad_qkv: bool, +) -> tuple[torch.Tensor, ...]: + grad_qkv: None | torch.Tensor = None + if preformat_grad_qkv: + grad_qkv = _preallocate_grad_qkv(query, key, value) + query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask) ( Q, K, @@ -516,15 +555,19 @@ def cudnn_sdpa_bwd_impl( attn_mask_4d, dropout_p, is_causal, + None if grad_qkv is None else grad_qkv.stride(), ) query = _sdpa_enforce_input_tensor_contiguity(query) key = _sdpa_enforce_input_tensor_contiguity(key) value = _sdpa_enforce_input_tensor_contiguity(value) - grad_query = torch.empty_like(query) - grad_key = torch.empty_like(key) - grad_value = torch.empty_like(value) + if preformat_grad_qkv: + grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) + else: + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) # Default value of scale, if not provided, in all torch versions if scale is None: @@ -560,21 +603,59 @@ def cudnn_sdpa_bwd_impl( graph.execute(cudnn_to_torch_tensor, workspace) - if attn_mask is None: - return grad_query, grad_key, grad_value + if preformat_grad_qkv: + grads = (grad_qkv,) else: - return grad_query, grad_key, grad_value, grad_attn_mask + grads = (grad_query, grad_key, grad_value) + + if attn_mask is not None: + grads = grads + (grad_attn_mask,) + return grads +# TODO: can meta and fn be made private? cudnn_sdpa_bwd = cudnn_ex.register_operator( "cudnn_sdpa_bwd", - meta=cudnn_sdpa_backward_meta, - fn=cudnn_sdpa_bwd_impl, + meta=cudnn_sdpa_bwd_meta, + fn=functools.partial(cudnn_sdpa_bwd_impl, preformat_grad_qkv=False), +) + + +def cudnn_sdpa_bwd_preformatted_meta( + grad_out: TensorLike, + query: TensorLike, + key: TensorLike, + value: TensorLike, + attn_mask: None | TensorProxy, + dropout_p: float, + is_causal: bool, + out: TensorLike, + softmax_stats: TensorLike, + philox_seed: TensorLike, + philox_offset: TensorLike, + *, + scale: None | float = None, +) -> tuple[TensorProxy, ...]: + grad_qkv = TensorProxy( + like=query, shape=_replace_dim_with(query.size(), 1, query.size(1) + key.size(1) + value.size(1)) + ) + + if attn_mask is not None: + grad_attn_mask = TensorProxy(like=attn_mask) + return (grad_qkv, grad_attn_mask) + else: + return (grad_qkv,) + + +cudnn_sdpa_bwd_preformatted = cudnn_ex.register_operator( + "cudnn_sdpa_bwd_preformatted", + meta=cudnn_sdpa_bwd_preformatted_meta, + fn=functools.partial(cudnn_sdpa_bwd_impl, preformat_grad_qkv=True), ) @langctx("torch") -def _cudnn_sdpa_transform( +def _cudnn_sdpa_fwd_wrapper( query: TensorProxy, key: TensorProxy, value: TensorProxy, @@ -590,7 +671,7 @@ def _cudnn_sdpa_transform( @langctx("torch") -def _cudnn_sdpa_grad( +def _cudnn_sdpa_bwd_wrapper( query: TensorProxy, key: TensorProxy, value: TensorProxy, @@ -604,9 +685,13 @@ def _cudnn_sdpa_grad( query, key, value, attn_mask, dropout_p, is_causal, scale=scale ) - g = get_grad(primal) - grads = cudnn_sdpa_bwd( - g, + bwd_op = ( + cudnn_sdpa_bwd_preformatted + if _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + else cudnn_sdpa_bwd + ) + grads = bwd_op( + get_grad(primal), query, key, value, @@ -619,15 +704,19 @@ def _cudnn_sdpa_grad( offset, scale=scale, ) - if attn_mask is None: - grad_query, grad_key, grad_val = grads - else: - grad_query, grad_key, grad_val, grad_attn_mask = grads - put_grads((query, key, value), (grad_query, grad_key, grad_val)) if attn_mask is not None: + grad_attn_mask = grads[-1] + grads = grads[:-1] put_grad(attn_mask, grad_attn_mask) + if bwd_op == cudnn_sdpa_bwd: + grad_query, grad_key, grad_value = grads + else: + (grad_qkv,) = grads + grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) + put_grads((query, key, value), (grad_query, grad_key, grad_value)) + return primal @@ -635,6 +724,6 @@ def _cudnn_sdpa_grad( cudnn_ex.register_implementation( ltorch.scaled_dot_product_attention, checker=_cudnn_sdpa_checker, - execution_transform=_cudnn_sdpa_transform, - grad_transform=_cudnn_sdpa_grad, + execution_transform=_cudnn_sdpa_fwd_wrapper, + grad_transform=_cudnn_sdpa_bwd_wrapper, ) From 4d8361f6c86a95ba825dfdb1a51b6a723f650c68 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 20 Mar 2024 22:02:05 +0000 Subject: [PATCH 03/13] More tests. --- thunder/tests/test_cudnn_executor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index ab985aecbf..65bf685bc6 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -43,6 +43,9 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req q, k, v = make(N, n_head, L, E), make(N, n_head, S, E), make(N, n_head, S, Ev) yield SampleInput(q, k, v, None, dropout_p=0.0, is_causal=True) + # Same sequence length and embedding size for Q, K and V, a common use case. + yield SampleInput(make(N, n_head, L, E), make(N, n_head, L, E), make(N, n_head, L, E), None, is_causal=True) + # Non-contiguous input tensor case nq = make(N, n_head, E, L).permute(0, 1, 3, 2) nk = make(N, n_head, E, S).permute(0, 1, 3, 2) From 12f3ea59a82004c2e9b5ce0058d5ef25f9e9ea6a Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Mar 2024 05:34:30 +0000 Subject: [PATCH 04/13] Make some functions private to the module. --- thunder/executors/cudnnex.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 3bd3c125a8..c0304bc95a 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -459,7 +459,7 @@ def _make_cudnn_sdpa_backward_graph( return _cudnnex_cache[cache_key] -def cudnn_sdpa_bwd_meta( +def _cudnn_sdpa_bwd_meta( grad_out: TensorLike, query: TensorLike, key: TensorLike, @@ -511,7 +511,7 @@ def _preallocate_grad_qkv( return torch.empty(b, s, h_qkv, d, dtype=query.dtype, device=query.device).permute(0, 2, 1, 3) -def cudnn_sdpa_bwd_impl( +def _cudnn_sdpa_bwd_impl( grad_out: torch.Tensor, query: torch.Tensor, key: torch.Tensor, @@ -613,15 +613,14 @@ def cudnn_sdpa_bwd_impl( return grads -# TODO: can meta and fn be made private? cudnn_sdpa_bwd = cudnn_ex.register_operator( "cudnn_sdpa_bwd", - meta=cudnn_sdpa_bwd_meta, - fn=functools.partial(cudnn_sdpa_bwd_impl, preformat_grad_qkv=False), + meta=_cudnn_sdpa_bwd_meta, + fn=functools.partial(_cudnn_sdpa_bwd_impl, preformat_grad_qkv=False), ) -def cudnn_sdpa_bwd_preformatted_meta( +def _cudnn_sdpa_bwd_preformatted_meta( grad_out: TensorLike, query: TensorLike, key: TensorLike, @@ -649,8 +648,8 @@ def cudnn_sdpa_bwd_preformatted_meta( cudnn_sdpa_bwd_preformatted = cudnn_ex.register_operator( "cudnn_sdpa_bwd_preformatted", - meta=cudnn_sdpa_bwd_preformatted_meta, - fn=functools.partial(cudnn_sdpa_bwd_impl, preformat_grad_qkv=True), + meta=_cudnn_sdpa_bwd_preformatted_meta, + fn=functools.partial(_cudnn_sdpa_bwd_impl, preformat_grad_qkv=True), ) From 61ee1798866c677bff3ee9f9ede4b10bf74de783 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Mar 2024 05:42:15 +0000 Subject: [PATCH 05/13] Merge the two ops. --- thunder/executors/cudnnex.py | 78 ++++++++++++------------------------ 1 file changed, 25 insertions(+), 53 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index c0304bc95a..78196f8b2c 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -459,6 +459,10 @@ def _make_cudnn_sdpa_backward_graph( return _cudnnex_cache[cache_key] +def _replace_dim_with(size: torch.Size, dim: int, dim_size: int) -> torch.Size: + return torch.Size(size[:dim] + (dim_size,) + size[dim + 1 :]) + + def _cudnn_sdpa_bwd_meta( grad_out: TensorLike, query: TensorLike, @@ -473,20 +477,24 @@ def _cudnn_sdpa_bwd_meta( philox_offset: TensorLike, *, scale: None | float = None, -) -> (TensorProxy, TensorProxy, TensorProxy): - grad_query = TensorProxy(like=query) - grad_key = TensorProxy(like=key) - grad_value = TensorProxy(like=value) + preformat_grad_qkv: bool, +) -> tuple[TensorProxy, ...]: + if preformat_grad_qkv: + grad_qkv = TensorProxy( + like=query, shape=_replace_dim_with(query.size(), 1, query.size(1) + key.size(1) + value.size(1)) + ) + grads = (grad_qkv,) + else: + grad_query = TensorProxy(like=query) + grad_key = TensorProxy(like=key) + grad_value = TensorProxy(like=value) + grads = (grad_query, grad_key, grad_value) if attn_mask is not None: grad_attn_mask = TensorProxy(like=attn_mask) - return (grad_query, grad_key, grad_value, grad_attn_mask) - else: - return (grad_query, grad_key, grad_value) - + grads = grads + (grad_attn_mask,) -def _replace_dim_with(size: torch.Size, dim: int, dim_size: int) -> torch.Size: - return torch.Size(size[:dim] + (dim_size,) + size[dim + 1 :]) + return grads def _same_size_except(*args, except_dim: int) -> bool: @@ -616,40 +624,7 @@ def _cudnn_sdpa_bwd_impl( cudnn_sdpa_bwd = cudnn_ex.register_operator( "cudnn_sdpa_bwd", meta=_cudnn_sdpa_bwd_meta, - fn=functools.partial(_cudnn_sdpa_bwd_impl, preformat_grad_qkv=False), -) - - -def _cudnn_sdpa_bwd_preformatted_meta( - grad_out: TensorLike, - query: TensorLike, - key: TensorLike, - value: TensorLike, - attn_mask: None | TensorProxy, - dropout_p: float, - is_causal: bool, - out: TensorLike, - softmax_stats: TensorLike, - philox_seed: TensorLike, - philox_offset: TensorLike, - *, - scale: None | float = None, -) -> tuple[TensorProxy, ...]: - grad_qkv = TensorProxy( - like=query, shape=_replace_dim_with(query.size(), 1, query.size(1) + key.size(1) + value.size(1)) - ) - - if attn_mask is not None: - grad_attn_mask = TensorProxy(like=attn_mask) - return (grad_qkv, grad_attn_mask) - else: - return (grad_qkv,) - - -cudnn_sdpa_bwd_preformatted = cudnn_ex.register_operator( - "cudnn_sdpa_bwd_preformatted", - meta=_cudnn_sdpa_bwd_preformatted_meta, - fn=functools.partial(_cudnn_sdpa_bwd_impl, preformat_grad_qkv=True), + fn=_cudnn_sdpa_bwd_impl, ) @@ -684,12 +659,8 @@ def _cudnn_sdpa_bwd_wrapper( query, key, value, attn_mask, dropout_p, is_causal, scale=scale ) - bwd_op = ( - cudnn_sdpa_bwd_preformatted - if _same_size_except(query.size(), key.size(), value.size(), except_dim=1) - else cudnn_sdpa_bwd - ) - grads = bwd_op( + preformat_grad_qkv = _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + grads = cudnn_sdpa_bwd( get_grad(primal), query, key, @@ -702,6 +673,7 @@ def _cudnn_sdpa_bwd_wrapper( seed, offset, scale=scale, + preformat_grad_qkv=preformat_grad_qkv, ) if attn_mask is not None: @@ -709,11 +681,11 @@ def _cudnn_sdpa_bwd_wrapper( grads = grads[:-1] put_grad(attn_mask, grad_attn_mask) - if bwd_op == cudnn_sdpa_bwd: - grad_query, grad_key, grad_value = grads - else: + if preformat_grad_qkv: (grad_qkv,) = grads grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) + else: + grad_query, grad_key, grad_value = grads put_grads((query, key, value), (grad_query, grad_key, grad_value)) return primal From c5d378201fe2fba4cf80cd1affb56faf98b25db6 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Mar 2024 05:43:06 +0000 Subject: [PATCH 06/13] Rename preformat to preallocate. --- thunder/executors/cudnnex.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 78196f8b2c..1740216d30 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -477,9 +477,9 @@ def _cudnn_sdpa_bwd_meta( philox_offset: TensorLike, *, scale: None | float = None, - preformat_grad_qkv: bool, + preallocate_grad_qkv: bool, ) -> tuple[TensorProxy, ...]: - if preformat_grad_qkv: + if preallocate_grad_qkv: grad_qkv = TensorProxy( like=query, shape=_replace_dim_with(query.size(), 1, query.size(1) + key.size(1) + value.size(1)) ) @@ -533,10 +533,10 @@ def _cudnn_sdpa_bwd_impl( philox_offset: torch.Tensor, *, scale: None | float = None, - preformat_grad_qkv: bool, + preallocate_grad_qkv: bool, ) -> tuple[torch.Tensor, ...]: grad_qkv: None | torch.Tensor = None - if preformat_grad_qkv: + if preallocate_grad_qkv: grad_qkv = _preallocate_grad_qkv(query, key, value) query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask) @@ -570,7 +570,7 @@ def _cudnn_sdpa_bwd_impl( key = _sdpa_enforce_input_tensor_contiguity(key) value = _sdpa_enforce_input_tensor_contiguity(value) - if preformat_grad_qkv: + if preallocate_grad_qkv: grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) else: grad_query = torch.empty_like(query) @@ -611,7 +611,7 @@ def _cudnn_sdpa_bwd_impl( graph.execute(cudnn_to_torch_tensor, workspace) - if preformat_grad_qkv: + if preallocate_grad_qkv: grads = (grad_qkv,) else: grads = (grad_query, grad_key, grad_value) @@ -659,7 +659,7 @@ def _cudnn_sdpa_bwd_wrapper( query, key, value, attn_mask, dropout_p, is_causal, scale=scale ) - preformat_grad_qkv = _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + preallocate_grad_qkv = _same_size_except(query.size(), key.size(), value.size(), except_dim=1) grads = cudnn_sdpa_bwd( get_grad(primal), query, @@ -673,7 +673,7 @@ def _cudnn_sdpa_bwd_wrapper( seed, offset, scale=scale, - preformat_grad_qkv=preformat_grad_qkv, + preallocate_grad_qkv=preallocate_grad_qkv, ) if attn_mask is not None: @@ -681,7 +681,7 @@ def _cudnn_sdpa_bwd_wrapper( grads = grads[:-1] put_grad(attn_mask, grad_attn_mask) - if preformat_grad_qkv: + if preallocate_grad_qkv: (grad_qkv,) = grads grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) else: From dd882826d49276b05c32fd290301a8f29707c6af Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Mar 2024 05:46:22 +0000 Subject: [PATCH 07/13] Add a knob. --- thunder/executors/cudnnex.py | 19 ++++++++++++++++--- thunder/executors/nvfuserex_impl.py | 2 +- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 1740216d30..5f607ec981 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -29,6 +29,7 @@ def cudnn_available() -> bool: from thunder.core.langctxs import langctx import thunder.core.dtypes as dtypes from thunder.torch import TensorLike +from thunder.core.compile_data import get_compile_option from thunder.core.proxies import Proxy, TensorProxy @@ -659,7 +660,19 @@ def _cudnn_sdpa_bwd_wrapper( query, key, value, attn_mask, dropout_p, is_causal, scale=scale ) - preallocate_grad_qkv = _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + description = """\ +This flag is for enabling nvFuser's zipping optimization that seeks to avoid +expensive concatenation. +https://github.com/NVIDIA/Fuser/issues/1502#issuecomment-1870837878 has more +details. When this flag is true, cudnn_sdpa_bwd may preallocate dQ, dK and dV +in **one** tensor and return them as slices of that tensor. +""" + may_preallocate: None | bool = get_compile_option("cudnn_sdpa_bwd_may_preallocate", description) + if may_preallocate is None: + may_preallocate = False + assert isinstance(may_preallocate, bool) + preallocate = may_preallocate and _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + grads = cudnn_sdpa_bwd( get_grad(primal), query, @@ -673,7 +686,7 @@ def _cudnn_sdpa_bwd_wrapper( seed, offset, scale=scale, - preallocate_grad_qkv=preallocate_grad_qkv, + preallocate_grad_qkv=preallocate, ) if attn_mask is not None: @@ -681,7 +694,7 @@ def _cudnn_sdpa_bwd_wrapper( grads = grads[:-1] put_grad(attn_mask, grad_attn_mask) - if preallocate_grad_qkv: + if preallocate: (grad_qkv,) = grads grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) else: diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 5204ab1d24..bda5d2aba0 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -787,7 +787,7 @@ def _can_fuse_node(n: Node): region = Region(producers, consumers, bsyms) # Acquires the nv_enable_bookend compile option, which defaults to True - bookend_help = """ + bookend_help = """\ nvFuser's 'bookending' heuristic tries to gather metadata operations---such as transpose, reshape, or view---into the beginning and ends of blocks that utilize nvFuser. By pushing these ops to the edges, they will get dropped by the nvFuser From 6a49c820c07a640e4a4e174b2be38274b64001f3 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Mar 2024 16:31:26 +0000 Subject: [PATCH 08/13] Fix the test. --- thunder/tests/test_cudnn_executor.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index 65bf685bc6..7369667dfc 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -75,9 +75,8 @@ def grad_scaled_dot_product_attention_reference_generator(op, device, dtype, req sample_input_generator=None, reference_input_generator=grad_scaled_dot_product_attention_reference_generator, torch_reference=torch.nn.functional.scaled_dot_product_attention, - # RuntimeError: Only fp32, half & bf16 supported at the moment + # RuntimeError: Only half & bf16 supported at the moment dtypes=( - thunder.dtypes.float32, thunder.dtypes.float16, thunder.dtypes.bfloat16, ), @@ -189,15 +188,10 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): return result -# NOTE Scaled_Dot_Product_Efficient_Attention_Backward does not support fp64 dtypes -# RuntimeError: Only fp32, half & bf16 supported at the moment -@ops( - (grad_sdpa_cudnn_opinfo,), - supported_dtypes=(dtypes.float16, dtypes.bfloat16), - supported_devicetypes=(devices.DeviceType.CUDA,), -) -def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp): - for sample in op.reference_inputs(device, dtype, requires_grad=True): +@pytest.mark.parametrize("may_preallocate", (True, False), ids=("may-preallocate", "never-preallocate")) +@pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes()))) +def test_vjp_correctness_cudnn_sdpa(dtype, may_preallocate): + for sample in grad_sdpa_cudnn_opinfo.reference_inputs("cuda", dtype, requires_grad=True): # Enforce tensor arguments are contiguous for torch reference contiguous_args = list(map(lambda a: a.contiguous() if isinstance(a, torch.Tensor) else a, sample.args)) @@ -212,25 +206,25 @@ def test_vjp_correctness_sdpa_cudnnex_manual(op, device, dtype, executor, comp): continue # Compute vjp result using PyTorch - expect_out = op.torch_reference(*contiguous_args, **sample.kwargs) + expect_out = grad_sdpa_cudnn_opinfo.torch_reference(*contiguous_args, **sample.kwargs) v = make_tensor_like(expect_out) expected_grad = torch.autograd.grad(expect_out, grad_inputs, v) # Compute vjp result using Thunder - flat_op, flat_args, spec = flatten_func(op.op, sample.args, sample.kwargs) + flat_op, flat_args, spec = flatten_func(grad_sdpa_cudnn_opinfo.op, sample.args, sample.kwargs) filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args) cfoo = thunder.compile( vjp(filtered_op), disable_torch_autograd_support=True, disable_preprocessing=True, - executors_list=executor.executors_list() + [cudnn_ex], + executors_list=[cudnn_ex], + cudnn_sdpa_bwd_may_preallocate=may_preallocate, ) actual_out, actual_grad = cfoo(filtered_args, (v,)) - comp(actual_out, expect_out, atol=1e-2, rtol=1e-2) - + torch.testing.assert_close(actual_out, expect_out, atol=1e-2, rtol=1e-2) # compare gradients of query, key, value, and attn_mask for eg, ag in zip(expected_grad, actual_grad): - comp(eg, ag, atol=2e-1, rtol=2e-2) + torch.testing.assert_close(eg, ag, atol=2e-1, rtol=2e-2) From ee7c74c11d594d2fca83329eee8ed549e5278348 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Mar 2024 09:32:04 -0700 Subject: [PATCH 09/13] Update thunder/executors/cudnnex.py Co-authored-by: Masaki Kozuki --- thunder/executors/cudnnex.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 5f607ec981..460a0cc3c8 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -1,6 +1,5 @@ from typing import Any -import functools import torch import numpy as np import random From 92e3fb3435b5d5fd3c2fd622e2340c907c038aee Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Mar 2024 17:02:49 +0000 Subject: [PATCH 10/13] Clean up. --- thunder/executors/cudnnex.py | 43 ++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 460a0cc3c8..bba0ccc1f4 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -208,6 +208,11 @@ def compute_NHWC_strides(shape): # And when registering for sdpa, cudnn assumes NHWC layout. (See _transform_sdpa_inputs()) def _sdpa_enforce_input_tensor_contiguity(a: torch.Tensor) -> torch.Tensor: if a.stride(-1) == 1: + # TODO(vedaanta-nvidia): there's an inconsistency between + # _transform_sdpa_inputs and this function, leading to a potential bug. + # _transform_sdpa_inputs always creates contiguous strides, but code + # here creates a partially contiguous stride when the last dimension is + # contiguous but other dimensions are not. return a else: return a.contiguous() @@ -346,7 +351,7 @@ def _cudnn_sdpa_checker( def _make_cudnn_sdpa_backward_graph( - query, key, value, attn_mask, dropout_p, is_causal, grad_qkv_stride: None | tuple[int, ...] + query, key, value, attn_mask, dropout_p, is_causal, grad_query_stride, grad_key_stride, grad_value_stride ): b, h, s_q, _ = query.size _, _, _, d_v = value.size @@ -417,13 +422,11 @@ def _make_cudnn_sdpa_backward_graph( dropout=dropout_tuple, ) - dQ.set_output(True).set_dim(query.size).set_stride(grad_qkv_stride or query.stride).set_data_type( + dQ.set_output(True).set_dim(query.size).set_stride(grad_query_stride).set_data_type( torch_to_cudnn_dtype(query.dtype) ) - dK.set_output(True).set_dim(key.size).set_stride(grad_qkv_stride or key.stride).set_data_type( - torch_to_cudnn_dtype(key.dtype) - ) - dV.set_output(True).set_dim(value.size).set_stride(grad_qkv_stride or value.stride).set_data_type( + dK.set_output(True).set_dim(key.size).set_stride(grad_key_stride).set_data_type(torch_to_cudnn_dtype(key.dtype)) + dV.set_output(True).set_dim(value.size).set_stride(grad_value_stride).set_data_type( torch_to_cudnn_dtype(value.dtype) ) @@ -535,11 +538,22 @@ def _cudnn_sdpa_bwd_impl( scale: None | float = None, preallocate_grad_qkv: bool, ) -> tuple[torch.Tensor, ...]: + query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask) + query = _sdpa_enforce_input_tensor_contiguity(query) + key = _sdpa_enforce_input_tensor_contiguity(key) + value = _sdpa_enforce_input_tensor_contiguity(value) + + # When preallocate_grad_qkv is on, allocate dQKV and make dQ, dK, and dV + # slices of that. Otherwise, allocate them individually. grad_qkv: None | torch.Tensor = None if preallocate_grad_qkv: grad_qkv = _preallocate_grad_qkv(query, key, value) + grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) + else: + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) - query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask) ( Q, K, @@ -563,20 +577,11 @@ def _cudnn_sdpa_bwd_impl( attn_mask_4d, dropout_p, is_causal, - None if grad_qkv is None else grad_qkv.stride(), + grad_query.stride(), + grad_key.stride(), + grad_value.stride(), ) - query = _sdpa_enforce_input_tensor_contiguity(query) - key = _sdpa_enforce_input_tensor_contiguity(key) - value = _sdpa_enforce_input_tensor_contiguity(value) - - if preallocate_grad_qkv: - grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) - else: - grad_query = torch.empty_like(query) - grad_key = torch.empty_like(key) - grad_value = torch.empty_like(value) - # Default value of scale, if not provided, in all torch versions if scale is None: scale = query.shape[-1] ** -0.5 From f0c0e17d6aeb546934e80b4236da9920ac8be539 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Mar 2024 17:06:56 +0000 Subject: [PATCH 11/13] Renaming. --- thunder/executors/cudnnex.py | 34 ++++++++++++++-------------- thunder/tests/test_cudnn_executor.py | 6 ++--- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index bba0ccc1f4..a31e62efdb 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -480,9 +480,9 @@ def _cudnn_sdpa_bwd_meta( philox_offset: TensorLike, *, scale: None | float = None, - preallocate_grad_qkv: bool, + cat_grad_qkv: bool, ) -> tuple[TensorProxy, ...]: - if preallocate_grad_qkv: + if cat_grad_qkv: grad_qkv = TensorProxy( like=query, shape=_replace_dim_with(query.size(), 1, query.size(1) + key.size(1) + value.size(1)) ) @@ -505,7 +505,7 @@ def _same_size_except(*args, except_dim: int) -> bool: return all(shape == shapes[0] for shape in shapes) -def _preallocate_grad_qkv( +def _cat_grad_qkv( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -536,18 +536,18 @@ def _cudnn_sdpa_bwd_impl( philox_offset: torch.Tensor, *, scale: None | float = None, - preallocate_grad_qkv: bool, + cat_grad_qkv: bool, ) -> tuple[torch.Tensor, ...]: query_4d, key_4d, value_4d, attn_mask_4d = _transform_sdpa_inputs(query, key, value, attn_mask) query = _sdpa_enforce_input_tensor_contiguity(query) key = _sdpa_enforce_input_tensor_contiguity(key) value = _sdpa_enforce_input_tensor_contiguity(value) - # When preallocate_grad_qkv is on, allocate dQKV and make dQ, dK, and dV + # When cat_grad_qkv is on, allocate dQKV and make dQ, dK, and dV # slices of that. Otherwise, allocate them individually. grad_qkv: None | torch.Tensor = None - if preallocate_grad_qkv: - grad_qkv = _preallocate_grad_qkv(query, key, value) + if cat_grad_qkv: + grad_qkv = _cat_grad_qkv(query, key, value) grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) else: grad_query = torch.empty_like(query) @@ -616,7 +616,7 @@ def _cudnn_sdpa_bwd_impl( graph.execute(cudnn_to_torch_tensor, workspace) - if preallocate_grad_qkv: + if cat_grad_qkv: grads = (grad_qkv,) else: grads = (grad_query, grad_key, grad_value) @@ -668,14 +668,14 @@ def _cudnn_sdpa_bwd_wrapper( This flag is for enabling nvFuser's zipping optimization that seeks to avoid expensive concatenation. https://github.com/NVIDIA/Fuser/issues/1502#issuecomment-1870837878 has more -details. When this flag is true, cudnn_sdpa_bwd may preallocate dQ, dK and dV -in **one** tensor and return them as slices of that tensor. +details. When this flag is true, cudnn_sdpa_bwd may cat dQ, dK and dV +as **one** tensor and return them as slices of that tensor. """ - may_preallocate: None | bool = get_compile_option("cudnn_sdpa_bwd_may_preallocate", description) - if may_preallocate is None: - may_preallocate = False - assert isinstance(may_preallocate, bool) - preallocate = may_preallocate and _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + may_cast_grad_qkv: None | bool = get_compile_option("cudnn_sdpa_bwd_may_cat_grad_qkv", description) + if may_cast_grad_qkv is None: + may_cast_grad_qkv = False + assert isinstance(may_cast_grad_qkv, bool) + cat_grad_qkv = may_cast_grad_qkv and _same_size_except(query.size(), key.size(), value.size(), except_dim=1) grads = cudnn_sdpa_bwd( get_grad(primal), @@ -690,7 +690,7 @@ def _cudnn_sdpa_bwd_wrapper( seed, offset, scale=scale, - preallocate_grad_qkv=preallocate, + cat_grad_qkv=cat_grad_qkv, ) if attn_mask is not None: @@ -698,7 +698,7 @@ def _cudnn_sdpa_bwd_wrapper( grads = grads[:-1] put_grad(attn_mask, grad_attn_mask) - if preallocate: + if cat_grad_qkv: (grad_qkv,) = grads grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) else: diff --git a/thunder/tests/test_cudnn_executor.py b/thunder/tests/test_cudnn_executor.py index 7369667dfc..37e2f6606e 100644 --- a/thunder/tests/test_cudnn_executor.py +++ b/thunder/tests/test_cudnn_executor.py @@ -188,9 +188,9 @@ def test_cudnn_vs_torch_consistency(op, device, dtype, *_): return result -@pytest.mark.parametrize("may_preallocate", (True, False), ids=("may-preallocate", "never-preallocate")) +@pytest.mark.parametrize("may_cat_grad_qkv", (True, False), ids=("may-cat-grad-qkv", "never-cat-grad-qkv")) @pytest.mark.parametrize("dtype", grad_sdpa_cudnn_opinfo.dtypes(), ids=tuple(map(str, grad_sdpa_cudnn_opinfo.dtypes()))) -def test_vjp_correctness_cudnn_sdpa(dtype, may_preallocate): +def test_vjp_correctness_cudnn_sdpa(dtype, may_cat_grad_qkv): for sample in grad_sdpa_cudnn_opinfo.reference_inputs("cuda", dtype, requires_grad=True): # Enforce tensor arguments are contiguous for torch reference contiguous_args = list(map(lambda a: a.contiguous() if isinstance(a, torch.Tensor) else a, sample.args)) @@ -219,7 +219,7 @@ def test_vjp_correctness_cudnn_sdpa(dtype, may_preallocate): disable_torch_autograd_support=True, disable_preprocessing=True, executors_list=[cudnn_ex], - cudnn_sdpa_bwd_may_preallocate=may_preallocate, + cudnn_sdpa_bwd_may_cat_grad_qkv=may_cat_grad_qkv, ) actual_out, actual_grad = cfoo(filtered_args, (v,)) From 2e1661d6fa43c022e71b5170acd10a89ffcd1204 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sun, 24 Mar 2024 23:44:19 +0000 Subject: [PATCH 12/13] Comment. --- thunder/executors/cudnnex.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index a31e62efdb..938a73d420 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -666,10 +666,9 @@ def _cudnn_sdpa_bwd_wrapper( description = """\ This flag is for enabling nvFuser's zipping optimization that seeks to avoid -expensive concatenation. -https://github.com/NVIDIA/Fuser/issues/1502#issuecomment-1870837878 has more -details. When this flag is true, cudnn_sdpa_bwd may cat dQ, dK and dV -as **one** tensor and return them as slices of that tensor. +expensive concatenation. https://github.com/NVIDIA/Fuser/issues/1768 has more +details. When this flag is true, cudnn_sdpa_bwd may cat dQ, dK and dV as one +tensor and return them as slices of that tensor. """ may_cast_grad_qkv: None | bool = get_compile_option("cudnn_sdpa_bwd_may_cat_grad_qkv", description) if may_cast_grad_qkv is None: From 30cb61ed8f55ec7290a9e08aa7bab012916b2004 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 25 Mar 2024 05:55:55 +0000 Subject: [PATCH 13/13] Comments and renaming. --- thunder/executors/cudnnex.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/thunder/executors/cudnnex.py b/thunder/executors/cudnnex.py index 938a73d420..9a9e72f3d3 100644 --- a/thunder/executors/cudnnex.py +++ b/thunder/executors/cudnnex.py @@ -505,7 +505,10 @@ def _same_size_except(*args, except_dim: int) -> bool: return all(shape == shapes[0] for shape in shapes) -def _cat_grad_qkv( +# Allocates an empty tensor that will hold dQ, dK, and dV, concatenated. +# `query`, `key` and `value` merely provide necessary metadata such as sizes +# and dtypes. They don't have to be passed in as `torch.Tensor`s. +def _allocate_catted_grad_qkv( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -518,7 +521,8 @@ def _cat_grad_qkv( h_q, h_k, h_v = query.size(1), key.size(1), value.size(1) h_qkv = h_q + h_k + h_v - # Create grad_qkv as a tensor of size [b,h_qkv,s,d] and allocation order [0,2,1,3] from major to minor. + # Create grad_qkv as a tensor of size [b,h_qkv,s,d] and allocation order + # [0,2,1,3] from major to minor. return torch.empty(b, s, h_qkv, d, dtype=query.dtype, device=query.device).permute(0, 2, 1, 3) @@ -547,7 +551,7 @@ def _cudnn_sdpa_bwd_impl( # slices of that. Otherwise, allocate them individually. grad_qkv: None | torch.Tensor = None if cat_grad_qkv: - grad_qkv = _cat_grad_qkv(query, key, value) + grad_qkv = _allocate_catted_grad_qkv(query, key, value) grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) else: grad_query = torch.empty_like(query) @@ -670,11 +674,11 @@ def _cudnn_sdpa_bwd_wrapper( details. When this flag is true, cudnn_sdpa_bwd may cat dQ, dK and dV as one tensor and return them as slices of that tensor. """ - may_cast_grad_qkv: None | bool = get_compile_option("cudnn_sdpa_bwd_may_cat_grad_qkv", description) - if may_cast_grad_qkv is None: - may_cast_grad_qkv = False - assert isinstance(may_cast_grad_qkv, bool) - cat_grad_qkv = may_cast_grad_qkv and _same_size_except(query.size(), key.size(), value.size(), except_dim=1) + may_cat_grad_qkv: None | bool = get_compile_option("cudnn_sdpa_bwd_may_cat_grad_qkv", description) + if may_cat_grad_qkv is None: + may_cat_grad_qkv = False + assert isinstance(may_cat_grad_qkv, bool) + cat_grad_qkv = may_cat_grad_qkv and _same_size_except(query.size(), key.size(), value.size(), except_dim=1) grads = cudnn_sdpa_bwd( get_grad(primal), @@ -698,6 +702,8 @@ def _cudnn_sdpa_bwd_wrapper( put_grad(attn_mask, grad_attn_mask) if cat_grad_qkv: + # The `split` is done outside `cudnn_sdpa_bwd` so it can be picked up + # by nvfuserex. (grad_qkv,) = grads grad_query, grad_key, grad_value = grad_qkv.split([query.size(1), key.size(1), value.size(1)], dim=1) else: