Skip to content

Commit

Permalink
debugged masking and added a print for the attention-weight matrix wi…
Browse files Browse the repository at this point in the history
…th example keywords
  • Loading branch information
JakubSchwenkbeck committed Dec 8, 2024
1 parent 403a21f commit 9d0033b
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/attention/scaled_dot_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub fn scaled_dot_product(
scores /= d_k.sqrt();
if mask {
let mask = Array3::from_shape_fn((batch_size, L_Q, L_K), |(b, i, j)| {
if i >= j {
if i > j {
0.0
} else {
f32::NEG_INFINITY
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn main() {
[1.3, 1.4, 1.5]
]];

let scores = scaled_dot_product(a.clone(), a.clone(), a.clone(), false);
let scores = scaled_dot_product(a.clone(), a.clone(), a.clone(), true);
let sm_scores = softmax_3d(&scores);
// Words corresponding to the input
let words = ["the", "cat", "sat", "on", "the", "mat"];
Expand Down
2 changes: 1 addition & 1 deletion tests/scaled_dot_attention_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ fn full_scaled_dot_attention_test() {
let res = scaled_dot_product_attention(a.clone(), a.clone(), a.clone(), false);
let result: Array1<f32> = res.slice(s![0, 0, ..]).to_owned();
let output = [0.7836, 0.8836, 0.9836];
for i in 0..output.len() {
for i in 0..result.len() {
assert!(
(result[i] - output[i]) < 0.001,
"Softmax scaled dot is too far off!"
Expand Down

0 comments on commit 9d0033b

Please sign in to comment.