diff --git a/src/attention/scaled_dot_attention.rs b/src/attention/scaled_dot_attention.rs index 9d99796..cfdb245 100644 --- a/src/attention/scaled_dot_attention.rs +++ b/src/attention/scaled_dot_attention.rs @@ -7,18 +7,19 @@ pub fn scaled_dot_product_attention( q: Array3, // Query k: Array3, // Key v: Array3, // Value + mask: bool, ) -> Array3 { - let scores = scaled_dot_product(q, k, v.clone(), None); + let scores = scaled_dot_product(q, k, v.clone(), mask); let sm_scores = softmax_3d(&scores); // TODO implement masking tensor_product(&sm_scores, &v) } pub fn scaled_dot_product( - q: Array3, // Shape: (B, L_Q, d_k) - k: Array3, // Shape: (B, L_K, d_k) - v: Array3, // Shape: (B, L_K, d_v) - _mask: Option>, // Shape: (B, L_Q, L_K) + q: Array3, // Shape: (B, L_Q, d_k) + k: Array3, // Shape: (B, L_K, d_k) + v: Array3, // Shape: (B, L_K, d_v) + mask: bool, ) -> Array3 { let batch_size = q.shape()[0]; assert_eq!(q.shape()[0], k.shape()[0], "Batch Size mismatch"); @@ -30,7 +31,30 @@ pub fn scaled_dot_product( // Scale the scores by sqrt(d_k) scores /= d_k.sqrt(); + if mask { + let mask = Array3::from_shape_fn((batch_size, L_Q, L_K), |(b, i, j)| { + if i >= j { + 0.0 + } else { + f32::NEG_INFINITY + } + }); + // Ensure the mask has the shape (B, L_Q, L_K) + assert_eq!(mask.shape(), &[batch_size, L_Q, L_K]); + // Add the mask to the scores: apply a large negative number to masked positions + // This ensures that after softmax, these positions will have zero attention. + for b in 0..batch_size { + for i in 0..L_Q { + for j in 0..L_K { + if mask[(b, i, j)] == 0.0 { + // Applying a large negative value to masked positions + scores[(b, i, j)] = f32::NEG_INFINITY; + } + } + } + } + } scores } pub fn query_key_product( @@ -60,69 +84,3 @@ pub fn query_key_product( Ok(scores) } - -pub fn test_attention_matrices() { - // Query matrix Q: Shape (2, 3, 4) -> Batch size 2, sequence length 3, d_k 4 - let q: Array3 = array![ - [ - [1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0] - ], - [ - [13.0, 14.0, 15.0, 16.0], - [17.0, 18.0, 19.0, 20.0], - [21.0, 22.0, 23.0, 24.0] - ] - ]; - - // Key matrix K: Shape (2, 3, 4) -> Batch size 2, sequence length 3, d_k 4 - let k: Array3 = array![ - [ - [1.0, 0.0, 1.0, 0.0], - [0.0, 1.0, 0.0, 1.0], - [1.0, 1.0, 1.0, 1.0] - ], - [ - [0.5, 0.5, 0.5, 0.5], - [1.0, 0.0, 0.0, 1.0], - [0.0, 1.0, 1.0, 0.0] - ] - ]; - - // Value matrix V: Shape (2, 3, 5) -> Batch size 2, sequence length 3, d_v 5 - let v: Array3 = array![ - [ - [1.0, 2.0, 3.0, 4.0, 5.0], - [6.0, 7.0, 8.0, 9.0, 10.0], - [11.0, 12.0, 13.0, 14.0, 15.0] - ], - [ - [16.0, 17.0, 18.0, 19.0, 20.0], - [21.0, 22.0, 23.0, 24.0, 25.0], - [26.0, 27.0, 28.0, 29.0, 30.0] - ] - ]; - - let res = scaled_dot_product(q.clone(), k.clone(), v.clone(), None); - println!( - "The Query Matrix : \n {:?} \n with shape {:?} \n ", - q, - q.shape() - ); - println!( - "The Key Matrix : \n {:?} \n with shape {:?} \n ", - k, - k.shape() - ); - println!( - "The Value Matrix : \n {:?} \n with shape {:?} \n ", - v, - v.shape() - ); - println!( - "The scaled Query and Key Product for Attention : \n {:?} \n with shape {:?} ", - res, - res.shape() - ); -} diff --git a/src/main.rs b/src/main.rs index 8caf122..b776082 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,3 @@ -use Transformer::attention::scaled_dot_attention::test_attention_matrices; - fn main() { println!("runs successfully!"); - test_attention_matrices() } diff --git a/tests/scaled_dot_attention_test.rs b/tests/scaled_dot_attention_test.rs index e884c92..1c14285 100644 --- a/tests/scaled_dot_attention_test.rs +++ b/tests/scaled_dot_attention_test.rs @@ -37,7 +37,7 @@ fn test_scaled_dot() { [1.3, 1.4, 1.5] ]]; // simple case, Q = V = K - let res = scaled_dot_product(a.clone(), a.clone(), a.clone(), None); + let res = scaled_dot_product(a.clone(), a.clone(), a.clone(), false); let expected: Array3 = array![[ [0.081, 0.185, 0.289, 0.392, 0.081, 0.496], @@ -78,7 +78,7 @@ fn softmax_scaled_dot_test() { [1.3, 1.4, 1.5] ]]; // simple case, Q = V = K - let res = softmax_3d(&scaled_dot_product(a.clone(), a.clone(), a.clone(), None)); + let res = softmax_3d(&scaled_dot_product(a.clone(), a.clone(), a.clone(), false)); let result: Array1 = res.slice(s![0, 0, ..]).to_owned(); let expected: Array1 = array![0.145, 0.162, 0.171, 0.189, 0.145, 0.21]; for i in 0..expected.len() { @@ -100,7 +100,7 @@ fn full_scaled_dot_attention_test() { [1.3, 1.4, 1.5] ]]; // simple case, Q = V = K - let res = scaled_dot_product_attention(a.clone(), a.clone(), a.clone()); + let res = scaled_dot_product_attention(a.clone(), a.clone(), a.clone(), false); let result: Array1 = res.slice(s![0, 0, ..]).to_owned(); let output = [0.7836, 0.8836, 0.9836]; for i in 0..output.len() {