Skip to content

Commit

Permalink
Additional recursive CTE work (#16616)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay authored Aug 26, 2024
1 parent d95e36f commit e6843dc
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 24 deletions.
118 changes: 117 additions & 1 deletion go/test/endtoend/vtgate/vitess_tester/cte/queries.test
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ CREATE TABLE employees
manager_id INT
);

# Simple recursive CTE using a real table. Select everything from empty table
with recursive cte as (select * from employees union all select * from cte)
select *
from cte;

# Insert data into the tables
INSERT INTO employees (id, name, manager_id)
VALUES (1, 'CEO', NULL),
Expand Down Expand Up @@ -107,4 +112,115 @@ GROUP BY manager_id;
--error infinite recursion
with recursive cte as (select 1 as n union all select n+1 from cte)
select *
from cte;
from cte;

# Define recursive CTE and then use it on the RHS of UNION
WITH RECURSIVE foo AS (SELECT id
FROM employees
WHERE id = 1
UNION ALL
SELECT id + 1
FROM foo
WHERE id < 5)
SELECT id
FROM foo;

# Recursive CTE with UNION DISTINCT
WITH RECURSIVE hierarchy AS (SELECT id, name, manager_id
FROM employees
UNION ALL
SELECT id, name, manager_id
FROM employees
UNION
DISTINCT
SELECT id * 2, name, manager_id
from hierarchy
WHERE id < 10)
SELECT *
FROM hierarchy;

# Select with false condition
with recursive cte as (select * from employees where false union all select * from cte)
select *
from cte;

# Select with no matching rows
with recursive cte as (select * from employees where id > 100 union all select * from cte)
select *
from cte;

# Recursive CTE joined with a normal table. Predicate on the outside should not be pushed in
WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id
FROM employees
WHERE manager_id IS NULL
UNION ALL
SELECT e.id, e.name, e.manager_id
FROM employees e
INNER JOIN emp_cte cte ON e.manager_id = cte.id)
SELECT *
FROM emp_cte
where name = 'Engineer1';

# Query with a recursive CTE in a subquery
SELECT *
FROM (SELECT 1 UNION ALL SELECT 2) AS dt(a)
WHERE EXISTS(WITH RECURSIVE qn AS (SELECT a * 0 AS b UNION ALL SELECT b + 1 FROM qn WHERE b = 0)
SELECT *
FROM qn
WHERE b = a);

# Join with recursive CTE inside a derived table using data from DUAL
SELECT e.id, e.name, e.manager_id, d.id AS cte_id
FROM employees e
JOIN (WITH RECURSIVE foo AS (SELECT 1 AS id
UNION ALL
SELECT id + 1
FROM foo
WHERE id < 5)
SELECT id
FROM foo) d ON e.id = d.id;

# Join with recursive CTE inside a derived table using data from employees table
SELECT e.id, e.name, e.manager_id, d.id AS cte_id
FROM employees e
JOIN (WITH RECURSIVE foo AS (SELECT id
FROM employees
WHERE manager_id IS NULL
UNION ALL
SELECT e.id
FROM employees e
JOIN foo f ON e.manager_id = f.id)
SELECT id
FROM foo) d ON e.id = d.id;

# Recursive CTE within an uncorrelated subquery as a select expression
SELECT e.id,
e.name,
e.manager_id,
(SELECT MAX(cte_id)
FROM (WITH RECURSIVE foo AS (SELECT 1 AS cte_id
UNION ALL
SELECT cte_id + 1
FROM foo
WHERE cte_id < e.id)
SELECT cte_id
FROM foo) AS recursive_result) AS max_cte_id
FROM employees e;

# Recursive CTE used twice in the same query
WITH RECURSIVE employee_hierarchy AS (SELECT id, name, manager_id, 1 AS level
FROM employees
WHERE manager_id IS NULL
UNION ALL
SELECT e.id, e.name, e.manager_id, h.level + 1
FROM employees e
JOIN employee_hierarchy h ON e.manager_id = h.id)
SELECT h1.id AS employee_id,
h1.name AS employee_name,
h1.level AS employee_level,
h2.name AS manager_name,
h2.level AS manager_level
FROM employee_hierarchy h1
LEFT JOIN
employee_hierarchy h2 ON h1.manager_id = h2.id
ORDER BY h1.level, h1.id;
15 changes: 11 additions & 4 deletions go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) {
addPred = stmt.AddWhere
case *sqlparser.Delete:
addPred = stmt.AddWhere
case nil:
// this would happen if we are adding a predicate on a dual query.
// we use this when building recursive CTE queries
sel := &sqlparser.Select{}
addPred = sel.AddWhere
qb.stmt = sel
default:
panic(fmt.Sprintf("cant add WHERE to %T", qb.stmt))
}
Expand Down Expand Up @@ -236,10 +242,11 @@ func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) {
}
}

func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string) {
func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string, distinct bool) {
cteUnion := &sqlparser.Union{
Left: qb.stmt.(sqlparser.SelectStatement),
Right: other.stmt.(sqlparser.SelectStatement),
Left: qb.stmt.(sqlparser.SelectStatement),
Right: other.stmt.(sqlparser.SelectStatement),
Distinct: distinct,
}

qb.stmt = &sqlparser.Select{
Expand Down Expand Up @@ -719,7 +726,7 @@ func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) {
panic(err)
}

qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String())
qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String(), op.Distinct)
}

func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTE, ou
panic(err)
}

return newRecurse(ctx, def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def), outerID)
return newRecurse(ctx, def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def), outerID, union.Distinct)
}

func idForRecursiveTable(ctx *plancontext.PlanningContext, def *semantics.CTE) semantics.TableSet {
Expand Down
13 changes: 11 additions & 2 deletions go/vt/vtgate/planbuilder/operators/cte_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,22 @@ func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator
}

func tryMergeCTE(ctx *plancontext.PlanningContext, seed, term Operator, in *RecurseCTE) *Route {
seedRoute, termRoute, _, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term)
if seedRoute == nil || !sameKeyspace {
seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term)
if seedRoute == nil {
return nil
}

switch {
case a == dual:
return mergeCTE(ctx, seedRoute, termRoute, routingB, in)
case b == dual:
return mergeCTE(ctx, seedRoute, termRoute, routingA, in)
case !sameKeyspace:
return nil
case a == anyShard:
return mergeCTE(ctx, seedRoute, termRoute, routingB, in)
case b == anyShard:
return mergeCTE(ctx, seedRoute, termRoute, routingA, in)
case a == sharded && b == sharded:
return tryMergeCTESharded(ctx, seedRoute, termRoute, in)
default:
Expand Down Expand Up @@ -80,6 +88,7 @@ func mergeCTE(ctx *plancontext.PlanningContext, seed, term *Route, r Routing, in
Term: newTerm,
LeftID: in.LeftID,
OuterID: in.OuterID,
Distinct: in.Distinct,
},
MergedWith: []*Route{term},
}
Expand Down
1 change: 0 additions & 1 deletion go/vt/vtgate/planbuilder/operators/join_merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ func prepareInputRoutes(lhs Operator, rhs Operator) (*Route, *Route, Routing, Ro
lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(lhsRoute, rhsRoute)

a, b := getRoutingType(routingA), getRoutingType(routingB)

return lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace
}

Expand Down
35 changes: 20 additions & 15 deletions go/vt/vtgate/planbuilder/operators/recurse_cte.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package operators

import (
"fmt"
"slices"
"strings"

"golang.org/x/exp/maps"
Expand Down Expand Up @@ -56,6 +57,9 @@ type RecurseCTE struct {

// The OuterID is the id for this use of the CTE
OuterID semantics.TableSet

// Distinct is used to determine if the result set should be distinct
Distinct bool
}

var _ Operator = (*RecurseCTE)(nil)
Expand All @@ -67,6 +71,7 @@ func newRecurse(
predicates []*plancontext.RecurseExpression,
horizon *Horizon,
leftID, outerID semantics.TableSet,
distinct bool,
) *RecurseCTE {
for _, pred := range predicates {
ctx.AddJoinPredicates(pred.Original, pred.RightExpr)
Expand All @@ -79,21 +84,18 @@ func newRecurse(
Horizon: horizon,
LeftID: leftID,
OuterID: outerID,
Distinct: distinct,
}
}

func (r *RecurseCTE) Clone(inputs []Operator) Operator {
return &RecurseCTE{
Seed: inputs[0],
Term: inputs[1],
Def: r.Def,
Predicates: r.Predicates,
Projections: r.Projections,
Vars: maps.Clone(r.Vars),
Horizon: r.Horizon,
LeftID: r.LeftID,
OuterID: r.OuterID,
}
klone := *r
klone.Seed = inputs[0]
klone.Term = inputs[1]
klone.Vars = maps.Clone(r.Vars)
klone.Predicates = slices.Clone(r.Predicates)
klone.Projections = slices.Clone(r.Projections)
return &klone
}

func (r *RecurseCTE) Inputs() []Operator {
Expand All @@ -106,8 +108,7 @@ func (r *RecurseCTE) SetInputs(operators []Operator) {
}

func (r *RecurseCTE) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Expr) Operator {
r.Term = newFilter(r, e)
return r
return newFilter(r, e)
}

func (r *RecurseCTE) AddColumn(ctx *plancontext.PlanningContext, _, _ bool, expr *sqlparser.AliasedExpr) int {
Expand Down Expand Up @@ -162,13 +163,17 @@ func (r *RecurseCTE) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.
}

func (r *RecurseCTE) ShortDescription() string {
distinct := ""
if r.Distinct {
distinct = "distinct "
}
if len(r.Vars) > 0 {
return fmt.Sprintf("%v", r.Vars)
return fmt.Sprintf("%s%v", distinct, r.Vars)
}
expressions := slice.Map(r.expressions(), func(expr *plancontext.RecurseExpression) string {
return sqlparser.String(expr.Original)
})
return fmt.Sprintf("%v %v", r.Def.Name, strings.Join(expressions, ", "))
return fmt.Sprintf("%s%v %v", distinct, r.Def.Name, strings.Join(expressions, ", "))
}

func (r *RecurseCTE) GetOrdering(*plancontext.PlanningContext) []OrderBy {
Expand Down
Loading

0 comments on commit e6843dc

Please sign in to comment.