diff --git a/auth-api/src/auth_api/exceptions/errors.py b/auth-api/src/auth_api/exceptions/errors.py index 8e349a187e..a4f6b61d66 100644 --- a/auth-api/src/auth_api/exceptions/errors.py +++ b/auth-api/src/auth_api/exceptions/errors.py @@ -23,6 +23,7 @@ class Error(Enum): """Error Codes.""" + INVALID_ORG = "The organization ID is in an incorrect format.", HTTPStatus.BAD_REQUEST INVALID_INPUT = "Invalid input, please check.", HTTPStatus.BAD_REQUEST DATA_NOT_FOUND = "No matching record found.", HTTPStatus.NOT_FOUND DATA_ALREADY_EXISTS = "The data you want to insert already exists.", HTTPStatus.BAD_REQUEST diff --git a/auth-api/src/auth_api/models/product_subscription.py b/auth-api/src/auth_api/models/product_subscription.py index 64482456f0..b300d2404e 100644 --- a/auth-api/src/auth_api/models/product_subscription.py +++ b/auth-api/src/auth_api/models/product_subscription.py @@ -15,6 +15,7 @@ The ProductSubscription object connects Org models to one or more ProductSubscription models. """ +from typing import Self from sql_versioning import Versioned from sqlalchemy import Column, ForeignKey, Integer, and_ @@ -45,7 +46,7 @@ def find_by_org_ids(cls, org_ids, valid_statuses=VALID_SUBSCRIPTION_STATUSES): ).all() @classmethod - def find_by_org_id_product_code(cls, org_id: int, product_code, valid_statuses=VALID_SUBSCRIPTION_STATUSES): + def find_by_org_id_product_code(cls, org_id: int, product_code, valid_statuses=VALID_SUBSCRIPTION_STATUSES) -> Self: """Find an product subscription instance that matches the provided id.""" return cls.query.filter( and_( diff --git a/auth-api/src/auth_api/models/task.py b/auth-api/src/auth_api/models/task.py index 0badd670ec..b31e0846a5 100644 --- a/auth-api/src/auth_api/models/task.py +++ b/auth-api/src/auth_api/models/task.py @@ -13,6 +13,7 @@ # limitations under the License. """This model manages a Task item in the Auth Service.""" import datetime as dt +from typing import Self import pytz from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, text @@ -93,14 +94,14 @@ def fetch_tasks(cls, task_search: TaskSearch): return pagination.items, pagination.total @classmethod - def find_by_task_id(cls, task_id: int): + def find_by_task_id(cls, task_id: int) -> Self: """Find a task instance that matches the provided id.""" return db.session.query(Task).filter_by(id=int(task_id or -1)).first() @classmethod def find_by_task_relationship_id( cls, relationship_id: int, task_relationship_type: str, task_status: str = TaskStatus.OPEN.value - ): + ) -> Self: """Find a task instance that related to the relationship id ( may be an ORG or a PRODUCT.""" return ( db.session.query(Task) @@ -112,6 +113,23 @@ def find_by_task_relationship_id( .first() ) + @classmethod + def find_by_incomplete_task_relationship_id( + cls, relationship_id: int, task_relationship_type: str, relationship_status: str = None) -> Self: + """Find a task instance that related to the relationship id ( may be an ORG or a PRODUCT) that is incomplete.""" + query = ( + db.session.query(Task) + .filter( + Task.relationship_id == int(relationship_id or -1), + Task.relationship_type == task_relationship_type, + Task.status.in_((TaskStatus.OPEN.value, TaskStatus.HOLD.value)) + )) + + if relationship_status is not None: + query = query.filter(Task.relationship_status == relationship_status) + + return query.first() + @classmethod def find_by_task_for_account(cls, org_id: int, status): """Find a task instance that matches the provided id.""" diff --git a/auth-api/src/auth_api/resources/v1/org_products.py b/auth-api/src/auth_api/resources/v1/org_products.py index acfc9e87fe..aabb7eabc1 100644 --- a/auth-api/src/auth_api/resources/v1/org_products.py +++ b/auth-api/src/auth_api/resources/v1/org_products.py @@ -18,7 +18,7 @@ from flask import Blueprint, g, request from flask_cors import cross_origin -from auth_api.exceptions import BusinessException +from auth_api.exceptions import BusinessException, Error from auth_api.schemas import utils as schema_utils from auth_api.services import Product as ProductService from auth_api.utils.auth import jwt as _jwt @@ -34,10 +34,8 @@ def get_org_product_subscriptions(org_id): """GET a new product subscription to the org using the request body.""" - if not org_id or org_id == "None" or not org_id.isdigit() or int(org_id) < 0: - return {"message": "The organization ID is in an incorrect format."}, HTTPStatus.BAD_REQUEST - try: + validate_organization(org_id) include_hidden = request.args.get("include_hidden", None) == "true" # used by NDS response, status = ( json.dumps(ProductService.get_all_product_subscription(org_id=int(org_id), include_hidden=include_hidden)), @@ -54,15 +52,13 @@ def get_org_product_subscriptions(org_id): def post_org_product_subscription(org_id): """Post a new product subscription to the org using the request body.""" - if not org_id or org_id == "None" or not org_id.isdigit() or int(org_id) < 0: - return {"message": "The organization ID is in an incorrect format."}, HTTPStatus.BAD_REQUEST - request_json = request.get_json() valid_format, errors = schema_utils.validate(request_json, "org_product_subscription") if not valid_format: return {"message": schema_utils.serialize(errors)}, HTTPStatus.BAD_REQUEST try: + validate_organization(org_id) roles = g.jwt_oidc_token_info.get("realm_access").get("roles") subscriptions = ProductService.create_product_subscription( int(org_id), request_json, skip_auth=Role.SYSTEM.value in roles, auto_approve=Role.SYSTEM.value in roles @@ -80,17 +76,35 @@ def post_org_product_subscription(org_id): def patch_org_product_subscription(org_id): """Patch existing product subscription to resubmit it for review.""" - if not org_id or org_id == "None" or not org_id.isdigit() or int(org_id) < 0: - return {"message": "The organization ID is in an incorrect format."}, HTTPStatus.BAD_REQUEST - request_json = request.get_json() valid_format, errors = schema_utils.validate(request_json, "org_product_subscription") if not valid_format: return {"message": schema_utils.serialize(errors)}, HTTPStatus.BAD_REQUEST try: + validate_organization(org_id) subscriptions = ProductService.resubmit_product_subscription(int(org_id), request_json) response, status = {"subscriptions": subscriptions}, HTTPStatus.OK except BusinessException as exception: response, status = {"code": exception.code, "message": exception.message}, exception.status_code return response, status + + +@bp.route("/", methods=["DELETE", "OPTIONS"]) +@cross_origin(origins="*", methods=["DELETE"]) +@_jwt.has_one_of_roles([Role.STAFF_CREATE_ACCOUNTS.value, Role.PUBLIC_USER.value, Role.SYSTEM.value]) +def delete_product_subscription(org_id, product_code): + """Delete existing product subscription.""" + + try: + validate_organization(org_id) + subscriptions = ProductService.remove_product_subscription(int(org_id), product_code) + response, status = {"subscriptions": subscriptions}, HTTPStatus.OK + except BusinessException as exception: + response, status = {"code": exception.code, "message": exception.message}, exception.status_code + return response, status + + +def validate_organization(org_id): + if not org_id or org_id == "None" or not org_id.isdigit() or int(org_id) < 0: + raise BusinessException(Error.INVALID_ORG, None) diff --git a/auth-api/src/auth_api/services/products.py b/auth-api/src/auth_api/services/products.py index e078dc3fe9..af55bd75cf 100644 --- a/auth-api/src/auth_api/services/products.py +++ b/auth-api/src/auth_api/services/products.py @@ -143,13 +143,33 @@ def resubmit_product_subscription(org_id, subscription_data: Dict[str, Any], ski return Product.get_all_product_subscription(org_id=org_id, skip_auth=True) + @staticmethod + def _is_previously_approved(org_id: int, product_code: str): + """Check if this product has a task that was previously approved.""" + inactive_sub = (ProductSubscriptionModel + .find_by_org_id_product_code(org_id=org_id, + product_code=product_code, + valid_statuses=(ProductSubscriptionStatus.INACTIVE.value,))) + if not inactive_sub: + return False, None + + task = TaskModel.find_by_task_relationship_id( + inactive_sub.id, TaskRelationshipType.PRODUCT.value, TaskStatus.COMPLETED.value + ) + if (task is None + or (task.relationship_status != TaskRelationshipStatus.ACTIVE.value + and task.action == TaskAction.PRODUCT_REVIEW.value)): + return False, None + + return True, inactive_sub + @staticmethod def create_product_subscription( - org_id, - subscription_data: Dict[str, Any], # pylint: disable=too-many-locals - is_new_transaction: bool = True, - skip_auth=False, - auto_approve=False, + org_id, + subscription_data: Dict[str, Any], # pylint: disable=too-many-locals + is_new_transaction: bool = True, + skip_auth=False, + auto_approve=False, ): """Create product subscription for the user. @@ -176,15 +196,18 @@ def create_product_subscription( check_auth(system_required=True, org_id=org_id) # Check if product needs premium account, if yes skip and continue. if ( - flags.is_on("remove-premium-restrictions", default=False) is False - and product_model.premium_only - and org.type_code not in PREMIUM_ORG_TYPES + flags.is_on("remove-premium-restrictions", default=False) is False + and product_model.premium_only + and org.type_code not in PREMIUM_ORG_TYPES ): continue + previously_approved, inactive_sub = Product._is_previously_approved(org_id, product_code) + if previously_approved: + auto_approve = True subscription_status = Product.find_subscription_status(org, product_model, auto_approve) product_subscription = Product._subscribe_and_publish_activity( - org_id, product_code, subscription_status, product_model.description + org_id, product_code, subscription_status, product_model.description, inactive_sub ) # If there is a linked product, add subscription to that too. @@ -229,6 +252,32 @@ def create_product_subscription( return Product.get_all_product_subscription(org_id=org_id, skip_auth=True) + @staticmethod + def remove_product_subscription(org_id: int, product_code: str, skip_auth=False): + """Deactivate org product subscription by code.""" + org: OrgModel = OrgModel.find_by_org_id(org_id) + if not org: + raise BusinessException(Error.DATA_NOT_FOUND, None) + + if not skip_auth: + check_auth(one_of_roles=(*CLIENT_ADMIN_ROLES, STAFF), org_id=org_id) + + existing_sub = ProductSubscriptionModel.find_by_org_id_product_code(org_id, product_code) + + if existing_sub: + existing_sub.status_code = ProductSubscriptionStatus.INACTIVE.value + existing_sub.save() + + pending_task = TaskModel.find_by_incomplete_task_relationship_id( + relationship_id=existing_sub.id, + task_relationship_type=TaskRelationshipType.PRODUCT.value, + relationship_status=ProductSubscriptionStatus.PENDING_STAFF_REVIEW.value + ) + if pending_task: + pending_task.delete() + + return Product.get_all_product_subscription(org_id=org_id, skip_auth=True) + @staticmethod def _send_product_subscription_confirmation(product_notification_info: ProductNotificationInfo, org_id: int): admin_emails = UserService.get_admin_emails_for_org(org_id) @@ -256,11 +305,19 @@ def _update_parent_subscription(org_id, sub_product_model, subscription_status): @staticmethod def _subscribe_and_publish_activity( - org_id: int, product_code: str, status_code: str, product_model_description: str + org_id: int, product_code: str, status_code: str, product_model_description: str, + inactive_sub: ProductSubscriptionModel = None ): - subscription = ProductSubscriptionModel( - org_id=org_id, product_code=product_code, status_code=status_code - ).flush() + subscription = None + if inactive_sub: + subscription = inactive_sub + subscription.status_code = status_code + subscription.flush() + else: + subscription = ProductSubscriptionModel( + org_id=org_id, product_code=product_code, status_code=status_code + ).flush() + if status_code == ProductSubscriptionStatus.ACTIVE.value: ActivityLogPublisher.publish_activity( Activity(org_id, ActivityAction.ADD_PRODUCT_AND_SERVICE.value, name=product_model_description) @@ -269,7 +326,8 @@ def _subscribe_and_publish_activity( @staticmethod def _reset_subscription_and_review_task( - review_task: TaskModel, product_model: ProductCodeModel, subscription: ProductSubscriptionModel, user_id: str + review_task: TaskModel, product_model: ProductCodeModel, subscription: ProductSubscriptionModel, + user_id: str ): review_task.status = TaskStatus.OPEN.value review_task.related_to = user_id @@ -342,8 +400,8 @@ def create_subscription_from_bcol_profile(org_id: int, bcol_profile_flags: List[ org_id=org_id, product_code=product_code, status_code=ProductSubscriptionStatus.ACTIVE.value ).flush() elif ( - subscription - and (existing_sub := subscription).status_code != ProductSubscriptionStatus.ACTIVE.value + subscription + and (existing_sub := subscription).status_code != ProductSubscriptionStatus.ACTIVE.value ): existing_sub.status_code = ProductSubscriptionStatus.ACTIVE.value existing_sub.flush() @@ -375,9 +433,9 @@ def get_all_product_subscription(org_id, skip_auth=False, **kwargs): # Include hidden products only for staff and SBC staff include_hidden = ( - user_from_context.is_staff() - or org.type_code == OrgType.SBC_STAFF.value - or kwargs.get("include_hidden", False) + user_from_context.is_staff() + or org.type_code == OrgType.SBC_STAFF.value + or kwargs.get("include_hidden", False) ) products = Product.get_products(include_hidden=include_hidden, staff_check=False) @@ -447,7 +505,7 @@ def update_product_subscription(product_sub_info: ProductSubscriptionInfo, is_ne @staticmethod def approve_reject_parent_subscription( - parent_product_code: int, is_approved: bool, is_hold: bool, org_id: int, is_new_transaction: bool = True + parent_product_code: int, is_approved: bool, is_hold: bool, org_id: int, is_new_transaction: bool = True ): """Approve or reject Parent Product Subscription.""" logger.debug("