From a1b8270e70139c85703968f2dc3de0d27faf1c4a Mon Sep 17 00:00:00 2001 From: Shunsuke Kanda Date: Sat, 26 Oct 2024 12:05:24 +0900 Subject: [PATCH] Add basic stats for cli (#48) * add * add * add * other metric test --- elinor-cli/README.md | 25 ++++++++++++++- elinor-cli/src/bin/compare.rs | 53 +++++++++++++++++++++++++++---- elinor-cli/src/bin/evaluate.rs | 11 +++++++ scripts/compare_with_trec_eval.py | 10 ++++++ src/relevance.rs | 14 ++++++++ 5 files changed, 105 insertions(+), 8 deletions(-) diff --git a/elinor-cli/README.md b/elinor-cli/README.md index d77e068..837381a 100644 --- a/elinor-cli/README.md +++ b/elinor-cli/README.md @@ -60,9 +60,14 @@ cargo run --release -p elinor-cli --bin elinor-evaluate -- \ The available metrics are shown in [Metric](https://docs.rs/elinor/latest/elinor/metrics/enum.Metric.html). -The output will show the macro-averaged scores for each metric: +The output will show the basic statistics and the macro-averaged scores for each metric: ``` +n_queries_in_true 8 +n_queries_in_pred 8 +n_docs_in_true 20 +n_docs_in_pred 24 +n_true_relevant_docs 14 precision@3 0.5833 ap 0.8229 rr 0.8125 @@ -135,6 +140,15 @@ cargo run --release -p elinor-cli --bin elinor-compare -- \ The output will be: ``` +# Basic statistics ++-----------+-------+ +| Key | Value | ++-----------+-------+ +| n_systems | 2 | +| n_topics | 8 | +| n_metrics | 4 | ++-----------+-------+ + # Alias +----------+-----------------------------+ | Alias | Path | @@ -198,6 +212,15 @@ cargo run --release -p elinor-cli --bin elinor-compare -- \ The output will be: ``` +# Basic statistics ++-----------+-------+ +| Key | Value | ++-----------+-------+ +| n_systems | 3 | +| n_topics | 8 | +| n_metrics | 4 | ++-----------+-------+ + # Alias +----------+-----------------------------+ | Alias | Path | diff --git a/elinor-cli/src/bin/compare.rs b/elinor-cli/src/bin/compare.rs index af9c634..6e470fb 100644 --- a/elinor-cli/src/bin/compare.rs +++ b/elinor-cli/src/bin/compare.rs @@ -43,6 +43,14 @@ struct Args { /// Print mode for the output (pretty or raw). #[arg(short, long, default_value = "pretty")] print_mode: PrintMode, + + /// Number of resamples for the bootstrap test. + #[arg(long, default_value = "10000")] + n_resamples: usize, + + /// Number of iterations for the randomized test. + #[arg(long, default_value = "10000")] + n_iters: usize, } fn main() -> Result<()> { @@ -85,9 +93,33 @@ fn main() -> Result<()> { } let topic_header = topic_headers[0].as_str(); + println!("# Basic statistics"); + { + let columns = vec![ + Series::new( + "Key".into(), + vec![ + "n_systems".to_string(), + "n_topics".to_string(), + "n_metrics".to_string(), + ], + ), + Series::new( + "Value".into(), + vec![ + dfs.len() as u64, + dfs[0].get_columns()[0].len() as u64, + dfs[0].get_columns().len() as u64 - 1, + ], + ), + ]; + let df = DataFrame::new(columns)?; + print_dataframe(&df, args.print_mode); + } + // If there is only one input CSV file, just print the means. if args.input_csvs.len() == 1 { - println!("# Means"); + println!("\n# Means"); { let metrics = extract_metrics(&dfs[0]); let values = get_means(&dfs[0], &metrics, topic_header); @@ -101,7 +133,7 @@ fn main() -> Result<()> { return Ok(()); } - println!("# Alias"); + println!("\n# Alias"); { let columns = vec![ Series::new( @@ -123,10 +155,17 @@ fn main() -> Result<()> { } if dfs.len() == 2 { - compare_two_systems(&dfs[0], &dfs[1], topic_header, args.print_mode)?; + compare_two_systems( + &dfs[0], + &dfs[1], + topic_header, + args.print_mode, + args.n_resamples, + args.n_iters, + )?; } if dfs.len() > 2 { - compare_multiple_systems(&dfs, topic_header, args.print_mode)?; + compare_multiple_systems(&dfs, topic_header, args.print_mode, args.n_iters)?; } Ok(()) @@ -180,6 +219,8 @@ fn compare_two_systems( df_2: &DataFrame, topic_header: &str, print_mode: PrintMode, + n_resamples: usize, + n_iters: usize, ) -> Result<()> { let metrics = extract_common_metrics([df_1, df_2]); if metrics.is_empty() { @@ -278,7 +319,6 @@ fn compare_two_systems( print_dataframe(&df, print_mode); } - let n_resamples = 10000; println!("\n# Two-sided paired Bootstrap test (n_resamples = {n_resamples})"); { let mut stats = vec![]; @@ -306,7 +346,6 @@ fn compare_two_systems( print_dataframe(&df, print_mode); } - let n_iters = 10000; println!("\n# Fisher's randomized test (n_iters = {n_iters})"); { let mut stats = vec![]; @@ -344,6 +383,7 @@ fn compare_multiple_systems( dfs: &[DataFrame], topic_header: &str, print_mode: PrintMode, + n_iters: usize, ) -> Result<()> { let metrics = extract_common_metrics(dfs); if metrics.is_empty() { @@ -382,7 +422,6 @@ fn compare_multiple_systems( df_metrics.push(joined); } - let n_iters = 10000; let rthsd_tester = RandomizedTukeyHsdTester::new(dfs.len()).with_n_iters(n_iters); for (metric, df_metric) in metrics.iter().zip(df_metrics.iter()) { diff --git a/elinor-cli/src/bin/evaluate.rs b/elinor-cli/src/bin/evaluate.rs index a98b3b1..0260dff 100644 --- a/elinor-cli/src/bin/evaluate.rs +++ b/elinor-cli/src/bin/evaluate.rs @@ -54,6 +54,12 @@ fn main() -> Result<()> { args.metrics }; + println!("n_queries_in_true\t{}", true_rels.n_queries()); + println!("n_queries_in_pred\t{}", pred_rels.n_queries()); + println!("n_docs_in_true\t{}", true_rels.n_docs()); + println!("n_docs_in_pred\t{}", pred_rels.n_docs()); + println!("n_true_relevant_docs\t{}", n_relevant_docs(&true_rels)); + let mut columns = vec![]; for metric in metrics { let result = elinor::evaluate(&true_rels, &pred_rels, metric)?; @@ -79,6 +85,11 @@ fn main() -> Result<()> { Ok(()) } +fn n_relevant_docs(true_rels: &TrueRelStore) -> usize { + let records = true_rels.records(); + records.into_iter().filter(|r| r.score > 0).count() +} + fn default_metrics() -> Vec { let mut metrics = Vec::new(); for k in [1, 5, 10] { diff --git a/scripts/compare_with_trec_eval.py b/scripts/compare_with_trec_eval.py index a299c1a..e0a99c1 100755 --- a/scripts/compare_with_trec_eval.py +++ b/scripts/compare_with_trec_eval.py @@ -74,6 +74,16 @@ def compare_decimal_places(a: str, b: str, decimal_places: int) -> bool: [metric for _, metric in metric_pairs], ) + # Add some additional basic metrics + metric_pairs.extend( + [ + ("num_q", "n_queries_in_true"), + ("num_q", "n_queries_in_pred"), + ("num_ret", "n_docs_in_pred"), + ("num_rel", "n_true_relevant_docs"), + ] + ) + failed_rows: list[str] = [] for trec_metric, elinor_metric in metric_pairs: trec_score = trec_results["trec_eval_output"][trec_metric] diff --git a/src/relevance.rs b/src/relevance.rs index d80c6cf..d7ddb1a 100644 --- a/src/relevance.rs +++ b/src/relevance.rs @@ -84,6 +84,20 @@ where .collect() } + /// Returns the relevance store as records. + pub fn records(&self) -> Vec> { + self.map + .iter() + .flat_map(|(query_id, data)| { + data.sorted.iter().map(move |rel| Record { + query_id: query_id.clone(), + doc_id: rel.doc_id.clone(), + score: rel.score.clone(), + }) + }) + .collect() + } + /// Returns the score for a given query-document pair. pub fn get_score(&self, query_id: &Q, doc_id: &Q) -> Option<&T> where