Skip to content

Commit 1733788

Browse files
committed
fix: error with checking the imported module
1 parent 8d85d69 commit 1733788

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

src/pytest_cython/plugin.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from _pytest.nodes import Collector
1313
from _pytest.doctest import skip, DoctestModule, DoctestItem
14-
from _pytest.pathlib import ImportMode
14+
from _pytest.pathlib import resolve_package_path, ImportMode
1515

1616

1717
CYTHON_SUFFIXES = ['.py', '.pyx']
@@ -74,23 +74,37 @@ def collect(self) -> Iterable[DoctestItem]:
7474
return _add_line_numbers(module, items)
7575

7676

77-
def _true_stem(path: str | pathlib.Path) -> str:
78-
stem = pathlib.Path(path).stem
79-
if stem == path:
80-
return stem
77+
def _without_suffixes(path: str | pathlib.Path) -> pathlib.Path:
78+
path = pathlib.Path(path)
79+
return path.with_name(path.name.split('.')[0]).with_suffix('')
8180

82-
return _true_stem(stem)
81+
82+
def _get_module_name(path: pathlib.Path) -> str:
83+
pkg_path = resolve_package_path(path)
84+
if pkg_path is not None:
85+
pkg_root = pkg_path.parent
86+
names = list(path.with_suffix("").relative_to(pkg_root).parts)
87+
if names[-1] == "__init__":
88+
names.pop()
89+
module_name = ".".join(names)
90+
else:
91+
pkg_root = path.parent
92+
module_name = path.stem
93+
94+
return module_name
8395

8496

8597
def _check_module_import(module: Any, path: pathlib.Path, mode: ImportMode) -> None:
98+
# double check that the only difference is the extension else raise an exception
99+
86100
if mode is ImportMode.importlib or IGNORE_IMPORTMISMATCH == "1":
87101
return
88102

89-
# double check that the only difference is the extension else raise an exception
90-
module_file = _true_stem(module.__file__)
91-
module_name = _true_stem(path)
103+
module_name = _get_module_name(path)
104+
module_file = _without_suffixes(module.__file__)
105+
import_file = _without_suffixes(path)
92106

93-
if pathlib.Path(module_file) == pathlib.Path(module_name):
107+
if module_file == import_file:
94108
return
95109

96110
raise Collector.CollectError(
@@ -100,7 +114,7 @@ def _check_module_import(module: Any, path: pathlib.Path, mode: ImportMode) -> N
100114
"which is not the same as the test file we want to collect:\n"
101115
" %s\n"
102116
"HINT: remove __pycache__ / .pyc files and/or use a "
103-
"unique basename for your test file modules" % (module_name, module_file, path)
117+
"unique basename for your test file modules" % (module_name, module_file, import_file)
104118
)
105119

106120

0 commit comments

Comments
 (0)