From a24646c708a9f12a004730d7ae43fab74821de20 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Thu, 30 Jan 2025 11:36:18 -0800 Subject: [PATCH] Unify precomputation of aggregations behind a common API (#16733) (#17197) We've had a series of aggregation speedups that use the same strategy: instead of iterating through documents that match the query one-by-one, we can look at a Lucene segment and compute the aggregation directly (if some particular conditions are met). In every case, we've hooked that into custom logic hijacks the getLeafCollector method and throws CollectionTerminatedException. This creates the illusion that we're implementing a custom LeafCollector, when really we're not collecting at all (which is the whole point). With this refactoring, the mechanism (hijacking getLeafCollector) is moved into AggregatorBase. Aggregators that have a strategy to precompute their answer can override tryPrecomputeAggregationForLeaf, which is expected to return true if they managed to precompute. This should also make it easier to keep track of which aggregations have precomputation approaches (since they override this method). --------- (cherry picked from commit 2847695ad1fcc04c88b96c8bab0bfdf694fa05dc) Signed-off-by: Michael Froh Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- .../search/aggregations/AggregatorBase.java | 23 +++++- .../bucket/composite/CompositeAggregator.java | 9 ++- .../histogram/DateHistogramAggregator.java | 22 +++--- .../bucket/range/RangeAggregator.java | 12 ++-- .../GlobalOrdinalsStringTermsAggregator.java | 72 +++++++++---------- .../aggregations/metrics/AvgAggregator.java | 28 ++++---- .../aggregations/metrics/MaxAggregator.java | 36 +++++----- .../aggregations/metrics/MinAggregator.java | 36 ++++++---- .../aggregations/metrics/SumAggregator.java | 24 ++++--- .../metrics/ValueCountAggregator.java | 29 ++++---- .../search/startree/StarTreeQueryHelper.java | 9 +-- 11 files changed, 168 insertions(+), 132 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java b/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java index 47e9def094623..f91bf972a3d28 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java +++ b/server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java @@ -32,6 +32,7 @@ package org.opensearch.search.aggregations; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.ScoreMode; import org.opensearch.core.common.breaker.CircuitBreaker; @@ -200,6 +201,9 @@ public Map metadata() { @Override public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException { + if (tryPrecomputeAggregationForLeaf(ctx)) { + throw new CollectionTerminatedException(); + } preGetSubLeafCollectors(ctx); final LeafBucketCollector sub = collectableSubAggregators.getLeafCollector(ctx); return getLeafCollector(ctx, sub); @@ -216,6 +220,21 @@ protected void preGetSubLeafCollectors(LeafReaderContext ctx) throws IOException */ protected void doPreCollection() throws IOException {} + /** + * Subclasses may override this method if they have an efficient way of computing their aggregation for the given + * segment (versus collecting matching documents). If this method returns true, collection for the given segment + * will be terminated, rather than executing normally. + *

+ * If this method returns true, the aggregator's state should be identical to what it would be if matching + * documents from the segment were fully collected. If this method returns false, the aggregator's state should + * be unchanged from before this method is called. + * @param ctx the context for the given segment + * @return true if and only if results for this segment have been precomputed + */ + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { + return false; + } + @Override public final void preCollection() throws IOException { List collectors = Arrays.asList(subAggregators); @@ -251,8 +270,8 @@ public Aggregator[] subAggregators() { public Aggregator subAggregator(String aggName) { if (subAggregatorbyName == null) { subAggregatorbyName = new HashMap<>(subAggregators.length); - for (int i = 0; i < subAggregators.length; i++) { - subAggregatorbyName.put(subAggregators[i].name(), subAggregators[i]); + for (Aggregator subAggregator : subAggregators) { + subAggregatorbyName.put(subAggregator.name(), subAggregator); } } return subAggregatorbyName.get(aggName); diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java index cfe716eb57ca8..0a200fcbc105b 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java @@ -556,10 +556,13 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t } @Override - protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); - if (optimized) throw new CollectionTerminatedException(); + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { + finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed + return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); + } + @Override + protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { finishLeaf(); boolean fillDocIdSet = deferredCollectors != NO_OP_COLLECTOR; diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java index 7e8f4958e9d56..0482188a33b14 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java @@ -33,7 +33,6 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedNumericDocValues; -import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.util.CollectionUtil; @@ -187,22 +186,23 @@ public ScoreMode scoreMode() { } @Override - public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - if (valuesSource == null) { - return LeafBucketCollector.NO_OP_COLLECTOR; - } - - boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); - if (optimized) throw new CollectionTerminatedException(); - - SortedNumericDocValues values = valuesSource.longValues(ctx); + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext()); if (supportedStarTree != null) { if (preComputeWithStarTree(ctx, supportedStarTree) == true) { - throw new CollectionTerminatedException(); + return true; } } + return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx)); + } + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + if (valuesSource == null) { + return LeafBucketCollector.NO_OP_COLLECTOR; + } + + SortedNumericDocValues values = valuesSource.longValues(ctx); return new LeafBucketCollectorBase(sub, values) { @Override public void collect(int doc, long owningBucketOrd) throws IOException { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java index 39dc1c36e8895..97e63078396f5 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java @@ -32,7 +32,6 @@ package org.opensearch.search.aggregations.bucket.range; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.ScoreMode; import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; @@ -310,10 +309,15 @@ public ScoreMode scoreMode() { } @Override - public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { - if (segmentMatchAll(context, ctx) && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false)) { - throw new CollectionTerminatedException(); + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { + if (segmentMatchAll(context, ctx)) { + return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false); } + return false; + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); return new LeafBucketCollectorBase(sub, values) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index 9e40f7b4c9b3e..4ce2e618535a4 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -40,7 +40,6 @@ import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; -import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.Weight; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; @@ -166,35 +165,32 @@ public void setWeight(Weight weight) { @return A LeafBucketCollector implementation with collection termination, since collection is complete @throws IOException If an I/O error occurs during reading */ - LeafBucketCollector termDocFreqCollector( - LeafReaderContext ctx, - SortedSetDocValues globalOrds, - BiConsumer ordCountConsumer - ) throws IOException { + boolean tryCollectFromTermFrequencies(LeafReaderContext ctx, SortedSetDocValues globalOrds, BiConsumer ordCountConsumer) + throws IOException { if (weight == null) { // Weight not assigned - cannot use this optimization - return null; + return false; } else { if (weight.count(ctx) == 0) { // No documents matches top level query on this segment, we can skip the segment entirely - return LeafBucketCollector.NO_OP_COLLECTOR; + return true; } else if (weight.count(ctx) != ctx.reader().maxDoc()) { // weight.count(ctx) == ctx.reader().maxDoc() implies there are no deleted documents and // top-level query matches all docs in the segment - return null; + return false; } } Terms segmentTerms = ctx.reader().terms(this.fieldName); if (segmentTerms == null) { // Field is not indexed. - return null; + return false; } NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME); if (docCountValues.nextDoc() != NO_MORE_DOCS) { // This segment has at least one document with the _doc_count field. - return null; + return false; } TermsEnum indexTermsEnum = segmentTerms.iterator(); @@ -218,31 +214,28 @@ LeafBucketCollector termDocFreqCollector( ordinalTerm = globalOrdinalTermsEnum.next(); } } - return new LeafBucketCollector() { - @Override - public void collect(int doc, long owningBucketOrd) throws IOException { - throw new CollectionTerminatedException(); - } - }; + return true; } @Override - public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx); - collectionStrategy.globalOrdsReady(globalOrds); - if (collectionStrategy instanceof DenseGlobalOrds && this.resultStrategy instanceof StandardTermsResults - && sub == LeafBucketCollector.NO_OP_COLLECTOR) { - LeafBucketCollector termDocFreqCollector = termDocFreqCollector( + && subAggregators.length == 0) { + return tryCollectFromTermFrequencies( ctx, globalOrds, (ord, docCount) -> incrementBucketDocCount(collectionStrategy.globalOrdToBucketOrd(0, ord), docCount) ); - if (termDocFreqCollector != null) { - return termDocFreqCollector; - } } + return false; + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx); + collectionStrategy.globalOrdsReady(globalOrds); SortedDocValues singleValues = DocValues.unwrapSingleton(globalOrds); if (singleValues != null) { @@ -433,6 +426,24 @@ static class LowCardinality extends GlobalOrdinalsStringTermsAggregator { this.segmentDocCounts = context.bigArrays().newLongArray(1, true); } + @Override + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { + if (subAggregators.length == 0) { + if (mapping != null) { + mapSegmentCountsToGlobalCounts(mapping); + } + final SortedSetDocValues segmentOrds = valuesSource.ordinalsValues(ctx); + segmentDocCounts = context.bigArrays().grow(segmentDocCounts, 1 + segmentOrds.getValueCount()); + mapping = valuesSource.globalOrdinalsMapping(ctx); + return tryCollectFromTermFrequencies( + ctx, + segmentOrds, + (ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount) + ); + } + return false; + } + @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { if (mapping != null) { @@ -443,17 +454,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol assert sub == LeafBucketCollector.NO_OP_COLLECTOR; mapping = valuesSource.globalOrdinalsMapping(ctx); - if (this.resultStrategy instanceof StandardTermsResults) { - LeafBucketCollector termDocFreqCollector = this.termDocFreqCollector( - ctx, - segmentOrds, - (ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount) - ); - if (termDocFreqCollector != null) { - return termDocFreqCollector; - } - } - final SortedDocValues singleValues = DocValues.unwrapSingleton(segmentOrds); if (singleValues != null) { segmentsWithSingleValuedOrds++; diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/AvgAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/AvgAggregator.java index f71b6679a7c4d..5f99a9cc05558 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/AvgAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/AvgAggregator.java @@ -32,7 +32,6 @@ package org.opensearch.search.aggregations.metrics; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.util.FixedBitSet; @@ -104,23 +103,29 @@ public ScoreMode scoreMode() { } @Override - public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { if (valuesSource == null) { - return LeafBucketCollector.NO_OP_COLLECTOR; + return false; } CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext()); if (supportedStarTree != null) { if (parent != null && subAggregators.length == 0) { // If this a child aggregator, then the parent will trigger star-tree pre-computation. // Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators - return LeafBucketCollector.NO_OP_COLLECTOR; + return true; } - return getStarTreeLeafCollector(ctx, sub, supportedStarTree); + precomputeLeafUsingStarTree(ctx, supportedStarTree); + return true; } - return getDefaultLeafCollector(ctx, sub); + return false; } - private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + if (valuesSource == null) { + return LeafBucketCollector.NO_OP_COLLECTOR; + } + final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); final CompensatedSum kahanSummation = new CompensatedSum(0, 0); @@ -154,8 +159,7 @@ public void collect(int doc, long bucket) throws IOException { }; } - public LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) - throws IOException { + private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException { StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree); assert starTreeValues != null; @@ -200,12 +204,6 @@ public LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafB sums.set(0, kahanSummation.value()); compensations.set(0, kahanSummation.delta()); - return new LeafBucketCollectorBase(sub, valuesSource.doubleValues(ctx)) { - @Override - public void collect(int doc, long bucket) { - throw new CollectionTerminatedException(); - } - }; } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java index c64a6cf29fb63..8a2c8a6de923f 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java @@ -104,6 +104,24 @@ public ScoreMode scoreMode() { return valuesSource != null && valuesSource.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; } + @Override + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { + if (valuesSource == null) { + return false; + } + CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext()); + if (supportedStarTree != null) { + if (parent != null && subAggregators.length == 0) { + // If this a child aggregator, then the parent will trigger star-tree pre-computation. + // Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators + return true; + } + precomputeLeafUsingStarTree(ctx, supportedStarTree); + return true; + } + return false; + } + @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { if (valuesSource == null) { @@ -130,20 +148,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc } } - CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext()); - if (supportedStarTree != null) { - if (parent != null && subAggregators.length == 0) { - // If this a child aggregator, then the parent will trigger star-tree pre-computation. - // Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators - return LeafBucketCollector.NO_OP_COLLECTOR; - } - getStarTreeCollector(ctx, sub, supportedStarTree); - } - return getDefaultLeafCollector(ctx, sub); - } - - private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx); final NumericDoubleValues values = MultiValueMode.MAX.select(allValues); @@ -167,9 +171,9 @@ public void collect(int doc, long bucket) throws IOException { }; } - public void getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException { + private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException { AtomicReference max = new AtomicReference<>(maxes.get(0)); - StarTreeQueryHelper.getStarTreeLeafCollector(context, valuesSource, ctx, sub, starTree, MetricStat.MAX.getTypeName(), value -> { + StarTreeQueryHelper.precomputeLeafUsingStarTree(context, valuesSource, ctx, starTree, MetricStat.MAX.getTypeName(), value -> { max.set(Math.max(max.get(), (NumericUtils.sortableLongToDouble(value)))); }, () -> maxes.set(0, max.get())); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregator.java index 5cdee536cde19..84dda7928aa90 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/MinAggregator.java @@ -104,6 +104,25 @@ public ScoreMode scoreMode() { return valuesSource != null && valuesSource.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; } + @Override + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { + if (valuesSource == null) { + return false; + } + CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext()); + if (supportedStarTree != null) { + if (parent != null && subAggregators.length == 0) { + // If this a child aggregator, then the parent will trigger star-tree pre-computation. + // Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators + return true; + } + precomputeLeafUsingStarTree(ctx, supportedStarTree); + return true; + } + + return false; + } + @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { if (valuesSource == null) { @@ -129,19 +148,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc } } - CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext()); - if (supportedStarTree != null) { - if (parent != null && subAggregators.length == 0) { - // If this a child aggregator, then the parent will trigger star-tree pre-computation. - // Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators - return LeafBucketCollector.NO_OP_COLLECTOR; - } - getStarTreeCollector(ctx, sub, supportedStarTree); - } - return getDefaultLeafCollector(ctx, sub); - } - - private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx); final NumericDoubleValues values = MultiValueMode.MIN.select(allValues); @@ -164,9 +170,9 @@ public void collect(int doc, long bucket) throws IOException { }; } - public void getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException { + private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException { AtomicReference min = new AtomicReference<>(mins.get(0)); - StarTreeQueryHelper.getStarTreeLeafCollector(context, valuesSource, ctx, sub, starTree, MetricStat.MIN.getTypeName(), value -> { + StarTreeQueryHelper.precomputeLeafUsingStarTree(context, valuesSource, ctx, starTree, MetricStat.MIN.getTypeName(), value -> { min.set(Math.min(min.get(), (NumericUtils.sortableLongToDouble(value)))); }, () -> mins.set(0, min.get())); } diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java index edcfb61263fc1..ba32592f75ea1 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/SumAggregator.java @@ -93,24 +93,29 @@ public ScoreMode scoreMode() { } @Override - public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { if (valuesSource == null) { - return LeafBucketCollector.NO_OP_COLLECTOR; + return false; } - CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext()); if (supportedStarTree != null) { if (parent != null && subAggregators.length == 0) { // If this a child aggregator, then the parent will trigger star-tree pre-computation. // Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators - return LeafBucketCollector.NO_OP_COLLECTOR; + return true; } - getStarTreeCollector(ctx, sub, supportedStarTree); + precomputeLeafUsingStarTree(ctx, supportedStarTree); + return true; } - return getDefaultLeafCollector(ctx, sub); + return false; } - private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + + if (valuesSource == null) { + return LeafBucketCollector.NO_OP_COLLECTOR; + } final BigArrays bigArrays = context.bigArrays(); final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx); final CompensatedSum kahanSummation = new CompensatedSum(0, 0); @@ -140,14 +145,13 @@ public void collect(int doc, long bucket) throws IOException { }; } - public void getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException { + private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException { final CompensatedSum kahanSummation = new CompensatedSum(sums.get(0), compensations.get(0)); - StarTreeQueryHelper.getStarTreeLeafCollector( + StarTreeQueryHelper.precomputeLeafUsingStarTree( context, valuesSource, ctx, - sub, starTree, MetricStat.SUM.getTypeName(), value -> kahanSummation.add(NumericUtils.sortableLongToDouble(value)), diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/ValueCountAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/ValueCountAggregator.java index d298361391ad9..3541753d94e6f 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/ValueCountAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/ValueCountAggregator.java @@ -88,24 +88,30 @@ public ValueCountAggregator( } @Override - public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { - if (valuesSource == null) { - return LeafBucketCollector.NO_OP_COLLECTOR; - } - final BigArrays bigArrays = context.bigArrays(); - + protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException { if (valuesSource instanceof ValuesSource.Numeric) { - CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext()); if (supportedStarTree != null) { if (parent != null && subAggregators.length == 0) { // If this a child aggregator, then the parent will trigger star-tree pre-computation. // Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators - return LeafBucketCollector.NO_OP_COLLECTOR; + return true; } - getStarTreeCollector(ctx, sub, supportedStarTree); + precomputeLeafUsingStarTree(ctx, supportedStarTree); + return true; } + } + return false; + } + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + if (valuesSource == null) { + return LeafBucketCollector.NO_OP_COLLECTOR; + } + final BigArrays bigArrays = context.bigArrays(); + + if (valuesSource instanceof ValuesSource.Numeric) { final SortedNumericDocValues values = ((ValuesSource.Numeric) valuesSource).longValues(ctx); return new LeafBucketCollectorBase(sub, values) { @@ -145,12 +151,11 @@ public void collect(int doc, long bucket) throws IOException { }; } - public void getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException { - StarTreeQueryHelper.getStarTreeLeafCollector( + private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException { + StarTreeQueryHelper.precomputeLeafUsingStarTree( context, (ValuesSource.Numeric) valuesSource, ctx, - sub, starTree, MetricStat.VALUE_COUNT.getTypeName(), value -> counts.increment(0, value), diff --git a/server/src/main/java/org/opensearch/search/startree/StarTreeQueryHelper.java b/server/src/main/java/org/opensearch/search/startree/StarTreeQueryHelper.java index 4bc82cde3d4fb..c2ce7f41ca7bf 100644 --- a/server/src/main/java/org/opensearch/search/startree/StarTreeQueryHelper.java +++ b/server/src/main/java/org/opensearch/search/startree/StarTreeQueryHelper.java @@ -10,7 +10,6 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SegmentReader; -import org.apache.lucene.search.CollectionTerminatedException; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.FixedBitSet; import org.opensearch.common.lucene.Lucene; @@ -21,7 +20,6 @@ import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeUtils; import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator; import org.opensearch.index.query.QueryShardContext; -import org.opensearch.search.aggregations.LeafBucketCollector; import org.opensearch.search.aggregations.StarTreeBucketCollector; import org.opensearch.search.aggregations.support.ValuesSource; import org.opensearch.search.internal.SearchContext; @@ -71,11 +69,10 @@ public static StarTreeValues getStarTreeValues(LeafReaderContext context, Compos * Get the star-tree leaf collector * This collector computes the aggregation prematurely and invokes an early termination collector */ - public static void getStarTreeLeafCollector( + public static void precomputeLeafUsingStarTree( SearchContext context, ValuesSource.Numeric valuesSource, LeafReaderContext ctx, - LeafBucketCollector sub, CompositeIndexFieldInfo starTree, String metric, Consumer valueConsumer, @@ -113,10 +110,6 @@ public static void getStarTreeLeafCollector( // Call the final consumer after processing all entries finalConsumer.run(); - - // FIXME : Remove after @msfroh PR for precompute - // Terminate after pre-computing aggregation - throw new CollectionTerminatedException(); } /**