Skip to content

Commit

Permalink
Fix type errors, pass collection id
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Nov 15, 2024
1 parent 0b1f6c5 commit 6f2a968
Show file tree
Hide file tree
Showing 22 changed files with 87 additions and 107 deletions.
10 changes: 4 additions & 6 deletions py/core/base/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Optional, Type, Union
from typing import Any, AsyncGenerator, Optional, Type

from pydantic import BaseModel

Expand All @@ -26,7 +26,7 @@ def __init__(self):

def create_and_add_message(
self,
role: Union[MessageType, str],
role: MessageType | str,
content: Optional[str] = None,
name: Optional[str] = None,
function_call: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -123,9 +123,7 @@ async def arun(
messages: Optional[list[Message]] = None,
*args,
**kwargs,
) -> Union[
list[LLMChatCompletion], AsyncGenerator[LLMChatCompletion, None]
]:
) -> list[LLMChatCompletion] | AsyncGenerator[LLMChatCompletion, None]:
pass

@abstractmethod
Expand All @@ -134,7 +132,7 @@ async def process_llm_response(
response: Any,
*args,
**kwargs,
) -> Union[None, AsyncGenerator[str, None]]:
) -> None | AsyncGenerator[str, None]:
pass

async def execute_tool(self, tool_name: str, *args, **kwargs) -> str:
Expand Down
2 changes: 1 addition & 1 deletion py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,7 @@ async def assign_document_to_collection_relational(
collection_id: UUID,
) -> UUID:
return await self.collection_handler.assign_document_to_collection_relational(
document_id, collection_id
document_id=document_id, collection_id=collection_id
)

async def remove_document_from_collection_relational(
Expand Down
11 changes: 3 additions & 8 deletions py/core/main/api/v2/ingestion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,15 +484,10 @@ async def update_document_metadata_app(
workflow_input
)

return { # type: ignore
"message": "Update metadata task completed successfully.",
"document_id": str(document_id),
"task_id": None,
}
return [
{ # type: ignore
"message": "Ingestion task completed successfully.",
"document_id": str(document_uuid),
"document_id": str(document_id),
"task_id": None,
}
]
Expand Down Expand Up @@ -636,7 +631,7 @@ async def create_vector_index_app(
},
)

return GenericMessageResponse(message=raw_message)
return GenericMessageResponse(message=raw_message) # type: ignore

list_vector_indices_extras = self.openapi_extras.get(
"create_vector_index", {}
Expand Down Expand Up @@ -725,7 +720,7 @@ async def delete_vector_index_app(
},
)

return GenericMessageResponse(message=raw_message)
return GenericMessageResponse(message=raw_message) # type: ignore

@staticmethod
async def _process_files(files):
Expand Down
10 changes: 5 additions & 5 deletions py/core/main/api/v2/management_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def update_prompt_app(
result = await self.service.update_prompt(
name, template, input_types
)
return GenericMessageResponse(message=result)
return GenericMessageResponse(message=result) # type: ignore

@self.router.post("/add_prompt")
@self.base_endpoint
Expand All @@ -115,7 +115,7 @@ async def add_prompt_app(
403,
)
result = await self.service.add_prompt(name, template, input_types)
return GenericMessageResponse(message=result)
return GenericMessageResponse(message=result) # type: ignore

@self.router.get("/get_prompt/{prompt_name}")
@self.base_endpoint
Expand All @@ -137,7 +137,7 @@ async def get_prompt_app(
result = await self.service.get_cached_prompt(
prompt_name, inputs, prompt_override
)
return GenericMessageResponse(message=result)
return GenericMessageResponse(message=result) # type: ignore

@self.router.get("/get_all_prompts")
@self.base_endpoint
Expand Down Expand Up @@ -519,7 +519,7 @@ async def collections_overview_app(
)
)

return collections_overview_response["results"], {
return collections_overview_response["results"], { # type: ignore
"total_entries": collections_overview_response["total_entries"]
}

Expand Down Expand Up @@ -640,7 +640,7 @@ async def add_user_to_collection_app(
result = await self.service.add_user_to_collection(
user_uuid, collection_uuid
)
return WrappedBooleanResponse(result=result)
return WrappedBooleanResponse(result=result) # type: ignore

@self.router.post("/remove_user_from_collection")
@self.base_endpoint
Expand Down
10 changes: 5 additions & 5 deletions py/core/main/api/v3/collections_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ async def list_collections(
limit=limit,
)

return (
return ( # type: ignore
collections_overview_response["results"],
{
"total_entries": collections_overview_response[
Expand Down Expand Up @@ -486,7 +486,7 @@ async def delete_collection(
)

await self.services["management"].delete_collection(id)
return GenericBooleanResponse(success=True)
return GenericBooleanResponse(success=True) # type: ignore

@self.router.post(
"/collections/{id}/documents/{document_id}",
Expand Down Expand Up @@ -745,7 +745,7 @@ async def remove_document_from_collection(
await self.services["management"].remove_document_from_collection(
document_id, id
)
return GenericBooleanResponse(success=True)
return GenericBooleanResponse(success=True) # type: ignore

@self.router.get(
"/collections/{id}/users",
Expand Down Expand Up @@ -932,7 +932,7 @@ async def add_user_to_collection(
result = await self.services["management"].add_user_to_collection(
user_id, id
)
return GenericBooleanResponse(success=result)
return GenericBooleanResponse(success=result) # type: ignore

@self.router.delete(
"/collections/{id}/users/{user_id}",
Expand Down Expand Up @@ -1014,4 +1014,4 @@ async def remove_user_from_collection(
await self.services["management"].remove_user_from_collection(
user_id, id
)
return GenericBooleanResponse(success=True)
return GenericBooleanResponse(success=True) # type: ignore
6 changes: 3 additions & 3 deletions py/core/main/api/v3/conversations_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ async def list_conversations(
offset=offset,
limit=limit,
)
return conversations_response["results"], {
return conversations_response["results"], { # type: ignore
"total_entries": conversations_response["total_entries"]
}

Expand Down Expand Up @@ -347,7 +347,7 @@ async def delete_conversation(
This endpoint deletes a conversation identified by its UUID.
"""
await self.services["management"].delete_conversation(str(id))
return GenericBooleanResponse(success=True)
return GenericBooleanResponse(success=True) # type: ignore

@self.router.post(
"/conversations/{id}/messages",
Expand Down Expand Up @@ -609,7 +609,7 @@ async def list_branches(
conversation_id=str(id),
)

return branches_response["results"], {
return branches_response["results"], { # type: ignore
"total_entries": branches_response["total_entries"]
}

Expand Down
19 changes: 13 additions & 6 deletions py/core/main/api/v3/documents_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ async def create_document(
None,
description="The ID of the document. If not provided, a new ID will be generated.",
),
collection_ids: Optional[list[UUID]] = Form(
None,
description="Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection.",
),
metadata: Optional[Json[dict]] = Form(
None,
description="Metadata to associate with the document, such as title, description, or custom fields.",
Expand Down Expand Up @@ -200,6 +204,7 @@ async def create_document(
workflow_input = {
"file_data": file_data,
"document_id": str(document_id),
"collection_ids": collection_ids,
"metadata": metadata,
"ingestion_config": ingestion_config,
"user": auth_user.model_dump_json(),
Expand Down Expand Up @@ -306,7 +311,7 @@ async def create_document(
},
)
@self.base_endpoint
async def update_document(
async def update_document( # type: ignore
file: Optional[UploadFile] = File(
None,
description="The file to ingest. Either a file or content must be provided, but not both.",
Expand Down Expand Up @@ -381,8 +386,10 @@ async def update_document(

# Check if the user is a superuser
if not auth_user.is_superuser:
if "user_id" in metadata and metadata["user_id"] != str(
auth_user.id
if (
metadata is not None
and "user_id" in metadata
and metadata["user_id"] != str(auth_user.id)
):
raise R2RException(
status_code=403,
Expand Down Expand Up @@ -795,7 +802,7 @@ async def list_chunks(
"Not authorized to access this document's chunks.", 403
)

return (
return ( # type: ignore
list_document_chunks["results"],
{"total_entries": list_document_chunks["total_entries"]},
)
Expand Down Expand Up @@ -1019,7 +1026,7 @@ async def delete_document_by_id(
]
}
await self.services["management"].delete(filters=filters)
return GenericBooleanResponse(success=True)
return GenericBooleanResponse(success=True) # type: ignore

@self.router.delete(
"/documents/by-filter",
Expand Down Expand Up @@ -1085,7 +1092,7 @@ async def delete_document_by_filter(
filters=filters_dict
)

return GenericBooleanResponse(success=delete_bool)
return GenericBooleanResponse(success=delete_bool) # type: ignore

@self.router.get(
"/documents/{id}/collections",
Expand Down
25 changes: 0 additions & 25 deletions py/core/main/api/v3/graph_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,6 @@
logger = logging.getLogger()


# class Entity(BaseModel):
# """Model representing a graph entity."""

# id: UUID
# name: str
# type: str
# metadata: dict = Field(default_factory=dict)
# level: EntityLevel
# collection_ids: list[UUID]
# embedding: Optional[list[float]] = None

# class Config:
# json_schema_extra = {
# "example": {
# "id": "9fbe403b-c11c-5aae-8ade-ef22980c3ad1",
# "name": "John Smith",
# "type": "PERSON",
# "metadata": {"confidence": 0.95},
# "level": "DOCUMENT",
# "collection_ids": ["d09dedb1-b2ab-48a5-b950-6e1f464d83e7"],
# "embedding": [0.1, 0.2, 0.3],
# }
# }


class Relationship(BaseModel):
"""Model representing a graph relationship."""

Expand Down
4 changes: 2 additions & 2 deletions py/core/main/api/v3/indices_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def create_index(
},
)

return GenericMessageResponse(message=raw_message)
return GenericMessageResponse(message=raw_message) # type: ignore

@self.router.get(
"/indices",
Expand Down Expand Up @@ -625,4 +625,4 @@ async def delete_index(
},
)

return GenericMessageResponse(message=raw_message)
return GenericMessageResponse(message=raw_message) # type: ignore
8 changes: 4 additions & 4 deletions py/core/main/api/v3/prompts_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def create_prompt(
result = await self.services["management"].add_prompt(
name, template, input_types
)
return GenericMessageResponse(message=result)
return GenericMessageResponse(message=result) # type: ignore

@self.router.get(
"/prompts",
Expand Down Expand Up @@ -188,7 +188,7 @@ async def get_prompts(
"management"
].get_all_prompts()

return (
return ( # type: ignore
get_prompts_response["results"],
{
"total_entries": get_prompts_response["total_entries"],
Expand Down Expand Up @@ -365,7 +365,7 @@ async def update_prompt(
result = await self.services["management"].update_prompt(
name, template, input_types
)
return GenericMessageResponse(message=result)
return GenericMessageResponse(message=result) # type: ignore

@self.router.delete(
"/prompts/{name}",
Expand Down Expand Up @@ -439,4 +439,4 @@ async def delete_prompt(
403,
)
await self.services["management"].delete_prompt(name)
return GenericBooleanResponse(success=True)
return GenericBooleanResponse(success=True) # type: ignore
4 changes: 2 additions & 2 deletions py/core/main/api/v3/system_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _setup_routes(self):
)
@self.base_endpoint
async def health_check() -> WrappedGenericMessageResponse:
return GenericMessageResponse(message="ok")
return GenericMessageResponse(message="ok") # type: ignore

@self.router.get(
"/system/settings",
Expand Down Expand Up @@ -224,7 +224,7 @@ async def server_stats(
"Only an authorized user can call the `server_stats` endpoint.",
403,
)
return {
return { # type: ignore
"start_time": self.start_time.isoformat(),
"uptime_seconds": (
datetime.now(timezone.utc) - self.start_time
Expand Down
4 changes: 2 additions & 2 deletions py/core/main/api/v3/users_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ async def add_user_to_collection(
await self.services["management"].add_user_to_collection( # type: ignore
id, collection_id
)
return GenericBooleanResponse(success=True)
return GenericBooleanResponse(success=True) # type: ignore

@self.router.delete(
"/users/{id}/collections/{collection_id}",
Expand Down Expand Up @@ -979,7 +979,7 @@ async def remove_user_from_collection(
await self.services["management"].remove_user_from_collection( # type: ignore
id, collection_id
)
return GenericBooleanResponse(success=True)
return GenericBooleanResponse(success=True) # type: ignore

@self.router.post(
"/users/{id}",
Expand Down
Loading

0 comments on commit 6f2a968

Please sign in to comment.