diff --git a/auth-api/src/auth_api/models/user.py b/auth-api/src/auth_api/models/user.py index fdefffdcc..b23cd8bfa 100644 --- a/auth-api/src/auth_api/models/user.py +++ b/auth-api/src/auth_api/models/user.py @@ -258,6 +258,8 @@ def _get_type(cls, user_from_context: UserContext) -> str: or user_from_context.login_source == LoginSource.BCROS.value ): user_type = Role.ANONYMOUS_USER.name + elif user_from_context.is_staff(): + user_type = Role.STAFF.name elif Role.GOV_ACCOUNT_USER.value in user_from_context.roles: user_type = Role.GOV_ACCOUNT_USER.name elif Role.PUBLIC_USER.value in user_from_context.roles or user_from_context.login_source in [ @@ -265,8 +267,6 @@ def _get_type(cls, user_from_context: UserContext) -> str: LoginSource.BCSC.value, ]: user_type = Role.PUBLIC_USER.name - elif user_from_context.is_staff(): - user_type = Role.STAFF.name elif user_from_context.is_system(): user_type = Role.SYSTEM.name diff --git a/auth-api/src/auth_api/resources/v1/user.py b/auth-api/src/auth_api/resources/v1/user.py index 8f4d4a272..5b58e5f8c 100644 --- a/auth-api/src/auth_api/resources/v1/user.py +++ b/auth-api/src/auth_api/resources/v1/user.py @@ -15,7 +15,7 @@ from http import HTTPStatus -from flask import Blueprint, abort, g, jsonify, request +from flask import Blueprint, abort, current_app, g, jsonify, request from flask_cors import cross_origin from auth_api.exceptions import BusinessException @@ -29,6 +29,7 @@ from auth_api.services.org import Org as OrgService from auth_api.services.user import User as UserService from auth_api.utils.auth import jwt as _jwt +from auth_api.utils.constants import GROUP_GOV_ACCOUNT_USERS from auth_api.utils.endpoints_enums import EndpointEnum from auth_api.utils.enums import LoginSource, Status from auth_api.utils.roles import Role @@ -108,6 +109,12 @@ def post_user(): if not valid_format: return {"message": schema_utils.serialize(errors)}, HTTPStatus.BAD_REQUEST + # Ensure STAFF doesn't have GOV_ACCOUNT_USER, otherwise they get extra permissions they shouldn't have. + roles = token.get("realm_access", {}).get("roles", []) + if Role.STAFF.name in roles and Role.GOV_ACCOUNT_USER.value in roles: + current_app.logger.info("Removing GOV_ACCOUNT_USER group from STAFF user") + KeycloakService.remove_user_from_group(token.get("sub"), GROUP_GOV_ACCOUNT_USERS) + user = UserService.save_from_jwt_token(request_json) response, status = user.as_dict(), HTTPStatus.CREATED # Add the user to public_users group if the user doesn't have public_user group diff --git a/auth-api/src/auth_api/services/keycloak.py b/auth-api/src/auth_api/services/keycloak.py index ff0f31513..22800e4e8 100644 --- a/auth-api/src/auth_api/services/keycloak.py +++ b/auth-api/src/auth_api/services/keycloak.py @@ -228,7 +228,7 @@ def remove_from_account_holders_group(keycloak_guid: str = None, **kwargs): keycloak_guid: Dict = user_from_context.sub if Role.ACCOUNT_HOLDER.value in user_from_context.roles: - KeycloakService._remove_user_from_group(keycloak_guid, GROUP_ACCOUNT_HOLDERS) + KeycloakService.remove_user_from_group(keycloak_guid, GROUP_ACCOUNT_HOLDERS) @staticmethod @user_context @@ -339,7 +339,7 @@ def add_user_to_group(user_id: str, group_name: str): response.raise_for_status() @staticmethod - def _remove_user_from_group(user_id: str, group_name: str): + def remove_user_from_group(user_id: str, group_name: str): """Remove user from the keycloak group.""" config = current_app.config base_url = config.get("KEYCLOAK_BASE_URL")