diff --git a/mem0/memory/main.py b/mem0/memory/main.py index db19a48beb..702551d9d0 100644 --- a/mem0/memory/main.py +++ b/mem0/memory/main.py @@ -289,6 +289,7 @@ def add( infer: bool = True, memory_type: Optional[str] = None, prompt: Optional[str] = None, + custom_fact_extraction_prompt: Optional[str] = None, ): """ Create a new memory. @@ -367,7 +368,7 @@ def add( messages = parse_vision_messages(messages) with concurrent.futures.ThreadPoolExecutor() as executor: - future1 = executor.submit(self._add_to_vector_store, messages, processed_metadata, effective_filters, infer) + future1 = executor.submit(self._add_to_vector_store, messages, processed_metadata, effective_filters, infer, custom_fact_extraction_prompt) future2 = executor.submit(self._add_to_graph, messages, effective_filters) concurrent.futures.wait([future1, future2]) @@ -383,7 +384,7 @@ def add( return {"results": vector_store_result} - def _add_to_vector_store(self, messages, metadata, filters, infer): + def _add_to_vector_store(self, messages, metadata, filters, infer, custom_fact_extraction_prompt): if not infer: returned_memories = [] for message_dict in messages: @@ -421,8 +422,10 @@ def _add_to_vector_store(self, messages, metadata, filters, infer): return returned_memories parsed_messages = parse_messages(messages) - - if self.config.custom_fact_extraction_prompt: + if custom_fact_extraction_prompt: + system_prompt = custom_fact_extraction_prompt + user_prompt = f"Input:\n{parsed_messages}" + elif self.config.custom_fact_extraction_prompt: system_prompt = self.config.custom_fact_extraction_prompt user_prompt = f"Input:\n{parsed_messages}" else: @@ -1332,6 +1335,7 @@ async def add( memory_type: Optional[str] = None, prompt: Optional[str] = None, llm=None, + custom_fact_extraction_prompt: Optional[str] = None, ): """ Create a new memory asynchronously. @@ -1385,7 +1389,7 @@ async def add( messages = parse_vision_messages(messages) vector_store_task = asyncio.create_task( - self._add_to_vector_store(messages, processed_metadata, effective_filters, infer) + self._add_to_vector_store(messages, processed_metadata, effective_filters, infer, custom_fact_extraction_prompt) ) graph_task = asyncio.create_task(self._add_to_graph(messages, effective_filters)) @@ -1405,6 +1409,7 @@ async def _add_to_vector_store( metadata: dict, effective_filters: dict, infer: bool, + custom_fact_extraction_prompt: Optional[str] = None, ): if not infer: returned_memories = [] @@ -1443,7 +1448,10 @@ async def _add_to_vector_store( return returned_memories parsed_messages = parse_messages(messages) - if self.config.custom_fact_extraction_prompt: + if custom_fact_extraction_prompt: + system_prompt = custom_fact_extraction_prompt + user_prompt = f"Input:\n{parsed_messages}" + elif self.config.custom_fact_extraction_prompt: system_prompt = self.config.custom_fact_extraction_prompt user_prompt = f"Input:\n{parsed_messages}" else: diff --git a/mem0/vector_stores/pgvector.py b/mem0/vector_stores/pgvector.py index e2d020a66d..5c199f2ecc 100644 --- a/mem0/vector_stores/pgvector.py +++ b/mem0/vector_stores/pgvector.py @@ -231,10 +231,10 @@ def search( with self._get_cursor() as cur: cur.execute( f""" - SELECT id, vector <=> %s::vector AS distance, payload + SELECT id, 1.0 - (vector <=> %s::vector) AS distance, payload FROM {self.collection_name} {filter_clause} - ORDER BY distance + ORDER BY distance DESC LIMIT %s """, (vectors, *filter_params, limit), diff --git a/server/main.py b/server/main.py index 85c7cc7ea2..3b034a7be6 100644 --- a/server/main.py +++ b/server/main.py @@ -76,6 +76,7 @@ class MemoryCreate(BaseModel): agent_id: Optional[str] = None run_id: Optional[str] = None metadata: Optional[Dict[str, Any]] = None + custom_fact_extraction_prompt: Optional[str] = Field(None, description="Custom prompt for fact extraction.") class SearchRequest(BaseModel):