diff --git a/.gitignore b/.gitignore
index 2bc819860..4d0ab36c6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -44,6 +44,9 @@ tunableop_results*.csv
!webui.sh
!package.json
+# dynamically generated
+/repositories/ip-instruct/
+
# all dynamic stuff
/extensions/**/*
/outputs/**/*
@@ -59,7 +62,6 @@ tunableop_results*.csv
.vscode/
.idea/
/localizations
-
.*/
# force included
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6cf7b67d8..39892eb0f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,16 +1,69 @@
# Change Log for SD.Next
-## Update for 2024-10-24
+## Update for 2024-10-29
-Improvements:
-- SD3 loader enhancements
+### Highlights for 2024-10-29
+
+- Support for **all SD3.x variants**
+ *SD3.0-Medium, SD3.5-Medium, SD3.5-Large, SD3.0-Large-Turbo*
+- Allow quantization using `bitsandbytes` on-the-fly during models load
+ Load any variant of SD3.x or FLUX.1 and apply quantization during load without the need for pre-quantized models
+- Allow for custom model URL in standard model selector
+ Can be used to specify any model from *HuggingFace* or *CivitAI*
+- Full support for `torch==2.5.1`
+- New wiki articles: [Gated Access](https://github.com/vladmandic/automatic/wiki/Gated), [Quantization](https://github.com/vladmandic/automatic/wiki/Quantization), [Offloading](https://github.com/vladmandic/automatic/wiki/Offload)
+
+Plus tons of smaller improvements and cumulative fixes reported since last release
+
+[README](https://github.com/vladmandic/automatic/blob/master/README.md) | [CHANGELOG](https://github.com/vladmandic/automatic/blob/master/CHANGELOG.md) | [WiKi](https://github.com/vladmandic/automatic/wiki) | [Discord](https://discord.com/invite/sd-next-federal-batch-inspectors-1101998836328697867)
+
+### Details for 2024-10-29
+
+- model selector:
+ - change-in-behavior
+ - when typing, it will auto-load model as soon as exactly one match is found
+ - allows entering model that are not on the list which triggers huggingface search
+ e.g. `stabilityai/stable-diffusion-xl-base-1.0`
+ partial search hits are displayed in the log
+ if exact model is found, it will be auto-downloaded and loaded
+ - allows entering civitai direct download link which triggers model download
+ e.g. `https://civitai.com/api/download/models/72396?type=Model&format=SafeTensor&size=full&fp=fp16`
+ - auto-search-and-download can be disabled in settings -> models -> auto-download
+ this also disables reference models as they are auto-downloaded on first use as well
+- sd3 enhancements:
+ - allow on-the-fly bnb quantization during load
- report when loading incomplete model
- - handle missing model components
+ - handle missing model components during load
- handle component preloading
-- OpenVINO: add accuracy option
-- ZLUDA: guess GPU arch
-
-Fixes:
+ - native lora handler
+ - support for all sd35 variants: *medium/large/large-turbo*
+ - gguf transformer loader (prototype)
+- flux.1 enhancements:
+ - allow on-the-fly bnb quantization during load
+- samplers:
+ - support for original k-diffusion samplers
+ select via *scripts -> k-diffusion -> sampler*
+- ipadapter:
+ - list available adapters based on loaded model type
+ - add adapter `ostris consistency` for sd15/sdxl
+- detailer:
+ - add `[prompt]` to refine/defailer prompts as placeholder referencing original prompt
+- torch
+ - use `torch==2.5.1` by default on supported platforms
+ - CUDA set device memory limit
+ in *settings -> compute settings -> torch memory limit*
+ default=0 meaning no limit, if set torch will limit memory usage to specified fraction
+ *note*: this is not a hard limit, torch will try to stay under this value
+- compute backends:
+ - OpenVINO: add accuracy option
+ - ZLUDA: guess GPU arch
+- major model load refactor
+- wiki: new articles
+ - [Gated Access Wiki](https://github.com/vladmandic/automatic/wiki/Gated)
+ - [Quantization Wiki](https://github.com/vladmandic/automatic/wiki/Quantization)
+ - [Offloading Wiki](https://github.com/vladmandic/automatic/wiki/Offload)
+
+fixes:
- fix send-to-control
- fix k-diffusion
- fix sd3 img2img and hires
diff --git a/extensions-builtin/Lora/network_overrides.py b/extensions-builtin/Lora/network_overrides.py
index 24afb0c28..5334f3c1b 100644
--- a/extensions-builtin/Lora/network_overrides.py
+++ b/extensions-builtin/Lora/network_overrides.py
@@ -26,7 +26,7 @@
force_models = [ # forced always
'sc',
- 'sd3',
+ # 'sd3',
'kandinsky',
'hunyuandit',
'auraflow',
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index c0a8555e1..83aa6b40b 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -127,6 +127,8 @@ def load_diffusers(name, network_on_disk, lora_scale=shared.opts.extra_networks_
def load_network(name, network_on_disk) -> network.Network:
+ if not shared.sd_loaded:
+ return
t0 = time.time()
cached = lora_cache.get(name, None)
if debug:
diff --git a/extensions-builtin/sdnext-modernui b/extensions-builtin/sdnext-modernui
index 906bd2a98..9b721248d 160000
--- a/extensions-builtin/sdnext-modernui
+++ b/extensions-builtin/sdnext-modernui
@@ -1 +1 @@
-Subproject commit 906bd2a98ba0736c235925f2beaea787a050aeed
+Subproject commit 9b721248d55021cbb3e2976ccfa984a8e5b96f39
diff --git a/html/reference.json b/html/reference.json
index 8d26433e7..4a549586f 100644
--- a/html/reference.json
+++ b/html/reference.json
@@ -119,11 +119,19 @@
"preview": "stabilityai--stable-diffusion-3.jpg",
"extras": "sampler: Default, cfg_scale: 7.0"
},
+ "StabilityAI Stable Diffusion 3.5 Medium": {
+ "path": "stabilityai/stable-diffusion-3.5-medium",
+ "skip": true,
+ "variant": "fp16",
+ "desc": "Stable Diffusion 3.5 Medium is a Multimodal Diffusion Transformer with improvements (MMDiT-X) text-to-image model that features improved performance in image quality, typography, complex prompt understanding, and resource-efficiency.",
+ "preview": "stabilityai--stable-diffusion-3_5.jpg",
+ "extras": "sampler: Default, cfg_scale: 7.0"
+ },
"StabilityAI Stable Diffusion 3.5 Large": {
"path": "stabilityai/stable-diffusion-3.5-large",
"skip": true,
"variant": "fp16",
- "desc": "Stable Diffusion 3 Medium is a Multimodal Diffusion Transformer (MMDiT) text-to-image model that features greatly improved performance in image quality, typography, complex prompt understanding, and resource-efficiency",
+ "desc": "Stable Diffusion 3.5 Large is a Multimodal Diffusion Transformer (MMDiT) text-to-image model that features improved performance in image quality, typography, complex prompt understanding, and resource-efficiency.",
"preview": "stabilityai--stable-diffusion-3_5.jpg",
"extras": "sampler: Default, cfg_scale: 7.0"
},
@@ -131,7 +139,7 @@
"path": "stabilityai/stable-diffusion-3.5-large-turbo",
"skip": true,
"variant": "fp16",
- "desc": "Stable Diffusion 3 Medium is a Multimodal Diffusion Transformer (MMDiT) text-to-image model that features greatly improved performance in image quality, typography, complex prompt understanding, and resource-efficiency",
+ "desc": "Stable Diffusion 3.5 Large Turbo is a Multimodal Diffusion Transformer (MMDiT) text-to-image model with Adversarial Diffusion Distillation (ADD) that features improved performance in image quality, typography, complex prompt understanding, and resource-efficiency, with a focus on fewer inference steps.",
"preview": "stabilityai--stable-diffusion-3_5.jpg",
"extras": "sampler: Default, cfg_scale: 7.0"
},
diff --git a/installer.py b/installer.py
index ff5c9f70e..a17efa83a 100644
--- a/installer.py
+++ b/installer.py
@@ -227,9 +227,9 @@ def installed(package, friendly: str = None, reload = False, quiet = False):
exact = pkg_version == p[1]
if not exact and not quiet:
if args.experimental:
- log.warning(f"Package: {p[0]} {pkg_version} required {p[1]} allowing experimental")
+ log.warning(f"Package: {p[0]} installed={pkg_version} required={p[1]} allowing experimental")
else:
- log.warning(f"Package: {p[0]} {pkg_version} required {p[1]} version mismatch")
+ log.warning(f"Package: {p[0]} installed={pkg_version} required={p[1]} version mismatch")
ok = ok and (exact or args.experimental)
else:
if not quiet:
@@ -254,11 +254,12 @@ def uninstall(package, quiet = False):
@lru_cache()
def pip(arg: str, ignore: bool = False, quiet: bool = False, uv = True):
originalArg = arg
- uv = uv and args.uv
- pipCmd = "uv pip" if uv else "pip"
arg = arg.replace('>=', '==')
+ package = arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()
+ uv = uv and args.uv and not package.startswith('git+')
+ pipCmd = "uv pip" if uv else "pip"
if not quiet and '-r ' not in arg:
- log.info(f'Install: package="{arg.replace("install", "").replace("--upgrade", "").replace("--no-deps", "").replace("--force", "").replace(" ", " ").strip()}" mode={"uv" if uv else "pip"}')
+ log.info(f'Install: package="{package}" mode={"uv" if uv else "pip"}')
env_args = os.environ.get("PIP_EXTRA_ARGS", "")
all_args = f'{pip_log}{arg} {env_args}'.strip()
if not quiet:
@@ -454,7 +455,7 @@ def check_python(supported_minors=[9, 10, 11, 12], reason=None):
# check diffusers version
def check_diffusers():
- sha = 'e45c25d03aeb0a967d8aaa0f6a79f280f6838e1f'
+ sha = '0d1d267b12e47b40b0e8f265339c76e0f45f8c49'
pkg = pkg_resources.working_set.by_key.get('diffusers', None)
minor = int(pkg.version.split('.')[1] if pkg is not None else 0)
cur = opts.get('diffusers_version', '') if minor > 0 else ''
@@ -489,7 +490,7 @@ def install_cuda():
log.info('CUDA: nVidia toolkit detected')
install('onnxruntime-gpu', 'onnxruntime-gpu', ignore=True, quiet=True)
# return os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/cu124')
- return os.environ.get('TORCH_COMMAND', 'torch==2.4.1+cu124 torchvision==0.19.1+cu124 --index-url https://download.pytorch.org/whl/cu124')
+ return os.environ.get('TORCH_COMMAND', 'torch==2.5.1+cu124 torchvision==0.20.1+cu124 --index-url https://download.pytorch.org/whl/cu124')
def install_rocm_zluda():
@@ -549,6 +550,7 @@ def install_rocm_zluda():
log.warning("ZLUDA support: experimental")
error = None
from modules import zluda_installer
+ zluda_installer.set_default_agent(device)
try:
if args.reinstall_zluda:
zluda_installer.uninstall()
@@ -570,8 +572,10 @@ def install_rocm_zluda():
log.info('Using CPU-only torch')
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision')
else:
- if rocm.version is None or float(rocm.version) >= 6.1: # assume the latest if version check fails
- #torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/rocm6.1')
+ if rocm.version is None or float(rocm.version) > 6.1: # assume the latest if version check fails
+ # torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.5.1+rocm6.2 torchvision==0.20.1+rocm6.2 --index-url https://download.pytorch.org/whl/rocm6.2')
+ torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.4.1+rocm6.1 torchvision==0.19.1+rocm6.1 --index-url https://download.pytorch.org/whl/rocm6.1')
+ elif rocm.version == "6.1": # lock to 2.4.1, older rocm (5.7) uses torch 2.3
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.4.1+rocm6.1 torchvision==0.19.1+rocm6.1 --index-url https://download.pytorch.org/whl/rocm6.1')
elif rocm.version == "6.0": # lock to 2.4.1, older rocm (5.7) uses torch 2.3
torch_command = os.environ.get('TORCH_COMMAND', 'torch==2.4.1+rocm6.0 torchvision==0.19.1+rocm6.0 --index-url https://download.pytorch.org/whl/rocm6.0')
@@ -730,7 +734,7 @@ def check_torch():
else:
if args.use_zluda:
log.warning("ZLUDA failed to initialize: no HIP SDK found")
- log.info('Using CPU-only Torch')
+ log.warning('Torch: CPU-only version installed')
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision')
if 'torch' in torch_command and not args.version:
install(torch_command, 'torch torchvision', quiet=True)
@@ -817,6 +821,7 @@ def install_packages():
log.info('Verifying packages')
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git")
install(clip_package, 'clip', quiet=True)
+ install('open-clip-torch', no_deps=True, quiet=True)
# tensorflow_package = os.environ.get('TENSORFLOW_PACKAGE', 'tensorflow==2.13.0')
# tensorflow_package = os.environ.get('TENSORFLOW_PACKAGE', None)
# if tensorflow_package is not None:
diff --git a/javascript/sdnext.css b/javascript/sdnext.css
index 3c81e7d8e..cbd3cac83 100644
--- a/javascript/sdnext.css
+++ b/javascript/sdnext.css
@@ -38,7 +38,7 @@ td > div > span { overflow-y: auto; max-height: 3em; overflow-x: hidden; }
.gradio-button.secondary-down, .gradio-button.secondary-down:hover { box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset; }
.gradio-button.secondary-down:hover { background: var(--button-secondary-background-fill-hover); color: var(--button-secondary-text-color-hover); }
.gradio-button.tool { max-width: min-content; min-width: min-content !important; font-size: 20px !important; color: var(--body-text-color) !important; align-self: end; margin-bottom: 4px; }
-.gradio-checkbox { margin: 0.75em 1.5em 0 0; align-self: center; }
+.gradio-checkbox { margin-right: 1em !important; align-self: center; }
.gradio-column { min-width: min(160px, 100%) !important; }
.gradio-container { max-width: unset !important; padding: var(--block-label-padding) !important; }
.gradio-container .prose a, .gradio-container .prose a:visited{ color: unset; text-decoration: none; }
diff --git a/modules/devices.py b/modules/devices.py
index 490d2a54d..56ac50091 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -4,6 +4,7 @@
import contextlib
from functools import wraps
import torch
+from modules import rocm
from modules.errors import log, display, install as install_traceback
from installer import install
@@ -50,8 +51,8 @@ def has_zluda() -> bool:
if not cuda_ok:
return False
try:
- device = torch.device("cuda")
- return torch.cuda.get_device_name(device).endswith("[ZLUDA]")
+ dev = torch.device("cuda")
+ return torch.cuda.get_device_name(dev).endswith("[ZLUDA]")
except Exception:
return False
@@ -206,7 +207,7 @@ def torch_gc(force=False, fast=False):
force = True
if oom > previous_oom:
previous_oom = oom
- log.warning(f'GPU out-of-memory error: {mem}')
+ log.warning(f'Torch GPU out-of-memory error: {mem}')
force = True
if force:
# actual gc
@@ -246,13 +247,26 @@ def set_cuda_sync_mode(mode):
return
try:
import ctypes
- log.info(f'Set cuda sync: mode={mode}')
+ log.info(f'Torch CUDA sync: mode={mode}')
torch.cuda.set_device(torch.device(get_optimal_device_name()))
ctypes.CDLL('libcudart.so').cudaSetDeviceFlags({'auto': 0, 'spin': 1, 'yield': 2, 'block': 4}[mode])
except Exception:
pass
+def set_cuda_memory_limit():
+ if not cuda_ok or opts.cuda_mem_fraction == 0:
+ return
+ from modules.shared import cmd_opts
+ try:
+ torch_gc(force=True)
+ mem = torch.cuda.get_device_properties(device).total_memory
+ torch.cuda.set_per_process_memory_fraction(float(opts.cuda_mem_fraction), cmd_opts.device_id if cmd_opts.device_id is not None else 0)
+ log.info(f'Torch CUDA memory limit: fraction={opts.cuda_mem_fraction:.2f} limit={round(opts.cuda_mem_fraction * mem / 1024 / 1024)} total={round(mem / 1024 / 1024)}')
+ except Exception as e:
+ log.warning(f'Torch CUDA memory limit: fraction={opts.cuda_mem_fraction:.2f} {e}')
+
+
def test_fp16():
global fp16_ok # pylint: disable=global-statement
if fp16_ok is not None:
@@ -283,16 +297,14 @@ def test_bf16():
if sys.platform == "darwin" or backend == 'openvino' or backend == 'directml': # override
bf16_ok = False
return bf16_ok
- elif backend == 'zluda':
- device_name = torch.cuda.get_device_name(device)
- if device_name.startswith("AMD Radeon RX "): # only force AMD
- device_name = device_name.replace("AMD Radeon RX ", "").split(" ", maxsplit=1)[0]
- if len(device_name) == 4 and device_name[0] in {"5", "6"}: # RDNA 1 and 2
- bf16_ok = False
- return bf16_ok
- elif backend == 'rocm':
- gcn_arch = getattr(torch.cuda.get_device_properties(device), "gcnArchName", "gfx0000")[3:7]
- if len(gcn_arch) == 4 and gcn_arch[0:2] == "10": # RDNA 1 and 2
+ elif backend == 'rocm' or backend == 'zluda':
+ agent = None
+ if backend == 'rocm':
+ agent = rocm.Agent(getattr(torch.cuda.get_device_properties(device), "gcnArchName", "gfx0000"))
+ else:
+ from modules.zluda_installer import default_agent
+ agent = default_agent
+ if agent is not None and agent.gfx_version < 0x1100 and agent.arch != rocm.MicroArchitecture.CDNA: # all cards before RDNA 3 except for CDNA cards
bf16_ok = False
return bf16_ok
try:
@@ -450,6 +462,7 @@ def set_dtype():
def set_cuda_params():
override_ipex_math()
+ set_cuda_memory_limit()
set_cudnn_params()
set_sdpa_params()
set_dtype()
diff --git a/modules/extras.py b/modules/extras.py
index e22360f8a..162491580 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -188,7 +188,7 @@ def add_model_metadata(checkpoint_info):
_, extension = os.path.splitext(output_modelname)
if os.path.exists(output_modelname) and not kwargs.get("overwrite", False):
- return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model alredy exists: {output_modelname}"]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], f"Model alredy exists: {output_modelname}"]
if extension.lower() == ".safetensors":
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
else:
@@ -202,7 +202,7 @@ def add_model_metadata(checkpoint_info):
created_model.calculate_shorthash()
devices.torch_gc(force=True)
shared.state.end()
- return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Model saved to {output_modelname}"]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], f"Model saved to {output_modelname}"]
def run_modelconvert(model, checkpoint_formats, precision, conv_type, custom_name, unet_conv, text_encoder_conv,
diff --git a/modules/face/faceid.py b/modules/face/faceid.py
index 754ce59a3..e2d5efccb 100644
--- a/modules/face/faceid.py
+++ b/modules/face/faceid.py
@@ -6,9 +6,10 @@
import diffusers
import huggingface_hub as hf
from PIL import Image
-from modules import processing, shared, devices, extra_networks, sd_models, sd_hijack_freeu, script_callbacks, ipadapter
+from modules import processing, shared, devices, extra_networks, sd_hijack_freeu, script_callbacks, ipadapter, token_merge
from modules.sd_hijack_hypertile import context_hypertile_vae, context_hypertile_unet
+
FACEID_MODELS = {
"FaceID Base": "h94/IP-Adapter-FaceID/ip-adapter-faceid_sd15.bin",
"FaceID Plus v1": "h94/IP-Adapter-FaceID/ip-adapter-faceid-plus_sd15.bin",
@@ -69,7 +70,7 @@ def face_id(
shared.prompt_styles.apply_styles_to_extra(p)
if shared.opts.cuda_compile_backend == 'none':
- sd_models.apply_token_merging(p.sd_model)
+ token_merge.apply_token_merging(p.sd_model)
sd_hijack_freeu.apply_freeu(p, not shared.native)
script_callbacks.before_process_callback(p)
@@ -246,7 +247,7 @@ def face_id(
if faceid_model is not None and original_load_ip_adapter is not None:
faceid_model.__class__.load_ip_adapter = original_load_ip_adapter
if shared.opts.cuda_compile_backend == 'none':
- sd_models.remove_token_merging(p.sd_model)
+ token_merge.remove_token_merging(p.sd_model)
script_callbacks.after_process_callback(p)
return processed_images
diff --git a/modules/intel/ipex/attention.py b/modules/intel/ipex/attention.py
index 22c74a78b..1618045b6 100644
--- a/modules/intel/ipex/attention.py
+++ b/modules/intel/ipex/attention.py
@@ -136,11 +136,11 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
- if attn_mask is not None and attn_mask.shape != query.shape:
+ if attn_mask is not None and attn_mask.shape[:-1] != query.shape[:-1]:
if len(query.shape) == 4:
- attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1))
+ attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
else:
- attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2]))
+ attn_mask = attn_mask.expand((query.shape[0], query.shape[1], key.shape[-2]))
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
diff --git a/modules/ipadapter.py b/modules/ipadapter.py
index 0ab27f03c..0a010c41c 100644
--- a/modules/ipadapter.py
+++ b/modules/ipadapter.py
@@ -3,8 +3,6 @@
- Downloads image_encoder or first usage (2.5GB)
- Introduced via: https://github.com/huggingface/diffusers/pull/5713
- IP adapters: https://huggingface.co/h94/IP-Adapter
-TODO ipadapter items:
-- SD/SDXL autodetect
"""
import os
@@ -14,21 +12,41 @@
from modules import processing, shared, devices, sd_models
-base_repo = "h94/IP-Adapter"
+clip_repo = "h94/IP-Adapter"
clip_loaded = None
-ADAPTERS = {
- 'None': 'none',
- 'Base': 'ip-adapter_sd15.safetensors',
- 'Base ViT-G': 'ip-adapter_sd15_vit-G.safetensors',
- 'Light': 'ip-adapter_sd15_light.safetensors',
- 'Plus': 'ip-adapter-plus_sd15.safetensors',
- 'Plus Face': 'ip-adapter-plus-face_sd15.safetensors',
- 'Full Face': 'ip-adapter-full-face_sd15.safetensors',
- 'Base SDXL': 'ip-adapter_sdxl.safetensors',
- 'Base ViT-H SDXL': 'ip-adapter_sdxl_vit-h.safetensors',
- 'Plus ViT-H SDXL': 'ip-adapter-plus_sdxl_vit-h.safetensors',
- 'Plus Face ViT-H SDXL': 'ip-adapter-plus-face_sdxl_vit-h.safetensors',
+ADAPTERS_NONE = {
+ 'None': { 'name': 'none', 'repo': 'none', 'subfolder': 'none' },
}
+ADAPTERS_SD15 = {
+ 'None': { 'name': 'none', 'repo': 'none', 'subfolder': 'none' },
+ 'Base': { 'name': 'ip-adapter_sd15.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'models' },
+ 'Base ViT-G': { 'name': 'ip-adapter_sd15_vit-G.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'models' },
+ 'Light': { 'name': 'ip-adapter_sd15_light.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'models' },
+ 'Plus': { 'name': 'ip-adapter-plus_sd15.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'models' },
+ 'Plus Face': { 'name': 'ip-adapter-plus-face_sd15.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'models' },
+ 'Full Face': { 'name': 'ip-adapter-full-face_sd15.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'models' },
+ 'Ostris Composition ViT-H': { 'name': 'ip_plus_composition_sd15.safetensors', 'repo': 'ostris/ip-composition-adapter', 'subfolder': '' },
+}
+ADAPTERS_SDXL = {
+ 'None': { 'name': 'none', 'repo': 'none', 'subfolder': 'none' },
+ 'Base SDXL': { 'name': 'ip-adapter_sdxl.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'sdxl_models' },
+ 'Base ViT-H SDXL': { 'name': 'ip-adapter_sdxl_vit-h.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'sdxl_models' },
+ 'Plus ViT-H SDXL': { 'name': 'ip-adapter-plus_sdxl_vit-h.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'sdxl_models' },
+ 'Plus Face ViT-H SDXL': { 'name': 'ip-adapter-plus-face_sdxl_vit-h.safetensors', 'repo': 'h94/IP-Adapter', 'subfolder': 'sdxl_models' },
+ 'Ostris Composition ViT-H SDXL': { 'name': 'ip_plus_composition_sdxl.safetensors', 'repo': 'ostris/ip-composition-adapter', 'subfolder': '' },
+}
+ADAPTERS = { **ADAPTERS_SD15, **ADAPTERS_SDXL }
+
+
+def get_adapters():
+ global ADAPTERS # pylint: disable=global-statement
+ if shared.sd_model_type == 'sd':
+ ADAPTERS = ADAPTERS_SD15
+ elif shared.sd_model_type == 'sdxl':
+ ADAPTERS = ADAPTERS_SDXL
+ else:
+ ADAPTERS = ADAPTERS_NONE
+ return list(ADAPTERS)
def get_images(input_images):
@@ -117,13 +135,13 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
if hasattr(p, 'ip_adapter_names'):
if isinstance(p.ip_adapter_names, str):
p.ip_adapter_names = [p.ip_adapter_names]
- adapters = [ADAPTERS.get(adapter, None) for adapter in p.ip_adapter_names if adapter is not None and adapter.lower() != 'none']
+ adapters = [ADAPTERS.get(adapter_name, None) for adapter_name in p.ip_adapter_names if adapter_name is not None and adapter_name.lower() != 'none']
adapter_names = p.ip_adapter_names
else:
if isinstance(adapter_names, str):
adapter_names = [adapter_names]
adapters = [ADAPTERS.get(adapter, None) for adapter in adapter_names]
- adapters = [adapter for adapter in adapters if adapter is not None and adapter.lower() != 'none']
+ adapters = [adapter for adapter in adapters if adapter is not None and adapter['name'].lower() != 'none']
if len(adapters) == 0:
unapply(pipe)
if hasattr(p, 'ip_adapter_images'):
@@ -189,41 +207,48 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
for adapter_name in adapter_names:
# which clip to use
- if 'ViT' not in adapter_name:
- clip_repo = base_repo
- clip_subfolder = 'models/image_encoder' if shared.sd_model_type == 'sd' else 'sdxl_models/image_encoder' # defaults per model
+ if 'ViT' not in adapter_name: # defaults per model
+ if shared.sd_model_type == 'sd':
+ clip_subfolder = 'models/image_encoder'
+ else:
+ clip_subfolder = 'sdxl_models/image_encoder'
elif 'ViT-H' in adapter_name:
- clip_repo = base_repo
clip_subfolder = 'models/image_encoder' # this is vit-h
elif 'ViT-G' in adapter_name:
- clip_repo = base_repo
clip_subfolder = 'sdxl_models/image_encoder' # this is vit-g
else:
shared.log.error(f'IP adapter: unknown model type: {adapter_name}')
return False
- # load feature extractor used by ip adapter
- if pipe.feature_extractor is None:
+ # load feature extractor used by ip adapter
+ if pipe.feature_extractor is None:
+ try:
from transformers import CLIPImageProcessor
shared.log.debug('IP adapter load: feature extractor')
pipe.feature_extractor = CLIPImageProcessor()
- # load image encoder used by ip adapter
- if pipe.image_encoder is None or clip_loaded != f'{clip_repo}/{clip_subfolder}':
- try:
- from transformers import CLIPVisionModelWithProjection
- shared.log.debug(f'IP adapter load: image encoder="{clip_repo}/{clip_subfolder}"')
- pipe.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_repo, subfolder=clip_subfolder, torch_dtype=devices.dtype, cache_dir=shared.opts.diffusers_dir, use_safetensors=True)
- clip_loaded = f'{clip_repo}/{clip_subfolder}'
- except Exception as e:
- shared.log.error(f'IP adapter: failed to load image encoder: {e}')
- return False
- sd_models.move_model(pipe.image_encoder, devices.device)
+ except Exception as e:
+ shared.log.error(f'IP adapter load: feature extractor {e}')
+ return False
+
+ # load image encoder used by ip adapter
+ if pipe.image_encoder is None or clip_loaded != f'{clip_repo}/{clip_subfolder}':
+ try:
+ from transformers import CLIPVisionModelWithProjection
+ shared.log.debug(f'IP adapter load: image encoder="{clip_repo}/{clip_subfolder}"')
+ pipe.image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_repo, subfolder=clip_subfolder, torch_dtype=devices.dtype, cache_dir=shared.opts.diffusers_dir, use_safetensors=True)
+ clip_loaded = f'{clip_repo}/{clip_subfolder}'
+ except Exception as e:
+ shared.log.error(f'IP adapter load: image encoder="{clip_repo}/{clip_subfolder}" {e}')
+ return False
+ sd_models.move_model(pipe.image_encoder, devices.device)
# main code
- t0 = time.time()
- ip_subfolder = 'models' if shared.sd_model_type == 'sd' else 'sdxl_models'
try:
- pipe.load_ip_adapter([base_repo], subfolder=[ip_subfolder], weight_name=adapters)
+ t0 = time.time()
+ repos = [adapter['repo'] for adapter in adapters]
+ subfolders = [adapter['subfolder'] for adapter in adapters]
+ names = [adapter['name'] for adapter in adapters]
+ pipe.load_ip_adapter(repos, subfolder=subfolders, weight_name=names)
if hasattr(p, 'ip_adapter_layers'):
pipe.set_ip_adapter_scale(p.ip_adapter_layers)
ip_str = ';'.join(adapter_names) + ':' + json.dumps(p.ip_adapter_layers)
@@ -240,5 +265,5 @@ def apply(pipe, p: processing.StableDiffusionProcessing, adapter_names=[], adapt
t1 = time.time()
shared.log.info(f'IP adapter: {ip_str} image={adapter_images} mask={adapter_masks is not None} time={t1-t0:.2f}')
except Exception as e:
- shared.log.error(f'IP adapter failed to load: repo="{base_repo}" folder="{ip_subfolder}" weights={adapters} names={adapter_names} {e}')
+ shared.log.error(f'IP adapter load: adapters={adapter_names} repo={repos} folders={subfolders} names={names} {e}')
return True
diff --git a/modules/loader.py b/modules/loader.py
index a2970abfd..0711c2906 100644
--- a/modules/loader.py
+++ b/modules/loader.py
@@ -44,6 +44,8 @@
timer.startup.record("torch")
import transformers # pylint: disable=W0611,C0411
+from transformers import logging as transformers_logging # pylint: disable=W0611,C0411
+transformers_logging.set_verbosity_error()
timer.startup.record("transformers")
import accelerate # pylint: disable=W0611,C0411
@@ -61,6 +63,9 @@
import pydantic # pylint: disable=W0611,C0411
timer.startup.record("pydantic")
+import diffusers.utils.import_utils # pylint: disable=W0611,C0411
+diffusers.utils.import_utils._k_diffusion_available = True # pylint: disable=protected-access # monkey-patch since we use k-diffusion from git
+diffusers.utils.import_utils._k_diffusion_version = '0.0.12' # pylint: disable=protected-access
import diffusers # pylint: disable=W0611,C0411
import diffusers.loaders.single_file # pylint: disable=W0611,C0411
import huggingface_hub # pylint: disable=W0611,C0411
diff --git a/modules/model_flux.py b/modules/model_flux.py
index 38207f73b..c605702c8 100644
--- a/modules/model_flux.py
+++ b/modules/model_flux.py
@@ -122,10 +122,12 @@ def quant_flux_bnb(checkpoint_info, transformer, text_encoder_2):
bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
bnb_4bit_compute_dtype=devices.dtype
)
- if 'Model' in shared.opts.bnb_quantization and transformer is None:
+ if ('Model' in shared.opts.bnb_quantization) and (transformer is None):
transformer = diffusers.FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype)
shared.log.debug(f'Quantization: module=transformer type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
- if 'Text Encoder' in shared.opts.bnb_quantization and text_encoder_2 is None:
+ if ('Text Encoder' in shared.opts.bnb_quantization) and (text_encoder_2 is None):
+ if repo_id == 'sayakpaul/flux.1-dev-nf4':
+ repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
text_encoder_2 = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2", cache_dir=cache_dir, quantization_config=bnb_config, torch_dtype=devices.dtype)
shared.log.debug(f'Quantization: module=t5 type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
except Exception as e:
@@ -285,25 +287,26 @@ def load_flux(checkpoint_info, diffusers_load_config): # triggered by opts.sd_ch
errors.display(e, 'FLUX Quanto:')
# initialize pipeline with pre-loaded components
- components = {}
- transformer, text_encoder_2 = quant_flux_bnb(checkpoint_info, transformer, text_encoder_2)
+ kwargs = {}
+ # transformer, text_encoder_2 = quant_flux_bnb(checkpoint_info, transformer, text_encoder_2)
if transformer is not None:
- components['transformer'] = transformer
+ kwargs['transformer'] = transformer
sd_unet.loaded_unet = shared.opts.sd_unet
if text_encoder_1 is not None:
- components['text_encoder'] = text_encoder_1
+ kwargs['text_encoder'] = text_encoder_1
model_te.loaded_te = shared.opts.sd_text_encoder
if text_encoder_2 is not None:
- components['text_encoder_2'] = text_encoder_2
+ kwargs['text_encoder_2'] = text_encoder_2
model_te.loaded_te = shared.opts.sd_text_encoder
if vae is not None:
- components['vae'] = vae
- shared.log.debug(f'Load model: type=FLUX preloaded={list(components)}')
+ kwargs['vae'] = vae
+ shared.log.debug(f'Load model: type=FLUX preloaded={list(kwargs)}')
if repo_id == 'sayakpaul/flux.1-dev-nf4':
repo_id = 'black-forest-labs/FLUX.1-dev' # workaround since sayakpaul model is missing model_index.json
- for c in components:
- if components[c].dtype == torch.float32 and devices.dtype != torch.float32:
- shared.log.warning(f'Load model: type=FLUX component={c} dtype={components[c].dtype} cast dtype={devices.dtype}')
- components[c] = components[c].to(dtype=devices.dtype)
- pipe = diffusers.FluxPipeline.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **components, **diffusers_load_config)
+ for c in kwargs:
+ if kwargs[c].dtype == torch.float32 and devices.dtype != torch.float32:
+ shared.log.warning(f'Load model: type=FLUX component={c} dtype={kwargs[c].dtype} cast dtype={devices.dtype}')
+ kwargs[c] = kwargs[c].to(dtype=devices.dtype)
+ kwargs = model_quant.create_bnb_config(kwargs)
+ pipe = diffusers.FluxPipeline.from_pretrained(repo_id, cache_dir=shared.opts.diffusers_dir, **kwargs, **diffusers_load_config)
return pipe
diff --git a/modules/model_quant.py b/modules/model_quant.py
index d54d6ff6d..fab3a93a0 100644
--- a/modules/model_quant.py
+++ b/modules/model_quant.py
@@ -1,4 +1,5 @@
import sys
+import diffusers
from installer import install, log
@@ -6,6 +7,23 @@
quanto = None
+def create_bnb_config(kwargs):
+ from modules import shared, devices
+ if len(shared.opts.bnb_quantization) > 0:
+ if 'Model' in shared.opts.bnb_quantization and 'transformer' not in kwargs:
+ load_bnb()
+ bnb_config = diffusers.BitsAndBytesConfig(
+ load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
+ load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
+ bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
+ bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
+ bnb_4bit_compute_dtype=devices.dtype
+ )
+ kwargs['quantization_config'] = bnb_config
+ shared.log.debug(f'Quantization: module=all type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
+ return kwargs
+
+
def load_bnb(msg='', silent=False):
global bnb # pylint: disable=global-statement
if bnb is not None:
@@ -16,6 +34,8 @@ def load_bnb(msg='', silent=False):
try:
import bitsandbytes
bnb = bitsandbytes
+ diffusers.utils.import_utils._bitsandbytes_available = True # pylint: disable=protected-access
+ diffusers.utils.import_utils._bitsandbytes_version = '0.43.3' # pylint: disable=protected-access
return bnb
except Exception as e:
if len(msg) > 0:
@@ -23,6 +43,7 @@ def load_bnb(msg='', silent=False):
bnb = None
if not silent:
raise
+ return None
def load_quanto(msg='', silent=False):
@@ -42,6 +63,7 @@ def load_quanto(msg='', silent=False):
quanto = None
if not silent:
raise
+ return None
def get_quant(name):
diff --git a/modules/model_sd3.py b/modules/model_sd3.py
index 4c8d80c96..e369b197a 100644
--- a/modules/model_sd3.py
+++ b/modules/model_sd3.py
@@ -1,7 +1,7 @@
import os
import diffusers
import transformers
-from modules import shared, devices, sd_models, sd_unet, model_te
+from modules import shared, devices, sd_models, sd_unet, model_te, model_quant
def load_overrides(kwargs, cache_dir):
@@ -51,8 +51,7 @@ def load_overrides(kwargs, cache_dir):
def load_quants(kwargs, repo_id, cache_dir):
if len(shared.opts.bnb_quantization) > 0:
- from modules.model_quant import load_bnb
- load_bnb('Load model: type=SD3')
+ model_quant.load_bnb('Load model: type=SD3')
bnb_config = diffusers.BitsAndBytesConfig(
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
@@ -75,7 +74,7 @@ def load_missing(kwargs, fn, cache_dir):
if size > 15000:
repo_id = 'stabilityai/stable-diffusion-3.5-large'
else:
- repo_id = 'stabilityai/stable-diffusion-3-medium'
+ repo_id = 'stabilityai/stable-diffusion-3-medium-diffusers'
if 'text_encoder' not in kwargs and 'text_encoder' not in keys:
kwargs['text_encoder'] = transformers.CLIPTextModelWithProjection.from_pretrained(repo_id, subfolder='text_encoder', cache_dir=cache_dir, torch_dtype=devices.dtype)
shared.log.debug(f'Load model: type=SD3 missing=te1 repo="{repo_id}"')
@@ -85,6 +84,9 @@ def load_missing(kwargs, fn, cache_dir):
if 'text_encoder_3' not in kwargs and 'text_encoder_3' not in keys:
kwargs['text_encoder_3'] = transformers.T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_3", variant='fp16', cache_dir=cache_dir, torch_dtype=devices.dtype)
shared.log.debug(f'Load model: type=SD3 missing=te3 repo="{repo_id}"')
+ if 'vae' not in kwargs and 'vae' not in keys:
+ kwargs['vae'] = diffusers.AutoencoderKL.from_pretrained(repo_id, subfolder='vae', cache_dir=cache_dir, torch_dtype=devices.dtype)
+ shared.log.debug(f'Load model: type=SD3 missing=vae repo="{repo_id}"')
# if 'transformer' not in kwargs and 'transformer' not in keys:
# kwargs['transformer'] = diffusers.SD3Transformer2DModel.from_pretrained(default_repo_id, subfolder="transformer", cache_dir=cache_dir, torch_dtype=devices.dtype)
return kwargs
@@ -120,10 +122,11 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
kwargs = {}
kwargs = load_overrides(kwargs, cache_dir)
- kwargs = load_quants(kwargs, repo_id, cache_dir)
+ if fn is None or not os.path.exists(fn):
+ kwargs = load_quants(kwargs, repo_id, cache_dir)
loader = diffusers.StableDiffusion3Pipeline.from_pretrained
- if fn is not None and os.path.exists(fn):
+ if fn is not None and os.path.exists(fn) and os.path.isfile(fn):
if fn.endswith('.safetensors'):
loader = diffusers.StableDiffusion3Pipeline.from_single_file
kwargs = load_missing(kwargs, fn, cache_dir)
@@ -135,8 +138,9 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
else:
kwargs['variant'] = 'fp16'
- shared.log.debug(f'Load model: type=SD3 preloaded={list(kwargs)}')
+ shared.log.debug(f'Load model: type=SD3 kwargs={list(kwargs)}')
+ kwargs = model_quant.create_bnb_config(kwargs)
pipe = loader(
repo_id,
torch_dtype=devices.dtype,
@@ -144,5 +148,5 @@ def load_sd3(checkpoint_info, cache_dir=None, config=None):
config=config,
**kwargs,
)
- devices.torch_gc()
+ devices.torch_gc(force=True)
return pipe
diff --git a/modules/modelloader.py b/modules/modelloader.py
index f9b7a4497..8eab91597 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -99,12 +99,26 @@ def download_civit_preview(model_path: str, preview_url: str):
download_pbar = None
-def download_civit_model_thread(model_name, model_url, model_path, model_type, token):
+def download_civit_model_thread(model_name: str, model_url: str, model_path: str = "", model_type: str = "Model", token: str = None):
import hashlib
sha256 = hashlib.sha256()
- sha256.update(model_name.encode('utf-8'))
+ sha256.update(model_url.encode('utf-8'))
temp_file = sha256.hexdigest()[:8] + '.tmp'
+ headers = {}
+ starting_pos = 0
+ if os.path.isfile(temp_file):
+ starting_pos = os.path.getsize(temp_file)
+ headers['Range'] = f'bytes={starting_pos}-'
+ if token is not None and len(token) > 0:
+ headers['Authorization'] = f'Bearer {token}'
+
+ r = shared.req(model_url, headers=headers, stream=True)
+ total_size = int(r.headers.get('content-length', 0))
+ if model_name is None or len(model_name) == 0:
+ cn = r.headers.get('content-disposition', '')
+ model_name = cn.split('filename=')[-1].strip('"')
+
if model_type == 'LoRA':
model_file = os.path.join(shared.opts.lora_dir, model_path, model_name)
temp_file = os.path.join(shared.opts.lora_dir, model_path, temp_file)
@@ -124,17 +138,6 @@ def download_civit_model_thread(model_name, model_url, model_path, model_type, t
shared.log.warning(res)
return res
- headers = {}
- starting_pos = 0
- if os.path.isfile(temp_file):
- starting_pos = os.path.getsize(temp_file)
- res += f' resume={round(starting_pos/1024/1024)}Mb'
- headers['Range'] = f'bytes={starting_pos}-'
- if token is not None and len(token) > 0:
- headers['Authorization'] = f'Bearer {token}'
-
- r = shared.req(model_url, headers=headers, stream=True)
- total_size = int(r.headers.get('content-length', 0))
res += f' size={round((starting_pos + total_size)/1024/1024, 2)}Mb'
shared.log.info(res)
shared.state.begin('CivitAI')
@@ -177,7 +180,10 @@ def download_civit_model_thread(model_name, model_url, model_path, model_type, t
shared.log.debug(f'Model download complete: temp="{temp_file}" path="{model_file}"')
os.rename(temp_file, model_file)
shared.state.end()
- return res
+ if os.path.exists(model_file):
+ return model_file
+ else:
+ return None
def download_civit_model(model_url: str, model_name: str, model_path: str, model_type: str, token: str = None):
@@ -309,15 +315,18 @@ def load_diffusers_models(clear=True):
return diffuser_repos
-def find_diffuser(name: str):
+def find_diffuser(name: str, full=False):
repo = [r for r in diffuser_repos if name == r['name'] or name == r['friendly'] or name == r['path']]
if len(repo) > 0:
- return repo['name']
+ return [repo[0]['name']]
hf_api = hf.HfApi()
models = list(hf_api.list_models(model_name=name, library=['diffusers'], full=True, limit=20, sort="downloads", direction=-1))
shared.log.debug(f'Searching diffusers models: {name} {len(models) > 0}')
if len(models) > 0:
- return models[0].id
+ if not full:
+ return models[0].id
+ else:
+ return [m.id for m in models]
return None
diff --git a/modules/onnx_impl/__init__.py b/modules/onnx_impl/__init__.py
index 013387d8f..42f22ea3f 100644
--- a/modules/onnx_impl/__init__.py
+++ b/modules/onnx_impl/__init__.py
@@ -184,8 +184,8 @@ def preprocess_pipeline(p):
return shared.sd_model
-def ORTDiffusionModelPart_to(self, *args, **kwargs):
- self.parent_model = self.parent_model.to(*args, **kwargs)
+def ORTPipelinePart_to(self, *args, **kwargs):
+ self.parent_pipeline = self.parent_pipeline.to(*args, **kwargs)
return self
@@ -241,9 +241,9 @@ def initialize_onnx():
diffusers.ORTStableDiffusionXLPipeline = diffusers.OnnxStableDiffusionXLPipeline # Huggingface model compatibility
diffusers.ORTStableDiffusionXLImg2ImgPipeline = diffusers.OnnxStableDiffusionXLImg2ImgPipeline
- optimum.onnxruntime.modeling_diffusion._ORTDiffusionModelPart.to = ORTDiffusionModelPart_to # pylint: disable=protected-access
- except Exception:
- pass
+ optimum.onnxruntime.modeling_diffusion.ORTPipelinePart.to = ORTPipelinePart_to # pylint: disable=protected-access
+ except Exception as e:
+ log.debug(f'ONNX failed to initialize XL pipelines: {e}')
initialized = True
diff --git a/modules/onnx_impl/ui.py b/modules/onnx_impl/ui.py
index f73e477c4..49af8d98b 100644
--- a/modules/onnx_impl/ui.py
+++ b/modules/onnx_impl/ui.py
@@ -15,7 +15,7 @@ def create_ui():
from modules.ui_common import create_refresh_button
from modules.ui_components import DropdownMulti
from modules.shared import log, opts, cmd_opts, refresh_checkpoints
- from modules.sd_models import checkpoint_tiles, get_closet_checkpoint_match
+ from modules.sd_models import checkpoint_titles, get_closet_checkpoint_match
from modules.paths import sd_configs_path
from .execution_providers import ExecutionProvider, install_execution_provider
from .utils import check_diffusers_cache
@@ -46,7 +46,7 @@ def create_ui():
with gr.TabItem("Manage cache", id="manage_cache"):
cache_state_dirname = gr.Textbox(value=None, visible=False)
with gr.Row():
- model_dropdown = gr.Dropdown(label="Model", value="Please select model", choices=checkpoint_tiles())
+ model_dropdown = gr.Dropdown(label="Model", value="Please select model", choices=checkpoint_titles())
create_refresh_button(model_dropdown, refresh_checkpoints, {}, "onnx_cache_refresh_diffusers_model")
with gr.Row():
def remove_cache_onnx_converted(dirname: str):
diff --git a/modules/postprocess/yolo.py b/modules/postprocess/yolo.py
index b162240e0..44f72237e 100644
--- a/modules/postprocess/yolo.py
+++ b/modules/postprocess/yolo.py
@@ -191,12 +191,20 @@ def restore(self, np_image, p: processing.StableDiffusionProcessing = None):
pp = None
shared.opts.data['mask_apply_overlay'] = True
resolution = 512 if shared.sd_model_type in ['none', 'sd', 'lcm', 'unknown'] else 1024
+ orig_prompt: str = orig_p.get('all_prompts', [''])[0]
+ orig_negative: str = orig_p.get('all_negative_prompts', [''])[0]
prompt: str = orig_p.get('refiner_prompt', '')
negative: str = orig_p.get('refiner_negative', '')
if len(prompt) == 0:
- prompt = orig_p.get('all_prompts', [''])[0]
+ prompt = orig_prompt
+ else:
+ prompt = prompt.replace('[PROMPT]', orig_prompt)
+ prompt = prompt.replace('[prompt]', orig_prompt)
if len(negative) == 0:
- negative = orig_p.get('all_negative_prompts', [''])[0]
+ negative = orig_negative
+ else:
+ negative = negative.replace('[PROMPT]', orig_negative)
+ negative = negative.replace('[prompt]', orig_negative)
prompt_lines = prompt.split('\n')
negative_lines = negative.split('\n')
prompt = prompt_lines[i % len(prompt_lines)]
diff --git a/modules/processing.py b/modules/processing.py
index 04350ee39..99d0cb351 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -4,7 +4,7 @@
from contextlib import nullcontext
import numpy as np
from PIL import Image, ImageOps
-from modules import shared, devices, errors, images, scripts, memstats, lowvram, script_callbacks, extra_networks, detailer, sd_hijack_freeu, sd_models, sd_vae, processing_helpers, timer, face_restoration
+from modules import shared, devices, errors, images, scripts, memstats, lowvram, script_callbacks, extra_networks, detailer, sd_hijack_freeu, sd_models, sd_checkpoint, sd_vae, processing_helpers, timer, face_restoration, token_merge
from modules.sd_hijack_hypertile import context_hypertile_vae, context_hypertile_unet
from modules.processing_class import StableDiffusionProcessing, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, StableDiffusionProcessingControl # pylint: disable=unused-import
from modules.processing_info import create_infotext
@@ -46,7 +46,8 @@ def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="",
self.width = p.width if hasattr(p, 'width') else (self.images[0].width if len(self.images) > 0 else 0)
self.height = p.height if hasattr(p, 'height') else (self.images[0].height if len(self.images) > 0 else 0)
self.sampler_name = p.sampler_name or ''
- self.cfg_scale = p.cfg_scale or 0
+ self.cfg_scale = p.cfg_scale if p.cfg_scale > 1 else None
+ self.cfg_end = p.cfg_end if p.cfg_end < 0 else None
self.image_cfg_scale = p.image_cfg_scale or 0
self.steps = p.steps or 0
self.batch_size = max(1, p.batch_size)
@@ -96,6 +97,7 @@ def js(self):
"height": self.height,
"sampler_name": self.sampler_name,
"cfg_scale": self.cfg_scale,
+ "cfg_end": self.cfg_end,
"steps": self.steps,
"batch_size": self.batch_size,
"detailer": self.detailer,
@@ -136,11 +138,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
processed = None
try:
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
- if p.override_settings.get('sd_model_checkpoint', None) is not None and sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
+ if p.override_settings.get('sd_model_checkpoint', None) is not None and sd_checkpoint.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
shared.log.warning(f"Override not found: checkpoint={p.override_settings.get('sd_model_checkpoint', None)}")
p.override_settings.pop('sd_model_checkpoint', None)
sd_models.reload_model_weights()
- if p.override_settings.get('sd_model_refiner', None) is not None and sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_refiner')) is None:
+ if p.override_settings.get('sd_model_refiner', None) is not None and sd_checkpoint.checkpoint_aliases.get(p.override_settings.get('sd_model_refiner')) is None:
shared.log.warning(f"Override not found: refiner={p.override_settings.get('sd_model_refiner', None)}")
p.override_settings.pop('sd_model_refiner', None)
sd_models.reload_model_weights()
@@ -162,7 +164,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.prompt_styles.apply_styles_to_extra(p)
shared.prompt_styles.extract_comments(p)
if shared.opts.cuda_compile_backend == 'none':
- sd_models.apply_token_merging(p.sd_model)
+ token_merge.apply_token_merging(p.sd_model)
sd_hijack_freeu.apply_freeu(p, not shared.native)
if p.width is not None:
@@ -205,7 +207,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
finally:
pag.unapply()
if shared.opts.cuda_compile_backend == 'none':
- sd_models.remove_token_merging(p.sd_model)
+ token_merge.remove_token_merging(p.sd_model)
script_callbacks.after_process_callback(p)
diff --git a/modules/processing_correction.py b/modules/processing_correction.py
index c52f30ab3..e715d8c49 100644
--- a/modules/processing_correction.py
+++ b/modules/processing_correction.py
@@ -7,7 +7,8 @@
import torch
from modules import shared, sd_vae_taesd, devices
-debug = shared.log.trace if os.environ.get('SD_HDR_DEBUG', None) is not None else lambda *args, **kwargs: None
+debug_enabled = os.environ.get('SD_HDR_DEBUG', None) is not None
+debug = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
debug('Trace: HDR')
@@ -119,16 +120,18 @@ def correction_callback(p, timestep, kwargs):
if not any([p.hdr_clamp, p.hdr_mode, p.hdr_maximize, p.hdr_sharpen, p.hdr_color, p.hdr_brightness, p.hdr_tint_ratio]):
return kwargs
latents = kwargs["latents"]
- debug('')
- debug(f' Timestep: {timestep}')
+ if debug_enabled:
+ debug('')
+ debug(f' Timestep: {timestep}')
# debug(f'HDR correction: latents={latents.shape}')
if len(latents.shape) == 4: # standard batched latent
for i in range(latents.shape[0]):
latents[i] = correction(p, timestep, latents[i])
- debug(f"Full Mean: {latents[i].mean().item()}")
- debug(f"Channel Means: {latents[i].mean(dim=(-1, -2), keepdim=True).flatten().float().cpu().numpy()}")
- debug(f"Channel Mins: {latents[i].min(-1, keepdim=True)[0].min(-2, keepdim=True)[0].flatten().float().cpu().numpy()}")
- debug(f"Channel Maxes: {latents[i].max(-1, keepdim=True)[0].min(-2, keepdim=True)[0].flatten().float().cpu().numpy()}")
+ if debug_enabled:
+ debug(f"Full Mean: {latents[i].mean().item()}")
+ debug(f"Channel Means: {latents[i].mean(dim=(-1, -2), keepdim=True).flatten().float().cpu().numpy()}")
+ debug(f"Channel Mins: {latents[i].min(-1, keepdim=True)[0].min(-2, keepdim=True)[0].flatten().float().cpu().numpy()}")
+ debug(f"Channel Maxes: {latents[i].max(-1, keepdim=True)[0].min(-2, keepdim=True)[0].flatten().float().cpu().numpy()}")
elif len(latents.shape) == 5 and latents.shape[0] == 1: # probably animatediff
latents = latents.squeeze(0).permute(1, 0, 2, 3)
for i in range(latents.shape[0]):
diff --git a/modules/processing_info.py b/modules/processing_info.py
index 29513167d..e798211b1 100644
--- a/modules/processing_info.py
+++ b/modules/processing_info.py
@@ -41,11 +41,12 @@ def create_infotext(p: StableDiffusionProcessing, all_prompts=None, all_seeds=No
# basic
"Steps": p.steps,
"Seed": all_seeds[index],
- "Sampler": p.sampler_name,
- "CFG scale": p.cfg_scale,
+ "Sampler": p.sampler_name if p.sampler_name != 'Default' else None,
+ "CFG scale": p.cfg_scale if p.cfg_scale > 1.0 else None,
+ "CFG end": p.cfg_end if p.cfg_end < 1.0 else None,
"Size": f"{p.width}x{p.height}" if hasattr(p, 'width') and hasattr(p, 'height') else None,
"Batch": f'{p.n_iter}x{p.batch_size}' if p.n_iter > 1 or p.batch_size > 1 else None,
- "Parser": shared.opts.prompt_attention,
+ "Parser": shared.opts.prompt_attention.split()[0],
"Model": None if (not shared.opts.add_model_name_to_info) or (not shared.sd_model.sd_checkpoint_info.model_name) else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', ''),
"Model hash": getattr(p, 'sd_model_hash', None if (not shared.opts.add_model_hash_to_info) or (not shared.sd_model.sd_model_hash) else shared.sd_model.sd_model_hash),
"VAE": (None if not shared.opts.add_model_name_to_info or sd_vae.loaded_vae_file is None else os.path.splitext(os.path.basename(sd_vae.loaded_vae_file))[0]) if p.full_quality else 'TAESD',
diff --git a/modules/processing_original.py b/modules/processing_original.py
index 852eb9a37..649023aae 100644
--- a/modules/processing_original.py
+++ b/modules/processing_original.py
@@ -1,7 +1,7 @@
import torch
import numpy as np
from PIL import Image
-from modules import shared, devices, processing, images, sd_models, sd_vae, sd_samplers, processing_helpers, prompt_parser
+from modules import shared, devices, processing, images, sd_vae, sd_samplers, processing_helpers, prompt_parser, token_merge
from modules.sd_hijack_hypertile import hypertile_set
@@ -135,10 +135,10 @@ def sample_txt2img(p: processing.StableDiffusionProcessingTxt2Img, conditioning,
p.sampler.initialize(p)
samples = samples[:, :, p.truncate_y//2:samples.shape[2]-(p.truncate_y+1)//2, p.truncate_x//2:samples.shape[3]-(p.truncate_x+1)//2]
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=p)
- sd_models.apply_token_merging(p.sd_model)
+ token_merge.apply_token_merging(p.sd_model)
hypertile_set(p, hr=True)
samples = p.sampler.sample_img2img(p, samples, noise, conditioning, unconditional_conditioning, steps=p.hr_second_pass_steps or p.steps, image_conditioning=image_conditioning)
- sd_models.apply_token_merging(p.sd_model)
+ token_merge.apply_token_merging(p.sd_model)
else:
p.ops.append('upscale')
x = None
diff --git a/modules/rocm.py b/modules/rocm.py
index 831932199..ef76a1cfa 100644
--- a/modules/rocm.py
+++ b/modules/rocm.py
@@ -52,37 +52,49 @@ class MicroArchitecture(Enum):
class Agent:
name: str
+ gfx_version: int
arch: MicroArchitecture
is_apu: bool
if sys.platform != "win32":
blaslt_supported: bool
+ @staticmethod
+ def parse_gfx_version(name: str) -> int:
+ result = 0
+ for i in range(3, len(name)):
+ if name[i].isdigit():
+ result *= 0x10
+ result += ord(name[i]) - 48
+ continue
+ if name[i] in "abcdef":
+ result *= 0x10
+ result += ord(name[i]) - 87
+ continue
+ break
+ return result
+
def __init__(self, name: str):
self.name = name
- gfx = name[3:7]
- if len(gfx) == 4:
+ self.gfx_version = Agent.parse_gfx_version(name)
+ if self.gfx_version > 0x1000:
self.arch = MicroArchitecture.RDNA
- elif gfx in ("908", "90a", "942",):
+ elif self.gfx_version in (0x908, 0x90a, 0x942,):
self.arch = MicroArchitecture.CDNA
else:
self.arch = MicroArchitecture.GCN
- self.is_apu = gfx.startswith("115") or gfx in ("801", "902", "90c", "1013", "1033", "1035", "1036", "1103",)
+ self.is_apu = (self.gfx_version & 0xFFF0 == 0x1150) or self.gfx_version in (0x801, 0x902, 0x90c, 0x1013, 0x1033, 0x1035, 0x1036, 0x1103,)
if sys.platform != "win32":
self.blaslt_supported = os.path.exists(os.path.join(HIPBLASLT_TENSILE_LIBPATH, f"extop_{name}.co"))
def get_gfx_version(self) -> Union[str, None]:
- if self.name.startswith("gfx12"):
+ if self.gfx_version >= 0x1200:
return "12.0.0"
- elif self.name.startswith("gfx11"):
+ elif self.gfx_version >= 0x1100:
return "11.0.0"
- elif self.name.startswith("gfx103"):
+ elif self.gfx_version >= 0x1000:
+ # gfx1010 users had to override gfx version to 10.3.0 in Linux
+ # it is unknown whether overriding is needed in ZLUDA
return "10.3.0"
- elif self.name.startswith("gfx102"):
- return "10.2.0"
- elif self.name.startswith("gfx101"):
- return "10.1.0"
- elif self.name.startswith("gfx100"):
- return "10.0.0"
return None
@@ -198,7 +210,7 @@ def get_flash_attention_command(agent: Agent):
if os.environ.get("FLASH_ATTENTION_USE_TRITON_ROCM", "FALSE") == "TRUE":
return "pytest git+https://github.com/ROCm/flash-attention@micmelesse/upstream_pr"
default = "git+https://github.com/ROCm/flash-attention"
- if agent.arch == MicroArchitecture.RDNA:
+ if agent.gfx_version >= 0x1100:
default = "git+https://github.com/ROCm/flash-attention@howiejay/navi_support"
return os.environ.get("FLASH_ATTENTION_PACKAGE", default)
diff --git a/modules/scripts.py b/modules/scripts.py
index 15da9c070..8a67d0a50 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -3,6 +3,7 @@
import sys
import time
from collections import namedtuple
+from dataclasses import dataclass
import gradio as gr
from modules import paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer
@@ -23,6 +24,11 @@ def __init__(self, images):
self.images = images
+@dataclass
+class OnComponent:
+ component: gr.blocks.Block
+
+
class Script:
parent = None
name = None
diff --git a/modules/sd_checkpoint.py b/modules/sd_checkpoint.py
new file mode 100644
index 000000000..e1787246f
--- /dev/null
+++ b/modules/sd_checkpoint.py
@@ -0,0 +1,385 @@
+import os
+import re
+import time
+import json
+import collections
+from modules import shared, paths, modelloader, hashes, sd_hijack_accelerate
+
+
+checkpoints_list = {}
+checkpoint_aliases = {}
+checkpoints_loaded = collections.OrderedDict()
+model_dir = "Stable-diffusion"
+model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
+sd_metadata_file = os.path.join(paths.data_path, "metadata.json")
+sd_metadata = None
+sd_metadata_pending = 0
+sd_metadata_timer = 0
+
+
+class CheckpointInfo:
+ def __init__(self, filename, sha=None):
+ self.name = None
+ self.hash = sha
+ self.filename = filename
+ self.type = ''
+ relname = filename
+ app_path = os.path.abspath(paths.script_path)
+
+ def rel(fn, path):
+ try:
+ return os.path.relpath(fn, path)
+ except Exception:
+ return fn
+
+ if relname.startswith('..'):
+ relname = os.path.abspath(relname)
+ if relname.startswith(shared.opts.ckpt_dir):
+ relname = rel(filename, shared.opts.ckpt_dir)
+ elif relname.startswith(shared.opts.diffusers_dir):
+ relname = rel(filename, shared.opts.diffusers_dir)
+ elif relname.startswith(model_path):
+ relname = rel(filename, model_path)
+ elif relname.startswith(paths.script_path):
+ relname = rel(filename, paths.script_path)
+ elif relname.startswith(app_path):
+ relname = rel(filename, app_path)
+ else:
+ relname = os.path.abspath(relname)
+ relname, ext = os.path.splitext(relname)
+ ext = ext.lower()[1:]
+
+ if os.path.isfile(filename): # ckpt or safetensor
+ self.name = relname
+ self.filename = filename
+ self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{relname}")
+ self.type = ext
+ if 'nf4' in filename:
+ self.type = 'transformer'
+ else: # maybe a diffuser
+ if self.hash is None:
+ repo = [r for r in modelloader.diffuser_repos if self.filename == r['name']]
+ else:
+ repo = [r for r in modelloader.diffuser_repos if self.hash == r['hash']]
+ if len(repo) == 0:
+ self.name = filename
+ self.filename = filename
+ self.sha256 = None
+ self.type = 'unknown'
+ else:
+ self.name = os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]['name'])
+ self.filename = repo[0]['path']
+ self.sha256 = repo[0]['hash']
+ self.type = 'diffusers'
+
+ self.shorthash = self.sha256[0:10] if self.sha256 else None
+ self.title = self.name if self.shorthash is None else f'{self.name} [{self.shorthash}]'
+ self.path = self.filename
+ self.model_name = os.path.basename(self.name)
+ self.metadata = read_metadata_from_safetensors(filename)
+ # shared.log.debug(f'Checkpoint: type={self.type} name={self.name} filename={self.filename} hash={self.shorthash} title={self.title}')
+
+ def register(self):
+ checkpoints_list[self.title] = self
+ for i in [self.name, self.filename, self.shorthash, self.title]:
+ if i is not None:
+ checkpoint_aliases[i] = self
+
+ def calculate_shorthash(self):
+ self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
+ if self.sha256 is None:
+ return None
+ self.shorthash = self.sha256[0:10]
+ if self.title in checkpoints_list:
+ checkpoints_list.pop(self.title)
+ self.title = f'{self.name} [{self.shorthash}]'
+ self.register()
+ return self.shorthash
+
+ def __str__(self):
+ return f'checkpoint: type={self.type} title="{self.title}" path="{self.path}"'
+
+
+def setup_model():
+ list_models()
+ sd_hijack_accelerate.hijack_hfhub()
+ # sd_hijack_accelerate.hijack_torch_conv()
+ if not shared.native:
+ enable_midas_autodownload()
+
+
+def checkpoint_titles(use_short=False): # pylint: disable=unused-argument
+ def convert(name):
+ return int(name) if name.isdigit() else name.lower()
+ def alphanumeric_key(key):
+ return [convert(c) for c in re.split('([0-9]+)', key)]
+ return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+
+
+def list_models():
+ t0 = time.time()
+ global checkpoints_list # pylint: disable=global-statement
+ checkpoints_list.clear()
+ checkpoint_aliases.clear()
+ ext_filter = [".safetensors"] if shared.opts.sd_disable_ckpt or shared.native else [".ckpt", ".safetensors"]
+ model_list = list(modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=ext_filter, download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"]))
+ for filename in sorted(model_list, key=str.lower):
+ checkpoint_info = CheckpointInfo(filename)
+ if checkpoint_info.name is not None:
+ checkpoint_info.register()
+ if shared.native:
+ for repo in modelloader.load_diffusers_models(clear=True):
+ checkpoint_info = CheckpointInfo(repo['name'], sha=repo['hash'])
+ if checkpoint_info.name is not None:
+ checkpoint_info.register()
+ if shared.cmd_opts.ckpt is not None:
+ if not os.path.exists(shared.cmd_opts.ckpt) and not shared.native:
+ if shared.cmd_opts.ckpt.lower() != "none":
+ shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found')
+ else:
+ checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt)
+ if checkpoint_info.name is not None:
+ checkpoint_info.register()
+ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
+ elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None:
+ shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found')
+ shared.log.info(f'Available Models: path="{shared.opts.ckpt_dir}" items={len(checkpoints_list)} time={time.time()-t0:.2f}')
+ checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename))
+
+def update_model_hashes():
+ txt = []
+ lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
+ # shared.log.info(f'Models list: short hash missing for {len(lst)} out of {len(checkpoints_list)} models')
+ for ckpt in lst:
+ ckpt.hash = model_hash(ckpt.filename)
+ # txt.append(f'Calculated short hash: {ckpt.title} {ckpt.hash}')
+ # txt.append(f'Updated short hashes for {len(lst)} out of {len(checkpoints_list)} models')
+ lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
+ shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
+ for ckpt in lst:
+ ckpt.sha256 = hashes.sha256(ckpt.filename, f"checkpoint/{ckpt.name}")
+ ckpt.shorthash = ckpt.sha256[0:10] if ckpt.sha256 is not None else None
+ if ckpt.sha256 is not None:
+ txt.append(f'Hash: {ckpt.title} {ckpt.shorthash}')
+ txt.append(f'Updated hashes for {len(lst)} out of {len(checkpoints_list)} models')
+ txt = '
'.join(txt)
+ return txt
+
+
+def get_closet_checkpoint_match(s: str):
+ if s.startswith('https://huggingface.co/'):
+ s = s.replace('https://huggingface.co/', '')
+ if s.startswith('huggingface/'):
+ model_name = s.replace('huggingface/', '')
+ checkpoint_info = CheckpointInfo(model_name) # create a virutal model info
+ checkpoint_info.type = 'huggingface'
+ return checkpoint_info
+
+ # alias search
+ checkpoint_info = checkpoint_aliases.get(s, None)
+ if checkpoint_info is not None:
+ return checkpoint_info
+
+ # models search
+ found = sorted([info for info in checkpoints_list.values() if os.path.basename(info.title).lower().startswith(s.lower())], key=lambda x: len(x.title))
+ if found and len(found) == 1:
+ return found[0]
+
+ # reference search
+ """
+ found = sorted([info for info in shared.reference_models.values() if os.path.basename(info['path']).lower().startswith(s.lower())], key=lambda x: len(x['path']))
+ if found and len(found) == 1:
+ checkpoint_info = CheckpointInfo(found[0]['path']) # create a virutal model info
+ checkpoint_info.type = 'huggingface'
+ return checkpoint_info
+ """
+
+ # huggingface search
+ if shared.opts.sd_checkpoint_autodownload and s.count('/') == 1:
+ modelloader.hf_login()
+ found = modelloader.find_diffuser(s, full=True)
+ shared.log.info(f'HF search: model="{s}" results={found}')
+ if found is not None and len(found) == 1 and found[0] == s:
+ checkpoint_info = CheckpointInfo(s)
+ checkpoint_info.type = 'huggingface'
+ return checkpoint_info
+
+ # civitai search
+ if shared.opts.sd_checkpoint_autodownload and s.startswith("https://civitai.com/api/download/models"):
+ fn = modelloader.download_civit_model_thread(model_name=None, model_url=s, model_path='', model_type='Model', token=None)
+ if fn is not None:
+ checkpoint_info = CheckpointInfo(fn)
+ return checkpoint_info
+
+ return None
+
+
+def model_hash(filename):
+ """old hash that only looks at a small part of the file and is prone to collisions"""
+ try:
+ with open(filename, "rb") as file:
+ import hashlib
+ # t0 = time.time()
+ m = hashlib.sha256()
+ file.seek(0x100000)
+ m.update(file.read(0x10000))
+ shorthash = m.hexdigest()[0:8]
+ # t1 = time.time()
+ # shared.log.debug(f'Calculating short hash: {filename} hash={shorthash} time={(t1-t0):.2f}')
+ return shorthash
+ except FileNotFoundError:
+ return 'NOFILE'
+ except Exception:
+ return 'NOHASH'
+
+
+def select_checkpoint(op='model'):
+ if op == 'dict':
+ model_checkpoint = shared.opts.sd_model_dict
+ elif op == 'refiner':
+ model_checkpoint = shared.opts.data.get('sd_model_refiner', None)
+ else:
+ model_checkpoint = shared.opts.sd_model_checkpoint
+ if model_checkpoint is None or model_checkpoint == 'None':
+ return None
+ checkpoint_info = get_closet_checkpoint_match(model_checkpoint)
+ if checkpoint_info is not None:
+ shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"')
+ return checkpoint_info
+ if len(checkpoints_list) == 0:
+ shared.log.warning("Cannot generate without a checkpoint")
+ shared.log.info("Set system paths to use existing folders")
+ shared.log.info(" or use --models-dir to specify base folder with all models")
+ shared.log.info(" or use --ckpt-dir to specify folder with sd models")
+ shared.log.info(" or use --ckpt to force using specific model")
+ return None
+ # checkpoint_info = next(iter(checkpoints_list.values()))
+ if model_checkpoint is not None:
+ if model_checkpoint != 'model.safetensors' and model_checkpoint != 'stabilityai/stable-diffusion-xl-base-1.0':
+ shared.log.info(f'Load {op}: search="{model_checkpoint}" not found')
+ else:
+ shared.log.info("Selecting first available checkpoint")
+ # shared.log.warning(f"Loading fallback checkpoint: {checkpoint_info.title}")
+ # shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
+ else:
+ shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"')
+ return checkpoint_info
+
+
+def read_metadata_from_safetensors(filename):
+ global sd_metadata # pylint: disable=global-statement
+ if sd_metadata is None:
+ sd_metadata = shared.readfile(sd_metadata_file, lock=True) if os.path.isfile(sd_metadata_file) else {}
+ res = sd_metadata.get(filename, None)
+ if res is not None:
+ return res
+ if not filename.endswith(".safetensors"):
+ return {}
+ if shared.cmd_opts.no_metadata:
+ return {}
+ res = {}
+ # try:
+ t0 = time.time()
+ with open(filename, mode="rb") as file:
+ try:
+ metadata_len = file.read(8)
+ metadata_len = int.from_bytes(metadata_len, "little")
+ json_start = file.read(2)
+ if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
+ shared.log.error(f'Model metadata invalid: file="{filename}"')
+ json_data = json_start + file.read(metadata_len-2)
+ json_obj = json.loads(json_data)
+ for k, v in json_obj.get("__metadata__", {}).items():
+ if v.startswith("data:"):
+ v = 'data'
+ if k == 'format' and v == 'pt':
+ continue
+ large = True if len(v) > 2048 else False
+ if large and k == 'ss_datasets':
+ continue
+ if large and k == 'workflow':
+ continue
+ if large and k == 'prompt':
+ continue
+ if large and k == 'ss_bucket_info':
+ continue
+ if v[0:1] == '{':
+ try:
+ v = json.loads(v)
+ if large and k == 'ss_tag_frequency':
+ v = { i: len(j) for i, j in v.items() }
+ if large and k == 'sd_merge_models':
+ scrub_dict(v, ['sd_merge_recipe'])
+ except Exception:
+ pass
+ res[k] = v
+ except Exception as e:
+ shared.log.error(f'Model metadata: file="{filename}" {e}')
+ sd_metadata[filename] = res
+ global sd_metadata_pending # pylint: disable=global-statement
+ sd_metadata_pending += 1
+ t1 = time.time()
+ global sd_metadata_timer # pylint: disable=global-statement
+ sd_metadata_timer += (t1 - t0)
+ # except Exception as e:
+ # shared.log.error(f"Error reading metadata from: {filename} {e}")
+ return res
+
+
+def enable_midas_autodownload():
+ """
+ Gives the ldm.modules.midas.api.load_model function automatic downloading.
+
+ When the 512-depth-ema model, and other future models like it, is loaded,
+ it calls midas.api.load_model to load the associated midas depth model.
+ This function applies a wrapper to download the model to the correct
+ location automatically.
+ """
+ from urllib import request
+ import ldm.modules.midas.api
+ midas_path = os.path.join(paths.models_path, 'midas')
+ for k, v in ldm.modules.midas.api.ISL_PATHS.items():
+ file_name = os.path.basename(v)
+ ldm.modules.midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
+ midas_urls = {
+ "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
+ "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
+ }
+ ldm.modules.midas.api.load_model_inner = ldm.modules.midas.api.load_model
+
+ def load_model_wrapper(model_type):
+ path = ldm.modules.midas.api.ISL_PATHS[model_type]
+ if not os.path.exists(path):
+ if not os.path.exists(midas_path):
+ os.mkdir(midas_path)
+ shared.log.info(f"Downloading midas model weights for {model_type} to {path}")
+ request.urlretrieve(midas_urls[model_type], path)
+ shared.log.info(f"{model_type} downloaded")
+ return ldm.modules.midas.api.load_model_inner(model_type)
+
+ ldm.modules.midas.api.load_model = load_model_wrapper
+
+
+def scrub_dict(dict_obj, keys):
+ for key in list(dict_obj.keys()):
+ if not isinstance(dict_obj, dict):
+ continue
+ if key in keys:
+ dict_obj.pop(key, None)
+ elif isinstance(dict_obj[key], dict):
+ scrub_dict(dict_obj[key], keys)
+ elif isinstance(dict_obj[key], list):
+ for item in dict_obj[key]:
+ scrub_dict(item, keys)
+
+
+def write_metadata():
+ global sd_metadata_pending # pylint: disable=global-statement
+ if sd_metadata_pending == 0:
+ shared.log.debug(f'Model metadata: file="{sd_metadata_file}" no changes')
+ return
+ shared.writefile(sd_metadata, sd_metadata_file)
+ shared.log.info(f'Model metadata saved: file="{sd_metadata_file}" items={sd_metadata_pending} time={sd_metadata_timer:.2f}')
+ sd_metadata_pending = 0
diff --git a/modules/sd_detect.py b/modules/sd_detect.py
new file mode 100644
index 000000000..7144a7be7
--- /dev/null
+++ b/modules/sd_detect.py
@@ -0,0 +1,150 @@
+import os
+import torch
+import diffusers
+from modules import shared, shared_items, devices, errors
+
+
+debug_load = os.environ.get('SD_LOAD_DEBUG', None)
+
+
+def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
+ guess = shared.opts.diffusers_pipeline
+ warn = shared.log.warning if warning else lambda *args, **kwargs: None
+ size = 0
+ pipeline = None
+ if guess == 'Autodetect':
+ try:
+ guess = 'Stable Diffusion XL' if 'XL' in f.upper() else 'Stable Diffusion'
+ # guess by size
+ if os.path.isfile(f) and f.endswith('.safetensors'):
+ size = round(os.path.getsize(f) / 1024 / 1024)
+ if (size > 0 and size < 128):
+ warn(f'Model size smaller than expected: {f} size={size} MB')
+ elif (size >= 316 and size <= 324) or (size >= 156 and size <= 164): # 320 or 160
+ warn(f'Model detected as VAE model, but attempting to load as model: {op}={f} size={size} MB')
+ guess = 'VAE'
+ elif (size >= 4970 and size <= 4976): # 4973
+ guess = 'Stable Diffusion 2' # SD v2 but could be eps or v-prediction
+ # elif size < 0: # unknown
+ # guess = 'Stable Diffusion 2B'
+ elif (size >= 5791 and size <= 5799): # 5795
+ if op == 'model':
+ warn(f'Model detected as SD-XL refiner model, but attempting to load a base model: {op}={f} size={size} MB')
+ guess = 'Stable Diffusion XL Refiner'
+ elif (size >= 6611 and size <= 7220): # 6617, HassakuXL is 6776, monkrenRealisticINT_v10 is 7217
+ guess = 'Stable Diffusion XL'
+ elif (size >= 3361 and size <= 3369): # 3368
+ guess = 'Stable Diffusion Upscale'
+ elif (size >= 4891 and size <= 4899): # 4897
+ guess = 'Stable Diffusion XL Inpaint'
+ elif (size >= 9791 and size <= 9799): # 9794
+ guess = 'Stable Diffusion XL Instruct'
+ elif (size > 3138 and size < 3142): #3140
+ guess = 'Stable Diffusion XL'
+ elif (size > 5692 and size < 5698) or (size > 4134 and size < 4138) or (size > 10362 and size < 10366) or (size > 15028 and size < 15228):
+ guess = 'Stable Diffusion 3'
+ elif (size > 18414 and size < 18420): # sd35-large aio
+ guess = 'Stable Diffusion 3'
+ elif (size > 20000 and size < 40000):
+ guess = 'FLUX'
+ # guess by name
+ """
+ if 'LCM_' in f.upper() or 'LCM-' in f.upper() or '_LCM' in f.upper() or '-LCM' in f.upper():
+ if shared.backend == shared.Backend.ORIGINAL:
+ warn(f'Model detected as LCM model, but attempting to load using backend=original: {op}={f} size={size} MB')
+ guess = 'Latent Consistency Model'
+ """
+ if 'instaflow' in f.lower():
+ guess = 'InstaFlow'
+ if 'segmoe' in f.lower():
+ guess = 'SegMoE'
+ if 'hunyuandit' in f.lower():
+ guess = 'HunyuanDiT'
+ if 'pixart-xl' in f.lower():
+ guess = 'PixArt-Alpha'
+ if 'stable-diffusion-3' in f.lower():
+ guess = 'Stable Diffusion 3'
+ if 'stable-cascade' in f.lower() or 'stablecascade' in f.lower() or 'wuerstchen3' in f.lower() or ('sotediffusion' in f.lower() and "v2" in f.lower()):
+ if devices.dtype == torch.float16:
+ warn('Stable Cascade does not support Float16')
+ guess = 'Stable Cascade'
+ if 'pixart-sigma' in f.lower():
+ guess = 'PixArt-Sigma'
+ if 'lumina-next' in f.lower():
+ guess = 'Lumina-Next'
+ if 'kolors' in f.lower():
+ guess = 'Kolors'
+ if 'auraflow' in f.lower():
+ guess = 'AuraFlow'
+ if 'cogview' in f.lower():
+ guess = 'CogView'
+ if 'meissonic' in f.lower():
+ guess = 'Meissonic'
+ pipeline = 'custom'
+ if 'omnigen' in f.lower():
+ guess = 'OmniGen'
+ pipeline = 'custom'
+ if 'flux' in f.lower():
+ guess = 'FLUX'
+ if size > 11000 and size < 20000:
+ warn(f'Model detected as FLUX UNET model, but attempting to load a base model: {op}={f} size={size} MB')
+ # switch for specific variant
+ if guess == 'Stable Diffusion' and 'inpaint' in f.lower():
+ guess = 'Stable Diffusion Inpaint'
+ elif guess == 'Stable Diffusion' and 'instruct' in f.lower():
+ guess = 'Stable Diffusion Instruct'
+ if guess == 'Stable Diffusion XL' and 'inpaint' in f.lower():
+ guess = 'Stable Diffusion XL Inpaint'
+ elif guess == 'Stable Diffusion XL' and 'instruct' in f.lower():
+ guess = 'Stable Diffusion XL Instruct'
+ # get actual pipeline
+ pipeline = shared_items.get_pipelines().get(guess, None) if pipeline is None else pipeline
+ if not quiet:
+ shared.log.info(f'Autodetect {op}: detect="{guess}" class={getattr(pipeline, "__name__", None)} file="{f}" size={size}MB')
+ except Exception as e:
+ shared.log.error(f'Autodetect {op}: file="{f}" {e}')
+ if debug_load:
+ errors.display(e, f'Load {op}: {f}')
+ return None, None
+ else:
+ try:
+ size = round(os.path.getsize(f) / 1024 / 1024)
+ pipeline = shared_items.get_pipelines().get(guess, None) if pipeline is None else pipeline
+ if not quiet:
+ shared.log.info(f'Load {op}: detect="{guess}" class={getattr(pipeline, "__name__", None)} file="{f}" size={size}MB')
+ except Exception as e:
+ shared.log.error(f'Load {op}: detect="{guess}" file="{f}" {e}')
+
+ if pipeline is None:
+ shared.log.warning(f'Load {op}: detect="{guess}" file="{f}" size={size} not recognized')
+ pipeline = diffusers.StableDiffusionPipeline
+ return pipeline, guess
+
+
+def get_load_config(model_file, model_type, config_type='yaml'):
+ if config_type == 'yaml':
+ yaml = os.path.splitext(model_file)[0] + '.yaml'
+ if os.path.exists(yaml):
+ return yaml
+ if model_type == 'Stable Diffusion':
+ return 'configs/v1-inference.yaml'
+ if model_type == 'Stable Diffusion XL':
+ return 'configs/sd_xl_base.yaml'
+ if model_type == 'Stable Diffusion XL Refiner':
+ return 'configs/sd_xl_refiner.yaml'
+ if model_type == 'Stable Diffusion 2':
+ return None # dont know if its eps or v so let diffusers sort it out
+ # return 'configs/v2-inference-512-base.yaml'
+ # return 'configs/v2-inference-768-v.yaml'
+ elif config_type == 'json':
+ if not shared.opts.diffuser_cache_config:
+ return None
+ if model_type == 'Stable Diffusion':
+ return 'configs/sd15'
+ if model_type == 'Stable Diffusion XL':
+ return 'configs/sdxl'
+ if model_type == 'Stable Diffusion 3':
+ return 'configs/sd3'
+ if model_type == 'FLUX':
+ return 'configs/flux'
+ return None
diff --git a/modules/sd_hijack_dynamic_atten.py b/modules/sd_hijack_dynamic_atten.py
index cb64482a5..de39a966f 100644
--- a/modules/sd_hijack_dynamic_atten.py
+++ b/modules/sd_hijack_dynamic_atten.py
@@ -57,11 +57,11 @@ def sliced_scaled_dot_product_attention(query, key, value, attn_mask=None, dropo
if do_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
- if attn_mask is not None and attn_mask.shape != query.shape:
+ if attn_mask is not None and attn_mask.shape[:-1] != query.shape[:-1]:
if len(query.shape) == 4:
- attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2], 1))
+ attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
else:
- attn_mask = attn_mask.repeat((batch_size_attention // attn_mask.shape[0], query_tokens // attn_mask.shape[1], shape_three // attn_mask.shape[2]))
+ attn_mask = attn_mask.expand((query.shape[0], query.shape[1], key.shape[-2]))
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 71a389c8a..734f4ab57 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,16 +1,12 @@
-import re
import io
import sys
-import json
import time
+import json
import copy
import inspect
import logging
import contextlib
-import collections
import os.path
-from os import mkdir
-from urllib import request
from enum import Enum
import diffusers
import diffusers.loaders.single_file_utils
@@ -18,20 +14,16 @@
import torch
import safetensors.torch
from omegaconf import OmegaConf
-from transformers import logging as transformers_logging
from ldm.util import instantiate_from_config
-from modules import paths, shared, shared_items, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, hashes, sd_models_config, sd_models_compile, sd_hijack_accelerate
+from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect
from modules.timer import Timer
from modules.memstats import memory_stats
from modules.modeldata import model_data
+from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import
-transformers_logging.set_verbosity_error()
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
-checkpoints_list = {}
-checkpoint_aliases = {}
-checkpoints_loaded = collections.OrderedDict()
sd_metadata_file = os.path.join(paths.data_path, "metadata.json")
sd_metadata = None
sd_metadata_pending = 0
@@ -40,85 +32,7 @@
debug_load = os.environ.get('SD_LOAD_DEBUG', None)
debug_process = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
diffusers_version = int(diffusers.__version__.split('.')[1])
-
-
-class CheckpointInfo:
- def __init__(self, filename, sha=None):
- self.name = None
- self.hash = sha
- self.filename = filename
- self.type = ''
- relname = filename
- app_path = os.path.abspath(paths.script_path)
-
- def rel(fn, path):
- try:
- return os.path.relpath(fn, path)
- except Exception:
- return fn
-
- if relname.startswith('..'):
- relname = os.path.abspath(relname)
- if relname.startswith(shared.opts.ckpt_dir):
- relname = rel(filename, shared.opts.ckpt_dir)
- elif relname.startswith(shared.opts.diffusers_dir):
- relname = rel(filename, shared.opts.diffusers_dir)
- elif relname.startswith(model_path):
- relname = rel(filename, model_path)
- elif relname.startswith(paths.script_path):
- relname = rel(filename, paths.script_path)
- elif relname.startswith(app_path):
- relname = rel(filename, app_path)
- else:
- relname = os.path.abspath(relname)
- relname, ext = os.path.splitext(relname)
- ext = ext.lower()[1:]
-
- if os.path.isfile(filename): # ckpt or safetensor
- self.name = relname
- self.filename = filename
- self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{relname}")
- self.type = ext
- if 'nf4' in filename:
- self.type = 'transformer'
- else: # maybe a diffuser
- if self.hash is None:
- repo = [r for r in modelloader.diffuser_repos if self.filename == r['name']]
- else:
- repo = [r for r in modelloader.diffuser_repos if self.hash == r['hash']]
- if len(repo) == 0:
- self.name = filename
- self.filename = filename
- self.sha256 = None
- self.type = 'unknown'
- else:
- self.name = os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]['name'])
- self.filename = repo[0]['path']
- self.sha256 = repo[0]['hash']
- self.type = 'diffusers'
-
- self.shorthash = self.sha256[0:10] if self.sha256 else None
- self.title = self.name if self.shorthash is None else f'{self.name} [{self.shorthash}]'
- self.path = self.filename
- self.model_name = os.path.basename(self.name)
- self.metadata = read_metadata_from_safetensors(filename)
- # shared.log.debug(f'Checkpoint: type={self.type} name={self.name} filename={self.filename} hash={self.shorthash} title={self.title}')
-
- def register(self):
- checkpoints_list[self.title] = self
- for i in [self.name, self.filename, self.shorthash, self.title]:
- if i is not None:
- checkpoint_aliases[i] = self
-
- def calculate_shorthash(self):
- self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
- if self.sha256 is None:
- return None
- self.shorthash = self.sha256[0:10]
- checkpoints_list.pop(self.title)
- self.title = f'{self.name} [{self.shorthash}]'
- self.register()
- return self.shorthash
+checkpoint_tiles = checkpoint_titles # legacy compatibility
class NoWatermark:
@@ -126,262 +40,6 @@ def apply_watermark(self, img):
return img
-def setup_model():
- list_models()
- sd_hijack_accelerate.hijack_hfhub()
- # sd_hijack_accelerate.hijack_torch_conv()
- if not shared.native:
- enable_midas_autodownload()
-
-
-def checkpoint_tiles(use_short=False): # pylint: disable=unused-argument
- def convert(name):
- return int(name) if name.isdigit() else name.lower()
- def alphanumeric_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
-
-
-def list_models():
- t0 = time.time()
- global checkpoints_list # pylint: disable=global-statement
- checkpoints_list.clear()
- checkpoint_aliases.clear()
- ext_filter = [".safetensors"] if shared.opts.sd_disable_ckpt or shared.native else [".ckpt", ".safetensors"]
- model_list = list(modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=ext_filter, download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"]))
- for filename in sorted(model_list, key=str.lower):
- checkpoint_info = CheckpointInfo(filename)
- if checkpoint_info.name is not None:
- checkpoint_info.register()
- if shared.native:
- for repo in modelloader.load_diffusers_models(clear=True):
- checkpoint_info = CheckpointInfo(repo['name'], sha=repo['hash'])
- if checkpoint_info.name is not None:
- checkpoint_info.register()
- if shared.cmd_opts.ckpt is not None:
- if not os.path.exists(shared.cmd_opts.ckpt) and not shared.native:
- if shared.cmd_opts.ckpt.lower() != "none":
- shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found')
- else:
- checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt)
- if checkpoint_info.name is not None:
- checkpoint_info.register()
- shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
- elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None:
- shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found')
- shared.log.info(f'Available Models: path="{shared.opts.ckpt_dir}" items={len(checkpoints_list)} time={time.time()-t0:.2f}')
- checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename))
-
-
-def update_model_hashes():
- txt = []
- lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
- # shared.log.info(f'Models list: short hash missing for {len(lst)} out of {len(checkpoints_list)} models')
- for ckpt in lst:
- ckpt.hash = model_hash(ckpt.filename)
- # txt.append(f'Calculated short hash: {ckpt.title} {ckpt.hash}')
- # txt.append(f'Updated short hashes for {len(lst)} out of {len(checkpoints_list)} models')
- lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
- shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
- for ckpt in lst:
- ckpt.sha256 = hashes.sha256(ckpt.filename, f"checkpoint/{ckpt.name}")
- ckpt.shorthash = ckpt.sha256[0:10] if ckpt.sha256 is not None else None
- if ckpt.sha256 is not None:
- txt.append(f'Calculated full hash: {ckpt.title} {ckpt.shorthash}')
- else:
- txt.append(f'Skipped hash calculation: {ckpt.title}')
- txt.append(f'Updated hashes for {len(lst)} out of {len(checkpoints_list)} models')
- txt = '
'.join(txt)
- return txt
-
-
-def get_closet_checkpoint_match(search_string):
- if search_string.startswith('huggingface/'):
- model_name = search_string.replace('huggingface/', '')
- checkpoint_info = CheckpointInfo(model_name) # create a virutal model info
- checkpoint_info.type = 'huggingface'
- return checkpoint_info
- checkpoint_info = checkpoint_aliases.get(search_string, None)
- if checkpoint_info is not None:
- return checkpoint_info
- found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
- if found and len(found) > 0:
- return found[0]
- found = sorted([info for info in checkpoints_list.values() if search_string.split(' ')[0] in info.title], key=lambda x: len(x.title))
- if found and len(found) > 0:
- return found[0]
- for v in shared.reference_models.values():
- pth = v['path'].split('@')[-1]
- if search_string in pth or os.path.basename(search_string) in pth:
- model_name = search_string.replace('huggingface/', '')
- checkpoint_info = CheckpointInfo(v['path']) # create a virutal model info
- checkpoint_info.type = 'huggingface'
- return checkpoint_info
- return None
-
-
-def model_hash(filename):
- """old hash that only looks at a small part of the file and is prone to collisions"""
- try:
- with open(filename, "rb") as file:
- import hashlib
- # t0 = time.time()
- m = hashlib.sha256()
- file.seek(0x100000)
- m.update(file.read(0x10000))
- shorthash = m.hexdigest()[0:8]
- # t1 = time.time()
- # shared.log.debug(f'Calculating short hash: {filename} hash={shorthash} time={(t1-t0):.2f}')
- return shorthash
- except FileNotFoundError:
- return 'NOFILE'
- except Exception:
- return 'NOHASH'
-
-
-def select_checkpoint(op='model'):
- if op == 'dict':
- model_checkpoint = shared.opts.sd_model_dict
- elif op == 'refiner':
- model_checkpoint = shared.opts.data.get('sd_model_refiner', None)
- else:
- model_checkpoint = shared.opts.sd_model_checkpoint
- if model_checkpoint is None or model_checkpoint == 'None':
- return None
- checkpoint_info = get_closet_checkpoint_match(model_checkpoint)
- if checkpoint_info is not None:
- shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"')
- return checkpoint_info
- if len(checkpoints_list) == 0:
- shared.log.warning("Cannot generate without a checkpoint")
- shared.log.info("Set system paths to use existing folders")
- shared.log.info(" or use --models-dir to specify base folder with all models")
- shared.log.info(" or use --ckpt-dir to specify folder with sd models")
- shared.log.info(" or use --ckpt to force using specific model")
- return None
- # checkpoint_info = next(iter(checkpoints_list.values()))
- if model_checkpoint is not None:
- if model_checkpoint != 'model.safetensors' and model_checkpoint != 'stabilityai/stable-diffusion-xl-base-1.0':
- shared.log.warning(f'Load {op}: select="{model_checkpoint}" not found')
- else:
- shared.log.info("Selecting first available checkpoint")
- # shared.log.warning(f"Loading fallback checkpoint: {checkpoint_info.title}")
- # shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
- else:
- shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"')
- return checkpoint_info
-
-
-checkpoint_dict_replacements = {
- 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
- 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
- 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
-}
-
-
-def transform_checkpoint_dict_key(k):
- for text, replacement in checkpoint_dict_replacements.items():
- if k.startswith(text):
- k = replacement + k[len(text):]
- return k
-
-
-def get_state_dict_from_checkpoint(pl_sd):
- pl_sd = pl_sd.pop("state_dict", pl_sd)
- pl_sd.pop("state_dict", None)
- sd = {}
- for k, v in pl_sd.items():
- new_key = transform_checkpoint_dict_key(k)
- if new_key is not None:
- sd[new_key] = v
- pl_sd.clear()
- pl_sd.update(sd)
- return pl_sd
-
-
-def write_metadata():
- global sd_metadata_pending # pylint: disable=global-statement
- if sd_metadata_pending == 0:
- shared.log.debug(f'Model metadata: file="{sd_metadata_file}" no changes')
- return
- shared.writefile(sd_metadata, sd_metadata_file)
- shared.log.info(f'Model metadata saved: file="{sd_metadata_file}" items={sd_metadata_pending} time={sd_metadata_timer:.2f}')
- sd_metadata_pending = 0
-
-
-def scrub_dict(dict_obj, keys):
- for key in list(dict_obj.keys()):
- if not isinstance(dict_obj, dict):
- continue
- if key in keys:
- dict_obj.pop(key, None)
- elif isinstance(dict_obj[key], dict):
- scrub_dict(dict_obj[key], keys)
- elif isinstance(dict_obj[key], list):
- for item in dict_obj[key]:
- scrub_dict(item, keys)
-
-
-def read_metadata_from_safetensors(filename):
- global sd_metadata # pylint: disable=global-statement
- if sd_metadata is None:
- sd_metadata = shared.readfile(sd_metadata_file, lock=True) if os.path.isfile(sd_metadata_file) else {}
- res = sd_metadata.get(filename, None)
- if res is not None:
- return res
- if not filename.endswith(".safetensors"):
- return {}
- if shared.cmd_opts.no_metadata:
- return {}
- res = {}
- # try:
- t0 = time.time()
- with open(filename, mode="rb") as file:
- try:
- metadata_len = file.read(8)
- metadata_len = int.from_bytes(metadata_len, "little")
- json_start = file.read(2)
- if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
- shared.log.error(f'Model metadata invalid: file="{filename}"')
- json_data = json_start + file.read(metadata_len-2)
- json_obj = json.loads(json_data)
- for k, v in json_obj.get("__metadata__", {}).items():
- if v.startswith("data:"):
- v = 'data'
- if k == 'format' and v == 'pt':
- continue
- large = True if len(v) > 2048 else False
- if large and k == 'ss_datasets':
- continue
- if large and k == 'workflow':
- continue
- if large and k == 'prompt':
- continue
- if large and k == 'ss_bucket_info':
- continue
- if v[0:1] == '{':
- try:
- v = json.loads(v)
- if large and k == 'ss_tag_frequency':
- v = { i: len(j) for i, j in v.items() }
- if large and k == 'sd_merge_models':
- scrub_dict(v, ['sd_merge_recipe'])
- except Exception:
- pass
- res[k] = v
- except Exception as e:
- shared.log.error(f'Model metadata: file="{filename}" {e}')
- sd_metadata[filename] = res
- global sd_metadata_pending # pylint: disable=global-statement
- sd_metadata_pending += 1
- t1 = time.time()
- global sd_metadata_timer # pylint: disable=global-statement
- sd_metadata_timer += (t1 - t0)
- # except Exception as e:
- # shared.log.error(f"Error reading metadata from: {filename} {e}")
- return res
-
-
def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pylint: disable=unused-argument
if not os.path.isfile(checkpoint_file):
shared.log.error(f'Load dict: path="{checkpoint_file}" not a file')
@@ -427,26 +85,55 @@ def get_safetensor_keys(filename):
return keys
+def get_state_dict_from_checkpoint(pl_sd):
+ checkpoint_dict_replacements = {
+ 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
+ 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
+ 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
+ }
+
+ def transform_checkpoint_dict_key(k):
+ for text, replacement in checkpoint_dict_replacements.items():
+ if k.startswith(text):
+ k = replacement + k[len(text):]
+ return k
+
+ pl_sd = pl_sd.pop("state_dict", pl_sd)
+ pl_sd.pop("state_dict", None)
+ sd = {}
+ for k, v in pl_sd.items():
+ new_key = transform_checkpoint_dict_key(k)
+ if new_key is not None:
+ sd[new_key] = v
+ pl_sd.clear()
+ pl_sd.update(sd)
+ return pl_sd
+
+
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
if not os.path.isfile(checkpoint_info.filename):
return None
+ """
if checkpoint_info in checkpoints_loaded:
shared.log.info("Load model: cache")
checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache
return checkpoints_loaded[checkpoint_info]
+ """
res = read_state_dict(checkpoint_info.filename, what='model')
+ """
if shared.opts.sd_checkpoint_cache > 0 and not shared.native:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = res
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
+ """
timer.record("load")
return res
def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, state_dict, timer):
- _pipeline, _model_type = detect_pipeline(checkpoint_info.path, 'model')
+ _pipeline, _model_type = sd_detect.detect_pipeline(checkpoint_info.path, 'model')
shared.log.debug(f'Load model: memory={memory_stats()}')
timer.record("hash")
if model_data.sd_dict == 'None':
@@ -498,41 +185,6 @@ def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo,
return True
-def enable_midas_autodownload():
- """
- Gives the ldm.modules.midas.api.load_model function automatic downloading.
-
- When the 512-depth-ema model, and other future models like it, is loaded,
- it calls midas.api.load_model to load the associated midas depth model.
- This function applies a wrapper to download the model to the correct
- location automatically.
- """
- import ldm.modules.midas.api
- midas_path = os.path.join(paths.models_path, 'midas')
- for k, v in ldm.modules.midas.api.ISL_PATHS.items():
- file_name = os.path.basename(v)
- ldm.modules.midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)
- midas_urls = {
- "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
- "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
- "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
- "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
- }
- ldm.modules.midas.api.load_model_inner = ldm.modules.midas.api.load_model
-
- def load_model_wrapper(model_type):
- path = ldm.modules.midas.api.ISL_PATHS[model_type]
- if not os.path.exists(path):
- if not os.path.exists(midas_path):
- mkdir(midas_path)
- shared.log.info(f"Downloading midas model weights for {model_type} to {path}")
- request.urlretrieve(midas_urls[model_type], path)
- shared.log.info(f"{model_type} downloaded")
- return ldm.modules.midas.api.load_model_inner(model_type)
-
- ldm.modules.midas.api.load_model = load_model_wrapper
-
-
def repair_config(sd_config):
if "use_ema" not in sd_config.model.params:
sd_config.model.params.use_ema = False
@@ -558,7 +210,6 @@ def change_backend():
unload_model_weights()
shared.backend = shared.Backend.ORIGINAL if shared.opts.sd_backend == 'original' else shared.Backend.DIFFUSERS
shared.native = shared.backend == shared.Backend.DIFFUSERS
- checkpoints_loaded.clear()
from modules.sd_samplers import list_samplers
list_samplers()
list_models()
@@ -566,118 +217,6 @@ def change_backend():
refresh_vae_list()
-def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False):
- guess = shared.opts.diffusers_pipeline
- warn = shared.log.warning if warning else lambda *args, **kwargs: None
- size = 0
- pipeline = None
- if guess == 'Autodetect':
- try:
- guess = 'Stable Diffusion XL' if 'XL' in f.upper() else 'Stable Diffusion'
- # guess by size
- if os.path.isfile(f) and f.endswith('.safetensors'):
- size = round(os.path.getsize(f) / 1024 / 1024)
- if (size > 0 and size < 128):
- warn(f'Model size smaller than expected: {f} size={size} MB')
- elif (size >= 316 and size <= 324) or (size >= 156 and size <= 164): # 320 or 160
- warn(f'Model detected as VAE model, but attempting to load as model: {op}={f} size={size} MB')
- guess = 'VAE'
- elif (size >= 4970 and size <= 4976): # 4973
- guess = 'Stable Diffusion 2' # SD v2 but could be eps or v-prediction
- # elif size < 0: # unknown
- # guess = 'Stable Diffusion 2B'
- elif (size >= 5791 and size <= 5799): # 5795
- if op == 'model':
- warn(f'Model detected as SD-XL refiner model, but attempting to load a base model: {op}={f} size={size} MB')
- guess = 'Stable Diffusion XL Refiner'
- elif (size >= 6611 and size <= 7220): # 6617, HassakuXL is 6776, monkrenRealisticINT_v10 is 7217
- guess = 'Stable Diffusion XL'
- elif (size >= 3361 and size <= 3369): # 3368
- guess = 'Stable Diffusion Upscale'
- elif (size >= 4891 and size <= 4899): # 4897
- guess = 'Stable Diffusion XL Inpaint'
- elif (size >= 9791 and size <= 9799): # 9794
- guess = 'Stable Diffusion XL Instruct'
- elif (size > 3138 and size < 3142): #3140
- guess = 'Stable Diffusion XL'
- elif (size > 5692 and size < 5698) or (size > 4134 and size < 4138) or (size > 10362 and size < 10366) or (size > 15028 and size < 15228):
- guess = 'Stable Diffusion 3'
- elif (size > 20000 and size < 40000):
- guess = 'FLUX'
- # guess by name
- """
- if 'LCM_' in f.upper() or 'LCM-' in f.upper() or '_LCM' in f.upper() or '-LCM' in f.upper():
- if shared.backend == shared.Backend.ORIGINAL:
- warn(f'Model detected as LCM model, but attempting to load using backend=original: {op}={f} size={size} MB')
- guess = 'Latent Consistency Model'
- """
- if 'instaflow' in f.lower():
- guess = 'InstaFlow'
- if 'segmoe' in f.lower():
- guess = 'SegMoE'
- if 'hunyuandit' in f.lower():
- guess = 'HunyuanDiT'
- if 'pixart-xl' in f.lower():
- guess = 'PixArt-Alpha'
- if 'stable-diffusion-3' in f.lower():
- guess = 'Stable Diffusion 3'
- if 'stable-cascade' in f.lower() or 'stablecascade' in f.lower() or 'wuerstchen3' in f.lower() or ('sotediffusion' in f.lower() and "v2" in f.lower()):
- if devices.dtype == torch.float16:
- warn('Stable Cascade does not support Float16')
- guess = 'Stable Cascade'
- if 'pixart-sigma' in f.lower():
- guess = 'PixArt-Sigma'
- if 'lumina-next' in f.lower():
- guess = 'Lumina-Next'
- if 'kolors' in f.lower():
- guess = 'Kolors'
- if 'auraflow' in f.lower():
- guess = 'AuraFlow'
- if 'cogview' in f.lower():
- guess = 'CogView'
- if 'meissonic' in f.lower():
- guess = 'Meissonic'
- pipeline = 'custom'
- if 'omnigen' in f.lower():
- guess = 'OmniGen'
- pipeline = 'custom'
- if 'flux' in f.lower():
- guess = 'FLUX'
- if size > 11000 and size < 20000:
- warn(f'Model detected as FLUX UNET model, but attempting to load a base model: {op}={f} size={size} MB')
- # switch for specific variant
- if guess == 'Stable Diffusion' and 'inpaint' in f.lower():
- guess = 'Stable Diffusion Inpaint'
- elif guess == 'Stable Diffusion' and 'instruct' in f.lower():
- guess = 'Stable Diffusion Instruct'
- if guess == 'Stable Diffusion XL' and 'inpaint' in f.lower():
- guess = 'Stable Diffusion XL Inpaint'
- elif guess == 'Stable Diffusion XL' and 'instruct' in f.lower():
- guess = 'Stable Diffusion XL Instruct'
- # get actual pipeline
- pipeline = shared_items.get_pipelines().get(guess, None) if pipeline is None else pipeline
- if not quiet:
- shared.log.info(f'Autodetect {op}: detect="{guess}" class={getattr(pipeline, "__name__", None)} file="{f}" size={size}MB')
- except Exception as e:
- shared.log.error(f'Autodetect {op}: file="{f}" {e}')
- if debug_load:
- errors.display(e, f'Load {op}: {f}')
- return None, None
- else:
- try:
- size = round(os.path.getsize(f) / 1024 / 1024)
- pipeline = shared_items.get_pipelines().get(guess, None) if pipeline is None else pipeline
- if not quiet:
- shared.log.info(f'Load {op}: detect="{guess}" class={getattr(pipeline, "__name__", None)} file="{f}" size={size}MB')
- except Exception as e:
- shared.log.error(f'Load {op}: detect="{guess}" file="{f}" {e}')
-
- if pipeline is None:
- shared.log.warning(f'Load {op}: detect="{guess}" file="{f}" size={size} not recognized')
- pipeline = diffusers.StableDiffusionPipeline
- return pipeline, guess
-
-
def copy_diffuser_options(new_pipe, orig_pipe):
new_pipe.sd_checkpoint_info = getattr(orig_pipe, 'sd_checkpoint_info', None)
new_pipe.sd_model_checkpoint = getattr(orig_pipe, 'sd_model_checkpoint', None)
@@ -836,6 +375,9 @@ def set_diffuser_offload(sd_model, op: str = 'model'):
def apply_balanced_offload(sd_model):
from accelerate import infer_auto_device_map, dispatch_model
from accelerate.hooks import add_hook_to_module, remove_hook_from_module, ModelHook
+ excluded = ['OmniGenPipeline']
+ if sd_model.__class__.__name__ in excluded:
+ return sd_model
class dispatch_from_cpu_hook(ModelHook):
def init_hook(self, module):
@@ -915,7 +457,6 @@ def move_model(model, device=None, force=False):
if hasattr(model.vae, '_hf_hook'):
debug_move(f'Model move: to={device} class={model.vae.__class__} fn={fn}') # pylint: disable=protected-access
model.vae._hf_hook.execution_device = device # pylint: disable=protected-access
- debug_move(f'Model move: device={device} class={model.__class__} accelerate={getattr(model, "has_accelerate", False)} fn={fn}') # pylint: disable=protected-access
if hasattr(model, "components"): # accelerate patch
for name, m in model.components.items():
if not hasattr(m, "_hf_hook"): # not accelerate hook
@@ -934,8 +475,9 @@ def move_model(model, device=None, force=False):
if hasattr(model, "device") and devices.normalize_device(model.device) == devices.normalize_device(device):
return
try:
+ t0 = time.time()
try:
- model.to(device)
+ model.to(device, non_blocking=True)
if hasattr(model, "prior_pipe"):
model.prior_pipe.to(device)
except Exception as e0:
@@ -955,8 +497,12 @@ def move_model(model, device=None, force=False):
pass # ignore model move if sequential offload is enabled
else:
raise e0
+ t1 = time.time()
except Exception as e1:
+ t1 = time.time()
shared.log.error(f'Model move: device={device} {e1}')
+ if os.environ.get('SD_MOVE_DEBUG', None) or (t1-t0) > 0.1:
+ shared.log.debug(f'Model move: device={device} class={model.__class__.__name__} accelerate={getattr(model, "has_accelerate", False)} fn={fn} time={t1-t0:.2f}') # pylint: disable=protected-access
devices.torch_gc()
@@ -975,35 +521,6 @@ def move_base(model, device):
return R
-def get_load_config(model_file, model_type, config_type='yaml'):
- if config_type == 'yaml':
- yaml = os.path.splitext(model_file)[0] + '.yaml'
- if os.path.exists(yaml):
- return yaml
- if model_type == 'Stable Diffusion':
- return 'configs/v1-inference.yaml'
- if model_type == 'Stable Diffusion XL':
- return 'configs/sd_xl_base.yaml'
- if model_type == 'Stable Diffusion XL Refiner':
- return 'configs/sd_xl_refiner.yaml'
- if model_type == 'Stable Diffusion 2':
- return None # dont know if its eps or v so let diffusers sort it out
- # return 'configs/v2-inference-512-base.yaml'
- # return 'configs/v2-inference-768-v.yaml'
- elif config_type == 'json':
- if not shared.opts.diffuser_cache_config:
- return None
- if model_type == 'Stable Diffusion':
- return 'configs/sd15'
- if model_type == 'Stable Diffusion XL':
- return 'configs/sdxl'
- if model_type == 'Stable Diffusion 3':
- return 'configs/sd3'
- if model_type == 'FLUX':
- return 'configs/flux'
- return None
-
-
def patch_diffuser_config(sd_model, model_file):
def load_config(fn, k):
model_file = os.path.splitext(fn)[0]
@@ -1120,65 +637,65 @@ def load_diffuser_folder(model_type, pipeline, checkpoint_info, diffusers_load_c
files = shared.walk_files(checkpoint_info.path, ['.safetensors', '.bin', '.ckpt'])
if 'variant' not in diffusers_load_config and any('diffusion_pytorch_model.fp16' in f for f in files): # deal with diffusers lack of variant fallback when loading
diffusers_load_config['variant'] = 'fp16'
- if model_type is not None and pipeline is not None and 'ONNX' in model_type: # forced pipeline
- try:
- sd_model = pipeline.from_pretrained(checkpoint_info.path)
- except Exception as e:
- shared.log.error(f'Load {op}: type=ONNX path="{checkpoint_info.path}" {e}')
- if debug_load:
- errors.display(e, 'Load')
- return None
- else:
- err1, err2, err3 = None, None, None
- if os.path.exists(checkpoint_info.path) and os.path.isdir(checkpoint_info.path):
- if os.path.exists(os.path.join(checkpoint_info.path, 'unet', 'diffusion_pytorch_model.bin')):
- shared.log.debug(f'Load {op}: type=pickle')
- diffusers_load_config['use_safetensors'] = False
+ if model_type is not None and pipeline is not None and 'ONNX' in model_type: # forced pipeline
+ try:
+ sd_model = pipeline.from_pretrained(checkpoint_info.path)
+ except Exception as e:
+ shared.log.error(f'Load {op}: type=ONNX path="{checkpoint_info.path}" {e}')
if debug_load:
- shared.log.debug(f'Load {op}: args={diffusers_load_config}')
- try: # 1 - autopipeline, best choice but not all pipelines are available
- try:
+ errors.display(e, 'Load')
+ return None
+ else:
+ err1, err2, err3 = None, None, None
+ if os.path.exists(checkpoint_info.path) and os.path.isdir(checkpoint_info.path):
+ if os.path.exists(os.path.join(checkpoint_info.path, 'unet', 'diffusion_pytorch_model.bin')):
+ shared.log.debug(f'Load {op}: type=pickle')
+ diffusers_load_config['use_safetensors'] = False
+ if debug_load:
+ shared.log.debug(f'Load {op}: args={diffusers_load_config}')
+ try: # 1 - autopipeline, best choice but not all pipelines are available
+ try:
+ sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
+ sd_model.model_type = sd_model.__class__.__name__
+ except ValueError as e:
+ if 'no variant default' in str(e):
+ shared.log.warning(f'Load {op}: variant={diffusers_load_config["variant"]} model="{checkpoint_info.path}" using default variant')
+ diffusers_load_config.pop('variant', None)
sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.model_type = sd_model.__class__.__name__
- except ValueError as e:
- if 'no variant default' in str(e):
- shared.log.warning(f'Load {op}: variant={diffusers_load_config["variant"]} model="{checkpoint_info.path}" using default variant')
- diffusers_load_config.pop('variant', None)
- sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
- sd_model.model_type = sd_model.__class__.__name__
- elif 'safetensors found in directory' in str(err1):
- shared.log.warning(f'Load {op}: type=pickle')
- diffusers_load_config['use_safetensors'] = False
- sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
- sd_model.model_type = sd_model.__class__.__name__
- else:
- raise ValueError from e # reraise
- except Exception as e:
- err1 = e
- if debug_load:
- errors.display(e, 'Load AutoPipeline')
- # shared.log.error(f'AutoPipeline: {e}')
- try: # 2 - diffusion pipeline, works for most non-linked pipelines
- if err1 is not None:
- sd_model = diffusers.DiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
- sd_model.model_type = sd_model.__class__.__name__
- except Exception as e:
- err2 = e
- if debug_load:
- errors.display(e, "Load DiffusionPipeline")
- # shared.log.error(f'DiffusionPipeline: {e}')
- try: # 3 - try basic pipeline just in case
- if err2 is not None:
- sd_model = diffusers.StableDiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
+ elif 'safetensors found in directory' in str(err1):
+ shared.log.warning(f'Load {op}: type=pickle')
+ diffusers_load_config['use_safetensors'] = False
+ sd_model = diffusers.AutoPipelineForText2Image.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
sd_model.model_type = sd_model.__class__.__name__
- except Exception as e:
- err3 = e # ignore last error
- shared.log.error(f"StableDiffusionPipeline: {e}")
- if debug_load:
- errors.display(e, "Load StableDiffusionPipeline")
- if err3 is not None:
- shared.log.error(f'Load {op}: {checkpoint_info.path} auto={err1} diffusion={err2}')
- return None
+ else:
+ raise ValueError from e # reraise
+ except Exception as e:
+ err1 = e
+ if debug_load:
+ errors.display(e, 'Load AutoPipeline')
+ # shared.log.error(f'AutoPipeline: {e}')
+ try: # 2 - diffusion pipeline, works for most non-linked pipelines
+ if err1 is not None:
+ sd_model = diffusers.DiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
+ sd_model.model_type = sd_model.__class__.__name__
+ except Exception as e:
+ err2 = e
+ if debug_load:
+ errors.display(e, "Load DiffusionPipeline")
+ # shared.log.error(f'DiffusionPipeline: {e}')
+ try: # 3 - try basic pipeline just in case
+ if err2 is not None:
+ sd_model = diffusers.StableDiffusionPipeline.from_pretrained(checkpoint_info.path, cache_dir=shared.opts.diffusers_dir, **diffusers_load_config)
+ sd_model.model_type = sd_model.__class__.__name__
+ except Exception as e:
+ err3 = e # ignore last error
+ shared.log.error(f"StableDiffusionPipeline: {e}")
+ if debug_load:
+ errors.display(e, "Load StableDiffusionPipeline")
+ if err3 is not None:
+ shared.log.error(f'Load {op}: {checkpoint_info.path} auto={err1} diffusion={err2}')
+ return None
return sd_model
@@ -1194,7 +711,7 @@ def load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_con
if shared.opts.diffusers_force_zeros:
diffusers_load_config['force_zeros_for_empty_prompt '] = shared.opts.diffusers_force_zeros
else:
- model_config = get_load_config(checkpoint_info.path, model_type, config_type='json')
+ model_config = sd_detect.get_load_config(checkpoint_info.path, model_type, config_type='json')
if model_config is not None:
if debug_load:
shared.log.debug(f'Load {op}: config="{model_config}"')
@@ -1285,7 +802,7 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No
return
# detect pipeline
- pipeline, model_type = detect_pipeline(checkpoint_info.path, op)
+ pipeline, model_type = sd_detect.detect_pipeline(checkpoint_info.path, op)
# preload vae so it can be used as param
vae = None
@@ -1312,7 +829,7 @@ def load_diffuser(checkpoint_info=None, already_loaded_state_dict=None, timer=No
sd_model = load_diffuser_file(model_type, pipeline, checkpoint_info, diffusers_load_config, op)
if sd_model is None:
- shared.log.error('Load {op}: no model loaded')
+ shared.log.error(f'Load {op}: name="{checkpoint_info.name if checkpoint_info is not None else None}" not loaded')
return
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() # pylint: disable=attribute-defined-outside-init
@@ -1760,7 +1277,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None,
shared.log.info(f"Model loaded in {timer.summary()}")
current_checkpoint_info = None
devices.torch_gc(force=True)
- shared.log.info(f'Model load finished: {memory_stats()} cached={len(checkpoints_loaded.keys())}')
+ shared.log.info(f'Model load finished: {memory_stats()}')
def reload_text_encoder(initial=False):
@@ -1820,7 +1337,7 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model',
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) if not shared.native else None
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
timer.record("config")
- if sd_model is None or checkpoint_config != getattr(sd_model, 'used_config', None):
+ if sd_model is None or checkpoint_config != getattr(sd_model, 'used_config', None) or force:
sd_model = None
if not shared.native:
load_model(checkpoint_info, already_loaded_state_dict=state_dict, timer=timer, op=op)
@@ -1877,9 +1394,10 @@ def disable_offload(sd_model):
from accelerate.hooks import remove_hook_from_module
if not getattr(sd_model, 'has_accelerate', False):
return
- for _name, model in sd_model.components.items():
- if isinstance(model, torch.nn.Module):
- remove_hook_from_module(model, recurse=True)
+ if hasattr(sd_model, 'components'):
+ for _name, model in sd_model.components.items():
+ if isinstance(model, torch.nn.Module):
+ remove_hook_from_module(model, recurse=True)
sd_model.has_accelerate = False
@@ -1914,82 +1432,6 @@ def unload_model_weights(op='model'):
shared.log.debug(f'Unload weights {op}: {memory_stats()}')
-def apply_token_merging(sd_model):
- current_tome = getattr(sd_model, 'applied_tome', 0)
- current_todo = getattr(sd_model, 'applied_todo', 0)
-
- if shared.opts.token_merging_method == 'ToMe' and shared.opts.tome_ratio > 0:
- if current_tome == shared.opts.tome_ratio:
- return
- if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
- shared.log.warning('Token merging not supported with HyperTile for UNet')
- return
- try:
- import installer
- installer.install('tomesd', 'tomesd', ignore=False)
- import tomesd
- tomesd.apply_patch(
- sd_model,
- ratio=shared.opts.tome_ratio,
- use_rand=False, # can cause issues with some samplers
- merge_attn=True,
- merge_crossattn=False,
- merge_mlp=False
- )
- shared.log.info(f'Applying ToMe: ratio={shared.opts.tome_ratio}')
- sd_model.applied_tome = shared.opts.tome_ratio
- except Exception:
- shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
- else:
- sd_model.applied_tome = 0
-
- if shared.opts.token_merging_method == 'ToDo' and shared.opts.todo_ratio > 0:
- if current_todo == shared.opts.todo_ratio:
- return
- if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
- shared.log.warning('Token merging not supported with HyperTile for UNet')
- return
- try:
- from modules.todo.todo_utils import patch_attention_proc
- token_merge_args = {
- "ratio": shared.opts.todo_ratio,
- "merge_tokens": "keys/values",
- "merge_method": "downsample",
- "downsample_method": "nearest",
- "downsample_factor": 2,
- "timestep_threshold_switch": 0.0,
- "timestep_threshold_stop": 0.0,
- "downsample_factor_level_2": 1,
- "ratio_level_2": 0.0,
- }
- patch_attention_proc(sd_model.unet, token_merge_args=token_merge_args)
- shared.log.info(f'Applying ToDo: ratio={shared.opts.todo_ratio}')
- sd_model.applied_todo = shared.opts.todo_ratio
- except Exception:
- shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
- else:
- sd_model.applied_todo = 0
-
-
-def remove_token_merging(sd_model):
- current_tome = getattr(sd_model, 'applied_tome', 0)
- current_todo = getattr(sd_model, 'applied_todo', 0)
- try:
- if current_tome > 0:
- import tomesd
- tomesd.remove_patch(sd_model)
- sd_model.applied_tome = 0
- except Exception:
- pass
- try:
- if current_todo > 0:
- from modules.todo.todo_utils import remove_patch
- remove_patch(sd_model)
- sd_model.applied_todo = 0
- except Exception:
- pass
-
-
def path_to_repo(fn: str = ''):
if isinstance(fn, CheckpointInfo):
fn = fn.name
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 82171e0b7..bbc2f360b 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -84,6 +84,8 @@ def create_sampler(name, model):
if 'AuraFlow' in model.__class__.__name__:
shared.log.warning(f'AuraFlow: sampler="{name}" unsupported')
return None
+ if 'KDiffusion' in model.__class__.__name__:
+ return None
if not hasattr(model, 'scheduler_config'):
model.scheduler_config = sampler.sampler.config.copy() if hasattr(sampler.sampler, 'config') else {}
model.scheduler = sampler.sampler
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 1b1cd189a..a487fe9b7 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -44,7 +44,7 @@ def single_sample_to_image(sample, approximation=None):
if sample.dtype == torch.bfloat16 and (approximation == 0 or approximation == 1):
sample = sample.to(torch.float16)
except Exception as e:
- warn_once(f'live preview: {e}')
+ warn_once(f'Preview: {e}')
if len(sample.shape) > 4: # likely unknown video latent (e.g. svd)
return Image.new(mode="RGB", size=(512, 512))
@@ -82,7 +82,7 @@ def single_sample_to_image(sample, approximation=None):
transform = T.ToPILImage()
image = transform(x_sample)
except Exception as e:
- warn_once(f'live preview: {e}')
+ warn_once(f'Preview: {e}')
image = Image.new(mode="RGB", size=(512, 512))
return image
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 52ba77bba..f266f8c38 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -2,7 +2,7 @@
import glob
from copy import deepcopy
import torch
-from modules import shared, errors, paths, devices, script_callbacks, sd_models
+from modules import shared, errors, paths, devices, script_callbacks, sd_models, sd_detect
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
@@ -206,8 +206,8 @@ def load_vae_diffusers(model_file, vae_file=None, vae_source="unknown-source"):
diffusers_load_config['variant'] = shared.opts.diffusers_vae_load_variant
if shared.opts.diffusers_vae_upcast != 'default':
diffusers_load_config['force_upcast'] = True if shared.opts.diffusers_vae_upcast == 'true' else False
- _pipeline, model_type = sd_models.detect_pipeline(model_file, 'vae')
- vae_config = sd_models.get_load_config(model_file, model_type, config_type='json')
+ _pipeline, model_type = sd_detect.detect_pipeline(model_file, 'vae')
+ vae_config = sd_detect.get_load_config(model_file, model_type, config_type='json')
if vae_config is not None:
diffusers_load_config['config'] = os.path.join(vae_config, 'vae')
shared.log.info(f'Load module: type=VAE model="{vae_file}" source={vae_source} config={diffusers_load_config}')
diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index 2b4399edb..78fe8f08b 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -46,7 +46,7 @@ def nn_approximation(sample): # Approximate NN
sd_vae_approx_model.load_state_dict(approx_weights)
sd_vae_approx_model.eval()
sd_vae_approx_model.to(device, dtype)
- shared.log.debug(f'VAE load: type=approximate model={model_path}')
+ shared.log.debug(f'VAE load: type=approximate model="{model_path}"')
try:
in_sample = sample.to(device, dtype).unsqueeze(0)
sd_vae_approx_model.to(device, dtype)
diff --git a/modules/sd_vae_ostris.py b/modules/sd_vae_ostris.py
new file mode 100644
index 000000000..70542e9f5
--- /dev/null
+++ b/modules/sd_vae_ostris.py
@@ -0,0 +1,41 @@
+import time
+import torch
+import diffusers
+from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file
+from modules import shared, devices
+
+
+decoder_id = "ostris/vae-kl-f8-d16"
+adapter_id = "ostris/16ch-VAE-Adapters"
+
+
+def load_vae(pipe):
+ if shared.sd_model_type == 'sd':
+ adapter_file = "16ch-VAE-Adapter-SD15-alpha.safetensors"
+ elif shared.sd_model_type == 'sdxl':
+ adapter_file = "16ch-VAE-Adapter-SDXL-alpha_v02.safetensors"
+ else:
+ shared.log.error('VAE: type=osiris unsupported model type')
+ return
+ t0 = time.time()
+ ckpt_file = hf_hub_download(adapter_id, adapter_file, cache_dir=shared.opts.hfcache_dir)
+ ckpt = load_file(ckpt_file)
+ lora_state_dict = {k: v for k, v in ckpt.items() if "lora" in k}
+ unet_state_dict = {k.replace("unet_", ""): v for k, v in ckpt.items() if "unet_" in k}
+
+ pipe.unet.conv_in = torch.nn.Conv2d(16, 320, 3, 1, 1)
+ pipe.unet.conv_out = torch.nn.Conv2d(320, 16, 3, 1, 1)
+ pipe.unet.load_state_dict(unet_state_dict, strict=False)
+ pipe.unet.conv_in.to(devices.dtype)
+ pipe.unet.conv_out.to(devices.dtype)
+ pipe.unet.config.in_channels = 16
+ pipe.unet.config.out_channels = 16
+
+ pipe.load_lora_weights(lora_state_dict, adapter_name=adapter_id)
+ # pipe.set_adapters(adapter_names=[adapter_id], adapter_weights=[0.8])
+ pipe.fuse_lora(adapter_names=[adapter_id], lora_scale=0.8, fuse_unet=True)
+
+ pipe.vae = diffusers.AutoencoderKL.from_pretrained(decoder_id, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
+ t1 = time.time()
+ shared.log.info(f'VAE load: type=osiris decoder="{decoder_id}" adapter="{adapter_id}" time={t1-t0:.2f}s')
diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
index 5cd7fab7c..4d213ad48 100644
--- a/modules/sd_vae_taesd.py
+++ b/modules/sd_vae_taesd.py
@@ -160,11 +160,11 @@ def decode(latents):
download_model(model_path)
if os.path.exists(model_path):
taesd_models[f'{model_class}-decoder'] = TAESD(decoder_path=model_path, encoder_path=None)
- shared.log.debug(f'VAE load: type=taesd model={model_path}')
+ shared.log.debug(f'VAE load: type=taesd model="{model_path}"')
vae = taesd_models[f'{model_class}-decoder']
vae.decoder.to(devices.device, dtype)
else:
- shared.log.error(f'VAE load: type=taesd model={model_path} not found')
+ shared.log.error(f'VAE load: type=taesd model="{model_path}" not found')
return latents
if vae is None:
return latents
@@ -181,10 +181,14 @@ def decode(latents):
image = 2.0 * image - 1.0 # typical normalized range except for preview which runs denormalization
return image
else:
- shared.log.error(f'TAESD decode unsupported latent type: {latents.shape}')
+ if not previous_warnings:
+ shared.log.error(f'TAESD decode unsupported latent type: {latents.shape}')
+ previous_warnings = True
return latents
except Exception as e:
- shared.log.error(f'VAE decode taesd: {e}')
+ if not previous_warnings:
+ shared.log.error(f'VAE decode taesd: {e}')
+ previous_warnings = True
return latents
@@ -204,7 +208,7 @@ def encode(image):
model_path = os.path.join(paths.models_path, "TAESD", f"tae{model_class}_encoder.pth")
download_model(model_path)
if os.path.exists(model_path):
- shared.log.debug(f'VAE load: type=taesd model={model_path}')
+ shared.log.debug(f'VAE load: type=taesd model="{model_path}"')
taesd_models[f'{model_class}-encoder'] = TAESD(encoder_path=model_path, decoder_path=None)
vae = taesd_models[f'{model_class}-encoder']
vae.encoder.to(devices.device, devices.dtype_vae)
diff --git a/modules/shared.py b/modules/shared.py
index e184e53ed..ae94f26cf 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -20,6 +20,7 @@
from modules.dml import memory_providers, default_memory_provider, directml_do_hijack
from modules.onnx_impl import initialize_onnx, execution_providers
from modules.memstats import memory_stats
+from modules.ui_components import DropdownEditable
import modules.interrogate
import modules.memmon
import modules.styles
@@ -279,12 +280,13 @@ def options_section(section_identifier, options_dict):
return options_dict
-def list_checkpoint_tiles():
+def list_checkpoint_titles():
import modules.sd_models # pylint: disable=W0621
- return modules.sd_models.checkpoint_tiles()
+ return modules.sd_models.checkpoint_titles()
-default_checkpoint = list_checkpoint_tiles()[0] if len(list_checkpoint_tiles()) > 0 else "model.safetensors"
+list_checkpoint_tiles = list_checkpoint_titles # alias for legacy typo
+default_checkpoint = list_checkpoint_titles()[0] if len(list_checkpoint_titles()) > 0 else "model.safetensors"
def is_url(string):
@@ -392,7 +394,7 @@ def get_default_modes():
elif gpu_memory <= 8:
cmd_opts.medvram = True
default_offload_mode = "model"
- log.info(f"Device detect: memory={gpu_memory:.1f} ptimization=medvram")
+ log.info(f"Device detect: memory={gpu_memory:.1f} optimization=medvram")
else:
default_offload_mode = "none"
log.info(f"Device detect: memory={gpu_memory:.1f} optimization=none")
@@ -426,13 +428,14 @@ def get_default_modes():
options_templates.update(options_section(('sd', "Execution & Models"), {
"sd_backend": OptionInfo(default_backend, "Execution backend", gr.Radio, {"choices": ["diffusers", "original"] }),
- "sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
- "sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_tiles()}, refresh=refresh_checkpoints),
+ "sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", DropdownEditable, lambda: {"choices": list_checkpoint_titles()}, refresh=refresh_checkpoints),
+ "sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints),
"sd_vae": OptionInfo("Automatic", "VAE model", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list),
"sd_unet": OptionInfo("None", "UNET model", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list),
"sd_text_encoder": OptionInfo('None', "Text encoder model", gr.Dropdown, lambda: {"choices": shared_items.sd_te_items()}, refresh=shared_items.refresh_te_list),
- "sd_model_dict": OptionInfo('None', "Use separate base dict", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_tiles()}, refresh=refresh_checkpoints),
+ "sd_model_dict": OptionInfo('None', "Use separate base dict", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints),
"sd_checkpoint_autoload": OptionInfo(True, "Model autoload on start"),
+ "sd_checkpoint_autodownload": OptionInfo(True, "Model auto-download on demand"),
"sd_textencoder_cache": OptionInfo(True, "Cache text encoder results"),
"stream_load": OptionInfo(False, "Load models using stream loading method", gr.Checkbox, {"visible": not native }),
"model_reuse_dict": OptionInfo(False, "Reuse loaded model dictionary", gr.Checkbox, {"visible": False}),
@@ -477,6 +480,7 @@ def get_default_modes():
"cudnn_benchmark": OptionInfo(False, "Full-depth cuDNN benchmark feature"),
"diffusers_fuse_projections": OptionInfo(False, "Fused projections"),
"torch_expandable_segments": OptionInfo(False, "Torch expandable segments"),
+ "cuda_mem_fraction": OptionInfo(0.0, "Torch memory limit", gr.Slider, {"minimum": 0, "maximum": 2.0, "step": 0.05}),
"torch_gc_threshold": OptionInfo(80, "Torch memory threshold for GC", gr.Slider, {"minimum": 0, "maximum": 100, "step": 1}),
"torch_malloc": OptionInfo("native", "Torch memory allocator", gr.Radio, {"choices": ['native', 'cudaMallocAsync'] }),
@@ -969,7 +973,7 @@ def set(self, key, value):
self.data_labels[key].onchange()
except Exception as err:
log.error(f'Error in onchange callback: {key} {value} {err}')
- errors.display(e, 'Error in onchange callback')
+ errors.display(err, 'Error in onchange callback')
setattr(self, key, oldval)
return False
return True
@@ -1247,7 +1251,7 @@ def req(url_addr, headers = None, **kwargs):
try:
res = requests.get(url_addr, timeout=30, headers=headers, verify=False, allow_redirects=True, **kwargs)
except Exception as err:
- log.error(f'HTTP request error: url={url_addr} {e}')
+ log.error(f'HTTP request error: url={url_addr} {err}')
res = { 'status_code': 500, 'text': f'HTTP request error: url={url_addr} {err}' }
res = SimpleNamespace(**res)
return res
diff --git a/modules/token_merge.py b/modules/token_merge.py
new file mode 100644
index 000000000..f97c1fc8e
--- /dev/null
+++ b/modules/token_merge.py
@@ -0,0 +1,77 @@
+from modules import shared
+
+
+def apply_token_merging(sd_model):
+ current_tome = getattr(sd_model, 'applied_tome', 0)
+ current_todo = getattr(sd_model, 'applied_todo', 0)
+
+ if shared.opts.token_merging_method == 'ToMe' and shared.opts.tome_ratio > 0:
+ if current_tome == shared.opts.tome_ratio:
+ return
+ if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
+ shared.log.warning('Token merging not supported with HyperTile for UNet')
+ return
+ try:
+ import installer
+ installer.install('tomesd', 'tomesd', ignore=False)
+ import tomesd
+ tomesd.apply_patch(
+ sd_model,
+ ratio=shared.opts.tome_ratio,
+ use_rand=False, # can cause issues with some samplers
+ merge_attn=True,
+ merge_crossattn=False,
+ merge_mlp=False
+ )
+ shared.log.info(f'Applying ToMe: ratio={shared.opts.tome_ratio}')
+ sd_model.applied_tome = shared.opts.tome_ratio
+ except Exception:
+ shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
+ else:
+ sd_model.applied_tome = 0
+
+ if shared.opts.token_merging_method == 'ToDo' and shared.opts.todo_ratio > 0:
+ if current_todo == shared.opts.todo_ratio:
+ return
+ if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
+ shared.log.warning('Token merging not supported with HyperTile for UNet')
+ return
+ try:
+ from modules.todo.todo_utils import patch_attention_proc
+ token_merge_args = {
+ "ratio": shared.opts.todo_ratio,
+ "merge_tokens": "keys/values",
+ "merge_method": "downsample",
+ "downsample_method": "nearest",
+ "downsample_factor": 2,
+ "timestep_threshold_switch": 0.0,
+ "timestep_threshold_stop": 0.0,
+ "downsample_factor_level_2": 1,
+ "ratio_level_2": 0.0,
+ }
+ patch_attention_proc(sd_model.unet, token_merge_args=token_merge_args)
+ shared.log.info(f'Applying ToDo: ratio={shared.opts.todo_ratio}')
+ sd_model.applied_todo = shared.opts.todo_ratio
+ except Exception:
+ shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
+ else:
+ sd_model.applied_todo = 0
+
+
+def remove_token_merging(sd_model):
+ current_tome = getattr(sd_model, 'applied_tome', 0)
+ current_todo = getattr(sd_model, 'applied_todo', 0)
+ try:
+ if current_tome > 0:
+ import tomesd
+ tomesd.remove_patch(sd_model)
+ sd_model.applied_tome = 0
+ except Exception:
+ pass
+ try:
+ if current_todo > 0:
+ from modules.todo.todo_utils import remove_patch
+ remove_patch(sd_model)
+ sd_model.applied_todo = 0
+ except Exception:
+ pass
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 5e873355b..9ad87c17f 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -319,13 +319,18 @@ def create_output_panel(tabname, preview=True, prompt=None, height=None):
return result_gallery, generation_info, html_info, html_info_formatted, html_log
-def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id, visible: bool = True):
+def create_refresh_button(refresh_component, refresh_method, refreshed_args = None, elem_id = None, visible: bool = True):
def refresh():
refresh_method()
- args = refreshed_args() if callable(refreshed_args) else refreshed_args
+ if refreshed_args is None:
+ args = {"choices": refresh_method()} # pylint: disable=unnecessary-lambda-assignment
+ elif callable(refreshed_args):
+ args = refreshed_args()
+ else:
+ args = refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
- return gr.update(**(args or {}))
+ return gr.update(**args)
refresh_button = ui_components.ToolButton(value=ui_symbols.refresh, elem_id=elem_id, visible=visible)
refresh_button.click(fn=refresh, inputs=[], outputs=[refresh_component])
diff --git a/modules/ui_control.py b/modules/ui_control.py
index 388e8ede9..4d7c59bee 100644
--- a/modules/ui_control.py
+++ b/modules/ui_control.py
@@ -612,6 +612,7 @@ def create_ui(_blocks: gr.Blocks=None):
(mask_controls[6], "Mask auto"),
# advanced
(cfg_scale, "CFG scale"),
+ (cfg_end, "CFG end"),
(clip_skip, "Clip skip"),
(image_cfg_scale, "Image CFG scale"),
(diffusers_guidance_rescale, "CFG rescale"),
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index e59e51e1a..9b66b754e 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -15,6 +15,8 @@ def refresh(self):
shared.refresh_checkpoints()
def list_reference(self): # pylint: disable=inconsistent-return-statements
+ if not shared.opts.sd_checkpoint_autodownload:
+ return []
for k, v in shared.reference_models.items():
if not shared.native:
if not v.get('original', False):
diff --git a/modules/ui_img2img.py b/modules/ui_img2img.py
index d46ea4dd3..44f48c3c6 100644
--- a/modules/ui_img2img.py
+++ b/modules/ui_img2img.py
@@ -263,6 +263,7 @@ def select_img2img_tab(tab):
(refiner_start, "Refiner start"),
# advanced
(cfg_scale, "CFG scale"),
+ (cfg_end, "CFG end"),
(image_cfg_scale, "Image CFG scale"),
(clip_skip, "Clip skip"),
(diffusers_guidance_rescale, "CFG rescale"),
diff --git a/modules/ui_models.py b/modules/ui_models.py
index 051ca39a7..e9be428b4 100644
--- a/modules/ui_models.py
+++ b/modules/ui_models.py
@@ -59,8 +59,8 @@ def analyze():
with gr.Tab(label="Convert"):
with gr.Row():
- model_name = gr.Dropdown(sd_models.checkpoint_tiles(), label="Original model")
- create_refresh_button(model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_Z")
+ model_name = gr.Dropdown(sd_models.checkpoint_titles(), label="Original model")
+ create_refresh_button(model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_titles()}, "refresh_checkpoint_Z")
with gr.Row():
custom_name = gr.Textbox(label="Output model name")
with gr.Row():
@@ -98,7 +98,7 @@ def analyze():
with gr.Tab(label="Merge"):
def sd_model_choices():
- return ['None'] + sd_models.checkpoint_tiles()
+ return ['None'] + sd_models.checkpoint_titles()
with gr.Row(equal_height=False):
with gr.Column(variant='compact'):
@@ -213,10 +213,10 @@ def modelmerger(dummy_component, # dummy function just to get argspec later
del kwargs['dummy_component']
if kwargs.get("custom_name", None) is None:
log.error('Merge: no output model specified')
- return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "No output model specified"]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], "No output model specified"]
elif kwargs.get("primary_model_name", None) is None or kwargs.get("secondary_model_name", None) is None:
log.error('Merge: no models selected')
- return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "No models selected"]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], "No models selected"]
else:
log.debug(f'Merge start: {kwargs}')
try:
@@ -224,7 +224,7 @@ def modelmerger(dummy_component, # dummy function just to get argspec later
except Exception as e:
modules.errors.display(e, 'Merge')
sd_models.list_models() # to remove the potentially missing models from the list
- return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
+ return [*[gr.Dropdown.update(choices=sd_models.checkpoint_titles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
def tertiary(mode):
diff --git a/modules/ui_txt2img.py b/modules/ui_txt2img.py
index 1ae3a8dad..ece8829c5 100644
--- a/modules/ui_txt2img.py
+++ b/modules/ui_txt2img.py
@@ -116,6 +116,7 @@ def create_ui():
(subseed_strength, "Variation strength"),
# advanced
(cfg_scale, "CFG scale"),
+ (cfg_end, "CFG end"),
(clip_skip, "Clip skip"),
(image_cfg_scale, "Image CFG scale"),
(diffusers_guidance_rescale, "CFG rescale"),
diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py
index 506652edf..84c130e8d 100644
--- a/modules/zluda_installer.py
+++ b/modules/zluda_installer.py
@@ -4,7 +4,7 @@
import shutil
import zipfile
import urllib.request
-from typing import Optional
+from typing import Optional, Union
from modules import rocm
@@ -15,12 +15,18 @@
}
HIPSDK_TARGETS = ['rocblas.dll', 'rocsolver.dll', f'hiprtc{"".join([v.zfill(2) for v in rocm.version.split(".")])}.dll']
ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',)
+default_agent: Union[rocm.Agent, None] = None
def get_path() -> str:
return os.path.abspath(os.environ.get('ZLUDA', '.zluda'))
+def set_default_agent(agent: rocm.Agent):
+ global default_agent # pylint: disable=global-statement
+ default_agent = agent
+
+
def install(zluda_path: os.PathLike) -> None:
if os.path.exists(zluda_path):
return
diff --git a/requirements.txt b/requirements.txt
index 801475332..208cfba3a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -36,11 +36,11 @@ torchsde==0.2.6
antlr4-python3-runtime==4.9.3
requests==2.32.3
tqdm==4.66.5
-accelerate==1.0.0
+accelerate==1.0.1
opencv-contrib-python-headless==4.9.0.80
einops==0.4.1
gradio==3.43.2
-huggingface_hub==0.25.2
+huggingface_hub==0.26.2
numexpr==2.8.8
numpy==1.26.4
numba==0.59.1
@@ -49,8 +49,8 @@ scipy
pandas
protobuf==4.25.3
pytorch_lightning==1.9.4
-tokenizers==0.20.0
-transformers==4.45.2
+tokenizers==0.20.1
+transformers==4.46.0
urllib3==1.26.19
Pillow==10.4.0
timm==0.9.16
@@ -61,8 +61,7 @@ torchdiffeq
dctorch
scikit-image
seam-carving
-open-clip-torch
-# TODO temporary block for torch==2.5.0
-torchvision!=0.20.0
+# block
torch!=2.5.0
+torchvision!=0.20.0
diff --git a/scripts/apg.py b/scripts/apg.py
index 3325d9333..c7e60c982 100644
--- a/scripts/apg.py
+++ b/scripts/apg.py
@@ -62,10 +62,13 @@ def run(self, p: processing.StableDiffusionProcessing, eta = 0.0, momentum = 0.0
def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, eta, momentum, threshold): # pylint: disable=arguments-differ, unused-argument
from modules import apg
+ if self.orig_pipe is None:
+ return processed
# restore pipeline
- if shared.sd_model_type == "sdxl":
+ if shared.sd_model_type == "sdxl" or shared.sd_model_type == "sd":
shared.sd_model = self.orig_pipe
elif shared.sd_model_type == "sc":
shared.sd_model.prior_pipe = self.orig_pipe
apg.buffer = None
+ self.orig_pipe = None
return processed
diff --git a/scripts/ipadapter.py b/scripts/ipadapter.py
index c7fcb3053..60c70b9dc 100644
--- a/scripts/ipadapter.py
+++ b/scripts/ipadapter.py
@@ -1,7 +1,7 @@
import json
from PIL import Image
import gradio as gr
-from modules import scripts, processing, shared, ipadapter
+from modules import scripts, processing, shared, ipadapter, ui_common
MAX_ADAPTERS = 4
@@ -60,9 +60,12 @@ def ui(self, _is_img2img):
for i in range(MAX_ADAPTERS):
with gr.Accordion(f'Adapter {i+1}', visible=i==0) as unit:
with gr.Row():
- adapters.append(gr.Dropdown(label='Adapter', choices=list(ipadapter.ADAPTERS), value='None'))
- scales.append(gr.Slider(label='Scale', minimum=0.0, maximum=1.0, step=0.01, value=0.5))
- crops.append(gr.Checkbox(label='Crop', default=False, interactive=True))
+ adapter = gr.Dropdown(label='Adapter', choices=list(ipadapter.get_adapters()), value='None')
+ adapters.append(adapter)
+ ui_common.create_refresh_button(adapter, ipadapter.get_adapters)
+ with gr.Row():
+ scales.append(gr.Slider(label='Strength', minimum=0.0, maximum=1.0, step=0.01, value=0.5))
+ crops.append(gr.Checkbox(label='Crop to portrait', default=False, interactive=True))
with gr.Row():
starts.append(gr.Slider(label='Start', minimum=0.0, maximum=1.0, step=0.1, value=0))
ends.append(gr.Slider(label='End', minimum=0.0, maximum=1.0, step=0.1, value=1))
diff --git a/scripts/ipinstruct.py b/scripts/ipinstruct.py
new file mode 100644
index 000000000..4a94197b1
--- /dev/null
+++ b/scripts/ipinstruct.py
@@ -0,0 +1,111 @@
+"""
+Repo:
+Models:
+adapter: `sd15`=0.35GB `sdxl`=2.12GB `sd3`=1.56GB
+encoder: `laion/CLIP-ViT-H-14-laion2B-s32B-b79K`=3.94GB
+"""
+import os
+import importlib
+import gradio as gr
+from modules import scripts, processing, shared, sd_models, devices
+
+
+repo = 'https://github.com/vladmandic/IP-Instruct'
+repo_id = 'CiaraRowles/IP-Adapter-Instruct'
+encoder = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
+folder = os.path.join('repositories', 'ip_instruct')
+
+
+class Script(scripts.Script):
+ def __init__(self):
+ super().__init__()
+ self.orig_pipe = None
+ self.lib = None
+
+ def title(self):
+ return 'IP Instruct'
+
+ def show(self, is_img2img):
+ if shared.cmd_opts.experimental:
+ return not is_img2img if shared.native else False
+ else:
+ return False
+
+ def install(self):
+ if not os.path.exists(folder):
+ from installer import clone
+ clone(repo, folder)
+ if self.lib is None:
+ self.lib = importlib.import_module('ip_instruct.ip_adapter')
+
+
+ def ui(self, _is_img2img): # ui elements
+ with gr.Row():
+ gr.HTML('  IP Adapter Instruct
')
+ with gr.Row():
+ query = gr.Textbox(lines=1, label='Query', placeholder='use the composition from the image')
+ with gr.Row():
+ image = gr.Image(value=None, label='Image', type='pil', source='upload', width=256, height=256)
+ with gr.Row():
+ strength = gr.Slider(label="Strength", value=1.0, minimum=0, maximum=2.0, step=0.05)
+ tokens = gr.Slider(label="Tokens", value=4, minimum=1, maximum=32, step=1)
+ with gr.Row():
+ instruct_guidance = gr.Slider(label="Guidance", value=6.0, minimum=1.0, maximum=15.0, step=0.05)
+ image_guidance = gr.Slider(label="Guidance", value=0.5, minimum=0, maximum=1.0, step=0.05)
+ return [query, image, strength, tokens, instruct_guidance, image_guidance]
+
+ def run(self, p: processing.StableDiffusionProcessing, query, image, strength, tokens, instruct_guidance, image_guidance): # pylint: disable=arguments-differ
+ supported_model_list = ['sd', 'sdxl', 'sd3']
+ if shared.sd_model_type not in supported_model_list:
+ shared.log.warning(f'IP-Instruct: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}')
+ return None
+ self.install()
+ if self.lib is None:
+ shared.log.error('IP-Instruct: failed to import library')
+ return None
+ self.orig_pipe = shared.sd_model
+ if shared.sd_model_type == 'sdxl':
+ pipe = self.lib.StableDiffusionXLPipelineExtraCFG
+ cls = self.lib.IPAdapterInstructSDXL
+ ckpt = "ip-adapter-instruct-sdxl.bin"
+ elif shared.sd_model_type == 'sd3':
+ pipe = self.lib.StableDiffusion3PipelineExtraCFG
+ cls = self.lib.IPAdapter_sd3_Instruct
+ ckpt = "ip-adapter-instruct-sd3.bin"
+ else:
+ pipe = self.lib.StableDiffusionPipelineCFG
+ cls = self.lib.IPAdapterInstruct
+ ckpt = "ip-adapter-instruct-sd15.bin"
+
+ shared.sd_model = sd_models.switch_pipe(pipe, shared.sd_model)
+
+ import huggingface_hub as hf
+ ip_ckpt = hf.hf_hub_download(repo_id=repo_id, filename=ckpt, cache_dir=shared.opts.hfcache_dir)
+ ip_model = cls(shared.sd_model, encoder, ip_ckpt, device=devices.device, dtypein=devices.dtype, num_tokens=tokens)
+ processing.fix_seed(p)
+ shared.log.debug(f'IP-Instruct: class={shared.sd_model.__class__.__name__} wrapper={ip_model.__class__.__name__} encoder={encoder} adapter={ckpt}')
+ shared.log.info(f'IP-Instruct: image={image} query="{query}" strength={strength} tokens={tokens} instruct_guidance={instruct_guidance} image_guidance={image_guidance}')
+
+ image_list = ip_model.generate(
+ query = query,
+ scale = strength,
+ instruct_guidance_scale = instruct_guidance,
+ image_guidance_scale = image_guidance,
+
+ prompt = p.prompt,
+ pil_image = image,
+ num_samples = 1,
+ num_inference_steps = p.steps,
+ seed = p.seed,
+ guidance_scale = p.cfg_scale,
+ auto_scale = False,
+ simple_cfg_mode = False,
+ )
+ processed = processing.Processed(p, images_list=image_list, seed=p.seed, subseed=p.subseed, index_of_first_image=0) # manually created processed object
+ # p.extra_generation_params["IPInstruct"] = f''
+ return processed
+
+ def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, **kwargs): # pylint: disable=unused-argument
+ if self.orig_pipe is not None:
+ shared.sd_model = self.orig_pipe
+ return processed
diff --git a/scripts/k_diff.py b/scripts/k_diff.py
new file mode 100644
index 000000000..711115aa1
--- /dev/null
+++ b/scripts/k_diff.py
@@ -0,0 +1,77 @@
+import inspect
+import importlib
+import gradio as gr
+import diffusers
+from modules import scripts, processing, shared, sd_models
+
+
+class Script(scripts.Script):
+ supported_models = ['sd', 'sdxl']
+ orig_pipe = None
+ try:
+ library = importlib.import_module('k_diffusion')
+ except Exception:
+ library = None
+
+ def title(self):
+ return 'K-Diffusion'
+
+ def show(self, is_img2img):
+ return not is_img2img if shared.native else False
+
+ def ui(self, _is_img2img): # ui elements
+ with gr.Row():
+ gr.HTML('  K-Diffusion samplers
')
+ with gr.Row():
+ sampler = gr.Dropdown(label="Sampler", choices=self.samplers())
+ return [sampler]
+
+ def samplers(self):
+ samplers = []
+ sampling = getattr(self.library, 'sampling', None)
+ if sampling is None:
+ return samplers
+ for s in dir(sampling):
+ if s.startswith('sample_'):
+ samplers.append(s.replace('sample_', ''))
+ return samplers
+
+ def callback(self, d):
+ _step = d['i']
+
+ def run(self, p: processing.StableDiffusionProcessing, sampler: str): # pylint: disable=arguments-differ
+ if shared.sd_model_type not in self.supported_models:
+ shared.log.warning(f'K-Diffusion: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={self.supported_models}')
+ return None
+ if self.library is None:
+ return
+ cls = None
+ if shared.sd_model_type == "sd":
+ cls = diffusers.pipelines.StableDiffusionKDiffusionPipeline
+ if shared.sd_model_type == "sdxl":
+ cls = diffusers.pipelines.StableDiffusionXLKDiffusionPipeline
+ if cls is None:
+ return None
+ self.orig_pipe = shared.sd_model
+ shared.sd_model = sd_models.switch_pipe(cls, shared.sd_model)
+ sampler = 'sample_' + sampler
+
+ sampling = getattr(self.library, "sampling", None)
+ shared.sd_model.sampler = getattr(sampling, sampler)
+
+ params = inspect.signature(shared.sd_model.sampler).parameters.values()
+ params = {param.name: param.default for param in params if param.default != inspect.Parameter.empty}
+ # if 'callback' in list(params):
+ # params['callback'] = self.callback
+ # if 'disable' in list(params):
+ # params['disable'] = False
+ shared.log.info(f'K-diffusion apply: class={shared.sd_model.__class__.__name__} sampler={sampler} params={params}')
+ p.extra_generation_params["Sampler"] = sampler
+
+ def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, sampler): # pylint: disable=arguments-differ, unused-argument
+ if self.orig_pipe is None:
+ return processed
+ if shared.sd_model_type == "sdxl" or shared.sd_model_type == "sd":
+ shared.sd_model = self.orig_pipe
+ self.orig_pipe = None
+ return processed
diff --git a/scripts/x_adapter.py b/scripts/x_adapter.py
index 7c1341701..553a20d30 100644
--- a/scripts/x_adapter.py
+++ b/scripts/x_adapter.py
@@ -22,7 +22,7 @@ def ui(self, _is_img2img):
with gr.Row():
gr.HTML('  X-Adapter
')
with gr.Row():
- model = gr.Dropdown(label='Adapter model', choices=['None'] + sd_models.checkpoint_tiles(), value='None')
+ model = gr.Dropdown(label='Adapter model', choices=['None'] + sd_models.checkpoint_titles(), value='None')
sampler = gr.Dropdown(label='Adapter sampler', choices=[s.name for s in sd_samplers.samplers], value='Default')
with gr.Row():
width = gr.Slider(label='Adapter width', minimum=64, maximum=2048, step=8, value=1024)
@@ -34,7 +34,7 @@ def ui(self, _is_img2img):
lora = gr.Textbox('', label='Adapter LoRA', default='')
return model, sampler, width, height, start, scale, lora
- def run(self, p: processing.StableDiffusionProcessing, model, sampler, width, height, start, scale, lora): # pylint: disable=arguments-differ
+ def run(self, p: processing.StableDiffusionProcessing, model, sampler, width, height, start, scale, lora): # pylint: disable=arguments-differ, unused-argument
from modules.xadapter.xadapter_hijacks import PositionNet
diffusers.models.embeddings.PositionNet = PositionNet # patch diffusers==0.26 from diffusers==0.20
from modules.xadapter.adapter import Adapter_XL
diff --git a/scripts/xyz_grid_classes.py b/scripts/xyz_grid_classes.py
index 8a78d2c40..335d66186 100644
--- a/scripts/xyz_grid_classes.py
+++ b/scripts/xyz_grid_classes.py
@@ -99,7 +99,7 @@ def __exit__(self, exc_type, exc_value, tb):
AxisOption("[Param] Height", int, apply_field("height")),
AxisOption("[Param] Seed", int, apply_seed),
AxisOption("[Param] Steps", int, apply_field("steps")),
- AxisOption("[Param] CFG scale", float, apply_field("cfg_scale")),
+ AxisOption("[Param] Guidance scale", float, apply_field("cfg_scale")),
AxisOption("[Param] Guidance end", float, apply_field("cfg_end")),
AxisOption("[Param] Variation seed", int, apply_field("subseed")),
AxisOption("[Param] Variation strength", float, apply_field("subseed_strength")),
@@ -125,7 +125,7 @@ def __exit__(self, exc_type, exc_value, tb):
AxisOption("[Refine] Sampler", str, apply_hr_sampler_name, fmt=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
AxisOption("[Refine] Denoising strength", float, apply_field("denoising_strength")),
AxisOption("[Refine] Hires steps", int, apply_field("hr_second_pass_steps")),
- AxisOption("[Refine] CFG scale", float, apply_field("image_cfg_scale")),
+ AxisOption("[Refine] Guidance scale", float, apply_field("image_cfg_scale")),
AxisOption("[Refine] Guidance rescale", float, apply_field("diffusers_guidance_rescale")),
AxisOption("[Refine] Refiner start", float, apply_field("refiner_start")),
AxisOption("[Refine] Refiner steps", float, apply_field("refiner_steps")),
diff --git a/wiki b/wiki
index 53def8203..4360bc7fc 160000
--- a/wiki
+++ b/wiki
@@ -1 +1 @@
-Subproject commit 53def8203b6799cbd659327c1af6aa5af9cb9a70
+Subproject commit 4360bc7fcfd5dc301b825a6c5dcf4274c8eba983