diff --git a/src/harville.rs b/src/harville.rs index 5ff459f..da68206 100644 --- a/src/harville.rs +++ b/src/harville.rs @@ -19,23 +19,49 @@ pub fn harville(probs: &Matrix, podium: &[usize]) -> f64 { pub fn harville_summary(probs: &Matrix, ranks: usize) -> Matrix { let runners = probs.cols(); - let mut summary = Matrix::allocate(ranks, runners); + let mut summary = Matrix::allocate(ranks, runners); let cardinalities = vec![runners; ranks]; let mut podium = vec![0; ranks]; let mut bitmap = vec![false; runners]; - harville_summary_no_alloc(probs, ranks, &cardinalities, &mut podium, &mut bitmap, &mut summary); + harville_summary_no_alloc( + probs, + ranks, + &cardinalities, + &mut podium, + &mut bitmap, + &mut summary, + ); summary } -pub fn harville_summary_no_alloc(probs: &Matrix, ranks: usize, cardinalities: &[usize], podium: &mut [usize], bitmap: &mut [bool], summary: &mut Matrix) { - debug_assert_eq!(probs.rows(), ranks, "number of rows in the probabilities matrix must equal to the number of ranks"); +pub fn harville_summary_no_alloc( + probs: &Matrix, + ranks: usize, + cardinalities: &[usize], + podium: &mut [usize], + bitmap: &mut [bool], + summary: &mut Matrix, +) { + debug_assert_eq!( + probs.rows(), + ranks, + "number of rows in the probabilities matrix must equal to the number of ranks" + ); debug_assert_eq!(summary.rows(), probs.rows(), "number of rows in the probabilities matrix must equal to the number of rows in the summary matrix"); debug_assert_eq!(summary.cols(), probs.cols(), "number of columns in the probabilities matrix must equal to the number of columns in the summary matrix"); - debug_assert_eq!(probs.rows(), podium.len(), "number of rows in the probabilities matrix must equal to the podium length"); - debug_assert_eq!(probs.cols(), bitmap.len(), "number of columns in the probabilities matrix must equal to the bitmap length"); + debug_assert_eq!( + probs.rows(), + podium.len(), + "number of rows in the probabilities matrix must equal to the podium length" + ); + debug_assert_eq!( + probs.cols(), + bitmap.len(), + "number of columns in the probabilities matrix must equal to the bitmap length" + ); let combinations = count_combinations(cardinalities); for combination in 0..combinations { - pick(&cardinalities, combination, podium); + pick(cardinalities, combination, podium); if !is_unique_linear(podium, bitmap) { continue; } @@ -48,13 +74,14 @@ pub fn harville_summary_no_alloc(probs: &Matrix, ranks: usize, cardinalitie #[cfg(test)] mod tests { - use assert_float_eq::*; use assert_float_eq::assert_float_relative_eq; + use assert_float_eq::*; use crate::capture::Capture; - use crate::comb::{Combinator, is_unique_quadratic}; + use crate::comb::{is_unique_quadratic, Combinator}; use crate::mc::DilatedProbs; use crate::probs::SliceExt; + use crate::testing::assert_slice_f64_relative; use super::*; @@ -219,7 +246,7 @@ mod tests { let probs = Matrix::from( DilatedProbs::default() .with_win_probs(Capture::Borrowed(&WIN_PROBS)) - .with_dilatives(Capture::Borrowed(&DILATIVES)) + .with_dilatives(Capture::Borrowed(&DILATIVES)), ); let combinator = Combinator::new(&[RUNNERS; RANKS]); let probs = combinator @@ -249,9 +276,22 @@ mod tests { .with_podium_places(RANKS), ); let summary = harville_summary(&probs, RANKS); - assert_eq!(RANKS, summary.rows()); - assert_eq!(WIN_PROBS.len(), summary.cols()); println!("summary:\n{}", summary.verbose()); + assert_slice_f64_relative( + &[ + 0.6, + 0.3, + 0.1, + 0.32380952380952444, + 0.48333333333333445, + 0.19285714285714314, + 0.07619047619047627, + 0.216666666666667, + 0.7071428571428587, + ], + summary.flatten(), + 1e-9, + ); for row in summary.into_iter() { assert_float_relative_eq!(1.0, row.sum()); @@ -272,9 +312,25 @@ mod tests { .with_podium_places(RANKS), ); let summary = harville_summary(&probs, RANKS); - assert_eq!(RANKS, summary.rows()); - assert_eq!(WIN_PROBS.len(), summary.cols()); println!("summary:\n{}", summary.verbose()); + assert_slice_f64_relative( + &[ + 0.6, + 0.3, + 0.1, + 0.0, + 0.32380952380952444, + 0.48333333333333445, + 0.19285714285714314, + 0.0, + 0.07619047619047627, + 0.216666666666667, + 0.7071428571428587, + 0.0, + ], + summary.flatten(), + 1e-9, + ); for row in summary.into_iter() { assert_float_relative_eq!(1.0, row.sum()); @@ -301,6 +357,7 @@ mod tests { let summary = harville_summary(&probs, RANKS); assert_eq!(RANKS, summary.rows()); assert_eq!(WIN_PROBS.len(), summary.cols()); + assert_slice_f64_relative(&WIN_PROBS, &summary[0], 1e-9); println!("summary:\n{}", summary.verbose()); for row in summary.into_iter() { @@ -324,6 +381,7 @@ mod tests { let summary = harville_summary(&probs, RANKS); assert_eq!(RANKS, summary.rows()); assert_eq!(WIN_PROBS.len(), summary.cols()); + assert_slice_f64_relative(&WIN_PROBS, &summary[0], 1e-9); println!("summary:\n{}", summary.verbose()); for row in summary.into_iter() { @@ -347,6 +405,7 @@ mod tests { let summary = harville_summary(&probs, RANKS); assert_eq!(RANKS, summary.rows()); assert_eq!(WIN_PROBS.len(), summary.cols()); + assert_slice_f64_relative(&WIN_PROBS, &summary[0], 1e-9); println!("summary:\n{}", summary.verbose()); for row in summary.into_iter() { @@ -366,11 +425,12 @@ mod tests { let probs = Matrix::from( DilatedProbs::default() .with_win_probs(Capture::Borrowed(&WIN_PROBS)) - .with_dilatives(Capture::Borrowed(&DILATIVES)) + .with_dilatives(Capture::Borrowed(&DILATIVES)), ); let summary = harville_summary(&probs, RANKS); assert_eq!(RANKS, summary.rows()); assert_eq!(WIN_PROBS.len(), summary.cols()); + assert_slice_f64_relative(&WIN_PROBS, &summary[0], 1e-9); println!("summary:\n{}", summary.verbose()); for row in summary.into_iter() {