From 1c298abd9edd967437beff8a035babba53eb2f4e Mon Sep 17 00:00:00 2001 From: Peter Boers Date: Tue, 19 Sep 2023 08:35:28 -0400 Subject: [PATCH] Refactor reporter handling to be able to pass it through request params (#355) --- .../api/api_v1/endpoints/processes.py | 37 +++++++++---------- orchestrator/app.py | 7 +--- orchestrator/db/filters/resource_type.py | 4 +- orchestrator/schemas/__init__.py | 4 ++ orchestrator/schemas/process.py | 6 +++ 5 files changed, 30 insertions(+), 28 deletions(-) diff --git a/orchestrator/api/api_v1/endpoints/processes.py b/orchestrator/api/api_v1/endpoints/processes.py index 7ee7a0c80..9dfbe40d9 100644 --- a/orchestrator/api/api_v1/endpoints/processes.py +++ b/orchestrator/api/api_v1/endpoints/processes.py @@ -44,10 +44,11 @@ ProcessDeprecationsSchema, ProcessIdSchema, ProcessResumeAllSchema, + ProcessStatusCounts, ProcessSubscriptionBaseSchema, ProcessSubscriptionSchema, + Reporter, ) -from orchestrator.schemas.process import ProcessStatusCounts from orchestrator.security import oidc_user from orchestrator.services.processes import ( SYSTEM_USER, @@ -85,6 +86,14 @@ def check_global_lock() -> None: ) +def resolve_user_name( + reporter: Optional[Reporter] = None, resolved_user: Optional[OIDCUserModel] = Depends(oidc_user) +) -> str: + if reporter: + return reporter + return resolved_user.name if resolved_user and resolved_user.name else SYSTEM_USER + + @router.delete("/{process_id}", response_model=None, status_code=HTTPStatus.NO_CONTENT) def delete(process_id: UUID) -> None: process = ProcessTable.query.filter_by(process_id=process_id).one_or_none() @@ -108,11 +117,10 @@ def new_process( workflow_key: str, request: Request, json_data: Optional[List[Dict[str, Any]]] = Body(...), - user: Optional[OIDCUserModel] = Depends(oidc_user), + user: str = Depends(resolve_user_name), ) -> Dict[str, UUID]: - user_name = user.name if user and user.name else SYSTEM_USER broadcast_func = api_broadcast_process_data(request) - process_id = start_process(workflow_key, user_inputs=json_data, user=user_name, broadcast_func=broadcast_func) + process_id = start_process(workflow_key, user_inputs=json_data, user=user, broadcast_func=broadcast_func) return {"id": process_id} @@ -124,7 +132,7 @@ def new_process( dependencies=[Depends(check_global_lock, use_cache=False)], ) def resume_process_endpoint( - process_id: UUID, request: Request, json_data: JSON = Body(...), user: Optional[OIDCUserModel] = Depends(oidc_user) + process_id: UUID, request: Request, json_data: JSON = Body(...), user: str = Depends(resolve_user_name) ) -> None: process = _get_process(process_id) @@ -137,26 +145,20 @@ def resume_process_endpoint( if process.last_status == ProcessStatus.RESUMED: raise_status(HTTPStatus.CONFLICT, "Resuming a resumed workflow is not possible") - user_name = user.name if user and user.name else SYSTEM_USER - broadcast_func = api_broadcast_process_data(request) - resume_process(process, user=user_name, user_inputs=json_data, broadcast_func=broadcast_func) + resume_process(process, user=user, user_inputs=json_data, broadcast_func=broadcast_func) @router.put( "/resume-all", response_model=ProcessResumeAllSchema, dependencies=[Depends(check_global_lock, use_cache=False)] ) -async def resume_all_processess_endpoint( - request: Request, user: Optional[OIDCUserModel] = Depends(oidc_user) -) -> Dict[str, int]: +async def resume_all_processess_endpoint(request: Request, user: str = Depends(resolve_user_name)) -> Dict[str, int]: """Retry all task processes in status Failed, Waiting, API Unavailable or Inconsistent Data. The retry is started in the background, returning status 200 and number of processes in message. When it is already running, refuse and return status 409 instead. """ - user_name = user.name if user and user.name else SYSTEM_USER - # Retrieve processes eligible for resuming processes_to_resume = ( ProcessTable.query.filter( @@ -174,7 +176,7 @@ async def resume_all_processess_endpoint( ) broadcast_func = api_broadcast_process_data(request) - if not await _async_resume_processes(processes_to_resume, user_name, broadcast_func=broadcast_func): + if not await _async_resume_processes(processes_to_resume, user, broadcast_func=broadcast_func): raise_status(HTTPStatus.CONFLICT, "Another request to resume all processes is in progress") logger.info("Resuming all processes", count=len(processes_to_resume)) @@ -183,15 +185,12 @@ async def resume_all_processess_endpoint( @router.put("/{process_id}/abort", response_model=None, status_code=HTTPStatus.NO_CONTENT) -def abort_process_endpoint( - process_id: UUID, request: Request, user: Optional[OIDCUserModel] = Depends(oidc_user) -) -> None: +def abort_process_endpoint(process_id: UUID, request: Request, user: str = Depends(resolve_user_name)) -> None: process = _get_process(process_id) - user_name = user.name if user and user.name else SYSTEM_USER broadcast_func = api_broadcast_process_data(request) try: - abort_process(process, user_name, broadcast_func=broadcast_func) + abort_process(process, user, broadcast_func=broadcast_func) return except Exception as e: raise_status(HTTPStatus.INTERNAL_SERVER_ERROR, str(e)) diff --git a/orchestrator/app.py b/orchestrator/app.py index b95a829a5..6f855535f 100644 --- a/orchestrator/app.py +++ b/orchestrator/app.py @@ -41,12 +41,7 @@ from orchestrator.distlock import init_distlock_manager from orchestrator.domain import SUBSCRIPTION_MODEL_REGISTRY, SubscriptionModel from orchestrator.exception_handlers import problem_detail_handler -from orchestrator.graphql import ( - Mutation, - Query, - create_graphql_router, - register_domain_models, -) +from orchestrator.graphql import Mutation, Query, create_graphql_router, register_domain_models from orchestrator.services.processes import ProcessDataBroadcastThread from orchestrator.settings import AppSettings, ExecutorType, app_settings from orchestrator.utils.vlans import VlanRanges diff --git a/orchestrator/db/filters/resource_type.py b/orchestrator/db/filters/resource_type.py index 09da4e36a..c5ab99f24 100644 --- a/orchestrator/db/filters/resource_type.py +++ b/orchestrator/db/filters/resource_type.py @@ -5,9 +5,7 @@ from orchestrator.db import ProductBlockTable, ResourceTypeTable from orchestrator.db.database import SearchQuery from orchestrator.db.filters import generic_filter -from orchestrator.db.filters.generic_filters import ( - generic_is_like_filter, -) +from orchestrator.db.filters.generic_filters import generic_is_like_filter logger = structlog.get_logger(__name__) diff --git a/orchestrator/schemas/__init__.py b/orchestrator/schemas/__init__.py index a512c28c6..704fa4bb4 100644 --- a/orchestrator/schemas/__init__.py +++ b/orchestrator/schemas/__init__.py @@ -25,8 +25,10 @@ ProcessIdSchema, ProcessResumeAllSchema, ProcessSchema, + ProcessStatusCounts, ProcessSubscriptionBaseSchema, ProcessSubscriptionSchema, + Reporter, ) from orchestrator.schemas.product import ProductBaseSchema, ProductCRUDSchema, ProductSchema from orchestrator.schemas.product_block import ProductBlockBaseSchema, ProductBlockEnrichedSchema @@ -68,4 +70,6 @@ "WorkerStatus", "WorkflowSchema", "WorkflowWithProductTagsSchema", + "Reporter", + "ProcessStatusCounts", ) diff --git a/orchestrator/schemas/process.py b/orchestrator/schemas/process.py index 550d76d63..cb354bb0a 100644 --- a/orchestrator/schemas/process.py +++ b/orchestrator/schemas/process.py @@ -15,6 +15,8 @@ from typing import Any, Dict, List, Optional from uuid import UUID +from pydantic import ConstrainedStr + from orchestrator.config.assignee import Assignee from orchestrator.schemas.base import OrchestratorBaseModel from orchestrator.schemas.subscription import SubscriptionSchema @@ -110,3 +112,7 @@ class ProcessResumeAllSchema(OrchestratorBaseModel): class ProcessStatusCounts(OrchestratorBaseModel): process_counts: Dict[ProcessStatus, int] task_counts: Dict[ProcessStatus, int] + + +class Reporter(ConstrainedStr): + max_length = 100