-
Notifications
You must be signed in to change notification settings - Fork 3
Add wrapper tool for FactReasoner #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jbrry
wants to merge
42
commits into
develop
Choose a base branch
from
feature/vector_database_tool
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
42 commits
Select commit
Hold shift + click to select a range
727d5f3
Add `VectorDBTool` based on Chroma
jbrry 1c75c1a
Merge branch 'develop' into feature/vector_database_tool
jbrry d85bf70
Adopt config strategy and add text splitter and vector db tools
jbrry 63e5cbe
Update litsearch agent with new indexing capabilities
jbrry f1f6a76
Remove comment line
jbrry 2e78325
Changes to standalone script for vectordb
jbrry 2586621
Add test to run `SemanticScholarSearchTool`
jbrry 21727e0
Merge changes from develop
jbrry 25ccec3
Add test for text splitter tool
jbrry 8d3e4b0
Merge branch 'develop' into feature/add_semantic_scholar_search_test
jbrry 384c9d3
Test the `from_params` method of the class
jbrry 231e6ff
Add tests for vector db and text splitter tools
jbrry c4a57c9
Fix chunking test
jbrry d947cb7
Leave config initialisation to super class
jbrry 2220443
Add FactReasoner fact-check to deep research pipeline
jbrry 7a2c13d
Add test for FactCheck tool
jbrry a492e6b
Merge branch 'feature/add_semantic_scholar_search_test' into feature/…
jbrry 595cb38
Updated test file
jbrry 8d9ed47
Wrap string in HttpUrl
jbrry fce0cfb
Update demo script with fact check
jbrry 6c50cb7
Enable polling for long running processes
jbrry 84cecff
Do not use text splitter tool
jbrry ae2f2a1
Remove text splitter tool
jbrry 024005c
Get URL paramater from .env file
jbrry 4efd005
Take parameters from env file
jbrry 5b39e84
Make the URL an `HttpUrl`
jbrry e3f79c9
Add dependencies to pyproject.toml
jbrry e2988d0
Use updated pyproject.toml file
jbrry 66e6d88
Fix merge conflict in pyproject.toml file
jbrry 32b7dac
Update test for fact-check to correspond to new post and get approach
jbrry dadb338
Update pyproject, and get API and job timeouts from env
jbrry 14d4be3
Set vector db path from util function
jbrry 1d2627a
Remove dependency on Langchain Documents
jbrry ccc37f7
Update test for vector db tool
jbrry f7964b9
Merge from `develop` and fix conflict in pyproject.toml
jbrry 0b81277
Update fact-check tool config to use endpoints given in .env file
jbrry 8582873
Merge branch 'develop' into feature/vector_database_tool
jbrry 03f0f7f
Fix conflict from develop
jbrry e7ef733
Fix conflict from develop
jbrry 8dbdc07
Sync version of pyproject.toml from integration branch
jbrry 2a50d76
Change vector db collection name to something more generic
jbrry 39c8b8e
Merge branch 'develop' into feature/vector_database_tool
jbrry File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| import asyncio | ||
| import os | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| import httpx | ||
| from loguru import logger | ||
| from pydantic import Field, HttpUrl | ||
|
|
||
| from akd._base import InputSchema, OutputSchema | ||
| from akd.tools._base import BaseTool, BaseToolConfig | ||
|
|
||
|
|
||
| class FactCheckInputSchema(InputSchema): | ||
| """Input schema for the Fact-Checking Tool.""" | ||
|
|
||
| question: str = Field(..., description="The original question that was asked.") | ||
| answer: str = Field(..., description="The LLM answer to be fact-checked.") | ||
|
|
||
|
|
||
| class FactCheckOutputSchema(OutputSchema): | ||
| """Output schema for the Fact-Checking Tool's results.""" | ||
|
|
||
| fact_reasoner_score: Dict[str, Any] = Field( | ||
| ..., | ||
| description="The full scoring dictionary from the FactReasoner.", | ||
| ) | ||
| supported_atoms: List[Dict[str, Any]] = Field( | ||
| ..., | ||
| description="List of atoms determined to be supported.", | ||
| ) | ||
| not_supported_atoms: List[Dict[str, Any]] = Field( | ||
| ..., | ||
| description="List of atoms determined to be not supported.", | ||
| ) | ||
| contexts: List[Dict[str, Any]] = Field( | ||
| ..., | ||
| description="List of retrieved contexts used for the check.", | ||
| ) | ||
| graph_id: Optional[str] = Field( | ||
| None, | ||
| description="The unique ID for the generated fact graph.", | ||
| ) | ||
| logging_metadata: Dict[str, Any] = Field( | ||
| {}, | ||
| description="Additional logging metadata from the run.", | ||
| ) | ||
|
|
||
|
|
||
| class FactCheckToolConfig(BaseToolConfig): | ||
| """Configuration for the FactCheckTool.""" | ||
|
|
||
| base_url: HttpUrl = Field( | ||
| default=HttpUrl(os.getenv("FACT_CHECK_API_URL", "http://localhost:8011")), | ||
| description="The base URL of the remote Fact-Checking and Correction Service.", | ||
| ) | ||
| start_endpoint: str = Field( | ||
| default=os.getenv("FACT_CHECK_START_ENDPOINT", "/fact-check/start"), | ||
| description="Endpoint to start a new fact-checking job." | ||
| ) | ||
| status_endpoint: str = Field( | ||
| default=os.getenv("FACT_CHECK_STATUS_ENDPOINT", "/fact-check/status/"), | ||
| description="Endpoint to get the status of a job. Must end with a slash." | ||
| ) | ||
| correct_endpoint: str = Field( | ||
| default=os.getenv("FACT_CHECK_CORRECT_ENDPOINT", "/correct/"), | ||
| description="Endpoint for single correction steps." | ||
| ) | ||
| display_graph_endpoint: str = Field( | ||
| default=os.getenv("FACT_CHECK_DISPLAY_GRAPH_ENDPOINT", "/display_graph/"), | ||
| description="Endpoint to display a saved fact graph." | ||
| ) | ||
| graph_json_endpoint: str = Field( | ||
| default=os.getenv("FACT_CHECK_GRAPH_JSON_ENDPOINT", "/graph/json/"), | ||
| description="Endpoint to retrieve graph data as JSON." | ||
| ) | ||
| polling_interval_seconds: int = Field( | ||
| default=120, | ||
| description="How often to poll for job results.", | ||
| ) | ||
| job_timeout_seconds: int = Field( | ||
| default=int(os.getenv("FACT_CHECK_JOB_TIMEOUT", "1800")), | ||
| description="Maximum time to wait for the entire job to complete (30 minutes).", | ||
| ) | ||
| request_timeout_seconds: int = Field( | ||
| default=int(os.getenv("FACT_CHECK_REQUEST_TIMEOUT", "60")), | ||
| description="Timeout in seconds for each individual API request.", | ||
| ) | ||
|
|
||
|
|
||
| class FactCheckTool( | ||
| BaseTool[FactCheckInputSchema, FactCheckOutputSchema], | ||
| ): | ||
| """ | ||
| A tool that calls an API to perform fact-checking on a given answer. | ||
| """ | ||
|
|
||
| name = "fact_check_tool" | ||
| description = ( | ||
| "Calls an API to run the FactReasoner pipeline on a question and answer." | ||
| ) | ||
| input_schema = FactCheckInputSchema | ||
| output_schema = FactCheckOutputSchema | ||
| config_schema = FactCheckToolConfig | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: FactCheckToolConfig | None = None, | ||
| debug: bool = False, | ||
| ): | ||
| """Initializes the FactCheckTool and its HTTP client.""" | ||
| config = config or FactCheckToolConfig() | ||
| super().__init__(config, debug) | ||
|
|
||
| logger.info("Initializing FactCheckTool...") | ||
| # Set a timeout on the API requests | ||
| timeout = httpx.Timeout(self.config.request_timeout_seconds, connect=60.0) | ||
| self.api_client = httpx.AsyncClient( | ||
| base_url=str(self.config.base_url), | ||
| timeout=timeout, | ||
| ) | ||
|
|
||
| async def _arun( | ||
| self, | ||
| params: FactCheckInputSchema, | ||
| ) -> FactCheckOutputSchema: | ||
| """ | ||
| Starts a fact-checking job and polls for its completion. | ||
| """ | ||
| logger.info( | ||
| f"Sending fact-check request for question: '{params.question[:50]}...'", | ||
| ) | ||
|
|
||
| try: | ||
| # Start the job | ||
| start_response = await self.api_client.post( | ||
| self.config.start_endpoint, | ||
| json=params.model_dump(), | ||
| ) | ||
| start_response.raise_for_status() | ||
| job_id = start_response.json()["job_id"] | ||
| logger.info(f"Successfully started job with ID: {job_id}") | ||
|
|
||
| # Poll for the result | ||
| total_wait_time = 0 | ||
| while total_wait_time < self.config.job_timeout_seconds: | ||
| logger.info(f"Polling status for job {job_id}...") | ||
| status_response = await self.api_client.get( | ||
| f"{self.config.status_endpoint}/{job_id}", | ||
| ) | ||
| status_response.raise_for_status() | ||
| status_data = status_response.json() | ||
|
|
||
| if status_data["status"] == "completed": | ||
| logger.info(f"Job {job_id} completed successfully.") | ||
| return FactCheckOutputSchema(**status_data["result"]) | ||
| elif status_data["status"] == "failed": | ||
| raise Exception( | ||
| f"Job {job_id} failed on the server: {status_data.get('error', 'Unknown error')}", | ||
| ) | ||
| elif status_data["status"] == "pending": | ||
| logger.info( | ||
| f"Job {job_id} is in progress... (waited {total_wait_time}s)", | ||
| ) | ||
| await asyncio.sleep(self.config.polling_interval_seconds) | ||
| total_wait_time += self.config.polling_interval_seconds | ||
|
|
||
| raise asyncio.TimeoutError( | ||
| f"Job {job_id} did not complete within the {self.config.job_timeout_seconds}s timeout.", | ||
| ) | ||
|
|
||
| except httpx.HTTPStatusError as e: | ||
| logger.error( | ||
| f"HTTP error occurred while calling fact-check API: {e.response.status_code} - {e.response.text}", | ||
| ) | ||
| raise | ||
| except Exception as e: | ||
| logger.error(f"An unexpected error occurred: {e}") | ||
| raise |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| import os | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| import chromadb | ||
| import chromadb.utils.embedding_functions as embedding_functions | ||
| from loguru import logger | ||
| from pydantic import Field | ||
|
|
||
| from akd._base import InputSchema, OutputSchema | ||
| from akd.tools._base import BaseTool, BaseToolConfig | ||
| from akd.utils import get_akd_root | ||
|
|
||
|
|
||
| class VectorDBIndexInputSchema(InputSchema): | ||
| """Input schema for indexing documents into the Vector Database.""" | ||
|
|
||
| ids: List[str] = Field(..., description="A unique list of document IDs.") | ||
| documents: List[str] = Field( | ||
| ..., | ||
| description="A list of document contents to index.", | ||
| ) | ||
| metadatas: Optional[List[Dict[str, Any]]] = Field( | ||
| None, | ||
| description="Optional list of metadata for each document.", | ||
| ) | ||
|
|
||
|
|
||
| class VectorDBQueryInputSchema(InputSchema): | ||
| """Input schema for querying documents from the Vector Database.""" | ||
|
|
||
| query: str = Field(..., description="The query string for retrieval.") | ||
| k: int = Field(3, description="Number of documents to retrieve.") | ||
|
|
||
|
|
||
| class VectorDBQueryOutputSchema(OutputSchema): | ||
| """Output schema for the Vector Database tool's query results.""" | ||
|
|
||
| results: List[Dict[str, Any]] = Field( | ||
| ..., | ||
| description="List of retrieved documents, each as a dictionary with 'page_content' and 'metadata'.", | ||
| ) | ||
|
|
||
|
|
||
| class VectorDBToolConfig(BaseToolConfig): | ||
| """Configuration for the VectorDBTool, loaded from environment variables.""" | ||
|
|
||
| embedding_model_name: str = Field( | ||
| default="sentence-transformers/all-MiniLM-L6-v2", | ||
| description="The name of the Hugging Face embedding model to use.", | ||
| ) | ||
| embedding_model_api_key: Optional[str] = Field( | ||
| default=os.getenv("EMBEDDING_MODEL_API_KEY", None), | ||
| description="The API key for the embedding model provider, if required.", | ||
| ) | ||
| db_path: str = Field( | ||
| default=os.getenv("VECTOR_DB_PATH", str(get_akd_root() / "chroma_db")), | ||
| description="Path to the persistent ChromaDB directory.", | ||
| ) | ||
| collection_name: str = Field( | ||
| default="akd_vdb", | ||
| description="Name of the collection within ChromaDB.", | ||
| ) | ||
|
|
||
|
|
||
| class VectorDBTool( | ||
| BaseTool[VectorDBQueryInputSchema, VectorDBQueryOutputSchema], | ||
| ): | ||
| """ | ||
| A tool for indexing and retrieving documents from a Chroma vector database. | ||
| """ | ||
|
|
||
| name = "vector_db_tool" | ||
| description = ( | ||
| "Indexes documents into a vector database and retrieves them based on a query." | ||
| ) | ||
| input_schema = VectorDBQueryInputSchema | ||
| output_schema = VectorDBQueryOutputSchema | ||
| config_schema = VectorDBToolConfig | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: Optional[VectorDBToolConfig] = None, | ||
| debug: bool = False, | ||
| ): | ||
| """Initializes the VectorDBTool and its ChromaDB client.""" | ||
| config = config or VectorDBToolConfig() | ||
| super().__init__(config, debug) | ||
|
|
||
| logger.info("Initializing VectorDBTool...") | ||
| self.client = chromadb.PersistentClient(path=self.config.db_path) | ||
|
|
||
| embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( | ||
| model_name=self.config.embedding_model_name, | ||
| ) | ||
| self.collection = self.client.get_or_create_collection( | ||
| name=self.config.collection_name, | ||
| embedding_function=embedding_function, | ||
| ) | ||
| logger.info( | ||
| f"Connected to ChromaDB collection '{self.config.collection_name}'.", | ||
| ) | ||
|
|
||
| def index(self, params: VectorDBIndexInputSchema): | ||
| """ | ||
| Adds or updates documents in the vector database collection. | ||
| """ | ||
| logger.info(f"Indexing {len(params.documents)} documents...") | ||
| self.collection.add( | ||
| ids=params.ids, | ||
| documents=params.documents, | ||
| metadatas=params.metadatas, | ||
| ) | ||
| logger.info("Indexing complete.") | ||
|
|
||
| async def _arun( | ||
| self, | ||
| params: VectorDBQueryInputSchema, | ||
| ) -> VectorDBQueryOutputSchema: | ||
| """ | ||
| Retrieves documents and returns them as a list of dictionaries. | ||
| """ | ||
| logger.info( | ||
| f"Querying collection with query: '{params.query}', retrieving top-{params.k} documents", | ||
| ) | ||
|
|
||
| results = self.collection.query( | ||
| query_texts=[params.query], | ||
| n_results=params.k, | ||
| include=["metadatas", "documents"], | ||
| ) | ||
|
|
||
| retrieved_docs = [] | ||
| if results and results.get("ids") and results["ids"][0]: | ||
| result_documents = results["documents"][0] | ||
| result_metadatas = results["metadatas"][0] | ||
|
|
||
| for i in range(len(result_documents)): | ||
| doc = { | ||
| "page_content": result_documents[i], | ||
| "metadata": result_metadatas[i] | ||
| if result_metadatas and result_metadatas[i] | ||
| else {}, | ||
| } | ||
| retrieved_docs.append(doc) | ||
|
|
||
| return VectorDBQueryOutputSchema(results=retrieved_docs) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.