Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Oct 26, 2023
2 parents 1f595b5 + 723847f commit ee65c5b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 19 deletions.
31 changes: 15 additions & 16 deletions comfy/clip_vision.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, CLIPImageProcessor, modeling_utils
from .utils import load_torch_file, transformers_convert
from transformers import CLIPVisionModelWithProjection, CLIPVisionConfig, modeling_utils
from .utils import load_torch_file, transformers_convert, common_upscale
import os
import torch
import contextlib

import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.utils

def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
scale = (size / min(image.shape[1], image.shape[2]))
image = torch.nn.functional.interpolate(image.movedim(-1, 1), size=(round(scale * image.shape[1]), round(scale * image.shape[2])), mode="bicubic", antialias=True)
h = (image.shape[2] - size)//2
w = (image.shape[3] - size)//2
image = image[:,:,h:h+size,w:w+size]
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])

class ClipVisionModel():
def __init__(self, json_config):
Expand All @@ -23,25 +35,12 @@ def __init__(self, json_config):
self.model.to(self.dtype)

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.processor = CLIPImageProcessor(crop_size=224,
do_center_crop=True,
do_convert_rgb=True,
do_normalize=True,
do_resize=True,
image_mean=[ 0.48145466,0.4578275,0.40821073],
image_std=[0.26862954,0.26130258,0.27577711],
resample=3, #bicubic
size=224)

def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False)

def encode_image(self, image):
img = torch.clip((255. * image), 0, 255).round().int()
img = list(map(lambda a: a, img))
inputs = self.processor(images=img, return_tensors="pt")
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = inputs['pixel_values'].to(self.load_device)
pixel_values = clip_preprocess(image.to(self.load_device))

if self.dtype != torch.float32:
precision_scope = torch.autocast
Expand Down
9 changes: 7 additions & 2 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,14 @@ def attention_split(q, k, v, heads, mask=None):

mem_free_total = model_management.get_free_memory(q.device)

if _ATTN_PRECISION =="fp32":
element_size = 4
else:
element_size = q.element_size()

gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3 if element_size == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1

Expand Down
3 changes: 2 additions & 1 deletion comfy/ldm/modules/sub_quadratic_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def _summarize_chunk(
)
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
torch.exp(attn_weights - max_score, out=attn_weights)
attn_weights -= max_score
torch.exp(attn_weights, out=attn_weights)
exp_weights = attn_weights.to(value.dtype)
exp_values = torch.bmm(exp_weights, value)
max_score = max_score.squeeze(-1)
Expand Down

0 comments on commit ee65c5b

Please sign in to comment.