From 1e7f4097d43c886e45472d290457f58c8a3a4b32 Mon Sep 17 00:00:00 2001 From: Chris Meyers Date: Mon, 9 Dec 2024 23:53:09 -0500 Subject: [PATCH] Formatting --- .../interfaces/_temporary_private_api.py | 24 +-- src/awx_plugins/interfaces/registry.py | 87 ++++++----- tests/_temporary_private_api_test.py | 138 +++++++++++++----- 3 files changed, 171 insertions(+), 78 deletions(-) diff --git a/src/awx_plugins/interfaces/_temporary_private_api.py b/src/awx_plugins/interfaces/_temporary_private_api.py index 07e4b9de..38640c04 100644 --- a/src/awx_plugins/interfaces/_temporary_private_api.py +++ b/src/awx_plugins/interfaces/_temporary_private_api.py @@ -4,7 +4,6 @@ The hope is that it will be refactored into something more standardized. """ -import collections import os import re import stat @@ -21,6 +20,7 @@ GenericOptionalPrimitiveType, ) + InputSchemaType = dict[str, list[dict[str, str | bool]]] HIDDEN_PASSWORD = '*' * 10 @@ -50,9 +50,10 @@ 'AWX_HOST', 'PROJECT_REVISION', 'SUPERVISOR_CONFIG_PATH', - ) + ), ) + def build_safe_env( env: dict[str, GenericOptionalPrimitiveType], ) -> dict[str, GenericOptionalPrimitiveType]: @@ -209,9 +210,9 @@ class TowerNamespace: safe_namespace[field_id] = False # make sure private keys end with a \n if field.get('format') == 'ssh_private_key': - if field_id in namespace and not str(namespace[field_id]).endswith( - '\n', - ): + if field_id in namespace and not str( + namespace[field_id], + ).endswith('\n'): namespace[field_id] = str(namespace[field_id]) + '\n' file_tmpls = self.injectors.get('file', {}) @@ -219,10 +220,13 @@ class TowerNamespace: # special `tower` template namespace so the filename can be # referenced in other injectors - sandbox_env = sandbox.ImmutableSandboxedEnvironment() # type: ignore[misc] + # type: ignore[misc] + sandbox_env = sandbox.ImmutableSandboxedEnvironment() for file_label, file_tmpl in file_tmpls.items(): - data: str = sandbox_env.from_string(file_tmpl).render(**namespace) # type: ignore[misc] + data: str = sandbox_env.from_string(file_tmpl).render( + **namespace, + ) # type: ignore[misc] env_dir = os.path.join(private_data_dir, 'env') _, path = tempfile.mkstemp(dir=env_dir) with open(path, 'w') as f: @@ -249,7 +253,9 @@ class TowerNamespace: if 'INVENTORY_UPDATE_ID' not in env: # awx-manage inventory_update does not support extra_vars via -e - def build_extra_vars(node: dict[str, str | list[str]] | list[str] | str) -> dict[str, str] | list[str] | str: + def build_extra_vars( + node: dict[str, str | list[str]] | list[str] | str, + ) -> dict[str, str] | list[str] | str: if isinstance(node, dict): return { build_extra_vars(k): build_extra_vars(v) for k, @@ -283,7 +289,7 @@ class CredentialPlugin: inputs: InputSchemaType backend: Callable[ [ - InputSchemaType + InputSchemaType, ], None, ] diff --git a/src/awx_plugins/interfaces/registry.py b/src/awx_plugins/interfaces/registry.py index cce24136..7d0aae66 100644 --- a/src/awx_plugins/interfaces/registry.py +++ b/src/awx_plugins/interfaces/registry.py @@ -1,8 +1,16 @@ -from importlib.metadata import EntryPoint, entry_points +"""Credential registry. + +Load and track custom credentials +""" + import inspect +from importlib.metadata import EntryPoint, entry_points -from awx_plugins.interfaces._temporary_private_api import ManagedCredentialType, CredentialPlugin -from awx_plugins.interfaces._temporary_private_licensing_api import ( +from ._temporary_private_api import ( # noqa: WPS436 + CredentialPlugin, + ManagedCredentialType, +) +from ._temporary_private_licensing_api import ( # noqa: WPS436 detect_server_product_name, ) @@ -11,30 +19,33 @@ def detect_server_product_name(): return 'NO_AWX' -class BasePluginRegistry: - def _get_all_entry_points_for(self, - entry_point_subsections: list[str], /) -> dict[str, EntryPoint]: - return { - ep.name: ep - for entry_point_category in entry_point_subsections - for ep in entry_points(group=f'awx_plugins.{entry_point_category}') - } +def _get_all_entry_points_for( + entry_point_subsections: list[str], /, +) -> dict[str, EntryPoint]: + return { + ep.name: ep + for entry_point_category in entry_point_subsections + for ep in entry_points(group=f'awx_plugins.{entry_point_category}') + } -class BaseCredentialTypeRegistry(BasePluginRegistry): +class BaseCredentialTypeRegistry: """Load and track ManagedCredentialType plugins.""" + _registry: dict[str, ManagedCredentialType] = {} def add(self, credential_type: ManagedCredentialType) -> None: - """Add the credential type to the registry - - :param credential_type: ManagedCredentialType to add to the registery + """Add the credential type to the registry. + + :param credential_type: ManagedCredentialType to add to the + registry """ namespace = credential_type.namespace if namespace in self._registry: - raise ValueError('a ManagedCredentialType with namespace={} is already defined in {}'.format( - namespace, inspect.getsourcefile(self._registry[namespace].__class__) - )) + raise ValueError( + 'existing entry for namespace={} defined in {}'.format( + namespace, inspect.getsourcefile( + self._registry[namespace].__class__))) self._registry[namespace] = credential_type def get(self, namespace: str) -> ManagedCredentialType: @@ -44,14 +55,14 @@ def get(self, namespace: str) -> ManagedCredentialType: :returns: The ManagedCredentialType object or None if not found """ return self._registry.get(namespace, None) - + def get_keys(self) -> list[str]: """All loaded plugin names. - + :returns: A list of plugin names. """ return self._registry.keys() - + def get_all(self) -> list[ManagedCredentialType]: """Access all loaded credential types. @@ -61,11 +72,11 @@ def get_all(self) -> list[ManagedCredentialType]: def _discover_all(self) -> list[EntryPoint]: """Find all entry points to be loaded.""" - raise ValueError("Implement me") - + raise ValueError('Implement me') + def load_all(self): - """Load all the discovered plugins""" - for entry_point_name, entry_point in self._discover_all().items(): + """Load all the discovered plugins.""" + for entry_point in self._discover_all().values(): credential_type = entry_point.load() self.add(credential_type) @@ -73,12 +84,19 @@ def load_all(self): class _ManagedCredentialTypeRegistry(BaseCredentialTypeRegistry): def _discover_all(self) -> list[EntryPoint]: """Find all relevant awx managed credential entry points.""" + is_awx = detect_server_product_name() == 'AWX' - return self._get_all_entry_points_for(['managed_credentials'] if is_awx else ['managed_credentials', 'managed_credentials.supported']) + return _get_all_entry_points_for( + ['managed_credentials'] if is_awx else [ + 'managed_credentials', + 'managed_credentials.supported', + ], + ) class _CredentialPluginRegistry(BaseCredentialTypeRegistry): - """Load and manage lookup credential plugin types""" + """Load and manage lookup credential plugin types.""" + _registry: dict[str, CredentialPlugin] = {} def get_all(self) -> list[CredentialPlugin]: @@ -87,20 +105,21 @@ def get_all(self) -> list[CredentialPlugin]: :returns: All loaded credential plugins """ return self._registry.values() - + def _discover_all(self) -> list[EntryPoint]: """Find all user managed credential entry points.""" - return self._get_all_entry_points_for(['credentials',]) - + return _get_all_entry_points_for(['credentials']) + def load_all(self): - """Load all the user discovered plugins""" + """Load all the user discovered plugins.""" for entry_point_name, entry_point in self._discover_all().items(): cred_plugin = entry_point.load() self._registry[entry_point_name] = CredentialPlugin( name=cred_plugin.name, inputs=cred_plugin.inputs, - backend=cred_plugin.backend + backend=cred_plugin.backend, ) - + + ManagedCredentialTypeRegistry = _ManagedCredentialTypeRegistry() -CredentialPluginRegistry = _CredentialPluginRegistry() \ No newline at end of file +CredentialPluginRegistry = _CredentialPluginRegistry() diff --git a/tests/_temporary_private_api_test.py b/tests/_temporary_private_api_test.py index 0df63f8c..c73320b8 100644 --- a/tests/_temporary_private_api_test.py +++ b/tests/_temporary_private_api_test.py @@ -1,46 +1,58 @@ """Tests for the temporarily hosted private helpers.""" import os -import jinja2 -import pytest import shutil import tempfile - from pathlib import Path, PurePath +import pytest + +import jinja2 import yaml -from awx_plugins.interfaces._temporary_private_api import HIDDEN_PASSWORD, ManagedCredentialType -from awx_plugins.interfaces._temporary_private_credential_api import Credential -from awx_plugins.interfaces._temporary_private_container_api import CONTAINER_ROOT +from awx_plugins.interfaces._temporary_private_api import ( + HIDDEN_PASSWORD, + ManagedCredentialType, +) +from awx_plugins.interfaces._temporary_private_container_api import ( + CONTAINER_ROOT, +) +from awx_plugins.interfaces._temporary_private_credential_api import Credential def to_host_path(path, private_data_dir): - """Given a path inside of the EE container, this gives the absolute path - on the host machine within the private_data_dir - """ + """Given a path inside of the EE container, this gives the absolute path on + the host machine within the private_data_dir.""" if not os.path.isabs(private_data_dir): raise RuntimeError('The private_data_dir path must be absolute') - if CONTAINER_ROOT != path and Path(CONTAINER_ROOT) not in Path(path).resolve().parents: - raise RuntimeError(f'Cannot convert path {path} unless it is a subdir of {CONTAINER_ROOT}') + if CONTAINER_ROOT != path and Path( + CONTAINER_ROOT, + ) not in Path(path).resolve().parents: + raise RuntimeError( + f'Cannot convert path {path} unless it is a subdir of {CONTAINER_ROOT}', ) return path.replace(CONTAINER_ROOT, private_data_dir, 1) + def read_extra_vars(private_data_dir: str, args: list[str]) -> dict[str, str]: fname = to_host_path(args[1][1:], private_data_dir) - with open(fname, 'r') as f: + with open(fname) as f: return yaml.safe_load(f) + def assert_dict_subset(subset, full_dict): - """ - Recursively asserts that `subset` is a subset of `full_dict`. - """ + """Recursively asserts that `subset` is a subset of `full_dict`.""" for key, value in subset.items(): assert key in full_dict, f"Key '{key}' not found in full_dict" if isinstance(value, dict): - assert isinstance(full_dict[key], dict), f"Key '{key}' is not a dictionary in full_dict" + assert isinstance( + full_dict[key], dict, + ), f"Key '{key}' is not a dictionary in full_dict" assert_dict_subset(value, full_dict[key]) else: - assert value == full_dict[key], f"Value mismatch for key '{key}': {value} != {full_dict[key]}" + assert value == full_dict[key], f"Value mismatch for key '{key}': {value} != { + full_dict[key] + }" + @pytest.fixture def private_data_dir(): @@ -87,7 +99,10 @@ def test_managed_credential_type_inject_cred() -> None: assert env['PET_NAME'] == 'birdie' -def test_custom_environment_injectors_with_jinja_syntax_error(private_data_dir): + +def test_custom_environment_injectors_with_jinja_syntax_error( + private_data_dir, +): cred_type = ManagedCredentialType( kind='cloud', name='SomeCloud', @@ -101,6 +116,7 @@ def test_custom_environment_injectors_with_jinja_syntax_error(private_data_dir): with pytest.raises(jinja2.exceptions.UndefinedError): cred_type.inject_credential(credential, {}, {}, [], private_data_dir) + def test_custom_environment_injectors_with_reserved_env_var(private_data_dir): cred_type = ManagedCredentialType( kind='cloud', @@ -117,6 +133,7 @@ def test_custom_environment_injectors_with_reserved_env_var(private_data_dir): assert 'JOB_ID' not in env + def test_custom_environment_injectors_with_secret_field(private_data_dir): cred_type = ManagedCredentialType( kind='cloud', @@ -130,12 +147,15 @@ def test_custom_environment_injectors_with_secret_field(private_data_dir): env = {} safe_env = {} - cred_type.inject_credential(credential, env, safe_env, [], private_data_dir) + cred_type.inject_credential( + credential, env, safe_env, [], private_data_dir, + ) assert env['MY_CLOUD_PRIVATE_VAR'] == 'SUPER-SECRET-123' assert 'SUPER-SECRET-123' not in safe_env.values() assert safe_env['MY_CLOUD_PRIVATE_VAR'] == HIDDEN_PASSWORD + @pytest.mark.parametrize( ('inputs', 'injectors', 'cred_inputs', 'expected_extra_vars'), ( @@ -146,11 +166,11 @@ def test_custom_environment_injectors_with_secret_field(private_data_dir): {'api_token': 'ABC123'}, id='happy-path', ), - pytest.param ( + pytest.param( {'fields': [{'id': 'turbo_button', 'label': 'Turbo Button', 'type': 'boolean'}]}, {'extra_vars': {'turbo_button': '{{turbo_button}}'}}, {'turbo_button': True}, - {'turbo_button': "True"}, + {'turbo_button': 'True'}, id='boolean', ), pytest.param( @@ -173,10 +193,12 @@ def test_custom_environment_injectors_with_secret_field(private_data_dir): {'turbo_button': True}, {'turbo_button': 'FAST!'}, id='templated-bool', - ) + ), ), ) -def test_custom_environment_injectors_with_extra_vars(private_data_dir, inputs, injectors, cred_inputs, expected_extra_vars): +def test_custom_environment_injectors_with_extra_vars( + private_data_dir, inputs, injectors, cred_inputs, expected_extra_vars, +): cred_type = ManagedCredentialType( kind='cloud', name='SomeCloud', @@ -189,17 +211,32 @@ def test_custom_environment_injectors_with_extra_vars(private_data_dir, inputs, args = [] cred_type.inject_credential(credential, {}, {}, args, private_data_dir) - + extra_vars = read_extra_vars(private_data_dir, args) assert_dict_subset(expected_extra_vars, extra_vars) + @pytest.mark.parametrize( - ('inputs', 'injectors', 'cred_inputs', 'expected_file_content'), + ( + 'inputs', + 'injectors', + 'cred_inputs', + 'expected_file_content', + ), ( pytest.param( - {'fields': [{'id': 'api_token', 'label': 'API Token', 'type': 'string'}]}, - {'file': {'template': '[mycloud]\n{{api_token}}'}, 'env': {'MY_CLOUD_INI_FILE': '{{tower.filename}}'}}, + { + 'fields': [{ + 'id': 'api_token', + 'label': 'API Token', + 'type': 'string', + }], + }, + { + 'file': {'template': '[mycloud]\n{{api_token}}'}, + 'env': {'MY_CLOUD_INI_FILE': '{{tower.filename}}'}, + }, {'api_token': 'ABC123'}, { 'MY_CLOUD_INI_FILE': '[mycloud]\nABC123', @@ -208,7 +245,10 @@ def test_custom_environment_injectors_with_extra_vars(private_data_dir, inputs, ), pytest.param( {'fields': []}, - {'file': {'template': 'Iñtërnâtiônàlizætiøn'}, 'env': {'MY_CLOUD_INI_FILE': '{{tower.filename}}'}}, + { + 'file': {'template': 'Iñtërnâtiônàlizætiøn'}, + 'env': {'MY_CLOUD_INI_FILE': '{{tower.filename}}'}, + }, {}, { 'MY_CLOUD_INI_FILE': 'Iñtërnâtiônàlizætiøn', @@ -216,21 +256,49 @@ def test_custom_environment_injectors_with_extra_vars(private_data_dir, inputs, id='unicode', ), pytest.param( - {'fields': [{'id': 'cert', 'label': 'Certificate', 'type': 'string'}, {'id': 'key', 'label': 'Key', 'type': 'string'}]}, { - 'file': {'template.cert': '[mycert]\n{{cert}}', 'template.key': '[mykey]\n{{key}}'}, - 'env': {'MY_CERT_INI_FILE': '{{tower.filename.cert}}', 'MY_KEY_INI_FILE': '{{tower.filename.key}}'}, + 'fields': [ + { + 'id': 'cert', + 'label': 'Certificate', + 'type': 'string', + }, + { + 'id': 'key', + 'label': 'Key', + 'type': 'string', + }, + ], + }, + { + 'file': { + 'template.cert': '[mycert]\n{{cert}}', + 'template.key': '[mykey]\n{{key}}', + }, + 'env': { + 'MY_CERT_INI_FILE': '{{tower.filename.cert}}', + 'MY_KEY_INI_FILE': '{{tower.filename.key}}', + }, + }, + { + 'cert': 'CERT123', + 'key': 'KEY123', }, - {'cert': 'CERT123', 'key': 'KEY123'}, { 'MY_CERT_INI_FILE': '[mycert]\nCERT123', 'MY_KEY_INI_FILE': '[mykey]\nKEY123', }, id='multiple-files', - ) + ), ), ) -def test_custom_environment_injectors_with_file(private_data_dir, inputs, injectors, cred_inputs, expected_file_content): +def test_custom_environment_injectors_with_file( + private_data_dir, + inputs, + injectors, + cred_inputs, + expected_file_content, +): cred_type = ManagedCredentialType( kind='cloud', name='SomeCloud', @@ -246,5 +314,5 @@ def test_custom_environment_injectors_with_file(private_data_dir, inputs, inject for env_fname_key, expected_content in expected_file_content.items(): path = to_host_path(env[env_fname_key], private_data_dir) - with open(path, 'r') as f: + with open(path) as f: assert f.read() == expected_content