From 3f36c6d4f46059864db55453a315ab8bce112f08 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Feb 2026 15:13:08 +0530 Subject: [PATCH 1/6] tests: add cp backend and attention backend tests. --- tests/models/testing_utils/parallelism.py | 87 +++++++++++++++++++++-- 1 file changed, 82 insertions(+), 5 deletions(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index e05b36799e66..7ef71ea5b1df 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -23,10 +23,7 @@ from diffusers.models._modeling_parallel import ContextParallelConfig -from ...testing_utils import ( - is_context_parallel, - require_torch_multi_accelerator, -) +from ...testing_utils import is_context_parallel, is_kernels_available, require_torch_multi_accelerator def _find_free_port(): @@ -38,7 +35,9 @@ def _find_free_port(): return port -def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict): +def _context_parallel_worker( + rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None +): """Worker function for context parallel testing.""" try: # Set up distributed environment @@ -67,6 +66,13 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di else: inputs_on_device[key] = value + # Enable attention backend + if attention_backend: + try: + model.set_attention_backend(attention_backend) + except Exception as e: + pytest.skip(f"Skipping test because of exception: {e}.") + # Enable context parallelism cp_config = ContextParallelConfig(**cp_dict) model.enable_parallelism(config=cp_config) @@ -126,3 +132,74 @@ def test_context_parallel_inference(self, cp_type): assert return_dict.get("status") == "success", ( f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" ) + + +@is_context_parallel +@require_torch_multi_accelerator +class ContextParallelAttentionBackendsTesterMixin: + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"]) + @pytest.mark.parametrize( + "attentiion_backend", + [ + "native", + pytest.param( + "flash_hub", + marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), + ), + pytest.param( + "_flash_3_hub", + marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), + ), + ], + ) + @pytest.mark.parametrize("ulysses_anything", [True, False]) + def test_context_parallel_attn_backend_inference(self, cp_type, attentiion_backend, ulysses_anything): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if getattr(self.model_class, "_cp_plan", None) is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + if cp_type == "ulysses_degree" and attentiion_backend == "native": + pytest.skip("Skipping test because ulysses isn't supported with native attention backend.") + + if ulysses_anything and "ulysses" not in cp_type: + pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + + # Move all tensors to CPU for multiprocessing + inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} + cp_dict = {cp_type: world_size} + if ulysses_anything: + cp_dict.update({"ulysses_anything": ulysses_anything}) + + # Find a free port for distributed communication + master_port = _find_free_port() + + # Use multiprocessing manager for cross-process communication + manager = mp.Manager() + return_dict = manager.dict() + + # Spawn worker processes + mp.spawn( + _context_parallel_worker, + args=( + world_size, + master_port, + self.model_class, + init_dict, + cp_dict, + inputs_dict, + return_dict, + attentiion_backend, + ), + nprocs=world_size, + join=True, + ) + + assert return_dict.get("status") == "success", ( + f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}" + ) From e7317067ab7bfe980c4edde148ae1929229ecbe7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Feb 2026 15:25:35 +0530 Subject: [PATCH 2/6] up --- tests/models/testing_utils/__init__.py | 3 ++- tests/models/transformers/test_models_transformer_flux.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/models/testing_utils/__init__.py b/tests/models/testing_utils/__init__.py index ea076b3ec774..d012114da85e 100644 --- a/tests/models/testing_utils/__init__.py +++ b/tests/models/testing_utils/__init__.py @@ -13,7 +13,7 @@ from .ip_adapter import IPAdapterTesterMixin from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin -from .parallelism import ContextParallelTesterMixin +from .parallelism import ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin from .quantization import ( BitsAndBytesCompileTesterMixin, BitsAndBytesConfigMixin, @@ -45,6 +45,7 @@ "BitsAndBytesTesterMixin", "CacheTesterMixin", "ContextParallelTesterMixin", + "ContextParallelAttentionBackendsTesterMixin", "CPUOffloadTesterMixin", "FasterCacheConfigMixin", "FasterCacheTesterMixin", diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 2d39dadfcad1..c8b68f36307a 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -29,6 +29,7 @@ BaseModelTesterConfig, BitsAndBytesCompileTesterMixin, BitsAndBytesTesterMixin, + ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, @@ -228,6 +229,12 @@ class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextPar """Context Parallel inference tests for Flux Transformer""" +class TestFluxTransformerContextParallelAttnBackends( + FluxTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin +): + """Context Parallel inference x attention backends tests for Flux Transformer""" + + class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin): """IP Adapter tests for Flux Transformer.""" From 1d12bd215f3f9c432bc172733d5b42904c5569f6 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Feb 2026 15:30:11 +0530 Subject: [PATCH 3/6] up --- tests/models/testing_utils/parallelism.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 7ef71ea5b1df..b24352e49d3b 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -160,7 +160,7 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attentiion_backe if getattr(self.model_class, "_cp_plan", None) is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") - if cp_type == "ulysses_degree" and attentiion_backend == "native": + if cp_type == "ring_degree" and attentiion_backend == "native": pytest.skip("Skipping test because ulysses isn't supported with native attention backend.") if ulysses_anything and "ulysses" not in cp_type: From 547f3df0a08227be073fc7aed3032d09eabf346e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Feb 2026 15:48:49 +0530 Subject: [PATCH 4/6] up --- tests/models/testing_utils/parallelism.py | 5 +++++ tests/models/testing_utils/utils.py | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 tests/models/testing_utils/utils.py diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index b24352e49d3b..e50b68adb766 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -24,6 +24,7 @@ from diffusers.models._modeling_parallel import ContextParallelConfig from ...testing_utils import is_context_parallel, is_kernels_available, require_torch_multi_accelerator +from .utils import _maybe_cast_to_bf16 def _find_free_port(): @@ -58,6 +59,9 @@ def _context_parallel_worker( model.to(device) model.eval() + # Cast as needed. + model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict) + # Move inputs to device inputs_on_device = {} for key, value in inputs_dict.items(): @@ -153,6 +157,7 @@ class ContextParallelAttentionBackendsTesterMixin: ], ) @pytest.mark.parametrize("ulysses_anything", [True, False]) + @torch.no_grad() def test_context_parallel_attn_backend_inference(self, cp_type, attentiion_backend, ulysses_anything): if not torch.distributed.is_available(): pytest.skip("torch.distributed is not available.") diff --git a/tests/models/testing_utils/utils.py b/tests/models/testing_utils/utils.py new file mode 100644 index 000000000000..7bec37db2496 --- /dev/null +++ b/tests/models/testing_utils/utils.py @@ -0,0 +1,22 @@ +import torch + +from diffusers.models.attention_dispatch import AttentionBackendName + + +_BF16_REQUIRED_BACKENDS = { + AttentionBackendName._NATIVE_CUDNN, + AttentionBackendName.FLASH_HUB, + AttentionBackendName._FLASH_3_HUB, +} + + +def _maybe_cast_to_bf16(backend, model, inputs_dict): + """Cast model and floating-point inputs to bfloat16 when the backend requires it.""" + if not backend or backend not in _BF16_REQUIRED_BACKENDS: + return model, inputs_dict + model = model.to(dtype=torch.bfloat16) + inputs_dict = { + k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v + for k, v in inputs_dict.items() + } + return model, inputs_dict From acfa871347ca3bab4a3e813906e81f91dce14df2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Feb 2026 17:07:00 +0530 Subject: [PATCH 5/6] fix ring for flash and flash_3 --- src/diffusers/models/attention_dispatch.py | 18 +++++++++++------- tests/models/testing_utils/parallelism.py | 11 ++++++----- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 834bce942e43..066f93a2a126 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1865,9 +1865,12 @@ def forward( out = out.to(torch.float32) lse = lse.to(torch.float32) - # Refer to: - # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 - if is_torch_version("<", "2.9.0"): + # lse must be 4-D to broadcast with out (B, S, H, D). + # Some backends (e.g. cuDNN on torch>=2.9) already return a + # trailing-1 dim; others (e.g. flash-hub / native-flash) always + # return 3-D lse, so we add the dim here when needed. + # See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if lse.ndim == 3: lse = lse.unsqueeze(-1) if prev_out is not None: out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) @@ -2154,10 +2157,11 @@ def _templated_unified_attention( scatter_idx, ) if return_lse: - # lse is of shape (B, S, H_LOCAL, 1) - # Refer to: - # https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 - if is_torch_version("<", "2.9.0"): + # lse from TemplatedRingAttention is 3-D (B, S, H_LOCAL) after its + # final squeeze(-1). SeqAllToAllDim requires a 4-D input, so we add + # the trailing dim here and remove it after the collective. + # See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544 + if lse.ndim == 3: lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) lse = lse.squeeze(-1) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index e50b68adb766..b2e7b92d8231 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -143,7 +143,7 @@ def test_context_parallel_inference(self, cp_type): class ContextParallelAttentionBackendsTesterMixin: @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"]) @pytest.mark.parametrize( - "attentiion_backend", + "attention_backend", [ "native", pytest.param( @@ -158,15 +158,16 @@ class ContextParallelAttentionBackendsTesterMixin: ) @pytest.mark.parametrize("ulysses_anything", [True, False]) @torch.no_grad() - def test_context_parallel_attn_backend_inference(self, cp_type, attentiion_backend, ulysses_anything): + def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything): if not torch.distributed.is_available(): pytest.skip("torch.distributed is not available.") if getattr(self.model_class, "_cp_plan", None) is None: pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") - if cp_type == "ring_degree" and attentiion_backend == "native": - pytest.skip("Skipping test because ulysses isn't supported with native attention backend.") + if cp_type == "ring_degree": + if attention_backend == "native": + pytest.skip("Skipping test because ulysses isn't supported with native attention backend.") if ulysses_anything and "ulysses" not in cp_type: pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.") @@ -199,7 +200,7 @@ def test_context_parallel_attn_backend_inference(self, cp_type, attentiion_backe cp_dict, inputs_dict, return_dict, - attentiion_backend, + attention_backend, ), nprocs=world_size, join=True, From ad9ac8dba6374d018a6882b6343369d314e3f36e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 24 Feb 2026 17:14:45 +0530 Subject: [PATCH 6/6] generate. --- utils/generate_model_tests.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/utils/generate_model_tests.py b/utils/generate_model_tests.py index 11acd2175e21..d27ced15afba 100644 --- a/utils/generate_model_tests.py +++ b/utils/generate_model_tests.py @@ -72,6 +72,7 @@ # Other testers ("SingleFileTesterMixin", "single_file"), ("IPAdapterTesterMixin", "ip_adapter"), + ("ContextParallelAttentionBackendsTesterMixin", "cp_attn"), ] @@ -229,7 +230,14 @@ def determine_testers(model_info: dict, include_optional: list[str], imports: se for tester, flag in OPTIONAL_TESTERS: if flag in include_optional: - if tester not in testers: + if tester == "ContextParallelAttentionBackendsTesterMixin": + if ( + "cp_attn" in include_optional + and "_cp_plan" in model_info["attributes"] + and model_info["attributes"]["_cp_plan"] is not None + ): + testers.append(tester) + elif tester not in testers: testers.append(tester) return testers @@ -530,6 +538,7 @@ def main(): "faster_cache", "single_file", "ip_adapter", + "cp_attn", "all", ], help="Optional testers to include",