From dfeaf47549d406a7cf459b03adf630f6c95dc3a4 Mon Sep 17 00:00:00 2001 From: emrgnt-cmplxty Date: Tue, 5 Nov 2024 10:04:17 -0800 Subject: [PATCH] add indices --- py/core/base/__init__.py | 2 +- py/core/base/api/models/__init__.py | 3 + py/core/base/providers/database.py | 7 +- py/core/base/utils/__init__.py | 4 +- py/core/main/api/v2/ingestion_router.py | 2 +- py/core/main/api/v3/chunks_router.py | 24 +- py/core/main/api/v3/collections_router.py | 7 +- py/core/main/api/v3/conversations_router.py | 372 ++++++++++++++++-- py/core/main/api/v3/documents_router.py | 52 +-- py/core/main/api/v3/indices_router.py | 240 +++++++---- py/core/main/api/v3/prompts_router.py | 240 ++++++++++- py/core/main/api/v3/retrieval_router.py | 8 +- py/core/main/api/v3/users_router.py | 8 +- .../simple/ingestion_workflow.py | 8 +- py/core/main/services/ingestion_service.py | 6 +- py/core/main/services/management_service.py | 6 + py/core/providers/database/collection.py | 4 +- py/core/providers/database/vector.py | 158 +++++--- py/core/utils/__init__.py | 4 +- py/sdk/async_client.py | 6 + py/sdk/sync_client.py | 6 + py/sdk/v3/chunks.py | 2 +- py/sdk/v3/documents.py | 6 +- py/sdk/v3/indices.py | 150 +++++++ py/sdk/v3/retrieval.py | 226 +++++++++++ py/sdk/v3/users.py | 196 +++++++++ py/shared/api/models/ingestion/responses.py | 6 +- py/shared/utils/__init__.py | 4 +- py/shared/utils/base_utils.py | 2 +- 29 files changed, 1509 insertions(+), 250 deletions(-) create mode 100644 py/sdk/v3/indices.py create mode 100644 py/sdk/v3/retrieval.py create mode 100644 py/sdk/v3/users.py diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index 769c642ee..c83c3eb19 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -136,7 +136,7 @@ "generate_document_id", "generate_extraction_id", "generate_default_user_collection_id", - "generate_collection_id_from_name", + "generate_id_from_label", "generate_user_id", "increment_version", "EntityType", diff --git a/py/core/base/api/models/__init__.py b/py/core/base/api/models/__init__.py index 8da9dfa54..054de59e9 100644 --- a/py/core/base/api/models/__init__.py +++ b/py/core/base/api/models/__init__.py @@ -6,6 +6,7 @@ WrappedTokenResponse, WrappedUserResponse, ) +from shared.api.models.base import PaginatedResultsWrapper, ResultsWrapper from shared.api.models.ingestion.responses import ( CreateVectorIndexResponse, IngestionResponse, @@ -147,4 +148,6 @@ "WrappedCompletionResponse", "WrappedRAGResponse", "WrappedRAGAgentResponse", + "PaginatedResultsWrapper", + "ResultsWrapper", ] diff --git a/py/core/base/providers/database.py b/py/core/base/providers/database.py index 2d73dfdf9..aacf5f064 100644 --- a/py/core/base/providers/database.py +++ b/py/core/base/providers/database.py @@ -564,6 +564,7 @@ async def get_chunk(self, chunk_id: UUID) -> Optional[dict[str, Any]]: @abstractmethod async def create_index( self, + name: Optional[str] = None, table_name: Optional[VectorTableName] = None, index_measure: IndexMeasure = IndexMeasure.cosine_distance, index_method: IndexMethod = IndexMethod.auto, @@ -578,7 +579,7 @@ async def create_index( @abstractmethod async def list_indices( - self, table_name: Optional[VectorTableName] = None + self, offset: int = 0, limit: int = 10, filters: Optional[dict] = None ) -> list[dict]: pass @@ -1479,9 +1480,9 @@ async def create_index( ) async def list_indices( - self, table_name: Optional[VectorTableName] = None + self, offset: int = 0, limit: int = 10, filters: Optional[dict] = None ) -> list[dict]: - return await self.vector_handler.list_indices(table_name) + return await self.vector_handler.list_indices(offset, limit, filters) async def delete_index( self, diff --git a/py/core/base/utils/__init__.py b/py/core/base/utils/__init__.py index 4b07ed5b9..69b525716 100644 --- a/py/core/base/utils/__init__.py +++ b/py/core/base/utils/__init__.py @@ -7,12 +7,12 @@ format_relations, format_search_results_for_llm, format_search_results_for_stream, - generate_collection_id_from_name, generate_default_prompt_id, generate_default_user_collection_id, generate_document_id, generate_extraction_id, generate_id, + generate_id_from_label, generate_user_id, increment_version, llm_cost_per_million_tokens, @@ -35,7 +35,7 @@ "generate_document_id", "generate_extraction_id", "generate_user_id", - "generate_collection_id_from_name", + "generate_id_from_label", "generate_default_prompt_id", "RecursiveCharacterTextSplitter", "TextSplitter", diff --git a/py/core/main/api/v2/ingestion_router.py b/py/core/main/api/v2/ingestion_router.py index 7250f4507..e9b1a7e2b 100644 --- a/py/core/main/api/v2/ingestion_router.py +++ b/py/core/main/api/v2/ingestion_router.py @@ -572,7 +572,7 @@ async def list_vector_indices_app( description=list_vector_indices_descriptions.get("table_name"), ), auth_user=Depends(self.service.providers.auth.auth_wrapper), - ) -> WrappedListVectorIndicesResponse: + ): indices = await self.service.providers.database.list_indices( table_name=table_name ) diff --git a/py/core/main/api/v3/chunks_router.py b/py/core/main/api/v3/chunks_router.py index 2ba1c51d1..f0dda31a7 100644 --- a/py/core/main/api/v3/chunks_router.py +++ b/py/core/main/api/v3/chunks_router.py @@ -182,14 +182,16 @@ async def create_chunks( run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> ResultsWrapper[list[ChunkIngestionResponse]]: - f""" + """ Create multiple chunks and process them through the ingestion pipeline. This endpoint allows creating multiple chunks at once, optionally associating them with documents and collections. The chunks will be processed asynchronously if run_with_orchestration is True. - Maximum of {MAX_CHUNKS_PER_REQUEST} chunks can be created in a single request. + Maximum of 100,000 chunks can be created in a single request. + + Note, it is not yet possible to add chunks to an existing document using this endpoint. """ default_document_id = generate_id() if len(raw_chunks) > MAX_CHUNKS_PER_REQUEST: @@ -212,9 +214,10 @@ async def create_chunks( document_id = document_id or default_document_id # Convert UnprocessedChunks to RawChunks for ingestion raw_chunks_for_doc = [ - RawChunk( + UnprocessedChunk( text=chunk.text if hasattr(chunk, "text") else "", metadata=chunk.metadata, + id=chunk.id, ) for chunk in doc_chunks ] @@ -229,6 +232,9 @@ async def create_chunks( "user": auth_user.model_dump_json(), } + # TODO - Modify create_chunks so that we can add chunks to existing document + # TODO - Modify create_chunks so that we can add chunks to existing document + if run_with_orchestration: # Run ingestion with orchestration raw_message = ( @@ -384,9 +390,11 @@ async def retrieve_chunk( client = R2RClient("http://localhost:7272") result = client.chunks.update( - id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa", - text="Updated content", - metadata={"key": "new value"} + { + "id": first_chunk_id, + "text": "Updated content", + "metadata": {"key": "new value"} + } ) """, } @@ -498,7 +506,7 @@ async def enrich_chunk( @self.base_endpoint async def delete_chunk( id: Json[UUID] = Path(...), - ) -> ResultsWrapper[ChunkResponse]: + ) -> ResultsWrapper[bool]: """ Delete a specific chunk by ID. @@ -512,7 +520,7 @@ async def delete_chunk( raise R2RException(f"Chunk {id} not found", 404) await self.services["management"].delete({"$eq": {"chunk_id": id}}) - return None + return True @self.router.get( "/chunks", diff --git a/py/core/main/api/v3/collections_router.py b/py/core/main/api/v3/collections_router.py index 7fa95d09c..deec0a3c6 100644 --- a/py/core/main/api/v3/collections_router.py +++ b/py/core/main/api/v3/collections_router.py @@ -8,6 +8,7 @@ from core.base import R2RException, RunType from core.base.api.models import ( + ResultsWrapper, WrappedAddUserResponse, WrappedCollectionListResponse, WrappedCollectionResponse, @@ -348,7 +349,7 @@ async def delete_collection( description="The unique identifier of the collection to delete", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedDeleteResponse: + ) -> ResultsWrapper[bool]: """ Delete an existing collection. @@ -566,7 +567,7 @@ async def remove_document_from_collection( description="The unique identifier of the document to remove", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedDeleteResponse: + ) -> ResultsWrapper[bool]: """ Remove a document from a collection. @@ -774,7 +775,7 @@ async def remove_user_from_collection( ..., description="The unique identifier of the user to remove" ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedDeleteResponse: + ) -> ResultsWrapper[bool]: """ Remove a user from a collection. diff --git a/py/core/main/api/v3/conversations_router.py b/py/core/main/api/v3/conversations_router.py index 51430f3c2..68a9d1a2b 100644 --- a/py/core/main/api/v3/conversations_router.py +++ b/py/core/main/api/v3/conversations_router.py @@ -3,10 +3,11 @@ from uuid import UUID from fastapi import Body, Depends, Path, Query -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.base import Message, R2RException, RunType from core.base.api.models import ( + ResultsWrapper, WrappedConversationResponse, WrappedConversationsOverviewResponse, WrappedDeleteResponse, @@ -22,9 +23,13 @@ class MessageContent(BaseModel): - content: str - parent_id: Optional[str] = None - metadata: Optional[dict] = None + content: str = Field(..., description="The content of the message") + parent_id: Optional[str] = Field( + None, description="The ID of the parent message, if any" + ) + metadata: Optional[dict] = Field( + None, description="Additional metadata for the message" + ) class ConversationsRouter(BaseRouterV3): @@ -40,27 +45,120 @@ def __init__( super().__init__(providers, services, orchestration_provider, run_type) def _setup_routes(self): - @self.router.post("/conversations") + @self.router.post( + "/conversations", + summary="Create a new conversation", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.conversations.create() +""", + }, + { + "lang": "cURL", + "source": """ +curl -X POST "https://api.example.com/v3/conversations" \\ + -H "Authorization: Bearer YOUR_API_KEY" +""", + }, + ] + }, + ) @self.base_endpoint async def create_conversation( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedConversationResponse: """ Create a new conversation. + + This endpoint initializes a new conversation for the authenticated user. + + Args: + auth_user: The authenticated user making the request. + + Returns: + WrappedConversationResponse: Details of the newly created conversation. + + Raises: + R2RException: If there's an error in creating the conversation. """ return await self.services.management.create_conversation() - @self.router.get("/conversations") + @self.router.get( + "/conversations", + summary="List conversations", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.conversations.list( + offset=0, + limit=10, + sort_by="created_at", + sort_order="desc" +) +""", + }, + { + "lang": "cURL", + "source": """ +curl -X GET "https://api.example.com/v3/conversations?offset=0&limit=10&sort_by=created_at&sort_order=desc" \\ + -H "Authorization: Bearer YOUR_API_KEY" +""", + }, + ] + }, + ) @self.base_endpoint async def list_conversations( - offset: int = Query(0, ge=0), - limit: int = Query(100, ge=1, le=1000), - sort_by: Optional[str] = Query(None), - sort_order: Optional[str] = Query("desc"), + offset: int = Query( + 0, ge=0, description="The number of conversations to skip" + ), + limit: int = Query( + 100, + ge=1, + le=1000, + description="The maximum number of conversations to return", + ), + sort_by: Optional[str] = Query( + None, description="The field to sort the conversations by" + ), + sort_order: Optional[str] = Query( + "desc", + description="The order to sort the conversations ('asc' or 'desc')", + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedConversationsOverviewResponse: """ List conversations with pagination and sorting options. + + This endpoint returns a paginated list of conversations for the authenticated user. + + Args: + offset (int): The number of conversations to skip (for pagination). + limit (int): The maximum number of conversations to return (1-1000). + sort_by (str, optional): The field to sort the conversations by. + sort_order (str, optional): The order to sort the conversations ("asc" or "desc"). + auth_user: The authenticated user making the request. + + Returns: + WrappedConversationsOverviewResponse: A paginated list of conversations and total count. + + Raises: + R2RException: If there's an error in retrieving the conversations. """ conversations_response = ( await self.services.management.conversations_overview( @@ -73,42 +171,176 @@ async def list_conversations( "total_entries": conversations_response["total_entries"] } - @self.router.get("/conversations/{id}") + @self.router.get( + "/conversations/{id}", + summary="Get conversation details", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.conversations.get( + "123e4567-e89b-12d3-a456-426614174000", + branch_id="branch_1" +) +""", + }, + { + "lang": "cURL", + "source": """ +curl -X GET "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000?branch_id=branch_1" \\ + -H "Authorization: Bearer YOUR_API_KEY" +""", + }, + ] + }, + ) @self.base_endpoint async def get_conversation( - id: UUID = Path(...), - branch_id: Optional[str] = Query(None), + id: UUID = Path( + ..., description="The unique identifier of the conversation" + ), + branch_id: Optional[str] = Query( + None, description="The ID of the specific branch to retrieve" + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedConversationResponse: """ Get details of a specific conversation. + + This endpoint retrieves detailed information about a single conversation identified by its UUID. + + Args: + id (UUID): The unique identifier of the conversation. + branch_id (str, optional): The ID of the specific branch to retrieve. + auth_user: The authenticated user making the request. + + Returns: + WrappedConversationResponse: Detailed information about the requested conversation. + + Raises: + R2RException: If the conversation is not found or the user doesn't have access. """ return await self.services.management.get_conversation( str(id), branch_id, ) - @self.router.delete("/conversations/{id}") + @self.router.delete( + "/conversations/{id}", + summary="Delete conversation", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.conversations.delete("123e4567-e89b-12d3-a456-426614174000") +""", + }, + { + "lang": "cURL", + "source": """ +curl -X DELETE "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer YOUR_API_KEY" +""", + }, + ] + }, + ) @self.base_endpoint async def delete_conversation( - id: UUID = Path(...), + id: UUID = Path( + ..., + description="The unique identifier of the conversation to delete", + ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedDeleteResponse: + ) -> ResultsWrapper[bool]: """ Delete an existing conversation. + + This endpoint deletes a conversation identified by its UUID. + + Args: + id (UUID): The unique identifier of the conversation to delete. + auth_user: The authenticated user making the request. + + Returns: + WrappedDeleteResponse: Confirmation of the deletion. + + Raises: + R2RException: If the conversation is not found or the user doesn't have permission to delete it. """ await self.services.management.delete_conversation(str(id)) return None - @self.router.post("/conversations/{id}/messages") + @self.router.post( + "/conversations/{id}/messages", + summary="Add message to conversation", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.conversations.add_message( + "123e4567-e89b-12d3-a456-426614174000", + content="Hello, world!", + parent_id="parent_message_id", + metadata={"key": "value"} +) +""", + }, + { + "lang": "cURL", + "source": """ +curl -X POST "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{"content": "Hello, world!", "parent_id": "parent_message_id", "metadata": {"key": "value"}}' +""", + }, + ] + }, + ) @self.base_endpoint async def add_message( - id: UUID = Path(...), - message_content: MessageContent = Body(...), + id: UUID = Path( + ..., description="The unique identifier of the conversation" + ), + message_content: MessageContent = Body( + ..., description="The content of the message to add" + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> dict: """ Add a new message to a conversation. + + This endpoint adds a new message to an existing conversation. + + Args: + id (UUID): The unique identifier of the conversation. + message_content (MessageContent): The content of the message to add. + auth_user: The authenticated user making the request. + + Returns: + dict: A dictionary containing the ID of the newly added message. + + Raises: + R2RException: If the conversation is not found or the user doesn't have permission to add messages. """ message = Message(content=message_content.content) message_id = await self.services.management.add_message( @@ -119,16 +351,67 @@ async def add_message( ) return {"message_id": message_id} - @self.router.put("/conversations/{id}/messages/{message_id}") + @self.router.put( + "/conversations/{id}/messages/{message_id}", + summary="Update message in conversation", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.conversations.update_message( + "123e4567-e89b-12d3-a456-426614174000", + "message_id_to_update", + content="Updated content" +) +""", + }, + { + "lang": "cURL", + "source": """ +curl -X PUT "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/messages/message_id_to_update" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{"content": "Updated content"}' +""", + }, + ] + }, + ) @self.base_endpoint async def update_message( - id: UUID = Path(...), - message_id: str = Path(...), - content: str = Body(...), + id: UUID = Path( + ..., description="The unique identifier of the conversation" + ), + message_id: str = Path( + ..., description="The ID of the message to update" + ), + content: str = Body( + ..., description="The new content for the message" + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> dict: """ Update an existing message in a conversation. + + This endpoint updates the content of an existing message in a conversation. + + Args: + id (UUID): The unique identifier of the conversation. + message_id (str): The ID of the message to update. + content (str): The new content for the message. + auth_user: The authenticated user making the request. + + Returns: + dict: A dictionary containing the new message ID and new branch ID. + + Raises: + R2RException: If the conversation or message is not found, or if the user doesn't have permission to update. """ new_message_id, new_branch_id = ( await self.services.management.edit_message( @@ -140,14 +423,53 @@ async def update_message( "new_branch_id": new_branch_id, } - @self.router.get("/conversations/{id}/branches") + @self.router.get( + "/conversations/{id}/branches", + summary="List branches in conversation", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.conversations.list_branches("123e4567-e89b-12d3-a456-426614174000") +""", + }, + { + "lang": "cURL", + "source": """ +curl -X GET "https://api.example.com/v3/conversations/123e4567-e89b-12d3-a456-426614174000/branches" \\ + -H "Authorization: Bearer YOUR_API_KEY" +""", + }, + ] + }, + ) @self.base_endpoint async def list_branches( - id: UUID = Path(...), + id: UUID = Path( + ..., description="The unique identifier of the conversation" + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> dict: """ List all branches in a conversation. + + This endpoint retrieves all branches associated with a specific conversation. + + Args: + id (UUID): The unique identifier of the conversation. + auth_user: The authenticated user making the request. + + Returns: + dict: A dictionary containing a list of branches in the conversation. + + Raises: + R2RException: If the conversation is not found or the user doesn't have permission to view branches. """ branches = await self.services.management.branches_overview( str(id) diff --git a/py/core/main/api/v3/documents_router.py b/py/core/main/api/v3/documents_router.py index ba877cd5d..c90a7490d 100644 --- a/py/core/main/api/v3/documents_router.py +++ b/py/core/main/api/v3/documents_router.py @@ -142,11 +142,11 @@ def _setup_routes(self): "source": textwrap.dedent( """ curl -X POST "https://api.example.com/v3/documents" \\ - -H "Content-Type: multipart/form-data" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -F "file=@pg_essay_1.html;type=text/html" \\ - -F 'metadata={}' \\ - -F 'id=null'""" + -H "Content-Type: multipart/form-data" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -F "file=@pg_essay_1.html;type=text/html" \\ + -F 'metadata={}' \\ + -F 'id=null' """ ), }, ] @@ -279,7 +279,7 @@ async def create_document( await simple_ingestor["ingest-files"](workflow_input) return { # type: ignore "message": "Ingestion task completed successfully.", - "document_id": str(id), + "document_id": str(document_id), "task_id": None, } @@ -307,10 +307,10 @@ async def create_document( "lang": "cURL", "source": textwrap.dedent( """ - curl -X POST "https://api.example.com/document/9fbe403b-c11c-5aae-8ade-ef22980c3ad1" \ - -H "Content-Type: multipart/form-data" \ - -H "Authorization: Bearer YOUR_API_KEY" \ - -F "file=@pg_essay_1.html;type=text/plain" """ + curl -X POST "https://api.example.com/document/9fbe403b-c11c-5aae-8ade-ef22980c3ad1" \\ + -H "Content-Type: multipart/form-data" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -F "file=@pg_essay_1.html;type=text/plain" """ ), }, ] @@ -491,8 +491,8 @@ async def update_document( "lang": "cURL", "source": textwrap.dedent( """ - curl -X GET "https://api.example.com/v3/documents" \ - -H "Authorization: Bearer YOUR_API_KEY" """ + curl -X GET "https://api.example.com/v3/documents" \\ + -H "Authorization: Bearer YOUR_API_KEY" """ ), }, ] @@ -575,8 +575,8 @@ async def get_documents( "lang": "cURL", "source": textwrap.dedent( """ - curl -X GET "https://api.example.com/v3/documents/9fbe403b-c11c-5aae-8ade-ef22980c3ad1" \ - -H "Authorization: Bearer YOUR_API_KEY" """ + curl -X GET "https://api.example.com/v3/documents/9fbe403b-c11c-5aae-8ade-ef22980c3ad1" \\ + -H "Authorization: Bearer YOUR_API_KEY" """ ), }, ] @@ -642,8 +642,8 @@ async def get_document( "lang": "cURL", "source": textwrap.dedent( """ - curl -X GET "https://api.example.com/v3/documents/9fbe403b-c11c-5aae-8ade-ef22980c3ad1/chunks" \ - -H "Authorization: Bearer YOUR_API_KEY" """ + curl -X GET "https://api.example.com/v3/documents/9fbe403b-c11c-5aae-8ade-ef22980c3ad1/chunks" \\ + -H "Authorization: Bearer YOUR_API_KEY" """ ), }, ] @@ -744,8 +744,8 @@ async def list_chunks( "lang": "cURL", "source": textwrap.dedent( """ - curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/download" \ - -H "Authorization: Bearer YOUR_API_KEY" """ + curl -X GET "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa/download" \\ + -H "Authorization: Bearer YOUR_API_KEY" """ ), }, ] @@ -825,8 +825,8 @@ async def file_stream(): "lang": "cURL", "source": textwrap.dedent( """ - curl -X DELETE "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \ - -H "Authorization: Bearer YOUR_API_KEY" """ + curl -X DELETE "https://api.example.com/v3/documents/b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" \\ + -H "Authorization: Bearer YOUR_API_KEY" """ ), }, ] @@ -844,12 +844,12 @@ async def delete_document_by_id( """ filters = { "$and": [ - {"$eq": str(auth_user.id)}, + {"user_id": {"$eq": str(auth_user.id)}}, {"document_id": {"$eq": id}}, ] } await self.services["management"].delete(filters=filters) - return None + return True @self.router.delete( "/documents/by-filter", @@ -874,7 +874,7 @@ async def delete_document_by_id( "source": textwrap.dedent( """ curl -X DELETE "https://api.example.com/v3/documents/by-filter?filters=%7B%22document_type%22%3A%7B%22%24eq%22%3A%22text%22%7D%2C%22created_at%22%3A%7B%22%24lt%22%3A%222023-01-01T00%3A00%3A00Z%22%7D%7D" \\ - -H "Authorization: Bearer YOUR_API_KEY" + -H "Authorization: Bearer YOUR_API_KEY" """ ), }, @@ -936,8 +936,8 @@ async def delete_document_by_filter( "lang": "cURL", "source": textwrap.dedent( """ - curl -X GET "https://api.example.com/v3/documents/9fbe403b-c11c-5aae-8ade-ef22980c3ad1/collections" \ - -H "Authorization: Bearer YOUR_API_KEY" """ + curl -X GET "https://api.example.com/v3/documents/9fbe403b-c11c-5aae-8ade-ef22980c3ad1/collections" \\ + -H "Authorization: Bearer YOUR_API_KEY" """ ), }, ] @@ -958,7 +958,7 @@ async def get_document_collections( description="The maximum number of collections to retrieve, up to 1,000.", ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[list[CollectionResponse]]: + ) -> PaginatedResultsWrapper[list[CollectionResponse]]: """ Retrieves all collections that contain the specified document. This endpoint is restricted to superusers only and provides a system-wide view of document organization. diff --git a/py/core/main/api/v3/indices_router.py b/py/core/main/api/v3/indices_router.py index 59b49fcb8..15f5d50f0 100644 --- a/py/core/main/api/v3/indices_router.py +++ b/py/core/main/api/v3/indices_router.py @@ -1,3 +1,12 @@ +# TODO - Move indices to 'id' basis +# TODO - Move indices to 'id' basis + +# TODO - Implement update index +# TODO - Implement update index + +# TODO - Implement index data model +# TODO - Implement index data model + import logging from typing import Any, Optional, Union from uuid import UUID @@ -61,6 +70,7 @@ class IndexConfig(BaseModel): # description=create_vector_descriptions.get("concurrently"), # ), # auth_user=Depends(self.service.providers.auth.auth_wrapper), + name: Optional[str] = None table_name: Optional[str] = VectorTableName.VECTORS index_method: Optional[str] = IndexMethod.hnsw index_measure: Optional[str] = IndexMeasure.cosine_distance @@ -85,6 +95,8 @@ def __init__( def _setup_routes(self): + ## TODO - Allow developer to pass the index id with the request + ## TODO - Allow developer to pass the index id with the request @self.router.post( "/indices", summary="Create Vector Index", @@ -182,7 +194,7 @@ def _setup_routes(self): async def create_index( config: IndexConfig = Body( ..., - description="Configuration for the vector index", + description="Configuration for the vector index, acceptable table_name values are 'vectors', 'document_entity', 'document_collections'", example={ "table_name": "vectors", "index_method": "hnsw", @@ -204,7 +216,8 @@ async def create_index( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedCreateVectorIndexResponse: """ - Create a new vector similarity search index in the database. + Create a new vector similarity search index in over the target table. Allowed tables include 'vectors', 'document_entity', 'document_collections'. + Vectors correspond to the chunks of text that are indexed for similarity search, whereas document_entity and document_collections are created during knowledge graph construction. This endpoint creates a database index optimized for efficient similarity search over vector embeddings. It supports two main indexing methods: @@ -238,7 +251,29 @@ async def create_index( - Index names must be unique per table """ # TODO: Implement index creation logic - pass + logger.info( + f"Creating vector index for {config.table_name} with method {config.index_method}, measure {config.index_measure}, concurrently {config.concurrently}" + ) + + raw_message = await self.orchestration_provider.run_workflow( + "create-vector-index", + { + "request": { + "table_name": config.table_name, + "index_method": config.index_method, + "index_measure": config.index_measure, + "index_name": config.index_name, + "index_column": config.index_column, + "index_arguments": config.index_arguments, + "concurrently": config.concurrently, + }, + }, + options={ + "additional_metadata": {}, + }, + ) + + return raw_message # type: ignore @self.router.get( "/indices", @@ -256,7 +291,7 @@ async def create_index( indices = client.indices.list( offset=0, limit=10, - filter_by={"table_name": "vectors"} + filters={"table_name": "vectors"} ) # Print index details @@ -274,7 +309,7 @@ async def create_index( -H "Content-Type: application/json" # With filters -curl -X GET "https://api.example.com/indices?offset=0&limit=10&filter_by={\"table_name\":\"vectors\"}" \\ +curl -X GET "https://api.example.com/indices?offset=0&limit=10&filters={\"table_name\":\"vectors\"}" \\ -H "Authorization: Bearer YOUR_API_KEY" \\ -H "Content-Type: application/json" """, @@ -293,7 +328,7 @@ async def list_indices( le=100, description="Maximum number of records to return", ), - filter_by: Optional[Json[dict]] = Query( + filters: Optional[Json[dict]] = Query( None, description='Filter criteria for indices (e.g., {"table_name": "vectors"})', ), @@ -313,10 +348,13 @@ async def list_indices( based on table name, index method, or other attributes. """ # TODO: Implement index listing logic - pass + indices = await self.providers.database.list_indices( + offset=offset, limit=limit, filters=filters + ) + return {"indices": indices["indices"]}, indices["page_info"] # type: ignore @self.router.get( - "/indices/{id}", + "/indices/{table_name}/{index_name}", summary="Get Vector Index Details", openapi_extra={ "x-codeSamples": [ @@ -347,9 +385,15 @@ async def list_indices( ) @self.base_endpoint async def get_index( - id: UUID = Path(...), + table_name: VectorTableName = Path( + ..., + description="The table of vector embeddings to delete (e.g. `vectors`, `document_entity`, `document_collections`)", + ), + index_name: str = Path( + ..., description="The name of the index to delete" + ), auth_user=Depends(self.providers.auth.auth_wrapper), - ): # -> WrappedGetIndexResponse: + ) -> dict: # -> WrappedGetIndexResponse: """ Get detailed information about a specific vector index. @@ -368,68 +412,77 @@ async def get_index( * Recommended optimizations """ # TODO: Implement get index logic - pass - - @self.router.put( - "/indices/{id}", - summary="Update Vector Index", - openapi_extra={ - "x-codeSamples": [ - { - "lang": "Python", - "source": """ -from r2r import R2RClient - -client = R2RClient("http://localhost:7272") - -# Update HNSW index parameters -result = client.indices.update( - "550e8400-e29b-41d4-a716-446655440000", - config={ - "index_arguments": { - "ef": 80, # Increase search quality - "m": 24 # Increase connections per layer - }, - "concurrently": True - }, - run_with_orchestration=True -)""", - }, - { - "lang": "Shell", - "source": """ -curl -X PUT "https://api.example.com/indices/550e8400-e29b-41d4-a716-446655440000" \\ - -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "config": { - "index_arguments": { - "ef": 80, - "m": 24 - }, - "concurrently": true - }, - "run_with_orchestration": true - }'""", - }, - ] - }, - ) - @self.base_endpoint - async def update_index( - id: UUID = Path(...), - config: IndexConfig = Body(...), - run_with_orchestration: Optional[bool] = Body(True), - auth_user=Depends(self.providers.auth.auth_wrapper), - ): # -> WrappedUpdateIndexResponse: - """ - Update an existing index's configuration. - """ - # TODO: Implement index update logic - pass + indices = await self.providers.database.list_indices( + filters={"index_name": index_name, "table_name": table_name} + ) + if len(indices["indices"]) != 1: + raise R2RException( + f"Index '{index_name}' not found", status_code=404 + ) + return {"index": indices["indices"][0]} + + # TODO - Implement update index + # TODO - Implement update index + # @self.router.post( + # "/indices/{name}", + # summary="Update Vector Index", + # openapi_extra={ + # "x-codeSamples": [ + # { + # "lang": "Python", + # "source": """ + # from r2r import R2RClient + + # client = R2RClient("http://localhost:7272") + + # # Update HNSW index parameters + # result = client.indices.update( + # "550e8400-e29b-41d4-a716-446655440000", + # config={ + # "index_arguments": { + # "ef": 80, # Increase search quality + # "m": 24 # Increase connections per layer + # }, + # "concurrently": True + # }, + # run_with_orchestration=True + # )""", + # }, + # { + # "lang": "Shell", + # "source": """ + # curl -X PUT "https://api.example.com/indices/550e8400-e29b-41d4-a716-446655440000" \\ + # -H "Content-Type: application/json" \\ + # -H "Authorization: Bearer YOUR_API_KEY" \\ + # -d '{ + # "config": { + # "index_arguments": { + # "ef": 80, + # "m": 24 + # }, + # "concurrently": true + # }, + # "run_with_orchestration": true + # }'""", + # }, + # ] + # }, + # ) + # @self.base_endpoint + # async def update_index( + # id: UUID = Path(...), + # config: IndexConfig = Body(...), + # run_with_orchestration: Optional[bool] = Body(True), + # auth_user=Depends(self.providers.auth.auth_wrapper), + # ): # -> WrappedUpdateIndexResponse: + # """ + # Update an existing index's configuration. + # """ + # # TODO: Implement index update logic + # pass @self.router.delete( - "/indices/{id}", + "/indices/{table_name}/{index_name}", summary="Delete Vector Index", openapi_extra={ "x-codeSamples": [ @@ -442,27 +495,34 @@ async def update_index( # Delete an index with orchestration for cleanup result = client.indices.delete( - "550e8400-e29b-41d4-a716-446655440000", + index_name="index_1", run_with_orchestration=True )""", }, { "lang": "Shell", "source": """ -curl -X DELETE "https://api.example.com/indices/550e8400-e29b-41d4-a716-446655440000" \\ +curl -X DELETE "https://api.example.com/indices/index_1" \\ -H "Content-Type: application/json" \\ - -H "Authorization: Bearer YOUR_API_KEY" \\ - -d '{ - "run_with_orchestration": true - }'""", + -H "Authorization: Bearer YOUR_API_KEY" """, }, ] }, ) @self.base_endpoint async def delete_index( - id: UUID = Path(...), - run_with_orchestration: Optional[bool] = Body(True), + table_name: VectorTableName = Path( + default=..., + description="The table of vector embeddings to delete (e.g. `vectors`, `document_entity`, `document_collections`)", + ), + index_name: str = Path( + ..., description="The name of the index to delete" + ), + # concurrently: bool = Body( + # default=True, + # description="Whether to delete the index concurrently (recommended for large indices)", + # ), + # run_with_orchestration: Optional[bool] = Body(True), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedDeleteVectorIndexResponse: """ @@ -477,8 +537,24 @@ async def delete_index( - Use run_with_orchestration=True for large indices to prevent timeouts - Consider index dependencies before deletion - The operation returns immediately but cleanup may continue in background - when run_with_orchestration=True. + The operation returns immediately but cleanup may continue in background. """ - # TODO: Implement index deletion logic - pass + logger.info( + f"Deleting vector index {index_name} from table {table_name}" + ) + + raw_message = await self.orchestration_provider.run_workflow( + "delete-vector-index", + { + "request": { + "index_name": index_name, + "table_name": table_name, + "concurrently": True, + }, + }, + options={ + "additional_metadata": {}, + }, + ) + + return raw_message # type: ignore diff --git a/py/core/main/api/v3/prompts_router.py b/py/core/main/api/v3/prompts_router.py index 358a565ff..af468b286 100644 --- a/py/core/main/api/v3/prompts_router.py +++ b/py/core/main/api/v3/prompts_router.py @@ -2,10 +2,11 @@ from uuid import UUID from fastapi import Body, Depends, Path, Query -from pydantic import BaseModel, Json +from pydantic import BaseModel, Field, Json from core.base import R2RException, RunType from core.base.api.models import ( + ResultsWrapper, WrappedDeleteResponse, WrappedGetPromptsResponse, WrappedPromptMessageResponse, @@ -19,9 +20,14 @@ class PromptConfig(BaseModel): - name: str - template: str - input_types: dict[str, str] = {} + name: str = Field(..., description="The name of the prompt") + template: str = Field( + ..., description="The template string for the prompt" + ) + input_types: dict[str, str] = Field( + default={}, + description="A dictionary mapping input names to their types", + ) class PromptsRouter(BaseRouterV3): @@ -37,44 +43,152 @@ def __init__( super().__init__(providers, services, orchestration_provider, run_type) def _setup_routes(self): - @self.router.post("/prompts") + @self.router.post( + "/prompts", + summary="Create a new prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.prompts.create( + name="greeting_prompt", + template="Hello, {name}!", + input_types={"name": "string"} +) +""", + }, + { + "lang": "cURL", + "source": """ +curl -X POST "https://api.example.com/v3/prompts" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{"name": "greeting_prompt", "template": "Hello, {name}!", "input_types": {"name": "string"}}' +""", + }, + ] + }, + ) @self.base_endpoint async def create_prompt( - config: PromptConfig = Body(...), + config: PromptConfig = Body( + ..., description="The configuration for the new prompt" + ), auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedPromptMessageResponse: """ Create a new prompt with the given configuration. + + This endpoint allows superusers to create a new prompt with a specified name, template, and input types. + + Args: + config (PromptConfig): The configuration for the new prompt. + auth_user: The authenticated user making the request. + + Returns: + WrappedPromptMessageResponse: Details of the newly created prompt. + + Raises: + R2RException: If the user is not a superuser or if there's an error in creating the prompt. """ if not auth_user.is_superuser: raise R2RException( "Only a superuser can create prompts.", 403, ) - result = await self.services.add_prompt( config.name, config.template, config.input_types ) return result - @self.router.get("/prompts") + @self.router.get( + "/prompts", + summary="List all prompts", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.prompts.list() +""", + }, + { + "lang": "cURL", + "source": """ +curl -X GET "https://api.example.com/v3/prompts" \\ + -H "Authorization: Bearer YOUR_API_KEY" +""", + }, + ] + }, + ) @self.base_endpoint async def list_prompts( auth_user=Depends(self.providers.auth.auth_wrapper), ) -> WrappedGetPromptsResponse: """ List all available prompts. + + This endpoint retrieves a list of all prompts in the system. Only superusers can access this endpoint. + + Args: + auth_user: The authenticated user making the request. + + Returns: + WrappedGetPromptsResponse: A list of all available prompts. + + Raises: + R2RException: If the user is not a superuser. """ if not auth_user.is_superuser: raise R2RException( "Only a superuser can list prompts.", 403, ) - result = await self.services.get_all_prompts() return {"prompts": result} - @self.router.get("/prompts/{name}") + @self.router.get( + "/prompts/{name}", + summary="Get a specific prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.prompts.get( + "greeting_prompt", + inputs={"name": "John"}, + prompt_override="Hi, {name}!" +) +""", + }, + { + "lang": "cURL", + "source": """ +curl -X GET "https://api.example.com/v3/prompts/greeting_prompt?inputs=%7B%22name%22%3A%22John%22%7D&prompt_override=Hi%2C%20%7Bname%7D!" \\ + -H "Authorization: Bearer YOUR_API_KEY" +""", + }, + ] + }, + ) @self.base_endpoint async def get_prompt( name: str = Path(..., description="Prompt name"), @@ -88,19 +202,64 @@ async def get_prompt( ) -> WrappedPromptMessageResponse: """ Get a specific prompt by name, optionally with inputs and override. + + This endpoint retrieves a specific prompt and allows for optional inputs and template override. + Only superusers can access this endpoint. + + Args: + name (str): The name of the prompt to retrieve. + inputs (dict, optional): JSON-encoded inputs for the prompt. + prompt_override (str, optional): An override for the prompt template. + auth_user: The authenticated user making the request. + + Returns: + WrappedPromptMessageResponse: The requested prompt with applied inputs and/or override. + + Raises: + R2RException: If the user is not a superuser or if the prompt is not found. """ if not auth_user.is_superuser: raise R2RException( "Only a superuser can retrieve prompts.", 403, ) - result = await self.services.get_prompt( name, inputs, prompt_override ) return result - @self.router.put("/prompts/{name}") + @self.router.put( + "/prompts/{name}", + summary="Update an existing prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.prompts.update( + "greeting_prompt", + template="Greetings, {name}!", + input_types={"name": "string", "age": "integer"} +) +""", + }, + { + "lang": "cURL", + "source": """ +curl -X PUT "https://api.example.com/v3/prompts/greeting_prompt" \\ + -H "Authorization: Bearer YOUR_API_KEY" \\ + -H "Content-Type: application/json" \\ + -d '{"template": "Greetings, {name}!", "input_types": {"name": "string", "age": "integer"}}' +""", + }, + ] + }, + ) @self.base_endpoint async def update_prompt( name: str = Path(..., description="Prompt name"), @@ -114,32 +273,81 @@ async def update_prompt( ) -> WrappedPromptMessageResponse: """ Update an existing prompt's template and/or input types. + + This endpoint allows superusers to update the template and input types of an existing prompt. + + Args: + name (str): The name of the prompt to update. + template (str, optional): The updated template string for the prompt. + input_types (dict, optional): The updated dictionary mapping input names to their types. + auth_user: The authenticated user making the request. + + Returns: + WrappedPromptMessageResponse: The updated prompt details. + + Raises: + R2RException: If the user is not a superuser or if the prompt is not found. """ if not auth_user.is_superuser: raise R2RException( "Only a superuser can update prompts.", 403, ) - result = await self.services.update_prompt( name, template, input_types ) return result - @self.router.delete("/prompts/{name}") + @self.router.delete( + "/prompts/{name}", + summary="Delete a prompt", + openapi_extra={ + "x-codeSamples": [ + { + "lang": "Python", + "source": """ +from r2r import R2RClient + +client = R2RClient("http://localhost:7272") +# when using auth, do client.login(...) + +result = client.prompts.delete("greeting_prompt") +""", + }, + { + "lang": "cURL", + "source": """ +curl -X DELETE "https://api.example.com/v3/prompts/greeting_prompt" \\ + -H "Authorization: Bearer YOUR_API_KEY" +""", + }, + ] + }, + ) @self.base_endpoint async def delete_prompt( name: str = Path(..., description="Prompt name"), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> WrappedDeleteResponse: + ) -> ResultsWrapper[bool]: """ Delete a prompt by name. + + This endpoint allows superusers to delete an existing prompt. + + Args: + name (str): The name of the prompt to delete. + auth_user: The authenticated user making the request. + + Returns: + WrappedDeleteResponse: Confirmation of the deletion. + + Raises: + R2RException: If the user is not a superuser or if the prompt is not found. """ if not auth_user.is_superuser: raise R2RException( "Only a superuser can delete prompts.", 403, ) - await self.services.delete_prompt(name) return None diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index a8a4c7ec0..b593205d6 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -197,7 +197,7 @@ async def search_app( auth_user, kg_search_settings ) - results = await self.service.search( + results = await self.services["retrieval"].search( query=query, vector_search_settings=vector_search_settings, kg_search_settings=kg_search_settings, @@ -313,7 +313,7 @@ async def rag_app( auth_user, vector_search_settings ) - response = await self.service.rag( + response = await self.services["retrieval"].rag( query=query, vector_search_settings=vector_search_settings, kg_search_settings=kg_search_settings, @@ -486,7 +486,7 @@ async def agent_app( kg_search_settings.filters = vector_search_settings.filters try: - response = await self.service.agent( + response = await self.services["retrieval"].agent( message=message, messages=messages, vector_search_settings=vector_search_settings, @@ -613,7 +613,7 @@ async def completion( system message at the start. Each message should have a 'role' and 'content'. """ - return await self.service.completion( + return await self.services["retrieval"].completion( messages=messages, generation_config=generation_config, ) diff --git a/py/core/main/api/v3/users_router.py b/py/core/main/api/v3/users_router.py index 4e6d70c3b..f994b4297 100644 --- a/py/core/main/api/v3/users_router.py +++ b/py/core/main/api/v3/users_router.py @@ -267,7 +267,7 @@ def _setup_routes(self): # client.login(...) # List users with filters -users = client.list_users( +users = client.users.list( offset=0, limit=100, username="john", @@ -340,7 +340,9 @@ async def list_users( # client.login(...) # Get user details -user = client.get_user("550e8400-e29b-41d4-a716-446655440000") +users = users.retrieve( + id="b4ac4dd6-5f27-596e-a55b-7cf242ca30aa" +) """, }, { @@ -502,7 +504,7 @@ async def remove_user_from_collection( ..., example="750e8400-e29b-41d4-a716-446655440000" ), auth_user=Depends(self.providers.auth.auth_wrapper), - ) -> ResultsWrapper[None]: + ) -> ResultsWrapper[bool]: """ Remove a user from a collection. Requires either superuser status or access to the collection. diff --git a/py/core/main/orchestration/simple/ingestion_workflow.py b/py/core/main/orchestration/simple/ingestion_workflow.py index 5b3a40209..5b5293d50 100644 --- a/py/core/main/orchestration/simple/ingestion_workflow.py +++ b/py/core/main/orchestration/simple/ingestion_workflow.py @@ -187,10 +187,12 @@ async def ingest_chunks(input_data): from core.base import IngestionStatus from core.main import IngestionServiceAdapter + print("input_data = ", input_data) parsed_data = IngestionServiceAdapter.parse_ingest_chunks_input( input_data ) + print("parsed_data = ", parsed_data) document_info = await service.ingest_chunks_ingress(**parsed_data) await service.update_document_status( @@ -200,7 +202,11 @@ async def ingest_chunks(input_data): extractions = [ DocumentChunk( - id=generate_extraction_id(document_id, i), + id=( + generate_extraction_id(document_id, i) + if chunk.id is None + else chunk.id + ), document_id=document_id, collection_ids=[], user_id=document_info.user_id, diff --git a/py/core/main/services/ingestion_service.py b/py/core/main/services/ingestion_service.py index 0e7037fa4..3d62f03df 100644 --- a/py/core/main/services/ingestion_service.py +++ b/py/core/main/services/ingestion_service.py @@ -15,6 +15,7 @@ R2RException, RawChunk, RunManager, + UnprocessedChunk, Vector, VectorEntry, VectorType, @@ -655,7 +656,10 @@ def parse_ingest_chunks_input(data: dict) -> dict: "user": IngestionServiceAdapter._parse_user_data(data["user"]), "metadata": data["metadata"], "document_id": data["document_id"], - "chunks": [RawChunk.from_dict(chunk) for chunk in data["chunks"]], + "chunks": [ + UnprocessedChunk.from_dict(chunk) for chunk in data["chunks"] + ], + "id": data.get("id"), } @staticmethod diff --git a/py/core/main/services/management_service.py b/py/core/main/services/management_service.py index 5779c9bdc..e4ac374a9 100644 --- a/py/core/main/services/management_service.py +++ b/py/core/main/services/management_service.py @@ -228,6 +228,8 @@ async def delete( NOTE: This method is not atomic and may result in orphaned entries in the documents overview table. NOTE: This method assumes that filters delete entire contents of any touched documents. """ + ### TODO - FIX THIS, ENSURE THAT DOCUMENTS OVERVIEW IS CLEARED + ### TODO - FIX THIS, ENSURE THAT DOCUMENTS OVERVIEW IS CLEARED def validate_filters(filters: dict[str, Any]) -> None: ALLOWED_FILTERS = { @@ -235,6 +237,10 @@ def validate_filters(filters: dict[str, Any]) -> None: "user_id", "collection_ids", "chunk_id", + # TODO - Modify these checks such that they can be used PROPERLY for nested filters + # TODO - Modify these checks such that they can be used PROPERLY for nested filters + "$and", + "$or", } if not filters: diff --git a/py/core/providers/database/collection.py b/py/core/providers/database/collection.py index 7dff5d8b6..99509dfb7 100644 --- a/py/core/providers/database/collection.py +++ b/py/core/providers/database/collection.py @@ -14,8 +14,8 @@ from core.base.abstractions import DocumentInfo, DocumentType, IngestionStatus from core.base.api.models import CollectionOverviewResponse, CollectionResponse from core.utils import ( - generate_collection_id_from_name, generate_default_user_collection_id, + generate_id_from_label, ) from .base import PostgresConnectionManager @@ -58,7 +58,7 @@ async def create_default_collection( user_id ) else: - default_collection_uuid = generate_collection_id_from_name( + default_collection_uuid = generate_id_from_label( self.config.default_collection_name ) diff --git a/py/core/providers/database/vector.py b/py/core/providers/database/vector.py index 2fd3586aa..d2d650c23 100644 --- a/py/core/providers/database/vector.py +++ b/py/core/providers/database/vector.py @@ -19,6 +19,7 @@ VectorSearchResult, VectorSearchSettings, VectorTableName, + generate_id_from_label, ) from .base import PostgresConnectionManager @@ -995,75 +996,108 @@ def parse_filter(filter_dict: dict) -> str: return where_clause async def list_indices( - self, table_name: Optional[VectorTableName] = None - ) -> list[dict[str, Any]]: - """ - Lists all vector indices for the specified table. + self, + offset: int = 0, + limit: int = 10, + filters: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + where_clauses = [] + params: list[Any] = [self.project_name] # Start with schema name + param_count = 1 - Args: - table_name (VectorTableName, optional): The table to list indices for. - If None, defaults to VECTORS table. + # Handle filtering + if filters: + if "table_name" in filters: + where_clauses.append(f"i.tablename = ${param_count + 1}") + params.append(filters["table_name"]) + param_count += 1 + if "index_method" in filters: + where_clauses.append(f"am.amname = ${param_count + 1}") + params.append(filters["index_method"]) + param_count += 1 + if "index_name" in filters: + where_clauses.append( + f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})" + ) + params.append(f"%{filters['index_name']}%") + param_count += 1 - Returns: - List[dict]: List of indices with their properties + where_clause = " AND ".join(where_clauses) if where_clauses else "" + if where_clause: + where_clause = "AND " + where_clause - Raises: - ArgError: If an invalid table name is provided + query = f""" + WITH index_info AS ( + SELECT + i.indexname as name, + i.tablename as table_name, + i.indexdef as definition, + am.amname as method, + pg_relation_size(c.oid) as size_in_bytes, + c.reltuples::bigint as row_estimate, + COALESCE(psat.idx_scan, 0) as number_of_scans, + COALESCE(psat.idx_tup_read, 0) as tuples_read, + COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched, + COUNT(*) OVER() as total_count + FROM pg_indexes i + JOIN pg_class c ON c.relname = i.indexname + JOIN pg_am am ON c.relam = am.oid + LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname + AND psat.schemaname = i.schemaname + WHERE i.schemaname = $1 + AND i.indexdef LIKE '%vector%' + {where_clause} + ) + SELECT * + FROM index_info + ORDER BY name + LIMIT ${param_count + 1} + OFFSET ${param_count + 2} """ - if table_name == VectorTableName.VECTORS: - table_name_str = f"{self.project_name}.{VectorTableName.VECTORS}" - col_name = "vec" - elif table_name == VectorTableName.ENTITIES_DOCUMENT: - table_name_str = ( - f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}" - ) - col_name = "description_embedding" - elif table_name == VectorTableName.ENTITIES_COLLECTION: - table_name_str = ( - f"{self.project_name}.{VectorTableName.ENTITIES_COLLECTION}" - ) - elif table_name == VectorTableName.COMMUNITIES: - table_name_str = ( - f"{self.project_name}.{VectorTableName.COMMUNITIES}" - ) - col_name = "embedding" - else: - raise ArgError("invalid table name") - query = """ - SELECT - i.indexname as name, - i.indexdef as definition, - am.amname as method, - pg_relation_size(c.oid) as size_in_bytes, - COALESCE(psat.idx_scan, 0) as number_of_scans, - COALESCE(psat.idx_tup_read, 0) as tuples_read, - COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched - FROM pg_indexes i - JOIN pg_class c ON c.relname = i.indexname - JOIN pg_am am ON c.relam = am.oid - LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname - AND psat.schemaname = i.schemaname - WHERE i.schemaname || '.' || i.tablename = $1 - AND i.indexdef LIKE $2; - """ + # Add limit and offset to params + params.extend([limit, offset]) - results = await self.connection_manager.fetch_query( - query, (table_name_str, f"%({col_name}%") - ) + results = await self.connection_manager.fetch_query(query, params) - return [ - { - "name": result["name"], - "definition": result["definition"], - "method": result["method"], - "size_in_bytes": result["size_in_bytes"], - "number_of_scans": result["number_of_scans"], - "tuples_read": result["tuples_read"], - "tuples_fetched": result["tuples_fetched"], - } - for result in results - ] + indices = [] + total_entries = 0 + + if results: + total_entries = results[0]["total_count"] + for result in results: + index_info = { + "name": result["name"], + "table_name": result["table_name"], + "definition": result["definition"], + "method": result["method"], + "size_in_bytes": result["size_in_bytes"], + "row_estimate": result["row_estimate"], + "number_of_scans": result["number_of_scans"], + "tuples_read": result["tuples_read"], + "tuples_fetched": result["tuples_fetched"], + } + indices.append(index_info) + + # Calculate pagination info + total_pages = (total_entries + limit - 1) // limit if limit > 0 else 1 + current_page = (offset // limit) + 1 if limit > 0 else 1 + + page_info = { + "total_entries": total_entries, + "total_pages": total_pages, + "current_page": current_page, + "limit": limit, + "offset": offset, + "has_previous": offset > 0, + "has_next": offset + limit < total_entries, + "previous_offset": max(0, offset - limit) if offset > 0 else None, + "next_offset": ( + offset + limit if offset + limit < total_entries else None + ), + } + + return {"indices": indices, "page_info": page_info} async def delete_index( self, diff --git a/py/core/utils/__init__.py b/py/core/utils/__init__.py index d885198b2..51e6bb958 100644 --- a/py/core/utils/__init__.py +++ b/py/core/utils/__init__.py @@ -4,11 +4,11 @@ format_relations, format_search_results_for_llm, format_search_results_for_stream, - generate_collection_id_from_name, generate_default_user_collection_id, generate_document_id, generate_extraction_id, generate_id, + generate_id_from_label, generate_user_id, increment_version, run_pipeline, @@ -29,7 +29,7 @@ "generate_id", "generate_document_id", "generate_extraction_id", - "generate_collection_id_from_name", + "generate_id_from_label", "generate_user_id", "increment_version", "decrement_version", diff --git a/py/sdk/async_client.py b/py/sdk/async_client.py index 27aa1cd0e..1acbfb43b 100644 --- a/py/sdk/async_client.py +++ b/py/sdk/async_client.py @@ -16,6 +16,9 @@ ) from .v3.chunks import ChunksSDK from .v3.documents import DocumentsSDK +from .v3.indices import IndicesSDK +from .v3.retrieval import RetrievalSDK +from .v3.users import UsersSDK class R2RAsyncClient( @@ -48,6 +51,9 @@ def __init__( self.client = custom_client or httpx.AsyncClient(timeout=timeout) self.documents = DocumentsSDK(self) self.chunks = ChunksSDK(self) + self.retrieval = RetrievalSDK(self) + self.indices = IndicesSDK(self) + self.users = UsersSDK(self) async def _make_request(self, method: str, endpoint: str, **kwargs): url = self._get_full_url(endpoint) diff --git a/py/sdk/sync_client.py b/py/sdk/sync_client.py index dbd7599c3..85fb7e07e 100644 --- a/py/sdk/sync_client.py +++ b/py/sdk/sync_client.py @@ -4,6 +4,9 @@ from .utils import SyncClientMetaclass from .v3.chunks import SyncChunkSDK from .v3.documents import SyncDocumentSDK +from .v3.indices import SyncIndexSDK +from .v3.retrieval import SyncRetrievalSDK +from .v3.users import SyncUsersSDK class R2RClient(R2RAsyncClient, metaclass=SyncClientMetaclass): @@ -21,6 +24,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.documents = SyncDocumentSDK(self.documents) self.chunks = SyncChunkSDK(self.chunks) + self.retrieval = SyncRetrievalSDK(self.retrieval) + self.indices = SyncIndexSDK(self.indices) + self.users = SyncUsersSDK(self.users) def _make_streaming_request(self, method: str, endpoint: str, **kwargs): async_gen = super()._make_streaming_request(method, endpoint, **kwargs) diff --git a/py/sdk/v3/chunks.py b/py/sdk/v3/chunks.py index e9b5400cc..a5c485777 100644 --- a/py/sdk/v3/chunks.py +++ b/py/sdk/v3/chunks.py @@ -171,7 +171,7 @@ async def delete( """ await self.client._make_request("DELETE", f"chunks/{str(id)}") - async def list_chunks( + async def list( self, offset: int = 0, limit: int = 10, diff --git a/py/sdk/v3/documents.py b/py/sdk/v3/documents.py index d26bc8ea2..ed2f7f7bb 100644 --- a/py/sdk/v3/documents.py +++ b/py/sdk/v3/documents.py @@ -203,7 +203,9 @@ async def delete( Args: id (Union[str, UUID]): ID of document to delete """ - await self.client._make_request("DELETE", f"documents/{str(id)}") + return await self.client._make_request( + "DELETE", f"documents/{str(id)}" + ) async def list_chunks( self, @@ -274,7 +276,7 @@ async def delete_by_filter( filters (Dict[str, Any]): Filters to apply when selecting documents to delete """ filters_json = json.dumps(filters) - await self.client._make_request( + return await self.client._make_request( "DELETE", "documents/by-filter", params={"filters": filters_json} ) diff --git a/py/sdk/v3/indices.py b/py/sdk/v3/indices.py new file mode 100644 index 000000000..64236302a --- /dev/null +++ b/py/sdk/v3/indices.py @@ -0,0 +1,150 @@ +import json +import logging +from inspect import getmembers, isasyncgenfunction, iscoroutinefunction +from typing import AsyncGenerator, Optional, Union +from uuid import UUID + +from ..base.base_client import sync_generator_wrapper, sync_wrapper + +# from ..models import ( +# IndexConfig, +# WrappedCreateVectorIndexResponse, +# WrappedListVectorIndicesResponse, +# WrappedGetIndexResponse, +# WrappedUpdateIndexResponse, +# WrappedDeleteVectorIndexResponse, +# ) + +logger = logging.getLogger() + + +class IndicesSDK: + def __init__(self, client): + self.client = client + + async def create_index( + self, + config: dict, # Union[dict, IndexConfig], + run_with_orchestration: Optional[bool] = True, + ) -> dict: + """ + Create a new vector similarity search index in the database. + + Args: + config (Union[dict, IndexConfig]): Configuration for the vector index. + run_with_orchestration (Optional[bool]): Whether to run index creation as an orchestrated task. + + Returns: + WrappedCreateVectorIndexResponse: The response containing the created index details. + """ + if not isinstance(config, dict): + config = config.model_dump() + + data = { + "config": config, + "run_with_orchestration": run_with_orchestration, + } + return await self.client._make_request("POST", "indices", json=data) # type: ignore + + async def list_indices( + self, + offset: int = 0, + limit: int = 10, + filters: Optional[dict] = None, + ) -> dict: + """ + List existing vector similarity search indices with pagination support. + + Args: + offset (int): Number of records to skip. + limit (int): Maximum number of records to return. + filters (Optional[dict]): Filter criteria for indices. + + Returns: + WrappedListVectorIndicesResponse: The response containing the list of indices. + """ + params = { + "offset": offset, + "limit": limit, + } + if filters: + params["filters"] = json.dumps(filters) + return await self.client._make_request("GET", "indices", params=params) # type: ignore + + async def get_index( + self, + index_name: str, + table_name: str = "vectors", + ) -> dict: + """ + Get detailed information about a specific vector index. + + Args: + id (Union[str, UUID]): The ID of the index to retrieve. + + Returns: + WrappedGetIndexResponse: The response containing the index details. + """ + return await self.client._make_request("GET", f"indices/{table_name}/{index_name}") # type: ignore + + # async def update_index( + # self, + # id: Union[str, UUID], + # config: dict, # Union[dict, IndexConfig], + # run_with_orchestration: Optional[bool] = True, + # ) -> dict: + # """ + # Update an existing index's configuration. + + # Args: + # id (Union[str, UUID]): The ID of the index to update. + # config (Union[dict, IndexConfig]): The new configuration for the index. + # run_with_orchestration (Optional[bool]): Whether to run the update as an orchestrated task. + + # Returns: + # WrappedUpdateIndexResponse: The response containing the updated index details. + # """ + # if not isinstance(config, dict): + # config = config.model_dump() + + # data = { + # "config": config, + # "run_with_orchestration": run_with_orchestration, + # } + # return await self.client._make_request("POST", f"indices/{id}", json=data) # type: ignore + + async def delete_index( + self, + index_name: str, + table_name: str = "vectors", + ) -> dict: + """ + Get detailed information about a specific vector index. + + Args: + id (Union[str, UUID]): The ID of the index to retrieve. + + Returns: + WrappedGetIndexResponse: The response containing the index details. + """ + return await self.client._make_request("DELETE", f"indices/{table_name}/{index_name}") # type: ignore + + +class SyncIndexSDK: + """Synchronous wrapper for DocumentsSDK""" + + def __init__(self, async_sdk: IndicesSDK): + self._async_sdk = async_sdk + + # Get all attributes from the instance + for name in dir(async_sdk): + if not name.startswith("_"): # Skip private methods + attr = getattr(async_sdk, name) + # Check if it's a method and if it's async + if callable(attr) and ( + iscoroutinefunction(attr) or isasyncgenfunction(attr) + ): + if isasyncgenfunction(attr): + setattr(self, name, sync_generator_wrapper(attr)) + else: + setattr(self, name, sync_wrapper(attr)) diff --git a/py/sdk/v3/retrieval.py b/py/sdk/v3/retrieval.py new file mode 100644 index 000000000..d62a13e2f --- /dev/null +++ b/py/sdk/v3/retrieval.py @@ -0,0 +1,226 @@ +import logging +from inspect import getmembers, isasyncgenfunction, iscoroutinefunction +from typing import AsyncGenerator, Optional, Union + +from ..base.base_client import sync_generator_wrapper, sync_wrapper +from ..models import ( + CombinedSearchResponse, + GenerationConfig, + KGSearchSettings, + Message, + RAGResponse, + VectorSearchSettings, +) + +logger = logging.getLogger() + + +class RetrievalSDK: + """ + SDK for interacting with documents in the v3 API. + """ + + def __init__(self, client): + self.client = client + + async def search( + self, + query: str, + vector_search_settings: Optional[ + Union[dict, VectorSearchSettings] + ] = None, + kg_search_settings: Optional[Union[dict, KGSearchSettings]] = None, + ) -> CombinedSearchResponse: + """ + Conduct a vector and/or KG search. + + Args: + query (str): The query to search for. + vector_search_settings (Optional[Union[dict, VectorSearchSettings]]): Vector search settings. + kg_search_settings (Optional[Union[dict, KGSearchSettings]]): KG search settings. + + Returns: + CombinedSearchResponse: The search response. + """ + if vector_search_settings and not isinstance( + vector_search_settings, dict + ): + vector_search_settings = vector_search_settings.model_dump() + if kg_search_settings and not isinstance(kg_search_settings, dict): + kg_search_settings = kg_search_settings.model_dump() + + data = { + "query": query, + "vector_search_settings": vector_search_settings, + "kg_search_settings": kg_search_settings, + } + return await self.client._make_request("POST", "retrieval/search", json=data) # type: ignore + + async def completion( + self, + messages: list[Union[dict, Message]], + generation_config: Optional[Union[dict, GenerationConfig]] = None, + ): + cast_messages: list[Message] = [ + Message(**msg) if isinstance(msg, dict) else msg + for msg in messages + ] + + if generation_config and not isinstance(generation_config, dict): + generation_config = generation_config.model_dump() + + data = { + "messages": [msg.model_dump() for msg in cast_messages], + "generation_config": generation_config, + } + + return await self.client._make_request("POST", "retrieval/completion", json=data) # type: ignore + + async def rag( + self, + query: str, + rag_generation_config: Optional[Union[dict, GenerationConfig]] = None, + vector_search_settings: Optional[ + Union[dict, VectorSearchSettings] + ] = None, + kg_search_settings: Optional[Union[dict, KGSearchSettings]] = None, + task_prompt_override: Optional[str] = None, + include_title_if_available: Optional[bool] = False, + ) -> Union[RAGResponse, AsyncGenerator[RAGResponse, None]]: + """ + Conducts a Retrieval Augmented Generation (RAG) search with the given query. + + Args: + query (str): The query to search for. + rag_generation_config (Optional[Union[dict, GenerationConfig]]): RAG generation configuration. + vector_search_settings (Optional[Union[dict, VectorSearchSettings]]): Vector search settings. + kg_search_settings (Optional[Union[dict, KGSearchSettings]]): KG search settings. + task_prompt_override (Optional[str]): Task prompt override. + include_title_if_available (Optional[bool]): Include the title if available. + + Returns: + Union[RAGResponse, AsyncGenerator[RAGResponse, None]]: The RAG response + """ + if rag_generation_config and not isinstance( + rag_generation_config, dict + ): + rag_generation_config = rag_generation_config.model_dump() + if vector_search_settings and not isinstance( + vector_search_settings, dict + ): + vector_search_settings = vector_search_settings.model_dump() + if kg_search_settings and not isinstance(kg_search_settings, dict): + kg_search_settings = kg_search_settings.model_dump() + + data = { + "query": query, + "rag_generation_config": rag_generation_config, + "vector_search_settings": vector_search_settings, + "kg_search_settings": kg_search_settings, + "task_prompt_override": task_prompt_override, + "include_title_if_available": include_title_if_available, + } + + if rag_generation_config and rag_generation_config.get( # type: ignore + "stream", False + ): + return self._make_streaming_request("POST", "retrieval/rag", json=data) # type: ignore + else: + return await self.client._make_request("POST", "retrieval/rag", json=data) # type: ignore + + async def agent( + self, + message: Optional[Union[dict, Message]] = None, + rag_generation_config: Optional[Union[dict, GenerationConfig]] = None, + vector_search_settings: Optional[ + Union[dict, VectorSearchSettings] + ] = None, + kg_search_settings: Optional[Union[dict, KGSearchSettings]] = None, + task_prompt_override: Optional[str] = None, + include_title_if_available: Optional[bool] = False, + conversation_id: Optional[str] = None, + branch_id: Optional[str] = None, + # TODO - Deprecate messages + messages: Optional[Union[dict, Message]] = None, + ) -> Union[list[Message], AsyncGenerator[Message, None]]: + """ + Performs a single turn in a conversation with a RAG agent. + + Args: + messages (List[Union[dict, Message]]): The messages to send to the agent. + rag_generation_config (Optional[Union[dict, GenerationConfig]]): RAG generation configuration. + vector_search_settings (Optional[Union[dict, VectorSearchSettings]]): Vector search settings. + kg_search_settings (Optional[Union[dict, KGSearchSettings]]): KG search settings. + task_prompt_override (Optional[str]): Task prompt override. + include_title_if_available (Optional[bool]): Include the title if available. + + Returns: + Union[List[Message], AsyncGenerator[Message, None]]: The agent response. + """ + if messages: + logger.warning( + "The `messages` argument is deprecated. Please use `message` instead." + ) + if rag_generation_config and not isinstance( + rag_generation_config, dict + ): + rag_generation_config = rag_generation_config.model_dump() + if vector_search_settings and not isinstance( + vector_search_settings, dict + ): + vector_search_settings = vector_search_settings.model_dump() + if kg_search_settings and not isinstance(kg_search_settings, dict): + kg_search_settings = kg_search_settings.model_dump() + + data = { + "rag_generation_config": rag_generation_config or {}, + "vector_search_settings": vector_search_settings or {}, + "kg_search_settings": kg_search_settings, + "task_prompt_override": task_prompt_override, + "include_title_if_available": include_title_if_available, + "conversation_id": conversation_id, + "branch_id": branch_id, + } + + if message: + cast_message: Message = ( + Message(**message) if isinstance(message, dict) else message + ) + data["message"] = cast_message.model_dump() + + if messages: + data["messages"] = [ + ( + Message(**msg).model_dump() # type: ignore + if isinstance(msg, dict) + else msg.model_dump() # type: ignore + ) + for msg in messages + ] + + if rag_generation_config and rag_generation_config.get( # type: ignore + "stream", False + ): + return self._make_streaming_request("POST", "retrieval/agent", json=data) # type: ignore + else: + return await self.client._make_request("POST", "retrieval/agent", json=data) # type: ignore + + +class SyncRetrievalSDK: + """Synchronous wrapper for ChunksSDK""" + + def __init__(self, async_sdk: RetrievalSDK): + self._async_sdk = async_sdk + + # Get all attributes from the instance + for name in dir(async_sdk): + if not name.startswith("_"): # Skip private methods + attr = getattr(async_sdk, name) + # Check if it's a method and if it's async + if callable(attr) and ( + iscoroutinefunction(attr) or isasyncgenfunction(attr) + ): + if isasyncgenfunction(attr): + setattr(self, name, sync_generator_wrapper(attr)) + else: + setattr(self, name, sync_wrapper(attr)) diff --git a/py/sdk/v3/users.py b/py/sdk/v3/users.py new file mode 100644 index 000000000..5c2f1184d --- /dev/null +++ b/py/sdk/v3/users.py @@ -0,0 +1,196 @@ +import json +from inspect import getmembers, isasyncgenfunction, iscoroutinefunction +from typing import Any, Dict, List, Optional, Union +from uuid import UUID + +from ..base.base_client import sync_generator_wrapper, sync_wrapper + + +class UsersSDK: + """ + SDK for interacting with users in the v3 API. + """ + + def __init__(self, client): + self.client = client + + async def list( + self, + offset: int = 0, + limit: int = 100, + username: Optional[str] = None, + email: Optional[str] = None, + is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, + sort_by: Optional[str] = None, + sort_order: Optional[str] = "desc", + ) -> dict: + """ + List users with pagination and filtering options. + + Args: + offset (int): Number of records to skip + limit (int): Maximum number of records to return + username (Optional[str]): Filter by username (partial match) + email (Optional[str]): Filter by email (partial match) + is_active (Optional[bool]): Filter by active status + is_superuser (Optional[bool]): Filter by superuser status + sort_by (Optional[str]): Field to sort by (created_at, username, email) + sort_order (Optional[str]): Sort order (asc or desc) + + Returns: + dict: List of users and pagination information + """ + params = { + "offset": offset, + "limit": limit, + "sort_order": sort_order, + } + + if username: + params["username"] = username + if email: + params["email"] = email + if is_active is not None: + params["is_active"] = is_active + if is_superuser is not None: + params["is_superuser"] = is_superuser + if sort_by: + params["sort_by"] = sort_by + + return await self.client._make_request("GET", "users", params=params) + + async def retrieve( + self, + id: Union[str, UUID], + ) -> dict: + """ + Get detailed information about a specific user. + + Args: + id (Union[str, UUID]): User ID to retrieve + + Returns: + dict: Detailed user information + """ + return await self.client._make_request("GET", f"users/{str(id)}") + + async def update( + self, + id: Union[str, UUID], + username: Optional[str] = None, + email: Optional[str] = None, + is_active: Optional[bool] = None, + is_superuser: Optional[bool] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> dict: + """ + Update user information. + + Args: + id (Union[str, UUID]): User ID to update + username (Optional[str]): New username + email (Optional[str]): New email address + is_active (Optional[bool]): Update active status + is_superuser (Optional[bool]): Update superuser status + metadata (Optional[Dict[str, Any]]): Update user metadata + + Returns: + dict: Updated user information + """ + data = {} + if username is not None: + data["username"] = username + if email is not None: + data["email"] = email + if is_active is not None: + data["is_active"] = is_active + if is_superuser is not None: + data["is_superuser"] = is_superuser + if metadata is not None: + data["metadata"] = metadata + + return await self.client._make_request( + "POST", f"users/{str(id)}", json=data + ) + + async def list_collections( + self, + id: Union[str, UUID], + offset: int = 0, + limit: int = 100, + ) -> dict: + """ + Get all collections associated with a specific user. + + Args: + id (Union[str, UUID]): User ID to get collections for + offset (int): Number of records to skip + limit (int): Maximum number of records to return + + Returns: + dict: List of collections and pagination information + """ + params = { + "offset": offset, + "limit": limit, + } + + return await self.client._make_request( + "GET", f"users/{str(id)}/collections", params=params + ) + + async def add_to_collection( + self, + id: Union[str, UUID], + collection_id: Union[str, UUID], + ) -> None: + """ + Add a user to a collection. + + Args: + id (Union[str, UUID]): User ID to add + collection_id (Union[str, UUID]): Collection ID to add user to + """ + await self.client._make_request( + "POST", f"users/{str(id)}/collections/{str(collection_id)}" + ) + + async def remove_from_collection( + self, + id: Union[str, UUID], + collection_id: Union[str, UUID], + ) -> bool: + """ + Remove a user from a collection. + + Args: + id (Union[str, UUID]): User ID to remove + collection_id (Union[str, UUID]): Collection ID to remove user from + + Returns: + bool: True if successful + """ + return await self.client._make_request( + "DELETE", f"users/{str(id)}/collections/{str(collection_id)}" + ) + + +class SyncUsersSDK: + """Synchronous wrapper for UsersSDK""" + + def __init__(self, async_sdk: UsersSDK): + self._async_sdk = async_sdk + + # Get all attributes from the instance + for name in dir(async_sdk): + if not name.startswith("_"): # Skip private methods + attr = getattr(async_sdk, name) + # Check if it's a method and if it's async + if callable(attr) and ( + iscoroutinefunction(attr) or isasyncgenfunction(attr) + ): + if isasyncgenfunction(attr): + setattr(self, name, sync_generator_wrapper(attr)) + else: + setattr(self, name, sync_wrapper(attr)) diff --git a/py/shared/api/models/ingestion/responses.py b/py/shared/api/models/ingestion/responses.py index 58e187d1a..bf3c7fabe 100644 --- a/py/shared/api/models/ingestion/responses.py +++ b/py/shared/api/models/ingestion/responses.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field -from shared.api.models.base import ResultsWrapper +from shared.api.models.base import PaginatedResultsWrapper, ResultsWrapper T = TypeVar("T") @@ -78,6 +78,8 @@ class SelectVectorIndexResponse(BaseModel): WrappedIngestionResponse = ResultsWrapper[list[IngestionResponse]] WrappedUpdateResponse = ResultsWrapper[UpdateResponse] WrappedCreateVectorIndexResponse = ResultsWrapper[CreateVectorIndexResponse] -WrappedListVectorIndicesResponse = ResultsWrapper[ListVectorIndicesResponse] +WrappedListVectorIndicesResponse = PaginatedResultsWrapper[ + ListVectorIndicesResponse +] WrappedDeleteVectorIndexResponse = ResultsWrapper[DeleteVectorIndexResponse] WrappedSelectVectorIndexResponse = ResultsWrapper[SelectVectorIndexResponse] diff --git a/py/shared/utils/__init__.py b/py/shared/utils/__init__.py index 41edbec67..e0c5aa934 100644 --- a/py/shared/utils/__init__.py +++ b/py/shared/utils/__init__.py @@ -5,12 +5,12 @@ format_relations, format_search_results_for_llm, format_search_results_for_stream, - generate_collection_id_from_name, generate_default_prompt_id, generate_default_user_collection_id, generate_document_id, generate_extraction_id, generate_id, + generate_id_from_label, generate_user_id, increment_version, llm_cost_per_million_tokens, @@ -31,7 +31,7 @@ "generate_extraction_id", "generate_default_user_collection_id", "generate_user_id", - "generate_collection_id_from_name", + "generate_id_from_label", "generate_default_prompt_id", # Other "increment_version", diff --git a/py/shared/utils/base_utils.py b/py/shared/utils/base_utils.py index 8c39b0cf5..b919cd97c 100644 --- a/py/shared/utils/base_utils.py +++ b/py/shared/utils/base_utils.py @@ -138,7 +138,7 @@ def generate_default_user_collection_id(user_id: UUID) -> UUID: return _generate_id_from_label(str(user_id)) -def generate_collection_id_from_name(collection_name: str) -> UUID: +def generate_id_from_label(collection_name: str) -> UUID: """ Generates a unique collection id from a given collection name """