diff --git a/README.md b/README.md index f64906f..7d0c42f 100644 --- a/README.md +++ b/README.md @@ -4,16 +4,16 @@ This is an Extension for [ComfyUI](https://github.com/comfyanonymous/ComfyUI), w > **ie.** This is not just a postprocessing filter ## How to Use -- Attach the **ReSharpen** node between `Empty Latent` and `KSampler` nodes +- Attach the **ReSharpen** node between `Empty Latent` and your `Sampler` node of choice - Adjust the **details** slider: - **Positive** values cause the images to be noisy - **Negative** values cause the images to be blurry -> Don't use values too close to `1` or `-1`, as it will become distorted +> Values too large or small may cause the result to become distorted! ### Important: - `Ancestral` samplers *(**eg.** `Euler a`)* do **not** work. -- The **enable** is "global." If you want to disable it during later part of the workflow *(**eg.** during `Hires. Fix`)*, you have to add another **ReSharpen** node and set it to disable. +- The effect is "global," meaning if you want to disable it during other parts of the workflow *(**eg.** during `Hires. Fix`)*, you need to add another **ReSharpen** node and set the `details` to `0` again. ## Examples diff --git a/__init__.py b/__init__.py index 090fbd2..fe98827 100644 --- a/__init__.py +++ b/__init__.py @@ -1,28 +1,26 @@ -from .resharpen import ReSharpen, disable +from .resharpen import ReSharpen, disable_resharpen +from functools import wraps +from typing import Callable import execution NODE_CLASS_MAPPINGS = {"Resharpen": ReSharpen} - NODE_DISPLAY_NAME_MAPPINGS = {"Resharpen": "ReSharpen"} def find_node(prompt: dict) -> bool: """Find any ReSharpen Node""" - for k, v in prompt.items(): - if v["class_type"] == "Resharpen": - return True - - return False + return any(v.get("class_type") == "Resharpen" for v in prompt.values()) -original_validate = execution.validate_prompt +original_validate: Callable = execution.validate_prompt -def hijack_validate(prompt): +@wraps(original_validate) +def hijack_validate(prompt: dict) -> Callable: if not find_node(prompt): - disable() + disable_resharpen() return original_validate(prompt) diff --git a/pyproject.toml b/pyproject.toml index f9c4933..cd055bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-resharpen" description = "Manipulate the details of generations" -version = "1.0.1" +version = "1.1.0" license = { text = "MIT License" } dependencies = [] diff --git a/resharpen.py b/resharpen.py index 212768c..bbd0537 100644 --- a/resharpen.py +++ b/resharpen.py @@ -1,46 +1,51 @@ -import comfy +from functools import wraps +from typing import Callable +import latent_preview +import torch -isEnabled = False -traj_cache = None -strength = 0.0 -ORIGINAL_SAMPLE = comfy.sample.sample -ORIGINAL_SAMPLE_CUSTOM = comfy.sample.sample_custom +ORIGINAL_PREP: Callable = latent_preview.prepare_callback +RESHARPEN_STRENGTH: float = 0.0 +LATENT_CACHE: torch.Tensor = None -def disable(): - global isEnabled - isEnabled = False +def disable_resharpen(): + """Reset the ReSharpen Strength""" + global RESHARPEN_STRENGTH + RESHARPEN_STRENGTH = 0.0 -def hijack(SAMPLE): - def sample_center(*args, **kwargs): - original_callback = kwargs["callback"] +def hijack(PREP) -> Callable: - def hijack_callback(step, x0, x, total_steps): - global isEnabled - global traj_cache - global strength + @wraps(PREP) + def prep_callback(*args, **kwargs): + original_callback: Callable = PREP(*args, **kwargs) + if not RESHARPEN_STRENGTH: + return original_callback + + print("[ReSharpen] Enabled~") - if not isEnabled: + @torch.inference_mode() + @wraps(original_callback) + def hijack_callback(step, x0, x, total_steps): + if not RESHARPEN_STRENGTH: return original_callback(step, x0, x, total_steps) - if traj_cache is not None: - delta = x.detach().clone() - traj_cache - x += delta * strength + global LATENT_CACHE + if LATENT_CACHE is not None: + delta = x.detach().clone() - LATENT_CACHE + x += delta * RESHARPEN_STRENGTH - traj_cache = x.detach().clone() + LATENT_CACHE = x.detach().clone() return original_callback(step, x0, x, total_steps) - kwargs["callback"] = hijack_callback - return SAMPLE(*args, **kwargs) + return hijack_callback - return sample_center + return prep_callback -comfy.sample.sample = hijack(ORIGINAL_SAMPLE) -comfy.sample.sample_custom = hijack(ORIGINAL_SAMPLE_CUSTOM) +latent_preview.prepare_callback = hijack(ORIGINAL_PREP) class ReSharpen: @@ -49,10 +54,9 @@ def INPUT_TYPES(s): return { "required": { "latent": ("LATENT",), - "enable": ("BOOLEAN", {"default": False}), "details": ( "FLOAT", - {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.1}, + {"default": 0.0, "min": -2.0, "max": 2.0, "step": 0.1}, ), } } @@ -61,14 +65,12 @@ def INPUT_TYPES(s): FUNCTION = "hook" CATEGORY = "latent" - def hook(self, latent, enable, details): - global isEnabled - isEnabled = enable + def hook(self, latent, details: float): + + global RESHARPEN_STRENGTH + RESHARPEN_STRENGTH = details / -10.0 - if isEnabled: - global traj_cache - traj_cache = None - global strength - strength = details / -10.0 + global LATENT_CACHE + LATENT_CACHE = None return (latent,)