Skip to content

Commit

Permalink
added docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Giorgio Savastano authored and Giorgio Savastano committed Jun 15, 2024
1 parent 4603e5f commit 9065c9b
Showing 1 changed file with 82 additions and 21 deletions.
103 changes: 82 additions & 21 deletions src/emd_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@ use ndarray::Zip;
use ordered_float::OrderedFloat;
use pathfinding::prelude::{kuhn_munkres_min, Matrix, MatrixFormatError};

/// Represents an unusable or error state in distance calculations, initialized to infinity.
const BAD_VALUE: f64 = f64::INFINITY;

/// Sorts the indices of the elements in the provided slice in ascending order based on the elements themselves.
///
/// # Arguments
/// * `data` - A slice of data implementing the `Ord` trait.
///
/// # Returns
/// A vector of indices that, if used to index into `data`, will produce a sorted array.
fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
let mut indices = (0..data.len()).collect::<Vec<_>>();
unsafe {
Expand All @@ -13,6 +21,14 @@ fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
indices
}

/// Calculates the Euclidean distance between two 1-dimensional arrays.
///
/// # Arguments
/// * `v1` - A 1-dimensional view of f64 data.
/// * `v2` - A 1-dimensional view of f64 data.
///
/// # Returns
/// The Euclidean distance as a floating-point number.
fn euclidean_distance(v1: &ArrayView1<f64>, v2: &ArrayView1<f64>) -> f64 {
v1.iter()
.zip(v2.iter())
Expand All @@ -21,27 +37,26 @@ fn euclidean_distance(v1: &ArrayView1<f64>, v2: &ArrayView1<f64>) -> f64 {
.sqrt()
}

/// Compute Euclidean distance between two 2-D data tensors (e.g., images).
///
// pub fn euclidean_rdist_rust(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array2<f64> {
// let mut c = Array2::<f64>::zeros((x.nrows(), y.nrows()));
// for i in 0..x.nrows() {
// for j in 0..y.nrows() {
// unsafe {
// *c.uget_mut([i, j]) = euclidean_distance(&x.row(i), &y.row(j));
// }
// }
// }
// c
// }

/// Computes the Euclidean distance between a single row and each row of a 2-dimensional array.
///
/// # Arguments
/// * `x` - A 1-dimensional view of a data row.
/// * `y` - A 2-dimensional array.
///
/// # Returns
/// A 1-dimensional array containing distances from `x` to each row in `y`.
fn euclidean_rdist_row(x: &ArrayView1<'_, f64>, y: &ArrayView2<'_, f64>) -> Array1<f64> {
let z = Zip::from(y.rows()).map_collect(|row| euclidean_distance(&row, &x));
z
Zip::from(y.rows()).map_collect(|row| euclidean_distance(&row, &x))
}

/// Computation of Euclidean distance between two 2-D data tensors (e.g., images)
/// Computes the Euclidean distances between rows of two 2-dimensional data arrays synchronously.
///
/// # Arguments
/// * `x` - A 2-dimensional view of data arrays.
/// * `y` - A 2-dimensional view of data arrays.
///
/// # Returns
/// A 2-dimensional array where each element `(i, j)` is the distance between row `i` of `x` and row `j` of `y`.
pub fn euclidean_rdist_rust(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array2<f64> {
let mut c = Array2::<f64>::zeros((x.nrows(), y.nrows()));
Zip::from(x.rows())
Expand All @@ -50,8 +65,14 @@ pub fn euclidean_rdist_rust(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> A
c
}

/// Parallel computation of Euclidean distance between two 2-D data tensors (e.g., images)
/// Similar to `euclidean_rdist_rust` but performs the computation in parallel.
///
/// # Arguments
/// * `x` - A 2-dimensional view of data arrays.
/// * `y` - A 2-dimensional view of data arrays.
///
/// # Returns
/// A 2-dimensional array where each element `(i, j)` is the distance between row `i` of `x` and row `j` of `y`.
pub fn euclidean_rdist_par(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Array2<f64> {
let mut c = Array2::<f64>::zeros((x.nrows(), y.nrows()));
Zip::from(x.rows())
Expand All @@ -60,8 +81,14 @@ pub fn euclidean_rdist_par(x: ArrayView2<'_, f64>, y: ArrayView2<'_, f64>) -> Ar
c
}

/// Compute Earth Movers Distance (EMD) between two 2-D data tensors (e.g., images).
/// Computes the Earth Movers Distance (EMD) between two 2-dimensional data tensors.
///
/// # Arguments
/// * `x` - A 2-dimensional view of f64 data tensors.
/// * `y` - A 2-dimensional view of f64 data tensors.
///
/// # Returns
/// A result containing the EMD as `OrderedFloat<f64>` or an error of type `MatrixFormatError`.
pub fn compute_emd_between_2dtensors(
x: ArrayView2<'_, f64>,
y: ArrayView2<'_, f64>,
Expand All @@ -73,6 +100,14 @@ pub fn compute_emd_between_2dtensors(
Ok(emd_dist)
}

/// Computes the EMD between one 2D tensor and multiple 2D tensors contained in a 3D array, returning results for each computation.
///
/// # Arguments
/// * `x` - A 2-dimensional array view.
/// * `y` - A 3-dimensional array view.
///
/// # Returns
/// A 1-dimensional array where each element is the EMD from `x` to each 2D tensor in `y`.
pub fn compute_emd_bulk(
x: ArrayView2<'_, f64>,
y: ArrayView3<'_, f64>,
Expand All @@ -82,13 +117,21 @@ pub fn compute_emd_bulk(
.and(y.axis_iter(Axis(0)))
.for_each(|c, mat_y| {
*c = compute_emd_between_2dtensors(mat_y, x).unwrap_or_else(|err| {
println!("BAD_VALUE due to: {}", err);
eprintln!("BAD_VALUE due to: {}", err);
return OrderedFloat::from(BAD_VALUE);
})
});
c
}

/// Similar to `compute_emd_bulk` but performs the computation in parallel.
///
/// # Arguments
/// * `x` - A 2-dimensional array view.
/// * `y` - A 3-dimensional array view.
///
/// # Returns
/// A 1-dimensional array where each element is the EMD from `x` to each 2D tensor in `y`.
pub fn compute_emd_bulk_par(
x: ArrayView2<'_, f64>,
y: ArrayView3<'_, f64>,
Expand All @@ -98,13 +141,22 @@ pub fn compute_emd_bulk_par(
.and(y.axis_iter(Axis(0)))
.par_for_each(|c, mat_y| {
*c = compute_emd_between_2dtensors(mat_y, x).unwrap_or_else(|err| {
println!("BAD_VALUE due to: {}", err);
eprintln!("BAD_VALUE due to: {}", err);
return OrderedFloat::from(BAD_VALUE);
})
});
c
}

/// Identifies and returns the indices of the `n` closest tensors from `y` to `x` based on EMD.
///
/// # Arguments
/// * `x` - A 2-dimensional array view of f64.
/// * `y` - A 3-dimensional array view of f64 tensors.
/// * `n` - The number of closest tensors to identify.
///
/// # Returns
/// An array of indices corresponding to the closest tensors.
pub fn classify_closest_n(
x: ArrayView2<'_, f64>,
y: ArrayView3<'_, f64>,
Expand All @@ -116,6 +168,15 @@ pub fn classify_closest_n(
unsafe { Array::from_vec(res.get_unchecked(0..n).to_vec()) }
}

/// Applies `classify_closest_n` for each tensor in `x` against all tensors in `y`, in parallel.
///
/// # Arguments
/// * `x` - A 3-dimensional array view of f64 tensors.
/// * `y` - A 3-dimensional array view of f64 tensors.
/// * `n` - The number of closest tensors to identify for each tensor in `x`.
///
/// # Returns
/// A 2-dimensional array where each row contains indices of the `n` closest tensors for each tensor in `x`.
pub fn classify_closest_n_bulk(
x: ArrayView3<'_, f64>,
y: ArrayView3<'_, f64>,
Expand Down

0 comments on commit 9065c9b

Please sign in to comment.