diff --git a/ragatouille/models/colbert.py b/ragatouille/models/colbert.py index b2781df..73d8f95 100644 --- a/ragatouille/models/colbert.py +++ b/ragatouille/models/colbert.py @@ -486,9 +486,8 @@ def _index_free_search( for query in embedded_queries: results_for_query = [] scores = self._colbert_score(query, embedded_docs, doc_mask) - sorted_scores = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) - high_score_idxes = [index for index, _ in sorted_scores[:k]] - for rank, doc_idx in enumerate(high_score_idxes): + sorted_scores = torch.topk(scores, k) + for rank, doc_idx in enumerate(sorted_scores.indices.tolist()): result = { "content": documents[doc_idx], "score": float(scores[doc_idx]),