Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Dec 15, 2023
2 parents 12cac0f + a5056cf commit b34d978
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 8 deletions.
4 changes: 1 addition & 3 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def attention_basic(q, k, v, heads, mask=None):

# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale

Expand Down
6 changes: 1 addition & 5 deletions comfy_extras/nodes_sag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):

# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION =="fp32":
with torch.autocast(enabled=False, device_type = 'cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * scale
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * scale

Expand Down Expand Up @@ -111,7 +109,6 @@ def patch(self, model, scale, blur_sigma):
m = model.clone()

attn_scores = None
mid_block_shape = None

# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
Expand All @@ -134,7 +131,6 @@ def attn_and_record(q, k, v, extra_options):

def post_cfg_function(args):
nonlocal attn_scores
nonlocal mid_block_shape
uncond_attn = attn_scores

sag_scale = scale
Expand Down

0 comments on commit b34d978

Please sign in to comment.