Skip to content

Commit

Permalink
added unit tetsts for positional encoding and saw many issues with th…
Browse files Browse the repository at this point in the history
…e current implementation
  • Loading branch information
JakubSchwenkbeck committed Dec 5, 2024
1 parent a209353 commit eee0fef
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/math/positional_encoding.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
pub fn sinusoidal_pos_encoding(pos: usize, index: usize, embedding: usize) -> f32 {
let divisor = 10000f32.powf(2.0 * (index / embedding) as f32);
pub fn sinusoidal_pos_encoding(pos: usize, index: usize, embedding_size: usize) -> f32 {
if pos == 1 {return 0.0};
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32)); // 100000^(2*i / embedding size)

if index % 2 == 0 {
if index % 2 == 0 { // for an even index, use sin, else cos!
(pos as f32 / divisor).sin()
} else {
(pos as f32 / divisor).cos()
Expand Down
57 changes: 57 additions & 0 deletions tests/positional_encoding_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use Transformer::math::positional_encoding::sinusoidal_pos_encoding;


#[test]
fn test_sin_encoding_even_index() {
let pos = 1;
let index = 2; // even index
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
// Expected result: sin(1 / (10000^(2*index/embedding_size)))
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).sin();
assert!((result - expected).abs() < 1e-6, "Failed for even index");
}

#[test]
fn test_cos_encoding_odd_index() {
let pos = 1;
let index = 3; // odd index
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
// Expected result: cos(1 / (10000^(2*index/embedding_size)))
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).cos();
assert!((result - expected).abs() < 1e-6, "Failed for odd index");
}

#[test]
fn test_large_position() {
let pos = 1000;
let index = 2;
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).sin();
assert!((result - expected).abs() < 1e-6, "Failed for large position");
}

#[test]
fn test_zero_position() {
let pos = 0;
let index = 1;
let embedding_size = 512;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
assert_eq!(result, 0.0, "Failed for zero position");
}

#[test]
fn test_large_embedding_size() {
let pos = 5;
let index = 10;
let embedding_size = 2048;
let result = sinusoidal_pos_encoding(pos, index, embedding_size);
let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32));
let expected = (pos as f32 / divisor).cos();
assert!((result - expected).abs() < 1e-6, "Failed for large embedding size");
}

0 comments on commit eee0fef

Please sign in to comment.