Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Mar 11, 2024
2 parents cdd69ef + 65397ce commit 7b236ae
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 65 deletions.
10 changes: 10 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")

parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")


if comfy.options.args_parsing:
args = parser.parse_args()
else:
Expand All @@ -134,3 +137,10 @@ class LatentPreviewMethod(enum.Enum):

if args.disable_auto_launch:
args.auto_launch = False

import logging
logging_level = logging.WARNING
if args.verbose:
logging_level = logging.DEBUG

logging.basicConfig(format="%(message)s", level=logging_level)
3 changes: 2 additions & 1 deletion comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import torch
import json
import logging

import comfy.ops
import comfy.model_patcher
Expand Down Expand Up @@ -99,7 +100,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:
print("missing clip vision:", m)
logging.warning("missing clip vision: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
Expand Down
18 changes: 12 additions & 6 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import math
import os
import logging
import comfy.utils
import comfy.model_management
import comfy.model_detection
Expand Down Expand Up @@ -367,7 +368,7 @@ def load_controlnet(ckpt_path, model=None):

leftover_keys = controlnet_data.keys()
if len(leftover_keys) > 0:
print("leftover keys:", leftover_keys)
logging.warning("leftover keys: {}".format(leftover_keys))
controlnet_data = new_sd

pth_key = 'control_model.zero_convs.0.0.weight'
Expand All @@ -382,7 +383,7 @@ def load_controlnet(ckpt_path, model=None):
else:
net = load_t2i_adapter(controlnet_data)
if net is None:
print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return net

if controlnet_config is None:
Expand Down Expand Up @@ -417,7 +418,7 @@ def load_controlnet(ckpt_path, model=None):
cd = controlnet_data[x]
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
else:
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")

class WeightsLoader(torch.nn.Module):
pass
Expand All @@ -426,7 +427,12 @@ class WeightsLoader(torch.nn.Module):
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
else:
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)

if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing))

if len(unexpected) > 0:
logging.info("unexpected controlnet keys: {}".format(unexpected))

global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
Expand Down Expand Up @@ -536,9 +542,9 @@ def load_t2i_adapter(t2i_data):

missing, unexpected = model_ad.load_state_dict(t2i_data)
if len(missing) > 0:
print("t2i missing", missing)
logging.warning("t2i missing {}".format(missing))

if len(unexpected) > 0:
print("t2i unexpected", unexpected)
logging.info("t2i unexpected {}".format(unexpected))

return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
3 changes: 2 additions & 1 deletion comfy/diffusers_convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import torch
import logging

# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py

Expand Down Expand Up @@ -177,7 +178,7 @@ def convert_vae_state_dict(vae_state_dict):
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f"mid.attn_1.{weight_name}.weight" in k:
print(f"Reshaping {k} for SD format")
logging.info(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict

Expand Down
3 changes: 2 additions & 1 deletion comfy/lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import comfy.utils
import logging

LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
Expand Down Expand Up @@ -156,7 +157,7 @@ def load_lora(lora, to_load):

for x in lora.keys():
if x not in loaded_keys:
print("lora key not loaded", x)
logging.warning("lora key not loaded: {}".format(x))
return patch_dict

def model_lora_keys_clip(model, key_map={}):
Expand Down
9 changes: 5 additions & 4 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
Expand Down Expand Up @@ -66,8 +67,8 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
if self.adm_channels is None:
self.adm_channels = 0
self.inpaint_model = False
print("model_type", model_type.name)
print("adm", self.adm_channels)
logging.warning("model_type {}".format(model_type.name))
logging.info("adm {}".format(self.adm_channels))

def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
Expand Down Expand Up @@ -183,10 +184,10 @@ def load_model_weights(self, sd, unet_prefix=""):
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)
logging.warning("unet missing: {}".format(m))

if len(u) > 0:
print("unet unexpected:", u)
logging.warning("unet unexpected: {}".format(u))
del to_load
return self

Expand Down
3 changes: 2 additions & 1 deletion comfy/model_detection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import comfy.supported_models
import comfy.supported_models_base
import logging

def count_blocks(state_dict_keys, prefix_string):
count = 0
Expand Down Expand Up @@ -186,7 +187,7 @@ def model_config_from_unet_config(unet_config):
if model_config.matches(unet_config):
return model_config(unet_config)

print("no match", unet_config)
logging.error("no match {}".format(unet_config))
return None

def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
Expand Down
41 changes: 20 additions & 21 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args
import comfy.utils
Expand Down Expand Up @@ -29,7 +30,7 @@ class CPUState(Enum):
xpu_available = False

if args.deterministic:
print("Using deterministic algorithms for pytorch")
logging.warning("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)

directml_enabled = False
Expand All @@ -41,7 +42,7 @@ class CPUState(Enum):
directml_device = torch_directml.device()
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
logging.warning("Using directml with device: {}".format(torch_directml.device_name(device_index)))
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.

Expand Down Expand Up @@ -131,10 +132,10 @@ def get_total_memory(dev=None, torch_total_too=False):

total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
logging.warning("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM

try:
Expand All @@ -157,12 +158,10 @@ def get_total_memory(dev=None, torch_total_too=False):
pass
try:
XFORMERS_VERSION = xformers.version.__version__
print("xformers version:", XFORMERS_VERSION)
logging.warning("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"):
print()
print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
print("Please downgrade or upgrade xformers to a different version.")
print()
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
XFORMERS_ENABLED_VAE = False
except:
pass
Expand Down Expand Up @@ -227,11 +226,11 @@ def is_nvidia():
FORCE_FP32 = False
FORCE_FP16 = False
if args.force_fp32:
print("Forcing FP32, if this improves things please report it.")
logging.warning("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True

if args.force_fp16:
print("Forcing FP16.")
logging.warning("Forcing FP16.")
FORCE_FP16 = True

if lowvram_available:
Expand All @@ -245,12 +244,12 @@ def is_nvidia():
if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED

print(f"Set vram state to: {vram_state.name}")
logging.warning(f"Set vram state to: {vram_state.name}")

DISABLE_SMART_MEMORY = args.disable_smart_memory

if DISABLE_SMART_MEMORY:
print("Disabling smart memory management")
logging.warning("Disabling smart memory management")

def get_torch_device_name(device):
if hasattr(device, 'type'):
Expand All @@ -268,11 +267,11 @@ def get_torch_device_name(device):
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))

try:
print("Device:", get_torch_device_name(get_torch_device()))
logging.warning("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
print("Could not pick default device.")
logging.warning("Could not pick default device.")

print("VAE dtype:", VAE_DTYPE)
logging.warning("VAE dtype: {}".format(VAE_DTYPE))

current_loaded_models = []

Expand Down Expand Up @@ -315,7 +314,7 @@ def model_load(self, lowvram_model_memory=0):
raise e

if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
logging.warning("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
Expand All @@ -328,7 +327,7 @@ def model_load(self, lowvram_model_memory=0):
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
print("lowvram: loaded module regularly", m)
logging.warning("lowvram: loaded module regularly {}".format(m))

self.model_accelerated = True

Expand Down Expand Up @@ -362,7 +361,7 @@ def unload_model_clones(model):
to_unload = [i] + to_unload

for i in to_unload:
print("unload clone", i)
logging.warning("unload clone {}".format(i))
current_loaded_models.pop(i).model_unload()

def free_memory(memory_required, device, keep_loaded=[]):
Expand Down Expand Up @@ -404,7 +403,7 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded.append(loaded_model)
else:
if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}")
logging.warning(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)

if len(models_to_load) == 0:
Expand All @@ -414,7 +413,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded)
return

print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
logging.warning(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")

total_memory_required = {}
for loaded_model in models_to_load:
Expand Down
18 changes: 11 additions & 7 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import copy
import inspect
import logging

import comfy.utils
import comfy.model_management
Expand Down Expand Up @@ -187,7 +188,7 @@ def patch_model(self, device_to=None, patch_weights=True):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue

weight = model_sd[key]
Expand Down Expand Up @@ -236,7 +237,7 @@ def calculate_weight(self, patches, weight, key):
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora": #lora/locon
Expand All @@ -252,7 +253,7 @@ def calculate_weight(self, patches, weight, key):
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
Expand Down Expand Up @@ -291,7 +292,7 @@ def calculate_weight(self, patches, weight, key):
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
Expand Down Expand Up @@ -320,7 +321,7 @@ def calculate_weight(self, patches, weight, key):
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
Expand All @@ -330,9 +331,12 @@ def calculate_weight(self, patches, weight, key):
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)

weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
print("patch type not recognized", patch_type, key)
logging.warning("patch type not recognized {} {}".format(patch_type, key))

return weight

Expand Down
Loading

0 comments on commit 7b236ae

Please sign in to comment.