diff --git a/src/class_resolver/__init__.py b/src/class_resolver/__init__.py index f527c02..1aef313 100644 --- a/src/class_resolver/__init__.py +++ b/src/class_resolver/__init__.py @@ -57,6 +57,7 @@ class is ``Algorithm`` and it can infer what you mean. RegistrationError, RegistrationNameConflict, RegistrationSynonymConflict, + SimpleResolver, ) from .func import FunctionResolver from .utils import ( @@ -93,6 +94,7 @@ class is ``Algorithm`` and it can infer what you mean. "Resolver", "ClassResolver", "FunctionResolver", + "SimpleResolver", # Utilities "get_cls", "get_subclasses", diff --git a/src/class_resolver/base.py b/src/class_resolver/base.py index df07eec..491bd25 100644 --- a/src/class_resolver/base.py +++ b/src/class_resolver/base.py @@ -312,3 +312,56 @@ def objective(trial: optuna.Trial) -> float: """ key = trial.suggest_categorical(name, sorted(self.lookup_dict)) return self.lookup(key) + + +class SimpleResolver(BaseResolver[X, X], Generic[X]): + """ + A simple resolver which uses the string representations as key. + + While very minimalistic, it can be quite handy when dealing with simple objects, e.g., + + >>> log_level_resolver = SimpleResolver(["debug", "info", "warning", "error"], default="info") + >>> log_level_resolver.make(None) + "info" + >>> r.make("WARNING") + "warning" + >>> r.make("fatal") + Traceback (most recent call last): + ... + ValueError: Invalid query=fatal. Possible queries are {"debug", "info", "warning", "error"}. + + We can also benefit from, e.g., creation of command-line options for click + + >>> log_level_option = log_level_resolver.get_option("--log-level") + + Or use the resolver to ensure a type-safe normalization + + >>> import typing + >>> LogLevel = typing.Literal["debug", "info", "warning", "error"] + >>> r: SimpleResolver[LogLevel] = SimpleResolver(["debug", "info", "warning", "error"], default="info") + """ + + # docstr-coverage: inherited + def extract_name(self, element: X) -> str: # noqa: D102 + return str(element) + + # docstr-coverage: inherited + def lookup(self, query: Hint[X], default: Optional[X] = None) -> X: # noqa: D102 + str_query = self.normalize(str(query)) + if str_query in self.lookup_dict: + return self.lookup_dict[str_query] + if query is not None: + raise ValueError(f"Invalid query={query}. Possible queries are {self.options}.") + if default is not None: + return default + if self.default is not None: + return self.default + raise ValueError( + "If query and default are None, a default must be set in the resolver, but it is None, too." + ) + + # docstr-coverage: inherited + def make(self, query, pos_kwargs: OptionalKwargs = None, **kwargs) -> X: # noqa: D102 + if pos_kwargs is not None: + raise ValueError(f"{self.__class__.__name__} does not support positional arguments.") + return self.lookup(query=query, **kwargs) diff --git a/tests/test_api.py b/tests/test_api.py index abe65c8..eb89f80 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -17,6 +17,7 @@ RegistrationNameConflict, RegistrationSynonymConflict, Resolver, + SimpleResolver, UnexpectedKeywordError, ) @@ -490,3 +491,30 @@ class AAlt3Base(Alt3Base): with self.assertRaises(TypeError) as e: resolver.make("a") self.assertEqual("surprise!", str(e.exception)) + + +class TestSimpleResolver(unittest.TestCase): + """Tests for the simple resolver.""" + + def setUp(self) -> None: + """Create test instance.""" + self.instance = SimpleResolver([0, 1, 2, 3]) + + def test_make(self): + """Test making valid objects.""" + for i in range(4): + self.assertEqual(self.instance.make(i), i) + self.assertEqual(self.instance.make(str(i)), i) + + def test_make_invalid(self): + """Test making invalid choices.""" + with self.assertRaises(ValueError): + self.instance.make(-1) + with self.assertRaises(ValueError): + self.instance.make(4) + + def test_default(self): + """Test make's interaction with default.""" + with self.assertRaises(ValueError): + self.instance.make(None) + self.assertEqual(self.instance.make(None, default=2), 2)