Skip to content

Commit a1e1c69

Browse files
LoadImage now loads all the frames from animated images as a batch.
1 parent 5f54614 commit a1e1c69

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

nodes.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010
import random
1111

12-
from PIL import Image, ImageOps
12+
from PIL import Image, ImageOps, ImageSequence
1313
from PIL.PngImagePlugin import PngInfo
1414
import numpy as np
1515
import safetensors.torch
@@ -1410,17 +1410,30 @@ def INPUT_TYPES(s):
14101410
FUNCTION = "load_image"
14111411
def load_image(self, image):
14121412
image_path = folder_paths.get_annotated_filepath(image)
1413-
i = Image.open(image_path)
1414-
i = ImageOps.exif_transpose(i)
1415-
image = i.convert("RGB")
1416-
image = np.array(image).astype(np.float32) / 255.0
1417-
image = torch.from_numpy(image)[None,]
1418-
if 'A' in i.getbands():
1419-
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
1420-
mask = 1. - torch.from_numpy(mask)
1413+
img = Image.open(image_path)
1414+
output_images = []
1415+
output_masks = []
1416+
for i in ImageSequence.Iterator(img):
1417+
i = ImageOps.exif_transpose(i)
1418+
image = i.convert("RGB")
1419+
image = np.array(image).astype(np.float32) / 255.0
1420+
image = torch.from_numpy(image)[None,]
1421+
if 'A' in i.getbands():
1422+
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
1423+
mask = 1. - torch.from_numpy(mask)
1424+
else:
1425+
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1426+
output_images.append(image)
1427+
output_masks.append(mask.unsqueeze(0))
1428+
1429+
if len(output_images) > 1:
1430+
output_image = torch.cat(output_images, dim=0)
1431+
output_mask = torch.cat(output_masks, dim=0)
14211432
else:
1422-
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
1423-
return (image, mask.unsqueeze(0))
1433+
output_image = output_images[0]
1434+
output_mask = output_masks[0]
1435+
1436+
return (output_image, output_mask)
14241437

14251438
@classmethod
14261439
def IS_CHANGED(s, image):

0 commit comments

Comments
 (0)