Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ class LatentPreviewMethod(enum.Enum):
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")

parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--use-subprocess-workers", action="store_true", help="Execute each prompt in an isolated subprocess with complete GPU/ROCm context reset. Ensures clean state between jobs but adds startup overhead.")
parser.add_argument("--subprocess-timeout", type=int, default=600, help="Timeout in seconds for subprocess execution (default: 600, only used with --use-subprocess-workers).")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")

class PerformanceFeature(enum.Enum):
Expand Down
145 changes: 145 additions & 0 deletions comfy/execution_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Core execution logic shared between normal and subprocess execution modes."""

import logging
import time

_active_worker = None


def create_worker(server_instance):
"""Create worker backend. Returns NativeWorker or SubprocessWorker."""
global _active_worker
from comfy.cli_args import args

server = WorkerServer(server_instance)

if args.use_subprocess_workers:
from comfy.worker_process import SubprocessWorker
worker = SubprocessWorker(server, timeout=args.subprocess_timeout)
else:
from comfy.worker_native import NativeWorker
worker = NativeWorker(server)

_active_worker = worker
return worker


async def init_execution_environment():
"""Load nodes and custom nodes. Returns number of node types loaded."""
import nodes
from comfy.cli_args import args

await nodes.init_extra_nodes(
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
init_api_nodes=not args.disable_api_nodes
)
return len(nodes.NODE_CLASS_MAPPINGS)


def setup_progress_hook(server_instance, interrupt_checker):
"""Set up global progress hook. interrupt_checker must raise on interrupt."""
import comfy.utils
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context

def hook(value, total, preview_image, prompt_id=None, node_id=None):
ctx = get_executing_context()
if ctx:
prompt_id = prompt_id or ctx.prompt_id
node_id = node_id or ctx.node_id

interrupt_checker()

prompt_id = prompt_id or server_instance.last_prompt_id
node_id = node_id or server_instance.last_node_id

get_progress_state().update_progress(node_id, value, total, preview_image)
server_instance.send_sync("progress", {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}, server_instance.client_id)

comfy.utils.set_progress_bar_global_hook(hook)


class WorkerServer:
"""Protocol boundary: client_id, last_node_id, last_prompt_id, sockets_metadata, send_sync(), queue_updated()"""

_WRITABLE = {'client_id', 'last_node_id', 'last_prompt_id'}

def __init__(self, server):
object.__setattr__(self, '_server', server)

def __setattr__(self, name, value):
if name in self._WRITABLE:
setattr(self._server, name, value)
else:
raise AttributeError(f"WorkerServer does not accept attribute '{name}'")

@property
def client_id(self):
return self._server.client_id

@property
def last_node_id(self):
return self._server.last_node_id

@property
def last_prompt_id(self):
return self._server.last_prompt_id

@property
def sockets_metadata(self):
return self._server.sockets_metadata

def send_sync(self, event, data, sid=None):
self._server.send_sync(event, data, sid or self.client_id)

def queue_updated(self):
self._server.queue_updated()

def interrupt_processing(value=True):
_active_worker.interrupt(value)


def _strip_sensitive(prompt):
return prompt[:5] + prompt[6:]


def prompt_worker(q, worker):
"""Main prompt execution loop."""
import execution

server = worker.server_instance

while True:
queue_item = q.get(timeout=worker.get_gc_timeout())
if queue_item is not None:
item, item_id = queue_item
start_time = time.perf_counter()
prompt_id = item[1]
server.last_prompt_id = prompt_id

extra_data = {**item[3], **item[5]}

result = worker.execute_prompt(item[2], prompt_id, extra_data, item[4], server=server)
worker.mark_needs_gc()

q.task_done(
item_id,
result['history_result'],
status=execution.PromptQueue.ExecutionStatus(
status_str='success' if result['success'] else 'error',
completed=result['success'],
messages=result['status_messages']
),
process_item=_strip_sensitive
)

if server.client_id is not None:
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id)

elapsed = time.perf_counter() - start_time
if elapsed > 600:
logging.info(f"Prompt executed in {time.strftime('%H:%M:%S', time.gmtime(elapsed))}")
else:
logging.info(f"Prompt executed in {elapsed:.2f} seconds")

worker.handle_flags(q.get_flags())
95 changes: 95 additions & 0 deletions comfy/worker_native.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Native (in-process) worker for prompt execution."""

import time
import gc


class NativeWorker:
"""Executes prompts in the same process as the server."""

def __init__(self, server_instance, interrupt_checker=None):
self.server_instance = server_instance
self.interrupt_checker = interrupt_checker
self.executor = None
self.last_gc_collect = 0
self.need_gc = False
self.gc_collect_interval = 10.0

async def initialize(self):
"""Load nodes and set up executor. Returns node count."""
from execution import PromptExecutor, CacheType
from comfy.cli_args import args
from comfy.execution_core import init_execution_environment, setup_progress_hook
import comfy.model_management as mm
import hook_breaker_ac10a0

hook_breaker_ac10a0.save_functions()
try:
node_count = await init_execution_environment()
finally:
hook_breaker_ac10a0.restore_functions()

interrupt_checker = self.interrupt_checker or mm.throw_exception_if_processing_interrupted
setup_progress_hook(self.server_instance, interrupt_checker=interrupt_checker)

cache_type = CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = CacheType.LRU
elif args.cache_ram > 0:
cache_type = CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = CacheType.NONE

self.executor = PromptExecutor(
self.server_instance,
cache_type=cache_type,
cache_args={"lru": args.cache_lru, "ram": args.cache_ram}
)
return node_count

def execute_prompt(self, prompt, prompt_id, extra_data, execute_outputs, server=None):
self.executor.execute(prompt, prompt_id, extra_data, execute_outputs)
return {
'success': self.executor.success,
'history_result': self.executor.history_result,
'status_messages': self.executor.status_messages,
'prompt_id': prompt_id
}

def handle_flags(self, flags):
import comfy.model_management as mm
import hook_breaker_ac10a0

free_memory = flags.get("free_memory", False)

if flags.get("unload_models", free_memory):
mm.unload_all_models()
self.need_gc = True
self.last_gc_collect = 0

if free_memory:
if self.executor:
self.executor.reset()
self.need_gc = True
self.last_gc_collect = 0

if self.need_gc:
current_time = time.perf_counter()
if (current_time - self.last_gc_collect) > self.gc_collect_interval:
gc.collect()
mm.soft_empty_cache()
self.last_gc_collect = current_time
self.need_gc = False
hook_breaker_ac10a0.restore_functions()

def interrupt(self, value=True):
import comfy.model_management
comfy.model_management.interrupt_current_processing(value)

def mark_needs_gc(self):
self.need_gc = True

def get_gc_timeout(self):
if self.need_gc:
return max(self.gc_collect_interval - (time.perf_counter() - self.last_gc_collect), 0.0)
return 1000.0
Loading