Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions app/api/middleware/cache_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from starlette.types import ASGIApp, Receive, Scope, Send
from fastapi import Request, Response
from app.services.cache_service import cache_service
from app.core.config import settings
import json
from structlog import get_logger

logger = get_logger()


class CacheMiddleware:
"""Middleware to cache GET responses, including headers."""

def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
# Only cache HTTP GET requests and if caching is enabled
if scope["type"] == "http" and scope["method"] == "GET" and settings.CACHE_ENABLED:
request = Request(scope, receive)
key = request.url.path + "?" + (request.url.query or "")
try:
cached = await cache_service.get(key)
except Exception as e:
logger.error("Redis failure on cache GET", error=str(e))
cached = None

if cached:
# Try to restore full response (body + headers + status)
try:
cached_obj = json.loads(cached)
response = Response(
content=cached_obj.get("body", ""),
status_code=cached_obj.get("status_code", 200),
headers=cached_obj.get("headers", {}),
)
logger.debug("Cache hit", key=key)
except Exception as e:
logger.error("Failed to parse cached response", error=str(e))
response = Response(content=cached, status_code=200)
await response(scope, receive, send)
return

# Capture and cache the response if not in cache
responder = _ResponseCatcher(self.app, key)
await responder(scope, receive, send)
return

# Non-GET or caching disabled: continue normally
await self.app(scope, receive, send)


class _ResponseCatcher:
"""Helper to capture response body, headers, and cache them."""

def __init__(self, app: ASGIApp, key: str):
self.app = app
self.key = key
self.body = b""
self.status_code = 200
self.headers = {}

async def __call__(self, scope: Scope, receive: Receive, send: Send):
async def send_wrapper(message):
if message["type"] == "http.response.start":
self.status_code = message["status"]
raw_headers = message.get("headers", [])
# Decode headers into dict[str, str]
self.headers = {k.decode(): v.decode() for k, v in raw_headers}
elif message["type"] == "http.response.body":
self.body += message.get("body", b"")
await send(message)

await self.app(scope, receive, send_wrapper)

# Cache successful GET responses, only if enabled and Redis is available
if self.status_code == 200 and settings.CACHE_ENABLED:
try:
payload = {
"body": self.body.decode(),
"headers": self.headers,
"status_code": self.status_code,
}
await cache_service.set(self.key, json.dumps(payload), settings.CACHE_TTL_DEFAULT)
logger.debug("Response cached", key=self.key, ttl=settings.CACHE_TTL_DEFAULT)
except Exception as e:
logger.error("Redis failure on cache SET", error=str(e))
37 changes: 37 additions & 0 deletions app/api/v1/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from fastapi import APIRouter, HTTPException
from app.services.cache_service import cache_service
from app.core.config import settings

router = APIRouter()

@router.get("/status")
async def cache_status():
"""
Get cache status (enabled flag and Redis connection health).
"""
enabled = settings.CACHE_ENABLED
try:
pong = await cache_service.redis.ping()
healthy = pong is True
except Exception as e:
healthy = False
return {"enabled": enabled, "healthy": healthy}

@router.post("/clear")
async def cache_clear():
"""
Clear the entire cache.
"""
if not settings.CACHE_ENABLED:
raise HTTPException(status_code=400, detail="Caching is disabled")
await cache_service.clear()
return {"status": "cleared"}

@router.get("/stats")
async def cache_stats():
"""
Retrieve cache hit/miss statistics.
"""
stats = cache_service.stats()
return stats
return stats
37 changes: 35 additions & 2 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,41 @@ class Settings(BaseSettings):
default="https://graph.microsoft.com/v1.0",
description="Base URL for Microsoft Graph API",
)
# Redis Configuration
REDIS_URL: Optional[str] = Field(
default="redis://localhost:6379/0",
description="Redis connection URL",
)
REDIS_HOST: Optional[str] = Field(
default="localhost",
description="Redis host",
)
REDIS_PORT: Optional[int] = Field(
default=6379,
description="Redis port",
)
REDIS_DB: Optional[int] = Field(
default=0,
description="Redis database index",
)
REDIS_PASSWORD: Optional[str] = Field(
default=None,
description="Redis password",
)

# Cache Settings
CACHE_ENABLED: bool = Field(
default=True,
description="Enable or disable caching",
)
CACHE_TTL_DEFAULT: int = Field(
default=300,
description="Default TTL (seconds) for cache entries",
)
CACHE_KEY_PREFIX: str = Field(
default="autoaudit",
description="Prefix for all cache keys",
)

# Database (needed by health checks)
DATABASE_URL: Optional[str] = Field(
Expand Down Expand Up @@ -111,10 +146,8 @@ def _db_url_if_set(cls, v: Optional[str]) -> Optional[str]:

class Config:
"""Pydantic configuration."""

env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = True


settings = Settings()
49 changes: 36 additions & 13 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import JSONResponse

from app.core.config import settings
from app.api.v1 import health
from app.api.v1 import auth
from app.api.v1 import auth, cache, health
from app.api.v1.graph import router as graph_router
from app.api.middleware.cache_middleware import CacheMiddleware
from app.utils.logger import logger
from app.api.v1 import graph
from app.services.cache_service import cache_service


def create_app() -> FastAPI:
Expand All @@ -22,7 +24,16 @@ def create_app() -> FastAPI:
configure_middleware(app, settings)
configure_routing(app, settings)
configure_exception_handlers(app)


# Redis connection management
@app.on_event("startup")
async def startup_event():
await cache_service.init()

@app.on_event("shutdown")
async def shutdown_event():
await cache_service.close()
await cache_service.wait_closed()

return app

Expand All @@ -37,6 +48,9 @@ def configure_middleware(app: FastAPI, settings):
allow_headers=["*"],
)

# Caches GET responses; keep enabled for /graph/* and any other GETs
app.add_middleware(CacheMiddleware)

# Trusted host middleware
app.add_middleware(
TrustedHostMiddleware,
Expand All @@ -48,7 +62,9 @@ def configure_middleware(app: FastAPI, settings):


def configure_routing(app: FastAPI, settings):
# ----------------------------
# Authentication endpoints
# ----------------------------
app.include_router(
auth.router,
prefix=f"{settings.API_PREFIX}/auth",
Expand All @@ -64,12 +80,22 @@ def configure_routing(app: FastAPI, settings):

)

# Graph API endpoints

# Graph API endpoints
# Mounted under /api/v1/graph/*
app.include_router(
graph.router,
graph_router,
prefix=f"{settings.API_PREFIX}/graph",
tags=["Graph API"],
responses={404: {"description": "Not found & Unsuccessfull"}}, #need to change this later
)


# Cache endpoints
# /api/v1/cache/status, /api/v1/cache/clear, /api/v1/cache/stats
app.include_router(
cache.router,
prefix=f"{settings.API_PREFIX}/cache",
tags=["Cache"],
)


Expand Down Expand Up @@ -119,11 +145,8 @@ async def root():
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"version": settings.VERSION,
}
return {"status": "healthy", "version": settings.VERSION}

# At the very bottom of main.py
app = create_app()

# App entrypoint
app = create_app()
90 changes: 90 additions & 0 deletions app/services/cache_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import redis.asyncio as redis
import asyncio
from typing import Optional
from app.core.config import settings
from structlog import get_logger

logger = get_logger()

class CacheService:
"""Redis-based cache service for AutoAudit API."""

def __init__(self):
self.redis = None
# Stats
self.hits = 0
self.misses = 0

async def init(self):
# Initialize Redis connection pool
self.redis = await redis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True,
)

async def close(self):
if self.redis:
await self.redis.close()

async def wait_closed(self):
if self.redis:
await self.redis.connection_pool.disconnect()

async def get(self, key: str) -> Optional[str]:
"""Retrieve a value from cache."""
if not self.redis:
logger.warning("Redis not initialized")
self.misses += 1
return None
# Use namespaced cache keys to avoid collisions across environments/projects
value = await self.redis.get(f"{settings.CACHE_KEY_PREFIX}:{key}")
if value is None:
self.misses += 1
logger.debug("Cache miss", key=key)
else:
self.hits += 1
logger.debug("Cache hit", key=key)
return value

async def set(self, key: str, value: str, ttl: Optional[int] = None) -> None:
"""Set a value in cache with TTL."""
if not self.redis:
logger.warning("Redis not initialized")
return
expire = ttl or settings.CACHE_TTL_DEFAULT
await self.redis.set(
f"{settings.CACHE_KEY_PREFIX}:{key}",
value,
ex=expire,
)
logger.debug("Cache set", key=key, ttl=expire)

async def delete(self, key: str) -> None:
"""Delete a key from cache."""
if not self.redis:
logger.warning("Redis not initialized")
return
await self.redis.delete(f"{settings.CACHE_KEY_PREFIX}:{key}")
logger.debug("Cache delete", key=key)

async def clear(self) -> None:
"""Clear the entire cache (use with caution)."""
if not self.redis:
logger.warning("Redis not initialized")
return
await self.redis.flushdb()
logger.warning("Cache cleared")

def stats(self) -> dict:
"""Return cache hit/miss statistics."""
total = self.hits + self.misses
hit_rate = (self.hits / total * 100) if total > 0 else 0.0
return {
"hits": self.hits,
"misses": self.misses,
"hit_rate": f"{hit_rate:.2f}%",
}

# Singleton pattern with async init
cache_service = CacheService()
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies = [
"uvicorn>=0.35.0",
"python-dotenv>=1.1.1",
"structlog>=25.4.0",
"redis>=6.4.0",
]

[project.optional-dependencies]
Expand All @@ -22,6 +23,8 @@ dev = [

[tool.uv]
dev-dependencies = [
"fakeredis>=2.31.0",
"pytest>=7.0.0",
"pytest-asyncio>=0.21.0",
"pytest-redis>=3.1.3",
]
Loading