diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index ea076b3ec774..e31e388f506f 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -1,4 +1,4 @@ -from .attention import AttentionTesterMixin +from .attention import AttentionBackendTesterMixin, AttentionTesterMixin from .cache import ( CacheTesterMixin, FasterCacheConfigMixin, @@ -38,6 +38,7 @@ __all__ = [ + "AttentionBackendTesterMixin", "AttentionTesterMixin", "BaseModelTesterConfig", "BitsAndBytesCompileTesterMixin", diff --git a/tests/models/testing_utils/attention.py b/tests/models/testing_utils/attention.py index 134b3fa33bfe..5d61c433bbce 100644 --- a/tests/models/testing_utils/attention.py +++ b/tests/models/testing_utils/attention.py @@ -14,22 +14,105 @@ # limitations under the License. import gc +import logging import pytest import torch from diffusers.models.attention import AttentionModuleMixin -from diffusers.models.attention_processor import ( - AttnProcessor, +from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry, attention_backend +from diffusers.models.attention_processor import AttnProcessor +from diffusers.utils import is_kernels_available, is_torch_version + +from ...testing_utils import assert_tensors_close, backend_empty_cache, is_attention, torch_device + + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Module-level backend parameter sets for AttentionBackendTesterMixin +# --------------------------------------------------------------------------- + +_CUDA_AVAILABLE = torch.cuda.is_available() +_KERNELS_AVAILABLE = is_kernels_available() + +_PARAM_NATIVE = pytest.param(AttentionBackendName.NATIVE, id="native") + +_PARAM_NATIVE_CUDNN = pytest.param( + AttentionBackendName._NATIVE_CUDNN, + id="native_cudnn", + marks=pytest.mark.skipif( + not _CUDA_AVAILABLE, + reason="CUDA is required for _native_cudnn backend.", + ), +) + +_PARAM_FLASH_HUB = pytest.param( + AttentionBackendName.FLASH_HUB, + id="flash_hub", + marks=[ + pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for flash_hub backend."), + pytest.mark.skipif( + not _KERNELS_AVAILABLE, + reason="`kernels` package is required for flash_hub backend. Install with `pip install kernels`.", + ), + ], ) -from ...testing_utils import ( - assert_tensors_close, - backend_empty_cache, - is_attention, - torch_device, +_PARAM_FLASH_3_HUB = pytest.param( + AttentionBackendName._FLASH_3_HUB, + id="flash_3_hub", + marks=[ + pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for _flash_3_hub backend."), + pytest.mark.skipif( + not _KERNELS_AVAILABLE, + reason="`kernels` package is required for _flash_3_hub backend. Install with `pip install kernels`.", + ), + ], ) +# All backends under test. +_ALL_BACKEND_PARAMS = [_PARAM_NATIVE, _PARAM_NATIVE_CUDNN, _PARAM_FLASH_HUB, _PARAM_FLASH_3_HUB] + +# Backends that only accept bf16/fp16 inputs; models and inputs must be cast before running them. +_BF16_REQUIRED_BACKENDS = { + AttentionBackendName._NATIVE_CUDNN, + AttentionBackendName.FLASH_HUB, + AttentionBackendName._FLASH_3_HUB, +} + +# Backends that perform non-deterministic operations and therefore cannot run when +# torch.use_deterministic_algorithms(True) is active (e.g. after enable_full_determinism()). +_NON_DETERMINISTIC_BACKENDS = {AttentionBackendName._NATIVE_CUDNN} + + +def _maybe_cast_to_bf16(backend, model, inputs_dict): + """Cast model and floating-point inputs to bfloat16 when the backend requires it.""" + if backend not in _BF16_REQUIRED_BACKENDS: + return model, inputs_dict + model = model.to(dtype=torch.bfloat16) + inputs_dict = { + k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs_dict.items() + } + return model, inputs_dict + + +def _skip_if_backend_requires_nondeterminism(backend): + """Skip at runtime when torch.use_deterministic_algorithms(True) blocks the backend. + + This check is intentionally deferred to test execution time because + enable_full_determinism() is typically called at module level in test files *after* + the module-level pytest.param() objects in this file have already been evaluated, + making it impossible to catch via a collection-time skipif condition. + """ + if backend in _NON_DETERMINISTIC_BACKENDS and torch.are_deterministic_algorithms_enabled(): + pytest.skip( + f"Backend '{backend.value}' performs non-deterministic operations and cannot run " + f"while `torch.use_deterministic_algorithms(True)` is active." + ) + @is_attention class AttentionTesterMixin: @@ -39,7 +122,6 @@ class AttentionTesterMixin: Tests functionality from AttentionModuleMixin including: - Attention processor management (set/get) - QKV projection fusion/unfusion - - Attention backends (XFormers, NPU, etc.) Expected from config mixin: - model_class: The model class to test @@ -179,3 +261,208 @@ def test_attention_processor_count_mismatch_raises_error(self): model.set_attn_processor(wrong_processors) assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch" + + +@is_attention +class AttentionBackendTesterMixin: + """ + Mixin class for testing attention backends on models. Following things are tested: + + 1. Backends can be set with the `attention_backend` context manager and with + `set_attention_backend()` method. + 2. SDPA outputs don't deviate too much from backend outputs. + 3. Backend works with (regional) compilation. + 4. Backends can be restored. + + Tests the backends using the model provided by the host test class. The backends to test + are defined in `_ALL_BACKEND_PARAMS`. + + Expected from the host test class: + - model_class: The model class to instantiate. + + Expected methods from the host test class: + - get_init_dict(): Returns dict of kwargs to construct the model. + - get_dummy_inputs(): Returns dict of inputs for the model's forward pass. + + Pytest mark: attention + Use `pytest -m "not attention"` to skip these tests. + """ + + # ----------------------------------------------------------------------- + # Tolerance attributes — override in host class to loosen/tighten checks. + # ----------------------------------------------------------------------- + + # test_output_close_to_native: alternate backends (flash, cuDNN) may + # accumulate small numerical errors vs the reference PyTorch SDPA kernel. + backend_vs_native_atol: float = 1e-2 + backend_vs_native_rtol: float = 1e-2 + + # test_compile: regional compilation introduces the same kind of numerical + # error as the non-compiled backend path, so the same loose tolerance applies. + compile_vs_native_atol: float = 1e-2 + compile_vs_native_rtol: float = 1e-2 + + def setup_method(self): + gc.collect() + backend_empty_cache(torch_device) + + def teardown_method(self): + gc.collect() + backend_empty_cache(torch_device) + + @torch.no_grad() + @pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS) + def test_set_attention_backend_matches_context_manager(self, backend): + """set_attention_backend() and the attention_backend() context manager must yield identical outputs.""" + _skip_if_backend_requires_nondeterminism(backend) + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict) + + with attention_backend(backend): + ctx_output = model(**inputs_dict, return_dict=False)[0] + + initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend() + + try: + model.set_attention_backend(backend.value) + except Exception as e: + logger.warning("Skipping test for backend '%s': %s", backend.value, e) + pytest.skip(str(e)) + + try: + set_output = model(**inputs_dict, return_dict=False)[0] + finally: + model.reset_attention_backend() + _AttentionBackendRegistry.set_active_backend(initial_registry_backend) + + assert_tensors_close( + set_output, + ctx_output, + atol=0, + rtol=0, + msg=( + f"Output from model.set_attention_backend('{backend.value}') should be identical " + f"to the output from `with attention_backend('{backend.value}'):`." + ), + ) + + @torch.no_grad() + @pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS) + def test_output_close_to_native(self, backend): + """All backends should produce model output numerically close to the native SDPA reference.""" + _skip_if_backend_requires_nondeterminism(backend) + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict) + + with attention_backend(AttentionBackendName.NATIVE): + native_output = model(**inputs_dict, return_dict=False)[0] + + initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend() + + try: + model.set_attention_backend(backend.value) + except Exception as e: + logger.warning("Skipping test for backend '%s': %s", backend.value, e) + pytest.skip(str(e)) + + try: + backend_output = model(**inputs_dict, return_dict=False)[0] + finally: + model.reset_attention_backend() + _AttentionBackendRegistry.set_active_backend(initial_registry_backend) + + assert_tensors_close( + backend_output, + native_output, + atol=self.backend_vs_native_atol, + rtol=self.backend_vs_native_rtol, + msg=f"Output from {backend} should be numerically close to native SDPA.", + ) + + @pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS) + def test_context_manager_switches_and_restores_backend(self, backend): + """attention_backend() should activate the requested backend and restore the previous one on exit.""" + initial_backend, _ = _AttentionBackendRegistry.get_active_backend() + + with attention_backend(backend): + active_backend, _ = _AttentionBackendRegistry.get_active_backend() + assert active_backend == backend, ( + f"Backend should be {backend} inside the context manager, got {active_backend}." + ) + + restored_backend, _ = _AttentionBackendRegistry.get_active_backend() + assert restored_backend == initial_backend, ( + f"Backend should be restored to {initial_backend} after exiting the context manager, " + f"got {restored_backend}." + ) + + @pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS) + def test_compile(self, backend): + """ + `torch.compile` tests checking for recompilation, graph breaks, forward can run, etc. + For speed, we use regional compilation here (`model.compile_repeated_blocks()` + as opposed to `model.compile`). + """ + _skip_if_backend_requires_nondeterminism(backend) + if getattr(self.model_class, "_repeated_blocks", None) is None: + pytest.skip("Skipping tests as regional compilation is not supported.") + + if backend == AttentionBackendName.NATIVE and not is_torch_version(">=", "2.9.0"): + pytest.xfail( + "test_compile with the native backend requires torch >= 2.9.0 for stable " + "fullgraph compilation with error_on_recompile=True." + ) + + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict) + + with torch.no_grad(), attention_backend(AttentionBackendName.NATIVE): + native_output = model(**inputs_dict, return_dict=False)[0] + + initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend() + + try: + model.set_attention_backend(backend.value) + except Exception as e: + logger.warning("Skipping test for backend '%s': %s", backend.value, e) + pytest.skip(str(e)) + + try: + model.compile_repeated_blocks(fullgraph=True) + torch.compiler.reset() + + with ( + torch._inductor.utils.fresh_inductor_cache(), + torch._dynamo.config.patch(error_on_recompile=True), + ): + with torch.no_grad(): + compile_output = model(**inputs_dict, return_dict=False)[0] + model(**inputs_dict, return_dict=False) + finally: + model.reset_attention_backend() + _AttentionBackendRegistry.set_active_backend(initial_registry_backend) + + assert_tensors_close( + compile_output, + native_output, + atol=self.compile_vs_native_atol, + rtol=self.compile_vs_native_rtol, + msg=f"Compiled output with backend '{backend.value}' should be numerically close to eager native SDPA.", + ) diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 2d39dadfcad1..21ac89d22cc2 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -25,6 +25,7 @@ from ...testing_utils import enable_full_determinism, torch_device from ..testing_utils import ( + AttentionBackendTesterMixin, AttentionTesterMixin, BaseModelTesterConfig, BitsAndBytesCompileTesterMixin, @@ -224,6 +225,10 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM """Attention processor tests for Flux Transformer.""" +class TestFluxTransformerAttentionBackend(FluxTransformerTesterConfig, AttentionBackendTesterMixin): + """Attention backend tests for Flux Transformer.""" + + class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin): """Context Parallel inference tests for Flux Transformer""" diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py deleted file mode 100644 index 01f4521c5adc..000000000000 --- a/tests/others/test_attention_backends.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -This test suite exists for the maintainers currently. It's not run in our CI at the moment. - -Once attention backends become more mature, we can consider including this in our CI. - -To run this test suite: - -```bash -export RUN_ATTENTION_BACKEND_TESTS=yes - -pytest tests/others/test_attention_backends.py -``` - -Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in -"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128). - -Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X -with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and -aiter 0.1.5.post4.dev20+ga25e55e79. -""" - -import os - -import pytest -import torch - - -pytestmark = pytest.mark.skipif( - os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough." -) -from diffusers import FluxPipeline # noqa: E402 -from diffusers.utils import is_torch_version # noqa: E402 - - -# fmt: off -FORWARD_CASES = [ - ( - "flash_hub", - torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16) - ), - ( - "_flash_3_hub", - torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16), - ), - ( - "native", - torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16) - ), - ( - "_native_cudnn", - torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16), - ), - ( - "aiter", - torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16), - ) -] - -COMPILE_CASES = [ - ( - "flash_hub", - torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), - True - ), - ( - "_flash_3_hub", - torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), - True, - ), - ( - "native", - torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16), - True, - ), - ( - "_native_cudnn", - torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16), - True, - ), - ( - "aiter", - torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16), - True, - ) -] -# fmt: on - -INFER_KW = { - "prompt": "dance doggo dance", - "height": 256, - "width": 256, - "num_inference_steps": 2, - "guidance_scale": 3.5, - "max_sequence_length": 128, - "output_type": "pt", -} - - -def _backend_is_probably_supported(pipe, name: str): - try: - pipe.transformer.set_attention_backend(name) - return pipe, True - except Exception: - return False - - -def _check_if_slices_match(output, expected_slice): - img = output.images.detach().cpu() - generated_slice = img.flatten() - generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) - assert torch.allclose(generated_slice, expected_slice, atol=1e-4) - - -@pytest.fixture(scope="session") -def device(): - if not torch.cuda.is_available(): - pytest.skip("CUDA is required for these tests.") - return torch.device("cuda:0") - - -@pytest.fixture(scope="session") -def pipe(device): - repo_id = "black-forest-labs/FLUX.1-dev" - pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device) - pipe.set_progress_bar_config(disable=True) - return pipe - - -@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES]) -def test_forward(pipe, backend_name, expected_slice): - out = _backend_is_probably_supported(pipe, backend_name) - if isinstance(out, bool): - pytest.xfail(f"Backend '{backend_name}' not supported in this environment.") - - modified_pipe = out[0] - out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) - _check_if_slices_match(out, expected_slice) - - -@pytest.mark.parametrize( - "backend_name,expected_slice,error_on_recompile", - COMPILE_CASES, - ids=[c[0] for c in COMPILE_CASES], -) -def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile): - if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"): - pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.") - - out = _backend_is_probably_supported(pipe, backend_name) - if isinstance(out, bool): - pytest.xfail(f"Backend '{backend_name}' not supported in this environment.") - - modified_pipe = out[0] - modified_pipe.transformer.compile(fullgraph=True) - - torch.compiler.reset() - with ( - torch._inductor.utils.fresh_inductor_cache(), - torch._dynamo.config.patch(error_on_recompile=error_on_recompile), - ): - out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) - - _check_if_slices_match(out, expected_slice) diff --git a/utils/generate_model_tests.py b/utils/generate_model_tests.py index 11acd2175e21..c9729e29ebc7 100644 --- a/utils/generate_model_tests.py +++ b/utils/generate_model_tests.py @@ -72,6 +72,7 @@ # Other testers ("SingleFileTesterMixin", "single_file"), ("IPAdapterTesterMixin", "ip_adapter"), + ("AttentionBackendTesterMixin", "attention_backends"), ] @@ -530,6 +531,7 @@ def main(): "faster_cache", "single_file", "ip_adapter", + "attention_backends", "all", ], help="Optional testers to include",