From 1886f4eac3d13a99a84201854d472d73f1e5775e Mon Sep 17 00:00:00 2001 From: Giorgio Savastano Date: Sun, 16 Jun 2024 17:27:18 +0200 Subject: [PATCH] small improvements --- noxfile.py | 2 +- src/emd_classification.rs | 21 +++++++++++++++------ src/lib.rs | 14 ++++++++------ tests/test_benchmark.py | 12 ++++-------- 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/noxfile.py b/noxfile.py index b6e5137..4fa6eed 100644 --- a/noxfile.py +++ b/noxfile.py @@ -3,6 +3,6 @@ @nox.session def tests(session): - session.install("pip", "numpy", "pytest", "scipy") + session.install("pip", "numpy", "pytest", "pytest-benchmark" "scipy") session.run("pip", "install", ".", "-v") session.run("pytest") diff --git a/src/emd_classification.rs b/src/emd_classification.rs index 4607922..ad0d0b5 100644 --- a/src/emd_classification.rs +++ b/src/emd_classification.rs @@ -27,13 +27,19 @@ fn argsort(data: &[T]) -> Vec { /// * `v1` - A 1-dimensional view of f64 data. /// * `v2` - A 1-dimensional view of f64 data. /// +/// # Panics +/// Panics if the input arrays `v1` and `v2` have different lengths. +/// /// # Returns /// The Euclidean distance as a floating-point number. fn euclidean_distance(v1: &ArrayView1, v2: &ArrayView1) -> f64 { - v1.iter() - .zip(v2.iter()) - .map(|(&x, &y)| (x - y).powi(2)) - .sum::() + if v1.len() != v2.len() { + panic!("Input arrays must have the same length"); + } + Zip::from(v1) + .and(v2) + .map_collect(|&x, &y| (x - y).powi(2)) + .sum() .sqrt() } @@ -49,6 +55,9 @@ fn euclidean_rdist_row( x: &ArrayView1<'_, f64>, y: &ArrayView2<'_, f64>, ) -> Array1> { + if x.is_empty() || y.is_empty() { + panic!("Input arrays must not be empty"); + } Zip::from(y.rows()).map_collect(|row| OrderedFloat::from(euclidean_distance(&row, &x))) } @@ -104,7 +113,7 @@ pub fn compute_emd_between_2dtensors( ) -> Result, MatrixFormatError> { let costs = euclidean_rdist_rust(x, y); let weights = Matrix::from_vec(costs.nrows(), costs.ncols(), costs.into_raw_vec())?; - let (emd_dist, _assignments) = kuhn_munkres_min(&weights); + let (emd_dist, _) = kuhn_munkres_min(&weights); Ok(emd_dist) } @@ -122,7 +131,7 @@ pub fn compute_emd_between_2dtensors_par( ) -> Result, MatrixFormatError> { let costs = euclidean_rdist_par(x, y); let weights = Matrix::from_vec(costs.nrows(), costs.ncols(), costs.into_raw_vec())?; - let (emd_dist, _assignments) = kuhn_munkres_min(&weights); + let (emd_dist, _) = kuhn_munkres_min(&weights); Ok(emd_dist) } diff --git a/src/lib.rs b/src/lib.rs index 4831c7b..16e39cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -51,9 +51,10 @@ fn compute_emd<'py>(x: PyReadonlyArray2<'py, f64>, y: PyReadonlyArray2<'py, f64> match z { Ok(z) => Ok(*z), - Err(_e) => Err(exceptions::PyTypeError::new_err( - "Failed to compute EMD distance.", - )), + Err(e) => Err(exceptions::PyTypeError::new_err(format!( + "Failed to compute EMD distance: {}", + e + ))), } } @@ -68,9 +69,10 @@ fn compute_emd_parallel<'py>( match z { Ok(z) => Ok(*z), - Err(_e) => Err(exceptions::PyTypeError::new_err( - "Failed to compute EMD distance.", - )), + Err(e) => Err(exceptions::PyTypeError::new_err(format!( + "Failed to compute EMD distance: {}", + e + ))), } } diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 20f2aaa..9ef9854 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -8,20 +8,16 @@ def compute_earth_mover_dist(first, second): """Compute earth's mover distance (EMD) between two data tensors using numpy and scipy.""" - emds = [] - for el in second: - d = cdist(first, el) - row_ind, col_ind = linear_sum_assignment(d) - emd = d[row_ind, col_ind].sum() - emds.append(emd) - return emds + d = cdist(first, second) + row_ind, col_ind = linear_sum_assignment(d) + return d[row_ind, col_ind].sum() def setup_data(): """Generates random data for benchmarking.""" rng = np.random.default_rng() data1 = rng.random((50, 50)) - data2 = rng.random((10, 50, 50)) + data2 = rng.random((50, 50)) return data1, data2