Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -556,13 +556,7 @@ public static void AssignValuesFromRaw(JObject rawObjectInfo)
new JObject()
{
["class_type"] = "DownloadAndLoadSAM2Model",
["inputs"] = new JObject()
{
["model"] = $"sam2_hiera_{size}.safetensors",
["segmentor"] = "automaskgenerator",
["device"] = "cuda", // TODO: This should really be decided by the python, not by swarm's workflow generator - the python knows what the GPU supports, swarm does not
["precision"] = "bf16"
}
["inputs"] = Sam2ModelInputs(size, "automaskgenerator")
},
new JObject()
{
Expand Down Expand Up @@ -651,9 +645,40 @@ public static void AssignValuesFromRaw(JObject rawObjectInfo)

public static T2IParamGroup ComfyAdvancedGroup;

public static T2IRegisteredParam<Image> Sam2PointImage;

public static T2IRegisteredParam<string> Sam2PointCoordsPositive, Sam2PointCoordsNegative, Sam2BBox, Sam2MaskPadding;

/// <summary>Creates the standard input set for a DownloadAndLoadSAM2Model node.</summary>
public static JObject Sam2ModelInputs(string size = "base_plus", string segmentor = "single_image")
{
return new JObject()
{
["model"] = $"sam2_hiera_{size}.safetensors",
["segmentor"] = segmentor,
["device"] = "cuda", // TODO: This should really be decided by the python, not by swarm's workflow generator - the python knows what the GPU supports, swarm does not
["precision"] = "bf16"
};
}

/// <inheritdoc/>
public override void OnInit()
{
Sam2PointImage = T2IParamTypes.Register<Image>(new("SAM2 Point Image", "Internal: Base image used for SAM2 point masking.",
null, FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true
));
Sam2PointCoordsPositive = T2IParamTypes.Register<string>(new("SAM2 Positive Points", "Internal: JSON list of positive point coordinates for SAM2 point masking.",
"[]", FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true
));
Sam2PointCoordsNegative = T2IParamTypes.Register<string>(new("SAM2 Negative Points", "Internal: JSON list of negative point coordinates for SAM2 point masking.",
"[]", FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true
));
Sam2BBox = T2IParamTypes.Register<string>(new("SAM2 BBox", "Internal: JSON bounding box [x1,y1,x2,y2] for SAM2 bbox masking.",
null, FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true
));
Sam2MaskPadding = T2IParamTypes.Register<string>(new("SAM2 Mask Padding", "Internal: Number of pixels to dilate/expand the SAM2 mask boundary.",
"0", IgnoreIf: "0", FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true
));
UseIPAdapterForRevision = T2IParamTypes.Register<string>(new("Use IP-Adapter", $"Select an IP-Adapter model to use IP-Adapter for image-prompt input handling.\nModels will automatically be downloaded when you first use them.\nNote if you use a custom model, you must also set your CLIP-Vision Model under Advanced Model Addons, otherwise CLIP Vision G will be presumed.\n<a target=\"_blank\" href=\"{Utilities.RepoDocsRoot}/Features/ImagePrompting.md\">See more docs here.</a>",
"None", IgnoreIf: "None", FeatureFlag: "ipadapter", GetValues: _ => IPAdapterModels, Group: T2IParamTypes.GroupImagePrompting, OrderPriority: 15, ChangeWeight: 1
));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json


class Sam2BBoxFromJson:
"""Converts a JSON bounding box string '[x1,y1,x2,y2]' into a BBOX type
that can be passed directly to Sam2Segmentation's bboxes input."""

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"bbox_json": ("STRING", {"forceInput": True}),
}
}

RETURN_TYPES = ("BBOX",)
RETURN_NAMES = ("bboxes",)
FUNCTION = "convert"
CATEGORY = "SAM2"

def convert(self, bbox_json):
coords = json.loads(bbox_json)
return ([[float(coords[0]), float(coords[1]), float(coords[2]), float(coords[3])]],)


NODE_CLASS_MAPPINGS = {
"Sam2BBoxFromJson": Sam2BBoxFromJson,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"Sam2BBoxFromJson": "SAM2 BBox From JSON",
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import numpy as np
import cv2


def fill_mask_holes(mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
"""Fill small holes in a binary mask using morphological close + flood fill."""
mask = np.squeeze(mask)
if mask.ndim == 0:
return np.array([[255]], dtype=np.uint8)
if mask.ndim > 2:
mask = mask[:, :, 0]
if mask.dtype != np.uint8:
if mask.dtype == bool or (mask.max() <= 1 and mask.dtype in [np.float32, np.float64]):
mask = (mask * 255).astype(np.uint8)
else:
mask = mask.astype(np.uint8)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
filled_mask = closed_mask.copy()
h, w = filled_mask.shape
canvas = np.zeros((h + 2, w + 2), dtype=np.uint8)
canvas[1:-1, 1:-1] = filled_mask
cv2.floodFill(canvas, None, (0, 0), 128)
filled_mask = np.where(canvas[1:-1, 1:-1] == 128, 0, 255).astype(np.uint8)
return filled_mask


def add_mask_padding(mask: np.ndarray, padding: int = 0) -> np.ndarray:
"""Expand a mask boundary by dilating outward."""
if padding <= 0:
return mask
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (padding * 2 + 1, padding * 2 + 1))
return cv2.dilate(mask, kernel, iterations=1)


class SwarmSam2MaskPostProcess:
"""Post-processes a SAM2 segmentation mask with hole-filling and padding."""

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"mask": ("MASK",),
},
"optional": {
"fill_holes": ("BOOLEAN", {"default": True}),
"hole_kernel_size": ("INT", {"default": 5, "min": 1, "max": 21, "step": 2}),
"mask_padding": ("INT", {"default": 0, "min": 0, "max": 256, "step": 1}),
},
}

RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("mask",)
FUNCTION = "process"
CATEGORY = "SAM2"

def process(self, mask, fill_holes=True, hole_kernel_size=5, mask_padding=0):
out_list = []
for i in range(mask.shape[0]):
m = mask[i].cpu().numpy()
m_uint8 = (m * 255).astype(np.uint8)
if fill_holes:
m_uint8 = fill_mask_holes(m_uint8, kernel_size=hole_kernel_size)
if mask_padding > 0:
m_uint8 = add_mask_padding(m_uint8, padding=mask_padding)
out_list.append(torch.from_numpy(m_uint8.astype(np.float32) / 255.0))
return (torch.stack(out_list, dim=0),)


NODE_CLASS_MAPPINGS = {
"SwarmSam2MaskPostProcess": SwarmSam2MaskPostProcess,
}

NODE_DISPLAY_NAME_MAPPINGS = {
"SwarmSam2MaskPostProcess": "SAM2 Mask Post-Process (Fill Holes + Padding)",
}
103 changes: 103 additions & 0 deletions src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1676,6 +1676,109 @@ void RunSegmentationProcessing(WorkflowGenerator g, bool isBeforeRefiner)
RunSegmentationProcessing(g, isBeforeRefiner: false);
}, 5);
#endregion
#region SAM2 Masking
AddStep(g =>
{
if (!g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointCoordsPositive, out string coords) || string.IsNullOrWhiteSpace(coords) || coords == "[]")
{
return;
}
string negCoords = null;
if (g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointCoordsNegative, out string negCoordsRaw) && !string.IsNullOrWhiteSpace(negCoordsRaw) && negCoordsRaw != "[]")
{
negCoords = negCoordsRaw;
}
JArray imageNodeActual = null;
if (g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointImage, out Image img))
{
WGNodeData imageNode = g.LoadImage(img, "${sampointimage}", true);
imageNodeActual = imageNode.Path;
}
else if (g.BasicInputImage is not null)
{
imageNodeActual = g.BasicInputImage.Path;
}
if (imageNodeActual is null)
{
return;
}
string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", ComfyUIBackendExtension.Sam2ModelInputs());
JObject segInputs = new()
{
["sam2_model"] = new JArray() { modelNode, 0 },
["image"] = imageNodeActual,
["keep_model_loaded"] = true,
["coordinates_positive"] = coords
};
if (negCoords is not null)
{
segInputs["coordinates_negative"] = negCoords;
}
string segNode = g.CreateNode("Sam2Segmentation", segInputs);
int pointsPadding = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int pp) ? pp : 0;
string postNode = g.CreateNode("SwarmSam2MaskPostProcess", new JObject()
{
["mask"] = new JArray() { segNode, 0 },
["fill_holes"] = true,
["hole_kernel_size"] = 9,
["mask_padding"] = pointsPadding
});
string maskNode = g.CreateNode("MaskToImage", new JObject()
{
["mask"] = new JArray() { postNode, 0 }
});
new WGNodeData([maskNode, 0], g, WGNodeData.DT_IMAGE, g.CurrentCompat()).SaveOutput(null, null, "9");
g.SkipFurtherSteps = true;
}, 8.9);
AddStep(g =>
{
if (!g.UserInput.TryGet(ComfyUIBackendExtension.Sam2BBox, out string bboxJson) || string.IsNullOrWhiteSpace(bboxJson))
{
return;
}
JArray imageNodeActual = null;
if (g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointImage, out Image img))
{
WGNodeData imageNode = g.LoadImage(img, "${sampointimage}", true);
imageNodeActual = imageNode.Path;
}
else if (g.BasicInputImage is not null)
{
imageNodeActual = g.BasicInputImage.Path;
}
if (imageNodeActual is null)
{
return;
}
string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", ComfyUIBackendExtension.Sam2ModelInputs());
string bboxNode = g.CreateNode("Sam2BBoxFromJson", new JObject()
{
["bbox_json"] = bboxJson
});
JObject segInputs = new()
{
["sam2_model"] = new JArray() { modelNode, 0 },
["image"] = imageNodeActual,
["keep_model_loaded"] = true,
["bboxes"] = new JArray() { bboxNode, 0 }
};
string segNode = g.CreateNode("Sam2Segmentation", segInputs);
int bboxPadding = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int bp) ? bp : 0;
string postNode = g.CreateNode("SwarmSam2MaskPostProcess", new JObject()
{
["mask"] = new JArray() { segNode, 0 },
["fill_holes"] = true,
["hole_kernel_size"] = 5,
["mask_padding"] = bboxPadding
});
string maskNode = g.CreateNode("MaskToImage", new JObject()
{
["mask"] = new JArray() { postNode, 0 }
});
new WGNodeData([maskNode, 0], g, WGNodeData.DT_IMAGE, g.CurrentCompat()).SaveOutput(null, null, "9");
g.SkipFurtherSteps = true;
}, 8.85);
#endregion
#region SaveImage
AddStep(g =>
{
Expand Down
Binary file added src/wwwroot/imgs/crosshair.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added src/wwwroot/imgs/rectangle.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading