Skip to content

Commit

Permalink
small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Giorgio Savastano authored and Giorgio Savastano committed Jun 16, 2024
1 parent f256581 commit 1886f4e
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

@nox.session
def tests(session):
session.install("pip", "numpy", "pytest", "scipy")
session.install("pip", "numpy", "pytest", "pytest-benchmark" "scipy")
session.run("pip", "install", ".", "-v")
session.run("pytest")
21 changes: 15 additions & 6 deletions src/emd_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@ fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
/// * `v1` - A 1-dimensional view of f64 data.
/// * `v2` - A 1-dimensional view of f64 data.
///
/// # Panics
/// Panics if the input arrays `v1` and `v2` have different lengths.
///
/// # Returns
/// The Euclidean distance as a floating-point number.
fn euclidean_distance(v1: &ArrayView1<f64>, v2: &ArrayView1<f64>) -> f64 {
v1.iter()
.zip(v2.iter())
.map(|(&x, &y)| (x - y).powi(2))
.sum::<f64>()
if v1.len() != v2.len() {
panic!("Input arrays must have the same length");
}
Zip::from(v1)
.and(v2)
.map_collect(|&x, &y| (x - y).powi(2))
.sum()
.sqrt()
}

Expand All @@ -49,6 +55,9 @@ fn euclidean_rdist_row(
x: &ArrayView1<'_, f64>,
y: &ArrayView2<'_, f64>,
) -> Array1<OrderedFloat<f64>> {
if x.is_empty() || y.is_empty() {
panic!("Input arrays must not be empty");
}
Zip::from(y.rows()).map_collect(|row| OrderedFloat::from(euclidean_distance(&row, &x)))
}

Expand Down Expand Up @@ -104,7 +113,7 @@ pub fn compute_emd_between_2dtensors(
) -> Result<OrderedFloat<f64>, MatrixFormatError> {
let costs = euclidean_rdist_rust(x, y);
let weights = Matrix::from_vec(costs.nrows(), costs.ncols(), costs.into_raw_vec())?;
let (emd_dist, _assignments) = kuhn_munkres_min(&weights);
let (emd_dist, _) = kuhn_munkres_min(&weights);
Ok(emd_dist)
}

Expand All @@ -122,7 +131,7 @@ pub fn compute_emd_between_2dtensors_par(
) -> Result<OrderedFloat<f64>, MatrixFormatError> {
let costs = euclidean_rdist_par(x, y);
let weights = Matrix::from_vec(costs.nrows(), costs.ncols(), costs.into_raw_vec())?;
let (emd_dist, _assignments) = kuhn_munkres_min(&weights);
let (emd_dist, _) = kuhn_munkres_min(&weights);
Ok(emd_dist)
}

Expand Down
14 changes: 8 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ fn compute_emd<'py>(x: PyReadonlyArray2<'py, f64>, y: PyReadonlyArray2<'py, f64>

match z {
Ok(z) => Ok(*z),
Err(_e) => Err(exceptions::PyTypeError::new_err(
"Failed to compute EMD distance.",
)),
Err(e) => Err(exceptions::PyTypeError::new_err(format!(
"Failed to compute EMD distance: {}",
e
))),
}
}

Expand All @@ -68,9 +69,10 @@ fn compute_emd_parallel<'py>(

match z {
Ok(z) => Ok(*z),
Err(_e) => Err(exceptions::PyTypeError::new_err(
"Failed to compute EMD distance.",
)),
Err(e) => Err(exceptions::PyTypeError::new_err(format!(
"Failed to compute EMD distance: {}",
e
))),
}
}

Expand Down
12 changes: 4 additions & 8 deletions tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@

def compute_earth_mover_dist(first, second):
"""Compute earth's mover distance (EMD) between two data tensors using numpy and scipy."""
emds = []
for el in second:
d = cdist(first, el)
row_ind, col_ind = linear_sum_assignment(d)
emd = d[row_ind, col_ind].sum()
emds.append(emd)
return emds
d = cdist(first, second)
row_ind, col_ind = linear_sum_assignment(d)
return d[row_ind, col_ind].sum()


def setup_data():
"""Generates random data for benchmarking."""
rng = np.random.default_rng()
data1 = rng.random((50, 50))
data2 = rng.random((10, 50, 50))
data2 = rng.random((50, 50))
return data1, data2


Expand Down

0 comments on commit 1886f4e

Please sign in to comment.