From 3548ba8a7c862c2bc730a1cd2b172c8b9e3f9406 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:31:24 -0700 Subject: [PATCH] SDPA integration for nvFuser (#951) --- thunder/executors/nvfuserex_impl.py | 250 +++++++++++++++++++++++++++- thunder/executors/sdpaex.py | 176 +------------------- thunder/executors/utils.py | 170 +++++++++++++++++++ thunder/extend/__init__.py | 121 +++++++------- thunder/tests/test_nvfuser.py | 75 +++++++++ 5 files changed, 560 insertions(+), 232 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 1649c2161d..ea21e996eb 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -16,6 +16,8 @@ import thunder.core.dtypes as dtypes import thunder.torch as ltorch +from thunder.torch import TensorLike + from thunder.core import prims, utils from thunder.core.baseutils import BoundSymbolInterface from thunder.core.prims import PrimIDs @@ -34,16 +36,29 @@ from thunder.core.utils import OrderedSet, check, check_same_dtype from thunder.core.trace import TraceCtx, from_trace, TraceProvenance from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, Symbol, has_tags -from thunder.core.devices import Device, DeviceType +from thunder.core.devices import Device, DeviceType, cpu import thunder.core.codeutils as codeutils from thunder.core.codeutils import Printable from thunder.core.transform_common import dce, cse_single_bsym, replace_redundant_inputs, NON_FUNCTIONAL_OPS from thunder.core.profile import add_markers from thunder.core.compile_data import get_compile_option -from thunder.executors.utils import Region +from thunder.core.transforms import ( + get_grad, + put_grads, +) + +from thunder.executors.utils import ( + Region, + _input_dtype_check_fused_scaled_dot_product_attention, + _input_shape_check_fused_scaled_dot_product_attention, + _fused_sdp_choice, + SpdaBackend, +) + from thunder.executors.passes import update_fusion_call_ctx from thunder.extend import FUEL_LEVEL, FusionExecutor, register_executor, add_default_executor +from thunder.executors.nvfuserex import nvfuser_version # NOTE This impl file is here because nvFuser may not be available, so it's imported conditionally # by nvfuserex.py when nvFuser is available. @@ -2208,3 +2223,234 @@ def matmul( register_supported(PrimIDs.MATMUL, matmul, _matmul_check) + + +# Registering SDPA operators for nvFuser +# SDPA requires an execution and grad transform since the forward and backward passes are called through different implementations. +# For both execution and grad transform, a new operator is registered with nvfuserex (ex.register_operator) and then added to the translation map (register_supported). +# The operators are tagged with OpTag.RANDOM_OP to prevent rematerialization in backward pass. +# Finally, the complete rule is registered through ex.register_supported, with the execution and grad transform wrapping around these operators. + + +# SDPA Forward +def _scaled_dot_product_flash_attention_forward_meta( + query: TensorLike, + key: TensorLike, + value: TensorLike, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: None | float = None, +) -> tuple[TensorProxy, TensorProxy, int, int]: + # Reference metadata: + # * query (batch_size, num_heads, query_seq_len, E) + # * key (batch_size, num_heads, key_seq_len, E) + # * value (batch_size, num_heads, key_seq_len, Ev) + # * output (batch_size, num_heads, query_seq_len, Ev) + + # at::_scaled_dot_product_flash_attention returns {output, log_sumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask}. + # In nvFuser, we only save {output, log_sumexp, philox_seed/offset} for backward since the other variables are not required for non-nested input tensors. + # For non-nested tensor, cum_seq_q/k is undefined, max_q/k can be inferred from input size, and we set `return_debug_mask=False`, so `debug_attn_mask` is a 1D zero tensor. + + batch_size, num_heads, query_seq_len, E = query.shape + key_seq_len = key.shape[2] + + return ( + output := TensorProxy(like=query, shape=(batch_size, num_heads, query_seq_len, E)), + log_sumexp := TensorProxy( + shape=(batch_size, num_heads, query_seq_len), dtype=dtypes.float32, device=query.device, requires_grad=False + ), + philox_seed := TensorProxy(shape=(), dtype=dtypes.int64, device=cpu, requires_grad=False), + philox_offset := TensorProxy(shape=(), dtype=dtypes.int64, device=cpu, requires_grad=False), + ) + + +def _scaled_dot_product_flash_attention_forward( + query: TensorProxy, + key: TensorProxy, + value: TensorProxy, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: None | float = None, + fd: FusionDefinition, + lc_to_nv_map: dict, +) -> Any: + + inputs = [query, key, value, dropout_p, is_causal, scale] + nv_inputs = [] + for inp in inputs: + nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None + nv_inputs.append(nv_inp) + + return fd.ops.sdpfa_fwd(*nv_inputs) + + +nv_sdpfa_fwd = ex.register_operator( + "nv_sdpfa_fwd", + meta=_scaled_dot_product_flash_attention_forward_meta, + fn=_scaled_dot_product_flash_attention_forward, + tags=[prims.OpTags.RANDOM_OP], +) + +register_supported(nv_sdpfa_fwd.id, _scaled_dot_product_flash_attention_forward, None) + + +# SDPA Backward +def _scaled_dot_product_flash_attention_backward_meta( + grad_out: TensorLike, + query: TensorLike, + key: TensorLike, + value: TensorLike, + out: TensorLike, + logsumexp: TensorLike, + dropout_p: float, + is_causal: bool, + philox_seed: TensorLike, + philox_offset: TensorLike, + *, + scale: None | float = None, +) -> tuple[TensorProxy, TensorProxy, TensorProxy]: + + batch_size, num_heads, query_seq_len, E = query.shape + key_seq_len = key.shape[2] + + # Reference metadata: + # https://github.com/pytorch/pytorch/blob/f57b00704e498a676854a02974ca9e0c42188b23/torch/_meta_registrations.py#L5043-L5063 + grad_query = TensorProxy(like=query, shape=(batch_size, num_heads, query_seq_len, E)) + grad_key = TensorProxy(like=key, shape=(batch_size, num_heads, key_seq_len, E)) + grad_value = TensorProxy(like=value, shape=(batch_size, num_heads, key_seq_len, E)) + return (grad_query, grad_key, grad_value) + + +def _scaled_dot_product_flash_attention_backward( + grad_out: TensorProxy, + query: TensorProxy, + key: TensorProxy, + value: TensorProxy, + out: TensorProxy, + logsumexp: TensorProxy, + dropout_p: float, + is_causal: bool, + philox_seed: TensorProxy, + philox_offset: TensorProxy, + *, + scale: None | float = None, + fd: FusionDefinition, + lc_to_nv_map: dict, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + inputs = [grad_out, query, key, value, out, logsumexp, dropout_p, is_causal, philox_seed, philox_offset, scale] + nv_inputs = [] + for inp in inputs: + nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None + nv_inputs.append(nv_inp) + + return fd.ops.sdpfa_bwd(*nv_inputs) + + +nv_sdpfa_bwd = ex.register_operator( + "nv_sdpfa_bwd", + meta=_scaled_dot_product_flash_attention_backward_meta, + fn=_scaled_dot_product_flash_attention_backward, + tags=[prims.OpTags.RANDOM_OP], +) + +register_supported(nv_sdpfa_bwd.id, _scaled_dot_product_flash_attention_backward, None) + + +# Checker for SDPA +def _scaled_dot_product_flash_attention_check( + query: Proxy, + key: Proxy, + value: Proxy, + attn_mask: Proxy | None, + dropout_p: float, + is_causal: bool, + *, + scale: None | float = None, +) -> bool: + + # fd.ops.sdpfa_fwd and fd.ops.sdpfa_bwd are adding in versions 0.2.9 and 0.2.10 respectively. + if nvfuser_version() < LooseVersion("0.2.10"): + return False + + enable_sdpa: None | bool = get_compile_option("nv_enable_sdpa", "Enable nvFuser flash attention SDPA.") + + if not enable_sdpa: + return False + + # Flash attn does not support attn_mask currently. + if attn_mask is not None: + return False + + if not are_supported_tensors(query, key, value): + return False + + # FP64 is not supported by flash attention + supported_dtypes = (dtypes.float16, dtypes.bfloat16) + _input_dtype_check_fused_scaled_dot_product_attention(query, key, value, attn_mask := None, supported_dtypes) + _input_shape_check_fused_scaled_dot_product_attention(query, key, value, attn_mask := None) + + # nvFuser only implements flash attention currently. + backend = _fused_sdp_choice(query, key, value, None, dropout_p, is_causal, scale) + return backend == SpdaBackend.FLASH_ATTENTION + + +# SDPA execution_transform -- calls nv_sdpfa_fwd operator registered above +def scaled_dot_product_flash_attention( + query: TensorProxy, + key: TensorProxy, + value: TensorProxy, + attn_mask: TensorProxy = None, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: None | float = None, +): + (attn_output, logsumexp, philox_seed, philox_offset) = nv_sdpfa_fwd( + query, key, value, dropout_p, is_causal, scale=scale + ) + return attn_output + + +# SDPA grad_transform -- calls nv_sdpfa_fwd and nv_sdpfa_bwd registered above +def scaled_dot_product_flash_attention_grad( + query: Proxy, + key: Proxy, + value: Proxy, + attn_mask: None | Proxy, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: None | float = None, +): + + (attn_output, logsumexp, philox_seed, philox_offset) = nv_sdpfa_fwd( + query, key, value, dropout_p, is_causal, scale=scale + ) + grad_out = get_grad(attn_output) + grad_query, grad_key, grad_val = nv_sdpfa_bwd( + grad_out, + query, + key, + value, + attn_output, + logsumexp, + dropout_p, + is_causal, + philox_seed, + philox_offset, + scale=scale, + ) + put_grads((query, key, value), (grad_query, grad_key, grad_val)) + return attn_output + + +# Register the complete rule for SDPA in nvfuser executor +ex.register_supported( + ltorch.scaled_dot_product_attention, + checker=_scaled_dot_product_flash_attention_check, + execution_transform=scaled_dot_product_flash_attention, + grad_transform=scaled_dot_product_flash_attention_grad, +) diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index ddd82ca915..b912acaed9 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -2,13 +2,11 @@ from looseversion import LooseVersion import torch -from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode import thunder.core.dtypes as dtypes from thunder.core.proxies import Proxy, TensorProxy import thunder.core.utils as utils import thunder.core.devices as devices -from thunder.core.compile_data import get_compile_option import thunder.torch as ltorch from thunder.torch import TensorLike @@ -22,17 +20,17 @@ from enum import Enum +from thunder.executors.utils import ( + _input_dtype_check_fused_scaled_dot_product_attention, + _input_shape_check_fused_scaled_dot_product_attention, + _fused_sdp_choice, + SpdaBackend, +) + sdpa_ex: OperatorExecutor = OperatorExecutor("sdpa", version="0.1") register_executor(sdpa_ex) -class SpdaBackend(Enum): - ERROR = -1 - MATH = 0 - FLASH_ATTENTION = 1 - MEMORY_EFFICIENT = 2 - - # Both flash attention and memory efficient sdpa require that the last stride be one. def _sdpa_enforce_input_tensor_contiguity(a: torch.Tensor) -> torch.Tensor: if a is None or a.stride(-1) == 1: @@ -109,73 +107,6 @@ def _attention_mask_memory_efficient_helper(attn_mask: None | torch.Tensor, quer return expanded_attn_mask.contiguous() -# TODO These checks should be converted to compile-time checks using a checker function -# This helper function checks that the shape of input tensors are supported by fused sdpa implementation. -def _input_shape_check_fused_scaled_dot_product_attention( - query: TensorLike, key: TensorLike, value: TensorLike, attn_mask: None | TensorLike -): - # Restrict input tensors to 4 dimension - utils.check( - query.ndim == 4, - lambda: f"grad_forward_sdpa: Expected query tensor to have 4 dimension, but it has {query.ndim}.", - ) - utils.check( - key.ndim == 4, - lambda: f"grad_forward_sdpa: Expected key tensor to have 4 dimension, but it has {key.ndim}.", - ) - utils.check( - value.ndim == 4, - lambda: f"grad_forward_sdpa: Expected value tensor to have 4 dimension, but it has {value.ndim}.", - ) - utils.check( - attn_mask is None or attn_mask.ndim == 4, - lambda: f"grad_forward_sdpa: Expected attn_mask tensor to have 4 dimension, but it has {attn_mask.ndim}.", - ) - - # query (batch_size, num_heads, query_seq_len, E) - # key (batch_size, num_heads, key_seq_len, E) - # value (batch_size, num_heads, key_seq_len, Ev) - # attn_mask (batch_size, num_heads, query_seq_len, key_seq_len) - inputs = [query, key, value] - if attn_mask is not None: - inputs.append(attn_mask) - - # NOTE aten::scaled_dot_product_efficient_attention does not support broadcastable batch size. - utils.check( - all(a.shape[0] == inputs[0].shape[0] for a in inputs), - lambda: "grad_forward_sdpa: Expected all inputs to have same batch_size.", - ) - - # Check for the same number of heads - utils.check( - all(a.shape[1] == 1 or a.shape[1] == inputs[0].shape[1] for a in inputs), - lambda: "grad_forward_sdpa: Expected all inputs to have same number of attention heads or a broadcastable dimension.", - ) - - -# TODO These checks should be converted to compile-time checks using a checker function -# This helper function checks that the dtypes of input tensors are supported by fused sdpa implementation. -def _input_dtype_check_fused_scaled_dot_product_attention( - query: TensorLike, - key: TensorLike, - value: TensorLike, - attn_mask: None | TensorLike, - supported_dtypes: tuple[dtypes.dtype, ...], -): - utils.check( - query.dtype in supported_dtypes, - lambda: f"grad_forward_sdpa: Only {supported_dtypes} dtypes are supported, but query has {query.dtype}.", - ) - utils.check( - key.dtype in supported_dtypes, - lambda: f"grad_forward_sdpa: Only {supported_dtypes} dtypes are supported, but key has {key.dtype}.", - ) - utils.check( - value.dtype in supported_dtypes, - lambda: f"grad_forward_sdpa: Only {supported_dtypes} dtypes are supported, but value has {value.dtype}.", - ) - - # This helper function maps to aten::_scaled_dot_product_efficient_attention function. def _grad_forward_scaled_dot_product_efficient_attention_meta( query: TensorLike, @@ -590,99 +521,6 @@ def _scaled_dot_product_attention_grad( return primal -# This helper function converts Thunder Proxy to PyTorch Meta Tensor -def _convert_to_meta_tensor(a: None | TensorProxy) -> None | torch.Tensor: - from thunder.torch import _thunder_to_torch_dtype_map - - if a is None: - return None - return torch.empty( - a.shape, - dtype=_thunder_to_torch_dtype_map[a.dtype], - requires_grad=a.requires_grad, - device="meta", - ) - - -# This helper function converts PyTorch meta tensor to FakeTensor, which -# models stride order for contiguity checks. -def _convert_to_fake_tensor(mode: FakeTensorMode, a: None | torch.Tensor) -> None | FakeTensor: - if a is None: - return None - return FakeTensor(mode, a, device="cuda") - - -# Convert input tensors represented as Thunder Proxy to PyTorch FakeTensor. -# Determine which fused sdpa kernel. -def _fused_sdp_choice( - query: Proxy, - key: Proxy, - value: Proxy, - attn_mask: None | Proxy, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: None | float = None, -) -> int: - input_tensors = (query, key, value, attn_mask) - meta_input_tensors = list(map(_convert_to_meta_tensor, input_tensors)) - with FakeTensorMode() as mode: - fake_query, fake_key, fake_value, fake_attn_mask = list( - map(lambda a: _convert_to_fake_tensor(mode, a), meta_input_tensors) - ) - - import thunder - - if isinstance(is_causal, thunder.core.proxies.IntegerProxy): - is_causal = is_causal.value - - if LooseVersion(torch.__version__) < LooseVersion("2.2.0"): - # Figure out which SDPA to use. There are performance cliffs to the - # various implementations, and this makes the decision cognizant of - # those cliffs. - backend = torch._fused_sdp_choice( - fake_query, - fake_key, - fake_value, - fake_attn_mask, - dropout_p, - is_causal, - scale=scale, - ) - return SpdaBackend(backend) - else: - from torch.backends.cuda import ( - SDPAParams, - can_use_efficient_attention, - can_use_flash_attention, - flash_sdp_enabled, - math_sdp_enabled, - mem_efficient_sdp_enabled, - ) - - args = [] - if hasattr(SDPAParams, "enable_gqa"): - args.append(False) - - sdp_params = SDPAParams(fake_query, fake_key, fake_value, fake_attn_mask, dropout_p, is_causal, *args) - - enable_debug: None | bool = get_compile_option( - "sdpa_debug", "Enables sdpa backend warning messages when a specific kernel is unavailable." - ) - # Set default value. - if enable_debug is None: - enable_debug = False - assert isinstance(enable_debug, bool) - - if flash_sdp_enabled() and can_use_flash_attention(sdp_params, enable_debug): - return SpdaBackend.FLASH_ATTENTION - elif mem_efficient_sdp_enabled() and can_use_efficient_attention(sdp_params, enable_debug): - return SpdaBackend.MEMORY_EFFICIENT - elif math_sdp_enabled(): - return SpdaBackend.MATH - else: - return SpdaBackend.ERROR - - def _scaled_dot_product_attention_checker( query: Proxy, key: Proxy, diff --git a/thunder/executors/utils.py b/thunder/executors/utils.py index b49b9cb5c5..a17acf1613 100644 --- a/thunder/executors/utils.py +++ b/thunder/executors/utils.py @@ -17,6 +17,9 @@ from thunder.core.proxies import Variable, variableify, Proxy, unvariableify from thunder.core.prims import PrimIDs from thunder.core.transform_common import order_proxies +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from thunder.core.compile_data import get_compile_option + # TODO Make these tags comment_symbols = { @@ -110,3 +113,170 @@ def set_saved_tensors(ctx, saved_tensors): yield finally: del ctx.saved_tensors + + +# TODO These checks should be converted to compile-time checks using a checker function +# This helper function checks that the shape of input tensors are supported by fused sdpa implementation. +def _input_shape_check_fused_scaled_dot_product_attention( + query: TensorLike, key: TensorLike, value: TensorLike, attn_mask: None | TensorLike +): + # Restrict input tensors to 4 dimension + utils.check( + query.ndim == 4, + lambda: f"grad_forward_sdpa: Expected query tensor to have 4 dimension, but it has {query.ndim}.", + ) + utils.check( + key.ndim == 4, + lambda: f"grad_forward_sdpa: Expected key tensor to have 4 dimension, but it has {key.ndim}.", + ) + utils.check( + value.ndim == 4, + lambda: f"grad_forward_sdpa: Expected value tensor to have 4 dimension, but it has {value.ndim}.", + ) + utils.check( + attn_mask is None or attn_mask.ndim == 4, + lambda: f"grad_forward_sdpa: Expected attn_mask tensor to have 4 dimension, but it has {attn_mask.ndim}.", + ) + + # query (batch_size, num_heads, query_seq_len, E) + # key (batch_size, num_heads, key_seq_len, E) + # value (batch_size, num_heads, key_seq_len, Ev) + # attn_mask (batch_size, num_heads, query_seq_len, key_seq_len) + inputs = [query, key, value] + if attn_mask is not None: + inputs.append(attn_mask) + + # NOTE aten::scaled_dot_product_efficient_attention does not support broadcastable batch size. + utils.check( + all(a.shape[0] == inputs[0].shape[0] for a in inputs), + lambda: "grad_forward_sdpa: Expected all inputs to have same batch_size.", + ) + + # Check for the same number of heads + utils.check( + all(a.shape[1] == 1 or a.shape[1] == inputs[0].shape[1] for a in inputs), + lambda: "grad_forward_sdpa: Expected all inputs to have same number of attention heads or a broadcastable dimension.", + ) + + +# TODO These checks should be converted to compile-time checks using a checker function +# This helper function checks that the dtypes of input tensors are supported by fused sdpa implementation. +def _input_dtype_check_fused_scaled_dot_product_attention( + query: TensorLike, + key: TensorLike, + value: TensorLike, + attn_mask: None | TensorLike, + supported_dtypes: tuple[dtypes.dtype, ...], +): + utils.check( + query.dtype in supported_dtypes, + lambda: f"grad_forward_sdpa: Only {supported_dtypes} dtypes are supported, but query has {query.dtype}.", + ) + utils.check( + key.dtype in supported_dtypes, + lambda: f"grad_forward_sdpa: Only {supported_dtypes} dtypes are supported, but key has {key.dtype}.", + ) + utils.check( + value.dtype in supported_dtypes, + lambda: f"grad_forward_sdpa: Only {supported_dtypes} dtypes are supported, but value has {value.dtype}.", + ) + + +# This helper function converts Thunder Proxy to PyTorch Meta Tensor +def _convert_to_meta_tensor(a: None | TensorProxy) -> None | torch.Tensor: + from thunder.torch import _thunder_to_torch_dtype_map + + if a is None: + return None + return torch.empty( + a.shape, + dtype=_thunder_to_torch_dtype_map[a.dtype], + requires_grad=a.requires_grad, + device="meta", + ) + + +# This helper function converts PyTorch meta tensor to FakeTensor, which +# models stride order for contiguity checks. +def _convert_to_fake_tensor(mode: FakeTensorMode, a: None | torch.Tensor) -> None | FakeTensor: + if a is None: + return None + return FakeTensor(mode, a, device="cuda") + + +class SpdaBackend(Enum): + ERROR = -1 + MATH = 0 + FLASH_ATTENTION = 1 + MEMORY_EFFICIENT = 2 + + +# Convert input tensors represented as Thunder Proxy to PyTorch FakeTensor. +# Determine which fused sdpa kernel. +def _fused_sdp_choice( + query: Proxy, + key: Proxy, + value: Proxy, + attn_mask: None | Proxy, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: None | float = None, +) -> int: + input_tensors = (query, key, value, attn_mask) + meta_input_tensors = list(map(_convert_to_meta_tensor, input_tensors)) + with FakeTensorMode() as mode: + fake_query, fake_key, fake_value, fake_attn_mask = list( + map(lambda a: _convert_to_fake_tensor(mode, a), meta_input_tensors) + ) + + import thunder + + if isinstance(is_causal, thunder.core.proxies.IntegerProxy): + is_causal = is_causal.value + + if LooseVersion(torch.__version__) < LooseVersion("2.2.0"): + # Figure out which SDPA to use. There are performance cliffs to the + # various implementations, and this makes the decision cognizant of + # those cliffs. + backend = torch._fused_sdp_choice( + fake_query, + fake_key, + fake_value, + fake_attn_mask, + dropout_p, + is_causal, + scale=scale, + ) + return SpdaBackend(backend) + else: + from torch.backends.cuda import ( + SDPAParams, + can_use_efficient_attention, + can_use_flash_attention, + flash_sdp_enabled, + math_sdp_enabled, + mem_efficient_sdp_enabled, + ) + + args = [] + if hasattr(SDPAParams, "enable_gqa"): + args.append(False) + + sdp_params = SDPAParams(fake_query, fake_key, fake_value, fake_attn_mask, dropout_p, is_causal, *args) + + enable_debug: None | bool = get_compile_option( + "sdpa_debug", "Enables sdpa backend warning messages when a specific kernel is unavailable." + ) + # Set default value. + if enable_debug is None: + enable_debug = False + assert isinstance(enable_debug, bool) + + if flash_sdp_enabled() and can_use_flash_attention(sdp_params, enable_debug): + return SpdaBackend.FLASH_ATTENTION + elif mem_efficient_sdp_enabled() and can_use_efficient_attention(sdp_params, enable_debug): + return SpdaBackend.MEMORY_EFFICIENT + elif math_sdp_enabled(): + return SpdaBackend.MATH + else: + return SpdaBackend.ERROR diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 025238f443..de7d60af6b 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -60,6 +60,7 @@ def __init__(self, name: Hashable, *, version: None | Any = None): self._implmap: dict[Hashable, ImplInfo] = {} self._lookasides: dict[Callable, Callable] = {} + self._opmap: dict[str, Symbol] = {} @property def name(self) -> Hashable: @@ -73,6 +74,10 @@ def version(self) -> Any: def implmap(self) -> dict[Hashable, ImplInfo]: return self._implmap + @property + def opmap(self) -> dict[str, Symbol]: + return self._opmap + def __repr__(self) -> str: return f"thunder.extend.OperatorExecutor('{str(self.name)}')" @@ -133,6 +138,61 @@ def get_grad_transform(self, sym: Symbol) -> None | Callable: return impl.grad_transform + # TODO Document this operation + # TODO Wrap meta in prim context? + # TODO Document how to avoid name collisions + def register_operator( + self, + name: str, + *, + like: None | Symbol = None, + meta: None | Callable = None, + tags: None | list[Any] = None, + module: None | type | ModuleType = None, + fn: None | Callable = None, + bind_postprocess: None | Callable = None, + replaces: None | Callable = None, + python_printer: Callable = default_python_printer, + ) -> Symbol: + ln = like is None + mn = meta is None + assert ( + ln ^ mn + ), f"Expected one and only one of 'like' and 'meta' to be specified. {'Neither' if ln and mn else 'Both'} were specified." + assert (module is not None) + ( + fn is not None + ) <= 2, f"Expected one and only one of 'module' or 'fn' to be specified. Module: {module}, Fn: {fn}" + + # NOTE Directly specifying a meta function makes the operation a prim + is_prim = meta is not None + # Set tags to be the same as 'like' if 'tags' is not specified + tags = like.tags if (tags is None and like is not None and hasattr(like, "tags")) else tags + meta = meta if meta is not None else like + call_ctx: None | dict[str, Callable] = None if fn is None else {name: fn} + + def _bind_postprocess(bsym: BoundSymbol) -> None: + bsym._call_ctx = call_ctx + if bind_postprocess is not None: + bind_postprocess(bsym) + + sym = Symbol( + name=name, + id=name, + meta=meta, + is_prim=is_prim, + _module=module, + executor=self, + _bind_postprocess=_bind_postprocess, + python_printer=python_printer, + tags=tags, + ) + self.opmap[name] = sym + + if replaces is not None: + self._lookasides[replaces] = sym + + return sym + class FUEL_LEVEL(enum.Enum): UNLIMITED = enum.auto() @@ -198,67 +258,6 @@ class OperatorExecutor(Executor): def __init__(self, name: Hashable, *, version: None | Any = None): super().__init__(name, version=version) - self._opmap: dict[str, Symbol] = {} - - @property - def opmap(self) -> dict[str, Symbol]: - return self._opmap - - # TODO Document this operation - # TODO Wrap meta in prim context? - # TODO Document how to avoid name collisions - def register_operator( - self, - name: str, - *, - like: None | Symbol = None, - meta: None | Callable = None, - tags: None | list[Any] = None, - module: None | type | ModuleType = None, - fn: None | Callable = None, - bind_postprocess: None | Callable = None, - replaces: None | Callable = None, - python_printer: Callable = default_python_printer, - ) -> Symbol: - ln = like is None - mn = meta is None - assert ( - ln ^ mn - ), f"Expected one and only one of 'like' and 'meta' to be specified. {'Neither' if ln and mn else 'Both'} were specified." - assert (module is not None) + ( - fn is not None - ) <= 2, f"Expected one and only one of 'module' or 'fn' to be specified. Module: {module}, Fn: {fn}" - - # NOTE Directly specifying a meta function makes the operation a prim - is_prim = meta is not None - # Set tags to be the same as 'like' if 'tags' is not specified - tags = like.tags if (tags is None and like is not None and hasattr(like, "tags")) else tags - meta = meta if meta is not None else like - call_ctx: None | dict[str, Callable] = None if fn is None else {name: fn} - - def _bind_postprocess(bsym: BoundSymbol) -> None: - bsym._call_ctx = call_ctx - if bind_postprocess is not None: - bind_postprocess(bsym) - - sym = Symbol( - name=name, - id=name, - meta=meta, - is_prim=is_prim, - _module=module, - executor=self, - _bind_postprocess=_bind_postprocess, - python_printer=python_printer, - tags=tags, - ) - self.opmap[name] = sym - - if replaces is not None: - self._lookasides[replaces] = sym - - return sym - def register_implementation( self, sym_or_id: Symbol | Hashable, diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index 40a661ffd0..3a695a32d0 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -991,3 +991,78 @@ def make_integer_tensor(): rout = f(x.cpu(), y.cpu()).to(device) jout = f(x, y) assert rout.equal(jout) + + +@instantiate( + dtypes=(thunder.float16, thunder.bfloat16), + devicetypes=(devices.DeviceType.CUDA,), + executors=(nvFuserExecutor,), + decorators=( + pytest.mark.skipif( + nvfuser_version() is None or nvfuser_version() < LooseVersion("0.2.10"), + reason="Requires nvFuser version 0.2.10 or later", + ), + pytest.mark.parametrize("dropout_p", [0.0, 0.2]), + pytest.mark.parametrize("is_causal", [False, True]), + pytest.mark.parametrize("scale", [None, 1e-3]), + ), +) +def test_sdpa( + executor, + device: str, + thunder_dtype: dtypes.dtype, + dropout_p: None | float, + is_causal: None | bool, + scale: None | float, +): + + def sdpa_fn(q, k, v, dropout_p, is_causal, scale): + return torch.nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=dropout_p, is_causal=is_causal, scale=scale + ) + + torch.manual_seed(0) + dtype = ltorch.to_torch_dtype(thunder_dtype) + + N, H, L, S, E = 4, 8, 16, 16, 8 + q = make_tensor((N, H, L, E), device=device, dtype=dtype, requires_grad=True) + k = make_tensor((N, H, S, E), device=device, dtype=dtype, requires_grad=True) + v = make_tensor((N, H, S, E), device=device, dtype=dtype, requires_grad=True) + grad_out = make_tensor((N, H, L, E), device=device, dtype=dtype) + + tensor_inputs = [q, k, v] + scalar_inputs = [dropout_p, is_causal, scale] + + compiled_func = thunder.jit(sdpa_fn, executors_list=executor.executors_list(), nv_enable_sdpa=True) + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + attn_out = compiled_func(*tensor_inputs, *scalar_inputs) + attn_out.backward(grad_out) + fwd_trace = thunder.last_traces(compiled_func)[-1] + bwd_trace = thunder.last_backward_traces(compiled_func)[-1] + fwd_fusion = examine.get_fusions(fwd_trace) + bwd_fusion = examine.get_fusions(bwd_trace) + + assert len(fwd_fusion) == 1 + assert len(bwd_fusion) == 1 + assert "nv_sdpfa_fwd" in fwd_fusion[-1][-1].name + + # Check nv_sdpfa_fwd is not in bwd_fusion -> that would indicate rematerialization + assert "nv_sdpfa_bwd" in bwd_fusion[-1][-1].name and "nv_sdpfa_fwd" not in bwd_fusion[-1][-1].name + + # Torch reference computation + # Clone the inputs to verify gradients with torch reference + ref_tensor_inputs = [] + for inp in tensor_inputs: + ref_inp = inp.clone().detach() + ref_inp.requires_grad = True + ref_tensor_inputs.append(ref_inp) + + with torch.random.fork_rng(devices=[torch.cuda.current_device()]): + ref_attn_out = sdpa_fn(*ref_tensor_inputs, *scalar_inputs) + ref_attn_out.backward(grad_out) + + nv_outputs = (attn_out, q.grad, k.grad, v.grad) + ref_outputs = (ref_attn_out, *(inp.grad for inp in ref_tensor_inputs)) + + for nv_out, ref_out in zip(nv_outputs, ref_outputs): + torch.testing.assert_close(nv_out, ref_out)