Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DATE_TRUNC Optimizer #14385

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,42 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.sql.Time;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.function.DateTimeUtils;
import org.apache.pinot.common.function.TimeZoneKey;
import org.apache.pinot.common.function.scalar.DateTimeFunctions;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.common.utils.request.RequestUtils;
import org.apache.pinot.core.operator.transform.function.DateTimeConversionTransformFunction;
import org.apache.pinot.core.operator.transform.function.DateTruncTransformFunction;
import org.apache.pinot.core.operator.transform.function.LiteralTransformFunction;
import org.apache.pinot.core.operator.transform.function.TimeConversionTransformFunction;
import org.apache.pinot.spi.data.DateTimeFieldSpec.TimeFormat;
import org.apache.pinot.spi.data.DateTimeFormatSpec;
import org.apache.pinot.spi.data.DateTimeGranularitySpec;
import org.apache.pinot.spi.data.Schema;
import org.apache.pinot.spi.utils.TimeUtils;
import org.apache.pinot.sql.FilterKind;
import org.joda.time.DateTime;
import org.joda.time.DateTimeField;
import org.joda.time.DateTimeZone;
import org.joda.time.DurationField;
import org.joda.time.DurationFieldType;
import org.joda.time.chrono.ISOChronology;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.time.*;


/**
Expand All @@ -53,6 +69,17 @@
* <p>NOTE: Other predicates such as NOT_EQUALS, IN, NOT_IN are not supported for now because these predicates are
* not common on time column, and they cannot be optimized to a single range predicate.
* </li>
* <li>
* Optimizes DATE_TRUNC function with range/equality predicates by either rounding up or down to closest granularity
* step
* <p>E.g. "dateTrunc('DAY', col, 'MILLISECONDS') > 1620777600000" will be optimized
* to "col > 1620863999999" as 1620863999999 is the largest value that can be truncated to 1620777600000
* <p>E.g. "datetrunc('DAY', col, 'MILLISECONDS') <= 1620777600010" will be optimized
* to col <= 1620863999999 as the next granularity step lower than 1620777600010 is 1620777600000 and 1620863999999
* is the largest value that truncates to be lower than the specified literal.
* <p>NOTE: Other predicates such as NOT_EQUALS, IN, NOT_IN are not supported for now because these predicates are
* not common on time column, and they cannot be optimized to a single range predicate.
* </li>
* </ul>
*
* NOTE: This optimizer is followed by the {@link MergeRangeFilterOptimizer}, which can merge the generated ranges.
Expand Down Expand Up @@ -84,6 +111,8 @@ Expression optimize(Expression filterExpression) {
optimizeTimeConvert(filterFunction, filterKind);
} else if (functionName.equalsIgnoreCase(DateTimeConversionTransformFunction.FUNCTION_NAME)) {
optimizeDateTimeConvert(filterFunction, filterKind);
} else if (functionName.equalsIgnoreCase(DateTruncTransformFunction.FUNCTION_NAME)) {
ashishjayamohan marked this conversation as resolved.
Show resolved Hide resolved
optimizeDateTrunc(filterFunction, filterKind);
}
}
}
Expand Down Expand Up @@ -411,6 +440,95 @@ && isStringLiteral(dateTimeConvertOperands.get(3)),
}
}

private void optimizeDateTrunc(Function filterFunction, FilterKind filterKind) {
List<Expression> filterOperands = filterFunction.getOperands();
List<Expression> dateTruncOperands = filterOperands.get(0).getFunctionCall().getOperands();

// TODO: Compute value and create query is date trunc is applied on a literal value
if (dateTruncOperands.get(1).isSetLiteral()) {
return;
}

Long lowerMillis = null;
Long upperMillis = null;
boolean lowerInclusive = true;
boolean upperInclusive = true;
List<Expression> operands = new ArrayList<>(dateTruncOperands);
String unit = operands.get(0).getLiteral().getStringValue();
String inputTimeUnit = (operands.size() >= 3) ? operands.get(2).getLiteral().getStringValue()
: TimeUnit.MILLISECONDS.name();
ISOChronology chronology = (operands.size() >= 4)
? DateTimeUtils.getChronology(TimeZoneKey.getTimeZoneKey(operands.get(3).getLiteral().getStringValue()))
: ISOChronology.getInstanceUTC();
String outputTimeUnit = (operands.size() == 5) ? operands.get(4).getLiteral().getStringValue()
: TimeUnit.MILLISECONDS.name();
System.out.println(Arrays.toString(
calculateRangeForDateTrunc(unit, getLongValue(filterOperands.get(1)), inputTimeUnit, chronology,
outputTimeUnit)));
switch (filterKind) {
case EQUALS:
operands.set(1, getExpression(getLongValue(filterOperands.get(1)), new DateTimeFormatSpec("TIMESTAMP")));
upperMillis = dateTruncCeil(operands);
lowerMillis = dateTruncFloor(operands);
if (lowerMillis != TimeUnit.MILLISECONDS.convert(getLongValue(filterOperands.get(1)), TimeUnit.valueOf(outputTimeUnit.toUpperCase()))) {
lowerMillis = Long.MAX_VALUE;
upperMillis = Long.MIN_VALUE;
String rangeString = new Range(lowerMillis, lowerInclusive, upperMillis, upperInclusive).getRangeString();
rewriteToRange(filterFunction, dateTruncOperands.get(1), rangeString);
return;
}
break;
case GREATER_THAN:
operands.set(1, getExpression(getLongValue(filterOperands.get(1)), new DateTimeFormatSpec("TIMESTAMP")));
lowerMillis = dateTruncCeil(operands);
lowerInclusive = false;
upperMillis = Long.MAX_VALUE;
break;
case GREATER_THAN_OR_EQUAL:
operands.set(1, getExpression(getLongValue(filterOperands.get(1)), new DateTimeFormatSpec("TIMESTAMP")));
lowerMillis = dateTruncFloor(operands);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be ceil?

upperMillis = Long.MAX_VALUE;
if (TimeUnit.valueOf(outputTimeUnit).convert(lowerMillis, TimeUnit.MILLISECONDS)
!= getLongValue(filterOperands.get(1))) {
lowerInclusive = false;
lowerMillis = dateTruncCeil(operands);
}
break;
case LESS_THAN:
operands.set(1, getExpression(getLongValue(filterOperands.get(1)), new DateTimeFormatSpec("TIMESTAMP")));
lowerMillis = Long.MIN_VALUE;
upperInclusive = false;
upperMillis = dateTruncFloor(operands);
if (upperMillis != TimeUnit.MILLISECONDS.convert(getLongValue(filterOperands.get(1)), TimeUnit.valueOf(outputTimeUnit.toUpperCase()))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we checking this here but not in in GREATER_THAN?

upperInclusive = true;
upperMillis = dateTruncCeil(operands);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this recomputed here?

}
break;
case LESS_THAN_OR_EQUAL:
operands.set(1, getExpression(getLongValue(filterOperands.get(1)), new DateTimeFormatSpec("TIMESTAMP")));
lowerMillis = Long.MIN_VALUE;
upperMillis = dateTruncCeil(operands);
break;
case BETWEEN:
operands.set(1, getExpression(getLongValue(filterOperands.get(1)), new DateTimeFormatSpec("TIMESTAMP")));
lowerMillis = dateTruncFloor(operands);
if (TimeUnit.valueOf(outputTimeUnit).convert(lowerMillis, TimeUnit.MILLISECONDS)
!= getLongValue(filterOperands.get(1))) {
lowerInclusive = false;
lowerMillis = dateTruncCeil(operands);
}
operands.set(1, getExpression(getLongValue(filterOperands.get(2)), new DateTimeFormatSpec("TIMESTAMP")));
upperMillis = dateTruncCeil(operands);
break;
default:
throw new IllegalStateException();
ashishjayamohan marked this conversation as resolved.
Show resolved Hide resolved
}
lowerMillis = TimeUnit.valueOf(inputTimeUnit).convert(lowerMillis, TimeUnit.MILLISECONDS);
upperMillis = TimeUnit.valueOf(inputTimeUnit).convert(upperMillis, TimeUnit.MILLISECONDS);
String rangeString = new Range(lowerMillis, lowerInclusive, upperMillis, upperInclusive).getRangeString();
rewriteToRange(filterFunction, dateTruncOperands.get(1), rangeString);
}

private boolean isStringLiteral(Expression expression) {
Literal literal = expression.getLiteral();
return literal != null && literal.isSetStringValue();
Expand Down Expand Up @@ -438,12 +556,89 @@ private long ceil(long millisValue, long granularityMillis) {
return (millisValue + granularityMillis - 1) / granularityMillis * granularityMillis;
}

private static void rewriteToRange(Function filterFunction, Expression expression, String rangeString) {
private void rewriteToRange(Function filterFunction, Expression expression, String rangeString) {
filterFunction.setOperator(FilterKind.RANGE.name());
// NOTE: Create an ArrayList because we might need to modify the list later
List<Expression> newOperands = new ArrayList<>(2);
newOperands.add(expression);
newOperands.add(RequestUtils.getLiteralExpression(rangeString));
filterFunction.setOperands(newOperands);
}


private Expression getExpression(long value, DateTimeFormatSpec inputFormat) {
Literal literal = new Literal();
literal.setLongValue(value);
Expression expression = new Expression(ExpressionType.LITERAL);
expression.setLiteral(literal);
return expression;
}

/**
* Helper function to find the floor of acceptable values truncating to a specified value
*/
private long dateTruncFloor(List<Expression> operands) {
String unit = operands.get(0).getLiteral().getStringValue();
long timeValue = getLongValue(operands.get(1));
ISOChronology chronology = (operands.size() >= 4) ? DateTimeUtils.getChronology(TimeZoneKey.getTimeZoneKey(operands.get(3).getLiteral().getStringValue())) : ISOChronology.getInstanceUTC();
String outputTimeUnit = (operands.size() == 5) ? operands.get(4).getLiteral().getStringValue()
: TimeUnit.MILLISECONDS.name();
long timeInMillis = TimeUnit.MILLISECONDS.convert(timeValue, TimeUnit.valueOf(outputTimeUnit.toUpperCase()));
return DateTimeUtils.getTimestampField(chronology, unit).roundFloor(timeInMillis);
}

/**
* Helper function that finds the maximum value (ceiling) that truncates to specified value
* Computes ceiling inverse of date trunc function
*/
private long dateTruncCeil(List<Expression> operands) {
String unit = operands.get(0).getLiteral().getStringValue();
long timeValue = getLongValue(operands.get(1));
ISOChronology chronology = (operands.size() >= 4) ? DateTimeUtils.getChronology(TimeZoneKey.getTimeZoneKey(operands.get(3).getLiteral().getStringValue())) : ISOChronology.getInstanceUTC();
String outputTimeUnit = (operands.size() == 5) ? operands.get(4).getLiteral().getStringValue()
: TimeUnit.MILLISECONDS.name();
// Add value of 1 unit as specified, subtract one to find maximum value that will truncate to desired value
long timeInMillis = TimeUnit.MILLISECONDS.convert(timeValue, TimeUnit.valueOf(outputTimeUnit.toUpperCase()));
return DateTimeUtils.getTimestampField(chronology, unit).roundCeiling(timeInMillis + 1) - 1;
}

public static long[] calculateRangeForDateTrunc(String unit, long targetTruncatedValue,
String inputTimeUnit, ISOChronology chronology,
String outputTimeUnit) {
// Step 1: Convert targetTruncatedValue to milliseconds (expected input for rounding)
long truncatedTimeInMs = TimeUnit.MILLISECONDS.convert(targetTruncatedValue,
TimeUnit.valueOf(outputTimeUnit.toUpperCase()));

// Step 2: Get the DateTimeField for the truncation unit (e.g., day, hour)
DateTimeField field = DateTimeUtils.getTimestampField(ISOChronology.getInstanceUTC(), unit);

// Step 3: Calculate the start of the interval in milliseconds
long intervalStartInMs = field.roundCeiling(truncatedTimeInMs);

// Step 4: Calculate the end of the interval in milliseconds
long intervalEndInMs = field.roundCeiling(intervalStartInMs + 1) - 1;

// Step 5a: Convert interval start back to the original input time unit
long intervalStart = TimeUnit.valueOf(inputTimeUnit.toUpperCase())
.convert(intervalStartInMs, TimeUnit.MILLISECONDS);

long checkIntervalStartInMs = TimeUnit.MILLISECONDS.convert(intervalStart,
TimeUnit.valueOf(inputTimeUnit.toUpperCase()));

// Step 5b: Carefully convert interval end to avoid precision loss
// First, try converting intervalEndInMs to the input unit directly
long intervalEnd = TimeUnit.valueOf(inputTimeUnit.toUpperCase())
.convert(intervalEndInMs, TimeUnit.MILLISECONDS);

// Check if precision was lost in conversion (by converting back and comparing)
long checkIntervalEndInMs = TimeUnit.MILLISECONDS.convert(intervalEnd,
TimeUnit.valueOf(inputTimeUnit.toUpperCase()));

// intervalStart = TimeUnit.valueOf(inputTimeUnit.toUpperCase())
// .convert(DateTimeUtils.getTimestampField(ISOChronology.getInstanceUTC(), unit).roundCeiling(intervalStartInMs), TimeUnit.MILLISECONDS);
// intervalEnd = TimeUnit.valueOf(inputTimeUnit.toUpperCase())
// .convert(DateTimeUtils.getTimestampField(ISOChronology.getInstanceUTC(), unit).roundCeiling(intervalEndInMs + 1) - 1, TimeUnit.MILLISECONDS);
// Return the start and end range
return new long[] { intervalStart, intervalEnd };
}
}
Loading
Loading