From d7791f2b33389e07f6dabc83a0a925e47a42599d Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Wed, 6 Nov 2024 12:02:28 -0800 Subject: [PATCH 01/21] complete simple tests, cleanup routers --- .../scripts/test_v3_sdk/test_v3_sdk_chunks.py | 72 ++++++++++++++++--- .../test_v3_sdk/test_v3_sdk_collections.py | 10 +-- .../test_v3_sdk/test_v3_sdk_documents.py | 5 +- .../scripts/test_v3_sdk/test_v3_sdk_users.py | 39 +++++++--- py/core/main/api/v3/chunks_router.py | 2 +- py/core/main/api/v3/documents_router.py | 2 + 6 files changed, 105 insertions(+), 25 deletions(-) diff --git a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_chunks.py b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_chunks.py index 77fcccfcf..ca0846aee 100644 --- a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_chunks.py +++ b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_chunks.py @@ -1,11 +1,40 @@ +import random +import string + from r2r import R2RClient -second_ingested_document_id = "b4ac4dd6-5f28-596e-a55b-7cf242ca30aa" -first_chunk_id = "b4ac4dd6-5f28-596e-a55b-7cf242ca30aa" -user_email = "John.Doe1@email.com" +first_created_chunk_id = "abcc4dd6-5f28-596e-a55b-7cf242ca30aa" +second_created_chunk_id = "abcc4dd6-5f28-596e-a55b-7cf242ca30bb" +created_document_id = "defc4dd6-5f28-596e-a55b-7cf242ca30aa" + + +# Function to generate a random email +def generate_random_email(): + username_length = 8 + username = "".join( + random.choices( + string.ascii_lowercase + string.digits, k=username_length + ) + ) + domain = random.choice( + ["example.com", "test.com", "fake.org", "random.net"] + ) + return f"{username}@{domain}" + + +user_email = generate_random_email() client = R2RClient("http://localhost:7276", prefix="/v3") +# First create and authenticate a user if not already done +try: + new_user = client.users.register( + email=user_email, password="new_secure_password123" + ) + print("New user created:", new_user) +except Exception as e: + print("User might already exist:", str(e)) + # Login result = client.users.login( email=user_email, password="new_secure_password123" @@ -24,13 +53,13 @@ ) print("Chunks list:", list_result) -# Test 2: Create chunk -print("\n=== Test 2: Create Chunk ===") +# Test 2: Create chunk and document +print("\n=== Test 2: Create Chunk & Doc. ===") create_result = client.chunks.create( chunks=[ { - "id": first_chunk_id, - "document_id": second_ingested_document_id, + "id": first_created_chunk_id, + "document_id": created_document_id, "collection_ids": ["b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"], "metadata": {"key": "value"}, "text": "Some text content", @@ -40,6 +69,25 @@ ) print("Created chunk:", create_result) +# TODO - Update router and uncomment this test +# TODO - Update router and uncomment this test +# Test 3: Create chunk +# print("\n=== Test 3: Create Chunk & Doc. ===") +# create_result = client.chunks.create( +# chunks=[ +# { +# "id": second_created_chunk_id, +# "document_id": created_document_id, +# "collection_ids": ["b4ac4dd6-5f27-596e-a55b-7cf242ca30aa"], +# "metadata": {"key": "value"}, +# "text": "Some text content", +# } +# ], +# run_with_orchestration=False, +# ) +# print("Created chunk:", create_result) + + # Test 3: Search chunks print("\n=== Test 3: Search Chunks ===") search_result = client.chunks.search(query="whoami?") @@ -47,16 +95,22 @@ # Test 4: Retrieve chunk print("\n=== Test 4: Retrieve Chunk ===") -retrieve_result = client.chunks.retrieve(id=first_chunk_id) +retrieve_result = client.chunks.retrieve(id=first_created_chunk_id) print("Retrieved chunk:", retrieve_result) # Test 5: Update chunk print("\n=== Test 5: Update Chunk ===") update_result = client.chunks.update( { - "id": first_chunk_id, + "id": first_created_chunk_id, "text": "Updated content", "metadata": {"key": "new value"}, } ) print("Updated chunk:", update_result) + + +# Test 4: Retrieve chunk +print("\n=== Test 6: Retrieve Updated Chunk ===") +retrieve_result = client.chunks.retrieve(id=first_created_chunk_id) +print("Retrieved updated chunk:", retrieve_result) diff --git a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_collections.py b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_collections.py index c1c1cdd16..296de8f16 100644 --- a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_collections.py +++ b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_collections.py @@ -1,6 +1,5 @@ from r2r import R2RClient -first_ingested_document_id = "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" user_email = "John.Doe1@email.com" client = R2RClient("http://localhost:7276", prefix="/v3") @@ -20,7 +19,6 @@ ) print("Login successful") -collection_id = "b773f631-eff1-4ad7-bc00-9f10e5e554b3" # # Test 1: Create a new collection print("\n=== Test 1: Create Collection ===") create_result = client.collections.create( @@ -52,9 +50,13 @@ print("Updated collection:", update_result) # # Test 5: Add document to collection +# list user documents +documents = client.documents.list(limit=10, offset=0) +print(documents) + print("\n=== Test 5: Add Document to Collection ===") add_doc_result = client.collections.add_document( - id=collection_id, document_id=first_ingested_document_id + id=collection_id, document_id=documents["results"][0]["id"] ) print("Added document to collection:", add_doc_result) @@ -75,7 +77,7 @@ # Test 8: Remove document from collection print("\n=== Test 8: Remove Document from Collection ===") remove_doc_result = client.collections.remove_document( - id=collection_id, document_id=first_ingested_document_id + id=collection_id, document_id=documents["results"][0]["id"] ) print("Removed document from collection:", remove_doc_result) diff --git a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_documents.py b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_documents.py index 06c1bae47..85c91baec 100644 --- a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_documents.py +++ b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_documents.py @@ -1,7 +1,7 @@ from r2r import R2RClient -first_ingested_document_id = "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" -first_ingested_file_path = "core/examples/data/pg_essay_1.html" +first_ingested_document_id = "1b594aea-583a-5a4b-92f4-229d6e5eb886" +first_ingested_file_path = "../../data/pg_essay_1.html" user_email = "John.Doe1@email.com" client = R2RClient("http://localhost:7276", prefix="/v3") @@ -54,6 +54,7 @@ print("Document chunks:", chunks_result) # Test 6: List document collections +client.users.logout() print("\n=== Test 6: List Document Collections ===") collections_result = client.documents.list_collections( id=first_ingested_document_id, offset=0, limit=10 diff --git a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_users.py b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_users.py index dbf142ded..8f6615e10 100644 --- a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_users.py +++ b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_users.py @@ -1,6 +1,26 @@ +import random +import string + from r2r import R2RClient -user_email = "John.Doe1@email.com" +# user_email = "John.Doe1@email.com" + + +# Function to generate a random email +def generate_random_email(): + username_length = 8 + username = "".join( + random.choices( + string.ascii_lowercase + string.digits, k=username_length + ) + ) + domain = random.choice( + ["example.com", "test.com", "fake.org", "random.net"] + ) + return f"{username}@{domain}" + + +user_email = generate_random_email() client = R2RClient("http://localhost:7276", prefix="/v3") @@ -35,13 +55,19 @@ reset_request_result = client.users.request_password_reset(email=user_email) print("Password reset request result:", reset_request_result) +# logout, to use super user +# Test 9: Logout user +print("\n=== Test 6: Logout User ===") +logout_result = client.users.logout() +print("Logout result:", logout_result) + # Test 6: List users -print("\n=== Test 6: List Users ===") +print("\n=== Test 7: List Users ===") users_list = client.users.list() print("Users list:", users_list) # Test 7: Retrieve user -print("\n=== Test 7: Retrieve User ===") +print("\n=== Test 8: Retrieve User ===") user_id = users_list["results"][0][ "user_id" ] # Assuming we have at least one user @@ -49,11 +75,6 @@ print("User details:", user_details) # Test 8: Update user -print("\n=== Test 8: Update User ===") +print("\n=== Test 9: Update User ===") update_result = client.users.update(user_id, name="Jane Doe") print("Update user result:", update_result) - -# Test 9: Logout user -print("\n=== Test 9: Logout User ===") -logout_result = client.users.logout() -print("Logout result:", logout_result) diff --git a/py/core/main/api/v3/chunks_router.py b/py/core/main/api/v3/chunks_router.py index c95759a08..7708d0968 100644 --- a/py/core/main/api/v3/chunks_router.py +++ b/py/core/main/api/v3/chunks_router.py @@ -365,7 +365,7 @@ async def retrieve_chunk( # document = await self.services["management"].get_document(chunk.document_id) # TODO - Add collection ID check if not auth_user.is_superuser and str(auth_user.id) != str( - chunk.user_id + chunk["user_id"] ): raise R2RException("Not authorized to access this chunk", 403) diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index f082ee171..cc33d35f1 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -970,6 +970,8 @@ async def get_document_collections( The results are paginated and ordered by collection creation date, with the most recently created collections appearing first. + + NOTE - This endpoint is only available to superusers, it will be extended to regular users in a future release. """ if not auth_user.is_superuser: raise R2RException( From 370d947b089e13d3210af0358a8a3537c998696b Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Wed, 6 Nov 2024 17:52:24 -0800 Subject: [PATCH 02/21] up --- .../scripts/test_v3_sdk/test_v3_sdk_graph.py | 202 ++++++++ py/core/main/api/v2/kg_router.py | 5 +- py/core/main/api/v3/collections_router.py | 1 - py/core/main/api/v3/graph_router.py | 481 ++++++++++-------- py/core/main/services/kg_service.py | 54 ++ py/core/pipes/kg/community_summary.py | 2 +- py/core/pipes/kg/entity_description.py | 2 +- py/core/pipes/kg/prompt_tuning.py | 4 +- py/core/pipes/kg/triples_extraction.py | 9 +- .../database/prompts/prompt_tuning.yaml | 1 + py/sdk/v3/graphs.py | 37 +- py/shared/abstractions/kg.py | 8 +- 12 files changed, 569 insertions(+), 237 deletions(-) create mode 100644 py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py diff --git a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py new file mode 100644 index 000000000..0f7561a1e --- /dev/null +++ b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py @@ -0,0 +1,202 @@ +from r2r import R2RClient +import uuid +import time + +# Initialize client +client = R2RClient("http://localhost:7276", prefix="/v3") + +def setup_prerequisites(): + """Setup necessary document and collection""" + print("\n=== Setting up prerequisites ===") + + # # Login + # try: + # client.users.register(email=user_email, password="new_secure_password123") + # except Exception as e: + # print("User might already exist:", str(e)) + + # result = client.users.login(email=user_email, password="new_secure_password123") + # print("Login successful") + + + try: + # Create document + doc_result = client.documents.create( + file_path="../../data/pg_essay_1.html", + metadata={"source": "test"}, + run_with_orchestration=False + ) + print('doc_id = ', doc_result) + doc_id = doc_result['results']['document_id'] + print(f"Created document with ID: {doc_id}") + except Exception as e: + doc_id = "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" + pass + + # Create collection + # collection_id = str(uuid.uuid4()) + collection_result = client.collections.create( + # collection_id=collection_id, + name="Test Collection", + description="Collection for testing graph operations" + ) + print("Created collection with ID: " + str(collection_result["results"]["collection_id"])) + collection_id = collection_result["results"]["collection_id"] + # Add document to collection + client.collections.add_document( + id=collection_id, + document_id=doc_id + ) + print(f"Added document {doc_id} to collection {collection_id}") + + return collection_id, doc_id + +def test_graph_operations(collection_id): + """Test graph CRUD operations""" + print("\n=== Testing Graph Operations ===") + + # Test 1: Create Graph + print("\n--- Test 1: Create Graph ---") + create_result = client.graphs.create( + collection_id=collection_id, + settings={ + "entity_types": ["PERSON", "ORG", "GPE"], + "min_confidence": 0.8 + }, + run_type="estimate", + run_with_orchestration=False + ) + print("Graph estimation result:", create_result) + + create_result = client.graphs.create( + collection_id=collection_id, + settings={ + "entity_types": ["PERSON", "ORG", "GPE"], + "min_confidence": 0.8 + }, + run_type="run", + run_with_orchestration=False + ) + print("Graph creation result:", create_result) + + # # # Test 2: Get Graph Status + # # print("\n--- Test 2: Get Graph Status ---") + # # status_result = client.graphs.get_status(collection_id=collection_id) + # # print("Graph status:", status_result) + + # Test 3: List Entities + print("\n--- Test 3: List Entities ---") + entities_result = client.graphs.list_entities( + collection_id=collection_id, + # level="collection", + offset=0, + limit=10 + ) + print("Entities:", entities_result) + + # Test 4: Get Specific Entity + print('entities_result["results"]["entities"][0] = ', entities_result["results"]["entities"][0]) + entity_id = entities_result["results"]["entities"][0]["id"] #entities_result['items'][0]['id'] + print('entity_id = ', entity_id) + print(f"\n--- Test 4: Get Entity {entity_id} ---") + entity_result = client.graphs.get_entity( + collection_id=collection_id, + entity_id=entity_id + ) + print("Entity details:", entity_result) + + # # # # Test 5: List Relationships + # # # print("\n--- Test 5: List Relationships ---") + # # relationships_result = client.graphs.list_relationships( + # # collection_id=collection_id, + # # offset=0, + # # limit=10 + # # ) + # # print("Relationships:", relationships_result) + + # Test 6: Create Communities + print("\n--- Test 6: Create Communities ---") + communities_result = client.graphs.create_communities( + run_type="estimate", + collection_id=collection_id, + run_with_orchestration=False + # settings={ + # "algorithm": "louvain", + # "resolution": 1.0, + # "min_community_size": 3 + # } + ) + print("Communities estimation result:", communities_result) + + communities_result = client.graphs.create_communities( + run_type="run", + collection_id=collection_id, + run_with_orchestration=False + # settings={ + # "algorithm": "louvain", + # "resolution": 1.0, + # "min_community_size": 3 + # } + ) + print("Communities creation result:", communities_result) + + # Wait for community creation to complete + + # Test 7: List Communities + print("\n--- Test 7: List Communities ---") + communities_list = client.graphs.list_communities( + collection_id=collection_id, + offset=0, + limit=10 + ) + print("Communities:", communities_list) + + # Test 8: Tune Prompt + print("\n--- Test 8: Tune Prompt ---") + tune_result = client.graphs.tune_prompt( + collection_id=collection_id, + prompt_name="graphrag_triples_extraction_few_shot", + documents_limit=100, + chunks_limit=1000 + ) + print("Prompt tuning result:", tune_result) + + # Test 9: Entity Deduplication + print("\n--- Test 9: Entity Deduplication ---") + dedup_result = client.graphs.deduplicate_entities( + collection_id=collection_id, + settings={ + "kg_entity_deduplication_type": "by_name", + "max_description_input_length": 65536 + } + ) + print("Deduplication result:", dedup_result) + + # Optional: Clean up + # Test 10: Delete Graph + print("\n--- Test 10: Delete Graph ---") + delete_result = client.graphs.delete( + collection_id=collection_id, + cascade=True + ) + print("Graph deletion result:", delete_result) + +def main(): + try: + # Setup prerequisites + # collection_id, doc_id = setup_prerequisites() + collection_id = "42e0efa8-ab92-49e8-ae5b-84215876a632" + + # Run graph operations tests + test_graph_operations(collection_id) + + except Exception as e: + print(f"Error occurred: {str(e)}") + finally: + pass + # Cleanup: Logout + # client.users.logout() + # print("\nLogged out successfully") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index 55a99855d..9f45f7fd7 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -17,6 +17,8 @@ WrappedKGTriplesResponse, WrappedKGTunePromptResponse, ) + + from core.base.logger.base import RunType from core.providers import ( HatchetOrchestrationProvider, @@ -287,6 +289,7 @@ async def get_entities( limit, ) + @self.router.get("/triples") @self.base_endpoint async def get_triples( @@ -432,7 +435,7 @@ async def deduplicate_entities( async def get_tuned_prompt( prompt_name: str = Query( ..., - description="The name of the prompt to tune. Valid options are 'kg_triples_extraction_prompt', 'kg_entity_description_prompt' and 'community_reports_prompt'.", + description="The name of the prompt to tune. Valid options are 'graphrag_triples_extraction_few_shot', 'graphrag_entity_description' and 'graphrag_community_reports'.", ), collection_id: Optional[UUID] = Query( None, description="Collection ID to retrieve communities from." diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index 4fcb7e555..0e6c2b3e1 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -103,7 +103,6 @@ async def create_collection( await self.services["management"].add_user_to_collection( auth_user.id, collection.collection_id ) - print("collection = ", collection) return collection @self.router.get( diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 48e2fb09a..96c803162 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -6,14 +6,15 @@ from fastapi import Body, Depends, Path, Query from pydantic import BaseModel, Field, Json -from core.base import R2RException, RunType -from core.base.abstractions import EntityLevel, KGRunType +from core.base import R2RException, RunType, KGCreationSettings +from core.base.abstractions import EntityLevel, KGRunType, Entity from core.base.api.models import ( - WrappedKGCommunitiesResponse, WrappedKGCreationResponse, WrappedKGEnrichmentResponse, WrappedKGEntityDeduplicationResponse, WrappedKGTunePromptResponse, + WrappedKGEntitiesResponse, + WrappedKGCommunitiesResponse ) from core.providers import ( HatchetOrchestrationProvider, @@ -30,29 +31,29 @@ logger = logging.getLogger() -class Entity(BaseModel): - """Model representing a graph entity.""" +# class Entity(BaseModel): +# """Model representing a graph entity.""" - id: UUID - name: str - type: str - metadata: dict = Field(default_factory=dict) - level: EntityLevel - collection_ids: list[UUID] - embedding: Optional[list[float]] = None +# id: UUID +# name: str +# type: str +# metadata: dict = Field(default_factory=dict) +# level: EntityLevel +# collection_ids: list[UUID] +# embedding: Optional[list[float]] = None - class Config: - json_schema_extra = { - "example": { - "id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", - "name": "John Smith", - "type": "PERSON", - "metadata": {"confidence": 0.95}, - "level": "DOCUMENT", - "collection_ids": ["d09dedb1-b2ab-48a5-b950-6e1f464d83e7"], - "embedding": [0.1, 0.2, 0.3], - } - } +# class Config: +# json_schema_extra = { +# "example": { +# "id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", +# "name": "John Smith", +# "type": "PERSON", +# "metadata": {"confidence": 0.95}, +# "level": "DOCUMENT", +# "collection_ids": ["d09dedb1-b2ab-48a5-b950-6e1f464d83e7"], +# "embedding": [0.1, 0.2, 0.3], +# } +# } class Relationship(BaseModel): @@ -132,9 +133,7 @@ def _setup_routes(self): result = client.graphs.create( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", settings={ - "extraction_method": "spacy", - "entity_types": ["PERSON", "ORG", "GPE"], - "min_confidence": 0.8 + "entity_types": ["PERSON", "ORG", "GPE"] } )""" ), @@ -148,8 +147,7 @@ def _setup_routes(self): -H "Authorization: Bearer YOUR_API_KEY" \\ -d '{ "settings": { - "entity_types": ["PERSON", "ORG", "GPE"], - "min_confidence": 0.8 + "entity_types": ["PERSON", "ORG", "GPE"] } }'""" ), @@ -160,14 +158,20 @@ def _setup_routes(self): @self.base_endpoint async def create_graph( collection_id: UUID = Path( - ..., description="Collection ID to create graph for" + default=..., + description="Collection ID to create graph for.", ), - settings: Optional[dict] = Body( - None, description="Graph creation settings" + run_type: Optional[KGRunType] = Body( + default=None, + description="Run type for the graph creation process.", ), - run_with_orchestration: bool = Query(True), + settings: Optional[KGCreationSettings] = Body( + default=None, + description="Settings for the graph creation process.", + ), + run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[WrappedKGCreationResponse]: + ) -> WrappedKGCreationResponse: """Creates a new knowledge graph by extracting entities and relationships from documents in a collection. The graph creation process involves: @@ -176,36 +180,60 @@ async def create_graph( 3. Building a connected knowledge graph structure """ + settings = settings.dict() if settings else None if not auth_user.is_superuser: - raise R2RException("Only superusers can create graphs", 403) + logger.warning("Implement permission checks here.") - server_settings = ( + logger.info(f"Running create-graph on collection {collection_id}") + + # If no collection ID is provided, use the default user collection + if not collection_id: + collection_id = generate_default_user_collection_id( + auth_user.id + ) + + # If no run type is provided, default to estimate + if not run_type: + run_type = KGRunType.ESTIMATE + + # Apply runtime settings overrides + server_kg_creation_settings = ( self.providers.database.config.kg_creation_settings ) + if settings: - server_settings = update_settings_from_dict( - server_settings, settings + server_kg_creation_settings = update_settings_from_dict( + server_kg_creation_settings, settings ) - workflow_input = { - "collection_id": str(collection_id), - "kg_creation_settings": server_settings.model_dump_json(), - "user": auth_user.model_dump_json(), - } - - if run_with_orchestration: - return await self.orchestration_provider.run_workflow( - "create-graph", {"request": workflow_input}, {} + # If the run type is estimate, return an estimate of the creation cost + if run_type is KGRunType.ESTIMATE: + return await self.services["kg"].get_creation_estimate( + collection_id, server_kg_creation_settings ) else: - from core.main.orchestration import simple_kg_factory - simple_kg = simple_kg_factory(self.services["kg"]) - await simple_kg["create-graph"](workflow_input) - return { # type: ignore - "message": "Graph created successfully.", - "task_id": None, - } + # Otherwise, create the graph + if run_with_orchestration: + workflow_input = { + "collection_id": str(collection_id), + "kg_creation_settings": server_kg_creation_settings.model_dump_json(), + "user": auth_user.json(), + } + + return await self.orchestration_provider.run_workflow( # type: ignore + "create-graph", {"request": workflow_input}, {} + ) + else: + from core.main.orchestration import simple_kg_factory + + logger.info("Running create-graph without orchestration.") + simple_kg = simple_kg_factory(self.service) + await simple_kg["create-graph"](workflow_input) + return { + "message": "Graph created successfully.", + "task_id": None, + } @self.router.get( "/graphs/{collection_id}", @@ -252,13 +280,14 @@ async def get_graph_status( - Community statistics - Current settings """ - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can view graph status", 403 - ) + raise NotImplementedError("Not implemented", 501) + # if not auth_user.is_superuser: + # raise R2RException( + # "Only superusers can view graph status", 403 + # ) - status = await self.services["kg"].get_graph_status(collection_id) - return status # type: ignore + # status = await self.services["kg"].get_graph_status(collection_id) + # return status # type: ignore # @self.router.post( # "/graphs/{collection_id}/enrich", @@ -439,13 +468,14 @@ async def create_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Entity]: """Creates a new entity in the graph.""" - if not auth_user.is_superuser: - raise R2RException("Only superusers can create entities", 403) + raise NotImplementedError("Not implemented", 501) + # if not auth_user.is_superuser: + # raise R2RException("Only superusers can create entities", 403) - new_entity = await self.services["kg"].create_entity( - collection_id, entity - ) - return new_entity # type: ignore + # new_entity = await self.services["kg"].create_entity( + # collection_id, entity + # ) + # return new_entity # type: ignore @self.router.delete( "/graphs/{collection_id}/entities/{entity_id}", @@ -490,13 +520,14 @@ async def delete_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[dict]: """Deletes an entity and optionally its relationships.""" - if not auth_user.is_superuser: - raise R2RException("Only superusers can delete entities", 403) + raise NotImplementedError("Not implemented", 501) + # if not auth_user.is_superuser: + # raise R2RException("Only superusers can delete entities", 403) - await self.services["kg"].delete_entity( - collection_id, entity_id, cascade - ) - return {"message": "Entity deleted successfully"} # type: ignore + # await self.services["kg"].delete_entity( + # collection_id, entity_id, cascade + # ) + # return {"message": "Entity deleted successfully"} # type: ignore @self.router.get( "/graphs/{collection_id}/entities", @@ -539,9 +570,9 @@ async def list_entities( level: EntityLevel = Query(EntityLevel.DOCUMENT), offset: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), - include_embeddings: bool = Query(False), + # include_embeddings: bool = Query(False), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Entity]]: + ) -> WrappedKGEntitiesResponse: # PaginatedResultsWrapper[list[Entity]]: """Lists entities in the graph with filtering and pagination support. Entities represent the nodes in the knowledge graph, extracted from documents. @@ -553,10 +584,16 @@ async def list_entities( - Community memberships - Optional vector embedding """ - entities = await self.services["kg"].list_entities( - collection_id, level, offset, limit, include_embeddings + if level == EntityLevel.CHUNK: + entity_table_name = "chunk_entity" + elif level == EntityLevel.DOCUMENT: + entity_table_name = "document_entity" + else: + entity_table_name = "collection_entity" + + return await self.services["kg"].list_entities( + collection_id=collection_id, entity_ids=[], entity_table_name=entity_table_name, offset=offset, limit=limit ) - return entities # type: ignore @self.router.get( "/graphs/{collection_id}/entities/{entity_id}", @@ -565,17 +602,24 @@ async def list_entities( @self.base_endpoint async def get_entity( collection_id: UUID = Path(...), - entity_id: UUID = Path(...), - include_embeddings: bool = Query(False), + level: EntityLevel = Query(EntityLevel.DOCUMENT), + entity_id: int = Path(...), + # include_embeddings: bool = Query(False), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Entity]: """Retrieves details of a specific entity.""" - entity = await self.services["kg"].get_entity( - collection_id, entity_id, include_embeddings - ) - if not entity: - raise R2RException("Entity not found", 404) - return entity + + if level == EntityLevel.CHUNK: + entity_table_name = "chunk_entity" + elif level == EntityLevel.DOCUMENT: + entity_table_name = "document_entity" + else: + entity_table_name = "collection_entity" + + result = (await self.services["kg"].list_entities( + collection_id=collection_id, entity_ids=[entity_id], entity_table_name=entity_table_name # , offset=offset, limit=limit + )) + return result['entities'][0] # type: ignore @self.router.post( "/graphs/{collection_id}/entities/{entity_id}", @@ -631,13 +675,14 @@ async def update_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Entity]: """Updates an existing entity.""" - if not auth_user.is_superuser: - raise R2RException("Only superusers can update entities", 403) + raise NotImplementedError("Not implemented", 501) + # if not auth_user.is_superuser: + # raise R2RException("Only superusers can update entities", 403) - updated_entity = await self.services["kg"].update_entity( - collection_id, entity_id, entity_update - ) - return updated_entity # type: ignore + # updated_entity = await self.services["kg"].update_entity( + # collection_id, entity_id, entity_update + # ) + # return updated_entity # type: ignore @self.router.post( "/graphs/{collection_id}/entities/deduplicate", @@ -823,15 +868,16 @@ async def create_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Relationship]: """Creates a new relationship between entities.""" - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can create relationships", 403 - ) + raise NotImplementedError("Not implemented", 501) + # if not auth_user.is_superuser: + # raise R2RException( + # "Only superusers can create relationships", 403 + # ) - new_relationship = await self.services["kg"].create_relationship( - collection_id, relationship - ) - return new_relationship # type: ignore + # new_relationship = await self.services["kg"].create_relationship( + # collection_id, relationship + # ) + # return new_relationship # type: ignore # Relationship operations @self.router.get( @@ -888,16 +934,16 @@ async def list_relationships( - Confidence score and metadata - Source documents and extractions """ - - relationships = await self.services["kg"].list_relationships( - collection_id, - source_id, - target_id, - relationship_type, - offset, - limit, - ) - return relationships # type: ignore + raise R2RException("Not implemented", 501) + # relationships = await self.services["kg"].list_relationships( + # collection_id, + # source_id, + # target_id, + # relationship_type, + # offset, + # limit, + # ) + # return relationships # type: ignore @self.router.get( "/graphs/{collection_id}/relationships/{relationship_id}", @@ -910,12 +956,13 @@ async def get_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Relationship]: """Retrieves details of a specific relationship.""" - relationship = await self.services["kg"].get_relationship( - collection_id, relationship_id - ) - if not relationship: - raise R2RException("Relationship not found", 404) - return relationship # type: ignore + raise R2RException("Not implemented", 501) + # relationship = await self.services["kg"].get_relationship( + # collection_id, relationship_id + # ) + # if not relationship: + # raise R2RException("Relationship not found", 404) + # return relationship # type: ignore @self.router.post( "/graphs/{collection_id}/relationships/{relationship_id}", @@ -971,17 +1018,18 @@ async def update_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Relationship]: """Updates an existing relationship.""" - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can update relationships", 403 - ) - - updated_relationship = await self.services[ - "kg" - ].update_relationship( - collection_id, relationship_id, relationship_update - ) - return updated_relationship # type: ignore + raise NotImplementedError("Not implemented") + # if not auth_user.is_superuser: + # raise R2RException( + # "Only superusers can update relationships", 403 + # ) + + # updated_relationship = await self.services[ + # "kg" + # ].update_relationship( + # collection_id, relationship_id, relationship_update + # ) + # return updated_relationship # type: ignore @self.router.delete( "/graphs/{collection_id}/relationships/{relationship_id}", @@ -1021,15 +1069,16 @@ async def delete_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[dict]: """Deletes a relationship.""" - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can delete relationships", 403 - ) + raise NotImplementedError("Not implemented") + # if not auth_user.is_superuser: + # raise R2RException( + # "Only superusers can delete relationships", 403 + # ) - await self.services["kg"].delete_relationship( - collection_id, relationship_id - ) - return {"message": "Relationship deleted successfully"} # type: ignore + # await self.services["kg"].delete_relationship( + # collection_id, relationship_id + # ) + # return {"message": "Relationship deleted successfully"} # type: ignore # Community operations @self.router.post( @@ -1049,11 +1098,7 @@ async def delete_relationship( result = client.graphs.create_communities( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", settings={ - "algorithm": "louvain", - "resolution": 1.0, - "min_community_size": 3, - "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", - "similarity_threshold": 0.7 + "max_summary_input_length": 65536, } )""" ), @@ -1067,11 +1112,7 @@ async def delete_relationship( -H "Authorization: Bearer YOUR_API_KEY" \\ -d '{ "settings": { - "algorithm": "louvain", - "resolution": 1.0, - "min_community_size": 3, - "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", - "similarity_threshold": 0.7 + "max_summary_input_length": 65536, } }'""" ), @@ -1083,9 +1124,13 @@ async def delete_relationship( async def create_communities( collection_id: UUID = Path(...), settings: Optional[dict] = Body(None), + run_type: Optional[KGRunType] = Body( + default=None, + description="Run type for the graph creation process.", + ), run_with_orchestration: bool = Query(True), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[WrappedKGCommunitiesResponse]: + ) -> WrappedKGEnrichmentResponse: """Creates communities in the graph by analyzing entity relationships and similarities. Communities are created by: @@ -1099,33 +1144,44 @@ async def create_communities( "Only superusers can create communities", 403 ) - server_settings = ( - self.providers.database.config.community_detection_settings + # Apply runtime settings overrides + server_kg_enrichment_settings = ( + self.providers.database.config.kg_enrichment_settings ) if settings: - server_settings = update_settings_from_dict( - server_settings, settings + server_kg_enrichment_settings = update_settings_from_dict( + server_kg_enrichment_settings, settings ) workflow_input = { "collection_id": str(collection_id), - "community_detection_settings": server_settings.model_dump_json(), + "kg_enrichment_settings": server_kg_enrichment_settings.model_dump_json(), "user": auth_user.model_dump_json(), } - if run_with_orchestration: - return await self.orchestration_provider.run_workflow( # type: ignore - "create-communities", {"request": workflow_input}, {} + if not run_type: + run_type = KGRunType.ESTIMATE + + # If the run type is estimate, return an estimate of the enrichment cost + if run_type is KGRunType.ESTIMATE: + return await self.services["kg"].get_enrichment_estimate( + collection_id, server_kg_enrichment_settings ) - else: - from core.main.orchestration import simple_kg_factory - simple_kg = simple_kg_factory(self.services["kg"]) - await simple_kg["create-communities"](workflow_input) - return { # type: ignore - "message": "Communities created successfully.", - "task_id": None, - } + else: + if run_with_orchestration: + return await self.orchestration_provider.run_workflow( # type: ignore + "enrich-graph", {"request": workflow_input}, {} + ) + else: + from core.main.orchestration import simple_kg_factory + + simple_kg = simple_kg_factory(self.services["kg"]) + await simple_kg["enrich-graph"](workflow_input) + return { # type: ignore + "message": "Communities created successfully.", + "task_id": None, + } @self.router.post( "/graphs/{collection_id}/communities/{community_id}", @@ -1179,15 +1235,16 @@ async def update_community( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Community]: """Updates a community's metadata.""" - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can update communities", 403 - ) + raise NotImplementedError("Not implemented") + # if not auth_user.is_superuser: + # raise R2RException( + # "Only superusers can update communities", 403 + # ) - updated_community = await self.services["kg"].update_community( - collection_id, community_id, community_update - ) - return updated_community # type: ignore + # updated_community = await self.services["kg"].update_community( + # collection_id, community_id, community_update + # ) + # return updated_community # type: ignore @self.router.get( "/graphs/{collection_id}/communities", @@ -1226,11 +1283,16 @@ async def update_community( @self.base_endpoint async def list_communities( collection_id: UUID = Path(...), - level: Optional[int] = Query(None), + community_numbers: Optional[list[int]] = Query( + None, description="Community numbers to filter by." + ), + levels: Optional[list[int]] = Query( + None, description="Levels to filter by." + ), offset: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Community]]: + ) -> WrappedKGCommunitiesResponse: # PaginatedResultsWrapper[list[Community]]: """Lists communities in the graph with optional filtering and pagination. Each community represents a group of related entities with: @@ -1241,7 +1303,7 @@ async def list_communities( - Impact rating and explanation """ communities = await self.services["kg"].list_communities( - collection_id, level, offset, limit + collection_id, levels, community_numbers, offset, limit ) return communities # type: ignore @@ -1256,12 +1318,13 @@ async def get_community( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Community]: """Retrieves details of a specific community.""" - community = await self.services["kg"].get_community( - collection_id, community_id - ) - if not community: - raise R2RException("Community not found", 404) - return community # type: ignore + raise NotImplementedError("Not implemented") + # community = await self.services["kg"].get_community( + # collection_id, community_id + # ) + # if not community: + # raise R2RException("Community not found", 404) + # return community # type: ignore @self.router.delete( "/graphs/{collection_id}/communities", @@ -1318,18 +1381,19 @@ async def delete_communities( Deletes communities from the graph. Can delete all communities or a specific level. This is useful when you want to recreate communities with different parameters. """ - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can delete communities", 403 - ) + raise NotImplementedError("Not implemented") + # if not auth_user.is_superuser: + # raise R2RException( + # "Only superusers can delete communities", 403 + # ) - await self.services["kg"].delete_communities(collection_id, level) + # await self.services["kg"].delete_communities(collection_id, level) - if level is not None: - return { # type: ignore - "message": f"Communities at level {level} deleted successfully" - } - return {"message": "All communities deleted successfully"} # type: ignore + # if level is not None: + # return { # type: ignore + # "message": f"Communities at level {level} deleted successfully" + # } + # return {"message": "All communities deleted successfully"} # type: ignore @self.router.delete( "/graphs/{collection_id}/communities/{community_id}", @@ -1372,22 +1436,23 @@ async def delete_community( Deletes a specific community by ID. This operation will not affect other communities or the underlying entities. """ - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can delete communities", 403 - ) - - # First check if community exists - community = await self.services["kg"].get_community( - collection_id, community_id - ) - if not community: - raise R2RException("Community not found", 404) - - await self.services["kg"].delete_community( - collection_id, community_id - ) - return True # type: ignore + raise NotImplementedError("Not implemented") + # if not auth_user.is_superuser: + # raise R2RException( + # "Only superusers can delete communities", 403 + # ) + + # # First check if community exists + # community = await self.services["kg"].get_community( + # collection_id, community_id + # ) + # if not community: + # raise R2RException("Community not found", 404) + + # await self.services["kg"].delete_community( + # collection_id, community_id + # ) + # return True # type: ignore @self.router.post( "/graphs/{collection_id}/tune-prompt", @@ -1405,7 +1470,7 @@ async def delete_community( result = client.graphs.tune_prompt( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - prompt_name="kg_triples_extraction_prompt", + prompt_name="graphrag_triples_extraction_few_shot", documents_limit=100, chunks_limit=1000 )""" @@ -1419,7 +1484,7 @@ async def delete_community( -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -d '{ - "prompt_name": "kg_triples_extraction_prompt", + "prompt_name": "graphrag_triples_extraction_few_shot", "documents_limit": 100, "chunks_limit": 1000 }'""" @@ -1433,14 +1498,14 @@ async def tune_prompt( collection_id: UUID = Path(...), prompt_name: str = Body( ..., - description="The prompt to tune. Valid options: kg_triples_extraction_prompt, kg_entity_description_prompt, community_reports_prompt", + description="The prompt to tune. Valid options: graphrag_triples_extraction_few_shot, graphrag_entity_description, graphrag_community_reports", ), documents_offset: int = Body(0, ge=0), documents_limit: int = Body(100, ge=1), chunks_offset: int = Body(0, ge=0), chunks_limit: int = Body(100, ge=1), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[WrappedKGTunePromptResponse]: + ) -> WrappedKGTunePromptResponse: """Tunes a graph operation prompt using collection data. Uses sample documents and chunks from the collection to tune prompts for: diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 3ad5ccb0c..4fd840201 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -316,6 +316,24 @@ async def get_enrichment_estimate( collection_id, kg_enrichment_settings ) + @telemetry_event("list_entities") + async def list_entities( + self, + collection_id: Optional[UUID] = None, + entity_ids: Optional[list[str]] = None, + entity_table_name: str = "document_entity", + offset: Optional[int] = None, + limit: Optional[int] = None, + **kwargs, + ): + return await self.providers.database.get_entities( + collection_id=collection_id, + entity_ids=entity_ids, + entity_table_name=entity_table_name, + offset=offset or 0, + limit=limit or -1, + ) + @telemetry_event("get_entities") async def get_entities( self, @@ -352,6 +370,23 @@ async def get_triples( limit=limit or -1, ) + @telemetry_event("list_triples") + async def list_triples( + self, + collection_id: Optional[UUID] = None, + entity_names: Optional[list[str]] = None, + triple_ids: Optional[list[str]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, + **kwargs, + ): + return await self.providers.database.get_triples( + collection_id=collection_id, + entity_names=entity_names, + triple_ids=triple_ids, + offset=offset or 0, + limit=limit or -1, + ) @telemetry_event("get_communities") async def get_communities( self, @@ -370,6 +405,25 @@ async def get_communities( limit=limit or -1, ) + @telemetry_event("list_communities") + async def list_communities( + self, + collection_id: Optional[UUID] = None, + levels: Optional[list[int]] = None, + community_numbers: Optional[list[int]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, + **kwargs, + ): + return await self.providers.database.get_communities( + collection_id=collection_id, + levels=levels, + community_numbers=community_numbers, + offset=offset or 0, + limit=limit or -1, + ) + + @telemetry_event("get_deduplication_estimate") async def get_deduplication_estimate( self, diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index 0afb3050d..81ceccc0b 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -155,7 +155,7 @@ async def process_community( ( await self.llm_provider.aget_completion( messages=await self.database_provider.prompt_handler.get_message_payload( - task_prompt_name=self.database_provider.config.kg_enrichment_settings.community_reports_prompt, + task_prompt_name=self.database_provider.config.kg_enrichment_settings.graphrag_community_reports, task_inputs={ "input_text": ( await self.community_summary_prompt( diff --git a/py/core/pipes/kg/entity_description.py b/py/core/pipes/kg/entity_description.py index 5798ec307..1787d95e0 100644 --- a/py/core/pipes/kg/entity_description.py +++ b/py/core/pipes/kg/entity_description.py @@ -101,7 +101,7 @@ async def process_entity( ( await self.llm_provider.aget_completion( messages=await self.database_provider.prompt_handler.get_message_payload( - task_prompt_name=self.database_provider.config.kg_creation_settings.kg_entity_description_prompt, + task_prompt_name=self.database_provider.config.kg_creation_settings.graphrag_entity_description, task_inputs={ "entity_info": truncate_info( entity_info, diff --git a/py/core/pipes/kg/prompt_tuning.py b/py/core/pipes/kg/prompt_tuning.py index 4ccdd5d4d..719a66ed2 100644 --- a/py/core/pipes/kg/prompt_tuning.py +++ b/py/core/pipes/kg/prompt_tuning.py @@ -67,8 +67,8 @@ async def _run_logic( messages=await self.database_provider.prompt_handler.get_message_payload( task_prompt_name="prompt_tuning_task", task_inputs={ - "prompt_template": current_prompt.template, - "input_types": str(current_prompt.input_types), + "prompt_template": current_prompt["template"], + "input_types": str(current_prompt["input_types"]), "sample_data": chunks, }, ), diff --git a/py/core/pipes/kg/triples_extraction.py b/py/core/pipes/kg/triples_extraction.py index 93f4921d1..0b5f4f42c 100644 --- a/py/core/pipes/kg/triples_extraction.py +++ b/py/core/pipes/kg/triples_extraction.py @@ -85,7 +85,7 @@ async def extract_kg( combined_extraction: str = " ".join([extraction.data for extraction in extractions]) # type: ignore messages = await self.database_provider.prompt_handler.get_message_payload( - task_prompt_name=self.database_provider.config.kg_creation_settings.kg_triples_extraction_prompt, + task_prompt_name=self.database_provider.config.kg_creation_settings.graphrag_triples_extraction_few_shot, task_inputs={ "input": combined_extraction, "max_knowledge_triples": max_knowledge_triples, @@ -240,15 +240,10 @@ async def _run_logic( # type: ignore f"KGTriplesExtractionPipe: Processing document {document_id} for KG extraction", ) - # First get the chunks response - chunks_response = await self.database_provider.list_document_chunks( - document_id=document_id - ) - # Then create the extractions from the results extractions = [ DocumentChunk( - id=extraction["chunk_id"], + id=extraction["id"], document_id=extraction["document_id"], user_id=extraction["user_id"], collection_ids=extraction["collection_ids"], diff --git a/py/core/providers/database/prompts/prompt_tuning.yaml b/py/core/providers/database/prompts/prompt_tuning.yaml index 7f2f3d687..a3924c684 100644 --- a/py/core/providers/database/prompts/prompt_tuning.yaml +++ b/py/core/providers/database/prompts/prompt_tuning.yaml @@ -22,4 +22,5 @@ prompt_tuning_task: Return only the new prompt template, maintaining the exact format required for the input types. input_types: prompt_template: str + sample_data: str input_types: str diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index e7114c8fd..c6f0cb1eb 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -4,6 +4,7 @@ from uuid import UUID from ..base.base_client import sync_generator_wrapper, sync_wrapper +from core.base.abstractions import EntityLevel, KGRunType # from shared.abstractions import EntityLevel @@ -14,6 +15,10 @@ # WrappedKGEntityDeduplicationResponse, # WrappedKGTunePromptResponse, # ) +from ..models import ( + KGCreationSettings, + KGRunType +) class GraphsSDK: @@ -27,7 +32,8 @@ def __init__(self, client): async def create( self, collection_id: Union[str, UUID], - settings: Optional[Dict[str, Any]] = None, + run_type: Optional[Union[str, KGRunType]] = None, + settings: Optional[Union[dict, KGCreationSettings]] = None, run_with_orchestration: Optional[bool] = True, ): # -> WrappedKGCreationResponse: """ @@ -41,14 +47,17 @@ async def create( Returns: WrappedKGCreationResponse: Creation results """ - params = {"run_with_orchestration": run_with_orchestration} - data = {} - if settings: - data["settings"] = settings + if isinstance(settings, KGCreationSettings): + settings = settings.model_dump() - return await self.client._make_request( - "POST", f"graphs/{str(collection_id)}", json=data, params=params - ) + data = { + # "collection_id": str(collection_id) if collection_id else None, + "run_type": str(run_type) if run_type else None, + "settings": settings or {}, + "run_with_orchestration": run_with_orchestration or True, + } + + return await self.client._make_request("POST", f"graphs/{collection_id}", json=data) # type: ignore async def get_status(self, collection_id: Union[str, UUID]) -> dict: """ @@ -103,7 +112,7 @@ async def create_entity( async def get_entity( self, collection_id: Union[str, UUID], - entity_id: Union[str, UUID], + entity_id: Union[str, int], include_embeddings: bool = False, ) -> dict: """ @@ -174,7 +183,7 @@ async def delete_entity( async def list_entities( self, collection_id: Union[str, UUID], - level, # : EntityLevel = EntityLevel.DOCUMENT, + level = EntityLevel.DOCUMENT, offset: int = 0, limit: int = 100, include_embeddings: bool = False, @@ -363,6 +372,7 @@ async def list_relationships( async def create_communities( self, collection_id: Union[str, UUID], + run_type: Optional[Union[str, KGRunType]] = None, settings: Optional[Dict[str, Any]] = None, run_with_orchestration: bool = True, ): # -> WrappedKGCommunitiesResponse: @@ -382,6 +392,9 @@ async def create_communities( if settings: data["settings"] = settings + if run_type: + data["run_type"] = str(run_type) + return await self.client._make_request( "POST", f"graphs/{str(collection_id)}/communities", @@ -523,8 +536,8 @@ async def tune_prompt( Args: collection_id (Union[str, UUID]): Collection ID to tune prompt for - prompt_name (str): Name of prompt to tune (kg_triples_extraction_prompt, - kg_entity_description_prompt, or community_reports_prompt) + prompt_name (str): Name of prompt to tune (graphrag_triples_extraction_few_shot, + graphrag_entity_description, or graphrag_community_reports) documents_offset (int): Document pagination offset documents_limit (int): Maximum number of documents to use chunks_offset (int): Chunk pagination offset diff --git a/py/shared/abstractions/kg.py b/py/shared/abstractions/kg.py index e50ae95f0..beb2d6b1d 100644 --- a/py/shared/abstractions/kg.py +++ b/py/shared/abstractions/kg.py @@ -29,13 +29,13 @@ def __str__(self): class KGCreationSettings(R2RSerializable): """Settings for knowledge graph creation.""" - kg_triples_extraction_prompt: str = Field( + graphrag_triples_extraction_few_shot: str = Field( default="graphrag_triples_extraction_few_shot", description="The prompt to use for knowledge graph extraction.", alias="graphrag_triples_extraction_few_shot_prompt", # TODO - mark deprecated & remove ) - kg_entity_description_prompt: str = Field( + graphrag_entity_description: str = Field( default="graphrag_entity_description", description="The prompt to use for entity description generation.", alias="graphrag_entity_description_prompt", # TODO - mark deprecated & remove @@ -109,10 +109,10 @@ class KGEnrichmentSettings(R2RSerializable): description="Force run the enrichment step even if graph creation is still in progress for some documents.", ) - community_reports_prompt: str = Field( + graphrag_community_reports: str = Field( default="graphrag_community_reports", description="The prompt to use for knowledge graph enrichment.", - alias="community_reports_prompt", # TODO - mark deprecated & remove + alias="graphrag_community_reports", # TODO - mark deprecated & remove ) max_summary_input_length: int = Field( From 6fdbd6c3ddccf525c59b989737ea0bf8da404c6e Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Fri, 8 Nov 2024 06:54:30 -0800 Subject: [PATCH 03/21] up --- py/core/__init__.py | 2 +- py/core/base/__init__.py | 2 +- py/core/base/abstractions/__init__.py | 4 +- py/core/base/providers/database.py | 20 +- .../scripts/test_v3_sdk/test_v3_sdk_graph.py | 96 +++--- py/core/main/api/v2/kg_router.py | 1 - py/core/main/api/v3/graph_router.py | 293 ++++++++---------- py/core/main/services/kg_service.py | 66 +++- py/core/pipes/kg/community_summary.py | 4 +- py/core/pipes/kg/triples_extraction.py | 4 +- py/core/providers/database/kg.py | 171 +++++++++- py/sdk/v3/graphs.py | 7 +- py/shared/abstractions/__init__.py | 4 +- py/shared/abstractions/graph.py | 9 +- py/shared/api/models/kg/responses.py | 4 +- .../pipes/test_kg_community_summary_pipe.py | 6 +- py/tests/core/providers/kg/test_kg_logic.py | 12 +- 17 files changed, 453 insertions(+), 252 deletions(-) diff --git a/py/core/__init__.py b/py/core/__init__.py index b5dcc8d15..dc15d1e1c 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -65,7 +65,7 @@ # KG abstractions "Entity", "KGExtraction", - "Triple", + "Relationship", # LLM abstractions "GenerationConfig", "LLMChatCompletion", diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index c83c3eb19..45261c484 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -33,7 +33,7 @@ # KG abstractions "Entity", "KGExtraction", - "Triple", + "Relationship", # LLM abstractions "GenerationConfig", "LLMChatCompletion", diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index ad31fdc62..9bc5e1d2d 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -28,7 +28,7 @@ EntityType, KGExtraction, RelationshipType, - Triple, + Relationship, ) from shared.abstractions.ingestion import ( ChunkEnrichmentSettings, @@ -111,7 +111,7 @@ "Community", "CommunityReport", "KGExtraction", - "Triple", + "Relationship", "EntityLevel", # LLM abstractions "GenerationConfig", diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 1b981b08b..dd2fe8d6a 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -22,7 +22,7 @@ Entity, KGExtraction, Message, - Triple, + Relationship, VectorEntry, ) from core.base.abstractions import ( @@ -71,7 +71,7 @@ KGExtraction, KGSearchSettings, RelationshipType, - Triple, + Relationship, ) from .base import ProviderConfig @@ -644,7 +644,7 @@ async def add_entities( @abstractmethod async def add_triples( self, - triples: list[Triple], + triples: list[Relationship], table_name: str = "chunk_triple", ) -> None: """Add triples to storage.""" @@ -701,7 +701,7 @@ async def add_community_report( @abstractmethod async def get_community_details( self, community_number: int, collection_id: UUID - ) -> Tuple[int, list[Entity], list[Triple]]: + ) -> Tuple[int, list[Entity], list[Relationship]]: """Get detailed information about a community.""" pass @@ -743,7 +743,7 @@ async def delete_node_via_document_id( """Delete a node using document ID.""" pass - # Entity and Triple management + # Entity and Relationship management @abstractmethod async def get_entities( self, @@ -857,7 +857,7 @@ async def get_existing_entity_extraction_ids( raise NotImplementedError @abstractmethod - async def get_all_triples(self, collection_id: UUID) -> List[Triple]: + async def get_all_triples(self, collection_id: UUID) -> List[Relationship]: raise NotImplementedError @abstractmethod @@ -1527,7 +1527,7 @@ async def add_entities( async def add_triples( self, - triples: list[Triple], + triples: list[Relationship], table_name: str = "chunk_triple", ) -> None: """Forward to KG handler add_triples method.""" @@ -1577,7 +1577,7 @@ async def add_community_report( async def get_community_details( self, community_number: int, collection_id: UUID - ) -> Tuple[int, list[Entity], list[Triple]]: + ) -> Tuple[int, list[Entity], list[Relationship]]: """Forward to KG handler get_community_details method.""" return await self.kg_handler.get_community_details( community_number, collection_id @@ -1624,7 +1624,7 @@ async def delete_node_via_document_id( document_id, collection_id ) - # Entity and Triple operations + # Entity and Relationship operations async def get_entities( self, collection_id: Optional[UUID], @@ -1710,7 +1710,7 @@ async def get_deduplication_estimate( collection_id, kg_deduplication_settings ) - async def get_all_triples(self, collection_id: UUID) -> List[Triple]: + async def get_all_triples(self, collection_id: UUID) -> List[Relationship]: return await self.kg_handler.get_all_triples(collection_id) async def update_entity_descriptions(self, entities: list[Entity]): diff --git a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py index 0f7561a1e..46e107580 100644 --- a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py +++ b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py @@ -5,66 +5,67 @@ # Initialize client client = R2RClient("http://localhost:7276", prefix="/v3") + def setup_prerequisites(): """Setup necessary document and collection""" print("\n=== Setting up prerequisites ===") - + # # Login # try: # client.users.register(email=user_email, password="new_secure_password123") # except Exception as e: # print("User might already exist:", str(e)) - + # result = client.users.login(email=user_email, password="new_secure_password123") # print("Login successful") - try: # Create document doc_result = client.documents.create( file_path="../../data/pg_essay_1.html", metadata={"source": "test"}, - run_with_orchestration=False + run_with_orchestration=False, ) - print('doc_id = ', doc_result) - doc_id = doc_result['results']['document_id'] + print("doc_id = ", doc_result) + doc_id = doc_result["results"]["document_id"] print(f"Created document with ID: {doc_id}") except Exception as e: doc_id = "b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" pass - + # Create collection # collection_id = str(uuid.uuid4()) collection_result = client.collections.create( # collection_id=collection_id, name="Test Collection", - description="Collection for testing graph operations" + description="Collection for testing graph operations", + ) + print( + "Created collection with ID: " + + str(collection_result["results"]["collection_id"]) ) - print("Created collection with ID: " + str(collection_result["results"]["collection_id"])) collection_id = collection_result["results"]["collection_id"] # Add document to collection - client.collections.add_document( - id=collection_id, - document_id=doc_id - ) + client.collections.add_document(id=collection_id, document_id=doc_id) print(f"Added document {doc_id} to collection {collection_id}") - + return collection_id, doc_id + def test_graph_operations(collection_id): """Test graph CRUD operations""" print("\n=== Testing Graph Operations ===") - + # Test 1: Create Graph print("\n--- Test 1: Create Graph ---") create_result = client.graphs.create( collection_id=collection_id, settings={ "entity_types": ["PERSON", "ORG", "GPE"], - "min_confidence": 0.8 + "min_confidence": 0.8, }, run_type="estimate", - run_with_orchestration=False + run_with_orchestration=False, ) print("Graph estimation result:", create_result) @@ -72,13 +73,13 @@ def test_graph_operations(collection_id): collection_id=collection_id, settings={ "entity_types": ["PERSON", "ORG", "GPE"], - "min_confidence": 0.8 + "min_confidence": 0.8, }, run_type="run", - run_with_orchestration=False + run_with_orchestration=False, ) print("Graph creation result:", create_result) - + # # # Test 2: Get Graph Status # # print("\n--- Test 2: Get Graph Status ---") # # status_result = client.graphs.get_status(collection_id=collection_id) @@ -90,21 +91,25 @@ def test_graph_operations(collection_id): collection_id=collection_id, # level="collection", offset=0, - limit=10 + limit=10, ) print("Entities:", entities_result) - + # Test 4: Get Specific Entity - print('entities_result["results"]["entities"][0] = ', entities_result["results"]["entities"][0]) - entity_id = entities_result["results"]["entities"][0]["id"] #entities_result['items'][0]['id'] - print('entity_id = ', entity_id) + print( + 'entities_result["results"]["entities"][0] = ', + entities_result["results"]["entities"][0], + ) + entity_id = entities_result["results"]["entities"][0][ + "id" + ] # entities_result['items'][0]['id'] + print("entity_id = ", entity_id) print(f"\n--- Test 4: Get Entity {entity_id} ---") entity_result = client.graphs.get_entity( - collection_id=collection_id, - entity_id=entity_id + collection_id=collection_id, entity_id=entity_id ) print("Entity details:", entity_result) - + # # # # Test 5: List Relationships # # # print("\n--- Test 5: List Relationships ---") # # relationships_result = client.graphs.list_relationships( @@ -113,13 +118,13 @@ def test_graph_operations(collection_id): # # limit=10 # # ) # # print("Relationships:", relationships_result) - + # Test 6: Create Communities print("\n--- Test 6: Create Communities ---") communities_result = client.graphs.create_communities( run_type="estimate", collection_id=collection_id, - run_with_orchestration=False + run_with_orchestration=False, # settings={ # "algorithm": "louvain", # "resolution": 1.0, @@ -131,7 +136,7 @@ def test_graph_operations(collection_id): communities_result = client.graphs.create_communities( run_type="run", collection_id=collection_id, - run_with_orchestration=False + run_with_orchestration=False, # settings={ # "algorithm": "louvain", # "resolution": 1.0, @@ -139,48 +144,46 @@ def test_graph_operations(collection_id): # } ) print("Communities creation result:", communities_result) - + # Wait for community creation to complete - + # Test 7: List Communities print("\n--- Test 7: List Communities ---") communities_list = client.graphs.list_communities( - collection_id=collection_id, - offset=0, - limit=10 + collection_id=collection_id, offset=0, limit=10 ) print("Communities:", communities_list) - + # Test 8: Tune Prompt print("\n--- Test 8: Tune Prompt ---") tune_result = client.graphs.tune_prompt( collection_id=collection_id, prompt_name="graphrag_triples_extraction_few_shot", documents_limit=100, - chunks_limit=1000 + chunks_limit=1000, ) print("Prompt tuning result:", tune_result) - + # Test 9: Entity Deduplication print("\n--- Test 9: Entity Deduplication ---") dedup_result = client.graphs.deduplicate_entities( collection_id=collection_id, settings={ "kg_entity_deduplication_type": "by_name", - "max_description_input_length": 65536 - } + "max_description_input_length": 65536, + }, ) print("Deduplication result:", dedup_result) - + # Optional: Clean up # Test 10: Delete Graph print("\n--- Test 10: Delete Graph ---") delete_result = client.graphs.delete( - collection_id=collection_id, - cascade=True + collection_id=collection_id, cascade=True ) print("Graph deletion result:", delete_result) + def main(): try: # Setup prerequisites @@ -189,7 +192,7 @@ def main(): # Run graph operations tests test_graph_operations(collection_id) - + except Exception as e: print(f"Error occurred: {str(e)}") finally: @@ -198,5 +201,6 @@ def main(): # client.users.logout() # print("\nLogged out successfully") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index 9f45f7fd7..fab231fdd 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -289,7 +289,6 @@ async def get_entities( limit, ) - @self.router.get("/triples") @self.base_endpoint async def get_triples( diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 96c803162..5a56dfa2d 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -7,14 +7,14 @@ from pydantic import BaseModel, Field, Json from core.base import R2RException, RunType, KGCreationSettings -from core.base.abstractions import EntityLevel, KGRunType, Entity +from core.base.abstractions import EntityLevel, KGRunType, Entity, Relationship, Community from core.base.api.models import ( WrappedKGCreationResponse, WrappedKGEnrichmentResponse, WrappedKGEntityDeduplicationResponse, WrappedKGTunePromptResponse, WrappedKGEntitiesResponse, - WrappedKGCommunitiesResponse + WrappedKGCommunitiesResponse, ) from core.providers import ( HatchetOrchestrationProvider, @@ -31,77 +31,6 @@ logger = logging.getLogger() -# class Entity(BaseModel): -# """Model representing a graph entity.""" - -# id: UUID -# name: str -# type: str -# metadata: dict = Field(default_factory=dict) -# level: EntityLevel -# collection_ids: list[UUID] -# embedding: Optional[list[float]] = None - -# class Config: -# json_schema_extra = { -# "example": { -# "id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", -# "name": "John Smith", -# "type": "PERSON", -# "metadata": {"confidence": 0.95}, -# "level": "DOCUMENT", -# "collection_ids": ["d09dedb1-b2ab-48a5-b950-6e1f464d83e7"], -# "embedding": [0.1, 0.2, 0.3], -# } -# } - - -class Relationship(BaseModel): - """Model representing a graph relationship.""" - - id: UUID - source_id: UUID - target_id: UUID - type: str - metadata: dict = Field(default_factory=dict) - collection_ids: list[UUID] - - class Config: - json_schema_extra = { - "example": { - "id": "8abc123d-ef45-678g-hi90-jklmno123456", - "source_id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", - "target_id": "7cde891f-2a3b-4c5d-6e7f-gh8i9j0k1l2m", - "type": "WORKS_FOR", - "metadata": {"confidence": 0.85}, - "collection_ids": ["d09dedb1-b2ab-48a5-b950-6e1f464d83e7"], - } - } - - -class Community(BaseModel): - """Model representing a graph community.""" - - id: UUID - level: int - number: int - entities: list[UUID] - metadata: dict = Field(default_factory=dict) - collection_id: UUID - - class Config: - json_schema_extra = { - "example": { - "id": "5xyz789a-bc12-3def-4ghi-jk5lm6no7pq8", - "level": 1, - "number": 3, - "entities": ["9fbe403b-c11c-5aae-8ade-ef22980c3ad1"], - "metadata": {"topic": "Finance"}, - "collection_id": "d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - } - } - - class GraphRouter(BaseRouterV3): def __init__( self, @@ -267,7 +196,7 @@ async def create_graph( ) @self.base_endpoint async def get_graph_status( - collection_id: UUID = Path(...), + collection_id: UUID = Path(...), # TODO: change to id? auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[dict]: """ @@ -280,14 +209,10 @@ async def get_graph_status( - Community statistics - Current settings """ - raise NotImplementedError("Not implemented", 501) - # if not auth_user.is_superuser: - # raise R2RException( - # "Only superusers can view graph status", 403 - # ) + # check if user has access the collection_id - # status = await self.services["kg"].get_graph_status(collection_id) - # return status # type: ignore + status = await self.services["kg"].get_graph_status(collection_id) + return status # type: ignore # @self.router.post( # "/graphs/{collection_id}/enrich", @@ -413,7 +338,7 @@ async def delete_graph( # Entity operations @self.router.post( - "/graphs/{collection_id}/entities", + "/graphs/{collection_id}/entities/{level}", summary="Create a new entity", openapi_extra={ "x-codeSamples": [ @@ -435,7 +360,6 @@ async def delete_graph( "source": "manual", "confidence": 1.0 }, - "level": "DOCUMENT" } )""" ), @@ -444,7 +368,7 @@ async def delete_graph( "lang": "cURL", "source": textwrap.dedent( """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/entities" \\ + curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/entities/document" \\ -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -d '{ @@ -454,7 +378,6 @@ async def delete_graph( "source": "manual", "confidence": 1.0 }, - "level": "DOCUMENT" }'""" ), }, @@ -468,14 +391,55 @@ async def create_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Entity]: """Creates a new entity in the graph.""" - raise NotImplementedError("Not implemented", 501) - # if not auth_user.is_superuser: - # raise R2RException("Only superusers can create entities", 403) + # entity validation. + entity = Entity(**entity) + level = entity.level - # new_entity = await self.services["kg"].create_entity( - # collection_id, entity - # ) - # return new_entity # type: ignore + if level is None: + raise R2RException( + "Entity level must be provided. Value is one of: collection, document, chunk", + 400, + ) + + if level == EntityLevel.DOCUMENT and not entity.document_id: + raise R2RException( + "document_id must be provided for all entities if level is DOCUMENT", + 400, + ) + + if ( + level == EntityLevel.COLLECTION + and not entity.collection_id + and not entity.document_ids + ): + raise R2RException( + "collection_id or document_ids must be provided for all entities if level is COLLECTION", + 400, + ) + + if level == EntityLevel.CHUNK and not entity.document_id: + raise R2RException( + "document_id must be provided for all entities if level is CHUNK", + 400, + ) + + # check if entity level is not chunk, then description embedding must be provided + if level != EntityLevel.CHUNK and entity.description_embedding: + raise R2RException( + "Please do not provide a description_embedding. R2R will automatically generate embeddings", + 400, + ) + + # check that ID is not provided for any entity + if entity.id: + raise R2RException( + "ID is not allowed to be provided for any entity. It is automatically generated when the entity is added to the graph.", + 400, + ) + + return await self.services["kg"].create_entity( + collection_id, entity + ) @self.router.delete( "/graphs/{collection_id}/entities/{entity_id}", @@ -520,14 +484,26 @@ async def delete_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[dict]: """Deletes an entity and optionally its relationships.""" - raise NotImplementedError("Not implemented", 501) - # if not auth_user.is_superuser: - # raise R2RException("Only superusers can delete entities", 403) - # await self.services["kg"].delete_entity( - # collection_id, entity_id, cascade - # ) - # return {"message": "Entity deleted successfully"} # type: ignore + # implement permission check. + if cascade == True: + # we don't currently have entity IDs in the triples table, so we can't cascade delete. + # we will be able to delete by name. + raise NotImplementedError( + "Cascade deletion is not implemented", 501 + ) + + if type(entity_id) == UUID: + # FIXME: currently entity ID is an integer in the graph. we need to change it to UUID + raise ValueError( + "Currently Entity ID is an integer in the graph. we need to change it to UUID for this endpoint to work", + 400, + ) + + await self.services["kg"].delete_entity( + collection_id, entity_id, cascade + ) + return {"message": "Entity deleted successfully"} # type: ignore @self.router.get( "/graphs/{collection_id}/entities", @@ -572,7 +548,9 @@ async def list_entities( limit: int = Query(100, ge=1, le=1000), # include_embeddings: bool = Query(False), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGEntitiesResponse: # PaginatedResultsWrapper[list[Entity]]: + ) -> ( + WrappedKGEntitiesResponse + ): # PaginatedResultsWrapper[list[Entity]]: """Lists entities in the graph with filtering and pagination support. Entities represent the nodes in the knowledge graph, extracted from documents. @@ -584,15 +562,14 @@ async def list_entities( - Community memberships - Optional vector embedding """ - if level == EntityLevel.CHUNK: - entity_table_name = "chunk_entity" - elif level == EntityLevel.DOCUMENT: - entity_table_name = "document_entity" - else: - entity_table_name = "collection_entity" + entity_table_name = level.value + "_entity" return await self.services["kg"].list_entities( - collection_id=collection_id, entity_ids=[], entity_table_name=entity_table_name, offset=offset, limit=limit + collection_id=collection_id, + entity_ids=[], + entity_table_name=entity_table_name, + offset=offset, + limit=limit, ) @self.router.get( @@ -609,17 +586,13 @@ async def get_entity( ) -> ResultsWrapper[Entity]: """Retrieves details of a specific entity.""" - if level == EntityLevel.CHUNK: - entity_table_name = "chunk_entity" - elif level == EntityLevel.DOCUMENT: - entity_table_name = "document_entity" - else: - entity_table_name = "collection_entity" - - result = (await self.services["kg"].list_entities( - collection_id=collection_id, entity_ids=[entity_id], entity_table_name=entity_table_name # , offset=offset, limit=limit - )) - return result['entities'][0] # type: ignore + entity_table_name = level.value + "_entity" + result = await self.services["kg"].list_entities( + collection_id=collection_id, + entity_ids=[entity_id], + entity_table_name=entity_table_name, # , offset=offset, limit=limit + ) + return result["entities"][0] # type: ignore @self.router.post( "/graphs/{collection_id}/entities/{entity_id}", @@ -675,14 +648,14 @@ async def update_entity( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Entity]: """Updates an existing entity.""" - raise NotImplementedError("Not implemented", 501) - # if not auth_user.is_superuser: - # raise R2RException("Only superusers can update entities", 403) - # updated_entity = await self.services["kg"].update_entity( - # collection_id, entity_id, entity_update - # ) - # return updated_entity # type: ignore + if not auth_user.is_superuser: + raise R2RException("Only superusers can update entities", 403) + + updated_entity = await self.services["kg"].update_entity( + collection_id, entity_id, entity_update + ) + return updated_entity # type: ignore @self.router.post( "/graphs/{collection_id}/entities/deduplicate", @@ -813,7 +786,7 @@ async def deduplicate_entities( } @self.router.post( - "/graphs/{collection_id}/relationships", + "/graphs/{document_id}/relationships", summary="Create a new relationship", openapi_extra={ "x-codeSamples": [ @@ -827,7 +800,7 @@ async def deduplicate_entities( # when using auth, do client.login(...) result = client.graphs.create_relationship( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + document_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", relationship={ "source_id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", "target_id": "7cde891f-2a3b-4c5d-6e7f-gh8i9j0k1l2m", @@ -863,21 +836,25 @@ async def deduplicate_entities( ) @self.base_endpoint async def create_relationship( - collection_id: UUID = Path(...), relationship: dict = Body(...), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Relationship]: """Creates a new relationship between entities.""" - raise NotImplementedError("Not implemented", 501) - # if not auth_user.is_superuser: - # raise R2RException( - # "Only superusers can create relationships", 403 - # ) - # new_relationship = await self.services["kg"].create_relationship( - # collection_id, relationship - # ) - # return new_relationship # type: ignore + # we define relationships only at a document level + # when a user creates a graph on two collections with a document in common, the the work is not duplicated + + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can create relationships", 403 + ) + + # validate if document_id is valid + + new_relationship = await self.services["kg"].create_relationship( + document_id, relationship + ) + return new_relationship # type: ignore # Relationship operations @self.router.get( @@ -1018,18 +995,17 @@ async def update_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[Relationship]: """Updates an existing relationship.""" - raise NotImplementedError("Not implemented") - # if not auth_user.is_superuser: - # raise R2RException( - # "Only superusers can update relationships", 403 - # ) + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can update relationships", 403 + ) - # updated_relationship = await self.services[ - # "kg" - # ].update_relationship( - # collection_id, relationship_id, relationship_update - # ) - # return updated_relationship # type: ignore + updated_relationship = await self.services[ + "kg" + ].update_relationship( + relationship_id, relationship_update + ) + return updated_relationship # type: ignore @self.router.delete( "/graphs/{collection_id}/relationships/{relationship_id}", @@ -1069,16 +1045,15 @@ async def delete_relationship( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[dict]: """Deletes a relationship.""" - raise NotImplementedError("Not implemented") - # if not auth_user.is_superuser: - # raise R2RException( - # "Only superusers can delete relationships", 403 - # ) + if not auth_user.is_superuser: + raise R2RException( + "Only superusers can delete relationships", 403 + ) - # await self.services["kg"].delete_relationship( - # collection_id, relationship_id - # ) - # return {"message": "Relationship deleted successfully"} # type: ignore + await self.services["kg"].delete_relationship( + collection_id, relationship_id + ) + return {"message": "Relationship deleted successfully"} # type: ignore # Community operations @self.router.post( @@ -1292,7 +1267,9 @@ async def list_communities( offset: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=1000), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGCommunitiesResponse: # PaginatedResultsWrapper[list[Community]]: + ) -> ( + WrappedKGCommunitiesResponse + ): # PaginatedResultsWrapper[list[Community]]: """Lists communities in the graph with optional filtering and pagination. Each community represents a group of related entities with: diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 4fd840201..4d247e47b 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -12,6 +12,8 @@ KGEntityDeduplicationSettings, KGEntityDeduplicationType, R2RException, + Entity, + Relationship, ) from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider from core.telemetry.telemetry_decorator import telemetry_event @@ -20,6 +22,7 @@ from ..config import R2RConfig from .base import Service + logger = logging.getLogger() @@ -113,6 +116,50 @@ async def kg_triples_extraction( return await _collect_results(result_gen) + @telemetry_event("create_entity") + async def create_entity( + self, + collection_id: UUID, + entity: Entity, + **kwargs, + ): + return await self.providers.database.create_entity( + collection_id, entity + ) + + @telemetry_event("update_entity") + async def update_entity( + self, + collection_id: UUID, + entity: Entity, + **kwargs, + ): + return await self.providers.database.update_entity( + collection_id, entity + ) + + @telemetry_event("delete_entity") + async def delete_entity( + self, + collection_id: UUID, + entity: Entity, + **kwargs, + ): + return await self.providers.database.delete_entity( + collection_id, entity + ) + + @telemetry_event("create_relationship") + async def create_relationship( + self, + collection_id: UUID, + relationship: Relationship, + **kwargs, + ): + return await self.providers.database.create_relationship( + collection_id, relationship + ) + @telemetry_event("get_document_ids_for_create_graph") async def get_document_ids_for_create_graph( self, @@ -210,6 +257,14 @@ async def kg_entity_description( return all_results + @telemetry_event("get_graph_status") + async def get_graph_status( + self, + collection_id: UUID, + **kwargs, + ): + return await self.providers.database.get_graph_status(collection_id) + @telemetry_event("kg_clustering") async def kg_clustering( self, @@ -271,6 +326,15 @@ async def delete_graph_for_documents( # TODO: Implement this, as it needs some checks. raise NotImplementedError + @telemetry_event("delete_graph") + async def delete_graph( + self, + collection_id: UUID, + cascade: bool, + **kwargs, + ): + return await self.delete_graph_for_collection(collection_id, cascade) + @telemetry_event("delete_graph_for_collection") async def delete_graph_for_collection( self, @@ -387,6 +451,7 @@ async def list_triples( offset=offset or 0, limit=limit or -1, ) + @telemetry_event("get_communities") async def get_communities( self, @@ -423,7 +488,6 @@ async def list_communities( limit=limit or -1, ) - @telemetry_event("get_deduplication_estimate") async def get_deduplication_estimate( self, diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index 81ceccc0b..2e2653654 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -15,7 +15,7 @@ EmbeddingProvider, GenerationConfig, ) -from core.base.abstractions import Entity, Triple +from core.base.abstractions import Entity, Relationship from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -51,7 +51,7 @@ def __init__( async def community_summary_prompt( self, entities: list[Entity], - triples: list[Triple], + triples: list[Relationship], max_summary_input_length: int, ): diff --git a/py/core/pipes/kg/triples_extraction.py b/py/core/pipes/kg/triples_extraction.py index 0b5f4f42c..b0f86d58c 100644 --- a/py/core/pipes/kg/triples_extraction.py +++ b/py/core/pipes/kg/triples_extraction.py @@ -15,7 +15,7 @@ KGExtraction, R2RDocumentProcessingError, R2RException, - Triple, + Relationship, ) from core.base.pipes.base_pipe import AsyncPipe from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider @@ -159,7 +159,7 @@ def parse_fn(response_str: str) -> Any: # check if subject and object are in entities_dict relations_arr.append( - Triple( + Relationship( subject=subject, predicate=predicate, object=object, diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 31162a59a..13d4a4517 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -14,7 +14,7 @@ KGExtractionStatus, KGHandler, R2RException, - Triple, + Relationship ) from core.base.abstractions import ( EntityLevel, @@ -247,16 +247,67 @@ async def add_entities( cleaned_entities, table_name, conflict_columns ) + async def get_graph_status(self, collection_id: UUID) -> dict: + # check document_info table for the documents in the collection and return the status of each document + kg_extraction_statuses = await self.connection_manager.fetch_query( + f"SELECT document_id, kg_extraction_status FROM {self._get_table_name('document_info')} WHERE collection_id = $1", + [collection_id], + ) + + document_ids = [doc_id["document_id"] for doc_id in kg_extraction_statuses] + + kg_enrichment_statuses = await self.connection_manager.fetch_query( + f"SELECT enrichment_status FROM {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} WHERE id = $1", + [collection_id], + ) + + # entity and relationship counts + chunk_entity_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1)", + [document_ids], + ) + + chunk_triple_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('chunk_triple')} WHERE document_id = ANY($1)", + [document_ids], + ) + + document_entity_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1)", + [document_ids], + ) + + collection_entity_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('collection_entity')} WHERE collection_id = $1", + [collection_id], + ) + + community_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('community_report')} WHERE collection_id = $1", + [collection_id], + ) + + return { + "kg_extraction_statuses": kg_extraction_statuses, + "kg_enrichment_status": kg_enrichment_statuses[0]["enrichment_status"], + "chunk_entity_count": chunk_entity_count[0]["count"], + "chunk_triple_count": chunk_triple_count[0]["count"], + "document_entity_count": document_entity_count[0]["count"], + "collection_entity_count": collection_entity_count[0]["count"], + "community_count": community_count[0]["count"], + } + + async def add_triples( self, - triples: list[Triple], + triples: list[Relationship], table_name: str = "chunk_triple", ) -> None: """ Upsert triples into the chunk_triple table. These are raw triples extracted from the document. Args: - triples: list[Triple]: list of triples to upsert + triples: list[Relationship]: list of triples to upsert table_name: str: name of the table to upsert into Returns: @@ -375,7 +426,7 @@ async def get_entity_map( QUERY2, [document_id] ) triples_list = [ - Triple( + Relationship( subject=triple["subject"], predicate=triple["predicate"], object=triple["object"], @@ -500,7 +551,7 @@ async def vector_query( # type: ignore for property_name in property_names } - async def get_all_triples(self, collection_id: UUID) -> list[Triple]: + async def get_all_triples(self, collection_id: UUID) -> list[Relationship]: # getting all documents for a collection QUERY = f""" @@ -517,7 +568,7 @@ async def get_all_triples(self, collection_id: UUID) -> list[Triple]: triples = await self.connection_manager.fetch_query( QUERY, [document_ids] ) - return [Triple(**triple) for triple in triples] + return [Relationship(**triple) for triple in triples] async def add_communities(self, communities: list[Any]) -> None: QUERY = f""" @@ -736,7 +787,7 @@ async def _compute_leiden_communities( async def get_community_details( self, community_number: int, collection_id: UUID - ) -> Tuple[int, list[Entity], list[Triple]]: + ) -> Tuple[int, list[Entity], list[Relationship]]: QUERY = f""" SELECT level FROM {self._get_table_name("community_info")} WHERE cluster = $1 AND collection_id = $2 @@ -792,13 +843,115 @@ async def get_community_details( triples = await self.connection_manager.fetch_query( QUERY, [community_number, collection_id] ) - triples = [Triple(**triple) for triple in triples] + triples = [Relationship(**triple) for triple in triples] return level, entities, triples # async def client(self): # return None + ############################################################ + ########## Entity CRUD Operations ########################## + ############################################################ + + async def create_entity( + self, collection_id: UUID, entity: Entity + ) -> None: + + table_name = entity.level.value + "_entity" + entity.level = None + + # check if the entity already exists + QUERY = f""" + SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 + """ + count = ( + await self.connection_manager.fetch_query(QUERY, [entity.id, collection_id]) + )[0]["count"] + + if count > 0: + raise R2RException("Entity already exists", 400) + + await self._add_objects([entity], table_name) + + async def update_entity( + self, collection_id: UUID, entity: Entity + ) -> None: + table_name = entity.level.value + "_entity" + + # check if the entity already exists + QUERY = f""" + SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 + """ + count = ( + await self.connection_manager.fetch_query(QUERY, [entity.id, collection_id]) + )[0]["count"] + + if count == 0: + raise R2RException("Entity does not exist", 404) + + await self._add_objects([entity], table_name) + + async def delete_entity( + self, collection_id: UUID, entity: Entity + ) -> None: + + table_name = entity.level.value + "_entity" + QUERY = f""" + DELETE FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 + """ + await self.connection_manager.execute_query(QUERY, [entity.id, collection_id]) + + ############################################################ + ########## Relationship CRUD Operations #################### + ############################################################ + + async def create_relationship( + self, collection_id: UUID, relationship: Relationship + ) -> None: + + # check if the relationship already exists + QUERY = f""" + SELECT COUNT(*) FROM {self._get_table_name("chunk_triple")} WHERE subject = $1 AND predicate = $2 AND object = $3 AND collection_id = $4 + """ + count = ( + await self.connection_manager.fetch_query(QUERY, [relationship.subject, relationship.predicate, relationship.object, collection_id]) + )[0]["count"] + + if count > 0: + raise R2RException("Relationship already exists", 400) + + await self._add_objects([relationship], "chunk_triple") + + async def update_relationship( + self, relationship_id: UUID, relationship: Relationship + ) -> None: + + # check if relationship_id exists + QUERY = f""" + SELECT COUNT(*) FROM {self._get_table_name("chunk_triple")} WHERE id = $1 + """ + count = ( + await self.connection_manager.fetch_query(QUERY, [relationship.id]) + )[0]["count"] + + if count == 0: + raise R2RException("Relationship does not exist", 404) + + await self._add_objects([relationship], "chunk_triple") + + async def delete_relationship( + self, relationship_id: UUID + ) -> None: + QUERY = f""" + DELETE FROM {self._get_table_name("chunk_triple")} WHERE id = $1 + """ + await self.connection_manager.execute_query(QUERY, [relationship_id]) + + ############################################################ + ########## Community CRUD Operations ####################### + ############################################################ + async def get_community_reports( self, collection_id: UUID ) -> list[CommunityReport]: @@ -1236,7 +1389,7 @@ async def get_triples( """ triples = await self.connection_manager.fetch_query(query, params) - triples = [Triple(**triple) for triple in triples] + triples = [Relationship(**triple) for triple in triples] total_entries = await self.get_triple_count( collection_id=collection_id ) diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index c6f0cb1eb..4e23c7874 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -15,10 +15,7 @@ # WrappedKGEntityDeduplicationResponse, # WrappedKGTunePromptResponse, # ) -from ..models import ( - KGCreationSettings, - KGRunType -) +from ..models import KGCreationSettings, KGRunType class GraphsSDK: @@ -183,7 +180,7 @@ async def delete_entity( async def list_entities( self, collection_id: Union[str, UUID], - level = EntityLevel.DOCUMENT, + level=EntityLevel.DOCUMENT, offset: int = 0, limit: int = 100, include_embeddings: bool = False, diff --git a/py/shared/abstractions/__init__.py b/py/shared/abstractions/__init__.py index ad88957ca..c5d5b254b 100644 --- a/py/shared/abstractions/__init__.py +++ b/py/shared/abstractions/__init__.py @@ -20,7 +20,7 @@ EntityType, KGExtraction, RelationshipType, - Triple, + Relationship, ) from .kg import ( KGCreationSettings, @@ -96,7 +96,7 @@ "Community", "CommunityReport", "KGExtraction", - "Triple", + "Relationship", # LLM abstractions "GenerationConfig", "LLMChatCompletion", diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 049679cdc..079d9be58 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -58,6 +58,7 @@ class Entity(R2RSerializable): name: str id: Optional[int] = None + level: Optional[EntityLevel] = None category: Optional[str] = None description: Optional[str] = None description_embedding: Optional[Union[list[float], str]] = None @@ -88,7 +89,7 @@ def __init__(self, **kwargs): self.attributes = self.attributes -class Triple(R2RSerializable): +class Relationship(R2RSerializable): """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" id: Optional[int] = None @@ -142,10 +143,10 @@ def from_dict( # type: ignore extraction_ids_key: str = "extraction_ids", document_id_key: str = "document_id", attributes_key: str = "attributes", - ) -> "Triple": + ) -> "Relationship": """Create a new relationship from the dict data.""" - return Triple( + return Relationship( id=d[id_key], short_id=d.get(short_id_key), subject=d[source_key], @@ -302,4 +303,4 @@ class KGExtraction(R2RSerializable): extraction_ids: list[uuid.UUID] document_id: uuid.UUID entities: list[Entity] - triples: list[Triple] + triples: list[Relationship] diff --git a/py/shared/api/models/kg/responses.py b/py/shared/api/models/kg/responses.py index 3438b26e1..36813c727 100644 --- a/py/shared/api/models/kg/responses.py +++ b/py/shared/api/models/kg/responses.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field from shared.abstractions.base import R2RSerializable -from shared.abstractions.graph import CommunityReport, Entity, Triple +from shared.abstractions.graph import CommunityReport, Entity, Relationship from shared.api.models.base import ResultsWrapper @@ -210,7 +210,7 @@ class Config: class KGTriplesResponse(R2RSerializable): """Response for knowledge graph triples.""" - triples: list[Triple] = Field( + triples: list[Relationship] = Field( ..., description="The list of triples in the graph.", ) diff --git a/py/tests/core/pipes/test_kg_community_summary_pipe.py b/py/tests/core/pipes/test_kg_community_summary_pipe.py index 04f519d32..02ec7d280 100644 --- a/py/tests/core/pipes/test_kg_community_summary_pipe.py +++ b/py/tests/core/pipes/test_kg_community_summary_pipe.py @@ -9,7 +9,7 @@ CommunityReport, Entity, KGExtraction, - Triple, + Relationship, ) from core.pipes.kg.community_summary import KGCommunitySummaryPipe from shared.abstractions.vector import VectorQuantizationType @@ -125,7 +125,7 @@ def entities_list(extraction_ids, document_id, embedding_vectors): @pytest.fixture(scope="function") def triples_raw_list(embedding_vectors, extraction_ids, document_id): return [ - Triple( + Relationship( id=1, subject="Entity1", predicate="predicate1", @@ -137,7 +137,7 @@ def triples_raw_list(embedding_vectors, extraction_ids, document_id): document_id=document_id, attributes={"attr1": "value1", "attr2": "value2"}, ), - Triple( + Relationship( id=2, subject="Entity2", predicate="predicate2", diff --git a/py/tests/core/providers/kg/test_kg_logic.py b/py/tests/core/providers/kg/test_kg_logic.py index 9b018fec1..1b5108b11 100644 --- a/py/tests/core/providers/kg/test_kg_logic.py +++ b/py/tests/core/providers/kg/test_kg_logic.py @@ -4,7 +4,13 @@ import pytest -from core.base import Community, CommunityReport, Entity, KGExtraction, Triple +from core.base import ( + Community, + CommunityReport, + Entity, + KGExtraction, + Relationship, +) from shared.abstractions.vector import VectorQuantizationType @@ -89,7 +95,7 @@ def entities_list(extraction_ids, document_id, embedding_vectors): @pytest.fixture(scope="function") def triples_raw_list(embedding_vectors, extraction_ids, document_id): return [ - Triple( + Relationship( subject="Entity1", predicate="predicate1", object="object1", @@ -100,7 +106,7 @@ def triples_raw_list(embedding_vectors, extraction_ids, document_id): document_id=document_id, attributes={"attr1": "value1", "attr2": "value2"}, ), - Triple( + Relationship( subject="Entity2", predicate="predicate2", object="object2", From 547b9e9809c254852fd2789adc7adece2f3fee2f Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Mon, 11 Nov 2024 06:23:35 -0800 Subject: [PATCH 04/21] checkin --- py/core/main/api/v3/graph_router.py | 207 ++++++++++++++++++++++++++++ py/core/main/services/kg_service.py | 32 +++-- py/core/providers/database/kg.py | 70 +++++++--- 3 files changed, 278 insertions(+), 31 deletions(-) diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 5a56dfa2d..8bdbcd78c 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -44,6 +44,213 @@ def __init__( super().__init__(providers, services, orchestration_provider, run_type) def _setup_routes(self): + + ##### CHUNK LEVEL OPERATIONS ##### + + ##### ENTITIES ###### + @self.router.get( + "/chunks/{id}/entities", + summary="List entities for a chunk", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.list_entities(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def list_entities( + id: UUID = Path(..., description="The ID of the chunk to retrieve entities for."), + entity_names: Optional[list[str]] = Query(None, description="A list of entity names to filter the entities by."), + entity_categories: Optional[list[str]] = Query(None, description="A list of entity categories to filter the entities by."), + attributes: Optional[list[str]] = Query(None, description="A list of attributes to return. By default, all attributes are returned."), + offset: int = Query(0, ge=0, description="The offset of the first entity to retrieve."), + limit: int = Query(100, ge=0, le=20_000, description="The maximum number of entities to retrieve, up to 20,000."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> PaginatedResultsWrapper[list[Entity]]: + """ + Retrieves a list of entities associated with a specific chunk. + + Note that when entities are extracted, neighboring chunks are also processed together to extract entities. + + So, the entity returned here may not be in the same chunk as the one specified, but rather in a neighboring chunk (upto 2 chunks by default). + """ + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + return await self.services["kg"].list_entities( + level=EntityLevel.CHUNK, + id=id, + offset=offset, + limit=limit, + entity_names=entity_names, + entity_categories=entity_categories, + attributes=attributes + ) + + @self.router.post( + "/chunks/{id}/entities", + summary="Create entities for a chunk", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.create_entities(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def create_entities( + id: UUID = Path(..., description="The ID of the chunk to create entities for."), + entities: list[Union[Entity, dict]] = Body(..., description="The entities to create."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + entities = [Entity(**entity) if isinstance(entity, dict) else entity for entity in entities] + # for each entity, set the level to CHUNK + for entity in entities: + if entity.level is None: + entity.level = EntityLevel.CHUNK + else: + raise R2RException("Entity level must be chunk or empty.", 400) + + return await self.services["kg"].create_entities( + level=EntityLevel.CHUNK, + id=id, + entities=entities, + ) + + @self.router.post( + "/chunks/{id}/entities/{entity_id}", + summary="Update an entity for a chunk", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.update_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def update_entity( + id: UUID = Path(..., description="The ID of the chunk to update the entity for."), + entity_id: UUID = Path(..., description="The ID of the entity to update."), + entity: Entity = Body(..., description="The updated entity."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + return await self.services["kg"].update_entity( + level=EntityLevel.CHUNK, + id=id, + entity_id=entity_id, + entity=entity, + ) + + + @self.router.delete( + "/chunks/{id}/entities/{entity_id}", + summary="Delete an entity for a chunk", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.delete_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + """ + ), + }, + ] + }, + ) + + @self.base_endpoint + async def delete_entity( + id: UUID = Path(..., description="The ID of the chunk to delete the entity for."), + entity_id: UUID = Path(..., description="The ID of the entity to delete."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + ##### RELATIONSHIPS ##### + + + + + + + + + + + + + + + + + + + + + + ##### DOCUMENT LEVEL OPERATIONS ##### + + + + + + + + + + ##### COLLECTION LEVEL OPERATIONS ##### + + + + + # Graph-level operations @self.router.post( "/graphs/{collection_id}", diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 4d247e47b..f9bf5571a 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -6,6 +6,7 @@ from core.base import KGExtractionStatus, RunManager from core.base.abstractions import ( + EntityLevel, GenerationConfig, KGCreationSettings, KGEnrichmentSettings, @@ -116,15 +117,16 @@ async def kg_triples_extraction( return await _collect_results(result_gen) - @telemetry_event("create_entity") - async def create_entity( + @telemetry_event("create_entities") + async def create_entities( self, - collection_id: UUID, - entity: Entity, + level: EntityLevel, + id: UUID, + entities: list[Entity], **kwargs, ): - return await self.providers.database.create_entity( - collection_id, entity + return await self.providers.database.create_entities( + level, id, entities, **kwargs ) @telemetry_event("update_entity") @@ -383,17 +385,21 @@ async def get_enrichment_estimate( @telemetry_event("list_entities") async def list_entities( self, - collection_id: Optional[UUID] = None, - entity_ids: Optional[list[str]] = None, - entity_table_name: str = "document_entity", + level: EntityLevel, + id: Optional[UUID] = None, + entity_names: Optional[list[str]] = None, + entity_categories: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, offset: Optional[int] = None, limit: Optional[int] = None, **kwargs, ): - return await self.providers.database.get_entities( - collection_id=collection_id, - entity_ids=entity_ids, - entity_table_name=entity_table_name, + return await self.providers.database.get_entities_v3( + level=level, + id=id, + entity_names=entity_names, + entity_categories=entity_categories, + attributes=attributes, offset=offset or 0, limit=limit or -1, ) diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 13d4a4517..c69b4df23 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -854,25 +854,12 @@ async def get_community_details( ########## Entity CRUD Operations ########################## ############################################################ - async def create_entity( - self, collection_id: UUID, entity: Entity + async def create_entities( + self, level: EntityLevel, id: UUID, entities: list[Entity] ) -> None: - - table_name = entity.level.value + "_entity" - entity.level = None - - # check if the entity already exists - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 - """ - count = ( - await self.connection_manager.fetch_query(QUERY, [entity.id, collection_id]) - )[0]["count"] - - if count > 0: - raise R2RException("Entity already exists", 400) - - await self._add_objects([entity], table_name) + + # TODO: check if already exists + await self._add_objects(entities, level.table_name) async def update_entity( self, collection_id: UUID, entity: Entity @@ -1272,6 +1259,53 @@ async def get_schema(self): # somehow get the rds from the postgres db. raise NotImplementedError + async def get_entities_v3( + self, + level: EntityLevel, + id: Optional[UUID] = None, + entity_names: Optional[list[str]] = None, + entity_categories: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, + offset: int = 0, + limit: int = -1, + ): + + params: list = [id] + + if level != EntityLevel.CHUNK and entity_categories: + raise ValueError("entity_categories are only supported for chunk level entities") + + filter = { + EntityLevel.CHUNK: "chunk_ids = ANY($1)", + EntityLevel.DOCUMENT: "document_id = $1", + EntityLevel.COLLECTION: "collection_id = $1", + }[level] + + if entity_names: + filter += " AND name = ANY($2)" + params.append(entity_names) + + if entity_categories: + filter += " AND category = ANY($3)" + params.append(entity_categories) + + QUERY = f""" + SELECT * from {self._get_table_name(level.table_name)} WHERE {filter} + OFFSET ${len(params)} LIMIT ${len(params) + 1} + """ + + params.extend([offset, limit]) + + output = await self.connection_manager.fetch_query(QUERY, params) + + if attributes: + output = [entity for entity in output if entity["name"] in attributes] + + return output + + + + # TODO: deprecate this async def get_entities( self, collection_id: Optional[UUID] = None, From a2dc253f00624ea58223ba6968d61368ce95d701 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Mon, 11 Nov 2024 16:23:22 -0800 Subject: [PATCH 05/21] up --- py/core/main/api/v3/graph_router.py | 636 +++++++++++++++++++++++++++- py/core/main/services/kg_service.py | 18 + py/core/providers/database/kg.py | 48 ++- 3 files changed, 690 insertions(+), 12 deletions(-) diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 8bdbcd78c..ca7aece4a 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -28,8 +28,22 @@ from .base_router import BaseRouterV3 +from fastapi import Request + logger = logging.getLogger() +class EntityResponse(BaseModel): + id: UUID + name: str + category: str + +class RelationshipResponse(BaseModel): + id: UUID + subject_id: UUID + object_id: UUID + subject_name: str + object_name: str + predicate: str class GraphRouter(BaseRouterV3): def __init__( @@ -43,10 +57,16 @@ def __init__( ): super().__init__(providers, services, orchestration_provider, run_type) - def _setup_routes(self): - - ##### CHUNK LEVEL OPERATIONS ##### + def _get_path_level(self, request: Request) -> EntityLevel: + path = request.url.path + if "/chunks/" in path: + return EntityLevel.CHUNK + elif "/documents/" in path: + return EntityLevel.DOCUMENT + else: + return EntityLevel.COLLECTION + def _setup_routes(self): ##### ENTITIES ###### @self.router.get( "/chunks/{id}/entities", @@ -69,8 +89,51 @@ def _setup_routes(self): ] }, ) + @self.router.get( + "/documents/{id}/entities", + summary="List entities for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.documents.list_entities(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + """ + ), + }, + ] + }, + ) + @self.router.get( + "/collections/{id}/entities", + summary="List entities for a collection", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.collections.list_entities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + """ + ), + }, + ] + }, + ) @self.base_endpoint async def list_entities( + request: Request, id: UUID = Path(..., description="The ID of the chunk to retrieve entities for."), entity_names: Optional[list[str]] = Query(None, description="A list of entity names to filter the entities by."), entity_categories: Optional[list[str]] = Query(None, description="A list of entity categories to filter the entities by."), @@ -89,8 +152,8 @@ async def list_entities( if not auth_user.is_superuser: raise R2RException("Only superusers can access this endpoint.", 403) - return await self.services["kg"].list_entities( - level=EntityLevel.CHUNK, + return await self.services["kg"].list_entities_v3( + level=self._get_path_level(request), id=id, offset=offset, limit=limit, @@ -120,8 +183,51 @@ async def list_entities( ] }, ) + @self.router.post( + "/documents/{id}/entities", + summary="Create entities for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.documents.create_entities(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + """ + ), + }, + ] + }, + ) + @self.router.post( + "/collections/{id}/entities", + summary="Create entities for a collection", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.collections.create_entities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + """ + ), + }, + ] + }, + ) @self.base_endpoint async def create_entities( + request: Request, id: UUID = Path(..., description="The ID of the chunk to create entities for."), entities: list[Union[Entity, dict]] = Body(..., description="The entities to create."), auth_user=Depends(self.providers.auth.auth_wrapper), @@ -137,8 +243,8 @@ async def create_entities( else: raise R2RException("Entity level must be chunk or empty.", 400) - return await self.services["kg"].create_entities( - level=EntityLevel.CHUNK, + return await self.services["kg"].create_entities_v3( + level=self._get_path_level(request), id=id, entities=entities, ) @@ -166,6 +272,7 @@ async def create_entities( ) @self.base_endpoint async def update_entity( + request: Request, id: UUID = Path(..., description="The ID of the chunk to update the entity for."), entity_id: UUID = Path(..., description="The ID of the entity to update."), entity: Entity = Body(..., description="The updated entity."), @@ -174,14 +281,13 @@ async def update_entity( if not auth_user.is_superuser: raise R2RException("Only superusers can access this endpoint.", 403) - return await self.services["kg"].update_entity( - level=EntityLevel.CHUNK, + return await self.services["kg"].update_entity_v3( + level=self._get_path_level(request), id=id, entity_id=entity_id, entity=entity, ) - @self.router.delete( "/chunks/{id}/entities/{entity_id}", summary="Delete an entity for a chunk", @@ -203,9 +309,51 @@ async def update_entity( ] }, ) + @self.router.delete( + "/documents/{id}/entities/{entity_id}", + summary="Delete an entity for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.delete_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + """ + ), + }, + ] + }, + ) + @self.router.delete( + "/collections/{id}/entities/{entity_id}", + summary="Delete an entity for a collection", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.delete_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + """ + ), + }, + ] + }, + ) @self.base_endpoint async def delete_entity( + request: Request, id: UUID = Path(..., description="The ID of the chunk to delete the entity for."), entity_id: UUID = Path(..., description="The ID of the entity to delete."), auth_user=Depends(self.providers.auth.auth_wrapper), @@ -213,44 +361,510 @@ async def delete_entity( if not auth_user.is_superuser: raise R2RException("Only superusers can access this endpoint.", 403) + return await self.services["kg"].delete_entity_v3( + level=self._get_path_level(request), + id=id, + entity_id=entity_id, + ) + ##### RELATIONSHIPS ##### + @self.router.get( + "/chunks/{id}/relationships", + summary="List relationships for a chunk", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.list_relationships(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """ + ), + }, + ] + }, + ) + @self.router.get( + "/documents/{id}/relationships", + summary="List relationships for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.documents.list_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """ + ), + }, + ] + }, + ) + + @self.router.get( + "/chunks/{id}/relationships", + summary="List relationships for a chunk", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.collections.list_relationships(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def list_relationships( + id: UUID = Path(..., description="The ID of the chunk to retrieve relationships for."), + entity_names: Optional[list[str]] = Query(None, description="A list of entity names to filter the relationships by."), + relationship_types: Optional[list[str]] = Query(None, description="A list of relationship types to filter the relationships by."), + attributes: Optional[list[str]] = Query(None, description="A list of attributes to return. By default, all attributes are returned."), + offset: int = Query(0, ge=0, description="The offset of the first relationship to retrieve."), + limit: int = Query(100, ge=0, le=20_000, description="The maximum number of relationships to retrieve, up to 20,000."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> PaginatedResultsWrapper[list[Relationship]]: + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + return await self.services["kg"].list_relationships_v3( + level=EntityLevel.CHUNK, + id=id, + entity_names=entity_names, + relationship_types=relationship_types, + attributes=attributes, + offset=offset, + limit=limit, + ) + + + @self.router.post( + "/chunks/{id}/relationships", + summary="Create relationships for a chunk", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.create_relationships(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def create_relationships( + id: UUID = Path(..., description="The ID of the chunk to create relationships for."), + relationships: list[Union[Relationship, dict]] = Body(..., description="The relationships to create."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> ResultsWrapper[list[RelationshipResponse]]: + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + relationships = [Relationship(**relationship) if isinstance(relationship, dict) else relationship for relationship in relationships] + + return await self.services["kg"].create_relationships_v3( + level=EntityLevel.CHUNK, + id=id, + relationships=relationships, + ) + + @self.router.post( + "/chunks/{id}/relationships/{relationship_id}", + summary="Update a relationship for a chunk", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.update_relationship(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def update_relationship( + id: UUID = Path(..., description="The ID of the chunk to update the relationship for."), + relationship_id: UUID = Path(..., description="The ID of the relationship to update."), + relationship: Relationship = Body(..., description="The updated relationship."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + + return await self.services["kg"].update_relationship_v3( + level=EntityLevel.CHUNK, + id=id, + relationship_id=relationship_id, + relationship=relationship, + ) + + @self.router.delete( + "/chunks/{id}/relationships/{relationship_id}", + summary="Delete a relationship for a chunk", + ) + @self.base_endpoint + async def delete_relationship( + id: UUID = Path(..., description="The ID of the chunk to delete the relationship for."), + relationship_id: UUID = Path(..., description="The ID of the relationship to delete."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + return await self.services["kg"].delete_relationship_v3( + level=EntityLevel.CHUNK, + id=id, + relationship_id=relationship_id, + ) + + ##### DOCUMENT LEVEL OPERATIONS ##### + @self.router.get( + "/documents/{id}/entities", + summary="List entities for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.list_entities(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def list_entities( + id: UUID = Path(..., description="The ID of the document to retrieve entities for."), + entity_names: Optional[list[str]] = Query(None, description="A list of entity names to filter the entities by."), + entity_categories: Optional[list[str]] = Query(None, description="A list of entity categories to filter the entities by."), + attributes: Optional[list[str]] = Query(None, description="A list of attributes to return. By default, all attributes are returned."), + offset: int = Query(0, ge=0, description="The offset of the first entity to retrieve."), + limit: int = Query(100, ge=0, le=20_000, description="The maximum number of entities to retrieve, up to 20,000."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> PaginatedResultsWrapper[list[Entity]]: + """ + Retrieves a list of entities associated with a specific document. + """ + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + return await self.services["kg"].list_entities_v3( + level=EntityLevel.DOCUMENT, + id=id, + offset=offset, + limit=limit, + entity_names=entity_names, + entity_categories=entity_categories, + attributes=attributes + ) + @self.router.post( + "/documents/{id}/entities", + summary="Create entities for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + result = client.documents.create_entities(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def create_entities( + id: UUID = Path(..., description="The ID of the chunk to create entities for."), + entities: list[Union[Entity, dict]] = Body(..., description="The entities to create."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + entities = [Entity(**entity) if isinstance(entity, dict) else entity for entity in entities] + # for each entity, set the level to CHUNK + for entity in entities: + if entity.level is None: + entity.level = EntityLevel.DOCUMENT + else: + raise R2RException("Entity level must be chunk or empty.", 400) + return await self.services["kg"].create_entities_v3( + level=EntityLevel.DOCUMENT, + id=id, + entities=entities, + ) + @self.router.post( + "/documents/{id}/entities/{entity_id}", + summary="Update an entity for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + result = client.documents.update_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def update_entity( + id: UUID = Path(..., description="The ID of the document to update the entity for."), + entity_id: UUID = Path(..., description="The ID of the entity to update."), + entity: Entity = Body(..., description="The updated entity."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + return await self.services["kg"].update_entity_v3( + level=EntityLevel.DOCUMENT, + id=id, + entity_id=entity_id, + entity=entity, + ) + @self.router.delete( + "/documents/{id}/entities/{entity_id}", + summary="Delete an entity for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + result = client.documents.delete_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def delete_entity( + id: UUID = Path(..., description="The ID of the document to delete the entity for."), + entity_id: UUID = Path(..., description="The ID of the entity to delete."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + ##### RELATIONSHIPS ##### + @self.router.get( + "/documents/{id}/relationships", + summary="List relationships for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + result = client.documents.list_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def list_relationships( + id: UUID = Path(..., description="The ID of the document to retrieve relationships for."), + entity_names: Optional[list[str]] = Query(None, description="A list of entity names to filter the relationships by."), + relationship_types: Optional[list[str]] = Query(None, description="A list of relationship types to filter the relationships by."), + attributes: Optional[list[str]] = Query(None, description="A list of attributes to return. By default, all attributes are returned."), + offset: int = Query(0, ge=0, description="The offset of the first relationship to retrieve."), + limit: int = Query(100, ge=0, le=20_000, description="The maximum number of relationships to retrieve, up to 20,000."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> PaginatedResultsWrapper[list[Relationship]]: + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + return await self.services["kg"].list_relationships_v3( + level=EntityLevel.DOCUMENT, + id=id, + entity_names=entity_names, + relationship_types=relationship_types, + attributes=attributes, + offset=offset, + limit=limit, + ) - ##### DOCUMENT LEVEL OPERATIONS ##### + @self.router.post( + "/documents/{id}/relationships", + summary="Create relationships for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + result = client.documents.create_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def create_relationships( + id: UUID = Path(..., description="The ID of the document to create relationships for."), + relationships: list[Union[Relationship, dict]] = Body(..., description="The relationships to create."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> ResultsWrapper[list[RelationshipResponse]]: + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + relationships = [Relationship(**relationship) if isinstance(relationship, dict) else relationship for relationship in relationships] + return await self.services["kg"].create_relationships_v3( + level=EntityLevel.DOCUMENT, + id=id, + relationships=relationships, + ) + + @self.router.post( + "/documents/{id}/relationships/{relationship_id}", + summary="Update a relationship for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + result = client.documents.update_relationship(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) + """ + ), + }, + ] + }, + ) + @self.base_endpoint + async def update_relationship( + id: UUID = Path(..., description="The ID of the document to update the relationship for."), + relationship_id: UUID = Path(..., description="The ID of the relationship to update."), + relationship: Relationship = Body(..., description="The updated relationship."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + return await self.services["kg"].update_relationship_v3( + level=EntityLevel.DOCUMENT, + id=id, + relationship_id=relationship_id, + relationship=relationship, + ) + + @self.router.delete( + "/documents/{id}/relationships/{relationship_id}", + summary="Delete a relationship for a document", + ) + @self.base_endpoint + async def delete_relationship( + id: UUID = Path(..., description="The ID of the document to delete the relationship for."), + relationship_id: UUID = Path(..., description="The ID of the relationship to delete."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + return await self.services["kg"].delete_relationship_v3( + level=EntityLevel.DOCUMENT, + id=id, + relationship_id=relationship_id, + ) + ##### COLLECTION LEVEL OPERATIONS ##### + # Graph-level operations @self.router.post( "/graphs/{collection_id}", diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index f9bf5571a..028488d1f 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -440,6 +440,8 @@ async def get_triples( limit=limit or -1, ) + ##### Relationships ##### + @telemetry_event("list_triples") async def list_triples( self, @@ -458,6 +460,22 @@ async def list_triples( limit=limit or -1, ) + @telemetry_event("list_relationships") + async def list_relationships_v3( + self, + level: EntityLevel, + id: UUID, + entity_names: Optional[list[str]] = None, + relationship_types: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, + ): + return await self.providers.database.list_relationships_v3( + level, id, entity_names, relationship_types, attributes, offset, limit + ) + + ##### Communities ##### @telemetry_event("get_communities") async def get_communities( self, diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index c69b4df23..a6d16ab58 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -297,7 +297,7 @@ async def get_graph_status(self, collection_id: UUID) -> dict: "community_count": community_count[0]["count"], } - + ### Relationships BEGIN #### async def add_triples( self, triples: list[Relationship], @@ -317,6 +317,52 @@ async def add_triples( [ele.to_dict() for ele in triples], table_name ) + async def list_relationships_v3( + self, + level: EntityLevel, + id: UUID, + entity_names: Optional[list[str]] = None, + relationship_types: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, + ): + filter_query = "" + if entity_names: + filter_query += "AND (subject IN ($2) OR object IN ($2))" + if relationship_types: + filter_query += "AND predicate IN ($3)" + + if level == EntityLevel.CHUNK: + QUERY = f""" + SELECT * FROM {self._get_table_name("chunk_triple")} WHERE $1 = ANY(chunk_ids) + {filter_query} + """ + elif level == EntityLevel.DOCUMENT: + QUERY = f""" + SELECT * FROM {self._get_table_name("chunk_triple")} WHERE $1 = document_id + {filter_query} + """ + elif level == EntityLevel.COLLECTION: + QUERY = f""" + WITH document_ids AS ( + SELECT document_id FROM {self._get_table_name("document_info")} WHERE $1 = ANY(collection_ids) + ) + SELECT * FROM {self._get_table_name("chunk_triple")} WHERE document_id IN (SELECT document_id FROM document_ids) + {filter_query} + """ + + results = await self.connection_manager.fetch_query(QUERY, [id, entity_names, relationship_types]) + + if attributes: + results = [ + {k: v for k, v in result.items() if k in attributes} for result in results + ] + + return results + + ### Relationships END #### + async def add_kg_extractions( self, kg_extractions: list[KGExtraction], From e1b5443852477e1d1e2bf999b5f99267aba61c02 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 13 Nov 2024 09:29:11 -0800 Subject: [PATCH 06/21] up --- .../r2rV2ClientIntegrationSuperUser.test.ts | 2 +- .../r2rV2ClientIntegrationUser.test.ts | 2 +- js/sdk/src/models.tsx | 4 +- js/sdk/src/r2rClient.ts | 16 +- py/cli/commands/kg.py | 18 +- py/core/__init__.py | 2 +- py/core/base/api/models/__init__.py | 2 +- py/core/base/providers/database.py | 66 ++-- py/core/configs/full_local_llm.toml | 6 +- py/core/configs/local_llm.toml | 6 +- .../examples/scripts/advanced_kg_cookbook.py | 2 +- .../scripts/test_v3_sdk/test_v3_sdk_graph.py | 2 +- py/core/main/abstractions.py | 2 +- py/core/main/api/v2/kg_router.py | 28 +- py/core/main/api/v3/graph_router.py | 11 +- py/core/main/assembly/factory.py | 14 +- .../main/orchestration/hatchet/kg_workflow.py | 6 +- .../main/orchestration/simple/kg_workflow.py | 4 +- py/core/main/services/kg_service.py | 44 +-- py/core/pipes/__init__.py | 4 +- py/core/pipes/kg/clustering.py | 6 +- py/core/pipes/kg/community_summary.py | 48 +-- py/core/pipes/kg/entity_description.py | 16 +- ...raction.py => relationships_extraction.py} | 38 +-- py/core/providers/database/kg.py | 282 +++++++++--------- .../prompts/graphrag_community_reports.yaml | 6 +- .../prompts/graphrag_entity_description.yaml | 10 +- ...ag_relationships_extraction_few_shot.yaml} | 6 +- .../d342e632358a_migrate_to_asyncpg.py | 4 +- py/r2r.toml | 6 +- py/sdk/v2/mixins/kg.py | 20 +- py/sdk/v3/graphs.py | 6 +- py/shared/abstractions/graph.py | 6 +- py/shared/abstractions/kg.py | 10 +- py/shared/abstractions/vector.py | 4 +- py/shared/api/models/__init__.py | 1 + py/shared/api/models/kg/responses.py | 30 +- .../pipes/test_kg_community_summary_pipe.py | 12 +- py/tests/core/providers/kg/test_kg_logic.py | 84 +++--- py/tests/integration/runner_cli.py | 4 +- .../python-backend/prompts.yaml | 6 +- .../ycombinator_graphrag/web-app/types.ts | 2 +- 42 files changed, 426 insertions(+), 422 deletions(-) rename py/core/pipes/kg/{triples_extraction.py => relationships_extraction.py} (87%) rename py/core/providers/database/prompts/{graphrag_triples_extraction_few_shot.yaml => graphrag_relationships_extraction_few_shot.yaml} (98%) diff --git a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts index 1746830a3..302348e65 100644 --- a/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts +++ b/js/sdk/__tests__/r2rV2ClientIntegrationSuperUser.test.ts @@ -81,7 +81,7 @@ let newCollectionId: string; * X createGraph * X enrichGraph * X getEntities - * X getTriples + * X getRelationships * X getCommunities * X getTunedPrompt * X deduplicateEntities diff --git a/js/sdk/__tests__/r2rV2ClientIntegrationUser.test.ts b/js/sdk/__tests__/r2rV2ClientIntegrationUser.test.ts index 93f0d3131..5c290d96d 100644 --- a/js/sdk/__tests__/r2rV2ClientIntegrationUser.test.ts +++ b/js/sdk/__tests__/r2rV2ClientIntegrationUser.test.ts @@ -77,7 +77,7 @@ const baseUrl = "http://localhost:7272"; * X createGraph * X enrichGraph * X getEntities - * X getTriples + * X getRelationships * X getCommunities * X getTunedPrompt * X deduplicateEntities diff --git a/js/sdk/src/models.tsx b/js/sdk/src/models.tsx index 1684ba14d..b0aaa9a05 100644 --- a/js/sdk/src/models.tsx +++ b/js/sdk/src/models.tsx @@ -74,13 +74,13 @@ export enum KGRunType { } export interface KGCreationSettings { - kg_triples_extraction_prompt?: string; + kg_relationships_extraction_prompt?: string; kg_entity_description_prompt?: string; force_kg_creation?: boolean; entity_types?: string[]; relation_types?: string[]; extractions_merge_count?: number; - max_knowledge_triples?: number; + max_knowledge_relationships?: number; max_description_input_length?: number; generation_config?: GenerationConfig; } diff --git a/js/sdk/src/r2rClient.ts b/js/sdk/src/r2rClient.ts index 150d38229..3a658bbfa 100644 --- a/js/sdk/src/r2rClient.ts +++ b/js/sdk/src/r2rClient.ts @@ -1682,21 +1682,21 @@ export class r2rClient extends BaseClient { } /** - * Retrieve triples from the knowledge graph. + * Retrieve relationships from the knowledge graph. * @returns A promise that resolves to the response from the server. * @param collection_id The ID of the collection to retrieve entities for. * @param offset The offset for pagination. * @param limit The limit for pagination. * @param entity_level The level of entity to filter by. - * @param triple_ids Triple IDs to filter by. + * @param relationship_ids Relationship IDs to filter by. */ - @feature("getTriples") - async getTriples( + @feature("getRelationships") + async getRelationships( collection_id?: string, offset?: number, limit?: number, entity_level?: string, - triple_ids?: string[], + relationship_ids?: string[], ): Promise { this._ensureAuthenticated(); @@ -1713,11 +1713,11 @@ export class r2rClient extends BaseClient { if (entity_level !== undefined) { params.entity_level = entity_level; } - if (triple_ids !== undefined) { - params.entity_ids = triple_ids; + if (relationship_ids !== undefined) { + params.entity_ids = relationship_ids; } - return this._makeRequest("GET", `triples`, { params }); + return this._makeRequest("GET", `relationships`, { params }); } /** diff --git a/py/cli/commands/kg.py b/py/cli/commands/kg.py index cb989dc25..f01236b15 100644 --- a/py/cli/commands/kg.py +++ b/py/cli/commands/kg.py @@ -223,7 +223,7 @@ async def get_entities( @click.option( "--collection-id", required=True, - help="Collection ID to retrieve triples from.", + help="Collection ID to retrieve relationships from.", ) @click.option( "--offset", @@ -238,9 +238,9 @@ async def get_entities( help="Limit for pagination.", ) @click.option( - "--triple-ids", + "--relationship-ids", multiple=True, - help="Triple IDs to filter by.", + help="Relationship IDs to filter by.", ) @click.option( "--entity-names", @@ -248,21 +248,21 @@ async def get_entities( help="Entity names to filter by.", ) @pass_context -async def get_triples( - ctx, collection_id, offset, limit, triple_ids, entity_names +async def get_relationships( + ctx, collection_id, offset, limit, relationship_ids, entity_names ): """ - Retrieve triples from the knowledge graph. + Retrieve relationships from the knowledge graph. """ client = ctx.obj with timer(): - response = await client.get_triples( + response = await client.get_relationships( collection_id, offset, limit, list(entity_names), - list(triple_ids), + list(relationship_ids), ) click.echo(json.dumps(response, indent=2)) @@ -284,7 +284,7 @@ async def delete_graph_for_collection(ctx, collection_id, cascade): """ Delete the graph for a given collection. - NOTE: Setting the cascade flag to true will delete entities and triples for documents that are shared across multiple collections. Do not set this flag unless you are absolutely sure that you want to delete the entities and triples for all documents in the collection. + NOTE: Setting the cascade flag to true will delete entities and relationships for documents that are shared across multiple collections. Do not set this flag unless you are absolutely sure that you want to delete the entities and relationships for all documents in the collection. """ client = ctx.obj diff --git a/py/core/__init__.py b/py/core/__init__.py index 5a512c7e7..35dad1455 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -215,7 +215,7 @@ ## PIPES "SearchPipe", "EmbeddingPipe", - "KGTriplesExtractionPipe", + "KGRelationshipsExtractionPipe", "ParsingPipe", "QueryTransformPipe", "SearchRAGPipe", diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index 4403e9352..d71a4eef0 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -30,7 +30,7 @@ WrappedKGEnrichmentResponse, WrappedKGEntitiesResponse, WrappedKGEntityDeduplicationResponse, - WrappedKGTriplesResponse, + WrappedKGRelationshipsResponse, WrappedKGTunePromptResponse, ) from shared.api.models.management.responses import ( diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 2788c364c..cf5ae10b7 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -623,12 +623,12 @@ async def add_entities( pass @abstractmethod - async def add_triples( + async def add_relationships( self, - triples: list[Relationship], - table_name: str = "chunk_triple", + relationships: list[Relationship], + table_name: str = "chunk_relationship", ) -> None: - """Add triples to storage.""" + """Add relationships to storage.""" pass @abstractmethod @@ -740,15 +740,15 @@ async def get_entities( pass @abstractmethod - async def get_triples( + async def get_relationships( self, offset: int, limit: int, collection_id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, - triple_ids: Optional[list[str]] = None, + relationship_ids: Optional[list[str]] = None, ) -> dict: - """Get triples from storage.""" + """Get relationships from storage.""" pass @abstractmethod @@ -763,12 +763,12 @@ async def get_entity_count( pass @abstractmethod - async def get_triple_count( + async def get_relationship_count( self, collection_id: Optional[UUID] = None, document_id: Optional[UUID] = None, ) -> int: - """Get triple count.""" + """Get relationship count.""" pass # Cost estimation methods @@ -802,8 +802,8 @@ async def create_vector_index(self) -> None: raise NotImplementedError @abstractmethod - async def delete_triples(self, triple_ids: list[int]) -> None: - """Delete triples.""" + async def delete_relationships(self, relationship_ids: list[int]) -> None: + """Delete relationships.""" raise NotImplementedError @abstractmethod @@ -827,8 +827,8 @@ async def update_kg_search_prompt(self) -> None: raise NotImplementedError @abstractmethod - async def upsert_triples(self) -> None: - """Upsert triples.""" + async def upsert_relationships(self) -> None: + """Upsert relationships.""" raise NotImplementedError @abstractmethod @@ -839,7 +839,7 @@ async def get_existing_entity_extraction_ids( raise NotImplementedError @abstractmethod - async def get_all_triples(self, collection_id: UUID) -> List[Triple]: + async def get_all_relationships(self, collection_id: UUID) -> list[Relationship]: raise NotImplementedError @abstractmethod @@ -1517,13 +1517,13 @@ async def add_entities( entities, table_name, conflict_columns ) - async def add_triples( + async def add_relationships( self, - triples: list[Relationship], - table_name: str = "chunk_triple", + relationships: list[Relationship], + table_name: str = "chunk_relationship", ) -> None: - """Forward to KG handler add_triples method.""" - return await self.kg_handler.add_triples(triples, table_name) + """Forward to KG handler add_relationships method.""" + return await self.kg_handler.add_relationships(relationships, table_name) async def get_entity_map( self, offset: int, limit: int, document_id: UUID @@ -1638,21 +1638,21 @@ async def get_entities( extra_columns=extra_columns, ) - async def get_triples( + async def get_relationships( self, offset: int, limit: int, collection_id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, - triple_ids: Optional[list[str]] = None, + relationship_ids: Optional[list[str]] = None, ) -> dict: - """Forward to KG handler get_triples method.""" - return await self.kg_handler.get_triples( + """Forward to KG handler get_relationships method.""" + return await self.kg_handler.get_relationships( offset=offset, limit=limit, collection_id=collection_id, entity_names=entity_names, - triple_ids=triple_ids, + relationship_ids=relationship_ids, ) async def get_entity_count( @@ -1667,13 +1667,13 @@ async def get_entity_count( collection_id, document_id, distinct, entity_table_name ) - async def get_triple_count( + async def get_relationship_count( self, collection_id: Optional[UUID] = None, document_id: Optional[UUID] = None, ) -> int: - """Forward to KG handler get_triple_count method.""" - return await self.kg_handler.get_triple_count( + """Forward to KG handler get_relationship_count method.""" + return await self.kg_handler.get_relationship_count( collection_id, document_id ) @@ -1704,8 +1704,8 @@ async def get_deduplication_estimate( collection_id, kg_deduplication_settings ) - async def get_all_triples(self, collection_id: UUID) -> List[Triple]: - return await self.kg_handler.get_all_triples(collection_id) + async def get_all_relationships(self, collection_id: UUID) -> list[Relationship]: + return await self.kg_handler.get_all_relationships(collection_id) async def update_entity_descriptions(self, entities: list[Entity]): return await self.kg_handler.update_entity_descriptions(entities) @@ -1718,8 +1718,8 @@ async def vector_query( async def create_vector_index(self) -> None: return await self.kg_handler.create_vector_index() - async def delete_triples(self, triple_ids: list[int]) -> None: - return await self.kg_handler.delete_triples(triple_ids) + async def delete_relationships(self, relationship_ids: list[int]) -> None: + return await self.kg_handler.delete_relationships(relationship_ids) async def get_schema(self) -> Any: return await self.kg_handler.get_schema() @@ -1733,8 +1733,8 @@ async def update_extraction_prompt(self) -> None: async def update_kg_search_prompt(self) -> None: return await self.kg_handler.update_kg_search_prompt() - async def upsert_triples(self) -> None: - return await self.kg_handler.upsert_triples() + async def upsert_relationships(self) -> None: + return await self.kg_handler.upsert_relationships() async def get_existing_entity_extraction_ids( self, document_id: UUID diff --git a/py/core/configs/full_local_llm.toml b/py/core/configs/full_local_llm.toml index 548e5ca84..97e5e253c 100644 --- a/py/core/configs/full_local_llm.toml +++ b/py/core/configs/full_local_llm.toml @@ -23,13 +23,13 @@ provider = "postgres" [database.kg_creation_settings] kg_entity_description_prompt = "graphrag_entity_description" - kg_triples_extraction_prompt = "graphrag_triples_extraction_few_shot" + kg_relationships_extraction_prompt = "graphrag_relationships_extraction_few_shot" entity_types = [] # if empty, all entities are extracted relation_types = [] # if empty, all relations are extracted fragment_merge_count = 4 # number of fragments to merge into a single extraction - max_knowledge_triples = 100 + max_knowledge_relationships = 100 max_description_input_length = 65536 - generation_config = { model = "ollama/llama3.1" } # and other params, model used for triplet extraction + generation_config = { model = "ollama/llama3.1" } # and other params, model used for relationshipt extraction [database.kg_entity_deduplication_settings] kg_entity_deduplication_type = "by_name" diff --git a/py/core/configs/local_llm.toml b/py/core/configs/local_llm.toml index 9a0196f96..71338b8b8 100644 --- a/py/core/configs/local_llm.toml +++ b/py/core/configs/local_llm.toml @@ -30,13 +30,13 @@ provider = "postgres" [database.kg_creation_settings] kg_entity_description_prompt = "graphrag_entity_description" - kg_triples_extraction_prompt = "graphrag_triples_extraction_few_shot" + kg_relationships_extraction_prompt = "graphrag_relationships_extraction_few_shot" entity_types = [] # if empty, all entities are extracted relation_types = [] # if empty, all relations are extracted fragment_merge_count = 4 # number of fragments to merge into a single extraction - max_knowledge_triples = 100 + max_knowledge_relationships = 100 max_description_input_length = 65536 - generation_config = { model = "ollama/llama3.1" } # and other params, model used for triplet extraction + generation_config = { model = "ollama/llama3.1" } # and other params, model used for relationshipt extraction [database.kg_entity_deduplication_settings] kg_entity_deduplication_type = "by_name" diff --git a/py/core/examples/scripts/advanced_kg_cookbook.py b/py/core/examples/scripts/advanced_kg_cookbook.py index 4084642fc..586888dfe 100644 --- a/py/core/examples/scripts/advanced_kg_cookbook.py +++ b/py/core/examples/scripts/advanced_kg_cookbook.py @@ -126,7 +126,7 @@ def main( client = R2RClient(base_url=base_url) - prompt = "graphrag_triples_extraction_few_shot" + prompt = "graphrag_relationships_extraction_few_shot" client.update_prompt( prompt, diff --git a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py index 46e107580..1cfb5435d 100644 --- a/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py +++ b/py/core/examples/scripts/test_v3_sdk/test_v3_sdk_graph.py @@ -158,7 +158,7 @@ def test_graph_operations(collection_id): print("\n--- Test 8: Tune Prompt ---") tune_result = client.graphs.tune_prompt( collection_id=collection_id, - prompt_name="graphrag_triples_extraction_few_shot", + prompt_name="graphrag_relationships_extraction_few_shot", documents_limit=100, chunks_limit=1000, ) diff --git a/py/core/main/abstractions.py b/py/core/main/abstractions.py index 4a7b71f28..20cea75bc 100644 --- a/py/core/main/abstractions.py +++ b/py/core/main/abstractions.py @@ -48,7 +48,7 @@ class R2RPipes(BaseModel): parsing_pipe: AsyncPipe embedding_pipe: AsyncPipe kg_search_pipe: AsyncPipe - kg_triples_extraction_pipe: AsyncPipe + kg_relationships_extraction_pipe: AsyncPipe kg_storage_pipe: AsyncPipe kg_entity_description_pipe: AsyncPipe kg_clustering_pipe: AsyncPipe diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index f5894d89e..ad335f4b6 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -14,7 +14,7 @@ WrappedKGEnrichmentResponse, WrappedKGEntitiesResponse, WrappedKGEntityDeduplicationResponse, - WrappedKGTriplesResponse, + WrappedKGRelationshipsResponse, WrappedKGTunePromptResponse, ) @@ -289,17 +289,17 @@ async def get_entities( limit=limit, ) - @self.router.get("/triples") + @self.router.get("/relationships") @self.base_endpoint - async def get_triples( + async def get_relationships( collection_id: Optional[UUID] = Query( - None, description="Collection ID to retrieve triples from." + None, description="Collection ID to retrieve relationships from." ), entity_names: Optional[list[str]] = Query( None, description="Entity names to filter by." ), - triple_ids: Optional[list[str]] = Query( - None, description="Triple IDs to filter by." + relationship_ids: Optional[list[str]] = Query( + None, description="Relationship IDs to filter by." ), offset: int = Query(0, ge=0, description="Offset for pagination."), limit: int = Query( @@ -308,9 +308,9 @@ async def get_triples( description="Number of items to return. Use -1 to return all items.", ), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedKGTriplesResponse: + ) -> WrappedKGRelationshipsResponse: """ - Retrieve triples from the knowledge graph. + Retrieve relationships from the knowledge graph. """ if not auth_user.is_superuser: logger.warning("Implement permission checks here.") @@ -320,12 +320,12 @@ async def get_triples( auth_user.id ) - return await self.service.get_triples( + return await self.service.get_relationships( offset=offset, limit=limit, collection_id=collection_id, entity_names=entity_names, - triple_ids=triple_ids, + relationship_ids=relationship_ids, ) @self.router.get("/communities") @@ -434,7 +434,7 @@ async def deduplicate_entities( async def get_tuned_prompt( prompt_name: str = Query( ..., - description="The name of the prompt to tune. Valid options are 'graphrag_triples_extraction_few_shot', 'graphrag_entity_description' and 'graphrag_community_reports'.", + description="The name of the prompt to tune. Valid options are 'graphrag_relationships_extraction_few_shot', 'graphrag_entity_description' and 'graphrag_community_reports'.", ), collection_id: Optional[UUID] = Query( None, description="Collection ID to retrieve communities from." @@ -481,7 +481,7 @@ async def delete_graph_for_collection( ), cascade: bool = Body( # FIXME: This should be a query parameter default=False, - description="Whether to cascade the deletion, and delete entities and triples belonging to the collection.", + description="Whether to cascade the deletion, and delete entities and relationships belonging to the collection.", ), auth_user=Depends(self.service.providers.auth.auth_wrapper), ): @@ -489,9 +489,9 @@ async def delete_graph_for_collection( Delete the graph for a given collection. Note that this endpoint may delete a large amount of data created by the KG pipeline, this deletion is irreversible, and recreating the graph may be an expensive operation. Notes: - The endpoint deletes all communities for a given collection. If the cascade flag is set to true, the endpoint also deletes all the entities and triples associated with the collection. + The endpoint deletes all communities for a given collection. If the cascade flag is set to true, the endpoint also deletes all the entities and relationships associated with the collection. - WARNING: Setting this flag to true will delete entities and triples for documents that are shared across multiple collections. Do not set this flag unless you are absolutely sure that you want to delete the entities and triples for all documents in the collection. + WARNING: Setting this flag to true will delete entities and relationships for documents that are shared across multiple collections. Do not set this flag unless you are absolutely sure that you want to delete the entities and relationships for all documents in the collection. """ if not auth_user.is_superuser: diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 322656e8e..3661f8be0 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -1,6 +1,6 @@ import logging import textwrap -from typing import Optional +from typing import Optional, Union from uuid import UUID from fastapi import Body, Depends, Path, Query @@ -13,6 +13,7 @@ WrappedKGEnrichmentResponse, WrappedKGEntityDeduplicationResponse, WrappedKGTunePromptResponse, + WrappedKGRelationshipsResponse, ) from core.providers import ( HatchetOrchestrationProvider, @@ -503,7 +504,7 @@ async def create_relationships( id: UUID = Path(..., description="The ID of the chunk to create relationships for."), relationships: list[Union[Relationship, dict]] = Body(..., description="The relationships to create."), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[list[RelationshipResponse]]: + ) -> WrappedKGRelationshipsResponse: if not auth_user.is_superuser: raise R2RException("Only superusers can access this endpoint.", 403) @@ -2229,7 +2230,7 @@ async def delete_community( result = client.graphs.tune_prompt( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - prompt_name="graphrag_triples_extraction_few_shot", + prompt_name="graphrag_relationships_extraction_few_shot", documents_limit=100, chunks_limit=1000 )""" @@ -2243,7 +2244,7 @@ async def delete_community( -H "Content-Type: application/json" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -d '{ - "prompt_name": "graphrag_triples_extraction_few_shot", + "prompt_name": "graphrag_relationships_extraction_few_shot", "documents_limit": 100, "chunks_limit": 1000 }'""" @@ -2257,7 +2258,7 @@ async def tune_prompt( collection_id: UUID = Path(...), prompt_name: str = Body( ..., - description="The prompt to tune. Valid options: graphrag_triples_extraction_few_shot, graphrag_entity_description, graphrag_community_reports", + description="The prompt to tune. Valid options: graphrag_relationships_extraction_few_shot, graphrag_entity_description, graphrag_community_reports", ), documents_offset: int = Body(0, ge=0), documents_limit: int = Body(100, ge=1), diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index 8ba3efc76..fbe6f7a8d 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -364,7 +364,7 @@ def create_pipes( self, parsing_pipe_override: Optional[AsyncPipe] = None, embedding_pipe_override: Optional[AsyncPipe] = None, - kg_triples_extraction_pipe_override: Optional[AsyncPipe] = None, + kg_relationships_extraction_pipe_override: Optional[AsyncPipe] = None, kg_storage_pipe_override: Optional[AsyncPipe] = None, kg_search_pipe_override: Optional[AsyncPipe] = None, vector_storage_pipe_override: Optional[AsyncPipe] = None, @@ -389,8 +389,8 @@ def create_pipes( ), embedding_pipe=embedding_pipe_override or self.create_embedding_pipe(*args, **kwargs), - kg_triples_extraction_pipe=kg_triples_extraction_pipe_override - or self.create_kg_triples_extraction_pipe(*args, **kwargs), + kg_relationships_extraction_pipe=kg_relationships_extraction_pipe_override + or self.create_kg_relationships_extraction_pipe(*args, **kwargs), kg_storage_pipe=kg_storage_pipe_override or self.create_kg_storage_pipe(*args, **kwargs), vector_storage_pipe=vector_storage_pipe_override @@ -535,14 +535,14 @@ def create_vector_search_pipe(self, *args, **kwargs) -> Any: config=AsyncPipe.PipeConfig(name="routing_search_pipe"), ) - def create_kg_triples_extraction_pipe(self, *args, **kwargs) -> Any: - from core.pipes import KGTriplesExtractionPipe + def create_kg_relationships_extraction_pipe(self, *args, **kwargs) -> Any: + from core.pipes import KGRelationshipsExtractionPipe - return KGTriplesExtractionPipe( + return KGRelationshipsExtractionPipe( logging_provider=self.providers.logging, llm_provider=self.providers.llm, database_provider=self.providers.database, - config=AsyncPipe.PipeConfig(name="kg_triples_extraction_pipe"), + config=AsyncPipe.PipeConfig(name="kg_relationships_extraction_pipe"), ) def create_kg_storage_pipe(self, *args, **kwargs) -> Any: diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 11d9c05a3..63e408692 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -89,17 +89,17 @@ async def kg_extract(self, context: Context) -> dict: # context.log(f"Running KG Extraction for collection ID: {input_data['collection_id']}") document_id = input_data["document_id"] - await self.kg_service.kg_triples_extraction( + await self.kg_service.kg_relationships_extraction( document_id=uuid.UUID(document_id), **input_data["kg_creation_settings"], ) logger.info( - f"Successfully ran kg triples extraction for document {document_id}" + f"Successfully ran kg relationships extraction for document {document_id}" ) return { - "result": f"successfully ran kg triples extraction for document {document_id} in {time.time() - start_time:.2f} seconds", + "result": f"successfully ran kg relationships extraction for document {document_id} in {time.time() - start_time:.2f} seconds", } @orchestration_provider.step( diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index 60bc47735..6122a6148 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -45,10 +45,10 @@ async def create_graph(input_data): ) for _, document_id in enumerate(document_ids): - # Extract triples from the document + # Extract relationships from the document try: - await service.kg_triples_extraction( + await service.kg_relationships_extraction( document_id=document_id, **input_data["kg_creation_settings"], ) diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 29ac0f10d..d2f9a2a15 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -58,13 +58,13 @@ def __init__( logging_connection, ) - @telemetry_event("kg_triples_extraction") - async def kg_triples_extraction( + @telemetry_event("kg_relationships_extraction") + async def kg_relationships_extraction( self, document_id: UUID, generation_config: GenerationConfig, extraction_merge_count: int, - max_knowledge_triples: int, + max_knowledge_relationships: int, entity_types: list[str], relation_types: list[str], **kwargs, @@ -81,13 +81,13 @@ async def kg_triples_extraction( status=KGExtractionStatus.PROCESSING, ) - triples = await self.pipes.kg_triples_extraction_pipe.run( - input=self.pipes.kg_triples_extraction_pipe.Input( + relationships = await self.pipes.kg_relationships_extraction_pipe.run( + input=self.pipes.kg_relationships_extraction_pipe.Input( message={ "document_id": document_id, "generation_config": generation_config, "extraction_merge_count": extraction_merge_count, - "max_knowledge_triples": max_knowledge_triples, + "max_knowledge_relationships": max_knowledge_relationships, "entity_types": entity_types, "relation_types": relation_types, "logger": logger, @@ -102,7 +102,7 @@ async def kg_triples_extraction( ) result_gen = await self.pipes.kg_storage_pipe.run( - input=self.pipes.kg_storage_pipe.Input(message=triples), + input=self.pipes.kg_storage_pipe.Input(message=relationships), state=None, run_manager=self.run_manager, ) @@ -440,58 +440,58 @@ async def get_entities( entity_table_name=entity_table_name, ) - @telemetry_event("get_triples") - async def get_triples( + @telemetry_event("get_relationships") + async def get_relationships( self, offset: int, limit: int, collection_id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, - triple_ids: Optional[list[str]] = None, + relationship_ids: Optional[list[str]] = None, **kwargs, ): - return await self.providers.database.get_triples( + return await self.providers.database.get_relationships( offset=offset, limit=limit, collection_id=collection_id, entity_names=entity_names, - triple_ids=triple_ids, + relationship_ids=relationship_ids, ) - @telemetry_event("list_triples") - async def list_triples( + @telemetry_event("list_relationships") + async def list_relationships( self, offset: int, limit: int, collection_id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, - triple_ids: Optional[list[str]] = None, + relationship_ids: Optional[list[str]] = None, **kwargs, ): - return await self.providers.database.get_triples( + return await self.providers.database.get_relationships( offset=offset, limit=limit, collection_id=collection_id, entity_names=entity_names, - triple_ids=triple_ids, + relationship_ids=relationship_ids, ) ##### Relationships ##### - @telemetry_event("list_triples") - async def list_triples( + @telemetry_event("list_relationships") + async def list_relationships( self, collection_id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, - triple_ids: Optional[list[str]] = None, + relationship_ids: Optional[list[str]] = None, offset: Optional[int] = None, limit: Optional[int] = None, **kwargs, ): - return await self.providers.database.get_triples( + return await self.providers.database.get_relationships( collection_id=collection_id, entity_names=entity_names, - triple_ids=triple_ids, + relationship_ids=relationship_ids, offset=offset or 0, limit=limit or -1, ) diff --git a/py/core/pipes/__init__.py b/py/core/pipes/__init__.py index a83373be8..c51d3907c 100644 --- a/py/core/pipes/__init__.py +++ b/py/core/pipes/__init__.py @@ -10,7 +10,7 @@ from .kg.entity_description import KGEntityDescriptionPipe from .kg.prompt_tuning import KGPromptTuningPipe from .kg.storage import KGStoragePipe -from .kg.triples_extraction import KGTriplesExtractionPipe +from .kg.relationships_extraction import KGRelationshipsExtractionPipe from .retrieval.kg_search_pipe import KGSearchSearchPipe from .retrieval.multi_search import MultiSearchPipe from .retrieval.query_transform_pipe import QueryTransformPipe @@ -23,7 +23,7 @@ "SearchPipe", "GeneratorPipe", "EmbeddingPipe", - "KGTriplesExtractionPipe", + "KGRelationshipsExtractionPipe", "KGSearchSearchPipe", "KGEntityDescriptionPipe", "ParsingPipe", diff --git a/py/core/pipes/kg/clustering.py b/py/core/pipes/kg/clustering.py index 1e06be3fd..acbcdf944 100644 --- a/py/core/pipes/kg/clustering.py +++ b/py/core/pipes/kg/clustering.py @@ -16,7 +16,7 @@ class KGClusteringPipe(AsyncPipe): """ - Clusters entities and triples into communities within the knowledge graph using hierarchical Leiden algorithm. + Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm. """ def __init__( @@ -46,7 +46,7 @@ async def cluster_kg( leiden_params: dict, ): """ - Clusters the knowledge graph triples into communities using hierarchical Leiden algorithm. Uses graspologic library. + Clusters the knowledge graph relationships into communities using hierarchical Leiden algorithm. Uses graspologic library. """ num_communities = ( @@ -73,7 +73,7 @@ async def _run_logic( # type: ignore **kwargs: Any, ) -> AsyncGenerator[dict, None]: """ - Executes the KG clustering pipe: clustering entities and triples into communities. + Executes the KG clustering pipe: clustering entities and relationships into communities. """ collection_id = input.message["collection_id"] diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index 2e2653654..69afaa7da 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -23,7 +23,7 @@ class KGCommunitySummaryPipe(AsyncPipe): """ - Clusters entities and triples into communities within the knowledge graph using hierarchical Leiden algorithm. + Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm. """ def __init__( @@ -51,28 +51,28 @@ def __init__( async def community_summary_prompt( self, entities: list[Entity], - triples: list[Relationship], + relationships: list[Relationship], max_summary_input_length: int, ): entity_map: dict[str, dict[str, list[Any]]] = {} for entity in entities: if not entity.name in entity_map: - entity_map[entity.name] = {"entities": [], "triples": []} + entity_map[entity.name] = {"entities": [], "relationships": []} entity_map[entity.name]["entities"].append(entity) - for triple in triples: - if not triple.subject in entity_map: - entity_map[triple.subject] = { + for relationship in relationships: + if not relationship.subject in entity_map: + entity_map[relationship.subject] = { "entities": [], - "triples": [], + "relationships": [], } - entity_map[triple.subject]["triples"].append(triple) + entity_map[relationship.subject]["relationships"].append(relationship) - # sort in descending order of triple count + # sort in descending order of relationship count sorted_entity_map = sorted( entity_map.items(), - key=lambda x: len(x[1]["triples"]), + key=lambda x: len(x[1]["relationships"]), reverse=True, ) @@ -90,15 +90,15 @@ async def _get_entity_descriptions_string( for entity in sampled_entities ) - async def _get_triples_string(triples: list, max_count: int = 100): - sampled_triples = ( - random.sample(triples, max_count) - if len(triples) > max_count - else triples + async def _get_relationships_string(relationships: list, max_count: int = 100): + sampled_relationships = ( + random.sample(relationships, max_count) + if len(relationships) > max_count + else relationships ) return "\n".join( - f"{triple.id},{triple.subject},{triple.object},{triple.predicate},{triple.description}" - for triple in sampled_triples + f"{relationship.id},{relationship.subject},{relationship.object},{relationship.predicate},{relationship.description}" + for relationship in sampled_relationships ) prompt = "" @@ -106,14 +106,14 @@ async def _get_triples_string(triples: list, max_count: int = 100): entity_descriptions = await _get_entity_descriptions_string( entity_data["entities"] ) - triples = await _get_triples_string(entity_data["triples"]) + relationships = await _get_relationships_string(entity_data["relationships"]) prompt += f""" Entity: {entity_name} Descriptions: {entity_descriptions} - Triples: - {triples} + Relationships: + {relationships} """ if len(prompt) > max_summary_input_length: @@ -137,16 +137,16 @@ async def process_community( Process a community by summarizing it and creating a summary embedding and storing it to a database. """ - community_level, entities, triples = ( + community_level, entities, relationships = ( await self.database_provider.get_community_details( community_number=community_number, collection_id=collection_id, ) ) - if entities == [] and triples == []: + if entities == [] and relationships == []: raise ValueError( - f"Community {community_number} has no entities or triples." + f"Community {community_number} has no entities or relationships." ) for attempt in range(3): @@ -160,7 +160,7 @@ async def process_community( "input_text": ( await self.community_summary_prompt( entities, - triples, + relationships, max_summary_input_length, ) ), diff --git a/py/core/pipes/kg/entity_description.py b/py/core/pipes/kg/entity_description.py index 1787d95e0..475685fb6 100644 --- a/py/core/pipes/kg/entity_description.py +++ b/py/core/pipes/kg/entity_description.py @@ -1,4 +1,4 @@ -# pipe to extract nodes/triples etc +# pipe to extract nodes/relationships etc import asyncio import logging @@ -73,16 +73,16 @@ def truncate_info(info_list, max_length): return truncated_info async def process_entity( - entities, triples, max_description_input_length, document_id + entities, relationships, max_description_input_length, document_id ): entity_info = [ f"{entity.name}, {entity.description}" for entity in entities ] - triples_txt = [ - f"{i+1}: {triple.subject}, {triple.object}, {triple.predicate} - Summary: {triple.description}" - for i, triple in enumerate(triples) + relationships_txt = [ + f"{i+1}: {relationship.subject}, {relationship.object}, {relationship.predicate} - Summary: {relationship.description}" + for i, relationship in enumerate(relationships) ] # potentially slow at scale, but set to avoid duplicates @@ -107,8 +107,8 @@ async def process_entity( entity_info, max_description_input_length, ), - "triples_txt": truncate_info( - triples_txt, + "relationships_txt": truncate_info( + relationships_txt, max_description_input_length, ), }, @@ -171,7 +171,7 @@ async def process_entity( workflows.append( process_entity( entity_info["entities"], - entity_info["triples"], + entity_info["relationships"], input.message["max_description_input_length"], document_id, ) diff --git a/py/core/pipes/kg/triples_extraction.py b/py/core/pipes/kg/relationships_extraction.py similarity index 87% rename from py/core/pipes/kg/triples_extraction.py rename to py/core/pipes/kg/relationships_extraction.py index f9dcd0514..ce103f4c3 100644 --- a/py/core/pipes/kg/triples_extraction.py +++ b/py/core/pipes/kg/relationships_extraction.py @@ -32,7 +32,7 @@ class ClientError(Exception): pass -class KGTriplesExtractionPipe(AsyncPipe[dict]): +class KGRelationshipsExtractionPipe(AsyncPipe[dict]): """ Extracts knowledge graph information from document extractions. """ @@ -56,7 +56,7 @@ def __init__( super().__init__( logging_provider=logging_provider, config=config - or AsyncPipe.PipeConfig(name="default_kg_triples_extraction_pipe"), + or AsyncPipe.PipeConfig(name="default_kg_relationships_extraction_pipe"), ) self.database_provider = database_provider self.llm_provider = llm_provider @@ -69,7 +69,7 @@ async def extract_kg( self, extractions: list[DocumentChunk], generation_config: GenerationConfig, - max_knowledge_triples: int, + max_knowledge_relationships: int, entity_types: list[str], relation_types: list[str], retries: int = 5, @@ -78,17 +78,17 @@ async def extract_kg( total_tasks: Optional[int] = None, ) -> KGExtraction: """ - Extracts NER triples from a extraction with retries. + Extracts NER relationships from a extraction with retries. """ # combine all extractions into a single string combined_extraction: str = " ".join([extraction.data for extraction in extractions]) # type: ignore messages = await self.database_provider.prompt_handler.get_message_payload( - task_prompt_name=self.database_provider.config.kg_creation_settings.graphrag_triples_extraction_few_shot, + task_prompt_name=self.database_provider.config.kg_creation_settings.graphrag_relationships_extraction_few_shot, task_inputs={ "input": combined_extraction, - "max_knowledge_triples": max_knowledge_triples, + "max_knowledge_relationships": max_knowledge_relationships, "entity_types": "\n".join(entity_types), "relation_types": "\n".join(relation_types), }, @@ -175,14 +175,14 @@ def parse_fn(response_str: str) -> Any: return entities_arr, relations_arr - entities, triples = parse_fn(kg_extraction) + entities, relationships = parse_fn(kg_extraction) return KGExtraction( extraction_ids=[ extraction.id for extraction in extractions ], document_id=extractions[0].document_id, entities=entities, - triples=triples, + relationships=relationships, ) except ( @@ -199,7 +199,7 @@ def parse_fn(response_str: str) -> Any: f"Failed after retries with for extraction {extractions[0].id} of document {extractions[0].document_id}: {e}" ) # raise e # you should raise an error. - # add metadata to entities and triples + # add metadata to entities and relationships logger.info( f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {extractions[0].document_id}", @@ -209,7 +209,7 @@ def parse_fn(response_str: str) -> Any: extraction_ids=[extraction.id for extraction in extractions], document_id=extractions[0].document_id, entities=[], - triples=[], + relationships=[], ) async def _run_logic( # type: ignore @@ -226,7 +226,7 @@ async def _run_logic( # type: ignore document_id = input.message["document_id"] generation_config = input.message["generation_config"] extraction_merge_count = input.message["extraction_merge_count"] - max_knowledge_triples = input.message["max_knowledge_triples"] + max_knowledge_relationships = input.message["max_knowledge_relationships"] entity_types = input.message["entity_types"] relation_types = input.message["relation_types"] @@ -237,7 +237,7 @@ async def _run_logic( # type: ignore logger = input.message.get("logger", logging.getLogger()) logger.info( - f"KGTriplesExtractionPipe: Processing document {document_id} for KG extraction", + f"KGRelationshipsExtractionPipe: Processing document {document_id} for KG extraction", ) # Then create the extractions from the results @@ -281,7 +281,7 @@ async def _run_logic( # type: ignore return logger.info( - f"KGTriplesExtractionPipe: Obtained {len(extractions)} extractions to process, time from start: {time.time() - start_time:.2f} seconds", + f"KGRelationshipsExtractionPipe: Obtained {len(extractions)} extractions to process, time from start: {time.time() - start_time:.2f} seconds", ) # sort the extractions accroding to chunk_order field in metadata in ascending order @@ -296,7 +296,7 @@ async def _run_logic( # type: ignore ] logger.info( - f"KGTriplesExtractionPipe: Extracting KG Triples for document and created {len(extractions_groups)} tasks, time from start: {time.time() - start_time:.2f} seconds", + f"KGRelationshipsExtractionPipe: Extracting KG Relationships for document and created {len(extractions_groups)} tasks, time from start: {time.time() - start_time:.2f} seconds", ) tasks = [ @@ -304,7 +304,7 @@ async def _run_logic( # type: ignore self.extract_kg( extractions=extractions_group, generation_config=generation_config, - max_knowledge_triples=max_knowledge_triples, + max_knowledge_relationships=max_knowledge_relationships, entity_types=entity_types, relation_types=relation_types, task_id=task_id, @@ -318,7 +318,7 @@ async def _run_logic( # type: ignore total_tasks = len(tasks) logger.info( - f"KGTriplesExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete", + f"KGRelationshipsExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete", ) for completed_task in asyncio.as_completed(tasks): @@ -327,15 +327,15 @@ async def _run_logic( # type: ignore completed_tasks += 1 if completed_tasks % 100 == 0: logger.info( - f"KGTriplesExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks", + f"KGRelationshipsExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks", ) except Exception as e: - logger.error(f"Error in Extracting KG Triples: {e}") + logger.error(f"Error in Extracting KG Relationships: {e}") yield R2RDocumentProcessingError( document_id=document_id, error_message=str(e), ) logger.info( - f"KGTriplesExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds", + f"KGRelationshipsExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds", ) diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index edf2b6b03..8da324acf 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -91,9 +91,9 @@ async def create_tables(self): """ await self.connection_manager.execute_query(query) - # raw triples table, also the final table. this will have embeddings. + # raw relationships table, also the final table. this will have embeddings. query = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name("chunk_triple")} ( + CREATE TABLE IF NOT EXISTS {self._get_table_name("chunk_relationship")} ( id SERIAL PRIMARY KEY, subject TEXT NOT NULL, predicate TEXT NOT NULL, @@ -148,7 +148,7 @@ async def create_tables(self): parent_cluster INT, level INT NOT NULL, is_final_cluster BOOLEAN NOT NULL, - triple_ids INT[] NOT NULL, + relationship_ids INT[] NOT NULL, collection_id UUID NOT NULL );""" @@ -270,8 +270,8 @@ async def get_graph_status(self, collection_id: UUID) -> dict: [document_ids], ) - chunk_triple_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('chunk_triple')} WHERE document_id = ANY($1)", + chunk_relationship_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1)", [document_ids], ) @@ -294,30 +294,30 @@ async def get_graph_status(self, collection_id: UUID) -> dict: "kg_extraction_statuses": kg_extraction_statuses, "kg_enrichment_status": kg_enrichment_statuses[0]["enrichment_status"], "chunk_entity_count": chunk_entity_count[0]["count"], - "chunk_triple_count": chunk_triple_count[0]["count"], + "chunk_relationship_count": chunk_relationship_count[0]["count"], "document_entity_count": document_entity_count[0]["count"], "collection_entity_count": collection_entity_count[0]["count"], "community_count": community_count[0]["count"], } ### Relationships BEGIN #### - async def add_triples( + async def add_relationships( self, - triples: list[Relationship], - table_name: str = "chunk_triple", + relationships: list[Relationship], + table_name: str = "chunk_relationship", ) -> None: """ - Upsert triples into the chunk_triple table. These are raw triples extracted from the document. + Upsert relationships into the chunk_relationship table. These are raw relationships extracted from the document. Args: - triples: list[Relationship]: list of triples to upsert + relationships: list[Relationship]: list of relationships to upsert table_name: str: name of the table to upsert into Returns: result: asyncpg.Record: result of the upsert operation """ return await self._add_objects( - [ele.to_dict() for ele in triples], table_name + [ele.to_dict() for ele in relationships], table_name ) async def list_relationships_v3( @@ -338,12 +338,12 @@ async def list_relationships_v3( if level == EntityLevel.CHUNK: QUERY = f""" - SELECT * FROM {self._get_table_name("chunk_triple")} WHERE $1 = ANY(chunk_ids) + SELECT * FROM {self._get_table_name("chunk_relationship")} WHERE $1 = ANY(chunk_ids) {filter_query} """ elif level == EntityLevel.DOCUMENT: QUERY = f""" - SELECT * FROM {self._get_table_name("chunk_triple")} WHERE $1 = document_id + SELECT * FROM {self._get_table_name("chunk_relationship")} WHERE $1 = document_id {filter_query} """ elif level == EntityLevel.COLLECTION: @@ -351,7 +351,7 @@ async def list_relationships_v3( WITH document_ids AS ( SELECT document_id FROM {self._get_table_name("document_info")} WHERE $1 = ANY(collection_ids) ) - SELECT * FROM {self._get_table_name("chunk_triple")} WHERE document_id IN (SELECT document_id FROM document_ids) + SELECT * FROM {self._get_table_name("chunk_relationship")} WHERE document_id IN (SELECT document_id FROM document_ids) {filter_query} """ @@ -372,7 +372,7 @@ async def add_kg_extractions( table_prefix: str = "chunk_", ) -> Tuple[int, int]: """ - Upsert entities and triples into the database. These are raw entities and triples extracted from the document fragments. + Upsert entities and relationships into the database. These are raw entities and relationships extracted from the document fragments. Args: kg_extractions: list[KGExtraction]: list of KG extractions to upsert @@ -389,7 +389,7 @@ async def add_kg_extractions( total_entities, total_relationships = ( total_entities + len(extraction.entities), - total_relationships + len(extraction.triples), + total_relationships + len(extraction.relationships), ) if extraction.entities: @@ -406,16 +406,16 @@ async def add_kg_extractions( extraction.entities, table_name=f"{table_prefix}entity" ) - if extraction.triples: - if not extraction.triples[0].extraction_ids: - for i in range(len(extraction.triples)): - extraction.triples[i].extraction_ids = ( + if extraction.relationships: + if not extraction.relationships[0].extraction_ids: + for i in range(len(extraction.relationships)): + extraction.relationships[i].extraction_ids = ( extraction.extraction_ids ) - extraction.triples[i].document_id = extraction.document_id + extraction.relationships[i].document_id = extraction.document_id - await self.add_triples( - extraction.triples, table_name=f"{table_prefix}triple" + await self.add_relationships( + extraction.relationships, table_name=f"{table_prefix}relationship" ) return (total_entities, total_relationships) @@ -466,38 +466,38 @@ async def get_entity_map( SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description, (SELECT array_agg(DISTINCT x) FROM unnest(t.extraction_ids) x) AS extraction_ids, t.document_id - FROM {self._get_table_name("chunk_triple")} t + FROM {self._get_table_name("chunk_relationship")} t JOIN entities_list el ON t.subject = el.name ORDER BY t.subject, t.predicate, t.object; """ - triples_list = await self.connection_manager.fetch_query( + relationships_list = await self.connection_manager.fetch_query( QUERY2, [document_id] ) - triples_list = [ + relationships_list = [ Relationship( - subject=triple["subject"], - predicate=triple["predicate"], - object=triple["object"], - weight=triple["weight"], - description=triple["description"], - extraction_ids=triple["extraction_ids"], - document_id=triple["document_id"], + subject=relationship["subject"], + predicate=relationship["predicate"], + object=relationship["object"], + weight=relationship["weight"], + description=relationship["description"], + extraction_ids=relationship["extraction_ids"], + document_id=relationship["document_id"], ) - for triple in triples_list + for relationship in relationships_list ] entity_map: dict[str, dict[str, list[Any]]] = {} for entity in entities_list: if entity.name not in entity_map: - entity_map[entity.name] = {"entities": [], "triples": []} + entity_map[entity.name] = {"entities": [], "relationships": []} entity_map[entity.name]["entities"].append(entity) - for triple in triples_list: - if triple.subject in entity_map: - entity_map[triple.subject]["triples"].append(triple) - if triple.object in entity_map: - entity_map[triple.object]["triples"].append(triple) + for relationship in relationships_list: + if relationship.subject in entity_map: + entity_map[relationship.subject]["relationships"].append(relationship) + if relationship.object in entity_map: + entity_map[relationship.object]["relationships"].append(relationship) return entity_map @@ -547,7 +547,7 @@ async def vector_query( # type: ignore else "document_entity" ) elif search_type == "__Relationship__": - table_name = "chunk_triple" + table_name = "chunk_relationship" elif search_type == "__Community__": table_name = "community_report" else: @@ -600,7 +600,7 @@ async def vector_query( # type: ignore for property_name in property_names } - async def get_all_triples(self, collection_id: UUID) -> list[Relationship]: + async def get_all_relationships(self, collection_id: UUID) -> list[Relationship]: # getting all documents for a collection if document_ids is None: @@ -615,18 +615,18 @@ async def get_all_triples(self, collection_id: UUID) -> list[Relationship]: ] QUERY = f""" - SELECT id, subject, predicate, weight, object, document_id FROM {self._get_table_name("chunk_triple")} WHERE document_id = ANY($1) + SELECT id, subject, predicate, weight, object, document_id FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY($1) """ - triples = await self.connection_manager.fetch_query( + relationships = await self.connection_manager.fetch_query( QUERY, [document_ids] ) - return [Relationship(**triple) for triple in triples] + return [Relationship(**relationship) for relationship in relationships] async def add_community_info( self, communities: list[CommunityInfo] ) -> None: QUERY = f""" - INSERT INTO {self._get_table_name("community_info")} (node, cluster, parent_cluster, level, is_final_cluster, triple_ids, collection_id) + INSERT INTO {self._get_table_name("community_info")} (node, cluster, parent_cluster, level, is_final_cluster, relationship_ids, collection_id) VALUES ($1, $2, $3, $4, $5, $6, $7) """ communities_tuples_list = [ @@ -636,7 +636,7 @@ async def add_community_info( community.parent_cluster, community.level, community.is_final_cluster, - community.triple_ids, + community.relationship_ids, community.collection_id, ) for community in communities @@ -728,16 +728,16 @@ async def add_community_report( ) async def _create_graph_and_cluster( - self, triples: list[Triple], leiden_params: dict[str, Any] + self, relationships: list[Relationship], leiden_params: dict[str, Any] ) -> Any: G = self.nx.Graph() - for triple in triples: + for relationship in relationships: G.add_edge( - triple.subject, - triple.object, - weight=triple.weight, - id=triple.id, + relationship.subject, + relationship.object, + weight=relationship.weight, + id=relationship.id, ) hierarchical_communities = await self._compute_leiden_communities( @@ -748,8 +748,8 @@ async def _create_graph_and_cluster( async def _cluster_and_add_community_info( self, - triples: list[Triple], - triple_ids_cache: dict[str, list[int]], + relationships: list[Relationship], + relationship_ids_cache: dict[str, list[int]], leiden_params: dict[str, Any], collection_id: UUID, ) -> int: @@ -768,18 +768,18 @@ async def _cluster_and_add_community_info( start_time = time.time() hierarchical_communities = await self._create_graph_and_cluster( - triples, leiden_params + relationships, leiden_params ) logger.info( f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds." ) - def triple_ids(node: str) -> list[int]: - return triple_ids_cache.get(node, []) + def relationship_ids(node: str) -> list[int]: + return relationship_ids_cache.get(node, []) logger.info( - f"Cached {len(triple_ids_cache)} triple ids, time {time.time() - start_time:.2f} seconds." + f"Cached {len(relationship_ids_cache)} relationship ids, time {time.time() - start_time:.2f} seconds." ) # upsert the communities into the database. @@ -790,7 +790,7 @@ def triple_ids(node: str) -> list[int]: parent_cluster=item.parent_cluster, level=item.level, is_final_cluster=item.is_final_cluster, - triple_ids=triple_ids(item.node), + relationship_ids=relationship_ids(item.node), collection_id=collection_id, ) for item in hierarchical_communities @@ -809,7 +809,7 @@ def triple_ids(node: str) -> list[int]: return num_communities async def _use_community_cache( - self, collection_id: UUID, triple_ids_cache: dict[str, list[int]] + self, collection_id: UUID, relationship_ids_cache: dict[str, list[int]] ) -> bool: # check if status is enriched or stale @@ -835,50 +835,50 @@ async def _use_community_cache( )["count"] # a hard threshold of 80% of the entities in the cache. - if num_entities > 0.8 * len(triple_ids_cache): + if num_entities > 0.8 * len(relationship_ids_cache): return True else: return False - async def _get_triple_ids_cache( - self, triples: list[Triple] + async def _get_relationship_ids_cache( + self, relationships: list[Relationship] ) -> dict[str, list[int]]: - # caching the triple ids - triple_ids_cache = dict[str, list[int]]() - for triple in triples: + # caching the relationship ids + relationship_ids_cache = dict[str, list[int]]() + for relationship in relationships: if ( - triple.subject not in triple_ids_cache - and triple.subject is not None + relationship.subject not in relationship_ids_cache + and relationship.subject is not None ): - triple_ids_cache[triple.subject] = [] + relationship_ids_cache[relationship.subject] = [] if ( - triple.object not in triple_ids_cache - and triple.object is not None + relationship.object not in relationship_ids_cache + and relationship.object is not None ): - triple_ids_cache[triple.object] = [] - if triple.subject is not None and triple.id is not None: - triple_ids_cache[triple.subject].append(triple.id) - if triple.object is not None and triple.id is not None: - triple_ids_cache[triple.object].append(triple.id) + relationship_ids_cache[relationship.object] = [] + if relationship.subject is not None and relationship.id is not None: + relationship_ids_cache[relationship.subject].append(relationship.id) + if relationship.object is not None and relationship.id is not None: + relationship_ids_cache[relationship.object].append(relationship.id) - return triple_ids_cache + return relationship_ids_cache async def _incremental_clustering( self, - triple_ids_cache: dict[str, list[int]], + relationship_ids_cache: dict[str, list[int]], leiden_params: dict[str, Any], collection_id: UUID, ) -> int: """ - Performs incremental clustering on new triples by: - 1. Getting all triples and new triples - 2. Getting community mapping for all existing triples - 3. For each new triple: + Performs incremental clustering on new relationships by: + 1. Getting all relationships and new relationships + 2. Getting community mapping for all existing relationships + 3. For each new relationship: - Check if subject/object exists in community mapping - If exists, add its cluster to updated communities set - - If not, append triple to new_triple_ids list for clustering - 4. Run hierarchical clustering on new_triple_ids list + - If not, append relationship to new_relationship_ids list for clustering + 4. Run hierarchical clustering on new_relationship_ids list 5. Update community info table with new clusters, offsetting IDs by max_cluster_id """ @@ -908,23 +908,23 @@ async def _incremental_clustering( QUERY, [collection_id, KGExtractionStatus.SUCCESS] ) - new_triple_ids = await self.get_all_triples( + new_relationship_ids = await self.get_all_relationships( collection_id, new_document_ids ) - # community mapping for new triples + # community mapping for new relationships updated_communities = set() - new_triples = [] - for triple in new_triple_ids: + new_relationships = [] + for relationship in new_relationship_ids: # bias towards subject - if triple.subject in communities_dict: - for community in communities_dict[triple.subject]: + if relationship.subject in communities_dict: + for community in communities_dict[relationship.subject]: updated_communities.add(community["cluster"]) - elif triple.object in communities_dict: - for community in communities_dict[triple.object]: + elif relationship.object in communities_dict: + for community in communities_dict[relationship.object]: updated_communities.add(community["cluster"]) else: - new_triples.append(triple) + new_relationships.append(relationship) # delete the communities information for the updated communities QUERY = f""" @@ -935,7 +935,7 @@ async def _incremental_clustering( ) hierarchical_communities_output = await self._create_graph_and_cluster( - new_triples, leiden_params + new_relationships, leiden_params ) community_info = [] @@ -950,7 +950,7 @@ async def _incremental_clustering( else None ), level=community.level, - triple_ids=[], # FIXME: need to get the triple ids for the community + relationship_ids=[], # FIXME: need to get the relationship ids for the community is_final_cluster=community.is_final_cluster, collection_id=collection_id, ) @@ -966,7 +966,7 @@ async def perform_graph_clustering( leiden_params: dict[str, Any], ) -> int: """ - Leiden clustering algorithm to cluster the knowledge graph triples into communities. + Leiden clustering algorithm to cluster the knowledge graph relationships into communities. Available parameters and defaults: max_cluster_size: int = 1000, @@ -984,19 +984,19 @@ async def perform_graph_clustering( start_time = time.time() - triples = await self.get_all_triples(collection_id) + relationships = await self.get_all_relationships(collection_id) logger.info(f"Clustering with settings: {leiden_params}") - triple_ids_cache = await self._get_triple_ids_cache(triples) + relationship_ids_cache = await self._get_relationship_ids_cache(relationships) - if await self._use_community_cache(collection_id, triple_ids_cache): + if await self._use_community_cache(collection_id, relationship_ids_cache): num_communities = await self._incremental_clustering( - triple_ids_cache, leiden_params, collection_id + relationship_ids_cache, leiden_params, collection_id ) else: num_communities = await self._cluster_and_add_community_info( - triples, triple_ids_cache, leiden_params, collection_id + relationships, relationship_ids_cache, leiden_params, collection_id ) return num_communities @@ -1057,8 +1057,8 @@ async def get_community_details( ) QUERY = f""" - WITH node_triple_ids AS ( - SELECT node, triple_ids + WITH node_relationship_ids AS ( + SELECT node, relationship_ids FROM {self._get_table_name("community_info")} WHERE cluster = $1 AND collection_id = $2 ) @@ -1066,7 +1066,7 @@ async def get_community_details( e.id AS id, e.name AS name, e.description AS description - FROM node_triple_ids nti + FROM node_relationship_ids nti JOIN {self._get_table_name(table_name)} e ON e.name = nti.node; """ entities = await self.connection_manager.fetch_query( @@ -1075,22 +1075,22 @@ async def get_community_details( entities = [Entity(**entity) for entity in entities] QUERY = f""" - WITH node_triple_ids AS ( - SELECT node, triple_ids + WITH node_relationship_ids AS ( + SELECT node, relationship_ids FROM {self._get_table_name("community_info")} WHERE cluster = $1 and collection_id = $2 ) SELECT DISTINCT t.id, t.subject, t.predicate, t.object, t.weight, t.description - FROM node_triple_ids nti - JOIN {self._get_table_name("chunk_triple")} t ON t.id = ANY(nti.triple_ids); + FROM node_relationship_ids nti + JOIN {self._get_table_name("chunk_relationship")} t ON t.id = ANY(nti.relationship_ids); """ - triples = await self.connection_manager.fetch_query( + relationships = await self.connection_manager.fetch_query( QUERY, [community_number, collection_id] ) - triples = [Relationship(**triple) for triple in triples] + relationships = [Relationship(**relationship) for relationship in relationships] - return level, entities, triples + return level, entities, relationships # async def client(self): # return None @@ -1144,7 +1144,7 @@ async def create_relationship( # check if the relationship already exists QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("chunk_triple")} WHERE subject = $1 AND predicate = $2 AND object = $3 AND collection_id = $4 + SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE subject = $1 AND predicate = $2 AND object = $3 AND collection_id = $4 """ count = ( await self.connection_manager.fetch_query(QUERY, [relationship.subject, relationship.predicate, relationship.object, collection_id]) @@ -1153,7 +1153,7 @@ async def create_relationship( if count > 0: raise R2RException("Relationship already exists", 400) - await self._add_objects([relationship], "chunk_triple") + await self._add_objects([relationship], "chunk_relationship") async def update_relationship( self, relationship_id: UUID, relationship: Relationship @@ -1161,7 +1161,7 @@ async def update_relationship( # check if relationship_id exists QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("chunk_triple")} WHERE id = $1 + SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 """ count = ( await self.connection_manager.fetch_query(QUERY, [relationship.id]) @@ -1170,13 +1170,13 @@ async def update_relationship( if count == 0: raise R2RException("Relationship does not exist", 404) - await self._add_objects([relationship], "chunk_triple") + await self._add_objects([relationship], "chunk_relationship") async def delete_relationship( self, relationship_id: UUID ) -> None: QUERY = f""" - DELETE FROM {self._get_table_name("chunk_triple")} WHERE id = $1 + DELETE FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 """ await self.connection_manager.execute_query(QUERY, [relationship_id]) @@ -1219,7 +1219,7 @@ async def delete_graph_for_collection( if status == KGExtractionStatus.PROCESSING.value: return - # remove all triples for these documents. + # remove all relationships for these documents. DELETE_QUERIES = [ f"DELETE FROM {self._get_table_name('community_info')} WHERE collection_id = $1;", f"DELETE FROM {self._get_table_name('community_report')} WHERE collection_id = $1;", @@ -1241,7 +1241,7 @@ async def delete_graph_for_collection( if cascade: DELETE_QUERIES += [ f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1::uuid[]);", - f"DELETE FROM {self._get_table_name('chunk_triple')} WHERE document_id = ANY($1::uuid[]);", + f"DELETE FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1::uuid[]);", f"DELETE FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1::uuid[]);", f"DELETE FROM {self._get_table_name('collection_entity')} WHERE collection_id = $1;", ] @@ -1288,7 +1288,7 @@ async def delete_node_via_document_id( # Execute separate DELETE queries delete_queries = [ f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = $1", - f"DELETE FROM {self._get_table_name('chunk_triple')} WHERE document_id = $1", + f"DELETE FROM {self._get_table_name('chunk_relationship')} WHERE document_id = $1", f"DELETE FROM {self._get_table_name('document_entity')} WHERE document_id = $1", ] @@ -1375,10 +1375,10 @@ async def get_creation_estimate( total_chunks * 10, total_chunks * 20, ) # 25 entities per 4 chunks - estimated_triples = ( + estimated_relationships = ( int(estimated_entities[0] * 1.25), int(estimated_entities[1] * 1.5), - ) # Assuming 1.25 triples per entity on average + ) # Assuming 1.25 relationships per entity on average estimated_llm_calls = ( total_chunks * 2 + estimated_entities[0], @@ -1414,8 +1414,8 @@ async def get_creation_estimate( estimated_entities=self._get_str_estimation_output( estimated_entities ), - estimated_triples=self._get_str_estimation_output( - estimated_triples + estimated_relationships=self._get_str_estimation_output( + estimated_relationships ), estimated_llm_calls=self._get_str_estimation_output( estimated_llm_calls @@ -1454,9 +1454,9 @@ async def get_enrichment_estimate( ) QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("chunk_triple")} WHERE document_id = ANY($1); + SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY($1); """ - triple_count = ( + relationship_count = ( await self.connection_manager.fetch_query(QUERY, [document_ids]) )[0]["count"] @@ -1483,7 +1483,7 @@ async def get_enrichment_estimate( return KGEnrichmentEstimationResponse( message='Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_type="run"` in the client.', total_entities=entity_count, - total_triples=triple_count, + total_relationships=relationship_count, estimated_llm_calls=self._get_str_estimation_output( estimated_llm_calls ), @@ -1502,7 +1502,7 @@ async def create_vector_index(self): # this needs to be run periodically for every collection. raise NotImplementedError - async def delete_triples(self, triple_ids: list[int]): + async def delete_relationships(self, relationship_ids: list[int]): # need to implement this. raise NotImplementedError @@ -1625,21 +1625,21 @@ async def get_entities( return {"entities": entities, "total_entries": total_entries} - async def get_triples( + async def get_relationships( self, offset: int, limit: int, collection_id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, - triple_ids: Optional[list[str]] = None, + relationship_ids: Optional[list[str]] = None, ) -> dict: conditions = [] params: list = [str(collection_id)] param_index = 2 - if triple_ids: + if relationship_ids: conditions.append(f"id = ANY(${param_index})") - params.append(triple_ids) + params.append(relationship_ids) param_index += 1 if entity_names: @@ -1664,7 +1664,7 @@ async def get_triples( query = f""" SELECT id, subject, predicate, object, description - FROM {self._get_table_name("chunk_triple")} + FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY( SELECT document_id FROM {self._get_table_name("document_info")} WHERE $1 = ANY(collection_ids) @@ -1674,13 +1674,13 @@ async def get_triples( {pagination_clause} """ - triples = await self.connection_manager.fetch_query(query, params) - triples = [Relationship(**triple) for triple in triples] - total_entries = await self.get_triple_count( + relationships = await self.connection_manager.fetch_query(query, params) + relationships = [Relationship(**relationship) for relationship in relationships] + total_entries = await self.get_relationship_count( collection_id=collection_id ) - return {"triples": triples, "total_entries": total_entries} + return {"relationships": relationships, "total_entries": total_entries} async def structured_query(self): raise NotImplementedError @@ -1691,7 +1691,7 @@ async def update_extraction_prompt(self): async def update_kg_search_prompt(self): raise NotImplementedError - async def upsert_triples(self): + async def upsert_relationships(self): raise NotImplementedError async def get_entity_count( @@ -1740,7 +1740,7 @@ async def get_entity_count( "count" ] - async def get_triple_count( + async def get_relationship_count( self, collection_id: Optional[UUID] = None, document_id: Optional[UUID] = None, @@ -1768,7 +1768,7 @@ async def get_triple_count( params.append(str(document_id)) QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("chunk_triple")} + SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE {" AND ".join(conditions)} """ return (await self.connection_manager.fetch_query(QUERY, params))[0][ diff --git a/py/core/providers/database/prompts/graphrag_community_reports.yaml b/py/core/providers/database/prompts/graphrag_community_reports.yaml index 7b7828f69..8b78b94d7 100644 --- a/py/core/providers/database/prompts/graphrag_community_reports.yaml +++ b/py/core/providers/database/prompts/graphrag_community_reports.yaml @@ -55,14 +55,14 @@ graphrag_community_reports: Entity: OpenAI descriptions: 101,OpenAI is an AI research and deployment company. - triples: + relationships: 201,OpenAI,Stripe,OpenAI partnered with Stripe to integrate payment solutions. 203,Airbnb,OpenAI,Airbnb utilizes OpenAI's AI tools for customer service. 204,Stripe,OpenAI,Stripe invested in OpenAI's latest funding round. Entity: Stripe descriptions: 102,Stripe is a technology company that builds economic infrastructure for the internet. - triples: + relationships: 201,OpenAI,Stripe,OpenAI partnered with Stripe to integrate payment solutions. 202,Stripe,Airbnb,Stripe provides payment processing services to Airbnb. 204,Stripe,OpenAI,Stripe invested in OpenAI's latest funding round. @@ -70,7 +70,7 @@ graphrag_community_reports: Entity: Airbnb descriptions: 103,Airbnb is an online marketplace for lodging and tourism experiences. - triples: + relationships: 203,Airbnb,OpenAI,Airbnb utilizes OpenAI's AI tools for customer service. 205,Airbnb,Stripe,Airbnb and Stripe collaborate on expanding global payment options. diff --git a/py/core/providers/database/prompts/graphrag_entity_description.yaml b/py/core/providers/database/prompts/graphrag_entity_description.yaml index ea0066a41..bfed919a0 100644 --- a/py/core/providers/database/prompts/graphrag_entity_description.yaml +++ b/py/core/providers/database/prompts/graphrag_entity_description.yaml @@ -1,15 +1,15 @@ graphrag_entity_description: template: | - Provide a comprehensive yet concise summary of the given entity, incorporating its description and associated triples: + Provide a comprehensive yet concise summary of the given entity, incorporating its description and associated relationships: Entity Info: {entity_info} - Triples: - {triples_txt} + Relationships: + {relationships_txt} Your summary should: 1. Clearly define the entity's core concept or purpose - 2. Highlight key relationships or attributes from the triples + 2. Highlight key relationships or attributes from the relationships 3. Integrate any relevant information from the existing description 4. Maintain a neutral, factual tone 5. Be approximately 2-3 sentences long @@ -17,4 +17,4 @@ graphrag_entity_description: Ensure the summary is coherent, informative, and captures the essence of the entity within the context of the provided information. input_types: entity_info: str - triples_txt: str + relationships_txt: str diff --git a/py/core/providers/database/prompts/graphrag_triples_extraction_few_shot.yaml b/py/core/providers/database/prompts/graphrag_relationships_extraction_few_shot.yaml similarity index 98% rename from py/core/providers/database/prompts/graphrag_triples_extraction_few_shot.yaml rename to py/core/providers/database/prompts/graphrag_relationships_extraction_few_shot.yaml index 6bfb1bb26..867c0ac70 100644 --- a/py/core/providers/database/prompts/graphrag_triples_extraction_few_shot.yaml +++ b/py/core/providers/database/prompts/graphrag_relationships_extraction_few_shot.yaml @@ -1,8 +1,8 @@ -graphrag_triples_extraction_few_shot: +graphrag_relationships_extraction_few_shot: template: > -Goal- Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities. - Given the text, extract up to {max_knowledge_triples} entity-relation triplets. + Given the text, extract up to {max_knowledge_relationships} entity-relation relationshipts. -Steps- 1. Identify all entities. For each identified entity, extract the following information: - entity_name: Name of the entity, capitalized @@ -117,7 +117,7 @@ graphrag_triples_extraction_few_shot: Output: input_types: - max_knowledge_triples: int + max_knowledge_relationships: int input: str entity_types: list[str] relation_types: list[str] diff --git a/py/migrations/versions/d342e632358a_migrate_to_asyncpg.py b/py/migrations/versions/d342e632358a_migrate_to_asyncpg.py index 9b786ccb4..6f839fea9 100644 --- a/py/migrations/versions/d342e632358a_migrate_to_asyncpg.py +++ b/py/migrations/versions/d342e632358a_migrate_to_asyncpg.py @@ -43,7 +43,7 @@ def upgrade() -> None: f"ALTER TABLE IF EXISTS {project_name}.entity_raw RENAME TO chunk_entity" ) op.execute( - f"ALTER TABLE IF EXISTS {project_name}.triple_raw RENAME TO chunk_triple" + f"ALTER TABLE IF EXISTS {project_name}.relationship_raw RENAME TO chunk_relationship" ) op.execute( f"ALTER TABLE IF EXISTS {project_name}.entity_embedding RENAME TO document_entity" @@ -158,7 +158,7 @@ def downgrade() -> None: f"ALTER TABLE IF EXISTS {project_name}.chunk_entity RENAME TO entity_raw" ) op.execute( - f"ALTER TABLE IF EXISTS {project_name}.chunk_triple RENAME TO triple_raw" + f"ALTER TABLE IF EXISTS {project_name}.chunk_relationship RENAME TO relationship_raw" ) op.execute( f"ALTER TABLE IF EXISTS {project_name}.document_entity RENAME TO entity_embedding" diff --git a/py/r2r.toml b/py/r2r.toml index e91479173..a64e5945c 100644 --- a/py/r2r.toml +++ b/py/r2r.toml @@ -44,13 +44,13 @@ batch_size = 256 [database.kg_creation_settings] kg_entity_description_prompt = "graphrag_entity_description" - kg_triples_extraction_prompt = "graphrag_triples_extraction_few_shot" + kg_relationships_extraction_prompt = "graphrag_relationships_extraction_few_shot" entity_types = [] # if empty, all entities are extracted relation_types = [] # if empty, all relations are extracted fragment_merge_count = 4 # number of fragments to merge into a single extraction - max_knowledge_triples = 100 + max_knowledge_relationships = 100 max_description_input_length = 65536 - generation_config = { model = "azure/gpt-4o-mini" } # and other params, model used for triplet extraction + generation_config = { model = "azure/gpt-4o-mini" } # and other params, model used for relationshipt extraction [database.kg_entity_deduplication_settings] kg_entity_deduplication_type = "by_name" diff --git a/py/sdk/v2/mixins/kg.py b/py/sdk/v2/mixins/kg.py index 4310dac62..4cb1cdfc2 100644 --- a/py/sdk/v2/mixins/kg.py +++ b/py/sdk/v2/mixins/kg.py @@ -103,39 +103,39 @@ async def get_entities( return await self._make_request("GET", "entities", params=params) # type: ignore - async def get_triples( + async def get_relationships( self, collection_id: Optional[Union[UUID, str]] = None, entity_names: Optional[list[str]] = None, - triple_ids: Optional[list[str]] = None, + relationship_ids: Optional[list[str]] = None, offset: Optional[int] = None, limit: Optional[int] = None, ) -> dict: """ - Retrieve triples from the knowledge graph. + Retrieve relationships from the knowledge graph. Args: - collection_id (str): The ID of the collection to retrieve triples from. + collection_id (str): The ID of the collection to retrieve relationships from. offset (int): The offset for pagination. limit (int): The limit for pagination. entity_names (Optional[List[str]]): Optional list of entity names to filter by. - triple_ids (Optional[List[str]]): Optional list of triple IDs to filter by. + relationship_ids (Optional[List[str]]): Optional list of relationship IDs to filter by. Returns: - dict: A dictionary containing the retrieved triples and total count. + dict: A dictionary containing the retrieved relationships and total count. """ params = { "collection_id": collection_id, "entity_names": entity_names, - "triple_ids": triple_ids, + "relationship_ids": relationship_ids, "offset": offset, "limit": limit, } params = {k: v for k, v in params.items() if v is not None} - return await self._make_request("GET", "triples", params=params) # type: ignore + return await self._make_request("GET", "relationships", params=params) # type: ignore async def get_communities( self, @@ -245,9 +245,9 @@ async def delete_graph_for_collection( Args: collection_id (Union[UUID, str]): The ID of the collection to delete the graph for. - cascade (bool): Whether to cascade the deletion, and delete entities and triples belonging to the collection. + cascade (bool): Whether to cascade the deletion, and delete entities and relationships belonging to the collection. - NOTE: Setting this flag to true will delete entities and triples for documents that are shared across multiple collections. Do not set this flag unless you are absolutely sure that you want to delete the entities and triples for all documents in the collection. + NOTE: Setting this flag to true will delete entities and relationships for documents that are shared across multiple collections. Do not set this flag unless you are absolutely sure that you want to delete the entities and relationships for all documents in the collection. """ data = { diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index 6aab179b6..a0fb7cb7e 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -181,8 +181,6 @@ async def list_entities( self, collection_id: Union[str, UUID], level=EntityLevel.DOCUMENT, - offset: int = 0, - limit: int = 100, include_embeddings: bool = False, offset: Optional[int] = 0, limit: Optional[int] = 100, @@ -372,7 +370,7 @@ async def create_communities( self, collection_id: Union[str, UUID], run_type: Optional[Union[str, KGRunType]] = None, - settings: Optional[Dict[str, Any]] = None, + settings: Optional[dict[str, Any]] = None, run_with_orchestration: bool = True, ): # -> WrappedKGCommunitiesResponse: """ @@ -535,7 +533,7 @@ async def tune_prompt( Args: collection_id (Union[str, UUID]): Collection ID to tune prompt for - prompt_name (str): Name of prompt to tune (graphrag_triples_extraction_few_shot, + prompt_name (str): Name of prompt to tune (graphrag_relationships_extraction_few_shot, graphrag_entity_description, or graphrag_community_reports) documents_offset (int): Document pagination offset documents_limit (int): Maximum number of documents to use diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 088f74115..b6b4417e3 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -238,7 +238,7 @@ class CommunityInfo(BaseModel): level: int is_final_cluster: bool collection_id: uuid.UUID - triple_ids: Optional[list[int]] = None + relationship_ids: Optional[list[int]] = None def __init__(self, **kwargs): super().__init__(**kwargs) @@ -251,7 +251,7 @@ def from_dict(cls, d: dict[str, Any]) -> "CommunityInfo": parent_cluster=d["parent_cluster"], level=d["level"], is_final_cluster=d["is_final_cluster"], - triple_ids=d["triple_ids"], + relationship_ids=d["relationship_ids"], collection_id=d["collection_id"], ) @@ -331,4 +331,4 @@ class KGExtraction(R2RSerializable): extraction_ids: list[uuid.UUID] document_id: uuid.UUID entities: list[Entity] - triples: list[Relationship] + relationships: list[Relationship] diff --git a/py/shared/abstractions/kg.py b/py/shared/abstractions/kg.py index e6b8d7f9f..699a5e1d6 100644 --- a/py/shared/abstractions/kg.py +++ b/py/shared/abstractions/kg.py @@ -30,10 +30,10 @@ def __str__(self): class KGCreationSettings(R2RSerializable): """Settings for knowledge graph creation.""" - graphrag_triples_extraction_few_shot: str = Field( - default="graphrag_triples_extraction_few_shot", + graphrag_relationships_extraction_few_shot: str = Field( + default="graphrag_relationships_extraction_few_shot", description="The prompt to use for knowledge graph extraction.", - alias="graphrag_triples_extraction_few_shot_prompt", # TODO - mark deprecated & remove + alias="graphrag_relationships_extraction_few_shot_prompt", # TODO - mark deprecated & remove ) graphrag_entity_description: str = Field( @@ -62,9 +62,9 @@ class KGCreationSettings(R2RSerializable): description="The number of extractions to merge into a single KG extraction.", ) - max_knowledge_triples: int = Field( + max_knowledge_relationships: int = Field( default=100, - description="The maximum number of knowledge triples to extract from each chunk.", + description="The maximum number of knowledge relationships to extract from each chunk.", ) max_description_input_length: int = Field( diff --git a/py/shared/abstractions/vector.py b/py/shared/abstractions/vector.py index fc7beb8f6..85742da8b 100644 --- a/py/shared/abstractions/vector.py +++ b/py/shared/abstractions/vector.py @@ -117,8 +117,8 @@ class VectorTableName(str, Enum): VECTORS = "vectors" ENTITIES_DOCUMENT = "document_entity" ENTITIES_COLLECTION = "collection_entity" - # TODO: Add support for triples - # TRIPLES = "chunk_triple" + # TODO: Add support for relationships + # TRIPLES = "chunk_relationship" COMMUNITIES = "community_report" def __str__(self) -> str: diff --git a/py/shared/api/models/__init__.py b/py/shared/api/models/__init__.py index e75900a91..b49d75be3 100644 --- a/py/shared/api/models/__init__.py +++ b/py/shared/api/models/__init__.py @@ -23,6 +23,7 @@ WrappedKGCreationResponse, WrappedKGEnrichmentResponse, WrappedKGEntityDeduplicationResponse, + WrappedKGRelationshipsResponse, ) from shared.api.models.management.responses import ( AnalyticsResponse, diff --git a/py/shared/api/models/kg/responses.py b/py/shared/api/models/kg/responses.py index 36813c727..39364cb75 100644 --- a/py/shared/api/models/kg/responses.py +++ b/py/shared/api/models/kg/responses.py @@ -5,7 +5,7 @@ from shared.abstractions.base import R2RSerializable from shared.abstractions.graph import CommunityReport, Entity, Relationship -from shared.api.models.base import ResultsWrapper +from shared.api.models.base import ResultsWrapper, PaginatedResultsWrapper class KGCreationResponse(BaseModel): @@ -75,9 +75,9 @@ class KGCreationEstimationResponse(R2RSerializable): description="The estimated number of entities in the graph.", ) - estimated_triples: Optional[str] = Field( + estimated_relationships: Optional[str] = Field( default=None, - description="The estimated number of triples in the graph.", + description="The estimated number of relationships in the graph.", ) estimated_llm_calls: Optional[str] = Field( @@ -148,9 +148,9 @@ class KGEnrichmentEstimationResponse(R2RSerializable): description="The total number of entities in the graph.", ) - total_triples: Optional[int] = Field( + total_relationships: Optional[int] = Field( default=None, - description="The total number of triples in the graph.", + description="The total number of relationships in the graph.", ) estimated_llm_calls: Optional[str] = Field( @@ -207,23 +207,23 @@ class Config: } -class KGTriplesResponse(R2RSerializable): - """Response for knowledge graph triples.""" +class KGRelationshipsResponse(R2RSerializable): + """Response for knowledge graph relationships.""" - triples: list[Relationship] = Field( + relationships: list[Relationship] = Field( ..., - description="The list of triples in the graph.", + description="The list of relationships in the graph.", ) total_entries: int = Field( ..., - description="The total number of triples in the graph for the collection or document.", + description="The total number of relationships in the graph for the collection or document.", ) class Config: json_schema_extra = { "example": { - "triples": [ + "relationships": [ { "subject": "Paris", "predicate": "is capital of", @@ -306,8 +306,12 @@ class Config: WrappedKGEnrichmentResponse = ResultsWrapper[ Union[KGEnrichmentResponse, KGEnrichmentEstimationResponse] ] -WrappedKGEntitiesResponse = ResultsWrapper[KGEntitiesResponse] -WrappedKGTriplesResponse = ResultsWrapper[KGTriplesResponse] + +# KG Entities +WrappedKGEntityResponse = ResultsWrapper[KGEntitiesResponse] +WrappedKGEntitiesResponse = PaginatedResultsWrapper[KGEntitiesResponse] +WrappedKGRelationshipsResponse = PaginatedResultsWrapper[KGRelationshipsResponse] + WrappedKGTunePromptResponse = ResultsWrapper[KGTunePromptResponse] WrappedKGCommunitiesResponse = ResultsWrapper[KGCommunitiesResponse] WrappedKGEntityDeduplicationResponse = ResultsWrapper[ diff --git a/py/tests/core/pipes/test_kg_community_summary_pipe.py b/py/tests/core/pipes/test_kg_community_summary_pipe.py index 02ec7d280..33cb7667a 100644 --- a/py/tests/core/pipes/test_kg_community_summary_pipe.py +++ b/py/tests/core/pipes/test_kg_community_summary_pipe.py @@ -123,7 +123,7 @@ def entities_list(extraction_ids, document_id, embedding_vectors): @pytest.fixture(scope="function") -def triples_raw_list(embedding_vectors, extraction_ids, document_id): +def relationships_raw_list(embedding_vectors, extraction_ids, document_id): return [ Relationship( id=1, @@ -156,24 +156,24 @@ def triples_raw_list(embedding_vectors, extraction_ids, document_id): async def test_community_summary_prompt( kg_community_summary_pipe, entities_list, - triples_raw_list, + relationships_raw_list, max_summary_input_length, ): summary = await kg_community_summary_pipe.community_summary_prompt( - entities_list, triples_raw_list, max_summary_input_length + entities_list, relationships_raw_list, max_summary_input_length ) expected_summary = """ Entity: Entity1 Descriptions: 1,Description1 - Triples: + Relationships: 1,Entity1,object1,predicate1,description1 Entity: Entity2 Descriptions: 2,Description2 - Triples: + Relationships: 2,Entity2,object2,predicate2,description2 """ - # "\n Entity: Entity1\n Descriptions: \n 1,Description1\n Triples: \n 1,Entity1,object1,predicate1,description1\n \n Entity: Entity2\n Descriptions: \n 2,Description2\n Triples: \n 2,Entity2,object2,predicate2,description2\n " + # "\n Entity: Entity1\n Descriptions: \n 1,Description1\n Relationships: \n 1,Entity1,object1,predicate1,description1\n \n Entity: Entity2\n Descriptions: \n 2,Description2\n Relationships: \n 2,Entity2,object2,predicate2,description2\n " assert summary.strip() == expected_summary.strip() diff --git a/py/tests/core/providers/kg/test_kg_logic.py b/py/tests/core/providers/kg/test_kg_logic.py index 0357a6a5c..3c74ef041 100644 --- a/py/tests/core/providers/kg/test_kg_logic.py +++ b/py/tests/core/providers/kg/test_kg_logic.py @@ -93,7 +93,7 @@ def entities_list(extraction_ids, document_id, embedding_vectors): @pytest.fixture(scope="function") -def triples_raw_list(embedding_vectors, extraction_ids, document_id): +def relationships_raw_list(embedding_vectors, extraction_ids, document_id): return [ Relationship( subject="Entity1", @@ -121,19 +121,19 @@ def triples_raw_list(embedding_vectors, extraction_ids, document_id): @pytest.fixture(scope="function") -def communities_list(entities_list, triples_raw_list): +def communities_list(entities_list, relationships_raw_list): return [ Community( name="Community1", description="Description1", entities=[entities_list[0]], - triples=[triples_raw_list[0]], + relationships=[relationships_raw_list[0]], ), Community( name="Community2", description="Description2", entities=[entities_list[1]], - triples=[triples_raw_list[1]], + relationships=[relationships_raw_list[1]], ), ] @@ -148,13 +148,13 @@ def community_table_info(collection_id): @pytest.fixture(scope="function") def kg_extractions( - extraction_ids, entities_raw_list, triples_raw_list, document_id + extraction_ids, entities_raw_list, relationships_raw_list, document_id ): return [ KGExtraction( extraction_ids=extraction_ids, entities=entities_raw_list, - triples=triples_raw_list, + relationships=relationships_raw_list, document_id=document_id, ) ] @@ -199,8 +199,8 @@ async def test_create_tables( "entities": [], "total_entries": 0, } - assert await postgres_db_provider.get_triples(collection_id) == { - "triples": [], + assert await postgres_db_provider.get_relationships(collection_id) == { + "relationships": [], "total_entries": 0, } assert await postgres_db_provider.get_communities(collection_id) == { @@ -242,17 +242,17 @@ async def test_add_entities( @pytest.mark.asyncio -async def test_add_triples( - postgres_db_provider, triples_raw_list, collection_id +async def test_add_relationships( + postgres_db_provider, relationships_raw_list, collection_id ): - await postgres_db_provider.add_triples( - triples_raw_list, table_name="chunk_triple" + await postgres_db_provider.add_relationships( + relationships_raw_list, table_name="chunk_relationship" ) - triples = await postgres_db_provider.get_triples(collection_id) - assert triples["triples"][0].subject == "Entity1" - assert triples["triples"][1].subject == "Entity2" - assert len(triples["triples"]) == 2 - assert triples["total_entries"] == 2 + relationships = await postgres_db_provider.get_relationships(collection_id) + assert relationships["relationships"][0].subject == "Entity1" + assert relationships["relationships"][1].subject == "Entity2" + assert len(relationships["relationships"]) == 2 + assert relationships["total_entries"] == 2 @pytest.mark.asyncio @@ -273,16 +273,16 @@ async def test_add_kg_extractions( assert len(entities["entities"]) == 2 assert entities["total_entries"] == 2 - triples = await postgres_db_provider.get_triples(collection_id) - assert triples["triples"][0].subject == "Entity1" - assert triples["triples"][1].subject == "Entity2" - assert len(triples["triples"]) == 2 - assert triples["total_entries"] == 2 + relationships = await postgres_db_provider.get_relationships(collection_id) + assert relationships["relationships"][0].subject == "Entity1" + assert relationships["relationships"][1].subject == "Entity2" + assert len(relationships["relationships"]) == 2 + assert relationships["total_entries"] == 2 @pytest.mark.asyncio async def test_get_entity_map( - postgres_db_provider, entities_raw_list, triples_raw_list, document_id + postgres_db_provider, entities_raw_list, relationships_raw_list, document_id ): await postgres_db_provider.add_entities( entities_raw_list, table_name="chunk_entity" @@ -291,13 +291,13 @@ async def test_get_entity_map( assert entity_map["Entity1"]["entities"][0].name == "Entity1" assert entity_map["Entity2"]["entities"][0].name == "Entity2" - await postgres_db_provider.add_triples(triples_raw_list) + await postgres_db_provider.add_relationships(relationships_raw_list) entity_map = await postgres_db_provider.get_entity_map(0, 2, document_id) assert entity_map["Entity1"]["entities"][0].name == "Entity1" assert entity_map["Entity2"]["entities"][0].name == "Entity2" - assert entity_map["Entity1"]["triples"][0].subject == "Entity1" - assert entity_map["Entity2"]["triples"][0].subject == "Entity2" + assert entity_map["Entity1"]["relationships"][0].subject == "Entity1" + assert entity_map["Entity2"]["relationships"][0].subject == "Entity2" @pytest.mark.asyncio @@ -329,14 +329,14 @@ async def test_upsert_embeddings( @pytest.mark.asyncio -async def test_get_all_triples( - postgres_db_provider, collection_id, triples_raw_list +async def test_get_all_relationships( + postgres_db_provider, collection_id, relationships_raw_list ): - await postgres_db_provider.add_triples(triples_raw_list) - triples = await postgres_db_provider.get_triples(collection_id) - assert triples["triples"][0].subject == "Entity1" - assert triples["triples"][1].subject == "Entity2" - assert len(triples["triples"]) == 2 + await postgres_db_provider.add_relationships(relationships_raw_list) + relationships = await postgres_db_provider.get_relationships(collection_id) + assert relationships["relationships"][0].subject == "Entity1" + assert relationships["relationships"][1].subject == "Entity2" + assert len(relationships["relationships"]) == 2 @pytest.mark.asyncio @@ -366,15 +366,15 @@ async def test_perform_graph_clustering( collection_id, leiden_params_1, entities_list, - triples_raw_list, + relationships_raw_list, ): - # addd entities and triples + # addd entities and relationships await postgres_db_provider.add_entities( entities_list, table_name="document_entity" ) - await postgres_db_provider.add_triples( - triples_raw_list, table_name="chunk_triple" + await postgres_db_provider.add_relationships( + relationships_raw_list, table_name="chunk_relationship" ) num_communities = await postgres_db_provider.perform_graph_clustering( @@ -387,7 +387,7 @@ async def test_perform_graph_clustering( async def test_get_community_details( postgres_db_provider, entities_list, - triples_raw_list, + relationships_raw_list, collection_id, community_report_list, community_table_info, @@ -396,13 +396,13 @@ async def test_get_community_details( await postgres_db_provider.add_entities( entities_list, table_name="document_entity" ) - await postgres_db_provider.add_triples( - triples_raw_list, table_name="chunk_triple" + await postgres_db_provider.add_relationships( + relationships_raw_list, table_name="chunk_relationship" ) await postgres_db_provider.add_community_info(community_table_info) await postgres_db_provider.add_community_report(community_report_list[0]) - community_level, entities, triples = ( + community_level, entities, relationships = ( await postgres_db_provider.get_community_details( community_number=1, collection_id=collection_id ) @@ -411,4 +411,4 @@ async def test_get_community_details( assert community_level == 0 # TODO: change these to objects assert entities[0].name == "Entity1" - assert triples[0].subject == "Entity1" + assert relationships[0].subject == "Entity1" diff --git a/py/tests/integration/runner_cli.py b/py/tests/integration/runner_cli.py index 576d82f5a..7a966ad73 100644 --- a/py/tests/integration/runner_cli.py +++ b/py/tests/integration/runner_cli.py @@ -466,11 +466,11 @@ def test_kg_delete_graph_with_cascading_sample_file_cli(): assert response.json()["results"]["entities"] == [] response = requests.get( - "http://localhost:7272/v2/triples", + "http://localhost:7272/v2/relationships", params={"collection_id": "122fdf6a-e116-546b-a8f6-e4cb2e2c0a09"}, ) - assert response.json()["results"]["triples"] == [] + assert response.json()["results"]["relationships"] == [] print("KG delete graph with cascading test passed") print("~" * 100) diff --git a/templates/ycombinator_graphrag/python-backend/prompts.yaml b/templates/ycombinator_graphrag/python-backend/prompts.yaml index 1506c1f7d..d65a14dbf 100644 --- a/templates/ycombinator_graphrag/python-backend/prompts.yaml +++ b/templates/ycombinator_graphrag/python-backend/prompts.yaml @@ -1,8 +1,8 @@ -graphrag_triples_extraction_few_shot: +graphrag_relationships_extraction_few_shot: template: > -Goal- Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities. - Given the text, extract up to {max_knowledge_triples} entity-relation triplets. + Given the text, extract up to {max_knowledge_relationships} entity-relation relationshipts. -Steps- 1. Identify all entities. For each identified entity, extract the following information: - entity_name: Name of the entity, capitalized @@ -117,7 +117,7 @@ graphrag_triples_extraction_few_shot: Output: input_types: - max_knowledge_triples: int + max_knowledge_relationships: int input: str entity_types: list[str] relation_types: list[str] diff --git a/templates/ycombinator_graphrag/web-app/types.ts b/templates/ycombinator_graphrag/web-app/types.ts index cccb9ecbe..fae667360 100644 --- a/templates/ycombinator_graphrag/web-app/types.ts +++ b/templates/ycombinator_graphrag/web-app/types.ts @@ -335,7 +335,7 @@ export interface KGEntity { description: string; } -export interface KGTriple { +export interface KGRelationship { subject: string; predicate: string; object: string; From db2cd7741b8336954a46d52547d095e7b2dac588 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 13 Nov 2024 10:34:19 -0800 Subject: [PATCH 07/21] response models --- py/core/base/api/models/__init__.py | 38 +- py/core/base/providers/database.py | 27 +- py/core/main/api/v2/kg_router.py | 5 +- py/core/main/api/v3/graph_router.py | 407 ++++++++++++++----- py/core/main/assembly/factory.py | 4 +- py/core/main/services/kg_service.py | 8 +- py/core/pipes/kg/community_summary.py | 12 +- py/core/pipes/kg/relationships_extraction.py | 8 +- py/core/providers/database/kg.py | 130 ++++-- py/shared/api/models/kg/responses.py | 4 +- py/shared/api/models/kg/responses_v3.py | 258 ++++++++++++ py/tests/core/providers/kg/test_kg_logic.py | 5 +- 12 files changed, 726 insertions(+), 180 deletions(-) create mode 100644 py/shared/api/models/kg/responses_v3.py diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index d71a4eef0..cbffaa69c 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -19,10 +19,7 @@ WrappedUpdateResponse, ) from shared.api.models.kg.responses import ( - KGCreationEstimationResponse, KGCreationResponse, - KGDeduplicationEstimationResponse, - KGEnrichmentEstimationResponse, KGEnrichmentResponse, KGEntityDeduplicationResponse, WrappedKGCommunitiesResponse, @@ -33,6 +30,23 @@ WrappedKGRelationshipsResponse, WrappedKGTunePromptResponse, ) + + +from shared.api.models.kg.responses_v3 import ( + WrappedKGEntitiesResponse as WrappedKGEntitiesResponseV3, + WrappedKGRelationshipsResponse as WrappedKGRelationshipsResponseV3, + WrappedKGCommunitiesResponse as WrappedKGCommunitiesResponseV3, + WrappedKGCreationResponse as WrappedKGCreationResponseV3, + WrappedKGEnrichmentResponse as WrappedKGEnrichmentResponseV3, + WrappedKGTunePromptResponse as WrappedKGTunePromptResponseV3, + WrappedKGEntityDeduplicationResponse as WrappedKGEntityDeduplicationResponseV3, + KGCreationResponse as KGCreationResponseV3, + KGEnrichmentResponse as KGEnrichmentResponseV3, + KGEntityDeduplicationResponse as KGEntityDeduplicationResponseV3, + KGTunePromptResponse as KGTunePromptResponseV3, +) + + from shared.api.models.management.responses import ( AnalyticsResponse, AppSettingsResponse, @@ -97,7 +111,8 @@ "WrappedMetadataUpdateResponse", "WrappedListVectorIndicesResponse", "UpdateResponse", - # Knowledge Graph Responses + # Knowledge Graph Responses for V2 + # will be removed eventually "KGCreationResponse", "WrappedKGCreationResponse", "KGEnrichmentResponse", @@ -105,9 +120,18 @@ "KGEntityDeduplicationResponse", "WrappedKGEntityDeduplicationResponse", "WrappedKGTunePromptResponse", - "KGCreationEstimationResponse", - "KGDeduplicationEstimationResponse", - "KGEnrichmentEstimationResponse", + # Knowledge Graph Responses for V3 + "WrappedKGEntitiesResponseV3", + "WrappedKGRelationshipsResponseV3", + "WrappedKGCommunitiesResponseV3", + "WrappedKGCreationResponseV3", + "WrappedKGEnrichmentResponseV3", + "WrappedKGTunePromptResponseV3", + "WrappedKGEntityDeduplicationResponseV3", + "KGCreationResponseV3", + "KGEnrichmentResponseV3", + "KGEntityDeduplicationResponseV3", + "KGTunePromptResponseV3", # Management Responses "PromptResponse", "ServerStats", diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index cf5ae10b7..7600bc3b1 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -42,9 +42,6 @@ ) from core.base.api.models import ( CollectionResponse, - KGCreationEstimationResponse, - KGDeduplicationEstimationResponse, - KGEnrichmentEstimationResponse, UserResponse, ) @@ -775,14 +772,14 @@ async def get_relationship_count( @abstractmethod async def get_creation_estimate( self, collection_id: UUID, kg_creation_settings: KGCreationSettings - ) -> KGCreationEstimationResponse: + ): """Get creation cost estimate.""" pass @abstractmethod async def get_enrichment_estimate( self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings - ) -> KGEnrichmentEstimationResponse: + ): """Get enrichment cost estimate.""" pass @@ -791,7 +788,7 @@ async def get_deduplication_estimate( self, collection_id: UUID, kg_deduplication_settings: KGEntityDeduplicationSettings, - ) -> KGDeduplicationEstimationResponse: + ): """Get deduplication cost estimate.""" pass @@ -839,7 +836,9 @@ async def get_existing_entity_extraction_ids( raise NotImplementedError @abstractmethod - async def get_all_relationships(self, collection_id: UUID) -> list[Relationship]: + async def get_all_relationships( + self, collection_id: UUID + ) -> list[Relationship]: raise NotImplementedError @abstractmethod @@ -1523,7 +1522,9 @@ async def add_relationships( table_name: str = "chunk_relationship", ) -> None: """Forward to KG handler add_relationships method.""" - return await self.kg_handler.add_relationships(relationships, table_name) + return await self.kg_handler.add_relationships( + relationships, table_name + ) async def get_entity_map( self, offset: int, limit: int, document_id: UUID @@ -1680,7 +1681,7 @@ async def get_relationship_count( # Estimation methods async def get_creation_estimate( self, collection_id: UUID, kg_creation_settings: KGCreationSettings - ) -> KGCreationEstimationResponse: + ): """Forward to KG handler get_creation_estimate method.""" return await self.kg_handler.get_creation_estimate( collection_id, kg_creation_settings @@ -1688,7 +1689,7 @@ async def get_creation_estimate( async def get_enrichment_estimate( self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings - ) -> KGEnrichmentEstimationResponse: + ): """Forward to KG handler get_enrichment_estimate method.""" return await self.kg_handler.get_enrichment_estimate( collection_id, kg_enrichment_settings @@ -1698,13 +1699,15 @@ async def get_deduplication_estimate( self, collection_id: UUID, kg_deduplication_settings: KGEntityDeduplicationSettings, - ) -> KGDeduplicationEstimationResponse: + ): """Forward to KG handler get_deduplication_estimate method.""" return await self.kg_handler.get_deduplication_estimate( collection_id, kg_deduplication_settings ) - async def get_all_relationships(self, collection_id: UUID) -> list[Relationship]: + async def get_all_relationships( + self, collection_id: UUID + ) -> list[Relationship]: return await self.kg_handler.get_all_relationships(collection_id) async def update_entity_descriptions(self, entities: list[Entity]): diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index ad335f4b6..a591f0e31 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -289,11 +289,12 @@ async def get_entities( limit=limit, ) - @self.router.get("/relationships") + @self.router.get("/triples") @self.base_endpoint async def get_relationships( collection_id: Optional[UUID] = Query( - None, description="Collection ID to retrieve relationships from." + None, + description="Collection ID to retrieve relationships from.", ), entity_names: Optional[list[str]] = Query( None, description="Entity names to filter by." diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 3661f8be0..c44488794 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -8,13 +8,20 @@ from core.base import R2RException, RunType from core.base.abstractions import EntityLevel, KGRunType + from core.base.api.models import ( - WrappedKGCreationResponse, - WrappedKGEnrichmentResponse, - WrappedKGEntityDeduplicationResponse, - WrappedKGTunePromptResponse, - WrappedKGRelationshipsResponse, + WrappedKGCreationResponseV3 as WrappedKGCreationResponse, + WrappedKGEnrichmentResponseV3 as WrappedKGEnrichmentResponse, + WrappedKGEntityDeduplicationResponseV3 as WrappedKGEntityDeduplicationResponse, + WrappedKGTunePromptResponseV3 as WrappedKGTunePromptResponse, + WrappedKGRelationshipsResponseV3 as WrappedKGRelationshipsResponse, + WrappedKGCommunitiesResponseV3 as WrappedKGCommunitiesResponse, + KGCreationResponseV3 as KGCreationResponse, + KGEnrichmentResponseV3 as KGEnrichmentResponse, + KGEntityDeduplicationResponseV3 as KGEntityDeduplicationResponse, + KGTunePromptResponseV3 as KGTunePromptResponse, ) + from core.providers import ( HatchetOrchestrationProvider, SimpleOrchestrationProvider, @@ -23,7 +30,8 @@ generate_default_user_collection_id, update_settings_from_dict, ) -from shared.api.models.base import PaginatedResultsWrapper, ResultsWrapper + +from core.base.api.models import PaginatedResultsWrapper, ResultsWrapper from .base_router import BaseRouterV3 @@ -67,6 +75,7 @@ class Relationship(BaseModel): object_name: str predicate: str + class GraphRouter(BaseRouterV3): def __init__( self, @@ -156,23 +165,46 @@ def _setup_routes(self): @self.base_endpoint async def list_entities( request: Request, - id: UUID = Path(..., description="The ID of the chunk to retrieve entities for."), - entity_names: Optional[list[str]] = Query(None, description="A list of entity names to filter the entities by."), - entity_categories: Optional[list[str]] = Query(None, description="A list of entity categories to filter the entities by."), - attributes: Optional[list[str]] = Query(None, description="A list of attributes to return. By default, all attributes are returned."), - offset: int = Query(0, ge=0, description="The offset of the first entity to retrieve."), - limit: int = Query(100, ge=0, le=20_000, description="The maximum number of entities to retrieve, up to 20,000."), + id: UUID = Path( + ..., + description="The ID of the chunk to retrieve entities for.", + ), + entity_names: Optional[list[str]] = Query( + None, + description="A list of entity names to filter the entities by.", + ), + entity_categories: Optional[list[str]] = Query( + None, + description="A list of entity categories to filter the entities by.", + ), + attributes: Optional[list[str]] = Query( + None, + description="A list of attributes to return. By default, all attributes are returned.", + ), + offset: int = Query( + 0, + ge=0, + description="The offset of the first entity to retrieve.", + ), + limit: int = Query( + 100, + ge=0, + le=20_000, + description="The maximum number of entities to retrieve, up to 20,000.", + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> PaginatedResultsWrapper[list[Entity]]: """ - Retrieves a list of entities associated with a specific chunk. - - Note that when entities are extracted, neighboring chunks are also processed together to extract entities. - - So, the entity returned here may not be in the same chunk as the one specified, but rather in a neighboring chunk (upto 2 chunks by default). + Retrieves a list of entities associated with a specific chunk. + + Note that when entities are extracted, neighboring chunks are also processed together to extract entities. + + So, the entity returned here may not be in the same chunk as the one specified, but rather in a neighboring chunk (upto 2 chunks by default). """ if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].list_entities_v3( level=self._get_path_level(request), @@ -181,9 +213,9 @@ async def list_entities( limit=limit, entity_names=entity_names, entity_categories=entity_categories, - attributes=attributes + attributes=attributes, ) - + @self.router.post( "/chunks/{id}/entities", summary="Create entities for a chunk", @@ -250,20 +282,31 @@ async def list_entities( @self.base_endpoint async def create_entities( request: Request, - id: UUID = Path(..., description="The ID of the chunk to create entities for."), - entities: list[Union[Entity, dict]] = Body(..., description="The entities to create."), + id: UUID = Path( + ..., description="The ID of the chunk to create entities for." + ), + entities: list[Union[Entity, dict]] = Body( + ..., description="The entities to create." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) - entities = [Entity(**entity) if isinstance(entity, dict) else entity for entity in entities] + entities = [ + Entity(**entity) if isinstance(entity, dict) else entity + for entity in entities + ] # for each entity, set the level to CHUNK for entity in entities: if entity.level is None: entity.level = EntityLevel.CHUNK else: - raise R2RException("Entity level must be chunk or empty.", 400) + raise R2RException( + "Entity level must be chunk or empty.", 400 + ) return await self.services["kg"].create_entities_v3( level=self._get_path_level(request), @@ -295,13 +338,20 @@ async def create_entities( @self.base_endpoint async def update_entity( request: Request, - id: UUID = Path(..., description="The ID of the chunk to update the entity for."), - entity_id: UUID = Path(..., description="The ID of the entity to update."), + id: UUID = Path( + ..., + description="The ID of the chunk to update the entity for.", + ), + entity_id: UUID = Path( + ..., description="The ID of the entity to update." + ), entity: Entity = Body(..., description="The updated entity."), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].update_entity_v3( level=self._get_path_level(request), @@ -376,12 +426,19 @@ async def update_entity( @self.base_endpoint async def delete_entity( request: Request, - id: UUID = Path(..., description="The ID of the chunk to delete the entity for."), - entity_id: UUID = Path(..., description="The ID of the entity to delete."), + id: UUID = Path( + ..., + description="The ID of the chunk to delete the entity for.", + ), + entity_id: UUID = Path( + ..., description="The ID of the entity to delete." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].delete_entity_v3( level=self._get_path_level(request), @@ -432,7 +489,6 @@ async def delete_entity( ] }, ) - @self.router.get( "/chunks/{id}/relationships", summary="List relationships for a chunk", @@ -456,16 +512,39 @@ async def delete_entity( ) @self.base_endpoint async def list_relationships( - id: UUID = Path(..., description="The ID of the chunk to retrieve relationships for."), - entity_names: Optional[list[str]] = Query(None, description="A list of entity names to filter the relationships by."), - relationship_types: Optional[list[str]] = Query(None, description="A list of relationship types to filter the relationships by."), - attributes: Optional[list[str]] = Query(None, description="A list of attributes to return. By default, all attributes are returned."), - offset: int = Query(0, ge=0, description="The offset of the first relationship to retrieve."), - limit: int = Query(100, ge=0, le=20_000, description="The maximum number of relationships to retrieve, up to 20,000."), + id: UUID = Path( + ..., + description="The ID of the chunk to retrieve relationships for.", + ), + entity_names: Optional[list[str]] = Query( + None, + description="A list of entity names to filter the relationships by.", + ), + relationship_types: Optional[list[str]] = Query( + None, + description="A list of relationship types to filter the relationships by.", + ), + attributes: Optional[list[str]] = Query( + None, + description="A list of attributes to return. By default, all attributes are returned.", + ), + offset: int = Query( + 0, + ge=0, + description="The offset of the first relationship to retrieve.", + ), + limit: int = Query( + 100, + ge=0, + le=20_000, + description="The maximum number of relationships to retrieve, up to 20,000.", + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> PaginatedResultsWrapper[list[Relationship]]: if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].list_relationships_v3( level=EntityLevel.CHUNK, @@ -477,7 +556,6 @@ async def list_relationships( limit=limit, ) - @self.router.post( "/chunks/{id}/relationships", summary="Create relationships for a chunk", @@ -501,21 +579,34 @@ async def list_relationships( ) @self.base_endpoint async def create_relationships( - id: UUID = Path(..., description="The ID of the chunk to create relationships for."), - relationships: list[Union[Relationship, dict]] = Body(..., description="The relationships to create."), + id: UUID = Path( + ..., + description="The ID of the chunk to create relationships for.", + ), + relationships: list[Union[Relationship, dict]] = Body( + ..., description="The relationships to create." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedKGRelationshipsResponse: if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) - relationships = [Relationship(**relationship) if isinstance(relationship, dict) else relationship for relationship in relationships] + relationships = [ + ( + Relationship(**relationship) + if isinstance(relationship, dict) + else relationship + ) + for relationship in relationships + ] return await self.services["kg"].create_relationships_v3( level=EntityLevel.CHUNK, id=id, relationships=relationships, ) - @self.router.post( "/chunks/{id}/relationships/{relationship_id}", @@ -538,16 +629,24 @@ async def create_relationships( ] }, ) - @self.base_endpoint + @self.base_endpoint async def update_relationship( - id: UUID = Path(..., description="The ID of the chunk to update the relationship for."), - relationship_id: UUID = Path(..., description="The ID of the relationship to update."), - relationship: Relationship = Body(..., description="The updated relationship."), + id: UUID = Path( + ..., + description="The ID of the chunk to update the relationship for.", + ), + relationship_id: UUID = Path( + ..., description="The ID of the relationship to update." + ), + relationship: Relationship = Body( + ..., description="The updated relationship." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) - + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].update_relationship_v3( level=EntityLevel.CHUNK, @@ -562,12 +661,19 @@ async def update_relationship( ) @self.base_endpoint async def delete_relationship( - id: UUID = Path(..., description="The ID of the chunk to delete the relationship for."), - relationship_id: UUID = Path(..., description="The ID of the relationship to delete."), + id: UUID = Path( + ..., + description="The ID of the chunk to delete the relationship for.", + ), + relationship_id: UUID = Path( + ..., description="The ID of the relationship to delete." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].delete_relationship_v3( level=EntityLevel.CHUNK, @@ -599,19 +705,42 @@ async def delete_relationship( ) @self.base_endpoint async def list_entities( - id: UUID = Path(..., description="The ID of the document to retrieve entities for."), - entity_names: Optional[list[str]] = Query(None, description="A list of entity names to filter the entities by."), - entity_categories: Optional[list[str]] = Query(None, description="A list of entity categories to filter the entities by."), - attributes: Optional[list[str]] = Query(None, description="A list of attributes to return. By default, all attributes are returned."), - offset: int = Query(0, ge=0, description="The offset of the first entity to retrieve."), - limit: int = Query(100, ge=0, le=20_000, description="The maximum number of entities to retrieve, up to 20,000."), + id: UUID = Path( + ..., + description="The ID of the document to retrieve entities for.", + ), + entity_names: Optional[list[str]] = Query( + None, + description="A list of entity names to filter the entities by.", + ), + entity_categories: Optional[list[str]] = Query( + None, + description="A list of entity categories to filter the entities by.", + ), + attributes: Optional[list[str]] = Query( + None, + description="A list of attributes to return. By default, all attributes are returned.", + ), + offset: int = Query( + 0, + ge=0, + description="The offset of the first entity to retrieve.", + ), + limit: int = Query( + 100, + ge=0, + le=20_000, + description="The maximum number of entities to retrieve, up to 20,000.", + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> PaginatedResultsWrapper[list[Entity]]: """ - Retrieves a list of entities associated with a specific document. + Retrieves a list of entities associated with a specific document. """ if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].list_entities_v3( level=EntityLevel.DOCUMENT, @@ -620,9 +749,9 @@ async def list_entities( limit=limit, entity_names=entity_names, entity_categories=entity_categories, - attributes=attributes + attributes=attributes, ) - + @self.router.post( "/documents/{id}/entities", summary="Create entities for a document", @@ -646,20 +775,31 @@ async def list_entities( ) @self.base_endpoint async def create_entities( - id: UUID = Path(..., description="The ID of the chunk to create entities for."), - entities: list[Union[Entity, dict]] = Body(..., description="The entities to create."), + id: UUID = Path( + ..., description="The ID of the chunk to create entities for." + ), + entities: list[Union[Entity, dict]] = Body( + ..., description="The entities to create." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) - entities = [Entity(**entity) if isinstance(entity, dict) else entity for entity in entities] + entities = [ + Entity(**entity) if isinstance(entity, dict) else entity + for entity in entities + ] # for each entity, set the level to CHUNK for entity in entities: if entity.level is None: entity.level = EntityLevel.DOCUMENT else: - raise R2RException("Entity level must be chunk or empty.", 400) + raise R2RException( + "Entity level must be chunk or empty.", 400 + ) return await self.services["kg"].create_entities_v3( level=EntityLevel.DOCUMENT, @@ -690,13 +830,20 @@ async def create_entities( ) @self.base_endpoint async def update_entity( - id: UUID = Path(..., description="The ID of the document to update the entity for."), - entity_id: UUID = Path(..., description="The ID of the entity to update."), + id: UUID = Path( + ..., + description="The ID of the document to update the entity for.", + ), + entity_id: UUID = Path( + ..., description="The ID of the entity to update." + ), entity: Entity = Body(..., description="The updated entity."), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].update_entity_v3( level=EntityLevel.DOCUMENT, @@ -705,7 +852,6 @@ async def update_entity( entity=entity, ) - @self.router.delete( "/documents/{id}/entities/{entity_id}", summary="Delete an entity for a document", @@ -727,15 +873,21 @@ async def update_entity( ] }, ) - @self.base_endpoint async def delete_entity( - id: UUID = Path(..., description="The ID of the document to delete the entity for."), - entity_id: UUID = Path(..., description="The ID of the entity to delete."), + id: UUID = Path( + ..., + description="The ID of the document to delete the entity for.", + ), + entity_id: UUID = Path( + ..., description="The ID of the entity to delete." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) ##### RELATIONSHIPS ##### @self.router.get( @@ -761,16 +913,39 @@ async def delete_entity( ) @self.base_endpoint async def list_relationships( - id: UUID = Path(..., description="The ID of the document to retrieve relationships for."), - entity_names: Optional[list[str]] = Query(None, description="A list of entity names to filter the relationships by."), - relationship_types: Optional[list[str]] = Query(None, description="A list of relationship types to filter the relationships by."), - attributes: Optional[list[str]] = Query(None, description="A list of attributes to return. By default, all attributes are returned."), - offset: int = Query(0, ge=0, description="The offset of the first relationship to retrieve."), - limit: int = Query(100, ge=0, le=20_000, description="The maximum number of relationships to retrieve, up to 20,000."), + id: UUID = Path( + ..., + description="The ID of the document to retrieve relationships for.", + ), + entity_names: Optional[list[str]] = Query( + None, + description="A list of entity names to filter the relationships by.", + ), + relationship_types: Optional[list[str]] = Query( + None, + description="A list of relationship types to filter the relationships by.", + ), + attributes: Optional[list[str]] = Query( + None, + description="A list of attributes to return. By default, all attributes are returned.", + ), + offset: int = Query( + 0, + ge=0, + description="The offset of the first relationship to retrieve.", + ), + limit: int = Query( + 100, + ge=0, + le=20_000, + description="The maximum number of relationships to retrieve, up to 20,000.", + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> PaginatedResultsWrapper[list[Relationship]]: if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].list_relationships_v3( level=EntityLevel.DOCUMENT, @@ -782,7 +957,6 @@ async def list_relationships( limit=limit, ) - @self.router.post( "/documents/{id}/relationships", summary="Create relationships for a document", @@ -806,21 +980,34 @@ async def list_relationships( ) @self.base_endpoint async def create_relationships( - id: UUID = Path(..., description="The ID of the document to create relationships for."), - relationships: list[Union[Relationship, dict]] = Body(..., description="The relationships to create."), + id: UUID = Path( + ..., + description="The ID of the document to create relationships for.", + ), + relationships: list[Union[Relationship, dict]] = Body( + ..., description="The relationships to create." + ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[list[RelationshipResponse]]: + ) -> WrappedKGCreationResponse: if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) - relationships = [Relationship(**relationship) if isinstance(relationship, dict) else relationship for relationship in relationships] + relationships = [ + ( + Relationship(**relationship) + if isinstance(relationship, dict) + else relationship + ) + for relationship in relationships + ] return await self.services["kg"].create_relationships_v3( level=EntityLevel.DOCUMENT, id=id, relationships=relationships, ) - @self.router.post( "/documents/{id}/relationships/{relationship_id}", @@ -843,16 +1030,24 @@ async def create_relationships( ] }, ) - @self.base_endpoint + @self.base_endpoint async def update_relationship( - id: UUID = Path(..., description="The ID of the document to update the relationship for."), - relationship_id: UUID = Path(..., description="The ID of the relationship to update."), - relationship: Relationship = Body(..., description="The updated relationship."), + id: UUID = Path( + ..., + description="The ID of the document to update the relationship for.", + ), + relationship_id: UUID = Path( + ..., description="The ID of the relationship to update." + ), + relationship: Relationship = Body( + ..., description="The updated relationship." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) - + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].update_relationship_v3( level=EntityLevel.DOCUMENT, @@ -867,12 +1062,19 @@ async def update_relationship( ) @self.base_endpoint async def delete_relationship( - id: UUID = Path(..., description="The ID of the document to delete the relationship for."), - relationship_id: UUID = Path(..., description="The ID of the relationship to delete."), + id: UUID = Path( + ..., + description="The ID of the document to delete the relationship for.", + ), + relationship_id: UUID = Path( + ..., description="The ID of the relationship to delete." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].delete_relationship_v3( level=EntityLevel.DOCUMENT, @@ -882,11 +1084,6 @@ async def delete_relationship( ##### COLLECTION LEVEL OPERATIONS ##### - - - - - # Graph-level operations @self.router.post( "/graphs/{collection_id}", diff --git a/py/core/main/assembly/factory.py b/py/core/main/assembly/factory.py index fbe6f7a8d..3d7f21eb3 100644 --- a/py/core/main/assembly/factory.py +++ b/py/core/main/assembly/factory.py @@ -542,7 +542,9 @@ def create_kg_relationships_extraction_pipe(self, *args, **kwargs) -> Any: logging_provider=self.providers.logging, llm_provider=self.providers.llm, database_provider=self.providers.database, - config=AsyncPipe.PipeConfig(name="kg_relationships_extraction_pipe"), + config=AsyncPipe.PipeConfig( + name="kg_relationships_extraction_pipe" + ), ) def create_kg_storage_pipe(self, *args, **kwargs) -> Any: diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index d2f9a2a15..87270aafd 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -508,7 +508,13 @@ async def list_relationships_v3( limit: Optional[int] = None, ): return await self.providers.database.list_relationships_v3( - level, id, entity_names, relationship_types, attributes, offset, limit + level, + id, + entity_names, + relationship_types, + attributes, + offset, + limit, ) ##### Communities ##### diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index 69afaa7da..19b386e62 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -67,7 +67,9 @@ async def community_summary_prompt( "entities": [], "relationships": [], } - entity_map[relationship.subject]["relationships"].append(relationship) + entity_map[relationship.subject]["relationships"].append( + relationship + ) # sort in descending order of relationship count sorted_entity_map = sorted( @@ -90,7 +92,9 @@ async def _get_entity_descriptions_string( for entity in sampled_entities ) - async def _get_relationships_string(relationships: list, max_count: int = 100): + async def _get_relationships_string( + relationships: list, max_count: int = 100 + ): sampled_relationships = ( random.sample(relationships, max_count) if len(relationships) > max_count @@ -106,7 +110,9 @@ async def _get_relationships_string(relationships: list, max_count: int = 100): entity_descriptions = await _get_entity_descriptions_string( entity_data["entities"] ) - relationships = await _get_relationships_string(entity_data["relationships"]) + relationships = await _get_relationships_string( + entity_data["relationships"] + ) prompt += f""" Entity: {entity_name} diff --git a/py/core/pipes/kg/relationships_extraction.py b/py/core/pipes/kg/relationships_extraction.py index ce103f4c3..f7867df20 100644 --- a/py/core/pipes/kg/relationships_extraction.py +++ b/py/core/pipes/kg/relationships_extraction.py @@ -56,7 +56,9 @@ def __init__( super().__init__( logging_provider=logging_provider, config=config - or AsyncPipe.PipeConfig(name="default_kg_relationships_extraction_pipe"), + or AsyncPipe.PipeConfig( + name="default_kg_relationships_extraction_pipe" + ), ) self.database_provider = database_provider self.llm_provider = llm_provider @@ -226,7 +228,9 @@ async def _run_logic( # type: ignore document_id = input.message["document_id"] generation_config = input.message["generation_config"] extraction_merge_count = input.message["extraction_merge_count"] - max_knowledge_relationships = input.message["max_knowledge_relationships"] + max_knowledge_relationships = input.message[ + "max_knowledge_relationships" + ] entity_types = input.message["entity_types"] relation_types = input.message["relation_types"] diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 8da324acf..016839c13 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -15,7 +15,7 @@ KGExtractionStatus, KGHandler, R2RException, - Relationship + Relationship, ) from core.base.abstractions import ( CommunityInfo, @@ -26,11 +26,7 @@ KGEntityDeduplicationSettings, VectorQuantizationType, ) -from core.base.api.models import ( - KGCreationEstimationResponse, - KGDeduplicationEstimationResponse, - KGEnrichmentEstimationResponse, -) + from core.base.utils import _decorate_vector_type, llm_cost_per_million_tokens from .base import PostgresConnectionManager @@ -257,7 +253,9 @@ async def get_graph_status(self, collection_id: UUID) -> dict: [collection_id], ) - document_ids = [doc_id["document_id"] for doc_id in kg_extraction_statuses] + document_ids = [ + doc_id["document_id"] for doc_id in kg_extraction_statuses + ] kg_enrichment_statuses = await self.connection_manager.fetch_query( f"SELECT enrichment_status FROM {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} WHERE id = $1", @@ -292,7 +290,9 @@ async def get_graph_status(self, collection_id: UUID) -> dict: return { "kg_extraction_statuses": kg_extraction_statuses, - "kg_enrichment_status": kg_enrichment_statuses[0]["enrichment_status"], + "kg_enrichment_status": kg_enrichment_statuses[0][ + "enrichment_status" + ], "chunk_entity_count": chunk_entity_count[0]["count"], "chunk_relationship_count": chunk_relationship_count[0]["count"], "document_entity_count": document_entity_count[0]["count"], @@ -355,11 +355,14 @@ async def list_relationships_v3( {filter_query} """ - results = await self.connection_manager.fetch_query(QUERY, [id, entity_names, relationship_types]) + results = await self.connection_manager.fetch_query( + QUERY, [id, entity_names, relationship_types] + ) if attributes: results = [ - {k: v for k, v in result.items() if k in attributes} for result in results + {k: v for k, v in result.items() if k in attributes} + for result in results ] return results @@ -412,10 +415,13 @@ async def add_kg_extractions( extraction.relationships[i].extraction_ids = ( extraction.extraction_ids ) - extraction.relationships[i].document_id = extraction.document_id + extraction.relationships[i].document_id = ( + extraction.document_id + ) await self.add_relationships( - extraction.relationships, table_name=f"{table_prefix}relationship" + extraction.relationships, + table_name=f"{table_prefix}relationship", ) return (total_entities, total_relationships) @@ -495,9 +501,13 @@ async def get_entity_map( for relationship in relationships_list: if relationship.subject in entity_map: - entity_map[relationship.subject]["relationships"].append(relationship) + entity_map[relationship.subject]["relationships"].append( + relationship + ) if relationship.object in entity_map: - entity_map[relationship.object]["relationships"].append(relationship) + entity_map[relationship.object]["relationships"].append( + relationship + ) return entity_map @@ -600,7 +610,9 @@ async def vector_query( # type: ignore for property_name in property_names } - async def get_all_relationships(self, collection_id: UUID) -> list[Relationship]: + async def get_all_relationships( + self, collection_id: UUID + ) -> list[Relationship]: # getting all documents for a collection if document_ids is None: @@ -857,10 +869,17 @@ async def _get_relationship_ids_cache( and relationship.object is not None ): relationship_ids_cache[relationship.object] = [] - if relationship.subject is not None and relationship.id is not None: - relationship_ids_cache[relationship.subject].append(relationship.id) + if ( + relationship.subject is not None + and relationship.id is not None + ): + relationship_ids_cache[relationship.subject].append( + relationship.id + ) if relationship.object is not None and relationship.id is not None: - relationship_ids_cache[relationship.object].append(relationship.id) + relationship_ids_cache[relationship.object].append( + relationship.id + ) return relationship_ids_cache @@ -988,15 +1007,22 @@ async def perform_graph_clustering( logger.info(f"Clustering with settings: {leiden_params}") - relationship_ids_cache = await self._get_relationship_ids_cache(relationships) + relationship_ids_cache = await self._get_relationship_ids_cache( + relationships + ) - if await self._use_community_cache(collection_id, relationship_ids_cache): + if await self._use_community_cache( + collection_id, relationship_ids_cache + ): num_communities = await self._incremental_clustering( relationship_ids_cache, leiden_params, collection_id ) else: num_communities = await self._cluster_and_add_community_info( - relationships, relationship_ids_cache, leiden_params, collection_id + relationships, + relationship_ids_cache, + leiden_params, + collection_id, ) return num_communities @@ -1088,7 +1114,9 @@ async def get_community_details( relationships = await self.connection_manager.fetch_query( QUERY, [community_number, collection_id] ) - relationships = [Relationship(**relationship) for relationship in relationships] + relationships = [ + Relationship(**relationship) for relationship in relationships + ] return level, entities, relationships @@ -1102,13 +1130,11 @@ async def get_community_details( async def create_entities( self, level: EntityLevel, id: UUID, entities: list[Entity] ) -> None: - + # TODO: check if already exists await self._add_objects(entities, level.table_name) - async def update_entity( - self, collection_id: UUID, entity: Entity - ) -> None: + async def update_entity(self, collection_id: UUID, entity: Entity) -> None: table_name = entity.level.value + "_entity" # check if the entity already exists @@ -1116,7 +1142,9 @@ async def update_entity( SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 """ count = ( - await self.connection_manager.fetch_query(QUERY, [entity.id, collection_id]) + await self.connection_manager.fetch_query( + QUERY, [entity.id, collection_id] + ) )[0]["count"] if count == 0: @@ -1124,15 +1152,15 @@ async def update_entity( await self._add_objects([entity], table_name) - async def delete_entity( - self, collection_id: UUID, entity: Entity - ) -> None: + async def delete_entity(self, collection_id: UUID, entity: Entity) -> None: table_name = entity.level.value + "_entity" QUERY = f""" DELETE FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 """ - await self.connection_manager.execute_query(QUERY, [entity.id, collection_id]) + await self.connection_manager.execute_query( + QUERY, [entity.id, collection_id] + ) ############################################################ ########## Relationship CRUD Operations #################### @@ -1141,13 +1169,21 @@ async def delete_entity( async def create_relationship( self, collection_id: UUID, relationship: Relationship ) -> None: - + # check if the relationship already exists QUERY = f""" SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE subject = $1 AND predicate = $2 AND object = $3 AND collection_id = $4 """ count = ( - await self.connection_manager.fetch_query(QUERY, [relationship.subject, relationship.predicate, relationship.object, collection_id]) + await self.connection_manager.fetch_query( + QUERY, + [ + relationship.subject, + relationship.predicate, + relationship.object, + collection_id, + ], + ) )[0]["count"] if count > 0: @@ -1158,7 +1194,7 @@ async def create_relationship( async def update_relationship( self, relationship_id: UUID, relationship: Relationship ) -> None: - + # check if relationship_id exists QUERY = f""" SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 @@ -1172,9 +1208,7 @@ async def update_relationship( await self._add_objects([relationship], "chunk_relationship") - async def delete_relationship( - self, relationship_id: UUID - ) -> None: + async def delete_relationship(self, relationship_id: UUID) -> None: QUERY = f""" DELETE FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 """ @@ -1346,7 +1380,7 @@ async def get_existing_entity_extraction_ids( async def get_creation_estimate( self, collection_id: UUID, kg_creation_settings: KGCreationSettings - ) -> KGCreationEstimationResponse: + ): # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id. document_ids = [ @@ -1520,11 +1554,13 @@ async def get_entities_v3( offset: int = 0, limit: int = -1, ): - + params: list = [id] if level != EntityLevel.CHUNK and entity_categories: - raise ValueError("entity_categories are only supported for chunk level entities") + raise ValueError( + "entity_categories are only supported for chunk level entities" + ) filter = { EntityLevel.CHUNK: "chunk_ids = ANY($1)", @@ -1550,12 +1586,12 @@ async def get_entities_v3( output = await self.connection_manager.fetch_query(QUERY, params) if attributes: - output = [entity for entity in output if entity["name"] in attributes] + output = [ + entity for entity in output if entity["name"] in attributes + ] return output - - # TODO: deprecate this async def get_entities( self, @@ -1674,8 +1710,12 @@ async def get_relationships( {pagination_clause} """ - relationships = await self.connection_manager.fetch_query(query, params) - relationships = [Relationship(**relationship) for relationship in relationships] + relationships = await self.connection_manager.fetch_query( + query, params + ) + relationships = [ + Relationship(**relationship) for relationship in relationships + ] total_entries = await self.get_relationship_count( collection_id=collection_id ) diff --git a/py/shared/api/models/kg/responses.py b/py/shared/api/models/kg/responses.py index 39364cb75..3cde5a192 100644 --- a/py/shared/api/models/kg/responses.py +++ b/py/shared/api/models/kg/responses.py @@ -310,7 +310,9 @@ class Config: # KG Entities WrappedKGEntityResponse = ResultsWrapper[KGEntitiesResponse] WrappedKGEntitiesResponse = PaginatedResultsWrapper[KGEntitiesResponse] -WrappedKGRelationshipsResponse = PaginatedResultsWrapper[KGRelationshipsResponse] +WrappedKGRelationshipsResponse = PaginatedResultsWrapper[ + KGRelationshipsResponse +] WrappedKGTunePromptResponse = ResultsWrapper[KGTunePromptResponse] WrappedKGCommunitiesResponse = ResultsWrapper[KGCommunitiesResponse] diff --git a/py/shared/api/models/kg/responses_v3.py b/py/shared/api/models/kg/responses_v3.py new file mode 100644 index 000000000..b2f585142 --- /dev/null +++ b/py/shared/api/models/kg/responses_v3.py @@ -0,0 +1,258 @@ +from typing import Optional, Union +from uuid import UUID + +from pydantic import BaseModel, Field + +from shared.abstractions.base import R2RSerializable +from shared.abstractions.graph import CommunityReport, Entity, Relationship +from shared.api.models.base import ResultsWrapper, PaginatedResultsWrapper + + +############# ESTIMATE MODELS ############# + + +class KGCreationEstimate(R2RSerializable): + """Response for knowledge graph creation estimation.""" + + document_count: Optional[int] = Field( + default=None, + description="The number of documents in the collection.", + ) + + number_of_jobs_created: Optional[int] = Field( + default=None, + description="The number of jobs created for the graph creation process.", + ) + + total_chunks: Optional[int] = Field( + default=None, + description="The estimated total number of chunks.", + ) + + estimated_entities: Optional[str] = Field( + default=None, + description="The estimated number of entities in the graph.", + ) + + estimated_relationships: Optional[str] = Field( + default=None, + description="The estimated number of relationships in the graph.", + ) + + estimated_llm_calls: Optional[str] = Field( + default=None, + description="The estimated number of LLM calls in millions.", + ) + + estimated_total_in_out_tokens_in_millions: Optional[str] = Field( + default=None, + description="The estimated total number of input and output tokens in millions.", + ) + + estimated_total_time_in_minutes: Optional[str] = Field( + default=None, + description="The estimated total time to run the graph creation process in minutes.", + ) + + estimated_cost_in_usd: Optional[str] = Field( + default=None, + description="The estimated cost to run the graph creation process in USD.", + ) + + +class KGEnrichmentEstimate(BaseModel): + total_entities: Optional[int] = Field( + default=None, + description="The total number of entities in the graph.", + ) + + total_relationships: Optional[int] = Field( + default=None, + description="The total number of relationships in the graph.", + ) + + estimated_llm_calls: Optional[str] = Field( + default=None, + description="The estimated number of LLM calls.", + ) + + estimated_total_in_out_tokens_in_millions: Optional[str] = Field( + default=None, + description="The estimated total number of input and output tokens in millions.", + ) + + estimated_cost_in_usd: Optional[str] = Field( + default=None, + description="The estimated cost to run the graph enrichment process.", + ) + + estimated_total_time_in_minutes: Optional[str] = Field( + default=None, + description="The estimated total time to run the graph enrichment process.", + ) + + +class KGDeduplicationEstimate(R2RSerializable): + """Response for knowledge graph deduplication estimation.""" + + num_entities: Optional[int] = Field( + default=None, + description="The number of entities in the collection.", + ) + + estimated_llm_calls: Optional[str] = Field( + default=None, + description="The estimated number of LLM calls.", + ) + + estimated_total_in_out_tokens_in_millions: Optional[str] = Field( + default=None, + description="The estimated total number of input and output tokens in millions.", + ) + + estimated_cost_in_usd: Optional[str] = Field( + default=None, + description="The estimated cost in USD.", + ) + + estimated_total_time_in_minutes: Optional[str] = Field( + default=None, + description="The estimated time in minutes.", + ) + + +############# RESPONSE MODELS ############# + + +class KGCreationResponse(BaseModel): + message: str = Field( + ..., + description="A message describing the result of the KG creation request.", + ) + id: Optional[UUID] = Field( + None, + description="The ID of the created object.", + ) + task_id: Optional[UUID] = Field( + None, + description="The task ID of the KG creation request.", + ) + estimate: Optional[KGCreationEstimate] = Field( + None, + description="The estimation of the KG creation request.", + ) + + +class Config: + json_schema_extra = { + "example": { + "message": "Graph creation queued successfully.", + "id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", + "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", + "estimate": { + "document_count": 100, + "number_of_jobs_created": 10, + "total_chunks": 1000, + "estimated_entities": "1000", + "estimated_relationships": "1000", + "estimated_llm_calls": "1000", + "estimated_total_in_out_tokens_in_millions": "1000", + "estimated_total_time_in_minutes": "1000", + "estimated_cost_in_usd": "1000", + }, + } + } + + +class KGEnrichmentResponse(BaseModel): + message: str = Field( + ..., + description="A message describing the result of the KG enrichment request.", + ) + task_id: UUID = Field( + ..., + description="The task ID of the KG enrichment request.", + ) + estimate: Optional[KGEnrichmentEstimate] = Field( + None, + description="The estimation of the KG enrichment request.", + ) + + class Config: + json_schema_extra = { + "example": { + "message": "Graph enrichment queued successfuly.", + "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", + "estimate": { + "total_entities": 1000, + "total_relationships": 1000, + "estimated_llm_calls": "1000", + "estimated_total_in_out_tokens_in_millions": "1000", + "estimated_cost_in_usd": "1000", + "estimated_total_time_in_minutes": "1000", + }, + } + } + + +class KGEntityDeduplicationResponse(BaseModel): + """Response for knowledge graph entity deduplication.""" + + message: str = Field( + ..., + description="The message to display to the user.", + ) + + task_id: Optional[UUID] = Field( + None, + description="The task ID of the KG entity deduplication request.", + ) + + estimate: Optional[KGDeduplicationEstimate] = Field( + None, + description="The estimation of the KG entity deduplication request.", + ) + + class Config: + json_schema_extra = { + "example": { + "message": "Entity deduplication queued successfully.", + "task_id": "c68dc72e-fc23-5452-8f49-d7bd46088a96", + "estimate": { + "num_entities": 1000, + "estimated_llm_calls": "1000", + "estimated_total_in_out_tokens_in_millions": "1000", + "estimated_cost_in_usd": "1000", + "estimated_total_time_in_minutes": "1000", + }, + } + } + + +class KGTunePromptResponse(R2RSerializable): + """Response containing just the tuned prompt string.""" + + tuned_prompt: str = Field( + ..., + description="The updated prompt.", + ) + + class Config: + json_schema_extra = {"example": {"tuned_prompt": "The updated prompt"}} + + +# GET +WrappedKGEntitiesResponse = PaginatedResultsWrapper[list[Entity]] +WrappedKGRelationshipsResponse = PaginatedResultsWrapper[list[Relationship]] +WrappedKGCommunitiesResponse = PaginatedResultsWrapper[list[CommunityReport]] + + +# CREATE +WrappedKGCreationResponse = ResultsWrapper[KGCreationResponse] +WrappedKGEnrichmentResponse = ResultsWrapper[KGEnrichmentResponse] + + +WrappedKGTunePromptResponse = ResultsWrapper[KGTunePromptResponse] +WrappedKGEntityDeduplicationResponse = ResultsWrapper[ + KGEntityDeduplicationResponse +] diff --git a/py/tests/core/providers/kg/test_kg_logic.py b/py/tests/core/providers/kg/test_kg_logic.py index 3c74ef041..16c96372d 100644 --- a/py/tests/core/providers/kg/test_kg_logic.py +++ b/py/tests/core/providers/kg/test_kg_logic.py @@ -282,7 +282,10 @@ async def test_add_kg_extractions( @pytest.mark.asyncio async def test_get_entity_map( - postgres_db_provider, entities_raw_list, relationships_raw_list, document_id + postgres_db_provider, + entities_raw_list, + relationships_raw_list, + document_id, ): await postgres_db_provider.add_entities( entities_raw_list, table_name="chunk_entity" From 541215bd50d96b36c27783347d68e54697f25a76 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 13 Nov 2024 13:30:47 -0800 Subject: [PATCH 08/21] checkin --- py/core/base/abstractions/__init__.py | 4 +- py/core/base/api/models/__init__.py | 4 + py/core/base/providers/database.py | 124 +++++++------ py/core/main/api/v2/kg_router.py | 5 +- py/core/main/api/v3/graph_router.py | 98 +++++----- py/core/main/services/kg_service.py | 6 +- py/core/pipes/kg/community_summary.py | 14 +- py/core/pipes/kg/storage.py | 45 ++++- py/core/pipes/retrieval/kg_search_pipe.py | 6 +- py/core/providers/database/kg.py | 172 +++++------------- ...reports.yaml => graphrag_communities.yaml} | 2 +- py/sdk/v3/graphs.py | 2 +- py/shared/abstractions/__init__.py | 4 +- py/shared/abstractions/graph.py | 75 +------- py/shared/abstractions/kg.py | 6 +- py/shared/abstractions/vector.py | 2 +- py/shared/api/models/kg/responses.py | 4 +- py/shared/api/models/kg/responses_v3.py | 27 ++- .../pipes/test_kg_community_summary_pipe.py | 2 +- py/tests/core/providers/kg/test_kg_logic.py | 46 +---- 20 files changed, 275 insertions(+), 373 deletions(-) rename py/core/providers/database/prompts/{graphrag_community_reports.yaml => graphrag_communities.yaml} (99%) diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index e371bdebb..2e3d56a85 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -22,7 +22,7 @@ from shared.abstractions.graph import ( Community, CommunityInfo, - CommunityReport, + Community, Entity, EntityLevel, EntityType, @@ -110,7 +110,7 @@ "EntityType", "RelationshipType", "Community", - "CommunityReport", + "Community", "KGExtraction", "Relationship", "EntityLevel", diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index cbffaa69c..708f0d914 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -40,10 +40,12 @@ WrappedKGEnrichmentResponse as WrappedKGEnrichmentResponseV3, WrappedKGTunePromptResponse as WrappedKGTunePromptResponseV3, WrappedKGEntityDeduplicationResponse as WrappedKGEntityDeduplicationResponseV3, + WrappedKGDeletionResponse as WrappedKGDeletionResponseV3, KGCreationResponse as KGCreationResponseV3, KGEnrichmentResponse as KGEnrichmentResponseV3, KGEntityDeduplicationResponse as KGEntityDeduplicationResponseV3, KGTunePromptResponse as KGTunePromptResponseV3, + KGDeletionResponse as KGDeletionResponseV3, ) @@ -132,6 +134,8 @@ "KGEnrichmentResponseV3", "KGEntityDeduplicationResponseV3", "KGTunePromptResponseV3", + "WrappedKGDeletionResponseV3", + "KGDeletionResponseV3", # Management Responses "PromptResponse", "ServerStats", diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 7600bc3b1..0dcae7472 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -16,12 +16,13 @@ from pydantic import BaseModel from core.base import ( - CommunityReport, + Community, Entity, KGExtraction, Message, Relationship, VectorEntry, + EntityLevel, ) from core.base.abstractions import ( DocumentResponse, @@ -57,7 +58,7 @@ from uuid import UUID from ..abstractions import ( - CommunityReport, + Community, Entity, KGCreationSettings, KGEnrichmentSettings, @@ -600,25 +601,52 @@ async def create_tables(self) -> None: """Create required database tables.""" pass + ### ENTITIES CRUD OPS ### @abstractmethod - async def add_kg_extractions( + async def create_entities( self, - kg_extractions: list[KGExtraction], - table_prefix: str = "chunk_", - ) -> Tuple[int, int]: - """Add KG extractions to storage.""" + entities: list[Entity], + table_name: str, + conflict_columns: list[str] = [], + ) -> Any: + """Add entities to storage.""" pass @abstractmethod - async def add_entities( + async def get_entities( self, - entities: list[Entity], + level: EntityLevel, + entity_names: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, + offset: int = 0, + limit: int = -1, + ) -> dict: + """Get entities from storage.""" + pass + + @abstractmethod + async def update_entity( + self, + entity: Entity, table_name: str, conflict_columns: list[str] = [], ) -> Any: - """Add entities to storage.""" + """Update an entity in storage.""" + pass + + @abstractmethod + async def delete_entity( + self, + id: UUID, + chunk_id: Optional[UUID] = None, + document_id: Optional[UUID] = None, + collection_id: Optional[UUID] = None, + graph_id: Optional[UUID] = None, + ) -> None: + """Delete an entity from storage.""" pass + ### RELATIONSHIPS CRUD OPS ### @abstractmethod async def add_relationships( self, @@ -636,16 +664,7 @@ async def get_entity_map( pass @abstractmethod - async def upsert_embeddings( - self, - data: list[Tuple[Any]], - table_name: str, - ) -> None: - """Upsert embeddings into storage.""" - pass - - @abstractmethod - async def vector_query( + async def graph_search( self, query: str, **kwargs: Any ) -> AsyncGenerator[Any, None]: """Perform vector similarity search.""" @@ -670,8 +689,8 @@ async def get_communities( pass @abstractmethod - async def add_community_report( - self, community_report: CommunityReport + async def add_community( + self, community: Community ) -> None: """Add a community report.""" pass @@ -684,14 +703,14 @@ async def get_community_details( pass @abstractmethod - async def get_community_reports( + async def get_community( self, collection_id: UUID - ) -> list[CommunityReport]: + ) -> list[Community]: """Get community reports for a collection.""" pass @abstractmethod - async def check_community_reports_exist( + async def check_community_exists( self, collection_id: UUID, offset: int, limit: int ) -> list[int]: """Check which community reports exist.""" @@ -736,6 +755,12 @@ async def get_entities( """Get entities from storage.""" pass + @abstractmethod + async def create_entities_v3(self, entities: list[Entity]) -> None: + """Create entities in storage.""" + pass + + @abstractmethod async def get_relationships( self, @@ -844,7 +869,7 @@ async def get_all_relationships( @abstractmethod async def update_entity_descriptions(self, entities: list[Entity]): raise NotImplementedError - + class PromptHandler(Handler): """Abstract base class for prompt handling operations.""" @@ -1495,16 +1520,6 @@ async def get_semantic_neighbors( similarity_threshold=similarity_threshold, ) - async def add_kg_extractions( - self, - kg_extractions: list[KGExtraction], - table_prefix: str = "chunk_", - ) -> Tuple[int, int]: - """Forward to KG handler add_kg_extractions method.""" - return await self.kg_handler.add_kg_extractions( - kg_extractions, table_prefix - ) - async def add_entities( self, entities: list[Entity], @@ -1532,14 +1547,6 @@ async def get_entity_map( """Forward to KG handler get_entity_map method.""" return await self.kg_handler.get_entity_map(offset, limit, document_id) - async def upsert_embeddings( - self, - data: list[Tuple[Any]], - table_name: str, - ) -> None: - """Forward to KG handler upsert_embeddings method.""" - return await self.kg_handler.upsert_embeddings(data, table_name) - # Community methods async def add_community_info(self, communities: list[Any]) -> None: """Forward to KG handler add_communities method.""" @@ -1562,11 +1569,11 @@ async def get_communities( community_numbers=community_numbers, ) - async def add_community_report( - self, community_report: CommunityReport + async def add_community( + self, community: Community ) -> None: - """Forward to KG handler add_community_report method.""" - return await self.kg_handler.add_community_report(community_report) + """Forward to KG handler add_community method.""" + return await self.kg_handler.add_community(community) async def get_community_details( self, community_number: int, collection_id: UUID @@ -1576,17 +1583,17 @@ async def get_community_details( community_number, collection_id ) - async def get_community_reports( + async def get_community( self, collection_id: UUID - ) -> list[CommunityReport]: - """Forward to KG handler get_community_reports method.""" - return await self.kg_handler.get_community_reports(collection_id) + ) -> list[Community]: + """Forward to KG handler get_community method.""" + return await self.kg_handler.get_community(collection_id) - async def check_community_reports_exist( + async def check_community_exists( self, collection_id: UUID, offset: int, limit: int ) -> list[int]: - """Forward to KG handler check_community_reports_exist method.""" - return await self.kg_handler.check_community_reports_exist( + """Forward to KG handler check_community_exists method.""" + return await self.kg_handler.check_community_exists( collection_id, offset, limit ) @@ -1638,6 +1645,8 @@ async def get_entities( entity_table_name=entity_table_name, extra_columns=extra_columns, ) + + async async def get_relationships( self, @@ -1713,10 +1722,10 @@ async def get_all_relationships( async def update_entity_descriptions(self, entities: list[Entity]): return await self.kg_handler.update_entity_descriptions(entities) - async def vector_query( + async def graph_search( self, query: str, **kwargs: Any ) -> AsyncGenerator[Any, None]: - return self.kg_handler.vector_query(query, **kwargs) # type: ignore + return self.kg_handler.graph_search(query, **kwargs) # type: ignore async def create_vector_index(self) -> None: return await self.kg_handler.create_vector_index() @@ -1945,3 +1954,4 @@ async def list_chunks( return await self.vector_handler.list_chunks( offset, limit, filters, include_vectors ) + diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index a591f0e31..8fdbdfeea 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -293,8 +293,7 @@ async def get_entities( @self.base_endpoint async def get_relationships( collection_id: Optional[UUID] = Query( - None, - description="Collection ID to retrieve relationships from.", + None, description="Collection ID to retrieve relationships from." ), entity_names: Optional[list[str]] = Query( None, description="Entity names to filter by." @@ -435,7 +434,7 @@ async def deduplicate_entities( async def get_tuned_prompt( prompt_name: str = Query( ..., - description="The name of the prompt to tune. Valid options are 'graphrag_relationships_extraction_few_shot', 'graphrag_entity_description' and 'graphrag_community_reports'.", + description="The name of the prompt to tune. Valid options are 'graphrag_relationships_extraction_few_shot', 'graphrag_entity_description' and 'graphrag_communities'.", ), collection_id: Optional[UUID] = Query( None, description="Collection ID to retrieve communities from." diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index c44488794..59e224d83 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -8,6 +8,7 @@ from core.base import R2RException, RunType from core.base.abstractions import EntityLevel, KGRunType +from core.base.abstractions import Community, Entity, Relationship from core.base.api.models import ( WrappedKGCreationResponseV3 as WrappedKGCreationResponse, @@ -20,8 +21,15 @@ KGEnrichmentResponseV3 as KGEnrichmentResponse, KGEntityDeduplicationResponseV3 as KGEntityDeduplicationResponse, KGTunePromptResponseV3 as KGTunePromptResponse, + WrappedKGEntitiesResponseV3 as WrappedKGEntitiesResponse, + WrappedKGRelationshipsResponseV3 as WrappedKGRelationshipsResponse, + WrappedKGCommunitiesResponseV3 as WrappedKGCommunitiesResponse, + WrappedKGDeletionResponseV3 as WrappedKGDeletionResponse, ) + + + from core.providers import ( HatchetOrchestrationProvider, SimpleOrchestrationProvider, @@ -33,6 +41,8 @@ from core.base.api.models import PaginatedResultsWrapper, ResultsWrapper +from core.base.abstractions import Entity, KGCreationSettings, Relationship + from .base_router import BaseRouterV3 from fastapi import Request @@ -40,40 +50,40 @@ logger = logging.getLogger() -class Entity(BaseModel): - """Model representing a graph entity.""" - - id: UUID - name: str - type: str - metadata: dict = Field(default_factory=dict) - level: EntityLevel - collection_ids: list[UUID] - embedding: Optional[list[float]] = None - - class Config: - json_schema_extra = { - "example": { - "id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", - "name": "John Smith", - "type": "PERSON", - "metadata": {"confidence": 0.95}, - "level": "DOCUMENT", - "collection_ids": ["d09dedb1-b2ab-48a5-b950-6e1f464d83e7"], - "embedding": [0.1, 0.2, 0.3], - } - } +# class Entity(BaseModel): +# """Model representing a graph entity.""" + +# id: UUID +# name: str +# type: str +# metadata: dict = Field(default_factory=dict) +# level: EntityLevel +# collection_ids: list[UUID] +# embedding: Optional[list[float]] = None +# class Config: +# json_schema_extra = { +# "example": { +# "id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", +# "name": "John Smith", +# "type": "PERSON", +# "metadata": {"confidence": 0.95}, +# "level": "DOCUMENT", +# "collection_ids": ["d09dedb1-b2ab-48a5-b950-6e1f464d83e7"], +# "embedding": [0.1, 0.2, 0.3], +# } +# } -class Relationship(BaseModel): - """Model representing a graph relationship.""" - id: UUID - subject_id: UUID - object_id: UUID - subject_name: str - object_name: str - predicate: str +# class Relationship(BaseModel): +# """Model representing a graph relationship.""" + +# id: UUID +# subject_id: UUID +# object_id: UUID +# subject_name: str +# object_name: str +# predicate: str class GraphRouter(BaseRouterV3): @@ -230,7 +240,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.create_entities(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.chunks.create_entities_v3(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -251,7 +261,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.create_entities(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.documents.create_entities_v3(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -272,7 +282,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.create_entities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.collections.create_entities_v3(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -280,12 +290,12 @@ async def list_entities( }, ) @self.base_endpoint - async def create_entities( + async def create_entities_v3( request: Request, id: UUID = Path( ..., description="The ID of the chunk to create entities for." ), - entities: list[Union[Entity, dict]] = Body( + entities: list[Entity] = Body( ..., description="The entities to create." ), auth_user=Depends(self.providers.auth.auth_wrapper), @@ -295,10 +305,6 @@ async def create_entities( "Only superusers can access this endpoint.", 403 ) - entities = [ - Entity(**entity) if isinstance(entity, dict) else entity - for entity in entities - ] # for each entity, set the level to CHUNK for entity in entities: if entity.level is None: @@ -766,7 +772,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.create_entities(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.documents.create_entities_v3(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -774,7 +780,7 @@ async def list_entities( }, ) @self.base_endpoint - async def create_entities( + async def create_entities_v3( id: UUID = Path( ..., description="The ID of the chunk to create entities for." ), @@ -1536,9 +1542,7 @@ async def list_entities( limit: int = Query(100, ge=1, le=1000), include_embeddings: bool = Query(False), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ( - WrappedKGEntitiesResponse - ): # PaginatedResultsWrapper[list[Entity]]: + ) -> WrappedKGEntitiesResponse: """Lists entities in the graph with filtering and pagination support. Entities represent the nodes in the knowledge graph, extracted from documents. @@ -2388,7 +2392,7 @@ async def delete_community( collection_id: UUID = Path(...), community_id: UUID = Path(...), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedBooleanResponse: + ) -> WrappedKGDeletionResponse: """ Deletes a specific community by ID. This operation will not affect other communities or the underlying entities. @@ -2455,7 +2459,7 @@ async def tune_prompt( collection_id: UUID = Path(...), prompt_name: str = Body( ..., - description="The prompt to tune. Valid options: graphrag_relationships_extraction_few_shot, graphrag_entity_description, graphrag_community_reports", + description="The prompt to tune. Valid options: graphrag_relationships_extraction_few_shot, graphrag_entity_description, graphrag_communities", ), documents_offset: int = Body(0, ge=0), documents_limit: int = Body(100, ge=1), diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 87270aafd..53e97f3f9 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -118,15 +118,15 @@ async def kg_relationships_extraction( return await _collect_results(result_gen) - @telemetry_event("create_entities") - async def create_entities( + @telemetry_event("create_entities_v3") + async def create_entities_v3( self, level: EntityLevel, id: UUID, entities: list[Entity], **kwargs, ): - return await self.providers.database.create_entities( + return await self.providers.database.create_entities_v3( level, id, entities, **kwargs ) diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index 19b386e62..5eb54ab99 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -9,7 +9,7 @@ from core.base import ( AsyncPipe, AsyncState, - CommunityReport, + Community, CompletionProvider, DatabaseProvider, EmbeddingProvider, @@ -161,7 +161,7 @@ async def process_community( ( await self.llm_provider.aget_completion( messages=await self.database_provider.prompt_handler.get_message_payload( - task_prompt_name=self.database_provider.config.kg_enrichment_settings.graphrag_community_reports, + task_prompt_name=self.database_provider.config.kg_enrichment_settings.graphrag_communities, task_inputs={ "input_text": ( await self.community_summary_prompt( @@ -206,7 +206,7 @@ async def process_community( "error": str(e), } - community_report = CommunityReport( + community = Community( community_number=community_number, collection_id=collection_id, level=community_level, @@ -223,11 +223,11 @@ async def process_community( ), ) - await self.database_provider.add_community_report(community_report) + await self.database_provider.add_community(community) return { - "community_number": community_report.community_number, - "name": community_report.name, + "community_number": community.community_number, + "name": community.name, } async def _run_logic( # type: ignore @@ -257,7 +257,7 @@ async def _run_logic( # type: ignore f"KGCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}" ) community_numbers_exist = ( - await self.database_provider.check_community_reports_exist( + await self.database_provider.check_communities_exist( collection_id=collection_id, offset=offset, limit=limit ) ) diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index 9cf592e6a..99c886359 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -53,8 +53,49 @@ async def store( Stores a batch of knowledge graph extractions in the graph database. """ try: - await self.database_provider.add_kg_extractions(kg_extractions) - return + # clean up and remove this method. + # make add_kg_extractions a method in the KGHandler + + total_entities, total_relationships = 0, 0 + + for extraction in kg_extractions: + + total_entities, total_relationships = ( + total_entities + len(extraction.entities), + total_relationships + len(extraction.relationships), + ) + + if extraction.entities: + if not extraction.entities[0].extraction_ids: + for i in range(len(extraction.entities)): + extraction.entities[i].extraction_ids = ( + extraction.extraction_ids + ) + extraction.entities[i].document_id = ( + extraction.document_id + ) + + await self.database_provider.add_entities( + extraction.entities, table_name=f"chunk_entity" + ) + + if extraction.relationships: + if not extraction.relationships[0].extraction_ids: + for i in range(len(extraction.relationships)): + extraction.relationships[i].extraction_ids = ( + extraction.extraction_ids + ) + extraction.relationships[i].document_id = ( + extraction.document_id + ) + + await self.database_provider.add_relationships( + extraction.relationships, + table_name=f"chunk_relationship", + ) + + return (total_entities, total_relationships) + except Exception as e: error_message = f"Failed to store knowledge graph extractions in the database: {e}" logger.error(error_message) diff --git a/py/core/pipes/retrieval/kg_search_pipe.py b/py/core/pipes/retrieval/kg_search_pipe.py index ec7fe15b6..60aa68df4 100644 --- a/py/core/pipes/retrieval/kg_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_pipe.py @@ -110,7 +110,7 @@ async def local_search( # entity search search_type = "__Entity__" - async for search_result in await self.database_provider.vector_query( # type: ignore + async for search_result in await self.database_provider.graph_search( # type: ignore message, search_type=search_type, search_type_limits=kg_search_settings.local_search_limits[ @@ -139,7 +139,7 @@ async def local_search( # relationship search # disabled for now. We will check evaluations and see if we need it # search_type = "__Relationship__" - # async for search_result in self.database_provider.vector_query( # type: ignore + # async for search_result in self.database_provider.graph_search( # type: ignore # input, # search_type=search_type, # search_type_limits=kg_search_settings.local_search_limits[ @@ -167,7 +167,7 @@ async def local_search( # community search search_type = "__Community__" - async for search_result in await self.database_provider.vector_query( # type: ignore + async for search_result in await self.database_provider.graph_search( # type: ignore message, search_type=search_type, search_type_limits=kg_search_settings.local_search_limits[ diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 016839c13..9746b1a70 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -9,7 +9,7 @@ from asyncpg.exceptions import PostgresError, UndefinedTableError from core.base import ( - CommunityReport, + Community, Entity, KGExtraction, KGExtractionStatus, @@ -152,7 +152,7 @@ async def create_tables(self): # communities_report table query = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name("community_report")} ( + CREATE TABLE IF NOT EXISTS {self._get_table_name("community")} ( id SERIAL PRIMARY KEY, community_number INT NOT NULL, collection_id UUID NOT NULL, @@ -284,7 +284,7 @@ async def get_graph_status(self, collection_id: UUID) -> dict: ) community_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('community_report')} WHERE collection_id = $1", + f"SELECT COUNT(*) FROM {self._get_table_name('community')} WHERE collection_id = $1", [collection_id], ) @@ -368,64 +368,6 @@ async def list_relationships_v3( return results ### Relationships END #### - - async def add_kg_extractions( - self, - kg_extractions: list[KGExtraction], - table_prefix: str = "chunk_", - ) -> Tuple[int, int]: - """ - Upsert entities and relationships into the database. These are raw entities and relationships extracted from the document fragments. - - Args: - kg_extractions: list[KGExtraction]: list of KG extractions to upsert - table_prefix: str: prefix to add to the table names - - Returns: - total_entities: int: total number of entities upserted - total_relationships: int: total number of relationships upserted - """ - - total_entities, total_relationships = 0, 0 - - for extraction in kg_extractions: - - total_entities, total_relationships = ( - total_entities + len(extraction.entities), - total_relationships + len(extraction.relationships), - ) - - if extraction.entities: - if not extraction.entities[0].extraction_ids: - for i in range(len(extraction.entities)): - extraction.entities[i].extraction_ids = ( - extraction.extraction_ids - ) - extraction.entities[i].document_id = ( - extraction.document_id - ) - - await self.add_entities( - extraction.entities, table_name=f"{table_prefix}entity" - ) - - if extraction.relationships: - if not extraction.relationships[0].extraction_ids: - for i in range(len(extraction.relationships)): - extraction.relationships[i].extraction_ids = ( - extraction.extraction_ids - ) - extraction.relationships[i].document_id = ( - extraction.document_id - ) - - await self.add_relationships( - extraction.relationships, - table_name=f"{table_prefix}relationship", - ) - - return (total_entities, total_relationships) - async def get_entity_map( self, offset: int, limit: int, document_id: UUID ) -> dict[str, dict[str, list[dict[str, Any]]]]: @@ -511,33 +453,7 @@ async def get_entity_map( return entity_map - async def upsert_embeddings( - self, - data: list[Tuple[Any]], - table_name: str, - ) -> None: - QUERY = f""" - INSERT INTO {self._get_table_name(table_name)} (name, description, description_embedding, extraction_ids, document_id) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (name, document_id) DO UPDATE SET - description = EXCLUDED.description, - description_embedding = EXCLUDED.description_embedding, - extraction_ids = EXCLUDED.extraction_ids, - document_id = EXCLUDED.document_id - """ - return await self.connection_manager.execute_many(QUERY, data) - - async def upsert_entities(self, entities: list[Entity]) -> None: - QUERY = """ - INSERT INTO $1.$2 (category, name, description, description_embedding, extraction_ids, document_id, attributes) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """ - - table_name = self._get_table_name("entities") - query = QUERY.format(table_name) - await self.connection_manager.execute_query(query, entities) - - async def vector_query( # type: ignore + async def graph_search( # type: ignore self, query: str, **kwargs: Any ) -> AsyncGenerator[Any, None]: @@ -559,7 +475,7 @@ async def vector_query( # type: ignore elif search_type == "__Relationship__": table_name = "chunk_relationship" elif search_type == "__Community__": - table_name = "community_report" + table_name = "community" else: raise ValueError(f"Invalid search type: {search_type}") @@ -694,7 +610,7 @@ async def get_communities( query = f""" SELECT id, community_number, collection_id, level, name, summary, findings, rating, rating_explanation, COUNT(*) OVER() AS total_entries - FROM {self._get_table_name('community_report')} + FROM {self._get_table_name('community')} WHERE collection_id = $1 {" AND " + " AND ".join(conditions) if conditions else ""} ORDER BY community_number @@ -703,23 +619,23 @@ async def get_communities( results = await self.connection_manager.fetch_query(query, params) total_entries = results[0]["total_entries"] if results else 0 - communities = [CommunityReport(**community) for community in results] + communities = [Community(**community) for community in results] return { "communities": communities, "total_entries": total_entries, } - async def add_community_report( - self, community_report: CommunityReport + async def add_community( + self, community: Community ) -> None: # TODO: Fix in the short term. # we need to do this because postgres insert needs to be a string - community_report.embedding = str(community_report.embedding) # type: ignore[assignment] + community.embedding = str(community.embedding) # type: ignore[assignment] non_null_attrs = { - k: v for k, v in community_report.__dict__.items() if v is not None + k: v for k, v in community.__dict__.items() if v is not None } columns = ", ".join(non_null_attrs.keys()) placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs))) @@ -729,7 +645,7 @@ async def add_community_report( ) QUERY = f""" - INSERT INTO {self._get_table_name("community_report")} ({columns}) + INSERT INTO {self._get_table_name("community")} ({columns}) VALUES ({placeholders}) ON CONFLICT (community_number, level, collection_id) DO UPDATE SET {conflict_columns} @@ -773,7 +689,7 @@ async def _cluster_and_add_community_info( await self.connection_manager.execute_query(QUERY, [collection_id]) QUERY = f""" - DELETE FROM {self._get_table_name("community_report")} WHERE collection_id = $1 + DELETE FROM {self._get_table_name("community")} WHERE collection_id = $1 """ await self.connection_manager.execute_query(QUERY, [collection_id]) @@ -947,7 +863,7 @@ async def _incremental_clustering( # delete the communities information for the updated communities QUERY = f""" - DELETE FROM {self._get_table_name("community_report")} WHERE collection_id = $1 AND community_number = ANY($2) + DELETE FROM {self._get_table_name("community")} WHERE collection_id = $1 AND community_number = ANY($2) """ await self.connection_manager.execute_query( QUERY, [collection_id, updated_communities] @@ -1127,7 +1043,7 @@ async def get_community_details( ########## Entity CRUD Operations ########################## ############################################################ - async def create_entities( + async def create_entities_v3( self, level: EntityLevel, id: UUID, entities: list[Entity] ) -> None: @@ -1218,21 +1134,21 @@ async def delete_relationship(self, relationship_id: UUID) -> None: ########## Community CRUD Operations ####################### ############################################################ - async def get_community_reports( + async def get_communities( self, collection_id: UUID - ) -> list[CommunityReport]: + ) -> list[Community]: QUERY = f""" - SELECT *c FROM {self._get_table_name("community_report")} WHERE collection_id = $1 + SELECT *c FROM {self._get_table_name("community")} WHERE collection_id = $1 """ return await self.connection_manager.fetch_query( QUERY, [collection_id] ) - async def check_community_reports_exist( + async def check_communities_exist( self, collection_id: UUID, offset: int, limit: int ) -> list[int]: QUERY = f""" - SELECT distinct community_number FROM {self._get_table_name("community_report")} WHERE collection_id = $1 AND community_number >= $2 AND community_number < $3 + SELECT distinct community_number FROM {self._get_table_name("community")} WHERE collection_id = $1 AND community_number >= $2 AND community_number < $3 """ community_numbers = await self.connection_manager.fetch_query( QUERY, [collection_id, offset, offset + limit] @@ -1256,7 +1172,7 @@ async def delete_graph_for_collection( # remove all relationships for these documents. DELETE_QUERIES = [ f"DELETE FROM {self._get_table_name('community_info')} WHERE collection_id = $1;", - f"DELETE FROM {self._get_table_name('community_report')} WHERE collection_id = $1;", + f"DELETE FROM {self._get_table_name('community')} WHERE collection_id = $1;", ] # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. @@ -1342,7 +1258,7 @@ async def delete_node_via_document_id( # If it's the last document, delete collection-related data collection_queries = [ f"DELETE FROM {self._get_table_name('community_info')} WHERE collection_id = $1", - f"DELETE FROM {self._get_table_name('community_report')} WHERE collection_id = $1", + f"DELETE FROM {self._get_table_name('community')} WHERE collection_id = $1", ] for query in collection_queries: await self.connection_manager.execute_query( @@ -1440,33 +1356,33 @@ async def get_creation_estimate( total_in_out_tokens[1] * 10 / 60, ) # 10 minutes per million tokens - return KGCreationEstimationResponse( - message='Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `create-graph` with `--run` in the cli, or `run_type="run"` in the client.', - document_count=len(document_ids), - number_of_jobs_created=len(document_ids) + 1, - total_chunks=total_chunks, - estimated_entities=self._get_str_estimation_output( + return { + "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `create-graph` with `--run` in the cli, or `run_type="run"` in the client.', + "document_count": len(document_ids), + "number_of_jobs_created": len(document_ids) + 1, + "total_chunks": total_chunks, + "estimated_entities": self._get_str_estimation_output( estimated_entities ), - estimated_relationships=self._get_str_estimation_output( + "estimated_relationships": self._get_str_estimation_output( estimated_relationships ), - estimated_llm_calls=self._get_str_estimation_output( + "estimated_llm_calls": self._get_str_estimation_output( estimated_llm_calls ), - estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output( + "estimated_total_in_out_tokens_in_millions": self._get_str_estimation_output( total_in_out_tokens ), - estimated_cost_in_usd=self._get_str_estimation_output( + "estimated_cost_in_usd": self._get_str_estimation_output( estimated_cost ), - estimated_total_time_in_minutes="Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + self._get_str_estimation_output(total_time_in_minutes), - ) + } async def get_enrichment_estimate( self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings - ) -> KGEnrichmentEstimationResponse: + ): document_ids = [ doc.id @@ -1514,22 +1430,22 @@ async def get_enrichment_estimate( estimated_total_in_out_tokens_in_millions[1] * 10 / 60, ) - return KGEnrichmentEstimationResponse( - message='Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_type="run"` in the client.', - total_entities=entity_count, - total_relationships=relationship_count, - estimated_llm_calls=self._get_str_estimation_output( + return { + "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_type="run"` in the client.', + "total_entities": entity_count, + "total_relationships": relationship_count, + "estimated_llm_calls": self._get_str_estimation_output( estimated_llm_calls ), - estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output( + "estimated_total_in_out_tokens_in_millions": self._get_str_estimation_output( estimated_total_in_out_tokens_in_millions ), - estimated_cost_in_usd=self._get_str_estimation_output( + "estimated_cost_in_usd": self._get_str_estimation_output( estimated_cost ), - estimated_total_time_in_minutes="Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + self._get_str_estimation_output(estimated_total_time), - ) + } async def create_vector_index(self): # need to implement this. Just call vector db provider's create_vector_index method. diff --git a/py/core/providers/database/prompts/graphrag_community_reports.yaml b/py/core/providers/database/prompts/graphrag_communities.yaml similarity index 99% rename from py/core/providers/database/prompts/graphrag_community_reports.yaml rename to py/core/providers/database/prompts/graphrag_communities.yaml index 8b78b94d7..be68b3d6e 100644 --- a/py/core/providers/database/prompts/graphrag_community_reports.yaml +++ b/py/core/providers/database/prompts/graphrag_communities.yaml @@ -1,4 +1,4 @@ -graphrag_community_reports: +graphrag_communities: template: | You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network. diff --git a/py/sdk/v3/graphs.py b/py/sdk/v3/graphs.py index a0fb7cb7e..bcdaa8413 100644 --- a/py/sdk/v3/graphs.py +++ b/py/sdk/v3/graphs.py @@ -534,7 +534,7 @@ async def tune_prompt( Args: collection_id (Union[str, UUID]): Collection ID to tune prompt for prompt_name (str): Name of prompt to tune (graphrag_relationships_extraction_few_shot, - graphrag_entity_description, or graphrag_community_reports) + graphrag_entity_description, or graphrag_communities) documents_offset (int): Document pagination offset documents_limit (int): Maximum number of documents to use chunks_offset (int): Chunk pagination offset diff --git a/py/shared/abstractions/__init__.py b/py/shared/abstractions/__init__.py index f0cfb9cbb..60f834134 100644 --- a/py/shared/abstractions/__init__.py +++ b/py/shared/abstractions/__init__.py @@ -14,7 +14,7 @@ from .exception import R2RDocumentProcessingError, R2RException from .graph import ( Community, - CommunityReport, + Community, Entity, EntityType, KGExtraction, @@ -93,7 +93,7 @@ "EntityType", "RelationshipType", "Community", - "CommunityReport", + "Community", "KGExtraction", "Relationship", # LLM abstractions diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index b6b4417e3..b5a32648e 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -159,75 +159,6 @@ def from_dict( # type: ignore attributes=d.get(attributes_key, {}), ) - -@dataclass -class Community(BaseModel): - """A protocol for a community in the system.""" - - id: int | None = None - """The ID of the community.""" - - community_number: int | None = None - """The community number.""" - - collection_id: uuid.UUID | None = None - """The ID of the collection this community is associated with.""" - - level: int | None = None - """Community level.""" - - name: str = "" - """The name of the community.""" - - summary: str = "" - """Summary of the report.""" - - findings: list[str] = [] - """Findings of the report.""" - - rating: float | None = None - """Rating of the report.""" - - rating_explanation: str | None = None - """Explanation of the rating.""" - - embedding: list[float] | None = None - """Embedding of summary and findings.""" - - attributes: dict[str, Any] | None = None - """A dictionary of additional attributes associated with the community (optional). To be included in the search prompt.""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - if isinstance(self.attributes, str): - self.attributes = json.loads(self.attributes) - - @classmethod - def from_dict( - cls, - d: dict[str, Any], - id_key: str = "id", - title_key: str = "title", - short_id_key: str = "short_id", - level_key: str = "level", - entities_key: str = "entity_ids", - relationships_key: str = "relationship_ids", - covariates_key: str = "covariate_ids", - attributes_key: str = "attributes", - ) -> "Community": - """Create a new community from the dict data.""" - return Community( - id=d[id_key], - title=d[title_key], - short_id=d.get(short_id_key), - level=d[level_key], - entity_ids=d.get(entities_key), - relationship_ids=d.get(relationships_key), - covariate_ids=d.get(covariates_key), - attributes=d.get(attributes_key), - ) - - @dataclass class CommunityInfo(BaseModel): """A protocol for a community in the system.""" @@ -257,7 +188,7 @@ def from_dict(cls, d: dict[str, Any]) -> "CommunityInfo": @dataclass -class CommunityReport(BaseModel): +class Community(BaseModel): """Defines an LLM-generated summary report of a community.""" community_number: int @@ -309,9 +240,9 @@ def from_dict( summary_embedding_key: str = "summary_embedding", embedding_key: str = "embedding", attributes_key: str = "attributes", - ) -> "CommunityReport": + ) -> "Community": """Create a new community report from the dict data.""" - return CommunityReport( + return Community( id=d[id_key], title=d[title_key], community_number=d[community_number_key], diff --git a/py/shared/abstractions/kg.py b/py/shared/abstractions/kg.py index 699a5e1d6..a942a2bdf 100644 --- a/py/shared/abstractions/kg.py +++ b/py/shared/abstractions/kg.py @@ -110,10 +110,10 @@ class KGEnrichmentSettings(R2RSerializable): description="Force run the enrichment step even if graph creation is still in progress for some documents.", ) - graphrag_community_reports: str = Field( - default="graphrag_community_reports", + graphrag_communities: str = Field( + default="graphrag_communities", description="The prompt to use for knowledge graph enrichment.", - alias="graphrag_community_reports", # TODO - mark deprecated & remove + alias="graphrag_communities", # TODO - mark deprecated & remove ) max_summary_input_length: int = Field( diff --git a/py/shared/abstractions/vector.py b/py/shared/abstractions/vector.py index 85742da8b..1975c0b5d 100644 --- a/py/shared/abstractions/vector.py +++ b/py/shared/abstractions/vector.py @@ -119,7 +119,7 @@ class VectorTableName(str, Enum): ENTITIES_COLLECTION = "collection_entity" # TODO: Add support for relationships # TRIPLES = "chunk_relationship" - COMMUNITIES = "community_report" + COMMUNITIES = "community" def __str__(self) -> str: return self.value diff --git a/py/shared/api/models/kg/responses.py b/py/shared/api/models/kg/responses.py index 3cde5a192..165c0f604 100644 --- a/py/shared/api/models/kg/responses.py +++ b/py/shared/api/models/kg/responses.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field from shared.abstractions.base import R2RSerializable -from shared.abstractions.graph import CommunityReport, Entity, Relationship +from shared.abstractions.graph import Community, Entity, Relationship from shared.api.models.base import ResultsWrapper, PaginatedResultsWrapper @@ -239,7 +239,7 @@ class Config: class KGCommunitiesResponse(R2RSerializable): """Response for knowledge graph communities.""" - communities: list[CommunityReport] = Field( + communities: list[Community] = Field( ..., description="The list of communities in the graph for the collection.", ) diff --git a/py/shared/api/models/kg/responses_v3.py b/py/shared/api/models/kg/responses_v3.py index b2f585142..af6bd2ad1 100644 --- a/py/shared/api/models/kg/responses_v3.py +++ b/py/shared/api/models/kg/responses_v3.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field from shared.abstractions.base import R2RSerializable -from shared.abstractions.graph import CommunityReport, Entity, Relationship +from shared.abstractions.graph import Community, Entity, Relationship from shared.api.models.base import ResultsWrapper, PaginatedResultsWrapper @@ -229,6 +229,28 @@ class Config: } +class KGDeletionResponse(BaseModel): + """Response for knowledge graph deletion.""" + + message: str = Field( + ..., + description="The message to display to the user.", + ) + id: UUID = Field( + ..., + description="The ID of the deleted graph.", + ) + + class Config: + json_schema_extra = { + "example": { + "message": "Entity deleted successfully.", + "id": "123e4567-e89b-12d3-a456-426614174000", + } + } + + + class KGTunePromptResponse(R2RSerializable): """Response containing just the tuned prompt string.""" @@ -244,7 +266,7 @@ class Config: # GET WrappedKGEntitiesResponse = PaginatedResultsWrapper[list[Entity]] WrappedKGRelationshipsResponse = PaginatedResultsWrapper[list[Relationship]] -WrappedKGCommunitiesResponse = PaginatedResultsWrapper[list[CommunityReport]] +WrappedKGCommunitiesResponse = PaginatedResultsWrapper[list[Community]] # CREATE @@ -256,3 +278,4 @@ class Config: WrappedKGEntityDeduplicationResponse = ResultsWrapper[ KGEntityDeduplicationResponse ] +WrappedKGDeletionResponse = ResultsWrapper[KGDeletionResponse] \ No newline at end of file diff --git a/py/tests/core/pipes/test_kg_community_summary_pipe.py b/py/tests/core/pipes/test_kg_community_summary_pipe.py index 33cb7667a..58c683ac9 100644 --- a/py/tests/core/pipes/test_kg_community_summary_pipe.py +++ b/py/tests/core/pipes/test_kg_community_summary_pipe.py @@ -6,7 +6,7 @@ from core.base import ( AsyncPipe, Community, - CommunityReport, + Community, Entity, KGExtraction, Relationship, diff --git a/py/tests/core/providers/kg/test_kg_logic.py b/py/tests/core/providers/kg/test_kg_logic.py index 16c96372d..84f41dd6e 100644 --- a/py/tests/core/providers/kg/test_kg_logic.py +++ b/py/tests/core/providers/kg/test_kg_logic.py @@ -6,7 +6,7 @@ from core.base import ( Community, - CommunityReport, + Community, Entity, KGExtraction, Relationship, @@ -161,9 +161,9 @@ def kg_extractions( @pytest.fixture(scope="function") -def community_report_list(embedding_vectors, collection_id): +def community_list(embedding_vectors, collection_id): return [ - CommunityReport( + Community( community_number=1, level=0, collection_id=collection_id, @@ -174,7 +174,7 @@ def community_report_list(embedding_vectors, collection_id): findings=["Findings of the community report"], embedding=embedding_vectors[0], ), - CommunityReport( + Community( community_number=2, level=0, collection_id=collection_id, @@ -254,32 +254,6 @@ async def test_add_relationships( assert len(relationships["relationships"]) == 2 assert relationships["total_entries"] == 2 - -@pytest.mark.asyncio -async def test_add_kg_extractions( - postgres_db_provider, kg_extractions, collection_id -): - added_extractions = await postgres_db_provider.add_kg_extractions( - kg_extractions, table_prefix="chunk_" - ) - - assert added_extractions == (2, 2) - - entities = await postgres_db_provider.get_entities( - collection_id, entity_table_name="chunk_entity" - ) - assert entities["entities"][0].name == "Entity1" - assert entities["entities"][1].name == "Entity2" - assert len(entities["entities"]) == 2 - assert entities["total_entries"] == 2 - - relationships = await postgres_db_provider.get_relationships(collection_id) - assert relationships["relationships"][0].subject == "Entity1" - assert relationships["relationships"][1].subject == "Entity2" - assert len(relationships["relationships"]) == 2 - assert relationships["total_entries"] == 2 - - @pytest.mark.asyncio async def test_get_entity_map( postgres_db_provider, @@ -320,7 +294,7 @@ async def test_upsert_embeddings( for entity in entities_list ] - await postgres_db_provider.upsert_embeddings( + await postgres_db_provider.add_entities( entities_list_to_upsert, table_name ) @@ -344,10 +318,10 @@ async def test_get_all_relationships( @pytest.mark.asyncio async def test_get_communities( - postgres_db_provider, collection_id, community_report_list + postgres_db_provider, collection_id, community_list ): - await postgres_db_provider.add_community_report(community_report_list[0]) - await postgres_db_provider.add_community_report(community_report_list[1]) + await postgres_db_provider.add_community(community_list[0]) + await postgres_db_provider.add_community(community_list[1]) communities = await postgres_db_provider.get_communities(collection_id) assert communities["communities"][0].name == "Community Report 1" assert len(communities["communities"]) == 2 @@ -392,7 +366,7 @@ async def test_get_community_details( entities_list, relationships_raw_list, collection_id, - community_report_list, + community_list, community_table_info, ): @@ -403,7 +377,7 @@ async def test_get_community_details( relationships_raw_list, table_name="chunk_relationship" ) await postgres_db_provider.add_community_info(community_table_info) - await postgres_db_provider.add_community_report(community_report_list[0]) + await postgres_db_provider.add_community(community_list[0]) community_level, entities, relationships = ( await postgres_db_provider.get_community_details( From 819d6ddcdc724d0231aa0501a72c110e58971daa Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 13 Nov 2024 15:43:34 -0800 Subject: [PATCH 09/21] up --- py/core/base/__init__.py | 9 +- py/core/base/abstractions/__init__.py | 3 +- py/core/base/providers/__init__.py | 4 +- py/core/base/providers/database.py | 609 +++-- py/core/pipes/kg/storage.py | 2 +- py/core/providers/database/kg.py | 2254 +++++++++-------- py/core/providers/database/kg_tmp/__init__.py | 0 .../providers/database/kg_tmp/community.py | 0 .../database/kg_tmp/community_info.py | 0 py/core/providers/database/kg_tmp/entity.py | 8 + py/core/providers/database/kg_tmp/graph.py | 0 py/core/providers/database/kg_tmp/main.py | 52 + .../providers/database/kg_tmp/relationship.py | 0 py/core/providers/database/postgres.py | 10 +- py/shared/abstractions/graph.py | 39 +- 15 files changed, 1708 insertions(+), 1282 deletions(-) create mode 100644 py/core/providers/database/kg_tmp/__init__.py create mode 100644 py/core/providers/database/kg_tmp/community.py create mode 100644 py/core/providers/database/kg_tmp/community_info.py create mode 100644 py/core/providers/database/kg_tmp/entity.py create mode 100644 py/core/providers/database/kg_tmp/graph.py create mode 100644 py/core/providers/database/kg_tmp/main.py create mode 100644 py/core/providers/database/kg_tmp/relationship.py diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index e514209cc..8d59f5ff0 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -33,6 +33,11 @@ "Entity", "KGExtraction", "Relationship", + "Community", + "CommunityInfo", + "KGCreationSettings", + "KGEnrichmentSettings", + "KGRunType", # LLM abstractions "GenerationConfig", "LLMChatCompletion", @@ -48,10 +53,6 @@ "VectorSearchSettings", "DocumentSearchSettings", "HybridSearchSettings", - # KG abstractions - "KGCreationSettings", - "KGEnrichmentSettings", - "KGRunType", # User abstractions "Token", "TokenData", diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 2e3d56a85..6daf753aa 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -26,6 +26,7 @@ Entity, EntityLevel, EntityType, + Graph, KGExtraction, RelationshipType, Relationship, @@ -110,7 +111,7 @@ "EntityType", "RelationshipType", "Community", - "Community", + "CommunityInfo", "KGExtraction", "Relationship", "EntityLevel", diff --git a/py/core/base/providers/__init__.py b/py/core/base/providers/__init__.py index 37af2b8f3..810bbdc70 100644 --- a/py/core/base/providers/__init__.py +++ b/py/core/base/providers/__init__.py @@ -8,7 +8,7 @@ DatabaseProvider, DocumentHandler, FileHandler, - KGHandler, + GraphHandler, LoggingHandler, PostgresConfigurationSettings, PromptHandler, @@ -48,7 +48,7 @@ "UserHandler", "LoggingHandler", "VectorHandler", - "KGHandler", + "GraphHandler", "PromptHandler", "FileHandler", "DatabaseConfig", diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 0dcae7472..ecb9ed497 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -15,9 +15,11 @@ from pydantic import BaseModel -from core.base import ( +from core.base.abstractions import ( Community, + CommunityInfo, Entity, + Graph, KGExtraction, Message, Relationship, @@ -593,283 +595,408 @@ async def list_chunks( pass -class KGHandler(Handler): - """Base handler for Knowledge Graph operations.""" - - @abstractmethod - async def create_tables(self) -> None: - """Create required database tables.""" - pass - - ### ENTITIES CRUD OPS ### +# class GraphHandler(Handler): +# """Base handler for Knowledge Graph operations.""" + +# @abstractmethod +# async def create_tables(self) -> None: +# """Create required database tables.""" +# pass + +# ### ENTITIES CRUD OPS ### +# @abstractmethod +# async def create_entities( +# self, +# entities: list[Entity], +# table_name: str, +# conflict_columns: list[str] = [], +# ) -> Any: +# """Add entities to storage.""" +# pass + +# @abstractmethod +# async def get_entities( +# self, +# level: EntityLevel, +# entity_names: Optional[list[str]] = None, +# attributes: Optional[list[str]] = None, +# offset: int = 0, +# limit: int = -1, +# ) -> dict: +# """Get entities from storage.""" +# pass + +# @abstractmethod +# async def update_entity( +# self, +# entity: Entity, +# table_name: str, +# conflict_columns: list[str] = [], +# ) -> Any: +# """Update an entity in storage.""" +# pass + +# @abstractmethod +# async def delete_entity( +# self, +# id: UUID, +# chunk_id: Optional[UUID] = None, +# document_id: Optional[UUID] = None, +# collection_id: Optional[UUID] = None, +# graph_id: Optional[UUID] = None, +# ) -> None: +# """Delete an entity from storage.""" +# pass + +# ### RELATIONSHIPS CRUD OPS ### +# @abstractmethod +# async def add_relationships( +# self, +# relationships: list[Relationship], +# table_name: str = "chunk_relationship", +# ) -> None: +# """Add relationships to storage.""" +# pass + +# @abstractmethod +# async def get_entity_map( +# self, offset: int, limit: int, document_id: UUID +# ) -> dict[str, dict[str, list[dict[str, Any]]]]: +# """Get entity map for a document.""" +# pass + +# @abstractmethod +# async def graph_search( +# self, query: str, **kwargs: Any +# ) -> AsyncGenerator[Any, None]: +# """Perform vector similarity search.""" +# pass + +# # Community management +# @abstractmethod +# async def add_community_info(self, communities: list[Any]) -> None: +# """Add communities to storage.""" +# pass + +# @abstractmethod +# async def get_communities( +# self, +# offset: int, +# limit: int, +# collection_id: Optional[UUID] = None, +# levels: Optional[list[int]] = None, +# community_numbers: Optional[list[int]] = None, +# ) -> dict: +# """Get communities for a collection.""" +# pass + +# @abstractmethod +# async def add_community( +# self, community: Community +# ) -> None: +# """Add a community report.""" +# pass + +# @abstractmethod +# async def get_community_details( +# self, community_number: int, collection_id: UUID +# ) -> Tuple[int, list[Entity], list[Relationship]]: +# """Get detailed information about a community.""" +# pass + +# @abstractmethod +# async def get_community( +# self, collection_id: UUID +# ) -> list[Community]: +# """Get community reports for a collection.""" +# pass + +# @abstractmethod +# async def check_community_exists( +# self, collection_id: UUID, offset: int, limit: int +# ) -> list[int]: +# """Check which community reports exist.""" +# pass + +# @abstractmethod +# async def perform_graph_clustering( +# self, +# collection_id: UUID, +# leiden_params: dict[str, Any], +# ) -> int: +# """Perform graph clustering.""" +# pass + +# # Graph operations +# @abstractmethod +# async def delete_graph_for_collection( +# self, collection_id: UUID, cascade: bool = False +# ) -> None: +# """Delete graph data for a collection.""" +# pass + +# @abstractmethod +# async def delete_node_via_document_id( +# self, document_id: UUID, collection_id: UUID +# ) -> None: +# """Delete a node using document ID.""" +# pass + +# # Entity and Relationship management +# @abstractmethod +# async def get_entities( +# self, +# offset: int, +# limit: int, +# collection_id: Optional[UUID] = None, +# entity_ids: Optional[list[str]] = None, +# entity_names: Optional[list[str]] = None, +# entity_table_name: str = "document_entity", +# extra_columns: Optional[list[str]] = None, +# ) -> dict: +# """Get entities from storage.""" +# pass + +# @abstractmethod +# async def create_entities_v3(self, entities: list[Entity]) -> None: +# """Create entities in storage.""" +# pass + + +# @abstractmethod +# async def get_relationships( +# self, +# offset: int, +# limit: int, +# collection_id: Optional[UUID] = None, +# entity_names: Optional[list[str]] = None, +# relationship_ids: Optional[list[str]] = None, +# ) -> dict: +# """Get relationships from storage.""" +# pass + +# @abstractmethod +# async def get_entity_count( +# self, +# collection_id: Optional[UUID] = None, +# document_id: Optional[UUID] = None, +# distinct: bool = False, +# entity_table_name: str = "document_entity", +# ) -> int: +# """Get entity count.""" +# pass + +# @abstractmethod +# async def get_relationship_count( +# self, +# collection_id: Optional[UUID] = None, +# document_id: Optional[UUID] = None, +# ) -> int: +# """Get relationship count.""" +# pass + +# # Cost estimation methods +# @abstractmethod +# async def get_creation_estimate( +# self, collection_id: UUID, kg_creation_settings: KGCreationSettings +# ): +# """Get creation cost estimate.""" +# pass + +# @abstractmethod +# async def get_enrichment_estimate( +# self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings +# ): +# """Get enrichment cost estimate.""" +# pass + +# @abstractmethod +# async def get_deduplication_estimate( +# self, +# collection_id: UUID, +# kg_deduplication_settings: KGEntityDeduplicationSettings, +# ): +# """Get deduplication cost estimate.""" +# pass + +# # Other operations +# @abstractmethod +# async def create_vector_index(self) -> None: +# """Create vector index.""" +# raise NotImplementedError + +# @abstractmethod +# async def delete_relationships(self, relationship_ids: list[int]) -> None: +# """Delete relationships.""" +# raise NotImplementedError + +# @abstractmethod +# async def get_schema(self) -> Any: +# """Get schema.""" +# raise NotImplementedError + +# @abstractmethod +# async def structured_query(self) -> Any: +# """Perform structured query.""" +# raise NotImplementedError + +# @abstractmethod +# async def update_extraction_prompt(self) -> None: +# """Update extraction prompt.""" +# raise NotImplementedError + +# @abstractmethod +# async def update_kg_search_prompt(self) -> None: +# """Update KG search prompt.""" +# raise NotImplementedError + +# @abstractmethod +# async def upsert_relationships(self) -> None: +# """Upsert relationships.""" +# raise NotImplementedError + +# @abstractmethod +# async def get_existing_entity_extraction_ids( +# self, document_id: UUID +# ) -> list[str]: +# """Get existing entity extraction IDs.""" +# raise NotImplementedError + +# @abstractmethod +# async def get_all_relationships( +# self, collection_id: UUID +# ) -> list[Relationship]: +# raise NotImplementedError + +# @abstractmethod +# async def update_entity_descriptions(self, entities: list[Entity]): +# raise NotImplementedError + + + +class EntityHandler(Handler): + @abstractmethod - async def create_entities( - self, - entities: list[Entity], - table_name: str, - conflict_columns: list[str] = [], - ) -> Any: - """Add entities to storage.""" + async def create(self, *args: Any, **kwargs: Any) -> None: + """Create entities in storage.""" pass @abstractmethod - async def get_entities( - self, - level: EntityLevel, - entity_names: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - offset: int = 0, - limit: int = -1, - ) -> dict: + async def get(self, *args: Any, **kwargs: Any) -> list[Entity]: """Get entities from storage.""" pass @abstractmethod - async def update_entity( - self, - entity: Entity, - table_name: str, - conflict_columns: list[str] = [], - ) -> Any: - """Update an entity in storage.""" - pass - - @abstractmethod - async def delete_entity( - self, - id: UUID, - chunk_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - collection_id: Optional[UUID] = None, - graph_id: Optional[UUID] = None, - ) -> None: - """Delete an entity from storage.""" + async def update(self, *args: Any, **kwargs: Any) -> None: + """Update entities in storage.""" pass - ### RELATIONSHIPS CRUD OPS ### @abstractmethod - async def add_relationships( - self, - relationships: list[Relationship], - table_name: str = "chunk_relationship", - ) -> None: - """Add relationships to storage.""" + async def delete(self, *args: Any, **kwargs: Any) -> None: + """Delete entities from storage.""" pass - @abstractmethod - async def get_entity_map( - self, offset: int, limit: int, document_id: UUID - ) -> dict[str, dict[str, list[dict[str, Any]]]]: - """Get entity map for a document.""" - pass +class RelationshipHandler(Handler): @abstractmethod - async def graph_search( - self, query: str, **kwargs: Any - ) -> AsyncGenerator[Any, None]: - """Perform vector similarity search.""" + async def create(self, *args: Any, **kwargs: Any) -> None: + """Add relationships to storage.""" pass - # Community management @abstractmethod - async def add_community_info(self, communities: list[Any]) -> None: - """Add communities to storage.""" + async def get(self, *args: Any, **kwargs: Any) -> list[Relationship]: + """Get relationships from storage.""" pass @abstractmethod - async def get_communities( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - levels: Optional[list[int]] = None, - community_numbers: Optional[list[int]] = None, - ) -> dict: - """Get communities for a collection.""" + async def update(self, *args: Any, **kwargs: Any) -> None: + """Update relationships in storage.""" pass @abstractmethod - async def add_community( - self, community: Community - ) -> None: - """Add a community report.""" + async def delete(self, *args: Any, **kwargs: Any) -> None: + """Delete relationships from storage.""" pass +class CommunityHandler(Handler): @abstractmethod - async def get_community_details( - self, community_number: int, collection_id: UUID - ) -> Tuple[int, list[Entity], list[Relationship]]: - """Get detailed information about a community.""" + async def create(self, *args: Any, **kwargs: Any) -> None: + """Create communities in storage.""" pass @abstractmethod - async def get_community( - self, collection_id: UUID - ) -> list[Community]: - """Get community reports for a collection.""" + async def get(self, *args: Any, **kwargs: Any) -> list[Community]: + """Get communities from storage.""" pass @abstractmethod - async def check_community_exists( - self, collection_id: UUID, offset: int, limit: int - ) -> list[int]: - """Check which community reports exist.""" + async def update(self, *args: Any, **kwargs: Any) -> None: + """Update communities in storage.""" pass @abstractmethod - async def perform_graph_clustering( - self, - collection_id: UUID, - leiden_params: dict[str, Any], - ) -> int: - """Perform graph clustering.""" + async def delete(self, *args: Any, **kwargs: Any) -> None: + """Delete communities from storage.""" pass - # Graph operations +class CommunityInfoHandler(Handler): @abstractmethod - async def delete_graph_for_collection( - self, collection_id: UUID, cascade: bool = False - ) -> None: - """Delete graph data for a collection.""" + async def create(self, *args: Any, **kwargs: Any) -> None: + """Create community info in storage.""" pass @abstractmethod - async def delete_node_via_document_id( - self, document_id: UUID, collection_id: UUID - ) -> None: - """Delete a node using document ID.""" - pass + async def get(self, *args: Any, **kwargs: Any) -> list[CommunityInfo]: + """Get community info from storage.""" + pass - # Entity and Relationship management @abstractmethod - async def get_entities( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - entity_ids: Optional[list[str]] = None, - entity_names: Optional[list[str]] = None, - entity_table_name: str = "document_entity", - extra_columns: Optional[list[str]] = None, - ) -> dict: - """Get entities from storage.""" + async def update(self, *args: Any, **kwargs: Any) -> None: + """Update community info in storage.""" pass @abstractmethod - async def create_entities_v3(self, entities: list[Entity]) -> None: - """Create entities in storage.""" + async def delete(self, *args: Any, **kwargs: Any) -> None: + """Delete community info from storage.""" pass +class GraphHandler(Handler): + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) @abstractmethod - async def get_relationships( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - relationship_ids: Optional[list[str]] = None, - ) -> dict: - """Get relationships from storage.""" + async def create(self, *args: Any, **kwargs: Any) -> None: + """Create graph in storage.""" pass @abstractmethod - async def get_entity_count( - self, - collection_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - distinct: bool = False, - entity_table_name: str = "document_entity", - ) -> int: - """Get entity count.""" + async def get(self, *args: Any, **kwargs: Any) -> list[Graph]: + """Get graph from storage.""" pass @abstractmethod - async def get_relationship_count( - self, - collection_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - ) -> int: - """Get relationship count.""" + async def update(self, *args: Any, **kwargs: Any) -> None: + """Update graph in storage.""" pass - # Cost estimation methods @abstractmethod - async def get_creation_estimate( - self, collection_id: UUID, kg_creation_settings: KGCreationSettings - ): - """Get creation cost estimate.""" + async def delete(self, *args: Any, **kwargs: Any) -> None: + """Delete graph from storage.""" pass + # add documents to the graph @abstractmethod - async def get_enrichment_estimate( - self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings - ): - """Get enrichment cost estimate.""" + async def add_document(self, *args: Any, **kwargs: Any) -> None: + """Add document to graph.""" pass @abstractmethod - async def get_deduplication_estimate( - self, - collection_id: UUID, - kg_deduplication_settings: KGEntityDeduplicationSettings, - ): - """Get deduplication cost estimate.""" + async def remove_document(self, *args: Any, **kwargs: Any) -> None: + """Delete document from graph.""" pass - # Other operations - @abstractmethod - async def create_vector_index(self) -> None: - """Create vector index.""" - raise NotImplementedError - - @abstractmethod - async def delete_relationships(self, relationship_ids: list[int]) -> None: - """Delete relationships.""" - raise NotImplementedError - - @abstractmethod - async def get_schema(self) -> Any: - """Get schema.""" - raise NotImplementedError - - @abstractmethod - async def structured_query(self) -> Any: - """Perform structured query.""" - raise NotImplementedError - @abstractmethod - async def update_extraction_prompt(self) -> None: - """Update extraction prompt.""" - raise NotImplementedError - - @abstractmethod - async def update_kg_search_prompt(self) -> None: - """Update KG search prompt.""" - raise NotImplementedError - - @abstractmethod - async def upsert_relationships(self) -> None: - """Upsert relationships.""" - raise NotImplementedError - - @abstractmethod - async def get_existing_entity_extraction_ids( - self, document_id: UUID - ) -> list[str]: - """Get existing entity extraction IDs.""" - raise NotImplementedError - - @abstractmethod - async def get_all_relationships( - self, collection_id: UUID - ) -> list[Relationship]: - raise NotImplementedError - - @abstractmethod - async def update_entity_descriptions(self, entities: list[Entity]): - raise NotImplementedError - class PromptHandler(Handler): """Abstract base class for prompt handling operations.""" @@ -1102,7 +1229,7 @@ class DatabaseProvider(Provider): token_handler: TokenHandler user_handler: UserHandler vector_handler: VectorHandler - kg_handler: KGHandler + graph_handler: GraphHandler prompt_handler: PromptHandler file_handler: FileHandler logging_handler: LoggingHandler @@ -1527,7 +1654,7 @@ async def add_entities( conflict_columns: list[str] = [], ) -> Any: """Forward to KG handler add_entities method.""" - return await self.kg_handler.add_entities( + return await self.graph_handler.add_entities( entities, table_name, conflict_columns ) @@ -1537,7 +1664,7 @@ async def add_relationships( table_name: str = "chunk_relationship", ) -> None: """Forward to KG handler add_relationships method.""" - return await self.kg_handler.add_relationships( + return await self.graph_handler.add_relationships( relationships, table_name ) @@ -1545,12 +1672,12 @@ async def get_entity_map( self, offset: int, limit: int, document_id: UUID ) -> dict[str, dict[str, list[dict[str, Any]]]]: """Forward to KG handler get_entity_map method.""" - return await self.kg_handler.get_entity_map(offset, limit, document_id) + return await self.graph_handler.get_entity_map(offset, limit, document_id) # Community methods async def add_community_info(self, communities: list[Any]) -> None: """Forward to KG handler add_communities method.""" - return await self.kg_handler.add_community_info(communities) + return await self.graph_handler.add_community_info(communities) async def get_communities( self, @@ -1561,7 +1688,7 @@ async def get_communities( community_numbers: Optional[list[int]] = None, ) -> dict: """Forward to KG handler get_communities method.""" - return await self.kg_handler.get_communities( + return await self.graph_handler.get_communities( offset=offset, limit=limit, collection_id=collection_id, @@ -1573,13 +1700,13 @@ async def add_community( self, community: Community ) -> None: """Forward to KG handler add_community method.""" - return await self.kg_handler.add_community(community) + return await self.graph_handler.add_community(community) async def get_community_details( self, community_number: int, collection_id: UUID ) -> Tuple[int, list[Entity], list[Relationship]]: """Forward to KG handler get_community_details method.""" - return await self.kg_handler.get_community_details( + return await self.graph_handler.get_community_details( community_number, collection_id ) @@ -1587,13 +1714,13 @@ async def get_community( self, collection_id: UUID ) -> list[Community]: """Forward to KG handler get_community method.""" - return await self.kg_handler.get_community(collection_id) + return await self.graph_handler.get_community(collection_id) async def check_community_exists( self, collection_id: UUID, offset: int, limit: int ) -> list[int]: """Forward to KG handler check_community_exists method.""" - return await self.kg_handler.check_community_exists( + return await self.graph_handler.check_community_exists( collection_id, offset, limit ) @@ -1603,7 +1730,7 @@ async def perform_graph_clustering( leiden_params: dict[str, Any], ) -> int: """Forward to KG handler perform_graph_clustering method.""" - return await self.kg_handler.perform_graph_clustering( + return await self.graph_handler.perform_graph_clustering( collection_id, leiden_params ) @@ -1612,7 +1739,7 @@ async def delete_graph_for_collection( self, collection_id: UUID, cascade: bool = False ) -> None: """Forward to KG handler delete_graph_for_collection method.""" - return await self.kg_handler.delete_graph_for_collection( + return await self.graph_handler.delete_graph_for_collection( collection_id, cascade ) @@ -1620,7 +1747,7 @@ async def delete_node_via_document_id( self, document_id: UUID, collection_id: UUID ) -> None: """Forward to KG handler delete_node_via_document_id method.""" - return await self.kg_handler.delete_node_via_document_id( + return await self.graph_handler.delete_node_via_document_id( document_id, collection_id ) @@ -1636,7 +1763,7 @@ async def get_entities( extra_columns: Optional[list[str]] = None, ) -> dict: """Forward to KG handler get_entities method.""" - return await self.kg_handler.get_entities( + return await self.graph_handler.get_entities( offset=offset, limit=limit, collection_id=collection_id, @@ -1645,8 +1772,6 @@ async def get_entities( entity_table_name=entity_table_name, extra_columns=extra_columns, ) - - async async def get_relationships( self, @@ -1657,7 +1782,7 @@ async def get_relationships( relationship_ids: Optional[list[str]] = None, ) -> dict: """Forward to KG handler get_relationships method.""" - return await self.kg_handler.get_relationships( + return await self.graph_handler.get_relationships( offset=offset, limit=limit, collection_id=collection_id, @@ -1673,7 +1798,7 @@ async def get_entity_count( entity_table_name: str = "document_entity", ) -> int: """Forward to KG handler get_entity_count method.""" - return await self.kg_handler.get_entity_count( + return await self.graph_handler.get_entity_count( collection_id, document_id, distinct, entity_table_name ) @@ -1683,7 +1808,7 @@ async def get_relationship_count( document_id: Optional[UUID] = None, ) -> int: """Forward to KG handler get_relationship_count method.""" - return await self.kg_handler.get_relationship_count( + return await self.graph_handler.get_relationship_count( collection_id, document_id ) @@ -1692,7 +1817,7 @@ async def get_creation_estimate( self, collection_id: UUID, kg_creation_settings: KGCreationSettings ): """Forward to KG handler get_creation_estimate method.""" - return await self.kg_handler.get_creation_estimate( + return await self.graph_handler.get_creation_estimate( collection_id, kg_creation_settings ) @@ -1700,7 +1825,7 @@ async def get_enrichment_estimate( self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings ): """Forward to KG handler get_enrichment_estimate method.""" - return await self.kg_handler.get_enrichment_estimate( + return await self.graph_handler.get_enrichment_estimate( collection_id, kg_enrichment_settings ) @@ -1710,48 +1835,48 @@ async def get_deduplication_estimate( kg_deduplication_settings: KGEntityDeduplicationSettings, ): """Forward to KG handler get_deduplication_estimate method.""" - return await self.kg_handler.get_deduplication_estimate( + return await self.graph_handler.get_deduplication_estimate( collection_id, kg_deduplication_settings ) async def get_all_relationships( self, collection_id: UUID ) -> list[Relationship]: - return await self.kg_handler.get_all_relationships(collection_id) + return await self.graph_handler.get_all_relationships(collection_id) async def update_entity_descriptions(self, entities: list[Entity]): - return await self.kg_handler.update_entity_descriptions(entities) + return await self.graph_handler.update_entity_descriptions(entities) async def graph_search( self, query: str, **kwargs: Any ) -> AsyncGenerator[Any, None]: - return self.kg_handler.graph_search(query, **kwargs) # type: ignore + return self.graph_handler.graph_search(query, **kwargs) # type: ignore async def create_vector_index(self) -> None: - return await self.kg_handler.create_vector_index() + return await self.graph_handler.create_vector_index() async def delete_relationships(self, relationship_ids: list[int]) -> None: - return await self.kg_handler.delete_relationships(relationship_ids) + return await self.graph_handler.delete_relationships(relationship_ids) async def get_schema(self) -> Any: - return await self.kg_handler.get_schema() + return await self.graph_handler.get_schema() async def structured_query(self) -> Any: - return await self.kg_handler.structured_query() + return await self.graph_handler.structured_query() async def update_extraction_prompt(self) -> None: - return await self.kg_handler.update_extraction_prompt() + return await self.graph_handler.update_extraction_prompt() async def update_kg_search_prompt(self) -> None: - return await self.kg_handler.update_kg_search_prompt() + return await self.graph_handler.update_kg_search_prompt() async def upsert_relationships(self) -> None: - return await self.kg_handler.upsert_relationships() + return await self.graph_handler.upsert_relationships() async def get_existing_entity_extraction_ids( self, document_id: UUID ) -> list[str]: - return await self.kg_handler.get_existing_entity_extraction_ids( + return await self.graph_handler.get_existing_entity_extraction_ids( document_id ) diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index 99c886359..aa572b680 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -54,7 +54,7 @@ async def store( """ try: # clean up and remove this method. - # make add_kg_extractions a method in the KGHandler + # make add_kg_extractions a method in the GraphHandler total_entities, total_relationships = 0, 0 diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 9746b1a70..1059712a9 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -8,15 +8,18 @@ import asyncpg from asyncpg.exceptions import PostgresError, UndefinedTableError -from core.base import ( +from core.base.abstractions import ( Community, Entity, KGExtraction, KGExtractionStatus, - KGHandler, + Graph, R2RException, Relationship, ) + +from core.base.providers.database import GraphHandler, EntityHandler, RelationshipHandler, CommunityHandler, CommunityInfoHandler + from core.base.abstractions import ( CommunityInfo, EntityLevel, @@ -35,8 +38,191 @@ logger = logging.getLogger() -class PostgresKGHandler(KGHandler): - """Handler for Knowledge Graph operations in PostgreSQL.""" +class PostgresEntityHandler(EntityHandler): + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.dimension = kwargs.get("dimension") + self.quantization_type = kwargs.get("quantization_type") + + async def create_tables(self) -> None: + + vector_column_str = _decorate_vector_type( + f"({self.dimension})", self.quantization_type + ) + + query = f""" + CREATE TABLE IF NOT EXISTS {self._get_table_name("chunk_entity")} ( + id SERIAL PRIMARY KEY, + category TEXT NOT NULL, + name TEXT NOT NULL, + description TEXT NOT NULL, + extraction_ids UUID[] NOT NULL, + document_id UUID NOT NULL, + attributes JSONB + ); + """ + await self.connection_manager.execute_query(query) + + # embeddings tables + query = f""" + CREATE TABLE IF NOT EXISTS {self._get_table_name("document_entity")} ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + description TEXT NOT NULL, + extraction_ids UUID[] NOT NULL, + description_embedding {vector_column_str} NOT NULL, + document_id UUID NOT NULL, + UNIQUE (name, document_id) + ); + """ + + await self.connection_manager.execute_query(query) + + # deduplicated entities table + query = f""" + CREATE TABLE IF NOT EXISTS {self._get_table_name("collection_entity")} ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + extraction_ids UUID[] NOT NULL, + document_ids UUID[] NOT NULL, + collection_id UUID NOT NULL, + description_embedding {vector_column_str}, + attributes JSONB, + UNIQUE (name, collection_id, attributes) + );""" + + await self.connection_manager.execute_query(query) + + +class PostgresRelationshipHandler(RelationshipHandler): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + async def create_tables(self) -> None: + pass + +class PostgresCommunityHandler(CommunityHandler): + pass + +class PostgresCommunityInfoHandler(CommunityInfoHandler): + pass + + + + + + + + + + + +class PostgresGraphHandler(GraphHandler): + """Handler for Knowledge Graph METHODS in PostgreSQL.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self.project_name = kwargs.get("project_name") + self.connection_manager = kwargs.get("connection_manager") + self.dimension = kwargs.get("dimension") + self.quantization_type = kwargs.get("quantization_type") + self.collection_handler = kwargs.get("collection_handler") + + self.entity_handler = PostgresEntityHandler(*args, **kwargs) + self.relationship_handler = PostgresRelationshipHandler(*args, **kwargs) + self.community_handler = PostgresCommunityHandler(*args, **kwargs) + self.community_info_handler = PostgresCommunityInfoHandler(*args, **kwargs) + + self.handlers = [ + self.entity_handler, + self.relationship_handler, + self.community_handler, + self.community_info_handler, + ] + + async def create_tables(self) -> None: + QUERY = f""" + CREATE TABLE IF NOT EXISTS {self._get_table_name("graph")} ( + id UUID PRIMARY KEY, + status TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL, + document_ids UUID[] NOT NULL, + collection_ids UUID[] NOT NULL, + attributes JSONB NOT NULL + ); + """ + + await self.connection_manager.execute_query(QUERY) + + for handler in self.handlers: + await handler.create_tables() + + async def create(self, graph: Graph) -> None: + QUERY = f""" + INSERT INTO {self._get_table_name("graph")} (id, status, created_at, updated_at, document_ids, collection_ids, attributes) + VALUES ($1, $2, $3, $4, $5, $6, $7) + """ + await self.connection_manager.execute_query(QUERY, *graph.to_dict().values()) + + async def update(self, graph: Graph) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph")} SET status = $2, updated_at = $3, document_ids = $4, collection_ids = $5, attributes = $6 WHERE id = $1 + """ + await self.connection_manager.execute_query(QUERY, *graph.to_dict().values()) + + async def delete(self, graph_id: UUID) -> None: + QUERY = f""" + DELETE FROM {self._get_table_name("graph")} WHERE id = $1 + """ + await self.connection_manager.execute_query(QUERY, graph_id) + + + async def get(self, graph_id: UUID) -> Graph: + QUERY = f""" + SELECT * FROM {self._get_table_name("graph")} WHERE id = $1 + """ + return Graph.from_dict(await self.connection_manager.fetch_query(QUERY, graph_id)) + + async def add_document(self, graph_id: UUID, document_id: UUID) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph")} SET document_ids = array_append(document_ids, $2) WHERE id = $1 + """ + await self.connection_manager.execute_query(QUERY, graph_id, document_id) + + async def remove_document(self, graph_id: UUID, document_id: UUID) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph")} SET document_ids = array_remove(document_ids, $2) WHERE id = $1 + """ + await self.connection_manager.execute_query(QUERY, graph_id, document_id) + + async def add_collection(self, graph_id: UUID, collection_id: UUID) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph")} SET collection_ids = array_append(collection_ids, $2) WHERE id = $1 + """ + await self.connection_manager.execute_query(QUERY, graph_id, collection_id) + + async def remove_collection(self, graph_id: UUID, collection_id: UUID) -> None: + QUERY = f""" + UPDATE {self._get_table_name("graph")} SET collection_ids = array_remove(collection_ids, $2) WHERE id = $1 + """ + await self.connection_manager.execute_query(QUERY, graph_id, collection_id) + + + + + + + + + + + +class PostgresGraphHandler_v1(GraphHandler): + """Handler for Knowledge Graph METHODS in PostgreSQL.""" def __init__( self, @@ -66,6 +252,8 @@ def _get_table_name(self, base_name: str) -> str: """Get the fully qualified table name.""" return f"{self.project_name}.{base_name}" + ####################### TABLE CREATION METHODS ####################### + async def create_tables(self): # raw entities table # create schema @@ -169,47 +357,127 @@ async def create_tables(self): await self.connection_manager.execute_query(query) - async def _add_objects( + + ################### ENTITY METHODS ################### + + + async def get_entities_v3( self, - objects: list[Any], - table_name: str, - conflict_columns: list[str] = [], - ) -> asyncpg.Record: - """ - Upsert objects into the specified table. - """ - # Get non-null attributes from the first object - non_null_attrs = {k: v for k, v in objects[0].items() if v is not None} - columns = ", ".join(non_null_attrs.keys()) + level: EntityLevel, + id: Optional[UUID] = None, + entity_names: Optional[list[str]] = None, + entity_categories: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, + offset: int = 0, + limit: int = -1, + ): - placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs))) + params: list = [id] - if conflict_columns: - conflict_columns_str = ", ".join(conflict_columns) - replace_columns_str = ", ".join( - f"{column} = EXCLUDED.{column}" for column in non_null_attrs + if level != EntityLevel.CHUNK and entity_categories: + raise ValueError( + "entity_categories are only supported for chunk level entities" ) - on_conflict_query = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {replace_columns_str}" - else: - on_conflict_query = "" + + filter = { + EntityLevel.CHUNK: "chunk_ids = ANY($1)", + EntityLevel.DOCUMENT: "document_id = $1", + EntityLevel.COLLECTION: "collection_id = $1", + }[level] + + if entity_names: + filter += " AND name = ANY($2)" + params.append(entity_names) + + if entity_categories: + filter += " AND category = ANY($3)" + params.append(entity_categories) QUERY = f""" - INSERT INTO {self._get_table_name(table_name)} ({columns}) - VALUES ({placeholders}) - {on_conflict_query} + SELECT * from {self._get_table_name(level.table_name)} WHERE {filter} + OFFSET ${len(params)} LIMIT ${len(params) + 1} """ - # Filter out null values for each object - params = [ - tuple( - (json.dumps(v) if isinstance(v, dict) else v) - for v in obj.values() - if v is not None + params.extend([offset, limit]) + + output = await self.connection_manager.fetch_query(QUERY, params) + + if attributes: + output = [ + entity for entity in output if entity["name"] in attributes + ] + + return output + + # TODO: deprecate this + async def get_entities( + self, + offset: int, + limit: int, + collection_id: Optional[UUID] = None, + entity_ids: Optional[list[str]] = None, + entity_names: Optional[list[str]] = None, + entity_table_name: str = "document_entity", + extra_columns: Optional[list[str]] = None, + ) -> dict: + conditions = [] + params: list = [collection_id] + param_index = 2 + + if entity_ids: + conditions.append(f"id = ANY(${param_index})") + params.append(entity_ids) + param_index += 1 + + if entity_names: + conditions.append(f"name = ANY(${param_index})") + params.append(entity_names) + param_index += 1 + + pagination_params = [] + if offset: + pagination_params.append(f"OFFSET ${param_index}") + params.append(offset) + param_index += 1 + + if limit != -1: + pagination_params.append(f"LIMIT ${param_index}") + params.append(limit) + param_index += 1 + + pagination_clause = " ".join(pagination_params) + + if entity_table_name == "collection_entity": + query = f""" + SELECT id, name, description, extraction_ids, document_ids {", " + ", ".join(extra_columns) if extra_columns else ""} + FROM {self._get_table_name(entity_table_name)} + WHERE collection_id = $1 + {" AND " + " AND ".join(conditions) if conditions else ""} + ORDER BY id + {pagination_clause} + """ + else: + query = f""" + SELECT id, name, description, extraction_ids, document_id {", " + ", ".join(extra_columns) if extra_columns else ""} + FROM {self._get_table_name(entity_table_name)} + WHERE document_id = ANY( + SELECT document_id FROM {self._get_table_name("document_info")} + WHERE $1 = ANY(collection_ids) ) - for obj in objects - ] + {" AND " + " AND ".join(conditions) if conditions else ""} + ORDER BY id + {pagination_clause} + """ + + results = await self.connection_manager.fetch_query(query, params) + entities = [Entity(**entity) for entity in results] + + total_entries = await self.get_entity_count( + collection_id=collection_id, entity_table_name=entity_table_name + ) + + return {"entities": entities, "total_entries": total_entries} - return await self.connection_manager.execute_many(QUERY, params) # type: ignore async def add_entities( self, @@ -246,61 +514,100 @@ async def add_entities( cleaned_entities, table_name, conflict_columns ) - async def get_graph_status(self, collection_id: UUID) -> dict: - # check document_info table for the documents in the collection and return the status of each document - kg_extraction_statuses = await self.connection_manager.fetch_query( - f"SELECT document_id, kg_extraction_status FROM {self._get_table_name('document_info')} WHERE collection_id = $1", - [collection_id], + + async def create_entities_v3( + self, level: EntityLevel, id: UUID, entities: list[Entity] + ) -> None: + + # TODO: check if already exists + await self._add_objects(entities, level.table_name) + + async def update_entity(self, collection_id: UUID, entity: Entity) -> None: + table_name = entity.level.value + "_entity" + + # check if the entity already exists + QUERY = f""" + SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 + """ + count = ( + await self.connection_manager.fetch_query( + QUERY, [entity.id, collection_id] + ) + )[0]["count"] + + if count == 0: + raise R2RException("Entity does not exist", 404) + + await self._add_objects([entity], table_name) + + async def delete_entity(self, collection_id: UUID, entity: Entity) -> None: + + table_name = entity.level.value + "_entity" + QUERY = f""" + DELETE FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 + """ + await self.connection_manager.execute_query( + QUERY, [entity.id, collection_id] ) - document_ids = [ - doc_id["document_id"] for doc_id in kg_extraction_statuses + + async def delete_node_via_document_id( + self, document_id: UUID, collection_id: UUID + ) -> None: + # don't delete if status is PROCESSING. + QUERY = f""" + SELECT kg_enrichment_status FROM {self._get_table_name("collections")} WHERE collection_id = $1 + """ + status = ( + await self.connection_manager.fetch_query(QUERY, [collection_id]) + )[0]["kg_enrichment_status"] + if status == KGExtractionStatus.PROCESSING.value: + return + + # Execute separate DELETE queries + delete_queries = [ + f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = $1", + f"DELETE FROM {self._get_table_name('chunk_relationship')} WHERE document_id = $1", + f"DELETE FROM {self._get_table_name('document_entity')} WHERE document_id = $1", ] - kg_enrichment_statuses = await self.connection_manager.fetch_query( - f"SELECT enrichment_status FROM {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} WHERE id = $1", - [collection_id], - ) + for query in delete_queries: + await self.connection_manager.execute_query(query, [document_id]) - # entity and relationship counts - chunk_entity_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1)", - [document_ids], + # Check if this is the last document in the collection + # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + documents = await self.collection_handler.documents_in_collection( + offset=0, + limit=100, + collection_id=collection_id, ) + count = documents["total_entries"] - chunk_relationship_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1)", - [document_ids], - ) + if count == 0: + # If it's the last document, delete collection-related data + collection_queries = [ + f"DELETE FROM {self._get_table_name('community_info')} WHERE collection_id = $1", + f"DELETE FROM {self._get_table_name('community')} WHERE collection_id = $1", + ] + for query in collection_queries: + await self.connection_manager.execute_query( + query, [collection_id] + ) # Ensure collection_id is in a list - document_entity_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1)", - [document_ids], - ) + # set status to PENDING for this collection. + QUERY = f""" + UPDATE {self._get_table_name("collections")} SET kg_enrichment_status = $1 WHERE collection_id = $2 + """ + await self.connection_manager.execute_query( + QUERY, [KGExtractionStatus.PENDING, collection_id] + ) + return None + return None - collection_entity_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('collection_entity')} WHERE collection_id = $1", - [collection_id], - ) - community_count = await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('community')} WHERE collection_id = $1", - [collection_id], - ) - return { - "kg_extraction_statuses": kg_extraction_statuses, - "kg_enrichment_status": kg_enrichment_statuses[0][ - "enrichment_status" - ], - "chunk_entity_count": chunk_entity_count[0]["count"], - "chunk_relationship_count": chunk_relationship_count[0]["count"], - "document_entity_count": document_entity_count[0]["count"], - "collection_entity_count": collection_entity_count[0]["count"], - "community_count": community_count[0]["count"], - } + ##################### RELATIONSHIP METHODS ##################### - ### Relationships BEGIN #### async def add_relationships( self, relationships: list[Relationship], @@ -367,188 +674,165 @@ async def list_relationships_v3( return results - ### Relationships END #### - async def get_entity_map( - self, offset: int, limit: int, document_id: UUID - ) -> dict[str, dict[str, list[dict[str, Any]]]]: + async def get_all_relationships( + self, collection_id: UUID + ) -> list[Relationship]: - QUERY1 = f""" - WITH entities_list AS ( - SELECT DISTINCT name - FROM {self._get_table_name("chunk_entity")} - WHERE document_id = $1 - ORDER BY name ASC - LIMIT {limit} OFFSET {offset} + # getting all documents for a collection + if document_ids is None: + QUERY = f""" + select distinct document_id from {self._get_table_name("document_info")} where $1 = ANY(collection_ids) + """ + document_ids_list = await self.connection_manager.fetch_query( + QUERY, [collection_id] ) - SELECT e.name, e.description, e.category, - (SELECT array_agg(DISTINCT x) FROM unnest(e.extraction_ids) x) AS extraction_ids, - e.document_id - FROM {self._get_table_name("chunk_entity")} e - JOIN entities_list el ON e.name = el.name - GROUP BY e.name, e.description, e.category, e.extraction_ids, e.document_id - ORDER BY e.name;""" + document_ids = [ + doc_id["document_id"] for doc_id in document_ids_list + ] - entities_list = await self.connection_manager.fetch_query( - QUERY1, [document_id] + QUERY = f""" + SELECT id, subject, predicate, weight, object, document_id FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY($1) + """ + relationships = await self.connection_manager.fetch_query( + QUERY, [document_ids] ) - entities_list = [ - Entity( - name=entity["name"], - description=entity["description"], - category=entity["category"], - extraction_ids=entity["extraction_ids"], - document_id=entity["document_id"], - ) - for entity in entities_list - ] + return [Relationship(**relationship) for relationship in relationships] - QUERY2 = f""" - WITH entities_list AS ( - SELECT DISTINCT name - FROM {self._get_table_name("chunk_entity")} - WHERE document_id = $1 - ORDER BY name ASC - LIMIT {limit} OFFSET {offset} - ) + async def create_relationship( + self, collection_id: UUID, relationship: Relationship + ) -> None: - SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description, - (SELECT array_agg(DISTINCT x) FROM unnest(t.extraction_ids) x) AS extraction_ids, t.document_id - FROM {self._get_table_name("chunk_relationship")} t - JOIN entities_list el ON t.subject = el.name - ORDER BY t.subject, t.predicate, t.object; + # check if the relationship already exists + QUERY = f""" + SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE subject = $1 AND predicate = $2 AND object = $3 AND collection_id = $4 """ - - relationships_list = await self.connection_manager.fetch_query( - QUERY2, [document_id] - ) - relationships_list = [ - Relationship( - subject=relationship["subject"], - predicate=relationship["predicate"], - object=relationship["object"], - weight=relationship["weight"], - description=relationship["description"], - extraction_ids=relationship["extraction_ids"], - document_id=relationship["document_id"], + count = ( + await self.connection_manager.fetch_query( + QUERY, + [ + relationship.subject, + relationship.predicate, + relationship.object, + collection_id, + ], ) - for relationship in relationships_list - ] + )[0]["count"] - entity_map: dict[str, dict[str, list[Any]]] = {} - for entity in entities_list: - if entity.name not in entity_map: - entity_map[entity.name] = {"entities": [], "relationships": []} - entity_map[entity.name]["entities"].append(entity) + if count > 0: + raise R2RException("Relationship already exists", 400) - for relationship in relationships_list: - if relationship.subject in entity_map: - entity_map[relationship.subject]["relationships"].append( - relationship - ) - if relationship.object in entity_map: - entity_map[relationship.object]["relationships"].append( - relationship - ) + await self._add_objects([relationship], "chunk_relationship") - return entity_map + async def update_relationship( + self, relationship_id: UUID, relationship: Relationship + ) -> None: - async def graph_search( # type: ignore - self, query: str, **kwargs: Any - ) -> AsyncGenerator[Any, None]: + # check if relationship_id exists + QUERY = f""" + SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 + """ + count = ( + await self.connection_manager.fetch_query(QUERY, [relationship.id]) + )[0]["count"] - query_embedding = kwargs.get("query_embedding", None) - search_type = kwargs.get("search_type", "__Entity__") - embedding_type = kwargs.get("embedding_type", "description_embedding") - property_names = kwargs.get("property_names", ["name", "description"]) - filters = kwargs.get("filters", {}) - entities_level = kwargs.get("entities_level", EntityLevel.DOCUMENT) - limit = kwargs.get("limit", 10) + if count == 0: + raise R2RException("Relationship does not exist", 404) - table_name = "" - if search_type == "__Entity__": - table_name = ( - "collection_entity" - if entities_level == EntityLevel.COLLECTION - else "document_entity" - ) - elif search_type == "__Relationship__": - table_name = "chunk_relationship" - elif search_type == "__Community__": - table_name = "community" - else: - raise ValueError(f"Invalid search type: {search_type}") - - property_names_str = ", ".join(property_names) + await self._add_objects([relationship], "chunk_relationship") - collection_ids_dict = filters.get("collection_ids", {}) - filter_query = "" - if collection_ids_dict: - filter_query = "WHERE collection_id = ANY($3)" - filter_ids = collection_ids_dict["$overlap"] + async def delete_relationship(self, relationship_id: UUID) -> None: + QUERY = f""" + DELETE FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 + """ + await self.connection_manager.execute_query(QUERY, [relationship_id]) - if ( - search_type == "__Community__" - or table_name == "collection_entity" - ): - logger.info(f"Searching in collection ids: {filter_ids}") - elif search_type in ["__Entity__", "__Relationship__"]: - filter_query = "WHERE document_id = ANY($3)" - # TODO - This seems like a hack, we will need a better way to filter by collection ids for entities and relationships - query = f""" - SELECT distinct document_id FROM {self._get_table_name('document_info')} WHERE $1 = ANY(collection_ids) - """ - filter_ids = [ - doc_id["document_id"] - for doc_id in await self.connection_manager.fetch_query( - query, filter_ids - ) - ] - logger.info(f"Searching in document ids: {filter_ids}") + async def get_relationships( + self, + offset: int, + limit: int, + collection_id: Optional[UUID] = None, + entity_names: Optional[list[str]] = None, + relationship_ids: Optional[list[str]] = None, + ) -> dict: + conditions = [] + params: list = [str(collection_id)] + param_index = 2 - QUERY = f""" - SELECT {property_names_str} FROM {self._get_table_name(table_name)} {filter_query} ORDER BY {embedding_type} <=> $1 LIMIT $2; - """ + if relationship_ids: + conditions.append(f"id = ANY(${param_index})") + params.append(relationship_ids) + param_index += 1 - if filter_query != "": - results = await self.connection_manager.fetch_query( - QUERY, (str(query_embedding), limit, filter_ids) - ) - else: - results = await self.connection_manager.fetch_query( - QUERY, (str(query_embedding), limit) + if entity_names: + conditions.append( + f"subject = ANY(${param_index}) or object = ANY(${param_index})" ) + params.append(entity_names) + param_index += 1 - for result in results: - yield { - property_name: result[property_name] - for property_name in property_names - } + pagination_params = [] + if offset: + pagination_params.append(f"OFFSET ${param_index}") + params.append(offset) + param_index += 1 - async def get_all_relationships( - self, collection_id: UUID - ) -> list[Relationship]: + if limit != -1: + pagination_params.append(f"LIMIT ${param_index}") + params.append(limit) + param_index += 1 - # getting all documents for a collection - if document_ids is None: - QUERY = f""" - select distinct document_id from {self._get_table_name("document_info")} where $1 = ANY(collection_ids) - """ - document_ids_list = await self.connection_manager.fetch_query( - QUERY, [collection_id] + pagination_clause = " ".join(pagination_params) + + query = f""" + SELECT id, subject, predicate, object, description + FROM {self._get_table_name("chunk_relationship")} + WHERE document_id = ANY( + SELECT document_id FROM {self._get_table_name("document_info")} + WHERE $1 = ANY(collection_ids) ) - document_ids = [ - doc_id["document_id"] for doc_id in document_ids_list - ] + {" AND " + " AND ".join(conditions) if conditions else ""} + ORDER BY id + {pagination_clause} + """ + + relationships = await self.connection_manager.fetch_query( + query, params + ) + relationships = [ + Relationship(**relationship) for relationship in relationships + ] + total_entries = await self.get_relationship_count( + collection_id=collection_id + ) + + return {"relationships": relationships, "total_entries": total_entries} + + ####################### COMMUNITY METHODS ####################### + + async def get_communities( + self, collection_id: UUID + ) -> list[Community]: + QUERY = f""" + SELECT *c FROM {self._get_table_name("community")} WHERE collection_id = $1 + """ + return await self.connection_manager.fetch_query( + QUERY, [collection_id] + ) + async def check_communities_exist( + self, collection_id: UUID, offset: int, limit: int + ) -> list[int]: QUERY = f""" - SELECT id, subject, predicate, weight, object, document_id FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY($1) + SELECT distinct community_number FROM {self._get_table_name("community")} WHERE collection_id = $1 AND community_number >= $2 AND community_number < $3 """ - relationships = await self.connection_manager.fetch_query( - QUERY, [document_ids] + community_numbers = await self.connection_manager.fetch_query( + QUERY, [collection_id, offset, offset + limit] ) - return [Relationship(**relationship) for relationship in relationships] + return [item["community_number"] for item in community_numbers] + + async def add_community_info( self, communities: list[CommunityInfo] @@ -625,6 +909,73 @@ async def get_communities( "communities": communities, "total_entries": total_entries, } + + + + async def get_community_details( + self, community_number: int, collection_id: UUID + ) -> Tuple[int, list[Entity], list[Relationship]]: + + QUERY = f""" + SELECT level FROM {self._get_table_name("community_info")} WHERE cluster = $1 AND collection_id = $2 + LIMIT 1 + """ + level = ( + await self.connection_manager.fetch_query( + QUERY, [community_number, collection_id] + ) + )[0]["level"] + + # selecting table name based on entity level + # check if there are any entities in the community that are not in the entity_embedding table + query = f""" + SELECT COUNT(*) FROM {self._get_table_name("collection_entity")} WHERE collection_id = $1 + """ + entity_count = ( + await self.connection_manager.fetch_query(query, [collection_id]) + )[0]["count"] + table_name = ( + "collection_entity" if entity_count > 0 else "document_entity" + ) + + QUERY = f""" + WITH node_relationship_ids AS ( + SELECT node, relationship_ids + FROM {self._get_table_name("community_info")} + WHERE cluster = $1 AND collection_id = $2 + ) + SELECT DISTINCT + e.id AS id, + e.name AS name, + e.description AS description + FROM node_relationship_ids nti + JOIN {self._get_table_name(table_name)} e ON e.name = nti.node; + """ + entities = await self.connection_manager.fetch_query( + QUERY, [community_number, collection_id] + ) + entities = [Entity(**entity) for entity in entities] + + QUERY = f""" + WITH node_relationship_ids AS ( + SELECT node, relationship_ids + FROM {self._get_table_name("community_info")} + WHERE cluster = $1 and collection_id = $2 + ) + SELECT DISTINCT + t.id, t.subject, t.predicate, t.object, t.weight, t.description + FROM node_relationship_ids nti + JOIN {self._get_table_name("chunk_relationship")} t ON t.id = ANY(nti.relationship_ids); + """ + relationships = await self.connection_manager.fetch_query( + QUERY, [community_number, collection_id] + ) + relationships = [ + Relationship(**relationship) for relationship in relationships + ] + + return level, entities, relationships + async def add_community( self, community: Community @@ -655,246 +1006,73 @@ async def add_community( QUERY, [tuple(non_null_attrs.values())] ) - async def _create_graph_and_cluster( - self, relationships: list[Relationship], leiden_params: dict[str, Any] - ) -> Any: + async def delete_graph_for_collection( + self, collection_id: UUID, cascade: bool = False + ) -> None: - G = self.nx.Graph() - for relationship in relationships: - G.add_edge( - relationship.subject, - relationship.object, - weight=relationship.weight, - id=relationship.id, - ) + # don't delete if status is PROCESSING. + QUERY = f""" + SELECT kg_enrichment_status FROM {self._get_table_name("collections")} WHERE collection_id = $1 + """ + status = ( + await self.connection_manager.fetch_query(QUERY, [collection_id]) + )[0]["kg_enrichment_status"] + if status == KGExtractionStatus.PROCESSING.value: + return - hierarchical_communities = await self._compute_leiden_communities( - G, leiden_params + # remove all relationships for these documents. + DELETE_QUERIES = [ + f"DELETE FROM {self._get_table_name('community_info')} WHERE collection_id = $1;", + f"DELETE FROM {self._get_table_name('community')} WHERE collection_id = $1;", + ] + + # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + document_ids_response = ( + await self.collection_handler.documents_in_collection( + offset=0, + limit=100, + collection_id=collection_id, + ) ) - return hierarchical_communities + # This type ignore is due to insufficient typing of the documents_in_collection method + document_ids = [doc.id for doc in document_ids_response["results"]] # type: ignore - async def _cluster_and_add_community_info( - self, - relationships: list[Relationship], - relationship_ids_cache: dict[str, list[int]], - leiden_params: dict[str, Any], - collection_id: UUID, - ) -> int: - - # clear if there is any old information - QUERY = f""" - DELETE FROM {self._get_table_name("community_info")} WHERE collection_id = $1 - """ - await self.connection_manager.execute_query(QUERY, [collection_id]) - - QUERY = f""" - DELETE FROM {self._get_table_name("community")} WHERE collection_id = $1 - """ - await self.connection_manager.execute_query(QUERY, [collection_id]) - - start_time = time.time() - - hierarchical_communities = await self._create_graph_and_cluster( - relationships, leiden_params - ) - - logger.info( - f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds." - ) - - def relationship_ids(node: str) -> list[int]: - return relationship_ids_cache.get(node, []) - - logger.info( - f"Cached {len(relationship_ids_cache)} relationship ids, time {time.time() - start_time:.2f} seconds." - ) - - # upsert the communities into the database. - inputs = [ - CommunityInfo( - node=str(item.node), - cluster=item.cluster, - parent_cluster=item.parent_cluster, - level=item.level, - is_final_cluster=item.is_final_cluster, - relationship_ids=relationship_ids(item.node), - collection_id=collection_id, - ) - for item in hierarchical_communities - ] - - await self.add_community_info(inputs) - - num_communities = ( - max([item.cluster for item in hierarchical_communities]) + 1 - ) - - logger.info( - f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds." - ) - - return num_communities - - async def _use_community_cache( - self, collection_id: UUID, relationship_ids_cache: dict[str, list[int]] - ) -> bool: - - # check if status is enriched or stale - QUERY = f""" - SELECT kg_enrichment_status FROM {self._get_table_name("collections")} WHERE collection_id = $1 - """ - status = ( - await self.connection_manager.fetchrow_query( - QUERY, [collection_id] - ) - )["kg_enrichment_status"] - if status == KGEnrichmentStatus.PENDING: - return False + # TODO: make these queries more efficient. Pass the document_ids as params. + if cascade: + DELETE_QUERIES += [ + f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1::uuid[]);", + f"DELETE FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1::uuid[]);", + f"DELETE FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1::uuid[]);", + f"DELETE FROM {self._get_table_name('collection_entity')} WHERE collection_id = $1;", + ] - # check the number of entities in the cache. - QUERY = f""" - SELECT COUNT(distinct node) FROM {self._get_table_name("community_info")} WHERE collection_id = $1 - """ - num_entities = ( - await self.connection_manager.fetchrow_query( - QUERY, [collection_id] + # setting the kg_creation_status to PENDING for this collection. + QUERY = f""" + UPDATE {self._get_table_name("document_info")} SET kg_extraction_status = $1 WHERE $2::uuid = ANY(collection_ids) + """ + await self.connection_manager.execute_query( + QUERY, [KGExtractionStatus.PENDING, collection_id] ) - )["count"] - - # a hard threshold of 80% of the entities in the cache. - if num_entities > 0.8 * len(relationship_ids_cache): - return True - else: - return False - async def _get_relationship_ids_cache( - self, relationships: list[Relationship] - ) -> dict[str, list[int]]: - - # caching the relationship ids - relationship_ids_cache = dict[str, list[int]]() - for relationship in relationships: - if ( - relationship.subject not in relationship_ids_cache - and relationship.subject is not None - ): - relationship_ids_cache[relationship.subject] = [] - if ( - relationship.object not in relationship_ids_cache - and relationship.object is not None - ): - relationship_ids_cache[relationship.object] = [] - if ( - relationship.subject is not None - and relationship.id is not None - ): - relationship_ids_cache[relationship.subject].append( - relationship.id - ) - if relationship.object is not None and relationship.id is not None: - relationship_ids_cache[relationship.object].append( - relationship.id + for query in DELETE_QUERIES: + if "community" in query or "collection_entity" in query: + await self.connection_manager.execute_query( + query, [collection_id] ) - - return relationship_ids_cache - - async def _incremental_clustering( - self, - relationship_ids_cache: dict[str, list[int]], - leiden_params: dict[str, Any], - collection_id: UUID, - ) -> int: - """ - Performs incremental clustering on new relationships by: - 1. Getting all relationships and new relationships - 2. Getting community mapping for all existing relationships - 3. For each new relationship: - - Check if subject/object exists in community mapping - - If exists, add its cluster to updated communities set - - If not, append relationship to new_relationship_ids list for clustering - 4. Run hierarchical clustering on new_relationship_ids list - 5. Update community info table with new clusters, offsetting IDs by max_cluster_id - """ - - QUERY = f""" - SELECT node, cluster, is_final_cluster FROM {self._get_table_name("community_info")} WHERE collection_id = $1 - """ - - communities = await self.connection_manager.fetch_query( - QUERY, [collection_id] - ) - max_cluster_id = max( - [community["cluster"] for community in communities] - ) - - # TODO: modify above query to get a dict grouped by node (without aggregation) - communities_dict = {} # type: ignore - for community in communities: - if community["node"] not in communities_dict: - communities_dict[community["node"]] = [] - communities_dict[community["node"]].append(community) - - QUERY = f""" - SELECT document_id FROM {self._get_table_name("document_info")} WHERE $1 = ANY(collection_ids) and kg_extraction_status = $2 - """ - - new_document_ids = await self.connection_manager.fetch_query( - QUERY, [collection_id, KGExtractionStatus.SUCCESS] - ) - - new_relationship_ids = await self.get_all_relationships( - collection_id, new_document_ids - ) - - # community mapping for new relationships - updated_communities = set() - new_relationships = [] - for relationship in new_relationship_ids: - # bias towards subject - if relationship.subject in communities_dict: - for community in communities_dict[relationship.subject]: - updated_communities.add(community["cluster"]) - elif relationship.object in communities_dict: - for community in communities_dict[relationship.object]: - updated_communities.add(community["cluster"]) else: - new_relationships.append(relationship) + await self.connection_manager.execute_query( + query, [document_ids] + ) - # delete the communities information for the updated communities + # set status to PENDING for this collection. QUERY = f""" - DELETE FROM {self._get_table_name("community")} WHERE collection_id = $1 AND community_number = ANY($2) + UPDATE {self._get_table_name("collections")} SET kg_enrichment_status = $1 WHERE collection_id = $2 """ await self.connection_manager.execute_query( - QUERY, [collection_id, updated_communities] - ) - - hierarchical_communities_output = await self._create_graph_and_cluster( - new_relationships, leiden_params + QUERY, [KGExtractionStatus.PENDING, collection_id] ) - community_info = [] - for community in hierarchical_communities_output: - community_info.append( - CommunityInfo( - node=community.node, - cluster=community.cluster + max_cluster_id, - parent_cluster=( - community.parent_cluster + max_cluster_id - if community.parent_cluster is not None - else None - ), - level=community.level, - relationship_ids=[], # FIXME: need to get the relationship ids for the community - is_final_cluster=community.is_final_cluster, - collection_id=collection_id, - ) - ) - - await self.add_community_info(community_info) - num_communities = max([item.cluster for item in community_info]) + 1 - return num_communities - async def perform_graph_clustering( self, collection_id: UUID, @@ -943,379 +1121,173 @@ async def perform_graph_clustering( return num_communities - async def _compute_leiden_communities( - self, - graph: Any, - leiden_params: dict[str, Any], - ) -> Any: - """Compute Leiden communities.""" - try: - from graspologic.partition import hierarchical_leiden - - if "random_seed" not in leiden_params: - leiden_params["random_seed"] = ( - 7272 # add seed to control randomness - ) - - start_time = time.time() - logger.info( - f"Running Leiden clustering with params: {leiden_params}" - ) + ####################### MANAGEMENT METHODS ####################### - community_mapping = hierarchical_leiden(graph, **leiden_params) + async def get_entity_map( + self, offset: int, limit: int, document_id: UUID + ) -> dict[str, dict[str, list[dict[str, Any]]]]: - logger.info( - f"Leiden clustering completed in {time.time() - start_time:.2f} seconds." + QUERY1 = f""" + WITH entities_list AS ( + SELECT DISTINCT name + FROM {self._get_table_name("chunk_entity")} + WHERE document_id = $1 + ORDER BY name ASC + LIMIT {limit} OFFSET {offset} ) - return community_mapping + SELECT e.name, e.description, e.category, + (SELECT array_agg(DISTINCT x) FROM unnest(e.extraction_ids) x) AS extraction_ids, + e.document_id + FROM {self._get_table_name("chunk_entity")} e + JOIN entities_list el ON e.name = el.name + GROUP BY e.name, e.description, e.category, e.extraction_ids, e.document_id + ORDER BY e.name;""" - except ImportError as e: - raise ImportError("Please install the graspologic package.") from e + entities_list = await self.connection_manager.fetch_query( + QUERY1, [document_id] + ) + entities_list = [ + Entity( + name=entity["name"], + description=entity["description"], + category=entity["category"], + extraction_ids=entity["extraction_ids"], + document_id=entity["document_id"], + ) + for entity in entities_list + ] - async def get_community_details( - self, community_number: int, collection_id: UUID - ) -> Tuple[int, list[Entity], list[Relationship]]: + QUERY2 = f""" + WITH entities_list AS ( - QUERY = f""" - SELECT level FROM {self._get_table_name("community_info")} WHERE cluster = $1 AND collection_id = $2 - LIMIT 1 - """ - level = ( - await self.connection_manager.fetch_query( - QUERY, [community_number, collection_id] + SELECT DISTINCT name + FROM {self._get_table_name("chunk_entity")} + WHERE document_id = $1 + ORDER BY name ASC + LIMIT {limit} OFFSET {offset} ) - )[0]["level"] - # selecting table name based on entity level - # check if there are any entities in the community that are not in the entity_embedding table - query = f""" - SELECT COUNT(*) FROM {self._get_table_name("collection_entity")} WHERE collection_id = $1 + SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description, + (SELECT array_agg(DISTINCT x) FROM unnest(t.extraction_ids) x) AS extraction_ids, t.document_id + FROM {self._get_table_name("chunk_relationship")} t + JOIN entities_list el ON t.subject = el.name + ORDER BY t.subject, t.predicate, t.object; """ - entity_count = ( - await self.connection_manager.fetch_query(query, [collection_id]) - )[0]["count"] - table_name = ( - "collection_entity" if entity_count > 0 else "document_entity" - ) - QUERY = f""" - WITH node_relationship_ids AS ( - SELECT node, relationship_ids - FROM {self._get_table_name("community_info")} - WHERE cluster = $1 AND collection_id = $2 - ) - SELECT DISTINCT - e.id AS id, - e.name AS name, - e.description AS description - FROM node_relationship_ids nti - JOIN {self._get_table_name(table_name)} e ON e.name = nti.node; - """ - entities = await self.connection_manager.fetch_query( - QUERY, [community_number, collection_id] + relationships_list = await self.connection_manager.fetch_query( + QUERY2, [document_id] ) - entities = [Entity(**entity) for entity in entities] - - QUERY = f""" - WITH node_relationship_ids AS ( - SELECT node, relationship_ids - FROM {self._get_table_name("community_info")} - WHERE cluster = $1 and collection_id = $2 + relationships_list = [ + Relationship( + subject=relationship["subject"], + predicate=relationship["predicate"], + object=relationship["object"], + weight=relationship["weight"], + description=relationship["description"], + extraction_ids=relationship["extraction_ids"], + document_id=relationship["document_id"], ) - SELECT DISTINCT - t.id, t.subject, t.predicate, t.object, t.weight, t.description - FROM node_relationship_ids nti - JOIN {self._get_table_name("chunk_relationship")} t ON t.id = ANY(nti.relationship_ids); - """ - relationships = await self.connection_manager.fetch_query( - QUERY, [community_number, collection_id] - ) - relationships = [ - Relationship(**relationship) for relationship in relationships + for relationship in relationships_list ] - return level, entities, relationships - - # async def client(self): - # return None - - ############################################################ - ########## Entity CRUD Operations ########################## - ############################################################ - - async def create_entities_v3( - self, level: EntityLevel, id: UUID, entities: list[Entity] - ) -> None: + entity_map: dict[str, dict[str, list[Any]]] = {} + for entity in entities_list: + if entity.name not in entity_map: + entity_map[entity.name] = {"entities": [], "relationships": []} + entity_map[entity.name]["entities"].append(entity) - # TODO: check if already exists - await self._add_objects(entities, level.table_name) + for relationship in relationships_list: + if relationship.subject in entity_map: + entity_map[relationship.subject]["relationships"].append( + relationship + ) + if relationship.object in entity_map: + entity_map[relationship.object]["relationships"].append( + relationship + ) - async def update_entity(self, collection_id: UUID, entity: Entity) -> None: - table_name = entity.level.value + "_entity" + return entity_map - # check if the entity already exists - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 - """ - count = ( - await self.connection_manager.fetch_query( - QUERY, [entity.id, collection_id] - ) - )[0]["count"] - if count == 0: - raise R2RException("Entity does not exist", 404) + async def get_graph_status(self, collection_id: UUID) -> dict: + # check document_info table for the documents in the collection and return the status of each document + kg_extraction_statuses = await self.connection_manager.fetch_query( + f"SELECT document_id, kg_extraction_status FROM {self._get_table_name('document_info')} WHERE collection_id = $1", + [collection_id], + ) - await self._add_objects([entity], table_name) + document_ids = [ + doc_id["document_id"] for doc_id in kg_extraction_statuses + ] - async def delete_entity(self, collection_id: UUID, entity: Entity) -> None: + kg_enrichment_statuses = await self.connection_manager.fetch_query( + f"SELECT enrichment_status FROM {self._get_table_name(PostgresCollectionHandler.TABLE_NAME)} WHERE id = $1", + [collection_id], + ) - table_name = entity.level.value + "_entity" - QUERY = f""" - DELETE FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 - """ - await self.connection_manager.execute_query( - QUERY, [entity.id, collection_id] + # entity and relationship counts + chunk_entity_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1)", + [document_ids], ) - ############################################################ - ########## Relationship CRUD Operations #################### - ############################################################ + chunk_relationship_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1)", + [document_ids], + ) - async def create_relationship( - self, collection_id: UUID, relationship: Relationship - ) -> None: + document_entity_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1)", + [document_ids], + ) - # check if the relationship already exists - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE subject = $1 AND predicate = $2 AND object = $3 AND collection_id = $4 - """ - count = ( - await self.connection_manager.fetch_query( - QUERY, - [ - relationship.subject, - relationship.predicate, - relationship.object, - collection_id, - ], - ) - )[0]["count"] + collection_entity_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('collection_entity')} WHERE collection_id = $1", + [collection_id], + ) - if count > 0: - raise R2RException("Relationship already exists", 400) + community_count = await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('community')} WHERE collection_id = $1", + [collection_id], + ) - await self._add_objects([relationship], "chunk_relationship") + return { + "kg_extraction_statuses": kg_extraction_statuses, + "kg_enrichment_status": kg_enrichment_statuses[0][ + "enrichment_status" + ], + "chunk_entity_count": chunk_entity_count[0]["count"], + "chunk_relationship_count": chunk_relationship_count[0]["count"], + "document_entity_count": document_entity_count[0]["count"], + "collection_entity_count": collection_entity_count[0]["count"], + "community_count": community_count[0]["count"], + } - async def update_relationship( - self, relationship_id: UUID, relationship: Relationship - ) -> None: - # check if relationship_id exists - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 - """ - count = ( - await self.connection_manager.fetch_query(QUERY, [relationship.id]) - )[0]["count"] + ####################### ESTIMATION METHODS ####################### - if count == 0: - raise R2RException("Relationship does not exist", 404) + async def get_creation_estimate( + self, collection_id: UUID, kg_creation_settings: KGCreationSettings + ): - await self._add_objects([relationship], "chunk_relationship") + # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id. + document_ids = [ + doc.id + for doc in ( + await self.collection_handler.documents_in_collection(collection_id) # type: ignore + )["results"] + ] - async def delete_relationship(self, relationship_id: UUID) -> None: - QUERY = f""" - DELETE FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 + query = f""" + SELECT document_id, COUNT(*) as chunk_count + FROM {self._get_table_name("vectors")} + WHERE document_id = ANY($1) + GROUP BY document_id """ - await self.connection_manager.execute_query(QUERY, [relationship_id]) - ############################################################ - ########## Community CRUD Operations ####################### - ############################################################ - - async def get_communities( - self, collection_id: UUID - ) -> list[Community]: - QUERY = f""" - SELECT *c FROM {self._get_table_name("community")} WHERE collection_id = $1 - """ - return await self.connection_manager.fetch_query( - QUERY, [collection_id] - ) - - async def check_communities_exist( - self, collection_id: UUID, offset: int, limit: int - ) -> list[int]: - QUERY = f""" - SELECT distinct community_number FROM {self._get_table_name("community")} WHERE collection_id = $1 AND community_number >= $2 AND community_number < $3 - """ - community_numbers = await self.connection_manager.fetch_query( - QUERY, [collection_id, offset, offset + limit] - ) - return [item["community_number"] for item in community_numbers] - - async def delete_graph_for_collection( - self, collection_id: UUID, cascade: bool = False - ) -> None: - - # don't delete if status is PROCESSING. - QUERY = f""" - SELECT kg_enrichment_status FROM {self._get_table_name("collections")} WHERE collection_id = $1 - """ - status = ( - await self.connection_manager.fetch_query(QUERY, [collection_id]) - )[0]["kg_enrichment_status"] - if status == KGExtractionStatus.PROCESSING.value: - return - - # remove all relationships for these documents. - DELETE_QUERIES = [ - f"DELETE FROM {self._get_table_name('community_info')} WHERE collection_id = $1;", - f"DELETE FROM {self._get_table_name('community')} WHERE collection_id = $1;", - ] - - # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. - document_ids_response = ( - await self.collection_handler.documents_in_collection( - offset=0, - limit=100, - collection_id=collection_id, - ) - ) - - # This type ignore is due to insufficient typing of the documents_in_collection method - document_ids = [doc.id for doc in document_ids_response["results"]] # type: ignore - - # TODO: make these queries more efficient. Pass the document_ids as params. - if cascade: - DELETE_QUERIES += [ - f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = ANY($1::uuid[]);", - f"DELETE FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1::uuid[]);", - f"DELETE FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1::uuid[]);", - f"DELETE FROM {self._get_table_name('collection_entity')} WHERE collection_id = $1;", - ] - - # setting the kg_creation_status to PENDING for this collection. - QUERY = f""" - UPDATE {self._get_table_name("document_info")} SET kg_extraction_status = $1 WHERE $2::uuid = ANY(collection_ids) - """ - await self.connection_manager.execute_query( - QUERY, [KGExtractionStatus.PENDING, collection_id] - ) - - for query in DELETE_QUERIES: - if "community" in query or "collection_entity" in query: - await self.connection_manager.execute_query( - query, [collection_id] - ) - else: - await self.connection_manager.execute_query( - query, [document_ids] - ) - - # set status to PENDING for this collection. - QUERY = f""" - UPDATE {self._get_table_name("collections")} SET kg_enrichment_status = $1 WHERE collection_id = $2 - """ - await self.connection_manager.execute_query( - QUERY, [KGExtractionStatus.PENDING, collection_id] - ) - - async def delete_node_via_document_id( - self, document_id: UUID, collection_id: UUID - ) -> None: - # don't delete if status is PROCESSING. - QUERY = f""" - SELECT kg_enrichment_status FROM {self._get_table_name("collections")} WHERE collection_id = $1 - """ - status = ( - await self.connection_manager.fetch_query(QUERY, [collection_id]) - )[0]["kg_enrichment_status"] - if status == KGExtractionStatus.PROCESSING.value: - return - - # Execute separate DELETE queries - delete_queries = [ - f"DELETE FROM {self._get_table_name('chunk_entity')} WHERE document_id = $1", - f"DELETE FROM {self._get_table_name('chunk_relationship')} WHERE document_id = $1", - f"DELETE FROM {self._get_table_name('document_entity')} WHERE document_id = $1", - ] - - for query in delete_queries: - await self.connection_manager.execute_query(query, [document_id]) - - # Check if this is the last document in the collection - # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. - documents = await self.collection_handler.documents_in_collection( - offset=0, - limit=100, - collection_id=collection_id, - ) - count = documents["total_entries"] - - if count == 0: - # If it's the last document, delete collection-related data - collection_queries = [ - f"DELETE FROM {self._get_table_name('community_info')} WHERE collection_id = $1", - f"DELETE FROM {self._get_table_name('community')} WHERE collection_id = $1", - ] - for query in collection_queries: - await self.connection_manager.execute_query( - query, [collection_id] - ) # Ensure collection_id is in a list - - # set status to PENDING for this collection. - QUERY = f""" - UPDATE {self._get_table_name("collections")} SET kg_enrichment_status = $1 WHERE collection_id = $2 - """ - await self.connection_manager.execute_query( - QUERY, [KGExtractionStatus.PENDING, collection_id] - ) - return None - return None - - def _get_str_estimation_output(self, x: tuple[Any, Any]) -> str: - if isinstance(x[0], int) and isinstance(x[1], int): - return " - ".join(map(str, x)) - else: - return " - ".join(f"{round(a, 2)}" for a in x) - - async def get_existing_entity_extraction_ids( - self, document_id: UUID - ) -> list[str]: - QUERY = f""" - SELECT DISTINCT unnest(extraction_ids) AS chunk_id FROM {self._get_table_name("chunk_entity")} WHERE document_id = $1 - """ - return [ - item["chunk_id"] - for item in await self.connection_manager.fetch_query( - QUERY, [document_id] - ) - ] - - async def get_creation_estimate( - self, collection_id: UUID, kg_creation_settings: KGCreationSettings - ): - - # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id. - document_ids = [ - doc.id - for doc in ( - await self.collection_handler.documents_in_collection(collection_id) # type: ignore - )["results"] - ] - - query = f""" - SELECT document_id, COUNT(*) as chunk_count - FROM {self._get_table_name("vectors")} - WHERE document_id = ANY($1) - GROUP BY document_id - """ - - chunk_counts = await self.connection_manager.fetch_query( - query, [document_ids] - ) + chunk_counts = await self.connection_manager.fetch_query( + query, [document_ids] + ) total_chunks = ( sum(doc["chunk_count"] for doc in chunk_counts) @@ -1447,196 +1419,466 @@ async def get_enrichment_estimate( + self._get_str_estimation_output(estimated_total_time), } - async def create_vector_index(self): - # need to implement this. Just call vector db provider's create_vector_index method. - # this needs to be run periodically for every collection. - raise NotImplementedError - - async def delete_relationships(self, relationship_ids: list[int]): - # need to implement this. - raise NotImplementedError - - async def get_schema(self): - # somehow get the rds from the postgres db. - raise NotImplementedError - - async def get_entities_v3( + async def get_deduplication_estimate( self, - level: EntityLevel, - id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - entity_categories: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - offset: int = 0, - limit: int = -1, + collection_id: UUID, + kg_deduplication_settings: KGEntityDeduplicationSettings, ): - - params: list = [id] - - if level != EntityLevel.CHUNK and entity_categories: - raise ValueError( - "entity_categories are only supported for chunk level entities" + try: + # number of documents in collection + query = f""" + SELECT name, count(name) + FROM {self._get_table_name("document_entity")} + WHERE document_id = ANY( + SELECT document_id FROM {self._get_table_name("document_info")} + WHERE $1 = ANY(collection_ids) + ) + GROUP BY name + HAVING count(name) >= 5 + """ + entities = await self.connection_manager.fetch_query( + query, [collection_id] ) + num_entities = len(entities) - filter = { - EntityLevel.CHUNK: "chunk_ids = ANY($1)", - EntityLevel.DOCUMENT: "document_id = $1", - EntityLevel.COLLECTION: "collection_id = $1", - }[level] - - if entity_names: - filter += " AND name = ANY($2)" - params.append(entity_names) + estimated_llm_calls = (num_entities, num_entities) + estimated_total_in_out_tokens_in_millions = ( + estimated_llm_calls[0] * 1000 / 1000000, + estimated_llm_calls[1] * 5000 / 1000000, + ) + estimated_cost_in_usd = ( + estimated_total_in_out_tokens_in_millions[0] + * llm_cost_per_million_tokens( + kg_deduplication_settings.generation_config.model + ), + estimated_total_in_out_tokens_in_millions[1] + * llm_cost_per_million_tokens( + kg_deduplication_settings.generation_config.model + ), + ) - if entity_categories: - filter += " AND category = ANY($3)" - params.append(entity_categories) + estimated_total_time_in_minutes = ( + estimated_total_in_out_tokens_in_millions[0] * 10 / 60, + estimated_total_in_out_tokens_in_millions[1] * 10 / 60, + ) - QUERY = f""" - SELECT * from {self._get_table_name(level.table_name)} WHERE {filter} - OFFSET ${len(params)} LIMIT ${len(params) + 1} - """ + return KGDeduplicationEstimationResponse( + message='Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the Deduplication process, run `deduplicate-entities` with `--run` in the cli, or `run_type="run"` in the client.', + num_entities=num_entities, + estimated_llm_calls=self._get_str_estimation_output( + estimated_llm_calls + ), + estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output( + estimated_total_in_out_tokens_in_millions + ), + estimated_cost_in_usd=self._get_str_estimation_output( + estimated_cost_in_usd + ), + estimated_total_time_in_minutes=self._get_str_estimation_output( + estimated_total_time_in_minutes + ), + ) + except UndefinedTableError as e: + logger.error( + f"Entity embedding table not found. Please run `create-graph` first. {str(e)}" + ) + raise R2RException( + message="Entity embedding table not found. Please run `create-graph` first.", + status_code=404, + ) + except PostgresError as e: + logger.error( + f"Database error in get_deduplication_estimate: {str(e)}" + ) + raise HTTPException( + status_code=500, + detail="An error occurred while fetching the deduplication estimate.", + ) + except Exception as e: + logger.error( + f"Unexpected error in get_deduplication_estimate: {str(e)}" + ) + raise HTTPException( + status_code=500, + detail="An unexpected error occurred while fetching the deduplication estimate.", + ) - params.extend([offset, limit]) - output = await self.connection_manager.fetch_query(QUERY, params) + ####################### GRAPH SEARCH METHODS ####################### - if attributes: - output = [ - entity for entity in output if entity["name"] in attributes - ] + async def graph_search( # type: ignore + self, query: str, **kwargs: Any + ) -> AsyncGenerator[Any, None]: - return output + query_embedding = kwargs.get("query_embedding", None) + search_type = kwargs.get("search_type", "__Entity__") + embedding_type = kwargs.get("embedding_type", "description_embedding") + property_names = kwargs.get("property_names", ["name", "description"]) + filters = kwargs.get("filters", {}) + entities_level = kwargs.get("entities_level", EntityLevel.DOCUMENT) + limit = kwargs.get("limit", 10) - # TODO: deprecate this - async def get_entities( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - entity_ids: Optional[list[str]] = None, - entity_names: Optional[list[str]] = None, - entity_table_name: str = "document_entity", - extra_columns: Optional[list[str]] = None, - ) -> dict: - conditions = [] - params: list = [collection_id] - param_index = 2 + table_name = "" + if search_type == "__Entity__": + table_name = ( + "collection_entity" + if entities_level == EntityLevel.COLLECTION + else "document_entity" + ) + elif search_type == "__Relationship__": + table_name = "chunk_relationship" + elif search_type == "__Community__": + table_name = "community" + else: + raise ValueError(f"Invalid search type: {search_type}") - if entity_ids: - conditions.append(f"id = ANY(${param_index})") - params.append(entity_ids) - param_index += 1 + property_names_str = ", ".join(property_names) - if entity_names: - conditions.append(f"name = ANY(${param_index})") - params.append(entity_names) - param_index += 1 + collection_ids_dict = filters.get("collection_ids", {}) + filter_query = "" + if collection_ids_dict: + filter_query = "WHERE collection_id = ANY($3)" + filter_ids = collection_ids_dict["$overlap"] - pagination_params = [] - if offset: - pagination_params.append(f"OFFSET ${param_index}") - params.append(offset) - param_index += 1 + if ( + search_type == "__Community__" + or table_name == "collection_entity" + ): + logger.info(f"Searching in collection ids: {filter_ids}") - if limit != -1: - pagination_params.append(f"LIMIT ${param_index}") - params.append(limit) - param_index += 1 + elif search_type in ["__Entity__", "__Relationship__"]: + filter_query = "WHERE document_id = ANY($3)" + # TODO - This seems like a hack, we will need a better way to filter by collection ids for entities and relationships + query = f""" + SELECT distinct document_id FROM {self._get_table_name('document_info')} WHERE $1 = ANY(collection_ids) + """ + filter_ids = [ + doc_id["document_id"] + for doc_id in await self.connection_manager.fetch_query( + query, filter_ids + ) + ] + logger.info(f"Searching in document ids: {filter_ids}") - pagination_clause = " ".join(pagination_params) + QUERY = f""" + SELECT {property_names_str} FROM {self._get_table_name(table_name)} {filter_query} ORDER BY {embedding_type} <=> $1 LIMIT $2; + """ - if entity_table_name == "collection_entity": - query = f""" - SELECT id, name, description, extraction_ids, document_ids {", " + ", ".join(extra_columns) if extra_columns else ""} - FROM {self._get_table_name(entity_table_name)} - WHERE collection_id = $1 - {" AND " + " AND ".join(conditions) if conditions else ""} - ORDER BY id - {pagination_clause} - """ + if filter_query != "": + results = await self.connection_manager.fetch_query( + QUERY, (str(query_embedding), limit, filter_ids) + ) else: - query = f""" - SELECT id, name, description, extraction_ids, document_id {", " + ", ".join(extra_columns) if extra_columns else ""} - FROM {self._get_table_name(entity_table_name)} - WHERE document_id = ANY( - SELECT document_id FROM {self._get_table_name("document_info")} - WHERE $1 = ANY(collection_ids) + results = await self.connection_manager.fetch_query( + QUERY, (str(query_embedding), limit) ) - {" AND " + " AND ".join(conditions) if conditions else ""} - ORDER BY id - {pagination_clause} - """ - results = await self.connection_manager.fetch_query(query, params) - entities = [Entity(**entity) for entity in results] + for result in results: + yield { + property_name: result[property_name] + for property_name in property_names + } + + ####################### GRAPH CLUSTERING METHODS ####################### + + async def _create_graph_and_cluster( + self, relationships: list[Relationship], leiden_params: dict[str, Any] + ) -> Any: + + G = self.nx.Graph() + for relationship in relationships: + G.add_edge( + relationship.subject, + relationship.object, + weight=relationship.weight, + id=relationship.id, + ) + + hierarchical_communities = await self._compute_leiden_communities( + G, leiden_params + ) + + return hierarchical_communities + + async def _cluster_and_add_community_info( + self, + relationships: list[Relationship], + relationship_ids_cache: dict[str, list[int]], + leiden_params: dict[str, Any], + collection_id: UUID, + ) -> int: + + # clear if there is any old information + QUERY = f""" + DELETE FROM {self._get_table_name("community_info")} WHERE collection_id = $1 + """ + await self.connection_manager.execute_query(QUERY, [collection_id]) + + QUERY = f""" + DELETE FROM {self._get_table_name("community")} WHERE collection_id = $1 + """ + await self.connection_manager.execute_query(QUERY, [collection_id]) + + start_time = time.time() + + hierarchical_communities = await self._create_graph_and_cluster( + relationships, leiden_params + ) + + logger.info( + f"Computing Leiden communities completed, time {time.time() - start_time:.2f} seconds." + ) + + def relationship_ids(node: str) -> list[int]: + return relationship_ids_cache.get(node, []) + + logger.info( + f"Cached {len(relationship_ids_cache)} relationship ids, time {time.time() - start_time:.2f} seconds." + ) + + # upsert the communities into the database. + inputs = [ + CommunityInfo( + node=str(item.node), + cluster=item.cluster, + parent_cluster=item.parent_cluster, + level=item.level, + is_final_cluster=item.is_final_cluster, + relationship_ids=relationship_ids(item.node), + collection_id=collection_id, + ) + for item in hierarchical_communities + ] + + await self.add_community_info(inputs) + + num_communities = ( + max([item.cluster for item in hierarchical_communities]) + 1 + ) + + logger.info( + f"Generated {num_communities} communities, time {time.time() - start_time:.2f} seconds." + ) + + return num_communities + + async def _use_community_cache( + self, collection_id: UUID, relationship_ids_cache: dict[str, list[int]] + ) -> bool: + + # check if status is enriched or stale + QUERY = f""" + SELECT kg_enrichment_status FROM {self._get_table_name("collections")} WHERE collection_id = $1 + """ + status = ( + await self.connection_manager.fetchrow_query( + QUERY, [collection_id] + ) + )["kg_enrichment_status"] + if status == KGEnrichmentStatus.PENDING: + return False + + # check the number of entities in the cache. + QUERY = f""" + SELECT COUNT(distinct node) FROM {self._get_table_name("community_info")} WHERE collection_id = $1 + """ + num_entities = ( + await self.connection_manager.fetchrow_query( + QUERY, [collection_id] + ) + )["count"] + + # a hard threshold of 80% of the entities in the cache. + if num_entities > 0.8 * len(relationship_ids_cache): + return True + else: + return False + + async def _get_relationship_ids_cache( + self, relationships: list[Relationship] + ) -> dict[str, list[int]]: + + # caching the relationship ids + relationship_ids_cache = dict[str, list[int]]() + for relationship in relationships: + if ( + relationship.subject not in relationship_ids_cache + and relationship.subject is not None + ): + relationship_ids_cache[relationship.subject] = [] + if ( + relationship.object not in relationship_ids_cache + and relationship.object is not None + ): + relationship_ids_cache[relationship.object] = [] + if ( + relationship.subject is not None + and relationship.id is not None + ): + relationship_ids_cache[relationship.subject].append( + relationship.id + ) + if relationship.object is not None and relationship.id is not None: + relationship_ids_cache[relationship.object].append( + relationship.id + ) + + return relationship_ids_cache + + async def _incremental_clustering( + self, + relationship_ids_cache: dict[str, list[int]], + leiden_params: dict[str, Any], + collection_id: UUID, + ) -> int: + """ + Performs incremental clustering on new relationships by: + 1. Getting all relationships and new relationships + 2. Getting community mapping for all existing relationships + 3. For each new relationship: + - Check if subject/object exists in community mapping + - If exists, add its cluster to updated communities set + - If not, append relationship to new_relationship_ids list for clustering + 4. Run hierarchical clustering on new_relationship_ids list + 5. Update community info table with new clusters, offsetting IDs by max_cluster_id + """ + + QUERY = f""" + SELECT node, cluster, is_final_cluster FROM {self._get_table_name("community_info")} WHERE collection_id = $1 + """ + + communities = await self.connection_manager.fetch_query( + QUERY, [collection_id] + ) + max_cluster_id = max( + [community["cluster"] for community in communities] + ) + + # TODO: modify above query to get a dict grouped by node (without aggregation) + communities_dict = {} # type: ignore + for community in communities: + if community["node"] not in communities_dict: + communities_dict[community["node"]] = [] + communities_dict[community["node"]].append(community) + + QUERY = f""" + SELECT document_id FROM {self._get_table_name("document_info")} WHERE $1 = ANY(collection_ids) and kg_extraction_status = $2 + """ + + new_document_ids = await self.connection_manager.fetch_query( + QUERY, [collection_id, KGExtractionStatus.SUCCESS] + ) + + new_relationship_ids = await self.get_all_relationships( + collection_id, new_document_ids + ) + + # community mapping for new relationships + updated_communities = set() + new_relationships = [] + for relationship in new_relationship_ids: + # bias towards subject + if relationship.subject in communities_dict: + for community in communities_dict[relationship.subject]: + updated_communities.add(community["cluster"]) + elif relationship.object in communities_dict: + for community in communities_dict[relationship.object]: + updated_communities.add(community["cluster"]) + else: + new_relationships.append(relationship) + + # delete the communities information for the updated communities + QUERY = f""" + DELETE FROM {self._get_table_name("community")} WHERE collection_id = $1 AND community_number = ANY($2) + """ + await self.connection_manager.execute_query( + QUERY, [collection_id, updated_communities] + ) - total_entries = await self.get_entity_count( - collection_id=collection_id, entity_table_name=entity_table_name + hierarchical_communities_output = await self._create_graph_and_cluster( + new_relationships, leiden_params ) - return {"entities": entities, "total_entries": total_entries} + community_info = [] + for community in hierarchical_communities_output: + community_info.append( + CommunityInfo( + node=community.node, + cluster=community.cluster + max_cluster_id, + parent_cluster=( + community.parent_cluster + max_cluster_id + if community.parent_cluster is not None + else None + ), + level=community.level, + relationship_ids=[], # FIXME: need to get the relationship ids for the community + is_final_cluster=community.is_final_cluster, + collection_id=collection_id, + ) + ) - async def get_relationships( + await self.add_community_info(community_info) + num_communities = max([item.cluster for item in community_info]) + 1 + return num_communities + + + async def _compute_leiden_communities( self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - relationship_ids: Optional[list[str]] = None, - ) -> dict: - conditions = [] - params: list = [str(collection_id)] - param_index = 2 + graph: Any, + leiden_params: dict[str, Any], + ) -> Any: + """Compute Leiden communities.""" + try: + from graspologic.partition import hierarchical_leiden - if relationship_ids: - conditions.append(f"id = ANY(${param_index})") - params.append(relationship_ids) - param_index += 1 + if "random_seed" not in leiden_params: + leiden_params["random_seed"] = ( + 7272 # add seed to control randomness + ) - if entity_names: - conditions.append( - f"subject = ANY(${param_index}) or object = ANY(${param_index})" + start_time = time.time() + logger.info( + f"Running Leiden clustering with params: {leiden_params}" ) - params.append(entity_names) - param_index += 1 - pagination_params = [] - if offset: - pagination_params.append(f"OFFSET ${param_index}") - params.append(offset) - param_index += 1 + community_mapping = hierarchical_leiden(graph, **leiden_params) - if limit != -1: - pagination_params.append(f"LIMIT ${param_index}") - params.append(limit) - param_index += 1 + logger.info( + f"Leiden clustering completed in {time.time() - start_time:.2f} seconds." + ) + return community_mapping - pagination_clause = " ".join(pagination_params) + except ImportError as e: + raise ImportError("Please install the graspologic package.") from e - query = f""" - SELECT id, subject, predicate, object, description - FROM {self._get_table_name("chunk_relationship")} - WHERE document_id = ANY( - SELECT document_id FROM {self._get_table_name("document_info")} - WHERE $1 = ANY(collection_ids) - ) - {" AND " + " AND ".join(conditions) if conditions else ""} - ORDER BY id - {pagination_clause} - """ - relationships = await self.connection_manager.fetch_query( - query, params - ) - relationships = [ - Relationship(**relationship) for relationship in relationships + ####################### UTILITY METHODS ####################### + + def _get_str_estimation_output(self, x: tuple[Any, Any]) -> str: + if isinstance(x[0], int) and isinstance(x[1], int): + return " - ".join(map(str, x)) + else: + return " - ".join(f"{round(a, 2)}" for a in x) + + async def get_existing_entity_extraction_ids( + self, document_id: UUID + ) -> list[str]: + QUERY = f""" + SELECT DISTINCT unnest(extraction_ids) AS chunk_id FROM {self._get_table_name("chunk_entity")} WHERE document_id = $1 + """ + return [ + item["chunk_id"] + for item in await self.connection_manager.fetch_query( + QUERY, [document_id] + ) ] - total_entries = await self.get_relationship_count( - collection_id=collection_id - ) - return {"relationships": relationships, "total_entries": total_entries} + async def create_vector_index(self): + # need to implement this. Just call vector db provider's create_vector_index method. + # this needs to be run periodically for every collection. + raise NotImplementedError + + async def structured_query(self): raise NotImplementedError @@ -1751,86 +1993,46 @@ async def update_entity_descriptions(self, entities: list[Entity]): await self.connection_manager.execute_many(query, inputs) # type: ignore - async def get_deduplication_estimate( + ####################### PRIVATE METHODS ########################## + + async def _add_objects( self, - collection_id: UUID, - kg_deduplication_settings: KGEntityDeduplicationSettings, - ): - try: - # number of documents in collection - query = f""" - SELECT name, count(name) - FROM {self._get_table_name("document_entity")} - WHERE document_id = ANY( - SELECT document_id FROM {self._get_table_name("document_info")} - WHERE $1 = ANY(collection_ids) - ) - GROUP BY name - HAVING count(name) >= 5 - """ - entities = await self.connection_manager.fetch_query( - query, [collection_id] - ) - num_entities = len(entities) + objects: list[Any], + table_name: str, + conflict_columns: list[str] = [], + ) -> asyncpg.Record: + """ + Upsert objects into the specified table. + """ + # Get non-null attributes from the first object + non_null_attrs = {k: v for k, v in objects[0].items() if v is not None} + columns = ", ".join(non_null_attrs.keys()) - estimated_llm_calls = (num_entities, num_entities) - estimated_total_in_out_tokens_in_millions = ( - estimated_llm_calls[0] * 1000 / 1000000, - estimated_llm_calls[1] * 5000 / 1000000, - ) - estimated_cost_in_usd = ( - estimated_total_in_out_tokens_in_millions[0] - * llm_cost_per_million_tokens( - kg_deduplication_settings.generation_config.model - ), - estimated_total_in_out_tokens_in_millions[1] - * llm_cost_per_million_tokens( - kg_deduplication_settings.generation_config.model - ), - ) + placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs))) - estimated_total_time_in_minutes = ( - estimated_total_in_out_tokens_in_millions[0] * 10 / 60, - estimated_total_in_out_tokens_in_millions[1] * 10 / 60, + if conflict_columns: + conflict_columns_str = ", ".join(conflict_columns) + replace_columns_str = ", ".join( + f"{column} = EXCLUDED.{column}" for column in non_null_attrs ) + on_conflict_query = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {replace_columns_str}" + else: + on_conflict_query = "" - return KGDeduplicationEstimationResponse( - message='Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the Deduplication process, run `deduplicate-entities` with `--run` in the cli, or `run_type="run"` in the client.', - num_entities=num_entities, - estimated_llm_calls=self._get_str_estimation_output( - estimated_llm_calls - ), - estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output( - estimated_total_in_out_tokens_in_millions - ), - estimated_cost_in_usd=self._get_str_estimation_output( - estimated_cost_in_usd - ), - estimated_total_time_in_minutes=self._get_str_estimation_output( - estimated_total_time_in_minutes - ), - ) - except UndefinedTableError as e: - logger.error( - f"Entity embedding table not found. Please run `create-graph` first. {str(e)}" - ) - raise R2RException( - message="Entity embedding table not found. Please run `create-graph` first.", - status_code=404, - ) - except PostgresError as e: - logger.error( - f"Database error in get_deduplication_estimate: {str(e)}" - ) - raise HTTPException( - status_code=500, - detail="An error occurred while fetching the deduplication estimate.", - ) - except Exception as e: - logger.error( - f"Unexpected error in get_deduplication_estimate: {str(e)}" - ) - raise HTTPException( - status_code=500, - detail="An unexpected error occurred while fetching the deduplication estimate.", + QUERY = f""" + INSERT INTO {self._get_table_name(table_name)} ({columns}) + VALUES ({placeholders}) + {on_conflict_query} + """ + + # Filter out null values for each object + params = [ + tuple( + (json.dumps(v) if isinstance(v, dict) else v) + for v in obj.values() + if v is not None ) + for obj in objects + ] + + return await self.connection_manager.execute_many(QUERY, params) # type: ignore diff --git a/py/core/providers/database/kg_tmp/__init__.py b/py/core/providers/database/kg_tmp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/py/core/providers/database/kg_tmp/community.py b/py/core/providers/database/kg_tmp/community.py new file mode 100644 index 000000000..e69de29bb diff --git a/py/core/providers/database/kg_tmp/community_info.py b/py/core/providers/database/kg_tmp/community_info.py new file mode 100644 index 000000000..e69de29bb diff --git a/py/core/providers/database/kg_tmp/entity.py b/py/core/providers/database/kg_tmp/entity.py new file mode 100644 index 000000000..bbd077427 --- /dev/null +++ b/py/core/providers/database/kg_tmp/entity.py @@ -0,0 +1,8 @@ +from core.base.providers.database import Handler +from core.providers.database.kg import PostgresConnectionManager + +class PostgresEntityHandler(Handler): + def __init__(self, project_name: str, connection_manager: PostgresConnectionManager): + super().__init__(project_name, connection_manager) + + diff --git a/py/core/providers/database/kg_tmp/graph.py b/py/core/providers/database/kg_tmp/graph.py new file mode 100644 index 000000000..e69de29bb diff --git a/py/core/providers/database/kg_tmp/main.py b/py/core/providers/database/kg_tmp/main.py new file mode 100644 index 000000000..3540403e7 --- /dev/null +++ b/py/core/providers/database/kg_tmp/main.py @@ -0,0 +1,52 @@ +# import json +# import logging +# import time +# from typing import Any, AsyncGenerator, Optional, Tuple +# from uuid import UUID +# from fastapi import HTTPException + +# import asyncpg +# from asyncpg.exceptions import PostgresError, UndefinedTableError + +# from core.base import ( +# Community, +# Entity, +# KGExtraction, +# KGExtractionStatus, +# GraphHandler, +# R2RException, +# Relationship, +# ) +# from core.base.abstractions import ( +# CommunityInfo, +# EntityLevel, +# KGCreationSettings, +# KGEnrichmentSettings, +# KGEnrichmentStatus, +# KGEntityDeduplicationSettings, +# VectorQuantizationType, +# ) + +# from core.base.utils import _decorate_vector_type, llm_cost_per_million_tokens + +# from .base import PostgresConnectionManager +# from .collection import PostgresCollectionHandler +# from .entity import PostgresEntityHandler +# from .relationship import PostgresRelationshipHandler +# from .community import PostgresCommunityHandler +# from .graph import PostgresGraphHandler + +# logger = logging.getLogger() + + +# class PostgresGraphHandler(GraphHandler): +# """Handler for Knowledge Graph METHODS in PostgreSQL.""" + +# entity_handler: PostgresEntityHandler +# relationship_handler: PostgresRelationshipHandler +# community_handler: PostgresCommunityHandler +# graph_handler: PostgresGraphHandler + +# def __init__(self, project_name: str, connection_manager: PostgresConnectionManager): +# super().__init__(project_name, connection_manager) + diff --git a/py/core/providers/database/kg_tmp/relationship.py b/py/core/providers/database/kg_tmp/relationship.py new file mode 100644 index 000000000..e69de29bb diff --git a/py/core/providers/database/postgres.py b/py/core/providers/database/postgres.py index dfbf3541f..5ba475c96 100644 --- a/py/core/providers/database/postgres.py +++ b/py/core/providers/database/postgres.py @@ -15,7 +15,7 @@ from core.providers.database.collection import PostgresCollectionHandler from core.providers.database.document import PostgresDocumentHandler from core.providers.database.file import PostgresFileHandler -from core.providers.database.kg import PostgresKGHandler +from core.providers.database.kg import PostgresGraphHandler from core.providers.database.logging import PostgresLoggingHandler from core.providers.database.prompt import PostgresPromptHandler from core.providers.database.tokens import PostgresTokenHandler @@ -62,7 +62,7 @@ class PostgresDBProvider(DatabaseProvider): token_handler: PostgresTokenHandler user_handler: PostgresUserHandler vector_handler: PostgresVectorHandler - kg_handler: PostgresKGHandler + graph_handler: PostgresGraphHandler prompt_handler: PostgresPromptHandler file_handler: PostgresFileHandler logging_handler: PostgresLoggingHandler @@ -154,13 +154,15 @@ def __init__( self.quantization_type, self.enable_fts, ) - self.kg_handler = PostgresKGHandler( + + self.graph_handler = PostgresGraphHandler( self.project_name, self.connection_manager, self.collection_handler, self.dimension, self.quantization_type, ) + self.prompt_handler = PostgresPromptHandler( self.project_name, self.connection_manager ) @@ -197,7 +199,7 @@ async def initialize(self): await self.vector_handler.create_tables() await self.prompt_handler.create_tables() await self.file_handler.create_tables() - await self.kg_handler.create_tables() + await self.graph_handler.create_tables() await self.logging_handler.create_tables() def _get_postgres_configuration_settings( diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index b5a32648e..c79b9a999 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -5,6 +5,7 @@ from enum import Enum from typing import Any, Optional, Union from uuid import UUID +from datetime import datetime from pydantic import BaseModel @@ -57,7 +58,7 @@ class Entity(R2RSerializable): """An entity extracted from a document.""" name: str - id: Optional[int] = None + id: Optional[Union[int, UUID]] = None level: Optional[EntityLevel] = None category: Optional[str] = None description: Optional[str] = None @@ -92,7 +93,7 @@ def __init__(self, **kwargs): class Relationship(R2RSerializable): """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" - id: Optional[int] = None + id: Optional[Union[int, UUID]] = None subject: str """The source entity name.""" @@ -189,6 +190,9 @@ def from_dict(cls, d: dict[str, Any]) -> "CommunityInfo": @dataclass class Community(BaseModel): + + id: Optional[Union[int, UUID]] = None + """Defines an LLM-generated summary report of a community.""" community_number: int @@ -256,6 +260,34 @@ def from_dict( ) +class Graph(BaseModel): + """A graph in the system.""" + + id: uuid.UUID + status: str + created_at: datetime + updated_at: datetime + document_ids: list[uuid.UUID] = [] + collection_ids: list[uuid.UUID] = [] + attributes: dict[str, Any] = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if isinstance(self.attributes, str): + self.attributes = json.loads(self.attributes) + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> "Graph": + return Graph( + id=d["id"], + status=d["status"], + created_at=d["created_at"], + updated_at=d["updated_at"], + document_ids=d["document_ids"], + collection_ids=d["collection_ids"], + attributes=d["attributes"], + ) + class KGExtraction(R2RSerializable): """An extraction from a document that is part of a knowledge graph.""" @@ -263,3 +295,6 @@ class KGExtraction(R2RSerializable): document_id: uuid.UUID entities: list[Entity] relationships: list[Relationship] + + + From 110e87d4e281f37dd06d6a0c679e7d18a64274c4 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 13 Nov 2024 16:40:32 -0800 Subject: [PATCH 10/21] checkin --- py/core/providers/database/kg.py | 404 ++++++++++++++++++++----- py/core/providers/database/postgres.py | 10 +- py/shared/abstractions/graph.py | 6 + 3 files changed, 337 insertions(+), 83 deletions(-) diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 1059712a9..73a419603 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -39,14 +39,39 @@ class PostgresEntityHandler(EntityHandler): + """Handler for managing entities in PostgreSQL database. + + Provides methods for CRUD operations on entities at different levels (chunk, document, collection). + Handles creation of database tables and management of entity data. + """ def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + """Initialize the PostgresEntityHandler. + + Args: + *args: Variable length argument list + **kwargs: Arbitrary keyword arguments. Must include: + - dimension: Dimension size for vector embeddings + - quantization_type: Type of vector quantization to use + """ self.dimension = kwargs.get("dimension") self.quantization_type = kwargs.get("quantization_type") + super().__init__( + project_name=kwargs.get("project_name"), + connection_manager=kwargs.get("connection_manager") + ) + async def create_tables(self) -> None: + """Create the necessary database tables for storing entities. + Creates three tables: + - chunk_entity: For storing chunk-level entities + - document_entity: For storing document-level entities with embeddings + - collection_entity: For storing deduplicated collection-level entities + + Each table has appropriate columns and constraints for its level. + """ vector_column_str = _decorate_vector_type( f"({self.dimension})", self.quantization_type ) @@ -95,19 +120,243 @@ async def create_tables(self) -> None: await self.connection_manager.execute_query(query) + async def create(self, entities: list[Entity]) -> None: + """Create new entities in the database. + + Args: + entities: List of Entity objects to create. All entities must be of the same level. + + Raises: + ValueError: If entity level is not set or if entities have different levels. + """ + # assert that all entities are of the same level + entity_level = entities[0].level + if entity_level is None: + raise ValueError("Entity level is not set") + + for entity in entities: + if entity.level != entity_level: + raise ValueError("All entities must be of the same level") + + return await self._add_objects( + entities, entity_level.table_name + ) + + async def get( + self, + level: EntityLevel, + id: Optional[UUID] = None, + entity_names: Optional[list[str]] = None, + entity_categories: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, + offset: int = 0, + limit: int = -1, + ) -> list[Entity]: + """Retrieve entities from the database based on various filters. + + Args: + level: Level of entities to retrieve (chunk, document, or collection) + id: Optional UUID to filter by + entity_names: Optional list of entity names to filter by + entity_categories: Optional list of categories (only for chunk level) + attributes: Optional list of attributes to filter by + offset: Number of records to skip + limit: Maximum number of records to return (-1 for no limit) + + Returns: + List of matching Entity objects + + Raises: + ValueError: If entity_categories is used with non-chunk level entities + """ + params: list = [id] + + if level != EntityLevel.CHUNK and entity_categories: + raise ValueError( + "entity_categories are only supported for chunk level entities" + ) + + filter = { + EntityLevel.CHUNK: "chunk_ids = ANY($1)", + EntityLevel.DOCUMENT: "document_id = $1", + EntityLevel.COLLECTION: "collection_id = $1", + }[level] + + if entity_names: + filter += " AND name = ANY($2)" + params.append(entity_names) + + if entity_categories: + filter += " AND category = ANY($3)" + params.append(entity_categories) + + QUERY = f""" + SELECT * from {self._get_table_name(level.table_name)} WHERE {filter} + OFFSET ${len(params)} LIMIT ${len(params) + 1} + """ + + params.extend([offset, limit]) + + output = await self.connection_manager.fetch_query(QUERY, params) + + if attributes: + output = [ + entity for entity in output if entity["name"] in attributes + ] + + return output + + + async def update(self, entity: Entity) -> None: + """Update an existing entity in the database. + + Args: + entity: Entity object containing updated data + + Raises: + R2RException: If the entity does not exist in the database + """ + table_name = entity.level.value + "_entity" + + # check if the entity already exists + QUERY = f""" + SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 + """ + count = ( + await self.connection_manager.fetch_query( + QUERY, [entity.id, entity.collection_id] + ) + )[0]["count"] + + if count == 0: + raise R2RException("Entity does not exist", 404) + + await self._add_objects([entity], table_name) + + async def delete(self, entity_id: UUID, level: EntityLevel) -> None: + """Delete an entity from the database. + + Args: + entity_id: UUID of the entity to delete + level: Level of the entity (chunk, document, or collection) + """ + table_name = level.value + "_entity" + QUERY = f""" + DELETE FROM {self._get_table_name(table_name)} WHERE id = $1 + """ + await self.connection_manager.execute_query( + QUERY, [entity_id] + ) + class PostgresRelationshipHandler(RelationshipHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + self.project_name = kwargs.get("project_name") + self.connection_manager = kwargs.get("connection_manager") async def create_tables(self) -> None: - pass + """Create the relationships table if it doesn't exist.""" + QUERY = f""" + CREATE TABLE IF NOT EXISTS {self._get_table_name("relationship")} ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + subject_id UUID, + object_id UUID, + weight FLOAT DEFAULT 1.0, + description TEXT, + predicate_embedding FLOAT[], + extraction_ids UUID[], + document_id UUID, + attributes JSONB DEFAULT '{{}}'::jsonb, + created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS relationship_subject_idx ON {self._get_table_name("relationship")} (subject); + CREATE INDEX IF NOT EXISTS relationship_object_idx ON {self._get_table_name("relationship")} (object); + CREATE INDEX IF NOT EXISTS relationship_predicate_idx ON {self._get_table_name("relationship")} (predicate); + CREATE INDEX IF NOT EXISTS relationship_document_id_idx ON {self._get_table_name("relationship")} (document_id); + """ + await self.connection_manager.execute_query(QUERY) + + def _get_table_name(self, table: str) -> str: + """Get the fully qualified table name.""" + return f'"{self.project_name}"."{table}"' + + async def create(self, relationships: list[Relationship]) -> None: + """Create a new relationship in the database.""" + await self._add_objects(relationships, "relationship") + + async def get(self, relationship_id: UUID) -> list[Relationship]: + """Get relationships from storage by ID.""" + QUERY = f""" + SELECT * FROM {self._get_table_name("relationship_chunk")} + WHERE id = $1 + """ + rows = await self.connection_manager.fetch_query(QUERY, [relationship_id]) + return [Relationship(**row) for row in rows] + + async def update(self, relationship: Relationship) -> None: + + # check if the relationship already exists + QUERY = f""" + SELECT COUNT(*) FROM {self._get_table_name("relationship")} WHERE id = $1 + """ + count = (await self.connection_manager.fetch_query(QUERY, [relationship.id]))[0]["count"] + if count == 0: + raise R2RException("Relationship does not exist", 204) + return await self._add_objects([relationship], "relationship", [relationship.id]) + + async def delete(self, relationship_id: UUID) -> None: + """Delete a relationship from the database.""" + QUERY = f""" + DELETE FROM {self._get_table_name("relationship")} + WHERE id = $1 + """ + await self.connection_manager.execute_query(QUERY, [relationship_id]) class PostgresCommunityHandler(CommunityHandler): - pass + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.project_name = kwargs.get("project_name") + self.connection_manager = kwargs.get("connection_manager") + + async def create_tables(self) -> None: + pass + + async def create(self, communities: list[Community]) -> None: + pass + + async def get(self, community_id: UUID) -> list[Community]: + pass + + async def update(self, community: Community) -> None: + pass + + async def delete(self, community_id: UUID) -> None: + pass class PostgresCommunityInfoHandler(CommunityInfoHandler): - pass + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.project_name = kwargs.get("project_name") + self.connection_manager = kwargs.get("connection_manager") + + async def create_tables(self) -> None: + pass + + async def create(self, community_infos: list[CommunityInfo]) -> None: + pass + + async def get(self, community_info_id: UUID) -> list[CommunityInfo]: + pass + + async def update(self, community_info: CommunityInfo) -> None: + pass + + async def delete(self, community_info_id: UUID) -> None: + pass @@ -122,8 +371,16 @@ class PostgresCommunityInfoHandler(CommunityInfoHandler): class PostgresGraphHandler(GraphHandler): """Handler for Knowledge Graph METHODS in PostgreSQL.""" - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) + def __init__( + self, + # project_name: str, + # connection_manager: PostgresConnectionManager, + # collection_handler: PostgresCollectionHandler, + # dimension: int, + # quantization_type: VectorQuantizationType, + *args: Any, + **kwargs: Any, + ) -> None: self.project_name = kwargs.get("project_name") self.connection_manager = kwargs.get("connection_manager") @@ -159,6 +416,7 @@ async def create_tables(self) -> None: await self.connection_manager.execute_query(QUERY) for handler in self.handlers: + print(f"Creating tables for {handler.__class__.__name__}") await handler.create_tables() async def create(self, graph: Graph) -> None: @@ -212,15 +470,6 @@ async def remove_collection(self, graph_id: UUID, collection_id: UUID) -> None: await self.connection_manager.execute_query(QUERY, graph_id, collection_id) - - - - - - - - - class PostgresGraphHandler_v1(GraphHandler): """Handler for Knowledge Graph METHODS in PostgreSQL.""" @@ -361,53 +610,53 @@ async def create_tables(self): ################### ENTITY METHODS ################### - async def get_entities_v3( - self, - level: EntityLevel, - id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - entity_categories: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - offset: int = 0, - limit: int = -1, - ): + # async def get_entities_v3( + # self, + # level: EntityLevel, + # id: Optional[UUID] = None, + # entity_names: Optional[list[str]] = None, + # entity_categories: Optional[list[str]] = None, + # attributes: Optional[list[str]] = None, + # offset: int = 0, + # limit: int = -1, + # ): - params: list = [id] + # params: list = [id] - if level != EntityLevel.CHUNK and entity_categories: - raise ValueError( - "entity_categories are only supported for chunk level entities" - ) + # if level != EntityLevel.CHUNK and entity_categories: + # raise ValueError( + # "entity_categories are only supported for chunk level entities" + # ) - filter = { - EntityLevel.CHUNK: "chunk_ids = ANY($1)", - EntityLevel.DOCUMENT: "document_id = $1", - EntityLevel.COLLECTION: "collection_id = $1", - }[level] + # filter = { + # EntityLevel.CHUNK: "chunk_ids = ANY($1)", + # EntityLevel.DOCUMENT: "document_id = $1", + # EntityLevel.COLLECTION: "collection_id = $1", + # }[level] - if entity_names: - filter += " AND name = ANY($2)" - params.append(entity_names) + # if entity_names: + # filter += " AND name = ANY($2)" + # params.append(entity_names) - if entity_categories: - filter += " AND category = ANY($3)" - params.append(entity_categories) + # if entity_categories: + # filter += " AND category = ANY($3)" + # params.append(entity_categories) - QUERY = f""" - SELECT * from {self._get_table_name(level.table_name)} WHERE {filter} - OFFSET ${len(params)} LIMIT ${len(params) + 1} - """ + # QUERY = f""" + # SELECT * from {self._get_table_name(level.table_name)} WHERE {filter} + # OFFSET ${len(params)} LIMIT ${len(params) + 1} + # """ - params.extend([offset, limit]) + # params.extend([offset, limit]) - output = await self.connection_manager.fetch_query(QUERY, params) + # output = await self.connection_manager.fetch_query(QUERY, params) - if attributes: - output = [ - entity for entity in output if entity["name"] in attributes - ] + # if attributes: + # output = [ + # entity for entity in output if entity["name"] in attributes + # ] - return output + # return output # TODO: deprecate this async def get_entities( @@ -522,33 +771,33 @@ async def create_entities_v3( # TODO: check if already exists await self._add_objects(entities, level.table_name) - async def update_entity(self, collection_id: UUID, entity: Entity) -> None: - table_name = entity.level.value + "_entity" + # async def update_entity(self, collection_id: UUID, entity: Entity) -> None: + # table_name = entity.level.value + "_entity" - # check if the entity already exists - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 - """ - count = ( - await self.connection_manager.fetch_query( - QUERY, [entity.id, collection_id] - ) - )[0]["count"] + # # check if the entity already exists + # QUERY = f""" + # SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 + # """ + # count = ( + # await self.connection_manager.fetch_query( + # QUERY, [entity.id, collection_id] + # ) + # )[0]["count"] - if count == 0: - raise R2RException("Entity does not exist", 404) + # if count == 0: + # raise R2RException("Entity does not exist", 404) - await self._add_objects([entity], table_name) + # await self._add_objects([entity], table_name) - async def delete_entity(self, collection_id: UUID, entity: Entity) -> None: + # async def delete_entity(self, collection_id: UUID, entity: Entity) -> None: - table_name = entity.level.value + "_entity" - QUERY = f""" - DELETE FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 - """ - await self.connection_manager.execute_query( - QUERY, [entity.id, collection_id] - ) + # table_name = entity.level.value + "_entity" + # QUERY = f""" + # DELETE FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 + # """ + # await self.connection_manager.execute_query( + # QUERY, [entity.id, collection_id] + # ) async def delete_node_via_document_id( @@ -2028,9 +2277,8 @@ async def _add_objects( # Filter out null values for each object params = [ tuple( - (json.dumps(v) if isinstance(v, dict) else v) - for v in obj.values() - if v is not None + (json.dumps(v) if isinstance(v, dict) else str(v) if "embedding" in k else v) + for k, v in ((k, v) for k, v in obj.items() if v is not None) ) for obj in objects ] diff --git a/py/core/providers/database/postgres.py b/py/core/providers/database/postgres.py index 5ba475c96..d63a9e2f9 100644 --- a/py/core/providers/database/postgres.py +++ b/py/core/providers/database/postgres.py @@ -156,11 +156,11 @@ def __init__( ) self.graph_handler = PostgresGraphHandler( - self.project_name, - self.connection_manager, - self.collection_handler, - self.dimension, - self.quantization_type, + project_name=self.project_name, + connection_manager=self.connection_manager, + collection_handler=self.collection_handler, + dimension=self.dimension, + quantization_type=self.quantization_type, ) self.prompt_handler = PostgresPromptHandler( diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index c79b9a999..7fb8c2894 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -101,6 +101,12 @@ class Relationship(R2RSerializable): predicate: str """A description of the relationship (optional).""" + subject_id: UUID | None = None + """The source entity ID (optional).""" + + object_id: UUID | None = None + """The target entity ID (optional).""" + object: str """The target entity name.""" From 257bda8ec8bb3906212a7a70905fd1d57e86557e Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Wed, 13 Nov 2024 18:30:30 -0800 Subject: [PATCH 11/21] up --- py/core/base/providers/database.py | 6 +- py/core/main/api/v3/graph_router.py | 103 ++++++++- py/core/main/services/kg_service.py | 39 +++- py/core/pipes/kg/deduplication.py | 20 +- py/core/pipes/kg/entity_description.py | 10 +- py/core/pipes/kg/relationships_extraction.py | 14 +- py/core/pipes/kg/storage.py | 12 +- py/core/pipes/retrieval/kg_search_pipe.py | 8 +- py/core/providers/database/kg.py | 216 +++++++++++------- py/core/providers/database/kg_tmp/main.py | 4 +- py/shared/abstractions/graph.py | 17 +- py/shared/abstractions/search.py | 4 +- .../pipes/test_kg_community_summary_pipe.py | 20 +- .../ingestion/test_contextual_embedding.py | 10 +- py/tests/core/providers/kg/test_kg_logic.py | 26 +-- 15 files changed, 338 insertions(+), 171 deletions(-) diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index ecb9ed497..328433ce6 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -856,7 +856,7 @@ async def list_chunks( # raise NotImplementedError # @abstractmethod -# async def get_existing_entity_extraction_ids( +# async def get_existing_entity_chunk_ids( # self, document_id: UUID # ) -> list[str]: # """Get existing entity extraction IDs.""" @@ -1873,10 +1873,10 @@ async def update_kg_search_prompt(self) -> None: async def upsert_relationships(self) -> None: return await self.graph_handler.upsert_relationships() - async def get_existing_entity_extraction_ids( + async def get_existing_entity_chunk_ids( self, document_id: UUID ) -> list[str]: - return await self.graph_handler.get_existing_entity_extraction_ids( + return await self.graph_handler.get_existing_entity_chunk_ids( document_id ) diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 59e224d83..e4d48678b 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -108,6 +108,7 @@ def _get_path_level(self, request: Request) -> EntityLevel: return EntityLevel.COLLECTION def _setup_routes(self): + ##### ENTITIES ###### @self.router.get( "/chunks/{id}/entities", @@ -216,7 +217,7 @@ async def list_entities( "Only superusers can access this endpoint.", 403 ) - return await self.services["kg"].list_entities_v3( + entities, count = await self.services["kg"].list_entities_v3( level=self._get_path_level(request), id=id, offset=offset, @@ -226,6 +227,10 @@ async def list_entities( attributes=attributes, ) + return entities, { + "total_entries": count, + } + @self.router.post( "/chunks/{id}/entities", summary="Create entities for a chunk", @@ -293,7 +298,7 @@ async def list_entities( async def create_entities_v3( request: Request, id: UUID = Path( - ..., description="The ID of the chunk to create entities for." + ..., description="The ID of the object to create entities for." ), entities: list[Entity] = Body( ..., description="The entities to create." @@ -305,18 +310,54 @@ async def create_entities_v3( "Only superusers can access this endpoint.", 403 ) - # for each entity, set the level to CHUNK + # get entity level from path + path = request.url.path + if "/chunks/" in path: + level = EntityLevel.CHUNK + elif "/documents/" in path: + level = EntityLevel.DOCUMENT + else: + level = EntityLevel.COLLECTION + + # set entity level if not set for entity in entities: - if entity.level is None: - entity.level = EntityLevel.CHUNK + if entity.level: + if entity.level != level: + raise R2RException( + "Entity level must match the path level.", 400 + ) else: - raise R2RException( - "Entity level must be chunk or empty.", 400 - ) + entity.level = level + + # depending on the level, perform validation + if level == EntityLevel.CHUNK: + for entity in entities: + if entity.chunk_ids and id not in entity.chunk_ids: + raise R2RException( + "Entity extraction IDs must include the chunk ID or should be empty.", 400 + ) + + elif level == EntityLevel.DOCUMENT: + for entity in entities: + if entity.document_id: + if entity.document_id != id: + raise R2RException( + "Entity document IDs must match the document ID or should be empty.", 400 + ) + else: + entity.document_id = id + + elif level == EntityLevel.COLLECTION: + for entity in entities: + if entity.collection_id: + if entity.collection_id != id: + raise R2RException( + "Entity collection IDs must match the collection ID or should be empty.", 400 + ) + else: + entity.collection_id = id return await self.services["kg"].create_entities_v3( - level=self._get_path_level(request), - id=id, entities=entities, ) @@ -341,6 +382,48 @@ async def create_entities_v3( ] }, ) + @self.router.post( + "/documents/{id}/entities/{entity_id}", + summary="Update an entity for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.documents.update_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + """ + ), + }, + ] + }, + ) + @self.router.post( + "/collections/{id}/entities/{entity_id}", + summary="Update an entity for a collection", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.collections.update_entity(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + """ + ), + }, + ] + }, + ) @self.base_endpoint async def update_entity( request: Request, diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 53e97f3f9..ef3e9c445 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -58,6 +58,8 @@ def __init__( logging_connection, ) + ################### EXTRACTION ################### + @telemetry_event("kg_relationships_extraction") async def kg_relationships_extraction( self, @@ -118,16 +120,45 @@ async def kg_relationships_extraction( return await _collect_results(result_gen) + ################### ENTITIES ################### + @telemetry_event("create_entities_v3") async def create_entities_v3( + self, + entities: list[Entity], + **kwargs, + ): + return await self.providers.database.graph_handler.entities.create( + entities, **kwargs + ) + + @telemetry_event("list_entities_v3") + async def list_entities_v3( self, level: EntityLevel, id: UUID, - entities: list[Entity], + entity_names: Optional[list[str]] = None, + entity_categories: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, **kwargs, ): - return await self.providers.database.create_entities_v3( - level, id, entities, **kwargs + return await self.providers.database.graph_handler.entities.get( + level, id, entity_names, entity_categories, attributes, offset, limit + ) + + @telemetry_event("update_entity_v3") + async def update_entity_v3( + self, + level: EntityLevel, + id: UUID, + entity_id: UUID, + entity: Entity, + **kwargs, + ): + return await self.providers.database.graph_handler.entities.update( + level, id, entity_id, entity ) @telemetry_event("update_entity") @@ -152,6 +183,8 @@ async def delete_entity( collection_id, entity ) + ################### RELATIONSHIPS ################### + @telemetry_event("create_relationship") async def create_relationship( self, diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index f7ff24f75..4809919d6 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -73,12 +73,12 @@ async def kg_named_entity_deduplication( # deduplicate entities by name deduplicated_entities: dict[str, dict[str, list[str]]] = {} deduplication_source_keys = [ - "extraction_ids", + "chunk_ids", "document_id", "attributes", ] deduplication_target_keys = [ - "extraction_ids", + "chunk_ids", "document_ids", "attributes", ] @@ -110,7 +110,7 @@ async def kg_named_entity_deduplication( Entity( name=name, collection_id=collection_id, - extraction_ids=entity["extraction_ids"], + chunk_ids=entity["chunk_ids"], document_ids=entity["document_ids"], attributes={}, ) @@ -163,12 +163,12 @@ async def kg_description_entity_deduplication( ) deduplication_source_keys = [ - "extraction_ids", + "chunk_ids", "document_id", "attributes", ] deduplication_target_keys = [ - "extraction_ids", + "chunk_ids", "document_ids", "attributes", ] @@ -219,15 +219,15 @@ async def kg_description_entity_deduplication( description = "\n".join(descriptions[:5]) # Collect all extraction IDs from entities in the cluster - extraction_ids = set() + chunk_ids = set() document_ids = set() for entity in entities: - if entity.extraction_ids: - extraction_ids.update(entity.extraction_ids) + if entity.chunk_ids: + chunk_ids.update(entity.chunk_ids) if entity.document_id: document_ids.add(entity.document_id) - extraction_ids_list = list(extraction_ids) + chunk_ids_list = list(chunk_ids) document_ids_list = list(document_ids) deduplicated_entities_list.append( @@ -235,7 +235,7 @@ async def kg_description_entity_deduplication( name=longest_name, description=description, collection_id=collection_id, - extraction_ids=extraction_ids_list, + chunk_ids=chunk_ids_list, document_ids=document_ids_list, attributes={ "aliases": list(aliases), diff --git a/py/core/pipes/kg/entity_description.py b/py/core/pipes/kg/entity_description.py index 475685fb6..83066b467 100644 --- a/py/core/pipes/kg/entity_description.py +++ b/py/core/pipes/kg/entity_description.py @@ -86,14 +86,14 @@ async def process_entity( ] # potentially slow at scale, but set to avoid duplicates - unique_extraction_ids = set() + unique_chunk_ids = set() for entity in entities: - for chunk_id in entity.extraction_ids: - unique_extraction_ids.add(chunk_id) + for chunk_id in entity.chunk_ids: + unique_chunk_ids.add(chunk_id) out_entity = Entity( name=entities[0].name, - extraction_ids=list(unique_extraction_ids), + chunk_ids=list(unique_chunk_ids), document_ids=[document_id], ) @@ -137,7 +137,7 @@ async def process_entity( out_entity.name, out_entity.description, str(out_entity.description_embedding), - out_entity.extraction_ids, + out_entity.chunk_ids, document_id, ) ], diff --git a/py/core/pipes/kg/relationships_extraction.py b/py/core/pipes/kg/relationships_extraction.py index f7867df20..f2dba7af4 100644 --- a/py/core/pipes/kg/relationships_extraction.py +++ b/py/core/pipes/kg/relationships_extraction.py @@ -144,7 +144,7 @@ def parse_fn(response_str: str) -> Any: description=entity_description, name=entity_value, document_id=extractions[0].document_id, - extraction_ids=[ + chunk_ids=[ extraction.id for extraction in extractions ], attributes={}, @@ -168,7 +168,7 @@ def parse_fn(response_str: str) -> Any: description=description, weight=weight, document_id=extractions[0].document_id, - extraction_ids=[ + chunk_ids=[ extraction.id for extraction in extractions ], attributes={}, @@ -179,7 +179,7 @@ def parse_fn(response_str: str) -> Any: entities, relationships = parse_fn(kg_extraction) return KGExtraction( - extraction_ids=[ + chunk_ids=[ extraction.id for extraction in extractions ], document_id=extractions[0].document_id, @@ -208,7 +208,7 @@ def parse_fn(response_str: str) -> Any: ) return KGExtraction( - extraction_ids=[extraction.id for extraction in extractions], + chunk_ids=[extraction.id for extraction in extractions], document_id=extractions[0].document_id, entities=[], relationships=[], @@ -268,16 +268,16 @@ async def _run_logic( # type: ignore ) if filter_out_existing_chunks: - existing_extraction_ids = await self.database_provider.get_existing_entity_extraction_ids( + existing_chunk_ids = await self.database_provider.get_existing_entity_chunk_ids( document_id=document_id ) extractions = [ extraction for extraction in extractions - if extraction.id not in existing_extraction_ids + if extraction.id not in existing_chunk_ids ] logger.info( - f"Filtered out {len(existing_extraction_ids)} existing extractions, remaining {len(extractions)} extractions for document {document_id}" + f"Filtered out {len(existing_chunk_ids)} existing extractions, remaining {len(extractions)} extractions for document {document_id}" ) if len(extractions) == 0: diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index aa572b680..8e846d5dd 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -66,10 +66,10 @@ async def store( ) if extraction.entities: - if not extraction.entities[0].extraction_ids: + if not extraction.entities[0].chunk_ids: for i in range(len(extraction.entities)): - extraction.entities[i].extraction_ids = ( - extraction.extraction_ids + extraction.entities[i].chunk_ids = ( + extraction.chunk_ids ) extraction.entities[i].document_id = ( extraction.document_id @@ -80,10 +80,10 @@ async def store( ) if extraction.relationships: - if not extraction.relationships[0].extraction_ids: + if not extraction.relationships[0].chunk_ids: for i in range(len(extraction.relationships)): - extraction.relationships[i].extraction_ids = ( - extraction.extraction_ids + extraction.relationships[i].chunk_ids = ( + extraction.chunk_ids ) extraction.relationships[i].document_id = ( extraction.document_id diff --git a/py/core/pipes/retrieval/kg_search_pipe.py b/py/core/pipes/retrieval/kg_search_pipe.py index 60aa68df4..9c5f09c08 100644 --- a/py/core/pipes/retrieval/kg_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_pipe.py @@ -120,7 +120,7 @@ async def local_search( property_names=[ "name", "description", - "extraction_ids", + "chunk_ids", ], filters=kg_search_settings.filters, entities_level=kg_search_settings.entities_level, @@ -132,7 +132,7 @@ async def local_search( ), method=KGSearchMethod.LOCAL, result_type=KGSearchResultType.ENTITY, - extraction_ids=search_result["extraction_ids"], + chunk_ids=search_result["chunk_ids"], metadata={"associated_query": message}, ) @@ -149,7 +149,7 @@ async def local_search( # property_names=[ # "name", # "description", - # "extraction_ids", + # "chunk_ids", # "document_ids", # ], # ): @@ -160,7 +160,7 @@ async def local_search( # ), # method=KGSearchMethod.LOCAL, # result_type=KGSearchResultType.RELATIONSHIP, - # # extraction_ids=search_result["extraction_ids"], + # # chunk_ids=search_result["chunk_ids"], # # document_ids=search_result["document_ids"], # metadata={"associated_query": message}, # ) diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 73a419603..6da3ace45 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -78,11 +78,12 @@ async def create_tables(self) -> None: query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("chunk_entity")} ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL NOT NULL, category TEXT NOT NULL, name TEXT NOT NULL, description TEXT NOT NULL, - extraction_ids UUID[] NOT NULL, + chunk_ids UUID[] NOT NULL, document_id UUID NOT NULL, attributes JSONB ); @@ -92,10 +93,11 @@ async def create_tables(self) -> None: # embeddings tables query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("document_entity")} ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL NOT NULL, name TEXT NOT NULL, description TEXT NOT NULL, - extraction_ids UUID[] NOT NULL, + chunk_ids UUID[] NOT NULL, description_embedding {vector_column_str} NOT NULL, document_id UUID NOT NULL, UNIQUE (name, document_id) @@ -107,10 +109,11 @@ async def create_tables(self) -> None: # deduplicated entities table query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("collection_entity")} ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL NOT NULL, name TEXT NOT NULL, description TEXT, - extraction_ids UUID[] NOT NULL, + chunk_ids UUID[] NOT NULL, document_ids UUID[] NOT NULL, collection_id UUID NOT NULL, description_embedding {vector_column_str}, @@ -129,6 +132,8 @@ async def create(self, entities: list[Entity]) -> None: Raises: ValueError: If entity level is not set or if entities have different levels. """ + + # TODO: move this the router layer # assert that all entities are of the same level entity_level = entities[0].level if entity_level is None: @@ -138,8 +143,11 @@ async def create(self, entities: list[Entity]) -> None: if entity.level != entity_level: raise ValueError("All entities must be of the same level") - return await self._add_objects( - entities, entity_level.table_name + return await _add_objects( + objects=[entity.__dict__ for entity in entities], + full_table_name=self._get_table_name(entity_level + "_entity"), + connection_manager=self.connection_manager, + exclude_attributes=["level"] ) async def get( @@ -191,12 +199,15 @@ async def get( params.append(entity_categories) QUERY = f""" - SELECT * from {self._get_table_name(level.table_name)} WHERE {filter} - OFFSET ${len(params)} LIMIT ${len(params) + 1} + SELECT * from {self._get_table_name(level + "_entity")} WHERE {filter} + OFFSET ${len(params)+1} LIMIT ${len(params) + 2} """ params.extend([offset, limit]) + print(QUERY) + print(params) + output = await self.connection_manager.fetch_query(QUERY, params) if attributes: @@ -204,7 +215,14 @@ async def get( entity for entity in output if entity["name"] in attributes ] - return output + output = [Entity(**entity) for entity in output] + + QUERY = f""" + SELECT COUNT(*) from {self._get_table_name(level + "_entity")} WHERE {filter} + """ + count = (await self.connection_manager.fetch_query(QUERY, params[:-2]))[0]["count"] + + return output, count async def update(self, entity: Entity) -> None: @@ -218,20 +236,45 @@ async def update(self, entity: Entity) -> None: """ table_name = entity.level.value + "_entity" + filter = "id = $1" + params = [entity.id] + if entity.level == EntityLevel.CHUNK: + filter += " AND chunk_ids = ANY($2)" + params.append(entity.chunk_ids) + elif entity.level == EntityLevel.DOCUMENT: + filter += " AND document_id = $2" + params.append(entity.document_id) + else: + filter += " AND collection_id = $2" + params.append(entity.collection_id) + # check if the entity already exists QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 + SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE {filter} """ count = ( await self.connection_manager.fetch_query( - QUERY, [entity.id, entity.collection_id] + QUERY, params ) )[0]["count"] - if count == 0: - raise R2RException("Entity does not exist", 404) + # don't override the chunk_ids + entity.chunk_ids = None + entity.level = None + + # get non null attributes + non_null_attributes = [k for k, v in entity.to_dict().items() if v is not None] - await self._add_objects([entity], table_name) + if count == 0: + raise R2RException("Entity does not exist", 204) + + await _add_objects( + objects=[entity], + full_table_name=self._get_table_name(table_name), + connection_manager=self.connection_manager, + conflict_columns=non_null_attributes, + exclude_attributes=["level"] + ) async def delete(self, entity_id: UUID, level: EntityLevel) -> None: """Delete an entity from the database. @@ -267,7 +310,7 @@ async def create_tables(self) -> None: weight FLOAT DEFAULT 1.0, description TEXT, predicate_embedding FLOAT[], - extraction_ids UUID[], + chunk_ids UUID[], document_id UUID, attributes JSONB DEFAULT '{{}}'::jsonb, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, @@ -388,16 +431,16 @@ def __init__( self.quantization_type = kwargs.get("quantization_type") self.collection_handler = kwargs.get("collection_handler") - self.entity_handler = PostgresEntityHandler(*args, **kwargs) - self.relationship_handler = PostgresRelationshipHandler(*args, **kwargs) - self.community_handler = PostgresCommunityHandler(*args, **kwargs) - self.community_info_handler = PostgresCommunityInfoHandler(*args, **kwargs) + self.entities = PostgresEntityHandler(*args, **kwargs) + self.relationships = PostgresRelationshipHandler(*args, **kwargs) + self.communities = PostgresCommunityHandler(*args, **kwargs) + self.community_infos = PostgresCommunityInfoHandler(*args, **kwargs) self.handlers = [ - self.entity_handler, - self.relationship_handler, - self.community_handler, - self.community_info_handler, + self.entities, + self.relationships, + self.communities, + self.community_infos, ] async def create_tables(self) -> None: @@ -513,11 +556,12 @@ async def create_tables(self): query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("chunk_entity")} ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL PRIMARY KEY, category TEXT NOT NULL, name TEXT NOT NULL, description TEXT NOT NULL, - extraction_ids UUID[] NOT NULL, + chunk_ids UUID[] NOT NULL, document_id UUID NOT NULL, attributes JSONB ); @@ -527,14 +571,15 @@ async def create_tables(self): # raw relationships table, also the final table. this will have embeddings. query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("chunk_relationship")} ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL PRIMARY KEY, subject TEXT NOT NULL, predicate TEXT NOT NULL, object TEXT NOT NULL, weight FLOAT NOT NULL, description TEXT NOT NULL, embedding {vector_column_str}, - extraction_ids UUID[] NOT NULL, + chunk_ids UUID[] NOT NULL, document_id UUID NOT NULL, attributes JSONB NOT NULL ); @@ -544,10 +589,11 @@ async def create_tables(self): # embeddings tables query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("document_entity")} ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL PRIMARY KEY, name TEXT NOT NULL, description TEXT NOT NULL, - extraction_ids UUID[] NOT NULL, + chunk_ids UUID[] NOT NULL, description_embedding {vector_column_str} NOT NULL, document_id UUID NOT NULL, UNIQUE (name, document_id) @@ -559,10 +605,11 @@ async def create_tables(self): # deduplicated entities table query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("collection_entity")} ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL PRIMARY KEY, name TEXT NOT NULL, description TEXT, - extraction_ids UUID[] NOT NULL, + chunk_ids UUID[] NOT NULL, document_ids UUID[] NOT NULL, collection_id UUID NOT NULL, description_embedding {vector_column_str}, @@ -575,7 +622,8 @@ async def create_tables(self): # communities table, result of the Leiden algorithm query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("community_info")} ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL PRIMARY KEY, node TEXT NOT NULL, cluster INT NOT NULL, parent_cluster INT, @@ -590,7 +638,8 @@ async def create_tables(self): # communities_report table query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("community")} ( - id SERIAL PRIMARY KEY, + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL PRIMARY KEY, community_number INT NOT NULL, collection_id UUID NOT NULL, level INT NOT NULL, @@ -698,7 +747,7 @@ async def get_entities( if entity_table_name == "collection_entity": query = f""" - SELECT id, name, description, extraction_ids, document_ids {", " + ", ".join(extra_columns) if extra_columns else ""} + SELECT id, name, description, chunk_ids, document_ids {", " + ", ".join(extra_columns) if extra_columns else ""} FROM {self._get_table_name(entity_table_name)} WHERE collection_id = $1 {" AND " + " AND ".join(conditions) if conditions else ""} @@ -707,7 +756,7 @@ async def get_entities( """ else: query = f""" - SELECT id, name, description, extraction_ids, document_id {", " + ", ".join(extra_columns) if extra_columns else ""} + SELECT id, name, description, chunk_ids, document_id {", " + ", ".join(extra_columns) if extra_columns else ""} FROM {self._get_table_name(entity_table_name)} WHERE document_id = ANY( SELECT document_id FROM {self._get_table_name("document_info")} @@ -747,9 +796,9 @@ async def add_entities( cleaned_entities = [] for entity in entities: entity_dict = entity.to_dict() - entity_dict["extraction_ids"] = ( - entity_dict["extraction_ids"] - if entity_dict.get("extraction_ids") + entity_dict["chunk_ids"] = ( + entity_dict["chunk_ids"] + if entity_dict.get("chunk_ids") else [] ) entity_dict["description_embedding"] = ( @@ -1385,11 +1434,11 @@ async def get_entity_map( LIMIT {limit} OFFSET {offset} ) SELECT e.name, e.description, e.category, - (SELECT array_agg(DISTINCT x) FROM unnest(e.extraction_ids) x) AS extraction_ids, + (SELECT array_agg(DISTINCT x) FROM unnest(e.chunk_ids) x) AS chunk_ids, e.document_id FROM {self._get_table_name("chunk_entity")} e JOIN entities_list el ON e.name = el.name - GROUP BY e.name, e.description, e.category, e.extraction_ids, e.document_id + GROUP BY e.name, e.description, e.category, e.chunk_ids, e.document_id ORDER BY e.name;""" entities_list = await self.connection_manager.fetch_query( @@ -1400,7 +1449,7 @@ async def get_entity_map( name=entity["name"], description=entity["description"], category=entity["category"], - extraction_ids=entity["extraction_ids"], + chunk_ids=entity["chunk_ids"], document_id=entity["document_id"], ) for entity in entities_list @@ -1417,7 +1466,7 @@ async def get_entity_map( ) SELECT DISTINCT t.subject, t.predicate, t.object, t.weight, t.description, - (SELECT array_agg(DISTINCT x) FROM unnest(t.extraction_ids) x) AS extraction_ids, t.document_id + (SELECT array_agg(DISTINCT x) FROM unnest(t.chunk_ids) x) AS chunk_ids, t.document_id FROM {self._get_table_name("chunk_relationship")} t JOIN entities_list el ON t.subject = el.name ORDER BY t.subject, t.predicate, t.object; @@ -1433,7 +1482,7 @@ async def get_entity_map( object=relationship["object"], weight=relationship["weight"], description=relationship["description"], - extraction_ids=relationship["extraction_ids"], + chunk_ids=relationship["chunk_ids"], document_id=relationship["document_id"], ) for relationship in relationships_list @@ -2109,11 +2158,11 @@ def _get_str_estimation_output(self, x: tuple[Any, Any]) -> str: else: return " - ".join(f"{round(a, 2)}" for a in x) - async def get_existing_entity_extraction_ids( + async def get_existing_entity_chunk_ids( self, document_id: UUID ) -> list[str]: QUERY = f""" - SELECT DISTINCT unnest(extraction_ids) AS chunk_id FROM {self._get_table_name("chunk_entity")} WHERE document_id = $1 + SELECT DISTINCT unnest(chunk_ids) AS chunk_id FROM {self._get_table_name("chunk_entity")} WHERE document_id = $1 """ return [ item["chunk_id"] @@ -2244,43 +2293,44 @@ async def update_entity_descriptions(self, entities: list[Entity]): ####################### PRIVATE METHODS ########################## - async def _add_objects( - self, - objects: list[Any], - table_name: str, - conflict_columns: list[str] = [], - ) -> asyncpg.Record: - """ - Upsert objects into the specified table. - """ - # Get non-null attributes from the first object - non_null_attrs = {k: v for k, v in objects[0].items() if v is not None} - columns = ", ".join(non_null_attrs.keys()) - - placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs))) +async def _add_objects( + objects: list[Any], + full_table_name: str, + connection_manager: PostgresConnectionManager, + conflict_columns: list[str] = [], + exclude_attributes: list[str] = [], +) -> asyncpg.Record: + """ + Upsert objects into the specified table. + """ + # Get non-null attributes from the first object + non_null_attrs = {k: v for k, v in objects[0].items() if v is not None and k not in exclude_attributes} + columns = ", ".join(non_null_attrs.keys()) - if conflict_columns: - conflict_columns_str = ", ".join(conflict_columns) - replace_columns_str = ", ".join( - f"{column} = EXCLUDED.{column}" for column in non_null_attrs - ) - on_conflict_query = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {replace_columns_str}" - else: - on_conflict_query = "" + placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs))) - QUERY = f""" - INSERT INTO {self._get_table_name(table_name)} ({columns}) - VALUES ({placeholders}) - {on_conflict_query} - """ + if conflict_columns: + conflict_columns_str = ", ".join(conflict_columns) + replace_columns_str = ", ".join( + f"{column} = EXCLUDED.{column}" for column in non_null_attrs + ) + on_conflict_query = f"ON CONFLICT ({conflict_columns_str}) DO UPDATE SET {replace_columns_str}" + else: + on_conflict_query = "" + + QUERY = f""" + INSERT INTO {full_table_name} ({columns}) + VALUES ({placeholders}) + {on_conflict_query} + """ - # Filter out null values for each object - params = [ - tuple( - (json.dumps(v) if isinstance(v, dict) else str(v) if "embedding" in k else v) - for k, v in ((k, v) for k, v in obj.items() if v is not None) - ) - for obj in objects - ] + # Filter out null values for each object + params = [ + tuple( + (json.dumps(v) if isinstance(v, dict) else str(v) if "embedding" in k else v) + for k, v in ((k, v) for k, v in obj.items() if v is not None and k not in exclude_attributes) + ) + for obj in objects + ] - return await self.connection_manager.execute_many(QUERY, params) # type: ignore + return await connection_manager.execute_many(QUERY, params) # type: ignore diff --git a/py/core/providers/database/kg_tmp/main.py b/py/core/providers/database/kg_tmp/main.py index 3540403e7..ef209ab3d 100644 --- a/py/core/providers/database/kg_tmp/main.py +++ b/py/core/providers/database/kg_tmp/main.py @@ -13,7 +13,7 @@ # Entity, # KGExtraction, # KGExtractionStatus, -# GraphHandler, +# KGHandler, # R2RException, # Relationship, # ) @@ -39,7 +39,7 @@ # logger = logging.getLogger() -# class PostgresGraphHandler(GraphHandler): +# class PostgresKGHandler(KGHandler): # """Handler for Knowledge Graph METHODS in PostgreSQL.""" # entity_handler: PostgresEntityHandler diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 7fb8c2894..3d3d4e08c 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -53,18 +53,18 @@ class EntityLevel(str, Enum): def __str__(self): return self.value - class Entity(R2RSerializable): """An entity extracted from a document.""" name: str - id: Optional[Union[int, UUID]] = None + id: Optional[UUID] = None + sid: Optional[int] = None #serial ID level: Optional[EntityLevel] = None category: Optional[str] = None description: Optional[str] = None description_embedding: Optional[Union[list[float], str]] = None community_numbers: Optional[list[str]] = None - extraction_ids: Optional[list[UUID]] = None + chunk_ids: Optional[list[UUID]] = None collection_id: Optional[UUID] = None document_id: Optional[UUID] = None document_ids: Optional[list[UUID]] = None @@ -93,7 +93,8 @@ def __init__(self, **kwargs): class Relationship(R2RSerializable): """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" - id: Optional[Union[int, UUID]] = None + id: Optional[UUID] = None + sid: Optional[int] = None #serial ID subject: str """The source entity name.""" @@ -119,7 +120,7 @@ class Relationship(R2RSerializable): predicate_embedding: list[float] | None = None """The semantic embedding for the relationship description (optional).""" - extraction_ids: list[UUID] = [] + chunk_ids: list[UUID] = [] """List of text unit IDs in which the relationship appears (optional).""" document_id: UUID | None = None @@ -147,7 +148,7 @@ def from_dict( # type: ignore predicate_key: str = "predicate", description_key: str = "description", weight_key: str = "weight", - extraction_ids_key: str = "extraction_ids", + chunk_ids_key: str = "chunk_ids", document_id_key: str = "document_id", attributes_key: str = "attributes", ) -> "Relationship": @@ -161,7 +162,7 @@ def from_dict( # type: ignore predicate=d.get(predicate_key), description=d.get(description_key), weight=d.get(weight_key, 1.0), - extraction_ids=d.get(extraction_ids_key), + chunk_ids=d.get(chunk_ids_key), document_id=d.get(document_id_key), attributes=d.get(attributes_key, {}), ) @@ -297,7 +298,7 @@ def from_dict(cls, d: dict[str, Any]) -> "Graph": class KGExtraction(R2RSerializable): """An extraction from a document that is part of a knowledge graph.""" - extraction_ids: list[uuid.UUID] + chunk_ids: list[uuid.UUID] document_id: uuid.UUID entities: list[Entity] relationships: list[Relationship] diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index 1d3e704c6..78adeb5d9 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -129,7 +129,7 @@ class KGSearchResult(R2RSerializable): KGEntityResult, KGRelationshipResult, KGCommunityResult, KGGlobalResult ] result_type: Optional[KGSearchResultType] = None - extraction_ids: Optional[list[UUID]] = None + chunk_ids: Optional[list[UUID]] = None metadata: dict[str, Any] = {} class Config: @@ -137,7 +137,7 @@ class Config: "method": "local", "content": KGEntityResult.Config.json_schema_extra, "result_type": "entity", - "extraction_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"], + "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"], "metadata": {"associated_query": "What is the capital of France?"}, } diff --git a/py/tests/core/pipes/test_kg_community_summary_pipe.py b/py/tests/core/pipes/test_kg_community_summary_pipe.py index 58c683ac9..f5e2e6497 100644 --- a/py/tests/core/pipes/test_kg_community_summary_pipe.py +++ b/py/tests/core/pipes/test_kg_community_summary_pipe.py @@ -53,7 +53,7 @@ def document_id(): @pytest.fixture(scope="function") -def extraction_ids(): +def chunk_ids(): return [ uuid.UUID("32ff6daf-6e67-44fa-b2a9-19384f5d9d19"), uuid.UUID("42ff6daf-6e67-44fa-b2a9-19384f5d9d19"), @@ -79,13 +79,13 @@ def embedding_vectors(embedding_dimension): @pytest.fixture(scope="function") -def entities_raw_list(document_id, extraction_ids): +def entities_raw_list(document_id, chunk_ids): return [ Entity( name="Entity1", description="Description1", category="Category1", - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, attributes={"attr1": "value1", "attr2": "value2"}, ), @@ -93,7 +93,7 @@ def entities_raw_list(document_id, extraction_ids): name="Entity2", description="Description2", category="Category2", - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, attributes={"attr3": "value3", "attr4": "value4"}, ), @@ -101,13 +101,13 @@ def entities_raw_list(document_id, extraction_ids): @pytest.fixture(scope="function") -def entities_list(extraction_ids, document_id, embedding_vectors): +def entities_list(chunk_ids, document_id, embedding_vectors): return [ Entity( id=1, name="Entity1", description="Description1", - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, description_embedding=embedding_vectors[0], ), @@ -115,7 +115,7 @@ def entities_list(extraction_ids, document_id, embedding_vectors): id=2, name="Entity2", description="Description2", - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, description_embedding=embedding_vectors[1], ), @@ -123,7 +123,7 @@ def entities_list(extraction_ids, document_id, embedding_vectors): @pytest.fixture(scope="function") -def relationships_raw_list(embedding_vectors, extraction_ids, document_id): +def relationships_raw_list(embedding_vectors, chunk_ids, document_id): return [ Relationship( id=1, @@ -133,7 +133,7 @@ def relationships_raw_list(embedding_vectors, extraction_ids, document_id): weight=1.0, description="description1", embedding=embedding_vectors[0], - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, attributes={"attr1": "value1", "attr2": "value2"}, ), @@ -145,7 +145,7 @@ def relationships_raw_list(embedding_vectors, extraction_ids, document_id): weight=1.0, description="description2", embedding=embedding_vectors[1], - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, attributes={"attr3": "value3", "attr4": "value4"}, ), diff --git a/py/tests/core/providers/ingestion/test_contextual_embedding.py b/py/tests/core/providers/ingestion/test_contextual_embedding.py index 666f65066..c00355fe5 100644 --- a/py/tests/core/providers/ingestion/test_contextual_embedding.py +++ b/py/tests/core/providers/ingestion/test_contextual_embedding.py @@ -45,7 +45,7 @@ def collection_ids(): @pytest.fixture -def extraction_ids(): +def chunk_ids(): return [ UUID("fce959df-46a2-4983-aa8b-dd1f93777e02"), UUID("9a85269c-84cd-4dff-bf21-7bd09974f668"), @@ -55,11 +55,11 @@ def extraction_ids(): @pytest.fixture def sample_chunks( - sample_document_id, sample_user, collection_ids, extraction_ids + sample_document_id, sample_user, collection_ids, chunk_ids ): return [ VectorEntry( - chunk_id=extraction_ids[0], + chunk_id=chunk_ids[0], document_id=sample_document_id, user_id=sample_user.id, collection_ids=collection_ids, @@ -72,7 +72,7 @@ def sample_chunks( metadata={"chunk_order": 0}, ), VectorEntry( - chunk_id=extraction_ids[1], + chunk_id=chunk_ids[1], document_id=sample_document_id, user_id=sample_user.id, collection_ids=collection_ids, @@ -85,7 +85,7 @@ def sample_chunks( metadata={"chunk_order": 1}, ), VectorEntry( - chunk_id=extraction_ids[2], + chunk_id=chunk_ids[2], document_id=sample_document_id, user_id=sample_user.id, collection_ids=collection_ids, diff --git a/py/tests/core/providers/kg/test_kg_logic.py b/py/tests/core/providers/kg/test_kg_logic.py index 84f41dd6e..6edb0324b 100644 --- a/py/tests/core/providers/kg/test_kg_logic.py +++ b/py/tests/core/providers/kg/test_kg_logic.py @@ -25,7 +25,7 @@ def document_id(): @pytest.fixture(scope="function") -def extraction_ids(): +def chunk_ids(): return [ uuid.UUID("32ff6daf-6e67-44fa-b2a9-19384f5d9d19"), uuid.UUID("42ff6daf-6e67-44fa-b2a9-19384f5d9d19"), @@ -51,13 +51,13 @@ def embedding_vectors(embedding_dimension): @pytest.fixture(scope="function") -def entities_raw_list(document_id, extraction_ids): +def entities_raw_list(document_id, chunk_ids): return [ Entity( name="Entity1", description="Description1", category="Category1", - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, attributes={"attr1": "value1", "attr2": "value2"}, ), @@ -65,7 +65,7 @@ def entities_raw_list(document_id, extraction_ids): name="Entity2", description="Description2", category="Category2", - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, attributes={"attr3": "value3", "attr4": "value4"}, ), @@ -73,19 +73,19 @@ def entities_raw_list(document_id, extraction_ids): @pytest.fixture(scope="function") -def entities_list(extraction_ids, document_id, embedding_vectors): +def entities_list(chunk_ids, document_id, embedding_vectors): return [ Entity( name="Entity1", description="Description1", - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, description_embedding=embedding_vectors[0], ), Entity( name="Entity2", description="Description2", - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, description_embedding=embedding_vectors[1], ), @@ -93,7 +93,7 @@ def entities_list(extraction_ids, document_id, embedding_vectors): @pytest.fixture(scope="function") -def relationships_raw_list(embedding_vectors, extraction_ids, document_id): +def relationships_raw_list(embedding_vectors, chunk_ids, document_id): return [ Relationship( subject="Entity1", @@ -102,7 +102,7 @@ def relationships_raw_list(embedding_vectors, extraction_ids, document_id): weight=1.0, description="description1", embedding=embedding_vectors[0], - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, attributes={"attr1": "value1", "attr2": "value2"}, ), @@ -113,7 +113,7 @@ def relationships_raw_list(embedding_vectors, extraction_ids, document_id): weight=1.0, description="description2", embedding=embedding_vectors[1], - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, document_id=document_id, attributes={"attr3": "value3", "attr4": "value4"}, ), @@ -148,11 +148,11 @@ def community_table_info(collection_id): @pytest.fixture(scope="function") def kg_extractions( - extraction_ids, entities_raw_list, relationships_raw_list, document_id + chunk_ids, entities_raw_list, relationships_raw_list, document_id ): return [ KGExtraction( - extraction_ids=extraction_ids, + chunk_ids=chunk_ids, entities=entities_raw_list, relationships=relationships_raw_list, document_id=document_id, @@ -288,7 +288,7 @@ async def test_upsert_embeddings( entity.name, entity.description, str(entity.description_embedding), - entity.extraction_ids, + entity.chunk_ids, entity.document_id, ) for entity in entities_list From deedf5642319a21ee98df42ffce2e1369db2c9ef Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Thu, 14 Nov 2024 09:12:08 -0800 Subject: [PATCH 12/21] up --- py/core/main/api/v3/graph_router.py | 2 +- py/core/main/services/kg_service.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index e4d48678b..fd68ad9a0 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -672,7 +672,7 @@ async def create_relationships( ..., description="The ID of the chunk to create relationships for.", ), - relationships: list[Union[Relationship, dict]] = Body( + relationships: list[Relationship] = Body( ..., description="The relationships to create." ), auth_user=Depends(self.providers.auth.auth_wrapper), diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index ef3e9c445..5b6b378a2 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -137,15 +137,21 @@ async def list_entities_v3( self, level: EntityLevel, id: UUID, + offset: int, + limit: int, entity_names: Optional[list[str]] = None, entity_categories: Optional[list[str]] = None, attributes: Optional[list[str]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, **kwargs, ): return await self.providers.database.graph_handler.entities.get( - level, id, entity_names, entity_categories, attributes, offset, limit + level=level, + id=id, + entity_names=entity_names, + entity_categories=entity_categories, + attributes=attributes, + offset=offset, + limit=limit ) @telemetry_event("update_entity_v3") @@ -434,8 +440,8 @@ async def list_entities( entity_names=entity_names, entity_categories=entity_categories, attributes=attributes, - offset=offset or 0, - limit=limit or -1, + offset=offset, + limit=limit, ) @telemetry_event("get_entities") From d6044b5623cc2d54a998f1a907d79fafe5863946 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Thu, 14 Nov 2024 14:26:37 -0800 Subject: [PATCH 13/21] up --- py/core/main/api/v3/graph_router.py | 1410 +++++---------------------- py/core/main/services/kg_service.py | 74 +- py/core/providers/database/kg.py | 123 ++- py/shared/abstractions/graph.py | 2 +- 4 files changed, 371 insertions(+), 1238 deletions(-) diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index fd68ad9a0..a586a8989 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -442,10 +442,33 @@ async def update_entity( "Only superusers can access this endpoint.", 403 ) + if not entity.level: + entity.level = self._get_path_level(request) + else: + if entity.level != self._get_path_level(request): + raise R2RException( + "Entity level must match the path level.", 400 + ) + + if entity.level == EntityLevel.CHUNK: + # don't override the chunk_ids + entity.chunk_ids = None + + elif entity.level == EntityLevel.DOCUMENT: + entity.document_id = id + + elif entity.level == EntityLevel.COLLECTION: + entity.collection_id = id + + if not entity.id: + entity.id = entity_id + else: + if entity.id != entity_id: + raise R2RException( + "Entity ID must match the entity ID or should be empty.", 400 + ) + return await self.services["kg"].update_entity_v3( - level=self._get_path_level(request), - id=id, - entity_id=entity_id, entity=entity, ) @@ -529,10 +552,10 @@ async def delete_entity( "Only superusers can access this endpoint.", 403 ) + entity = Entity(id=entity_id, level=self._get_path_level(request)) + return await self.services["kg"].delete_entity_v3( - level=self._get_path_level(request), - id=id, - entity_id=entity_id, + entity=entity, ) ##### RELATIONSHIPS ##### @@ -601,6 +624,7 @@ async def delete_entity( ) @self.base_endpoint async def list_relationships( + request: Request, id: UUID = Path( ..., description="The ID of the chunk to retrieve relationships for.", @@ -635,8 +659,8 @@ async def list_relationships( "Only superusers can access this endpoint.", 403 ) - return await self.services["kg"].list_relationships_v3( - level=EntityLevel.CHUNK, + relationships, count = await self.services["kg"].list_relationships_v3( + level=self._get_path_level(request), id=id, entity_names=entity_names, relationship_types=relationship_types, @@ -645,6 +669,10 @@ async def list_relationships( limit=limit, ) + return relationships, { + "total_entries": count, + } + @self.router.post( "/chunks/{id}/relationships", summary="Create relationships for a chunk", @@ -666,8 +694,51 @@ async def list_relationships( ] }, ) + @self.router.post( + "/documents/{id}/relationships", + summary="Create relationships for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.documents.create_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + """ + ), + }, + ] + }, + ) + @self.router.post( + "/collections/{id}/relationships", + summary="Create relationships for a collection", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.collections.create_relationships(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + """ + ), + }, + ] + }, + ) @self.base_endpoint async def create_relationships( + request: Request, id: UUID = Path( ..., description="The ID of the chunk to create relationships for.", @@ -676,26 +747,19 @@ async def create_relationships( ..., description="The relationships to create." ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGRelationshipsResponse: + ) -> WrappedKGCreationResponse: if not auth_user.is_superuser: raise R2RException( "Only superusers can access this endpoint.", 403 ) - relationships = [ - ( - Relationship(**relationship) - if isinstance(relationship, dict) - else relationship - ) - for relationship in relationships - ] - - return await self.services["kg"].create_relationships_v3( - level=EntityLevel.CHUNK, - id=id, - relationships=relationships, - ) + return { + "message": "Relationships created successfully.", + "count": await self.services["kg"].create_relationships_v3( + id=id, + relationships=relationships, + ), + } @self.router.post( "/chunks/{id}/relationships/{relationship_id}", @@ -711,7 +775,7 @@ async def create_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.update_relationship(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) + result = client.chunks.update_relationship(chunk_id="9fbe403b -c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) """ ), }, @@ -769,11 +833,12 @@ async def delete_relationship( id=id, relationship_id=relationship_id, ) + - ##### DOCUMENT LEVEL OPERATIONS ##### - @self.router.get( - "/documents/{id}/entities", - summary="List entities for a document", + # Graph-level operations + @self.router.post( + "/graphs/{collection_id}", + summary="Create a new graph", openapi_extra={ "x-codeSamples": [ { @@ -785,120 +850,114 @@ async def delete_relationship( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.list_entities(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + result = client.graphs.create( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + settings={ + "entity_types": ["PERSON", "ORG", "GPE"] + } + )""" + ), + }, + { + "lang": "cURL", + "source": textwrap.dedent( """ + curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7" \\ + -H "Content-Type: application/json" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -d '{ + "settings": { + "entity_types": ["PERSON", "ORG", "GPE"] + } + }'""" ), }, ] }, ) @self.base_endpoint - async def list_entities( - id: UUID = Path( - ..., - description="The ID of the document to retrieve entities for.", - ), - entity_names: Optional[list[str]] = Query( - None, - description="A list of entity names to filter the entities by.", - ), - entity_categories: Optional[list[str]] = Query( - None, - description="A list of entity categories to filter the entities by.", - ), - attributes: Optional[list[str]] = Query( - None, - description="A list of attributes to return. By default, all attributes are returned.", + async def create_graph( + collection_id: UUID = Path( + default=..., + description="Collection ID to create graph for.", ), - offset: int = Query( - 0, - ge=0, - description="The offset of the first entity to retrieve.", + run_type: Optional[KGRunType] = Body( + default=None, + description="Run type for the graph creation process.", ), - limit: int = Query( - 100, - ge=0, - le=20_000, - description="The maximum number of entities to retrieve, up to 20,000.", + settings: Optional[KGCreationSettings] = Body( + default=None, + description="Settings for the graph creation process.", ), + run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Entity]]: - """ - Retrieves a list of entities associated with a specific document. + ) -> WrappedKGCreationResponse: + """Creates a new knowledge graph by extracting entities and relationships from documents in a collection. + + The graph creation process involves: + 1. Parsing documents into semantic chunks + 2. Extracting entities and relationships using LLMs or NER + 3. Building a connected knowledge graph structure """ + + settings = settings.dict() if settings else None if not auth_user.is_superuser: - raise R2RException( - "Only superusers can access this endpoint.", 403 + logger.warning("Implement permission checks here.") + + logger.info(f"Running create-graph on collection {collection_id}") + + # If no collection ID is provided, use the default user collection + if not collection_id: + collection_id = generate_default_user_collection_id( + auth_user.id ) - return await self.services["kg"].list_entities_v3( - level=EntityLevel.DOCUMENT, - id=id, - offset=offset, - limit=limit, - entity_names=entity_names, - entity_categories=entity_categories, - attributes=attributes, - ) + # If no run type is provided, default to estimate + if not run_type: + run_type = KGRunType.ESTIMATE - @self.router.post( - "/documents/{id}/entities", - summary="Create entities for a document", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient + # Apply runtime settings overrides + server_kg_creation_settings = ( + self.providers.database.config.kg_creation_settings + ) - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) + if settings: + server_kg_creation_settings = update_settings_from_dict( + server_kg_creation_settings, settings + ) - result = client.documents.create_entities_v3(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) - """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def create_entities_v3( - id: UUID = Path( - ..., description="The ID of the chunk to create entities for." - ), - entities: list[Union[Entity, dict]] = Body( - ..., description="The entities to create." - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ): - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can access this endpoint.", 403 + # If the run type is estimate, return an estimate of the creation cost + if run_type is KGRunType.ESTIMATE: + return await self.services["kg"].get_creation_estimate( + collection_id, server_kg_creation_settings ) + else: - entities = [ - Entity(**entity) if isinstance(entity, dict) else entity - for entity in entities - ] - # for each entity, set the level to CHUNK - for entity in entities: - if entity.level is None: - entity.level = EntityLevel.DOCUMENT - else: - raise R2RException( - "Entity level must be chunk or empty.", 400 + # Otherwise, create the graph + if run_with_orchestration: + workflow_input = { + "collection_id": str(collection_id), + "kg_creation_settings": server_kg_creation_settings.model_dump_json(), + "user": auth_user.json(), + } + + return await self.orchestration_provider.run_workflow( # type: ignore + "create-graph", {"request": workflow_input}, {} ) + else: + from core.main.orchestration import simple_kg_factory - return await self.services["kg"].create_entities_v3( - level=EntityLevel.DOCUMENT, - id=id, - entities=entities, - ) + logger.info("Running create-graph without orchestration.") + simple_kg = simple_kg_factory(self.service) + await simple_kg["create-graph"](workflow_input) + return { + "message": "Graph created successfully.", + "task_id": None, + } - @self.router.post( - "/documents/{id}/entities/{entity_id}", - summary="Update an entity for a document", + @self.router.get( + "/graphs/{collection_id}", + summary="Get graph status", openapi_extra={ "x-codeSamples": [ { @@ -910,414 +969,17 @@ async def create_entities_v3( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.update_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + result = client.graphs.get_status( + collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + )""" + ), + }, + { + "lang": "cURL", + "source": textwrap.dedent( """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def update_entity( - id: UUID = Path( - ..., - description="The ID of the document to update the entity for.", - ), - entity_id: UUID = Path( - ..., description="The ID of the entity to update." - ), - entity: Entity = Body(..., description="The updated entity."), - auth_user=Depends(self.providers.auth.auth_wrapper), - ): - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can access this endpoint.", 403 - ) - - return await self.services["kg"].update_entity_v3( - level=EntityLevel.DOCUMENT, - id=id, - entity_id=entity_id, - entity=entity, - ) - - @self.router.delete( - "/documents/{id}/entities/{entity_id}", - summary="Delete an entity for a document", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.documents.delete_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") - """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def delete_entity( - id: UUID = Path( - ..., - description="The ID of the document to delete the entity for.", - ), - entity_id: UUID = Path( - ..., description="The ID of the entity to delete." - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ): - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can access this endpoint.", 403 - ) - - ##### RELATIONSHIPS ##### - @self.router.get( - "/documents/{id}/relationships", - summary="List relationships for a document", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.documents.list_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") - """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def list_relationships( - id: UUID = Path( - ..., - description="The ID of the document to retrieve relationships for.", - ), - entity_names: Optional[list[str]] = Query( - None, - description="A list of entity names to filter the relationships by.", - ), - relationship_types: Optional[list[str]] = Query( - None, - description="A list of relationship types to filter the relationships by.", - ), - attributes: Optional[list[str]] = Query( - None, - description="A list of attributes to return. By default, all attributes are returned.", - ), - offset: int = Query( - 0, - ge=0, - description="The offset of the first relationship to retrieve.", - ), - limit: int = Query( - 100, - ge=0, - le=20_000, - description="The maximum number of relationships to retrieve, up to 20,000.", - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Relationship]]: - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can access this endpoint.", 403 - ) - - return await self.services["kg"].list_relationships_v3( - level=EntityLevel.DOCUMENT, - id=id, - entity_names=entity_names, - relationship_types=relationship_types, - attributes=attributes, - offset=offset, - limit=limit, - ) - - @self.router.post( - "/documents/{id}/relationships", - summary="Create relationships for a document", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.documents.create_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) - """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def create_relationships( - id: UUID = Path( - ..., - description="The ID of the document to create relationships for.", - ), - relationships: list[Union[Relationship, dict]] = Body( - ..., description="The relationships to create." - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGCreationResponse: - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can access this endpoint.", 403 - ) - - relationships = [ - ( - Relationship(**relationship) - if isinstance(relationship, dict) - else relationship - ) - for relationship in relationships - ] - - return await self.services["kg"].create_relationships_v3( - level=EntityLevel.DOCUMENT, - id=id, - relationships=relationships, - ) - - @self.router.post( - "/documents/{id}/relationships/{relationship_id}", - summary="Update a relationship for a document", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.documents.update_relationship(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) - """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def update_relationship( - id: UUID = Path( - ..., - description="The ID of the document to update the relationship for.", - ), - relationship_id: UUID = Path( - ..., description="The ID of the relationship to update." - ), - relationship: Relationship = Body( - ..., description="The updated relationship." - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ): - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can access this endpoint.", 403 - ) - - return await self.services["kg"].update_relationship_v3( - level=EntityLevel.DOCUMENT, - id=id, - relationship_id=relationship_id, - relationship=relationship, - ) - - @self.router.delete( - "/documents/{id}/relationships/{relationship_id}", - summary="Delete a relationship for a document", - ) - @self.base_endpoint - async def delete_relationship( - id: UUID = Path( - ..., - description="The ID of the document to delete the relationship for.", - ), - relationship_id: UUID = Path( - ..., description="The ID of the relationship to delete." - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ): - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can access this endpoint.", 403 - ) - - return await self.services["kg"].delete_relationship_v3( - level=EntityLevel.DOCUMENT, - id=id, - relationship_id=relationship_id, - ) - - ##### COLLECTION LEVEL OPERATIONS ##### - - # Graph-level operations - @self.router.post( - "/graphs/{collection_id}", - summary="Create a new graph", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.create( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - settings={ - "entity_types": ["PERSON", "ORG", "GPE"] - } - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "settings": { - "entity_types": ["PERSON", "ORG", "GPE"] - } - }'""" - ), - }, - ] - }, - ) - @self.base_endpoint - async def create_graph( - collection_id: UUID = Path( - default=..., - description="Collection ID to create graph for.", - ), - run_type: Optional[KGRunType] = Body( - default=None, - description="Run type for the graph creation process.", - ), - settings: Optional[KGCreationSettings] = Body( - default=None, - description="Settings for the graph creation process.", - ), - run_with_orchestration: Optional[bool] = Body(True), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGCreationResponse: - """Creates a new knowledge graph by extracting entities and relationships from documents in a collection. - - The graph creation process involves: - 1. Parsing documents into semantic chunks - 2. Extracting entities and relationships using LLMs or NER - 3. Building a connected knowledge graph structure - """ - - settings = settings.dict() if settings else None - if not auth_user.is_superuser: - logger.warning("Implement permission checks here.") - - logger.info(f"Running create-graph on collection {collection_id}") - - # If no collection ID is provided, use the default user collection - if not collection_id: - collection_id = generate_default_user_collection_id( - auth_user.id - ) - - # If no run type is provided, default to estimate - if not run_type: - run_type = KGRunType.ESTIMATE - - # Apply runtime settings overrides - server_kg_creation_settings = ( - self.providers.database.config.kg_creation_settings - ) - - if settings: - server_kg_creation_settings = update_settings_from_dict( - server_kg_creation_settings, settings - ) - - # If the run type is estimate, return an estimate of the creation cost - if run_type is KGRunType.ESTIMATE: - return await self.services["kg"].get_creation_estimate( - collection_id, server_kg_creation_settings - ) - else: - - # Otherwise, create the graph - if run_with_orchestration: - workflow_input = { - "collection_id": str(collection_id), - "kg_creation_settings": server_kg_creation_settings.model_dump_json(), - "user": auth_user.json(), - } - - return await self.orchestration_provider.run_workflow( # type: ignore - "create-graph", {"request": workflow_input}, {} - ) - else: - from core.main.orchestration import simple_kg_factory - - logger.info("Running create-graph without orchestration.") - simple_kg = simple_kg_factory(self.service) - await simple_kg["create-graph"](workflow_input) - return { - "message": "Graph created successfully.", - "task_id": None, - } - - @self.router.get( - "/graphs/{collection_id}", - summary="Get graph status", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.get_status( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ + curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7" \\ + -H "Authorization: Bearer YOUR_API_KEY" """ ), }, ] @@ -1405,328 +1067,28 @@ async def get_graph_status( # raise R2RException("Only superusers can enrich graphs", 403) # server_settings = self.providers.database.config.kg_enrichment_settings - # if settings: - # server_settings = update_settings_from_dict(server_settings, settings) - - # workflow_input = { - # "collection_id": str(collection_id), - # "kg_enrichment_settings": server_settings.model_dump_json(), - # "user": auth_user.model_dump_json(), - # } - - # if run_with_orchestration: - # return await self.orchestration_provider.run_workflow( - # "enrich-graph", {"request": workflow_input}, {} - # ) - # else: - # from core.main.orchestration import simple_kg_factory - # simple_kg = simple_kg_factory(self.services["kg"]) - # await simple_kg["enrich-graph"](workflow_input) - # return {"message": "Graph enriched successfully.", "task_id": None} - - @self.router.delete( - "/graphs/{collection_id}", - summary="Delete a graph", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.delete( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - cascade=True - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7?cascade=true" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def delete_graph( - collection_id: UUID = Path(...), - cascade: bool = Query(False), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[dict]: - """Deletes a graph and optionally its associated entities and relationships.""" - if not auth_user.is_superuser: - raise R2RException("Only superusers can delete graphs", 403) - - await self.services["kg"].delete_graph(collection_id, cascade) - return {"message": "Graph deleted successfully"} # type: ignore - - # Entity operations - @self.router.post( - "/graphs/{collection_id}/entities/{level}", - summary="Create a new entity", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.create_entity( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - entity={ - "name": "John Smith", - "type": "PERSON", - "metadata": { - "source": "manual", - "confidence": 1.0 - }, - } - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/entities/document" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "name": "John Smith", - "type": "PERSON", - "metadata": { - "source": "manual", - "confidence": 1.0 - }, - }'""" - ), - }, - ] - }, - ) - @self.base_endpoint - async def create_entity( - collection_id: UUID = Path(...), - entity: dict = Body(...), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[Entity]: - """Creates a new entity in the graph.""" - if not auth_user.is_superuser: - raise R2RException("Only superusers can create entities", 403) - - new_entity = await self.services["kg"].create_entity( - collection_id, entity - ) - return new_entity # type: ignore - - @self.router.delete( - "/graphs/{collection_id}/entities/{entity_id}", - summary="Delete an entity", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.delete_entity( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - entity_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", - cascade=True - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/entities/9fbe403b-c11c-5aae-8ade-ef22980c3ad1?cascade=true" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def delete_entity( - collection_id: UUID = Path(...), - entity_id: UUID = Path(...), - cascade: bool = Query( - False, - description="Whether to also delete related relationships", - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[dict]: - """Deletes an entity and optionally its relationships.""" - if not auth_user.is_superuser: - raise R2RException("Only superusers can delete entities", 403) - - # await self.services["kg"].delete_entity( - # collection_id, entity_id, cascade - # ) - # return {"message": "Entity deleted successfully"} # type: ignore - - @self.router.get( - "/graphs/{collection_id}/entities", - summary="List entities", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.list_entities( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - level="DOCUMENT", - offset=0, - limit=100, - include_embeddings=False - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/entities?\\ - level=DOCUMENT&offset=0&limit=100&include_embeddings=false" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def list_entities( - collection_id: UUID = Path(...), - level: EntityLevel = Query(EntityLevel.DOCUMENT), - offset: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - include_embeddings: bool = Query(False), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGEntitiesResponse: - """Lists entities in the graph with filtering and pagination support. - - Entities represent the nodes in the knowledge graph, extracted from documents. - Each entity has: - - Unique identifier and name - - Entity type (e.g. PERSON, ORG, LOCATION) - - Source documents and extractions - - Generated description - - Community memberships - - Optional vector embedding - """ - entities = await self.services["kg"].list_entities( - collection_id, level, offset, limit, include_embeddings - ) - return entities # type: ignore - - @self.router.get( - "/graphs/{collection_id}/entities/{entity_id}", - summary="Get entity details", - ) - @self.base_endpoint - async def get_entity( - collection_id: UUID = Path(...), - level: EntityLevel = Query(EntityLevel.DOCUMENT), - entity_id: int = Path(...), - # include_embeddings: bool = Query(False), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[Entity]: - """Retrieves details of a specific entity.""" - entity = await self.services["kg"].get_entity( - collection_id, entity_id, include_embeddings - ) - if not entity: - raise R2RException("Entity not found", 404) - return entity - - @self.router.post( - "/graphs/{collection_id}/entities/{entity_id}", - summary="Update entity", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.update_entity( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - entity_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", - entity_update={ - "name": "Updated Entity Name", - "metadata": { - "confidence": 0.95, - "source": "manual_correction" - } - } - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/entities/9fbe403b-c11c-5aae-8ade-ef22980c3ad1" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "name": "Updated Entity Name", - "metadata": { - "confidence": 0.95, - "source": "manual_correction" - } - }'""" - ), - }, - ] - }, - ) - @self.base_endpoint - async def update_entity( - collection_id: UUID = Path(...), - entity_id: UUID = Path(...), - entity_update: dict = Body(...), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[Entity]: - """Updates an existing entity.""" - if not auth_user.is_superuser: - raise R2RException("Only superusers can update entities", 403) + # if settings: + # server_settings = update_settings_from_dict(server_settings, settings) - # updated_entity = await self.services["kg"].update_entity( - # collection_id, entity_id, entity_update - # ) - # return updated_entity # type: ignore + # workflow_input = { + # "collection_id": str(collection_id), + # "kg_enrichment_settings": server_settings.model_dump_json(), + # "user": auth_user.model_dump_json(), + # } - @self.router.post( - "/graphs/{collection_id}/entities/deduplicate", - summary="Deduplicate entities in the graph", + # if run_with_orchestration: + # return await self.orchestration_provider.run_workflow( + # "enrich-graph", {"request": workflow_input}, {} + # ) + # else: + # from core.main.orchestration import simple_kg_factory + # simple_kg = simple_kg_factory(self.services["kg"]) + # await simple_kg["enrich-graph"](workflow_input) + # return {"message": "Graph enriched successfully.", "task_id": None} + + @self.router.delete( + "/graphs/{collection_id}", + summary="Delete a graph", openapi_extra={ "x-codeSamples": [ { @@ -1738,17 +1100,9 @@ async def update_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.deduplicate_entities( + result = client.graphs.delete( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - settings={ - "kg_entity_deduplication_type": "by_name", - "kg_entity_deduplication_prompt": "default", - "generation_config": { - "model": "openai/gpt-4o-mini", - "temperature": 0.12 - }, - "max_description_input_length": 65536 - } + cascade=True )""" ), }, @@ -1756,25 +1110,26 @@ async def update_entity( "lang": "cURL", "source": textwrap.dedent( """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/entities/deduplicate" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "settings": { - "kg_entity_deduplication_type": "by_name", - "kg_entity_deduplication_prompt": "default", - "max_description_input_length": 65536, - "generation_config": { - "model": "openai/gpt-4o-mini", - "temperature": 0.12 - } - } - }'""" + curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7?cascade=true" \\ + -H "Authorization: Bearer YOUR_API_KEY" """ ), }, ] }, ) + @self.base_endpoint + async def delete_graph( + collection_id: UUID = Path(...), + cascade: bool = Query(False), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> ResultsWrapper[dict]: + """Deletes a graph and optionally its associated entities and relationships.""" + if not auth_user.is_superuser: + raise R2RException("Only superusers can delete graphs", 403) + + await self.services["kg"].delete_graph(collection_id, cascade) + return {"message": "Graph deleted successfully"} # type: ignore + @self.base_endpoint async def deduplicate_entities( collection_id: UUID = Path(...), @@ -1852,321 +1207,6 @@ async def deduplicate_entities( "task_id": None, } - @self.router.post( - "/graphs/{document_id}/relationships", - summary="Create a new relationship", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.create_relationship( - document_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - relationship={ - "source_id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", - "target_id": "7cde891f-2a3b-4c5d-6e7f-gh8i9j0k1l2m", - "type": "WORKS_FOR", - "metadata": { - "source": "manual", - "confidence": 1.0 - } - } - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/relationships" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "source_id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", - "target_id": "7cde891f-2a3b-4c5d-6e7f-gh8i9j0k1l2m", - "type": "WORKS_FOR", - "metadata": { - "source": "manual", - "confidence": 1.0 - } - }'""" - ), - }, - ] - }, - ) - @self.base_endpoint - async def create_relationship( - relationship: dict = Body(...), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[Relationship]: - """Creates a new relationship between entities.""" - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can create relationships", 403 - ) - - new_relationship = await self.services["kg"].create_relationship( - collection_id, relationship - ) - return new_relationship # type: ignore - - # Relationship operations - @self.router.get( - "/graphs/{collection_id}/relationships", - summary="List relationships", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.list_relationships( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - source_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", - relationship_type="WORKS_FOR", - offset=0, - limit=100 - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/relationships?\\ - source_id=9fbe403b-c11c-5aae-8ade-ef22980c3ad1&\\ - relationship_type=WORKS_FOR&offset=0&limit=100" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def list_relationships( - collection_id: UUID = Path(...), - source_id: Optional[UUID] = Query(None), - target_id: Optional[UUID] = Query(None), - relationship_type: Optional[str] = Query(None), - offset: int = Query( - 0, - ge=0, - description="Specifies the number of objects to skip. Defaults to 0.", - ), - limit: int = Query( - 100, - ge=1, - le=1000, - description="Specifies a limit on the number of objects to return, ranging between 1 and 100. Defaults to 100.", - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Relationship]]: - """Lists relationships (edges) between entities in the knowledge graph. - - Relationships represent connections between entities with: - - Source and target entities - - Relationship type and description - - Confidence score and metadata - - Source documents and extractions - """ - raise R2RException("Not implemented", 501) - # relationships = await self.services["kg"].list_relationships( - # collection_id, - # source_id, - # target_id, - # relationship_type, - # offset, - # limit, - # ) - # return relationships # type: ignore - - @self.router.get( - "/graphs/{collection_id}/relationships/{relationship_id}", - summary="Get relationship details", - ) - @self.base_endpoint - async def get_relationship( - collection_id: UUID = Path(...), - relationship_id: UUID = Path(...), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[Relationship]: - """Retrieves details of a specific relationship.""" - raise R2RException("Not implemented", 501) - # relationship = await self.services["kg"].get_relationship( - # collection_id, relationship_id - # ) - # if not relationship: - # raise R2RException("Relationship not found", 404) - # return relationship # type: ignore - - @self.router.post( - "/graphs/{collection_id}/relationships/{relationship_id}", - summary="Update relationship", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.update_relationship( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - relationship_id="8abc123d-ef45-678g-hi90-jklmno123456", - relationship_update={ - "type": "EMPLOYED_BY", - "metadata": { - "confidence": 0.95, - "source": "manual_correction" - } - } - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/relationships/8abc123d-ef45-678g-hi90-jklmno123456" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "type": "EMPLOYED_BY", - "metadata": { - "confidence": 0.95, - "source": "manual_correction" - } - }'""" - ), - }, - ] - }, - ) - @self.base_endpoint - async def update_relationship( - collection_id: UUID = Path(...), - relationship_id: UUID = Path(...), - relationship_update: dict = Body(...), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[Relationship]: - """Updates an existing relationship.""" - raise NotImplementedError("Not implemented") - # if not auth_user.is_superuser: - # raise R2RException( - # "Only superusers can update relationships", 403 - # ) - - updated_relationship = await self.services[ - "kg" - ].update_relationship( - collection_id, relationship_id, relationship_update - ) - return updated_relationship # type: ignore - - @self.router.delete( - "/graphs/{collection_id}/relationships/{relationship_id}", - summary="Delete a relationship", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.delete_relationship( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - relationship_id="8abc123d-ef45-678g-hi90-jklmno123456" - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/relationships/8abc123d-ef45-678g-hi90-jklmno123456" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def delete_relationship( - collection_id: UUID = Path(...), - relationship_id: UUID = Path(...), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[dict]: - """Deletes a relationship.""" - raise NotImplementedError("Not implemented") - # if not auth_user.is_superuser: - # raise R2RException( - # "Only superusers can delete relationships", 403 - # ) - - # await self.services["kg"].delete_relationship( - # collection_id, relationship_id - # ) - # return {"message": "Relationship deleted successfully"} # type: ignore - - # Community operations - @self.router.post( - "/graphs/{collection_id}/communities", - summary="Create communities in the graph", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) - - result = client.graphs.create_communities( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - settings={ - "max_summary_input_length": 65536, - } - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities/create" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "settings": { - "max_summary_input_length": 65536, - } - }'""" - ), - }, - ] - }, - ) @self.base_endpoint async def create_communities( collection_id: UUID = Path(...), diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 5b6b378a2..7751ee5d0 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -157,49 +157,55 @@ async def list_entities_v3( @telemetry_event("update_entity_v3") async def update_entity_v3( self, - level: EntityLevel, - id: UUID, - entity_id: UUID, entity: Entity, **kwargs, ): return await self.providers.database.graph_handler.entities.update( - level, id, entity_id, entity + entity ) - @telemetry_event("update_entity") - async def update_entity( + @telemetry_event(" ") + async def delete_entity_v3( self, - collection_id: UUID, entity: Entity, **kwargs, ): - return await self.providers.database.update_entity( - collection_id, entity + return await self.providers.database.graph_handler.entities.delete( + entity ) - @telemetry_event("delete_entity") - async def delete_entity( + ################### RELATIONSHIPS ################### + + + @telemetry_event("list_relationships_v3") + async def list_relationships_v3( self, - collection_id: UUID, - entity: Entity, - **kwargs, + id: UUID, + level: EntityLevel, + offset: int, + limit: int, + entity_names: Optional[list[str]] = None, + relationship_types: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, ): - return await self.providers.database.delete_entity( - collection_id, entity + return await self.providers.database.graph_handler.relationships.get( + id=id, + level=level, + entity_names=entity_names, + relationship_types=relationship_types, + attributes=attributes, + offset=offset, + limit=limit, ) - ################### RELATIONSHIPS ################### - - @telemetry_event("create_relationship") - async def create_relationship( + @telemetry_event("create_relationships_v3") + async def create_relationships_v3( self, - collection_id: UUID, - relationship: Relationship, + relationships: list[Relationship], **kwargs, ): - return await self.providers.database.create_relationship( - collection_id, relationship + return await self.providers.database.graph_handler.relationships.create( + relationships ) @telemetry_event("get_document_ids_for_create_graph") @@ -540,20 +546,20 @@ async def list_relationships_v3( self, level: EntityLevel, id: UUID, + offset: int, + limit: int, entity_names: Optional[list[str]] = None, relationship_types: Optional[list[str]] = None, attributes: Optional[list[str]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, ): - return await self.providers.database.list_relationships_v3( - level, - id, - entity_names, - relationship_types, - attributes, - offset, - limit, + return await self.providers.database.graph_handler.relationships.get( + level=level, + id=id, + entity_names=entity_names, + relationship_types=relationship_types, + attributes=attributes, + offset=offset, + limit=limit, ) ##### Communities ##### diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 6da3ace45..cd4fbacb7 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -268,30 +268,27 @@ async def update(self, entity: Entity) -> None: if count == 0: raise R2RException("Entity does not exist", 204) - await _add_objects( - objects=[entity], + return await _update_object( + object=entity.__dict__, full_table_name=self._get_table_name(table_name), connection_manager=self.connection_manager, - conflict_columns=non_null_attributes, - exclude_attributes=["level"] + id_column="id", ) - async def delete(self, entity_id: UUID, level: EntityLevel) -> None: + async def delete(self, entity: Entity) -> None: """Delete an entity from the database. Args: entity_id: UUID of the entity to delete level: Level of the entity (chunk, document, or collection) """ - table_name = level.value + "_entity" - QUERY = f""" - DELETE FROM {self._get_table_name(table_name)} WHERE id = $1 - """ - await self.connection_manager.execute_query( - QUERY, [entity_id] + table_name = entity.level.value + "_entity" + return await _delete_object( + object_id=entity.id, + full_table_name=self._get_table_name(table_name), + connection_manager=self.connection_manager, ) - class PostgresRelationshipHandler(RelationshipHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: self.project_name = kwargs.get("project_name") @@ -302,6 +299,7 @@ async def create_tables(self) -> None: QUERY = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("relationship")} ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL NOT NULL, subject TEXT NOT NULL, predicate TEXT NOT NULL, object TEXT NOT NULL, @@ -329,16 +327,54 @@ def _get_table_name(self, table: str) -> str: async def create(self, relationships: list[Relationship]) -> None: """Create a new relationship in the database.""" - await self._add_objects(relationships, "relationship") + await _add_objects( + objects=[relationship.__dict__ for relationship in relationships], + full_table_name=self._get_table_name("relationship"), + connection_manager=self.connection_manager, + ) - async def get(self, relationship_id: UUID) -> list[Relationship]: + async def get(self, + id: UUID, + level: EntityLevel, + entity_names: Optional[list[str]] = None, + relationship_types: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, + offset: int = 0, + limit: int = -1 + ) -> list[Relationship]: """Get relationships from storage by ID.""" + + filter = { + EntityLevel.CHUNK: "chunk_ids = ANY($1)", + EntityLevel.DOCUMENT: "document_id = $1", + }[level] + + params = [id] + + if entity_names: + filter += " AND (subject = ANY($2) OR object = ANY($2))" + params.append(entity_names) + + if relationship_types: + filter += " AND predicate = ANY($3)" + params.append(relationship_types) + QUERY = f""" - SELECT * FROM {self._get_table_name("relationship_chunk")} - WHERE id = $1 + SELECT * FROM {self._get_table_name("relationship")} + WHERE {filter} + OFFSET ${len(params)+1} LIMIT ${len(params) + 2} """ - rows = await self.connection_manager.fetch_query(QUERY, [relationship_id]) - return [Relationship(**row) for row in rows] + + + params.extend([offset, limit]) + rows = await self.connection_manager.fetch_query(QUERY, params) + + QUERY_COUNT = f""" + SELECT COUNT(*) FROM {self._get_table_name("relationship")} WHERE {filter} + """ + count = (await self.connection_manager.fetch_query(QUERY_COUNT, params[:-2]))[0]["count"] + + return [Relationship(**row) for row in rows], count async def update(self, relationship: Relationship) -> None: @@ -2334,3 +2370,54 @@ async def _add_objects( ] return await connection_manager.execute_many(QUERY, params) # type: ignore + +async def _update_object( + object: dict[str, Any], + full_table_name: str, + connection_manager: PostgresConnectionManager, + id_column: str = "id", + exclude_attributes: list[str] = [], +) -> asyncpg.Record: + """ + Update a single object in the specified table. + + Args: + object: Dictionary containing the fields to update + full_table_name: Name of the table to update + connection_manager: Database connection manager + id_column: Name of the ID column to use in WHERE clause (default: "id") + exclude_attributes: List of attributes to exclude from update + """ + # Get non-null attributes, excluding the ID and any excluded attributes + non_null_attrs = { + k: v for k, v in object.items() + if v is not None and k != id_column and k not in exclude_attributes + } + + # Create SET clause with placeholders + set_clause = ", ".join(f"{k} = ${i+1}" for i, k in enumerate(non_null_attrs.keys())) + + QUERY = f""" + UPDATE {full_table_name} + SET {set_clause} + WHERE {id_column} = ${len(non_null_attrs) + 1} + """ + + # Prepare parameters: values for SET clause + ID value for WHERE clause + params = [ + (json.dumps(v) if isinstance(v, dict) else str(v) if "embedding" in k else v) + for k, v in non_null_attrs.items() + ] + params.append(object[id_column]) + + return await connection_manager.execute_many(QUERY, [tuple(params)]) # type: ignore + +async def _delete_object( + object_id: UUID, + full_table_name: str, + connection_manager: PostgresConnectionManager, +): + QUERY = f""" + DELETE FROM {full_table_name} WHERE id = $1 + """ + return await connection_manager.execute_query(QUERY, [object_id]) \ No newline at end of file diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 3d3d4e08c..25ef7f2a8 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -56,7 +56,7 @@ def __str__(self): class Entity(R2RSerializable): """An entity extracted from a document.""" - name: str + name: Optional[str] = None id: Optional[UUID] = None sid: Optional[int] = None #serial ID level: Optional[EntityLevel] = None From 46550431cde9ba87d0115e348f1e992dac5cf98e Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Thu, 14 Nov 2024 17:16:22 -0800 Subject: [PATCH 14/21] up --- py/core/main/api/v3/graph_router.py | 486 ++++++++++++------ .../main/orchestration/hatchet/kg_workflow.py | 37 +- py/core/main/services/kg_service.py | 57 +- py/core/providers/database/kg.py | 146 +++--- py/shared/abstractions/graph.py | 40 +- py/shared/abstractions/kg.py | 2 +- 6 files changed, 529 insertions(+), 239 deletions(-) diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index a586a8989..a8ed83091 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -28,8 +28,6 @@ ) - - from core.providers import ( HatchetOrchestrationProvider, SimpleOrchestrationProvider, @@ -50,42 +48,6 @@ logger = logging.getLogger() -# class Entity(BaseModel): -# """Model representing a graph entity.""" - -# id: UUID -# name: str -# type: str -# metadata: dict = Field(default_factory=dict) -# level: EntityLevel -# collection_ids: list[UUID] -# embedding: Optional[list[float]] = None - -# class Config: -# json_schema_extra = { -# "example": { -# "id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", -# "name": "John Smith", -# "type": "PERSON", -# "metadata": {"confidence": 0.95}, -# "level": "DOCUMENT", -# "collection_ids": ["d09dedb1-b2ab-48a5-b950-6e1f464d83e7"], -# "embedding": [0.1, 0.2, 0.3], -# } -# } - - -# class Relationship(BaseModel): -# """Model representing a graph relationship.""" - -# id: UUID -# subject_id: UUID -# object_id: UUID -# subject_name: str -# object_name: str -# predicate: str - - class GraphRouter(BaseRouterV3): def __init__( self, @@ -109,9 +71,82 @@ def _get_path_level(self, request: Request) -> EntityLevel: def _setup_routes(self): + # create graph new endpoint + + @self.router.post( + "/documents/{id}/graphs/{run_type}", + ) + @self.base_endpoint + async def create_graph( + id: UUID = Path(..., description="The ID of the document to create a graph for."), + run_type: KGRunType = Path( + default=KGRunType.CREATE, + description="Run type for the graph creation process.", + ), + settings: Optional[KGCreationSettings] = Body( + default=None, + description="Settings for the graph creation process.", + ), + run_with_orchestration: Optional[bool] = Body(True), + auth_user=Depends(self.providers.auth.auth_wrapper), + ) -> WrappedKGCreationResponse: + """ + Creates a new knowledge graph by extracting entities and relationships from a document. + The graph creation process involves: + 1. Parsing documents into semantic chunks + 2. Extracting entities and relationships using LLMs or NER + """ + + settings = settings.dict() if settings else None + if not auth_user.is_superuser: + logger.warning("Implement permission checks here.") + + # If no run type is provided, default to estimate + if not run_type: + run_type = KGRunType.ESTIMATE + + # Apply runtime settings overrides + server_kg_creation_settings = ( + self.providers.database.config.kg_creation_settings + ) + + if settings: + server_kg_creation_settings = update_settings_from_dict( + server_kg_creation_settings, settings + ) + + # If the run type is estimate, return an estimate of the creation cost + if run_type is KGRunType.ESTIMATE: + raise NotImplementedError("Estimate is not implemented yet.") + # return await self.services["kg"].get_creation_estimate( + # document_id = id, settings=server_kg_creation_settings + # ) + else: + # Otherwise, create the graph + if run_with_orchestration: + workflow_input = { + "document_id": str(id), + "kg_creation_settings": server_kg_creation_settings.model_dump_json(), + "user": auth_user.json(), + } + + return await self.orchestration_provider.run_workflow( # type: ignore + "create-graph", {"request": workflow_input}, {} + ) + else: + from core.main.orchestration import simple_kg_factory + + logger.info("Running create-graph without orchestration.") + simple_kg = simple_kg_factory(self.service) + await simple_kg["create-graph"](workflow_input) + return { + "message": "Graph created successfully.", + "task_id": None, + } + ##### ENTITIES ###### @self.router.get( - "/chunks/{id}/entities", + "/chunks/{id}/graphs/entities", summary="List entities for a chunk", openapi_extra={ "x-codeSamples": [ @@ -124,7 +159,7 @@ def _setup_routes(self): client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.list_entities(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + result = client.chunks.graphs.list_entities(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) """ ), }, @@ -132,7 +167,7 @@ def _setup_routes(self): }, ) @self.router.get( - "/documents/{id}/entities", + "/documents/{id}/graphs/entities", summary="List entities for a document", openapi_extra={ "x-codeSamples": [ @@ -145,7 +180,7 @@ def _setup_routes(self): client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.list_entities(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + result = client.documents.graphs.list_entities(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) """ ), }, @@ -153,7 +188,7 @@ def _setup_routes(self): }, ) @self.router.get( - "/collections/{id}/entities", + "/collections/{id}/graphs/entities", summary="List entities for a collection", openapi_extra={ "x-codeSamples": [ @@ -166,7 +201,7 @@ def _setup_routes(self): client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.list_entities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + result = client.collections.graphs.list_entities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) """ ), }, @@ -232,7 +267,7 @@ async def list_entities( } @self.router.post( - "/chunks/{id}/entities", + "/chunks/{id}/graphs/entities", summary="Create entities for a chunk", openapi_extra={ "x-codeSamples": [ @@ -245,7 +280,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.create_entities_v3(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.chunks.graphs.create_entities_v3(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -253,7 +288,7 @@ async def list_entities( }, ) @self.router.post( - "/documents/{id}/entities", + "/documents/{id}/graphs/entities", summary="Create entities for a document", openapi_extra={ "x-codeSamples": [ @@ -266,7 +301,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.create_entities_v3(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.documents.graphs.create_entities_v3(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -274,7 +309,7 @@ async def list_entities( }, ) @self.router.post( - "/collections/{id}/entities", + "/collections/{id}/graphs/entities", summary="Create entities for a collection", openapi_extra={ "x-codeSamples": [ @@ -287,7 +322,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.create_entities_v3(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.collections.graphs.create_entities_v3(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -362,7 +397,7 @@ async def create_entities_v3( ) @self.router.post( - "/chunks/{id}/entities/{entity_id}", + "/chunks/{id}/graphs/entities/{entity_id}", summary="Update an entity for a chunk", openapi_extra={ "x-codeSamples": [ @@ -375,7 +410,7 @@ async def create_entities_v3( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.update_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + result = client.chunks.graphs.update_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) """ ), }, @@ -383,7 +418,7 @@ async def create_entities_v3( }, ) @self.router.post( - "/documents/{id}/entities/{entity_id}", + "/documents/{id}/graphs/entities/{entity_id}", summary="Update an entity for a document", openapi_extra={ "x-codeSamples": [ @@ -396,7 +431,7 @@ async def create_entities_v3( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.update_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + result = client.documents.graphs.update_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) """ ), }, @@ -404,7 +439,7 @@ async def create_entities_v3( }, ) @self.router.post( - "/collections/{id}/entities/{entity_id}", + "/collections/{id}/graphs/entities/{entity_id}", summary="Update an entity for a collection", openapi_extra={ "x-codeSamples": [ @@ -417,7 +452,7 @@ async def create_entities_v3( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.update_entity(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + result = client.collections.graphs.update_entity(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) """ ), }, @@ -473,7 +508,7 @@ async def update_entity( ) @self.router.delete( - "/chunks/{id}/entities/{entity_id}", + "/chunks/{id}/graphs/entities/{entity_id}", summary="Delete an entity for a chunk", openapi_extra={ "x-codeSamples": [ @@ -486,7 +521,7 @@ async def update_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.delete_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + result = client.chunks.graphs.delete_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") """ ), }, @@ -494,7 +529,7 @@ async def update_entity( }, ) @self.router.delete( - "/documents/{id}/entities/{entity_id}", + "/documents/{id}/graphs/entities/{entity_id}", summary="Delete an entity for a document", openapi_extra={ "x-codeSamples": [ @@ -507,7 +542,7 @@ async def update_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.delete_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + result = client.documents.graphs.delete_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") """ ), }, @@ -515,7 +550,7 @@ async def update_entity( }, ) @self.router.delete( - "/collections/{id}/entities/{entity_id}", + "/collections/{id}/graphs/entities/{entity_id}", summary="Delete an entity for a collection", openapi_extra={ "x-codeSamples": [ @@ -528,7 +563,7 @@ async def update_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.delete_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + result = client.collections.graphs.delete_entity(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") """ ), }, @@ -560,7 +595,7 @@ async def delete_entity( ##### RELATIONSHIPS ##### @self.router.get( - "/chunks/{id}/relationships", + "/chunks/{id}/graphs/relationships", summary="List relationships for a chunk", openapi_extra={ "x-codeSamples": [ @@ -573,7 +608,7 @@ async def delete_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.list_relationships(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.chunks.graphs.list_relationships(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") """ ), }, @@ -581,7 +616,7 @@ async def delete_entity( }, ) @self.router.get( - "/documents/{id}/relationships", + "/documents/{id}/graphs/relationships", summary="List relationships for a document", openapi_extra={ "x-codeSamples": [ @@ -594,7 +629,7 @@ async def delete_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.list_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.documents.graphs.list_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") """ ), }, @@ -615,7 +650,7 @@ async def delete_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.list_relationships(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.collections.graphs.list_relationships(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") """ ), }, @@ -674,7 +709,7 @@ async def list_relationships( } @self.router.post( - "/chunks/{id}/relationships", + "/chunks/{id}/graphs/relationships", summary="Create relationships for a chunk", openapi_extra={ "x-codeSamples": [ @@ -687,7 +722,7 @@ async def list_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.create_relationships(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + result = client.chunks.graphs.create_relationships(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) """ ), }, @@ -695,7 +730,7 @@ async def list_relationships( }, ) @self.router.post( - "/documents/{id}/relationships", + "/documents/{id}/graphs/relationships", summary="Create relationships for a document", openapi_extra={ "x-codeSamples": [ @@ -708,7 +743,7 @@ async def list_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.create_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + result = client.documents.graphs.create_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) """ ), }, @@ -716,7 +751,7 @@ async def list_relationships( }, ) @self.router.post( - "/collections/{id}/relationships", + "/collections/{id}/graphs/relationships", summary="Create relationships for a collection", openapi_extra={ "x-codeSamples": [ @@ -729,7 +764,7 @@ async def list_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.create_relationships(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + result = client.collections.graphs.create_relationships(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) """ ), }, @@ -762,7 +797,7 @@ async def create_relationships( } @self.router.post( - "/chunks/{id}/relationships/{relationship_id}", + "/chunks/{id}/graphs/relationships/{relationship_id}", summary="Update a relationship for a chunk", openapi_extra={ "x-codeSamples": [ @@ -775,7 +810,49 @@ async def create_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.update_relationship(chunk_id="9fbe403b -c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) + result = client.chunks.graphs.update_relationship(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) + """ + ), + }, + ] + }, + ) + @self.router.post( + "/documents/{id}/graphs/relationships/{relationship_id}", + summary="Update a relationship for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.documents.update_relationship(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) + """ + ), + }, + ] + }, + ) + @self.router.post( + "/collections/{id}/graphs/relationships/{relationship_id}", + summary="Update a relationship for a collection", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.collections.graphs.update_relationship(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) """ ), }, @@ -784,6 +861,7 @@ async def create_relationships( ) @self.base_endpoint async def update_relationship( + request: Request, id: UUID = Path( ..., description="The ID of the chunk to update the relationship for.", @@ -801,19 +879,84 @@ async def update_relationship( "Only superusers can access this endpoint.", 403 ) + level = self._get_path_level(request) + + if not relationship.id: + relationship.id = relationship_id + else: + if relationship.id != relationship_id: + raise ValueError("Relationship ID in path and body do not match") + return await self.services["kg"].update_relationship_v3( - level=EntityLevel.CHUNK, - id=id, - relationship_id=relationship_id, relationship=relationship, ) @self.router.delete( - "/chunks/{id}/relationships/{relationship_id}", + "/chunks/{id}/graphs/relationships/{relationship_id}", summary="Delete a relationship for a chunk", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.chunks.graphs.delete_relationship(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000") + """ + ), + }, + ] + }, + ) + @self.router.delete( + "/documents/{id}/graphs/relationships/{relationship_id}", + summary="Delete a relationship for a document", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.documents.graphs.delete_relationship(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000") + """ + ), + }, + ] + }, + ) + @self.router.delete( + "/collections/{id}/graphs/relationships/{relationship_id}", + summary="Delete a relationship for a collection", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": textwrap.dedent( + """ + from r2r import R2RClient + + client = R2RClient("http://localhost:7272") + # when using auth, do client.login(...) + + result = client.collections.graphs.delete_relationship(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000") + """ + ), + }, + ] + }, ) @self.base_endpoint async def delete_relationship( + request: Request, id: UUID = Path( ..., description="The ID of the chunk to delete the relationship for.", @@ -828,16 +971,113 @@ async def delete_relationship( "Only superusers can access this endpoint.", 403 ) + level = self._get_path_level(request) + if level == EntityLevel.CHUNK: + chunk_ids = [id] + relationship = Relationship(id = relationship_id, chunk_ids = chunk_ids) + elif level == EntityLevel.DOCUMENT: + relationship = Relationship(id = relationship_id, document_id = id) + else: + relationship = Relationship(id = relationship_id, collection_id = id) + return await self.services["kg"].delete_relationship_v3( - level=EntityLevel.CHUNK, - id=id, - relationship_id=relationship_id, + relationship=relationship, + ) + + + ################### COMMUNITIES ################### + + @self.router.post( + "/collections/{id}/graphs/", + ) + @self.base_endpoint + async def create_communities( + request: Request, + id: UUID = Path(..., description="The ID of the collection to create communities for."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + # run enrich graph workflow + + @self.router.post( + "/collections/{id}/graphs/communities", + summary="Create communities", + ) + @self.base_endpoint + async def create_communities( + request: Request, + id: UUID = Path(..., description="The ID of the collection to create communities for."), + communities: list[Community] = Body(..., description="The communities to create."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + for community in communities: + if not community.collection_id: + community.collection_id = id + else: + if community.collection_id != id: + raise ValueError("Collection ID in path and body do not match") + + return await self.services["kg"].create_communities_v3(communities) + + + @self.router.get( + "/collections/{id}/graphs/communities", + summary="Get communities", + ) + @self.base_endpoint + async def get_communities( + request: Request, + id: UUID = Path(..., description="The ID of the collection to get communities for."), + offset: int = Query(0, description="Number of communities to skip"), + limit: int = Query(100, description="Maximum number of communities to return"), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + return await self.services["kg"].get_communities_v3( + collection_id=id, + offset=offset, + limit=limit, + ) + + @self.router.delete( + "/collections/{id}/graphs/communities/{community_id}", + summary="Delete a community", + ) + @self.base_endpoint + async def delete_community( + request: Request, + id: UUID = Path(..., description="The ID of the collection to delete the community from."), + community_id: UUID = Path(..., description="The ID of the community to delete."), + auth_user=Depends(self.providers.auth.auth_wrapper), + ): + if not auth_user.is_superuser: + raise R2RException("Only superusers can access this endpoint.", 403) + + community = Community(id=community_id, collection_id=id) + + return await self.services["kg"].delete_community_v3( + community=community, ) - + + + ################### GRAPHS ################### + + @self.base_endpoint + async def create_entities( + request: Request, + ): + pass # Graph-level operations @self.router.post( - "/graphs/{collection_id}", + "/graphs/", summary="Create a new graph", openapi_extra={ "x-codeSamples": [ @@ -876,84 +1116,12 @@ async def delete_relationship( }, ) @self.base_endpoint - async def create_graph( - collection_id: UUID = Path( - default=..., - description="Collection ID to create graph for.", - ), - run_type: Optional[KGRunType] = Body( - default=None, - description="Run type for the graph creation process.", - ), - settings: Optional[KGCreationSettings] = Body( - default=None, - description="Settings for the graph creation process.", - ), - run_with_orchestration: Optional[bool] = Body(True), + async def create_empty_graph( auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGCreationResponse: - """Creates a new knowledge graph by extracting entities and relationships from documents in a collection. - - The graph creation process involves: - 1. Parsing documents into semantic chunks - 2. Extracting entities and relationships using LLMs or NER - 3. Building a connected knowledge graph structure - """ - - settings = settings.dict() if settings else None - if not auth_user.is_superuser: - logger.warning("Implement permission checks here.") - - logger.info(f"Running create-graph on collection {collection_id}") - - # If no collection ID is provided, use the default user collection - if not collection_id: - collection_id = generate_default_user_collection_id( - auth_user.id - ) - - # If no run type is provided, default to estimate - if not run_type: - run_type = KGRunType.ESTIMATE - - # Apply runtime settings overrides - server_kg_creation_settings = ( - self.providers.database.config.kg_creation_settings - ) - - if settings: - server_kg_creation_settings = update_settings_from_dict( - server_kg_creation_settings, settings - ) - - # If the run type is estimate, return an estimate of the creation cost - if run_type is KGRunType.ESTIMATE: - return await self.services["kg"].get_creation_estimate( - collection_id, server_kg_creation_settings - ) - else: - - # Otherwise, create the graph - if run_with_orchestration: - workflow_input = { - "collection_id": str(collection_id), - "kg_creation_settings": server_kg_creation_settings.model_dump_json(), - "user": auth_user.json(), - } + ): + - return await self.orchestration_provider.run_workflow( # type: ignore - "create-graph", {"request": workflow_input}, {} - ) - else: - from core.main.orchestration import simple_kg_factory - logger.info("Running create-graph without orchestration.") - simple_kg = simple_kg_factory(self.service) - await simple_kg["create-graph"](workflow_input) - return { - "message": "Graph created successfully.", - "task_id": None, - } @self.router.get( "/graphs/{collection_id}", diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 63e408692..34509d9fe 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -27,6 +27,9 @@ def hatchet_kg_factory( def get_input_data_dict(input_data): for key, value in input_data.items(): + if key == "document_id": + input_data[key] = uuid.UUID(value) + if key == "collection_id": input_data[key] = uuid.UUID(value) @@ -164,22 +167,30 @@ async def get_document_ids_for_create_graph( input_data = get_input_data_dict( context.workflow_input()["request"] ) - collection_id = input_data["collection_id"] - return_val = { - "document_ids": [ - str(doc_id) - for doc_id in await self.kg_service.get_document_ids_for_create_graph( - collection_id=collection_id, - **input_data["kg_creation_settings"], + if "collection_id" in input_data: + + collection_id = input_data["collection_id"] + + return_val = { + "document_ids": [ + str(doc_id) + for doc_id in await self.kg_service.get_document_ids_for_create_graph( + collection_id=collection_id, + **input_data["kg_creation_settings"], + ) + ] + } + + if len(return_val["document_ids"]) == 0: + raise ValueError( + "No documents to process, either all documents to create the graph were already created or in progress, or the collection is empty." ) - ] - } - if len(return_val["document_ids"]) == 0: - raise ValueError( - "No documents to process, either all documents to create the graph were already created or in progress, or the collection is empty." - ) + else: + return_val = { + "document_ids": [str(input_data["document_id"])] + } return return_val diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 7751ee5d0..6660bded4 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -16,6 +16,7 @@ R2RException, Entity, Relationship, + Community, ) from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider from core.telemetry.telemetry_decorator import telemetry_event @@ -176,7 +177,6 @@ async def delete_entity_v3( ################### RELATIONSHIPS ################### - @telemetry_event("list_relationships_v3") async def list_relationships_v3( self, @@ -207,6 +207,61 @@ async def create_relationships_v3( return await self.providers.database.graph_handler.relationships.create( relationships ) + + @telemetry_event("delete_relationship_v3") + async def delete_relationship_v3( + self, + relationship: Relationship, + **kwargs, + ): + return await self.providers.database.graph_handler.relationships.delete(relationship) + + + @telemetry_event("update_relationship_v3") + async def update_relationship_v3( + self, + relationship: Relationship, + **kwargs, + ): + return await self.providers.database.graph_handler.relationships.update(relationship) + + + ################### COMMUNITIES ################### + + @telemetry_event("create_communities_v3") + async def create_communities_v3( + self, + communities: list[Community], + **kwargs, + ): + return await self.providers.database.graph_handler.communities.create(communities) + + @telemetry_event("update_community_v3") + async def update_community_v3( + self, + community: Community, + **kwargs, + ): + return await self.providers.database.graph_handler.communities.update(community) + + @telemetry_event("delete_community_v3") + async def delete_community_v3( + self, + community: Community, + **kwargs, + ): + return await self.providers.database.graph_handler.communities.delete(community) + + @telemetry_event("list_communities_v3") + async def list_communities_v3( + self, + id: UUID, + level: EntityLevel, + **kwargs, + ): + return await self.providers.database.graph_handler.communities.get(id, level) + + ################### GRAPH ################### @telemetry_event("get_document_ids_for_create_graph") async def get_document_ids_for_create_graph( diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index cd4fbacb7..43fafb4ef 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -297,7 +297,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: async def create_tables(self) -> None: """Create the relationships table if it doesn't exist.""" QUERY = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name("relationship")} ( + CREATE TABLE IF NOT EXISTS {self._get_table_name("chunk_relationship")} ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), sid SERIAL NOT NULL, subject TEXT NOT NULL, @@ -310,14 +310,15 @@ async def create_tables(self) -> None: predicate_embedding FLOAT[], chunk_ids UUID[], document_id UUID, + collection_id UUID, attributes JSONB DEFAULT '{{}}'::jsonb, created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP ); - CREATE INDEX IF NOT EXISTS relationship_subject_idx ON {self._get_table_name("relationship")} (subject); - CREATE INDEX IF NOT EXISTS relationship_object_idx ON {self._get_table_name("relationship")} (object); - CREATE INDEX IF NOT EXISTS relationship_predicate_idx ON {self._get_table_name("relationship")} (predicate); - CREATE INDEX IF NOT EXISTS relationship_document_id_idx ON {self._get_table_name("relationship")} (document_id); + CREATE INDEX IF NOT EXISTS relationship_subject_idx ON {self._get_table_name("chunk_relationship")} (subject); + CREATE INDEX IF NOT EXISTS relationship_object_idx ON {self._get_table_name("chunk_relationship")} (object); + CREATE INDEX IF NOT EXISTS relationship_predicate_idx ON {self._get_table_name("chunk_relationship")} (predicate); + CREATE INDEX IF NOT EXISTS relationship_document_id_idx ON {self._get_table_name("chunk_relationship")} (document_id); """ await self.connection_manager.execute_query(QUERY) @@ -329,7 +330,7 @@ async def create(self, relationships: list[Relationship]) -> None: """Create a new relationship in the database.""" await _add_objects( objects=[relationship.__dict__ for relationship in relationships], - full_table_name=self._get_table_name("relationship"), + full_table_name=self._get_table_name("chunk_relationship"), connection_manager=self.connection_manager, ) @@ -360,7 +361,7 @@ async def get(self, params.append(relationship_types) QUERY = f""" - SELECT * FROM {self._get_table_name("relationship")} + SELECT * FROM {self._get_table_name("chunk_relationship")} WHERE {filter} OFFSET ${len(params)+1} LIMIT ${len(params) + 2} """ @@ -370,81 +371,107 @@ async def get(self, rows = await self.connection_manager.fetch_query(QUERY, params) QUERY_COUNT = f""" - SELECT COUNT(*) FROM {self._get_table_name("relationship")} WHERE {filter} + SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE {filter} """ count = (await self.connection_manager.fetch_query(QUERY_COUNT, params[:-2]))[0]["count"] return [Relationship(**row) for row in rows], count async def update(self, relationship: Relationship) -> None: + return await _update_object( + object=relationship.__dict__, + full_table_name=self._get_table_name("chunk_relationship"), + connection_manager=self.connection_manager, + id_column="id", + ) - # check if the relationship already exists - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("relationship")} WHERE id = $1 - """ - count = (await self.connection_manager.fetch_query(QUERY, [relationship.id]))[0]["count"] - if count == 0: - raise R2RException("Relationship does not exist", 204) - - return await self._add_objects([relationship], "relationship", [relationship.id]) - - async def delete(self, relationship_id: UUID) -> None: + async def delete(self, relationship: Relationship) -> None: """Delete a relationship from the database.""" QUERY = f""" - DELETE FROM {self._get_table_name("relationship")} + DELETE FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 """ - await self.connection_manager.execute_query(QUERY, [relationship_id]) -class PostgresCommunityHandler(CommunityHandler): - - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.project_name = kwargs.get("project_name") - self.connection_manager = kwargs.get("connection_manager") - - async def create_tables(self) -> None: - pass - - async def create(self, communities: list[Community]) -> None: - pass + return await self.connection_manager.execute_query(QUERY, [relationship.id]) - async def get(self, community_id: UUID) -> list[Community]: - pass - - async def update(self, community: Community) -> None: - pass - - async def delete(self, community_id: UUID) -> None: - pass - -class PostgresCommunityInfoHandler(CommunityInfoHandler): +class PostgresCommunityHandler(CommunityHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: self.project_name = kwargs.get("project_name") self.connection_manager = kwargs.get("connection_manager") + self.dimension = kwargs.get("dimension") + self.quantization_type = kwargs.get("quantization_type") async def create_tables(self) -> None: - pass - - async def create(self, community_infos: list[CommunityInfo]) -> None: - pass - - async def get(self, community_info_id: UUID) -> list[CommunityInfo]: - pass - - async def update(self, community_info: CommunityInfo) -> None: - pass - - async def delete(self, community_info_id: UUID) -> None: - pass - + vector_column_str = _decorate_vector_type( + f"({self.dimension})", self.quantization_type + ) + + # communities table, result of the Leiden algorithm + query = f""" + CREATE TABLE IF NOT EXISTS {self._get_table_name("community_info")} ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL PRIMARY KEY, + node TEXT NOT NULL, + cluster INT NOT NULL, + parent_cluster INT, + level INT NOT NULL, + is_final_cluster BOOLEAN NOT NULL, + relationship_ids INT[] NOT NULL, + collection_id UUID NOT NULL + );""" + await self.connection_manager.execute_query(query) + # communities_report table + query = f""" + CREATE TABLE IF NOT EXISTS {self._get_table_name("community")} ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + sid SERIAL PRIMARY KEY, + community_number INT NOT NULL, + collection_id UUID NOT NULL, + level INT NOT NULL, + name TEXT NOT NULL, + summary TEXT NOT NULL, + findings TEXT[] NOT NULL, + rating FLOAT NOT NULL, + rating_explanation TEXT NOT NULL, + embedding {vector_column_str} NOT NULL, + attributes JSONB, + UNIQUE (community_number, level, collection_id) + );""" + await self.connection_manager.execute_query(query) + async def create(self, communities: list[Community]) -> None: + await _add_objects( + objects=[community.__dict__ for community in communities], + full_table_name=self._get_table_name("community"), + connection_manager=self.connection_manager, + ) + async def update(self, community: Community) -> None: + return await _update_object( + object=community.__dict__, + full_table_name=self._get_table_name("community"), + connection_manager=self.connection_manager, + id_column="id", + ) + async def delete(self, community: Community) -> None: + return await _delete_object( + object_id=community.id, + full_table_name=self._get_table_name("community"), + connection_manager=self.connection_manager, + ) + async def get(self, collection_id: UUID, offset: int, limit: int) -> list[Community]: + QUERY = f""" + SELECT * FROM {self._get_table_name("community")} WHERE collection_id = $1 + OFFSET $2 LIMIT $3 + """ + params = [collection_id, offset, limit] + return [Community(**row) for row in await self.connection_manager.fetch_query(QUERY, params)] class PostgresGraphHandler(GraphHandler): @@ -470,13 +497,11 @@ def __init__( self.entities = PostgresEntityHandler(*args, **kwargs) self.relationships = PostgresRelationshipHandler(*args, **kwargs) self.communities = PostgresCommunityHandler(*args, **kwargs) - self.community_infos = PostgresCommunityInfoHandler(*args, **kwargs) self.handlers = [ self.entities, self.relationships, self.communities, - self.community_infos, ] async def create_tables(self) -> None: @@ -2401,6 +2426,7 @@ async def _update_object( UPDATE {full_table_name} SET {set_clause} WHERE {id_column} = ${len(non_null_attrs) + 1} + RETURNING id """ # Prepare parameters: values for SET clause + ID value for WHERE clause @@ -2410,7 +2436,9 @@ async def _update_object( ] params.append(object[id_column]) - return await connection_manager.execute_many(QUERY, [tuple(params)]) # type: ignore + ret = await connection_manager.execute_many(QUERY, [tuple(params)]) # type: ignore + import pdb; pdb.set_trace() + return ret async def _delete_object( object_id: UUID, diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 25ef7f2a8..c0e795839 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -96,19 +96,19 @@ class Relationship(R2RSerializable): id: Optional[UUID] = None sid: Optional[int] = None #serial ID - subject: str + subject: Optional[str] = None """The source entity name.""" - predicate: str + predicate: Optional[str] = None """A description of the relationship (optional).""" - subject_id: UUID | None = None + subject_id: Optional[UUID] = None """The source entity ID (optional).""" - object_id: UUID | None = None + object_id: Optional[UUID] = None """The target entity ID (optional).""" - object: str + object: Optional[str] = None """The target entity name.""" weight: float | None = 1.0 @@ -123,7 +123,7 @@ class Relationship(R2RSerializable): chunk_ids: list[UUID] = [] """List of text unit IDs in which the relationship appears (optional).""" - document_id: UUID | None = None + document_id: Optional[UUID] = None """Document ID in which the relationship appears (optional).""" attributes: dict[str, Any] | str = {} @@ -303,5 +303,33 @@ class KGExtraction(R2RSerializable): entities: list[Entity] relationships: list[Relationship] +class Graph(R2RSerializable): + """A request to create a graph.""" + + id: Optional[uuid.UUID] = None + name: Optional[str] = None + description: Optional[str] = None + document_ids: list[uuid.UUID] = [] + collection_ids: list[uuid.UUID] = [] + statistics: dict[str, Any] = {} + created_at: datetime + updated_at: datetime # Implemntation is not yet complete + status: str = "pending" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + @classmethod + def from_dict(cls, d: dict[str, Any]) -> "Graph": + return Graph( + id=d["id"], + name=d["name"], + description=d["description"], + document_ids=d["document_ids"], + collection_ids=d["collection_ids"], + statistics=d["statistics"], + created_at=d["created_at"], + updated_at=d["updated_at"], + status=d["status"], + ) diff --git a/py/shared/abstractions/kg.py b/py/shared/abstractions/kg.py index a942a2bdf..7c9bbcef0 100644 --- a/py/shared/abstractions/kg.py +++ b/py/shared/abstractions/kg.py @@ -10,7 +10,7 @@ class KGRunType(str, Enum): """Type of KG run.""" ESTIMATE = "estimate" - RUN = "run" + CREATE = "create" def __str__(self): return self.value From 314ad01b101d51aee0745b1b7e1cd6826308b04d Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Thu, 14 Nov 2024 17:52:51 -0800 Subject: [PATCH 15/21] up --- py/core/base/providers/database.py | 28 ++- py/core/main/api/v2/kg_router.py | 3 +- py/core/main/api/v3/graph_router.py | 13 +- .../main/orchestration/hatchet/kg_workflow.py | 4 +- py/core/main/services/kg_service.py | 60 +++-- py/core/pipes/kg/relationships_extraction.py | 10 +- py/core/pipes/kg/storage.py | 6 +- py/core/providers/database/kg.py | 214 ++++++++++-------- py/core/providers/database/kg_tmp/entity.py | 7 +- py/core/providers/database/kg_tmp/main.py | 1 - py/shared/abstractions/graph.py | 11 +- py/shared/api/models/kg/responses_v3.py | 3 +- .../ingestion/test_contextual_embedding.py | 4 +- py/tests/core/providers/kg/test_kg_logic.py | 1 + 14 files changed, 206 insertions(+), 159 deletions(-) diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 328433ce6..a6556c2ac 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -17,7 +17,7 @@ from core.base.abstractions import ( Community, - CommunityInfo, + CommunityInfo, Entity, Graph, KGExtraction, @@ -635,7 +635,7 @@ async def list_chunks( # ) -> Any: # """Update an entity in storage.""" # pass - + # @abstractmethod # async def delete_entity( # self, @@ -873,9 +873,8 @@ async def list_chunks( # raise NotImplementedError - class EntityHandler(Handler): - + @abstractmethod async def create(self, *args: Any, **kwargs: Any) -> None: """Create entities in storage.""" @@ -918,6 +917,7 @@ async def delete(self, *args: Any, **kwargs: Any) -> None: """Delete relationships from storage.""" pass + class CommunityHandler(Handler): @abstractmethod async def create(self, *args: Any, **kwargs: Any) -> None: @@ -939,6 +939,7 @@ async def delete(self, *args: Any, **kwargs: Any) -> None: """Delete communities from storage.""" pass + class CommunityInfoHandler(Handler): @abstractmethod async def create(self, *args: Any, **kwargs: Any) -> None: @@ -948,7 +949,7 @@ async def create(self, *args: Any, **kwargs: Any) -> None: @abstractmethod async def get(self, *args: Any, **kwargs: Any) -> list[CommunityInfo]: """Get community info from storage.""" - pass + pass @abstractmethod async def update(self, *args: Any, **kwargs: Any) -> None: @@ -960,8 +961,9 @@ async def delete(self, *args: Any, **kwargs: Any) -> None: """Delete community info from storage.""" pass + class GraphHandler(Handler): - + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) @@ -997,7 +999,6 @@ async def remove_document(self, *args: Any, **kwargs: Any) -> None: pass - class PromptHandler(Handler): """Abstract base class for prompt handling operations.""" @@ -1672,7 +1673,9 @@ async def get_entity_map( self, offset: int, limit: int, document_id: UUID ) -> dict[str, dict[str, list[dict[str, Any]]]]: """Forward to KG handler get_entity_map method.""" - return await self.graph_handler.get_entity_map(offset, limit, document_id) + return await self.graph_handler.get_entity_map( + offset, limit, document_id + ) # Community methods async def add_community_info(self, communities: list[Any]) -> None: @@ -1696,9 +1699,7 @@ async def get_communities( community_numbers=community_numbers, ) - async def add_community( - self, community: Community - ) -> None: + async def add_community(self, community: Community) -> None: """Forward to KG handler add_community method.""" return await self.graph_handler.add_community(community) @@ -1710,9 +1711,7 @@ async def get_community_details( community_number, collection_id ) - async def get_community( - self, collection_id: UUID - ) -> list[Community]: + async def get_community(self, collection_id: UUID) -> list[Community]: """Forward to KG handler get_community method.""" return await self.graph_handler.get_community(collection_id) @@ -2079,4 +2078,3 @@ async def list_chunks( return await self.vector_handler.list_chunks( offset, limit, filters, include_vectors ) - diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index 8fdbdfeea..7b00a1a2f 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -293,7 +293,8 @@ async def get_entities( @self.base_endpoint async def get_relationships( collection_id: Optional[UUID] = Query( - None, description="Collection ID to retrieve relationships from." + None, + description="Collection ID to retrieve relationships from.", ), entity_names: Optional[list[str]] = Query( None, description="Entity names to filter by." diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index a8ed83091..e24b39318 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -72,7 +72,6 @@ def _get_path_level(self, request: Request) -> EntityLevel: def _setup_routes(self): # create graph new endpoint - @self.router.post( "/documents/{id}/graphs/{run_type}", ) @@ -354,7 +353,7 @@ async def create_entities_v3( else: level = EntityLevel.COLLECTION - # set entity level if not set + # set entity level if not set for entity in entities: if entity.level: if entity.level != level: @@ -371,9 +370,9 @@ async def create_entities_v3( raise R2RException( "Entity extraction IDs must include the chunk ID or should be empty.", 400 ) - + elif level == EntityLevel.DOCUMENT: - for entity in entities: + for entity in entities: if entity.document_id: if entity.document_id != id: raise R2RException( @@ -384,7 +383,7 @@ async def create_entities_v3( elif level == EntityLevel.COLLECTION: for entity in entities: - if entity.collection_id: + if entity.collection_id: if entity.collection_id != id: raise R2RException( "Entity collection IDs must match the collection ID or should be empty.", 400 @@ -880,7 +879,7 @@ async def update_relationship( ) level = self._get_path_level(request) - + if not relationship.id: relationship.id = relationship_id else: @@ -1119,7 +1118,7 @@ async def create_entities( async def create_empty_graph( auth_user=Depends(self.providers.auth.auth_wrapper), ): - + diff --git a/py/core/main/orchestration/hatchet/kg_workflow.py b/py/core/main/orchestration/hatchet/kg_workflow.py index 34509d9fe..2462c7116 100644 --- a/py/core/main/orchestration/hatchet/kg_workflow.py +++ b/py/core/main/orchestration/hatchet/kg_workflow.py @@ -188,9 +188,7 @@ async def get_document_ids_for_create_graph( ) else: - return_val = { - "document_ids": [str(input_data["document_id"])] - } + return_val = {"document_ids": [str(input_data["document_id"])]} return return_val diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 6660bded4..d554f9412 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -132,7 +132,7 @@ async def create_entities_v3( return await self.providers.database.graph_handler.entities.create( entities, **kwargs ) - + @telemetry_event("list_entities_v3") async def list_entities_v3( self, @@ -146,13 +146,13 @@ async def list_entities_v3( **kwargs, ): return await self.providers.database.graph_handler.entities.get( - level=level, - id=id, - entity_names=entity_names, - entity_categories=entity_categories, - attributes=attributes, - offset=offset, - limit=limit + level=level, + id=id, + entity_names=entity_names, + entity_categories=entity_categories, + attributes=attributes, + offset=offset, + limit=limit, ) @telemetry_event("update_entity_v3") @@ -204,18 +204,23 @@ async def create_relationships_v3( relationships: list[Relationship], **kwargs, ): - return await self.providers.database.graph_handler.relationships.create( - relationships + return ( + await self.providers.database.graph_handler.relationships.create( + relationships + ) ) - + @telemetry_event("delete_relationship_v3") async def delete_relationship_v3( self, relationship: Relationship, **kwargs, ): - return await self.providers.database.graph_handler.relationships.delete(relationship) - + return ( + await self.providers.database.graph_handler.relationships.delete( + relationship + ) + ) @telemetry_event("update_relationship_v3") async def update_relationship_v3( @@ -223,34 +228,43 @@ async def update_relationship_v3( relationship: Relationship, **kwargs, ): - return await self.providers.database.graph_handler.relationships.update(relationship) - + return ( + await self.providers.database.graph_handler.relationships.update( + relationship + ) + ) ################### COMMUNITIES ################### - @telemetry_event("create_communities_v3") + @telemetry_event("create_communities_v3") async def create_communities_v3( self, communities: list[Community], **kwargs, ): - return await self.providers.database.graph_handler.communities.create(communities) - + return await self.providers.database.graph_handler.communities.create( + communities + ) + @telemetry_event("update_community_v3") async def update_community_v3( self, community: Community, **kwargs, ): - return await self.providers.database.graph_handler.communities.update(community) - + return await self.providers.database.graph_handler.communities.update( + community + ) + @telemetry_event("delete_community_v3") async def delete_community_v3( self, community: Community, **kwargs, ): - return await self.providers.database.graph_handler.communities.delete(community) + return await self.providers.database.graph_handler.communities.delete( + community + ) @telemetry_event("list_communities_v3") async def list_communities_v3( @@ -259,7 +273,9 @@ async def list_communities_v3( level: EntityLevel, **kwargs, ): - return await self.providers.database.graph_handler.communities.get(id, level) + return await self.providers.database.graph_handler.communities.get( + id, level + ) ################### GRAPH ################### diff --git a/py/core/pipes/kg/relationships_extraction.py b/py/core/pipes/kg/relationships_extraction.py index f2dba7af4..3513c0728 100644 --- a/py/core/pipes/kg/relationships_extraction.py +++ b/py/core/pipes/kg/relationships_extraction.py @@ -179,9 +179,7 @@ def parse_fn(response_str: str) -> Any: entities, relationships = parse_fn(kg_extraction) return KGExtraction( - chunk_ids=[ - extraction.id for extraction in extractions - ], + chunk_ids=[extraction.id for extraction in extractions], document_id=extractions[0].document_id, entities=entities, relationships=relationships, @@ -268,8 +266,10 @@ async def _run_logic( # type: ignore ) if filter_out_existing_chunks: - existing_chunk_ids = await self.database_provider.get_existing_entity_chunk_ids( - document_id=document_id + existing_chunk_ids = ( + await self.database_provider.get_existing_entity_chunk_ids( + document_id=document_id + ) ) extractions = [ extraction diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index 8e846d5dd..b5cdb31e4 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -53,9 +53,9 @@ async def store( Stores a batch of knowledge graph extractions in the graph database. """ try: - # clean up and remove this method. + # clean up and remove this method. # make add_kg_extractions a method in the GraphHandler - + total_entities, total_relationships = 0, 0 for extraction in kg_extractions: @@ -95,7 +95,7 @@ async def store( ) return (total_entities, total_relationships) - + except Exception as e: error_message = f"Failed to store knowledge graph extractions in the database: {e}" logger.error(error_message) diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 43fafb4ef..0c6909fc1 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -18,7 +18,13 @@ Relationship, ) -from core.base.providers.database import GraphHandler, EntityHandler, RelationshipHandler, CommunityHandler, CommunityInfoHandler +from core.base.providers.database import ( + GraphHandler, + EntityHandler, + RelationshipHandler, + CommunityHandler, + CommunityInfoHandler, +) from core.base.abstractions import ( CommunityInfo, @@ -40,11 +46,11 @@ class PostgresEntityHandler(EntityHandler): """Handler for managing entities in PostgreSQL database. - + Provides methods for CRUD operations on entities at different levels (chunk, document, collection). Handles creation of database tables and management of entity data. """ - + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize the PostgresEntityHandler. @@ -59,7 +65,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__( project_name=kwargs.get("project_name"), - connection_manager=kwargs.get("connection_manager") + connection_manager=kwargs.get("connection_manager"), ) async def create_tables(self) -> None: @@ -69,7 +75,7 @@ async def create_tables(self) -> None: - chunk_entity: For storing chunk-level entities - document_entity: For storing document-level entities with embeddings - collection_entity: For storing deduplicated collection-level entities - + Each table has appropriate columns and constraints for its level. """ vector_column_str = _decorate_vector_type( @@ -125,20 +131,20 @@ async def create_tables(self) -> None: async def create(self, entities: list[Entity]) -> None: """Create new entities in the database. - + Args: entities: List of Entity objects to create. All entities must be of the same level. - + Raises: ValueError: If entity level is not set or if entities have different levels. """ # TODO: move this the router layer # assert that all entities are of the same level - entity_level = entities[0].level + entity_level = entities[0].level if entity_level is None: raise ValueError("Entity level is not set") - + for entity in entities: if entity.level != entity_level: raise ValueError("All entities must be of the same level") @@ -147,11 +153,11 @@ async def create(self, entities: list[Entity]) -> None: objects=[entity.__dict__ for entity in entities], full_table_name=self._get_table_name(entity_level + "_entity"), connection_manager=self.connection_manager, - exclude_attributes=["level"] + exclude_attributes=["level"], ) async def get( - self, + self, level: EntityLevel, id: Optional[UUID] = None, entity_names: Optional[list[str]] = None, @@ -220,11 +226,12 @@ async def get( QUERY = f""" SELECT COUNT(*) from {self._get_table_name(level + "_entity")} WHERE {filter} """ - count = (await self.connection_manager.fetch_query(QUERY, params[:-2]))[0]["count"] + count = ( + await self.connection_manager.fetch_query(QUERY, params[:-2]) + )[0]["count"] return output, count - async def update(self, entity: Entity) -> None: """Update an existing entity in the database. @@ -252,25 +259,25 @@ async def update(self, entity: Entity) -> None: QUERY = f""" SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE {filter} """ - count = ( - await self.connection_manager.fetch_query( - QUERY, params - ) - )[0]["count"] + count = (await self.connection_manager.fetch_query(QUERY, params))[0][ + "count" + ] # don't override the chunk_ids entity.chunk_ids = None entity.level = None # get non null attributes - non_null_attributes = [k for k, v in entity.to_dict().items() if v is not None] + non_null_attributes = [ + k for k, v in entity.to_dict().items() if v is not None + ] if count == 0: raise R2RException("Entity does not exist", 204) return await _update_object( - object=entity.__dict__, - full_table_name=self._get_table_name(table_name), + object=entity.__dict__, + full_table_name=self._get_table_name(table_name), connection_manager=self.connection_manager, id_column="id", ) @@ -289,6 +296,7 @@ async def delete(self, entity: Entity) -> None: connection_manager=self.connection_manager, ) + class PostgresRelationshipHandler(RelationshipHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: self.project_name = kwargs.get("project_name") @@ -334,14 +342,15 @@ async def create(self, relationships: list[Relationship]) -> None: connection_manager=self.connection_manager, ) - async def get(self, - id: UUID, - level: EntityLevel, - entity_names: Optional[list[str]] = None, - relationship_types: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - offset: int = 0, - limit: int = -1 + async def get( + self, + id: UUID, + level: EntityLevel, + entity_names: Optional[list[str]] = None, + relationship_types: Optional[list[str]] = None, + attributes: Optional[list[str]] = None, + offset: int = 0, + limit: int = -1, ) -> list[Relationship]: """Get relationships from storage by ID.""" @@ -351,7 +360,7 @@ async def get(self, }[level] params = [id] - + if entity_names: filter += " AND (subject = ANY($2) OR object = ANY($2))" params.append(entity_names) @@ -359,21 +368,22 @@ async def get(self, if relationship_types: filter += " AND predicate = ANY($3)" params.append(relationship_types) - + QUERY = f""" SELECT * FROM {self._get_table_name("chunk_relationship")} WHERE {filter} OFFSET ${len(params)+1} LIMIT ${len(params) + 2} """ - params.extend([offset, limit]) rows = await self.connection_manager.fetch_query(QUERY, params) - + QUERY_COUNT = f""" SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE {filter} """ - count = (await self.connection_manager.fetch_query(QUERY_COUNT, params[:-2]))[0]["count"] + count = ( + await self.connection_manager.fetch_query(QUERY_COUNT, params[:-2]) + )[0]["count"] return [Relationship(**row) for row in rows], count @@ -391,7 +401,10 @@ async def delete(self, relationship: Relationship) -> None: DELETE FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 """ - return await self.connection_manager.execute_query(QUERY, [relationship.id]) + return await self.connection_manager.execute_query( + QUERY, [relationship.id] + ) + class PostgresCommunityHandler(CommunityHandler): @@ -406,7 +419,7 @@ async def create_tables(self) -> None: vector_column_str = _decorate_vector_type( f"({self.dimension})", self.quantization_type ) - + # communities table, result of the Leiden algorithm query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("community_info")} ( @@ -465,20 +478,25 @@ async def delete(self, community: Community) -> None: connection_manager=self.connection_manager, ) - async def get(self, collection_id: UUID, offset: int, limit: int) -> list[Community]: + async def get( + self, collection_id: UUID, offset: int, limit: int + ) -> list[Community]: QUERY = f""" SELECT * FROM {self._get_table_name("community")} WHERE collection_id = $1 OFFSET $2 LIMIT $3 """ params = [collection_id, offset, limit] - return [Community(**row) for row in await self.connection_manager.fetch_query(QUERY, params)] + return [ + Community(**row) + for row in await self.connection_manager.fetch_query(QUERY, params) + ] class PostgresGraphHandler(GraphHandler): """Handler for Knowledge Graph METHODS in PostgreSQL.""" def __init__( - self, + self, # project_name: str, # connection_manager: PostgresConnectionManager, # collection_handler: PostgresCollectionHandler, @@ -516,7 +534,7 @@ async def create_tables(self) -> None: attributes JSONB NOT NULL ); """ - + await self.connection_manager.execute_query(QUERY) for handler in self.handlers: @@ -528,13 +546,17 @@ async def create(self, graph: Graph) -> None: INSERT INTO {self._get_table_name("graph")} (id, status, created_at, updated_at, document_ids, collection_ids, attributes) VALUES ($1, $2, $3, $4, $5, $6, $7) """ - await self.connection_manager.execute_query(QUERY, *graph.to_dict().values()) + await self.connection_manager.execute_query( + QUERY, *graph.to_dict().values() + ) - async def update(self, graph: Graph) -> None: + async def update(self, graph: Graph) -> None: QUERY = f""" UPDATE {self._get_table_name("graph")} SET status = $2, updated_at = $3, document_ids = $4, collection_ids = $5, attributes = $6 WHERE id = $1 """ - await self.connection_manager.execute_query(QUERY, *graph.to_dict().values()) + await self.connection_manager.execute_query( + QUERY, *graph.to_dict().values() + ) async def delete(self, graph_id: UUID) -> None: QUERY = f""" @@ -542,36 +564,49 @@ async def delete(self, graph_id: UUID) -> None: """ await self.connection_manager.execute_query(QUERY, graph_id) - async def get(self, graph_id: UUID) -> Graph: QUERY = f""" SELECT * FROM {self._get_table_name("graph")} WHERE id = $1 """ - return Graph.from_dict(await self.connection_manager.fetch_query(QUERY, graph_id)) - + return Graph.from_dict( + await self.connection_manager.fetch_query(QUERY, graph_id) + ) + async def add_document(self, graph_id: UUID, document_id: UUID) -> None: QUERY = f""" UPDATE {self._get_table_name("graph")} SET document_ids = array_append(document_ids, $2) WHERE id = $1 """ - await self.connection_manager.execute_query(QUERY, graph_id, document_id) + await self.connection_manager.execute_query( + QUERY, graph_id, document_id + ) async def remove_document(self, graph_id: UUID, document_id: UUID) -> None: QUERY = f""" UPDATE {self._get_table_name("graph")} SET document_ids = array_remove(document_ids, $2) WHERE id = $1 """ - await self.connection_manager.execute_query(QUERY, graph_id, document_id) + await self.connection_manager.execute_query( + QUERY, graph_id, document_id + ) - async def add_collection(self, graph_id: UUID, collection_id: UUID) -> None: + async def add_collection( + self, graph_id: UUID, collection_id: UUID + ) -> None: QUERY = f""" UPDATE {self._get_table_name("graph")} SET collection_ids = array_append(collection_ids, $2) WHERE id = $1 """ - await self.connection_manager.execute_query(QUERY, graph_id, collection_id) + await self.connection_manager.execute_query( + QUERY, graph_id, collection_id + ) - async def remove_collection(self, graph_id: UUID, collection_id: UUID) -> None: + async def remove_collection( + self, graph_id: UUID, collection_id: UUID + ) -> None: QUERY = f""" UPDATE {self._get_table_name("graph")} SET collection_ids = array_remove(collection_ids, $2) WHERE id = $1 """ - await self.connection_manager.execute_query(QUERY, graph_id, collection_id) + await self.connection_manager.execute_query( + QUERY, graph_id, collection_id + ) class PostgresGraphHandler_v1(GraphHandler): @@ -716,10 +751,8 @@ async def create_tables(self): await self.connection_manager.execute_query(query) - ################### ENTITY METHODS ################### - # async def get_entities_v3( # self, # level: EntityLevel, @@ -837,7 +870,6 @@ async def get_entities( return {"entities": entities, "total_entries": total_entries} - async def add_entities( self, entities: list[Entity], @@ -873,7 +905,6 @@ async def add_entities( cleaned_entities, table_name, conflict_columns ) - async def create_entities_v3( self, level: EntityLevel, id: UUID, entities: list[Entity] ) -> None: @@ -909,7 +940,6 @@ async def create_entities_v3( # QUERY, [entity.id, collection_id] # ) - async def delete_node_via_document_id( self, document_id: UUID, collection_id: UUID ) -> None: @@ -963,8 +993,6 @@ async def delete_node_via_document_id( return None return None - - ##################### RELATIONSHIP METHODS ##################### async def add_relationships( @@ -1057,7 +1085,6 @@ async def get_all_relationships( ) return [Relationship(**relationship) for relationship in relationships] - async def create_relationship( self, collection_id: UUID, relationship: Relationship ) -> None: @@ -1106,7 +1133,6 @@ async def delete_relationship(self, relationship_id: UUID) -> None: """ await self.connection_manager.execute_query(QUERY, [relationship_id]) - async def get_relationships( self, offset: int, @@ -1170,9 +1196,7 @@ async def get_relationships( ####################### COMMUNITY METHODS ####################### - async def get_communities( - self, collection_id: UUID - ) -> list[Community]: + async def get_communities(self, collection_id: UUID) -> list[Community]: QUERY = f""" SELECT *c FROM {self._get_table_name("community")} WHERE collection_id = $1 """ @@ -1191,8 +1215,6 @@ async def check_communities_exist( ) return [item["community_number"] for item in community_numbers] - - async def add_community_info( self, communities: list[CommunityInfo] ) -> None: @@ -1268,8 +1290,6 @@ async def get_communities( "communities": communities, "total_entries": total_entries, } - - async def get_community_details( self, community_number: int, collection_id: UUID @@ -1335,10 +1355,7 @@ async def get_community_details( return level, entities, relationships - - async def add_community( - self, community: Community - ) -> None: + async def add_community(self, community: Community) -> None: # TODO: Fix in the short term. # we need to do this because postgres insert needs to be a string @@ -1567,7 +1584,6 @@ async def get_entity_map( return entity_map - async def get_graph_status(self, collection_id: UUID) -> dict: # check document_info table for the documents in the collection and return the status of each document kg_extraction_statuses = await self.connection_manager.fetch_query( @@ -1622,7 +1638,6 @@ async def get_graph_status(self, collection_id: UUID) -> dict: "community_count": community_count[0]["count"], } - ####################### ESTIMATION METHODS ####################### async def get_creation_estimate( @@ -1862,7 +1877,6 @@ async def get_deduplication_estimate( detail="An unexpected error occurred while fetching the deduplication estimate.", ) - ####################### GRAPH SEARCH METHODS ####################### async def graph_search( # type: ignore @@ -2180,7 +2194,6 @@ async def _incremental_clustering( num_communities = max([item.cluster for item in community_info]) + 1 return num_communities - async def _compute_leiden_communities( self, graph: Any, @@ -2210,7 +2223,6 @@ async def _compute_leiden_communities( except ImportError as e: raise ImportError("Please install the graspologic package.") from e - ####################### UTILITY METHODS ####################### def _get_str_estimation_output(self, x: tuple[Any, Any]) -> str: @@ -2237,8 +2249,6 @@ async def create_vector_index(self): # this needs to be run periodically for every collection. raise NotImplementedError - - async def structured_query(self): raise NotImplementedError @@ -2354,6 +2364,7 @@ async def update_entity_descriptions(self, entities: list[Entity]): ####################### PRIVATE METHODS ########################## + async def _add_objects( objects: list[Any], full_table_name: str, @@ -2365,7 +2376,11 @@ async def _add_objects( Upsert objects into the specified table. """ # Get non-null attributes from the first object - non_null_attrs = {k: v for k, v in objects[0].items() if v is not None and k not in exclude_attributes} + non_null_attrs = { + k: v + for k, v in objects[0].items() + if v is not None and k not in exclude_attributes + } columns = ", ".join(non_null_attrs.keys()) placeholders = ", ".join(f"${i+1}" for i in range(len(non_null_attrs))) @@ -2388,14 +2403,23 @@ async def _add_objects( # Filter out null values for each object params = [ tuple( - (json.dumps(v) if isinstance(v, dict) else str(v) if "embedding" in k else v) - for k, v in ((k, v) for k, v in obj.items() if v is not None and k not in exclude_attributes) + ( + json.dumps(v) + if isinstance(v, dict) + else str(v) if "embedding" in k else v + ) + for k, v in ( + (k, v) + for k, v in obj.items() + if v is not None and k not in exclude_attributes + ) ) for obj in objects ] return await connection_manager.execute_many(QUERY, params) # type: ignore + async def _update_object( object: dict[str, Any], full_table_name: str, @@ -2405,7 +2429,7 @@ async def _update_object( ) -> asyncpg.Record: """ Update a single object in the specified table. - + Args: object: Dictionary containing the fields to update full_table_name: Name of the table to update @@ -2415,13 +2439,16 @@ async def _update_object( """ # Get non-null attributes, excluding the ID and any excluded attributes non_null_attrs = { - k: v for k, v in object.items() + k: v + for k, v in object.items() if v is not None and k != id_column and k not in exclude_attributes } - + # Create SET clause with placeholders - set_clause = ", ".join(f"{k} = ${i+1}" for i, k in enumerate(non_null_attrs.keys())) - + set_clause = ", ".join( + f"{k} = ${i+1}" for i, k in enumerate(non_null_attrs.keys()) + ) + QUERY = f""" UPDATE {full_table_name} SET {set_clause} @@ -2431,15 +2458,22 @@ async def _update_object( # Prepare parameters: values for SET clause + ID value for WHERE clause params = [ - (json.dumps(v) if isinstance(v, dict) else str(v) if "embedding" in k else v) + ( + json.dumps(v) + if isinstance(v, dict) + else str(v) if "embedding" in k else v + ) for k, v in non_null_attrs.items() ] params.append(object[id_column]) ret = await connection_manager.execute_many(QUERY, [tuple(params)]) # type: ignore - import pdb; pdb.set_trace() + import pdb + + pdb.set_trace() return ret + async def _delete_object( object_id: UUID, full_table_name: str, @@ -2448,4 +2482,4 @@ async def _delete_object( QUERY = f""" DELETE FROM {full_table_name} WHERE id = $1 """ - return await connection_manager.execute_query(QUERY, [object_id]) \ No newline at end of file + return await connection_manager.execute_query(QUERY, [object_id]) diff --git a/py/core/providers/database/kg_tmp/entity.py b/py/core/providers/database/kg_tmp/entity.py index bbd077427..3e8f5a48b 100644 --- a/py/core/providers/database/kg_tmp/entity.py +++ b/py/core/providers/database/kg_tmp/entity.py @@ -1,8 +1,9 @@ from core.base.providers.database import Handler from core.providers.database.kg import PostgresConnectionManager + class PostgresEntityHandler(Handler): - def __init__(self, project_name: str, connection_manager: PostgresConnectionManager): + def __init__( + self, project_name: str, connection_manager: PostgresConnectionManager + ): super().__init__(project_name, connection_manager) - - diff --git a/py/core/providers/database/kg_tmp/main.py b/py/core/providers/database/kg_tmp/main.py index ef209ab3d..644c6db9b 100644 --- a/py/core/providers/database/kg_tmp/main.py +++ b/py/core/providers/database/kg_tmp/main.py @@ -49,4 +49,3 @@ # def __init__(self, project_name: str, connection_manager: PostgresConnectionManager): # super().__init__(project_name, connection_manager) - diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index c0e795839..2556541e1 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -53,12 +53,13 @@ class EntityLevel(str, Enum): def __str__(self): return self.value + class Entity(R2RSerializable): """An entity extracted from a document.""" name: Optional[str] = None id: Optional[UUID] = None - sid: Optional[int] = None #serial ID + sid: Optional[int] = None # serial ID level: Optional[EntityLevel] = None category: Optional[str] = None description: Optional[str] = None @@ -94,7 +95,7 @@ class Relationship(R2RSerializable): """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" id: Optional[UUID] = None - sid: Optional[int] = None #serial ID + sid: Optional[int] = None # serial ID subject: Optional[str] = None """The source entity name.""" @@ -167,6 +168,7 @@ def from_dict( # type: ignore attributes=d.get(attributes_key, {}), ) + @dataclass class CommunityInfo(BaseModel): """A protocol for a community in the system.""" @@ -295,6 +297,7 @@ def from_dict(cls, d: dict[str, Any]) -> "Graph": attributes=d["attributes"], ) + class KGExtraction(R2RSerializable): """An extraction from a document that is part of a knowledge graph.""" @@ -303,6 +306,7 @@ class KGExtraction(R2RSerializable): entities: list[Entity] relationships: list[Relationship] + class Graph(R2RSerializable): """A request to create a graph.""" @@ -313,13 +317,12 @@ class Graph(R2RSerializable): collection_ids: list[uuid.UUID] = [] statistics: dict[str, Any] = {} created_at: datetime - updated_at: datetime # Implemntation is not yet complete + updated_at: datetime # Implemntation is not yet complete status: str = "pending" def __init__(self, **kwargs): super().__init__(**kwargs) - @classmethod def from_dict(cls, d: dict[str, Any]) -> "Graph": return Graph( diff --git a/py/shared/api/models/kg/responses_v3.py b/py/shared/api/models/kg/responses_v3.py index af6bd2ad1..ee69715d7 100644 --- a/py/shared/api/models/kg/responses_v3.py +++ b/py/shared/api/models/kg/responses_v3.py @@ -250,7 +250,6 @@ class Config: } - class KGTunePromptResponse(R2RSerializable): """Response containing just the tuned prompt string.""" @@ -278,4 +277,4 @@ class Config: WrappedKGEntityDeduplicationResponse = ResultsWrapper[ KGEntityDeduplicationResponse ] -WrappedKGDeletionResponse = ResultsWrapper[KGDeletionResponse] \ No newline at end of file +WrappedKGDeletionResponse = ResultsWrapper[KGDeletionResponse] diff --git a/py/tests/core/providers/ingestion/test_contextual_embedding.py b/py/tests/core/providers/ingestion/test_contextual_embedding.py index c00355fe5..c63b438ed 100644 --- a/py/tests/core/providers/ingestion/test_contextual_embedding.py +++ b/py/tests/core/providers/ingestion/test_contextual_embedding.py @@ -54,9 +54,7 @@ def chunk_ids(): @pytest.fixture -def sample_chunks( - sample_document_id, sample_user, collection_ids, chunk_ids -): +def sample_chunks(sample_document_id, sample_user, collection_ids, chunk_ids): return [ VectorEntry( chunk_id=chunk_ids[0], diff --git a/py/tests/core/providers/kg/test_kg_logic.py b/py/tests/core/providers/kg/test_kg_logic.py index 6edb0324b..bc129609c 100644 --- a/py/tests/core/providers/kg/test_kg_logic.py +++ b/py/tests/core/providers/kg/test_kg_logic.py @@ -254,6 +254,7 @@ async def test_add_relationships( assert len(relationships["relationships"]) == 2 assert relationships["total_entries"] == 2 + @pytest.mark.asyncio async def test_get_entity_map( postgres_db_provider, From 623de531ccdf3c9e7a932b3eeb1f37d79efe76b6 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Thu, 14 Nov 2024 17:59:16 -0800 Subject: [PATCH 16/21] up --- py/core/base/providers/database.py | 279 ------------------ py/core/providers/database/kg_tmp/__init__.py | 0 .../providers/database/kg_tmp/community.py | 0 .../database/kg_tmp/community_info.py | 0 py/core/providers/database/kg_tmp/entity.py | 9 - py/core/providers/database/kg_tmp/graph.py | 0 py/core/providers/database/kg_tmp/main.py | 51 ---- .../providers/database/kg_tmp/relationship.py | 0 8 files changed, 339 deletions(-) delete mode 100644 py/core/providers/database/kg_tmp/__init__.py delete mode 100644 py/core/providers/database/kg_tmp/community.py delete mode 100644 py/core/providers/database/kg_tmp/community_info.py delete mode 100644 py/core/providers/database/kg_tmp/entity.py delete mode 100644 py/core/providers/database/kg_tmp/graph.py delete mode 100644 py/core/providers/database/kg_tmp/main.py delete mode 100644 py/core/providers/database/kg_tmp/relationship.py diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index fa44990e1..cf659979c 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -596,285 +596,6 @@ async def list_chunks( ) -> dict[str, Any]: pass - -# class GraphHandler(Handler): -# """Base handler for Knowledge Graph operations.""" - -# @abstractmethod -# async def create_tables(self) -> None: -# """Create required database tables.""" -# pass - -# ### ENTITIES CRUD OPS ### -# @abstractmethod -# async def create_entities( -# self, -# entities: list[Entity], -# table_name: str, -# conflict_columns: list[str] = [], -# ) -> Any: -# """Add entities to storage.""" -# pass - -# @abstractmethod -# async def get_entities( -# self, -# level: EntityLevel, -# entity_names: Optional[list[str]] = None, -# attributes: Optional[list[str]] = None, -# offset: int = 0, -# limit: int = -1, -# ) -> dict: -# """Get entities from storage.""" -# pass - -# @abstractmethod -# async def update_entity( -# self, -# entity: Entity, -# table_name: str, -# conflict_columns: list[str] = [], -# ) -> Any: -# """Update an entity in storage.""" -# pass - -# @abstractmethod -# async def delete_entity( -# self, -# id: UUID, -# chunk_id: Optional[UUID] = None, -# document_id: Optional[UUID] = None, -# collection_id: Optional[UUID] = None, -# graph_id: Optional[UUID] = None, -# ) -> None: -# """Delete an entity from storage.""" -# pass - -# ### RELATIONSHIPS CRUD OPS ### -# @abstractmethod -# async def add_relationships( -# self, -# relationships: list[Relationship], -# table_name: str = "chunk_relationship", -# ) -> None: -# """Add relationships to storage.""" -# pass - -# @abstractmethod -# async def get_entity_map( -# self, offset: int, limit: int, document_id: UUID -# ) -> dict[str, dict[str, list[dict[str, Any]]]]: -# """Get entity map for a document.""" -# pass - -# @abstractmethod -# async def graph_search( -# self, query: str, **kwargs: Any -# ) -> AsyncGenerator[Any, None]: -# """Perform vector similarity search.""" -# pass - -# # Community management -# @abstractmethod -# async def add_community_info(self, communities: list[Any]) -> None: -# """Add communities to storage.""" -# pass - -# @abstractmethod -# async def get_communities( -# self, -# offset: int, -# limit: int, -# collection_id: Optional[UUID] = None, -# levels: Optional[list[int]] = None, -# community_numbers: Optional[list[int]] = None, -# ) -> dict: -# """Get communities for a collection.""" -# pass - -# @abstractmethod -# async def add_community( -# self, community: Community -# ) -> None: -# """Add a community report.""" -# pass - -# @abstractmethod -# async def get_community_details( -# self, community_number: int, collection_id: UUID -# ) -> Tuple[int, list[Entity], list[Relationship]]: -# """Get detailed information about a community.""" -# pass - -# @abstractmethod -# async def get_community( -# self, collection_id: UUID -# ) -> list[Community]: -# """Get community reports for a collection.""" -# pass - -# @abstractmethod -# async def check_community_exists( -# self, collection_id: UUID, offset: int, limit: int -# ) -> list[int]: -# """Check which community reports exist.""" -# pass - -# @abstractmethod -# async def perform_graph_clustering( -# self, -# collection_id: UUID, -# leiden_params: dict[str, Any], -# ) -> int: -# """Perform graph clustering.""" -# pass - -# # Graph operations -# @abstractmethod -# async def delete_graph_for_collection( -# self, collection_id: UUID, cascade: bool = False -# ) -> None: -# """Delete graph data for a collection.""" -# pass - -# @abstractmethod -# async def delete_node_via_document_id( -# self, document_id: UUID, collection_id: UUID -# ) -> None: -# """Delete a node using document ID.""" -# pass - -# # Entity and Relationship management -# @abstractmethod -# async def get_entities( -# self, -# offset: int, -# limit: int, -# collection_id: Optional[UUID] = None, -# entity_ids: Optional[list[str]] = None, -# entity_names: Optional[list[str]] = None, -# entity_table_name: str = "document_entity", -# extra_columns: Optional[list[str]] = None, -# ) -> dict: -# """Get entities from storage.""" -# pass - -# @abstractmethod -# async def create_entities_v3(self, entities: list[Entity]) -> None: -# """Create entities in storage.""" -# pass - - -# @abstractmethod -# async def get_relationships( -# self, -# offset: int, -# limit: int, -# collection_id: Optional[UUID] = None, -# entity_names: Optional[list[str]] = None, -# relationship_ids: Optional[list[str]] = None, -# ) -> dict: -# """Get relationships from storage.""" -# pass - -# @abstractmethod -# async def get_entity_count( -# self, -# collection_id: Optional[UUID] = None, -# document_id: Optional[UUID] = None, -# distinct: bool = False, -# entity_table_name: str = "document_entity", -# ) -> int: -# """Get entity count.""" -# pass - -# @abstractmethod -# async def get_relationship_count( -# self, -# collection_id: Optional[UUID] = None, -# document_id: Optional[UUID] = None, -# ) -> int: -# """Get relationship count.""" -# pass - -# # Cost estimation methods -# @abstractmethod -# async def get_creation_estimate( -# self, collection_id: UUID, kg_creation_settings: KGCreationSettings -# ): -# """Get creation cost estimate.""" -# pass - -# @abstractmethod -# async def get_enrichment_estimate( -# self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings -# ): -# """Get enrichment cost estimate.""" -# pass - -# @abstractmethod -# async def get_deduplication_estimate( -# self, -# collection_id: UUID, -# kg_deduplication_settings: KGEntityDeduplicationSettings, -# ): -# """Get deduplication cost estimate.""" -# pass - -# # Other operations -# @abstractmethod -# async def create_vector_index(self) -> None: -# """Create vector index.""" -# raise NotImplementedError - -# @abstractmethod -# async def delete_relationships(self, relationship_ids: list[int]) -> None: -# """Delete relationships.""" -# raise NotImplementedError - -# @abstractmethod -# async def get_schema(self) -> Any: -# """Get schema.""" -# raise NotImplementedError - -# @abstractmethod -# async def structured_query(self) -> Any: -# """Perform structured query.""" -# raise NotImplementedError - -# @abstractmethod -# async def update_extraction_prompt(self) -> None: -# """Update extraction prompt.""" -# raise NotImplementedError - -# @abstractmethod -# async def update_kg_search_prompt(self) -> None: -# """Update KG search prompt.""" -# raise NotImplementedError - -# @abstractmethod -# async def upsert_relationships(self) -> None: -# """Upsert relationships.""" -# raise NotImplementedError - -# @abstractmethod -# async def get_existing_entity_chunk_ids( -# self, document_id: UUID -# ) -> list[str]: -# """Get existing entity extraction IDs.""" -# raise NotImplementedError - -# @abstractmethod -# async def get_all_relationships( -# self, collection_id: UUID -# ) -> list[Relationship]: -# raise NotImplementedError - -# @abstractmethod -# async def update_entity_descriptions(self, entities: list[Entity]): -# raise NotImplementedError - - class EntityHandler(Handler): @abstractmethod diff --git a/py/core/providers/database/kg_tmp/__init__.py b/py/core/providers/database/kg_tmp/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/py/core/providers/database/kg_tmp/community.py b/py/core/providers/database/kg_tmp/community.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/py/core/providers/database/kg_tmp/community_info.py b/py/core/providers/database/kg_tmp/community_info.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/py/core/providers/database/kg_tmp/entity.py b/py/core/providers/database/kg_tmp/entity.py deleted file mode 100644 index 3e8f5a48b..000000000 --- a/py/core/providers/database/kg_tmp/entity.py +++ /dev/null @@ -1,9 +0,0 @@ -from core.base.providers.database import Handler -from core.providers.database.kg import PostgresConnectionManager - - -class PostgresEntityHandler(Handler): - def __init__( - self, project_name: str, connection_manager: PostgresConnectionManager - ): - super().__init__(project_name, connection_manager) diff --git a/py/core/providers/database/kg_tmp/graph.py b/py/core/providers/database/kg_tmp/graph.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/py/core/providers/database/kg_tmp/main.py b/py/core/providers/database/kg_tmp/main.py deleted file mode 100644 index 644c6db9b..000000000 --- a/py/core/providers/database/kg_tmp/main.py +++ /dev/null @@ -1,51 +0,0 @@ -# import json -# import logging -# import time -# from typing import Any, AsyncGenerator, Optional, Tuple -# from uuid import UUID -# from fastapi import HTTPException - -# import asyncpg -# from asyncpg.exceptions import PostgresError, UndefinedTableError - -# from core.base import ( -# Community, -# Entity, -# KGExtraction, -# KGExtractionStatus, -# KGHandler, -# R2RException, -# Relationship, -# ) -# from core.base.abstractions import ( -# CommunityInfo, -# EntityLevel, -# KGCreationSettings, -# KGEnrichmentSettings, -# KGEnrichmentStatus, -# KGEntityDeduplicationSettings, -# VectorQuantizationType, -# ) - -# from core.base.utils import _decorate_vector_type, llm_cost_per_million_tokens - -# from .base import PostgresConnectionManager -# from .collection import PostgresCollectionHandler -# from .entity import PostgresEntityHandler -# from .relationship import PostgresRelationshipHandler -# from .community import PostgresCommunityHandler -# from .graph import PostgresGraphHandler - -# logger = logging.getLogger() - - -# class PostgresKGHandler(KGHandler): -# """Handler for Knowledge Graph METHODS in PostgreSQL.""" - -# entity_handler: PostgresEntityHandler -# relationship_handler: PostgresRelationshipHandler -# community_handler: PostgresCommunityHandler -# graph_handler: PostgresGraphHandler - -# def __init__(self, project_name: str, connection_manager: PostgresConnectionManager): -# super().__init__(project_name, connection_manager) diff --git a/py/core/providers/database/kg_tmp/relationship.py b/py/core/providers/database/kg_tmp/relationship.py deleted file mode 100644 index e69de29bb..000000000 From 3289e3d321cb4895c4bfa114e3c8104c53e22010 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Thu, 14 Nov 2024 21:13:06 -0800 Subject: [PATCH 17/21] v2 --- py/core/base/providers/database.py | 3 +- py/core/base/utils/__init__.py | 2 + py/core/main/api/v3/graph_router.py | 5 +- py/core/main/services/kg_service.py | 192 +--------- py/core/pipes/kg/relationships_extraction.py | 2 +- py/core/providers/database/kg.py | 383 +++++++------------ py/shared/abstractions/kg.py | 1 + py/shared/utils/__init__.py | 2 + py/shared/utils/base_utils.py | 7 + 9 files changed, 155 insertions(+), 442 deletions(-) diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index cf659979c..e53383443 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -1545,7 +1545,8 @@ async def get_creation_estimate( ): """Forward to KG handler get_creation_estimate method.""" return await self.graph_handler.get_creation_estimate( - collection_id, kg_creation_settings + collection_id = collection_id, + kg_creation_settings = kg_creation_settings ) async def get_enrichment_estimate( diff --git a/py/core/base/utils/__init__.py b/py/core/base/utils/__init__.py index aead42b43..2544e0f32 100644 --- a/py/core/base/utils/__init__.py +++ b/py/core/base/utils/__init__.py @@ -18,6 +18,7 @@ run_pipeline, to_async_generator, validate_uuid, + _get_str_estimation_output, ) __all__ = [ @@ -40,4 +41,5 @@ "llm_cost_per_million_tokens", "validate_uuid", "_decorate_vector_type", + "_get_str_estimation_output", ] diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index e24b39318..399e5a491 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -79,7 +79,6 @@ def _setup_routes(self): async def create_graph( id: UUID = Path(..., description="The ID of the document to create a graph for."), run_type: KGRunType = Path( - default=KGRunType.CREATE, description="Run type for the graph creation process.", ), settings: Optional[KGCreationSettings] = Body( @@ -1118,9 +1117,7 @@ async def create_entities( async def create_empty_graph( auth_user=Depends(self.providers.auth.auth_wrapper), ): - - - + pass @self.router.get( "/graphs/{collection_id}", diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 13bb7de2f..25c349058 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -485,7 +485,8 @@ async def get_creation_estimate( **kwargs, ): return await self.providers.database.get_creation_estimate( - collection_id, kg_creation_settings + collection_id = collection_id, + kg_creation_settings = kg_creation_settings ) @telemetry_event("get_enrichment_estimate") @@ -500,195 +501,6 @@ async def get_enrichment_estimate( collection_id, kg_enrichment_settings ) - @telemetry_event("list_entities") - async def list_entities( - self, - level: EntityLevel, - id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - entity_categories: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, - ): - return await self.providers.database.get_entities_v3( - level=level, - id=id, - entity_names=entity_names, - entity_categories=entity_categories, - attributes=attributes, - offset=offset, - limit=limit, - ) - - @telemetry_event("get_entities") - async def get_entities( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - entity_ids: Optional[list[str]] = None, - entity_table_name: str = "document_entity", - **kwargs, - ): - return await self.providers.database.get_entities( - collection_id=collection_id, - entity_ids=entity_ids, - entity_table_name=entity_table_name, - offset=offset, - limit=limit, - ) - - @telemetry_event("get_entities") - async def get_entities( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - entity_ids: Optional[list[str]] = None, - entity_table_name: str = "document_entity", - ): - return await self.providers.database.get_entities( - offset=offset, - limit=limit, - collection_id=collection_id, - entity_ids=entity_ids, - entity_table_name=entity_table_name, - ) - - @telemetry_event("get_relationships") - async def get_relationships( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - relationship_ids: Optional[list[str]] = None, - **kwargs, - ): - return await self.providers.database.get_relationships( - offset=offset, - limit=limit, - collection_id=collection_id, - entity_names=entity_names, - relationship_ids=relationship_ids, - ) - - @telemetry_event("list_relationships") - async def list_relationships( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - relationship_ids: Optional[list[str]] = None, - **kwargs, - ): - return await self.providers.database.get_relationships( - offset=offset, - limit=limit, - collection_id=collection_id, - entity_names=entity_names, - relationship_ids=relationship_ids, - ) - - ##### Relationships ##### - - @telemetry_event("list_relationships") - async def list_relationships( - self, - collection_id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - relationship_ids: Optional[list[str]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, - ): - return await self.providers.database.get_relationships( - collection_id=collection_id, - entity_names=entity_names, - relationship_ids=relationship_ids, - offset=offset or 0, - limit=limit or -1, - ) - - @telemetry_event("list_relationships") - async def list_relationships_v3( - self, - level: EntityLevel, - id: UUID, - offset: int, - limit: int, - entity_names: Optional[list[str]] = None, - relationship_types: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - ): - return await self.providers.database.graph_handler.relationships.get( - level=level, - id=id, - entity_names=entity_names, - relationship_types=relationship_types, - attributes=attributes, - offset=offset, - limit=limit, - ) - - ##### Communities ##### - @telemetry_event("get_communities") - async def get_communities( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - levels: Optional[list[int]] = None, - community_numbers: Optional[list[int]] = None, - **kwargs, - ): - return await self.providers.database.get_communities( - offset=offset, - limit=limit, - collection_id=collection_id, - levels=levels, - community_numbers=community_numbers, - ) - - @telemetry_event("list_communities") - async def list_communities( - self, - collection_id: Optional[UUID] = None, - levels: Optional[list[int]] = None, - community_numbers: Optional[list[int]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, - ): - return await self.providers.database.get_communities( - offset=offset, - limit=limit, - collection_id=collection_id, - levels=levels, - community_numbers=community_numbers, - ) - - @telemetry_event("list_communities") - async def list_communities( - self, - collection_id: Optional[UUID] = None, - levels: Optional[list[int]] = None, - community_numbers: Optional[list[int]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - **kwargs, - ): - return await self.providers.database.get_communities( - collection_id=collection_id, - levels=levels, - community_numbers=community_numbers, - offset=offset or 0, - limit=limit or -1, - ) - @telemetry_event("get_deduplication_estimate") async def get_deduplication_estimate( self, diff --git a/py/core/pipes/kg/relationships_extraction.py b/py/core/pipes/kg/relationships_extraction.py index cab5c78d7..1b53726cb 100644 --- a/py/core/pipes/kg/relationships_extraction.py +++ b/py/core/pipes/kg/relationships_extraction.py @@ -267,7 +267,7 @@ async def _run_logic( # type: ignore if filter_out_existing_chunks: existing_chunk_ids = ( - await self.database_provider.get_existing_entity_chunk_ids( + await self.database_provider.graph_handler.relationships_handler.get( document_id=document_id ) ) diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index c02b50c98..fed673ceb 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -36,7 +36,7 @@ VectorQuantizationType, ) -from core.base.utils import _decorate_vector_type, llm_cost_per_million_tokens +from core.base.utils import _decorate_vector_type, llm_cost_per_million_tokens, _get_str_estimation_output from .base import PostgresConnectionManager from .collection import PostgresCollectionHandler @@ -424,7 +424,7 @@ async def create_tables(self) -> None: query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("community_info")} ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - sid SERIAL PRIMARY KEY, + sid SERIAL, node TEXT NOT NULL, cluster INT NOT NULL, parent_cluster INT, @@ -440,7 +440,7 @@ async def create_tables(self) -> None: query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("community")} ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - sid SERIAL PRIMARY KEY, + sid SERIAL, community_number INT NOT NULL, collection_id UUID NOT NULL, level INT NOT NULL, @@ -609,6 +609,139 @@ async def remove_collection( ) + ###### ESTIMATION METHODS ###### + + async def get_creation_estimate( + self, + kg_creation_settings: KGCreationSettings, + document_id: Optional[UUID] = None, + collection_id: Optional[UUID] = None, + ): + """Get the estimated cost and time for creating a KG.""" + + if bool(document_id) ^ bool(collection_id) is False: + raise ValueError("Exactly one of document_id or collection_id must be provided.") + + # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id. + + document_ids = [document_id] if document_id else [ + doc.id for doc in (await self.collection_handler.documents_in_collection(collection_id, offset=0, limit=-1))["results"] # type: ignore + ] + + chunk_counts = await self.connection_manager.fetch_query( + f"SELECT document_id, COUNT(*) as chunk_count FROM {self._get_table_name('vectors')} " + f"WHERE document_id = ANY($1) GROUP BY document_id", [document_ids] + ) + + total_chunks = sum(doc["chunk_count"] for doc in chunk_counts) // kg_creation_settings.extraction_merge_count + estimated_entities = (total_chunks * 10, total_chunks * 20) + estimated_relationships = (int(estimated_entities[0] * 1.25), int(estimated_entities[1] * 1.5)) + estimated_llm_calls = (total_chunks * 2 + estimated_entities[0], total_chunks * 2 + estimated_entities[1]) + total_in_out_tokens = tuple(2000 * calls // 1000000 for calls in estimated_llm_calls) + cost_per_million = llm_cost_per_million_tokens(kg_creation_settings.generation_config.model) + estimated_cost = tuple(tokens * cost_per_million for tokens in total_in_out_tokens) + total_time_in_minutes = tuple(tokens * 10 / 60 for tokens in total_in_out_tokens) + + return { + "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `create-graph` with `--run` in the cli, or `run_type="run"` in the client.', + "document_count": len(document_ids), + "number_of_jobs_created": len(document_ids) + 1, + "total_chunks": total_chunks, + "estimated_entities": _get_str_estimation_output(estimated_entities), + "estimated_relationships": _get_str_estimation_output(estimated_relationships), + "estimated_llm_calls": _get_str_estimation_output(estimated_llm_calls), + "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(total_in_out_tokens), + "estimated_cost_in_usd": _get_str_estimation_output(estimated_cost), + "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + _get_str_estimation_output(total_time_in_minutes), + } + + async def get_enrichment_estimate( + self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings + ): + """Get the estimated cost and time for enriching a KG.""" + + document_ids = [ + doc.id + for doc in ( + await self.collection_handler.documents_in_collection(collection_id) # type: ignore + )["results"] + ] + + # Get entity and relationship counts + entity_count = (await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1);", + [document_ids] + ))[0]["count"] + + if not entity_count: + raise ValueError("No entities found in the graph. Please run `create-graph` first.") + + relationship_count = (await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1);", + [document_ids] + ))[0]["count"] + + # Calculate estimates + estimated_llm_calls = (entity_count // 10, entity_count // 5) + tokens_in_millions = tuple(2000 * calls / 1000000 for calls in estimated_llm_calls) + cost_per_million = llm_cost_per_million_tokens(kg_enrichment_settings.generation_config.model) + estimated_cost = tuple(tokens * cost_per_million for tokens in tokens_in_millions) + estimated_time = tuple(tokens * 10 / 60 for tokens in tokens_in_millions) + + return { + "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_type="run"` in the client.', + "total_entities": entity_count, + "total_relationships": relationship_count, + "estimated_llm_calls": _get_str_estimation_output(estimated_llm_calls), + "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(tokens_in_millions), + "estimated_cost_in_usd": _get_str_estimation_output(estimated_cost), + "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + + _get_str_estimation_output(estimated_time), + } + + async def get_deduplication_estimate( + self, + collection_id: UUID, + kg_deduplication_settings: KGEntityDeduplicationSettings, + ): + """Get the estimated cost and time for deduplicating entities in a KG.""" + try: + query = f""" + SELECT name, count(name) + FROM {self._get_table_name("document_entity")} + WHERE document_id = ANY( + SELECT document_id FROM {self._get_table_name("document_info")} + WHERE $1 = ANY(collection_ids) + ) + GROUP BY name + HAVING count(name) >= 5 + """ + entities = await self.connection_manager.fetch_query(query, [collection_id]) + num_entities = len(entities) + + estimated_llm_calls = (num_entities, num_entities) + tokens_in_millions = ( + estimated_llm_calls[0] * 1000 / 1000000, + estimated_llm_calls[1] * 5000 / 1000000, + ) + cost_per_million = llm_cost_per_million_tokens(kg_deduplication_settings.generation_config.model) + estimated_cost = (tokens_in_millions[0] * cost_per_million, tokens_in_millions[1] * cost_per_million) + estimated_time = (tokens_in_millions[0] * 10 / 60, tokens_in_millions[1] * 10 / 60) + + return { + "message": 'Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges.', + "num_entities": num_entities, + "estimated_llm_calls": _get_str_estimation_output(estimated_llm_calls), + "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(tokens_in_millions), + "estimated_cost_in_usd": _get_str_estimation_output(estimated_cost), + "estimated_total_time_in_minutes": _get_str_estimation_output(estimated_time), + } + except UndefinedTableError: + raise R2RException("Entity embedding table not found. Please run `create-graph` first.", 404) + except Exception as e: + logger.error(f"Error in get_deduplication_estimate: {str(e)}") + raise HTTPException(500, "Error fetching deduplication estimate.") + class PostgresGraphHandler_v1(GraphHandler): """Handler for Knowledge Graph METHODS in PostgreSQL.""" @@ -719,7 +852,7 @@ async def create_tables(self): query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("community_info")} ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - sid SERIAL PRIMARY KEY, + sid SERIAL, node TEXT NOT NULL, cluster INT NOT NULL, parent_cluster INT, @@ -1640,243 +1773,6 @@ async def get_graph_status(self, collection_id: UUID) -> dict: ####################### ESTIMATION METHODS ####################### - async def get_creation_estimate( - self, collection_id: UUID, kg_creation_settings: KGCreationSettings - ): - - # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id. - document_ids = [ - doc.id - for doc in ( - await self.collection_handler.documents_in_collection(collection_id) # type: ignore - )["results"] - ] - - query = f""" - SELECT document_id, COUNT(*) as chunk_count - FROM {self._get_table_name("vectors")} - WHERE document_id = ANY($1) - GROUP BY document_id - """ - - chunk_counts = await self.connection_manager.fetch_query( - query, [document_ids] - ) - - total_chunks = ( - sum(doc["chunk_count"] for doc in chunk_counts) - // kg_creation_settings.extraction_merge_count - ) # 4 chunks per llm - estimated_entities = ( - total_chunks * 10, - total_chunks * 20, - ) # 25 entities per 4 chunks - estimated_relationships = ( - int(estimated_entities[0] * 1.25), - int(estimated_entities[1] * 1.5), - ) # Assuming 1.25 relationships per entity on average - - estimated_llm_calls = ( - total_chunks * 2 + estimated_entities[0], - total_chunks * 2 + estimated_entities[1], - ) - - total_in_out_tokens = ( - 2000 * estimated_llm_calls[0] // 1000000, - 2000 * estimated_llm_calls[1] // 1000000, - ) # in millions - - estimated_cost = ( - total_in_out_tokens[0] - * llm_cost_per_million_tokens( - kg_creation_settings.generation_config.model - ), - total_in_out_tokens[1] - * llm_cost_per_million_tokens( - kg_creation_settings.generation_config.model - ), - ) - - total_time_in_minutes = ( - total_in_out_tokens[0] * 10 / 60, - total_in_out_tokens[1] * 10 / 60, - ) # 10 minutes per million tokens - - return { - "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `create-graph` with `--run` in the cli, or `run_type="run"` in the client.', - "document_count": len(document_ids), - "number_of_jobs_created": len(document_ids) + 1, - "total_chunks": total_chunks, - "estimated_entities": self._get_str_estimation_output( - estimated_entities - ), - "estimated_relationships": self._get_str_estimation_output( - estimated_relationships - ), - "estimated_llm_calls": self._get_str_estimation_output( - estimated_llm_calls - ), - "estimated_total_in_out_tokens_in_millions": self._get_str_estimation_output( - total_in_out_tokens - ), - "estimated_cost_in_usd": self._get_str_estimation_output( - estimated_cost - ), - "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " - + self._get_str_estimation_output(total_time_in_minutes), - } - - async def get_enrichment_estimate( - self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings - ): - - document_ids = [ - doc.id - for doc in ( - await self.collection_handler.documents_in_collection(collection_id) # type: ignore - )["results"] - ] - - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("document_entity")} WHERE document_id = ANY($1); - """ - entity_count = ( - await self.connection_manager.fetch_query(QUERY, [document_ids]) - )[0]["count"] - - if not entity_count: - raise ValueError( - "No entities found in the graph. Please run `create-graph` first." - ) - - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY($1); - """ - relationship_count = ( - await self.connection_manager.fetch_query(QUERY, [document_ids]) - )[0]["count"] - - estimated_llm_calls = (entity_count // 10, entity_count // 5) - estimated_total_in_out_tokens_in_millions = ( - 2000 * estimated_llm_calls[0] / 1000000, - 2000 * estimated_llm_calls[1] / 1000000, - ) - cost_per_million_tokens = llm_cost_per_million_tokens( - kg_enrichment_settings.generation_config.model - ) - estimated_cost = ( - estimated_total_in_out_tokens_in_millions[0] - * cost_per_million_tokens, - estimated_total_in_out_tokens_in_millions[1] - * cost_per_million_tokens, - ) - - estimated_total_time = ( - estimated_total_in_out_tokens_in_millions[0] * 10 / 60, - estimated_total_in_out_tokens_in_millions[1] * 10 / 60, - ) - - return { - "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_type="run"` in the client.', - "total_entities": entity_count, - "total_relationships": relationship_count, - "estimated_llm_calls": self._get_str_estimation_output( - estimated_llm_calls - ), - "estimated_total_in_out_tokens_in_millions": self._get_str_estimation_output( - estimated_total_in_out_tokens_in_millions - ), - "estimated_cost_in_usd": self._get_str_estimation_output( - estimated_cost - ), - "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " - + self._get_str_estimation_output(estimated_total_time), - } - - async def get_deduplication_estimate( - self, - collection_id: UUID, - kg_deduplication_settings: KGEntityDeduplicationSettings, - ): - try: - # number of documents in collection - query = f""" - SELECT name, count(name) - FROM {self._get_table_name("document_entity")} - WHERE document_id = ANY( - SELECT document_id FROM {self._get_table_name("document_info")} - WHERE $1 = ANY(collection_ids) - ) - GROUP BY name - HAVING count(name) >= 5 - """ - entities = await self.connection_manager.fetch_query( - query, [collection_id] - ) - num_entities = len(entities) - - estimated_llm_calls = (num_entities, num_entities) - estimated_total_in_out_tokens_in_millions = ( - estimated_llm_calls[0] * 1000 / 1000000, - estimated_llm_calls[1] * 5000 / 1000000, - ) - estimated_cost_in_usd = ( - estimated_total_in_out_tokens_in_millions[0] - * llm_cost_per_million_tokens( - kg_deduplication_settings.generation_config.model - ), - estimated_total_in_out_tokens_in_millions[1] - * llm_cost_per_million_tokens( - kg_deduplication_settings.generation_config.model - ), - ) - - estimated_total_time_in_minutes = ( - estimated_total_in_out_tokens_in_millions[0] * 10 / 60, - estimated_total_in_out_tokens_in_millions[1] * 10 / 60, - ) - - return KGDeduplicationEstimationResponse( - message='Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the Deduplication process, run `deduplicate-entities` with `--run` in the cli, or `run_type="run"` in the client.', - num_entities=num_entities, - estimated_llm_calls=self._get_str_estimation_output( - estimated_llm_calls - ), - estimated_total_in_out_tokens_in_millions=self._get_str_estimation_output( - estimated_total_in_out_tokens_in_millions - ), - estimated_cost_in_usd=self._get_str_estimation_output( - estimated_cost_in_usd - ), - estimated_total_time_in_minutes=self._get_str_estimation_output( - estimated_total_time_in_minutes - ), - ) - except UndefinedTableError as e: - logger.error( - f"Entity embedding table not found. Please run `create-graph` first. {str(e)}" - ) - raise R2RException( - message="Entity embedding table not found. Please run `create-graph` first.", - status_code=404, - ) - except PostgresError as e: - logger.error( - f"Database error in get_deduplication_estimate: {str(e)}" - ) - raise HTTPException( - status_code=500, - detail="An error occurred while fetching the deduplication estimate.", - ) - except Exception as e: - logger.error( - f"Unexpected error in get_deduplication_estimate: {str(e)}" - ) - raise HTTPException( - status_code=500, - detail="An unexpected error occurred while fetching the deduplication estimate.", - ) - ####################### GRAPH SEARCH METHODS ####################### async def graph_search( # type: ignore @@ -2225,11 +2121,6 @@ async def _compute_leiden_communities( ####################### UTILITY METHODS ####################### - def _get_str_estimation_output(self, x: tuple[Any, Any]) -> str: - if isinstance(x[0], int) and isinstance(x[1], int): - return " - ".join(map(str, x)) - else: - return " - ".join(f"{round(a, 2)}" for a in x) async def get_existing_entity_chunk_ids( self, document_id: UUID diff --git a/py/shared/abstractions/kg.py b/py/shared/abstractions/kg.py index 7c9bbcef0..28037f419 100644 --- a/py/shared/abstractions/kg.py +++ b/py/shared/abstractions/kg.py @@ -11,6 +11,7 @@ class KGRunType(str, Enum): ESTIMATE = "estimate" CREATE = "create" + RUN = "run" # deprecated def __str__(self): return self.value diff --git a/py/shared/utils/__init__.py b/py/shared/utils/__init__.py index eabefcc74..f5ff8b280 100644 --- a/py/shared/utils/__init__.py +++ b/py/shared/utils/__init__.py @@ -16,6 +16,7 @@ run_pipeline, to_async_generator, validate_uuid, + _get_str_estimation_output, ) from .splitter.text import RecursiveCharacterTextSplitter, TextSplitter @@ -43,4 +44,5 @@ "TextSplitter", # Vector utils "_decorate_vector_type", + "_get_str_estimation_output", ] diff --git a/py/shared/utils/base_utils.py b/py/shared/utils/base_utils.py index 5d584a79c..beaef95b6 100644 --- a/py/shared/utils/base_utils.py +++ b/py/shared/utils/base_utils.py @@ -254,3 +254,10 @@ def _decorate_vector_type( quantization_type: VectorQuantizationType = VectorQuantizationType.FP32, ) -> str: return f"{quantization_type.db_type}{input_str}" + + +def _get_str_estimation_output(x: tuple[Any, Any]) -> str: + if isinstance(x[0], int) and isinstance(x[1], int): + return " - ".join(map(str, x)) + else: + return " - ".join(f"{round(a, 2)}" for a in x) From ad1eadae763b3b90c0272b6e12cdfbccd5d15cfe Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Fri, 15 Nov 2024 10:29:26 -0800 Subject: [PATCH 18/21] up --- py/core/base/providers/database.py | 5 +- py/core/main/api/v2/kg_router.py | 12 +- py/core/main/api/v3/graph_router.py | 156 +++-- .../main/orchestration/simple/kg_workflow.py | 1 + py/core/main/services/kg_service.py | 63 +- py/core/pipes/kg/community_summary.py | 2 +- py/core/pipes/kg/entity_description.py | 16 +- py/core/pipes/kg/relationships_extraction.py | 2 +- py/core/providers/database/kg.py | 538 +++++------------- py/shared/abstractions/graph.py | 4 +- py/shared/api/models/kg/responses.py | 29 +- 11 files changed, 355 insertions(+), 473 deletions(-) diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index e53383443..92b382eeb 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -596,6 +596,7 @@ async def list_chunks( ) -> dict[str, Any]: pass + class EntityHandler(Handler): @abstractmethod @@ -1545,8 +1546,8 @@ async def get_creation_estimate( ): """Forward to KG handler get_creation_estimate method.""" return await self.graph_handler.get_creation_estimate( - collection_id = collection_id, - kg_creation_settings = kg_creation_settings + collection_id=collection_id, + kg_creation_settings=kg_creation_settings, ) async def get_enrichment_estimate( diff --git a/py/core/main/api/v2/kg_router.py b/py/core/main/api/v2/kg_router.py index 7b00a1a2f..b1d4bcc9a 100644 --- a/py/core/main/api/v2/kg_router.py +++ b/py/core/main/api/v2/kg_router.py @@ -115,14 +115,14 @@ async def create_graph( if not auth_user.is_superuser: logger.warning("Implement permission checks here.") - logger.info(f"Running create-graph on collection {collection_id}") - # If no collection ID is provided, use the default user collection if not collection_id: collection_id = generate_default_user_collection_id( auth_user.id ) + logger.info(f"Running create-graph on collection {collection_id}") + # If no run type is provided, default to estimate if not run_type: run_type = KGRunType.ESTIMATE @@ -281,7 +281,7 @@ async def get_entities( else: entity_table_name = "collection_entity" - return await self.service.get_entities( + entities = await self.service.get_entities( collection_id=collection_id, entity_ids=entity_ids, entity_table_name=entity_table_name, @@ -289,6 +289,8 @@ async def get_entities( limit=limit, ) + return entities + @self.router.get("/triples") @self.base_endpoint async def get_relationships( @@ -299,7 +301,7 @@ async def get_relationships( entity_names: Optional[list[str]] = Query( None, description="Entity names to filter by." ), - relationship_ids: Optional[list[str]] = Query( + triple_ids: Optional[list[str]] = Query( None, description="Relationship IDs to filter by." ), offset: int = Query(0, ge=0, description="Offset for pagination."), @@ -326,7 +328,7 @@ async def get_relationships( limit=limit, collection_id=collection_id, entity_names=entity_names, - relationship_ids=relationship_ids, + relationship_ids=triple_ids, ) @self.router.get("/communities") diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 399e5a491..496bde99d 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -77,7 +77,10 @@ def _setup_routes(self): ) @self.base_endpoint async def create_graph( - id: UUID = Path(..., description="The ID of the document to create a graph for."), + id: UUID = Path( + ..., + description="The ID of the document to create a graph for.", + ), run_type: KGRunType = Path( description="Run type for the graph creation process.", ), @@ -89,10 +92,10 @@ async def create_graph( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedKGCreationResponse: """ - Creates a new knowledge graph by extracting entities and relationships from a document. - The graph creation process involves: - 1. Parsing documents into semantic chunks - 2. Extracting entities and relationships using LLMs or NER + Creates a new knowledge graph by extracting entities and relationships from a document. + The graph creation process involves: + 1. Parsing documents into semantic chunks + 2. Extracting entities and relationships using LLMs or NER """ settings = settings.dict() if settings else None @@ -157,7 +160,7 @@ async def create_graph( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.graphs.list_entities(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + result = client.chunks.graphs.entities.list(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) """ ), }, @@ -178,7 +181,7 @@ async def create_graph( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.graphs.list_entities(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + result = client.documents.graphs.entities.list(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) """ ), }, @@ -199,7 +202,7 @@ async def create_graph( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.graphs.list_entities(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) + result = client.collections.graphs.entities.list(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", offset=0, limit=100) """ ), }, @@ -278,7 +281,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.graphs.create_entities_v3(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.chunks.graphs.entities.create(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -299,7 +302,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.graphs.create_entities_v3(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.documents.graphs.entities.create(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -320,7 +323,7 @@ async def list_entities( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.graphs.create_entities_v3(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) + result = client.collections.graphs.entities.create(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entities=[entity1, entity2]) """ ), }, @@ -367,7 +370,8 @@ async def create_entities_v3( for entity in entities: if entity.chunk_ids and id not in entity.chunk_ids: raise R2RException( - "Entity extraction IDs must include the chunk ID or should be empty.", 400 + "Entity extraction IDs must include the chunk ID or should be empty.", + 400, ) elif level == EntityLevel.DOCUMENT: @@ -375,7 +379,8 @@ async def create_entities_v3( if entity.document_id: if entity.document_id != id: raise R2RException( - "Entity document IDs must match the document ID or should be empty.", 400 + "Entity document IDs must match the document ID or should be empty.", + 400, ) else: entity.document_id = id @@ -385,7 +390,8 @@ async def create_entities_v3( if entity.collection_id: if entity.collection_id != id: raise R2RException( - "Entity collection IDs must match the collection ID or should be empty.", 400 + "Entity collection IDs must match the collection ID or should be empty.", + 400, ) else: entity.collection_id = id @@ -408,7 +414,7 @@ async def create_entities_v3( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.graphs.update_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + result = client.chunks.graphs.entities.update(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) """ ), }, @@ -429,7 +435,7 @@ async def create_entities_v3( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.graphs.update_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + result = client.documents.graphs.entities.update(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) """ ), }, @@ -450,7 +456,7 @@ async def create_entities_v3( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.graphs.update_entity(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) + result = client.collections.graphs.entities.update(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000", entity=entity) """ ), }, @@ -498,7 +504,8 @@ async def update_entity( else: if entity.id != entity_id: raise R2RException( - "Entity ID must match the entity ID or should be empty.", 400 + "Entity ID must match the entity ID or should be empty.", + 400, ) return await self.services["kg"].update_entity_v3( @@ -519,7 +526,7 @@ async def update_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.graphs.delete_entity(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + result = client.chunks.graphs.entities.delete(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") """ ), }, @@ -540,7 +547,7 @@ async def update_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.graphs.delete_entity(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + result = client.documents.graphs.entities.delete(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") """ ), }, @@ -561,7 +568,7 @@ async def update_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.graphs.delete_entity(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") + result = client.collections.graphs.entities.delete(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", entity_id="123e4567-e89b-12d3-a456-426614174000") """ ), }, @@ -606,7 +613,7 @@ async def delete_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.graphs.list_relationships(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.chunks.graphs.relationships.list(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") """ ), }, @@ -627,7 +634,7 @@ async def delete_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.graphs.list_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.documents.graphs.relationships.list(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") """ ), }, @@ -648,7 +655,7 @@ async def delete_entity( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.graphs.list_relationships(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") + result = client.collections.graphs.relationships.list(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1") """ ), }, @@ -692,7 +699,9 @@ async def list_relationships( "Only superusers can access this endpoint.", 403 ) - relationships, count = await self.services["kg"].list_relationships_v3( + relationships, count = await self.services[ + "kg" + ].list_relationships_v3( level=self._get_path_level(request), id=id, entity_names=entity_names, @@ -720,7 +729,7 @@ async def list_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.graphs.create_relationships(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + result = client.chunks.graphs.relationships.create(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) """ ), }, @@ -741,7 +750,7 @@ async def list_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.graphs.create_relationships(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + result = client.documents.graphs.relationships.create(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) """ ), }, @@ -762,7 +771,7 @@ async def list_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.graphs.create_relationships(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) + result = client.collections.graphs.relationships.create(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationships=[relationship1, relationship2]) """ ), }, @@ -808,7 +817,7 @@ async def create_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.graphs.update_relationship(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) + result = client.chunks.graphs.relationships.update(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) """ ), }, @@ -829,7 +838,7 @@ async def create_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.update_relationship(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) + result = client.documents.relationships.update(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) """ ), }, @@ -850,7 +859,7 @@ async def create_relationships( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.graphs.update_relationship(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) + result = client.collections.graphs.relationships.update(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000", relationship=relationship) """ ), }, @@ -883,7 +892,9 @@ async def update_relationship( relationship.id = relationship_id else: if relationship.id != relationship_id: - raise ValueError("Relationship ID in path and body do not match") + raise ValueError( + "Relationship ID in path and body do not match" + ) return await self.services["kg"].update_relationship_v3( relationship=relationship, @@ -903,7 +914,7 @@ async def update_relationship( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.chunks.graphs.delete_relationship(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000") + result = client.chunks.graphs.relationships.delete(chunk_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000") """ ), }, @@ -924,7 +935,7 @@ async def update_relationship( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.documents.graphs.delete_relationship(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000") + result = client.documents.graphs.relationships.delete(document_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000") """ ), }, @@ -945,7 +956,7 @@ async def update_relationship( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.collections.graphs.delete_relationship(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000") + result = client.collections.graphs.relationships.delete(collection_id="9fbe403b-c11c-5aae-8ade-ef22980c3ad1", relationship_id="123e4567-e89b-12d3-a456-426614174000") """ ), }, @@ -972,17 +983,20 @@ async def delete_relationship( level = self._get_path_level(request) if level == EntityLevel.CHUNK: chunk_ids = [id] - relationship = Relationship(id = relationship_id, chunk_ids = chunk_ids) + relationship = Relationship( + id=relationship_id, chunk_ids=chunk_ids + ) elif level == EntityLevel.DOCUMENT: - relationship = Relationship(id = relationship_id, document_id = id) + relationship = Relationship(id=relationship_id, document_id=id) else: - relationship = Relationship(id = relationship_id, collection_id = id) + relationship = Relationship( + id=relationship_id, collection_id=id + ) return await self.services["kg"].delete_relationship_v3( relationship=relationship, ) - ################### COMMUNITIES ################### @self.router.post( @@ -991,11 +1005,16 @@ async def delete_relationship( @self.base_endpoint async def create_communities( request: Request, - id: UUID = Path(..., description="The ID of the collection to create communities for."), + id: UUID = Path( + ..., + description="The ID of the collection to create communities for.", + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) # run enrich graph workflow @@ -1006,23 +1025,31 @@ async def create_communities( @self.base_endpoint async def create_communities( request: Request, - id: UUID = Path(..., description="The ID of the collection to create communities for."), - communities: list[Community] = Body(..., description="The communities to create."), + id: UUID = Path( + ..., + description="The ID of the collection to create communities for.", + ), + communities: list[Community] = Body( + ..., description="The communities to create." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) for community in communities: if not community.collection_id: community.collection_id = id else: if community.collection_id != id: - raise ValueError("Collection ID in path and body do not match") + raise ValueError( + "Collection ID in path and body do not match" + ) return await self.services["kg"].create_communities_v3(communities) - @self.router.get( "/collections/{id}/graphs/communities", summary="Get communities", @@ -1030,13 +1057,22 @@ async def create_communities( @self.base_endpoint async def get_communities( request: Request, - id: UUID = Path(..., description="The ID of the collection to get communities for."), - offset: int = Query(0, description="Number of communities to skip"), - limit: int = Query(100, description="Maximum number of communities to return"), + id: UUID = Path( + ..., + description="The ID of the collection to get communities for.", + ), + offset: int = Query( + 0, description="Number of communities to skip" + ), + limit: int = Query( + 100, description="Maximum number of communities to return" + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) return await self.services["kg"].get_communities_v3( collection_id=id, @@ -1051,12 +1087,19 @@ async def get_communities( @self.base_endpoint async def delete_community( request: Request, - id: UUID = Path(..., description="The ID of the collection to delete the community from."), - community_id: UUID = Path(..., description="The ID of the community to delete."), + id: UUID = Path( + ..., + description="The ID of the collection to delete the community from.", + ), + community_id: UUID = Path( + ..., description="The ID of the community to delete." + ), auth_user=Depends(self.providers.auth.auth_wrapper), ): if not auth_user.is_superuser: - raise R2RException("Only superusers can access this endpoint.", 403) + raise R2RException( + "Only superusers can access this endpoint.", 403 + ) community = Community(id=community_id, collection_id=id) @@ -1064,7 +1107,6 @@ async def delete_community( community=community, ) - ################### GRAPHS ################### @self.base_endpoint @@ -1511,7 +1553,7 @@ async def update_community( client = R2RClient("http://localhost:7272") # when using auth, do client.login(...) - result = client.graphs.list_communities( + result = client.graphs.communities.list( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", level=1, offset=0, @@ -1589,12 +1631,12 @@ async def get_community( # when using auth, do client.login(...) # Delete all communities - result = client.graphs.delete_communities( + result = client.graphs.communities.delete( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" ) # Delete specific level - result = client.graphs.delete_communities( + result = client.graphs.communities.delete( collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", level=1 )""" diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index 6122a6148..d30045bd0 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -62,6 +62,7 @@ async def create_graph(input_data): logger.error( f"Error in creating graph for document {document_id}: {e}" ) + raise e async def enrich_graph(input_data): diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 25c349058..6d0cb571d 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -166,7 +166,7 @@ async def update_entity_v3( entity ) - @telemetry_event(" ") + @telemetry_event("delete_entity_v3") async def delete_entity_v3( self, entity: Entity, @@ -176,6 +176,25 @@ async def delete_entity_v3( entity ) + # TODO: deprecate this + @telemetry_event("get_entities") + async def get_entities( + self, + collection_id: Optional[UUID] = None, + entity_ids: Optional[list[str]] = None, + entity_table_name: str = "document_entity", + offset: Optional[int] = None, + limit: Optional[int] = None, + **kwargs, + ): + return await self.providers.database.get_entities( + collection_id=collection_id, + entity_ids=entity_ids, + entity_table_name=entity_table_name, + offset=offset or 0, + limit=limit or -1, + ) + ################### RELATIONSHIPS ################### @telemetry_event("list_relationships_v3") @@ -235,6 +254,25 @@ async def update_relationship_v3( ) ) + # TODO: deprecate this + @telemetry_event("get_triples") + async def get_relationships( + self, + collection_id: Optional[UUID] = None, + entity_names: Optional[list[str]] = None, + relationship_ids: Optional[list[str]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, + **kwargs, + ): + return await self.providers.database.get_relationships( + collection_id=collection_id, + entity_names=entity_names, + relationship_ids=relationship_ids, + offset=offset or 0, + limit=limit or -1, + ) + ################### COMMUNITIES ################### @telemetry_event("create_communities_v3") @@ -278,6 +316,25 @@ async def list_communities_v3( id, level ) + # TODO: deprecate this + @telemetry_event("get_communities") + async def get_communities( + self, + collection_id: Optional[UUID] = None, + levels: Optional[list[int]] = None, + community_numbers: Optional[list[int]] = None, + offset: Optional[int] = None, + limit: Optional[int] = None, + **kwargs, + ): + return await self.providers.database.get_communities( + collection_id=collection_id, + levels=levels, + community_numbers=community_numbers, + offset=offset or 0, + limit=limit or -1, + ) + ################### GRAPH ################### @telemetry_event("get_document_ids_for_create_graph") @@ -485,8 +542,8 @@ async def get_creation_estimate( **kwargs, ): return await self.providers.database.get_creation_estimate( - collection_id = collection_id, - kg_creation_settings = kg_creation_settings + collection_id=collection_id, + kg_creation_settings=kg_creation_settings, ) @telemetry_event("get_enrichment_estimate") diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index 5eb54ab99..5481bdf3a 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -257,7 +257,7 @@ async def _run_logic( # type: ignore f"KGCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}" ) community_numbers_exist = ( - await self.database_provider.check_communities_exist( + await self.database_provider.graph_handler.check_communities_exist( collection_id=collection_id, offset=offset, limit=limit ) ) diff --git a/py/core/pipes/kg/entity_description.py b/py/core/pipes/kg/entity_description.py index 83066b467..ff5b6e9b9 100644 --- a/py/core/pipes/kg/entity_description.py +++ b/py/core/pipes/kg/entity_description.py @@ -94,7 +94,7 @@ async def process_entity( out_entity = Entity( name=entities[0].name, chunk_ids=list(unique_chunk_ids), - document_ids=[document_id], + document_id=document_id, ) out_entity.description = ( @@ -131,17 +131,9 @@ async def process_entity( )[0] # upsert the entity and its embedding - await self.database_provider.upsert_embeddings( - [ - ( - out_entity.name, - out_entity.description, - str(out_entity.description_embedding), - out_entity.chunk_ids, - document_id, - ) - ], - "document_entity", + await self.database_provider.add_entities( + [out_entity], + table_name="document_entity", ) return out_entity.name diff --git a/py/core/pipes/kg/relationships_extraction.py b/py/core/pipes/kg/relationships_extraction.py index 1b53726cb..cab5c78d7 100644 --- a/py/core/pipes/kg/relationships_extraction.py +++ b/py/core/pipes/kg/relationships_extraction.py @@ -267,7 +267,7 @@ async def _run_logic( # type: ignore if filter_out_existing_chunks: existing_chunk_ids = ( - await self.database_provider.graph_handler.relationships_handler.get( + await self.database_provider.get_existing_entity_chunk_ids( document_id=document_id ) ) diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index fed673ceb..25b61f732 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -36,7 +36,11 @@ VectorQuantizationType, ) -from core.base.utils import _decorate_vector_type, llm_cost_per_million_tokens, _get_str_estimation_output +from core.base.utils import ( + _decorate_vector_type, + llm_cost_per_million_tokens, + _get_str_estimation_output, +) from .base import PostgresConnectionManager from .collection import PostgresCollectionHandler @@ -522,6 +526,10 @@ def __init__( self.communities, ] + import networkx as nx + + self.nx = nx + async def create_tables(self) -> None: QUERY = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name("graph")} ( @@ -608,11 +616,10 @@ async def remove_collection( QUERY, graph_id, collection_id ) - ###### ESTIMATION METHODS ###### - + async def get_creation_estimate( - self, + self, kg_creation_settings: KGCreationSettings, document_id: Optional[UUID] = None, collection_id: Optional[UUID] = None, @@ -620,39 +627,74 @@ async def get_creation_estimate( """Get the estimated cost and time for creating a KG.""" if bool(document_id) ^ bool(collection_id) is False: - raise ValueError("Exactly one of document_id or collection_id must be provided.") + raise ValueError( + "Exactly one of document_id or collection_id must be provided." + ) # todo: harmonize the document_id and id fields: postgres table contains document_id, but other places use id. - document_ids = [document_id] if document_id else [ - doc.id for doc in (await self.collection_handler.documents_in_collection(collection_id, offset=0, limit=-1))["results"] # type: ignore - ] + document_ids = ( + [document_id] + if document_id + else [ + doc.id for doc in (await self.collection_handler.documents_in_collection(collection_id, offset=0, limit=-1))["results"] # type: ignore + ] + ) chunk_counts = await self.connection_manager.fetch_query( f"SELECT document_id, COUNT(*) as chunk_count FROM {self._get_table_name('vectors')} " - f"WHERE document_id = ANY($1) GROUP BY document_id", [document_ids] + f"WHERE document_id = ANY($1) GROUP BY document_id", + [document_ids], ) - total_chunks = sum(doc["chunk_count"] for doc in chunk_counts) // kg_creation_settings.extraction_merge_count + total_chunks = ( + sum(doc["chunk_count"] for doc in chunk_counts) + // kg_creation_settings.extraction_merge_count + ) estimated_entities = (total_chunks * 10, total_chunks * 20) - estimated_relationships = (int(estimated_entities[0] * 1.25), int(estimated_entities[1] * 1.5)) - estimated_llm_calls = (total_chunks * 2 + estimated_entities[0], total_chunks * 2 + estimated_entities[1]) - total_in_out_tokens = tuple(2000 * calls // 1000000 for calls in estimated_llm_calls) - cost_per_million = llm_cost_per_million_tokens(kg_creation_settings.generation_config.model) - estimated_cost = tuple(tokens * cost_per_million for tokens in total_in_out_tokens) - total_time_in_minutes = tuple(tokens * 10 / 60 for tokens in total_in_out_tokens) + estimated_relationships = ( + int(estimated_entities[0] * 1.25), + int(estimated_entities[1] * 1.5), + ) + estimated_llm_calls = ( + total_chunks * 2 + estimated_entities[0], + total_chunks * 2 + estimated_entities[1], + ) + total_in_out_tokens = tuple( + 2000 * calls // 1000000 for calls in estimated_llm_calls + ) + cost_per_million = llm_cost_per_million_tokens( + kg_creation_settings.generation_config.model + ) + estimated_cost = tuple( + tokens * cost_per_million for tokens in total_in_out_tokens + ) + total_time_in_minutes = tuple( + tokens * 10 / 60 for tokens in total_in_out_tokens + ) return { "message": 'Ran Graph Creation Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG creation process, run `create-graph` with `--run` in the cli, or `run_type="run"` in the client.', "document_count": len(document_ids), "number_of_jobs_created": len(document_ids) + 1, "total_chunks": total_chunks, - "estimated_entities": _get_str_estimation_output(estimated_entities), - "estimated_relationships": _get_str_estimation_output(estimated_relationships), - "estimated_llm_calls": _get_str_estimation_output(estimated_llm_calls), - "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(total_in_out_tokens), - "estimated_cost_in_usd": _get_str_estimation_output(estimated_cost), - "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + _get_str_estimation_output(total_time_in_minutes), + "estimated_entities": _get_str_estimation_output( + estimated_entities + ), + "estimated_relationships": _get_str_estimation_output( + estimated_relationships + ), + "estimated_llm_calls": _get_str_estimation_output( + estimated_llm_calls + ), + "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output( + total_in_out_tokens + ), + "estimated_cost_in_usd": _get_str_estimation_output( + estimated_cost + ), + "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + + _get_str_estimation_output(total_time_in_minutes), } async def get_enrichment_estimate( @@ -668,33 +710,53 @@ async def get_enrichment_estimate( ] # Get entity and relationship counts - entity_count = (await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1);", - [document_ids] - ))[0]["count"] + entity_count = ( + await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('document_entity')} WHERE document_id = ANY($1);", + [document_ids], + ) + )[0]["count"] if not entity_count: - raise ValueError("No entities found in the graph. Please run `create-graph` first.") + raise ValueError( + "No entities found in the graph. Please run `create-graph` first." + ) - relationship_count = (await self.connection_manager.fetch_query( - f"SELECT COUNT(*) FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1);", - [document_ids] - ))[0]["count"] + relationship_count = ( + await self.connection_manager.fetch_query( + f"SELECT COUNT(*) FROM {self._get_table_name('chunk_relationship')} WHERE document_id = ANY($1);", + [document_ids], + ) + )[0]["count"] # Calculate estimates estimated_llm_calls = (entity_count // 10, entity_count // 5) - tokens_in_millions = tuple(2000 * calls / 1000000 for calls in estimated_llm_calls) - cost_per_million = llm_cost_per_million_tokens(kg_enrichment_settings.generation_config.model) - estimated_cost = tuple(tokens * cost_per_million for tokens in tokens_in_millions) - estimated_time = tuple(tokens * 10 / 60 for tokens in tokens_in_millions) + tokens_in_millions = tuple( + 2000 * calls / 1000000 for calls in estimated_llm_calls + ) + cost_per_million = llm_cost_per_million_tokens( + kg_enrichment_settings.generation_config.model + ) + estimated_cost = tuple( + tokens * cost_per_million for tokens in tokens_in_millions + ) + estimated_time = tuple( + tokens * 10 / 60 for tokens in tokens_in_millions + ) return { "message": 'Ran Graph Enrichment Estimate (not the actual run). Note that these are estimated ranges, actual values may vary. To run the KG enrichment process, run `enrich-graph` with `--run` in the cli, or `run_type="run"` in the client.', "total_entities": entity_count, "total_relationships": relationship_count, - "estimated_llm_calls": _get_str_estimation_output(estimated_llm_calls), - "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(tokens_in_millions), - "estimated_cost_in_usd": _get_str_estimation_output(estimated_cost), + "estimated_llm_calls": _get_str_estimation_output( + estimated_llm_calls + ), + "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output( + tokens_in_millions + ), + "estimated_cost_in_usd": _get_str_estimation_output( + estimated_cost + ), "estimated_total_time_in_minutes": "Depends on your API key tier. Accurate estimate coming soon. Rough estimate: " + _get_str_estimation_output(estimated_time), } @@ -716,7 +778,9 @@ async def get_deduplication_estimate( GROUP BY name HAVING count(name) >= 5 """ - entities = await self.connection_manager.fetch_query(query, [collection_id]) + entities = await self.connection_manager.fetch_query( + query, [collection_id] + ) num_entities = len(entities) estimated_llm_calls = (num_entities, num_entities) @@ -724,216 +788,43 @@ async def get_deduplication_estimate( estimated_llm_calls[0] * 1000 / 1000000, estimated_llm_calls[1] * 5000 / 1000000, ) - cost_per_million = llm_cost_per_million_tokens(kg_deduplication_settings.generation_config.model) - estimated_cost = (tokens_in_millions[0] * cost_per_million, tokens_in_millions[1] * cost_per_million) - estimated_time = (tokens_in_millions[0] * 10 / 60, tokens_in_millions[1] * 10 / 60) + cost_per_million = llm_cost_per_million_tokens( + kg_deduplication_settings.generation_config.model + ) + estimated_cost = ( + tokens_in_millions[0] * cost_per_million, + tokens_in_millions[1] * cost_per_million, + ) + estimated_time = ( + tokens_in_millions[0] * 10 / 60, + tokens_in_millions[1] * 10 / 60, + ) return { - "message": 'Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges.', + "message": "Ran Deduplication Estimate (not the actual run). Note that these are estimated ranges.", "num_entities": num_entities, - "estimated_llm_calls": _get_str_estimation_output(estimated_llm_calls), - "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output(tokens_in_millions), - "estimated_cost_in_usd": _get_str_estimation_output(estimated_cost), - "estimated_total_time_in_minutes": _get_str_estimation_output(estimated_time), + "estimated_llm_calls": _get_str_estimation_output( + estimated_llm_calls + ), + "estimated_total_in_out_tokens_in_millions": _get_str_estimation_output( + tokens_in_millions + ), + "estimated_cost_in_usd": _get_str_estimation_output( + estimated_cost + ), + "estimated_total_time_in_minutes": _get_str_estimation_output( + estimated_time + ), } except UndefinedTableError: - raise R2RException("Entity embedding table not found. Please run `create-graph` first.", 404) + raise R2RException( + "Entity embedding table not found. Please run `create-graph` first.", + 404, + ) except Exception as e: logger.error(f"Error in get_deduplication_estimate: {str(e)}") raise HTTPException(500, "Error fetching deduplication estimate.") -class PostgresGraphHandler_v1(GraphHandler): - """Handler for Knowledge Graph METHODS in PostgreSQL.""" - - def __init__( - self, - project_name: str, - connection_manager: PostgresConnectionManager, - collection_handler: PostgresCollectionHandler, - dimension: int, - quantization_type: VectorQuantizationType, - *args: Any, - **kwargs: Any, - ) -> None: - """Initialize the handler with the same signature as the original provider.""" - super().__init__(project_name, connection_manager) - self.collection_handler = collection_handler - self.dimension = dimension - self.quantization_type = quantization_type - try: - import networkx as nx - - self.nx = nx - except ImportError as exc: - raise ImportError( - "NetworkX is not installed. Please install it to use this module." - ) from exc - - def _get_table_name(self, base_name: str) -> str: - """Get the fully qualified table name.""" - return f"{self.project_name}.{base_name}" - - ####################### TABLE CREATION METHODS ####################### - - async def create_tables(self): - # raw entities table - # create schema - - vector_column_str = _decorate_vector_type( - f"({self.dimension})", self.quantization_type - ) - - query = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name("chunk_entity")} ( - id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - sid SERIAL PRIMARY KEY, - category TEXT NOT NULL, - name TEXT NOT NULL, - description TEXT NOT NULL, - chunk_ids UUID[] NOT NULL, - document_id UUID NOT NULL, - attributes JSONB - ); - """ - await self.connection_manager.execute_query(query) - - # raw relationships table, also the final table. this will have embeddings. - query = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name("chunk_relationship")} ( - id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - sid SERIAL PRIMARY KEY, - subject TEXT NOT NULL, - predicate TEXT NOT NULL, - object TEXT NOT NULL, - weight FLOAT NOT NULL, - description TEXT NOT NULL, - embedding {vector_column_str}, - chunk_ids UUID[] NOT NULL, - document_id UUID NOT NULL, - attributes JSONB NOT NULL - ); - """ - await self.connection_manager.execute_query(query) - - # embeddings tables - query = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name("document_entity")} ( - id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - sid SERIAL PRIMARY KEY, - name TEXT NOT NULL, - description TEXT NOT NULL, - chunk_ids UUID[] NOT NULL, - description_embedding {vector_column_str} NOT NULL, - document_id UUID NOT NULL, - UNIQUE (name, document_id) - ); - """ - - await self.connection_manager.execute_query(query) - - # deduplicated entities table - query = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name("collection_entity")} ( - id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - sid SERIAL PRIMARY KEY, - name TEXT NOT NULL, - description TEXT, - chunk_ids UUID[] NOT NULL, - document_ids UUID[] NOT NULL, - collection_id UUID NOT NULL, - description_embedding {vector_column_str}, - attributes JSONB, - UNIQUE (name, collection_id, attributes) - );""" - - await self.connection_manager.execute_query(query) - - # communities table, result of the Leiden algorithm - query = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name("community_info")} ( - id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - sid SERIAL, - node TEXT NOT NULL, - cluster INT NOT NULL, - parent_cluster INT, - level INT NOT NULL, - is_final_cluster BOOLEAN NOT NULL, - relationship_ids INT[] NOT NULL, - collection_id UUID NOT NULL - );""" - - await self.connection_manager.execute_query(query) - - # communities_report table - query = f""" - CREATE TABLE IF NOT EXISTS {self._get_table_name("community")} ( - id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), - sid SERIAL PRIMARY KEY, - community_number INT NOT NULL, - collection_id UUID NOT NULL, - level INT NOT NULL, - name TEXT NOT NULL, - summary TEXT NOT NULL, - findings TEXT[] NOT NULL, - rating FLOAT NOT NULL, - rating_explanation TEXT NOT NULL, - embedding {vector_column_str} NOT NULL, - attributes JSONB, - UNIQUE (community_number, level, collection_id) - );""" - - await self.connection_manager.execute_query(query) - - ################### ENTITY METHODS ################### - - # async def get_entities_v3( - # self, - # level: EntityLevel, - # id: Optional[UUID] = None, - # entity_names: Optional[list[str]] = None, - # entity_categories: Optional[list[str]] = None, - # attributes: Optional[list[str]] = None, - # offset: int = 0, - # limit: int = -1, - # ): - - # params: list = [id] - - # if level != EntityLevel.CHUNK and entity_categories: - # raise ValueError( - # "entity_categories are only supported for chunk level entities" - # ) - - # filter = { - # EntityLevel.CHUNK: "chunk_ids = ANY($1)", - # EntityLevel.DOCUMENT: "document_id = $1", - # EntityLevel.COLLECTION: "collection_id = $1", - # }[level] - - # if entity_names: - # filter += " AND name = ANY($2)" - # params.append(entity_names) - - # if entity_categories: - # filter += " AND category = ANY($3)" - # params.append(entity_categories) - - # QUERY = f""" - # SELECT * from {self._get_table_name(level.table_name)} WHERE {filter} - # OFFSET ${len(params)} LIMIT ${len(params) + 1} - # """ - - # params.extend([offset, limit]) - - # output = await self.connection_manager.fetch_query(QUERY, params) - - # if attributes: - # output = [ - # entity for entity in output if entity["name"] in attributes - # ] - - # return output - # TODO: deprecate this async def get_entities( self, @@ -974,7 +865,7 @@ async def get_entities( if entity_table_name == "collection_entity": query = f""" - SELECT id, name, description, chunk_ids, document_ids {", " + ", ".join(extra_columns) if extra_columns else ""} + SELECT sid as id, name, description, chunk_ids, document_ids {", " + ", ".join(extra_columns) if extra_columns else ""} FROM {self._get_table_name(entity_table_name)} WHERE collection_id = $1 {" AND " + " AND ".join(conditions) if conditions else ""} @@ -983,7 +874,7 @@ async def get_entities( """ else: query = f""" - SELECT id, name, description, chunk_ids, document_id {", " + ", ".join(extra_columns) if extra_columns else ""} + SELECT sid as id, name, description, chunk_ids, document_id {", " + ", ".join(extra_columns) if extra_columns else ""} FROM {self._get_table_name(entity_table_name)} WHERE document_id = ANY( SELECT document_id FROM {self._get_table_name("document_info")} @@ -1034,45 +925,13 @@ async def add_entities( ) cleaned_entities.append(entity_dict) - return await self._add_objects( - cleaned_entities, table_name, conflict_columns + return await _add_objects( + objects=cleaned_entities, + full_table_name=self._get_table_name(table_name), + connection_manager=self.connection_manager, + conflict_columns=conflict_columns, ) - async def create_entities_v3( - self, level: EntityLevel, id: UUID, entities: list[Entity] - ) -> None: - - # TODO: check if already exists - await self._add_objects(entities, level.table_name) - - # async def update_entity(self, collection_id: UUID, entity: Entity) -> None: - # table_name = entity.level.value + "_entity" - - # # check if the entity already exists - # QUERY = f""" - # SELECT COUNT(*) FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 - # """ - # count = ( - # await self.connection_manager.fetch_query( - # QUERY, [entity.id, collection_id] - # ) - # )[0]["count"] - - # if count == 0: - # raise R2RException("Entity does not exist", 404) - - # await self._add_objects([entity], table_name) - - # async def delete_entity(self, collection_id: UUID, entity: Entity) -> None: - - # table_name = entity.level.value + "_entity" - # QUERY = f""" - # DELETE FROM {self._get_table_name(table_name)} WHERE id = $1 AND collection_id = $2 - # """ - # await self.connection_manager.execute_query( - # QUERY, [entity.id, collection_id] - # ) - async def delete_node_via_document_id( self, document_id: UUID, collection_id: UUID ) -> None: @@ -1128,6 +987,7 @@ async def delete_node_via_document_id( ##################### RELATIONSHIP METHODS ##################### + # DEPRECATED async def add_relationships( self, relationships: list[Relationship], @@ -1143,59 +1003,14 @@ async def add_relationships( Returns: result: asyncpg.Record: result of the upsert operation """ - return await self._add_objects( - [ele.to_dict() for ele in relationships], table_name - ) - - async def list_relationships_v3( - self, - level: EntityLevel, - id: UUID, - entity_names: Optional[list[str]] = None, - relationship_types: Optional[list[str]] = None, - attributes: Optional[list[str]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - ): - filter_query = "" - if entity_names: - filter_query += "AND (subject IN ($2) OR object IN ($2))" - if relationship_types: - filter_query += "AND predicate IN ($3)" - - if level == EntityLevel.CHUNK: - QUERY = f""" - SELECT * FROM {self._get_table_name("chunk_relationship")} WHERE $1 = ANY(chunk_ids) - {filter_query} - """ - elif level == EntityLevel.DOCUMENT: - QUERY = f""" - SELECT * FROM {self._get_table_name("chunk_relationship")} WHERE $1 = document_id - {filter_query} - """ - elif level == EntityLevel.COLLECTION: - QUERY = f""" - WITH document_ids AS ( - SELECT document_id FROM {self._get_table_name("document_info")} WHERE $1 = ANY(collection_ids) - ) - SELECT * FROM {self._get_table_name("chunk_relationship")} WHERE document_id IN (SELECT document_id FROM document_ids) - {filter_query} - """ - - results = await self.connection_manager.fetch_query( - QUERY, [id, entity_names, relationship_types] + return await _add_objects( + objects=[ele.to_dict() for ele in relationships], + full_table_name=self._get_table_name(table_name), + connection_manager=self.connection_manager, ) - if attributes: - results = [ - {k: v for k, v in result.items() if k in attributes} - for result in results - ] - - return results - async def get_all_relationships( - self, collection_id: UUID + self, collection_id: UUID, document_ids: Optional[list[UUID]] = None ) -> list[Relationship]: # getting all documents for a collection @@ -1211,61 +1026,14 @@ async def get_all_relationships( ] QUERY = f""" - SELECT id, subject, predicate, weight, object, document_id FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY($1) + SELECT sid as id, subject, predicate, weight, object, document_id FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY($1) """ relationships = await self.connection_manager.fetch_query( QUERY, [document_ids] ) return [Relationship(**relationship) for relationship in relationships] - async def create_relationship( - self, collection_id: UUID, relationship: Relationship - ) -> None: - - # check if the relationship already exists - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE subject = $1 AND predicate = $2 AND object = $3 AND collection_id = $4 - """ - count = ( - await self.connection_manager.fetch_query( - QUERY, - [ - relationship.subject, - relationship.predicate, - relationship.object, - collection_id, - ], - ) - )[0]["count"] - - if count > 0: - raise R2RException("Relationship already exists", 400) - - await self._add_objects([relationship], "chunk_relationship") - - async def update_relationship( - self, relationship_id: UUID, relationship: Relationship - ) -> None: - - # check if relationship_id exists - QUERY = f""" - SELECT COUNT(*) FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 - """ - count = ( - await self.connection_manager.fetch_query(QUERY, [relationship.id]) - )[0]["count"] - - if count == 0: - raise R2RException("Relationship does not exist", 404) - - await self._add_objects([relationship], "chunk_relationship") - - async def delete_relationship(self, relationship_id: UUID) -> None: - QUERY = f""" - DELETE FROM {self._get_table_name("chunk_relationship")} WHERE id = $1 - """ - await self.connection_manager.execute_query(QUERY, [relationship_id]) - + # DEPRECATED async def get_relationships( self, offset: int, @@ -1304,7 +1072,7 @@ async def get_relationships( pagination_clause = " ".join(pagination_params) query = f""" - SELECT id, subject, predicate, object, description + SELECT sid as id, subject, predicate, object, description, chunk_ids, document_id FROM {self._get_table_name("chunk_relationship")} WHERE document_id = ANY( SELECT document_id FROM {self._get_table_name("document_info")} @@ -1477,7 +1245,7 @@ async def get_community_details( SELECT DISTINCT t.id, t.subject, t.predicate, t.object, t.weight, t.description FROM node_relationship_ids nti - JOIN {self._get_table_name("chunk_relationship")} t ON t.id = ANY(nti.relationship_ids); + JOIN {self._get_table_name("chunk_relationship")} t ON t.sid = ANY(nti.relationship_ids); """ relationships = await self.connection_manager.fetch_query( QUERY, [community_number, collection_id] @@ -1614,8 +1382,11 @@ async def perform_graph_clustering( relationships ) - if await self._use_community_cache( - collection_id, relationship_ids_cache + if ( + await self._use_community_cache( + collection_id, relationship_ids_cache + ) + and False ): num_communities = await self._incremental_clustering( relationship_ids_cache, leiden_params, collection_id @@ -2121,7 +1892,6 @@ async def _compute_leiden_communities( ####################### UTILITY METHODS ####################### - async def get_existing_entity_chunk_ids( self, document_id: UUID ) -> list[str]: @@ -2359,9 +2129,7 @@ async def _update_object( params.append(object[id_column]) ret = await connection_manager.execute_many(QUERY, [tuple(params)]) # type: ignore - import pdb - pdb.set_trace() return ret diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 2556541e1..9575149dc 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -58,7 +58,7 @@ class Entity(R2RSerializable): """An entity extracted from a document.""" name: Optional[str] = None - id: Optional[UUID] = None + id: Optional[Union[UUID, int]] = None sid: Optional[int] = None # serial ID level: Optional[EntityLevel] = None category: Optional[str] = None @@ -94,7 +94,7 @@ def __init__(self, **kwargs): class Relationship(R2RSerializable): """A relationship between two entities. This is a generic relationship, and can be used to represent any type of relationship between any two entities.""" - id: Optional[UUID] = None + id: Optional[Union[UUID, int]] = None sid: Optional[int] = None # serial ID subject: Optional[str] = None diff --git a/py/shared/api/models/kg/responses.py b/py/shared/api/models/kg/responses.py index 165c0f604..768a3c933 100644 --- a/py/shared/api/models/kg/responses.py +++ b/py/shared/api/models/kg/responses.py @@ -300,6 +300,27 @@ class Config: json_schema_extra = {"example": {"tuned_prompt": "The updated prompt"}} +class KGDeletionResponse(BaseModel): + """Response for knowledge graph deletion.""" + + message: str = Field( + ..., + description="The message to display to the user.", + ) + id: UUID = Field( + ..., + description="The ID of the deleted graph.", + ) + + class Config: + json_schema_extra = { + "example": { + "message": "Entity deleted successfully.", + "id": "123e4567-e89b-12d3-a456-426614174000", + } + } + + WrappedKGCreationResponse = ResultsWrapper[ Union[KGCreationResponse, KGCreationEstimationResponse] ] @@ -308,14 +329,12 @@ class Config: ] # KG Entities -WrappedKGEntityResponse = ResultsWrapper[KGEntitiesResponse] -WrappedKGEntitiesResponse = PaginatedResultsWrapper[KGEntitiesResponse] -WrappedKGRelationshipsResponse = PaginatedResultsWrapper[ - KGRelationshipsResponse -] +WrappedKGEntitiesResponse = ResultsWrapper[KGEntitiesResponse] +WrappedKGRelationshipsResponse = ResultsWrapper[KGRelationshipsResponse] WrappedKGTunePromptResponse = ResultsWrapper[KGTunePromptResponse] WrappedKGCommunitiesResponse = ResultsWrapper[KGCommunitiesResponse] WrappedKGEntityDeduplicationResponse = ResultsWrapper[ Union[KGEntityDeduplicationResponse, KGDeduplicationEstimationResponse] ] +WrappedKGDeletionResponse = ResultsWrapper[KGDeletionResponse] From 68327e6f9b555c168ce195c6454b809145a3a431 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Fri, 15 Nov 2024 11:19:11 -0800 Subject: [PATCH 19/21] up --- py/core/base/providers/database.py | 232 ------------------- py/core/main/api/v3/graph_router.py | 4 +- py/core/main/services/kg_service.py | 32 +-- py/core/pipes/kg/clustering.py | 12 +- py/core/pipes/kg/community_summary.py | 9 +- py/core/pipes/kg/deduplication.py | 14 +- py/core/pipes/kg/deduplication_summary.py | 4 +- py/core/pipes/kg/entity_description.py | 7 +- py/core/pipes/kg/relationships_extraction.py | 10 +- py/core/pipes/kg/storage.py | 4 +- py/core/pipes/retrieval/kg_search_pipe.py | 4 +- py/shared/abstractions/graph.py | 174 +------------- py/tests/core/providers/kg/test_kg_logic.py | 16 +- 13 files changed, 67 insertions(+), 455 deletions(-) diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 92b382eeb..2b77d28a9 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -1377,238 +1377,6 @@ async def get_semantic_neighbors( similarity_threshold=similarity_threshold, ) - async def add_entities( - self, - entities: list[Entity], - table_name: str, - conflict_columns: list[str] = [], - ) -> Any: - """Forward to KG handler add_entities method.""" - return await self.graph_handler.add_entities( - entities, table_name, conflict_columns - ) - - async def add_relationships( - self, - relationships: list[Relationship], - table_name: str = "chunk_relationship", - ) -> None: - """Forward to KG handler add_relationships method.""" - return await self.graph_handler.add_relationships( - relationships, table_name - ) - - async def get_entity_map( - self, offset: int, limit: int, document_id: UUID - ) -> dict[str, dict[str, list[dict[str, Any]]]]: - """Forward to KG handler get_entity_map method.""" - return await self.graph_handler.get_entity_map( - offset, limit, document_id - ) - - # Community methods - async def add_community_info(self, communities: list[Any]) -> None: - """Forward to KG handler add_communities method.""" - return await self.graph_handler.add_community_info(communities) - - async def get_communities( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - levels: Optional[list[int]] = None, - community_numbers: Optional[list[int]] = None, - ) -> dict: - """Forward to KG handler get_communities method.""" - return await self.graph_handler.get_communities( - offset=offset, - limit=limit, - collection_id=collection_id, - levels=levels, - community_numbers=community_numbers, - ) - - async def add_community(self, community: Community) -> None: - """Forward to KG handler add_community method.""" - return await self.graph_handler.add_community(community) - - async def get_community_details( - self, community_number: int, collection_id: UUID - ) -> Tuple[int, list[Entity], list[Relationship]]: - """Forward to KG handler get_community_details method.""" - return await self.graph_handler.get_community_details( - community_number, collection_id - ) - - async def get_community(self, collection_id: UUID) -> list[Community]: - """Forward to KG handler get_community method.""" - return await self.graph_handler.get_community(collection_id) - - async def check_community_exists( - self, collection_id: UUID, offset: int, limit: int - ) -> list[int]: - """Forward to KG handler check_community_exists method.""" - return await self.graph_handler.check_community_exists( - collection_id, offset, limit - ) - - async def perform_graph_clustering( - self, - collection_id: UUID, - leiden_params: dict[str, Any], - ) -> int: - """Forward to KG handler perform_graph_clustering method.""" - return await self.graph_handler.perform_graph_clustering( - collection_id, leiden_params - ) - - # Graph operations - async def delete_graph_for_collection( - self, collection_id: UUID, cascade: bool = False - ) -> None: - """Forward to KG handler delete_graph_for_collection method.""" - return await self.graph_handler.delete_graph_for_collection( - collection_id, cascade - ) - - async def delete_node_via_document_id( - self, document_id: UUID, collection_id: UUID - ) -> None: - """Forward to KG handler delete_node_via_document_id method.""" - return await self.graph_handler.delete_node_via_document_id( - document_id, collection_id - ) - - # Entity and Relationship operations - async def get_entities( - self, - offset: int, - limit: int, - collection_id: Optional[UUID], - entity_ids: Optional[list[str]] = None, - entity_names: Optional[list[str]] = None, - entity_table_name: str = "document_entity", - extra_columns: Optional[list[str]] = None, - ) -> dict: - """Forward to KG handler get_entities method.""" - return await self.graph_handler.get_entities( - offset=offset, - limit=limit, - collection_id=collection_id, - entity_ids=entity_ids, - entity_names=entity_names, - entity_table_name=entity_table_name, - extra_columns=extra_columns, - ) - - async def get_relationships( - self, - offset: int, - limit: int, - collection_id: Optional[UUID] = None, - entity_names: Optional[list[str]] = None, - relationship_ids: Optional[list[str]] = None, - ) -> dict: - """Forward to KG handler get_relationships method.""" - return await self.graph_handler.get_relationships( - offset=offset, - limit=limit, - collection_id=collection_id, - entity_names=entity_names, - relationship_ids=relationship_ids, - ) - - async def get_entity_count( - self, - collection_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - distinct: bool = False, - entity_table_name: str = "document_entity", - ) -> int: - """Forward to KG handler get_entity_count method.""" - return await self.graph_handler.get_entity_count( - collection_id, document_id, distinct, entity_table_name - ) - - async def get_relationship_count( - self, - collection_id: Optional[UUID] = None, - document_id: Optional[UUID] = None, - ) -> int: - """Forward to KG handler get_relationship_count method.""" - return await self.graph_handler.get_relationship_count( - collection_id, document_id - ) - - # Estimation methods - async def get_creation_estimate( - self, collection_id: UUID, kg_creation_settings: KGCreationSettings - ): - """Forward to KG handler get_creation_estimate method.""" - return await self.graph_handler.get_creation_estimate( - collection_id=collection_id, - kg_creation_settings=kg_creation_settings, - ) - - async def get_enrichment_estimate( - self, collection_id: UUID, kg_enrichment_settings: KGEnrichmentSettings - ): - """Forward to KG handler get_enrichment_estimate method.""" - return await self.graph_handler.get_enrichment_estimate( - collection_id, kg_enrichment_settings - ) - - async def get_deduplication_estimate( - self, - collection_id: UUID, - kg_deduplication_settings: KGEntityDeduplicationSettings, - ): - """Forward to KG handler get_deduplication_estimate method.""" - return await self.graph_handler.get_deduplication_estimate( - collection_id, kg_deduplication_settings - ) - - async def get_all_relationships( - self, collection_id: UUID - ) -> list[Relationship]: - return await self.graph_handler.get_all_relationships(collection_id) - - async def update_entity_descriptions(self, entities: list[Entity]): - return await self.graph_handler.update_entity_descriptions(entities) - - async def graph_search( - self, query: str, **kwargs: Any - ) -> AsyncGenerator[Any, None]: - return self.graph_handler.graph_search(query, **kwargs) # type: ignore - - async def create_vector_index(self) -> None: - return await self.graph_handler.create_vector_index() - - async def delete_relationships(self, relationship_ids: list[int]) -> None: - return await self.graph_handler.delete_relationships(relationship_ids) - - async def get_schema(self) -> Any: - return await self.graph_handler.get_schema() - - async def structured_query(self) -> Any: - return await self.graph_handler.structured_query() - - async def update_extraction_prompt(self) -> None: - return await self.graph_handler.update_extraction_prompt() - - async def update_kg_search_prompt(self) -> None: - return await self.graph_handler.update_kg_search_prompt() - - async def upsert_relationships(self) -> None: - return await self.graph_handler.upsert_relationships() - - async def get_existing_entity_chunk_ids( - self, document_id: UUID - ) -> list[str]: - return await self.graph_handler.get_existing_entity_chunk_ids( - document_id - ) - async def add_prompt( self, name: str, template: str, input_types: dict[str, str] ) -> None: diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 496bde99d..2a82db697 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -240,7 +240,7 @@ async def list_entities( description="The maximum number of entities to retrieve, up to 20,000.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Entity]]: + ) -> WrappedKGEntitiesResponse: # type: ignore """ Retrieves a list of entities associated with a specific chunk. @@ -263,7 +263,7 @@ async def list_entities( attributes=attributes, ) - return entities, { + return entities, { # type: ignore "total_entries": count, } diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 6d0cb571d..775d44f4b 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -187,7 +187,7 @@ async def get_entities( limit: Optional[int] = None, **kwargs, ): - return await self.providers.database.get_entities( + return await self.providers.database.graph_handler.get_entities( collection_id=collection_id, entity_ids=entity_ids, entity_table_name=entity_table_name, @@ -265,7 +265,7 @@ async def get_relationships( limit: Optional[int] = None, **kwargs, ): - return await self.providers.database.get_relationships( + return await self.providers.database.graph_handler.get_relationships( collection_id=collection_id, entity_names=entity_names, relationship_ids=relationship_ids, @@ -327,7 +327,7 @@ async def get_communities( limit: Optional[int] = None, **kwargs, ): - return await self.providers.database.get_communities( + return await self.providers.database.graph_handler.get_communities( collection_id=collection_id, levels=levels, community_numbers=community_numbers, @@ -378,10 +378,12 @@ async def kg_entity_description( f"KGService: Running kg_entity_description for document {document_id}" ) - entity_count = await self.providers.database.get_entity_count( - document_id=document_id, - distinct=True, - entity_table_name="chunk_entity", + entity_count = ( + await self.providers.database.graph_handler.get_entity_count( + document_id=document_id, + distinct=True, + entity_table_name="chunk_entity", + ) ) logger.info( @@ -519,7 +521,7 @@ async def delete_graph_for_collection( cascade: bool, **kwargs, ): - return await self.providers.database.delete_graph_for_collection( + return await self.providers.database.graph_handler.delete_graph_for_collection( collection_id, cascade ) @@ -530,7 +532,7 @@ async def delete_node_via_document_id( collection_id: UUID, **kwargs, ): - return await self.providers.database.delete_node_via_document_id( + return await self.providers.database.graph_handler.delete_node_via_document_id( collection_id, document_id ) @@ -541,9 +543,11 @@ async def get_creation_estimate( kg_creation_settings: KGCreationSettings, **kwargs, ): - return await self.providers.database.get_creation_estimate( - collection_id=collection_id, - kg_creation_settings=kg_creation_settings, + return ( + await self.providers.database.graph_handler.get_creation_estimate( + collection_id=collection_id, + kg_creation_settings=kg_creation_settings, + ) ) @telemetry_event("get_enrichment_estimate") @@ -554,7 +558,7 @@ async def get_enrichment_estimate( **kwargs, ): - return await self.providers.database.get_enrichment_estimate( + return await self.providers.database.graph_handler.get_enrichment_estimate( collection_id, kg_enrichment_settings ) @@ -565,7 +569,7 @@ async def get_deduplication_estimate( kg_deduplication_settings: KGEntityDeduplicationSettings, **kwargs, ): - return await self.providers.database.get_deduplication_estimate( + return await self.providers.database.graph_handler.get_deduplication_estimate( collection_id, kg_deduplication_settings ) diff --git a/py/core/pipes/kg/clustering.py b/py/core/pipes/kg/clustering.py index acbcdf944..6765096b6 100644 --- a/py/core/pipes/kg/clustering.py +++ b/py/core/pipes/kg/clustering.py @@ -6,10 +6,10 @@ AsyncPipe, AsyncState, CompletionProvider, - DatabaseProvider, EmbeddingProvider, ) from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider +from core.providers.database import PostgresDBProvider logger = logging.getLogger() @@ -21,7 +21,7 @@ class KGClusteringPipe(AsyncPipe): def __init__( self, - database_provider: DatabaseProvider, + database_provider: PostgresDBProvider, llm_provider: CompletionProvider, embedding_provider: EmbeddingProvider, config: AsyncPipe.PipeConfig, @@ -49,11 +49,9 @@ async def cluster_kg( Clusters the knowledge graph relationships into communities using hierarchical Leiden algorithm. Uses graspologic library. """ - num_communities = ( - await self.database_provider.perform_graph_clustering( - collection_id, - leiden_params, - ) + num_communities = await self.database_provider.graph_handler.perform_graph_clustering( + collection_id, + leiden_params, ) # type: ignore logger.info( diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index 5481bdf3a..c303e3d89 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -11,11 +11,12 @@ AsyncState, Community, CompletionProvider, - DatabaseProvider, EmbeddingProvider, GenerationConfig, ) + from core.base.abstractions import Entity, Relationship +from core.providers.database import PostgresDBProvider from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -28,7 +29,7 @@ class KGCommunitySummaryPipe(AsyncPipe): def __init__( self, - database_provider: DatabaseProvider, + database_provider: PostgresDBProvider, llm_provider: CompletionProvider, embedding_provider: EmbeddingProvider, config: AsyncPipe.PipeConfig, @@ -144,7 +145,7 @@ async def process_community( """ community_level, entities, relationships = ( - await self.database_provider.get_community_details( + await self.database_provider.graph_handler.get_community_details( community_number=community_number, collection_id=collection_id, ) @@ -223,7 +224,7 @@ async def process_community( ), ) - await self.database_provider.add_community(community) + await self.database_provider.graph_handler.add_community(community) return { "community_number": community.community_number, diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index d01dd4750..dd7b31250 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -50,8 +50,10 @@ async def kg_named_entity_deduplication( self, collection_id: UUID, **kwargs ): try: - entity_count = await self.database_provider.get_entity_count( - collection_id=collection_id, distinct=True + entity_count = ( + await self.database_provider.graph_handler.get_entity_count( + collection_id=collection_id, distinct=True + ) ) logger.info( @@ -62,7 +64,7 @@ async def kg_named_entity_deduplication( ) entities = ( - await self.database_provider.get_entities( + await self.database_provider.graph_handler.get_entities( collection_id=collection_id, offset=0, limit=-1 ) )["entities"] @@ -121,7 +123,7 @@ async def kg_named_entity_deduplication( logger.info( f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {collection_id}" ) - await self.database_provider.add_entities( + await self.database_provider.graph_handler.add_entities( deduplicated_entities_list, table_name="collection_entity", conflict_columns=["name", "collection_id", "attributes"], @@ -146,7 +148,7 @@ async def kg_description_entity_deduplication( from sklearn.cluster import DBSCAN entities = ( - await self.database_provider.get_entities( + await self.database_provider.graph_handler.get_entities( collection_id=collection_id, offset=0, limit=-1, @@ -247,7 +249,7 @@ async def kg_description_entity_deduplication( logger.info( f"KGEntityDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {collection_id}" ) - await self.database_provider.add_entities( + await self.database_provider.graph_handler.add_entities( deduplicated_entities_list, table_name="collection_entity", conflict_columns=["name", "collection_id", "attributes"], diff --git a/py/core/pipes/kg/deduplication_summary.py b/py/core/pipes/kg/deduplication_summary.py index f7ee298b9..ba41c5c47 100644 --- a/py/core/pipes/kg/deduplication_summary.py +++ b/py/core/pipes/kg/deduplication_summary.py @@ -155,7 +155,7 @@ async def _run_logic( ) entities = ( - await self.database_provider.get_entities( + await self.database_provider.graph_handler.get_entities( collection_id=collection_id, entity_table_name="collection_entity", offset=offset, @@ -166,7 +166,7 @@ async def _run_logic( entity_names = [entity.name for entity in entities] entity_descriptions = ( - await self.database_provider.get_entities( + await self.database_provider.graph_handler.get_entities( collection_id=collection_id, entity_names=entity_names, entity_table_name="document_entity", diff --git a/py/core/pipes/kg/entity_description.py b/py/core/pipes/kg/entity_description.py index ff5b6e9b9..3f7956b8d 100644 --- a/py/core/pipes/kg/entity_description.py +++ b/py/core/pipes/kg/entity_description.py @@ -15,6 +15,7 @@ ) from core.base.abstractions import Entity from core.base.pipes.base_pipe import AsyncPipe +from core.providers.database import PostgresDBProvider from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider logger = logging.getLogger() @@ -30,7 +31,7 @@ class Input(AsyncPipe.Input): def __init__( self, - database_provider: DatabaseProvider, + database_provider: PostgresDBProvider, llm_provider: CompletionProvider, embedding_provider: EmbeddingProvider, config: AsyncPipe.PipeConfig, @@ -131,7 +132,7 @@ async def process_entity( )[0] # upsert the entity and its embedding - await self.database_provider.add_entities( + await self.database_provider.graph_handler.add_entities( [out_entity], table_name="document_entity", ) @@ -147,7 +148,7 @@ async def process_entity( f"KGEntityDescriptionPipe: Getting entity map for document {document_id}", ) - entity_map = await self.database_provider.get_entity_map( + entity_map = await self.database_provider.graph_handler.get_entity_map( offset, limit, document_id ) total_entities = len(entity_map) diff --git a/py/core/pipes/kg/relationships_extraction.py b/py/core/pipes/kg/relationships_extraction.py index cab5c78d7..8067b7b4d 100644 --- a/py/core/pipes/kg/relationships_extraction.py +++ b/py/core/pipes/kg/relationships_extraction.py @@ -8,7 +8,6 @@ from core.base import ( AsyncState, CompletionProvider, - DatabaseProvider, DocumentChunk, Entity, GenerationConfig, @@ -19,6 +18,7 @@ ) from core.base.pipes.base_pipe import AsyncPipe from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider +from core.providers.database import PostgresDBProvider logger = logging.getLogger() @@ -43,7 +43,7 @@ class Input(AsyncPipe.Input): def __init__( self, - database_provider: DatabaseProvider, + database_provider: PostgresDBProvider, llm_provider: CompletionProvider, config: AsyncPipe.PipeConfig, logging_provider: SqlitePersistentLoggingProvider, @@ -266,10 +266,8 @@ async def _run_logic( # type: ignore ) if filter_out_existing_chunks: - existing_chunk_ids = ( - await self.database_provider.get_existing_entity_chunk_ids( - document_id=document_id - ) + existing_chunk_ids = await self.database_provider.graph_handler.get_existing_entity_chunk_ids( + document_id=document_id ) extractions = [ extraction diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index b5cdb31e4..ad672ca79 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -75,7 +75,7 @@ async def store( extraction.document_id ) - await self.database_provider.add_entities( + await self.database_provider.graph_handler.add_entities( extraction.entities, table_name=f"chunk_entity" ) @@ -89,7 +89,7 @@ async def store( extraction.document_id ) - await self.database_provider.add_relationships( + await self.database_provider.graph_handler.add_relationships( extraction.relationships, table_name=f"chunk_relationship", ) diff --git a/py/core/pipes/retrieval/kg_search_pipe.py b/py/core/pipes/retrieval/kg_search_pipe.py index 9c5f09c08..5c3d8da76 100644 --- a/py/core/pipes/retrieval/kg_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_pipe.py @@ -110,7 +110,7 @@ async def local_search( # entity search search_type = "__Entity__" - async for search_result in await self.database_provider.graph_search( # type: ignore + async for search_result in await self.database_provider.graph_handler.graph_search( # type: ignore message, search_type=search_type, search_type_limits=kg_search_settings.local_search_limits[ @@ -167,7 +167,7 @@ async def local_search( # community search search_type = "__Community__" - async for search_result in await self.database_provider.graph_search( # type: ignore + async for search_result in await self.database_provider.graph_handler.graph_search( # type: ignore message, search_type=search_type, search_type_limits=kg_search_settings.local_search_limits[ diff --git a/py/shared/abstractions/graph.py b/py/shared/abstractions/graph.py index 9575149dc..6913047f3 100644 --- a/py/shared/abstractions/graph.py +++ b/py/shared/abstractions/graph.py @@ -25,14 +25,6 @@ class Identified: """Human readable ID used to refer to this community in prompts or texts displayed to users, such as in a report text (optional).""" -@dataclass -class Named(Identified): - """A protocol for an item with a name/title.""" - - title: str - """The name/title of the item.""" - - class EntityType(R2RSerializable): id: str name: str @@ -98,37 +90,17 @@ class Relationship(R2RSerializable): sid: Optional[int] = None # serial ID subject: Optional[str] = None - """The source entity name.""" - predicate: Optional[str] = None - """A description of the relationship (optional).""" - subject_id: Optional[UUID] = None - """The source entity ID (optional).""" - object_id: Optional[UUID] = None - """The target entity ID (optional).""" - object: Optional[str] = None - """The target entity name.""" - weight: float | None = 1.0 - """The edge weight.""" - description: str | None = None - """A description of the relationship (optional).""" - + description_embedding: list[float] | None = None predicate_embedding: list[float] | None = None - """The semantic embedding for the relationship description (optional).""" - chunk_ids: list[UUID] = [] - """List of text unit IDs in which the relationship appears (optional).""" - document_id: Optional[UUID] = None - """Document ID in which the relationship appears (optional).""" - attributes: dict[str, Any] | str = {} - """Additional attributes associated with the relationship (optional). To be included in the search prompt""" def __init__(self, **kwargs): super().__init__(**kwargs) @@ -138,39 +110,9 @@ def __init__(self, **kwargs): except json.JSONDecodeError: self.attributes = self.attributes - @classmethod - def from_dict( # type: ignore - cls, - d: dict[str, Any], - id_key: str = "id", - short_id_key: str = "short_id", - source_key: str = "subject", - target_key: str = "object", - predicate_key: str = "predicate", - description_key: str = "description", - weight_key: str = "weight", - chunk_ids_key: str = "chunk_ids", - document_id_key: str = "document_id", - attributes_key: str = "attributes", - ) -> "Relationship": - """Create a new relationship from the dict data.""" - - return Relationship( - id=d[id_key], - short_id=d.get(short_id_key), - subject=d[source_key], - object=d[target_key], - predicate=d.get(predicate_key), - description=d.get(description_key), - weight=d.get(weight_key, 1.0), - chunk_ids=d.get(chunk_ids_key), - document_id=d.get(document_id_key), - attributes=d.get(attributes_key, {}), - ) - @dataclass -class CommunityInfo(BaseModel): +class CommunityInfo(R2RSerializable): """A protocol for a community in the system.""" node: str @@ -184,122 +126,30 @@ class CommunityInfo(BaseModel): def __init__(self, **kwargs): super().__init__(**kwargs) - @classmethod - def from_dict(cls, d: dict[str, Any]) -> "CommunityInfo": - return CommunityInfo( - node=d["node"], - cluster=d["cluster"], - parent_cluster=d["parent_cluster"], - level=d["level"], - is_final_cluster=d["is_final_cluster"], - relationship_ids=d["relationship_ids"], - collection_id=d["collection_id"], - ) - @dataclass -class Community(BaseModel): - - id: Optional[Union[int, UUID]] = None - - """Defines an LLM-generated summary report of a community.""" +class Community(R2RSerializable): community_number: int - """The ID of the community this report is associated with.""" - level: int - """The level of the community this report is associated with.""" - collection_id: uuid.UUID - """The ID of the collection this report is associated with.""" - name: str = "" - """Name of the report.""" - summary: str = "" - """Summary of the report.""" - findings: list[str] = [] - """Findings of the report.""" - + id: Optional[Union[int, UUID]] = None rating: float | None = None - """Rating of the report.""" - rating_explanation: str | None = None - """Explanation of the rating.""" - embedding: list[float] | None = None - """Embedding of summary and findings.""" - attributes: dict[str, Any] | None = None - """A dictionary of additional attributes associated with the report (optional).""" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - if isinstance(self.attributes, str): - self.attributes = json.loads(self.attributes) - - @classmethod - def from_dict( - cls, - d: dict[str, Any], - id_key: str = "id", - title_key: str = "title", - community_number_key: str = "community_number", - short_id_key: str = "short_id", - summary_key: str = "summary", - findings_key: str = "findings", - rank_key: str = "rank", - summary_embedding_key: str = "summary_embedding", - embedding_key: str = "embedding", - attributes_key: str = "attributes", - ) -> "Community": - """Create a new community report from the dict data.""" - return Community( - id=d[id_key], - title=d[title_key], - community_number=d[community_number_key], - short_id=d.get(short_id_key), - summary=d[summary_key], - findings=d[findings_key], - rank=d[rank_key], - summary_embedding=d.get(summary_embedding_key), - embedding=d.get(embedding_key), - attributes=d.get(attributes_key), - ) - - -class Graph(BaseModel): - """A graph in the system.""" - - id: uuid.UUID - status: str - created_at: datetime - updated_at: datetime - document_ids: list[uuid.UUID] = [] - collection_ids: list[uuid.UUID] = [] - attributes: dict[str, Any] = {} def __init__(self, **kwargs): super().__init__(**kwargs) if isinstance(self.attributes, str): self.attributes = json.loads(self.attributes) - @classmethod - def from_dict(cls, d: dict[str, Any]) -> "Graph": - return Graph( - id=d["id"], - status=d["status"], - created_at=d["created_at"], - updated_at=d["updated_at"], - document_ids=d["document_ids"], - collection_ids=d["collection_ids"], - attributes=d["attributes"], - ) - class KGExtraction(R2RSerializable): - """An extraction from a document that is part of a knowledge graph.""" + """A protocol for a knowledge graph extraction.""" chunk_ids: list[uuid.UUID] document_id: uuid.UUID @@ -322,17 +172,3 @@ class Graph(R2RSerializable): def __init__(self, **kwargs): super().__init__(**kwargs) - - @classmethod - def from_dict(cls, d: dict[str, Any]) -> "Graph": - return Graph( - id=d["id"], - name=d["name"], - description=d["description"], - document_ids=d["document_ids"], - collection_ids=d["collection_ids"], - statistics=d["statistics"], - created_at=d["created_at"], - updated_at=d["updated_at"], - status=d["status"], - ) diff --git a/py/tests/core/providers/kg/test_kg_logic.py b/py/tests/core/providers/kg/test_kg_logic.py index bc129609c..072966385 100644 --- a/py/tests/core/providers/kg/test_kg_logic.py +++ b/py/tests/core/providers/kg/test_kg_logic.py @@ -229,7 +229,7 @@ async def test_add_entities_raw( async def test_add_entities( postgres_db_provider, entities_list, collection_id ): - await postgres_db_provider.add_entities( + await postgres_db_provider.graph_handler.add_entities( entities_list, table_name="document_entity" ) entities = await postgres_db_provider.get_entities( @@ -245,7 +245,7 @@ async def test_add_entities( async def test_add_relationships( postgres_db_provider, relationships_raw_list, collection_id ): - await postgres_db_provider.add_relationships( + await postgres_db_provider.graph_handler.add_relationships( relationships_raw_list, table_name="chunk_relationship" ) relationships = await postgres_db_provider.get_relationships(collection_id) @@ -269,7 +269,9 @@ async def test_get_entity_map( assert entity_map["Entity1"]["entities"][0].name == "Entity1" assert entity_map["Entity2"]["entities"][0].name == "Entity2" - await postgres_db_provider.add_relationships(relationships_raw_list) + await postgres_db_provider.graph_handler.add_relationships( + relationships_raw_list + ) entity_map = await postgres_db_provider.get_entity_map(0, 2, document_id) assert entity_map["Entity1"]["entities"][0].name == "Entity1" assert entity_map["Entity2"]["entities"][0].name == "Entity2" @@ -310,7 +312,9 @@ async def test_upsert_embeddings( async def test_get_all_relationships( postgres_db_provider, collection_id, relationships_raw_list ): - await postgres_db_provider.add_relationships(relationships_raw_list) + await postgres_db_provider.graph_handler.add_relationships( + relationships_raw_list + ) relationships = await postgres_db_provider.get_relationships(collection_id) assert relationships["relationships"][0].subject == "Entity1" assert relationships["relationships"][1].subject == "Entity2" @@ -351,7 +355,7 @@ async def test_perform_graph_clustering( await postgres_db_provider.add_entities( entities_list, table_name="document_entity" ) - await postgres_db_provider.add_relationships( + await postgres_db_provider.graph_handler.add_relationships( relationships_raw_list, table_name="chunk_relationship" ) @@ -374,7 +378,7 @@ async def test_get_community_details( await postgres_db_provider.add_entities( entities_list, table_name="document_entity" ) - await postgres_db_provider.add_relationships( + await postgres_db_provider.graph_handler.add_relationships( relationships_raw_list, table_name="chunk_relationship" ) await postgres_db_provider.add_community_info(community_table_info) From 09189442edc5521f6475f177b8a51aa87500dac6 Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Fri, 15 Nov 2024 12:15:20 -0800 Subject: [PATCH 20/21] up --- js/sdk/src/r2rClient.ts | 4 ++-- py/cli/commands/v2/kg.py | 4 ++-- py/core/main/api/v3/graph_router.py | 6 +++--- py/core/providers/database/kg.py | 5 ++--- py/sdk/v2/kg.py | 2 +- 5 files changed, 10 insertions(+), 11 deletions(-) diff --git a/js/sdk/src/r2rClient.ts b/js/sdk/src/r2rClient.ts index 007e6a561..7b04a48fe 100644 --- a/js/sdk/src/r2rClient.ts +++ b/js/sdk/src/r2rClient.ts @@ -1705,8 +1705,8 @@ export class r2rClient extends BaseClient { * @param entity_level The level of entity to filter by. * @param relationship_ids Relationship IDs to filter by. */ - @feature("getRelationships") - async getRelationships( + @feature("getTriples") + async getTriples( collection_id?: string, offset?: number, limit?: number, diff --git a/py/cli/commands/v2/kg.py b/py/cli/commands/v2/kg.py index 2f2054eef..c213603df 100644 --- a/py/cli/commands/v2/kg.py +++ b/py/cli/commands/v2/kg.py @@ -229,7 +229,7 @@ async def get_entities( @click.option( "--collection-id", required=True, - help="Collection ID to retrieve relationships from.", + help="Collection ID to retrieve triples from.", ) @click.option( "--offset", @@ -254,7 +254,7 @@ async def get_entities( help="Entity names to filter by.", ) @pass_context -async def get_relationships( +async def get_triples( ctx, collection_id, offset, limit, relationship_ids, entity_names ): """ diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 2a82db697..85bac9a95 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -1,6 +1,6 @@ import logging import textwrap -from typing import Optional, Union +from typing import Optional from uuid import UUID from fastapi import Body, Depends, Path, Query @@ -689,8 +689,8 @@ async def list_relationships( limit: int = Query( 100, ge=0, - le=20_000, - description="The maximum number of relationships to retrieve, up to 20,000.", + le=1000, + description="The maximum number of relationships to retrieve, up to 1000.", ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> PaginatedResultsWrapper[list[Relationship]]: diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 25b61f732..184747449 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -149,9 +149,8 @@ async def create(self, entities: list[Entity]) -> None: if entity_level is None: raise ValueError("Entity level is not set") - for entity in entities: - if entity.level != entity_level: - raise ValueError("All entities must be of the same level") + if not all(entity.level == entity_level for entity in entities): + raise ValueError("All entities must be of the same level") return await _add_objects( objects=[entity.__dict__ for entity in entities], diff --git a/py/sdk/v2/kg.py b/py/sdk/v2/kg.py index ea809b036..993e96e89 100644 --- a/py/sdk/v2/kg.py +++ b/py/sdk/v2/kg.py @@ -103,7 +103,7 @@ async def get_entities( return await self._make_request("GET", "entities", params=params) # type: ignore - async def get_relationships( + async def get_triples( self, collection_id: Optional[Union[UUID, str]] = None, entity_names: Optional[list[str]] = None, From 7636c9d0ded04ae06a6df0de11171e7b3c18d70c Mon Sep 17 00:00:00 2001 From: Shreyas Pimpalgaonkar Date: Fri, 15 Nov 2024 12:44:54 -0800 Subject: [PATCH 21/21] up --- py/core/main/api/v3/graph_router.py | 833 ++++++++++------------ py/core/main/services/kg_service.py | 9 +- py/core/pipes/kg/community_summary.py | 8 +- py/core/pipes/kg/deduplication_summary.py | 4 +- py/core/pipes/kg/storage.py | 4 +- py/core/providers/database/kg.py | 110 ++- 6 files changed, 432 insertions(+), 536 deletions(-) diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 85bac9a95..6cf826aef 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -90,7 +90,7 @@ async def create_graph( ), run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGCreationResponse: + ) -> WrappedKGCreationResponse: # type: ignore """ Creates a new knowledge graph by extracting entities and relationships from a document. The graph creation process involves: @@ -98,7 +98,7 @@ async def create_graph( 2. Extracting entities and relationships using LLMs or NER """ - settings = settings.dict() if settings else None + settings = settings.dict() if settings else None # type: ignore if not auth_user.is_superuser: logger.warning("Implement permission checks here.") @@ -112,7 +112,7 @@ async def create_graph( ) if settings: - server_kg_creation_settings = update_settings_from_dict( + server_kg_creation_settings = update_settings_from_dict( # type: ignore server_kg_creation_settings, settings ) @@ -138,8 +138,8 @@ async def create_graph( from core.main.orchestration import simple_kg_factory logger.info("Running create-graph without orchestration.") - simple_kg = simple_kg_factory(self.service) - await simple_kg["create-graph"](workflow_input) + simple_kg = simple_kg_factory(self.services["kg"]) + await simple_kg["create-graph"](workflow_input) # type: ignore return { "message": "Graph created successfully.", "task_id": None, @@ -693,7 +693,7 @@ async def list_relationships( description="The maximum number of relationships to retrieve, up to 1000.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> PaginatedResultsWrapper[list[Relationship]]: + ) -> WrappedKGRelationshipsResponse: # type: ignore if not auth_user.is_superuser: raise R2RException( "Only superusers can access this endpoint.", 403 @@ -711,7 +711,7 @@ async def list_relationships( limit=limit, ) - return relationships, { + return relationships, { # type: ignore "total_entries": count, } @@ -795,7 +795,7 @@ async def create_relationships( "Only superusers can access this endpoint.", 403 ) - return { + return { # type: ignore "message": "Relationships created successfully.", "count": await self.services["kg"].create_relationships_v3( id=id, @@ -1211,86 +1211,7 @@ async def get_graph_status( "Only superusers can view graph status", 403 ) - # status = await self.services["kg"].get_graph_status(collection_id) - # return status # type: ignore - - # @self.router.post( - # "/graphs/{collection_id}/enrich", - # summary="Enrich an existing graph", - # openapi_extra={ - # "x-codeSamples": [ - # { - # "lang": "Python", - # "source": textwrap.dedent( - # """ - # from r2r import R2RClient - - # client = R2RClient("http://localhost:7272") - # # when using auth, do client.login(...) - - # result = client.graphs.enrich( - # collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - # settings={ - # "community_detection": { - # "algorithm": "louvain", - # "resolution": 1.0 - # }, - # "embedding_model": "sentence-transformers/all-MiniLM-L6-v2" - # } - # )""" - # ), - # }, - # { - # "lang": "cURL", - # "source": textwrap.dedent( - # """ - # curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/enrich" \\ - # -H "Content-Type: application/json" \\ - # -H "Authorization: Bearer YOUR_API_KEY" \\ - # -d '{ - # "settings": { - # "community_detection": { - # "algorithm": "louvain", - # "resolution": 1.0 - # }, - # "embedding_model": "sentence-transformers/all-MiniLM-L6-v2" - # } - # }'""" - # ), - # }, - # ] - # }, - # ) - # @self.base_endpoint - # async def enrich_graph( - # collection_id: UUID = Path(...), - # settings: Optional[dict] = Body(None), - # run_with_orchestration: bool = Query(True), - # auth_user=Depends(self.providers.auth.auth_wrapper), - # ) -> ResultsWrapper[WrappedKGEnrichmentResponse]: - # """Enriches an existing graph with additional information and creates communities.""" - # if not auth_user.is_superuser: - # raise R2RException("Only superusers can enrich graphs", 403) - - # server_settings = self.providers.database.config.kg_enrichment_settings - # if settings: - # server_settings = update_settings_from_dict(server_settings, settings) - - # workflow_input = { - # "collection_id": str(collection_id), - # "kg_enrichment_settings": server_settings.model_dump_json(), - # "user": auth_user.model_dump_json(), - # } - - # if run_with_orchestration: - # return await self.orchestration_provider.run_workflow( - # "enrich-graph", {"request": workflow_input}, {} - # ) - # else: - # from core.main.orchestration import simple_kg_factory - # simple_kg = simple_kg_factory(self.services["kg"]) - # await simple_kg["enrich-graph"](workflow_input) - # return {"message": "Graph enriched successfully.", "task_id": None} + raise NotImplementedError("Not implemented") @self.router.delete( "/graphs/{collection_id}", @@ -1413,406 +1334,378 @@ async def deduplicate_entities( "task_id": None, } - @self.base_endpoint - async def create_communities( - collection_id: UUID = Path(...), - settings: Optional[dict] = Body(None), - run_type: Optional[KGRunType] = Body( - default=None, - description="Run type for the graph creation process.", - ), - run_with_orchestration: bool = Query(True), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGEnrichmentResponse: - """Creates communities in the graph by analyzing entity relationships and similarities. - - Communities are created by: - 1. Builds similarity graph between entities - 2. Applies community detection algorithm (e.g. Leiden) - 3. Creates hierarchical community levels - 4. Generates summaries and insights for each community - """ - if not auth_user.is_superuser: - raise R2RException( - "Only superusers can create communities", 403 - ) - - # Apply runtime settings overrides - server_kg_enrichment_settings = ( - self.providers.database.config.kg_enrichment_settings - ) - if settings: - server_kg_enrichment_settings = update_settings_from_dict( - server_kg_enrichment_settings, settings - ) - - workflow_input = { - "collection_id": str(collection_id), - "kg_enrichment_settings": server_kg_enrichment_settings.model_dump_json(), - "user": auth_user.model_dump_json(), - } - - if not run_type: - run_type = KGRunType.ESTIMATE - - # If the run type is estimate, return an estimate of the enrichment cost - if run_type is KGRunType.ESTIMATE: - return await self.services["kg"].get_enrichment_estimate( - collection_id, server_kg_enrichment_settings - ) - - else: - if run_with_orchestration: - return await self.orchestration_provider.run_workflow( # type: ignore - "enrich-graph", {"request": workflow_input}, {} - ) - else: - from core.main.orchestration import simple_kg_factory - - simple_kg = simple_kg_factory(self.services["kg"]) - await simple_kg["enrich-graph"](workflow_input) - return { # type: ignore - "message": "Communities created successfully.", - "task_id": None, - } - - @self.router.post( - "/graphs/{collection_id}/communities/{community_id}", - summary="Update community", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient - - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) + # @self.base_endpoint + # async def create_communities( + # collection_id: UUID = Path(...), + # settings: Optional[dict] = Body(None), + # run_type: Optional[KGRunType] = Body( + # default=None, + # description="Run type for the graph creation process.", + # ), + # run_with_orchestration: bool = Query(True), + # auth_user=Depends(self.providers.auth.auth_wrapper), + # ) -> WrappedKGEnrichmentResponse: + # """Creates communities in the graph by analyzing entity relationships and similarities. + + # Communities are created by: + # 1. Builds similarity graph between entities + # 2. Applies community detection algorithm (e.g. Leiden) + # 3. Creates hierarchical community levels + # 4. Generates summaries and insights for each community + # """ + # if not auth_user.is_superuser: + # raise R2RException( + # "Only superusers can create communities", 403 + # ) - result = client.graphs.update_community( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - community_id="5xyz789a-bc12-3def-4ghi-jk5lm6no7pq8", - community_update={ - "metadata": { - "topic": "Technology", - "description": "Tech companies and products" - } - } - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities/5xyz789a-bc12-3def-4ghi-jk5lm6no7pq8" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "metadata": { - "topic": "Technology", - "description": "Tech companies and products" - } - }'""" - ), - }, - ] - }, - ) - @self.base_endpoint - async def update_community( - collection_id: UUID = Path(...), - community_id: UUID = Path(...), - community_update: dict = Body(...), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[Community]: - """Updates a community's metadata.""" - raise NotImplementedError("Not implemented") - # if not auth_user.is_superuser: - # raise R2RException( - # "Only superusers can update communities", 403 - # ) + # # Apply runtime settings overrides + # server_kg_enrichment_settings = ( + # self.providers.database.config.kg_enrichment_settings + # ) + # if settings: + # server_kg_enrichment_settings = update_settings_from_dict( + # server_kg_enrichment_settings, settings + # ) - # updated_community = await self.services["kg"].update_community( - # collection_id, community_id, community_update - # ) - # return updated_community # type: ignore + # workflow_input = { + # "collection_id": str(collection_id), + # "kg_enrichment_settings": server_kg_enrichment_settings.model_dump_json(), + # "user": auth_user.model_dump_json(), + # } - @self.router.get( - "/graphs/{collection_id}/communities", - summary="List communities", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient + # if not run_type: + # run_type = KGRunType.ESTIMATE - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) + # # If the run type is estimate, return an estimate of the enrichment cost + # if run_type is KGRunType.ESTIMATE: + # return await self.services["kg"].get_enrichment_estimate( + # collection_id, server_kg_enrichment_settings + # ) - result = client.graphs.communities.list( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - level=1, - offset=0, - limit=100 - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities?\\ - level=1&offset=0&limit=100" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def list_communities( - collection_id: UUID = Path(...), - level: Optional[int] = Query(None), - offset: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ( - WrappedKGCommunitiesResponse - ): # PaginatedResultsWrapper[list[Community]]: - """Lists communities in the graph with optional filtering and pagination. - - Each community represents a group of related entities with: - - Community number and hierarchical level - - Member entities and relationships - - Generated name and summary - - Key findings and insights - - Impact rating and explanation - """ - communities = await self.services["kg"].list_communities( - collection_id, levels, community_numbers, offset, limit - ) - return communities # type: ignore + # else: + # if run_with_orchestration: + # return await self.orchestration_provider.run_workflow( # type: ignore + # "enrich-graph", {"request": workflow_input}, {} + # ) + # else: + # from core.main.orchestration import simple_kg_factory + + # simple_kg = simple_kg_factory(self.services["kg"]) + # await simple_kg["enrich-graph"](workflow_input) + # return { # type: ignore + # "message": "Communities created successfully.", + # "task_id": None, + # } - @self.router.get( - "/graphs/{collection_id}/communities/{community_id}", - summary="Get community details", - ) - @self.base_endpoint - async def get_community( - collection_id: UUID = Path(...), - community_id: UUID = Path(...), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[Community]: - """Retrieves details of a specific community.""" - raise NotImplementedError("Not implemented") - # community = await self.services["kg"].get_community( - # collection_id, community_id - # ) - # if not community: - # raise R2RException("Community not found", 404) - # return community # type: ignore + # @self.router.post( + # "/graphs/{collection_id}/communities/{community_id}", + # summary="Update community", + # openapi_extra={ + # "x-codeSamples": [ + # { + # "lang": "Python", + # "source": textwrap.dedent( + # """ + # from r2r import R2RClient - @self.router.delete( - "/graphs/{collection_id}/communities", - summary="Delete all communities", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient + # client = R2RClient("http://localhost:7272") + # # when using auth, do client.login(...) - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) + # result = client.graphs.update_community( + # collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + # community_id="5xyz789a-bc12-3def-4ghi-jk5lm6no7pq8", + # community_update={ + # "metadata": { + # "topic": "Technology", + # "description": "Tech companies and products" + # } + # } + # )""" + # ), + # }, + # { + # "lang": "cURL", + # "source": textwrap.dedent( + # """ + # curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities/5xyz789a-bc12-3def-4ghi-jk5lm6no7pq8" \\ + # -H "Content-Type: application/json" \\ + # -H "Authorization: Bearer YOUR_API_KEY" \\ + # -d '{ + # "metadata": { + # "topic": "Technology", + # "description": "Tech companies and products" + # } + # }'""" + # ), + # }, + # ] + # }, + # ) + # @self.base_endpoint + # async def update_community( + # collection_id: UUID = Path(...), + # community_id: UUID = Path(...), + # community_update: dict = Body(...), + # auth_user=Depends(self.providers.auth.auth_wrapper), + # ) -> ResultsWrapper[Community]: + # """Updates a community's metadata.""" - # Delete all communities - result = client.graphs.communities.delete( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" - ) + # raise NotImplementedError("Not implemented") - # Delete specific level - result = client.graphs.communities.delete( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - level=1 - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - # Delete all communities - curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities" \\ - -H "Authorization: Bearer YOUR_API_KEY" + # @self.router.get( + # "/graphs/{collection_id}/communities", + # summary="List communities", + # openapi_extra={ + # "x-codeSamples": [ + # { + # "lang": "Python", + # "source": textwrap.dedent( + # """ + # from r2r import R2RClient - # Delete specific level - curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities?level=1" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def delete_communities( - collection_id: UUID = Path(...), - level: Optional[int] = Query( - None, - description="Specific community level to delete. If not provided, all levels will be deleted.", - ), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[dict]: - """ - Deletes communities from the graph. Can delete all communities or a specific level. - This is useful when you want to recreate communities with different parameters. - """ - raise NotImplementedError("Not implemented") - # if not auth_user.is_superuser: - # raise R2RException( - # "Only superusers can delete communities", 403 - # ) + # client = R2RClient("http://localhost:7272") + # # when using auth, do client.login(...) - # await self.services["kg"].delete_communities(collection_id, level) + # result = client.graphs.communities.list( + # collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + # level=1, + # offset=0, + # limit=100 + # )""" + # ), + # }, + # { + # "lang": "cURL", + # "source": textwrap.dedent( + # """ + # curl -X GET "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities?\\ + # level=1&offset=0&limit=100" \\ + # -H "Authorization: Bearer YOUR_API_KEY" """ + # ), + # }, + # ] + # }, + # ) + # @self.base_endpoint + # async def list_communities( + # collection_id: UUID = Path(...), + # offset: int = Query(0, ge=0), + # limit: int = Query(100, ge=1, le=1000), + # auth_user=Depends(self.providers.auth.auth_wrapper), + # ) -> ( + # WrappedKGCommunitiesResponse + # ): # PaginatedResultsWrapper[list[Community]]: + # """Lists communities in the graph with optional filtering and pagination. + + # Each community represents a group of related entities with: + # - Community number and hierarchical level + # - Member entities and relationships + # - Generated name and summary + # - Key findings and insights + # - Impact rating and explanation + # """ + # communities = await self.services["kg"].list_communities( + # collection_id, offset, limit + # ) + # return communities # type: ignore + + # @self.router.delete( + # "/graphs/{collection_id}/communities", + # summary="Delete communities", + # openapi_extra={ + # "x-codeSamples": [ + # { + # "lang": "Python", + # "source": textwrap.dedent( + # """ + # from r2r import R2RClient - # if level is not None: - # return { # type: ignore - # "message": f"Communities at level {level} deleted successfully" - # } - # return {"message": "All communities deleted successfully"} # type: ignore + # client = R2RClient("http://localhost:7272") + # # when using auth, do client.login(...) - @self.router.delete( - "/graphs/{collection_id}/communities/{community_id}", - summary="Delete a specific community", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient + # # Delete all communities + # result = client.graphs.communities.delete( + # collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7" + # ) - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) + # # Delete specific level + # result = client.graphs.communities.delete( + # collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + # level=1 + # )""" + # ), + # }, + # { + # "lang": "cURL", + # "source": textwrap.dedent( + # """ + # # Delete all communities + # curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities" \\ + # -H "Authorization: Bearer YOUR_API_KEY" - result = client.graphs.delete_community( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - community_id="5xyz789a-bc12-3def-4ghi-jk5lm6no7pq8" - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities/5xyz789a-bc12-3def-4ghi-jk5lm6no7pq8" \\ - -H "Authorization: Bearer YOUR_API_KEY" """ - ), - }, - ] - }, - ) - @self.base_endpoint - async def delete_community( - collection_id: UUID = Path(...), - community_id: UUID = Path(...), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGDeletionResponse: - """ - Deletes a specific community by ID. - This operation will not affect other communities or the underlying entities. - """ - raise NotImplementedError("Not implemented") - # if not auth_user.is_superuser: - # raise R2RException( - # "Only superusers can delete communities", 403 - # ) - - # # First check if community exists - # community = await self.services["kg"].get_community( - # collection_id, community_id - # ) - # if not community: - # raise R2RException("Community not found", 404) - - await self.services["kg"].delete_community( - collection_id, community_id - ) - return True # type: ignore + # # Delete specific level + # curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities?level=1" \\ + # -H "Authorization: Bearer YOUR_API_KEY" """ + # ), + # }, + # ] + # }, + # ) + # @self.base_endpoint + # async def delete_communities( + # collection_id: UUID = Path(...), + # level: Optional[int] = Query( + # None, + # description="Specific community level to delete. If not provided, all levels will be deleted.", + # ), + # auth_user=Depends(self.providers.auth.auth_wrapper), + # ) -> ResultsWrapper[dict]: + # """ + # Deletes communities from the graph. Can delete all communities or a specific level. + # This is useful when you want to recreate communities with different parameters. + # """ + # raise NotImplementedError("Not implemented") + # # if not auth_user.is_superuser: + # # raise R2RException( + # # "Only superusers can delete communities", 403 + # # ) + + # # await self.services["kg"].delete_communities(collection_id, level) + + # # if level is not None: + # # return { # type: ignore + # # "message": f"Communities at level {level} deleted successfully" + # # } + # # return {"message": "All communities deleted successfully"} # type: ignore + + # @self.router.delete( + # "/graphs/{collection_id}/communities/{community_id}", + # summary="Delete a specific community", + # openapi_extra={ + # "x-codeSamples": [ + # { + # "lang": "Python", + # "source": textwrap.dedent( + # """ + # from r2r import R2RClient - @self.router.post( - "/graphs/{collection_id}/tune-prompt", - summary="Tune a graph-related prompt", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": textwrap.dedent( - """ - from r2r import R2RClient + # client = R2RClient("http://localhost:7272") + # # when using auth, do client.login(...) - client = R2RClient("http://localhost:7272") - # when using auth, do client.login(...) + # result = client.graphs.delete_community( + # collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + # community_id="5xyz789a-bc12-3def-4ghi-jk5lm6no7pq8" + # )""" + # ), + # }, + # { + # "lang": "cURL", + # "source": textwrap.dedent( + # """ + # curl -X DELETE "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/communities/5xyz789a-bc12-3def-4ghi-jk5lm6no7pq8" \\ + # -H "Authorization: Bearer YOUR_API_KEY" """ + # ), + # }, + # ] + # }, + # ) + # @self.base_endpoint + # async def delete_community( + # collection_id: UUID = Path(...), + # community_id: UUID = Path(...), + # auth_user=Depends(self.providers.auth.auth_wrapper), + # ) -> WrappedKGDeletionResponse: + # """ + # Deletes a specific community by ID. + # This operation will not affect other communities or the underlying entities. + # """ + # raise NotImplementedError("Not implemented") + # # if not auth_user.is_superuser: + # # raise R2RException( + # # "Only superusers can delete communities", 403 + # # ) + + # # # First check if community exists + # # community = await self.services["kg"].get_community( + # # collection_id, community_id + # # ) + # # if not community: + # # raise R2RException("Community not found", 404) + + # await self.services["kg"].delete_community( + # collection_id, community_id + # ) + # return True # type: ignore - result = client.graphs.tune_prompt( - collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", - prompt_name="graphrag_relationships_extraction_few_shot", - documents_limit=100, - chunks_limit=1000 - )""" - ), - }, - { - "lang": "cURL", - "source": textwrap.dedent( - """ - curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/tune-prompt" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "prompt_name": "graphrag_relationships_extraction_few_shot", - "documents_limit": 100, - "chunks_limit": 1000 - }'""" - ), - }, - ] - }, - ) - @self.base_endpoint - async def tune_prompt( - collection_id: UUID = Path(...), - prompt_name: str = Body( - ..., - description="The prompt to tune. Valid options: graphrag_relationships_extraction_few_shot, graphrag_entity_description, graphrag_communities", - ), - documents_offset: int = Body(0, ge=0), - documents_limit: int = Body(100, ge=1), - chunks_offset: int = Body(0, ge=0), - chunks_limit: int = Body(100, ge=1), - auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedKGTunePromptResponse: - """Tunes a graph operation prompt using collection data. + # @self.router.post( + # "/graphs/{collection_id}/tune-prompt", + # summary="Tune a graph-related prompt", + # openapi_extra={ + # "x-codeSamples": [ + # { + # "lang": "Python", + # "source": textwrap.dedent( + # """ + # from r2r import R2RClient - Uses sample documents and chunks from the collection to tune prompts for: - - Entity and relationship extraction - - Entity description generation - - Community report generation - """ - if not auth_user.is_superuser: - raise R2RException("Only superusers can tune prompts", 403) - - tuned_prompt = await self.services["kg"].tune_prompt( - prompt_name=prompt_name, - collection_id=collection_id, - documents_offset=documents_offset, - documents_limit=documents_limit, - chunks_offset=chunks_offset, - chunks_limit=chunks_limit, - ) + # client = R2RClient("http://localhost:7272") + # # when using auth, do client.login(...) - return tuned_prompt # type: ignore + # result = client.graphs.tune_prompt( + # collection_id="d09dedb1-b2ab-48a5-b950-6e1f464d83e7", + # prompt_name="graphrag_relationships_extraction_few_shot", + # documents_limit=100, + # chunks_limit=1000 + # )""" + # ), + # }, + # { + # "lang": "cURL", + # "source": textwrap.dedent( + # """ + # curl -X POST "https://api.example.com/v3/graphs/d09dedb1-b2ab-48a5-b950-6e1f464d83e7/tune-prompt" \\ + # -H "Content-Type: application/json" \\ + # -H "Authorization: Bearer YOUR_API_KEY" \\ + # -d '{ + # "prompt_name": "graphrag_relationships_extraction_few_shot", + # "documents_limit": 100, + # "chunks_limit": 1000 + # }'""" + # ), + # }, + # ] + # }, + # ) + # @self.base_endpoint + # async def tune_prompt( + # collection_id: UUID = Path(...), + # prompt_name: str = Body( + # ..., + # description="The prompt to tune. Valid options: graphrag_relationships_extraction_few_shot, graphrag_entity_description, graphrag_communities", + # ), + # documents_offset: int = Body(0, ge=0), + # documents_limit: int = Body(100, ge=1), + # chunks_offset: int = Body(0, ge=0), + # chunks_limit: int = Body(100, ge=1), + # auth_user=Depends(self.providers.auth.auth_wrapper), + # ) -> WrappedKGTunePromptResponse: + # """Tunes a graph operation prompt using collection data. + + # Uses sample documents and chunks from the collection to tune prompts for: + # - Entity and relationship extraction + # - Entity description generation + # - Community report generation + # """ + # if not auth_user.is_superuser: + # raise R2RException("Only superusers can tune prompts", 403) + + # tuned_prompt = await self.services["kg"].tune_prompt( + # prompt_name=prompt_name, + # collection_id=collection_id, + # documents_offset=documents_offset, + # documents_limit=documents_limit, + # chunks_offset=chunks_offset, + # chunks_limit=chunks_limit, + # ) + + # return tuned_prompt # type: ignore diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index 775d44f4b..6b9075975 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -309,11 +309,14 @@ async def delete_community_v3( async def list_communities_v3( self, id: UUID, - level: EntityLevel, + offset: int, + limit: int, **kwargs, ): return await self.providers.database.graph_handler.communities.get( - id, level + collection_id=id, + offset=offset, + limit=limit, ) # TODO: deprecate this @@ -442,7 +445,7 @@ async def get_graph_status( collection_id: UUID, **kwargs, ): - return await self.providers.database.get_graph_status(collection_id) + raise NotImplementedError("Not implemented") @telemetry_event("kg_clustering") async def kg_clustering( diff --git a/py/core/pipes/kg/community_summary.py b/py/core/pipes/kg/community_summary.py index c303e3d89..ef8b6796a 100644 --- a/py/core/pipes/kg/community_summary.py +++ b/py/core/pipes/kg/community_summary.py @@ -59,16 +59,16 @@ async def community_summary_prompt( entity_map: dict[str, dict[str, list[Any]]] = {} for entity in entities: if not entity.name in entity_map: - entity_map[entity.name] = {"entities": [], "relationships": []} - entity_map[entity.name]["entities"].append(entity) + entity_map[entity.name] = {"entities": [], "relationships": []} # type: ignore + entity_map[entity.name]["entities"].append(entity) # type: ignore for relationship in relationships: if not relationship.subject in entity_map: - entity_map[relationship.subject] = { + entity_map[relationship.subject] = { # type: ignore "entities": [], "relationships": [], } - entity_map[relationship.subject]["relationships"].append( + entity_map[relationship.subject]["relationships"].append( # type: ignore relationship ) diff --git a/py/core/pipes/kg/deduplication_summary.py b/py/core/pipes/kg/deduplication_summary.py index ba41c5c47..dc208fb34 100644 --- a/py/core/pipes/kg/deduplication_summary.py +++ b/py/core/pipes/kg/deduplication_summary.py @@ -120,7 +120,9 @@ async def _prepare_and_upsert_entities( f"Upserting {len(entities_batch)} entities for collection {collection_id}" ) - await self.database_provider.update_entity_descriptions(entities_batch) + await self.database_provider.graph_handler.update_entity_descriptions( + entities_batch + ) logger.info( f"Upserted {len(entities_batch)} entities for collection {collection_id}" diff --git a/py/core/pipes/kg/storage.py b/py/core/pipes/kg/storage.py index ad672ca79..39d4449fb 100644 --- a/py/core/pipes/kg/storage.py +++ b/py/core/pipes/kg/storage.py @@ -5,12 +5,12 @@ from core.base import ( AsyncState, - DatabaseProvider, KGExtraction, R2RDocumentProcessingError, ) from core.base.pipes.base_pipe import AsyncPipe from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider +from core.providers.database.postgres import PostgresDBProvider logger = logging.getLogger() @@ -22,7 +22,7 @@ class Input(AsyncPipe.Input): def __init__( self, - database_provider: DatabaseProvider, + database_provider: PostgresDBProvider, config: AsyncPipe.PipeConfig, logging_provider: SqlitePersistentLoggingProvider, storage_batch_size: int = 1, diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 184747449..df2f12784 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -1,7 +1,7 @@ import json import logging import time -from typing import Any, AsyncGenerator, Optional, Tuple +from typing import Any, AsyncGenerator, Optional, Tuple, Union from uuid import UUID import asyncpg @@ -64,13 +64,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: - dimension: Dimension size for vector embeddings - quantization_type: Type of vector quantization to use """ - self.dimension = kwargs.get("dimension") - self.quantization_type = kwargs.get("quantization_type") - super().__init__( - project_name=kwargs.get("project_name"), - connection_manager=kwargs.get("connection_manager"), - ) + # The signature to this class isn't finalized yet, so we need to use type ignore + self.dimension: int = kwargs.get("dimension") # type: ignore + self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type", VectorQuantizationType.FP32) # type: ignore + + self.project_name: str = kwargs.get("project_name") # type: ignore + self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore async def create_tables(self) -> None: """Create the necessary database tables for storing entities. @@ -168,7 +168,7 @@ async def get( attributes: Optional[list[str]] = None, offset: int = 0, limit: int = -1, - ) -> list[Entity]: + ): """Retrieve entities from the database based on various filters. Args: @@ -214,9 +214,6 @@ async def get( params.extend([offset, limit]) - print(QUERY) - print(params) - output = await self.connection_manager.fetch_query(QUERY, params) if attributes: @@ -244,10 +241,10 @@ async def update(self, entity: Entity) -> None: Raises: R2RException: If the entity does not exist in the database """ - table_name = entity.level.value + "_entity" + table_name = entity.level.value + "_entity" # type: ignore filter = "id = $1" - params = [entity.id] + params: list[Any] = [entity.id] if entity.level == EntityLevel.CHUNK: filter += " AND chunk_ids = ANY($2)" params.append(entity.chunk_ids) @@ -292,9 +289,9 @@ async def delete(self, entity: Entity) -> None: entity_id: UUID of the entity to delete level: Level of the entity (chunk, document, or collection) """ - table_name = entity.level.value + "_entity" + table_name = entity.level.value + "_entity" # type: ignore return await _delete_object( - object_id=entity.id, + object_id=entity.id, # type: ignore full_table_name=self._get_table_name(table_name), connection_manager=self.connection_manager, ) @@ -302,8 +299,8 @@ async def delete(self, entity: Entity) -> None: class PostgresRelationshipHandler(RelationshipHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: - self.project_name = kwargs.get("project_name") - self.connection_manager = kwargs.get("connection_manager") + self.project_name: str = kwargs.get("project_name") # type: ignore + self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore async def create_tables(self) -> None: """Create the relationships table if it doesn't exist.""" @@ -354,7 +351,7 @@ async def get( attributes: Optional[list[str]] = None, offset: int = 0, limit: int = -1, - ) -> list[Relationship]: + ): """Get relationships from storage by ID.""" filter = { @@ -366,11 +363,11 @@ async def get( if entity_names: filter += " AND (subject = ANY($2) OR object = ANY($2))" - params.append(entity_names) + params.append(entity_names) # type: ignore if relationship_types: filter += " AND predicate = ANY($3)" - params.append(relationship_types) + params.append(relationship_types) # type: ignore QUERY = f""" SELECT * FROM {self._get_table_name("chunk_relationship")} @@ -378,7 +375,7 @@ async def get( OFFSET ${len(params)+1} LIMIT ${len(params) + 2} """ - params.extend([offset, limit]) + params.extend([offset, limit]) # type: ignore rows = await self.connection_manager.fetch_query(QUERY, params) QUERY_COUNT = f""" @@ -388,7 +385,7 @@ async def get( await self.connection_manager.fetch_query(QUERY_COUNT, params[:-2]) )[0]["count"] - return [Relationship(**row) for row in rows], count + return [Relationship(**row) for row in rows], count # type: ignore async def update(self, relationship: Relationship) -> None: return await _update_object( @@ -412,10 +409,10 @@ async def delete(self, relationship: Relationship) -> None: class PostgresCommunityHandler(CommunityHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: - self.project_name = kwargs.get("project_name") - self.connection_manager = kwargs.get("connection_manager") - self.dimension = kwargs.get("dimension") - self.quantization_type = kwargs.get("quantization_type") + self.project_name: str = kwargs.get("project_name") # type: ignore + self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore + self.dimension: int = kwargs.get("dimension") # type: ignore + self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore async def create_tables(self) -> None: @@ -476,24 +473,33 @@ async def update(self, community: Community) -> None: async def delete(self, community: Community) -> None: return await _delete_object( - object_id=community.id, + object_id=community.id, # type: ignore full_table_name=self._get_table_name("community"), connection_manager=self.connection_manager, ) - async def get( - self, collection_id: UUID, offset: int, limit: int - ) -> list[Community]: + async def get(self, collection_id: UUID, offset: int, limit: int): QUERY = f""" SELECT * FROM {self._get_table_name("community")} WHERE collection_id = $1 OFFSET $2 LIMIT $3 """ params = [collection_id, offset, limit] - return [ + communities = [ Community(**row) for row in await self.connection_manager.fetch_query(QUERY, params) ] + QUERY_COUNT = f""" + SELECT COUNT(*) FROM {self._get_table_name("community")} WHERE collection_id = $1 + """ + count = ( + await self.connection_manager.fetch_query( + QUERY_COUNT, [collection_id] + ) + )[0]["count"] + + return communities, count + class PostgresGraphHandler(GraphHandler): """Handler for Knowledge Graph METHODS in PostgreSQL.""" @@ -509,11 +515,11 @@ def __init__( **kwargs: Any, ) -> None: - self.project_name = kwargs.get("project_name") - self.connection_manager = kwargs.get("connection_manager") - self.dimension = kwargs.get("dimension") - self.quantization_type = kwargs.get("quantization_type") - self.collection_handler = kwargs.get("collection_handler") + self.project_name: str = kwargs.get("project_name") # type: ignore + self.connection_manager: PostgresConnectionManager = kwargs.get("connection_manager") # type: ignore + self.dimension: int = kwargs.get("dimension") # type: ignore + self.quantization_type: VectorQuantizationType = kwargs.get("quantization_type") # type: ignore + self.collection_handler: PostgresCollectionHandler = kwargs.get("collection_handler") # type: ignore self.entities = PostgresEntityHandler(*args, **kwargs) self.relationships = PostgresRelationshipHandler(*args, **kwargs) @@ -554,7 +560,7 @@ async def create(self, graph: Graph) -> None: VALUES ($1, $2, $3, $4, $5, $6, $7) """ await self.connection_manager.execute_query( - QUERY, *graph.to_dict().values() + QUERY, [*graph.to_dict().values()] ) async def update(self, graph: Graph) -> None: @@ -562,21 +568,21 @@ async def update(self, graph: Graph) -> None: UPDATE {self._get_table_name("graph")} SET status = $2, updated_at = $3, document_ids = $4, collection_ids = $5, attributes = $6 WHERE id = $1 """ await self.connection_manager.execute_query( - QUERY, *graph.to_dict().values() + QUERY, [*graph.to_dict().values()] ) async def delete(self, graph_id: UUID) -> None: QUERY = f""" DELETE FROM {self._get_table_name("graph")} WHERE id = $1 """ - await self.connection_manager.execute_query(QUERY, graph_id) + await self.connection_manager.execute_query(QUERY, [graph_id]) - async def get(self, graph_id: UUID) -> Graph: + async def get(self, graph_id: UUID): QUERY = f""" SELECT * FROM {self._get_table_name("graph")} WHERE id = $1 """ return Graph.from_dict( - await self.connection_manager.fetch_query(QUERY, graph_id) + await self.connection_manager.fetch_query(QUERY, [graph_id]) ) async def add_document(self, graph_id: UUID, document_id: UUID) -> None: @@ -584,7 +590,7 @@ async def add_document(self, graph_id: UUID, document_id: UUID) -> None: UPDATE {self._get_table_name("graph")} SET document_ids = array_append(document_ids, $2) WHERE id = $1 """ await self.connection_manager.execute_query( - QUERY, graph_id, document_id + QUERY, [graph_id, document_id] ) async def remove_document(self, graph_id: UUID, document_id: UUID) -> None: @@ -592,7 +598,7 @@ async def remove_document(self, graph_id: UUID, document_id: UUID) -> None: UPDATE {self._get_table_name("graph")} SET document_ids = array_remove(document_ids, $2) WHERE id = $1 """ await self.connection_manager.execute_query( - QUERY, graph_id, document_id + QUERY, [graph_id, document_id] ) async def add_collection( @@ -602,7 +608,7 @@ async def add_collection( UPDATE {self._get_table_name("graph")} SET collection_ids = array_append(collection_ids, $2) WHERE id = $1 """ await self.connection_manager.execute_query( - QUERY, graph_id, collection_id + QUERY, [graph_id, collection_id] ) async def remove_collection( @@ -612,7 +618,7 @@ async def remove_collection( UPDATE {self._get_table_name("graph")} SET collection_ids = array_remove(collection_ids, $2) WHERE id = $1 """ await self.connection_manager.execute_query( - QUERY, graph_id, collection_id + QUERY, [graph_id, collection_id] ) ###### ESTIMATION METHODS ###### @@ -919,7 +925,7 @@ async def add_entities( ) entity_dict["description_embedding"] = ( str(entity_dict["description_embedding"]) - if entity_dict.get("description_embedding") + if entity_dict.get("description_embedding") # type: ignore else None ) cleaned_entities.append(entity_dict) @@ -1096,14 +1102,6 @@ async def get_relationships( ####################### COMMUNITY METHODS ####################### - async def get_communities(self, collection_id: UUID) -> list[Community]: - QUERY = f""" - SELECT *c FROM {self._get_table_name("community")} WHERE collection_id = $1 - """ - return await self.connection_manager.fetch_query( - QUERY, [collection_id] - ) - async def check_communities_exist( self, collection_id: UUID, offset: int, limit: int ) -> list[int]: @@ -1738,7 +1736,7 @@ async def _get_relationship_ids_cache( ) -> dict[str, list[int]]: # caching the relationship ids - relationship_ids_cache = dict[str, list[int]]() + relationship_ids_cache = dict[str, list[Union[int, UUID]]]() for relationship in relationships: if ( relationship.subject not in relationship_ids_cache @@ -1762,7 +1760,7 @@ async def _get_relationship_ids_cache( relationship.id ) - return relationship_ids_cache + return relationship_ids_cache # type: ignore async def _incremental_clustering( self,