Skip to content

Commit

Permalink
Merge pull request #579 from mdekstrand/feature/knn-accel
Browse files Browse the repository at this point in the history
Modest user-KNN speedups + item KNN logging
  • Loading branch information
mdekstrand authored Dec 27, 2024
2 parents 88fc375 + 8a07e3c commit b193bd1
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 78 deletions.
61 changes: 26 additions & 35 deletions lenskit/lenskit/knn/item.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,25 @@
# pyright: basic
from __future__ import annotations

import logging
import warnings

import numpy as np
import structlog
import torch
from scipy.sparse import csr_array
from typing_extensions import Optional, override

from lenskit import util
from lenskit.data import Dataset, FeedbackType, ItemList, QueryInput, RecQuery, Vocabulary
from lenskit.diagnostics import DataWarning
from lenskit.logging import trace
from lenskit.logging.progress import item_progress_handle, pbh_update
from lenskit.math.sparse import normalize_sparse_rows, safe_spmv
from lenskit.parallel import ensure_parallel_init
from lenskit.pipeline import Component, Trainable
from lenskit.util.torch import inference_mode

_log = logging.getLogger(__name__)
_log = structlog.stdlib.get_logger(__name__)
MAX_BLOCKS = 1024


Expand Down Expand Up @@ -130,56 +131,53 @@ def train(self, data: Dataset):
(user,item,rating) data for computing item similarities.
"""
ensure_parallel_init()
log = _log.bind(n_items=data.item_count, feedback=self.feedback)
# Training proceeds in 2 steps:
# 1. Normalize item vectors to be mean-centered and unit-normalized
# 2. Compute similarities with pairwise dot products
self._timer = util.Stopwatch()
_log.info("training IKNN for %d users in %s feedback mode", data.item_count, self.feedback)

_log.debug("[%s] beginning fit, memory use %s", self._timer, util.max_memory())
log.info("begining IKNN training")

field = "rating" if self.feedback == "explicit" else None
init_rmat = data.interaction_matrix("torch", field=field)
n_items = data.item_count
_log.info(
"[%s] made sparse matrix for %d items (%d ratings from %d users)",
log.info(
"[%s] made sparse matrix",
self._timer,
n_items,
len(init_rmat.values()),
data.user_count,
n_ratings=len(init_rmat.values()),
n_users=data.user_count,
)
_log.debug("[%s] made matrix, memory use %s", self._timer, util.max_memory())

# we operate on *transposed* rating matrix: items on the rows
rmat = init_rmat.transpose(0, 1).to_sparse_csr().to(torch.float64)

if self.feedback == "explicit":
rmat, means = normalize_sparse_rows(rmat, "center")
if np.allclose(rmat.values(), 0.0):
_log.warning("normalized ratings are zero, centering is not recommended")
log.warning("normalized ratings are zero, centering is not recommended")
warnings.warn(
"Ratings seem to have the same value, centering is not recommended.",
DataWarning,
)
else:
means = None
_log.debug("[%s] centered, memory use %s", self._timer, util.max_memory())
log.debug("[%s] centered, memory use %s", self._timer, util.max_memory())

rmat, _norms = normalize_sparse_rows(rmat, "unit")
_log.debug("[%s] normalized, memory use %s", self._timer, util.max_memory())
log.debug("[%s] normalized, memory use %s", self._timer, util.max_memory())

_log.info("[%s] computing similarity matrix", self._timer)
log.info("[%s] computing similarity matrix", self._timer)
smat = self._compute_similarities(rmat)
_log.debug("[%s] computed, memory use %s", self._timer, util.max_memory())
log.debug("[%s] computed, memory use %s", self._timer, util.max_memory())

_log.info(
log.info(
"[%s] got neighborhoods for %d of %d items",
self._timer,
np.sum(np.diff(smat.crow_indices()) > 0),
n_items,
)

_log.info("[%s] computed %d neighbor pairs", self._timer, len(smat.col_indices()))
log.info("[%s] computed %d neighbor pairs", self._timer, len(smat.col_indices()))

self.items_ = data.items
self.item_means_ = means.numpy() if means is not None else None
Expand All @@ -188,7 +186,7 @@ def train(self, data: Dataset):
(smat.values(), smat.col_indices(), smat.crow_indices()), smat.shape
)
self.users_ = data.users
_log.debug("[%s] done, memory use %s", self._timer, util.max_memory())
log.debug("[%s] done, memory use %s", self._timer, util.max_memory())

def _compute_similarities(self, rmat: torch.Tensor) -> torch.Tensor:
nitems, nusers = rmat.shape
Expand All @@ -204,13 +202,14 @@ def _compute_similarities(self, rmat: torch.Tensor) -> torch.Tensor:
@inference_mode
def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
query = RecQuery.create(query)
_log.debug("predicting %d items for user %s", len(items), query.user_id)
log = _log.bind(user_id=query.user_id, n_items=len(items))
trace(log, "beginning prediction")

ratings = query.user_items
if ratings is None or len(ratings) == 0:
if ratings is None:
warnings.warn("no user history, did you omit a history component?", DataWarning)
_log.debug("user has no history, returning")
log.debug("user has no history, returning")
return ItemList(items, scores=np.nan)

# set up rating array
Expand All @@ -219,7 +218,7 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
ri_mask = ri_nums >= 0
ri_valid_nums = ri_nums[ri_mask]
n_valid = len(ri_valid_nums)
_log.debug("user %s: %d of %d rated items in model", query.user_id, n_valid, len(ratings))
trace(log, "%d of %d rated items in model", n_valid, len(ratings))

if self.feedback == "explicit":
ri_vals = ratings.field("rating", "numpy")
Expand Down Expand Up @@ -280,9 +279,9 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
slow_trimmed, slow_inds = torch.topk(slow_mat, self.nnbrs)
assert slow_trimmed.shape == (n_slow, self.nnbrs)
if self.feedback == "explicit":
scores[ti_slow_mask] = torch.sum(
slow_trimmed * torch.from_numpy(ri_vals)[slow_inds], axis=1
).numpy()
svals = torch.from_numpy(ri_vals)[slow_inds]
assert svals.shape == slow_trimmed.shape
scores[ti_slow_mask] = torch.sum(slow_trimmed * svals, axis=1).numpy()
scores[ti_slow_mask] /= torch.sum(slow_trimmed, axis=1).numpy()
else:
scores[ti_slow_mask] = torch.sum(slow_trimmed, axis=1).numpy()
Expand All @@ -291,11 +290,9 @@ def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
if self.item_means_ is not None:
scores[ti_mask] += self.item_means_[ti_valid_nums]

_log.debug(
"user %s: predicted for %d of %d items",
query.user_id,
log.debug(
"scored %d items",
int(np.isfinite(scores).sum()),
len(items),
)

return ItemList(items, scores=scores)
Expand All @@ -304,12 +301,6 @@ def __str__(self):
return "ItemItem(nnbrs={}, msize={})".format(self.nnbrs, self.save_nbrs)


@torch.jit.ignore # type: ignore
def _msg(level, msg):
# type: (int, str) -> None
_log.log(level, msg)


@torch.jit.script
def _sim_row(
item: int, matrix: torch.Tensor, row: torch.Tensor, min_sim: float, max_nbrs: Optional[int]
Expand Down
95 changes: 52 additions & 43 deletions lenskit/lenskit/knn/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pandas as pd
import structlog
import torch
from scipy.sparse import csr_array
from scipy.sparse import csc_array
from typing_extensions import NamedTuple, Optional, Self, override

from lenskit import util
Expand All @@ -37,6 +37,12 @@ class UserKNNScorer(Component, Trainable):
user-user implementation is not terribly configurable; it hard-codes design
decisions found to work well in the previous Java-based LensKit code.
.. note::
This component must be used with queries containing the user's history,
either directly in the input or by wiring its query input to the output of a
user history component (e.g., :class:`~lenskit.basic.UserTrainingHistoryLookup`).
Args:
nnbrs:
the maximum number of neighbors for scoring each item (``None`` for
Expand Down Expand Up @@ -74,7 +80,7 @@ class UserKNNScorer(Component, Trainable):
"Mean rating for each known user."
user_vectors_: torch.Tensor
"Normalized rating matrix (CSR) to find neighbors at prediction time."
user_ratings_: csr_array
user_ratings_: csc_array
"Centered but un-normalized rating matrix (COO) to find neighbor ratings."

def __init__(
Expand Down Expand Up @@ -130,7 +136,7 @@ def train(self, data: Dataset) -> Self:
normed, _norms = normalize_sparse_rows(rmat, "unit")

self.user_vectors_ = normed
self.user_ratings_ = torch_sparse_to_scipy(rmat).tocsr()
self.user_ratings_ = torch_sparse_to_scipy(rmat).tocsc()
self.users_ = data.users.copy()
self.user_means_ = means
self.items_ = data.items.copy()
Expand Down Expand Up @@ -282,64 +288,67 @@ def score_items_with_neighbors(
items: torch.Tensor,
nbr_rows: torch.Tensor,
nbr_sims: torch.Tensor,
ratings: csr_array,
ratings: csc_array,
max_nbrs: int,
min_nbrs: int,
average: bool,
) -> np.ndarray[tuple[int], np.dtype[np.float64]]:
# select a sub-matrix for further manipulation
items = items.numpy()
(ni,) = items.shape
(nrow, ncol) = ratings.shape
# do matrix surgery
nbr_rates = ratings[nbr_rows.numpy(), :]
nbr_rates = nbr_rates[:, items.numpy()]

nbr_t = nbr_rates.transpose().tocsr()
# sort neighbors by similarity
nbr_order = np.argsort(-nbr_sims)
nbr_rows = nbr_rows[nbr_order].numpy()
nbr_sims = nbr_sims[nbr_order].numpy()

# get the rating rows for our neighbors
nbr_rates = ratings[nbr_rows, :]

# count nbrs for each item
counts = np.diff(nbr_t.indptr)
assert counts.shape == items.shape
# which items are scorable?
counts = np.diff(nbr_rates.indptr)
min_nbr_mask = counts >= min_nbrs
is_nbr_mask = min_nbr_mask[items]
is_scorable = items[is_nbr_mask]

# get the ratings for requested scorable items
nbr_rates = nbr_rates[:, is_scorable]
assert isinstance(nbr_rates, csc_array)
nbr_rates.sort_indices()
counts = counts[is_scorable]

log.debug(
"scoring items",
max_count=np.max(counts),
max_count=np.max(counts) if len(counts) else 0,
nbr_shape=nbr_rates.shape,
)

# fast-path items with small neighborhoods
fp_mask = counts <= max_nbrs
# Now, for our next trick - we have a CSC matrix, whose rows (users) are
# sorted by decreasing similarity. So we can *zero* any entries past the
# first max_neighbors in a row. This can be done with a little bit of
# jiggery-pokery.

# step 1: create a list of column start indices
starts = np.repeat(nbr_rates.indptr[:-1], counts)
# step 2: create a ranking from start to end
ranks = np.arange(nbr_rates.nnz, dtype=np.int32)
# step 3: subtract the column starts — this will give us numbers within rows
ranks -= starts
rmask = ranks >= max_nbrs
# step 4: zero out rating values for everything past max_nbrs
nbr_rates.data[rmask] = 0

# now we can just do a matrix-vector multiply to compute the scores
results = np.full(ni, np.nan)
nbr_fp = nbr_rates[:, fp_mask]
results[fp_mask] = nbr_fp.T @ nbr_sims
results[is_nbr_mask] = nbr_rates.T @ nbr_sims

if average:
nbr_fp_ones = csr_array((np.ones(nbr_fp.nnz), nbr_fp.indices, nbr_fp.indptr), nbr_fp.shape)
tot_sims = nbr_fp_ones.T @ nbr_sims
nbr_ones = csc_array(
(np.where(rmask, 0, 1), nbr_rates.indices, nbr_rates.indptr), nbr_rates.shape
)
tot_sims = nbr_ones.T @ nbr_sims
assert np.all(np.isfinite(tot_sims))
results[fp_mask] /= tot_sims

# clear out too-small neighborhoods
results[counts < min_nbrs] = torch.nan

# deal with too-large items
exc_mask = counts > max_nbrs
n_bad = np.sum(exc_mask)
if n_bad:
log.debug("scoring %d slow-path items", n_bad)

bads = np.argwhere(exc_mask)[:, 0]
for badi in bads:
s, e = nbr_t.indptr[badi : (badi + 2)]

bi_users = nbr_t.indices[s:e]
bi_rates = torch.from_numpy(nbr_t.data[s:e])
bi_sims = nbr_sims[bi_users]

tk_vs, tk_is = torch.topk(bi_sims, max_nbrs)
sum = torch.sum(tk_vs)
if average:
results[badi] = torch.dot(tk_vs, bi_rates[tk_is]) / sum
else:
results[badi] = sum
results[is_nbr_mask] /= tot_sims

return results

0 comments on commit b193bd1

Please sign in to comment.