Skip to content

Commit abcab4c

Browse files
Only check functions for source module
1 parent 1d617e4 commit abcab4c

File tree

4 files changed

+31
-3
lines changed

4 files changed

+31
-3
lines changed

src/blueapi/utils/modules.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def get_named_subset(names: list[str]):
3737
yield value
3838

3939

40-
def is_sourced_from_module(obj: Any, module: ModuleType) -> bool:
40+
def is_function_sourced_from_module(obj: Any, module: ModuleType) -> bool:
4141
"""
4242
Check if an object is originally from a particular module, useful to detect
4343
whether it actually comes from a nested import.
@@ -46,4 +46,6 @@ def is_sourced_from_module(obj: Any, module: ModuleType) -> bool:
4646
obj: Object to check
4747
module: Module to check against object
4848
"""
49-
return importlib.import_module(obj.__module__) is module
49+
return (
50+
hasattr(obj, "__module__") and importlib.import_module(obj.__module__) is module
51+
)

tests/unit_tests/utils/functions_a.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def a(): ...
2+
3+
4+
def b(): ...

tests/unit_tests/utils/functions_b.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
def c(): ...
2+
3+
4+
def d(): ...
5+
6+
7+
e = 1

tests/unit_tests/utils/test_modules.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from importlib import import_module
22

3-
from blueapi.utils import load_module_all
3+
from blueapi.utils import is_sourced_from_module, load_module_all
44

55

66
def test_imports_all():
@@ -11,3 +11,18 @@ def test_imports_all():
1111
def test_imports_everything_without_all():
1212
module = import_module(".lacksall", package="tests.unit_tests.utils")
1313
assert list(load_module_all(module)) == [3, "hello", 9]
14+
15+
16+
def test_source_is_in_module():
17+
module = import_module(".functions_b", package="tests.unit_tests.utils")
18+
assert is_sourced_from_module(module.__dict__["c"], module)
19+
20+
21+
def test_source_is_not_in_module():
22+
module = import_module(".functions_b", package="tests.unit_tests.utils")
23+
assert not is_sourced_from_module(module.__dict__["a"], module)
24+
25+
26+
def test_source_check_on_non_function():
27+
module = import_module(".functions_b", package="tests.unit_tests.utils")
28+
assert not is_sourced_from_module(module.__dict__["e"], module)

0 commit comments

Comments
 (0)