From eee0fef6a88ce53557ee70c2f7a9b11d05bc6764 Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 5 Dec 2024 15:48:06 +0100 Subject: [PATCH] added unit tetsts for positional encoding and saw many issues with the current implementation --- src/math/positional_encoding.rs | 7 ++-- tests/positional_encoding_test.rs | 57 +++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 tests/positional_encoding_test.rs diff --git a/src/math/positional_encoding.rs b/src/math/positional_encoding.rs index 3339987..4c9805f 100644 --- a/src/math/positional_encoding.rs +++ b/src/math/positional_encoding.rs @@ -1,7 +1,8 @@ -pub fn sinusoidal_pos_encoding(pos: usize, index: usize, embedding: usize) -> f32 { - let divisor = 10000f32.powf(2.0 * (index / embedding) as f32); +pub fn sinusoidal_pos_encoding(pos: usize, index: usize, embedding_size: usize) -> f32 { + if pos == 1 {return 0.0}; + let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32)); // 100000^(2*i / embedding size) - if index % 2 == 0 { + if index % 2 == 0 { // for an even index, use sin, else cos! (pos as f32 / divisor).sin() } else { (pos as f32 / divisor).cos() diff --git a/tests/positional_encoding_test.rs b/tests/positional_encoding_test.rs new file mode 100644 index 0000000..d6e681f --- /dev/null +++ b/tests/positional_encoding_test.rs @@ -0,0 +1,57 @@ +use Transformer::math::positional_encoding::sinusoidal_pos_encoding; + + + #[test] + fn test_sin_encoding_even_index() { + let pos = 1; + let index = 2; // even index + let embedding_size = 512; + let result = sinusoidal_pos_encoding(pos, index, embedding_size); + // Expected result: sin(1 / (10000^(2*index/embedding_size))) + let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32)); + let expected = (pos as f32 / divisor).sin(); + assert!((result - expected).abs() < 1e-6, "Failed for even index"); + } + + #[test] + fn test_cos_encoding_odd_index() { + let pos = 1; + let index = 3; // odd index + let embedding_size = 512; + let result = sinusoidal_pos_encoding(pos, index, embedding_size); + // Expected result: cos(1 / (10000^(2*index/embedding_size))) + let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32)); + let expected = (pos as f32 / divisor).cos(); + assert!((result - expected).abs() < 1e-6, "Failed for odd index"); + } + + #[test] + fn test_large_position() { + let pos = 1000; + let index = 2; + let embedding_size = 512; + let result = sinusoidal_pos_encoding(pos, index, embedding_size); + let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32)); + let expected = (pos as f32 / divisor).sin(); + assert!((result - expected).abs() < 1e-6, "Failed for large position"); + } + + #[test] + fn test_zero_position() { + let pos = 0; + let index = 1; + let embedding_size = 512; + let result = sinusoidal_pos_encoding(pos, index, embedding_size); + assert_eq!(result, 0.0, "Failed for zero position"); + } + + #[test] + fn test_large_embedding_size() { + let pos = 5; + let index = 10; + let embedding_size = 2048; + let result = sinusoidal_pos_encoding(pos, index, embedding_size); + let divisor = 10000f32.powf(2.0 * (index as f32 / embedding_size as f32)); + let expected = (pos as f32 / divisor).cos(); + assert!((result - expected).abs() < 1e-6, "Failed for large embedding size"); + }