Skip to content

Commit 166f8d6

Browse files
committed
Merge branch 'master' into beta
2 parents 694d3bf + 248d912 commit 166f8d6

File tree

15 files changed

+802
-278
lines changed

15 files changed

+802
-278
lines changed

comfy/clip_model.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import torch
2+
from comfy.ldm.modules.attention import optimized_attention_for_device
3+
4+
class CLIPAttention(torch.nn.Module):
5+
def __init__(self, embed_dim, heads, dtype, device, operations):
6+
super().__init__()
7+
8+
self.heads = heads
9+
self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
10+
self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11+
self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12+
13+
self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
14+
15+
def forward(self, x, mask=None, optimized_attention=None):
16+
q = self.q_proj(x)
17+
k = self.k_proj(x)
18+
v = self.v_proj(x)
19+
20+
out = optimized_attention(q, k, v, self.heads, mask)
21+
return self.out_proj(out)
22+
23+
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
24+
"gelu": torch.nn.functional.gelu,
25+
}
26+
27+
class CLIPMLP(torch.nn.Module):
28+
def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
29+
super().__init__()
30+
self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
31+
self.activation = ACTIVATIONS[activation]
32+
self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
33+
34+
def forward(self, x):
35+
x = self.fc1(x)
36+
x = self.activation(x)
37+
x = self.fc2(x)
38+
return x
39+
40+
class CLIPLayer(torch.nn.Module):
41+
def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
42+
super().__init__()
43+
self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
44+
self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
45+
self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
46+
self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
47+
48+
def forward(self, x, mask=None, optimized_attention=None):
49+
x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
50+
x += self.mlp(self.layer_norm2(x))
51+
return x
52+
53+
54+
class CLIPEncoder(torch.nn.Module):
55+
def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
56+
super().__init__()
57+
self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
58+
59+
def forward(self, x, mask=None, intermediate_output=None):
60+
optimized_attention = optimized_attention_for_device(x.device, mask=True)
61+
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
62+
if mask is not None:
63+
mask += causal_mask
64+
else:
65+
mask = causal_mask
66+
67+
if intermediate_output is not None:
68+
if intermediate_output < 0:
69+
intermediate_output = len(self.layers) + intermediate_output
70+
71+
intermediate = None
72+
for i, l in enumerate(self.layers):
73+
x = l(x, mask, optimized_attention)
74+
if i == intermediate_output:
75+
intermediate = x.clone()
76+
return x, intermediate
77+
78+
class CLIPEmbeddings(torch.nn.Module):
79+
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
80+
super().__init__()
81+
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
82+
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
83+
84+
def forward(self, input_tokens):
85+
return self.token_embedding(input_tokens) + self.position_embedding.weight
86+
87+
88+
class CLIPTextModel_(torch.nn.Module):
89+
def __init__(self, config_dict, dtype, device, operations):
90+
num_layers = config_dict["num_hidden_layers"]
91+
embed_dim = config_dict["hidden_size"]
92+
heads = config_dict["num_attention_heads"]
93+
intermediate_size = config_dict["intermediate_size"]
94+
intermediate_activation = config_dict["hidden_act"]
95+
96+
super().__init__()
97+
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
98+
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
99+
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
100+
101+
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
102+
x = self.embeddings(input_tokens)
103+
#TODO: attention_mask
104+
x, i = self.encoder(x, intermediate_output=intermediate_output)
105+
x = self.final_layer_norm(x)
106+
if i is not None and final_layer_norm_intermediate:
107+
i = self.final_layer_norm(i)
108+
109+
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
110+
return x, i, pooled_output
111+
112+
class CLIPTextModel(torch.nn.Module):
113+
def __init__(self, config_dict, dtype, device, operations):
114+
super().__init__()
115+
self.num_layers = config_dict["num_hidden_layers"]
116+
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
117+
self.dtype = dtype
118+
119+
def get_input_embeddings(self):
120+
return self.text_model.embeddings.token_embedding
121+
122+
def set_input_embeddings(self, embeddings):
123+
self.text_model.embeddings.token_embedding = embeddings
124+
125+
def forward(self, *args, **kwargs):
126+
return self.text_model(*args, **kwargs)

comfy/ldm/modules/attention.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,13 @@ def attention_basic(q, k, v, heads, mask=None):
112112
del q, k
113113

114114
if exists(mask):
115-
mask = rearrange(mask, 'b ... -> b (...)')
116-
max_neg_value = -torch.finfo(sim.dtype).max
117-
mask = repeat(mask, 'b j -> (b h) () j', h=h)
118-
sim.masked_fill_(~mask, max_neg_value)
115+
if mask.dtype == torch.bool:
116+
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
117+
max_neg_value = -torch.finfo(sim.dtype).max
118+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
119+
sim.masked_fill_(~mask, max_neg_value)
120+
else:
121+
sim += mask
119122

120123
# attention, what we cannot get enough of
121124
sim = sim.softmax(dim=-1)
@@ -347,6 +350,18 @@ def attention_pytorch(q, k, v, heads, mask=None):
347350
if model_management.pytorch_attention_enabled():
348351
optimized_attention_masked = attention_pytorch
349352

353+
def optimized_attention_for_device(device, mask=False):
354+
if device == torch.device("cpu"): #TODO
355+
if model_management.pytorch_attention_enabled():
356+
return attention_pytorch
357+
else:
358+
return attention_basic
359+
if mask:
360+
return optimized_attention_masked
361+
362+
return optimized_attention
363+
364+
350365
class CrossAttention(nn.Module):
351366
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
352367
super().__init__()
@@ -391,7 +406,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=
391406
self.is_res = inner_dim == dim
392407

393408
if self.ff_in:
394-
self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device)
409+
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
395410
self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
396411

397412
self.disable_self_attn = disable_self_attn

comfy/model_management.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -564,12 +564,12 @@ def cast_to_device(tensor, device, dtype, copy=False):
564564
if device_supports_cast:
565565
if copy:
566566
if tensor.device == device:
567-
return tensor.to(dtype, copy=copy)
568-
return tensor.to(device, copy=copy).to(dtype)
567+
return tensor.to(dtype, copy=copy, non_blocking=True)
568+
return tensor.to(device, copy=copy, non_blocking=True).to(dtype, non_blocking=True)
569569
else:
570-
return tensor.to(device).to(dtype)
570+
return tensor.to(device, non_blocking=True).to(dtype, non_blocking=True)
571571
else:
572-
return tensor.to(dtype).to(device, copy=copy)
572+
return tensor.to(device, dtype, copy=copy, non_blocking=True)
573573

574574
def xformers_enabled():
575575
global directml_enabled

comfy/sd1_clip.py

Lines changed: 22 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import os
22

3-
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig, modeling_utils
3+
from transformers import CLIPTokenizer
44
import comfy.ops
55
import torch
66
import traceback
77
import zipfile
88
from . import model_management
99
import contextlib
10+
import comfy.clip_model
11+
import json
1012

1113
def gen_empty_tokens(special_tokens, length):
1214
start_token = special_tokens.get("start", None)
@@ -65,35 +67,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
6567
"hidden"
6668
]
6769
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
68-
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
69-
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
70-
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
70+
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
71+
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True): # clip-vit-base-patch32
7172
super().__init__()
7273
assert layer in self.LAYERS
73-
self.num_layers = 12
74-
if textmodel_path is not None:
75-
self.transformer = model_class.from_pretrained(textmodel_path)
76-
else:
77-
if textmodel_json_config is None:
78-
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
79-
config = config_class.from_json_file(textmodel_json_config)
80-
self.num_layers = config.num_hidden_layers
81-
with comfy.ops.use_comfy_ops(device, dtype):
82-
with modeling_utils.no_init_weights():
83-
self.transformer = model_class(config)
84-
85-
self.inner_name = inner_name
86-
if dtype is not None:
87-
inner_model = getattr(self.transformer, self.inner_name)
88-
if hasattr(inner_model, "embeddings"):
89-
embeddings_bak = inner_model.embeddings.to(torch.float32)
90-
inner_model.embeddings = None
91-
self.transformer.to(dtype)
92-
inner_model.embeddings = embeddings_bak
93-
else:
94-
previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True)
95-
self.transformer.to(dtype)
96-
self.transformer.set_input_embeddings(previous_inputs)
74+
75+
if textmodel_json_config is None:
76+
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
77+
78+
with open(textmodel_json_config) as f:
79+
config = json.load(f)
80+
81+
self.transformer = model_class(config, dtype, device, comfy.ops)
82+
self.num_layers = self.transformer.num_layers
9783

9884
self.max_length = max_length
9985
if freeze:
@@ -108,7 +94,7 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le
10894
self.layer_norm_hidden_state = layer_norm_hidden_state
10995
if layer == "hidden":
11096
assert layer_idx is not None
111-
assert abs(layer_idx) <= self.num_layers
97+
assert abs(layer_idx) < self.num_layers
11298
self.clip_layer(layer_idx)
11399
self.layer_default = (self.layer, self.layer_idx)
114100

@@ -119,7 +105,7 @@ def freeze(self):
119105
param.requires_grad = False
120106

121107
def clip_layer(self, layer_idx):
122-
if abs(layer_idx) >= self.num_layers:
108+
if abs(layer_idx) > self.num_layers:
123109
self.layer = "last"
124110
else:
125111
self.layer = "hidden"
@@ -174,7 +160,7 @@ def forward(self, tokens):
174160
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
175161
tokens = torch.LongTensor(tokens).to(device)
176162

177-
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
163+
if self.transformer.dtype != torch.float32:
178164
precision_scope = torch.autocast
179165
else:
180166
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
@@ -190,20 +176,16 @@ def forward(self, tokens):
190176
if tokens[x, y] == max_token:
191177
break
192178

193-
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, output_hidden_states=self.layer=="hidden")
179+
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
194180
self.transformer.set_input_embeddings(backup_embeds)
195181

196182
if self.layer == "last":
197-
z = outputs.last_hidden_state
198-
elif self.layer == "pooled":
199-
z = outputs.pooler_output[:, None, :]
183+
z = outputs[0]
200184
else:
201-
z = outputs.hidden_states[self.layer_idx]
202-
if self.layer_norm_hidden_state:
203-
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
185+
z = outputs[1]
204186

205-
if hasattr(outputs, "pooler_output"):
206-
pooled_output = outputs.pooler_output.float()
187+
if outputs[2] is not None:
188+
pooled_output = outputs[2].float()
207189
else:
208190
pooled_output = None
209191

comfy/sd2_clip.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import os
44

55
class SD2ClipHModel(sd1_clip.SDClipModel):
6-
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
6+
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
77
if layer == "penultimate":
88
layer="hidden"
9-
layer_idx=23
9+
layer_idx=-2
1010

1111
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
12-
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
12+
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
1313

1414
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
1515
def __init__(self, tokenizer_path=None, embedding_directory=None):

comfy/sdxl_clip.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import os
44

55
class SDXLClipG(sd1_clip.SDClipModel):
6-
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
6+
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None):
77
if layer == "penultimate":
88
layer="hidden"
99
layer_idx=-2
1010

1111
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
12-
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
12+
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
1313
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
1414

1515
def load_sd(self, sd):
@@ -37,7 +37,7 @@ def untokenize(self, token_weight_pair):
3737
class SDXLClipModel(torch.nn.Module):
3838
def __init__(self, device="cpu", dtype=None):
3939
super().__init__()
40-
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
40+
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
4141
self.clip_g = SDXLClipG(device=device, dtype=dtype)
4242

4343
def clip_layer(self, layer_idx):

comfy_extras/nodes_hypertile.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22

33
import math
44
from einops import rearrange
5-
import random
5+
# Use torch rng for consistency across generations
6+
from torch import randint
67

7-
def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
8+
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
89
min_value = min(min_value, value)
910

1011
# All big divisors of value (inclusive)
1112
divisors = [i for i in range(min_value, value + 1) if value % i == 0]
1213

1314
ns = [value // i for i in divisors[:max_options]] # has at least 1 element
1415

15-
random.seed(counter)
16-
idx = random.randint(0, len(ns) - 1)
16+
idx = randint(low=0, high=len(ns) - 1, size=(1,)).item()
1717

1818
return ns[idx]
1919

@@ -42,7 +42,6 @@ def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
4242

4343
latent_tile_size = max(32, tile_size) // 8
4444
self.temp = None
45-
self.counter = 1
4645

4746
def hypertile_in(q, k, v, extra_options):
4847
if q.shape[-1] in apply_to:
@@ -53,10 +52,8 @@ def hypertile_in(q, k, v, extra_options):
5352
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
5453

5554
factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
56-
nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
57-
self.counter += 1
58-
nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
59-
self.counter += 1
55+
nh = random_divisor(h, latent_tile_size * factor, swap_size)
56+
nw = random_divisor(w, latent_tile_size * factor, swap_size)
6057

6158
if nh * nw > 1:
6259
q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)

0 commit comments

Comments
 (0)