diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index e085186ef68..9e2e03d7238 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,5 +1,5 @@ -from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils -from .utils import load_torch_file, transformers_convert +from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils +from .utils import load_torch_file, transformers_convert, common_upscale import os import torch import contextlib @@ -7,6 +7,18 @@ import comfy.ops import comfy.model_patcher import comfy.model_management +import comfy.utils + +def clip_preprocess(image, size=224): + mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype) + std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype) + scale = (size / min(image.shape[1], image.shape[2])) + image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3,1,1])) / std.view([3,1,1]) class ClipVisionModel(): def __init__(self, json_config): @@ -23,25 +35,12 @@ def __init__(self, json_config): self.model.to(self.dtype) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) - self.processor = CLIPImageProcessor(crop_size=224, - do_center_crop=True, - do_convert_rgb=True, - do_normalize=True, - do_resize=True, - image_mean=[ 0.48145466,0.4578275,0.40821073], - image_std=[0.26862954,0.26130258,0.27577711], - resample=3, #bicubic - size=224) - def load_sd(self, sd): return self.model.load_state_dict(sd, strict=False) def encode_image(self, image): - img = torch.clip((255. * image), 0, 255).round().int() - img = list(map(lambda a: a, img)) - inputs = self.processor(images=img, return_tensors="pt") comfy.model_management.load_model_gpu(self.patcher) - pixel_values = inputs['pixel_values'].to(self.load_device) + pixel_values = clip_preprocess(image.to(self.load_device)) if self.dtype != torch.float32: precision_scope = torch.autocast diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index da7d8de5ef0..f111e7364bf 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -222,9 +222,14 @@ def attention_split(q, k, v, heads, mask=None): mem_free_total = model_management.get_free_memory(q.device) + if _ATTN_PRECISION =="fp32": + element_size = 4 + else: + element_size = q.element_size() + gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size + modifier = 3 if element_size == 2 else 2.5 mem_required = tensor_size * modifier steps = 1 diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 4d42059b5a8..8e8e8054dfd 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -83,7 +83,8 @@ def _summarize_chunk( ) max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() - torch.exp(attn_weights - max_score, out=attn_weights) + attn_weights -= max_score + torch.exp(attn_weights, out=attn_weights) exp_weights = attn_weights.to(value.dtype) exp_values = torch.bmm(exp_weights, value) max_score = max_score.squeeze(-1)