Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Dec 12, 2023
2 parents f30ff20 + 824e493 commit c186944
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 186 deletions.
30 changes: 15 additions & 15 deletions comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=comfy.ops,
operations=comfy.ops.disable_weight_init,
**kwargs,
):
super().__init__()
Expand Down Expand Up @@ -141,24 +141,24 @@ def __init__(
)
]
)
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations)])
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])

self.input_hint_block = TimestepEmbedSequential(
operations.conv_nd(dims, hint_channels, 16, 3, padding=1),
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 16, 3, padding=1),
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2),
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 32, 3, padding=1),
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2),
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 96, 3, padding=1),
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2),
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
nn.SiLU(),
zero_module(operations.conv_nd(dims, 256, model_channels, 3, padding=1))
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
)

self._feature_size = model_channels
Expand Down Expand Up @@ -206,7 +206,7 @@ def __init__(
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
Expand Down Expand Up @@ -234,7 +234,7 @@ def __init__(
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch, operations=operations))
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
ds *= 2
self._feature_size += ch

Expand Down Expand Up @@ -276,11 +276,11 @@ def __init__(
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
self._feature_size += ch

def make_zero_conv(self, channels, operations=None):
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))

def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
Expand Down
2 changes: 1 addition & 1 deletion comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, json_config):
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16

self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.disable_weight_init)

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
Expand Down
69 changes: 33 additions & 36 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def __init__(self, device=None):
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (0.0, 1.0)
self.global_average_pooling = False
self.timestep_range = None

if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.global_average_pooling = False

def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint
Expand Down Expand Up @@ -77,6 +77,7 @@ def copy_to(self, c):
c.cond_hint_original = self.cond_hint_original
c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range
c.global_average_pooling = self.global_average_pooling

def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
Expand Down Expand Up @@ -129,12 +130,14 @@ def control_merge(self, control_input, control_output, control_prev, output_dtyp
return out

class ControlNet(ControlBase):
def __init__(self, control_model, global_average_pooling=False, device=None):
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device)
self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.load_device = load_device
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype

def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
Expand All @@ -149,11 +152,8 @@ def get_control(self, x_noisy, t, cond, batched_number):
return None

dtype = self.control_model.dtype
if comfy.model_management.supports_dtype(self.device, dtype):
precision_scope = lambda a: contextlib.nullcontext(a)
else:
precision_scope = torch.autocast
dtype = torch.float32
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype

output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
Expand All @@ -171,12 +171,11 @@ def get_control(self, x_noisy, t, cond, batched_number):
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

with precision_scope(comfy.model_management.get_autocast_device(self.device)):
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)

def copy(self):
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
self.copy_to(c)
return c

Expand Down Expand Up @@ -207,10 +206,11 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
self.bias = None

def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
return torch.nn.functional.linear(input, weight, bias)

class Conv2d(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -246,25 +246,11 @@ def __init__(


def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)

def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")

class Conv3d(comfy.ops.Conv3d):
pass

class GroupNorm(comfy.ops.GroupNorm):
pass

class LayerNorm(comfy.ops.LayerNorm):
pass
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)


class ControlLora(ControlNet):
Expand All @@ -278,10 +264,19 @@ def pre_run(self, model, percent_to_timestep_function):
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
self.manual_cast_dtype = model.manual_cast_dtype
dtype = model.get_dtype()
self.control_model.to(dtype)
if self.manual_cast_dtype is None:
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
pass
else:
class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
pass
dtype = self.manual_cast_dtype

controlnet_config["operations"] = control_lora_ops
controlnet_config["dtype"] = dtype
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict()
Expand Down Expand Up @@ -385,6 +380,10 @@ def load_controlnet(ckpt_path, model=None):
if controlnet_config is None:
unet_dtype = comfy.model_management.unet_dtype()
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
Expand Down Expand Up @@ -413,14 +412,12 @@ class WeightsLoader(torch.nn.Module):
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
print(missing, unexpected)

control_model = control_model.to(unet_dtype)

global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
global_average_pooling = True

control = ControlNet(control_model, global_average_pooling=global_average_pooling)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control

class T2IAdapter(ControlBase):
Expand Down
13 changes: 7 additions & 6 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init

# CrossAttn precision handling
if args.dont_upcast_attention:
Expand Down Expand Up @@ -55,7 +56,7 @@ def init_(tensor):

# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)

Expand All @@ -65,7 +66,7 @@ def forward(self, x):


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
Expand Down Expand Up @@ -363,7 +364,7 @@ def optimized_attention_for_device(device, mask=False):


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
Expand Down Expand Up @@ -396,7 +397,7 @@ def forward(self, x, context=None, value=None, mask=None):

class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
super().__init__()

self.ff_in = ff_in or inner_dim is not None
Expand Down Expand Up @@ -565,7 +566,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
use_checkpoint=True, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
Expand Down Expand Up @@ -639,7 +640,7 @@ def __init__(
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None, device=None, operations=comfy.ops
dtype=None, device=None, operations=ops
):
super().__init__(
in_channels,
Expand Down
Loading

0 comments on commit c186944

Please sign in to comment.