Skip to content

Commit 0377f37

Browse files
committed
Wait cool down
1 parent b2e62c2 commit 0377f37

File tree

3 files changed

+127
-0
lines changed

3 files changed

+127
-0
lines changed

comfy/cli_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ class LatentPreviewMethod(enum.Enum):
9393

9494
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
9595

96+
parser.add_argument("--max-temperature", type=int, default=0, help="Don't execute a node if the temperature is above it, but wait cool down to the safe temperature.")
97+
parser.add_argument("--safe-temperature", type=int, default=0, help="Safe temperature to wait cool down before executin a node.")
98+
parser.add_argument("--safe-progress-temperature", type=int, default=0, help="Safe temperature to wait cool down between progress.")
99+
parser.add_argument("--max-cool-down-seconds", type=int, default=0, help="Max seconds to wait the temperature cool down.")
100+
parser.add_argument("--each-cool-down-seconds", type=int, default=5, help="Seconds to wait the temperature cool down before each measurement.")
96101

97102
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
98103
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")

comfy/utils.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
import os
99
import re
1010
from PIL import Image
11+
import time
12+
import psutil
13+
import GPUtil
14+
import platform
15+
import subprocess
16+
from comfy.cli_args import args
1117

1218
def get_extension_calling():
1319
for frame in inspect.stack():
@@ -454,6 +460,119 @@ def update_absolute(self, value, total=None, preview=None):
454460
self.current = value
455461
if self.hook is not None:
456462
self.hook(self.current, self.total, preview)
463+
wait_cooldown(kind="progress")
457464

458465
def update(self, value):
459466
self.update_absolute(self.current + value)
467+
468+
def clear_line(n=1):
469+
LINE_UP = '\033[1A'
470+
LINE_CLEAR = '\x1b[2K'
471+
for i in range(n):
472+
print(LINE_UP, end=LINE_CLEAR)
473+
474+
def func_sleep(seconds, pbar=None):
475+
while seconds > 0:
476+
print(f"Sleeping {seconds} seconds")
477+
time.sleep(1)
478+
seconds -= 1
479+
clear_line()
480+
if pbar is not None:
481+
pbar.update(1)
482+
483+
def get_processor_name():
484+
if platform.system() == "Windows":
485+
return platform.processor()
486+
elif platform.system() == "Darwin":
487+
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin'
488+
command ="sysctl -n machdep.cpu.brand_string"
489+
return subprocess.check_output(command).strip()
490+
elif platform.system() == "Linux":
491+
command = "cat /proc/cpuinfo"
492+
all_info = subprocess.check_output(command, shell=True).decode().strip()
493+
for line in all_info.split("\n"):
494+
if "model name" in line:
495+
return re.sub(".*model name.*:", "", line, 1).strip()
496+
return ""
497+
498+
def get_temperatures():
499+
temperatures = []
500+
501+
if platform.system() == "Linux":
502+
cpu_max_temp = 0
503+
504+
for k, v in psutil.sensors_temperatures(fahrenheit=False).items():
505+
for t in v:
506+
if t.current > cpu_max_temp:
507+
cpu_max_temp = t.current
508+
509+
temperatures.append({
510+
"label": get_processor_name(),
511+
"temperature": cpu_max_temp,
512+
"kind": "CPU",
513+
})
514+
515+
for gpu in GPUtil.getGPUs():
516+
temperatures.append({
517+
"label": gpu.name,
518+
"temperature": gpu.temperature,
519+
"kind": "GPU",
520+
})
521+
522+
return temperatures
523+
524+
waiting_cooldown = False
525+
526+
def _wait_cooldown(max_temperature=70, safe_temperature=60, seconds=2, max_seconds=0):
527+
global waiting_cooldown
528+
529+
if waiting_cooldown:
530+
return
531+
532+
waiting_cooldown = True
533+
534+
try:
535+
max_temperature, safe_temperature = max(max_temperature, safe_temperature), min(max_temperature, safe_temperature)
536+
537+
if max_temperature <= 0:
538+
return
539+
540+
if safe_temperature <= 0:
541+
safe_temperature = max_temperature
542+
543+
if max_seconds == 0:
544+
max_seconds = 0xffffffffffffffff
545+
546+
seconds = max(1, seconds)
547+
max_seconds = max(seconds, max_seconds)
548+
times = max_seconds // seconds
549+
550+
hot = True
551+
552+
# Start with the max temperature, so if not above it don't cool down.
553+
limit_temperature = max_temperature
554+
555+
while hot and times > 0:
556+
temperatures = [f"{t['kind']} {t['label']}: {t['temperature']}" for t in get_temperatures() if t["temperature"] > limit_temperature]
557+
hot = len(temperatures) > 0
558+
559+
if hot:
560+
# Switch to safe temperature to cool down to that temperature
561+
limit_temperature = safe_temperature
562+
print(f"Too hot! Limit temperature: [ {limit_temperature} ] Current temperature: [ " + " | ".join(temperatures) + " ]")
563+
pbar = ProgressBar(seconds)
564+
func_sleep(seconds, pbar)
565+
clear_line()
566+
times -= 1
567+
finally:
568+
waiting_cooldown = False
569+
570+
def wait_cooldown(kind="execution"):
571+
safe_temperature = args.safe_progress_temperature if kind == "progress" else args.safe_temperature
572+
if safe_temperature > 0:
573+
_wait_cooldown(
574+
max_temperature=args.max_temperature,
575+
safe_temperature=safe_temperature,
576+
seconds=args.each_cool_down_seconds,
577+
max_seconds=args.max_cool_down_seconds,
578+
)

execution.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import nodes
1414

1515
import comfy.model_management
16+
from comfy.utils import wait_cooldown
1617

1718
def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
1819
valid_inputs = class_def.INPUT_TYPES()
@@ -165,6 +166,8 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
165166
server.last_node_id = unique_id
166167
server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id)
167168

169+
wait_cooldown(kind="execution")
170+
168171
obj = object_storage.get((unique_id, class_type), None)
169172
if obj is None:
170173
obj = class_def()

0 commit comments

Comments
 (0)