Skip to content

Commit

Permalink
provide generic type for component subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jan 11, 2025
1 parent 2ff7d48 commit 204dcee
Show file tree
Hide file tree
Showing 18 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/guide/examples/blendcomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LinearBlendConfig(BaseModel):
"""


class LinearBlendScorer(Component):
class LinearBlendScorer(Component[ItemList]):
r"""
Score items with a linear blend of two other scores.
Expand Down
2 changes: 1 addition & 1 deletion lenskit-funksvd/lenskit/funksvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _align_add_bias(bias, index, keys, series):
return bias, series


class FunkSVDScorer(Trainable, Component):
class FunkSVDScorer(Trainable, Component[ItemList]):
"""
FunkSVD explicit-feedback matrix factoriation. FunkSVD is a regularized
biased matrix factorization technique trained with featurewise stochastic
Expand Down
2 changes: 1 addition & 1 deletion lenskit-hpf/lenskit/hpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
_logger = logging.getLogger(__name__)


class HPFScorer(Component, Trainable):
class HPFScorer(Component[ItemList], Trainable):
"""
Hierarchical Poisson factorization, provided by
`hpfrec <https://hpfrec.readthedocs.io/en/latest/>`_.
Expand Down
2 changes: 1 addition & 1 deletion lenskit-implicit/lenskit/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ImplicitALSConfig(ImplicitConfig, extra="allow"):
weight: float = 40.0


class BaseRec(Component, Trainable):
class BaseRec(Component[ItemList], Trainable):
"""
Base class for Implicit-backed recommenders.
Expand Down
2 changes: 1 addition & 1 deletion lenskit-sklearn/lenskit/sklearn/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BiasedSVDConfig:
n_iter: int = 5


class BiasedSVDScorer(Component, Trainable):
class BiasedSVDScorer(Component[ItemList], Trainable):
"""
Biased matrix factorization for explicit feedback using SciKit-Learn's
:class:`~sklearn.decomposition.TruncatedSVD`. It operates by first
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/als/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def to(self, device):
return self._replace(ui_rates=self.ui_rates.to(device), iu_rates=self.iu_rates.to(device))


class ALSBase(ABC, Component, Trainable):
class ALSBase(ABC, Component[ItemList], Trainable):
"""
Base class for ALS models.
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def entity_damping(self, entity: Literal["user", "item"]) -> float:
return entity_damping(self.damping, entity)


class BiasScorer(Component):
class BiasScorer(Component[ItemList]):
"""
A user-item bias rating prediction model. This component uses
:class:`BiasModel` to predict ratings for users and items.
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/candidates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
_logger = logging.getLogger(__name__)


class TrainingCandidateSelectorBase(Component, Trainable):
class TrainingCandidateSelectorBase(Component[ItemList], Trainable):
"""
Base class for candidate selectors using the training data.
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_logger = logging.getLogger(__name__)


class FallbackScorer(Component):
class FallbackScorer(Component[ItemList]):
"""
Scoring component that fills in missing scores using a fallback.
Expand Down
4 changes: 2 additions & 2 deletions lenskit/lenskit/basic/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_logger = logging.getLogger(__name__)


class UserTrainingHistoryLookup(Component, Trainable):
class UserTrainingHistoryLookup(Component[ItemList], Trainable):
"""
Look up a user's history from the training data.
Expand Down Expand Up @@ -57,7 +57,7 @@ def __str__(self):
return self.__class__.__name__


class KnownRatingScorer(Component, Trainable):
class KnownRatingScorer(Component[ItemList], Trainable):
"""
Score items by returning their values from the training data.
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/popularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PopConfig(BaseModel):
"""


class PopScorer(Component, Trainable):
class PopScorer(Component[ItemList], Trainable):
"""
Score items by their popularity. Use with :py:class:`TopN` to get a
most-popular-items recommender.
Expand Down
4 changes: 2 additions & 2 deletions lenskit/lenskit/basic/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class RandomConfig(BaseModel, arbitrary_types_allowed=True):
"""


class RandomSelector(Component):
class RandomSelector(Component[ItemList]):
"""
Randomly select items from a candidate list.
Expand Down Expand Up @@ -74,7 +74,7 @@ def __call__(
return items[np.zeros(0, dtype=np.int32)]


class SoftmaxRanker(Component):
class SoftmaxRanker(Component[ItemList]):
"""
Stochastic top-N ranking with softmax sampling.
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/basic/topn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TopNConfig(BaseModel):
"""


class TopNRanker(Component):
class TopNRanker(Component[ItemList]):
"""
Rank scored items by their score and take the top *N*. The ranking length
can be passed either at runtime or at component instantiation time, with the
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/knn/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def explicit(self) -> bool:
return self.feedback == "explicit"


class ItemKNNScorer(Component, Trainable):
class ItemKNNScorer(Component[ItemList], Trainable):
"""
Item-item nearest-neighbor collaborative filtering feedback. This item-item
implementation is based on the description of item-based CF by
Expand Down
2 changes: 1 addition & 1 deletion lenskit/lenskit/knn/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def explicit(self) -> bool:
return self.feedback == "explicit"


class UserKNNScorer(Component, Trainable):
class UserKNNScorer(Component[ItemList], Trainable):
"""
User-user nearest-neighbor collaborative filtering with ratings. This
user-user implementation is not terribly configurable; it hard-codes design
Expand Down
6 changes: 3 additions & 3 deletions lenskit/tests/pipeline/test_component_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ class PrefixConfigPYDC:
prefix: str = "UNDEFINED"


class PrefixerDC(Component):
class PrefixerDC(Component[str]):
config: PrefixConfigDC

def __call__(self, msg: str) -> str:
return self.config.prefix + msg


class PrefixerM(Component):
class PrefixerM(Component[str]):
config: PrefixConfigM

def __call__(self, msg: str) -> str:
Expand All @@ -51,7 +51,7 @@ class PrefixerM2(PrefixerM):
config: PrefixConfigM


class PrefixerPYDC(Component):
class PrefixerPYDC(Component[str]):
config: PrefixConfigPYDC

def __call__(self, msg: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion lenskit/tests/pipeline/test_pipeline_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class PrefixConfig:
prefix: str


class Prefixer(Component):
class Prefixer(Component[str]):
config: PrefixConfig

def __call__(self, msg: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion lenskit/tests/pipeline/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class PrefixConfig:
prefix: str


class Prefixer(Component):
class Prefixer(Component[str]):
config: PrefixConfig

def __call__(self, msg: str) -> str:
Expand Down

0 comments on commit 204dcee

Please sign in to comment.