Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .github/workflows/lint-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ name: Lint & Test
on:
push:
branches:
- '**'
- "**"
pull_request:
branches:
- '**'

- "**"

jobs:
lint-and-test:
Expand All @@ -29,10 +28,11 @@ jobs:
# needed because the postgres container does not provide a healthcheck
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5


steps:
- name: Checkout repository
uses: actions/checkout@v2
with:
submodules: recursive

- name: Set up Python3.8
uses: actions/setup-python@v2
Expand All @@ -58,13 +58,13 @@ jobs:
# ::error file={filename},line={line},col={col}::{message}
- name: Run flake8
run: "flake8 \
--format='::error file=%(path)s,line=%(row)d,col=%(col)d::\
[flake8] %(code)s: %(text)s'"
--format='::error file=%(path)s,line=%(row)d,col=%(col)d::\
[flake8] %(code)s: %(text)s'"

- name: Run pytest
run: |
pytest
env:
TEST_DB_URI: postgresql://postgres:postgres@localhost:5432/api
TEST_POSTGRES_URI: postgresql://postgres:postgres@localhost:5432/api
# This isn't the real SECRET_KEY but the one used for testing
SECRET_KEY: nqk8umrpc4f968_2%jz_%r-r2o@v4!21#%)h&-s_7qm150=o@6
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "api/models"]
path = api/models
url = https://github.com/Tech-With-Tim/models
3 changes: 1 addition & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ WORKDIR /app

RUN apt-get update && apt-get install gcc -y

COPY Pipfile Pipfile.lock ./

RUN pip install pipenv
COPY Pipfile Pipfile.lock ./
RUN pipenv install --deploy --system

ADD . /app
Expand Down
4 changes: 3 additions & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ verify_ssl = true
[dev-packages]
black = "*"
flake8 = "*"
requests = "*"
pre-commit = "*"
pytest-asyncio = "*"
httpx = "*"
pytest = "*"
pytest-mock = "*"

[packages]
pyjwt = "*"
Expand Down
344 changes: 154 additions & 190 deletions Pipfile.lock

Large diffs are not rendered by default.

61 changes: 35 additions & 26 deletions api/app.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,54 @@
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, HTTPException
from aiohttp import ClientSession

from utils.response import JSONResponse
from api import versions
import logging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ordering isn't logical at all, ordering by length doesn't make sense to me. Order by first party, third party and local packages and in each of them put imports then from ... imports and each of these categories order alphabetically or just use isort.


import logging

log = logging.getLogger()

app = FastAPI()
app.router.prefix = "/api"
app.router.default_response_class = JSONResponse

class API(FastAPI):
"""FastAPI subclass to implement more API like handling."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def handle_http_exception(self, error: HTTPException):
"""
Returns errors as JSON instead of default HTML
Uses custom error handler if one exists.
"""
origins = ["*"] # TODO: change origins later
app.add_middleware(
CORSMiddleware,
allow_methods=["*"],
allow_headers=["*"],
allow_origins=origins,
expose_headers=["Location"],
)
app.include_router(versions.v1.router)

handler = self._find_exception_handler(error=error)

if handler is not None:
return await handler(error)
@app.on_event("startup")
async def on_startup():
"""Creates a ClientSession to be used app-wide."""
from api import http_session

headers = error.get_headers()
headers["Content-Type"] = "application/json"
if http_session.session is None or http_session.session.closed:
http_session.session = ClientSession()
log.info("Set http_session.")

return JSONResponse(
headers=headers,
status_code=error.status_code,
content={"error": error.name, "message": error.description},
)

@app.on_event("shutdown")
async def on_shutdown():
"""Closes the app-wide ClientSession"""
from api import http_session

app = API()
app.router.default_response_class = JSONResponse
if http_session.session is not None and not http_session.session.closed:
await http_session.session.close()

app.include_router(versions.v1.router)

app.add_exception_handler(HTTPException, app.handle_http_exception)
@app.exception_handler(RequestValidationError)
async def validation_handler(request, err: RequestValidationError):
return JSONResponse(
status_code=422, content={"error": "Invalid data", "data": err.errors()}
)


@app.exception_handler(500)
Expand Down
6 changes: 6 additions & 0 deletions api/http_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from aiohttp import ClientSession
from typing import Optional

session: Optional[ClientSession] = None

__all__ = (session,)
1 change: 1 addition & 0 deletions api/models
Submodule models added at 50531e
5 changes: 0 additions & 5 deletions api/models/__init__.py

This file was deleted.

4 changes: 4 additions & 0 deletions api/versions/v1/routers/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .routes import router


__all__ = (router,)
58 changes: 58 additions & 0 deletions api/versions/v1/routers/auth/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import config
import typing

from urllib.parse import quote_plus

from api.http_session import session

DISCORD_ENDPOINT = "https://discord.com/api"
SCOPES = ["identify"]


async def exchange_code(
*, code: str, scope: str, redirect_uri: str, grant_type: str = "authorization_code"
) -> typing.Tuple[dict, int]:
"""Exchange discord oauth code for access and refresh tokens."""
async with session.post(
"%s/v6/oauth2/token" % DISCORD_ENDPOINT,
data=dict(
code=code,
scope=scope,
grant_type=grant_type,
redirect_uri=redirect_uri,
client_id=config.discord_client_id(),
client_secret=config.discord_client_secret(),
),
headers={"Content-Type": "application/x-www-form-urlencoded"},
) as response:
return await response.json(), response.status


async def get_user(access_token: str) -> dict:
"""Coroutine to fetch User data from discord using the users `access_token`"""
async with session.get(
"%s/v6/users/@me" % DISCORD_ENDPOINT,
headers={"Authorization": "Bearer %s" % access_token},
) as response:
return await response.json()


def format_scopes(scopes: typing.List[str]) -> str:
"""Format a list of scopes."""
return " ".join(scopes)


def get_redirect(callback: str, scopes: typing.List[str]) -> str:
"""Generates the correct oauth link depending on our provided arguments."""
return (
"{BASE}/oauth2/authorize?response_type=code"
"&client_id={client_id}"
"&scope={scopes}"
"&redirect_uri={redirect_uri}"
"&prompt=consent"
).format(
BASE=DISCORD_ENDPOINT,
scopes=format_scopes(scopes),
redirect_uri=quote_plus(callback),
client_id=config.discord_client_id(),
)
12 changes: 12 additions & 0 deletions api/versions/v1/routers/auth/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from datetime import datetime
from pydantic import BaseModel, HttpUrl


class CallbackResponse(BaseModel):
token: str
exp: datetime


class CallbackBody(BaseModel):
code: str
callback: HttpUrl
117 changes: 117 additions & 0 deletions api/versions/v1/routers/auth/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import jwt
import utils
import config

from pydantic import HttpUrl
from fastapi import APIRouter, Request
from datetime import datetime, timedelta
from fastapi.responses import RedirectResponse

from api.models import User, Token
from .models import CallbackBody, CallbackResponse
from .helpers import (
SCOPES,
get_user,
get_redirect,
exchange_code,
format_scopes,
)

router = APIRouter(prefix="/auth")


@router.get(
"/discord/redirect",
tags=["auth"],
status_code=307,
)
async def redirect_to_discord_oauth_portal(request: Request, callback: HttpUrl = None):
"""Redirect user to correct oauth link depending on specified domain and requested scopes."""
callback = callback or (str(request.base_url) + "v1/auth/discord/callback")

return RedirectResponse(
get_redirect(callback=callback, scopes=SCOPES), status_code=307
)


if config.debug():

@router.get(
"/discord/callback",
tags=["auth"],
response_model=CallbackResponse,
response_description="GET Discord OAuth Callback",
)
async def get_discord_oauth_callback(
request: Request, code: str, callback: HttpUrl = None
):
"""
Callback endpoint for finished discord authorization flow.
"""
callback = callback or (str(request.base_url) + "v1/auth/discord/callback")
return await post_discord_oauth_callback(code, callback)


@router.post(
"/discord/callback",
tags=["auth"],
response_model=CallbackResponse,
response_description="POST Discord OAuth Callback",
)
async def post_discord_oauth_callback(data: CallbackBody):
"""
Callback endpoint for finished discord authorization flow.
"""
access_data, status_code = await exchange_code(
code=data.code, scope=format_scopes(SCOPES), redirect_uri=data.callback
)

if access_data.get("error", False):
if status_code == 400:
return utils.JSONResponse(
{
"error": "Bad Request",
"message": "Discord returned 400 status.",
"data": access_data,
},
400,
)

if status_code < 200 or status_code >= 300:
return utils.JSONResponse(
{
"error": "Bad Gateway",
"message": "Discord returned non 2xx status code",
},
502,
)

expires_at = datetime.utcnow() + timedelta(seconds=access_data["expires_in"])
expires_at = expires_at.replace(microsecond=0)

user_data = await get_user(access_token=access_data["access_token"])
user_data["id"] = uid = int(user_data["id"])

user = await User.fetch(id=uid)

if user is None:
user = await User.create(
id=user_data["id"],
username=user_data["username"],
discriminator=user_data["discriminator"],
avatar=user_data["avatar"],
)

await Token(
user_id=user.id,
data=access_data,
expires_at=expires_at,
token=access_data["access_token"],
).update()

token = jwt.encode(
{"uid": user.id, "exp": expires_at, "iat": datetime.utcnow()},
key=config.secret_key(),
)

return {"token": token, "exp": expires_at}
3 changes: 3 additions & 0 deletions api/versions/v1/routers/router.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from fastapi import APIRouter
from . import auth

router = APIRouter(prefix="/v1")

router.include_router(auth.router)
Loading