Skip to content

Commit

Permalink
extended scaled_dot_attention.rs to support masking with negative inf…
Browse files Browse the repository at this point in the history
…inity before softmax
  • Loading branch information
JakubSchwenkbeck committed Dec 7, 2024
1 parent 38fd239 commit e3c8aff
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 77 deletions.
100 changes: 29 additions & 71 deletions src/attention/scaled_dot_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@ pub fn scaled_dot_product_attention(
q: Array3<f32>, // Query
k: Array3<f32>, // Key
v: Array3<f32>, // Value
mask: bool,
) -> Array3<f32> {
let scores = scaled_dot_product(q, k, v.clone(), None);
let scores = scaled_dot_product(q, k, v.clone(), mask);
let sm_scores = softmax_3d(&scores);
// TODO implement masking
tensor_product(&sm_scores, &v)
}

pub fn scaled_dot_product(
q: Array3<f32>, // Shape: (B, L_Q, d_k)
k: Array3<f32>, // Shape: (B, L_K, d_k)
v: Array3<f32>, // Shape: (B, L_K, d_v)
_mask: Option<Array3<f32>>, // Shape: (B, L_Q, L_K)
q: Array3<f32>, // Shape: (B, L_Q, d_k)
k: Array3<f32>, // Shape: (B, L_K, d_k)
v: Array3<f32>, // Shape: (B, L_K, d_v)
mask: bool,
) -> Array3<f32> {
let batch_size = q.shape()[0];
assert_eq!(q.shape()[0], k.shape()[0], "Batch Size mismatch");
Expand All @@ -30,7 +31,30 @@ pub fn scaled_dot_product(

// Scale the scores by sqrt(d_k)
scores /= d_k.sqrt();
if mask {
let mask = Array3::from_shape_fn((batch_size, L_Q, L_K), |(b, i, j)| {
if i >= j {
0.0
} else {
f32::NEG_INFINITY
}
});
// Ensure the mask has the shape (B, L_Q, L_K)
assert_eq!(mask.shape(), &[batch_size, L_Q, L_K]);

// Add the mask to the scores: apply a large negative number to masked positions
// This ensures that after softmax, these positions will have zero attention.
for b in 0..batch_size {
for i in 0..L_Q {
for j in 0..L_K {
if mask[(b, i, j)] == 0.0 {
// Applying a large negative value to masked positions
scores[(b, i, j)] = f32::NEG_INFINITY;
}
}
}
}
}
scores
}
pub fn query_key_product(
Expand Down Expand Up @@ -60,69 +84,3 @@ pub fn query_key_product(

Ok(scores)
}

pub fn test_attention_matrices() {
// Query matrix Q: Shape (2, 3, 4) -> Batch size 2, sequence length 3, d_k 4
let q: Array3<f32> = array![
[
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0]
],
[
[13.0, 14.0, 15.0, 16.0],
[17.0, 18.0, 19.0, 20.0],
[21.0, 22.0, 23.0, 24.0]
]
];

// Key matrix K: Shape (2, 3, 4) -> Batch size 2, sequence length 3, d_k 4
let k: Array3<f32> = array![
[
[1.0, 0.0, 1.0, 0.0],
[0.0, 1.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0]
],
[
[0.5, 0.5, 0.5, 0.5],
[1.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 1.0, 0.0]
]
];

// Value matrix V: Shape (2, 3, 5) -> Batch size 2, sequence length 3, d_v 5
let v: Array3<f32> = array![
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[6.0, 7.0, 8.0, 9.0, 10.0],
[11.0, 12.0, 13.0, 14.0, 15.0]
],
[
[16.0, 17.0, 18.0, 19.0, 20.0],
[21.0, 22.0, 23.0, 24.0, 25.0],
[26.0, 27.0, 28.0, 29.0, 30.0]
]
];

let res = scaled_dot_product(q.clone(), k.clone(), v.clone(), None);
println!(
"The Query Matrix : \n {:?} \n with shape {:?} \n ",
q,
q.shape()
);
println!(
"The Key Matrix : \n {:?} \n with shape {:?} \n ",
k,
k.shape()
);
println!(
"The Value Matrix : \n {:?} \n with shape {:?} \n ",
v,
v.shape()
);
println!(
"The scaled Query and Key Product for Attention : \n {:?} \n with shape {:?} ",
res,
res.shape()
);
}
3 changes: 0 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use Transformer::attention::scaled_dot_attention::test_attention_matrices;

fn main() {
println!("runs successfully!");
test_attention_matrices()
}
6 changes: 3 additions & 3 deletions tests/scaled_dot_attention_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn test_scaled_dot() {
[1.3, 1.4, 1.5]
]];
// simple case, Q = V = K
let res = scaled_dot_product(a.clone(), a.clone(), a.clone(), None);
let res = scaled_dot_product(a.clone(), a.clone(), a.clone(), false);

let expected: Array3<f32> = array![[
[0.081, 0.185, 0.289, 0.392, 0.081, 0.496],
Expand Down Expand Up @@ -78,7 +78,7 @@ fn softmax_scaled_dot_test() {
[1.3, 1.4, 1.5]
]];
// simple case, Q = V = K
let res = softmax_3d(&scaled_dot_product(a.clone(), a.clone(), a.clone(), None));
let res = softmax_3d(&scaled_dot_product(a.clone(), a.clone(), a.clone(), false));
let result: Array1<f32> = res.slice(s![0, 0, ..]).to_owned();
let expected: Array1<f32> = array![0.145, 0.162, 0.171, 0.189, 0.145, 0.21];
for i in 0..expected.len() {
Expand All @@ -100,7 +100,7 @@ fn full_scaled_dot_attention_test() {
[1.3, 1.4, 1.5]
]];
// simple case, Q = V = K
let res = scaled_dot_product_attention(a.clone(), a.clone(), a.clone());
let res = scaled_dot_product_attention(a.clone(), a.clone(), a.clone(), false);
let result: Array1<f32> = res.slice(s![0, 0, ..]).to_owned();
let output = [0.7836, 0.8836, 0.9836];
for i in 0..output.len() {
Expand Down

0 comments on commit e3c8aff

Please sign in to comment.