From 4f63c7100ddd990a8113785334862666ff1e2d35 Mon Sep 17 00:00:00 2001 From: mohsaka <135669458+mohsaka@users.noreply.github.com> Date: Mon, 27 Oct 2025 14:25:10 -0700 Subject: [PATCH] Final analyzer/planner/optimizer changes for tvf except LocalExecutionPlanner and ExcludeColumns optimizer rule. Co-authored-by: kasiafi <30203062+kasiafi@users.noreply.github.com> Co-authored-by: Xin Zhang --- .../facebook/presto/sql/analyzer/Field.java | 8 - .../sql/analyzer/StatementAnalyzer.java | 2 +- .../sql/planner/BasePlanFragmenter.java | 18 + .../presto/sql/planner/PlanOptimizers.java | 18 +- .../presto/sql/planner/PlannerUtils.java | 19 + .../presto/sql/planner/QueryPlanner.java | 11 +- .../presto/sql/planner/RelationPlanner.java | 137 +- .../sql/planner/SchedulingOrderVisitor.java | 14 + .../sql/planner/TableFunctionUtils.java | 98 ++ .../rule/ImplementTableFunctionSource.java | 1033 ++++++++++++ .../PruneTableFunctionProcessorColumns.java | 88 ++ ...neTableFunctionProcessorSourceColumns.java | 104 ++ .../rule/RemoveRedundantTableFunction.java | 66 + .../rule/RewriteTableFunctionToTableScan.java | 41 +- .../planner/optimizations/AddExchanges.java | 56 +- .../optimizations/AddLocalExchanges.java | 86 + .../optimizations/PropertyDerivations.java | 46 + .../PruneUnreferencedOutputs.java | 21 + .../optimizations/QueryCardinalityUtil.java | 7 + .../StreamPropertyDerivations.java | 30 + .../planner/optimizations/SymbolMapper.java | 151 ++ .../UnaliasSymbolReferences.java | 92 +- .../sql/planner/plan/InternalPlanVisitor.java | 5 + .../presto/sql/planner/plan/Patterns.java | 5 + .../sql/planner/plan/SimplePlanRewriter.java | 5 + .../sql/planner/plan/TableFunctionNode.java | 139 +- .../plan/TableFunctionProcessorNode.java | 234 +++ .../sql/planner/planPrinter/PlanPrinter.java | 184 ++- .../sanity/ValidateDependenciesChecker.java | 112 ++ .../presto/testing/LocalQueryRunner.java | 3 +- .../facebook/presto/util/GraphvizPrinter.java | 24 + .../connector/tvf/TestingTableFunctions.java | 115 +- .../planner/TestTableFunctionInvocation.java | 272 ++++ .../planner/assertions/PlanMatchPattern.java | 26 + .../assertions/TableFunctionMatcher.java | 412 +++++ .../TableFunctionProcessorMatcher.java | 239 +++ .../TestImplementTableFunctionSource.java | 1404 +++++++++++++++++ ...estPruneTableFunctionProcessorColumns.java | 221 +++ ...neTableFunctionProcessorSourceColumns.java | 198 +++ .../TestRemoveRedundantTableFunction.java | 80 + .../iterative/rule/test/PlanBuilder.java | 30 + .../test/TableFunctionProcessorBuilder.java | 140 ++ .../presto/sql/parser/AstBuilder.java | 2 +- 43 files changed, 5876 insertions(+), 120 deletions(-) create mode 100644 presto-main-base/src/main/java/com/facebook/presto/sql/planner/TableFunctionUtils.java create mode 100644 presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementTableFunctionSource.java create mode 100644 presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java create mode 100644 presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java create mode 100644 presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunction.java create mode 100644 presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestImplementTableFunctionSource.java create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunction.java create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java diff --git a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java index 15c33950ce14b..630f4670f6cc2 100644 --- a/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java +++ b/presto-analyzer/src/main/java/com/facebook/presto/sql/analyzer/Field.java @@ -86,14 +86,6 @@ public Field(Optional nodeLocation, Optional relati this.aliased = aliased; } - public static Field newUnqualified(Optional name, Type type) - { - requireNonNull(name, "name is null"); - requireNonNull(type, "type is null"); - - return new Field(Optional.empty(), Optional.empty(), name, type, false, Optional.empty(), Optional.empty(), false); - } - public Optional getNodeLocation() { return nodeLocation; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java index 8e0709d02cafe..8bc3e5029eb70 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/StatementAnalyzer.java @@ -1457,7 +1457,7 @@ private void verifyRequiredColumns(TableFunctionInvocation node, Map column < 0 || column >= inputScope.getRelationType().getAllFieldCount()) // hidden columns can be required as well as visible columns + .filter(column -> column < 0 || column >= inputScope.getRelationType().getVisibleFieldCount()) .findFirst() .ifPresent(column -> { throw new SemanticException(TABLE_FUNCTION_IMPLEMENTATION_ERROR, "Invalid index: %s of required column from table argument %s", column, name); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java index bc77373a5ecbd..912268c8361b0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/BasePlanFragmenter.java @@ -47,6 +47,8 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.sanity.PlanChecker; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -270,6 +272,22 @@ public PlanNode visitValues(ValuesNode node, RewriteContext return context.defaultRewrite(node, context.get()); } + @Override + public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + // context is mutable. The leaf node should set the PartitioningHandle. + context.get().addSourceDistribution(node.getId(), SOURCE_DISTRIBUTION, metadata, session); + } + return context.defaultRewrite(node, context.get()); + } + @Override public PlanNode visitExchange(ExchangeNode exchange, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 2d7c8be053645..98dbabf5392a3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -50,6 +50,7 @@ import com.facebook.presto.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter; import com.facebook.presto.sql.planner.iterative.rule.ImplementFilteredAggregations; import com.facebook.presto.sql.planner.iterative.rule.ImplementOffset; +import com.facebook.presto.sql.planner.iterative.rule.ImplementTableFunctionSource; import com.facebook.presto.sql.planner.iterative.rule.InlineProjections; import com.facebook.presto.sql.planner.iterative.rule.InlineProjectionsOnValues; import com.facebook.presto.sql.planner.iterative.rule.InlineSqlFunctions; @@ -81,6 +82,8 @@ import com.facebook.presto.sql.planner.iterative.rule.PruneRedundantProjectionAssignments; import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneSemiJoinFilteringSourceColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneTableFunctionProcessorColumns; +import com.facebook.presto.sql.planner.iterative.rule.PruneTableFunctionProcessorSourceColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTableScanColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneTopNColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneUpdateSourceColumns; @@ -120,6 +123,7 @@ import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantLimit; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantSort; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantSortColumns; +import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTableFunction; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTopN; import com.facebook.presto.sql.planner.iterative.rule.RemoveRedundantTopNColumns; import com.facebook.presto.sql.planner.iterative.rule.RemoveTrivialFilters; @@ -310,6 +314,8 @@ public PlanOptimizers( new PruneValuesColumns(), new PruneWindowColumns(), new PruneLimitColumns(), + new PruneTableFunctionProcessorColumns(), + new PruneTableFunctionProcessorSourceColumns(), new PruneTableScanColumns()); builder.add(new LogicalCteOptimizer(metadata)); @@ -367,6 +373,14 @@ public PlanOptimizers( PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, sqlParser, expressionOptimizerManager, featuresConfig.isNativeExecutionEnabled())); PlanOptimizer prefilterForLimitingAggregation = new StatsRecordingPlanOptimizer(optimizerStats, new PrefilterForLimitingAggregation(metadata, statsCalculator)); + builder.add( + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of(new RewriteTableFunctionToTableScan(metadata)))); + builder.add( new IterativeOptimizer( metadata, @@ -409,6 +423,7 @@ public PlanOptimizers( .addAll(columnPruningRules) .addAll(ImmutableSet.of( new MergeDuplicateAggregation(metadata.getFunctionAndTypeManager()), + new ImplementTableFunctionSource(metadata), new RemoveRedundantIdentityProjections(), new RemoveFullSample(), new EvaluateZeroSample(), @@ -423,6 +438,7 @@ public PlanOptimizers( new PushLimitThroughSemiJoin(), new PushLimitThroughUnion(), new RemoveTrivialFilters(), + new RemoveRedundantTableFunction(), new ImplementFilteredAggregations(metadata.getFunctionAndTypeManager()), new SingleDistinctAggregationToGroupBy(), new MultipleDistinctAggregationToMarkDistinct(), @@ -762,7 +778,7 @@ public PlanOptimizers( ruleStats, statsCalculator, estimatedExchangesCostCalculator, - ImmutableSet.of(new RemoveRedundantIdentityProjections(), new PruneRedundantProjectionAssignments())), + ImmutableSet.of(new RemoveRedundantIdentityProjections(), new PruneRedundantProjectionAssignments(), new RemoveRedundantTableFunction())), new PushdownSubfields(metadata, expressionOptimizerManager)); builder.add(predicatePushDown); // Run predicate push down one more time in case we can leverage new information from layouts' effective predicate diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index 826eaea10044b..0448dece8ca32 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -43,6 +43,7 @@ import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.analyzer.Field; +import com.facebook.presto.sql.analyzer.RelationType; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; @@ -205,6 +206,9 @@ public static PlanNode addOverrideProjection(PlanNode source, PlanNodeIdAllocato || source.getOutputVariables().stream().distinct().count() != source.getOutputVariables().size()) { return source; } + if (source instanceof ProjectNode && ((ProjectNode) source).getAssignments().getMap().equals(variableMap)) { + return source; + } Assignments.Builder assignmentsBuilder = Assignments.builder(); assignmentsBuilder.putAll(source.getOutputVariables().stream().collect(toImmutableMap(identity(), x -> variableMap.containsKey(x) ? variableMap.get(x) : x))); return new ProjectNode(source.getSourceLocation(), planNodeIdAllocator.getNextId(), source, assignmentsBuilder.build(), LOCAL); @@ -574,4 +578,19 @@ public static RowExpression randomizeJoinKey(Session session, FunctionAndTypeMan } return new SpecialFormExpression(COALESCE, VARCHAR, ImmutableList.of(castToVarchar, concatExpression)); } + + public static int[] getFieldIndexesForVisibleColumns(RelationPlan sourcePlan) + { + // required columns are a subset of visible columns of the source. remap required column indexes to field indexes in source relation type. + RelationType sourceRelationType = sourcePlan.getScope().getRelationType(); + int[] fieldIndexForVisibleColumn = new int[sourceRelationType.getVisibleFieldCount()]; + int visibleColumn = 0; + for (int i = 0; i < sourceRelationType.getAllFieldCount(); i++) { + if (!sourceRelationType.getFieldByIndex(i).isHidden()) { + fieldIndexForVisibleColumn[visibleColumn] = i; + visibleColumn++; + } + } + return fieldIndexForVisibleColumn; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java index e2c79cb02b29c..d81385d100a8b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/QueryPlanner.java @@ -143,7 +143,7 @@ import static java.lang.String.format; import static java.util.Objects.requireNonNull; -class QueryPlanner +public class QueryPlanner { private final Analysis analysis; private final VariableAllocator variableAllocator; @@ -524,7 +524,7 @@ private PlanBuilder project(PlanBuilder subPlan, Iterable expression * * @return the new subplan and a mapping of each expression to the symbol representing the coercion or an existing symbol if a coercion wasn't needed */ - private PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) + public PlanAndMappings coerce(PlanBuilder subPlan, List expressions, Analysis analysis, PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, Metadata metadata) { Assignments.Builder assignments = Assignments.builder(); assignments.putAll(subPlan.getRoot().getOutputVariables().stream().collect(toImmutableMap(Function.identity(), Function.identity()))); @@ -1346,7 +1346,7 @@ private RowExpression rowExpression(Expression expression, SqlPlannerContext con context.getTranslatorContext()); } - private static List toSymbolReferences(List variables) + public static List toSymbolReferences(List variables) { return variables.stream() .map(variable -> new SymbolReference( @@ -1355,6 +1355,11 @@ private static List toSymbolReferences(List new NodeLocation(location.getLine(), location.getColumn())), variable.getName()); + } + public static class PlanAndMappings { private final PlanBuilder subPlan; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java index 0bd4e47fc67b8..f661fd2653f9a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/RelationPlanner.java @@ -30,11 +30,13 @@ import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.CteReferenceNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.ExceptNode; import com.facebook.presto.spi.plan.FilterNode; import com.facebook.presto.spi.plan.IntersectNode; import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; @@ -50,14 +52,15 @@ import com.facebook.presto.sql.analyzer.Field; import com.facebook.presto.sql.analyzer.RelationId; import com.facebook.presto.sql.analyzer.RelationType; +import com.facebook.presto.sql.analyzer.ResolvedField; import com.facebook.presto.sql.analyzer.Scope; -import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.parser.SqlParser; import com.facebook.presto.sql.planner.optimizations.JoinNodeUtils; import com.facebook.presto.sql.planner.optimizations.SampleNodeUtil; import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; import com.facebook.presto.sql.tree.AliasedRelation; import com.facebook.presto.sql.tree.Cast; import com.facebook.presto.sql.tree.CoalesceExpression; @@ -87,9 +90,7 @@ import com.facebook.presto.sql.tree.SetOperation; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.Table; -import com.facebook.presto.sql.tree.TableFunctionDescriptorArgument; import com.facebook.presto.sql.tree.TableFunctionInvocation; -import com.facebook.presto.sql.tree.TableFunctionTableArgument; import com.facebook.presto.sql.tree.TableSubquery; import com.facebook.presto.sql.tree.Union; import com.facebook.presto.sql.tree.Unnest; @@ -117,6 +118,7 @@ import static com.facebook.presto.SystemSessionProperties.getQueryAnalyzerTimeout; import static com.facebook.presto.common.type.TypeUtils.isEnumType; import static com.facebook.presto.metadata.MetadataUtil.createQualifiedObjectName; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_PLAN_ERROR; import static com.facebook.presto.spi.StandardErrorCode.QUERY_PLANNING_TIMEOUT; import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL; @@ -125,9 +127,10 @@ import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isEqualComparisonExpression; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral; import static com.facebook.presto.sql.analyzer.FeaturesConfig.CteMaterializationStrategy.NONE; -import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED; import static com.facebook.presto.sql.analyzer.SemanticExceptions.notSupportedException; import static com.facebook.presto.sql.planner.PlannerUtils.newVariable; +import static com.facebook.presto.sql.planner.TableFunctionUtils.addPassthroughColumns; +import static com.facebook.presto.sql.planner.TableFunctionUtils.getOrderingScheme; import static com.facebook.presto.sql.planner.TranslateExpressionsUtil.toRowExpression; import static com.facebook.presto.sql.tree.Join.Type.INNER; import static com.facebook.presto.sql.tree.Join.Type.LEFT; @@ -304,51 +307,109 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan, SqlPlannerCo return new RelationPlan(planBuilder.getRoot(), plan.getScope(), newMappings.build()); } + /** + * Processes a {@code TableFunctionInvocation} node to construct and return a {@link RelationPlan}. + * This involves preparing the necessary plan nodes, variable mappings, and associated properties + * to represent the execution plan for the invoked table function. + * + * @param node The {@code TableFunctionInvocation} syntax tree node to be processed. + * @param context The SQL planner context used for planning and analysis tasks. + * @return A {@link RelationPlan} encapsulating the execution plan for the table function invocation. + */ @Override protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node, SqlPlannerContext context) { - node.getArguments().stream() - .forEach(argument -> { - if (argument.getValue() instanceof TableFunctionTableArgument) { - throw new SemanticException(NOT_SUPPORTED, argument, "Table arguments are not yet supported for table functions"); - } - if (argument.getValue() instanceof TableFunctionDescriptorArgument) { - throw new SemanticException(NOT_SUPPORTED, argument, "Descriptor arguments are not yet supported for table functions"); - } - }); Analysis.TableFunctionInvocationAnalysis functionAnalysis = analysis.getTableFunctionAnalysis(node); + ImmutableList.Builder sources = ImmutableList.builder(); + ImmutableList.Builder sourceProperties = ImmutableList.builder(); + ImmutableList.Builder outputVariables = ImmutableList.builder(); + + // create new symbols for table function's proper columns + RelationType relationType = analysis.getScope(node).getRelationType(); + List properOutputs = IntStream.range(0, functionAnalysis.getProperColumnsCount()) + .mapToObj(relationType::getFieldByIndex) + .map(field -> variableAllocator.newVariable(getSourceLocation(node), field.getName().orElse("field"), field.getType())) + .collect(toImmutableList()); - // TODO handle input relations: - // 1. extract the input relations from node.getArguments() and plan them. Apply relation coercions if requested. - // 2. for each input relation, prepare the TableArgumentProperties record, consisting of: - // - row or set semantics (from the actualArgument) - // - prune when empty property (from the actualArgument) - // - pass through columns property (from the actualArgument) - // - optional Specification: ordering scheme and partitioning (from the node's argument) <- planned upon the source's RelationPlan (or combined RelationPlan from all sources) - // TODO add - argument name - // TODO add - mapping column name => Symbol // TODO mind the fields without names and duplicate field names in RelationType - List sources = ImmutableList.of(); - List inputRelationsProperties = ImmutableList.of(); - - Scope scope = analysis.getScope(node); - - ImmutableList.Builder outputVariablesBuilder = ImmutableList.builder(); - for (Field field : scope.getRelationType().getAllFields()) { - VariableReferenceExpression variable = variableAllocator.newVariable(getSourceLocation(node), field.getName().get(), field.getType()); - outputVariablesBuilder.add(variable); - } + outputVariables.addAll(properOutputs); + QueryPlanner partitionQueryPlanner = new QueryPlanner(analysis, variableAllocator, idAllocator, lambdaDeclarationToVariableMap, metadata, session, context, sqlParser); + processTableArguments(context, functionAnalysis, outputVariables, sources, sourceProperties, partitionQueryPlanner); - List outputVariables = outputVariablesBuilder.build(); PlanNode root = new TableFunctionNode( idAllocator.getNextId(), functionAnalysis.getFunctionName(), functionAnalysis.getArguments(), - outputVariablesBuilder.build(), - sources.stream().map(RelationPlan::getRoot).collect(toImmutableList()), - inputRelationsProperties, - new TableFunctionHandle(functionAnalysis.getConnectorId(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle())); + properOutputs, + sources.build(), + sourceProperties.build(), + functionAnalysis.getCopartitioningLists(), + new TableFunctionHandle( + functionAnalysis.getConnectorId(), + functionAnalysis.getConnectorTableFunctionHandle(), + functionAnalysis.getTransactionHandle())); + + return new RelationPlan(root, analysis.getScope(node), outputVariables.build()); + } - return new RelationPlan(root, scope, outputVariables); + private void processTableArguments(SqlPlannerContext context, + Analysis.TableFunctionInvocationAnalysis functionAnalysis, + ImmutableList.Builder outputVariables, + ImmutableList.Builder sources, + ImmutableList.Builder sourceProperties, + QueryPlanner partitionQueryPlanner) + { + // process sources in order of argument declarations + for (Analysis.TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { + RelationPlan sourcePlan = process(tableArgument.getRelation(), context); + PlanBuilder sourcePlanBuilder = initializePlanBuilder(sourcePlan); + + int[] fieldIndexForVisibleColumn = PlannerUtils.getFieldIndexesForVisibleColumns(sourcePlan); + + List requiredColumns = functionAnalysis.getRequiredColumns().get(tableArgument.getArgumentName()).stream() + .map(column -> fieldIndexForVisibleColumn[column]) + .map(sourcePlan::getVariable) + .collect(toImmutableList()); + + Optional specification = Optional.empty(); + + // if the table argument has set semantics, create Specification + if (!tableArgument.isRowSemantics()) { + // partition by + List partitionBy = ImmutableList.of(); + // if there are partitioning columns, they might have to be coerced for copartitioning + if (tableArgument.getPartitionBy().isPresent() && !tableArgument.getPartitionBy().get().isEmpty()) { + List partitioningColumns = tableArgument.getPartitionBy().get(); + for (Expression partitionColumn : partitioningColumns) { + if (!sourcePlanBuilder.canTranslate(partitionColumn)) { + ResolvedField partition = sourcePlan.getScope().tryResolveField(partitionColumn).orElseThrow(() -> new PrestoException(INVALID_PLAN_ERROR, "Missing equivalent alias")); + sourcePlanBuilder.getTranslations().put(partitionColumn, sourcePlan.getVariable(partition.getRelationFieldIndex())); + } + } + QueryPlanner.PlanAndMappings copartitionCoercions = partitionQueryPlanner.coerce(sourcePlanBuilder, partitioningColumns, analysis, idAllocator, variableAllocator, metadata); + sourcePlanBuilder = copartitionCoercions.getSubPlan(); + partitionBy = partitioningColumns.stream() + .map(copartitionCoercions::get) + .collect(toImmutableList()); + } + + // order by + Optional orderBy = getOrderingScheme(tableArgument, sourcePlanBuilder, sourcePlan); + specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy)); + } + + // add output symbols passed from the table argument + ImmutableList.Builder passThroughColumns = ImmutableList.builder(); + addPassthroughColumns(outputVariables, tableArgument, sourcePlan, specification, passThroughColumns, sourcePlanBuilder); + sources.add(sourcePlanBuilder.getRoot()); + + sourceProperties.add(new TableFunctionNode.TableArgumentProperties( + tableArgument.getArgumentName(), + tableArgument.isRowSemantics(), + tableArgument.isPruneWhenEmpty(), + new TableFunctionNode.PassThroughSpecification(tableArgument.isPassThroughColumns(), passThroughColumns.build()), + requiredColumns, + specification)); + } } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java index abb784cdaa298..471c797c426a8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/SchedulingOrderVisitor.java @@ -22,9 +22,11 @@ import com.facebook.presto.spi.plan.SpatialJoinNode; import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.sql.planner.plan.InternalPlanVisitor; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.NoSuchElementException; import java.util.function.Consumer; public class SchedulingOrderVisitor @@ -88,5 +90,17 @@ public Void visitTableScan(TableScanNode node, Consumer schedulingOr schedulingOrder.accept(node.getId()); return null; } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Consumer schedulingOrder) + { + if (!node.getSource().isPresent()) { + schedulingOrder.accept(node.getId()); + } + else { + node.getSource().orElseThrow(NoSuchElementException::new).accept(this, schedulingOrder); + } + return null; + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/TableFunctionUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/TableFunctionUtils.java new file mode 100644 index 0000000000000..0f8e6fb1685c4 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/TableFunctionUtils.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.common.block.SortOrder; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.analyzer.Analysis; +import com.facebook.presto.sql.analyzer.ResolvedField; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SortItem; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +public class TableFunctionUtils +{ + private TableFunctionUtils() {} + + static Optional getOrderingScheme(Analysis.TableArgumentAnalysis tableArgument, PlanBuilder sourcePlanBuilder, RelationPlan sourcePlan) + { + Optional orderBy = Optional.empty(); + if (tableArgument.getOrderBy().isPresent()) { + List sortItems = tableArgument.getOrderBy().get().getSortItems(); + + // Ensure all ORDER BY columns can be translated (populate missing translations if needed) + for (SortItem sortItem : sortItems) { + Expression sortKey = sortItem.getSortKey(); + if (!sourcePlanBuilder.canTranslate(sortKey)) { + Optional resolvedField = sourcePlan.getScope().tryResolveField(sortKey); + resolvedField.ifPresent(field -> sourcePlanBuilder.getTranslations().put( + sortKey, + sourcePlan.getVariable(field.getRelationFieldIndex()))); + } + } + + // The ordering symbols are coerced + List coerced = sortItems.stream() + .map(SortItem::getSortKey) + .map(sourcePlanBuilder::translate) + .collect(toImmutableList()); + + List sortOrders = sortItems.stream() + .map(PlannerUtils::toSortOrder) + .collect(toImmutableList()); + + orderBy = Optional.of(PlannerUtils.toOrderingScheme(coerced, sortOrders)); + } + return orderBy; + } + + static void addPassthroughColumns(ImmutableList.Builder outputVariables, + Analysis.TableArgumentAnalysis tableArgument, RelationPlan sourcePlan, + Optional specification, + ImmutableList.Builder passThroughColumns, + PlanBuilder sourcePlanBuilder) + { + if (tableArgument.isPassThroughColumns()) { + // the original output symbols from the source node, not coerced + // note: hidden columns are included. They are present in sourcePlan.fieldMappings + outputVariables.addAll(sourcePlan.getFieldMappings()); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + sourcePlan.getFieldMappings().stream() + .map(variable -> new TableFunctionNode.PassThroughColumn(variable, partitionBy.contains(variable))) + .forEach(passThroughColumns::add); + } + else if (tableArgument.getPartitionBy().isPresent()) { + tableArgument.getPartitionBy().get().stream() + .map(sourcePlanBuilder::translate) + // the original symbols for partitioning columns, not coerced + .forEach(variable -> { + outputVariables.add(variable); + passThroughColumns.add(new TableFunctionNode.PassThroughColumn(variable, true)); + }); + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementTableFunctionSource.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementTableFunctionSource.java new file mode 100644 index 0000000000000..5673fa2882537 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/ImplementTableFunctionSource.java @@ -0,0 +1,1033 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.WindowNode; +import com.facebook.presto.spi.plan.WindowNode.Frame; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.tree.QualifiedName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; + +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.spi.plan.JoinType.FULL; +import static com.facebook.presto.spi.plan.JoinType.INNER; +import static com.facebook.presto.spi.plan.JoinType.LEFT; +import static com.facebook.presto.spi.plan.JoinType.RIGHT; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_FOLLOWING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.BoundType.UNBOUNDED_PRECEDING; +import static com.facebook.presto.spi.plan.WindowNode.Frame.WindowType.ROWS; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunction; +import static com.facebook.presto.sql.relational.Expressions.coalesce; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +/** + * This rule prepares cartesian product of partitions + * from all inputs of table function. + *

+ * It rewrites TableFunctionNode with potentially many sources + * into a TableFunctionProcessorNode. The new node has one + * source being a combination of the original sources. + *

+ * The original sources are combined with joins. The join + * conditions depend on the prune when empty property, and on + * the co-partitioning of sources. + *

+ * The resulting source should be partitioned and ordered + * according to combined schemas from the component sources. + *

+ * Example transformation for two sources, both with set semantics + * and KEEP WHEN EMPTY property: + *

+ * - TableFunction foo
+ *      - source T1(a1, b1) PARTITION BY a1 ORDER BY b1
+ *      - source T2(a2, b2) PARTITION BY a2
+ * 
+ * Is transformed into: + *
+ * - TableFunctionDataProcessor foo
+ *      PARTITION BY (a1, a2), ORDER BY combined_row_number
+ *      - Project
+ *          marker_1 <= IF(table1_row_number = combined_row_number, table1_row_number, CAST(null AS bigint))
+ *          marker_2 <= IF(table2_row_number = combined_row_number, table2_row_number, CAST(null AS bigint))
+ *          - Project
+ *              combined_row_number <= IF(COALESCE(table1_row_number, BIGINT '-1') > COALESCE(table2_row_number, BIGINT '-1'), table1_row_number, table2_row_number)
+ *              combined_partition_size <= IF(COALESCE(table1_partition_size, BIGINT '-1') > COALESCE(table2_partition_size, BIGINT '-1'), table1_partition_size, table2_partition_size)
+ *              - FULL Join
+ *                  [table1_row_number = table2_row_number OR
+ *                   table1_row_number > table2_partition_size AND table2_row_number = BIGINT '1' OR
+ *                   table2_row_number > table1_partition_size AND table1_row_number = BIGINT '1']
+ *                  - Window [PARTITION BY a1 ORDER BY b1]
+ *                      table1_row_number <= row_number()
+ *                      table1_partition_size <= count()
+ *                          - source T1(a1, b1)
+ *                  - Window [PARTITION BY a2]
+ *                      table2_row_number <= row_number()
+ *                      table2_partition_size <= count()
+ *                          - source T2(a2, b2)
+ * 
+ */ +public class ImplementTableFunctionSource + implements Rule +{ + private static final Pattern PATTERN = tableFunction(); + private static final Frame FULL_FRAME = new Frame( + ROWS, + UNBOUNDED_PRECEDING, + Optional.empty(), + Optional.empty(), + UNBOUNDED_FOLLOWING, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + private static final DataOrganizationSpecification UNORDERED_SINGLE_PARTITION = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); + + private final Metadata metadata; + + public ImplementTableFunctionSource(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionNode node, Captures captures, Context context) + { + if (node.getSources().isEmpty()) { + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.empty(), + false, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + if (node.getSources().size() == 1) { + // Single source does not require pre-processing. + // If the source has row semantics, its specification is empty. + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // This property can be used later to choose optimal distribution. + TableArgumentProperties sourceProperties = getOnlyElement(node.getTableArgumentProperties()); + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(getOnlyElement(node.getSources())), + sourceProperties.isPruneWhenEmpty(), + ImmutableList.of(sourceProperties.getPassThroughSpecification()), + ImmutableList.of(sourceProperties.getRequiredColumns()), + Optional.empty(), + sourceProperties.getSpecification(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + Map sources = mapSourcesByName(node.getSources(), node.getTableArgumentProperties()); + ImmutableList.Builder intermediateResultsBuilder = ImmutableList.builder(); + + FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); + + // Create call expression for row_number + FunctionHandle rowNumberFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("row_number")), + ImmutableList.of()); + + FunctionMetadata rowNumberFunctionMetadata = functionAndTypeManager.getFunctionMetadata(rowNumberFunctionHandle); + CallExpression rowNumberFunction = new CallExpression("row_number", rowNumberFunctionHandle, functionAndTypeManager.getType(rowNumberFunctionMetadata.getReturnType()), ImmutableList.of()); + + // Create call expression for count + FunctionHandle countFunctionHandle = functionAndTypeManager.resolveFunction(Optional.of(context.getSession().getSessionFunctions()), + context.getSession().getTransactionId(), + functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("count")), + ImmutableList.of()); + + FunctionMetadata countFunctionMetadata = functionAndTypeManager.getFunctionMetadata(countFunctionHandle); + CallExpression countFunction = new CallExpression("count", countFunctionHandle, functionAndTypeManager.getType(countFunctionMetadata.getReturnType()), ImmutableList.of()); + + // handle co-partitioned sources + for (List copartitioningList : node.getCopartitioningLists()) { + List sourceList = copartitioningList.stream() + .map(sources::get) + .collect(toImmutableList()); + intermediateResultsBuilder.add(copartition(sourceList, rowNumberFunction, countFunction, context, metadata)); + } + + // prepare non-co-partitioned sources + Set copartitionedSources = node.getCopartitioningLists().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + sources.entrySet().stream() + .filter(entry -> !copartitionedSources.contains(entry.getKey())) + .map(entry -> planWindowFunctionsForSource(entry.getValue().source(), entry.getValue().properties(), rowNumberFunction, countFunction, context)) + .forEach(intermediateResultsBuilder::add); + + NodeWithVariables finalResultSource; + + List intermediateResultSources = intermediateResultsBuilder.build(); + if (intermediateResultSources.size() == 1) { + finalResultSource = getOnlyElement(intermediateResultSources); + } + else { + NodeWithVariables first = intermediateResultSources.get(0); + NodeWithVariables second = intermediateResultSources.get(1); + JoinedNodes joined = join(first, second, context, metadata); + + for (int i = 2; i < intermediateResultSources.size(); i++) { + NodeWithVariables joinedWithSymbols = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + joined = join(joinedWithSymbols, intermediateResultSources.get(i), context, metadata); + } + + finalResultSource = appendHelperSymbolsForJoinedNodes(joined, context, metadata); + } + + // For each source, all source's output symbols are mapped to the source's row number symbol. + // The row number symbol will be later converted to a marker of "real" input rows vs "filler" input rows of the source. + // The "filler" input rows are the rows appended while joining partitions of different lengths, + // to fill the smaller partition up to the bigger partition's size. They are a side effect of the algorithm, + // and should not be processed by the table function. + Map rowNumberSymbols = finalResultSource.rowNumberSymbolsMapping(); + + // The max row number symbol from all joined partitions. + VariableReferenceExpression finalRowNumberSymbol = finalResultSource.rowNumber(); + // Combined partitioning lists from all sources. + List finalPartitionBy = finalResultSource.partitionBy(); + + NodeWithMarkers marked = appendMarkerSymbols(finalResultSource.node(), ImmutableSet.copyOf(rowNumberSymbols.values()), finalRowNumberSymbol, context, metadata); + + // Remap the symbol mapping: replace the row number symbol with the corresponding marker symbol. + // In the new map, every source symbol is associated with the corresponding marker symbol. + // Null value of the marker indicates that the source value should be ignored by the table function. + ImmutableMap markerSymbols = rowNumberSymbols.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> marked.variableToMarker().get(entry.getValue()))); + + // Use the final row number symbol for ordering the combined sources. + // It runs along each partition in the cartesian product, numbering the partition's rows according to the expected ordering / orderings. + // note: ordering is necessary even if all the source tables are not ordered. Thanks to the ordering, the original rows + // of each input table come before the "filler" rows. + ImmutableList.Builder newOrderings = ImmutableList.builder(); + newOrderings.add(new Ordering(finalRowNumberSymbol, ASC_NULLS_LAST)); + Optional finalOrderBy = Optional.of(new OrderingScheme(newOrderings.build())); + + // derive the prune when empty property + boolean pruneWhenEmpty = node.getTableArgumentProperties().stream().anyMatch(TableArgumentProperties::isPruneWhenEmpty); + + // Combine the pass through specifications from all sources + List passThroughSpecifications = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .collect(toImmutableList()); + + // Combine the required symbols from all sources + List> requiredVariables = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getRequiredColumns) + .collect(toImmutableList()); + + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(marked.node()), + pruneWhenEmpty, + passThroughSpecifications, + requiredVariables, + Optional.of(markerSymbols), + Optional.of(new DataOrganizationSpecification(finalPartitionBy, finalOrderBy)), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + private static Map mapSourcesByName(List sources, List properties) + { + return Streams.zip(sources.stream(), properties.stream(), SourceWithProperties::new) + .collect(toImmutableMap(entry -> entry.properties().getArgumentName(), identity())); + } + + private static NodeWithVariables planWindowFunctionsForSource( + PlanNode source, + TableArgumentProperties argumentProperties, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context) + { + String argumentName = argumentProperties.getArgumentName(); + + VariableReferenceExpression rowNumber = context.getVariableAllocator().newVariable(argumentName + "_row_number", BIGINT); + Map rowNumberSymbolMapping = source.getOutputVariables().stream() + .collect(toImmutableMap(identity(), symbol -> rowNumber)); + + VariableReferenceExpression partitionSize = context.getVariableAllocator().newVariable(argumentName + "_partition_size", BIGINT); + + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // If the source has row semantics, its specification is empty. Currently, such source is processed + // as if it was a single partition. Alternatively, it could be split into smaller partitions of arbitrary size. + DataOrganizationSpecification specification = argumentProperties.getSpecification().orElse(UNORDERED_SINGLE_PARTITION); + + PlanNode innerWindow = new WindowNode( + source.getSourceLocation(), + context.getIdAllocator().getNextId(), + source, + specification, + ImmutableMap.of( + rowNumber, new WindowNode.Function(rowNumberFunction, FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + PlanNode window = new WindowNode( + innerWindow.getSourceLocation(), + context.getIdAllocator().getNextId(), + innerWindow, + specification, + ImmutableMap.of( + partitionSize, new WindowNode.Function(countFunction, FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + + return new NodeWithVariables(window, rowNumber, partitionSize, specification.getPartitionBy(), argumentProperties.isPruneWhenEmpty(), rowNumberSymbolMapping); + } + + private static NodeWithVariables copartition( + List sourceList, + CallExpression rowNumberFunction, + CallExpression countFunction, + Context context, + Metadata metadata) + { + checkArgument(sourceList.size() >= 2, "co-partitioning list should contain at least two tables"); + + // Reorder the co-partitioned sources to process the sources with prune when empty property first. + // It allows to use inner or side joins instead of outer joins. + sourceList = sourceList.stream() + .sorted(Comparator.comparingInt(source -> source.properties().isPruneWhenEmpty() ? -1 : 1)) + .collect(toImmutableList()); + + NodeWithVariables first = planWindowFunctionsForSource(sourceList.get(0).source(), sourceList.get(0).properties(), rowNumberFunction, countFunction, context); + NodeWithVariables second = planWindowFunctionsForSource(sourceList.get(1).source(), sourceList.get(1).properties(), rowNumberFunction, countFunction, context); + JoinedNodes copartitioned = copartition(first, second, context, metadata); + + for (int i = 2; i < sourceList.size(); i++) { + NodeWithVariables copartitionedWithSymbols = appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + NodeWithVariables next = planWindowFunctionsForSource(sourceList.get(i).source(), sourceList.get(i).properties(), rowNumberFunction, countFunction, context); + copartitioned = copartition(copartitionedWithSymbols, next, context, metadata); + } + + return appendHelperSymbolsForCopartitionedNodes(copartitioned, context, metadata); + } + + private static JoinedNodes copartition(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + checkArgument(left.partitionBy().size() == right.partitionBy().size(), "co-partitioning lists do not match"); + + // In StatementAnalyzer we require that co-partitioned tables have non-empty partitioning column lists. + // Co-partitioning tables with empty partition by would be ineffective. + checkState(!left.partitionBy().isEmpty(), "co-partitioned tables must have partitioning columns"); + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + + Optional copartitionConjuncts = Streams.zip( + left.partitionBy.stream(), + right.partitionBy.stream(), + (leftColumn, rightColumn) -> new CallExpression("NOT", + functionResolution.notFunction(), + BOOLEAN, + ImmutableList.of( + new CallExpression(IS_DISTINCT_FROM.name(), + functionResolution.comparisonFunction(IS_DISTINCT_FROM, INTEGER, INTEGER), + BOOLEAN, + ImmutableList.of(leftColumn, rightColumn))))) + .map(expr -> expr) + .reduce((expr, conjunct) -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(expr, conjunct))); + + // Align matching partitions (co-partitions) from left and right source, according to row number. + // Matching partitions are identified by their corresponding partitioning columns being NOT DISTINCT from each other. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. + // It preserves the outstanding rows from the bigger partition, matching them to the first row from the smaller partition. + // + // (P1_1 IS NOT DISTINCT FROM P2_1) AND (P1_2 IS NOT DISTINCT FROM P2_2) AND ... + // AND ( + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1)) + + SpecialFormExpression orExpression = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + RowExpression joinCondition = copartitionConjuncts.map( + conjunct -> new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of(conjunct, orExpression))) + .orElse(orExpression); + + // The join type depends on the prune when empty property of the sources. + // If a source is prune when empty, we should not process any co-partition which is not present in this source, + // so effectively the other source becomes inner side of the join. + // + // example: + // table T1 partition by P1 table T2 partition by P2 + // P1 C1 P2 C2 + // ---------- ---------- + // 1 'a' 2 'c' + // 2 'b' 3 'd' + // + // co-partitioning results: + // 1) T1 is prune when empty: do LEFT JOIN to drop co-partition '3' + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // + // 2) T2 is prune when empty: do RIGHT JOIN to drop co-partition '1' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // null null 3 'd' + // + // 3) T1 and T2 are both prune when empty: do INNER JOIN to drop co-partitions '1' and '3' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // + // 4) neither table is prune when empty: do FULL JOIN to preserve all co-partitions + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // null null 3 'd' + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForCopartitionedNodes( + JoinedNodes copartitionedNodes, + Context context, + Metadata metadata) + { + checkArgument(copartitionedNodes.leftPartitionBy().size() == copartitionedNodes.rightPartitionBy().size(), "co-partitioning lists do not match"); + + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftRowNumber(), + copartitionedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + copartitionedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + copartitionedNodes.leftPartitionSize(), + copartitionedNodes.rightPartitionSize())); + + // Derive partitioning columns for joined partitions. + // Either the combined partitioning columns are pairwise NOT DISTINCT (this is the co-partitioning rule), + // or one of them is null as a result of outer join. + ImmutableList.Builder joinedPartitionBy = ImmutableList.builder(); + Assignments.Builder joinedPartitionByAssignments = Assignments.builder(); + for (int i = 0; i < copartitionedNodes.leftPartitionBy().size(); i++) { + VariableReferenceExpression leftColumn = copartitionedNodes.leftPartitionBy().get(i); + VariableReferenceExpression rightColumn = copartitionedNodes.rightPartitionBy().get(i); + Type type = context.getVariableAllocator().getVariables().get(leftColumn.getName()); + + VariableReferenceExpression joinedColumn = context.getVariableAllocator().newVariable("combined_partition_column", type); + joinedPartitionByAssignments.put(joinedColumn, coalesce(leftColumn, rightColumn)); + joinedPartitionBy.add(joinedColumn); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + copartitionedNodes.joinedNode(), + Assignments.builder() + .putAll( + copartitionedNodes.joinedNode().getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .putAll(joinedPartitionByAssignments.build()) + .build()); + boolean joinedPruneWhenEmpty = copartitionedNodes.leftPruneWhenEmpty() || copartitionedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(copartitionedNodes.leftRowNumberSymbolsMapping()) + .putAll(copartitionedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy.build(), joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static JoinedNodes join(NodeWithVariables left, NodeWithVariables right, Context context, Metadata metadata) + { + // Align rows from left and right source according to row number. Because every partition is row-numbered, this produces cartesian product of partitions. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. It preserves the outstanding rows + // from the bigger partition, matching them to the first row from the smaller partition. + // + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1) + + FunctionResolution functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + RowExpression joinCondition = new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.rowNumber)), + new SpecialFormExpression(SpecialFormExpression.Form.OR, + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, right.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, new ConstantExpression(1L, BIGINT))))), + new SpecialFormExpression(SpecialFormExpression.Form.AND, + BOOLEAN, + ImmutableList.of( + new CallExpression(GREATER_THAN.name(), + functionResolution.comparisonFunction(GREATER_THAN, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(right.rowNumber, left.partitionSize)), + new CallExpression(EQUAL.name(), + functionResolution.comparisonFunction(EQUAL, BIGINT, BIGINT), + BOOLEAN, + ImmutableList.of(left.rowNumber, new ConstantExpression(1L, BIGINT))))))))); + JoinType joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + Optional.empty(), + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + Stream.concat(left.node().getOutputVariables().stream(), + right.node().getOutputVariables().stream()) + .collect(Collectors.toList()), + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithVariables appendHelperSymbolsForJoinedNodes(JoinedNodes joinedNodes, Context context, Metadata metadata) + { + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedRowNumber = context.getVariableAllocator().newVariable("combined_row_number", BIGINT); + RowExpression rowNumberExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftRowNumber(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightRowNumber(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftRowNumber(), + joinedNodes.rightRowNumber())); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + VariableReferenceExpression joinedPartitionSize = context.getVariableAllocator().newVariable("combined_partition_size", BIGINT); + RowExpression partitionSizeExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + GREATER_THAN.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.GREATER_THAN, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of( + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.leftPartitionSize(), + new ConstantExpression(-1L, BIGINT)), + new SpecialFormExpression( + COALESCE, + BIGINT, + joinedNodes.rightPartitionSize(), + new ConstantExpression(-1L, BIGINT)))), + joinedNodes.leftPartitionSize(), + joinedNodes.rightPartitionSize())); + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + joinedNodes.joinedNode(), + Assignments.builder() + .putAll( + joinedNodes.joinedNode().getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .build()); + + List joinedPartitionBy = ImmutableList.builder() + .addAll(joinedNodes.leftPartitionBy()) + .addAll(joinedNodes.rightPartitionBy()) + .build(); + + boolean joinedPruneWhenEmpty = joinedNodes.leftPruneWhenEmpty() || joinedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(joinedNodes.leftRowNumberSymbolsMapping()) + .putAll(joinedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithVariables(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy, joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static NodeWithMarkers appendMarkerSymbols(PlanNode node, Set variables, VariableReferenceExpression referenceSymbol, Context context, Metadata metadata) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putAll( + node.getOutputVariables().stream() + .collect(toImmutableMap(v -> v, v -> v))); + + ImmutableMap.Builder variablesToMarkers = ImmutableMap.builder(); + + for (VariableReferenceExpression variable : variables) { + VariableReferenceExpression marker = context.getVariableAllocator().newVariable("marker", BIGINT); + variablesToMarkers.put(variable, marker); + RowExpression ifExpression = new SpecialFormExpression( + IF, + BIGINT, + ImmutableList.of( + new CallExpression( + EQUAL.name(), + metadata.getFunctionAndTypeManager().resolveOperator( + OperatorType.EQUAL, + fromTypes(BIGINT, BIGINT)), + BOOLEAN, + ImmutableList.of(variable, referenceSymbol)), + variable, + new ConstantExpression(null, BIGINT))); + assignments.put(marker, ifExpression); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + node, + assignments.build()); + + return new NodeWithMarkers(project, variablesToMarkers.buildOrThrow()); + } + + private static class SourceWithProperties + { + private final PlanNode source; + private final TableArgumentProperties properties; + + public SourceWithProperties(PlanNode source, TableArgumentProperties properties) + { + this.source = requireNonNull(source, "source is null"); + this.properties = requireNonNull(properties, "properties is null"); + } + + public PlanNode source() + { + return source; + } + + public TableArgumentProperties properties() + { + return properties; + } + } + + public static final class NodeWithVariables + { + private final PlanNode node; + private final VariableReferenceExpression rowNumber; + private final VariableReferenceExpression partitionSize; + private final List partitionBy; + private final boolean pruneWhenEmpty; + private final Map rowNumberSymbolsMapping; + + public NodeWithVariables(PlanNode node, VariableReferenceExpression rowNumber, VariableReferenceExpression partitionSize, + List partitionBy, boolean pruneWhenEmpty, + Map rowNumberSymbolsMapping) + { + this.node = requireNonNull(node, "node is null"); + this.rowNumber = requireNonNull(rowNumber, "rowNumber is null"); + this.partitionSize = requireNonNull(partitionSize, "partitionSize is null"); + this.partitionBy = ImmutableList.copyOf(partitionBy); + this.pruneWhenEmpty = pruneWhenEmpty; + this.rowNumberSymbolsMapping = ImmutableMap.copyOf(rowNumberSymbolsMapping); + } + + public PlanNode node() + { + return node; + } + + public VariableReferenceExpression rowNumber() + { + return rowNumber; + } + + public VariableReferenceExpression partitionSize() + { + return partitionSize; + } + + public List partitionBy() + { + return partitionBy; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public Map rowNumberSymbolsMapping() + { + return rowNumberSymbolsMapping; + } + } + + private static class JoinedNodes + { + private final PlanNode joinedNode; + private final VariableReferenceExpression leftRowNumber; + private final VariableReferenceExpression leftPartitionSize; + private final List leftPartitionBy; + private final boolean leftPruneWhenEmpty; + private final Map leftRowNumberSymbolsMapping; + private final VariableReferenceExpression rightRowNumber; + private final VariableReferenceExpression rightPartitionSize; + private final List rightPartitionBy; + private final boolean rightPruneWhenEmpty; + private final Map rightRowNumberSymbolsMapping; + + public JoinedNodes( + PlanNode joinedNode, + VariableReferenceExpression leftRowNumber, + VariableReferenceExpression leftPartitionSize, + List leftPartitionBy, + boolean leftPruneWhenEmpty, + Map leftRowNumberSymbolsMapping, + VariableReferenceExpression rightRowNumber, + VariableReferenceExpression rightPartitionSize, + List rightPartitionBy, + boolean rightPruneWhenEmpty, + Map rightRowNumberSymbolsMapping) + { + this.joinedNode = requireNonNull(joinedNode, "joinedNode is null"); + this.leftRowNumber = requireNonNull(leftRowNumber, "leftRowNumber is null"); + this.leftPartitionSize = requireNonNull(leftPartitionSize, "leftPartitionSize is null"); + this.leftPartitionBy = ImmutableList.copyOf(requireNonNull(leftPartitionBy, "leftPartitionBy is null")); + this.leftPruneWhenEmpty = leftPruneWhenEmpty; + this.leftRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(leftRowNumberSymbolsMapping, "leftRowNumberSymbolsMapping is null")); + this.rightRowNumber = requireNonNull(rightRowNumber, "rightRowNumber is null"); + this.rightPartitionSize = requireNonNull(rightPartitionSize, "rightPartitionSize is null"); + this.rightPartitionBy = ImmutableList.copyOf(requireNonNull(rightPartitionBy, "rightPartitionBy is null")); + this.rightPruneWhenEmpty = rightPruneWhenEmpty; + this.rightRowNumberSymbolsMapping = ImmutableMap.copyOf(requireNonNull(rightRowNumberSymbolsMapping, "rightRowNumberSymbolsMapping is null")); + } + + public PlanNode joinedNode() + { + return joinedNode; + } + public VariableReferenceExpression leftRowNumber() + { + return leftRowNumber; + } + public VariableReferenceExpression leftPartitionSize() + { + return leftPartitionSize; + } + public List leftPartitionBy() + { + return leftPartitionBy; + } + public boolean leftPruneWhenEmpty() + { + return leftPruneWhenEmpty; + } + public Map leftRowNumberSymbolsMapping() + { + return leftRowNumberSymbolsMapping; + } + public VariableReferenceExpression rightRowNumber() + { + return rightRowNumber; + } + public VariableReferenceExpression rightPartitionSize() + { + return rightPartitionSize; + } + public List rightPartitionBy() + { + return rightPartitionBy; + } + public boolean rightPruneWhenEmpty() + { + return rightPruneWhenEmpty; + } + public Map rightRowNumberSymbolsMapping() + { + return rightRowNumberSymbolsMapping; + } + } + + private static class NodeWithMarkers + { + private final PlanNode node; + private final Map variableToMarker; + + public NodeWithMarkers(PlanNode node, Map variableToMarker) + { + this.node = requireNonNull(node, "node is null"); + this.variableToMarker = ImmutableMap.copyOf(requireNonNull(variableToMarker, "symbolToMarker is null")); + } + + public PlanNode node() + { + return node; + } + + public Map variableToMarker() + { + return variableToMarker; + } + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java new file mode 100644 index 0000000000000..cf8caeda8e9df --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorColumns.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.VariableAllocator; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.collect.ImmutableList.toImmutableList; + +/** + * TableFunctionProcessorNode has two kinds of outputs: + * - proper outputs, which are the columns produced by the table function, + * - pass-through outputs, which are the columns copied from table arguments. + * This rule filters out unreferenced pass-through symbols. + * Unreferenced proper symbols are not pruned, because there is currently no way + * to communicate to the table function the request for not producing certain columns. + * // TODO prune table function's proper outputs + */ +public class PruneTableFunctionProcessorColumns + extends ProjectOffPushDownRule +{ + public PruneTableFunctionProcessorColumns() + { + super(tableFunctionProcessor()); + } + + @Override + protected Optional pushDownProjectOff(PlanNodeIdAllocator idAllocator, VariableAllocator variableAllocator, TableFunctionProcessorNode node, Set referencedOutputs) + { + List prunedPassThroughSpecifications = node.getPassThroughSpecifications().stream() + .map(sourceSpecification -> { + List prunedPassThroughColumns = sourceSpecification.getColumns().stream() + .filter(column -> referencedOutputs.contains(column.getOutputVariables())) + .collect(toImmutableList()); + return new TableFunctionNode.PassThroughSpecification(sourceSpecification.isDeclaredAsPassThrough(), prunedPassThroughColumns); + }) + .collect(toImmutableList()); + + int originalPassThroughCount = node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .mapToInt(List::size) + .sum(); + + int prunedPassThroughCount = prunedPassThroughSpecifications.stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .mapToInt(List::size) + .sum(); + + if (originalPassThroughCount == prunedPassThroughCount) { + return Optional.empty(); + } + + return Optional.of(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + node.getSource(), + node.isPruneWhenEmpty(), + prunedPassThroughSpecifications, + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle())); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java new file mode 100644 index 0000000000000..ee59afe81fcab --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PruneTableFunctionProcessorSourceColumns.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.iterative.rule.Util.restrictOutputs; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; +import static com.google.common.collect.Maps.filterKeys; + +/** + * This rule prunes unreferenced outputs of TableFunctionProcessorNode. + * First, it extracts all symbols required for: + * - pass-through + * - table function computation + * - partitioning and ordering (including the hashSymbol) + * Next, a mapping of input symbols to marker symbols is updated + * so that it only contains mappings for the required symbols. + * Last, all the remaining marker symbols are added to the collection + * of required symbols. + * Any source output symbols not included in the required symbols + * can be pruned. + */ +public class PruneTableFunctionProcessorSourceColumns + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (!node.getSource().isPresent()) { + return Result.empty(); + } + + ImmutableSet.Builder requiredInputs = ImmutableSet.builder(); + + node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .forEach(requiredInputs::add); + + node.getRequiredVariables() + .forEach(requiredInputs::addAll); + + node.getSpecification().ifPresent(specification -> { + requiredInputs.addAll(specification.getPartitionBy()); + specification.getOrderingScheme().ifPresent(orderingScheme -> requiredInputs.addAll(orderingScheme.getOrderByVariables())); + }); + + node.getHashSymbol().ifPresent(requiredInputs::add); + + Optional> updatedMarkerSymbols = node.getMarkerVariables() + .map(mapping -> filterKeys(mapping, requiredInputs.build()::contains)); + + updatedMarkerSymbols.ifPresent(mapping -> requiredInputs.addAll(mapping.values())); + + return restrictOutputs(context.getIdAllocator(), node.getSource().orElseThrow(NoSuchElementException::new), requiredInputs.build()) + .map(child -> Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(child), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + updatedMarkerSymbols, + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle()))) + .orElse(Result.empty()); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunction.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunction.java new file mode 100644 index 0000000000000..c92b7af9fb93e --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RemoveRedundantTableFunction.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.google.common.collect.ImmutableList; + +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMost; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; + +/** + * Table function can take multiple table arguments. Each argument is either "prune when empty" or "keep when empty". + * "Prune when empty" means that if this argument has no rows, the function result is empty, so the function can be + * removed from the plan, and replaced with empty values. + * "Keep when empty" means that even if the argument has no rows, the function should still be executed, and it can + * return a non-empty result. + * All the table arguments are combined into a single source of a TableFunctionProcessorNode. If either argument is + * "prune when empty", the overall result is "prune when empty". This rule removes a redundant TableFunctionProcessorNode + * based on the "prune when empty" property. + */ +public class RemoveRedundantTableFunction + implements Rule +{ + private static final Pattern PATTERN = tableFunctionProcessor(); + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) + { + if (node.isPruneWhenEmpty() && node.getSource().isPresent()) { + if (isAtMost(node.getSource().orElseThrow(NoSuchElementException::new), context.getLookup(), 0)) { + return Result.ofPlanNode( + new ValuesNode(node.getSourceLocation(), + node.getId(), + node.getOutputVariables(), + ImmutableList.of(), + Optional.empty())); + } + } + + return Result.empty(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java index 2418377c7ac53..d280a697866d8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteTableFunctionToTableScan.java @@ -23,7 +23,7 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.iterative.Rule; -import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.google.common.collect.ImmutableMap; import java.util.List; @@ -31,29 +31,32 @@ import static com.facebook.presto.matching.Pattern.empty; import static com.facebook.presto.sql.planner.plan.Patterns.sources; -import static com.facebook.presto.sql.planner.plan.Patterns.tableFunction; +import static com.facebook.presto.sql.planner.plan.Patterns.tableFunctionProcessor; import static com.google.common.base.Preconditions.checkState; import static java.util.Objects.requireNonNull; /* - * This process converts connector-resolvable TableFunctionNodes into equivalent - * TableScanNodes by invoking the connector’s applyTableFunction() during planning. - * It allows table-valued functions whose results can be expressed as a ConnectorTableHandle - * to be treated like regular scans and benefit from normal scan optimizations. + * This rule converts connector-resolvable TableFunctionProcessorNodes into equivalent + * TableScanNodes by invoking the connector's applyTableFunction() method during query planning. + * + * It enables table-valued functions whose results can be represented as a ConnectorTableHandle + * to be treated like regular table scans, allowing them to benefit from standard scan optimizations. * * Example: * Before Transformation: * TableFunction(my_function(arg1, arg2)) * * After Transformation: - * TableScan(my_function(arg1, arg2)).applyTableFunction_tableHandle) - * assignments: {outputVar1 -> my_function(arg1, arg2)).applyTableFunction_colHandle1, - * outputVar2 -> my_function(arg1, arg2)).applyTableFunction_colHandle2} + * TableScan(my_function(arg1, arg2)) + * assignments: { + * outputVar1 -> my_function(arg1, arg2)_colHandle1, + * outputVar2 -> my_function(arg1, arg2)_colHandle2 + * } */ public class RewriteTableFunctionToTableScan - implements Rule + implements Rule { - private static final Pattern PATTERN = tableFunction() + private static final Pattern PATTERN = tableFunctionProcessor() .with(empty(sources())); private final Metadata metadata; @@ -64,32 +67,32 @@ public RewriteTableFunctionToTableScan(Metadata metadata) } @Override - public Pattern getPattern() + public Pattern getPattern() { return PATTERN; } @Override - public Result apply(TableFunctionNode tableFunctionNode, Captures captures, Context context) + public Result apply(TableFunctionProcessorNode node, Captures captures, Context context) { - Optional> result = metadata.applyTableFunction(context.getSession(), tableFunctionNode.getHandle()); + Optional> result = metadata.applyTableFunction(context.getSession(), node.getHandle()); if (!result.isPresent()) { return Result.empty(); } List columnHandles = result.get().getColumnHandles(); - checkState(tableFunctionNode.getOutputVariables().size() == columnHandles.size(), "returned table does not match the node's output"); + checkState(node.getOutputVariables().size() == columnHandles.size(), "returned table does not match the node's output"); ImmutableMap.Builder assignments = ImmutableMap.builder(); for (int i = 0; i < columnHandles.size(); i++) { - assignments.put(tableFunctionNode.getOutputVariables().get(i), columnHandles.get(i)); + assignments.put(node.getOutputVariables().get(i), columnHandles.get(i)); } return Result.ofPlanNode(new TableScanNode( - tableFunctionNode.getSourceLocation(), - tableFunctionNode.getId(), + node.getSourceLocation(), + node.getId(), result.get().getTableHandle(), - tableFunctionNode.getOutputVariables(), + node.getOutputVariables(), assignments.buildOrThrow(), TupleDomain.all(), TupleDomain.all(), Optional.empty())); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 61b53b70905f0..a9aa22a2f4be9 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -78,6 +78,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.google.common.annotations.VisibleForTesting; import com.google.common.cache.CacheBuilder; @@ -98,6 +99,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -415,7 +417,59 @@ public PlanWithProperties visitWindow(WindowNode node, PreferredProperties prefe @Override public PlanWithProperties visitTableFunction(TableFunctionNode node, PreferredProperties preferredProperties) { - throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, PreferredProperties preferredProperties) + { + if (!node.getSource().isPresent()) { + return new PlanWithProperties(node, deriveProperties(node, ImmutableList.of())); + } + + if (!node.getSpecification().isPresent()) { + // node.getSpecification.isEmpty() indicates that there were no sources or a single source with row semantics. + // The case of no sources was addressed above. + // The case of a single source with row semantics is addressed here. A single source with row semantics can be distributed arbitrarily. + PlanWithProperties child = planChild(node, PreferredProperties.any()); + return rebaseAndDeriveProperties(node, child); + } + + List partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification().orElseThrow(NoSuchElementException::new) + .getOrderingScheme() + .ifPresent(orderingScheme -> + orderingScheme.getOrderByVariables().stream() + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) + .forEach(desiredProperties::add)); + + PlanWithProperties child = planChild(node, PreferredProperties.partitionedWithLocal(ImmutableSet.copyOf(partitionBy), desiredProperties)); + + // TODO do not gather if already gathered + if (!node.isPruneWhenEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else if (!isStreamPartitionedOn(child.getProperties(), partitionBy) && + !isNodePartitionedOn(child.getProperties(), partitionBy)) { + if (partitionBy.isEmpty()) { + child = withDerivedProperties( + gatheringExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode()), + child.getProperties()); + } + else { + child = withDerivedProperties( + partitionedExchange(idAllocator.getNextId(), REMOTE_STREAMING, child.getNode(), Partitioning.create(FIXED_HASH_DISTRIBUTION, partitionBy), node.getHashSymbol()), + child.getProperties()); + } + } + + return rebaseAndDeriveProperties(node, child); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index c76c6b6252771..4b29dbe1073b2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.VariableAllocator; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -57,6 +58,8 @@ import com.facebook.presto.sql.planner.plan.LateralJoinNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.google.common.collect.ImmutableList; @@ -65,6 +68,7 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; @@ -110,6 +114,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -474,6 +479,87 @@ public PlanWithProperties visitWindow(WindowNode node, StreamPreferredProperties return deriveProperties(result, child.getProperties()); } + @Override + public PlanWithProperties visitTableFunction(TableFunctionNode node, StreamPreferredProperties parentPreferences) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, StreamPreferredProperties parentPreferences) + { + if (!node.getSource().isPresent()) { + return deriveProperties(node, ImmutableList.of()); + } + + if (!node.getSpecification().isPresent()) { + // node.getSpecification.isEmpty() indicates that there were no sources or a single source with row semantics. + // The case of no sources was addressed above. + // The case of a single source with row semantics is addressed here. Source's properties do not hold after the TableFunctionProcessorNode + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), StreamPreferredProperties.any(), StreamPreferredProperties.any()); + return rebaseAndDeriveProperties(node, ImmutableList.of(child)); + } + + List partitionBy = node.getSpecification().orElseThrow(NoSuchElementException::new).getPartitionBy(); + StreamPreferredProperties childRequirements; + if (!node.isPruneWhenEmpty()) { + childRequirements = singleStream(); + } + else { + childRequirements = parentPreferences + .constrainTo(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()) + .withDefaultParallelism(session) + .withPartitioning(partitionBy); + } + + PlanWithProperties child = planAndEnforce(node.getSource().orElseThrow(NoSuchElementException::new), childRequirements, childRequirements); + + List> desiredProperties = new ArrayList<>(); + if (!partitionBy.isEmpty()) { + desiredProperties.add(new GroupingProperty<>(partitionBy)); + } + node.getSpecification() + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .ifPresent(orderingScheme -> + orderingScheme.getOrderByVariables().stream() + .map(variable -> new SortingProperty<>(variable, orderingScheme.getOrdering(variable))) + .forEach(desiredProperties::add)); + Iterator>> matchIterator = LocalProperties.match(child.getProperties().getLocalProperties(), desiredProperties).iterator(); + + Set prePartitionedInputs = ImmutableSet.of(); + if (!partitionBy.isEmpty()) { + Optional> groupingRequirement = matchIterator.next(); + Set unPartitionedInputs = groupingRequirement.map(LocalProperty::getColumns).orElse(ImmutableSet.of()); + prePartitionedInputs = partitionBy.stream() + .filter(symbol -> !unPartitionedInputs.contains(symbol)) + .collect(toImmutableSet()); + } + + int preSortedOrderPrefix = 0; + if (prePartitionedInputs.equals(ImmutableSet.copyOf(partitionBy))) { + while (matchIterator.hasNext() && !matchIterator.next().isPresent()) { + preSortedOrderPrefix++; + } + } + + TableFunctionProcessorNode result = new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(child.getNode()), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + prePartitionedInputs, + preSortedOrderPrefix, + node.getHashSymbol(), + node.getHandle()); + + return deriveProperties(result, child.getProperties()); + } + @Override public PlanWithProperties visitDelete(DeleteNode node, StreamPreferredProperties parentPreferences) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java index b8b0089153b7b..35d4d0280a2cb 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PropertyDerivations.java @@ -25,6 +25,7 @@ import com.facebook.presto.spi.SortingProperty; import com.facebook.presto.spi.UniqueProperty; import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.DeleteNode; import com.facebook.presto.spi.plan.DistinctLimitNode; import com.facebook.presto.spi.plan.EquiJoinClause; @@ -70,6 +71,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -108,6 +111,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.lang.String.format; import static java.util.stream.Collectors.toMap; public class PropertyDerivations @@ -284,6 +288,48 @@ public ActualProperties visitWindow(WindowNode node, List inpu .build(); } + @Override + public ActualProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public ActualProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + ImmutableList.Builder> localProperties = ImmutableList.builder(); + + if (node.getSource().isPresent()) { + ActualProperties properties = Iterables.getOnlyElement(inputProperties); + + // Only the partitioning properties of the source are passed-through, because the pass-through mechanism preserves the partitioning values. + // Sorting properties might be broken because input rows can be shuffled or nulls can be inserted as the result of pass-through. + // Constant properties might be broken because nulls can be inserted as the result of pass-through. + if (!node.getPrePartitioned().isEmpty()) { + GroupingProperty prePartitionedProperty = new GroupingProperty<>(node.getPrePartitioned()); + for (LocalProperty localProperty : properties.getLocalProperties()) { + if (!prePartitionedProperty.isSimplifiedBy(localProperty)) { + break; + } + localProperties.add(localProperty); + } + } + } + + List partitionBy = node.getSpecification() + .map(DataOrganizationSpecification::getPartitionBy) + .orElse(ImmutableList.of()); + if (!partitionBy.isEmpty()) { + localProperties.add(new GroupingProperty<>(partitionBy)); + } + + return ActualProperties.builder() + .local(localProperties.build()) + .build() + // Crop properties to output columns. + .translateVariable(variable -> node.getOutputVariables().contains(variable) ? Optional.of(variable) : Optional.empty()); + } + @Override public ActualProperties visitGroupId(GroupIdNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index d9fd049555be3..327b48e9260d3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -65,6 +65,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -499,6 +500,26 @@ public PlanNode visitWindow(WindowNode node, RewriteContext> context) + { + return node.getSource().map(source -> new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(context.rewrite(source, ImmutableSet.copyOf(source.getOutputVariables()))), + node.isPruneWhenEmpty(), + node.getPassThroughSpecifications(), + node.getRequiredVariables(), + node.getMarkerVariables(), + node.getSpecification(), + node.getPrePartitioned(), + node.getPreSorted(), + node.getHashSymbol(), + node.getHandle() + )).orElse(node); + } + @Override public PlanNode visitTableScan(TableScanNode node, RewriteContext> context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java index ffd4806665c2c..6e05bc4bcae29 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/QueryCardinalityUtil.java @@ -19,6 +19,7 @@ import com.facebook.presto.spi.plan.PlanNode; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.ValuesNode; +import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.sql.planner.iterative.GroupReference; import com.facebook.presto.sql.planner.iterative.Lookup; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; @@ -144,6 +145,12 @@ public Range visitValues(ValuesNode node, Void context) return Range.singleton((long) node.getRows().size()); } + @Override + public Range visitWindow(WindowNode node, Void context) + { + return node.getSource().accept(this, null); + } + @Override public Range visitOffset(OffsetNode node, Void context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index a56b3c773d6ed..42367d13be7f8 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -60,6 +60,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -68,11 +70,13 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -582,6 +586,32 @@ public StreamProperties visitWindow(WindowNode node, List inpu return Iterables.getOnlyElement(inputProperties); } + @Override + public StreamProperties visitTableFunction(TableFunctionNode node, List inputProperties) + { + throw new IllegalStateException(format("Unexpected node: TableFunctionNode (%s)", node.getName())); + } + + @Override + public StreamProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, List inputProperties) + { + if (!node.getSource().isPresent()) { + return StreamProperties.singleStream(); + } + + StreamProperties properties = Iterables.getOnlyElement(inputProperties); + + Set passThroughInputs = Sets.intersection(ImmutableSet.copyOf(node.getSource().orElseThrow(NoSuchElementException::new).getOutputVariables()), ImmutableSet.copyOf(node.getOutputVariables())); + StreamProperties translatedProperties = properties.translate(column -> { + if (passThroughInputs.contains(column)) { + return Optional.of(column); + } + return Optional.empty(); + }); + + return translatedProperties.unordered(true); + } + @Override public StreamProperties visitRowNumber(RowNumberNode node, List inputProperties) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java index 9805efad17939..e6bba0b956261 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/SymbolMapper.java @@ -20,6 +20,7 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; import com.facebook.presto.spi.plan.Ordering; import com.facebook.presto.spi.plan.OrderingScheme; import com.facebook.presto.spi.plan.PartitioningScheme; @@ -37,6 +38,8 @@ import com.facebook.presto.sql.planner.Symbol; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.ExpressionRewriter; @@ -51,12 +54,14 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import static com.facebook.presto.spi.StandardWarningCode.MULTIPLE_ORDER_BY; import static com.facebook.presto.spi.plan.AggregationNode.groupingSets; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation; import static com.facebook.presto.sql.planner.optimizations.PartitioningUtils.translateVariable; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -157,6 +162,89 @@ else if (orderingMap.get(translated) != orderingScheme.getOrdering(variable)) { return new OrderingScheme(orderBy.build().stream().map(variable -> new Ordering(variable, orderingMap.get(variable))).collect(toImmutableList())); } + public OrderingSchemeWithPreSortedPrefix map(OrderingScheme orderingScheme, int preSorted) + { + ImmutableList.Builder newOrderings = ImmutableList.builder(); + int newPreSorted = preSorted; + + Set added = new HashSet<>(orderingScheme.getOrderBy().size()); + + for (int i = 0; i < orderingScheme.getOrderBy().size(); i++) { + VariableReferenceExpression variable = orderingScheme.getOrderBy().get(i).getVariable(); + VariableReferenceExpression canonical = map(variable); + if (added.add(canonical)) { + newOrderings.add(new Ordering(canonical, orderingScheme.getOrdering(variable))); + } + else if (i < preSorted) { + newPreSorted--; + } + } + + return new OrderingSchemeWithPreSortedPrefix(new OrderingScheme(newOrderings.build()), newPreSorted); + } + + public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode source) + { + // rewrite and deduplicate pass-through specifications + // note: Potentially, pass-through symbols from different sources might be recognized as semantically identical, and rewritten + // to the same symbol. Currently, we retrieve the first occurrence of a symbol, and skip all the following occurrences. + // For better performance, we could pick the occurrence with "isPartitioningColumn" property, since the pass-through mechanism + // is more efficient for partitioning columns which are guaranteed to be constant within partition. + // TODO choose a partitioning column to be retrieved while deduplicating + ImmutableList.Builder newPassThroughSpecifications = ImmutableList.builder(); + Set newPassThroughVariables = new HashSet<>(); + for (TableFunctionNode.PassThroughSpecification specification : node.getPassThroughSpecifications()) { + ImmutableList.Builder newColumns = ImmutableList.builder(); + for (TableFunctionNode.PassThroughColumn column : specification.getColumns()) { + VariableReferenceExpression newVariable = map(column.getOutputVariables()); + if (newPassThroughVariables.add(newVariable)) { + newColumns.add(new TableFunctionNode.PassThroughColumn(newVariable, column.isPartitioningColumn())); + } + } + newPassThroughSpecifications.add(new TableFunctionNode.PassThroughSpecification(specification.isDeclaredAsPassThrough(), newColumns.build())); + } + + // rewrite required symbols without deduplication. the table function expects specific input layout + List> newRequiredVariables = node.getRequiredVariables().stream() + .map(list -> list.stream() + .map(this::map) + .collect(toImmutableList())) + .collect(toImmutableList()); + + // rewrite and deduplicate marker mapping + Optional> newMarkerVariables = node.getMarkerVariables() + .map(mapping -> mapping.entrySet().stream() + .collect(toImmutableMap( + entry -> map(entry.getKey()), + entry -> map(entry.getValue()), + (first, second) -> { + checkState(first.equals(second), "Ambiguous marker symbols: %s and %s", first, second); + return first; + }))); + + // rewrite and deduplicate specification + Optional newSpecification = node.getSpecification().map(specification -> mapAndDistinct(specification, node.getPreSorted())); + + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs().stream() + .map(this::map) + .collect(toImmutableList()), + Optional.of(source), + node.isPruneWhenEmpty(), + newPassThroughSpecifications.build(), + newRequiredVariables, + newMarkerVariables, + newSpecification.map(SpecificationWithPreSortedPrefix::getSpecification), + node.getPrePartitioned().stream() + .map(this::map) + .collect(toImmutableSet()), + newSpecification.map(SpecificationWithPreSortedPrefix::getPreSorted).orElse(node.getPreSorted()), + node.getHashSymbol().map(this::map), + node.getHandle()); + } + public AggregationNode map(AggregationNode node, PlanNode source) { return map(node, source, node.getId()); @@ -335,6 +423,25 @@ private List mapAndDistinctSymbol(List outputs) return builder.build(); } + private SpecificationWithPreSortedPrefix mapAndDistinct(DataOrganizationSpecification specification, int preSorted) + { + Optional newOrderingScheme = specification.getOrderingScheme() + .map(orderingScheme -> map(orderingScheme, preSorted)); + + return new SpecificationWithPreSortedPrefix( + new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getOrderingScheme)), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::getPreSorted).orElse(preSorted)); + } + + DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) + { + return new DataOrganizationSpecification( + mapAndDistinctVariable(specification.getPartitionBy()), + specification.getOrderingScheme().map(this::map)); + } + private List mapAndDistinctVariable(List outputs) { Set added = new HashSet<>(); @@ -379,4 +486,48 @@ public void put(VariableReferenceExpression from, VariableReferenceExpression to mappingsBuilder.put(from, to); } } + + private static class OrderingSchemeWithPreSortedPrefix + { + private final OrderingScheme orderingScheme; + private final int preSorted; + + public OrderingSchemeWithPreSortedPrefix(OrderingScheme orderingScheme, int preSorted) + { + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + this.preSorted = preSorted; + } + + public OrderingScheme getOrderingScheme() + { + return orderingScheme; + } + + public int getPreSorted() + { + return preSorted; + } + } + + private static class SpecificationWithPreSortedPrefix + { + private final DataOrganizationSpecification specification; + private final int preSorted; + + public SpecificationWithPreSortedPrefix(DataOrganizationSpecification specification, int preSorted) + { + this.specification = requireNonNull(specification, "specification is null"); + this.preSorted = preSorted; + } + + public DataOrganizationSpecification getSpecification() + { + return specification; + } + + public int getPreSorted() + { + return preSorted; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java index 91a92107f9f3c..62aa2ce9b3b5c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -75,6 +75,7 @@ import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -83,6 +84,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -158,6 +160,11 @@ private Rewriter(TypeProvider types, FunctionAndTypeManager functionAndTypeManag this.warningCollector = warningCollector; } + public Map getMapping() + { + return mapping; + } + @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { @@ -481,18 +488,93 @@ public PlanNode visitTableFinish(TableFinishNode node, RewriteContext cont @Override public PlanNode visitTableFunction(TableFunctionNode node, RewriteContext context) { + Map mappings = + Optional.ofNullable(context.get()) + .map(c -> new HashMap()) + .orElseGet(HashMap::new); + + SymbolMapper mapper = new SymbolMapper(mappings, warningCollector); + + List newProperOutputs = node.getOutputVariables().stream() + .map(mapper::map) + .collect(toImmutableList()); + + ImmutableList.Builder newSources = ImmutableList.builder(); + ImmutableList.Builder newTableArgumentProperties = ImmutableList.builder(); + + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode newSource = node.getSources().get(i).accept(this, context); + newSources.add(newSource); + + SymbolMapper inputMapper = new SymbolMapper(new HashMap<>(), warningCollector); + + TableFunctionNode.TableArgumentProperties properties = node.getTableArgumentProperties().get(i); + + Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); + TableFunctionNode.PassThroughSpecification newPassThroughSpecification = new TableFunctionNode.PassThroughSpecification( + properties.getPassThroughSpecification().isDeclaredAsPassThrough(), + properties.getPassThroughSpecification().getColumns().stream() + .map(column -> new TableFunctionNode.PassThroughColumn( + inputMapper.map(column.getOutputVariables()), + column.isPartitioningColumn())) + .collect(toImmutableList())); + newTableArgumentProperties.add(new TableFunctionNode.TableArgumentProperties( + properties.getArgumentName(), + properties.isRowSemantics(), + properties.isPruneWhenEmpty(), + newPassThroughSpecification, + properties.getRequiredColumns().stream() + .map(inputMapper::map) + .collect(toImmutableList()), + newSpecification)); + } + return new TableFunctionNode( - node.getSourceLocation(), node.getId(), - Optional.empty(), node.getName(), node.getArguments(), - node.getOutputVariables(), - node.getSources(), - node.getTableArgumentProperties(), + newProperOutputs, + newSources.build(), + newTableArgumentProperties.build(), + node.getCopartitioningLists(), node.getHandle()); } + @Override + public PlanNode visitTableFunctionProcessor(TableFunctionProcessorNode node, RewriteContext context) + { + if (!node.getSource().isPresent()) { + Map mappings = + Optional.ofNullable(context.get()) + .map(c -> new HashMap()) + .orElseGet(HashMap::new); + SymbolMapper mapper = new SymbolMapper(mappings, warningCollector); + + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs().stream() + .map(mapper::map) + .collect(toImmutableList()), + Optional.empty(), + node.isPruneWhenEmpty(), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + node.getHashSymbol().map(mapper::map), + node.getHandle()); + } + + PlanNode rewrittenSource = node.getSource().get().accept(this, context); + Map mappings = ((Rewriter) context.getNodeRewriter()).getMapping(); + SymbolMapper mapper = new SymbolMapper(mappings, types, warningCollector); + + return mapper.map(node, rewrittenSource); + } + @Override public PlanNode visitRowNumber(RowNumberNode node, RewriteContext context) { diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java index b33dfc48938d7..34ddf755e6e35 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/InternalPlanVisitor.java @@ -126,4 +126,9 @@ public R visitTableFunction(TableFunctionNode node, C context) { return visitPlan(node, context); } + + public R visitTableFunctionProcessor(TableFunctionProcessorNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java index f1a00c4b1a128..0f74ec388347a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/Patterns.java @@ -234,6 +234,11 @@ public static Pattern tableFunction() return typeOf(TableFunctionNode.class); } + public static Pattern tableFunctionProcessor() + { + return typeOf(TableFunctionProcessorNode.class); + } + public static Pattern rowNumber() { return typeOf(RowNumberNode.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java index 22d4f18e42ff9..f87c1a1bba5c5 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/SimplePlanRewriter.java @@ -61,6 +61,11 @@ public C get() return userContext; } + public SimplePlanRewriter getNodeRewriter() + { + return nodeRewriter; + } + /** * Invoke the rewrite logic recursively on children of the given node and swap it * out with an identical copy with the rewritten children diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java index 97892523498c0..8838e82b48c91 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionNode.java @@ -22,13 +22,17 @@ import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Immutable; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; @Immutable @@ -40,6 +44,7 @@ public class TableFunctionNode private final List outputVariables; private final List sources; private final List tableArgumentProperties; + private final List> copartitioningLists; private final TableFunctionHandle handle; @JsonCreator @@ -50,9 +55,10 @@ public TableFunctionNode( @JsonProperty("outputVariables") List outputVariables, @JsonProperty("sources") List sources, @JsonProperty("tableArgumentProperties") List tableArgumentProperties, + @JsonProperty("copartitioningLists") List> copartitioningLists, @JsonProperty("handle") TableFunctionHandle handle) { - this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, handle); + this(Optional.empty(), id, Optional.empty(), name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); } public TableFunctionNode( @@ -64,14 +70,18 @@ public TableFunctionNode( List outputVariables, List sources, List tableArgumentProperties, + List> copartitioningLists, TableFunctionHandle handle) { super(sourceLocation, id, statsEquivalentPlanNode); this.name = requireNonNull(name, "name is null"); - this.arguments = requireNonNull(arguments, "arguments is null"); - this.outputVariables = requireNonNull(outputVariables, "outputVariables is null"); - this.sources = requireNonNull(sources, "sources is null"); - this.tableArgumentProperties = requireNonNull(tableArgumentProperties, "tableArgumentProperties is null"); + this.arguments = ImmutableMap.copyOf(arguments); + this.outputVariables = ImmutableList.copyOf(outputVariables); + this.sources = ImmutableList.copyOf(sources); + this.tableArgumentProperties = ImmutableList.copyOf(tableArgumentProperties); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); this.handle = requireNonNull(handle, "handle is null"); } @@ -87,8 +97,23 @@ public Map getArguments() return arguments; } - @JsonProperty + @Override public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + variables.addAll(outputVariables); + + tableArgumentProperties.stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + public List getProperOutputs() { return outputVariables; } @@ -99,6 +124,12 @@ public List getTableArgumentProperties() return tableArgumentProperties; } + @JsonProperty + public List> getCopartitioningLists() + { + return copartitioningLists; + } + @JsonProperty public TableFunctionHandle getHandle() { @@ -122,35 +153,47 @@ public R accept(InternalPlanVisitor visitor, C context) public PlanNode replaceChildren(List newSources) { checkArgument(sources.size() == newSources.size(), "wrong number of new children"); - return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, handle); + return new TableFunctionNode(getId(), name, arguments, outputVariables, newSources, tableArgumentProperties, copartitioningLists, handle); } @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, handle); + return new TableFunctionNode(getSourceLocation(), getId(), statsEquivalentPlanNode, name, arguments, outputVariables, sources, tableArgumentProperties, copartitioningLists, handle); } public static class TableArgumentProperties { + private final String argumentName; private final boolean rowSemantics; private final boolean pruneWhenEmpty; - private final boolean passThroughColumns; + private final PassThroughSpecification passThroughSpecification; + private final List requiredColumns; private final Optional specification; @JsonCreator public TableArgumentProperties( + @JsonProperty("argumentName") String argumentName, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, - @JsonProperty("passThroughColumns") boolean passThroughColumns, + @JsonProperty("passThroughSpecification") PassThroughSpecification passThroughSpecification, + @JsonProperty("requiredColumns") List requiredColumns, @JsonProperty("specification") Optional specification) { + this.argumentName = requireNonNull(argumentName, "argumentName is null"); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; - this.passThroughColumns = passThroughColumns; + this.passThroughSpecification = requireNonNull(passThroughSpecification, "passThroughSpecification is null"); + this.requiredColumns = ImmutableList.copyOf(requiredColumns); this.specification = requireNonNull(specification, "specification is null"); } + @JsonProperty + public String getArgumentName() + { + return argumentName; + } + @JsonProperty public boolean isRowSemantics() { @@ -164,15 +207,83 @@ public boolean isPruneWhenEmpty() } @JsonProperty - public boolean isPassThroughColumns() + public PassThroughSpecification getPassThroughSpecification() + { + return passThroughSpecification; + } + + @JsonProperty + public List getRequiredColumns() { - return passThroughColumns; + return requiredColumns; } @JsonProperty - public Optional specification() + public Optional getSpecification() { return specification; } } + + /** + * Specifies how columns from source tables are passed through to the output of a table function. + * This class manages both explicitly declared pass-through columns and partitioning columns + * that must be preserved in the output. + */ + public static class PassThroughSpecification + { + private final boolean declaredAsPassThrough; + private final List columns; + + @JsonCreator + public PassThroughSpecification( + @JsonProperty("declaredAsPassThrough") boolean declaredAsPassThrough, + @JsonProperty("columns") List columns) + { + this.declaredAsPassThrough = declaredAsPassThrough; + this.columns = ImmutableList.copyOf(requireNonNull(columns, "columns is null")); + checkArgument( + declaredAsPassThrough || this.columns.stream().allMatch(PassThroughColumn::isPartitioningColumn), + "non-partitioning pass-through column for non-pass-through source of a table function"); + } + + @JsonProperty + public boolean isDeclaredAsPassThrough() + { + return declaredAsPassThrough; + } + + @JsonProperty + public List getColumns() + { + return columns; + } + } + + public static class PassThroughColumn + { + private final VariableReferenceExpression outputVariables; + private final boolean isPartitioningColumn; + + @JsonCreator + public PassThroughColumn( + @JsonProperty("outputVariables") VariableReferenceExpression outputVariables, + @JsonProperty("partitioningColumn") boolean isPartitioningColumn) + { + this.outputVariables = requireNonNull(outputVariables, "symbol is null"); + this.isPartitioningColumn = isPartitioningColumn; + } + + @JsonProperty + public VariableReferenceExpression getOutputVariables() + { + return outputVariables; + } + + @JsonProperty + public boolean isPartitioningColumn() + { + return isPartitioningColumn; + } + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java new file mode 100644 index 0000000000000..5f6d71cf3e1e5 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TableFunctionProcessorNode.java @@ -0,0 +1,234 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.plan; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeId; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.presto.spi.function.table.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorNode + extends InternalPlanNode +{ + private final String name; + + // symbols produced by the function + private final List properOutputs; + + // pre-planned sources + private final Optional source; + // TODO do we need the info of which source has row semantics, or is it already included in the joins / join distribution? + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // all source symbols to be produced on output, ordered as table argument specifications + private final List passThroughSpecifications; + + // symbols required from each source, ordered as table argument specifications + private final List> requiredVariables; + + // mapping from source symbol to helper "marker" symbol which indicates whether the source value is valid + // for processing or for pass-through. null value in the marker column indicates that the value at the same + // position in the source column should not be processed or passed-through. + // the mapping is only present if there are two or more sources. + private final Optional> markerVariables; + + private final Optional specification; + private final Set prePartitioned; + private final int preSorted; + private final Optional hashSymbol; + + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionProcessorNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("properOutputs") List properOutputs, + @JsonProperty("source") Optional source, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughSpecifications") List passThroughSpecifications, + @JsonProperty("requiredVariables") List> requiredVariables, + @JsonProperty("markerVariables") Optional> markerVariables, + @JsonProperty("specification") Optional specification, + @JsonProperty("prePartitioned") Set prePartitioned, + @JsonProperty("preSorted") int preSorted, + @JsonProperty("hashSymbol") Optional hashSymbol, + @JsonProperty("handle") TableFunctionHandle handle) + { + super(Optional.empty(), id, Optional.empty()); + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.source = requireNonNull(source, "source is null"); + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.requiredVariables = requiredVariables.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerVariables = markerVariables.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.prePartitioned = ImmutableSet.copyOf(prePartitioned); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + checkArgument(partitionBy.containsAll(prePartitioned), "all pre-partitioned symbols must be contained in the partitioning list"); + this.preSorted = preSorted; + checkArgument( + specification + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .map(OrderingScheme::getOrderBy) + .map(List::size) + .orElse(0) >= preSorted, + "the number of pre-sorted symbols cannot be greater than the number of all ordering symbols"); + checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public List getProperOutputs() + { + return properOutputs; + } + + @JsonProperty + public Optional getSource() + { + return source; + } + + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public List getPassThroughSpecifications() + { + return passThroughSpecifications; + } + + @JsonProperty + public List> getRequiredVariables() + { + return requiredVariables; + } + + @JsonProperty + public Optional> getMarkerVariables() + { + return markerVariables; + } + + @JsonProperty + public Optional getSpecification() + { + return specification; + } + + @JsonProperty + public Set getPrePartitioned() + { + return prePartitioned; + } + + @JsonProperty + public int getPreSorted() + { + return preSorted; + } + + @JsonProperty + public Optional getHashSymbol() + { + return hashSymbol; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return source.map(ImmutableList::of).orElse(ImmutableList.of()); + } + + @Override + public List getOutputVariables() + { + ImmutableList.Builder variables = ImmutableList.builder(); + + variables.addAll(properOutputs); + + passThroughSpecifications.stream() + .map(PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .forEach(variables::add); + + return variables.build(); + } + + @Override + public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) + { + return this; + } + + @Override + public PlanNode replaceChildren(List newSources) + { + Optional newSource = newSources.isEmpty() ? Optional.empty() : Optional.of(getOnlyElement(newSources)); + return new TableFunctionProcessorNode(getId(), name, properOutputs, newSource, pruneWhenEmpty, passThroughSpecifications, requiredVariables, markerVariables, specification, prePartitioned, preSorted, hashSymbol, handle); + } + + @Override + public R accept(InternalPlanVisitor visitor, C context) + { + return visitor.visitTableFunctionProcessor(this, context); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java index 2b7059d12e02a..f2ca735649b8a 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/planPrinter/PlanPrinter.java @@ -35,6 +35,9 @@ import com.facebook.presto.spi.SourceLocation; import com.facebook.presto.spi.TableHandle; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; import com.facebook.presto.spi.plan.AbstractJoinNode; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; @@ -99,6 +102,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -110,6 +114,7 @@ import com.google.common.base.Functions; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSortedMap; import com.google.common.collect.Iterables; @@ -118,11 +123,13 @@ import io.airlift.slice.Slice; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -135,6 +142,7 @@ import static com.facebook.presto.execution.StageInfo.getAllStages; import static com.facebook.presto.expressions.DynamicFilters.extractDynamicFilters; import static com.facebook.presto.metadata.CastType.CAST; +import static com.facebook.presto.spi.function.table.DescriptorArgument.NULL_DESCRIPTOR; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.planner.SortExpressionExtractor.getSortExpressionContext; import static com.facebook.presto.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -148,10 +156,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.lang.String.format; import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.function.Function.identity; import static java.util.stream.Collectors.toList; public class PlanPrinter @@ -1329,9 +1339,177 @@ public Void visitTableFunction(TableFunctionNode node, Void context) "TableFunction", node.getName()); - checkArgument( - node.getSources().isEmpty() && node.getTableArgumentProperties().isEmpty(), - "Table or descriptor arguments are not yet supported in PlanPrinter"); + if (!node.getArguments().isEmpty()) { + nodeOutput.appendDetails("Arguments:"); + + Map tableArguments = node.getTableArgumentProperties().stream() + .collect(toImmutableMap(TableFunctionNode.TableArgumentProperties::getArgumentName, identity())); + + node.getArguments().entrySet().stream() + .forEach(entry -> nodeOutput.appendDetailsLine(formatArgument(entry.getKey(), entry.getValue(), tableArguments))); + + if (!node.getCopartitioningLists().isEmpty()) { + nodeOutput.appendDetailsLine(node.getCopartitioningLists().stream() + .map(list -> list.stream().collect(Collectors.joining(", ", "(", ")"))) + .collect(Collectors.joining(", ", "Co-partition: [", "] "))); + } + } + + processChildren(node, context); + + return null; + } + + private String formatArgument(String argumentName, Argument argument, Map tableArguments) + { + if (argument instanceof ScalarArgument) { + ScalarArgument scalarArgument = (ScalarArgument) argument; + return formatScalarArgument(argumentName, scalarArgument); + } + if (argument instanceof DescriptorArgument) { + DescriptorArgument descriptorArgument = (DescriptorArgument) argument; + return formatDescriptorArgument(argumentName, descriptorArgument); + } + else { + TableFunctionNode.TableArgumentProperties argumentProperties = tableArguments.get(argumentName); + return formatTableArgument(argumentName, argumentProperties); + } + } + + private String formatScalarArgument(String argumentName, ScalarArgument argument) + { + return format( + "%s => ScalarArgument{type=%s, value=%s}", + argumentName, + argument.getType().getDisplayName(), + argument.getValue()); + } + + private String formatDescriptorArgument(String argumentName, DescriptorArgument argument) + { + String descriptor; + if (argument.equals(NULL_DESCRIPTOR)) { + descriptor = "NULL"; + } + else { + descriptor = argument.getDescriptor().orElseThrow(() -> new IllegalStateException("Missing descriptor")).getFields().stream() + .map(field -> field.getName() + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) + .collect(Collectors.joining(", ", "(", ")")); + } + return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + } + + private String formatTableArgument(String argumentName, TableFunctionNode.TableArgumentProperties argumentProperties) + { + List properties = new ArrayList<>(); + + if (argumentProperties.isRowSemantics()) { + properties.add("row semantics "); + } + argumentProperties.getSpecification().ifPresent(specification -> { + StringBuilder specificationBuilder = new StringBuilder(); + specificationBuilder + .append("partition by: [") + .append(Joiner.on(", ").join(specification.getPartitionBy())) + .append("]"); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + specificationBuilder + .append(", order by: ") + .append(formatOrderingScheme(orderingScheme)); + }); + properties.add(specificationBuilder.toString()); + }); + + properties.add("required columns: [" + + Joiner.on(", ").join(argumentProperties.getRequiredColumns()) + "]"); + + if (argumentProperties.isPruneWhenEmpty()) { + properties.add("prune when empty"); + } + + if (argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + properties.add("pass through columns"); + } + + return format("%s => TableArgument{%s}", argumentName, Joiner.on(", ").join(properties)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme) + { + return formatCollection(orderingScheme.getOrderByVariables(), variable -> variable + " " + orderingScheme.getOrdering(variable)); + } + + private String formatOrderingScheme(OrderingScheme orderingScheme, int preSortedOrderPrefix) + { + List orderBy = Stream.concat( + orderingScheme.getOrderByVariables().stream() + .limit(preSortedOrderPrefix) + .map(variable -> "<" + variable + " " + orderingScheme.getOrdering(variable) + ">"), + orderingScheme.getOrderByVariables().stream() + .skip(preSortedOrderPrefix) + .map(variable -> variable + " " + orderingScheme.getOrdering(variable))) + .collect(toImmutableList()); + return formatCollection(orderBy, Objects::toString); + } + + public String formatCollection(Collection collection, Function formatter) + { + return collection.stream() + .map(formatter) + .collect(Collectors.joining(", ", "[", "]")); + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Void context) + { + ImmutableMap.Builder descriptor = ImmutableMap.builder(); + + descriptor.put("name", node.getName()); + + descriptor.put("properOutputs", format("[%s]", Joiner.on(", ").join(node.getProperOutputs()))); + + String specs = node.getPassThroughSpecifications().stream() + .map(spec -> spec.getColumns().stream() + .map(col -> col.getOutputVariables().toString()) + .collect(Collectors.joining(", ", "[", "]"))) + .collect(Collectors.joining(", ")); + descriptor.put("passThroughSymbols", format("[%s]", specs)); + + String requiredSymbols = node.getRequiredVariables().stream() + .map(vars -> vars.stream() + .map(VariableReferenceExpression::toString) + .collect(Collectors.joining(", ", "[", "]"))) + .collect(Collectors.joining(", ", "[", "]")); + descriptor.put("requiredSymbols", format("[%s]", requiredSymbols)); + + node.getSpecification().ifPresent(specification -> { + if (!specification.getPartitionBy().isEmpty()) { + List prePartitioned = specification.getPartitionBy().stream() + .filter(node.getPrePartitioned()::contains) + .collect(toImmutableList()); + + List notPrePartitioned = specification.getPartitionBy().stream() + .filter(column -> !node.getPrePartitioned().contains(column)) + .collect(toImmutableList()); + + StringBuilder builder = new StringBuilder(); + if (!prePartitioned.isEmpty()) { + builder.append(prePartitioned.stream() + .map(VariableReferenceExpression::toString) + .collect(Collectors.joining(", ", "<", ">"))); + if (!notPrePartitioned.isEmpty()) { + builder.append(", "); + } + } + if (!notPrePartitioned.isEmpty()) { + builder.append(Joiner.on(", ").join(notPrePartitioned)); + } + descriptor.put("partitionBy", format("[%s]", builder)); + } + specification.getOrderingScheme().ifPresent(orderingScheme -> descriptor.put("orderBy", formatOrderingScheme(orderingScheme, node.getPreSorted()))); + }); + + addNode(node, "TableFunctionProcessorNode" + descriptor.build()); return processChildren(node, context); } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java index 7bb6f516b9171..fef6b4ed58bd2 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateDependenciesChecker.java @@ -70,6 +70,7 @@ import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -120,6 +121,117 @@ public Void visitPlan(PlanNode node, Set boundVaria @Override public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) { + for (int i = 0; i < node.getSources().size(); i++) { + PlanNode source = node.getSources().get(i); + source.accept(this, boundSymbols); + Set inputs = createInputs(source, boundSymbols); + TableFunctionNode.TableArgumentProperties argumentProperties = node.getTableArgumentProperties().get(i); + + checkDependencies( + inputs, + argumentProperties.getRequiredColumns(), + "Invalid node. Required input symbols from source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + argumentProperties.getRequiredColumns(), + source.getOutputVariables()); + argumentProperties.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + Set passThroughVariable = argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughVariable, + "Invalid node. Pass-through symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + passThroughVariable, + source.getOutputVariables()); + } + return null; + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundVariables) + { + if (!node.getSource().isPresent()) { + return null; + } + + PlanNode source = node.getSource().get(); + source.accept(this, boundVariables); + + Set inputs = createInputs(source, boundVariables); + + Set passThroughSymbols = node.getPassThroughSpecifications().stream() + .map(TableFunctionNode.PassThroughSpecification::getColumns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughSymbols, + "Invalid node. Pass-through symbols (%s) not in source plan output (%s)", + passThroughSymbols, + source.getOutputVariables()); + + Set requiredSymbols = node.getRequiredVariables().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + checkDependencies( + inputs, + requiredSymbols, + "Invalid node. Required symbols (%s) not in source plan output (%s)", + requiredSymbols, + source.getOutputVariables()); + + node.getMarkerVariables().ifPresent(mapping -> { + checkDependencies( + inputs, + mapping.keySet(), + "Invalid node. Source symbols (%s) not in source plan output (%s)", + mapping.keySet(), + source.getOutputVariables()); + checkDependencies( + inputs, + mapping.values(), + "Invalid node. Source marker symbols (%s) not in source plan output (%s)", + mapping.values(), + source.getOutputVariables()); + }); + + node.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols (%s) not in source plan output (%s)", + specification.getPartitionBy(), + source.getOutputVariables()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderByVariables(), + "Invalid node. Order by symbols (%s) not in source plan output (%s)", + orderingScheme.getOrderBy(), + source.getOutputVariables()); + }); + }); + return null; } diff --git a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index f7fd052f02c3a..6bdf3c742af26 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main-base/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -772,7 +772,8 @@ public void installPlugin(Plugin plugin) @Override public void createCatalog(String catalogName, String connectorName, Map properties) { - throw new UnsupportedOperationException(); + nodeManager.addCurrentNodeConnector(new ConnectorId(catalogName)); + connectorManager.createConnection(catalogName, connectorName, properties); } @Override diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index 6c210e9e0848c..284b7dc629521 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -66,6 +66,8 @@ import com.facebook.presto.sql.planner.plan.SampleNode; import com.facebook.presto.sql.planner.plan.SequenceNode; import com.facebook.presto.sql.planner.plan.StatisticsWriterNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.planner.plan.TableWriterMergeNode; import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.sql.planner.plan.UpdateNode; @@ -131,6 +133,8 @@ private enum NodeType ANALYZE_FINISH, EXPLAIN_ANALYZE, UPDATE, + TABLE_FUNCTION, + TABLE_FUNCTION_PROCESSOR } private static final Map NODE_COLORS = immutableEnumMap(ImmutableMap.builder() @@ -162,6 +166,8 @@ private enum NodeType .put(NodeType.ANALYZE_FINISH, "plum") .put(NodeType.EXPLAIN_ANALYZE, "cadetblue1") .put(NodeType.UPDATE, "blue") + .put(NodeType.TABLE_FUNCTION, "mediumorchid3") + .put(NodeType.TABLE_FUNCTION_PROCESSOR, "steelblue3") .build()); static { @@ -382,6 +388,24 @@ public Void visitWindow(WindowNode node, Void context) return node.getSource().accept(this, context); } + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Void context) + { + printNode(node, "Table Function Processor", NODE_COLORS.get(NodeType.TABLE_FUNCTION_PROCESSOR)); + if (node.getSource().isPresent()) { + node.getSource().get().accept(this, context); + } + return null; + } + + @Override + public Void visitTableFunction(TableFunctionNode node, Void context) + { + printNode(node, "Table Function Node", NODE_COLORS.get(NodeType.TABLE_FUNCTION)); + node.getSources().stream().map(source -> source.accept(this, context)); + return null; + } + @Override public Void visitRowNumber(RowNumberNode node, Void context) { diff --git a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java index 96373d826b50a..316d98787cf31 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java +++ b/presto-main-base/src/test/java/com/facebook/presto/connector/tvf/TestingTableFunctions.java @@ -67,18 +67,17 @@ public class TestingTableFunctions public static class TestConnectorTableFunction extends AbstractConnectorTableFunction { - private static final String TEST_FUNCTION = "test_function"; - + private static final String FUNCTION_NAME = "test_function"; public TestConnectorTableFunction() { - super(SCHEMA_NAME, TEST_FUNCTION, ImmutableList.of(), ReturnTypeSpecification.GenericTable.GENERIC_TABLE); + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), ReturnTypeSpecification.GenericTable.GENERIC_TABLE); } @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { return TableFunctionAnalysis.builder() - .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, TEST_FUNCTION))) + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("c1", Optional.of(BOOLEAN))))) .build(); } @@ -87,11 +86,10 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class TestConnectorTableFunction2 extends AbstractConnectorTableFunction { - private static final String TEST_FUNCTION_2 = "test_function2"; - + private static final String FUNCTION_NAME = "test_function2"; public TestConnectorTableFunction2() { - super(SCHEMA_NAME, TEST_FUNCTION_2, ImmutableList.of(), ONLY_PASS_THROUGH); + super(SCHEMA_NAME, FUNCTION_NAME, ImmutableList.of(), ONLY_PASS_THROUGH); } @Override @@ -104,11 +102,10 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class NullArgumentsTableFunction extends AbstractConnectorTableFunction { - private static final String NULL_ARGUMENTS_FUNCTION = "null_arguments_function"; - + private static final String FUNCTION_NAME = "null_arguments_function"; public NullArgumentsTableFunction() { - super(SCHEMA_NAME, NULL_ARGUMENTS_FUNCTION, null, ONLY_PASS_THROUGH); + super(SCHEMA_NAME, FUNCTION_NAME, null, ONLY_PASS_THROUGH); } @Override @@ -121,12 +118,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class DuplicateArgumentsTableFunction extends AbstractConnectorTableFunction { - private static final String DUPLICATE_ARGUMENTS_FUNCTION = "duplicate_arguments_function"; + private static final String FUNCTION_NAME = "duplicate_arguments_function"; public DuplicateArgumentsTableFunction() { super( SCHEMA_NAME, - DUPLICATE_ARGUMENTS_FUNCTION, + FUNCTION_NAME, ImmutableList.of( ScalarArgumentSpecification.builder().name("a").type(INTEGER).build(), ScalarArgumentSpecification.builder().name("a").type(INTEGER).build()), @@ -143,12 +140,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class MultipleRSTableFunction extends AbstractConnectorTableFunction { - private static final String MULTIPLE_SOURCES_FUNCTION = "multiple_sources_function"; + private static final String FUNCTION_NAME = "multiple_sources_function"; public MultipleRSTableFunction() { super( SCHEMA_NAME, - MULTIPLE_SOURCES_FUNCTION, + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder().name("t").rowSemantics().build(), TableArgumentSpecification.builder().name("t2").rowSemantics().build()), ONLY_PASS_THROUGH); @@ -172,7 +169,6 @@ public static class SimpleTableFunction { private static final String FUNCTION_NAME = "simple_table_function"; private static final String TABLE_NAME = "simple_table"; - public SimpleTableFunction() { super( @@ -227,11 +223,12 @@ public TestTVFConnectorTableHandle getTableHandle() public static class TwoScalarArgumentsFunction extends AbstractConnectorTableFunction { + private static final String FUNCTION_NAME = "two_scalar_arguments_function"; public TwoScalarArgumentsFunction() { super( SCHEMA_NAME, - "two_arguments_function", + FUNCTION_NAME, ImmutableList.of( ScalarArgumentSpecification.builder() .name("TEXT") @@ -256,7 +253,6 @@ public static class TableArgumentFunction extends AbstractConnectorTableFunction { public static final String FUNCTION_NAME = "table_argument_function"; - public TableArgumentFunction() { super( @@ -284,11 +280,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class DescriptorArgumentFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "descriptor_argument_function"; public DescriptorArgumentFunction() { super( SCHEMA_NAME, - "descriptor_argument_function", + FUNCTION_NAME, ImmutableList.of( DescriptorArgumentSpecification.builder() .name("SCHEMA") @@ -327,11 +324,16 @@ public TestTVFConnectorTableHandle getTableHandle() public static class TestingTableFunctionHandle implements ConnectorTableFunctionHandle { + private final TestTVFConnectorTableHandle tableHandle; private final SchemaFunctionName schemaFunctionName; @JsonCreator public TestingTableFunctionHandle(@JsonProperty("schemaFunctionName") SchemaFunctionName schemaFunctionName) { + this.tableHandle = new TestTVFConnectorTableHandle( + new SchemaTableName(SCHEMA_NAME, TABLE_NAME), + Optional.of(ImmutableList.of(new TestTVFConnectorColumnHandle(COLUMN_NAME, BOOLEAN))), + TupleDomain.all()); this.schemaFunctionName = requireNonNull(schemaFunctionName, "schemaFunctionName is null"); } @@ -340,16 +342,22 @@ public SchemaFunctionName getSchemaFunctionName() { return schemaFunctionName; } + + public TestTVFConnectorTableHandle getTableHandle() + { + return tableHandle; + } } public static class TableArgumentRowSemanticsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "table_argument_row_semantics_function"; public TableArgumentRowSemanticsFunction() { super( SCHEMA_NAME, - "table_argument_row_semantics_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") @@ -372,17 +380,20 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class TwoTableArgumentsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "two_table_arguments_function"; public TwoTableArgumentsFunction() { super( SCHEMA_NAME, - "two_table_arguments_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT1") + .keepWhenEmpty() .build(), TableArgumentSpecification.builder() .name("INPUT2") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -402,11 +413,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class OnlyPassThroughFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "only_pass_through_function"; public OnlyPassThroughFunction() { super( SCHEMA_NAME, - "only_pass_through_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") @@ -425,11 +437,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class MonomorphicStaticReturnTypeFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "monomorphic_static_return_type_function"; public MonomorphicStaticReturnTypeFunction() { super( SCHEMA_NAME, - "monomorphic_static_return_type_function", + FUNCTION_NAME, ImmutableList.of(), new DescribedTable(Descriptor.descriptor( ImmutableList.of("a", "b"), @@ -448,11 +461,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class PolymorphicStaticReturnTypeFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "polymorphic_static_return_type_function"; public PolymorphicStaticReturnTypeFunction() { super( SCHEMA_NAME, - "polymorphic_static_return_type_function", + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .build()), @@ -471,14 +485,16 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class PassThroughFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "pass_through_function"; public PassThroughFunction() { super( SCHEMA_NAME, - "pass_through_function", + FUNCTION_NAME, ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .passThroughColumns() + .keepWhenEmpty() .build()), new DescribedTable(Descriptor.descriptor( ImmutableList.of("x"), @@ -495,14 +511,16 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact public static class RequiredColumnsFunction extends AbstractConnectorTableFunction { + public static final String FUNCTION_NAME = "required_columns_function"; public RequiredColumnsFunction() { super( SCHEMA_NAME, - "required_columns_function", + FUNCTION_NAME, ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -517,4 +535,51 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact .build(); } } + + public static class DifferentArgumentTypesFunction + extends AbstractConnectorTableFunction + { + public static final String FUNCTION_NAME = "different_arguments_function"; + public DifferentArgumentTypesFunction() + { + super( + SCHEMA_NAME, + FUNCTION_NAME, + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT_1") + .passThroughColumns() + .keepWhenEmpty() + .build(), + DescriptorArgumentSpecification.builder() + .name("LAYOUT") + .build(), + TableArgumentSpecification.builder() + .name("INPUT_2") + .rowSemantics() + .passThroughColumns() + .build(), + ScalarArgumentSpecification.builder() + .name("ID") + .type(BIGINT) + .build(), + TableArgumentSpecification.builder() + .name("INPUT_3") + .pruneWhenEmpty() + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(new TestingTableFunctionHandle(new SchemaFunctionName(SCHEMA_NAME, FUNCTION_NAME))) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .requiredColumns("INPUT_3", ImmutableList.of(0)) + .build(); + } + } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java new file mode 100644 index 0000000000000..6d236432e1d15 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/TestTableFunctionInvocation.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.connector.tvf.TestTVFConnectorFactory; +import com.facebook.presto.connector.tvf.TestTVFConnectorPlugin; +import com.facebook.presto.connector.tvf.TestingTableFunctions; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DescriptorArgumentFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.DifferentArgumentTypesFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TestingTableFunctionHandle; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoScalarArgumentsFunction; +import com.facebook.presto.connector.tvf.TestingTableFunctions.TwoTableArgumentsFunction; +import com.facebook.presto.spi.connector.TableFunctionApplicationResult; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.Descriptor.Field; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.sql.planner.assertions.RowNumberSymbolMatcher; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.tree.LongLiteral; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.sql.Optimizer.PlanStage.CREATED; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictOutput; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.strictProject; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunction; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.descriptorArgument; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.DescriptorArgumentValue.nullDescriptor; +import static com.facebook.presto.sql.planner.assertions.TableFunctionMatcher.TableArgumentValue.Builder.tableArgument; + +public class TestTableFunctionInvocation + extends BasePlanTest +{ + private static final String TESTING_CATALOG = "test"; + + @BeforeClass + public final void setup() + { + getQueryRunner().installPlugin(new TestTVFConnectorPlugin(TestTVFConnectorFactory.builder() + .withTableFunctions(ImmutableSet.of( + new DifferentArgumentTypesFunction(), + new TwoScalarArgumentsFunction(), + new TwoTableArgumentsFunction(), + new DescriptorArgumentFunction(), + new TestingTableFunctions.PassThroughFunction())) + .withApplyTableFunction((session, handle) -> { + if (handle instanceof TestingTableFunctionHandle) { + TestingTableFunctionHandle functionHandle = (TestingTableFunctionHandle) handle; + return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow(() -> new IllegalStateException("Missing columns")))); + } + throw new IllegalStateException("Unsupported table function handle: " + handle.getClass().getSimpleName()); + }) + .build())); + getQueryRunner().createCatalog(TESTING_CATALOG, "testTVF", ImmutableMap.of()); + } + + @Test + public void testTableFunctionInitialPlan() + { + assertPlan( + "SELECT * FROM TABLE(test.system.different_arguments_function(" + + "INPUT_1 => TABLE(SELECT 'a') t1(c1) PARTITION BY c1 ORDER BY c1," + + "INPUT_3 => TABLE(SELECT 'b') t3(c3) PARTITION BY c3," + + "INPUT_2 => TABLE(VALUES 1) t2(c2)," + + "ID => BIGINT '2001'," + + "LAYOUT => DESCRIPTOR (x boolean, y bigint)" + + "COPARTITION (t1, t3))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("different_arguments_function") + .addTableArgument( + "INPUT_1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1"), ImmutableList.of("c1"), ImmutableMap.of("c1", ASC_NULLS_LAST))) + .passThroughVariables(ImmutableSet.of("c1")) + .passThroughColumns()) + .addTableArgument( + "INPUT_3", + tableArgument(2) + .specification(specification(ImmutableList.of("c3"), ImmutableList.of(), ImmutableMap.of())) + .pruneWhenEmpty() + .passThroughVariables(ImmutableSet.of("c3"))) + .addTableArgument( + "INPUT_2", + tableArgument(1) + .rowSemantics() + .passThroughVariables(ImmutableSet.of("c2")) + .passThroughColumns()) + .addScalarArgument("ID", 2001L) + .addDescriptorArgument( + "LAYOUT", + descriptorArgument(new Descriptor(ImmutableList.of( + new Field("X", Optional.of(BOOLEAN)), + new Field("Y", Optional.of(BIGINT)))))) + .addCopartitioning(ImmutableList.of("INPUT_1", "INPUT_3")) + .properOutputs(ImmutableList.of("OUTPUT")), + anyTree(project(ImmutableMap.of("c1", expression("'a'")), values(1))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("1"))))), + anyTree(project(ImmutableMap.of("c3", expression("'b'")), values(1)))))); + } + + @Test + public void testTableFunctionInitialPlanWithCoercionForCopartitioning() + { + assertPlan("SELECT * FROM TABLE(test.system.two_table_arguments_function(" + + "INPUT1 => TABLE(VALUES SMALLINT '1') t1(c1) PARTITION BY c1," + + "INPUT2 => TABLE(VALUES INTEGER '2') t2(c2) PARTITION BY c2 " + + "COPARTITION (t1, t2))) t", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_table_arguments_function") + .addTableArgument( + "INPUT1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1_coerced"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c1"))) + .addTableArgument( + "INPUT2", + tableArgument(1) + .specification(specification(ImmutableList.of("c2"), ImmutableList.of(), ImmutableMap.of())) + .passThroughVariables(ImmutableSet.of("c2"))) + .addCopartitioning(ImmutableList.of("INPUT1", "INPUT2")) + .properOutputs(ImmutableList.of("COLUMN")), + project(ImmutableMap.of("c1_coerced", expression("CAST(c1 AS INTEGER)")), + anyTree(values(ImmutableList.of("c1"), ImmutableList.of(ImmutableList.of(new LongLiteral("1")))))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new LongLiteral("2")))))))); + } + + @Test + public void testNullScalarArgument() + { + // the argument NUMBER has null default value + assertPlan( + " SELECT * FROM TABLE(test.system.two_scalar_arguments_function(TEXT => null))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_scalar_arguments_function") + .addScalarArgument("TEXT", null) + .addScalarArgument("NUMBER", null) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testNullDescriptorArgument() + { + assertPlan( + " SELECT * FROM TABLE(test.system.descriptor_argument_function(SCHEMA => CAST(null AS DESCRIPTOR)))", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + + // the argument SCHEMA has null default value + assertPlan( + " SELECT * FROM TABLE(test.system.descriptor_argument_function())", + CREATED, + anyTree(tableFunction(builder -> builder + .name("descriptor_argument_function") + .addDescriptorArgument("SCHEMA", nullDescriptor()) + .properOutputs(ImmutableList.of("OUTPUT"))))); + } + + @Test + public void testPruneTableFunctionColumns() + { + // all table function outputs are referenced with SELECT *, no pruning + assertPlan("SELECT * FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true) t(a, b)))", + strictOutput( + ImmutableList.of("x", "a", "b"), + tableFunctionProcessor( + builder -> builder + .name("pass_through_function") + .properOutputs(ImmutableList.of("x")) + .passThroughSymbols( + ImmutableList.of(ImmutableList.of("a", "b"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("a"))) + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())), + project(ImmutableMap.of("a", expression("INTEGER'1'"), "b", expression("BOOLEAN'true'")), values(1))))); + + // no table function outputs are referenced. All pass-through symbols are pruned from the TableFunctionProcessorNode. The unused symbol "b" is pruned from the source values node. + assertPlan("SELECT 'constant' c FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true) t(a, b)))", + strictOutput( + ImmutableList.of("c"), + strictProject( + ImmutableMap.of("c", expression("VARCHAR'constant'")), + tableFunctionProcessor( + builder -> builder + .name("pass_through_function") + .properOutputs(ImmutableList.of("x")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("a"))) + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())), + project(ImmutableMap.of("a", expression("INTEGER'1'")), values(1)))))); + } + + @Test + public void testRemoveRedundantTableFunction() + { + assertPlan("SELECT * FROM TABLE(test.system.pass_through_function(input => TABLE(SELECT 1, true WHERE false) t(a, b) PRUNE WHEN EMPTY))", + output(values(ImmutableList.of("x", "a", "b")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false) t2(c, d) KEEP WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) PRUNE WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) PRUNE WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) KEEP WHEN EMPTY))\n", + output(values(ImmutableList.of("column")))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) KEEP WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false WHERE false) t2(c, d) KEEP WHEN EMPTY))\n", + output( + node(TableFunctionProcessorNode.class, + values(ImmutableList.of("a", "marker_1", "c", "marker_2", "row_number"))))); + + assertPlan("SELECT *\n" + + "FROM TABLE(test.system.two_table_arguments_function(\n" + + " input1 => TABLE(SELECT 1, true WHERE false) t1(a, b) KEEP WHEN EMPTY,\n" + + " input2 => TABLE(SELECT 2, false) t2(c, d) PRUNE WHEN EMPTY))\n", + output( + node(TableFunctionProcessorNode.class, + project( + project( + rowNumber( + builder -> builder.partitionBy(ImmutableList.of()), + project( + ImmutableMap.of("c", expression("INTEGER'2'")), + values(1)) + ).withAlias("input_2_row_number", new RowNumberSymbolMatcher())))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java index e5838185f495f..559f86dc3a6bf 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/PlanMatchPattern.java @@ -666,6 +666,11 @@ public static PlanMatchPattern values(List aliases, List aliases) { return values(aliases, Optional.empty()); @@ -701,6 +706,27 @@ public static PlanMatchPattern remoteSource(List sourceFragmentI return node(RemoteSourceNode.class).with(new RemoteSourceMatcher(sourceFragmentIds, outputSymbolAliases)); } + public static PlanMatchPattern tableFunction(Consumer handler, PlanMatchPattern... sources) + { + TableFunctionMatcher.Builder builder = new TableFunctionMatcher.Builder(sources); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler, PlanMatchPattern source) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(source); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(); + handler.accept(builder); + return builder.build(); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java new file mode 100644 index 0000000000000..c14b68b443867 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionMatcher.java @@ -0,0 +1,412 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.table.Argument; +import com.facebook.presto.spi.function.table.Descriptor; +import com.facebook.presto.spi.function.table.DescriptorArgument; +import com.facebook.presto.spi.function.table.ScalarArgument; +import com.facebook.presto.spi.function.table.TableArgument; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReferences; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static java.util.Objects.requireNonNull; + +public class TableFunctionMatcher + implements Matcher +{ + private final String name; + private final Map arguments; + private final List properOutputs; + private final List> copartitioningLists; + + private TableFunctionMatcher( + String name, + Map arguments, + List properOutputs, + List> copartitioningLists) + { + this.name = requireNonNull(name, "name is null"); + this.arguments = ImmutableMap.copyOf(requireNonNull(arguments, "arguments is null")); + this.properOutputs = ImmutableList.copyOf(requireNonNull(properOutputs, "properOutputs is null")); + requireNonNull(copartitioningLists, "copartitioningLists is null"); + this.copartitioningLists = copartitioningLists.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionNode tableFunctionNode = (TableFunctionNode) node; + + if (!name.equals(tableFunctionNode.getName())) { + return NO_MATCH; + } + + if (arguments.size() != tableFunctionNode.getArguments().size()) { + return NO_MATCH; + } + for (Map.Entry entry : arguments.entrySet()) { + String name = entry.getKey(); + Argument actual = tableFunctionNode.getArguments().get(name); + if (actual == null) { + return NO_MATCH; + } + ArgumentValue expected = entry.getValue(); + switch (expected.getType()) { + case DescriptorArgumentValue.type: + DescriptorArgumentValue expectedDescriptor = (DescriptorArgumentValue) expected; + if (!(actual instanceof DescriptorArgument) || !expectedDescriptor.getDescriptor().equals(((DescriptorArgument) actual).getDescriptor())) { + return NO_MATCH; + } + break; + case ScalarArgumentValue.type: + ScalarArgumentValue expectedScalar = (ScalarArgumentValue) expected; + if (!(actual instanceof ScalarArgument) || !Objects.equals(expectedScalar.getValue(), ((ScalarArgument) actual).getValue())) { + return NO_MATCH; + } + break; + default: + if (!(actual instanceof TableArgument) || getMatchResult(symbolAliases, (TableArgumentValue) expected, tableFunctionNode, name).equals(NO_MATCH)) { + return NO_MATCH; + } + } + } + + if (!ImmutableSet.copyOf(copartitioningLists).equals(ImmutableSet.copyOf(tableFunctionNode.getCopartitioningLists()))) { + return NO_MATCH; + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + private MatchResult getMatchResult(SymbolAliases symbolAliases, TableArgumentValue expected, TableFunctionNode tableFunctionNode, String name) + { + TableArgumentValue expectedTableArgument = expected; + TableArgumentProperties argumentProperties = tableFunctionNode.getTableArgumentProperties().get(expectedTableArgument.sourceIndex()); + if (!name.equals(argumentProperties.getArgumentName())) { + return NO_MATCH; + } + if (expectedTableArgument.rowSemantics() != argumentProperties.isRowSemantics() || + expectedTableArgument.pruneWhenEmpty() != argumentProperties.isPruneWhenEmpty() || + expectedTableArgument.passThroughColumns() != argumentProperties.getPassThroughSpecification().isDeclaredAsPassThrough()) { + return NO_MATCH; + } + + if (expectedTableArgument.specification().isPresent() != argumentProperties.getSpecification().isPresent()) { + return NO_MATCH; + } + if (!expectedTableArgument.specification() + .map(expectedSpecification -> matchSpecification(argumentProperties.getSpecification().get(), expectedSpecification.getExpectedValue(symbolAliases))) + .orElse(true)) { + return NO_MATCH; + } + Set expectedPassThrough = expectedTableArgument.passThroughVariables().stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = toSymbolReferences( + argumentProperties.getPassThroughSpecification().getColumns().stream() + .map(TableFunctionNode.PassThroughColumn::getOutputVariables) + .collect(Collectors.toList())) + .stream() + .map(SymbolReference.class::cast) + .collect(Collectors.toSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + return match(symbolAliases); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("arguments", arguments) + .add("properOutputs", properOutputs) + .add("copartitioningLists", copartitioningLists) + .toString(); + } + + public static class Builder + { + private final PlanMatchPattern[] sources; + private String name; + private final ImmutableMap.Builder arguments = ImmutableMap.builder(); + private List properOutputs = ImmutableList.of(); + private final ImmutableList.Builder> copartitioningLists = ImmutableList.builder(); + + Builder(PlanMatchPattern... sources) + { + this.sources = Arrays.copyOf(sources, sources.length); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder addDescriptorArgument(String name, DescriptorArgumentValue descriptor) + { + this.arguments.put(name, descriptor); + return this; + } + + public Builder addScalarArgument(String name, Object value) + { + this.arguments.put(name, new ScalarArgumentValue(value)); + return this; + } + + public Builder addTableArgument(String name, TableArgumentValue.Builder tableArgument) + { + this.arguments.put(name, tableArgument.build()); + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder addCopartitioning(List copartitioning) + { + this.copartitioningLists.add(copartitioning); + return this; + } + + public PlanMatchPattern build() + { + return node(TableFunctionNode.class, sources) + .with(new TableFunctionMatcher(name, arguments.buildOrThrow(), properOutputs, copartitioningLists.build())); + } + } + + interface ArgumentValue + { + String getType(); + } + + public static class DescriptorArgumentValue + implements ArgumentValue + { + private final Optional descriptor; + public static final String type = "Descriptor"; + + public DescriptorArgumentValue(Optional descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + public static DescriptorArgumentValue descriptorArgument(Descriptor descriptor) + { + return new DescriptorArgumentValue(Optional.of(requireNonNull(descriptor, "descriptor is null"))); + } + + public static DescriptorArgumentValue nullDescriptor() + { + return new DescriptorArgumentValue(Optional.empty()); + } + + public Optional getDescriptor() + { + return descriptor; + } + + @Override + public String getType() + { + return type; + } + } + + public static class ScalarArgumentValue + implements ArgumentValue + { + private final Object value; + public static final String type = "Scalar"; + + public ScalarArgumentValue(Object value) + { + this.value = value; + } + + public Object getValue() + { + return value; + } + + @Override + public String getType() + { + return type; + } + } + + public static class TableArgumentValue + implements ArgumentValue + { + private final int sourceIndex; + private final boolean rowSemantics; + private final boolean pruneWhenEmpty; + private final boolean passThroughColumns; + private final Optional> specification; + private final Set passThroughVariables; + public static final String type = "Table"; + + public TableArgumentValue(int sourceIndex, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, Optional> specification, Set passThroughVariables) + { + this.sourceIndex = sourceIndex; + this.rowSemantics = rowSemantics; + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughColumns = passThroughColumns; + this.specification = requireNonNull(specification, "specification is null"); + this.passThroughVariables = ImmutableSet.copyOf(passThroughVariables); + } + + public int sourceIndex() + { + return sourceIndex; + } + + public boolean rowSemantics() + { + return rowSemantics; + } + + public boolean pruneWhenEmpty() + { + return pruneWhenEmpty; + } + + public boolean passThroughColumns() + { + return passThroughColumns; + } + + public Set passThroughVariables() + { + return passThroughVariables; + } + + public Optional> specification() + { + return specification; + } + + @Override + public String getType() + { + return type; + } + + public static class Builder + { + private final int sourceIndex; + private boolean rowSemantics; + private boolean pruneWhenEmpty; + private boolean passThroughColumns; + private Optional> specification = Optional.empty(); + private Set passThroughVariables = ImmutableSet.of(); + + private Builder(int sourceIndex) + { + this.sourceIndex = sourceIndex; + } + + public static Builder tableArgument(int sourceIndex) + { + return new Builder(sourceIndex); + } + + public Builder rowSemantics() + { + this.rowSemantics = true; + this.pruneWhenEmpty = true; + return this; + } + + public Builder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public Builder passThroughColumns() + { + this.passThroughColumns = true; + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public Builder passThroughVariables(Set variables) + { + this.passThroughVariables = variables; + return this; + } + + private TableArgumentValue build() + { + return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification, passThroughVariables); + } + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java new file mode 100644 index 0000000000000..4891c3eb021dd --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/assertions/TableFunctionProcessorMatcher.java @@ -0,0 +1,239 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.assertions; + +import com.facebook.presto.Session; +import com.facebook.presto.cost.StatsProvider; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.sql.planner.QueryPlanner; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Optional; + +import static com.facebook.presto.sql.planner.QueryPlanner.toSymbolReference; +import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH; +import static com.facebook.presto.sql.planner.assertions.MatchResult.match; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.SpecificationProvider.matchSpecification; +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorMatcher + implements Matcher +{ + private final String name; + private final List properOutputs; + private final List> passThroughSymbols; + private final List> requiredSymbols; + private final Optional> markerSymbols; + private final Optional> specification; + private final Optional hashSymbol; + + private TableFunctionProcessorMatcher( + String name, + List properOutputs, + List> passThroughSymbols, + List> requiredSymbols, + Optional> markerSymbols, + Optional> specification, + Optional hashSymbol) + { + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.passThroughSymbols = passThroughSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.requiredSymbols = requiredSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionProcessorNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionProcessorNode tableFunctionProcessorNode = (TableFunctionProcessorNode) node; + + if (!name.equals(tableFunctionProcessorNode.getName())) { + return NO_MATCH; + } + + if (properOutputs.size() != tableFunctionProcessorNode.getProperOutputs().size()) { + return NO_MATCH; + } + + List> expectedPassThrough = passThroughSymbols.stream() + .map(list -> list.stream() + .map(symbolAliases::get) + .collect(toImmutableList())) + .collect(toImmutableList()); + List> actualPassThrough = tableFunctionProcessorNode.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::getColumns) + .map(list -> list.stream() + .map(PassThroughColumn::getOutputVariables) + .map(QueryPlanner::toSymbolReference) + .collect(toImmutableList())) + .collect(toImmutableList()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + + if (markerSymbols.isPresent() != tableFunctionProcessorNode.getMarkerVariables().isPresent()) { + return NO_MATCH; + } + if (markerSymbols.isPresent()) { + Map expectedMapping = markerSymbols.get().entrySet().stream() + .collect(toImmutableMap(entry -> symbolAliases.get(entry.getKey()), entry -> symbolAliases.get(entry.getValue()))); + Map actualMapping = tableFunctionProcessorNode.getMarkerVariables().get().entrySet().stream() + .collect(toImmutableMap(entry -> toSymbolReference(entry.getKey()), entry -> toSymbolReference(entry.getValue()))); + if (!expectedMapping.equals(actualMapping)) { + return NO_MATCH; + } + } + + if (specification.isPresent() != tableFunctionProcessorNode.getSpecification().isPresent()) { + return NO_MATCH; + } + if (specification.isPresent()) { + if (!matchSpecification(specification.get().getExpectedValue(symbolAliases), tableFunctionProcessorNode.getSpecification().orElseThrow(NoSuchElementException::new))) { + return NO_MATCH; + } + } + if (hashSymbol.isPresent()) { + if (!hashSymbol.map(symbolAliases::get).equals(tableFunctionProcessorNode.getHashSymbol().map(QueryPlanner::toSymbolReference))) { + return NO_MATCH; + } + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), toSymbolReference(tableFunctionProcessorNode.getProperOutputs().get(i))); + } + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("properOutputs", properOutputs) + .add("passThroughSymbols", passThroughSymbols) + .add("requiredSymbols", requiredSymbols) + .add("markerSymbols", markerSymbols) + .add("specification", specification) + .add("hashSymbol", hashSymbol) + .toString(); + } + + public static class Builder + { + private final Optional source; + private String name; + private List properOutputs = ImmutableList.of(); + private List> passThroughSymbols = ImmutableList.of(); + private List> requiredSymbols = ImmutableList.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional> specification = Optional.empty(); + private Optional hashSymbol = Optional.empty(); + + public Builder() + { + this.source = Optional.empty(); + } + + public Builder(PlanMatchPattern source) + { + this.source = Optional.of(source); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder passThroughSymbols(List> passThroughSymbols) + { + this.passThroughSymbols = passThroughSymbols; + return this; + } + + public Builder requiredSymbols(List> requiredSymbols) + { + this.requiredSymbols = requiredSymbols; + return this; + } + + public Builder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public Builder hashSymbol(String hashSymbol) + { + this.hashSymbol = Optional.of(hashSymbol); + return this; + } + + public PlanMatchPattern build() + { + PlanMatchPattern[] sources = source.map(sourcePattern -> new PlanMatchPattern[] {sourcePattern}).orElse(new PlanMatchPattern[] {}); + return node(TableFunctionProcessorNode.class, sources) + .with(new TableFunctionProcessorMatcher(name, properOutputs, passThroughSymbols, requiredSymbols, markerSymbols, specification, hashSymbol)); + } + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestImplementTableFunctionSource.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestImplementTableFunctionSource.java new file mode 100644 index 0000000000000..b90eafe5b9fd8 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestImplementTableFunctionSource.java @@ -0,0 +1,1404 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST; +import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_FIRST; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.window; +import static com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; + +public class TestImplementTableFunctionSource + extends BaseRuleTest +{ + @Test + public void testNoSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> p.tableFunction( + "test_function", + ImmutableList.of(p.variable("a")), + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of())) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a")))); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + // no pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"))), + values("c"))); + + // pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableFunctionNode.TableArgumentProperties( + "table_argument", + true, + true, + new TableFunctionNode.PassThroughSpecification(true, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, false))), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"))), + values("c"))); + } + + @Test + public void testSingleSourceWithSetSemantics() + { + // no pass-through columns, no partition by + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .specification(specification(ImmutableList.of(), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // no pass-through columns, partitioning column present + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new TableFunctionNode.PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_LAST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, false))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty())))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c", "d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"))) + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())), + values("c", "d"))); + } + + @Test + public void testTwoSourcesWithSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty())))), + ImmutableList.of()); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testThreeSourcesWithSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + VariableReferenceExpression h = p.variable("h"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f), + p.values(g, h)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(h), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(h, DESC_NULLS_FIRST)))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"), ImmutableList.of())) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("f"), ImmutableList.of("h"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2", + "g", "marker_3", + "h", "marker_3")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = input_3_row_number OR " + + "(combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f"))))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST))) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST)) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("g", "h")))))))); + } + + @Test + public void testTwoCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST)) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testCoPartitionJoinTypes() + { + // both sources are prune when empty, so they are combined using inner join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + + // only the left source is prune when empty, so sources are combined using left join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR " + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + + // only the right source is prune when empty. the sources are reordered so that the prune when empty source is first. they are combined using left join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), input_2_row_number, input_1_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), input_2_partition_size, input_1_partition_size)"), + "combined_partition_column", expression("COALESCE(d, c)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM c) " + + "AND (" + + " input_2_row_number = input_1_row_number OR" + + " (input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' OR" + + " input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c")))))))); + + // neither source is prune when empty, so sources are combined using full join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d)" + + " AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d")))))))); + } + + @Test + public void testThreeCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2_3"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)"), + "combined_partition_column_1_2_3", expression("COALESCE(combined_partition_column_1_2, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (combined_partition_column_1_2 IS DISTINCT FROM e) " + + "AND (" + + " combined_row_number_1_2 = input_3_row_number OR" + + " (combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR" + + " input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1'))"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND (" + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e")))))))); + } + + @Test + public void testTwoCoPartitionLists() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression g = p.variable("g"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e), + p.values(f, g)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty()))), + new TableArgumentProperties( + "input_4", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(f, true))), + ImmutableList.of(g), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(f), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(g, DESC_NULLS_FIRST)))))))), + ImmutableList.of( + ImmutableList.of("input_1", "input_2"), + ImmutableList.of("input_3", "input_4"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"), ImmutableList.of("f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"), ImmutableList.of("g"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3", + "f", "marker_4", + "g", "marker_4")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2", "combined_partition_column_3_4"), ImmutableList.of("combined_row_number_1_2_3_4"), ImmutableMap.of("combined_row_number_1_2_3_4", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3_4, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3_4, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3_4, input_3_row_number, null)"), + "marker_4", expression("IF(input_4_row_number = combined_row_number_1_2_3_4, input_4_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3_4", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(combined_row_number_3_4, BIGINT '-1'), combined_row_number_1_2, combined_row_number_3_4)"), + "combined_partition_size_1_2_3_4", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(combined_partition_size_3_4, BIGINT '-1'), combined_partition_size_1_2, combined_partition_size_3_4)")), + join(// join nodes using helper symbols + JoinType.LEFT, + ImmutableList.of(), + Optional.of("combined_row_number_1_2 = combined_row_number_3_4 OR " + + "(combined_row_number_1_2 > combined_partition_size_3_4 AND combined_row_number_3_4 = BIGINT '1' OR " + + "combined_row_number_3_4 > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + JoinType.INNER, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM d) " + + "AND ( " + + "input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))))), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_3_4", expression("IF(COALESCE(input_3_row_number, BIGINT '-1') > COALESCE(input_4_row_number, BIGINT '-1'), input_3_row_number, input_4_row_number)"), + "combined_partition_size_3_4", expression("IF(COALESCE(input_3_partition_size, BIGINT '-1') > COALESCE(input_4_partition_size, BIGINT '-1'), input_3_partition_size, input_4_partition_size)"), + "combined_partition_column_3_4", expression("COALESCE(e, f)")), + join(// co-partition nodes + JoinType.FULL, + ImmutableList.of(), + Optional.of("NOT (e IS DISTINCT FROM f) " + + "AND ( " + + "input_3_row_number = input_4_row_number OR " + + "(input_3_row_number > input_4_partition_size AND input_4_row_number = BIGINT '1' OR " + + "input_4_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1')) "), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e"))), + window(// append helper symbols for source input_4 + builder -> builder + .specification(specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST))) + .addFunction("input_4_partition_size", functionCall("count", ImmutableList.of())), + // input_4 + window(builder -> builder + .specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST)) + .addFunction("input_4_row_number", functionCall("row_number", ImmutableList.of())), + values("f", "g")))))))))); + } + + @Test + public void testCoPartitionedAndNotCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_2_3", "c"), ImmutableList.of("combined_row_number_2_3_1"), ImmutableMap.of("combined_row_number_2_3_1", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_2_3_1, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_2_3_1, input_2_row_number, null)"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_2_3_1, input_3_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_2_3_1", expression("IF(COALESCE(combined_row_number_2_3, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), combined_row_number_2_3, input_1_row_number)"), + "combined_partition_size_2_3_1", expression("IF(COALESCE(combined_partition_size_2_3, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), combined_partition_size_2_3, input_1_partition_size)")), + join(// join nodes using helper symbols + JoinType.INNER, + ImmutableList.of(), + Optional.of("combined_row_number_2_3 = input_1_row_number OR " + + "(combined_row_number_2_3 > input_1_partition_size AND input_1_row_number = BIGINT '1' OR " + + "input_1_row_number > combined_partition_size_2_3 AND combined_row_number_2_3 = BIGINT '1')"), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_2_3", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), input_2_row_number, input_3_row_number)"), + "combined_partition_size_2_3", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), input_2_partition_size, input_3_partition_size)"), + "combined_partition_column_2_3", expression("COALESCE(d, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (d IS DISTINCT FROM e) " + + "AND ( " + + "input_2_row_number = input_3_row_number OR " + + "(input_2_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1' OR " + + "input_3_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("d"))), + window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())), + values("e"))))), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c")))))))); + } + + @Test + public void testCoerceForCopartitioning() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c", TINYINT); + VariableReferenceExpression cCoerced = p.variable("c_coerced", INTEGER); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e", INTEGER); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + // coerce column c for co-partitioning + p.project( + Assignments.builder() + .put(c, p.rowExpression("c")) + .put(d, p.rowExpression("d")) + .put(cCoerced, p.rowExpression("CAST(c AS INTEGER)")) + .build(), + p.values(c, d)), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(cCoerced), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(f, DESC_NULLS_FIRST)))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "c_coerced", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c_coerced, e)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c_coerced IS DISTINCT FROM e) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + project( + ImmutableMap.of("c_coerced", expression("CAST(c AS INTEGER)")), + values("c", "d")))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST)) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testTwoCoPartitioningColumns() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c, d), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e, f), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c", "d"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column_1", "combined_partition_column_2"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1", expression("COALESCE(c, e)"), + "combined_partition_column_2", expression("COALESCE(d, f)")), + join(// co-partition nodes + JoinType.LEFT, + ImmutableList.of(), + Optional.of("NOT (c IS DISTINCT FROM e) " + + "AND NOT (d IS DISTINCT FROM f) " + + "AND ( " + + " input_1_row_number = input_2_row_number OR" + + " (input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR" + + " input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1'))"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } + + @Test + public void testTwoSourcesWithRowAndSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + true, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(e), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("c"), ImmutableList.of("e", "f"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("d"), ImmutableList.of("e"))) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, null)"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, null)")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + JoinType.FULL, + ImmutableList.of(), + Optional.of("input_1_row_number = input_2_row_number OR " + + "(input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR " + + "input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1')"), + window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + window(builder -> builder + .specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())), + values("c", "d"))), + window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + window(builder -> builder + .specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of()) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())), + values("e", "f")))))))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java new file mode 100644 index 0000000000000..bcae22ae6c623 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorColumns.java @@ -0,0 +1,221 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneTableFunctionProcessorColumns + extends BaseRuleTest +{ + @Test + public void testDoNotPruneProperOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("p")) + .source(p.values(p.variable("x")))))) + .doesNotFire(); + } + + @Test + public void testPrunePassThroughOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of(), + tableFunctionProcessor(builder -> builder + .name("test_function") + .passThroughSymbols(ImmutableList.of(ImmutableList.of())), + values("a", "b")))); + + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of(), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of())), + values("a", "b")))); + } + + @Test + public void testReferencedPassThroughOutputs() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression x = p.variable("x"); + VariableReferenceExpression y = p.variable("y"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(y, y).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(x, y) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .matches(project( + ImmutableMap.of("y", expression("y"), "b", expression("b")), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("x", "y")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("b"))), + values("a", "b")))); + } + + @Test + public void testAllPassThroughOutputsReferenced() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(a, a).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .doesNotFire(); + + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + return p.project( + Assignments.builder().put(a, a).put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications( + new PassThroughSpecification( + true, + ImmutableList.of( + new PassThroughColumn(a, true), + new PassThroughColumn(b, false)))) + .source(p.values(a, b)))); + }) + .doesNotFire(); + } + + @Test + public void testNoSource() + { + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> p.project( + Assignments.of(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("proper"))))) + .doesNotFire(); + } + + @Test + public void testMultipleTableArguments() + { + // multiple pass-through specifications indicate that the table function has multiple table arguments + tester().assertThat(new PruneTableFunctionProcessorColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + return p.project( + Assignments.builder().put(b, b).build(), + p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(p.variable("proper")) + .passThroughSpecifications( + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(a, true))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(b, true))), + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, false)))) + .source(p.values(a, b, c, d)))); + }) + .matches(project( + ImmutableMap.of("b", expression("b")), + tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of(), ImmutableList.of("b"), ImmutableList.of())), + values("a", "b", "c", "d")))); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java new file mode 100644 index 0000000000000..68f56d320e396 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPruneTableFunctionProcessorSourceColumns.java @@ -0,0 +1,198 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.Ordering; +import com.facebook.presto.spi.plan.OrderingScheme; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.specification; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPruneTableFunctionProcessorSourceColumns + extends BaseRuleTest +{ + @Test + public void testPruneUnreferencedSymbol() + { + // symbols 'a', 'b', 'c', 'd', 'hash', and 'marker' are used by the node. + // symbol 'unreferenced' is pruned out. Also, the mapping for this symbol is removed from marker mappings + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression proper = p.variable("proper"); + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + VariableReferenceExpression hash = p.variable("hash"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(a, false)))) + .requiredSymbols(ImmutableList.of(ImmutableList.of(b))) + .markerSymbols(ImmutableMap.of( + a, marker, + b, marker, + c, marker, + d, marker, + unreferenced, marker)) + .specification(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(new Ordering(d, ASC_NULLS_FIRST)))))) + .hashSymbol(hash) + .source(p.values(a, b, c, d, unreferenced, hash, marker))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("proper")) + .passThroughSymbols(ImmutableList.of(ImmutableList.of("a"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("b"))) + .markerSymbols(ImmutableMap.of( + "a", "marker", + "b", "marker", + "c", "marker", + "d", "marker")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_FIRST))) + .hashSymbol("hash"), + project( + ImmutableMap.of( + "a", expression("a"), + "b", expression("b"), + "c", expression("c"), + "d", expression("d"), + "hash", expression("hash"), + "marker", expression("marker")), + values("a", "b", "c", "d", "unreferenced", "hash", "marker")))); + } + + @Test + public void testPruneUnusedMarkerSymbol() + { + // symbol 'unreferenced' is pruned out because the node does not use it. + // also, the mapping for this symbol is removed from marker mappings. + // because the marker symbol 'marker' is no longer used, it is pruned out too. + // note: currently a marker symbol cannot become unused because the function + // must use at least one symbol from each source. it might change in the future. + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .markerSymbols(ImmutableMap.of(unreferenced, marker)) + .source(p.values(unreferenced, marker))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .markerSymbols(ImmutableMap.of()), + project( + ImmutableMap.of(), + values("unreferenced", "marker")))); + } + + @Test + public void testMultipleSources() + { + // multiple pass-through specifications indicate that the table function has multiple table arguments + // the third argument provides symbols 'e', 'f', and 'unreferenced'. those symbols are mapped to common marker symbol 'marker3' + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression b = p.variable("b"); + VariableReferenceExpression c = p.variable("c"); + VariableReferenceExpression d = p.variable("d"); + VariableReferenceExpression e = p.variable("e"); + VariableReferenceExpression f = p.variable("f"); + VariableReferenceExpression marker1 = p.variable("marker1"); + VariableReferenceExpression marker2 = p.variable("marker2"); + VariableReferenceExpression marker3 = p.variable("marker3"); + VariableReferenceExpression unreferenced = p.variable("unreferenced"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .passThroughSpecifications( + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(a, false))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true)))) + .requiredSymbols(ImmutableList.of( + ImmutableList.of(b), + ImmutableList.of(d), + ImmutableList.of(f))) + .markerSymbols(ImmutableMap.of( + a, marker1, + b, marker1, + c, marker2, + d, marker2, + e, marker3, + f, marker3, + unreferenced, marker3)) + .source(p.values(a, b, c, d, e, f, marker1, marker2, marker3, unreferenced))); + }) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .passThroughSymbols(ImmutableList.of(ImmutableList.of("a"), ImmutableList.of("c"), ImmutableList.of("e"))) + .requiredSymbols(ImmutableList.of(ImmutableList.of("b"), ImmutableList.of("d"), ImmutableList.of("f"))) + .markerSymbols(ImmutableMap.of( + "a", "marker1", + "b", "marker1", + "c", "marker2", + "d", "marker2", + "e", "marker3", + "f", "marker3")), + project( + ImmutableMap.of( + "a", expression("a"), + "b", expression("b"), + "c", expression("c"), + "d", expression("d"), + "e", expression("e"), + "f", expression("f"), + "marker1", expression("marker1"), + "marker2", expression("marker2"), + "marker3", expression("marker3")), + values("a", "b", "c", "d", "e", "f", "marker1", "marker2", "marker3", "unreferenced")))); + } + + @Test + public void allSymbolsReferenced() + { + tester().assertThat(new PruneTableFunctionProcessorSourceColumns()) + .on(p -> { + VariableReferenceExpression a = p.variable("a"); + VariableReferenceExpression marker = p.variable("marker"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .requiredSymbols(ImmutableList.of(ImmutableList.of(a))) + .markerSymbols(ImmutableMap.of(a, marker)) + .source(p.values(a, marker))); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunction.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunction.java new file mode 100644 index 0000000000000..d70fecf0f6283 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRemoveRedundantTableFunction.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestRemoveRedundantTableFunction + extends BaseRuleTest +{ + @Test + public void testRemoveTableFunction() + { + tester().assertThat(new RemoveRedundantTableFunction()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .pruneWhenEmpty() + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(passThrough))); + }) + .matches(values("proper", "pass_through")); + } + + @Test + public void testDoNotRemoveKeepWhenEmpty() + { + tester().assertThat(new RemoveRedundantTableFunction()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(passThrough))); + }) + .doesNotFire(); + } + + @Test + public void testDoNotRemoveNonEmptyInput() + { + tester().assertThat(new RemoveRedundantTableFunction()) + .on(p -> { + VariableReferenceExpression passThrough = p.variable("pass_through"); + VariableReferenceExpression proper = p.variable("proper"); + return p.tableFunctionProcessor( + builder -> builder + .name("test_function") + .pruneWhenEmpty() + .properOutputs(proper) + .passThroughSpecifications(new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(passThrough, true)))) + .source(p.values(5, passThrough))); + }) + .doesNotFire(); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 9482554a3dcfb..1868883f566ad 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -19,6 +19,7 @@ import com.facebook.presto.common.predicate.TupleDomain; import com.facebook.presto.common.type.Type; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.TableFunctionHandle; import com.facebook.presto.spi.ColumnHandle; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.IndexHandle; @@ -27,6 +28,7 @@ import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.function.FunctionHandle; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.AggregationNode.Aggregation; import com.facebook.presto.spi.plan.AggregationNode.Step; @@ -86,6 +88,8 @@ import com.facebook.presto.sql.planner.plan.RemoteSourceNode; import com.facebook.presto.sql.planner.plan.RowNumberNode; import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.TableFunctionNode; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator; import com.facebook.presto.sql.tree.Expression; @@ -971,6 +975,32 @@ public WindowNode window(DataOrganizationSpecification specification, Map properOutputs, + List sources, + List tableArgumentProperties, + List> copartitioningLists) + + { + return new TableFunctionNode( + idAllocator.getNextId(), + name, + ImmutableMap.of(), + properOutputs, + sources, + tableArgumentProperties, + copartitioningLists, + new TableFunctionHandle(new ConnectorId("connector_id"), new ConnectorTableFunctionHandle() {}, TestingTransactionHandle.create())); + } + + public TableFunctionProcessorNode tableFunctionProcessor(Consumer consumer) + { + TableFunctionProcessorBuilder tableFunctionProcessorBuilder = new TableFunctionProcessorBuilder(); + consumer.accept(tableFunctionProcessorBuilder); + return tableFunctionProcessorBuilder.build(idAllocator); + } + public RowNumberNode rowNumber(List partitionBy, Optional maxRowCountPerPartition, VariableReferenceExpression rowNumberVariable, PlanNode source) { return new RowNumberNode( diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java new file mode 100644 index 0000000000000..404831b10f0ef --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/TableFunctionProcessorBuilder.java @@ -0,0 +1,140 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule.test; + +import com.facebook.presto.metadata.TableFunctionHandle; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.function.table.ConnectorTableFunctionHandle; +import com.facebook.presto.spi.plan.DataOrganizationSpecification; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import com.facebook.presto.sql.planner.plan.TableFunctionProcessorNode; +import com.facebook.presto.testing.TestingTransactionHandle; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +public class TableFunctionProcessorBuilder +{ + private String name; + private List properOutputs = ImmutableList.of(); + private Optional source = Optional.empty(); + private boolean pruneWhenEmpty; + private List passThroughSpecifications = ImmutableList.of(); + private List> requiredSymbols = ImmutableList.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional specification = Optional.empty(); + private Set prePartitioned = ImmutableSet.of(); + private int preSorted; + private Optional hashSymbol = Optional.empty(); + private ConnectorTableFunctionHandle connectorHandle = new ConnectorTableFunctionHandle() {}; + + public TableFunctionProcessorBuilder() {} + + public TableFunctionProcessorBuilder name(String name) + { + this.name = name; + return this; + } + + public TableFunctionProcessorBuilder properOutputs(VariableReferenceExpression... properOutputs) + { + this.properOutputs = ImmutableList.copyOf(properOutputs); + return this; + } + + public TableFunctionProcessorBuilder source(PlanNode source) + { + this.source = Optional.of(source); + return this; + } + + public TableFunctionProcessorBuilder pruneWhenEmpty() + { + this.pruneWhenEmpty = true; + return this; + } + + public TableFunctionProcessorBuilder passThroughSpecifications(PassThroughSpecification... passThroughSpecifications) + { + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + return this; + } + + public TableFunctionProcessorBuilder requiredSymbols(List> requiredSymbols) + { + this.requiredSymbols = requiredSymbols; + return this; + } + + public TableFunctionProcessorBuilder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public TableFunctionProcessorBuilder specification(DataOrganizationSpecification specification) + { + this.specification = Optional.of(specification); + return this; + } + + public TableFunctionProcessorBuilder prePartitioned(Set prePartitioned) + { + this.prePartitioned = prePartitioned; + return this; + } + + public TableFunctionProcessorBuilder preSorted(int preSorted) + { + this.preSorted = preSorted; + return this; + } + + public TableFunctionProcessorBuilder hashSymbol(VariableReferenceExpression hashSymbol) + { + this.hashSymbol = Optional.of(hashSymbol); + return this; + } + + public TableFunctionProcessorBuilder connectorHandle(ConnectorTableFunctionHandle connectorHandle) + { + this.connectorHandle = connectorHandle; + return this; + } + + public TableFunctionProcessorNode build(PlanNodeIdAllocator idAllocator) + { + return new TableFunctionProcessorNode( + idAllocator.getNextId(), + name, + properOutputs, + source, + pruneWhenEmpty, + passThroughSpecifications, + requiredSymbols, + markerSymbols, + specification, + prePartitioned, + preSorted, + hashSymbol, + new TableFunctionHandle(new ConnectorId("connector_id"), connectorHandle, TestingTransactionHandle.create())); + } +} diff --git a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java index 9d28b9d4b4219..8b3840e633f40 100644 --- a/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/com/facebook/presto/sql/parser/AstBuilder.java @@ -1667,7 +1667,7 @@ public Node visitDescriptorArgument(SqlBaseParser.DescriptorArgumentContext cont @Override public Node visitDescriptorField(SqlBaseParser.DescriptorFieldContext context) { - return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), Optional.of(getType(context.type()))); + return new DescriptorField(getLocation(context), (Identifier) visit(context.identifier()), Optional.ofNullable(context.type()).map(this::getType)); } /**