Skip to content

Commit

Permalink
fix: support function injection w/o magic strings
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Jul 7, 2024
1 parent 2dbb175 commit 2bb99f2
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections.abc import Callable
from copy import deepcopy
from enum import Enum, auto
from types import FunctionType

from horde_sdk.ai_horde_api.apimodels import ImageGenerateJobPopResponse
from horde_sdk.ai_horde_api.apimodels.base import (
Expand Down Expand Up @@ -77,6 +78,17 @@ def __init__(
self.faults = faults


def _calc_upscale_sampler_steps(payload):
"""Calculates the amount of hires_fix upscaler steps based on the denoising used and the steps used for the
primary image"""
upscale_steps = round(payload["ddim_steps"] * (0.9 - payload["hires_fix_denoising_strength"]))
if upscale_steps < 3:
upscale_steps = 3

logger.debug(f"Upscale steps calculated as {upscale_steps}")
return upscale_steps


class HordeLib:
_instance: HordeLib | None = None
_initialised = False
Expand Down Expand Up @@ -227,7 +239,7 @@ class HordeLib:
"upscale_sampler.denoise": "hires_fix_denoising_strength",
"upscale_sampler.seed": "seed",
"upscale_sampler.cfg": "cfg_scale",
"upscale_sampler.steps": "func._calc_upscale_sampler_steps",
"upscale_sampler.steps": _calc_upscale_sampler_steps,
"upscale_sampler.sampler_name": "sampler_name",
"controlnet_apply.strength": "control_strength",
"controlnet_model_loader.control_net_name": "control_type",
Expand Down Expand Up @@ -563,14 +575,6 @@ def _apply_aihorde_compatibility_hacks(self, payload: dict) -> tuple[dict, list[
# del payload["denoising_strength"]
return payload, faults

def _calc_upscale_sampler_steps(self, payload):
"""Calculates the amount of hires_fix upscaler steps
Based on the denoising used and the steps used for the primary image"""
upscale_steps = round(payload["ddim_steps"] * (0.9 - payload["hires_fix_denoising_strength"]))
if upscale_steps < 3:
upscale_steps = 3
return upscale_steps

def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, list[GenMetadataEntry]]:
payload = deepcopy(payload)
faults: list[GenMetadataEntry] = []
Expand Down Expand Up @@ -814,12 +818,10 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
multiplier = None
# We allow a multiplier in the param, so that I can adjust easily the
# values for steps on things like stable cascade
if "*" in key:
if isinstance(key, FunctionType):
pipeline_params[newkey] = key(payload)
elif "*" in key:
key, multiplier = key.split("*", 1)
if key.startswith("func."):
key, func_name = key.split(".", 1)
parsing_function = getattr(self, func_name)
pipeline_params[newkey] = parsing_function(payload)
elif key in payload:
if multiplier:
pipeline_params[newkey] = round(payload.get(key) * float(multiplier))
Expand Down

0 comments on commit 2bb99f2

Please sign in to comment.