Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/memos/api/product_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ class APIADDRequest(BaseRequest):
),
)

info: dict[str, str] | None = Field(
info: dict[str, Any] | None = Field(
None,
description=(
"Additional metadata for the add request. "
Expand Down
7 changes: 5 additions & 2 deletions src/memos/mem_scheduler/optimized_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def mix_search_memories(
target_session_id = search_req.session_id
if not target_session_id:
target_session_id = "default_session"
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
search_priority = {"session_id": search_req.session_id} if search_req.session_id else None
search_filter = search_req.filter

# Rerank Memories - reranker expects TextualMemoryItem objects

Expand All @@ -155,6 +156,7 @@ def mix_search_memories(
mode=SearchMode.FAST,
manual_close_internet=not search_req.internet_search,
search_filter=search_filter,
search_priority=search_priority,
info=info,
)

Expand All @@ -178,7 +180,7 @@ def mix_search_memories(
query=search_req.query, # Use search_req.query instead of undefined query
graph_results=history_memories, # Pass TextualMemoryItem objects directly
top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k
search_filter=search_filter,
search_priority=search_priority,
)
logger.info(f"Reranked {len(sorted_history_memories)} history memories.")
processed_hist_mem = self.searcher.post_retrieve(
Expand Down Expand Up @@ -234,6 +236,7 @@ def mix_search_memories(
mode=SearchMode.FAST,
memory_type="All",
search_filter=search_filter,
search_priority=search_priority,
info=info,
)
else:
Expand Down
17 changes: 13 additions & 4 deletions src/memos/memories/textual/prefer_text_memory/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=No

@abstractmethod
def retrieve(
self, query: str, top_k: int, info: dict[str, Any] | None = None
self,
query: str,
top_k: int,
info: dict[str, Any] | None = None,
search_filter: dict[str, Any] | None = None,
) -> list[TextualMemoryItem]:
"""Retrieve memories from the retriever."""

Expand Down Expand Up @@ -76,14 +80,19 @@ def _original_text_reranker(
return prefs_mem

def retrieve(
self, query: str, top_k: int, info: dict[str, Any] | None = None
self,
query: str,
top_k: int,
info: dict[str, Any] | None = None,
search_filter: dict[str, Any] | None = None,
) -> list[TextualMemoryItem]:
"""Retrieve memories from the naive retriever."""
# TODO: un-support rewrite query and session filter now
if info:
info = info.copy() # Create a copy to avoid modifying the original
info.pop("chat_history", None)
info.pop("session_id", None)
search_filter = {"and": [info, search_filter]}
query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings
query_embedding = query_embeddings[0] # Get the first (and only) embedding

Expand All @@ -96,15 +105,15 @@ def retrieve(
query,
"explicit_preference",
top_k * 2,
info,
search_filter,
)
future_implicit = executor.submit(
self.vector_db.search,
query_embedding,
query,
"implicit_preference",
top_k * 2,
info,
search_filter,
)

# Wait for all results
Expand Down
7 changes: 5 additions & 2 deletions src/memos/memories/textual/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def get_memory(
"""
return self.extractor.extract(messages, type, info)

def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
def search(
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
Args:
query (str): The query to search for.
Expand All @@ -85,7 +87,8 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
return self.retriever.retrieve(query, top_k, info)
print(f"search_filter for preference memory: {search_filter}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print and use log.info

return self.retriever.retrieve(query, top_k, info, search_filter)

def load(self, dir: str) -> None:
"""Load memories from the specified directory.
Expand Down
6 changes: 4 additions & 2 deletions src/memos/memories/textual/simple_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def get_memory(
"""
return self.extractor.extract(messages, type, info)

def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]:
def search(
self, query: str, top_k: int, info=None, search_filter=None, **kwargs
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
Args:
query (str): The query to search for.
Expand All @@ -59,7 +61,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
return self.retriever.retrieve(query, top_k, info)
return self.retriever.retrieve(query, top_k, info, search_filter)

def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]:
"""Add memories.
Expand Down
10 changes: 9 additions & 1 deletion src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def search(
mode: str = "fast",
memory_type: str = "All",
manual_close_internet: bool = True,
search_priority: dict | None = None,
search_filter: dict | None = None,
user_name: str | None = None,
) -> list[TextualMemoryItem]:
Expand Down Expand Up @@ -209,7 +210,14 @@ def search(
manual_close_internet=manual_close_internet,
)
return searcher.search(
query, top_k, info, mode, memory_type, search_filter, user_name=user_name
query,
top_k,
info,
mode,
memory_type,
search_filter,
search_priority,
user_name=user_name,
)

def get_relevant_subgraph(
Expand Down
26 changes: 17 additions & 9 deletions src/memos/memories/textual/tree_text_memory/retrieve/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def retrieve(
memory_scope: str,
query_embedding: list[list[float]] | None = None,
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
id_filter: dict | None = None,
use_fast_graph: bool = False,
Expand All @@ -62,9 +63,12 @@ def retrieve(
raise ValueError(f"Unsupported memory scope: {memory_scope}")

if memory_scope == "WorkingMemory":
# For working memory, retrieve all entries (no filtering)
# For working memory, retrieve all entries (no session-oriented filtering)
working_memories = self.graph_store.get_all_memory_items(
scope="WorkingMemory", include_embedding=False, user_name=user_name
scope="WorkingMemory",
include_embedding=False,
user_name=user_name,
filter=search_filter,
)
return [TextualMemoryItem.from_dict(record) for record in working_memories[:top_k]]

Expand All @@ -84,6 +88,7 @@ def retrieve(
memory_scope,
top_k,
search_filter=search_filter,
search_priority=search_priority,
user_name=user_name,
)
if self.use_bm25:
Expand Down Expand Up @@ -274,6 +279,7 @@ def _vector_recall(
status: str = "activated",
cube_name: str | None = None,
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
) -> list[TextualMemoryItem]:
"""
Expand All @@ -283,39 +289,41 @@ def _vector_recall(
if not query_embedding:
return []

def search_single(vec, filt=None):
def search_single(vec, search_priority=None, search_filter=None):
return (
self.graph_store.search_by_embedding(
vector=vec,
top_k=top_k,
status=status,
scope=memory_scope,
cube_name=cube_name,
search_filter=filt,
search_filter=search_priority,
filter=search_filter,
user_name=user_name,
)
or []
)

def search_path_a():
"""Path A: search without filter"""
"""Path A: search without priority"""
path_a_hits = []
with ContextThreadPoolExecutor() as executor:
futures = [
executor.submit(search_single, vec, None) for vec in query_embedding[:max_num]
executor.submit(search_single, vec, None, search_filter)
for vec in query_embedding[:max_num]
]
for f in concurrent.futures.as_completed(futures):
path_a_hits.extend(f.result() or [])
return path_a_hits

def search_path_b():
"""Path B: search with filter"""
if not search_filter:
"""Path B: search with priority"""
if not search_priority:
return []
path_b_hits = []
with ContextThreadPoolExecutor() as executor:
futures = [
executor.submit(search_single, vec, search_filter)
executor.submit(search_single, vec, search_priority, search_filter)
for vec in query_embedding[:max_num]
]
for f in concurrent.futures.as_completed(futures):
Expand Down
24 changes: 22 additions & 2 deletions src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,20 @@ def retrieve(
mode="fast",
memory_type="All",
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
**kwargs,
) -> list[tuple[TextualMemoryItem, float]]:
logger.info(
f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}"
)
parsed_goal, query_embedding, context, query = self._parse_task(
query, info, mode, search_filter=search_filter, user_name=user_name
query,
info,
mode,
search_filter=search_filter,
search_priority=search_priority,
user_name=user_name,
)
results = self._retrieve_paths(
query,
Expand All @@ -87,6 +93,7 @@ def retrieve(
mode,
memory_type,
search_filter,
search_priority,
user_name,
)
return results
Expand All @@ -112,6 +119,7 @@ def search(
mode="fast",
memory_type="All",
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
) -> list[TextualMemoryItem]:
"""
Expand All @@ -128,6 +136,7 @@ def search(
memory_type (str): Type restriction for search.
['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory']
search_filter (dict, optional): Optional metadata filters for search results.
search_priority (dict, optional): Optional metadata priority for search results.
Returns:
list[TextualMemoryItem]: List of matching memories.
"""
Expand All @@ -147,6 +156,7 @@ def search(
mode=mode,
memory_type=memory_type,
search_filter=search_filter,
search_priority=search_priority,
user_name=user_name,
)

Expand Down Expand Up @@ -174,6 +184,7 @@ def _parse_task(
mode,
top_k=5,
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
):
"""Parse user query, do embedding search and create context"""
Expand All @@ -192,7 +203,8 @@ def _parse_task(
query_embedding,
top_k=top_k,
status="activated",
search_filter=search_filter,
search_filter=search_priority,
filter=search_filter,
user_name=user_name,
)
]
Expand Down Expand Up @@ -244,6 +256,7 @@ def _retrieve_paths(
mode,
memory_type,
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
):
"""Run A/B/C retrieval paths in parallel"""
Expand All @@ -264,6 +277,7 @@ def _retrieve_paths(
top_k,
memory_type,
search_filter,
search_priority,
user_name,
id_filter,
)
Expand All @@ -277,6 +291,7 @@ def _retrieve_paths(
top_k,
memory_type,
search_filter,
search_priority,
user_name,
id_filter,
mode=mode,
Expand Down Expand Up @@ -313,6 +328,7 @@ def _retrieve_from_working_memory(
top_k,
memory_type,
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
id_filter: dict | None = None,
):
Expand All @@ -326,6 +342,7 @@ def _retrieve_from_working_memory(
top_k=top_k,
memory_scope="WorkingMemory",
search_filter=search_filter,
search_priority=search_priority,
user_name=user_name,
id_filter=id_filter,
use_fast_graph=self.use_fast_graph,
Expand All @@ -349,6 +366,7 @@ def _retrieve_from_long_term_and_user(
top_k,
memory_type,
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
id_filter: dict | None = None,
mode: str = "fast",
Expand Down Expand Up @@ -378,6 +396,7 @@ def _retrieve_from_long_term_and_user(
top_k=top_k * 2,
memory_scope="LongTermMemory",
search_filter=search_filter,
search_priority=search_priority,
user_name=user_name,
id_filter=id_filter,
use_fast_graph=self.use_fast_graph,
Expand All @@ -393,6 +412,7 @@ def _retrieve_from_long_term_and_user(
top_k=top_k * 2,
memory_scope="UserMemory",
search_filter=search_filter,
search_priority=search_priority,
user_name=user_name,
id_filter=id_filter,
use_fast_graph=self.use_fast_graph,
Expand Down
Loading
Loading