From 3a6cc9520cf9cbce73646c892b5b5eb774a93a5b Mon Sep 17 00:00:00 2001 From: Jakub Date: Wed, 4 Dec 2024 15:47:11 +0100 Subject: [PATCH] implemented dot product for vectors and several unit tests --- src/math/linear_algebra.rs | 8 ++++++- tests/linear_algebra_test.rs | 46 ++++++++++++++++++++++++++++++++++-- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/math/linear_algebra.rs b/src/math/linear_algebra.rs index 5c9f8f0..cac34b4 100644 --- a/src/math/linear_algebra.rs +++ b/src/math/linear_algebra.rs @@ -1,5 +1,5 @@ -use ndarray::Array2; +use ndarray::{Array1, Array2}; use ndarray::linalg::general_mat_mul; @@ -21,3 +21,9 @@ pub fn matmul(a: &Array2, b: &Array2) -> Result, &'static Ok(result) } + + +pub fn dotproduct(a: &Array1, b: &Array1) -> f32 { + a.dot(b) +} + diff --git a/tests/linear_algebra_test.rs b/tests/linear_algebra_test.rs index 669d554..2be6cbb 100644 --- a/tests/linear_algebra_test.rs +++ b/tests/linear_algebra_test.rs @@ -1,5 +1,5 @@ -use Transformer::math::linear_algebra::matmul; -use ndarray::array; // Assuming you're using ndarray for matrices +use Transformer::math::linear_algebra::{dotproduct, matmul}; +use ndarray::{array, Array1}; // Assuming you're using ndarray for matrices #[test] @@ -48,3 +48,45 @@ fn test_matmul_invalid_input() { // Assert: Ensure the result is an error due to incompatible dimensions assert_eq!(result, Err("Matrix dimensions are incompatible for multiplication.")); } + +#[test] +fn test_dotproduct(){ + let a : Array1 = array![ 1.0, 2.0, 3.0 ]; + let b : Array1 = array![ 4.0, 5.0, 5.0 ]; + + let expected = (4 + 10 + 15) as f32; + let result = dotproduct(&a, &b); + + assert_eq!(result, expected); +} +#[test] +fn test_floats_dotproduct(){ + let a : Array1 = array![ 2.9, 7.68, 2.333 ]; + let b : Array1 = array![ 0.74, 1.2, 5.111 ]; + + let expected = (2.9*0.74 + 7.68*1.2 + 2.333*5.111) as f32; + let result = dotproduct(&a, &b); + + assert_eq!(result, expected); +}#[test] +fn test_empty_dotproduct(){ + let a : Array1 = array![ ]; + let b : Array1 = array![ ]; + + let expected = 0.0; + let result = dotproduct(&a, &b); + + assert_eq!(result, expected); +} + +/* +#[test] +fn test_mismatch_dotproduct(){ + let a : Array1 = array![ 2.9, 7.68, 2.333 ,1.0]; + let b : Array1 = array![ 0.74, 1.2, 5.111 ]; + + let expected = (2.9*0.74 + 7.68*1.2 + 2.333*5.111 ) as f32; + let result = dotproduct(&a, &b); + + assert_eq!(result, expected); +}*/ \ No newline at end of file