From f6b5af34bc6626e0071b00e601e9ec7fd80af3f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=82=9F=E7=94=BB?= Date: Fri, 28 Nov 2025 16:18:13 +0800 Subject: [PATCH 1/9] add ovis_image --- scripts/convert_ovis_image_to_diffusers.py | 286 +++++++ src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_ovis_image.py | 542 +++++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/ovis_image/__init__.py | 50 ++ .../pipelines/ovis_image/pipeline_output.py | 35 + .../ovis_image/pipeline_ovis_image.py | 737 ++++++++++++++++++ .../dummy_torch_and_transformers_objects.py | 15 + tests/pipelines/ovis_image/__init__.py | 0 11 files changed, 1674 insertions(+) create mode 100644 scripts/convert_ovis_image_to_diffusers.py create mode 100644 src/diffusers/models/transformers/transformer_ovis_image.py create mode 100644 src/diffusers/pipelines/ovis_image/__init__.py create mode 100644 src/diffusers/pipelines/ovis_image/pipeline_output.py create mode 100644 src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py create mode 100644 tests/pipelines/ovis_image/__init__.py diff --git a/scripts/convert_ovis_image_to_diffusers.py b/scripts/convert_ovis_image_to_diffusers.py new file mode 100644 index 000000000000..36edf65ea4fb --- /dev/null +++ b/scripts/convert_ovis_image_to_diffusers.py @@ -0,0 +1,286 @@ +import argparse +from contextlib import nullcontext + +import safetensors.torch +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download + +from diffusers import AutoencoderKL, OvisImageTransformer2DModel +from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint +from diffusers.utils.import_utils import is_accelerate_available + + +""" +# Transformer + +python scripts/convert_ovis_image_to_diffusers.py \ +--original_state_dict_repo_id "AIDC-AI/Ovis-Image-7B" \ +--filename "ovis_image.safetensors" +--output_path "ovis-image" \ +--transformer +""" + + +CTX = init_empty_weights if is_accelerate_available() else nullcontext + +parser = argparse.ArgumentParser() +parser.add_argument("--original_state_dict_repo_id", default=None, type=str) +parser.add_argument("--filename", default="ovis_image.safetensors", type=str) +parser.add_argument("--checkpoint_path", default=None, type=str) +parser.add_argument("--in_channels", type=int, default=64) +parser.add_argument("--out_channels", type=int, default=None) +parser.add_argument("--transformer", action="store_true") +parser.add_argument("--output_path", type=str) +parser.add_argument("--dtype", type=str, default="bf16") + +args = parser.parse_args() +dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32 + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + ckpt_path = args.checkpoint_path + else: + raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; +# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation +def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + +def convert_ovis_image_transformer_checkpoint_to_diffusers( + original_state_dict, num_layers, num_single_layers, inner_dim, mlp_ratio=4.0 +): + converted_state_dict = {} + + ## time_text_embed.timestep_embedder <- time_in + converted_state_dict["timestep_embedder.linear_1.weight"] = original_state_dict.pop( + "time_in.in_layer.weight" + ) + converted_state_dict["timestep_embedder.linear_1.bias"] = original_state_dict.pop( + "time_in.in_layer.bias" + ) + converted_state_dict["timestep_embedder.linear_2.weight"] = original_state_dict.pop( + "time_in.out_layer.weight" + ) + converted_state_dict["timestep_embedder.linear_2.bias"] = original_state_dict.pop( + "time_in.out_layer.bias" + ) + + # context_embedder + converted_state_dict["context_embedder_norm.weight"] = original_state_dict.pop( + "semantic_txt_norm.weight" + ) + converted_state_dict["context_embedder.weight"] = original_state_dict.pop( + "semantic_txt_in.weight" + ) + converted_state_dict["context_embedder.bias"] = original_state_dict.pop( + "semantic_txt_in.bias" + ) + + # x_embedder + converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight") + converted_state_dict["x_embedder.bias"] = original_state_dict.pop("img_in.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + # norms. + ## norm1 + converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mod.lin.bias" + ) + ## norm1_context + converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mod.lin.bias" + ) + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0 + ) + context_q, context_k, context_v = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.weight" + ) + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = torch.cat( + [ + original_state_dict.pop(f"double_blocks.{i}.img_mlp.up_proj.weight"), + original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.weight") + ], + dim=0, + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = torch.cat( + [ + original_state_dict.pop(f"double_blocks.{i}.img_mlp.up_proj.bias"), + original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.bias") + ], + dim=0, + ) + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.down_proj.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_mlp.down_proj.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = torch.cat( + [ + original_state_dict.pop(f"double_blocks.{i}.txt_mlp.up_proj.weight"), + original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.weight") + ], + dim=0, + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = torch.cat( + [ + original_state_dict.pop(f"double_blocks.{i}.txt_mlp.up_proj.bias"), + original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.bias") + ], + dim=0, + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.down_proj.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_mlp.down_proj.bias" + ) + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop( + f"double_blocks.{i}.img_attn.proj.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop( + f"double_blocks.{i}.txt_attn.proj.bias" + ) + + # single transformer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + # norm.linear <- single_blocks.0.modulation.lin + converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.weight" + ) + converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop( + f"single_blocks.{i}.modulation.lin.bias" + ) + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim * 2) + q, k, v, mlp = torch.split(original_state_dict.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) + q_bias, k_bias, v_bias, mlp_bias = torch.split( + original_state_dict.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) + converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.query_norm.weight" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop( + f"single_blocks.{i}.norm.key_norm.weight" + ) + # output projections. + converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.weight" + ) + converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop( + f"single_blocks.{i}.linear2.bias" + ) + + converted_state_dict["proj_out.weight"] = original_state_dict.pop( + "final_layer.linear.weight" + ) + converted_state_dict["proj_out.bias"] = original_state_dict.pop( + "final_layer.linear.bias" + ) + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.weight") + ) + converted_state_dict["norm_out.linear.bias"] = swap_scale_shift( + original_state_dict.pop("final_layer.adaLN_modulation.1.bias") + ) + + return converted_state_dict + + +def main(args): + original_ckpt = load_original_checkpoint(args) + + if args.transformer: + num_layers = 6 + num_single_layers = 27 + inner_dim = 3072 + mlp_ratio = 4.0 + + converted_transformer_state_dict = convert_ovis_image_transformer_checkpoint_to_diffusers( + original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio + ) + transformer = OvisImageTransformer2DModel( + in_channels=args.in_channels, out_channels=args.out_channels + ) + transformer.load_state_dict(converted_transformer_state_dict, strict=True) + + print( + f"Saving Ovis-Image Transformer in Diffusers format." + ) + transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") + + +if __name__ == "__main__": + main(args) \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8a81beca9748..f83232136fe8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -240,6 +240,7 @@ "MultiAdapter", "MultiControlNetModel", "OmniGenTransformer2DModel", + "OvisImageTransformer2DModel", "ParallelConfig", "PixArtTransformer2DModel", "PriorTransformer", @@ -533,6 +534,7 @@ "MochiPipeline", "MusicLDMPipeline", "OmniGenPipeline", + "OvisImagePipeline", "PaintByExamplePipeline", "PIAPipeline", "PixArtAlphaPipeline", @@ -959,6 +961,7 @@ MultiAdapter, MultiControlNetModel, OmniGenTransformer2DModel, + OvisImageTransformer2DModel, ParallelConfig, PixArtTransformer2DModel, PriorTransformer, @@ -1221,6 +1224,7 @@ MochiPipeline, MusicLDMPipeline, OmniGenPipeline, + OvisImagePipeline, PaintByExamplePipeline, PIAPipeline, PixArtAlphaPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 09b2b731b5c4..1e3d0cd643df 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -103,6 +103,7 @@ _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] + _import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"] _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"] @@ -208,6 +209,7 @@ LuminaNextDiT2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, + OvisImageTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, PRXTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index d0794dc321a8..442701123cd8 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -36,6 +36,7 @@ from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel + from .transformer_ovis_image import OvisImageTransformer2DModel from .transformer_prx import PRXTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sana_video import SanaVideoTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py new file mode 100644 index 000000000000..4118f1106ce7 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -0,0 +1,542 @@ +# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import is_torch_npu_available, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import ( + Timesteps, + TimestepEmbedding, + apply_rotary_emb, + get_1d_rotary_pos_embed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _get_projections(attn: "OvisImageAttention", hidden_states, encoder_hidden_states=None): + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections(attn: "OvisImageAttention", hidden_states, encoder_hidden_states=None): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = (None,) + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections(attn: "OvisImageAttention", hidden_states, encoder_hidden_states=None): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class OvisImageAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__( + self, + attn: "OvisImageAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( + attn, hidden_states, encoder_hidden_states + ) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if attn.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) + + encoder_query = attn.norm_added_q(encoder_query) + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + + +class OvisImageAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = OvisImageAttnProcessor + _available_processors = [ + OvisImageAttnProcessor, + ] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + + +@maybe_allow_in_graph +class OvisImageSingleTransformerBlock(nn.Module): + def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim * 2) + self.act_mlp = nn.SiLU() + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + + self.attn = OvisImageAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=OvisImageAttnProcessor(), + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states, mlp_hidden_gate = torch.split( + self.proj_mlp(norm_hidden_states), + [self.mlp_hidden_dim, self.mlp_hidden_dim], + dim=-1 + ) + mlp_hidden_states = self.act_mlp(mlp_hidden_gate) * mlp_hidden_states + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states + + +@maybe_allow_in_graph +class OvisImageTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = OvisImageAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + processor=OvisImageAttnProcessor(), + eps=eps, + ) + + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( + encoder_hidden_states, emb=temb + ) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class OvisImagePosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + +class OvisImageTransformer2DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + CacheMixin, +): + _supports_gradient_checkpointing = True + _no_split_modules = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"] + + @register_to_config + def __init__( + self, + patch_size: int = 1, + in_channels: int = 64, + out_channels: Optional[int] = 64, + num_layers: int = 6, + num_single_layers: int = 27, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 2048, + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = OvisImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) + + self.context_embedder_norm = nn.RMSNorm(joint_attention_dim, eps=1e-6) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + OvisImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + OvisImageSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + + timesteps_proj = self.time_proj(timestep) + temb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + + encoder_hidden_states = self.context_embedder_norm(encoder_hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + if is_torch_npu_available(): + freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) + image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) + else: + image_rotary_emb = self.pos_embed(ids) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b4043cd146b4..2612306cc412 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -300,6 +300,7 @@ _import_structure["mochi"] = ["MochiPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["omnigen"] = ["OmniGenPipeline"] + _import_structure["ovis_image"] = ["OvisImagePipeline"] _import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] @@ -717,6 +718,7 @@ from .mochi import MochiPipeline from .musicldm import MusicLDMPipeline from .omnigen import OmniGenPipeline + from .ovis_image import OvisImagePipeline from .pag import ( AnimateDiffPAGPipeline, HunyuanDiTPAGPipeline, diff --git a/src/diffusers/pipelines/ovis_image/__init__.py b/src/diffusers/pipelines/ovis_image/__init__.py new file mode 100644 index 000000000000..69feeefa51d1 --- /dev/null +++ b/src/diffusers/pipelines/ovis_image/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa: F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_output"] = ["OvisImagePipelineOutput"] + _import_structure["pipeline_ovis_image"] = ["OvisImagePipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_output import OvisImagePipelineOutput + from .pipeline_ovis_image import OvisImagePipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) \ No newline at end of file diff --git a/src/diffusers/pipelines/ovis_image/pipeline_output.py b/src/diffusers/pipelines/ovis_image/pipeline_output.py new file mode 100644 index 000000000000..fe45b5ea9f98 --- /dev/null +++ b/src/diffusers/pipelines/ovis_image/pipeline_output.py @@ -0,0 +1,35 @@ +# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from diffusers.utils import BaseOutput + + +@dataclass +class OvisImagePipelineOutput(BaseOutput): + """ + Output class for Ovis-Image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file diff --git a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py new file mode 100644 index 000000000000..b103f123ce8f --- /dev/null +++ b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py @@ -0,0 +1,737 @@ +# Copyright 2025 Alibaba Ovis-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import ( + Qwen3Model, + Qwen2TokenizerFast, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import AutoencoderKL, OvisImageTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import OvisImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import OvisImagePipeline + + >>> pipe = OvisImagePipeline.from_pretrained("AIDC-AI/Ovis-Image-7B", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A creative 3D artistic render where the text \"OVIS-IMAGE\" is written in a bold, expressive handwritten brush style using thick, wet oil paint. The paint is a mix of vibrant rainbow colors (red, blue, yellow) swirling together like toothpaste or impasto art. You can see the ridges of the brush bristles and the glossy, wet texture of the paint. The background is a clean artist's canvas. Dynamic lighting creates soft shadows behind the floating paint strokes. Colorful, expressive, tactile texture, 4k detail." + >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50, true_cfg_scale=5.0).images[0] + >>> image.save("ovis_image.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class OvisImagePipeline( + DiffusionPipeline, +): + r""" + The Ovis-Image pipeline for text-to-image generation. + + Reference: https://github.com/AIDC-AI/Ovis-Image + + Args: + transformer ([`OvisImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3Model`]): + Text encoder of class [Qwen3Model](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3Model). + tokenizer (`Qwen2TokenizerFast`): + Tokenizer of class + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: Qwen3Model, + tokenizer: Qwen2TokenizerFast, + transformer: OvisImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Ovis-Image latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.system_prompt = "Describe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: " + self.user_prompt_begin_id = 28 + self.tokenizer_max_length = 256 + self.user_prompt_begin_id + self.default_sample_size = 128 + + def _get_messages( + self, + prompt: Union[str, List[str]] = None, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + messages = [] + for each_prompt in prompt: + message = [{ + "role": "user", + "content": self.system_prompt + each_prompt, + }] + message = self.tokenizer.apply_chat_template( + message, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + messages.append(message) + return messages + + def _get_ovis_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + messages = self._get_messages(prompt) + batch_size = len(messages) + + tokens = self.tokenizer( + messages, + padding="max_length", + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors="pt", + add_special_tokens=False, + ) + input_ids = tokens.input_ids.to(device) + attention_mask = tokens.attention_mask.to(device) + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = outputs.last_hidden_state + prompt_embeds = prompt_embeds * attention_mask[..., None] + prompt_embeds = prompt_embeds[:, self.user_prompt_begin_id:, :] + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + + Args: + prompt (`str`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + if prompt_embeds is None: + prompt_embeds = self._get_ovis_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3) + text_ids[..., 1] = text_ids[..., 1] + torch.arange(prompt_embeds.shape[1])[None, :] + text_ids[..., 2] = text_ids[..., 2] + torch.arange(prompt_embeds.shape[1])[None, :] + text_ids = text_ids.to(device=device, dtype=dtype) + return prompt_embeds, text_ids + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if max_sequence_length is not None and max_sequence_length > 256: + raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 5.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 256, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ovis_image.OvisImagePipelineOutput`] or `tuple`: [`~pipelines.ovis_image.OvisImagePipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return OvisImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e6cf26a12544..13f86ad2ef5b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1922,6 +1922,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class OvisImagePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class PaintByExamplePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/ovis_image/__init__.py b/tests/pipelines/ovis_image/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From e5d7e93f85055159f54c65398c18c3dd0554fcbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=82=9F=E7=94=BB?= Date: Tue, 2 Dec 2025 09:59:22 +0800 Subject: [PATCH 2/9] fix code quality --- scripts/convert_ovis_image_to_diffusers.py | 7 +- .../transformers/transformer_ovis_image.py | 92 ++++++++++++------- .../pipelines/ovis_image/__init__.py | 2 +- .../pipelines/ovis_image/pipeline_output.py | 2 +- .../ovis_image/pipeline_ovis_image.py | 10 +- 5 files changed, 69 insertions(+), 44 deletions(-) diff --git a/scripts/convert_ovis_image_to_diffusers.py b/scripts/convert_ovis_image_to_diffusers.py index 36edf65ea4fb..66abfd945421 100644 --- a/scripts/convert_ovis_image_to_diffusers.py +++ b/scripts/convert_ovis_image_to_diffusers.py @@ -6,8 +6,7 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download -from diffusers import AutoencoderKL, OvisImageTransformer2DModel -from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint +from diffusers import OvisImageTransformer2DModel from diffusers.utils.import_utils import is_accelerate_available @@ -277,10 +276,10 @@ def main(args): transformer.load_state_dict(converted_transformer_state_dict, strict=True) print( - f"Saving Ovis-Image Transformer in Diffusers format." + "Saving Ovis-Image Transformer in Diffusers format." ) transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") if __name__ == "__main__": - main(args) \ No newline at end of file + main(args) diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py index 4118f1106ce7..55afabb4b3c6 100644 --- a/src/diffusers/models/transformers/transformer_ovis_image.py +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -60,7 +59,8 @@ def _get_fused_projections(attn: "OvisImageAttention", hidden_states, encoder_hi encoder_query = encoder_key = encoder_value = (None,) if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) + encoder_query, encoder_key, encoder_value = attn.to_added_qkv( + encoder_hidden_states).chunk(3, dim=-1) return query, key, value, encoder_query, encoder_key, encoder_value @@ -77,7 +77,8 @@ class OvisImageAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + raise ImportError( + f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") def __call__( self, @@ -138,7 +139,6 @@ def __call__( return hidden_states - class OvisImageAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = OvisImageAttnProcessor _available_processors = [ @@ -176,24 +176,31 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias - self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) - self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_q = torch.nn.RMSNorm( + dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm( + dim_head, eps=eps, elementwise_affine=elementwise_affine) self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.pre_only: self.to_out = torch.nn.ModuleList([]) - self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Linear( + self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(torch.nn.Dropout(dropout)) if added_kv_proj_dim is not None: self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) - self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + self.add_q_proj = torch.nn.Linear( + added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear( + added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear( + added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear( + self.inner_dim, query_dim, bias=out_bias) if processor is None: processor = self._default_processor_cls() @@ -207,9 +214,11 @@ def forward( image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + attn_parameters = set(inspect.signature( + self.processor.__call__).parameters.keys()) quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} - unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + unused_kwargs = [k for k, _ in kwargs.items( + ) if k not in attn_parameters and k not in quiet_attn_parameters] if len(unused_kwargs) > 0: logger.warning( f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." @@ -249,13 +258,14 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: text_seq_len = encoder_hidden_states.shape[1] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = torch.cat( + [encoder_hidden_states, hidden_states], dim=1) residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states, mlp_hidden_gate = torch.split( - self.proj_mlp(norm_hidden_states), - [self.mlp_hidden_dim, self.mlp_hidden_dim], + self.proj_mlp(norm_hidden_states), + [self.mlp_hidden_dim, self.mlp_hidden_dim], dim=-1 ) mlp_hidden_states = self.act_mlp(mlp_hidden_gate) * mlp_hidden_states @@ -273,7 +283,8 @@ def forward( if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) - encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + encoder_hidden_states, hidden_states = hidden_states[:, + :text_seq_len], hidden_states[:, text_seq_len:] return encoder_hidden_states, hidden_states @@ -302,8 +313,10 @@ def __init__( self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") + self.norm2_context = nn.LayerNorm( + dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward( + dim=dim, dim_out=dim, activation_fn="swiglu") def forward( self, @@ -313,7 +326,8 @@ def forward( image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb @@ -338,7 +352,8 @@ def forward( hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = norm_hidden_states * \ + (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output @@ -352,10 +367,12 @@ def forward( encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * \ + (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + encoder_hidden_states = encoder_hidden_states + \ + c_gate_mlp.unsqueeze(1) * context_ff_output if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) @@ -400,9 +417,11 @@ class OvisImageTransformer2DModel( CacheMixin, ): _supports_gradient_checkpointing = True - _no_split_modules = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"] + _no_split_modules = ["OvisImageTransformerBlock", + "OvisImageSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] - _repeated_blocks = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"] + _repeated_blocks = ["OvisImageTransformerBlock", + "OvisImageSingleTransformerBlock"] @register_to_config def __init__( @@ -421,10 +440,13 @@ def __init__( self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim - self.pos_embed = OvisImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) + self.pos_embed = OvisImagePosEmbed( + theta=10000, axes_dim=axes_dims_rope) - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=self.inner_dim) self.context_embedder_norm = nn.RMSNorm(joint_attention_dim, eps=1e-6) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) @@ -452,8 +474,10 @@ def __init__( ] ) - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False @@ -466,15 +490,17 @@ def forward( txt_ids: torch.Tensor = None, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: - + hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 timesteps_proj = self.time_proj(timestep) - temb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + temb = self.timestep_embedder( + timesteps_proj.to(dtype=hidden_states.dtype)) - encoder_hidden_states = self.context_embedder_norm(encoder_hidden_states) + encoder_hidden_states = self.context_embedder_norm( + encoder_hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: diff --git a/src/diffusers/pipelines/ovis_image/__init__.py b/src/diffusers/pipelines/ovis_image/__init__.py index 69feeefa51d1..275061b1f626 100644 --- a/src/diffusers/pipelines/ovis_image/__init__.py +++ b/src/diffusers/pipelines/ovis_image/__init__.py @@ -47,4 +47,4 @@ ) for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) \ No newline at end of file + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/ovis_image/pipeline_output.py b/src/diffusers/pipelines/ovis_image/pipeline_output.py index fe45b5ea9f98..160c5b73a917 100644 --- a/src/diffusers/pipelines/ovis_image/pipeline_output.py +++ b/src/diffusers/pipelines/ovis_image/pipeline_output.py @@ -32,4 +32,4 @@ class OvisImagePipelineOutput(BaseOutput): num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. """ - images: Union[List[PIL.Image.Image], np.ndarray] \ No newline at end of file + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py index b103f123ce8f..89e42a90cc36 100644 --- a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py +++ b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py @@ -22,7 +22,7 @@ Qwen2TokenizerFast, ) -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, OvisImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( @@ -220,11 +220,11 @@ def _get_ovis_prompt_embeds( batch_size = len(messages) tokens = self.tokenizer( - messages, - padding="max_length", + messages, + padding="max_length", truncation=True, - max_length=self.tokenizer_max_length, - return_tensors="pt", + max_length=self.tokenizer_max_length, + return_tensors="pt", add_special_tokens=False, ) input_ids = tokens.input_ids.to(device) From 04c89da6d7be3c9a3c96c9420eed528fe430542f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=82=9F=E7=94=BB?= Date: Tue, 2 Dec 2025 10:10:31 +0800 Subject: [PATCH 3/9] optimize pipeline_ovis_image.py according to the feedbacks --- .../ovis_image/pipeline_ovis_image.py | 81 +++---------------- 1 file changed, 10 insertions(+), 71 deletions(-) diff --git a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py index 89e42a90cc36..6393ee8d0e99 100644 --- a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py +++ b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py @@ -55,7 +55,7 @@ >>> pipe = OvisImagePipeline.from_pretrained("AIDC-AI/Ovis-Image-7B", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") >>> prompt = "A creative 3D artistic render where the text \"OVIS-IMAGE\" is written in a bold, expressive handwritten brush style using thick, wet oil paint. The paint is a mix of vibrant rainbow colors (red, blue, yellow) swirling together like toothpaste or impasto art. You can see the ridges of the brush bristles and the glossy, wet texture of the paint. The background is a clean artist's canvas. Dynamic lighting creates soft shadows behind the floating paint strokes. Colorful, expressive, tactile texture, 4k detail." - >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50, true_cfg_scale=5.0).images[0] + >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50, guidance_scale=5.0).images[0] >>> image.save("ovis_image.png") ``` """ @@ -362,59 +362,6 @@ def _unpack_latents(latents, height, width, vae_scale_factor): return latents - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." - deprecate( - "enable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." - deprecate( - "disable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." - deprecate( - "enable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." - deprecate( - "disable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.disable_tiling() - def prepare_latents( self, batch_size, @@ -475,8 +422,8 @@ def interrupt(self): def __call__( self, prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - true_cfg_scale: float = 5.0, + negative_prompt: Union[str, List[str]] = "", + guidance_scale: float = 5.0, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -502,10 +449,10 @@ def __call__( instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is not greater than `1`). - true_cfg_scale (`float`, *optional*, defaults to 1.0): - True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + guidance_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `guidance_scale` > 1 and `negative_prompt` is provided. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. @@ -591,10 +538,7 @@ def __call__( device = self._execution_device - has_neg_prompt = negative_prompt is not None or ( - negative_prompt_embeds is not None - ) - do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + do_classifier_free_guidance = guidance_scale > 1 ( prompt_embeds, text_ids, @@ -604,7 +548,7 @@ def __call__( device=device, num_images_per_prompt=num_images_per_prompt, ) - if do_true_cfg: + if do_classifier_free_guidance: ( negative_prompt_embeds, negative_text_ids, @@ -653,9 +597,6 @@ def __call__( if self.joint_attention_kwargs is None: self._joint_attention_kwargs = {} - image_embeds = None - negative_image_embeds = None - # 6. Denoising loop # We set the index here to remove DtoH sync, helpful especially during compilation. # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 @@ -666,8 +607,6 @@ def __call__( continue self._current_timestep = t - if image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -681,7 +620,7 @@ def __call__( return_dict=False, )[0] - if do_true_cfg: + if do_classifier_free_guidance: with self.transformer.cache_context("uncond"): neg_noise_pred = self.transformer( hidden_states=latents, @@ -691,7 +630,7 @@ def __call__( img_ids=latent_image_ids, return_dict=False, )[0] - noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From cc1693e5438998a28f37d5cc252f95ed2500880a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=82=9F=E7=94=BB?= Date: Tue, 2 Dec 2025 16:01:34 +0800 Subject: [PATCH 4/9] optimize imports --- .../models/transformers/transformer_ovis_image.py | 12 ++++-------- .../pipelines/ovis_image/pipeline_ovis_image.py | 13 ++----------- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py index 55afabb4b3c6..4b300816534a 100644 --- a/src/diffusers/models/transformers/transformer_ovis_image.py +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -26,16 +26,12 @@ from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin -from ..embeddings import ( - Timesteps, - TimestepEmbedding, - apply_rotary_emb, - get_1d_rotary_pos_embed, -) +from ..embeddings import (TimestepEmbedding, Timesteps, apply_rotary_emb, + get_1d_rotary_pos_embed) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle - +from ..normalization import (AdaLayerNormContinuous, AdaLayerNormZero, + AdaLayerNormZeroSingle) logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py index 6393ee8d0e99..1983ebe2fa17 100644 --- a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py +++ b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py @@ -17,25 +17,16 @@ import numpy as np import torch -from transformers import ( - Qwen3Model, - Qwen2TokenizerFast, -) +from transformers import Qwen2TokenizerFast, Qwen3Model from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, OvisImageTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import ( - deprecate, - is_torch_xla_available, - logging, - replace_example_docstring, -) +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import OvisImagePipelineOutput - if is_torch_xla_available(): import torch_xla.core.xla_model as xm From 311743206a0b8c5e1abb1d29812224675864670b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=82=9F=E7=94=BB?= Date: Tue, 2 Dec 2025 17:26:55 +0800 Subject: [PATCH 5/9] add docs --- .../en/api/models/ovisimage_transformer2d.md | 24 +++++++++ docs/source/en/api/pipelines/ovis_image.md | 50 +++++++++++++++++++ .../transformers/transformer_ovis_image.py | 50 ++++++++++++++++++- 3 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/api/models/ovisimage_transformer2d.md create mode 100644 docs/source/en/api/pipelines/ovis_image.md diff --git a/docs/source/en/api/models/ovisimage_transformer2d.md b/docs/source/en/api/models/ovisimage_transformer2d.md new file mode 100644 index 000000000000..484652404af3 --- /dev/null +++ b/docs/source/en/api/models/ovisimage_transformer2d.md @@ -0,0 +1,24 @@ + + +# OvisImageTransformer2DModel + +The model can be loaded with the following code snippet. + +```python +from diffusers import OvisImageTransformer2DModel + +transformer = OvisImageTransformer2DModel.from_pretrained("AIDC-AI/Ovis-Image-7B", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## OvisImageTransformer2DModel + +[[autodoc]] OvisImageTransformer2DModel diff --git a/docs/source/en/api/pipelines/ovis_image.md b/docs/source/en/api/pipelines/ovis_image.md new file mode 100644 index 000000000000..e03889b0a020 --- /dev/null +++ b/docs/source/en/api/pipelines/ovis_image.md @@ -0,0 +1,50 @@ + + +# Ovis-Image + +![concepts](https://github.com/AIDC-AI/Ovis-Image/blob/main/docs/imgs/ovis_image_case.png) + +Ovis-Image is a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints. + +[Ovis-Image Technical Report](https://arxiv.org/abs/2511.22982) from Alibaba Group, by Guo-Hua Wang, Liangfu Cao, Tianyu Cui, Minghao Fu, Xiaohao Chen, Pengxin Zhan, Jianshan Zhao, Lan Li, Bowen Fu, Jiaqi Liu, Qing-Guo Chen. + +The abstract from the paper is: + +*We introduce Ovis-Image, a 7B text-to-image model specifically optimized for high-quality text rendering, designed to operate efficiently under stringent computational constraints. Built upon our previous Ovis-U1 framework, Ovis-Image integrates a diffusion-based visual decoder with the stronger Ovis 2.5 multimodal backbone, leveraging a text-centric training pipeline that combines large-scale pre-training with carefully tailored post-training refinements. Despite its compact architecture, Ovis-Image achieves text rendering performance on par with significantly larger open models such as Qwen-Image and approaches closed-source systems like Seedream and GPT4o. Crucially, the model remains deployable on a single high-end GPU with moderate memory, narrowing the gap between frontier-level text rendering and practical deployment. Our results indicate that combining a strong multimodal backbone with a carefully designed, text-focused training recipe is sufficient to achieve reliable bilingual text rendering without resorting to oversized or proprietary models.* + +**Highlights**: + +* **Strong text rendering at a compact 7B scale**: Ovis-Image is a 7B text-to-image model that delivers text rendering quality comparable to much larger 20B-class systems such as Qwen-Image and competitive with leading closed-source models like GPT4o in text-centric scenarios, while remaining small enough to run on widely accessible hardware. +* **High fidelity on text-heavy, layout-sensitive prompts**: The model excels on prompts that demand tight alignment between linguistic content and rendered typography (e.g., posters, banners, logos, UI mockups, infographics), producing legible, correctly spelled, and semantically consistent text across diverse fonts, sizes, and aspect ratios without compromising overall visual quality. +* **Efficiency and deployability**: With its 7B parameter budget and streamlined architecture, Ovis-Image fits on a single high-end GPU with moderate memory, supports low-latency interactive use, and scales to batch production serving, bringing near–frontier text rendering to applications where tens-of-billions–parameter models are impractical. + + +This pipeline was contributed by Ovis-Image Team. The original codebase can be found [here](https://github.com/AIDC-AI/Ovis-Image). + +Available models: + +| Model | Recommended dtype | +|:-----:|:-----------------:| +| [`AIDC-AI/Ovis-Image-7B`](https://huggingface.co/AIDC-AI/Ovis-Image-7B) | `torch.bfloat16` | + +Refer to [this](https://huggingface.co/collections/AIDC-AI/ovis-image) collection for more information. + +## OvisImagePipeline + +[[autodoc]] OvisImagePipeline + - all + - __call__ + +## OvisImagePipelineOutput + +[[autodoc]] pipelines.ovis_image.pipeline_output.OvisImagePipelineOutput diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py index 4b300816534a..336cf4471434 100644 --- a/src/diffusers/models/transformers/transformer_ovis_image.py +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -412,6 +412,33 @@ class OvisImageTransformer2DModel( FromOriginalModelMixin, CacheMixin, ): + """ + The Transformer model introduced in Ovis-Image. + + Reference: https://github.com/AIDC-AI/Ovis-Image + + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `6`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `27`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `2048`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + _supports_gradient_checkpointing = True _no_split_modules = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"] @@ -486,7 +513,28 @@ def forward( txt_ids: torch.Tensor = None, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: - + """ + The [`OvisImageTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep (`torch.LongTensor`): + Used to indicate denoising step. + img_ids: (`torch.Tensor`): + The position ids for image tokens. + txt_ids (`torch.Tensor`): + The position ids for text tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 From da6db7b7e232cf8286d83f2158d1214799e2ab87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=82=9F=E7=94=BB?= Date: Tue, 2 Dec 2025 17:33:38 +0800 Subject: [PATCH 6/9] make style --- scripts/convert_ovis_image_to_diffusers.py | 52 +++------- .../transformers/transformer_ovis_image.py | 95 +++++++------------ .../ovis_image/pipeline_ovis_image.py | 18 ++-- src/diffusers/utils/dummy_pt_objects.py | 15 +++ 4 files changed, 71 insertions(+), 109 deletions(-) diff --git a/scripts/convert_ovis_image_to_diffusers.py b/scripts/convert_ovis_image_to_diffusers.py index 66abfd945421..0d3d9cd44bf6 100644 --- a/scripts/convert_ovis_image_to_diffusers.py +++ b/scripts/convert_ovis_image_to_diffusers.py @@ -63,29 +63,15 @@ def convert_ovis_image_transformer_checkpoint_to_diffusers( converted_state_dict = {} ## time_text_embed.timestep_embedder <- time_in - converted_state_dict["timestep_embedder.linear_1.weight"] = original_state_dict.pop( - "time_in.in_layer.weight" - ) - converted_state_dict["timestep_embedder.linear_1.bias"] = original_state_dict.pop( - "time_in.in_layer.bias" - ) - converted_state_dict["timestep_embedder.linear_2.weight"] = original_state_dict.pop( - "time_in.out_layer.weight" - ) - converted_state_dict["timestep_embedder.linear_2.bias"] = original_state_dict.pop( - "time_in.out_layer.bias" - ) + converted_state_dict["timestep_embedder.linear_1.weight"] = original_state_dict.pop("time_in.in_layer.weight") + converted_state_dict["timestep_embedder.linear_1.bias"] = original_state_dict.pop("time_in.in_layer.bias") + converted_state_dict["timestep_embedder.linear_2.weight"] = original_state_dict.pop("time_in.out_layer.weight") + converted_state_dict["timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.out_layer.bias") # context_embedder - converted_state_dict["context_embedder_norm.weight"] = original_state_dict.pop( - "semantic_txt_norm.weight" - ) - converted_state_dict["context_embedder.weight"] = original_state_dict.pop( - "semantic_txt_in.weight" - ) - converted_state_dict["context_embedder.bias"] = original_state_dict.pop( - "semantic_txt_in.bias" - ) + converted_state_dict["context_embedder_norm.weight"] = original_state_dict.pop("semantic_txt_norm.weight") + converted_state_dict["context_embedder.weight"] = original_state_dict.pop("semantic_txt_in.weight") + converted_state_dict["context_embedder.bias"] = original_state_dict.pop("semantic_txt_in.bias") # x_embedder converted_state_dict["x_embedder.weight"] = original_state_dict.pop("img_in.weight") @@ -151,14 +137,14 @@ def convert_ovis_image_transformer_checkpoint_to_diffusers( converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = torch.cat( [ original_state_dict.pop(f"double_blocks.{i}.img_mlp.up_proj.weight"), - original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.weight") + original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.weight"), ], dim=0, ) converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = torch.cat( [ original_state_dict.pop(f"double_blocks.{i}.img_mlp.up_proj.bias"), - original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.bias") + original_state_dict.pop(f"double_blocks.{i}.img_mlp.gate_proj.bias"), ], dim=0, ) @@ -171,14 +157,14 @@ def convert_ovis_image_transformer_checkpoint_to_diffusers( converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = torch.cat( [ original_state_dict.pop(f"double_blocks.{i}.txt_mlp.up_proj.weight"), - original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.weight") + original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.weight"), ], dim=0, ) converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = torch.cat( [ original_state_dict.pop(f"double_blocks.{i}.txt_mlp.up_proj.bias"), - original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.bias") + original_state_dict.pop(f"double_blocks.{i}.txt_mlp.gate_proj.bias"), ], dim=0, ) @@ -242,12 +228,8 @@ def convert_ovis_image_transformer_checkpoint_to_diffusers( f"single_blocks.{i}.linear2.bias" ) - converted_state_dict["proj_out.weight"] = original_state_dict.pop( - "final_layer.linear.weight" - ) - converted_state_dict["proj_out.bias"] = original_state_dict.pop( - "final_layer.linear.bias" - ) + converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias") converted_state_dict["norm_out.linear.weight"] = swap_scale_shift( original_state_dict.pop("final_layer.adaLN_modulation.1.weight") ) @@ -270,14 +252,10 @@ def main(args): converted_transformer_state_dict = convert_ovis_image_transformer_checkpoint_to_diffusers( original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio ) - transformer = OvisImageTransformer2DModel( - in_channels=args.in_channels, out_channels=args.out_channels - ) + transformer = OvisImageTransformer2DModel(in_channels=args.in_channels, out_channels=args.out_channels) transformer.load_state_dict(converted_transformer_state_dict, strict=True) - print( - "Saving Ovis-Image Transformer in Diffusers format." - ) + print("Saving Ovis-Image Transformer in Diffusers format.") transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer") diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py index 336cf4471434..0a09aa720b3f 100644 --- a/src/diffusers/models/transformers/transformer_ovis_image.py +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -26,12 +26,11 @@ from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin -from ..embeddings import (TimestepEmbedding, Timesteps, apply_rotary_emb, - get_1d_rotary_pos_embed) +from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import (AdaLayerNormContinuous, AdaLayerNormZero, - AdaLayerNormZeroSingle) +from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -55,8 +54,7 @@ def _get_fused_projections(attn: "OvisImageAttention", hidden_states, encoder_hi encoder_query = encoder_key = encoder_value = (None,) if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): - encoder_query, encoder_key, encoder_value = attn.to_added_qkv( - encoder_hidden_states).chunk(3, dim=-1) + encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) return query, key, value, encoder_query, encoder_key, encoder_value @@ -73,8 +71,7 @@ class OvisImageAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") def __call__( self, @@ -172,31 +169,24 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim self.added_proj_bias = added_proj_bias - self.norm_q = torch.nn.RMSNorm( - dim_head, eps=eps, elementwise_affine=elementwise_affine) - self.norm_k = torch.nn.RMSNorm( - dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) if not self.pre_only: self.to_out = torch.nn.ModuleList([]) - self.to_out.append(torch.nn.Linear( - self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) self.to_out.append(torch.nn.Dropout(dropout)) if added_kv_proj_dim is not None: self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) - self.add_q_proj = torch.nn.Linear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_k_proj = torch.nn.Linear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.add_v_proj = torch.nn.Linear( - added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - self.to_add_out = torch.nn.Linear( - self.inner_dim, query_dim, bias=out_bias) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) if processor is None: processor = self._default_processor_cls() @@ -210,11 +200,9 @@ def forward( image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: - attn_parameters = set(inspect.signature( - self.processor.__call__).parameters.keys()) + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} - unused_kwargs = [k for k, _ in kwargs.items( - ) if k not in attn_parameters and k not in quiet_attn_parameters] + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] if len(unused_kwargs) > 0: logger.warning( f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." @@ -254,15 +242,12 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: text_seq_len = encoder_hidden_states.shape[1] - hidden_states = torch.cat( - [encoder_hidden_states, hidden_states], dim=1) + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states, mlp_hidden_gate = torch.split( - self.proj_mlp(norm_hidden_states), - [self.mlp_hidden_dim, self.mlp_hidden_dim], - dim=-1 + self.proj_mlp(norm_hidden_states), [self.mlp_hidden_dim, self.mlp_hidden_dim], dim=-1 ) mlp_hidden_states = self.act_mlp(mlp_hidden_gate) * mlp_hidden_states joint_attention_kwargs = joint_attention_kwargs or {} @@ -279,8 +264,7 @@ def forward( if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) - encoder_hidden_states, hidden_states = hidden_states[:, - :text_seq_len], hidden_states[:, text_seq_len:] + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] return encoder_hidden_states, hidden_states @@ -309,10 +293,8 @@ def __init__( self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") - self.norm2_context = nn.LayerNorm( - dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward( - dim=dim, dim_out=dim, activation_fn="swiglu") + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) + self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="swiglu") def forward( self, @@ -322,8 +304,7 @@ def forward( image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, emb=temb) + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb @@ -348,8 +329,7 @@ def forward( hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * \ - (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output @@ -363,12 +343,10 @@ def forward( encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * \ - (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + \ - c_gate_mlp.unsqueeze(1) * context_ff_output + encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) @@ -440,11 +418,9 @@ class OvisImageTransformer2DModel( """ _supports_gradient_checkpointing = True - _no_split_modules = ["OvisImageTransformerBlock", - "OvisImageSingleTransformerBlock"] + _no_split_modules = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] - _repeated_blocks = ["OvisImageTransformerBlock", - "OvisImageSingleTransformerBlock"] + _repeated_blocks = ["OvisImageTransformerBlock", "OvisImageSingleTransformerBlock"] @register_to_config def __init__( @@ -463,13 +439,10 @@ def __init__( self.out_channels = out_channels or in_channels self.inner_dim = num_attention_heads * attention_head_dim - self.pos_embed = OvisImagePosEmbed( - theta=10000, axes_dim=axes_dims_rope) + self.pos_embed = OvisImagePosEmbed(theta=10000, axes_dim=axes_dims_rope) - self.time_proj = Timesteps( - num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding( - in_channels=256, time_embed_dim=self.inner_dim) + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim) self.context_embedder_norm = nn.RMSNorm(joint_attention_dim, eps=1e-6) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) @@ -497,10 +470,8 @@ def __init__( ] ) - self.norm_out = AdaLayerNormContinuous( - self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear( - self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False @@ -540,11 +511,9 @@ def forward( timestep = timestep.to(hidden_states.dtype) * 1000 timesteps_proj = self.time_proj(timestep) - temb = self.timestep_embedder( - timesteps_proj.to(dtype=hidden_states.dtype)) + temb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) - encoder_hidden_states = self.context_embedder_norm( - encoder_hidden_states) + encoder_hidden_states = self.context_embedder_norm(encoder_hidden_states) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: diff --git a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py index 1983ebe2fa17..c11811951fea 100644 --- a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py +++ b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py @@ -27,6 +27,7 @@ from ..pipeline_utils import DiffusionPipeline from .pipeline_output import OvisImagePipelineOutput + if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -184,15 +185,14 @@ def _get_messages( prompt = [prompt] if isinstance(prompt, str) else prompt messages = [] for each_prompt in prompt: - message = [{ - "role": "user", - "content": self.system_prompt + each_prompt, - }] + message = [ + { + "role": "user", + "content": self.system_prompt + each_prompt, + } + ] message = self.tokenizer.apply_chat_template( - message, - tokenize=False, - add_generation_prompt=True, - enable_thinking=False + message, tokenize=False, add_generation_prompt=True, enable_thinking=False ) messages.append(message) return messages @@ -226,7 +226,7 @@ def _get_ovis_prompt_embeds( ) prompt_embeds = outputs.last_hidden_state prompt_embeds = prompt_embeds * attention_mask[..., None] - prompt_embeds = prompt_embeds[:, self.user_prompt_begin_id:, :] + prompt_embeds = prompt_embeds[:, self.user_prompt_begin_id :, :] _, seq_len, _ = prompt_embeds.shape diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index fe9a4b30f0c1..6be7618fcd5e 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1248,6 +1248,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class OvisImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ParallelConfig(metaclass=DummyObject): _backends = ["torch"] From f5c79c580ad284913f72c44b78ba24111b22bace Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=82=9F=E7=94=BB?= Date: Tue, 2 Dec 2025 18:04:34 +0800 Subject: [PATCH 7/9] make style --- .../pipelines/ovis_image/pipeline_ovis_image.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py index c11811951fea..94d6cee93d7e 100644 --- a/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py +++ b/src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py @@ -46,7 +46,7 @@ >>> pipe = OvisImagePipeline.from_pretrained("AIDC-AI/Ovis-Image-7B", torch_dtype=torch.bfloat16) >>> pipe.to("cuda") - >>> prompt = "A creative 3D artistic render where the text \"OVIS-IMAGE\" is written in a bold, expressive handwritten brush style using thick, wet oil paint. The paint is a mix of vibrant rainbow colors (red, blue, yellow) swirling together like toothpaste or impasto art. You can see the ridges of the brush bristles and the glossy, wet texture of the paint. The background is a clean artist's canvas. Dynamic lighting creates soft shadows behind the floating paint strokes. Colorful, expressive, tactile texture, 4k detail." + >>> prompt = 'A creative 3D artistic render where the text "OVIS-IMAGE" is written in a bold, expressive handwritten brush style using thick, wet oil paint. The paint is a mix of vibrant rainbow colors (red, blue, yellow) swirling together like toothpaste or impasto art. You can see the ridges of the brush bristles and the glossy, wet texture of the paint. The background is a clean artist\'s canvas. Dynamic lighting creates soft shadows behind the floating paint strokes. Colorful, expressive, tactile texture, 4k detail.' >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50, guidance_scale=5.0).images[0] >>> image.save("ovis_image.png") ``` @@ -142,7 +142,8 @@ class OvisImagePipeline( vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`Qwen3Model`]): - Text encoder of class [Qwen3Model](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3Model). + Text encoder of class + [Qwen3Model](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3Model). tokenizer (`Qwen2TokenizerFast`): Tokenizer of class [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). @@ -495,9 +496,9 @@ def __call__( Examples: Returns: - [`~pipelines.ovis_image.OvisImagePipelineOutput`] or `tuple`: [`~pipelines.ovis_image.OvisImagePipelineOutput`] if `return_dict` - is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated - images. + [`~pipelines.ovis_image.OvisImagePipelineOutput`] or `tuple`: + [`~pipelines.ovis_image.OvisImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor From 11a5aca7a7942cd6ca78292a22a53bf40a671386 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 2 Dec 2025 22:18:53 +0100 Subject: [PATCH 8/9] add ovis to toctree --- docs/source/en/_toctree.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d2b4a0de915b..f06ba3251142 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -375,6 +375,8 @@ title: MochiTransformer3DModel - local: api/models/omnigen_transformer title: OmniGenTransformer2DModel + - local: api/models/ovisimage_transformer_2d + title: OvisImageTransformer2DModel - local: api/models/pixart_transformer2d title: PixArtTransformer2DModel - local: api/models/prior_transformer @@ -567,6 +569,8 @@ title: MultiDiffusion - local: api/pipelines/omnigen title: OmniGen + - local: api/pipelines/ovis_image + title: Ovis-Image - local: api/pipelines/pag title: PAG - local: api/pipelines/paint_by_example From 2a361252a293d1d4ecebf0e0832b9000985354b8 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 2 Dec 2025 22:26:28 +0100 Subject: [PATCH 9/9] oops --- docs/source/en/_toctree.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f06ba3251142..79299d11dae9 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -375,7 +375,7 @@ title: MochiTransformer3DModel - local: api/models/omnigen_transformer title: OmniGenTransformer2DModel - - local: api/models/ovisimage_transformer_2d + - local: api/models/ovisimage_transformer2d title: OvisImageTransformer2DModel - local: api/models/pixart_transformer2d title: PixArtTransformer2DModel