-
Notifications
You must be signed in to change notification settings - Fork 954
I have created the demonstration script `examples/wanvideo/model_infe… #867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torchvision.transforms.functional as F | ||
from torchvision.models.detection import fasterrcnn_vgg16_bn | ||
from torchvision.ops import nms | ||
|
||
class OpenLogoClassifier(nn.Module): | ||
def __init__(self, model_path, device="cuda"): | ||
super().__init__() | ||
self.device = device | ||
# Load the pretrained Faster R-CNN model with VGG16 backbone | ||
self.model = fasterrcnn_vgg16_bn(pretrained=False) # We will load our own weights | ||
|
||
# Load the state dict from the provided .pt file | ||
state_dict = torch.load(model_path, map_location=self.device) | ||
self.model.load_state_dict(state_dict) | ||
|
||
self.model.to(self.device) | ||
self.model.eval() | ||
|
||
def preprocess_image(self, image): | ||
# Preprocess the image for the Faster R-CNN model | ||
# The image is expected to be a tensor of shape [C, H, W] with values in [0, 1] | ||
image = image.to(self.device) | ||
return image | ||
|
||
def compute_gradient(self, latents, pipeline, timestep, class_id, mask=None, guidance_scale=1.0): | ||
with torch.enable_grad(): | ||
latents = latents.detach().requires_grad_(True) | ||
|
||
# Decode latents to video | ||
video = pipeline.vae.decode(latents) | ||
|
||
# The output of the VAE is in the range [-1, 1], we need to normalize it to [0, 1] | ||
video = (video / 2 + 0.5).clamp(0, 1) | ||
|
||
log_prob_sum = 0 | ||
|
||
num_frames = video.shape[2] | ||
for i in range(num_frames): | ||
image = video[:, :, i, :, :] | ||
|
||
# Preprocess the image for the classifier | ||
image_processed = self.preprocess_image(image) | ||
|
||
# Get predictions from the classifier | ||
self.model.train() | ||
|
||
image_for_model = image_processed | ||
features = self.model.backbone(image_for_model) | ||
proposals, proposal_losses = self.model.rpn(image_for_model, features, None) | ||
box_features = self.model.roi_heads.box_roi_pool(features, proposals, [image_for_model.shape[2:]]) | ||
box_features = self.model.roi_heads.box_head(box_features) | ||
class_logits, box_regression = self.model.roi_heads.box_predictor(box_features) | ||
|
||
log_probs = torch.nn.functional.log_softmax(class_logits, dim=-1) | ||
log_prob_target = log_probs[:, class_id].mean() | ||
log_prob_sum = log_prob_sum + log_prob_target | ||
|
||
grad = torch.autograd.grad(log_prob_sum, latents, grad_outputs=torch.ones_like(log_prob_sum))[0] | ||
|
||
# Apply the soft logo mask if provided | ||
if mask is not None: | ||
grad = grad * mask | ||
|
||
# Clamp the gradient to avoid artifacts | ||
grad = torch.clamp(grad, -0.1, 0.1) | ||
|
||
self.model.eval() | ||
Comment on lines
+47
to
+69
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Calling # Get predictions from the classifier
image_for_model = image_processed
features = self.model.backbone(image_for_model)
proposals, proposal_losses = self.model.rpn(image_for_model, features, None)
box_features = self.model.roi_heads.box_roi_pool(features, proposals, [image_for_model.shape[2:]])
box_features = self.model.roi_heads.box_head(box_features)
class_logits, box_regression = self.model.roi_heads.box_predictor(box_features)
log_probs = torch.nn.functional.log_softmax(class_logits, dim=-1)
log_prob_target = log_probs[:, class_id].mean()
log_prob_sum = log_prob_sum + log_prob_target
grad = torch.autograd.grad(log_prob_sum, latents, grad_outputs=torch.ones_like(log_prob_sum))[0]
# Apply the soft logo mask if provided
if mask is not None:
grad = grad * mask
# Clamp the gradient to avoid artifacts
grad = torch.clamp(grad, -0.1, 0.1)
|
||
|
||
return grad |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,57 @@ | ||||||||||
import torch | ||||||||||
import torch.nn as nn | ||||||||||
import torchvision.transforms.functional as F | ||||||||||
from torchvision.models.detection import fasterrcnn_resnet50_fpn | ||||||||||
from torchvision.transforms import GaussianBlur | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
|
||||||||||
class LogoDetector(nn.Module): | ||||||||||
def __init__(self, model_path=None, device="cuda"): | ||||||||||
super().__init__() | ||||||||||
self.device = device | ||||||||||
if model_path: | ||||||||||
# Load a custom model if a path is provided | ||||||||||
# For now, we assume the model is a Faster R-CNN model | ||||||||||
self.model = torch.load(model_path, map_location=self.device) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||||
else: | ||||||||||
# Load a pretrained Faster R-CNN model from torchvision | ||||||||||
self.model = fasterrcnn_resnet50_fpn(pretrained=True) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||
|
||||||||||
self.model.to(self.device) | ||||||||||
self.model.eval() | ||||||||||
|
||||||||||
def preprocess_image(self, image): | ||||||||||
# The image is expected to be a tensor of shape [C, H, W] with values in [0, 1] | ||||||||||
image = image.to(self.device) | ||||||||||
return image | ||||||||||
|
||||||||||
def get_logo_mask(self, image, threshold=0.5, blur_kernel_size=21, blur_sigma=5): | ||||||||||
# The image is expected to be a tensor of shape (B, C, H, W) | ||||||||||
if image.dim() == 3: | ||||||||||
image = image.unsqueeze(0) | ||||||||||
|
||||||||||
image_processed = self.preprocess_image(image) | ||||||||||
|
||||||||||
with torch.no_grad(): | ||||||||||
predictions = self.model(image_processed) | ||||||||||
|
||||||||||
mask = torch.zeros(image.shape[2:], device=self.device) | ||||||||||
|
||||||||||
for i in range(len(predictions)): | ||||||||||
boxes = predictions[i]['boxes'] | ||||||||||
scores = predictions[i]['scores'] | ||||||||||
|
||||||||||
for box, score in zip(boxes, scores): | ||||||||||
if score > threshold: | ||||||||||
x1, y1, x2, y2 = box.int() | ||||||||||
mask[y1:y2, x1:x2] = 1 | ||||||||||
|
||||||||||
# Create a soft mask by applying a Gaussian blur | ||||||||||
if blur_kernel_size > 0 and blur_sigma > 0: | ||||||||||
mask = F.gaussian_blur(mask.unsqueeze(0).unsqueeze(0), kernel_size=blur_kernel_size, sigma=blur_sigma) | ||||||||||
mask = mask.squeeze(0).squeeze(0) | ||||||||||
|
||||||||||
# Normalize the mask to be in the range [0, 1] | ||||||||||
if mask.max() > 0: | ||||||||||
mask = mask / mask.max() | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
|
||||||||||
return mask |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -24,6 +24,8 @@ | |||||
from ..prompters import WanPrompter | ||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm | ||||||
from ..lora import GeneralLoRALoader | ||||||
from ..classifiers.open_logo import OpenLogoClassifier | ||||||
from ..models.logo_detector import LogoDetector | ||||||
|
||||||
|
||||||
|
||||||
|
@@ -43,6 +45,7 @@ def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=Non | |||||
self.vae: WanVideoVAE = None | ||||||
self.motion_controller: WanMotionControllerModel = None | ||||||
self.vace: VaceWanModel = None | ||||||
self.classifier: OpenLogoClassifier = None | ||||||
self.in_iteration_models = ("dit", "motion_controller", "vace") | ||||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") | ||||||
self.unit_runner = PipelineUnitRunner() | ||||||
|
@@ -72,6 +75,10 @@ def load_lora(self, module, path, alpha=1): | |||||
loader.load(module, lora, alpha=alpha) | ||||||
|
||||||
|
||||||
def load_classifier(self, model_path, **kwargs): | ||||||
self.classifier = OpenLogoClassifier(model_path, device=self.device, **kwargs) | ||||||
|
||||||
|
||||||
def training_loss(self, **inputs): | ||||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) | ||||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * self.scheduler.num_train_timesteps) | ||||||
|
@@ -348,6 +355,19 @@ def from_pretrained( | |||||
return pipe | ||||||
|
||||||
|
||||||
def get_guidance_strength(self, progress_id, num_inference_steps, schedule_name="late_step_emphasis"): | ||||||
if schedule_name == "late_step_emphasis": | ||||||
# Start guidance after 60% of the steps | ||||||
start_step = int(num_inference_steps * 0.6) | ||||||
if progress_id < start_step: | ||||||
return 0.0 | ||||||
else: | ||||||
# Ramp up from 0 to 1 | ||||||
return (progress_id - start_step) / (num_inference_steps - start_step) | ||||||
else: | ||||||
return 1.0 | ||||||
|
||||||
|
||||||
@torch.no_grad() | ||||||
def __call__( | ||||||
self, | ||||||
|
@@ -383,6 +403,12 @@ def __call__( | |||||
# Classifier-free guidance | ||||||
cfg_scale: Optional[float] = 5.0, | ||||||
cfg_merge: Optional[bool] = False, | ||||||
# Classifier guidance | ||||||
classifier_guidance_scale: float = 0.0, | ||||||
logo_mask: Optional[torch.Tensor] = None, | ||||||
classifier_class_id: int = 0, | ||||||
guidance_schedule: str = "late_step_emphasis", | ||||||
logo_detector_path: Optional[str] = None, | ||||||
# Boundary | ||||||
switch_DiT_boundary: Optional[float] = 0.875, | ||||||
# Scheduler | ||||||
|
@@ -433,6 +459,14 @@ def __call__( | |||||
for unit in self.units: | ||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) | ||||||
|
||||||
# Logo detection | ||||||
if logo_detector_path is not None: | ||||||
logo_detector = LogoDetector(logo_detector_path, device=self.device) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
if input_image is not None: | ||||||
input_image_tensor = self.preprocess_image(input_image.resize((width, height))) | ||||||
logo_mask = logo_detector.get_logo_mask(input_image_tensor) | ||||||
logo_mask = logo_mask.to(dtype=self.torch_dtype, device=self.device) | ||||||
|
||||||
# Denoise | ||||||
self.load_models_to_device(self.in_iteration_models) | ||||||
models = {name: getattr(self, name) for name in self.in_iteration_models} | ||||||
|
@@ -456,6 +490,25 @@ def __call__( | |||||
else: | ||||||
noise_pred = noise_pred_posi | ||||||
|
||||||
# Classifier guidance | ||||||
if self.classifier is not None and classifier_guidance_scale > 0.0: | ||||||
g_t = self.get_guidance_strength(progress_id, num_inference_steps, guidance_schedule) | ||||||
if g_t > 0: | ||||||
latents_for_grad = inputs_shared["latents"].detach().requires_grad_(True) | ||||||
grad = self.classifier.compute_gradient( | ||||||
latents=latents_for_grad, | ||||||
pipeline=self, | ||||||
timestep=timestep, | ||||||
class_id=classifier_class_id, | ||||||
mask=logo_mask, | ||||||
) | ||||||
|
||||||
alpha_t = self.scheduler.get_alpha_t(timestep) | ||||||
sigma_t = self.scheduler.get_sigma_t(timestep) | ||||||
|
||||||
guidance = -alpha_t * g_t * classifier_guidance_scale * sigma_t * grad | ||||||
noise_pred = noise_pred + guidance | ||||||
|
||||||
# Scheduler | ||||||
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"]) | ||||||
if "first_frame_latents" in inputs_shared: | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
from PIL import Image | ||
from diffsynth import save_video | ||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig | ||
from modelscope import dataset_snapshot_download | ||
|
||
# This is a placeholder path. Please replace it with the actual path to your classifier model. | ||
LOGO_CLASSIFIER_PATH = "path/to/your/logo_classifier.pt" | ||
# This is a placeholder path. Please replace it with the actual path to your logo detector model. | ||
LOGO_DETECTOR_PATH = "path/to/your/logo_detector.pt" | ||
|
||
pipe = WanVideoPipeline.from_pretrained( | ||
torch_dtype=torch.bfloat16, | ||
device="cuda", | ||
model_configs=[ | ||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), | ||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), | ||
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"), | ||
], | ||
) | ||
pipe.enable_vram_management() | ||
|
||
# Load the classifier | ||
pipe.load_classifier(LOGO_CLASSIFIER_PATH) | ||
Comment on lines
+8
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This example script uses placeholder paths for the logo classifier and detector models. If a user runs this script without replacing the placeholders with valid paths, it will crash with a import os
LOGO_CLASSIFIER_PATH = "path/to/your/logo_classifier.pt"
# This is a placeholder path. Please replace it with the actual path to your logo detector model.
LOGO_DETECTOR_PATH = "path/to/your/logo_detector.pt"
if not os.path.exists(LOGO_CLASSIFIER_PATH) or not os.path.exists(LOGO_DETECTOR_PATH):
print("Please replace the placeholder paths for LOGO_CLASSIFIER_PATH and LOGO_DETECTOR_PATH")
print("in the script with actual paths to your models.")
exit()
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.2-TI2V-5B", origin_file_pattern="Wan2.2_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
# Load the classifier
pipe.load_classifier(LOGO_CLASSIFIER_PATH) |
||
|
||
# Image-to-video with logo guidance | ||
dataset_snapshot_download( | ||
dataset_id="DiffSynth-Studio/examples_in_diffsynth", | ||
local_dir="./", | ||
allow_file_pattern=["data/examples/wan/cat_fightning.jpg"] | ||
) | ||
input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 704)) | ||
|
||
video = pipe( | ||
prompt="A video of a cat fighting, with a brand logo in the corner.", | ||
negative_prompt="worst quality, low quality", | ||
seed=0, tiled=True, | ||
height=704, width=1248, | ||
input_image=input_image, | ||
num_frames=121, | ||
classifier_guidance_scale=5.0, | ||
classifier_class_id=1, # Replace with the class id of your logo | ||
logo_detector_path=LOGO_DETECTOR_PATH, | ||
) | ||
save_video(video, "video_with_logo_guidance.mp4", fps=15, quality=5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The imports
torchvision.transforms.functional as F
andfrom torchvision.ops import nms
are not used in this file. They should be removed to keep the code clean.