From ab93abd4b2eaf99d4a52f9a036600d9d46355d92 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 11 Dec 2023 17:33:35 +0000 Subject: [PATCH 1/3] Prevent cleaning graph state on undo/redo (#2255) * Prevent cleaning graph state on undo/redo * Remove pause rendering due to LG bug --- web/extensions/core/undoRedo.js | 25 +++++++++++-------------- web/scripts/app.js | 7 +++++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js index c6613b0f02d..3cb137520f4 100644 --- a/web/extensions/core/undoRedo.js +++ b/web/extensions/core/undoRedo.js @@ -71,24 +71,21 @@ function graphEqual(a, b, root = true) { } const undoRedo = async (e) => { + const updateState = async (source, target) => { + const prevState = source.pop(); + if (prevState) { + target.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState, false); + activeState = prevState; + } + } if (e.ctrlKey || e.metaKey) { if (e.key === "y") { - const prevState = redo.pop(); - if (prevState) { - undo.push(activeState); - isOurLoad = true; - await app.loadGraphData(prevState); - activeState = prevState; - } + updateState(redo, undo); return true; } else if (e.key === "z") { - const prevState = undo.pop(); - if (prevState) { - redo.push(activeState); - isOurLoad = true; - await app.loadGraphData(prevState); - activeState = prevState; - } + updateState(undo, redo); return true; } } diff --git a/web/scripts/app.js b/web/scripts/app.js index 5faf41fb36b..d2a6f4de425 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1559,9 +1559,12 @@ export class ComfyApp { /** * Populates the graph with the specified workflow data * @param {*} graphData A serialized graph object + * @param { boolean } clean If the graph state, e.g. images, should be cleared */ - async loadGraphData(graphData) { - this.clean(); + async loadGraphData(graphData, clean = true) { + if (clean !== false) { + this.clean(); + } let reset_invalid_values = false; if (!graphData) { From ba07cb748e4793a6393288d621aa8e2f0f282595 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Dec 2023 18:24:44 -0500 Subject: [PATCH 2/3] Use faster manual cast for fp8 in unet. --- comfy/model_base.py | 19 ++++++++++--------- comfy/model_management.py | 16 +++++++++++++++- comfy/ops.py | 9 +++++++++ comfy/sd.py | 12 ++++++++++-- comfy/supported_models_base.py | 4 ++++ 5 files changed, 48 insertions(+), 12 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 5bfcc391ded..bab7b9b340d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -4,6 +4,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep import comfy.model_management import comfy.conds +import comfy.ops from enum import Enum import contextlib from . import utils @@ -41,9 +42,14 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None): unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config + self.manual_cast_dtype = model_config.manual_cast_dtype if not unet_config.get("disable_unet_model_creation", False): - self.diffusion_model = UNetModel(**unet_config, device=device) + if self.manual_cast_dtype is not None: + operations = comfy.ops.manual_cast + else: + operations = comfy.ops + self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -63,11 +69,8 @@ def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, trans context = c_crossattn dtype = self.get_dtype() - if comfy.model_management.supports_dtype(xc.device, dtype): - precision_scope = lambda a: contextlib.nullcontext(a) - else: - precision_scope = torch.autocast - dtype = torch.float32 + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype xc = xc.to(dtype) t = self.model_sampling.timestep(t).float() @@ -79,9 +82,7 @@ def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, trans extra = extra.to(dtype) extra_conds[o] = extra - with precision_scope(comfy.model_management.get_autocast_device(xc.device)): - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() - + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() return self.model_sampling.calculate_denoised(sigma, model_output, x) def get_dtype(self): diff --git a/comfy/model_management.py b/comfy/model_management.py index a6c8fb352b2..fe0374a8b2f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -474,6 +474,20 @@ def unet_dtype(device=None, model_params=0): return torch.float16 return torch.float32 +# None means no manual cast +def unet_manual_cast(weight_dtype, inference_device): + if weight_dtype == torch.float32: + return None + + fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False) + if fp16_supported and weight_dtype == torch.float16: + return None + + if fp16_supported: + return torch.float16 + else: + return torch.float32 + def text_encoder_offload_device(): if args.gpu_only: return get_torch_device() @@ -538,7 +552,7 @@ def get_autocast_device(dev): def supports_dtype(device, dtype): #TODO if dtype == torch.float32: return True - if torch.device("cpu") == device: + if is_device_cpu(device): return False if dtype == torch.float16: return True diff --git a/comfy/ops.py b/comfy/ops.py index e48568409a1..a67bc809fd2 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -62,6 +62,15 @@ def forward(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + @classmethod + def conv_nd(s, dims, *args, **kwargs): + if dims == 2: + return s.Conv2d(*args, **kwargs) + elif dims == 3: + return s.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") + @contextmanager def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way old_torch_nn_linear = torch.nn.Linear diff --git a/comfy/sd.py b/comfy/sd.py index 43e201d363b..8c056e4ea2f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -433,11 +433,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.") unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) class WeightsLoader(torch.nn.Module): pass model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype) + model_config.set_manual_cast(manual_cast_dtype) + if model_config is None: raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) @@ -470,7 +474,7 @@ class WeightsLoader(torch.nn.Module): print("left over keys:", left_over) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device) + model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device) if inital_load_device != torch.device("cpu"): print("loaded straight to GPU") model_management.load_model_gpu(model_patcher) @@ -481,6 +485,9 @@ class WeightsLoader(torch.nn.Module): def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = model_management.get_torch_device() + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + if "input_blocks.0.0.weight" in sd: #ldm model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) if model_config is None: @@ -501,13 +508,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format else: print(diffusers_keys[k], k) offload_device = model_management.unet_offload_device() + model_config.set_manual_cast(manual_cast_dtype) model = model_config.get_model(new_sd, "") model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() if len(left_over) > 0: print("left over keys in unet:", left_over) - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) + return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) def load_unet(unet_path): sd = comfy.utils.load_torch_file(unet_path) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 3412cfea030..49087d23e5d 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -22,6 +22,8 @@ class BASE: sampling_settings = {} latent_format = latent_formats.LatentFormat + manual_cast_dtype = None + @classmethod def matches(s, unet_config): for k in s.unet_config: @@ -71,3 +73,5 @@ def process_vae_state_dict_for_saving(self, state_dict): replace_prefix = {"": "first_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) + def set_manual_cast(self, manual_cast_dtype): + self.manual_cast_dtype = manual_cast_dtype From b0aab1e4ea3dfefe09c4f07de0e5237558097e22 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Dec 2023 18:36:29 -0500 Subject: [PATCH 3/3] Add an option --fp16-unet to force using fp16 for the unet. --- comfy/cli_args.py | 1 + comfy/model_management.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 58d0348028f..d9c8668f470 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -57,6 +57,7 @@ def __call__(self, parser, namespace, values, option_string=None): fpunet_group = parser.add_mutually_exclusive_group() fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") +fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.") fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") diff --git a/comfy/model_management.py b/comfy/model_management.py index fe0374a8b2f..b6a9471bfa1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -466,6 +466,8 @@ def unet_inital_load_device(parameters, dtype): def unet_dtype(device=None, model_params=0): if args.bf16_unet: return torch.bfloat16 + if args.fp16_unet: + return torch.float16 if args.fp8_e4m3fn_unet: return torch.float8_e4m3fn if args.fp8_e5m2_unet: