Skip to content

Commit

Permalink
Implement Policy Mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
chaen committed Jun 3, 2024
1 parent a0b0af6 commit 01b4530
Show file tree
Hide file tree
Showing 29 changed files with 1,218 additions and 162 deletions.
21 changes: 12 additions & 9 deletions diracx-routers/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ description = "TODO"
readme = "README.md"
requires-python = ">=3.10"
keywords = []
license = {text = "GPL-3.0-only"}
license = { text = "GPL-3.0-only" }
classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
Expand All @@ -20,7 +20,7 @@ dependencies = [
"dirac",
"diracx-core",
"diracx-db",
"python-dotenv", # TODO: We might not need this
"python-dotenv", # TODO: We might not need this
"python-multipart",
"fastapi",
"httpx",
Expand All @@ -35,11 +35,7 @@ dependencies = [
dynamic = ["version"]

[project.optional-dependencies]
testing = [
"diracx-testing",
"moto[server]",
"pytest-httpx",
]
testing = ["diracx-testing", "moto[server]", "pytest-httpx"]
types = [
"boto3-stubs",
"types-aiobotocore[essential]",
Expand All @@ -56,6 +52,11 @@ config = "diracx.routers.configuration:router"
auth = "diracx.routers.auth:router"
".well-known" = "diracx.routers.auth.well_known:router"

[project.entry-points."diracx.access_policies"]
WMSAccessPolicy = "diracx.routers.job_manager.access_policies:WMSAccessPolicy"
SandboxAccessPolicy = "diracx.routers.job_manager.access_policies:SandboxAccessPolicy"


[tool.setuptools.packages.find]
where = ["src"]

Expand All @@ -70,8 +71,10 @@ root = ".."
testpaths = ["tests"]
addopts = [
"-v",
"--cov=diracx.routers", "--cov-report=term-missing",
"-pdiracx.testing", "-pdiracx.testing.osdb",
"--cov=diracx.routers",
"--cov-report=term-missing",
"-pdiracx.testing",
"-pdiracx.testing.osdb",
"--import-mode=importlib",
]
asyncio_mode = "auto"
Expand Down
117 changes: 113 additions & 4 deletions diracx-routers/src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
"""
# Startup sequence
uvicorn is called with `create_app` as a factory
create_app loads the environment configuration
"""

from __future__ import annotations

import inspect
Expand All @@ -6,7 +17,7 @@
from collections.abc import AsyncGenerator
from functools import partial
from logging import Formatter, StreamHandler
from typing import Any, Awaitable, Callable, Iterable, TypeVar, cast
from typing import Any, Awaitable, Callable, Iterable, Sequence, TypeVar, cast

import dotenv
from cachetools import TTLCache
Expand All @@ -28,10 +39,11 @@
from diracx.db.exceptions import DBUnavailable
from diracx.db.os.utils import BaseOSDB
from diracx.db.sql.utils import BaseSQLDB
from diracx.routers.access_policies import BaseAccessPolicy, check_permissions

from .auth import verify_dirac_access_token
from .fastapi_classes import DiracFastAPI, DiracxRouter
from .otel import instrument_otel
from .utils.users import verify_dirac_access_token

T = TypeVar("T")
T2 = TypeVar("T2", bound=BaseSQLDB | BaseOSDB)
Expand Down Expand Up @@ -83,6 +95,7 @@ def configure_logger():
# All routes must have tags (needed for auto gen of client)
# Form headers must have a description (autogen)
# methods name should follow the generate_unique_id_function pattern
# All routes should have a policy mechanism


def create_app_inner(
Expand All @@ -92,21 +105,83 @@ def create_app_inner(
database_urls: dict[str, str],
os_database_conn_kwargs: dict[str, Any],
config_source: ConfigSource,
all_access_policies: dict[str, Sequence[BaseAccessPolicy]],
) -> DiracFastAPI:
"""
This method does the heavy lifting work of putting all the pieces together.
When starting the application normaly, this method is called by create_app,
and the values of the parameters are taken from environment variables or
entrypoints.
When running tests, the parameters are mocks or test settings.
We rely on the dependency_override mechanism to implement
the actual behavior we are interested in for settings, DBs or policy.
This allows an extension to override any of these components
:param enabled_system:
this contains the name of all the routers we have to load
:param all_service_settings:
list of instance of each Settings type required
:param database_urls:
dict <db_name: url>. When testing, sqlite urls are used
:param os_database_conn_kwargs:
<db_name:dict> containing all the parameters the OpenSearch client takes
:param config_source:
Source of the configuration to use
:param all_access_policies:
<policy_name: [implementations]>
"""

app = DiracFastAPI()

# Find which settings classes are available and add them to dependency_overrides
# We use a single instance of each Setting classes for performance reasons,
# since it avoids recreating a pydantic model every time
# We add the Settings lifetime_function to the application lifetime_function,
# Please see ServiceSettingsBase for more details

available_settings_classes: set[type[ServiceSettingsBase]] = set()
for service_settings in all_service_settings:
cls = type(service_settings)
assert cls not in available_settings_classes
available_settings_classes.add(cls)
app.lifetime_functions.append(service_settings.lifetime_function)
# We always return the same setting instance for perf reasons
app.dependency_overrides[cls.create] = partial(lambda x: x, service_settings)

# Override the configuration source
# Override the ConfigSource.create by the actual reading of the config
app.dependency_overrides[ConfigSource.create] = config_source.read_config

all_access_policies_used = {}

for access_policy_name, access_policy_classes in all_access_policies.items():

# The first AccessPolicy is the highest priority one
access_policy_used = access_policy_classes[0].policy
all_access_policies_used[access_policy_name] = access_policy_classes[0]

# app.lifetime_functions.append(access_policy.lifetime_function)
# Add overrides for all the AccessPolicy classes, including those from extensions
# This means vanilla DiracX routers get an instance of the extension's AccessPolicy
for access_policy_class in access_policy_classes:
# Here we do not check that access_policy_class.check is
# not already in the dependency_overrides becaue the same
# policy could be used for multiple purpose
# (e.g. open access)
# assert access_policy_class.check not in app.dependency_overrides
app.dependency_overrides[access_policy_class.check] = partial(
check_permissions, access_policy_used, access_policy_name
)

app.dependency_overrides[BaseAccessPolicy.all_used_access_policies] = (
lambda: all_access_policies_used
)

fail_startup = True
# Add the SQL DBs to the application
available_sql_db_classes: set[type[BaseSQLDB]] = set()
Expand Down Expand Up @@ -237,7 +312,22 @@ def create_app_inner(


def create_app() -> DiracFastAPI:
"""Load settings from the environment and create the application object"""
"""Load settings from the environment and create the application object
The configuration may be placed in .env files pointed to by
environment variables DIRACX_SERVICE_DOTENV.
They can be followed by "_X" where X is a number, and the order
is respected.
We then loop over all the diracx.services definitions.
A specific route can be disabled with an environment variable
DIRACX_SERVICE_<name>_ENABLED=false
For each of the enabled route, we inspect which Setting classes
are needed.
We attempt to load each setting classes to make sure that the
settings are correctly defined.
"""
for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"):
logger.debug("Loading dotenv file: %s", env_file)
if not dotenv.load_dotenv(env_file):
Expand All @@ -261,12 +351,31 @@ def create_app() -> DiracFastAPI:
# Load settings classes required by the routers
all_service_settings = [settings_class() for settings_class in settings_classes]

# Find all the access policies

available_access_policy_names = set(
[
entry_point.name
for entry_point in select_from_extension(group="diracx.access_policies")
]
)

all_access_policies = {}

for access_policy_name in available_access_policy_names:

access_policy_classes = BaseAccessPolicy.available_implementations(
access_policy_name
)
all_access_policies[access_policy_name] = access_policy_classes

return create_app_inner(
enabled_systems=enabled_systems,
all_service_settings=all_service_settings,
database_urls=BaseSQLDB.available_urls(),
os_database_conn_kwargs=BaseOSDB.available_urls(),
config_source=ConfigSource.create(),
all_access_policies=all_access_policies,
)


Expand Down
Loading

0 comments on commit 01b4530

Please sign in to comment.