Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 39 additions & 58 deletions mindone/diffusers/models/transformers/transformer_wan_vace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# This code is adapted from https://github.com/huggingface/diffusers
# with modifications to run diffusers on mindspore.
Expand All @@ -25,12 +25,18 @@
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import logging
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..layers_compat import unflatten
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
from .transformer_wan import (
WanAttention,
WanAttnProcessor,
WanRotaryPosEmbed,
WanTimeTextImageEmbedding,
WanTransformerBlock,
)

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

Expand All @@ -57,33 +63,22 @@ def __init__(

# 2. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = Attention(
query_dim=dim,
self.attn1 = WanAttention(
dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
processor=WanAttnProcessor2_0(),
processor=WanAttnProcessor(),
)

# 3. Cross-attention
self.attn2 = Attention(
query_dim=dim,
self.attn2 = WanAttention(
dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
added_proj_bias=True,
processor=WanAttnProcessor2_0(),
processor=WanAttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else mint.nn.Identity()

Expand Down Expand Up @@ -118,12 +113,12 @@ def construct(
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
control_hidden_states
)
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)

# 2. Cross-attention
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
control_hidden_states = control_hidden_states + attn_output

# 3. Feed-forward
Expand All @@ -142,7 +137,7 @@ def construct(
return conditioning_states, control_hidden_states


class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in the Wan model.

Expand Down Expand Up @@ -219,11 +214,9 @@ def __init__(

# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = mint.nn.Conv3d(
in_channels, inner_dim, kernel_size=tuple(patch_size), stride=tuple(patch_size)
)
self.patch_embedding = mint.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
self.vace_patch_embedding = mint.nn.Conv3d(
vace_in_channels, inner_dim, kernel_size=tuple(patch_size), stride=tuple(patch_size)
vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size
)

# 2. Condition embeddings
Expand Down Expand Up @@ -272,8 +265,6 @@ def __init__(

self.gradient_checkpointing = False

self.p_t, self.p_h, self.p_w = self.config.patch_size

def construct(
self,
hidden_states: ms.Tensor,
Expand All @@ -285,11 +276,23 @@ def construct(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[ms.Tensor, Dict[str, ms.Tensor]]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if attention_kwargs is not None and "scale" in attention_kwargs:
# weight the lora layers by setting `lora_scale` for each PEFT layer here
# and remove `lora_scale` from each PEFT layer at the end.
# scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode
raise RuntimeError(
f"You are trying to set scaling of lora layer by passing {attention_kwargs['scale']=}. "
f"However it's not allowed in on-the-fly model forwarding. "
f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and "
f"`unscale_lora_layers(model, lora_scale)` after model forwarding. "
f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`."
)

post_patch_num_frames = num_frames // self.p_t
post_patch_height = height // self.p_h
post_patch_width = width // self.p_w
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config["patch_size"]
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w

if control_hidden_states_scale is None:
control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
Expand All @@ -305,10 +308,10 @@ def construct(

# 2. Patch embedding
hidden_states = self.patch_embedding(hidden_states)
hidden_states = mint.transpose(hidden_states.flatten(2), 1, 2)
hidden_states = hidden_states.flatten(2).swapaxes(1, 2)

control_hidden_states = self.vace_patch_embedding(control_hidden_states)
control_hidden_states = mint.transpose(control_hidden_states.flatten(2), 1, 2)
control_hidden_states = control_hidden_states.flatten(2).swapaxes(1, 2)
control_hidden_states_padding = control_hidden_states.new_zeros(
(batch_size, hidden_states.shape[1] - control_hidden_states.shape[1], control_hidden_states.shape[2])
)
Expand All @@ -318,35 +321,13 @@ def construct(
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
# timestep_proj = self.unflatten(timestep_proj)
timestep_proj = unflatten(timestep_proj, 1, (6, -1))

# 4. Image embedding
if encoder_hidden_states_image is not None:
encoder_hidden_states = mint.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)

# 5. Transformer blocks
# mindspore not support gradient_checkpointing
# if self.gradient_checkpointing:
# # Prepare VACE hints
# control_hidden_states_list = []
# for i, block in enumerate(self.vace_blocks):
# conditioning_states, control_hidden_states = self._gradient_checkpointing_func(block, hidden_states,
# encoder_hidden_states,
# control_hidden_states,
# timestep_proj,
# rotary_emb)
# control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
# control_hidden_states_list = control_hidden_states_list[::-1]
#
# for i, block in enumerate(self.blocks):
# hidden_states = self._gradient_checkpointing_func(block, hidden_states, encoder_hidden_states,
# timestep_proj, rotary_emb)
# if i in self.config.vace_layers:
# control_hint, scale = control_hidden_states_list[-1]
# hidden_states = hidden_states + control_hint * scale
# control_hidden_states_list = control_hidden_states_list[:-1]
# else:
# Prepare VACE hints
control_hidden_states_list = []
for i, block in enumerate(self.vace_blocks):
Expand All @@ -370,7 +351,7 @@ def construct(
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, self.p_t, self.p_h, self.p_w, -1
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def postprocess(self, frames: np.ndarray) -> np.ndarray:
class GuardrailRunner:
def __init__(
self,
safety_models: list[ContentSafetyGuardrail] | None = None,
safety_models: Union[list[ContentSafetyGuardrail], None] = None,
generic_block_msg: str = "",
generic_safe_msg: str = "",
postprocessors: list[PostprocessingGuardrail] | None = None,
postprocessors: Union[list[PostprocessingGuardrail], None] = None,
):
self.safety_models = safety_models
self.generic_block_msg = generic_block_msg
Expand Down
2 changes: 1 addition & 1 deletion mindone/diffusers/pipelines/lumina2/pipeline_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _get_gemma_prompt_embeds(
prompt_attention_mask = ms.tensor(text_inputs.attention_mask)
text_input_ids = ms.tensor(text_input_ids)
prompt_embeds = self.text_encoder(
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True, return_dict=True
)
prompt_embeds = prompt_embeds.hidden_states[-2]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _extract_masked_hidden(self, hidden_states: ms.tensor, mask: ms.tensor):
def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
dtype: Optional[ms.dtype] = None,
dtype: Optional[ms.Type] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for dtype has been corrected from Optional[ms.dtype] to Optional[ms.Type]. This is a good fix, as ms.Type is the correct way to hint a type class in MindSpore, whereas ms.dtype refers to an instance of a dtype. This improves type correctness and clarity.

):
dtype = dtype or self.text_encoder.dtype

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
from transformers import Qwen2Tokenizer

import mindspore as ms
from mindspore import mint

from mindone.transformers import Qwen2VLProcessor

from ....transformers import Qwen2_5_VLForConditionalGeneration
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import QwenImageLoraLoaderMixin
Expand Down Expand Up @@ -228,7 +230,7 @@ def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
image: Optional[ms.tensor] = None,
dtype: Optional[ms.dtype] = None,
dtype: Optional[ms.Type] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for dtype has been corrected from Optional[ms.dtype] to Optional[ms.Type]. This is a good fix, as ms.Type is the correct way to hint a type class in MindSpore, whereas ms.dtype refers to an instance of a dtype. This improves type correctness and clarity.

):
dtype = dtype or self.text_encoder.dtype

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@

import numpy as np
import PIL.Image
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
from transformers import Qwen2Tokenizer

import mindspore as ms
from mindspore import mint

from mindone.transformers import Qwen2VLProcessor

from ....transformers import Qwen2_5_VLForConditionalGeneration
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import QwenImageLoraLoaderMixin
Expand Down Expand Up @@ -238,7 +240,7 @@ def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
image: Optional[ms.tensor] = None,
dtype: Optional[ms.dtype] = None,
dtype: Optional[ms.Type] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for dtype has been corrected from Optional[ms.dtype] to Optional[ms.Type]. This is a good fix, as ms.Type is the correct way to hint a type class in MindSpore, whereas ms.dtype refers to an instance of a dtype. This improves type correctness and clarity.

):
dtype = dtype or self.text_encoder.dtype

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _extract_masked_hidden(self, hidden_states: ms.tensor, mask: ms.tensor):
def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
dtype: Optional[ms.dtype] = None,
dtype: Optional[ms.Type] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for dtype has been corrected from Optional[ms.dtype] to Optional[ms.Type]. This is a good fix, as ms.Type is the correct way to hint a type class in MindSpore, whereas ms.dtype refers to an instance of a dtype. This improves type correctness and clarity.

):
dtype = dtype or self.text_encoder.dtype

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _extract_masked_hidden(self, hidden_states: ms.tensor, mask: ms.tensor):
def _get_qwen_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
dtype: Optional[ms.dtype] = None,
dtype: Optional[ms.Type] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for dtype has been corrected from Optional[ms.dtype] to Optional[ms.Type]. This is a good fix, as ms.Type is the correct way to hint a type class in MindSpore, whereas ms.dtype refers to an instance of a dtype. This improves type correctness and clarity.

):
dtype = dtype or self.text_encoder.dtype

Expand Down
Loading