diff --git a/src/attention/scaled_dot_attention.rs b/src/attention/scaled_dot_attention.rs index 28f0765..aad6d4c 100644 --- a/src/attention/scaled_dot_attention.rs +++ b/src/attention/scaled_dot_attention.rs @@ -32,7 +32,7 @@ pub fn scaled_dot_product( scores /= d_k.sqrt(); if mask { let mask = Array3::from_shape_fn((batch_size, L_Q, L_K), |(b, i, j)| { - if i >= j { + if i > j { 0.0 } else { f32::NEG_INFINITY diff --git a/src/main.rs b/src/main.rs index 5a2f14e..e8cf130 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,7 +14,7 @@ fn main() { [1.3, 1.4, 1.5] ]]; - let scores = scaled_dot_product(a.clone(), a.clone(), a.clone(), false); + let scores = scaled_dot_product(a.clone(), a.clone(), a.clone(), true); let sm_scores = softmax_3d(&scores); // Words corresponding to the input let words = ["the", "cat", "sat", "on", "the", "mat"]; diff --git a/tests/scaled_dot_attention_test.rs b/tests/scaled_dot_attention_test.rs index 1c14285..a5f027b 100644 --- a/tests/scaled_dot_attention_test.rs +++ b/tests/scaled_dot_attention_test.rs @@ -103,7 +103,7 @@ fn full_scaled_dot_attention_test() { let res = scaled_dot_product_attention(a.clone(), a.clone(), a.clone(), false); let result: Array1 = res.slice(s![0, 0, ..]).to_owned(); let output = [0.7836, 0.8836, 0.9836]; - for i in 0..output.len() { + for i in 0..result.len() { assert!( (result[i] - output[i]) < 0.001, "Softmax scaled dot is too far off!"