Skip to content

Commit

Permalink
Feat: Support for Flux (#325)
Browse files Browse the repository at this point in the history
* fix: ignore models/ dir

* feat: Flux support + img2img + batches

* fix: standardized baseline

* comment out dev tests

* fix/refactor: download/validate schnell during CI runs

Also adjusts the model name fixture logic to support the dynamic downloading/validating of future models with surrounding comments/warnings to extend the list downloading/validating relies on.

* fix: use latest compat. `horde_model_reference`

* fix: support for flux loras

* missing image

---------

Co-authored-by: tazlin <tazlin.on.github@gmail.com>
  • Loading branch information
db0 and tazlin authored Sep 14, 2024
1 parent 85973b2 commit efa3e0f
Show file tree
Hide file tree
Showing 15 changed files with 1,259 additions and 47 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ hordelib/model_database/stable_diffusion.json
hordelib/model_database/lora.json
ComfyUI
model.ckpt
models/
coverage.lcov
profiles/
longprompts.zip
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ repos:
types-tabulate,
types-tqdm,
types-urllib3,
horde_sdk==0.14.0,
horde_model_reference==0.8.1,
horde_sdk==0.14.3,
horde_model_reference==0.9.0,
]
30 changes: 24 additions & 6 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ class HordeLib:
"upscale_sampler.sampler_name": "sampler_name",
"controlnet_apply.strength": "control_strength",
"controlnet_model_loader.control_net_name": "control_type",
# Flux
"cfg_guider.cfg": "cfg_scale",
"random_noise.noise_seed": "seed",
"k_sampler_select.sampler_name": "sampler_name",
"basic_scheduler.denoise": "denoising_strength",
"basic_scheduler.steps": "ddim_steps",
# Stable Cascade
"stable_cascade_empty_latent_image.width": "width",
"stable_cascade_empty_latent_image.height": "height",
Expand Down Expand Up @@ -856,10 +862,15 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
)

# The last LORA always connects to the sampler and clip text encoders (via the clip_skip)
if lora_index == len(payload.get("loras")) - 1:
self.generator.reconnect_input(pipeline_data, "sampler.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "upscale_sampler.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "clip_skip.clip", f"lora_{lora_index}")
if lora_index == len(payload.get("loras")) - 1 and SharedModelManager.manager.compvis:
model_details = SharedModelManager.manager.compvis.get_model_reference_info(payload["model_name"])
if model_details is not None and model_details["baseline"] == "flux_1":
self.generator.reconnect_input(pipeline_data, "cfg_guider.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "basic_scheduler.model", f"lora_{lora_index}")
else:
self.generator.reconnect_input(pipeline_data, "sampler.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "upscale_sampler.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "clip_skip.clip", f"lora_{lora_index}")

# Translate the payload parameters into pipeline parameters
pipeline_params = {}
Expand All @@ -885,7 +896,7 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis

# We inject these parameters to ensure the HordeCheckpointLoader knows what file to load, if necessary
# We don't want to hardcode this into the pipeline.json as we export this directly from ComfyUI
# and don't want to have to rememebr to re-add those keys
# and don't want to have to rememeber to re-add those keys
if "model_loader_stage_c.ckpt_name" in pipeline_params:
pipeline_params["model_loader_stage_c.file_type"] = "stable_cascade_stage_c"
if "model_loader_stage_b.ckpt_name" in pipeline_params:
Expand Down Expand Up @@ -990,7 +1001,12 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
# We do this by reconnecting the nodes in the pipeline to make the input to the vae encoder
# the source image instead of the latent noise generator
if pipeline_params.get("image_loader.image"):
self.generator.reconnect_input(pipeline_data, "sampler.latent_image", "vae_encode")
if SharedModelManager.manager.compvis:
model_details = SharedModelManager.manager.compvis.get_model_reference_info(payload["model_name"])
if isinstance(model_details, dict) and model_details.get("baseline") == "flux_1":
self.generator.reconnect_input(pipeline_data, "sampler_custom_advanced.latent_image", "vae_encode")
else:
self.generator.reconnect_input(pipeline_data, "sampler.latent_image", "vae_encode")
if pipeline_params.get("sc_image_loader.image"):
self.generator.reconnect_input(
pipeline_data,
Expand Down Expand Up @@ -1181,6 +1197,8 @@ def _get_appropriate_pipeline(self, params):
if params.get("hires_fix", False):
return "stable_cascade_2pass"
return "stable_cascade"
if model_details.get("baseline") == "flux_1":
return "flux"
if params.get("control_type"):
if params.get("return_control_map", False):
return "controlnet_annotator"
Expand Down
6 changes: 3 additions & 3 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,13 @@ def _parse_civitai_lora_data(self, item, adhoc=False):
logger.debug(f"Rejecting LoRa {lora.get('name')} because it doesn't have a url")
return None
# We don't want to start downloading GBs of a single LoRa.
# We just ignore anything over 150Mb. Them's the breaks...
# We just ignore anything over 400Mb. Them's the breaks...
if (
lora["versions"][lora_version]["adhoc"]
and lora["versions"][lora_version]["size_mb"] > 220
and lora["versions"][lora_version]["size_mb"] > 400
and lora["id"] not in self._default_lora_ids
):
logger.debug(f"Rejecting LoRa {lora.get('name')} version {lora_version} because its size is over 220Mb.")
logger.debug(f"Rejecting LoRa {lora.get('name')} version {lora_version} because its size is over 400Mb.")
return None
if lora["versions"][lora_version]["adhoc"] and lora["nsfw"] and not self.nsfw:
logger.debug(f"Rejecting LoRa {lora.get('name')} because worker is SFW.")
Expand Down
Loading

0 comments on commit efa3e0f

Please sign in to comment.