Skip to content

Commit 6dbdb1a

Browse files
authored
Merge pull request #573 from lenskit/fix/fallback-bias
Improve pipeline logging and builder API
2 parents ce74335 + 6b1c07f commit 6dbdb1a

File tree

7 files changed

+49
-7
lines changed

7 files changed

+49
-7
lines changed

lenskit/lenskit/basic/bias.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def __init__(
286286

287287
@property
288288
def is_trained(self) -> bool:
289-
return hasattr(self, "bias_")
289+
return hasattr(self, "model_")
290290

291291
def train(self, data: Dataset):
292292
"""

lenskit/lenskit/logging/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Logging, progress, and resource records.
33
"""
44

5+
import os
56
from typing import Any
67

78
import structlog
@@ -21,6 +22,7 @@
2122
]
2223

2324
get_logger = structlog.stdlib.get_logger
25+
_trace_debug = os.environ.get("LK_TRACE", "no").lower() == "debug"
2426

2527

2628
def trace(logger: structlog.stdlib.BoundLogger, *args: Any, **kwargs: Any):
@@ -32,3 +34,5 @@ def trace(logger: structlog.stdlib.BoundLogger, *args: Any, **kwargs: Any):
3234
meth = getattr(logger, "trace", None)
3335
if meth is not None:
3436
meth(*args, **kwargs)
37+
elif _trace_debug:
38+
logger.debug(*args, **kwargs)

lenskit/lenskit/pipeline/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
# pyright: strict
1212
from __future__ import annotations
1313

14-
import logging
1514
import typing
1615
import warnings
1716
from types import FunctionType, UnionType
1817
from uuid import NAMESPACE_URL, uuid4, uuid5
1918

19+
import structlog
2020
from typing_extensions import Any, Literal, Self, TypeAlias, TypeVar, cast, overload
2121

2222
from lenskit.data import Dataset
@@ -51,7 +51,7 @@
5151
"topn_pipeline",
5252
]
5353

54-
_log = logging.getLogger(__name__)
54+
_log = structlog.stdlib.get_logger(__name__)
5555

5656
# common type var for quick use
5757
T = TypeVar("T")
@@ -707,6 +707,7 @@ def run_all(self, *nodes: str | Node[Any], **kwargs: object) -> PipelineState:
707707

708708
runner = PipelineRunner(self, kwargs)
709709
node_list = [self.node(n) for n in nodes]
710+
_log.debug("running pipeline", name=self.name, nodes=[n.name for n in node_list])
710711
if not node_list:
711712
node_list = self.nodes
712713

lenskit/lenskit/pipeline/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def candidate_selector(self, sel: Component):
5858
self._selector = sel
5959

6060
def predicts_ratings(
61-
self, transform: Component | None = None, *, fallback: Component | None = None
61+
self, *, transform: Component | None = None, fallback: Component | None = None
6262
):
6363
"""
6464
Specify that this pipeline will predict ratings, optionally providing a

lenskit/lenskit/pipeline/components.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from __future__ import annotations
1111

1212
import inspect
13+
import json
1314
from abc import abstractmethod
1415
from importlib import import_module
1516
from types import FunctionType
@@ -214,6 +215,10 @@ def __call__(self, **kwargs: Any) -> COut:
214215
"""
215216
...
216217

218+
def __repr__(self) -> str:
219+
params = json.dumps(self.get_config(), indent=2)
220+
return f"<{self.__class__.__name__} {params}>"
221+
217222

218223
def instantiate_component(
219224
comp: str | type | FunctionType, config: dict[str, Any] | None

lenskit/lenskit/pipeline/runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def run(self, node: Node[Any], *, required: bool = True) -> Any:
5959
elif status == "failed": # pragma: nocover
6060
raise RuntimeError(f"{node} previously failed")
6161

62-
trace(self.log, "processing node %s", node)
62+
trace(self.log, "processing node", node=node.name)
6363
self.status[node.name] = "in-progress"
6464
try:
6565
self._run_node(node, required)
@@ -96,6 +96,7 @@ def _inject_input(self, name: str, types: set[type] | None, required: bool) -> N
9696
if val is not None and types and not is_compatible_data(val, *types):
9797
raise TypeError(f"invalid data for input {name} (expected {types}, got {type(val)})")
9898

99+
trace(self.log, "injecting input", name=name, value=val)
99100
self.state[name] = val
100101

101102
def _run_component(
@@ -107,7 +108,7 @@ def _run_component(
107108
required: bool,
108109
) -> None:
109110
in_data = {}
110-
log = self.log.bind(component=name)
111+
log = self.log.bind(node=name)
111112
trace(log, "processing inputs")
112113
for iname, itype in inputs.items():
113114
# look up the input wiring for this parameter input
@@ -158,7 +159,7 @@ def _run_component(
158159

159160
in_data[iname] = ival
160161

161-
trace(log, "running component")
162+
trace(log, "running component", component=comp)
162163
self.state[name] = comp(**in_data)
163164

164165

lenskit/tests/basic/test_composite.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Licensed under the MIT license, see LICENSE.md for details.
55
# SPDX-License-Identifier: MIT
66

7+
import logging
78
import pickle
89
from typing import Any
910

@@ -20,9 +21,13 @@
2021
from lenskit.data import Dataset
2122
from lenskit.data.items import ItemList
2223
from lenskit.data.types import ID
24+
from lenskit.operations import predict, score
2325
from lenskit.pipeline import Pipeline
26+
from lenskit.pipeline.common import RecPipelineBuilder
2427
from lenskit.util.test import ml_ds, ml_ratings # noqa: F401
2528

29+
_log = logging.getLogger(__name__)
30+
2631

2732
def test_fallback_fill_missing(ml_ds: Dataset):
2833
pipe = Pipeline()
@@ -51,3 +56,29 @@ def test_fallback_fill_missing(ml_ds: Dataset):
5156

5257
assert scores[:2] == approx(known(2, ItemList(item_ids=items[:2])).scores())
5358
assert scores[2:] == approx(bias(2, ItemList(item_ids=items[2:])).scores())
59+
60+
61+
def test_fallback_double_bias(rng: np.random.Generator, ml_ds: Dataset):
62+
builder = RecPipelineBuilder()
63+
builder.scorer(BiasScorer(damping=50))
64+
builder.predicts_ratings(fallback=BiasScorer(damping=0))
65+
pipe = builder.build("double-bias")
66+
67+
_log.info("pipeline configuration: %s", pipe.get_config().model_dump_json(indent=2))
68+
69+
pipe.train(ml_ds)
70+
71+
for user in rng.choice(ml_ds.users.ids(), 100):
72+
items = rng.choice(ml_ds.items.ids(), 500)
73+
scores = score(pipe, user, items)
74+
scores = scores.scores()
75+
assert scores is not None
76+
assert not np.any(np.isnan(scores))
77+
78+
preds = predict(pipe, user, items)
79+
80+
preds = preds.scores()
81+
assert preds is not None
82+
assert not np.any(np.isnan(preds))
83+
84+
assert scores == approx(preds)

0 commit comments

Comments
 (0)