Skip to content

Commit

Permalink
refactor tables
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 committed Oct 10, 2024
1 parent cd98505 commit b95409f
Showing 1 changed file with 88 additions and 154 deletions.
242 changes: 88 additions & 154 deletions lumen/ai/interceptor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import json
import sqlite3
import uuid

from abc import abstractmethod
from functools import wraps
Expand All @@ -17,13 +18,16 @@ class Message(BaseModel):
content: str


class Batch(BaseModel):
batch_id: int
class Invocation(BaseModel):
input_id: int
prompt: str
messages: list[Message]
kwargs: dict[str, Any]
response: str | None
kwargs: dict[str, Any]
invocation_id: str

def serialize(self) -> list[dict[str, Any]]:
"""Serialize messages into a list of dictionaries."""
return [
{"role": message.role, "content": message.content}
for message in self.messages
Expand All @@ -32,7 +36,7 @@ def serialize(self) -> list[dict[str, Any]]:

class Session(BaseModel):
session_id: str
batches: list[Batch]
invocations: list[Invocation]


class Interceptor(param.Parameterized):
Expand All @@ -48,7 +52,7 @@ def __init__(self, **params):
if needs_init:
self.init_db()
self._client = self._original_create = self._original_create_response = None
self._last_batch_id = None
self._last_invocation_id = None
self.session_id = self._generate_session_id()

def _create_connection(self) -> sqlite3.Connection:
Expand All @@ -58,24 +62,11 @@ def _create_connection(self) -> sqlite3.Connection:
def _generate_session_id(self) -> str:
"""Generate a unique session ID."""
first_message_timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
return f"conv_{first_message_timestamp}"
return f"session_{first_message_timestamp}"

def _dump_response_model(self, response_model: BaseModel) -> str:
"""Dump the response model to a JSON string."""
# using json.dumps instead of model_dump_json
# for more consistent serialization
content = json.dumps(response_model.model_dump())
return content

def _select_max_batch_id(self) -> int:
"""Get the maximum batch ID for the current session."""
cursor = self.conn.cursor()
cursor.execute(
"SELECT MAX(batch_id) FROM messages WHERE session_id = ?",
(self.session_id,),
)
max_batch_id = cursor.fetchone()[0]
return 0 if max_batch_id is None else max_batch_id
return json.dumps(response_model.model_dump())

@abstractmethod
def patch_client(self, client) -> None:
Expand All @@ -91,57 +82,42 @@ def init_db(self) -> None:
cursor = self.conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS message_batches (
id TEXT PRIMARY KEY,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
CREATE TABLE IF NOT EXISTS invocations (
invocation_id TEXT PRIMARY KEY,
session_id TEXT,
role TEXT,
content TEXT,
batch_id INTEGER,
messages TEXT,
message_kwargs TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES message_batches(id)
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS message_kwargs (
session_id TEXT,
batch_id INTEGER,
key TEXT,
value TEXT,
FOREIGN KEY (session_id) REFERENCES message_batches(id),
FOREIGN KEY (batch_id) REFERENCES messages(batch_id)
FOREIGN KEY (session_id) REFERENCES sessions(session_id)
)
"""
)
self.conn.commit()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS responses (
session_id TEXT,
batch_id INTEGER,
response_id INTEGER PRIMARY KEY AUTOINCREMENT,
invocation_id INTEGER,
content TEXT,
FOREIGN KEY (session_id) REFERENCES message_batches(id),
FOREIGN KEY (batch_id) REFERENCES messages(batch_id)
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (invocation_id) REFERENCES invocations(invocation_id)
)
"""
)
self.conn.commit()

def reset_db(self) -> None:
"""Reset the database by deleting all tables."""
cursor = self.conn.cursor()
cursor.execute("DROP TABLE IF EXISTS message_batches")
cursor.execute("DROP TABLE IF EXISTS messages")
cursor.execute("DROP TABLE IF EXISTS message_kwargs")
cursor.execute("DROP TABLE IF EXISTS responses")
cursor.execute("DROP TABLE IF EXISTS invocations")
cursor.execute("DROP TABLE IF EXISTS sessions")
self.conn.commit()
self.init_db()

Expand All @@ -150,13 +126,10 @@ def delete_session(self, session_id: str | None = None) -> None:
cursor = self.conn.cursor()
if session_id is None:
session_id = self.get_session_ids()[-1]
cursor.execute("DELETE FROM message_batches WHERE id = ?", (session_id,))
cursor.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
cursor.execute("DELETE FROM message_kwargs WHERE session_id = ?", (session_id,))
cursor.execute("DELETE FROM responses WHERE session_id = ?", (session_id,))
cursor.execute("DELETE FROM invocations WHERE session_id = ?", (session_id,))
self.conn.commit()

def store_messages(
def store_invocation(
self, messages: list[dict[str, str]], **kwargs: dict[str, Any]
) -> None:
"""
Expand All @@ -167,159 +140,120 @@ def store_messages(
kwargs: The keyword arguments passed to the create method.
"""
cursor = self.conn.cursor()
self._last_invocation_id = invocation_id = uuid.uuid4().hex

cursor.execute(
"INSERT OR IGNORE INTO message_batches (id) VALUES (?)",
(self.session_id,),
"""
INSERT INTO invocations (invocation_id, session_id, messages, message_kwargs)
VALUES (?, ?, ?, ?)
""",
(invocation_id, self.session_id, json.dumps(messages), json.dumps(kwargs)),
)

batch_id = self._select_max_batch_id() + 1

for message in messages:
cursor.execute(
"""
INSERT INTO messages (session_id, role, content, batch_id) VALUES (?, ?, ?, ?)
""",
(self.session_id, message["role"], message["content"], batch_id),
)

for key, value in kwargs.items():
if key == "messages":
continue

try:
cursor.execute(
"""
INSERT INTO message_kwargs (session_id, batch_id, key, value) VALUES (?, ?, ?, ?)
""",
(self.session_id, batch_id, key, json.dumps(value)),
)
except TypeError as exc:
raise RuntimeError(
"The method patch_client must be called before instructor.patch"
) from exc

self.conn.commit()

self._last_batch_id = batch_id

def store_response(self, content: str) -> None:
cursor = self.conn.cursor()

# Store the response based on the last batch_id
cursor.execute(
"""
INSERT INTO responses (session_id, batch_id, content)
INSERT INTO responses (invocation_id, content, timestamp)
VALUES (?, ?, ?)
""",
(self.session_id, self._last_batch_id, content),
(self._last_invocation_id, content, datetime.datetime.now()),
)
self.conn.commit()

def get_session(self, session_id: str | None = None) -> Session:
"""
Retrieve the session batches of inputs from the last session, or a specific session if provided.
Retrieve the session invocations of inputs from the last session, or a specific session if provided.
Args:
session_id: The session ID to retrieve batches from. If not provided, the last session is used.
session_id: The session ID to retrieve invocations from. If not provided, the last session is used.
Returns:
A list of dictionaries containing batch_id, messages, and kwargs for each batch.
A Session object containing invocations for the session.
"""
cursor = self.conn.cursor()

if session_id is None:
cursor.execute(
"SELECT id FROM message_batches ORDER BY timestamp DESC LIMIT 1"
"SELECT session_id FROM invocations ORDER BY timestamp DESC LIMIT 1"
)
try:
session_id = cursor.fetchone()[0]
except TypeError:
return []
return Session(session_id="", invocations=[])

cursor.execute(
"SELECT DISTINCT batch_id FROM messages WHERE session_id = ? ORDER BY batch_id",
"""
SELECT
invocations.invocation_id,
messages,
message_kwargs,
responses.content AS response
FROM invocations
LEFT JOIN responses ON invocations.invocation_id = responses.invocation_id
WHERE invocations.session_id = ?
ORDER BY invocations.invocation_id DESC
""",
(session_id,),
)
batch_ids = cursor.fetchall()

batches = []
for (batch_id,) in batch_ids:
cursor.execute(
"SELECT role, content FROM messages WHERE session_id = ? AND batch_id = ? ORDER BY timestamp",
(session_id, batch_id),
)
messages = cursor.fetchall()
batch_messages = [
{"role": role, "content": content} for role, content in messages
]

cursor.execute(
"""
SELECT key, value
FROM message_kwargs
WHERE session_id = ? AND batch_id = ?
""",
(session_id, batch_id),
)
kwargs_data = cursor.fetchall()
kwargs_dict = {key: json.loads(value) for key, value in kwargs_data}

cursor.execute(
"""
SELECT content
FROM responses
WHERE session_id = ? AND batch_id = ?
""",
(session_id, batch_id),
)
response_data = cursor.fetchone()
if response_data is None:
response_content = None
else:
response_content = response_data[0]

batches.append(
{
"batch_id": batch_id,
"messages": batch_messages,
"kwargs": kwargs_dict,
"response": response_content,
}
invocation_data = cursor.fetchall()

input_id = -1
invocations = []
prev_user_content = None
for invocation_id, invocation_messages, message_kwargs, response in invocation_data:
messages = []
for message in json.loads(invocation_messages):
if message["role"] == "user":
user_content = message["content"]
if prev_user_content != user_content:
prev_user_content = user_content
input_id += 1
messages.append(Message(role=message["role"], content=message["content"]))

invocations.append(
Invocation(
prompt=user_content,
invocation_id=invocation_id,
input_id=input_id,
messages=messages,
kwargs=json.loads(message_kwargs),
response=response,
)
)

return Session(
session_id=session_id, batches=[Batch(**batch) for batch in batches]
)
return Session(session_id=session_id, invocations=invocations)

def get_all_sessions(self) -> dict[str, Session]:
"""
Retrieve the batches of messages from all sessions.
Retrieve the invocations of messages from all sessions.
Returns:
A dictionary containing session_id as keys and the corresponding
list of message batches for each session.
Session object for each session.
"""
all_batches = {}
all_sessions = {}
for session_id in self.get_session_ids():
all_batches[session_id] = self.get_session(session_id)
all_sessions[session_id] = self.get_session(session_id)

return all_batches
return all_sessions

def get_session_ids(self) -> list[str]:
cursor = self.conn.cursor()
cursor.execute("SELECT id FROM message_batches")
cursor.execute("SELECT DISTINCT session_id FROM invocations")
return [row[0] for row in cursor.fetchall()]

def unpatch(self) -> None:
"""Close the database connection and reverts the client create."""
"""Close the database connection and revert the client create."""
if self._original_create_response is not None:
self._client.chat.completions.create = self._original_create_response
if self._original_create is not None:
self._client.chat.completions.create = self._original_create
self.conn.close()

def __del__(self) -> None:
"""Close the database connection and reverts the client create when the object is deleted."""
"""Close the database connection and revert the client create when the object is deleted."""
self.unpatch()


Expand All @@ -343,11 +277,11 @@ def patch_client(
async def stream_response(*args: Any, **kwargs: Any):
async for chunk in await self._original_create(*args, **kwargs):
yield chunk
self.store_messages(**kwargs)
self.store_invocation(**kwargs)

async def non_stream_response(*args: Any, **kwargs: Any):
response = await self._original_create(*args, **kwargs)
self.store_messages(**kwargs)
self.store_invocation(**kwargs)
return response

@wraps(client.chat.completions.create)
Expand Down

0 comments on commit b95409f

Please sign in to comment.