forked from comfyanonymous/ComfyUI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Nodes to properly use the SDV img2vid checkpoint.
The img2vid model is conditioned on clip vision output only which means there's no CLIP model which is why I added a ImageOnlyCheckpointLoader to load it. Note that the unClipCheckpointLoader can also load it because it also has a CLIP_VISION output. SDV_img2vid_Conditioning is the node used to pass the right conditioning to the img2vid model. VideoLinearCFGGuidance applies a linearly decreasing CFG scale to each video frame from the cfg set in the sampler node to min_cfg. SDV_img2vid_Conditioning can be found in conditioning->video_models ImageOnlyCheckpointLoader can be found in loaders->video_models VideoLinearCFGGuidance can be found in sampling->video_models
- Loading branch information
1 parent
871cc20
commit 42dfae6
Showing
2 changed files
with
90 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import nodes | ||
import torch | ||
import comfy.utils | ||
import comfy.sd | ||
import folder_paths | ||
|
||
|
||
class ImageOnlyCheckpointLoader: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), | ||
}} | ||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") | ||
FUNCTION = "load_checkpoint" | ||
|
||
CATEGORY = "loaders/video_models" | ||
|
||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): | ||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) | ||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) | ||
return (out[0], out[3], out[2]) | ||
|
||
|
||
class SDV_img2vid_Conditioning: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": { "clip_vision": ("CLIP_VISION",), | ||
"init_image": ("IMAGE",), | ||
"vae": ("VAE",), | ||
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), | ||
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), | ||
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}), | ||
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}), | ||
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}), | ||
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}) | ||
}} | ||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") | ||
RETURN_NAMES = ("positive", "negative", "latent") | ||
|
||
FUNCTION = "encode" | ||
|
||
CATEGORY = "conditioning/video_models" | ||
|
||
def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level): | ||
output = clip_vision.encode_image(init_image) | ||
pooled = output.image_embeds.unsqueeze(0) | ||
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) | ||
encode_pixels = pixels[:,:,:,:3] | ||
if augmentation_level > 0: | ||
encode_pixels += torch.randn_like(pixels) * augmentation_level | ||
t = vae.encode(encode_pixels) | ||
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]] | ||
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]] | ||
latent = torch.zeros([video_frames, 4, height // 8, width // 8]) | ||
return (positive, negative, {"samples":latent}) | ||
|
||
class VideoLinearCFGGuidance: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": { "model": ("MODEL",), | ||
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), | ||
}} | ||
RETURN_TYPES = ("MODEL",) | ||
FUNCTION = "patch" | ||
|
||
CATEGORY = "sampling/video_models" | ||
|
||
def patch(self, model, min_cfg): | ||
def linear_cfg(args): | ||
cond = args["cond"] | ||
uncond = args["uncond"] | ||
cond_scale = args["cond_scale"] | ||
|
||
scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1)) | ||
return uncond + scale * (cond - uncond) | ||
|
||
m = model.clone() | ||
m.set_model_sampler_cfg_function(linear_cfg) | ||
return (m, ) | ||
|
||
NODE_CLASS_MAPPINGS = { | ||
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, | ||
"SDV_img2vid_Conditioning": SDV_img2vid_Conditioning, | ||
"VideoLinearCFGGuidance": VideoLinearCFGGuidance, | ||
} | ||
|
||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters