From 727d5f3cbafbb131a38b128b40e8c408579e4dff Mon Sep 17 00:00:00 2001 From: jbrry Date: Wed, 23 Jul 2025 10:54:39 +0100 Subject: [PATCH 01/33] Add `VectorDBTool` based on Chroma --- akd/tools/vector_db_tool.py | 142 ++++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 akd/tools/vector_db_tool.py diff --git a/akd/tools/vector_db_tool.py b/akd/tools/vector_db_tool.py new file mode 100644 index 00000000..332698d3 --- /dev/null +++ b/akd/tools/vector_db_tool.py @@ -0,0 +1,142 @@ +from typing import List, Optional + +import chromadb +import chromadb.utils.embedding_functions as embedding_functions +from langchain_core.documents import Document +from loguru import logger +from pydantic import Field + +from akd._base import InputSchema, OutputSchema +from akd.tools._base import BaseTool, BaseToolConfig + + +class VectorDBInputSchema(InputSchema): + """Input schema for querying documents from the Vector Database.""" + + query: str = Field(..., description="The query string for retrieval.") + k: int = Field(3, description="Number of documents to retrieve.") + + +class VectorDBOutputSchema(OutputSchema): + """Output schema for the Vector Database tool's query results.""" + + results: List[Document] = Field( + ..., + description="List of retrieved Langchain Document objects.", + ) + + +class VectorDBToolConfig(BaseToolConfig): + """Configuration for the VectorDBTool, loaded from environment variables.""" + + embedding_model_name: str = Field( + default="sentence-transformers/all-MiniLM-L6-v2", + description="The name of the Hugging Face embedding model to use.", + ) + embedding_model_api_key: Optional[str] = Field( + default=None, + description="The API key for the embedding model provider, currently using HuggingFace.", + ) + db_path: str = Field( + default="./chroma_db", + description="Path to the persistent ChromaDB directory.", + ) + collection_name: str = Field( + default="litagent_demo", + description="Name of the collection within ChromaDB.", + ) + + +class VectorDBTool( + BaseTool[VectorDBInputSchema, VectorDBOutputSchema], +): + """ + A tool for indexing and retrieving documents from a Chroma vector database. + """ + + name = "vector_db_tool" + description = ( + "Indexes documents into a vector database and retrieves them based on a query." + ) + input_schema = VectorDBInputSchema + output_schema = VectorDBOutputSchema + config_schema = VectorDBToolConfig + + def __init__( + self, + config: Optional[VectorDBToolConfig] = None, + debug: bool = False, + ): + """Initializes the VectorDBTool and its ChromaDB client.""" + config = config or self.config_schema() + super().__init__(config, debug) + + logger.info("Initializing VectorDBTool...") + + self.client = chromadb.PersistentClient(path=self.config.db_path) + + embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( + model_name=self.config.embedding_model_name, + ) + self.collection = self.client.get_or_create_collection( + name=self.config.collection_name, + embedding_function=embedding_function, + ) + logger.info( + f"Connected to ChromaDB collection '{self.config.collection_name}'.", + ) + + def index(self, documents: List[Document]): + """ + Adds or updates documents in the vector database collection from Langchain Documents. + """ + logger.info(f"Indexing {len(documents)} documents...") + + # Extract components from the Document objects for ChromaDB + ids = [doc.metadata.get("id", f"doc_{i}") for i, doc in enumerate(documents)] + contents = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + + self.collection.add( + ids=ids, + documents=contents, + metadatas=metadatas, + ) + logger.info("Indexing complete.") + + async def _arun( + self, + params: VectorDBInputSchema, + ) -> VectorDBOutputSchema: + """ + Retrieves documents and returns them as a list of Langchain Document objects. + """ + logger.info( + f"Querying collection with query: '{params.query}', retrieving top-{params.k} documents", + ) + + # Include metadatas and documents to reconstruct the Document objects + results = self.collection.query( + query_texts=[params.query], + n_results=params.k, + include=["metadatas", "documents"], + ) + + retrieved_docs = [] + # The result is batched; we process the first (and only) query's results + if results and results.get("ids") and results["ids"][0]: + result_ids = results["ids"][0] + result_documents = results["documents"][0] + result_metadatas = results["metadatas"][0] + + for i in range(len(result_ids)): + # Reconstruct the Langchain Document object + doc = Document( + page_content=result_documents[i], + metadata=result_metadatas[i] + if result_metadatas and result_metadatas[i] + else {}, + ) + retrieved_docs.append(doc) + + return VectorDBOutputSchema(results=retrieved_docs) From d85bf70a651620e915b778a7482854f76840f905 Mon Sep 17 00:00:00 2001 From: jbrry Date: Wed, 23 Jul 2025 15:23:14 +0100 Subject: [PATCH 02/33] Adopt config strategy and add text splitter and vector db tools --- scripts/run_lit_agent.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/scripts/run_lit_agent.py b/scripts/run_lit_agent.py index 2be7fddc..d918bba1 100644 --- a/scripts/run_lit_agent.py +++ b/scripts/run_lit_agent.py @@ -2,10 +2,7 @@ import asyncio import json -import openai from loguru import logger -from pydantic import ConfigDict -from langchain_openai import ChatOpenAI from akd.agents.extraction import ( EstimationExtractionAgent, @@ -14,25 +11,23 @@ from akd.agents.factory import create_query_agent from akd.agents.intents import IntentAgent from akd.agents.litsearch import LitAgent, LitAgentInputSchema -from akd.configs.lit_config import get_lit_agent_settings -from akd.tools.scrapers.composite import CompositeWebScraper, ResearchArticleResolver +from akd.tools.scrapers.composite import CompositeScraper, ResearchArticleResolver from akd.tools.scrapers.pdf_scrapers import SimplePDFScraper from akd.tools.scrapers.resolvers import ADSResolver, ArxivResolver, IdentityResolver from akd.tools.scrapers.web_scrapers import Crawl4AIWebScraper, SimpleWebScraper -from akd.tools.search import SearxNGSearchTool +from akd.tools.search import SearxNGSearchTool, SearxNGSearchToolConfig +from akd.tools.text_splitter import TextSplitterTool +from akd.tools.vector_db_tool import VectorDBTool, VectorDBToolConfig async def main(args): - lit_agent_config = get_lit_agent_settings(args.config) - search_config = lit_agent_config.search - scraper_config = lit_agent_config.scraper - + search_config = SearxNGSearchToolConfig(max_results=3) search_tool = SearxNGSearchTool(config=search_config) - scraper = CompositeWebScraper( - SimpleWebScraper(scraper_config), - Crawl4AIWebScraper(scraper_config), - SimplePDFScraper(scraper_config), + scraper = CompositeScraper( + SimpleWebScraper(), + Crawl4AIWebScraper(), + SimplePDFScraper(), debug=True, ) @@ -42,11 +37,15 @@ async def main(args): IdentityResolver(), ) - intent_agent = IntentAgent( - config=ConfigDict(client=ChatOpenAI()), + text_splitter = TextSplitterTool() + vector_db_config = VectorDBToolConfig( + collection_name="lit_agent_demo", ) + vector_db_tool = VectorDBTool(config=vector_db_config) + intent_agent = IntentAgent() query_agent = create_query_agent() + schema_mapper = IntentBasedExtractionSchemaMapper() extraction_agent = EstimationExtractionAgent() @@ -58,12 +57,14 @@ async def main(args): search_tool=search_tool, web_scraper=scraper, article_resolver=article_resolver, + text_splitter=text_splitter, + vector_db_tool=vector_db_tool, ) lit_agent.clear_history() result = await lit_agent.arun( - LitAgentInputSchema(query=args.query, max_search_results=5) + LitAgentInputSchema(query=args.query, max_search_results=5), ) logger.info(result.model_dump()) From 63e5cbeb9cda73e6ac7c0ef45e64f93ac5833a3e Mon Sep 17 00:00:00 2001 From: jbrry Date: Wed, 23 Jul 2025 15:27:23 +0100 Subject: [PATCH 03/33] Update litsearch agent with new indexing capabilities --- akd/agents/litsearch.py | 46 +++++++++++++++++++ akd/tools/text_splitter.py | 93 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+) create mode 100644 akd/tools/text_splitter.py diff --git a/akd/agents/litsearch.py b/akd/agents/litsearch.py index 8be51fac..52c76c24 100644 --- a/akd/agents/litsearch.py +++ b/akd/agents/litsearch.py @@ -1,5 +1,6 @@ from typing import List +from langchain_core.documents import Document from loguru import logger from pydantic import Field @@ -15,6 +16,10 @@ ) from akd.tools.search import SearxNGSearchTool, SearxNGSearchToolInputSchema +# --- MODIFIED: Re-importing the TextSplitterTool and its schema --- +from akd.tools.text_splitter import TextSplitterInputSchema, TextSplitterTool +from akd.tools.vector_db_tool import VectorDBTool + from .extraction import ( EstimationExtractionAgent, ExtractionInputSchema, @@ -60,6 +65,8 @@ def __init__( search_tool: SearxNGSearchTool, web_scraper: WebScraperToolBase, article_resolver: BaseArticleResolver, + text_splitter: TextSplitterTool, + vector_db_tool: VectorDBTool, n_queries: int = 3, debug: bool = False, ) -> None: @@ -71,10 +78,19 @@ def __init__( self.search_tool = search_tool self.web_scraper = web_scraper self.article_resolver = article_resolver + self.text_splitter = text_splitter + self.vector_db_tool = vector_db_tool self.n_queries = n_queries super().__init__(debug=debug) + async def get_response_async(self, *args, **kwargs) -> LitAgentOutputSchema: + """ + This method is required by the BaseAgent but is not used by LitAgent, + which is an orchestrator. The main entry point is the `_arun` method. + """ + raise NotImplementedError() + async def _arun( self, params: LitAgentInputSchema, @@ -110,6 +126,7 @@ async def _arun( # Log search results logger.info(f"Found {len(search_results.results)} relevant web pages:") contents = [] + docs_to_split = [] for i, result in enumerate(search_results.results, 1): logger.debug(f"Result {i} : Scraping the url {result.url}") resolver_output = await self.article_resolver.arun( @@ -132,8 +149,37 @@ async def _arun( f"Result {i}: {result.title} | " f"{url} | {content[:100]}.. | words={len(content.split())}", ) + contents.append(ExtractionDTO(source=str(url), result=content)) + if content: + doc = Document( + page_content=content, + metadata={ + "id": str(url), + "source": str(url), + "title": result.title, + }, + ) + docs_to_split.append(doc) + + # Split and index the documents + if docs_to_split: + logger.info(f"Splitting {len(docs_to_split)} documents into chunks...") + splitter_output = await self.text_splitter.arun( + TextSplitterInputSchema( + documents=docs_to_split, + ), + ) + docs_to_index = splitter_output.chunks + + if docs_to_index: + logger.info(f"Indexing {len(docs_to_index)} document chunks...") + try: + self.vector_db_tool.index(documents=docs_to_index) + except Exception as e: + logger.error(f"Failed to index document chunks in VectorDB: {e}") + results = [] for content in contents: self.extraction_agent.reset_memory() diff --git a/akd/tools/text_splitter.py b/akd/tools/text_splitter.py new file mode 100644 index 00000000..1fe6f42b --- /dev/null +++ b/akd/tools/text_splitter.py @@ -0,0 +1,93 @@ +from typing import List, Optional + +from langchain_core.documents import Document +from langchain_text_splitters import RecursiveCharacterTextSplitter +from loguru import logger +from pydantic import Field + +from akd._base import InputSchema, OutputSchema +from akd.tools._base import BaseTool, BaseToolConfig + + +class TextSplitterInputSchema(InputSchema): + """Input schema for the Text Splitter Tool.""" + + documents: List[Document] = Field( + ..., + description="A list of Langchain Document objects to split.", + ) + + +class TextSplitterOutputSchema(OutputSchema): + """Output schema for the Text Splitter Tool.""" + + chunks: List[Document] = Field( + ..., + description="A list of smaller Langchain Document objects (chunks).", + ) + + +class TextSplitterToolConfig(BaseToolConfig): + """Configuration for the TextSplitterTool.""" + + chunk_size: int = Field( + default=1000, + description="The maximum size of each text chunk.", + ) + chunk_overlap: int = Field( + default=100, + description="The number of characters to overlap between chunks.", + ) + + +class TextSplitterTool( + BaseTool[TextSplitterInputSchema, TextSplitterOutputSchema], +): + """ + A tool for splitting large documents into smaller, more manageable chunks. + """ + + name = "text_splitter_tool" + description = "Splits a list of documents into smaller text chunks." + input_schema = TextSplitterInputSchema + output_schema = TextSplitterOutputSchema + config_schema = TextSplitterToolConfig + + def __init__( + self, + config: Optional[TextSplitterToolConfig] = None, + debug: bool = False, + ): + """Initializes the TextSplitterTool.""" + config = config or self.config_schema() + super().__init__(config, debug) + + logger.info("Initializing TextSplitterTool...") + self._splitter = RecursiveCharacterTextSplitter( + chunk_size=self.config.chunk_size, + chunk_overlap=self.config.chunk_overlap, + ) + + async def _arun( + self, + params: TextSplitterInputSchema, + **kwargs, + ) -> TextSplitterOutputSchema: + """ + Splits the provided documents into smaller chunks. + """ + logger.info(f"Splitting {len(params.documents)} document(s)...") + all_chunks = [] + for doc in params.documents: + chunks = self._splitter.split_documents([doc]) + # Add unique IDs to each chunk's metadata + for i, chunk in enumerate(chunks): + source_id = chunk.metadata.get( + "id", + chunk.metadata.get("source", "unknown"), + ) + chunk.metadata["id"] = f"{source_id}_{i}" + all_chunks.extend(chunks) + + logger.info(f"Created {len(all_chunks)} chunks.") + return TextSplitterOutputSchema(chunks=all_chunks) From f1f6a76dd7029bd62865aca9da125adeba36f530 Mon Sep 17 00:00:00 2001 From: jbrry Date: Wed, 23 Jul 2025 16:07:00 +0100 Subject: [PATCH 04/33] Remove comment line --- akd/agents/litsearch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/akd/agents/litsearch.py b/akd/agents/litsearch.py index 52c76c24..7925527e 100644 --- a/akd/agents/litsearch.py +++ b/akd/agents/litsearch.py @@ -15,8 +15,6 @@ WebScraperToolBase, ) from akd.tools.search import SearxNGSearchTool, SearxNGSearchToolInputSchema - -# --- MODIFIED: Re-importing the TextSplitterTool and its schema --- from akd.tools.text_splitter import TextSplitterInputSchema, TextSplitterTool from akd.tools.vector_db_tool import VectorDBTool From 2e783259c6d7bfd15742271495d41debaed99d45 Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 7 Aug 2025 11:30:49 +0100 Subject: [PATCH 05/33] Changes to standalone script for vectordb --- scripts/run_lit_agent.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/run_lit_agent.py b/scripts/run_lit_agent.py index d918bba1..fdb4ab15 100644 --- a/scripts/run_lit_agent.py +++ b/scripts/run_lit_agent.py @@ -21,8 +21,10 @@ async def main(args): - search_config = SearxNGSearchToolConfig(max_results=3) - search_tool = SearxNGSearchTool(config=search_config) + # search_config = SearxNGSearchToolConfig(max_results=3) + search_tool = SearxNGSearchTool( + # config=search_config + ) scraper = CompositeScraper( SimpleWebScraper(), @@ -39,6 +41,7 @@ async def main(args): text_splitter = TextSplitterTool() vector_db_config = VectorDBToolConfig( + db_path="./", collection_name="lit_agent_demo", ) vector_db_tool = VectorDBTool(config=vector_db_config) From 25866215615de934dc2430dab2894efa7cba0e1e Mon Sep 17 00:00:00 2001 From: jbrry Date: Fri, 8 Aug 2025 15:07:08 +0100 Subject: [PATCH 06/33] Add test to run `SemanticScholarSearchTool` --- tests/tools/semantic_scholar_search_test.py | 72 +++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 tests/tools/semantic_scholar_search_test.py diff --git a/tests/tools/semantic_scholar_search_test.py b/tests/tools/semantic_scholar_search_test.py new file mode 100644 index 00000000..0249d0ad --- /dev/null +++ b/tests/tools/semantic_scholar_search_test.py @@ -0,0 +1,72 @@ +import asyncio + +import pytest + +from akd.tools.search import ( + SemanticScholarSearchTool, + SemanticScholarSearchToolConfig, + SemanticScholarSearchToolInputSchema, + SemanticScholarSearchToolOutputSchema, +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_fetch_paper_by_external_id(): # Renamed for clarity + """ + Tests that fetch_paper_by_external_id can successfully retrieve + and parse a specific paper using its ARXIV ID. + """ + config = SemanticScholarSearchToolConfig() + search_tool = SemanticScholarSearchTool(config=config, debug=True) + + known_arxiv_id = "1706.03762" + input_schema = SemanticScholarSearchToolInputSchema(queries=[known_arxiv_id]) + + results = await search_tool.fetch_paper_by_external_id( + input_schema, + external_id="ARXIV", + ) + + assert isinstance(results, list) + assert len(results) == 1, ( + "Expected to find exactly one paper for the given ArXiv ID." + ) + + paper = results[0] + # Check that the title and ArXiv ID match the paper we requested. + assert paper.external_ids["ArXiv"] == known_arxiv_id + + +@pytest.mark.asyncio +async def test_arun(): + """ + Tests the main `arun` method to ensure the full process works. + """ + config = SemanticScholarSearchToolConfig() + search_tool = SemanticScholarSearchTool(config=config, debug=True) + + queries = ["Enhanced dependency parsing approaches"] + input_schema = SemanticScholarSearchToolInputSchema(queries=queries, max_results=3) + + output = await search_tool.arun(input_schema) + + # Assertions to check the final, processed output + assert isinstance(output, SemanticScholarSearchToolOutputSchema) + assert len(output.results) > 0, "No results found" + + first_result = output.results[0] + assert first_result.url, "No url included" + assert first_result.title, "No title included" + assert first_result.content, "No content included" + + +async def main(): + """Runs all the defined tests.""" + await test_fetch_paper_by_external_id() + await test_arun() + + +if __name__ == "__main__": + asyncio.run(main()) From 25ccec3e9f9b2cb94f402f2cf7ddc2615a32b0d6 Mon Sep 17 00:00:00 2001 From: jbrry Date: Mon, 11 Aug 2025 10:18:53 +0100 Subject: [PATCH 07/33] Add test for text splitter tool --- tests/tools/text_splitter_test.py | 88 +++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 tests/tools/text_splitter_test.py diff --git a/tests/tools/text_splitter_test.py b/tests/tools/text_splitter_test.py new file mode 100644 index 00000000..b2f1aa72 --- /dev/null +++ b/tests/tools/text_splitter_test.py @@ -0,0 +1,88 @@ +import pytest +from langchain_core.documents import Document + +from akd.tools.text_splitter import ( + TextSplitterInputSchema, + TextSplitterOutputSchema, + TextSplitterTool, + TextSplitterToolConfig, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def sample_documents(): + """Provides sample documents for testing.""" + long_text = " ".join(["This is sentence " + str(i) + "." for i in range(200)]) + + return [ + Document( + page_content=long_text, + metadata={"source": "doc1.txt", "url": "www.example1.com"}, + ), + Document( + page_content="This is a short document that should not be split.", + metadata={"source": "doc2.txt", "url": "www.example2.com"}, + ), + ] + + +async def test_text_splitter_with_default_config(sample_documents): + """ + Tests the TextSplitterTool with its default configuration to ensure + it splits long documents and assigns new IDs. + """ + + splitter_tool = TextSplitterTool() + input_data = TextSplitterInputSchema(documents=sample_documents) + + # Run the tool + output = await splitter_tool._arun(input_data) + + assert isinstance(output, TextSplitterOutputSchema) + + # The long document should be split, and the short one should remain as one chunk. + assert len(output.chunks) > len(sample_documents) + + # Check the properties of each chunk + for chunk in output.chunks: + assert isinstance(chunk, Document) + assert len(chunk.page_content) <= splitter_tool.config.chunk_size + # Verify that a new, unique chunk ID has been added to the metadata + assert "id" in chunk.metadata + assert chunk.metadata["id"].startswith(chunk.metadata["source"]) + + +async def test_text_splitter_with_custom_config(): + """ + Tests the TextSplitterTool with a custom configuration (smaller chunk size) + to verify it produces more chunks. + """ + small_text = ( + "The cat sat on the mat. Rug is another word for mat. This is a third sentence." + ) + doc = Document(page_content=small_text, metadata={"source": "custom.txt"}) + + # Use a very small chunk size to force splitting + custom_config = TextSplitterToolConfig(chunk_size=30, chunk_overlap=5) + splitter_tool = TextSplitterTool(config=custom_config) + + input_data = TextSplitterInputSchema(documents=[doc]) + + # Run the tool + output = await splitter_tool._arun(input_data) + + # Assert the output + assert isinstance(output, TextSplitterOutputSchema) + + assert len(output.chunks) > 1 + + first_chunk = output.chunks[0] + second_chunk = output.chunks[1] + + # Assert the index of the start index of the second chunk + overlap_start_index = first_chunk.page_content.find( + second_chunk.page_content[: custom_config.chunk_overlap], + ) + assert overlap_start_index != -1, "Chunks should have overlapping content" From 384c9d35ca624524c57babd967cb221c0314ab0c Mon Sep 17 00:00:00 2001 From: jbrry Date: Mon, 11 Aug 2025 12:50:19 +0100 Subject: [PATCH 08/33] Test the `from_params` method of the class --- tests/tools/semantic_scholar_search_test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/tools/semantic_scholar_search_test.py b/tests/tools/semantic_scholar_search_test.py index 0249d0ad..a86eefce 100644 --- a/tests/tools/semantic_scholar_search_test.py +++ b/tests/tools/semantic_scholar_search_test.py @@ -11,6 +11,16 @@ pytest_plugins = ("pytest_asyncio",) +def test_from_params_constructor(): + """ + Tests that the from_params classmethod correctly initializes the tool + and its configuration. + """ + search_tool = SemanticScholarSearchTool.from_params(max_results=5, debug=True) + assert search_tool.config.max_results == 5 + # Test a default value + assert search_tool.config.external_id == "DOI" + @pytest.mark.asyncio async def test_fetch_paper_by_external_id(): # Renamed for clarity @@ -64,6 +74,7 @@ async def test_arun(): async def main(): """Runs all the defined tests.""" + test_from_params_constructor() await test_fetch_paper_by_external_id() await test_arun() From 231e6ffb873b4b81f54f94628dccaaf00e1ea919 Mon Sep 17 00:00:00 2001 From: jbrry Date: Mon, 11 Aug 2025 15:41:51 +0100 Subject: [PATCH 09/33] Add tests for vector db and text splitter tools --- tests/tools/text_splitter_test.py | 2 +- tests/tools/vector_db_tool_test.py | 125 +++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 tests/tools/vector_db_tool_test.py diff --git a/tests/tools/text_splitter_test.py b/tests/tools/text_splitter_test.py index b2f1aa72..3ed158fd 100644 --- a/tests/tools/text_splitter_test.py +++ b/tests/tools/text_splitter_test.py @@ -64,7 +64,7 @@ async def test_text_splitter_with_custom_config(): ) doc = Document(page_content=small_text, metadata={"source": "custom.txt"}) - # Use a very small chunk size to force splitting + # Use a small chunk size to force splitting custom_config = TextSplitterToolConfig(chunk_size=30, chunk_overlap=5) splitter_tool = TextSplitterTool(config=custom_config) diff --git a/tests/tools/vector_db_tool_test.py b/tests/tools/vector_db_tool_test.py new file mode 100644 index 00000000..3be016c8 --- /dev/null +++ b/tests/tools/vector_db_tool_test.py @@ -0,0 +1,125 @@ +import shutil +from pathlib import Path + +import pytest +from langchain_core.documents import Document + +from akd.tools.vector_db_tool import ( + VectorDBInputSchema, + VectorDBOutputSchema, + VectorDBTool, + VectorDBToolConfig, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def temp_db_path(tmp_path: Path) -> str: + """Create a temporary directory for the ChromaDB database.""" + db_path = str(tmp_path / "test_chroma_db") + yield db_path + # Clean up the database directory after the test runs + shutil.rmtree(db_path, ignore_errors=True) + + +@pytest.fixture +def sample_documents() -> list[Document]: + """Provides a list of sample documents for indexing.""" + return [ + Document( + page_content="The sky is blue. It is a beautiful day.", + metadata={"id": "doc1", "source": "weather.txt"}, + ), + Document( + page_content="Apples are a type of fruit. They are often red or green.", + metadata={"id": "doc2", "source": "fruits.txt"}, + ), + Document( + page_content="A computer is an electronic device. Keyboards are used for input.", + metadata={"id": "doc3", "source": "tech.txt"}, + ), + ] + + +@pytest.fixture +def configured_db_tool( + temp_db_path: str, + sample_documents: list[Document], +) -> VectorDBTool: + """ + Provides a fully configured and pre-populated VectorDBTool instance + for testing retrieval. + """ + # Use a unique collection name for each test run to ensure isolation + config = VectorDBToolConfig( + db_path=temp_db_path, + collection_name="test_collection", + ) + db_tool = VectorDBTool(config=config) + + # Index the sample documents into the fresh database + db_tool.index(sample_documents) + + return db_tool + + +def test_vectordb_initialization(temp_db_path: str): + """ + Tests that the VectorDBTool initializes correctly and creates the + database directory and collection. + """ + config = VectorDBToolConfig(db_path=temp_db_path, collection_name="init_test") + db_tool = VectorDBTool(config=config) + + # Assert that the client and collection were created + assert db_tool.client is not None + assert db_tool.collection is not None + assert db_tool.collection.name == "init_test" + + # Assert that the database directory was created on disk + assert Path(temp_db_path).exists() + + +def test_index_method(temp_db_path: str, sample_documents: list[Document]): + """ + Tests that the `index` method correctly adds documents to the collection. + """ + config = VectorDBToolConfig(db_path=temp_db_path, collection_name="index_test") + db_tool = VectorDBTool(config=config) + + # Index the documents + db_tool.index(sample_documents) + + # Verify the documents were added to the collection + assert db_tool.collection.count() == len(sample_documents) + + # Retrieve one document by ID to confirm its content + retrieved = db_tool.collection.get(ids=["doc2"], include=["metadatas", "documents"]) + assert ( + retrieved["documents"][0] + == "Apples are a type of fruit. They are often red or green." + ) + assert retrieved["metadatas"][0]["source"] == "fruits.txt" + + +async def test_arun_retrieval(configured_db_tool: VectorDBTool): + """ + Tests that the `_arun` method correctly retrieves the most relevant + documents for a given query. + """ + + query = "What color is the sky?" + input_params = VectorDBInputSchema(query=query, k=1) + + output = await configured_db_tool._arun(input_params) + + assert isinstance(output, VectorDBOutputSchema) + assert len(output.results) == 1 + + retrieved_doc = output.results[0] + assert isinstance(retrieved_doc, Document) + + # Check that the most relevant document was returned + assert retrieved_doc.metadata["source"] == "weather.txt" + assert "The sky is blue" in retrieved_doc.page_content From c4a57c9554d964ab15207b251b241e1fac982ba0 Mon Sep 17 00:00:00 2001 From: jbrry Date: Mon, 11 Aug 2025 16:32:51 +0100 Subject: [PATCH 10/33] Fix chunking test --- tests/tools/text_splitter_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/tools/text_splitter_test.py b/tests/tools/text_splitter_test.py index 3ed158fd..1dcead89 100644 --- a/tests/tools/text_splitter_test.py +++ b/tests/tools/text_splitter_test.py @@ -65,7 +65,7 @@ async def test_text_splitter_with_custom_config(): doc = Document(page_content=small_text, metadata={"source": "custom.txt"}) # Use a small chunk size to force splitting - custom_config = TextSplitterToolConfig(chunk_size=30, chunk_overlap=5) + custom_config = TextSplitterToolConfig(chunk_size=30, chunk_overlap=10) splitter_tool = TextSplitterTool(config=custom_config) input_data = TextSplitterInputSchema(documents=[doc]) @@ -81,8 +81,8 @@ async def test_text_splitter_with_custom_config(): first_chunk = output.chunks[0] second_chunk = output.chunks[1] - # Assert the index of the start index of the second chunk - overlap_start_index = first_chunk.page_content.find( - second_chunk.page_content[: custom_config.chunk_overlap], - ) - assert overlap_start_index != -1, "Chunks should have overlapping content" + # Check the first word 'Rug' is in first chunk + assert ( + second_chunk.page_content[: custom_config.chunk_overlap].split()[0] + in first_chunk.page_content + ), "Chunks should have overlapping content" From d947cb7638ca6b8d3d2f3e6095fa97d709bafeb7 Mon Sep 17 00:00:00 2001 From: jbrry Date: Mon, 11 Aug 2025 16:34:05 +0100 Subject: [PATCH 11/33] Leave config initialisation to super class --- akd/tools/text_splitter.py | 7 ++++--- akd/tools/vector_db_tool.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/akd/tools/text_splitter.py b/akd/tools/text_splitter.py index 1fe6f42b..a5d9c021 100644 --- a/akd/tools/text_splitter.py +++ b/akd/tools/text_splitter.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -55,11 +55,12 @@ class TextSplitterTool( def __init__( self, - config: Optional[TextSplitterToolConfig] = None, + config: TextSplitterToolConfig | None = None, + # config: Optional[TextSplitterToolConfig] = None, debug: bool = False, ): """Initializes the TextSplitterTool.""" - config = config or self.config_schema() + config = config or TextSplitterToolConfig() super().__init__(config, debug) logger.info("Initializing TextSplitterTool...") diff --git a/akd/tools/vector_db_tool.py b/akd/tools/vector_db_tool.py index 332698d3..e972b9e0 100644 --- a/akd/tools/vector_db_tool.py +++ b/akd/tools/vector_db_tool.py @@ -64,11 +64,11 @@ class VectorDBTool( def __init__( self, - config: Optional[VectorDBToolConfig] = None, + config: VectorDBToolConfig | None = None, debug: bool = False, ): """Initializes the VectorDBTool and its ChromaDB client.""" - config = config or self.config_schema() + config = config or VectorDBToolConfig() super().__init__(config, debug) logger.info("Initializing VectorDBTool...") From 2220443115490aade38c70d51602ce3c5fa1b221 Mon Sep 17 00:00:00 2001 From: jbrry Date: Tue, 12 Aug 2025 12:31:47 +0100 Subject: [PATCH 12/33] Add FactReasoner fact-check to deep research pipeline --- akd/tools/fact_check.py | 109 ++++++++++++++++ scripts/demo_deep_search_with_fact_check.py | 138 ++++++++++++++++++++ 2 files changed, 247 insertions(+) create mode 100644 akd/tools/fact_check.py create mode 100644 scripts/demo_deep_search_with_fact_check.py diff --git a/akd/tools/fact_check.py b/akd/tools/fact_check.py new file mode 100644 index 00000000..de0048b3 --- /dev/null +++ b/akd/tools/fact_check.py @@ -0,0 +1,109 @@ +from typing import Any, Dict, List, Optional + +import httpx +from loguru import logger +from pydantic import Field, HttpUrl + +from akd._base import InputSchema, OutputSchema +from akd.tools._base import BaseTool, BaseToolConfig + + +class FactCheckInputSchema(InputSchema): + """Input schema for the Fact-Checking Tool.""" + + question: str = Field(..., description="The original question that was asked.") + answer: str = Field(..., description="The LLM answer to be fact-checked.") + + +class FactCheckOutputSchema(OutputSchema): + """Output schema for the Fact-Checking Tool's results.""" + + fact_reasoner_score: Dict[str, Any] = Field( + ..., + description="The full scoring dictionary from the FactReasoner.", + ) + supported_atoms: List[Dict[str, Any]] = Field( + ..., + description="List of atoms determined to be supported.", + ) + not_supported_atoms: List[Dict[str, Any]] = Field( + ..., + description="List of atoms determined to be not supported.", + ) + contexts: List[Dict[str, Any]] = Field( + ..., + description="List of retrieved contexts used for the check.", + ) + graph_id: Optional[str] = Field( + None, + description="The unique ID for the generated fact graph.", + ) + + +class FactCheckToolConfig(BaseToolConfig): + """Configuration for the FactCheckTool.""" + + base_url: HttpUrl = Field( + # default="http://localhost:8011", + default="https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud", + description="The base URL of the remote Fact-Checking and Correction Service.", + ) + + +class FactCheckTool( + BaseTool[FactCheckInputSchema, FactCheckOutputSchema], +): + """ + A tool that calls an API to perform fact-checking on a given answer. + """ + + name = "fact_check_tool" + description = ( + "Calls an API to run the FactReasoner pipeline on a question and answer." + ) + input_schema = FactCheckInputSchema + output_schema = FactCheckOutputSchema + config_schema = FactCheckToolConfig + + def __init__( + self, + config: FactCheckToolConfig | None = None, + debug: bool = False, + ): + """Initializes the FactCheckTool and its HTTP client.""" + config = config or FactCheckToolConfig() + super().__init__(config, debug) + + logger.info("Initializing FactCheckTool...") + self.api_client = httpx.AsyncClient(base_url=str(self.config.base_url)) + + async def _arun( + self, + params: FactCheckInputSchema, + ) -> FactCheckOutputSchema: + """ + Calls the /fact-check/ endpoint on the remote service. + """ + logger.info( + f"Sending fact-check request for question: '{params.question[:50]}...'", + ) + + try: + response = await self.api_client.post( + "/fact-check/", + json=params.model_dump(), + timeout=1500.0, # 25 mins timeout for potentially slow API calls + ) + response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) + + results = response.json() + return FactCheckOutputSchema(**results) + + except httpx.HTTPStatusError as e: + logger.error( + f"HTTP error occurred while calling fact-check API: {e.response.status_code} - {e.response.text}", + ) + raise + except Exception as e: + logger.error(f"An unexpected error occurred: {e}") + raise diff --git a/scripts/demo_deep_search_with_fact_check.py b/scripts/demo_deep_search_with_fact_check.py new file mode 100644 index 00000000..f3fd184f --- /dev/null +++ b/scripts/demo_deep_search_with_fact_check.py @@ -0,0 +1,138 @@ +""" +Demo script for running a full research and fact-checking workflow. + +This script demonstrates a complete pipeline: +1. Run the DeepLitSearchAgent to find relevant literature and generate a report. +2. Use the TextSplitterTool to chunk the retrieved source documents. +3. Use the VectorDBTool to index the chunks into a persistent ChromaDB. + This step is so that we can also use the same documents to verify the answer. +4. Use the FactCheckTool to verify the generated report. +""" + +import asyncio +import sys + +from langchain_core.documents import Document + +from akd.agents.search import ( + DeepLitSearchAgent, + DeepLitSearchAgentConfig, + LitSearchAgentInputSchema, +) +from akd.configs.project import get_project_settings +from akd.tools.fact_check import FactCheckInputSchema, FactCheckTool +from akd.tools.text_splitter import ( + TextSplitterInputSchema, + TextSplitterTool, + TextSplitterToolConfig, +) +from akd.tools.vector_db_tool import VectorDBTool + + +async def main(): + # Check for API keys + settings = get_project_settings() + if not settings.model_config_settings.api_keys.openai: + print( + "No OpenAI API key found. Please set OPENAI_API_KEY environment variable.", + ) + return + + # Configure and run the agent + agent_config = DeepLitSearchAgentConfig( + max_research_iterations=1, + use_semantic_scholar=False, # avoid rate limits + enable_full_content_scraping=True, + debug=True, + ) + agent = DeepLitSearchAgent(config=agent_config) + + research_query = "recent advances in transformer architectures" + input_params = LitSearchAgentInputSchema(query=research_query, max_results=3) + + print(f"--- Starting research for: '{research_query}' ---") + research_output = await agent._arun(input_params) + + report_results = [ + res + for res in research_output.results + if res.get("url") == "deep-research://report" + ] + source_results = [ + res + for res in research_output.results + if res.get("url") != "deep-research://report" + ] + + if not source_results: + print("No source documents were found. Exiting.") + return + + if not report_results: + print("No research report was generated. Exiting.") + return + + research_report = report_results[0] + + print(f"\n--- Found {len(source_results)} source documents and research report ---") + + # Convert search results to Langchain Document objects for processing + documents_to_process = [ + Document( + page_content=res["content"], + metadata={"source": res["url"], "title": res["title"]}, + ) + for res in source_results + if res.get("content") + ] + + splitter_config = TextSplitterToolConfig(chunk_size=2000) + splitter_tool = TextSplitterTool(config=splitter_config, debug=True) + splitter_input = TextSplitterInputSchema(documents=documents_to_process) + + print( + f"\n--- Splitting {len(documents_to_process)} documents into smaller chunks ---", + ) + split_output = await splitter_tool._arun(splitter_input) + chunks = split_output.chunks + + print(f"Created {len(chunks)} chunks.") + + vector_db_tool = VectorDBTool() + + print(f"\n--- Indexing {len(chunks)} chunks into ChromaDB ---") + + vector_db_tool.index(chunks) + + print(f"Indexing complete! Database is located at: {vector_db_tool.config.db_path}") + + print("\n--- Fact-checking the generated research report ---") + fact_check_tool = FactCheckTool() + + fact_check_input = FactCheckInputSchema( + question=research_report.get("query", research_query), + answer=research_report.get("content", ""), + ) + + fact_check_result = await fact_check_tool.arun(params=fact_check_input) + + print("\n--- Fact-Check Complete ---") + score = fact_check_result.fact_reasoner_score.get("factuality_score", 0) + num_supported = len(fact_check_result.supported_atoms) + num_not_supported = len(fact_check_result.not_supported_atoms) + + print(f"Factuality Score: {score:.2%}") + print(f"Supported Atoms: {num_supported}") + print(f"Not Supported Atoms: {num_not_supported}") + print(f"Graph ID: {fact_check_result.graph_id}") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except Exception as e: + print(f"\nDemo failed with an unexpected error: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) From 7a2c13d5f5b6c8e226c73681603c1276a5b217c6 Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 14 Aug 2025 14:40:44 +0100 Subject: [PATCH 13/33] Add test for FactCheck tool --- tests/tools/test_fact_check.py | 55 ++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 tests/tools/test_fact_check.py diff --git a/tests/tools/test_fact_check.py b/tests/tools/test_fact_check.py new file mode 100644 index 00000000..5b4a528f --- /dev/null +++ b/tests/tools/test_fact_check.py @@ -0,0 +1,55 @@ +from unittest.mock import AsyncMock + +import pytest + +from akd.tools.fact_check import ( + FactCheckInputSchema, + FactCheckOutputSchema, + FactCheckTool, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def mock_api_response() -> dict: + """Provides a mock, successful JSON response from the fact-check API.""" + return { + "fact_reasoner_score": { + "factuality_score_per_atom": [{"a0": {"score": 0.987795, "support": "S"}}], + "factuality_score": 1.0, + }, + "supported_atoms": [{"id": "a0", "text": "The sky is blue"}], + "not_supported_atoms": [], + "contexts": [{"id": "c_a0_0", "title": "Why Is the Sky Blue?"}], + "graph_id": "mock-graph-id-123", + } + + +async def test_fact_check_tool_with_mock(mocker, mock_api_response): + """ """ + + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_api_response + + mocker.patch( + new_callable=AsyncMock, + return_value=mock_response, + ) + + tool = FactCheckTool() + input_data = FactCheckInputSchema( + question="What colour is the sky?", + answer="The sky is blue.", + ) + + # This will call the mocked version of the tool. + output = await tool.arun(input_data) + + assert isinstance(output, FactCheckOutputSchema) + + assert output.fact_reasoner_score["factuality_score"] == 1.0 + assert len(output.fact_reasoner_score["factuality_score_per_atom"]) == 1 + assert output.graph_id == "mock-graph-id-123" + assert len(output.supported_atoms) == 1 From 595cb3884d451d765c9a6a8604a6905f2d6fee74 Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 14 Aug 2025 15:45:58 +0100 Subject: [PATCH 14/33] Updated test file --- tests/tools/test_fact_check.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/tools/test_fact_check.py b/tests/tools/test_fact_check.py index 5b4a528f..75242f80 100644 --- a/tests/tools/test_fact_check.py +++ b/tests/tools/test_fact_check.py @@ -1,11 +1,13 @@ from unittest.mock import AsyncMock import pytest +from pydantic import HttpUrl from akd.tools.fact_check import ( FactCheckInputSchema, FactCheckOutputSchema, FactCheckTool, + FactCheckToolConfig, ) pytestmark = pytest.mark.asyncio @@ -16,7 +18,7 @@ def mock_api_response() -> dict: """Provides a mock, successful JSON response from the fact-check API.""" return { "fact_reasoner_score": { - "factuality_score_per_atom": [{"a0": {"score": 0.987795, "support": "S"}}], + "factuality_score_per_atom": [{"a0": {"score": 0.9877, "support": "S"}}], "factuality_score": 1.0, }, "supported_atoms": [{"id": "a0", "text": "The sky is blue"}], @@ -26,14 +28,22 @@ def mock_api_response() -> dict: } +@pytest.fixture +def fact_check_tool(): + config = FactCheckToolConfig(base_url="http://localhost:8011") + tool = FactCheckTool(config=config) + return tool + + async def test_fact_check_tool_with_mock(mocker, mock_api_response): - """ """ + """Tests the FactCheck tool's post method. The post method is patched with a mock request.""" mock_response = mocker.Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_api_response mocker.patch( + "akd.tools.fact_check.httpx.AsyncClient.post", new_callable=AsyncMock, return_value=mock_response, ) @@ -53,3 +63,7 @@ async def test_fact_check_tool_with_mock(mocker, mock_api_response): assert len(output.fact_reasoner_score["factuality_score_per_atom"]) == 1 assert output.graph_id == "mock-graph-id-123" assert len(output.supported_atoms) == 1 + + +def test_fact_check_tool_initialization(fact_check_tool): + assert fact_check_tool.base_url == HttpUrl("http://localhost:8011") From 8d9ed477c12288b0b22f21a3549c4bb7429eb073 Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 14 Aug 2025 15:49:36 +0100 Subject: [PATCH 15/33] Wrap string in HttpUrl --- akd/tools/fact_check.py | 4 +++- tests/tools/test_fact_check.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/akd/tools/fact_check.py b/akd/tools/fact_check.py index de0048b3..324fed98 100644 --- a/akd/tools/fact_check.py +++ b/akd/tools/fact_check.py @@ -45,7 +45,9 @@ class FactCheckToolConfig(BaseToolConfig): base_url: HttpUrl = Field( # default="http://localhost:8011", - default="https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud", + default=HttpUrl( + "https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud", + ), description="The base URL of the remote Fact-Checking and Correction Service.", ) diff --git a/tests/tools/test_fact_check.py b/tests/tools/test_fact_check.py index 75242f80..ec5dd0b1 100644 --- a/tests/tools/test_fact_check.py +++ b/tests/tools/test_fact_check.py @@ -10,8 +10,6 @@ FactCheckToolConfig, ) -pytestmark = pytest.mark.asyncio - @pytest.fixture def mock_api_response() -> dict: @@ -35,6 +33,7 @@ def fact_check_tool(): return tool +@pytest.mark.asyncio async def test_fact_check_tool_with_mock(mocker, mock_api_response): """Tests the FactCheck tool's post method. The post method is patched with a mock request.""" From fce0cfbd1b595c8db017706723ec0b0328e9c8db Mon Sep 17 00:00:00 2001 From: jbrry Date: Wed, 20 Aug 2025 15:51:38 +0100 Subject: [PATCH 16/33] Update demo script with fact check --- scripts/demo_deep_search_with_fact_check.py | 35 +++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/scripts/demo_deep_search_with_fact_check.py b/scripts/demo_deep_search_with_fact_check.py index f3fd184f..6a48ef3f 100644 --- a/scripts/demo_deep_search_with_fact_check.py +++ b/scripts/demo_deep_search_with_fact_check.py @@ -12,6 +12,8 @@ import asyncio import sys +import markdown +from bs4 import BeautifulSoup from langchain_core.documents import Document from akd.agents.search import ( @@ -29,6 +31,29 @@ from akd.tools.vector_db_tool import VectorDBTool +def process_report(report): + """Processes a report (in markdown) to remove headers and reference lists.""" + html = markdown.markdown(report) + soup = BeautifulSoup(html, "html.parser") + + # Ignore header values + content_tags = soup.find_all(["p", "li"]) + + prose_fragments = [] + for tag in content_tags: + is_reference_link = ( + tag.name == "li" and tag.find("a") and len(tag.contents) == 1 + ) + + if is_reference_link: + continue + else: + prose_fragments.append(tag.get_text(strip=True)) + + cleaned_markdown = "\n".join(prose_fragments) + return cleaned_markdown + + async def main(): # Check for API keys settings = get_project_settings() @@ -47,7 +72,7 @@ async def main(): ) agent = DeepLitSearchAgent(config=agent_config) - research_query = "recent advances in transformer architectures" + research_query = "What evidence is there for water on Mars?" input_params = LitSearchAgentInputSchema(query=research_query, max_results=3) print(f"--- Starting research for: '{research_query}' ---") @@ -109,11 +134,17 @@ async def main(): print("\n--- Fact-checking the generated research report ---") fact_check_tool = FactCheckTool() + # Clean markdown formatting from report + report = research_report.get("content", "") + + plaintext_report = process_report(report) + fact_check_input = FactCheckInputSchema( question=research_report.get("query", research_query), - answer=research_report.get("content", ""), + answer=plaintext_report, ) + print(plaintext_report) fact_check_result = await fact_check_tool.arun(params=fact_check_input) print("\n--- Fact-Check Complete ---") From 6c50cb7042a6a9346e6eca6e18e180d2e52c1b5b Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 21 Aug 2025 16:26:58 +0100 Subject: [PATCH 17/33] Enable polling for long running processes --- akd/tools/fact_check.py | 75 ++++++++++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/akd/tools/fact_check.py b/akd/tools/fact_check.py index 324fed98..bc871bea 100644 --- a/akd/tools/fact_check.py +++ b/akd/tools/fact_check.py @@ -1,3 +1,5 @@ +import asyncio +import json from typing import Any, Dict, List, Optional import httpx @@ -19,24 +21,22 @@ class FactCheckOutputSchema(OutputSchema): """Output schema for the Fact-Checking Tool's results.""" fact_reasoner_score: Dict[str, Any] = Field( - ..., - description="The full scoring dictionary from the FactReasoner.", + ..., description="The full scoring dictionary from the FactReasoner." ) supported_atoms: List[Dict[str, Any]] = Field( - ..., - description="List of atoms determined to be supported.", + ..., description="List of atoms determined to be supported." ) not_supported_atoms: List[Dict[str, Any]] = Field( - ..., - description="List of atoms determined to be not supported.", + ..., description="List of atoms determined to be not supported." ) contexts: List[Dict[str, Any]] = Field( - ..., - description="List of retrieved contexts used for the check.", + ..., description="List of retrieved contexts used for the check." ) graph_id: Optional[str] = Field( - None, - description="The unique ID for the generated fact graph.", + None, description="The unique ID for the generated fact graph." + ) + logging_metadata: Dict[str, Any] = Field( + {}, description="Additional logging metadata from the run." ) @@ -45,11 +45,15 @@ class FactCheckToolConfig(BaseToolConfig): base_url: HttpUrl = Field( # default="http://localhost:8011", - default=HttpUrl( - "https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud", - ), + default="https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud", description="The base URL of the remote Fact-Checking and Correction Service.", ) + polling_interval_seconds: int = Field( + default=120, description="How often to poll for job results." + ) + job_timeout_seconds: int = Field( + default=1800, description="Maximum time to wait for a job to complete (30 minutes)." + ) class FactCheckTool( @@ -79,27 +83,52 @@ def __init__( logger.info("Initializing FactCheckTool...") self.api_client = httpx.AsyncClient(base_url=str(self.config.base_url)) + async def _arun( self, params: FactCheckInputSchema, ) -> FactCheckOutputSchema: """ - Calls the /fact-check/ endpoint on the remote service. + Starts a fact-checking job and polls for its completion. """ logger.info( f"Sending fact-check request for question: '{params.question[:50]}...'", ) try: - response = await self.api_client.post( - "/fact-check/", - json=params.model_dump(), - timeout=1500.0, # 25 mins timeout for potentially slow API calls + # Start the job + start_response = await self.api_client.post( + "/fact-check/start", json=params.model_dump(), timeout=60.0 + ) + start_response.raise_for_status() + job_id = start_response.json()["job_id"] + logger.info(f"Successfully started job with ID: {job_id}") + + # Poll for the result + total_wait_time = 0 + while total_wait_time < self.config.job_timeout_seconds: + logger.info(f"Polling status for job {job_id}...") + status_response = await self.api_client.get( + f"/fact-check/status/{job_id}", timeout=60.0 + ) + status_response.raise_for_status() + status_data = status_response.json() + + if status_data["status"] == "completed": + logger.info(f"Job {job_id} completed successfully.") + return FactCheckOutputSchema(**status_data["result"]) + elif status_data["status"] == "failed": + raise Exception( + f"Job {job_id} failed on the server: {status_data.get('error', 'Unknown error')}" + ) + elif status_data["status"] == "pending": + logger.info(f"Job {job_id} is in progress... (waited {total_wait_time}s)") + await asyncio.sleep(self.config.polling_interval_seconds) + total_wait_time += self.config.polling_interval_seconds + + raise asyncio.TimeoutError( + f"Job {job_id} did not complete within the {self.config.job_timeout_seconds}s timeout." ) - response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) - - results = response.json() - return FactCheckOutputSchema(**results) except httpx.HTTPStatusError as e: logger.error( @@ -108,4 +137,4 @@ async def _arun( raise except Exception as e: logger.error(f"An unexpected error occurred: {e}") - raise + raise \ No newline at end of file From 84cecff53582c8d322e52216ee8da97980b59920 Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 21 Aug 2025 22:30:46 +0100 Subject: [PATCH 18/33] Do not use text splitter tool --- scripts/demo_deep_search_with_fact_check.py | 28 +++--- ...est.py => test_semantic_scholar_search.py} | 0 ...db_tool_test.py => test_vector_db_tool.py} | 0 tests/tools/text_splitter_test.py | 88 ------------------- 4 files changed, 11 insertions(+), 105 deletions(-) rename tests/tools/{semantic_scholar_search_test.py => test_semantic_scholar_search.py} (100%) rename tests/tools/{vector_db_tool_test.py => test_vector_db_tool.py} (100%) delete mode 100644 tests/tools/text_splitter_test.py diff --git a/scripts/demo_deep_search_with_fact_check.py b/scripts/demo_deep_search_with_fact_check.py index 6a48ef3f..52b51f2f 100644 --- a/scripts/demo_deep_search_with_fact_check.py +++ b/scripts/demo_deep_search_with_fact_check.py @@ -3,9 +3,8 @@ This script demonstrates a complete pipeline: 1. Run the DeepLitSearchAgent to find relevant literature and generate a report. -2. Use the TextSplitterTool to chunk the retrieved source documents. +2. Convert the search results into text chunks using a Langchain text splitter. 3. Use the VectorDBTool to index the chunks into a persistent ChromaDB. - This step is so that we can also use the same documents to verify the answer. 4. Use the FactCheckTool to verify the generated report. """ @@ -15,6 +14,7 @@ import markdown from bs4 import BeautifulSoup from langchain_core.documents import Document +from langchain_text_splitters import RecursiveCharacterTextSplitter from akd.agents.search import ( DeepLitSearchAgent, @@ -23,11 +23,6 @@ ) from akd.configs.project import get_project_settings from akd.tools.fact_check import FactCheckInputSchema, FactCheckTool -from akd.tools.text_splitter import ( - TextSplitterInputSchema, - TextSplitterTool, - TextSplitterToolConfig, -) from akd.tools.vector_db_tool import VectorDBTool @@ -111,16 +106,14 @@ async def main(): if res.get("content") ] - splitter_config = TextSplitterToolConfig(chunk_size=2000) - splitter_tool = TextSplitterTool(config=splitter_config, debug=True) - splitter_input = TextSplitterInputSchema(documents=documents_to_process) - print( - f"\n--- Splitting {len(documents_to_process)} documents into smaller chunks ---", + f"\n--- Splitting {len(documents_to_process)} documents into smaller chunks ---" ) - split_output = await splitter_tool._arun(splitter_input) - chunks = split_output.chunks - + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=2000, + chunk_overlap=200, + ) + chunks = text_splitter.split_documents(documents_to_process) print(f"Created {len(chunks)} chunks.") vector_db_tool = VectorDBTool() @@ -129,7 +122,7 @@ async def main(): vector_db_tool.index(chunks) - print(f"Indexing complete! Database is located at: {vector_db_tool.config.db_path}") + print(f"Indexing complete. Database is located at: {vector_db_tool.config.db_path}") print("\n--- Fact-checking the generated research report ---") fact_check_tool = FactCheckTool() @@ -144,7 +137,8 @@ async def main(): answer=plaintext_report, ) - print(plaintext_report) + print(f"\n--- Running Fact-Check on report ---\n{plaintext_report}\n") + fact_check_result = await fact_check_tool.arun(params=fact_check_input) print("\n--- Fact-Check Complete ---") diff --git a/tests/tools/semantic_scholar_search_test.py b/tests/tools/test_semantic_scholar_search.py similarity index 100% rename from tests/tools/semantic_scholar_search_test.py rename to tests/tools/test_semantic_scholar_search.py diff --git a/tests/tools/vector_db_tool_test.py b/tests/tools/test_vector_db_tool.py similarity index 100% rename from tests/tools/vector_db_tool_test.py rename to tests/tools/test_vector_db_tool.py diff --git a/tests/tools/text_splitter_test.py b/tests/tools/text_splitter_test.py deleted file mode 100644 index 1dcead89..00000000 --- a/tests/tools/text_splitter_test.py +++ /dev/null @@ -1,88 +0,0 @@ -import pytest -from langchain_core.documents import Document - -from akd.tools.text_splitter import ( - TextSplitterInputSchema, - TextSplitterOutputSchema, - TextSplitterTool, - TextSplitterToolConfig, -) - -pytestmark = pytest.mark.asyncio - - -@pytest.fixture -def sample_documents(): - """Provides sample documents for testing.""" - long_text = " ".join(["This is sentence " + str(i) + "." for i in range(200)]) - - return [ - Document( - page_content=long_text, - metadata={"source": "doc1.txt", "url": "www.example1.com"}, - ), - Document( - page_content="This is a short document that should not be split.", - metadata={"source": "doc2.txt", "url": "www.example2.com"}, - ), - ] - - -async def test_text_splitter_with_default_config(sample_documents): - """ - Tests the TextSplitterTool with its default configuration to ensure - it splits long documents and assigns new IDs. - """ - - splitter_tool = TextSplitterTool() - input_data = TextSplitterInputSchema(documents=sample_documents) - - # Run the tool - output = await splitter_tool._arun(input_data) - - assert isinstance(output, TextSplitterOutputSchema) - - # The long document should be split, and the short one should remain as one chunk. - assert len(output.chunks) > len(sample_documents) - - # Check the properties of each chunk - for chunk in output.chunks: - assert isinstance(chunk, Document) - assert len(chunk.page_content) <= splitter_tool.config.chunk_size - # Verify that a new, unique chunk ID has been added to the metadata - assert "id" in chunk.metadata - assert chunk.metadata["id"].startswith(chunk.metadata["source"]) - - -async def test_text_splitter_with_custom_config(): - """ - Tests the TextSplitterTool with a custom configuration (smaller chunk size) - to verify it produces more chunks. - """ - small_text = ( - "The cat sat on the mat. Rug is another word for mat. This is a third sentence." - ) - doc = Document(page_content=small_text, metadata={"source": "custom.txt"}) - - # Use a small chunk size to force splitting - custom_config = TextSplitterToolConfig(chunk_size=30, chunk_overlap=10) - splitter_tool = TextSplitterTool(config=custom_config) - - input_data = TextSplitterInputSchema(documents=[doc]) - - # Run the tool - output = await splitter_tool._arun(input_data) - - # Assert the output - assert isinstance(output, TextSplitterOutputSchema) - - assert len(output.chunks) > 1 - - first_chunk = output.chunks[0] - second_chunk = output.chunks[1] - - # Check the first word 'Rug' is in first chunk - assert ( - second_chunk.page_content[: custom_config.chunk_overlap].split()[0] - in first_chunk.page_content - ), "Chunks should have overlapping content" From ae2f2a12c1d07db8f1524550ced6c9f4d05857ca Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 21 Aug 2025 22:31:20 +0100 Subject: [PATCH 19/33] Remove text splitter tool --- akd/tools/text_splitter.py | 94 -------------------------------------- 1 file changed, 94 deletions(-) delete mode 100644 akd/tools/text_splitter.py diff --git a/akd/tools/text_splitter.py b/akd/tools/text_splitter.py deleted file mode 100644 index a5d9c021..00000000 --- a/akd/tools/text_splitter.py +++ /dev/null @@ -1,94 +0,0 @@ -from typing import List - -from langchain_core.documents import Document -from langchain_text_splitters import RecursiveCharacterTextSplitter -from loguru import logger -from pydantic import Field - -from akd._base import InputSchema, OutputSchema -from akd.tools._base import BaseTool, BaseToolConfig - - -class TextSplitterInputSchema(InputSchema): - """Input schema for the Text Splitter Tool.""" - - documents: List[Document] = Field( - ..., - description="A list of Langchain Document objects to split.", - ) - - -class TextSplitterOutputSchema(OutputSchema): - """Output schema for the Text Splitter Tool.""" - - chunks: List[Document] = Field( - ..., - description="A list of smaller Langchain Document objects (chunks).", - ) - - -class TextSplitterToolConfig(BaseToolConfig): - """Configuration for the TextSplitterTool.""" - - chunk_size: int = Field( - default=1000, - description="The maximum size of each text chunk.", - ) - chunk_overlap: int = Field( - default=100, - description="The number of characters to overlap between chunks.", - ) - - -class TextSplitterTool( - BaseTool[TextSplitterInputSchema, TextSplitterOutputSchema], -): - """ - A tool for splitting large documents into smaller, more manageable chunks. - """ - - name = "text_splitter_tool" - description = "Splits a list of documents into smaller text chunks." - input_schema = TextSplitterInputSchema - output_schema = TextSplitterOutputSchema - config_schema = TextSplitterToolConfig - - def __init__( - self, - config: TextSplitterToolConfig | None = None, - # config: Optional[TextSplitterToolConfig] = None, - debug: bool = False, - ): - """Initializes the TextSplitterTool.""" - config = config or TextSplitterToolConfig() - super().__init__(config, debug) - - logger.info("Initializing TextSplitterTool...") - self._splitter = RecursiveCharacterTextSplitter( - chunk_size=self.config.chunk_size, - chunk_overlap=self.config.chunk_overlap, - ) - - async def _arun( - self, - params: TextSplitterInputSchema, - **kwargs, - ) -> TextSplitterOutputSchema: - """ - Splits the provided documents into smaller chunks. - """ - logger.info(f"Splitting {len(params.documents)} document(s)...") - all_chunks = [] - for doc in params.documents: - chunks = self._splitter.split_documents([doc]) - # Add unique IDs to each chunk's metadata - for i, chunk in enumerate(chunks): - source_id = chunk.metadata.get( - "id", - chunk.metadata.get("source", "unknown"), - ) - chunk.metadata["id"] = f"{source_id}_{i}" - all_chunks.extend(chunks) - - logger.info(f"Created {len(all_chunks)} chunks.") - return TextSplitterOutputSchema(chunks=all_chunks) From 024005cc53913a2b3a65b22e2f6a147f85b79694 Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 21 Aug 2025 22:52:41 +0100 Subject: [PATCH 20/33] Get URL paramater from .env file --- .env.example | 3 ++- akd/tools/fact_check.py | 52 ++++++++++++++++++++++++++--------------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/.env.example b/.env.example index 622e0162..ebc7e68f 100644 --- a/.env.example +++ b/.env.example @@ -31,6 +31,7 @@ GOOGLE_API_KEY="" GOOGLE_CSE_ID="" SEARXNG_BASE_URL="" #"http://localhost:8080" +FACT_CHECK_API_URL="https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud" # for your vector store EMBEDDING_MODEL_ID="nasa-impact/nasa-smd-ibm-st-v2" @@ -45,4 +46,4 @@ MAX_NUM_TURNS = 5 RETRY_ATTEMPTS = 3 # number of editors -NUM_EDITORS = 3 \ No newline at end of file +NUM_EDITORS = 3 diff --git a/akd/tools/fact_check.py b/akd/tools/fact_check.py index bc871bea..e284ab41 100644 --- a/akd/tools/fact_check.py +++ b/akd/tools/fact_check.py @@ -1,5 +1,5 @@ import asyncio -import json +import os from typing import Any, Dict, List, Optional import httpx @@ -21,22 +21,28 @@ class FactCheckOutputSchema(OutputSchema): """Output schema for the Fact-Checking Tool's results.""" fact_reasoner_score: Dict[str, Any] = Field( - ..., description="The full scoring dictionary from the FactReasoner." + ..., + description="The full scoring dictionary from the FactReasoner.", ) supported_atoms: List[Dict[str, Any]] = Field( - ..., description="List of atoms determined to be supported." + ..., + description="List of atoms determined to be supported.", ) not_supported_atoms: List[Dict[str, Any]] = Field( - ..., description="List of atoms determined to be not supported." + ..., + description="List of atoms determined to be not supported.", ) contexts: List[Dict[str, Any]] = Field( - ..., description="List of retrieved contexts used for the check." + ..., + description="List of retrieved contexts used for the check.", ) graph_id: Optional[str] = Field( - None, description="The unique ID for the generated fact graph." + None, + description="The unique ID for the generated fact graph.", ) logging_metadata: Dict[str, Any] = Field( - {}, description="Additional logging metadata from the run." + {}, + description="Additional logging metadata from the run.", ) @@ -44,15 +50,19 @@ class FactCheckToolConfig(BaseToolConfig): """Configuration for the FactCheckTool.""" base_url: HttpUrl = Field( - # default="http://localhost:8011", - default="https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud", + default=os.getenv( + "FACT_CHECK_API_URL", + default="http://localhost:8011", + ), description="The base URL of the remote Fact-Checking and Correction Service.", ) polling_interval_seconds: int = Field( - default=120, description="How often to poll for job results." + default=120, + description="How often to poll for job results.", ) job_timeout_seconds: int = Field( - default=1800, description="Maximum time to wait for a job to complete (30 minutes)." + default=1800, + description="Maximum time to wait for a job to complete (30 minutes).", ) @@ -83,7 +93,6 @@ def __init__( logger.info("Initializing FactCheckTool...") self.api_client = httpx.AsyncClient(base_url=str(self.config.base_url)) - async def _arun( self, params: FactCheckInputSchema, @@ -98,7 +107,9 @@ async def _arun( try: # Start the job start_response = await self.api_client.post( - "/fact-check/start", json=params.model_dump(), timeout=60.0 + "/fact-check/start", + json=params.model_dump(), + timeout=60.0, ) start_response.raise_for_status() job_id = start_response.json()["job_id"] @@ -109,7 +120,8 @@ async def _arun( while total_wait_time < self.config.job_timeout_seconds: logger.info(f"Polling status for job {job_id}...") status_response = await self.api_client.get( - f"/fact-check/status/{job_id}", timeout=60.0 + f"/fact-check/status/{job_id}", + timeout=60.0, ) status_response.raise_for_status() status_data = status_response.json() @@ -119,15 +131,17 @@ async def _arun( return FactCheckOutputSchema(**status_data["result"]) elif status_data["status"] == "failed": raise Exception( - f"Job {job_id} failed on the server: {status_data.get('error', 'Unknown error')}" + f"Job {job_id} failed on the server: {status_data.get('error', 'Unknown error')}", ) elif status_data["status"] == "pending": - logger.info(f"Job {job_id} is in progress... (waited {total_wait_time}s)") + logger.info( + f"Job {job_id} is in progress... (waited {total_wait_time}s)", + ) await asyncio.sleep(self.config.polling_interval_seconds) total_wait_time += self.config.polling_interval_seconds - + raise asyncio.TimeoutError( - f"Job {job_id} did not complete within the {self.config.job_timeout_seconds}s timeout." + f"Job {job_id} did not complete within the {self.config.job_timeout_seconds}s timeout.", ) except httpx.HTTPStatusError as e: @@ -137,4 +151,4 @@ async def _arun( raise except Exception as e: logger.error(f"An unexpected error occurred: {e}") - raise \ No newline at end of file + raise From 4efd005af7d81ca8403ae8570812d5fbc4a89bdc Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 21 Aug 2025 23:06:41 +0100 Subject: [PATCH 21/33] Take parameters from env file --- .env.example | 4 ++++ akd/tools/vector_db_tool.py | 7 ++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.env.example b/.env.example index ebc7e68f..e83be831 100644 --- a/.env.example +++ b/.env.example @@ -36,6 +36,10 @@ FACT_CHECK_API_URL="https://factreasoner-service-app.1yhbkn094k2v.us-south.codee # for your vector store EMBEDDING_MODEL_ID="nasa-impact/nasa-smd-ibm-st-v2" +VECTOR_DB_PATH="./chroma_db" +# API key for various embedding functions (chroma) +EMBEDDING_MODEL_API_KEY="" + # number of wiki results to return TOP_N_WIKI_RESULTS = 1 diff --git a/akd/tools/vector_db_tool.py b/akd/tools/vector_db_tool.py index e972b9e0..5276376e 100644 --- a/akd/tools/vector_db_tool.py +++ b/akd/tools/vector_db_tool.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional import chromadb @@ -34,11 +35,11 @@ class VectorDBToolConfig(BaseToolConfig): description="The name of the Hugging Face embedding model to use.", ) embedding_model_api_key: Optional[str] = Field( - default=None, - description="The API key for the embedding model provider, currently using HuggingFace.", + default=os.getenv("EMBEDDING_MODEL_API_KEY", None), + description="The API key for the embedding model provider, if required.", ) db_path: str = Field( - default="./chroma_db", + default=os.getenv("VECTOR_DB_PATH", "./chroma_db"), description="Path to the persistent ChromaDB directory.", ) collection_name: str = Field( From 5b39e840ecfce1d0f1f9c34a8d3ce15a04b66224 Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 21 Aug 2025 23:15:11 +0100 Subject: [PATCH 22/33] Make the URL an `HttpUrl` --- akd/tools/fact_check.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/akd/tools/fact_check.py b/akd/tools/fact_check.py index e284ab41..38893b16 100644 --- a/akd/tools/fact_check.py +++ b/akd/tools/fact_check.py @@ -50,10 +50,7 @@ class FactCheckToolConfig(BaseToolConfig): """Configuration for the FactCheckTool.""" base_url: HttpUrl = Field( - default=os.getenv( - "FACT_CHECK_API_URL", - default="http://localhost:8011", - ), + default=HttpUrl(os.getenv("FACT_CHECK_API_URL", "http://localhost:8011")), description="The base URL of the remote Fact-Checking and Correction Service.", ) polling_interval_seconds: int = Field( From e3f79c90a6c8ffed4f3028e24f868739cc197af6 Mon Sep 17 00:00:00 2001 From: jbrry Date: Fri, 22 Aug 2025 14:03:53 +0100 Subject: [PATCH 23/33] Add dependencies to pyproject.toml --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 17503a2a..33e56885 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,9 @@ dependencies = [ "pytest-cov>=6.2.1", "sentence-transformers>=5.0.0", "ollama>=0.5.1", - "tiktoken>=0.9.0" + "tiktoken>=0.9.0", + "markdown>=3.8.2", + "chromadb>=1.0.20", ] [project.urls] From e2988d0a6cf3c8e44d3c3f28ac95c09713b3e5c9 Mon Sep 17 00:00:00 2001 From: jbrry Date: Fri, 22 Aug 2025 15:15:01 +0100 Subject: [PATCH 24/33] Use updated pyproject.toml file --- pyproject.toml | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 33e56885..d023cdcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,15 +41,16 @@ dependencies = [ "wikipedia>=1.4.0", "docling>=2.37.0", "langchain-openai>=0.3.27", - "pytest-asyncio>=1.0.0", "tenacity>=9.1.2", "httpx>=0.28.1", - "pytest-cov>=6.2.1", "sentence-transformers>=5.0.0", "ollama>=0.5.1", "tiktoken>=0.9.0", + "rapidfuzz>=3.13.0", + "deepeval>=3.4.0", "markdown>=3.8.2", - "chromadb>=1.0.20", + "chromadb>=1.0.13", + "pytest-mock>=3.14.1", ] [project.urls] @@ -66,11 +67,33 @@ build-backend = "setuptools.build_meta" include = ["akd*"] exclude = ["tests*", "examples*", "scripts*", "notebooks*"] -[tool.poetry.group.dev.dependencies] -ipykernel = "^6.30.0" +# Pytest configuration +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] +python_files = ["test_*.py", "*_test.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] +addopts = [ + "--strict-markers", + "--strict-config", + "-ra", + "--cov=akd", + "--cov-report=term-missing", + "--cov-report=xml", +] +markers = [ + "unit: Unit tests", + "integration: Integration tests", + "slow: Slow tests", +] -[dependency-groups] +# Optional dependencies (extras) +[project.optional-dependencies] dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=1.0.0", + "pytest-cov>=6.0.0", "ipykernel>=6.30.0", "ipywidgets>=8.1.7", "pre-commit>=4.2.0", From 32b7dac7d37458f78ffa793787bd479127e1e8dc Mon Sep 17 00:00:00 2001 From: jbrry Date: Fri, 22 Aug 2025 17:12:30 +0100 Subject: [PATCH 25/33] Update test for fact-check to correspond to new post and get approach --- tests/tools/test_fact_check.py | 63 ++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 15 deletions(-) diff --git a/tests/tools/test_fact_check.py b/tests/tools/test_fact_check.py index ec5dd0b1..bdee4a33 100644 --- a/tests/tools/test_fact_check.py +++ b/tests/tools/test_fact_check.py @@ -12,7 +12,13 @@ @pytest.fixture -def mock_api_response() -> dict: +def mock_api_start_response() -> dict: + """Provides a mock JSON response for when the tool is first started.""" + return {"job_id": "mock-job-1234"} + + +@pytest.fixture +def mock_api_final_response() -> dict: """Provides a mock, successful JSON response from the fact-check API.""" return { "fact_reasoner_score": { @@ -23,46 +29,73 @@ def mock_api_response() -> dict: "not_supported_atoms": [], "contexts": [{"id": "c_a0_0", "title": "Why Is the Sky Blue?"}], "graph_id": "mock-graph-id-123", + "logging_metadata": {}, } @pytest.fixture def fact_check_tool(): - config = FactCheckToolConfig(base_url="http://localhost:8011") + config = FactCheckToolConfig( + base_url=HttpUrl("http://localhost:8011"), + polling_interval_seconds=1, + ) tool = FactCheckTool(config=config) return tool @pytest.mark.asyncio -async def test_fact_check_tool_with_mock(mocker, mock_api_response): - """Tests the FactCheck tool's post method. The post method is patched with a mock request.""" - - mock_response = mocker.Mock() - mock_response.status_code = 200 - mock_response.json.return_value = mock_api_response +async def test_fact_check_tool_polling_workflow( + mocker, + fact_check_tool, + mock_api_start_response, + mock_api_final_response, +): + """ + Tests the full polling workflow by mocking both the POST and GET calls. + """ + + # Mock the POST call to /fact-check/start + mock_post_response = mocker.Mock() + mock_post_response.status_code = 200 + mock_post_response.json.return_value = mock_api_start_response + + # Mock the GET call to /fact-check/status/{job_id} + mock_get_response = mocker.Mock() + mock_get_response.status_code = 200 + mock_get_response.json.return_value = { + "status": "completed", + "result": mock_api_final_response, + } + # Patch both 'post' and 'get' methods of the httpx.AsyncClient mocker.patch( "akd.tools.fact_check.httpx.AsyncClient.post", new_callable=AsyncMock, - return_value=mock_response, + return_value=mock_post_response, + ) + mocker.patch( + "akd.tools.fact_check.httpx.AsyncClient.get", + new_callable=AsyncMock, + return_value=mock_get_response, ) - tool = FactCheckTool() + # Prepare the input for the tool input_data = FactCheckInputSchema( question="What colour is the sky?", answer="The sky is blue.", ) - # This will call the mocked version of the tool. - output = await tool.arun(input_data) + # This will call the mocked versions of post and then get. + output = await fact_check_tool.arun(input_data) + # Assert: assert isinstance(output, FactCheckOutputSchema) - assert output.fact_reasoner_score["factuality_score"] == 1.0 - assert len(output.fact_reasoner_score["factuality_score_per_atom"]) == 1 assert output.graph_id == "mock-graph-id-123" assert len(output.supported_atoms) == 1 def test_fact_check_tool_initialization(fact_check_tool): - assert fact_check_tool.base_url == HttpUrl("http://localhost:8011") + """Tests that the tool initializes with the correct config.""" + # This test remains the same and is a good check. + assert fact_check_tool.config.base_url == HttpUrl("http://localhost:8011") From dadb3388126f4bd9638f69332c9c0dadd43ed47a Mon Sep 17 00:00:00 2001 From: jbrry Date: Fri, 22 Aug 2025 18:05:34 +0100 Subject: [PATCH 26/33] Update pyproject, and get API and job timeouts from env --- .env.example | 4 ++++ akd/tools/fact_check.py | 17 ++++++++++++----- pyproject.toml | 2 +- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/.env.example b/.env.example index e83be831..fd83edaa 100644 --- a/.env.example +++ b/.env.example @@ -31,7 +31,11 @@ GOOGLE_API_KEY="" GOOGLE_CSE_ID="" SEARXNG_BASE_URL="" #"http://localhost:8080" + +# For FactCheck FACT_CHECK_API_URL="https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud" +FACT_CHECK_JOB_TIMEOUT="1800" +FACT_CHECK_REQUEST_TIMEOUT="60" # for your vector store EMBEDDING_MODEL_ID="nasa-impact/nasa-smd-ibm-st-v2" diff --git a/akd/tools/fact_check.py b/akd/tools/fact_check.py index 38893b16..a5c894cf 100644 --- a/akd/tools/fact_check.py +++ b/akd/tools/fact_check.py @@ -58,8 +58,12 @@ class FactCheckToolConfig(BaseToolConfig): description="How often to poll for job results.", ) job_timeout_seconds: int = Field( - default=1800, - description="Maximum time to wait for a job to complete (30 minutes).", + default=int(os.getenv("FACT_CHECK_JOB_TIMEOUT", "1800")), + description="Maximum time to wait for the entire job to complete (30 minutes).", + ) + request_timeout_seconds: int = Field( + default=int(os.getenv("FACT_CHECK_REQUEST_TIMEOUT", "60")), + description="Timeout in seconds for each individual API request.", ) @@ -88,7 +92,12 @@ def __init__( super().__init__(config, debug) logger.info("Initializing FactCheckTool...") - self.api_client = httpx.AsyncClient(base_url=str(self.config.base_url)) + # Set a timeout on the API requests + timeout = httpx.Timeout(self.config.request_timeout_seconds, connect=60.0) + self.api_client = httpx.AsyncClient( + base_url=str(self.config.base_url), + timeout=timeout, + ) async def _arun( self, @@ -106,7 +115,6 @@ async def _arun( start_response = await self.api_client.post( "/fact-check/start", json=params.model_dump(), - timeout=60.0, ) start_response.raise_for_status() job_id = start_response.json()["job_id"] @@ -118,7 +126,6 @@ async def _arun( logger.info(f"Polling status for job {job_id}...") status_response = await self.api_client.get( f"/fact-check/status/{job_id}", - timeout=60.0, ) status_response.raise_for_status() status_data = status_response.json() diff --git a/pyproject.toml b/pyproject.toml index d023cdcc..df4f2a8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ dependencies = [ "deepeval>=3.4.0", "markdown>=3.8.2", "chromadb>=1.0.13", - "pytest-mock>=3.14.1", ] [project.urls] @@ -94,6 +93,7 @@ dev = [ "pytest>=8.0.0", "pytest-asyncio>=1.0.0", "pytest-cov>=6.0.0", + "pytest-mock>=3.14.1", "ipykernel>=6.30.0", "ipywidgets>=8.1.7", "pre-commit>=4.2.0", From 14d4be3dedc8ee371f7094f403bddf9f1a4961f6 Mon Sep 17 00:00:00 2001 From: jbrry Date: Fri, 22 Aug 2025 18:11:49 +0100 Subject: [PATCH 27/33] Set vector db path from util function --- akd/tools/vector_db_tool.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/akd/tools/vector_db_tool.py b/akd/tools/vector_db_tool.py index 5276376e..6f354a9e 100644 --- a/akd/tools/vector_db_tool.py +++ b/akd/tools/vector_db_tool.py @@ -9,6 +9,7 @@ from akd._base import InputSchema, OutputSchema from akd.tools._base import BaseTool, BaseToolConfig +from akd.utils import get_akd_root class VectorDBInputSchema(InputSchema): @@ -42,6 +43,11 @@ class VectorDBToolConfig(BaseToolConfig): default=os.getenv("VECTOR_DB_PATH", "./chroma_db"), description="Path to the persistent ChromaDB directory.", ) + db_path: str = Field( + default=os.getenv("VECTOR_DB_PATH", str(get_akd_root() / "chroma_db")), + description="Path to the persistent ChromaDB directory.", + ) + collection_name: str = Field( default="litagent_demo", description="Name of the collection within ChromaDB.", From 1d2627a0a21ac8cc0cfd8509b50d409541787b86 Mon Sep 17 00:00:00 2001 From: jbrry Date: Fri, 22 Aug 2025 22:48:09 +0100 Subject: [PATCH 28/33] Remove dependency on Langchain Documents --- akd/tools/vector_db_tool.py | 77 ++++++++--------- scripts/demo_deep_search_with_fact_check.py | 45 +++++----- tests/tools/test_semantic_scholar_search.py | 92 +++++++++------------ tests/tools/test_vector_db_tool.py | 88 +++++++++----------- 4 files changed, 138 insertions(+), 164 deletions(-) diff --git a/akd/tools/vector_db_tool.py b/akd/tools/vector_db_tool.py index 6f354a9e..3262b49b 100644 --- a/akd/tools/vector_db_tool.py +++ b/akd/tools/vector_db_tool.py @@ -1,9 +1,8 @@ import os -from typing import List, Optional +from typing import Any, Dict, List, Optional import chromadb import chromadb.utils.embedding_functions as embedding_functions -from langchain_core.documents import Document from loguru import logger from pydantic import Field @@ -12,19 +11,31 @@ from akd.utils import get_akd_root -class VectorDBInputSchema(InputSchema): +class VectorDBIndexInputSchema(InputSchema): + """Input schema for indexing documents into the Vector Database.""" + + ids: List[str] = Field(..., description="A unique list of document IDs.") + documents: List[str] = Field( + ..., description="A list of document contents to index." + ) + metadatas: Optional[List[Dict[str, Any]]] = Field( + None, description="Optional list of metadata for each document." + ) + + +class VectorDBQueryInputSchema(InputSchema): """Input schema for querying documents from the Vector Database.""" query: str = Field(..., description="The query string for retrieval.") k: int = Field(3, description="Number of documents to retrieve.") -class VectorDBOutputSchema(OutputSchema): +class VectorDBQueryOutputSchema(OutputSchema): """Output schema for the Vector Database tool's query results.""" - results: List[Document] = Field( + results: List[Dict[str, Any]] = Field( ..., - description="List of retrieved Langchain Document objects.", + description="List of retrieved documents, each as a dictionary with 'page_content' and 'metadata'.", ) @@ -39,15 +50,10 @@ class VectorDBToolConfig(BaseToolConfig): default=os.getenv("EMBEDDING_MODEL_API_KEY", None), description="The API key for the embedding model provider, if required.", ) - db_path: str = Field( - default=os.getenv("VECTOR_DB_PATH", "./chroma_db"), - description="Path to the persistent ChromaDB directory.", - ) db_path: str = Field( default=os.getenv("VECTOR_DB_PATH", str(get_akd_root() / "chroma_db")), description="Path to the persistent ChromaDB directory.", ) - collection_name: str = Field( default="litagent_demo", description="Name of the collection within ChromaDB.", @@ -55,7 +61,7 @@ class VectorDBToolConfig(BaseToolConfig): class VectorDBTool( - BaseTool[VectorDBInputSchema, VectorDBOutputSchema], + BaseTool[VectorDBQueryInputSchema, VectorDBQueryOutputSchema], ): """ A tool for indexing and retrieving documents from a Chroma vector database. @@ -65,13 +71,13 @@ class VectorDBTool( description = ( "Indexes documents into a vector database and retrieves them based on a query." ) - input_schema = VectorDBInputSchema - output_schema = VectorDBOutputSchema + input_schema = VectorDBQueryInputSchema + output_schema = VectorDBQueryOutputSchema config_schema = VectorDBToolConfig def __init__( self, - config: VectorDBToolConfig | None = None, + config: Optional[VectorDBToolConfig] = None, debug: bool = False, ): """Initializes the VectorDBTool and its ChromaDB client.""" @@ -79,7 +85,6 @@ def __init__( super().__init__(config, debug) logger.info("Initializing VectorDBTool...") - self.client = chromadb.PersistentClient(path=self.config.db_path) embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( @@ -93,36 +98,29 @@ def __init__( f"Connected to ChromaDB collection '{self.config.collection_name}'.", ) - def index(self, documents: List[Document]): + def index(self, params: VectorDBIndexInputSchema): """ - Adds or updates documents in the vector database collection from Langchain Documents. + Adds or updates documents in the vector database collection. """ - logger.info(f"Indexing {len(documents)} documents...") - - # Extract components from the Document objects for ChromaDB - ids = [doc.metadata.get("id", f"doc_{i}") for i, doc in enumerate(documents)] - contents = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] - + logger.info(f"Indexing {len(params.documents)} documents...") self.collection.add( - ids=ids, - documents=contents, - metadatas=metadatas, + ids=params.ids, + documents=params.documents, + metadatas=params.metadatas, ) logger.info("Indexing complete.") async def _arun( self, - params: VectorDBInputSchema, - ) -> VectorDBOutputSchema: + params: VectorDBQueryInputSchema, + ) -> VectorDBQueryOutputSchema: """ - Retrieves documents and returns them as a list of Langchain Document objects. + Retrieves documents and returns them as a list of dictionaries. """ logger.info( f"Querying collection with query: '{params.query}', retrieving top-{params.k} documents", ) - # Include metadatas and documents to reconstruct the Document objects results = self.collection.query( query_texts=[params.query], n_results=params.k, @@ -130,20 +128,17 @@ async def _arun( ) retrieved_docs = [] - # The result is batched; we process the first (and only) query's results if results and results.get("ids") and results["ids"][0]: - result_ids = results["ids"][0] result_documents = results["documents"][0] result_metadatas = results["metadatas"][0] - for i in range(len(result_ids)): - # Reconstruct the Langchain Document object - doc = Document( - page_content=result_documents[i], - metadata=result_metadatas[i] + for i in range(len(result_documents)): + doc = { + "page_content": result_documents[i], + "metadata": result_metadatas[i] if result_metadatas and result_metadatas[i] else {}, - ) + } retrieved_docs.append(doc) - return VectorDBOutputSchema(results=retrieved_docs) + return VectorDBQueryOutputSchema(results=retrieved_docs) diff --git a/scripts/demo_deep_search_with_fact_check.py b/scripts/demo_deep_search_with_fact_check.py index 52b51f2f..139ad966 100644 --- a/scripts/demo_deep_search_with_fact_check.py +++ b/scripts/demo_deep_search_with_fact_check.py @@ -13,7 +13,6 @@ import markdown from bs4 import BeautifulSoup -from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter from akd.agents.search import ( @@ -23,7 +22,7 @@ ) from akd.configs.project import get_project_settings from akd.tools.fact_check import FactCheckInputSchema, FactCheckTool -from akd.tools.vector_db_tool import VectorDBTool +from akd.tools.vector_db_tool import VectorDBIndexInputSchema, VectorDBTool def process_report(report): @@ -96,31 +95,39 @@ async def main(): print(f"\n--- Found {len(source_results)} source documents and research report ---") - # Convert search results to Langchain Document objects for processing - documents_to_process = [ - Document( - page_content=res["content"], - metadata={"source": res["url"], "title": res["title"]}, - ) - for res in source_results - if res.get("content") - ] - - print( - f"\n--- Splitting {len(documents_to_process)} documents into smaller chunks ---" - ) + print(f"\n--- Splitting {len(source_results)} documents into smaller chunks ---") text_splitter = RecursiveCharacterTextSplitter( chunk_size=2000, chunk_overlap=200, ) - chunks = text_splitter.split_documents(documents_to_process) - print(f"Created {len(chunks)} chunks.") + + all_chunks = [] + all_metadatas = [] + all_ids = [] + + for i, res in enumerate(source_results): + if res.get("content"): + # Use split_text on the raw content string + chunks = text_splitter.split_text(res["content"]) + for j, chunk in enumerate(chunks): + all_chunks.append(chunk) + # Create a unique ID and metadata for each chunk + metadata = {"source": res["url"], "title": res["title"]} + all_metadatas.append(metadata) + all_ids.append(f"res_{i}_chunk_{j}") + + print(f"Created {len(all_chunks)} chunks.") vector_db_tool = VectorDBTool() - print(f"\n--- Indexing {len(chunks)} chunks into ChromaDB ---") + print(f"\n--- Indexing {len(all_chunks)} chunks into ChromaDB ---") - vector_db_tool.index(chunks) + index_params = VectorDBIndexInputSchema( + ids=all_ids, + documents=all_chunks, + metadatas=all_metadatas, + ) + vector_db_tool.index(index_params) print(f"Indexing complete. Database is located at: {vector_db_tool.config.db_path}") diff --git a/tests/tools/test_semantic_scholar_search.py b/tests/tools/test_semantic_scholar_search.py index a86eefce..2f0c73e1 100644 --- a/tests/tools/test_semantic_scholar_search.py +++ b/tests/tools/test_semantic_scholar_search.py @@ -1,15 +1,31 @@ -import asyncio +from unittest.mock import AsyncMock import pytest +from pydantic import HttpUrl -from akd.tools.search import ( +from akd.structures import SearchResultItem +from akd.tools.search.semantic_scholar_search import ( SemanticScholarSearchTool, - SemanticScholarSearchToolConfig, SemanticScholarSearchToolInputSchema, SemanticScholarSearchToolOutputSchema, ) -pytest_plugins = ("pytest_asyncio",) + +@pytest.fixture +def mock_arun_output() -> SemanticScholarSearchToolOutputSchema: + """Provides a realistic, successful output object from the arun method.""" + return SemanticScholarSearchToolOutputSchema( + results=[ + SearchResultItem( + title="A Mock Paper on Transformer Architectures", + url=HttpUrl("https://www.semanticscholar.org/paper/a1b2c3d4e5f6"), + content="This paper discusses recent advances in transformer models for NLP.", + query="Recent advances in transformer architectures", + ), + ], + category="science", + ) + def test_from_params_constructor(): """ @@ -18,66 +34,32 @@ def test_from_params_constructor(): """ search_tool = SemanticScholarSearchTool.from_params(max_results=5, debug=True) assert search_tool.config.max_results == 5 - # Test a default value - assert search_tool.config.external_id == "DOI" + # Test a default value from the config + assert search_tool.config.base_url == HttpUrl("https://api.semanticscholar.org") @pytest.mark.asyncio -async def test_fetch_paper_by_external_id(): # Renamed for clarity +async def test_arun_with_direct_mock(mocker, mock_arun_output): """ - Tests that fetch_paper_by_external_id can successfully retrieve - and parse a specific paper using its ARXIV ID. + Tests the main `arun` method by directly mocking the internal `_arun` method. """ - config = SemanticScholarSearchToolConfig() - search_tool = SemanticScholarSearchTool(config=config, debug=True) - - known_arxiv_id = "1706.03762" - input_schema = SemanticScholarSearchToolInputSchema(queries=[known_arxiv_id]) - - results = await search_tool.fetch_paper_by_external_id( - input_schema, - external_id="ARXIV", + # Patch the internal _arun method to return our mock output directly. + mocker.patch( + "akd.tools.search.semantic_scholar_search.SemanticScholarSearchTool._arun", + new_callable=AsyncMock, + return_value=mock_arun_output, ) - assert isinstance(results, list) - assert len(results) == 1, ( - "Expected to find exactly one paper for the given ArXiv ID." + # Initialize the tool and input data. + tool = SemanticScholarSearchTool() + input_schema = SemanticScholarSearchToolInputSchema( + queries=["Recent advances in transformer architectures"], + max_results=3, ) - paper = results[0] - # Check that the title and ArXiv ID match the paper we requested. - assert paper.external_ids["ArXiv"] == known_arxiv_id + # This will now call the mocked version of _arun. + output = await tool.arun(input_schema) - -@pytest.mark.asyncio -async def test_arun(): - """ - Tests the main `arun` method to ensure the full process works. - """ - config = SemanticScholarSearchToolConfig() - search_tool = SemanticScholarSearchTool(config=config, debug=True) - - queries = ["Enhanced dependency parsing approaches"] - input_schema = SemanticScholarSearchToolInputSchema(queries=queries, max_results=3) - - output = await search_tool.arun(input_schema) - - # Assertions to check the final, processed output assert isinstance(output, SemanticScholarSearchToolOutputSchema) assert len(output.results) > 0, "No results found" - - first_result = output.results[0] - assert first_result.url, "No url included" - assert first_result.title, "No title included" - assert first_result.content, "No content included" - - -async def main(): - """Runs all the defined tests.""" - test_from_params_constructor() - await test_fetch_paper_by_external_id() - await test_arun() - - -if __name__ == "__main__": - asyncio.run(main()) + assert output.results[0].title == "A Mock Paper on Transformer Architectures" diff --git a/tests/tools/test_vector_db_tool.py b/tests/tools/test_vector_db_tool.py index 3be016c8..b26f7566 100644 --- a/tests/tools/test_vector_db_tool.py +++ b/tests/tools/test_vector_db_tool.py @@ -1,12 +1,13 @@ import shutil from pathlib import Path +from typing import Dict, List import pytest -from langchain_core.documents import Document from akd.tools.vector_db_tool import ( - VectorDBInputSchema, - VectorDBOutputSchema, + VectorDBIndexInputSchema, + VectorDBQueryInputSchema, + VectorDBQueryOutputSchema, VectorDBTool, VectorDBToolConfig, ) @@ -19,107 +20,96 @@ def temp_db_path(tmp_path: Path) -> str: """Create a temporary directory for the ChromaDB database.""" db_path = str(tmp_path / "test_chroma_db") yield db_path - # Clean up the database directory after the test runs shutil.rmtree(db_path, ignore_errors=True) @pytest.fixture -def sample_documents() -> list[Document]: - """Provides a list of sample documents for indexing.""" - return [ - Document( - page_content="The sky is blue. It is a beautiful day.", - metadata={"id": "doc1", "source": "weather.txt"}, - ), - Document( - page_content="Apples are a type of fruit. They are often red or green.", - metadata={"id": "doc2", "source": "fruits.txt"}, - ), - Document( - page_content="A computer is an electronic device. Keyboards are used for input.", - metadata={"id": "doc3", "source": "tech.txt"}, - ), - ] +def sample_data() -> Dict[str, List]: + """Provides sample data as a dictionary of lists.""" + return { + "ids": ["doc1", "doc2", "doc3"], + "documents": [ + "The sky is blue. It is thirty degrees Celsius.", + "The ingredients include apples, pears and grapes.", + "A computer is an electronic device. Keyboards are used for input.", + ], + "metadatas": [ + {"source": "weather.txt", "title": "weather_report"}, + {"source": "ingredients.txt", "title": "ingredients_list"}, + {"source": "tech.txt", "url": "www.tech-example.com"}, + ], + } @pytest.fixture def configured_db_tool( temp_db_path: str, - sample_documents: list[Document], + sample_data: Dict[str, List], ) -> VectorDBTool: """ - Provides a fully configured and pre-populated VectorDBTool instance - for testing retrieval. + Provides a fully configured and pre-populated VectorDBTool instance. """ - # Use a unique collection name for each test run to ensure isolation config = VectorDBToolConfig( db_path=temp_db_path, collection_name="test_collection", ) db_tool = VectorDBTool(config=config) - # Index the sample documents into the fresh database - db_tool.index(sample_documents) + # Index the sample documents + index_params = VectorDBIndexInputSchema(**sample_data) + db_tool.index(index_params) return db_tool def test_vectordb_initialization(temp_db_path: str): """ - Tests that the VectorDBTool initializes correctly and creates the - database directory and collection. + Tests that the VectorDBTool initializes correctly. """ config = VectorDBToolConfig(db_path=temp_db_path, collection_name="init_test") db_tool = VectorDBTool(config=config) - # Assert that the client and collection were created assert db_tool.client is not None assert db_tool.collection is not None assert db_tool.collection.name == "init_test" - - # Assert that the database directory was created on disk assert Path(temp_db_path).exists() -def test_index_method(temp_db_path: str, sample_documents: list[Document]): +def test_index_method(temp_db_path: str, sample_data: Dict[str, List]): """ - Tests that the `index` method correctly adds documents to the collection. + Tests that the `index` method correctly adds documents. """ config = VectorDBToolConfig(db_path=temp_db_path, collection_name="index_test") db_tool = VectorDBTool(config=config) - # Index the documents - db_tool.index(sample_documents) + index_params = VectorDBIndexInputSchema(**sample_data) + db_tool.index(index_params) - # Verify the documents were added to the collection - assert db_tool.collection.count() == len(sample_documents) + assert db_tool.collection.count() == len(sample_data["documents"]) - # Retrieve one document by ID to confirm its content retrieved = db_tool.collection.get(ids=["doc2"], include=["metadatas", "documents"]) assert ( retrieved["documents"][0] - == "Apples are a type of fruit. They are often red or green." + == "The ingredients include apples, pears and grapes." ) - assert retrieved["metadatas"][0]["source"] == "fruits.txt" + assert retrieved["metadatas"][0]["source"] == "ingredients.txt" + assert retrieved["metadatas"][0]["title"] == "ingredients_list" async def test_arun_retrieval(configured_db_tool: VectorDBTool): """ - Tests that the `_arun` method correctly retrieves the most relevant - documents for a given query. + Tests that the `_arun` method correctly retrieves relevant documents. """ - query = "What color is the sky?" - input_params = VectorDBInputSchema(query=query, k=1) + input_params = VectorDBQueryInputSchema(query=query, k=1) output = await configured_db_tool._arun(input_params) - assert isinstance(output, VectorDBOutputSchema) + assert isinstance(output, VectorDBQueryOutputSchema) assert len(output.results) == 1 retrieved_doc = output.results[0] - assert isinstance(retrieved_doc, Document) + assert isinstance(retrieved_doc, dict) - # Check that the most relevant document was returned - assert retrieved_doc.metadata["source"] == "weather.txt" - assert "The sky is blue" in retrieved_doc.page_content + assert retrieved_doc["metadata"]["source"] == "weather.txt" + assert "The sky is blue" in retrieved_doc["page_content"] From ccc37f70b1ff6bc40428ad52d2991e47e52963ea Mon Sep 17 00:00:00 2001 From: jbrry Date: Fri, 22 Aug 2025 22:52:04 +0100 Subject: [PATCH 29/33] Update test for vector db tool --- tests/tools/test_vector_db_tool.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tools/test_vector_db_tool.py b/tests/tools/test_vector_db_tool.py index b26f7566..6f4d07b7 100644 --- a/tests/tools/test_vector_db_tool.py +++ b/tests/tools/test_vector_db_tool.py @@ -89,8 +89,7 @@ def test_index_method(temp_db_path: str, sample_data: Dict[str, List]): retrieved = db_tool.collection.get(ids=["doc2"], include=["metadatas", "documents"]) assert ( - retrieved["documents"][0] - == "The ingredients include apples, pears and grapes." + retrieved["documents"][0] == "The ingredients include apples, pears and grapes." ) assert retrieved["metadatas"][0]["source"] == "ingredients.txt" assert retrieved["metadatas"][0]["title"] == "ingredients_list" From 0b812771e2f53a8956f3604f959f11b9e9b5c998 Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 28 Aug 2025 17:54:26 +0100 Subject: [PATCH 30/33] Update fact-check tool config to use endpoints given in .env file --- .env.example | 5 +++++ akd/tools/fact_check.py | 24 ++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/.env.example b/.env.example index fd83edaa..693e5a7a 100644 --- a/.env.example +++ b/.env.example @@ -36,6 +36,11 @@ SEARXNG_BASE_URL="" #"http://localhost:8080" FACT_CHECK_API_URL="https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud" FACT_CHECK_JOB_TIMEOUT="1800" FACT_CHECK_REQUEST_TIMEOUT="60" +FACT_CHECK_START_ENDPOINT="/fact-check/start" +FACT_CHECK_STATUS_ENDPOINT="/fact-check/status" +FACT_CHECK_CORRECT_ENDPOINT="/correct" +FACT_CHECK_DISPLAY_GRAPH_ENDPOINT="/display_graph" +FACT_CHECK_GRAPH_JSON_ENDPOINT="/graph/json" # for your vector store EMBEDDING_MODEL_ID="nasa-impact/nasa-smd-ibm-st-v2" diff --git a/akd/tools/fact_check.py b/akd/tools/fact_check.py index a5c894cf..a6a51c31 100644 --- a/akd/tools/fact_check.py +++ b/akd/tools/fact_check.py @@ -53,6 +53,26 @@ class FactCheckToolConfig(BaseToolConfig): default=HttpUrl(os.getenv("FACT_CHECK_API_URL", "http://localhost:8011")), description="The base URL of the remote Fact-Checking and Correction Service.", ) + start_endpoint: str = Field( + default=os.getenv("FACT_CHECK_START_ENDPOINT", "/fact-check/start"), + description="Endpoint to start a new fact-checking job." + ) + status_endpoint: str = Field( + default=os.getenv("FACT_CHECK_STATUS_ENDPOINT", "/fact-check/status/"), + description="Endpoint to get the status of a job. Must end with a slash." + ) + correct_endpoint: str = Field( + default=os.getenv("FACT_CHECK_CORRECT_ENDPOINT", "/correct/"), + description="Endpoint for single correction steps." + ) + display_graph_endpoint: str = Field( + default=os.getenv("FACT_CHECK_DISPLAY_GRAPH_ENDPOINT", "/display_graph/"), + description="Endpoint to display a saved fact graph." + ) + graph_json_endpoint: str = Field( + default=os.getenv("FACT_CHECK_GRAPH_JSON_ENDPOINT", "/graph/json/"), + description="Endpoint to retrieve graph data as JSON." + ) polling_interval_seconds: int = Field( default=120, description="How often to poll for job results.", @@ -113,7 +133,7 @@ async def _arun( try: # Start the job start_response = await self.api_client.post( - "/fact-check/start", + self.config.start_endpoint, json=params.model_dump(), ) start_response.raise_for_status() @@ -125,7 +145,7 @@ async def _arun( while total_wait_time < self.config.job_timeout_seconds: logger.info(f"Polling status for job {job_id}...") status_response = await self.api_client.get( - f"/fact-check/status/{job_id}", + f"{self.config.status_endpoint}/{job_id}", ) status_response.raise_for_status() status_data = status_response.json() From e7ef73387b08e4b66aabb673aa8dc28736a08a24 Mon Sep 17 00:00:00 2001 From: jbrry Date: Wed, 3 Sep 2025 20:26:39 +0100 Subject: [PATCH 31/33] Fix conflict from develop --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1c79b68c..331cd481 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,12 +92,9 @@ dev = [ "pytest>=8.0.0", "pytest-asyncio>=1.0.0", "pytest-cov>=6.0.0", -<<<<<<< HEAD "pytest-mock>=3.14.1", "ipykernel>=6.30.0", "ipywidgets>=8.1.7", -======= ->>>>>>> develop "pre-commit>=4.2.0", ] local = [ From 8dbdc077da304da159e7ee38b840c504ea349100 Mon Sep 17 00:00:00 2001 From: jbrry Date: Wed, 3 Sep 2025 21:17:04 +0100 Subject: [PATCH 32/33] Sync version of pyproject.toml from integration branch --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 331cd481..a11cb97c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ local = [ "ipywidgets>=8.1.7", ] ml = [ + "chromadb>=1.0.13", "pandas>=2.3.1", "sentence-transformers>=5.0.0", "docling>=2.37.0", From 2a50d764d9ccd2ed40b7277b1916e9956586dc5a Mon Sep 17 00:00:00 2001 From: jbrry Date: Thu, 4 Sep 2025 18:04:17 +0100 Subject: [PATCH 33/33] Change vector db collection name to something more generic --- akd/tools/vector_db_tool.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/akd/tools/vector_db_tool.py b/akd/tools/vector_db_tool.py index 3262b49b..257c86e4 100644 --- a/akd/tools/vector_db_tool.py +++ b/akd/tools/vector_db_tool.py @@ -16,10 +16,12 @@ class VectorDBIndexInputSchema(InputSchema): ids: List[str] = Field(..., description="A unique list of document IDs.") documents: List[str] = Field( - ..., description="A list of document contents to index." + ..., + description="A list of document contents to index.", ) metadatas: Optional[List[Dict[str, Any]]] = Field( - None, description="Optional list of metadata for each document." + None, + description="Optional list of metadata for each document.", ) @@ -55,7 +57,7 @@ class VectorDBToolConfig(BaseToolConfig): description="Path to the persistent ChromaDB directory.", ) collection_name: str = Field( - default="litagent_demo", + default="akd_vdb", description="Name of the collection within ChromaDB.", )