Skip to content

Commit

Permalink
add embeddings into assistant
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 committed Sep 24, 2024
1 parent ca5aeff commit ea7d578
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 27 deletions.
20 changes: 10 additions & 10 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,16 @@ async def _system_prompt_with_context(
self, messages: list | str, context: str = ""
) -> str:
system_prompt = self.system_prompt
table_name = memory.get("current_table")
if self.embeddings:
context = self.embeddings.query(messages)
# TODO: refactor this so it joins messages in a more robust way
text = "\n".join([message["content"] for message in messages])
# TODO: refactor this so it's not subsetting by index
# [(0, 'The creator of this dataset is named Andrew HH', 0.7491879463195801, 'windturbines.parquet')]
result = self.embeddings.query(text, table_name=table_name)[0][1]
context += "\n" + result
if context:
system_prompt += f"\n### CONTEXT: {context}"
system_prompt += f"{system_prompt}\n### CONTEXT: {context}".strip()
return system_prompt

async def _get_closest_tables(self, messages: list | str, tables: list[str], n: int = 3) -> list[str]:
Expand Down Expand Up @@ -283,10 +289,7 @@ async def _system_prompt_with_context(
f"\nHere's a summary of the dataset the user just asked about:\n```\n{memory['current_data']}\n```"
)

system_prompt = self.system_prompt
if context:
system_prompt += f"\n### CONTEXT: {context}"
return system_prompt
return await super()._system_prompt_with_context(messages, context=context)


class ChatDetailsAgent(ChatAgent):
Expand All @@ -313,7 +316,6 @@ class ChatDetailsAgent(ChatAgent):
async def _system_prompt_with_context(
self, messages: list | str, context: str = ""
) -> str:
system_prompt = self.system_prompt
topic = (await self.llm.invoke(
messages,
system="What is the topic of the table?",
Expand All @@ -329,9 +331,7 @@ async def _system_prompt_with_context(
columns = list(current_data.columns)
context += f"\nHere are the columns of the table: {columns}"

if context:
system_prompt += f"\n### CONTEXT: {context}"
return system_prompt
return await super()._system_prompt_with_context(messages, context=context)


class LumenBaseAgent(Agent):
Expand Down
8 changes: 8 additions & 0 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Agent, AnalysisAgent, ChatAgent, SQLAgent,
)
from .config import DEMO_MESSAGES, GETTING_STARTED_SUGGESTIONS
from .embeddings import Embeddings
from .export import export_notebook
from .llm import Llama, Llm
from .logs import ChatLogs
Expand All @@ -37,6 +38,8 @@ class Assistant(Viewer):

agents = param.List(default=[ChatAgent])

embeddings = param.ClassSelector(class_=Embeddings)

llm = param.ClassSelector(class_=Llm, default=Llama())

interface = param.ClassSelector(class_=ChatInterface)
Expand All @@ -54,6 +57,7 @@ class Assistant(Viewer):
def __init__(
self,
llm: Llm | None = None,
embeddings: Embeddings | None = None,
interface: ChatInterface | None = None,
agents: list[Agent | type[Agent]] | None = None,
logs_filename: str = "",
Expand Down Expand Up @@ -111,11 +115,15 @@ def download_notebook():
interface.post_hook = on_message

llm = llm or self.llm
embeddings = embeddings or self.embeddings
instantiated = []
self._analyses = []
for agent in agents or self.agents:
if not isinstance(agent, Agent):
kwargs = {"llm": llm} if agent.llm is None else {}
if embeddings:
print(f"embeddings for {agent}")
kwargs["embeddings"] = embeddings
agent = agent(interface=interface, **kwargs)
if agent.llm is None:
agent.llm = llm
Expand Down
76 changes: 59 additions & 17 deletions lumen/ai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,50 @@ def setup_database(self):
CREATE TABLE document_data (
id INTEGER,
text VARCHAR,
embedding FLOAT[1536]
embedding FLOAT[1536],
table_name VARCHAR
);
CREATE INDEX embedding_index ON document_data USING HNSW (embedding) WITH (metric = 'cosine');
"""
)

def add_directory(self, data_dir: Path, file_type: str = "json"):
@classmethod
def from_directory(
cls,
data_dir: Path,
file_type: str = "json",
database_path: str = ":memory:",
table_name: str = "default",
):
embeddings = cls(database_path)
for i, path in enumerate(data_dir.glob(f"**/*.{file_type}")):
text = path.read_text()
embedding = self.get_embedding(text)
self.connection.execute(
embedding = embeddings.get_embedding(text)
embeddings.connection.execute(
"""
INSERT INTO document_data (id, text, embedding)
VALUES (?, ?, ?);
INSERT INTO document_data (id, text, embedding, table_name)
VALUES (?, ?, ?, ?);
""",
[i, text, embedding],
[i, text, embedding, table_name],
)
return embeddings

@classmethod
def from_dict(cls, data: dict, database_path: str = ":memory:"):
embeddings = cls(database_path)
global_id = 0
for table_name, texts in data.items():
for text in texts:
embedding = embeddings.get_embedding(text)
embeddings.connection.execute(
"""
INSERT INTO document_data (id, text, embedding, table_name)
VALUES (?, ?, ?, ?);
""",
[global_id, text, embedding, table_name],
)
global_id += 1
return embeddings

def get_embedding(self, text: str) -> list:
raise NotImplementedError
Expand All @@ -58,17 +85,32 @@ def get_combined_embedding(self, text: str) -> list:
combined_embedding = [sum(x) / len(x) for x in zip(*embeddings)]
return combined_embedding

def query(self, query_text: str, top_k: int = 10) -> list:
def query(self, query_text: str, top_k: int = 1, table_name: str | None = None) -> list:
print(query_text, "QUERY")
query_embedding = self.get_combined_embedding(query_text)
result = self.connection.execute(
"""
SELECT id, text, array_cosine_similarity(embedding, ?::FLOAT[1536]) AS similarity
FROM document_data
ORDER BY similarity DESC
LIMIT ?;
""",
[query_embedding, top_k],
).fetchall()

if table_name:
result = self.connection.execute(
"""
SELECT id, text, array_cosine_similarity(embedding, ?::FLOAT[1536]) AS similarity, table_name
FROM document_data
WHERE table_name = ?
ORDER BY similarity DESC
LIMIT ?;
""",
[query_embedding, table_name, top_k],
).fetchall()
else:
result = self.connection.execute(
"""
SELECT id, text, array_cosine_similarity(embedding, ?::FLOAT[1536]) AS similarity, table_name
FROM document_data
ORDER BY similarity DESC
LIMIT ?;
""",
[query_embedding, top_k],
).fetchall()

return result

def close(self):
Expand Down
Binary file added lumen/ai/interceptor.db
Binary file not shown.

0 comments on commit ea7d578

Please sign in to comment.