From 079f0607b43c239bea93e92096b805ec06a8f486 Mon Sep 17 00:00:00 2001 From: Carlos Date: Tue, 21 Nov 2023 12:28:55 +0100 Subject: [PATCH] add instrumentalization --- .../io/anserini/search/SearchCollection.java | 100 +++++++++++++++++- 1 file changed, 96 insertions(+), 4 deletions(-) diff --git a/src/main/java/io/anserini/search/SearchCollection.java b/src/main/java/io/anserini/search/SearchCollection.java index e5ca858803..e71e8c2a72 100644 --- a/src/main/java/io/anserini/search/SearchCollection.java +++ b/src/main/java/io/anserini/search/SearchCollection.java @@ -103,6 +103,7 @@ import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -122,6 +123,57 @@ * Main entry point for search. */ public final class SearchCollection implements Closeable { + + public static long calculateP99(ConcurrentSkipListMap latencies) { + // Step 1: Get the values from the map and store them in a list + List latencyValues = new ArrayList<>(latencies.values()); + + // Step 2: Sort the list + Collections.sort(latencyValues); + + // Step 3: Calculate the index for the P99 value + int p99Index = (int) Math.ceil(latencyValues.size() * 0.99) - 1; + + // Step 4: Retrieve the value at the P99 index + return latencyValues.get(p99Index); +} + + + public static double calculateMedian(ConcurrentSkipListMap latencies) { + // Step 1: Get the values from the map and store them in a list + List latencyValues = new ArrayList<>(latencies.values()); + + // Step 2: Sort the list + Collections.sort(latencyValues); + + // Step 3: Calculate the median + int size = latencyValues.size(); + if (size % 2 == 0) { + // If the list has an even number of elements, average the two middle elements + int middleIndex1 = size / 2 - 1; + int middleIndex2 = size / 2; + long value1 = latencyValues.get(middleIndex1); + long value2 = latencyValues.get(middleIndex2); + return (double) (value1 + value2) / 2.0; + } else { + // If the list has an odd number of elements, return the middle element + int middleIndex = size / 2; + return latencyValues.get(middleIndex); + } +} + + public static double calculateMean(ConcurrentSkipListMap latencies) { + long sum = 0; + + for (long value : latencies.values()) { + sum += value; + } + + return (double) sum / latencies.size(); + +} + + // These are the default tie-breaking rules for documents that end up with the same score with respect to a query. // For most collections, docids are strings, and we break ties by lexicographic sort order. For tweets, docids are // longs, and we break ties by reverse numerical sort order (i.e., most recent tweet first). This means that searching @@ -231,6 +283,10 @@ public static class Args { @Option(name = "-arbitraryScoreTieBreak", usage = "Break score ties arbitrarily (not recommended)") public boolean arbitraryScoreTieBreak = false; + @Option(name = "-instrumentalize", usage = "Add instrumentation to better report latency and QPS") + public boolean instrumentalize = false; + + @Option(name = "-hits", metaVar = "[number]", required = false, usage = "max number of hits to return") public int hits = 1000; @@ -749,6 +805,7 @@ public void run() { // Data structure for holding the per-query results, with the qid as the key and the results (the lines that // will go into the final run file) as the value. ConcurrentSkipListMap results = new ConcurrentSkipListMap<>(); + ConcurrentSkipListMap latencies = new ConcurrentSkipListMap<>(); AtomicInteger cnt = new AtomicInteger(); // Initialize query encoder if specified @@ -767,6 +824,8 @@ public void run() { // This is the per-query execution, in parallel. executor.execute(() -> { + final long initial_query_time = System.nanoTime(); + // This is for holding the results. StringBuilder out = new StringBuilder(); @@ -810,6 +869,12 @@ public void run() { } catch (IOException e) { throw new CompletionException(e); } + if (args.instrumentalize) + { + final long final_query_time = System.nanoTime(); + final long durationMillis = TimeUnit.MILLISECONDS.convert(final_query_time - initial_query_time, TimeUnit.NANOSECONDS); + latencies.put(qid, durationMillis); + } // For removing duplicate docids. Set docids = new HashSet<>(); @@ -859,7 +924,6 @@ public void run() { break; } } - results.put(qid, out.toString()); int n = cnt.incrementAndGet(); if (n % 100 == 0) { @@ -881,9 +945,36 @@ public void run() { } final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); - LOG.info(desc + ": " + topics.size() + " queries processed in " + - DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss") + - String.format(" = ~%.2f q/s", topics.size() / (durationMillis / 1000.0))); + if (args.instrumentalize) + { + final double mean = calculateMean(latencies); + final double median = calculateMedian(latencies); + final double p99 = calculateP99(latencies); + String final_path = outputPath+"_Efficiency"; + double timing = topics.size() / (durationMillis / 1000.0); + PrintWriter out_measures = new PrintWriter(Files.newBufferedWriter(Paths.get(final_path), StandardCharsets.UTF_8)); + out_measures.println("QPS: " + timing); + out_measures.println("Mean latency: " + mean); + out_measures.println("Median latency: " + median); + out_measures.println("P99 latency: " + p99); + out_measures.flush(); + out_measures.close(); + LOG.info(desc + ": " + topics.size() + " queries processed in " + + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss") + + String.format(" = ~%.2f q/s", topics.size() / (durationMillis / 1000.0)) + + String.format(" with average latency: %.2f ms", mean) + + String.format(" with median latency: %.2f ms", median) + + String.format(" and p99 latency: %.2f ms", p99) + ); + + } + else + { + LOG.info(desc + ": " + topics.size() + " queries processed in " + + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss") + + String.format(" = ~%.2f q/s", topics.size() / (durationMillis / 1000.0))); + } + // Now we write the results to a run file. PrintWriter out = new PrintWriter(Files.newBufferedWriter(Paths.get(outputPath), StandardCharsets.UTF_8)); @@ -1438,5 +1529,6 @@ public static void main(String[] args) throws Exception { searcher.close(); final long durationMillis = TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS); LOG.info("Total run time: " + DurationFormatUtils.formatDuration(durationMillis, "HH:mm:ss")); + } }