Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Giorgio Savastano authored and Giorgio Savastano committed Jun 19, 2024
1 parent ddb1be9 commit 36e1067
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
28 changes: 13 additions & 15 deletions src/emd_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,10 @@ fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
/// # Returns
/// The Euclidean distance as a floating-point number.
fn euclidean_distance(v1: &ArrayView1<f64>, v2: &ArrayView1<f64>) -> f64 {
if v1.len() != v2.len() {
panic!("Input arrays must have the same length");
}
Zip::from(v1)
.and(v2)
.map_collect(|&x, &y| (x - y).powi(2))
.sum()
v1.iter()
.zip(v2.iter())
.map(|(&x, &y)| (x - y) * (x - y))
.sum::<f64>()
.sqrt()
}

Expand All @@ -55,10 +52,10 @@ fn euclidean_rdist_row(
x: &ArrayView1<'_, f64>,
y: &ArrayView2<'_, f64>,
) -> Array1<OrderedFloat<f64>> {
if x.is_empty() || y.is_empty() {
panic!("Input arrays must not be empty");
}
Zip::from(y.rows()).map_collect(|row| OrderedFloat::from(euclidean_distance(&row, &x)))
y.rows()
.into_iter()
.map(|row| OrderedFloat::from(euclidean_distance(&row, &x)))
.collect()
}

/// Computes the Euclidean distances between rows of two 2-dimensional data arrays synchronously.
Expand All @@ -73,7 +70,8 @@ pub fn euclidean_rdist_rust(
x: ArrayView2<'_, f64>,
y: ArrayView2<'_, f64>,
) -> Array2<OrderedFloat<f64>> {
let mut c = Array2::<OrderedFloat<f64>>::zeros((x.nrows(), y.nrows()));
let mut c = Array2::from_elem((x.nrows(), y.nrows()), OrderedFloat::from(0.0));

Zip::from(x.rows())
.and(c.rows_mut())
.for_each(|row_x, mut row_c| row_c.assign(&euclidean_rdist_row(&row_x, &y)));
Expand All @@ -88,11 +86,11 @@ pub fn euclidean_rdist_rust(
///
/// # Returns
/// A 2-dimensional array where each element `(i, j)` is the distance between row `i` of `x` and row `j` of `y`.
pub fn euclidean_rdist_par(
pub fn euclidean_rdist_rust_par(
x: ArrayView2<'_, f64>,
y: ArrayView2<'_, f64>,
) -> Array2<OrderedFloat<f64>> {
let mut c = Array2::<OrderedFloat<f64>>::zeros((x.nrows(), y.nrows()));
let mut c = Array2::from_elem((x.nrows(), y.nrows()), OrderedFloat::from(0.0));
Zip::from(x.rows())
.and(c.rows_mut())
.par_for_each(|row_x, mut row_c| row_c.assign(&euclidean_rdist_row(&row_x, &y)));
Expand Down Expand Up @@ -129,7 +127,7 @@ pub fn compute_emd_between_2dtensors_par(
x: ArrayView2<'_, f64>,
y: ArrayView2<'_, f64>,
) -> Result<OrderedFloat<f64>, MatrixFormatError> {
let costs = euclidean_rdist_par(x, y);
let costs = euclidean_rdist_rust_par(x, y);
let weights = Matrix::from_vec(costs.nrows(), costs.ncols(), costs.into_raw_vec())?;
let (emd_dist, _) = kuhn_munkres_min(&weights);
Ok(emd_dist)
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn euclidean_rdist_parallel<'py>(
) -> Bound<'py, PyArray2<f64>> {
let x = x.as_array();
let y = y.as_array();
let z = emd_classification::euclidean_rdist_par(x, y);
let z = emd_classification::euclidean_rdist_rust_par(x, y);
let res = z.mapv(|elem| elem.into_inner());
res.into_pyarray_bound(py)
}
Expand Down
6 changes: 6 additions & 0 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ def setup_data():


def test_rust_emd_benchmark(benchmark):
"""Benchmark the Rust-backed EMD calculation."""
data1, data2 = setup_data()
benchmark(compute_earth_movers_distance_2d, data1, data2, False)


def test_rust_emd_par_benchmark(benchmark):
"""Benchmark the Rust-backed EMD calculation."""
data1, data2 = setup_data()
benchmark(compute_earth_movers_distance_2d, data1, data2, True)
Expand Down

0 comments on commit 36e1067

Please sign in to comment.