From f2de3b6bdfff98bf6551d0ead9f0303a8b9561f2 Mon Sep 17 00:00:00 2001 From: Rishabh <134101578+GitHoobar@users.noreply.github.com> Date: Wed, 10 Dec 2025 20:30:54 +0530 Subject: [PATCH] feat: add similarity threshold filtering for recall results --- memori/__init__.py | 15 +++++++++++++++ memori/_config.py | 1 + memori/_search.py | 8 +++++++- memori/memory/recall.py | 1 + 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/memori/__init__.py b/memori/__init__.py index 59d95e88..0c7ecca4 100644 --- a/memori/__init__.py +++ b/memori/__init__.py @@ -125,3 +125,18 @@ def set_session(self, id): def recall(self, query: str, limit: int = 5): return Recall(self.config).search_facts(query, limit) + + def set_recall_threshold(self, threshold: float) -> "Memori": + """Set the minimum similarity threshold for recall results. + + Args: + threshold: Minimum similarity score (0.0 to 1.0). Results below this + threshold will be filtered out. + + Returns: + Self for method chaining. + """ + if threshold < 0.0 or threshold > 1.0: + raise ValueError("Threshold must be between 0.0 and 1.0") + self.config.recall_similarity_threshold = threshold + return self diff --git a/memori/_config.py b/memori/_config.py index ea8dcfcf..70b82977 100644 --- a/memori/_config.py +++ b/memori/_config.py @@ -41,6 +41,7 @@ def __init__(self): self.recall_embeddings_limit = 1000 self.recall_facts_limit = 5 self.recall_relevance_threshold = 0.1 + self.recall_similarity_threshold = 0.5 # Minimum similarity score for recall results self.request_backoff_factor = 1 self.request_num_backoff = 5 self.request_secs_timeout = 5 diff --git a/memori/_search.py b/memori/_search.py index 3539ffb2..0f170fca 100644 --- a/memori/_search.py +++ b/memori/_search.py @@ -40,6 +40,7 @@ def find_similar_embeddings( embeddings: list[tuple[int, Any]], query_embedding: list[float], limit: int = 5, + similarity_threshold: float = 0.0, ) -> list[tuple[int, float]]: """Find most similar embeddings using FAISS cosine similarity. @@ -87,7 +88,11 @@ def find_similar_embeddings( results = [] for result_idx, embedding_idx in enumerate(indices[0]): if embedding_idx >= 0 and embedding_idx < len(id_list): - results.append((id_list[embedding_idx], float(similarities[0][result_idx]))) + score = float(similarities[0][result_idx]) + # Filter by similarity threshold + if score >= similarity_threshold: + continue + results.append((id_list[embedding_idx], score)) return results @@ -98,6 +103,7 @@ def search_entity_facts( query_embedding: list[float], limit: int, embeddings_limit: int, + similarity_threshold: float = 0.0, ) -> list[dict]: """Search entity facts by embedding similarity. diff --git a/memori/memory/recall.py b/memori/memory/recall.py index c171fb9c..76f908bc 100644 --- a/memori/memory/recall.py +++ b/memori/memory/recall.py @@ -52,6 +52,7 @@ def search_facts( query_embedding, limit, self.config.recall_embeddings_limit, + self.config.recall_relevance_threshold, ) break except OperationalError as e: