-
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathyolo_seg.py
117 lines (94 loc) · 4.08 KB
/
yolo_seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
@author: jags111
@title: Jags_VectorMagic
@nickname: Jags_VectorMagic
@description: This extension offers various vector manipulation and generation tools
"""
import folder_paths
from PIL import Image
import numpy as np
from ultralytics import YOLO
import torch
import os
import nodes
from typing import Optional
import comfy
folder_paths.folder_names_and_paths["yolov8"] = ([os.path.join(folder_paths.models_dir, "yolov8")], folder_paths.supported_pt_extensions)
class YoloSEGdetectionNode:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"model_name": (folder_paths.get_filename_list("yolov8"), ),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("SEG_IMAGE",)
FUNCTION = "detect"
CATEGORY = "Jags_vector/yoloSEG"
def detect(self, image, model_name):
# Convert tensor to numpy array and then to PIL Image
image_tensor = image
image_np = image_tensor.cpu().numpy() # Change from CxHxW to HxWxC for Pillow
image = Image.fromarray((image_np.squeeze(0) * 255).astype(np.uint8)) # Convert float [0,1] tensor to uint8 image
print(f'model_path: {os.path.join(folder_paths.models_dir, "yolov8")}/{model_name}')
model = YOLO(f'{os.path.join(folder_paths.models_dir, "yolov8")}/{model_name}') # load a custom model
results = model(image)
# TODO load masks
# masks = results[0].masks
im_array = results[0].plot() # plot a BGR numpy array of predictions
im = Image.fromarray(im_array[...,::-1]) # RGB PIL image
image_tensor_out = torch.tensor(np.array(im).astype(np.float32) / 255.0) # Convert back to CxHxW
image_tensor_out = torch.unsqueeze(image_tensor_out, 0)
return (image_tensor_out,)
class YoloSegNode:
def __init__(self) -> None:
...
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"model_name": (folder_paths.get_filename_list("yolov8"), ),
"class_id": ("INT", {"default": 0})
},
}
RETURN_TYPES = ("IMAGE", "MASK",)
RETURN_NAMES = ("SEG_IMAGE", "SEG_MASK",)
FUNCTION = "seg"
CATEGORY = "Jags_vector/yoloSEG"
def seg(self, image, model_name, class_id):
# Convert tensor to numpy array and then to PIL Image
image_tensor = image
image_np = image_tensor.cpu().numpy() # Change from CxHxW to HxWxC for Pillow
image = Image.fromarray((image_np.squeeze(0) * 255).astype(np.uint8)) # Convert float [0,1] tensor to uint8 image
print(f'model_path: {os.path.join(folder_paths.models_dir, "yolov8")}/{model_name}')
model = YOLO(f'{os.path.join(folder_paths.models_dir, "yolov8")}/{model_name}') # load a custom model
results = model(image)
# get array results
masks = results[0].masks.data
boxes = results[0].boxes.data
# extract classes
clss = boxes[:, 5]
# get indices of results where class is 0 (people in COCO)
people_indices = torch.where(clss == class_id)
# use these indices to extract the relevant masks
people_masks = masks[people_indices]
# scale for visualizing results
people_mask = torch.any(people_masks, dim=0).int() * 255
im_array = results[0].plot() # plot a BGR numpy array of predictions
im = Image.fromarray(im_array[...,::-1]) # RGB PIL image
image_tensor_out = torch.tensor(np.array(im).astype(np.float32) / 255.0) # Convert back to CxHxW
image_tensor_out = torch.unsqueeze(image_tensor_out, 0)
return (image_tensor_out, people_mask)
NODE_CLASS_MAPPINGS = {
"YoloSEGdetectionNode": YoloSEGdetectionNode,
"YoloSegNode": YoloSegNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"YoloSEGdetectionNode": 'Jags-YoloSEGdetectionNode',
"YoloSegNode": 'Jags-YoloSegNode',
}