diff --git a/Dockerfile b/Dockerfile index d8394da..6879709 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,9 +3,9 @@ FROM python:3.9-slim LABEL vendor=neon.ai \ ai.neon.name="neon-hana" -ENV OVOS_CONFIG_BASE_FOLDER neon -ENV OVOS_CONFIG_FILENAME diana.yaml -ENV XDG_CONFIG_HOME /config +ENV OVOS_CONFIG_BASE_FOLDER=neon +ENV OVOS_CONFIG_FILENAME=diana.yaml +ENV XDG_CONFIG_HOME=/config RUN apt update && apt install -y swig gcc libpulse-dev portaudio19-dev diff --git a/README.md b/README.md index d741025..41d25a1 100644 --- a/README.md +++ b/README.md @@ -18,18 +18,18 @@ hana: mq_default_timeout: 10 access_token_ttl: 86400 # 1 day refresh_token_ttl: 604800 # 1 week - requests_per_minute: 60 + requests_per_minute: 60 # All other requests (auth, registration, etc) also count towards this limit auth_requests_per_minute: 6 # This counts valid and invalid requests from an IP address + registration_requests_per_hour: 4 # This is low to prevent malicious user creation that will pollute the database access_token_secret: a800445648142061fc238d1f84e96200da87f4f9fa7835cac90db8b4391b117b refresh_token_secret: 833d369ac73d883123743a44b4a7fe21203cffc956f4c8fec712e71aafa8e1aa + jwt_issuer: neon.ai # Used in the `iss` field of generated JWT tokens. fastapi_title: "My HANA API Host" fastapi_summary: "Personal HTTP API to access my DIANA backend." - disable_auth: True + disable_auth: True # If true, no authentication will be attempted; all connections will be allowed stt_max_length_encoded: 500000 # Arbitrary limit that is larger than any expected voice command tts_max_words: 128 # Arbitrary limit that is longer than any default LLM token limit enable_email: True # Disabled by default; anyone with access to the API will be able to send emails from the configured address - node_username: node_user # Username to authenticate Node API access; leave empty to disable Node API access - node_password: node_password # Password associated with node_username max_streaming_clients: -1 # Maximum audio streaming clients allowed (including 0). Default unset value allows infinite clients ``` It is recommended to generate unique values for configured tokens, these are 32 @@ -45,7 +45,27 @@ docker run -p 8080:8080 -v ~/.config/neon:/config/neon ghcr.io/neongeckocom/neon are using the default port 8080 ## Usage -Full API documentation is available at `/docs`. The `/auth/login` endpoint should -be used to generate a `client_id`, `access_token`, and `refresh_token`. The -`access_token` should be included in every request and upon expiration of the -`access_token`, a new token can be obtained from the `auth/refresh` endpoint. +Full API documentation is available at `/docs`. + +### Registration +The `/auth/register` endpoint may be used to create a new user if auth is enabled. +If auth is disabled, any login requests will return a successful response. + +### Token Generation +The `/auth/login` endpoint should be used to generate a `client_id`, +`access_token`, and `refresh_token`. The `access_token` should be included in +every request and upon expiration of the `access_token`, a new token can be +obtained from the `auth/refresh` endpoint. Tokens are client-specific and clients +are expected to include the same `client_id` and valid tokens for that client +with every request. + +### Token Management +`access_token`s should not be saved to persistent storage; they are only valid +for a short period of time and a new `access_token` should be generated for +every new session. + +`refresh_token`s should be saved to persistent storage and used to generate a new +`access_token` and `refresh_token` at the beginning of a session, or when the +current `access_token` expires. A `refresh_token` may only be used once; a new +`refresh_token` returned from the `/auth/refresh` endpoint will replace the one +included in the request. diff --git a/neon_hana/app/__init__.py b/neon_hana/app/__init__.py index 9fd7f1d..299f963 100644 --- a/neon_hana/app/__init__.py +++ b/neon_hana/app/__init__.py @@ -32,6 +32,7 @@ from neon_hana.app.routers.llm import llm_route from neon_hana.app.routers.mq_backend import mq_route from neon_hana.app.routers.auth import auth_route +from neon_hana.app.routers.user import user_route from neon_hana.app.routers.util import util_route from neon_hana.app.routers.node_server import node_route from neon_hana.version import __version__ @@ -49,5 +50,6 @@ def create_app(config: dict): app.include_router(llm_route) app.include_router(util_route) app.include_router(node_route) + app.include_router(user_route) return app diff --git a/neon_hana/app/dependencies.py b/neon_hana/app/dependencies.py index 0c9dcf5..e6e4726 100644 --- a/neon_hana/app/dependencies.py +++ b/neon_hana/app/dependencies.py @@ -31,5 +31,5 @@ config = Configuration().get("hana") or dict() mq_connector = MQServiceManager(config) -client_manager = ClientManager(config) +client_manager = ClientManager(config, mq_connector) jwt_bearer = UserTokenAuth(client_manager) diff --git a/neon_hana/app/routers/auth.py b/neon_hana/app/routers/auth.py index 14a9359..605db9b 100644 --- a/neon_hana/app/routers/auth.py +++ b/neon_hana/app/routers/auth.py @@ -28,6 +28,7 @@ from neon_hana.app.dependencies import client_manager from neon_hana.schema.auth_requests import * +from neon_data_models.models.user import User auth_route = APIRouter(prefix="/auth", tags=["authentication"]) @@ -43,3 +44,10 @@ async def check_login(auth_request: AuthenticationRequest, @auth_route.post("/refresh") async def check_refresh(request: RefreshRequest) -> AuthenticationResponse: return client_manager.check_refresh_request(**dict(request)) + + +@auth_route.post("/register") +async def register_user(register_request: RegistrationRequest, + request: Request) -> User: + return client_manager.check_registration_request(**dict(register_request), + origin_ip=request.client.host) diff --git a/neon_hana/app/routers/node_server.py b/neon_hana/app/routers/node_server.py index c7acd3f..3a2782a 100644 --- a/neon_hana/app/routers/node_server.py +++ b/neon_hana/app/routers/node_server.py @@ -26,7 +26,6 @@ from asyncio import Event from signal import signal, SIGINT -from time import sleep from typing import Optional, Union from fastapi import APIRouter, WebSocket, HTTPException @@ -36,13 +35,17 @@ from neon_hana.app.dependencies import config, client_manager from neon_hana.mq_websocket_api import MQWebsocketAPI, ClientNotKnown -from neon_hana.schema.node_v1 import (NodeAudioInput, NodeGetStt, - NodeGetTts, NodeKlatResponse, - NodeAudioInputResponse, - NodeGetSttResponse, - NodeGetTtsResponse, CoreWWDetected, - CoreIntentFailure, CoreErrorResponse, - CoreClearData, CoreAlertExpired) +from neon_data_models.models.api.node_v1 import (NodeAudioInput, NodeGetStt, + NodeGetTts, NodeKlatResponse, + NodeAudioInputResponse, + NodeGetSttResponse, + NodeGetTtsResponse, + CoreWWDetected, + CoreIntentFailure, + CoreErrorResponse, + CoreClearData, + CoreAlertExpired) + node_route = APIRouter(prefix="/node", tags=["node"]) socket_api = MQWebsocketAPI(config) diff --git a/neon_hana/app/routers/user.py b/neon_hana/app/routers/user.py new file mode 100644 index 0000000..729760b --- /dev/null +++ b/neon_hana/app/routers/user.py @@ -0,0 +1,48 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2024 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from fastapi import APIRouter, Depends +from neon_hana.app.dependencies import jwt_bearer, mq_connector +from neon_hana.schema.user_requests import GetUserRequest, UpdateUserRequest +from neon_data_models.models.user import User + +user_route = APIRouter(tags=["user"], dependencies=[Depends(jwt_bearer)]) + + +@user_route.post("/get") +async def get_user(request: GetUserRequest, + token: str = Depends(jwt_bearer)) -> User: + hana_token = jwt_bearer.client_manager.get_token_data(token) + return mq_connector.read_user(access_token=hana_token, + auth_user=hana_token.sub, + **dict(request)) + + +@user_route.post("/update") +async def update_user(request: UpdateUserRequest, + token: str = Depends(jwt_bearer)) -> User: + return mq_connector.update_user(access_token=token, + **dict(request)) diff --git a/neon_hana/app/routers/util.py b/neon_hana/app/routers/util.py index 2d62d94..ed3a9b9 100644 --- a/neon_hana/app/routers/util.py +++ b/neon_hana/app/routers/util.py @@ -50,6 +50,7 @@ async def api_client_ip(request: Request) -> str: # Reported host is a hostname, not an IP address. Return a generic # loopback value ip_addr = "127.0.0.1" + # Validation will fail, but this increments the rate-limiting client_manager.validate_auth("", ip_addr) return ip_addr @@ -57,5 +58,6 @@ async def api_client_ip(request: Request) -> str: @util_route.get("/headers") async def api_headers(request: Request): ip_addr = request.client.host if request.client else "127.0.0.1" + # Validation will fail, but this increments the rate-limiting client_manager.validate_auth("", ip_addr) return request.headers diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index 7ad2e26..daa89d5 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -23,74 +23,121 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from threading import Lock import jwt +from uuid import uuid4 +from datetime import datetime +from threading import Lock from time import time from typing import Dict, Optional from fastapi import Request, HTTPException from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials -from jwt import DecodeError +from jwt import DecodeError, ExpiredSignatureError from ovos_utils import LOG +from ovos_utils.log import log_deprecation +from pydantic import ValidationError from token_throttler import TokenThrottler, TokenBucket from token_throttler.storage import RuntimeStorage -from neon_hana.auth.permissions import ClientPermissions +from neon_data_models.models.api.jwt import HanaToken +from neon_hana.mq_service_api import MQServiceManager +from neon_data_models.models.user import (User, NeonUserConfig, + PermissionsConfig) +from neon_data_models.enum import AccessRoles +from neon_hana.schema.auth_requests import AuthenticationResponse + +_DEFAULT_USER_PERMISSIONS = PermissionsConfig(klat=AccessRoles.USER, + core=AccessRoles.USER, + diana=AccessRoles.USER, + node=AccessRoles.USER, + hub=AccessRoles.USER, + llm=AccessRoles.USER) class ClientManager: - def __init__(self, config: dict): + def __init__(self, config: dict, + mq_connector: Optional[MQServiceManager] = None): self.rate_limiter = TokenThrottler(cost=1, storage=RuntimeStorage()) - self.authorized_clients: Dict[str, dict] = dict() + # Keep a dict of `client_id` to auth tokens that have authenticated to + # this instance + self._authorized_clients: Dict[str, AuthenticationResponse] = dict() self._access_token_lifetime = config.get("access_token_ttl", 3600 * 24) self._refresh_token_lifetime = config.get("refresh_token_ttl", - 3600 * 24 * 7) + 3600 * 24 * 90) + self._jwt_issuer = config.get("jwt_issuer", "neon.ai") self._access_secret = config.get("access_token_secret") self._refresh_secret = config.get("refresh_token_secret") self._rpm = config.get("requests_per_minute", 60) self._auth_rpm = config.get("auth_requests_per_minute", 6) + self._register_rph = config.get("registration_requests_per_hour", 4) self._disable_auth = config.get("disable_auth") - self._node_username = config.get("node_username") - self._node_password = config.get("node_password") self._max_streaming_clients = config.get("max_streaming_clients") self._jwt_algo = "HS256" self._connected_streams = 0 self._stream_check_lock = Lock() + # If authentication is explicitly disabled, don't try to query the + # users service + self._mq_connector = None if self._disable_auth else mq_connector - def _create_tokens(self, encode_data: dict) -> dict: - # Permissions were not included in old tokens, allow refreshing with - # default permissions - encode_data.setdefault("permissions", ClientPermissions().as_dict()) - - token_expiration = encode_data['expire'] - token = jwt.encode(encode_data, self._access_secret, self._jwt_algo) - encode_data['expire'] = time() + self._refresh_token_lifetime - encode_data['access_token'] = token - refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo) - # TODO: Store refresh token on server to allow invalidating clients - return {"username": encode_data['username'], - "client_id": encode_data['client_id'], - "permissions": encode_data['permissions'], - "access_token": token, - "refresh_token": refresh, - "expiration": token_expiration} - - def get_permissions(self, client_id: str) -> ClientPermissions: + @property + def authorized_clients(self) -> Dict[str, AuthenticationResponse]: """ - Get ClientPermissions model for the given client_id - @param client_id: Client ID to get permissions for - @return: ClientPermissions object for the specified client + Dict of `client_id` to `AuthenticationResponse` objects for clients + known by this instance. NOTE: Refresh tokens are not reliably stored + here and should never be retrievable after generation for security. """ - if self._disable_auth: - LOG.debug("Auth disabled, allow full client permissions") - return ClientPermissions(assist=True, backend=True, node=True) - if client_id not in self.authorized_clients: - LOG.warning(f"{client_id} not known to this server") - return ClientPermissions(assist=False, backend=False, node=False) - client = self.authorized_clients[client_id] - return ClientPermissions(**client.get('permissions', dict())) + log_deprecation("This property is deprecated with no replacement", "1.0.0") + return self._authorized_clients + + def _create_tokens(self, + user_id: str, + client_id: str, + token_name: Optional[str] = None, + permissions: Optional[PermissionsConfig] = None, + **kwargs) -> (str, str, Dict[str, HanaToken]): + token_id = str(uuid4()) + # Subtract a second from creation so the token may be used immediately + # upon return + creation_timestamp = round(time()) - 1 + expiration_timestamp = creation_timestamp + self._access_token_lifetime + refresh_expiration_timestamp = creation_timestamp + self._refresh_token_lifetime + permissions = permissions or PermissionsConfig(core=AccessRoles.GUEST, + diana=AccessRoles.GUEST, + node=AccessRoles.GUEST, + llm=AccessRoles.GUEST) + token_name = token_name or kwargs.get("name") or \ + datetime.fromtimestamp(creation_timestamp).isoformat() + access_token_data = HanaToken(iss=self._jwt_issuer, + sub=user_id, + exp=expiration_timestamp, + iat=creation_timestamp, + jti=token_id, + client_id=client_id, + roles=permissions.to_roles(), + token_name=token_name, + creation_timestamp=creation_timestamp, + last_refresh_timestamp=creation_timestamp, + purpose="access") + refresh_token_data = HanaToken(iss=self._jwt_issuer, + sub=user_id, + exp=refresh_expiration_timestamp, + iat=creation_timestamp, + jti=f"{token_id}.refresh", + client_id=client_id, + roles=permissions.to_roles(), + token_name=token_name, + creation_timestamp=creation_timestamp, + last_refresh_timestamp=creation_timestamp, + purpose="refresh") + access_token = jwt.encode(access_token_data.model_dump(), + self._access_secret, self._jwt_algo) + refresh_token = jwt.encode(refresh_token_data.model_dump(), + self._refresh_secret, self._jwt_algo) + + return access_token, refresh_token, {"access": access_token_data, + "refresh": refresh_token_data} def check_connect_stream(self) -> bool: """ @@ -112,114 +159,220 @@ def disconnect_stream(self): with self._stream_check_lock: self._connected_streams -= 1 + def _consume_rate_limit_token(self, ratelimit_id: str): + if not self.rate_limiter.consume(ratelimit_id): + bucket = list(self.rate_limiter.get_all_buckets(ratelimit_id). + values())[0] + replenish_time = bucket.last_replenished + bucket.replenish_time + wait_time = round(replenish_time - time()) + ip_addr, request_cls = ratelimit_id.split('-', 1) + raise HTTPException(status_code=429, + detail=f"Too many {request_cls} requests from: " + f"{ip_addr}. Wait {wait_time}s.") + + def check_registration_request(self, username: str, password: str, + user_config: NeonUserConfig, + origin_ip: str = "127.0.0.1") -> User: + """ + Handle a request to register a new user. + """ + + ratelimit_id = f"{origin_ip}-register" + if not self.rate_limiter.get_all_buckets(ratelimit_id): + self.rate_limiter.add_bucket(ratelimit_id, + TokenBucket(replenish_time=3600, + max_tokens=self._register_rph)) + self._consume_rate_limit_token(ratelimit_id) + + new_user = User(username=username, password_hash=password, + neon=user_config, permissions=_DEFAULT_USER_PERMISSIONS) + if self._mq_connector: + return self._mq_connector.create_user(new_user) + else: + LOG.debug("No User Database connected. Return valid registration.") + return new_user + def check_auth_request(self, client_id: str, username: str, password: Optional[str] = None, - origin_ip: str = "127.0.0.1") -> dict: + token_name: Optional[str] = None, + origin_ip: str = "127.0.0.1") -> AuthenticationResponse: """ Authenticate and Authorize a new client connection with the specified username, password, and origin IP address. @param client_id: Client ID of the connection to auth @param username: Supplied username to authenticate @param password: Supplied password to authenticate + @param token_name: Token name to add to user database @param origin_ip: Origin IP address of request @return: response tokens, permissions, and other metadata """ - if client_id in self.authorized_clients: - print(f"Using cached client: {self.authorized_clients[client_id]}") - return self.authorized_clients[client_id] + # Caching does not work here because there is no guarantee that this + # instance knows the client's refresh token. One client may also want + # to generate multiple tokens. + # if client_id in self.authorized_clients: + # print(f"Using cached client: {self.authorized_clients[client_id]}") + # return self.authorized_clients[client_id] - ratelimit_id = f"auth{origin_ip}" + ratelimit_id = f"{origin_ip}-auth" if not self.rate_limiter.get_all_buckets(ratelimit_id): self.rate_limiter.add_bucket(ratelimit_id, TokenBucket(replenish_time=60, max_tokens=self._auth_rpm)) - if not self.rate_limiter.consume(ratelimit_id): - bucket = list(self.rate_limiter.get_all_buckets(ratelimit_id). - values())[0] - replenish_time = bucket.last_replenished + bucket.replenish_time - wait_time = round(replenish_time - time()) - raise HTTPException(status_code=429, - detail=f"Too many auth requests from: " - f"{origin_ip}. Wait {wait_time}s.") + self._consume_rate_limit_token(ratelimit_id) - node_access = False - if username != "guest": - # TODO: Validate password here - pass - if all((self._node_username, username == self._node_username, - password == self._node_password)): - node_access = True - permissions = ClientPermissions(node=node_access) - expiration = time() + self._access_token_lifetime + if self._mq_connector is None: + # Auth is disabled; every auth request gets a successful response + user = User(username=username, password_hash=password, + permissions=_DEFAULT_USER_PERMISSIONS) + # elif all((self._node_username, username == self._node_username, + # password == self._node_password)): + # # User matches configured node username/password + # user = User(username=username, password_hash=password, + # permissions=_DEFAULT_USER_PERMISSIONS) + # user.permissions.node = AccessRoles.USER + else: + user = self._mq_connector.read_user(username, password) + + create_time = round(time()) encode_data = {"client_id": client_id, - "username": username, - "password": password, - "permissions": permissions.as_dict(), - "expire": expiration} - auth = self._create_tokens(encode_data) - self.authorized_clients[client_id] = auth - return auth - - def check_refresh_request(self, access_token: str, refresh_token: str, - client_id: str): + "user_id": user.user_id, + "permissions": user.permissions, + "token_name": token_name, + "last_refresh_timestamp": create_time} + access, refresh, config = self._create_tokens(**encode_data) + + auth_response = AuthenticationResponse(username=user.username, + client_id=client_id, + access_token=access, + refresh_token=refresh, + expiration=config['access'].exp) + self.authorized_clients[client_id] = auth_response + self._add_token_to_userdb(user, config['refresh']) + return auth_response + + def check_refresh_request(self, access_token: Optional[str], + refresh_token: str, + client_id: str) -> AuthenticationResponse: # Read and validate refresh token try: - refresh_data = jwt.decode(refresh_token, self._refresh_secret, - self._jwt_algo) + refresh_data = HanaToken(**jwt.decode(refresh_token, + self._refresh_secret, + self._jwt_algo)) + token_data = HanaToken(**jwt.decode(access_token, + self._access_secret, + self._jwt_algo, + options={"verify_signature": False})) + if refresh_data.purpose != "refresh": + raise HTTPException(status_code=400, + detail="Supplied refresh token not valid") + # if token_data.purpose != "access": + # raise HTTPException(status_code=400, + # detail="Supplied refresh token not valid") except DecodeError: raise HTTPException(status_code=400, - detail="Invalid refresh token supplied") - if refresh_data['access_token'] != access_token: + detail="Invalid token supplied") + except ExpiredSignatureError: + raise HTTPException(status_code=401, + detail="Refresh token is expired") + if refresh_data.jti != token_data.jti + ".refresh": raise HTTPException(status_code=403, detail="Refresh and access token mismatch") - if time() > refresh_data['expire']: + if time() > refresh_data.exp: raise HTTPException(status_code=401, detail="Refresh token is expired") - # Read access token and re-generate a new pair of tokens - # This is already known to be a valid token based on the refresh token - token_data = jwt.decode(access_token, self._access_secret, - self._jwt_algo) - if token_data['client_id'] != client_id: + if refresh_data.client_id != client_id: raise HTTPException(status_code=403, detail="Access token does not match client_id") - encode_data = {k: token_data[k] for k in - ("client_id", "username", "password")} - encode_data["expire"] = time() + self._access_token_lifetime - new_auth = self._create_tokens(encode_data) - return new_auth + + encode_data = {"user_id": refresh_data.sub, + "client_id": client_id, + "token_name": refresh_data.token_name, + "permissions": PermissionsConfig.from_roles(refresh_data.roles) + } + access, refresh, tokens = self._create_tokens(**encode_data) + username = refresh_data.sub + if self._mq_connector: + user = self._mq_connector.read_user(username=refresh_data.sub, + access_token=token_data) + if not user.password_hash: + # This should not be possible, but don't let an error in the + # users service allow for injecting a new valid token to the db + raise HTTPException(status_code=500, detail="Error Fetching User") + self._add_token_to_userdb(user, tokens['refresh']) + + auth_response = AuthenticationResponse(username=username, + client_id=client_id, + access_token=access, + refresh_token=refresh, + expiration=tokens['refresh'].exp) + self._authorized_clients[client_id] = auth_response + return auth_response + + def _add_token_to_userdb(self, user: User, new_token: HanaToken): + if new_token.purpose != "refresh": + raise ValueError(f"Expected a refresh token, got: " + f"{new_token.purpose}") + if self._mq_connector is None: + LOG.debug("No MQ Connection to a user database") + return + for idx, token in enumerate(user.tokens): + # If the token is already defined, maintain the original + # creation timestamp + if token.jti == new_token.jti: + new_token.creation_timestamp = token.creation_timestamp + user.tokens.remove(token) + user.tokens.append(new_token) + self._mq_connector.update_user(user) def get_client_id(self, token: str) -> str: """ - Extract the client_id from a JWT token - @param token: JWT token to parse + Extract the client_id from a JWT string + @param token: JWT to parse @return: client_id associated with token """ - auth = jwt.decode(token, self._access_secret, self._jwt_algo) - return auth['client_id'] + auth = HanaToken(**jwt.decode(token, self._access_secret, + self._jwt_algo)) + return auth.client_id + + def get_token_data(self, token: str) -> HanaToken: + """ + Extract the user_id from a JWT string + @param token: JWT to parse + @retrun: user_id associated with token + """ + return HanaToken(**jwt.decode(token, self._access_secret, + self._jwt_algo)) def validate_auth(self, token: str, origin_ip: str) -> bool: - if not self.rate_limiter.get_all_buckets(origin_ip): - self.rate_limiter.add_bucket(origin_ip, + ratelimit_id = f"{origin_ip}-total" + if not self.rate_limiter.get_all_buckets(ratelimit_id): + self.rate_limiter.add_bucket(ratelimit_id, TokenBucket(replenish_time=60, max_tokens=self._rpm)) - if not self.rate_limiter.consume(origin_ip) and self._rpm > 0: - raise HTTPException(status_code=429, - detail=f"Requests limited to {self._rpm}/min " - f"per client connection") + if self._rpm > 0: + self._consume_rate_limit_token(ratelimit_id) if self._disable_auth: return True try: - auth = jwt.decode(token, self._access_secret, self._jwt_algo) - if auth['expire'] < time(): - self.authorized_clients.pop(auth['client_id'], None) + auth = HanaToken(**jwt.decode(token, self._access_secret, + self._jwt_algo)) + if auth.exp < time(): + self.authorized_clients.pop(auth.client_id, None) return False - self.authorized_clients[auth['client_id']] = auth + self.authorized_clients[auth.client_id] = AuthenticationResponse( + username=auth.sub, client_id=auth.client_id, access_token=token, + refresh_token="", expiration=auth.exp) return True + except ValidationError: + LOG.error(f"Invalid token data received from {origin_ip}.") except DecodeError: # Invalid token supplied pass + except ExpiredSignatureError: + # Expired token + pass return False diff --git a/neon_hana/mq_service_api.py b/neon_hana/mq_service_api.py index a36e5e3..e1fd292 100644 --- a/neon_hana/mq_service_api.py +++ b/neon_hana/mq_service_api.py @@ -27,13 +27,17 @@ import json from time import time -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Union from uuid import uuid4 from fastapi import HTTPException -from neon_hana.schema.node_model import NodeData -from neon_hana.schema.user_profile import UserProfile +from neon_data_models.models.api import CreateUserRequest, ReadUserRequest, \ + UpdateUserRequest, DeleteUserRequest +from neon_data_models.models.api.jwt import HanaToken from neon_mq_connector.utils.client_utils import send_mq_request +from neon_data_models.models.client.node import NodeData +from neon_data_models.models.user.neon_profile import UserProfile +from neon_data_models.models.user import User class APIError(HTTPException): @@ -77,6 +81,28 @@ def _validate_api_proxy_response(response: dict, query_params: dict): code = response['status_code'] if response['status_code'] > 200 else 500 raise APIError(status_code=code, detail=response['content']) + @staticmethod + def _query_users_api(user_db_request: Union[CreateUserRequest, + ReadUserRequest, + UpdateUserRequest, + DeleteUserRequest]) -> \ + (int, Union[User, str]): + """ + Query the users API and return a status code and either a valid User or + a string error message. Authentication may use EITHER a password or + a token. + @param user_db_request: UserDbRequest object describing CRUD operation + to return + @return: success bool, HTTP status code, User object or string error + """ + response = send_mq_request("/neon_users", + user_db_request.model_dump(exclude={ + "message_id"}), + target_queue="neon_users_input") + if response.get("success"): + return 200, User(**response.get("user")) + return response.get("code", 500), response.get("error", "") + def get_session(self, node_data: NodeData) -> dict: """ Get a serialized Session object for the specified Node. @@ -89,6 +115,62 @@ def get_session(self, node_data: NodeData) -> dict: "site_id": node_data.location.site_id}) return self.sessions_by_id[session_id] + def create_user(self, user: User) -> User: + """ + Create a new user. + @param user: User object to add to the users service database + @returns: User object added to the database + """ + create_user_request = CreateUserRequest(user=user, message_id="") + code, err_or_user = self._query_users_api(create_user_request) + if code != 200: + raise HTTPException(status_code=code, detail=err_or_user) + return err_or_user + + def read_user(self, username: str, password: Optional[str] = None, + access_token: Optional[HanaToken] = None, + auth_user: Optional[str] = None) -> User: + """ + Get a User object for a user. This requires that a valid password OR + access token be provided to prevent arbitrary users from reading + private profile info. + @param username: Valid username to get a User object for + @param password: Valid password to use for authentication + @param access_token: Valid access token to use for authentication + @param auth_user: Optional username or user ID to use for authentication + @returns: User object from the Users service. + """ + auth_user = auth_user or username + read_user_request = ReadUserRequest(user_spec=username, + auth_user_spec=auth_user, + access_token=access_token, + password=password, message_id="") + code, err_or_user = self._query_users_api(read_user_request) + if code != 200: + raise HTTPException(status_code=code, detail=err_or_user) + return err_or_user + + def update_user(self, user: User, + auth_user: Optional[str] = None, + auth_password: Optional[str] = None) -> User: + """ + Update an existing user in the database. + @param user: Updated user object to write + @param auth_user: Username to use for authentication + @param auth_password: Password associated with `auth_user` + @returns: User as read from the database + """ + auth_user = auth_user or user.username + auth_password = auth_password or user.password_hash + update_user_request = UpdateUserRequest(user=user, + auth_username=auth_user, + auth_password=auth_password, + message_id="") + code, err_or_user = self._query_users_api(update_user_request) + if code != 200: + raise HTTPException(status_code=code, detail=err_or_user) + return err_or_user + def query_api_proxy(self, service_name: str, query_params: dict, timeout: int = 10): query_params['service'] = service_name diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 6b9b96e..d9d0bae 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -33,7 +33,7 @@ from neon_iris.client import NeonAIClient from ovos_bus_client.message import Message from threading import RLock -from ovos_utils import LOG +from ovos_utils.log import LOG class ClientNotKnown(RuntimeError): diff --git a/neon_hana/schema/assist_requests.py b/neon_hana/schema/assist_requests.py index 7af09b7..472b708 100644 --- a/neon_hana/schema/assist_requests.py +++ b/neon_hana/schema/assist_requests.py @@ -27,8 +27,8 @@ from typing import List, Optional from pydantic import BaseModel -from neon_hana.schema.node_model import NodeData -from neon_hana.schema.user_profile import UserProfile +from neon_data_models.models.client.node import NodeData +from neon_data_models.models.user.neon_profile import UserProfile class STTRequest(BaseModel): diff --git a/neon_hana/schema/auth_requests.py b/neon_hana/schema/auth_requests.py index d02724d..cd215e2 100644 --- a/neon_hana/schema/auth_requests.py +++ b/neon_hana/schema/auth_requests.py @@ -24,22 +24,27 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from datetime import datetime from typing import Optional from uuid import uuid4 from pydantic import BaseModel, Field +from neon_data_models.models.user import NeonUserConfig + class AuthenticationRequest(BaseModel): username: str = "guest" password: Optional[str] = None + token_name: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) client_id: str = Field(default_factory=lambda: str(uuid4())) model_config = { "json_schema_extra": { "examples": [{ "username": "guest", - "password": "password" + "password": "password", + "token_name": "My Client" }]}} @@ -48,7 +53,8 @@ class AuthenticationResponse(BaseModel): client_id: str access_token: str refresh_token: str - expiration: float + expiration: float = Field( + description="Expiration timestamp of the refresh token") model_config = { "json_schema_extra": { @@ -60,8 +66,29 @@ class AuthenticationResponse(BaseModel): "expiration": 1706045776.4168212 }]}} + def __getitem__(self, item): + if hasattr(self, item): + return getattr(self, item) + raise KeyError(item) + class RefreshRequest(BaseModel): - access_token: str + access_token: Optional[str] = None refresh_token: str client_id: str + + +class RegistrationRequest(BaseModel): + username: str + password: str + user_config: NeonUserConfig = NeonUserConfig() + + model_config = { + "json_schema_extra": { + "examples": [{ + "username": "guest", + "password": "password", + "user_config": NeonUserConfig().model_dump() + }, {"username": "guest", + "password": "password"} + ]}} diff --git a/neon_hana/schema/node_model.py b/neon_hana/schema/node_model.py index 3fbdc50..b9f0080 100644 --- a/neon_hana/schema/node_model.py +++ b/neon_hana/schema/node_model.py @@ -23,35 +23,9 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from uuid import uuid4 -from pydantic import BaseModel, Field -from typing import Optional, Dict +from neon_data_models.models.client.node import NodeSoftware, NodeNetworking, NodeLocation, NodeData +from ovos_utils.log import log_deprecation - -class NodeSoftware(BaseModel): - operating_system: str = "" - os_version: str = "" - neon_packages: Optional[Dict[str, str]] = None - - -class NodeNetworking(BaseModel): - local_ip: str = "127.0.0.1" - public_ip: str = "" - mac_address: str = "" - - -class NodeLocation(BaseModel): - lat: Optional[float] = None - lon: Optional[float] = None - site_id: Optional[str] = None - - -class NodeData(BaseModel): - device_id: str = Field(default_factory=lambda: str(uuid4())) - device_name: str = "" - device_description: str = "" - platform: str = "" - networking: NodeNetworking = NodeNetworking() - software: NodeSoftware = NodeSoftware() - location: NodeLocation = NodeLocation() +log_deprecation('Imports moved to `neon_data_models.models.client.node`', + '1.0.0') diff --git a/neon_hana/schema/node_v1.py b/neon_hana/schema/node_v1.py deleted file mode 100644 index 913c6b0..0000000 --- a/neon_hana/schema/node_v1.py +++ /dev/null @@ -1,154 +0,0 @@ -# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System -# All trademark and other rights reserved by their respective owners -# Copyright 2008-2021 Neongecko.com Inc. -# BSD-3 -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# 1. Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from this -# software without specific prior written permission. -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, -# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, -# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from pydantic import BaseModel, Field -from typing import Optional, Dict, List, Literal -from neon_hana.schema.node_model import NodeData - - -class NodeInputContext(BaseModel): - node_data: Optional[NodeData] = Field(description="Node Data") - - -class AudioInputData(BaseModel): - audio_data: str = Field(description="base64-encoded audio") - lang: str = Field(description="BCP-47 language code") - - -class TextInputData(BaseModel): - text: str = Field(description="String text input") - lang: str = Field(description="BCP-47 language code") - - -class UtteranceInputData(BaseModel): - utterances: List[str] = Field(description="List of input utterance(s)") - lang: str = Field(description="BCP-47 language") - - -class KlatResponse(BaseModel): - sentence: str = Field(description="Text response") - audio: dict = {Field(description="Audio Gender", - type=Literal["male", "female"]): - Field(description="b64-encoded audio", type=str)} - - -class TtsResponse(KlatResponse): - translated: bool = Field(description="True if sentence was translated") - phonemes: List[str] = Field(description="List of phonemes") - genders: List[str] = Field(description="List of audio genders") - - -class KlatResponseData(BaseModel): - responses: dict = {Field(type=str, - description="BCP-47 language"): KlatResponse} - - -class NodeAudioInput(BaseModel): - msg_type: str = "neon.audio_input" - data: AudioInputData - context: NodeInputContext - - -class NodeTextInput(BaseModel): - msg_type: str = "recognizer_loop:utterance" - data: UtteranceInputData - context: NodeInputContext - - -class NodeGetStt(BaseModel): - msg_type: str = "neon.get_stt" - data: AudioInputData - context: NodeInputContext - - -class NodeGetTts(BaseModel): - msg_type: str = "neon.get_tts" - data: TextInputData - context: NodeInputContext - - -class NodeKlatResponse(BaseModel): - msg_type: str = "klat.response" - data: dict = {Field(type=str, description="BCP-47 language"): KlatResponse} - context: dict - - -class NodeAudioInputResponse(BaseModel): - msg_type: str = "neon.audio_input.response" - data: dict = {"parser_data": Field(description="Dict audio parser data", - type=dict), - "transcripts": Field(description="Transcribed text", - type=List[str]), - "skills_recv": Field(description="Skills service acknowledge", - type=bool)} - context: dict - - -class NodeGetSttResponse(BaseModel): - msg_type: str = "neon.get_stt.response" - data: dict = {"parser_data": Field(description="Dict audio parser data", - type=dict), - "transcripts": Field(description="Transcribed text", - type=List[str]), - "skills_recv": Field(description="Skills service acknowledge", - type=bool)} - context: dict - - -class NodeGetTtsResponse(BaseModel): - msg_type: str = "neon.get_tts.response" - data: KlatResponseData - context: dict - - -class CoreWWDetected(BaseModel): - msg_type: str = "neon.ww_detected" - data: dict - context: dict - - -class CoreIntentFailure(BaseModel): - msg_type: str = "complete.intent.failure" - data: dict - context: dict - - -class CoreErrorResponse(BaseModel): - msg_type: str = "klat.error" - data: dict - context: dict - - -class CoreClearData(BaseModel): - msg_type: str = "neon.clear_data" - data: dict - context: dict - - -class CoreAlertExpired(BaseModel): - msg_type: str = "neon.alert_expired" - data: dict - context: dict diff --git a/neon_hana/schema/user_profile.py b/neon_hana/schema/user_profile.py index 91a9a05..231c39b 100644 --- a/neon_hana/schema/user_profile.py +++ b/neon_hana/schema/user_profile.py @@ -24,81 +24,8 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Optional, List -from pydantic import BaseModel +from neon_data_models.models.user.neon_profile import * +from ovos_utils.log import log_deprecation - -class ProfileUser(BaseModel): - first_name: str = "" - middle_name: str = "" - last_name: str = "" - preferred_name: str = "" - full_name: str = "" - dob: str = "YYYY/MM/DD" - age: str = "" - email: str = "" - username: str = "" - password: str = "" - picture: str = "" - about: str = "" - phone: str = "" - phone_verified: bool = False - email_verified: bool = False - - -class ProfileBrands(BaseModel): - ignored_brands: dict = {} - favorite_brands: dict = {} - specially_requested: dict = {} - - -class ProfileSpeech(BaseModel): - stt_language: str = "en-us" - alt_languages: List[str] = ['en'] - tts_language: str = "en-us" - tts_gender: str = "female" - neon_voice: Optional[str] = '' - secondary_tts_language: Optional[str] = '' - secondary_tts_gender: str = "male" - secondary_neon_voice: str = '' - speed_multiplier: float = 1.0 - - -class ProfileUnits(BaseModel): - time: int = 12 - # 12, 24 - date: str = "MDY" - # MDY, YMD, YDM - measure: str = "imperial" - # imperial, metric - - -class ProfileLocation(BaseModel): - lat: Optional[float] = None - lng: Optional[float] = None - city: Optional[str] = None - state: Optional[str] = None - country: Optional[str] = None - tz: Optional[str] = None - utc: Optional[float] = None - - -class ProfileResponseMode(BaseModel): - speed_mode: str = "quick" - hesitation: bool = False - limit_dialog: bool = False - - -class ProfilePrivacy(BaseModel): - save_audio: bool = False - save_text: bool = False - - -class UserProfile(BaseModel): - user: ProfileUser = ProfileUser() - # brands: ProfileBrands - speech: ProfileSpeech = ProfileSpeech() - units: ProfileUnits = ProfileUnits() - location: ProfileLocation = ProfileLocation() - response_mode: ProfileResponseMode = ProfileResponseMode() - privacy: ProfilePrivacy = ProfilePrivacy() +log_deprecation('Imports moved to `neon_data_models.models.user.neon_profile`', + '1.0.0') diff --git a/neon_hana/auth/permissions.py b/neon_hana/schema/user_requests.py similarity index 66% rename from neon_hana/auth/permissions.py rename to neon_hana/schema/user_requests.py index 287cc38..3e1618b 100644 --- a/neon_hana/auth/permissions.py +++ b/neon_hana/schema/user_requests.py @@ -1,6 +1,6 @@ # NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System # All trademark and other rights reserved by their respective owners -# Copyright 2008-2021 Neongecko.com Inc. +# Copyright 2008-2024 Neongecko.com Inc. # BSD-3 # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -24,20 +24,33 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from dataclasses import dataclass, asdict +from pydantic import BaseModel +from typing import Optional +from neon_data_models.models.user.database import User -@dataclass -class ClientPermissions: - """ - Data class representing permissions of a particular client connection. - """ - assist: bool = True - backend: bool = True - node: bool = False - def as_dict(self) -> dict: - """ - Get a dict representation of this instance. - """ - return asdict(self) +class GetUserRequest(BaseModel): + username: str = "guest" + + model_config = { + "json_schema_extra": { + "examples": [{ + "username": "guest" + }]}} + + +class UpdateUserRequest(BaseModel): + user: User + auth_username: Optional[str] = None + auth_password: Optional[str] = None + + model_config = { + "json_schema_extra": { + "examples": [{ + "user": User(username="guest").model_dump() + }, + {"user": User(username="some_user").model_dump(), + "auth_username": "admin_user", + "auth_password": "admin_password"} + ]}} diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0e5f1c8..2116970 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -5,4 +5,6 @@ pydantic~=2.5 pyjwt~=2.8 token-throttler~=1.4 neon-mq-connector~=0.7 -ovos-config~=0.0.12 \ No newline at end of file +ovos-config~=0.0,>=0.0.12 +ovos-utils~=0.0,>=0.0.38 +neon-data-models diff --git a/tests/test_app.py b/tests/test_app.py index ef546b1..11c69b6 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -5,6 +5,8 @@ from fastapi.testclient import TestClient +from neon_data_models.models.user import User + _TEST_CONFIG = { "mq_default_timeout": 10, "access_token_ttl": 86400, # 1 day @@ -36,7 +38,11 @@ def setUpClass(cls, ws_api, config): app = create_app(_TEST_CONFIG) cls.test_app = TestClient(app) - def _get_tokens(self): + @patch("neon_hana.mq_service_api.send_mq_request") + def _get_tokens(self, send_request): + valid_user = User(username="guest", password_hash="password") + send_request.return_value = {"user": valid_user.model_dump(), + "success": True} if not self.tokens: response = self.test_app.post("/auth/login", json={"username": "guest", @@ -52,12 +58,14 @@ def test_app_init(self): @patch("neon_hana.mq_service_api.send_mq_request") def test_auth_login(self, send_request): - send_request.return_value = {} # TODO: Define valid login + valid_user = User(username="guest", password_hash="password") + send_request.return_value = {"user": valid_user.model_dump(), + "success": True} # Valid Login response = self.test_app.post("/auth/login", - json={"username": "guest", - "password": "password"}) + json={"username": valid_user.username, + "password": valid_user.password_hash}) response_data = response.json() self.assertEqual(response.status_code, 200, response.text) self.assertEqual(response_data['username'], "guest") @@ -66,7 +74,13 @@ def test_auth_login(self, send_request): self.assertGreater(response_data['expiration'], time()) # Invalid Login - # TODO: Define invalid login request + send_request.return_value = {"code": 404, "error": "User not found"} + response = self.test_app.post("/auth/login", + json={"username": valid_user.username, + "password": valid_user.password_hash}) + self.assertEqual(response.status_code, 404, response.status_code) + self.assertEqual(response.json()['detail'], + "User not found", response.text) # Invalid Request self.assertEqual(self.test_app.post("/auth/login").status_code, 422) @@ -76,7 +90,9 @@ def test_auth_login(self, send_request): @patch("neon_hana.mq_service_api.send_mq_request") def test_auth_refresh(self, send_request): - send_request.return_value = {} # TODO: Define valid refresh + valid_user = User(username="guest", password_hash="password") + send_request.return_value = {"user": valid_user.model_dump(), + "success": True} valid_tokens = self._get_tokens() @@ -86,14 +102,12 @@ def test_auth_refresh(self, send_request): response_data = response.json() self.assertNotEqual(response_data, valid_tokens) - # # TODO - # # Refresh with old tokens fails - # response = self.test_app.post("/auth/refresh", json=valid_tokens) - # self.assertEqual(response.status_code, 422, response.text) - - # Valid request with new tokens - response = self.test_app.post("/auth/refresh", json=response_data) - self.assertEqual(response.status_code, 200, response.text) + # Refresh with old tokens fails (mocked return from users service) + send_request.return_value = {"code": 422, + "detail": "Invalid token", + "success": False} + response = self.test_app.post("/auth/refresh", json=valid_tokens) + self.assertEqual(response.status_code, 422, response.text) # TODO: Test with expired token diff --git a/tests/test_auth.py b/tests/test_auth.py index 88422fb..b5270d1 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -25,7 +25,7 @@ # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import unittest -from time import time +from time import time, sleep from uuid import uuid4 from fastapi import HTTPException @@ -34,7 +34,7 @@ class TestClientManager(unittest.TestCase): from neon_hana.auth.client_manager import ClientManager client_manager = ClientManager({"access_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391b117b", - "refresh_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391b117b", + "refresh_token_secret": "a800445648142061fc238d1f84e96200da87f4f9f784108ac90db8b4391ba800", "disable_auth": False}) def test_check_auth_request(self): @@ -62,25 +62,30 @@ def test_check_auth_request(self): # TODO: Test permissions - # Check auth already authorized - self.assertEqual(auth_resp_2, - self.client_manager.check_auth_request(**request_2)) + # Check auth already authorized. New tokens are generated with new + # expirations + self.assertNotEqual(auth_resp_2, + self.client_manager.check_auth_request(**request_2)) def test_validate_auth(self): + # Test valid client valid_client = str(uuid4()) - invalid_client = str(uuid4()) auth_response = self.client_manager.check_auth_request( - username="valid", client_id=valid_client)['access_token'] - + username="valid", client_id=valid_client).access_token self.assertTrue(self.client_manager.validate_auth(auth_response, "127.0.0.1")) + + # Unauthenticated client fails + invalid_client = str(uuid4()) self.assertFalse(self.client_manager.validate_auth(invalid_client, "127.0.0.1")) - - expired_token = self.client_manager._create_tokens( - {"client_id": invalid_client, "username": "test", - "password": "test", "expire": time(), - "permissions": {}})['access_token'] + # Test expired token fails auth + self.client_manager._access_token_lifetime = 1 + self.client_manager._refresh_token_lifetime = 1 + expired_token, _, _ = self.client_manager._create_tokens( + user_id=str(uuid4()), + client_id=str(uuid4())) + sleep(1) self.assertFalse(self.client_manager.validate_auth(expired_token, "127.0.0.1")) @@ -93,116 +98,52 @@ def test_validate_auth(self): def test_check_refresh_request(self): valid_client = str(uuid4()) - tokens = self.client_manager._create_tokens({"client_id": valid_client, - "username": "test", - "password": "test", - "expire": time(), - "permissions": {}}) - self.assertEqual(tokens['client_id'], valid_client) + self.client_manager._access_token_lifetime = 60 + self.client_manager._refresh_token_lifetime = 3600 + access, refresh, config = self.client_manager._create_tokens( + user_id=str(uuid4()), client_id=valid_client) + access2, refresh2, config2 = self.client_manager._create_tokens( + user_id=str(uuid4()), client_id=str(uuid4())) + self.assertEqual(config['access'].client_id, valid_client) + self.assertEqual(config['refresh'].client_id, valid_client) # Test invalid refresh token with self.assertRaises(HTTPException) as e: - self.client_manager.check_refresh_request(tokens['access_token'], - valid_client, + self.client_manager.check_refresh_request(access, access, valid_client) self.assertEqual(e.exception.status_code, 400) # Test incorrect access token with self.assertRaises(HTTPException) as e: - self.client_manager.check_refresh_request(tokens['refresh_token'], - tokens['refresh_token'], + self.client_manager.check_refresh_request(access2, refresh, valid_client) self.assertEqual(e.exception.status_code, 403) # Test invalid client_id with self.assertRaises(HTTPException) as e: - self.client_manager.check_refresh_request(tokens['access_token'], - tokens['refresh_token'], + self.client_manager.check_refresh_request(access, refresh, str(uuid4())) self.assertEqual(e.exception.status_code, 403) # Test valid refresh valid_refresh = self.client_manager.check_refresh_request( - tokens['access_token'], tokens['refresh_token'], - tokens['client_id']) - self.assertEqual(valid_refresh['client_id'], tokens['client_id']) - self.assertNotEqual(valid_refresh['access_token'], - tokens['access_token']) - self.assertNotEqual(valid_refresh['refresh_token'], - tokens['refresh_token']) + access, refresh, config['access'].client_id) + self.assertEqual(valid_refresh.client_id, config['access'].client_id) + self.assertNotEqual(valid_refresh.access_token, access) + self.assertNotEqual(valid_refresh.refresh_token, refresh) # Test expired refresh token real_refresh = self.client_manager._refresh_token_lifetime self.client_manager._refresh_token_lifetime = 0 - tokens = self.client_manager._create_tokens({"client_id": valid_client, - "username": "test", - "password": "test", - "expire": time(), - "permissions": {}}) + + access, refresh, config = self.client_manager._create_tokens( + user_id=str(uuid4()), client_id=valid_client) with self.assertRaises(HTTPException) as e: - self.client_manager.check_refresh_request(tokens['access_token'], - tokens['refresh_token'], - tokens['client_id']) + self.client_manager.check_refresh_request(access, refresh, + config['access'].client_id) self.assertEqual(e.exception.status_code, 401) self.client_manager._refresh_token_lifetime = real_refresh - def test_get_permissions(self): - from neon_hana.auth.permissions import ClientPermissions - - node_user = "node_test" - rest_user = "rest_user" - self.client_manager._node_username = node_user - self.client_manager._node_password = node_user - - rest_resp = self.client_manager.check_auth_request(rest_user, rest_user) - node_resp = self.client_manager.check_auth_request(node_user, node_user, - node_user) - node_fail = self.client_manager.check_auth_request("node_fail", - node_user, rest_user) - - rest_cid = rest_resp['client_id'] - node_cid = node_resp['client_id'] - fail_cid = node_fail['client_id'] - - permissive = ClientPermissions(True, True, True) - no_node = ClientPermissions(True, True, False) - no_perms = ClientPermissions(False, False, False) - - # Auth disabled, returns all True - self.client_manager._disable_auth = True - self.assertEqual(self.client_manager.get_permissions(rest_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions(node_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions(rest_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions(fail_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions("fake_user"), - permissive) - - # Auth enabled - self.client_manager._disable_auth = False - self.assertEqual(self.client_manager.get_permissions(rest_cid), no_node) - self.assertEqual(self.client_manager.get_permissions(node_cid), - permissive) - self.assertEqual(self.client_manager.get_permissions(fail_cid), no_node) - self.assertEqual(self.client_manager.get_permissions("fake_user"), - no_perms) - - def test_client_permissions(self): - from neon_hana.auth.permissions import ClientPermissions - default_perms = ClientPermissions() - restricted_perms = ClientPermissions(False, False, False) - permissive_perms = ClientPermissions(True, True, True) - self.assertIsInstance(default_perms.as_dict(), dict) - for v in default_perms.as_dict().values(): - self.assertIsInstance(v, bool) - self.assertIsInstance(restricted_perms.as_dict(), dict) - self.assertFalse(any([v for v in restricted_perms.as_dict().values()])) - self.assertIsInstance(permissive_perms.as_dict(), dict) - self.assertTrue(all([v for v in permissive_perms.as_dict().values()])) - def test_stream_connections(self): # Test configured maximum self.client_manager._max_streaming_clients = 1 diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..53e6dc0 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,18 @@ +from unittest import TestCase + +from neon_hana.schema.user_profile import UserProfile + +from neon_data_models.models.user import User + + +class TestUserProfile(TestCase): + def test_user_profile(self): + # Test default + profile = UserProfile() + self.assertIsInstance(profile, UserProfile) + + # Test from User + default_user = User(username="test_user") + profile = UserProfile.from_user_object(default_user) + self.assertIsInstance(profile, UserProfile) + self.assertEqual(default_user.username, profile.user.username)