Skip to content

Commit

Permalink
tested softmax on scaled dot test
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubSchwenkbeck committed Dec 7, 2024
1 parent 7dda9ff commit 2fe4ea7
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
13 changes: 12 additions & 1 deletion src/attention/scaled_dot_attention.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
#![allow(unused_variables)]
use crate::attention::softmax::softmax_3d;
use crate::math::linear_algebra::matmul;
use ndarray::{array, Array3, Axis, ShapeError};

pub fn scaled_dot_product_attention(
q: Array3<f32>, // Query
k: Array3<f32>, // Key
v: Array3<f32>, // Value
) -> Array3<f32> {
let scores = scaled_dot_product(q, k, v, None);
let sm_scores = softmax_3d(&scores);
sm_scores
}

pub fn scaled_dot_product(
q: Array3<f32>, // Shape: (B, L_Q, d_k)
k: Array3<f32>, // Shape: (B, L_K, d_k)
v: Array3<f32>, // Shape: (B, L_K, d_v)
Expand Down Expand Up @@ -92,7 +103,7 @@ pub fn test_attention_matrices() {
]
];

let res = scaled_dot_product_attention(q.clone(), k.clone(), v.clone(), None);
let res = scaled_dot_product(q.clone(), k.clone(), v.clone(), None);
println!(
"The Query Matrix : \n {:?} \n with shape {:?} \n ",
q,
Expand Down
35 changes: 29 additions & 6 deletions tests/scaled_dot_attention_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use ndarray::{array, Array3};
use Transformer::attention::scaled_dot_attention::{
query_key_product, scaled_dot_product_attention,
};
use ndarray::{array, s, Array1, Array3};
use Transformer::attention::scaled_dot_attention::{query_key_product, scaled_dot_product};
use Transformer::attention::softmax::softmax_3d;

#[test]
fn test_query_key_product() {
let q = array![
Expand All @@ -25,7 +25,7 @@ fn test_query_key_product() {

// test values by https://medium.com/@saraswatp/understanding-scaled-dot-product-attention-in-transformer-models-5fe02b0f150c
#[test]
fn test_scaled_dot_attention() {
fn test_scaled_dot() {
let a: Array3<f32> = array![[
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
Expand All @@ -35,7 +35,7 @@ fn test_scaled_dot_attention() {
[1.3, 1.4, 1.5]
]];
// simple case, Q = V = K
let res = scaled_dot_product_attention(a.clone(), a.clone(), a.clone(), None);
let res = scaled_dot_product(a.clone(), a.clone(), a.clone(), None);

let expected: Array3<f32> = array![[
[0.081, 0.185, 0.289, 0.392, 0.081, 0.496],
Expand Down Expand Up @@ -63,3 +63,26 @@ fn test_scaled_dot_attention() {

assert_eq!(res.shape()[2], expected.shape()[2]);
}

// test values by https://medium.com/@saraswatp/understanding-scaled-dot-product-attention-in-transformer-models-5fe02b0f150c
#[test]
fn softmax_scaled_dot_test() {
let a: Array3<f32> = array![[
[0.1, 0.2, 0.3],
[0.4, 0.5, 0.6],
[0.7, 0.8, 0.9],
[1.0, 1.1, 1.2],
[0.1, 0.2, 0.3],
[1.3, 1.4, 1.5]
]];
// simple case, Q = V = K
let res = softmax_3d(&scaled_dot_product(a.clone(), a.clone(), a.clone(), None));
let result: Array1<f32> = res.slice(s![0, 0, ..]).to_owned();
let expected: Array1<f32> = array![0.145, 0.162, 0.171, 0.189, 0.145, 0.21];
for i in 0..expected.len() {
assert!(
(result[i] - expected[i]) < 0.001,
"Softmax scaled dot is too far off!"
);
}
}

0 comments on commit 2fe4ea7

Please sign in to comment.