From c15c3912cb49a77476cbe84c113f68b201318c68 Mon Sep 17 00:00:00 2001
From: Gonzalo Ortiz Jaureguizar
Date: Mon, 1 Apr 2024 10:45:05 +0200
Subject: [PATCH] Percentile operations supporting null (#12271)
* new test framework candidate
* Improved test system
* Improve framework to be able to specify segments as strings
* fix headers
* Improve assertions when there are nulls
* Improve error text
* Improvements in the framework
* Add a base class single input aggregation operations can extend to support null handling
* Fix issue in NullableSingleInputAggregationFunction.forEachNotNullInt
* Improve error message in NullEnabledQueriesTest
* Add new schema family
* Rename test schemas and table config
* Split AllNullQueriesTest into on test per query
* Revert change in AllNullQueriesTest that belongs to mode-null-support branch
* Add tests
* Fix issue in bytes in aggregation case
* Update to the new framework
* Fix some tests
* rollback a code style change
---
.../function/AggregationFunctionFactory.java | 37 +-
...ullableSingleInputAggregationFunction.java | 9 +
.../PercentileAggregationFunction.java | 57 +--
.../PercentileEstAggregationFunction.java | 108 +++---
.../PercentileEstMVAggregationFunction.java | 4 +-
.../PercentileKLLAggregationFunction.java | 69 ++--
.../PercentileKLLMVAggregationFunction.java | 2 +-
.../PercentileMVAggregationFunction.java | 4 +-
.../PercentileRawEstAggregationFunction.java | 10 +-
.../PercentileRawKLLAggregationFunction.java | 4 +-
...rcentileRawTDigestAggregationFunction.java | 17 +-
...entileSmartTDigestAggregationFunction.java | 96 +++--
.../PercentileTDigestAggregationFunction.java | 110 +++---
...ercentileTDigestMVAggregationFunction.java | 6 +-
...ractPercentileAggregationFunctionTest.java | 333 ++++++++++++++++++
.../PercentileAggregationFunctionTest.java | 27 ++
.../PercentileEstAggregationFunctionTest.java | 45 +++
.../PercentileKLLAggregationFunctionTest.java | 47 +++
...leSmartTDigestAggregationFunctionTest.java | 87 +++++
.../apache/pinot/queries/FluentQueryTest.java | 2 +-
20 files changed, 861 insertions(+), 213 deletions(-)
create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AbstractPercentileAggregationFunctionTest.java
create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunctionTest.java
create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunctionTest.java
create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunctionTest.java
create mode 100644 pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunctionTest.java
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index eeed8608a4ec..a82d421ebc96 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -61,16 +61,16 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
if (upperCaseFunctionName.startsWith("PERCENTILE")) {
String remainingFunctionName = upperCaseFunctionName.substring(10);
if (remainingFunctionName.equals("SMARTTDIGEST")) {
- return new PercentileSmartTDigestAggregationFunction(arguments);
+ return new PercentileSmartTDigestAggregationFunction(arguments, nullHandlingEnabled);
}
if (remainingFunctionName.equals("KLL")) {
- return new PercentileKLLAggregationFunction(arguments);
+ return new PercentileKLLAggregationFunction(arguments, nullHandlingEnabled);
}
if (remainingFunctionName.equals("KLLMV")) {
return new PercentileKLLMVAggregationFunction(arguments);
}
if (remainingFunctionName.equals("RAWKLL")) {
- return new PercentileRawKLLAggregationFunction(arguments);
+ return new PercentileRawKLLAggregationFunction(arguments, nullHandlingEnabled);
}
if (remainingFunctionName.equals("RAWKLLMV")) {
return new PercentileRawKLLMVAggregationFunction(arguments);
@@ -80,23 +80,28 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
// NOTE: This convention is deprecated. DO NOT add new functions here
if (remainingFunctionName.matches("\\d+")) {
// Percentile
- return new PercentileAggregationFunction(firstArgument, parsePercentileToInt(remainingFunctionName));
+ return new PercentileAggregationFunction(firstArgument, parsePercentileToInt(remainingFunctionName),
+ nullHandlingEnabled);
} else if (remainingFunctionName.matches("EST\\d+")) {
// PercentileEst
String percentileString = remainingFunctionName.substring(3);
- return new PercentileEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
+ return new PercentileEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString),
+ nullHandlingEnabled);
} else if (remainingFunctionName.matches("RAWEST\\d+")) {
// PercentileRawEst
String percentileString = remainingFunctionName.substring(6);
- return new PercentileRawEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
+ return new PercentileRawEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString),
+ nullHandlingEnabled);
} else if (remainingFunctionName.matches("TDIGEST\\d+")) {
// PercentileTDigest
String percentileString = remainingFunctionName.substring(7);
- return new PercentileTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
+ return new PercentileTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString),
+ nullHandlingEnabled);
} else if (remainingFunctionName.matches("RAWTDIGEST\\d+")) {
// PercentileRawTDigest
String percentileString = remainingFunctionName.substring(10);
- return new PercentileRawTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
+ return new PercentileRawTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString),
+ nullHandlingEnabled);
} else if (remainingFunctionName.matches("\\d+MV")) {
// PercentileMV
String percentileString = remainingFunctionName.substring(0, remainingFunctionName.length() - 2);
@@ -125,23 +130,23 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
Preconditions.checkArgument(percentile >= 0 && percentile <= 100, "Invalid percentile: %s", percentile);
if (remainingFunctionName.isEmpty()) {
// Percentile
- return new PercentileAggregationFunction(firstArgument, percentile);
+ return new PercentileAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("EST")) {
// PercentileEst
- return new PercentileEstAggregationFunction(firstArgument, percentile);
+ return new PercentileEstAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("RAWEST")) {
// PercentileRawEst
- return new PercentileRawEstAggregationFunction(firstArgument, percentile);
+ return new PercentileRawEstAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("TDIGEST")) {
// PercentileTDigest
- return new PercentileTDigestAggregationFunction(firstArgument, percentile);
+ return new PercentileTDigestAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("RAWTDIGEST")) {
// PercentileRawTDigest
- return new PercentileRawTDigestAggregationFunction(firstArgument, percentile);
+ return new PercentileRawTDigestAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("MV")) {
// PercentileMV
@@ -175,11 +180,13 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
Preconditions.checkArgument(compressionFactor >= 0, "Invalid compressionFactor: %d", compressionFactor);
if (remainingFunctionName.equals("TDIGEST")) {
// PercentileTDigest
- return new PercentileTDigestAggregationFunction(firstArgument, percentile, compressionFactor);
+ return new PercentileTDigestAggregationFunction(firstArgument, percentile, compressionFactor,
+ nullHandlingEnabled);
}
if (remainingFunctionName.equals("RAWTDIGEST")) {
// PercentileRawTDigest
- return new PercentileRawTDigestAggregationFunction(firstArgument, percentile, compressionFactor);
+ return new PercentileRawTDigestAggregationFunction(firstArgument, percentile, compressionFactor,
+ nullHandlingEnabled);
}
if (remainingFunctionName.equals("TDIGESTMV")) {
// PercentileTDigestMV
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java
index 78f1ae12696c..907f0139d2a9 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/NullableSingleInputAggregationFunction.java
@@ -103,6 +103,15 @@ public void forEachNotNull(int length, IntIterator nullIndexIterator, BatchConsu
}
}
+ /**
+ * Folds over the non-null ranges of the blockValSet using the reducer.
+ * @param initialAcum the initial value of the accumulator
+ * @param The type of the accumulator
+ */
+ public A foldNotNull(int length, BlockValSet blockValSet, A initialAcum, Reducer reducer) {
+ return foldNotNull(length, blockValSet.getNullBitmap(), initialAcum, reducer);
+ }
+
/**
* Folds over the non-null ranges of the blockValSet using the reducer.
* @param initialAcum the initial value of the accumulator
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java
index 5d227caeadaa..c9c71744d267 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileAggregationFunction.java
@@ -31,7 +31,7 @@
import org.apache.pinot.segment.spi.AggregationFunctionType;
-public class PercentileAggregationFunction extends BaseSingleInputAggregationFunction {
+public class PercentileAggregationFunction extends NullableSingleInputAggregationFunction {
private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
//version 0 functions specified in the of form PERCENTILE<2-digits>(column)
@@ -39,14 +39,14 @@ public class PercentileAggregationFunction extends BaseSingleInputAggregationFun
protected final int _version;
protected final double _percentile;
- public PercentileAggregationFunction(ExpressionContext expression, int percentile) {
- super(expression);
+ public PercentileAggregationFunction(ExpressionContext expression, int percentile, boolean nullHandlingEnabled) {
+ super(expression, nullHandlingEnabled);
_version = 0;
_percentile = percentile;
}
- public PercentileAggregationFunction(ExpressionContext expression, double percentile) {
- super(expression);
+ public PercentileAggregationFunction(ExpressionContext expression, double percentile, boolean nullHandlingEnabled) {
+ super(expression, nullHandlingEnabled);
_version = 1;
_percentile = percentile;
}
@@ -77,33 +77,42 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
Map blockValSetMap) {
DoubleArrayList valueList = getValueList(aggregationResultHolder);
- double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
- for (int i = 0; i < length; i++) {
- valueList.add(valueArray[i]);
- }
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+ double[] valueArray = blockValSet.getDoubleValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueList.add(valueArray[i]);
+ }
+ });
}
@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
Map blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
- for (int i = 0; i < length; i++) {
- DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]);
- valueList.add(valueArray[i]);
- }
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+ double[] valueArray = blockValSet.getDoubleValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]);
+ valueList.add(valueArray[i]);
+ }
+ });
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
Map blockValSetMap) {
- double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
- for (int i = 0; i < length; i++) {
- double value = valueArray[i];
- for (int groupKey : groupKeysArray[i]) {
- DoubleArrayList valueList = getValueList(groupByResultHolder, groupKey);
- valueList.add(value);
+ BlockValSet blockValSet = blockValSetMap.get(_expression);
+ double[] valueArray = blockValSet.getDoubleValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ double value = valueArray[i];
+ for (int groupKey : groupKeysArray[i]) {
+ DoubleArrayList valueList = getValueList(groupByResultHolder, groupKey);
+ valueList.add(value);
+ }
}
- }
+ });
}
@Override
@@ -146,7 +155,11 @@ public ColumnDataType getFinalResultColumnType() {
public Double extractFinalResult(DoubleArrayList intermediateResult) {
int size = intermediateResult.size();
if (size == 0) {
- return DEFAULT_FINAL_RESULT;
+ if (_nullHandlingEnabled) {
+ return null;
+ } else {
+ return DEFAULT_FINAL_RESULT;
+ }
} else {
double[] values = intermediateResult.elements();
Arrays.sort(values, 0, size);
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java
index d055e4650541..e67a3f7d6500 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstAggregationFunction.java
@@ -32,7 +32,7 @@
import org.apache.pinot.spi.data.FieldSpec.DataType;
-public class PercentileEstAggregationFunction extends BaseSingleInputAggregationFunction {
+public class PercentileEstAggregationFunction extends NullableSingleInputAggregationFunction {
public static final double DEFAULT_MAX_ERROR = 0.05;
//version 0 functions specified in the of form PERCENTILEEST<2-digits>(column)
@@ -40,14 +40,15 @@ public class PercentileEstAggregationFunction extends BaseSingleInputAggregation
protected final int _version;
protected final double _percentile;
- public PercentileEstAggregationFunction(ExpressionContext expression, int percentile) {
- super(expression);
+ public PercentileEstAggregationFunction(ExpressionContext expression, int percentile, boolean nullHandlingEnabled) {
+ super(expression, nullHandlingEnabled);
_version = 0;
_percentile = percentile;
}
- public PercentileEstAggregationFunction(ExpressionContext expression, double percentile) {
- super(expression);
+ public PercentileEstAggregationFunction(ExpressionContext expression, double percentile,
+ boolean nullHandlingEnabled) {
+ super(expression, nullHandlingEnabled);
_version = 1;
_percentile = percentile;
}
@@ -81,24 +82,30 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde
if (blockValSet.getValueType() != DataType.BYTES) {
long[] longValues = blockValSet.getLongValuesSV();
QuantileDigest quantileDigest = getDefaultQuantileDigest(aggregationResultHolder);
- for (int i = 0; i < length; i++) {
- quantileDigest.add(longValues[i]);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ quantileDigest.add(longValues[i]);
+ }
+ });
} else {
// Serialized QuantileDigest
byte[][] bytesValues = blockValSet.getBytesValuesSV();
- QuantileDigest quantileDigest = aggregationResultHolder.getResult();
- if (quantileDigest != null) {
- for (int i = 0; i < length; i++) {
- quantileDigest.merge(ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]));
+ foldNotNull(length, blockValSet, (QuantileDigest) aggregationResultHolder.getResult(), (quantile, from, toEx) -> {
+ int start;
+ QuantileDigest quantileDigest;
+ if (quantile != null) {
+ start = from;
+ quantileDigest = quantile;
+ } else {
+ start = from + 1;
+ quantileDigest = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[from]);
+ aggregationResultHolder.setValue(quantileDigest);
}
- } else {
- quantileDigest = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[0]);
- aggregationResultHolder.setValue(quantileDigest);
- for (int i = 1; i < length; i++) {
+ for (int i = start; i < toEx; i++) {
quantileDigest.merge(ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]));
}
- }
+ return quantileDigest;
+ });
}
}
@@ -108,22 +115,26 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol
BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
long[] longValues = blockValSet.getLongValuesSV();
- for (int i = 0; i < length; i++) {
- getDefaultQuantileDigest(groupByResultHolder, groupKeyArray[i]).add(longValues[i]);
- }
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ getDefaultQuantileDigest(groupByResultHolder, groupKeyArray[i]).add(longValues[i]);
+ }
+ });
} else {
// Serialized QuantileDigest
byte[][] bytesValues = blockValSet.getBytesValuesSV();
- for (int i = 0; i < length; i++) {
- QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]);
- int groupKey = groupKeyArray[i];
- QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey);
- if (quantileDigest != null) {
- quantileDigest.merge(value);
- } else {
- groupByResultHolder.setValueForKey(groupKey, value);
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]);
+ int groupKey = groupKeyArray[i];
+ QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey);
+ if (quantileDigest != null) {
+ quantileDigest.merge(value);
+ } else {
+ groupByResultHolder.setValueForKey(groupKey, value);
+ }
}
- }
+ });
}
}
@@ -133,28 +144,32 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult
BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
long[] longValues = blockValSet.getLongValuesSV();
- for (int i = 0; i < length; i++) {
- long value = longValues[i];
- for (int groupKey : groupKeysArray[i]) {
- getDefaultQuantileDigest(groupByResultHolder, groupKey).add(value);
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ long value = longValues[i];
+ for (int groupKey : groupKeysArray[i]) {
+ getDefaultQuantileDigest(groupByResultHolder, groupKey).add(value);
+ }
}
- }
+ });
} else {
// Serialized QuantileDigest
byte[][] bytesValues = blockValSet.getBytesValuesSV();
- for (int i = 0; i < length; i++) {
- QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]);
- for (int groupKey : groupKeysArray[i]) {
- QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey);
- if (quantileDigest != null) {
- quantileDigest.merge(value);
- } else {
- // Create a new QuantileDigest for the group
- groupByResultHolder
- .setValueForKey(groupKey, ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]));
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]);
+ for (int groupKey : groupKeysArray[i]) {
+ QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey);
+ if (quantileDigest != null) {
+ quantileDigest.merge(value);
+ } else {
+ // Create a new QuantileDigest for the group
+ groupByResultHolder.setValueForKey(groupKey,
+ ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]));
+ }
}
}
- }
+ });
}
}
@@ -202,6 +217,9 @@ public ColumnDataType getFinalResultColumnType() {
@Override
public Long extractFinalResult(QuantileDigest intermediateResult) {
+ if (intermediateResult.getCount() == 0 && _nullHandlingEnabled) {
+ return null;
+ }
return intermediateResult.getQuantile(_percentile / 100.0);
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java
index c1001f25c7eb..5a861714620e 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileEstMVAggregationFunction.java
@@ -30,11 +30,11 @@
public class PercentileEstMVAggregationFunction extends PercentileEstAggregationFunction {
public PercentileEstMVAggregationFunction(ExpressionContext expression, int percentile) {
- super(expression, percentile);
+ super(expression, percentile, false);
}
public PercentileEstMVAggregationFunction(ExpressionContext expression, double percentile) {
- super(expression, percentile);
+ super(expression, percentile, false);
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunction.java
index 6d2b3b8697f9..bcf025a80149 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLAggregationFunction.java
@@ -61,14 +61,14 @@
*
*/
public class PercentileKLLAggregationFunction
- extends BaseSingleInputAggregationFunction> {
+ extends NullableSingleInputAggregationFunction> {
protected static final int DEFAULT_K_VALUE = 200;
protected final double _percentile;
protected int _kValue;
- public PercentileKLLAggregationFunction(List arguments) {
- super(arguments.get(0));
+ public PercentileKLLAggregationFunction(List arguments, boolean nullHandlingEnabled) {
+ super(arguments.get(0), nullHandlingEnabled);
// Check that there are correct number of arguments
int numArguments = arguments.size();
@@ -107,14 +107,18 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde
if (valueType == DataType.BYTES) {
// Assuming the column contains serialized data sketch
KllDoublesSketch[] deserializedSketches = deserializeSketches(blockValSetMap.get(_expression).getBytesValuesSV());
- for (int i = 0; i < length; i++) {
- sketch.merge(deserializedSketches[i]);
- }
+ forEachNotNull(length, valueSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ sketch.merge(deserializedSketches[i]);
+ }
+ });
} else {
double[] values = valueSet.getDoubleValuesSV();
- for (int i = 0; i < length; i++) {
- sketch.update(values[i]);
- }
+ forEachNotNull(length, valueSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ sketch.update(values[i]);
+ }
+ });
}
}
@@ -127,16 +131,20 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol
if (valueType == DataType.BYTES) {
// serialized sketch
KllDoublesSketch[] deserializedSketches = deserializeSketches(blockValSetMap.get(_expression).getBytesValuesSV());
- for (int i = 0; i < length; i++) {
- KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]);
- sketch.merge(deserializedSketches[i]);
- }
+ forEachNotNull(length, valueSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]);
+ sketch.merge(deserializedSketches[i]);
+ }
+ });
} else {
double[] values = valueSet.getDoubleValuesSV();
- for (int i = 0; i < length; i++) {
- KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]);
- sketch.update(values[i]);
- }
+ forEachNotNull(length, valueSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKeyArray[i]);
+ sketch.update(values[i]);
+ }
+ });
}
}
@@ -149,20 +157,24 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult
if (valueType == DataType.BYTES) {
// serialized sketch
KllDoublesSketch[] deserializedSketches = deserializeSketches(blockValSetMap.get(_expression).getBytesValuesSV());
- for (int i = 0; i < length; i++) {
- for (int groupKey : groupKeysArray[i]) {
- KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey);
- sketch.merge(deserializedSketches[i]);
+ forEachNotNull(length, valueSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey);
+ sketch.merge(deserializedSketches[i]);
+ }
}
- }
+ });
} else {
double[] values = valueSet.getDoubleValuesSV();
- for (int i = 0; i < length; i++) {
- for (int groupKey : groupKeysArray[i]) {
- KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey);
- sketch.update(values[i]);
+ forEachNotNull(length, valueSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ KllDoublesSketch sketch = getOrCreateSketch(groupByResultHolder, groupKey);
+ sketch.update(values[i]);
+ }
}
- }
+ });
}
}
@@ -241,6 +253,9 @@ public String getResultColumnName() {
@Override
public Comparable> extractFinalResult(KllDoublesSketch sketch) {
+ if (sketch.isEmpty() && _nullHandlingEnabled) {
+ return null;
+ }
return sketch.getQuantile(_percentile / 100);
}
}
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLMVAggregationFunction.java
index 4653e9051d38..26af8dea447d 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileKLLMVAggregationFunction.java
@@ -32,7 +32,7 @@
public class PercentileKLLMVAggregationFunction extends PercentileKLLAggregationFunction {
public PercentileKLLMVAggregationFunction(List arguments) {
- super(arguments);
+ super(arguments, false);
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java
index 794a9896a7d4..620763ea7599 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileMVAggregationFunction.java
@@ -30,11 +30,11 @@
public class PercentileMVAggregationFunction extends PercentileAggregationFunction {
public PercentileMVAggregationFunction(ExpressionContext expression, int percentile) {
- super(expression, percentile);
+ super(expression, percentile, false);
}
public PercentileMVAggregationFunction(ExpressionContext expression, double percentile) {
- super(expression, percentile);
+ super(expression, percentile, false);
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawEstAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawEstAggregationFunction.java
index 063359ec9604..04787e7d559b 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawEstAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawEstAggregationFunction.java
@@ -37,12 +37,14 @@ public class PercentileRawEstAggregationFunction
extends BaseSingleInputAggregationFunction {
private final PercentileEstAggregationFunction _percentileEstAggregationFunction;
- public PercentileRawEstAggregationFunction(ExpressionContext expressionContext, double percentile) {
- this(expressionContext, new PercentileEstAggregationFunction(expressionContext, percentile));
+ public PercentileRawEstAggregationFunction(ExpressionContext expressionContext, double percentile,
+ boolean nullHandlingEnabled) {
+ this(expressionContext, new PercentileEstAggregationFunction(expressionContext, percentile, nullHandlingEnabled));
}
- public PercentileRawEstAggregationFunction(ExpressionContext expressionContext, int percentile) {
- this(expressionContext, new PercentileEstAggregationFunction(expressionContext, percentile));
+ public PercentileRawEstAggregationFunction(ExpressionContext expressionContext, int percentile,
+ boolean nullHandlingEnabled) {
+ this(expressionContext, new PercentileEstAggregationFunction(expressionContext, percentile, nullHandlingEnabled));
}
protected PercentileRawEstAggregationFunction(ExpressionContext expression,
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawKLLAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawKLLAggregationFunction.java
index 39c2022ff026..7e88cf009d88 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawKLLAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawKLLAggregationFunction.java
@@ -28,8 +28,8 @@
public class PercentileRawKLLAggregationFunction extends PercentileKLLAggregationFunction {
- public PercentileRawKLLAggregationFunction(List arguments) {
- super(arguments);
+ public PercentileRawKLLAggregationFunction(List arguments, boolean nullHandlingEnabled) {
+ super(arguments, nullHandlingEnabled);
}
@Override
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawTDigestAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawTDigestAggregationFunction.java
index 99a096c13063..fc618027a5fa 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawTDigestAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileRawTDigestAggregationFunction.java
@@ -37,17 +37,22 @@ public class PercentileRawTDigestAggregationFunction
extends BaseSingleInputAggregationFunction {
private final PercentileTDigestAggregationFunction _percentileTDigestAggregationFunction;
- public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, int percentile) {
- this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile));
+ public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, int percentile,
+ boolean nullHandlingEnabled) {
+ this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile,
+ nullHandlingEnabled));
}
- public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, double percentile) {
- this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile));
+ public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, double percentile,
+ boolean nullHandlingEnabled) {
+ this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile,
+ nullHandlingEnabled));
}
public PercentileRawTDigestAggregationFunction(ExpressionContext expressionContext, double percentile,
- int compressionFactor) {
- this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile, compressionFactor));
+ int compressionFactor, boolean nullHandlingEnabled) {
+ this(expressionContext, new PercentileTDigestAggregationFunction(expressionContext, percentile, compressionFactor,
+ nullHandlingEnabled));
}
protected PercentileRawTDigestAggregationFunction(ExpressionContext expression,
diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunction.java
index 92cd5fa09b9d..20d5372ca56f 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunction.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/PercentileSmartTDigestAggregationFunction.java
@@ -50,15 +50,15 @@
* - compression: Compression for the converted TDigest, 100 by default.
* Example of third argument: 'threshold=10000;compression=50'
*/
-public class PercentileSmartTDigestAggregationFunction extends BaseSingleInputAggregationFunction