Skip to content

Commit

Permalink
fix: updated the auth settings
Browse files Browse the repository at this point in the history
  • Loading branch information
broomva committed May 13, 2024
1 parent e91d8ea commit cfbf3fc
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 33 deletions.
52 changes: 22 additions & 30 deletions arcan/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from fastapi import Depends, FastAPI, Form, HTTPException, Request, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
# %%
from fastapi.security import (HTTPAuthorizationCredentials, HTTPBearer,
OAuth2PasswordBearer, OAuth2PasswordRequestForm)
from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage
from langserve import add_routes
from langserve.pydantic_v1 import BaseModel, Field
Expand All @@ -20,30 +22,16 @@
from arcan.api.datamodel import get_db, get_db_context
from arcan.api.datamodel.chat_history import ChatHistory
from arcan.api.datamodel.conversation import Conversation
from arcan.api.datamodel.user import (
ACCESS_TOKEN_EXPIRE_MINUTES,
TokenModel,
User,
UserInDB,
UserModel,
UserRepository,
UserService,
oauth2_scheme,
pwd_context,
)
from arcan.api.datamodel.user import (ACCESS_TOKEN_EXPIRE_MINUTES, TokenModel,
User, UserInDB, UserModel,
UserRepository, UserService,
oauth2_scheme, pwd_context)
from arcan.api.session import ArcanSession, run_agent
from arcan.spells.vector_search import (
get_per_user_retriever,
per_req_config_modifier,
pgVectorStore,
)

# %%
# from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

# from arcan.api.session.auth import requires_auth
from arcan.spells.vector_search import (get_per_user_retriever,
per_req_config_modifier, pgVectorStore)

# auth_scheme = HTTPBearer()
auth_scheme = HTTPBearer()

load_dotenv()

Expand Down Expand Up @@ -72,13 +60,6 @@ async def index():
return {"message": "Arcan is Running!"}


# @requires_auth
@app.get("/api/chat")
async def chat(user_id: str, query: str, db: Session = Depends(get_db)):
arcan_session = ArcanSession(db)
response = run_agent(session=arcan_session, user_id=user_id, query=query)
return {"response": response}


# %%

Expand Down Expand Up @@ -109,12 +90,14 @@ class Output(BaseModel):
add_routes(
app,
LLM(provider="ChatOpenAI").llm,
per_req_config_modifier=per_req_config_modifier,
path="/openai",
)

add_routes(
app,
LLM(provider="ChatGroq").llm,
per_req_config_modifier=per_req_config_modifier,
path="/groq",
)

Expand Down Expand Up @@ -185,7 +168,16 @@ async def read_users_me(
enabled_endpoints=["invoke"],
)

# %%
#%%


# @requires_auth
@app.get("/api/chat")
async def chat(user_id: str, query: str, current_user: Annotated[UserModel, Depends(get_current_active_user_from_request)], db: Session = Depends(get_db)):
arcan_session = ArcanSession(db)
response = run_agent(session=arcan_session, user_id=current_user, query=query)
return {"response": response}


if __name__ == "__main__":
import uvicorn
Expand Down
2 changes: 1 addition & 1 deletion arcan/api/datamodel/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def create_access_token(self, data: dict, expires_delta: timedelta | None = None
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

async def get_current_user(self, token: Annotated[str, Depends(oauth2_scheme)]):
async def get_current_user(self, token: Annotated[str, Depends(oauth2_scheme)]) -> str:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
Expand Down
4 changes: 2 additions & 2 deletions arcan/api/session/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def requires_auth(func):
@wraps(func)
def wrapper(*args, token: HTTPAuthorizationCredentials = security, **kwargs):
if token.credentials != os.environ["AUTH_TOKEN"]:
if token.credentials != os.environ["ARCAN_AUTH_TOKEN"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
Expand All @@ -25,7 +25,7 @@ def wrapper(*args, token: HTTPAuthorizationCredentials = security, **kwargs):
def aio_requires_auth(func):
@wraps(func)
async def wrapper(*args, token: HTTPAuthorizationCredentials = None, **kwargs):
if token is None or token.credentials != os.environ["AUTH_TOKEN"]:
if token is None or token.credentials != os.environ["ARCAN_AUTH_TOKEN"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
Expand Down

0 comments on commit cfbf3fc

Please sign in to comment.