Skip to content

Commit

Permalink
Merge pull request #9 from btambara/add-security-endpoints
Browse files Browse the repository at this point in the history
Add security endpoints
  • Loading branch information
btambara authored Apr 18, 2024
2 parents be18b0b + b12f1cd commit e93458c
Show file tree
Hide file tree
Showing 23 changed files with 475 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ POSTGRES_USER=postgres
POSTGRES_PASSWORD=example

ENVIRONMENT=DEVELOPMENT

SECRET_KEY=6820ede1abd21c37a7199e7f61e2d12a00bada25ae858d90ee536765551793c7
ACCESS_TOKEN_EXPIRE_MINUTES=60
23 changes: 23 additions & 0 deletions api/alembic/versions/fa7218c49f34_create_user_and_token_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Create user and token tables
Revision ID: fa7218c49f34
Revises: e5fa7cab21b2
Create Date: 2024-04-13 18:03:07.208993
"""

from typing import Sequence, Union

# revision identifiers, used by Alembic.
revision: str = "fa7218c49f34"
down_revision: Union[str, None] = "e5fa7cab21b2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
pass


def downgrade() -> None:
pass
4 changes: 4 additions & 0 deletions api/src/api/api_v1/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
players_endpoints,
stats_endpoints,
)
from user import user_endpoints

player_router = APIRouter()
player_router.include_router(
Expand All @@ -17,3 +18,6 @@
pitches_endpoints.router, prefix="/player/pitches", tags=["pitches"]
)
player_router.include_router(celery_endpoints.router, prefix="/celery", tags=["celery"])
player_router.include_router(
user_endpoints.router, prefix="/authenticate", tags=["Authentication"]
)
Empty file added api/src/auth_token/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions api/src/auth_token/token_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel
from sqlalchemy.orm import declarative_base

Base = declarative_base()


class Token(BaseModel): # type: ignore[misc]
access_token: str
token_type: str


class TokenData(BaseModel): # type: ignore[misc]
username: str
4 changes: 4 additions & 0 deletions api/src/db/database.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Annotated

from auth_token import token_model
from fastapi import Depends
from player.models import player
from settings.config import Settings, get_settings
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from user import user_model


def get_session_local(settings: Annotated[Settings, Depends(get_settings)]) -> Session:
Expand All @@ -23,5 +25,7 @@ def get_session_local(settings: Annotated[Settings, Depends(get_settings)]) -> S
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

player.Base.metadata.create_all(bind=engine)
token_model.Base.metadata.create_all(bind=engine)
user_model.Base.metadata.create_all(bind=engine)

return SessionLocal()
7 changes: 6 additions & 1 deletion api/src/player/routers/pitches_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Dict, List, Optional, Union
from typing import Annotated, Dict, List, Optional, Union

from api.deps import get_db
from fastapi import APIRouter, Depends, HTTPException
from player.crud import pitches_crud
from player.models.player import Pitches, PitchType
from player.schemas import pitches_schemas
from security.helpers import get_current_user
from sqlalchemy.orm import Session

router = APIRouter()
Expand All @@ -13,6 +14,7 @@
@router.post("/{mlb_id}", response_model=pitches_schemas.Pitches) # type: ignore[misc]
async def create_pitches(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
mlb_id: int,
pitches: pitches_schemas.PitchesCreate,
Expand Down Expand Up @@ -42,6 +44,7 @@ async def read_pitches(
@router.put("/{id}", response_model=pitches_schemas.Pitches) # type: ignore[misc]
async def update_pitches(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
id: int,
pitches_in: pitches_schemas.PitchesUpdate,
Expand All @@ -59,6 +62,7 @@ async def update_pitches(
@router.delete("/{id}", response_model=pitches_schemas.Pitches) # type: ignore[misc]
async def delete_pitches(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
id: int,
) -> Optional[Pitches]:
Expand Down Expand Up @@ -92,6 +96,7 @@ async def read_all_pitches_by_mlb_id(
@router.post("/pitch_type/{id}", response_model=pitches_schemas.PitchType) # type: ignore[misc]
async def create_pitch_type(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
id: int,
pitch_type: pitches_schemas.PitchTypeCreate,
Expand Down
6 changes: 5 additions & 1 deletion api/src/player/routers/players_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Optional
from typing import Annotated, List, Optional

from api.deps import get_db
from fastapi import APIRouter, Depends, HTTPException
from player.crud import player_crud
from player.models.player import Player
from player.schemas import player_schemas
from security.helpers import get_current_user
from sqlalchemy.orm import Session

router = APIRouter()
Expand All @@ -13,6 +14,7 @@
@router.post("/", response_model=player_schemas.Player) # type: ignore[misc]
async def create_player(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
player: player_schemas.PlayerCreate,
) -> Player:
Expand Down Expand Up @@ -88,6 +90,7 @@ async def read_player_by_primary_number(
@router.put("/{id}", response_model=player_schemas.Player) # type: ignore[misc]
async def update_player(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
id: int,
player_in: player_schemas.PlayerUpdate,
Expand All @@ -105,6 +108,7 @@ async def update_player(
@router.delete("/{id}", response_model=player_schemas.Player) # type: ignore[misc]
async def delete_player(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
id: int,
) -> Optional[Player]:
Expand Down
6 changes: 5 additions & 1 deletion api/src/player/routers/stats_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Optional
from typing import Annotated, List, Optional

from api.deps import get_db
from fastapi import APIRouter, Depends, HTTPException
from player.crud import stats_crud
from player.models.player import Stats
from player.schemas import stats_schemas
from security.helpers import get_current_user
from sqlalchemy.orm import Session

router = APIRouter()
Expand All @@ -13,6 +14,7 @@
@router.post("/{mlb_id}", response_model=stats_schemas.Stats) # type: ignore[misc]
async def create_stats(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
mlb_id: int,
stats: stats_schemas.StatsCreate,
Expand Down Expand Up @@ -42,6 +44,7 @@ async def read_stats(
@router.put("/{id}", response_model=stats_schemas.Stats) # type: ignore[misc]
async def update_stats(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
id: int,
stats_in: stats_schemas.StatsUpdate,
Expand All @@ -59,6 +62,7 @@ async def update_stats(
@router.delete("/{id}", response_model=stats_schemas.Stats) # type: ignore[misc]
async def delete_stats(
*,
token: Annotated[str, Depends(get_current_user)],
db: Session = Depends(get_db),
id: int,
) -> Optional[Stats]:
Expand Down
Empty file added api/src/security/__init__.py
Empty file.
72 changes: 72 additions & 0 deletions api/src/security/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from datetime import datetime, timedelta, timezone
from typing import Annotated, Dict, Literal

import bcrypt
from api.deps import get_db
from auth_token.token_model import TokenData
from fastapi import Depends, HTTPException, status
from jose import JWTError, jwt
from jose.constants import ALGORITHMS
from security.oauth2 import oauth2_scheme
from settings.config import Settings, get_settings
from sqlalchemy.orm import Session
from user.user_crud import get_user_by_email
from user.user_model import User


def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(bytes(plain_password, "utf-8"), bytes(hashed_password, "utf-8")) # type: ignore[no-any-return]


def authenticate_user(db: Session, email: str, password: str) -> User | Literal[False]:
user = get_user_by_email(db, email)
if not user:
return False
if not verify_password(password, user.password):
return False
return user


def create_access_token(
data: Dict[str, str | datetime],
secret_key: str,
algorithm: str = ALGORITHMS.HS256,
expires_delta: timedelta | None = None,
) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
return encoded_jwt # type: ignore[no-any-return]


ret_key: str


async def get_current_user(
settings: Annotated[Settings, Depends(get_settings)],
token: Annotated[str, Depends(oauth2_scheme)],
db: Session = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, settings.secret_key, algorithms=[settings.algorithm]
)
username: str = payload.get("sub")
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
raise credentials_exception
user = get_user_by_email(db, token_data.username)
if user is None:
raise credentials_exception
return user
3 changes: 3 additions & 0 deletions api/src/security/oauth2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from fastapi.security import OAuth2PasswordBearer

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/authenticate/token")
5 changes: 5 additions & 0 deletions api/src/settings/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import lru_cache

from jose.constants import ALGORITHMS
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

Expand All @@ -12,6 +13,10 @@ class Settings(BaseSettings): # type:ignore[misc]
postgres_password: str = Field()
environment: str = Field()

secret_key: str = Field()
access_token_expire_minutes: int = Field()
algorithm: str = Field(ALGORITHMS.HS256)


@lru_cache
def get_settings() -> Settings:
Expand Down
45 changes: 45 additions & 0 deletions api/src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,58 @@
from typing import Any, Generator

import pytest
from auth_token.token_model import Token
from fastapi.testclient import TestClient
from httpx import Headers
from main import app
from settings.config import get_settings, get_settings_override
from user.user_model import User

test_user_info = {"email": "test@testing.com", "password": "password"}


@pytest.fixture # type: ignore[misc]
def test_client() -> Generator[TestClient, Any, None]:
app.dependency_overrides[get_settings] = get_settings_override
client = TestClient(app)
yield client


@pytest.fixture # type: ignore[misc]
def get_test_user(test_client: TestClient) -> User:
response = test_client.get("/api/v1/authenticate/email/" + test_user_info["email"])

if response.status_code == 404:
response = test_client.post(
"/api/v1/authenticate",
json={
"email": test_user_info["email"],
"password": test_user_info["password"],
},
)

response_json = response.json()

return User(email=response_json["email"], password=response_json["password"])


@pytest.fixture # type: ignore[misc]
def get_test_user_token(get_test_user: User, test_client: TestClient) -> Token:
headers = Headers({"Content-Type": "application/x-www-form-urlencoded"})
data = {
"grant_type": "",
"username": get_test_user.email,
"password": test_user_info["password"],
"scope": "",
"client_id": "",
"client_secret": "",
}
response = test_client.post(
"/api/v1/authenticate/token", data=data, headers=headers
)
response_json = response.json()

return Token(
access_token=response_json["access_token"],
token_type=response_json["token_type"],
)
Loading

0 comments on commit e93458c

Please sign in to comment.