diff --git a/js/sdk/src/r2rClient.ts b/js/sdk/src/r2rClient.ts index 4ba1e93cd..25ccd6d12 100644 --- a/js/sdk/src/r2rClient.ts +++ b/js/sdk/src/r2rClient.ts @@ -494,6 +494,7 @@ export class r2rClient { document_ids?: string[]; user_ids?: (string | null)[]; ingestion_config?: Record; + collection_ids?: string[]; run_with_orchestration?: boolean; } = {}, ): Promise { @@ -560,6 +561,9 @@ export class r2rClient { ingestion_config: options.ingestion_config ? JSON.stringify(options.ingestion_config) : undefined, + collection_ids: options.collection_ids + ? JSON.stringify(options.collection_ids) + : undefined, run_with_orchestration: options.run_with_orchestration != undefined ? String(options.run_with_orchestration) @@ -601,6 +605,7 @@ export class r2rClient { document_ids: string[]; metadatas?: Record[]; ingestion_config?: Record; + collection_ids?: string[]; run_with_orchestration?: boolean; }, ): Promise { @@ -642,6 +647,9 @@ export class r2rClient { ingestion_config: options.ingestion_config ? JSON.stringify(options.ingestion_config) : undefined, + collection_ids: options.collection_ids + ? JSON.stringify(options.collection_ids) + : undefined, run_with_orchestration: options.run_with_orchestration != undefined ? String(options.run_with_orchestration) @@ -675,6 +683,7 @@ export class r2rClient { chunks: RawChunk[], documentId?: string, metadata?: Record, + collection_ids?: string[], run_with_orchestration?: boolean, ): Promise> { this._ensureAuthenticated(); @@ -682,6 +691,7 @@ export class r2rClient { chunks: chunks, document_id: documentId, metadata: metadata, + collection_ids: collection_ids, run_with_orchestration: run_with_orchestration, }; diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 1fa824a80..a130cd416 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -20,6 +20,7 @@ ) from shared.abstractions.graph import ( Community, + CommunityInfo, CommunityReport, Entity, EntityLevel, @@ -27,7 +28,6 @@ KGExtraction, RelationshipType, Triple, - CommunityInfo, ) from shared.abstractions.ingestion import ( ChunkEnrichmentSettings, diff --git a/py/core/main/api/ingestion_router.py b/py/core/main/api/ingestion_router.py index 7d1eebd09..2c451aa48 100644 --- a/py/core/main/api/ingestion_router.py +++ b/py/core/main/api/ingestion_router.py @@ -122,6 +122,10 @@ async def ingest_files_app( None, description=ingest_files_descriptions.get("document_ids"), ), + collection_ids: Optional[Json[list[list[UUID]]]] = Form( + None, + description="Optional collection IDs for the documents, if provided the document will be assigned to them at ingestion.", + ), metadatas: Optional[Json[list[dict]]] = Form( None, description=ingest_files_descriptions.get("metadatas") ), @@ -181,6 +185,9 @@ async def ingest_files_app( "ingestion_config": ingestion_config, "user": auth_user.model_dump_json(), "size_in_bytes": content_length, + "collection_ids": ( + collection_ids[it] if collection_ids else None + ), "is_update": False, } @@ -241,6 +248,10 @@ async def update_files_app( document_ids: Optional[Json[list[UUID]]] = Form( None, description=ingest_files_descriptions.get("document_ids") ), + collection_ids: Optional[Json[list[list[UUID]]]] = Form( + None, + description="Optional collection IDs for the documents, if provided the document will be assigned to them at ingestion.", + ), metadatas: Optional[Json[list[dict]]] = Form( None, description=ingest_files_descriptions.get("metadatas") ), @@ -314,6 +325,7 @@ async def update_files_app( "ingestion_config": ingestion_config, "user": auth_user.model_dump_json(), "is_update": True, + "collection_ids": collection_ids, } if run_with_orchestration: @@ -357,6 +369,10 @@ async def ingest_chunks_app( metadata: Optional[dict] = Body( None, description=ingest_files_descriptions.get("metadata") ), + collection_ids: Optional[Json[list[list[UUID]]]] = Body( + None, + description="Optional collection IDs for the documents, if provided the document will be assigned to them at ingestion.", + ), run_with_orchestration: Optional[bool] = Body( True, description=ingest_files_descriptions.get( @@ -388,6 +404,7 @@ async def ingest_chunks_app( "chunks": [chunk.model_dump() for chunk in chunks], "metadata": metadata or {}, "user": auth_user.model_dump_json(), + "collection_ids": collection_ids, } if run_with_orchestration: raw_message = await self.orchestration_provider.run_workflow( diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index aa5ff5ab8..6672c3d20 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -154,17 +154,43 @@ async def parse(self, context: Context) -> dict: status=IngestionStatus.SUCCESS, ) - # TODO: Move logic onto the `management service` - collection_id = generate_default_user_collection_id( - document_info.user_id - ) - await service.providers.database.assign_document_to_collection_relational( - document_id=document_info.id, - collection_id=collection_id, - ) - await service.providers.database.assign_document_to_collection_vector( - document_id=document_info.id, collection_id=collection_id + collection_ids = context.workflow_input()["request"].get( + "collection_ids" ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.user_id + ) + await service.providers.database.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.assign_document_to_collection_vector( + document_id=document_info.id, + collection_id=collection_id, + ) + else: + for collection_id in collection_ids: + try: + await service.providers.database.create_collection( + name=document_info.title, + collection_id=collection_id, + description="", + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.assign_document_to_collection_vector( + document_id=document_info.id, + collection_id=collection_id, + ) # get server chunk enrichment settings and override parts of it if provided in the ingestion config server_chunk_enrichment_settings = getattr( @@ -450,16 +476,43 @@ async def finalize(self, context: Context) -> dict: try: # TODO - Move logic onto the `management service` - collection_id = generate_default_user_collection_id( - document_info.user_id - ) - await self.ingestion_service.providers.database.assign_document_to_collection_relational( - document_id=document_info.id, - collection_id=collection_id, - ) - await self.ingestion_service.providers.database.assign_document_to_collection_vector( - document_id=document_info.id, collection_id=collection_id + collection_ids = context.workflow_input()["request"].get( + "collection_ids" ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.user_id + ) + await service.providers.database.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.assign_document_to_collection_vector( + document_id=document_info.id, + collection_id=collection_id, + ) + else: + for collection_id in collection_ids: + try: + await service.providers.database.create_collection( + name=document_info.title or "N/A", + collection_id=collection_id, + description="", + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.assign_document_to_collection_vector( + document_id=document_info.id, + collection_id=collection_id, + ) except Exception as e: logger.error( f"Error during assigning document to collection: {str(e)}" diff --git a/py/core/main/orchestration/simple/ingestion_workflow.py b/py/core/main/orchestration/simple/ingestion_workflow.py index 9cf71d1cd..9665e66b1 100644 --- a/py/core/main/orchestration/simple/ingestion_workflow.py +++ b/py/core/main/orchestration/simple/ingestion_workflow.py @@ -67,18 +67,43 @@ async def ingest_files(input_data): document_info, status=IngestionStatus.SUCCESS ) + collection_ids = parsed_data.get("collection_ids") + try: - # TODO - Move logic onto management service - collection_id = generate_default_user_collection_id( - str(document_info.user_id) - ) - await service.providers.database.assign_document_to_collection_relational( - document_id=document_info.id, - collection_id=collection_id, - ) - await service.providers.database.assign_document_to_collection_vector( - document_info.id, collection_id - ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.user_id + ) + await service.providers.database.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.assign_document_to_collection_vector( + document_id=document_info.id, + collection_id=collection_id, + ) + else: + for collection_id in collection_ids: + try: + await service.providers.database.create_collection( + name=document_info.title, + collection_id=collection_id, + description="", + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.assign_document_to_collection_vector( + document_id=document_info.id, + collection_id=collection_id, + ) except Exception as e: logger.error( f"Error during assigning document to collection: {str(e)}" @@ -229,18 +254,44 @@ async def ingest_chunks(input_data): document_info, status=IngestionStatus.SUCCESS ) + collection_ids = parsed_data.get("collection_ids") + try: # TODO - Move logic onto management service - collection_id = generate_default_user_collection_id( - str(document_info.user_id) - ) - await service.providers.database.assign_document_to_collection_relational( - document_id=document_info.id, - collection_id=collection_id, - ) - await service.providers.database.assign_document_to_collection_vector( - document_id=document_info.id, collection_id=collection_id - ) + if not collection_ids: + # TODO: Move logic onto the `management service` + collection_id = generate_default_user_collection_id( + document_info.user_id + ) + await service.providers.database.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.assign_document_to_collection_vector( + document_id=document_info.id, + collection_id=collection_id, + ) + else: + for collection_id in collection_ids: + try: + await service.providers.database.create_collection( + name=document_info.title, + collection_id=collection_id, + description="", + ) + except Exception as e: + logger.warning( + f"Warning, could not create collection with error: {str(e)}" + ) + + await service.providers.database.assign_document_to_collection_relational( + document_id=document_info.id, + collection_id=collection_id, + ) + await service.providers.database.assign_document_to_collection_vector( + document_id=document_info.id, + collection_id=collection_id, + ) except Exception as e: logger.error( f"Error during assigning document to collection: {str(e)}" diff --git a/py/core/main/orchestration/simple/kg_workflow.py b/py/core/main/orchestration/simple/kg_workflow.py index cca48cc7d..60bc47735 100644 --- a/py/core/main/orchestration/simple/kg_workflow.py +++ b/py/core/main/orchestration/simple/kg_workflow.py @@ -3,11 +3,10 @@ import math import uuid -from core import GenerationConfig -from core import R2RException +from core import GenerationConfig, R2RException +from core.base.abstractions import KGEnrichmentStatus from ...services import KgService -from core.base.abstractions import KGEnrichmentStatus logger = logging.getLogger() diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 54832a785..9e6278158 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -73,6 +73,7 @@ async def ingest_file_ingress( metadata: Optional[dict] = None, version: Optional[str] = None, is_update: bool = False, + collection_ids: Optional[list[UUID]] = None, *args: Any, **kwargs: Any, ) -> dict: @@ -634,6 +635,7 @@ def parse_ingest_file_input(data: dict) -> dict: "is_update": data.get("is_update", False), "file_data": data["file_data"], "size_in_bytes": data["size_in_bytes"], + "collection_ids": data.get("collection_ids", []), } @staticmethod diff --git a/py/core/pipes/kg/deduplication.py b/py/core/pipes/kg/deduplication.py index 669c26b39..b6d8e74d8 100644 --- a/py/core/pipes/kg/deduplication.py +++ b/py/core/pipes/kg/deduplication.py @@ -1,7 +1,8 @@ +import json import logging from typing import Any, Optional, Union from uuid import UUID -import json + from core.base import AsyncState, R2RException from core.base.abstractions import Entity, KGEntityDeduplicationType from core.base.pipes import AsyncPipe diff --git a/py/core/providers/database/kg.py b/py/core/providers/database/kg.py index 3238975e5..48b814036 100644 --- a/py/core/providers/database/kg.py +++ b/py/core/providers/database/kg.py @@ -17,11 +17,11 @@ Triple, ) from core.base.abstractions import ( - EntityLevel, CommunityInfo, - KGEnrichmentStatus, + EntityLevel, KGCreationSettings, KGEnrichmentSettings, + KGEnrichmentStatus, KGEntityDeduplicationSettings, VectorQuantizationType, ) diff --git a/py/sdk/mixins/ingestion.py b/py/sdk/mixins/ingestion.py index 4aa8d8d5e..ceb5054f5 100644 --- a/py/sdk/mixins/ingestion.py +++ b/py/sdk/mixins/ingestion.py @@ -14,6 +14,7 @@ async def ingest_files( document_ids: Optional[list[Union[str, UUID]]] = None, metadatas: Optional[list[dict]] = None, ingestion_config: Optional[dict] = None, + collection_ids: Optional[list[list[Union[str, UUID]]]] = None, run_with_orchestration: Optional[bool] = None, ) -> dict: """ @@ -74,6 +75,17 @@ async def ingest_files( if run_with_orchestration is not None: data["run_with_orchestration"] = str(run_with_orchestration) + if collection_ids: + data["collection_ids"] = json.dumps( + [ + [ + str(collection_id) + for collection_id in doc_collection_ids + ] + for doc_collection_ids in collection_ids + ] + ) + return await self._make_request( # type: ignore "POST", "ingest_files", data=data, files=files_tuples ) @@ -84,6 +96,7 @@ async def update_files( document_ids: Optional[list[Union[str, UUID]]] = None, metadatas: Optional[list[dict]] = None, ingestion_config: Optional[dict] = None, + collection_ids: Optional[list[list[Union[str, UUID]]]] = None, run_with_orchestration: Optional[bool] = None, ) -> dict: """ @@ -133,6 +146,16 @@ async def update_files( if run_with_orchestration is not None: data["run_with_orchestration"] = str(run_with_orchestration) + if collection_ids: + data["collection_ids"] = json.dumps( + [ + [ + str(collection_id) + for collection_id in doc_collection_ids + ] + for doc_collection_ids in collection_ids + ] + ) return await self._make_request( # type: ignore "POST", "update_files", data=data, files=files ) @@ -142,6 +165,7 @@ async def ingest_chunks( chunks: list[dict], document_id: Optional[UUID] = None, metadata: Optional[dict] = None, + collection_ids: Optional[list[list[Union[str, UUID]]]] = None, run_with_orchestration: Optional[bool] = None, ) -> dict: """ @@ -163,6 +187,18 @@ async def ingest_chunks( } if run_with_orchestration is not None: data["run_with_orchestration"] = str(run_with_orchestration) # type: ignore + + if collection_ids: + data["collection_ids"] = json.dumps( # type: ignore + [ + [ + str(collection_id) + for collection_id in doc_collection_ids + ] + for doc_collection_ids in collection_ids + ] + ) + return await self._make_request("POST", "ingest_chunks", json=data) # type: ignore async def update_chunks( diff --git a/py/shared/abstractions/document.py b/py/shared/abstractions/document.py index 624fca0ee..cc6d82d08 100644 --- a/py/shared/abstractions/document.py +++ b/py/shared/abstractions/document.py @@ -4,7 +4,7 @@ import logging from datetime import datetime from enum import Enum -from typing import Optional, Union, ClassVar +from typing import ClassVar, Optional, Union from uuid import UUID, uuid4 from pydantic import Field