Skip to content

Commit

Permalink
fix: support changes to 'prompt' caching mechanism
Browse files Browse the repository at this point in the history
In comfyui internals, 'prompt` actually refers to an entire set of nodes and settings. Previously, we were hijacking the `recursive_output_delete_if_changed` call, but a recent change in comfyui has switched this behavior to the class `IsChangedCache`. The hook in place before had more to do with detecting bugs to do with caching than anything else, so I've implemented a similar sort of hijack for `IsChangedCache`.
  • Loading branch information
tazlin committed Aug 23, 2024
1 parent 987aabc commit 67f44c2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 50 deletions.
84 changes: 37 additions & 47 deletions hordelib/comfy_horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
_comfy_cleanup_models: types.FunctionType
_comfy_soft_empty_cache: types.FunctionType

_comfy_recursive_output_delete_if_changed: types.FunctionType
_comfy_is_changed_cache_get: types.FunctionType

_canny: types.ModuleType
_hed: types.ModuleType
Expand Down Expand Up @@ -138,7 +138,6 @@ def do_comfy_import(
global _comfy_current_loaded_models
global _comfy_load_models_gpu
global _comfy_nodes, _comfy_PromptExecutor, _comfy_validate_prompt
global _comfy_recursive_output_delete_if_changed
global _comfy_folder_names_and_paths, _comfy_supported_pt_extensions
global _comfy_load_checkpoint_guess_config
global _comfy_get_torch_device, _comfy_get_free_memory, _comfy_get_total_memory
Expand Down Expand Up @@ -169,10 +168,15 @@ def do_comfy_import(
from execution import nodes as _comfy_nodes
from execution import PromptExecutor as _comfy_PromptExecutor
from execution import validate_prompt as _comfy_validate_prompt
from execution import recursive_output_delete_if_changed

_comfy_recursive_output_delete_if_changed = recursive_output_delete_if_changed # type: ignore
execution.recursive_output_delete_if_changed = recursive_output_delete_if_changed_hijack
# from execution import recursive_output_delete_if_changed
from execution import IsChangedCache

global _comfy_is_changed_cache_get
_comfy_is_changed_cache_get = IsChangedCache.get # type: ignore

IsChangedCache.get = IsChangedCache_get_hijack # type: ignore

from folder_paths import folder_names_and_paths as _comfy_folder_names_and_paths # type: ignore
from folder_paths import supported_pt_extensions as _comfy_supported_pt_extensions # type: ignore
from comfy.sd import load_checkpoint_guess_config as _comfy_load_checkpoint_guess_config
Expand All @@ -197,22 +201,8 @@ def do_comfy_import(
uniformer as _uniformer,
)

import comfy.model_management

# comfy.model_management.vram_state = comfy.model_management.VRAMState.HIGH_VRAM
# comfy.model_management.set_vram_to = comfy.model_management.VRAMState.HIGH_VRAM

logger.info("Comfy_Horde initialised")

# def always_cpu(parameters, dtype):
# return torch.device("cpu")

# comfy.model_management.unet_inital_load_device = always_cpu
# comfy.model_management.DISABLE_SMART_MEMORY = True
# comfy.model_management.lowvram_available = True

# comfy.model_management.unet_offload_device = _unet_offload_device_hijack

log_free_ram()
output_collector.replay()

Expand All @@ -221,39 +211,39 @@ def do_comfy_import(

_last_pipeline_settings_hash = ""

import PIL.Image


def default_json_serializer_pil_image(obj):
if isinstance(obj, PIL.Image.Image):
return str(hash(obj.__str__()))
return obj


def IsChangedCache_get_hijack(self, *args, **kwargs):
global _comfy_is_changed_cache_get
result = _comfy_is_changed_cache_get(self, *args, **kwargs)

def recursive_output_delete_if_changed_hijack(prompt: dict, old_prompt, outputs, current_item):
global _last_pipeline_settings_hash
if current_item == "prompt":
try:
pipeline_settings_hash = hashlib.md5(json.dumps(prompt).encode("utf-8")).hexdigest()
logger.debug(f"pipeline_settings_hash: {pipeline_settings_hash}")

if pipeline_settings_hash != _last_pipeline_settings_hash:
_last_pipeline_settings_hash = pipeline_settings_hash
logger.debug("Pipeline settings changed")

if old_prompt:
old_pipeline_settings_hash = hashlib.md5(json.dumps(old_prompt).encode("utf-8")).hexdigest()
logger.debug(f"old_pipeline_settings_hash: {old_pipeline_settings_hash}")
if pipeline_settings_hash != old_pipeline_settings_hash:
logger.debug("Pipeline settings changed from old_prompt")
except TypeError:
logger.debug("could not print hash due to source image in payload")
if current_item == "prompt" or current_item == "negative_prompt":
try:
prompt_text = prompt[current_item]["inputs"]["text"]
prompt_hash = hashlib.md5(prompt_text.encode("utf-8")).hexdigest()
logger.debug(f"{current_item} hash: {prompt_hash}")
except KeyError:
pass

global _comfy_recursive_output_delete_if_changed
return _comfy_recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item)
prompt = self.dynprompt.original_prompt

pipeline_settings_hash = hashlib.md5(
json.dumps(prompt, default=default_json_serializer_pil_image).encode(),
).hexdigest()

if pipeline_settings_hash != _last_pipeline_settings_hash:
_last_pipeline_settings_hash = pipeline_settings_hash
logger.debug(f"Pipeline settings changed: {pipeline_settings_hash}")
logger.debug(f"Cache length: {len(self.outputs_cache.cache)}")
logger.debug(f"Subcache length: {len(self.outputs_cache.subcaches)}")

logger.debug(f"IsChangedCache.dynprompt.all_node_ids: {self.dynprompt.all_node_ids()}")

if result:
logger.debug(f"IsChangedCache.get: {result}")

# def cleanup():
# _comfy_soft_empty_cache()
return result


def unload_all_models_vram():
Expand Down
13 changes: 10 additions & 3 deletions hordelib/nodes/node_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def load_checkpoint(
horde_model_name: str,
ckpt_name: str | None = None,
file_type: str | None = None,
output_vae=True,
output_clip=True,
output_vae=True, # this arg is required by comfyui internals
output_clip=True, # this arg is required by comfyui internals
preloading=False,
):
log_free_ram()
Expand Down Expand Up @@ -115,8 +115,15 @@ def load_checkpoint(

if ckpt_name is not None and Path(ckpt_name).is_absolute():
ckpt_path = ckpt_name
elif ckpt_name is not None:
full_path = folder_paths.get_full_path("checkpoints", ckpt_name)

if full_path is None:
raise ValueError(f"Checkpoint {ckpt_name} not found.")

ckpt_path = full_path
else:
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
raise ValueError("No checkpoint name provided.")

with torch.no_grad():
result = comfy.sd.load_checkpoint_guess_config(
Expand Down

0 comments on commit 67f44c2

Please sign in to comment.