Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
71 changes: 71 additions & 0 deletions diffsynth/classifiers/open_logo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.models.detection import fasterrcnn_vgg16_bn
from torchvision.ops import nms
Comment on lines +3 to +5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The imports torchvision.transforms.functional as F and from torchvision.ops import nms are not used in this file. They should be removed to keep the code clean.

Suggested change
import torchvision.transforms.functional as F
from torchvision.models.detection import fasterrcnn_vgg16_bn
from torchvision.ops import nms
from torchvision.models.detection import fasterrcnn_vgg16_bn


class OpenLogoClassifier(nn.Module):
def __init__(self, model_path, device="cuda"):
super().__init__()
self.device = device
# Load the pretrained Faster R-CNN model with VGG16 backbone
self.model = fasterrcnn_vgg16_bn(pretrained=False) # We will load our own weights

# Load the state dict from the provided .pt file
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)

self.model.to(self.device)
self.model.eval()

def preprocess_image(self, image):
# Preprocess the image for the Faster R-CNN model
# The image is expected to be a tensor of shape [C, H, W] with values in [0, 1]
image = image.to(self.device)
return image

def compute_gradient(self, latents, pipeline, timestep, class_id, mask=None, guidance_scale=1.0):
with torch.enable_grad():
latents = latents.detach().requires_grad_(True)

# Decode latents to video
video = pipeline.vae.decode(latents)

# The output of the VAE is in the range [-1, 1], we need to normalize it to [0, 1]
video = (video / 2 + 0.5).clamp(0, 1)

log_prob_sum = 0

num_frames = video.shape[2]
for i in range(num_frames):
image = video[:, :, i, :, :]

# Preprocess the image for the classifier
image_processed = self.preprocess_image(image)

# Get predictions from the classifier
self.model.train()

image_for_model = image_processed
features = self.model.backbone(image_for_model)
proposals, proposal_losses = self.model.rpn(image_for_model, features, None)
box_features = self.model.roi_heads.box_roi_pool(features, proposals, [image_for_model.shape[2:]])
box_features = self.model.roi_heads.box_head(box_features)
class_logits, box_regression = self.model.roi_heads.box_predictor(box_features)

log_probs = torch.nn.functional.log_softmax(class_logits, dim=-1)
log_prob_target = log_probs[:, class_id].mean()
log_prob_sum = log_prob_sum + log_prob_target

grad = torch.autograd.grad(log_prob_sum, latents, grad_outputs=torch.ones_like(log_prob_sum))[0]

# Apply the soft logo mask if provided
if mask is not None:
grad = grad * mask

# Clamp the gradient to avoid artifacts
grad = torch.clamp(grad, -0.1, 0.1)

self.model.eval()
Comment on lines +47 to +69
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Calling self.model.train() inside the compute_gradient method is incorrect. This will put the model, including its BatchNorm layers, into training mode. During guidance, the model should behave as it does during inference, which means BatchNorm layers should use their learned running statistics, not the statistics of the current mini-batch. Calling .train() will cause them to use batch statistics and update their running statistics, which is not desired and will lead to incorrect gradients and potentially artifacts in the generated video. The model is already set to eval() mode in the __init__ method, so these calls are not only incorrect but also unnecessary.

                # Get predictions from the classifier

                image_for_model = image_processed
                features = self.model.backbone(image_for_model)
                proposals, proposal_losses = self.model.rpn(image_for_model, features, None)
                box_features = self.model.roi_heads.box_roi_pool(features, proposals, [image_for_model.shape[2:]])
                box_features = self.model.roi_heads.box_head(box_features)
                class_logits, box_regression = self.model.roi_heads.box_predictor(box_features)

                log_probs = torch.nn.functional.log_softmax(class_logits, dim=-1)
                log_prob_target = log_probs[:, class_id].mean()
                log_prob_sum = log_prob_sum + log_prob_target

            grad = torch.autograd.grad(log_prob_sum, latents, grad_outputs=torch.ones_like(log_prob_sum))[0]

            # Apply the soft logo mask if provided
            if mask is not None:
                grad = grad * mask

            # Clamp the gradient to avoid artifacts
            grad = torch.clamp(grad, -0.1, 0.1)


return grad
57 changes: 57 additions & 0 deletions diffsynth/models/logo_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms import GaussianBlur
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import from torchvision.transforms import GaussianBlur is not used. F.gaussian_blur is used instead, which comes from torchvision.transforms.functional. This unused import should be removed.


class LogoDetector(nn.Module):
def __init__(self, model_path=None, device="cuda"):
super().__init__()
self.device = device
if model_path:
# Load a custom model if a path is provided
# For now, we assume the model is a Faster R-CNN model
self.model = torch.load(model_path, map_location=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using torch.load(model_path) to load an entire model is not recommended because the serialized data is bound to the specific classes and directory structure used when the model was saved. The recommended approach is to save and load only the model's state_dict. This makes the code more robust to refactoring.

Suggested change
self.model = torch.load(model_path, map_location=self.device)
self.model = fasterrcnn_resnet50_fpn(pretrained=False) # Or your custom model class
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)

else:
# Load a pretrained Faster R-CNN model from torchvision
self.model = fasterrcnn_resnet50_fpn(pretrained=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pretrained argument for torchvision models is deprecated and will be removed in a future version. You should use the weights argument instead for better future compatibility.

Suggested change
self.model = fasterrcnn_resnet50_fpn(pretrained=True)
self.model = fasterrcnn_resnet50_fpn(weights="FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT")


self.model.to(self.device)
self.model.eval()

def preprocess_image(self, image):
# The image is expected to be a tensor of shape [C, H, W] with values in [0, 1]
image = image.to(self.device)
return image

def get_logo_mask(self, image, threshold=0.5, blur_kernel_size=21, blur_sigma=5):
# The image is expected to be a tensor of shape (B, C, H, W)
if image.dim() == 3:
image = image.unsqueeze(0)

image_processed = self.preprocess_image(image)

with torch.no_grad():
predictions = self.model(image_processed)

mask = torch.zeros(image.shape[2:], device=self.device)

for i in range(len(predictions)):
boxes = predictions[i]['boxes']
scores = predictions[i]['scores']

for box, score in zip(boxes, scores):
if score > threshold:
x1, y1, x2, y2 = box.int()
mask[y1:y2, x1:x2] = 1

# Create a soft mask by applying a Gaussian blur
if blur_kernel_size > 0 and blur_sigma > 0:
mask = F.gaussian_blur(mask.unsqueeze(0).unsqueeze(0), kernel_size=blur_kernel_size, sigma=blur_sigma)
mask = mask.squeeze(0).squeeze(0)

# Normalize the mask to be in the range [0, 1]
if mask.max() > 0:
mask = mask / mask.max()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The line mask = mask / mask.max() could result in a division by zero if mask.max() is 0. While unlikely if the threshold is met, it's safer to guard against this. Adding a small epsilon to the denominator will prevent this.

Suggested change
mask = mask / mask.max()
mask = mask / (mask.max() + 1e-8)


return mask
53 changes: 53 additions & 0 deletions diffsynth/pipelines/wan_video_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from ..prompters import WanPrompter
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
from ..lora import GeneralLoRALoader
from ..classifiers.open_logo import OpenLogoClassifier
from ..models.logo_detector import LogoDetector



Expand All @@ -43,6 +45,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=Non
self.vae: WanVideoVAE = None
self.motion_controller: WanMotionControllerModel = None
self.vace: VaceWanModel = None
self.classifier: OpenLogoClassifier = None
self.in_iteration_models = ("dit", "motion_controller", "vace")
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace")
self.unit_runner = PipelineUnitRunner()
Expand Down Expand Up @@ -72,6 +75,10 @@ def load_lora(self, module, path, alpha=1):
loader.load(module, lora, alpha=alpha)


def load_classifier(self, model_path, **kwargs):
self.classifier = OpenLogoClassifier(model_path, device=self.device, **kwargs)


def training_loss(self, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps)
Expand Down Expand Up @@ -348,6 +355,19 @@ def from_pretrained(
return pipe


def get_guidance_strength(self, progress_id, num_inference_steps, schedule_name="late_step_emphasis"):
if schedule_name == "late_step_emphasis":
# Start guidance after 60% of the steps
start_step = int(num_inference_steps * 0.6)
if progress_id < start_step:
return 0.0
else:
# Ramp up from 0 to 1
return (progress_id - start_step) / (num_inference_steps - start_step)
else:
return 1.0


@torch.no_grad()
def __call__(
self,
Expand Down Expand Up @@ -383,6 +403,12 @@ def __call__(
# Classifier-free guidance
cfg_scale: Optional[float] = 5.0,
cfg_merge: Optional[bool] = False,
# Classifier guidance
classifier_guidance_scale: float = 0.0,
logo_mask: Optional[torch.Tensor] = None,
classifier_class_id: int = 0,
guidance_schedule: str = "late_step_emphasis",
logo_detector_path: Optional[str] = None,
# Boundary
switch_DiT_boundary: Optional[float] = 0.875,
# Scheduler
Expand Down Expand Up @@ -433,6 +459,14 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

# Logo detection
if logo_detector_path is not None:
logo_detector = LogoDetector(logo_detector_path, device=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The LogoDetector is instantiated inside the __call__ method. This means the logo detection model will be loaded from disk every time the pipeline is called, which is very inefficient and will significantly slow down repeated inferences. The detector should be loaded once and stored as a pipeline attribute, similar to how self.classifier is handled. You should add a load_logo_detector method to the pipeline and use a self.logo_detector attribute.

Suggested change
logo_detector = LogoDetector(logo_detector_path, device=self.device)
logo_detector = self.logo_detector

if input_image is not None:
input_image_tensor = self.preprocess_image(input_image.resize((width, height)))
logo_mask = logo_detector.get_logo_mask(input_image_tensor)
logo_mask = logo_mask.to(dtype=self.torch_dtype, device=self.device)

# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
Expand All @@ -456,6 +490,25 @@ def __call__(
else:
noise_pred = noise_pred_posi

# Classifier guidance
if self.classifier is not None and classifier_guidance_scale > 0.0:
g_t = self.get_guidance_strength(progress_id, num_inference_steps, guidance_schedule)
if g_t > 0:
latents_for_grad = inputs_shared["latents"].detach().requires_grad_(True)
grad = self.classifier.compute_gradient(
latents=latents_for_grad,
pipeline=self,
timestep=timestep,
class_id=classifier_class_id,
mask=logo_mask,
)

alpha_t = self.scheduler.get_alpha_t(timestep)
sigma_t = self.scheduler.get_sigma_t(timestep)

guidance = -alpha_t * g_t * classifier_guidance_scale * sigma_t * grad
noise_pred = noise_pred + guidance

# Scheduler
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
if "first_frame_latents" in inputs_shared:
Expand Down
12 changes: 12 additions & 0 deletions diffsynth/schedulers/flow_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def training_weight(self, timestep):
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
weights = self.linear_timesteps_weights[timestep_id]
return weights


def get_sigma_t(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
sigma = self.sigmas[timestep_id]
return sigma


def get_alpha_t(self, timestep):
return 1.0


def calculate_shift(
Expand Down
45 changes: 45 additions & 0 deletions examples/wanvideo/model_inference/Wan2.2-TI2V-5B-logo-guidance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from PIL import Image
from diffsynth import save_video
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download

# This is a placeholder path. Please replace it with the actual path to your classifier model.
LOGO_CLASSIFIER_PATH = "path/to/your/logo_classifier.pt"
# This is a placeholder path. Please replace it with the actual path to your logo detector model.
LOGO_DETECTOR_PATH = "path/to/your/logo_detector.pt"

pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()

# Load the classifier
pipe.load_classifier(LOGO_CLASSIFIER_PATH)
Comment on lines +8 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This example script uses placeholder paths for the logo classifier and detector models. If a user runs this script without replacing the placeholders with valid paths, it will crash with a FileNotFoundError. For a better user experience, the script should check if these files exist before attempting to load them and print a helpful message if they are not found.

import os
LOGO_CLASSIFIER_PATH = "path/to/your/logo_classifier.pt"
# This is a placeholder path. Please replace it with the actual path to your logo detector model.
LOGO_DETECTOR_PATH = "path/to/your/logo_detector.pt"

if not os.path.exists(LOGO_CLASSIFIER_PATH) or not os.path.exists(LOGO_DETECTOR_PATH):
    print("Please replace the placeholder paths for LOGO_CLASSIFIER_PATH and LOGO_DETECTOR_PATH")
    print("in the script with actual paths to your models.")
    exit()

pipe = WanVideoPipeline.from_pretrained(
    torch_dtype=torch.bfloat16,
    device="cuda",
    model_configs=[
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
        ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
    ],
)
pipe.enable_vram_management()

# Load the classifier
pipe.load_classifier(LOGO_CLASSIFIER_PATH)


# Image-to-video with logo guidance
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"]
)
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704))

video = pipe(
prompt="A video of a cat fighting, with a brand logo in the corner.",
negative_prompt="worst quality, low quality",
seed=0, tiled=True,
height=704, width=1248,
input_image=input_image,
num_frames=121,
classifier_guidance_scale=5.0,
classifier_class_id=1, # Replace with the class id of your logo
logo_detector_path=LOGO_DETECTOR_PATH,
)
save_video(video, "video_with_logo_guidance.mp4", fps=15, quality=5)