Skip to content

Commit b51d049

Browse files
committed
wip - split analysis into two parts
Signed-off-by: Andres Taylor <andres@planetscale.com>
1 parent a404807 commit b51d049

File tree

4 files changed

+165
-39
lines changed

4 files changed

+165
-39
lines changed

go/vt/vtgate/semantics/analyzer.go

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,43 +28,49 @@ import (
2828
// analyzer controls the flow of the analysis.
2929
// It starts the tree walking and controls which part of the analysis sees which parts of the tree
3030
type analyzer struct {
31-
scoper *scoper
32-
tables *tableCollector
33-
binder *binder
34-
typer *typer
35-
rewriter *earlyRewriter
36-
sig QuerySignature
31+
scoper *scoper
32+
earlyTables *earlyTableCollector
33+
tables *tableCollector
34+
binder *binder
35+
typer *typer
36+
rewriter *earlyRewriter
37+
sig QuerySignature
38+
si SchemaInformation
39+
currentDb string
3740

3841
err error
3942
inProjection int
4043

41-
projErr error
42-
unshardedErr error
43-
warning string
44+
projErr error
45+
unshardedErr error
46+
warning string
47+
singleUnshardedKeyspace bool
4448
}
4549

4650
// newAnalyzer create the semantic analyzer
4751
func newAnalyzer(dbName string, si SchemaInformation) *analyzer {
4852
// TODO dependencies between these components are a little tangled. We should try to clean up
4953
s := newScoper()
5054
a := &analyzer{
51-
scoper: s,
52-
tables: newTableCollector(s, si, dbName),
53-
typer: newTyper(si.Environment().CollationEnv()),
55+
scoper: s,
56+
earlyTables: newEarlyTableCollector(si, dbName),
57+
typer: newTyper(si.Environment().CollationEnv()),
58+
si: si,
59+
currentDb: dbName,
5460
}
5561
s.org = a
56-
a.tables.org = a
62+
return a
63+
}
5764

58-
b := newBinder(s, a, a.tables, a.typer)
59-
a.binder = b
65+
func (a *analyzer) lateInit() {
66+
a.tables = a.earlyTables.newTableCollector(a.scoper, a)
67+
a.binder = newBinder(a.scoper, a, a.tables, a.typer)
6068
a.rewriter = &earlyRewriter{
61-
env: si.Environment(),
62-
scoper: s,
63-
binder: b,
69+
env: a.si.Environment(),
70+
scoper: a.scoper,
71+
binder: a.binder,
6472
expandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{},
6573
}
66-
s.binder = b
67-
return a
6874
}
6975

7076
// Analyze analyzes the parsed query.
@@ -109,6 +115,32 @@ func (a *analyzer) newSemTable(
109115
if isCommented {
110116
comments = commentedStmt.GetParsedComments()
111117
}
118+
119+
if a.singleUnshardedKeyspace {
120+
return &SemTable{
121+
Tables: a.earlyTables.Tables,
122+
Comments: comments,
123+
Warning: a.warning,
124+
Collation: coll,
125+
ExprTypes: map[sqlparser.Expr]evalengine.Type{},
126+
NotSingleRouteErr: a.projErr,
127+
NotUnshardedErr: a.unshardedErr,
128+
Recursive: ExprDependencies{},
129+
Direct: ExprDependencies{},
130+
Targets: map[sqlparser.IdentifierCS]TableSet{},
131+
ColumnEqualities: map[columnName][]sqlparser.Expr{},
132+
ExpandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{},
133+
columns: map[*sqlparser.Union]sqlparser.SelectExprs{},
134+
comparator: nil,
135+
StatementIDs: a.scoper.statementIDs,
136+
QuerySignature: QuerySignature{},
137+
childForeignKeysInvolved: map[TableSet][]vindexes.ChildFKInfo{},
138+
parentForeignKeysInvolved: map[TableSet][]vindexes.ParentFKInfo{},
139+
childFkToUpdExprs: map[string]sqlparser.UpdateExprs{},
140+
collEnv: env,
141+
}, nil
142+
}
143+
112144
columns := map[*sqlparser.Union]sqlparser.SelectExprs{}
113145
for union, info := range a.tables.unionInfo {
114146
columns[union] = info.exprs
@@ -298,10 +330,38 @@ func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet,
298330
}
299331

300332
func (a *analyzer) analyze(statement sqlparser.Statement) error {
333+
_ = sqlparser.Rewrite(statement, nil, a.earlyUp)
334+
if a.err != nil {
335+
return a.err
336+
}
337+
ks, _ := singleUnshardedKeyspace(a.earlyTables.Tables)
338+
if ks != nil {
339+
// if we found a single unsharded keyspace in the early walk we can stop here
340+
a.singleUnshardedKeyspace = true
341+
return nil
342+
}
343+
344+
a.lateInit()
345+
301346
_ = sqlparser.Rewrite(statement, a.analyzeDown, a.analyzeUp)
302347
return a.err
303348
}
304349

350+
// earlyUp collects tables in the query, so we can check
351+
// if this a single unsharded query we are dealing with
352+
func (a *analyzer) earlyUp(cursor *sqlparser.Cursor) bool {
353+
if !a.shouldContinue() {
354+
return false
355+
}
356+
357+
if err := a.earlyTables.up(cursor); err != nil {
358+
a.setError(err)
359+
return false
360+
}
361+
362+
return a.shouldContinue()
363+
}
364+
305365
func (a *analyzer) shouldContinue() bool {
306366
return a.err == nil
307367
}

go/vt/vtgate/semantics/analyzer_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ func TestScopingWDerivedTables(t *testing.T) {
986986
recursiveExpectation: MergeTableSets(TS0, TS1),
987987
}, {
988988
query: "select t.id from (select * from user) as t join user as u on t.id = u.id",
989-
expectation: TS1,
989+
expectation: TS2,
990990
recursiveExpectation: TS0,
991991
}, {
992992
query: "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t",
@@ -1638,11 +1638,11 @@ func TestScopingSubQueryJoinClause(t *testing.T) {
16381638

16391639
var ks1 = &vindexes.Keyspace{
16401640
Name: "ks1",
1641-
Sharded: false,
1641+
Sharded: true,
16421642
}
16431643
var ks2 = &vindexes.Keyspace{
16441644
Name: "ks2",
1645-
Sharded: false,
1645+
Sharded: true,
16461646
}
16471647
var ks3 = &vindexes.Keyspace{
16481648
Name: "ks3",

go/vt/vtgate/semantics/semantic_state.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,10 @@ func (st *SemTable) ColumnLookup(col *sqlparser.ColName) (int, error) {
730730

731731
// SingleUnshardedKeyspace returns the single keyspace if all tables in the query are in the same, unsharded keyspace
732732
func (st *SemTable) SingleUnshardedKeyspace() (ks *vindexes.Keyspace, tables []*vindexes.Table) {
733+
return singleUnshardedKeyspace(st.Tables)
734+
}
735+
736+
func singleUnshardedKeyspace(tableInfos []TableInfo) (ks *vindexes.Keyspace, tables []*vindexes.Table) {
733737
validKS := func(this *vindexes.Keyspace) bool {
734738
if this == nil || this.Sharded {
735739
return false
@@ -744,7 +748,7 @@ func (st *SemTable) SingleUnshardedKeyspace() (ks *vindexes.Keyspace, tables []*
744748
return true
745749
}
746750

747-
for _, table := range st.Tables {
751+
for _, table := range tableInfos {
748752
if _, isDT := table.(*DerivedTable); isDT {
749753
continue
750754
}

go/vt/vtgate/semantics/table_collector.go

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,61 @@ type tableCollector struct {
3535
currentDb string
3636
org originable
3737
unionInfo map[*sqlparser.Union]unionInfo
38+
done map[*sqlparser.AliasedTableExpr]TableInfo
3839
}
3940

40-
func newTableCollector(scoper *scoper, si SchemaInformation, currentDb string) *tableCollector {
41-
return &tableCollector{
42-
scoper: scoper,
41+
type earlyTableCollector struct {
42+
si SchemaInformation
43+
currentDb string
44+
Tables []TableInfo
45+
done map[*sqlparser.AliasedTableExpr]TableInfo
46+
}
47+
48+
func newEarlyTableCollector(si SchemaInformation, currentDb string) *earlyTableCollector {
49+
return &earlyTableCollector{
4350
si: si,
4451
currentDb: currentDb,
52+
done: map[*sqlparser.AliasedTableExpr]TableInfo{},
53+
}
54+
}
55+
56+
func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) error {
57+
aet, ok := cursor.Node().(*sqlparser.AliasedTableExpr)
58+
if !ok {
59+
return nil
60+
}
61+
return etc.visitAliasedTableExpr(aet)
62+
}
63+
64+
func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTableExpr) error {
65+
tbl, ok := aet.Expr.(sqlparser.TableName)
66+
if !ok {
67+
return nil
68+
}
69+
return etc.handleTableName(tbl, aet)
70+
}
71+
72+
func (etc *earlyTableCollector) newTableCollector(scoper *scoper, org originable) *tableCollector {
73+
return &tableCollector{
74+
Tables: etc.Tables,
75+
scoper: scoper,
76+
si: etc.si,
77+
currentDb: etc.currentDb,
4578
unionInfo: map[*sqlparser.Union]unionInfo{},
79+
done: etc.done,
80+
org: org,
81+
}
82+
}
83+
84+
func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sqlparser.AliasedTableExpr) error {
85+
tableInfo, err := getTableInfo(aet, tbl, etc.si, etc.currentDb)
86+
if err != nil {
87+
return err
4688
}
89+
90+
etc.done[aet] = tableInfo
91+
etc.Tables = append(etc.Tables, tableInfo)
92+
return nil
4793
}
4894

4995
func (tc *tableCollector) up(cursor *sqlparser.Cursor) error {
@@ -103,28 +149,42 @@ func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr
103149
return nil
104150
}
105151

106-
func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sqlparser.TableName) error {
152+
func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sqlparser.TableName) (err error) {
153+
var tableInfo TableInfo
154+
var found bool
155+
156+
tableInfo, found = tc.done[node]
157+
if !found {
158+
tableInfo, err = getTableInfo(node, t, tc.si, tc.currentDb)
159+
if err != nil {
160+
return err
161+
}
162+
tc.Tables = append(tc.Tables, tableInfo)
163+
}
164+
165+
scope := tc.scoper.currentScope()
166+
return scope.addTable(tableInfo)
167+
}
168+
169+
func getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, si SchemaInformation, currentDb string) (TableInfo, error) {
107170
var tbl *vindexes.Table
108171
var vindex vindexes.Vindex
109172
isInfSchema := sqlparser.SystemSchema(t.Qualifier.String())
110173
var err error
111-
tbl, vindex, _, _, _, err = tc.si.FindTableOrVindex(t)
174+
tbl, vindex, _, _, _, err = si.FindTableOrVindex(t)
112175
if err != nil && !isInfSchema {
113176
// if we are dealing with a system table, it might not be available in the vschema, but that is OK
114-
return err
177+
return nil, err
115178
}
116179
if tbl == nil && vindex != nil {
117180
tbl = newVindexTable(t.Name)
118181
}
119182

120-
scope := tc.scoper.currentScope()
121-
tableInfo, err := tc.createTable(t, node, tbl, isInfSchema, vindex)
183+
tableInfo, err := createTable(t, node, tbl, isInfSchema, vindex, si, currentDb)
122184
if err != nil {
123-
return err
185+
return nil, err
124186
}
125-
126-
tc.Tables = append(tc.Tables, tableInfo)
127-
return scope.addTable(tableInfo)
187+
return tableInfo, nil
128188
}
129189

130190
func (tc *tableCollector) handleDerivedTable(node *sqlparser.AliasedTableExpr, t *sqlparser.DerivedTable) error {
@@ -228,12 +288,14 @@ func (tc *tableCollector) tableInfoFor(id TableSet) (TableInfo, error) {
228288
return tc.Tables[offset], nil
229289
}
230290

231-
func (tc *tableCollector) createTable(
291+
func createTable(
232292
t sqlparser.TableName,
233293
alias *sqlparser.AliasedTableExpr,
234294
tbl *vindexes.Table,
235295
isInfSchema bool,
236296
vindex vindexes.Vindex,
297+
si SchemaInformation,
298+
currentDb string,
237299
) (TableInfo, error) {
238300
hint := getVindexHint(alias.Hints)
239301

@@ -247,13 +309,13 @@ func (tc *tableCollector) createTable(
247309
Table: tbl,
248310
VindexHint: hint,
249311
isInfSchema: isInfSchema,
250-
collationEnv: tc.si.Environment().CollationEnv(),
312+
collationEnv: si.Environment().CollationEnv(),
251313
}
252314

253315
if alias.As.IsEmpty() {
254316
dbName := t.Qualifier.String()
255317
if dbName == "" {
256-
dbName = tc.currentDb
318+
dbName = currentDb
257319
}
258320

259321
table.dbName = dbName

0 commit comments

Comments
 (0)