Skip to content

Commit

Permalink
Implement JWT token refresh
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonDaniel committed Jan 18, 2024
1 parent 3ce54b4 commit 0b7197d
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 9 deletions.
2 changes: 1 addition & 1 deletion diana_services_api/app/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


def main():
config = Configuration().get("diana_services_api")
config = Configuration().get("diana_services_api", {})
app = create_app(config)
uvicorn.run(app, host=config.get('server_host', "0.0.0.0"),
port=config.get('port', 8080))
Expand Down
5 changes: 5 additions & 0 deletions diana_services_api/app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@
@auth_route.post("/login")
async def check_login(request: AuthenticationRequest) -> AuthenticationResponse:
return client_manager.check_auth_request(**dict(request))


@auth_route.post("/refresh")
async def check_refresh(request: RefreshRequest) -> AuthenticationResponse:
return client_manager.check_refresh_request(**dict(request))
51 changes: 43 additions & 8 deletions diana_services_api/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ def __init__(self, config: dict):
self._disable_auth = config.get("disable_auth")
self._jwt_algo = "HS256"

def _create_tokens(self, encode_data: dict) -> dict:
token = jwt.encode(encode_data, self._access_secret, self._jwt_algo)
encode_data['expire'] = time() + self._refresh_token_lifetime
encode_data['access_token'] = token
refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo)
# TODO: Store refresh token on server to allow invalidating clients
return {"username": encode_data['username'],
"client_id": encode_data['client_id'],
"access_token": token,
"refresh_token": refresh}

def check_auth_request(self, client_id: str, username: str,
password: Optional[str] = None):
if client_id in self.authorized_clients:
Expand All @@ -56,17 +67,41 @@ def check_auth_request(self, client_id: str, username: str,
"username": username,
"password": password,
"expire": expiration}
token = jwt.encode(encode_data, self._access_secret, self._jwt_algo)
encode_data['expire'] = time() + self._refresh_token_lifetime
refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo)
# TODO: Store refresh token on server to validate refresh requests
auth = {"username": username,
"client_id": client_id,
"access_token": token,
"refresh_token": refresh}
auth = self._create_tokens(encode_data)
self.authorized_clients[client_id] = auth
return auth

def check_refresh_request(self, access_token: str, refresh_token: str,
client_id: str):
# Read and validate refresh token
try:
refresh_data = jwt.decode(refresh_token, self._refresh_secret,
self._jwt_algo)
except DecodeError:
raise HTTPException(status_code=400,
detail="Invalid refresh token supplied")
if refresh_data['access_token'] != access_token:
raise HTTPException(status_code=403,
detail="Refresh and access token mismatch")
if time() > refresh_data['expire']:
raise HTTPException(status_code=401,
detail="Refresh token is expired")
# Read access token and re-generate a new pair of tokens
try:
token_data = jwt.decode(access_token, self._access_secret,
self._jwt_algo)
except DecodeError:
raise HTTPException(status_code=400,
detail="Invalid access token supplied")
if token_data['client_id'] != client_id:
raise HTTPException(status_code=403,
detail="Access token does not match client_id")
encode_data = {k: token_data[k] for k in
("client_id", "username", "password")}
encode_data["expire"] = time() + self._access_token_lifetime
new_auth = self._create_tokens(encode_data)
return new_auth

def validate_auth(self, token: str) -> bool:
if self._disable_auth:
return True
Expand Down
6 changes: 6 additions & 0 deletions diana_services_api/schema/auth_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,9 @@ class AuthenticationResponse(BaseModel):
"access_token": "<redacted>",
"refresh_token": "<redacted>"
}]}}


class RefreshRequest(BaseModel):
access_token: str
refresh_token: str
client_id: str

0 comments on commit 0b7197d

Please sign in to comment.