From 7e667f3d587bb29b36c175713998a0128e35cf5e Mon Sep 17 00:00:00 2001 From: Cui-yshoho Date: Wed, 19 Nov 2025 18:40:33 +0800 Subject: [PATCH 1/3] add cp --- docs/diffusers/_toctree.yml | 2 + docs/diffusers/api/parallel.md | 20 + mindone/diffusers/__init__.py | 4 + mindone/diffusers/hooks/__init__.py | 1 + mindone/diffusers/hooks/context_parallel.py | 310 +++++ mindone/diffusers/models/__init__.py | 2 + .../diffusers/models/_modeling_parallel.py | 238 ++++ mindone/diffusers/models/attention.py | 10 + .../diffusers/models/attention_dispatch.py | 1181 +++++++++++++++++ mindone/diffusers/models/layers_compat.py | 156 +++ mindone/diffusers/models/modeling_utils.py | 134 ++ .../models/transformers/transformer_bria.py | 96 +- .../models/transformers/transformer_flux.py | 37 +- .../models/transformers/transformer_ltx.py | 84 +- .../transformers/transformer_qwenimage.py | 36 +- .../transformers/transformer_skyreels_v2.py | 347 +++-- .../models/transformers/transformer_wan.py | 115 +- 17 files changed, 2472 insertions(+), 301 deletions(-) create mode 100644 docs/diffusers/api/parallel.md create mode 100644 mindone/diffusers/hooks/context_parallel.py create mode 100644 mindone/diffusers/models/_modeling_parallel.py create mode 100644 mindone/diffusers/models/attention_dispatch.py diff --git a/docs/diffusers/_toctree.yml b/docs/diffusers/_toctree.yml index 7f4c6b0feb..e54ad9e709 100644 --- a/docs/diffusers/_toctree.yml +++ b/docs/diffusers/_toctree.yml @@ -68,6 +68,8 @@ title: Accelerate inference - local: optimization/memory title: Reduce memory usage + - local: api/parallel + title: Parallel inference - title: Community optimizations sections: - local: optimization/xformers diff --git a/docs/diffusers/api/parallel.md b/docs/diffusers/api/parallel.md new file mode 100644 index 0000000000..c0707586e8 --- /dev/null +++ b/docs/diffusers/api/parallel.md @@ -0,0 +1,20 @@ + + +# Parallelism + +Parallelism strategies help speed up diffusion transformers by distributing computations across multiple devices, allowing for faster inference/training times. + +::: mindone.diffusers.ParallelConfig + +::: mindone.diffusers.ContextParallelConfig + +::: mindone.diffusers.hooks.apply_context_parallel diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index 3eae32e857..a9fe5ad081 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -78,6 +78,7 @@ "CogView4Transformer2DModel", "ConsisIDTransformer3DModel", "ConsistencyDecoderVAE", + "ContextParallelConfig", "ControlNetModel", "ControlNetUnionModel", "ControlNetXSAdapter", @@ -106,6 +107,7 @@ "MultiAdapter", "MultiControlNetModel", "OmniGenTransformer2DModel", + "ParallelConfig", "PixArtTransformer2DModel", "PriorTransformer", "QwenImageTransformer2DModel", @@ -464,6 +466,7 @@ CogView4Transformer2DModel, ConsisIDTransformer3DModel, ConsistencyDecoderVAE, + ContextParallelConfig, ControlNetModel, ControlNetUnionModel, ControlNetXSAdapter, @@ -492,6 +495,7 @@ MultiAdapter, MultiControlNetModel, OmniGenTransformer2DModel, + ParallelConfig, PixArtTransformer2DModel, PriorTransformer, QwenImageTransformer2DModel, diff --git a/mindone/diffusers/hooks/__init__.py b/mindone/diffusers/hooks/__init__.py index 5b9096818a..3d1f1d6f92 100644 --- a/mindone/diffusers/hooks/__init__.py +++ b/mindone/diffusers/hooks/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .context_parallel import apply_context_parallel from .faster_cache import FasterCacheConfig, apply_faster_cache from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .hooks import HookRegistry, ModelHook diff --git a/mindone/diffusers/hooks/context_parallel.py b/mindone/diffusers/hooks/context_parallel.py new file mode 100644 index 0000000000..1eff865d4c --- /dev/null +++ b/mindone/diffusers/hooks/context_parallel.py @@ -0,0 +1,310 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Dict, List, Type, Union + +import mindspore as ms +from mindspore import mint + +from ..models._modeling_parallel import ( + ContextParallelConfig, + ContextParallelInput, + ContextParallelModelPlan, + ContextParallelOutput, +) +from ..utils import get_logger +from ..utils.mindspore_utils import unwrap_module +from .hooks import HookRegistry, ModelHook + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}" +_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" + + +# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata +@dataclass +class ModuleForwardMetadata: + cached_parameter_indices: Dict[str, int] = None + _cls: Type = None + + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): + kwargs = kwargs or {} + + if identifier in kwargs: + return kwargs[identifier], True, None + + if self.cached_parameter_indices is not None: + index = self.cached_parameter_indices.get(identifier, None) + if index is None: + raise ValueError(f"Parameter '{identifier}' not found in cached indices.") + return args[index], False, index + + if self._cls is None: + raise ValueError("Model class is not set for metadata.") + + parameters = list(inspect.signature(self._cls.construct).parameters.keys()) + parameters = parameters[1:] # skip `self` + self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)} + + if identifier not in self.cached_parameter_indices: + raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") + + index = self.cached_parameter_indices[identifier] + + if index >= len(args): + raise ValueError(f"Expected {index} arguments but got {len(args)}.") + + return args[index], False, index + + +def apply_context_parallel( + module: ms.nn.Cell, + parallel_config: ContextParallelConfig, + plan: Dict[str, ContextParallelModelPlan], +) -> None: + """Apply context parallel on a model.""" + logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}") + + for module_id, cp_model_plan in plan.items(): + submodule = _get_submodule_by_name(module, module_id) + if not isinstance(submodule, list): + submodule = [submodule] + + logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules") + + for m in submodule: + if isinstance(cp_model_plan, dict): + hook = ContextParallelSplitHook(cp_model_plan, parallel_config) + hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) + elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): + if isinstance(cp_model_plan, ContextParallelOutput): + cp_model_plan = [cp_model_plan] + if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan): + raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}") + hook = ContextParallelGatherHook(cp_model_plan, parallel_config) + hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) + else: + raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") + registry = HookRegistry.check_if_exists_or_initialize(m) + registry.register_hook(hook, hook_name) + + +def remove_context_parallel(module: ms.nn.Cell, plan: Dict[str, ContextParallelModelPlan]) -> None: + for module_id, cp_model_plan in plan.items(): + submodule = _get_submodule_by_name(module, module_id) + if not isinstance(submodule, list): + submodule = [submodule] + + for m in submodule: + registry = HookRegistry.check_if_exists_or_initialize(m) + if isinstance(cp_model_plan, dict): + hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) + elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): + hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) + else: + raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") + registry.remove_hook(hook_name) + + +class ContextParallelSplitHook(ModelHook): + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: + super().__init__() + self.metadata = metadata + self.parallel_config = parallel_config + self.module_forward_metadata = None + + def initialize_hook(self, module): + cls = unwrap_module(module).__class__ + self.module_forward_metadata = ModuleForwardMetadata(_cls=cls) + return module + + def pre_construct(self, module, *args, **kwargs): + args_list = list(args) + + for name, cpm in self.metadata.items(): + if isinstance(cpm, ContextParallelInput) and cpm.split_output: + continue + + # Maybe the parameter was passed as a keyword argument + input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs( + name, args_list, kwargs + ) + + if input_val is None: + continue + + # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard + # the output instead of input for a particular layer by setting split_output=True + if isinstance(input_val, ms.Tensor): + input_val = self._prepare_cp_input(input_val, cpm) + elif isinstance(input_val, (list, tuple)): + if len(input_val) != len(cpm): + raise ValueError( + f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}." + ) + sharded_input_val = [] + for i, x in enumerate(input_val): + if ms.is_tensor(x) and not cpm[i].split_output: + x = self._prepare_cp_input(x, cpm[i]) + sharded_input_val.append(x) + input_val = sharded_input_val + else: + raise ValueError(f"Unsupported input type: {type(input_val)}") + + if is_kwarg: + kwargs[name] = input_val + elif index is not None and index < len(args_list): + args_list[index] = input_val + else: + raise ValueError( + f"An unexpected error occurred while processing the input '{name}'. Please open an " + f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible " + f"example along with the full stack trace." + ) + + return tuple(args_list), kwargs + + def post_construct(self, module, output): + is_tensor = isinstance(output, ms.Tensor) + is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, ms.Tensor) for x in output) + + if not is_tensor and not is_tensor_list: + raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") + + output = [output] if is_tensor else list(output) + for index, cpm in self.metadata.items(): + if not isinstance(cpm, ContextParallelInput) or not cpm.split_output: + continue + if index >= len(output): + raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") + current_output = output[index] + current_output = self._prepare_cp_input(current_output, cpm) + output[index] = current_output + + return output[0] if is_tensor else tuple(output) + + def _prepare_cp_input(self, x: ms.Tensor, cp_input: ContextParallelInput) -> ms.Tensor: + if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: + raise ValueError( + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." + ) + return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + + +class ContextParallelGatherHook(ModelHook): + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None: + super().__init__() + self.metadata = metadata + self.parallel_config = parallel_config + + def post_construct(self, module, output): + is_tensor = isinstance(output, ms.Tensor) + + if is_tensor: + output = [output] + elif not (isinstance(output, (list, tuple)) and all(isinstance(x, ms.Tensor) for x in output)): + raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.") + + output = list(output) + + if len(output) != len(self.metadata): + raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.") + + for i, cpm in enumerate(self.metadata): + if cpm is None: + continue + output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) + + return output[0] if is_tensor else tuple(output) + + +class AllGatherFunction(ms.nn.Cell): + def __init__(self, dim, group): + super().__init__() + self.dim = dim + self.group = group + self.world_size = mint.distributed.get_world_size(group) + self.rank = mint.distributed.get_rank(group) + + def construct(self, tensor): + # return funcol.all_gather_tensor(tensor, dim, group=group) + # mint.distributed.all_gather_into_tensor only support dim=0 + tensor_t = tensor.transpose(self.dim, 0) if self.dim != 0 else tensor + + out_shape = list(tensor_t.shape) + out_shape[0] *= self.world_size + output = mint.zeros(out_shape, dtype=tensor_t.dtype) + + mint.distributed.all_gather_into_tensor(output, tensor_t.contiguous(), group=self.group) + + if self.dim != 0: + output = output.transpose(0, self.dim) + + return output + + def bprop(self, tensor, out, dout): + grad_chunks = mint.chunk(dout, self.world_size, dim=self.dim) + return (grad_chunks[self.rank],) + + +class EquipartitionSharder: + @classmethod + def shard(cls, tensor: ms.Tensor, dim: int, mesh) -> ms.Tensor: + # NOTE: the following assertion does not have to be true in general. We simply enforce it for now + # because the alternate case has not yet been tested/required for any model. + assert ( + tensor.shape[dim] % mint.distributed.get_world_size(mesh) == 0 + ), "Tensor size along dimension to be sharded must be divisible by mesh size" + + # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank) + # return tensor.chunk(mint.distributed.get_world_size(mesh), dim=dim)[mesh.get_rank()] + + return tensor.chunk(mint.distributed.get_world_size(mesh), dim=dim)[mint.distributed.get_rank(mesh)] + + @classmethod + def unshard(cls, tensor: ms.Tensor, dim: int, mesh) -> ms.Tensor: + tensor = tensor.contiguous() + tensor = AllGatherFunction(dim, mesh)(tensor) + return tensor + + +def _get_submodule_by_name(model: ms.nn.Cell, name: str) -> Union[ms.nn.Cell, List[ms.nn.Cell]]: + if name.count("*") > 1: + raise ValueError("Wildcard '*' can only be used once in the name") + return _find_submodule_by_name(model, name) + + +def _find_submodule_by_name(model: ms.nn.Cell, name: str) -> Union[ms.nn.Cell, List[ms.nn.Cell]]: + if name == "": + return model + first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "") + if first_atom == "*": + if not isinstance(model, ms.nn.CellList): + raise ValueError("Wildcard '*' can only be used with ModuleList") + submodules = [] + for submodule in model: + subsubmodules = _find_submodule_by_name(submodule, remaining_name) + if not isinstance(subsubmodules, list): + subsubmodules = [subsubmodules] + submodules.extend(subsubmodules) + return submodules + else: + if hasattr(model, first_atom): + submodule = getattr(model, first_atom) + return _find_submodule_by_name(submodule, remaining_name) + else: + raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'") diff --git a/mindone/diffusers/models/__init__.py b/mindone/diffusers/models/__init__.py index cbb81ccf22..f859f450c6 100644 --- a/mindone/diffusers/models/__init__.py +++ b/mindone/diffusers/models/__init__.py @@ -20,6 +20,7 @@ from ..utils import _LazyModule _import_structure = { + "_modeling_parallel": ["ContextParallelConfig", "ParallelConfig"], "adapter": ["MultiAdapter", "T2IAdapter"], "auto_model": ["AutoModel"], "autoencoders.autoencoder_asym_kl": ["AsymmetricAutoencoderKL"], @@ -104,6 +105,7 @@ } if TYPE_CHECKING: + from ._modeling_parallel import ContextParallelConfig, ParallelConfig from .adapter import MultiAdapter, T2IAdapter from .auto_model import AutoModel from .autoencoders import ( diff --git a/mindone/diffusers/models/_modeling_parallel.py b/mindone/diffusers/models/_modeling_parallel.py new file mode 100644 index 0000000000..a4edc00d26 --- /dev/null +++ b/mindone/diffusers/models/_modeling_parallel.py @@ -0,0 +1,238 @@ +# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨 +# Experimental changes are subject to change and APIs may break without warning. + +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union + +from mindspore import mint + +from ..utils import get_logger + +if TYPE_CHECKING: + pass + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +# TODO(aryan): add support for the following: +# - Unified Attention +# - More dispatcher attention backends +# - CFG/Data Parallel +# - Tensor Parallel + + +@dataclass +class ContextParallelConfig: + """ + Configuration for context parallelism. + + Args: + ring_degree (`int`, *optional*, defaults to `1`): + Number of devices to use for ring attention within a context parallel region. Must be a divisor of the + total number of devices in the context parallel mesh. + ulysses_degree (`int`, *optional*, defaults to `1`): + Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the + total number of devices in the context parallel mesh. + convert_to_fp32 (`bool`, *optional*, defaults to `True`): + Whether to convert output and LSE to float32 for ring attention numerical stability. + rotate_method (`str`, *optional*, defaults to `"allgather"`): + Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"` + is supported. + + """ + + ring_degree: Optional[int] = None + ulysses_degree: Optional[int] = None + convert_to_fp32: bool = True + # TODO: support alltoall + rotate_method: Literal["allgather", "alltoall"] = "allgather" + + _rank: int = None + _world_size: int = None + _device: str = None + _mesh: dict = None + _flattened_mesh: str = None + _ring_mesh: str = None + _ulysses_mesh: str = None + _ring_local_rank: int = None + _ulysses_local_rank: int = None + + def __post_init__(self): + if self.ring_degree is None: + self.ring_degree = 1 + if self.ulysses_degree is None: + self.ulysses_degree = 1 + + def setup(self, rank: int, world_size: int, device, mesh): + self._rank = rank + self._world_size = world_size + self._device = device + self._mesh = mesh + if self.ring_degree is None: + self.ring_degree = 1 + if self.ulysses_degree is None: + self.ulysses_degree = 1 + if self.rotate_method != "allgather": + raise NotImplementedError( + f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." + ) + if self._flattened_mesh is None: + self._flattened_mesh = self._mesh._flatten() + if self._ring_mesh is None: + self._ring_mesh = self._mesh["ring"] + if self._ulysses_mesh is None: + self._ulysses_mesh = self._mesh["ulysses"] + if self._ring_local_rank is None: + self._ring_local_rank = mint.distributed.get_rank(self._ring_mesh) + if self._ulysses_local_rank is None: + self._ulysses_local_rank = mint.distributed.get_rank(self._ulysses_mesh) + + +@dataclass +class ParallelConfig: + """ + Configuration for applying different parallelisms. + + Args: + context_parallel_config (`ContextParallelConfig`, *optional*): + Configuration for context parallelism. + """ + + context_parallel_config: Optional[ContextParallelConfig] = None + + _rank: int = None + _world_size: int = None + _device: str = None + _cp_mesh: dict = None + + def setup( + self, + rank: int, + world_size: int, + device: str, + *, + cp_mesh: Optional[dict] = None, + ): + self._rank = rank + self._world_size = world_size + self._device = device + self._cp_mesh = cp_mesh + if self.context_parallel_config is not None: + self.context_parallel_config.setup(rank, world_size, device, cp_mesh) + + +@dataclass(frozen=True) +class ContextParallelInput: + """ + Configuration for splitting an input tensor across context parallel region. + + Args: + split_dim (`int`): + The dimension along which to split the tensor. + expected_dims (`int`, *optional*): + The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the + tensor has the expected number of dimensions before splitting. + split_output (`bool`, *optional*, defaults to `False`): + Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor. + This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex: + RoPE). + """ + + split_dim: int + expected_dims: Optional[int] = None + split_output: bool = False + + def __repr__(self): + return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})" + + +@dataclass(frozen=True) +class ContextParallelOutput: + """ + Configuration for gathering an output tensor across context parallel region. + + Args: + gather_dim (`int`): + The dimension along which to gather the tensor. + expected_dims (`int`, *optional*): + The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the + tensor has the expected number of dimensions before gathering. + """ + + gather_dim: int + expected_dims: Optional[int] = None + + def __repr__(self): + return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})" + + +# A dictionary where keys denote the input to be split across context parallel region, and the +# value denotes the sharding configuration. +# If the key is a string, it denotes the name of the parameter in the forward function. +# If the key is an integer, split_output must be set to True, and it denotes the index of the output +# to be split across context parallel region. +ContextParallelInputType = Dict[ + Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]] +] + +# A dictionary where keys denote the output to be gathered across context parallel region, and the +# value denotes the gathering configuration. +ContextParallelOutputType = Union[ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]] + +# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of +# the module should be split/gathered across context parallel region. +ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]] + + +# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel): +# +# Each model should define a _cp_plan attribute that contains information on how to shard/gather +# tensors at different stages of the forward: +# +# ```python +# _cp_plan = { +# "": { +# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), +# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), +# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), +# }, +# "pos_embed": { +# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), +# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), +# }, +# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), +# } +# ``` +# +# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be +# split/gathered according to this at the respective module level. Here, the following happens: +# - "": +# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before +# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs) +# - "pos_embed": +# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs), +# we can individually specify how they should be split +# - "proj_out": +# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear +# layer forward has run). +# +# ContextParallelInput: +# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to +# +# ContextParallelOutput: +# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to diff --git a/mindone/diffusers/models/attention.py b/mindone/diffusers/models/attention.py index 8f4a7e7bec..72916544a3 100644 --- a/mindone/diffusers/models/attention.py +++ b/mindone/diffusers/models/attention.py @@ -150,6 +150,16 @@ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProce if not return_deprecated_lora: return self.processor + def set_attention_backend(self, backend: str): + from .attention_dispatch import AttentionBackendName + + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend.lower()) + self.processor._attention_backend = backend + def fuse_projections(self): """ Fuse the query, key, and value projections into a single projection for efficiency. diff --git a/mindone/diffusers/models/attention_dispatch.py b/mindone/diffusers/models/attention_dispatch.py new file mode 100644 index 0000000000..b7a598b2e1 --- /dev/null +++ b/mindone/diffusers/models/attention_dispatch.py @@ -0,0 +1,1181 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import functools +import inspect +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union + +import mindspore as ms +from mindspore import mint, nn, ops + +from ..utils import get_logger +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS +from .layers_compat import scaled_dot_product_attention + +if TYPE_CHECKING: + from ._modeling_parallel import ParallelConfig + +_CAN_USE_FLASH_ATTN = True +_CAN_USE_FLASH_ATTN_3 = False +_CAN_USE_SAGE_ATTN = False +_CAN_USE_FLEX_ATTN = False +_CAN_USE_NPU_ATTN = False +_CAN_USE_XLA_ATTN = False +_CAN_USE_XFORMERS_ATTN = False + + +if _CAN_USE_FLASH_ATTN_3: + raise RuntimeError("Flash Attention 3 is not usable.") +else: + flash_attn_3_func = None + flash_attn_3_varlen_func = None + flash_attn_3_func_hub = None + + +if _CAN_USE_SAGE_ATTN: + raise RuntimeError("Sage Attention is not usable.") +else: + sageattn = None + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_qk_int8_pv_fp8_cuda = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + sageattn_varlen = None + + +if _CAN_USE_FLEX_ATTN: + # We cannot import the flex_attention function from the package directly because it is expected (from the + # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the + # compiled function. + raise RuntimeError("Flex Attention is not usable.") +else: + flex_attention = None + + +if _CAN_USE_NPU_ATTN: + raise RuntimeError("NPU Fusion Attention is not usable.") +else: + npu_fusion_attention = None + + +if _CAN_USE_XLA_ATTN: + raise RuntimeError("XLA Attention is not usable.") +else: + xla_flash_attention = None + + +if _CAN_USE_XFORMERS_ATTN: + raise RuntimeError("Xformers Attention is not usable.") +else: + xops = None + + +def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + + return wrap if fn is None else fn + + +def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + + return wrap if fn is None else fn + + +_custom_op = custom_op_no_op +_register_fake = register_fake_no_op + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +# TODO(aryan): Add support for the following: +# - Sage Attention++ +# - block sparse, radial and other attention methods +# - CP with sage attention, flex, xformers, other missing backends +# - Add support for normal and CP training with backends that don't support it yet + +_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] +_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] +_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] + + +class AttentionBackendName(str, Enum): + # EAGER = "eager" + + # `flash-attn` + FLASH = "flash" + FLASH_VARLEN = "flash_varlen" + _FLASH_3 = "_flash_3" + _FLASH_VARLEN_3 = "_flash_varlen_3" + _FLASH_3_HUB = "_flash_3_hub" + _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. + + # PyTorch native + FLEX = "flex" + NATIVE = "native" + _NATIVE_CUDNN = "_native_cudnn" + _NATIVE_EFFICIENT = "_native_efficient" + _NATIVE_FLASH = "_native_flash" + _NATIVE_MATH = "_native_math" + _NATIVE_NPU = "_native_npu" + _NATIVE_XLA = "_native_xla" + + # `sageattention` + SAGE = "sage" + SAGE_VARLEN = "sage_varlen" + _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" + _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" + _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" + _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # TODO: let's not add support for Sparge Attention now because it requires tuning per model + # We can look into supporting something "autotune"-ing in the future + # SPARGE = "sparge" + + # `xformers` + XFORMERS = "xformers" + + +class _AttentionBackendRegistry: + _backends = {} + _constraints = {} + _supported_arg_names = {} + _supports_context_parallel = {} + _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) + _checks_enabled = DIFFUSERS_ATTN_CHECKS + + @classmethod + def register( + cls, + backend: AttentionBackendName, + constraints: Optional[List[Callable]] = None, + supports_context_parallel: bool = False, + ): + logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") + + def decorator(func): + cls._backends[backend] = func + cls._constraints[backend] = constraints or [] + cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) + cls._supports_context_parallel[backend] = supports_context_parallel + return func + + return decorator + + @classmethod + def get_active_backend(cls): + return cls._active_backend, cls._backends[cls._active_backend] + + @classmethod + def list_backends(cls): + return list(cls._backends.keys()) + + @classmethod + def _is_context_parallel_enabled( + cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] + ) -> bool: + supports_context_parallel = backend in cls._supports_context_parallel + is_degree_greater_than_1 = parallel_config is not None and ( + parallel_config.context_parallel_config.ring_degree > 1 + or parallel_config.context_parallel_config.ulysses_degree > 1 + ) + return supports_context_parallel and is_degree_greater_than_1 + + +@contextlib.contextmanager +def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): + """ + Context manager to set the active attention backend. + """ + if backend not in _AttentionBackendRegistry._backends: + raise ValueError(f"Backend {backend} is not registered.") + + backend = AttentionBackendName(backend) + _check_attention_backend_requirements(backend) + + old_backend = _AttentionBackendRegistry._active_backend + _AttentionBackendRegistry._active_backend = backend + + try: + yield + finally: + _AttentionBackendRegistry._active_backend = old_backend + + +def dispatch_attention_fn( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + *, + backend: Optional[AttentionBackendName] = None, + parallel_config: Optional["ParallelConfig"] = None, +) -> ms.Tensor: + attention_kwargs = attention_kwargs or {} + + if backend is None: + # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment + # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager + backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() + else: + backend_name = AttentionBackendName(backend) + backend_fn = _AttentionBackendRegistry._backends.get(backend_name) + + if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled( + backend_name, parallel_config + ): + raise ValueError( + f"Backend {backend_name} either does not support context parallelism or context parallelism " + f"was enabled with a world size of 1." + ) + + kwargs = { + "query": query, + "key": key, + "value": value, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + **attention_kwargs, + "_parallel_config": parallel_config, + } + kwargs["enable_gqa"] = enable_gqa + + if _AttentionBackendRegistry._checks_enabled: + removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) + if removed_kwargs: + logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.") + for check in _AttentionBackendRegistry._constraints.get(backend_name): + check(**kwargs) + + kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} + return backend_fn(**kwargs) + + +# ===== Checks ===== +# A list of very simple functions to catch common errors quickly when debugging. + + +def _check_attn_mask_or_causal(attn_mask: Optional[ms.Tensor], is_causal: bool, **kwargs) -> None: + if attn_mask is not None and is_causal: + raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") + + +def _check_device(query: ms.Tensor, key: ms.Tensor, value: ms.Tensor, **kwargs) -> None: + if query.device != key.device or query.device != value.device: + raise ValueError("Query, key, and value must be on the same device.") + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError("Query, key, and value must have the same dtype.") + + +def _check_qkv_dtype_match(query: ms.Tensor, key: ms.Tensor, value: ms.Tensor, **kwargs) -> None: + if query.dtype != key.dtype: + raise ValueError("Query and key must have the same dtype.") + if query.dtype != value.dtype: + raise ValueError("Query and value must have the same dtype.") + + +def _check_qkv_dtype_bf16_or_fp16(query: ms.Tensor, key: ms.Tensor, value: ms.Tensor, **kwargs) -> None: + _check_qkv_dtype_match(query, key, value) + if query.dtype not in (ms.bfloat16, ms.float16): + raise ValueError("Query, key, and value must be either bfloat16 or float16.") + + +def _check_shape( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + **kwargs, +) -> None: + if query.shape[-1] != key.shape[-1]: + raise ValueError("Query and key must have the same last dimension.") + if query.shape[-2] != value.shape[-2]: + raise ValueError("Query and value must have the same second to last dimension.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: + raise ValueError("Attention mask must match the key's second to last dimension.") + + +# ===== Helper functions ===== + + +def _check_attention_backend_requirements(backend: AttentionBackendName) -> None: + if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]: + if not _CAN_USE_FLASH_ATTN: + raise RuntimeError(f"Flash Attention backend '{backend.value}' is not usable.") + + elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: + if not _CAN_USE_FLASH_ATTN_3: + raise RuntimeError(f"Flash Attention 3 backend '{backend.value}' is not usable.") + + # TODO: add support Hub variant of FA3 varlen later + elif backend in [AttentionBackendName._FLASH_3_HUB]: + raise RuntimeError(f"Flash Attention 3 Hub backend '{backend.value}' is not usable.") + + elif backend in [ + AttentionBackendName.SAGE, + AttentionBackendName.SAGE_VARLEN, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + ]: + if not _CAN_USE_SAGE_ATTN: + raise RuntimeError(f"Sage Attention backend '{backend.value}' is not usable.") + + elif backend == AttentionBackendName.FLEX: + if not _CAN_USE_FLEX_ATTN: + raise RuntimeError(f"Flex Attention backend '{backend.value}' is not usable.") + + elif backend == AttentionBackendName._NATIVE_NPU: + if not _CAN_USE_NPU_ATTN: + raise RuntimeError(f"NPU Attention backend '{backend.value}' is not usable.") + + elif backend == AttentionBackendName._NATIVE_XLA: + if not _CAN_USE_XLA_ATTN: + raise RuntimeError(f"XLA Attention backend '{backend.value}' is not usable.") + + elif backend == AttentionBackendName.XFORMERS: + if not _CAN_USE_XFORMERS_ATTN: + raise RuntimeError(f"Xformers Attention backend '{backend.value}' is not usable.") + + +@functools.lru_cache(maxsize=128) +def _prepare_for_flash_attn_or_sage_varlen_without_mask( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, +): + seqlens_q = mint.full((batch_size,), seq_len_q, dtype=ms.int32) + seqlens_k = mint.full((batch_size,), seq_len_kv, dtype=ms.int32) + cu_seqlens_q = mint.zeros(batch_size + 1, dtype=ms.int32) + cu_seqlens_k = mint.zeros(batch_size + 1, dtype=ms.int32) + cu_seqlens_q[1:] = mint.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = mint.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _prepare_for_flash_attn_or_sage_varlen_with_mask( + batch_size: int, + seq_len_q: int, + attn_mask: ms.Tensor, +): + seqlens_q = mint.full((batch_size,), seq_len_q, dtype=ms.int32) + seqlens_k = attn_mask.sum(dim=1, dtype=ms.int32) + cu_seqlens_q = mint.zeros(batch_size + 1, dtype=ms.int32) + cu_seqlens_k = mint.zeros(batch_size + 1, dtype=ms.int32) + cu_seqlens_q[1:] = mint.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = mint.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _prepare_for_flash_attn_or_sage_varlen( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + attn_mask: Optional[ms.Tensor] = None, +) -> None: + if attn_mask is None: + return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv) + return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask) + + +def _normalize_attn_mask(attn_mask: ms.Tensor, batch_size: int, seq_len_k: int) -> ms.Tensor: + """ + Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in + FlashAttention/Sage varlen. + + Supports 1D to 4D shapes and common broadcasting patterns. + """ + if attn_mask.dtype != ms.bool: + raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") + + if attn_mask.ndim == 1: + # [seq_len_k] -> broadcast across batch + attn_mask = attn_mask.unsqueeze(0).expand((batch_size, seq_len_k)) + + elif attn_mask.ndim == 2: + # [batch_size, seq_len_k]. Maybe broadcast across batch + if attn_mask.shape[0] not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." + ) + attn_mask = attn_mask.expand((batch_size, seq_len_k)) + + elif attn_mask.ndim == 3: + # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension + # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen. + if attn_mask.shape[0] not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." + ) + attn_mask = attn_mask.any(dim=1) + attn_mask = attn_mask.expand((batch_size, seq_len_k)) + + elif attn_mask.ndim == 4: + # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions + if attn_mask.shape[0] not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." + ) + attn_mask = attn_mask.expand((batch_size, -1, -1, seq_len_k)) # [B, H, Q, K] + attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] + + else: + raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") + + if attn_mask.shape != (batch_size, seq_len_k): + raise ValueError( + f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" + ) + + return attn_mask + + +def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return q_idx >= kv_idx + + +# ===== Helper functions to use attention backends with templated CP autograd functions ===== + + +class NativeAttentionCell(nn.Cell): + def __init__(self): + super().__init__() + + def construct( + self, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, + ): + # Native attention does not return_lse + if return_lse: + raise ValueError("Native attention does not support return_lse=True") + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + return out + + def bprop( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx, + _parallel_config, + out, + dout, + ): + query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + def forward_fn(q, k, v): + out = scaled_dot_product_attention( + query=q, + key=k, + value=v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + grad_query_t, grad_key_t, grad_value_t = ms.grad(forward_fn, grad_position=(0, 1, 2))(query_t, key_t, value_t) + + grad_query = grad_query_t.permute(0, 2, 1, 3) + grad_key = grad_key_t.permute(0, 2, 1, 3) + grad_value = grad_value_t.permute(0, 2, 1, 3) + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + + +# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 +class FlashAttentionCell(nn.Cell): + def __init__(self): + super().__init__() + + def construct( + self, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, + ): + # Hardcoded for now + grad_enabled = any(x._requires_grad for x in (query, key, value)) + + if scale is None: + scale = query.shape[-1] ** (-0.5) + + if is_causal: + sparse_mode = 2 + else: + sparse_mode = 0 + + # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. + if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): + dropout_p = dropout_p if dropout_p > 0 else 1e-30 + + input_layout = "BSND" + head_num = query.shape[2] + + softmax_max, softmax_sum, _, out = ops.operations.nn_ops.FlashAttentionScore( + head_num=head_num, + keep_prob=1 - dropout_p, + scale_value=scale, + input_layout=input_layout, + sparse_mode=sparse_mode, + )(query, key, value, None, None, None, attn_mask) + lse = softmax_max[..., 0] + mint.log(softmax_sum[..., 0]) + lse = lse.permute(0, 2, 1) + + return (out, lse) if return_lse else out + + def bprop( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx, + _parallel_config, + out, + dout, + ): + grad_query, grad_key, grad_value = mint.empty_like(query), mint.empty_like(key), mint.empty_like(value) + + # Head dimension may have been padded + grad_query = grad_query[..., : dout.shape[-1]] + grad_key = grad_key[..., : dout.shape[-1]] + grad_value = grad_value[..., : dout.shape[-1]] + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + + +# ===== Context parallel ===== + + +def _all_to_all_single(x: ms.Tensor, group) -> ms.Tensor: + shape = x.shape + # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization + # to benchmark triton codegen fails somewhere: + # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3') + # ValueError: Tensors must be contiguous + x = x.flatten() + # `all_to_all_single` writes the result into output in-place. + x_output = mint.zeros_like(x) + mint.distributed.all_to_all_single(x_output, x, group=group) + x_output = x_output.reshape(shape) + return x_output + + +def permute_tensor( + tensor: ms.Tensor, + src_dst: list[int], + group: None, +) -> ms.Tensor: + """ + Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should + be defined such that src_dst[m] == n means m sends to n. + + Group can be one of: + List[int]: ranks participating in the collective. + List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD. + ProcessGroup: Will perform a collective using the ranks and tag of the PG. + DeviceMesh: Do a SPMD collective over all ranks of the mesh + (DeviceMesh, int): Do a MPMD collective over one + """ + rank = mint.distributed.get_rank(group) + world_size = mint.distributed.get_world_size(group) + + output_split_sizes = [0] * world_size + input_split_sizes = [0] * world_size + + dst = src_dst[rank] + input_split_sizes[dst] = tensor.size + + for m, n in enumerate(src_dst): + if n == rank: + output_split_sizes[m] = tensor.size + + output = mint.zeros_like(tensor) + + mint.distributed.all_to_all_single( + output, tensor, input_split_sizes=input_split_sizes, output_split_sizes=output_split_sizes, group=group + ) + + return output + + +class TemplatedRingAttention(nn.Cell): + def __init__(self): + super().__init__() + self.forward_op = None + self.backward_op = None + self.q_shape = None + self.kv_shape = None + self._parallel_config = None + + def construct( + self, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + ): + ring_mesh = _parallel_config.context_parallel_config._ring_mesh + rank = _parallel_config.context_parallel_config._ring_local_rank + world_size = _parallel_config.context_parallel_config.ring_degree + next_rank = (rank + 1) % world_size + prev_out = prev_lse = None + + self.forward_op = forward_op + self.backward_op = backward_op + self.q_shape = query.shape + self.kv_shape = key.shape + self._parallel_config = _parallel_config + + kv_buffer = mint.cat([key.flatten(), value.flatten()]).contiguous() + group_size = mint.distributed.get_world_size(ring_mesh) + kv_buffer_output = mint.cat([mint.zeros_like(kv_buffer) for _ in range(group_size)], dim=0) + # `all_gather_into_tensor` performs in-place all-gather into kv_buffer_output. + _ = mint.distributed.all_gather_into_tensor(kv_buffer_output, kv_buffer, group=ring_mesh) + kv_buffer = kv_buffer_output.chunk(world_size) + + for i in range(world_size): + if i > 0: + kv = kv_buffer[next_rank] + key_numel = key.numel() + key = kv[:key_numel].reshape_as(key) + value = kv[key_numel:].reshape_as(value) + next_rank = (next_rank + 1) % world_size + + out, lse = forward_op( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + True, + ) + + if _parallel_config.context_parallel_config.convert_to_fp32: + out = out.to(ms.float32) + lse = lse.to(ms.float32) + + lse = lse.unsqueeze(-1) + if prev_out is not None: + out = prev_out - mint.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) + lse = prev_lse - mint.nn.functional.logsigmoid(prev_lse - lse) + prev_out = out + prev_lse = lse + + out = out.to(query.dtype) + lse = lse.squeeze(-1) + + return (out, lse) if return_lse else out + + def bprop( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + out, + dout, + ): + ring_mesh = self._parallel_config.context_parallel_config._ring_mesh + rank = self._parallel_config.context_parallel_config._ring_local_rank + world_size = self._parallel_config.context_parallel_config.ring_degree + next_rank = (rank + 1) % world_size + next_ranks = list(range(1, world_size)) + [0] + + accum_dtype = ms.float32 if self._parallel_config.context_parallel_config.convert_to_fp32 else dout.dtype + grad_query = mint.zeros(self.q_shape, dtype=accum_dtype) + grad_key = mint.zeros(self.kv_shape, dtype=accum_dtype) + grad_value = mint.zeros(self.kv_shape, dtype=accum_dtype) + next_grad_kv = None + + kv_buffer = mint.cat([key.flatten(), value.flatten()]).contiguous() + group_size = mint.distributed.get_world_size(ring_mesh) + kv_buffer_output = mint.cat([mint.zeros_like(kv_buffer) for _ in range(group_size)], dim=0) + _ = mint.distributed.all_gather_into_tensor(kv_buffer_output, kv_buffer, group=ring_mesh) + kv_buffer = kv_buffer_output.chunk(world_size) + + for i in range(world_size): + if i > 0: + kv = kv_buffer[next_rank] + key_numel = key.numel() + key = kv[:key_numel].reshape_as(key) + value = kv[key_numel:].reshape_as(value) + next_rank = (next_rank + 1) % world_size + + grad_query_op, grad_key_op, grad_value_op, *_ = ( + ms.grad(self.forward_op, grad_position=(0, 1, 2))( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + True, + ), + ) + + if i > 0: + grad_kv_buffer = next_grad_kv + grad_key_numel = grad_key.numel() + grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key) + grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value) + + grad_query += grad_query_op + grad_key += grad_key_op + grad_value += grad_value_op + + if i < world_size - 1: + grad_kv_buffer = mint.cat([grad_key.flatten(), grad_value.flatten()]).contiguous() + next_grad_kv = permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh) + + grad_query, grad_key, grad_value = (x.to(dout.dtype) for x in (grad_query, grad_key, grad_value)) + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + + +class TemplatedUlyssesAttention(nn.Cell): + def __init__(self): + super().__init__() + self.forward_op = None + self.backward_op = None + self.q_shape = None + self.kv_shape = None + self._parallel_config = None + + def construct( + self, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + world_size = _parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh + + self.forward_op = forward_op + self.backward_op = backward_op + self._parallel_config = _parallel_config + + B, S_Q_LOCAL, H, D = query.shape + _, S_KV_LOCAL, _, _ = key.shape + H_LOCAL = H // world_size + query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + query, key, value = (_all_to_all_single(x, group) for x in (query, key, value)) + query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) + + out = forward_op( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx=True, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse, *_ = out + + out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + out = _all_to_all_single(out, group) + out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous() + + if return_lse: + lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous() + lse = _all_to_all_single(lse, group) + lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous() + else: + lse = None + + return (out, lse) if return_lse else out + + def bprop( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + out, + dout, + ): + ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh + world_size = self._parallel_config.context_parallel_config.ulysses_degree + group = ulysses_mesh + + B, S_LOCAL, H, D = dout.shape + H_LOCAL = H // world_size + + grad_out = dout.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + grad_out = _all_to_all_single(grad_out, group) + grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() + + # grad_query_op, grad_key_op, grad_value_op, *_ = self.backward_op(self, grad_out) + grad_query_op, grad_key_op, grad_value_op, *_ = ( + ms.grad(self.forward_op, grad_position=(0, 1, 2))( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + _save_ctx=True, + _parallel_config=_parallel_config, + ), + ) + + grad_query, grad_key, grad_value = ( + x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() + for x in (grad_query_op, grad_key_op, grad_value_op) + ) + grad_query, grad_key, grad_value = (_all_to_all_single(x, group) for x in (grad_query, grad_key, grad_value)) + grad_query, grad_key, grad_value = ( + x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value) + ) + + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + + +def _templated_context_parallel_attention( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + *, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, +): + if attn_mask is not None: + raise ValueError("Attention mask is not yet supported for templated attention.") + if is_causal: + raise ValueError("Causal attention is not yet supported for templated attention.") + if enable_gqa: + raise ValueError("GQA is not yet supported for templated attention.") + + # TODO: add support for unified attention with ring/ulysses degree both being > 1 + if _parallel_config.context_parallel_config.ring_degree > 1: + return TemplatedRingAttention()( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + elif _parallel_config.context_parallel_config.ulysses_degree > 1: + return TemplatedUlyssesAttention()( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + else: + raise ValueError("Reaching this branch of code is unexpected. Please report a bug.") + + +# ===== Attention backends ===== + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], + supports_context_parallel=True, +) +def _flash_attention( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> ms.Tensor: + lse = None + if _parallel_config is None: + out = FlashAttentionCell().construct( + query=query, + key=key, + value=value, + dropout_p=dropout_p, + scale=scale, + is_causal=is_causal, + return_lse=return_lse, + ) + if return_lse: + out, lse = out + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + is_causal, + scale, + False, + return_lse, + forward_op=FlashAttentionCell().construct, + backward_op=FlashAttentionCell().bprop, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse = out + + return (out, lse) if return_lse else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.NATIVE, + constraints=[_check_device, _check_shape], + supports_context_parallel=True, +) +def _native_attention( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> ms.Tensor: + if return_lse: + raise ValueError("Native attention backend does not support setting `return_lse=True`.") + if _parallel_config is None: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + else: + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op=NativeAttentionCell().construct, + backward_op=NativeAttentionCell().bprop, + _parallel_config=_parallel_config, + ) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_FLASH, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_flash_attention( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> ms.Tensor: + if return_lse: + raise ValueError("Native flash attention backend does not support setting `return_lse=True`.") + if enable_gqa: + raise ValueError("Native flash attention backend does not support setting `enable_gqa=True`.") + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=None, # not supported + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_MATH, + constraints=[_check_device, _check_shape], +) +def _native_math_attention( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> ms.Tensor: + if return_lse: + raise ValueError("Native math attention backend does not support setting `return_lse=True`.") + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + backend="math", + ) + out = out.permute(0, 2, 1, 3) + return out diff --git a/mindone/diffusers/models/layers_compat.py b/mindone/diffusers/models/layers_compat.py index f41672aad4..bd0d30df28 100644 --- a/mindone/diffusers/models/layers_compat.py +++ b/mindone/diffusers/models/layers_compat.py @@ -24,6 +24,9 @@ - **unflatten**: Always custom due to framework limitations. [2025/10/22] - **RMSNorm**: Always custom due to framework limitations. + [2025/11/12] + - **scaled_dot_product_attention**: Always custom due to framework limitations. + - *DeviceMesh*: Always custom due to framework limitations. Example: Import this module and use the operators as you would with native MindSpore functions, with the assurance of cross-version compatibility. @@ -36,6 +39,7 @@ - ... """ +import math import numbers from typing import Optional, Union @@ -63,6 +67,8 @@ "view_as_complex", "unflatten", "RMSNorm", + "scaled_dot_product_attention", + "DeviceMesh", ] MINDSPORE_VERSION = parse(ms.__version__) @@ -615,3 +621,153 @@ def extra_repr(self) -> str: Return the extra representation of the module. """ return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) + + +# ================================================================================ +# DeviceMesh +# ================================================================================ +class DeviceMesh: + def __init__(self, device_type, mesh_shape, mesh_dim_names): + self.device_type = device_type + self.mesh_shape = mesh_shape + self.mesh_dim_names = mesh_dim_names + + dim0, dim1 = mesh_shape + + self.mesh = [[r + i * dim1 for r in range(dim1)] for i in range(dim0)] + self.groups = {} + + current_rank = mint.distributed.get_rank() + col_groups = [[self.mesh[r][c] for r in range(dim0)] for c in range(dim1)] + + self.groups[mesh_dim_names[0]] = next( + (mint.distributed.new_group(ranks=r) for r in col_groups if current_rank in r), None + ) + self.groups[mesh_dim_names[1]] = next( + (mint.distributed.new_group(ranks=r) for r in self.mesh if current_rank in r), None + ) + + def __getitem__(self, dim_name): + return self.groups[dim_name] + + def _flatten(self): + flat_ranks = [rank for row in self.mesh for rank in row] + return mint.distributed.new_group(ranks=flat_ranks) + + +# ================================================================================ +# scaled_dot_product_attention +# ================================================================================ +def scaled_dot_product_attention( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: Optional[bool] = False, + backend: Optional[str] = "flash", +): + head_dim = query.shape[-1] + + # Note: PyTorch's SDPA and MindSpore's FA handle `attention_mask` slightly differently. + # In PyTorch, if the mask is not boolean (e.g., float32 with 0/1 values), it is interpreted + # as an additive bias: `attn_bias = attn_mask + attn_bias`. + # This implicit branch may lead to issues if the pipeline mistakenly provides + # a 0/1 float mask instead of a boolean mask. + # While this behavior is consistent with HF Diffusers for now, + # it may still be a potential bug source worth validating. + if ( + (attn_mask is not None and attn_mask.dtype != ms.bool_ and 1.0 in attn_mask) + or head_dim > 512 + or backend == "math" + or enable_gqa + ): + out = math_attention_op(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa) + + if query.dtype in (ms.float16, ms.bfloat16): + out = flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale) + else: + out = flash_attention_op( + query.to(ms.float16), + key.to(ms.float16), + value.to(ms.float16), + attn_mask, + keep_prob=1 - dropout_p, + scale=scale, + ).to(query.dtype) + return out + + +def math_attention_op( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +): + L, S = query.shape[-2], key.shape[-2] + scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale + attn_bias = mint.zeros((L, S), dtype=query.dtype) + if is_causal: + if attn_mask is not None: + if attn_mask.dtype == ms.bool_: + attn_mask = mint.logical_and(attn_mask, mint.ones((L, S), dtype=ms.bool_).tril(diagonal=0)) + else: + attn_mask = attn_mask + mint.triu(mint.full((L, S), float("-inf"), dtype=attn_mask.dtype), diagonal=1) + else: + temp_mask = mint.ones((L, S), dtype=ms.bool_).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias = attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == ms.bool_: + attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.shape[-3] // key.shape[-3], -3) + value = value.repeat_interleave(query.shape[-3] // value.shape[-3], -3) + + attn_weight = mint.matmul(query, key.swapaxes(-2, -1)) * scale_factor + attn_weight += attn_bias + attn_weight = mint.softmax(attn_weight, dim=-1) + attn_weight = ops.dropout(attn_weight, dropout_p, training=True) + return mint.matmul(attn_weight, value) + + +def flash_attention_op( + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + keep_prob: float = 1.0, + scale: Optional[float] = None, +): + # For most scenarios, qkv has been processed into a BNSD layout before sdp + input_layout = "BNSD" + head_num = query.shape[1] + if scale is None: + scale = query.shape[-1] ** (-0.5) + + # In case qkv is 3-dim after `head_to_batch_dim` + if query.ndim == 3: + input_layout = "BSH" + head_num = 1 + + # process `attn_mask` as logic is different between PyTorch and Mindspore + # In MindSpore, False indicates retention and True indicates discard, in PyTorch it is the opposite + if attn_mask is not None: + attn_mask = mint.logical_not(attn_mask) if attn_mask.dtype == ms.bool_ else attn_mask.bool() + attn_mask = mint.broadcast_to( + attn_mask, (attn_mask.shape[0], attn_mask.shape[1], query.shape[-2], key.shape[-2]) + )[:, :1, :, :] + + return ops.operations.nn_ops.FlashAttentionScore( + head_num=head_num, keep_prob=keep_prob, scale_value=scale, input_layout=input_layout + )(query, key, value, None, None, None, attn_mask)[3] diff --git a/mindone/diffusers/models/modeling_utils.py b/mindone/diffusers/models/modeling_utils.py index fabe98c449..0e64249a51 100644 --- a/mindone/diffusers/models/modeling_utils.py +++ b/mindone/diffusers/models/modeling_utils.py @@ -53,6 +53,8 @@ logging, ) from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card +from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig +from .layers_compat import DeviceMesh from .model_loading_utils import ( _fetch_index_file, _fetch_index_file_legacy, @@ -181,6 +183,8 @@ class ModelMixin(nn.Cell, PushToHubMixin): _skip_layerwise_casting_patterns = None _supports_group_offloading = True _repeated_blocks = [] + _parallel_config = None + _cp_plan = None def __init__(self): super().__init__() @@ -436,6 +440,65 @@ def enable_group_offload( """ raise NotImplementedError("`enable_group_offload` is not yet supported.") + def set_attention_backend(self, backend: str) -> None: + """ + Set the attention backend for the model. + + Args: + backend (`str`): + The name of the backend to set. Must be one of the available backends defined in + `AttentionBackendName`. Available backends can be found in + `mindone.diffusers.attention_dispatch.AttentionBackendName`. Defaults to mindone native scaled dot product + attention as backend. + """ + from .attention import AttentionModuleMixin + from .attention_dispatch import ( # _maybe_download_kernel_for_backend, + AttentionBackendName, + _check_attention_backend_requirements, + ) + + # TODO: the following will not be required when everything is refactored to AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + + backend = backend.lower() + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend) + _check_attention_backend_requirements(backend) + # _maybe_download_kernel_for_backend(backend) + + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + for _, module in self.cells_and_names(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = backend + + def reset_attention_backend(self) -> None: + """ + Resets the attention backend for the model. Following calls to `forward` will use the environment default, if + set, or the mindone native scaled dot product attention. + """ + from .attention import AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + logger.warning("Attention backends are an experimental feature and the API may be subject to change.") + + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + for _, module in self.cells_and_names(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = None + def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -678,6 +741,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_safetensors = kwargs.pop("use_safetensors", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) + parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None) is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING if is_parallel_loading_enabled: @@ -902,6 +966,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Set model in evaluation mode to deactivate DropOut modules by default model.set_train(False) + if parallel_config is not None: + model.enable_parallelism(config=parallel_config) + if output_loading_info: return model, loading_info @@ -955,6 +1022,73 @@ def compile_repeated_blocks(self, *args, **kwargs): f"Regional compilation failed because {repeated_blocks} classes are not found in the model. " ) + def enable_parallelism( + self, + *, + config: Union[ParallelConfig, ContextParallelConfig], + cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None, + ): + from ..hooks.context_parallel import apply_context_parallel + from .attention import AttentionModuleMixin + from .attention_processor import Attention, MochiAttention + + logger.warning( + "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." # noqa + ) + + if isinstance(config, ContextParallelConfig): + config = ParallelConfig(context_parallel_config=config) + + if not mint.distributed.is_initialized(): + raise RuntimeError("mint.distributed must be initialized before calling `enable_parallelism`.") + + rank = mint.distributed.get_rank() + world_size = mint.distributed.get_world_size() + device_type = "Ascend" + rank = mint.distributed.get_rank() + device = ms.get_current_device() + + cp_mesh = None + if config.context_parallel_config is not None: + cp_config = config.context_parallel_config + if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1: + raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") + if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1: + raise ValueError( + "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." + ) + if cp_config.ring_degree * cp_config.ulysses_degree > world_size: + raise ValueError( + f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})." # noqa + ) + cp_mesh = DeviceMesh( + device_type=device_type, + mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree), + mesh_dim_names=("ring", "ulysses"), + ) + + config.setup(rank, world_size, device, cp_mesh=cp_mesh) + + if cp_plan is None and self._cp_plan is None: + raise ValueError( + "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." + ) + cp_plan = cp_plan if cp_plan is not None else self._cp_plan + + if config.context_parallel_config is not None: + apply_context_parallel(self, config.context_parallel_config, cp_plan) + + self._parallel_config = config + + attention_classes = (Attention, MochiAttention, AttentionModuleMixin) + for _, module in self.cells_and_names(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_parallel_config"): + continue + processor._parallel_config = config + @classmethod def _load_pretrained_model( cls, diff --git a/mindone/diffusers/models/transformers/transformer_bria.py b/mindone/diffusers/models/transformers/transformer_bria.py index 86393381a2..340237dc2b 100644 --- a/mindone/diffusers/models/transformers/transformer_bria.py +++ b/mindone/diffusers/models/transformers/transformer_bria.py @@ -1,16 +1,16 @@ -import math from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import mindspore as ms import mindspore.nn as nn -from mindspore import mint, ops +from mindspore import mint from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging from ..attention import AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding from ..layers_compat import unflatten @@ -152,7 +152,14 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - hidden_states = attn.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) @@ -234,89 +241,6 @@ def __init__( processor = self._default_processor_cls() self.set_processor(processor) - def scaled_dot_product_attention( - self, - query: ms.Tensor, - key: ms.Tensor, - value: ms.Tensor, - attn_mask: Optional[ms.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - ): - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - # Note: PyTorch's SDPA and MindSpore's FA handle `attention_mask` slightly differently. - # In PyTorch, if the mask is not boolean (e.g., float32 with 0/1 values), it is interpreted - # as an additive bias: `attn_bias = attn_mask + attn_bias`. - # This implicit branch may lead to issues if the pipeline mistakenly provides - # a 0/1 float mask instead of a boolean mask. - # While this behavior is consistent with HF Diffusers for now, - # it may still be a potential bug source worth validating. - if attn_mask is not None and attn_mask.dtype != ms.bool_ and 1.0 in attn_mask: - L, S = query.shape[-2], key.shape[-2] - scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale - attn_bias = mint.zeros((L, S), dtype=query.dtype) - if is_causal: - assert attn_mask is None - temp_mask = mint.ones((L, S), dtype=ms.bool_).tril(diagonal=0) - attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf")) - attn_bias.to(query.dtype) - - if attn_mask is not None: - if attn_mask.dtype == ms.bool_: - attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf")) - else: - attn_bias = attn_mask + attn_bias - - attn_weight = mint.matmul(query, key.swapaxes(-2, -1)) * scale_factor - attn_weight += attn_bias - attn_weight = mint.softmax(attn_weight, dim=-1) - attn_weight = ops.dropout(attn_weight, dropout_p, training=True) - return mint.matmul(attn_weight, value).permute(0, 2, 1, 3) - - if query.dtype in (ms.float16, ms.bfloat16): - out = self.flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale) - else: - out = self.flash_attention_op( - query.to(ms.float16), - key.to(ms.float16), - value.to(ms.float16), - attn_mask, - keep_prob=1 - dropout_p, - scale=scale, - ).to(query.dtype) - return out.permute(0, 2, 1, 3) - - def flash_attention_op( - self, - query: ms.Tensor, - key: ms.Tensor, - value: ms.Tensor, - attn_mask: Optional[ms.Tensor] = None, - keep_prob: float = 1.0, - scale: Optional[float] = None, - ): - # For most scenarios, qkv has been processed into a BNSD layout before sdp - input_layout = "BNSD" - head_num = query.shape[1] - - # In case qkv is 3-dim after `head_to_batch_dim` - if query.ndim == 3: - input_layout = "BSH" - head_num = 1 - - # process `attn_mask` as logic is different between PyTorch and Mindspore - # In MindSpore, False indicates retention and True indicates discard, in PyTorch it is the opposite - if attn_mask is not None: - attn_mask = mint.logical_not(attn_mask) if attn_mask.dtype == ms.bool_ else attn_mask.bool() - attn_mask = mint.broadcast_to( - attn_mask, (attn_mask.shape[0], attn_mask.shape[1], query.shape[-2], key.shape[-2]) - )[:, :1, :, :] - - return ops.operations.nn_ops.FlashAttentionScore( - head_num=head_num, keep_prob=keep_prob, scale_value=scale or self.scale, input_layout=input_layout - )(query, key, value, None, None, None, attn_mask)[3] - def construct( self, hidden_states: ms.Tensor, diff --git a/mindone/diffusers/models/transformers/transformer_flux.py b/mindone/diffusers/models/transformers/transformer_flux.py index 7acbff94a1..e5221d6028 100644 --- a/mindone/diffusers/models/transformers/transformer_flux.py +++ b/mindone/diffusers/models/transformers/transformer_flux.py @@ -24,7 +24,9 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, @@ -70,8 +72,10 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st return _get_projections(attn, hidden_states, encoder_hidden_states) -@ms.jit_class class FluxAttnProcessor: + _attention_backend = None + _parallel_config = None + def __call__( self, attn: "FluxAttention", @@ -107,7 +111,15 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - hidden_states = attn.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + # hidden_states = attn.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) @@ -131,6 +143,9 @@ def __call__( class FluxIPAdapterAttnProcessor(nn.Cell): """Flux Attention processor for IP-Adapter.""" + _attention_backend = None + _parallel_config = None + def __init__(self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, dtype=None): super().__init__() @@ -193,14 +208,17 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - hidden_states = attn.scaled_dot_product_attention( + hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) @@ -228,13 +246,15 @@ def __call__( ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) - current_ip_hidden_states = attn.scaled_dot_product_attention( + current_ip_hidden_states = dispatch_attention_fn( ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) @@ -645,6 +665,15 @@ class FluxTransformer2DModel( _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } @register_to_config def __init__( diff --git a/mindone/diffusers/models/transformers/transformer_ltx.py b/mindone/diffusers/models/transformers/transformer_ltx.py index 265df3cb4a..d1517d56ca 100644 --- a/mindone/diffusers/models/transformers/transformer_ltx.py +++ b/mindone/diffusers/models/transformers/transformer_ltx.py @@ -21,12 +21,14 @@ from typing import Any, Dict, Optional, Tuple, Union import mindspore as ms -from mindspore import mint, nn, ops +from mindspore import mint, nn from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import deprecate, logging +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection from ..layers_compat import unflatten @@ -51,6 +53,9 @@ class LTXVideoAttnProcessor: model. It applies a normalization layer and rotary embedding on the query and key vector. """ + _attention_backend = None + _parallel_config = None + def __call__( self, attn: "LTXAttention", @@ -85,8 +90,15 @@ def __call__( key = unflatten(key, 2, (attn.heads, -1)) value = unflatten(value, 2, (attn.heads, -1)) - hidden_states = attn.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.to(query.dtype) @@ -143,60 +155,6 @@ def __init__( processor = self._default_processor_cls() self.set_processor(processor) - def scaled_dot_product_attention( - self, - query: ms.Tensor, - key: ms.Tensor, - value: ms.Tensor, - attn_mask: Optional[ms.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - ): - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - if query.dtype in (ms.float16, ms.bfloat16): - out = self.flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale) - else: - out = self.flash_attention_op( - query.to(ms.float16), - key.to(ms.float16), - value.to(ms.float16), - attn_mask, - keep_prob=1 - dropout_p, - scale=scale, - ).to(query.dtype) - return out.permute(0, 2, 1, 3) - - def flash_attention_op( - self, - query: ms.Tensor, - key: ms.Tensor, - value: ms.Tensor, - attn_mask: Optional[ms.Tensor] = None, - keep_prob: float = 1.0, - scale: Optional[float] = None, - ): - # For most scenarios, qkv has been processed into a BNSD layout before sdp - input_layout = "BNSD" - head_num = query.shape[1] - - # In case qkv is 3-dim after `head_to_batch_dim` - if query.ndim == 3: - input_layout = "BSH" - head_num = 1 - - # process `attn_mask` as logic is different between PyTorch and Mindspore - # In MindSpore, False indicates retention and True indicates discard, in PyTorch it is the opposite - if attn_mask is not None: - attn_mask = mint.logical_not(attn_mask) if attn_mask.dtype == ms.bool_ else attn_mask.bool() - attn_mask = mint.broadcast_to( - attn_mask, (attn_mask.shape[0], attn_mask.shape[1], query.shape[-2], key.shape[-2]) - )[:, :1, :, :] - - return ops.operations.nn_ops.FlashAttentionScore( - head_num=head_num, keep_prob=keep_prob, scale_value=scale or self.scale, input_layout=input_layout - )(query, key, value, None, None, None, attn_mask)[3] - def construct( self, hidden_states: ms.Tensor, @@ -447,6 +405,18 @@ class LTXVideoTransformer3DModel( _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["LTXVideoTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } @register_to_config def __init__( diff --git a/mindone/diffusers/models/transformers/transformer_qwenimage.py b/mindone/diffusers/models/transformers/transformer_qwenimage.py index db0b47de8a..7c05c700cb 100644 --- a/mindone/diffusers/models/transformers/transformer_qwenimage.py +++ b/mindone/diffusers/models/transformers/transformer_qwenimage.py @@ -26,8 +26,11 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging -from ..attention import FeedForward +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..embeddings import TimestepEmbedding, Timesteps from ..layers_compat import unflatten, view_as_complex from ..modeling_outputs import Transformer2DModelOutput @@ -246,6 +249,7 @@ class QwenDoubleStreamAttnProcessor2_0: """ _attention_backend = None + _parallel_config = None def __call__( self, @@ -305,12 +309,16 @@ def __call__( joint_value = mint.cat([txt_value, img_value], dim=1) # Compute joint attention - # TODO: function dispatch_attention_fn.py - joint_query, joint_key, joint_value = (x.permute(0, 2, 1, 3) for x in (joint_query, joint_key, joint_value)) - joint_hidden_states = attn.scaled_dot_product_attention( - joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + joint_hidden_states = dispatch_attention_fn( + joint_query, + joint_key, + joint_value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) - joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3) # Reshape back joint_hidden_states = joint_hidden_states.flatten(2, 3) @@ -446,7 +454,9 @@ def construct( return encoder_hidden_states, hidden_states -class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class QwenImageTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): """ The Transformer model introduced in Qwen. @@ -476,6 +486,18 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro _no_split_modules = ["QwenImageTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _repeated_blocks = ["QwenImageTransformerBlock"] + _cp_plan = { + "": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "pos_embed": { + 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + } @register_to_config def __init__( diff --git a/mindone/diffusers/models/transformers/transformer_skyreels_v2.py b/mindone/diffusers/models/transformers/transformer_skyreels_v2.py index 1b7c0b1aeb..2298fa4c1a 100644 --- a/mindone/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/mindone/diffusers/models/transformers/transformer_skyreels_v2.py @@ -19,13 +19,14 @@ from typing import Any, Dict, Optional, Tuple, Union import mindspore as ms -from mindspore import mint, nn, ops +from mindspore import mint, nn +from mindspore.common.initializer import Zero from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import logging -from ..attention import FeedForward -from ..attention_processor import Attention +from ...utils import deprecate, logging +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import ( PixArtAlphaTextProjection, @@ -33,7 +34,7 @@ get_1d_rotary_pos_embed, get_1d_sincos_pos_embed_from_grid, ) -from ..layers_compat import unflatten, view_as_complex +from ..layers_compat import RMSNorm, unflatten from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -41,10 +42,42 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class SkyReelsV2AttnProcessor2_0: +def _get_qkv_projections(attn: "SkyReelsV2Attention", hidden_states: ms.Tensor, encoder_hidden_states: ms.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if attn.cross_attention_dim_head is None: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states_img: ms.Tensor): + if attn.fused_projections: + key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1) + else: + key_img = attn.add_k_proj(encoder_hidden_states_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + return key_img, value_img + + +class SkyReelsV2AttnProcessor: + _attention_backend = None + _parallel_config = None + def __call__( self, - attn: Attention, + attn: "SkyReelsV2Attention", hidden_states: ms.Tensor, encoder_hidden_states: Optional[ms.Tensor] = None, attention_mask: Optional[ms.Tensor] = None, @@ -56,28 +89,30 @@ def __call__( image_context_length = encoder_hidden_states.shape[1] - 512 encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] encoder_hidden_states = encoder_hidden_states[:, image_context_length:] - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) + query = attn.norm_q(query) + key = attn.norm_k(key) - query = unflatten(query, 2, (attn.heads, -1)).swapaxes(1, 2) - key = unflatten(key, 2, (attn.heads, -1)).swapaxes(1, 2) - value = unflatten(value, 2, (attn.heads, -1)).swapaxes(1, 2) + query = unflatten(query, 2, (attn.heads, -1)) + key = unflatten(key, 2, (attn.heads, -1)) + value = unflatten(value, 2, (attn.heads, -1)) if rotary_emb is not None: - def apply_rotary_emb(hidden_states: ms.Tensor, freqs: ms.Tensor): - x_rotated = view_as_complex(unflatten(hidden_states.to(ms.float32), 3, (-1, 2))) - x_out = ops.view_as_real(x_rotated * freqs).flatten(3, 4) - return x_out.type_as(hidden_states) + def apply_rotary_emb( + hidden_states: ms.Tensor, + freqs_cos: ms.Tensor, + freqs_sin: ms.Tensor, + ): + x1, x2 = unflatten(hidden_states, -1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = mint.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) query = apply_rotary_emb(query, rotary_emb) key = apply_rotary_emb(key, rotary_emb) @@ -85,29 +120,37 @@ def apply_rotary_emb(hidden_states: ms.Tensor, freqs: ms.Tensor): # I2V task hidden_states_img = None if encoder_hidden_states_img is not None: - key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img) key_img = attn.norm_added_k(key_img) - value_img = attn.add_v_proj(encoder_hidden_states_img) - - key_img = unflatten(key_img, 2, (attn.heads, -1)).swapaxes(1, 2) - value_img = unflatten(value_img, 2, (attn.heads, -1)).swapaxes(1, 2) - hidden_states_img = attn.scaled_dot_product_attention( - query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + key_img = unflatten(key_img, 2, (attn.heads, -1)) + value_img = unflatten(value_img, 2, (attn.heads, -1)) + + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) - hidden_states_img = hidden_states_img.swapaxes(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) - hidden_states = attn.scaled_dot_product_attention( + hidden_states = dispatch_attention_fn( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) - hidden_states = hidden_states.swapaxes(1, 2).flatten(2, 3) + hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) if hidden_states_img is not None: @@ -118,7 +161,127 @@ def apply_rotary_emb(hidden_states: ms.Tensor, freqs: ms.Tensor): return hidden_states -# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding +class SkyReelsV2AttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "The SkyReelsV2AttnProcessor2_0 class is deprecated and will be removed in a future version. " + "Please use SkyReelsV2AttnProcessor instead. " + ) + deprecate("SkyReelsV2AttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + return SkyReelsV2AttnProcessor(*args, **kwargs) + + +class SkyReelsV2Attention(ms.nn.Cell, AttentionModuleMixin): + _default_processor_cls = SkyReelsV2AttnProcessor + _available_processors = [SkyReelsV2AttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + processor=None, + is_cross_attention=None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = mint.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = mint.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = mint.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = ms.nn.CellList( + [ + mint.nn.Linear(self.inner_dim, dim, bias=True), + mint.nn.Dropout(dropout), + ] + ) + self.norm_q = RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = mint.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = mint.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = RMSNorm(dim_head * heads, eps=eps) + + self.is_cross_attention = cross_attention_dim_head is not None + + self.set_processor(processor) + + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if self.cross_attention_dim_head is None: + concatenated_weights = mint.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = mint.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + # with torch.device("meta"): + # self.to_qkv = nn.Linear(in_features, out_features, bias=True) + self.to_qkv = mint.nn.Linear(in_features, out_features, bias=True, weight_init=Zero(), bias_init=Zero()) + self.to_qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = mint.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = mint.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + # with torch.device("meta"): + # self.to_kv = nn.Linear(in_features, out_features, bias=True) + self.to_kv = mint.nn.Linear(in_features, out_features, bias=True, weight_init=Zero(), bias_init=Zero()) + self.to_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = mint.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = mint.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + # with torch.device("meta"): + # self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv = mint.nn.Linear( + in_features, out_features, bias=True, weight_init=Zero(), bias_init=Zero() + ) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @ms._no_grad() + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + rotary_emb: Optional[Tuple[ms.Tensor, ms.Tensor]] = None, + **kwargs, + ) -> ms.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs) + + class SkyReelsV2ImageEmbedding(ms.nn.Cell): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): super().__init__() @@ -209,7 +372,11 @@ def construct( class SkyReelsV2RotaryPosEmbed(nn.Cell): def __init__( - self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, ): super().__init__() @@ -219,37 +386,52 @@ def __init__( h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim + freqs_dtype = ms.float64 + + self.t_dim = t_dim + self.h_dim = h_dim + self.w_dim = w_dim + + freqs_cos = [] + freqs_sin = [] - freqs = [] for dim in [t_dim, h_dim, w_dim]: - freq = get_1d_rotary_pos_embed( - dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=ms.float32 + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, ) - freqs.append(freq) - self.freqs = mint.cat(freqs, dim=1) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", mint.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", mint.cat(freqs_sin, dim=1), persistent=False) def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - freqs = self.freqs - # freqs = freqs.split_with_sizes() - freqs = mint.split( - freqs, - [ - self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), - self.attention_head_dim // 6, - self.attention_head_dim // 6, - ], - dim=1, - ) + split_sizes = [self.t_dim, self.h_dim, self.w_dim] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) - freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).tile((1, pph, ppw, 1)) - freqs_h = freqs[1][:pph].view(1, pph, 1, -1).tile((ppf, 1, ppw, 1)) - freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).tile((ppf, pph, 1, 1)) - freqs = mint.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) - return freqs + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand((ppf, pph, ppw, -1)) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand((ppf, pph, ppw, -1)) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand((ppf, pph, ppw, -1)) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand((ppf, pph, ppw, -1)) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand((ppf, pph, ppw, -1)) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand((ppf, pph, ppw, -1)) + + freqs_cos = mint.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = mint.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + + return freqs_cos, freqs_sin class SkyReelsV2TransformerBlock(nn.Cell): @@ -267,33 +449,24 @@ def __init__( # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.attn1 = Attention( - query_dim=dim, + self.attn1 = SkyReelsV2Attention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, - processor=SkyReelsV2AttnProcessor2_0(), + cross_attention_dim_head=None, + processor=SkyReelsV2AttnProcessor(), ) # 2. Cross-attention - self.attn2 = Attention( - query_dim=dim, + self.attn2 = SkyReelsV2Attention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, added_kv_proj_dim=added_kv_proj_dim, - added_proj_bias=True, - processor=SkyReelsV2AttnProcessor2_0(), + cross_attention_dim_head=dim // num_heads, + processor=SkyReelsV2AttnProcessor(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else mint.nn.Identity() @@ -320,13 +493,15 @@ def construct( # For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim) e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] + # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask) + attn_output = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) hidden_states = hidden_states + attn_output # 3. Feed-forward @@ -335,10 +510,13 @@ def construct( ) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + return hidden_states -class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): +class SkyReelsV2Transformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): r""" A Transformer model for video-like data used in the Wan-based SkyReels-V2 model. @@ -386,11 +564,12 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr _no_split_modules = ["SkyReelsV2TransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["SkyReelsV2TransformerBlock"] @register_to_config def __init__( self, - patch_size: Tuple[int] = (1, 2, 2), + patch_size: Tuple[int, ...] = (1, 2, 2), num_attention_heads: int = 16, attention_head_dim: int = 128, in_channels: int = 16, @@ -463,11 +642,17 @@ def construct( return_dict: bool = False, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[ms.Tensor, Dict[str, ms.Tensor]]: - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective.") + # weight the lora layers by setting `lora_scale` for each PEFT layer here + # and remove `lora_scale` from each PEFT layer at the end. + # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode + raise RuntimeError( + f"You are trying to set scaling of lora layer by passing {attention_kwargs['scale']=}. " + f"However it's not allowed in on-the-fly model forwarding. " + f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and " + f"`unscale_lora_layers(model, lora_scale)` after model forwarding. " + f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`." + ) batch_size, num_channels, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config["patch_size"] diff --git a/mindone/diffusers/models/transformers/transformer_wan.py b/mindone/diffusers/models/transformers/transformer_wan.py index 24d1673c99..9aa9bad5e5 100644 --- a/mindone/diffusers/models/transformers/transformer_wan.py +++ b/mindone/diffusers/models/transformers/transformer_wan.py @@ -19,13 +19,15 @@ from typing import Any, Dict, Optional, Tuple, Union import mindspore as ms -from mindspore import mint, nn, ops +from mindspore import mint, nn from mindspore.common.initializer import Zero from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import deprecate, logging +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..layers_compat import unflatten @@ -66,6 +68,9 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: m class WanAttnProcessor: + _attention_backend = None + _parallel_config = None + def __call__( self, attn: "WanAttention", @@ -120,14 +125,28 @@ def apply_rotary_emb( key_img = unflatten(key_img, 2, (attn.heads, -1)) value_img = unflatten(value_img, 2, (attn.heads, -1)) - hidden_states_img = attn.scaled_dot_product_attention( - query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) - hidden_states = attn.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + parallel_config=self._parallel_config, ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -173,7 +192,6 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim self.cross_attention_dim_head = cross_attention_dim_head self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads - self.scale = dim_head**-0.5 self.to_q = mint.nn.Linear(dim, self.inner_dim, bias=True) self.to_k = mint.nn.Linear(dim, self.kv_inner_dim, bias=True) @@ -197,60 +215,6 @@ def __init__( self.set_processor(processor) - def scaled_dot_product_attention( - self, - query: ms.Tensor, - key: ms.Tensor, - value: ms.Tensor, - attn_mask: Optional[ms.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - ): - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - if query.dtype in (ms.float16, ms.bfloat16): - out = self.flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale) - else: - out = self.flash_attention_op( - query.to(ms.float16), - key.to(ms.float16), - value.to(ms.float16), - attn_mask, - keep_prob=1 - dropout_p, - scale=scale, - ).to(query.dtype) - return out.permute(0, 2, 1, 3) - - def flash_attention_op( - self, - query: ms.Tensor, - key: ms.Tensor, - value: ms.Tensor, - attn_mask: Optional[ms.Tensor] = None, - keep_prob: float = 1.0, - scale: Optional[float] = None, - ): - # For most scenarios, qkv has been processed into a BNSD layout before sdp - input_layout = "BNSD" - head_num = query.shape[1] - - # In case qkv is 3-dim after `head_to_batch_dim` - if query.ndim == 3: - input_layout = "BSH" - head_num = 1 - - # process `attn_mask` as logic is different between PyTorch and Mindspore - # In MindSpore, False indicates retention and True indicates discard, in PyTorch it is the opposite - if attn_mask is not None: - attn_mask = mint.logical_not(attn_mask) if attn_mask.dtype == ms.bool_ else attn_mask.bool() - attn_mask = mint.broadcast_to( - attn_mask, (attn_mask.shape[0], attn_mask.shape[1], query.shape[-2], key.shape[-2]) - )[:, :1, :, :] - - return ops.operations.nn_ops.FlashAttentionScore( - head_num=head_num, keep_prob=keep_prob, scale_value=scale or self.scale, input_layout=input_layout - )(query, key, value, None, None, None, attn_mask)[3] - def fuse_projections(self): if getattr(self, "fused_projections", False): return @@ -291,6 +255,7 @@ def fuse_projections(self): self.fused_projections = True + @ms._no_grad() def unfuse_projections(self): if not getattr(self, "fused_projections", False): return @@ -403,10 +368,16 @@ def __init__( h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim + + self.t_dim = t_dim + self.h_dim = h_dim + self.w_dim = w_dim + freqs_dtype = ms.float64 freqs_cos = [] freqs_sin = [] + for dim in [t_dim, h_dim, w_dim]: freq_cos, freq_sin = get_1d_rotary_pos_embed( dim, @@ -427,11 +398,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - split_sizes = [ - self.attention_head_dim - 2 * (self.attention_head_dim // 3), - self.attention_head_dim // 3, - self.attention_head_dim // 3, - ] + split_sizes = [self.t_dim, self.h_dim, self.w_dim] freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1) @@ -591,11 +558,27 @@ class WanTransformer3DModel( _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] _repeated_blocks = ["WanTransformerBlock"] + _cp_plan = { + "rope": { + 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), + 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), + }, + "blocks.0": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*": { + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), + "": { + "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + } @register_to_config def __init__( self, - patch_size: Tuple[int] = (1, 2, 2), + patch_size: Tuple[int, ...] = (1, 2, 2), num_attention_heads: int = 40, attention_head_dim: int = 128, in_channels: int = 16, From f3471dd099820b00f71ccd7453586054cb0a5a60 Mon Sep 17 00:00:00 2001 From: Cui-yshoho Date: Thu, 27 Nov 2025 15:31:54 +0800 Subject: [PATCH 2/3] modify bprop to _Function --- mindone/diffusers/hooks/context_parallel.py | 43 +- .../diffusers/models/attention_dispatch.py | 469 ++++++++---------- mindone/diffusers/models/layers_compat.py | 2 +- .../models/transformers/transformer_wan.py | 2 +- 4 files changed, 240 insertions(+), 276 deletions(-) diff --git a/mindone/diffusers/hooks/context_parallel.py b/mindone/diffusers/hooks/context_parallel.py index 1eff865d4c..a117fb2a86 100644 --- a/mindone/diffusers/hooks/context_parallel.py +++ b/mindone/diffusers/hooks/context_parallel.py @@ -199,10 +199,12 @@ def post_construct(self, module, output): def _prepare_cp_input(self, x: ms.Tensor, cp_input: ContextParallelInput) -> ms.Tensor: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: - raise ValueError( - f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions." + logger.warning_once( + f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied." ) - return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + return x + else: + return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) class ContextParallelGatherHook(ModelHook): @@ -232,33 +234,32 @@ def post_construct(self, module, output): return output[0] if is_tensor else tuple(output) -class AllGatherFunction(ms.nn.Cell): - def __init__(self, dim, group): - super().__init__() - self.dim = dim - self.group = group - self.world_size = mint.distributed.get_world_size(group) - self.rank = mint.distributed.get_rank(group) +class AllGatherFunction(ms.common._Function): + @staticmethod + def forward(ctx, tensor, dim, group): + ctx.dim = dim + ctx.group = group + ctx.world_size = mint.distributed.get_world_size(group) + ctx.rank = mint.distributed.get_rank(group) - def construct(self, tensor): - # return funcol.all_gather_tensor(tensor, dim, group=group) # mint.distributed.all_gather_into_tensor only support dim=0 - tensor_t = tensor.transpose(self.dim, 0) if self.dim != 0 else tensor + tensor_t = tensor.transpose(dim, 0) if dim != 0 else tensor out_shape = list(tensor_t.shape) - out_shape[0] *= self.world_size + out_shape[0] *= ctx.world_size output = mint.zeros(out_shape, dtype=tensor_t.dtype) - mint.distributed.all_gather_into_tensor(output, tensor_t.contiguous(), group=self.group) + mint.distributed.all_gather_into_tensor(output, tensor_t.contiguous(), group=group) - if self.dim != 0: - output = output.transpose(0, self.dim) + if dim != 0: + output = output.transpose(0, dim) return output - def bprop(self, tensor, out, dout): - grad_chunks = mint.chunk(dout, self.world_size, dim=self.dim) - return (grad_chunks[self.rank],) + @staticmethod + def backward(ctx, grad_output): + grad_chunks = mint.chunk(grad_output, ctx.world_size, dim=ctx.dim) + return grad_chunks[ctx.rank], None, None class EquipartitionSharder: @@ -278,7 +279,7 @@ def shard(cls, tensor: ms.Tensor, dim: int, mesh) -> ms.Tensor: @classmethod def unshard(cls, tensor: ms.Tensor, dim: int, mesh) -> ms.Tensor: tensor = tensor.contiguous() - tensor = AllGatherFunction(dim, mesh)(tensor) + tensor = AllGatherFunction.apply(tensor, dim, mesh) return tensor diff --git a/mindone/diffusers/models/attention_dispatch.py b/mindone/diffusers/models/attention_dispatch.py index b7a598b2e1..6fcff4f04c 100644 --- a/mindone/diffusers/models/attention_dispatch.py +++ b/mindone/diffusers/models/attention_dispatch.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union import mindspore as ms -from mindspore import mint, nn, ops +from mindspore import mint, ops from ..utils import get_logger from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS @@ -463,157 +463,167 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # ===== Helper functions to use attention backends with templated CP autograd functions ===== -class NativeAttentionCell(nn.Cell): - def __init__(self): - super().__init__() +def _native_attention_forward_op( + ctx, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + # Native attention does not return_lse + if return_lse: + raise ValueError("Native attention does not support return_lse=True") - def construct( - self, - query: ms.Tensor, - key: ms.Tensor, - value: ms.Tensor, - attn_mask: Optional[ms.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, - return_lse: bool = False, - _save_ctx: bool = True, - _parallel_config: Optional["ParallelConfig"] = None, - ): - # Native attention does not return_lse - if return_lse: - raise ValueError("Native attention does not support return_lse=True") + # used for backward pass + if _save_ctx: + ctx.save_for_backward(query, key, value) + ctx.attn_mask = attn_mask + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.enable_gqa = enable_gqa - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + return out + + +def _native_attention_backward_op( + ctx, + grad_out: ms.Tensor, + *args, + **kwargs, +): + query, key, value = ctx.saved_tensors + + query._requires_grad = True + key._requires_grad = True + value._requires_grad = True + + query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + def forward_fn(q, k, v): out = scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, + query=query_t, + key=key_t, + value=value_t, + attn_mask=ctx.attn_mask, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + enable_gqa=ctx.enable_gqa, ) out = out.permute(0, 2, 1, 3) - return out - def bprop( - self, - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - _save_ctx, - _parallel_config, - out, - dout, - ): - query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - - def forward_fn(q, k, v): - out = scaled_dot_product_attention( - query=q, - key=k, - value=v, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) - return out - - grad_query_t, grad_key_t, grad_value_t = ms.grad(forward_fn, grad_position=(0, 1, 2))(query_t, key_t, value_t) + grad_out_t = grad_out.permute(0, 2, 1, 3) # noqa + grad_query_t, grad_key_t, grad_value_t = ms.grad(forward_fn, grad_position=(0, 1, 2))(query_t, key_t, value_t) - grad_query = grad_query_t.permute(0, 2, 1, 3) - grad_key = grad_key_t.permute(0, 2, 1, 3) - grad_value = grad_value_t.permute(0, 2, 1, 3) + grad_query = grad_query_t.permute(0, 2, 1, 3) + grad_key = grad_key_t.permute(0, 2, 1, 3) + grad_value = grad_value_t.permute(0, 2, 1, 3) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value # Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807 -class FlashAttentionCell(nn.Cell): - def __init__(self): - super().__init__() - - def construct( - self, - query: ms.Tensor, - key: ms.Tensor, - value: ms.Tensor, - attn_mask: Optional[ms.Tensor] = None, - dropout_p: float = 0.0, - is_causal: bool = False, - scale: Optional[float] = None, - enable_gqa: bool = False, - return_lse: bool = False, - _save_ctx: bool = True, - _parallel_config: Optional["ParallelConfig"] = None, - ): - # Hardcoded for now - grad_enabled = any(x._requires_grad for x in (query, key, value)) - - if scale is None: - scale = query.shape[-1] ** (-0.5) +def _flash_attention_forward_op( + ctx, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + # if attn_mask is not None: + # raise ValueError("`attn_mask` is not yet supported for flash-attn.") + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for flash-attn.") - if is_causal: - sparse_mode = 2 - else: - sparse_mode = 0 + # Hardcoded for now + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + grad_enabled = any(x._requires_grad for x in (query, key, value)) - # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. - if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): - dropout_p = dropout_p if dropout_p > 0 else 1e-30 + if scale is None: + scale = query.shape[-1] ** (-0.5) - input_layout = "BSND" - head_num = query.shape[2] + if is_causal: + sparse_mode = 2 + else: + sparse_mode = 0 + + # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround. + if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): + dropout_p = dropout_p if dropout_p > 0 else 1e-30 + + input_layout = "BSND" + head_num = query.shape[2] + + softmax_max, softmax_sum, _, out = ops.operations.nn_ops.FlashAttentionScore( + head_num=head_num, + keep_prob=1 - dropout_p, + scale_value=scale, + input_layout=input_layout, + sparse_mode=sparse_mode, + )(query, key, value, None, None, None, attn_mask) + lse = softmax_max[..., 0] + mint.log(softmax_sum[..., 0]) + lse = lse.permute(0, 2, 1) + + if _save_ctx: + ctx.save_for_backward(query, key, value, out, lse) + ctx.dropout_p = dropout_p + ctx.scale = scale + ctx.is_causal = is_causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic - softmax_max, softmax_sum, _, out = ops.operations.nn_ops.FlashAttentionScore( - head_num=head_num, - keep_prob=1 - dropout_p, - scale_value=scale, - input_layout=input_layout, - sparse_mode=sparse_mode, - )(query, key, value, None, None, None, attn_mask) - lse = softmax_max[..., 0] + mint.log(softmax_sum[..., 0]) - lse = lse.permute(0, 2, 1) + return (out, lse) if return_lse else out - return (out, lse) if return_lse else out - def bprop( - self, - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - _save_ctx, - _parallel_config, - out, - dout, - ): - grad_query, grad_key, grad_value = mint.empty_like(query), mint.empty_like(key), mint.empty_like(value) +def _flash_attention_backward_op( + ctx, + grad_out: ms.Tensor, + *args, + **kwargs, +): + query, key, value, out, lse = ctx.saved_tensors + grad_query, grad_key, grad_value = mint.empty_like(query), mint.empty_like(key), mint.empty_like(value) - # Head dimension may have been padded - grad_query = grad_query[..., : dout.shape[-1]] - grad_key = grad_key[..., : dout.shape[-1]] - grad_value = grad_value[..., : dout.shape[-1]] + # Head dimension may have been padded + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value # ===== Context parallel ===== @@ -671,17 +681,10 @@ def permute_tensor( return output -class TemplatedRingAttention(nn.Cell): - def __init__(self): - super().__init__() - self.forward_op = None - self.backward_op = None - self.q_shape = None - self.kv_shape = None - self._parallel_config = None - - def construct( - self, +class TemplatedRingAttention(ms.common._Function): + @staticmethod + def forward( + ctx, query: ms.Tensor, key: ms.Tensor, value: ms.Tensor, @@ -701,11 +704,11 @@ def construct( next_rank = (rank + 1) % world_size prev_out = prev_lse = None - self.forward_op = forward_op - self.backward_op = backward_op - self.q_shape = query.shape - self.kv_shape = key.shape - self._parallel_config = _parallel_config + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx.q_shape = query.shape + ctx.kv_shape = key.shape + ctx._parallel_config = _parallel_config kv_buffer = mint.cat([key.flatten(), value.flatten()]).contiguous() group_size = mint.distributed.get_world_size(ring_mesh) @@ -723,6 +726,7 @@ def construct( next_rank = (next_rank + 1) % world_size out, lse = forward_op( + ctx, query, key, value, @@ -732,6 +736,8 @@ def construct( scale, enable_gqa, True, + _save_ctx=i == 0, + _parallel_config=_parallel_config, ) if _parallel_config.context_parallel_config.convert_to_fp32: @@ -750,35 +756,25 @@ def construct( return (out, lse) if return_lse else out - def bprop( - self, - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - forward_op, - backward_op, - _parallel_config, - out, - dout, + @staticmethod + def backward( + ctx, + grad_out: ms.Tensor, + *args, ): - ring_mesh = self._parallel_config.context_parallel_config._ring_mesh - rank = self._parallel_config.context_parallel_config._ring_local_rank - world_size = self._parallel_config.context_parallel_config.ring_degree + ring_mesh = ctx._parallel_config.context_parallel_config._ring_mesh + rank = ctx._parallel_config.context_parallel_config._ring_local_rank + world_size = ctx._parallel_config.context_parallel_config.ring_degree next_rank = (rank + 1) % world_size next_ranks = list(range(1, world_size)) + [0] - accum_dtype = ms.float32 if self._parallel_config.context_parallel_config.convert_to_fp32 else dout.dtype - grad_query = mint.zeros(self.q_shape, dtype=accum_dtype) - grad_key = mint.zeros(self.kv_shape, dtype=accum_dtype) - grad_value = mint.zeros(self.kv_shape, dtype=accum_dtype) + accum_dtype = ms.float32 if ctx._parallel_config.context_parallel_config.convert_to_fp32 else grad_out.dtype + grad_query = mint.zeros(ctx.q_shape, dtype=accum_dtype) + grad_key = mint.zeros(ctx.kv_shape, dtype=accum_dtype) + grad_value = mint.zeros(ctx.kv_shape, dtype=accum_dtype) next_grad_kv = None + query, key, value, *_ = ctx.saved_tensors kv_buffer = mint.cat([key.flatten(), value.flatten()]).contiguous() group_size = mint.distributed.get_world_size(ring_mesh) kv_buffer_output = mint.cat([mint.zeros_like(kv_buffer) for _ in range(group_size)], dim=0) @@ -793,19 +789,7 @@ def bprop( value = kv[key_numel:].reshape_as(value) next_rank = (next_rank + 1) % world_size - grad_query_op, grad_key_op, grad_value_op, *_ = ( - ms.grad(self.forward_op, grad_position=(0, 1, 2))( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - True, - ), - ) + grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out) if i > 0: grad_kv_buffer = next_grad_kv @@ -821,22 +805,15 @@ def bprop( grad_kv_buffer = mint.cat([grad_key.flatten(), grad_value.flatten()]).contiguous() next_grad_kv = permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh) - grad_query, grad_key, grad_value = (x.to(dout.dtype) for x in (grad_query, grad_key, grad_value)) + grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None -class TemplatedUlyssesAttention(nn.Cell): - def __init__(self): - super().__init__() - self.forward_op = None - self.backward_op = None - self.q_shape = None - self.kv_shape = None - self._parallel_config = None - - def construct( - self, +class TemplatedUlyssesAttention(ms.common._Function): + @staticmethod + def forward( + ctx, query: ms.Tensor, key: ms.Tensor, value: ms.Tensor, @@ -854,9 +831,9 @@ def construct( world_size = _parallel_config.context_parallel_config.ulysses_degree group = ulysses_mesh - self.forward_op = forward_op - self.backward_op = backward_op - self._parallel_config = _parallel_config + ctx.forward_op = forward_op + ctx.backward_op = backward_op + ctx._parallel_config = _parallel_config B, S_Q_LOCAL, H, D = query.shape _, S_KV_LOCAL, _, _ = key.shape @@ -868,6 +845,7 @@ def construct( query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value)) out = forward_op( + ctx, query, key, value, @@ -896,50 +874,24 @@ def construct( return (out, lse) if return_lse else out - def bprop( - self, - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - forward_op, - backward_op, - _parallel_config, - out, - dout, + @staticmethod + def backward( + ctx, + grad_out: ms.Tensor, + *args, ): - ulysses_mesh = self._parallel_config.context_parallel_config._ulysses_mesh - world_size = self._parallel_config.context_parallel_config.ulysses_degree + ulysses_mesh = ctx._parallel_config.context_parallel_config._ulysses_mesh + world_size = ctx._parallel_config.context_parallel_config.ulysses_degree group = ulysses_mesh - B, S_LOCAL, H, D = dout.shape + B, S_LOCAL, H, D = grad_out.shape H_LOCAL = H // world_size - grad_out = dout.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() + grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() grad_out = _all_to_all_single(grad_out, group) grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous() - # grad_query_op, grad_key_op, grad_value_op, *_ = self.backward_op(self, grad_out) - grad_query_op, grad_key_op, grad_value_op, *_ = ( - ms.grad(self.forward_op, grad_position=(0, 1, 2))( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - _save_ctx=True, - _parallel_config=_parallel_config, - ), - ) + grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out) grad_query, grad_key, grad_value = ( x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() @@ -977,7 +929,7 @@ def _templated_context_parallel_attention( # TODO: add support for unified attention with ring/ulysses degree both being > 1 if _parallel_config.context_parallel_config.ring_degree > 1: - return TemplatedRingAttention()( + return TemplatedRingAttention.apply( query, key, value, @@ -992,7 +944,7 @@ def _templated_context_parallel_attention( _parallel_config, ) elif _parallel_config.context_parallel_config.ulysses_degree > 1: - return TemplatedUlyssesAttention()( + return TemplatedUlyssesAttention.apply( query, key, value, @@ -1022,6 +974,7 @@ def _flash_attention( query: ms.Tensor, key: ms.Tensor, value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1030,17 +983,26 @@ def _flash_attention( ) -> ms.Tensor: lse = None if _parallel_config is None: - out = FlashAttentionCell().construct( - query=query, - key=key, - value=value, - dropout_p=dropout_p, - scale=scale, - is_causal=is_causal, - return_lse=return_lse, - ) - if return_lse: - out, lse = out + if scale is None: + scale = query.shape[-1] ** (-0.5) + + if is_causal: + sparse_mode = 2 + else: + sparse_mode = 0 + + input_layout = "BSND" + head_num = query.shape[2] + + softmax_max, softmax_sum, _, out = ops.operations.nn_ops.FlashAttentionScore( + head_num=head_num, + keep_prob=1 - dropout_p, + scale_value=scale, + input_layout=input_layout, + sparse_mode=sparse_mode, + )(query, key, value) + lse = softmax_max[..., 0] + mint.log(softmax_sum[..., 0]) + lse = lse.permute(0, 2, 1) else: out = _templated_context_parallel_attention( query, @@ -1052,8 +1014,8 @@ def _flash_attention( scale, False, return_lse, - forward_op=FlashAttentionCell().construct, - backward_op=FlashAttentionCell().bprop, + forward_op=_flash_attention_forward_op, + backward_op=_flash_attention_backward_op, _parallel_config=_parallel_config, ) if return_lse: @@ -1105,8 +1067,8 @@ def _native_attention( scale, enable_gqa, return_lse, - forward_op=NativeAttentionCell().construct, - backward_op=NativeAttentionCell().bprop, + forward_op=_native_attention_forward_op, + backward_op=_native_attention_backward_op, _parallel_config=_parallel_config, ) @@ -1121,6 +1083,7 @@ def _native_flash_attention( query: ms.Tensor, key: ms.Tensor, value: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1137,7 +1100,7 @@ def _native_flash_attention( query=query, key=key, value=value, - attn_mask=None, # not supported + attn_mask=None, dropout_p=dropout_p, is_causal=is_causal, scale=scale, diff --git a/mindone/diffusers/models/layers_compat.py b/mindone/diffusers/models/layers_compat.py index bd0d30df28..ea04d40065 100644 --- a/mindone/diffusers/models/layers_compat.py +++ b/mindone/diffusers/models/layers_compat.py @@ -26,7 +26,7 @@ - **RMSNorm**: Always custom due to framework limitations. [2025/11/12] - **scaled_dot_product_attention**: Always custom due to framework limitations. - - *DeviceMesh*: Always custom due to framework limitations. + - **DeviceMesh**: Always custom due to framework limitations. Example: Import this module and use the operators as you would with native MindSpore functions, with the assurance of cross-version compatibility. diff --git a/mindone/diffusers/models/transformers/transformer_wan.py b/mindone/diffusers/models/transformers/transformer_wan.py index 9aa9bad5e5..94d06cd775 100644 --- a/mindone/diffusers/models/transformers/transformer_wan.py +++ b/mindone/diffusers/models/transformers/transformer_wan.py @@ -411,7 +411,7 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: # freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) # freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) # FIXME: we use tile since `tensor.broadcast_to` will thrown an issue (complex input is not supported) in graph - # mode + # mode freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).tile((1, pph, ppw, 1)) freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).tile((ppf, 1, ppw, 1)) freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).tile((ppf, pph, 1, 1)) From 86f7a1fd6a062c808fa49b8ec35e8ff86d6b35d4 Mon Sep 17 00:00:00 2001 From: Cui-yshoho Date: Tue, 2 Dec 2025 15:28:05 +0800 Subject: [PATCH 3/3] fix jit classmethod -> staticmethod --- .../diffusers/models/attention_dispatch.py | 33 ++++++++++--------- .../transformers/transformer_skyreels_v2.py | 16 ++++----- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/mindone/diffusers/models/attention_dispatch.py b/mindone/diffusers/models/attention_dispatch.py index 6fcff4f04c..d9e3a83d5b 100644 --- a/mindone/diffusers/models/attention_dispatch.py +++ b/mindone/diffusers/models/attention_dispatch.py @@ -158,37 +158,40 @@ class _AttentionBackendRegistry: _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) _checks_enabled = DIFFUSERS_ATTN_CHECKS - @classmethod + @staticmethod def register( - cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None, supports_context_parallel: bool = False, ): + Registry = _AttentionBackendRegistry logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") def decorator(func): - cls._backends[backend] = func - cls._constraints[backend] = constraints or [] - cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) - cls._supports_context_parallel[backend] = supports_context_parallel + Registry._backends[backend] = func + Registry._constraints[backend] = constraints or [] + Registry._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) + Registry._supports_context_parallel[backend] = supports_context_parallel return func return decorator - @classmethod - def get_active_backend(cls): - return cls._active_backend, cls._backends[cls._active_backend] + @staticmethod + def get_active_backend(): + Registry = _AttentionBackendRegistry + return Registry._active_backend, Registry._backends[Registry._active_backend] - @classmethod - def list_backends(cls): - return list(cls._backends.keys()) + @staticmethod + def list_backends(): + Registry = _AttentionBackendRegistry + return list(Registry._backends.keys()) - @classmethod + @staticmethod def _is_context_parallel_enabled( - cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] + backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] ) -> bool: - supports_context_parallel = backend in cls._supports_context_parallel + Registry = _AttentionBackendRegistry + supports_context_parallel = backend in Registry._supports_context_parallel is_degree_greater_than_1 = parallel_config is not None and ( parallel_config.context_parallel_config.ring_degree > 1 or parallel_config.context_parallel_config.ulysses_degree > 1 diff --git a/mindone/diffusers/models/transformers/transformer_skyreels_v2.py b/mindone/diffusers/models/transformers/transformer_skyreels_v2.py index 2298fa4c1a..1739de22ff 100644 --- a/mindone/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/mindone/diffusers/models/transformers/transformer_skyreels_v2.py @@ -114,8 +114,8 @@ def apply_rotary_emb( out[..., 1::2] = x1 * sin + x2 * cos return out.type_as(hidden_states) - query = apply_rotary_emb(query, rotary_emb) - key = apply_rotary_emb(key, rotary_emb) + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) # I2V task hidden_states_img = None @@ -420,13 +420,13 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1) - freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand((ppf, pph, ppw, -1)) - freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand((ppf, pph, ppw, -1)) - freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand((ppf, pph, ppw, -1)) + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).broadcast_to((ppf, pph, ppw, -1)) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).broadcast_to((ppf, pph, ppw, -1)) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).broadcast_to((ppf, pph, ppw, -1)) - freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand((ppf, pph, ppw, -1)) - freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand((ppf, pph, ppw, -1)) - freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand((ppf, pph, ppw, -1)) + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).broadcast_to((ppf, pph, ppw, -1)) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).broadcast_to((ppf, pph, ppw, -1)) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).broadcast_to((ppf, pph, ppw, -1)) freqs_cos = mint.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) freqs_sin = mint.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)