Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Jan 7, 2024
2 parents d110107 + 0c2c9fb commit e057c9e
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ def attention_split(q, k, v, heads, mask=None):
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale

if mask is not None:
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]

s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
first_op_done = True
Expand Down Expand Up @@ -301,11 +307,14 @@ def attention_xformers(q, k, v, heads, mask=None):
(q, k, v),
)

# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]

out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)

if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, heads, -1, dim_head)
Expand Down

0 comments on commit e057c9e

Please sign in to comment.