From 92b0285f8597f235dc2761b7cc2310f41adde0c5 Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 24 Sep 2024 09:54:23 +0200 Subject: [PATCH 01/28] Add format to send image node --- ai_diffusion/comfy_workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index a374aed535..d6f1664536 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -726,7 +726,7 @@ def load_mask(self, mask: Image): def send_image(self, image: Output): if self._run_mode is ComfyRunMode.runtime: return self.add("ETN_ReturnImage", 1, images=image) - return self.add("ETN_SendImageWebSocket", 1, images=image) + return self.add("ETN_SendImageWebSocket", 1, images=image, format="PNG") def save_image(self, image: Output, prefix: str): return self.add("SaveImage", 1, images=image, filename_prefix=prefix) From febe5e106d56772799fe94028d93dabbc4f0c64f Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 25 Sep 2024 12:17:54 +0200 Subject: [PATCH 02/28] Track and import current workflow from ComfyUI web instances and allow executing them from new custom graph workspace --- ai_diffusion/api.py | 8 ++++ ai_diffusion/client.py | 8 +++- ai_diffusion/comfy_client.py | 54 ++++++++++++++++---------- ai_diffusion/comfy_workflow.py | 66 +++++++++++++++++++++++++------- ai_diffusion/connection.py | 48 +++++++++++++++-------- ai_diffusion/model.py | 35 ++++++++++++++--- ai_diffusion/ui/diffusion.py | 7 +++- ai_diffusion/ui/generation.py | 70 ++++++++++++++++++++++++++++++++-- ai_diffusion/ui/widget.py | 2 + ai_diffusion/workflow.py | 33 +++++++++++++++- tests/test_workflow.py | 42 ++++++++++++++++++-- 11 files changed, 306 insertions(+), 67 deletions(-) diff --git a/ai_diffusion/api.py b/ai_diffusion/api.py index 56ad6c17bb..c473fb24c1 100644 --- a/ai_diffusion/api.py +++ b/ai_diffusion/api.py @@ -18,6 +18,7 @@ class WorkflowKind(Enum): upscale_simple = 4 upscale_tiled = 5 control_image = 6 + custom = 7 @dataclass @@ -144,6 +145,12 @@ def clamped(self): return params +@dataclass +class CustomWorkflowInput: + workflow: dict + params: dict[str, Any] + + @dataclass class WorkflowInput: kind: WorkflowKind @@ -157,6 +164,7 @@ class WorkflowInput: control_mode: ControlMode = ControlMode.reference batch_count: int = 1 nsfw_filter: float = 0.0 + custom_workflow: CustomWorkflowInput | None = None @property def extent(self): diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index 76568558be..a3158194bf 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -25,6 +25,7 @@ class ClientEvent(Enum): disconnected = 5 queued = 6 upload = 7 + published = 8 class ClientMessage(NamedTuple): @@ -32,7 +33,7 @@ class ClientMessage(NamedTuple): job_id: str = "" progress: float = 0 images: ImageCollection | None = None - result: dict | None = None + result: dict | SharedWorkflow | None = None error: str | None = None @@ -68,6 +69,11 @@ def parse(data: dict): return DeviceInfo("cpu", "unknown", 0) +class SharedWorkflow(NamedTuple): + publisher: str + workflow: dict + + class CheckpointInfo(NamedTuple): filename: str arch: Arch diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index 8f61590d7e..1c4cc09b33 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -11,7 +11,7 @@ from .api import WorkflowInput from .client import Client, CheckpointInfo, ClientMessage, ClientEvent, DeviceInfo, ClientModels -from .client import TranslationPackage, filter_supported_styles, loras_to_upload +from .client import SharedWorkflow, TranslationPackage, filter_supported_styles, loras_to_upload from .files import FileFormat from .image import Image, ImageCollection from .network import RequestManager, NetworkError @@ -87,16 +87,18 @@ class ComfyClient(Client): default_url = "http://127.0.0.1:8188" - _requests = RequestManager() - _id: str - _messages: asyncio.Queue[ClientMessage] - _queue: asyncio.Queue[JobInfo] - _jobs: deque[JobInfo] - _active: Optional[JobInfo] = None - _job_runner: asyncio.Task - _websocket_listener: asyncio.Task - _supported_archs: list[Arch] - _supported_languages: list[TranslationPackage] + def __init__(self, url): + self.url = url + self.models = ClientModels() + self._requests = RequestManager() + self._id = str(uuid.uuid4()) + self._active: Optional[JobInfo] = None + self._supported_archs: list[Arch] = [] + self._supported_languages: list[TranslationPackage] = [] + self._messages = asyncio.Queue() + self._queue = asyncio.Queue() + self._jobs = deque() + self._is_connected = False @staticmethod async def connect(url=default_url, access_token=""): @@ -175,15 +177,6 @@ async def connect(url=default_url, access_token=""): _ensure_supported_style(client) return client - def __init__(self, url): - self.url = url - self.models = ClientModels() - self._id = str(uuid.uuid4()) - self._messages = asyncio.Queue() - self._queue = asyncio.Queue() - self._jobs = deque() - self._is_connected = False - async def _get(self, op: str): return await self._requests.get(f"{self.url}/{op}") @@ -237,6 +230,7 @@ async def _listen(self): f"{url}/ws?clientId={self._id}", max_size=2**30, read_limit=2**30, ping_timeout=60 ): try: + await self._subscribe_workflows() await self._listen_websocket(websocket) except websockets_exceptions.ConnectionClosedError as e: log.warning(f"Websocket connection closed: {str(e)}") @@ -322,7 +316,12 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr traceback = msg["data"].get("traceback", "no traceback") log.error(f"Job {job} failed: {error}\n{traceback}") self._clear_job(job.remote_id) - await self._report(ClientEvent.error, job.local_id, 0, error=error) + await self._report(ClientEvent.error, job.local_id, error=error) + + if msg["type"] == "etn_workflow_published": + name = f"{msg['data']['publisher']['name']} ({msg['data']['publisher']['id']})" + workflow = SharedWorkflow(name, msg["data"]["workflow"]) + await self._report(ClientEvent.published, "", result=workflow) async def listen(self): self._is_connected = True @@ -358,6 +357,7 @@ async def disconnect(self): self._job_runner, self._websocket_listener, self._report(ClientEvent.disconnected, ""), + self._unsubscribe_workflows(), ) async def try_inspect(self, folder_name: str): @@ -431,6 +431,18 @@ async def translate(self, text: str, lang: str): log.error(f"Could not translate text: {str(e)}") return text + async def _subscribe_workflows(self): + try: + await self._post("api/etn/workflow/subscribe", {"client_id": self._id}) + except Exception as e: + log.error(f"Couldn't subscribe to shared workflows: {str(e)}") + + async def _unsubscribe_workflows(self): + try: + await self._post("api/etn/workflow/unsubscribe", {"client_id": self._id}) + except Exception as e: + log.error(f"Couldn't unsubscribe from shared workflows: {str(e)}") + def supports_arch(self, arch: Arch): return arch in self._supported_archs diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index d6f1664536..3fbe2f1d38 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -22,26 +22,43 @@ class Output(NamedTuple): Output2 = Tuple[Output, Output] Output3 = Tuple[Output, Output, Output] Output4 = Tuple[Output, Output, Output, Output] +Input = int | float | bool | str | Output -class ComfyWorkflow: - """Builder for workflows which can be sent to the ComfyUI prompt API.""" +class ComfyNode(NamedTuple): + id: int + type: str + inputs: dict[str, Input] + + def output(self, index=0) -> Output: + return Output(int(self.id), index) - root: dict[str, dict] - images: dict[str, Image] - node_count = 0 - sample_count = 0 - _cache: dict[str, Output | Output2 | Output3 | Output4] - _nodes_required_inputs: dict[str, dict[str, Any]] - _run_mode: ComfyRunMode +class ComfyWorkflow: + """Builder for workflows which can be sent to the ComfyUI prompt API.""" def __init__(self, node_inputs: dict | None = None, run_mode=ComfyRunMode.server): - self.root = {} - self.images = {} - self._cache = {} - self._nodes_required_inputs = node_inputs or {} - self._run_mode = run_mode + self.root: dict[str, dict] = {} + self.images: dict[str, Image] = {} + self.node_count = 0 + self.sample_count = 0 + self._cache: dict[str, Output | Output2 | Output3 | Output4] = {} + self._nodes_required_inputs: dict[str, dict[str, Any]] = node_inputs or {} + self._run_mode: ComfyRunMode = run_mode + + @staticmethod + def from_dict(existing: dict): + w = ComfyWorkflow() + node_map: dict[str, str] = {} + for k, v in existing.items(): + node_map[k] = str(w.node_count) + w.root[str(w.node_count)] = v + w.node_count += 1 + for node in w.root.values(): + for e in node["inputs"].values(): + if isinstance(e, list): + e[0] = node_map.get(e[0], e[0]) + return w def add_default_values(self, node_name: str, args: dict): if node_inputs := self._nodes_required_inputs.get(node_name, None): @@ -102,11 +119,32 @@ def add_cached(self, class_type: str, output_count: Literal[1] | Literal[3], **i self._cache[key] = result return result + def remove(self, node_id: int): + del self.root[str(node_id)] + + def node(self, node_id: int): + inputs = self.root[str(node_id)]["inputs"] + inputs = { + k: Output(int(v[0]), v[1]) if isinstance(v, list) else v for k, v in inputs.items() + } + return ComfyNode(node_id, self.root[str(node_id)]["class_type"], inputs) + + def copy(self, node: ComfyNode): + return self.add(node.type, 1, **node.inputs) + + def __iter__(self): + return iter(self.node(int(k)) for k in self.root.keys()) + + def __contains__(self, node: ComfyNode): + return any(n == node for n in self) + def _add_image(self, image: Image): id = str(uuid4()) self.images[id] = image return id + # Nodes + def ksampler( self, model: Output, diff --git a/ai_diffusion/connection.py b/ai_diffusion/connection.py index 302b8d06b8..8c345d6a5e 100644 --- a/ai_diffusion/connection.py +++ b/ai_diffusion/connection.py @@ -4,7 +4,7 @@ from PyQt5.QtGui import QDesktopServices import asyncio -from .client import Client, ClientMessage, ClientEvent, DeviceInfo +from .client import Client, ClientMessage, ClientEvent, DeviceInfo, SharedWorkflow from .comfy_client import ComfyClient from .cloud_client import CloudClient from .network import NetworkError @@ -36,12 +36,16 @@ class Connection(QObject, ObservableProperties): error_changed = pyqtSignal(str) models_changed = pyqtSignal() message_received = pyqtSignal(ClientMessage) - - _client: Client | None = None - _task: asyncio.Task | None = None + workflow_published = pyqtSignal(str) def __init__(self): super().__init__() + + self._client: Client | None = None + self._task: asyncio.Task | None = None + self._workflows: dict[str, dict] = {} + self._temporary_disconnect = False + settings.changed.connect(self._handle_settings_changed) self._update_state() @@ -151,26 +155,20 @@ def user(self): if client := self.client_if_connected: return client.user + @property + def workflows(self): + return self._workflows + async def _handle_messages(self): client = self._client - temporary_disconnect = False + self._temporary_disconnect = False assert client is not None try: async with client: async for msg in client.listen(): try: - if msg.event is ClientEvent.error and not msg.job_id: - self.error = _("Error communicating with server: ") + str(msg.error) - elif msg.event is ClientEvent.disconnected: - temporary_disconnect = True - self.error = _("Disconnected from server, trying to reconnect...") - elif msg.event is ClientEvent.connected: - if temporary_disconnect: - temporary_disconnect = False - self.error = "" - else: - self.message_received.emit(msg) + self._handle_message(msg) except asyncio.CancelledError: break except Exception as e: @@ -179,6 +177,24 @@ async def _handle_messages(self): except asyncio.CancelledError: pass # shutdown + def _handle_message(self, msg: ClientMessage): + match msg: + case (ClientEvent.error, "", *_): + self.error = _("Error communicating with server: ") + str(msg.error) + case (ClientEvent.disconnected, *_): + self._temporary_disconnect = True + self.error = _("Disconnected from server, trying to reconnect...") + case (ClientEvent.connected, *_): + if self._temporary_disconnect: + self._temporary_disconnect = False + self.error = "" + case (ClientEvent.published, *_): + assert isinstance(msg.result, SharedWorkflow) + self._workflows[msg.result.publisher] = msg.result.workflow + self.workflow_published.emit(msg.result.publisher) + case _: + self.message_received.emit(msg) + def _update_state(self): if ( self.state in [ConnectionState.disconnected, ConnectionState.error] diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 27f85781ab..0fe8a1105e 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -10,14 +10,15 @@ import uuid from . import eventloop, workflow, util -from .api import ConditioningInput, ControlInput, WorkflowKind, WorkflowInput -from .api import InpaintMode, InpaintParams, FillMode +from .api import ConditioningInput, ControlInput, WorkflowKind, WorkflowInput, SamplingInput +from .api import InpaintMode, InpaintParams, FillMode, ImageInput, CustomWorkflowInput from .localization import translate as _ from .util import clamp, ensure, trim_text, client_logger as log from .settings import ApplyBehavior, settings from .network import NetworkError from .image import Extent, Image, Mask, Bounds, DummyImage -from .client import ClientMessage, ClientEvent, filter_supported_styles, resolve_arch +from .client import ClientMessage, ClientEvent, SharedWorkflow +from .client import filter_supported_styles, resolve_arch from .document import Document, KritaDocument from .layer import Layer, LayerType, RestoreActiveLayer from .pose import Pose @@ -37,6 +38,7 @@ class Workspace(Enum): upscaling = 1 live = 2 animation = 3 + custom = 4 class ProgressKind(Enum): @@ -64,6 +66,7 @@ class Model(QObject, ObservableProperties): fixed_seed = Property(False, persist=True) queue_front = Property(False, persist=True) translation_enabled = Property(True, persist=True) + custom_workflow = Property("", persist=True) inpaint: CustomInpaint upscale: "UpscaleWorkspace" live: "LiveWorkspace" @@ -82,6 +85,7 @@ class Model(QObject, ObservableProperties): fixed_seed_changed = pyqtSignal(bool) queue_front_changed = pyqtSignal(bool) translation_enabled_changed = pyqtSignal(bool) + custom_workflow_changed = pyqtSignal(str) progress_kind_changed = pyqtSignal(ProgressKind) progress_changed = pyqtSignal(float) error_changed = pyqtSignal(str) @@ -353,6 +357,27 @@ async def _generate_live(self, last_input: WorkflowInput | None = None): return None + def generate_custom(self): + try: + bounds = Bounds(0, 0, *self._doc.extent) + workflow = self._connection.workflows[self.custom_workflow] + img_input = ImageInput.from_extent(bounds.extent) + img_input.initial_image = self._get_current_image(bounds) + input = WorkflowInput( + WorkflowKind.custom, + img_input, + sampling=SamplingInput("custom", "custom", 1, 1000, seed=self.seed), + custom_workflow=CustomWorkflowInput(workflow, {}), + ) + job_params = JobParams(bounds, self.custom_workflow) + except Exception as e: + self.report_error(util.log_error(e)) + return + + self.clear_error() + jobs = self.enqueue_jobs(input, JobKind.diffusion, job_params, self.batch_count) + eventloop.run(_report_errors(self, jobs)) + def _get_current_image(self, bounds: Bounds): exclude = None if self.workspace is not Workspace.live: @@ -559,9 +584,9 @@ def apply_generated_result(self, job_id: str, index: int): self.jobs.selection = None self.jobs.notify_used(job_id, index) - def add_control_layer(self, job: Job, result: dict | None): + def add_control_layer(self, job: Job, result: dict | SharedWorkflow | None): assert job.kind is JobKind.control_layer and job.control - if job.control.mode is ControlMode.pose and result is not None: + if job.control.mode is ControlMode.pose and isinstance(result, dict): pose = Pose.from_open_pose_json(result) pose.scale(job.params.bounds.extent) return self.layers.create_vector(job.params.prompt, pose.to_svg()) diff --git a/ai_diffusion/ui/diffusion.py b/ai_diffusion/ui/diffusion.py index 1f218dfb54..79e4a3e47c 100644 --- a/ai_diffusion/ui/diffusion.py +++ b/ai_diffusion/ui/diffusion.py @@ -13,7 +13,7 @@ from ..root import root from ..localization import translate as _ from . import theme -from .generation import GenerationWidget +from .generation import GenerationWidget, CustomWorkflowWidget from .upscale import UpscaleWidget from .live import LiveWidget from .animation import AnimationWidget @@ -211,12 +211,14 @@ def __init__(self): self._upscaling = UpscaleWidget() self._animation = AnimationWidget() self._live = LiveWidget() + self._custom = CustomWorkflowWidget() self._frame = QStackedWidget(self) self._frame.addWidget(self._welcome) self._frame.addWidget(self._generation) self._frame.addWidget(self._upscaling) self._frame.addWidget(self._live) self._frame.addWidget(self._animation) + self._frame.addWidget(self._custom) self.setWidget(self._frame) root.connection.state_changed.connect(self.update_content) @@ -249,3 +251,6 @@ def update_content(self): elif model.workspace is Workspace.animation: self._animation.model = model self._frame.setCurrentWidget(self._animation) + elif model.workspace is Workspace.custom: + self._custom.model = model + self._frame.setCurrentWidget(self._custom) diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index f80a1edb47..bf12b896a0 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -519,6 +519,14 @@ def set_context(self): self._model.inpaint.context = data +def _create_error_label(parent: QWidget) -> QLabel: + label = QLabel(parent) + label.setStyleSheet("font-weight: bold; color: red;") + label.setWordWrap(True) + label.setVisible(False) + return label + + class GenerationWidget(QWidget): _model: Model _model_bindings: list[QMetaObject.Connection | Binding] @@ -597,10 +605,7 @@ def __init__(self): self.progress_bar.setFixedHeight(6) layout.addWidget(self.progress_bar) - self.error_text = QLabel(self) - self.error_text.setStyleSheet("font-weight: bold; color: red;") - self.error_text.setWordWrap(True) - self.error_text.setVisible(False) + self.error_text = _create_error_label(self) layout.addWidget(self.error_text) self.history = HistoryWidget(self) @@ -788,3 +793,60 @@ def update_generate_button(self): True: theme.icon("region-alpha-active"), False: theme.icon("region-alpha"), } + + +class CustomWorkflowWidget(QWidget): + def __init__(self): + super().__init__() + self._model = root.active_model + self._model_bindings: list[QMetaObject.Connection | Binding] = [] + + self._workspace_select = WorkspaceSelectWidget(self) + self._workflow_select = QComboBox(self) + + self._generate_button = GenerateButton(JobKind.diffusion, self) + self._error_text = _create_error_label(self) + + layout = QVBoxLayout() + header_layout = QHBoxLayout() + header_layout.addWidget(self._workspace_select) + header_layout.addWidget(self._workflow_select) + layout.addLayout(header_layout) + layout.addWidget(self._generate_button) + layout.addWidget(self._error_text) + layout.addStretch() + self.setLayout(layout) + + def _update_workflows(self): + workflows = root.connection.workflows + current_items = [ + self._workflow_select.itemText(i) for i in range(self._workflow_select.count()) + ] + for key in workflows: + if key not in current_items: + self._workflow_select.addItem(key, key) + for item in current_items: + if item not in workflows: + self._workflow_select.removeItem(self._workflow_select.findText(item)) + + def _update_current_workflow(self): + pass + + @property + def model(self): + return self._model + + @model.setter + def model(self, model: Model): + if self._model != model: + Binding.disconnect_all(self._model_bindings) + self._model = model + self._model_bindings = [ + bind(model, "workspace", self._workspace_select, "value", Bind.one_way), + bind_combo(model, "custom_workflow", self._workflow_select), + root.connection.workflow_published.connect(self._update_workflows), + model.custom_workflow_changed.connect(self._update_current_workflow), + model.error_changed.connect(self._error_text.setText), + model.has_error_changed.connect(self._error_text.setVisible), + self._generate_button.clicked.connect(model.generate_custom), + ] diff --git a/ai_diffusion/ui/widget.py b/ai_diffusion/ui/widget.py index e674a6e106..7e96311946 100644 --- a/ai_diffusion/ui/widget.py +++ b/ai_diffusion/ui/widget.py @@ -691,6 +691,7 @@ class WorkspaceSelectWidget(QToolButton): Workspace.upscaling: theme.icon("workspace-upscaling"), Workspace.live: theme.icon("workspace-live"), Workspace.animation: theme.icon("workspace-animation"), + Workspace.custom: theme.icon("workspace-custom"), } _value = Workspace.generation @@ -703,6 +704,7 @@ def __init__(self, parent): menu.addAction(self._create_action(_("Upscale"), Workspace.upscaling)) menu.addAction(self._create_action(_("Live"), Workspace.live)) menu.addAction(self._create_action(_("Animation"), Workspace.animation)) + menu.addAction(self._create_action(_("Graph"), Workspace.custom)) self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly) self.setMenu(menu) diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index 36f153d170..b5d1cfda96 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -9,7 +9,7 @@ from . import resolution, resources from .api import ControlInput, ImageInput, CheckpointInput, SamplingInput, WorkflowInput, LoraInput from .api import ExtentInput, InpaintMode, InpaintParams, FillMode, ConditioningInput, WorkflowKind -from .api import RegionInput +from .api import RegionInput, CustomWorkflowInput from .image import Bounds, Extent, Image, Mask, Point, multiple_of from .client import ClientModels, ModelDict from .files import FileLibrary, FileFormat @@ -18,7 +18,7 @@ from .resources import ControlMode, Arch, UpscalerName, ResourceKind, ResourceId from .settings import PerformanceSettings from .text import merge_prompt, extract_loras -from .comfy_workflow import ComfyWorkflow, ComfyRunMode, Output +from .comfy_workflow import ComfyWorkflow, ComfyRunMode, Input, Output, ComfyNode from .localization import translate as _ from .settings import settings from .util import ensure, median_or_zero, unique, client_logger as log @@ -1043,6 +1043,33 @@ def tiled_region(region: Region, index: int, tile_bounds: Bounds): return w +def expand_custom(w: ComfyWorkflow, input: CustomWorkflowInput, image: Image): + custom = ComfyWorkflow.from_dict(input.workflow) + nodes: dict[int, int] = {} # map old node IDs to new node IDs + outputs: dict[Output, Input] = {} + + def map_input(input): + if isinstance(input, Output): + if mapped := outputs.get(input): + return mapped + else: + return Output(nodes[input.node], input.output) + return input + + for node in custom: + match node.type: + case "ETN_KritaCanvas": + outputs[node.output(0)] = w.load_image(image) + outputs[node.output(1)] = image.width + outputs[node.output(2)] = image.height + case _: + mapped_inputs = {k: map_input(v) for k, v in node.inputs.items()} + mapped = ComfyNode(node.id, node.type, mapped_inputs) + nodes[node.id] = w.copy(mapped).node + + return w + + ################################################################################################### @@ -1258,6 +1285,8 @@ def create(i: WorkflowInput, models: ClientModels, comfy_mode=ComfyRunMode.serve bounds=i.inpaint.target_bounds if i.inpaint else None, seed=i.sampling.seed if i.sampling else -1, ) + elif i.kind is WorkflowKind.custom: + return expand_custom(workflow, ensure(i.custom_workflow), i.image) else: raise ValueError(f"Unsupported workflow kind: {i.kind}") diff --git a/tests/test_workflow.py b/tests/test_workflow.py index eef2c5755b..a884f71405 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,17 +1,20 @@ import itertools import pytest import dotenv +import json import os from datetime import datetime from pathlib import Path from typing import Any +from PyQt5.QtCore import Qt from ai_diffusion import workflow -from ai_diffusion.api import LoraInput, WorkflowKind, WorkflowInput, ControlInput -from ai_diffusion.api import InpaintMode, FillMode, ConditioningInput, RegionInput +from ai_diffusion.api import LoraInput, WorkflowKind, WorkflowInput, ControlInput, RegionInput +from ai_diffusion.api import InpaintMode, FillMode, ConditioningInput, CustomWorkflowInput from ai_diffusion.client import ClientModels, CheckpointInfo from ai_diffusion.comfy_client import ComfyClient from ai_diffusion.cloud_client import CloudClient +from ai_diffusion.comfy_workflow import ComfyWorkflow, ComfyNode, Output from ai_diffusion.files import FileLibrary, FileCollection, File, FileSource from ai_diffusion.resources import ControlMode from ai_diffusion.settings import PerformanceSettings @@ -538,7 +541,7 @@ async def main(): if not job_id: job_id = await client.enqueue(job) if msg.event is ClientEvent.finished and msg.job_id == job_id: - assert msg.result is not None + assert isinstance(msg.result, dict) result = Pose.from_open_pose_json(msg.result).to_svg() (result_dir / image_name).write_text(result) return @@ -749,6 +752,39 @@ def test_translation(qtapp, client): run_and_save(qtapp, client, job, "test_translation") +def test_initialize_workflow(): + w = ComfyWorkflow.from_dict( + { + "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, + "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}}, + "9": {"class_type": "B", "inputs": {"in": ["4", 0]}}, + } + ) + assert w.node(0) == ComfyNode(0, "A", {"int": 4, "float": 1.2, "string": "mouse"}) + assert w.node(1) == ComfyNode(1, "C", {"in": Output(2, 1)}) + assert w.node(2) == ComfyNode(2, "B", {"in": Output(0, 0)}) + + +def test_expand_workflow(): + ext = ComfyWorkflow() + in_img, width, height = ext.add("ETN_KritaCanvas", 3) + scaled = ext.add("ImageScale", 1, image=in_img, width=width, height=height) + ext.add("ETN_KritaOutput", 1, images=scaled) + + input = CustomWorkflowInput(workflow=ext.root, params={}) + image = Image.create(Extent(4, 4), Qt.GlobalColor.white) + + w = ComfyWorkflow() + w = workflow.expand_custom(w, input, image) + expected = [ + ComfyNode(1, "ETN_LoadImageBase64", {"image": image.to_base64()}), + ComfyNode(2, "ImageScale", {"image": Output(1, 0), "width": 4, "height": 4}), + ComfyNode(3, "ETN_KritaOutput", {"images": Output(2, 0)}), + ] + for node in expected: + assert node in w, f"Node {node} not found in\n{json.dumps(w.root, indent=2)}" + + inpaint_benchmark = { "tori": (InpaintMode.fill, "photo of tori, japanese garden", None), "bruges": (InpaintMode.fill, "photo of a canal in bruges, belgium", None), From 99b13c7313341fe2edcb42b8c06446b3a5d8d92e Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 25 Sep 2024 15:04:48 +0200 Subject: [PATCH 03/28] Common controls for custom workflow UI Workflow import now reorders nodes so dependencies come before they are used --- ai_diffusion/comfy_client.py | 2 +- ai_diffusion/comfy_workflow.py | 20 ++- ai_diffusion/icons/workspace-custom-dark.svg | 145 ++++++++++++++++++ ai_diffusion/icons/workspace-custom-light.svg | 145 ++++++++++++++++++ ai_diffusion/ui/generation.py | 90 ++++++++--- tests/test_workflow.py | 4 +- 6 files changed, 371 insertions(+), 35 deletions(-) create mode 100644 ai_diffusion/icons/workspace-custom-dark.svg create mode 100644 ai_diffusion/icons/workspace-custom-light.svg diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index 1c4cc09b33..fc72232ba4 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -230,7 +230,6 @@ async def _listen(self): f"{url}/ws?clientId={self._id}", max_size=2**30, read_limit=2**30, ping_timeout=60 ): try: - await self._subscribe_workflows() await self._listen_websocket(websocket) except websockets_exceptions.ConnectionClosedError as e: log.warning(f"Websocket connection closed: {str(e)}") @@ -264,6 +263,7 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr if msg["type"] == "status": await self._report(ClientEvent.connected, "") + await self._subscribe_workflows() if msg["type"] == "execution_start": id = msg["data"]["prompt_id"] diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 3fbe2f1d38..5ae0aba2f2 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -50,14 +50,20 @@ def __init__(self, node_inputs: dict | None = None, run_mode=ComfyRunMode.server def from_dict(existing: dict): w = ComfyWorkflow() node_map: dict[str, str] = {} - for k, v in existing.items(): - node_map[k] = str(w.node_count) - w.root[str(w.node_count)] = v + queue = list(existing.keys()) + while queue: + id = queue.pop(0) + node = existing[id] + edges = [e for e in node["inputs"].values() if isinstance(e, list)] + if any(e[0] not in node_map for e in edges): + queue.append(id) # requeue node if an input is not yet mapped + continue + + for e in edges: + e[0] = node_map[e[0]] + node_map[id] = str(w.node_count) + w.root[str(w.node_count)] = node w.node_count += 1 - for node in w.root.values(): - for e in node["inputs"].values(): - if isinstance(e, list): - e[0] = node_map.get(e[0], e[0]) return w def add_default_values(self, node_name: str, args: dict): diff --git a/ai_diffusion/icons/workspace-custom-dark.svg b/ai_diffusion/icons/workspace-custom-dark.svg new file mode 100644 index 0000000000..4eb2d3526f --- /dev/null +++ b/ai_diffusion/icons/workspace-custom-dark.svg @@ -0,0 +1,145 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + diff --git a/ai_diffusion/icons/workspace-custom-light.svg b/ai_diffusion/icons/workspace-custom-light.svg new file mode 100644 index 0000000000..e56c8e14c4 --- /dev/null +++ b/ai_diffusion/icons/workspace-custom-light.svg @@ -0,0 +1,145 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index bf12b896a0..fa88054a86 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -519,7 +519,46 @@ def set_context(self): self._model.inpaint.context = data -def _create_error_label(parent: QWidget) -> QLabel: +class ProgressBar(QProgressBar): + def __init__(self, parent: QWidget): + super().__init__(parent) + self._model = root.active_model + self._model_bindings: list[QMetaObject.Connection] = [] + self.setMinimum(0) + self.setMaximum(1000) + self.setTextVisible(False) + self.setFixedHeight(6) + + @property + def model(self): + return self._model + + @model.setter + def model(self, model: Model): + if self._model != model: + Binding.disconnect_all(self._model_bindings) + self._model = model + self._model_bindings = [ + self._model.progress_changed.connect(self._update_progress), + self._model.progress_kind_changed.connect(self._update_progress_kind), + ] + + def _update_progress_kind(self): + palette = self.palette() + if self._model.progress_kind is ProgressKind.upload: + palette.setColor(QPalette.ColorRole.Highlight, QColor(theme.progress_alt)) + self.setPalette(palette) + + def _update_progress(self): + if self._model.progress >= 0: + self.setValue(int(self._model.progress * 1000)) + else: + if self.value() >= 100: + self.reset() + self.setValue(min(99, self.value() + 2)) + + +def _create_error_label(parent: QWidget): label = QLabel(parent) label.setStyleSheet("font-weight: bold; color: red;") label.setWordWrap(True) @@ -598,11 +637,7 @@ def __init__(self): actions_layout.addWidget(self.queue_button) layout.addLayout(actions_layout) - self.progress_bar = QProgressBar(self) - self.progress_bar.setMinimum(0) - self.progress_bar.setMaximum(1000) - self.progress_bar.setTextVisible(False) - self.progress_bar.setFixedHeight(6) + self.progress_bar = ProgressBar(self) layout.addWidget(self.progress_bar) self.error_text = _create_error_label(self) @@ -634,8 +669,6 @@ def model(self, model: Model): model.document.layers.active_changed.connect(self.update_generate_button), model.regions.active_changed.connect(self.update_generate_button), model.region_only_changed.connect(self.update_generate_button), - model.progress_changed.connect(self.update_progress), - model.progress_kind_changed.connect(self.update_progress_kind), model.error_changed.connect(self.error_text.setText), model.has_error_changed.connect(self.error_text.setVisible), self.add_control_button.clicked.connect(model.regions.add_control), @@ -647,24 +680,11 @@ def model(self, model: Model): self.custom_inpaint.model = model self.generate_button.model = model self.queue_button.model = model + self.progress_bar.model = model self.strength_slider.model = model self.history.model_ = model self.update_generate_button() - def update_progress_kind(self): - palette = self.palette() - if self.model.progress_kind is ProgressKind.upload: - palette.setColor(QPalette.ColorRole.Highlight, QColor(theme.progress_alt)) - self.progress_bar.setPalette(palette) - - def update_progress(self): - if self.model.progress >= 0: - self.progress_bar.setValue(int(self.model.progress * 1000)) - else: - if self.progress_bar.value() >= 100: - self.progress_bar.reset() - self.progress_bar.setValue(min(99, self.progress_bar.value() + 2)) - def apply_result(self, item: QListWidgetItem): job_id, index = self.history.item_info(item) self.model.apply_generated_result(job_id, index) @@ -805,18 +825,32 @@ def __init__(self): self._workflow_select = QComboBox(self) self._generate_button = GenerateButton(JobKind.diffusion, self) + self._queue_button = QueueButton(parent=self) + self._queue_button.setFixedHeight(self._generate_button.height() - 2) + self._progress_bar = ProgressBar(self) self._error_text = _create_error_label(self) + self._history = HistoryWidget(self) + self._history.item_activated.connect(self.apply_result) + layout = QVBoxLayout() header_layout = QHBoxLayout() header_layout.addWidget(self._workspace_select) header_layout.addWidget(self._workflow_select) layout.addLayout(header_layout) - layout.addWidget(self._generate_button) + actions_layout = QHBoxLayout() + actions_layout.addWidget(self._generate_button) + actions_layout.addWidget(self._queue_button) + layout.addLayout(actions_layout) + layout.addWidget(self._progress_bar) layout.addWidget(self._error_text) - layout.addStretch() + layout.addWidget(self._history) self.setLayout(layout) + self._update_workflows() + self._update_current_workflow() + root.connection.workflow_published.connect(self._update_workflows) + def _update_workflows(self): workflows = root.connection.workflows current_items = [ @@ -844,9 +878,15 @@ def model(self, model: Model): self._model_bindings = [ bind(model, "workspace", self._workspace_select, "value", Bind.one_way), bind_combo(model, "custom_workflow", self._workflow_select), - root.connection.workflow_published.connect(self._update_workflows), model.custom_workflow_changed.connect(self._update_current_workflow), model.error_changed.connect(self._error_text.setText), model.has_error_changed.connect(self._error_text.setVisible), self._generate_button.clicked.connect(model.generate_custom), ] + self._queue_button.model = model + self._progress_bar.model = model + self._history.model_ = model + + def apply_result(self, item: QListWidgetItem): + job_id, index = self._history.item_info(item) + self.model.apply_generated_result(job_id, index) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index a884f71405..3f08b6cea8 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -761,8 +761,8 @@ def test_initialize_workflow(): } ) assert w.node(0) == ComfyNode(0, "A", {"int": 4, "float": 1.2, "string": "mouse"}) - assert w.node(1) == ComfyNode(1, "C", {"in": Output(2, 1)}) - assert w.node(2) == ComfyNode(2, "B", {"in": Output(0, 0)}) + assert w.node(1) == ComfyNode(1, "B", {"in": Output(0, 0)}) + assert w.node(2) == ComfyNode(2, "C", {"in": Output(1, 1)}) def test_expand_workflow(): From e1187a11e288cb8db6bf722f1c45f292bf7d76ed Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 25 Sep 2024 16:05:15 +0200 Subject: [PATCH 04/28] Fix handling of jobs which finish without generating any images #1182 --- ai_diffusion/comfy_client.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index fc72232ba4..5c26d1b1cc 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -230,6 +230,7 @@ async def _listen(self): f"{url}/ws?clientId={self._id}", max_size=2**30, read_limit=2**30, ping_timeout=60 ): try: + await self._subscribe_workflows() await self._listen_websocket(websocket) except websockets_exceptions.ConnectionClosedError as e: log.warning(f"Websocket connection closed: {str(e)}") @@ -263,7 +264,6 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr if msg["type"] == "status": await self._report(ClientEvent.connected, "") - await self._subscribe_workflows() if msg["type"] == "execution_start": id = msg["data"]["prompt_id"] @@ -281,12 +281,18 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr if msg["type"] == "executing" and msg["data"]["node"] is None: job_id = msg["data"]["prompt_id"] - if self._clear_job(job_id): + if local_id := self._clear_job(job_id): # Usually we don't get here because finished, interrupted or error is sent first. # But it may happen if the entire execution is cached and no images are sent. if len(images) == 0: images = last_images - await self._report(ClientEvent.finished, job_id, 1, images=images) + if len(images) == 0: + # Still no images. Potential scenario: execution cached, but previous + # generation happened before the client was connected. + err = "No new images were generated because the inputs did not change." + await self._report(ClientEvent.error, local_id, error=err) + else: + await self._report(ClientEvent.finished, local_id, 1, images=images) elif msg["type"] in ("execution_cached", "executing", "progress"): if self._active is not None and progress is not None: @@ -302,7 +308,7 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr pose_json = _extract_pose_json(msg) if job and pose_json: result = pose_json - elif job and _validate_executed_node(msg, len(images)): + elif job is not None and _validate_executed_node(msg, len(images)): self._clear_job(job.remote_id) last_images = images await self._report( @@ -513,9 +519,10 @@ async def _start_job(self, remote_id: str): def _clear_job(self, job_remote_id: str | asyncio.Future | None): if self._active is not None and self._active.remote_id == job_remote_id: + result = self._active.local_id self._active = None - return True - return False + return result + return None def _check_workload(self, sdver: Arch) -> list[MissingResource]: models = self.models @@ -731,8 +738,11 @@ def _validate_executed_node(msg: dict, image_count: int): images = output["images"] if len(images) != image_count: # not critical log.warning(f"Received number of images does not match: {len(images)} != {image_count}") - if len(images) > 0 and "source" in images[0] and images[0]["type"] == "output": + if image_count == 0 or len(images) == 0: + log.warning(f"Received no images (execution cached?)") + return False + if "source" in images[0] and images[0]["type"] == "output": return True except Exception as e: log.warning(f"Error processing message, error={str(e)}, msg={msg}") - return False + return False From c8bce670ab9fcc8ea946b841121bcaacb885a85a Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 26 Sep 2024 11:43:14 +0200 Subject: [PATCH 05/28] Custom graph: support selections as mask --- ai_diffusion/comfy_workflow.py | 28 +++++++++++++++++++++++++--- ai_diffusion/image.py | 4 ++++ ai_diffusion/model.py | 17 ++++++++++++++--- ai_diffusion/ui/generation.py | 3 +++ ai_diffusion/workflow.py | 12 ++++++++++-- tests/test_workflow.py | 17 +++++++++++------ 6 files changed, 67 insertions(+), 14 deletions(-) diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 5ae0aba2f2..21917bee6e 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -1,7 +1,8 @@ from __future__ import annotations +from copy import deepcopy from enum import Enum from pathlib import Path -from typing import NamedTuple, Tuple, Literal, overload, Any +from typing import NamedTuple, Tuple, Literal, TypeVar, overload, Any from uuid import uuid4 import json @@ -19,6 +20,7 @@ class Output(NamedTuple): output: int +T = TypeVar("T") Output2 = Tuple[Output, Output] Output3 = Tuple[Output, Output, Output] Output4 = Tuple[Output, Output, Output, Output] @@ -30,6 +32,17 @@ class ComfyNode(NamedTuple): type: str inputs: dict[str, Input] + @overload + def input(self, key: str, default: T) -> T: ... + + @overload + def input(self, key: str, default: None = None) -> Input: ... + + def input(self, key: str, default: T | None = None) -> T | Input: + result = self.inputs[key] + assert default is None or type(result) == type(default) + return result + def output(self, index=0) -> Output: return Output(int(self.id), index) @@ -47,13 +60,13 @@ def __init__(self, node_inputs: dict | None = None, run_mode=ComfyRunMode.server self._run_mode: ComfyRunMode = run_mode @staticmethod - def from_dict(existing: dict): + def import_graph(existing: dict): w = ComfyWorkflow() node_map: dict[str, str] = {} queue = list(existing.keys()) while queue: id = queue.pop(0) - node = existing[id] + node = deepcopy(existing[id]) edges = [e for e in node["inputs"].values() if isinstance(e, list)] if any(e[0] not in node_map for e in edges): queue.append(id) # requeue node if an input is not yet mapped @@ -66,6 +79,12 @@ def from_dict(existing: dict): w.node_count += 1 return w + @staticmethod + def from_dict(existing: dict): + w = ComfyWorkflow() + w.root = existing + return w + def add_default_values(self, node_name: str, args: dict): if node_inputs := self._nodes_required_inputs.get(node_name, None): for k, v in node_inputs.items(): @@ -138,6 +157,9 @@ def node(self, node_id: int): def copy(self, node: ComfyNode): return self.add(node.type, 1, **node.inputs) + def find(self, type: str): + return (self.node(int(k)) for k, v in self.root.items() if v["class_type"] == type) + def __iter__(self): return iter(self.node(int(k)) for k in self.root.keys()) diff --git a/ai_diffusion/image.py b/ai_diffusion/image.py index fafe527fe7..82e35053df 100644 --- a/ai_diffusion/image.py +++ b/ai_diffusion/image.py @@ -690,6 +690,10 @@ def __init__(self, bounds: Bounds, data: Union[QImage, QByteArray]): ) assert not self.image.isNull() + @staticmethod + def transparent(bounds: Bounds): + return Mask(bounds, QByteArray(bytes(bounds.width * bounds.height))) + @staticmethod def rectangle(bounds: Bounds, feather=0): # Note: for testing only, where Krita selection is not available diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 0fe8a1105e..44f5a1e0ba 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -31,6 +31,7 @@ from .region import Region, RegionLink, RootRegion, process_regions, get_region_inpaint_mask from .resources import ControlMode from .resolution import compute_bounds, compute_relative_bounds +from .comfy_workflow import ComfyWorkflow class Workspace(Enum): @@ -359,15 +360,25 @@ async def _generate_live(self, last_input: WorkflowInput | None = None): def generate_custom(self): try: + workflow_raw = self._connection.workflows[self.custom_workflow] + wf = ComfyWorkflow.import_graph(workflow_raw) bounds = Bounds(0, 0, *self._doc.extent) - workflow = self._connection.workflows[self.custom_workflow] img_input = ImageInput.from_extent(bounds.extent) img_input.initial_image = self._get_current_image(bounds) + seed = self.seed if self.fixed_seed else workflow.generate_seed() + + if next(wf.find(type="ETN_KritaSelection"), None): + mask, _ = self._doc.create_mask_from_selection() + if mask: + img_input.hires_mask = mask.to_image(bounds.extent) + else: + img_input.hires_mask = Mask.transparent(bounds).to_image() + input = WorkflowInput( WorkflowKind.custom, img_input, - sampling=SamplingInput("custom", "custom", 1, 1000, seed=self.seed), - custom_workflow=CustomWorkflowInput(workflow, {}), + sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed), + custom_workflow=CustomWorkflowInput(wf.root, {}), ) job_params = JobParams(bounds, self.custom_workflow) except Exception as e: diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index fa88054a86..307e7464c4 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -217,6 +217,7 @@ def update_selection(self): elif selection: item = self._find(selection) if item is not None and not item.isSelected(): + self.clearSelection() item.setSelected(True) self.update_apply_button() @@ -875,6 +876,8 @@ def model(self, model: Model): if self._model != model: Binding.disconnect_all(self._model_bindings) self._model = model + if not model.custom_workflow: + model.custom_workflow = self._workflow_select.currentData() self._model_bindings = [ bind(model, "workspace", self._workspace_select, "value", Bind.one_way), bind_combo(model, "custom_workflow", self._workflow_select), diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index b5d1cfda96..7a744e1d22 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -1043,7 +1043,9 @@ def tiled_region(region: Region, index: int, tile_bounds: Bounds): return w -def expand_custom(w: ComfyWorkflow, input: CustomWorkflowInput, image: Image): +def expand_custom( + w: ComfyWorkflow, input: CustomWorkflowInput, images: ImageInput, sampling: SamplingInput +): custom = ComfyWorkflow.from_dict(input.workflow) nodes: dict[int, int] = {} # map old node IDs to new node IDs outputs: dict[Output, Input] = {} @@ -1059,9 +1061,13 @@ def map_input(input): for node in custom: match node.type: case "ETN_KritaCanvas": + image = ensure(images.initial_image) outputs[node.output(0)] = w.load_image(image) outputs[node.output(1)] = image.width outputs[node.output(2)] = image.height + outputs[node.output(3)] = sampling.seed + case "ETN_KritaSelection": + outputs[node.output(0)] = w.load_mask(ensure(images.hires_mask)) case _: mapped_inputs = {k: map_input(v) for k, v in node.inputs.items()} mapped = ComfyNode(node.id, node.type, mapped_inputs) @@ -1286,7 +1292,9 @@ def create(i: WorkflowInput, models: ClientModels, comfy_mode=ComfyRunMode.serve seed=i.sampling.seed if i.sampling else -1, ) elif i.kind is WorkflowKind.custom: - return expand_custom(workflow, ensure(i.custom_workflow), i.image) + return expand_custom( + workflow, ensure(i.custom_workflow), ensure(i.images), ensure(i.sampling) + ) else: raise ValueError(f"Unsupported workflow kind: {i.kind}") diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 3f08b6cea8..8cb99b22e2 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -11,6 +11,7 @@ from ai_diffusion import workflow from ai_diffusion.api import LoraInput, WorkflowKind, WorkflowInput, ControlInput, RegionInput from ai_diffusion.api import InpaintMode, FillMode, ConditioningInput, CustomWorkflowInput +from ai_diffusion.api import SamplingInput, ImageInput from ai_diffusion.client import ClientModels, CheckpointInfo from ai_diffusion.comfy_client import ComfyClient from ai_diffusion.cloud_client import CloudClient @@ -752,8 +753,8 @@ def test_translation(qtapp, client): run_and_save(qtapp, client, job, "test_translation") -def test_initialize_workflow(): - w = ComfyWorkflow.from_dict( +def test_import_workflow(): + w = ComfyWorkflow.import_graph( { "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}}, @@ -767,19 +768,23 @@ def test_initialize_workflow(): def test_expand_workflow(): ext = ComfyWorkflow() - in_img, width, height = ext.add("ETN_KritaCanvas", 3) + in_img, width, height, seed = ext.add("ETN_KritaCanvas", 4) scaled = ext.add("ImageScale", 1, image=in_img, width=width, height=height) ext.add("ETN_KritaOutput", 1, images=scaled) + ext.add("SeedEater", 1, seed=seed) input = CustomWorkflowInput(workflow=ext.root, params={}) - image = Image.create(Extent(4, 4), Qt.GlobalColor.white) + images = ImageInput.from_extent(Extent(4, 4)) + images.initial_image = Image.create(Extent(4, 4), Qt.GlobalColor.white) + sampling = SamplingInput("", "", 1.0, 1000, seed=123) w = ComfyWorkflow() - w = workflow.expand_custom(w, input, image) + w = workflow.expand_custom(w, input, images, sampling) expected = [ - ComfyNode(1, "ETN_LoadImageBase64", {"image": image.to_base64()}), + ComfyNode(1, "ETN_LoadImageBase64", {"image": images.initial_image.to_base64()}), ComfyNode(2, "ImageScale", {"image": Output(1, 0), "width": 4, "height": 4}), ComfyNode(3, "ETN_KritaOutput", {"images": Output(2, 0)}), + ComfyNode(4, "SeedEater", {"seed": 123}), ] for node in expected: assert node in w, f"Node {node} not found in\n{json.dumps(w.root, indent=2)}" From 0ac1f0ee3187520e010e20e5ffa042372f420205 Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 27 Sep 2024 19:18:38 +0200 Subject: [PATCH 06/28] Persist graphs in the model, store exposed graph parameters separately, and generate a UI for those parameters --- ai_diffusion/custom_workflow.py | 91 ++++++++++++++ ai_diffusion/layer.py | 6 +- ai_diffusion/model.py | 19 +-- ai_diffusion/persistence.py | 3 + ai_diffusion/properties.py | 9 +- ai_diffusion/ui/generation.py | 208 ++++++++++++++++++++++++++++---- 6 files changed, 293 insertions(+), 43 deletions(-) create mode 100644 ai_diffusion/custom_workflow.py diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py new file mode 100644 index 0000000000..6c162643b8 --- /dev/null +++ b/ai_diffusion/custom_workflow.py @@ -0,0 +1,91 @@ +from enum import Enum +from typing import Any, NamedTuple +from PyQt5.QtCore import pyqtSignal, QObject + +from .comfy_workflow import ComfyWorkflow +from .connection import Connection +from .properties import Property, ObservableProperties + + +class ParamKind(Enum): + image_layer = 0 + mask_layer = 1 + number_int = 2 + + +class CustomParam(NamedTuple): + kind: ParamKind + name: str + default: Any | None = None + min: int | None = None + max: int | None = None + + +def _gather_params(w: ComfyWorkflow): + for node in w: + match node.type: + case "ETN_KritaImageLayer": + name = node.input("name", "Image") + yield CustomParam(ParamKind.image_layer, name) + case "ETN_KritaMaskLayer": + name = node.input("name", "Mask") + yield CustomParam(ParamKind.mask_layer, name) + case "ETN_IntParameter": + name = node.input("name", "Parameter") + default = node.input("default", 0) + min = node.input("min", -(2**31)) + max = node.input("max", 2**31) + yield CustomParam(ParamKind.number_int, name, default=default, min=min, max=max) + + +class CustomWorkspace(QObject, ObservableProperties): + + graph_id = Property("", persist=True) + graph = Property({}, persist=True, setter="_set_graph") + params = Property({}, persist=True) + + graph_id_changed = pyqtSignal(str) + graph_changed = pyqtSignal(dict) + params_changed = pyqtSignal(dict) + modified = pyqtSignal(QObject, str) + + def __init__(self, connection: Connection): + super().__init__() + self._workflow: ComfyWorkflow | None = None + self._metadata: list[CustomParam] = [] + self._connection = connection + self._connection.workflow_published.connect(self._update_workflow) + + if len(connection.workflows) > 0: + self._update_workflow(next(iter(connection.workflows.keys()))) + + def _update_workflow(self, id: str): + wf = self._connection.workflows[id] + if not self.graph_id: + self.graph_id = id + if self.graph_id == id: + self.graph = wf + + def _set_graph(self, graph: dict): + self._workflow = ComfyWorkflow.import_graph(graph) + self._metadata = list(_gather_params(self._workflow)) + self.params = _coerce(self.params, self._metadata) + self._graph = self._workflow.root + self.graph_changed.emit(self._graph) + + @property + def workflow(self): + return self._workflow + + @property + def metadata(self): + return self._metadata + + +def _coerce(params: dict[str, Any], types: list[CustomParam]): + def use(value, default): + if value is None or not type(value) == type(default): + return default + return value + + return {t.name: use(params.get(t.name), t.default) for t in types} diff --git a/ai_diffusion/layer.py b/ai_diffusion/layer.py index d670e8b4bb..a01771825d 100644 --- a/ai_diffusion/layer.py +++ b/ai_diffusion/layer.py @@ -528,19 +528,19 @@ def update_layer_image(self, layer: Layer, image: Image, bounds: Bounds, keep_al _mask_types = [t.value for t in LayerType if t.is_mask] @property - def all(self): + def all(self) -> list[Layer]: if self._doc is None: return [] return [self.wrap(n) for n in traverse_layers(self._doc.rootNode())] @property - def images(self): + def images(self) -> list[Layer]: if self._doc is None: return [] return [self.wrap(n) for n in traverse_layers(self._doc.rootNode(), self._image_types)] @property - def masks(self): + def masks(self) -> list[Layer]: if self._doc is None: return [] return [self.wrap(n) for n in traverse_layers(self._doc.rootNode(), self._mask_types)] diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 44f5a1e0ba..96cc5dd42f 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -19,6 +19,7 @@ from .image import Extent, Image, Mask, Bounds, DummyImage from .client import ClientMessage, ClientEvent, SharedWorkflow from .client import filter_supported_styles, resolve_arch +from .custom_workflow import CustomWorkspace from .document import Document, KritaDocument from .layer import Layer, LayerType, RestoreActiveLayer from .pose import Pose @@ -53,10 +54,6 @@ class Model(QObject, ObservableProperties): list of finished, currently running and enqueued jobs. """ - _doc: Document - _connection: Connection - _layer: Layer | None = None - workspace = Property(Workspace.generation, setter="set_workspace", persist=True) regions: "RootRegion" style = Property(Styles.list().default, setter="set_style", persist=True) @@ -67,14 +64,8 @@ class Model(QObject, ObservableProperties): fixed_seed = Property(False, persist=True) queue_front = Property(False, persist=True) translation_enabled = Property(True, persist=True) - custom_workflow = Property("", persist=True) - inpaint: CustomInpaint - upscale: "UpscaleWorkspace" - live: "LiveWorkspace" - animation: "AnimationWorkspace" progress_kind = Property(ProgressKind.generation) progress = Property(0.0) - jobs: JobQueue error = Property("") workspace_changed = pyqtSignal(Workspace) @@ -86,7 +77,6 @@ class Model(QObject, ObservableProperties): fixed_seed_changed = pyqtSignal(bool) queue_front_changed = pyqtSignal(bool) translation_enabled_changed = pyqtSignal(bool) - custom_workflow_changed = pyqtSignal(str) progress_kind_changed = pyqtSignal(ProgressKind) progress_changed = pyqtSignal(float) error_changed = pyqtSignal(str) @@ -97,6 +87,7 @@ def __init__(self, document: Document, connection: Connection): super().__init__() self._doc = document self._connection = connection + self._layer: Layer | None = None self.generate_seed() self.jobs = JobQueue() self.regions = RootRegion(self) @@ -104,6 +95,7 @@ def __init__(self, document: Document, connection: Connection): self.upscale = UpscaleWorkspace(self) self.live = LiveWorkspace(self) self.animation = AnimationWorkspace(self) + self.custom = CustomWorkspace(connection) self.jobs.selection_changed.connect(self.update_preview) self.error_changed.connect(lambda: self.has_error_changed.emit(self.has_error)) @@ -360,8 +352,7 @@ async def _generate_live(self, last_input: WorkflowInput | None = None): def generate_custom(self): try: - workflow_raw = self._connection.workflows[self.custom_workflow] - wf = ComfyWorkflow.import_graph(workflow_raw) + wf = ensure(self.custom.workflow) bounds = Bounds(0, 0, *self._doc.extent) img_input = ImageInput.from_extent(bounds.extent) img_input.initial_image = self._get_current_image(bounds) @@ -380,7 +371,7 @@ def generate_custom(self): sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed), custom_workflow=CustomWorkflowInput(wf.root, {}), ) - job_params = JobParams(bounds, self.custom_workflow) + job_params = JobParams(bounds, self.custom.graph_id) except Exception as e: self.report_error(util.log_error(e)) return diff --git a/ai_diffusion/persistence.py b/ai_diffusion/persistence.py index ad78d02076..b01dec15e8 100644 --- a/ai_diffusion/persistence.py +++ b/ai_diffusion/persistence.py @@ -132,6 +132,7 @@ def _save(self): state["upscale"] = _serialize(model.upscale) state["live"] = _serialize(model.live) state["animation"] = _serialize(model.animation) + state["custom"] = _serialize(model.custom) state["history"] = [asdict(h) for h in self._history] state["root"] = _serialize(model.regions) state["control"] = [_serialize(c) for c in model.regions.control] @@ -150,6 +151,7 @@ def _load(self, model: Model, state_bytes: bytes): _deserialize(model.upscale, state.get("upscale", {})) _deserialize(model.live, state.get("live", {})) _deserialize(model.animation, state.get("animation", {})) + _deserialize(model.custom, state.get("custom", {})) _deserialize(model.regions, state.get("root", {})) for control_state in state.get("control", []): _deserialize(model.regions.control.emplace(), control_state) @@ -176,6 +178,7 @@ def _track(self, model: Model): model.upscale.modified.connect(self._save) model.live.modified.connect(self._save) model.animation.modified.connect(self._save) + model.custom.modified.connect(self._save) model.jobs.job_finished.connect(self._save_results) model.jobs.job_discarded.connect(self._remove_results) model.jobs.result_discarded.connect(self._remove_image) diff --git a/ai_diffusion/properties.py b/ai_diffusion/properties.py index a4cd94086e..00902e9e55 100644 --- a/ai_diffusion/properties.py +++ b/ai_diffusion/properties.py @@ -1,3 +1,4 @@ +from copy import copy from enum import Enum from typing import Any, NamedTuple, Sequence, TypeVar, Generic @@ -18,7 +19,7 @@ def __init_subclass__(cls, **kwargs): name: attr for name, attr in cls.__dict__.items() if isinstance(attr, Property) } for name, property in properties.items(): - setattr(cls, f"_{name}", property.default_value) + setattr(cls, f"_{name}", _copy_reference_types(property.default_value)) getter, setter = None, None if property.getter is not None: getter = getattr(cls, property.getter) @@ -198,3 +199,9 @@ def deserialize(obj: QObject, data: dict[str, Any], converter=_default_deseriali if not isinstance(value, type(current)): raise TypeError(f"{name} was '{value}', but expected {type(current)}") setattr(obj, name, value) + + +def _copy_reference_types(object): + if isinstance(object, (list, dict)): + return copy(object) + return object diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index 307e7464c4..8bb19dd701 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -1,26 +1,21 @@ from __future__ import annotations +from enum import Enum from textwrap import wrap as wrap_text +from typing import Any, NamedTuple from PyQt5.QtCore import Qt, QMetaObject, QSize, QPoint, QUuid, pyqtSignal from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor +from PyQt5.QtWidgets import QAction, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QProgressBar from PyQt5.QtWidgets import ( - QAction, - QWidget, - QVBoxLayout, - QHBoxLayout, - QPushButton, - QProgressBar, QLabel, QListWidget, QListWidgetItem, QListView, QSizePolicy, QToolButton, - QComboBox, - QCheckBox, - QMenu, - QShortcut, - QMessageBox, + QSlider, + QSpinBox, ) +from PyQt5.QtWidgets import QComboBox, QCheckBox, QMenu, QShortcut, QMessageBox, QGridLayout from ..properties import Binding, Bind, bind, bind_combo, bind_toggle from ..image import Bounds, Extent, Image @@ -29,6 +24,7 @@ from ..style import Styles from ..root import root from ..workflow import InpaintMode, FillMode +from ..comfy_workflow import ComfyWorkflow, ComfyNode, Input, Output from ..localization import translate as _ from ..util import ensure, flatten from .widget import WorkspaceSelectWidget, StyleSelectWidget, StrengthWidget, QueueButton @@ -36,6 +32,8 @@ from .region import RegionPromptWidget from . import theme +from ..custom_workflow import CustomParam, ParamKind + class HistoryWidget(QListWidget): _model: Model @@ -816,6 +814,146 @@ def update_generate_button(self): } +class LayerSelect(QComboBox): + value_changed = pyqtSignal() + + def __init__(self, filter: str | None = None, parent: QWidget | None = None): + super().__init__(parent) + self.setContentsMargins(0, 0, 0, 0) + self.filter = filter + self.currentIndexChanged.connect(lambda _: self.value_changed.emit()) + + self._update() + root.active_model.layers.changed.connect(self._update) + + def _update(self): + if self.filter is None: + layers = root.active_model.layers.all + elif self.filter == "image": + layers = root.active_model.layers.images + elif self.filter == "mask": + layers = root.active_model.layers.masks + else: + assert False, f"Unknown filter: {self.filter}" + + for l in layers: + if self.findData(l.id) == -1: + self.addItem(l.name, l.id) + i = 0 + while i < self.count(): + if self.itemData(i) not in (l.id for l in layers): + self.removeItem(i) + else: + i += 1 + + @property + def value(self): + return self.currentData() + + @value.setter + def value(self, value: str): + i = self.findData(value) + if i != -1 and i != self.currentIndex(): + self.setCurrentIndex(i) + + +class IntParamWidget(QWidget): + value_changed = pyqtSignal() + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(parent) + self.setContentsMargins(0, 0, 0, 0) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + + assert param.min is not None and param.max is not None and param.default is not None + if param.max - param.min <= 200: + self._widget = QSlider(Qt.Orientation.Horizontal, parent) + self._widget.setRange(param.min, param.max) + self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) + self._widget.valueChanged.connect(self._notify) + self._label = QLabel(self) + self._label.setFixedWidth(40) + self._label.setAlignment(Qt.AlignmentFlag.AlignRight) + layout.addWidget(self._widget) + layout.addWidget(self._label) + self.setLayout(layout) + else: + self._widget = QSpinBox(parent) + self._widget.setRange(param.min, param.max) + self._widget.valueChanged.connect(self._notify) + self._label = None + layout = QHBoxLayout(self) + layout.addWidget(self._widget) + self.setLayout(layout) + + self.value = param.default + + def _notify(self): + if self._label: + self._label.setText(str(self._widget.value())) + self.value_changed.emit() + + @property + def value(self): + return self._widget.value() + + @value.setter + def value(self, value: int): + self._widget.setValue(value) + + +CustomParamWidget = LayerSelect | IntParamWidget + + +def _create_param_widget(param: CustomParam, parent: QWidget): + if param.kind is ParamKind.image_layer: + return LayerSelect("image", parent) + if param.kind is ParamKind.mask_layer: + return LayerSelect("mask", parent) + if param.kind is ParamKind.number_int: + return IntParamWidget(param, parent) + assert False, f"Unknown param kind: {param.kind}" + + +class WorkflowParamsWidget(QWidget): + value_changed = pyqtSignal() + + def __init__(self, params: list[CustomParam], parent: QWidget | None = None): + super().__init__(parent) + self._widgets: dict[str, CustomParamWidget] = {} + + layout = QGridLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setColumnMinimumWidth(1, 10) + self.setLayout(layout) + + for p in params: + label = QLabel(p.name, self) + widget = _create_param_widget(p, self) + widget.value_changed.connect(self._notify) + row = len(self._widgets) + layout.addWidget(label, row, 0) + layout.addWidget(widget, row, 2) + self._widgets[p.name] = widget + + def _notify(self): + self.value_changed.emit() + + @property + def value(self): + return {name: widget.value for name, widget in self._widgets.items()} + + @value.setter + def value(self, values: dict[str, Any]): + for name, value in values.items(): + if widget := self._widgets.get(name): + if type(widget.value) == type(value): + widget.value = value + + class CustomWorkflowWidget(QWidget): def __init__(self): super().__init__() @@ -824,6 +962,8 @@ def __init__(self): self._workspace_select = WorkspaceSelectWidget(self) self._workflow_select = QComboBox(self) + self._workflow_select.currentIndexChanged.connect(self._change_workflow) + self._params_widget = WorkflowParamsWidget([], self) self._generate_button = GenerateButton(JobKind.diffusion, self) self._queue_button = QueueButton(parent=self) @@ -834,26 +974,28 @@ def __init__(self): self._history = HistoryWidget(self) self._history.item_activated.connect(self.apply_result) - layout = QVBoxLayout() + self._layout = QVBoxLayout() header_layout = QHBoxLayout() header_layout.addWidget(self._workspace_select) header_layout.addWidget(self._workflow_select) - layout.addLayout(header_layout) + self._layout.addLayout(header_layout) + self._layout.addWidget(self._params_widget) actions_layout = QHBoxLayout() actions_layout.addWidget(self._generate_button) actions_layout.addWidget(self._queue_button) - layout.addLayout(actions_layout) - layout.addWidget(self._progress_bar) - layout.addWidget(self._error_text) - layout.addWidget(self._history) - self.setLayout(layout) + self._layout.addLayout(actions_layout) + self._layout.addWidget(self._progress_bar) + self._layout.addWidget(self._error_text) + self._layout.addWidget(self._history) + self.setLayout(self._layout) self._update_workflows() - self._update_current_workflow() root.connection.workflow_published.connect(self._update_workflows) def _update_workflows(self): - workflows = root.connection.workflows + workflows = set(root.connection.workflows) + if self.model.custom.graph_id: + workflows.add(self.model.custom.graph_id) current_items = [ self._workflow_select.itemText(i) for i in range(self._workflow_select.count()) ] @@ -864,8 +1006,25 @@ def _update_workflows(self): if item not in workflows: self._workflow_select.removeItem(self._workflow_select.findText(item)) + self._update_current_workflow() + def _update_current_workflow(self): - pass + if not self.model.custom.workflow: + return + self._params_widget.deleteLater() + self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self) + self._params_widget.value = self.model.custom.params + self._layout.insertWidget(1, self._params_widget) + self._params_widget.value_changed.connect(self._change_params) + + def _change_workflow(self): + id = self._workflow_select.currentData() + if graph := root.connection.workflows.get(id): + self.model.custom.graph = graph + self.model.custom.graph_id = id + + def _change_params(self): + self.model.custom.params = self._params_widget.value @property def model(self): @@ -876,12 +1035,10 @@ def model(self, model: Model): if self._model != model: Binding.disconnect_all(self._model_bindings) self._model = model - if not model.custom_workflow: - model.custom_workflow = self._workflow_select.currentData() self._model_bindings = [ bind(model, "workspace", self._workspace_select, "value", Bind.one_way), - bind_combo(model, "custom_workflow", self._workflow_select), - model.custom_workflow_changed.connect(self._update_current_workflow), + bind_combo(model.custom, "graph_id", self._workflow_select, Bind.one_way), + model.custom.graph_changed.connect(self._update_current_workflow), model.error_changed.connect(self._error_text.setText), model.has_error_changed.connect(self._error_text.setVisible), self._generate_button.clicked.connect(model.generate_custom), @@ -889,6 +1046,7 @@ def model(self, model: Model): self._queue_button.model = model self._progress_bar.model = model self._history.model_ = model + self._update_workflows() def apply_result(self, item: QListWidgetItem): job_id, index = self._history.item_info(item) From e4af1409ee4125bc12af6682a47680771b7093c0 Mon Sep 17 00:00:00 2001 From: Acly Date: Sat, 28 Sep 2024 10:47:44 +0200 Subject: [PATCH 07/28] Allow custom workflows from different sources (file, document, web ui instance) to co-exist --- ai_diffusion/custom_workflow.py | 189 +++++++++++++++++++++++++++----- ai_diffusion/model.py | 9 +- ai_diffusion/root.py | 14 ++- ai_diffusion/ui/generation.py | 32 +----- ai_diffusion/ui/style.py | 2 +- 5 files changed, 185 insertions(+), 61 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 6c162643b8..9fc484058b 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -1,10 +1,133 @@ +import json + from enum import Enum +from dataclasses import dataclass from typing import Any, NamedTuple -from PyQt5.QtCore import pyqtSignal, QObject +from pathlib import Path +from PyQt5.QtCore import Qt, QObject, QAbstractListModel, QSortFilterProxyModel, QModelIndex +from PyQt5.QtCore import pyqtSignal from .comfy_workflow import ComfyWorkflow from .connection import Connection from .properties import Property, ObservableProperties +from .util import user_data_dir, client_logger as log + + +class WorkflowSource(Enum): + document = 0 + remote = 1 + local = 2 + + +@dataclass +class CustomWorkflow: + id: str + source: WorkflowSource + graph: dict + path: Path | None = None + + @property + def name(self): + return self.id.removesuffix(".json") + + @property + def workflow(self): + return ComfyWorkflow.import_graph(self.graph) + + +class WorkflowCollection(QAbstractListModel): + def __init__(self, connection: Connection, folder: Path | None = None): + super().__init__() + self._workflows: list[CustomWorkflow] = [] + + self._folder = folder or user_data_dir / "workflows" + for file in self._folder.glob("*.json"): + try: + self._process_file(file) + except Exception as e: + log.exception(f"Error loading workflow from {file}: {e}") + + self._connection = connection + self._connection.workflow_published.connect(self._process_remote_workflow) + for wf in self._connection.workflows.keys(): + self._process_remote_workflow(wf) + + def _process_remote_workflow(self, id: str): + self._process(CustomWorkflow(id, WorkflowSource.remote, self._connection.workflows[id])) + + def _process_file(self, file: Path): + with file.open("r") as f: + self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f))) + + def _process(self, workflow: CustomWorkflow): + idx = self.find_index(workflow.id) + if idx.isValid(): + self.set_graph(idx, workflow.graph) + else: + self.append(workflow) + + def rowCount(self, parent=QModelIndex()): + return len(self._workflows) + + def data(self, index: QModelIndex, role: int = 0): + if role == Qt.ItemDataRole.DisplayRole: + return self._workflows[index.row()].name + if role == Qt.ItemDataRole.UserRole: + return self._workflows[index.row()].id + + def append(self, item: CustomWorkflow): + end = len(self._workflows) + self.beginInsertRows(QModelIndex(), end, end) + self._workflows.append(item) + self.endInsertRows() + + def set_graph(self, index: QModelIndex, graph: dict): + self._workflows[index.row()].graph = graph + self.dataChanged.emit(index, index) + + def find_index(self, id: str): + for i, wf in enumerate(self._workflows): + if wf.id == id: + return self.index(i) + return QModelIndex() + + def find(self, id: str): + idx = self.find_index(id) + if idx.isValid(): + return self._workflows[idx.row()] + return None + + def get(self, id: str): + result = self.find(id) + if result is None: + raise KeyError(f"Workflow {id} not found") + return result + + def __getitem__(self, index: int): + return self._workflows[index] + + def __len__(self): + return len(self._workflows) + + +class SortedWorkflows(QSortFilterProxyModel): + def __init__(self, workflows: WorkflowCollection): + super().__init__() + self._workflows = workflows + self.setSourceModel(workflows) + self.setSortCaseSensitivity(Qt.CaseSensitivity.CaseInsensitive) + self.sort(0) + + def lessThan(self, left: QModelIndex, right: QModelIndex): + l = self._workflows[left.row()] + r = self._workflows[right.row()] + if l.source is r.source: + return l.name < r.name + return l.source.value < r.source.value + + def __getitem__(self, index: int): + idx = self.mapToSource(self.index(index, 0)).row() + return self._workflows[idx] class ParamKind(Enum): @@ -40,43 +163,59 @@ def _gather_params(w: ComfyWorkflow): class CustomWorkspace(QObject, ObservableProperties): - graph_id = Property("", persist=True) - graph = Property({}, persist=True, setter="_set_graph") + workflow_id = Property("", setter="_set_workflow_id") params = Property({}, persist=True) - graph_id_changed = pyqtSignal(str) - graph_changed = pyqtSignal(dict) + workflow_id_changed = pyqtSignal(str) + graph_changed = pyqtSignal() params_changed = pyqtSignal(dict) modified = pyqtSignal(QObject, str) - def __init__(self, connection: Connection): + def __init__(self, workflows: WorkflowCollection): super().__init__() - self._workflow: ComfyWorkflow | None = None + self._workflows = workflows + self._workflow: CustomWorkflow | None = None + self._graph: ComfyWorkflow | None = None self._metadata: list[CustomParam] = [] - self._connection = connection - self._connection.workflow_published.connect(self._update_workflow) - - if len(connection.workflows) > 0: - self._update_workflow(next(iter(connection.workflows.keys()))) - def _update_workflow(self, id: str): - wf = self._connection.workflows[id] - if not self.graph_id: - self.graph_id = id - if self.graph_id == id: - self.graph = wf - - def _set_graph(self, graph: dict): - self._workflow = ComfyWorkflow.import_graph(graph) - self._metadata = list(_gather_params(self._workflow)) - self.params = _coerce(self.params, self._metadata) - self._graph = self._workflow.root - self.graph_changed.emit(self._graph) + workflows.dataChanged.connect(self._update_workflow) + workflows.rowsInserted.connect(self._set_default_workflow) + self._set_default_workflow() + + def _set_default_workflow(self): + if not self.workflow_id and len(self._workflows) > 0: + self.workflow_id = self._workflows[0].id + + def _update_workflow(self, idx: QModelIndex, _: QModelIndex): + wf = self._workflows[idx.row()] + if wf.id == self._workflow_id: + self._workflow = wf + self._graph = self._workflow.workflow + self._metadata = list(_gather_params(self._graph)) + self.params = _coerce(self.params, self._metadata) + self.graph_changed.emit() + + def _set_workflow_id(self, id: str): + if self._workflow_id == id: + return + self._workflow_id = id + self.workflow_id_changed.emit(id) + self.modified.emit(self, "workflow_id") + self._update_workflow(self._workflows.find_index(id), QModelIndex()) + + def set_graph(self, id: str, graph: dict): + if self._workflows.find(id) is None: + self._workflows.append(CustomWorkflow(id, WorkflowSource.document, graph)) + self.workflow_id = id @property def workflow(self): return self._workflow + @property + def graph(self): + return self._graph + @property def metadata(self): return self._metadata diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 96cc5dd42f..4ecec0a6a7 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -19,7 +19,7 @@ from .image import Extent, Image, Mask, Bounds, DummyImage from .client import ClientMessage, ClientEvent, SharedWorkflow from .client import filter_supported_styles, resolve_arch -from .custom_workflow import CustomWorkspace +from .custom_workflow import CustomWorkspace, WorkflowCollection from .document import Document, KritaDocument from .layer import Layer, LayerType, RestoreActiveLayer from .pose import Pose @@ -32,7 +32,6 @@ from .region import Region, RegionLink, RootRegion, process_regions, get_region_inpaint_mask from .resources import ControlMode from .resolution import compute_bounds, compute_relative_bounds -from .comfy_workflow import ComfyWorkflow class Workspace(Enum): @@ -83,7 +82,7 @@ class Model(QObject, ObservableProperties): has_error_changed = pyqtSignal(bool) modified = pyqtSignal(QObject, str) - def __init__(self, document: Document, connection: Connection): + def __init__(self, document: Document, connection: Connection, workflows: WorkflowCollection): super().__init__() self._doc = document self._connection = connection @@ -95,7 +94,7 @@ def __init__(self, document: Document, connection: Connection): self.upscale = UpscaleWorkspace(self) self.live = LiveWorkspace(self) self.animation = AnimationWorkspace(self) - self.custom = CustomWorkspace(connection) + self.custom = CustomWorkspace(workflows) self.jobs.selection_changed.connect(self.update_preview) self.error_changed.connect(lambda: self.has_error_changed.emit(self.has_error)) @@ -352,7 +351,7 @@ async def _generate_live(self, last_input: WorkflowInput | None = None): def generate_custom(self): try: - wf = ensure(self.custom.workflow) + wf = ensure(self.custom.graph) bounds = Bounds(0, 0, *self._doc.extent) img_input = ImageInput.from_extent(bounds.extent) img_input.initial_image = self._get_current_image(bounds) diff --git a/ai_diffusion/root.py b/ai_diffusion/root.py index ddecaf2e0d..da069a6ea6 100644 --- a/ai_diffusion/root.py +++ b/ai_diffusion/root.py @@ -4,10 +4,11 @@ from .connection import Connection, ConnectionState from .client import ClientMessage +from .custom_workflow import WorkflowCollection from .server import Server, ServerState from .document import Document, KritaDocument from .model import Model -from .files import FileLibrary, File, FileSource +from .files import FileFormat, FileLibrary, File, FileSource from .persistence import ModelSync, RecentlyUsedSync, import_prompt_from_file from .updates import AutoUpdate from .ui.theme import checkpoint_icon @@ -37,6 +38,7 @@ def init(self): self._server = Server(settings.server_path) self._connection = Connection() self._files = FileLibrary.load() + self._workflows = WorkflowCollection(self._connection) self._models = [] self._recent = RecentlyUsedSync.from_settings() self._auto_update = AutoUpdate() @@ -50,7 +52,7 @@ def prune_models(self): self._models = [m for m in self._models if m.model.document.is_valid] def create_model(self, doc: KritaDocument): - model = Model(doc, self._connection) + model = Model(doc, self._connection, self._workflows) self._recent.track(model) persistence_sync = ModelSync(model) import_prompt_from_file(model) @@ -82,6 +84,10 @@ def server(self): def files(self) -> FileLibrary: return self._files + @property + def workflows(self) -> WorkflowCollection: + return self._workflows + @property def auto_update(self) -> AutoUpdate: return self._auto_update @@ -90,7 +96,7 @@ def auto_update(self) -> AutoUpdate: def active_model(self): if model := self.model_for_active_document(): return model - return Model(Document(), self._connection) + return Model(Document(), self._connection, self._workflows) async def autostart(self, signal_server_change: Callable): connection = self._connection @@ -148,7 +154,7 @@ def _update_files(self): ] self._files.checkpoints.update(checkpoints, FileSource.remote) - loras = [File.remote(lora) for lora in client.models.loras] + loras = [File.remote(lora, FileFormat.lora) for lora in client.models.loras] self._files.loras.update(loras, FileSource.remote) diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index 8bb19dd701..240e9736e2 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -32,7 +32,7 @@ from .region import RegionPromptWidget from . import theme -from ..custom_workflow import CustomParam, ParamKind +from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows class HistoryWidget(QListWidget): @@ -962,7 +962,9 @@ def __init__(self): self._workspace_select = WorkspaceSelectWidget(self) self._workflow_select = QComboBox(self) + self._workflow_select.setModel(SortedWorkflows(root.workflows)) self._workflow_select.currentIndexChanged.connect(self._change_workflow) + self._params_widget = WorkflowParamsWidget([], self) self._generate_button = GenerateButton(JobKind.diffusion, self) @@ -989,25 +991,6 @@ def __init__(self): self._layout.addWidget(self._history) self.setLayout(self._layout) - self._update_workflows() - root.connection.workflow_published.connect(self._update_workflows) - - def _update_workflows(self): - workflows = set(root.connection.workflows) - if self.model.custom.graph_id: - workflows.add(self.model.custom.graph_id) - current_items = [ - self._workflow_select.itemText(i) for i in range(self._workflow_select.count()) - ] - for key in workflows: - if key not in current_items: - self._workflow_select.addItem(key, key) - for item in current_items: - if item not in workflows: - self._workflow_select.removeItem(self._workflow_select.findText(item)) - - self._update_current_workflow() - def _update_current_workflow(self): if not self.model.custom.workflow: return @@ -1018,10 +1001,7 @@ def _update_current_workflow(self): self._params_widget.value_changed.connect(self._change_params) def _change_workflow(self): - id = self._workflow_select.currentData() - if graph := root.connection.workflows.get(id): - self.model.custom.graph = graph - self.model.custom.graph_id = id + self.model.custom.workflow_id = self._workflow_select.currentData() def _change_params(self): self.model.custom.params = self._params_widget.value @@ -1037,7 +1017,7 @@ def model(self, model: Model): self._model = model self._model_bindings = [ bind(model, "workspace", self._workspace_select, "value", Bind.one_way), - bind_combo(model.custom, "graph_id", self._workflow_select, Bind.one_way), + bind_combo(model.custom, "workflow_id", self._workflow_select, Bind.one_way), model.custom.graph_changed.connect(self._update_current_workflow), model.error_changed.connect(self._error_text.setText), model.has_error_changed.connect(self._error_text.setVisible), @@ -1046,7 +1026,7 @@ def model(self, model: Model): self._queue_button.model = model self._progress_bar.model = model self._history.model_ = model - self._update_workflows() + self._update_current_workflow() def apply_result(self, item: QListWidgetItem): job_id, index = self._history.item_info(item) diff --git a/ai_diffusion/ui/style.py b/ai_diffusion/ui/style.py index 715ac9b8b9..bd25649cf4 100644 --- a/ai_diffusion/ui/style.py +++ b/ai_diffusion/ui/style.py @@ -389,7 +389,7 @@ def _upload_lora(self): if client.max_upload_size and path.stat().st_size > client.max_upload_size: _show_file_too_large_warning(client.max_upload_size, self) return - file = File.local(path, compute_hash=True) + file = File.local(path, FileFormat.lora, compute_hash=True) root.files.loras.add(file) self._add_item(file) From 06631f7cb30d037f7798cb7810d26e68a6c3319e Mon Sep 17 00:00:00 2001 From: Acly Date: Mon, 30 Sep 2024 20:07:33 +0200 Subject: [PATCH 08/28] Move & add some tests for custom workflows --- tests/test_custom_workflow.py | 150 ++++++++++++++++++++++++++++++++++ tests/test_workflow.py | 37 --------- 2 files changed, 150 insertions(+), 37 deletions(-) create mode 100644 tests/test_custom_workflow.py diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py new file mode 100644 index 0000000000..95280ca654 --- /dev/null +++ b/tests/test_custom_workflow.py @@ -0,0 +1,150 @@ +import json +from pathlib import Path +from PyQt5.QtCore import Qt + +from ai_diffusion.api import CustomWorkflowInput, ImageInput, SamplingInput +from ai_diffusion.connection import Connection +from ai_diffusion.comfy_workflow import ComfyNode, ComfyWorkflow, Output +from ai_diffusion.custom_workflow import CustomWorkflow, WorkflowSource, WorkflowCollection +from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace, CustomParam +from ai_diffusion.image import Image, Extent +from ai_diffusion import workflow + + +def test_workflow_collection(tmp_path: Path): + file1 = tmp_path / "file1.json" + file1.write_text('{"file": 1}') + file2 = tmp_path / "file2.json" + file2.write_text('{"file": 2}') + + connection = Connection() + connection_workflows = { + "connection1": {"connection": 1}, + } + connection._workflows = connection_workflows + + collection = WorkflowCollection(connection, tmp_path) + assert len(collection) == 3 + assert collection.find("file1") == CustomWorkflow("file1", WorkflowSource.local, {"file": 1}) + assert collection.find("file2") == CustomWorkflow("file2", WorkflowSource.local, {"file": 2}) + assert collection.find("connection1") == CustomWorkflow( + "connection1", WorkflowSource.remote, {"connection": 1} + ) + + events = [] + + def on_begin_insert(index, first, last): + events.append(("begin_insert", first)) + + def on_end_insert(): + events.append("end_insert") + + def on_data_changed(start, end): + events.append(("data_changed", start.row())) + + collection.rowsAboutToBeInserted.connect(on_begin_insert) + collection.rowsInserted.connect(on_end_insert) + collection.dataChanged.connect(on_data_changed) + + connection_workflows["connection2"] = {"connection": 2} + connection.workflow_published.emit("connection2") + + assert len(collection) == 4 + assert collection.find("connection2") == CustomWorkflow( + "connection2", WorkflowSource.remote, {"connection": 2} + ) + + collection.set_graph(collection.index(0), {"file": 3}) + assert collection.find("file1") == CustomWorkflow("file1", WorkflowSource.local, {"file": 3}) + + assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)] + + collection.append(CustomWorkflow("doc1", WorkflowSource.document, {"doc": 1})) + + sorted = SortedWorkflows(collection) + assert sorted[0].source is WorkflowSource.document + assert sorted[1].source is WorkflowSource.remote + assert sorted[2].source is WorkflowSource.remote + assert sorted[3].name == "file1" + assert sorted[4].name == "file2" + + +def test_workspace(): + connection = Connection() + connection_workflows = { + "connection1": { + "1": { + "class_type": "ETN_IntParameter", + "inputs": {"name": "param1", "default": 42, "min": 5, "max": 95}, + } + } + } + connection._workflows = connection_workflows + workflows = WorkflowCollection(connection) + + workspace = CustomWorkspace(workflows) + assert workspace.workflow_id == "connection1" + assert workspace.workflow and workspace.workflow.id == "connection1" + assert workspace.graph and workspace.graph.node(0).type == "ETN_IntParameter" + assert workspace.metadata[0].name == "param1" + assert workspace.params == {"param1": 42} + + doc_graph = { + "1": { + "class_type": "ETN_IntParameter", + "inputs": {"name": "param2", "default": 23, "min": 9, "max": 35}, + } + } + workspace.set_graph("doc1", doc_graph) + assert workspace.workflow_id == "doc1" + assert workspace.workflow and workspace.workflow.source is WorkflowSource.document + assert workspace.graph and workspace.graph.node(0).type == "ETN_IntParameter" + assert workspace.metadata[0].name == "param2" + assert workspace.params == {"param2": 23} + + doc_graph["1"]["inputs"]["default"] = 24 + doc_graph["2"] = { + "class_type": "ETN_IntParameter", + "inputs": {"name": "param3", "default": 7, "min": 0, "max": 10}, + } + workflows.set_graph(workflows.index(1), doc_graph) + assert workspace.metadata[0].default == 24 + assert workspace.metadata[1].name == "param3" + assert workspace.params == {"param2": 23, "param3": 7} + + +def test_import_workflow(): + w = ComfyWorkflow.import_graph( + { + "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, + "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}}, + "9": {"class_type": "B", "inputs": {"in": ["4", 0]}}, + } + ) + assert w.node(0) == ComfyNode(0, "A", {"int": 4, "float": 1.2, "string": "mouse"}) + assert w.node(1) == ComfyNode(1, "B", {"in": Output(0, 0)}) + assert w.node(2) == ComfyNode(2, "C", {"in": Output(1, 1)}) + + +def test_expand_workflow(): + ext = ComfyWorkflow() + in_img, width, height, seed = ext.add("ETN_KritaCanvas", 4) + scaled = ext.add("ImageScale", 1, image=in_img, width=width, height=height) + ext.add("ETN_KritaOutput", 1, images=scaled) + ext.add("SeedEater", 1, seed=seed) + + input = CustomWorkflowInput(workflow=ext.root, params={}) + images = ImageInput.from_extent(Extent(4, 4)) + images.initial_image = Image.create(Extent(4, 4), Qt.GlobalColor.white) + sampling = SamplingInput("", "", 1.0, 1000, seed=123) + + w = ComfyWorkflow() + w = workflow.expand_custom(w, input, images, sampling) + expected = [ + ComfyNode(1, "ETN_LoadImageBase64", {"image": images.initial_image.to_base64()}), + ComfyNode(2, "ImageScale", {"image": Output(1, 0), "width": 4, "height": 4}), + ComfyNode(3, "ETN_KritaOutput", {"images": Output(2, 0)}), + ComfyNode(4, "SeedEater", {"seed": 123}), + ] + for node in expected: + assert node in w, f"Node {node} not found in\n{json.dumps(w.root, indent=2)}" diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 8cb99b22e2..4820465beb 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -753,43 +753,6 @@ def test_translation(qtapp, client): run_and_save(qtapp, client, job, "test_translation") -def test_import_workflow(): - w = ComfyWorkflow.import_graph( - { - "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, - "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}}, - "9": {"class_type": "B", "inputs": {"in": ["4", 0]}}, - } - ) - assert w.node(0) == ComfyNode(0, "A", {"int": 4, "float": 1.2, "string": "mouse"}) - assert w.node(1) == ComfyNode(1, "B", {"in": Output(0, 0)}) - assert w.node(2) == ComfyNode(2, "C", {"in": Output(1, 1)}) - - -def test_expand_workflow(): - ext = ComfyWorkflow() - in_img, width, height, seed = ext.add("ETN_KritaCanvas", 4) - scaled = ext.add("ImageScale", 1, image=in_img, width=width, height=height) - ext.add("ETN_KritaOutput", 1, images=scaled) - ext.add("SeedEater", 1, seed=seed) - - input = CustomWorkflowInput(workflow=ext.root, params={}) - images = ImageInput.from_extent(Extent(4, 4)) - images.initial_image = Image.create(Extent(4, 4), Qt.GlobalColor.white) - sampling = SamplingInput("", "", 1.0, 1000, seed=123) - - w = ComfyWorkflow() - w = workflow.expand_custom(w, input, images, sampling) - expected = [ - ComfyNode(1, "ETN_LoadImageBase64", {"image": images.initial_image.to_base64()}), - ComfyNode(2, "ImageScale", {"image": Output(1, 0), "width": 4, "height": 4}), - ComfyNode(3, "ETN_KritaOutput", {"images": Output(2, 0)}), - ComfyNode(4, "SeedEater", {"seed": 123}), - ] - for node in expected: - assert node in w, f"Node {node} not found in\n{json.dumps(w.root, indent=2)}" - - inpaint_benchmark = { "tori": (InpaintMode.fill, "photo of tori, japanese garden", None), "bruges": (InpaintMode.fill, "photo of a canal in bruges, belgium", None), From 0a0f465abb3187a85e0e66c5b66219f80fa69791 Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 1 Oct 2024 13:43:43 +0200 Subject: [PATCH 09/28] Save & import custom workflows with files --- ai_diffusion/custom_workflow.py | 32 +++++++++ ai_diffusion/ui/generation.py | 119 ++++++++++++++++++++++++++++++-- tests/test_custom_workflow.py | 58 ++++++++++++---- 3 files changed, 193 insertions(+), 16 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 9fc484058b..963ff0d44b 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -85,6 +85,31 @@ def set_graph(self, index: QModelIndex, graph: dict): self._workflows[index.row()].graph = graph self.dataChanged.emit(index, index) + def save_as(self, id: str, graph: dict): + if self.find(id) is not None: + suffix = 1 + while self.find(f"{id} ({suffix})"): + suffix += 1 + id = f"{id} ({suffix})" + + self._folder.mkdir(exist_ok=True) + path = self._folder / f"{id}.json" + path.write_text(json.dumps(graph, indent=2)) + self.append(CustomWorkflow(id, WorkflowSource.local, graph, path)) + return id + + def import_file(self, filepath: Path): + try: + with filepath.open("r") as f: + graph = json.load(f) + try: + ComfyWorkflow.import_graph(graph) + except Exception as e: + raise RuntimeError(f"This is not a supported workflow file ({e})") + return self.save_as(filepath.stem, graph) + except Exception as e: + raise RuntimeError(f"Error importing workflow from {filepath}: {e}") + def find_index(self, id: str): for i, wf in enumerate(self._workflows): if wf.id == id: @@ -208,6 +233,13 @@ def set_graph(self, id: str, graph: dict): self._workflows.append(CustomWorkflow(id, WorkflowSource.document, graph)) self.workflow_id = id + def import_file(self, filepath: Path): + self.workflow_id = self._workflows.import_file(filepath) + + def save_as(self, id: str): + assert self._graph, "Save as: no workflow selected" + self.workflow_id = self._workflows.save_as(id, self._graph.root) + @property def workflow(self): return self._workflow diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index 240e9736e2..e7011430b1 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -1,9 +1,11 @@ from __future__ import annotations from enum import Enum +from functools import wraps +from pathlib import Path from textwrap import wrap as wrap_text from typing import Any, NamedTuple -from PyQt5.QtCore import Qt, QMetaObject, QSize, QPoint, QUuid, pyqtSignal -from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor +from PyQt5.QtCore import Qt, QMetaObject, QSize, QPoint, QUuid, pyqtSignal, QUrl +from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor, QDesktopServices from PyQt5.QtWidgets import QAction, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QProgressBar from PyQt5.QtWidgets import ( QLabel, @@ -14,6 +16,8 @@ QToolButton, QSlider, QSpinBox, + QFileDialog, + QLineEdit, ) from PyQt5.QtWidgets import QComboBox, QCheckBox, QMenu, QShortcut, QMessageBox, QGridLayout @@ -954,17 +958,62 @@ def value(self, values: dict[str, Any]): widget.value = value +def popup_on_error(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except Exception as e: + QMessageBox.critical(self, _("Error"), str(e)) + + return wrapper + + class CustomWorkflowWidget(QWidget): def __init__(self): super().__init__() + self._model = root.active_model self._model_bindings: list[QMetaObject.Connection | Binding] = [] self._workspace_select = WorkspaceSelectWidget(self) - self._workflow_select = QComboBox(self) + + self._workflow_select_widgets = QWidget(self) + + self._workflow_select = QComboBox(self._workflow_select_widgets) self._workflow_select.setModel(SortedWorkflows(root.workflows)) self._workflow_select.currentIndexChanged.connect(self._change_workflow) + self._import_workflow_button = QToolButton(self._workflow_select_widgets) + self._import_workflow_button.setText("I") + self._import_workflow_button.setToolTip(_("Import workflow from file")) + self._import_workflow_button.clicked.connect(self._import_workflow) + + self._save_workflow_button = QToolButton(self._workflow_select_widgets) + self._save_workflow_button.setText("S") + self._save_workflow_button.setToolTip(_("Save workflow to file")) + self._save_workflow_button.clicked.connect(self._save_workflow) + + self._open_webui_button = QToolButton(self._workflow_select_widgets) + self._open_webui_button.setText("W") + self._open_webui_button.setToolTip(_("Open Web UI to create custom workflows")) + self._open_webui_button.clicked.connect(self._open_webui) + + self._workflow_edit_widgets = QWidget(self) + self._workflow_edit_widgets.setVisible(False) + + self._workflow_name_edit = QLineEdit(self._workflow_edit_widgets) + self._workflow_name_edit.textEdited.connect(self._edit_name) + self._workflow_name_edit.returnPressed.connect(self._accept_name) + + self._accept_name_button = QToolButton(self._workflow_edit_widgets) + self._accept_name_button.setText("✔") + self._accept_name_button.clicked.connect(self._accept_name) + + self._cancel_name_button = QToolButton(self._workflow_edit_widgets) + self._cancel_name_button.setText("✘") + self._cancel_name_button.clicked.connect(self._cancel_name) + self._params_widget = WorkflowParamsWidget([], self) self._generate_button = GenerateButton(JobKind.diffusion, self) @@ -977,9 +1026,23 @@ def __init__(self): self._history.item_activated.connect(self.apply_result) self._layout = QVBoxLayout() + select_layout = QHBoxLayout() + select_layout.setContentsMargins(0, 0, 0, 0) + select_layout.addWidget(self._workflow_select) + select_layout.addWidget(self._import_workflow_button) + select_layout.addWidget(self._save_workflow_button) + select_layout.addWidget(self._open_webui_button) + self._workflow_select_widgets.setLayout(select_layout) + edit_layout = QHBoxLayout() + edit_layout.setContentsMargins(0, 0, 0, 0) + edit_layout.addWidget(self._workflow_name_edit) + edit_layout.addWidget(self._accept_name_button) + edit_layout.addWidget(self._cancel_name_button) + self._workflow_edit_widgets.setLayout(edit_layout) header_layout = QHBoxLayout() header_layout.addWidget(self._workspace_select) - header_layout.addWidget(self._workflow_select) + header_layout.addWidget(self._workflow_select_widgets) + header_layout.addWidget(self._workflow_edit_widgets) self._layout.addLayout(header_layout) self._layout.addWidget(self._params_widget) actions_layout = QHBoxLayout() @@ -993,7 +1056,10 @@ def __init__(self): def _update_current_workflow(self): if not self.model.custom.workflow: + self._save_workflow_button.setEnabled(False) return + self._save_workflow_button.setEnabled(True) + self._params_widget.deleteLater() self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self) self._params_widget.value = self.model.custom.params @@ -1018,6 +1084,7 @@ def model(self, model: Model): self._model_bindings = [ bind(model, "workspace", self._workspace_select, "value", Bind.one_way), bind_combo(model.custom, "workflow_id", self._workflow_select, Bind.one_way), + model.workspace_changed.connect(self._cancel_name), model.custom.graph_changed.connect(self._update_current_workflow), model.error_changed.connect(self._error_text.setText), model.has_error_changed.connect(self._error_text.setVisible), @@ -1031,3 +1098,47 @@ def model(self, model: Model): def apply_result(self, item: QListWidgetItem): job_id, index = self._history.item_info(item) self.model.apply_generated_result(job_id, index) + + @popup_on_error + def _import_workflow(self, *args): + filename, __ = QFileDialog.getOpenFileName( + self, + _("Import Workflow"), + str(Path.home()), + "Workflow Files (*.json);;All Files (*)", + ) + if filename: + self.model.custom.import_file(Path(filename)) + + def _save_workflow(self): + self.is_edit_mode = True + + def _open_webui(self): + if client := root.connection.client_if_connected: + QDesktopServices.openUrl(QUrl(client.url)) + + @property + def is_edit_mode(self): + return self._workflow_edit_widgets.isVisible() + + @is_edit_mode.setter + def is_edit_mode(self, value: bool): + if value == self.is_edit_mode: + return + self._workflow_select_widgets.setVisible(not value) + self._workflow_edit_widgets.setVisible(value) + if value: + self._workflow_name_edit.setText(self.model.custom.workflow_id) + self._workflow_name_edit.selectAll() + self._workflow_name_edit.setFocus() + + def _edit_name(self): + self._accept_name_button.setEnabled(self._workflow_name_edit.text().strip() != "") + + @popup_on_error + def _accept_name(self, *args): + self.model.custom.save_as(self._workflow_name_edit.text()) + self.is_edit_mode = False + + def _cancel_name(self): + self.is_edit_mode = False diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 95280ca654..689aaa8386 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -1,4 +1,5 @@ import json +import pytest from pathlib import Path from PyQt5.QtCore import Qt @@ -11,7 +12,7 @@ from ai_diffusion import workflow -def test_workflow_collection(tmp_path: Path): +def test_collection(tmp_path: Path): file1 = tmp_path / "file1.json" file1.write_text('{"file": 1}') file2 = tmp_path / "file2.json" @@ -69,16 +70,49 @@ def on_data_changed(start, end): assert sorted[4].name == "file2" -def test_workspace(): - connection = Connection() - connection_workflows = { - "connection1": { - "1": { - "class_type": "ETN_IntParameter", - "inputs": {"name": "param1", "default": 42, "min": 5, "max": 95}, - } +def make_dummy_graph(n: int = 42): + return { + "1": { + "class_type": "ETN_IntParameter", + "inputs": {"name": "param1", "default": n, "min": 5, "max": 95}, } } + + +def test_files(tmp_path: Path): + collection_folder = tmp_path / "workflows" + + collection = WorkflowCollection(Connection(), collection_folder) + assert len(collection) == 0 + + file1 = tmp_path / "file1.json" + file1.write_text(json.dumps(make_dummy_graph())) + + collection.import_file(file1) + assert collection.find("file1") is not None + + collection.import_file(file1) + assert collection.find("file1 (1)") is not None + + collection.save_as("file1", {"file": 2}) + assert collection.find("file1 (2)") is not None + + files = [ + collection_folder / "file1.json", + collection_folder / "file1 (1).json", + collection_folder / "file1 (2).json", + ] + assert all(f.exists() for f in files) + + bad_file = tmp_path / "bad.json" + bad_file.write_text("bad json") + with pytest.raises(RuntimeError): + collection.import_file(bad_file) + + +def test_workspace(): + connection = Connection() + connection_workflows = {"connection1": make_dummy_graph(42)} connection._workflows = connection_workflows workflows = WorkflowCollection(connection) @@ -92,7 +126,7 @@ def test_workspace(): doc_graph = { "1": { "class_type": "ETN_IntParameter", - "inputs": {"name": "param2", "default": 23, "min": 9, "max": 35}, + "inputs": {"name": "param2", "default": 23, "min": 5, "max": 95}, } } workspace.set_graph("doc1", doc_graph) @@ -113,7 +147,7 @@ def test_workspace(): assert workspace.params == {"param2": 23, "param3": 7} -def test_import_workflow(): +def test_import(): w = ComfyWorkflow.import_graph( { "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, @@ -126,7 +160,7 @@ def test_import_workflow(): assert w.node(2) == ComfyNode(2, "C", {"in": Output(1, 1)}) -def test_expand_workflow(): +def test_expand(): ext = ComfyWorkflow() in_img, width, height, seed = ext.add("ETN_KritaCanvas", 4) scaled = ext.add("ImageScale", 1, image=in_img, width=width, height=height) From 7df76e0d437e7c257c7e501bc07a7bf5dd53665c Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 2 Oct 2024 00:30:19 +0200 Subject: [PATCH 10/28] Custom workflow icons --- ai_diffusion/custom_workflow.py | 13 +++ ai_diffusion/icons/comfyui-dark.svg | 96 +++++++++++++++++++++ ai_diffusion/icons/comfyui-light.svg | 96 +++++++++++++++++++++ ai_diffusion/icons/file-json-dark.svg | 96 +++++++++++++++++++++ ai_diffusion/icons/file-json-light.svg | 96 +++++++++++++++++++++ ai_diffusion/icons/file-kra-dark.svg | 92 ++++++++++++++++++++ ai_diffusion/icons/file-kra-light.svg | 92 ++++++++++++++++++++ ai_diffusion/icons/import-dark.svg | 66 ++++++++++++++ ai_diffusion/icons/import-light.svg | 66 ++++++++++++++ ai_diffusion/icons/save-dark.svg | 90 +++++++++++++++++++ ai_diffusion/icons/save-light.svg | 90 +++++++++++++++++++ ai_diffusion/icons/web-connection-dark.svg | 87 +++++++++++++++++++ ai_diffusion/icons/web-connection-light.svg | 87 +++++++++++++++++++ ai_diffusion/ui/generation.py | 64 ++++++++------ 14 files changed, 1106 insertions(+), 25 deletions(-) create mode 100644 ai_diffusion/icons/comfyui-dark.svg create mode 100644 ai_diffusion/icons/comfyui-light.svg create mode 100644 ai_diffusion/icons/file-json-dark.svg create mode 100644 ai_diffusion/icons/file-json-light.svg create mode 100644 ai_diffusion/icons/file-kra-dark.svg create mode 100644 ai_diffusion/icons/file-kra-light.svg create mode 100644 ai_diffusion/icons/import-dark.svg create mode 100644 ai_diffusion/icons/import-light.svg create mode 100644 ai_diffusion/icons/save-dark.svg create mode 100644 ai_diffusion/icons/save-light.svg create mode 100644 ai_diffusion/icons/web-connection-dark.svg create mode 100644 ai_diffusion/icons/web-connection-light.svg diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 963ff0d44b..f090c05a73 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -11,6 +11,7 @@ from .connection import Connection from .properties import Property, ObservableProperties from .util import user_data_dir, client_logger as log +from .ui import theme class WorkflowSource(Enum): @@ -36,6 +37,11 @@ def workflow(self): class WorkflowCollection(QAbstractListModel): + + _icon_local = theme.icon("file-json") + _icon_remote = theme.icon("web-connection") + _icon_document = theme.icon("file-kra") + def __init__(self, connection: Connection, folder: Path | None = None): super().__init__() self._workflows: list[CustomWorkflow] = [] @@ -74,6 +80,13 @@ def data(self, index: QModelIndex, role: int = 0): return self._workflows[index.row()].name if role == Qt.ItemDataRole.UserRole: return self._workflows[index.row()].id + if role == Qt.ItemDataRole.DecorationRole: + source = self._workflows[index.row()].source + if source is WorkflowSource.document: + return self._icon_document + if source is WorkflowSource.remote: + return self._icon_remote + return self._icon_local def append(self, item: CustomWorkflow): end = len(self._workflows) diff --git a/ai_diffusion/icons/comfyui-dark.svg b/ai_diffusion/icons/comfyui-dark.svg new file mode 100644 index 0000000000..c8c916dc29 --- /dev/null +++ b/ai_diffusion/icons/comfyui-dark.svg @@ -0,0 +1,96 @@ + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + diff --git a/ai_diffusion/icons/comfyui-light.svg b/ai_diffusion/icons/comfyui-light.svg new file mode 100644 index 0000000000..03d88d1696 --- /dev/null +++ b/ai_diffusion/icons/comfyui-light.svg @@ -0,0 +1,96 @@ + + + + + + + + + + + + image/svg+xml + + + + + + + + + + + + diff --git a/ai_diffusion/icons/file-json-dark.svg b/ai_diffusion/icons/file-json-dark.svg new file mode 100644 index 0000000000..20143f5e32 --- /dev/null +++ b/ai_diffusion/icons/file-json-dark.svg @@ -0,0 +1,96 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + diff --git a/ai_diffusion/icons/file-json-light.svg b/ai_diffusion/icons/file-json-light.svg new file mode 100644 index 0000000000..7d726a8249 --- /dev/null +++ b/ai_diffusion/icons/file-json-light.svg @@ -0,0 +1,96 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + diff --git a/ai_diffusion/icons/file-kra-dark.svg b/ai_diffusion/icons/file-kra-dark.svg new file mode 100644 index 0000000000..f2fca41442 --- /dev/null +++ b/ai_diffusion/icons/file-kra-dark.svg @@ -0,0 +1,92 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + diff --git a/ai_diffusion/icons/file-kra-light.svg b/ai_diffusion/icons/file-kra-light.svg new file mode 100644 index 0000000000..99ff41d617 --- /dev/null +++ b/ai_diffusion/icons/file-kra-light.svg @@ -0,0 +1,92 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + diff --git a/ai_diffusion/icons/import-dark.svg b/ai_diffusion/icons/import-dark.svg new file mode 100644 index 0000000000..5dd0ce7eac --- /dev/null +++ b/ai_diffusion/icons/import-dark.svg @@ -0,0 +1,66 @@ + + + + + + + + image/svg+xml + + + + + + + + + + diff --git a/ai_diffusion/icons/import-light.svg b/ai_diffusion/icons/import-light.svg new file mode 100644 index 0000000000..4a19b201df --- /dev/null +++ b/ai_diffusion/icons/import-light.svg @@ -0,0 +1,66 @@ + + + + + + + + image/svg+xml + + + + + + + + + + diff --git a/ai_diffusion/icons/save-dark.svg b/ai_diffusion/icons/save-dark.svg new file mode 100644 index 0000000000..11db345bce --- /dev/null +++ b/ai_diffusion/icons/save-dark.svg @@ -0,0 +1,90 @@ + + + + + + + + + + + + image/svg+xml + + + + + + + + + + diff --git a/ai_diffusion/icons/save-light.svg b/ai_diffusion/icons/save-light.svg new file mode 100644 index 0000000000..b8bc1e32a6 --- /dev/null +++ b/ai_diffusion/icons/save-light.svg @@ -0,0 +1,90 @@ + + + + + + + + + + + + image/svg+xml + + + + + + + + + + diff --git a/ai_diffusion/icons/web-connection-dark.svg b/ai_diffusion/icons/web-connection-dark.svg new file mode 100644 index 0000000000..dcd351f355 --- /dev/null +++ b/ai_diffusion/icons/web-connection-dark.svg @@ -0,0 +1,87 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + diff --git a/ai_diffusion/icons/web-connection-light.svg b/ai_diffusion/icons/web-connection-light.svg new file mode 100644 index 0000000000..fa1338e6e5 --- /dev/null +++ b/ai_diffusion/icons/web-connection-light.svg @@ -0,0 +1,87 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index e7011430b1..a54baf5f3a 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -3,9 +3,9 @@ from functools import wraps from pathlib import Path from textwrap import wrap as wrap_text -from typing import Any, NamedTuple +from typing import Any, Callable, NamedTuple from PyQt5.QtCore import Qt, QMetaObject, QSize, QPoint, QUuid, pyqtSignal, QUrl -from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor, QDesktopServices +from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor, QDesktopServices, QIcon from PyQt5.QtWidgets import QAction, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QProgressBar from PyQt5.QtWidgets import ( QLabel, @@ -25,6 +25,7 @@ from ..image import Bounds, Extent, Image from ..jobs import Job, JobQueue, JobState, JobKind, JobParams from ..model import Model, InpaintContext, RootRegion, ProgressKind +from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows from ..style import Styles from ..root import root from ..workflow import InpaintMode, FillMode @@ -36,8 +37,6 @@ from .region import RegionPromptWidget from . import theme -from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows - class HistoryWidget(QListWidget): _model: Model @@ -969,6 +968,16 @@ def wrapper(self, *args, **kwargs): return wrapper +def _create_tool_button(parent: QWidget, icon: QIcon, tooltip: str, handler: Callable[..., None]): + button = QToolButton(parent) + button.setIcon(icon) + button.setToolTip(tooltip) + button.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly) + button.setAutoRaise(True) + button.clicked.connect(handler) + return button + + class CustomWorkflowWidget(QWidget): def __init__(self): super().__init__() @@ -984,20 +993,24 @@ def __init__(self): self._workflow_select.setModel(SortedWorkflows(root.workflows)) self._workflow_select.currentIndexChanged.connect(self._change_workflow) - self._import_workflow_button = QToolButton(self._workflow_select_widgets) - self._import_workflow_button.setText("I") - self._import_workflow_button.setToolTip(_("Import workflow from file")) - self._import_workflow_button.clicked.connect(self._import_workflow) - - self._save_workflow_button = QToolButton(self._workflow_select_widgets) - self._save_workflow_button.setText("S") - self._save_workflow_button.setToolTip(_("Save workflow to file")) - self._save_workflow_button.clicked.connect(self._save_workflow) - - self._open_webui_button = QToolButton(self._workflow_select_widgets) - self._open_webui_button.setText("W") - self._open_webui_button.setToolTip(_("Open Web UI to create custom workflows")) - self._open_webui_button.clicked.connect(self._open_webui) + self._import_workflow_button = _create_tool_button( + self._workflow_select_widgets, + theme.icon("import"), + _("Import workflow from file"), + self._import_workflow, + ) + self._save_workflow_button = _create_tool_button( + self._workflow_select_widgets, + theme.icon("save"), + _("Save workflow to file"), + self._save_workflow, + ) + self._open_webui_button = _create_tool_button( + self._workflow_select_widgets, + theme.icon("comfyui"), + _("Open Web UI to create custom workflows"), + self._open_webui, + ) self._workflow_edit_widgets = QWidget(self) self._workflow_edit_widgets.setVisible(False) @@ -1006,13 +1019,12 @@ def __init__(self): self._workflow_name_edit.textEdited.connect(self._edit_name) self._workflow_name_edit.returnPressed.connect(self._accept_name) - self._accept_name_button = QToolButton(self._workflow_edit_widgets) - self._accept_name_button.setText("✔") - self._accept_name_button.clicked.connect(self._accept_name) - - self._cancel_name_button = QToolButton(self._workflow_edit_widgets) - self._cancel_name_button.setText("✘") - self._cancel_name_button.clicked.connect(self._cancel_name) + self._accept_name_button = _create_tool_button( + self._workflow_edit_widgets, theme.icon("apply"), _("Apply"), self._accept_name + ) + self._cancel_name_button = _create_tool_button( + self._workflow_edit_widgets, theme.icon("cancel"), _("Cancel"), self._cancel_name + ) self._params_widget = WorkflowParamsWidget([], self) @@ -1028,6 +1040,7 @@ def __init__(self): self._layout = QVBoxLayout() select_layout = QHBoxLayout() select_layout.setContentsMargins(0, 0, 0, 0) + select_layout.setSpacing(2) select_layout.addWidget(self._workflow_select) select_layout.addWidget(self._import_workflow_button) select_layout.addWidget(self._save_workflow_button) @@ -1035,6 +1048,7 @@ def __init__(self): self._workflow_select_widgets.setLayout(select_layout) edit_layout = QHBoxLayout() edit_layout.setContentsMargins(0, 0, 0, 0) + edit_layout.setSpacing(2) edit_layout.addWidget(self._workflow_name_edit) edit_layout.addWidget(self._accept_name_button) edit_layout.addWidget(self._cancel_name_button) From a7e3408a8ae69b08414e4806579b8ec82e2756c9 Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 2 Oct 2024 17:03:16 +0200 Subject: [PATCH 11/28] Button to delete local (file based) custom workflows --- ai_diffusion/custom_workflow.py | 20 +++++++++++++++++++- ai_diffusion/model.py | 2 +- ai_diffusion/ui/generation.py | 25 ++++++++++++++++++++++++- tests/test_custom_workflow.py | 17 +++++++++++++---- 4 files changed, 57 insertions(+), 7 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index f090c05a73..5f09651671 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -63,7 +63,7 @@ def _process_remote_workflow(self, id: str): def _process_file(self, file: Path): with file.open("r") as f: - self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f))) + self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f), file)) def _process(self, workflow: CustomWorkflow): idx = self.find_index(workflow.id) @@ -94,6 +94,16 @@ def append(self, item: CustomWorkflow): self._workflows.append(item) self.endInsertRows() + def remove(self, id: str): + idx = self.find_index(id) + if idx.isValid(): + wf = self._workflows[idx.row()] + if wf.source is WorkflowSource.local and wf.path is not None: + wf.path.unlink() + self.beginRemoveRows(QModelIndex(), idx.row(), idx.row()) + self._workflows.pop(idx.row()) + self.endRemoveRows() + def set_graph(self, index: QModelIndex, graph: dict): self._workflows[index.row()].graph = graph self.dataChanged.emit(index, index) @@ -253,6 +263,14 @@ def save_as(self, id: str): assert self._graph, "Save as: no workflow selected" self.workflow_id = self._workflows.save_as(id, self._graph.root) + def remove_workflow(self): + if id := self.workflow_id: + self._workflow_id = "" + self._workflow = None + self._graph = None + self._metadata = [] + self._workflows.remove(id) + @property def workflow(self): return self._workflow diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 4ecec0a6a7..d0206f9774 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -370,7 +370,7 @@ def generate_custom(self): sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed), custom_workflow=CustomWorkflowInput(wf.root, {}), ) - job_params = JobParams(bounds, self.custom.graph_id) + job_params = JobParams(bounds, self.custom.workflow_id) except Exception as e: self.report_error(util.log_error(e)) return diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index a54baf5f3a..f6a92c0174 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -25,7 +25,7 @@ from ..image import Bounds, Extent, Image from ..jobs import Job, JobQueue, JobState, JobKind, JobParams from ..model import Model, InpaintContext, RootRegion, ProgressKind -from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows +from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows, WorkflowSource from ..style import Styles from ..root import root from ..workflow import InpaintMode, FillMode @@ -1005,6 +1005,12 @@ def __init__(self): _("Save workflow to file"), self._save_workflow, ) + self._delete_workflow_button = _create_tool_button( + self._workflow_select_widgets, + theme.icon("discard"), + _("Delete the currently selected workflow"), + self._delete_workflow, + ) self._open_webui_button = _create_tool_button( self._workflow_select_widgets, theme.icon("comfyui"), @@ -1044,6 +1050,7 @@ def __init__(self): select_layout.addWidget(self._workflow_select) select_layout.addWidget(self._import_workflow_button) select_layout.addWidget(self._save_workflow_button) + select_layout.addWidget(self._delete_workflow_button) select_layout.addWidget(self._open_webui_button) self._workflow_select_widgets.setLayout(select_layout) edit_layout = QHBoxLayout() @@ -1071,8 +1078,12 @@ def __init__(self): def _update_current_workflow(self): if not self.model.custom.workflow: self._save_workflow_button.setEnabled(False) + self._delete_workflow_button.setEnabled(False) return self._save_workflow_button.setEnabled(True) + self._delete_workflow_button.setEnabled( + self.model.custom.workflow.source is WorkflowSource.local + ) self._params_widget.deleteLater() self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self) @@ -1127,6 +1138,18 @@ def _import_workflow(self, *args): def _save_workflow(self): self.is_edit_mode = True + def _delete_workflow(self): + filepath = ensure(self.model.custom.workflow).path + q = QMessageBox.question( + self, + _("Delete Workflow"), + _("Are you sure you want to delete the current workflow?") + f"\n{filepath}", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.StandardButton.No, + ) + if q == QMessageBox.StandardButton.Yes: + self.model.custom.remove_workflow() + def _open_webui(self): if client := root.connection.client_if_connected: QDesktopServices.openUrl(QUrl(client.url)) diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 689aaa8386..844e3176b4 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -26,8 +26,12 @@ def test_collection(tmp_path: Path): collection = WorkflowCollection(connection, tmp_path) assert len(collection) == 3 - assert collection.find("file1") == CustomWorkflow("file1", WorkflowSource.local, {"file": 1}) - assert collection.find("file2") == CustomWorkflow("file2", WorkflowSource.local, {"file": 2}) + assert collection.find("file1") == CustomWorkflow( + "file1", WorkflowSource.local, {"file": 1}, file1 + ) + assert collection.find("file2") == CustomWorkflow( + "file2", WorkflowSource.local, {"file": 2}, file2 + ) assert collection.find("connection1") == CustomWorkflow( "connection1", WorkflowSource.remote, {"connection": 1} ) @@ -56,8 +60,9 @@ def on_data_changed(start, end): ) collection.set_graph(collection.index(0), {"file": 3}) - assert collection.find("file1") == CustomWorkflow("file1", WorkflowSource.local, {"file": 3}) - + assert collection.find("file1") == CustomWorkflow( + "file1", WorkflowSource.local, {"file": 3}, file1 + ) assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)] collection.append(CustomWorkflow("doc1", WorkflowSource.document, {"doc": 1})) @@ -104,6 +109,10 @@ def test_files(tmp_path: Path): ] assert all(f.exists() for f in files) + collection.remove("file1 (1)") + assert collection.find("file1 (1)") is None + assert not (collection_folder / "file1 (1).json").exists() + bad_file = tmp_path / "bad.json" bad_file.write_text("bad json") with pytest.raises(RuntimeError): From 865f283ce6ccfb5e74a3d748f5bc015d4190bb81 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 3 Oct 2024 17:47:21 +0200 Subject: [PATCH 12/28] More parameter nodes: number, bool, text (prompt) --- ai_diffusion/comfy_workflow.py | 10 +- ai_diffusion/custom_workflow.py | 34 ++++++- ai_diffusion/ui/generation.py | 165 ++++++++++++++++++++++++++++++-- ai_diffusion/workflow.py | 23 ++++- tests/test_custom_workflow.py | 71 +++++++++++++- 5 files changed, 287 insertions(+), 16 deletions(-) diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 21917bee6e..a986b1bcaa 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -40,13 +40,21 @@ def input(self, key: str, default: None = None) -> Input: ... def input(self, key: str, default: T | None = None) -> T | Input: result = self.inputs[key] - assert default is None or type(result) == type(default) + assert ( + default is None + or type(result) == type(default) + or (isnumber(result) and isnumber(default)) + ) return result def output(self, index=0) -> Output: return Output(int(self.id), index) +def isnumber(x): + return isinstance(x, (int, float)) + + class ComfyWorkflow: """Builder for workflows which can be sent to the ComfyUI prompt API.""" diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 5f09651671..a6fdd3c06d 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -182,17 +182,22 @@ class ParamKind(Enum): image_layer = 0 mask_layer = 1 number_int = 2 + number_float = 3 + boolean = 4 + text = 5 + prompt_positive = 6 + prompt_negative = 7 class CustomParam(NamedTuple): kind: ParamKind name: str default: Any | None = None - min: int | None = None - max: int | None = None + min: int | float | None = None + max: int | float | None = None -def _gather_params(w: ComfyWorkflow): +def workflow_parameters(w: ComfyWorkflow): for node in w: match node.type: case "ETN_KritaImageLayer": @@ -207,6 +212,27 @@ def _gather_params(w: ComfyWorkflow): min = node.input("min", -(2**31)) max = node.input("max", 2**31) yield CustomParam(ParamKind.number_int, name, default=default, min=min, max=max) + case "ETN_NumberParameter": + name = node.input("name", "Parameter") + default = node.input("default", 0.0) + min = node.input("min", 0.0) + max = node.input("max", 1.0) + yield CustomParam(ParamKind.number_float, name, default=default, min=min, max=max) + case "ETN_BoolParameter": + name = node.input("name", "Parameter") + default = node.input("default", False) + yield CustomParam(ParamKind.boolean, name, default=default) + case "ETN_TextParameter": + name = node.input("name", "Parameter") + default = node.input("default", "") + type = node.input("type", "general") + match type: + case "general": + yield CustomParam(ParamKind.text, name, default=default) + case "prompt (positive)": + yield CustomParam(ParamKind.prompt_positive, name, default=default) + case "prompt (negative)": + yield CustomParam(ParamKind.prompt_negative, name, default=default) class CustomWorkspace(QObject, ObservableProperties): @@ -239,7 +265,7 @@ def _update_workflow(self, idx: QModelIndex, _: QModelIndex): if wf.id == self._workflow_id: self._workflow = wf self._graph = self._workflow.workflow - self._metadata = list(_gather_params(self._graph)) + self._metadata = list(workflow_parameters(self._graph)) self.params = _coerce(self.params, self._metadata) self.graph_changed.emit() diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index f6a92c0174..e3528f589c 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -6,6 +6,7 @@ from typing import Any, Callable, NamedTuple from PyQt5.QtCore import Qt, QMetaObject, QSize, QPoint, QUuid, pyqtSignal, QUrl from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor, QDesktopServices, QIcon +from PyQt5.QtGui import QFontMetrics from PyQt5.QtWidgets import QAction, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QProgressBar from PyQt5.QtWidgets import ( QLabel, @@ -18,6 +19,8 @@ QSpinBox, QFileDialog, QLineEdit, + QDoubleSpinBox, + QFrame, ) from PyQt5.QtWidgets import QComboBox, QCheckBox, QMenu, QShortcut, QMessageBox, QGridLayout @@ -33,8 +36,9 @@ from ..localization import translate as _ from ..util import ensure, flatten from .widget import WorkspaceSelectWidget, StyleSelectWidget, StrengthWidget, QueueButton -from .widget import GenerateButton, create_wide_tool_button +from .widget import GenerateButton, TextPromptWidget, create_wide_tool_button from .region import RegionPromptWidget +from .switch import SwitchWidget from . import theme @@ -874,7 +878,7 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None): assert param.min is not None and param.max is not None and param.default is not None if param.max - param.min <= 200: self._widget = QSlider(Qt.Orientation.Horizontal, parent) - self._widget.setRange(param.min, param.max) + self._widget.setRange(int(param.min), int(param.max)) self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) self._widget.valueChanged.connect(self._notify) self._label = QLabel(self) @@ -882,15 +886,13 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None): self._label.setAlignment(Qt.AlignmentFlag.AlignRight) layout.addWidget(self._widget) layout.addWidget(self._label) - self.setLayout(layout) else: self._widget = QSpinBox(parent) - self._widget.setRange(param.min, param.max) + self._widget.setRange(int(param.min), int(param.max)) self._widget.valueChanged.connect(self._notify) self._label = None layout = QHBoxLayout(self) layout.addWidget(self._widget) - self.setLayout(layout) self.value = param.default @@ -908,7 +910,150 @@ def value(self, value: int): self._widget.setValue(value) -CustomParamWidget = LayerSelect | IntParamWidget +class FloatParamWidget(QWidget): + value_changed = pyqtSignal() + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(parent) + self.setContentsMargins(0, 0, 0, 0) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + + assert param.min is not None and param.max is not None and param.default is not None + if param.max - param.min <= 20: + self._widget = QSlider(Qt.Orientation.Horizontal, parent) + self._widget.setRange(round(param.min * 100), round(param.max * 100)) + self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) + self._widget.valueChanged.connect(self._notify) + self._label = QLabel(self) + self._label.setFixedWidth(40) + self._label.setAlignment(Qt.AlignmentFlag.AlignRight) + layout.addWidget(self._widget) + layout.addWidget(self._label) + else: + self._widget = QDoubleSpinBox(parent) + self._widget.setRange(param.min, param.max) + self._widget.valueChanged.connect(self._notify) + self._label = None + layout = QHBoxLayout(self) + layout.addWidget(self._widget) + + self.value = param.default + + def _notify(self): + if self._label: + self._label.setText(f"{self.value:.2f}") + self.value_changed.emit() + + @property + def value(self): + if isinstance(self._widget, QSlider): + return self._widget.value() / 100 + else: + return self._widget.value() + + @value.setter + def value(self, value: float): + if isinstance(self._widget, QSlider): + self._widget.setValue(round(value * 100)) + else: + self._widget.setValue(value) + + +class BoolParamWidget(QWidget): + value_changed = pyqtSignal() + + _true_text = _("On") + _false_text = _("Off") + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(parent) + self.setContentsMargins(0, 0, 0, 0) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + + fm = QFontMetrics(self.font()) + self._label = QLabel(self) + self._label.setMinimumWidth(max(fm.width(self._true_text), fm.width(self._false_text)) + 4) + self._widget = SwitchWidget(parent) + self._widget.toggled.connect(self._notify) + layout.addWidget(self._widget) + layout.addWidget(self._label) + + assert isinstance(param.default, bool) + self.value = param.default + + def _notify(self): + self._label.setText(self._true_text if self.value else self._false_text) + self.value_changed.emit() + + @property + def value(self): + return self._widget.isChecked() + + @value.setter + def value(self, value: bool): + self._widget.setChecked(value) + + +class TextParamWidget(QLineEdit): + value_changed = pyqtSignal() + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(parent) + assert isinstance(param.default, str) + + self.value = param.default + self.textChanged.connect(self._notify) + + def _notify(self): + self.value_changed.emit() + + @property + def value(self): + return self.text() + + @value.setter + def value(self, value: str): + self.setText(value) + + +class PromptParamWidget(TextPromptWidget): + value_changed = pyqtSignal() + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(is_negative=param.kind is ParamKind.prompt_negative, parent=parent) + assert isinstance(param.default, str) + + self.setObjectName("PromptParam") + self.setFrameStyle(QFrame.Shape.StyledPanel) + self.setStyleSheet( + f"QFrame#PromptParam {{ background-color: {theme.base}; border: 1px solid {theme.line_base}; }}" + ) + self.text = param.default + self.text_changed.connect(self.value_changed) + + @property + def value(self): + return self.text + + @value.setter + def value(self, value: str): + self.text = value + + +CustomParamWidget = ( + LayerSelect + | IntParamWidget + | FloatParamWidget + | BoolParamWidget + | TextParamWidget + | PromptParamWidget +) def _create_param_widget(param: CustomParam, parent: QWidget): @@ -918,6 +1063,14 @@ def _create_param_widget(param: CustomParam, parent: QWidget): return LayerSelect("mask", parent) if param.kind is ParamKind.number_int: return IntParamWidget(param, parent) + if param.kind is ParamKind.number_float: + return FloatParamWidget(param, parent) + if param.kind is ParamKind.boolean: + return BoolParamWidget(param, parent) + if param.kind is ParamKind.text: + return TextParamWidget(param, parent) + if param.kind in [ParamKind.prompt_positive, ParamKind.prompt_negative]: + return PromptParamWidget(param, parent) assert False, f"Unknown param kind: {param.kind}" diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index 7a744e1d22..12cfe13a17 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -1052,12 +1052,22 @@ def expand_custom( def map_input(input): if isinstance(input, Output): - if mapped := outputs.get(input): + mapped = outputs.get(input) + if mapped is not None: return mapped else: return Output(nodes[input.node], input.output) return input + def get_param(node: ComfyNode, expected_type: type | None = None): + name = node.input("name", "") + value = input.params.get(name) + if value is None: + raise Exception(f"Missing required parameter '{name}' for custom workflow") + if expected_type and not isinstance(value, expected_type): + raise Exception(f"Parameter '{name}' must be of type {expected_type}") + return value + for node in custom: match node.type: case "ETN_KritaCanvas": @@ -1068,6 +1078,17 @@ def map_input(input): outputs[node.output(3)] = sampling.seed case "ETN_KritaSelection": outputs[node.output(0)] = w.load_mask(ensure(images.hires_mask)) + case "ETN_KritaImageLayer": + outputs[node.output(0)] = w.load_image(get_param(node, Image)) + case "ETN_KritaMaskLayer": + outputs[node.output(0)] = w.load_mask(get_param(node, Image)) + case ( + "ETN_IntParameter" + | "ETN_NumberParameter" + | "ETN_BoolParameter" + | "ETN_TextParameter" + ): + outputs[node.output(0)] = get_param(node) case _: mapped_inputs = {k: map_input(v) for k, v in node.inputs.items()} mapped = ComfyNode(node.id, node.type, mapped_inputs) diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 844e3176b4..6ce33a9c91 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -7,7 +7,8 @@ from ai_diffusion.connection import Connection from ai_diffusion.comfy_workflow import ComfyNode, ComfyWorkflow, Output from ai_diffusion.custom_workflow import CustomWorkflow, WorkflowSource, WorkflowCollection -from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace, CustomParam +from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace +from ai_diffusion.custom_workflow import CustomParam, ParamKind, workflow_parameters from ai_diffusion.image import Image, Extent from ai_diffusion import workflow @@ -169,14 +170,62 @@ def test_import(): assert w.node(2) == ComfyNode(2, "C", {"in": Output(1, 1)}) +def test_parameters(): + w = ComfyWorkflow() + w.add("ETN_IntParameter", 1, name="int", default=4, min=0, max=10) + w.add("ETN_BoolParameter", 1, name="bool", default=True) + w.add("ETN_NumberParameter", 1, name="number", default=1.2, min=0.0, max=10.0) + w.add("ETN_TextParameter", 1, name="text", type="general", default="mouse") + w.add("ETN_TextParameter", 1, name="positive", type="prompt (positive)", default="p") + w.add("ETN_TextParameter", 1, name="negative", type="prompt (negative)", default="n") + w.add("ETN_KritaImageLayer", 1, name="image") + w.add("ETN_KritaMaskLayer", 1, name="mask") + + assert list(workflow_parameters(w)) == [ + CustomParam(ParamKind.number_int, "int", 4, 0, 10), + CustomParam(ParamKind.boolean, "bool", True), + CustomParam(ParamKind.number_float, "number", 1.2, 0.0, 10.0), + CustomParam(ParamKind.text, "text", "mouse"), + CustomParam(ParamKind.prompt_positive, "positive", "p"), + CustomParam(ParamKind.prompt_negative, "negative", "n"), + CustomParam(ParamKind.image_layer, "image"), + CustomParam(ParamKind.mask_layer, "mask"), + ] + + def test_expand(): ext = ComfyWorkflow() in_img, width, height, seed = ext.add("ETN_KritaCanvas", 4) scaled = ext.add("ImageScale", 1, image=in_img, width=width, height=height) ext.add("ETN_KritaOutput", 1, images=scaled) - ext.add("SeedEater", 1, seed=seed) + inty = ext.add("ETN_IntParameter", 1, name="inty", default=4, min=0, max=10) + numby = ext.add("ETN_NumberParameter", 1, name="numby", default=1.2, min=0.0, max=10.0) + texty = ext.add("ETN_TextParameter", 1, name="texty", type="general", default="mouse") + booly = ext.add("ETN_BoolParameter", 1, name="booly", default=True) + layer_img = ext.add("ETN_KritaImageLayer", 1, name="layer_img") + layer_mask = ext.add("ETN_KritaMaskLayer", 1, name="layer_mask") + ext.add( + "Sink", + 1, + seed=seed, + inty=inty, + numby=numby, + texty=texty, + booly=booly, + layer_img=layer_img, + layer_mask=layer_mask, + ) + + params = { + "inty": 7, + "numby": 3.4, + "texty": "cat", + "booly": False, + "layer_img": Image.create(Extent(4, 4), Qt.GlobalColor.black), + "layer_mask": Image.create(Extent(4, 4), Qt.GlobalColor.white), + } - input = CustomWorkflowInput(workflow=ext.root, params={}) + input = CustomWorkflowInput(workflow=ext.root, params=params) images = ImageInput.from_extent(Extent(4, 4)) images.initial_image = Image.create(Extent(4, 4), Qt.GlobalColor.white) sampling = SamplingInput("", "", 1.0, 1000, seed=123) @@ -187,7 +236,21 @@ def test_expand(): ComfyNode(1, "ETN_LoadImageBase64", {"image": images.initial_image.to_base64()}), ComfyNode(2, "ImageScale", {"image": Output(1, 0), "width": 4, "height": 4}), ComfyNode(3, "ETN_KritaOutput", {"images": Output(2, 0)}), - ComfyNode(4, "SeedEater", {"seed": 123}), + ComfyNode(4, "ETN_LoadImageBase64", {"image": params["layer_img"].to_base64()}), + ComfyNode(5, "ETN_LoadMaskBase64", {"mask": params["layer_mask"].to_base64()}), + ComfyNode( + 6, + "Sink", + { + "seed": 123, + "inty": 7, + "numby": 3.4, + "texty": "cat", + "booly": False, + "layer_img": Output(4, 0), + "layer_mask": Output(5, 0), + }, + ), ] for node in expected: assert node in w, f"Node {node} not found in\n{json.dumps(w.root, indent=2)}" From 4699f72f286c429869da6f8371b6c354aaab4220 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 3 Oct 2024 19:49:33 +0200 Subject: [PATCH 13/28] Make image/mask layer params work --- ai_diffusion/model.py | 13 +++++++++++-- ai_diffusion/ui/generation.py | 6 +++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index d0206f9774..36c3bd6282 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -19,7 +19,7 @@ from .image import Extent, Image, Mask, Bounds, DummyImage from .client import ClientMessage, ClientEvent, SharedWorkflow from .client import filter_supported_styles, resolve_arch -from .custom_workflow import CustomWorkspace, WorkflowCollection +from .custom_workflow import CustomWorkspace, WorkflowCollection, ParamKind from .document import Document, KritaDocument from .layer import Layer, LayerType, RestoreActiveLayer from .pose import Pose @@ -364,11 +364,20 @@ def generate_custom(self): else: img_input.hires_mask = Mask.transparent(bounds).to_image() + params = copy(self.custom.params) + for md in self.custom.metadata: + if md.kind is ParamKind.image_layer: + layer = self.layers.find(QUuid(params[md.name])) + params[md.name] = ensure(layer).get_pixels(bounds) + elif md.kind is ParamKind.mask_layer: + layer = self.layers.find(QUuid(params[md.name])) + params[md.name] = ensure(layer).get_mask(bounds) + input = WorkflowInput( WorkflowKind.custom, img_input, sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed), - custom_workflow=CustomWorkflowInput(wf.root, {}), + custom_workflow=CustomWorkflowInput(wf.root, params), ) job_params = JobParams(bounds, self.custom.workflow_id) except Exception as e: diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index e3528f589c..0d149d4a23 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -854,12 +854,12 @@ def _update(self): i += 1 @property - def value(self): - return self.currentData() + def value(self) -> str: + return self.currentData().toString() @value.setter def value(self, value: str): - i = self.findData(value) + i = self.findData(QUuid(value)) if i != -1 and i != self.currentIndex(): self.setCurrentIndex(i) From 5356b9dbc80b401e7f261d57b0ee8f3bd69e4639 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 3 Oct 2024 20:13:10 +0200 Subject: [PATCH 14/28] Move custom workflow ui to its own file --- ai_diffusion/model.py | 8 +- ai_diffusion/ui/custom_workflow.py | 536 ++++++++++++++++++++++++++++ ai_diffusion/ui/diffusion.py | 3 +- ai_diffusion/ui/generation.py | 548 +---------------------------- 4 files changed, 551 insertions(+), 544 deletions(-) create mode 100644 ai_diffusion/ui/custom_workflow.py diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 36c3bd6282..1b8dd11115 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -368,10 +368,14 @@ def generate_custom(self): for md in self.custom.metadata: if md.kind is ParamKind.image_layer: layer = self.layers.find(QUuid(params[md.name])) - params[md.name] = ensure(layer).get_pixels(bounds) + if layer is None: + raise ValueError(f"Input layer for parameter {md.name} not found") + params[md.name] = layer.get_pixels(bounds) elif md.kind is ParamKind.mask_layer: layer = self.layers.find(QUuid(params[md.name])) - params[md.name] = ensure(layer).get_mask(bounds) + if layer is None: + raise ValueError(f"Input layer for parameter {md.name} not found") + params[md.name] = layer.get_mask(bounds) input = WorkflowInput( WorkflowKind.custom, diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py new file mode 100644 index 0000000000..2a86efa531 --- /dev/null +++ b/ai_diffusion/ui/custom_workflow.py @@ -0,0 +1,536 @@ +from functools import wraps +from pathlib import Path +from typing import Any, Callable + +from PyQt5.QtCore import Qt, pyqtSignal, QMetaObject, QUuid, QUrl +from PyQt5.QtGui import QFontMetrics, QIcon, QDesktopServices +from PyQt5.QtWidgets import QComboBox, QFileDialog, QFrame, QGridLayout, QHBoxLayout +from PyQt5.QtWidgets import QLabel, QLineEdit, QListWidgetItem, QMessageBox, QSpinBox +from PyQt5.QtWidgets import QToolButton, QVBoxLayout, QWidget, QSlider, QDoubleSpinBox + +from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows, WorkflowSource +from ..jobs import JobKind +from ..model import Model +from ..properties import Binding, Bind, bind, bind_combo +from ..root import root +from ..localization import translate as _ +from ..util import ensure +from .generation import GenerateButton, ProgressBar, QueueButton, HistoryWidget, create_error_label +from .switch import SwitchWidget +from .widget import TextPromptWidget, WorkspaceSelectWidget +from . import theme + + +class LayerSelect(QComboBox): + value_changed = pyqtSignal() + + def __init__(self, filter: str | None = None, parent: QWidget | None = None): + super().__init__(parent) + self.setContentsMargins(0, 0, 0, 0) + self.filter = filter + self.currentIndexChanged.connect(lambda _: self.value_changed.emit()) + + self._update() + root.active_model.layers.changed.connect(self._update) + + def _update(self): + if self.filter is None: + layers = root.active_model.layers.all + elif self.filter == "image": + layers = root.active_model.layers.images + elif self.filter == "mask": + layers = root.active_model.layers.masks + else: + assert False, f"Unknown filter: {self.filter}" + + for l in layers: + if self.findData(l.id) == -1: + self.addItem(l.name, l.id) + i = 0 + while i < self.count(): + if self.itemData(i) not in (l.id for l in layers): + self.removeItem(i) + else: + i += 1 + + @property + def value(self) -> str: + if self.currentIndex() == -1: + return "" + return self.currentData().toString() + + @value.setter + def value(self, value: str): + i = self.findData(QUuid(value)) + if i != -1 and i != self.currentIndex(): + self.setCurrentIndex(i) + + +class IntParamWidget(QWidget): + value_changed = pyqtSignal() + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(parent) + self.setContentsMargins(0, 0, 0, 0) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + + assert param.min is not None and param.max is not None and param.default is not None + if param.max - param.min <= 200: + self._widget = QSlider(Qt.Orientation.Horizontal, parent) + self._widget.setRange(int(param.min), int(param.max)) + self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) + self._widget.valueChanged.connect(self._notify) + self._label = QLabel(self) + self._label.setFixedWidth(40) + self._label.setAlignment(Qt.AlignmentFlag.AlignRight) + layout.addWidget(self._widget) + layout.addWidget(self._label) + else: + self._widget = QSpinBox(parent) + self._widget.setRange(int(param.min), int(param.max)) + self._widget.valueChanged.connect(self._notify) + self._label = None + layout = QHBoxLayout(self) + layout.addWidget(self._widget) + + self.value = param.default + + def _notify(self): + if self._label: + self._label.setText(str(self._widget.value())) + self.value_changed.emit() + + @property + def value(self): + return self._widget.value() + + @value.setter + def value(self, value: int): + self._widget.setValue(value) + + +class FloatParamWidget(QWidget): + value_changed = pyqtSignal() + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(parent) + self.setContentsMargins(0, 0, 0, 0) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + + assert param.min is not None and param.max is not None and param.default is not None + if param.max - param.min <= 20: + self._widget = QSlider(Qt.Orientation.Horizontal, parent) + self._widget.setRange(round(param.min * 100), round(param.max * 100)) + self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) + self._widget.valueChanged.connect(self._notify) + self._label = QLabel(self) + self._label.setFixedWidth(40) + self._label.setAlignment(Qt.AlignmentFlag.AlignRight) + layout.addWidget(self._widget) + layout.addWidget(self._label) + else: + self._widget = QDoubleSpinBox(parent) + self._widget.setRange(param.min, param.max) + self._widget.valueChanged.connect(self._notify) + self._label = None + layout = QHBoxLayout(self) + layout.addWidget(self._widget) + + self.value = param.default + + def _notify(self): + if self._label: + self._label.setText(f"{self.value:.2f}") + self.value_changed.emit() + + @property + def value(self): + if isinstance(self._widget, QSlider): + return self._widget.value() / 100 + else: + return self._widget.value() + + @value.setter + def value(self, value: float): + if isinstance(self._widget, QSlider): + self._widget.setValue(round(value * 100)) + else: + self._widget.setValue(value) + + +class BoolParamWidget(QWidget): + value_changed = pyqtSignal() + + _true_text = _("On") + _false_text = _("Off") + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(parent) + self.setContentsMargins(0, 0, 0, 0) + + layout = QHBoxLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + self.setLayout(layout) + + fm = QFontMetrics(self.font()) + self._label = QLabel(self) + self._label.setMinimumWidth(max(fm.width(self._true_text), fm.width(self._false_text)) + 4) + self._widget = SwitchWidget(parent) + self._widget.toggled.connect(self._notify) + layout.addWidget(self._widget) + layout.addWidget(self._label) + + assert isinstance(param.default, bool) + self.value = param.default + + def _notify(self): + self._label.setText(self._true_text if self.value else self._false_text) + self.value_changed.emit() + + @property + def value(self): + return self._widget.isChecked() + + @value.setter + def value(self, value: bool): + self._widget.setChecked(value) + + +class TextParamWidget(QLineEdit): + value_changed = pyqtSignal() + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(parent) + assert isinstance(param.default, str) + + self.value = param.default + self.textChanged.connect(self._notify) + + def _notify(self): + self.value_changed.emit() + + @property + def value(self): + return self.text() + + @value.setter + def value(self, value: str): + self.setText(value) + + +class PromptParamWidget(TextPromptWidget): + value_changed = pyqtSignal() + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(is_negative=param.kind is ParamKind.prompt_negative, parent=parent) + assert isinstance(param.default, str) + + self.setObjectName("PromptParam") + self.setFrameStyle(QFrame.Shape.StyledPanel) + self.setStyleSheet( + f"QFrame#PromptParam {{ background-color: {theme.base}; border: 1px solid {theme.line_base}; }}" + ) + self.text = param.default + self.text_changed.connect(self.value_changed) + + @property + def value(self): + return self.text + + @value.setter + def value(self, value: str): + self.text = value + + +CustomParamWidget = ( + LayerSelect + | IntParamWidget + | FloatParamWidget + | BoolParamWidget + | TextParamWidget + | PromptParamWidget +) + + +def _create_param_widget(param: CustomParam, parent: QWidget): + if param.kind is ParamKind.image_layer: + return LayerSelect("image", parent) + if param.kind is ParamKind.mask_layer: + return LayerSelect("mask", parent) + if param.kind is ParamKind.number_int: + return IntParamWidget(param, parent) + if param.kind is ParamKind.number_float: + return FloatParamWidget(param, parent) + if param.kind is ParamKind.boolean: + return BoolParamWidget(param, parent) + if param.kind is ParamKind.text: + return TextParamWidget(param, parent) + if param.kind in [ParamKind.prompt_positive, ParamKind.prompt_negative]: + return PromptParamWidget(param, parent) + assert False, f"Unknown param kind: {param.kind}" + + +class WorkflowParamsWidget(QWidget): + value_changed = pyqtSignal() + + def __init__(self, params: list[CustomParam], parent: QWidget | None = None): + super().__init__(parent) + self._widgets: dict[str, CustomParamWidget] = {} + + layout = QGridLayout(self) + layout.setContentsMargins(0, 0, 0, 0) + layout.setColumnMinimumWidth(1, 10) + self.setLayout(layout) + + for p in params: + label = QLabel(p.name, self) + widget = _create_param_widget(p, self) + widget.value_changed.connect(self._notify) + row = len(self._widgets) + layout.addWidget(label, row, 0) + layout.addWidget(widget, row, 2) + self._widgets[p.name] = widget + + def _notify(self): + self.value_changed.emit() + + @property + def value(self): + return {name: widget.value for name, widget in self._widgets.items()} + + @value.setter + def value(self, values: dict[str, Any]): + for name, value in values.items(): + if widget := self._widgets.get(name): + if type(widget.value) == type(value): + widget.value = value + + +def popup_on_error(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + except Exception as e: + QMessageBox.critical(self, _("Error"), str(e)) + + return wrapper + + +def _create_tool_button(parent: QWidget, icon: QIcon, tooltip: str, handler: Callable[..., None]): + button = QToolButton(parent) + button.setIcon(icon) + button.setToolTip(tooltip) + button.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly) + button.setAutoRaise(True) + button.clicked.connect(handler) + return button + + +class CustomWorkflowWidget(QWidget): + def __init__(self): + super().__init__() + + self._model = root.active_model + self._model_bindings: list[QMetaObject.Connection | Binding] = [] + + self._workspace_select = WorkspaceSelectWidget(self) + + self._workflow_select_widgets = QWidget(self) + + self._workflow_select = QComboBox(self._workflow_select_widgets) + self._workflow_select.setModel(SortedWorkflows(root.workflows)) + self._workflow_select.currentIndexChanged.connect(self._change_workflow) + + self._import_workflow_button = _create_tool_button( + self._workflow_select_widgets, + theme.icon("import"), + _("Import workflow from file"), + self._import_workflow, + ) + self._save_workflow_button = _create_tool_button( + self._workflow_select_widgets, + theme.icon("save"), + _("Save workflow to file"), + self._save_workflow, + ) + self._delete_workflow_button = _create_tool_button( + self._workflow_select_widgets, + theme.icon("discard"), + _("Delete the currently selected workflow"), + self._delete_workflow, + ) + self._open_webui_button = _create_tool_button( + self._workflow_select_widgets, + theme.icon("comfyui"), + _("Open Web UI to create custom workflows"), + self._open_webui, + ) + + self._workflow_edit_widgets = QWidget(self) + self._workflow_edit_widgets.setVisible(False) + + self._workflow_name_edit = QLineEdit(self._workflow_edit_widgets) + self._workflow_name_edit.textEdited.connect(self._edit_name) + self._workflow_name_edit.returnPressed.connect(self._accept_name) + + self._accept_name_button = _create_tool_button( + self._workflow_edit_widgets, theme.icon("apply"), _("Apply"), self._accept_name + ) + self._cancel_name_button = _create_tool_button( + self._workflow_edit_widgets, theme.icon("cancel"), _("Cancel"), self._cancel_name + ) + + self._params_widget = WorkflowParamsWidget([], self) + + self._generate_button = GenerateButton(JobKind.diffusion, self) + self._queue_button = QueueButton(parent=self) + self._queue_button.setFixedHeight(self._generate_button.height() - 2) + self._progress_bar = ProgressBar(self) + self._error_text = create_error_label(self) + + self._history = HistoryWidget(self) + self._history.item_activated.connect(self.apply_result) + + self._layout = QVBoxLayout() + select_layout = QHBoxLayout() + select_layout.setContentsMargins(0, 0, 0, 0) + select_layout.setSpacing(2) + select_layout.addWidget(self._workflow_select) + select_layout.addWidget(self._import_workflow_button) + select_layout.addWidget(self._save_workflow_button) + select_layout.addWidget(self._delete_workflow_button) + select_layout.addWidget(self._open_webui_button) + self._workflow_select_widgets.setLayout(select_layout) + edit_layout = QHBoxLayout() + edit_layout.setContentsMargins(0, 0, 0, 0) + edit_layout.setSpacing(2) + edit_layout.addWidget(self._workflow_name_edit) + edit_layout.addWidget(self._accept_name_button) + edit_layout.addWidget(self._cancel_name_button) + self._workflow_edit_widgets.setLayout(edit_layout) + header_layout = QHBoxLayout() + header_layout.addWidget(self._workspace_select) + header_layout.addWidget(self._workflow_select_widgets) + header_layout.addWidget(self._workflow_edit_widgets) + self._layout.addLayout(header_layout) + self._layout.addWidget(self._params_widget) + actions_layout = QHBoxLayout() + actions_layout.addWidget(self._generate_button) + actions_layout.addWidget(self._queue_button) + self._layout.addLayout(actions_layout) + self._layout.addWidget(self._progress_bar) + self._layout.addWidget(self._error_text) + self._layout.addWidget(self._history) + self.setLayout(self._layout) + + def _update_current_workflow(self): + if not self.model.custom.workflow: + self._save_workflow_button.setEnabled(False) + self._delete_workflow_button.setEnabled(False) + return + self._save_workflow_button.setEnabled(True) + self._delete_workflow_button.setEnabled( + self.model.custom.workflow.source is WorkflowSource.local + ) + + self._params_widget.deleteLater() + self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self) + self._params_widget.value = self.model.custom.params + self._layout.insertWidget(1, self._params_widget) + self._params_widget.value_changed.connect(self._change_params) + + def _change_workflow(self): + self.model.custom.workflow_id = self._workflow_select.currentData() + + def _change_params(self): + self.model.custom.params = self._params_widget.value + + @property + def model(self): + return self._model + + @model.setter + def model(self, model: Model): + if self._model != model: + Binding.disconnect_all(self._model_bindings) + self._model = model + self._model_bindings = [ + bind(model, "workspace", self._workspace_select, "value", Bind.one_way), + bind_combo(model.custom, "workflow_id", self._workflow_select, Bind.one_way), + model.workspace_changed.connect(self._cancel_name), + model.custom.graph_changed.connect(self._update_current_workflow), + model.error_changed.connect(self._error_text.setText), + model.has_error_changed.connect(self._error_text.setVisible), + self._generate_button.clicked.connect(model.generate_custom), + ] + self._queue_button.model = model + self._progress_bar.model = model + self._history.model_ = model + self._update_current_workflow() + + def apply_result(self, item: QListWidgetItem): + job_id, index = self._history.item_info(item) + self.model.apply_generated_result(job_id, index) + + @popup_on_error + def _import_workflow(self, *args): + filename, __ = QFileDialog.getOpenFileName( + self, + _("Import Workflow"), + str(Path.home()), + "Workflow Files (*.json);;All Files (*)", + ) + if filename: + self.model.custom.import_file(Path(filename)) + + def _save_workflow(self): + self.is_edit_mode = True + + def _delete_workflow(self): + filepath = ensure(self.model.custom.workflow).path + q = QMessageBox.question( + self, + _("Delete Workflow"), + _("Are you sure you want to delete the current workflow?") + f"\n{filepath}", + QMessageBox.Yes | QMessageBox.No, + QMessageBox.StandardButton.No, + ) + if q == QMessageBox.StandardButton.Yes: + self.model.custom.remove_workflow() + + def _open_webui(self): + if client := root.connection.client_if_connected: + QDesktopServices.openUrl(QUrl(client.url)) + + @property + def is_edit_mode(self): + return self._workflow_edit_widgets.isVisible() + + @is_edit_mode.setter + def is_edit_mode(self, value: bool): + if value == self.is_edit_mode: + return + self._workflow_select_widgets.setVisible(not value) + self._workflow_edit_widgets.setVisible(value) + if value: + self._workflow_name_edit.setText(self.model.custom.workflow_id) + self._workflow_name_edit.selectAll() + self._workflow_name_edit.setFocus() + + def _edit_name(self): + self._accept_name_button.setEnabled(self._workflow_name_edit.text().strip() != "") + + @popup_on_error + def _accept_name(self, *args): + self.model.custom.save_as(self._workflow_name_edit.text()) + self.is_edit_mode = False + + def _cancel_name(self): + self.is_edit_mode = False diff --git a/ai_diffusion/ui/diffusion.py b/ai_diffusion/ui/diffusion.py index 79e4a3e47c..f30d741900 100644 --- a/ai_diffusion/ui/diffusion.py +++ b/ai_diffusion/ui/diffusion.py @@ -13,7 +13,8 @@ from ..root import root from ..localization import translate as _ from . import theme -from .generation import GenerationWidget, CustomWorkflowWidget +from .generation import GenerationWidget +from .custom_workflow import CustomWorkflowWidget from .upscale import UpscaleWidget from .live import LiveWidget from .animation import AnimationWidget diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index 0d149d4a23..67541227de 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -1,44 +1,23 @@ from __future__ import annotations -from enum import Enum -from functools import wraps -from pathlib import Path from textwrap import wrap as wrap_text -from typing import Any, Callable, NamedTuple -from PyQt5.QtCore import Qt, QMetaObject, QSize, QPoint, QUuid, pyqtSignal, QUrl -from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor, QDesktopServices, QIcon -from PyQt5.QtGui import QFontMetrics +from PyQt5.QtCore import Qt, QMetaObject, QSize, QPoint, QUuid, pyqtSignal +from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor from PyQt5.QtWidgets import QAction, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QProgressBar -from PyQt5.QtWidgets import ( - QLabel, - QListWidget, - QListWidgetItem, - QListView, - QSizePolicy, - QToolButton, - QSlider, - QSpinBox, - QFileDialog, - QLineEdit, - QDoubleSpinBox, - QFrame, -) -from PyQt5.QtWidgets import QComboBox, QCheckBox, QMenu, QShortcut, QMessageBox, QGridLayout +from PyQt5.QtWidgets import QLabel, QListWidget, QListWidgetItem, QListView, QSizePolicy +from PyQt5.QtWidgets import QComboBox, QCheckBox, QMenu, QShortcut, QMessageBox, QToolButton from ..properties import Binding, Bind, bind, bind_combo, bind_toggle from ..image import Bounds, Extent, Image from ..jobs import Job, JobQueue, JobState, JobKind, JobParams from ..model import Model, InpaintContext, RootRegion, ProgressKind -from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows, WorkflowSource from ..style import Styles from ..root import root from ..workflow import InpaintMode, FillMode -from ..comfy_workflow import ComfyWorkflow, ComfyNode, Input, Output from ..localization import translate as _ from ..util import ensure, flatten from .widget import WorkspaceSelectWidget, StyleSelectWidget, StrengthWidget, QueueButton -from .widget import GenerateButton, TextPromptWidget, create_wide_tool_button +from .widget import GenerateButton, create_wide_tool_button from .region import RegionPromptWidget -from .switch import SwitchWidget from . import theme @@ -564,7 +543,7 @@ def _update_progress(self): self.setValue(min(99, self.value() + 2)) -def _create_error_label(parent: QWidget): +def create_error_label(parent: QWidget): label = QLabel(parent) label.setStyleSheet("font-weight: bold; color: red;") label.setWordWrap(True) @@ -646,7 +625,7 @@ def __init__(self): self.progress_bar = ProgressBar(self) layout.addWidget(self.progress_bar) - self.error_text = _create_error_label(self) + self.error_text = create_error_label(self) layout.addWidget(self.error_text) self.history = HistoryWidget(self) @@ -819,516 +798,3 @@ def update_generate_button(self): True: theme.icon("region-alpha-active"), False: theme.icon("region-alpha"), } - - -class LayerSelect(QComboBox): - value_changed = pyqtSignal() - - def __init__(self, filter: str | None = None, parent: QWidget | None = None): - super().__init__(parent) - self.setContentsMargins(0, 0, 0, 0) - self.filter = filter - self.currentIndexChanged.connect(lambda _: self.value_changed.emit()) - - self._update() - root.active_model.layers.changed.connect(self._update) - - def _update(self): - if self.filter is None: - layers = root.active_model.layers.all - elif self.filter == "image": - layers = root.active_model.layers.images - elif self.filter == "mask": - layers = root.active_model.layers.masks - else: - assert False, f"Unknown filter: {self.filter}" - - for l in layers: - if self.findData(l.id) == -1: - self.addItem(l.name, l.id) - i = 0 - while i < self.count(): - if self.itemData(i) not in (l.id for l in layers): - self.removeItem(i) - else: - i += 1 - - @property - def value(self) -> str: - return self.currentData().toString() - - @value.setter - def value(self, value: str): - i = self.findData(QUuid(value)) - if i != -1 and i != self.currentIndex(): - self.setCurrentIndex(i) - - -class IntParamWidget(QWidget): - value_changed = pyqtSignal() - - def __init__(self, param: CustomParam, parent: QWidget | None = None): - super().__init__(parent) - self.setContentsMargins(0, 0, 0, 0) - - layout = QHBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - self.setLayout(layout) - - assert param.min is not None and param.max is not None and param.default is not None - if param.max - param.min <= 200: - self._widget = QSlider(Qt.Orientation.Horizontal, parent) - self._widget.setRange(int(param.min), int(param.max)) - self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) - self._widget.valueChanged.connect(self._notify) - self._label = QLabel(self) - self._label.setFixedWidth(40) - self._label.setAlignment(Qt.AlignmentFlag.AlignRight) - layout.addWidget(self._widget) - layout.addWidget(self._label) - else: - self._widget = QSpinBox(parent) - self._widget.setRange(int(param.min), int(param.max)) - self._widget.valueChanged.connect(self._notify) - self._label = None - layout = QHBoxLayout(self) - layout.addWidget(self._widget) - - self.value = param.default - - def _notify(self): - if self._label: - self._label.setText(str(self._widget.value())) - self.value_changed.emit() - - @property - def value(self): - return self._widget.value() - - @value.setter - def value(self, value: int): - self._widget.setValue(value) - - -class FloatParamWidget(QWidget): - value_changed = pyqtSignal() - - def __init__(self, param: CustomParam, parent: QWidget | None = None): - super().__init__(parent) - self.setContentsMargins(0, 0, 0, 0) - - layout = QHBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - self.setLayout(layout) - - assert param.min is not None and param.max is not None and param.default is not None - if param.max - param.min <= 20: - self._widget = QSlider(Qt.Orientation.Horizontal, parent) - self._widget.setRange(round(param.min * 100), round(param.max * 100)) - self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) - self._widget.valueChanged.connect(self._notify) - self._label = QLabel(self) - self._label.setFixedWidth(40) - self._label.setAlignment(Qt.AlignmentFlag.AlignRight) - layout.addWidget(self._widget) - layout.addWidget(self._label) - else: - self._widget = QDoubleSpinBox(parent) - self._widget.setRange(param.min, param.max) - self._widget.valueChanged.connect(self._notify) - self._label = None - layout = QHBoxLayout(self) - layout.addWidget(self._widget) - - self.value = param.default - - def _notify(self): - if self._label: - self._label.setText(f"{self.value:.2f}") - self.value_changed.emit() - - @property - def value(self): - if isinstance(self._widget, QSlider): - return self._widget.value() / 100 - else: - return self._widget.value() - - @value.setter - def value(self, value: float): - if isinstance(self._widget, QSlider): - self._widget.setValue(round(value * 100)) - else: - self._widget.setValue(value) - - -class BoolParamWidget(QWidget): - value_changed = pyqtSignal() - - _true_text = _("On") - _false_text = _("Off") - - def __init__(self, param: CustomParam, parent: QWidget | None = None): - super().__init__(parent) - self.setContentsMargins(0, 0, 0, 0) - - layout = QHBoxLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - self.setLayout(layout) - - fm = QFontMetrics(self.font()) - self._label = QLabel(self) - self._label.setMinimumWidth(max(fm.width(self._true_text), fm.width(self._false_text)) + 4) - self._widget = SwitchWidget(parent) - self._widget.toggled.connect(self._notify) - layout.addWidget(self._widget) - layout.addWidget(self._label) - - assert isinstance(param.default, bool) - self.value = param.default - - def _notify(self): - self._label.setText(self._true_text if self.value else self._false_text) - self.value_changed.emit() - - @property - def value(self): - return self._widget.isChecked() - - @value.setter - def value(self, value: bool): - self._widget.setChecked(value) - - -class TextParamWidget(QLineEdit): - value_changed = pyqtSignal() - - def __init__(self, param: CustomParam, parent: QWidget | None = None): - super().__init__(parent) - assert isinstance(param.default, str) - - self.value = param.default - self.textChanged.connect(self._notify) - - def _notify(self): - self.value_changed.emit() - - @property - def value(self): - return self.text() - - @value.setter - def value(self, value: str): - self.setText(value) - - -class PromptParamWidget(TextPromptWidget): - value_changed = pyqtSignal() - - def __init__(self, param: CustomParam, parent: QWidget | None = None): - super().__init__(is_negative=param.kind is ParamKind.prompt_negative, parent=parent) - assert isinstance(param.default, str) - - self.setObjectName("PromptParam") - self.setFrameStyle(QFrame.Shape.StyledPanel) - self.setStyleSheet( - f"QFrame#PromptParam {{ background-color: {theme.base}; border: 1px solid {theme.line_base}; }}" - ) - self.text = param.default - self.text_changed.connect(self.value_changed) - - @property - def value(self): - return self.text - - @value.setter - def value(self, value: str): - self.text = value - - -CustomParamWidget = ( - LayerSelect - | IntParamWidget - | FloatParamWidget - | BoolParamWidget - | TextParamWidget - | PromptParamWidget -) - - -def _create_param_widget(param: CustomParam, parent: QWidget): - if param.kind is ParamKind.image_layer: - return LayerSelect("image", parent) - if param.kind is ParamKind.mask_layer: - return LayerSelect("mask", parent) - if param.kind is ParamKind.number_int: - return IntParamWidget(param, parent) - if param.kind is ParamKind.number_float: - return FloatParamWidget(param, parent) - if param.kind is ParamKind.boolean: - return BoolParamWidget(param, parent) - if param.kind is ParamKind.text: - return TextParamWidget(param, parent) - if param.kind in [ParamKind.prompt_positive, ParamKind.prompt_negative]: - return PromptParamWidget(param, parent) - assert False, f"Unknown param kind: {param.kind}" - - -class WorkflowParamsWidget(QWidget): - value_changed = pyqtSignal() - - def __init__(self, params: list[CustomParam], parent: QWidget | None = None): - super().__init__(parent) - self._widgets: dict[str, CustomParamWidget] = {} - - layout = QGridLayout(self) - layout.setContentsMargins(0, 0, 0, 0) - layout.setColumnMinimumWidth(1, 10) - self.setLayout(layout) - - for p in params: - label = QLabel(p.name, self) - widget = _create_param_widget(p, self) - widget.value_changed.connect(self._notify) - row = len(self._widgets) - layout.addWidget(label, row, 0) - layout.addWidget(widget, row, 2) - self._widgets[p.name] = widget - - def _notify(self): - self.value_changed.emit() - - @property - def value(self): - return {name: widget.value for name, widget in self._widgets.items()} - - @value.setter - def value(self, values: dict[str, Any]): - for name, value in values.items(): - if widget := self._widgets.get(name): - if type(widget.value) == type(value): - widget.value = value - - -def popup_on_error(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - try: - return func(self, *args, **kwargs) - except Exception as e: - QMessageBox.critical(self, _("Error"), str(e)) - - return wrapper - - -def _create_tool_button(parent: QWidget, icon: QIcon, tooltip: str, handler: Callable[..., None]): - button = QToolButton(parent) - button.setIcon(icon) - button.setToolTip(tooltip) - button.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly) - button.setAutoRaise(True) - button.clicked.connect(handler) - return button - - -class CustomWorkflowWidget(QWidget): - def __init__(self): - super().__init__() - - self._model = root.active_model - self._model_bindings: list[QMetaObject.Connection | Binding] = [] - - self._workspace_select = WorkspaceSelectWidget(self) - - self._workflow_select_widgets = QWidget(self) - - self._workflow_select = QComboBox(self._workflow_select_widgets) - self._workflow_select.setModel(SortedWorkflows(root.workflows)) - self._workflow_select.currentIndexChanged.connect(self._change_workflow) - - self._import_workflow_button = _create_tool_button( - self._workflow_select_widgets, - theme.icon("import"), - _("Import workflow from file"), - self._import_workflow, - ) - self._save_workflow_button = _create_tool_button( - self._workflow_select_widgets, - theme.icon("save"), - _("Save workflow to file"), - self._save_workflow, - ) - self._delete_workflow_button = _create_tool_button( - self._workflow_select_widgets, - theme.icon("discard"), - _("Delete the currently selected workflow"), - self._delete_workflow, - ) - self._open_webui_button = _create_tool_button( - self._workflow_select_widgets, - theme.icon("comfyui"), - _("Open Web UI to create custom workflows"), - self._open_webui, - ) - - self._workflow_edit_widgets = QWidget(self) - self._workflow_edit_widgets.setVisible(False) - - self._workflow_name_edit = QLineEdit(self._workflow_edit_widgets) - self._workflow_name_edit.textEdited.connect(self._edit_name) - self._workflow_name_edit.returnPressed.connect(self._accept_name) - - self._accept_name_button = _create_tool_button( - self._workflow_edit_widgets, theme.icon("apply"), _("Apply"), self._accept_name - ) - self._cancel_name_button = _create_tool_button( - self._workflow_edit_widgets, theme.icon("cancel"), _("Cancel"), self._cancel_name - ) - - self._params_widget = WorkflowParamsWidget([], self) - - self._generate_button = GenerateButton(JobKind.diffusion, self) - self._queue_button = QueueButton(parent=self) - self._queue_button.setFixedHeight(self._generate_button.height() - 2) - self._progress_bar = ProgressBar(self) - self._error_text = _create_error_label(self) - - self._history = HistoryWidget(self) - self._history.item_activated.connect(self.apply_result) - - self._layout = QVBoxLayout() - select_layout = QHBoxLayout() - select_layout.setContentsMargins(0, 0, 0, 0) - select_layout.setSpacing(2) - select_layout.addWidget(self._workflow_select) - select_layout.addWidget(self._import_workflow_button) - select_layout.addWidget(self._save_workflow_button) - select_layout.addWidget(self._delete_workflow_button) - select_layout.addWidget(self._open_webui_button) - self._workflow_select_widgets.setLayout(select_layout) - edit_layout = QHBoxLayout() - edit_layout.setContentsMargins(0, 0, 0, 0) - edit_layout.setSpacing(2) - edit_layout.addWidget(self._workflow_name_edit) - edit_layout.addWidget(self._accept_name_button) - edit_layout.addWidget(self._cancel_name_button) - self._workflow_edit_widgets.setLayout(edit_layout) - header_layout = QHBoxLayout() - header_layout.addWidget(self._workspace_select) - header_layout.addWidget(self._workflow_select_widgets) - header_layout.addWidget(self._workflow_edit_widgets) - self._layout.addLayout(header_layout) - self._layout.addWidget(self._params_widget) - actions_layout = QHBoxLayout() - actions_layout.addWidget(self._generate_button) - actions_layout.addWidget(self._queue_button) - self._layout.addLayout(actions_layout) - self._layout.addWidget(self._progress_bar) - self._layout.addWidget(self._error_text) - self._layout.addWidget(self._history) - self.setLayout(self._layout) - - def _update_current_workflow(self): - if not self.model.custom.workflow: - self._save_workflow_button.setEnabled(False) - self._delete_workflow_button.setEnabled(False) - return - self._save_workflow_button.setEnabled(True) - self._delete_workflow_button.setEnabled( - self.model.custom.workflow.source is WorkflowSource.local - ) - - self._params_widget.deleteLater() - self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self) - self._params_widget.value = self.model.custom.params - self._layout.insertWidget(1, self._params_widget) - self._params_widget.value_changed.connect(self._change_params) - - def _change_workflow(self): - self.model.custom.workflow_id = self._workflow_select.currentData() - - def _change_params(self): - self.model.custom.params = self._params_widget.value - - @property - def model(self): - return self._model - - @model.setter - def model(self, model: Model): - if self._model != model: - Binding.disconnect_all(self._model_bindings) - self._model = model - self._model_bindings = [ - bind(model, "workspace", self._workspace_select, "value", Bind.one_way), - bind_combo(model.custom, "workflow_id", self._workflow_select, Bind.one_way), - model.workspace_changed.connect(self._cancel_name), - model.custom.graph_changed.connect(self._update_current_workflow), - model.error_changed.connect(self._error_text.setText), - model.has_error_changed.connect(self._error_text.setVisible), - self._generate_button.clicked.connect(model.generate_custom), - ] - self._queue_button.model = model - self._progress_bar.model = model - self._history.model_ = model - self._update_current_workflow() - - def apply_result(self, item: QListWidgetItem): - job_id, index = self._history.item_info(item) - self.model.apply_generated_result(job_id, index) - - @popup_on_error - def _import_workflow(self, *args): - filename, __ = QFileDialog.getOpenFileName( - self, - _("Import Workflow"), - str(Path.home()), - "Workflow Files (*.json);;All Files (*)", - ) - if filename: - self.model.custom.import_file(Path(filename)) - - def _save_workflow(self): - self.is_edit_mode = True - - def _delete_workflow(self): - filepath = ensure(self.model.custom.workflow).path - q = QMessageBox.question( - self, - _("Delete Workflow"), - _("Are you sure you want to delete the current workflow?") + f"\n{filepath}", - QMessageBox.Yes | QMessageBox.No, - QMessageBox.StandardButton.No, - ) - if q == QMessageBox.StandardButton.Yes: - self.model.custom.remove_workflow() - - def _open_webui(self): - if client := root.connection.client_if_connected: - QDesktopServices.openUrl(QUrl(client.url)) - - @property - def is_edit_mode(self): - return self._workflow_edit_widgets.isVisible() - - @is_edit_mode.setter - def is_edit_mode(self, value: bool): - if value == self.is_edit_mode: - return - self._workflow_select_widgets.setVisible(not value) - self._workflow_edit_widgets.setVisible(value) - if value: - self._workflow_name_edit.setText(self.model.custom.workflow_id) - self._workflow_name_edit.selectAll() - self._workflow_name_edit.setFocus() - - def _edit_name(self): - self._accept_name_button.setEnabled(self._workflow_name_edit.text().strip() != "") - - @popup_on_error - def _accept_name(self, *args): - self.model.custom.save_as(self._workflow_name_edit.text()) - self.is_edit_mode = False - - def _cancel_name(self): - self.is_edit_mode = False From b57cd9426dc9266c009391482315075a7fd11487 Mon Sep 17 00:00:00 2001 From: Acly Date: Sat, 5 Oct 2024 20:25:46 +0200 Subject: [PATCH 15/28] Import for basic comfy UI workflow files (might not work when using bypass and similar) --- ai_diffusion/comfy_client.py | 2 +- ai_diffusion/comfy_workflow.py | 69 ++++- ai_diffusion/custom_workflow.py | 31 ++- tests/data/object_info.json | 241 ++++++++++++++++++ tests/data/workflow-api.json | 64 +++++ tests/data/workflow-ui.json | 438 ++++++++++++++++++++++++++++++++ tests/test_custom_workflow.py | 38 ++- 7 files changed, 856 insertions(+), 27 deletions(-) create mode 100644 tests/data/object_info.json create mode 100644 tests/data/workflow-api.json create mode 100644 tests/data/workflow-ui.json diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index 5c26d1b1cc..38bee0ef2b 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -126,7 +126,7 @@ async def connect(url=default_url, access_token=""): # Check for required and optional model resources models = client.models - models.node_inputs = {name: nodes[name]["input"].get("required", None) for name in nodes} + models.node_inputs = {name: nodes[name]["input"] for name in nodes} available_resources = client.models.resources = {} clip_models = nodes["DualCLIPLoader"]["input"]["required"]["clip_name1"][0] diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index a986b1bcaa..09eb729986 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -64,17 +64,22 @@ def __init__(self, node_inputs: dict | None = None, run_mode=ComfyRunMode.server self.node_count = 0 self.sample_count = 0 self._cache: dict[str, Output | Output2 | Output3 | Output4] = {} - self._nodes_required_inputs: dict[str, dict[str, Any]] = node_inputs or {} + self._nodes_inputs: dict[str, dict[str, Any]] = node_inputs or {} self._run_mode: ComfyRunMode = run_mode @staticmethod - def import_graph(existing: dict): - w = ComfyWorkflow() + def import_graph(existing: dict, node_inputs: dict): + w = ComfyWorkflow(node_inputs) + existing = _convert_ui_workflow(existing, node_inputs) node_map: dict[str, str] = {} queue = list(existing.keys()) while queue: id = queue.pop(0) node = deepcopy(existing[id]) + if node_inputs and node["class_type"] not in node_inputs: + raise ValueError( + f"Workflow contains a node of type {node['class_type']} which is not installed on the ComfyUI server." + ) edges = [e for e in node["inputs"].values() if isinstance(e, list)] if any(e[0] not in node_map for e in edges): queue.append(id) # requeue node if an input is not yet mapped @@ -94,7 +99,7 @@ def from_dict(existing: dict): return w def add_default_values(self, node_name: str, args: dict): - if node_inputs := self._nodes_required_inputs.get(node_name, None): + if node_inputs := _inputs_for_node(self._nodes_inputs, node_name, "required"): for k, v in node_inputs.items(): if k not in args: if len(v) == 1 and isinstance(v[0], list) and len(v[0]) > 0: @@ -834,3 +839,59 @@ def estimate_pose(self, image: Output, resolution: int): # use smaller model, but it requires onnxruntime, see #630 mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx" return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls) + + +def _inputs_for_node(node_inputs: dict[str, dict[str, Any]], node_name: str, filter=""): + inputs = node_inputs.get(node_name) + if inputs is None: + return None + if filter: + return inputs.get(filter) + result = inputs.get("required", {}) + result.update(inputs.get("optional", {})) + return result + + +def _convert_ui_workflow(w: dict, node_inputs: dict): + version = w.get("version") + nodes = w.get("nodes") + links = w.get("links") + if not (version and nodes and links): + return w + + primitives = {} + for node in nodes: + if node["type"] == "PrimitiveNode": + primitives[node["id"]] = node["widgets_values"][0] + + r = {} + for node in nodes: + id = node["id"] + type = node["type"] + if type == "PrimitiveNode": + continue + + inputs = {} + fields = _inputs_for_node(node_inputs, type) + if fields is None: + raise ValueError( + f"Workflow uses node type {type}, but it is not installed on the ComfyUI server." + ) + widget_count = 0 + for field_name, field in fields.items(): + field_type = field[0] + if field_type in ["INT", "FLOAT", "BOOL", "STRING"] or isinstance(field_type, list): + inputs[field_name] = node["widgets_values"][widget_count] + widget_count += 1 + for connection in node["inputs"]: + if connection["name"] == field_name and connection["link"] is not None: + link = next(l for l in links if l[0] == connection["link"]) + prim = primitives.get(link[1]) + if prim is not None: + inputs[field_name] = prim + else: + inputs[field_name] = [link[1], link[2]] + break + r[id] = {"class_type": type, "inputs": inputs} + + return r diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index a6fdd3c06d..31c0c74b00 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -25,16 +25,18 @@ class CustomWorkflow: id: str source: WorkflowSource graph: dict + workflow: ComfyWorkflow path: Path | None = None + @staticmethod + def from_api(id: str, source: WorkflowSource, graph: dict, path: Path | None = None): + # doesn't work for UI workflow export (API workflow only) + return CustomWorkflow(id, source, graph, ComfyWorkflow.import_graph(graph, {}), path) + @property def name(self): return self.id.removesuffix(".json") - @property - def workflow(self): - return ComfyWorkflow.import_graph(self.graph) - class WorkflowCollection(QAbstractListModel): @@ -58,12 +60,20 @@ def __init__(self, connection: Connection, folder: Path | None = None): for wf in self._connection.workflows.keys(): self._process_remote_workflow(wf) + def _create_workflow( + self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None + ): + wf = ComfyWorkflow.import_graph(graph, self._connection.client.models.node_inputs) + return CustomWorkflow(id, source, graph, wf, path) + def _process_remote_workflow(self, id: str): - self._process(CustomWorkflow(id, WorkflowSource.remote, self._connection.workflows[id])) + graph = self._connection.workflows[id] + self._process(self._create_workflow(id, WorkflowSource.remote, graph)) def _process_file(self, file: Path): with file.open("r") as f: - self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f), file)) + graph = json.load(f) + self._process(self._create_workflow(file.stem, WorkflowSource.local, graph, file)) def _process(self, workflow: CustomWorkflow): idx = self.find_index(workflow.id) @@ -94,6 +104,9 @@ def append(self, item: CustomWorkflow): self._workflows.append(item) self.endInsertRows() + def add_from_document(self, id: str, graph: dict): + self.append(self._create_workflow(id, WorkflowSource.document, graph)) + def remove(self, id: str): idx = self.find_index(id) if idx.isValid(): @@ -118,7 +131,7 @@ def save_as(self, id: str, graph: dict): self._folder.mkdir(exist_ok=True) path = self._folder / f"{id}.json" path.write_text(json.dumps(graph, indent=2)) - self.append(CustomWorkflow(id, WorkflowSource.local, graph, path)) + self.append(self._create_workflow(id, WorkflowSource.local, graph, path)) return id def import_file(self, filepath: Path): @@ -126,7 +139,7 @@ def import_file(self, filepath: Path): with filepath.open("r") as f: graph = json.load(f) try: - ComfyWorkflow.import_graph(graph) + ComfyWorkflow.import_graph(graph, self._connection.client.models.node_inputs) except Exception as e: raise RuntimeError(f"This is not a supported workflow file ({e})") return self.save_as(filepath.stem, graph) @@ -279,7 +292,7 @@ def _set_workflow_id(self, id: str): def set_graph(self, id: str, graph: dict): if self._workflows.find(id) is None: - self._workflows.append(CustomWorkflow(id, WorkflowSource.document, graph)) + self._workflows.add_from_document(id, graph) self.workflow_id = id def import_file(self, filepath: Path): diff --git a/tests/data/object_info.json b/tests/data/object_info.json new file mode 100644 index 0000000000..8ae9803922 --- /dev/null +++ b/tests/data/object_info.json @@ -0,0 +1,241 @@ +{ + "GrowMask": { + "input": { + "required": { + "mask": [ + "MASK" + ], + "expand": [ + "INT", + { + "default": 0, + "min": -16384, + "max": 16384, + "step": 1 + } + ], + "tapered_corners": [ + "BOOLEAN", + { + "default": true + } + ] + } + }, + "input_order": { + "required": [ + "mask", + "expand", + "tapered_corners" + ] + }, + "output": [ + "MASK" + ], + "output_is_list": [ + false + ], + "output_name": [ + "MASK" + ], + "name": "GrowMask", + "display_name": "GrowMask", + "description": "", + "python_module": "comfy_extras.nodes_mask", + "category": "mask", + "output_node": false + }, + "ImageUpscaleWithModel": { + "input": { + "required": { + "upscale_model": [ + "UPSCALE_MODEL" + ], + "image": [ + "IMAGE" + ] + } + }, + "input_order": { + "required": [ + "upscale_model", + "image" + ] + }, + "output": [ + "IMAGE" + ], + "output_is_list": [ + false + ], + "output_name": [ + "IMAGE" + ], + "name": "ImageUpscaleWithModel", + "display_name": "Upscale Image (using Model)", + "description": "", + "python_module": "comfy_extras.nodes_upscale_model", + "category": "image/upscaling", + "output_node": false + }, + "ETN_ApplyMaskToImage": { + "input": { + "required": { + "image": [ + "IMAGE" + ], + "mask": [ + "MASK" + ] + } + }, + "input_order": { + "required": [ + "image", + "mask" + ] + }, + "output": [ + "IMAGE" + ], + "output_is_list": [ + false + ], + "output_name": [ + "IMAGE" + ], + "name": "ETN_ApplyMaskToImage", + "display_name": "Apply Mask to Image", + "description": "", + "python_module": "custom_nodes.comfyui-tooling-nodes", + "category": "external_tooling", + "output_node": false + }, + "UpscaleModelLoader": { + "input": { + "required": { + "model_name": [ + [ + "4x_NMKD-Superscale-SP_178000_G.pth", + "OmniSR_X2_DIV2K.safetensors", + "OmniSR_X3_DIV2K.safetensors", + "OmniSR_X4_DIV2K.safetensors" + ] + ] + } + }, + "input_order": { + "required": [ + "model_name" + ] + }, + "output": [ + "UPSCALE_MODEL" + ], + "output_is_list": [ + false + ], + "output_name": [ + "UPSCALE_MODEL" + ], + "name": "UpscaleModelLoader", + "display_name": "Load Upscale Model", + "description": "", + "python_module": "comfy_extras.nodes_upscale_model", + "category": "loaders", + "output_node": false + }, + "ETN_KritaCanvas": { + "input": {}, + "input_order": {}, + "output": [ + "IMAGE", + "INT", + "INT", + "INT" + ], + "output_is_list": [ + false, + false, + false, + false + ], + "output_name": [ + "image", + "width", + "height", + "seed" + ], + "name": "ETN_KritaCanvas", + "display_name": "Krita Canvas", + "description": "", + "python_module": "custom_nodes.comfyui-tooling-nodes", + "category": "krita", + "output_node": false + }, + "ETN_KritaOutput": { + "input": { + "required": { + "images": [ + "IMAGE" + ], + "format": [ + [ + "PNG", + "JPEG" + ], + { + "default": "PNG" + } + ] + } + }, + "input_order": { + "required": [ + "images", + "format" + ] + }, + "output": [], + "output_is_list": [], + "output_name": [], + "name": "ETN_KritaOutput", + "display_name": "Krita Output", + "description": "", + "python_module": "custom_nodes.comfyui-tooling-nodes", + "category": "krita", + "output_node": true + }, + "ETN_KritaMaskLayer": { + "input": { + "required": { + "name": [ + "STRING", + { + "default": "Mask" + } + ] + } + }, + "input_order": { + "required": [ + "name" + ] + }, + "output": [ + "MASK" + ], + "output_is_list": [ + false + ], + "output_name": [ + "mask" + ], + "name": "ETN_KritaMaskLayer", + "display_name": "Krita Mask Layer", + "description": "", + "python_module": "custom_nodes.comfyui-tooling-nodes", + "category": "krita", + "output_node": false + } +} \ No newline at end of file diff --git a/tests/data/workflow-api.json b/tests/data/workflow-api.json new file mode 100644 index 0000000000..8dfba8fc14 --- /dev/null +++ b/tests/data/workflow-api.json @@ -0,0 +1,64 @@ +{ + "0": { + "class_type": "UpscaleModelLoader", + "inputs": { + "model_name": "4x_NMKD-Superscale-SP_178000_G.pth" + } + }, + "1": { + "class_type": "ETN_KritaCanvas", + "inputs": {} + }, + "2": { + "class_type": "ETN_KritaMaskLayer", + "inputs": { + "name": "Zauber" + } + }, + "3": { + "class_type": "GrowMask", + "inputs": { + "mask": [ + "2", + 0 + ], + "expand": 4 + } + }, + "4": { + "class_type": "ImageUpscaleWithModel", + "inputs": { + "upscale_model": [ + "0", + 0 + ], + "image": [ + "1", + 0 + ] + } + }, + "5": { + "class_type": "ETN_ApplyMaskToImage", + "inputs": { + "image": [ + "4", + 0 + ], + "mask": [ + "3", + 0 + ] + } + }, + "6": { + "class_type": "ETN_KritaOutput", + "inputs": { + "images": [ + "5", + 0 + ], + "format": "PNG" + } + } +} \ No newline at end of file diff --git a/tests/data/workflow-ui.json b/tests/data/workflow-ui.json new file mode 100644 index 0000000000..7ade8ed334 --- /dev/null +++ b/tests/data/workflow-ui.json @@ -0,0 +1,438 @@ +{ + "last_node_id": 10, + "last_link_id": 10, + "nodes": [ + { + "id": 3, + "type": "GrowMask", + "pos": { + "0": 448, + "1": 57 + }, + "size": [ + 315, + 82 + ], + "flags": {}, + "order": 5, + "mode": 0, + "inputs": [ + { + "name": "mask", + "type": "MASK", + "link": 3 + }, + { + "name": "expand", + "type": "INT", + "link": 2, + "widget": { + "name": "expand" + } + } + ], + "outputs": [ + { + "name": "MASK", + "type": "MASK", + "links": [ + 10 + ], + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "GrowMask" + }, + "widgets_values": [ + 4, + true + ] + }, + { + "id": 4, + "type": "PrimitiveNode", + "pos": { + "0": 138, + "1": 156 + }, + "size": [ + 210, + 82 + ], + "flags": {}, + "order": 0, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "INT", + "type": "INT", + "links": [ + 2 + ], + "slot_index": 0, + "widget": { + "name": "expand" + } + } + ], + "properties": { + "Run widget replace on values": false + }, + "widgets_values": [ + 4, + "fixed" + ] + }, + { + "id": 8, + "type": "PrimitiveNode", + "pos": { + "0": -280, + "1": -370 + }, + "size": [ + 364.38086885107873, + 106 + ], + "flags": {}, + "order": 1, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "COMBO", + "type": "COMBO", + "links": [ + 4 + ], + "slot_index": 0, + "widget": { + "name": "model_name" + } + } + ], + "properties": { + "Run widget replace on values": false + }, + "widgets_values": [ + "4x_NMKD-Superscale-SP_178000_G.pth", + "fixed", + "" + ] + }, + { + "id": 9, + "type": "ImageUpscaleWithModel", + "pos": { + "0": 470, + "1": -370 + }, + "size": { + "0": 340.20001220703125, + "1": 46 + }, + "flags": {}, + "order": 6, + "mode": 0, + "inputs": [ + { + "name": "upscale_model", + "type": "UPSCALE_MODEL", + "link": 5 + }, + { + "name": "image", + "type": "IMAGE", + "link": 6 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 9 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ImageUpscaleWithModel" + } + }, + { + "id": 10, + "type": "ETN_ApplyMaskToImage", + "pos": { + "0": 870, + "1": -190 + }, + "size": { + "0": 239.40000915527344, + "1": 46 + }, + "flags": {}, + "order": 7, + "mode": 0, + "inputs": [ + { + "name": "image", + "type": "IMAGE", + "link": 9 + }, + { + "name": "mask", + "type": "MASK", + "link": 10 + } + ], + "outputs": [ + { + "name": "IMAGE", + "type": "IMAGE", + "links": [ + 8 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ETN_ApplyMaskToImage" + } + }, + { + "id": 6, + "type": "UpscaleModelLoader", + "pos": { + "0": 120, + "1": -370 + }, + "size": [ + 315, + 58 + ], + "flags": {}, + "order": 4, + "mode": 0, + "inputs": [ + { + "name": "model_name", + "type": "COMBO", + "link": 4, + "widget": { + "name": "model_name" + } + } + ], + "outputs": [ + { + "name": "UPSCALE_MODEL", + "type": "UPSCALE_MODEL", + "links": [ + 5 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "UpscaleModelLoader" + }, + "widgets_values": [ + "4x_NMKD-Superscale-SP_178000_G.pth" + ] + }, + { + "id": 1, + "type": "ETN_KritaCanvas", + "pos": { + "0": 205, + "1": -247 + }, + "size": [ + 200, + 100 + ], + "flags": {}, + "order": 2, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "image", + "type": "IMAGE", + "links": [ + 6 + ], + "shape": 3, + "slot_index": 0 + }, + { + "name": "width", + "type": "INT", + "links": null, + "shape": 3 + }, + { + "name": "height", + "type": "INT", + "links": null, + "shape": 3 + }, + { + "name": "seed", + "type": "INT", + "links": null, + "shape": 3 + } + ], + "properties": { + "Node name for S&R": "ETN_KritaCanvas" + } + }, + { + "id": 2, + "type": "ETN_KritaOutput", + "pos": { + "0": 1140, + "1": -190 + }, + "size": [ + 200, + 120 + ], + "flags": {}, + "order": 8, + "mode": 0, + "inputs": [ + { + "name": "images", + "type": "IMAGE", + "link": 8 + } + ], + "outputs": [], + "properties": { + "Node name for S&R": "ETN_KritaOutput" + }, + "widgets_values": [ + "PNG" + ] + }, + { + "id": 5, + "type": "ETN_KritaMaskLayer", + "pos": { + "0": 41, + "1": 11 + }, + "size": { + "0": 315, + "1": 58 + }, + "flags": {}, + "order": 3, + "mode": 0, + "inputs": [], + "outputs": [ + { + "name": "mask", + "type": "MASK", + "links": [ + 3 + ], + "shape": 3, + "slot_index": 0 + } + ], + "properties": { + "Node name for S&R": "ETN_KritaMaskLayer" + }, + "widgets_values": [ + "Zauber" + ] + } + ], + "links": [ + [ + 2, + 4, + 0, + 3, + 1, + "INT" + ], + [ + 3, + 5, + 0, + 3, + 0, + "MASK" + ], + [ + 4, + 8, + 0, + 6, + 0, + "COMBO" + ], + [ + 5, + 6, + 0, + 9, + 0, + "UPSCALE_MODEL" + ], + [ + 6, + 1, + 0, + 9, + 1, + "IMAGE" + ], + [ + 8, + 10, + 0, + 2, + 0, + "IMAGE" + ], + [ + 9, + 9, + 0, + 10, + 0, + "IMAGE" + ], + [ + 10, + 3, + 0, + 10, + 1, + "MASK" + ] + ], + "groups": [], + "config": {}, + "extra": { + "ds": { + "scale": 1.0834705943388394, + "offset": [ + 311.97951538052513, + 487.01874472527845 + ] + } + }, + "version": 0.4 +} \ No newline at end of file diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 6ce33a9c91..46b86b2c60 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -12,6 +12,8 @@ from ai_diffusion.image import Image, Extent from ai_diffusion import workflow +from .config import test_dir + def test_collection(tmp_path: Path): file1 = tmp_path / "file1.json" @@ -27,13 +29,13 @@ def test_collection(tmp_path: Path): collection = WorkflowCollection(connection, tmp_path) assert len(collection) == 3 - assert collection.find("file1") == CustomWorkflow( + assert collection.find("file1") == CustomWorkflow.from_api( "file1", WorkflowSource.local, {"file": 1}, file1 ) - assert collection.find("file2") == CustomWorkflow( + assert collection.find("file2") == CustomWorkflow.from_api( "file2", WorkflowSource.local, {"file": 2}, file2 ) - assert collection.find("connection1") == CustomWorkflow( + assert collection.find("connection1") == CustomWorkflow.from_api( "connection1", WorkflowSource.remote, {"connection": 1} ) @@ -56,17 +58,17 @@ def on_data_changed(start, end): connection.workflow_published.emit("connection2") assert len(collection) == 4 - assert collection.find("connection2") == CustomWorkflow( + assert collection.find("connection2") == CustomWorkflow.from_api( "connection2", WorkflowSource.remote, {"connection": 2} ) collection.set_graph(collection.index(0), {"file": 3}) - assert collection.find("file1") == CustomWorkflow( + assert collection.find("file1") == CustomWorkflow.from_api( "file1", WorkflowSource.local, {"file": 3}, file1 ) assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)] - collection.append(CustomWorkflow("doc1", WorkflowSource.document, {"doc": 1})) + collection.add_from_document("doc1", {"doc": 1}) sorted = SortedWorkflows(collection) assert sorted[0].source is WorkflowSource.document @@ -158,18 +160,28 @@ def test_workspace(): def test_import(): - w = ComfyWorkflow.import_graph( - { - "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, - "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}}, - "9": {"class_type": "B", "inputs": {"in": ["4", 0]}}, - } - ) + graph = { + "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, + "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}}, + "9": {"class_type": "B", "inputs": {"in": ["4", 0]}}, + } + w = ComfyWorkflow.import_graph(graph, {}) assert w.node(0) == ComfyNode(0, "A", {"int": 4, "float": 1.2, "string": "mouse"}) assert w.node(1) == ComfyNode(1, "B", {"in": Output(0, 0)}) assert w.node(2) == ComfyNode(2, "C", {"in": Output(1, 1)}) +def test_import_ui_workflow(): + graph = json.loads((test_dir / "data" / "workflow-ui.json").read_text()) + object_info = json.loads((test_dir / "data" / "object_info.json").read_text()) + node_inputs = {k: v.get("input") for k, v in object_info.items()} + result = ComfyWorkflow.import_graph(graph, node_inputs) + + expected_graph = json.loads((test_dir / "data" / "workflow-api.json").read_text()) + expected = ComfyWorkflow.import_graph(expected_graph, {}) + assert result.root == expected.root + + def test_parameters(): w = ComfyWorkflow() w.add("ETN_IntParameter", 1, name="int", default=4, min=0, max=10) From 5836469d980f387482632ea14a2bf4595cd8a983 Mon Sep 17 00:00:00 2001 From: Acly Date: Sat, 5 Oct 2024 22:28:03 +0200 Subject: [PATCH 16/28] Don't store graph as dict and fix tests --- ai_diffusion/comfy_workflow.py | 3 ++ ai_diffusion/custom_workflow.py | 24 +++++++------- tests/test_custom_workflow.py | 58 +++++++++++++++++++-------------- 3 files changed, 49 insertions(+), 36 deletions(-) diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 09eb729986..96f82b4cf9 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -859,6 +859,9 @@ def _convert_ui_workflow(w: dict, node_inputs: dict): if not (version and nodes and links): return w + if not node_inputs: + raise ValueError("An active ComfyUI connection is required to convert a UI workflow file.") + primitives = {} for node in nodes: if node["type"] == "PrimitiveNode": diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 31c0c74b00..5d127bc252 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -24,15 +24,9 @@ class WorkflowSource(Enum): class CustomWorkflow: id: str source: WorkflowSource - graph: dict workflow: ComfyWorkflow path: Path | None = None - @staticmethod - def from_api(id: str, source: WorkflowSource, graph: dict, path: Path | None = None): - # doesn't work for UI workflow export (API workflow only) - return CustomWorkflow(id, source, graph, ComfyWorkflow.import_graph(graph, {}), path) - @property def name(self): return self.id.removesuffix(".json") @@ -46,6 +40,7 @@ class WorkflowCollection(QAbstractListModel): def __init__(self, connection: Connection, folder: Path | None = None): super().__init__() + self._connection = connection self._workflows: list[CustomWorkflow] = [] self._folder = folder or user_data_dir / "workflows" @@ -55,16 +50,20 @@ def __init__(self, connection: Connection, folder: Path | None = None): except Exception as e: log.exception(f"Error loading workflow from {file}: {e}") - self._connection = connection self._connection.workflow_published.connect(self._process_remote_workflow) for wf in self._connection.workflows.keys(): self._process_remote_workflow(wf) + def _node_inputs(self): + if client := self._connection.client_if_connected: + return client.models.node_inputs + return {} + def _create_workflow( self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None ): - wf = ComfyWorkflow.import_graph(graph, self._connection.client.models.node_inputs) - return CustomWorkflow(id, source, graph, wf, path) + wf = ComfyWorkflow.import_graph(graph, self._node_inputs()) + return CustomWorkflow(id, source, wf, path) def _process_remote_workflow(self, id: str): graph = self._connection.workflows[id] @@ -78,7 +77,7 @@ def _process_file(self, file: Path): def _process(self, workflow: CustomWorkflow): idx = self.find_index(workflow.id) if idx.isValid(): - self.set_graph(idx, workflow.graph) + self.dataChanged.emit(idx, idx) else: self.append(workflow) @@ -118,7 +117,8 @@ def remove(self, id: str): self.endRemoveRows() def set_graph(self, index: QModelIndex, graph: dict): - self._workflows[index.row()].graph = graph + wf = self._workflows[index.row()] + wf.workflow = ComfyWorkflow.import_graph(graph, self._node_inputs()) self.dataChanged.emit(index, index) def save_as(self, id: str, graph: dict): @@ -139,7 +139,7 @@ def import_file(self, filepath: Path): with filepath.open("r") as f: graph = json.load(f) try: - ComfyWorkflow.import_graph(graph, self._connection.client.models.node_inputs) + ComfyWorkflow.import_graph(graph, self._node_inputs()) except Exception as e: raise RuntimeError(f"This is not a supported workflow file ({e})") return self.save_as(filepath.stem, graph) diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 46b86b2c60..19cba9f0d8 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -15,29 +15,41 @@ from .config import test_dir +def _assert_has_workflow( + collection: WorkflowCollection, + name: str, + source: WorkflowSource, + graph: dict, + file: Path | None = None, +): + workflow = collection.find(name) + assert ( + workflow is not None + and workflow.source == source + and workflow.workflow.root == graph + and workflow.path == file + ) + + def test_collection(tmp_path: Path): file1 = tmp_path / "file1.json" - file1.write_text('{"file": 1}') + file1_graph = {"0": {"class_type": "F1", "inputs": {}}} + file1.write_text(json.dumps(file1_graph)) + file2 = tmp_path / "file2.json" - file2.write_text('{"file": 2}') + file2_graph = {"0": {"class_type": "F2", "inputs": {}}} + file2.write_text(json.dumps(file2_graph)) connection = Connection() - connection_workflows = { - "connection1": {"connection": 1}, - } + connection_graph = {"0": {"class_type": "C1", "inputs": {}}} + connection_workflows = {"connection1": connection_graph} connection._workflows = connection_workflows collection = WorkflowCollection(connection, tmp_path) assert len(collection) == 3 - assert collection.find("file1") == CustomWorkflow.from_api( - "file1", WorkflowSource.local, {"file": 1}, file1 - ) - assert collection.find("file2") == CustomWorkflow.from_api( - "file2", WorkflowSource.local, {"file": 2}, file2 - ) - assert collection.find("connection1") == CustomWorkflow.from_api( - "connection1", WorkflowSource.remote, {"connection": 1} - ) + _assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph, file1) + _assert_has_workflow(collection, "file2", WorkflowSource.local, file2_graph, file2) + _assert_has_workflow(collection, "connection1", WorkflowSource.remote, connection_graph) events = [] @@ -54,21 +66,19 @@ def on_data_changed(start, end): collection.rowsInserted.connect(on_end_insert) collection.dataChanged.connect(on_data_changed) - connection_workflows["connection2"] = {"connection": 2} + connection2_graph = {"0": {"class_type": "C2", "inputs": {}}} + connection_workflows["connection2"] = connection2_graph connection.workflow_published.emit("connection2") assert len(collection) == 4 - assert collection.find("connection2") == CustomWorkflow.from_api( - "connection2", WorkflowSource.remote, {"connection": 2} - ) + _assert_has_workflow(collection, "connection2", WorkflowSource.remote, connection2_graph) - collection.set_graph(collection.index(0), {"file": 3}) - assert collection.find("file1") == CustomWorkflow.from_api( - "file1", WorkflowSource.local, {"file": 3}, file1 - ) + file1_graph_changed = {"0": {"class_type": "F3", "inputs": {}}} + collection.set_graph(collection.index(0), file1_graph_changed) + _assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph_changed, file1) assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)] - collection.add_from_document("doc1", {"doc": 1}) + collection.add_from_document("doc1", {"0": {"class_type": "D1", "inputs": {}}}) sorted = SortedWorkflows(collection) assert sorted[0].source is WorkflowSource.document @@ -102,7 +112,7 @@ def test_files(tmp_path: Path): collection.import_file(file1) assert collection.find("file1 (1)") is not None - collection.save_as("file1", {"file": 2}) + collection.save_as("file1", make_dummy_graph(77)) assert collection.find("file1 (2)") is not None files = [ From 0ba901a5e9b35738906c8cf92816dcab46c76925 Mon Sep 17 00:00:00 2001 From: Acly Date: Mon, 7 Oct 2024 09:24:06 +0200 Subject: [PATCH 17/28] Better error message when failing to read transferred images --- ai_diffusion/image.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ai_diffusion/image.py b/ai_diffusion/image.py index 82e35053df..d06768bc2b 100644 --- a/ai_diffusion/image.py +++ b/ai_diffusion/image.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum from math import ceil, sqrt -from PyQt5.QtGui import QImage, QImageWriter, QPixmap, QIcon, QPainter, QColorSpace +from PyQt5.QtGui import QImage, QImageWriter, QImageReader, QPixmap, QIcon, QPainter, QColorSpace from PyQt5.QtGui import qRgba, qRed, qGreen, qBlue, qAlpha, qGray from PyQt5.QtCore import Qt, QByteArray, QBuffer, QRect, QSize, QFile, QIODevice from typing import Callable, Iterable, SupportsIndex, Tuple, NamedTuple, Union, Optional @@ -642,10 +642,11 @@ def from_bytes(data: QByteArray | bytes, offsets: list[int]): for i, offset in enumerate(offsets): buffer.seek(offset) img = QImage() - if img.load(buffer, None): + loader = QImageReader(buffer) + if loader.read(img): images.append(Image(img)) else: - raise Exception(f"Failed to load image {i} from buffer") + raise Exception(f"Failed to load image {i} from buffer: {loader.errorString()}") buffer.close() return images From 62e4c7420cdb51e95d6b2f60ea462adc8f4ac85e Mon Sep 17 00:00:00 2001 From: Acly Date: Mon, 7 Oct 2024 12:00:06 +0200 Subject: [PATCH 18/28] Support generalized custom parameter node and choice (combo box) parameters --- ai_diffusion/comfy_workflow.py | 17 ++++++-- ai_diffusion/custom_workflow.py | 40 +++++++++++++----- ai_diffusion/ui/custom_workflow.py | 40 +++++++++++++++--- ai_diffusion/workflow.py | 7 +--- tests/test_custom_workflow.py | 65 +++++++++++++++++++++--------- 5 files changed, 123 insertions(+), 46 deletions(-) diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 96f82b4cf9..4faadd9b06 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -36,10 +36,10 @@ class ComfyNode(NamedTuple): def input(self, key: str, default: T) -> T: ... @overload - def input(self, key: str, default: None = None) -> Input: ... + def input(self, key: str, default: None = None) -> Input | None: ... - def input(self, key: str, default: T | None = None) -> T | Input: - result = self.inputs[key] + def input(self, key: str, default: T | None = None) -> T | Input | None: + result = self.inputs.get(key, default) assert ( default is None or type(result) == type(default) @@ -112,6 +112,11 @@ def add_default_values(self, node_name: str, args: dict): args[k] = default return args + def input_type(self, class_type: str, input_name: str) -> tuple | None: + if inputs := _inputs_for_node(self._nodes_inputs, class_type): + return inputs.get(input_name) + return None + def dump(self, filepath: str | Path): filepath = Path(filepath) if filepath.suffix != ".json": @@ -173,6 +178,12 @@ def copy(self, node: ComfyNode): def find(self, type: str): return (self.node(int(k)) for k, v in self.root.items() if v["class_type"] == type) + def find_connected(self, output: Output): + for node in self: + for input_name, input_value in node.inputs.items(): + if input_value == output: + yield node, input_name + def __iter__(self): return iter(self.node(int(k)) for k in self.root.keys()) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 5d127bc252..3b6f5490b2 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -77,6 +77,7 @@ def _process_file(self, file: Path): def _process(self, workflow: CustomWorkflow): idx = self.find_index(workflow.id) if idx.isValid(): + self._workflows[idx.row()] = workflow self.dataChanged.emit(idx, idx) else: self.append(workflow) @@ -196,10 +197,11 @@ class ParamKind(Enum): mask_layer = 1 number_int = 2 number_float = 3 - boolean = 4 + toggle = 4 text = 5 prompt_positive = 6 prompt_negative = 7 + choice = 8 class CustomParam(NamedTuple): @@ -208,44 +210,60 @@ class CustomParam(NamedTuple): default: Any | None = None min: int | float | None = None max: int | float | None = None + choices: list[str] | None = None def workflow_parameters(w: ComfyWorkflow): + text_types = ("text", "prompt (positive)", "prompt (negative)") for node in w: - match node.type: - case "ETN_KritaImageLayer": + match (node.type, node.input("type", "")): + case ("ETN_KritaImageLayer", _): name = node.input("name", "Image") yield CustomParam(ParamKind.image_layer, name) - case "ETN_KritaMaskLayer": + case ("ETN_KritaMaskLayer", _): name = node.input("name", "Mask") yield CustomParam(ParamKind.mask_layer, name) - case "ETN_IntParameter": + case ("ETN_Parameter", "number (integer)"): name = node.input("name", "Parameter") default = node.input("default", 0) min = node.input("min", -(2**31)) max = node.input("max", 2**31) yield CustomParam(ParamKind.number_int, name, default=default, min=min, max=max) - case "ETN_NumberParameter": + case ("ETN_Parameter", "number"): name = node.input("name", "Parameter") default = node.input("default", 0.0) min = node.input("min", 0.0) max = node.input("max", 1.0) yield CustomParam(ParamKind.number_float, name, default=default, min=min, max=max) - case "ETN_BoolParameter": + case ("ETN_Parameter", "toggle"): name = node.input("name", "Parameter") default = node.input("default", False) - yield CustomParam(ParamKind.boolean, name, default=default) - case "ETN_TextParameter": + yield CustomParam(ParamKind.toggle, name, default=default) + case ("ETN_Parameter", type) if type in text_types: name = node.input("name", "Parameter") default = node.input("default", "") - type = node.input("type", "general") match type: - case "general": + case "text": yield CustomParam(ParamKind.text, name, default=default) case "prompt (positive)": yield CustomParam(ParamKind.prompt_positive, name, default=default) case "prompt (negative)": yield CustomParam(ParamKind.prompt_negative, name, default=default) + case ("ETN_Parameter", "choice"): + name = node.input("name", "Parameter") + default = node.input("default", "") + connected, input_name = next(w.find_connected(node.output()), (None, "")) + if connected: + if input_type := w.input_type(connected.type, input_name): + if isinstance(input_type[0], list): + yield CustomParam( + ParamKind.choice, name, choices=input_type[0], default=default + ) + else: + yield CustomParam(ParamKind.text, name, default=default) + case ("ETN_Parameter", unknown_type): + unknown = node.input("name", "?") + ": " + unknown_type + log.warning(f"Custom workflow has an unsupported parameter type {unknown}") class CustomWorkspace(QObject, ObservableProperties): diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py index 2a86efa531..b12437139d 100644 --- a/ai_diffusion/ui/custom_workflow.py +++ b/ai_diffusion/ui/custom_workflow.py @@ -14,7 +14,7 @@ from ..properties import Binding, Bind, bind, bind_combo from ..root import root from ..localization import translate as _ -from ..util import ensure +from ..util import ensure, clamp from .generation import GenerateButton, ProgressBar, QueueButton, HistoryWidget, create_error_label from .switch import SwitchWidget from .widget import TextPromptWidget, WorkspaceSelectWidget @@ -80,7 +80,6 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None): assert param.min is not None and param.max is not None and param.default is not None if param.max - param.min <= 200: self._widget = QSlider(Qt.Orientation.Horizontal, parent) - self._widget.setRange(int(param.min), int(param.max)) self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) self._widget.valueChanged.connect(self._notify) self._label = QLabel(self) @@ -90,12 +89,14 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None): layout.addWidget(self._label) else: self._widget = QSpinBox(parent) - self._widget.setRange(int(param.min), int(param.max)) self._widget.valueChanged.connect(self._notify) self._label = None - layout = QHBoxLayout(self) layout.addWidget(self._widget) + min_range = clamp(int(param.min), -(2**31), 2**31 - 1) + max_range = clamp(int(param.max), -(2**31), 2**31 - 1) + self._widget.setRange(min_range, max_range) + self.value = param.default def _notify(self): @@ -139,7 +140,6 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None): self._widget.setRange(param.min, param.max) self._widget.valueChanged.connect(self._notify) self._label = None - layout = QHBoxLayout(self) layout.addWidget(self._widget) self.value = param.default @@ -248,6 +248,31 @@ def value(self, value: str): self.text = value +class ChoiceParamWidget(QComboBox): + value_changed = pyqtSignal() + + def __init__(self, param: CustomParam, parent: QWidget | None = None): + super().__init__(parent) + if param.choices: + self.addItems(param.choices) + self.currentIndexChanged.connect(lambda _: self.value_changed.emit()) + + if param.default is not None: + self.value = param.default + + @property + def value(self) -> str: + if self.currentIndex() == -1: + return "" + return self.currentText() + + @value.setter + def value(self, value: str): + i = self.findText(value) + if i != -1 and i != self.currentIndex(): + self.setCurrentIndex(i) + + CustomParamWidget = ( LayerSelect | IntParamWidget @@ -255,6 +280,7 @@ def value(self, value: str): | BoolParamWidget | TextParamWidget | PromptParamWidget + | ChoiceParamWidget ) @@ -267,12 +293,14 @@ def _create_param_widget(param: CustomParam, parent: QWidget): return IntParamWidget(param, parent) if param.kind is ParamKind.number_float: return FloatParamWidget(param, parent) - if param.kind is ParamKind.boolean: + if param.kind is ParamKind.toggle: return BoolParamWidget(param, parent) if param.kind is ParamKind.text: return TextParamWidget(param, parent) if param.kind in [ParamKind.prompt_positive, ParamKind.prompt_negative]: return PromptParamWidget(param, parent) + if param.kind is ParamKind.choice: + return ChoiceParamWidget(param, parent) assert False, f"Unknown param kind: {param.kind}" diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index 12cfe13a17..3591225073 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -1082,12 +1082,7 @@ def get_param(node: ComfyNode, expected_type: type | None = None): outputs[node.output(0)] = w.load_image(get_param(node, Image)) case "ETN_KritaMaskLayer": outputs[node.output(0)] = w.load_mask(get_param(node, Image)) - case ( - "ETN_IntParameter" - | "ETN_NumberParameter" - | "ETN_BoolParameter" - | "ETN_TextParameter" - ): + case "ETN_Parameter": outputs[node.output(0)] = get_param(node) case _: mapped_inputs = {k: map_input(v) for k, v in node.inputs.items()} diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 19cba9f0d8..03312c2bc0 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -91,8 +91,14 @@ def on_data_changed(start, end): def make_dummy_graph(n: int = 42): return { "1": { - "class_type": "ETN_IntParameter", - "inputs": {"name": "param1", "default": n, "min": 5, "max": 95}, + "class_type": "ETN_Parameter", + "inputs": { + "name": "param1", + "type": "number (integer)", + "default": n, + "min": 5, + "max": 95, + }, } } @@ -141,27 +147,33 @@ def test_workspace(): workspace = CustomWorkspace(workflows) assert workspace.workflow_id == "connection1" assert workspace.workflow and workspace.workflow.id == "connection1" - assert workspace.graph and workspace.graph.node(0).type == "ETN_IntParameter" + assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter" assert workspace.metadata[0].name == "param1" assert workspace.params == {"param1": 42} doc_graph = { "1": { - "class_type": "ETN_IntParameter", - "inputs": {"name": "param2", "default": 23, "min": 5, "max": 95}, + "class_type": "ETN_Parameter", + "inputs": { + "name": "param2", + "type": "number (integer)", + "default": 23, + "min": 5, + "max": 95, + }, } } workspace.set_graph("doc1", doc_graph) assert workspace.workflow_id == "doc1" assert workspace.workflow and workspace.workflow.source is WorkflowSource.document - assert workspace.graph and workspace.graph.node(0).type == "ETN_IntParameter" + assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter" assert workspace.metadata[0].name == "param2" assert workspace.params == {"param2": 23} doc_graph["1"]["inputs"]["default"] = 24 doc_graph["2"] = { - "class_type": "ETN_IntParameter", - "inputs": {"name": "param3", "default": 7, "min": 0, "max": 10}, + "class_type": "ETN_Parameter", + "inputs": {"name": "param3", "type": "number (integer)", "default": 7, "min": 0, "max": 10}, } workflows.set_graph(workflows.index(1), doc_graph) assert workspace.metadata[0].default == 24 @@ -193,23 +205,30 @@ def test_import_ui_workflow(): def test_parameters(): - w = ComfyWorkflow() - w.add("ETN_IntParameter", 1, name="int", default=4, min=0, max=10) - w.add("ETN_BoolParameter", 1, name="bool", default=True) - w.add("ETN_NumberParameter", 1, name="number", default=1.2, min=0.0, max=10.0) - w.add("ETN_TextParameter", 1, name="text", type="general", default="mouse") - w.add("ETN_TextParameter", 1, name="positive", type="prompt (positive)", default="p") - w.add("ETN_TextParameter", 1, name="negative", type="prompt (negative)", default="n") + node_inputs = {"ChoiceNode": {"required": {"choice_param": (["a", "b", "c"],)}}} + + w = ComfyWorkflow(node_inputs=node_inputs) + w.add("ETN_Parameter", 1, name="int", type="number (integer)", default=4, min=0, max=10) + w.add("ETN_Parameter", 1, name="bool", type="toggle", default=True) + w.add("ETN_Parameter", 1, name="number", type="number", default=1.2, min=0.0, max=10.0) + w.add("ETN_Parameter", 1, name="text", type="text", default="mouse") + w.add("ETN_Parameter", 1, name="positive", type="prompt (positive)", default="p") + w.add("ETN_Parameter", 1, name="negative", type="prompt (negative)", default="n") + w.add("ETN_Parameter", 1, name="choice_unconnected", type="choice", default="z") + choice_param = w.add("ETN_Parameter", 1, name="choice", type="choice", default="c") + w.add("ChoiceNode", 1, choice_param=choice_param) w.add("ETN_KritaImageLayer", 1, name="image") w.add("ETN_KritaMaskLayer", 1, name="mask") assert list(workflow_parameters(w)) == [ CustomParam(ParamKind.number_int, "int", 4, 0, 10), - CustomParam(ParamKind.boolean, "bool", True), + CustomParam(ParamKind.toggle, "bool", True), CustomParam(ParamKind.number_float, "number", 1.2, 0.0, 10.0), CustomParam(ParamKind.text, "text", "mouse"), CustomParam(ParamKind.prompt_positive, "positive", "p"), CustomParam(ParamKind.prompt_negative, "negative", "n"), + CustomParam(ParamKind.text, "choice_unconnected", "z"), + CustomParam(ParamKind.choice, "choice", "c", choices=["a", "b", "c"]), CustomParam(ParamKind.image_layer, "image"), CustomParam(ParamKind.mask_layer, "mask"), ] @@ -220,10 +239,13 @@ def test_expand(): in_img, width, height, seed = ext.add("ETN_KritaCanvas", 4) scaled = ext.add("ImageScale", 1, image=in_img, width=width, height=height) ext.add("ETN_KritaOutput", 1, images=scaled) - inty = ext.add("ETN_IntParameter", 1, name="inty", default=4, min=0, max=10) - numby = ext.add("ETN_NumberParameter", 1, name="numby", default=1.2, min=0.0, max=10.0) - texty = ext.add("ETN_TextParameter", 1, name="texty", type="general", default="mouse") - booly = ext.add("ETN_BoolParameter", 1, name="booly", default=True) + inty = ext.add( + "ETN_Parameter", 1, name="inty", type="number (integer)", default=4, min=0, max=10 + ) + numby = ext.add("ETN_Parameter", 1, name="numby", type="number", default=1.2, min=0.0, max=10.0) + texty = ext.add("ETN_Parameter", 1, name="texty", type="text", default="mouse") + booly = ext.add("ETN_Parameter", 1, name="booly", type="toggle", default=True) + choicy = ext.add("ETN_Parameter", 1, name="choicy", type="choice", default="c") layer_img = ext.add("ETN_KritaImageLayer", 1, name="layer_img") layer_mask = ext.add("ETN_KritaMaskLayer", 1, name="layer_mask") ext.add( @@ -234,6 +256,7 @@ def test_expand(): numby=numby, texty=texty, booly=booly, + choicy=choicy, layer_img=layer_img, layer_mask=layer_mask, ) @@ -243,6 +266,7 @@ def test_expand(): "numby": 3.4, "texty": "cat", "booly": False, + "choicy": "b", "layer_img": Image.create(Extent(4, 4), Qt.GlobalColor.black), "layer_mask": Image.create(Extent(4, 4), Qt.GlobalColor.white), } @@ -269,6 +293,7 @@ def test_expand(): "numby": 3.4, "texty": "cat", "booly": False, + "choicy": "b", "layer_img": Output(4, 0), "layer_mask": Output(5, 0), }, From 1428088e0f01cc319ba95e97ac2236a7c2b06f49 Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 8 Oct 2024 16:45:57 +0200 Subject: [PATCH 19/28] Support integrating styles into custom workflows --- ai_diffusion/custom_workflow.py | 4 +++ ai_diffusion/model.py | 12 +++++++-- ai_diffusion/ui/control.py | 28 ++++++++++----------- ai_diffusion/ui/custom_workflow.py | 37 +++++++++++++++++++++++++--- ai_diffusion/workflow.py | 29 ++++++++++++++++++---- tests/test_custom_workflow.py | 39 ++++++++++++++++++++++++++++-- 6 files changed, 122 insertions(+), 27 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 3b6f5490b2..dd7f3e66ea 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -202,6 +202,7 @@ class ParamKind(Enum): prompt_positive = 6 prompt_negative = 7 choice = 8 + style = 9 class CustomParam(NamedTuple): @@ -217,6 +218,9 @@ def workflow_parameters(w: ComfyWorkflow): text_types = ("text", "prompt (positive)", "prompt (negative)") for node in w: match (node.type, node.input("type", "")): + case ("ETN_KritaStyle", _): + name = node.input("name", "Style") + yield CustomParam(ParamKind.style, name, node.input("sampler_preset", "auto")) case ("ETN_KritaImageLayer", _): name = node.input("name", "Image") yield CustomParam(ParamKind.image_layer, name) diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 1b8dd11115..9bb75229b9 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -366,16 +366,24 @@ def generate_custom(self): params = copy(self.custom.params) for md in self.custom.metadata: + param = params.get(md.name) + assert param is not None, f"Parameter {md.name} not found" + if md.kind is ParamKind.image_layer: - layer = self.layers.find(QUuid(params[md.name])) + layer = self.layers.find(QUuid(param)) if layer is None: raise ValueError(f"Input layer for parameter {md.name} not found") params[md.name] = layer.get_pixels(bounds) elif md.kind is ParamKind.mask_layer: - layer = self.layers.find(QUuid(params[md.name])) + layer = self.layers.find(QUuid(param)) if layer is None: raise ValueError(f"Input layer for parameter {md.name} not found") params[md.name] = layer.get_mask(bounds) + elif md.kind is ParamKind.style: + style = Styles.list().find(str(param)) + if style is None: + raise ValueError(f"Style {param} not found") + params[md.name] = style input = WorkflowInput( WorkflowKind.custom, diff --git a/ai_diffusion/ui/control.py b/ai_diffusion/ui/control.py index 1e4af8a463..d86c0d3e63 100644 --- a/ai_diffusion/ui/control.py +++ b/ai_diffusion/ui/control.py @@ -16,16 +16,14 @@ class ControlWidget(QWidget): - _control_list: ControlLayerList - _control: ControlLayer - _connections: list[QMetaObject.Connection | Binding] def __init__( - self, control_list: ControlLayerList, control: ControlLayer, parent: ControlListWidget + self, control_list: ControlLayerList | None, control: ControlLayer, parent: QWidget ): super().__init__(parent) self._control_list = control_list self._control = control + self._connections: list[QMetaObject.Connection | Binding] = [] layout = QVBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) @@ -80,13 +78,6 @@ def __init__( self.expand_button.setChecked(False) self.expand_button.clicked.connect(self._toggle_extended) - self.remove_button = QToolButton(self) - self.remove_button.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly) - self.remove_button.setIcon(theme.icon("remove")) - self.remove_button.setToolTip(_("Remove control layer")) - self.remove_button.setAutoRaise(True) - self.remove_button.clicked.connect(self.remove) - bar_layout = QHBoxLayout() bar_layout.addWidget(self.mode_select) bar_layout.addWidget(self.layer_select, 3) @@ -95,9 +86,17 @@ def __init__( bar_layout.addWidget(self.preset_slider, 1) bar_layout.addWidget(self.error_text, 3) bar_layout.addWidget(self.expand_button) - bar_layout.addWidget(self.remove_button) layout.addLayout(bar_layout) + if self._control_list is not None: + self.remove_button = QToolButton(self) + self.remove_button.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly) + self.remove_button.setIcon(theme.icon("remove")) + self.remove_button.setToolTip(_("Remove control layer")) + self.remove_button.setAutoRaise(True) + self.remove_button.clicked.connect(self.remove) + bar_layout.addWidget(self.remove_button) + line = QFrame(self) line.setObjectName("LeftIndent") line.setStyleSheet(f"#LeftIndent {{ color: {theme.line}; }}") @@ -198,12 +197,13 @@ def _update_layers(self): self.layer_select.addItem(layer.name, layer.id) if layer.id == self._control.layer_id: index = self.layer_select.count() - 1 - if index == -1 and self._control in self._control_list: + if index == -1 and self._control_list and self._control in self._control_list: self.remove() - else: + elif index >= 0: self.layer_select.setCurrentIndex(index) def remove(self): + assert self._control_list is not None self._control_list.remove(self._control) def resizeEvent(self, a0: QResizeEvent | None): diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py index b12437139d..720acc42a0 100644 --- a/ai_diffusion/ui/custom_workflow.py +++ b/ai_diffusion/ui/custom_workflow.py @@ -12,12 +12,13 @@ from ..jobs import JobKind from ..model import Model from ..properties import Binding, Bind, bind, bind_combo +from ..style import Styles from ..root import root from ..localization import translate as _ from ..util import ensure, clamp from .generation import GenerateButton, ProgressBar, QueueButton, HistoryWidget, create_error_label from .switch import SwitchWidget -from .widget import TextPromptWidget, WorkspaceSelectWidget +from .widget import TextPromptWidget, WorkspaceSelectWidget, StyleSelectWidget from . import theme @@ -273,6 +274,31 @@ def value(self, value: str): self.setCurrentIndex(i) +class StyleParamWidget(QWidget): + value_changed = pyqtSignal() + + def __init__(self, parent: QWidget): + super().__init__(parent) + self._style_select = StyleSelectWidget(self) + self._style_select.value_changed.connect(self._notify) + layout = QHBoxLayout() + layout.setContentsMargins(0, 0, 0, 0) + layout.addWidget(self._style_select) + self.setLayout(layout) + + def _notify(self): + self.value_changed.emit() + + @property + def value(self): + return self._style_select.value.filename + + @value.setter + def value(self, value: str): + if style := Styles.list().find(value): + self._style_select.value = style + + CustomParamWidget = ( LayerSelect | IntParamWidget @@ -281,10 +307,11 @@ def value(self, value: str): | TextParamWidget | PromptParamWidget | ChoiceParamWidget + | StyleParamWidget ) -def _create_param_widget(param: CustomParam, parent: QWidget): +def _create_param_widget(param: CustomParam, parent: QWidget) -> CustomParamWidget: if param.kind is ParamKind.image_layer: return LayerSelect("image", parent) if param.kind is ParamKind.mask_layer: @@ -301,13 +328,15 @@ def _create_param_widget(param: CustomParam, parent: QWidget): return PromptParamWidget(param, parent) if param.kind is ParamKind.choice: return ChoiceParamWidget(param, parent) + if param.kind is ParamKind.style: + return StyleParamWidget(parent) assert False, f"Unknown param kind: {param.kind}" class WorkflowParamsWidget(QWidget): value_changed = pyqtSignal() - def __init__(self, params: list[CustomParam], parent: QWidget | None = None): + def __init__(self, params: list[CustomParam], parent: QWidget): super().__init__(parent) self._widgets: dict[str, CustomParamWidget] = {} @@ -321,7 +350,7 @@ def __init__(self, params: list[CustomParam], parent: QWidget | None = None): widget = _create_param_widget(p, self) widget.value_changed.connect(self._notify) row = len(self._widgets) - layout.addWidget(label, row, 0) + layout.addWidget(label, row, 0, Qt.AlignmentFlag.AlignBaseline) layout.addWidget(widget, row, 2) self._widgets[p.name] = widget diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index 3591225073..f197b23a65 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -1044,13 +1044,17 @@ def tiled_region(region: Region, index: int, tile_bounds: Bounds): def expand_custom( - w: ComfyWorkflow, input: CustomWorkflowInput, images: ImageInput, sampling: SamplingInput + w: ComfyWorkflow, + input: CustomWorkflowInput, + images: ImageInput, + sampling: SamplingInput, + models: ClientModels, ): custom = ComfyWorkflow.from_dict(input.workflow) nodes: dict[int, int] = {} # map old node IDs to new node IDs outputs: dict[Output, Input] = {} - def map_input(input): + def map_input(input: Input): if isinstance(input, Output): mapped = outputs.get(input) if mapped is not None: @@ -1078,12 +1082,27 @@ def get_param(node: ComfyNode, expected_type: type | None = None): outputs[node.output(3)] = sampling.seed case "ETN_KritaSelection": outputs[node.output(0)] = w.load_mask(ensure(images.hires_mask)) + case "ETN_Parameter": + outputs[node.output(0)] = get_param(node) case "ETN_KritaImageLayer": outputs[node.output(0)] = w.load_image(get_param(node, Image)) case "ETN_KritaMaskLayer": outputs[node.output(0)] = w.load_mask(get_param(node, Image)) - case "ETN_Parameter": - outputs[node.output(0)] = get_param(node) + case "ETN_KritaStyle": + style: Style = get_param(node, Style) + is_live = node.input("sampler_preset", "auto") == "live" + checkpoint_input = style.get_models() + sampling = _sampling_from_style(style, 1.0, is_live) + model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, models) + outputs[node.output(0)] = model + outputs[node.output(1)] = clip + outputs[node.output(2)] = vae + outputs[node.output(3)] = style.style_prompt + outputs[node.output(4)] = style.negative_prompt + outputs[node.output(5)] = sampling.sampler + outputs[node.output(6)] = sampling.scheduler + outputs[node.output(7)] = sampling.total_steps + outputs[node.output(8)] = sampling.cfg_scale case _: mapped_inputs = {k: map_input(v) for k, v in node.inputs.items()} mapped = ComfyNode(node.id, node.type, mapped_inputs) @@ -1309,7 +1328,7 @@ def create(i: WorkflowInput, models: ClientModels, comfy_mode=ComfyRunMode.serve ) elif i.kind is WorkflowKind.custom: return expand_custom( - workflow, ensure(i.custom_workflow), ensure(i.images), ensure(i.sampling) + workflow, ensure(i.custom_workflow), ensure(i.images), ensure(i.sampling), models ) else: raise ValueError(f"Unsupported workflow kind: {i.kind}") diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 03312c2bc0..ae986df1c9 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -4,12 +4,15 @@ from PyQt5.QtCore import Qt from ai_diffusion.api import CustomWorkflowInput, ImageInput, SamplingInput +from ai_diffusion.client import ClientModels, CheckpointInfo from ai_diffusion.connection import Connection from ai_diffusion.comfy_workflow import ComfyNode, ComfyWorkflow, Output from ai_diffusion.custom_workflow import CustomWorkflow, WorkflowSource, WorkflowCollection from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace from ai_diffusion.custom_workflow import CustomParam, ParamKind, workflow_parameters from ai_diffusion.image import Image, Extent +from ai_diffusion.style import Style +from ai_diffusion.resources import Arch from ai_diffusion import workflow from .config import test_dir @@ -219,6 +222,7 @@ def test_parameters(): w.add("ChoiceNode", 1, choice_param=choice_param) w.add("ETN_KritaImageLayer", 1, name="image") w.add("ETN_KritaMaskLayer", 1, name="mask") + w.add("ETN_KritaStyle", 9, name="style", sampler_preset="live") # type: ignore assert list(workflow_parameters(w)) == [ CustomParam(ParamKind.number_int, "int", 4, 0, 10), @@ -231,6 +235,7 @@ def test_parameters(): CustomParam(ParamKind.choice, "choice", "c", choices=["a", "b", "c"]), CustomParam(ParamKind.image_layer, "image"), CustomParam(ParamKind.mask_layer, "mask"), + CustomParam(ParamKind.style, "style", "live"), ] @@ -248,6 +253,7 @@ def test_expand(): choicy = ext.add("ETN_Parameter", 1, name="choicy", type="choice", default="c") layer_img = ext.add("ETN_KritaImageLayer", 1, name="layer_img") layer_mask = ext.add("ETN_KritaMaskLayer", 1, name="layer_mask") + stylie = ext.add("ETN_KritaStyle", 9, name="style", sampler_preset="live") # type: ignore ext.add( "Sink", 1, @@ -259,8 +265,21 @@ def test_expand(): choicy=choicy, layer_img=layer_img, layer_mask=layer_mask, + model=stylie[0], + clip=stylie[1], + vae=stylie[2], + positive=stylie[3], + negative=stylie[4], + sampler=stylie[5], + scheduler=stylie[6], + steps=stylie[7], + guidance=stylie[8], ) + style = Style(Path("default.json")) + style.sd_checkpoint = "checkpoint.safetensors" + style.style_prompt = "bee hive" + style.negative_prompt = "pigoon" params = { "inty": 7, "numby": 3.4, @@ -269,6 +288,7 @@ def test_expand(): "choicy": "b", "layer_img": Image.create(Extent(4, 4), Qt.GlobalColor.black), "layer_mask": Image.create(Extent(4, 4), Qt.GlobalColor.white), + "style": style, } input = CustomWorkflowInput(workflow=ext.root, params=params) @@ -276,16 +296,22 @@ def test_expand(): images.initial_image = Image.create(Extent(4, 4), Qt.GlobalColor.white) sampling = SamplingInput("", "", 1.0, 1000, seed=123) + models = ClientModels() + models.checkpoints = { + "checkpoint.safetensors": CheckpointInfo("checkpoint.safetensors", Arch.sd15) + } + w = ComfyWorkflow() - w = workflow.expand_custom(w, input, images, sampling) + w = workflow.expand_custom(w, input, images, sampling, models) expected = [ ComfyNode(1, "ETN_LoadImageBase64", {"image": images.initial_image.to_base64()}), ComfyNode(2, "ImageScale", {"image": Output(1, 0), "width": 4, "height": 4}), ComfyNode(3, "ETN_KritaOutput", {"images": Output(2, 0)}), ComfyNode(4, "ETN_LoadImageBase64", {"image": params["layer_img"].to_base64()}), ComfyNode(5, "ETN_LoadMaskBase64", {"mask": params["layer_mask"].to_base64()}), + ComfyNode(6, "CheckpointLoaderSimple", {"ckpt_name": "checkpoint.safetensors"}), ComfyNode( - 6, + 7, "Sink", { "seed": 123, @@ -296,6 +322,15 @@ def test_expand(): "choicy": "b", "layer_img": Output(4, 0), "layer_mask": Output(5, 0), + "model": Output(6, 0), + "clip": Output(6, 1), + "vae": Output(6, 2), + "positive": "bee hive", + "negative": "pigoon", + "sampler": "euler", + "scheduler": "sgm_uniform", + "steps": 6, + "guidance": 1.8, }, ), ] From 917b711953793ae823cfbac16326db56f54c2e8e Mon Sep 17 00:00:00 2001 From: Acly Date: Tue, 8 Oct 2024 17:50:02 +0200 Subject: [PATCH 20/28] Fix seed being overwritten --- ai_diffusion/workflow.py | 9 ++++----- tests/test_custom_workflow.py | 3 +-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index f197b23a65..b0e94e4bf3 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -1047,7 +1047,7 @@ def expand_custom( w: ComfyWorkflow, input: CustomWorkflowInput, images: ImageInput, - sampling: SamplingInput, + seed: int, models: ClientModels, ): custom = ComfyWorkflow.from_dict(input.workflow) @@ -1079,7 +1079,7 @@ def get_param(node: ComfyNode, expected_type: type | None = None): outputs[node.output(0)] = w.load_image(image) outputs[node.output(1)] = image.width outputs[node.output(2)] = image.height - outputs[node.output(3)] = sampling.seed + outputs[node.output(3)] = seed case "ETN_KritaSelection": outputs[node.output(0)] = w.load_mask(ensure(images.hires_mask)) case "ETN_Parameter": @@ -1327,9 +1327,8 @@ def create(i: WorkflowInput, models: ClientModels, comfy_mode=ComfyRunMode.serve seed=i.sampling.seed if i.sampling else -1, ) elif i.kind is WorkflowKind.custom: - return expand_custom( - workflow, ensure(i.custom_workflow), ensure(i.images), ensure(i.sampling), models - ) + seed = ensure(i.sampling).seed + return expand_custom(workflow, ensure(i.custom_workflow), ensure(i.images), seed, models) else: raise ValueError(f"Unsupported workflow kind: {i.kind}") diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index ae986df1c9..e3fd9b1771 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -294,7 +294,6 @@ def test_expand(): input = CustomWorkflowInput(workflow=ext.root, params=params) images = ImageInput.from_extent(Extent(4, 4)) images.initial_image = Image.create(Extent(4, 4), Qt.GlobalColor.white) - sampling = SamplingInput("", "", 1.0, 1000, seed=123) models = ClientModels() models.checkpoints = { @@ -302,7 +301,7 @@ def test_expand(): } w = ComfyWorkflow() - w = workflow.expand_custom(w, input, images, sampling, models) + w = workflow.expand_custom(w, input, images, 123, models) expected = [ ComfyNode(1, "ETN_LoadImageBase64", {"image": images.initial_image.to_base64()}), ComfyNode(2, "ImageScale", {"image": Output(1, 0), "width": 4, "height": 4}), From acfedba4cf6c9ea1e940a91ab717ee930efbc2e9 Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 9 Oct 2024 19:48:49 +0200 Subject: [PATCH 21/28] Move some code out of model.py --- ai_diffusion/custom_workflow.py | 31 +++++++++++++++++++++++++++++-- ai_diffusion/model.py | 24 ++---------------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index dd7f3e66ea..4e9276e7ce 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -1,15 +1,19 @@ import json from enum import Enum +from copy import copy from dataclasses import dataclass from typing import Any, NamedTuple from pathlib import Path -from PyQt5.QtCore import Qt, QObject, QAbstractListModel, QSortFilterProxyModel, QModelIndex +from PyQt5.QtCore import Qt, QObject, QUuid, QAbstractListModel, QSortFilterProxyModel, QModelIndex from PyQt5.QtCore import pyqtSignal from .comfy_workflow import ComfyWorkflow from .connection import Connection +from .image import Bounds +from .layer import LayerManager from .properties import Property, ObservableProperties +from .style import Styles from .util import user_data_dir, client_logger as log from .ui import theme @@ -265,7 +269,7 @@ def workflow_parameters(w: ComfyWorkflow): ) else: yield CustomParam(ParamKind.text, name, default=default) - case ("ETN_Parameter", unknown_type): + case ("ETN_Parameter", unknown_type) if unknown_type != "auto": unknown = node.input("name", "?") + ": " + unknown_type log.warning(f"Custom workflow has an unsupported parameter type {unknown}") @@ -344,6 +348,29 @@ def graph(self): def metadata(self): return self._metadata + def collect_parameters(self, layers: LayerManager, bounds: Bounds): + params = copy(self.params) + for md in self.metadata: + param = params.get(md.name) + assert param is not None, f"Parameter {md.name} not found" + + if md.kind is ParamKind.image_layer: + layer = layers.find(QUuid(param)) + if layer is None: + raise ValueError(f"Input layer for parameter {md.name} not found") + params[md.name] = layer.get_pixels(bounds) + elif md.kind is ParamKind.mask_layer: + layer = layers.find(QUuid(param)) + if layer is None: + raise ValueError(f"Input layer for parameter {md.name} not found") + params[md.name] = layer.get_mask(bounds) + elif md.kind is ParamKind.style: + style = Styles.list().find(str(param)) + if style is None: + raise ValueError(f"Style {param} not found") + params[md.name] = style + return params + def _coerce(params: dict[str, Any], types: list[CustomParam]): def use(value, default): diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 9bb75229b9..0277acab0f 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -364,27 +364,7 @@ def generate_custom(self): else: img_input.hires_mask = Mask.transparent(bounds).to_image() - params = copy(self.custom.params) - for md in self.custom.metadata: - param = params.get(md.name) - assert param is not None, f"Parameter {md.name} not found" - - if md.kind is ParamKind.image_layer: - layer = self.layers.find(QUuid(param)) - if layer is None: - raise ValueError(f"Input layer for parameter {md.name} not found") - params[md.name] = layer.get_pixels(bounds) - elif md.kind is ParamKind.mask_layer: - layer = self.layers.find(QUuid(param)) - if layer is None: - raise ValueError(f"Input layer for parameter {md.name} not found") - params[md.name] = layer.get_mask(bounds) - elif md.kind is ParamKind.style: - style = Styles.list().find(str(param)) - if style is None: - raise ValueError(f"Style {param} not found") - params[md.name] = style - + params = self.custom.collect_parameters(self.layers, bounds) input = WorkflowInput( WorkflowKind.custom, img_input, @@ -539,7 +519,7 @@ def create_result_layer( ): name = f"{prefix}{job_region.prompt} ({params.seed})" region_layer = self.layers.find(QUuid(job_region.layer_id)) or self.layers.root - # a previous apply from the same batch my have already created groups and re-linked + # a previous apply from the same batch may have already created groups and re-linked region_layer = Region.link_target(region_layer) # Replace content if requested and not a group layer From 0e9d24e9a8883a2a5e0fc459af81db1837da2922 Mon Sep 17 00:00:00 2001 From: Acly Date: Wed, 9 Oct 2024 20:20:35 +0200 Subject: [PATCH 22/28] Live generation mode for custom workflows --- ai_diffusion/custom_workflow.py | 74 +++++++++++++++- ai_diffusion/model.py | 27 ++++-- ai_diffusion/ui/custom_workflow.py | 135 +++++++++++++++++++++++------ ai_diffusion/ui/live.py | 24 +++-- 4 files changed, 213 insertions(+), 47 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 4e9276e7ce..40eb76849d 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -1,21 +1,25 @@ +import asyncio import json from enum import Enum from copy import copy from dataclasses import dataclass -from typing import Any, NamedTuple +from typing import Any, Awaitable, Callable, NamedTuple, Literal from pathlib import Path from PyQt5.QtCore import Qt, QObject, QUuid, QAbstractListModel, QSortFilterProxyModel, QModelIndex from PyQt5.QtCore import pyqtSignal +from .api import WorkflowInput from .comfy_workflow import ComfyWorkflow from .connection import Connection -from .image import Bounds +from .image import Bounds, Image from .layer import LayerManager +from .jobs import Job, JobParams, JobQueue, JobKind from .properties import Property, ObservableProperties from .style import Styles from .util import user_data_dir, client_logger as log from .ui import theme +from . import eventloop class WorkflowSource(Enum): @@ -274,23 +278,45 @@ def workflow_parameters(w: ComfyWorkflow): log.warning(f"Custom workflow has an unsupported parameter type {unknown}") +ImageGenerator = Callable[[WorkflowInput | None], Awaitable[None | Literal[False] | WorkflowInput]] + + +class CustomGenerationMode(Enum): + regular = 0 + live = 1 + + class CustomWorkspace(QObject, ObservableProperties): workflow_id = Property("", setter="_set_workflow_id") params = Property({}, persist=True) + mode = Property(CustomGenerationMode.regular, setter="_set_mode") + is_live = Property(False, setter="toggle_live") + has_result = Property(False) workflow_id_changed = pyqtSignal(str) graph_changed = pyqtSignal() params_changed = pyqtSignal(dict) + mode_changed = pyqtSignal(CustomGenerationMode) + is_live_changed = pyqtSignal(bool) + result_available = pyqtSignal(Image) + has_result_changed = pyqtSignal(bool) modified = pyqtSignal(QObject, str) - def __init__(self, workflows: WorkflowCollection): + _live_poll_rate = 0.1 + + def __init__(self, workflows: WorkflowCollection, generator: ImageGenerator, jobs: JobQueue): super().__init__() self._workflows = workflows + self._generator = generator self._workflow: CustomWorkflow | None = None self._graph: ComfyWorkflow | None = None self._metadata: list[CustomParam] = [] + self._last_input: WorkflowInput | None = None + self._last_result: Image | None = None + self._last_job: JobParams | None = None + jobs.job_finished.connect(self._handle_job_finished) workflows.dataChanged.connect(self._update_workflow) workflows.rowsInserted.connect(self._set_default_workflow) self._set_default_workflow() @@ -336,6 +362,22 @@ def remove_workflow(self): self._metadata = [] self._workflows.remove(id) + def generate(self): + eventloop.run(self._generator(None)) + + def toggle_live(self, active: bool): + if self._is_live != active: + self._is_live = active + self.is_live_changed.emit(active) + if active: + eventloop.run(self._continue_generating()) + + def _set_mode(self, value: CustomGenerationMode): + if self._mode != value: + self._mode = value + self.mode_changed.emit(value) + self.is_live = False + @property def workflow(self): return self._workflow @@ -371,6 +413,32 @@ def collect_parameters(self, layers: LayerManager, bounds: Bounds): params[md.name] = style return params + def _handle_job_finished(self, job: Job): + if job.kind is JobKind.live_preview: + if len(job.results) > 0: + self._last_result = job.results[0] + self._last_job = job.params + self.result_available.emit(self._last_result) + self.has_result = True + eventloop.run(self._continue_generating()) + + async def _continue_generating(self): + while self.is_live: + new_input = await self._generator(self._last_input) + if new_input is False: # abort live generation + self.is_live = False + return + elif new_input is None: # no changes in input data + await asyncio.sleep(self._live_poll_rate) + else: # frame was scheduled + self._last_input = new_input + return + + @property + def live_result(self): + assert self._last_result and self._last_job, "No live result available" + return self._last_result, self._last_job + def _coerce(params: dict[str, Any], types: list[CustomParam]): def use(value, default): diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 0277acab0f..20e124b9c1 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -19,7 +19,7 @@ from .image import Extent, Image, Mask, Bounds, DummyImage from .client import ClientMessage, ClientEvent, SharedWorkflow from .client import filter_supported_styles, resolve_arch -from .custom_workflow import CustomWorkspace, WorkflowCollection, ParamKind +from .custom_workflow import CustomWorkspace, WorkflowCollection, CustomGenerationMode from .document import Document, KritaDocument from .layer import Layer, LayerType, RestoreActiveLayer from .pose import Pose @@ -94,7 +94,7 @@ def __init__(self, document: Document, connection: Connection, workflows: Workfl self.upscale = UpscaleWorkspace(self) self.live = LiveWorkspace(self) self.animation = AnimationWorkspace(self) - self.custom = CustomWorkspace(workflows) + self.custom = CustomWorkspace(workflows, self._generate_custom, self.jobs) self.jobs.selection_changed.connect(self.update_preview) self.error_changed.connect(lambda: self.has_error_changed.emit(self.has_error)) @@ -349,13 +349,17 @@ async def _generate_live(self, last_input: WorkflowInput | None = None): return None - def generate_custom(self): + async def _generate_custom(self, previous_input: WorkflowInput | None): + if self.workspace is not Workspace.custom or not self.document.is_active: + return False + try: wf = ensure(self.custom.graph) bounds = Bounds(0, 0, *self._doc.extent) img_input = ImageInput.from_extent(bounds.extent) img_input.initial_image = self._get_current_image(bounds) - seed = self.seed if self.fixed_seed else workflow.generate_seed() + is_live = self.custom.mode is CustomGenerationMode.live + seed = self.seed if is_live or self.fixed_seed else workflow.generate_seed() if next(wf.find(type="ETN_KritaSelection"), None): mask, _ = self._doc.create_mask_from_selection() @@ -372,13 +376,18 @@ def generate_custom(self): custom_workflow=CustomWorkflowInput(wf.root, params), ) job_params = JobParams(bounds, self.custom.workflow_id) + job_kind = JobKind.live_preview if is_live else JobKind.diffusion + + if input == previous_input: + return None + + self.clear_error() + await self.enqueue_jobs(input, job_kind, job_params, self.batch_count) + return input + except Exception as e: self.report_error(util.log_error(e)) - return - - self.clear_error() - jobs = self.enqueue_jobs(input, JobKind.diffusion, job_params, self.batch_count) - eventloop.run(_report_errors(self, jobs)) + return False def _get_current_image(self, bounds: Bounds): exclude = None diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py index 720acc42a0..d8f89a1e13 100644 --- a/ai_diffusion/ui/custom_workflow.py +++ b/ai_diffusion/ui/custom_workflow.py @@ -2,21 +2,24 @@ from pathlib import Path from typing import Any, Callable -from PyQt5.QtCore import Qt, pyqtSignal, QMetaObject, QUuid, QUrl +from PyQt5.QtCore import Qt, pyqtSignal, QMetaObject, QUuid, QUrl, QPoint from PyQt5.QtGui import QFontMetrics, QIcon, QDesktopServices -from PyQt5.QtWidgets import QComboBox, QFileDialog, QFrame, QGridLayout, QHBoxLayout -from PyQt5.QtWidgets import QLabel, QLineEdit, QListWidgetItem, QMessageBox, QSpinBox +from PyQt5.QtWidgets import QComboBox, QFileDialog, QFrame, QGridLayout, QHBoxLayout, QMenu +from PyQt5.QtWidgets import QLabel, QLineEdit, QListWidgetItem, QMessageBox, QSpinBox, QAction from PyQt5.QtWidgets import QToolButton, QVBoxLayout, QWidget, QSlider, QDoubleSpinBox from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows, WorkflowSource +from ..custom_workflow import CustomGenerationMode from ..jobs import JobKind -from ..model import Model +from ..model import Model, ApplyBehavior from ..properties import Binding, Bind, bind, bind_combo from ..style import Styles from ..root import root +from ..settings import settings from ..localization import translate as _ from ..util import ensure, clamp from .generation import GenerateButton, ProgressBar, QueueButton, HistoryWidget, create_error_label +from .live import LivePreviewArea from .switch import SwitchWidget from .widget import TextPromptWidget, WorkspaceSelectWidget, StyleSelectWidget from . import theme @@ -447,14 +450,36 @@ def __init__(self): self._params_widget = WorkflowParamsWidget([], self) self._generate_button = GenerateButton(JobKind.diffusion, self) + self._generate_button.clicked.connect(self._generate) + + self._apply_button = QToolButton(self) + self._apply_button.setIcon(theme.icon("apply")) + self._apply_button.setFixedHeight(self._generate_button.height() - 2) + self._apply_button.setToolTip(_("Create a new layer with the current result")) + self._apply_button.clicked.connect(self.apply_live_result) + + self._mode_button = QToolButton(self) + self._mode_button.setArrowType(Qt.ArrowType.DownArrow) + self._mode_button.setFixedHeight(self._generate_button.height() - 2) + self._mode_button.clicked.connect(self._show_generate_menu) + menu = QMenu(self) + menu.addAction(self._mk_action(CustomGenerationMode.regular, _("Generate"), "generate")) + menu.addAction( + self._mk_action(CustomGenerationMode.live, _("Generate Live"), "workspace-live") + ) + self._generate_menu = menu + self._queue_button = QueueButton(parent=self) self._queue_button.setFixedHeight(self._generate_button.height() - 2) + self._progress_bar = ProgressBar(self) self._error_text = create_error_label(self) self._history = HistoryWidget(self) self._history.item_activated.connect(self.apply_result) + self._live_preview = LivePreviewArea(self) + self._layout = QVBoxLayout() select_layout = QHBoxLayout() select_layout.setContentsMargins(0, 0, 0, 0) @@ -479,35 +504,20 @@ def __init__(self): self._layout.addLayout(header_layout) self._layout.addWidget(self._params_widget) actions_layout = QHBoxLayout() + actions_layout.setSpacing(0) actions_layout.addWidget(self._generate_button) + actions_layout.addWidget(self._apply_button) + actions_layout.addWidget(self._mode_button) + actions_layout.addSpacing(4) actions_layout.addWidget(self._queue_button) self._layout.addLayout(actions_layout) self._layout.addWidget(self._progress_bar) self._layout.addWidget(self._error_text) self._layout.addWidget(self._history) + self._layout.addWidget(self._live_preview) self.setLayout(self._layout) - def _update_current_workflow(self): - if not self.model.custom.workflow: - self._save_workflow_button.setEnabled(False) - self._delete_workflow_button.setEnabled(False) - return - self._save_workflow_button.setEnabled(True) - self._delete_workflow_button.setEnabled( - self.model.custom.workflow.source is WorkflowSource.local - ) - - self._params_widget.deleteLater() - self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self) - self._params_widget.value = self.model.custom.params - self._layout.insertWidget(1, self._params_widget) - self._params_widget.value_changed.connect(self._change_params) - - def _change_workflow(self): - self.model.custom.workflow_id = self._workflow_select.currentData() - - def _change_params(self): - self.model.custom.params = self._params_widget.value + self._update_ui() @property def model(self): @@ -525,17 +535,90 @@ def model(self, model: Model): model.custom.graph_changed.connect(self._update_current_workflow), model.error_changed.connect(self._error_text.setText), model.has_error_changed.connect(self._error_text.setVisible), - self._generate_button.clicked.connect(model.generate_custom), + model.custom.mode_changed.connect(self._update_ui), + model.custom.is_live_changed.connect(self._update_ui), + model.custom.result_available.connect(self._live_preview.show_image), + model.custom.has_result_changed.connect(self._apply_button.setEnabled), ] self._queue_button.model = model self._progress_bar.model = model self._history.model_ = model self._update_current_workflow() + self._update_ui() + + def _mk_action(self, mode: CustomGenerationMode, text: str, icon: str): + action = QAction(text, self) + action.setIcon(theme.icon(icon)) + action.setIconVisibleInMenu(True) + action.triggered.connect(lambda: self._change_mode(mode)) + return action + + def _change_mode(self, mode: CustomGenerationMode): + self.model.custom.mode = mode + + def _show_generate_menu(self): + width = self._generate_button.width() + self._mode_button.width() + pos = QPoint(0, self._generate_button.height()) + self._generate_menu.setFixedWidth(width) + self._generate_menu.exec_(self._generate_button.mapToGlobal(pos)) + + def _update_ui(self): + is_live_mode = self.model.custom.mode is CustomGenerationMode.live + self._history.setVisible(not is_live_mode) + self._live_preview.setVisible(is_live_mode) + self._apply_button.setVisible(is_live_mode) + self._apply_button.setEnabled(self.model.custom.has_result) + + if not is_live_mode: + text = _("Generate") + icon = "generate" + elif not self.model.custom.is_live: + text = _("Start Generating") + icon = "play" + else: + text = _("Stop Generating") + icon = "pause" + self._generate_button.operation = text + self._generate_button.setIcon(theme.icon(icon)) + + def _generate(self): + if self.model.custom.mode is CustomGenerationMode.regular: + self.model.custom.generate() + else: + self.model.custom.is_live = not self.model.custom.is_live + + def _update_current_workflow(self): + if not self.model.custom.workflow: + self._save_workflow_button.setEnabled(False) + self._delete_workflow_button.setEnabled(False) + return + self._save_workflow_button.setEnabled(True) + self._delete_workflow_button.setEnabled( + self.model.custom.workflow.source is WorkflowSource.local + ) + + self._params_widget.deleteLater() + self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self) + self._params_widget.value = self.model.custom.params + self._layout.insertWidget(1, self._params_widget) + self._params_widget.value_changed.connect(self._change_params) + + def _change_workflow(self): + self.model.custom.workflow_id = self._workflow_select.currentData() + + def _change_params(self): + self.model.custom.params = self._params_widget.value def apply_result(self, item: QListWidgetItem): job_id, index = self._history.item_info(item) self.model.apply_generated_result(job_id, index) + def apply_live_result(self): + image, params = self.model.custom.live_result + self.model.apply_result(image, params, ApplyBehavior.layer) + if settings.new_seed_after_apply: + self.model.generate_seed() + @popup_on_error def _import_workflow(self, *args): filename, __ = QFileDialog.getOpenFileName( diff --git a/ai_diffusion/ui/live.py b/ai_diffusion/ui/live.py index dce222bace..ad39a9a40b 100644 --- a/ai_diffusion/ui/live.py +++ b/ai_diffusion/ui/live.py @@ -23,6 +23,19 @@ from . import theme +class LivePreviewArea(QLabel): + def __init__(self, parent: QWidget): + super().__init__(parent) + self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + self.setAlignment(Qt.AlignmentFlag(Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft)) + + def show_image(self, image: Image): + target = Extent.from_qsize(self.size()) + img = Image.scale_to_fit(image, target) + self.setPixmap(img.to_pixmap()) + self.setMinimumSize(256, 256) + + class LiveWidget(QWidget): _play_icon = theme.icon("play") _pause_icon = theme.icon("pause") @@ -163,11 +176,7 @@ def __init__(self): self.progress_bar.setVisible(False) layout.addWidget(self.progress_bar) - self.preview_area = QLabel(self) - self.preview_area.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) - self.preview_area.setAlignment( - Qt.AlignmentFlag(Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft) - ) + self.preview_area = LivePreviewArea(self) layout.addWidget(self.preview_area) @property @@ -252,10 +261,7 @@ def update_progress(self): def show_result(self, image: Image): self.progress_bar.setVisible(False) - target = Extent.from_qsize(self.preview_area.size()) - img = Image.scale_to_fit(image, target) - self.preview_area.setPixmap(img.to_pixmap()) - self.preview_area.setMinimumSize(256, 256) + self.preview_area.show_image(image) def apply_result(self): self.model.live.apply_result() From 0e5dafc62dd9a59cd14ba729f56b5d45ccd3fa8d Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 10 Oct 2024 12:12:28 +0200 Subject: [PATCH 23/28] Fix tests --- ai_diffusion/custom_workflow.py | 8 +++++--- ai_diffusion/jobs.py | 6 ++++-- tests/test_custom_workflow.py | 8 +++++++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 40eb76849d..100b41230c 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -4,7 +4,7 @@ from enum import Enum from copy import copy from dataclasses import dataclass -from typing import Any, Awaitable, Callable, NamedTuple, Literal +from typing import Any, Awaitable, Callable, NamedTuple, Literal, TYPE_CHECKING from pathlib import Path from PyQt5.QtCore import Qt, QObject, QUuid, QAbstractListModel, QSortFilterProxyModel, QModelIndex from PyQt5.QtCore import pyqtSignal @@ -13,7 +13,6 @@ from .comfy_workflow import ComfyWorkflow from .connection import Connection from .image import Bounds, Image -from .layer import LayerManager from .jobs import Job, JobParams, JobQueue, JobKind from .properties import Property, ObservableProperties from .style import Styles @@ -21,6 +20,9 @@ from .ui import theme from . import eventloop +if TYPE_CHECKING: + from .layer import LayerManager + class WorkflowSource(Enum): document = 0 @@ -390,7 +392,7 @@ def graph(self): def metadata(self): return self._metadata - def collect_parameters(self, layers: LayerManager, bounds: Bounds): + def collect_parameters(self, layers: "LayerManager", bounds: Bounds): params = copy(self.params) for md in self.metadata: param = params.get(md.name) diff --git a/ai_diffusion/jobs.py b/ai_diffusion/jobs.py index 962145be1b..412d2d30d8 100644 --- a/ai_diffusion/jobs.py +++ b/ai_diffusion/jobs.py @@ -3,14 +3,16 @@ from dataclasses import dataclass, fields, field from datetime import datetime from enum import Enum, Flag -from typing import Any, Deque, NamedTuple +from typing import Any, Deque, NamedTuple, TYPE_CHECKING from PyQt5.QtCore import QObject, pyqtSignal from .image import Bounds, ImageCollection from .settings import settings from .style import Style from .util import ensure -from . import control + +if TYPE_CHECKING: + from . import control class JobState(Flag): diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index e3fd9b1771..38d52dbd0c 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -11,6 +11,7 @@ from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace from ai_diffusion.custom_workflow import CustomParam, ParamKind, workflow_parameters from ai_diffusion.image import Image, Extent +from ai_diffusion.jobs import JobQueue from ai_diffusion.style import Style from ai_diffusion.resources import Arch from ai_diffusion import workflow @@ -141,13 +142,18 @@ def test_files(tmp_path: Path): collection.import_file(bad_file) +async def dummy_generate(workflow_input): + return None + + def test_workspace(): connection = Connection() connection_workflows = {"connection1": make_dummy_graph(42)} connection._workflows = connection_workflows workflows = WorkflowCollection(connection) - workspace = CustomWorkspace(workflows) + jobs = JobQueue() + workspace = CustomWorkspace(workflows, dummy_generate, jobs) assert workspace.workflow_id == "connection1" assert workspace.workflow and workspace.workflow.id == "connection1" assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter" From 8a4c833de5c1305f730790c045786e05b31d9c90 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 10 Oct 2024 14:04:12 +0200 Subject: [PATCH 24/28] Make job metadata more flexible and add style loras Fixed color of progress bar when there is a model upload --- ai_diffusion/custom_workflow.py | 7 +++++ ai_diffusion/jobs.py | 27 +++++++++++------ ai_diffusion/model.py | 30 +++++++++--------- ai_diffusion/ui/generation.py | 54 +++++++++++++++++++++------------ 4 files changed, 75 insertions(+), 43 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 100b41230c..cc1f9cae8f 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -392,6 +392,13 @@ def graph(self): def metadata(self): return self._metadata + @property + def job_name(self): + for param in self.metadata: + if param.kind is ParamKind.prompt_positive: + return str(self.params[param.name]) + return self.workflow_id or "Custom Workflow" + def collect_parameters(self, layers: "LayerManager", bounds: Bounds): params = copy(self.params) for md in self.metadata: diff --git a/ai_diffusion/jobs.py b/ai_diffusion/jobs.py index 412d2d30d8..b2f16cbf2d 100644 --- a/ai_diffusion/jobs.py +++ b/ai_diffusion/jobs.py @@ -47,14 +47,10 @@ def from_dict(data: dict[str, Any]): @dataclass class JobParams: bounds: Bounds - prompt: str - negative_prompt: str = "" + name: str # used eg. as name for new layers created from this job regions: list[JobRegion] = field(default_factory=list) - strength: float = 1.0 + metadata: dict[str, Any] = field(default_factory=dict) seed: int = 0 - style: str = "" - checkpoint: str = "" - sampler: str = "" has_mask: bool = False frame: tuple[int, int, int] = (0, 0, 0) animation_id: str = "" @@ -73,9 +69,22 @@ def equal_ignore_seed(cls, a: JobParams | None, b: JobParams | None): return all(getattr(a, name) == getattr(b, name) for name in field_names) def set_style(self, style: Style): - self.style = style.filename - self.checkpoint = style.sd_checkpoint - self.sampler = f"{style.sampler} ({style.sampler_steps} / {style.cfg_scale})" + self.metadata["style"] = style.filename + self.metadata["checkpoint"] = style.sd_checkpoint + self.metadata["loras"] = style.loras + self.metadata["sampler"] = f"{style.sampler} ({style.sampler_steps} / {style.cfg_scale})" + + @property + def prompt(self): + return self.metadata.get("prompt", "") + + @property + def style(self): + return self.metadata.get("style", "") + + @property + def strength(self): + return self.metadata.get("strength", 1.0) class Job: diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 20e124b9c1..941e6aed47 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -194,23 +194,23 @@ def _prepare_workflow(self, dryrun=False): ) job_params = JobParams(bounds, prompt, regions=job_regions) job_params.set_style(self.style) + job_params.metadata["prompt"] = prompt + job_params.metadata["negative_prompt"] = self.regions.negative + job_params.metadata["strength"] = self.strength + if len(job_regions) == 1: + job_params.metadata["prompt"] = job_params.name = job_regions[0].prompt return input, job_params async def enqueue_jobs( self, input: WorkflowInput, kind: JobKind, params: JobParams, count: int = 1 ): sampling = ensure(input.sampling) - params.negative_prompt = self.regions.negative - params.strength = sampling.denoise_strength params.has_mask = input.images is not None and input.images.hires_mask is not None - if len(params.regions) == 1: - params.prompt = params.regions[0].prompt for i in range(count): - input = replace( - input, sampling=replace(sampling, seed=sampling.seed + i * settings.batch_size) - ) - params.seed = ensure(input.sampling).seed + next_seed = sampling.seed + i * settings.batch_size + input = replace(input, sampling=replace(sampling, seed=next_seed)) + params.seed = next_seed job = self.jobs.add(kind, copy(params)) await self._enqueue_job(job, input) @@ -375,7 +375,7 @@ async def _generate_custom(self, previous_input: WorkflowInput | None): sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed), custom_workflow=CustomWorkflowInput(wf.root, params), ) - job_params = JobParams(bounds, self.custom.workflow_id) + job_params = JobParams(bounds, self.custom.job_name, metadata=self.custom.params) job_kind = JobKind.live_preview if is_live else JobKind.diffusion if input == previous_input: @@ -486,7 +486,7 @@ def update_preview(self): def show_preview(self, job_id: str, index: int, name_prefix="Preview"): job = self.jobs.find(job_id) assert job is not None, "Cannot show preview, invalid job id" - name = f"[{name_prefix}] {trim_text(job.params.prompt, 77)}" + name = f"[{name_prefix}] {trim_text(job.params.name, 77)}" if self._layer and self._layer.was_removed: self._layer = None # layer was removed by user if self._layer is not None: @@ -508,7 +508,7 @@ def apply_result(self, image: Image, params: JobParams, behavior: ApplyBehavior, if behavior is ApplyBehavior.replace: self.layers.update_layer_image(self.layers.active, image, params.bounds) else: - name = f"{prefix}{trim_text(params.prompt, 200)} ({params.seed})" + name = f"{prefix}{trim_text(params.name, 200)} ({params.seed})" self.layers.create(name, image, params.bounds) else: # apply to regions with RestoreActiveLayer(self.layers) as restore: @@ -600,9 +600,9 @@ def add_control_layer(self, job: Job, result: dict | SharedWorkflow | None): if job.control.mode is ControlMode.pose and isinstance(result, dict): pose = Pose.from_open_pose_json(result) pose.scale(job.params.bounds.extent) - return self.layers.create_vector(job.params.prompt, pose.to_svg()) + return self.layers.create_vector(job.params.name, pose.to_svg()) elif len(job.results) > 0: - return self.layers.create(job.params.prompt, job.results[0], job.params.bounds) + return self.layers.create(job.params.name, job.results[0], job.params.bounds) return self.layers.active # Execution was cached and no image was produced def add_upscale_layer(self, job: Job): @@ -1073,7 +1073,7 @@ def _import_animation(self, job: Job): keyframes = self._keyframes.pop(job.params.animation_id) _, start, end = job.params.frame doc.import_animation(keyframes, start) - eventloop.run(self._update_layer_name(f"[Generated] {start}-{end}: {job.params.prompt}")) + eventloop.run(self._update_layer_name(f"[Generated] {start}-{end}: {job.params.name}")) async def _update_layer_name(self, name: str): doc = self._model.document @@ -1138,7 +1138,7 @@ def _save_job_result(model: Model, job: Job | None, index: int): assert len(job.results) > index, "Cannot save result, invalid result index" assert model.document.filename, "Cannot save result, document is not saved" timestamp = job.timestamp.strftime("%Y%m%d-%H%M%S") - prompt = util.sanitize_prompt(job.params.prompt) + prompt = util.sanitize_prompt(job.params.name) path = Path(model.document.filename) path = path.parent / f"{path.stem}-generated-{timestamp}-{index}-{prompt}.png" path = util.find_unused_path(path) diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py index 67541227de..723f3ee416 100644 --- a/ai_diffusion/ui/generation.py +++ b/ai_diffusion/ui/generation.py @@ -119,8 +119,9 @@ def add(self, job: Job): if not JobParams.equal_ignore_seed(self._last_job_params, job.params): self._last_job_params = job.params - prompt = job.params.prompt if job.params.prompt != "" else "" - strength = f"{job.params.strength*100:.0f}% - " if job.params.strength != 1.0 else "" + prompt = job.params.name if job.params.name != "" else "" + strength = job.params.metadata.get("strength", 1.0) + strength = f"{strength*100:.0f}% - " if strength != 1.0 else "" header = QListWidgetItem(f"{job.timestamp:%H:%M} - {strength}{prompt}") header.setFlags(Qt.ItemFlag.NoItemFlags) @@ -140,26 +141,39 @@ def add(self, job: Job): if scroll_to_bottom: self.scrollToBottom() + _job_info_translations = { + "prompt": _("Prompt"), + "negative_prompt": _("Negative Prompt"), + "style": _("Style"), + "strength": _("Strength"), + "checkpoint": _("Model"), + "loras": _("LoRA"), + "sampler": _("Sampler"), + "seed": _("Seed"), + } + def _job_info(self, params: JobParams): - prompt = params.prompt if params.prompt != "" else "" - if len(prompt) > 70: - prompt = prompt[:66] + "..." + title = params.name if params.name != "" else "" + if len(title) > 70: + title = title[:66] + "..." + if params.strength != 1.0: + title = f"{title} @ {params.strength*100:.0f}%" style = Styles.list().find(params.style) - positive = _("Prompt") + f": {params.prompt or '-'}" - negative = _("Negative Prompt") + f": {params.negative_prompt or '-'}" - strings = [ - f"{prompt} @ {params.strength*100:.0f}%\n", + strings: list[str | list[str]] = [ + title + "\n", _("Click to toggle preview, double-click to apply."), "", - _("Style") + f": {style.name if style else params.style}", - wrap_text(positive, 80, subsequent_indent=" "), - wrap_text(negative, 80, subsequent_indent=" "), - _("Strength") + f": {params.strength*100:.0f}%", - _("Model") + f": {params.checkpoint}", - _("Sampler") + f": {params.sampler}", - _("Seed") + f": {params.seed}", - f"{params.bounds}", ] + for key, value in params.metadata.items(): + if key == "style" and style: + value = style.name + if isinstance(value, list) and len(value) == 0: + continue + if isinstance(value, list) and isinstance(value[0], dict): + value = "\n ".join((f"{v.get('name')} ({v.get('strength')})" for v in value)) + s = f"{self._job_info_translations.get(key, key)}: {value}" + strings.append(wrap_text(s, 80, subsequent_indent=" ")) + strings.append(_("Seed") + f": {params.seed}") return "\n".join(flatten(strings)) def remove(self, job: Job): @@ -351,7 +365,7 @@ def _copy_prompt(self): active = self._model.regions.active_or_root active.positive = job.params.prompt if isinstance(active, RootRegion): - active.negative = job.params.negative_prompt + active.negative = job.params.metadata.get("negative_prompt", "") def _copy_strength(self): if job := self.selected_job: @@ -509,6 +523,7 @@ def __init__(self, parent: QWidget): super().__init__(parent) self._model = root.active_model self._model_bindings: list[QMetaObject.Connection] = [] + self._palette = self.palette() self.setMinimum(0) self.setMaximum(1000) self.setTextVisible(False) @@ -529,8 +544,9 @@ def model(self, model: Model): ] def _update_progress_kind(self): - palette = self.palette() + palette = self._palette if self._model.progress_kind is ProgressKind.upload: + palette = self.palette() palette.setColor(QPalette.ColorRole.Highlight, QColor(theme.progress_alt)) self.setPalette(palette) From 2aa7b167a9283854845c074b507d7b03bac46a17 Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 11 Oct 2024 12:42:40 +0200 Subject: [PATCH 25/28] Placeholder message for custom workflows on cloud --- ai_diffusion/resources.py | 6 +++--- ai_diffusion/ui/custom_workflow.py | 30 ++++++++++++++++++++++++++++++ ai_diffusion/ui/diffusion.py | 8 +++++++- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/ai_diffusion/resources.py b/ai_diffusion/resources.py index 624198b3bb..84c1dcf286 100644 --- a/ai_diffusion/resources.py +++ b/ai_diffusion/resources.py @@ -6,10 +6,10 @@ # Version identifier for all the resources defined here. This is used as the server version. # It usually follows the plugin version, but not all new plugin versions also require a server update. -version = "1.25.0" +version = "1.26.0" comfy_url = "https://github.com/comfyanonymous/ComfyUI" -comfy_version = "dc96a1ae19b1d714a791f1fcb21578389955bbfd" +comfy_version = "1b8089528502a881d0ed2918b2abd54441743dd0" class CustomNode(NamedTuple): @@ -39,7 +39,7 @@ class CustomNode(NamedTuple): "External Tooling Nodes", "comfyui-tooling-nodes", "https://github.com/Acly/comfyui-tooling-nodes", - "1a24975f99b95fa9a17a917de05fa248d8dc9569", + "9a9cbe78a5851a49da0b38bc9b17504837476e8f", ["ETN_LoadImageBase64", "ETN_LoadMaskBase64", "ETN_SendImageWebSocket", "ETN_Translate"], ), CustomNode( diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py index d8f89a1e13..5624dbca6b 100644 --- a/ai_diffusion/ui/custom_workflow.py +++ b/ai_diffusion/ui/custom_workflow.py @@ -674,3 +674,33 @@ def _accept_name(self, *args): def _cancel_name(self): self.is_edit_mode = False + + +class CustomWorkflowPlaceholder(QWidget): + def __init__(self): + super().__init__() + self._model = root.active_model + self._connections = [] + + self._workspace_select = WorkspaceSelectWidget(self) + note = QLabel("" + _("Custom workflows are not available on Cloud.") + "", self) + + self._layout = QVBoxLayout() + self._layout.addWidget(self._workspace_select) + self._layout.addSpacing(50) + self._layout.addWidget(note, 0, Qt.AlignmentFlag.AlignCenter) + self._layout.addStretch() + self.setLayout(self._layout) + + @property + def model(self): + return self._model + + @model.setter + def model(self, model: Model): + if self._model != model: + Binding.disconnect_all(self._connections) + self._model = model + self._connections = [ + bind(model, "workspace", self._workspace_select, "value", Bind.one_way) + ] diff --git a/ai_diffusion/ui/diffusion.py b/ai_diffusion/ui/diffusion.py index f30d741900..3a20386fcd 100644 --- a/ai_diffusion/ui/diffusion.py +++ b/ai_diffusion/ui/diffusion.py @@ -14,7 +14,7 @@ from ..localization import translate as _ from . import theme from .generation import GenerationWidget -from .custom_workflow import CustomWorkflowWidget +from .custom_workflow import CustomWorkflowWidget, CustomWorkflowPlaceholder from .upscale import UpscaleWidget from .live import LiveWidget from .animation import AnimationWidget @@ -213,6 +213,7 @@ def __init__(self): self._animation = AnimationWidget() self._live = LiveWidget() self._custom = CustomWorkflowWidget() + self._custom_placeholder = CustomWorkflowPlaceholder() self._frame = QStackedWidget(self) self._frame.addWidget(self._welcome) self._frame.addWidget(self._generation) @@ -220,6 +221,7 @@ def __init__(self): self._frame.addWidget(self._live) self._frame.addWidget(self._animation) self._frame.addWidget(self._custom) + self._frame.addWidget(self._custom_placeholder) self.setWidget(self._frame) root.connection.state_changed.connect(self.update_content) @@ -238,6 +240,7 @@ def update_content(self): model = root.model_for_active_document() connection = root.connection requires_update = self._welcome.requires_update + is_cloud = settings.server_mode is ServerMode.cloud if model is None or connection.state is not ConnectionState.connected or requires_update: self._frame.setCurrentWidget(self._welcome) elif model.workspace is Workspace.generation: @@ -252,6 +255,9 @@ def update_content(self): elif model.workspace is Workspace.animation: self._animation.model = model self._frame.setCurrentWidget(self._animation) + elif model.workspace is Workspace.custom and is_cloud: + self._custom_placeholder.model = model + self._frame.setCurrentWidget(self._custom_placeholder) elif model.workspace is Workspace.custom: self._custom.model = model self._frame.setCurrentWidget(self._custom) From 70dbe98bf2de4964b24496da4fb1300c366186dc Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 11 Oct 2024 17:23:34 +0200 Subject: [PATCH 26/28] Support multiple output nodes and guess workflow sample count --- ai_diffusion/comfy_client.py | 14 +++++--------- ai_diffusion/comfy_workflow.py | 10 ++++++++++ ai_diffusion/workflow.py | 1 + tests/test_custom_workflow.py | 6 ++++-- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index 38bee0ef2b..645511ff26 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -282,9 +282,8 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr if msg["type"] == "executing" and msg["data"]["node"] is None: job_id = msg["data"]["prompt_id"] if local_id := self._clear_job(job_id): - # Usually we don't get here because finished, interrupted or error is sent first. - # But it may happen if the entire execution is cached and no images are sent. if len(images) == 0: + # It may happen if the entire execution is cached and no images are sent. images = last_images if len(images) == 0: # Still no images. Potential scenario: execution cached, but previous @@ -292,7 +291,10 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr err = "No new images were generated because the inputs did not change." await self._report(ClientEvent.error, local_id, error=err) else: - await self._report(ClientEvent.finished, local_id, 1, images=images) + last_images = images + await self._report( + ClientEvent.finished, local_id, 1, images=images, result=result + ) elif msg["type"] in ("execution_cached", "executing", "progress"): if self._active is not None and progress is not None: @@ -308,12 +310,6 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr pose_json = _extract_pose_json(msg) if job and pose_json: result = pose_json - elif job is not None and _validate_executed_node(msg, len(images)): - self._clear_job(job.remote_id) - last_images = images - await self._report( - ClientEvent.finished, job.local_id, 1, images=images, result=result - ) if msg["type"] == "execution_error": job = self._get_active_job(msg["data"]["prompt_id"]) diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 4faadd9b06..f88ef0b914 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -96,6 +96,7 @@ def import_graph(existing: dict, node_inputs: dict): def from_dict(existing: dict): w = ComfyWorkflow() w.root = existing + w.node_count = len(w.root) return w def add_default_values(self, node_name: str, args: dict): @@ -184,6 +185,15 @@ def find_connected(self, output: Output): if input_value == output: yield node, input_name + def guess_sample_count(self): + self.sample_count = sum( + int(value) + for node in self + for name, value in node.inputs.items() + if name == "steps" and isinstance(value, (int, float)) + ) + return self.sample_count + def __iter__(self): return iter(self.node(int(k)) for k in self.root.keys()) diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index b0e94e4bf3..98dfd93dad 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -1108,6 +1108,7 @@ def get_param(node: ComfyNode, expected_type: type | None = None): mapped = ComfyNode(node.id, node.type, mapped_inputs) nodes[node.id] = w.copy(mapped).node + w.guess_sample_count() return w diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 38d52dbd0c..4e58620263 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -192,14 +192,16 @@ def test_workspace(): def test_import(): graph = { - "4": {"class_type": "A", "inputs": {"int": 4, "float": 1.2, "string": "mouse"}}, + "4": {"class_type": "A", "inputs": {"steps": 4, "float": 1.2, "string": "mouse"}}, "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}}, "9": {"class_type": "B", "inputs": {"in": ["4", 0]}}, } w = ComfyWorkflow.import_graph(graph, {}) - assert w.node(0) == ComfyNode(0, "A", {"int": 4, "float": 1.2, "string": "mouse"}) + assert w.node(0) == ComfyNode(0, "A", {"steps": 4, "float": 1.2, "string": "mouse"}) assert w.node(1) == ComfyNode(1, "B", {"in": Output(0, 0)}) assert w.node(2) == ComfyNode(2, "C", {"in": Output(1, 1)}) + assert w.node_count == 3 + assert w.guess_sample_count() == 4 def test_import_ui_workflow(): From aed73dcf345def9533ec65d55e9fe90efc3dc238 Mon Sep 17 00:00:00 2001 From: Acly Date: Fri, 11 Oct 2024 20:41:01 +0200 Subject: [PATCH 27/28] Only try to analyse workflows after connecting to get node info --- ai_diffusion/custom_workflow.py | 58 +++++++++++++++++++++------------ ai_diffusion/root.py | 3 +- tests/test_custom_workflow.py | 50 +++++++++++++++++++++++----- 3 files changed, 81 insertions(+), 30 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index cc1f9cae8f..1f37d32230 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -10,8 +10,8 @@ from PyQt5.QtCore import pyqtSignal from .api import WorkflowInput -from .comfy_workflow import ComfyWorkflow -from .connection import Connection +from .comfy_workflow import ComfyWorkflow, ComfyNode +from .connection import Connection, ConnectionState from .image import Bounds, Image from .jobs import Job, JobParams, JobQueue, JobKind from .properties import Property, ObservableProperties @@ -51,23 +51,29 @@ class WorkflowCollection(QAbstractListModel): def __init__(self, connection: Connection, folder: Path | None = None): super().__init__() self._connection = connection - self._workflows: list[CustomWorkflow] = [] - self._folder = folder or user_data_dir / "workflows" - for file in self._folder.glob("*.json"): - try: - self._process_file(file) - except Exception as e: - log.exception(f"Error loading workflow from {file}: {e}") + self._workflows: list[CustomWorkflow] = [] + self._connection.state_changed.connect(self._handle_connection) self._connection.workflow_published.connect(self._process_remote_workflow) - for wf in self._connection.workflows.keys(): - self._process_remote_workflow(wf) + self._handle_connection(self._connection.state) + + def _handle_connection(self, state: ConnectionState): + if state in (ConnectionState.connected, ConnectionState.disconnected): + self.clear() + + if state is ConnectionState.connected: + for file in self._folder.glob("*.json"): + try: + self._process_file(file) + except Exception as e: + log.exception(f"Error loading workflow from {file}: {e}") + + for wf in self._connection.workflows.keys(): + self._process_remote_workflow(wf) def _node_inputs(self): - if client := self._connection.client_if_connected: - return client.models.node_inputs - return {} + return self._connection.client.models.node_inputs def _create_workflow( self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None @@ -127,6 +133,12 @@ def remove(self, id: str): self._workflows.pop(idx.row()) self.endRemoveRows() + def clear(self): + if len(self._workflows) > 0: + self.beginResetModel() + self._workflows.clear() + self.endResetModel() + def set_graph(self, index: QModelIndex, graph: dict): wf = self._workflows[index.row()] wf.workflow = ComfyWorkflow.import_graph(graph, self._node_inputs()) @@ -266,13 +278,8 @@ def workflow_parameters(w: ComfyWorkflow): case ("ETN_Parameter", "choice"): name = node.input("name", "Parameter") default = node.input("default", "") - connected, input_name = next(w.find_connected(node.output()), (None, "")) - if connected: - if input_type := w.input_type(connected.type, input_name): - if isinstance(input_type[0], list): - yield CustomParam( - ParamKind.choice, name, choices=input_type[0], default=default - ) + if choices := _get_choices(w, node): + yield CustomParam(ParamKind.choice, name, choices=choices, default=default) else: yield CustomParam(ParamKind.text, name, default=default) case ("ETN_Parameter", unknown_type) if unknown_type != "auto": @@ -280,6 +287,15 @@ def workflow_parameters(w: ComfyWorkflow): log.warning(f"Custom workflow has an unsupported parameter type {unknown}") +def _get_choices(w: ComfyWorkflow, node: ComfyNode): + connected, input_name = next(w.find_connected(node.output()), (None, "")) + if connected: + if input_type := w.input_type(connected.type, input_name): + if isinstance(input_type[0], list): + return input_type[0] + return None + + ImageGenerator = Callable[[WorkflowInput | None], Awaitable[None | Literal[False] | WorkflowInput]] diff --git a/ai_diffusion/root.py b/ai_diffusion/root.py index da069a6ea6..f001ae48ea 100644 --- a/ai_diffusion/root.py +++ b/ai_diffusion/root.py @@ -40,6 +40,7 @@ def init(self): self._files = FileLibrary.load() self._workflows = WorkflowCollection(self._connection) self._models = [] + self._null_model = Model(Document(), self._connection, self._workflows) self._recent = RecentlyUsedSync.from_settings() self._auto_update = AutoUpdate() if settings.auto_update: @@ -96,7 +97,7 @@ def auto_update(self) -> AutoUpdate: def active_model(self): if model := self.model_for_active_document(): return model - return Model(Document(), self._connection, self._workflows) + return self._null_model async def autostart(self, signal_server_change: Callable): connection = self._connection diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 4e58620263..8679928d79 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -3,9 +3,9 @@ from pathlib import Path from PyQt5.QtCore import Qt -from ai_diffusion.api import CustomWorkflowInput, ImageInput, SamplingInput -from ai_diffusion.client import ClientModels, CheckpointInfo -from ai_diffusion.connection import Connection +from ai_diffusion.api import CustomWorkflowInput, ImageInput, WorkflowInput +from ai_diffusion.client import Client, ClientModels, CheckpointInfo +from ai_diffusion.connection import Connection, ConnectionState from ai_diffusion.comfy_workflow import ComfyNode, ComfyWorkflow, Output from ai_diffusion.custom_workflow import CustomWorkflow, WorkflowSource, WorkflowCollection from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace @@ -19,6 +19,40 @@ from .config import test_dir +class MockClient(Client): + def __init__(self, node_inputs: dict[str, dict]): + self.models = ClientModels() + self.models.node_inputs = node_inputs + + @staticmethod + async def connect(url: str, access_token: str = "") -> Client: + return MockClient({}) + + async def enqueue(self, work: WorkflowInput, front: bool = False) -> str: + return "" + + async def listen(self): # type: ignore + return + + async def interrupt(self): + pass + + async def clear_queue(self): + pass + + +def create_mock_connection( + initial_workflows: dict[str, dict], + node_inputs: dict[str, dict] | None = None, + state: ConnectionState = ConnectionState.connected, +): + connection = Connection() + connection._client = MockClient(node_inputs or {}) + connection._workflows = initial_workflows + connection.state = state + return connection + + def _assert_has_workflow( collection: WorkflowCollection, name: str, @@ -44,12 +78,13 @@ def test_collection(tmp_path: Path): file2_graph = {"0": {"class_type": "F2", "inputs": {}}} file2.write_text(json.dumps(file2_graph)) - connection = Connection() connection_graph = {"0": {"class_type": "C1", "inputs": {}}} connection_workflows = {"connection1": connection_graph} - connection._workflows = connection_workflows + connection = create_mock_connection(connection_workflows, state=ConnectionState.disconnected) collection = WorkflowCollection(connection, tmp_path) + assert len(collection) == 0 + connection.state = ConnectionState.connected assert len(collection) == 3 _assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph, file1) _assert_has_workflow(collection, "file2", WorkflowSource.local, file2_graph, file2) @@ -110,7 +145,7 @@ def make_dummy_graph(n: int = 42): def test_files(tmp_path: Path): collection_folder = tmp_path / "workflows" - collection = WorkflowCollection(Connection(), collection_folder) + collection = WorkflowCollection(create_mock_connection({}, {}), collection_folder) assert len(collection) == 0 file1 = tmp_path / "file1.json" @@ -147,9 +182,8 @@ async def dummy_generate(workflow_input): def test_workspace(): - connection = Connection() connection_workflows = {"connection1": make_dummy_graph(42)} - connection._workflows = connection_workflows + connection = create_mock_connection(connection_workflows, {}) workflows = WorkflowCollection(connection) jobs = JobQueue() From 089dc66d256a7e31f2764487b9d828120248bbe1 Mon Sep 17 00:00:00 2001 From: Acly Date: Sat, 12 Oct 2024 20:10:07 +0200 Subject: [PATCH 28/28] Load previously used custom graph from document (somehow) --- ai_diffusion/custom_workflow.py | 53 ++++++++++++++++++++---------- ai_diffusion/persistence.py | 22 +++++++++++-- ai_diffusion/ui/custom_workflow.py | 4 +-- tests/test_custom_workflow.py | 26 ++++++++++----- 4 files changed, 74 insertions(+), 31 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 1f37d32230..804f15a9c1 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -48,11 +48,14 @@ class WorkflowCollection(QAbstractListModel): _icon_remote = theme.icon("web-connection") _icon_document = theme.icon("file-kra") + loaded = pyqtSignal() + def __init__(self, connection: Connection, folder: Path | None = None): super().__init__() self._connection = connection self._folder = folder or user_data_dir / "workflows" self._workflows: list[CustomWorkflow] = [] + self._pending_workflows: list[tuple[str, WorkflowSource, dict]] = [] self._connection.state_changed.connect(self._handle_connection) self._connection.workflow_published.connect(self._process_remote_workflow) @@ -63,6 +66,10 @@ def _handle_connection(self, state: ConnectionState): self.clear() if state is ConnectionState.connected: + for id, source, graph in self._pending_workflows: + self._process_workflow(id, source, graph) + self._pending_workflows.clear() + for file in self._folder.glob("*.json"): try: self._process_file(file) @@ -72,31 +79,36 @@ def _handle_connection(self, state: ConnectionState): for wf in self._connection.workflows.keys(): self._process_remote_workflow(wf) + self.loaded.emit() + def _node_inputs(self): return self._connection.client.models.node_inputs - def _create_workflow( + def _process_workflow( self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None ): - wf = ComfyWorkflow.import_graph(graph, self._node_inputs()) - return CustomWorkflow(id, source, wf, path) - - def _process_remote_workflow(self, id: str): - graph = self._connection.workflows[id] - self._process(self._create_workflow(id, WorkflowSource.remote, graph)) - - def _process_file(self, file: Path): - with file.open("r") as f: - graph = json.load(f) - self._process(self._create_workflow(file.stem, WorkflowSource.local, graph, file)) + if self._connection.state is not ConnectionState.connected: + self._pending_workflows.append((id, source, graph)) + return - def _process(self, workflow: CustomWorkflow): + comfy_flow = ComfyWorkflow.import_graph(graph, self._node_inputs()) + workflow = CustomWorkflow(id, source, comfy_flow, path) idx = self.find_index(workflow.id) if idx.isValid(): self._workflows[idx.row()] = workflow self.dataChanged.emit(idx, idx) else: self.append(workflow) + return idx + + def _process_remote_workflow(self, id: str): + graph = self._connection.workflows[id] + self._process_workflow(id, WorkflowSource.remote, graph) + + def _process_file(self, file: Path): + with file.open("r") as f: + graph = json.load(f) + self._process_workflow(file.stem, WorkflowSource.local, graph, file) def rowCount(self, parent=QModelIndex()): return len(self._workflows) @@ -121,7 +133,7 @@ def append(self, item: CustomWorkflow): self.endInsertRows() def add_from_document(self, id: str, graph: dict): - self.append(self._create_workflow(id, WorkflowSource.document, graph)) + self._process_workflow(id, WorkflowSource.document, graph) def remove(self, id: str): idx = self.find_index(id) @@ -154,7 +166,7 @@ def save_as(self, id: str, graph: dict): self._folder.mkdir(exist_ok=True) path = self._folder / f"{id}.json" path.write_text(json.dumps(graph, indent=2)) - self.append(self._create_workflow(id, WorkflowSource.local, graph, path)) + self._process_workflow(id, WorkflowSource.local, graph, path) return id def import_file(self, filepath: Path): @@ -336,12 +348,16 @@ def __init__(self, workflows: WorkflowCollection, generator: ImageGenerator, job jobs.job_finished.connect(self._handle_job_finished) workflows.dataChanged.connect(self._update_workflow) - workflows.rowsInserted.connect(self._set_default_workflow) + workflows.loaded.connect(self._set_default_workflow) self._set_default_workflow() def _set_default_workflow(self): if not self.workflow_id and len(self._workflows) > 0: self.workflow_id = self._workflows[0].id + else: + current_index = self._workflows.find_index(self.workflow_id) + if current_index.isValid(): + self._update_workflow(current_index, QModelIndex()) def _update_workflow(self, idx: QModelIndex, _: QModelIndex): wf = self._workflows[idx.row()] @@ -358,10 +374,13 @@ def _set_workflow_id(self, id: str): self._workflow_id = id self.workflow_id_changed.emit(id) self.modified.emit(self, "workflow_id") - self._update_workflow(self._workflows.find_index(id), QModelIndex()) + index = self._workflows.find_index(id) + if index.isValid(): # might be invalid when loading document before connecting + self._update_workflow(index, QModelIndex()) def set_graph(self, id: str, graph: dict): if self._workflows.find(id) is None: + id = "Document Workflow (embedded)" self._workflows.add_from_document(id, graph) self.workflow_id = id diff --git a/ai_diffusion/persistence.py b/ai_diffusion/persistence.py index b01dec15e8..348e62cfbc 100644 --- a/ai_diffusion/persistence.py +++ b/ai_diffusion/persistence.py @@ -8,8 +8,9 @@ from PyQt5.QtWidgets import QMessageBox from .api import InpaintMode, FillMode -from .image import Bounds, Image, ImageCollection, ImageFileFormat +from .image import ImageCollection from .model import Model, InpaintContext +from .custom_workflow import CustomWorkspace from .control import ControlLayer, ControlLayerList from .region import RootRegion, Region from .jobs import Job, JobKind, JobParams, JobQueue @@ -132,7 +133,7 @@ def _save(self): state["upscale"] = _serialize(model.upscale) state["live"] = _serialize(model.live) state["animation"] = _serialize(model.animation) - state["custom"] = _serialize(model.custom) + state["custom"] = _serialize_custom(model.custom) state["history"] = [asdict(h) for h in self._history] state["root"] = _serialize(model.regions) state["control"] = [_serialize(c) for c in model.regions.control] @@ -151,7 +152,7 @@ def _load(self, model: Model, state_bytes: bytes): _deserialize(model.upscale, state.get("upscale", {})) _deserialize(model.live, state.get("live", {})) _deserialize(model.animation, state.get("animation", {})) - _deserialize(model.custom, state.get("custom", {})) + _deserialize_custom(model.custom, state.get("custom", {})) _deserialize(model.regions, state.get("root", {})) for control_state in state.get("control", []): _deserialize(model.regions.control.emplace(), control_state) @@ -264,6 +265,21 @@ def converter(type, value): return deserialize(obj, data, converter) +def _serialize_custom(custom: CustomWorkspace): + result = _serialize(custom) + result["workflow_id"] = custom.workflow_id + result["graph"] = custom.graph.root if custom.graph else None + return result + + +def _deserialize_custom(custom: CustomWorkspace, data: dict[str, Any]): + _deserialize(custom, data) + workflow_id = data.get("workflow_id", "") + graph = data.get("graph", None) + if workflow_id and graph: + custom.set_graph(workflow_id, graph) + + def _find_annotation(document, name: str): if result := document.find_annotation(name): return result diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py index 5624dbca6b..85def2632b 100644 --- a/ai_diffusion/ui/custom_workflow.py +++ b/ai_diffusion/ui/custom_workflow.py @@ -87,7 +87,7 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None): self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) self._widget.valueChanged.connect(self._notify) self._label = QLabel(self) - self._label.setFixedWidth(40) + self._label.setFixedWidth(32) self._label.setAlignment(Qt.AlignmentFlag.AlignRight) layout.addWidget(self._widget) layout.addWidget(self._label) @@ -135,7 +135,7 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None): self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) self._widget.valueChanged.connect(self._notify) self._label = QLabel(self) - self._label.setFixedWidth(40) + self._label.setFixedWidth(32) self._label.setAlignment(Qt.AlignmentFlag.AlignRight) layout.addWidget(self._widget) layout.addWidget(self._label) diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 8679928d79..ebf1693d81 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -83,14 +83,24 @@ def test_collection(tmp_path: Path): connection = create_mock_connection(connection_workflows, state=ConnectionState.disconnected) collection = WorkflowCollection(connection, tmp_path) + events = [] + assert len(collection) == 0 + + def on_loaded(): + events.append("loaded") + + collection.loaded.connect(on_loaded) + doc_graph = {"0": {"class_type": "D1", "inputs": {}}} + collection.add_from_document("doc1", doc_graph) + connection.state = ConnectionState.connected - assert len(collection) == 3 + assert len(collection) == 4 + assert events == ["loaded"] _assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph, file1) _assert_has_workflow(collection, "file2", WorkflowSource.local, file2_graph, file2) _assert_has_workflow(collection, "connection1", WorkflowSource.remote, connection_graph) - - events = [] + _assert_has_workflow(collection, "doc1", WorkflowSource.document, doc_graph) def on_begin_insert(index, first, last): events.append(("begin_insert", first)) @@ -109,15 +119,13 @@ def on_data_changed(start, end): connection_workflows["connection2"] = connection2_graph connection.workflow_published.emit("connection2") - assert len(collection) == 4 + assert len(collection) == 5 _assert_has_workflow(collection, "connection2", WorkflowSource.remote, connection2_graph) file1_graph_changed = {"0": {"class_type": "F3", "inputs": {}}} - collection.set_graph(collection.index(0), file1_graph_changed) + collection.set_graph(collection.find_index("file1"), file1_graph_changed) _assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph_changed, file1) - assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)] - - collection.add_from_document("doc1", {"0": {"class_type": "D1", "inputs": {}}}) + assert events == ["loaded", ("begin_insert", 4), "end_insert", ("data_changed", 1)] sorted = SortedWorkflows(collection) assert sorted[0].source is WorkflowSource.document @@ -207,7 +215,7 @@ def test_workspace(): } } workspace.set_graph("doc1", doc_graph) - assert workspace.workflow_id == "doc1" + assert workspace.workflow_id == "Document Workflow (embedded)" assert workspace.workflow and workspace.workflow.source is WorkflowSource.document assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter" assert workspace.metadata[0].name == "param2"