Skip to content

Commit

Permalink
Enable External Event Loop Integration for ComfyUI [refactor] (#6114)
Browse files Browse the repository at this point in the history
* Refactor main.py to support external event loop integration

* added optional "asyncio_loop" argument to allow using existing event loop

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
  • Loading branch information
bigcat88 authored Dec 24, 2024
1 parent bc6dac4 commit 26e0ba8
Showing 1 changed file with 39 additions and 21 deletions.
60 changes: 39 additions & 21 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ def cuda_malloc_warning():
if cuda_malloc_warning:
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")

def prompt_worker(q, server):

def prompt_worker(q, server_instance):
current_time: float = 0.0
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
Expand All @@ -167,7 +168,7 @@ def prompt_worker(q, server):
item, item_id = queue_item
execution_start_time = time.perf_counter()
prompt_id = item[1]
server.last_prompt_id = prompt_id
server_instance.last_prompt_id = prompt_id

e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
Expand All @@ -177,8 +178,8 @@ def prompt_worker(q, server):
status_str='success' if e.success else 'error',
completed=e.success,
messages=e.status_messages))
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)

current_time = time.perf_counter()
execution_time = current_time - execution_start_time
Expand All @@ -205,21 +206,23 @@ def prompt_worker(q, server):
last_gc_collect = current_time
need_gc = False

async def run(server, address='', port=8188, verbose=True, call_on_start=None):

async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop())


def hijack_progress(server):
def hijack_progress(server_instance):
def hook(value, total, preview_image):
comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}

server.send_sync("progress", progress, server.client_id)
server_instance.send_sync("progress", progress, server_instance.client_id)
if preview_image is not None:
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)

comfy.utils.set_progress_bar_global_hook(hook)


Expand All @@ -229,7 +232,11 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True)


if __name__ == "__main__":
def start_comfyui(asyncio_loop=None):
"""
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
Returns the event loop, server instance, and a function to start the server asynchronously.
"""
if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
logging.info(f"Setting temp directory to: {temp_dir}")
Expand All @@ -243,19 +250,20 @@ def cleanup_temp():
except:
pass

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
q = execution.PromptQueue(server)
if not asyncio_loop:
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)

nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)

cuda_malloc_warning()

server.add_routes()
hijack_progress(server)
prompt_server.add_routes()
hijack_progress(prompt_server)

threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()

if args.quick_test_for_ci:
exit(0)
Expand All @@ -272,9 +280,19 @@ def startup_server(scheme, address, port):
webbrowser.open(f"{scheme}://{address}:{port}")
call_on_start = startup_server

async def start_all():
await prompt_server.setup()
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)

# Returning these so that other code can integrate with the ComfyUI loop and server
return asyncio_loop, prompt_server, start_all


if __name__ == "__main__":
# Running directly, just start ComfyUI.
event_loop, _, start_all_func = start_comfyui()
try:
loop.run_until_complete(server.setup())
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
event_loop.run_until_complete(start_all_func())
except KeyboardInterrupt:
logging.info("\nStopped server")

Expand Down

0 comments on commit 26e0ba8

Please sign in to comment.