Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismeyersfsu committed Dec 10, 2024
1 parent e411283 commit 1e7f409
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 78 deletions.
24 changes: 15 additions & 9 deletions src/awx_plugins/interfaces/_temporary_private_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
The hope is that it will be refactored into something more standardized.
"""

import collections
import os
import re
import stat
Expand All @@ -21,6 +20,7 @@
GenericOptionalPrimitiveType,
)


InputSchemaType = dict[str, list[dict[str, str | bool]]]

HIDDEN_PASSWORD = '*' * 10
Expand Down Expand Up @@ -50,9 +50,10 @@
'AWX_HOST',
'PROJECT_REVISION',
'SUPERVISOR_CONFIG_PATH',
)
),
)


def build_safe_env(
env: dict[str, GenericOptionalPrimitiveType],
) -> dict[str, GenericOptionalPrimitiveType]:
Expand Down Expand Up @@ -209,20 +210,23 @@ class TowerNamespace:
safe_namespace[field_id] = False
# make sure private keys end with a \n
if field.get('format') == 'ssh_private_key':
if field_id in namespace and not str(namespace[field_id]).endswith(
'\n',
):
if field_id in namespace and not str(
namespace[field_id],
).endswith('\n'):
namespace[field_id] = str(namespace[field_id]) + '\n'

file_tmpls = self.injectors.get('file', {})
# If any file templates are provided, render the files and update the
# special `tower` template namespace so the filename can be
# referenced in other injectors

sandbox_env = sandbox.ImmutableSandboxedEnvironment() # type: ignore[misc]
# type: ignore[misc]
sandbox_env = sandbox.ImmutableSandboxedEnvironment()

for file_label, file_tmpl in file_tmpls.items():
data: str = sandbox_env.from_string(file_tmpl).render(**namespace) # type: ignore[misc]
data: str = sandbox_env.from_string(file_tmpl).render(
**namespace,
) # type: ignore[misc]
env_dir = os.path.join(private_data_dir, 'env')
_, path = tempfile.mkstemp(dir=env_dir)
with open(path, 'w') as f:
Expand All @@ -249,7 +253,9 @@ class TowerNamespace:

if 'INVENTORY_UPDATE_ID' not in env:
# awx-manage inventory_update does not support extra_vars via -e
def build_extra_vars(node: dict[str, str | list[str]] | list[str] | str) -> dict[str, str] | list[str] | str:
def build_extra_vars(
node: dict[str, str | list[str]] | list[str] | str,
) -> dict[str, str] | list[str] | str:
if isinstance(node, dict):
return {
build_extra_vars(k): build_extra_vars(v) for k,
Expand Down Expand Up @@ -283,7 +289,7 @@ class CredentialPlugin:
inputs: InputSchemaType
backend: Callable[
[
InputSchemaType
InputSchemaType,
], None,
]

Expand Down
87 changes: 53 additions & 34 deletions src/awx_plugins/interfaces/registry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from importlib.metadata import EntryPoint, entry_points
"""Credential registry.
Load and track custom credentials
"""

import inspect
from importlib.metadata import EntryPoint, entry_points

from awx_plugins.interfaces._temporary_private_api import ManagedCredentialType, CredentialPlugin
from awx_plugins.interfaces._temporary_private_licensing_api import (
from ._temporary_private_api import ( # noqa: WPS436
CredentialPlugin,
ManagedCredentialType,
)
from ._temporary_private_licensing_api import ( # noqa: WPS436
detect_server_product_name,
)

Expand All @@ -11,30 +19,33 @@ def detect_server_product_name():
return 'NO_AWX'


class BasePluginRegistry:
def _get_all_entry_points_for(self,
entry_point_subsections: list[str], /) -> dict[str, EntryPoint]:
return {
ep.name: ep
for entry_point_category in entry_point_subsections
for ep in entry_points(group=f'awx_plugins.{entry_point_category}')
}
def _get_all_entry_points_for(
entry_point_subsections: list[str], /,
) -> dict[str, EntryPoint]:
return {
ep.name: ep
for entry_point_category in entry_point_subsections
for ep in entry_points(group=f'awx_plugins.{entry_point_category}')
}


class BaseCredentialTypeRegistry(BasePluginRegistry):
class BaseCredentialTypeRegistry:
"""Load and track ManagedCredentialType plugins."""

_registry: dict[str, ManagedCredentialType] = {}

def add(self, credential_type: ManagedCredentialType) -> None:
"""Add the credential type to the registry
:param credential_type: ManagedCredentialType to add to the registery
"""Add the credential type to the registry.
:param credential_type: ManagedCredentialType to add to the
registry
"""
namespace = credential_type.namespace
if namespace in self._registry:
raise ValueError('a ManagedCredentialType with namespace={} is already defined in {}'.format(
namespace, inspect.getsourcefile(self._registry[namespace].__class__)
))
raise ValueError(
'existing entry for namespace={} defined in {}'.format(
namespace, inspect.getsourcefile(
self._registry[namespace].__class__)))
self._registry[namespace] = credential_type

def get(self, namespace: str) -> ManagedCredentialType:
Expand All @@ -44,14 +55,14 @@ def get(self, namespace: str) -> ManagedCredentialType:
:returns: The ManagedCredentialType object or None if not found
"""
return self._registry.get(namespace, None)

def get_keys(self) -> list[str]:
"""All loaded plugin names.
:returns: A list of plugin names.
"""
return self._registry.keys()

def get_all(self) -> list[ManagedCredentialType]:
"""Access all loaded credential types.
Expand All @@ -61,24 +72,31 @@ def get_all(self) -> list[ManagedCredentialType]:

def _discover_all(self) -> list[EntryPoint]:
"""Find all entry points to be loaded."""
raise ValueError("Implement me")
raise ValueError('Implement me')

def load_all(self):
"""Load all the discovered plugins"""
for entry_point_name, entry_point in self._discover_all().items():
"""Load all the discovered plugins."""
for entry_point in self._discover_all().values():
credential_type = entry_point.load()
self.add(credential_type)


class _ManagedCredentialTypeRegistry(BaseCredentialTypeRegistry):
def _discover_all(self) -> list[EntryPoint]:
"""Find all relevant awx managed credential entry points."""

is_awx = detect_server_product_name() == 'AWX'
return self._get_all_entry_points_for(['managed_credentials'] if is_awx else ['managed_credentials', 'managed_credentials.supported'])
return _get_all_entry_points_for(
['managed_credentials'] if is_awx else [
'managed_credentials',
'managed_credentials.supported',
],
)


class _CredentialPluginRegistry(BaseCredentialTypeRegistry):
"""Load and manage lookup credential plugin types"""
"""Load and manage lookup credential plugin types."""

_registry: dict[str, CredentialPlugin] = {}

def get_all(self) -> list[CredentialPlugin]:
Expand All @@ -87,20 +105,21 @@ def get_all(self) -> list[CredentialPlugin]:
:returns: All loaded credential plugins
"""
return self._registry.values()

def _discover_all(self) -> list[EntryPoint]:
"""Find all user managed credential entry points."""
return self._get_all_entry_points_for(['credentials',])
return _get_all_entry_points_for(['credentials'])

def load_all(self):
"""Load all the user discovered plugins"""
"""Load all the user discovered plugins."""
for entry_point_name, entry_point in self._discover_all().items():
cred_plugin = entry_point.load()
self._registry[entry_point_name] = CredentialPlugin(
name=cred_plugin.name,
inputs=cred_plugin.inputs,
backend=cred_plugin.backend
backend=cred_plugin.backend,
)



ManagedCredentialTypeRegistry = _ManagedCredentialTypeRegistry()
CredentialPluginRegistry = _CredentialPluginRegistry()
CredentialPluginRegistry = _CredentialPluginRegistry()
Loading

0 comments on commit 1e7f409

Please sign in to comment.