From c15c3912cb49a77476cbe84c113f68b201318c68 Mon Sep 17 00:00:00 2001 From: Gonzalo Ortiz Jaureguizar Date: Mon, 1 Apr 2024 10:45:05 +0200 Subject: [PATCH] Percentile operations supporting null (#12271) * new test framework candidate * Improved test system * Improve framework to be able to specify segments as strings * fix headers * Improve assertions when there are nulls * Improve error text * Improvements in the framework * Add a base class single input aggregation operations can extend to support null handling * Fix issue in NullableSingleInputAggregationFunction.forEachNotNullInt * Improve error message in NullEnabledQueriesTest * Add new schema family * Rename test schemas and table config * Split AllNullQueriesTest into on test per query * Revert change in AllNullQueriesTest that belongs to mode-null-support branch * Add tests * Fix issue in bytes in aggregation case * Update to the new framework * Fix some tests * rollback a code style change --- .../function/AggregationFunctionFactory.java | 37 +- ...ullableSingleInputAggregationFunction.java | 9 + .../PercentileAggregationFunction.java | 57 +-- .../PercentileEstAggregationFunction.java | 108 +++--- .../PercentileEstMVAggregationFunction.java | 4 +- .../PercentileKLLAggregationFunction.java | 69 ++-- .../PercentileKLLMVAggregationFunction.java | 2 +- .../PercentileMVAggregationFunction.java | 4 +- .../PercentileRawEstAggregationFunction.java | 10 +- .../PercentileRawKLLAggregationFunction.java | 4 +- ...rcentileRawTDigestAggregationFunction.java | 17 +- ...entileSmartTDigestAggregationFunction.java | 96 +++-- .../PercentileTDigestAggregationFunction.java | 110 +++--- ...ercentileTDigestMVAggregationFunction.java | 6 +- ...ractPercentileAggregationFunctionTest.java | 333 ++++++++++++++++++ .../PercentileAggregationFunctionTest.java | 27 ++ .../PercentileEstAggregationFunctionTest.java | 45 +++ .../PercentileKLLAggregationFunctionTest.java | 47 +++ ...leSmartTDigestAggregationFunctionTest.java | 87 +++++ .../apache/pinot/queries/FluentQueryTest.java | 2 +- 20 files changed, 861 insertions(+), 213 deletions(-) create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AbstractPercentileAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunctionTest.java create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunctionTest.java diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java index eeed8608a4ec..a82d421ebc96 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java @@ -61,16 +61,16 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio if (upperCaseFunctionName.startsWith("PERCENTILE")) { String remainingFunctionName = upperCaseFunctionName.substring(10); if (remainingFunctionName.equals("SMARTTDIGEST")) { - return new PercentileSmartTDigestAggregationFunction(arguments); + return new PercentileSmartTDigestAggregationFunction(arguments, nullHandlingEnabled); } if (remainingFunctionName.equals("KLL")) { - return new PercentileKLLAggregationFunction(arguments); + return new PercentileKLLAggregationFunction(arguments, nullHandlingEnabled); } if (remainingFunctionName.equals("KLLMV")) { return new PercentileKLLMVAggregationFunction(arguments); } if (remainingFunctionName.equals("RAWKLL")) { - return new PercentileRawKLLAggregationFunction(arguments); + return new PercentileRawKLLAggregationFunction(arguments, nullHandlingEnabled); } if (remainingFunctionName.equals("RAWKLLMV")) { return new PercentileRawKLLMVAggregationFunction(arguments); @@ -80,23 +80,28 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio // NOTE: This convention is deprecated. DO NOT add new functions here if (remainingFunctionName.matches("\\d+")) { // Percentile - return new PercentileAggregationFunction(firstArgument, parsePercentileToInt(remainingFunctionName)); + return new PercentileAggregationFunction(firstArgument, parsePercentileToInt(remainingFunctionName), + nullHandlingEnabled); } else if (remainingFunctionName.matches("EST\\d+")) { // PercentileEst String percentileString = remainingFunctionName.substring(3); - return new PercentileEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString)); + return new PercentileEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString), + nullHandlingEnabled); } else if (remainingFunctionName.matches("RAWEST\\d+")) { // PercentileRawEst String percentileString = remainingFunctionName.substring(6); - return new PercentileRawEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString)); + return new PercentileRawEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString), + nullHandlingEnabled); } else if (remainingFunctionName.matches("TDIGEST\\d+")) { // PercentileTDigest String percentileString = remainingFunctionName.substring(7); - return new PercentileTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString)); + return new PercentileTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString), + nullHandlingEnabled); } else if (remainingFunctionName.matches("RAWTDIGEST\\d+")) { // PercentileRawTDigest String percentileString = remainingFunctionName.substring(10); - return new PercentileRawTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString)); + return new PercentileRawTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString), + nullHandlingEnabled); } else if (remainingFunctionName.matches("\\d+MV")) { // PercentileMV String percentileString = remainingFunctionName.substring(0, remainingFunctionName.length() - 2); @@ -125,23 +130,23 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio Preconditions.checkArgument(percentile >= 0 && percentile <= 100, "Invalid percentile: %s", percentile); if (remainingFunctionName.isEmpty()) { // Percentile - return new PercentileAggregationFunction(firstArgument, percentile); + return new PercentileAggregationFunction(firstArgument, percentile, nullHandlingEnabled); } if (remainingFunctionName.equals("EST")) { // PercentileEst - return new PercentileEstAggregationFunction(firstArgument, percentile); + return new PercentileEstAggregationFunction(firstArgument, percentile, nullHandlingEnabled); } if (remainingFunctionName.equals("RAWEST")) { // PercentileRawEst - return new PercentileRawEstAggregationFunction(firstArgument, percentile); + return new PercentileRawEstAggregationFunction(firstArgument, percentile, nullHandlingEnabled); } if (remainingFunctionName.equals("TDIGEST")) { // PercentileTDigest - return new PercentileTDigestAggregationFunction(firstArgument, percentile); + return new PercentileTDigestAggregationFunction(firstArgument, percentile, nullHandlingEnabled); } if (remainingFunctionName.equals("RAWTDIGEST")) { // PercentileRawTDigest - return new PercentileRawTDigestAggregationFunction(firstArgument, percentile); + return new PercentileRawTDigestAggregationFunction(firstArgument, percentile, nullHandlingEnabled); } if (remainingFunctionName.equals("MV")) { // PercentileMV @@ -175,11 +180,13 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio Preconditions.checkArgument(compressionFactor >= 0, "Invalid compressionFactor: %d", compressionFactor); if (remainingFunctionName.equals("TDIGEST")) { // PercentileTDigest - return new PercentileTDigestAggregationFunction(firstArgument, percentile, compressionFactor); + return new PercentileTDigestAggregationFunction(firstArgument, percentile, compressionFactor, + nullHandlingEnabled); } if (remainingFunctionName.equals("RAWTDIGEST")) { // PercentileRawTDigest - return new PercentileRawTDigestAggregationFunction(firstArgument, percentile, compressionFactor); + return new PercentileRawTDigestAggregationFunction(firstArgument, percentile, compressionFactor, + nullHandlingEnabled); } if (remainingFunctionName.equals("TDIGESTMV")) { // PercentileTDigestMV diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java index 78f1ae12696c..907f0139d2a9 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java @@ -103,6 +103,15 @@ public void forEachNotNull(int length, IntIterator nullIndexIterator, BatchConsu } } + /** + * Folds over the non-null ranges of the blockValSet using the reducer. + * @param initialAcum the initial value of the accumulator + * @param The type of the accumulator + */ + public A foldNotNull(int length, BlockValSet blockValSet, A initialAcum, Reducer reducer) { + return foldNotNull(length, blockValSet.getNullBitmap(), initialAcum, reducer); + } + /** * Folds over the non-null ranges of the blockValSet using the reducer. * @param initialAcum the initial value of the accumulator diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java index 5d227caeadaa..c9c71744d267 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java @@ -31,7 +31,7 @@ import org.apache.pinot.segment.spi.AggregationFunctionType; -public class PercentileAggregationFunction extends BaseSingleInputAggregationFunction { +public class PercentileAggregationFunction extends NullableSingleInputAggregationFunction { private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY; //version 0 functions specified in the of form PERCENTILE<2-digits>(column) @@ -39,14 +39,14 @@ public class PercentileAggregationFunction extends BaseSingleInputAggregationFun protected final int _version; protected final double _percentile; - public PercentileAggregationFunction(ExpressionContext expression, int percentile) { - super(expression); + public PercentileAggregationFunction(ExpressionContext expression, int percentile, boolean nullHandlingEnabled) { + super(expression, nullHandlingEnabled); _version = 0; _percentile = percentile; } - public PercentileAggregationFunction(ExpressionContext expression, double percentile) { - super(expression); + public PercentileAggregationFunction(ExpressionContext expression, double percentile, boolean nullHandlingEnabled) { + super(expression, nullHandlingEnabled); _version = 1; _percentile = percentile; } @@ -77,33 +77,42 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map blockValSetMap) { DoubleArrayList valueList = getValueList(aggregationResultHolder); - double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - valueList.add(valueArray[i]); - } + BlockValSet blockValSet = blockValSetMap.get(_expression); + double[] valueArray = blockValSet.getDoubleValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + valueList.add(valueArray[i]); + } + }); } @Override public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { - double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]); - valueList.add(valueArray[i]); - } + BlockValSet blockValSet = blockValSetMap.get(_expression); + double[] valueArray = blockValSet.getDoubleValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]); + valueList.add(valueArray[i]); + } + }); } @Override public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map blockValSetMap) { - double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - double value = valueArray[i]; - for (int groupKey : groupKeysArray[i]) { - DoubleArrayList valueList = getValueList(groupByResultHolder, groupKey); - valueList.add(value); + BlockValSet blockValSet = blockValSetMap.get(_expression); + double[] valueArray = blockValSet.getDoubleValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + double value = valueArray[i]; + for (int groupKey : groupKeysArray[i]) { + DoubleArrayList valueList = getValueList(groupByResultHolder, groupKey); + valueList.add(value); + } } - } + }); } @Override @@ -146,7 +155,11 @@ public ColumnDataType getFinalResultColumnType() { public Double extractFinalResult(DoubleArrayList intermediateResult) { int size = intermediateResult.size(); if (size == 0) { - return DEFAULT_FINAL_RESULT; + if (_nullHandlingEnabled) { + return null; + } else { + return DEFAULT_FINAL_RESULT; + } } else { double[] values = intermediateResult.elements(); Arrays.sort(values, 0, size); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java index d055e4650541..e67a3f7d6500 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java @@ -32,7 +32,7 @@ import org.apache.pinot.spi.data.FieldSpec.DataType; -public class PercentileEstAggregationFunction extends BaseSingleInputAggregationFunction { +public class PercentileEstAggregationFunction extends NullableSingleInputAggregationFunction { public static final double DEFAULT_MAX_ERROR = 0.05; //version 0 functions specified in the of form PERCENTILEEST<2-digits>(column) @@ -40,14 +40,15 @@ public class PercentileEstAggregationFunction extends BaseSingleInputAggregation protected final int _version; protected final double _percentile; - public PercentileEstAggregationFunction(ExpressionContext expression, int percentile) { - super(expression); + public PercentileEstAggregationFunction(ExpressionContext expression, int percentile, boolean nullHandlingEnabled) { + super(expression, nullHandlingEnabled); _version = 0; _percentile = percentile; } - public PercentileEstAggregationFunction(ExpressionContext expression, double percentile) { - super(expression); + public PercentileEstAggregationFunction(ExpressionContext expression, double percentile, + boolean nullHandlingEnabled) { + super(expression, nullHandlingEnabled); _version = 1; _percentile = percentile; } @@ -81,24 +82,30 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde if (blockValSet.getValueType() != DataType.BYTES) { long[] longValues = blockValSet.getLongValuesSV(); QuantileDigest quantileDigest = getDefaultQuantileDigest(aggregationResultHolder); - for (int i = 0; i < length; i++) { - quantileDigest.add(longValues[i]); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + quantileDigest.add(longValues[i]); + } + }); } else { // Serialized QuantileDigest byte[][] bytesValues = blockValSet.getBytesValuesSV(); - QuantileDigest quantileDigest = aggregationResultHolder.getResult(); - if (quantileDigest != null) { - for (int i = 0; i < length; i++) { - quantileDigest.merge(ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i])); + foldNotNull(length, blockValSet, (QuantileDigest) aggregationResultHolder.getResult(), (quantile, from, toEx) -> { + int start; + QuantileDigest quantileDigest; + if (quantile != null) { + start = from; + quantileDigest = quantile; + } else { + start = from + 1; + quantileDigest = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[from]); + aggregationResultHolder.setValue(quantileDigest); } - } else { - quantileDigest = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[0]); - aggregationResultHolder.setValue(quantileDigest); - for (int i = 1; i < length; i++) { + for (int i = start; i < toEx; i++) { quantileDigest.merge(ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i])); } - } + return quantileDigest; + }); } } @@ -108,22 +115,26 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol BlockValSet blockValSet = blockValSetMap.get(_expression); if (blockValSet.getValueType() != DataType.BYTES) { long[] longValues = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - getDefaultQuantileDigest(groupByResultHolder, groupKeyArray[i]).add(longValues[i]); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + getDefaultQuantileDigest(groupByResultHolder, groupKeyArray[i]).add(longValues[i]); + } + }); } else { // Serialized QuantileDigest byte[][] bytesValues = blockValSet.getBytesValuesSV(); - for (int i = 0; i < length; i++) { - QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]); - int groupKey = groupKeyArray[i]; - QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey); - if (quantileDigest != null) { - quantileDigest.merge(value); - } else { - groupByResultHolder.setValueForKey(groupKey, value); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]); + int groupKey = groupKeyArray[i]; + QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey); + if (quantileDigest != null) { + quantileDigest.merge(value); + } else { + groupByResultHolder.setValueForKey(groupKey, value); + } } - } + }); } } @@ -133,28 +144,32 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult BlockValSet blockValSet = blockValSetMap.get(_expression); if (blockValSet.getValueType() != DataType.BYTES) { long[] longValues = blockValSet.getLongValuesSV(); - for (int i = 0; i < length; i++) { - long value = longValues[i]; - for (int groupKey : groupKeysArray[i]) { - getDefaultQuantileDigest(groupByResultHolder, groupKey).add(value); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + long value = longValues[i]; + for (int groupKey : groupKeysArray[i]) { + getDefaultQuantileDigest(groupByResultHolder, groupKey).add(value); + } } - } + }); } else { // Serialized QuantileDigest byte[][] bytesValues = blockValSet.getBytesValuesSV(); - for (int i = 0; i < length; i++) { - QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]); - for (int groupKey : groupKeysArray[i]) { - QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey); - if (quantileDigest != null) { - quantileDigest.merge(value); - } else { - // Create a new QuantileDigest for the group - groupByResultHolder - .setValueForKey(groupKey, ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i])); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]); + for (int groupKey : groupKeysArray[i]) { + QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey); + if (quantileDigest != null) { + quantileDigest.merge(value); + } else { + // Create a new QuantileDigest for the group + groupByResultHolder.setValueForKey(groupKey, + ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i])); + } } } - } + }); } } @@ -202,6 +217,9 @@ public ColumnDataType getFinalResultColumnType() { @Override public Long extractFinalResult(QuantileDigest intermediateResult) { + if (intermediateResult.getCount() == 0 && _nullHandlingEnabled) { + return null; + } return intermediateResult.getQuantile(_percentile / 100.0); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java index c1001f25c7eb..5a861714620e 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java @@ -30,11 +30,11 @@ public class PercentileEstMVAggregationFunction extends PercentileEstAggregationFunction { public PercentileEstMVAggregationFunction(ExpressionContext expression, int percentile) { - super(expression, percentile); + super(expression, percentile, false); } public PercentileEstMVAggregationFunction(ExpressionContext expression, double percentile) { - super(expression, percentile); + super(expression, percentile, false); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunction.java index 6d2b3b8697f9..bcf025a80149 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunction.java @@ -61,14 +61,14 @@ *

*/ public class PercentileKLLAggregationFunction - extends BaseSingleInputAggregationFunction> { + extends NullableSingleInputAggregationFunction> { protected static final int DEFAULT_K_VALUE = 200; protected final double _percentile; protected int _kValue; - public PercentileKLLAggregationFunction(List arguments) { - super(arguments.get(0)); + public PercentileKLLAggregationFunction(List arguments, boolean nullHandlingEnabled) { + super(arguments.get(0), nullHandlingEnabled); // Check that there are correct number of arguments int numArguments = arguments.size(); @@ -107,14 +107,18 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde if (valueType == DataType.BYTES) { // Assuming the column contains serialized data sketch KllDoublesSketch[] deserializedSketches = deserializeSketches(blockValSetMap.get(_expression).getBytesValuesSV()); - for (int i = 0; i < length; i++) { - sketch.merge(deserializedSketches[i]); - } + forEachNotNull(length, valueSet, (from, to) -> { + for (int i = from; i < to; i++) { + sketch.merge(deserializedSketches[i]); + } + }); } else { double[] values = valueSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - sketch.update(values[i]); - } + forEachNotNull(length, valueSet, (from, to) -> { + for (int i = from; i < to; i++) { + sketch.update(values[i]); + } + }); } } @@ -127,16 +131,20 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol if (valueType == DataType.BYTES) { // serialized sketch KllDoublesSketch[] deserializedSketches = deserializeSketches(blockValSetMap.get(_expression).getBytesValuesSV()); - for (int i = 0; i < length; i++) { - KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]); - sketch.merge(deserializedSketches[i]); - } + forEachNotNull(length, valueSet, (from, to) -> { + for (int i = from; i < to; i++) { + KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]); + sketch.merge(deserializedSketches[i]); + } + }); } else { double[] values = valueSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]); - sketch.update(values[i]); - } + forEachNotNull(length, valueSet, (from, to) -> { + for (int i = from; i < to; i++) { + KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]); + sketch.update(values[i]); + } + }); } } @@ -149,20 +157,24 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult if (valueType == DataType.BYTES) { // serialized sketch KllDoublesSketch[] deserializedSketches = deserializeSketches(blockValSetMap.get(_expression).getBytesValuesSV()); - for (int i = 0; i < length; i++) { - for (int groupKey : groupKeysArray[i]) { - KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey); - sketch.merge(deserializedSketches[i]); + forEachNotNull(length, valueSet, (from, to) -> { + for (int i = from; i < to; i++) { + for (int groupKey : groupKeysArray[i]) { + KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey); + sketch.merge(deserializedSketches[i]); + } } - } + }); } else { double[] values = valueSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - for (int groupKey : groupKeysArray[i]) { - KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey); - sketch.update(values[i]); + forEachNotNull(length, valueSet, (from, to) -> { + for (int i = from; i < to; i++) { + for (int groupKey : groupKeysArray[i]) { + KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey); + sketch.update(values[i]); + } } - } + }); } } @@ -241,6 +253,9 @@ public String getResultColumnName() { @Override public Comparable extractFinalResult(KllDoublesSketch sketch) { + if (sketch.isEmpty() && _nullHandlingEnabled) { + return null; + } return sketch.getQuantile(_percentile / 100); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLMVAggregationFunction.java index 4653e9051d38..26af8dea447d 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLMVAggregationFunction.java @@ -32,7 +32,7 @@ public class PercentileKLLMVAggregationFunction extends PercentileKLLAggregationFunction { public PercentileKLLMVAggregationFunction(List arguments) { - super(arguments); + super(arguments, false); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java index 794a9896a7d4..620763ea7599 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java @@ -30,11 +30,11 @@ public class PercentileMVAggregationFunction extends PercentileAggregationFunction { public PercentileMVAggregationFunction(ExpressionContext expression, int percentile) { - super(expression, percentile); + super(expression, percentile, false); } public PercentileMVAggregationFunction(ExpressionContext expression, double percentile) { - super(expression, percentile); + super(expression, percentile, false); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawEstAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawEstAggregationFunction.java index 063359ec9604..04787e7d559b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawEstAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawEstAggregationFunction.java @@ -37,12 +37,14 @@ public class PercentileRawEstAggregationFunction extends BaseSingleInputAggregationFunction { private final PercentileEstAggregationFunction _percentileEstAggregationFunction; - public PercentileRawEstAggregationFunction(ExpressionContext expressionContext, double percentile) { - this(expressionContext, new PercentileEstAggregationFunction(expressionContext, percentile)); + public PercentileRawEstAggregationFunction(ExpressionContext expressionContext, double percentile, + boolean nullHandlingEnabled) { + this(expressionContext, new PercentileEstAggregationFunction(expressionContext, percentile, nullHandlingEnabled)); } - public PercentileRawEstAggregationFunction(ExpressionContext expressionContext, int percentile) { - this(expressionContext, new PercentileEstAggregationFunction(expressionContext, percentile)); + public PercentileRawEstAggregationFunction(ExpressionContext expressionContext, int percentile, + boolean nullHandlingEnabled) { + this(expressionContext, new PercentileEstAggregationFunction(expressionContext, percentile, nullHandlingEnabled)); } protected PercentileRawEstAggregationFunction(ExpressionContext expression, diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawKLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawKLLAggregationFunction.java index 39c2022ff026..7e88cf009d88 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawKLLAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawKLLAggregationFunction.java @@ -28,8 +28,8 @@ public class PercentileRawKLLAggregationFunction extends PercentileKLLAggregationFunction { - public PercentileRawKLLAggregationFunction(List arguments) { - super(arguments); + public PercentileRawKLLAggregationFunction(List arguments, boolean nullHandlingEnabled) { + super(arguments, nullHandlingEnabled); } @Override diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawTDigestAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawTDigestAggregationFunction.java index 99a096c13063..fc618027a5fa 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawTDigestAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawTDigestAggregationFunction.java @@ -37,17 +37,22 @@ public class PercentileRawTDigestAggregationFunction extends BaseSingleInputAggregationFunction { private final PercentileTDigestAggregationFunction _percentileTDigestAggregationFunction; - public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, int percentile) { - this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile)); + public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, int percentile, + boolean nullHandlingEnabled) { + this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile, + nullHandlingEnabled)); } - public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, double percentile) { - this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile)); + public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, double percentile, + boolean nullHandlingEnabled) { + this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile, + nullHandlingEnabled)); } public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, double percentile, - int compressionFactor) { - this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile, compressionFactor)); + int compressionFactor, boolean nullHandlingEnabled) { + this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile, compressionFactor, + nullHandlingEnabled)); } protected PercentileRawTDigestAggregationFunction(ExpressionContext expression, diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunction.java index 92cd5fa09b9d..20d5372ca56f 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunction.java @@ -50,15 +50,15 @@ * - compression: Compression for the converted TDigest, 100 by default. * Example of third argument: 'threshold=10000;compression=50' */ -public class PercentileSmartTDigestAggregationFunction extends BaseSingleInputAggregationFunction { +public class PercentileSmartTDigestAggregationFunction extends NullableSingleInputAggregationFunction { private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY; private final double _percentile; private final int _threshold; private final int _compression; - public PercentileSmartTDigestAggregationFunction(List arguments) { - super(arguments.get(0)); + public PercentileSmartTDigestAggregationFunction(List arguments, boolean nullHandlingEnabled) { + super(arguments.get(0), nullHandlingEnabled); try { _percentile = arguments.get(1).getLiteral().getDoubleValue(); } catch (Exception e) { @@ -128,39 +128,53 @@ private static void validateValueType(BlockValSet blockValSet) { blockValSet.isSingleValue() ? "" : "_MV"); } - private static void aggregateIntoTDigest(int length, AggregationResultHolder aggregationResultHolder, + private void aggregateIntoTDigest(int length, AggregationResultHolder aggregationResultHolder, BlockValSet blockValSet) { TDigest tDigest = aggregationResultHolder.getResult(); if (blockValSet.isSingleValue()) { double[] doubleValues = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - tDigest.add(doubleValues[i]); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + tDigest.add(doubleValues[i]); + } + }); } else { double[][] doubleValues = blockValSet.getDoubleValuesMV(); - for (int i = 0; i < length; i++) { - for (double value : doubleValues[i]) { - tDigest.add(value); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + for (double value : doubleValues[i]) { + tDigest.add(value); + } } - } + }); } } - private void aggregateIntoValueList(int length, AggregationResultHolder aggregationResultHolder, - BlockValSet blockValSet) { + private DoubleArrayList getOrCreateList(int length, AggregationResultHolder aggregationResultHolder) { DoubleArrayList valueList = aggregationResultHolder.getResult(); if (valueList == null) { valueList = new DoubleArrayList(length); aggregationResultHolder.setValue(valueList); } + return valueList; + } + + private void aggregateIntoValueList(int length, AggregationResultHolder aggregationResultHolder, + BlockValSet blockValSet) { + DoubleArrayList valueList = getOrCreateList(length, aggregationResultHolder); if (blockValSet.isSingleValue()) { double[] doubleValues = blockValSet.getDoubleValuesSV(); - valueList.addElements(valueList.size(), doubleValues, 0, length); + forEachNotNull(length, blockValSet, (from, toEx) -> + valueList.addElements(valueList.size(), doubleValues, from, toEx - from) + ); } else { double[][] doubleValues = blockValSet.getDoubleValuesMV(); - for (int i = 0; i < length; i++) { - valueList.addElements(valueList.size(), doubleValues[i]); - } + forEachNotNull(length, blockValSet, (from, toEx) -> { + for (int i = 0; i < length; i++) { + valueList.addElements(valueList.size(), doubleValues[i]); + } + } + ); } if (valueList.size() > _threshold) { aggregationResultHolder.setValue(convertValueListToTDigest(valueList)); @@ -183,16 +197,20 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol validateValueType(blockValSet); if (blockValSet.isSingleValue()) { double[] doubleValues = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]); - valueList.add(doubleValues[i]); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]); + valueList.add(doubleValues[i]); + } + }); } else { double[][] doubleValues = blockValSet.getDoubleValuesMV(); - for (int i = 0; i < length; i++) { - DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]); - valueList.addElements(valueList.size(), doubleValues[i]); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]); + valueList.addElements(valueList.size(), doubleValues[i]); + } + }); } } @@ -212,19 +230,23 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult validateValueType(blockValSet); if (blockValSet.isSingleValue()) { double[] doubleValues = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - for (int groupKey : groupKeysArray[i]) { - getValueList(groupByResultHolder, groupKey).add(doubleValues[i]); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + for (int groupKey : groupKeysArray[i]) { + getValueList(groupByResultHolder, groupKey).add(doubleValues[i]); + } } - } + }); } else { double[][] doubleValues = blockValSet.getDoubleValuesMV(); - for (int i = 0; i < length; i++) { - for (int groupKey : groupKeysArray[i]) { - DoubleArrayList valueList = getValueList(groupByResultHolder, groupKey); - valueList.addElements(valueList.size(), doubleValues[i]); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + for (int groupKey : groupKeysArray[i]) { + DoubleArrayList valueList = getValueList(groupByResultHolder, groupKey); + valueList.addElements(valueList.size(), doubleValues[i]); + } } - } + }); } } @@ -285,7 +307,11 @@ public Double extractFinalResult(Object intermediateResult) { DoubleArrayList valueList = (DoubleArrayList) intermediateResult; int size = valueList.size(); if (size == 0) { - return DEFAULT_FINAL_RESULT; + if (_nullHandlingEnabled) { + return null; + } else { + return DEFAULT_FINAL_RESULT; + } } else { double[] values = valueList.elements(); Arrays.sort(values, 0, size); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java index d4224739c6ee..c831e52d2248 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestAggregationFunction.java @@ -39,7 +39,7 @@ * extra handling for two argument PERCENTILE functions to assess if v0 or v1. This can be revisited later if the * need arises */ -public class PercentileTDigestAggregationFunction extends BaseSingleInputAggregationFunction { +public class PercentileTDigestAggregationFunction extends NullableSingleInputAggregationFunction { public static final int DEFAULT_TDIGEST_COMPRESSION = 100; // version 0 functions specified in the of form PERCENTILETDIGEST<2-digits>(column). Uses default compression of 100 @@ -48,23 +48,25 @@ public class PercentileTDigestAggregationFunction extends BaseSingleInputAggrega protected final double _percentile; protected final int _compressionFactor; - public PercentileTDigestAggregationFunction(ExpressionContext expression, int percentile) { - super(expression); + public PercentileTDigestAggregationFunction(ExpressionContext expression, int percentile, + boolean nullHandlingEnabled) { + super(expression, nullHandlingEnabled); _version = 0; _percentile = percentile; _compressionFactor = DEFAULT_TDIGEST_COMPRESSION; } - public PercentileTDigestAggregationFunction(ExpressionContext expression, double percentile) { - super(expression); + public PercentileTDigestAggregationFunction(ExpressionContext expression, double percentile, + boolean nullHandlingEnabled) { + super(expression, nullHandlingEnabled); _version = 1; _percentile = percentile; _compressionFactor = DEFAULT_TDIGEST_COMPRESSION; } public PercentileTDigestAggregationFunction(ExpressionContext expression, double percentile, - int compressionFactor) { - super(expression); + int compressionFactor, boolean nullHandlingEnabled) { + super(expression, nullHandlingEnabled); _version = 1; _percentile = percentile; _compressionFactor = compressionFactor; @@ -104,24 +106,28 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde if (blockValSet.getValueType() != DataType.BYTES) { double[] doubleValues = blockValSet.getDoubleValuesSV(); TDigest tDigest = getDefaultTDigest(aggregationResultHolder, _compressionFactor); - for (int i = 0; i < length; i++) { - tDigest.add(doubleValues[i]); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + tDigest.add(doubleValues[i]); + } + }); } else { // Serialized TDigest byte[][] bytesValues = blockValSet.getBytesValuesSV(); - TDigest tDigest = aggregationResultHolder.getResult(); - if (tDigest != null) { - for (int i = 0; i < length; i++) { - tDigest.add(ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i])); - } - } else { - tDigest = ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[0]); - aggregationResultHolder.setValue(tDigest); - for (int i = 1; i < length; i++) { - tDigest.add(ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i])); + foldNotNull(length, blockValSet, (TDigest) aggregationResultHolder.getResult(), (tDigest, from, toEx) -> { + if (tDigest != null) { + for (int i = from; i < toEx; i++) { + tDigest.add(ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i])); + } + } else { + tDigest = ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[0]); + aggregationResultHolder.setValue(tDigest); + for (int i = 1; i < length; i++) { + tDigest.add(ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i])); + } } - } + return tDigest; + }); } } @@ -131,22 +137,26 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol BlockValSet blockValSet = blockValSetMap.get(_expression); if (blockValSet.getValueType() != DataType.BYTES) { double[] doubleValues = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - getDefaultTDigest(groupByResultHolder, groupKeyArray[i], _compressionFactor).add(doubleValues[i]); - } + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + getDefaultTDigest(groupByResultHolder, groupKeyArray[i], _compressionFactor).add(doubleValues[i]); + } + }); } else { // Serialized TDigest byte[][] bytesValues = blockValSet.getBytesValuesSV(); - for (int i = 0; i < length; i++) { - TDigest value = ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i]); - int groupKey = groupKeyArray[i]; - TDigest tDigest = groupByResultHolder.getResult(groupKey); - if (tDigest != null) { - tDigest.add(value); - } else { - groupByResultHolder.setValueForKey(groupKey, value); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + TDigest value = ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i]); + int groupKey = groupKeyArray[i]; + TDigest tDigest = groupByResultHolder.getResult(groupKey); + if (tDigest != null) { + tDigest.add(value); + } else { + groupByResultHolder.setValueForKey(groupKey, value); + } } - } + }); } } @@ -156,27 +166,31 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult BlockValSet blockValSet = blockValSetMap.get(_expression); if (blockValSet.getValueType() != DataType.BYTES) { double[] doubleValues = blockValSet.getDoubleValuesSV(); - for (int i = 0; i < length; i++) { - double value = doubleValues[i]; - for (int groupKey : groupKeysArray[i]) { - getDefaultTDigest(groupByResultHolder, groupKey, _compressionFactor).add(value); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + double value = doubleValues[i]; + for (int groupKey : groupKeysArray[i]) { + getDefaultTDigest(groupByResultHolder, groupKey, _compressionFactor).add(value); + } } - } + }); } else { // Serialized QuantileDigest byte[][] bytesValues = blockValSet.getBytesValuesSV(); - for (int i = 0; i < length; i++) { - TDigest value = ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i]); - for (int groupKey : groupKeysArray[i]) { - TDigest tDigest = groupByResultHolder.getResult(groupKey); - if (tDigest != null) { - tDigest.add(value); - } else { - // Create a new TDigest for the group - groupByResultHolder.setValueForKey(groupKey, ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i])); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + TDigest value = ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i]); + for (int groupKey : groupKeysArray[i]) { + TDigest tDigest = groupByResultHolder.getResult(groupKey); + if (tDigest != null) { + tDigest.add(value); + } else { + // Create a new TDigest for the group + groupByResultHolder.setValueForKey(groupKey, ObjectSerDeUtils.TDIGEST_SER_DE.deserialize(bytesValues[i])); + } } } - } + }); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java index 571f2ae9126a..a6b7884e6e87 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileTDigestMVAggregationFunction.java @@ -30,16 +30,16 @@ public class PercentileTDigestMVAggregationFunction extends PercentileTDigestAggregationFunction { public PercentileTDigestMVAggregationFunction(ExpressionContext expression, int percentile) { - super(expression, percentile); + super(expression, percentile, false); } public PercentileTDigestMVAggregationFunction(ExpressionContext expression, double percentile) { - super(expression, percentile); + super(expression, percentile, false); } public PercentileTDigestMVAggregationFunction(ExpressionContext expression, double percentile, int compressionFactor) { - super(expression, percentile, compressionFactor); + super(expression, percentile, compressionFactor, false); } @Override diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AbstractPercentileAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AbstractPercentileAggregationFunctionTest.java new file mode 100644 index 000000000000..fe9cc09f26a9 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AbstractPercentileAggregationFunctionTest.java @@ -0,0 +1,333 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.queries.FluentQueryTest; +import org.apache.pinot.spi.data.FieldSpec; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + + +public abstract class AbstractPercentileAggregationFunctionTest extends AbstractAggregationFunctionTest { + + @DataProvider(name = "scenarios") + Object[] scenarios() { + return new Object[] { + new Scenario(FieldSpec.DataType.INT), + new Scenario(FieldSpec.DataType.LONG), + new Scenario(FieldSpec.DataType.FLOAT), + new Scenario(FieldSpec.DataType.DOUBLE), + }; + } + + public abstract String callStr(String column, int percent); + + public String getFinalResultColumnType() { + return "DOUBLE"; + } + + public class Scenario { + private final FieldSpec.DataType _dataType; + + public Scenario(FieldSpec.DataType dataType) { + _dataType = dataType; + } + + public FieldSpec.DataType getDataType() { + return _dataType; + } + + public FluentQueryTest.DeclaringTable getDeclaringTable(boolean nullHandlingEnabled) { + return givenSingleNullableFieldTable(_dataType, nullHandlingEnabled); + } + + @Override + public String toString() { + return "Scenario{" + "dt=" + _dataType + '}'; + } + } + + FluentQueryTest.TableWithSegments withDefaultData(Scenario scenario, boolean nullHandlingEnabled) { + return scenario.getDeclaringTable(nullHandlingEnabled) + .onFirstInstance("myField", + "null", + "0", + "null", + "1", + "null", + "2", + "null", + "3", + "null", + "4", + "null" + ).andSegment("myField", + "null", + "5", + "null", + "6", + "null", + "7", + "null", + "8", + "null", + "9", + "null" + ); + } + + String minValue(FieldSpec.DataType dataType) { + switch (dataType) { + case INT: return "-2.147483648E9"; + case LONG: return "-9.223372036854776E18"; + case FLOAT: return "-Infinity"; + case DOUBLE: return "-Infinity"; + default: + throw new IllegalArgumentException("Unexpected type " + dataType); + } + } + + String expectedAggrWithoutNull10(Scenario scenario) { + return minValue(scenario._dataType); + } + + String expectedAggrWithoutNull15(Scenario scenario) { + return minValue(scenario._dataType); + } + + String expectedAggrWithoutNull30(Scenario scenario) { + return minValue(scenario._dataType); + } + + String expectedAggrWithoutNull35(Scenario scenario) { + return minValue(scenario._dataType); + } + + String expectedAggrWithoutNull50(Scenario scenario) { + return minValue(scenario._dataType); + } + + String expectedAggrWithoutNull55(Scenario scenario) { + return "0"; + } + + String expectedAggrWithoutNull70(Scenario scenario) { + return "3"; + } + + String expectedAggrWithoutNull75(Scenario scenario) { + return "4"; + } + + String expectedAggrWithoutNull90(Scenario scenario) { + return "7"; + } + + String expectedAggrWithoutNull100(Scenario scenario) { + return "9"; + } + + @Test(dataProvider = "scenarios") + void aggrWithoutNull(Scenario scenario) { + + FluentQueryTest.TableWithSegments instance = withDefaultData(scenario, false); + + instance + .whenQuery("select " + callStr("myField", 10) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull10(scenario)); + + instance + .whenQuery("select " + callStr("myField", 15) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull15(scenario)); + + instance + .whenQuery("select " + callStr("myField", 30) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull30(scenario)); + instance + .whenQuery("select " + callStr("myField", 35) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull35(scenario)); + + instance + .whenQuery("select " + callStr("myField", 50) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull50(scenario)); + instance + .whenQuery("select " + callStr("myField", 55) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull55(scenario)); + + instance + .whenQuery("select " + callStr("myField", 70) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull70(scenario)); + + instance + .whenQuery("select " + callStr("myField", 75) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull75(scenario)); + + instance + .whenQuery("select " + callStr("myField", 90) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull90(scenario)); + + instance + .whenQuery("select " + callStr("myField", 100) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithoutNull100(scenario)); + } + + String expectedAggrWithNull10(Scenario scenario) { + return "1"; + } + + @Test(dataProvider = "scenarios") + void aggrWithNull10(Scenario scenario) { + withDefaultData(scenario, true) + .whenQuery("select " + callStr("myField", 10) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithNull10(scenario)); + } + + String expectedAggrWithNull15(Scenario scenario) { + return "1"; + } + + @Test(dataProvider = "scenarios") + void aggrWithNull15(Scenario scenario) { + withDefaultData(scenario, true) + .whenQuery("select " + callStr("myField", 15) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithNull15(scenario)); + } + + String expectedAggrWithNull30(Scenario scenario) { + return "3"; + } + + @Test(dataProvider = "scenarios") + void aggrWithNull30(Scenario scenario) { + withDefaultData(scenario, true) + .whenQuery("select " + callStr("myField", 30) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithNull30(scenario)); + } + + String expectedAggrWithNull35(Scenario scenario) { + return "3"; + } + + @Test(dataProvider = "scenarios") + void aggrWithNull35(Scenario scenario) { + withDefaultData(scenario, true) + .whenQuery("select " + callStr("myField", 35) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithNull35(scenario)); + } + + String expectedAggrWithNull50(Scenario scenario) { + return "5"; + } + + @Test(dataProvider = "scenarios") + void aggrWithNull50(Scenario scenario) { + withDefaultData(scenario, true) + .whenQuery("select " + callStr("myField", 50) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithNull50(scenario)); + } + + String expectedAggrWithNull55(Scenario scenario) { + return "5"; + } + + @Test(dataProvider = "scenarios") + void aggrWithNull55(Scenario scenario) { + withDefaultData(scenario, true) + .whenQuery("select " + callStr("myField", 55) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithNull55(scenario)); + } + + String expectedAggrWithNull70(Scenario scenario) { + return "7"; + } + + @Test(dataProvider = "scenarios") + void aggrWithNull70(Scenario scenario) { + withDefaultData(scenario, true) + .whenQuery("select " + callStr("myField", 70) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithNull70(scenario)); + } + + String expectedAggrWithNull75(Scenario scenario) { + return "7"; + } + + @Test(dataProvider = "scenarios") + void aggrWithNull75(Scenario scenario) { + withDefaultData(scenario, true) + .whenQuery("select " + callStr("myField", 75) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithNull75(scenario)); + } + + String expectedAggrWithNull100(Scenario scenario) { + return "9"; + } + + @Test(dataProvider = "scenarios") + void aggrWithNull100(Scenario scenario) { + withDefaultData(scenario, true) + .whenQuery("select " + callStr("myField", 100) + " from testTable") + .thenResultIs(getFinalResultColumnType(), expectedAggrWithNull100(scenario)); + } + + @Test(dataProvider = "scenarios") + void aggrSvWithoutNull(Scenario scenario) { + scenario.getDeclaringTable(false) + .onFirstInstance("myField", + "null", + "1", + "null" + ).andSegment("myField", + "9" + ).andSegment("myField", + "null", + "null", + "null" + ).whenQuery("select $segmentName, " + callStr("myField", 50) + " from testTable " + + "group by $segmentName order by $segmentName") + .thenResultIs("STRING | " + getFinalResultColumnType(), + "testTable_0 | " + minValue(scenario._dataType), + "testTable_1 | 9", + "testTable_2 | " + minValue(scenario._dataType) + ); + } + + @Test(dataProvider = "scenarios") + void aggrSvWithNull(Scenario scenario) { + scenario.getDeclaringTable(true) + .onFirstInstance("myField", + "null", + "1", + "null" + ).andSegment("myField", + "9" + ).andSegment("myField", + "null", + "null", + "null" + ).whenQuery("select $segmentName, " + callStr("myField", 50) + " from testTable " + + "group by $segmentName order by $segmentName") + .thenResultIs("STRING | " + getFinalResultColumnType(), + "testTable_0 | 1", + "testTable_1 | 9", + "testTable_2 | null" + ); + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunctionTest.java new file mode 100644 index 000000000000..3c2ecdde0112 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunctionTest.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.pinot.core.query.aggregation.function; + +public class PercentileAggregationFunctionTest extends AbstractPercentileAggregationFunctionTest { + @Override + public String callStr(String column, int percent) { + return "PERCENTILE(" + column + ", " + percent + ")"; + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunctionTest.java new file mode 100644 index 000000000000..4dda1614b7c8 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunctionTest.java @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import org.apache.pinot.spi.data.FieldSpec; + + +public class PercentileEstAggregationFunctionTest extends AbstractPercentileAggregationFunctionTest { + @Override + public String callStr(String column, int percent) { + return "PERCENTILEEST(" + column + ", " + percent + ")"; + } + + @Override + public String getFinalResultColumnType() { + return "LONG"; + } + + String minValue(FieldSpec.DataType dataType) { + switch (dataType) { + case INT: return "-2147483648"; + case LONG: return "-9223372036854775808"; + case FLOAT: return "-9223372036854775808"; + case DOUBLE: return "-9223372036854775808"; + default: + throw new IllegalArgumentException("Unexpected type " + dataType); + } + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunctionTest.java new file mode 100644 index 000000000000..1eb6c991c22f --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunctionTest.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + + +public class PercentileKLLAggregationFunctionTest extends AbstractPercentileAggregationFunctionTest { + @Override + public String callStr(String column, int percent) { + return "PERCENTILEKLL(" + column + ", " + percent + ")"; + } + + @Override + String expectedAggrWithNull10(Scenario scenario) { + return "0"; + } + + @Override + String expectedAggrWithNull30(Scenario scenario) { + return "2"; + } + + @Override + String expectedAggrWithNull50(Scenario scenario) { + return "4"; + } + + @Override + String expectedAggrWithNull70(Scenario scenario) { + return "6"; + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunctionTest.java new file mode 100644 index 000000000000..b1eb471c704e --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunctionTest.java @@ -0,0 +1,87 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + + +public class PercentileSmartTDigestAggregationFunctionTest { + + public static class WithHighThreshold extends AbstractPercentileAggregationFunctionTest { + @Override + public String callStr(String column, int percent) { + return "PERCENTILESMARTTDIGEST(" + column + ", " + percent + ", 'THRESHOLD=10000')"; + } + } + + public static class WithSmallThreshold extends AbstractPercentileAggregationFunctionTest { + @Override + public String callStr(String column, int percent) { + return "PERCENTILESMARTTDIGEST(" + column + ", " + percent + ", 'THRESHOLD=1')"; + } + + @Override + String expectedAggrWithNull10(Scenario scenario) { + return "0.5"; + } + + @Override + String expectedAggrWithNull30(Scenario scenario) { + return "2.5"; + } + + @Override + String expectedAggrWithNull50(Scenario scenario) { + return "4.5"; + } + + @Override + String expectedAggrWithNull70(Scenario scenario) { + return "6.5"; + } + + @Override + String expectedAggrWithoutNull55(Scenario scenario) { + switch (scenario.getDataType()) { + case INT: + return "-6.442450943999939E8"; + case LONG: + return "-2.7670116110564065E18"; + case FLOAT: + case DOUBLE: + return "-Infinity"; + default: + throw new IllegalArgumentException("Unsupported datatype " + scenario.getDataType()); + } + } + + @Override + String expectedAggrWithoutNull75(Scenario scenario) { + return "4.0"; + } + + @Override + String expectedAggrWithoutNull90(Scenario scenario) { + return "7.100000000000001"; + } + + @Override + String expectedAggrWithoutNull100(Scenario scenario) { + return super.expectedAggrWithoutNull100(scenario); + } + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/FluentQueryTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/FluentQueryTest.java index ba6d22c429c4..8bd93cd42e3c 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/FluentQueryTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/FluentQueryTest.java @@ -112,7 +112,7 @@ public OnFirstInstance onFirstInstance(Object[]... content) { } } - static class TableWithSegments { + public static class TableWithSegments { protected final TableConfig _tableConfig; protected final Schema _schema; protected final File _indexDir;