diff --git a/mindone/diffusers/models/transformers/transformer_wan_vace.py b/mindone/diffusers/models/transformers/transformer_wan_vace.py index ec6eb2869d..87d4393101 100644 --- a/mindone/diffusers/models/transformers/transformer_wan_vace.py +++ b/mindone/diffusers/models/transformers/transformer_wan_vace.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. # # This code is adapted from https://github.com/huggingface/diffusers # with modifications to run diffusers on mindspore. @@ -25,12 +25,18 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging from ..attention import FeedForward -from ..attention_processor import Attention +from ..cache_utils import CacheMixin from ..layers_compat import unflatten from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm -from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock +from .transformer_wan import ( + WanAttention, + WanAttnProcessor, + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -57,33 +63,22 @@ def __init__( # 2. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.attn1 = Attention( - query_dim=dim, + self.attn1 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, - processor=WanAttnProcessor2_0(), + processor=WanAttnProcessor(), ) # 3. Cross-attention - self.attn2 = Attention( - query_dim=dim, + self.attn2 = WanAttention( + dim=dim, heads=num_heads, - kv_heads=num_heads, dim_head=dim // num_heads, - qk_norm=qk_norm, eps=eps, - bias=True, - cross_attention_dim=None, - out_bias=True, added_kv_proj_dim=added_kv_proj_dim, - added_proj_bias=True, - processor=WanAttnProcessor2_0(), + processor=WanAttnProcessor(), ) self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else mint.nn.Identity() @@ -118,12 +113,12 @@ def construct( norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as( control_hidden_states ) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states) # 2. Cross-attention norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states) - attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) control_hidden_states = control_hidden_states + attn_output # 3. Feed-forward @@ -142,7 +137,7 @@ def construct( return conditioning_states, control_hidden_states -class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in the Wan model. @@ -219,11 +214,9 @@ def __init__( # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) - self.patch_embedding = mint.nn.Conv3d( - in_channels, inner_dim, kernel_size=tuple(patch_size), stride=tuple(patch_size) - ) + self.patch_embedding = mint.nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) self.vace_patch_embedding = mint.nn.Conv3d( - vace_in_channels, inner_dim, kernel_size=tuple(patch_size), stride=tuple(patch_size) + vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size ) # 2. Condition embeddings @@ -272,8 +265,6 @@ def __init__( self.gradient_checkpointing = False - self.p_t, self.p_h, self.p_w = self.config.patch_size - def construct( self, hidden_states: ms.Tensor, @@ -285,11 +276,23 @@ def construct( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[ms.Tensor, Dict[str, ms.Tensor]]: - batch_size, num_channels, num_frames, height, width = hidden_states.shape + if attention_kwargs is not None and "scale" in attention_kwargs: + # weight the lora layers by setting `lora_scale` for each PEFT layer here + # and remove `lora_scale` from each PEFT layer at the end. + # scale_lora_layers & unscale_lora_layers maybe contains some operation forbidden in graph mode + raise RuntimeError( + f"You are trying to set scaling of lora layer by passing {attention_kwargs['scale']=}. " + f"However it's not allowed in on-the-fly model forwarding. " + f"Please manually call `scale_lora_layers(model, lora_scale)` before model forwarding and " + f"`unscale_lora_layers(model, lora_scale)` after model forwarding. " + f"For example, it can be done in a pipeline call like `StableDiffusionPipeline.__call__`." + ) - post_patch_num_frames = num_frames // self.p_t - post_patch_height = height // self.p_h - post_patch_width = width // self.p_w + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config["patch_size"] + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w if control_hidden_states_scale is None: control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers)) @@ -305,10 +308,10 @@ def construct( # 2. Patch embedding hidden_states = self.patch_embedding(hidden_states) - hidden_states = mint.transpose(hidden_states.flatten(2), 1, 2) + hidden_states = hidden_states.flatten(2).swapaxes(1, 2) control_hidden_states = self.vace_patch_embedding(control_hidden_states) - control_hidden_states = mint.transpose(control_hidden_states.flatten(2), 1, 2) + control_hidden_states = control_hidden_states.flatten(2).swapaxes(1, 2) control_hidden_states_padding = control_hidden_states.new_zeros( (batch_size, hidden_states.shape[1] - control_hidden_states.shape[1], control_hidden_states.shape[2]) ) @@ -318,7 +321,6 @@ def construct( temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) - # timestep_proj = self.unflatten(timestep_proj) timestep_proj = unflatten(timestep_proj, 1, (6, -1)) # 4. Image embedding @@ -326,27 +328,6 @@ def construct( encoder_hidden_states = mint.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) # 5. Transformer blocks - # mindspore not support gradient_checkpointing - # if self.gradient_checkpointing: - # # Prepare VACE hints - # control_hidden_states_list = [] - # for i, block in enumerate(self.vace_blocks): - # conditioning_states, control_hidden_states = self._gradient_checkpointing_func(block, hidden_states, - # encoder_hidden_states, - # control_hidden_states, - # timestep_proj, - # rotary_emb) - # control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) - # control_hidden_states_list = control_hidden_states_list[::-1] - # - # for i, block in enumerate(self.blocks): - # hidden_states = self._gradient_checkpointing_func(block, hidden_states, encoder_hidden_states, - # timestep_proj, rotary_emb) - # if i in self.config.vace_layers: - # control_hint, scale = control_hidden_states_list[-1] - # hidden_states = hidden_states + control_hint * scale - # control_hidden_states_list = control_hidden_states_list[:-1] - # else: # Prepare VACE hints control_hidden_states_list = [] for i, block in enumerate(self.vace_blocks): @@ -370,7 +351,7 @@ def construct( hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( - batch_size, post_patch_num_frames, post_patch_height, post_patch_width, self.p_t, self.p_h, self.p_w, -1 + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) diff --git a/mindone/diffusers/pipelines/cosmos/_cosmos_guardrail/cosmos_guardrail.py b/mindone/diffusers/pipelines/cosmos/_cosmos_guardrail/cosmos_guardrail.py index 8c36f9b198..0e2e6831de 100644 --- a/mindone/diffusers/pipelines/cosmos/_cosmos_guardrail/cosmos_guardrail.py +++ b/mindone/diffusers/pipelines/cosmos/_cosmos_guardrail/cosmos_guardrail.py @@ -79,10 +79,10 @@ def postprocess(self, frames: np.ndarray) -> np.ndarray: class GuardrailRunner: def __init__( self, - safety_models: list[ContentSafetyGuardrail] | None = None, + safety_models: Union[list[ContentSafetyGuardrail], None] = None, generic_block_msg: str = "", generic_safe_msg: str = "", - postprocessors: list[PostprocessingGuardrail] | None = None, + postprocessors: Union[list[PostprocessingGuardrail], None] = None, ): self.safety_models = safety_models self.generic_block_msg = generic_block_msg diff --git a/mindone/diffusers/pipelines/lumina2/pipeline_lumina2.py b/mindone/diffusers/pipelines/lumina2/pipeline_lumina2.py index 6e3b7a4207..56de5e45a1 100644 --- a/mindone/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/mindone/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -209,7 +209,7 @@ def _get_gemma_prompt_embeds( prompt_attention_mask = ms.tensor(text_inputs.attention_mask) text_input_ids = ms.tensor(text_input_ids) prompt_embeds = self.text_encoder( - text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True, return_dict=True ) prompt_embeds = prompt_embeds.hidden_states[-2] diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index 64d640c031..ebc582d349 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -187,7 +187,7 @@ def _extract_masked_hidden(self, hidden_states: ms.tensor, mask: ms.tensor): def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, - dtype: Optional[ms.dtype] = None, + dtype: Optional[ms.Type] = None, ): dtype = dtype or self.text_encoder.dtype diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index 7fe3fbbe2d..e20099cca0 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -20,11 +20,13 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -from transformers import Qwen2Tokenizer, Qwen2VLProcessor +from transformers import Qwen2Tokenizer import mindspore as ms from mindspore import mint +from mindone.transformers import Qwen2VLProcessor + from ....transformers import Qwen2_5_VLForConditionalGeneration from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import QwenImageLoraLoaderMixin @@ -228,7 +230,7 @@ def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, image: Optional[ms.tensor] = None, - dtype: Optional[ms.dtype] = None, + dtype: Optional[ms.Type] = None, ): dtype = dtype or self.text_encoder.dtype diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index 320b43eb4d..e3d0600c5e 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -21,11 +21,13 @@ import numpy as np import PIL.Image -from transformers import Qwen2Tokenizer, Qwen2VLProcessor +from transformers import Qwen2Tokenizer import mindspore as ms from mindspore import mint +from mindone.transformers import Qwen2VLProcessor + from ....transformers import Qwen2_5_VLForConditionalGeneration from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import QwenImageLoraLoaderMixin @@ -238,7 +240,7 @@ def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, image: Optional[ms.tensor] = None, - dtype: Optional[ms.dtype] = None, + dtype: Optional[ms.Type] = None, ): dtype = dtype or self.text_encoder.dtype diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index 1d3f036ce8..a731e68b7d 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -194,7 +194,7 @@ def _extract_masked_hidden(self, hidden_states: ms.tensor, mask: ms.tensor): def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, - dtype: Optional[ms.dtype] = None, + dtype: Optional[ms.Type] = None, ): dtype = dtype or self.text_encoder.dtype diff --git a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 9af4f5b083..341389e46a 100644 --- a/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/mindone/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -204,7 +204,7 @@ def _extract_masked_hidden(self, hidden_states: ms.tensor, mask: ms.tensor): def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]] = None, - dtype: Optional[ms.dtype] = None, + dtype: Optional[ms.Type] = None, ): dtype = dtype or self.text_encoder.dtype diff --git a/mindone/diffusers/pipelines/wan/pipeline_wan_vace.py b/mindone/diffusers/pipelines/wan/pipeline_wan_vace.py index 4496908ddd..0c03e21fb2 100644 --- a/mindone/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/mindone/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -1,4 +1,4 @@ -# Copyright 2025 The HuggingFace Team. All rights reserved. +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. # # This code is adapted from https://github.com/huggingface/diffusers # with modifications to run diffusers on mindspore. @@ -26,13 +26,12 @@ import mindspore as ms from mindspore import mint -from mindone.diffusers.models.transformers.transformer_wan_vace import WanVACETransformer3DModel from mindone.transformers import UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin -from ...models import AutoencoderKLWan +from ...models import AutoencoderKLWan, WanVACETransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import is_ftfy_available, logging from ...utils.mindspore_utils import pynative_context, randn_tensor @@ -46,6 +45,72 @@ import ftfy +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import mindspore as ms + >>> import PIL.Image + >>> from mindone.diffusers import AutoencoderKLWan, WanVACEPipeline + >>> from mindone.diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + >>> from mindone.diffusers.utils import export_to_video, load_image + def prepare_video_and_mask(first_img: PIL.Image.Image, last_img: PIL.Image.Image, height: int, width: int, num_frames: int): + first_img = first_img.resize((width, height)) + last_img = last_img.resize((width, height)) + frames = [] + frames.append(first_img) + # Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays + # whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to + # match the original code. + frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2)) + frames.append(last_img) + mask_black = PIL.Image.new("L", (width, height), 0) + mask_white = PIL.Image.new("L", (width, height), 255) + mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black] + return frames, mask + + >>> # Available checkpoints: Wan-AI/Wan2.1-VACE-1.3B-diffusers, Wan-AI/Wan2.1-VACE-14B-diffusers + >>> model_id = "Wan-AI/Wan2.1-VACE-1.3B-diffusers" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", mindspore_dtype=ms.float32) + >>> pipe = WanVACEPipeline.from_pretrained(model_id, vae=vae, mindspore_dtype=ms.bfloat16) + >>> flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + + >>> prompt = "CG animation style, a small blue bird takes off from the ground, " \ + "flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. " \ + "The background shows a blue sky with white clouds under bright sunshine. " \ + "The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, " \ + "worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, " \ + "poorly drawn faces, deformed, disfigured, " \ + "misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + >>> first_frame = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png" + ... ) + >>> last_frame = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png>>> " + ... ) + + >>> height = 512 + >>> width = 512 + >>> num_frames = 81 + >>> video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames) + + >>> output = pipe( + ... video=video, + ... mask=mask, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=height, + ... width=width, + ... num_frames=num_frames, + ... num_inference_steps=30, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) @@ -83,7 +148,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin): Pipeline for controllable generation using Wan. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, etc.). + implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: tokenizer ([`T5Tokenizer`]): @@ -200,7 +265,7 @@ def encode_prompt( weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. dtype: (`ms.dtype`, *optional*): - ms dtype + mindspore dtype """ prompt = [prompt] if isinstance(prompt, str) else prompt if prompt is not None: @@ -457,8 +522,7 @@ def prepare_video_latents( latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0) latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype) else: - mask = mask.to(dtype=vae_dtype) - mask = mint.where(mask > 0.5, 1.0, 0.0) + mask = mint.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) inactive = video * (1 - mask) reactive = video * mask with pynative_context(): @@ -539,7 +603,7 @@ def prepare_latents( width: int = 832, num_frames: int = 81, dtype: Optional[ms.Type] = None, - generator: Optional[Union[ms.Generator, List[ms.Generator]]] = None, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, latents: Optional[ms.Tensor] = None, ) -> ms.Tensor: if latents is not None: @@ -600,7 +664,7 @@ def __call__( num_inference_steps: int = 50, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, - generator: np.random.Generator = None, + generator: Optional[Union[np.random.Generator, List[np.random.Generator]]] = None, latents: Optional[ms.Tensor] = None, prompt_embeds: Optional[ms.Tensor] = None, negative_prompt_embeds: Optional[ms.Tensor] = None, @@ -662,8 +726,8 @@ def __call__( usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - generator (`np.random.Generator`): - A [`np.random.Generator`] to make + generator (`np.random.Generator` or `List[np.random.Generator]`, *optional*): + A [`np.random.Generator`](https://numpy.org/doc/stable/reference/random/generator.html) to make generation deterministic. latents (`ms.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image @@ -801,7 +865,7 @@ def __call__( num_reference_images = len(reference_images[0]) conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator) - mask = self.prepare_masks(mask, reference_images, generator).to(conditioning_latents.dtype) + mask = self.prepare_masks(mask, reference_images, generator) conditioning_latents = mint.cat([conditioning_latents, mask], dim=1) conditioning_latents = conditioning_latents.to(transformer_dtype) @@ -883,8 +947,8 @@ def __call__( latents_mean = ( ms.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(latents.dtype) ) - latents_std = 1.0 / ms.tensor(self.vae.config.latents_std, dtype=ms.float32).view( - 1, self.vae.config.z_dim, 1, 1, 1 + latents_std = 1.0 / ms.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.dtype ) latents = latents / latents_std + latents_mean with pynative_context(): diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 1b8cfd7c9e..f6d255baf6 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -1124,6 +1124,7 @@ Qwen2VLImageProcessorFast, Qwen2VLModel, Qwen2VLPreTrainedModel, + Qwen2VLProcessor, Qwen2VLVideoProcessor, ) from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index b1d25744a3..7881576ee4 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -213,6 +213,7 @@ ("qwen2_audio", "Qwen2AudioConfig"), ("qwen2_audio_encoder", "Qwen2AudioEncoderConfig"), ("qwen2_vl", "Qwen2VLConfig"), + ("qwen2_vl_text", "Qwen2VLTextConfig"), ("rag", "RagConfig"), ("recurrent_gemma", "RecurrentGemmaConfig"), ("regnet", "RegNetConfig"), @@ -483,6 +484,7 @@ ("qwen2_audio", "Qwen2Audio"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_vl", "Qwen2VL"), + ("qwen2_vl_text", "Qwen2VL"), ("rag", "RAG"), ("recurrent_gemma", "RecurrentGemma"), ("regnet", "RegNet"), @@ -613,6 +615,7 @@ ("maskformer-swin", "maskformer"), ("openai-gpt", "openai"), ("qwen2_audio_encoder", "qwen2_audio"), + ("qwen2_vl_text", "qwen2_vl"), ("rt_detr_resnet", "rt_detr"), ("siglip_vision_model", "siglip"), ("smolvlm_vision", "smolvlm"), diff --git a/mindone/transformers/models/auto/image_processing_auto.py b/mindone/transformers/models/auto/image_processing_auto.py index 2fb5c3d7e9..d248fb446b 100644 --- a/mindone/transformers/models/auto/image_processing_auto.py +++ b/mindone/transformers/models/auto/image_processing_auto.py @@ -73,6 +73,7 @@ ("owlv2", ("Owlv2ImageProcessor",)), ("owlvit", ("OwlViTImageProcessor",)), ("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), + ("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")), ("sam", ("SamImageProcessor",)), ("segformer", ("SegformerImageProcessor",)), ("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index c89e50bc47..2a35ee2cee 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -194,6 +194,7 @@ ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoeModel"), ("qwen2_vl", "Qwen2VLModel"), + ("qwen2_vl_text", "Qwen2VLTextModel"), ("recurrent_gemma", "RecurrentGemmaModel"), ("regnet", "RegNetModel"), ("roberta", "RobertaModel"), diff --git a/mindone/transformers/models/auto/processing_auto.py b/mindone/transformers/models/auto/processing_auto.py index 88df3d8aee..f8d5098dc1 100644 --- a/mindone/transformers/models/auto/processing_auto.py +++ b/mindone/transformers/models/auto/processing_auto.py @@ -67,6 +67,7 @@ ("owlvit", "OwlViTProcessor"), ("pop2piano", "Pop2PianoProcessor"), ("qwen2_5_vl", "Qwen2_5_VLProcessor"), + ("qwen2_vl", "Qwen2VLProcessor"), ("sam", "SamProcessor"), ("seamless_m4t", "SeamlessM4TProcessor"), ("siglip", "SiglipProcessor"), diff --git a/mindone/transformers/models/auto/video_processing_auto.py b/mindone/transformers/models/auto/video_processing_auto.py index 9959be4480..9d33a486b8 100644 --- a/mindone/transformers/models/auto/video_processing_auto.py +++ b/mindone/transformers/models/auto/video_processing_auto.py @@ -45,6 +45,7 @@ VIDEO_PROCESSOR_MAPPING_NAMES = OrderedDict( [ ("qwen2_5_vl", "Qwen2VLVideoProcessor"), + ("qwen2_vl", "Qwen2VLVideoProcessor"), ] ) diff --git a/mindone/transformers/models/qwen2_vl/__init__.py b/mindone/transformers/models/qwen2_vl/__init__.py index 3a1053517d..155a8aa4df 100644 --- a/mindone/transformers/models/qwen2_vl/__init__.py +++ b/mindone/transformers/models/qwen2_vl/__init__.py @@ -17,4 +17,5 @@ from .image_processing_qwen2_vl import * from .image_processing_qwen2_vl_fast import * from .modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2VLModel, Qwen2VLPreTrainedModel +from .processing_qwen2_vl import * from .video_processing_qwen2_vl import * diff --git a/mindone/transformers/models/qwen2_vl/processing_qwen2_vl.py b/mindone/transformers/models/qwen2_vl/processing_qwen2_vl.py new file mode 100644 index 0000000000..68720b6430 --- /dev/null +++ b/mindone/transformers/models/qwen2_vl/processing_qwen2_vl.py @@ -0,0 +1,254 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for Qwen2-VL. +""" + +from typing import Optional, Union + +import numpy as np +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +from ...feature_extraction_utils import BatchFeature +from ...image_utils import ImageInput +from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack +from ...utils import logging +from ...video_utils import VideoInput + +logger = logging.get_logger(__name__) + + +class Qwen2VLImagesKwargs(ImagesKwargs): + min_pixels: Optional[int] + max_pixels: Optional[int] + patch_size: Optional[int] + temporal_patch_size: Optional[int] + merge_size: Optional[int] + + +class Qwen2VLProcessorKwargs(ProcessingKwargs, total=False): + images_kwargs: Qwen2VLImagesKwargs + _defaults = { + "text_kwargs": { + "padding": False, + "return_mm_token_type_ids": False, + }, + } + + +class Qwen2VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2-VL processor which wraps a Qwen2-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2VLProcessor.__call__`] and [`~Qwen2VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + video_processor ([`Qwen2VLVideoProcessor`], *optional*): + The video processor is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer", "video_processor"] + image_processor_class = "AutoImageProcessor" + video_processor_class = "AutoVideoProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + self.image_token_id = ( + tokenizer.image_token_id + if getattr(tokenizer, "image_token_id", None) + else tokenizer.convert_tokens_to_ids(self.image_token) + ) + self.video_token_id = ( + tokenizer.video_token_id + if getattr(tokenizer, "video_token_id", None) + else tokenizer.convert_tokens_to_ids(self.video_token) + ) + super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None, + videos: Optional[VideoInput] = None, + **kwargs: Unpack[Qwen2VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwargs` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `list[str]`, `list[list[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + image_inputs = videos_inputs = {} + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + + if videos is not None: + videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + if not isinstance(text, list): + text = [text] + + text = text.copy() # below lines change text in-place + + if images is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + num_image_tokens = image_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if videos is not None: + merge_length = self.video_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + num_video_tokens = video_grid_thw[index].prod() // merge_length + text[i] = text[i].replace(self.video_token, "<|placeholder|>" * num_video_tokens, 1) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None) + return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"], return_tensors=None) + self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"]) + + if return_mm_token_type_ids: + array_ids = np.array(text_inputs["input_ids"]) + mm_token_type_ids = np.zeros_like(text_inputs["input_ids"]) + mm_token_type_ids[array_ids == self.image_token_id] = 1 + text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist() + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors) + + def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs): + """ + Computes the number of placeholder tokens needed for multimodal inputs with the given sizes. + Args: + image_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (height, width) per each image. + video_sizes (`list[list[int]]`, *optional*): + The input sizes formatted as (num_frames, height, width) per each video. + Returns: + `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided + input modalities, along with other useful data. + """ + + vision_data = {} + if image_sizes is not None: + images_kwargs = Qwen2VLProcessorKwargs._defaults.get("images_kwargs", {}) + images_kwargs.update(kwargs) + merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size + + num_image_patches = [ + self.image_processor.get_number_of_image_patches(*image_size, images_kwargs) + for image_size in image_sizes + ] + num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches] + vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches}) + + if video_sizes is not None: + videos_kwargs = Qwen2VLProcessorKwargs._defaults.get("videos_kwargs", {}) + videos_kwargs.update(kwargs) + num_video_patches = [ + self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs) + for video_size in video_sizes + ] + num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches] + vision_data["num_video_tokens"] = num_video_tokens + + return MultiModalData(**vision_data) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `list[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + +__all__ = ["Qwen2VLProcessor"] diff --git a/tests/diffusers_tests/modules/modules_test_cases.py b/tests/diffusers_tests/modules/modules_test_cases.py index 0b74d15ab3..7809d3cfed 100644 --- a/tests/diffusers_tests/modules/modules_test_cases.py +++ b/tests/diffusers_tests/modules/modules_test_cases.py @@ -1481,6 +1481,10 @@ "enable_temporal_attentions": True, }, ], +] + + +WANVACE_TRANSFORMER3D_CASES = [ [ "WanVACETransformer3DModel", "diffusers.models.transformers.transformer_wan_vace.WanVACETransformer3DModel", @@ -1642,6 +1646,7 @@ + TRANSFORMER2D_CASES + FLUX_TRANSFORMER2D_CASES + LATTE_TRANSORMER3D_CASES + + WANVACE_TRANSFORMER3D_CASES + LUMINA_NEXTDIT2D_CASES + LUMINA2_TRANSFORMER2D_CASES + CHROMA_TRANSFORMER2D_CASES diff --git a/tests/diffusers_tests/pipelines/audioldm/test_audioldm.py b/tests/diffusers_tests/pipelines/audioldm/test_audioldm.py index 7325f6fa82..9837fdcd47 100644 --- a/tests/diffusers_tests/pipelines/audioldm/test_audioldm.py +++ b/tests/diffusers_tests/pipelines/audioldm/test_audioldm.py @@ -19,9 +19,12 @@ import unittest +import diffusers import numpy as np +import pytest import torch from ddt import data, ddt, unpack +from packaging.version import Version from transformers import ClapTextConfig, SpeechT5HifiGanConfig import mindspore as ms @@ -163,6 +166,11 @@ def get_dummy_inputs(self): @data(*test_cases) @unpack def test_audioldm(self, mode, dtype): + last_supported_version = Version("0.33.1") + current_version = Version(diffusers.__version__) + if current_version > last_supported_version: + pytest.skip(f"AudioLDMPipeline is not supported in diffusers version {current_version}") + ms.set_context(mode=mode) pt_components, ms_components = self.get_dummy_components() @@ -196,6 +204,11 @@ class AudioLDMPipelineNightlyTests(PipelineTesterMixin, unittest.TestCase): @data(*test_cases) @unpack def test_audioldm(self, mode, dtype): + last_supported_version = Version("0.33.1") + current_version = Version(diffusers.__version__) + if current_version > last_supported_version: + pytest.skip(f"AudioLDMPipeline is not supported in diffusers version {current_version}") + ms.set_context(mode=mode) ms_dtype = getattr(ms, dtype) diff --git a/tests/diffusers_tests/pipelines/audioldm2/test_audioldm2.py b/tests/diffusers_tests/pipelines/audioldm2/test_audioldm2.py index 0de893bee0..fd4ff5cec6 100644 --- a/tests/diffusers_tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/diffusers_tests/pipelines/audioldm2/test_audioldm2.py @@ -20,8 +20,11 @@ import unittest import numpy as np +import pytest import torch +import transformers from ddt import data, ddt, unpack +from packaging.version import Version from transformers import ClapAudioConfig, ClapConfig, ClapTextConfig, GPT2Config, SpeechT5HifiGanConfig, T5Config import mindspore as ms @@ -240,6 +243,11 @@ def get_dummy_inputs(self): @data(*test_cases) @unpack def test_audioldm2(self, mode, dtype): + last_supported_version = Version("4.50.0") + current_version = Version(transformers.__version__) + if current_version > last_supported_version: + pytest.skip(f"AudioLDM2Pipeline is not supported in transformers version {current_version}") + ms.set_context(mode=mode) pt_components, ms_components = self.get_dummy_components() diff --git a/tests/diffusers_tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/diffusers_tests/pipelines/blipdiffusion/test_blipdiffusion.py index fde0180587..d830e5a4a8 100644 --- a/tests/diffusers_tests/pipelines/blipdiffusion/test_blipdiffusion.py +++ b/tests/diffusers_tests/pipelines/blipdiffusion/test_blipdiffusion.py @@ -18,9 +18,12 @@ import unittest +import diffusers import numpy as np +import pytest import torch from ddt import data, ddt, unpack +from packaging.version import Version from PIL import Image from transformers import Blip2Config, CLIPTextConfig @@ -197,6 +200,11 @@ def get_dummy_inputs(self, seed=0): @data(*test_cases) @unpack def test_blipdiffusion(self, mode, dtype): + last_supported_version = Version("0.33.1") + current_version = Version(diffusers.__version__) + if current_version > last_supported_version: + pytest.skip(f"BlipDiffusionPipeline is not supported in diffusers version {current_version}") + ms.set_context(mode=mode) pt_components, ms_components = self.get_dummy_components() @@ -233,6 +241,11 @@ class BlipDiffusionPipelineIntegrationTests(PipelineTesterMixin, unittest.TestCa @data(*test_cases) @unpack def test_blipdiffusion(self, mode, dtype): + last_supported_version = Version("0.33.1") + current_version = Version(diffusers.__version__) + if current_version > last_supported_version: + pytest.skip(f"BlipDiffusionPipeline is not supported in diffusers version {current_version}") + ms.set_context(mode=mode) ms_dtype = getattr(ms, dtype) diff --git a/tests/diffusers_tests/pipelines/bria/test_pipeline_bria.py b/tests/diffusers_tests/pipelines/bria/test_pipeline_bria.py index 8ac6d08cf6..80f054c244 100644 --- a/tests/diffusers_tests/pipelines/bria/test_pipeline_bria.py +++ b/tests/diffusers_tests/pipelines/bria/test_pipeline_bria.py @@ -17,9 +17,12 @@ import unittest +import diffusers import numpy as np +import pytest import torch from ddt import data, ddt, unpack +from packaging.version import Version import mindspore as ms @@ -137,6 +140,11 @@ def get_dummy_inputs(self): @data(*test_cases) @unpack def test_bria(self, mode, dtype): + required_version = Version("0.35.2") + current_version = Version(diffusers.__version__) + if current_version <= required_version: + pytest.skip(f"BriaPipeline is not supported in diffusers version {current_version}") + ms.set_context(mode=mode) pt_components, ms_components = self.get_dummy_components() diff --git a/tests/diffusers_tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/diffusers_tests/pipelines/controlnet/test_controlnet_blip_diffusion.py index cf5736616b..7a6f54aa33 100644 --- a/tests/diffusers_tests/pipelines/controlnet/test_controlnet_blip_diffusion.py +++ b/tests/diffusers_tests/pipelines/controlnet/test_controlnet_blip_diffusion.py @@ -18,9 +18,12 @@ import unittest +import diffusers import numpy as np +import pytest import torch from ddt import data, ddt, unpack +from packaging.version import Version from PIL import Image from transformers import Blip2Config, CLIPTextConfig @@ -215,6 +218,11 @@ def get_dummy_inputs(self, seed=0): @data(*test_cases) @unpack def test_blipdiffusion_controlnet(self, mode, dtype): + last_supported_version = Version("0.33.1") + current_version = Version(diffusers.__version__) + if current_version > last_supported_version: + pytest.skip(f"BlipDiffusionControlNetPipeline is not supported in diffusers version {current_version}") + ms.set_context(mode=mode) pt_components, ms_components = self.get_dummy_components() @@ -251,6 +259,11 @@ class BlipDiffusionControlNetPipelineIntegrationTests(PipelineTesterMixin, unitt @data(*test_cases) @unpack def test_blipdiffusion_controlnet(self, mode, dtype): + last_supported_version = Version("0.33.1") + current_version = Version(diffusers.__version__) + if current_version > last_supported_version: + pytest.skip(f"BlipDiffusionControlNetPipeline is not supported in diffusers version {current_version}") + ms.set_context(mode=mode) ms_dtype = getattr(ms, dtype) diff --git a/tests/diffusers_tests/pipelines/musicldm/test_musicldm.py b/tests/diffusers_tests/pipelines/musicldm/test_musicldm.py index a9abded83b..a4bd21e7e9 100644 --- a/tests/diffusers_tests/pipelines/musicldm/test_musicldm.py +++ b/tests/diffusers_tests/pipelines/musicldm/test_musicldm.py @@ -2,9 +2,12 @@ import unittest +import diffusers import numpy as np +import pytest import torch from ddt import data, ddt, unpack +from packaging.version import Version from transformers.models.clap.configuration_clap import ClapAudioConfig, ClapConfig, ClapTextConfig from transformers.models.speecht5.configuration_speecht5 import SpeechT5HifiGanConfig @@ -174,6 +177,11 @@ def get_dummy_inputs(self, seed=0): @data(*test_cases) @unpack def test_inference(self, mode, dtype): + required_version = Version("0.33.1") + current_version = Version(diffusers.__version__) + if current_version > required_version: + pytest.skip(f"MusicLDMPipeline is not supported in diffusers version {current_version}") + ms.set_context(mode=mode) pt_components, ms_components = self.get_dummy_components() diff --git a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py index 21e17fbcf9..b7e6aa3749 100644 --- a/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py +++ b/tests/diffusers_tests/pipelines/qwenimage/test_qwenimage_edit.py @@ -150,7 +150,7 @@ class QwenImageEditPipelineFastTests(PipelineTesterMixin, unittest.TestCase): [ "processor", "transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor", - "transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor", + "mindone.transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor", dict( pretrained_model_name_or_path="hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", trust_remote_code=True,