From 4d0f8f1e7d7c95e2c6b03c6fad41cde9e9a042b6 Mon Sep 17 00:00:00 2001 From: Yuki Watanabe Date: Sun, 2 Jun 2024 00:12:14 +0900 Subject: [PATCH] impr: add middlewares --- redis/schemas.py | 9 ++++--- redis/session.py | 22 ++++++++++++---- server/main.py | 67 ++++++++++++++++++++++++++++++++++++++---------- 3 files changed, 75 insertions(+), 23 deletions(-) diff --git a/redis/schemas.py b/redis/schemas.py index fd4fdf8..529ac74 100644 --- a/redis/schemas.py +++ b/redis/schemas.py @@ -1,5 +1,6 @@ -class TemplateSchema: - attr: int +from dataclasses import dataclass - def __init__(self, attr: int = None): - self.attr = attr + +@dataclass +class SessionSchema: + data: any = None diff --git a/redis/session.py b/redis/session.py index 2b3e4bf..09e9827 100644 --- a/redis/session.py +++ b/redis/session.py @@ -1,7 +1,8 @@ -import secrets import os +import secrets from .redis import RedisCrud +from .schemas import SessionSchema class SessionCrud: @@ -26,18 +27,29 @@ def _set(self, key, value): def _delete(self, key): return self.crud.delete(key) - def create(self, response, data): - session_id = secrets.token_urlsafe(16) + def create(self, response, data: SessionSchema) -> SessionSchema | None: + session_id = secrets.token_urlsafe(64) self._set(session_id, data) response.set_cookie(key=self.cookie_name, value=session_id) + return self._get(session_id) - def get(self, request): + def get(self, request) -> SessionSchema | None: sess_id = request.cookies.get(self.cookie_name) if sess_id is None: return None return self._get(sess_id) - def delete(self, request, response): + def update(self, request, response, data: SessionSchema) -> SessionSchema | None: + sess_id = request.cookies.get(self.cookie_name) + + # create new session if not exists + if sess_id is None: + return self.create(response, data) + + self._set(sess_id, data) + return data + + def delete(self, request, response) -> None: sess_id = request.cookies.get(self.cookie_name) if sess_id is None: return None diff --git a/server/main.py b/server/main.py index 7ddb471..8639cc3 100644 --- a/server/main.py +++ b/server/main.py @@ -1,9 +1,11 @@ +import json import logging -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles +from fastapi import FastAPI, Request, Response, Depends +from db.package.session import get_db +from redis_crud import SessionCrud +from redis_crud.schemas import SessionSchema from routers.v1 import main as v1_router from util.env import get_env @@ -25,24 +27,61 @@ app_params["docs_url"] = None app_params["redoc_url"] = None app_params["openapi_url"] = None +else: + app_params["docs_url"] = "/api/docs" + app_params["redoc_url"] = "/api/redoc" + app_params["openapi_url"] = "/api/openapi.json" # create app app = FastAPI(**app_params) -origins = [ - "http://example.com", -] -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# origins = [ +# "http://example.com", +# ] +# +# app.add_middleware( +# CORSMiddleware, +# allow_origins=origins, +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) + +@app.middleware("http") +def error_response(request: Request, call_next): + response = Response(json.dumps({"status": "internal server error"}), status_code=500) + try: + response = call_next(request) + except Exception as e: + logger.error(e) + return response + + +@app.middleware("http") +async def db_opener(request: Request, call_next, db=Depends(get_db)): + request.state.db = db + return await call_next(request) + + +@app.middleware("http") +async def session_creator(request: Request, call_next): + with SessionCrud() as session_crud: + req_session_data = session_crud.get(request) + if req_session_data is None: + req_session_data = SessionSchema() + request.state.session = req_session_data + + response = await call_next(request) + + with SessionCrud() as session_crud: + res_session_data = request.state.session + session_crud.update(request, response, res_session_data) + return response + # mount static folder -app.mount("/static", StaticFiles(directory="/app/static"), name="static") +# app.mount("/static", StaticFiles(directory="/app/static"), name="static") # add routers app.include_router(