Skip to content

Commit

Permalink
Fixed issues regarding positional_encoding and the unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubSchwenkbeck committed Dec 5, 2024
1 parent eee0fef commit b774014
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 53 deletions.
7 changes: 5 additions & 2 deletions src/math/positional_encoding.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
pub fn sinusoidal_pos_encoding(pos: usize, index: usize, embedding_size: usize) -> f32 {
if pos == 1 {return 0.0};
if pos == 0 {
return 0.0;
};
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32)); // 100000^(2*i / embedding size)

if index % 2 == 0 { // for an even index, use sin, else cos!
if index % 2 == 0 {
// for an even index, use sin, else cos!
(pos as f32 / divisor).sin()
} else {
(pos as f32 / divisor).cos()
Expand Down
107 changes: 56 additions & 51 deletions tests/positional_encoding_test.rs
Original file line number Diff line number Diff line change
@@ -1,57 +1,62 @@
use Transformer::math::positional_encoding::sinusoidal_pos_encoding;

#[test]
fn test_sin_encoding_even_index() {
let pos = 1;
let index = 2; // even index
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
// Expected result: sin(1 / (10000^(2*index/embedding_size)))
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).sin();
assert!((result - expected).abs() < 1e-6, "Failed for even index");
}

#[test]
fn test_sin_encoding_even_index() {
let pos = 1;
let index = 2; // even index
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
// Expected result: sin(1 / (10000^(2*index/embedding_size)))
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).sin();
assert!((result - expected).abs() < 1e-6, "Failed for even index");
}
#[test]
fn test_cos_encoding_odd_index() {
let pos = 1;
let index = 3; // odd index
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
// Expected result: cos(1 / (10000^(2*index/embedding_size)))
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).cos();
assert!((result - expected).abs() < 1e-6, "Failed for odd index");
}

#[test]
fn test_cos_encoding_odd_index() {
let pos = 1;
let index = 3; // odd index
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
// Expected result: cos(1 / (10000^(2*index/embedding_size)))
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).cos();
assert!((result - expected).abs() < 1e-6, "Failed for odd index");
}
#[test]
fn test_large_position() {
let pos = 1000;
let index = 2;
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).sin();
assert!(
(result - expected).abs() < 1e-6,
"Failed for large position"
);
}

#[test]
fn test_large_position() {
let pos = 1000;
let index = 2;
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).sin();
assert!((result - expected).abs() < 1e-6, "Failed for large position");
}
#[test]
fn test_zero_position() {
let pos = 0;
let index = 1;
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
assert_eq!(result, 0.0, "Failed for zero position");
}

#[test]
fn test_zero_position() {
let pos = 0;
let index = 1;
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
assert_eq!(result, 0.0, "Failed for zero position");
}

#[test]
fn test_large_embedding_size() {
let pos = 5;
let index = 10;
let embedding_size = 2048;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).cos();
assert!((result - expected).abs() < 1e-6, "Failed for large embedding size");
}
#[test]
fn test_large_embedding_size() {
let pos = 5;
let index = 10;
let embedding_size = 2048;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).sin();
assert!(
(result - expected).abs() < 1e-6,
"Failed for large embedding size"
);
}

0 comments on commit b774014

Please sign in to comment.