From 7dd23302ec19225657d40ec4c36c9d639f9b55a0 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 8 Jan 2025 16:37:11 -0600 Subject: [PATCH] INTPYTHON-452 Add hybrid retriever test with nested field (#54) --- .../langchain_mongodb/vectorstores.py | 2 + .../integration_tests/test_retrievers.py | 78 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py index 00ee1e3..4c7b9ae 100644 --- a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py +++ b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py @@ -758,6 +758,8 @@ def _similarity_search_with_score( # Format for res in cursor: + if self._text_key not in res: + continue text = res.pop(self._text_key) score = res.pop("score") make_serializable(res) diff --git a/libs/langchain-mongodb/tests/integration_tests/test_retrievers.py b/libs/langchain-mongodb/tests/integration_tests/test_retrievers.py index abf85af..e7bd2ff 100644 --- a/libs/langchain-mongodb/tests/integration_tests/test_retrievers.py +++ b/libs/langchain-mongodb/tests/integration_tests/test_retrievers.py @@ -21,10 +21,13 @@ DB_NAME = "langchain_test_db" COLLECTION_NAME = "langchain_test_retrievers" +COLLECTION_NAME_NESTED = "langchain_test_retrievers_nested" VECTOR_INDEX_NAME = "vector_index" EMBEDDING_FIELD = "embedding" PAGE_CONTENT_FIELD = "text" +PAGE_CONTENT_FIELD_NESTED = "title.text" SEARCH_INDEX_NAME = "text_index" +SEARCH_INDEX_NAME_NESTED = "text_index_nested" TIMEOUT = 60.0 INTERVAL = 0.5 @@ -71,6 +74,39 @@ def collection(client: MongoClient, dimensions: int) -> Collection: return clxn +@pytest.fixture(scope="module") +def collection_nested(client: MongoClient, dimensions: int) -> Collection: + """A Collection with both a Vector and a Full-text Search Index""" + if COLLECTION_NAME_NESTED not in client[DB_NAME].list_collection_names(): + clxn = client[DB_NAME].create_collection(COLLECTION_NAME_NESTED) + else: + clxn = client[DB_NAME][COLLECTION_NAME_NESTED] + + clxn.delete_many({}) + + if not any([VECTOR_INDEX_NAME == ix["name"] for ix in clxn.list_search_indexes()]): + create_vector_search_index( + collection=clxn, + index_name=VECTOR_INDEX_NAME, + dimensions=dimensions, + path="embedding", + similarity="cosine", + wait_until_complete=TIMEOUT, + ) + + if not any( + [SEARCH_INDEX_NAME_NESTED == ix["name"] for ix in clxn.list_search_indexes()] + ): + create_fulltext_search_index( + collection=clxn, + index_name=SEARCH_INDEX_NAME_NESTED, + field=PAGE_CONTENT_FIELD_NESTED, + wait_until_complete=TIMEOUT, + ) + + return clxn + + @pytest.fixture(scope="module") def indexed_vectorstore( collection: Collection, @@ -93,6 +129,28 @@ def indexed_vectorstore( vectorstore.collection.delete_many({}) +@pytest.fixture(scope="module") +def indexed_nested_vectorstore( + collection_nested: Collection, + example_documents: List[Document], + embedding: Embeddings, +) -> Generator[MongoDBAtlasVectorSearch, None, None]: + """Return a VectorStore with example document embeddings indexed.""" + + vectorstore = PatchedMongoDBAtlasVectorSearch( + collection=collection_nested, + embedding=embedding, + index_name=VECTOR_INDEX_NAME, + text_key=PAGE_CONTENT_FIELD_NESTED, + ) + + vectorstore.add_documents(example_documents) + + yield vectorstore + + vectorstore.collection.delete_many({}) + + def test_vector_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) -> None: """Test VectorStoreRetriever""" retriever = indexed_vectorstore.as_retriever() @@ -125,6 +183,26 @@ def test_hybrid_retriever(indexed_vectorstore: PatchedMongoDBAtlasVectorSearch) assert "New Orleans" in results[0].page_content +def test_hybrid_retriever_nested( + indexed_nested_vectorstore: PatchedMongoDBAtlasVectorSearch, +) -> None: + """Test basic usage of MongoDBAtlasHybridSearchRetriever""" + retriever = MongoDBAtlasHybridSearchRetriever( + vectorstore=indexed_nested_vectorstore, + search_index_name=SEARCH_INDEX_NAME_NESTED, + top_k=3, + ) + + query1 = "What did I visit France?" + results = retriever.invoke(query1) + assert len(results) == 3 + assert "Paris" in results[0].page_content + + query2 = "When was the last time I visited new orleans?" + results = retriever.invoke(query2) + assert "New Orleans" in results[0].page_content + + def test_fulltext_retriever( indexed_vectorstore: PatchedMongoDBAtlasVectorSearch, ) -> None: