diff --git a/neon_data_models/enum.py b/neon_data_models/enum.py index 305fb6f..3fcd4ba 100644 --- a/neon_data_models/enum.py +++ b/neon_data_models/enum.py @@ -38,9 +38,11 @@ class AccessRoles(IntEnum): NONE = 0 GUEST = 1 USER = 2 - ADMIN = 3 - OWNER = 4 - + # 3-5 Reserved for "premium users" + ADMIN = 6 + # 7-8 Reserved for "restricted owners" + OWNER = 9 + # 10 Reserved for "unlimited access" NODE = -1 diff --git a/neon_data_models/models/api/__init__.py b/neon_data_models/models/api/__init__.py index 24ff4a3..5881279 100644 --- a/neon_data_models/models/api/__init__.py +++ b/neon_data_models/models/api/__init__.py @@ -26,3 +26,4 @@ from neon_data_models.models.api.node_v1 import * from neon_data_models.models.api.mq import * +from neon_data_models.models.api.jwt import * diff --git a/neon_data_models/models/api/jwt.py b/neon_data_models/models/api/jwt.py new file mode 100644 index 0000000..61680a3 --- /dev/null +++ b/neon_data_models/models/api/jwt.py @@ -0,0 +1,86 @@ +# NEON AI (TM) SOFTWARE, Software Development Kit & Application Development System +# All trademark and other rights reserved by their respective owners +# Copyright 2008-2024 Neongecko.com Inc. +# BSD-3 +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from this +# software without specific prior written permission. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, +# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from time import time +from typing import Optional, List, Literal +from uuid import uuid4 + +from pydantic import Field +from neon_data_models.enum import AccessRoles +from neon_data_models.models.base import BaseModel + + +class JWT(BaseModel): + iss: Optional[str] = Field(None, validate_default=True, + description="Token issuer") + sub: Optional[str] = Field(None, validate_default=True, + description="Unique token subject, ie a user ID") + exp: int = Field(description="Expiration time in epoch seconds") + iat: int = Field(description="Token creation time in epoch seconds") + jti: str = Field(description="Unique token identifier", + default_factory=lambda: str(uuid4())) + + client_id: str = Field(description="Client identifier") + roles: List[str] = Field(description="List of roles, " + "formatted as ` `. " + "See PermissionsConfig for role names") + + +class HanaToken(JWT): + def __init__(self, **kwargs): + from neon_data_models.models.user import PermissionsConfig + permissions = kwargs.get("permissions") + if permissions and isinstance(permissions, PermissionsConfig): + kwargs["roles"] = permissions.to_roles() + elif permissions and isinstance(permissions, dict): + core_permissions = AccessRoles.GUEST if \ + permissions.get("assist") else AccessRoles.NONE + diana_permissions = AccessRoles.GUEST if \ + permissions.get("backend") else AccessRoles.NONE + node_permissions = AccessRoles.USER if \ + permissions.get("node") else AccessRoles.NONE + kwargs["roles"] = [f"core {core_permissions.value}", + f"diana {diana_permissions.value}", + f"node {node_permissions.value}"] + if kwargs.get("expire") and isinstance(kwargs["expire"], float): + kwargs["expire"] = round(kwargs["expire"]) + BaseModel.__init__(self, **kwargs) + + # Private parameters + token_name: str = Field(default="", + description="Friendly name to identify this token.") + creation_timestamp: int = Field(default_factory=lambda: int(time()), + description="Timestamp of initial token " + "creation (not counting " + "refreshes).") + last_refresh_timestamp: Optional[int] = Field(default=None, + description="Timestamp of " + "most recent " + "refresh.") + purpose: Literal["access", "refresh"] = "access" + + +__all__ = [JWT.__name__, HanaToken.__name__] diff --git a/neon_data_models/models/api/node_v1/__init__.py b/neon_data_models/models/api/node_v1/__init__.py index 6ca1cc4..3440595 100644 --- a/neon_data_models/models/api/node_v1/__init__.py +++ b/neon_data_models/models/api/node_v1/__init__.py @@ -80,8 +80,8 @@ class TextInputData(BaseModel): class NodeKlatResponse(BaseMessage): msg_type: Literal["klat.response"] = "klat.response" - data: Dict[str, KlatResponse] = Field(type=Dict[str, KlatResponse], - description="dict of BCP-47 language: KlatResponse") + data: Dict[str, KlatResponse] = Field( + description="dict of BCP-47 language: KlatResponse") class NodeAudioInputResponse(BaseMessage): @@ -97,8 +97,7 @@ class NodeGetSttResponse(BaseMessage): class NodeGetTtsResponse(BaseMessage): msg_type: Literal["neon.get_tts.response"] = "neon.get_tts.response" data: Dict[str, KlatResponse] = ( - Field(type=Dict[str, KlatResponse], - description="dict of BCP-47 language: KlatResponse")) + Field(description="dict of BCP-47 language: KlatResponse")) class CoreWWDetected(BaseMessage): diff --git a/neon_data_models/models/user/database.py b/neon_data_models/models/user/database.py index dd842c0..af6f5fc 100644 --- a/neon_data_models/models/user/database.py +++ b/neon_data_models/models/user/database.py @@ -26,7 +26,10 @@ from time import time from typing import Dict, Any, List, Literal, Optional +from typing_extensions import deprecated from uuid import uuid4 + +from neon_data_models.models.api.jwt import HanaToken from neon_data_models.models.base import BaseModel from pydantic import Field from datetime import date @@ -124,7 +127,28 @@ class PermissionsConfig(BaseModel): class Config: use_enum_values = True - + @classmethod + def from_roles(cls, roles: List[str]): + """ + Parse PermissionsConfig from standard JWT roles configuration. + """ + kwargs = {} + for role in roles: + name, value = role.split(' ') + kwargs[name] = AccessRoles[value.upper()] + return cls(**kwargs) + + def to_roles(self): + """ + Dump a PermissionsConfig to standard JWT roles to be included in a JWT. + """ + roles = [] + for key, val in self.model_dump().items(): + roles.append(f"{key} {AccessRoles(val).name}") + return roles + + +@deprecated(f"Use `neon_data_models.models.api.jwt.HanaToken`") class TokenConfig(BaseModel): username: str client_id: str @@ -136,9 +160,9 @@ class TokenConfig(BaseModel): description="Unix timestamp of refresh token expiration") token_name: str creation_timestamp: int = Field( - description="Unix timestamp of auth token creation") + description="Unix timestamp of token creation (auth+refresh)") last_refresh_timestamp: int = Field( - description="Unix timestamp of last auth token refresh") + description="Unix timestamp of last token refresh (auth+refresh)") access_token: Optional[str] = None @@ -151,12 +175,11 @@ class User(BaseModel): klat: KlatConfig = KlatConfig() llm: BrainForgeConfig = BrainForgeConfig() permissions: PermissionsConfig = PermissionsConfig() - tokens: Optional[List[TokenConfig]] = [] + tokens: Optional[List[HanaToken]] = [] def __eq__(self, other): return self.model_dump() == other.model_dump() __all__ = [NeonUserConfig.__name__, KlatConfig.__name__, - BrainForgeConfig.__name__, PermissionsConfig.__name__, - TokenConfig.__name__, User.__name__] + BrainForgeConfig.__name__, PermissionsConfig.__name__, User.__name__] diff --git a/tests/models/test_user.py b/tests/models/test_user.py index 8762651..c67c862 100644 --- a/tests/models/test_user.py +++ b/tests/models/test_user.py @@ -26,9 +26,13 @@ from time import time from unittest import TestCase +from uuid import uuid4 + from pydantic import ValidationError from datetime import date -from neon_data_models.models.user.database import NeonUserConfig, TokenConfig, User + +from neon_data_models.models.api.jwt import HanaToken +from neon_data_models.models.user.database import NeonUserConfig, User, PermissionsConfig class TestDatabase(TestCase): @@ -83,17 +87,17 @@ def test_neon_user_config(self): def test_user(self): user_kwargs = dict(username="test", password_hash="test", - tokens=[{"username": "test", - "client_id": "test_id", - "permissions": {}, - "refresh_token": "", - "expiration": round(time()), - "refresh_expiration": round(time()), - "token_name": "test_token", + tokens=[{"token_name": "test_token", + "jti": str(uuid4()), + "sub": str(uuid4()), + "client_id": str(uuid4()), + "roles": PermissionsConfig().to_roles(), + "iat": round(time()) - 1, + "exp": round(time()) + 1, "creation_timestamp": round(time()), "last_refresh_timestamp": round(time())}]) default_user = User(**user_kwargs) - self.assertIsInstance(default_user.tokens[0], TokenConfig) + self.assertIsInstance(default_user.tokens[0], HanaToken) with self.assertRaises(ValidationError): User() @@ -105,6 +109,65 @@ def test_user(self): self.assertNotEqual(default_user, duplicate_user) self.assertEqual(default_user.tokens, duplicate_user.tokens) + def test_permissions_config(self): + from neon_data_models.models.user.database import PermissionsConfig + from neon_data_models.enum import AccessRoles + + # Test Default + default_config = PermissionsConfig() + for _, value in default_config.model_dump().items(): + self.assertEqual(value, AccessRoles.NONE) + + test_config = PermissionsConfig(klat=AccessRoles.USER, + core=AccessRoles.GUEST, + diana=AccessRoles.GUEST, + node=AccessRoles.NODE, + hub=AccessRoles.NODE, + llm=AccessRoles.NONE) + # Test dump/load + self.assertEqual(PermissionsConfig(**test_config.model_dump()), + test_config) + + # Test to/from roles + roles = test_config.to_roles() + self.assertIsInstance(roles, list) + for role in roles: + self.assertEqual(len(role.split()), 2) + self.assertEqual(PermissionsConfig.from_roles(roles), test_config) + + def test_token_config(self): + from neon_data_models.models.user.database import PermissionsConfig + token_id = str(uuid4()) + user_id = str(uuid4()) + client_id = str(uuid4()) + token_name = "Test Token" + permissions = PermissionsConfig() + refresh_expiration = round(time()) + 3600 + creation = round(time()) - 3600 + last_refresh = round(time()) + + from_database = HanaToken(token_name=token_name, + jti=token_id, + sub=user_id, + client_id=client_id, + roles=permissions.to_roles(), + exp=refresh_expiration, + iat=creation, + last_refresh_timestamp=last_refresh) + + from_token = HanaToken(jti=token_id, + sub=user_id, + iat=creation, + exp=refresh_expiration, + token_name=token_name, + client_id=client_id, + permissions=permissions, + last_refresh_timestamp=last_refresh) + + self.assertEqual(from_database, from_token) + self.assertEqual(from_database.model_dump_json(), + from_token.model_dump_json()) + class TestNeonProfile(TestCase): def test_create(self):