Skip to content

Commit

Permalink
Percentile operations supporting null (#12271)
Browse files Browse the repository at this point in the history
* new test framework candidate

* Improved test system

* Improve framework to be able to specify segments as strings

* fix headers

* Improve assertions when there are nulls

* Improve error text

* Improvements in the framework

* Add a base class single input aggregation operations can extend to support null handling

* Fix issue in NullableSingleInputAggregationFunction.forEachNotNullInt

* Improve error message in NullEnabledQueriesTest

* Add new schema family

* Rename test schemas and table config

* Split AllNullQueriesTest into on test per query

* Revert change in AllNullQueriesTest that belongs to mode-null-support branch

* Add tests

* Fix issue in bytes in aggregation case

* Update to the new framework

* Fix some tests

* rollback a code style change
  • Loading branch information
gortiz authored Apr 1, 2024
1 parent 3185e30 commit c15c391
Show file tree
Hide file tree
Showing 20 changed files with 861 additions and 213 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
if (upperCaseFunctionName.startsWith("PERCENTILE")) {
String remainingFunctionName = upperCaseFunctionName.substring(10);
if (remainingFunctionName.equals("SMARTTDIGEST")) {
return new PercentileSmartTDigestAggregationFunction(arguments);
return new PercentileSmartTDigestAggregationFunction(arguments, nullHandlingEnabled);
}
if (remainingFunctionName.equals("KLL")) {
return new PercentileKLLAggregationFunction(arguments);
return new PercentileKLLAggregationFunction(arguments, nullHandlingEnabled);
}
if (remainingFunctionName.equals("KLLMV")) {
return new PercentileKLLMVAggregationFunction(arguments);
}
if (remainingFunctionName.equals("RAWKLL")) {
return new PercentileRawKLLAggregationFunction(arguments);
return new PercentileRawKLLAggregationFunction(arguments, nullHandlingEnabled);
}
if (remainingFunctionName.equals("RAWKLLMV")) {
return new PercentileRawKLLMVAggregationFunction(arguments);
Expand All @@ -80,23 +80,28 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
// NOTE: This convention is deprecated. DO NOT add new functions here
if (remainingFunctionName.matches("\\d+")) {
// Percentile
return new PercentileAggregationFunction(firstArgument, parsePercentileToInt(remainingFunctionName));
return new PercentileAggregationFunction(firstArgument, parsePercentileToInt(remainingFunctionName),
nullHandlingEnabled);
} else if (remainingFunctionName.matches("EST\\d+")) {
// PercentileEst
String percentileString = remainingFunctionName.substring(3);
return new PercentileEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
return new PercentileEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString),
nullHandlingEnabled);
} else if (remainingFunctionName.matches("RAWEST\\d+")) {
// PercentileRawEst
String percentileString = remainingFunctionName.substring(6);
return new PercentileRawEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
return new PercentileRawEstAggregationFunction(firstArgument, parsePercentileToInt(percentileString),
nullHandlingEnabled);
} else if (remainingFunctionName.matches("TDIGEST\\d+")) {
// PercentileTDigest
String percentileString = remainingFunctionName.substring(7);
return new PercentileTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
return new PercentileTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString),
nullHandlingEnabled);
} else if (remainingFunctionName.matches("RAWTDIGEST\\d+")) {
// PercentileRawTDigest
String percentileString = remainingFunctionName.substring(10);
return new PercentileRawTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString));
return new PercentileRawTDigestAggregationFunction(firstArgument, parsePercentileToInt(percentileString),
nullHandlingEnabled);
} else if (remainingFunctionName.matches("\\d+MV")) {
// PercentileMV
String percentileString = remainingFunctionName.substring(0, remainingFunctionName.length() - 2);
Expand Down Expand Up @@ -125,23 +130,23 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
Preconditions.checkArgument(percentile >= 0 && percentile <= 100, "Invalid percentile: %s", percentile);
if (remainingFunctionName.isEmpty()) {
// Percentile
return new PercentileAggregationFunction(firstArgument, percentile);
return new PercentileAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("EST")) {
// PercentileEst
return new PercentileEstAggregationFunction(firstArgument, percentile);
return new PercentileEstAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("RAWEST")) {
// PercentileRawEst
return new PercentileRawEstAggregationFunction(firstArgument, percentile);
return new PercentileRawEstAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("TDIGEST")) {
// PercentileTDigest
return new PercentileTDigestAggregationFunction(firstArgument, percentile);
return new PercentileTDigestAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("RAWTDIGEST")) {
// PercentileRawTDigest
return new PercentileRawTDigestAggregationFunction(firstArgument, percentile);
return new PercentileRawTDigestAggregationFunction(firstArgument, percentile, nullHandlingEnabled);
}
if (remainingFunctionName.equals("MV")) {
// PercentileMV
Expand Down Expand Up @@ -175,11 +180,13 @@ public static AggregationFunction getAggregationFunction(FunctionContext functio
Preconditions.checkArgument(compressionFactor >= 0, "Invalid compressionFactor: %d", compressionFactor);
if (remainingFunctionName.equals("TDIGEST")) {
// PercentileTDigest
return new PercentileTDigestAggregationFunction(firstArgument, percentile, compressionFactor);
return new PercentileTDigestAggregationFunction(firstArgument, percentile, compressionFactor,
nullHandlingEnabled);
}
if (remainingFunctionName.equals("RAWTDIGEST")) {
// PercentileRawTDigest
return new PercentileRawTDigestAggregationFunction(firstArgument, percentile, compressionFactor);
return new PercentileRawTDigestAggregationFunction(firstArgument, percentile, compressionFactor,
nullHandlingEnabled);
}
if (remainingFunctionName.equals("TDIGESTMV")) {
// PercentileTDigestMV
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ public void forEachNotNull(int length, IntIterator nullIndexIterator, BatchConsu
}
}

/**
* Folds over the non-null ranges of the blockValSet using the reducer.
* @param initialAcum the initial value of the accumulator
* @param <A> The type of the accumulator
*/
public <A> A foldNotNull(int length, BlockValSet blockValSet, A initialAcum, Reducer<A> reducer) {
return foldNotNull(length, blockValSet.getNullBitmap(), initialAcum, reducer);
}

/**
* Folds over the non-null ranges of the blockValSet using the reducer.
* @param initialAcum the initial value of the accumulator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@
import org.apache.pinot.segment.spi.AggregationFunctionType;


public class PercentileAggregationFunction extends BaseSingleInputAggregationFunction<DoubleArrayList, Double> {
public class PercentileAggregationFunction extends NullableSingleInputAggregationFunction<DoubleArrayList, Double> {
private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;

//version 0 functions specified in the of form PERCENTILE<2-digits>(column)
//version 1 functions of form PERCENTILE(column, <2-digits>.<16-digits>)
protected final int _version;
protected final double _percentile;

public PercentileAggregationFunction(ExpressionContext expression, int percentile) {
super(expression);
public PercentileAggregationFunction(ExpressionContext expression, int percentile, boolean nullHandlingEnabled) {
super(expression, nullHandlingEnabled);
_version = 0;
_percentile = percentile;
}

public PercentileAggregationFunction(ExpressionContext expression, double percentile) {
super(expression);
public PercentileAggregationFunction(ExpressionContext expression, double percentile, boolean nullHandlingEnabled) {
super(expression, nullHandlingEnabled);
_version = 1;
_percentile = percentile;
}
Expand Down Expand Up @@ -77,33 +77,42 @@ public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int ma
public void aggregate(int length, AggregationResultHolder aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
DoubleArrayList valueList = getValueList(aggregationResultHolder);
double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
valueList.add(valueArray[i]);
}
BlockValSet blockValSet = blockValSetMap.get(_expression);
double[] valueArray = blockValSet.getDoubleValuesSV();
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
valueList.add(valueArray[i]);
}
});
}

@Override
public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]);
valueList.add(valueArray[i]);
}
BlockValSet blockValSet = blockValSetMap.get(_expression);
double[] valueArray = blockValSet.getDoubleValuesSV();
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
DoubleArrayList valueList = getValueList(groupByResultHolder, groupKeyArray[i]);
valueList.add(valueArray[i]);
}
});
}

@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
double[] valueArray = blockValSetMap.get(_expression).getDoubleValuesSV();
for (int i = 0; i < length; i++) {
double value = valueArray[i];
for (int groupKey : groupKeysArray[i]) {
DoubleArrayList valueList = getValueList(groupByResultHolder, groupKey);
valueList.add(value);
BlockValSet blockValSet = blockValSetMap.get(_expression);
double[] valueArray = blockValSet.getDoubleValuesSV();
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
double value = valueArray[i];
for (int groupKey : groupKeysArray[i]) {
DoubleArrayList valueList = getValueList(groupByResultHolder, groupKey);
valueList.add(value);
}
}
}
});
}

@Override
Expand Down Expand Up @@ -146,7 +155,11 @@ public ColumnDataType getFinalResultColumnType() {
public Double extractFinalResult(DoubleArrayList intermediateResult) {
int size = intermediateResult.size();
if (size == 0) {
return DEFAULT_FINAL_RESULT;
if (_nullHandlingEnabled) {
return null;
} else {
return DEFAULT_FINAL_RESULT;
}
} else {
double[] values = intermediateResult.elements();
Arrays.sort(values, 0, size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,23 @@
import org.apache.pinot.spi.data.FieldSpec.DataType;


public class PercentileEstAggregationFunction extends BaseSingleInputAggregationFunction<QuantileDigest, Long> {
public class PercentileEstAggregationFunction extends NullableSingleInputAggregationFunction<QuantileDigest, Long> {
public static final double DEFAULT_MAX_ERROR = 0.05;

//version 0 functions specified in the of form PERCENTILEEST<2-digits>(column)
//version 1 functions of form PERCENTILEEST(column, <2-digits>.<16-digits>)
protected final int _version;
protected final double _percentile;

public PercentileEstAggregationFunction(ExpressionContext expression, int percentile) {
super(expression);
public PercentileEstAggregationFunction(ExpressionContext expression, int percentile, boolean nullHandlingEnabled) {
super(expression, nullHandlingEnabled);
_version = 0;
_percentile = percentile;
}

public PercentileEstAggregationFunction(ExpressionContext expression, double percentile) {
super(expression);
public PercentileEstAggregationFunction(ExpressionContext expression, double percentile,
boolean nullHandlingEnabled) {
super(expression, nullHandlingEnabled);
_version = 1;
_percentile = percentile;
}
Expand Down Expand Up @@ -81,24 +82,30 @@ public void aggregate(int length, AggregationResultHolder aggregationResultHolde
if (blockValSet.getValueType() != DataType.BYTES) {
long[] longValues = blockValSet.getLongValuesSV();
QuantileDigest quantileDigest = getDefaultQuantileDigest(aggregationResultHolder);
for (int i = 0; i < length; i++) {
quantileDigest.add(longValues[i]);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
quantileDigest.add(longValues[i]);
}
});
} else {
// Serialized QuantileDigest
byte[][] bytesValues = blockValSet.getBytesValuesSV();
QuantileDigest quantileDigest = aggregationResultHolder.getResult();
if (quantileDigest != null) {
for (int i = 0; i < length; i++) {
quantileDigest.merge(ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]));
foldNotNull(length, blockValSet, (QuantileDigest) aggregationResultHolder.getResult(), (quantile, from, toEx) -> {
int start;
QuantileDigest quantileDigest;
if (quantile != null) {
start = from;
quantileDigest = quantile;
} else {
start = from + 1;
quantileDigest = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[from]);
aggregationResultHolder.setValue(quantileDigest);
}
} else {
quantileDigest = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[0]);
aggregationResultHolder.setValue(quantileDigest);
for (int i = 1; i < length; i++) {
for (int i = start; i < toEx; i++) {
quantileDigest.merge(ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]));
}
}
return quantileDigest;
});
}
}

Expand All @@ -108,22 +115,26 @@ public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHol
BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
getDefaultQuantileDigest(groupByResultHolder, groupKeyArray[i]).add(longValues[i]);
}
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
getDefaultQuantileDigest(groupByResultHolder, groupKeyArray[i]).add(longValues[i]);
}
});
} else {
// Serialized QuantileDigest
byte[][] bytesValues = blockValSet.getBytesValuesSV();
for (int i = 0; i < length; i++) {
QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]);
int groupKey = groupKeyArray[i];
QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey);
if (quantileDigest != null) {
quantileDigest.merge(value);
} else {
groupByResultHolder.setValueForKey(groupKey, value);
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]);
int groupKey = groupKeyArray[i];
QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey);
if (quantileDigest != null) {
quantileDigest.merge(value);
} else {
groupByResultHolder.setValueForKey(groupKey, value);
}
}
}
});
}
}

Expand All @@ -133,28 +144,32 @@ public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResult
BlockValSet blockValSet = blockValSetMap.get(_expression);
if (blockValSet.getValueType() != DataType.BYTES) {
long[] longValues = blockValSet.getLongValuesSV();
for (int i = 0; i < length; i++) {
long value = longValues[i];
for (int groupKey : groupKeysArray[i]) {
getDefaultQuantileDigest(groupByResultHolder, groupKey).add(value);
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
long value = longValues[i];
for (int groupKey : groupKeysArray[i]) {
getDefaultQuantileDigest(groupByResultHolder, groupKey).add(value);
}
}
}
});
} else {
// Serialized QuantileDigest
byte[][] bytesValues = blockValSet.getBytesValuesSV();
for (int i = 0; i < length; i++) {
QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]);
for (int groupKey : groupKeysArray[i]) {
QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey);
if (quantileDigest != null) {
quantileDigest.merge(value);
} else {
// Create a new QuantileDigest for the group
groupByResultHolder
.setValueForKey(groupKey, ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]));
forEachNotNull(length, blockValSet, (from, to) -> {
for (int i = from; i < to; i++) {
QuantileDigest value = ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]);
for (int groupKey : groupKeysArray[i]) {
QuantileDigest quantileDigest = groupByResultHolder.getResult(groupKey);
if (quantileDigest != null) {
quantileDigest.merge(value);
} else {
// Create a new QuantileDigest for the group
groupByResultHolder.setValueForKey(groupKey,
ObjectSerDeUtils.QUANTILE_DIGEST_SER_DE.deserialize(bytesValues[i]));
}
}
}
}
});
}
}

Expand Down Expand Up @@ -202,6 +217,9 @@ public ColumnDataType getFinalResultColumnType() {

@Override
public Long extractFinalResult(QuantileDigest intermediateResult) {
if (intermediateResult.getCount() == 0 && _nullHandlingEnabled) {
return null;
}
return intermediateResult.getQuantile(_percentile / 100.0);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
public class PercentileEstMVAggregationFunction extends PercentileEstAggregationFunction {

public PercentileEstMVAggregationFunction(ExpressionContext expression, int percentile) {
super(expression, percentile);
super(expression, percentile, false);
}

public PercentileEstMVAggregationFunction(ExpressionContext expression, double percentile) {
super(expression, percentile);
super(expression, percentile, false);
}

@Override
Expand Down
Loading

0 comments on commit c15c391

Please sign in to comment.