diff --git a/src/resolvelib/reporters.py b/src/resolvelib/reporters.py index 9a6bfad..b000043 100644 --- a/src/resolvelib/reporters.py +++ b/src/resolvelib/reporters.py @@ -61,5 +61,5 @@ def rejecting_candidate( def pinning(self, candidate: CT) -> None: """Called when adding a candidate to the potential solution.""" - def fallback(self, from_: str, to: str) -> None: - """Called when falling back from one backtrack strategy to another.""" + def fallback(self) -> None: + """Called when falling back from backjumping to backtracking.""" diff --git a/src/resolvelib/resolvers.py b/src/resolvelib/resolvers.py index 8cae235..8fcff6d 100644 --- a/src/resolvelib/resolvers.py +++ b/src/resolvelib/resolvers.py @@ -31,9 +31,6 @@ if TYPE_CHECKING: from .providers import Preference - from typing_extensions import Literal - - BacktrackStrategy = Literal["backjump_fallback", "backjump", "backtrack"] class Result(NamedTuple, Generic[RT, CT, KT]): mapping: Mapping[KT, CT] @@ -104,11 +101,10 @@ def __init__( self, provider: AbstractProvider[RT, CT, KT], reporter: BaseReporter[RT, CT, KT], - backtrack_strategy: BacktrackStrategy = "backjump_fallback", ) -> None: self._p = provider self._r = reporter - self._backtrack_strategy: BacktrackStrategy = backtrack_strategy + self._started_fallback = False self._fallback_states: Optional[list[State[RT, CT, KT]]] = None self._states: list[State[RT, CT, KT]] = [] @@ -295,10 +291,7 @@ def _backjump_iteration( incompatible_state = False name, candidate, broken_state = None, None, None - if ( - self._backtrack_strategy == "backjump_fallback" - and self._fallback_states is None - ): + if self._fallback_states is None: fallback_states = [ State( s.mapping.copy(), @@ -330,11 +323,7 @@ def _backjump_iteration( # Backup states first time a backjump goes # further than a backtrack would have - if ( - self._backtrack_strategy == "backjump_fallback" - and self._fallback_states is None - and backjump_count == 2 - ): + if self._fallback_states is None and backjump_count == 2: self._fallback_states = fallback_states if name is None or candidate is None or broken_state is None: @@ -400,7 +389,7 @@ def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool: name, candidate, incompatibilities_from_broken = None, None, None - if self._backtrack_strategy in ("backjump", "backjump_fallback"): + if not self._started_fallback: try: ( name, @@ -410,18 +399,15 @@ def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool: causes=causes, incompatible_deps=incompatible_deps ) except ResolutionImpossible: - if ( - self._backtrack_strategy == "backjump" - or self._fallback_states is None - ): + if self._fallback_states is None: raise # Backjumping failed but fallback to backtracking was requested self._states = self._fallback_states - self._backtrack_strategy = "backtrack" - self._r.fallback("backjump_fallback", "backtrack") + self._started_fallback = True + self._r.fallback() - if self._backtrack_strategy == "backtrack": + if self._started_fallback: ( name, candidate,