From 77755ab8dbc74f3f231aa817590401d7969f96a4 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Dec 2023 23:27:13 -0500 Subject: [PATCH 1/4] Refactor comfy.ops comfy.ops -> comfy.ops.disable_weight_init This should make it more clear what they actually do. Some unused code has also been removed. --- comfy/cldm/cldm.py | 2 +- comfy/clip_vision.py | 2 +- comfy/controlnet.py | 27 ++---- comfy/ldm/modules/attention.py | 13 +-- comfy/ldm/modules/diffusionmodules/model.py | 39 ++++----- .../modules/diffusionmodules/openaimodel.py | 14 +-- comfy/ldm/modules/diffusionmodules/util.py | 41 --------- comfy/ldm/modules/temporal_ae.py | 5 +- comfy/model_base.py | 2 +- comfy/ops.py | 85 +++++++------------ 10 files changed, 77 insertions(+), 153 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index bbe5891e691..00373a7903f 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -53,7 +53,7 @@ def __init__( transformer_depth_middle=None, transformer_depth_output=None, device=None, - operations=comfy.ops, + operations=comfy.ops.disable_weight_init, **kwargs, ): super().__init__() diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index ae87c75b4d4..ba8a3a8d569 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -38,7 +38,7 @@ def __init__(self, json_config): if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False): self.dtype = torch.float16 - self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops) + self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.disable_weight_init) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6d37aa74f69..3212ac8c4b9 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -208,9 +208,9 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True, def forward(self, input): if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) else: - return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) + return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias) class Conv2d(torch.nn.Module): def __init__( @@ -247,24 +247,9 @@ def __init__( def forward(self, input): if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) - - def conv_nd(self, dims, *args, **kwargs): - if dims == 2: - return self.Conv2d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - - class Conv3d(comfy.ops.Conv3d): - pass - - class GroupNorm(comfy.ops.GroupNorm): - pass - - class LayerNorm(comfy.ops.LayerNorm): - pass + return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) class ControlLora(ControlNet): @@ -278,7 +263,9 @@ def pre_run(self, model, percent_to_timestep_function): controlnet_config = model.model_config.unet_config.copy() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - controlnet_config["operations"] = ControlLoraOps() + class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): + pass + controlnet_config["operations"] = control_lora_ops self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) dtype = model.get_dtype() self.control_model.to(dtype) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 8299b1d94bb..8d86aa53d2e 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -19,6 +19,7 @@ from comfy.cli_args import args import comfy.ops +ops = comfy.ops.disable_weight_init # CrossAttn precision handling if args.dont_upcast_attention: @@ -55,7 +56,7 @@ def init_(tensor): # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops): super().__init__() self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) @@ -65,7 +66,7 @@ def forward(self, x): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) @@ -356,7 +357,7 @@ def optimized_attention_for_device(device, mask=False): class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) @@ -389,7 +390,7 @@ def forward(self, x, context=None, value=None, mask=None): class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, - disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops): + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops): super().__init__() self.ff_in = ff_in or inner_dim is not None @@ -558,7 +559,7 @@ class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, - use_checkpoint=True, dtype=None, device=None, operations=comfy.ops): + use_checkpoint=True, dtype=None, device=None, operations=ops): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): context_dim = [context_dim] * depth @@ -632,7 +633,7 @@ def __init__( disable_self_attn=False, disable_temporal_crossattention=False, max_time_embed_period: int = 10000, - dtype=None, device=None, operations=comfy.ops + dtype=None, device=None, operations=ops ): super().__init__( in_channels, diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index f23417fd216..fce29cb85ec 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -8,6 +8,7 @@ from comfy import model_management import comfy.ops +ops = comfy.ops.disable_weight_init if model_management.xformers_enabled_vae(): import xformers @@ -48,7 +49,7 @@ def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: - self.conv = comfy.ops.Conv2d(in_channels, + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, @@ -78,7 +79,7 @@ def __init__(self, in_channels, with_conv): self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves - self.conv = comfy.ops.Conv2d(in_channels, + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, @@ -105,30 +106,30 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, self.swish = torch.nn.SiLU(inplace=True) self.norm1 = Normalize(in_channels) - self.conv1 = comfy.ops.Conv2d(in_channels, + self.conv1 = ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: - self.temb_proj = comfy.ops.Linear(temb_channels, + self.temb_proj = ops.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout, inplace=True) - self.conv2 = comfy.ops.Conv2d(out_channels, + self.conv2 = ops.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - self.conv_shortcut = comfy.ops.Conv2d(in_channels, + self.conv_shortcut = ops.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: - self.nin_shortcut = comfy.ops.Conv2d(in_channels, + self.nin_shortcut = ops.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, @@ -245,22 +246,22 @@ def __init__(self, in_channels): self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = comfy.ops.Conv2d(in_channels, + self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.k = comfy.ops.Conv2d(in_channels, + self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.v = comfy.ops.Conv2d(in_channels, + self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) - self.proj_out = comfy.ops.Conv2d(in_channels, + self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, @@ -312,14 +313,14 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList([ - comfy.ops.Linear(self.ch, + ops.Linear(self.ch, self.temb_ch), - comfy.ops.Linear(self.temb_ch, + ops.Linear(self.temb_ch, self.temb_ch), ]) # downsampling - self.conv_in = comfy.ops.Conv2d(in_channels, + self.conv_in = ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -388,7 +389,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, # end self.norm_out = Normalize(block_in) - self.conv_out = comfy.ops.Conv2d(block_in, + self.conv_out = ops.Conv2d(block_in, out_ch, kernel_size=3, stride=1, @@ -461,7 +462,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, self.in_channels = in_channels # downsampling - self.conv_in = comfy.ops.Conv2d(in_channels, + self.conv_in = ops.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, @@ -506,7 +507,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, # end self.norm_out = Normalize(block_in) - self.conv_out = comfy.ops.Conv2d(block_in, + self.conv_out = ops.Conv2d(block_in, 2*z_channels if double_z else z_channels, kernel_size=3, stride=1, @@ -541,7 +542,7 @@ class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - conv_out_op=comfy.ops.Conv2d, + conv_out_op=ops.Conv2d, resnet_op=ResnetBlock, attn_op=AttnBlock, **ignorekwargs): @@ -565,7 +566,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, self.z_shape, np.prod(self.z_shape))) # z to block_in - self.conv_in = comfy.ops.Conv2d(z_channels, + self.conv_in = ops.Conv2d(z_channels, block_in, kernel_size=3, stride=1, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 12efd833c51..057dd16b250 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -12,13 +12,13 @@ checkpoint, avg_pool_nd, zero_module, - normalization, timestep_embedding, AlphaBlender, ) from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists import comfy.ops +ops = comfy.ops.disable_weight_init class TimestepBlock(nn.Module): """ @@ -70,7 +70,7 @@ class Upsample(nn.Module): upsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -106,7 +106,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -159,7 +159,7 @@ def __init__( skip_t_emb=False, dtype=None, device=None, - operations=comfy.ops + operations=ops ): super().__init__() self.channels = channels @@ -284,7 +284,7 @@ def __init__( down: bool = False, dtype=None, device=None, - operations=comfy.ops + operations=ops ): super().__init__( channels, @@ -434,7 +434,7 @@ def __init__( disable_temporal_crossattention=False, max_ddpm_temb_period=10000, device=None, - operations=comfy.ops, + operations=ops, ): super().__init__() assert use_spatial_transformer == True, "use_spatial_transformer has to be true" @@ -581,7 +581,7 @@ def get_resblock( up=False, dtype=None, device=None, - operations=comfy.ops + operations=ops ): if self.use_temporal_resblocks: return VideoResBlock( diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 704bbe57450..68175b62a58 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -16,7 +16,6 @@ from einops import repeat, rearrange from comfy.ldm.util import instantiate_from_config -import comfy.ops class AlphaBlender(nn.Module): strategies = ["learned", "fixed", "learned_with_images"] @@ -273,46 +272,6 @@ def mean_flat(tensor): return tensor.mean(dim=list(range(1, len(tensor.shape)))) -def normalization(channels, dtype=None): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels, dtype=dtype) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return comfy.ops.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return comfy.ops.Linear(*args, **kwargs) - - def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py index 11ae049f3be..7ea68dc9e28 100644 --- a/comfy/ldm/modules/temporal_ae.py +++ b/comfy/ldm/modules/temporal_ae.py @@ -5,6 +5,7 @@ from einops import rearrange, repeat import comfy.ops +ops = comfy.ops.disable_weight_init from .diffusionmodules.model import ( AttnBlock, @@ -130,9 +131,9 @@ def __init__( time_embed_dim = self.in_channels * 4 self.video_time_embed = torch.nn.Sequential( - comfy.ops.Linear(self.in_channels, time_embed_dim), + ops.Linear(self.in_channels, time_embed_dim), torch.nn.SiLU(), - comfy.ops.Linear(time_embed_dim, self.in_channels), + ops.Linear(time_embed_dim, self.in_channels), ) self.merge_strategy = merge_strategy diff --git a/comfy/model_base.py b/comfy/model_base.py index bab7b9b340d..412c837925c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -48,7 +48,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None): if self.manual_cast_dtype is not None: operations = comfy.ops.manual_cast else: - operations = comfy.ops + operations = comfy.ops.disable_weight_init self.diffusion_model = UNetModel(**unet_config, device=device, operations=operations) self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) diff --git a/comfy/ops.py b/comfy/ops.py index a67bc809fd2..08c63384789 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1,33 +1,35 @@ import torch from contextlib import contextmanager -class Linear(torch.nn.Linear): - def reset_parameters(self): - return None +class disable_weight_init: + class Linear(torch.nn.Linear): + def reset_parameters(self): + return None -class Conv2d(torch.nn.Conv2d): - def reset_parameters(self): - return None + class Conv2d(torch.nn.Conv2d): + def reset_parameters(self): + return None -class Conv3d(torch.nn.Conv3d): - def reset_parameters(self): - return None + class Conv3d(torch.nn.Conv3d): + def reset_parameters(self): + return None -class GroupNorm(torch.nn.GroupNorm): - def reset_parameters(self): - return None + class GroupNorm(torch.nn.GroupNorm): + def reset_parameters(self): + return None -class LayerNorm(torch.nn.LayerNorm): - def reset_parameters(self): - return None + class LayerNorm(torch.nn.LayerNorm): + def reset_parameters(self): + return None -def conv_nd(dims, *args, **kwargs): - if dims == 2: - return Conv2d(*args, **kwargs) - elif dims == 3: - return Conv3d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") + @classmethod + def conv_nd(s, dims, *args, **kwargs): + if dims == 2: + return s.Conv2d(*args, **kwargs) + elif dims == 3: + return s.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") def cast_bias_weight(s, input): bias = None @@ -36,55 +38,28 @@ def cast_bias_weight(s, input): weight = s.weight.to(device=input.device, dtype=input.dtype) return weight, bias -class manual_cast: - class Linear(Linear): +class manual_cast(disable_weight_init): + class Linear(disable_weight_init.Linear): def forward(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) - class Conv2d(Conv2d): + class Conv2d(disable_weight_init.Conv2d): def forward(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) - class Conv3d(Conv3d): + class Conv3d(disable_weight_init.Conv3d): def forward(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) - class GroupNorm(GroupNorm): + class GroupNorm(disable_weight_init.GroupNorm): def forward(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) - class LayerNorm(LayerNorm): + class LayerNorm(disable_weight_init.LayerNorm): def forward(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) - - @classmethod - def conv_nd(s, dims, *args, **kwargs): - if dims == 2: - return s.Conv2d(*args, **kwargs) - elif dims == 3: - return s.Conv3d(*args, **kwargs) - else: - raise ValueError(f"unsupported dimensions: {dims}") - -@contextmanager -def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way - old_torch_nn_linear = torch.nn.Linear - force_device = device - force_dtype = dtype - def linear_with_dtype(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): - if force_device is not None: - device = force_device - if force_dtype is not None: - dtype = force_dtype - return Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) - - torch.nn.Linear = linear_with_dtype - try: - yield - finally: - torch.nn.Linear = old_torch_nn_linear From 3152023fbc4f8ee6598a863314ca98d48ea9c2e6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 11 Dec 2023 23:50:38 -0500 Subject: [PATCH 2/4] Use inference dtype for unet memory usage estimation. --- comfy/model_base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 412c837925c..a7582b330d9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -177,9 +177,12 @@ def set_inpaint(self): def memory_required(self, input_shape): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): + dtype = self.get_dtype() + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype #TODO: this needs to be tweaked area = input_shape[0] * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024) + return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * input_shape[2] * input_shape[3] From 32b7e7e769c206a06bf6e10ad2ddb6af9a378f56 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Dec 2023 03:32:23 -0500 Subject: [PATCH 3/4] Add manual cast to controlnet. --- comfy/cldm/cldm.py | 28 +++++++++++------------ comfy/controlnet.py | 54 +++++++++++++++++++++++++++------------------ 2 files changed, 46 insertions(+), 36 deletions(-) diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 00373a7903f..5eee5a51c95 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -141,24 +141,24 @@ def __init__( ) ] ) - self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)]) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)]) self.input_hint_block = TimestepEmbedSequential( - operations.conv_nd(dims, hint_channels, 16, 3, padding=1), + operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 16, 3, padding=1), + operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2), + operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 32, 3, padding=1), + operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2), + operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 96, 3, padding=1), + operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2), + operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device), nn.SiLU(), - zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1)) + operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device) ) self._feature_size = model_channels @@ -206,7 +206,7 @@ def __init__( ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: @@ -234,7 +234,7 @@ def __init__( ) ch = out_ch input_block_chans.append(ch) - self.zero_convs.append(self.make_zero_conv(ch, operations=operations)) + self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)) ds *= 2 self._feature_size += ch @@ -276,11 +276,11 @@ def __init__( operations=operations )] self.middle_block = TimestepEmbedSequential(*mid_block) - self.middle_block_out = self.make_zero_conv(ch, operations=operations) + self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device) self._feature_size += ch - def make_zero_conv(self, channels, operations=None): - return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0))) + def make_zero_conv(self, channels, operations=None, dtype=None, device=None): + return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device)) def forward(self, x, hint, timesteps, context, y=None, **kwargs): t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 3212ac8c4b9..110b5c7c290 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -36,13 +36,13 @@ def __init__(self, device=None): self.cond_hint = None self.strength = 1.0 self.timestep_percent_range = (0.0, 1.0) + self.global_average_pooling = False self.timestep_range = None if device is None: device = comfy.model_management.get_torch_device() self.device = device self.previous_controlnet = None - self.global_average_pooling = False def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): self.cond_hint_original = cond_hint @@ -77,6 +77,7 @@ def copy_to(self, c): c.cond_hint_original = self.cond_hint_original c.strength = self.strength c.timestep_percent_range = self.timestep_percent_range + c.global_average_pooling = self.global_average_pooling def inference_memory_requirements(self, dtype): if self.previous_controlnet is not None: @@ -129,12 +130,14 @@ def control_merge(self, control_input, control_output, control_prev, output_dtyp return out class ControlNet(ControlBase): - def __init__(self, control_model, global_average_pooling=False, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) self.control_model = control_model - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + self.load_device = load_device + self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.global_average_pooling = global_average_pooling self.model_sampling_current = None + self.manual_cast_dtype = manual_cast_dtype def get_control(self, x_noisy, t, cond, batched_number): control_prev = None @@ -149,11 +152,8 @@ def get_control(self, x_noisy, t, cond, batched_number): return None dtype = self.control_model.dtype - if comfy.model_management.supports_dtype(self.device, dtype): - precision_scope = lambda a: contextlib.nullcontext(a) - else: - precision_scope = torch.autocast - dtype = torch.float32 + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype output_dtype = x_noisy.dtype if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: @@ -171,12 +171,11 @@ def get_control(self, x_noisy, t, cond, batched_number): timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - with precision_scope(comfy.model_management.get_autocast_device(self.device)): - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) return self.control_merge(None, control, control_prev, output_dtype) def copy(self): - c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) self.copy_to(c) return c @@ -207,10 +206,11 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True, self.bias = None def forward(self, input): + weight, bias = comfy.ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias) + return torch.nn.functional.linear(input, weight, bias) class Conv2d(torch.nn.Module): def __init__( @@ -246,10 +246,11 @@ def __init__( def forward(self, input): + weight, bias = comfy.ops.cast_bias_weight(self, input) if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) class ControlLora(ControlNet): @@ -263,12 +264,19 @@ def pre_run(self, model, percent_to_timestep_function): controlnet_config = model.model_config.unet_config.copy() controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1] - class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): - pass + self.manual_cast_dtype = model.manual_cast_dtype + dtype = model.get_dtype() + if self.manual_cast_dtype is None: + class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init): + pass + else: + class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast): + pass + dtype = self.manual_cast_dtype + controlnet_config["operations"] = control_lora_ops + controlnet_config["dtype"] = dtype self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) - dtype = model.get_dtype() - self.control_model.to(dtype) self.control_model.to(comfy.model_management.get_torch_device()) diffusion_model = model.diffusion_model sd = diffusion_model.state_dict() @@ -372,6 +380,10 @@ def load_controlnet(ckpt_path, model=None): if controlnet_config is None: unet_dtype = comfy.model_management.unet_dtype() controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config + load_device = comfy.model_management.get_torch_device() + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype is not None: + controlnet_config["operations"] = comfy.ops.manual_cast controlnet_config.pop("out_channels") controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1] control_model = comfy.cldm.cldm.ControlNet(**controlnet_config) @@ -400,14 +412,12 @@ class WeightsLoader(torch.nn.Module): missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False) print(missing, unexpected) - control_model = control_model.to(unet_dtype) - global_average_pooling = False filename = os.path.splitext(ckpt_path)[0] if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling global_average_pooling = True - control = ControlNet(control_model, global_average_pooling=global_average_pooling) + control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype) return control class T2IAdapter(ControlBase): From 824e4935f53fdbda8f4608f511b4c2e8daf79dfa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 12 Dec 2023 12:03:29 -0500 Subject: [PATCH 4/4] Add dtype parameter to VAE object. --- comfy/sd.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 8c056e4ea2f..220637a05d7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -151,7 +151,7 @@ def get_key_patches(self): return self.patcher.get_key_patches() class VAE: - def __init__(self, sd=None, device=None, config=None): + def __init__(self, sd=None, device=None, config=None, dtype=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) @@ -188,7 +188,9 @@ def __init__(self, sd=None, device=None, config=None): device = model_management.vae_device() self.device = device offload_device = model_management.vae_offload_device() - self.vae_dtype = model_management.vae_dtype() + if dtype is None: + dtype = model_management.vae_dtype() + self.vae_dtype = dtype self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device()