Skip to content

Commit

Permalink
fucking mess
Browse files Browse the repository at this point in the history
  • Loading branch information
chaen committed Mar 27, 2024
1 parent 02b4374 commit c0af7d4
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 158 deletions.
8 changes: 6 additions & 2 deletions diracx-routers/src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def create_app_inner(
for entry_point in select_from_extension(group="diracx.access_policies")
]
)

available_access_policy_names = []
for access_policy_name in available_access_policy_names:

access_policy_classes = BaseAccessPolicy.available_implementations(
access_policy_name
)
Expand All @@ -95,6 +94,11 @@ def create_app_inner(
app.dependency_overrides[access_policy_class.check] = partial(
check_permissions, access_policy
)
from diracx.routers.job_manager.access_policies import WMSAccessPolicy

app.dependency_overrides[WMSAccessPolicy.check] = partial(
check_permissions, WMSAccessPolicy()
)

fail_startup = True
# Add the SQL DBs to the application
Expand Down
Empty file.
5 changes: 3 additions & 2 deletions diracx-routers/src/diracx/routers/job_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Annotated, Any, TypedDict
from typing import Annotated, Any, Callable, TypedDict

from fastapi import BackgroundTasks, Body, Depends, HTTPException, Query
from pydantic import BaseModel, root_validator
Expand Down Expand Up @@ -32,9 +32,10 @@
from ..auth import AuthorizedUserInfo, verify_dirac_access_token
from ..dependencies import JobDB, JobLoggingDB, SandboxMetadataDB, TaskQueueDB
from ..fastapi_classes import DiracxRouter
from .access_policies import ActionType, WMSAccessPolicyCallable
from .access_policies import ActionType, WMSAccessPolicy
from .sandboxes import router as sandboxes_router

WMSAccessPolicyCallable = Annotated[Callable, Depends(WMSAccessPolicy.check)]
MAX_PARAMETRIC_JOBS = 20

logger = logging.getLogger(__name__)
Expand Down
186 changes: 186 additions & 0 deletions diracx-routers/src/diracx/routers/job_manager/access_policies copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# from __future__ import annotations

# import contextlib
# import functools
# import os
# from enum import StrEnum, auto
# from typing import Annotated, AsyncIterator, Callable, Self

# from fastapi import Depends, HTTPException, status

# from diracx.core.extensions import select_from_extension
# from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER
# from diracx.db.sql import JobDB

# from ..auth import AuthorizedUserInfo, verify_dirac_access_token


# class ActionType(StrEnum):
# CREATE = auto()
# READ = auto()
# MANAGE = auto()
# QUERY = auto()


# async def default_wms_policy(
# user_info: AuthorizedUserInfo,
# /,
# *,
# action: ActionType,
# job_db: JobDB,
# job_ids: list[int] | None = None,
# ):
# """Implement the JobPolicy"""
# if action == ActionType.CREATE:
# if job_ids is not None:
# raise NotImplementedError(
# "job_ids is not None with ActionType.CREATE. This shouldn't happen"
# )
# if NORMAL_USER not in user_info.properties:
# raise HTTPException(status.HTTP_403_FORBIDDEN)
# return

# if JOB_ADMINISTRATOR in user_info.properties:
# return

# if NORMAL_USER not in user_info.properties:
# raise HTTPException(status.HTTP_403_FORBIDDEN)

# if action == ActionType.QUERY:
# if job_ids is not None:
# raise NotImplementedError(
# "job_ids is not None with ActionType.QUERY. This shouldn't happen"
# )
# return

# if job_ids is None:
# raise NotImplementedError("job_ids is None. his shouldn't happen")

# # TODO: check the CS global job monitoring flag

# job_owners = await job_db.summary(
# ["Owner", "VO"],
# [{"parameter": "JobID", "operator": "in", "values": job_ids}],
# )

# expected_owner = {
# "Owner": user_info.preferred_username,
# "VO": user_info.vo,
# "count": len(set(job_ids)),
# }
# # All the jobs belong to the user doing the query
# # and all of them are present
# if job_owners == [expected_owner]:
# return

# raise HTTPException(status.HTTP_403_FORBIDDEN)


# class BaseAccessPolicy:

# policy: Callable

# @classmethod
# def check(cls) -> Self:
# raise NotImplementedError("This should never be called")

# @contextlib.asynccontextmanager
# async def lifetime_function(self) -> AsyncIterator[None]:
# """A context manager that can be used to run code at startup and shutdown."""
# yield

# @classmethod
# def available_implementations(
# cls, access_policy_name: str
# ) -> list[type[BaseAccessPolicy]]:
# """Return the available implementations of the AccessPolicy in reverse priority order."""
# policy_classes: list[type[BaseAccessPolicy]] = [
# entry_point.load()
# for entry_point in select_from_extension(
# group="diracx.access_policies", name=access_policy_name
# )
# ]
# if not policy_classes:
# raise NotImplementedError(
# f"Could not find any matches for {access_policy_name=}"
# )
# return policy_classes


# class WMSAccessPolicy(BaseAccessPolicy):
# policy = staticmethod(default_wms_policy)


# def check_permissions(
# access_policy_instance,
# user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
# ):
# """
# This is what every route should depend on to check user permissions.

# It yield an access policy that needs to be checked.
# If this is declared as a dependency but not called
# """
# has_been_called = False

# # # TODO: query the CS to find the actual policy
# # policy = default_wms_policy

# @functools.wraps(access_policy_instance.policy)
# async def wrapped_policy(**kwargs):
# """This wrapper is just to update the has_been_called flag"""
# nonlocal has_been_called
# has_been_called = True
# return await access_policy_instance.policy(user_info, **kwargs)

# try:
# yield wrapped_policy
# finally:
# if not has_been_called:
# # TODO nice error message with inspect
# # That should really not happen
# print(
# "THIS SHOULD NOT HAPPEN, ALWAYS VERIFY PERMISSION",
# "(PS: I hope you are in a CI)",
# flush=True,
# )
# os._exit(1)


# # def check_permissions_alone(
# # policy,
# # user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
# # ):
# # """
# # This is what every route should depend on to check user permissions.

# # It yield an access policy that needs to be checked.
# # If this is declared as a dependency but not called
# # """
# # has_been_called = False

# # # # TODO: query the CS to find the actual policy
# # # policy = default_wms_policy

# # @functools.wraps(policy)
# # async def wrapped_policy(**kwargs):
# # """This wrapper is just to update the has_been_called flag"""
# # nonlocal has_been_called
# # has_been_called = True
# # return await policy(user_info, **kwargs)

# # try:
# # yield wrapped_policy
# # finally:
# # if not has_been_called:
# # # TODO nice error message with inspect
# # # That should really not happen
# # print(
# # "THIS SHOULD NOT HAPPEN, ALWAYS VERIFY PERMISSION",
# # "(PS: I hope you are in a CI)",
# # flush=True,
# # )
# # os._exit(1)


# WMSAccessPolicyCallable = Annotated[Callable, Depends(WMSAccessPolicy.check)]
Loading

0 comments on commit c0af7d4

Please sign in to comment.