Skip to content

Commit

Permalink
Support for extending cfn cli with custom commands (#1020)
Browse files Browse the repository at this point in the history
* code: testing command extension

* code: finishing ExtensionPlugin base class

* code: removing unnecessary abstract class

* code: renaming parameter

* test: added unit tests for cli and extensions

* refactor: renaming parameter

* test: testing ExtensionPlugin

* test: testing PluginRegistry

* code: updated extension plugin to provider parser to plugin instead

* isort: fixing isort

* format: fixing black formatting

* code: safety check for collisions

---------

Co-authored-by: Adrian Chouza <achouza@amazon.com>
  • Loading branch information
TheChouzanOne and achouzamz authored Aug 31, 2023
1 parent 74ed3d9 commit 61f7d71
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/rpdk/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .build_image import setup_subparser as build_image_setup_subparser
from .data_loaders import resource_yaml
from .exceptions import DownstreamError, SysExitRecommendedError
from .extensions import setup_subparsers as extensions_setup_subparser
from .generate import setup_subparser as generate_setup_subparser
from .init import setup_subparser as init_setup_subparser
from .invoke import setup_subparser as invoke_setup_subparser
Expand Down Expand Up @@ -88,6 +89,7 @@ def no_command(args):
invoke_setup_subparser(subparsers, parents)
unittest_patch_setup_subparser(subparsers, parents)
build_image_setup_subparser(subparsers, parents)
extensions_setup_subparser(subparsers, parents)
args = parser.parse_args(args=args_in)

setup_logging(args.verbose)
Expand Down
18 changes: 18 additions & 0 deletions src/rpdk/core/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .plugin_registry import get_extensions


def _check_command_name_collision(subparsers, command_name):
if command_name in subparsers.choices:
raise RuntimeError(
f'"{command_name}" is already registered as an extension. Please use a different name.'
)


def setup_subparsers(subparsers, parents):
extensions = get_extensions()

for extension_cls in extensions.values():
extension = extension_cls()()
_check_command_name_collision(subparsers, extension.command_name)
parser = subparsers.add_parser(extension.command_name, parents=parents)
extension.setup_parser(parser)
16 changes: 16 additions & 0 deletions src/rpdk/core/plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,19 @@ def generate(self, project):
@abstractmethod
def package(self, project, zip_file):
pass


class ExtensionPlugin(ABC):
COMMAND_NAME = None

@property
def command_name(self):
if not self.COMMAND_NAME:
raise RuntimeError(
"Set COMMAND_NAME to the command you want to extend cfn with: `cfn COMMAND_NAME`."
)
return self.COMMAND_NAME

@abstractmethod
def setup_parser(self, parser):
pass
9 changes: 9 additions & 0 deletions src/rpdk/core/plugin_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,14 @@ def get_parsers():
return parsers


def get_extensions():
extensions = {
entry_point.name: entry_point.load
for entry_point in pkg_resources.iter_entry_points("rpdk.v1.extensions")
}

return extensions


def load_plugin(language):
return PLUGIN_REGISTRY[language]()()
9 changes: 9 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def test_main_no_args_prints_help(capsys):
assert "--help" in out


def test_main_setup_extensions():
with patch(
"rpdk.core.cli.extensions_setup_subparser"
) as extensions_setup_subparser:
main(args_in=[])

extensions_setup_subparser.assert_called_once()


def test_main_version_arg_prints_version(capsys):
main(args_in=["--version"])
out, err = capsys.readouterr()
Expand Down
60 changes: 60 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import argparse
from unittest import TestCase
from unittest.mock import MagicMock, patch

from rpdk.core.extensions import setup_subparsers


class ExtensionTest(TestCase):
def test_setup_subparsers(self): # pylint: disable=no-self-use
expeted_command_name = "expected-command-name"

mock_extension = MagicMock()
mock_extension.command_name = expeted_command_name

mock_extension_entry_point = MagicMock()
mock_extension_entry_point.return_value.return_value = mock_extension

mock_extension_entry_points = {"key": mock_extension_entry_point}

subparsers, parents, parser = MagicMock(), MagicMock(), MagicMock()
subparsers.add_parser.return_value = parser

with patch("rpdk.core.extensions.get_extensions") as mock_get_extensions:
mock_get_extensions.return_value = mock_extension_entry_points
setup_subparsers(subparsers, parents)

mock_extension.setup_parser.assert_called_once_with(parser)
subparsers.add_parser.assert_called_with(expeted_command_name, parents=parents)

def test_setup_subparsers_should_raise_error_when_collision_occur(self):
command_name = "command-name"

mock_extension_1, mock_extension_2 = MagicMock(), MagicMock()
mock_extension_1.command_name = command_name
mock_extension_2.command_name = command_name

mock_extension_1_entry_point = MagicMock()
mock_extension_1_entry_point.return_value.return_value = mock_extension_1

mock_extension_2_entry_point = MagicMock()
mock_extension_2_entry_point.return_value.return_value = mock_extension_2

mock_extension_entry_points = {
"key_1": mock_extension_1_entry_point,
"key_2": mock_extension_2_entry_point,
}

parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers()

with patch(
"rpdk.core.extensions.get_extensions"
) as mock_get_extensions, self.assertRaises(RuntimeError) as context:
mock_get_extensions.return_value = mock_extension_entry_points
setup_subparsers(subparsers, [])

assert (
str(context.exception)
== '"command-name" is already registered as an extension. Please use a different name.'
)
64 changes: 48 additions & 16 deletions tests/test_plugin_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import pytest

from rpdk.core.filters import FILTER_REGISTRY
from rpdk.core.plugin_base import LanguagePlugin, __name__ as plugin_base_name
from rpdk.core.plugin_base import (
ExtensionPlugin,
LanguagePlugin,
__name__ as plugin_base_name,
)


class TestLanguagePlugin(LanguagePlugin):
Expand All @@ -22,7 +26,7 @@ def package(self, project, zip_file):


@pytest.fixture
def plugin():
def language_plugin():
return TestLanguagePlugin()


Expand All @@ -34,20 +38,20 @@ def test_language_plugin_module_not_set():
plugin._module_name # pylint: disable=pointless-statement


def test_language_plugin_init_no_op(plugin):
plugin.init(None)
def test_language_plugin_init_no_op(language_plugin):
language_plugin.init(None)


def test_language_plugin_generate_no_op(plugin):
plugin.generate(None)
def test_language_plugin_generate_no_op(language_plugin):
language_plugin.generate(None)


def test_language_plugin_package_no_op(plugin):
plugin.package(None, None)
def test_language_plugin_package_no_op(language_plugin):
language_plugin.package(None, None)


def test_language_plugin_setup_jinja_env_defaults(plugin):
env = plugin._setup_jinja_env()
def test_language_plugin_setup_jinja_env_defaults(language_plugin):
env = language_plugin._setup_jinja_env()
assert env.loader
assert env.autoescape

Expand All @@ -57,28 +61,56 @@ def test_language_plugin_setup_jinja_env_defaults(plugin):
assert env.get_template("test.txt")


def test_language_plugin_setup_jinja_env_overrides(plugin):
def test_language_plugin_setup_jinja_env_overrides(language_plugin):
loader = object()
autoescape = object()
env = plugin._setup_jinja_env(autoescape=autoescape, loader=loader)
env = language_plugin._setup_jinja_env(autoescape=autoescape, loader=loader)
assert env.loader is loader
assert env.autoescape is autoescape

for name in FILTER_REGISTRY:
assert name in env.filters


def test_language_plugin_setup_jinja_env_no_spec(plugin):
def test_language_plugin_setup_jinja_env_no_spec(language_plugin):
with patch(
"rpdk.core.plugin_base.importlib.util.find_spec", return_value=None
) as mock_spec, patch("rpdk.core.plugin_base.PackageLoader") as mock_loader:
env = plugin._setup_jinja_env()
env = language_plugin._setup_jinja_env()

mock_spec.assert_called_once_with(plugin._module_name)
mock_loader.assert_has_calls([call(plugin._module_name), call(plugin_base_name)])
mock_spec.assert_called_once_with(language_plugin._module_name)
mock_loader.assert_has_calls(
[call(language_plugin._module_name), call(plugin_base_name)]
)

assert env.loader
assert env.autoescape

for name in FILTER_REGISTRY:
assert name in env.filters


class TestExtensionPlugin(ExtensionPlugin):
COMMAND_NAME = "test-extension"

def setup_parser(self, parser):
super().setup_parser(parser)


@pytest.fixture
def extension_plugin():
return TestExtensionPlugin()


def test_extension_plugin_command_name(extension_plugin):
assert extension_plugin.command_name == "test-extension"


def test_extension_plugin_command_name_error(extension_plugin):
extension_plugin.COMMAND_NAME = None
with pytest.raises(RuntimeError):
extension_plugin.command_name # pylint: disable=pointless-statement


def test_extension_plugin_setup_parser_no_op(extension_plugin):
extension_plugin.setup_parser(None)
20 changes: 19 additions & 1 deletion tests/test_plugin_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from unittest.mock import Mock, patch

from rpdk.core.plugin_registry import load_plugin
from rpdk.core.plugin_registry import get_extensions, load_plugin


def test_load_plugin():
Expand All @@ -11,3 +11,21 @@ def test_load_plugin():
load_plugin("test")
plugin.assert_called_once_with()
plugin.return_value.assert_called_once_with()


def test_get_extensions():
mock_entrypoint_1 = Mock()
mock_entrypoint_2 = Mock()

patch_iter_entry_points = patch(
"rpdk.core.plugin_registry.pkg_resources.iter_entry_points"
)
with patch_iter_entry_points as mock_iter_entry_points:
mock_iter_entry_points.return_value = [mock_entrypoint_1, mock_entrypoint_2]

extensions = get_extensions()

assert extensions == {
mock_entrypoint_1.name: mock_entrypoint_1.load,
mock_entrypoint_2.name: mock_entrypoint_2.load,
}

0 comments on commit 61f7d71

Please sign in to comment.