Skip to content
2 changes: 2 additions & 0 deletions src/class_resolver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class is ``Algorithm`` and it can infer what you mean.
RegistrationError,
RegistrationNameConflict,
RegistrationSynonymConflict,
SimpleResolver,
)
from .func import FunctionResolver
from .utils import (
Expand Down Expand Up @@ -93,6 +94,7 @@ class is ``Algorithm`` and it can infer what you mean.
"Resolver",
"ClassResolver",
"FunctionResolver",
"SimpleResolver",
# Utilities
"get_cls",
"get_subclasses",
Expand Down
53 changes: 53 additions & 0 deletions src/class_resolver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,56 @@ def objective(trial: optuna.Trial) -> float:
"""
key = trial.suggest_categorical(name, sorted(self.lookup_dict))
return self.lookup(key)


class SimpleResolver(BaseResolver[X, X], Generic[X]):
"""
A simple resolver which uses the string representations as key.

While very minimalistic, it can be quite handy when dealing with simple objects, e.g.,

>>> log_level_resolver = SimpleResolver(["debug", "info", "warning", "error"], default="info")
>>> log_level_resolver.make(None)
"info"
>>> r.make("WARNING")
"warning"
>>> r.make("fatal")
Traceback (most recent call last):
...
ValueError: Invalid query=fatal. Possible queries are {"debug", "info", "warning", "error"}.

We can also benefit from, e.g., creation of command-line options for click

>>> log_level_option = log_level_resolver.get_option("--log-level")

Or use the resolver to ensure a type-safe normalization

>>> import typing
>>> LogLevel = typing.Literal["debug", "info", "warning", "error"]
>>> r: SimpleResolver[LogLevel] = SimpleResolver(["debug", "info", "warning", "error"], default="info")
"""

# docstr-coverage: inherited
def extract_name(self, element: X) -> str: # noqa: D102
return str(element)

# docstr-coverage: inherited
def lookup(self, query: Hint[X], default: Optional[X] = None) -> X: # noqa: D102
str_query = self.normalize(str(query))
if str_query in self.lookup_dict:
return self.lookup_dict[str_query]
if query is not None:
raise ValueError(f"Invalid query={query}. Possible queries are {self.options}.")
if default is not None:
return default
if self.default is not None:
return self.default
raise ValueError(
"If query and default are None, a default must be set in the resolver, but it is None, too."
)

# docstr-coverage: inherited
def make(self, query, pos_kwargs: OptionalKwargs = None, **kwargs) -> X: # noqa: D102
if pos_kwargs is not None:
raise ValueError(f"{self.__class__.__name__} does not support positional arguments.")
return self.lookup(query=query, **kwargs)
28 changes: 28 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RegistrationNameConflict,
RegistrationSynonymConflict,
Resolver,
SimpleResolver,
UnexpectedKeywordError,
)

Expand Down Expand Up @@ -490,3 +491,30 @@ class AAlt3Base(Alt3Base):
with self.assertRaises(TypeError) as e:
resolver.make("a")
self.assertEqual("surprise!", str(e.exception))


class TestSimpleResolver(unittest.TestCase):
"""Tests for the simple resolver."""

def setUp(self) -> None:
"""Create test instance."""
self.instance = SimpleResolver([0, 1, 2, 3])

def test_make(self):
"""Test making valid objects."""
for i in range(4):
self.assertEqual(self.instance.make(i), i)
self.assertEqual(self.instance.make(str(i)), i)

def test_make_invalid(self):
"""Test making invalid choices."""
with self.assertRaises(ValueError):
self.instance.make(-1)
with self.assertRaises(ValueError):
self.instance.make(4)

def test_default(self):
"""Test make's interaction with default."""
with self.assertRaises(ValueError):
self.instance.make(None)
self.assertEqual(self.instance.make(None, default=2), 2)