Skip to content

Commit d550e66

Browse files
committed
chore: formatted with black
1 parent 2802548 commit d550e66

File tree

10 files changed

+159
-102
lines changed

10 files changed

+159
-102
lines changed

arcan/ai/agents/__init__.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,43 +8,60 @@
88
import pickle
99
import weakref
1010
from datetime import datetime
11+
1112
# Ensure necessary imports for ArcanAgent
1213
from tempfile import TemporaryDirectory
1314
from typing import Any, AsyncIterator, Dict, List, Optional, cast
1415

1516
from fastapi import Depends
1617
from fastapi.responses import StreamingResponse
17-
from langchain.agents import (AgentExecutor, AgentType,
18-
create_tool_calling_agent, initialize_agent,
19-
load_tools)
18+
from langchain.agents import (
19+
AgentExecutor,
20+
AgentType,
21+
create_tool_calling_agent,
22+
initialize_agent,
23+
load_tools,
24+
)
2025
from langchain.agents.agent_types import AgentType
21-
from langchain.agents.format_scratchpad.openai_tools import \
22-
format_to_openai_tool_messages
26+
from langchain.agents.format_scratchpad.openai_tools import (
27+
format_to_openai_tool_messages,
28+
)
2329
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
2430
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
2531
from langchain.embeddings.openai import OpenAIEmbeddings
2632
from langchain.memory import ConversationBufferMemory
2733
from langchain.pydantic_v1 import BaseModel
2834
from langchain.sql_database import SQLDatabase
29-
from langchain_community.agent_toolkits import (FileManagementToolkit,
30-
SQLDatabaseToolkit)
35+
from langchain_community.agent_toolkits import FileManagementToolkit, SQLDatabaseToolkit
3136
from langchain_core.callbacks import CallbackManagerForChainRun
32-
from langchain_core.load.serializable import (Serializable,
33-
SerializedConstructor,
34-
SerializedNotImplemented)
37+
from langchain_core.load.serializable import (
38+
Serializable,
39+
SerializedConstructor,
40+
SerializedNotImplemented,
41+
)
3542
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
3643
from langchain_core.prompts import ChatPromptTemplate
44+
3745
# from langchain_core.pydantic_v1 import BaseModel
38-
from langchain_core.runnables import (ConfigurableField, ConfigurableFieldSpec,
39-
Runnable, RunnableConfig,
40-
RunnablePassthrough,
41-
RunnableSerializable)
46+
from langchain_core.runnables import (
47+
ConfigurableField,
48+
ConfigurableFieldSpec,
49+
Runnable,
50+
RunnableConfig,
51+
RunnablePassthrough,
52+
RunnableSerializable,
53+
)
4254
from langchain_core.runnables.base import Runnable, RunnableBindingBase
43-
from langchain_core.runnables.utils import (AddableDict, AnyConfigurableField,
44-
ConfigurableField,
45-
ConfigurableFieldSpec, Input,
46-
Output, create_model,
47-
get_unique_config_specs)
55+
from langchain_core.runnables.utils import (
56+
AddableDict,
57+
AnyConfigurableField,
58+
ConfigurableField,
59+
ConfigurableFieldSpec,
60+
Input,
61+
Output,
62+
create_model,
63+
get_unique_config_specs,
64+
)
4865
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
4966
from pydantic import BaseModel, Field
5067
from sqlalchemy.dialects.postgresql import insert
@@ -56,8 +73,10 @@
5673
from arcan.ai.llm import LLM
5774
from arcan.ai.parser import ArcanOutputParser
5875
from arcan.ai.prompts import arcan_prompt, spells_agent_prompt
76+
5977
# from arcan.ai.router import semantic_layer
6078
from arcan.ai.tools import tools as spells
79+
6180
# from arcan.api.session import ArcanSession
6281
# from arcan.ai.agents import ArcanAgent
6382
from arcan.datamodel.chat_history import ChatHistory
@@ -73,7 +92,7 @@ class ArcanAgent(RunnableSerializable):
7392
chat_history: List = Field(default_factory=list)
7493
user_id: Optional[str] = None
7594
verbose: bool = False
76-
prompt: ChatPromptTemplate = spells_agent_prompt,
95+
prompt: ChatPromptTemplate = (spells_agent_prompt,)
7796
configs: List[ConfigurableFieldSpec] = Field(default_factory=list)
7897
llm_with_tools: LLM = Field(default_factory=lambda: LLM().llm)
7998
agent: Runnable = Field(default_factory=RunnablePassthrough)
@@ -82,22 +101,38 @@ class ArcanAgent(RunnableSerializable):
82101

83102
class Config:
84103
arbitrary_types_allowed = True
85-
extra = 'allow' # This allows additional fields not explicitly defined
86-
87-
def __init__(self, llm=None, tools: list = spells, prompt: ChatPromptTemplate = spells_agent_prompt,
88-
agent_type="arcan_spells_agent", chat_history: list = [],
89-
user_id: str = None, verbose: bool = False, configs: list = [],
90-
**kwargs):
91-
super().__init__(tools=tools, agent_type=agent_type, chat_history=chat_history,
92-
user_id=user_id, verbose=verbose, prompt=prompt, configs=configs, **kwargs)
93-
object.__setattr__(self, '_llm', llm or LLM().llm)
104+
extra = "allow" # This allows additional fields not explicitly defined
105+
106+
def __init__(
107+
self,
108+
llm=None,
109+
tools: list = spells,
110+
prompt: ChatPromptTemplate = spells_agent_prompt,
111+
agent_type="arcan_spells_agent",
112+
chat_history: list = [],
113+
user_id: str = None,
114+
verbose: bool = False,
115+
configs: list = [],
116+
**kwargs,
117+
):
118+
super().__init__(
119+
tools=tools,
120+
agent_type=agent_type,
121+
chat_history=chat_history,
122+
user_id=user_id,
123+
verbose=verbose,
124+
prompt=prompt,
125+
configs=configs,
126+
**kwargs,
127+
)
128+
object.__setattr__(self, "_llm", llm or LLM().llm)
94129
# Initialize other fields after the main Pydantic initialization
95130
self.session: ArcanSession = ArcanSession()
96131
self.bare_tools = load_tools(["llm-math"], llm=self.llm)
97132
self.agent_tools = self.tools + self.bare_tools
98133
self.llm_with_tools = self.llm.bind_tools(self.agent_tools)
99134
self.agent, self.runnable = self.get_or_create_agent(self.user_id)
100-
135+
101136
@property
102137
def llm(self):
103138
return self._llm
@@ -238,8 +273,12 @@ def invoke(
238273
]
239274
)
240275
try:
241-
self.session.store_message(user_id=self.user_id, body=user_content, response=response['output'])
242-
self.session.store_chat_history(user_id=self.user_id, agent_history=self.chat_history)
276+
self.session.store_message(
277+
user_id=self.user_id, body=user_content, response=response["output"]
278+
)
279+
self.session.store_chat_history(
280+
user_id=self.user_id, agent_history=self.chat_history
281+
)
243282
except SQLAlchemyError as e:
244283
self.session.database.rollback()
245284
print(f"Error storing conversation in database: {e}")

arcan/ai/agents/session.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#%%
1+
# %%
22

33
import ast
44
import os
@@ -59,7 +59,7 @@ def get_chat_history(self, user_id: str) -> list:
5959
with self._get_session() as db_session:
6060
history = (
6161
db_session.query(ChatHistory)
62-
# .options(joinedload(ChatHistory.history))
62+
# .options(joinedload(ChatHistory.history))
6363
.filter(ChatHistory.sender == user_id)
6464
.order_by(ChatHistory.updated_at.asc())
6565
.all()
@@ -87,19 +87,19 @@ def rollback(self):
8787
# self.database_uri = os.environ.get("SQLALCHEMY_URL")
8888
# self.agents: Dict[str, weakref.ref] = weakref.WeakValueDictionary()
8989

90-
# def store_message(self, user_id: str, body: str, response: str):
91-
# """
92-
# Stores a message in the database.
93-
94-
# :param user_id: The unique identifier for the user.
95-
# :param Body: The body of the message sent by the user.
96-
# :param response: The response generated by the system.
97-
# """
98-
# with self.database as db_session:
99-
# conversation = Conversation(sender=user_id, message=body, response=response)
100-
# db_session.add(conversation)
101-
# db_session.commit()
102-
# print(f"Conversation #{conversation.id} stored in database")
90+
# def store_message(self, user_id: str, body: str, response: str):
91+
# """
92+
# Stores a message in the database.
93+
94+
# :param user_id: The unique identifier for the user.
95+
# :param Body: The body of the message sent by the user.
96+
# :param response: The response generated by the system.
97+
# """
98+
# with self.database as db_session:
99+
# conversation = Conversation(sender=user_id, message=body, response=response)
100+
# db_session.add(conversation)
101+
# db_session.commit()
102+
# print(f"Conversation #{conversation.id} stored in database")
103103

104104
# def store_chat_history(self, user_id, agent_history):
105105
# """

arcan/ai/llm/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ class LLMFactory:
7070
os.getenv("TOGETHER_MODEL_NAME", "llama3-8b-8192"),
7171
),
7272
),
73-
'ChatOllama' : lambda **kwargs: ChatOllama(
74-
model = kwargs.get("model", os.getenv("OLLAMA_MODEL", "phi3")),
75-
)
73+
"ChatOllama": lambda **kwargs: ChatOllama(
74+
model=kwargs.get("model", os.getenv("OLLAMA_MODEL", "phi3")),
75+
),
7676
}
7777

7878
@staticmethod

arcan/ai/router/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def get_response(self, query: str, user_id: str) -> str:
6868
return route_text, query
6969
else:
7070
print(f"No route found for query: {query}")
71-
return 'No Router Matched', query
71+
return "No Router Matched", query
7272

7373

7474
# Initialize RouteManager with an encoder

arcan/ai/runnables/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_runnable(self, runnable_name: str, cache: bool = True) -> RemoteRunnable
2424
class ArcanRunnables:
2525
def __init__(self, base_url: str = "http://localhost:8000/"):
2626
self.factory = RunnableFactory(base_url=base_url)
27-
27+
2828
def get_spells_runnable(self) -> AgentExecutor:
2929
return self.factory.get_runnable(runnable_name="spells")
3030

@@ -39,9 +39,9 @@ def get_ollama_runnable(self) -> AgentExecutor:
3939

4040
def get_auth_spells_runnable(self) -> AgentExecutor:
4141
return self.factory.get_runnable(runnable_name="auth_spells")
42-
42+
4343
def get_chain_with_history_runnable(self) -> AgentExecutor:
4444
return self.factory.get_runnable(runnable_name="chain_with_history")
4545

4646

47-
#%%
47+
# %%

0 commit comments

Comments
 (0)