diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index ceede3e0..4f445e9a 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -469,7 +469,7 @@ class APIADDRequest(BaseRequest): ), ) - info: dict[str, str] | None = Field( + info: dict[str, Any] | None = Field( None, description=( "Additional metadata for the add request. " diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index 0e64ea9a..e25c7cb1 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -138,7 +138,8 @@ def mix_search_memories( target_session_id = search_req.session_id if not target_session_id: target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + search_priority = {"session_id": search_req.session_id} if search_req.session_id else None + search_filter = search_req.filter # Rerank Memories - reranker expects TextualMemoryItem objects @@ -155,6 +156,7 @@ def mix_search_memories( mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, search_filter=search_filter, + search_priority=search_priority, info=info, ) @@ -178,7 +180,7 @@ def mix_search_memories( query=search_req.query, # Use search_req.query instead of undefined query graph_results=history_memories, # Pass TextualMemoryItem objects directly top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k - search_filter=search_filter, + search_priority=search_priority, ) logger.info(f"Reranked {len(sorted_history_memories)} history memories.") processed_hist_mem = self.searcher.post_retrieve( @@ -234,6 +236,7 @@ def mix_search_memories( mode=SearchMode.FAST, memory_type="All", search_filter=search_filter, + search_priority=search_priority, info=info, ) else: diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 534f5d67..6352d584 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -17,7 +17,11 @@ def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=No @abstractmethod def retrieve( - self, query: str, top_k: int, info: dict[str, Any] | None = None + self, + query: str, + top_k: int, + info: dict[str, Any] | None = None, + search_filter: dict[str, Any] | None = None, ) -> list[TextualMemoryItem]: """Retrieve memories from the retriever.""" @@ -76,7 +80,11 @@ def _original_text_reranker( return prefs_mem def retrieve( - self, query: str, top_k: int, info: dict[str, Any] | None = None + self, + query: str, + top_k: int, + info: dict[str, Any] | None = None, + search_filter: dict[str, Any] | None = None, ) -> list[TextualMemoryItem]: """Retrieve memories from the naive retriever.""" # TODO: un-support rewrite query and session filter now @@ -84,6 +92,7 @@ def retrieve( info = info.copy() # Create a copy to avoid modifying the original info.pop("chat_history", None) info.pop("session_id", None) + search_filter = {"and": [info, search_filter]} query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings query_embedding = query_embeddings[0] # Get the first (and only) embedding @@ -96,7 +105,7 @@ def retrieve( query, "explicit_preference", top_k * 2, - info, + search_filter, ) future_implicit = executor.submit( self.vector_db.search, @@ -104,7 +113,7 @@ def retrieve( query, "implicit_preference", top_k * 2, - info, + search_filter, ) # Wait for all results diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 6e196e23..c0ed1217 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -76,7 +76,9 @@ def get_memory( """ return self.extractor.extract(messages, type, info) - def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + def search( + self, query: str, top_k: int, info=None, search_filter=None, **kwargs + ) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: query (str): The query to search for. @@ -85,7 +87,8 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem Returns: list[TextualMemoryItem]: List of matching memories. """ - return self.retriever.retrieve(query, top_k, info) + logger.info(f"search_filter for preference memory: {search_filter}") + return self.retriever.retrieve(query, top_k, info, search_filter) def load(self, dir: str) -> None: """Load memories from the specified directory. diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index 29f30d38..1f02132b 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -50,7 +50,9 @@ def get_memory( """ return self.extractor.extract(messages, type, info) - def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + def search( + self, query: str, top_k: int, info=None, search_filter=None, **kwargs + ) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: query (str): The query to search for. @@ -59,7 +61,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem Returns: list[TextualMemoryItem]: List of matching memories. """ - return self.retriever.retrieve(query, top_k, info) + return self.retriever.retrieve(query, top_k, info, search_filter) def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: """Add memories. diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index df5e05a1..2a109bf7 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -162,6 +162,7 @@ def search( mode: str = "fast", memory_type: str = "All", manual_close_internet: bool = True, + search_priority: dict | None = None, search_filter: dict | None = None, user_name: str | None = None, ) -> list[TextualMemoryItem]: @@ -209,7 +210,14 @@ def search( manual_close_internet=manual_close_internet, ) return searcher.search( - query, top_k, info, mode, memory_type, search_filter, user_name=user_name + query, + top_k, + info, + mode, + memory_type, + search_filter, + search_priority, + user_name=user_name, ) def get_relevant_subgraph( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 37504890..7fa8a87b 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -38,6 +38,7 @@ def retrieve( memory_scope: str, query_embedding: list[list[float]] | None = None, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, use_fast_graph: bool = False, @@ -62,9 +63,12 @@ def retrieve( raise ValueError(f"Unsupported memory scope: {memory_scope}") if memory_scope == "WorkingMemory": - # For working memory, retrieve all entries (no filtering) + # For working memory, retrieve all entries (no session-oriented filtering) working_memories = self.graph_store.get_all_memory_items( - scope="WorkingMemory", include_embedding=False, user_name=user_name + scope="WorkingMemory", + include_embedding=False, + user_name=user_name, + filter=search_filter, ) return [TextualMemoryItem.from_dict(record) for record in working_memories[:top_k]] @@ -84,6 +88,7 @@ def retrieve( memory_scope, top_k, search_filter=search_filter, + search_priority=search_priority, user_name=user_name, ) if self.use_bm25: @@ -274,6 +279,7 @@ def _vector_recall( status: str = "activated", cube_name: str | None = None, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, ) -> list[TextualMemoryItem]: """ @@ -283,7 +289,7 @@ def _vector_recall( if not query_embedding: return [] - def search_single(vec, filt=None): + def search_single(vec, search_priority=None, search_filter=None): return ( self.graph_store.search_by_embedding( vector=vec, @@ -291,31 +297,33 @@ def search_single(vec, filt=None): status=status, scope=memory_scope, cube_name=cube_name, - search_filter=filt, + search_filter=search_priority, + filter=search_filter, user_name=user_name, ) or [] ) def search_path_a(): - """Path A: search without filter""" + """Path A: search without priority""" path_a_hits = [] with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(search_single, vec, None) for vec in query_embedding[:max_num] + executor.submit(search_single, vec, None, search_filter) + for vec in query_embedding[:max_num] ] for f in concurrent.futures.as_completed(futures): path_a_hits.extend(f.result() or []) return path_a_hits def search_path_b(): - """Path B: search with filter""" - if not search_filter: + """Path B: search with priority""" + if not search_priority: return [] path_b_hits = [] with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(search_single, vec, search_filter) + executor.submit(search_single, vec, search_priority, search_filter) for vec in query_embedding[:max_num] ] for f in concurrent.futures.as_completed(futures): diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 26ae1a72..976be6a5 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -69,6 +69,7 @@ def retrieve( mode="fast", memory_type="All", search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: @@ -76,7 +77,12 @@ def retrieve( f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" ) parsed_goal, query_embedding, context, query = self._parse_task( - query, info, mode, search_filter=search_filter, user_name=user_name + query, + info, + mode, + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, ) results = self._retrieve_paths( query, @@ -87,6 +93,7 @@ def retrieve( mode, memory_type, search_filter, + search_priority, user_name, ) return results @@ -112,6 +119,7 @@ def search( mode="fast", memory_type="All", search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, ) -> list[TextualMemoryItem]: """ @@ -128,6 +136,7 @@ def search( memory_type (str): Type restriction for search. ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory'] search_filter (dict, optional): Optional metadata filters for search results. + search_priority (dict, optional): Optional metadata priority for search results. Returns: list[TextualMemoryItem]: List of matching memories. """ @@ -147,6 +156,7 @@ def search( mode=mode, memory_type=memory_type, search_filter=search_filter, + search_priority=search_priority, user_name=user_name, ) @@ -174,6 +184,7 @@ def _parse_task( mode, top_k=5, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, ): """Parse user query, do embedding search and create context""" @@ -192,7 +203,8 @@ def _parse_task( query_embedding, top_k=top_k, status="activated", - search_filter=search_filter, + search_filter=search_priority, + filter=search_filter, user_name=user_name, ) ] @@ -244,6 +256,7 @@ def _retrieve_paths( mode, memory_type, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, ): """Run A/B/C retrieval paths in parallel""" @@ -264,6 +277,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + search_priority, user_name, id_filter, ) @@ -277,6 +291,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + search_priority, user_name, id_filter, mode=mode, @@ -313,6 +328,7 @@ def _retrieve_from_working_memory( top_k, memory_type, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, ): @@ -326,6 +342,7 @@ def _retrieve_from_working_memory( top_k=top_k, memory_scope="WorkingMemory", search_filter=search_filter, + search_priority=search_priority, user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, @@ -349,6 +366,7 @@ def _retrieve_from_long_term_and_user( top_k, memory_type, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, mode: str = "fast", @@ -378,6 +396,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="LongTermMemory", search_filter=search_filter, + search_priority=search_priority, user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, @@ -393,6 +412,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="UserMemory", search_filter=search_filter, + search_priority=search_priority, user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 9c5be2fa..e346bdf1 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -237,7 +237,8 @@ def _fine_search( return self._agentic_search(search_req=search_req, user_context=user_context) target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + search_priority = {"session_id": search_req.session_id} if search_req.session_id else None + search_filter = search_req.filter info = { "user_id": search_req.user_id, @@ -254,6 +255,7 @@ def _fine_search( manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, + search_priority=search_priority, info=info, ) @@ -289,6 +291,7 @@ def _fine_search( top_k=retrieval_size, mode=SearchMode.FAST, memory_type="All", + search_priority=search_priority, search_filter=search_filter, info=info, ) @@ -324,7 +327,8 @@ def _search_pref( """ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] - + print(f"search_req.filter for preference memory: {search_req.filter}") + print(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}") try: results = self.naive_mem_cube.pref_mem.search( query=search_req.query, @@ -334,6 +338,7 @@ def _search_pref( "session_id": search_req.session_id, "chat_history": search_req.chat_history, }, + search_filter=search_req.filter, ) return [format_memory_item(data) for data in results] except Exception as e: @@ -356,8 +361,9 @@ def _fast_search( List of search results """ target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - + search_priority = {"session_id": search_req.session_id} if search_req.session_id else None + search_filter = search_req.filter or None + print(f"type of text_mem: {type(self.naive_mem_cube.text_mem)}") search_results = self.naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, @@ -365,6 +371,7 @@ def _fast_search( mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, search_filter=search_filter, + search_priority=search_priority, info={ "user_id": search_req.user_id, "session_id": target_session_id, diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index db5a51fc..764b5303 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -125,7 +125,7 @@ def rerank( query: str, graph_results: list[TextualMemoryItem], top_k: int, - search_filter: dict | None = None, + search_priority: dict | None = None, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: """ @@ -140,7 +140,7 @@ def rerank( `.memory` str field; non-strings are ignored. top_k : int Return at most this many items. - search_filter : dict | None + search_priority : dict | None, optional Currently unused. Present to keep signature compatible. Returns @@ -194,7 +194,7 @@ def rerank( raw_score = float(r.get("relevance_score", r.get("score", 0.0))) item = graph_results[idx] # generic boost - score = self._apply_boost_generic(item, raw_score, search_filter) + score = self._apply_boost_generic(item, raw_score, search_priority) scored_items.append((item, score)) scored_items.sort(key=lambda x: x[1], reverse=True) @@ -213,7 +213,7 @@ def rerank( scored_items = [] for item, raw_score in zip(graph_results, score_list, strict=False): - score = self._apply_boost_generic(item, raw_score, search_filter) + score = self._apply_boost_generic(item, raw_score, search_priority) scored_items.append((item, score)) scored_items.sort(key=lambda x: x[1], reverse=True) diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index eafee263..2181961d 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -229,6 +229,7 @@ def search( List of search results with distance scores and payloads. """ # Convert filter to Milvus expression + print(f"filter for milvus: {filter}") expr = self._dict_to_expr(filter) if filter else "" search_func_map = { @@ -267,27 +268,175 @@ def search( return items def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str: - """Convert a dictionary filter to a Milvus expression string.""" + """Convert a dictionary filter to a Milvus expression string. + + Supports complex query syntax with logical operators, comparison operators, + arithmetic operators, array operators, and string pattern matching. + + Args: + filter_dict: Dictionary containing filter conditions + + Returns: + Milvus expression string + """ if not filter_dict: return "" + return self._build_expression(filter_dict) + + def _build_expression(self, condition: Any) -> str: + """Build expression from condition dict or value.""" + if isinstance(condition, dict): + # Handle logical operators + if "and" in condition: + return self._handle_logical_and(condition["and"]) + elif "or" in condition: + return self._handle_logical_or(condition["or"]) + elif "not" in condition: + return self._handle_logical_not(condition["not"]) + else: + # Handle field conditions + return self._handle_field_conditions(condition) + else: + # Simple value comparison + return f"{condition}" + + def _handle_logical_and(self, conditions: list) -> str: + """Handle AND logical operator.""" + if not conditions: + return "" + expressions = [self._build_expression(cond) for cond in conditions if cond is not None] + expressions = [expr for expr in expressions if expr] + if not expressions: + return "" + return f"({' and '.join(expressions)})" + + def _handle_logical_or(self, conditions: list) -> str: + """Handle OR logical operator.""" + if not conditions: + return "" + expressions = [self._build_expression(cond) for cond in conditions if cond is not None] + expressions = [expr for expr in expressions if expr] + if not expressions: + return "" + return f"({' or '.join(expressions)})" + + def _handle_logical_not(self, condition: Any) -> str: + """Handle NOT logical operator.""" + expr = self._build_expression(condition) + if not expr: + return "" + return f"(not {expr})" + + def _handle_field_conditions(self, condition_dict: dict[str, Any]) -> str: + """Handle field-specific conditions.""" conditions = [] - for field, value in filter_dict.items(): - # Skip None values as they cause Milvus query syntax errors + + for field, value in condition_dict.items(): if value is None: continue - # For JSON fields, we need to use payload["field"] syntax - elif isinstance(value, str): - conditions.append(f"payload['{field}'] == '{value}'") - elif isinstance(value, list) and len(value) == 0: - # Skip empty lists as they cause Milvus query syntax errors - continue - elif isinstance(value, list) and len(value) > 0: - conditions.append(f"payload['{field}'] in {value}") - else: - conditions.append(f"payload['{field}'] == '{value}'") + + field_expr = self._build_field_expression(field, value) + if field_expr: + conditions.append(field_expr) + + if not conditions: + return "" return " and ".join(conditions) + def _build_field_expression(self, field: str, value: Any) -> str: + """Build expression for a single field.""" + # Handle comparison operators + if isinstance(value, dict): + if len(value) == 1: + op, operand = next(iter(value.items())) + op_lower = op.lower() + + if op_lower == "in": + return self._handle_in_operator(field, operand) + elif op_lower == "contains": + return self._handle_contains_operator(field, operand, case_sensitive=True) + elif op_lower == "icontains": + return self._handle_contains_operator(field, operand, case_sensitive=False) + elif op_lower == "like": + return self._handle_like_operator(field, operand) + elif op_lower in ["gte", "lte", "gt", "lt", "ne"]: + return self._handle_comparison_operator(field, op_lower, operand) + else: + # Unknown operator, treat as equality + return f"payload['{field}'] == {self._format_value(operand)}" + else: + # Multiple operators, handle each one + sub_conditions = [] + for op, operand in value.items(): + op_lower = op.lower() + if op_lower in [ + "gte", + "lte", + "gt", + "lt", + "ne", + "in", + "contains", + "icontains", + "like", + ]: + sub_expr = self._build_field_expression(field, {op: operand}) + if sub_expr: + sub_conditions.append(sub_expr) + + if sub_conditions: + return f"({' and '.join(sub_conditions)})" + return "" + else: + # Simple equality + return f"payload['{field}'] == {self._format_value(value)}" + + def _handle_in_operator(self, field: str, values: list) -> str: + """Handle IN operator for arrays.""" + if not isinstance(values, list) or not values: + return "" + + formatted_values = [self._format_value(v) for v in values] + return f"payload['{field}'] in [{', '.join(formatted_values)}]" + + def _handle_contains_operator(self, field: str, value: Any, case_sensitive: bool = True) -> str: + """Handle CONTAINS/ICONTAINS operator.""" + formatted_value = self._format_value(value) + if case_sensitive: + return f"json_contains(payload['{field}'], {formatted_value})" + else: + # For case-insensitive contains, we need to use LIKE with lower case + return f"(not json_contains(payload['{field}'], {formatted_value}))" + + def _handle_like_operator(self, field: str, pattern: str) -> str: + """Handle LIKE operator for string pattern matching.""" + # Convert SQL-like pattern to Milvus-like pattern + return f"payload['{field}'] like '{pattern}'" + + def _handle_comparison_operator(self, field: str, operator: str, value: Any) -> str: + """Handle comparison operators (gte, lte, gt, lt, ne).""" + milvus_op = {"gte": ">=", "lte": "<=", "gt": ">", "lt": "<", "ne": "!="}.get(operator, "==") + + formatted_value = self._format_value(value) + return f"payload['{field}'] {milvus_op} {formatted_value}" + + def _format_value(self, value: Any) -> str: + """Format value for Milvus expression.""" + if isinstance(value, str): + return f"'{value}'" + elif isinstance(value, int | float): + return str(value) + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, list): + formatted_items = [self._format_value(item) for item in value] + return f"[{', '.join(formatted_items)}]" + elif value is None: + return "null" + else: + return f"'{value!s}'" + def _get_metric_type(self) -> str: """Get the metric type for search.""" metric_map = {