Skip to content

Commit 8418807

Browse files
authored
Merge pull request #576 from lenskit/feature/common-tests
Implement and use common component test suites
2 parents 8146c6a + 9e79952 commit 8418807

File tree

24 files changed

+362
-24
lines changed

24 files changed

+362
-24
lines changed

lenskit-funksvd/lenskit/funksvd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ class FunkSVDScorer(Component, Trainable):
238238

239239
def __init__(
240240
self,
241-
features: int,
241+
features: int = 50,
242242
iterations: int = 100,
243243
*,
244244
lrate: float = 0.001,
@@ -314,6 +314,7 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
314314
query = RecQuery.create(query)
315315

316316
user_id = query.user_id
317+
user_num = None
317318
if user_id is not None:
318319
user_num = self.users_.number(user_id, missing=None)
319320
if user_num is None:

lenskit-funksvd/tests/test_funksvd.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from lenskit.data.bulk import dict_to_df, iter_item_lists
1818
from lenskit.funksvd import FunkSVDScorer
1919
from lenskit.metrics import call_metric, quick_measure_model
20-
from lenskit.testing import ml_100k, ml_ds, wantjit # noqa: F401
20+
from lenskit.testing import BasicComponentTests, ScorerTests, wantjit
2121

2222
_log = logging.getLogger(__name__)
2323

@@ -27,6 +27,10 @@
2727
simple_ds = from_interactions_df(simple_df)
2828

2929

30+
class TestFunkSVD(BasicComponentTests, ScorerTests):
31+
component = FunkSVDScorer
32+
33+
3034
def test_fsvd_basic_build():
3135
algo = FunkSVDScorer(20, iterations=20)
3236
algo.train(simple_ds)

lenskit-hpf/lenskit/hpf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class HPFScorer(Component, Trainable):
3939
items_: Vocabulary
4040
item_features_: np.ndarray[tuple[int, int], np.dtype[np.float64]]
4141

42-
def __init__(self, features: int, **kwargs):
42+
def __init__(self, features: int = 50, **kwargs):
4343
self.features = features
4444
self._kwargs = kwargs
4545

@@ -78,6 +78,7 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
7878
query = RecQuery.create(query)
7979

8080
user_id = query.user_id
81+
user_num = None
8182
if user_id is not None:
8283
user_num = self.users_.number(user_id, missing=None)
8384
if user_num is None:

lenskit-hpf/tests/test_hpf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
from lenskit.data import ItemList, from_interactions_df
1616
from lenskit.metrics import quick_measure_model
1717
from lenskit.pipeline import topn_pipeline
18+
from lenskit.testing import BasicComponentTests, ScorerTests
1819

1920
hpf = importorskip("lenskit.hpf")
2021

2122
_log = logging.getLogger(__name__)
2223

2324

25+
class TestHPF(BasicComponentTests, ScorerTests):
26+
component = hpf.HPFScorer
27+
28+
2429
@mark.slow
2530
def test_hpf_train_large(tmp_path, ml_ratings):
2631
algo = hpf.HPFScorer(20)

lenskit-implicit/tests/test_implicit.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,19 @@
1515
from lenskit.data import ItemList, from_interactions_df
1616
from lenskit.implicit import ALS, BPR
1717
from lenskit.metrics import quick_measure_model
18+
from lenskit.testing import BasicComponentTests, ScorerTests
1819

1920
_log = logging.getLogger(__name__)
2021

2122

23+
class TestImplicitALS(BasicComponentTests, ScorerTests):
24+
component = ALS
25+
26+
27+
class TestImplicitBPR(BasicComponentTests, ScorerTests):
28+
component = BPR
29+
30+
2231
@mark.slow
2332
def test_implicit_als_train_rec(ml_ds):
2433
algo = ALS(25)

lenskit-sklearn/lenskit/sklearn/svd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class BiasedSVDScorer(Component, Trainable):
5353

5454
def __init__(
5555
self,
56-
features: int,
56+
features: int = 50,
5757
*,
5858
damping: UITuple[float] | float | tuple[float, float] = 5,
5959
algorithm: Literal["arpack", "randomized"] = "randomized",

lenskit-sklearn/tests/test_svd.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from lenskit.data import Dataset, ItemList, from_interactions_df
1616
from lenskit.metrics import call_metric, quick_measure_model
1717
from lenskit.sklearn import svd
18+
from lenskit.testing import BasicComponentTests, ScorerTests
1819

1920
_log = logging.getLogger(__name__)
2021

@@ -26,6 +27,10 @@
2627
need_skl = mark.skipif(not svd.SKL_AVAILABLE, reason="scikit-learn not installed")
2728

2829

30+
class TestBiasedSVD(BasicComponentTests, ScorerTests):
31+
component = svd.BiasedSVDScorer
32+
33+
2934
@need_skl
3035
def test_svd_basic_build():
3136
algo = svd.BiasedSVDScorer(2)

lenskit/lenskit/als/_explicit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class BiasedMFScorer(ALSBase):
5858

5959
def __init__(
6060
self,
61-
features: int,
61+
features: int = 50,
6262
*,
6363
epochs: int = 10,
6464
reg: float | tuple[float, float] = 0.1,

lenskit/lenskit/als/_implicit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class ImplicitMFScorer(ALSBase):
8181

8282
def __init__(
8383
self,
84-
features: int,
84+
features: int = 50,
8585
*,
8686
epochs: int = 20,
8787
reg: float | tuple[float, float] = 0.1,

lenskit/lenskit/data/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from .convert import from_interactions_df
1414
from .dataset import Dataset, FieldError
1515
from .items import ItemList
16+
from .lazy import LazyDataset
17+
from .matrix import MatrixDataset
1618
from .movielens import load_movielens, load_movielens_df
1719
from .mtarray import MTArray, MTFloatArray, MTGenericArray, MTIntArray
1820
from .query import QueryInput, RecQuery
@@ -23,6 +25,8 @@
2325
"Dataset",
2426
"FieldError",
2527
"from_interactions_df",
28+
"LazyDataset",
29+
"MatrixDataset",
2630
"ID",
2731
"NPID",
2832
"UITuple",

lenskit/lenskit/knn/item.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class ItemKNNScorer(Component, Trainable):
7979

8080
def __init__(
8181
self,
82-
nnbrs: int,
82+
nnbrs: int = 20,
8383
min_nbrs: int = 1,
8484
min_sim: float = 1.0e-6,
8585
save_nbrs: int | None = None,
@@ -202,7 +202,10 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
202202
ratings = query.user_items
203203
if ratings is None:
204204
if query.user_id is None:
205-
raise ValueError("cannot recommend without without either user ID or items")
205+
warnings.warn(
206+
"cannot recommend without without either user ID or items", DataWarning
207+
)
208+
return ItemList(items, scores=np.nan)
206209

207210
upos = self.users_.number(query.user_id, missing=None)
208211
if upos is None:

lenskit/lenskit/knn/user.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class UserKNNScorer(Component, Trainable):
7979

8080
def __init__(
8181
self,
82-
nnbrs: int,
82+
nnbrs: int = 20,
8383
min_nbrs: int = 1,
8484
min_sim: float = 1.0e-6,
8585
feedback: FeedbackType = "explicit",
@@ -155,6 +155,9 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
155155
query = RecQuery.create(query)
156156
watch = util.Stopwatch()
157157
log = _log.bind(user_id=query.user_id, n_items=len(items))
158+
if len(items) == 0:
159+
log.debug("no candidate items, skipping")
160+
return ItemList(items, scores=np.nan)
158161

159162
udata = self._get_user_data(query)
160163
if udata is None:

lenskit/lenskit/testing/__init__.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import os
99
from contextlib import contextmanager
1010

11-
import pytest
12-
1311
from ._arrays import coo_arrays, scored_lists, sparse_arrays, sparse_tensors
12+
from ._components import BasicComponentTests, ScorerTests
13+
from ._markers import jit_enabled, wantjit
1414
from ._movielens import (
1515
demo_recs,
1616
ml_100k,
@@ -36,16 +36,10 @@
3636
"wantjit",
3737
"jit_enabled",
3838
"set_env_var",
39+
"BasicComponentTests",
40+
"ScorerTests",
3941
]
4042

41-
jit_enabled = True
42-
if "NUMBA_DISABLE_JIT" in os.environ:
43-
jit_enabled = False
44-
if os.environ.get("PYTORCH_JIT", None) == "0":
45-
jit_enabled = False
46-
47-
wantjit = pytest.mark.skipif(not jit_enabled, reason="JIT required")
48-
4943

5044
@contextmanager
5145
def set_env_var(var, val):

0 commit comments

Comments
 (0)