Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add endpoint to get client permissions #32

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions neon_hana/app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from fastapi import APIRouter, Request
from fastapi import APIRouter, Request, Depends
from neon_data_models.models.user.database import PermissionsConfig
from neon_data_models.models.user import User

from neon_hana.app.dependencies import client_manager
from neon_hana.app.dependencies import client_manager, jwt_bearer
from neon_hana.schema.auth_requests import *
from neon_data_models.models.user import User

auth_route = APIRouter(prefix="/auth", tags=["authentication"])

Expand All @@ -51,3 +52,8 @@ async def register_user(register_request: RegistrationRequest,
request: Request) -> User:
return client_manager.check_registration_request(**dict(register_request),
origin_ip=request.client.host)


@auth_route.post("/permissions")
async def check_permissions(token: str = Depends(jwt_bearer)) -> PermissionsConfig:
return client_manager.get_token_permissions(token)
11 changes: 10 additions & 1 deletion neon_hana/auth/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,20 @@ def get_token_data(self, token: str) -> HanaToken:
"""
Extract the user_id from a JWT string
@param token: JWT to parse
@retrun: user_id associated with token
@return: user_id associated with token
"""
return HanaToken(**jwt.decode(token, self._access_secret,
self._jwt_algo))

def get_token_permissions(self, token: str) -> PermissionsConfig:
"""
Get a PermissionsConfig object from a JWT Token
@param token: JWT to parse
@return: PermissionsConfig object representing the token permissions
"""
roles = self.get_token_data(token).roles
return PermissionsConfig.from_roles(roles)

def validate_auth(self, token: str, origin_ip: str) -> bool:
ratelimit_id = f"{origin_ip}-total"
if not self.rate_limiter.get_all_buckets(ratelimit_id):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
from uuid import uuid4

from fastapi import HTTPException
from jwt import DecodeError

from neon_data_models.enum import AccessRoles
from neon_data_models.models.user.database import PermissionsConfig


class TestClientManager(unittest.TestCase):
Expand Down Expand Up @@ -173,3 +177,23 @@ def test_stream_connections(self):
self.client_manager._max_streaming_clients = False
self.assertTrue(self.client_manager.check_connect_stream())
self.assertEqual(self.client_manager._connected_streams, 5)

def test_get_token_permissions(self):
permissions = PermissionsConfig(core=AccessRoles.USER,
diana=AccessRoles.USER,
node=AccessRoles.USER,
llm=AccessRoles.USER,
users=AccessRoles.USER)
valid_token, _, _ = self.client_manager._create_tokens("test_user",
"test_client",
"test_name",
permissions)

# Valid token decodes
self.assertEqual(self.client_manager.get_token_permissions(valid_token),
permissions)

# Invalid token raises exception
with self.assertRaises(DecodeError):
self.client_manager.get_token_permissions("invalid_token_string")

Loading