Skip to content

Commit

Permalink
Merge branch 'custom-workflow'
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Oct 13, 2024
2 parents fb03e15 + 089dc66 commit 5615255
Show file tree
Hide file tree
Showing 41 changed files with 4,304 additions and 217 deletions.
8 changes: 8 additions & 0 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class WorkflowKind(Enum):
upscale_simple = 4
upscale_tiled = 5
control_image = 6
custom = 7


@dataclass
Expand Down Expand Up @@ -144,6 +145,12 @@ def clamped(self):
return params


@dataclass
class CustomWorkflowInput:
workflow: dict
params: dict[str, Any]


@dataclass
class WorkflowInput:
kind: WorkflowKind
Expand All @@ -157,6 +164,7 @@ class WorkflowInput:
control_mode: ControlMode = ControlMode.reference
batch_count: int = 1
nsfw_filter: float = 0.0
custom_workflow: CustomWorkflowInput | None = None

@property
def extent(self):
Expand Down
8 changes: 7 additions & 1 deletion ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ class ClientEvent(Enum):
disconnected = 5
queued = 6
upload = 7
published = 8


class ClientMessage(NamedTuple):
event: ClientEvent
job_id: str = ""
progress: float = 0
images: ImageCollection | None = None
result: dict | None = None
result: dict | SharedWorkflow | None = None
error: str | None = None


Expand Down Expand Up @@ -68,6 +69,11 @@ def parse(data: dict):
return DeviceInfo("cpu", "unknown", 0)


class SharedWorkflow(NamedTuple):
publisher: str
workflow: dict


class CheckpointInfo(NamedTuple):
filename: str
arch: Arch
Expand Down
90 changes: 54 additions & 36 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .api import WorkflowInput
from .client import Client, CheckpointInfo, ClientMessage, ClientEvent, DeviceInfo, ClientModels
from .client import TranslationPackage, filter_supported_styles, loras_to_upload
from .client import SharedWorkflow, TranslationPackage, filter_supported_styles, loras_to_upload
from .files import FileFormat
from .image import Image, ImageCollection
from .network import RequestManager, NetworkError
Expand Down Expand Up @@ -87,16 +87,18 @@ class ComfyClient(Client):

default_url = "http://127.0.0.1:8188"

_requests = RequestManager()
_id: str
_messages: asyncio.Queue[ClientMessage]
_queue: asyncio.Queue[JobInfo]
_jobs: deque[JobInfo]
_active: Optional[JobInfo] = None
_job_runner: asyncio.Task
_websocket_listener: asyncio.Task
_supported_archs: list[Arch]
_supported_languages: list[TranslationPackage]
def __init__(self, url):
self.url = url
self.models = ClientModels()
self._requests = RequestManager()
self._id = str(uuid.uuid4())
self._active: Optional[JobInfo] = None
self._supported_archs: list[Arch] = []
self._supported_languages: list[TranslationPackage] = []
self._messages = asyncio.Queue()
self._queue = asyncio.Queue()
self._jobs = deque()
self._is_connected = False

@staticmethod
async def connect(url=default_url, access_token=""):
Expand Down Expand Up @@ -124,7 +126,7 @@ async def connect(url=default_url, access_token=""):

# Check for required and optional model resources
models = client.models
models.node_inputs = {name: nodes[name]["input"].get("required", None) for name in nodes}
models.node_inputs = {name: nodes[name]["input"] for name in nodes}
available_resources = client.models.resources = {}

clip_models = nodes["DualCLIPLoader"]["input"]["required"]["clip_name1"][0]
Expand Down Expand Up @@ -175,15 +177,6 @@ async def connect(url=default_url, access_token=""):
_ensure_supported_style(client)
return client

def __init__(self, url):
self.url = url
self.models = ClientModels()
self._id = str(uuid.uuid4())
self._messages = asyncio.Queue()
self._queue = asyncio.Queue()
self._jobs = deque()
self._is_connected = False

async def _get(self, op: str):
return await self._requests.get(f"{self.url}/{op}")

Expand Down Expand Up @@ -237,6 +230,7 @@ async def _listen(self):
f"{url}/ws?clientId={self._id}", max_size=2**30, read_limit=2**30, ping_timeout=60
):
try:
await self._subscribe_workflows()
await self._listen_websocket(websocket)
except websockets_exceptions.ConnectionClosedError as e:
log.warning(f"Websocket connection closed: {str(e)}")
Expand Down Expand Up @@ -287,12 +281,20 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr

if msg["type"] == "executing" and msg["data"]["node"] is None:
job_id = msg["data"]["prompt_id"]
if self._clear_job(job_id):
# Usually we don't get here because finished, interrupted or error is sent first.
# But it may happen if the entire execution is cached and no images are sent.
if local_id := self._clear_job(job_id):
if len(images) == 0:
# It may happen if the entire execution is cached and no images are sent.
images = last_images
await self._report(ClientEvent.finished, job_id, 1, images=images)
if len(images) == 0:
# Still no images. Potential scenario: execution cached, but previous
# generation happened before the client was connected.
err = "No new images were generated because the inputs did not change."
await self._report(ClientEvent.error, local_id, error=err)
else:
last_images = images
await self._report(
ClientEvent.finished, local_id, 1, images=images, result=result
)

elif msg["type"] in ("execution_cached", "executing", "progress"):
if self._active is not None and progress is not None:
Expand All @@ -308,12 +310,6 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr
pose_json = _extract_pose_json(msg)
if job and pose_json:
result = pose_json
elif job and _validate_executed_node(msg, len(images)):
self._clear_job(job.remote_id)
last_images = images
await self._report(
ClientEvent.finished, job.local_id, 1, images=images, result=result
)

if msg["type"] == "execution_error":
job = self._get_active_job(msg["data"]["prompt_id"])
Expand All @@ -322,7 +318,12 @@ async def _listen_websocket(self, websocket: websockets_client.WebSocketClientPr
traceback = msg["data"].get("traceback", "no traceback")
log.error(f"Job {job} failed: {error}\n{traceback}")
self._clear_job(job.remote_id)
await self._report(ClientEvent.error, job.local_id, 0, error=error)
await self._report(ClientEvent.error, job.local_id, error=error)

if msg["type"] == "etn_workflow_published":
name = f"{msg['data']['publisher']['name']} ({msg['data']['publisher']['id']})"
workflow = SharedWorkflow(name, msg["data"]["workflow"])
await self._report(ClientEvent.published, "", result=workflow)

async def listen(self):
self._is_connected = True
Expand Down Expand Up @@ -358,6 +359,7 @@ async def disconnect(self):
self._job_runner,
self._websocket_listener,
self._report(ClientEvent.disconnected, ""),
self._unsubscribe_workflows(),
)

async def try_inspect(self, folder_name: str):
Expand Down Expand Up @@ -431,6 +433,18 @@ async def translate(self, text: str, lang: str):
log.error(f"Could not translate text: {str(e)}")
return text

async def _subscribe_workflows(self):
try:
await self._post("api/etn/workflow/subscribe", {"client_id": self._id})
except Exception as e:
log.error(f"Couldn't subscribe to shared workflows: {str(e)}")

async def _unsubscribe_workflows(self):
try:
await self._post("api/etn/workflow/unsubscribe", {"client_id": self._id})
except Exception as e:
log.error(f"Couldn't unsubscribe from shared workflows: {str(e)}")

def supports_arch(self, arch: Arch):
return arch in self._supported_archs

Expand Down Expand Up @@ -501,9 +515,10 @@ async def _start_job(self, remote_id: str):

def _clear_job(self, job_remote_id: str | asyncio.Future | None):
if self._active is not None and self._active.remote_id == job_remote_id:
result = self._active.local_id
self._active = None
return True
return False
return result
return None

def _check_workload(self, sdver: Arch) -> list[MissingResource]:
models = self.models
Expand Down Expand Up @@ -719,8 +734,11 @@ def _validate_executed_node(msg: dict, image_count: int):
images = output["images"]
if len(images) != image_count: # not critical
log.warning(f"Received number of images does not match: {len(images)} != {image_count}")
if len(images) > 0 and "source" in images[0] and images[0]["type"] == "output":
if image_count == 0 or len(images) == 0:
log.warning(f"Received no images (execution cached?)")
return False
if "source" in images[0] and images[0]["type"] == "output":
return True
except Exception as e:
log.warning(f"Error processing message, error={str(e)}, msg={msg}")
return False
return False
Loading

0 comments on commit 5615255

Please sign in to comment.