Skip to content

Commit

Permalink
Implement rate limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonDaniel committed Jan 19, 2024
1 parent 0b7197d commit 1370058
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
24 changes: 20 additions & 4 deletions diana_services_api/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,21 @@
from fastapi import Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jwt import DecodeError
from token_throttler import TokenThrottler, TokenBucket
from token_throttler.storage import RuntimeStorage


class ClientManager:
def __init__(self, config: dict):
self.rate_limiter = TokenThrottler(cost=1, storage=RuntimeStorage())

self.authorized_clients: Dict[str, dict] = dict()
self._access_token_lifetime = config.get("access_token_ttl", 3600 * 24)
self._refresh_token_lifetime = config.get("refresh_token_ttl",
3600 * 24 * 7)
self._access_secret = config.get("access_token_secret")
self._refresh_secret = config.get("refresh_token_secret")
self._access_secret = config.get("access_token_secret") or "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391b117b"
self._refresh_secret = config.get("refresh_token_secret") or "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391b117b"
self._rpm = config.get("requests_per_minute", 60)
self._disable_auth = config.get("disable_auth")
self._jwt_algo = "HS256"

Expand Down Expand Up @@ -102,7 +107,16 @@ def check_refresh_request(self, access_token: str, refresh_token: str,
new_auth = self._create_tokens(encode_data)
return new_auth

def validate_auth(self, token: str) -> bool:
def validate_auth(self, token: str, origin_ip: str) -> bool:
if not self.rate_limiter.get_all_buckets(origin_ip):
self.rate_limiter.add_bucket(origin_ip,
TokenBucket(replenish_time=60,
max_tokens=self._rpm))
if not self.rate_limiter.consume(origin_ip) and self._rpm > 0:
raise HTTPException(status_code=429,
detail=f"Requests limited to {self._rpm}/min"
f"per client connection")

if self._disable_auth:
return True
try:
Expand All @@ -112,6 +126,7 @@ def validate_auth(self, token: str) -> bool:
return False
# Keep track of authorized client connections
self.authorized_clients[auth['client_id']] = auth
# TODO: Consider consuming an extra request for guest sessions
return True
except DecodeError:
# Invalid token supplied
Expand All @@ -131,7 +146,8 @@ async def __call__(self, request: Request):
if not credentials.scheme == "Bearer":
raise HTTPException(status_code=403,
detail="Invalid authentication scheme.")
if not self.client_manager.validate_auth(credentials.credentials):
if not self.client_manager.validate_auth(credentials.credentials,
request.client.host):
raise HTTPException(status_code=403,
detail="Invalid or expired token.")
return credentials.credentials
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ fastapi~=0.95
uvicorn~=0.25
pydantic~=2.5
pyjwt~=2.8
token-throttler~=1.4
neon-mq-connector~=0.7
ovos-config~=0.0.12

0 comments on commit 1370058

Please sign in to comment.