diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 62cdc019ddf..c4e7dc55866 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -286,14 +286,22 @@ func containsStar(s sqlparser.SelectExprs) bool { } func checkUnionColumns(union *sqlparser.Union) error { - firstProj := sqlparser.GetFirstSelect(union).SelectExprs + lft, err := sqlparser.GetFirstSelect(union) + if err != nil { + return err + } + firstProj := lft.GetColumns() if containsStar(firstProj) { // if we still have *, we can't figure out if the query is invalid or not // we'll fail it at run time instead return nil } - secondProj := sqlparser.GetFirstSelect(union.Right).SelectExprs + rgt, err := sqlparser.GetFirstSelect(union.Right) + if err != nil { + return err + } + secondProj := rgt.GetColumns() if containsStar(secondProj) { return nil } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index de8fbdee0d7..3f7a21cb6ac 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -1003,8 +1003,10 @@ func TestUnionWithOrderBy(t *testing.T) { stmt, semTable := parseAndAnalyze(t, query, "") union, _ := stmt.(*sqlparser.Union) - sel1 := sqlparser.GetFirstSelect(union) - sel2 := sqlparser.GetFirstSelect(union.Right) + sel1, err := sqlparser.GetFirstSelect(union) + require.NoError(t, err) + sel2, err := sqlparser.GetFirstSelect(union.Right) + require.NoError(t, err) t1 := sel1.From[0].(*sqlparser.AliasedTableExpr) t2 := sel2.From[0].(*sqlparser.AliasedTableExpr) diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go index 498fc5076c1..29330e17ce2 100644 --- a/go/vt/vtgate/semantics/cte_table.go +++ b/go/vt/vtgate/semantics/cte_table.go @@ -150,7 +150,7 @@ func (cte *CTETable) GetMirrorRule() *vindexes.MirrorRule { type CTE struct { Name string - Query sqlparser.SelectStatement + Query sqlparser.TableSubquery isAuthoritative bool recursiveDeps *TableSet Columns sqlparser.Columns diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 3e53ed0816a..c802279789c 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -385,12 +385,16 @@ func getIntLiteral(e sqlparser.Expr) *sqlparser.Literal { // handleOrderBy processes the ORDER BY clause. func (r *earlyRewriter) handleOrderBy(parent sqlparser.SQLNode, iter iterator) error { - stmt, ok := parent.(sqlparser.SelectStatement) + stmt, ok := parent.(sqlparser.TableSubquery) if !ok { return nil } - sel := sqlparser.GetFirstSelect(stmt) + sel, err := sqlparser.GetFirstSelect(stmt) + if err != nil { + return err + } + for e := iter.next(); e != nil; e = iter.next() { lit, err := r.replaceLiteralsInOrderBy(e, iter) if err != nil { @@ -419,12 +423,15 @@ func (r *earlyRewriter) handleOrderBy(parent sqlparser.SQLNode, iter iterator) e // handleGroupBy processes the GROUP BY clause. func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) error { - stmt, ok := parent.(sqlparser.SelectStatement) + stmt, ok := parent.(*sqlparser.Select) if !ok { return nil } - sel := sqlparser.GetFirstSelect(stmt) + sel, err := sqlparser.GetFirstSelect(stmt) + if err != nil { + return err + } for e := iter.next(); e != nil; e = iter.next() { expr, err := r.replaceLiteralsInGroupBy(e) if err != nil { @@ -435,7 +442,6 @@ func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) e if err != nil { return err } - } err = iter.replace(expr) if err != nil { diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index 9d596d9ecd1..e6df3c3a5b0 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -297,11 +297,11 @@ func (s *scoper) createSpecialScopePostProjection(parent sqlparser.SQLNode) erro for i, sel := range sqlparser.GetAllSelects(parent) { if i == 0 { nScope.stmt = sel - tableInfo = createVTableInfoForExpressions(sel.SelectExprs, nil /*needed for star expressions*/, s.org) + tableInfo = createVTableInfoForExpressions(sel.GetColumns(), nil /*needed for star expressions*/, s.org) nScope.tables = append(nScope.tables, tableInfo) continue } - thisTableInfo := createVTableInfoForExpressions(sel.SelectExprs, nil /*needed for star expressions*/, s.org) + thisTableInfo := createVTableInfoForExpressions(sel.GetColumns(), nil /*needed for star expressions*/, s.org) if len(tableInfo.cols) != len(thisTableInfo.cols) { return vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.WrongNumberOfColumnsInSelect, "The used SELECT statements have a different number of columns") } diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 45a50fd23a2..329ebcef254 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -154,8 +154,11 @@ func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr } func (tc *tableCollector) visitUnion(union *sqlparser.Union) error { - firstSelect := sqlparser.GetFirstSelect(union) - expanded, selectExprs := getColumnNames(firstSelect.SelectExprs) + firstSelect, err := sqlparser.GetFirstSelect(union) + if err != nil { + return err + } + expanded, selectExprs := getColumnNames(firstSelect.GetColumns()) info := unionInfo{ isAuthoritative: expanded, exprs: selectExprs, @@ -165,12 +168,12 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error { return nil } - size := len(firstSelect.SelectExprs) + size := firstSelect.GetColumnCount() info.recursive = make([]TableSet, size) typers := make([]evalengine.TypeAggregator, size) collations := tc.org.collationEnv() - err := sqlparser.VisitAllSelects(union, func(s *sqlparser.Select, idx int) error { + err = sqlparser.VisitAllSelects(union, func(s *sqlparser.Select, idx int) error { for i, expr := range s.SelectExprs { ae, ok := expr.(*sqlparser.AliasedExpr) if !ok { @@ -413,7 +416,10 @@ func checkValidRecursiveCTE(cteDef *CTE) error { return vterrors.VT09026(cteDef.Name) } - firstSelect := sqlparser.GetFirstSelect(union.Right) + firstSelect, err := sqlparser.GetFirstSelect(union.Right) + if err != nil { + return err + } if firstSelect.GroupBy != nil { return vterrors.VT09027(cteDef.Name) } @@ -470,8 +476,16 @@ func (tc *tableCollector) addSelectDerivedTable( return scope.addTable(tableInfo) } -func (tc *tableCollector) addUnionDerivedTable(union *sqlparser.Union, node *sqlparser.AliasedTableExpr, columns sqlparser.Columns, alias sqlparser.IdentifierCS) error { - firstSelect := sqlparser.GetFirstSelect(union) +func (tc *tableCollector) addUnionDerivedTable( + union *sqlparser.Union, + node *sqlparser.AliasedTableExpr, + columns sqlparser.Columns, + alias sqlparser.IdentifierCS, +) error { + firstSelect, err := sqlparser.GetFirstSelect(union) + if err != nil { + return err + } tables := tc.scoper.wScope[firstSelect] info, found := tc.unionInfo[union] if !found {