Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
727d5f3
Add `VectorDBTool` based on Chroma
jbrry Jul 23, 2025
1c75c1a
Merge branch 'develop' into feature/vector_database_tool
jbrry Jul 23, 2025
d85bf70
Adopt config strategy and add text splitter and vector db tools
jbrry Jul 23, 2025
63e5cbe
Update litsearch agent with new indexing capabilities
jbrry Jul 23, 2025
f1f6a76
Remove comment line
jbrry Jul 23, 2025
2e78325
Changes to standalone script for vectordb
jbrry Aug 7, 2025
2586621
Add test to run `SemanticScholarSearchTool`
jbrry Aug 8, 2025
21727e0
Merge changes from develop
jbrry Aug 8, 2025
25ccec3
Add test for text splitter tool
jbrry Aug 11, 2025
8d3e4b0
Merge branch 'develop' into feature/add_semantic_scholar_search_test
jbrry Aug 11, 2025
384c9d3
Test the `from_params` method of the class
jbrry Aug 11, 2025
231e6ff
Add tests for vector db and text splitter tools
jbrry Aug 11, 2025
c4a57c9
Fix chunking test
jbrry Aug 11, 2025
d947cb7
Leave config initialisation to super class
jbrry Aug 11, 2025
2220443
Add FactReasoner fact-check to deep research pipeline
jbrry Aug 12, 2025
7a2c13d
Add test for FactCheck tool
jbrry Aug 14, 2025
a492e6b
Merge branch 'feature/add_semantic_scholar_search_test' into feature/…
jbrry Aug 14, 2025
595cb38
Updated test file
jbrry Aug 14, 2025
8d9ed47
Wrap string in HttpUrl
jbrry Aug 14, 2025
fce0cfb
Update demo script with fact check
jbrry Aug 20, 2025
6c50cb7
Enable polling for long running processes
jbrry Aug 21, 2025
84cecff
Do not use text splitter tool
jbrry Aug 21, 2025
ae2f2a1
Remove text splitter tool
jbrry Aug 21, 2025
024005c
Get URL paramater from .env file
jbrry Aug 21, 2025
4efd005
Take parameters from env file
jbrry Aug 21, 2025
5b39e84
Make the URL an `HttpUrl`
jbrry Aug 21, 2025
e3f79c9
Add dependencies to pyproject.toml
jbrry Aug 22, 2025
e2988d0
Use updated pyproject.toml file
jbrry Aug 22, 2025
66e6d88
Fix merge conflict in pyproject.toml file
jbrry Aug 22, 2025
32b7dac
Update test for fact-check to correspond to new post and get approach
jbrry Aug 22, 2025
dadb338
Update pyproject, and get API and job timeouts from env
jbrry Aug 22, 2025
14d4be3
Set vector db path from util function
jbrry Aug 22, 2025
1d2627a
Remove dependency on Langchain Documents
jbrry Aug 22, 2025
ccc37f7
Update test for vector db tool
jbrry Aug 22, 2025
f7964b9
Merge from `develop` and fix conflict in pyproject.toml
jbrry Aug 22, 2025
0b81277
Update fact-check tool config to use endpoints given in .env file
jbrry Aug 28, 2025
8582873
Merge branch 'develop' into feature/vector_database_tool
jbrry Aug 28, 2025
03f0f7f
Fix conflict from develop
jbrry Sep 3, 2025
e7ef733
Fix conflict from develop
jbrry Sep 3, 2025
8dbdc07
Sync version of pyproject.toml from integration branch
jbrry Sep 3, 2025
2a50d76
Change vector db collection name to something more generic
jbrry Sep 4, 2025
39c8b8e
Merge branch 'develop' into feature/vector_database_tool
jbrry Sep 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,23 @@ SEARXNG_STRICT=False
SEARXNG_DEBUG=False
SEARXNG_ENGINES="google,arxiv,google_scholar"

# For FactCheck
FACT_CHECK_API_URL="https://factreasoner-service-app.1yhbkn094k2v.us-south.codeengine.appdomain.cloud"
FACT_CHECK_JOB_TIMEOUT="1800"
FACT_CHECK_REQUEST_TIMEOUT="60"
FACT_CHECK_START_ENDPOINT="/fact-check/start"
FACT_CHECK_STATUS_ENDPOINT="/fact-check/status"
FACT_CHECK_CORRECT_ENDPOINT="/correct"
FACT_CHECK_DISPLAY_GRAPH_ENDPOINT="/display_graph"
FACT_CHECK_GRAPH_JSON_ENDPOINT="/graph/json"

# for your vector store
EMBEDDING_MODEL_ID="nasa-impact/nasa-smd-ibm-st-v2"

VECTOR_DB_PATH="./chroma_db"
# API key for various embedding functions (chroma)
EMBEDDING_MODEL_API_KEY=""

# number of wiki results to return
TOP_N_WIKI_RESULTS = 1

Expand Down
178 changes: 178 additions & 0 deletions akd/tools/fact_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import asyncio
import os
from typing import Any, Dict, List, Optional

import httpx
from loguru import logger
from pydantic import Field, HttpUrl

from akd._base import InputSchema, OutputSchema
from akd.tools._base import BaseTool, BaseToolConfig


class FactCheckInputSchema(InputSchema):
"""Input schema for the Fact-Checking Tool."""

question: str = Field(..., description="The original question that was asked.")
answer: str = Field(..., description="The LLM answer to be fact-checked.")


class FactCheckOutputSchema(OutputSchema):
"""Output schema for the Fact-Checking Tool's results."""

fact_reasoner_score: Dict[str, Any] = Field(
...,
description="The full scoring dictionary from the FactReasoner.",
)
supported_atoms: List[Dict[str, Any]] = Field(
...,
description="List of atoms determined to be supported.",
)
not_supported_atoms: List[Dict[str, Any]] = Field(
...,
description="List of atoms determined to be not supported.",
)
contexts: List[Dict[str, Any]] = Field(
...,
description="List of retrieved contexts used for the check.",
)
graph_id: Optional[str] = Field(
None,
description="The unique ID for the generated fact graph.",
)
logging_metadata: Dict[str, Any] = Field(
{},
description="Additional logging metadata from the run.",
)


class FactCheckToolConfig(BaseToolConfig):
"""Configuration for the FactCheckTool."""

base_url: HttpUrl = Field(
default=HttpUrl(os.getenv("FACT_CHECK_API_URL", "http://localhost:8011")),
description="The base URL of the remote Fact-Checking and Correction Service.",
)
start_endpoint: str = Field(
default=os.getenv("FACT_CHECK_START_ENDPOINT", "/fact-check/start"),
description="Endpoint to start a new fact-checking job."
)
status_endpoint: str = Field(
default=os.getenv("FACT_CHECK_STATUS_ENDPOINT", "/fact-check/status/"),
description="Endpoint to get the status of a job. Must end with a slash."
)
correct_endpoint: str = Field(
default=os.getenv("FACT_CHECK_CORRECT_ENDPOINT", "/correct/"),
description="Endpoint for single correction steps."
)
display_graph_endpoint: str = Field(
default=os.getenv("FACT_CHECK_DISPLAY_GRAPH_ENDPOINT", "/display_graph/"),
description="Endpoint to display a saved fact graph."
)
graph_json_endpoint: str = Field(
default=os.getenv("FACT_CHECK_GRAPH_JSON_ENDPOINT", "/graph/json/"),
description="Endpoint to retrieve graph data as JSON."
)
polling_interval_seconds: int = Field(
default=120,
description="How often to poll for job results.",
)
job_timeout_seconds: int = Field(
default=int(os.getenv("FACT_CHECK_JOB_TIMEOUT", "1800")),
description="Maximum time to wait for the entire job to complete (30 minutes).",
)
request_timeout_seconds: int = Field(
default=int(os.getenv("FACT_CHECK_REQUEST_TIMEOUT", "60")),
description="Timeout in seconds for each individual API request.",
)


class FactCheckTool(
BaseTool[FactCheckInputSchema, FactCheckOutputSchema],
):
"""
A tool that calls an API to perform fact-checking on a given answer.
"""

name = "fact_check_tool"
description = (
"Calls an API to run the FactReasoner pipeline on a question and answer."
)
input_schema = FactCheckInputSchema
output_schema = FactCheckOutputSchema
config_schema = FactCheckToolConfig

def __init__(
self,
config: FactCheckToolConfig | None = None,
debug: bool = False,
):
"""Initializes the FactCheckTool and its HTTP client."""
config = config or FactCheckToolConfig()
super().__init__(config, debug)

logger.info("Initializing FactCheckTool...")
# Set a timeout on the API requests
timeout = httpx.Timeout(self.config.request_timeout_seconds, connect=60.0)
self.api_client = httpx.AsyncClient(
base_url=str(self.config.base_url),
timeout=timeout,
)

async def _arun(
self,
params: FactCheckInputSchema,
) -> FactCheckOutputSchema:
"""
Starts a fact-checking job and polls for its completion.
"""
logger.info(
f"Sending fact-check request for question: '{params.question[:50]}...'",
)

try:
# Start the job
start_response = await self.api_client.post(
self.config.start_endpoint,
json=params.model_dump(),
)
start_response.raise_for_status()
job_id = start_response.json()["job_id"]
logger.info(f"Successfully started job with ID: {job_id}")

# Poll for the result
total_wait_time = 0
while total_wait_time < self.config.job_timeout_seconds:
logger.info(f"Polling status for job {job_id}...")
status_response = await self.api_client.get(
f"{self.config.status_endpoint}/{job_id}",
)
status_response.raise_for_status()
status_data = status_response.json()

if status_data["status"] == "completed":
logger.info(f"Job {job_id} completed successfully.")
return FactCheckOutputSchema(**status_data["result"])
elif status_data["status"] == "failed":
raise Exception(
f"Job {job_id} failed on the server: {status_data.get('error', 'Unknown error')}",
)
elif status_data["status"] == "pending":
logger.info(
f"Job {job_id} is in progress... (waited {total_wait_time}s)",
)
await asyncio.sleep(self.config.polling_interval_seconds)
total_wait_time += self.config.polling_interval_seconds

raise asyncio.TimeoutError(
f"Job {job_id} did not complete within the {self.config.job_timeout_seconds}s timeout.",
)

except httpx.HTTPStatusError as e:
logger.error(
f"HTTP error occurred while calling fact-check API: {e.response.status_code} - {e.response.text}",
)
raise
except Exception as e:
logger.error(f"An unexpected error occurred: {e}")
raise
146 changes: 146 additions & 0 deletions akd/tools/vector_db_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
from typing import Any, Dict, List, Optional

import chromadb
import chromadb.utils.embedding_functions as embedding_functions
from loguru import logger
from pydantic import Field

from akd._base import InputSchema, OutputSchema
from akd.tools._base import BaseTool, BaseToolConfig
from akd.utils import get_akd_root


class VectorDBIndexInputSchema(InputSchema):
"""Input schema for indexing documents into the Vector Database."""

ids: List[str] = Field(..., description="A unique list of document IDs.")
documents: List[str] = Field(
...,
description="A list of document contents to index.",
)
metadatas: Optional[List[Dict[str, Any]]] = Field(
None,
description="Optional list of metadata for each document.",
)


class VectorDBQueryInputSchema(InputSchema):
"""Input schema for querying documents from the Vector Database."""

query: str = Field(..., description="The query string for retrieval.")
k: int = Field(3, description="Number of documents to retrieve.")


class VectorDBQueryOutputSchema(OutputSchema):
"""Output schema for the Vector Database tool's query results."""

results: List[Dict[str, Any]] = Field(
...,
description="List of retrieved documents, each as a dictionary with 'page_content' and 'metadata'.",
)


class VectorDBToolConfig(BaseToolConfig):
"""Configuration for the VectorDBTool, loaded from environment variables."""

embedding_model_name: str = Field(
default="sentence-transformers/all-MiniLM-L6-v2",
description="The name of the Hugging Face embedding model to use.",
)
embedding_model_api_key: Optional[str] = Field(
default=os.getenv("EMBEDDING_MODEL_API_KEY", None),
description="The API key for the embedding model provider, if required.",
)
db_path: str = Field(
default=os.getenv("VECTOR_DB_PATH", str(get_akd_root() / "chroma_db")),
description="Path to the persistent ChromaDB directory.",
)
collection_name: str = Field(
default="akd_vdb",
description="Name of the collection within ChromaDB.",
)


class VectorDBTool(
BaseTool[VectorDBQueryInputSchema, VectorDBQueryOutputSchema],
):
"""
A tool for indexing and retrieving documents from a Chroma vector database.
"""

name = "vector_db_tool"
description = (
"Indexes documents into a vector database and retrieves them based on a query."
)
input_schema = VectorDBQueryInputSchema
output_schema = VectorDBQueryOutputSchema
config_schema = VectorDBToolConfig

def __init__(
self,
config: Optional[VectorDBToolConfig] = None,
debug: bool = False,
):
"""Initializes the VectorDBTool and its ChromaDB client."""
config = config or VectorDBToolConfig()
super().__init__(config, debug)

logger.info("Initializing VectorDBTool...")
self.client = chromadb.PersistentClient(path=self.config.db_path)

embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=self.config.embedding_model_name,
)
self.collection = self.client.get_or_create_collection(
name=self.config.collection_name,
embedding_function=embedding_function,
)
logger.info(
f"Connected to ChromaDB collection '{self.config.collection_name}'.",
)

def index(self, params: VectorDBIndexInputSchema):
"""
Adds or updates documents in the vector database collection.
"""
logger.info(f"Indexing {len(params.documents)} documents...")
self.collection.add(
ids=params.ids,
documents=params.documents,
metadatas=params.metadatas,
)
logger.info("Indexing complete.")

async def _arun(
self,
params: VectorDBQueryInputSchema,
) -> VectorDBQueryOutputSchema:
"""
Retrieves documents and returns them as a list of dictionaries.
"""
logger.info(
f"Querying collection with query: '{params.query}', retrieving top-{params.k} documents",
)

results = self.collection.query(
query_texts=[params.query],
n_results=params.k,
include=["metadatas", "documents"],
)

retrieved_docs = []
if results and results.get("ids") and results["ids"][0]:
result_documents = results["documents"][0]
result_metadatas = results["metadatas"][0]

for i in range(len(result_documents)):
doc = {
"page_content": result_documents[i],
"metadata": result_metadatas[i]
if result_metadatas and result_metadatas[i]
else {},
}
retrieved_docs.append(doc)

return VectorDBQueryOutputSchema(results=retrieved_docs)
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ dependencies = [
"ollama>=0.5.1",
"tiktoken>=0.9.0",
"rapidfuzz>=3.13.0",
"deepeval>=3.4.0",
"markdown>=3.8.2",
"chromadb>=1.0.13",
"undetected-chromedriver>=3.5.5",
"pypaperbot @ git+https://github.com/NISH1001/PyPaperBot.git@develop",
]
Expand Down Expand Up @@ -89,6 +92,9 @@ dev = [
"pytest>=8.0.0",
"pytest-asyncio>=1.0.0",
"pytest-cov>=6.0.0",
"pytest-mock>=3.14.1",
"ipykernel>=6.30.0",
"ipywidgets>=8.1.7",
"pre-commit>=4.2.0",
]
local = [
Expand All @@ -98,6 +104,7 @@ local = [
"ipywidgets>=8.1.7",
]
ml = [
"chromadb>=1.0.13",
"pandas>=2.3.1",
"sentence-transformers>=5.0.0",
"docling>=2.37.0",
Expand Down
Loading
Loading