-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtable_collector.go
147 lines (134 loc) · 3.9 KB
/
table_collector.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
package sqlparser
type TableSchemaAndName struct {
schema string
name string
}
func NewTableSchemaAndName(schema string, name string) TableSchemaAndName {
return TableSchemaAndName{
schema: schema,
name: name,
}
}
func (t TableSchemaAndName) GetSchema() string {
return t.schema
}
func (t TableSchemaAndName) GetName() string {
return t.name
}
func (t TableSchemaAndName) String() string {
if t.schema == "" {
return t.name
}
return t.schema + "." + t.name
}
// CollectTables builds the list of required tables for all the
// tables referenced in a query.
func CollectTables(stmt Statement, defaultTableSchema string) []TableSchemaAndName {
var tables []TableSchemaAndName
// All Statement types myst be covered here.
switch node := stmt.(type) {
case *Union, *Select:
tables = collectFromSubQuery(node, tables)
case *Insert:
tables = collectFromTableName(node.Table, tables)
tables = collectFromSubQuery(node, tables)
case *Update:
tables = collectFromTableExprs(node.TableExprs, tables)
tables = collectFromSubQuery(node, tables)
case *Delete:
tables = collectFromTableExprs(node.TableExprs, tables)
tables = collectFromSubQuery(node, tables)
case DDLStatement:
for _, t := range node.AffectedTables() {
tables = collectFromTableName(t, tables)
}
case
*AlterMigration,
*AlterDMLJob,
*RevertMigration,
*ShowMigrationLogs,
*ShowThrottledApps,
*ShowThrottlerStatus:
tables = []TableSchemaAndName{}
case *Flush:
for _, t := range node.TableNames {
tables = collectFromTableName(t, tables)
}
case *OtherAdmin, *CheckTable, *Kill, *CallProc, *Begin, *Commit, *Rollback,
*Load, *Savepoint, *Release, *SRollback, *Set, *Show,
*OtherRead, Explain, DBDDLStatement:
// no op
default:
}
tables = addDefaultTableSchema(tables, defaultTableSchema)
return removeDuplicateTables(tables)
}
func collectFromSubQuery(stmt Statement, tables []TableSchemaAndName) []TableSchemaAndName {
_ = Walk(func(node SQLNode) (bool, error) {
switch node := node.(type) {
case *Select:
tables = collectFromTableExprs(node.From, tables)
case TableExprs:
return false, nil
}
return true, nil
}, stmt)
return tables
}
func collectFromTableName(node TableName, tables []TableSchemaAndName) []TableSchemaAndName {
tables = append(tables, TableSchemaAndName{
schema: node.Qualifier.String(),
name: node.Name.String(),
})
return tables
}
func collectFromTableExprs(node TableExprs, tables []TableSchemaAndName) []TableSchemaAndName {
for _, node := range node {
tables = buildTableExprPermissions(node, tables)
}
return tables
}
func buildTableExprPermissions(node TableExpr, tables []TableSchemaAndName) []TableSchemaAndName {
switch node := node.(type) {
case *AliasedTableExpr:
// An AliasedTableExpr can also be a subquery, but we should skip them here
// because the buildSubQueryPermissions walker will catch them and extract
// the corresponding table names.
switch node := node.Expr.(type) {
case TableName:
tables = collectFromTableName(node, tables)
case *DerivedTable:
tables = collectFromSubQuery(node.Select, tables)
}
case *ParenTableExpr:
tables = collectFromTableExprs(node.Exprs, tables)
case *JoinTableExpr:
tables = buildTableExprPermissions(node.LeftExpr, tables)
tables = buildTableExprPermissions(node.RightExpr, tables)
}
return tables
}
func addDefaultTableSchema(tables []TableSchemaAndName, dbName string) []TableSchemaAndName {
if dbName == "" {
return tables
}
for index := range tables {
if tables[index].schema == "" {
tables[index].schema = dbName
}
}
return tables
}
func removeDuplicateTables(tables []TableSchemaAndName) []TableSchemaAndName {
encountered := map[TableSchemaAndName]bool{}
var result []TableSchemaAndName
for v := range tables {
if encountered[tables[v]] == true {
// Do not add duplicate.
} else {
encountered[tables[v]] = true
result = append(result, tables[v])
}
}
return result
}