diff --git a/.env.example b/.env.example index ecddd852..5c786bac 100644 --- a/.env.example +++ b/.env.example @@ -39,9 +39,23 @@ SEARXNG_STRICT=False SEARXNG_DEBUG=False SEARXNG_ENGINES="google,arxiv,google_scholar" +# For FactCheck +FACT_CHECK_API_URL="https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud" +FACT_CHECK_JOB_TIMEOUT="1800" +FACT_CHECK_REQUEST_TIMEOUT="60" +FACT_CHECK_START_ENDPOINT="/fact-check/start" +FACT_CHECK_STATUS_ENDPOINT="/fact-check/status" +FACT_CHECK_CORRECT_ENDPOINT="/correct" +FACT_CHECK_DISPLAY_GRAPH_ENDPOINT="/display_graph" +FACT_CHECK_GRAPH_JSON_ENDPOINT="/graph/json" + # for your vector store EMBEDDING_MODEL_ID="nasa-impact/nasa-smd-ibm-st-v2" +VECTOR_DB_PATH="./chroma_db" +# API key for various embedding functions (chroma) +EMBEDDING_MODEL_API_KEY="" + # number of wiki results to return TOP_N_WIKI_RESULTS = 1 diff --git a/akd/tools/fact_check.py b/akd/tools/fact_check.py new file mode 100644 index 00000000..a6a51c31 --- /dev/null +++ b/akd/tools/fact_check.py @@ -0,0 +1,178 @@ +import asyncio +import os +from typing import Any, Dict, List, Optional + +import httpx +from loguru import logger +from pydantic import Field, HttpUrl + +from akd._base import InputSchema, OutputSchema +from akd.tools._base import BaseTool, BaseToolConfig + + +class FactCheckInputSchema(InputSchema): + """Input schema for the Fact-Checking Tool.""" + + question: str = Field(..., description="The original question that was asked.") + answer: str = Field(..., description="The LLM answer to be fact-checked.") + + +class FactCheckOutputSchema(OutputSchema): + """Output schema for the Fact-Checking Tool's results.""" + + fact_reasoner_score: Dict[str, Any] = Field( + ..., + description="The full scoring dictionary from the FactReasoner.", + ) + supported_atoms: List[Dict[str, Any]] = Field( + ..., + description="List of atoms determined to be supported.", + ) + not_supported_atoms: List[Dict[str, Any]] = Field( + ..., + description="List of atoms determined to be not supported.", + ) + contexts: List[Dict[str, Any]] = Field( + ..., + description="List of retrieved contexts used for the check.", + ) + graph_id: Optional[str] = Field( + None, + description="The unique ID for the generated fact graph.", + ) + logging_metadata: Dict[str, Any] = Field( + {}, + description="Additional logging metadata from the run.", + ) + + +class FactCheckToolConfig(BaseToolConfig): + """Configuration for the FactCheckTool.""" + + base_url: HttpUrl = Field( + default=HttpUrl(os.getenv("FACT_CHECK_API_URL", "http://localhost:8011")), + description="The base URL of the remote Fact-Checking and Correction Service.", + ) + start_endpoint: str = Field( + default=os.getenv("FACT_CHECK_START_ENDPOINT", "/fact-check/start"), + description="Endpoint to start a new fact-checking job." + ) + status_endpoint: str = Field( + default=os.getenv("FACT_CHECK_STATUS_ENDPOINT", "/fact-check/status/"), + description="Endpoint to get the status of a job. Must end with a slash." + ) + correct_endpoint: str = Field( + default=os.getenv("FACT_CHECK_CORRECT_ENDPOINT", "/correct/"), + description="Endpoint for single correction steps." + ) + display_graph_endpoint: str = Field( + default=os.getenv("FACT_CHECK_DISPLAY_GRAPH_ENDPOINT", "/display_graph/"), + description="Endpoint to display a saved fact graph." + ) + graph_json_endpoint: str = Field( + default=os.getenv("FACT_CHECK_GRAPH_JSON_ENDPOINT", "/graph/json/"), + description="Endpoint to retrieve graph data as JSON." + ) + polling_interval_seconds: int = Field( + default=120, + description="How often to poll for job results.", + ) + job_timeout_seconds: int = Field( + default=int(os.getenv("FACT_CHECK_JOB_TIMEOUT", "1800")), + description="Maximum time to wait for the entire job to complete (30 minutes).", + ) + request_timeout_seconds: int = Field( + default=int(os.getenv("FACT_CHECK_REQUEST_TIMEOUT", "60")), + description="Timeout in seconds for each individual API request.", + ) + + +class FactCheckTool( + BaseTool[FactCheckInputSchema, FactCheckOutputSchema], +): + """ + A tool that calls an API to perform fact-checking on a given answer. + """ + + name = "fact_check_tool" + description = ( + "Calls an API to run the FactReasoner pipeline on a question and answer." + ) + input_schema = FactCheckInputSchema + output_schema = FactCheckOutputSchema + config_schema = FactCheckToolConfig + + def __init__( + self, + config: FactCheckToolConfig | None = None, + debug: bool = False, + ): + """Initializes the FactCheckTool and its HTTP client.""" + config = config or FactCheckToolConfig() + super().__init__(config, debug) + + logger.info("Initializing FactCheckTool...") + # Set a timeout on the API requests + timeout = httpx.Timeout(self.config.request_timeout_seconds, connect=60.0) + self.api_client = httpx.AsyncClient( + base_url=str(self.config.base_url), + timeout=timeout, + ) + + async def _arun( + self, + params: FactCheckInputSchema, + ) -> FactCheckOutputSchema: + """ + Starts a fact-checking job and polls for its completion. + """ + logger.info( + f"Sending fact-check request for question: '{params.question[:50]}...'", + ) + + try: + # Start the job + start_response = await self.api_client.post( + self.config.start_endpoint, + json=params.model_dump(), + ) + start_response.raise_for_status() + job_id = start_response.json()["job_id"] + logger.info(f"Successfully started job with ID: {job_id}") + + # Poll for the result + total_wait_time = 0 + while total_wait_time < self.config.job_timeout_seconds: + logger.info(f"Polling status for job {job_id}...") + status_response = await self.api_client.get( + f"{self.config.status_endpoint}/{job_id}", + ) + status_response.raise_for_status() + status_data = status_response.json() + + if status_data["status"] == "completed": + logger.info(f"Job {job_id} completed successfully.") + return FactCheckOutputSchema(**status_data["result"]) + elif status_data["status"] == "failed": + raise Exception( + f"Job {job_id} failed on the server: {status_data.get('error', 'Unknown error')}", + ) + elif status_data["status"] == "pending": + logger.info( + f"Job {job_id} is in progress... (waited {total_wait_time}s)", + ) + await asyncio.sleep(self.config.polling_interval_seconds) + total_wait_time += self.config.polling_interval_seconds + + raise asyncio.TimeoutError( + f"Job {job_id} did not complete within the {self.config.job_timeout_seconds}s timeout.", + ) + + except httpx.HTTPStatusError as e: + logger.error( + f"HTTP error occurred while calling fact-check API: {e.response.status_code} - {e.response.text}", + ) + raise + except Exception as e: + logger.error(f"An unexpected error occurred: {e}") + raise diff --git a/akd/tools/vector_db_tool.py b/akd/tools/vector_db_tool.py new file mode 100644 index 00000000..257c86e4 --- /dev/null +++ b/akd/tools/vector_db_tool.py @@ -0,0 +1,146 @@ +import os +from typing import Any, Dict, List, Optional + +import chromadb +import chromadb.utils.embedding_functions as embedding_functions +from loguru import logger +from pydantic import Field + +from akd._base import InputSchema, OutputSchema +from akd.tools._base import BaseTool, BaseToolConfig +from akd.utils import get_akd_root + + +class VectorDBIndexInputSchema(InputSchema): + """Input schema for indexing documents into the Vector Database.""" + + ids: List[str] = Field(..., description="A unique list of document IDs.") + documents: List[str] = Field( + ..., + description="A list of document contents to index.", + ) + metadatas: Optional[List[Dict[str, Any]]] = Field( + None, + description="Optional list of metadata for each document.", + ) + + +class VectorDBQueryInputSchema(InputSchema): + """Input schema for querying documents from the Vector Database.""" + + query: str = Field(..., description="The query string for retrieval.") + k: int = Field(3, description="Number of documents to retrieve.") + + +class VectorDBQueryOutputSchema(OutputSchema): + """Output schema for the Vector Database tool's query results.""" + + results: List[Dict[str, Any]] = Field( + ..., + description="List of retrieved documents, each as a dictionary with 'page_content' and 'metadata'.", + ) + + +class VectorDBToolConfig(BaseToolConfig): + """Configuration for the VectorDBTool, loaded from environment variables.""" + + embedding_model_name: str = Field( + default="sentence-transformers/all-MiniLM-L6-v2", + description="The name of the Hugging Face embedding model to use.", + ) + embedding_model_api_key: Optional[str] = Field( + default=os.getenv("EMBEDDING_MODEL_API_KEY", None), + description="The API key for the embedding model provider, if required.", + ) + db_path: str = Field( + default=os.getenv("VECTOR_DB_PATH", str(get_akd_root() / "chroma_db")), + description="Path to the persistent ChromaDB directory.", + ) + collection_name: str = Field( + default="akd_vdb", + description="Name of the collection within ChromaDB.", + ) + + +class VectorDBTool( + BaseTool[VectorDBQueryInputSchema, VectorDBQueryOutputSchema], +): + """ + A tool for indexing and retrieving documents from a Chroma vector database. + """ + + name = "vector_db_tool" + description = ( + "Indexes documents into a vector database and retrieves them based on a query." + ) + input_schema = VectorDBQueryInputSchema + output_schema = VectorDBQueryOutputSchema + config_schema = VectorDBToolConfig + + def __init__( + self, + config: Optional[VectorDBToolConfig] = None, + debug: bool = False, + ): + """Initializes the VectorDBTool and its ChromaDB client.""" + config = config or VectorDBToolConfig() + super().__init__(config, debug) + + logger.info("Initializing VectorDBTool...") + self.client = chromadb.PersistentClient(path=self.config.db_path) + + embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=self.config.embedding_model_name, + ) + self.collection = self.client.get_or_create_collection( + name=self.config.collection_name, + embedding_function=embedding_function, + ) + logger.info( + f"Connected to ChromaDB collection '{self.config.collection_name}'.", + ) + + def index(self, params: VectorDBIndexInputSchema): + """ + Adds or updates documents in the vector database collection. + """ + logger.info(f"Indexing {len(params.documents)} documents...") + self.collection.add( + ids=params.ids, + documents=params.documents, + metadatas=params.metadatas, + ) + logger.info("Indexing complete.") + + async def _arun( + self, + params: VectorDBQueryInputSchema, + ) -> VectorDBQueryOutputSchema: + """ + Retrieves documents and returns them as a list of dictionaries. + """ + logger.info( + f"Querying collection with query: '{params.query}', retrieving top-{params.k} documents", + ) + + results = self.collection.query( + query_texts=[params.query], + n_results=params.k, + include=["metadatas", "documents"], + ) + + retrieved_docs = [] + if results and results.get("ids") and results["ids"][0]: + result_documents = results["documents"][0] + result_metadatas = results["metadatas"][0] + + for i in range(len(result_documents)): + doc = { + "page_content": result_documents[i], + "metadata": result_metadatas[i] + if result_metadatas and result_metadatas[i] + else {}, + } + retrieved_docs.append(doc) + + return VectorDBQueryOutputSchema(results=retrieved_docs) diff --git a/pyproject.toml b/pyproject.toml index fc5ec0f3..a11cb97c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ dependencies = [ "ollama>=0.5.1", "tiktoken>=0.9.0", "rapidfuzz>=3.13.0", + "deepeval>=3.4.0", + "markdown>=3.8.2", + "chromadb>=1.0.13", "undetected-chromedriver>=3.5.5", "pypaperbot @ git+https://github.com/NISH1001/PyPaperBot.git@develop", ] @@ -89,6 +92,9 @@ dev = [ "pytest>=8.0.0", "pytest-asyncio>=1.0.0", "pytest-cov>=6.0.0", + "pytest-mock>=3.14.1", + "ipykernel>=6.30.0", + "ipywidgets>=8.1.7", "pre-commit>=4.2.0", ] local = [ @@ -98,6 +104,7 @@ local = [ "ipywidgets>=8.1.7", ] ml = [ + "chromadb>=1.0.13", "pandas>=2.3.1", "sentence-transformers>=5.0.0", "docling>=2.37.0", diff --git a/scripts/demo_deep_search_with_fact_check.py b/scripts/demo_deep_search_with_fact_check.py new file mode 100644 index 00000000..139ad966 --- /dev/null +++ b/scripts/demo_deep_search_with_fact_check.py @@ -0,0 +1,170 @@ +""" +Demo script for running a full research and fact-checking workflow. + +This script demonstrates a complete pipeline: +1. Run the DeepLitSearchAgent to find relevant literature and generate a report. +2. Convert the search results into text chunks using a Langchain text splitter. +3. Use the VectorDBTool to index the chunks into a persistent ChromaDB. +4. Use the FactCheckTool to verify the generated report. +""" + +import asyncio +import sys + +import markdown +from bs4 import BeautifulSoup +from langchain_text_splitters import RecursiveCharacterTextSplitter + +from akd.agents.search import ( + DeepLitSearchAgent, + DeepLitSearchAgentConfig, + LitSearchAgentInputSchema, +) +from akd.configs.project import get_project_settings +from akd.tools.fact_check import FactCheckInputSchema, FactCheckTool +from akd.tools.vector_db_tool import VectorDBIndexInputSchema, VectorDBTool + + +def process_report(report): + """Processes a report (in markdown) to remove headers and reference lists.""" + html = markdown.markdown(report) + soup = BeautifulSoup(html, "html.parser") + + # Ignore header values + content_tags = soup.find_all(["p", "li"]) + + prose_fragments = [] + for tag in content_tags: + is_reference_link = ( + tag.name == "li" and tag.find("a") and len(tag.contents) == 1 + ) + + if is_reference_link: + continue + else: + prose_fragments.append(tag.get_text(strip=True)) + + cleaned_markdown = "\n".join(prose_fragments) + return cleaned_markdown + + +async def main(): + # Check for API keys + settings = get_project_settings() + if not settings.model_config_settings.api_keys.openai: + print( + "No OpenAI API key found. Please set OPENAI_API_KEY environment variable.", + ) + return + + # Configure and run the agent + agent_config = DeepLitSearchAgentConfig( + max_research_iterations=1, + use_semantic_scholar=False, # avoid rate limits + enable_full_content_scraping=True, + debug=True, + ) + agent = DeepLitSearchAgent(config=agent_config) + + research_query = "What evidence is there for water on Mars?" + input_params = LitSearchAgentInputSchema(query=research_query, max_results=3) + + print(f"--- Starting research for: '{research_query}' ---") + research_output = await agent._arun(input_params) + + report_results = [ + res + for res in research_output.results + if res.get("url") == "deep-research://report" + ] + source_results = [ + res + for res in research_output.results + if res.get("url") != "deep-research://report" + ] + + if not source_results: + print("No source documents were found. Exiting.") + return + + if not report_results: + print("No research report was generated. Exiting.") + return + + research_report = report_results[0] + + print(f"\n--- Found {len(source_results)} source documents and research report ---") + + print(f"\n--- Splitting {len(source_results)} documents into smaller chunks ---") + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=2000, + chunk_overlap=200, + ) + + all_chunks = [] + all_metadatas = [] + all_ids = [] + + for i, res in enumerate(source_results): + if res.get("content"): + # Use split_text on the raw content string + chunks = text_splitter.split_text(res["content"]) + for j, chunk in enumerate(chunks): + all_chunks.append(chunk) + # Create a unique ID and metadata for each chunk + metadata = {"source": res["url"], "title": res["title"]} + all_metadatas.append(metadata) + all_ids.append(f"res_{i}_chunk_{j}") + + print(f"Created {len(all_chunks)} chunks.") + + vector_db_tool = VectorDBTool() + + print(f"\n--- Indexing {len(all_chunks)} chunks into ChromaDB ---") + + index_params = VectorDBIndexInputSchema( + ids=all_ids, + documents=all_chunks, + metadatas=all_metadatas, + ) + vector_db_tool.index(index_params) + + print(f"Indexing complete. Database is located at: {vector_db_tool.config.db_path}") + + print("\n--- Fact-checking the generated research report ---") + fact_check_tool = FactCheckTool() + + # Clean markdown formatting from report + report = research_report.get("content", "") + + plaintext_report = process_report(report) + + fact_check_input = FactCheckInputSchema( + question=research_report.get("query", research_query), + answer=plaintext_report, + ) + + print(f"\n--- Running Fact-Check on report ---\n{plaintext_report}\n") + + fact_check_result = await fact_check_tool.arun(params=fact_check_input) + + print("\n--- Fact-Check Complete ---") + score = fact_check_result.fact_reasoner_score.get("factuality_score", 0) + num_supported = len(fact_check_result.supported_atoms) + num_not_supported = len(fact_check_result.not_supported_atoms) + + print(f"Factuality Score: {score:.2%}") + print(f"Supported Atoms: {num_supported}") + print(f"Not Supported Atoms: {num_not_supported}") + print(f"Graph ID: {fact_check_result.graph_id}") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except Exception as e: + print(f"\nDemo failed with an unexpected error: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) diff --git a/tests/tools/test_fact_check.py b/tests/tools/test_fact_check.py new file mode 100644 index 00000000..bdee4a33 --- /dev/null +++ b/tests/tools/test_fact_check.py @@ -0,0 +1,101 @@ +from unittest.mock import AsyncMock + +import pytest +from pydantic import HttpUrl + +from akd.tools.fact_check import ( + FactCheckInputSchema, + FactCheckOutputSchema, + FactCheckTool, + FactCheckToolConfig, +) + + +@pytest.fixture +def mock_api_start_response() -> dict: + """Provides a mock JSON response for when the tool is first started.""" + return {"job_id": "mock-job-1234"} + + +@pytest.fixture +def mock_api_final_response() -> dict: + """Provides a mock, successful JSON response from the fact-check API.""" + return { + "fact_reasoner_score": { + "factuality_score_per_atom": [{"a0": {"score": 0.9877, "support": "S"}}], + "factuality_score": 1.0, + }, + "supported_atoms": [{"id": "a0", "text": "The sky is blue"}], + "not_supported_atoms": [], + "contexts": [{"id": "c_a0_0", "title": "Why Is the Sky Blue?"}], + "graph_id": "mock-graph-id-123", + "logging_metadata": {}, + } + + +@pytest.fixture +def fact_check_tool(): + config = FactCheckToolConfig( + base_url=HttpUrl("http://localhost:8011"), + polling_interval_seconds=1, + ) + tool = FactCheckTool(config=config) + return tool + + +@pytest.mark.asyncio +async def test_fact_check_tool_polling_workflow( + mocker, + fact_check_tool, + mock_api_start_response, + mock_api_final_response, +): + """ + Tests the full polling workflow by mocking both the POST and GET calls. + """ + + # Mock the POST call to /fact-check/start + mock_post_response = mocker.Mock() + mock_post_response.status_code = 200 + mock_post_response.json.return_value = mock_api_start_response + + # Mock the GET call to /fact-check/status/{job_id} + mock_get_response = mocker.Mock() + mock_get_response.status_code = 200 + mock_get_response.json.return_value = { + "status": "completed", + "result": mock_api_final_response, + } + + # Patch both 'post' and 'get' methods of the httpx.AsyncClient + mocker.patch( + "akd.tools.fact_check.httpx.AsyncClient.post", + new_callable=AsyncMock, + return_value=mock_post_response, + ) + mocker.patch( + "akd.tools.fact_check.httpx.AsyncClient.get", + new_callable=AsyncMock, + return_value=mock_get_response, + ) + + # Prepare the input for the tool + input_data = FactCheckInputSchema( + question="What colour is the sky?", + answer="The sky is blue.", + ) + + # This will call the mocked versions of post and then get. + output = await fact_check_tool.arun(input_data) + + # Assert: + assert isinstance(output, FactCheckOutputSchema) + assert output.fact_reasoner_score["factuality_score"] == 1.0 + assert output.graph_id == "mock-graph-id-123" + assert len(output.supported_atoms) == 1 + + +def test_fact_check_tool_initialization(fact_check_tool): + """Tests that the tool initializes with the correct config.""" + # This test remains the same and is a good check. + assert fact_check_tool.config.base_url == HttpUrl("http://localhost:8011") diff --git a/tests/tools/test_semantic_scholar_search.py b/tests/tools/test_semantic_scholar_search.py new file mode 100644 index 00000000..2f0c73e1 --- /dev/null +++ b/tests/tools/test_semantic_scholar_search.py @@ -0,0 +1,65 @@ +from unittest.mock import AsyncMock + +import pytest +from pydantic import HttpUrl + +from akd.structures import SearchResultItem +from akd.tools.search.semantic_scholar_search import ( + SemanticScholarSearchTool, + SemanticScholarSearchToolInputSchema, + SemanticScholarSearchToolOutputSchema, +) + + +@pytest.fixture +def mock_arun_output() -> SemanticScholarSearchToolOutputSchema: + """Provides a realistic, successful output object from the arun method.""" + return SemanticScholarSearchToolOutputSchema( + results=[ + SearchResultItem( + title="A Mock Paper on Transformer Architectures", + url=HttpUrl("https://www.semanticscholar.org/paper/a1b2c3d4e5f6"), + content="This paper discusses recent advances in transformer models for NLP.", + query="Recent advances in transformer architectures", + ), + ], + category="science", + ) + + +def test_from_params_constructor(): + """ + Tests that the from_params classmethod correctly initializes the tool + and its configuration. + """ + search_tool = SemanticScholarSearchTool.from_params(max_results=5, debug=True) + assert search_tool.config.max_results == 5 + # Test a default value from the config + assert search_tool.config.base_url == HttpUrl("https://api.semanticscholar.org") + + +@pytest.mark.asyncio +async def test_arun_with_direct_mock(mocker, mock_arun_output): + """ + Tests the main `arun` method by directly mocking the internal `_arun` method. + """ + # Patch the internal _arun method to return our mock output directly. + mocker.patch( + "akd.tools.search.semantic_scholar_search.SemanticScholarSearchTool._arun", + new_callable=AsyncMock, + return_value=mock_arun_output, + ) + + # Initialize the tool and input data. + tool = SemanticScholarSearchTool() + input_schema = SemanticScholarSearchToolInputSchema( + queries=["Recent advances in transformer architectures"], + max_results=3, + ) + + # This will now call the mocked version of _arun. + output = await tool.arun(input_schema) + + assert isinstance(output, SemanticScholarSearchToolOutputSchema) + assert len(output.results) > 0, "No results found" + assert output.results[0].title == "A Mock Paper on Transformer Architectures" diff --git a/tests/tools/test_vector_db_tool.py b/tests/tools/test_vector_db_tool.py new file mode 100644 index 00000000..6f4d07b7 --- /dev/null +++ b/tests/tools/test_vector_db_tool.py @@ -0,0 +1,114 @@ +import shutil +from pathlib import Path +from typing import Dict, List + +import pytest + +from akd.tools.vector_db_tool import ( + VectorDBIndexInputSchema, + VectorDBQueryInputSchema, + VectorDBQueryOutputSchema, + VectorDBTool, + VectorDBToolConfig, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def temp_db_path(tmp_path: Path) -> str: + """Create a temporary directory for the ChromaDB database.""" + db_path = str(tmp_path / "test_chroma_db") + yield db_path + shutil.rmtree(db_path, ignore_errors=True) + + +@pytest.fixture +def sample_data() -> Dict[str, List]: + """Provides sample data as a dictionary of lists.""" + return { + "ids": ["doc1", "doc2", "doc3"], + "documents": [ + "The sky is blue. It is thirty degrees Celsius.", + "The ingredients include apples, pears and grapes.", + "A computer is an electronic device. Keyboards are used for input.", + ], + "metadatas": [ + {"source": "weather.txt", "title": "weather_report"}, + {"source": "ingredients.txt", "title": "ingredients_list"}, + {"source": "tech.txt", "url": "www.tech-example.com"}, + ], + } + + +@pytest.fixture +def configured_db_tool( + temp_db_path: str, + sample_data: Dict[str, List], +) -> VectorDBTool: + """ + Provides a fully configured and pre-populated VectorDBTool instance. + """ + config = VectorDBToolConfig( + db_path=temp_db_path, + collection_name="test_collection", + ) + db_tool = VectorDBTool(config=config) + + # Index the sample documents + index_params = VectorDBIndexInputSchema(**sample_data) + db_tool.index(index_params) + + return db_tool + + +def test_vectordb_initialization(temp_db_path: str): + """ + Tests that the VectorDBTool initializes correctly. + """ + config = VectorDBToolConfig(db_path=temp_db_path, collection_name="init_test") + db_tool = VectorDBTool(config=config) + + assert db_tool.client is not None + assert db_tool.collection is not None + assert db_tool.collection.name == "init_test" + assert Path(temp_db_path).exists() + + +def test_index_method(temp_db_path: str, sample_data: Dict[str, List]): + """ + Tests that the `index` method correctly adds documents. + """ + config = VectorDBToolConfig(db_path=temp_db_path, collection_name="index_test") + db_tool = VectorDBTool(config=config) + + index_params = VectorDBIndexInputSchema(**sample_data) + db_tool.index(index_params) + + assert db_tool.collection.count() == len(sample_data["documents"]) + + retrieved = db_tool.collection.get(ids=["doc2"], include=["metadatas", "documents"]) + assert ( + retrieved["documents"][0] == "The ingredients include apples, pears and grapes." + ) + assert retrieved["metadatas"][0]["source"] == "ingredients.txt" + assert retrieved["metadatas"][0]["title"] == "ingredients_list" + + +async def test_arun_retrieval(configured_db_tool: VectorDBTool): + """ + Tests that the `_arun` method correctly retrieves relevant documents. + """ + query = "What color is the sky?" + input_params = VectorDBQueryInputSchema(query=query, k=1) + + output = await configured_db_tool._arun(input_params) + + assert isinstance(output, VectorDBQueryOutputSchema) + assert len(output.results) == 1 + + retrieved_doc = output.results[0] + assert isinstance(retrieved_doc, dict) + + assert retrieved_doc["metadata"]["source"] == "weather.txt" + assert "The sky is blue" in retrieved_doc["page_content"]