-
Notifications
You must be signed in to change notification settings - Fork 12
/
export_dataset.py
69 lines (58 loc) · 2.21 KB
/
export_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import shutil
import sqlite3
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument("db")
args = parser.parse_args()
# Prepare dataset folder
try:
os.mkdir("dataset")
except FileExistsError:
pass
# Create release database
sac_db = sqlite3.connect('dataset/sac.sqlite')
cursor = sac_db.cursor()
cursor.execute('''CREATE TABLE generations
(id INTEGER PRIMARY KEY, method INTEGER, prompt TEXT, verified INTEGER)''')
cursor.execute('''CREATE TABLE ratings
(gid INTEGER, rating INTEGER,
FOREIGN KEY(gid) REFERENCES generations(id))''')
cursor.execute('''CREATE TABLE upscales
(gid INTEGER, choice INTEGER, method INTEGER,
FOREIGN KEY(gid) REFERENCES generations(id),
PRIMARY KEY(gid, choice))''')
sac_db.commit()
cursor.close()
# Retrieve generations from current database
bot_db = sqlite3.connect(args.db)
cursor = bot_db.cursor()
cursor.execute("select * from generations;")
gens = cursor.fetchall()
cursor.close()
bot_cursor = bot_db.cursor()
sac_cursor = sac_db.cursor()
for i, gen in enumerate(tqdm(gens)):
path_template = str(gen[0]) + "_" + gen[-1].replace(" ", "_").replace("/","_") + "_{}" + ".png"
paths = [path[1].format(path[0]+1) for path in enumerate([path_template] * 8)]
for path in paths:
shutil.copy(path, "dataset/{}".format(path))
# Correct index and redact info into new release database
bot_cursor.execute("SELECT * FROM users WHERE id=?", (gen[1],))
user = bot_cursor.fetchone()
verified = int(user[3])
bot_cursor.execute("SELECT * FROM ratings WHERE gid=?", (gen[0],))
ratings = bot_cursor.fetchall()
bot_cursor.execute("SELECT * FROM upscales WHERE gid=?", (gen[0],))
upscales = bot_cursor.fetchall()
sac_cursor.execute("INSERT INTO generations VALUES (?,?,?,?)", (i, gen[3], gen[4], verified))
for rating in ratings:
sac_cursor.execute("INSERT INTO ratings VALUES (?,?)", (i, rating[2]))
for upscale in upscales:
sac_cursor.execute("INSERT INTO upscales VALUES (?,?,?)", (i, upscale[1], upscale[2]))
sac_db.commit()
bot_cursor.close()
sac_cursor.close()
bot_db.close()
sac_db.close()