Skip to content

Commit

Permalink
Nodes to properly use the SDV img2vid checkpoint.
Browse files Browse the repository at this point in the history
The img2vid model is conditioned on clip vision output only which means
there's no CLIP model which is why I added a ImageOnlyCheckpointLoader to
load it. Note that the unClipCheckpointLoader can also load it because it
also has a CLIP_VISION output.

SDV_img2vid_Conditioning is the node used to pass the right conditioning
to the img2vid model.

VideoLinearCFGGuidance applies a linearly decreasing CFG scale to each
video frame from the cfg set in the sampler node to min_cfg.

SDV_img2vid_Conditioning can be found in conditioning->video_models
ImageOnlyCheckpointLoader can be found in loaders->video_models
VideoLinearCFGGuidance can be found in sampling->video_models
  • Loading branch information
comfyanonymous committed Nov 24, 2023
1 parent 871cc20 commit 42dfae6
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
89 changes: 89 additions & 0 deletions comfy_extras/nodes_video_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import nodes
import torch
import comfy.utils
import comfy.sd
import folder_paths


class ImageOnlyCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"

CATEGORY = "loaders/video_models"

def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])


class SDV_img2vid_Conditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_vision": ("CLIP_VISION",),
"init_image": ("IMAGE",),
"vae": ("VAE",),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}),
"motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}),
"fps": ("INT", {"default": 6, "min": 1, "max": 1024}),
"augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01})
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")

FUNCTION = "encode"

CATEGORY = "conditioning/video_models"

def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level):
output = clip_vision.encode_image(init_image)
pooled = output.image_embeds.unsqueeze(0)
pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1)
encode_pixels = pixels[:,:,:,:3]
if augmentation_level > 0:
encode_pixels += torch.randn_like(pixels) * augmentation_level
t = vae.encode(encode_pixels)
positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]]
negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]]
latent = torch.zeros([video_frames, 4, height // 8, width // 8])
return (positive, negative, {"samples":latent})

class VideoLinearCFGGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "sampling/video_models"

def patch(self, model, min_cfg):
def linear_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]

scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1))
return uncond + scale * (cond - uncond)

m = model.clone()
m.set_model_sampler_cfg_function(linear_cfg)
return (m, )

NODE_CLASS_MAPPINGS = {
"ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader,
"SDV_img2vid_Conditioning": SDV_img2vid_Conditioning,
"VideoLinearCFGGuidance": VideoLinearCFGGuidance,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)",
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,7 @@ def init_custom_nodes():
"nodes_model_advanced.py",
"nodes_model_downscale.py",
"nodes_images.py",
"nodes_video_model.py",
]

for node_file in extras_files:
Expand Down

0 comments on commit 42dfae6

Please sign in to comment.