Skip to content

Commit

Permalink
move semantics to the new AST types
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Jan 10, 2025
1 parent 148c41b commit 7fe8d6b
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 19 deletions.
12 changes: 10 additions & 2 deletions go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,14 +286,22 @@ func containsStar(s sqlparser.SelectExprs) bool {
}

func checkUnionColumns(union *sqlparser.Union) error {
firstProj := sqlparser.GetFirstSelect(union).SelectExprs
lft, err := sqlparser.GetFirstSelect(union)
if err != nil {
return err
}
firstProj := lft.GetColumns()
if containsStar(firstProj) {
// if we still have *, we can't figure out if the query is invalid or not
// we'll fail it at run time instead
return nil
}

secondProj := sqlparser.GetFirstSelect(union.Right).SelectExprs
rgt, err := sqlparser.GetFirstSelect(union.Right)
if err != nil {
return err
}
secondProj := rgt.GetColumns()
if containsStar(secondProj) {
return nil
}
Expand Down
6 changes: 4 additions & 2 deletions go/vt/vtgate/semantics/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1003,8 +1003,10 @@ func TestUnionWithOrderBy(t *testing.T) {

stmt, semTable := parseAndAnalyze(t, query, "")
union, _ := stmt.(*sqlparser.Union)
sel1 := sqlparser.GetFirstSelect(union)
sel2 := sqlparser.GetFirstSelect(union.Right)
sel1, err := sqlparser.GetFirstSelect(union)
require.NoError(t, err)
sel2, err := sqlparser.GetFirstSelect(union.Right)
require.NoError(t, err)

t1 := sel1.From[0].(*sqlparser.AliasedTableExpr)
t2 := sel2.From[0].(*sqlparser.AliasedTableExpr)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/semantics/cte_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func (cte *CTETable) GetMirrorRule() *vindexes.MirrorRule {

type CTE struct {
Name string
Query sqlparser.SelectStatement
Query sqlparser.TableSubquery
isAuthoritative bool
recursiveDeps *TableSet
Columns sqlparser.Columns
Expand Down
16 changes: 11 additions & 5 deletions go/vt/vtgate/semantics/early_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,16 @@ func getIntLiteral(e sqlparser.Expr) *sqlparser.Literal {

// handleOrderBy processes the ORDER BY clause.
func (r *earlyRewriter) handleOrderBy(parent sqlparser.SQLNode, iter iterator) error {
stmt, ok := parent.(sqlparser.SelectStatement)
stmt, ok := parent.(sqlparser.TableSubquery)
if !ok {
return nil
}

sel := sqlparser.GetFirstSelect(stmt)
sel, err := sqlparser.GetFirstSelect(stmt)
if err != nil {
return err
}

for e := iter.next(); e != nil; e = iter.next() {
lit, err := r.replaceLiteralsInOrderBy(e, iter)
if err != nil {
Expand Down Expand Up @@ -419,12 +423,15 @@ func (r *earlyRewriter) handleOrderBy(parent sqlparser.SQLNode, iter iterator) e

// handleGroupBy processes the GROUP BY clause.
func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) error {
stmt, ok := parent.(sqlparser.SelectStatement)
stmt, ok := parent.(*sqlparser.Select)
if !ok {
return nil
}

sel := sqlparser.GetFirstSelect(stmt)
sel, err := sqlparser.GetFirstSelect(stmt)
if err != nil {
return err
}
for e := iter.next(); e != nil; e = iter.next() {
expr, err := r.replaceLiteralsInGroupBy(e)
if err != nil {
Expand All @@ -435,7 +442,6 @@ func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) e
if err != nil {
return err
}

}
err = iter.replace(expr)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/semantics/scoper.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,11 @@ func (s *scoper) createSpecialScopePostProjection(parent sqlparser.SQLNode) erro
for i, sel := range sqlparser.GetAllSelects(parent) {
if i == 0 {
nScope.stmt = sel
tableInfo = createVTableInfoForExpressions(sel.SelectExprs, nil /*needed for star expressions*/, s.org)
tableInfo = createVTableInfoForExpressions(sel.GetColumns(), nil /*needed for star expressions*/, s.org)
nScope.tables = append(nScope.tables, tableInfo)
continue
}
thisTableInfo := createVTableInfoForExpressions(sel.SelectExprs, nil /*needed for star expressions*/, s.org)
thisTableInfo := createVTableInfoForExpressions(sel.GetColumns(), nil /*needed for star expressions*/, s.org)
if len(tableInfo.cols) != len(thisTableInfo.cols) {
return vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.WrongNumberOfColumnsInSelect, "The used SELECT statements have a different number of columns")
}
Expand Down
28 changes: 21 additions & 7 deletions go/vt/vtgate/semantics/table_collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,11 @@ func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr
}

func (tc *tableCollector) visitUnion(union *sqlparser.Union) error {
firstSelect := sqlparser.GetFirstSelect(union)
expanded, selectExprs := getColumnNames(firstSelect.SelectExprs)
firstSelect, err := sqlparser.GetFirstSelect(union)
if err != nil {
return err
}
expanded, selectExprs := getColumnNames(firstSelect.GetColumns())
info := unionInfo{
isAuthoritative: expanded,
exprs: selectExprs,
Expand All @@ -165,12 +168,12 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error {
return nil
}

size := len(firstSelect.SelectExprs)
size := firstSelect.GetColumnCount()
info.recursive = make([]TableSet, size)
typers := make([]evalengine.TypeAggregator, size)
collations := tc.org.collationEnv()

err := sqlparser.VisitAllSelects(union, func(s *sqlparser.Select, idx int) error {
err = sqlparser.VisitAllSelects(union, func(s *sqlparser.Select, idx int) error {
for i, expr := range s.SelectExprs {
ae, ok := expr.(*sqlparser.AliasedExpr)
if !ok {
Expand Down Expand Up @@ -413,7 +416,10 @@ func checkValidRecursiveCTE(cteDef *CTE) error {
return vterrors.VT09026(cteDef.Name)
}

firstSelect := sqlparser.GetFirstSelect(union.Right)
firstSelect, err := sqlparser.GetFirstSelect(union.Right)
if err != nil {
return err
}
if firstSelect.GroupBy != nil {
return vterrors.VT09027(cteDef.Name)
}
Expand Down Expand Up @@ -470,8 +476,16 @@ func (tc *tableCollector) addSelectDerivedTable(
return scope.addTable(tableInfo)
}

func (tc *tableCollector) addUnionDerivedTable(union *sqlparser.Union, node *sqlparser.AliasedTableExpr, columns sqlparser.Columns, alias sqlparser.IdentifierCS) error {
firstSelect := sqlparser.GetFirstSelect(union)
func (tc *tableCollector) addUnionDerivedTable(
union *sqlparser.Union,
node *sqlparser.AliasedTableExpr,
columns sqlparser.Columns,
alias sqlparser.IdentifierCS,
) error {
firstSelect, err := sqlparser.GetFirstSelect(union)
if err != nil {
return err
}
tables := tc.scoper.wScope[firstSelect]
info, found := tc.unionInfo[union]
if !found {
Expand Down

0 comments on commit 7fe8d6b

Please sign in to comment.