diff --git a/go/test/endtoend/vtgate/queries/misc/misc_test.go b/go/test/endtoend/vtgate/queries/misc/misc_test.go index 7ab0fe7ef54..9c1aef2eb0d 100644 --- a/go/test/endtoend/vtgate/queries/misc/misc_test.go +++ b/go/test/endtoend/vtgate/queries/misc/misc_test.go @@ -19,6 +19,7 @@ package misc import ( "context" "database/sql" + "encoding/json" "fmt" "strings" "testing" @@ -34,7 +35,7 @@ import ( "vitess.io/vitess/go/test/endtoend/utils" ) -func start(t *testing.T) (utils.MySQLCompare, func()) { +func start(t testing.TB) (utils.MySQLCompare, func()) { mcmp, err := utils.NewMySQLCompare(t, vtParams, mysqlParams) require.NoError(t, err) @@ -53,6 +54,96 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { } } +func BenchmarkValuesJoin(b *testing.B) { + mcmp, closer := start(b) + defer closer() + + type Rep struct { + QueriesRouted map[string]int `json:"QueriesRouted"` + } + + getQueriesRouted := func(thisB *testing.B) int { + _, response, _ := clusterInstance.VtgateProcess.MakeAPICall("/debug/vars") + r := Rep{} + err := json.Unmarshal([]byte(response), &r) + require.NoError(thisB, err) + + var res int + for _, c := range r.QueriesRouted { + res += c + } + return res + } + + b.ReportAllocs() + + lhsRowCount := 0 + rhsRowCount := 0 + + insertLHS := func(count int) { + for ; lhsRowCount < count; lhsRowCount++ { + mcmp.Exec(fmt.Sprintf("insert into t1(id1, id2) values (%d, %d)", lhsRowCount, lhsRowCount)) + } + } + insertRHS := func(count int) { + for ; rhsRowCount < count; rhsRowCount++ { + mcmp.Exec(fmt.Sprintf("insert into tbl(id, unq_col, nonunq_col) values (%d, %d, %d)", rhsRowCount, rhsRowCount, rhsRowCount)) + } + } + + testCases := []struct { + lhsRowCount int + rhsRowCount int + }{ + { + lhsRowCount: 20, + rhsRowCount: 10, + }, + { + lhsRowCount: 50, + rhsRowCount: 25, + }, + { + lhsRowCount: 100, + rhsRowCount: 50, + }, + { + lhsRowCount: 200, + rhsRowCount: 100, + }, + { + lhsRowCount: 500, + rhsRowCount: 250, + }, + { + lhsRowCount: 1000, + rhsRowCount: 500, + }, + { + lhsRowCount: 2000, + rhsRowCount: 1000, + }, + } + + var previousQueriesRoutedSum int + for _, testCase := range testCases { + insertLHS(testCase.lhsRowCount) + insertRHS(testCase.rhsRowCount) + + b.Run(fmt.Sprintf("LHS(%d) RHS(%d)", testCase.lhsRowCount, testCase.rhsRowCount), func(b *testing.B) { + for range b.N { + mcmp.Exec("select t1.id1, tbl.id from t1, tbl where t1.id2 = tbl.nonunq_col") + } + b.StopTimer() + + totalQueriesRouted := getQueriesRouted(b) + queriesRouted := totalQueriesRouted - previousQueriesRoutedSum + previousQueriesRoutedSum = totalQueriesRouted + b.ReportMetric(float64(queriesRouted/b.N), "queries_routed/op") + }) + } +} + func TestBitVals(t *testing.T) { mcmp, closer := start(t) defer closer() diff --git a/go/vt/sqlparser/comments.go b/go/vt/sqlparser/comments.go index dff6f60e531..b48e2a4af92 100644 --- a/go/vt/sqlparser/comments.go +++ b/go/vt/sqlparser/comments.go @@ -49,6 +49,8 @@ const ( DirectiveAllowHashJoin = "ALLOW_HASH_JOIN" // DirectiveQueryPlanner lets the user specify per query which planner should be used DirectiveQueryPlanner = "PLANNER" + // DirectiveAllowValuesJoin allows the planner to use VALUES JOINS when possible. + DirectiveAllowValuesJoin = "ALLOW_VALUES_JOIN" // DirectiveVExplainRunDMLQueries tells vexplain queries/all that it is okay to also run the query. DirectiveVExplainRunDMLQueries = "EXECUTE_DML_QUERIES" // DirectiveConsolidator enables the query consolidator. @@ -554,6 +556,10 @@ func AllowScatterDirective(stmt Statement) bool { return checkDirective(stmt, DirectiveAllowScatter) } +func AllowValuesJoinDirective(stmt Statement) bool { + return checkDirective(stmt, DirectiveAllowValuesJoin) +} + func checkDirective(stmt Statement, key string) bool { cmt, ok := stmt.(Commented) if ok { diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 4a3365623fb..a65f4383440 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -1438,12 +1438,12 @@ func (cached *ValuesJoin) CachedSize(alloc bool) int64 { if cc, ok := cached.Right.(cachedObject); ok { size += cc.CachedSize(true) } - // field Vars []int + // field CopyColumnsToRHS []int { - size += hack.RuntimeAllocSize(int64(cap(cached.Vars)) * int64(8)) + size += hack.RuntimeAllocSize(int64(cap(cached.CopyColumnsToRHS)) * int64(8)) } - // field RowConstructorArg string - size += hack.RuntimeAllocSize(int64(len(cached.RowConstructorArg))) + // field BindVarName string + size += hack.RuntimeAllocSize(int64(len(cached.BindVarName))) // field Cols []int { size += hack.RuntimeAllocSize(int64(cap(cached.Cols)) * int64(8)) diff --git a/go/vt/vtgate/engine/join_values_test.go b/go/vt/vtgate/engine/join_values_test.go deleted file mode 100644 index 068259a4e3e..00000000000 --- a/go/vt/vtgate/engine/join_values_test.go +++ /dev/null @@ -1,101 +0,0 @@ -/* -Copyright 2025 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package engine - -import ( - "context" - "testing" - - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/sqltypes" - querypb "vitess.io/vitess/go/vt/proto/query" -) - -func TestJoinValuesExecute(t *testing.T) { - - /* - select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4 - LHS: select col1, col2, col3 from left - RHS: select col5, col6, id from (values row(1,2), ...) left(id,col1) join right on left.col1 = right.col4 - */ - - leftPrim := &fakePrimitive{ - useNewPrintBindVars: true, - results: []*sqltypes.Result{ - sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col1|col2|col3", - "int64|varchar|varchar", - ), - "1|a|aa", - "2|b|bb", - "3|c|cc", - "4|d|dd", - ), - }, - } - rightPrim := &fakePrimitive{ - useNewPrintBindVars: true, - results: []*sqltypes.Result{ - sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col5|col6|id", - "varchar|varchar|int64", - ), - "d|dd|0", - "e|ee|1", - "f|ff|2", - "g|gg|3", - ), - }, - } - - bv := map[string]*querypb.BindVariable{ - "a": sqltypes.Int64BindVariable(10), - } - - vjn := &ValuesJoin{ - Left: leftPrim, - Right: rightPrim, - Vars: []int{0}, - RowConstructorArg: "v", - Cols: []int{-1, -2, -3, -1, 1, 2}, - ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"}, - } - - r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true) - require.NoError(t, err) - leftPrim.ExpectLog(t, []string{ - `Execute a: type:INT64 value:"10" true`, - }) - rightPrim.ExpectLog(t, []string{ - `Execute a: type:INT64 value:"10" v: [[INT64(0) INT64(1)][INT64(1) INT64(2)][INT64(2) INT64(3)][INT64(3) INT64(4)]] true`, - }) - - result := sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col1|col2|col3|col4|col5|col6", - "int64|varchar|varchar|int64|varchar|varchar", - ), - "1|a|aa|1|d|dd", - "2|b|bb|2|e|ee", - "3|c|cc|3|f|ff", - "4|d|dd|4|g|gg", - ) - expectResult(t, r, result) -} diff --git a/go/vt/vtgate/engine/join_values.go b/go/vt/vtgate/engine/values_join.go similarity index 72% rename from go/vt/vtgate/engine/join_values.go rename to go/vt/vtgate/engine/values_join.go index 7b4fc19e908..ce6711633c9 100644 --- a/go/vt/vtgate/engine/join_values.go +++ b/go/vt/vtgate/engine/values_join.go @@ -18,6 +18,7 @@ package engine import ( "context" + "fmt" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -34,10 +35,24 @@ type ValuesJoin struct { // of the Join. They can be any primitive. Left, Right Primitive - Vars []int - RowConstructorArg string - Cols []int - ColNames []string + // The name for the bind var containing the tuple-of-tuples being sent to the RHS + BindVarName string + + // LHSRowID is the offset of the row ID in the LHS, used to use columns from the LHS in the output + // If LHSRowID is false, the output will be the same as the RHS, so the following fields are ignored - Cols, ColNames. + // We copy everything from the LHS to the RHS in this case, and column names are taken from the RHS. + RowID bool + + // CopyColumnsToRHS are the offsets of columns from LHS we are copying over to the RHS + // []int{0,2} means that the first column in the t-o-t is the first offset from the left and the second column is the third offset + CopyColumnsToRHS []int + + // Cols tells use which side the output columns come from: + // negative numbers are offsets to the left, and positive to the right + Cols []int + + // ColNames are the output column names + ColNames []string } // TryExecute performs a non-streaming exec. @@ -47,7 +62,7 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars return nil, err } bv := &querypb.BindVariable{ - Type: querypb.Type_TUPLE, + Type: querypb.Type_ROW_TUPLE, } if len(lresult.Rows) == 0 && wantfields { // If there are no rows, we still need to construct a single row @@ -60,27 +75,43 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars } bv.Values = append(bv.Values, sqltypes.TupleToProto(vals)) - bindVars[jv.RowConstructorArg] = bv + bindVars[jv.BindVarName] = bv + if jv.RowID { + panic("implement me") + } return jv.Right.GetFields(ctx, vcursor, bindVars) } + rowSize := len(jv.CopyColumnsToRHS) + if jv.RowID { + rowSize++ // +1 since we add the row ID + } for i, row := range lresult.Rows { - newRow := make(sqltypes.Row, 0, len(jv.Vars)+1) // +1 since we always add the row ID - newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID - - for _, loffset := range jv.Vars { - newRow = append(newRow, row[loffset]) + newRow := make(sqltypes.Row, 0, rowSize) + + if jv.RowID { + for _, loffset := range jv.CopyColumnsToRHS { + newRow = append(newRow, row[loffset]) + } + newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID + } else { + newRow = row } bv.Values = append(bv.Values, sqltypes.TupleToProto(newRow)) } - bindVars[jv.RowConstructorArg] = bv + bindVars[jv.BindVarName] = bv rresult, err := vcursor.ExecutePrimitive(ctx, jv.Right, bindVars, wantfields) if err != nil { return nil, err } + if !jv.RowID { + // if we are not using the row ID, we can just return the result from the RHS + return rresult, nil + } + result := &sqltypes.Result{} result.Fields = joinFields(lresult.Fields, rresult.Fields, jv.Cols) @@ -143,8 +174,9 @@ func (jv *ValuesJoin) description() PrimitiveDescription { OperatorType: "Join", Variant: "Values", Other: map[string]any{ - "ValuesArg": jv.RowConstructorArg, - "Vars": jv.Vars, + "BindVarName": jv.BindVarName, + "CopyColumnsToRHS": jv.CopyColumnsToRHS, + "RowID": fmt.Sprintf("%t", jv.RowID), }, } } diff --git a/go/vt/vtgate/engine/values_join_test.go b/go/vt/vtgate/engine/values_join_test.go new file mode 100644 index 00000000000..29297d6aa32 --- /dev/null +++ b/go/vt/vtgate/engine/values_join_test.go @@ -0,0 +1,149 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + querypb "vitess.io/vitess/go/vt/proto/query" + + "vitess.io/vitess/go/sqltypes" +) + +func TestJoinValuesExecute(t *testing.T) { + + type testCase struct { + rowID bool + cols []int + CopyColumnsToRHS []int + rhsResults []*sqltypes.Result + expectedRHSLog []string + } + + testCases := []testCase{ + { + /* + select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4 + LHS: select col1, col2, col3 from left + RHS: select col5, col6, id from (values row(1,2), ...) left(id,col1) join right on left.col1 = right.col4 + */ + + rowID: true, + cols: []int{-1, -2, -3, -1, 1, 2}, + CopyColumnsToRHS: []int{0}, + rhsResults: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col5|col6|id", + "varchar|varchar|int64", + ), + "d|dd|0", + "e|ee|1", + "f|ff|2", + "g|gg|3", + ), + }, + expectedRHSLog: []string{ + `Execute a: type:INT64 value:"10" v: [[INT64(1) INT64(0)][INT64(2) INT64(1)][INT64(3) INT64(2)][INT64(4) INT64(3)]] true`, + }, + }, { + /* + select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4 + LHS: select col1, col2, col3 from left + RHS: select col1, col2, col3, col4, col5, col6 from (values row(1,2,3), ...) left(col1,col2,col3) join right on left.col1 = right.col4 + */ + + rowID: false, + rhsResults: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3|col4|col5|col6", + "int64|varchar|varchar|int64|varchar|varchar", + ), + "1|a|aa|1|d|dd", + "2|b|bb|2|e|ee", + "3|c|cc|3|f|ff", + "4|d|dd|4|g|gg", + ), + }, + expectedRHSLog: []string{ + `Execute a: type:INT64 value:"10" v: [[INT64(1) VARCHAR("a") VARCHAR("aa")][INT64(2) VARCHAR("b") VARCHAR("bb")][INT64(3) VARCHAR("c") VARCHAR("cc")][INT64(4) VARCHAR("d") VARCHAR("dd")]] true`, + }, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("rowID:%t", tc.rowID), func(t *testing.T) { + leftPrim := &fakePrimitive{ + useNewPrintBindVars: true, + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + "4|d|dd", + ), + }, + } + rightPrim := &fakePrimitive{ + useNewPrintBindVars: true, + results: tc.rhsResults, + } + + bv := map[string]*querypb.BindVariable{ + "a": sqltypes.Int64BindVariable(10), + } + + vjn := &ValuesJoin{ + Left: leftPrim, + Right: rightPrim, + CopyColumnsToRHS: tc.CopyColumnsToRHS, + BindVarName: "v", + Cols: tc.cols, + ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"}, + RowID: tc.rowID, + } + + r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true) + require.NoError(t, err) + leftPrim.ExpectLog(t, []string{ + `Execute a: type:INT64 value:"10" true`, + }) + rightPrim.ExpectLog(t, tc.expectedRHSLog) + + result := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3|col4|col5|col6", + "int64|varchar|varchar|int64|varchar|varchar", + ), + "1|a|aa|1|d|dd", + "2|b|bb|2|e|ee", + "3|c|cc|3|f|ff", + "4|d|dd|4|g|gg", + ) + expectResult(t, r, result) + }) + } +} diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 1b3313b540a..0315e524b3a 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -49,8 +49,6 @@ func transformToPrimitive(ctx *plancontext.PlanningContext, op operators.Operato return transformSubQuery(ctx, op) case *operators.Filter: return transformFilter(ctx, op) - case *operators.Horizon: - panic("should have been solved in the operator") case *operators.Projection: return transformProjection(ctx, op) case *operators.Limit: @@ -79,11 +77,37 @@ func transformToPrimitive(ctx *plancontext.PlanningContext, op operators.Operato return transformRecurseCTE(ctx, op) case *operators.PercentBasedMirror: return transformPercentBasedMirror(ctx, op) + case *operators.ValuesJoin: + return transformValuesJoin(ctx, op) + case *operators.Values: + panic("should have been pushed under a route") + case *operators.Horizon: + panic("should have been solved in the operator") } return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToPrimitive)", op)) } +func transformValuesJoin(ctx *plancontext.PlanningContext, op *operators.ValuesJoin) (engine.Primitive, error) { + lhs, err := transformToPrimitive(ctx, op.LHS) + if err != nil { + return nil, err + } + rhs, err := transformToPrimitive(ctx, op.RHS) + if err != nil { + return nil, err + } + + return &engine.ValuesJoin{ + Left: lhs, + Right: rhs, + CopyColumnsToRHS: op.CopyColumnsToRHS, + BindVarName: op.ValuesDestination, + Cols: op.Columns, + ColNames: op.ColumnName, + }, nil +} + func transformPercentBasedMirror(ctx *plancontext.PlanningContext, op *operators.PercentBasedMirror) (engine.Primitive, error) { primitive, err := transformToPrimitive(ctx, op.Operator()) if err != nil { @@ -172,7 +196,7 @@ func transformInsertionSelection(ctx *plancontext.PlanningContext, op *operators return nil, vterrors.VT13001(fmt.Sprintf("Incorrect type encountered: %T (transformInsertionSelection)", op.Insert)) } - stmt, dmlOp, err := operators.ToSQL(ctx, rb.Source) + stmt, dmlOp, err := operators.ToAST(ctx, rb.Source) if err != nil { return nil, err } @@ -579,7 +603,7 @@ func getHints(cmt *sqlparser.ParsedComments) *queryHints { } func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) (engine.Primitive, error) { - stmt, dmlOp, err := operators.ToSQL(ctx, op.Source) + stmt, dmlOp, err := operators.ToAST(ctx, op.Source) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index b11f49fe936..1fb5df92f12 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -509,7 +509,7 @@ func splitGroupingToLeftAndRight( rhs.addGrouping(ctx, groupBy) columns.addRight(groupBy.Inner) case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)): - jc := breakExpressionInLHSandRHS(ctx, groupBy.Inner, lhs.tableID) + jc := breakApplyJoinExpressionInLHSandRHS(ctx, groupBy.Inner, lhs.tableID) for _, lhsExpr := range jc.LHSExprs { e := lhsExpr.Expr lhs.addGrouping(ctx, NewGroupBy(e)) diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 2297227a22d..18399997df0 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -146,7 +146,7 @@ func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sql rhs := aj.RHS predicates := sqlparser.SplitAndExpression(nil, expr) for _, pred := range predicates { - col := breakExpressionInLHSandRHS(ctx, pred, TableID(aj.LHS)) + col := breakApplyJoinExpressionInLHSandRHS(ctx, pred, TableID(aj.LHS)) aj.JoinPredicates.add(col) ctx.AddJoinPredicates(pred, col.RHSExpr) rhs = rhs.AddPredicate(ctx, col.RHSExpr) @@ -199,7 +199,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq case deps.IsSolvedBy(rhs): col.RHSExpr = e case deps.IsSolvedBy(both): - col = breakExpressionInLHSandRHS(ctx, e, TableID(aj.LHS)) + col = breakApplyJoinExpressionInLHSandRHS(ctx, e, TableID(aj.LHS)) default: panic(vterrors.VT13001(fmt.Sprintf("expression depends on tables outside this join: %s", sqlparser.String(e)))) } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 009f2c7c265..b536ee65791 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -156,7 +156,7 @@ func (jpc *joinPredicateCollector) inspectPredicate( // then we can use this predicate to connect the subquery to the outer query if !deps.IsSolvedBy(jpc.subqID) && deps.IsSolvedBy(jpc.totalID) { jpc.predicates = append(jpc.predicates, predicate) - jc := breakExpressionInLHSandRHS(ctx, predicate, jpc.outerID) + jc := breakApplyJoinExpressionInLHSandRHS(ctx, predicate, jpc.outerID) jpc.joinColumns = append(jpc.joinColumns, jc) pred = jc.RHSExpr } diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index f42ec87404d..03dd5177ba0 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -22,9 +22,9 @@ import ( "vitess.io/vitess/go/vt/vtgate/semantics" ) -// breakExpressionInLHSandRHS takes an expression and +// breakApplyJoinExpressionInLHSandRHS takes an expression and // extracts the parts that are coming from one of the sides into `ColName`s that are needed -func breakExpressionInLHSandRHS( +func breakApplyJoinExpressionInLHSandRHS( ctx *plancontext.PlanningContext, expr sqlparser.Expr, lhs semantics.TableSet, @@ -129,3 +129,28 @@ func getFirstSelect(selStmt sqlparser.TableStatement) *sqlparser.Select { } return firstSelect } + +func breakValuesJoinExpressionInLHS(ctx *plancontext.PlanningContext, + expr sqlparser.Expr, + lhs semantics.TableSet, +) (result valuesJoinColumn) { + result.Original = sqlparser.Clone(expr) + result.PureLHS = true + result.RHS = expr + _ = sqlparser.Rewrite(expr, func(cursor *sqlparser.Cursor) bool { // TODO: rewrite to use Walk instead (no pun intended, promise!) + node := cursor.Node() + col, ok := node.(*sqlparser.ColName) + if !ok { + return true + } + if ctx.SemTable.RecursiveDeps(col) == lhs { + result.LHS = append(result.LHS, col) + // TODO: Fine all the LHS columns, and + // rewrite the expression to use the value join name and the column. + } else { + result.PureLHS = false + } + return true + }, nil) + return +} diff --git a/go/vt/vtgate/planbuilder/operators/expressions_test.go b/go/vt/vtgate/planbuilder/operators/expressions_test.go new file mode 100644 index 00000000000..d162ee9a693 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/expressions_test.go @@ -0,0 +1,58 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +func TestSplitComplexPredicateToLHS(t *testing.T) { + ast, err := sqlparser.NewTestParser().ParseExpr("l.foo + r.bar - l.baz / r.tata = 0") + require.NoError(t, err) + lID := semantics.SingleTableSet(0) + rID := semantics.SingleTableSet(1) + ctx := plancontext.CreateEmptyPlanningContext() + ctx.SemTable = semantics.EmptySemTable() + // simple sem analysis using the column prefix + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + col, ok := node.(*sqlparser.ColName) + if !ok { + return true, nil + } + if col.Qualifier.Name.String() == "l" { + ctx.SemTable.Recursive[col] = lID + } else { + ctx.SemTable.Recursive[col] = rID + } + return false, nil + }, ast) + + valuesJoinCols := breakValuesJoinExpressionInLHS(ctx, ast, lID) + nodes := slice.Map(valuesJoinCols.LHS, func(from sqlparser.Expr) string { + return sqlparser.String(from) + }) + + assert.Equal(t, []string{"l.foo", "l.baz"}, nodes) +} diff --git a/go/vt/vtgate/planbuilder/operators/info_schema_planning.go b/go/vt/vtgate/planbuilder/operators/info_schema_planning.go index 1e15237f30d..2d54e012c7e 100644 --- a/go/vt/vtgate/planbuilder/operators/info_schema_planning.go +++ b/go/vt/vtgate/planbuilder/operators/info_schema_planning.go @@ -21,6 +21,8 @@ import ( "slices" "strings" + "vitess.io/vitess/go/vt/vtgate/semantics" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" @@ -102,6 +104,10 @@ func (isr *InfoSchemaRouting) updateRoutingLogic(ctx *plancontext.PlanningContex return isr } +func (isr *InfoSchemaRouting) AddValuesTableID(id semantics.TableSet) { + panic(vterrors.VT13001("think about values and info schema routing")) +} + func (isr *InfoSchemaRouting) Cost() int { return 0 } diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 1673a7b68f2..ab79417c5c5 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -158,7 +158,7 @@ func addCTEPredicate( } func breakCTEExpressionInLhsAndRhs(ctx *plancontext.PlanningContext, pred sqlparser.Expr, lhsID semantics.TableSet) *plancontext.RecurseExpression { - col := breakExpressionInLHSandRHS(ctx, pred, lhsID) + col := breakApplyJoinExpressionInLHSandRHS(ctx, pred, lhsID) lhsExprs := slice.Map(col.LHSExprs, func(bve BindVarExpr) plancontext.BindVarExpr { col, ok := bve.Expr.(*sqlparser.ColName) diff --git a/go/vt/vtgate/planbuilder/operators/misc_routing.go b/go/vt/vtgate/planbuilder/operators/misc_routing.go index 5cfb52c0248..517bcccd3aa 100644 --- a/go/vt/vtgate/planbuilder/operators/misc_routing.go +++ b/go/vt/vtgate/planbuilder/operators/misc_routing.go @@ -21,6 +21,7 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -77,6 +78,7 @@ func (tr *TargetedRouting) Clone() Routing { func (tr *TargetedRouting) updateRoutingLogic(_ *plancontext.PlanningContext, _ sqlparser.Expr) Routing { return tr } +func (tr *TargetedRouting) AddValuesTableID(semantics.TableSet) {} func (tr *TargetedRouting) Cost() int { return 1 @@ -102,6 +104,8 @@ func (n *NoneRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlparser return n } +func (*NoneRouting) AddValuesTableID(semantics.TableSet) {} + func (n *NoneRouting) Cost() int { return 0 } @@ -129,6 +133,8 @@ func (rr *AnyShardRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlp return rr } +func (tr *AnyShardRouting) AddValuesTableID(semantics.TableSet) {} + func (rr *AnyShardRouting) Cost() int { return 0 } @@ -166,6 +172,8 @@ func (dr *DualRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlparse return dr } +func (tr *DualRouting) AddValuesTableID(semantics.TableSet) {} + func (dr *DualRouting) Cost() int { return 0 } @@ -191,14 +199,16 @@ func (sr *SequenceRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlp return sr } -func (sr *SequenceRouting) Cost() int { +func (*SequenceRouting) AddValuesTableID(semantics.TableSet) {} + +func (*SequenceRouting) Cost() int { return 0 } -func (sr *SequenceRouting) OpCode() engine.Opcode { +func (*SequenceRouting) OpCode() engine.Opcode { return engine.Next } -func (sr *SequenceRouting) Keyspace() *vindexes.Keyspace { +func (*SequenceRouting) Keyspace() *vindexes.Keyspace { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/op_to_ast.go b/go/vt/vtgate/planbuilder/operators/op_to_ast.go new file mode 100644 index 00000000000..1f1798f7d13 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/op_to_ast.go @@ -0,0 +1,397 @@ +/* +Copyright 2022 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +func ToAST(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement, _ Operator, err error) { + defer PanicHandler(&err) + + q := &queryBuilder{ctx: ctx} + buildAST(op, q) + if ctx.SemTable != nil { + q.sortTables() + } + return q.stmt, q.dmlOperator, nil +} + +// buildAST recursively builds the query into an AST, from an operator tree +func buildAST(op Operator, qb *queryBuilder) { + switch op := op.(type) { + case *Table: + buildTable(op, qb) + case *Projection: + buildProjection(op, qb) + case *ApplyJoin: + buildApplyJoin(op, qb) + case *Filter: + buildFilter(op, qb) + case *Horizon: + if op.TableId != nil { + buildDerived(op, qb) + return + } + buildHorizon(op, qb) + case *Limit: + buildLimit(op, qb) + case *Ordering: + buildOrdering(op, qb) + case *Aggregator: + buildAggregation(op, qb) + case *Union: + buildUnion(op, qb) + case *Distinct: + buildDistinct(op, qb) + case *Update: + buildUpdate(op, qb) + case *Delete: + buildDelete(op, qb) + case *Insert: + buildDML(op, qb) + case *RecurseCTE: + buildRecursiveCTE(op, qb) + case *Values: + buildValues(op, qb) + case *ValuesJoin: + buildValuesJoin(op, qb) + default: + panic(vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op))) + } +} + +func buildDistinct(op *Distinct, qb *queryBuilder) { + buildAST(op.Source, qb) + statement := qb.asSelectStatement() + d, ok := statement.(sqlparser.Distinctable) + if !ok { + panic(vterrors.VT13001("expected a select statement with distinct")) + } + d.MakeDistinct() +} + +func buildValuesJoin(op *ValuesJoin, qb *queryBuilder) { + qb.ctx.SkipValuesArgument(op.ValuesDestination) + buildAST(op.LHS, qb) + qbR := &queryBuilder{ctx: qb.ctx} + buildAST(op.RHS, qbR) + qb.joinWith(qbR, nil, sqlparser.NormalJoinType) +} + +func buildValues(op *Values, qb *queryBuilder) { + buildAST(op.Source, qb) + if qb.ctx.IsValuesArgumentSkipped(op.Name) { + return + } + + expr := &sqlparser.DerivedTable{ + Select: &sqlparser.ValuesStatement{ + ListArg: sqlparser.NewListArg(op.Name), + }, + } + + deps := semantics.EmptyTableSet() + for _, ae := range qb.ctx.ValuesJoinColumns[op.Name] { + deps = deps.Merge(qb.ctx.SemTable.RecursiveDeps(ae.Expr)) + } + + qb.addTableExpr(op.Name, op.Name, TableID(op), expr, nil, op.getColumnNamesFromCtx(qb.ctx)) +} + +func buildDelete(op *Delete, qb *queryBuilder) { + qb.stmt = &sqlparser.Delete{ + Ignore: op.Ignore, + Targets: sqlparser.TableNames{op.Target.Name}, + } + buildAST(op.Source, qb) + + qb.dmlOperator = op +} + +func buildUpdate(op *Update, qb *queryBuilder) { + updExprs := getUpdateExprs(op) + upd := &sqlparser.Update{ + Ignore: op.Ignore, + Exprs: updExprs, + } + qb.stmt = upd + qb.dmlOperator = op + buildAST(op.Source, qb) +} + +func getUpdateExprs(op *Update) sqlparser.UpdateExprs { + updExprs := make(sqlparser.UpdateExprs, 0, len(op.Assignments)) + for _, se := range op.Assignments { + updExprs = append(updExprs, &sqlparser.UpdateExpr{ + Name: se.Name, + Expr: se.Expr.EvalExpr, + }) + } + return updExprs +} + +type OpWithAST interface { + Operator + Statement() sqlparser.Statement +} + +func buildDML(op OpWithAST, qb *queryBuilder) { + qb.stmt = op.Statement() + qb.dmlOperator = op +} + +func buildAggregation(op *Aggregator, qb *queryBuilder) { + buildAST(op.Source, qb) + + qb.clearProjections() + + cols := op.GetColumns(qb.ctx) + for _, column := range cols { + qb.addProjection(column) + } + + for _, by := range op.Grouping { + qb.addGroupBy(by.Inner) + simplified := by.Inner + if by.WSOffset != -1 { + qb.addGroupBy(weightStringFor(simplified)) + } + } + if op.WithRollup { + qb.setWithRollup() + } + + if op.DT != nil { + sel := qb.asSelectStatement() + qb.stmt = nil + qb.addTableExpr(op.DT.Alias, op.DT.Alias, TableID(op), &sqlparser.DerivedTable{ + Select: sel, + }, nil, op.DT.Columns) + } +} + +func buildOrdering(op *Ordering, qb *queryBuilder) { + buildAST(op.Source, qb) + + for _, order := range op.Order { + qb.asOrderAndLimit().AddOrder(order.Inner) + } +} + +func buildLimit(op *Limit, qb *queryBuilder) { + buildAST(op.Source, qb) + qb.asOrderAndLimit().SetLimit(op.AST) +} + +func buildTable(op *Table, qb *queryBuilder) { + if !qb.includeTable(op) { + return + } + + dbName := "" + + if op.QTable.IsInfSchema { + dbName = op.QTable.Table.Qualifier.String() + } + qb.addTable(dbName, op.QTable.Table.Name.String(), op.QTable.Alias.As.String(), TableID(op), op.QTable.Alias.Hints) + for _, pred := range op.QTable.Predicates { + qb.addPredicate(pred) + } + for _, name := range op.Columns { + qb.addProjection(&sqlparser.AliasedExpr{Expr: name}) + } +} + +func buildProjection(op *Projection, qb *queryBuilder) { + buildAST(op.Source, qb) + + _, isSel := qb.stmt.(*sqlparser.Select) + if isSel { + qb.clearProjections() + cols := op.GetSelectExprs(qb.ctx) + for _, column := range cols { + qb.addProjection(column) + } + } + + // if the projection is on derived table, we use the select we have + // created above and transform it into a derived table + if op.DT != nil { + sel := qb.asSelectStatement() + qb.stmt = nil + qb.addTableExpr(op.DT.Alias, op.DT.Alias, TableID(op), &sqlparser.DerivedTable{ + Select: sel, + }, nil, op.DT.Columns) + } + + if !isSel { + for _, column := range op.GetSelectExprs(qb.ctx) { + qb.addProjection(column) + } + } +} + +func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) { + predicates := slice.Map(op.JoinPredicates.columns, func(jc applyJoinColumn) sqlparser.Expr { + // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done + err := qb.ctx.SkipJoinPredicates(jc.Original) + if err != nil { + panic(err) + } + return jc.Original + }) + pred := sqlparser.AndExpressions(predicates...) + + buildAST(op.LHS, qb) + + qbR := &queryBuilder{ctx: qb.ctx} + buildAST(op.RHS, qbR) + + switch { + // if we have a recursive cte, we might be missing a statement from one of the sides + case qbR.stmt == nil: + // do nothing + case qb.stmt == nil: + qb.stmt = qbR.stmt + default: + qb.joinWith(qbR, pred, op.JoinType) + } +} + +func buildUnion(op *Union, qb *queryBuilder) { + // the first input is built first + buildAST(op.Sources[0], qb) + + for i, src := range op.Sources { + if i == 0 { + continue + } + + // now we can go over the remaining inputs and UNION them together + qbOther := &queryBuilder{ctx: qb.ctx} + buildAST(src, qbOther) + qb.unionWith(qbOther, op.distinct) + } +} + +func buildFilter(op *Filter, qb *queryBuilder) { + buildAST(op.Source, qb) + + for _, pred := range op.Predicates { + qb.addPredicate(pred) + } +} + +func buildDerived(op *Horizon, qb *queryBuilder) { + buildAST(op.Source, qb) + + sqlparser.RemoveKeyspaceInCol(op.Query) + + stmt := qb.stmt + qb.stmt = nil + switch sel := stmt.(type) { + case *sqlparser.Select: + buildDerivedSelect(op, qb, sel) + return + case *sqlparser.Union: + buildDerivedUnion(op, qb, sel) + return + } + panic(fmt.Sprintf("unknown select statement type: %T", stmt)) +} + +func buildDerivedUnion(op *Horizon, qb *queryBuilder, union *sqlparser.Union) { + opQuery, ok := op.Query.(*sqlparser.Union) + if !ok { + panic(vterrors.VT12001("Horizon contained SELECT but statement was UNION")) + } + + union.Limit = opQuery.Limit + union.OrderBy = opQuery.OrderBy + union.Distinct = opQuery.Distinct + + qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ + Select: union, + }, nil, op.ColumnAliases) +} + +func buildDerivedSelect(op *Horizon, qb *queryBuilder, sel *sqlparser.Select) { + opQuery, ok := op.Query.(*sqlparser.Select) + if !ok { + panic(vterrors.VT12001("Horizon contained UNION but statement was SELECT")) + } + sel.Limit = opQuery.Limit + sel.OrderBy = opQuery.OrderBy + sel.GroupBy = opQuery.GroupBy + sel.Having = mergeHaving(sel.Having, opQuery.Having) + sel.SelectExprs = opQuery.SelectExprs + sel.Distinct = opQuery.Distinct + qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ + Select: sel, + }, nil, op.ColumnAliases) + for _, col := range op.Columns { + qb.addProjection(&sqlparser.AliasedExpr{Expr: col}) + } +} + +func buildHorizon(op *Horizon, qb *queryBuilder) { + buildAST(op.Source, qb) + stripDownQuery(op.Query, qb.asSelectStatement()) + sqlparser.RemoveKeyspaceInCol(qb.stmt) +} + +func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) { + predicates := slice.Map(op.Predicates, func(jc *plancontext.RecurseExpression) sqlparser.Expr { + // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done + err := qb.ctx.SkipJoinPredicates(jc.Original) + if err != nil { + panic(err) + } + return jc.Original + }) + pred := sqlparser.AndExpressions(predicates...) + buildAST(op.Seed(), qb) + qbR := &queryBuilder{ctx: qb.ctx} + buildAST(op.Term(), qbR) + qbR.addPredicate(pred) + infoFor, err := qb.ctx.SemTable.TableInfoFor(op.OuterID) + if err != nil { + panic(err) + } + + qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String(), op.Distinct, op.Def.Columns) +} + +func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { + switch { + case h1 == nil && h2 == nil: + return nil + case h1 == nil: + return h2 + case h2 == nil: + return h1 + default: + h1.Expr = sqlparser.AndExpressions(h1.Expr, h2.Expr) + return h1 + } +} diff --git a/go/vt/vtgate/planbuilder/operators/op_to_ast_test.go b/go/vt/vtgate/planbuilder/operators/op_to_ast_test.go new file mode 100644 index 00000000000..b760df8fc7e --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/op_to_ast_test.go @@ -0,0 +1,117 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +func TestToSQLValues(t *testing.T) { + ctx := plancontext.CreateEmptyPlanningContext() + name := "toto" + ctx.ValuesJoinColumns[name] = []*sqlparser.AliasedExpr{{Expr: sqlparser.NewColName("user_id")}} + + tableName := sqlparser.NewTableName("x") + tableColumn := sqlparser.NewColName("id") + source := &Table{ + QTable: &QueryTable{ + Table: tableName, + Alias: sqlparser.NewAliasedTableExpr(tableName, ""), + }, + Columns: []*sqlparser.ColName{tableColumn}, + } + op := &Values{ + unaryOperator: newUnaryOp(source), + Name: name, + } + + stmt, _, err := ToAST(ctx, op) + require.NoError(t, err) + require.Equal(t, "select id from x, (values ::toto) as t(user_id)", sqlparser.String(stmt)) + + // Now do the same test but with a projection on top + proj := newAliasedProjection(op) + proj.addUnexploredExpr(sqlparser.NewAliasedExpr(tableColumn, ""), tableColumn) + + userIdColName := sqlparser.NewColNameWithQualifier("user_id", sqlparser.NewTableName("t")) + proj.addUnexploredExpr( + sqlparser.NewAliasedExpr(userIdColName, ""), + userIdColName, + ) + + stmt, _, err = ToAST(ctx, proj) + require.NoError(t, err) + require.Equal(t, "select id, t.user_id from x, (values ::toto) as t(user_id)", sqlparser.String(stmt)) +} + +func TestToSQLValuesJoin(t *testing.T) { + // Build a SQL AST from a values join that has been pushed under a route + ctx := plancontext.CreateEmptyPlanningContext() + parser := sqlparser.NewTestParser() + + lhsTableName := sqlparser.NewTableName("x") + lhsTableColumn := sqlparser.NewColName("id") + lhsFilterPred, err := parser.ParseExpr("x.id = 42") + require.NoError(t, err) + + LHS := &Filter{ + unaryOperator: newUnaryOp(&Table{ + QTable: &QueryTable{ + Table: lhsTableName, + Alias: sqlparser.NewAliasedTableExpr(lhsTableName, ""), + }, + Columns: []*sqlparser.ColName{lhsTableColumn}, + }), + Predicates: []sqlparser.Expr{lhsFilterPred}, + } + + const argumentName = "v" + ctx.ValuesJoinColumns[argumentName] = []*sqlparser.AliasedExpr{{Expr: sqlparser.NewColName("user_id")}} + rhsTableName := sqlparser.NewTableName("y") + rhsTableColumn := sqlparser.NewColName("tata") + rhsFilterPred, err := parser.ParseExpr("y.tata = 42") + require.NoError(t, err) + rhsJoinFilterPred, err := parser.ParseExpr("y.tata = x.id") + require.NoError(t, err) + + RHS := &Filter{ + unaryOperator: newUnaryOp(&Values{ + unaryOperator: newUnaryOp(&Table{ + QTable: &QueryTable{ + Table: rhsTableName, + Alias: sqlparser.NewAliasedTableExpr(rhsTableName, ""), + }, + Columns: []*sqlparser.ColName{rhsTableColumn}, + }), + Name: lhsTableName.Name.String(), + }), + Predicates: []sqlparser.Expr{rhsFilterPred, rhsJoinFilterPred}, + } + + vj := &ValuesJoin{ + binaryOperator: newBinaryOp(LHS, RHS), + } + + stmt, _, err := ToAST(ctx, vj) + require.NoError(t, err) + require.Equal(t, "select id, tata from x, y where x.id = 42 and y.tata = 42 and y.tata = x.id", sqlparser.String(stmt)) +} diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index eb6c42b8724..32c4f77e6d7 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -34,6 +34,7 @@ type ( const ( physicalTransform Phase = iota initialPlanning + rewriteApplyJoin pullDistinctFromUnion delegateAggregation recursiveCTEHorizons @@ -50,6 +51,8 @@ func (p Phase) String() string { return "physicalTransform" case initialPlanning: return "initial horizon planning optimization" + case rewriteApplyJoin: + return "rewrite ApplyJoin to ValuesJoin" case pullDistinctFromUnion: return "pull distinct from UNION" case delegateAggregation: @@ -69,8 +72,11 @@ func (p Phase) String() string { } } -func (p Phase) shouldRun(s semantics.QuerySignature) bool { +func (p Phase) shouldRun(ctx *plancontext.PlanningContext) bool { + s := ctx.SemTable.QuerySignature switch p { + case rewriteApplyJoin: + return ctx.AllowValuesJoin case pullDistinctFromUnion: return s.Union case delegateAggregation: @@ -85,6 +91,7 @@ func (p Phase) shouldRun(s semantics.QuerySignature) bool { return s.SubQueries case dmlWithInput: return s.DML + default: return true } @@ -106,11 +113,117 @@ func (p Phase) act(ctx *plancontext.PlanningContext, op Operator) Operator { return settleSubqueries(ctx, op) case dmlWithInput: return findDMLAboveRoute(ctx, op) + case rewriteApplyJoin: + return rewriteApplyToValues(ctx, op) + default: return op } } +func rewriteApplyToValues(ctx *plancontext.PlanningContext, op Operator) Operator { + var skipped []sqlparser.Expr + isSkipped := func(expr sqlparser.Expr) bool { + for _, skip := range skipped { + if ctx.SemTable.EqualsExpr(expr, skip) { + return true + } + } + return false + } + + // Traverse the operator tree to convert ApplyJoin to ValuesJoin. + // Then add a Values node to the RHS of the new ValuesJoin, + // and usually a filter containing the join predicates is placed there. + visit := func(op Operator, lhsTables semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { + aj, ok := op.(*ApplyJoin) + if !ok { + return op, NoRewrite + } + + vj, valuesTableID := newValuesJoin(ctx, aj.LHS, aj.RHS, aj.JoinType) + if vj == nil { + return op, NoRewrite + } + + for _, column := range aj.JoinColumns.columns { + vj.AddColumn(ctx, true, false, aeWrap(column.Original)) + } + + for _, pred := range aj.JoinPredicates.columns { + skipped = append(skipped, pred.RHSExpr) + err := ctx.SkipJoinPredicates(pred.Original) + if err != nil { + panic(err) + } + + newOriginal := sqlparser.Rewrite(pred.Original, nil, func(cursor *sqlparser.Cursor) bool { + col, isCol := cursor.Node().(*sqlparser.ColName) + if !isCol || ctx.SemTable.RecursiveDeps(col) != valuesTableID { + return true + } + + newCol := &sqlparser.ColName{ + Name: sqlparser.NewIdentifierCI(getValuesJoinColName(ctx, vj.ValuesDestination, valuesTableID, col)), + Qualifier: sqlparser.NewTableName(vj.ValuesDestination), + } + ctx.SemTable.CopyExprInfo(pred.Original, newCol) + cursor.Replace(newCol) + return true + }) + + vj.AddJoinPredicate(ctx, newOriginal.(sqlparser.Expr)) + } + + return vj, Rewrote("rewrote ApplyJoin to ValuesJoin") + } + + shouldVisit := func(op Operator) VisitRule { + rb, ok := op.(*Route) + if !ok { + return VisitChildren + } + + routing, ok := rb.Routing.(*ShardedRouting) + if !ok { + return SkipChildren + } + + // We need to skip the predicates that are already pushed down to the mysql - + // we will push down the JoinValues predicates, and they will be used for routing + var preds []sqlparser.Expr + for _, pred := range routing.SeenPredicates { + if !isSkipped(pred) { + preds = append(preds, pred) + } + } + routing.SeenPredicates = preds + + rb.Routing = routing.resetRoutingLogic(ctx) + return SkipChildren + } + + return TopDown(op, TableID, visit, shouldVisit) +} + +func newValuesJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType sqlparser.JoinType) (*ValuesJoin, semantics.TableSet) { + if !joinType.IsInner() { + return nil, semantics.EmptyTableSet() + } + + bindVariableName := ctx.ReservedVars.ReserveVariable("values") + ctx.ValueJoins[bindVariableName] = bindVariableName + v := &Values{ + unaryOperator: newUnaryOp(rhs), + Name: bindVariableName, + TableID: TableID(lhs), + } + return &ValuesJoin{ + binaryOperator: newBinaryOp(lhs, v), + ValuesDestination: bindVariableName, + }, v.TableID +} + type phaser struct { current Phase } @@ -124,7 +237,7 @@ func (p *phaser) next(ctx *plancontext.PlanningContext) Phase { p.current++ - if curr.shouldRun(ctx.SemTable.QuerySignature) { + if curr.shouldRun(ctx) { return curr } } diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/query_builder.go similarity index 51% rename from go/vt/vtgate/planbuilder/operators/SQL_builder.go rename to go/vt/vtgate/planbuilder/operators/query_builder.go index 248064b70fa..fc7358b8631 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/query_builder.go @@ -1,5 +1,5 @@ /* -Copyright 2022 The Vitess Authors. +Copyright 2025 The Vitess Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ import ( "slices" "sort" - "vitess.io/vitess/go/slice" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" @@ -45,17 +44,6 @@ func (qb *queryBuilder) asOrderAndLimit() sqlparser.OrderAndLimit { return qb.stmt.(sqlparser.OrderAndLimit) } -func ToSQL(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement, _ Operator, err error) { - defer PanicHandler(&err) - - q := &queryBuilder{ctx: ctx} - buildQuery(op, q) - if ctx.SemTable != nil { - q.sortTables() - } - return q.stmt, q.dmlOperator, nil -} - // includeTable will return false if the table is a CTE, and it is not merged // it will return true if the table is not a CTE or if it is a CTE and it is merged func (qb *queryBuilder) includeTable(op *Table) bool { @@ -142,7 +130,7 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { addPred = sel.AddWhere qb.stmt = sel default: - panic(fmt.Sprintf("cant add WHERE to %T, %s", qb.stmt, sqlparser.String(expr))) + panic(fmt.Sprintf("cant add WHERE to %T", qb.stmt)) } for _, exp := range sqlparser.SplitAndExpression(nil, expr) { @@ -197,39 +185,6 @@ func (qb *queryBuilder) pushUnionInsideDerived() { qb.stmt = sel } -func unionSelects(exprs []sqlparser.SelectExpr) []sqlparser.SelectExpr { - var selectExprs []sqlparser.SelectExpr - for _, col := range exprs { - switch col := col.(type) { - case *sqlparser.AliasedExpr: - expr := sqlparser.NewColName(col.ColumnName()) - selectExprs = append(selectExprs, &sqlparser.AliasedExpr{Expr: expr}) - default: - selectExprs = append(selectExprs, col) - } - } - return selectExprs -} - -func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.TableStatement) { - colName := column.Name.String() - firstSelect := getFirstSelect(sel) - exprs := firstSelect.GetColumns() - offset := slices.IndexFunc(exprs, func(expr sqlparser.SelectExpr) bool { - switch ae := expr.(type) { - case *sqlparser.StarExpr: - return true - case *sqlparser.AliasedExpr: - // When accessing columns on top of a UNION, we fall back to this simple strategy of string comparisons - return ae.ColumnName() == colName - } - return false - }) - if offset == -1 { - panic(vterrors.VT12001(fmt.Sprintf("did not find column [%s] on UNION", sqlparser.String(column)))) - } -} - func (qb *queryBuilder) clearProjections() { sel, isSel := qb.stmt.(*sqlparser.Select) if !isSel { @@ -267,17 +222,6 @@ func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string qb.addTable("", name, alias, "", nil) } -type FromStatement interface { - GetFrom() []sqlparser.TableExpr - SetFrom([]sqlparser.TableExpr) - GetWherePredicate() sqlparser.Expr - SetWherePredicate(sqlparser.Expr) -} - -var _ FromStatement = (*sqlparser.Select)(nil) -var _ FromStatement = (*sqlparser.Update)(nil) -var _ FromStatement = (*sqlparser.Delete)(nil) - func (qb *queryBuilder) joinWith(other *queryBuilder, onCondition sqlparser.Expr, joinType sqlparser.JoinType) { stmt := qb.stmt.(FromStatement) otherStmt := other.stmt.(FromStatement) @@ -317,48 +261,6 @@ func (qb *queryBuilder) mergeWhereClauses(stmt, otherStmt FromStatement) { } } -func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr { - var lhs sqlparser.TableExpr - fromClause := stmt.GetFrom() - if len(fromClause) == 1 { - lhs = fromClause[0] - } else { - lhs = &sqlparser.ParenTableExpr{Exprs: fromClause} - } - var rhs sqlparser.TableExpr - otherFromClause := otherStmt.GetFrom() - if len(otherFromClause) == 1 { - rhs = otherFromClause[0] - } else { - rhs = &sqlparser.ParenTableExpr{Exprs: otherFromClause} - } - - return &sqlparser.JoinTableExpr{ - LeftExpr: lhs, - RightExpr: rhs, - Join: joinType, - Condition: &sqlparser.JoinCondition{ - On: onCondition, - }, - } -} - -func (qb *queryBuilder) sortTables() { - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - sel, isSel := node.(*sqlparser.Select) - if !isSel { - return true, nil - } - ts := &tableSorter{ - sel: sel, - tbl: qb.ctx.SemTable, - } - sort.Sort(ts) - return true, nil - }, qb.stmt) - -} - type tableSorter struct { sel *sqlparser.Select tbl *semantics.SemTable @@ -390,365 +292,87 @@ func (ts *tableSorter) Swap(i, j int) { ts.sel.From[i], ts.sel.From[j] = ts.sel.From[j], ts.sel.From[i] } -func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) { - switch expr := expr.(type) { - case *sqlparser.AliasedExpr: - sqlparser.RemoveKeyspaceInCol(expr.Expr) - case *sqlparser.StarExpr: - expr.TableName.Qualifier = sqlparser.NewIdentifierCS("") - } -} - -func stripDownQuery(from, to sqlparser.TableStatement) { - switch node := from.(type) { - case *sqlparser.Select: - toNode, ok := to.(*sqlparser.Select) - if !ok { - panic(vterrors.VT13001("AST did not match")) - } - toNode.Distinct = node.Distinct - toNode.GroupBy = node.GroupBy - toNode.Having = node.Having - toNode.OrderBy = node.OrderBy - toNode.Comments = node.Comments - toNode.Limit = node.Limit - toNode.SelectExprs = node.SelectExprs - for _, expr := range toNode.SelectExprs.Exprs { - removeKeyspaceFromSelectExpr(expr) - } - case *sqlparser.Union: - toNode, ok := to.(*sqlparser.Union) - if !ok { - panic(vterrors.VT13001("AST did not match")) - } - stripDownQuery(node.Left, toNode.Left) - stripDownQuery(node.Right, toNode.Right) - toNode.OrderBy = node.OrderBy - default: - panic(vterrors.VT13001(fmt.Sprintf("this should not happen - we have covered all implementations of SelectStatement %T", from))) - } -} - -// buildQuery recursively builds the query into an AST, from an operator tree -func buildQuery(op Operator, qb *queryBuilder) { - switch op := op.(type) { - case *Table: - buildTable(op, qb) - case *Projection: - buildProjection(op, qb) - case *ApplyJoin: - buildApplyJoin(op, qb) - case *Filter: - buildFilter(op, qb) - case *Horizon: - if op.TableId != nil { - buildDerived(op, qb) - return - } - buildHorizon(op, qb) - case *Limit: - buildLimit(op, qb) - case *Ordering: - buildOrdering(op, qb) - case *Aggregator: - buildAggregation(op, qb) - case *Union: - buildUnion(op, qb) - case *Distinct: - buildQuery(op.Source, qb) - statement := qb.asSelectStatement() - d, ok := statement.(sqlparser.Distinctable) - if !ok { - panic(vterrors.VT13001("expected a select statement with distinct")) +func (qb *queryBuilder) sortTables() { + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + sel, isSel := node.(*sqlparser.Select) + if !isSel { + return true, nil } - d.MakeDistinct() - case *Update: - buildUpdate(op, qb) - case *Delete: - buildDelete(op, qb) - case *Insert: - buildDML(op, qb) - case *RecurseCTE: - buildRecursiveCTE(op, qb) - default: - panic(vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op))) - } -} - -func buildDelete(op *Delete, qb *queryBuilder) { - qb.stmt = &sqlparser.Delete{ - Ignore: op.Ignore, - Targets: sqlparser.TableNames{op.Target.Name}, - } - buildQuery(op.Source, qb) - - qb.dmlOperator = op -} - -func buildUpdate(op *Update, qb *queryBuilder) { - updExprs := getUpdateExprs(op) - upd := &sqlparser.Update{ - Ignore: op.Ignore, - Exprs: updExprs, - } - qb.stmt = upd - qb.dmlOperator = op - buildQuery(op.Source, qb) -} - -func getUpdateExprs(op *Update) sqlparser.UpdateExprs { - updExprs := make(sqlparser.UpdateExprs, 0, len(op.Assignments)) - for _, se := range op.Assignments { - updExprs = append(updExprs, &sqlparser.UpdateExpr{ - Name: se.Name, - Expr: se.Expr.EvalExpr, - }) - } - return updExprs -} - -type OpWithAST interface { - Operator - Statement() sqlparser.Statement -} - -func buildDML(op OpWithAST, qb *queryBuilder) { - qb.stmt = op.Statement() - qb.dmlOperator = op -} - -func buildAggregation(op *Aggregator, qb *queryBuilder) { - buildQuery(op.Source, qb) - - qb.clearProjections() - - cols := op.GetColumns(qb.ctx) - for _, column := range cols { - qb.addProjection(column) - } - - for _, by := range op.Grouping { - qb.addGroupBy(by.Inner) - simplified := by.Inner - if by.WSOffset != -1 { - qb.addGroupBy(weightStringFor(simplified)) + ts := &tableSorter{ + sel: sel, + tbl: qb.ctx.SemTable, } - } - if op.WithRollup { - qb.setWithRollup() - } - - if op.DT != nil { - sel := qb.asSelectStatement() - qb.stmt = nil - qb.addTableExpr(op.DT.Alias, op.DT.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: sel, - }, nil, op.DT.Columns) - } -} - -func buildOrdering(op *Ordering, qb *queryBuilder) { - buildQuery(op.Source, qb) - - for _, order := range op.Order { - qb.asOrderAndLimit().AddOrder(order.Inner) - } -} - -func buildLimit(op *Limit, qb *queryBuilder) { - buildQuery(op.Source, qb) - qb.asOrderAndLimit().SetLimit(op.AST) -} - -func buildTable(op *Table, qb *queryBuilder) { - if !qb.includeTable(op) { - return - } - - dbName := "" - - if op.QTable.IsInfSchema { - dbName = op.QTable.Table.Qualifier.String() - } - qb.addTable(dbName, op.QTable.Table.Name.String(), op.QTable.Alias.As.String(), TableID(op), op.QTable.Alias.Hints) - for _, pred := range op.QTable.Predicates { - qb.addPredicate(pred) - } - for _, name := range op.Columns { - qb.addProjection(&sqlparser.AliasedExpr{Expr: name}) - } + sort.Sort(ts) + return true, nil + }, qb.stmt) } -func buildProjection(op *Projection, qb *queryBuilder) { - buildQuery(op.Source, qb) - - _, isSel := qb.stmt.(*sqlparser.Select) - if isSel { - qb.clearProjections() - cols := op.GetSelectExprs(qb.ctx) - for _, column := range cols { - qb.addProjection(column) - } - } - - // if the projection is on derived table, we use the select we have - // created above and transform it into a derived table - if op.DT != nil { - sel := qb.asSelectStatement() - qb.stmt = nil - qb.addTableExpr(op.DT.Alias, op.DT.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: sel, - }, nil, op.DT.Columns) - } - - if !isSel { - for _, column := range op.GetSelectExprs(qb.ctx) { - qb.addProjection(column) +func unionSelects(exprs []sqlparser.SelectExpr) []sqlparser.SelectExpr { + var selectExprs []sqlparser.SelectExpr + for _, col := range exprs { + switch col := col.(type) { + case *sqlparser.AliasedExpr: + expr := sqlparser.NewColName(col.ColumnName()) + selectExprs = append(selectExprs, &sqlparser.AliasedExpr{Expr: expr}) + default: + selectExprs = append(selectExprs, col) } } + return selectExprs } -func buildApplyJoin(op *ApplyJoin, qb *queryBuilder) { - predicates := slice.Map(op.JoinPredicates.columns, func(jc applyJoinColumn) sqlparser.Expr { - // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done - err := qb.ctx.SkipJoinPredicates(jc.Original) - if err != nil { - panic(err) +func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.TableStatement) { + colName := column.Name.String() + firstSelect := getFirstSelect(sel) + exprs := firstSelect.GetColumns() + offset := slices.IndexFunc(exprs, func(expr sqlparser.SelectExpr) bool { + switch ae := expr.(type) { + case *sqlparser.StarExpr: + return true + case *sqlparser.AliasedExpr: + // When accessing columns on top of a UNION, we fall back to this simple strategy of string comparisons + return ae.ColumnName() == colName } - return jc.Original + return false }) - pred := sqlparser.AndExpressions(predicates...) - - buildQuery(op.LHS, qb) - - qbR := &queryBuilder{ctx: qb.ctx} - buildQuery(op.RHS, qbR) - - switch { - // if we have a recursive cte, we might be missing a statement from one of the sides - case qbR.stmt == nil: - // do nothing - case qb.stmt == nil: - qb.stmt = qbR.stmt - default: - qb.joinWith(qbR, pred, op.JoinType) - } -} - -func buildUnion(op *Union, qb *queryBuilder) { - // the first input is built first - buildQuery(op.Sources[0], qb) - - for i, src := range op.Sources { - if i == 0 { - continue - } - - // now we can go over the remaining inputs and UNION them together - qbOther := &queryBuilder{ctx: qb.ctx} - buildQuery(src, qbOther) - qb.unionWith(qbOther, op.distinct) - } -} - -func buildFilter(op *Filter, qb *queryBuilder) { - buildQuery(op.Source, qb) - - for _, pred := range op.Predicates { - qb.addPredicate(pred) + if offset == -1 { + panic(vterrors.VT12001(fmt.Sprintf("did not find column [%s] on UNION", sqlparser.String(column)))) } } -func buildDerived(op *Horizon, qb *queryBuilder) { - buildQuery(op.Source, qb) - - sqlparser.RemoveKeyspaceInCol(op.Query) - - stmt := qb.stmt - qb.stmt = nil - switch sel := stmt.(type) { - case *sqlparser.Select: - buildDerivedSelect(op, qb, sel) - return - case *sqlparser.Union: - buildDerivedUnion(op, qb, sel) - return - } - panic(fmt.Sprintf("unknown select statement type: %T", stmt)) +type FromStatement interface { + GetFrom() []sqlparser.TableExpr + SetFrom([]sqlparser.TableExpr) + GetWherePredicate() sqlparser.Expr + SetWherePredicate(sqlparser.Expr) } -func buildDerivedUnion(op *Horizon, qb *queryBuilder, union *sqlparser.Union) { - opQuery, ok := op.Query.(*sqlparser.Union) - if !ok { - panic(vterrors.VT12001("Horizon contained SELECT but statement was UNION")) - } - - union.Limit = opQuery.Limit - union.OrderBy = opQuery.OrderBy - union.Distinct = opQuery.Distinct - - qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: union, - }, nil, op.ColumnAliases) -} +var _ FromStatement = (*sqlparser.Select)(nil) +var _ FromStatement = (*sqlparser.Update)(nil) +var _ FromStatement = (*sqlparser.Delete)(nil) -func buildDerivedSelect(op *Horizon, qb *queryBuilder, sel *sqlparser.Select) { - opQuery, ok := op.Query.(*sqlparser.Select) - if !ok { - panic(vterrors.VT12001("Horizon contained UNION but statement was SELECT")) - } - sel.Limit = opQuery.Limit - sel.OrderBy = opQuery.OrderBy - sel.GroupBy = opQuery.GroupBy - sel.Having = mergeHaving(sel.Having, opQuery.Having) - sel.SelectExprs = opQuery.SelectExprs - sel.Distinct = opQuery.Distinct - qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ - Select: sel, - }, nil, op.ColumnAliases) - for _, col := range op.Columns { - qb.addProjection(&sqlparser.AliasedExpr{Expr: col}) +func buildJoin(stmt FromStatement, otherStmt FromStatement, onCondition sqlparser.Expr, joinType sqlparser.JoinType) *sqlparser.JoinTableExpr { + var lhs sqlparser.TableExpr + fromClause := stmt.GetFrom() + if len(fromClause) == 1 { + lhs = fromClause[0] + } else { + lhs = &sqlparser.ParenTableExpr{Exprs: fromClause} } -} - -func buildHorizon(op *Horizon, qb *queryBuilder) { - buildQuery(op.Source, qb) - stripDownQuery(op.Query, qb.asSelectStatement()) - sqlparser.RemoveKeyspaceInCol(qb.stmt) -} - -func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) { - predicates := slice.Map(op.Predicates, func(jc *plancontext.RecurseExpression) sqlparser.Expr { - // since we are adding these join predicates, we need to mark to broken up version (RHSExpr) of it as done - err := qb.ctx.SkipJoinPredicates(jc.Original) - if err != nil { - panic(err) - } - return jc.Original - }) - pred := sqlparser.AndExpressions(predicates...) - buildQuery(op.Seed(), qb) - qbR := &queryBuilder{ctx: qb.ctx} - buildQuery(op.Term(), qbR) - qbR.addPredicate(pred) - infoFor, err := qb.ctx.SemTable.TableInfoFor(op.OuterID) - if err != nil { - panic(err) + var rhs sqlparser.TableExpr + otherFromClause := otherStmt.GetFrom() + if len(otherFromClause) == 1 { + rhs = otherFromClause[0] + } else { + rhs = &sqlparser.ParenTableExpr{Exprs: otherFromClause} } - qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String(), op.Distinct, op.Def.Columns) -} - -func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { - switch { - case h1 == nil && h2 == nil: - return nil - case h1 == nil: - return h2 - case h2 == nil: - return h1 - default: - h1.Expr = sqlparser.AndExpressions(h1.Expr, h2.Expr) - return h1 + return &sqlparser.JoinTableExpr{ + LeftExpr: lhs, + RightExpr: rhs, + Join: joinType, + Condition: &sqlparser.JoinCondition{ + On: onCondition, + }, } } diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 935e4d4204d..3124f25efdc 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -64,13 +64,18 @@ func runPhases(ctx *plancontext.PlanningContext, root Operator) Operator { } op = phase.act(ctx, op) - op = runRewriters(ctx, op) + op = runPushDownRewriters(ctx, op) } + ctx.CurrentPhase = int(DONE) + + op = runPushDownRewriters(ctx, op) + op = compact(ctx, op) + return addGroupByOnRHSOfJoin(op) } -func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { +func runPushDownRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { switch in := in.(type) { case *Horizon: @@ -103,7 +108,8 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { return tryPushUpdate(in) case *RecurseCTE: return tryMergeRecurse(ctx, in) - + case *Values: + return tryPushValues(ctx, in) default: return in, NoRewrite } @@ -111,14 +117,28 @@ func runRewriters(ctx *plancontext.PlanningContext, root Operator) Operator { if pbm, ok := root.(*PercentBasedMirror); ok { pbm.SetInputs([]Operator{ - runRewriters(ctx, pbm.Operator()), - runRewriters(ctx.UseMirror(), pbm.Target()), + runPushDownRewriters(ctx, pbm.Operator()), + runPushDownRewriters(ctx.UseMirror(), pbm.Target()), }) } return FixedPointBottomUp(root, TableID, visitor, stopAtRoute) } +func tryPushValues(ctx *plancontext.PlanningContext, in *Values) (Operator, *ApplyResult) { + switch src := in.Source.(type) { + case *ValuesJoin: + src.LHS = in.Clone([]Operator{src.LHS}) + return src, Rewrote("pushed values to the LHS of values join") + case *Filter: + return Swap(in, src, "pushed values under filter") + case *Route: + src.Routing.AddValuesTableID(in.TableID) + return Swap(in, src, "pushed values under route") + } + return in, NoRewrite +} + func tryPushDelete(in *Delete) (Operator, *ApplyResult) { if src, ok := in.Source.(*Route); ok { return pushDMLUnderRoute(in, src, "pushed delete under route") @@ -698,6 +718,11 @@ func tryPushFilter(ctx *plancontext.PlanningContext, in *Filter) (Operator, *App } src.Outer, in.Source = in, src.Outer return src, Rewrote("push filter to outer query in subquery container") + case *ValuesJoin: + for _, pred := range in.Predicates { + src.AddPredicate(ctx, pred) + } + return src, Rewrote("pushed filter predicates through values join") case *Filter: if len(in.Predicates) == 0 { return in.Source, Rewrote("filter with no predicates removed") diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index 9f455e1acec..346a75f719b 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -46,6 +46,8 @@ type ( Lock sqlparser.Lock ResultColumns int + + SeenValues semantics.TableSet } RouteOrdering struct { @@ -101,13 +103,15 @@ type ( OpCode() engine.Opcode Keyspace() *vindexes.Keyspace // note that all routings do not have a keyspace, so this method can return nil + AddValuesTableID(id semantics.TableSet) + // updateRoutingLogic updates the routing to take predicates into account. This can be used for routing // using vindexes or for figuring out which keyspace an information_schema query should be sent to. updateRoutingLogic(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Routing } ) -// UpdateRoutingLogic first checks if we are dealing with a predicate that +// UpdateRoutingLogic first checks if we are dealing with a predicate that can be evaluated to false or NULL. func UpdateRoutingLogic(ctx *plancontext.PlanningContext, expr sqlparser.Expr, r Routing) Routing { ks := r.Keyspace() if ks == nil { diff --git a/go/vt/vtgate/planbuilder/operators/sharded_routing.go b/go/vt/vtgate/planbuilder/operators/sharded_routing.go index 891e3cf5862..258ef77a98b 100644 --- a/go/vt/vtgate/planbuilder/operators/sharded_routing.go +++ b/go/vt/vtgate/planbuilder/operators/sharded_routing.go @@ -46,7 +46,8 @@ type ShardedRouting struct { // SeenPredicates contains all the predicates that have had a chance to influence routing. // If we need to replan routing, we'll use this list - SeenPredicates []sqlparser.Expr + SeenPredicates []sqlparser.Expr + ValuesTablesIDs semantics.TableSet } var _ Routing = (*ShardedRouting)(nil) @@ -189,6 +190,10 @@ func (tr *ShardedRouting) Clone() Routing { } } +func (sr *ShardedRouting) AddValuesTableID(id semantics.TableSet) { + sr.ValuesTablesIDs = sr.ValuesTablesIDs.Merge(id) +} + func (tr *ShardedRouting) updateRoutingLogic(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Routing { tr.SeenPredicates = append(tr.SeenPredicates, expr) @@ -206,6 +211,7 @@ func (tr *ShardedRouting) updateRoutingLogic(ctx *plancontext.PlanningContext, e return tr } +// resetRoutingLogic resets the routing logic to the initial state, and uses the predicates to recompute the routing func (tr *ShardedRouting) resetRoutingLogic(ctx *plancontext.PlanningContext) Routing { tr.RouteOpCode = engine.Scatter tr.Selected = nil @@ -537,6 +543,20 @@ func (tr *ShardedRouting) planEqualOp(ctx *plancontext.PlanningContext, node *sq } val := makeEvalEngineExpr(ctx, vdValue) if val == nil { + col, ok := vdValue.(*sqlparser.ColName) + if !ok { + return false + } + from := ctx.SemTable.RecursiveDeps(col) + if from.IsSolvedBy(tr.ValuesTablesIDs) { + multiEqual := func(vindex *vindexes.ColumnVindex) engine.Opcode { + // TODO @harshit - what else should we do here? + return engine.MultiEqual + } + arg := sqlparser.NewListArg("values") // TODO: HACK - we need to store these names? + + return tr.haveMatchingVindex(ctx, node, arg, column, val, multiEqual, justTheVindex) + } return false } diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index 62b6e7a725e..ac057c1ed23 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -101,7 +101,7 @@ func (sq *SubQuery) GetJoinColumns(ctx *plancontext.PlanningContext, outer Opera } sq.outerID = outerID mapper := func(in sqlparser.Expr) (applyJoinColumn, error) { - return breakExpressionInLHSandRHS(ctx, in, outerID), nil + return breakApplyJoinExpressionInLHSandRHS(ctx, in, outerID), nil } joinPredicates, err := slice.MapWithError(sq.Predicates, mapper) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 87182fa7713..7a9c2d6c3fe 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -297,7 +297,7 @@ func extractLHSExpr( lhs semantics.TableSet, ) func(expr sqlparser.Expr) sqlparser.Expr { return func(expr sqlparser.Expr) sqlparser.Expr { - col := breakExpressionInLHSandRHS(ctx, expr, lhs) + col := breakApplyJoinExpressionInLHSandRHS(ctx, expr, lhs) if col.IsPureLeft() { panic(vterrors.VT13001("did not expect to find any predicates that do not need data from the inner here")) } @@ -676,7 +676,7 @@ func (s *subqueryRouteMerger) merge(ctx *plancontext.PlanningContext, inner, out // We really need to figure out why this is not working as expected func (s *subqueryRouteMerger) rewriteASTExpression(ctx *plancontext.PlanningContext, inner *Route) Operator { src := s.outer.Source - stmt, _, err := ToSQL(ctx, inner.Source) + stmt, _, err := ToAST(ctx, inner.Source) if err != nil { panic(err) } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 759500a8dc8..c9191ce6caf 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -212,7 +212,7 @@ func prepareUpdateExpressionList(ctx *plancontext.PlanningContext, upd *sqlparse for _, ue := range upd.Exprs { target := ctx.SemTable.DirectDeps(ue.Name) exprDeps := ctx.SemTable.RecursiveDeps(ue.Expr) - jc := breakExpressionInLHSandRHS(ctx, ue.Expr, exprDeps.Remove(target)) + jc := breakApplyJoinExpressionInLHSandRHS(ctx, ue.Expr, exprDeps.Remove(target)) ueMap[target] = append(ueMap[target], updColumn{ue.Name, jc}) } diff --git a/go/vt/vtgate/planbuilder/operators/values.go b/go/vt/vtgate/planbuilder/operators/values.go new file mode 100644 index 00000000000..c67576f5237 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/values.go @@ -0,0 +1,105 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +type Values struct { + unaryOperator + + Name string + TableID semantics.TableSet +} + +func (v *Values) Clone(inputs []Operator) Operator { + clone := *v + + if len(inputs) > 0 { + clone.Source = inputs[0] + } + return &clone +} + +func (v *Values) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser.Expr) Operator { + return newFilter(v, expr) +} + +func (v *Values) AddColumn(*plancontext.PlanningContext, bool, bool, *sqlparser.AliasedExpr) int { + panic(vterrors.VT13001("we cannot add new columns to a Values operator")) +} + +func (v *Values) AddWSColumn(*plancontext.PlanningContext, int, bool) int { + panic(vterrors.VT13001("we cannot add new columns to a Values operator")) +} + +func (v *Values) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, _ bool) int { + for i, column := range v.getExprsFromCtx(ctx) { + if ctx.SemTable.EqualsExpr(column, expr) { + return i + } + } + return -1 +} + +func (v *Values) getColumnNamesFromCtx(ctx *plancontext.PlanningContext) sqlparser.Columns { + columns, found := ctx.ValuesJoinColumns[v.Name] + if !found { + panic(vterrors.VT13001("columns not found")) + } + return slice.Map(columns, func(ae *sqlparser.AliasedExpr) sqlparser.IdentifierCI { + return sqlparser.NewIdentifierCI(ae.ColumnName()) + }) +} + +func (v *Values) getExprsFromCtx(ctx *plancontext.PlanningContext) []sqlparser.Expr { + columns := ctx.ValuesJoinColumns[v.Name] + return slice.Map(columns, func(ae *sqlparser.AliasedExpr) sqlparser.Expr { + return ae.Expr + }) +} + +func (v *Values) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { + columns := ctx.ValuesJoinColumns[v.Name] + return columns +} + +func (v *Values) GetSelectExprs(ctx *plancontext.PlanningContext) []sqlparser.SelectExpr { + r := v.GetColumns(ctx) + var selectExprs []sqlparser.SelectExpr + for _, expr := range r { + selectExprs = append(selectExprs, expr) + } + return selectExprs +} + +func (v *Values) ShortDescription() string { + return v.Name +} + +func (v *Values) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { + return v.Source.GetOrdering(ctx) +} + +func (v *Values) introducesTableID() semantics.TableSet { + return v.TableID +} diff --git a/go/vt/vtgate/planbuilder/operators/values_join.go b/go/vt/vtgate/planbuilder/operators/values_join.go new file mode 100644 index 00000000000..a120733d492 --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/values_join.go @@ -0,0 +1,218 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "fmt" + "slices" + "strings" + + "vitess.io/vitess/go/slice" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +type ( + ValuesJoin struct { + binaryOperator + + ValuesDestination string + + JoinColumns []valuesJoinColumn + JoinPredicates []valuesJoinColumn + + // After offset planning + + // CopyColumnsToRHS are the offsets of columns from LHS we are copying over to the RHS + // []int{0,2} means that the first column in the t-o-t is the first offset from the left and the second column is the third offset + CopyColumnsToRHS []int + + Columns []int + ColumnName []string + } + + valuesJoinColumn struct { + Original sqlparser.Expr + RHS sqlparser.Expr + LHS []sqlparser.Expr + PureLHS bool + } +) + +func (c valuesJoinColumn) String() string { + return fmt.Sprintf("[%s:%s]", sqlparser.SliceString(c.LHS), sqlparser.String(c.Original)) +} + +var _ Operator = (*ValuesJoin)(nil) +var _ JoinOp = (*ValuesJoin)(nil) + +func (vj *ValuesJoin) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { + if reuseExisting { + if offset := vj.FindCol(ctx, expr.Expr, false); offset >= 0 { + return offset + } + } + + vj.JoinColumns = append(vj.JoinColumns, breakValuesJoinExpressionInLHS(ctx, expr.Expr, TableID(vj.LHS))) + vj.ColumnName = append(vj.ColumnName, expr.ColumnName()) + return len(vj.JoinColumns) - 1 +} + +// AddWSColumn is used to add a weight_string column to the operator +func (vj *ValuesJoin) AddWSColumn(ctx *plancontext.PlanningContext, offset int, underRoute bool) int { + panic("oh no") +} + +func (vj *ValuesJoin) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { + for offset, column := range vj.JoinColumns { + if ctx.SemTable.EqualsExpr(column.Original, expr) { + return offset + } + } + return -1 +} + +func (vj *ValuesJoin) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr { + results := make([]*sqlparser.AliasedExpr, len(vj.JoinColumns)) + for i, column := range vj.JoinColumns { + results[i] = sqlparser.NewAliasedExpr(column.Original, vj.ColumnName[i]) + } + return results +} + +func (vj *ValuesJoin) GetSelectExprs(ctx *plancontext.PlanningContext) []sqlparser.SelectExpr { + return transformColumnsToSelectExprs(ctx, vj) +} + +func (vj *ValuesJoin) GetLHS() Operator { + return vj.LHS +} + +func (vj *ValuesJoin) GetRHS() Operator { + return vj.RHS +} + +func (vj *ValuesJoin) SetLHS(operator Operator) { + vj.LHS = operator +} + +func (vj *ValuesJoin) SetRHS(operator Operator) { + vj.RHS = operator +} + +func (vj *ValuesJoin) MakeInner() { + // no-op for values-join +} + +func (vj *ValuesJoin) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { + return AddPredicate(ctx, vj, expr, false, newFilterSinglePredicate) +} + +func (vj *ValuesJoin) IsInner() bool { + return true +} + +func (vj *ValuesJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) { + if expr == nil { + return + } + lID := TableID(vj.LHS) + lhsJoinCols := breakValuesJoinExpressionInLHS(ctx, expr, lID) + if lhsJoinCols.PureLHS { + vj.LHS = vj.LHS.AddPredicate(ctx, expr) + return + } + vj.RHS = vj.RHS.AddPredicate(ctx, expr) + vj.JoinPredicates = append(vj.JoinPredicates, lhsJoinCols) +} + +func (vj *ValuesJoin) Clone(inputs []Operator) Operator { + clone := *vj + clone.LHS = inputs[0] + clone.RHS = inputs[1] + return &clone +} + +func (vj *ValuesJoin) ShortDescription() string { + fn := func(cols []valuesJoinColumn) string { + out := slice.Map(cols, func(jc valuesJoinColumn) string { + return jc.String() + }) + return strings.Join(out, ", ") + } + + firstPart := fmt.Sprintf("%s on %s columns: %s", vj.ValuesDestination, fn(vj.JoinPredicates), fn(vj.JoinColumns)) + + return firstPart +} + +func (vj *ValuesJoin) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { + return vj.RHS.GetOrdering(ctx) +} + +func (vj *ValuesJoin) planOffsets(ctx *plancontext.PlanningContext) Operator { + exprs := ctx.GetColumns(vj.ValuesDestination) + for _, jc := range vj.JoinColumns { + newExprs := vj.planOffsetsForLHSExprs(ctx, jc.LHS) + exprs = append(exprs, newExprs...) + offset := vj.RHS.AddColumn(ctx, true, false, aeWrap(jc.Original)) + vj.Columns = append(vj.Columns, ToRightOffset(offset)) + } + for _, jc := range vj.JoinPredicates { + // for join predicates, we only need to push the LHS dependencies. The RHS expressions are already pushed + newExprs := vj.planOffsetsForLHSExprs(ctx, jc.LHS) + exprs = append(exprs, newExprs...) + } + ctx.SetColumns(vj.ValuesDestination, exprs) + return vj +} + +func (vj *ValuesJoin) planOffsetsForLHSExprs(ctx *plancontext.PlanningContext, input []sqlparser.Expr) (exprs []*sqlparser.AliasedExpr) { + for _, lhsExpr := range input { + offset := vj.LHS.AddColumn(ctx, true, false, aeWrap(lhsExpr)) + // only add it if we don't already have it + if slices.Index(vj.CopyColumnsToRHS, offset) == -1 { + vj.CopyColumnsToRHS = append(vj.CopyColumnsToRHS, offset) + newCol := sqlparser.NewColName(getValuesJoinColName(ctx, vj.ValuesDestination, TableID(vj.LHS), lhsExpr)) + exprs = append(exprs, aeWrap(newCol)) + } + } + return exprs +} + +func getValuesJoinColName(ctx *plancontext.PlanningContext, valuesDestination string, tableID semantics.TableSet, expr sqlparser.Expr) string { + col, isCol := expr.(*sqlparser.ColName) + if !isCol { + panic(fmt.Sprintf("expected a col named '%v'", expr)) + } + tableName := col.Qualifier.Name.String() + if tableName == "" { + ti, err := ctx.SemTable.TableInfoFor(tableID) + if err != nil { + tableName = valuesDestination + } else { + tblName, err := ti.Name() + if err != nil { + tableName = valuesDestination + } else { + tableName = tblName.Name.String() + } + } + } + return tableName + "_" + col.Name.String() +} diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 016f5c877cf..c48fb9fa7b3 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -43,8 +43,16 @@ type PlanningContext struct { // a join predicate is reverted to its original form during planning. skipPredicates map[sqlparser.Expr]any + // skipValuesArgument tracks Values operator that should be skipped when + // rewriting the operator tree to an AST tree. + // This happens when a ValuesJoin is pushed under a route and we do not + // need to have a Values operator anymore on its RHS. + skipValuesArgument map[string]any + PlannerVersion querypb.ExecuteOptions_PlannerVersion + AllowValuesJoin bool + // If we during planning have turned this expression into an argument name, // we can continue using the same argument name ReservedArguments map[sqlparser.Expr]string @@ -77,10 +85,29 @@ type PlanningContext struct { // isMirrored indicates that mirrored tables should be used. isMirrored bool + // ValuesJoinColumns stores the columns we need for each values statement in the plan. + ValuesJoinColumns map[string][]*sqlparser.AliasedExpr + + // ValueJoins contains one entry for each value join that has been created. + // The key is the value-join ops ValuesDestination, and the value is the Values op associated with it. + // When first created, these are one-to-one, but the Values are merged if they end up in the same route + ValueJoins map[string]string + emptyEnv *evalengine.ExpressionEnv constantCfg *evalengine.Config } +func CreateEmptyPlanningContext() *PlanningContext { + return &PlanningContext{ + joinPredicates: make(map[sqlparser.Expr][]sqlparser.Expr), + skipPredicates: make(map[sqlparser.Expr]any), + skipValuesArgument: make(map[string]any), + ReservedArguments: make(map[sqlparser.Expr]string), + ValuesJoinColumns: make(map[string][]*sqlparser.AliasedExpr), + ValueJoins: make(map[string]string), + } +} + // CreatePlanningContext initializes a new PlanningContext with the given parameters. // It analyzes the SQL statement within the given virtual schema context, // handling default keyspace settings and semantic analysis. @@ -104,14 +131,18 @@ func CreatePlanningContext(stmt sqlparser.Statement, vschema.PlannerWarning(semTable.Warning) return &PlanningContext{ - ReservedVars: reservedVars, - SemTable: semTable, - VSchema: vschema, - joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, - skipPredicates: map[sqlparser.Expr]any{}, - PlannerVersion: version, - ReservedArguments: map[sqlparser.Expr]string{}, - Statement: stmt, + ReservedVars: reservedVars, + SemTable: semTable, + VSchema: vschema, + joinPredicates: map[sqlparser.Expr][]sqlparser.Expr{}, + skipPredicates: map[sqlparser.Expr]any{}, + skipValuesArgument: map[string]any{}, + PlannerVersion: version, + ReservedArguments: map[sqlparser.Expr]string{}, + ValuesJoinColumns: make(map[string][]*sqlparser.AliasedExpr), + Statement: stmt, + AllowValuesJoin: sqlparser.AllowValuesJoinDirective(stmt), + ValueJoins: make(map[string]string), }, nil } @@ -176,6 +207,15 @@ func (ctx *PlanningContext) SkipJoinPredicates(joinPred sqlparser.Expr) error { return vterrors.VT13001("predicate does not exist: " + sqlparser.String(joinPred)) } +func (ctx *PlanningContext) SkipValuesArgument(name string) { + ctx.skipValuesArgument[name] = "" +} + +func (ctx *PlanningContext) IsValuesArgumentSkipped(name string) bool { + _, ok := ctx.skipValuesArgument[name] + return ok +} + // KeepPredicateInfo transfers join predicate information from another context. // This is useful when nesting queries, ensuring consistent predicate handling across contexts. func (ctx *PlanningContext) KeepPredicateInfo(other *PlanningContext) { @@ -434,6 +474,15 @@ func (ctx *PlanningContext) ActiveCTE() *ContextCTE { return ctx.CurrentCTE[len(ctx.CurrentCTE)-1] } +func (ctx *PlanningContext) GetColumns(joinName string) []*sqlparser.AliasedExpr { + valuesName := ctx.ValueJoins[joinName] + return ctx.ValuesJoinColumns[valuesName] +} +func (ctx *PlanningContext) SetColumns(joinName string, cols []*sqlparser.AliasedExpr) { + valuesName := ctx.ValueJoins[joinName] + ctx.ValuesJoinColumns[valuesName] = cols +} + func (ctx *PlanningContext) UseMirror() *PlanningContext { if ctx.isMirrored { panic(vterrors.VT13001("cannot mirror already mirrored planning context")) @@ -457,6 +506,7 @@ func (ctx *PlanningContext) UseMirror() *PlanningContext { CurrentCTE: ctx.CurrentCTE, emptyEnv: ctx.emptyEnv, isMirrored: true, + ValuesJoinColumns: ctx.ValuesJoinColumns, } return ctx.mirror } diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index 9d653b2f6e9..f34b2bd60d7 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,8 +1,52 @@ [ { "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "", + "query": "select /*vt+ ALLOW_VALUES_JOIN */ user.foo, user_extra.user_id from user, user_extra where user.id = user_extra.toto and user.foo = 1 and user_extra.bar = 2", "plan": { + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_VALUES_JOIN */ user.foo, user_extra.user_id from user, user_extra where user.id = user_extra.toto", + "Instructions": { + "OperatorType": "Join", + "Variant": "Values", + "BindVarName": "values", + "CopyColumnsToRHS": [ + 0, + 1 + ], + "RowID": "false", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.user_id, user_extra.toto from user_extra where 1 != 1", + "Query": "select /*vt+ ALLOW_VALUES_JOIN */ user_extra.user_id, user_extra.toto from user_extra where user_extra.bar = 2", + "Table": "user_extra" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.foo, user_extra.user_id from (values ::values) as `values`(user_id, toto), `user` where 1 != 1", + "Query": "select /*vt+ ALLOW_VALUES_JOIN */ `user`.foo, values.user_extra_user_id as user_id from (values ::values) as `values`(user_extra_user_id, user_extra_toto), `user` where `user`.foo = 1 and `user`.id = values.user_extra_toto", + "Table": "`user`", + "Values": [ + ":user_extra_toto" + ], + "Vindex": "user_index" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] } } ] \ No newline at end of file diff --git a/go/vt/vtgate/semantics/table_set.go b/go/vt/vtgate/semantics/table_set.go index acc83306869..65d714fc340 100644 --- a/go/vt/vtgate/semantics/table_set.go +++ b/go/vt/vtgate/semantics/table_set.go @@ -18,6 +18,7 @@ package semantics import ( "fmt" + "strings" "vitess.io/vitess/go/vt/vtgate/semantics/bitset" ) @@ -41,6 +42,22 @@ func (ts TableSet) Format(f fmt.State, verb rune) { fmt.Fprintf(f, "}") } +func (ts TableSet) DebugString() string { + var f strings.Builder + first := true + f.WriteString("TableSet{") + bitset.Bitset(ts).ForEach(func(tid int) { + if first { + f.WriteString(fmt.Sprintf("%d", tid)) + first = false + } else { + f.WriteString(fmt.Sprintf(",%d", tid)) + } + }) + f.WriteString("}") + return f.String() +} + // IsOverlapping returns true if at least one table exists in both sets func (ts TableSet) IsOverlapping(other TableSet) bool { return bitset.Bitset(ts).Overlaps(bitset.Bitset(other))