Skip to content

Commit 59f9e3b

Browse files
committed
Stabilizes dot accumulation precision
Accumulates the dot product in float32 before casting back to the original dtype to prevent precision loss with low-precision inputs
1 parent 93aeb57 commit 59f9e3b

File tree

1 file changed

+3
-2
lines changed
  • kernel_course/python_ops

1 file changed

+3
-2
lines changed

kernel_course/python_ops/dot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def dot(
1919
x = x.reshape(-1)
2020
y = y.reshape(-1)
2121

22-
z = torch.tensor(0.0, device=x.device, dtype=x.dtype)
22+
z = torch.tensor(0.0, device=x.device, dtype=torch.float32)
2323
for i in range(len(x)):
24-
z += x[i] * y[i]
24+
z += (x[i] * y[i]).to(torch.float32)
25+
z = z.to(x.dtype)
2526

2627
return z

0 commit comments

Comments
 (0)