Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/models/testing_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .attention import AttentionTesterMixin
from .attention import AttentionBackendTesterMixin, AttentionTesterMixin
from .cache import (
CacheTesterMixin,
FasterCacheConfigMixin,
Expand Down Expand Up @@ -38,6 +38,7 @@


__all__ = [
"AttentionBackendTesterMixin",
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesCompileTesterMixin",
Expand Down
303 changes: 295 additions & 8 deletions tests/models/testing_utils/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,105 @@
# limitations under the License.

import gc
import logging

import pytest
import torch

from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry, attention_backend
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import is_kernels_available, is_torch_version

from ...testing_utils import assert_tensors_close, backend_empty_cache, is_attention, torch_device


logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Module-level backend parameter sets for AttentionBackendTesterMixin
# ---------------------------------------------------------------------------

_CUDA_AVAILABLE = torch.cuda.is_available()
_KERNELS_AVAILABLE = is_kernels_available()

_PARAM_NATIVE = pytest.param(AttentionBackendName.NATIVE, id="native")

_PARAM_NATIVE_CUDNN = pytest.param(
AttentionBackendName._NATIVE_CUDNN,
id="native_cudnn",
marks=pytest.mark.skipif(
not _CUDA_AVAILABLE,
reason="CUDA is required for _native_cudnn backend.",
),
)

_PARAM_FLASH_HUB = pytest.param(
AttentionBackendName.FLASH_HUB,
id="flash_hub",
marks=[
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for flash_hub backend."),
pytest.mark.skipif(
not _KERNELS_AVAILABLE,
reason="`kernels` package is required for flash_hub backend. Install with `pip install kernels`.",
),
],
)

from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_attention,
torch_device,
_PARAM_FLASH_3_HUB = pytest.param(
AttentionBackendName._FLASH_3_HUB,
id="flash_3_hub",
marks=[
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for _flash_3_hub backend."),
pytest.mark.skipif(
not _KERNELS_AVAILABLE,
reason="`kernels` package is required for _flash_3_hub backend. Install with `pip install kernels`.",
),
],
)

# All backends under test.
_ALL_BACKEND_PARAMS = [_PARAM_NATIVE, _PARAM_NATIVE_CUDNN, _PARAM_FLASH_HUB, _PARAM_FLASH_3_HUB]

# Backends that only accept bf16/fp16 inputs; models and inputs must be cast before running them.
_BF16_REQUIRED_BACKENDS = {
AttentionBackendName._NATIVE_CUDNN,
AttentionBackendName.FLASH_HUB,
AttentionBackendName._FLASH_3_HUB,
}

# Backends that perform non-deterministic operations and therefore cannot run when
# torch.use_deterministic_algorithms(True) is active (e.g. after enable_full_determinism()).
_NON_DETERMINISTIC_BACKENDS = {AttentionBackendName._NATIVE_CUDNN}


def _maybe_cast_to_bf16(backend, model, inputs_dict):
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
if backend not in _BF16_REQUIRED_BACKENDS:
return model, inputs_dict
model = model.to(dtype=torch.bfloat16)
inputs_dict = {
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs_dict.items()
}
return model, inputs_dict


def _skip_if_backend_requires_nondeterminism(backend):
"""Skip at runtime when torch.use_deterministic_algorithms(True) blocks the backend.

This check is intentionally deferred to test execution time because
enable_full_determinism() is typically called at module level in test files *after*
the module-level pytest.param() objects in this file have already been evaluated,
making it impossible to catch via a collection-time skipif condition.
"""
if backend in _NON_DETERMINISTIC_BACKENDS and torch.are_deterministic_algorithms_enabled():
pytest.skip(
f"Backend '{backend.value}' performs non-deterministic operations and cannot run "
f"while `torch.use_deterministic_algorithms(True)` is active."
)


@is_attention
class AttentionTesterMixin:
Expand All @@ -39,7 +122,6 @@ class AttentionTesterMixin:
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)

Expected from config mixin:
- model_class: The model class to test
Expand Down Expand Up @@ -179,3 +261,208 @@ def test_attention_processor_count_mismatch_raises_error(self):
model.set_attn_processor(wrong_processors)

assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"


@is_attention
class AttentionBackendTesterMixin:
"""
Mixin class for testing attention backends on models. Following things are tested:

1. Backends can be set with the `attention_backend` context manager and with
`set_attention_backend()` method.
2. SDPA outputs don't deviate too much from backend outputs.
3. Backend works with (regional) compilation.
4. Backends can be restored.

Tests the backends using the model provided by the host test class. The backends to test
are defined in `_ALL_BACKEND_PARAMS`.

Expected from the host test class:
- model_class: The model class to instantiate.

Expected methods from the host test class:
- get_init_dict(): Returns dict of kwargs to construct the model.
- get_dummy_inputs(): Returns dict of inputs for the model's forward pass.

Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests.
"""

# -----------------------------------------------------------------------
# Tolerance attributes — override in host class to loosen/tighten checks.
# -----------------------------------------------------------------------

# test_output_close_to_native: alternate backends (flash, cuDNN) may
# accumulate small numerical errors vs the reference PyTorch SDPA kernel.
backend_vs_native_atol: float = 1e-2
backend_vs_native_rtol: float = 1e-2

# test_compile: regional compilation introduces the same kind of numerical
# error as the non-compiled backend path, so the same loose tolerance applies.
compile_vs_native_atol: float = 1e-2
compile_vs_native_rtol: float = 1e-2

def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)

def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)

@torch.no_grad()
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_set_attention_backend_matches_context_manager(self, backend):
"""set_attention_backend() and the attention_backend() context manager must yield identical outputs."""
_skip_if_backend_requires_nondeterminism(backend)

init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)

with attention_backend(backend):
ctx_output = model(**inputs_dict, return_dict=False)[0]

initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()

try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))

try:
set_output = model(**inputs_dict, return_dict=False)[0]
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)

assert_tensors_close(
set_output,
ctx_output,
atol=0,
rtol=0,
msg=(
f"Output from model.set_attention_backend('{backend.value}') should be identical "
f"to the output from `with attention_backend('{backend.value}'):`."
),
)

@torch.no_grad()
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_output_close_to_native(self, backend):
"""All backends should produce model output numerically close to the native SDPA reference."""
_skip_if_backend_requires_nondeterminism(backend)

init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)

with attention_backend(AttentionBackendName.NATIVE):
native_output = model(**inputs_dict, return_dict=False)[0]

initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()

try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))

try:
backend_output = model(**inputs_dict, return_dict=False)[0]
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)

assert_tensors_close(
backend_output,
native_output,
atol=self.backend_vs_native_atol,
rtol=self.backend_vs_native_rtol,
msg=f"Output from {backend} should be numerically close to native SDPA.",
)

@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_context_manager_switches_and_restores_backend(self, backend):
"""attention_backend() should activate the requested backend and restore the previous one on exit."""
initial_backend, _ = _AttentionBackendRegistry.get_active_backend()

with attention_backend(backend):
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
assert active_backend == backend, (
f"Backend should be {backend} inside the context manager, got {active_backend}."
)

restored_backend, _ = _AttentionBackendRegistry.get_active_backend()
assert restored_backend == initial_backend, (
f"Backend should be restored to {initial_backend} after exiting the context manager, "
f"got {restored_backend}."
)

@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_compile(self, backend):
"""
`torch.compile` tests checking for recompilation, graph breaks, forward can run, etc.
For speed, we use regional compilation here (`model.compile_repeated_blocks()`
as opposed to `model.compile`).
"""
_skip_if_backend_requires_nondeterminism(backend)
if getattr(self.model_class, "_repeated_blocks", None) is None:
pytest.skip("Skipping tests as regional compilation is not supported.")

if backend == AttentionBackendName.NATIVE and not is_torch_version(">=", "2.9.0"):
pytest.xfail(
"test_compile with the native backend requires torch >= 2.9.0 for stable "
"fullgraph compilation with error_on_recompile=True."
)

init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)

with torch.no_grad(), attention_backend(AttentionBackendName.NATIVE):
native_output = model(**inputs_dict, return_dict=False)[0]

initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()

try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))

try:
model.compile_repeated_blocks(fullgraph=True)
torch.compiler.reset()

with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
):
with torch.no_grad():
compile_output = model(**inputs_dict, return_dict=False)[0]
model(**inputs_dict, return_dict=False)
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)

assert_tensors_close(
compile_output,
native_output,
atol=self.compile_vs_native_atol,
rtol=self.compile_vs_native_rtol,
msg=f"Compiled output with backend '{backend.value}' should be numerically close to eager native SDPA.",
)
5 changes: 5 additions & 0 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionBackendTesterMixin,
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
Expand Down Expand Up @@ -224,6 +225,10 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
"""Attention processor tests for Flux Transformer."""


class TestFluxTransformerAttentionBackend(FluxTransformerTesterConfig, AttentionBackendTesterMixin):
"""Attention backend tests for Flux Transformer."""


class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux Transformer"""

Expand Down
Loading
Loading