diff --git a/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs b/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs index ecaa147c2..2c68a4b2b 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs +++ b/src/BuiltinExtensions/ComfyUIBackend/ComfyUIBackendExtension.cs @@ -556,13 +556,7 @@ public static void AssignValuesFromRaw(JObject rawObjectInfo) new JObject() { ["class_type"] = "DownloadAndLoadSAM2Model", - ["inputs"] = new JObject() - { - ["model"] = $"sam2_hiera_{size}.safetensors", - ["segmentor"] = "automaskgenerator", - ["device"] = "cuda", // TODO: This should really be decided by the python, not by swarm's workflow generator - the python knows what the GPU supports, swarm does not - ["precision"] = "bf16" - } + ["inputs"] = Sam2ModelInputs(size, "automaskgenerator") }, new JObject() { @@ -651,9 +645,40 @@ public static void AssignValuesFromRaw(JObject rawObjectInfo) public static T2IParamGroup ComfyAdvancedGroup; + public static T2IRegisteredParam Sam2PointImage; + + public static T2IRegisteredParam Sam2PointCoordsPositive, Sam2PointCoordsNegative, Sam2BBox, Sam2MaskPadding; + + /// Creates the standard input set for a DownloadAndLoadSAM2Model node. + public static JObject Sam2ModelInputs(string size = "base_plus", string segmentor = "single_image") + { + return new JObject() + { + ["model"] = $"sam2_hiera_{size}.safetensors", + ["segmentor"] = segmentor, + ["device"] = "cuda", // TODO: This should really be decided by the python, not by swarm's workflow generator - the python knows what the GPU supports, swarm does not + ["precision"] = "bf16" + }; + } + /// public override void OnInit() { + Sam2PointImage = T2IParamTypes.Register(new("SAM2 Point Image", "Internal: Base image used for SAM2 point masking.", + null, FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); + Sam2PointCoordsPositive = T2IParamTypes.Register(new("SAM2 Positive Points", "Internal: JSON list of positive point coordinates for SAM2 point masking.", + "[]", FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); + Sam2PointCoordsNegative = T2IParamTypes.Register(new("SAM2 Negative Points", "Internal: JSON list of negative point coordinates for SAM2 point masking.", + "[]", FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); + Sam2BBox = T2IParamTypes.Register(new("SAM2 BBox", "Internal: JSON bounding box [x1,y1,x2,y2] for SAM2 bbox masking.", + null, FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); + Sam2MaskPadding = T2IParamTypes.Register(new("SAM2 Mask Padding", "Internal: Number of pixels to dilate/expand the SAM2 mask boundary.", + "0", IgnoreIf: "0", FeatureFlag: "sam2", VisibleNormally: false, ExtraHidden: true, DoNotSave: true, DoNotPreview: true, AlwaysRetain: true + )); UseIPAdapterForRevision = T2IParamTypes.Register(new("Use IP-Adapter", $"Select an IP-Adapter model to use IP-Adapter for image-prompt input handling.\nModels will automatically be downloaded when you first use them.\nNote if you use a custom model, you must also set your CLIP-Vision Model under Advanced Model Addons, otherwise CLIP Vision G will be presumed.\nSee more docs here.", "None", IgnoreIf: "None", FeatureFlag: "ipadapter", GetValues: _ => IPAdapterModels, Group: T2IParamTypes.GroupImagePrompting, OrderPriority: 15, ChangeWeight: 1 )); diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py new file mode 100644 index 000000000..b8801bca8 --- /dev/null +++ b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/Sam2BBoxNode/__init__.py @@ -0,0 +1,32 @@ +import json + + +class Sam2BBoxFromJson: + """Converts a JSON bounding box string '[x1,y1,x2,y2]' into a BBOX type + that can be passed directly to Sam2Segmentation's bboxes input.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "bbox_json": ("STRING", {"forceInput": True}), + } + } + + RETURN_TYPES = ("BBOX",) + RETURN_NAMES = ("bboxes",) + FUNCTION = "convert" + CATEGORY = "SAM2" + + def convert(self, bbox_json): + coords = json.loads(bbox_json) + return ([[float(coords[0]), float(coords[1]), float(coords[2]), float(coords[3])]],) + + +NODE_CLASS_MAPPINGS = { + "Sam2BBoxFromJson": Sam2BBoxFromJson, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Sam2BBoxFromJson": "SAM2 BBox From JSON", +} diff --git a/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmSam2MaskPostProcess/__init__.py b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmSam2MaskPostProcess/__init__.py new file mode 100644 index 000000000..a3aa1f9df --- /dev/null +++ b/src/BuiltinExtensions/ComfyUIBackend/ExtraNodes/SwarmSam2MaskPostProcess/__init__.py @@ -0,0 +1,77 @@ +import torch +import numpy as np +import cv2 + + +def fill_mask_holes(mask: np.ndarray, kernel_size: int = 5) -> np.ndarray: + """Fill small holes in a binary mask using morphological close + flood fill.""" + mask = np.squeeze(mask) + if mask.ndim == 0: + return np.array([[255]], dtype=np.uint8) + if mask.ndim > 2: + mask = mask[:, :, 0] + if mask.dtype != np.uint8: + if mask.dtype == bool or (mask.max() <= 1 and mask.dtype in [np.float32, np.float64]): + mask = (mask * 255).astype(np.uint8) + else: + mask = mask.astype(np.uint8) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + closed_mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) + filled_mask = closed_mask.copy() + h, w = filled_mask.shape + canvas = np.zeros((h + 2, w + 2), dtype=np.uint8) + canvas[1:-1, 1:-1] = filled_mask + cv2.floodFill(canvas, None, (0, 0), 128) + filled_mask = np.where(canvas[1:-1, 1:-1] == 128, 0, 255).astype(np.uint8) + return filled_mask + + +def add_mask_padding(mask: np.ndarray, padding: int = 0) -> np.ndarray: + """Expand a mask boundary by dilating outward.""" + if padding <= 0: + return mask + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (padding * 2 + 1, padding * 2 + 1)) + return cv2.dilate(mask, kernel, iterations=1) + + +class SwarmSam2MaskPostProcess: + """Post-processes a SAM2 segmentation mask with hole-filling and padding.""" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + }, + "optional": { + "fill_holes": ("BOOLEAN", {"default": True}), + "hole_kernel_size": ("INT", {"default": 5, "min": 1, "max": 21, "step": 2}), + "mask_padding": ("INT", {"default": 0, "min": 0, "max": 256, "step": 1}), + }, + } + + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("mask",) + FUNCTION = "process" + CATEGORY = "SAM2" + + def process(self, mask, fill_holes=True, hole_kernel_size=5, mask_padding=0): + out_list = [] + for i in range(mask.shape[0]): + m = mask[i].cpu().numpy() + m_uint8 = (m * 255).astype(np.uint8) + if fill_holes: + m_uint8 = fill_mask_holes(m_uint8, kernel_size=hole_kernel_size) + if mask_padding > 0: + m_uint8 = add_mask_padding(m_uint8, padding=mask_padding) + out_list.append(torch.from_numpy(m_uint8.astype(np.float32) / 255.0)) + return (torch.stack(out_list, dim=0),) + + +NODE_CLASS_MAPPINGS = { + "SwarmSam2MaskPostProcess": SwarmSam2MaskPostProcess, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SwarmSam2MaskPostProcess": "SAM2 Mask Post-Process (Fill Holes + Padding)", +} diff --git a/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs b/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs index eefbbe560..0bddf3136 100644 --- a/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs +++ b/src/BuiltinExtensions/ComfyUIBackend/WorkflowGeneratorSteps.cs @@ -1676,6 +1676,109 @@ void RunSegmentationProcessing(WorkflowGenerator g, bool isBeforeRefiner) RunSegmentationProcessing(g, isBeforeRefiner: false); }, 5); #endregion + #region SAM2 Masking + AddStep(g => + { + if (!g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointCoordsPositive, out string coords) || string.IsNullOrWhiteSpace(coords) || coords == "[]") + { + return; + } + string negCoords = null; + if (g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointCoordsNegative, out string negCoordsRaw) && !string.IsNullOrWhiteSpace(negCoordsRaw) && negCoordsRaw != "[]") + { + negCoords = negCoordsRaw; + } + JArray imageNodeActual = null; + if (g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointImage, out Image img)) + { + WGNodeData imageNode = g.LoadImage(img, "${sampointimage}", true); + imageNodeActual = imageNode.Path; + } + else if (g.BasicInputImage is not null) + { + imageNodeActual = g.BasicInputImage.Path; + } + if (imageNodeActual is null) + { + return; + } + string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", ComfyUIBackendExtension.Sam2ModelInputs()); + JObject segInputs = new() + { + ["sam2_model"] = new JArray() { modelNode, 0 }, + ["image"] = imageNodeActual, + ["keep_model_loaded"] = true, + ["coordinates_positive"] = coords + }; + if (negCoords is not null) + { + segInputs["coordinates_negative"] = negCoords; + } + string segNode = g.CreateNode("Sam2Segmentation", segInputs); + int pointsPadding = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int pp) ? pp : 0; + string postNode = g.CreateNode("SwarmSam2MaskPostProcess", new JObject() + { + ["mask"] = new JArray() { segNode, 0 }, + ["fill_holes"] = true, + ["hole_kernel_size"] = 9, + ["mask_padding"] = pointsPadding + }); + string maskNode = g.CreateNode("MaskToImage", new JObject() + { + ["mask"] = new JArray() { postNode, 0 } + }); + new WGNodeData([maskNode, 0], g, WGNodeData.DT_IMAGE, g.CurrentCompat()).SaveOutput(null, null, "9"); + g.SkipFurtherSteps = true; + }, 8.9); + AddStep(g => + { + if (!g.UserInput.TryGet(ComfyUIBackendExtension.Sam2BBox, out string bboxJson) || string.IsNullOrWhiteSpace(bboxJson)) + { + return; + } + JArray imageNodeActual = null; + if (g.UserInput.TryGet(ComfyUIBackendExtension.Sam2PointImage, out Image img)) + { + WGNodeData imageNode = g.LoadImage(img, "${sampointimage}", true); + imageNodeActual = imageNode.Path; + } + else if (g.BasicInputImage is not null) + { + imageNodeActual = g.BasicInputImage.Path; + } + if (imageNodeActual is null) + { + return; + } + string modelNode = g.CreateNode("DownloadAndLoadSAM2Model", ComfyUIBackendExtension.Sam2ModelInputs()); + string bboxNode = g.CreateNode("Sam2BBoxFromJson", new JObject() + { + ["bbox_json"] = bboxJson + }); + JObject segInputs = new() + { + ["sam2_model"] = new JArray() { modelNode, 0 }, + ["image"] = imageNodeActual, + ["keep_model_loaded"] = true, + ["bboxes"] = new JArray() { bboxNode, 0 } + }; + string segNode = g.CreateNode("Sam2Segmentation", segInputs); + int bboxPadding = int.TryParse(g.UserInput.Get(ComfyUIBackendExtension.Sam2MaskPadding, "0"), out int bp) ? bp : 0; + string postNode = g.CreateNode("SwarmSam2MaskPostProcess", new JObject() + { + ["mask"] = new JArray() { segNode, 0 }, + ["fill_holes"] = true, + ["hole_kernel_size"] = 5, + ["mask_padding"] = bboxPadding + }); + string maskNode = g.CreateNode("MaskToImage", new JObject() + { + ["mask"] = new JArray() { postNode, 0 } + }); + new WGNodeData([maskNode, 0], g, WGNodeData.DT_IMAGE, g.CurrentCompat()).SaveOutput(null, null, "9"); + g.SkipFurtherSteps = true; + }, 8.85); + #endregion #region SaveImage AddStep(g => { diff --git a/src/wwwroot/imgs/crosshair.png b/src/wwwroot/imgs/crosshair.png new file mode 100644 index 000000000..6e9495e04 Binary files /dev/null and b/src/wwwroot/imgs/crosshair.png differ diff --git a/src/wwwroot/imgs/rectangle.png b/src/wwwroot/imgs/rectangle.png new file mode 100644 index 000000000..2813ee61f Binary files /dev/null and b/src/wwwroot/imgs/rectangle.png differ diff --git a/src/wwwroot/js/genpage/helpers/image_editor.js b/src/wwwroot/js/genpage/helpers/image_editor.js index b583683e5..7da245216 100644 --- a/src/wwwroot/js/genpage/helpers/image_editor.js +++ b/src/wwwroot/js/genpage/helpers/image_editor.js @@ -95,6 +95,10 @@ class ImageEditorTool { onGlobalMouseUp(e) { return false; } + + onContextMenu(e) { + return false; + } } /** @@ -1234,6 +1238,411 @@ class ImageEditorToolPicker extends ImageEditorTempTool { } } +/** + * The SAM2 Point Segmentation tool - click to place positive/negative points and auto-generate a mask. + */ +class ImageEditorToolSam2Points extends ImageEditorTool { + constructor(editor) { + super(editor, 'sam2points', 'crosshair', 'SAM2 Points', 'Left click to add positive points. Right click to add negative points.\nEach click regenerates the mask.\nRequires SAM2 to be installed.\nHotKey: Y', 'y'); + this.cursor = 'crosshair'; + this.positivePoints = []; + this.negativePoints = []; + this.requestSerial = 0; + this.activeRequestId = 0; + this.maskRequestInFlight = false; + this.pendingMaskUpdate = false; + this.modelWarmed = false; + this.isWarmingUp = false; + this.controlsHTML = ` +
+ + + +
`; + this.warmupHTML = `
Warming up SAM2 model...
`; + this.showControls(); + } + + showControls() { + let prevPadding = this.maskPaddingInput ? this.maskPaddingInput.value : '0'; + this.configDiv.innerHTML = this.controlsHTML; + this.maskPaddingInput = this.configDiv.querySelector('.id-mask-padding'); + this.maskPaddingInput.value = prevPadding; + this.configDiv.querySelector('.id-clear-mask').addEventListener('click', () => { + // Clear points + this.positivePoints = []; + this.negativePoints = []; + // Clear mask layer + let maskLayer = this.editor.activeLayer && this.editor.activeLayer.isMask ? this.editor.activeLayer : this.editor.layers.find(layer => layer.isMask); + if (maskLayer) { + maskLayer.saveBeforeEdit(); + maskLayer.ctx.clearRect(0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + maskLayer.hasAnyContent = false; + } + this.activeRequestId = ++this.requestSerial; + this.maskRequestInFlight = false; + this.pendingMaskUpdate = false; + this.editor.redraw(); + }); + } + + drawPoint(ctx, x, y, fillColor, showX) { + let [cx, cy] = this.editor.imageCoordToCanvasCoord(x, y); + let radius = Math.max(3, Math.round(4 * this.editor.zoomLevel)); + ctx.save(); + ctx.lineWidth = Math.max(1, Math.round(2 * this.editor.zoomLevel)); + ctx.strokeStyle = '#000000'; + ctx.fillStyle = fillColor; + ctx.beginPath(); + ctx.arc(cx, cy, radius, 0, 2 * Math.PI); + ctx.fill(); + ctx.stroke(); + if (showX) { + let cross = Math.max(3, Math.round(radius * 0.9)); + ctx.beginPath(); + ctx.moveTo(cx - cross, cy - cross); + ctx.lineTo(cx + cross, cy + cross); + ctx.moveTo(cx - cross, cy + cross); + ctx.lineTo(cx + cross, cy - cross); + ctx.stroke(); + } + ctx.restore(); + } + + draw() { + let ctx = this.editor.ctx; + for (let point of this.positivePoints) { + this.drawPoint(ctx, point.x, point.y, '#33ff99', false); + } + for (let point of this.negativePoints) { + this.drawPoint(ctx, point.x, point.y, '#ff3355', true); + } + } + + onContextMenu(e) { + e.preventDefault(); + return true; + } + + setActive() { + super.setActive(); + if (!this.modelWarmed && !this.isWarmingUp && currentBackendFeatureSet.includes('sam2') && this.editor.getFinalImageData?.()) { + this.triggerWarmup(); + } + } + + triggerWarmup() { + this.isWarmingUp = true; + this.cursor = 'wait'; + this.editor.canvas.style.cursor = 'wait'; + this.configDiv.innerHTML = this.warmupHTML; + try { + let img = this.editor.getFinalImageData(); + let genData = getGenInput(); + genData['sampointimage'] = img; + genData['images'] = 1; + genData['prompt'] = ''; + delete genData['batchsize']; + genData['donotsave'] = true; + let cx = Math.floor((this.editor.realWidth || 64) / 2); + let cy = Math.floor((this.editor.realHeight || 64) / 2); + genData['sampositivepoints'] = JSON.stringify([{ x: cx, y: cy }]); + makeWSRequestT2I('GenerateText2ImageWS', genData, data => { + if (data.image || data.error) { + this.modelWarmed = true; + this.isWarmingUp = false; + this.cursor = 'crosshair'; + this.editor.canvas.style.cursor = 'crosshair'; + this.showControls(); + } + }); + } catch (e) { + this.modelWarmed = true; + this.isWarmingUp = false; + this.cursor = 'crosshair'; + this.editor.canvas.style.cursor = 'crosshair'; + this.showControls(); + } + } + + onMouseDown(e) { + if (this.isWarmingUp) { return; } + if (e.button !== 0 && e.button !== 2) { + return; + } + this.editor.updateMousePosFrom(e); + let [mouseX, mouseY] = this.editor.canvasCoordToImageCoord(this.editor.mouseX, this.editor.mouseY); + mouseX = Math.round(mouseX); + mouseY = Math.round(mouseY); + if (mouseX < 0 || mouseY < 0 || mouseX >= this.editor.realWidth || mouseY >= this.editor.realHeight) { + return; + } + let point = { x: mouseX, y: mouseY }; + if (e.button === 2) { + e.preventDefault(); + this.negativePoints.push(point); + } + else { + this.positivePoints.push(point); + } + this.queueMaskUpdate(); + this.editor.redraw(); + } + + queueMaskUpdate() { + if (!currentBackendFeatureSet.includes('sam2')) { + $('#sam2_installer').modal('show'); + return; + } + if (this.positivePoints.length === 0) { + return; + } + if (this.maskRequestInFlight) { + this.pendingMaskUpdate = true; + return; + } + this.requestMaskUpdate(); + } + + finishMaskUpdate(requestId) { + if (requestId !== this.activeRequestId) { + return; + } + this.maskRequestInFlight = false; + if (this.pendingMaskUpdate) { + this.pendingMaskUpdate = false; + this.requestMaskUpdate(); + } + } + + requestMaskUpdate() { + this.maskRequestInFlight = true; + let requestId = ++this.requestSerial; + this.activeRequestId = requestId; + let img = this.editor.getFinalImageData(); + let genData = getGenInput(); + genData['sampointimage'] = img; + genData['images'] = 1; + genData['prompt'] = ''; + delete genData['batchsize']; + genData['donotsave'] = true; + genData['sampositivepoints'] = JSON.stringify(this.positivePoints); + if (this.negativePoints.length > 0) { + genData['samnegativepoints'] = JSON.stringify(this.negativePoints); + } + let maskPadding = parseInt(this.maskPaddingInput.value) || 0; + if (maskPadding > 0) { + genData['sammaskpadding'] = `${maskPadding}`; + } + makeWSRequestT2I('GenerateText2ImageWS', genData, data => { + if (requestId !== this.activeRequestId) { + return; + } + if (!data.image) { + return; + } + let newImg = new Image(); + newImg.onload = () => { + if (requestId !== this.activeRequestId) { + return; + } + this.editor.applyMaskFromImage(newImg, true); + this.finishMaskUpdate(requestId); + }; + newImg.src = data.image; + }); + } +} + +/** + * The SAM2 Bounding Box segmentation tool - drag to define a box and auto-generate a mask. + */ +class ImageEditorToolSam2BBox extends ImageEditorTool { + constructor(editor) { + super(editor, 'sam2bbox', 'rectangle', 'SAM2 BBox', 'Click and drag to create a bounding box. Release to generate mask.\nRequires SAM2 to be installed.', null); + this.cursor = 'crosshair'; + this.bboxStartX = null; + this.bboxStartY = null; + this.bboxEndX = null; + this.bboxEndY = null; + this.isDrawing = false; + this.requestSerial = 0; + this.activeRequestId = 0; + this.maskRequestInFlight = false; + this.modelWarmed = false; + this.isWarmingUp = false; + this.controlsHTML = ` +
+ + + +
`; + this.warmupHTML = `
Warming up SAM2 model...
`; + this.showControls(); + } + + showControls() { + let prevPadding = this.maskPaddingInput ? this.maskPaddingInput.value : '0'; + this.configDiv.innerHTML = this.controlsHTML; + this.maskPaddingInput = this.configDiv.querySelector('.id-mask-padding'); + this.maskPaddingInput.value = prevPadding; + this.configDiv.querySelector('.id-clear-mask').addEventListener('click', () => { + let maskLayer = this.editor.activeLayer && this.editor.activeLayer.isMask ? this.editor.activeLayer : this.editor.layers.find(layer => layer.isMask); + if (!maskLayer) { + return; + } + maskLayer.saveBeforeEdit(); + maskLayer.ctx.clearRect(0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + maskLayer.hasAnyContent = false; + this.editor.redraw(); + }); + } + + draw() { + if (this.isDrawing && this.bboxStartX !== null && this.bboxEndX !== null) { + let ctx = this.editor.ctx; + let [x1, y1] = this.editor.imageCoordToCanvasCoord(this.bboxStartX, this.bboxStartY); + let [x2, y2] = this.editor.imageCoordToCanvasCoord(this.bboxEndX, this.bboxEndY); + let minX = Math.min(x1, x2); + let minY = Math.min(y1, y2); + let maxX = Math.max(x1, x2); + let maxY = Math.max(y1, y2); + ctx.save(); + ctx.strokeStyle = '#33ff99'; + ctx.lineWidth = 2; + ctx.setLineDash([5, 5]); + ctx.strokeRect(minX, minY, maxX - minX, maxY - minY); + ctx.restore(); + } + } + + setActive() { + super.setActive(); + if (!this.modelWarmed && !this.isWarmingUp && currentBackendFeatureSet.includes('sam2') && this.editor.getFinalImageData?.()) { + this.triggerWarmup(); + } + } + + triggerWarmup() { + this.isWarmingUp = true; + this.cursor = 'wait'; + this.editor.canvas.style.cursor = 'wait'; + this.configDiv.innerHTML = this.warmupHTML; + try { + let img = this.editor.getFinalImageData(); + let genData = getGenInput(); + genData['sampointimage'] = img; + genData['images'] = 1; + genData['prompt'] = ''; + delete genData['batchsize']; + genData['donotsave'] = true; + let cx = Math.floor((this.editor.realWidth || 64) / 2); + let cy = Math.floor((this.editor.realHeight || 64) / 2); + genData['sambbox'] = JSON.stringify([cx - 1, cy - 1, cx + 1, cy + 1]); + makeWSRequestT2I('GenerateText2ImageWS', genData, data => { + if (data.image || data.error) { + this.modelWarmed = true; + this.isWarmingUp = false; + this.cursor = 'crosshair'; + this.editor.canvas.style.cursor = 'crosshair'; + this.showControls(); + } + }); + } catch (e) { + this.modelWarmed = true; + this.isWarmingUp = false; + this.cursor = 'crosshair'; + this.editor.canvas.style.cursor = 'crosshair'; + this.showControls(); + } + } + + onMouseDown(e) { + if (this.isWarmingUp) { return; } + if (e.button !== 0) { + return; + } + this.editor.updateMousePosFrom(e); + let [mouseX, mouseY] = this.editor.canvasCoordToImageCoord(this.editor.mouseX, this.editor.mouseY); + mouseX = Math.round(mouseX); + mouseY = Math.round(mouseY); + if (mouseX < 0 || mouseY < 0 || mouseX >= this.editor.realWidth || mouseY >= this.editor.realHeight) { + return; + } + this.isDrawing = true; + this.bboxStartX = mouseX; + this.bboxStartY = mouseY; + this.bboxEndX = mouseX; + this.bboxEndY = mouseY; + } + + onMouseMove(e) { + if (this.isDrawing) { + this.editor.updateMousePosFrom(e); + let [mouseX, mouseY] = this.editor.canvasCoordToImageCoord(this.editor.mouseX, this.editor.mouseY); + mouseX = Math.max(0, Math.min(this.editor.realWidth - 1, Math.round(mouseX))); + mouseY = Math.max(0, Math.min(this.editor.realHeight - 1, Math.round(mouseY))); + this.bboxEndX = mouseX; + this.bboxEndY = mouseY; + this.editor.redraw(); + } + } + + onMouseUp(e) { + if (this.isWarmingUp) { return; } + if (this.isDrawing) { + this.isDrawing = false; + this.requestMaskUpdate(); + } + } + + requestMaskUpdate() { + if (!currentBackendFeatureSet.includes('sam2')) { + $('#sam2_installer').modal('show'); + return; + } + if (this.bboxStartX === null || this.bboxEndX === null) { + return; + } + this.maskRequestInFlight = true; + let requestId = ++this.requestSerial; + this.activeRequestId = requestId; + let img = this.editor.getFinalImageData(); + let genData = getGenInput(); + genData['sampointimage'] = img; + genData['images'] = 1; + genData['prompt'] = ''; + delete genData['batchsize']; + genData['donotsave'] = true; + let minX = Math.min(this.bboxStartX, this.bboxEndX); + let minY = Math.min(this.bboxStartY, this.bboxEndY); + let maxX = Math.max(this.bboxStartX, this.bboxEndX); + let maxY = Math.max(this.bboxStartY, this.bboxEndY); + genData['sambbox'] = JSON.stringify([minX, minY, maxX, maxY]); + let maskPadding = parseInt(this.maskPaddingInput.value) || 0; + if (maskPadding > 0) { + genData['sammaskpadding'] = `${maskPadding}`; + } + makeWSRequestT2I('GenerateText2ImageWS', genData, data => { + if (requestId !== this.activeRequestId) { + return; + } + if (!data.image) { + return; + } + let newImg = new Image(); + newImg.onload = () => { + if (requestId !== this.activeRequestId) { + return; + } + this.editor.applyMaskFromImage(newImg, true); + this.maskRequestInFlight = false; + }; + newImg.src = data.image; + }); + } +} + /** * A single layer within an image editing interface. * This can be real (user-controlled) OR sub-layers (sometimes user-controlled) OR temporary buffers. @@ -1595,6 +2004,8 @@ class ImageEditor { this.addTool(new ImageEditorToolShape(this)); this.pickerTool = new ImageEditorToolPicker(this, 'picker', 'paintbrush', 'Color Picker', 'Pick a color from the image.'); this.addTool(this.pickerTool); + this.addTool(new ImageEditorToolSam2Points(this)); + this.addTool(new ImageEditorToolSam2BBox(this)); this.activateTool('brush'); this.maxHistory = 15; } @@ -1681,6 +2092,11 @@ class ImageEditor { e.stopPropagation(); }); canvas.addEventListener('drop', (e) => this.handleCanvasImageDrop(e)); + canvas.addEventListener('contextmenu', (e) => { + if (this.activeTool && this.activeTool.onContextMenu(e)) { + e.preventDefault(); + } + }); this.ctx = canvas.getContext('2d'); canvas.style.cursor = 'none'; this.maskHelperCanvas = document.createElement('canvas'); @@ -2012,6 +2428,16 @@ class ImageEditor { this.addLayer(maskLayer); this.realWidth = img.naturalWidth; this.realHeight = img.naturalHeight; + if (this.tools['sam2points']) { + this.tools['sam2points'].positivePoints = []; + this.tools['sam2points'].negativePoints = []; + } + if (this.tools['sam2bbox']) { + this.tools['sam2bbox'].bboxStartX = null; + this.tools['sam2bbox'].bboxStartY = null; + this.tools['sam2bbox'].bboxEndX = null; + this.tools['sam2bbox'].bboxEndY = null; + } this.offsetX = 0 this.offsetY = 0; if (this.active) { @@ -2252,6 +2678,34 @@ class ImageEditor { return canvas.toDataURL(format); } + applyMaskFromImage(img, replaceExisting = true) { + let maskLayer = this.activeLayer && this.activeLayer.isMask ? this.activeLayer : this.layers.find(layer => layer.isMask); + if (!maskLayer) { + maskLayer = new ImageEditorLayer(this, img.naturalWidth || img.width, img.naturalHeight || img.height); + maskLayer.isMask = true; + this.addLayer(maskLayer); + } + if (replaceExisting) { + maskLayer.saveBeforeEdit(); + maskLayer.ctx.clearRect(0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + } + maskLayer.ctx.drawImage(img, 0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + // Convert black pixels to transparent so only the white mask region is visible + let imageData = maskLayer.ctx.getImageData(0, 0, maskLayer.canvas.width, maskLayer.canvas.height); + let data = imageData.data; + for (let i = 0; i < data.length; i += 4) { + let brightness = data[i] + data[i + 1] + data[i + 2]; + if (brightness < 128) { + data[i + 3] = 0; + } + } + maskLayer.ctx.putImageData(imageData, 0, 0); + maskLayer.hasAnyContent = true; + this.setActiveLayer(maskLayer); + this.sortLayers(); + this.redraw(); + } + getFinalMaskData(format = 'image/png') { let canvas = document.createElement('canvas'); canvas.width = this.realWidth;