Skip to content

Commit

Permalink
Refactor reporter handling to be able to pass it through request para…
Browse files Browse the repository at this point in the history
…ms (#355)
  • Loading branch information
pboers1988 authored Sep 19, 2023
1 parent 937b0ce commit 1c298ab
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 28 deletions.
37 changes: 18 additions & 19 deletions orchestrator/api/api_v1/endpoints/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@
ProcessDeprecationsSchema,
ProcessIdSchema,
ProcessResumeAllSchema,
ProcessStatusCounts,
ProcessSubscriptionBaseSchema,
ProcessSubscriptionSchema,
Reporter,
)
from orchestrator.schemas.process import ProcessStatusCounts
from orchestrator.security import oidc_user
from orchestrator.services.processes import (
SYSTEM_USER,
Expand Down Expand Up @@ -85,6 +86,14 @@ def check_global_lock() -> None:
)


def resolve_user_name(
reporter: Optional[Reporter] = None, resolved_user: Optional[OIDCUserModel] = Depends(oidc_user)
) -> str:
if reporter:
return reporter
return resolved_user.name if resolved_user and resolved_user.name else SYSTEM_USER


@router.delete("/{process_id}", response_model=None, status_code=HTTPStatus.NO_CONTENT)
def delete(process_id: UUID) -> None:
process = ProcessTable.query.filter_by(process_id=process_id).one_or_none()
Expand All @@ -108,11 +117,10 @@ def new_process(
workflow_key: str,
request: Request,
json_data: Optional[List[Dict[str, Any]]] = Body(...),
user: Optional[OIDCUserModel] = Depends(oidc_user),
user: str = Depends(resolve_user_name),
) -> Dict[str, UUID]:
user_name = user.name if user and user.name else SYSTEM_USER
broadcast_func = api_broadcast_process_data(request)
process_id = start_process(workflow_key, user_inputs=json_data, user=user_name, broadcast_func=broadcast_func)
process_id = start_process(workflow_key, user_inputs=json_data, user=user, broadcast_func=broadcast_func)

return {"id": process_id}

Expand All @@ -124,7 +132,7 @@ def new_process(
dependencies=[Depends(check_global_lock, use_cache=False)],
)
def resume_process_endpoint(
process_id: UUID, request: Request, json_data: JSON = Body(...), user: Optional[OIDCUserModel] = Depends(oidc_user)
process_id: UUID, request: Request, json_data: JSON = Body(...), user: str = Depends(resolve_user_name)
) -> None:
process = _get_process(process_id)

Expand All @@ -137,26 +145,20 @@ def resume_process_endpoint(
if process.last_status == ProcessStatus.RESUMED:
raise_status(HTTPStatus.CONFLICT, "Resuming a resumed workflow is not possible")

user_name = user.name if user and user.name else SYSTEM_USER

broadcast_func = api_broadcast_process_data(request)
resume_process(process, user=user_name, user_inputs=json_data, broadcast_func=broadcast_func)
resume_process(process, user=user, user_inputs=json_data, broadcast_func=broadcast_func)


@router.put(
"/resume-all", response_model=ProcessResumeAllSchema, dependencies=[Depends(check_global_lock, use_cache=False)]
)
async def resume_all_processess_endpoint(
request: Request, user: Optional[OIDCUserModel] = Depends(oidc_user)
) -> Dict[str, int]:
async def resume_all_processess_endpoint(request: Request, user: str = Depends(resolve_user_name)) -> Dict[str, int]:
"""Retry all task processes in status Failed, Waiting, API Unavailable or Inconsistent Data.
The retry is started in the background, returning status 200 and number of processes in message.
When it is already running, refuse and return status 409 instead.
"""

user_name = user.name if user and user.name else SYSTEM_USER

# Retrieve processes eligible for resuming
processes_to_resume = (
ProcessTable.query.filter(
Expand All @@ -174,7 +176,7 @@ async def resume_all_processess_endpoint(
)

broadcast_func = api_broadcast_process_data(request)
if not await _async_resume_processes(processes_to_resume, user_name, broadcast_func=broadcast_func):
if not await _async_resume_processes(processes_to_resume, user, broadcast_func=broadcast_func):
raise_status(HTTPStatus.CONFLICT, "Another request to resume all processes is in progress")

logger.info("Resuming all processes", count=len(processes_to_resume))
Expand All @@ -183,15 +185,12 @@ async def resume_all_processess_endpoint(


@router.put("/{process_id}/abort", response_model=None, status_code=HTTPStatus.NO_CONTENT)
def abort_process_endpoint(
process_id: UUID, request: Request, user: Optional[OIDCUserModel] = Depends(oidc_user)
) -> None:
def abort_process_endpoint(process_id: UUID, request: Request, user: str = Depends(resolve_user_name)) -> None:
process = _get_process(process_id)

user_name = user.name if user and user.name else SYSTEM_USER
broadcast_func = api_broadcast_process_data(request)
try:
abort_process(process, user_name, broadcast_func=broadcast_func)
abort_process(process, user, broadcast_func=broadcast_func)
return
except Exception as e:
raise_status(HTTPStatus.INTERNAL_SERVER_ERROR, str(e))
Expand Down
7 changes: 1 addition & 6 deletions orchestrator/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@
from orchestrator.distlock import init_distlock_manager
from orchestrator.domain import SUBSCRIPTION_MODEL_REGISTRY, SubscriptionModel
from orchestrator.exception_handlers import problem_detail_handler
from orchestrator.graphql import (
Mutation,
Query,
create_graphql_router,
register_domain_models,
)
from orchestrator.graphql import Mutation, Query, create_graphql_router, register_domain_models
from orchestrator.services.processes import ProcessDataBroadcastThread
from orchestrator.settings import AppSettings, ExecutorType, app_settings
from orchestrator.utils.vlans import VlanRanges
Expand Down
4 changes: 1 addition & 3 deletions orchestrator/db/filters/resource_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from orchestrator.db import ProductBlockTable, ResourceTypeTable
from orchestrator.db.database import SearchQuery
from orchestrator.db.filters import generic_filter
from orchestrator.db.filters.generic_filters import (
generic_is_like_filter,
)
from orchestrator.db.filters.generic_filters import generic_is_like_filter

logger = structlog.get_logger(__name__)

Expand Down
4 changes: 4 additions & 0 deletions orchestrator/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
ProcessIdSchema,
ProcessResumeAllSchema,
ProcessSchema,
ProcessStatusCounts,
ProcessSubscriptionBaseSchema,
ProcessSubscriptionSchema,
Reporter,
)
from orchestrator.schemas.product import ProductBaseSchema, ProductCRUDSchema, ProductSchema
from orchestrator.schemas.product_block import ProductBlockBaseSchema, ProductBlockEnrichedSchema
Expand Down Expand Up @@ -68,4 +70,6 @@
"WorkerStatus",
"WorkflowSchema",
"WorkflowWithProductTagsSchema",
"Reporter",
"ProcessStatusCounts",
)
6 changes: 6 additions & 0 deletions orchestrator/schemas/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import Any, Dict, List, Optional
from uuid import UUID

from pydantic import ConstrainedStr

from orchestrator.config.assignee import Assignee
from orchestrator.schemas.base import OrchestratorBaseModel
from orchestrator.schemas.subscription import SubscriptionSchema
Expand Down Expand Up @@ -110,3 +112,7 @@ class ProcessResumeAllSchema(OrchestratorBaseModel):
class ProcessStatusCounts(OrchestratorBaseModel):
process_counts: Dict[ProcessStatus, int]
task_counts: Dict[ProcessStatus, int]


class Reporter(ConstrainedStr):
max_length = 100

0 comments on commit 1c298ab

Please sign in to comment.