Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ekoutanov committed Nov 24, 2023
1 parent 9d2347b commit 548b80c
Showing 1 changed file with 75 additions and 15 deletions.
90 changes: 75 additions & 15 deletions src/harville.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,49 @@ pub fn harville(probs: &Matrix<f64>, podium: &[usize]) -> f64 {

pub fn harville_summary(probs: &Matrix<f64>, ranks: usize) -> Matrix<f64> {
let runners = probs.cols();
let mut summary = Matrix::allocate(ranks, runners);
let mut summary = Matrix::allocate(ranks, runners);
let cardinalities = vec![runners; ranks];
let mut podium = vec![0; ranks];
let mut bitmap = vec![false; runners];
harville_summary_no_alloc(probs, ranks, &cardinalities, &mut podium, &mut bitmap, &mut summary);
harville_summary_no_alloc(
probs,
ranks,
&cardinalities,
&mut podium,
&mut bitmap,
&mut summary,
);
summary
}

pub fn harville_summary_no_alloc(probs: &Matrix<f64>, ranks: usize, cardinalities: &[usize], podium: &mut [usize], bitmap: &mut [bool], summary: &mut Matrix<f64>) {
debug_assert_eq!(probs.rows(), ranks, "number of rows in the probabilities matrix must equal to the number of ranks");
pub fn harville_summary_no_alloc(
probs: &Matrix<f64>,
ranks: usize,
cardinalities: &[usize],
podium: &mut [usize],
bitmap: &mut [bool],
summary: &mut Matrix<f64>,
) {
debug_assert_eq!(
probs.rows(),
ranks,
"number of rows in the probabilities matrix must equal to the number of ranks"
);
debug_assert_eq!(summary.rows(), probs.rows(), "number of rows in the probabilities matrix must equal to the number of rows in the summary matrix");
debug_assert_eq!(summary.cols(), probs.cols(), "number of columns in the probabilities matrix must equal to the number of columns in the summary matrix");
debug_assert_eq!(probs.rows(), podium.len(), "number of rows in the probabilities matrix must equal to the podium length");
debug_assert_eq!(probs.cols(), bitmap.len(), "number of columns in the probabilities matrix must equal to the bitmap length");
debug_assert_eq!(
probs.rows(),
podium.len(),
"number of rows in the probabilities matrix must equal to the podium length"
);
debug_assert_eq!(
probs.cols(),
bitmap.len(),
"number of columns in the probabilities matrix must equal to the bitmap length"
);
let combinations = count_combinations(cardinalities);
for combination in 0..combinations {
pick(&cardinalities, combination, podium);
pick(cardinalities, combination, podium);
if !is_unique_linear(podium, bitmap) {
continue;
}
Expand All @@ -48,13 +74,14 @@ pub fn harville_summary_no_alloc(probs: &Matrix<f64>, ranks: usize, cardinalitie

#[cfg(test)]
mod tests {
use assert_float_eq::*;
use assert_float_eq::assert_float_relative_eq;
use assert_float_eq::*;

use crate::capture::Capture;
use crate::comb::{Combinator, is_unique_quadratic};
use crate::comb::{is_unique_quadratic, Combinator};
use crate::mc::DilatedProbs;
use crate::probs::SliceExt;
use crate::testing::assert_slice_f64_relative;

use super::*;

Expand Down Expand Up @@ -219,7 +246,7 @@ mod tests {
let probs = Matrix::from(
DilatedProbs::default()
.with_win_probs(Capture::Borrowed(&WIN_PROBS))
.with_dilatives(Capture::Borrowed(&DILATIVES))
.with_dilatives(Capture::Borrowed(&DILATIVES)),
);
let combinator = Combinator::new(&[RUNNERS; RANKS]);
let probs = combinator
Expand Down Expand Up @@ -249,9 +276,22 @@ mod tests {
.with_podium_places(RANKS),
);
let summary = harville_summary(&probs, RANKS);
assert_eq!(RANKS, summary.rows());
assert_eq!(WIN_PROBS.len(), summary.cols());
println!("summary:\n{}", summary.verbose());
assert_slice_f64_relative(
&[
0.6,
0.3,
0.1,
0.32380952380952444,
0.48333333333333445,
0.19285714285714314,
0.07619047619047627,
0.216666666666667,
0.7071428571428587,
],
summary.flatten(),
1e-9,
);

for row in summary.into_iter() {
assert_float_relative_eq!(1.0, row.sum());
Expand All @@ -272,9 +312,25 @@ mod tests {
.with_podium_places(RANKS),
);
let summary = harville_summary(&probs, RANKS);
assert_eq!(RANKS, summary.rows());
assert_eq!(WIN_PROBS.len(), summary.cols());
println!("summary:\n{}", summary.verbose());
assert_slice_f64_relative(
&[
0.6,
0.3,
0.1,
0.0,
0.32380952380952444,
0.48333333333333445,
0.19285714285714314,
0.0,
0.07619047619047627,
0.216666666666667,
0.7071428571428587,
0.0,
],
summary.flatten(),
1e-9,
);

for row in summary.into_iter() {
assert_float_relative_eq!(1.0, row.sum());
Expand All @@ -301,6 +357,7 @@ mod tests {
let summary = harville_summary(&probs, RANKS);
assert_eq!(RANKS, summary.rows());
assert_eq!(WIN_PROBS.len(), summary.cols());
assert_slice_f64_relative(&WIN_PROBS, &summary[0], 1e-9);
println!("summary:\n{}", summary.verbose());

for row in summary.into_iter() {
Expand All @@ -324,6 +381,7 @@ mod tests {
let summary = harville_summary(&probs, RANKS);
assert_eq!(RANKS, summary.rows());
assert_eq!(WIN_PROBS.len(), summary.cols());
assert_slice_f64_relative(&WIN_PROBS, &summary[0], 1e-9);
println!("summary:\n{}", summary.verbose());

for row in summary.into_iter() {
Expand All @@ -347,6 +405,7 @@ mod tests {
let summary = harville_summary(&probs, RANKS);
assert_eq!(RANKS, summary.rows());
assert_eq!(WIN_PROBS.len(), summary.cols());
assert_slice_f64_relative(&WIN_PROBS, &summary[0], 1e-9);
println!("summary:\n{}", summary.verbose());

for row in summary.into_iter() {
Expand All @@ -366,11 +425,12 @@ mod tests {
let probs = Matrix::from(
DilatedProbs::default()
.with_win_probs(Capture::Borrowed(&WIN_PROBS))
.with_dilatives(Capture::Borrowed(&DILATIVES))
.with_dilatives(Capture::Borrowed(&DILATIVES)),
);
let summary = harville_summary(&probs, RANKS);
assert_eq!(RANKS, summary.rows());
assert_eq!(WIN_PROBS.len(), summary.cols());
assert_slice_f64_relative(&WIN_PROBS, &summary[0], 1e-9);
println!("summary:\n{}", summary.verbose());

for row in summary.into_iter() {
Expand Down

0 comments on commit 548b80c

Please sign in to comment.