diff --git a/ai_diffusion/api.py b/ai_diffusion/api.py
index 56ad6c17bb..c473fb24c1 100644
--- a/ai_diffusion/api.py
+++ b/ai_diffusion/api.py
@@ -18,6 +18,7 @@ class WorkflowKind(Enum):
upscale_simple = 4
upscale_tiled = 5
control_image = 6
+ custom = 7
@dataclass
@@ -144,6 +145,12 @@ def clamped(self):
return params
+@dataclass
+class CustomWorkflowInput:
+ workflow: dict
+ params: dict[str, Any]
+
+
@dataclass
class WorkflowInput:
kind: WorkflowKind
@@ -157,6 +164,7 @@ class WorkflowInput:
control_mode: ControlMode = ControlMode.reference
batch_count: int = 1
nsfw_filter: float = 0.0
+ custom_workflow: CustomWorkflowInput | None = None
@property
def extent(self):
diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py
index 76568558be..a3158194bf 100644
--- a/ai_diffusion/client.py
+++ b/ai_diffusion/client.py
@@ -25,6 +25,7 @@ class ClientEvent(Enum):
disconnected = 5
queued = 6
upload = 7
+ published = 8
class ClientMessage(NamedTuple):
@@ -32,7 +33,7 @@ class ClientMessage(NamedTuple):
job_id: str = ""
progress: float = 0
images: ImageCollection | None = None
- result: dict | None = None
+ result: dict | SharedWorkflow | None = None
error: str | None = None
@@ -68,6 +69,11 @@ def parse(data: dict):
return DeviceInfo("cpu", "unknown", 0)
+class SharedWorkflow(NamedTuple):
+ publisher: str
+ workflow: dict
+
+
class CheckpointInfo(NamedTuple):
filename: str
arch: Arch
diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py
index 8f61590d7e..645511ff26 100644
--- a/ai_diffusion/comfy_client.py
+++ b/ai_diffusion/comfy_client.py
@@ -11,7 +11,7 @@
from .api import WorkflowInput
from .client import Client, CheckpointInfo, ClientMessage, ClientEvent, DeviceInfo, ClientModels
-from .client import TranslationPackage, filter_supported_styles, loras_to_upload
+from .client import SharedWorkflow, TranslationPackage, filter_supported_styles, loras_to_upload
from .files import FileFormat
from .image import Image, ImageCollection
from .network import RequestManager, NetworkError
@@ -87,16 +87,18 @@ class ComfyClient(Client):
default_url = "http://127.0.0.1:8188"
- _requests = RequestManager()
- _id: str
- _messages: asyncio.Queue[ClientMessage]
- _queue: asyncio.Queue[JobInfo]
- _jobs: deque[JobInfo]
- _active: Optional[JobInfo] = None
- _job_runner: asyncio.Task
- _websocket_listener: asyncio.Task
- _supported_archs: list[Arch]
- _supported_languages: list[TranslationPackage]
+ def __init__(self, url):
+ self.url = url
+ self.models = ClientModels()
+ self._requests = RequestManager()
+ self._id = str(uuid.uuid4())
+ self._active: Optional[JobInfo] = None
+ self._supported_archs: list[Arch] = []
+ self._supported_languages: list[TranslationPackage] = []
+ self._messages = asyncio.Queue()
+ self._queue = asyncio.Queue()
+ self._jobs = deque()
+ self._is_connected = False
@staticmethod
async def connect(url=default_url, access_token=""):
@@ -124,7 +126,7 @@ async def connect(url=default_url, access_token=""):
# Check for required and optional model resources
models = client.models
- models.node_inputs = {name: nodes[name]["input"].get("required", None) for name in nodes}
+ models.node_inputs = {name: nodes[name]["input"] for name in nodes}
available_resources = client.models.resources = {}
clip_models = nodes["DualCLIPLoader"]["input"]["required"]["clip_name1"][0]
@@ -175,15 +177,6 @@ async def connect(url=default_url, access_token=""):
_ensure_supported_style(client)
return client
- def __init__(self, url):
- self.url = url
- self.models = ClientModels()
- self._id = str(uuid.uuid4())
- self._messages = asyncio.Queue()
- self._queue = asyncio.Queue()
- self._jobs = deque()
- self._is_connected = False
-
async def _get(self, op: str):
return await self._requests.get(f"{self.url}/{op}")
@@ -237,6 +230,7 @@ async def _listen(self):
f"{url}/ws?clientId={self._id}", max_size=2**30, read_limit=2**30, ping_timeout=60
):
try:
+ await self._subscribe_workflows()
await self._listen_websocket(websocket)
except websockets_exceptions.ConnectionClosedError as e:
log.warning(f"Websocket connection closed: {str(e)}")
@@ -287,12 +281,20 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr
if msg["type"] == "executing" and msg["data"]["node"] is None:
job_id = msg["data"]["prompt_id"]
- if self._clear_job(job_id):
- # Usually we don't get here because finished, interrupted or error is sent first.
- # But it may happen if the entire execution is cached and no images are sent.
+ if local_id := self._clear_job(job_id):
if len(images) == 0:
+ # It may happen if the entire execution is cached and no images are sent.
images = last_images
- await self._report(ClientEvent.finished, job_id, 1, images=images)
+ if len(images) == 0:
+ # Still no images. Potential scenario: execution cached, but previous
+ # generation happened before the client was connected.
+ err = "No new images were generated because the inputs did not change."
+ await self._report(ClientEvent.error, local_id, error=err)
+ else:
+ last_images = images
+ await self._report(
+ ClientEvent.finished, local_id, 1, images=images, result=result
+ )
elif msg["type"] in ("execution_cached", "executing", "progress"):
if self._active is not None and progress is not None:
@@ -308,12 +310,6 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr
pose_json = _extract_pose_json(msg)
if job and pose_json:
result = pose_json
- elif job and _validate_executed_node(msg, len(images)):
- self._clear_job(job.remote_id)
- last_images = images
- await self._report(
- ClientEvent.finished, job.local_id, 1, images=images, result=result
- )
if msg["type"] == "execution_error":
job = self._get_active_job(msg["data"]["prompt_id"])
@@ -322,7 +318,12 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr
traceback = msg["data"].get("traceback", "no traceback")
log.error(f"Job {job} failed: {error}\n{traceback}")
self._clear_job(job.remote_id)
- await self._report(ClientEvent.error, job.local_id, 0, error=error)
+ await self._report(ClientEvent.error, job.local_id, error=error)
+
+ if msg["type"] == "etn_workflow_published":
+ name = f"{msg['data']['publisher']['name']} ({msg['data']['publisher']['id']})"
+ workflow = SharedWorkflow(name, msg["data"]["workflow"])
+ await self._report(ClientEvent.published, "", result=workflow)
async def listen(self):
self._is_connected = True
@@ -358,6 +359,7 @@ async def disconnect(self):
self._job_runner,
self._websocket_listener,
self._report(ClientEvent.disconnected, ""),
+ self._unsubscribe_workflows(),
)
async def try_inspect(self, folder_name: str):
@@ -431,6 +433,18 @@ async def translate(self, text: str, lang: str):
log.error(f"Could not translate text: {str(e)}")
return text
+ async def _subscribe_workflows(self):
+ try:
+ await self._post("api/etn/workflow/subscribe", {"client_id": self._id})
+ except Exception as e:
+ log.error(f"Couldn't subscribe to shared workflows: {str(e)}")
+
+ async def _unsubscribe_workflows(self):
+ try:
+ await self._post("api/etn/workflow/unsubscribe", {"client_id": self._id})
+ except Exception as e:
+ log.error(f"Couldn't unsubscribe from shared workflows: {str(e)}")
+
def supports_arch(self, arch: Arch):
return arch in self._supported_archs
@@ -501,9 +515,10 @@ async def _start_job(self, remote_id: str):
def _clear_job(self, job_remote_id: str | asyncio.Future | None):
if self._active is not None and self._active.remote_id == job_remote_id:
+ result = self._active.local_id
self._active = None
- return True
- return False
+ return result
+ return None
def _check_workload(self, sdver: Arch) -> list[MissingResource]:
models = self.models
@@ -719,8 +734,11 @@ def _validate_executed_node(msg: dict, image_count: int):
images = output["images"]
if len(images) != image_count: # not critical
log.warning(f"Received number of images does not match: {len(images)} != {image_count}")
- if len(images) > 0 and "source" in images[0] and images[0]["type"] == "output":
+ if image_count == 0 or len(images) == 0:
+ log.warning(f"Received no images (execution cached?)")
+ return False
+ if "source" in images[0] and images[0]["type"] == "output":
return True
except Exception as e:
log.warning(f"Error processing message, error={str(e)}, msg={msg}")
- return False
+ return False
diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py
index a374aed535..f88ef0b914 100644
--- a/ai_diffusion/comfy_workflow.py
+++ b/ai_diffusion/comfy_workflow.py
@@ -1,7 +1,8 @@
from __future__ import annotations
+from copy import deepcopy
from enum import Enum
from pathlib import Path
-from typing import NamedTuple, Tuple, Literal, overload, Any
+from typing import NamedTuple, Tuple, Literal, TypeVar, overload, Any
from uuid import uuid4
import json
@@ -19,32 +20,87 @@ class Output(NamedTuple):
output: int
+T = TypeVar("T")
Output2 = Tuple[Output, Output]
Output3 = Tuple[Output, Output, Output]
Output4 = Tuple[Output, Output, Output, Output]
+Input = int | float | bool | str | Output
-class ComfyWorkflow:
- """Builder for workflows which can be sent to the ComfyUI prompt API."""
+class ComfyNode(NamedTuple):
+ id: int
+ type: str
+ inputs: dict[str, Input]
+
+ @overload
+ def input(self, key: str, default: T) -> T: ...
+
+ @overload
+ def input(self, key: str, default: None = None) -> Input | None: ...
+
+ def input(self, key: str, default: T | None = None) -> T | Input | None:
+ result = self.inputs.get(key, default)
+ assert (
+ default is None
+ or type(result) == type(default)
+ or (isnumber(result) and isnumber(default))
+ )
+ return result
+
+ def output(self, index=0) -> Output:
+ return Output(int(self.id), index)
- root: dict[str, dict]
- images: dict[str, Image]
- node_count = 0
- sample_count = 0
- _cache: dict[str, Output | Output2 | Output3 | Output4]
- _nodes_required_inputs: dict[str, dict[str, Any]]
- _run_mode: ComfyRunMode
+def isnumber(x):
+ return isinstance(x, (int, float))
+
+
+class ComfyWorkflow:
+ """Builder for workflows which can be sent to the ComfyUI prompt API."""
def __init__(self, node_inputs: dict | None = None, run_mode=ComfyRunMode.server):
- self.root = {}
- self.images = {}
- self._cache = {}
- self._nodes_required_inputs = node_inputs or {}
- self._run_mode = run_mode
+ self.root: dict[str, dict] = {}
+ self.images: dict[str, Image] = {}
+ self.node_count = 0
+ self.sample_count = 0
+ self._cache: dict[str, Output | Output2 | Output3 | Output4] = {}
+ self._nodes_inputs: dict[str, dict[str, Any]] = node_inputs or {}
+ self._run_mode: ComfyRunMode = run_mode
+
+ @staticmethod
+ def import_graph(existing: dict, node_inputs: dict):
+ w = ComfyWorkflow(node_inputs)
+ existing = _convert_ui_workflow(existing, node_inputs)
+ node_map: dict[str, str] = {}
+ queue = list(existing.keys())
+ while queue:
+ id = queue.pop(0)
+ node = deepcopy(existing[id])
+ if node_inputs and node["class_type"] not in node_inputs:
+ raise ValueError(
+ f"Workflow contains a node of type {node['class_type']} which is not installed on the ComfyUI server."
+ )
+ edges = [e for e in node["inputs"].values() if isinstance(e, list)]
+ if any(e[0] not in node_map for e in edges):
+ queue.append(id) # requeue node if an input is not yet mapped
+ continue
+
+ for e in edges:
+ e[0] = node_map[e[0]]
+ node_map[id] = str(w.node_count)
+ w.root[str(w.node_count)] = node
+ w.node_count += 1
+ return w
+
+ @staticmethod
+ def from_dict(existing: dict):
+ w = ComfyWorkflow()
+ w.root = existing
+ w.node_count = len(w.root)
+ return w
def add_default_values(self, node_name: str, args: dict):
- if node_inputs := self._nodes_required_inputs.get(node_name, None):
+ if node_inputs := _inputs_for_node(self._nodes_inputs, node_name, "required"):
for k, v in node_inputs.items():
if k not in args:
if len(v) == 1 and isinstance(v[0], list) and len(v[0]) > 0:
@@ -57,6 +113,11 @@ def add_default_values(self, node_name: str, args: dict):
args[k] = default
return args
+ def input_type(self, class_type: str, input_name: str) -> tuple | None:
+ if inputs := _inputs_for_node(self._nodes_inputs, class_type):
+ return inputs.get(input_name)
+ return None
+
def dump(self, filepath: str | Path):
filepath = Path(filepath)
if filepath.suffix != ".json":
@@ -102,11 +163,50 @@ def add_cached(self, class_type: str, output_count: Literal[1] | Literal[3], **i
self._cache[key] = result
return result
+ def remove(self, node_id: int):
+ del self.root[str(node_id)]
+
+ def node(self, node_id: int):
+ inputs = self.root[str(node_id)]["inputs"]
+ inputs = {
+ k: Output(int(v[0]), v[1]) if isinstance(v, list) else v for k, v in inputs.items()
+ }
+ return ComfyNode(node_id, self.root[str(node_id)]["class_type"], inputs)
+
+ def copy(self, node: ComfyNode):
+ return self.add(node.type, 1, **node.inputs)
+
+ def find(self, type: str):
+ return (self.node(int(k)) for k, v in self.root.items() if v["class_type"] == type)
+
+ def find_connected(self, output: Output):
+ for node in self:
+ for input_name, input_value in node.inputs.items():
+ if input_value == output:
+ yield node, input_name
+
+ def guess_sample_count(self):
+ self.sample_count = sum(
+ int(value)
+ for node in self
+ for name, value in node.inputs.items()
+ if name == "steps" and isinstance(value, (int, float))
+ )
+ return self.sample_count
+
+ def __iter__(self):
+ return iter(self.node(int(k)) for k in self.root.keys())
+
+ def __contains__(self, node: ComfyNode):
+ return any(n == node for n in self)
+
def _add_image(self, image: Image):
id = str(uuid4())
self.images[id] = image
return id
+ # Nodes
+
def ksampler(
self,
model: Output,
@@ -726,7 +826,7 @@ def load_mask(self, mask: Image):
def send_image(self, image: Output):
if self._run_mode is ComfyRunMode.runtime:
return self.add("ETN_ReturnImage", 1, images=image)
- return self.add("ETN_SendImageWebSocket", 1, images=image)
+ return self.add("ETN_SendImageWebSocket", 1, images=image, format="PNG")
def save_image(self, image: Output, prefix: str):
return self.add("SaveImage", 1, images=image, filename_prefix=prefix)
@@ -760,3 +860,62 @@ def estimate_pose(self, image: Output, resolution: int):
# use smaller model, but it requires onnxruntime, see #630
mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx"
return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls)
+
+
+def _inputs_for_node(node_inputs: dict[str, dict[str, Any]], node_name: str, filter=""):
+ inputs = node_inputs.get(node_name)
+ if inputs is None:
+ return None
+ if filter:
+ return inputs.get(filter)
+ result = inputs.get("required", {})
+ result.update(inputs.get("optional", {}))
+ return result
+
+
+def _convert_ui_workflow(w: dict, node_inputs: dict):
+ version = w.get("version")
+ nodes = w.get("nodes")
+ links = w.get("links")
+ if not (version and nodes and links):
+ return w
+
+ if not node_inputs:
+ raise ValueError("An active ComfyUI connection is required to convert a UI workflow file.")
+
+ primitives = {}
+ for node in nodes:
+ if node["type"] == "PrimitiveNode":
+ primitives[node["id"]] = node["widgets_values"][0]
+
+ r = {}
+ for node in nodes:
+ id = node["id"]
+ type = node["type"]
+ if type == "PrimitiveNode":
+ continue
+
+ inputs = {}
+ fields = _inputs_for_node(node_inputs, type)
+ if fields is None:
+ raise ValueError(
+ f"Workflow uses node type {type}, but it is not installed on the ComfyUI server."
+ )
+ widget_count = 0
+ for field_name, field in fields.items():
+ field_type = field[0]
+ if field_type in ["INT", "FLOAT", "BOOL", "STRING"] or isinstance(field_type, list):
+ inputs[field_name] = node["widgets_values"][widget_count]
+ widget_count += 1
+ for connection in node["inputs"]:
+ if connection["name"] == field_name and connection["link"] is not None:
+ link = next(l for l in links if l[0] == connection["link"])
+ prim = primitives.get(link[1])
+ if prim is not None:
+ inputs[field_name] = prim
+ else:
+ inputs[field_name] = [link[1], link[2]]
+ break
+ r[id] = {"class_type": type, "inputs": inputs}
+
+ return r
diff --git a/ai_diffusion/connection.py b/ai_diffusion/connection.py
index 302b8d06b8..8c345d6a5e 100644
--- a/ai_diffusion/connection.py
+++ b/ai_diffusion/connection.py
@@ -4,7 +4,7 @@
from PyQt5.QtGui import QDesktopServices
import asyncio
-from .client import Client, ClientMessage, ClientEvent, DeviceInfo
+from .client import Client, ClientMessage, ClientEvent, DeviceInfo, SharedWorkflow
from .comfy_client import ComfyClient
from .cloud_client import CloudClient
from .network import NetworkError
@@ -36,12 +36,16 @@ class Connection(QObject, ObservableProperties):
error_changed = pyqtSignal(str)
models_changed = pyqtSignal()
message_received = pyqtSignal(ClientMessage)
-
- _client: Client | None = None
- _task: asyncio.Task | None = None
+ workflow_published = pyqtSignal(str)
def __init__(self):
super().__init__()
+
+ self._client: Client | None = None
+ self._task: asyncio.Task | None = None
+ self._workflows: dict[str, dict] = {}
+ self._temporary_disconnect = False
+
settings.changed.connect(self._handle_settings_changed)
self._update_state()
@@ -151,26 +155,20 @@ def user(self):
if client := self.client_if_connected:
return client.user
+ @property
+ def workflows(self):
+ return self._workflows
+
async def _handle_messages(self):
client = self._client
- temporary_disconnect = False
+ self._temporary_disconnect = False
assert client is not None
try:
async with client:
async for msg in client.listen():
try:
- if msg.event is ClientEvent.error and not msg.job_id:
- self.error = _("Error communicating with server: ") + str(msg.error)
- elif msg.event is ClientEvent.disconnected:
- temporary_disconnect = True
- self.error = _("Disconnected from server, trying to reconnect...")
- elif msg.event is ClientEvent.connected:
- if temporary_disconnect:
- temporary_disconnect = False
- self.error = ""
- else:
- self.message_received.emit(msg)
+ self._handle_message(msg)
except asyncio.CancelledError:
break
except Exception as e:
@@ -179,6 +177,24 @@ async def _handle_messages(self):
except asyncio.CancelledError:
pass # shutdown
+ def _handle_message(self, msg: ClientMessage):
+ match msg:
+ case (ClientEvent.error, "", *_):
+ self.error = _("Error communicating with server: ") + str(msg.error)
+ case (ClientEvent.disconnected, *_):
+ self._temporary_disconnect = True
+ self.error = _("Disconnected from server, trying to reconnect...")
+ case (ClientEvent.connected, *_):
+ if self._temporary_disconnect:
+ self._temporary_disconnect = False
+ self.error = ""
+ case (ClientEvent.published, *_):
+ assert isinstance(msg.result, SharedWorkflow)
+ self._workflows[msg.result.publisher] = msg.result.workflow
+ self.workflow_published.emit(msg.result.publisher)
+ case _:
+ self.message_received.emit(msg)
+
def _update_state(self):
if (
self.state in [ConnectionState.disconnected, ConnectionState.error]
diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py
new file mode 100644
index 0000000000..804f15a9c1
--- /dev/null
+++ b/ai_diffusion/custom_workflow.py
@@ -0,0 +1,493 @@
+import asyncio
+import json
+
+from enum import Enum
+from copy import copy
+from dataclasses import dataclass
+from typing import Any, Awaitable, Callable, NamedTuple, Literal, TYPE_CHECKING
+from pathlib import Path
+from PyQt5.QtCore import Qt, QObject, QUuid, QAbstractListModel, QSortFilterProxyModel, QModelIndex
+from PyQt5.QtCore import pyqtSignal
+
+from .api import WorkflowInput
+from .comfy_workflow import ComfyWorkflow, ComfyNode
+from .connection import Connection, ConnectionState
+from .image import Bounds, Image
+from .jobs import Job, JobParams, JobQueue, JobKind
+from .properties import Property, ObservableProperties
+from .style import Styles
+from .util import user_data_dir, client_logger as log
+from .ui import theme
+from . import eventloop
+
+if TYPE_CHECKING:
+ from .layer import LayerManager
+
+
+class WorkflowSource(Enum):
+ document = 0
+ remote = 1
+ local = 2
+
+
+@dataclass
+class CustomWorkflow:
+ id: str
+ source: WorkflowSource
+ workflow: ComfyWorkflow
+ path: Path | None = None
+
+ @property
+ def name(self):
+ return self.id.removesuffix(".json")
+
+
+class WorkflowCollection(QAbstractListModel):
+
+ _icon_local = theme.icon("file-json")
+ _icon_remote = theme.icon("web-connection")
+ _icon_document = theme.icon("file-kra")
+
+ loaded = pyqtSignal()
+
+ def __init__(self, connection: Connection, folder: Path | None = None):
+ super().__init__()
+ self._connection = connection
+ self._folder = folder or user_data_dir / "workflows"
+ self._workflows: list[CustomWorkflow] = []
+ self._pending_workflows: list[tuple[str, WorkflowSource, dict]] = []
+
+ self._connection.state_changed.connect(self._handle_connection)
+ self._connection.workflow_published.connect(self._process_remote_workflow)
+ self._handle_connection(self._connection.state)
+
+ def _handle_connection(self, state: ConnectionState):
+ if state in (ConnectionState.connected, ConnectionState.disconnected):
+ self.clear()
+
+ if state is ConnectionState.connected:
+ for id, source, graph in self._pending_workflows:
+ self._process_workflow(id, source, graph)
+ self._pending_workflows.clear()
+
+ for file in self._folder.glob("*.json"):
+ try:
+ self._process_file(file)
+ except Exception as e:
+ log.exception(f"Error loading workflow from {file}: {e}")
+
+ for wf in self._connection.workflows.keys():
+ self._process_remote_workflow(wf)
+
+ self.loaded.emit()
+
+ def _node_inputs(self):
+ return self._connection.client.models.node_inputs
+
+ def _process_workflow(
+ self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None
+ ):
+ if self._connection.state is not ConnectionState.connected:
+ self._pending_workflows.append((id, source, graph))
+ return
+
+ comfy_flow = ComfyWorkflow.import_graph(graph, self._node_inputs())
+ workflow = CustomWorkflow(id, source, comfy_flow, path)
+ idx = self.find_index(workflow.id)
+ if idx.isValid():
+ self._workflows[idx.row()] = workflow
+ self.dataChanged.emit(idx, idx)
+ else:
+ self.append(workflow)
+ return idx
+
+ def _process_remote_workflow(self, id: str):
+ graph = self._connection.workflows[id]
+ self._process_workflow(id, WorkflowSource.remote, graph)
+
+ def _process_file(self, file: Path):
+ with file.open("r") as f:
+ graph = json.load(f)
+ self._process_workflow(file.stem, WorkflowSource.local, graph, file)
+
+ def rowCount(self, parent=QModelIndex()):
+ return len(self._workflows)
+
+ def data(self, index: QModelIndex, role: int = 0):
+ if role == Qt.ItemDataRole.DisplayRole:
+ return self._workflows[index.row()].name
+ if role == Qt.ItemDataRole.UserRole:
+ return self._workflows[index.row()].id
+ if role == Qt.ItemDataRole.DecorationRole:
+ source = self._workflows[index.row()].source
+ if source is WorkflowSource.document:
+ return self._icon_document
+ if source is WorkflowSource.remote:
+ return self._icon_remote
+ return self._icon_local
+
+ def append(self, item: CustomWorkflow):
+ end = len(self._workflows)
+ self.beginInsertRows(QModelIndex(), end, end)
+ self._workflows.append(item)
+ self.endInsertRows()
+
+ def add_from_document(self, id: str, graph: dict):
+ self._process_workflow(id, WorkflowSource.document, graph)
+
+ def remove(self, id: str):
+ idx = self.find_index(id)
+ if idx.isValid():
+ wf = self._workflows[idx.row()]
+ if wf.source is WorkflowSource.local and wf.path is not None:
+ wf.path.unlink()
+ self.beginRemoveRows(QModelIndex(), idx.row(), idx.row())
+ self._workflows.pop(idx.row())
+ self.endRemoveRows()
+
+ def clear(self):
+ if len(self._workflows) > 0:
+ self.beginResetModel()
+ self._workflows.clear()
+ self.endResetModel()
+
+ def set_graph(self, index: QModelIndex, graph: dict):
+ wf = self._workflows[index.row()]
+ wf.workflow = ComfyWorkflow.import_graph(graph, self._node_inputs())
+ self.dataChanged.emit(index, index)
+
+ def save_as(self, id: str, graph: dict):
+ if self.find(id) is not None:
+ suffix = 1
+ while self.find(f"{id} ({suffix})"):
+ suffix += 1
+ id = f"{id} ({suffix})"
+
+ self._folder.mkdir(exist_ok=True)
+ path = self._folder / f"{id}.json"
+ path.write_text(json.dumps(graph, indent=2))
+ self._process_workflow(id, WorkflowSource.local, graph, path)
+ return id
+
+ def import_file(self, filepath: Path):
+ try:
+ with filepath.open("r") as f:
+ graph = json.load(f)
+ try:
+ ComfyWorkflow.import_graph(graph, self._node_inputs())
+ except Exception as e:
+ raise RuntimeError(f"This is not a supported workflow file ({e})")
+ return self.save_as(filepath.stem, graph)
+ except Exception as e:
+ raise RuntimeError(f"Error importing workflow from {filepath}: {e}")
+
+ def find_index(self, id: str):
+ for i, wf in enumerate(self._workflows):
+ if wf.id == id:
+ return self.index(i)
+ return QModelIndex()
+
+ def find(self, id: str):
+ idx = self.find_index(id)
+ if idx.isValid():
+ return self._workflows[idx.row()]
+ return None
+
+ def get(self, id: str):
+ result = self.find(id)
+ if result is None:
+ raise KeyError(f"Workflow {id} not found")
+ return result
+
+ def __getitem__(self, index: int):
+ return self._workflows[index]
+
+ def __len__(self):
+ return len(self._workflows)
+
+
+class SortedWorkflows(QSortFilterProxyModel):
+ def __init__(self, workflows: WorkflowCollection):
+ super().__init__()
+ self._workflows = workflows
+ self.setSourceModel(workflows)
+ self.setSortCaseSensitivity(Qt.CaseSensitivity.CaseInsensitive)
+ self.sort(0)
+
+ def lessThan(self, left: QModelIndex, right: QModelIndex):
+ l = self._workflows[left.row()]
+ r = self._workflows[right.row()]
+ if l.source is r.source:
+ return l.name < r.name
+ return l.source.value < r.source.value
+
+ def __getitem__(self, index: int):
+ idx = self.mapToSource(self.index(index, 0)).row()
+ return self._workflows[idx]
+
+
+class ParamKind(Enum):
+ image_layer = 0
+ mask_layer = 1
+ number_int = 2
+ number_float = 3
+ toggle = 4
+ text = 5
+ prompt_positive = 6
+ prompt_negative = 7
+ choice = 8
+ style = 9
+
+
+class CustomParam(NamedTuple):
+ kind: ParamKind
+ name: str
+ default: Any | None = None
+ min: int | float | None = None
+ max: int | float | None = None
+ choices: list[str] | None = None
+
+
+def workflow_parameters(w: ComfyWorkflow):
+ text_types = ("text", "prompt (positive)", "prompt (negative)")
+ for node in w:
+ match (node.type, node.input("type", "")):
+ case ("ETN_KritaStyle", _):
+ name = node.input("name", "Style")
+ yield CustomParam(ParamKind.style, name, node.input("sampler_preset", "auto"))
+ case ("ETN_KritaImageLayer", _):
+ name = node.input("name", "Image")
+ yield CustomParam(ParamKind.image_layer, name)
+ case ("ETN_KritaMaskLayer", _):
+ name = node.input("name", "Mask")
+ yield CustomParam(ParamKind.mask_layer, name)
+ case ("ETN_Parameter", "number (integer)"):
+ name = node.input("name", "Parameter")
+ default = node.input("default", 0)
+ min = node.input("min", -(2**31))
+ max = node.input("max", 2**31)
+ yield CustomParam(ParamKind.number_int, name, default=default, min=min, max=max)
+ case ("ETN_Parameter", "number"):
+ name = node.input("name", "Parameter")
+ default = node.input("default", 0.0)
+ min = node.input("min", 0.0)
+ max = node.input("max", 1.0)
+ yield CustomParam(ParamKind.number_float, name, default=default, min=min, max=max)
+ case ("ETN_Parameter", "toggle"):
+ name = node.input("name", "Parameter")
+ default = node.input("default", False)
+ yield CustomParam(ParamKind.toggle, name, default=default)
+ case ("ETN_Parameter", type) if type in text_types:
+ name = node.input("name", "Parameter")
+ default = node.input("default", "")
+ match type:
+ case "text":
+ yield CustomParam(ParamKind.text, name, default=default)
+ case "prompt (positive)":
+ yield CustomParam(ParamKind.prompt_positive, name, default=default)
+ case "prompt (negative)":
+ yield CustomParam(ParamKind.prompt_negative, name, default=default)
+ case ("ETN_Parameter", "choice"):
+ name = node.input("name", "Parameter")
+ default = node.input("default", "")
+ if choices := _get_choices(w, node):
+ yield CustomParam(ParamKind.choice, name, choices=choices, default=default)
+ else:
+ yield CustomParam(ParamKind.text, name, default=default)
+ case ("ETN_Parameter", unknown_type) if unknown_type != "auto":
+ unknown = node.input("name", "?") + ": " + unknown_type
+ log.warning(f"Custom workflow has an unsupported parameter type {unknown}")
+
+
+def _get_choices(w: ComfyWorkflow, node: ComfyNode):
+ connected, input_name = next(w.find_connected(node.output()), (None, ""))
+ if connected:
+ if input_type := w.input_type(connected.type, input_name):
+ if isinstance(input_type[0], list):
+ return input_type[0]
+ return None
+
+
+ImageGenerator = Callable[[WorkflowInput | None], Awaitable[None | Literal[False] | WorkflowInput]]
+
+
+class CustomGenerationMode(Enum):
+ regular = 0
+ live = 1
+
+
+class CustomWorkspace(QObject, ObservableProperties):
+
+ workflow_id = Property("", setter="_set_workflow_id")
+ params = Property({}, persist=True)
+ mode = Property(CustomGenerationMode.regular, setter="_set_mode")
+ is_live = Property(False, setter="toggle_live")
+ has_result = Property(False)
+
+ workflow_id_changed = pyqtSignal(str)
+ graph_changed = pyqtSignal()
+ params_changed = pyqtSignal(dict)
+ mode_changed = pyqtSignal(CustomGenerationMode)
+ is_live_changed = pyqtSignal(bool)
+ result_available = pyqtSignal(Image)
+ has_result_changed = pyqtSignal(bool)
+ modified = pyqtSignal(QObject, str)
+
+ _live_poll_rate = 0.1
+
+ def __init__(self, workflows: WorkflowCollection, generator: ImageGenerator, jobs: JobQueue):
+ super().__init__()
+ self._workflows = workflows
+ self._generator = generator
+ self._workflow: CustomWorkflow | None = None
+ self._graph: ComfyWorkflow | None = None
+ self._metadata: list[CustomParam] = []
+ self._last_input: WorkflowInput | None = None
+ self._last_result: Image | None = None
+ self._last_job: JobParams | None = None
+
+ jobs.job_finished.connect(self._handle_job_finished)
+ workflows.dataChanged.connect(self._update_workflow)
+ workflows.loaded.connect(self._set_default_workflow)
+ self._set_default_workflow()
+
+ def _set_default_workflow(self):
+ if not self.workflow_id and len(self._workflows) > 0:
+ self.workflow_id = self._workflows[0].id
+ else:
+ current_index = self._workflows.find_index(self.workflow_id)
+ if current_index.isValid():
+ self._update_workflow(current_index, QModelIndex())
+
+ def _update_workflow(self, idx: QModelIndex, _: QModelIndex):
+ wf = self._workflows[idx.row()]
+ if wf.id == self._workflow_id:
+ self._workflow = wf
+ self._graph = self._workflow.workflow
+ self._metadata = list(workflow_parameters(self._graph))
+ self.params = _coerce(self.params, self._metadata)
+ self.graph_changed.emit()
+
+ def _set_workflow_id(self, id: str):
+ if self._workflow_id == id:
+ return
+ self._workflow_id = id
+ self.workflow_id_changed.emit(id)
+ self.modified.emit(self, "workflow_id")
+ index = self._workflows.find_index(id)
+ if index.isValid(): # might be invalid when loading document before connecting
+ self._update_workflow(index, QModelIndex())
+
+ def set_graph(self, id: str, graph: dict):
+ if self._workflows.find(id) is None:
+ id = "Document Workflow (embedded)"
+ self._workflows.add_from_document(id, graph)
+ self.workflow_id = id
+
+ def import_file(self, filepath: Path):
+ self.workflow_id = self._workflows.import_file(filepath)
+
+ def save_as(self, id: str):
+ assert self._graph, "Save as: no workflow selected"
+ self.workflow_id = self._workflows.save_as(id, self._graph.root)
+
+ def remove_workflow(self):
+ if id := self.workflow_id:
+ self._workflow_id = ""
+ self._workflow = None
+ self._graph = None
+ self._metadata = []
+ self._workflows.remove(id)
+
+ def generate(self):
+ eventloop.run(self._generator(None))
+
+ def toggle_live(self, active: bool):
+ if self._is_live != active:
+ self._is_live = active
+ self.is_live_changed.emit(active)
+ if active:
+ eventloop.run(self._continue_generating())
+
+ def _set_mode(self, value: CustomGenerationMode):
+ if self._mode != value:
+ self._mode = value
+ self.mode_changed.emit(value)
+ self.is_live = False
+
+ @property
+ def workflow(self):
+ return self._workflow
+
+ @property
+ def graph(self):
+ return self._graph
+
+ @property
+ def metadata(self):
+ return self._metadata
+
+ @property
+ def job_name(self):
+ for param in self.metadata:
+ if param.kind is ParamKind.prompt_positive:
+ return str(self.params[param.name])
+ return self.workflow_id or "Custom Workflow"
+
+ def collect_parameters(self, layers: "LayerManager", bounds: Bounds):
+ params = copy(self.params)
+ for md in self.metadata:
+ param = params.get(md.name)
+ assert param is not None, f"Parameter {md.name} not found"
+
+ if md.kind is ParamKind.image_layer:
+ layer = layers.find(QUuid(param))
+ if layer is None:
+ raise ValueError(f"Input layer for parameter {md.name} not found")
+ params[md.name] = layer.get_pixels(bounds)
+ elif md.kind is ParamKind.mask_layer:
+ layer = layers.find(QUuid(param))
+ if layer is None:
+ raise ValueError(f"Input layer for parameter {md.name} not found")
+ params[md.name] = layer.get_mask(bounds)
+ elif md.kind is ParamKind.style:
+ style = Styles.list().find(str(param))
+ if style is None:
+ raise ValueError(f"Style {param} not found")
+ params[md.name] = style
+ return params
+
+ def _handle_job_finished(self, job: Job):
+ if job.kind is JobKind.live_preview:
+ if len(job.results) > 0:
+ self._last_result = job.results[0]
+ self._last_job = job.params
+ self.result_available.emit(self._last_result)
+ self.has_result = True
+ eventloop.run(self._continue_generating())
+
+ async def _continue_generating(self):
+ while self.is_live:
+ new_input = await self._generator(self._last_input)
+ if new_input is False: # abort live generation
+ self.is_live = False
+ return
+ elif new_input is None: # no changes in input data
+ await asyncio.sleep(self._live_poll_rate)
+ else: # frame was scheduled
+ self._last_input = new_input
+ return
+
+ @property
+ def live_result(self):
+ assert self._last_result and self._last_job, "No live result available"
+ return self._last_result, self._last_job
+
+
+def _coerce(params: dict[str, Any], types: list[CustomParam]):
+ def use(value, default):
+ if value is None or not type(value) == type(default):
+ return default
+ return value
+
+ return {t.name: use(params.get(t.name), t.default) for t in types}
diff --git a/ai_diffusion/icons/comfyui-dark.svg b/ai_diffusion/icons/comfyui-dark.svg
new file mode 100644
index 0000000000..c8c916dc29
--- /dev/null
+++ b/ai_diffusion/icons/comfyui-dark.svg
@@ -0,0 +1,96 @@
+
+
diff --git a/ai_diffusion/icons/comfyui-light.svg b/ai_diffusion/icons/comfyui-light.svg
new file mode 100644
index 0000000000..03d88d1696
--- /dev/null
+++ b/ai_diffusion/icons/comfyui-light.svg
@@ -0,0 +1,96 @@
+
+
diff --git a/ai_diffusion/icons/file-json-dark.svg b/ai_diffusion/icons/file-json-dark.svg
new file mode 100644
index 0000000000..20143f5e32
--- /dev/null
+++ b/ai_diffusion/icons/file-json-dark.svg
@@ -0,0 +1,96 @@
+
+
diff --git a/ai_diffusion/icons/file-json-light.svg b/ai_diffusion/icons/file-json-light.svg
new file mode 100644
index 0000000000..7d726a8249
--- /dev/null
+++ b/ai_diffusion/icons/file-json-light.svg
@@ -0,0 +1,96 @@
+
+
diff --git a/ai_diffusion/icons/file-kra-dark.svg b/ai_diffusion/icons/file-kra-dark.svg
new file mode 100644
index 0000000000..f2fca41442
--- /dev/null
+++ b/ai_diffusion/icons/file-kra-dark.svg
@@ -0,0 +1,92 @@
+
+
diff --git a/ai_diffusion/icons/file-kra-light.svg b/ai_diffusion/icons/file-kra-light.svg
new file mode 100644
index 0000000000..99ff41d617
--- /dev/null
+++ b/ai_diffusion/icons/file-kra-light.svg
@@ -0,0 +1,92 @@
+
+
diff --git a/ai_diffusion/icons/import-dark.svg b/ai_diffusion/icons/import-dark.svg
new file mode 100644
index 0000000000..5dd0ce7eac
--- /dev/null
+++ b/ai_diffusion/icons/import-dark.svg
@@ -0,0 +1,66 @@
+
+
diff --git a/ai_diffusion/icons/import-light.svg b/ai_diffusion/icons/import-light.svg
new file mode 100644
index 0000000000..4a19b201df
--- /dev/null
+++ b/ai_diffusion/icons/import-light.svg
@@ -0,0 +1,66 @@
+
+
diff --git a/ai_diffusion/icons/save-dark.svg b/ai_diffusion/icons/save-dark.svg
new file mode 100644
index 0000000000..11db345bce
--- /dev/null
+++ b/ai_diffusion/icons/save-dark.svg
@@ -0,0 +1,90 @@
+
+
diff --git a/ai_diffusion/icons/save-light.svg b/ai_diffusion/icons/save-light.svg
new file mode 100644
index 0000000000..b8bc1e32a6
--- /dev/null
+++ b/ai_diffusion/icons/save-light.svg
@@ -0,0 +1,90 @@
+
+
diff --git a/ai_diffusion/icons/web-connection-dark.svg b/ai_diffusion/icons/web-connection-dark.svg
new file mode 100644
index 0000000000..dcd351f355
--- /dev/null
+++ b/ai_diffusion/icons/web-connection-dark.svg
@@ -0,0 +1,87 @@
+
+
diff --git a/ai_diffusion/icons/web-connection-light.svg b/ai_diffusion/icons/web-connection-light.svg
new file mode 100644
index 0000000000..fa1338e6e5
--- /dev/null
+++ b/ai_diffusion/icons/web-connection-light.svg
@@ -0,0 +1,87 @@
+
+
diff --git a/ai_diffusion/icons/workspace-custom-dark.svg b/ai_diffusion/icons/workspace-custom-dark.svg
new file mode 100644
index 0000000000..4eb2d3526f
--- /dev/null
+++ b/ai_diffusion/icons/workspace-custom-dark.svg
@@ -0,0 +1,145 @@
+
+
diff --git a/ai_diffusion/icons/workspace-custom-light.svg b/ai_diffusion/icons/workspace-custom-light.svg
new file mode 100644
index 0000000000..e56c8e14c4
--- /dev/null
+++ b/ai_diffusion/icons/workspace-custom-light.svg
@@ -0,0 +1,145 @@
+
+
diff --git a/ai_diffusion/image.py b/ai_diffusion/image.py
index fafe527fe7..d06768bc2b 100644
--- a/ai_diffusion/image.py
+++ b/ai_diffusion/image.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from enum import Enum
from math import ceil, sqrt
-from PyQt5.QtGui import QImage, QImageWriter, QPixmap, QIcon, QPainter, QColorSpace
+from PyQt5.QtGui import QImage, QImageWriter, QImageReader, QPixmap, QIcon, QPainter, QColorSpace
from PyQt5.QtGui import qRgba, qRed, qGreen, qBlue, qAlpha, qGray
from PyQt5.QtCore import Qt, QByteArray, QBuffer, QRect, QSize, QFile, QIODevice
from typing import Callable, Iterable, SupportsIndex, Tuple, NamedTuple, Union, Optional
@@ -642,10 +642,11 @@ def from_bytes(data: QByteArray | bytes, offsets: list[int]):
for i, offset in enumerate(offsets):
buffer.seek(offset)
img = QImage()
- if img.load(buffer, None):
+ loader = QImageReader(buffer)
+ if loader.read(img):
images.append(Image(img))
else:
- raise Exception(f"Failed to load image {i} from buffer")
+ raise Exception(f"Failed to load image {i} from buffer: {loader.errorString()}")
buffer.close()
return images
@@ -690,6 +691,10 @@ def __init__(self, bounds: Bounds, data: Union[QImage, QByteArray]):
)
assert not self.image.isNull()
+ @staticmethod
+ def transparent(bounds: Bounds):
+ return Mask(bounds, QByteArray(bytes(bounds.width * bounds.height)))
+
@staticmethod
def rectangle(bounds: Bounds, feather=0):
# Note: for testing only, where Krita selection is not available
diff --git a/ai_diffusion/jobs.py b/ai_diffusion/jobs.py
index 962145be1b..b2f16cbf2d 100644
--- a/ai_diffusion/jobs.py
+++ b/ai_diffusion/jobs.py
@@ -3,14 +3,16 @@
from dataclasses import dataclass, fields, field
from datetime import datetime
from enum import Enum, Flag
-from typing import Any, Deque, NamedTuple
+from typing import Any, Deque, NamedTuple, TYPE_CHECKING
from PyQt5.QtCore import QObject, pyqtSignal
from .image import Bounds, ImageCollection
from .settings import settings
from .style import Style
from .util import ensure
-from . import control
+
+if TYPE_CHECKING:
+ from . import control
class JobState(Flag):
@@ -45,14 +47,10 @@ def from_dict(data: dict[str, Any]):
@dataclass
class JobParams:
bounds: Bounds
- prompt: str
- negative_prompt: str = ""
+ name: str # used eg. as name for new layers created from this job
regions: list[JobRegion] = field(default_factory=list)
- strength: float = 1.0
+ metadata: dict[str, Any] = field(default_factory=dict)
seed: int = 0
- style: str = ""
- checkpoint: str = ""
- sampler: str = ""
has_mask: bool = False
frame: tuple[int, int, int] = (0, 0, 0)
animation_id: str = ""
@@ -71,9 +69,22 @@ def equal_ignore_seed(cls, a: JobParams | None, b: JobParams | None):
return all(getattr(a, name) == getattr(b, name) for name in field_names)
def set_style(self, style: Style):
- self.style = style.filename
- self.checkpoint = style.sd_checkpoint
- self.sampler = f"{style.sampler} ({style.sampler_steps} / {style.cfg_scale})"
+ self.metadata["style"] = style.filename
+ self.metadata["checkpoint"] = style.sd_checkpoint
+ self.metadata["loras"] = style.loras
+ self.metadata["sampler"] = f"{style.sampler} ({style.sampler_steps} / {style.cfg_scale})"
+
+ @property
+ def prompt(self):
+ return self.metadata.get("prompt", "")
+
+ @property
+ def style(self):
+ return self.metadata.get("style", "")
+
+ @property
+ def strength(self):
+ return self.metadata.get("strength", 1.0)
class Job:
diff --git a/ai_diffusion/layer.py b/ai_diffusion/layer.py
index d670e8b4bb..a01771825d 100644
--- a/ai_diffusion/layer.py
+++ b/ai_diffusion/layer.py
@@ -528,19 +528,19 @@ def update_layer_image(self, layer: Layer, image: Image, bounds: Bounds, keep_al
_mask_types = [t.value for t in LayerType if t.is_mask]
@property
- def all(self):
+ def all(self) -> list[Layer]:
if self._doc is None:
return []
return [self.wrap(n) for n in traverse_layers(self._doc.rootNode())]
@property
- def images(self):
+ def images(self) -> list[Layer]:
if self._doc is None:
return []
return [self.wrap(n) for n in traverse_layers(self._doc.rootNode(), self._image_types)]
@property
- def masks(self):
+ def masks(self) -> list[Layer]:
if self._doc is None:
return []
return [self.wrap(n) for n in traverse_layers(self._doc.rootNode(), self._mask_types)]
diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py
index 27f85781ab..941e6aed47 100644
--- a/ai_diffusion/model.py
+++ b/ai_diffusion/model.py
@@ -10,14 +10,16 @@
import uuid
from . import eventloop, workflow, util
-from .api import ConditioningInput, ControlInput, WorkflowKind, WorkflowInput
-from .api import InpaintMode, InpaintParams, FillMode
+from .api import ConditioningInput, ControlInput, WorkflowKind, WorkflowInput, SamplingInput
+from .api import InpaintMode, InpaintParams, FillMode, ImageInput, CustomWorkflowInput
from .localization import translate as _
from .util import clamp, ensure, trim_text, client_logger as log
from .settings import ApplyBehavior, settings
from .network import NetworkError
from .image import Extent, Image, Mask, Bounds, DummyImage
-from .client import ClientMessage, ClientEvent, filter_supported_styles, resolve_arch
+from .client import ClientMessage, ClientEvent, SharedWorkflow
+from .client import filter_supported_styles, resolve_arch
+from .custom_workflow import CustomWorkspace, WorkflowCollection, CustomGenerationMode
from .document import Document, KritaDocument
from .layer import Layer, LayerType, RestoreActiveLayer
from .pose import Pose
@@ -37,6 +39,7 @@ class Workspace(Enum):
upscaling = 1
live = 2
animation = 3
+ custom = 4
class ProgressKind(Enum):
@@ -50,10 +53,6 @@ class Model(QObject, ObservableProperties):
list of finished, currently running and enqueued jobs.
"""
- _doc: Document
- _connection: Connection
- _layer: Layer | None = None
-
workspace = Property(Workspace.generation, setter="set_workspace", persist=True)
regions: "RootRegion"
style = Property(Styles.list().default, setter="set_style", persist=True)
@@ -64,13 +63,8 @@ class Model(QObject, ObservableProperties):
fixed_seed = Property(False, persist=True)
queue_front = Property(False, persist=True)
translation_enabled = Property(True, persist=True)
- inpaint: CustomInpaint
- upscale: "UpscaleWorkspace"
- live: "LiveWorkspace"
- animation: "AnimationWorkspace"
progress_kind = Property(ProgressKind.generation)
progress = Property(0.0)
- jobs: JobQueue
error = Property("")
workspace_changed = pyqtSignal(Workspace)
@@ -88,10 +82,11 @@ class Model(QObject, ObservableProperties):
has_error_changed = pyqtSignal(bool)
modified = pyqtSignal(QObject, str)
- def __init__(self, document: Document, connection: Connection):
+ def __init__(self, document: Document, connection: Connection, workflows: WorkflowCollection):
super().__init__()
self._doc = document
self._connection = connection
+ self._layer: Layer | None = None
self.generate_seed()
self.jobs = JobQueue()
self.regions = RootRegion(self)
@@ -99,6 +94,7 @@ def __init__(self, document: Document, connection: Connection):
self.upscale = UpscaleWorkspace(self)
self.live = LiveWorkspace(self)
self.animation = AnimationWorkspace(self)
+ self.custom = CustomWorkspace(workflows, self._generate_custom, self.jobs)
self.jobs.selection_changed.connect(self.update_preview)
self.error_changed.connect(lambda: self.has_error_changed.emit(self.has_error))
@@ -198,23 +194,23 @@ def _prepare_workflow(self, dryrun=False):
)
job_params = JobParams(bounds, prompt, regions=job_regions)
job_params.set_style(self.style)
+ job_params.metadata["prompt"] = prompt
+ job_params.metadata["negative_prompt"] = self.regions.negative
+ job_params.metadata["strength"] = self.strength
+ if len(job_regions) == 1:
+ job_params.metadata["prompt"] = job_params.name = job_regions[0].prompt
return input, job_params
async def enqueue_jobs(
self, input: WorkflowInput, kind: JobKind, params: JobParams, count: int = 1
):
sampling = ensure(input.sampling)
- params.negative_prompt = self.regions.negative
- params.strength = sampling.denoise_strength
params.has_mask = input.images is not None and input.images.hires_mask is not None
- if len(params.regions) == 1:
- params.prompt = params.regions[0].prompt
for i in range(count):
- input = replace(
- input, sampling=replace(sampling, seed=sampling.seed + i * settings.batch_size)
- )
- params.seed = ensure(input.sampling).seed
+ next_seed = sampling.seed + i * settings.batch_size
+ input = replace(input, sampling=replace(sampling, seed=next_seed))
+ params.seed = next_seed
job = self.jobs.add(kind, copy(params))
await self._enqueue_job(job, input)
@@ -353,6 +349,46 @@ async def _generate_live(self, last_input: WorkflowInput | None = None):
return None
+ async def _generate_custom(self, previous_input: WorkflowInput | None):
+ if self.workspace is not Workspace.custom or not self.document.is_active:
+ return False
+
+ try:
+ wf = ensure(self.custom.graph)
+ bounds = Bounds(0, 0, *self._doc.extent)
+ img_input = ImageInput.from_extent(bounds.extent)
+ img_input.initial_image = self._get_current_image(bounds)
+ is_live = self.custom.mode is CustomGenerationMode.live
+ seed = self.seed if is_live or self.fixed_seed else workflow.generate_seed()
+
+ if next(wf.find(type="ETN_KritaSelection"), None):
+ mask, _ = self._doc.create_mask_from_selection()
+ if mask:
+ img_input.hires_mask = mask.to_image(bounds.extent)
+ else:
+ img_input.hires_mask = Mask.transparent(bounds).to_image()
+
+ params = self.custom.collect_parameters(self.layers, bounds)
+ input = WorkflowInput(
+ WorkflowKind.custom,
+ img_input,
+ sampling=SamplingInput("custom", "custom", 1, 1000, seed=seed),
+ custom_workflow=CustomWorkflowInput(wf.root, params),
+ )
+ job_params = JobParams(bounds, self.custom.job_name, metadata=self.custom.params)
+ job_kind = JobKind.live_preview if is_live else JobKind.diffusion
+
+ if input == previous_input:
+ return None
+
+ self.clear_error()
+ await self.enqueue_jobs(input, job_kind, job_params, self.batch_count)
+ return input
+
+ except Exception as e:
+ self.report_error(util.log_error(e))
+ return False
+
def _get_current_image(self, bounds: Bounds):
exclude = None
if self.workspace is not Workspace.live:
@@ -450,7 +486,7 @@ def update_preview(self):
def show_preview(self, job_id: str, index: int, name_prefix="Preview"):
job = self.jobs.find(job_id)
assert job is not None, "Cannot show preview, invalid job id"
- name = f"[{name_prefix}] {trim_text(job.params.prompt, 77)}"
+ name = f"[{name_prefix}] {trim_text(job.params.name, 77)}"
if self._layer and self._layer.was_removed:
self._layer = None # layer was removed by user
if self._layer is not None:
@@ -472,7 +508,7 @@ def apply_result(self, image: Image, params: JobParams, behavior: ApplyBehavior,
if behavior is ApplyBehavior.replace:
self.layers.update_layer_image(self.layers.active, image, params.bounds)
else:
- name = f"{prefix}{trim_text(params.prompt, 200)} ({params.seed})"
+ name = f"{prefix}{trim_text(params.name, 200)} ({params.seed})"
self.layers.create(name, image, params.bounds)
else: # apply to regions
with RestoreActiveLayer(self.layers) as restore:
@@ -492,7 +528,7 @@ def create_result_layer(
):
name = f"{prefix}{job_region.prompt} ({params.seed})"
region_layer = self.layers.find(QUuid(job_region.layer_id)) or self.layers.root
- # a previous apply from the same batch my have already created groups and re-linked
+ # a previous apply from the same batch may have already created groups and re-linked
region_layer = Region.link_target(region_layer)
# Replace content if requested and not a group layer
@@ -559,14 +595,14 @@ def apply_generated_result(self, job_id: str, index: int):
self.jobs.selection = None
self.jobs.notify_used(job_id, index)
- def add_control_layer(self, job: Job, result: dict | None):
+ def add_control_layer(self, job: Job, result: dict | SharedWorkflow | None):
assert job.kind is JobKind.control_layer and job.control
- if job.control.mode is ControlMode.pose and result is not None:
+ if job.control.mode is ControlMode.pose and isinstance(result, dict):
pose = Pose.from_open_pose_json(result)
pose.scale(job.params.bounds.extent)
- return self.layers.create_vector(job.params.prompt, pose.to_svg())
+ return self.layers.create_vector(job.params.name, pose.to_svg())
elif len(job.results) > 0:
- return self.layers.create(job.params.prompt, job.results[0], job.params.bounds)
+ return self.layers.create(job.params.name, job.results[0], job.params.bounds)
return self.layers.active # Execution was cached and no image was produced
def add_upscale_layer(self, job: Job):
@@ -1037,7 +1073,7 @@ def _import_animation(self, job: Job):
keyframes = self._keyframes.pop(job.params.animation_id)
_, start, end = job.params.frame
doc.import_animation(keyframes, start)
- eventloop.run(self._update_layer_name(f"[Generated] {start}-{end}: {job.params.prompt}"))
+ eventloop.run(self._update_layer_name(f"[Generated] {start}-{end}: {job.params.name}"))
async def _update_layer_name(self, name: str):
doc = self._model.document
@@ -1102,7 +1138,7 @@ def _save_job_result(model: Model, job: Job | None, index: int):
assert len(job.results) > index, "Cannot save result, invalid result index"
assert model.document.filename, "Cannot save result, document is not saved"
timestamp = job.timestamp.strftime("%Y%m%d-%H%M%S")
- prompt = util.sanitize_prompt(job.params.prompt)
+ prompt = util.sanitize_prompt(job.params.name)
path = Path(model.document.filename)
path = path.parent / f"{path.stem}-generated-{timestamp}-{index}-{prompt}.png"
path = util.find_unused_path(path)
diff --git a/ai_diffusion/persistence.py b/ai_diffusion/persistence.py
index ad78d02076..348e62cfbc 100644
--- a/ai_diffusion/persistence.py
+++ b/ai_diffusion/persistence.py
@@ -8,8 +8,9 @@
from PyQt5.QtWidgets import QMessageBox
from .api import InpaintMode, FillMode
-from .image import Bounds, Image, ImageCollection, ImageFileFormat
+from .image import ImageCollection
from .model import Model, InpaintContext
+from .custom_workflow import CustomWorkspace
from .control import ControlLayer, ControlLayerList
from .region import RootRegion, Region
from .jobs import Job, JobKind, JobParams, JobQueue
@@ -132,6 +133,7 @@ def _save(self):
state["upscale"] = _serialize(model.upscale)
state["live"] = _serialize(model.live)
state["animation"] = _serialize(model.animation)
+ state["custom"] = _serialize_custom(model.custom)
state["history"] = [asdict(h) for h in self._history]
state["root"] = _serialize(model.regions)
state["control"] = [_serialize(c) for c in model.regions.control]
@@ -150,6 +152,7 @@ def _load(self, model: Model, state_bytes: bytes):
_deserialize(model.upscale, state.get("upscale", {}))
_deserialize(model.live, state.get("live", {}))
_deserialize(model.animation, state.get("animation", {}))
+ _deserialize_custom(model.custom, state.get("custom", {}))
_deserialize(model.regions, state.get("root", {}))
for control_state in state.get("control", []):
_deserialize(model.regions.control.emplace(), control_state)
@@ -176,6 +179,7 @@ def _track(self, model: Model):
model.upscale.modified.connect(self._save)
model.live.modified.connect(self._save)
model.animation.modified.connect(self._save)
+ model.custom.modified.connect(self._save)
model.jobs.job_finished.connect(self._save_results)
model.jobs.job_discarded.connect(self._remove_results)
model.jobs.result_discarded.connect(self._remove_image)
@@ -261,6 +265,21 @@ def converter(type, value):
return deserialize(obj, data, converter)
+def _serialize_custom(custom: CustomWorkspace):
+ result = _serialize(custom)
+ result["workflow_id"] = custom.workflow_id
+ result["graph"] = custom.graph.root if custom.graph else None
+ return result
+
+
+def _deserialize_custom(custom: CustomWorkspace, data: dict[str, Any]):
+ _deserialize(custom, data)
+ workflow_id = data.get("workflow_id", "")
+ graph = data.get("graph", None)
+ if workflow_id and graph:
+ custom.set_graph(workflow_id, graph)
+
+
def _find_annotation(document, name: str):
if result := document.find_annotation(name):
return result
diff --git a/ai_diffusion/properties.py b/ai_diffusion/properties.py
index a4cd94086e..00902e9e55 100644
--- a/ai_diffusion/properties.py
+++ b/ai_diffusion/properties.py
@@ -1,3 +1,4 @@
+from copy import copy
from enum import Enum
from typing import Any, NamedTuple, Sequence, TypeVar, Generic
@@ -18,7 +19,7 @@ def __init_subclass__(cls, **kwargs):
name: attr for name, attr in cls.__dict__.items() if isinstance(attr, Property)
}
for name, property in properties.items():
- setattr(cls, f"_{name}", property.default_value)
+ setattr(cls, f"_{name}", _copy_reference_types(property.default_value))
getter, setter = None, None
if property.getter is not None:
getter = getattr(cls, property.getter)
@@ -198,3 +199,9 @@ def deserialize(obj: QObject, data: dict[str, Any], converter=_default_deseriali
if not isinstance(value, type(current)):
raise TypeError(f"{name} was '{value}', but expected {type(current)}")
setattr(obj, name, value)
+
+
+def _copy_reference_types(object):
+ if isinstance(object, (list, dict)):
+ return copy(object)
+ return object
diff --git a/ai_diffusion/resources.py b/ai_diffusion/resources.py
index 624198b3bb..84c1dcf286 100644
--- a/ai_diffusion/resources.py
+++ b/ai_diffusion/resources.py
@@ -6,10 +6,10 @@
# Version identifier for all the resources defined here. This is used as the server version.
# It usually follows the plugin version, but not all new plugin versions also require a server update.
-version = "1.25.0"
+version = "1.26.0"
comfy_url = "https://github.com/comfyanonymous/ComfyUI"
-comfy_version = "dc96a1ae19b1d714a791f1fcb21578389955bbfd"
+comfy_version = "1b8089528502a881d0ed2918b2abd54441743dd0"
class CustomNode(NamedTuple):
@@ -39,7 +39,7 @@ class CustomNode(NamedTuple):
"External Tooling Nodes",
"comfyui-tooling-nodes",
"https://github.com/Acly/comfyui-tooling-nodes",
- "1a24975f99b95fa9a17a917de05fa248d8dc9569",
+ "9a9cbe78a5851a49da0b38bc9b17504837476e8f",
["ETN_LoadImageBase64", "ETN_LoadMaskBase64", "ETN_SendImageWebSocket", "ETN_Translate"],
),
CustomNode(
diff --git a/ai_diffusion/root.py b/ai_diffusion/root.py
index ddecaf2e0d..f001ae48ea 100644
--- a/ai_diffusion/root.py
+++ b/ai_diffusion/root.py
@@ -4,10 +4,11 @@
from .connection import Connection, ConnectionState
from .client import ClientMessage
+from .custom_workflow import WorkflowCollection
from .server import Server, ServerState
from .document import Document, KritaDocument
from .model import Model
-from .files import FileLibrary, File, FileSource
+from .files import FileFormat, FileLibrary, File, FileSource
from .persistence import ModelSync, RecentlyUsedSync, import_prompt_from_file
from .updates import AutoUpdate
from .ui.theme import checkpoint_icon
@@ -37,7 +38,9 @@ def init(self):
self._server = Server(settings.server_path)
self._connection = Connection()
self._files = FileLibrary.load()
+ self._workflows = WorkflowCollection(self._connection)
self._models = []
+ self._null_model = Model(Document(), self._connection, self._workflows)
self._recent = RecentlyUsedSync.from_settings()
self._auto_update = AutoUpdate()
if settings.auto_update:
@@ -50,7 +53,7 @@ def prune_models(self):
self._models = [m for m in self._models if m.model.document.is_valid]
def create_model(self, doc: KritaDocument):
- model = Model(doc, self._connection)
+ model = Model(doc, self._connection, self._workflows)
self._recent.track(model)
persistence_sync = ModelSync(model)
import_prompt_from_file(model)
@@ -82,6 +85,10 @@ def server(self):
def files(self) -> FileLibrary:
return self._files
+ @property
+ def workflows(self) -> WorkflowCollection:
+ return self._workflows
+
@property
def auto_update(self) -> AutoUpdate:
return self._auto_update
@@ -90,7 +97,7 @@ def auto_update(self) -> AutoUpdate:
def active_model(self):
if model := self.model_for_active_document():
return model
- return Model(Document(), self._connection)
+ return self._null_model
async def autostart(self, signal_server_change: Callable):
connection = self._connection
@@ -148,7 +155,7 @@ def _update_files(self):
]
self._files.checkpoints.update(checkpoints, FileSource.remote)
- loras = [File.remote(lora) for lora in client.models.loras]
+ loras = [File.remote(lora, FileFormat.lora) for lora in client.models.loras]
self._files.loras.update(loras, FileSource.remote)
diff --git a/ai_diffusion/ui/control.py b/ai_diffusion/ui/control.py
index 1e4af8a463..d86c0d3e63 100644
--- a/ai_diffusion/ui/control.py
+++ b/ai_diffusion/ui/control.py
@@ -16,16 +16,14 @@
class ControlWidget(QWidget):
- _control_list: ControlLayerList
- _control: ControlLayer
- _connections: list[QMetaObject.Connection | Binding]
def __init__(
- self, control_list: ControlLayerList, control: ControlLayer, parent: ControlListWidget
+ self, control_list: ControlLayerList | None, control: ControlLayer, parent: QWidget
):
super().__init__(parent)
self._control_list = control_list
self._control = control
+ self._connections: list[QMetaObject.Connection | Binding] = []
layout = QVBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
@@ -80,13 +78,6 @@ def __init__(
self.expand_button.setChecked(False)
self.expand_button.clicked.connect(self._toggle_extended)
- self.remove_button = QToolButton(self)
- self.remove_button.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
- self.remove_button.setIcon(theme.icon("remove"))
- self.remove_button.setToolTip(_("Remove control layer"))
- self.remove_button.setAutoRaise(True)
- self.remove_button.clicked.connect(self.remove)
-
bar_layout = QHBoxLayout()
bar_layout.addWidget(self.mode_select)
bar_layout.addWidget(self.layer_select, 3)
@@ -95,9 +86,17 @@ def __init__(
bar_layout.addWidget(self.preset_slider, 1)
bar_layout.addWidget(self.error_text, 3)
bar_layout.addWidget(self.expand_button)
- bar_layout.addWidget(self.remove_button)
layout.addLayout(bar_layout)
+ if self._control_list is not None:
+ self.remove_button = QToolButton(self)
+ self.remove_button.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
+ self.remove_button.setIcon(theme.icon("remove"))
+ self.remove_button.setToolTip(_("Remove control layer"))
+ self.remove_button.setAutoRaise(True)
+ self.remove_button.clicked.connect(self.remove)
+ bar_layout.addWidget(self.remove_button)
+
line = QFrame(self)
line.setObjectName("LeftIndent")
line.setStyleSheet(f"#LeftIndent {{ color: {theme.line}; }}")
@@ -198,12 +197,13 @@ def _update_layers(self):
self.layer_select.addItem(layer.name, layer.id)
if layer.id == self._control.layer_id:
index = self.layer_select.count() - 1
- if index == -1 and self._control in self._control_list:
+ if index == -1 and self._control_list and self._control in self._control_list:
self.remove()
- else:
+ elif index >= 0:
self.layer_select.setCurrentIndex(index)
def remove(self):
+ assert self._control_list is not None
self._control_list.remove(self._control)
def resizeEvent(self, a0: QResizeEvent | None):
diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py
new file mode 100644
index 0000000000..85def2632b
--- /dev/null
+++ b/ai_diffusion/ui/custom_workflow.py
@@ -0,0 +1,706 @@
+from functools import wraps
+from pathlib import Path
+from typing import Any, Callable
+
+from PyQt5.QtCore import Qt, pyqtSignal, QMetaObject, QUuid, QUrl, QPoint
+from PyQt5.QtGui import QFontMetrics, QIcon, QDesktopServices
+from PyQt5.QtWidgets import QComboBox, QFileDialog, QFrame, QGridLayout, QHBoxLayout, QMenu
+from PyQt5.QtWidgets import QLabel, QLineEdit, QListWidgetItem, QMessageBox, QSpinBox, QAction
+from PyQt5.QtWidgets import QToolButton, QVBoxLayout, QWidget, QSlider, QDoubleSpinBox
+
+from ..custom_workflow import CustomParam, ParamKind, SortedWorkflows, WorkflowSource
+from ..custom_workflow import CustomGenerationMode
+from ..jobs import JobKind
+from ..model import Model, ApplyBehavior
+from ..properties import Binding, Bind, bind, bind_combo
+from ..style import Styles
+from ..root import root
+from ..settings import settings
+from ..localization import translate as _
+from ..util import ensure, clamp
+from .generation import GenerateButton, ProgressBar, QueueButton, HistoryWidget, create_error_label
+from .live import LivePreviewArea
+from .switch import SwitchWidget
+from .widget import TextPromptWidget, WorkspaceSelectWidget, StyleSelectWidget
+from . import theme
+
+
+class LayerSelect(QComboBox):
+ value_changed = pyqtSignal()
+
+ def __init__(self, filter: str | None = None, parent: QWidget | None = None):
+ super().__init__(parent)
+ self.setContentsMargins(0, 0, 0, 0)
+ self.filter = filter
+ self.currentIndexChanged.connect(lambda _: self.value_changed.emit())
+
+ self._update()
+ root.active_model.layers.changed.connect(self._update)
+
+ def _update(self):
+ if self.filter is None:
+ layers = root.active_model.layers.all
+ elif self.filter == "image":
+ layers = root.active_model.layers.images
+ elif self.filter == "mask":
+ layers = root.active_model.layers.masks
+ else:
+ assert False, f"Unknown filter: {self.filter}"
+
+ for l in layers:
+ if self.findData(l.id) == -1:
+ self.addItem(l.name, l.id)
+ i = 0
+ while i < self.count():
+ if self.itemData(i) not in (l.id for l in layers):
+ self.removeItem(i)
+ else:
+ i += 1
+
+ @property
+ def value(self) -> str:
+ if self.currentIndex() == -1:
+ return ""
+ return self.currentData().toString()
+
+ @value.setter
+ def value(self, value: str):
+ i = self.findData(QUuid(value))
+ if i != -1 and i != self.currentIndex():
+ self.setCurrentIndex(i)
+
+
+class IntParamWidget(QWidget):
+ value_changed = pyqtSignal()
+
+ def __init__(self, param: CustomParam, parent: QWidget | None = None):
+ super().__init__(parent)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ layout = QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.setLayout(layout)
+
+ assert param.min is not None and param.max is not None and param.default is not None
+ if param.max - param.min <= 200:
+ self._widget = QSlider(Qt.Orientation.Horizontal, parent)
+ self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4)
+ self._widget.valueChanged.connect(self._notify)
+ self._label = QLabel(self)
+ self._label.setFixedWidth(32)
+ self._label.setAlignment(Qt.AlignmentFlag.AlignRight)
+ layout.addWidget(self._widget)
+ layout.addWidget(self._label)
+ else:
+ self._widget = QSpinBox(parent)
+ self._widget.valueChanged.connect(self._notify)
+ self._label = None
+ layout.addWidget(self._widget)
+
+ min_range = clamp(int(param.min), -(2**31), 2**31 - 1)
+ max_range = clamp(int(param.max), -(2**31), 2**31 - 1)
+ self._widget.setRange(min_range, max_range)
+
+ self.value = param.default
+
+ def _notify(self):
+ if self._label:
+ self._label.setText(str(self._widget.value()))
+ self.value_changed.emit()
+
+ @property
+ def value(self):
+ return self._widget.value()
+
+ @value.setter
+ def value(self, value: int):
+ self._widget.setValue(value)
+
+
+class FloatParamWidget(QWidget):
+ value_changed = pyqtSignal()
+
+ def __init__(self, param: CustomParam, parent: QWidget | None = None):
+ super().__init__(parent)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ layout = QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.setLayout(layout)
+
+ assert param.min is not None and param.max is not None and param.default is not None
+ if param.max - param.min <= 20:
+ self._widget = QSlider(Qt.Orientation.Horizontal, parent)
+ self._widget.setRange(round(param.min * 100), round(param.max * 100))
+ self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4)
+ self._widget.valueChanged.connect(self._notify)
+ self._label = QLabel(self)
+ self._label.setFixedWidth(32)
+ self._label.setAlignment(Qt.AlignmentFlag.AlignRight)
+ layout.addWidget(self._widget)
+ layout.addWidget(self._label)
+ else:
+ self._widget = QDoubleSpinBox(parent)
+ self._widget.setRange(param.min, param.max)
+ self._widget.valueChanged.connect(self._notify)
+ self._label = None
+ layout.addWidget(self._widget)
+
+ self.value = param.default
+
+ def _notify(self):
+ if self._label:
+ self._label.setText(f"{self.value:.2f}")
+ self.value_changed.emit()
+
+ @property
+ def value(self):
+ if isinstance(self._widget, QSlider):
+ return self._widget.value() / 100
+ else:
+ return self._widget.value()
+
+ @value.setter
+ def value(self, value: float):
+ if isinstance(self._widget, QSlider):
+ self._widget.setValue(round(value * 100))
+ else:
+ self._widget.setValue(value)
+
+
+class BoolParamWidget(QWidget):
+ value_changed = pyqtSignal()
+
+ _true_text = _("On")
+ _false_text = _("Off")
+
+ def __init__(self, param: CustomParam, parent: QWidget | None = None):
+ super().__init__(parent)
+ self.setContentsMargins(0, 0, 0, 0)
+
+ layout = QHBoxLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ self.setLayout(layout)
+
+ fm = QFontMetrics(self.font())
+ self._label = QLabel(self)
+ self._label.setMinimumWidth(max(fm.width(self._true_text), fm.width(self._false_text)) + 4)
+ self._widget = SwitchWidget(parent)
+ self._widget.toggled.connect(self._notify)
+ layout.addWidget(self._widget)
+ layout.addWidget(self._label)
+
+ assert isinstance(param.default, bool)
+ self.value = param.default
+
+ def _notify(self):
+ self._label.setText(self._true_text if self.value else self._false_text)
+ self.value_changed.emit()
+
+ @property
+ def value(self):
+ return self._widget.isChecked()
+
+ @value.setter
+ def value(self, value: bool):
+ self._widget.setChecked(value)
+
+
+class TextParamWidget(QLineEdit):
+ value_changed = pyqtSignal()
+
+ def __init__(self, param: CustomParam, parent: QWidget | None = None):
+ super().__init__(parent)
+ assert isinstance(param.default, str)
+
+ self.value = param.default
+ self.textChanged.connect(self._notify)
+
+ def _notify(self):
+ self.value_changed.emit()
+
+ @property
+ def value(self):
+ return self.text()
+
+ @value.setter
+ def value(self, value: str):
+ self.setText(value)
+
+
+class PromptParamWidget(TextPromptWidget):
+ value_changed = pyqtSignal()
+
+ def __init__(self, param: CustomParam, parent: QWidget | None = None):
+ super().__init__(is_negative=param.kind is ParamKind.prompt_negative, parent=parent)
+ assert isinstance(param.default, str)
+
+ self.setObjectName("PromptParam")
+ self.setFrameStyle(QFrame.Shape.StyledPanel)
+ self.setStyleSheet(
+ f"QFrame#PromptParam {{ background-color: {theme.base}; border: 1px solid {theme.line_base}; }}"
+ )
+ self.text = param.default
+ self.text_changed.connect(self.value_changed)
+
+ @property
+ def value(self):
+ return self.text
+
+ @value.setter
+ def value(self, value: str):
+ self.text = value
+
+
+class ChoiceParamWidget(QComboBox):
+ value_changed = pyqtSignal()
+
+ def __init__(self, param: CustomParam, parent: QWidget | None = None):
+ super().__init__(parent)
+ if param.choices:
+ self.addItems(param.choices)
+ self.currentIndexChanged.connect(lambda _: self.value_changed.emit())
+
+ if param.default is not None:
+ self.value = param.default
+
+ @property
+ def value(self) -> str:
+ if self.currentIndex() == -1:
+ return ""
+ return self.currentText()
+
+ @value.setter
+ def value(self, value: str):
+ i = self.findText(value)
+ if i != -1 and i != self.currentIndex():
+ self.setCurrentIndex(i)
+
+
+class StyleParamWidget(QWidget):
+ value_changed = pyqtSignal()
+
+ def __init__(self, parent: QWidget):
+ super().__init__(parent)
+ self._style_select = StyleSelectWidget(self)
+ self._style_select.value_changed.connect(self._notify)
+ layout = QHBoxLayout()
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.addWidget(self._style_select)
+ self.setLayout(layout)
+
+ def _notify(self):
+ self.value_changed.emit()
+
+ @property
+ def value(self):
+ return self._style_select.value.filename
+
+ @value.setter
+ def value(self, value: str):
+ if style := Styles.list().find(value):
+ self._style_select.value = style
+
+
+CustomParamWidget = (
+ LayerSelect
+ | IntParamWidget
+ | FloatParamWidget
+ | BoolParamWidget
+ | TextParamWidget
+ | PromptParamWidget
+ | ChoiceParamWidget
+ | StyleParamWidget
+)
+
+
+def _create_param_widget(param: CustomParam, parent: QWidget) -> CustomParamWidget:
+ if param.kind is ParamKind.image_layer:
+ return LayerSelect("image", parent)
+ if param.kind is ParamKind.mask_layer:
+ return LayerSelect("mask", parent)
+ if param.kind is ParamKind.number_int:
+ return IntParamWidget(param, parent)
+ if param.kind is ParamKind.number_float:
+ return FloatParamWidget(param, parent)
+ if param.kind is ParamKind.toggle:
+ return BoolParamWidget(param, parent)
+ if param.kind is ParamKind.text:
+ return TextParamWidget(param, parent)
+ if param.kind in [ParamKind.prompt_positive, ParamKind.prompt_negative]:
+ return PromptParamWidget(param, parent)
+ if param.kind is ParamKind.choice:
+ return ChoiceParamWidget(param, parent)
+ if param.kind is ParamKind.style:
+ return StyleParamWidget(parent)
+ assert False, f"Unknown param kind: {param.kind}"
+
+
+class WorkflowParamsWidget(QWidget):
+ value_changed = pyqtSignal()
+
+ def __init__(self, params: list[CustomParam], parent: QWidget):
+ super().__init__(parent)
+ self._widgets: dict[str, CustomParamWidget] = {}
+
+ layout = QGridLayout(self)
+ layout.setContentsMargins(0, 0, 0, 0)
+ layout.setColumnMinimumWidth(1, 10)
+ self.setLayout(layout)
+
+ for p in params:
+ label = QLabel(p.name, self)
+ widget = _create_param_widget(p, self)
+ widget.value_changed.connect(self._notify)
+ row = len(self._widgets)
+ layout.addWidget(label, row, 0, Qt.AlignmentFlag.AlignBaseline)
+ layout.addWidget(widget, row, 2)
+ self._widgets[p.name] = widget
+
+ def _notify(self):
+ self.value_changed.emit()
+
+ @property
+ def value(self):
+ return {name: widget.value for name, widget in self._widgets.items()}
+
+ @value.setter
+ def value(self, values: dict[str, Any]):
+ for name, value in values.items():
+ if widget := self._widgets.get(name):
+ if type(widget.value) == type(value):
+ widget.value = value
+
+
+def popup_on_error(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return func(self, *args, **kwargs)
+ except Exception as e:
+ QMessageBox.critical(self, _("Error"), str(e))
+
+ return wrapper
+
+
+def _create_tool_button(parent: QWidget, icon: QIcon, tooltip: str, handler: Callable[..., None]):
+ button = QToolButton(parent)
+ button.setIcon(icon)
+ button.setToolTip(tooltip)
+ button.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
+ button.setAutoRaise(True)
+ button.clicked.connect(handler)
+ return button
+
+
+class CustomWorkflowWidget(QWidget):
+ def __init__(self):
+ super().__init__()
+
+ self._model = root.active_model
+ self._model_bindings: list[QMetaObject.Connection | Binding] = []
+
+ self._workspace_select = WorkspaceSelectWidget(self)
+
+ self._workflow_select_widgets = QWidget(self)
+
+ self._workflow_select = QComboBox(self._workflow_select_widgets)
+ self._workflow_select.setModel(SortedWorkflows(root.workflows))
+ self._workflow_select.currentIndexChanged.connect(self._change_workflow)
+
+ self._import_workflow_button = _create_tool_button(
+ self._workflow_select_widgets,
+ theme.icon("import"),
+ _("Import workflow from file"),
+ self._import_workflow,
+ )
+ self._save_workflow_button = _create_tool_button(
+ self._workflow_select_widgets,
+ theme.icon("save"),
+ _("Save workflow to file"),
+ self._save_workflow,
+ )
+ self._delete_workflow_button = _create_tool_button(
+ self._workflow_select_widgets,
+ theme.icon("discard"),
+ _("Delete the currently selected workflow"),
+ self._delete_workflow,
+ )
+ self._open_webui_button = _create_tool_button(
+ self._workflow_select_widgets,
+ theme.icon("comfyui"),
+ _("Open Web UI to create custom workflows"),
+ self._open_webui,
+ )
+
+ self._workflow_edit_widgets = QWidget(self)
+ self._workflow_edit_widgets.setVisible(False)
+
+ self._workflow_name_edit = QLineEdit(self._workflow_edit_widgets)
+ self._workflow_name_edit.textEdited.connect(self._edit_name)
+ self._workflow_name_edit.returnPressed.connect(self._accept_name)
+
+ self._accept_name_button = _create_tool_button(
+ self._workflow_edit_widgets, theme.icon("apply"), _("Apply"), self._accept_name
+ )
+ self._cancel_name_button = _create_tool_button(
+ self._workflow_edit_widgets, theme.icon("cancel"), _("Cancel"), self._cancel_name
+ )
+
+ self._params_widget = WorkflowParamsWidget([], self)
+
+ self._generate_button = GenerateButton(JobKind.diffusion, self)
+ self._generate_button.clicked.connect(self._generate)
+
+ self._apply_button = QToolButton(self)
+ self._apply_button.setIcon(theme.icon("apply"))
+ self._apply_button.setFixedHeight(self._generate_button.height() - 2)
+ self._apply_button.setToolTip(_("Create a new layer with the current result"))
+ self._apply_button.clicked.connect(self.apply_live_result)
+
+ self._mode_button = QToolButton(self)
+ self._mode_button.setArrowType(Qt.ArrowType.DownArrow)
+ self._mode_button.setFixedHeight(self._generate_button.height() - 2)
+ self._mode_button.clicked.connect(self._show_generate_menu)
+ menu = QMenu(self)
+ menu.addAction(self._mk_action(CustomGenerationMode.regular, _("Generate"), "generate"))
+ menu.addAction(
+ self._mk_action(CustomGenerationMode.live, _("Generate Live"), "workspace-live")
+ )
+ self._generate_menu = menu
+
+ self._queue_button = QueueButton(parent=self)
+ self._queue_button.setFixedHeight(self._generate_button.height() - 2)
+
+ self._progress_bar = ProgressBar(self)
+ self._error_text = create_error_label(self)
+
+ self._history = HistoryWidget(self)
+ self._history.item_activated.connect(self.apply_result)
+
+ self._live_preview = LivePreviewArea(self)
+
+ self._layout = QVBoxLayout()
+ select_layout = QHBoxLayout()
+ select_layout.setContentsMargins(0, 0, 0, 0)
+ select_layout.setSpacing(2)
+ select_layout.addWidget(self._workflow_select)
+ select_layout.addWidget(self._import_workflow_button)
+ select_layout.addWidget(self._save_workflow_button)
+ select_layout.addWidget(self._delete_workflow_button)
+ select_layout.addWidget(self._open_webui_button)
+ self._workflow_select_widgets.setLayout(select_layout)
+ edit_layout = QHBoxLayout()
+ edit_layout.setContentsMargins(0, 0, 0, 0)
+ edit_layout.setSpacing(2)
+ edit_layout.addWidget(self._workflow_name_edit)
+ edit_layout.addWidget(self._accept_name_button)
+ edit_layout.addWidget(self._cancel_name_button)
+ self._workflow_edit_widgets.setLayout(edit_layout)
+ header_layout = QHBoxLayout()
+ header_layout.addWidget(self._workspace_select)
+ header_layout.addWidget(self._workflow_select_widgets)
+ header_layout.addWidget(self._workflow_edit_widgets)
+ self._layout.addLayout(header_layout)
+ self._layout.addWidget(self._params_widget)
+ actions_layout = QHBoxLayout()
+ actions_layout.setSpacing(0)
+ actions_layout.addWidget(self._generate_button)
+ actions_layout.addWidget(self._apply_button)
+ actions_layout.addWidget(self._mode_button)
+ actions_layout.addSpacing(4)
+ actions_layout.addWidget(self._queue_button)
+ self._layout.addLayout(actions_layout)
+ self._layout.addWidget(self._progress_bar)
+ self._layout.addWidget(self._error_text)
+ self._layout.addWidget(self._history)
+ self._layout.addWidget(self._live_preview)
+ self.setLayout(self._layout)
+
+ self._update_ui()
+
+ @property
+ def model(self):
+ return self._model
+
+ @model.setter
+ def model(self, model: Model):
+ if self._model != model:
+ Binding.disconnect_all(self._model_bindings)
+ self._model = model
+ self._model_bindings = [
+ bind(model, "workspace", self._workspace_select, "value", Bind.one_way),
+ bind_combo(model.custom, "workflow_id", self._workflow_select, Bind.one_way),
+ model.workspace_changed.connect(self._cancel_name),
+ model.custom.graph_changed.connect(self._update_current_workflow),
+ model.error_changed.connect(self._error_text.setText),
+ model.has_error_changed.connect(self._error_text.setVisible),
+ model.custom.mode_changed.connect(self._update_ui),
+ model.custom.is_live_changed.connect(self._update_ui),
+ model.custom.result_available.connect(self._live_preview.show_image),
+ model.custom.has_result_changed.connect(self._apply_button.setEnabled),
+ ]
+ self._queue_button.model = model
+ self._progress_bar.model = model
+ self._history.model_ = model
+ self._update_current_workflow()
+ self._update_ui()
+
+ def _mk_action(self, mode: CustomGenerationMode, text: str, icon: str):
+ action = QAction(text, self)
+ action.setIcon(theme.icon(icon))
+ action.setIconVisibleInMenu(True)
+ action.triggered.connect(lambda: self._change_mode(mode))
+ return action
+
+ def _change_mode(self, mode: CustomGenerationMode):
+ self.model.custom.mode = mode
+
+ def _show_generate_menu(self):
+ width = self._generate_button.width() + self._mode_button.width()
+ pos = QPoint(0, self._generate_button.height())
+ self._generate_menu.setFixedWidth(width)
+ self._generate_menu.exec_(self._generate_button.mapToGlobal(pos))
+
+ def _update_ui(self):
+ is_live_mode = self.model.custom.mode is CustomGenerationMode.live
+ self._history.setVisible(not is_live_mode)
+ self._live_preview.setVisible(is_live_mode)
+ self._apply_button.setVisible(is_live_mode)
+ self._apply_button.setEnabled(self.model.custom.has_result)
+
+ if not is_live_mode:
+ text = _("Generate")
+ icon = "generate"
+ elif not self.model.custom.is_live:
+ text = _("Start Generating")
+ icon = "play"
+ else:
+ text = _("Stop Generating")
+ icon = "pause"
+ self._generate_button.operation = text
+ self._generate_button.setIcon(theme.icon(icon))
+
+ def _generate(self):
+ if self.model.custom.mode is CustomGenerationMode.regular:
+ self.model.custom.generate()
+ else:
+ self.model.custom.is_live = not self.model.custom.is_live
+
+ def _update_current_workflow(self):
+ if not self.model.custom.workflow:
+ self._save_workflow_button.setEnabled(False)
+ self._delete_workflow_button.setEnabled(False)
+ return
+ self._save_workflow_button.setEnabled(True)
+ self._delete_workflow_button.setEnabled(
+ self.model.custom.workflow.source is WorkflowSource.local
+ )
+
+ self._params_widget.deleteLater()
+ self._params_widget = WorkflowParamsWidget(self.model.custom.metadata, self)
+ self._params_widget.value = self.model.custom.params
+ self._layout.insertWidget(1, self._params_widget)
+ self._params_widget.value_changed.connect(self._change_params)
+
+ def _change_workflow(self):
+ self.model.custom.workflow_id = self._workflow_select.currentData()
+
+ def _change_params(self):
+ self.model.custom.params = self._params_widget.value
+
+ def apply_result(self, item: QListWidgetItem):
+ job_id, index = self._history.item_info(item)
+ self.model.apply_generated_result(job_id, index)
+
+ def apply_live_result(self):
+ image, params = self.model.custom.live_result
+ self.model.apply_result(image, params, ApplyBehavior.layer)
+ if settings.new_seed_after_apply:
+ self.model.generate_seed()
+
+ @popup_on_error
+ def _import_workflow(self, *args):
+ filename, __ = QFileDialog.getOpenFileName(
+ self,
+ _("Import Workflow"),
+ str(Path.home()),
+ "Workflow Files (*.json);;All Files (*)",
+ )
+ if filename:
+ self.model.custom.import_file(Path(filename))
+
+ def _save_workflow(self):
+ self.is_edit_mode = True
+
+ def _delete_workflow(self):
+ filepath = ensure(self.model.custom.workflow).path
+ q = QMessageBox.question(
+ self,
+ _("Delete Workflow"),
+ _("Are you sure you want to delete the current workflow?") + f"\n{filepath}",
+ QMessageBox.Yes | QMessageBox.No,
+ QMessageBox.StandardButton.No,
+ )
+ if q == QMessageBox.StandardButton.Yes:
+ self.model.custom.remove_workflow()
+
+ def _open_webui(self):
+ if client := root.connection.client_if_connected:
+ QDesktopServices.openUrl(QUrl(client.url))
+
+ @property
+ def is_edit_mode(self):
+ return self._workflow_edit_widgets.isVisible()
+
+ @is_edit_mode.setter
+ def is_edit_mode(self, value: bool):
+ if value == self.is_edit_mode:
+ return
+ self._workflow_select_widgets.setVisible(not value)
+ self._workflow_edit_widgets.setVisible(value)
+ if value:
+ self._workflow_name_edit.setText(self.model.custom.workflow_id)
+ self._workflow_name_edit.selectAll()
+ self._workflow_name_edit.setFocus()
+
+ def _edit_name(self):
+ self._accept_name_button.setEnabled(self._workflow_name_edit.text().strip() != "")
+
+ @popup_on_error
+ def _accept_name(self, *args):
+ self.model.custom.save_as(self._workflow_name_edit.text())
+ self.is_edit_mode = False
+
+ def _cancel_name(self):
+ self.is_edit_mode = False
+
+
+class CustomWorkflowPlaceholder(QWidget):
+ def __init__(self):
+ super().__init__()
+ self._model = root.active_model
+ self._connections = []
+
+ self._workspace_select = WorkspaceSelectWidget(self)
+ note = QLabel("" + _("Custom workflows are not available on Cloud.") + "", self)
+
+ self._layout = QVBoxLayout()
+ self._layout.addWidget(self._workspace_select)
+ self._layout.addSpacing(50)
+ self._layout.addWidget(note, 0, Qt.AlignmentFlag.AlignCenter)
+ self._layout.addStretch()
+ self.setLayout(self._layout)
+
+ @property
+ def model(self):
+ return self._model
+
+ @model.setter
+ def model(self, model: Model):
+ if self._model != model:
+ Binding.disconnect_all(self._connections)
+ self._model = model
+ self._connections = [
+ bind(model, "workspace", self._workspace_select, "value", Bind.one_way)
+ ]
diff --git a/ai_diffusion/ui/diffusion.py b/ai_diffusion/ui/diffusion.py
index 1f218dfb54..3a20386fcd 100644
--- a/ai_diffusion/ui/diffusion.py
+++ b/ai_diffusion/ui/diffusion.py
@@ -14,6 +14,7 @@
from ..localization import translate as _
from . import theme
from .generation import GenerationWidget
+from .custom_workflow import CustomWorkflowWidget, CustomWorkflowPlaceholder
from .upscale import UpscaleWidget
from .live import LiveWidget
from .animation import AnimationWidget
@@ -211,12 +212,16 @@ def __init__(self):
self._upscaling = UpscaleWidget()
self._animation = AnimationWidget()
self._live = LiveWidget()
+ self._custom = CustomWorkflowWidget()
+ self._custom_placeholder = CustomWorkflowPlaceholder()
self._frame = QStackedWidget(self)
self._frame.addWidget(self._welcome)
self._frame.addWidget(self._generation)
self._frame.addWidget(self._upscaling)
self._frame.addWidget(self._live)
self._frame.addWidget(self._animation)
+ self._frame.addWidget(self._custom)
+ self._frame.addWidget(self._custom_placeholder)
self.setWidget(self._frame)
root.connection.state_changed.connect(self.update_content)
@@ -235,6 +240,7 @@ def update_content(self):
model = root.model_for_active_document()
connection = root.connection
requires_update = self._welcome.requires_update
+ is_cloud = settings.server_mode is ServerMode.cloud
if model is None or connection.state is not ConnectionState.connected or requires_update:
self._frame.setCurrentWidget(self._welcome)
elif model.workspace is Workspace.generation:
@@ -249,3 +255,9 @@ def update_content(self):
elif model.workspace is Workspace.animation:
self._animation.model = model
self._frame.setCurrentWidget(self._animation)
+ elif model.workspace is Workspace.custom and is_cloud:
+ self._custom_placeholder.model = model
+ self._frame.setCurrentWidget(self._custom_placeholder)
+ elif model.workspace is Workspace.custom:
+ self._custom.model = model
+ self._frame.setCurrentWidget(self._custom)
diff --git a/ai_diffusion/ui/generation.py b/ai_diffusion/ui/generation.py
index f80a1edb47..723f3ee416 100644
--- a/ai_diffusion/ui/generation.py
+++ b/ai_diffusion/ui/generation.py
@@ -2,25 +2,9 @@
from textwrap import wrap as wrap_text
from PyQt5.QtCore import Qt, QMetaObject, QSize, QPoint, QUuid, pyqtSignal
from PyQt5.QtGui import QGuiApplication, QMouseEvent, QPalette, QColor
-from PyQt5.QtWidgets import (
- QAction,
- QWidget,
- QVBoxLayout,
- QHBoxLayout,
- QPushButton,
- QProgressBar,
- QLabel,
- QListWidget,
- QListWidgetItem,
- QListView,
- QSizePolicy,
- QToolButton,
- QComboBox,
- QCheckBox,
- QMenu,
- QShortcut,
- QMessageBox,
-)
+from PyQt5.QtWidgets import QAction, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QProgressBar
+from PyQt5.QtWidgets import QLabel, QListWidget, QListWidgetItem, QListView, QSizePolicy
+from PyQt5.QtWidgets import QComboBox, QCheckBox, QMenu, QShortcut, QMessageBox, QToolButton
from ..properties import Binding, Bind, bind, bind_combo, bind_toggle
from ..image import Bounds, Extent, Image
@@ -135,8 +119,9 @@ def add(self, job: Job):
if not JobParams.equal_ignore_seed(self._last_job_params, job.params):
self._last_job_params = job.params
- prompt = job.params.prompt if job.params.prompt != "" else ""
- strength = f"{job.params.strength*100:.0f}% - " if job.params.strength != 1.0 else ""
+ prompt = job.params.name if job.params.name != "" else ""
+ strength = job.params.metadata.get("strength", 1.0)
+ strength = f"{strength*100:.0f}% - " if strength != 1.0 else ""
header = QListWidgetItem(f"{job.timestamp:%H:%M} - {strength}{prompt}")
header.setFlags(Qt.ItemFlag.NoItemFlags)
@@ -156,26 +141,39 @@ def add(self, job: Job):
if scroll_to_bottom:
self.scrollToBottom()
+ _job_info_translations = {
+ "prompt": _("Prompt"),
+ "negative_prompt": _("Negative Prompt"),
+ "style": _("Style"),
+ "strength": _("Strength"),
+ "checkpoint": _("Model"),
+ "loras": _("LoRA"),
+ "sampler": _("Sampler"),
+ "seed": _("Seed"),
+ }
+
def _job_info(self, params: JobParams):
- prompt = params.prompt if params.prompt != "" else ""
- if len(prompt) > 70:
- prompt = prompt[:66] + "..."
+ title = params.name if params.name != "" else ""
+ if len(title) > 70:
+ title = title[:66] + "..."
+ if params.strength != 1.0:
+ title = f"{title} @ {params.strength*100:.0f}%"
style = Styles.list().find(params.style)
- positive = _("Prompt") + f": {params.prompt or '-'}"
- negative = _("Negative Prompt") + f": {params.negative_prompt or '-'}"
- strings = [
- f"{prompt} @ {params.strength*100:.0f}%\n",
+ strings: list[str | list[str]] = [
+ title + "\n",
_("Click to toggle preview, double-click to apply."),
"",
- _("Style") + f": {style.name if style else params.style}",
- wrap_text(positive, 80, subsequent_indent=" "),
- wrap_text(negative, 80, subsequent_indent=" "),
- _("Strength") + f": {params.strength*100:.0f}%",
- _("Model") + f": {params.checkpoint}",
- _("Sampler") + f": {params.sampler}",
- _("Seed") + f": {params.seed}",
- f"{params.bounds}",
]
+ for key, value in params.metadata.items():
+ if key == "style" and style:
+ value = style.name
+ if isinstance(value, list) and len(value) == 0:
+ continue
+ if isinstance(value, list) and isinstance(value[0], dict):
+ value = "\n ".join((f"{v.get('name')} ({v.get('strength')})" for v in value))
+ s = f"{self._job_info_translations.get(key, key)}: {value}"
+ strings.append(wrap_text(s, 80, subsequent_indent=" "))
+ strings.append(_("Seed") + f": {params.seed}")
return "\n".join(flatten(strings))
def remove(self, job: Job):
@@ -217,6 +215,7 @@ def update_selection(self):
elif selection:
item = self._find(selection)
if item is not None and not item.isSelected():
+ self.clearSelection()
item.setSelected(True)
self.update_apply_button()
@@ -366,7 +365,7 @@ def _copy_prompt(self):
active = self._model.regions.active_or_root
active.positive = job.params.prompt
if isinstance(active, RootRegion):
- active.negative = job.params.negative_prompt
+ active.negative = job.params.metadata.get("negative_prompt", "")
def _copy_strength(self):
if job := self.selected_job:
@@ -519,6 +518,55 @@ def set_context(self):
self._model.inpaint.context = data
+class ProgressBar(QProgressBar):
+ def __init__(self, parent: QWidget):
+ super().__init__(parent)
+ self._model = root.active_model
+ self._model_bindings: list[QMetaObject.Connection] = []
+ self._palette = self.palette()
+ self.setMinimum(0)
+ self.setMaximum(1000)
+ self.setTextVisible(False)
+ self.setFixedHeight(6)
+
+ @property
+ def model(self):
+ return self._model
+
+ @model.setter
+ def model(self, model: Model):
+ if self._model != model:
+ Binding.disconnect_all(self._model_bindings)
+ self._model = model
+ self._model_bindings = [
+ self._model.progress_changed.connect(self._update_progress),
+ self._model.progress_kind_changed.connect(self._update_progress_kind),
+ ]
+
+ def _update_progress_kind(self):
+ palette = self._palette
+ if self._model.progress_kind is ProgressKind.upload:
+ palette = self.palette()
+ palette.setColor(QPalette.ColorRole.Highlight, QColor(theme.progress_alt))
+ self.setPalette(palette)
+
+ def _update_progress(self):
+ if self._model.progress >= 0:
+ self.setValue(int(self._model.progress * 1000))
+ else:
+ if self.value() >= 100:
+ self.reset()
+ self.setValue(min(99, self.value() + 2))
+
+
+def create_error_label(parent: QWidget):
+ label = QLabel(parent)
+ label.setStyleSheet("font-weight: bold; color: red;")
+ label.setWordWrap(True)
+ label.setVisible(False)
+ return label
+
+
class GenerationWidget(QWidget):
_model: Model
_model_bindings: list[QMetaObject.Connection | Binding]
@@ -590,17 +638,10 @@ def __init__(self):
actions_layout.addWidget(self.queue_button)
layout.addLayout(actions_layout)
- self.progress_bar = QProgressBar(self)
- self.progress_bar.setMinimum(0)
- self.progress_bar.setMaximum(1000)
- self.progress_bar.setTextVisible(False)
- self.progress_bar.setFixedHeight(6)
+ self.progress_bar = ProgressBar(self)
layout.addWidget(self.progress_bar)
- self.error_text = QLabel(self)
- self.error_text.setStyleSheet("font-weight: bold; color: red;")
- self.error_text.setWordWrap(True)
- self.error_text.setVisible(False)
+ self.error_text = create_error_label(self)
layout.addWidget(self.error_text)
self.history = HistoryWidget(self)
@@ -629,8 +670,6 @@ def model(self, model: Model):
model.document.layers.active_changed.connect(self.update_generate_button),
model.regions.active_changed.connect(self.update_generate_button),
model.region_only_changed.connect(self.update_generate_button),
- model.progress_changed.connect(self.update_progress),
- model.progress_kind_changed.connect(self.update_progress_kind),
model.error_changed.connect(self.error_text.setText),
model.has_error_changed.connect(self.error_text.setVisible),
self.add_control_button.clicked.connect(model.regions.add_control),
@@ -642,24 +681,11 @@ def model(self, model: Model):
self.custom_inpaint.model = model
self.generate_button.model = model
self.queue_button.model = model
+ self.progress_bar.model = model
self.strength_slider.model = model
self.history.model_ = model
self.update_generate_button()
- def update_progress_kind(self):
- palette = self.palette()
- if self.model.progress_kind is ProgressKind.upload:
- palette.setColor(QPalette.ColorRole.Highlight, QColor(theme.progress_alt))
- self.progress_bar.setPalette(palette)
-
- def update_progress(self):
- if self.model.progress >= 0:
- self.progress_bar.setValue(int(self.model.progress * 1000))
- else:
- if self.progress_bar.value() >= 100:
- self.progress_bar.reset()
- self.progress_bar.setValue(min(99, self.progress_bar.value() + 2))
-
def apply_result(self, item: QListWidgetItem):
job_id, index = self.history.item_info(item)
self.model.apply_generated_result(job_id, index)
diff --git a/ai_diffusion/ui/live.py b/ai_diffusion/ui/live.py
index dce222bace..ad39a9a40b 100644
--- a/ai_diffusion/ui/live.py
+++ b/ai_diffusion/ui/live.py
@@ -23,6 +23,19 @@
from . import theme
+class LivePreviewArea(QLabel):
+ def __init__(self, parent: QWidget):
+ super().__init__(parent)
+ self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
+ self.setAlignment(Qt.AlignmentFlag(Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft))
+
+ def show_image(self, image: Image):
+ target = Extent.from_qsize(self.size())
+ img = Image.scale_to_fit(image, target)
+ self.setPixmap(img.to_pixmap())
+ self.setMinimumSize(256, 256)
+
+
class LiveWidget(QWidget):
_play_icon = theme.icon("play")
_pause_icon = theme.icon("pause")
@@ -163,11 +176,7 @@ def __init__(self):
self.progress_bar.setVisible(False)
layout.addWidget(self.progress_bar)
- self.preview_area = QLabel(self)
- self.preview_area.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
- self.preview_area.setAlignment(
- Qt.AlignmentFlag(Qt.AlignmentFlag.AlignTop | Qt.AlignmentFlag.AlignLeft)
- )
+ self.preview_area = LivePreviewArea(self)
layout.addWidget(self.preview_area)
@property
@@ -252,10 +261,7 @@ def update_progress(self):
def show_result(self, image: Image):
self.progress_bar.setVisible(False)
- target = Extent.from_qsize(self.preview_area.size())
- img = Image.scale_to_fit(image, target)
- self.preview_area.setPixmap(img.to_pixmap())
- self.preview_area.setMinimumSize(256, 256)
+ self.preview_area.show_image(image)
def apply_result(self):
self.model.live.apply_result()
diff --git a/ai_diffusion/ui/style.py b/ai_diffusion/ui/style.py
index 715ac9b8b9..bd25649cf4 100644
--- a/ai_diffusion/ui/style.py
+++ b/ai_diffusion/ui/style.py
@@ -389,7 +389,7 @@ def _upload_lora(self):
if client.max_upload_size and path.stat().st_size > client.max_upload_size:
_show_file_too_large_warning(client.max_upload_size, self)
return
- file = File.local(path, compute_hash=True)
+ file = File.local(path, FileFormat.lora, compute_hash=True)
root.files.loras.add(file)
self._add_item(file)
diff --git a/ai_diffusion/ui/widget.py b/ai_diffusion/ui/widget.py
index e674a6e106..7e96311946 100644
--- a/ai_diffusion/ui/widget.py
+++ b/ai_diffusion/ui/widget.py
@@ -691,6 +691,7 @@ class WorkspaceSelectWidget(QToolButton):
Workspace.upscaling: theme.icon("workspace-upscaling"),
Workspace.live: theme.icon("workspace-live"),
Workspace.animation: theme.icon("workspace-animation"),
+ Workspace.custom: theme.icon("workspace-custom"),
}
_value = Workspace.generation
@@ -703,6 +704,7 @@ def __init__(self, parent):
menu.addAction(self._create_action(_("Upscale"), Workspace.upscaling))
menu.addAction(self._create_action(_("Live"), Workspace.live))
menu.addAction(self._create_action(_("Animation"), Workspace.animation))
+ menu.addAction(self._create_action(_("Graph"), Workspace.custom))
self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
self.setMenu(menu)
diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py
index 36f153d170..98dfd93dad 100644
--- a/ai_diffusion/workflow.py
+++ b/ai_diffusion/workflow.py
@@ -9,7 +9,7 @@
from . import resolution, resources
from .api import ControlInput, ImageInput, CheckpointInput, SamplingInput, WorkflowInput, LoraInput
from .api import ExtentInput, InpaintMode, InpaintParams, FillMode, ConditioningInput, WorkflowKind
-from .api import RegionInput
+from .api import RegionInput, CustomWorkflowInput
from .image import Bounds, Extent, Image, Mask, Point, multiple_of
from .client import ClientModels, ModelDict
from .files import FileLibrary, FileFormat
@@ -18,7 +18,7 @@
from .resources import ControlMode, Arch, UpscalerName, ResourceKind, ResourceId
from .settings import PerformanceSettings
from .text import merge_prompt, extract_loras
-from .comfy_workflow import ComfyWorkflow, ComfyRunMode, Output
+from .comfy_workflow import ComfyWorkflow, ComfyRunMode, Input, Output, ComfyNode
from .localization import translate as _
from .settings import settings
from .util import ensure, median_or_zero, unique, client_logger as log
@@ -1043,6 +1043,75 @@ def tiled_region(region: Region, index: int, tile_bounds: Bounds):
return w
+def expand_custom(
+ w: ComfyWorkflow,
+ input: CustomWorkflowInput,
+ images: ImageInput,
+ seed: int,
+ models: ClientModels,
+):
+ custom = ComfyWorkflow.from_dict(input.workflow)
+ nodes: dict[int, int] = {} # map old node IDs to new node IDs
+ outputs: dict[Output, Input] = {}
+
+ def map_input(input: Input):
+ if isinstance(input, Output):
+ mapped = outputs.get(input)
+ if mapped is not None:
+ return mapped
+ else:
+ return Output(nodes[input.node], input.output)
+ return input
+
+ def get_param(node: ComfyNode, expected_type: type | None = None):
+ name = node.input("name", "")
+ value = input.params.get(name)
+ if value is None:
+ raise Exception(f"Missing required parameter '{name}' for custom workflow")
+ if expected_type and not isinstance(value, expected_type):
+ raise Exception(f"Parameter '{name}' must be of type {expected_type}")
+ return value
+
+ for node in custom:
+ match node.type:
+ case "ETN_KritaCanvas":
+ image = ensure(images.initial_image)
+ outputs[node.output(0)] = w.load_image(image)
+ outputs[node.output(1)] = image.width
+ outputs[node.output(2)] = image.height
+ outputs[node.output(3)] = seed
+ case "ETN_KritaSelection":
+ outputs[node.output(0)] = w.load_mask(ensure(images.hires_mask))
+ case "ETN_Parameter":
+ outputs[node.output(0)] = get_param(node)
+ case "ETN_KritaImageLayer":
+ outputs[node.output(0)] = w.load_image(get_param(node, Image))
+ case "ETN_KritaMaskLayer":
+ outputs[node.output(0)] = w.load_mask(get_param(node, Image))
+ case "ETN_KritaStyle":
+ style: Style = get_param(node, Style)
+ is_live = node.input("sampler_preset", "auto") == "live"
+ checkpoint_input = style.get_models()
+ sampling = _sampling_from_style(style, 1.0, is_live)
+ model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, models)
+ outputs[node.output(0)] = model
+ outputs[node.output(1)] = clip
+ outputs[node.output(2)] = vae
+ outputs[node.output(3)] = style.style_prompt
+ outputs[node.output(4)] = style.negative_prompt
+ outputs[node.output(5)] = sampling.sampler
+ outputs[node.output(6)] = sampling.scheduler
+ outputs[node.output(7)] = sampling.total_steps
+ outputs[node.output(8)] = sampling.cfg_scale
+ case _:
+ mapped_inputs = {k: map_input(v) for k, v in node.inputs.items()}
+ mapped = ComfyNode(node.id, node.type, mapped_inputs)
+ nodes[node.id] = w.copy(mapped).node
+
+ w.guess_sample_count()
+ return w
+
+
###################################################################################################
@@ -1258,6 +1327,9 @@ def create(i: WorkflowInput, models: ClientModels, comfy_mode=ComfyRunMode.serve
bounds=i.inpaint.target_bounds if i.inpaint else None,
seed=i.sampling.seed if i.sampling else -1,
)
+ elif i.kind is WorkflowKind.custom:
+ seed = ensure(i.sampling).seed
+ return expand_custom(workflow, ensure(i.custom_workflow), ensure(i.images), seed, models)
else:
raise ValueError(f"Unsupported workflow kind: {i.kind}")
diff --git a/tests/data/object_info.json b/tests/data/object_info.json
new file mode 100644
index 0000000000..8ae9803922
--- /dev/null
+++ b/tests/data/object_info.json
@@ -0,0 +1,241 @@
+{
+ "GrowMask": {
+ "input": {
+ "required": {
+ "mask": [
+ "MASK"
+ ],
+ "expand": [
+ "INT",
+ {
+ "default": 0,
+ "min": -16384,
+ "max": 16384,
+ "step": 1
+ }
+ ],
+ "tapered_corners": [
+ "BOOLEAN",
+ {
+ "default": true
+ }
+ ]
+ }
+ },
+ "input_order": {
+ "required": [
+ "mask",
+ "expand",
+ "tapered_corners"
+ ]
+ },
+ "output": [
+ "MASK"
+ ],
+ "output_is_list": [
+ false
+ ],
+ "output_name": [
+ "MASK"
+ ],
+ "name": "GrowMask",
+ "display_name": "GrowMask",
+ "description": "",
+ "python_module": "comfy_extras.nodes_mask",
+ "category": "mask",
+ "output_node": false
+ },
+ "ImageUpscaleWithModel": {
+ "input": {
+ "required": {
+ "upscale_model": [
+ "UPSCALE_MODEL"
+ ],
+ "image": [
+ "IMAGE"
+ ]
+ }
+ },
+ "input_order": {
+ "required": [
+ "upscale_model",
+ "image"
+ ]
+ },
+ "output": [
+ "IMAGE"
+ ],
+ "output_is_list": [
+ false
+ ],
+ "output_name": [
+ "IMAGE"
+ ],
+ "name": "ImageUpscaleWithModel",
+ "display_name": "Upscale Image (using Model)",
+ "description": "",
+ "python_module": "comfy_extras.nodes_upscale_model",
+ "category": "image/upscaling",
+ "output_node": false
+ },
+ "ETN_ApplyMaskToImage": {
+ "input": {
+ "required": {
+ "image": [
+ "IMAGE"
+ ],
+ "mask": [
+ "MASK"
+ ]
+ }
+ },
+ "input_order": {
+ "required": [
+ "image",
+ "mask"
+ ]
+ },
+ "output": [
+ "IMAGE"
+ ],
+ "output_is_list": [
+ false
+ ],
+ "output_name": [
+ "IMAGE"
+ ],
+ "name": "ETN_ApplyMaskToImage",
+ "display_name": "Apply Mask to Image",
+ "description": "",
+ "python_module": "custom_nodes.comfyui-tooling-nodes",
+ "category": "external_tooling",
+ "output_node": false
+ },
+ "UpscaleModelLoader": {
+ "input": {
+ "required": {
+ "model_name": [
+ [
+ "4x_NMKD-Superscale-SP_178000_G.pth",
+ "OmniSR_X2_DIV2K.safetensors",
+ "OmniSR_X3_DIV2K.safetensors",
+ "OmniSR_X4_DIV2K.safetensors"
+ ]
+ ]
+ }
+ },
+ "input_order": {
+ "required": [
+ "model_name"
+ ]
+ },
+ "output": [
+ "UPSCALE_MODEL"
+ ],
+ "output_is_list": [
+ false
+ ],
+ "output_name": [
+ "UPSCALE_MODEL"
+ ],
+ "name": "UpscaleModelLoader",
+ "display_name": "Load Upscale Model",
+ "description": "",
+ "python_module": "comfy_extras.nodes_upscale_model",
+ "category": "loaders",
+ "output_node": false
+ },
+ "ETN_KritaCanvas": {
+ "input": {},
+ "input_order": {},
+ "output": [
+ "IMAGE",
+ "INT",
+ "INT",
+ "INT"
+ ],
+ "output_is_list": [
+ false,
+ false,
+ false,
+ false
+ ],
+ "output_name": [
+ "image",
+ "width",
+ "height",
+ "seed"
+ ],
+ "name": "ETN_KritaCanvas",
+ "display_name": "Krita Canvas",
+ "description": "",
+ "python_module": "custom_nodes.comfyui-tooling-nodes",
+ "category": "krita",
+ "output_node": false
+ },
+ "ETN_KritaOutput": {
+ "input": {
+ "required": {
+ "images": [
+ "IMAGE"
+ ],
+ "format": [
+ [
+ "PNG",
+ "JPEG"
+ ],
+ {
+ "default": "PNG"
+ }
+ ]
+ }
+ },
+ "input_order": {
+ "required": [
+ "images",
+ "format"
+ ]
+ },
+ "output": [],
+ "output_is_list": [],
+ "output_name": [],
+ "name": "ETN_KritaOutput",
+ "display_name": "Krita Output",
+ "description": "",
+ "python_module": "custom_nodes.comfyui-tooling-nodes",
+ "category": "krita",
+ "output_node": true
+ },
+ "ETN_KritaMaskLayer": {
+ "input": {
+ "required": {
+ "name": [
+ "STRING",
+ {
+ "default": "Mask"
+ }
+ ]
+ }
+ },
+ "input_order": {
+ "required": [
+ "name"
+ ]
+ },
+ "output": [
+ "MASK"
+ ],
+ "output_is_list": [
+ false
+ ],
+ "output_name": [
+ "mask"
+ ],
+ "name": "ETN_KritaMaskLayer",
+ "display_name": "Krita Mask Layer",
+ "description": "",
+ "python_module": "custom_nodes.comfyui-tooling-nodes",
+ "category": "krita",
+ "output_node": false
+ }
+}
\ No newline at end of file
diff --git a/tests/data/workflow-api.json b/tests/data/workflow-api.json
new file mode 100644
index 0000000000..8dfba8fc14
--- /dev/null
+++ b/tests/data/workflow-api.json
@@ -0,0 +1,64 @@
+{
+ "0": {
+ "class_type": "UpscaleModelLoader",
+ "inputs": {
+ "model_name": "4x_NMKD-Superscale-SP_178000_G.pth"
+ }
+ },
+ "1": {
+ "class_type": "ETN_KritaCanvas",
+ "inputs": {}
+ },
+ "2": {
+ "class_type": "ETN_KritaMaskLayer",
+ "inputs": {
+ "name": "Zauber"
+ }
+ },
+ "3": {
+ "class_type": "GrowMask",
+ "inputs": {
+ "mask": [
+ "2",
+ 0
+ ],
+ "expand": 4
+ }
+ },
+ "4": {
+ "class_type": "ImageUpscaleWithModel",
+ "inputs": {
+ "upscale_model": [
+ "0",
+ 0
+ ],
+ "image": [
+ "1",
+ 0
+ ]
+ }
+ },
+ "5": {
+ "class_type": "ETN_ApplyMaskToImage",
+ "inputs": {
+ "image": [
+ "4",
+ 0
+ ],
+ "mask": [
+ "3",
+ 0
+ ]
+ }
+ },
+ "6": {
+ "class_type": "ETN_KritaOutput",
+ "inputs": {
+ "images": [
+ "5",
+ 0
+ ],
+ "format": "PNG"
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/data/workflow-ui.json b/tests/data/workflow-ui.json
new file mode 100644
index 0000000000..7ade8ed334
--- /dev/null
+++ b/tests/data/workflow-ui.json
@@ -0,0 +1,438 @@
+{
+ "last_node_id": 10,
+ "last_link_id": 10,
+ "nodes": [
+ {
+ "id": 3,
+ "type": "GrowMask",
+ "pos": {
+ "0": 448,
+ "1": 57
+ },
+ "size": [
+ 315,
+ 82
+ ],
+ "flags": {},
+ "order": 5,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "mask",
+ "type": "MASK",
+ "link": 3
+ },
+ {
+ "name": "expand",
+ "type": "INT",
+ "link": 2,
+ "widget": {
+ "name": "expand"
+ }
+ }
+ ],
+ "outputs": [
+ {
+ "name": "MASK",
+ "type": "MASK",
+ "links": [
+ 10
+ ],
+ "shape": 3
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "GrowMask"
+ },
+ "widgets_values": [
+ 4,
+ true
+ ]
+ },
+ {
+ "id": 4,
+ "type": "PrimitiveNode",
+ "pos": {
+ "0": 138,
+ "1": 156
+ },
+ "size": [
+ 210,
+ 82
+ ],
+ "flags": {},
+ "order": 0,
+ "mode": 0,
+ "inputs": [],
+ "outputs": [
+ {
+ "name": "INT",
+ "type": "INT",
+ "links": [
+ 2
+ ],
+ "slot_index": 0,
+ "widget": {
+ "name": "expand"
+ }
+ }
+ ],
+ "properties": {
+ "Run widget replace on values": false
+ },
+ "widgets_values": [
+ 4,
+ "fixed"
+ ]
+ },
+ {
+ "id": 8,
+ "type": "PrimitiveNode",
+ "pos": {
+ "0": -280,
+ "1": -370
+ },
+ "size": [
+ 364.38086885107873,
+ 106
+ ],
+ "flags": {},
+ "order": 1,
+ "mode": 0,
+ "inputs": [],
+ "outputs": [
+ {
+ "name": "COMBO",
+ "type": "COMBO",
+ "links": [
+ 4
+ ],
+ "slot_index": 0,
+ "widget": {
+ "name": "model_name"
+ }
+ }
+ ],
+ "properties": {
+ "Run widget replace on values": false
+ },
+ "widgets_values": [
+ "4x_NMKD-Superscale-SP_178000_G.pth",
+ "fixed",
+ ""
+ ]
+ },
+ {
+ "id": 9,
+ "type": "ImageUpscaleWithModel",
+ "pos": {
+ "0": 470,
+ "1": -370
+ },
+ "size": {
+ "0": 340.20001220703125,
+ "1": 46
+ },
+ "flags": {},
+ "order": 6,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "upscale_model",
+ "type": "UPSCALE_MODEL",
+ "link": 5
+ },
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 6
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 9
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "ImageUpscaleWithModel"
+ }
+ },
+ {
+ "id": 10,
+ "type": "ETN_ApplyMaskToImage",
+ "pos": {
+ "0": 870,
+ "1": -190
+ },
+ "size": {
+ "0": 239.40000915527344,
+ "1": 46
+ },
+ "flags": {},
+ "order": 7,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "link": 9
+ },
+ {
+ "name": "mask",
+ "type": "MASK",
+ "link": 10
+ }
+ ],
+ "outputs": [
+ {
+ "name": "IMAGE",
+ "type": "IMAGE",
+ "links": [
+ 8
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "ETN_ApplyMaskToImage"
+ }
+ },
+ {
+ "id": 6,
+ "type": "UpscaleModelLoader",
+ "pos": {
+ "0": 120,
+ "1": -370
+ },
+ "size": [
+ 315,
+ 58
+ ],
+ "flags": {},
+ "order": 4,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "model_name",
+ "type": "COMBO",
+ "link": 4,
+ "widget": {
+ "name": "model_name"
+ }
+ }
+ ],
+ "outputs": [
+ {
+ "name": "UPSCALE_MODEL",
+ "type": "UPSCALE_MODEL",
+ "links": [
+ 5
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "UpscaleModelLoader"
+ },
+ "widgets_values": [
+ "4x_NMKD-Superscale-SP_178000_G.pth"
+ ]
+ },
+ {
+ "id": 1,
+ "type": "ETN_KritaCanvas",
+ "pos": {
+ "0": 205,
+ "1": -247
+ },
+ "size": [
+ 200,
+ 100
+ ],
+ "flags": {},
+ "order": 2,
+ "mode": 0,
+ "inputs": [],
+ "outputs": [
+ {
+ "name": "image",
+ "type": "IMAGE",
+ "links": [
+ 6
+ ],
+ "shape": 3,
+ "slot_index": 0
+ },
+ {
+ "name": "width",
+ "type": "INT",
+ "links": null,
+ "shape": 3
+ },
+ {
+ "name": "height",
+ "type": "INT",
+ "links": null,
+ "shape": 3
+ },
+ {
+ "name": "seed",
+ "type": "INT",
+ "links": null,
+ "shape": 3
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "ETN_KritaCanvas"
+ }
+ },
+ {
+ "id": 2,
+ "type": "ETN_KritaOutput",
+ "pos": {
+ "0": 1140,
+ "1": -190
+ },
+ "size": [
+ 200,
+ 120
+ ],
+ "flags": {},
+ "order": 8,
+ "mode": 0,
+ "inputs": [
+ {
+ "name": "images",
+ "type": "IMAGE",
+ "link": 8
+ }
+ ],
+ "outputs": [],
+ "properties": {
+ "Node name for S&R": "ETN_KritaOutput"
+ },
+ "widgets_values": [
+ "PNG"
+ ]
+ },
+ {
+ "id": 5,
+ "type": "ETN_KritaMaskLayer",
+ "pos": {
+ "0": 41,
+ "1": 11
+ },
+ "size": {
+ "0": 315,
+ "1": 58
+ },
+ "flags": {},
+ "order": 3,
+ "mode": 0,
+ "inputs": [],
+ "outputs": [
+ {
+ "name": "mask",
+ "type": "MASK",
+ "links": [
+ 3
+ ],
+ "shape": 3,
+ "slot_index": 0
+ }
+ ],
+ "properties": {
+ "Node name for S&R": "ETN_KritaMaskLayer"
+ },
+ "widgets_values": [
+ "Zauber"
+ ]
+ }
+ ],
+ "links": [
+ [
+ 2,
+ 4,
+ 0,
+ 3,
+ 1,
+ "INT"
+ ],
+ [
+ 3,
+ 5,
+ 0,
+ 3,
+ 0,
+ "MASK"
+ ],
+ [
+ 4,
+ 8,
+ 0,
+ 6,
+ 0,
+ "COMBO"
+ ],
+ [
+ 5,
+ 6,
+ 0,
+ 9,
+ 0,
+ "UPSCALE_MODEL"
+ ],
+ [
+ 6,
+ 1,
+ 0,
+ 9,
+ 1,
+ "IMAGE"
+ ],
+ [
+ 8,
+ 10,
+ 0,
+ 2,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 9,
+ 9,
+ 0,
+ 10,
+ 0,
+ "IMAGE"
+ ],
+ [
+ 10,
+ 3,
+ 0,
+ 10,
+ 1,
+ "MASK"
+ ]
+ ],
+ "groups": [],
+ "config": {},
+ "extra": {
+ "ds": {
+ "scale": 1.0834705943388394,
+ "offset": [
+ 311.97951538052513,
+ 487.01874472527845
+ ]
+ }
+ },
+ "version": 0.4
+}
\ No newline at end of file
diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py
new file mode 100644
index 0000000000..ebf1693d81
--- /dev/null
+++ b/tests/test_custom_workflow.py
@@ -0,0 +1,387 @@
+import json
+import pytest
+from pathlib import Path
+from PyQt5.QtCore import Qt
+
+from ai_diffusion.api import CustomWorkflowInput, ImageInput, WorkflowInput
+from ai_diffusion.client import Client, ClientModels, CheckpointInfo
+from ai_diffusion.connection import Connection, ConnectionState
+from ai_diffusion.comfy_workflow import ComfyNode, ComfyWorkflow, Output
+from ai_diffusion.custom_workflow import CustomWorkflow, WorkflowSource, WorkflowCollection
+from ai_diffusion.custom_workflow import SortedWorkflows, CustomWorkspace
+from ai_diffusion.custom_workflow import CustomParam, ParamKind, workflow_parameters
+from ai_diffusion.image import Image, Extent
+from ai_diffusion.jobs import JobQueue
+from ai_diffusion.style import Style
+from ai_diffusion.resources import Arch
+from ai_diffusion import workflow
+
+from .config import test_dir
+
+
+class MockClient(Client):
+ def __init__(self, node_inputs: dict[str, dict]):
+ self.models = ClientModels()
+ self.models.node_inputs = node_inputs
+
+ @staticmethod
+ async def connect(url: str, access_token: str = "") -> Client:
+ return MockClient({})
+
+ async def enqueue(self, work: WorkflowInput, front: bool = False) -> str:
+ return ""
+
+ async def listen(self): # type: ignore
+ return
+
+ async def interrupt(self):
+ pass
+
+ async def clear_queue(self):
+ pass
+
+
+def create_mock_connection(
+ initial_workflows: dict[str, dict],
+ node_inputs: dict[str, dict] | None = None,
+ state: ConnectionState = ConnectionState.connected,
+):
+ connection = Connection()
+ connection._client = MockClient(node_inputs or {})
+ connection._workflows = initial_workflows
+ connection.state = state
+ return connection
+
+
+def _assert_has_workflow(
+ collection: WorkflowCollection,
+ name: str,
+ source: WorkflowSource,
+ graph: dict,
+ file: Path | None = None,
+):
+ workflow = collection.find(name)
+ assert (
+ workflow is not None
+ and workflow.source == source
+ and workflow.workflow.root == graph
+ and workflow.path == file
+ )
+
+
+def test_collection(tmp_path: Path):
+ file1 = tmp_path / "file1.json"
+ file1_graph = {"0": {"class_type": "F1", "inputs": {}}}
+ file1.write_text(json.dumps(file1_graph))
+
+ file2 = tmp_path / "file2.json"
+ file2_graph = {"0": {"class_type": "F2", "inputs": {}}}
+ file2.write_text(json.dumps(file2_graph))
+
+ connection_graph = {"0": {"class_type": "C1", "inputs": {}}}
+ connection_workflows = {"connection1": connection_graph}
+ connection = create_mock_connection(connection_workflows, state=ConnectionState.disconnected)
+
+ collection = WorkflowCollection(connection, tmp_path)
+ events = []
+
+ assert len(collection) == 0
+
+ def on_loaded():
+ events.append("loaded")
+
+ collection.loaded.connect(on_loaded)
+ doc_graph = {"0": {"class_type": "D1", "inputs": {}}}
+ collection.add_from_document("doc1", doc_graph)
+
+ connection.state = ConnectionState.connected
+ assert len(collection) == 4
+ assert events == ["loaded"]
+ _assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph, file1)
+ _assert_has_workflow(collection, "file2", WorkflowSource.local, file2_graph, file2)
+ _assert_has_workflow(collection, "connection1", WorkflowSource.remote, connection_graph)
+ _assert_has_workflow(collection, "doc1", WorkflowSource.document, doc_graph)
+
+ def on_begin_insert(index, first, last):
+ events.append(("begin_insert", first))
+
+ def on_end_insert():
+ events.append("end_insert")
+
+ def on_data_changed(start, end):
+ events.append(("data_changed", start.row()))
+
+ collection.rowsAboutToBeInserted.connect(on_begin_insert)
+ collection.rowsInserted.connect(on_end_insert)
+ collection.dataChanged.connect(on_data_changed)
+
+ connection2_graph = {"0": {"class_type": "C2", "inputs": {}}}
+ connection_workflows["connection2"] = connection2_graph
+ connection.workflow_published.emit("connection2")
+
+ assert len(collection) == 5
+ _assert_has_workflow(collection, "connection2", WorkflowSource.remote, connection2_graph)
+
+ file1_graph_changed = {"0": {"class_type": "F3", "inputs": {}}}
+ collection.set_graph(collection.find_index("file1"), file1_graph_changed)
+ _assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph_changed, file1)
+ assert events == ["loaded", ("begin_insert", 4), "end_insert", ("data_changed", 1)]
+
+ sorted = SortedWorkflows(collection)
+ assert sorted[0].source is WorkflowSource.document
+ assert sorted[1].source is WorkflowSource.remote
+ assert sorted[2].source is WorkflowSource.remote
+ assert sorted[3].name == "file1"
+ assert sorted[4].name == "file2"
+
+
+def make_dummy_graph(n: int = 42):
+ return {
+ "1": {
+ "class_type": "ETN_Parameter",
+ "inputs": {
+ "name": "param1",
+ "type": "number (integer)",
+ "default": n,
+ "min": 5,
+ "max": 95,
+ },
+ }
+ }
+
+
+def test_files(tmp_path: Path):
+ collection_folder = tmp_path / "workflows"
+
+ collection = WorkflowCollection(create_mock_connection({}, {}), collection_folder)
+ assert len(collection) == 0
+
+ file1 = tmp_path / "file1.json"
+ file1.write_text(json.dumps(make_dummy_graph()))
+
+ collection.import_file(file1)
+ assert collection.find("file1") is not None
+
+ collection.import_file(file1)
+ assert collection.find("file1 (1)") is not None
+
+ collection.save_as("file1", make_dummy_graph(77))
+ assert collection.find("file1 (2)") is not None
+
+ files = [
+ collection_folder / "file1.json",
+ collection_folder / "file1 (1).json",
+ collection_folder / "file1 (2).json",
+ ]
+ assert all(f.exists() for f in files)
+
+ collection.remove("file1 (1)")
+ assert collection.find("file1 (1)") is None
+ assert not (collection_folder / "file1 (1).json").exists()
+
+ bad_file = tmp_path / "bad.json"
+ bad_file.write_text("bad json")
+ with pytest.raises(RuntimeError):
+ collection.import_file(bad_file)
+
+
+async def dummy_generate(workflow_input):
+ return None
+
+
+def test_workspace():
+ connection_workflows = {"connection1": make_dummy_graph(42)}
+ connection = create_mock_connection(connection_workflows, {})
+ workflows = WorkflowCollection(connection)
+
+ jobs = JobQueue()
+ workspace = CustomWorkspace(workflows, dummy_generate, jobs)
+ assert workspace.workflow_id == "connection1"
+ assert workspace.workflow and workspace.workflow.id == "connection1"
+ assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter"
+ assert workspace.metadata[0].name == "param1"
+ assert workspace.params == {"param1": 42}
+
+ doc_graph = {
+ "1": {
+ "class_type": "ETN_Parameter",
+ "inputs": {
+ "name": "param2",
+ "type": "number (integer)",
+ "default": 23,
+ "min": 5,
+ "max": 95,
+ },
+ }
+ }
+ workspace.set_graph("doc1", doc_graph)
+ assert workspace.workflow_id == "Document Workflow (embedded)"
+ assert workspace.workflow and workspace.workflow.source is WorkflowSource.document
+ assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter"
+ assert workspace.metadata[0].name == "param2"
+ assert workspace.params == {"param2": 23}
+
+ doc_graph["1"]["inputs"]["default"] = 24
+ doc_graph["2"] = {
+ "class_type": "ETN_Parameter",
+ "inputs": {"name": "param3", "type": "number (integer)", "default": 7, "min": 0, "max": 10},
+ }
+ workflows.set_graph(workflows.index(1), doc_graph)
+ assert workspace.metadata[0].default == 24
+ assert workspace.metadata[1].name == "param3"
+ assert workspace.params == {"param2": 23, "param3": 7}
+
+
+def test_import():
+ graph = {
+ "4": {"class_type": "A", "inputs": {"steps": 4, "float": 1.2, "string": "mouse"}},
+ "zak": {"class_type": "C", "inputs": {"in": ["9", 1]}},
+ "9": {"class_type": "B", "inputs": {"in": ["4", 0]}},
+ }
+ w = ComfyWorkflow.import_graph(graph, {})
+ assert w.node(0) == ComfyNode(0, "A", {"steps": 4, "float": 1.2, "string": "mouse"})
+ assert w.node(1) == ComfyNode(1, "B", {"in": Output(0, 0)})
+ assert w.node(2) == ComfyNode(2, "C", {"in": Output(1, 1)})
+ assert w.node_count == 3
+ assert w.guess_sample_count() == 4
+
+
+def test_import_ui_workflow():
+ graph = json.loads((test_dir / "data" / "workflow-ui.json").read_text())
+ object_info = json.loads((test_dir / "data" / "object_info.json").read_text())
+ node_inputs = {k: v.get("input") for k, v in object_info.items()}
+ result = ComfyWorkflow.import_graph(graph, node_inputs)
+
+ expected_graph = json.loads((test_dir / "data" / "workflow-api.json").read_text())
+ expected = ComfyWorkflow.import_graph(expected_graph, {})
+ assert result.root == expected.root
+
+
+def test_parameters():
+ node_inputs = {"ChoiceNode": {"required": {"choice_param": (["a", "b", "c"],)}}}
+
+ w = ComfyWorkflow(node_inputs=node_inputs)
+ w.add("ETN_Parameter", 1, name="int", type="number (integer)", default=4, min=0, max=10)
+ w.add("ETN_Parameter", 1, name="bool", type="toggle", default=True)
+ w.add("ETN_Parameter", 1, name="number", type="number", default=1.2, min=0.0, max=10.0)
+ w.add("ETN_Parameter", 1, name="text", type="text", default="mouse")
+ w.add("ETN_Parameter", 1, name="positive", type="prompt (positive)", default="p")
+ w.add("ETN_Parameter", 1, name="negative", type="prompt (negative)", default="n")
+ w.add("ETN_Parameter", 1, name="choice_unconnected", type="choice", default="z")
+ choice_param = w.add("ETN_Parameter", 1, name="choice", type="choice", default="c")
+ w.add("ChoiceNode", 1, choice_param=choice_param)
+ w.add("ETN_KritaImageLayer", 1, name="image")
+ w.add("ETN_KritaMaskLayer", 1, name="mask")
+ w.add("ETN_KritaStyle", 9, name="style", sampler_preset="live") # type: ignore
+
+ assert list(workflow_parameters(w)) == [
+ CustomParam(ParamKind.number_int, "int", 4, 0, 10),
+ CustomParam(ParamKind.toggle, "bool", True),
+ CustomParam(ParamKind.number_float, "number", 1.2, 0.0, 10.0),
+ CustomParam(ParamKind.text, "text", "mouse"),
+ CustomParam(ParamKind.prompt_positive, "positive", "p"),
+ CustomParam(ParamKind.prompt_negative, "negative", "n"),
+ CustomParam(ParamKind.text, "choice_unconnected", "z"),
+ CustomParam(ParamKind.choice, "choice", "c", choices=["a", "b", "c"]),
+ CustomParam(ParamKind.image_layer, "image"),
+ CustomParam(ParamKind.mask_layer, "mask"),
+ CustomParam(ParamKind.style, "style", "live"),
+ ]
+
+
+def test_expand():
+ ext = ComfyWorkflow()
+ in_img, width, height, seed = ext.add("ETN_KritaCanvas", 4)
+ scaled = ext.add("ImageScale", 1, image=in_img, width=width, height=height)
+ ext.add("ETN_KritaOutput", 1, images=scaled)
+ inty = ext.add(
+ "ETN_Parameter", 1, name="inty", type="number (integer)", default=4, min=0, max=10
+ )
+ numby = ext.add("ETN_Parameter", 1, name="numby", type="number", default=1.2, min=0.0, max=10.0)
+ texty = ext.add("ETN_Parameter", 1, name="texty", type="text", default="mouse")
+ booly = ext.add("ETN_Parameter", 1, name="booly", type="toggle", default=True)
+ choicy = ext.add("ETN_Parameter", 1, name="choicy", type="choice", default="c")
+ layer_img = ext.add("ETN_KritaImageLayer", 1, name="layer_img")
+ layer_mask = ext.add("ETN_KritaMaskLayer", 1, name="layer_mask")
+ stylie = ext.add("ETN_KritaStyle", 9, name="style", sampler_preset="live") # type: ignore
+ ext.add(
+ "Sink",
+ 1,
+ seed=seed,
+ inty=inty,
+ numby=numby,
+ texty=texty,
+ booly=booly,
+ choicy=choicy,
+ layer_img=layer_img,
+ layer_mask=layer_mask,
+ model=stylie[0],
+ clip=stylie[1],
+ vae=stylie[2],
+ positive=stylie[3],
+ negative=stylie[4],
+ sampler=stylie[5],
+ scheduler=stylie[6],
+ steps=stylie[7],
+ guidance=stylie[8],
+ )
+
+ style = Style(Path("default.json"))
+ style.sd_checkpoint = "checkpoint.safetensors"
+ style.style_prompt = "bee hive"
+ style.negative_prompt = "pigoon"
+ params = {
+ "inty": 7,
+ "numby": 3.4,
+ "texty": "cat",
+ "booly": False,
+ "choicy": "b",
+ "layer_img": Image.create(Extent(4, 4), Qt.GlobalColor.black),
+ "layer_mask": Image.create(Extent(4, 4), Qt.GlobalColor.white),
+ "style": style,
+ }
+
+ input = CustomWorkflowInput(workflow=ext.root, params=params)
+ images = ImageInput.from_extent(Extent(4, 4))
+ images.initial_image = Image.create(Extent(4, 4), Qt.GlobalColor.white)
+
+ models = ClientModels()
+ models.checkpoints = {
+ "checkpoint.safetensors": CheckpointInfo("checkpoint.safetensors", Arch.sd15)
+ }
+
+ w = ComfyWorkflow()
+ w = workflow.expand_custom(w, input, images, 123, models)
+ expected = [
+ ComfyNode(1, "ETN_LoadImageBase64", {"image": images.initial_image.to_base64()}),
+ ComfyNode(2, "ImageScale", {"image": Output(1, 0), "width": 4, "height": 4}),
+ ComfyNode(3, "ETN_KritaOutput", {"images": Output(2, 0)}),
+ ComfyNode(4, "ETN_LoadImageBase64", {"image": params["layer_img"].to_base64()}),
+ ComfyNode(5, "ETN_LoadMaskBase64", {"mask": params["layer_mask"].to_base64()}),
+ ComfyNode(6, "CheckpointLoaderSimple", {"ckpt_name": "checkpoint.safetensors"}),
+ ComfyNode(
+ 7,
+ "Sink",
+ {
+ "seed": 123,
+ "inty": 7,
+ "numby": 3.4,
+ "texty": "cat",
+ "booly": False,
+ "choicy": "b",
+ "layer_img": Output(4, 0),
+ "layer_mask": Output(5, 0),
+ "model": Output(6, 0),
+ "clip": Output(6, 1),
+ "vae": Output(6, 2),
+ "positive": "bee hive",
+ "negative": "pigoon",
+ "sampler": "euler",
+ "scheduler": "sgm_uniform",
+ "steps": 6,
+ "guidance": 1.8,
+ },
+ ),
+ ]
+ for node in expected:
+ assert node in w, f"Node {node} not found in\n{json.dumps(w.root, indent=2)}"
diff --git a/tests/test_workflow.py b/tests/test_workflow.py
index eef2c5755b..4820465beb 100644
--- a/tests/test_workflow.py
+++ b/tests/test_workflow.py
@@ -1,17 +1,21 @@
import itertools
import pytest
import dotenv
+import json
import os
from datetime import datetime
from pathlib import Path
from typing import Any
+from PyQt5.QtCore import Qt
from ai_diffusion import workflow
-from ai_diffusion.api import LoraInput, WorkflowKind, WorkflowInput, ControlInput
-from ai_diffusion.api import InpaintMode, FillMode, ConditioningInput, RegionInput
+from ai_diffusion.api import LoraInput, WorkflowKind, WorkflowInput, ControlInput, RegionInput
+from ai_diffusion.api import InpaintMode, FillMode, ConditioningInput, CustomWorkflowInput
+from ai_diffusion.api import SamplingInput, ImageInput
from ai_diffusion.client import ClientModels, CheckpointInfo
from ai_diffusion.comfy_client import ComfyClient
from ai_diffusion.cloud_client import CloudClient
+from ai_diffusion.comfy_workflow import ComfyWorkflow, ComfyNode, Output
from ai_diffusion.files import FileLibrary, FileCollection, File, FileSource
from ai_diffusion.resources import ControlMode
from ai_diffusion.settings import PerformanceSettings
@@ -538,7 +542,7 @@ async def main():
if not job_id:
job_id = await client.enqueue(job)
if msg.event is ClientEvent.finished and msg.job_id == job_id:
- assert msg.result is not None
+ assert isinstance(msg.result, dict)
result = Pose.from_open_pose_json(msg.result).to_svg()
(result_dir / image_name).write_text(result)
return