diff --git a/diffsynth/classifiers/__init__.py b/diffsynth/classifiers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/diffsynth/classifiers/open_logo.py b/diffsynth/classifiers/open_logo.py new file mode 100644 index 00000000..055c063c --- /dev/null +++ b/diffsynth/classifiers/open_logo.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torchvision.transforms.functional as F +from torchvision.models.detection import fasterrcnn_vgg16_bn +from torchvision.ops import nms + +class OpenLogoClassifier(nn.Module): + def __init__(self, model_path, device="cuda"): + super().__init__() + self.device = device + # Load the pretrained Faster R-CNN model with VGG16 backbone + self.model = fasterrcnn_vgg16_bn(pretrained=False) # We will load our own weights + + # Load the state dict from the provided .pt file + state_dict = torch.load(model_path, map_location=self.device) + self.model.load_state_dict(state_dict) + + self.model.to(self.device) + self.model.eval() + + def preprocess_image(self, image): + # Preprocess the image for the Faster R-CNN model + # The image is expected to be a tensor of shape [C, H, W] with values in [0, 1] + image = image.to(self.device) + return image + + def compute_gradient(self, latents, pipeline, timestep, class_id, mask=None, guidance_scale=1.0): + with torch.enable_grad(): + latents = latents.detach().requires_grad_(True) + + # Decode latents to video + video = pipeline.vae.decode(latents) + + # The output of the VAE is in the range [-1, 1], we need to normalize it to [0, 1] + video = (video / 2 + 0.5).clamp(0, 1) + + log_prob_sum = 0 + + num_frames = video.shape[2] + for i in range(num_frames): + image = video[:, :, i, :, :] + + # Preprocess the image for the classifier + image_processed = self.preprocess_image(image) + + # Get predictions from the classifier + self.model.train() + + image_for_model = image_processed + features = self.model.backbone(image_for_model) + proposals, proposal_losses = self.model.rpn(image_for_model, features, None) + box_features = self.model.roi_heads.box_roi_pool(features, proposals, [image_for_model.shape[2:]]) + box_features = self.model.roi_heads.box_head(box_features) + class_logits, box_regression = self.model.roi_heads.box_predictor(box_features) + + log_probs = torch.nn.functional.log_softmax(class_logits, dim=-1) + log_prob_target = log_probs[:, class_id].mean() + log_prob_sum = log_prob_sum + log_prob_target + + grad = torch.autograd.grad(log_prob_sum, latents, grad_outputs=torch.ones_like(log_prob_sum))[0] + + # Apply the soft logo mask if provided + if mask is not None: + grad = grad * mask + + # Clamp the gradient to avoid artifacts + grad = torch.clamp(grad, -0.1, 0.1) + + self.model.eval() + + return grad diff --git a/diffsynth/models/logo_detector.py b/diffsynth/models/logo_detector.py new file mode 100644 index 00000000..0e13faec --- /dev/null +++ b/diffsynth/models/logo_detector.py @@ -0,0 +1,57 @@ +import torch +import torch.nn as nn +import torchvision.transforms.functional as F +from torchvision.models.detection import fasterrcnn_resnet50_fpn +from torchvision.transforms import GaussianBlur + +class LogoDetector(nn.Module): + def __init__(self, model_path=None, device="cuda"): + super().__init__() + self.device = device + if model_path: + # Load a custom model if a path is provided + # For now, we assume the model is a Faster R-CNN model + self.model = torch.load(model_path, map_location=self.device) + else: + # Load a pretrained Faster R-CNN model from torchvision + self.model = fasterrcnn_resnet50_fpn(pretrained=True) + + self.model.to(self.device) + self.model.eval() + + def preprocess_image(self, image): + # The image is expected to be a tensor of shape [C, H, W] with values in [0, 1] + image = image.to(self.device) + return image + + def get_logo_mask(self, image, threshold=0.5, blur_kernel_size=21, blur_sigma=5): + # The image is expected to be a tensor of shape (B, C, H, W) + if image.dim() == 3: + image = image.unsqueeze(0) + + image_processed = self.preprocess_image(image) + + with torch.no_grad(): + predictions = self.model(image_processed) + + mask = torch.zeros(image.shape[2:], device=self.device) + + for i in range(len(predictions)): + boxes = predictions[i]['boxes'] + scores = predictions[i]['scores'] + + for box, score in zip(boxes, scores): + if score > threshold: + x1, y1, x2, y2 = box.int() + mask[y1:y2, x1:x2] = 1 + + # Create a soft mask by applying a Gaussian blur + if blur_kernel_size > 0 and blur_sigma > 0: + mask = F.gaussian_blur(mask.unsqueeze(0).unsqueeze(0), kernel_size=blur_kernel_size, sigma=blur_sigma) + mask = mask.squeeze(0).squeeze(0) + + # Normalize the mask to be in the range [0, 1] + if mask.max() > 0: + mask = mask / mask.max() + + return mask diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 53df7d94..9eb833f1 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -24,6 +24,8 @@ from ..prompters import WanPrompter from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm from ..lora import GeneralLoRALoader +from ..classifiers.open_logo import OpenLogoClassifier +from ..models.logo_detector import LogoDetector @@ -43,6 +45,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=Non self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None + self.classifier: OpenLogoClassifier = None self.in_iteration_models = ("dit", "motion_controller", "vace") self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") self.unit_runner = PipelineUnitRunner() @@ -72,6 +75,10 @@ def load_lora(self, module, path, alpha=1): loader.load(module, lora, alpha=alpha) + def load_classifier(self, model_path, **kwargs): + self.classifier = OpenLogoClassifier(model_path, device=self.device, **kwargs) + + def training_loss(self, **inputs): max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) @@ -348,6 +355,19 @@ def from_pretrained( return pipe + def get_guidance_strength(self, progress_id, num_inference_steps, schedule_name="late_step_emphasis"): + if schedule_name == "late_step_emphasis": + # Start guidance after 60% of the steps + start_step = int(num_inference_steps * 0.6) + if progress_id < start_step: + return 0.0 + else: + # Ramp up from 0 to 1 + return (progress_id - start_step) / (num_inference_steps - start_step) + else: + return 1.0 + + @torch.no_grad() def __call__( self, @@ -383,6 +403,12 @@ def __call__( # Classifier-free guidance cfg_scale: Optional[float] = 5.0, cfg_merge: Optional[bool] = False, + # Classifier guidance + classifier_guidance_scale: float = 0.0, + logo_mask: Optional[torch.Tensor] = None, + classifier_class_id: int = 0, + guidance_schedule: str = "late_step_emphasis", + logo_detector_path: Optional[str] = None, # Boundary switch_DiT_boundary: Optional[float] = 0.875, # Scheduler @@ -433,6 +459,14 @@ def __call__( for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + # Logo detection + if logo_detector_path is not None: + logo_detector = LogoDetector(logo_detector_path, device=self.device) + if input_image is not None: + input_image_tensor = self.preprocess_image(input_image.resize((width, height))) + logo_mask = logo_detector.get_logo_mask(input_image_tensor) + logo_mask = logo_mask.to(dtype=self.torch_dtype, device=self.device) + # Denoise self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} @@ -456,6 +490,25 @@ def __call__( else: noise_pred = noise_pred_posi + # Classifier guidance + if self.classifier is not None and classifier_guidance_scale > 0.0: + g_t = self.get_guidance_strength(progress_id, num_inference_steps, guidance_schedule) + if g_t > 0: + latents_for_grad = inputs_shared["latents"].detach().requires_grad_(True) + grad = self.classifier.compute_gradient( + latents=latents_for_grad, + pipeline=self, + timestep=timestep, + class_id=classifier_class_id, + mask=logo_mask, + ) + + alpha_t = self.scheduler.get_alpha_t(timestep) + sigma_t = self.scheduler.get_sigma_t(timestep) + + guidance = -alpha_t * g_t * classifier_guidance_scale * sigma_t * grad + noise_pred = noise_pred + guidance + # Scheduler inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) if "first_frame_latents" in inputs_shared: diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index 6a8e2351..4b178e66 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -104,6 +104,18 @@ def training_weight(self, timestep): timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs()) weights = self.linear_timesteps_weights[timestep_id] return weights + + + def get_sigma_t(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + return sigma + + + def get_alpha_t(self, timestep): + return 1.0 def calculate_shift( diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B-logo-guidance.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B-logo-guidance.py new file mode 100644 index 00000000..89787302 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B-logo-guidance.py @@ -0,0 +1,45 @@ +import torch +from PIL import Image +from diffsynth import save_video +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + +# This is a placeholder path. Please replace it with the actual path to your classifier model. +LOGO_CLASSIFIER_PATH = "path/to/your/logo_classifier.pt" +# This is a placeholder path. Please replace it with the actual path to your logo detector model. +LOGO_DETECTOR_PATH = "path/to/your/logo_detector.pt" + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +# Load the classifier +pipe.load_classifier(LOGO_CLASSIFIER_PATH) + +# Image-to-video with logo guidance +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] +) +input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704)) + +video = pipe( + prompt="A video of a cat fighting, with a brand logo in the corner.", + negative_prompt="worst quality, low quality", + seed=0, tiled=True, + height=704, width=1248, + input_image=input_image, + num_frames=121, + classifier_guidance_scale=5.0, + classifier_class_id=1, # Replace with the class id of your logo + logo_detector_path=LOGO_DETECTOR_PATH, +) +save_video(video, "video_with_logo_guidance.mp4", fps=15, quality=5)