Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Dec 12, 2023
2 parents c05e5dd + b0aab1e commit f30ff20
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 28 deletions.
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __call__(self, parser, namespace, values, option_string=None):

fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")

Expand Down
19 changes: 10 additions & 9 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import comfy.conds
import comfy.ops
from enum import Enum
import contextlib
from . import utils
Expand Down Expand Up @@ -41,9 +42,14 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None):
unet_config = model_config.unet_config
self.latent_format = model_config.latent_format
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype

if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device)
if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops
self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations)
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)

Expand All @@ -63,11 +69,8 @@ def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, trans
context = c_crossattn
dtype = self.get_dtype()

if comfy.model_management.supports_dtype(xc.device, dtype):
precision_scope = lambda a: contextlib.nullcontext(a)
else:
precision_scope = torch.autocast
dtype = torch.float32
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype

xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
Expand All @@ -79,9 +82,7 @@ def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, trans
extra = extra.to(dtype)
extra_conds[o] = extra

with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()

model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)

def get_dtype(self):
Expand Down
18 changes: 17 additions & 1 deletion comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,8 @@ def unet_inital_load_device(parameters, dtype):
def unet_dtype(device=None, model_params=0):
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
return torch.float16
if args.fp8_e4m3fn_unet:
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
Expand All @@ -488,6 +490,20 @@ def unet_dtype(device=None, model_params=0):
return torch.float16
return torch.float32

# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device):
if weight_dtype == torch.float32:
return None

fp16_supported = comfy.model_management.should_use_fp16(inference_device, prioritize_performance=False)
if fp16_supported and weight_dtype == torch.float16:
return None

if fp16_supported:
return torch.float16
else:
return torch.float32

def text_encoder_offload_device():
if args.gpu_only:
return get_torch_device()
Expand Down Expand Up @@ -552,7 +568,7 @@ def get_autocast_device(dev):
def supports_dtype(device, dtype): #TODO
if dtype == torch.float32:
return True
if torch.device("cpu") == device:
if is_device_cpu(device):
return False
if dtype == torch.float16:
return True
Expand Down
9 changes: 9 additions & 0 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
return s.Conv2d(*args, **kwargs)
elif dims == 3:
return s.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")

@contextmanager
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear
Expand Down
12 changes: 10 additions & 2 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o

parameters = comfy.utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)

class WeightsLoader(torch.nn.Module):
pass

model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
model_config.set_manual_cast(manual_cast_dtype)

if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))

Expand Down Expand Up @@ -501,7 +505,7 @@ class WeightsLoader(torch.nn.Module):
print("left over keys:", left_over)

if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
print("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
Expand All @@ -512,6 +516,9 @@ class WeightsLoader(torch.nn.Module):
def load_unet_state_dict(sd): #load unet in diffusers format
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)

if "input_blocks.0.0.weight" in sd: #ldm
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
if model_config is None:
Expand All @@ -532,13 +539,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
else:
print(diffusers_keys[k], k)
offload_device = model_management.unet_offload_device()
model_config.set_manual_cast(manual_cast_dtype)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys in unet:", left_over)
return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)

def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path)
Expand Down
4 changes: 4 additions & 0 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class BASE:
sampling_settings = {}
latent_format = latent_formats.LatentFormat

manual_cast_dtype = None

@classmethod
def matches(s, unet_config):
for k in s.unet_config:
Expand Down Expand Up @@ -71,3 +73,5 @@ def process_vae_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "first_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

def set_manual_cast(self, manual_cast_dtype):
self.manual_cast_dtype = manual_cast_dtype
25 changes: 11 additions & 14 deletions web/extensions/core/undoRedo.js
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,21 @@ function graphEqual(a, b, root = true) {
}

const undoRedo = async (e) => {
const updateState = async (source, target) => {
const prevState = source.pop();
if (prevState) {
target.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState, false);
activeState = prevState;
}
}
if (e.ctrlKey || e.metaKey) {
if (e.key === "y") {
const prevState = redo.pop();
if (prevState) {
undo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
updateState(redo, undo);
return true;
} else if (e.key === "z") {
const prevState = undo.pop();
if (prevState) {
redo.push(activeState);
isOurLoad = true;
await app.loadGraphData(prevState);
activeState = prevState;
}
updateState(undo, redo);
return true;
}
}
Expand Down
7 changes: 5 additions & 2 deletions web/scripts/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -1564,9 +1564,12 @@ export class ComfyApp {
/**
* Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object
* @param { boolean } clean If the graph state, e.g. images, should be cleared
*/
async loadGraphData(graphData) {
this.clean();
async loadGraphData(graphData, clean = true) {
if (clean !== false) {
this.clean();
}

let reset_invalid_values = false;
if (!graphData) {
Expand Down

0 comments on commit f30ff20

Please sign in to comment.