From 999ed7d0c4ac2b7530e6154e13e417aeac85793d Mon Sep 17 00:00:00 2001 From: Glen Date: Sun, 8 Mar 2026 16:15:39 -0600 Subject: [PATCH 1/6] Add SAM2 segmentation tools and model integration - Implemented SAM2 point segmentation tool allowing users to place positive and negative points to generate masks. - Added SAM2 bounding box segmentation tool for users to draw bounding boxes and generate masks. - Integrated SAM2 model loading and inference into the workflow generator. - Introduced new UI elements for mask padding configuration and clear mask functionality. - Added crosshair and rectangle images for cursor representation in the editor. - Enhanced image editor to handle mask application from generated images. --- .../ComfyUIBackend/ComfyUIBackendExtension.cs | 21 + .../ExtraNodes/Sam2BBoxNode/__init__.py | 232 ++++++++ .../ExtraNodes/Sam2BBoxNode/sam2.py | 498 ++++++++++++++++++ .../ComfyUIBackend/WorkflowGeneratorSteps.cs | 107 ++++ src/wwwroot/imgs/crosshair.png | Bin 0 -> 233 bytes src/wwwroot/imgs/rectangle.png | Bin 0 -> 236 bytes .../js/genpage/helpers/image_editor.js | 430 +++++++++++++++ 7 files changed, 1288 insertions(+) create mode 100644 src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py create mode 100644 src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py create mode 100644 src/wwwroot/imgs/crosshair.png create mode 100644 src/wwwroot/imgs/rectangle.png diff --git a/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs b/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs index ecaa147c2..341f17e4d 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs +++ b/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs @@ -651,9 +651,30 @@ public static void AssignValuesFromRaw(JObject rawObjectInfo) public static T2IParamGroup ComfyAdvancedGroup; + public static T2IRegisteredParam Sam2PointImage; + + public static T2IRegisteredParam Sam2PointCoordsPositive, Sam2PointCoordsNegative, Sam2BBox; + + public static T2IRegisteredParam Sam2MaskPadding; + /// public override void OnInit() { + Sam2PointImage = T2IParamTypes.Register(new("SAM2 Point Image", "Internal: Base image used for SAM2 point masking.", + null, FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); + Sam2PointCoordsPositive = T2IParamTypes.Register(new("SAM2 Positive Points", "Internal: JSON list of positive point coordinates for SAM2 point masking.", + "[]", FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); + Sam2PointCoordsNegative = T2IParamTypes.Register(new("SAM2 Negative Points", "Internal: JSON list of negative point coordinates for SAM2 point masking.", + "[]", FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); + Sam2BBox = T2IParamTypes.Register(new("SAM2 BBox", "Internal: JSON bounding box [x1,y1,x2,y2] for SAM2 bbox masking.", + null, FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); + Sam2MaskPadding = T2IParamTypes.Register(new("SAM2 Mask Padding", "Internal: Number of pixels to dilate/expand the SAM2 mask boundary.", + "0", IgnoreIf: "0", FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); UseIPAdapterForRevision = T2IParamTypes.Register(new("Use IP-Adapter", $"Select an IP-Adapter model to use IP-Adapter for image-prompt input handling.\nModels will automatically be downloaded when you first use them.\nNote if you use a custom model, you must also set your CLIP-Vision Model under Advanced Model Addons, otherwise CLIP Vision G will be presumed.\nSee more docs here.", "None", IgnoreIf: "None", FeatureFlag: "ipadapter", GetValues: _ => IPAdapterModels, Group: T2IParamTypes.GroupImagePrompting, OrderPriority: 15, ChangeWeight: 1 )); diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py new file mode 100644 index 000000000..de629bfb6 --- /dev/null +++ b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py @@ -0,0 +1,232 @@ +import json +import torch +import numpy as np +from PIL import Image +from io import BytesIO + +from .sam2 import ( + predict_mask_from_points, + predict_mask_from_bboxes, + crop_image_with_mask, + remove_background, +) + + +class Sam2BBoxFromJson: + """Converts a JSON bounding box string '[x1,y1,x2,y2]' into a BBOX type + that can be passed directly to Sam2Segmentation's bboxes input.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "bbox_json": ("STRING", {"forceInput": True}), + } + } + + RETURN_TYPES = ("BBOX",) + RETURN_NAMES = ("bboxes",) + FUNCTION = "convert" + CATEGORY = "SAM2" + + def convert(self, bbox_json): + coords = json.loads(bbox_json) + # coords = [x1, y1, x2, y2] + # BBOX type is a list of [x1, y1, x2, y2] lists (one per box) + return ([[float(coords[0]), float(coords[1]), float(coords[2]), float(coords[3])]],) + + +class Sam2PointSegmentation: + """SAM2 segmentation using point prompts with hole-filling""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "points_json": ("STRING", {"multiline": True}), + "model_name": (["sam2_b.pt", "sam2_l.pt", "sam2_s.pt", "sam2_t.pt"], {"default": "sam2_b.pt"}), + }, + "optional": { + "fill_holes": ("BOOLEAN", {"default": True}), + "hole_kernel_size": ("INT", {"default": 5, "min": 1, "max": 21, "step": 2}), + "mask_padding": ("INT", {"default": 0, "min": 0, "max": 50}), + } + } + + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("mask",) + FUNCTION = "segment" + CATEGORY = "SAM2" + + def segment(self, image, points_json, model_name, fill_holes=True, hole_kernel_size=5, mask_padding=0): + """ + Generate segmentation mask from point prompts + + Points JSON format: [ + {"x": 100, "y": 200, "label": 1}, + {"x": 150, "y": 250, "label": 0} + ] + where label: 1 = foreground, 0 = background + """ + # Parse points from JSON + try: + points_data = json.loads(points_json) + if not isinstance(points_data, list): + raise ValueError("Points must be a JSON array") + + points = [(p["x"], p["y"]) for p in points_data] + labels = [p.get("label", 1) for p in points_data] + except Exception as e: + raise ValueError(f"Invalid points JSON: {e}") + + # Convert image to bytes + img_pil = Image.fromarray((image[0].cpu().numpy() * 255).astype(np.uint8)) + img_buffer = BytesIO() + img_pil.save(img_buffer, format="JPEG") + image_bytes = img_buffer.getvalue() + + # Generate mask using helper function + mask = predict_mask_from_points( + image_bytes, + points, + labels, + model_name=model_name, + fill_holes=fill_holes, + kernel_size=hole_kernel_size, + padding=mask_padding, + ) + + # Convert to tensor format (B, H, W) + mask_tensor = torch.from_numpy(mask).float() / 255.0 + return (mask_tensor.unsqueeze(0),) + + +class Sam2BBoxSegmentation: + """SAM2 segmentation using bounding box prompts with hole-filling""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "bbox_json": ("STRING", {"multiline": True}), + "model_name": (["sam2_b.pt", "sam2_l.pt", "sam2_s.pt", "sam2_t.pt"], {"default": "sam2_b.pt"}), + }, + "optional": { + "fill_holes": ("BOOLEAN", {"default": True}), + "hole_kernel_size": ("INT", {"default": 5, "min": 1, "max": 21, "step": 2}), + "mask_padding": ("INT", {"default": 0, "min": 0, "max": 50}), + } + } + + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("mask",) + FUNCTION = "segment" + CATEGORY = "SAM2" + + def segment(self, image, bbox_json, model_name, fill_holes=True, hole_kernel_size=5, mask_padding=0): + """ + Generate segmentation mask from bounding box prompt + + BBox JSON format: [x1, y1, x2, y2] or [[x1, y1, x2, y2], ...] + """ + # Parse bbox from JSON + try: + bbox_data = json.loads(bbox_json) + if isinstance(bbox_data[0], (list, tuple)): + bboxes = bbox_data[0] # Multiple boxes + else: + bboxes = bbox_data # Single box + except Exception as e: + raise ValueError(f"Invalid bbox JSON: {e}") + + # Convert image to bytes + img_pil = Image.fromarray((image[0].cpu().numpy() * 255).astype(np.uint8)) + img_buffer = BytesIO() + img_pil.save(img_buffer, format="JPEG") + image_bytes = img_buffer.getvalue() + + # Generate mask using helper function + mask = predict_mask_from_bboxes( + image_bytes, + bboxes, + model_name=model_name, + fill_holes=fill_holes, + kernel_size=hole_kernel_size, + padding=mask_padding, + ) + + # Convert to tensor format (B, H, W) + mask_tensor = torch.from_numpy(mask).float() / 255.0 + return (mask_tensor.unsqueeze(0),) + + +class Sam2MaskPostProcess: + """Post-process SAM2 masks: crop to bounds, remove background, etc.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "mask": ("MASK",), + }, + "optional": { + "operation": (["crop_to_mask", "remove_background"], {"default": "crop_to_mask"}), + "padding_pixels": ("INT", {"default": 0, "min": 0, "max": 100}), + "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}), + } + } + + RETURN_TYPES = ("IMAGE", "STRING") + RETURN_NAMES = ("result", "info") + FUNCTION = "process" + CATEGORY = "SAM2" + + def process(self, image, mask, operation, padding_pixels=0, aspect_ratio=1.0): + """Post-process masks: crop or remove background""" + # Convert image to bytes + img_pil = Image.fromarray((image[0].cpu().numpy() * 255).astype(np.uint8)) + img_buffer = BytesIO() + img_pil.save(img_buffer, format="JPEG") + image_bytes = img_buffer.getvalue() + + # Convert mask to numpy (0 or 255) + mask_np = (mask[0].cpu().numpy() * 255).astype(np.uint8) + + if operation == "crop_to_mask": + target_ratio = aspect_ratio if aspect_ratio != 1.0 else None + result_bytes, info_dict = crop_image_with_mask( + image_bytes, + mask_np, + padding_pixels=padding_pixels, + target_aspect_ratio=target_ratio, + ) + result_img = Image.open(BytesIO(result_bytes)) + else: # remove_background + result_bytes = remove_background(image_bytes, mask_np) + result_img = Image.open(BytesIO(result_bytes)).convert("RGB") + info_dict = {"operation": "background_removed"} + + # Convert back to tensor + result_tensor = torch.from_numpy(np.array(result_img)).float() / 255.0 + if result_tensor.ndim == 2: + result_tensor = result_tensor.unsqueeze(-1).repeat(1, 1, 3) + + return (result_tensor.unsqueeze(0), json.dumps(info_dict)) + + +NODE_CLASS_MAPPINGS = { + "Sam2BBoxFromJson": Sam2BBoxFromJson, + "Sam2PointSegmentation": Sam2PointSegmentation, + "Sam2BBoxSegmentation": Sam2BBoxSegmentation, + "Sam2MaskPostProcess": Sam2MaskPostProcess, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Sam2BBoxFromJson": "SAM2 BBox From JSON", + "Sam2PointSegmentation": "SAM2 Point Segmentation", + "Sam2BBoxSegmentation": "SAM2 BBox Segmentation", + "Sam2MaskPostProcess": "SAM2 Mask Post-Process", +} diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py new file mode 100644 index 000000000..27de71ab7 --- /dev/null +++ b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py @@ -0,0 +1,498 @@ +""" +SAM 2 segmentation utilities +Provides model loading and inference for interactive segmentation +""" + +import numpy as np +from PIL import Image +from io import BytesIO +from typing import List, Tuple, Optional +from pathlib import Path +from ultralytics import SAM +import torch +import cv2 + +# Global model cache +_model_cache = {} + +# Models directory +MODELS_DIR = Path(__file__).parent.parent.parent / "models" +MODELS_DIR.mkdir(exist_ok=True) + +# Detect GPU availability +print(f"PyTorch version: {torch.__version__}") +print(f"CUDA available: {torch.cuda.is_available()}") +print( + f"CUDA version (compiled): {torch.version.cuda if torch.version.cuda else 'None'}" +) +print( + f"cuDNN version: {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 'None'}" +) + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +print(f"SAM 2 will use device: {DEVICE}") +if DEVICE == "cuda": + print(f"GPU: {torch.cuda.get_device_name(0)}") + print( + f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB" + ) +else: + print("⚠️ Running on CPU - segmentation will be slower") + print( + "To enable GPU: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121" + ) + + +def get_sam_model(model_name: str = "sam2_b.pt") -> SAM: + """ + Load and cache SAM 2 model with GPU support + + Args: + model_name: Model name (sam2_b.pt, sam2_l.pt, sam2_s.pt, sam2_t.pt) + + Returns: + SAM model instance + """ + if model_name not in _model_cache: + model_path = MODELS_DIR / model_name + + # If model doesn't exist locally, ultralytics will auto-download it + # Just use the model name without .pt extension for auto-download + if not model_path.exists(): + print(f"Model not found at {model_path}") + print(f"Ultralytics will auto-download {model_name} to cache...") + # Use just the model name (e.g., 'sam2_b.pt') - ultralytics handles download + model = SAM(model_name) + else: + print(f"Loading SAM model from: {model_path} on {DEVICE}") + model = SAM(str(model_path)) + + # Move model to GPU if available + if DEVICE == "cuda": + model.to(DEVICE) + _model_cache[model_name] = model + print(f"SAM model loaded: {model_name} on {DEVICE}") + + return _model_cache[model_name] + + +def clear_model_cache(): + """Clear the model cache to force reload""" + global _model_cache + _model_cache = {} + + +def fill_mask_holes(mask: np.ndarray, kernel_size: int = 5) -> np.ndarray: + """ + Fill small holes in a binary mask using morphological operations + + Args: + mask: Binary mask as numpy array (H, W) or (H, W, C) with values 0 or 255 + kernel_size: Size of the morphological kernel (larger = fills bigger holes) + + Returns: + Cleaned mask with holes filled + """ + # Squeeze to remove any single-dimensional entries and ensure 2D + mask = np.squeeze(mask) + if mask.ndim == 0: + # Scalar case - shouldn't happen but handle it + return np.array([[255]], dtype=np.uint8) + if mask.ndim > 2: + # If still 3D+, take first channel + mask = mask[:, :, 0] + + # Ensure mask is uint8 with values 0 or 255 + if mask.dtype != np.uint8: + if mask.dtype == bool or (mask.max() <= 1 and mask.dtype in [np.float32, np.float64]): + mask = (mask * 255).astype(np.uint8) + else: + mask = mask.astype(np.uint8) + + # Create a kernel for morphological operations + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + + # Morphological closing: dilation followed by erosion + # This fills small holes while preserving the outer boundary + closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + + # Optional: Fill any remaining interior holes using flood fill + # This catches larger holes that closing might miss + filled_mask = closed_mask.copy() + h, w = filled_mask.shape + + # Create a slightly larger canvas for flood fill + canvas = np.zeros((h + 2, w + 2), dtype=np.uint8) + canvas[1:-1, 1:-1] = filled_mask + + # Flood fill from the border to mark the background + cv2.floodFill(canvas, None, (0, 0), 128) + + # Extract the filled region (anything not marked as background is foreground) + filled_mask = np.where(canvas[1:-1, 1:-1] == 128, 0, 255).astype(np.uint8) + + return filled_mask + + +def add_mask_padding(mask: np.ndarray, padding: int = 10) -> np.ndarray: + """ + Add padding to a mask by dilating it outward + + This expands the mask boundary, which is useful for inpainting to: + - Include border pixels for better blending + - Provide more context around the masked area + - Avoid hard edges at the mask boundary + + Args: + mask: Binary mask as numpy array (H, W) with values 0 or 255 + padding: Number of pixels to expand the mask (kernel size for dilation) + + Returns: + Dilated mask with padding added + """ + if padding <= 0: + return mask + + # Create a circular/elliptical kernel for smooth expansion + kernel = cv2.getStructuringElement( + cv2.MORPH_ELLIPSE, (padding * 2 + 1, padding * 2 + 1) + ) + + # Dilate the mask to expand it outward + padded_mask = cv2.dilate(mask, kernel, iterations=1) + + return padded_mask + + +def predict_mask_from_points( + image_bytes: bytes, + points: List[Tuple[int, int]], + labels: List[int], + model_name: str = "sam2_b.pt", + fill_holes: bool = True, + kernel_size: int = 5, + padding: int = 0, +) -> np.ndarray: + """ + Generate segmentation mask from point prompts + + Args: + image_bytes: Image data as bytes + points: List of (x, y) coordinates for prompts + labels: List of labels (1 for foreground, 0 for background) + model_name: SAM model to use + fill_holes: Whether to fill small holes in the mask + kernel_size: Size of morphological kernel for hole filling + padding: Number of pixels to expand the mask boundary (0 = no padding) + + Returns: + Binary mask as numpy array (H, W) with values 0 or 255 + """ + # Load image + img = Image.open(BytesIO(image_bytes)) + img_array = np.array(img) + + # Get cached model (now that wrapping is fixed, caching works properly) + model = get_sam_model(model_name) + + # Convert points and labels to numpy arrays + points_array = np.array(points, dtype=np.float32) + labels_array = np.array(labels, dtype=np.int32) + + print(f"Predicting with {len(points)} points: {points_array}") + print(f"Labels: {labels_array}") + + # IMPORTANT: Wrap points and labels in an extra list dimension + # This tells SAM that all points belong to ONE object + # points=[[[x1,y1], [x2,y2]]] instead of [[x1,y1], [x2,y2]] + points_wrapped = [points_array.tolist()] + labels_wrapped = [labels_array.tolist()] + + print(f"Wrapped points: {points_wrapped}") + print(f"Wrapped labels: {labels_wrapped}") + + # Run prediction with point prompts + results = model( + img_array, points=points_wrapped, labels=labels_wrapped, verbose=False + ) + + # Extract mask from results + if ( + len(results) > 0 + and hasattr(results[0], "masks") + and results[0].masks is not None + ): + # Get the first mask + mask = results[0].masks.data[0].cpu().numpy() + # Convert to uint8 (0 or 255) + mask = (mask * 255).astype(np.uint8) + print( + f"Generated mask with shape: {mask.shape}, unique values: {np.unique(mask)}, sum: {mask.sum()}" + ) + + # Fill holes if requested + if fill_holes: + print(f"Filling holes in mask with kernel size {kernel_size}") + mask = fill_mask_holes(mask, kernel_size) + + # Add padding if requested + if padding > 0: + print(f"Adding {padding}px padding to mask") + mask = add_mask_padding(mask, padding) + + return mask + + # Return empty mask if no result + print("No mask generated!") + return np.zeros(img_array.shape[:2], dtype=np.uint8) + + +def predict_mask_from_bboxes( + image_bytes: bytes, + bboxes: List[float], + model_name: str = "sam2_b.pt", + fill_holes: bool = True, + kernel_size: int = 5, + padding: int = 0, +) -> np.ndarray: + """ + Generate segmentation mask from bounding box prompt + + Args: + image_bytes: Image data as bytes + bboxes: Bounding box as [x1, y1, x2, y2] + model_name: SAM model to use + fill_holes: Whether to fill small holes in the mask + kernel_size: Size of morphological kernel for hole filling + padding: Number of pixels to expand the mask boundary (0 = no padding) + + Returns: + Binary mask as numpy array (H, W) with values 0 or 255 + """ + # Load image + img = Image.open(BytesIO(image_bytes)) + img_array = np.array(img) + + # Get cached model + model = get_sam_model(model_name) + + print(f"Predicting with bounding box: {bboxes}") + + # Run prediction with bboxes prompt + # SAM expects bboxes as a list: [x1, y1, x2, y2] + results = model(img_array, bboxes=bboxes, verbose=False) + + # Extract mask from results + if ( + len(results) > 0 + and hasattr(results[0], "masks") + and results[0].masks is not None + ): + # Get the first mask + mask = results[0].masks.data[0].cpu().numpy() + # Convert to uint8 (0 or 255) + mask = (mask * 255).astype(np.uint8) + print( + f"Generated mask with shape: {mask.shape}, unique values: {np.unique(mask)}, sum: {mask.sum()}" + ) + + # Fill holes if requested + if fill_holes: + print(f"Filling holes in mask with kernel size {kernel_size}") + mask = fill_mask_holes(mask, kernel_size) + + # Add padding if requested + if padding > 0: + print(f"Adding {padding}px padding to mask") + mask = add_mask_padding(mask, padding) + + return mask + + # Return empty mask if no result + print("No mask generated!") + return np.zeros(img_array.shape[:2], dtype=np.uint8) + + +def crop_image_with_mask( + image_bytes: bytes, + mask: np.ndarray, + padding_pixels: int = 0, + target_aspect_ratio: Optional[float] = None, +) -> Tuple[bytes, dict]: + """ + Crop image to mask bounds with padding and aspect ratio adjustment + + The padding is applied around the mask bounds first. If a target aspect ratio + is specified and the padded mask doesn't fit, padding is reduced automatically + to fit within the aspect ratio constraint. + + Args: + image_bytes: Original image data as bytes + mask: Binary mask (0 or 255) + padding_pixels: Padding in pixels around mask bounds + target_aspect_ratio: Target width/height ratio (None = use mask bounds) + + Returns: + Tuple of (cropped_image_bytes, info_dict) + """ + # Load original image + img = Image.open(BytesIO(image_bytes)) + img_array = np.array(img) + + # Find mask bounds + rows = np.any(mask > 0, axis=1) + cols = np.any(mask > 0, axis=0) + + if not np.any(rows) or not np.any(cols): + # No mask, return original image + output_buffer = BytesIO() + img.save(output_buffer, format="JPEG", quality=95) + output_buffer.seek(0) + return output_buffer.getvalue(), { + "original_size": {"width": img.width, "height": img.height}, + "crop_bounds": None, + "message": "No mask found", + } + + row_min, row_max = np.where(rows)[0][[0, -1]] + col_min, col_max = np.where(cols)[0][[0, -1]] + + # Calculate mask dimensions + mask_width = col_max - col_min + 1 + mask_height = row_max - row_min + 1 + + # Calculate mask center + mask_center_x = (col_min + col_max) / 2 + mask_center_y = (row_min + row_max) / 2 + + if target_aspect_ratio is not None: + # Calculate dimensions needed for aspect ratio around mask + padding + # Start with mask dimensions plus requested padding + padded_mask_width = mask_width + 2 * padding_pixels + padded_mask_height = mask_height + 2 * padding_pixels + + # Determine which dimension constrains us for the aspect ratio + required_width_for_height = padded_mask_height * target_aspect_ratio + required_height_for_width = padded_mask_width / target_aspect_ratio + + # Choose the dimension that fits the aspect ratio + if required_width_for_height >= padded_mask_width: + # Height is the constraint - use padded_mask_height and calculate width + final_height = padded_mask_height + final_width = int(final_height * target_aspect_ratio) + else: + # Width is the constraint - use padded_mask_width and calculate height + final_width = padded_mask_width + final_height = int(final_width / target_aspect_ratio) + + # Center the crop around the mask center + crop_x1 = int(mask_center_x - final_width / 2) + crop_y1 = int(mask_center_y - final_height / 2) + crop_x2 = crop_x1 + final_width + crop_y2 = crop_y1 + final_height + + # Clamp to image bounds + if crop_x1 < 0: + crop_x2 = min(crop_x2 - crop_x1, img.width) + crop_x1 = 0 + if crop_y1 < 0: + crop_y2 = min(crop_y2 - crop_y1, img.height) + crop_y1 = 0 + if crop_x2 > img.width: + crop_x1 = max(0, crop_x1 - (crop_x2 - img.width)) + crop_x2 = img.width + if crop_y2 > img.height: + crop_y1 = max(0, crop_y1 - (crop_y2 - img.height)) + crop_y2 = img.height + + actual_padding_x = (crop_x2 - crop_x1 - mask_width) // 2 + actual_padding_y = (crop_y2 - crop_y1 - mask_height) // 2 + else: + # No aspect ratio constraint - just add padding around mask + crop_x1 = max(0, col_min - padding_pixels) + crop_y1 = max(0, row_min - padding_pixels) + crop_x2 = min(img.width, col_max + padding_pixels + 1) + crop_y2 = min(img.height, row_max + padding_pixels + 1) + + actual_padding_x = min(padding_pixels, col_min, img.width - col_max - 1) + actual_padding_y = min(padding_pixels, row_min, img.height - row_max - 1) + + # Crop to final bounds + cropped = img_array[crop_y1:crop_y2, crop_x1:crop_x2] + + # Convert back to PIL Image and save + cropped_img = Image.fromarray(cropped) + output_buffer = BytesIO() + cropped_img.save(output_buffer, format="JPEG", quality=95, optimize=True) + output_buffer.seek(0) + + info = { + "original_size": {"width": img.width, "height": img.height}, + "mask_bounds": { + "x": int(col_min), + "y": int(row_min), + "width": int(mask_width), + "height": int(mask_height), + }, + "crop_bounds": { + "x": int(crop_x1), + "y": int(crop_y1), + "width": int(crop_x2 - crop_x1), + "height": int(crop_y2 - crop_y1), + }, + "final_size": {"width": cropped_img.width, "height": cropped_img.height}, + "padding_applied": { + "horizontal": int(actual_padding_x), + "vertical": int(actual_padding_y), + }, + "requested_padding": padding_pixels, + } + + return output_buffer.getvalue(), info + + +def remove_background( + image_bytes: bytes, + mask: np.ndarray, +) -> bytes: + """ + Remove background using the provided mask + + Args: + image_bytes: Image data as bytes + mask: Binary mask (0 or 255) where 255 = keep, 0 = remove + + Returns: + PNG image bytes with transparent background + """ + # Load image + img = Image.open(BytesIO(image_bytes)) + + print(f"Removing background from image: {img.size}") + print(f"Mask shape: {mask.shape}, unique values: {np.unique(mask)}") + + # Convert image to RGBA if not already + if img.mode != "RGBA": + img = img.convert("RGBA") + + # Get image as array + img_array = np.array(img) + + # Ensure mask is the same size as image + if mask.shape != img_array.shape[:2]: + mask_img = Image.fromarray(mask) + mask_img = mask_img.resize((img.width, img.height), Image.LANCZOS) + mask = np.array(mask_img) + + # Apply mask as alpha channel (255 = keep, 0 = transparent) + img_array[:, :, 3] = mask + + # Convert back to PIL Image + result_img = Image.fromarray(img_array, "RGBA") + + # Save as PNG (supports transparency) + output_buffer = BytesIO() + result_img.save(output_buffer, format="PNG", optimize=True) + output_buffer.seek(0) + + return output_buffer.getvalue() diff --git a/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs b/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs index eefbbe560..d7f0d8e99 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs +++ b/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs @@ -1676,6 +1676,113 @@ void RunSegmentationProcessing(WorkflowGenerator g, bool isBeforeRefiner) RunSegmentationProcessing(g, isBeforeRefiner: false); }, 5); #endregion + #region SAM2 Masking + AddStep(g => + { + if (!g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointCoordsPositive, out string coords) || string.IsNullOrWhiteSpace(coords) || coords == "[]") + { + return; + } + string negCoords = null; + if (g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointCoordsNegative, out string negCoordsRaw) && !string.IsNullOrWhiteSpace(negCoordsRaw) && negCoordsRaw != "[]") + { + negCoords = negCoordsRaw; + } + JArray imageNodeActual = null; + if (g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointImage, out Image img)) + { + WGNodeData imageNode = g.LoadImage(img, "${sampointimage}", true); + imageNodeActual = imageNode.Path; + } + else if (g.BasicInputImage is not null) + { + imageNodeActual = g.BasicInputImage.Path; + } + if (imageNodeActual is null) + { + return; + } + string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", new JObject() + { + ["model"] = "sam2_hiera_base_plus.safetensors", + ["segmentor"] = "single_image", + ["device"] = "cuda", + ["precision"] = "bf16" + }); + JObject segInputs = new() + { + ["sam2_model"] = new JArray() { modelNode, 0 }, + ["image"] = imageNodeActual, + ["keep_model_loaded"] = true, + ["coordinates_positive"] = coords, + ["fill_holes"] = true, + ["hole_kernel_size"] = 9, + ["mask_padding"] = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int pointsPadding) ? pointsPadding : 0 + }; + Logs.Debug($"[SAM2-Points] mask_padding from UserInput = {g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0")}"); + if (negCoords is not null) + { + segInputs["coordinates_negative"] = negCoords; + } + string segNode = g.CreateNode("Sam2Segmentation", segInputs); + string maskNode = g.CreateNode("MaskToImage", new JObject() + { + ["mask"] = new JArray() { segNode, 0 } + }); + new WGNodeData([maskNode, 0], g, WGNodeData.DT_IMAGE, g.CurrentCompat()).SaveOutput(null, null, "9"); + g.SkipFurtherSteps = true; + }, 8.9); + AddStep(g => + { + if (!g.UserInput.TryGet(ComfyUIBackendExtension.Sam2BBox, out string bboxJson) || string.IsNullOrWhiteSpace(bboxJson)) + { + return; + } + JArray imageNodeActual = null; + if (g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointImage, out Image img)) + { + WGNodeData imageNode = g.LoadImage(img, "${sampointimage}", true); + imageNodeActual = imageNode.Path; + } + else if (g.BasicInputImage is not null) + { + imageNodeActual = g.BasicInputImage.Path; + } + if (imageNodeActual is null) + { + return; + } + string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", new JObject() + { + ["model"] = "sam2_hiera_base_plus.safetensors", + ["segmentor"] = "single_image", + ["device"] = "cuda", + ["precision"] = "bf16" + }); + string bboxNode = g.CreateNode("Sam2BBoxFromJson", new JObject() + { + ["bbox_json"] = bboxJson + }); + JObject segInputs = new() + { + ["sam2_model"] = new JArray() { modelNode, 0 }, + ["image"] = imageNodeActual, + ["keep_model_loaded"] = true, + ["bboxes"] = new JArray() { bboxNode, 0 }, + ["fill_holes"] = true, + ["hole_kernel_size"] = 5, + ["mask_padding"] = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int bboxPadding) ? bboxPadding : 0 + }; + Logs.Debug($"[SAM2-BBox] mask_padding from UserInput = {g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0")}"); + string segNode = g.CreateNode("Sam2Segmentation", segInputs); + string maskNode = g.CreateNode("MaskToImage", new JObject() + { + ["mask"] = new JArray() { segNode, 0 } + }); + new WGNodeData([maskNode, 0], g, WGNodeData.DT_IMAGE, g.CurrentCompat()).SaveOutput(null, null, "9"); + g.SkipFurtherSteps = true; + }, 8.85); + #endregion #region SaveImage AddStep(g => { diff --git a/src/wwwroot/imgs/crosshair.png b/src/wwwroot/imgs/crosshair.png new file mode 100644 index 0000000000000000000000000000000000000000..6e9495e045213409a46d1da1fbf95967080cbe57 GIT binary patch literal 233 zcmeAS@N?(olHy`uVBq!ia0vp^3LwnE1|*BCs=fdz#^NA%Cx&(BWL^R}Ea{HEjtmSN z`?>!lvI6;>1s;*b3=DjSL74G){)!Z!V4bInV@QPi+jEY54GIFzf$!@x?{+B4a|APb z`rHrlwmzO{_F&J)-|00C{5Q%Rq}C;tS@B)_aMon+#C!i0FMi7KTQ)`O=Aq&`=FdlL z4C5a#C(WCo{_OA+CRJ&W#D1ccE9H6XN8(0H&P>5>}GG mXcoylV9(r?&Tci>{u~E-?HavcMqShZ0000z literal 0 HcmV?d00001 diff --git a/src/wwwroot/js/genpage/helpers/image_editor.js b/src/wwwroot/js/genpage/helpers/image_editor.js index b583683e5..810981b1c 100644 --- a/src/wwwroot/js/genpage/helpers/image_editor.js +++ b/src/wwwroot/js/genpage/helpers/image_editor.js @@ -95,6 +95,10 @@ class ImageEditorTool { onGlobalMouseUp(e) { return false; } + + onContextMenu(e) { + return false; + } } /** @@ -1234,6 +1238,387 @@ class ImageEditorToolPicker extends ImageEditorTempTool { } } +/** + * The SAM2 Point Segmentation tool - click to place positive/negative points and auto-generate a mask. + */ +class ImageEditorToolSam2Points extends ImageEditorTool { + constructor(editor) { + super(editor, 'sam2points', 'crosshair', 'SAM2 Points', 'Left click to add positive points. Right click to add negative points.\nEach click regenerates the mask.\nRequires SAM2 to be installed.\nHotKey: Y', 'y'); + this.cursor = 'crosshair'; + this.positivePoints = []; + this.negativePoints = []; + this.requestSerial = 0; + this.activeRequestId = 0; + this.maskRequestInFlight = false; + this.pendingMaskUpdate = false; + this.modelWarmed = false; + this.isWarmingUp = false; + this.configDiv.innerHTML = ` +
+ + + + +
`; + this.maskPaddingInput = this.configDiv.querySelector('.id-mask-padding'); + this.configDiv.querySelector('.id-clear-mask').addEventListener('click', () => { + // Clear points + this.positivePoints = []; + this.negativePoints = []; + // Clear mask layer + let maskLayer = this.editor.activeLayer && this.editor.activeLayer.isMask ? this.editor.activeLayer : this.editor.layers.find(layer => layer.isMask); + if (maskLayer) { + maskLayer.saveBeforeEdit(); + maskLayer.ctx.clearRect(0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + maskLayer.hasAnyContent = false; + } + this.activeRequestId = ++this.requestSerial; + this.maskRequestInFlight = false; + this.pendingMaskUpdate = false; + this.editor.redraw(); + }); + } + + drawPoint(ctx, x, y, fillColor, showX) { + let [cx, cy] = this.editor.imageCoordToCanvasCoord(x, y); + let radius = Math.max(3, Math.round(4 * this.editor.zoomLevel)); + ctx.save(); + ctx.lineWidth = Math.max(1, Math.round(2 * this.editor.zoomLevel)); + ctx.strokeStyle = '#000000'; + ctx.fillStyle = fillColor; + ctx.beginPath(); + ctx.arc(cx, cy, radius, 0, 2 * Math.PI); + ctx.fill(); + ctx.stroke(); + if (showX) { + let cross = Math.max(3, Math.round(radius * 0.9)); + ctx.beginPath(); + ctx.moveTo(cx - cross, cy - cross); + ctx.lineTo(cx + cross, cy + cross); + ctx.moveTo(cx - cross, cy + cross); + ctx.lineTo(cx + cross, cy - cross); + ctx.stroke(); + } + ctx.restore(); + } + + draw() { + let ctx = this.editor.ctx; + for (let point of this.positivePoints) { + this.drawPoint(ctx, point.x, point.y, '#33ff99', false); + } + for (let point of this.negativePoints) { + this.drawPoint(ctx, point.x, point.y, '#ff3355', true); + } + } + + onContextMenu(e) { + e.preventDefault(); + return true; + } + + setActive() { + super.setActive(); + if (!this.modelWarmed && !this.isWarmingUp && currentBackendFeatureSet.includes('sam2') && this.editor.getFinalImageData?.()) { + this.triggerWarmup(); + } + } + + triggerWarmup() { + this.isWarmingUp = true; + let statusElem = this.configDiv.querySelector('.id-sam2-status'); + if (statusElem) { statusElem.style.display = ''; } + try { + let img = this.editor.getFinalImageData(); + let genData = getGenInput(); + genData['sampointimage'] = img; + genData['images'] = 1; + genData['prompt'] = ''; + delete genData['batchsize']; + genData['donotsave'] = true; + let cx = Math.floor((this.editor.realWidth || 64) / 2); + let cy = Math.floor((this.editor.realHeight || 64) / 2); + genData['sampositivepoints'] = JSON.stringify([{ x: cx, y: cy }]); + makeWSRequestT2I('GenerateText2ImageWS', genData, data => { + if (data.image || data.error) { + this.modelWarmed = true; + this.isWarmingUp = false; + if (statusElem) { statusElem.style.display = 'none'; } + } + }); + } catch (e) { + this.modelWarmed = true; + this.isWarmingUp = false; + if (statusElem) { statusElem.style.display = 'none'; } + } + } + + onMouseDown(e) { + if (this.isWarmingUp) { return; } + if (e.button !== 0 && e.button !== 2) { + return; + } + this.editor.updateMousePosFrom(e); + let [mouseX, mouseY] = this.editor.canvasCoordToImageCoord(this.editor.mouseX, this.editor.mouseY); + mouseX = Math.round(mouseX); + mouseY = Math.round(mouseY); + if (mouseX < 0 || mouseY < 0 || mouseX >= this.editor.realWidth || mouseY >= this.editor.realHeight) { + return; + } + let point = { x: mouseX, y: mouseY }; + if (e.button === 2) { + e.preventDefault(); + this.negativePoints.push(point); + } + else { + this.positivePoints.push(point); + } + this.queueMaskUpdate(); + this.editor.redraw(); + } + + queueMaskUpdate() { + if (!currentBackendFeatureSet.includes('sam2')) { + $('#sam2_installer').modal('show'); + return; + } + if (this.positivePoints.length === 0) { + return; + } + if (this.maskRequestInFlight) { + this.pendingMaskUpdate = true; + return; + } + this.requestMaskUpdate(); + } + + finishMaskUpdate(requestId) { + if (requestId !== this.activeRequestId) { + return; + } + this.maskRequestInFlight = false; + if (this.pendingMaskUpdate) { + this.pendingMaskUpdate = false; + this.requestMaskUpdate(); + } + } + + requestMaskUpdate() { + this.maskRequestInFlight = true; + let requestId = ++this.requestSerial; + this.activeRequestId = requestId; + let img = this.editor.getFinalImageData(); + let genData = getGenInput(); + genData['sampointimage'] = img; + genData['images'] = 1; + genData['prompt'] = ''; + delete genData['batchsize']; + genData['donotsave'] = true; + genData['sampositivepoints'] = JSON.stringify(this.positivePoints); + if (this.negativePoints.length > 0) { + genData['samnegativepoints'] = JSON.stringify(this.negativePoints); + } + let maskPadding = parseInt(this.maskPaddingInput.value) || 0; + if (maskPadding > 0) { + genData['sammaskpadding'] = `${maskPadding}`; + } + makeWSRequestT2I('GenerateText2ImageWS', genData, data => { + if (requestId !== this.activeRequestId) { + return; + } + if (!data.image) { + return; + } + let newImg = new Image(); + newImg.onload = () => { + if (requestId !== this.activeRequestId) { + return; + } + this.editor.applyMaskFromImage(newImg, true); + this.finishMaskUpdate(requestId); + }; + newImg.src = data.image; + }); + } +} + +/** + * The SAM2 Bounding Box segmentation tool - drag to define a box and auto-generate a mask. + */ +class ImageEditorToolSam2BBox extends ImageEditorTool { + constructor(editor) { + super(editor, 'sam2bbox', 'rectangle', 'SAM2 BBox', 'Click and drag to create a bounding box. Release to generate mask.\nRequires SAM2 to be installed.', null); + this.cursor = 'crosshair'; + this.bboxStartX = null; + this.bboxStartY = null; + this.bboxEndX = null; + this.bboxEndY = null; + this.isDrawing = false; + this.requestSerial = 0; + this.activeRequestId = 0; + this.maskRequestInFlight = false; + this.modelWarmed = false; + this.isWarmingUp = false; + this.configDiv.innerHTML = ` +
+ + + + +
`; + this.maskPaddingInput = this.configDiv.querySelector('.id-mask-padding'); + this.configDiv.querySelector('.id-clear-mask').addEventListener('click', () => { + let maskLayer = this.editor.activeLayer && this.editor.activeLayer.isMask ? this.editor.activeLayer : this.editor.layers.find(layer => layer.isMask); + if (!maskLayer) { + return; + } + maskLayer.saveBeforeEdit(); + maskLayer.ctx.clearRect(0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + maskLayer.hasAnyContent = false; + this.editor.redraw(); + }); + } + + draw() { + if (this.isDrawing && this.bboxStartX !== null && this.bboxEndX !== null) { + let ctx = this.editor.ctx; + let [x1, y1] = this.editor.imageCoordToCanvasCoord(this.bboxStartX, this.bboxStartY); + let [x2, y2] = this.editor.imageCoordToCanvasCoord(this.bboxEndX, this.bboxEndY); + let minX = Math.min(x1, x2); + let minY = Math.min(y1, y2); + let maxX = Math.max(x1, x2); + let maxY = Math.max(y1, y2); + ctx.save(); + ctx.strokeStyle = '#33ff99'; + ctx.lineWidth = 2; + ctx.setLineDash([5, 5]); + ctx.strokeRect(minX, minY, maxX - minX, maxY - minY); + ctx.restore(); + } + } + + setActive() { + super.setActive(); + if (!this.modelWarmed && !this.isWarmingUp && currentBackendFeatureSet.includes('sam2') && this.editor.getFinalImageData?.()) { + this.triggerWarmup(); + } + } + + triggerWarmup() { + this.isWarmingUp = true; + let statusElem = this.configDiv.querySelector('.id-sam2-status'); + if (statusElem) { statusElem.style.display = ''; } + try { + let img = this.editor.getFinalImageData(); + let genData = getGenInput(); + genData['sampointimage'] = img; + genData['images'] = 1; + genData['prompt'] = ''; + delete genData['batchsize']; + genData['donotsave'] = true; + let cx = Math.floor((this.editor.realWidth || 64) / 2); + let cy = Math.floor((this.editor.realHeight || 64) / 2); + genData['sambbox'] = JSON.stringify([cx - 1, cy - 1, cx + 1, cy + 1]); + makeWSRequestT2I('GenerateText2ImageWS', genData, data => { + if (data.image || data.error) { + this.modelWarmed = true; + this.isWarmingUp = false; + if (statusElem) { statusElem.style.display = 'none'; } + } + }); + } catch (e) { + this.modelWarmed = true; + this.isWarmingUp = false; + if (statusElem) { statusElem.style.display = 'none'; } + } + } + + onMouseDown(e) { + if (this.isWarmingUp) { return; } + if (e.button !== 0) { + return; + } + this.editor.updateMousePosFrom(e); + let [mouseX, mouseY] = this.editor.canvasCoordToImageCoord(this.editor.mouseX, this.editor.mouseY); + mouseX = Math.round(mouseX); + mouseY = Math.round(mouseY); + if (mouseX < 0 || mouseY < 0 || mouseX >= this.editor.realWidth || mouseY >= this.editor.realHeight) { + return; + } + this.isDrawing = true; + this.bboxStartX = mouseX; + this.bboxStartY = mouseY; + this.bboxEndX = mouseX; + this.bboxEndY = mouseY; + } + + onMouseMove(e) { + if (this.isDrawing) { + this.editor.updateMousePosFrom(e); + let [mouseX, mouseY] = this.editor.canvasCoordToImageCoord(this.editor.mouseX, this.editor.mouseY); + mouseX = Math.max(0, Math.min(this.editor.realWidth - 1, Math.round(mouseX))); + mouseY = Math.max(0, Math.min(this.editor.realHeight - 1, Math.round(mouseY))); + this.bboxEndX = mouseX; + this.bboxEndY = mouseY; + this.editor.redraw(); + } + } + + onMouseUp(e) { + if (this.isWarmingUp) { return; } + if (this.isDrawing) { + this.isDrawing = false; + this.requestMaskUpdate(); + } + } + + requestMaskUpdate() { + if (!currentBackendFeatureSet.includes('sam2')) { + $('#sam2_installer').modal('show'); + return; + } + if (this.bboxStartX === null || this.bboxEndX === null) { + return; + } + this.maskRequestInFlight = true; + let requestId = ++this.requestSerial; + this.activeRequestId = requestId; + let img = this.editor.getFinalImageData(); + let genData = getGenInput(); + genData['sampointimage'] = img; + genData['images'] = 1; + genData['prompt'] = ''; + delete genData['batchsize']; + genData['donotsave'] = true; + let minX = Math.min(this.bboxStartX, this.bboxEndX); + let minY = Math.min(this.bboxStartY, this.bboxEndY); + let maxX = Math.max(this.bboxStartX, this.bboxEndX); + let maxY = Math.max(this.bboxStartY, this.bboxEndY); + genData['sambbox'] = JSON.stringify([minX, minY, maxX, maxY]); + let maskPadding = parseInt(this.maskPaddingInput.value) || 0; + if (maskPadding > 0) { + genData['sammaskpadding'] = `${maskPadding}`; + } + makeWSRequestT2I('GenerateText2ImageWS', genData, data => { + if (requestId !== this.activeRequestId) { + return; + } + if (!data.image) { + return; + } + let newImg = new Image(); + newImg.onload = () => { + if (requestId !== this.activeRequestId) { + return; + } + this.editor.applyMaskFromImage(newImg, true); + this.maskRequestInFlight = false; + }; + newImg.src = data.image; + }); + } +} + /** * A single layer within an image editing interface. * This can be real (user-controlled) OR sub-layers (sometimes user-controlled) OR temporary buffers. @@ -1595,6 +1980,8 @@ class ImageEditor { this.addTool(new ImageEditorToolShape(this)); this.pickerTool = new ImageEditorToolPicker(this, 'picker', 'paintbrush', 'Color Picker', 'Pick a color from the image.'); this.addTool(this.pickerTool); + this.addTool(new ImageEditorToolSam2Points(this)); + this.addTool(new ImageEditorToolSam2BBox(this)); this.activateTool('brush'); this.maxHistory = 15; } @@ -1681,6 +2068,11 @@ class ImageEditor { e.stopPropagation(); }); canvas.addEventListener('drop', (e) => this.handleCanvasImageDrop(e)); + canvas.addEventListener('contextmenu', (e) => { + if (this.activeTool && this.activeTool.onContextMenu(e)) { + e.preventDefault(); + } + }); this.ctx = canvas.getContext('2d'); canvas.style.cursor = 'none'; this.maskHelperCanvas = document.createElement('canvas'); @@ -2012,6 +2404,16 @@ class ImageEditor { this.addLayer(maskLayer); this.realWidth = img.naturalWidth; this.realHeight = img.naturalHeight; + if (this.tools['sam2points']) { + this.tools['sam2points'].positivePoints = []; + this.tools['sam2points'].negativePoints = []; + } + if (this.tools['sam2bbox']) { + this.tools['sam2bbox'].bboxStartX = null; + this.tools['sam2bbox'].bboxStartY = null; + this.tools['sam2bbox'].bboxEndX = null; + this.tools['sam2bbox'].bboxEndY = null; + } this.offsetX = 0 this.offsetY = 0; if (this.active) { @@ -2252,6 +2654,34 @@ class ImageEditor { return canvas.toDataURL(format); } + applyMaskFromImage(img, replaceExisting = true) { + let maskLayer = this.activeLayer && this.activeLayer.isMask ? this.activeLayer : this.layers.find(layer => layer.isMask); + if (!maskLayer) { + maskLayer = new ImageEditorLayer(this, img.naturalWidth || img.width, img.naturalHeight || img.height); + maskLayer.isMask = true; + this.addLayer(maskLayer); + } + if (replaceExisting) { + maskLayer.saveBeforeEdit(); + maskLayer.ctx.clearRect(0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + } + maskLayer.ctx.drawImage(img, 0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + // Convert black pixels to transparent so only the white mask region is visible + let imageData = maskLayer.ctx.getImageData(0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + let data = imageData.data; + for (let i = 0; i < data.length; i += 4) { + let brightness = data[i] + data[i + 1] + data[i + 2]; + if (brightness < 128) { + data[i + 3] = 0; + } + } + maskLayer.ctx.putImageData(imageData, 0, 0); + maskLayer.hasAnyContent = true; + this.setActiveLayer(maskLayer); + this.sortLayers(); + this.redraw(); + } + getFinalMaskData(format = 'image/png') { let canvas = document.createElement('canvas'); canvas.width = this.realWidth; From 619d797480099bd204e693b43075ea941c811bb8 Mon Sep 17 00:00:00 2001 From: Glen Date: Sun, 8 Mar 2026 16:25:38 -0600 Subject: [PATCH 2/6] Refactor SAM2 tool UI to improve warmup status display and control handling --- .../js/genpage/helpers/image_editor.js | 48 ++++++++++++++----- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/src/wwwroot/js/genpage/helpers/image_editor.js b/src/wwwroot/js/genpage/helpers/image_editor.js index 810981b1c..7da245216 100644 --- a/src/wwwroot/js/genpage/helpers/image_editor.js +++ b/src/wwwroot/js/genpage/helpers/image_editor.js @@ -1253,14 +1253,21 @@ class ImageEditorToolSam2Points extends ImageEditorTool { this.pendingMaskUpdate = false; this.modelWarmed = false; this.isWarmingUp = false; - this.configDiv.innerHTML = ` + this.controlsHTML = `
-
`; + this.warmupHTML = `
Warming up SAM2 model...
`; + this.showControls(); + } + + showControls() { + let prevPadding = this.maskPaddingInput ? this.maskPaddingInput.value : '0'; + this.configDiv.innerHTML = this.controlsHTML; this.maskPaddingInput = this.configDiv.querySelector('.id-mask-padding'); + this.maskPaddingInput.value = prevPadding; this.configDiv.querySelector('.id-clear-mask').addEventListener('click', () => { // Clear points this.positivePoints = []; @@ -1326,8 +1333,9 @@ class ImageEditorToolSam2Points extends ImageEditorTool { triggerWarmup() { this.isWarmingUp = true; - let statusElem = this.configDiv.querySelector('.id-sam2-status'); - if (statusElem) { statusElem.style.display = ''; } + this.cursor = 'wait'; + this.editor.canvas.style.cursor = 'wait'; + this.configDiv.innerHTML = this.warmupHTML; try { let img = this.editor.getFinalImageData(); let genData = getGenInput(); @@ -1343,13 +1351,17 @@ class ImageEditorToolSam2Points extends ImageEditorTool { if (data.image || data.error) { this.modelWarmed = true; this.isWarmingUp = false; - if (statusElem) { statusElem.style.display = 'none'; } + this.cursor = 'crosshair'; + this.editor.canvas.style.cursor = 'crosshair'; + this.showControls(); } }); } catch (e) { this.modelWarmed = true; this.isWarmingUp = false; - if (statusElem) { statusElem.style.display = 'none'; } + this.cursor = 'crosshair'; + this.editor.canvas.style.cursor = 'crosshair'; + this.showControls(); } } @@ -1459,14 +1471,21 @@ class ImageEditorToolSam2BBox extends ImageEditorTool { this.maskRequestInFlight = false; this.modelWarmed = false; this.isWarmingUp = false; - this.configDiv.innerHTML = ` + this.controlsHTML = `
-
`; + this.warmupHTML = `
Warming up SAM2 model...
`; + this.showControls(); + } + + showControls() { + let prevPadding = this.maskPaddingInput ? this.maskPaddingInput.value : '0'; + this.configDiv.innerHTML = this.controlsHTML; this.maskPaddingInput = this.configDiv.querySelector('.id-mask-padding'); + this.maskPaddingInput.value = prevPadding; this.configDiv.querySelector('.id-clear-mask').addEventListener('click', () => { let maskLayer = this.editor.activeLayer && this.editor.activeLayer.isMask ? this.editor.activeLayer : this.editor.layers.find(layer => layer.isMask); if (!maskLayer) { @@ -1506,8 +1525,9 @@ class ImageEditorToolSam2BBox extends ImageEditorTool { triggerWarmup() { this.isWarmingUp = true; - let statusElem = this.configDiv.querySelector('.id-sam2-status'); - if (statusElem) { statusElem.style.display = ''; } + this.cursor = 'wait'; + this.editor.canvas.style.cursor = 'wait'; + this.configDiv.innerHTML = this.warmupHTML; try { let img = this.editor.getFinalImageData(); let genData = getGenInput(); @@ -1523,13 +1543,17 @@ class ImageEditorToolSam2BBox extends ImageEditorTool { if (data.image || data.error) { this.modelWarmed = true; this.isWarmingUp = false; - if (statusElem) { statusElem.style.display = 'none'; } + this.cursor = 'crosshair'; + this.editor.canvas.style.cursor = 'crosshair'; + this.showControls(); } }); } catch (e) { this.modelWarmed = true; this.isWarmingUp = false; - if (statusElem) { statusElem.style.display = 'none'; } + this.cursor = 'crosshair'; + this.editor.canvas.style.cursor = 'crosshair'; + this.showControls(); } } From ddfac2396dd888aa307d960aa83eeef4387e4a7d Mon Sep 17 00:00:00 2001 From: Glen Date: Sun, 8 Mar 2026 16:30:34 -0600 Subject: [PATCH 3/6] Consolidate SAM2 point coordinate parameters into a single declaration --- .../ComfyUIBackend/ComfyUIBackendExtension.cs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs b/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs index 341f17e4d..875c8f6c2 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs +++ b/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs @@ -653,9 +653,7 @@ public static void AssignValuesFromRaw(JObject rawObjectInfo) public static T2IRegisteredParam Sam2PointImage; - public static T2IRegisteredParam Sam2PointCoordsPositive, Sam2PointCoordsNegative, Sam2BBox; - - public static T2IRegisteredParam Sam2MaskPadding; + public static T2IRegisteredParam Sam2PointCoordsPositive, Sam2PointCoordsNegative, Sam2BBox, Sam2MaskPadding; /// public override void OnInit() From cf8f4aec5fc2099e2739dd2d5c223768d603a937 Mon Sep 17 00:00:00 2001 From: Glen Date: Sun, 8 Mar 2026 16:34:32 -0600 Subject: [PATCH 4/6] Remove unused SAM2 mask processing classes and functions --- .../ExtraNodes/Sam2BBoxNode/__init__.py | 59 ------ .../ExtraNodes/Sam2BBoxNode/sam2.py | 185 +----------------- 2 files changed, 1 insertion(+), 243 deletions(-) diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py index de629bfb6..7951b420c 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py +++ b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py @@ -7,8 +7,6 @@ from .sam2 import ( predict_mask_from_points, predict_mask_from_bboxes, - crop_image_with_mask, - remove_background, ) @@ -162,71 +160,14 @@ def segment(self, image, bbox_json, model_name, fill_holes=True, hole_kernel_siz return (mask_tensor.unsqueeze(0),) -class Sam2MaskPostProcess: - """Post-process SAM2 masks: crop to bounds, remove background, etc.""" - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "mask": ("MASK",), - }, - "optional": { - "operation": (["crop_to_mask", "remove_background"], {"default": "crop_to_mask"}), - "padding_pixels": ("INT", {"default": 0, "min": 0, "max": 100}), - "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1}), - } - } - - RETURN_TYPES = ("IMAGE", "STRING") - RETURN_NAMES = ("result", "info") - FUNCTION = "process" - CATEGORY = "SAM2" - - def process(self, image, mask, operation, padding_pixels=0, aspect_ratio=1.0): - """Post-process masks: crop or remove background""" - # Convert image to bytes - img_pil = Image.fromarray((image[0].cpu().numpy() * 255).astype(np.uint8)) - img_buffer = BytesIO() - img_pil.save(img_buffer, format="JPEG") - image_bytes = img_buffer.getvalue() - - # Convert mask to numpy (0 or 255) - mask_np = (mask[0].cpu().numpy() * 255).astype(np.uint8) - - if operation == "crop_to_mask": - target_ratio = aspect_ratio if aspect_ratio != 1.0 else None - result_bytes, info_dict = crop_image_with_mask( - image_bytes, - mask_np, - padding_pixels=padding_pixels, - target_aspect_ratio=target_ratio, - ) - result_img = Image.open(BytesIO(result_bytes)) - else: # remove_background - result_bytes = remove_background(image_bytes, mask_np) - result_img = Image.open(BytesIO(result_bytes)).convert("RGB") - info_dict = {"operation": "background_removed"} - - # Convert back to tensor - result_tensor = torch.from_numpy(np.array(result_img)).float() / 255.0 - if result_tensor.ndim == 2: - result_tensor = result_tensor.unsqueeze(-1).repeat(1, 1, 3) - - return (result_tensor.unsqueeze(0), json.dumps(info_dict)) - - NODE_CLASS_MAPPINGS = { "Sam2BBoxFromJson": Sam2BBoxFromJson, "Sam2PointSegmentation": Sam2PointSegmentation, "Sam2BBoxSegmentation": Sam2BBoxSegmentation, - "Sam2MaskPostProcess": Sam2MaskPostProcess, } NODE_DISPLAY_NAME_MAPPINGS = { "Sam2BBoxFromJson": "SAM2 BBox From JSON", "Sam2PointSegmentation": "SAM2 Point Segmentation", "Sam2BBoxSegmentation": "SAM2 BBox Segmentation", - "Sam2MaskPostProcess": "SAM2 Mask Post-Process", } diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py index 27de71ab7..b9639b861 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py +++ b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py @@ -6,7 +6,7 @@ import numpy as np from PIL import Image from io import BytesIO -from typing import List, Tuple, Optional +from typing import List, Tuple from pathlib import Path from ultralytics import SAM import torch @@ -313,186 +313,3 @@ def predict_mask_from_bboxes( return np.zeros(img_array.shape[:2], dtype=np.uint8) -def crop_image_with_mask( - image_bytes: bytes, - mask: np.ndarray, - padding_pixels: int = 0, - target_aspect_ratio: Optional[float] = None, -) -> Tuple[bytes, dict]: - """ - Crop image to mask bounds with padding and aspect ratio adjustment - - The padding is applied around the mask bounds first. If a target aspect ratio - is specified and the padded mask doesn't fit, padding is reduced automatically - to fit within the aspect ratio constraint. - - Args: - image_bytes: Original image data as bytes - mask: Binary mask (0 or 255) - padding_pixels: Padding in pixels around mask bounds - target_aspect_ratio: Target width/height ratio (None = use mask bounds) - - Returns: - Tuple of (cropped_image_bytes, info_dict) - """ - # Load original image - img = Image.open(BytesIO(image_bytes)) - img_array = np.array(img) - - # Find mask bounds - rows = np.any(mask > 0, axis=1) - cols = np.any(mask > 0, axis=0) - - if not np.any(rows) or not np.any(cols): - # No mask, return original image - output_buffer = BytesIO() - img.save(output_buffer, format="JPEG", quality=95) - output_buffer.seek(0) - return output_buffer.getvalue(), { - "original_size": {"width": img.width, "height": img.height}, - "crop_bounds": None, - "message": "No mask found", - } - - row_min, row_max = np.where(rows)[0][[0, -1]] - col_min, col_max = np.where(cols)[0][[0, -1]] - - # Calculate mask dimensions - mask_width = col_max - col_min + 1 - mask_height = row_max - row_min + 1 - - # Calculate mask center - mask_center_x = (col_min + col_max) / 2 - mask_center_y = (row_min + row_max) / 2 - - if target_aspect_ratio is not None: - # Calculate dimensions needed for aspect ratio around mask + padding - # Start with mask dimensions plus requested padding - padded_mask_width = mask_width + 2 * padding_pixels - padded_mask_height = mask_height + 2 * padding_pixels - - # Determine which dimension constrains us for the aspect ratio - required_width_for_height = padded_mask_height * target_aspect_ratio - required_height_for_width = padded_mask_width / target_aspect_ratio - - # Choose the dimension that fits the aspect ratio - if required_width_for_height >= padded_mask_width: - # Height is the constraint - use padded_mask_height and calculate width - final_height = padded_mask_height - final_width = int(final_height * target_aspect_ratio) - else: - # Width is the constraint - use padded_mask_width and calculate height - final_width = padded_mask_width - final_height = int(final_width / target_aspect_ratio) - - # Center the crop around the mask center - crop_x1 = int(mask_center_x - final_width / 2) - crop_y1 = int(mask_center_y - final_height / 2) - crop_x2 = crop_x1 + final_width - crop_y2 = crop_y1 + final_height - - # Clamp to image bounds - if crop_x1 < 0: - crop_x2 = min(crop_x2 - crop_x1, img.width) - crop_x1 = 0 - if crop_y1 < 0: - crop_y2 = min(crop_y2 - crop_y1, img.height) - crop_y1 = 0 - if crop_x2 > img.width: - crop_x1 = max(0, crop_x1 - (crop_x2 - img.width)) - crop_x2 = img.width - if crop_y2 > img.height: - crop_y1 = max(0, crop_y1 - (crop_y2 - img.height)) - crop_y2 = img.height - - actual_padding_x = (crop_x2 - crop_x1 - mask_width) // 2 - actual_padding_y = (crop_y2 - crop_y1 - mask_height) // 2 - else: - # No aspect ratio constraint - just add padding around mask - crop_x1 = max(0, col_min - padding_pixels) - crop_y1 = max(0, row_min - padding_pixels) - crop_x2 = min(img.width, col_max + padding_pixels + 1) - crop_y2 = min(img.height, row_max + padding_pixels + 1) - - actual_padding_x = min(padding_pixels, col_min, img.width - col_max - 1) - actual_padding_y = min(padding_pixels, row_min, img.height - row_max - 1) - - # Crop to final bounds - cropped = img_array[crop_y1:crop_y2, crop_x1:crop_x2] - - # Convert back to PIL Image and save - cropped_img = Image.fromarray(cropped) - output_buffer = BytesIO() - cropped_img.save(output_buffer, format="JPEG", quality=95, optimize=True) - output_buffer.seek(0) - - info = { - "original_size": {"width": img.width, "height": img.height}, - "mask_bounds": { - "x": int(col_min), - "y": int(row_min), - "width": int(mask_width), - "height": int(mask_height), - }, - "crop_bounds": { - "x": int(crop_x1), - "y": int(crop_y1), - "width": int(crop_x2 - crop_x1), - "height": int(crop_y2 - crop_y1), - }, - "final_size": {"width": cropped_img.width, "height": cropped_img.height}, - "padding_applied": { - "horizontal": int(actual_padding_x), - "vertical": int(actual_padding_y), - }, - "requested_padding": padding_pixels, - } - - return output_buffer.getvalue(), info - - -def remove_background( - image_bytes: bytes, - mask: np.ndarray, -) -> bytes: - """ - Remove background using the provided mask - - Args: - image_bytes: Image data as bytes - mask: Binary mask (0 or 255) where 255 = keep, 0 = remove - - Returns: - PNG image bytes with transparent background - """ - # Load image - img = Image.open(BytesIO(image_bytes)) - - print(f"Removing background from image: {img.size}") - print(f"Mask shape: {mask.shape}, unique values: {np.unique(mask)}") - - # Convert image to RGBA if not already - if img.mode != "RGBA": - img = img.convert("RGBA") - - # Get image as array - img_array = np.array(img) - - # Ensure mask is the same size as image - if mask.shape != img_array.shape[:2]: - mask_img = Image.fromarray(mask) - mask_img = mask_img.resize((img.width, img.height), Image.LANCZOS) - mask = np.array(mask_img) - - # Apply mask as alpha channel (255 = keep, 0 = transparent) - img_array[:, :, 3] = mask - - # Convert back to PIL Image - result_img = Image.fromarray(img_array, "RGBA") - - # Save as PNG (supports transparency) - output_buffer = BytesIO() - result_img.save(output_buffer, format="PNG", optimize=True) - output_buffer.seek(0) - - return output_buffer.getvalue() From 0ae5d51b38ce355238833cc22551b79987147f65 Mon Sep 17 00:00:00 2001 From: Glen Date: Sun, 8 Mar 2026 19:22:11 -0600 Subject: [PATCH 5/6] Refactor SAM2 model input handling to use a dedicated download method for improved maintainability --- .../ComfyUIBackend/ComfyUIBackendExtension.cs | 20 ++++++++++++------- .../ComfyUIBackend/WorkflowGeneratorSteps.cs | 16 ++------------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs b/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs index 875c8f6c2..2c68a4b2b 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs +++ b/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs @@ -556,13 +556,7 @@ public static void AssignValuesFromRaw(JObject rawObjectInfo) new JObject() { ["class_type"] = "DownloadAndLoadSAM2Model", - ["inputs"] = new JObject() - { - ["model"] = $"sam2_hiera_{size}.safetensors", - ["segmentor"] = "automaskgenerator", - ["device"] = "cuda", // TODO: This should really be decided by the python, not by swarm's workflow generator - the python knows what the GPU supports, swarm does not - ["precision"] = "bf16" - } + ["inputs"] = Sam2ModelInputs(size, "automaskgenerator") }, new JObject() { @@ -655,6 +649,18 @@ public static void AssignValuesFromRaw(JObject rawObjectInfo) public static T2IRegisteredParam Sam2PointCoordsPositive, Sam2PointCoordsNegative, Sam2BBox, Sam2MaskPadding; + /// Creates the standard input set for a DownloadAndLoadSAM2Model node. + public static JObject Sam2ModelInputs(string size = "base_plus", string segmentor = "single_image") + { + return new JObject() + { + ["model"] = $"sam2_hiera_{size}.safetensors", + ["segmentor"] = segmentor, + ["device"] = "cuda", // TODO: This should really be decided by the python, not by swarm's workflow generator - the python knows what the GPU supports, swarm does not + ["precision"] = "bf16" + }; + } + /// public override void OnInit() { diff --git a/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs b/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs index d7f0d8e99..5fa83b3fb 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs +++ b/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs @@ -1702,13 +1702,7 @@ void RunSegmentationProcessing(WorkflowGenerator g, bool isBeforeRefiner) { return; } - string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", new JObject() - { - ["model"] = "sam2_hiera_base_plus.safetensors", - ["segmentor"] = "single_image", - ["device"] = "cuda", - ["precision"] = "bf16" - }); + string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", ComfyUIBackendExtension.Sam2ModelInputs()); JObject segInputs = new() { ["sam2_model"] = new JArray() { modelNode, 0 }, @@ -1752,13 +1746,7 @@ void RunSegmentationProcessing(WorkflowGenerator g, bool isBeforeRefiner) { return; } - string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", new JObject() - { - ["model"] = "sam2_hiera_base_plus.safetensors", - ["segmentor"] = "single_image", - ["device"] = "cuda", - ["precision"] = "bf16" - }); + string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", ComfyUIBackendExtension.Sam2ModelInputs()); string bboxNode = g.CreateNode("Sam2BBoxFromJson", new JObject() { ["bbox_json"] = bboxJson From bbdc0aed1a4e287e93062ab166a08e27db4c381e Mon Sep 17 00:00:00 2001 From: Glen Date: Sun, 8 Mar 2026 19:57:30 -0600 Subject: [PATCH 6/6] remove old SAM2 script and replace with SAM2 postprocess custom node --- .../ExtraNodes/Sam2BBoxNode/__init__.py | 141 -------- .../ExtraNodes/Sam2BBoxNode/sam2.py | 315 ------------------ .../SwarmSam2MaskPostProcess/__init__.py | 77 +++++ .../ComfyUIBackend/WorkflowGeneratorSteps.cs | 32 +- 4 files changed, 97 insertions(+), 468 deletions(-) delete mode 100644 src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py create mode 100644 src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmSam2MaskPostProcess/__init__.py diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py index 7951b420c..b8801bca8 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py +++ b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py @@ -1,13 +1,4 @@ import json -import torch -import numpy as np -from PIL import Image -from io import BytesIO - -from .sam2 import ( - predict_mask_from_points, - predict_mask_from_bboxes, -) class Sam2BBoxFromJson: @@ -29,145 +20,13 @@ def INPUT_TYPES(s): def convert(self, bbox_json): coords = json.loads(bbox_json) - # coords = [x1, y1, x2, y2] - # BBOX type is a list of [x1, y1, x2, y2] lists (one per box) return ([[float(coords[0]), float(coords[1]), float(coords[2]), float(coords[3])]],) -class Sam2PointSegmentation: - """SAM2 segmentation using point prompts with hole-filling""" - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "points_json": ("STRING", {"multiline": True}), - "model_name": (["sam2_b.pt", "sam2_l.pt", "sam2_s.pt", "sam2_t.pt"], {"default": "sam2_b.pt"}), - }, - "optional": { - "fill_holes": ("BOOLEAN", {"default": True}), - "hole_kernel_size": ("INT", {"default": 5, "min": 1, "max": 21, "step": 2}), - "mask_padding": ("INT", {"default": 0, "min": 0, "max": 50}), - } - } - - RETURN_TYPES = ("MASK",) - RETURN_NAMES = ("mask",) - FUNCTION = "segment" - CATEGORY = "SAM2" - - def segment(self, image, points_json, model_name, fill_holes=True, hole_kernel_size=5, mask_padding=0): - """ - Generate segmentation mask from point prompts - - Points JSON format: [ - {"x": 100, "y": 200, "label": 1}, - {"x": 150, "y": 250, "label": 0} - ] - where label: 1 = foreground, 0 = background - """ - # Parse points from JSON - try: - points_data = json.loads(points_json) - if not isinstance(points_data, list): - raise ValueError("Points must be a JSON array") - - points = [(p["x"], p["y"]) for p in points_data] - labels = [p.get("label", 1) for p in points_data] - except Exception as e: - raise ValueError(f"Invalid points JSON: {e}") - - # Convert image to bytes - img_pil = Image.fromarray((image[0].cpu().numpy() * 255).astype(np.uint8)) - img_buffer = BytesIO() - img_pil.save(img_buffer, format="JPEG") - image_bytes = img_buffer.getvalue() - - # Generate mask using helper function - mask = predict_mask_from_points( - image_bytes, - points, - labels, - model_name=model_name, - fill_holes=fill_holes, - kernel_size=hole_kernel_size, - padding=mask_padding, - ) - - # Convert to tensor format (B, H, W) - mask_tensor = torch.from_numpy(mask).float() / 255.0 - return (mask_tensor.unsqueeze(0),) - - -class Sam2BBoxSegmentation: - """SAM2 segmentation using bounding box prompts with hole-filling""" - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "bbox_json": ("STRING", {"multiline": True}), - "model_name": (["sam2_b.pt", "sam2_l.pt", "sam2_s.pt", "sam2_t.pt"], {"default": "sam2_b.pt"}), - }, - "optional": { - "fill_holes": ("BOOLEAN", {"default": True}), - "hole_kernel_size": ("INT", {"default": 5, "min": 1, "max": 21, "step": 2}), - "mask_padding": ("INT", {"default": 0, "min": 0, "max": 50}), - } - } - - RETURN_TYPES = ("MASK",) - RETURN_NAMES = ("mask",) - FUNCTION = "segment" - CATEGORY = "SAM2" - - def segment(self, image, bbox_json, model_name, fill_holes=True, hole_kernel_size=5, mask_padding=0): - """ - Generate segmentation mask from bounding box prompt - - BBox JSON format: [x1, y1, x2, y2] or [[x1, y1, x2, y2], ...] - """ - # Parse bbox from JSON - try: - bbox_data = json.loads(bbox_json) - if isinstance(bbox_data[0], (list, tuple)): - bboxes = bbox_data[0] # Multiple boxes - else: - bboxes = bbox_data # Single box - except Exception as e: - raise ValueError(f"Invalid bbox JSON: {e}") - - # Convert image to bytes - img_pil = Image.fromarray((image[0].cpu().numpy() * 255).astype(np.uint8)) - img_buffer = BytesIO() - img_pil.save(img_buffer, format="JPEG") - image_bytes = img_buffer.getvalue() - - # Generate mask using helper function - mask = predict_mask_from_bboxes( - image_bytes, - bboxes, - model_name=model_name, - fill_holes=fill_holes, - kernel_size=hole_kernel_size, - padding=mask_padding, - ) - - # Convert to tensor format (B, H, W) - mask_tensor = torch.from_numpy(mask).float() / 255.0 - return (mask_tensor.unsqueeze(0),) - - NODE_CLASS_MAPPINGS = { "Sam2BBoxFromJson": Sam2BBoxFromJson, - "Sam2PointSegmentation": Sam2PointSegmentation, - "Sam2BBoxSegmentation": Sam2BBoxSegmentation, } NODE_DISPLAY_NAME_MAPPINGS = { "Sam2BBoxFromJson": "SAM2 BBox From JSON", - "Sam2PointSegmentation": "SAM2 Point Segmentation", - "Sam2BBoxSegmentation": "SAM2 BBox Segmentation", } diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py deleted file mode 100644 index b9639b861..000000000 --- a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/sam2.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -SAM 2 segmentation utilities -Provides model loading and inference for interactive segmentation -""" - -import numpy as np -from PIL import Image -from io import BytesIO -from typing import List, Tuple -from pathlib import Path -from ultralytics import SAM -import torch -import cv2 - -# Global model cache -_model_cache = {} - -# Models directory -MODELS_DIR = Path(__file__).parent.parent.parent / "models" -MODELS_DIR.mkdir(exist_ok=True) - -# Detect GPU availability -print(f"PyTorch version: {torch.__version__}") -print(f"CUDA available: {torch.cuda.is_available()}") -print( - f"CUDA version (compiled): {torch.version.cuda if torch.version.cuda else 'None'}" -) -print( - f"cuDNN version: {torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 'None'}" -) - -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" -print(f"SAM 2 will use device: {DEVICE}") -if DEVICE == "cuda": - print(f"GPU: {torch.cuda.get_device_name(0)}") - print( - f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB" - ) -else: - print("⚠️ Running on CPU - segmentation will be slower") - print( - "To enable GPU: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121" - ) - - -def get_sam_model(model_name: str = "sam2_b.pt") -> SAM: - """ - Load and cache SAM 2 model with GPU support - - Args: - model_name: Model name (sam2_b.pt, sam2_l.pt, sam2_s.pt, sam2_t.pt) - - Returns: - SAM model instance - """ - if model_name not in _model_cache: - model_path = MODELS_DIR / model_name - - # If model doesn't exist locally, ultralytics will auto-download it - # Just use the model name without .pt extension for auto-download - if not model_path.exists(): - print(f"Model not found at {model_path}") - print(f"Ultralytics will auto-download {model_name} to cache...") - # Use just the model name (e.g., 'sam2_b.pt') - ultralytics handles download - model = SAM(model_name) - else: - print(f"Loading SAM model from: {model_path} on {DEVICE}") - model = SAM(str(model_path)) - - # Move model to GPU if available - if DEVICE == "cuda": - model.to(DEVICE) - _model_cache[model_name] = model - print(f"SAM model loaded: {model_name} on {DEVICE}") - - return _model_cache[model_name] - - -def clear_model_cache(): - """Clear the model cache to force reload""" - global _model_cache - _model_cache = {} - - -def fill_mask_holes(mask: np.ndarray, kernel_size: int = 5) -> np.ndarray: - """ - Fill small holes in a binary mask using morphological operations - - Args: - mask: Binary mask as numpy array (H, W) or (H, W, C) with values 0 or 255 - kernel_size: Size of the morphological kernel (larger = fills bigger holes) - - Returns: - Cleaned mask with holes filled - """ - # Squeeze to remove any single-dimensional entries and ensure 2D - mask = np.squeeze(mask) - if mask.ndim == 0: - # Scalar case - shouldn't happen but handle it - return np.array([[255]], dtype=np.uint8) - if mask.ndim > 2: - # If still 3D+, take first channel - mask = mask[:, :, 0] - - # Ensure mask is uint8 with values 0 or 255 - if mask.dtype != np.uint8: - if mask.dtype == bool or (mask.max() <= 1 and mask.dtype in [np.float32, np.float64]): - mask = (mask * 255).astype(np.uint8) - else: - mask = mask.astype(np.uint8) - - # Create a kernel for morphological operations - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) - - # Morphological closing: dilation followed by erosion - # This fills small holes while preserving the outer boundary - closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) - - # Optional: Fill any remaining interior holes using flood fill - # This catches larger holes that closing might miss - filled_mask = closed_mask.copy() - h, w = filled_mask.shape - - # Create a slightly larger canvas for flood fill - canvas = np.zeros((h + 2, w + 2), dtype=np.uint8) - canvas[1:-1, 1:-1] = filled_mask - - # Flood fill from the border to mark the background - cv2.floodFill(canvas, None, (0, 0), 128) - - # Extract the filled region (anything not marked as background is foreground) - filled_mask = np.where(canvas[1:-1, 1:-1] == 128, 0, 255).astype(np.uint8) - - return filled_mask - - -def add_mask_padding(mask: np.ndarray, padding: int = 10) -> np.ndarray: - """ - Add padding to a mask by dilating it outward - - This expands the mask boundary, which is useful for inpainting to: - - Include border pixels for better blending - - Provide more context around the masked area - - Avoid hard edges at the mask boundary - - Args: - mask: Binary mask as numpy array (H, W) with values 0 or 255 - padding: Number of pixels to expand the mask (kernel size for dilation) - - Returns: - Dilated mask with padding added - """ - if padding <= 0: - return mask - - # Create a circular/elliptical kernel for smooth expansion - kernel = cv2.getStructuringElement( - cv2.MORPH_ELLIPSE, (padding * 2 + 1, padding * 2 + 1) - ) - - # Dilate the mask to expand it outward - padded_mask = cv2.dilate(mask, kernel, iterations=1) - - return padded_mask - - -def predict_mask_from_points( - image_bytes: bytes, - points: List[Tuple[int, int]], - labels: List[int], - model_name: str = "sam2_b.pt", - fill_holes: bool = True, - kernel_size: int = 5, - padding: int = 0, -) -> np.ndarray: - """ - Generate segmentation mask from point prompts - - Args: - image_bytes: Image data as bytes - points: List of (x, y) coordinates for prompts - labels: List of labels (1 for foreground, 0 for background) - model_name: SAM model to use - fill_holes: Whether to fill small holes in the mask - kernel_size: Size of morphological kernel for hole filling - padding: Number of pixels to expand the mask boundary (0 = no padding) - - Returns: - Binary mask as numpy array (H, W) with values 0 or 255 - """ - # Load image - img = Image.open(BytesIO(image_bytes)) - img_array = np.array(img) - - # Get cached model (now that wrapping is fixed, caching works properly) - model = get_sam_model(model_name) - - # Convert points and labels to numpy arrays - points_array = np.array(points, dtype=np.float32) - labels_array = np.array(labels, dtype=np.int32) - - print(f"Predicting with {len(points)} points: {points_array}") - print(f"Labels: {labels_array}") - - # IMPORTANT: Wrap points and labels in an extra list dimension - # This tells SAM that all points belong to ONE object - # points=[[[x1,y1], [x2,y2]]] instead of [[x1,y1], [x2,y2]] - points_wrapped = [points_array.tolist()] - labels_wrapped = [labels_array.tolist()] - - print(f"Wrapped points: {points_wrapped}") - print(f"Wrapped labels: {labels_wrapped}") - - # Run prediction with point prompts - results = model( - img_array, points=points_wrapped, labels=labels_wrapped, verbose=False - ) - - # Extract mask from results - if ( - len(results) > 0 - and hasattr(results[0], "masks") - and results[0].masks is not None - ): - # Get the first mask - mask = results[0].masks.data[0].cpu().numpy() - # Convert to uint8 (0 or 255) - mask = (mask * 255).astype(np.uint8) - print( - f"Generated mask with shape: {mask.shape}, unique values: {np.unique(mask)}, sum: {mask.sum()}" - ) - - # Fill holes if requested - if fill_holes: - print(f"Filling holes in mask with kernel size {kernel_size}") - mask = fill_mask_holes(mask, kernel_size) - - # Add padding if requested - if padding > 0: - print(f"Adding {padding}px padding to mask") - mask = add_mask_padding(mask, padding) - - return mask - - # Return empty mask if no result - print("No mask generated!") - return np.zeros(img_array.shape[:2], dtype=np.uint8) - - -def predict_mask_from_bboxes( - image_bytes: bytes, - bboxes: List[float], - model_name: str = "sam2_b.pt", - fill_holes: bool = True, - kernel_size: int = 5, - padding: int = 0, -) -> np.ndarray: - """ - Generate segmentation mask from bounding box prompt - - Args: - image_bytes: Image data as bytes - bboxes: Bounding box as [x1, y1, x2, y2] - model_name: SAM model to use - fill_holes: Whether to fill small holes in the mask - kernel_size: Size of morphological kernel for hole filling - padding: Number of pixels to expand the mask boundary (0 = no padding) - - Returns: - Binary mask as numpy array (H, W) with values 0 or 255 - """ - # Load image - img = Image.open(BytesIO(image_bytes)) - img_array = np.array(img) - - # Get cached model - model = get_sam_model(model_name) - - print(f"Predicting with bounding box: {bboxes}") - - # Run prediction with bboxes prompt - # SAM expects bboxes as a list: [x1, y1, x2, y2] - results = model(img_array, bboxes=bboxes, verbose=False) - - # Extract mask from results - if ( - len(results) > 0 - and hasattr(results[0], "masks") - and results[0].masks is not None - ): - # Get the first mask - mask = results[0].masks.data[0].cpu().numpy() - # Convert to uint8 (0 or 255) - mask = (mask * 255).astype(np.uint8) - print( - f"Generated mask with shape: {mask.shape}, unique values: {np.unique(mask)}, sum: {mask.sum()}" - ) - - # Fill holes if requested - if fill_holes: - print(f"Filling holes in mask with kernel size {kernel_size}") - mask = fill_mask_holes(mask, kernel_size) - - # Add padding if requested - if padding > 0: - print(f"Adding {padding}px padding to mask") - mask = add_mask_padding(mask, padding) - - return mask - - # Return empty mask if no result - print("No mask generated!") - return np.zeros(img_array.shape[:2], dtype=np.uint8) - - diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmSam2MaskPostProcess/__init__.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmSam2MaskPostProcess/__init__.py new file mode 100644 index 000000000..a3aa1f9df --- /dev/null +++ b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmSam2MaskPostProcess/__init__.py @@ -0,0 +1,77 @@ +import torch +import numpy as np +import cv2 + + +def fill_mask_holes(mask: np.ndarray, kernel_size: int = 5) -> np.ndarray: + """Fill small holes in a binary mask using morphological close + flood fill.""" + mask = np.squeeze(mask) + if mask.ndim == 0: + return np.array([[255]], dtype=np.uint8) + if mask.ndim > 2: + mask = mask[:, :, 0] + if mask.dtype != np.uint8: + if mask.dtype == bool or (mask.max() <= 1 and mask.dtype in [np.float32, np.float64]): + mask = (mask * 255).astype(np.uint8) + else: + mask = mask.astype(np.uint8) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + filled_mask = closed_mask.copy() + h, w = filled_mask.shape + canvas = np.zeros((h + 2, w + 2), dtype=np.uint8) + canvas[1:-1, 1:-1] = filled_mask + cv2.floodFill(canvas, None, (0, 0), 128) + filled_mask = np.where(canvas[1:-1, 1:-1] == 128, 0, 255).astype(np.uint8) + return filled_mask + + +def add_mask_padding(mask: np.ndarray, padding: int = 0) -> np.ndarray: + """Expand a mask boundary by dilating outward.""" + if padding <= 0: + return mask + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (padding * 2 + 1, padding * 2 + 1)) + return cv2.dilate(mask, kernel, iterations=1) + + +class SwarmSam2MaskPostProcess: + """Post-processes a SAM2 segmentation mask with hole-filling and padding.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + }, + "optional": { + "fill_holes": ("BOOLEAN", {"default": True}), + "hole_kernel_size": ("INT", {"default": 5, "min": 1, "max": 21, "step": 2}), + "mask_padding": ("INT", {"default": 0, "min": 0, "max": 256, "step": 1}), + }, + } + + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + CATEGORY = "SAM2" + + def process(self, mask, fill_holes=True, hole_kernel_size=5, mask_padding=0): + out_list = [] + for i in range(mask.shape[0]): + m = mask[i].cpu().numpy() + m_uint8 = (m * 255).astype(np.uint8) + if fill_holes: + m_uint8 = fill_mask_holes(m_uint8, kernel_size=hole_kernel_size) + if mask_padding > 0: + m_uint8 = add_mask_padding(m_uint8, padding=mask_padding) + out_list.append(torch.from_numpy(m_uint8.astype(np.float32) / 255.0)) + return (torch.stack(out_list, dim=0),) + + +NODE_CLASS_MAPPINGS = { + "SwarmSam2MaskPostProcess": SwarmSam2MaskPostProcess, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SwarmSam2MaskPostProcess": "SAM2 Mask Post-Process (Fill Holes + Padding)", +} diff --git a/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs b/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs index 5fa83b3fb..0bddf3136 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs +++ b/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs @@ -1708,20 +1708,24 @@ void RunSegmentationProcessing(WorkflowGenerator g, bool isBeforeRefiner) ["sam2_model"] = new JArray() { modelNode, 0 }, ["image"] = imageNodeActual, ["keep_model_loaded"] = true, - ["coordinates_positive"] = coords, - ["fill_holes"] = true, - ["hole_kernel_size"] = 9, - ["mask_padding"] = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int pointsPadding) ? pointsPadding : 0 + ["coordinates_positive"] = coords }; - Logs.Debug($"[SAM2-Points] mask_padding from UserInput = {g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0")}"); if (negCoords is not null) { segInputs["coordinates_negative"] = negCoords; } string segNode = g.CreateNode("Sam2Segmentation", segInputs); + int pointsPadding = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int pp) ? pp : 0; + string postNode = g.CreateNode("SwarmSam2MaskPostProcess", new JObject() + { + ["mask"] = new JArray() { segNode, 0 }, + ["fill_holes"] = true, + ["hole_kernel_size"] = 9, + ["mask_padding"] = pointsPadding + }); string maskNode = g.CreateNode("MaskToImage", new JObject() { - ["mask"] = new JArray() { segNode, 0 } + ["mask"] = new JArray() { postNode, 0 } }); new WGNodeData([maskNode, 0], g, WGNodeData.DT_IMAGE, g.CurrentCompat()).SaveOutput(null, null, "9"); g.SkipFurtherSteps = true; @@ -1756,16 +1760,20 @@ void RunSegmentationProcessing(WorkflowGenerator g, bool isBeforeRefiner) ["sam2_model"] = new JArray() { modelNode, 0 }, ["image"] = imageNodeActual, ["keep_model_loaded"] = true, - ["bboxes"] = new JArray() { bboxNode, 0 }, - ["fill_holes"] = true, - ["hole_kernel_size"] = 5, - ["mask_padding"] = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int bboxPadding) ? bboxPadding : 0 + ["bboxes"] = new JArray() { bboxNode, 0 } }; - Logs.Debug($"[SAM2-BBox] mask_padding from UserInput = {g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0")}"); string segNode = g.CreateNode("Sam2Segmentation", segInputs); + int bboxPadding = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int bp) ? bp : 0; + string postNode = g.CreateNode("SwarmSam2MaskPostProcess", new JObject() + { + ["mask"] = new JArray() { segNode, 0 }, + ["fill_holes"] = true, + ["hole_kernel_size"] = 5, + ["mask_padding"] = bboxPadding + }); string maskNode = g.CreateNode("MaskToImage", new JObject() { - ["mask"] = new JArray() { segNode, 0 } + ["mask"] = new JArray() { postNode, 0 } }); new WGNodeData([maskNode, 0], g, WGNodeData.DT_IMAGE, g.CurrentCompat()).SaveOutput(null, null, "9"); g.SkipFurtherSteps = true;