Skip to content

Commit

Permalink
Add functional tests for narrow_requirement_selection
Browse files Browse the repository at this point in the history
  • Loading branch information
notatallshaw committed Aug 3, 2024
1 parent 9ab1631 commit 468f5ce
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions tests/functional/python/test_resolvers_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,24 @@ def get_dependencies(self, candidate):
return list(self._iter_dependencies(candidate))


class PythonInputProviderNarrowRequirements(PythonInputProvider):
def narrow_requirement_selection(
self, identifiers, resolutions, candidates, information, backtrack_causes
):
# Consider requirements that have 0 candidates (a resolution end point
# that can be backtracked from) or 1 candidate (speeds up situations where
# ever requirement is pinned to 1 specific version)
number_of_candidates = defaultdict(list)
for identifier in identifiers:
number_of_candidates[len(list(candidates[identifier]))].append(identifier)

min_candidates = min(number_of_candidates.keys())
if min_candidates in (0, 1):
return number_of_candidates[min_candidates]

return identifiers


INPUTS_DIR = os.path.abspath(os.path.join(__file__, "..", "inputs"))

CASE_DIR = os.path.join(INPUTS_DIR, "case")
Expand All @@ -133,20 +151,32 @@ def get_dependencies(self, candidate):
}


@pytest.fixture(
params=[
def create_params(provider_class):
return [
pytest.param(
os.path.join(CASE_DIR, n),
(os.path.join(CASE_DIR, n), provider_class),
marks=pytest.mark.xfail(strict=True, reason=XFAIL_CASES[n]),
)
if n in XFAIL_CASES
else os.path.join(CASE_DIR, n)
else (os.path.join(CASE_DIR, n), provider_class)
for n in CASE_NAMES
]


@pytest.fixture(
params=[
*create_params(PythonInputProvider),
*create_params(PythonInputProviderNarrowRequirements),
],
ids=[
f"{n[:-5]}-{cls.__name__}"
for cls in [PythonInputProvider, PythonInputProviderNarrowRequirements]
for n in CASE_NAMES
],
ids=[n[:-5] for n in CASE_NAMES],
)
def provider(request):
return PythonInputProvider(request.param)
path, provider_class = request.param
return provider_class(path)


def _format_confliction(exception):
Expand Down

0 comments on commit 468f5ce

Please sign in to comment.