From ea7d5785577767dcae673aafd1df3e6d0de7d240 Mon Sep 17 00:00:00 2001 From: Andrew Huang Date: Tue, 24 Sep 2024 10:32:58 -0700 Subject: [PATCH] add embeddings into assistant --- lumen/ai/agents.py | 20 +++++------ lumen/ai/assistant.py | 8 +++++ lumen/ai/embeddings.py | 76 +++++++++++++++++++++++++++++++--------- lumen/ai/interceptor.db | Bin 0 -> 28672 bytes 4 files changed, 77 insertions(+), 27 deletions(-) create mode 100644 lumen/ai/interceptor.db diff --git a/lumen/ai/agents.py b/lumen/ai/agents.py index de7cc3f4..4996006a 100644 --- a/lumen/ai/agents.py +++ b/lumen/ai/agents.py @@ -115,10 +115,16 @@ async def _system_prompt_with_context( self, messages: list | str, context: str = "" ) -> str: system_prompt = self.system_prompt + table_name = memory.get("current_table") if self.embeddings: - context = self.embeddings.query(messages) + # TODO: refactor this so it joins messages in a more robust way + text = "\n".join([message["content"] for message in messages]) + # TODO: refactor this so it's not subsetting by index + # [(0, 'The creator of this dataset is named Andrew HH', 0.7491879463195801, 'windturbines.parquet')] + result = self.embeddings.query(text, table_name=table_name)[0][1] + context += "\n" + result if context: - system_prompt += f"\n### CONTEXT: {context}" + system_prompt += f"{system_prompt}\n### CONTEXT: {context}".strip() return system_prompt async def _get_closest_tables(self, messages: list | str, tables: list[str], n: int = 3) -> list[str]: @@ -283,10 +289,7 @@ async def _system_prompt_with_context( f"\nHere's a summary of the dataset the user just asked about:\n```\n{memory['current_data']}\n```" ) - system_prompt = self.system_prompt - if context: - system_prompt += f"\n### CONTEXT: {context}" - return system_prompt + return await super()._system_prompt_with_context(messages, context=context) class ChatDetailsAgent(ChatAgent): @@ -313,7 +316,6 @@ class ChatDetailsAgent(ChatAgent): async def _system_prompt_with_context( self, messages: list | str, context: str = "" ) -> str: - system_prompt = self.system_prompt topic = (await self.llm.invoke( messages, system="What is the topic of the table?", @@ -329,9 +331,7 @@ async def _system_prompt_with_context( columns = list(current_data.columns) context += f"\nHere are the columns of the table: {columns}" - if context: - system_prompt += f"\n### CONTEXT: {context}" - return system_prompt + return await super()._system_prompt_with_context(messages, context=context) class LumenBaseAgent(Agent): diff --git a/lumen/ai/assistant.py b/lumen/ai/assistant.py index 501a3c6d..c99c6f48 100644 --- a/lumen/ai/assistant.py +++ b/lumen/ai/assistant.py @@ -22,6 +22,7 @@ Agent, AnalysisAgent, ChatAgent, SQLAgent, ) from .config import DEMO_MESSAGES, GETTING_STARTED_SUGGESTIONS +from .embeddings import Embeddings from .export import export_notebook from .llm import Llama, Llm from .logs import ChatLogs @@ -37,6 +38,8 @@ class Assistant(Viewer): agents = param.List(default=[ChatAgent]) + embeddings = param.ClassSelector(class_=Embeddings) + llm = param.ClassSelector(class_=Llm, default=Llama()) interface = param.ClassSelector(class_=ChatInterface) @@ -54,6 +57,7 @@ class Assistant(Viewer): def __init__( self, llm: Llm | None = None, + embeddings: Embeddings | None = None, interface: ChatInterface | None = None, agents: list[Agent | type[Agent]] | None = None, logs_filename: str = "", @@ -111,11 +115,15 @@ def download_notebook(): interface.post_hook = on_message llm = llm or self.llm + embeddings = embeddings or self.embeddings instantiated = [] self._analyses = [] for agent in agents or self.agents: if not isinstance(agent, Agent): kwargs = {"llm": llm} if agent.llm is None else {} + if embeddings: + print(f"embeddings for {agent}") + kwargs["embeddings"] = embeddings agent = agent(interface=interface, **kwargs) if agent.llm is None: agent.llm = llm diff --git a/lumen/ai/embeddings.py b/lumen/ai/embeddings.py index 871d4951..dd28a2b7 100644 --- a/lumen/ai/embeddings.py +++ b/lumen/ai/embeddings.py @@ -21,23 +21,50 @@ def setup_database(self): CREATE TABLE document_data ( id INTEGER, text VARCHAR, - embedding FLOAT[1536] + embedding FLOAT[1536], + table_name VARCHAR ); CREATE INDEX embedding_index ON document_data USING HNSW (embedding) WITH (metric = 'cosine'); """ ) - def add_directory(self, data_dir: Path, file_type: str = "json"): + @classmethod + def from_directory( + cls, + data_dir: Path, + file_type: str = "json", + database_path: str = ":memory:", + table_name: str = "default", + ): + embeddings = cls(database_path) for i, path in enumerate(data_dir.glob(f"**/*.{file_type}")): text = path.read_text() - embedding = self.get_embedding(text) - self.connection.execute( + embedding = embeddings.get_embedding(text) + embeddings.connection.execute( """ - INSERT INTO document_data (id, text, embedding) - VALUES (?, ?, ?); + INSERT INTO document_data (id, text, embedding, table_name) + VALUES (?, ?, ?, ?); """, - [i, text, embedding], + [i, text, embedding, table_name], ) + return embeddings + + @classmethod + def from_dict(cls, data: dict, database_path: str = ":memory:"): + embeddings = cls(database_path) + global_id = 0 + for table_name, texts in data.items(): + for text in texts: + embedding = embeddings.get_embedding(text) + embeddings.connection.execute( + """ + INSERT INTO document_data (id, text, embedding, table_name) + VALUES (?, ?, ?, ?); + """, + [global_id, text, embedding, table_name], + ) + global_id += 1 + return embeddings def get_embedding(self, text: str) -> list: raise NotImplementedError @@ -58,17 +85,32 @@ def get_combined_embedding(self, text: str) -> list: combined_embedding = [sum(x) / len(x) for x in zip(*embeddings)] return combined_embedding - def query(self, query_text: str, top_k: int = 10) -> list: + def query(self, query_text: str, top_k: int = 1, table_name: str | None = None) -> list: + print(query_text, "QUERY") query_embedding = self.get_combined_embedding(query_text) - result = self.connection.execute( - """ - SELECT id, text, array_cosine_similarity(embedding, ?::FLOAT[1536]) AS similarity - FROM document_data - ORDER BY similarity DESC - LIMIT ?; - """, - [query_embedding, top_k], - ).fetchall() + + if table_name: + result = self.connection.execute( + """ + SELECT id, text, array_cosine_similarity(embedding, ?::FLOAT[1536]) AS similarity, table_name + FROM document_data + WHERE table_name = ? + ORDER BY similarity DESC + LIMIT ?; + """, + [query_embedding, table_name, top_k], + ).fetchall() + else: + result = self.connection.execute( + """ + SELECT id, text, array_cosine_similarity(embedding, ?::FLOAT[1536]) AS similarity, table_name + FROM document_data + ORDER BY similarity DESC + LIMIT ?; + """, + [query_embedding, top_k], + ).fetchall() + return result def close(self): diff --git a/lumen/ai/interceptor.db b/lumen/ai/interceptor.db new file mode 100644 index 0000000000000000000000000000000000000000..d5bce5a49de8883daba2d8bc7ad50da2313a4727 GIT binary patch literal 28672 zcmeI%-)h=W9Ki9UwPrY6d*@Z@xiqvC#s*^-gE@N?nc6muFxiEqdfG`fRulK{YLo6E z_9%OcJ;$D4H-jb8sPRw7F2+E<2NH8~PX3(F@B9J>&rZv2Ux|y}pzHYJk+!4hy7pWM zP16#=W&~SQ!g{zxMUM{d_)Tbs7eCV5Kec4)tM=g6_P6w>^x^&QL4|?<0tg_000Iag zfWWN4=yB4>?Ct3{PkiURqXx>m?zx`wCI?2|k~v!nJNKe2#pJZero>zcuJYPFx6y71 zTfVXn7CM|ee)CP(T&mi#D6PdF&7SKk*I#jc;nTw;|EqhFNM`o;_4?S>>vw{sQ^QmJp>msQuDTV|oomdxPFEcS zSDK@bn?`1LSN~F+jaL*YFA=-TZH&An(V0$OablG!IjcUNjk%hAT&e~)smSV*i9gj` zVL#|~R)yv6l!&J1>Xh=^!Ek)1do7NFG`33>DUM_zS1a2hU$cU&u^ZvZGdou~`9r2m zhp{YIW>YSnET5n1Mi|eLIg20lk22ymV>gzCBI>(5BH`C&HXMt?_#1@yyFc4VW={8Y zt?jnd``hW|4E^4?-k5v$#>2T2o582{Uii69hr69{M?nAq1Q0*~0R#|0009ILKmdVt z6-b2xWBtFb;bmM1AbN2F7&Afy5I_I{1Q0*~0R#|0009J8|1$>=KmY**5I_I{1Q0*~0R#|Oe*xD2>mOrA VhyVfzAb