Skip to content

Commit

Permalink
Merge branch 'master' into vae-fallback-cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Nov 11, 2023
2 parents f48f2f8 + 248aa3e commit bbf4a40
Show file tree
Hide file tree
Showing 64 changed files with 8,469 additions and 2,873 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/test-ui.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Tests CI

on: [push, pull_request]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v3
with:
node-version: 18
- uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
- name: Run Tests
run: |
npm ci
npm run test:generate
npm test
working-directory: ./tests-ui
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ venv/
/web/extensions/*
!/web/extensions/logging.js.example
!/web/extensions/core/
/tests-ui/data/object_info.json
47 changes: 26 additions & 21 deletions comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
Expand All @@ -52,6 +51,7 @@ def __init__(
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=comfy.ops,
):
Expand Down Expand Up @@ -79,29 +79,24 @@ def __init__(
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]

if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks

if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.")

self.attention_resolutions = attention_resolutions
transformer_depth = transformer_depth[:]

self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
Expand Down Expand Up @@ -180,11 +175,14 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
operations=operations
dtype=self.dtype,
device=device,
operations=operations,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_transformers = transformer_depth.pop(0)
if num_transformers > 0:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
Expand All @@ -201,9 +199,9 @@ def __init__(
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, operations=operations
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
Expand All @@ -223,11 +221,13 @@ def __init__(
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
dtype=self.dtype,
device=device,
operations=operations
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
)
)
)
Expand All @@ -245,20 +245,23 @@ def __init__(
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
mid_block = [
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
SpatialTransformer( # always uses a self-attn
)]
if transformer_depth_middle >= 0:
mid_block += [SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint, operations=operations
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
),
ResBlock(
ch,
Expand All @@ -267,9 +270,11 @@ def __init__(
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
),
)
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
self._feature_size += ch

Expand Down
2 changes: 2 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __call__(self, parser, namespace, values, option_string=None):
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")

parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
Expand Down
36 changes: 19 additions & 17 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
from .utils import load_torch_file, transformers_convert
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert, common_upscale
import os
import torch
import contextlib

import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.utils

def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2]))
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])

class ClipVisionModel():
def __init__(self, json_config):
Expand All @@ -23,25 +35,12 @@ def __init__(self, json_config):
self.model.to(self.dtype)

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)

def encode_image(self, image):
img = torch.clip((255. * image), 0, 255).round().int()
img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt")
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = inputs['pixel_values'].to(self.load_device)
pixel_values = clip_preprocess(image.to(self.load_device))

if self.dtype != torch.float32:
precision_scope = torch.autocast
Expand Down Expand Up @@ -92,8 +91,11 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
else:
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
else:
return None

clip = ClipVisionModel(json_config)
m, u = clip.load_sd(sd)
if len(m) > 0:
Expand Down
79 changes: 79 additions & 0 deletions comfy/conds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import enum
import torch
import math
import comfy.utils


def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
return abs(a*b) // math.gcd(a, b)

class CONDRegular:
def __init__(self, cond):
self.cond = cond

def _copy_with(self, cond):
return self.__class__(cond)

def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))

def can_concat(self, other):
if self.cond.shape != other.cond.shape:
return False
return True

def concat(self, others):
conds = [self.cond]
for x in others:
conds.append(x.cond)
return torch.cat(conds)

class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))


class CONDCrossAttn(CONDRegular):
def can_concat(self, other):
s1 = self.cond.shape
s2 = other.cond.shape
if s1 != s2:
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False

mult_min = lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
return True

def concat(self, others):
conds = [self.cond]
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
conds.append(c)

out = []
for c in conds:
if c.shape[1] < crossattn_max_len:
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)

class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond

def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(self.cond)

def can_concat(self, other):
if self.cond != other.cond:
return False
return True

def concat(self, others):
return self.cond
18 changes: 15 additions & 3 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(self, control_model, global_average_pooling=False, device=None):
self.control_model = control_model
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling
self.model_sampling_current = None

def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
Expand All @@ -156,10 +157,13 @@ def get_control(self, x_noisy, t, cond, batched_number):


context = cond['c_crossattn']
y = cond.get('c_adm', None)
y = cond.get('y', None)
if y is not None:
y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)

control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype)

def copy(self):
Expand All @@ -172,6 +176,14 @@ def get_models(self):
out.append(self.control_model_wrapped)
return out

def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.model_sampling_current = model.model_sampling

def cleanup(self):
self.model_sampling_current = None
super().cleanup()

class ControlLoraOps:
class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
Expand Down Expand Up @@ -416,7 +428,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
if control_prev is not None:
return control_prev
else:
return {}
return None

if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
if self.cond_hint is not None:
Expand Down
Loading

0 comments on commit bbf4a40

Please sign in to comment.