Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Nov 8, 2023
2 parents d93cab1 + 2a23ba0 commit 0bee4d8
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 46 deletions.
4 changes: 2 additions & 2 deletions comfy/ldm/modules/diffusionmodules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
Expand Down
17 changes: 17 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def __init__(self, model, load_device, offload_device, size=0, current_device=No
self.model = model
self.patches = {}
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
Expand Down Expand Up @@ -91,6 +93,9 @@ def set_model_attn2_output_patch(self, patch):
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")

def add_object_patch(self, name, obj):
self.object_patches[name] = obj

def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:
Expand Down Expand Up @@ -150,6 +155,12 @@ def model_state_dict(self, filter_prefix=None):
return sd

def patch_model(self, device_to=None):
for k in self.object_patches:
old = getattr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])

model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
Expand Down Expand Up @@ -290,3 +301,9 @@ def unpatch_model(self, device_to=None):
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to

keys = list(self.object_patches_backup.keys())
for k in keys:
setattr(self.model, k, self.object_patches_backup[k])

self.object_patches_backup = {}
2 changes: 2 additions & 0 deletions comfy/model_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))

sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)

def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())

Expand Down
118 changes: 82 additions & 36 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,54 @@
from . import model_management
import contextlib

def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
end_token = special_tokens.get("end", None)
pad_token = special_tokens.get("pad")
output = []
if start_token is not None:
output.append(start_token)
if end_token is not None:
output.append(end_token)
output += [pad_token] * (length - len(output))
return output

class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list(self.empty_tokens)
to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens)

sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))

out, pooled = self.encode(to_encode)
z_empty = out[0:1]
if pooled.shape[0] > 1:
first_pooled = pooled[1:2]
if pooled is not None:
first_pooled = pooled[0:1].cpu()
else:
first_pooled = pooled[0:1]
first_pooled = pooled

output = []
for k in range(1, out.shape[0]):
for k in range(0, sections):
z = out[k:k+1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k - 1][j][1]
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
if has_weights:
z_empty = out[-1]
for i in range(len(z)):
for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z)

if (len(output) == 0):
return z_empty.cpu(), first_pooled.cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled

class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
Expand All @@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden"
]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.num_layers = 12
if textmodel_path is not None:
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
self.transformer = model_class.from_pretrained(textmodel_path)
else:
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config)
config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with comfy.ops.use_comfy_ops(device, dtype):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
self.transformer = model_class(config)

self.inner_name = inner_name
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))

self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76]
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False

self.layer_norm_hidden_state = True
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers
Expand Down Expand Up @@ -117,7 +145,7 @@ def set_up_textual_embeddings(self, tokens, current_embeds):
else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x):
tokens_temp += [self.empty_tokens[0][-1]]
tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp]

n = token_dict_size
Expand All @@ -142,7 +170,7 @@ def forward(self, tokens):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)

if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
Expand All @@ -168,12 +196,16 @@ def forward(self, tokens):
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)

pooled_output = outputs.pooler_output
if self.text_projection is not None:
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None

if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output.float()
return z.float(), pooled_output

def encode(self, tokens):
return self(tokens)
Expand Down Expand Up @@ -343,17 +375,24 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out

class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.max_tokens_per_section = self.max_length - 2

empty = self.tokenizer('')["input_ids"]
self.start_token = empty[0]
self.end_token = empty[1]
if has_start_token:
self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length

vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory
Expand Down Expand Up @@ -414,11 +453,13 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
else:
continue
#parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])

#reshape token array to CLIP input size
batched_tokens = []
batch = [(self.start_token, 1.0, 0)]
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch
Expand All @@ -435,16 +476,21 @@ def tokenize_with_weights(self, text:str, return_word_ids=False):
#add end token and pad
else:
batch.append((self.end_token, 1.0, 0))
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch
batch = [(self.start_token, 1.0, 0)]
batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch)
else:
batch.extend([(t,w,i+1) for t,w in t_group])
t_group = []

#fill last batch
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))

if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
Expand Down
3 changes: 1 addition & 2 deletions comfy/sd2_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, la
layer_idx=23

textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})

class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None):
Expand Down
8 changes: 3 additions & 5 deletions comfy/sdxl_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate"
layer_idx=-2

textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
self.empty_tokens = [[49406] + [49407] + [0] * 75]
self.layer_norm_hidden_state = False
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)

def load_sd(self, sd):
return super().load_sd(sd)
Expand All @@ -38,8 +37,7 @@ def untokenize(self, token_weight_pair):
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
self.clip_l.layer_norm_hidden_state = False
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)

def clip_layer(self, layer_idx):
Expand Down
57 changes: 57 additions & 0 deletions comfy_extras/nodes_model_advanced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import folder_paths
import comfy.sd
import comfy.model_sampling


def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)

# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5

class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "advanced/model"

def patch(self, model, sampling, zsnr):
m = model.clone()

if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION

class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):
pass

model_sampling = ModelSamplingAdvanced()
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling)
return (m, )

NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
}
4 changes: 3 additions & 1 deletion comfy_extras/nodes_post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def INPUT_TYPES(s):
"max": 1.0,
"step": 0.01
}),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
},
}

Expand Down Expand Up @@ -54,6 +54,8 @@ def blend_mode(self, img1, img2, mode):
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
elif mode == "difference":
return img1 - img2
else:
raise ValueError(f"Unsupported blend mode: {mode}")

Expand Down
Loading

0 comments on commit 0bee4d8

Please sign in to comment.