Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Dec 21, 2023
2 parents 351a042 + a1e1c69 commit 05e3554
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 14 deletions.
6 changes: 4 additions & 2 deletions comfy_extras/nodes_custom_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,17 @@ def INPUT_TYPES(s):
return {"required":
{"model": ("MODEL",),
"steps": ("INT", {"default": 1, "min": 1, "max": 10}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"

FUNCTION = "get_sigmas"

def get_sigmas(self, model, steps):
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps]
def get_sigmas(self, model, steps, denoise):
start_step = 10 - int(10 * denoise)
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
sigmas = model.model.model_sampling.sigma(timesteps)
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
return (sigmas, )
Expand Down
32 changes: 31 additions & 1 deletion comfy_extras/nodes_rebatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,40 @@ def rebatch(self, latents, batch_size):

return (output_list,)

class ImageRebatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "images": ("IMAGE",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
INPUT_IS_LIST = True
OUTPUT_IS_LIST = (True, )

FUNCTION = "rebatch"

CATEGORY = "image/batch"

def rebatch(self, images, batch_size):
batch_size = batch_size[0]

output_list = []
all_images = []
for img in images:
for i in range(img.shape[0]):
all_images.append(img[i:i+1])

for i in range(0, len(all_images), batch_size):
output_list.append(torch.cat(all_images[i:i+batch_size], dim=0))

return (output_list,)

NODE_CLASS_MAPPINGS = {
"RebatchLatents": LatentRebatch,
"RebatchImages": ImageRebatch,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"RebatchLatents": "Rebatch Latents",
}
"RebatchImages": "Rebatch Images",
}
35 changes: 24 additions & 11 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import random

from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageSequence
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
Expand Down Expand Up @@ -1410,17 +1410,30 @@ def INPUT_TYPES(s):
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
i = Image.open(image_path)
i = ImageOps.exif_transpose(i)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
img = Image.open(image_path)
output_images = []
output_masks = []
for i in ImageSequence.Iterator(img):
i = ImageOps.exif_transpose(i)
image = i.convert("RGB")
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))

if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (image, mask.unsqueeze(0))
output_image = output_images[0]
output_mask = output_masks[0]

return (output_image, output_mask)

@classmethod
def IS_CHANGED(s, image):
Expand Down

0 comments on commit 05e3554

Please sign in to comment.