From 468f5cec92253fb619a2ff18c38bcc90b6234c2b Mon Sep 17 00:00:00 2001 From: Damian Shaw Date: Sat, 3 Aug 2024 12:22:00 -0400 Subject: [PATCH] Add functional tests for narrow_requirement_selection --- .../python/test_resolvers_python.py | 42 ++++++++++++++++--- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/tests/functional/python/test_resolvers_python.py b/tests/functional/python/test_resolvers_python.py index 18c1550..c1e3038 100644 --- a/tests/functional/python/test_resolvers_python.py +++ b/tests/functional/python/test_resolvers_python.py @@ -121,6 +121,24 @@ def get_dependencies(self, candidate): return list(self._iter_dependencies(candidate)) +class PythonInputProviderNarrowRequirements(PythonInputProvider): + def narrow_requirement_selection( + self, identifiers, resolutions, candidates, information, backtrack_causes + ): + # Consider requirements that have 0 candidates (a resolution end point + # that can be backtracked from) or 1 candidate (speeds up situations where + # ever requirement is pinned to 1 specific version) + number_of_candidates = defaultdict(list) + for identifier in identifiers: + number_of_candidates[len(list(candidates[identifier]))].append(identifier) + + min_candidates = min(number_of_candidates.keys()) + if min_candidates in (0, 1): + return number_of_candidates[min_candidates] + + return identifiers + + INPUTS_DIR = os.path.abspath(os.path.join(__file__, "..", "inputs")) CASE_DIR = os.path.join(INPUTS_DIR, "case") @@ -133,20 +151,32 @@ def get_dependencies(self, candidate): } -@pytest.fixture( - params=[ +def create_params(provider_class): + return [ pytest.param( - os.path.join(CASE_DIR, n), + (os.path.join(CASE_DIR, n), provider_class), marks=pytest.mark.xfail(strict=True, reason=XFAIL_CASES[n]), ) if n in XFAIL_CASES - else os.path.join(CASE_DIR, n) + else (os.path.join(CASE_DIR, n), provider_class) + for n in CASE_NAMES + ] + + +@pytest.fixture( + params=[ + *create_params(PythonInputProvider), + *create_params(PythonInputProviderNarrowRequirements), + ], + ids=[ + f"{n[:-5]}-{cls.__name__}" + for cls in [PythonInputProvider, PythonInputProviderNarrowRequirements] for n in CASE_NAMES ], - ids=[n[:-5] for n in CASE_NAMES], ) def provider(request): - return PythonInputProvider(request.param) + path, provider_class = request.param + return provider_class(path) def _format_confliction(exception):