Skip to content

Commit d6de7df

Browse files
committed
mysql -> sqlite
1 parent d79cb00 commit d6de7df

File tree

6 files changed

+17
-69
lines changed

6 files changed

+17
-69
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ app
55
index
66
*.json
77
*.bib
8-
main.yml
8+
main.yml
9+
anthology.db

docker-compose.yml

Lines changed: 0 additions & 33 deletions
This file was deleted.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ colbert-ir[torch,faiss-cpu] # faiss-gpu
33
faiss-cpu
44
huggingface_hub
55
bibtexparser
6-
mysql-connector-python
6+
pysqlite3
77
flask
88
gunicorn

src/constants.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
DATASET_PATH = os.path.join(INDEX_ROOT, 'dataset.json')
1212

1313
DB_NAME = 'anthology'
14-
DB_HOSTNAME = 'mysql_db' # localhost
15-
DB_PORT = 3306 # None
14+
DB_FILENAME = 'anthology.db'
1615

1716
VENUES = Literal['workshop', 'journal', 'short', 'demo', 'tutorial', 'industry', 'findings', 'main']

src/db.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import json
22
from typing import List, Optional, Union
33

4-
import mysql.connector
5-
from constants import DATASET_PATH, DB_HOSTNAME, DB_NAME, DB_PORT, VENUES
4+
import sqlite3
5+
from constants import DATASET_PATH, DB_FILENAME, DB_NAME, VENUES
66

77
PAPER_QUERY = """
88
SELECT *
@@ -19,27 +19,11 @@ def read_dataset():
1919

2020

2121
def create_database():
22-
db = mysql.connector.connect(
23-
host = DB_HOSTNAME,
24-
user = "root",
25-
password = "",
26-
port = DB_PORT
22+
db: sqlite3.Connection = sqlite3.connect(
23+
database = DB_FILENAME
2724
)
2825
cursor = db.cursor()
2926

30-
cursor.execute("SHOW DATABASES")
31-
db_exists = False
32-
for x in cursor:
33-
db_name = x[0]
34-
if db_name == DB_NAME:
35-
db_exists = True
36-
37-
# Create database
38-
if not db_exists:
39-
print("Creating new database...")
40-
cursor.execute(f'CREATE DATABASE {DB_NAME}')
41-
cursor.execute(f'USE {DB_NAME}')
42-
4327
# Create table
4428
print('Creating new table...')
4529
cursor.execute(f'DROP TABLE IF EXISTS paper')
@@ -61,6 +45,7 @@ def create_database():
6145
acl_data = read_dataset()
6246

6347
vals = []
48+
paper: dict
6449
for pid, paper in enumerate(acl_data):
6550
title = paper.get('title', '')
6651
author = paper.get('author', '')
@@ -80,7 +65,7 @@ def create_database():
8065
INSERT INTO paper (
8166
pid, title, author, year, abstract, url, type, venue, venue_type, is_findings
8267
) VALUES (
83-
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s
68+
?, ?, ?, ?, ?, ?, ?, ?, ?, ?
8469
)
8570
"""
8671

@@ -123,17 +108,13 @@ def query_paper_metadata(
123108
):
124109
if not isinstance(venue_type, list): venue_type = [venue_type]
125110

126-
db = mysql.connector.connect(
127-
host = DB_HOSTNAME,
128-
user = "root",
129-
password = "",
130-
database = DB_NAME,
131-
port = DB_PORT
111+
db: sqlite3.Connection = sqlite3.connect(
112+
database = DB_FILENAME
132113
)
133114

134115
cursor = db.cursor()
135116

136-
pids_str = ', '.join(['%s'] * len(pids))
117+
pids_str = ', '.join(['?'] * len(pids))
137118

138119
constraints_str = ""
139120
if start_year: constraints_str += f" AND year >= {start_year}"

src/server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def query():
6969
# Get top passage IDs from ColBERT
7070
colbert_response = api_search_query(query)
7171

72-
# Query MySQL database for paper information
72+
# Query database for paper information
7373
pids = [r['pid'] for r in colbert_response["topk"]]
74-
mysql_response = query_paper_metadata(
74+
server_response = query_paper_metadata(
7575
pids,
7676
start_year=start_year,
7777
end_year=end_year,
@@ -80,9 +80,9 @@ def query():
8080
)
8181

8282
K = 20
83-
mysql_response = mysql_response[:K]
83+
server_response = server_response[:K]
8484

85-
return mysql_response
85+
return server_response
8686

8787

8888
# @app.route('/search', methods=['POST', 'GET'])

0 commit comments

Comments
 (0)