Skip to content

Commit

Permalink
Unify precomputation of aggregations behind a common API (#16733) (#1…
Browse files Browse the repository at this point in the history
…7197)

We've had a series of aggregation speedups that use the same strategy:
instead of iterating through documents that match the query
one-by-one, we can look at a Lucene segment and compute the
aggregation directly (if some particular conditions are met).

In every case, we've hooked that into custom logic hijacks the
getLeafCollector method and throws CollectionTerminatedException. This
creates the illusion that we're implementing a custom LeafCollector,
when really we're not collecting at all (which is the whole point).

With this refactoring, the mechanism (hijacking getLeafCollector) is
moved into AggregatorBase. Aggregators that have a strategy to
precompute their answer can override tryPrecomputeAggregationForLeaf,
which is expected to return true if they managed to precompute.

This should also make it easier to keep track of which aggregations
have precomputation approaches (since they override this method).

---------


(cherry picked from commit 2847695)

Signed-off-by: Michael Froh <froh@amazon.com>
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
1 parent ffed717 commit a24646c
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 132 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
package org.opensearch.search.aggregations;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.core.common.breaker.CircuitBreaker;
Expand Down Expand Up @@ -200,6 +201,9 @@ public Map<String, Object> metadata() {

@Override
public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException {
if (tryPrecomputeAggregationForLeaf(ctx)) {
throw new CollectionTerminatedException();
}
preGetSubLeafCollectors(ctx);
final LeafBucketCollector sub = collectableSubAggregators.getLeafCollector(ctx);
return getLeafCollector(ctx, sub);
Expand All @@ -216,6 +220,21 @@ protected void preGetSubLeafCollectors(LeafReaderContext ctx) throws IOException
*/
protected void doPreCollection() throws IOException {}

/**
* Subclasses may override this method if they have an efficient way of computing their aggregation for the given
* segment (versus collecting matching documents). If this method returns true, collection for the given segment
* will be terminated, rather than executing normally.
* <p>
* If this method returns true, the aggregator's state should be identical to what it would be if matching
* documents from the segment were fully collected. If this method returns false, the aggregator's state should
* be unchanged from before this method is called.
* @param ctx the context for the given segment
* @return true if and only if results for this segment have been precomputed
*/
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
return false;
}

@Override
public final void preCollection() throws IOException {
List<BucketCollector> collectors = Arrays.asList(subAggregators);
Expand Down Expand Up @@ -251,8 +270,8 @@ public Aggregator[] subAggregators() {
public Aggregator subAggregator(String aggName) {
if (subAggregatorbyName == null) {
subAggregatorbyName = new HashMap<>(subAggregators.length);
for (int i = 0; i < subAggregators.length; i++) {
subAggregatorbyName.put(subAggregators[i].name(), subAggregators[i]);
for (Aggregator subAggregator : subAggregators) {
subAggregatorbyName.put(subAggregator.name(), subAggregator);
}
}
return subAggregatorbyName.get(aggName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,10 +556,13 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t
}

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
}

@Override
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
finishLeaf();

boolean fillDocIdSet = deferredCollectors != NO_OP_COLLECTOR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.CollectionUtil;
Expand Down Expand Up @@ -187,22 +186,23 @@ public ScoreMode scoreMode() {
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
if (optimized) throw new CollectionTerminatedException();

SortedNumericDocValues values = valuesSource.longValues(ctx);
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (preComputeWithStarTree(ctx, supportedStarTree) == true) {
throw new CollectionTerminatedException();
return true;
}
}
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

SortedNumericDocValues values = valuesSource.longValues(ctx);
return new LeafBucketCollectorBase(sub, values) {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
package org.opensearch.search.aggregations.bucket.range;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ScoreMode;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.io.stream.StreamInput;
Expand Down Expand Up @@ -310,10 +309,15 @@ public ScoreMode scoreMode() {
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (segmentMatchAll(context, ctx) && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false)) {
throw new CollectionTerminatedException();
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (segmentMatchAll(context, ctx)) {
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false);
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {

final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
return new LeafBucketCollectorBase(sub, values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -166,35 +165,32 @@ public void setWeight(Weight weight) {
@return A LeafBucketCollector implementation with collection termination, since collection is complete
@throws IOException If an I/O error occurs during reading
*/
LeafBucketCollector termDocFreqCollector(
LeafReaderContext ctx,
SortedSetDocValues globalOrds,
BiConsumer<Long, Integer> ordCountConsumer
) throws IOException {
boolean tryCollectFromTermFrequencies(LeafReaderContext ctx, SortedSetDocValues globalOrds, BiConsumer<Long, Integer> ordCountConsumer)
throws IOException {
if (weight == null) {
// Weight not assigned - cannot use this optimization
return null;
return false;
} else {
if (weight.count(ctx) == 0) {
// No documents matches top level query on this segment, we can skip the segment entirely
return LeafBucketCollector.NO_OP_COLLECTOR;
return true;
} else if (weight.count(ctx) != ctx.reader().maxDoc()) {
// weight.count(ctx) == ctx.reader().maxDoc() implies there are no deleted documents and
// top-level query matches all docs in the segment
return null;
return false;
}
}

Terms segmentTerms = ctx.reader().terms(this.fieldName);
if (segmentTerms == null) {
// Field is not indexed.
return null;
return false;
}

NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME);
if (docCountValues.nextDoc() != NO_MORE_DOCS) {
// This segment has at least one document with the _doc_count field.
return null;
return false;
}

TermsEnum indexTermsEnum = segmentTerms.iterator();
Expand All @@ -218,31 +214,28 @@ LeafBucketCollector termDocFreqCollector(
ordinalTerm = globalOrdinalTermsEnum.next();
}
}
return new LeafBucketCollector() {
@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
throw new CollectionTerminatedException();
}
};
return true;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
collectionStrategy.globalOrdsReady(globalOrds);

if (collectionStrategy instanceof DenseGlobalOrds
&& this.resultStrategy instanceof StandardTermsResults
&& sub == LeafBucketCollector.NO_OP_COLLECTOR) {
LeafBucketCollector termDocFreqCollector = termDocFreqCollector(
&& subAggregators.length == 0) {
return tryCollectFromTermFrequencies(
ctx,
globalOrds,
(ord, docCount) -> incrementBucketDocCount(collectionStrategy.globalOrdToBucketOrd(0, ord), docCount)
);
if (termDocFreqCollector != null) {
return termDocFreqCollector;
}
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
collectionStrategy.globalOrdsReady(globalOrds);

SortedDocValues singleValues = DocValues.unwrapSingleton(globalOrds);
if (singleValues != null) {
Expand Down Expand Up @@ -433,6 +426,24 @@ static class LowCardinality extends GlobalOrdinalsStringTermsAggregator {
this.segmentDocCounts = context.bigArrays().newLongArray(1, true);
}

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (subAggregators.length == 0) {
if (mapping != null) {
mapSegmentCountsToGlobalCounts(mapping);
}
final SortedSetDocValues segmentOrds = valuesSource.ordinalsValues(ctx);
segmentDocCounts = context.bigArrays().grow(segmentDocCounts, 1 + segmentOrds.getValueCount());
mapping = valuesSource.globalOrdinalsMapping(ctx);
return tryCollectFromTermFrequencies(
ctx,
segmentOrds,
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
);
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
if (mapping != null) {
Expand All @@ -443,17 +454,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
assert sub == LeafBucketCollector.NO_OP_COLLECTOR;
mapping = valuesSource.globalOrdinalsMapping(ctx);

if (this.resultStrategy instanceof StandardTermsResults) {
LeafBucketCollector termDocFreqCollector = this.termDocFreqCollector(
ctx,
segmentOrds,
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
);
if (termDocFreqCollector != null) {
return termDocFreqCollector;
}
}

final SortedDocValues singleValues = DocValues.unwrapSingleton(segmentOrds);
if (singleValues != null) {
segmentsWithSingleValuedOrds++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
package org.opensearch.search.aggregations.metrics;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.util.FixedBitSet;
Expand Down Expand Up @@ -104,23 +103,29 @@ public ScoreMode scoreMode() {
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
return false;
}
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (parent != null && subAggregators.length == 0) {
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
return LeafBucketCollector.NO_OP_COLLECTOR;
return true;
}
return getStarTreeLeafCollector(ctx, sub, supportedStarTree);
precomputeLeafUsingStarTree(ctx, supportedStarTree);
return true;
}
return getDefaultLeafCollector(ctx, sub);
return false;
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}

final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
Expand Down Expand Up @@ -154,8 +159,7 @@ public void collect(int doc, long bucket) throws IOException {
};
}

public LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
throws IOException {
private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree);
assert starTreeValues != null;

Expand Down Expand Up @@ -200,12 +204,6 @@ public LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafB

sums.set(0, kahanSummation.value());
compensations.set(0, kahanSummation.delta());
return new LeafBucketCollectorBase(sub, valuesSource.doubleValues(ctx)) {
@Override
public void collect(int doc, long bucket) {
throw new CollectionTerminatedException();
}
};
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,24 @@ public ScoreMode scoreMode() {
return valuesSource != null && valuesSource.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
}

@Override
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
if (valuesSource == null) {
return false;
}
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (parent != null && subAggregators.length == 0) {
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
return true;
}
precomputeLeafUsingStarTree(ctx, supportedStarTree);
return true;
}
return false;
}

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
Expand All @@ -130,20 +148,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
}
}

CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
if (supportedStarTree != null) {
if (parent != null && subAggregators.length == 0) {
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
return LeafBucketCollector.NO_OP_COLLECTOR;
}
getStarTreeCollector(ctx, sub, supportedStarTree);
}
return getDefaultLeafCollector(ctx, sub);
}

private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {

final BigArrays bigArrays = context.bigArrays();
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
final NumericDoubleValues values = MultiValueMode.MAX.select(allValues);
Expand All @@ -167,9 +171,9 @@ public void collect(int doc, long bucket) throws IOException {
};
}

public void getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException {
private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
AtomicReference<Double> max = new AtomicReference<>(maxes.get(0));
StarTreeQueryHelper.getStarTreeLeafCollector(context, valuesSource, ctx, sub, starTree, MetricStat.MAX.getTypeName(), value -> {
StarTreeQueryHelper.precomputeLeafUsingStarTree(context, valuesSource, ctx, starTree, MetricStat.MAX.getTypeName(), value -> {
max.set(Math.max(max.get(), (NumericUtils.sortableLongToDouble(value))));
}, () -> maxes.set(0, max.get()));
}
Expand Down
Loading

0 comments on commit a24646c

Please sign in to comment.