Skip to content

Commit

Permalink
fix: added ArcanSpellsAgent and api/serving methods
Browse files Browse the repository at this point in the history
  • Loading branch information
broomva committed May 12, 2024
1 parent 67d56d4 commit cc720b1
Show file tree
Hide file tree
Showing 10 changed files with 1,361 additions and 1,206 deletions.
233 changes: 108 additions & 125 deletions arcan/ai/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,25 @@
# %%
import ast
import asyncio
import os
import pickle
import weakref
from datetime import datetime
from tempfile import TemporaryDirectory
from typing import Dict

from fastapi.responses import StreamingResponse
from langchain.agents import AgentExecutor, load_tools
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents import (AgentExecutor, create_tool_calling_agent,
load_tools)
from langchain.agents.format_scratchpad.openai_tools import \
format_to_openai_tool_messages
from langchain.agents.output_parsers.openai_tools import \
OpenAIToolsAgentOutputParser
from langchain.sql_database import SQLDatabase
from langchain_community.agent_toolkits import FileManagementToolkit, SQLDatabaseToolkit
from langchain_community.agent_toolkits import (FileManagementToolkit,
SQLDatabaseToolkit)
from langchain_core.messages import AIMessage, HumanMessage
from sqlalchemy.dialects.postgresql import insert

from arcan.ai.agents.helpers import AsyncIteratorCallbackHandler
from arcan.ai.llm import LLM
from arcan.ai.prompts import arcan_prompt
from arcan.ai.prompts import arcan_prompt, spells_agent_prompt
from arcan.ai.router import semantic_layer
from arcan.ai.tools import tools as spells
from arcan.api.datamodels.chat_history import ChatsHistory
from arcan.api.datamodels.conversation import Conversation


class ArcanAgent:
Expand Down Expand Up @@ -130,117 +124,7 @@ def get_response(self, user_content: str):
return response["output"]


class ArcanSession:
def __init__(self, session_factory):
"""
Initializes a new instance of the ArcanSession class.
:param session_factory: A callable that returns a new SQLAlchemy Session instance when called.
"""
self.session_factory = session_factory
self.agents: Dict[str, weakref.ref] = weakref.WeakValueDictionary()

def get_or_create_agent(
self, user_id: str, provided_agent: ArcanAgent = None
) -> ArcanAgent:
"""
Retrieves or creates a ArcanAgent for a given user_id.
:param user_id: The unique identifier for the user.
:return: An instance of ArcanAgent.
"""
if provided_agent is not None:
provided_agent.user_id = user_id
self.agents[user_id] = provided_agent
return provided_agent

agent = self.agents.get(user_id)
chat_history = []

# Obtain a new database session
try:
chat_history = self.get_chat_history(user_id)
except Exception as e:
print(f"Error getting chat history for {user_id}: {e}")

if agent is not None and chat_history:
print(f"Using existing agent {agent}")
elif agent is None and chat_history:
print(f"Using reloaded agent with history {chat_history}")
agent = ArcanAgent(
context=chat_history, user_id=user_id
) # Initialize with chat history
elif agent is None and not chat_history:
print("Using a new agent")
agent = ArcanAgent(user_id=user_id) # Initialize without chat history

self.agents[user_id] = agent
return agent

def store_message(self, user_id: str, body: str, response: str):
"""
Stores a message in the database.

:param user_id: The unique identifier for the user.
:param Body: The body of the message sent by the user.
:param response: The response generated by the system.
"""
with self.session_factory as db_session:
conversation = Conversation(sender=user_id, message=body, response=response)
db_session.add(conversation)
db_session.commit()
print(f"Conversation #{conversation.id} stored in database")

def store_chat_history(self, user_id, agent_history):
"""
Stores or updates the chat history for a user in the database.
:param user_id: The unique identifier for the user.
:param agent_history: The chat history to be stored.
"""
history = pickle.dumps(agent_history)
# Upsert statement
stmt = (
insert(ChatsHistory)
.values(
sender=user_id,
history=str(history),
updated_at=datetime.utcnow(), # Explicitly set updated_at on insert
)
.on_conflict_do_update(
index_elements=["sender"], # Specify the conflict target
set_={
"history": str(history), # Update the history field upon conflict
"updated_at": datetime.utcnow(), # Update the updated_at field upon conflict
},
)
)
# Execute the upsert
with self.session_factory as db:
db.execute(stmt)
db.commit()
print(f"Upsert chat history for user {user_id} with statement {stmt}")

def get_chat_history(self, user_id: str) -> list:
"""
Retrieves the chat history for a user from the database.
:param db_session: The SQLAlchemy Session instance.
:param user_id: The unique identifier for the user.
:return: A list representing the chat history.
"""
with self.session_factory as db_session:
history = (
db_session.query(ChatsHistory)
.filter(ChatsHistory.sender == user_id)
.order_by(ChatsHistory.updated_at.asc())
.all()
) or []
if not history:
return []
chat_history = history[0].history
loaded = pickle.loads(ast.literal_eval(chat_history))
return loaded


# %%
Expand Down Expand Up @@ -312,3 +196,102 @@ async def agent_chat(text: str, agent): # query: Query = Body(...),):
except Exception as e:
raise (e)
return StreamingResponse(gen, media_type="text/event-stream")




#%%
from langchain.agents import AgentExecutor, create_tool_calling_agent


class ArcanSpellsAgent(ArcanAgent):
"""
Represents a Arcan Agent that interacts with the user and provides responses using OpenAI tools.
Attributes:
llm (LLM): The Language Model Manager used by the agent.
tools (list): The list of tools used by the agent.
hub_prompt (str): The prompt for the OpenAI tools agent.
agent_type (str): The type of the agent.
chat_history (list): The chat history of the agent.
llm_with_tools: The Language Model Manager with the tools bound.
prompt: The chat prompt template for the agent.
agent: The agent pipeline.
agent_executor: The executor for the agent.
user_id: The unique identifier for the user.
verbose: A boolean indicating whether to print verbose output.
Methods:
get_response: Gets the response from the agent given user input.
"""

def __init__(
self,
llm: LLM = LLM().llm,
tools: list = spells,
prompt: str = spells_agent_prompt,
agent_type="arcan_spells_agent",
context: list = [], # represents the chat history, can be pulled from a db
user_id: str = None,
verbose: bool = False,
database: SQLDatabase = SQLDatabase.from_uri(os.environ.get("SQLALCHEMY_URL")),
):
self.llm: LLM = llm
self.tools: list = tools
self.agent_type: str = agent_type
self.chat_history: list = context
self.user_id: str = user_id
self.verbose: bool = verbose
self.database = database
self.toolkit = SQLDatabaseToolkit(db=self.database, llm=self.llm)
self.context = self.toolkit.get_context()
self.prompt = prompt # arcan_prompt.partial(**self.context)
self.sql_tools = self.toolkit.get_tools()
self.working_directory = TemporaryDirectory()
self.file_system_tools = FileManagementToolkit(
root_dir=str(self.working_directory.name)
).get_tools()
self.parser = OpenAIToolsAgentOutputParser()
self.bare_tools = load_tools(
[
"llm-math",
# "human",
# "wolfram-alpha"
],
llm=self.llm,
)
self.agent_tools = (
self.tools + self.bare_tools # + self.sql_tools + self.file_system_tools
)
self.llm_with_tools = self.llm.bind_tools(self.agent_tools)
# Construct the Tools agent
self.agent = create_tool_calling_agent(self.llm, self.agent_tools, self.prompt)
self.agent_executor = AgentExecutor(
agent=self.agent, tools=self.agent_tools, verbose=self.verbose
)

def get_response(self, user_content: str):
"""
Gets the response from the agent given user input.
Args:
user_content (str): The user input.
Returns:
str: The response from the agent.
"""
routed_content = semantic_layer(query=user_content, user_id=self.user_id)
response = self.agent_executor.invoke(
{"input": routed_content, "chat_history": self.chat_history}
)
self.chat_history.extend(
[
HumanMessage(content=user_content),
AIMessage(content=response["output"]),
]
)
return response["output"]

# %%
18 changes: 13 additions & 5 deletions arcan/ai/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from typing import cast

from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.prompts import (ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder)

ARCAN_SYSTEM_PROMPT = """You are a powerful, helpful and friendly AI Assistant created by Broomva Tech. Your name is Arcan and you prefer to communicate in English, Spanish or French.
You were created by Carlos D. Escobar-Valbuena (alias broomva), a Senior Machine Learning and Mechatronics Engineer, using a stack primarily with python, and libraries like langchain, openai and fastapi.
Expand Down Expand Up @@ -76,6 +74,16 @@

arcan_prompt = ChatPromptTemplate.from_messages(ARCAN_DEFAULT_PROMPT)


SPELLS_AGENT_DEFAULT_PROMPT = [
SystemMessage(content=cast(str, ARCAN_SYSTEM_PROMPT)),
MessagesPlaceholder(variable_name=MEMORY_KEY),
HumanMessagePromptTemplate.from_template("{input}"),
MessagesPlaceholder(variable_name=AGENT_SCRATCHPAD),
]

spells_agent_prompt = ChatPromptTemplate.from_messages(SPELLS_AGENT_DEFAULT_PROMPT)

# %%
# from langchain import hub
# hub.push("broomva/arcan", arcan_prompt, new_repo_description="Arcan AI Assistant Prompt")
Expand Down
35 changes: 17 additions & 18 deletions arcan/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#%%
from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Form, Request
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session

from arcan.ai.agents import ArcanSession
from arcan.api.datamodels import get_db
from arcan.api.datamodels import get_db, get_db_context
from arcan.api.session import ArcanSession, run_agent

#%%
# from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

# from arcan.api.session.auth import requires_auth
Expand All @@ -29,23 +31,20 @@ async def index():
return {"message": "Arcan is Running!"}





@app.get("/api/chat/{user_id}")
async def api_user_chat(user_id: str, query: str, db: Session = Depends(get_db)):
arcan_session = ArcanSession(db)
response = run_agent(session=arcan_session, user_id=user_id, query=query)
return {"response": response}

# @requires_auth
@app.get("/api/chat")
async def chat(user_id: str, query: str, db: Session = Depends(get_db)):
arcan_session = ArcanSession(db)
print(f"Sending the LangChain response to user: {user_id}")
agent = arcan_session.get_or_create_agent(user_id)
# Get the generated text from the LangChain agent
langchain_response = agent.get_response(user_content=query)
# Store the conversation in the database
try:
arcan_session.store_message(
user_id=user_id, body=query, response=langchain_response
)
arcan_session.store_chat_history(
user_id=user_id, agent_history=agent.chat_history
)
except SQLAlchemyError as e:
db.rollback()
print(f"Error storing conversation in database: {e}")
return {"response": langchain_response}
response = run_agent(session=arcan_session, user_id=user_id, query=query)
return {"response": response}

#%%
4 changes: 4 additions & 0 deletions arcan/api/datamodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

load_dotenv()

#%%

engine = create_engine(os.environ.get("SQLALCHEMY_URL"))
SessionLocal = sessionmaker(bind=engine)
Base = declarative_base()
Expand Down Expand Up @@ -38,3 +40,5 @@ def get_db_context():
yield db
finally:
db.close()

# %%
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Base.metadata.create_all(engine)


class Conversation(Base):
class Conversations(Base):
"""
Represents a conversation entity.
Expand Down
Loading

0 comments on commit cc720b1

Please sign in to comment.