diff --git a/arcan/ai/agents/casters/researcher.py b/arcan/ai/agents/casters/researcher.py index cbf94e6..a905033 100644 --- a/arcan/ai/agents/casters/researcher.py +++ b/arcan/ai/agents/casters/researcher.py @@ -27,14 +27,9 @@ def search(query): url = "https://google.serper.dev/search" - payload = json.dumps({ - "q": query - }) + payload = json.dumps({"q": query}) - headers = { - 'X-API-KEY': serper_api_key, - 'Content-Type': 'application/json' - } + headers = {"X-API-KEY": serper_api_key, "Content-Type": "application/json"} response = requests.request("POST", url, headers=headers, data=payload) @@ -51,14 +46,12 @@ def scrape_website(objective: str, url: str): print("Scraping website...") # Define the headers for the request headers = { - 'Cache-Control': 'no-cache', - 'Content-Type': 'application/json', + "Cache-Control": "no-cache", + "Content-Type": "application/json", } # Define the data to be sent in the request - data = { - "url": url - } + data = {"url": url} # Convert Python object to JSON string data_json = json.dumps(data) @@ -86,7 +79,8 @@ def summary(objective, content): llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k-0613") text_splitter = RecursiveCharacterTextSplitter( - separators=["\n\n", "\n"], chunk_size=10000, chunk_overlap=500) + separators=["\n\n", "\n"], chunk_size=10000, chunk_overlap=500 + ) docs = text_splitter.create_documents([content]) map_prompt = """ Write a summary of the following text for {objective}: @@ -94,14 +88,15 @@ def summary(objective, content): SUMMARY: """ map_prompt_template = PromptTemplate( - template=map_prompt, input_variables=["text", "objective"]) + template=map_prompt, input_variables=["text", "objective"] + ) summary_chain = load_summarize_chain( llm=llm, - chain_type='map_reduce', + chain_type="map_reduce", map_prompt=map_prompt_template, combine_prompt=map_prompt_template, - verbose=True + verbose=True, ) output = summary_chain.run(input_documents=docs, objective=objective) @@ -111,8 +106,10 @@ def summary(objective, content): class ScrapeWebsiteInput(BaseModel): """Inputs for scrape_website""" + objective: str = Field( - description="The objective & task that users give to the agent") + description="The objective & task that users give to the agent" + ) url: str = Field(description="The url of the website to be scraped") @@ -133,7 +130,7 @@ def _arun(self, url: str): Tool( name="Search", func=search, - description="useful for when you need to answer questions about current events, data. You should ask targeted questions" + description="useful for when you need to answer questions about current events, data. You should ask targeted questions", ), ScrapeWebsiteTool(), ] @@ -158,7 +155,8 @@ def _arun(self, url: str): llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k-0613") memory = ConversationSummaryBufferMemory( - memory_key="memory", return_messages=True, llm=llm, max_token_limit=1000) + memory_key="memory", return_messages=True, llm=llm, max_token_limit=1000 +) agent = initialize_agent( tools, @@ -201,5 +199,5 @@ class Query(BaseModel): def researchAgent(query: Query): query = query.query content = agent({"input": query}) - actual_content = content['output'] - return actual_content \ No newline at end of file + actual_content = content["output"] + return actual_content diff --git a/arcan/ai/graphs/__init__.py b/arcan/ai/graphs/__init__.py index f7c0751..4db6978 100644 --- a/arcan/ai/graphs/__init__.py +++ b/arcan/ai/graphs/__init__.py @@ -31,7 +31,6 @@ def add(x: float, y: float) -> float: ) - import operator from typing import Annotated, Sequence, TypedDict @@ -81,7 +80,7 @@ def call_tools(state): graph = workflow.compile() -#%% +# %% from typing import TypedDict, Annotated, List, Union @@ -94,8 +93,10 @@ class AgentState(TypedDict): agent_out: Union[AgentAction, AgentFinish, None] intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add] + from langchain_core.tools import tool + @tool("search") def search_tool(query: str): """Searches for information on the topic of artificial intelligence (AI). @@ -104,16 +105,15 @@ def search_tool(query: str): # this is a "RAG" emulator return ehi_information + @tool("final_answer") -def final_answer_tool( - answer: str, - source: str -): +def final_answer_tool(answer: str, source: str): """Returns a natural language response to the user in `answer`, and a `source` which provides citations for where this information came from. """ return "" + import os from langchain.agents import create_openai_tools_agent from langchain import hub @@ -126,29 +126,28 @@ def final_answer_tool( prompt = hub.pull("hwchase17/openai-functions-agent") query_agent_runnable = create_openai_tools_agent( - llm=llm, - tools=[final_answer_tool, search_tool], - prompt=prompt + llm=llm, tools=[final_answer_tool, search_tool], prompt=prompt ) from langchain_core.agents import AgentFinish import json + def run_query_agent(state: list): print("> run_query_agent") agent_out = query_agent_runnable.invoke(state) return {"agent_out": agent_out} + def execute_search(state: list): print("> execute_search") action = state["agent_out"] tool_call = action[-1].message_log[-1].additional_kwargs["tool_calls"][-1] - out = search_tool.invoke( - json.loads(tool_call["function"]["arguments"]) - ) + out = search_tool.invoke(json.loads(tool_call["function"]["arguments"])) return {"intermediate_steps": [{"search": str(out)}]} + def router(state: list): print("> router") if isinstance(state["agent_out"], list): @@ -156,9 +155,11 @@ def router(state: list): else: return "error" + # finally, we will have a single LLM call that MUST use the final_answer structure final_answer_llm = llm.bind_tools([final_answer_tool], tool_choice="final_answer") + # this forced final_answer LLM call will be used to structure output from our # RAG endpoint def rag_final_answer(state: list): @@ -177,6 +178,7 @@ def rag_final_answer(state: list): function_call = out.additional_kwargs["tool_calls"][-1]["function"]["arguments"] return {"agent_out": function_call} + # we use the same forced final_answer LLM call to handle incorrectly formatted # output from our query_agent def handle_error(state: list): @@ -207,7 +209,4 @@ def handle_error(state: list): runnable = graph.compile() -out = runnable.invoke({ - "input": "what is AI?", - "chat_history": [] -}) \ No newline at end of file +out = runnable.invoke({"input": "what is AI?", "chat_history": []}) diff --git a/arcan/ai/interface/app.py b/arcan/ai/interface/app.py index 2eeecb7..3712cd5 100644 --- a/arcan/ai/interface/app.py +++ b/arcan/ai/interface/app.py @@ -1,5 +1,3 @@ - - from typing import Optional import chainlit as cl @@ -42,7 +40,9 @@ def auth_callback( def get_runnable(): from langserve import RemoteRunnable - spells_runnable = RemoteRunnable("https://api.arcanai.tech/spells/", headers={"arcanai_api_key": '1234'}) + spells_runnable = RemoteRunnable( + "https://api.arcanai.tech/spells/", headers={"arcanai_api_key": "1234"} + ) return spells_runnable @@ -52,23 +52,23 @@ def get_runnable(): # response - - @cl.on_message async def on_msg(msg: cl.Message): res = await get_runnable().ainvoke( - {"input": msg.content,}, - config={"configurable": {"user_id": "broomva"},} + { + "input": msg.content, + }, + config={ + "configurable": {"user_id": "broomva"}, + }, ) - await cl.Message(content=res['output']).send() - - + await cl.Message(content=res["output"]).send() # @cl.on_message # async def on_msg(msg: cl.Message): # msg = cl.Message(content="") - + # async for chunk in get_runnable().astream( # {"input": msg.content, "chat_history": []}, # # config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), @@ -83,4 +83,4 @@ async def on_msg(msg: cl.Message): # res = await agent.ainvoke( # message.content # ) -# await cl.Message(content=res).send() \ No newline at end of file +# await cl.Message(content=res).send() diff --git a/arcan/api/__init__.py b/arcan/api/__init__.py index 5e1e7a6..27b300c 100644 --- a/arcan/api/__init__.py +++ b/arcan/api/__init__.py @@ -6,13 +6,17 @@ from typing import Annotated, Any, Callable, Dict, List, Optional, Union from dotenv import load_dotenv -from fastapi import (Depends, FastAPI, Form, Header, HTTPException, Request, - status) +from fastapi import Depends, FastAPI, Form, Header, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse + # %% -from fastapi.security import (HTTPAuthorizationCredentials, HTTPBearer, - OAuth2PasswordBearer, OAuth2PasswordRequestForm) +from fastapi.security import ( + HTTPAuthorizationCredentials, + HTTPBearer, + OAuth2PasswordBearer, + OAuth2PasswordRequestForm, +) from langchain_community.chat_message_histories import FileChatMessageHistory from langchain_core import __version__ from langchain_core.chat_history import BaseChatMessageHistory @@ -33,9 +37,15 @@ from arcan.ai.llm import LLM from arcan.api.auth import fetch_session_from_header from arcan.datamodel.engine import session_scope # , session_scope_context -from arcan.datamodel.user import (ACCESS_TOKEN_EXPIRE_MINUTES, TokenModel, - UserModel, UserRepository, UserService, - oauth2_scheme, pwd_context) +from arcan.datamodel.user import ( + ACCESS_TOKEN_EXPIRE_MINUTES, + TokenModel, + UserModel, + UserRepository, + UserService, + oauth2_scheme, + pwd_context, +) # from arcan.spells.vector_search import (get_per_user_retriever, # per_req_config_modifier, pgVectorStore) @@ -189,8 +199,6 @@ async def login_for_access_token( ) - - async def get_current_active_user_from_request( request: Request, session: Session = Depends(session_scope) ) -> UserModel: diff --git a/arcan/datamodel/user.py b/arcan/datamodel/user.py index 3fa8490..d503fd9 100644 --- a/arcan/datamodel/user.py +++ b/arcan/datamodel/user.py @@ -8,8 +8,7 @@ from jose import JWTError, jwt from passlib.context import CryptContext from pydantic import BaseModel -from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String, - Text) +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text from sqlalchemy.orm import Session, relationship from arcan.datamodel.engine import Base, engine, session_scope