Skip to content

Commit

Permalink
impr: add middlewares
Browse files Browse the repository at this point in the history
  • Loading branch information
ukwhatn committed Jun 1, 2024
1 parent acd83b8 commit 4d0f8f1
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
9 changes: 5 additions & 4 deletions redis/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
class TemplateSchema:
attr: int
from dataclasses import dataclass

def __init__(self, attr: int = None):
self.attr = attr

@dataclass
class SessionSchema:
data: any = None
22 changes: 17 additions & 5 deletions redis/session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import secrets
import os
import secrets

from .redis import RedisCrud
from .schemas import SessionSchema


class SessionCrud:
Expand All @@ -26,18 +27,29 @@ def _set(self, key, value):
def _delete(self, key):
return self.crud.delete(key)

def create(self, response, data):
session_id = secrets.token_urlsafe(16)
def create(self, response, data: SessionSchema) -> SessionSchema | None:
session_id = secrets.token_urlsafe(64)
self._set(session_id, data)
response.set_cookie(key=self.cookie_name, value=session_id)
return self._get(session_id)

def get(self, request):
def get(self, request) -> SessionSchema | None:
sess_id = request.cookies.get(self.cookie_name)
if sess_id is None:
return None
return self._get(sess_id)

def delete(self, request, response):
def update(self, request, response, data: SessionSchema) -> SessionSchema | None:
sess_id = request.cookies.get(self.cookie_name)

# create new session if not exists
if sess_id is None:
return self.create(response, data)

self._set(sess_id, data)
return data

def delete(self, request, response) -> None:
sess_id = request.cookies.get(self.cookie_name)
if sess_id is None:
return None
Expand Down
67 changes: 53 additions & 14 deletions server/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import logging

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, Request, Response, Depends

from db.package.session import get_db
from redis_crud import SessionCrud
from redis_crud.schemas import SessionSchema
from routers.v1 import main as v1_router
from util.env import get_env

Expand All @@ -25,24 +27,61 @@
app_params["docs_url"] = None
app_params["redoc_url"] = None
app_params["openapi_url"] = None
else:
app_params["docs_url"] = "/api/docs"
app_params["redoc_url"] = "/api/redoc"
app_params["openapi_url"] = "/api/openapi.json"

# create app
app = FastAPI(**app_params)

origins = [
"http://example.com",
]

app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# origins = [
# "http://example.com",
# ]
#
# app.add_middleware(
# CORSMiddleware,
# allow_origins=origins,
# allow_credentials=True,
# allow_methods=["*"],
# allow_headers=["*"],
# )

@app.middleware("http")
def error_response(request: Request, call_next):
response = Response(json.dumps({"status": "internal server error"}), status_code=500)
try:
response = call_next(request)
except Exception as e:
logger.error(e)
return response


@app.middleware("http")
async def db_opener(request: Request, call_next, db=Depends(get_db)):
request.state.db = db
return await call_next(request)


@app.middleware("http")
async def session_creator(request: Request, call_next):
with SessionCrud() as session_crud:
req_session_data = session_crud.get(request)
if req_session_data is None:
req_session_data = SessionSchema()
request.state.session = req_session_data

response = await call_next(request)

with SessionCrud() as session_crud:
res_session_data = request.state.session
session_crud.update(request, response, res_session_data)
return response


# mount static folder
app.mount("/static", StaticFiles(directory="/app/static"), name="static")
# app.mount("/static", StaticFiles(directory="/app/static"), name="static")

# add routers
app.include_router(
Expand Down

0 comments on commit 4d0f8f1

Please sign in to comment.