Skip to content

Commit b9ffd32

Browse files
committed
improved memory usage
1 parent 1351153 commit b9ffd32

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ python src/index.py
6060
**Deploy Web Server**
6161
```sh
6262
# Start an API endpoint
63-
gunicorn -w 4 -b 0.0.0.0:8080 server:app
63+
gunicorn -w 1 --threads 100 --worker-class gthread -b 0.0.0.0:8080 src.server:app
6464

6565
# Then visit:
6666
# http://localhost:8080

fly.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ primary_region = 'lax'
1414
processes = ['app']
1515

1616
[[vm]]
17-
memory = '4gb'
17+
memory = '2gb'
1818
cpu_kind = 'shared'
1919
cpus = 2

src/server.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
11
import os, math, re, requests, io, time
22
from typing import List, Optional, Union
33

4+
import sys
5+
from pathlib import Path
6+
sys.path.append(str(Path(__file__).parent))
7+
48
from flask import Flask, abort, request, render_template, jsonify
59
from functools import lru_cache
610

711
from constants import INDEX_PATH, VENUES
8-
912
from search import ColBERT
1013
from db import create_database, query_paper_metadata
1114
from utils import download_index_from_hf, print_estimate_cost
1215

1316
import PyPDF2
1417
from openai import OpenAI
1518

19+
from threading import local
20+
thread_local = local()
21+
1622
PORT = int(os.getenv("PORT", 8080))
1723
app = Flask(__name__)
18-
19-
# Load OpenAI API key
20-
if os.path.exists('.openai-api-key'):
21-
with open('.openai-api-key', 'r') as f:
22-
api_key = f.read().strip()
23-
client = OpenAI(api_key=api_key)
24+
colbert = None
25+
_is_initialized = False
2426

2527
@lru_cache(maxsize=1000000)
2628
def api_search_query(query):
@@ -114,17 +116,26 @@ def get_pdf_text(pdf_url: str) -> str:
114116
return " ".join(page.extract_text() for page in pdf_reader.pages)
115117

116118

119+
def get_openai_client():
120+
if not hasattr(thread_local, 'client'):
121+
if os.path.exists('.openai-api-key'):
122+
with open('.openai-api-key', 'r') as f:
123+
api_key = f.read().strip()
124+
thread_local.client = OpenAI(api_key=api_key)
125+
return thread_local.client
126+
127+
117128
@app.route('/api/llm', methods=['POST'])
118129
def query_llm():
130+
print(f'Started a new query!')
131+
client = get_openai_client()
119132
data = request.json
120133
title = data['title']
121134
abstract = data['abstract']
122135
question = data['question']
123136
pdf_url = data['pdf_url'] if 'pdf_url' in data else None
124137

125138
try:
126-
print(f'Started a new query!')
127-
128139
start_time = time.time()
129140
pdf_text = get_pdf_text(pdf_url)
130141
print(f"PDF text extraction took {time.time() - start_time:.2f} seconds")
@@ -135,7 +146,7 @@ def query_llm():
135146
Question: {question}
136147
Please provide a concise answer."""
137148

138-
print_estimate_cost(prompt, model='gpt-4o-mini', input_cost=0.15, output_cost=0.6, estimated_output_toks=100)
149+
# print_estimate_cost(prompt, model='gpt-4o-mini', input_cost=0.15, output_cost=0.6, estimated_output_toks=100)
139150

140151
start_llm_time = time.time()
141152
response = client.chat.completions.create(
@@ -157,20 +168,29 @@ def query_llm():
157168
return jsonify({'error': 'Failed to process request'}), 500
158169

159170

171+
def init_app():
172+
global colbert, _is_initialized
173+
if not _is_initialized:
174+
download_index_from_hf()
175+
create_database()
176+
colbert = ColBERT(index_path=INDEX_PATH)
177+
_is_initialized = True
178+
179+
# Remove client initialization from here since we're using thread-local storage
180+
181+
182+
@app.before_request
183+
def before_request():
184+
init_app()
185+
186+
160187
if __name__ == "__main__":
161188
"""
162189
Example usage:
163190
python server.py
164191
http://localhost:8080/api/colbert?query=Information retrevial with BERT
165192
http://localhost:8080/api/search?query=Information retrevial with BERT
166193
"""
167-
download_index_from_hf()
168-
create_database()
169-
global colbert
170-
colbert = ColBERT(index_path=INDEX_PATH)
171-
print(colbert.search('text simplificaiton'))
172-
print(api_search_query("text simplification")['topk'][:5])
173-
174194
# Watch web dirs for changes
175195
extra_files = [os.path.join(dirname, filename) for dirname, _, files in os.walk('templates') for filename in files]
176196

0 commit comments

Comments
 (0)