Skip to content

Commit

Permalink
Refactor attention.application to preload q
Browse files Browse the repository at this point in the history
  • Loading branch information
voltjia committed Jan 2, 2025
1 parent f4d0fba commit d421777
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ def arrange_k_or_v(input):


def application(q, k, v, o):
q_loaded = (q * 1.44269504089).to(ntl.float16)

acc = ntl.zeros((q.shape[-2], q.shape[-1]), dtype=ntl.float32)
l_i = ntl.full((q.shape[-2],), 1, dtype=ntl.float32)
m_i = ntl.full((q.shape[-2],), float("-inf"), dtype=ntl.float32)

for i in range(k.shape[0]):
qk = ntl.dot((q * 1.44269504089).to(ntl.float16), ntl.trans(k[i]))
qk = ntl.dot(q_loaded, ntl.trans(k[i]))

m_ij = ntl.maximum(m_i, ntl.max(qk, 1))
p = ntl.exp2(qk - m_ij[:, None])
Expand Down

0 comments on commit d421777

Please sign in to comment.