Skip to content

Commit 0c2c9fb

Browse files
Support attention mask in split attention.
1 parent 3ad0191 commit 0c2c9fb

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

comfy/ldm/modules/attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,12 @@ def attention_split(q, k, v, heads, mask=None):
239239
else:
240240
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
241241

242+
if mask is not None:
243+
if len(mask.shape) == 2:
244+
s1 += mask[i:end]
245+
else:
246+
s1 += mask[:, i:end]
247+
242248
s2 = s1.softmax(dim=-1).to(v.dtype)
243249
del s1
244250
first_op_done = True

0 commit comments

Comments
 (0)