-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
126 lines (96 loc) · 6.03 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import random
import ast
import json
import os
nominal_operations = ['=', '!=', 'IS', 'IN']
numeric_operations = ['<', '>', '<=', '>=', '=', '!=', 'IS', 'BETWEEN', 'NOT BETWEEN', 'IN']
def generate_queries():
argument_parser = argparse.ArgumentParser()
argument_parser.add_argument('num_queries')
argument_parser.add_argument('joins_file')
argument_parser.add_argument('selections_file')
args = argument_parser.parse_args()
with open(args.joins_file, 'r') as joins_file, open(args.selections_file, 'r') as selections_file:
join_specs = json.load(joins_file)
selection_specs = json.load(selections_file)
queries = set()
for q in range(int(args.num_queries)):
tables = set()
columns = set()
selections = []
# get a random sample of N selections where 1 <= N <= len(selection_specs)
sampled_selection_specs = random.sample(selection_specs, k=random.randint(1, len(selection_specs)))
for i, selection_spec in enumerate(sampled_selection_specs):
# include the table reference in the list of tables
tables.add(ast.ReferenceNode(selection_spec['table_name']))
# include the column reference in the list of columns (as aggregate function MIN)
columns.add(ast.FunctionNode('MIN', (
ast.ReferenceDotNode(selection_spec['table_name'], selection_spec['column_name']))))
# the left side of the selection is a reference dot node
left = ast.ReferenceDotNode(selection_spec['table_name'], selection_spec['column_name'])
# choose a random operation depending on the selection attribute
operation = random.choice(nominal_operations) \
if selection_spec['data_type'] in ['text', 'character varying'] \
else random.choice(numeric_operations)
# create the node for the right side of the selection
if operation == 'IS':
right = ast.TermNode(random.choice(['NULL', 'NOT NULL']))
elif operation in ['BETWEEN', 'NOT BETWEEN']:
if 'values' in selection_spec:
left_bound = random.choice(selection_spec['values'])
right_bound = random.choice([v for v in selection_spec['values'] if v != left_bound])
else:
left_bound = random.randint(selection_spec['range'][0], selection_spec['range'][1] - 1)
right_bound = random.randint(left_bound + 1, selection_spec['range'][1])
right = ast.BoundsNode(ast.TermNode(left_bound), ast.TermNode(right_bound))
elif operation == 'IN':
if 'values' in selection_spec:
terms = random.sample(selection_spec['values'], k=len(selection_spec['values']))
else:
k = random.randint(1, 10)
terms = random.sample(range(selection_spec['range'][0], selection_spec['range'][1] + 1), k)
right = ast.ListNode([ast.TermNode(term) for term in terms])
else:
if 'values' in selection_spec:
term = random.choice(selection_spec['values'])
else:
term = random.randint(selection_spec['range'][0], selection_spec['range'][1])
right = ast.TermNode(term)
selections.append(ast.OperationNode(operation, left, right))
joins = []
# only join on attributes that are included in a selection, otherwise the query may be intractable
selection_table_names = [table.reference for table in tables]
filtered_join_specs = list(filter(lambda s: s['left_table_name'] in selection_table_names
or s['right_table_name'] in selection_table_names, join_specs))
# get a random sample of N joins where 0 <= N <= len(filtered_join_specs)
sampled_join_specs = random.sample(filtered_join_specs, k=random.randint(0, len(filtered_join_specs)))
for join_spec in sampled_join_specs:
# include the table references in the list of tables
tables.add(ast.ReferenceNode(join_spec['left_table_name']))
tables.add(ast.ReferenceNode(join_spec['right_table_name']))
# include the column references in the list of columns
columns.add(ast.FunctionNode('MIN', ast.ReferenceDotNode(join_spec['left_table_name'],
join_spec['left_column_name'])))
columns.add(ast.FunctionNode('MIN', ast.ReferenceDotNode(join_spec['right_table_name'],
join_spec['right_column_name'])))
# create the join node
left = ast.ReferenceDotNode(join_spec['left_table_name'], join_spec['left_column_name'])
right = ast.ReferenceDotNode(join_spec['right_table_name'], join_spec['right_column_name'])
joins.append(ast.OperationNode('=', left, right))
# build the tree bottom-up, combining selections and joins with the AND operator
predicate = selections.pop()
while selections:
predicate = ast.OperationNode('AND', predicate, selections.pop())
while joins:
predicate = ast.OperationNode('AND', predicate, joins.pop())
# create the top level select node
query = ast.SelectNode(ast.ListNode(list(columns)), ast.ListNode(list(tables)), predicate)
queries.add(query.to_sql())
# write all the generated queries to file
os.makedirs('generated', exist_ok=True)
for i, q in enumerate(queries):
with open('generated/{}.sql'.format(i + 1), 'w') as f:
f.write(q)
if __name__ == '__main__':
generate_queries()