Skip to content

Commit

Permalink
implemented dot product for vectors and several unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubSchwenkbeck committed Dec 4, 2024
1 parent 3a1e76f commit 3a6cc95
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/math/linear_algebra.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

use ndarray::Array2;
use ndarray::{Array1, Array2};
use ndarray::linalg::general_mat_mul;


Expand All @@ -21,3 +21,9 @@ pub fn matmul(a: &Array2<f32>, b: &Array2<f32>) -> Result<Array2<f32>, &'static
Ok(result)
}



pub fn dotproduct(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
a.dot(b)
}

46 changes: 44 additions & 2 deletions tests/linear_algebra_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use Transformer::math::linear_algebra::matmul;
use ndarray::array; // Assuming you're using ndarray for matrices
use Transformer::math::linear_algebra::{dotproduct, matmul};
use ndarray::{array, Array1}; // Assuming you're using ndarray for matrices


#[test]
Expand Down Expand Up @@ -48,3 +48,45 @@ fn test_matmul_invalid_input() {
// Assert: Ensure the result is an error due to incompatible dimensions
assert_eq!(result, Err("Matrix dimensions are incompatible for multiplication."));
}

#[test]
fn test_dotproduct(){
let a : Array1<f32> = array![ 1.0, 2.0, 3.0 ];
let b : Array1<f32> = array![ 4.0, 5.0, 5.0 ];

let expected = (4 + 10 + 15) as f32;
let result = dotproduct(&a, &b);

assert_eq!(result, expected);
}
#[test]
fn test_floats_dotproduct(){
let a : Array1<f32> = array![ 2.9, 7.68, 2.333 ];
let b : Array1<f32> = array![ 0.74, 1.2, 5.111 ];

let expected = (2.9*0.74 + 7.68*1.2 + 2.333*5.111) as f32;
let result = dotproduct(&a, &b);

assert_eq!(result, expected);
}#[test]
fn test_empty_dotproduct(){
let a : Array1<f32> = array![ ];
let b : Array1<f32> = array![ ];

let expected = 0.0;
let result = dotproduct(&a, &b);

assert_eq!(result, expected);
}

/*
#[test]
fn test_mismatch_dotproduct(){
let a : Array1<f32> = array![ 2.9, 7.68, 2.333 ,1.0];
let b : Array1<f32> = array![ 0.74, 1.2, 5.111 ];
let expected = (2.9*0.74 + 7.68*1.2 + 2.333*5.111 ) as f32;
let result = dotproduct(&a, &b);
assert_eq!(result, expected);
}*/

0 comments on commit 3a6cc95

Please sign in to comment.