Skip to content

Commit 21d7352

Browse files
Make source check more efficient
1 parent 30acc5b commit 21d7352

File tree

4 files changed

+13
-17
lines changed

4 files changed

+13
-17
lines changed

src/blueapi/core/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ def plan_2(...) -> MsgGenerator:
108108
# they are valid plans, unless there is an __all__ defined in the module,
109109
# in which case we only inspect objects listed there, regardless of their
110110
# original source module.
111-
if (
111+
if is_bluesky_plan_generator(obj) and (
112112
hasattr(module, "__all__")
113113
or is_function_sourced_from_module(obj, module)
114-
) and is_bluesky_plan_generator(obj):
114+
):
115115
self.register_plan(obj)
116116

117117
def with_device_module(self, module: ModuleType) -> None:

src/blueapi/utils/modules.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import importlib
2-
from collections.abc import Iterable
2+
from collections.abc import Callable, Iterable
33
from types import ModuleType
44
from typing import Any
55

@@ -37,7 +37,9 @@ def get_named_subset(names: list[str]):
3737
yield value
3838

3939

40-
def is_function_sourced_from_module(obj: Any, module: ModuleType) -> bool:
40+
def is_function_sourced_from_module(
41+
obj: Callable[..., Any], module: ModuleType
42+
) -> bool:
4143
"""
4244
Check if an object is originally from a particular module, useful to detect
4345
whether it actually comes from a nested import.
@@ -46,6 +48,4 @@ def is_function_sourced_from_module(obj: Any, module: ModuleType) -> bool:
4648
obj: Object to check
4749
module: Module to check against object
4850
"""
49-
return (
50-
hasattr(obj, "__module__") and importlib.import_module(obj.__module__) is module
51-
)
51+
return importlib.import_module(obj.__module__) is module

tests/unit_tests/utils/functions_b.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,3 @@ def c(): ...
55

66

77
def d(): ...
8-
9-
10-
e = 1

tests/unit_tests/utils/test_modules.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@ def test_imports_everything_without_all():
1515

1616
def test_source_is_in_module():
1717
module = import_module(".functions_b", package="tests.unit_tests.utils")
18-
assert is_function_sourced_from_module(module.__dict__["c"], module)
18+
c = module.__dict__["c"]
19+
assert callable(c)
20+
assert is_function_sourced_from_module(c, module)
1921

2022

2123
def test_source_is_not_in_module():
2224
module = import_module(".functions_b", package="tests.unit_tests.utils")
23-
assert not is_function_sourced_from_module(module.__dict__["a"], module)
24-
25-
26-
def test_source_check_on_non_function():
27-
module = import_module(".functions_b", package="tests.unit_tests.utils")
28-
assert not is_function_sourced_from_module(module.__dict__["e"], module)
25+
a = module.__dict__["a"]
26+
assert callable(a)
27+
assert not is_function_sourced_from_module(a, module)

0 commit comments

Comments
 (0)