diff --git a/mem0/vector_stores/milvus.py b/mem0/vector_stores/milvus.py index 09e49a954f..70bff4ac93 100644 --- a/mem0/vector_stores/milvus.py +++ b/mem0/vector_stores/milvus.py @@ -178,9 +178,18 @@ def update(self, vector_id=None, vector=None, payload=None): Args: vector_id (str): ID of the vector to update. - vector (List[float], optional): Updated vector. + vector (List[float], optional): Updated vector. If None, the existing vector will be preserved. payload (Dict, optional): Updated payload. """ + # If vector is None, fetch the existing vector to preserve it + if vector is None: + existing_data = self.client.get(collection_name=self.collection_name, ids=vector_id) + if not existing_data: + raise ValueError(f"Vector with ID {vector_id} not found") + vector = existing_data[0].get("vectors") + if vector is None: + raise ValueError(f"Could not retrieve existing vector for ID {vector_id}") + schema = {"id": vector_id, "vectors": vector, "metadata": payload} self.client.upsert(collection_name=self.collection_name, data=schema) diff --git a/tests/vector_stores/test_milvus.py b/tests/vector_stores/test_milvus.py index d296620be8..06c8e45abb 100644 --- a/tests/vector_stores/test_milvus.py +++ b/tests/vector_stores/test_milvus.py @@ -157,18 +157,58 @@ def test_update_uses_upsert(self, milvus_db, mock_milvus_client): vector_id = "test_id" vector = [0.1] * 1536 payload = {"user_id": "alice", "data": "Updated memory"} - + milvus_db.update(vector_id=vector_id, vector=vector, payload=payload) - + # Verify upsert was called (not delete+insert) mock_milvus_client.upsert.assert_called_once() - + call_args = mock_milvus_client.upsert.call_args assert call_args[1]['collection_name'] == "test_collection" assert call_args[1]['data']['id'] == vector_id assert call_args[1]['data']['vectors'] == vector assert call_args[1]['data']['metadata'] == payload + def test_update_with_vector_none(self, milvus_db, mock_milvus_client): + """Test that update with vector=None fetches and preserves existing vector.""" + vector_id = "test_id" + existing_vector = [0.5] * 1536 + payload = {"user_id": "alice", "data": "Updated metadata only"} + + # Mock the get call to return existing vector + mock_milvus_client.get.return_value = [ + {"id": vector_id, "vectors": existing_vector, "metadata": {"user_id": "alice", "data": "Old data"}} + ] + + # Update with vector=None should fetch existing vector + milvus_db.update(vector_id=vector_id, vector=None, payload=payload) + + # Verify get was called to fetch existing vector + mock_milvus_client.get.assert_called_once_with( + collection_name="test_collection", + ids=vector_id + ) + + # Verify upsert was called with the existing vector + mock_milvus_client.upsert.assert_called_once() + call_args = mock_milvus_client.upsert.call_args + assert call_args[1]['collection_name'] == "test_collection" + assert call_args[1]['data']['id'] == vector_id + assert call_args[1]['data']['vectors'] == existing_vector # Should use existing vector + assert call_args[1]['data']['metadata'] == payload + + def test_update_with_vector_none_raises_error_if_not_found(self, milvus_db, mock_milvus_client): + """Test that update with vector=None raises error if vector not found.""" + vector_id = "nonexistent_id" + payload = {"user_id": "alice"} + + # Mock the get call to return empty list + mock_milvus_client.get.return_value = [] + + # Should raise ValueError + with pytest.raises(ValueError, match=f"Vector with ID {vector_id} not found"): + milvus_db.update(vector_id=vector_id, vector=None, payload=payload) + def test_delete(self, milvus_db, mock_milvus_client): """Test vector deletion.""" vector_id = "test_id"