Skip to content

Commit

Permalink
Implement permissions endpoint with unit test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
NeonDaniel committed Dec 28, 2024
1 parent 994640a commit a1354dc
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
12 changes: 9 additions & 3 deletions neon_hana/app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Depends
from neon_data_models.models.user.database import PermissionsConfig
from neon_data_models.models.user import User

from neon_hana.app.dependencies import client_manager
from neon_hana.app.dependencies import client_manager, jwt_bearer
from neon_hana.schema.auth_requests import *
from neon_data_models.models.user import User

auth_route = APIRouter(prefix="/auth", tags=["authentication"])

Expand All @@ -51,3 +52,8 @@ async def register_user(register_request: RegistrationRequest,
request: Request) -> User:
return client_manager.check_registration_request(**dict(register_request),
origin_ip=request.client.host)


@auth_route.post("/permissions")
async def check_permissions(token: str = Depends(jwt_bearer)) -> PermissionsConfig:
return client_manager.get_token_permissions(token)
11 changes: 10 additions & 1 deletion neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,20 @@ def get_token_data(self, token: str) -> HanaToken:
"""
Extract the user_id from a JWT string
@param token: JWT to parse
@retrun: user_id associated with token
@return: user_id associated with token
"""
return HanaToken(**jwt.decode(token, self._access_secret,
self._jwt_algo))

def get_token_permissions(self, token: str) -> PermissionsConfig:
"""
Get a PermissionsConfig object from a JWT Token
@param token: JWT to parse
@return: PermissionsConfig object representing the token permissions
"""
roles = self.get_token_data(token).roles
return PermissionsConfig.from_roles(roles)

def validate_auth(self, token: str, origin_ip: str) -> bool:
ratelimit_id = f"{origin_ip}-total"
if not self.rate_limiter.get_all_buckets(ratelimit_id):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from uuid import uuid4

from fastapi import HTTPException
from jwt import DecodeError

from neon_data_models.enum import AccessRoles
from neon_data_models.models.user.database import PermissionsConfig


class TestClientManager(unittest.TestCase):
Expand Down Expand Up @@ -173,3 +177,23 @@ def test_stream_connections(self):
self.client_manager._max_streaming_clients = False
self.assertTrue(self.client_manager.check_connect_stream())
self.assertEqual(self.client_manager._connected_streams, 5)

def test_get_token_permissions(self):
permissions = PermissionsConfig(core=AccessRoles.USER,
diana=AccessRoles.USER,
node=AccessRoles.USER,
llm=AccessRoles.USER,
users=AccessRoles.USER)
valid_token, _, _ = self.client_manager._create_tokens("test_user",
"test_client",
"test_name",
permissions)

# Valid token decodes
self.assertEqual(self.client_manager.get_token_permissions(valid_token),
permissions)

# Invalid token raises exception
with self.assertRaises(DecodeError):
self.client_manager.get_token_permissions("invalid_token_string")

0 comments on commit a1354dc

Please sign in to comment.