diff --git a/lumen/ai/interceptor.py b/lumen/ai/interceptor.py new file mode 100644 index 00000000..1f3a120c --- /dev/null +++ b/lumen/ai/interceptor.py @@ -0,0 +1,352 @@ +import datetime +import json +import sqlite3 +import uuid + +from abc import abstractmethod +from functools import wraps +from pathlib import Path +from typing import Any, Literal + +import param + +from pydantic import BaseModel + + +class Message(BaseModel): + role: str + content: str + + +class Invocation(BaseModel): + input_id: int + prompt: str + messages: list[Message] + response: str | None + kwargs: dict[str, Any] + invocation_id: str + + def serialize(self) -> list[dict[str, Any]]: + """Serialize messages into a list of dictionaries.""" + return [ + {"role": message.role, "content": message.content} + for message in self.messages + ] + + +class Session(BaseModel): + session_id: str + invocations: list[Invocation] + + +class Interceptor(param.Parameterized): + + db_path = param.String( + default="messages.db", doc="Path to the SQLite database file" + ) + + def __init__(self, **params): + super().__init__(**params) + needs_init = not Path(self.db_path).exists() + self.conn = self._create_connection() + if needs_init: + self.init_db() + self._client = self._original_create = self._original_create_response = None + self._last_invocation_id = None + self.session_id = self._generate_session_id() + + def _create_connection(self) -> sqlite3.Connection: + """Create and return a database connection.""" + return sqlite3.connect(self.db_path) + + def _generate_session_id(self) -> str: + """Generate a unique session ID.""" + first_message_timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + return f"session_{first_message_timestamp}" + + def _dump_response_model(self, response_model: BaseModel) -> str: + """Dump the response model to a JSON string.""" + return json.dumps(response_model.model_dump()) + + @abstractmethod + def patch_client(self, client) -> None: + """ + Patch the LLM client's create method to store messages and arguments in the database. + + Args: + client: The LLM client instance to patch. + """ + + def init_db(self) -> None: + """Initialize the database by creating necessary tables if they don't exist.""" + cursor = self.conn.cursor() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS sessions ( + session_id TEXT PRIMARY KEY, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP + ) + """ + ) + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS invocations ( + invocation_id TEXT PRIMARY KEY, + session_id TEXT, + messages TEXT, + message_kwargs TEXT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES sessions(session_id) + ) + """ + ) + self.conn.commit() + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS responses ( + response_id INTEGER PRIMARY KEY AUTOINCREMENT, + invocation_id INTEGER, + content TEXT, + timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (invocation_id) REFERENCES invocations(invocation_id) + ) + """ + ) + + def reset_db(self) -> None: + """Reset the database by deleting all tables.""" + cursor = self.conn.cursor() + cursor.execute("DROP TABLE IF EXISTS invocations") + cursor.execute("DROP TABLE IF EXISTS sessions") + self.conn.commit() + self.init_db() + + def delete_session(self, session_id: str | None = None) -> None: + """Delete the last session from the database.""" + cursor = self.conn.cursor() + if session_id is None: + session_id = self.get_session_ids()[-1] + cursor.execute("DELETE FROM invocations WHERE session_id = ?", (session_id,)) + self.conn.commit() + + def store_invocation( + self, messages: list[dict[str, str]], **kwargs: dict[str, Any] + ) -> None: + """ + Store messages and keyword arguments in the database for the current session. + + Args: + messages: List of message dictionaries to store. + kwargs: The keyword arguments passed to the create method. + """ + cursor = self.conn.cursor() + self._last_invocation_id = invocation_id = uuid.uuid4().hex + + cursor.execute( + """ + INSERT INTO invocations (invocation_id, session_id, messages, message_kwargs) + VALUES (?, ?, ?, ?) + """, + (invocation_id, self.session_id, json.dumps(messages), json.dumps(kwargs)), + ) + self.conn.commit() + + def store_response(self, content: str) -> None: + cursor = self.conn.cursor() + cursor.execute( + """ + INSERT INTO responses (invocation_id, content, timestamp) + VALUES (?, ?, ?) + """, + (self._last_invocation_id, content, datetime.datetime.now()), + ) + self.conn.commit() + + def get_session(self, session_id: str | None = None) -> Session: + """ + Retrieve the session invocations of inputs from the last session, or a specific session if provided. + + Args: + session_id: The session ID to retrieve invocations from. If not provided, the last session is used. + + Returns: + A Session object containing invocations for the session. + """ + cursor = self.conn.cursor() + + if session_id is None: + cursor.execute( + "SELECT session_id FROM invocations ORDER BY timestamp DESC LIMIT 1" + ) + try: + session_id = cursor.fetchone()[0] + except TypeError: + return Session(session_id="", invocations=[]) + + cursor.execute( + """ + SELECT + invocations.invocation_id, + messages, + message_kwargs, + responses.content AS response + FROM invocations + LEFT JOIN responses ON invocations.invocation_id = responses.invocation_id + WHERE invocations.session_id = ? + ORDER BY invocations.timestamp ASC + """, + (session_id,), + ) + invocation_data = cursor.fetchall() + + input_id = -1 + invocations = [] + prev_user_content = None + for invocation_id, invocation_messages, message_kwargs, response in invocation_data: + messages = [] + for message in json.loads(invocation_messages): + if message["role"] == "user": + user_content = message["content"] + if prev_user_content != user_content: + prev_user_content = user_content + input_id += 1 + messages.append(Message(role=message["role"], content=message["content"])) + + invocations.append( + Invocation( + prompt=user_content, + invocation_id=invocation_id, + input_id=input_id, + messages=messages, + kwargs=json.loads(message_kwargs), + response=response, + ) + ) + + return Session(session_id=session_id, invocations=invocations) + + def get_all_sessions(self) -> dict[str, Session]: + """ + Retrieve the invocations of messages from all sessions. + + Returns: + A dictionary containing session_id as keys and the corresponding + Session object for each session. + """ + all_sessions = {} + for session_id in self.get_session_ids(): + all_sessions[session_id] = self.get_session(session_id) + + return all_sessions + + def get_session_ids(self) -> list[str]: + cursor = self.conn.cursor() + cursor.execute("SELECT DISTINCT session_id FROM invocations") + return [row[0] for row in cursor.fetchall()] + + def unpatch(self) -> None: + """Close the database connection and revert the client create.""" + if self._original_create_response is not None: + self._client.chat.completions.create = self._original_create_response + if self._original_create is not None: + self._client.chat.completions.create = self._original_create + self.conn.close() + + def __del__(self) -> None: + """Close the database connection and revert the client create when the object is deleted.""" + self.unpatch() + + +class OpenAIInterceptor(Interceptor): + + def patch_client( + self, client, mode: Literal["store_all", "store_inputs"] = "store_all" + ) -> None: + """ + Patch the OpenAI client's create method to store messages and arguments in the database. + + Args: + client: The OpenAI client instance to patch. + mode: The mode to patch the client in. + If "store_all", responses are generated, and everything is stored. + If "store_inputs", responses are generated, but only input arguments are stored. + """ + self._client = client + self._original_create = client.chat.completions.create + + async def stream_response(*args: Any, **kwargs: Any): + async for chunk in await self._original_create(*args, **kwargs): + yield chunk + self.store_invocation(**kwargs) + + async def non_stream_response(*args: Any, **kwargs: Any): + response = await self._original_create(*args, **kwargs) + self.store_invocation(**kwargs) + return response + + @wraps(client.chat.completions.create) + async def patched_async_create(*args: Any, **kwargs: Any) -> Any: + stream = kwargs.get("stream", False) + if stream: + return stream_response(*args, **kwargs) + else: + return await non_stream_response(*args, **kwargs) + + self._client.chat.completions.create = patched_async_create + if mode == "store_all": + self.patch_client_response(client) + + def patch_client_response(self, client) -> None: + """ + Patch the OpenAI client's create method to store the response in the database. + + Args: + client: The OpenAI client instance to patch. + """ + self._client = client + self._original_create_response = client.chat.completions.create + + async def stream_response(*args: Any, **kwargs: Any): + content = "" + response = await self._original_create_response(*args, **kwargs) + async for chunk in response: + yield chunk + if hasattr(chunk, "choices") and chunk.choices: + delta = chunk.choices[0].delta + if delta.content is not None: + content += delta.content + + # last chunk contains all the content + if not content and isinstance(chunk, BaseModel): + content = self._dump_response_model(chunk) + + if content: + self.store_response(content) + else: + print("Could not intercept a response") + + async def non_stream_response(*args: Any, **kwargs: Any): + content = "" + response = await self._original_create_response(*args, **kwargs) + if hasattr(response, "choices"): + content = response.choices[0].message.content + + if not content and isinstance(response, BaseModel): + content = self._dump_response_model(response) + + if content: + self.store_response(content) + else: + print("Could not intercept a response.") + + return response + + @wraps(client.chat.completions.create) + async def patched_create_response(*args: Any, **kwargs: Any) -> Any: + stream = kwargs.get("stream", False) + if stream: + return stream_response(*args, **kwargs) + else: + return await non_stream_response(*args, **kwargs) + + client.chat.completions.create = patched_create_response diff --git a/lumen/ai/llm.py b/lumen/ai/llm.py index 3040094e..8da21a3e 100644 --- a/lumen/ai/llm.py +++ b/lumen/ai/llm.py @@ -3,6 +3,7 @@ import os from functools import partial +from types import SimpleNamespace import instructor import panel as pn @@ -12,6 +13,8 @@ from instructor.patch import Mode, patch from pydantic import BaseModel +from .interceptor import Interceptor + class Llm(param.Parameterized): @@ -21,6 +24,8 @@ class Llm(param.Parameterized): use_logfire = param.Boolean(default=False) + interceptor = param.ClassSelector(default=None, class_=Interceptor) + # Allows defining a dictionary of default models. model_kwargs = param.Dict(default={}) @@ -28,6 +33,9 @@ class Llm(param.Parameterized): __abstract = True + def __init__(self, **params): + super().__init__(**params) + def _get_model_kwargs(self, model_key): if model_key in self.model_kwargs: model_kwargs = self.model_kwargs.get(model_key) @@ -219,11 +227,18 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None, ** if self.organization: model_kwargs["organization"] = self.organization llm = openai.AsyncOpenAI(**model_kwargs) + + if self.interceptor: + self.interceptor.patch_client(llm, mode="store_inputs") + if response_model: - client = instructor.from_openai(llm) - client_callable = partial(client.chat.completions.create, model=model) - else: - client_callable = partial(llm.chat.completions.create, model=model) + llm = patch(llm) + + if self.interceptor: + # must be called after instructor + self.interceptor.patch_client_response(llm) + + client_callable = partial(llm.chat.completions.create, model=model) if self.use_logfire: import logfire @@ -259,11 +274,18 @@ def get_client(self, model_key: str, response_model: BaseModel | None = None, ** if self.azure_endpoint: model_kwargs["azure_endpoint"] = self.azure_endpoint llm = openai.AsyncAzureOpenAI(**model_kwargs) + + if self.interceptor: + self.interceptor.patch_client(llm, mode="store_inputs") + if response_model: - client = instructor.from_openai(llm) - client_callable = partial(client.chat.completions.create, model=model) - else: - client_callable = partial(llm.chat.completions.create, model=model) + llm = patch(llm) + + if self.interceptor: + # must be called after instructor + self.interceptor.patch_client_response(llm) + + client_callable = partial(llm.chat.completions.create, model=model) return client_callable @@ -301,27 +323,26 @@ def _client_kwargs(self): def get_client(self, model_key: str, response_model: BaseModel | None = None, **kwargs): from mistralai import Mistral - async def llm_chat_non_stream_async(*args, **kwargs): - response = await llm.chat.complete_async(*args, **kwargs) - return response.choices[0].message.content - model_kwargs = self._get_model_kwargs(model_key) + model_kwargs["api_key"] = self.api_key model = model_kwargs.pop("model") + llm = Mistral(**model_kwargs) + + stream = kwargs.get("stream", False) + llm.chat.completions = SimpleNamespace(create=None) # make it like OpenAI for simplicity + llm.chat.completions.create = llm.chat.stream_async if stream else llm.chat.complete_async + + if self.interceptor: + self.interceptor.patch_client(llm, mode="store_inputs") - llm = Mistral(api_key=self.api_key) if response_model: - # can't use from_mistral due to new mistral API - # https://github.com/jxnl/instructor/issues/969 - return patch( - create=partial(llm.chat.complete_async, model=model), - mode=self.mode, - ) + llm = patch(llm) - stream = kwargs.get("stream", False) - if stream: - return partial(llm.chat.stream_async, model=model) - else: - return partial(llm_chat_non_stream_async, model=model) + if self.interceptor: + self.interceptor.patch_client_response(llm) + + client_callable = partial(llm.chat.completions.create, model=model) + return client_callable @classmethod def _get_delta(cls, chunk): @@ -355,7 +376,6 @@ async def invoke( ) - class AzureMistralAI(MistralAI): api_key = param.String(default=os.getenv("AZURE_API_KEY")) @@ -374,20 +394,26 @@ async def llm_chat_non_stream_async(*args, **kwargs): return response.choices[0].message.content model_kwargs = self._get_model_kwargs(model_key) + model_kwargs["api_key"] = self.api_key + model_kwargs["azure_endpoint"] = self.azure_endpoint model = model_kwargs.pop("model") + llm = MistralAzure(**model_kwargs) + + stream = kwargs.get("stream", False) + llm.chat.completions = SimpleNamespace(create=None) # make it like OpenAI for simplicity + llm.chat.completions.create = llm.chat.stream_async if stream else llm.chat.complete_async + + if self.interceptor: + self.interceptor.patch_client(llm, mode="store_inputs") - llm = MistralAzure(azure_api_key=self.api_key, azure_endpoint=self.azure_endpoint) if response_model: - return patch( - create=partial(llm.chat.complete_async, model=model), - mode=self.mode, - ) + llm = patch(llm) - stream = kwargs.get("stream", False) - if stream: - return partial(llm.chat.stream_async, model=model) - else: - return partial(llm_chat_non_stream_async, model=model) + if self.interceptor: + self.interceptor.patch_client_response(llm) + + client_callable = partial(llm.chat.completions.create, model=model) + return client_callable class AnthropicAI(Llm): @@ -410,12 +436,15 @@ def _client_kwargs(self): return {"temperature": self.temperature, "max_tokens": 1024} def get_client(self, model_key: str, response_model: BaseModel | None = None, **kwargs): + if self.interceptor: + raise NotImplementedError("Interceptors are not supported for AnthropicAI.") + from anthropic import AsyncAnthropic model_kwargs = self._get_model_kwargs(model_key) model = model_kwargs.pop("model") - llm = AsyncAnthropic(api_key=self.api_key) + llm = AsyncAnthropic(api_key=self.api_key, **model_kwargs) if response_model: client = instructor.from_anthropic(llm)