Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support selecting steps by package name for strun #202

Merged
merged 3 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/stpipe/cli/strun.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys

from stpipe import Step
from stpipe import cmdline
from stpipe.cli.main import _print_versions
from stpipe.exceptions import StpipeExitException

Expand All @@ -21,7 +21,7 @@ def main():
sys.exit(0)

try:
Step.from_cmdline(sys.argv[1:])
cmdline.step_from_cmdline(sys.argv[1:])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not required by this PR (I'm fine removing it) but during debugging it was confusing that this interface calls Step.from_cmdline which then just calls cmdline.step_from_cmdline.

stpipe/src/stpipe/step.py

Lines 170 to 191 in 799be65

@staticmethod
def from_cmdline(args):
"""
Create a step from a configuration file.
Parameters
----------
args : list of str
Commandline arguments
Returns
-------
step : Step instance
If the config file has a ``class`` parameter, the return
value will be as instance of that class.
Any parameters found in the config file will be set
as member variables on the returned `Step` instance.
"""
from . import cmdline
return cmdline.step_from_cmdline(args)

except StpipeExitException as e:
sys.exit(e.exit_status)
except Exception:
Expand Down
4 changes: 2 additions & 2 deletions src/stpipe/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from collections import namedtuple

from importlib_metadata import entry_points
import importlib_metadata
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to change this import to allow the monkeypatching of "entry_points" in the unit test to work.


STEPS_GROUP = "stpipe.steps"

Expand All @@ -26,7 +26,7 @@ class alias, and the third is a bool indicating whether the class is to be
"""
steps = []

for entry_point in entry_points(group=STEPS_GROUP):
for entry_point in importlib_metadata.entry_points(group=STEPS_GROUP):
package_name = entry_point.dist.name
package_version = entry_point.dist.version
package_steps = []
Expand Down
13 changes: 12 additions & 1 deletion src/stpipe/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,24 @@ def resolve_step_class_alias(name):
Parameters
----------
name : str
If name contains "::" only the package with
a name matching the characters before "::"
will be searched for the matching step.

Returns
-------
str
"""
# check if the name contains a package name
if "::" in name:
scope, class_name = name.split("::", maxsplit=1)
else:
scope, class_name = None, name

for info in entry_points.get_steps():
if info.class_alias is not None and name == info.class_alias:
if scope and info.package_name != scope:
continue
if info.class_alias is not None and class_name == info.class_alias:
return info.class_name

return name
Expand Down
42 changes: 41 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from stpipe import Step
from stpipe.utilities import import_class, import_func
from stpipe.utilities import import_class, import_func, resolve_step_class_alias


def what_is_your_quest():
Expand All @@ -13,6 +13,8 @@ class HovercraftFullOfEels:


class Foo(Step):
class_alias = "foo_step"

def process(self, input_data):
pass

Expand Down Expand Up @@ -52,3 +54,41 @@ def test_import_class_no_module():
def test_import_func_no_module():
with pytest.raises(ImportError):
import_func("foo")


@pytest.mark.parametrize(
"name, resolve",
(
("foo_step", True),
("stpipe::foo_step", True),
("some_other_package::foo_step", False),
),
)
def test_class_alias_lookup(name, resolve, monkeypatch):
# as the test class above isn't registered via an entry point
# we mock the entry points here
class FakeDist:
name = "stpipe"
version = "dev"

class FakeEntryPoint:
dist = FakeDist()

def load(self):
def loader():
return [("Foo", "foo_step", False)]

return loader

def fake_entrypoints(group=None):
return [FakeEntryPoint()]

import importlib_metadata

monkeypatch.setattr(importlib_metadata, "entry_points", fake_entrypoints)

resolved_name = resolve_step_class_alias(name)
if resolve:
assert resolved_name == Foo.__name__
else:
assert resolved_name == name
Loading