Skip to content

Commit

Permalink
fix: added user datamaodel and updated api
Browse files Browse the repository at this point in the history
  • Loading branch information
broomva committed May 13, 2024
1 parent 7cff114 commit e2de273
Show file tree
Hide file tree
Showing 14 changed files with 641 additions and 98 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ package_build:
package_list:
unzip -l dist/*.whl

server:
poetry run uvicorn arcan.api:app --port 8000 --host 0.0.0.0
serve:
poetry run uvicorn arcan.api:app --port 8000 --host 0.0.0.0 --reload
27 changes: 14 additions & 13 deletions arcan/ai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,20 @@
from tempfile import TemporaryDirectory

from fastapi.responses import StreamingResponse
from langchain.agents import AgentExecutor, create_tool_calling_agent, load_tools
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents import (AgentExecutor, create_tool_calling_agent,
load_tools)
from langchain.agents.format_scratchpad.openai_tools import \
format_to_openai_tool_messages
from langchain.agents.output_parsers.openai_tools import \
OpenAIToolsAgentOutputParser
from langchain.sql_database import SQLDatabase
from langchain_community.agent_toolkits import FileManagementToolkit, SQLDatabaseToolkit
from langchain_community.agent_toolkits import (FileManagementToolkit,
SQLDatabaseToolkit)
from langchain_core.messages import AIMessage, HumanMessage

from arcan.ai.agents.helpers import AsyncIteratorCallbackHandler
from arcan.ai.llm import LLM
from arcan.ai.prompts import arcan_prompt, spells_agent_prompt

# from arcan.ai.router import semantic_layer
from arcan.ai.tools import tools as spells

Expand Down Expand Up @@ -45,14 +46,14 @@ class ArcanAgent:

def __init__(
self,
database: SQLDatabase,
llm: LLM = LLM().llm,
tools: list = spells,
hub_prompt: str = "broomva/arcan",
agent_type="arcan_spells_agent",
context: list = [], # represents the chat history, can be pulled from a db
user_id: str = None,
verbose: bool = False,
database: SQLDatabase = SQLDatabase.from_uri(os.environ.get("SQLALCHEMY_URL")),
):
self.llm: LLM = llm
self.tools: list = tools
Expand Down Expand Up @@ -223,26 +224,26 @@ class ArcanSpellsAgent(ArcanAgent):

def __init__(
self,
# database: SQLDatabase,
llm: LLM = LLM().llm,
tools: list = spells,
prompt: str = spells_agent_prompt,
agent_type="arcan_spells_agent",
context: list = [], # represents the chat history, can be pulled from a db
user_id: str = None,
verbose: bool = False,
database: SQLDatabase = SQLDatabase.from_uri(os.environ.get("SQLALCHEMY_URL")),
):
self.llm: LLM = llm
self.tools: list = tools
self.agent_type: str = agent_type
self.chat_history: list = context
self.user_id: str = user_id
self.verbose: bool = verbose
self.database = database
self.toolkit = SQLDatabaseToolkit(db=database, llm=self.llm)
self.context = self.toolkit.get_context()
# self.database = database
# self.toolkit = SQLDatabaseToolkit(db=database, llm=self.llm)
# self.context = self.toolkit.get_context()
# self.sql_tools = self.toolkit.get_tools()
self.prompt = prompt # arcan_prompt.partial(**self.context)
self.sql_tools = self.toolkit.get_tools()
self.working_directory = TemporaryDirectory()
self.file_system_tools = FileManagementToolkit(
root_dir=str(self.working_directory.name)
Expand Down
88 changes: 84 additions & 4 deletions arcan/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
# %%
from typing import Any, List, Union
from datetime import datetime, timedelta, timezone
from typing import Annotated, Any, Dict, List, Optional, Union

from dotenv import load_dotenv
from fastapi import Depends, FastAPI, Form, Request
from fastapi import Depends, FastAPI, Form, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from typing_extensions import Annotated

from arcan.ai.agents import ArcanSpellsAgent
from arcan.ai.llm import LLM
from arcan.api.datamodels import get_db, get_db_context
from arcan.api.datamodel import get_db, get_db_context
from arcan.api.datamodel.chat_history import ChatHistory
from arcan.api.datamodel.conversation import Conversation
from arcan.api.datamodel.user import (ACCESS_TOKEN_EXPIRE_MINUTES, TokenModel,
User, UserInDB, UserModel,
UserRepository, UserService,
oauth2_scheme, pwd_context)
from arcan.api.session import ArcanSession, run_agent
from arcan.spells.vector_search import (get_per_user_retriever,
per_req_config_modifier, pgVectorStore)

# %%
# from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand All @@ -24,6 +37,7 @@

load_dotenv()


app = FastAPI()


Expand Down Expand Up @@ -73,7 +87,9 @@ class Output(BaseModel):

add_routes(
app=app,
runnable=ArcanSpellsAgent()
runnable=ArcanSpellsAgent(
# database=SQLDatabase.from_uri(os.environ.get("SQLALCHEMY_URL"))
)
.agent_executor.with_types(input_type=Input, output_type=Output)
.with_config({"run_name": "agent"}),
path="/spells_agent",
Expand All @@ -97,3 +113,67 @@ class Output(BaseModel):
LLM(provider="ChatTogetherAI").llm,
path="/together",
)


@app.post("/token")
async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], session: Session = Depends(get_db)
) -> TokenModel:
user_repo = UserRepository(session)
user_interface = UserService(user_repository=user_repo, pwd_context=pwd_context)
user = user_interface.authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = user_interface.create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return TokenModel(
id = 1,
access_token = access_token,
token_type = "bearer",
user_id = user.username,
user = user)


async def get_current_active_user_from_request(request: Request, session: Session = Depends(get_db)) -> UserModel:
"""Get the current active user from the request."""
user_repo = UserRepository(session)
user_interface = UserService(user_repository=user_repo, pwd_context=pwd_context)
token = await oauth2_scheme(request)
user = user_interface.get_current_user(token=token)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
if user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return user

@app.get("/users/me/", response_model=UserModel)
async def read_users_me(
current_user: Annotated[UserModel, Depends(get_current_active_user_from_request)],
):
return current_user

add_routes(
app,
get_per_user_retriever(vectorstore=pgVectorStore().get_vector_store()),
per_req_config_modifier=per_req_config_modifier,
enabled_endpoints=["invoke"],
)

#%%

if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="localhost", port=8000)

# %%
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@

from dotenv import load_dotenv
from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

load_dotenv()

# %%

engine = create_engine(
"postgresql://postgres.vkugdscspjbjgquljqle:ARCAN.broomva2024@aws-0-us-east-1.pooler.supabase.com:5432/postgres?client_encoding=utf8"
) # os.environ.get("SQLALCHEMY_URL"))

DATABASE_URL = str(os.environ.get("SQLALCHEMY_URL"))
print(DATABASE_URL)
assert DATABASE_URL is not None, "SQLALCHEMY_URL environment variable not found"

engine = create_engine(DATABASE_URL) # Oddly requires the hard coded string or else fails to connect
SessionLocal = sessionmaker(bind=engine)
Base = declarative_base()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from sqlalchemy import Column, DateTime, String, Text

from arcan.api.datamodels import Base, engine
from arcan.api.datamodel import Base, engine

Base.metadata.create_all(engine)


class ChatsHistory(Base):
class ChatHistory(Base):
"""
Represents the chat history for a sender.
Expand All @@ -17,7 +17,7 @@ class ChatsHistory(Base):
updated_at (datetime): The timestamp of when the chat history was last updated.
"""

__tablename__ = "chats_history"
__tablename__ = "chat_history"
sender = Column(String, primary_key=True, index=True)
history = Column(Text)
updated_at = Column(DateTime, default=datetime.utcnow)
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from sqlalchemy import Column, DateTime, Integer, String

from arcan.api.datamodels import Base, engine
from arcan.api.datamodel import Base, engine

Base.metadata.create_all(engine)


class Conversations(Base):
class Conversation(Base):
"""
Represents a conversation entity.
Expand All @@ -19,7 +19,7 @@ class Conversations(Base):
created_at (datetime): The timestamp of when the conversation was created.
"""

__tablename__ = "conversations"
__tablename__ = "conversation"
id = Column(Integer, primary_key=True, index=True)
sender = Column(String)
message = Column(String)
Expand Down
Loading

0 comments on commit e2de273

Please sign in to comment.