Open
Description
Hey Aladdin, thanks for your tutorials!
I've been implementing the Transformer architecture and learning about einsum. Following your implementation (einsum) against one without einsum I found differences in the final result. Here is the code for reproducibility:
b, s, h, d = 2, 2, 2, 2
q = torch.randn((b, s, h, d))
k = torch.randn((b, s, h, d))
v = torch.randn((b, s, h, d))
q_mod = q.permute(0, 2, 1, 3) # [b, h, s, d]
k_mod = k.permute(0, 2, 3, 1) # [b, h, d, s]
classic_scores = torch.matmul(q_mod, k_mod)
classic_scores = torch.softmax(classic_scores / (d ** (1/2)), dim=3)
v_mod = v.permute(0, 2, 1, 3)
classic_att = torch.matmul(classic_scores, v_mod).reshape(b, s, h * d)
einstein_scores = torch.einsum("bqhd,bkhd->bhqk", q, k)
einstein_scores = torch.softmax(einstein_scores / (d ** (1/2)), dim=3)
einstein_att = torch.einsum("bhql,blhd->bqhd", einstein_scores, v).reshape(b, s, h * d)
assert torch.all(classic_scores == einstein_scores), "Scores doesn't match"
assert torch.all(classic_att == einstein_att), "Attention doesn't match"
The attention scores match perfectly, but the final attention score doesn't match. With my inputs, here is the result:
>>> print(classic_att)
tensor([[[ 1.1246, 0.1376, 1.2368, -0.6316],
[-2.1842, -0.0181, -2.2082, -0.0023]],
[[ 0.5911, 0.2132, -0.1727, 0.8552],
[ 0.2701, 0.0846, 0.2370, 0.1205]]])
>>> print(einstein_att)
tensor([[[ 1.1246, 0.1376, -2.1842, -0.0181],
[ 1.2368, -0.6316, -2.2082, -0.0023]],
[[ 0.5911, 0.2132, 0.2701, 0.0846],
[-0.1727, 0.8552, 0.2370, 0.1205]]])
It seems that the values aren't off, they are just transposed? I'm a newbie with einsum, and I couldn't figure it out. Hope someone can found the solution for this :)
Metadata
Metadata
Assignees
Labels
No labels