diff --git a/py/core/base/agent/agent.py b/py/core/base/agent/agent.py index 6abe81d28..6758512dd 100644 --- a/py/core/base/agent/agent.py +++ b/py/core/base/agent/agent.py @@ -2,7 +2,7 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Optional, Type, Union +from typing import Any, AsyncGenerator, Optional, Type from pydantic import BaseModel @@ -26,7 +26,7 @@ def __init__(self): def create_and_add_message( self, - role: Union[MessageType, str], + role: MessageType | str, content: Optional[str] = None, name: Optional[str] = None, function_call: Optional[dict[str, Any]] = None, @@ -123,9 +123,7 @@ async def arun( messages: Optional[list[Message]] = None, *args, **kwargs, - ) -> Union[ - list[LLMChatCompletion], AsyncGenerator[LLMChatCompletion, None] - ]: + ) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]: pass @abstractmethod @@ -134,7 +132,7 @@ async def process_llm_response( response: Any, *args, **kwargs, - ) -> Union[None, AsyncGenerator[str, None]]: + ) -> None | AsyncGenerator[str, None]: pass async def execute_tool(self, tool_name: str, *args, **kwargs) -> str: diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 5359329c8..8ee1236ae 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -1220,7 +1220,7 @@ async def assign_document_to_collection_relational( collection_id: UUID, ) -> UUID: return await self.collection_handler.assign_document_to_collection_relational( - document_id, collection_id + document_id=document_id, collection_id=collection_id ) async def remove_document_from_collection_relational( diff --git a/py/core/main/api/v2/ingestion_router.py b/py/core/main/api/v2/ingestion_router.py index d5432ab04..e74da9ea7 100644 --- a/py/core/main/api/v2/ingestion_router.py +++ b/py/core/main/api/v2/ingestion_router.py @@ -484,15 +484,10 @@ async def update_document_metadata_app( workflow_input ) - return { # type: ignore - "message": "Update metadata task completed successfully.", - "document_id": str(document_id), - "task_id": None, - } return [ { # type: ignore "message": "Ingestion task completed successfully.", - "document_id": str(document_uuid), + "document_id": str(document_id), "task_id": None, } ] @@ -636,7 +631,7 @@ async def create_vector_index_app( }, ) - return GenericMessageResponse(message=raw_message) + return GenericMessageResponse(message=raw_message) # type: ignore list_vector_indices_extras = self.openapi_extras.get( "create_vector_index", {} @@ -725,7 +720,7 @@ async def delete_vector_index_app( }, ) - return GenericMessageResponse(message=raw_message) + return GenericMessageResponse(message=raw_message) # type: ignore @staticmethod async def _process_files(files): diff --git a/py/core/main/api/v2/management_router.py b/py/core/main/api/v2/management_router.py index 1ff2aab7c..5dbe05bf5 100644 --- a/py/core/main/api/v2/management_router.py +++ b/py/core/main/api/v2/management_router.py @@ -99,7 +99,7 @@ async def update_prompt_app( result = await self.service.update_prompt( name, template, input_types ) - return GenericMessageResponse(message=result) + return GenericMessageResponse(message=result) # type: ignore @self.router.post("/add_prompt") @self.base_endpoint @@ -115,7 +115,7 @@ async def add_prompt_app( 403, ) result = await self.service.add_prompt(name, template, input_types) - return GenericMessageResponse(message=result) + return GenericMessageResponse(message=result) # type: ignore @self.router.get("/get_prompt/{prompt_name}") @self.base_endpoint @@ -137,7 +137,7 @@ async def get_prompt_app( result = await self.service.get_cached_prompt( prompt_name, inputs, prompt_override ) - return GenericMessageResponse(message=result) + return GenericMessageResponse(message=result) # type: ignore @self.router.get("/get_all_prompts") @self.base_endpoint @@ -519,7 +519,7 @@ async def collections_overview_app( ) ) - return collections_overview_response["results"], { + return collections_overview_response["results"], { # type: ignore "total_entries": collections_overview_response["total_entries"] } @@ -640,7 +640,7 @@ async def add_user_to_collection_app( result = await self.service.add_user_to_collection( user_uuid, collection_uuid ) - return WrappedBooleanResponse(result=result) + return WrappedBooleanResponse(result=result) # type: ignore @self.router.post("/remove_user_from_collection") @self.base_endpoint diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index 1e7128e61..0ab045782 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -223,7 +223,7 @@ async def list_collections( limit=limit, ) - return ( + return ( # type: ignore collections_overview_response["results"], { "total_entries": collections_overview_response[ @@ -486,7 +486,7 @@ async def delete_collection( ) await self.services["management"].delete_collection(id) - return GenericBooleanResponse(success=True) + return GenericBooleanResponse(success=True) # type: ignore @self.router.post( "/collections/{id}/documents/{document_id}", @@ -745,7 +745,7 @@ async def remove_document_from_collection( await self.services["management"].remove_document_from_collection( document_id, id ) - return GenericBooleanResponse(success=True) + return GenericBooleanResponse(success=True) # type: ignore @self.router.get( "/collections/{id}/users", @@ -932,7 +932,7 @@ async def add_user_to_collection( result = await self.services["management"].add_user_to_collection( user_id, id ) - return GenericBooleanResponse(success=result) + return GenericBooleanResponse(success=result) # type: ignore @self.router.delete( "/collections/{id}/users/{user_id}", @@ -1014,4 +1014,4 @@ async def remove_user_from_collection( await self.services["management"].remove_user_from_collection( user_id, id ) - return GenericBooleanResponse(success=True) + return GenericBooleanResponse(success=True) # type: ignore diff --git a/py/core/main/api/v3/conversations_router.py b/py/core/main/api/v3/conversations_router.py index 04d517e28..369a159c6 100644 --- a/py/core/main/api/v3/conversations_router.py +++ b/py/core/main/api/v3/conversations_router.py @@ -195,7 +195,7 @@ async def list_conversations( offset=offset, limit=limit, ) - return conversations_response["results"], { + return conversations_response["results"], { # type: ignore "total_entries": conversations_response["total_entries"] } @@ -347,7 +347,7 @@ async def delete_conversation( This endpoint deletes a conversation identified by its UUID. """ await self.services["management"].delete_conversation(str(id)) - return GenericBooleanResponse(success=True) + return GenericBooleanResponse(success=True) # type: ignore @self.router.post( "/conversations/{id}/messages", @@ -609,7 +609,7 @@ async def list_branches( conversation_id=str(id), ) - return branches_response["results"], { + return branches_response["results"], { # type: ignore "total_entries": branches_response["total_entries"] } diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index ccd7754d5..a93faced2 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -126,6 +126,10 @@ async def create_document( None, description="The ID of the document. If not provided, a new ID will be generated.", ), + collection_ids: Optional[list[UUID]] = Form( + None, + description="Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection.", + ), metadata: Optional[Json[dict]] = Form( None, description="Metadata to associate with the document, such as title, description, or custom fields.", @@ -200,6 +204,7 @@ async def create_document( workflow_input = { "file_data": file_data, "document_id": str(document_id), + "collection_ids": collection_ids, "metadata": metadata, "ingestion_config": ingestion_config, "user": auth_user.model_dump_json(), @@ -306,7 +311,7 @@ async def create_document( }, ) @self.base_endpoint - async def update_document( + async def update_document( # type: ignore file: Optional[UploadFile] = File( None, description="The file to ingest. Either a file or content must be provided, but not both.", @@ -381,8 +386,10 @@ async def update_document( # Check if the user is a superuser if not auth_user.is_superuser: - if "user_id" in metadata and metadata["user_id"] != str( - auth_user.id + if ( + metadata is not None + and "user_id" in metadata + and metadata["user_id"] != str(auth_user.id) ): raise R2RException( status_code=403, @@ -795,7 +802,7 @@ async def list_chunks( "Not authorized to access this document's chunks.", 403 ) - return ( + return ( # type: ignore list_document_chunks["results"], {"total_entries": list_document_chunks["total_entries"]}, ) @@ -1019,7 +1026,7 @@ async def delete_document_by_id( ] } await self.services["management"].delete(filters=filters) - return GenericBooleanResponse(success=True) + return GenericBooleanResponse(success=True) # type: ignore @self.router.delete( "/documents/by-filter", @@ -1085,7 +1092,7 @@ async def delete_document_by_filter( filters=filters_dict ) - return GenericBooleanResponse(success=delete_bool) + return GenericBooleanResponse(success=delete_bool) # type: ignore @self.router.get( "/documents/{id}/collections", diff --git a/py/core/main/api/v3/graph_router.py b/py/core/main/api/v3/graph_router.py index 250cc62f6..976467251 100644 --- a/py/core/main/api/v3/graph_router.py +++ b/py/core/main/api/v3/graph_router.py @@ -32,31 +32,6 @@ logger = logging.getLogger() -# class Entity(BaseModel): -# """Model representing a graph entity.""" - -# id: UUID -# name: str -# type: str -# metadata: dict = Field(default_factory=dict) -# level: EntityLevel -# collection_ids: list[UUID] -# embedding: Optional[list[float]] = None - -# class Config: -# json_schema_extra = { -# "example": { -# "id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1", -# "name": "John Smith", -# "type": "PERSON", -# "metadata": {"confidence": 0.95}, -# "level": "DOCUMENT", -# "collection_ids": ["d09dedb1-b2ab-48a5-b950-6e1f464d83e7"], -# "embedding": [0.1, 0.2, 0.3], -# } -# } - - class Relationship(BaseModel): """Model representing a graph relationship.""" diff --git a/py/core/main/api/v3/indices_router.py b/py/core/main/api/v3/indices_router.py index e16a10ad7..428f29e0b 100644 --- a/py/core/main/api/v3/indices_router.py +++ b/py/core/main/api/v3/indices_router.py @@ -241,7 +241,7 @@ async def create_index( }, ) - return GenericMessageResponse(message=raw_message) + return GenericMessageResponse(message=raw_message) # type: ignore @self.router.get( "/indices", @@ -625,4 +625,4 @@ async def delete_index( }, ) - return GenericMessageResponse(message=raw_message) + return GenericMessageResponse(message=raw_message) # type: ignore diff --git a/py/core/main/api/v3/prompts_router.py b/py/core/main/api/v3/prompts_router.py index 99c0e5933..73ac4d91b 100644 --- a/py/core/main/api/v3/prompts_router.py +++ b/py/core/main/api/v3/prompts_router.py @@ -114,7 +114,7 @@ async def create_prompt( result = await self.services["management"].add_prompt( name, template, input_types ) - return GenericMessageResponse(message=result) + return GenericMessageResponse(message=result) # type: ignore @self.router.get( "/prompts", @@ -188,7 +188,7 @@ async def get_prompts( "management" ].get_all_prompts() - return ( + return ( # type: ignore get_prompts_response["results"], { "total_entries": get_prompts_response["total_entries"], @@ -365,7 +365,7 @@ async def update_prompt( result = await self.services["management"].update_prompt( name, template, input_types ) - return GenericMessageResponse(message=result) + return GenericMessageResponse(message=result) # type: ignore @self.router.delete( "/prompts/{name}", @@ -439,4 +439,4 @@ async def delete_prompt( 403, ) await self.services["management"].delete_prompt(name) - return GenericBooleanResponse(success=True) + return GenericBooleanResponse(success=True) # type: ignore diff --git a/py/core/main/api/v3/system_router.py b/py/core/main/api/v3/system_router.py index b587bab17..669168219 100644 --- a/py/core/main/api/v3/system_router.py +++ b/py/core/main/api/v3/system_router.py @@ -94,7 +94,7 @@ def _setup_routes(self): ) @self.base_endpoint async def health_check() -> WrappedGenericMessageResponse: - return GenericMessageResponse(message="ok") + return GenericMessageResponse(message="ok") # type: ignore @self.router.get( "/system/settings", @@ -224,7 +224,7 @@ async def server_stats( "Only an authorized user can call the `server_stats` endpoint.", 403, ) - return { + return { # type: ignore "start_time": self.start_time.isoformat(), "uptime_seconds": ( datetime.now(timezone.utc) - self.start_time diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index c4c1ad9d9..5bf6e3e1f 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -892,7 +892,7 @@ async def add_user_to_collection( await self.services["management"].add_user_to_collection( # type: ignore id, collection_id ) - return GenericBooleanResponse(success=True) + return GenericBooleanResponse(success=True) # type: ignore @self.router.delete( "/users/{id}/collections/{collection_id}", @@ -979,7 +979,7 @@ async def remove_user_from_collection( await self.services["management"].remove_user_from_collection( # type: ignore id, collection_id ) - return GenericBooleanResponse(success=True) + return GenericBooleanResponse(success=True) # type: ignore @self.router.post( "/users/{id}", diff --git a/py/core/main/orchestration/hatchet/ingestion_workflow.py b/py/core/main/orchestration/hatchet/ingestion_workflow.py index b2b508402..0ff165d04 100644 --- a/py/core/main/orchestration/hatchet/ingestion_workflow.py +++ b/py/core/main/orchestration/hatchet/ingestion_workflow.py @@ -183,9 +183,10 @@ async def parse(self, context: Context) -> dict: for collection_id in collection_ids: try: await service.providers.database.create_collection( + user_id=document_info.user_id, name=document_info.title, - collection_id=collection_id, description="", + collection_id=collection_id, ) except Exception as e: logger.warning( @@ -511,9 +512,10 @@ async def finalize(self, context: Context) -> dict: for collection_id in collection_ids: try: await service.providers.database.create_collection( + user_id=document_info.user_id, name=document_info.title or "N/A", - collection_id=collection_id, description="", + collection_id=collection_id, ) except Exception as e: logger.warning( diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 820a7bac7..e14814f6e 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -688,11 +688,11 @@ async def update_document_metadata( user: UserResponse, ) -> None: # Verify document exists and user has access - existing_document = ( - await self.providers.database.get_documents_overview( - filter_document_ids=[document_id], - filter_user_ids=[user.id], - ) + existing_document = await self.providers.database.get_documents_overview( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended. + offset=0, + limit=100, + filter_document_ids=[document_id], + filter_user_ids=[user.id], ) if not existing_document["results"]: diff --git a/py/core/main/services/kg_service.py b/py/core/main/services/kg_service.py index f3c74a5f8..76e1a819c 100644 --- a/py/core/main/services/kg_service.py +++ b/py/core/main/services/kg_service.py @@ -410,11 +410,11 @@ async def get_communities( @telemetry_event("list_communities") async def list_communities( self, + offset: int, + limit: int, collection_id: Optional[UUID] = None, levels: Optional[list[int]] = None, community_numbers: Optional[list[int]] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, **kwargs, ): return await self.providers.database.get_communities( diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index 9f711d84c..64272155b 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -320,12 +320,13 @@ async def agent( conversation_id = conversation["id"] parent_id = None - for inner_message in messages[:-1]: - parent_id = await self.logging_connection.add_message( - conversation_id, # Use the stored conversation_id - inner_message, - parent_id, - ) + if conversation_id and messages: + for inner_message in messages[:-1]: + parent_id = await self.logging_connection.add_message( + conversation_id, # Use the stored conversation_id + inner_message, + parent_id, + ) messages = messages or [] current_message = messages[-1] # type: ignore @@ -336,7 +337,8 @@ async def agent( current_message, # type: ignore parent_id=str(ids[-2]) if (ids and len(ids) > 1) else None, # type: ignore ) - message_id = message["id"] + if message is not None: + message_id = message.model_dump()["id"] if rag_generation_config.stream: t1 = time.time() @@ -381,8 +383,8 @@ async def stream_response(): **kwargs, ) await self.logging_connection.add_message( - conversation_id, - Message(**results[-1]), + conversation_id=conversation_id, + content=Message(**results[-1]), parent_id=message_id, ) diff --git a/py/core/providers/database/prompt.py b/py/core/providers/database/prompt.py index fc6b56e13..526daacc8 100644 --- a/py/core/providers/database/prompt.py +++ b/py/core/providers/database/prompt.py @@ -156,7 +156,7 @@ async def get_cached_prompt( self._prompt_cache.set(cache_key, result) return result - async def get_prompt( + async def get_prompt( # type: ignore self, name: str, inputs: Optional[dict] = None, diff --git a/py/core/providers/logger/r2r_logger.py b/py/core/providers/logger/r2r_logger.py index 0ef93e9b6..72ba6d941 100644 --- a/py/core/providers/logger/r2r_logger.py +++ b/py/core/providers/logger/r2r_logger.py @@ -145,6 +145,7 @@ async def savepoint(self, name: str): """Create a savepoint with proper error handling.""" if self.conn is None: await self.initialize() + assert self.conn is not None async with self.conn.cursor() as cursor: await cursor.execute(f"SAVEPOINT {name}") try: @@ -865,8 +866,8 @@ async def delete_conversation(self, conversation_id: str): """Delete a conversation and all related data.""" if self.conn is None: await self.initialize() - try: + assert self.conn is not None # Delete all message branches associated with the conversation await self.conn.execute( "DELETE FROM message_branches WHERE message_id IN (SELECT id FROM messages WHERE conversation_id = ?)", @@ -888,6 +889,7 @@ async def delete_conversation(self, conversation_id: str): ) await self.conn.commit() except Exception: + assert self.conn is not None await self.conn.rollback() raise diff --git a/py/sdk/v3/chunks.py b/py/sdk/v3/chunks.py index 5c43ebf7a..a3423613c 100644 --- a/py/sdk/v3/chunks.py +++ b/py/sdk/v3/chunks.py @@ -32,7 +32,7 @@ async def create( list[dict]: List of creation results containing processed chunk information """ data = { - "chunks": [chunk.dict() for chunk in chunks], + "chunks": chunks, "run_with_orchestration": run_with_orchestration, } return await self.client._make_request( diff --git a/py/sdk/v3/conversations.py b/py/sdk/v3/conversations.py index de984c8af..c35c34a8a 100644 --- a/py/sdk/v3/conversations.py +++ b/py/sdk/v3/conversations.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Any from uuid import UUID @@ -101,7 +101,7 @@ async def add_message( content: str, role: str, parent_id: Optional[str] = None, - metadata: Optional[dict[str, str]] = None, + metadata: Optional[dict] = None, ) -> dict: """ Add a new message to a conversation. diff --git a/py/sdk/v3/documents.py b/py/sdk/v3/documents.py index d96aab52d..0a9e47ae7 100644 --- a/py/sdk/v3/documents.py +++ b/py/sdk/v3/documents.py @@ -17,12 +17,22 @@ async def create( file_path: Optional[str] = None, content: Optional[str] = None, id: Optional[str | UUID] = None, + collection_ids: Optional[list[str | UUID]] = None, metadata: Optional[dict] = None, ingestion_config: Optional[dict] = None, run_with_orchestration: Optional[bool] = True, ) -> dict: """ Create a new document from either a file or content. + + Args: + file_path (Optional[str]): The file to upload, if any + content (Optional[str]): Optional text content to upload, if no file path is provided + id (Optional[Union[str, UUID]]): Optional ID to assign to the document + collection_ids (Optional[list[Union[str, UUID]]]): Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection. + metadata (Optional[dict]): Optional metadata to assign to the document + ingestion_config (Optional[dict]): Optional ingestion configuration to use + run_with_orchestration (Optional[bool]): Whether to run with orchestration """ if not file_path and not content: raise ValueError("Either file_path or content must be provided") @@ -38,6 +48,10 @@ async def create( data["metadata"] = json.dumps(metadata) if ingestion_config: data["ingestion_config"] = json.dumps(ingestion_config) + if collection_ids: + data["collection_ids"] = json.dumps( + [str(collection_id) for collection_id in collection_ids] + ) if run_with_orchestration is not None: data["run_with_orchestration"] = str(run_with_orchestration) diff --git a/py/shared/api/models/retrieval/responses.py b/py/shared/api/models/retrieval/responses.py index 6fb5f1929..2069147e3 100644 --- a/py/shared/api/models/retrieval/responses.py +++ b/py/shared/api/models/retrieval/responses.py @@ -124,21 +124,6 @@ class DocumentSearchResult(BaseModel): ) -class DocumentSearchResult(BaseModel): - document_id: str = Field( - ..., - description="The document ID", - ) - metadata: Optional[dict] = Field( - None, - description="The metadata of the document", - ) - score: float = Field( - ..., - description="The score of the document", - ) - - WrappedCompletionResponse = ResultsWrapper[LLMChatCompletion] # Create wrapped versions of the responses WrappedVectorSearchResponse = ResultsWrapper[list[VectorSearchResult]]