Skip to content

Commit

Permalink
Allow passing of collection id at document ingestion
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Nov 15, 2024
1 parent f0e6636 commit 86aaff8
Show file tree
Hide file tree
Showing 13 changed files with 54 additions and 25 deletions.
2 changes: 1 addition & 1 deletion js/sdk/__tests__/ChunksIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ describe("r2rClient V3 Collections Integration Tests", () => {
task_id: null,
},
]);
});
}, 10000);

test("Retrieve a chunk", async () => {
const response = await client.chunks.retrieve({
Expand Down
11 changes: 11 additions & 0 deletions js/sdk/__tests__/CollectionsIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import { describe, test, beforeAll, expect } from "@jest/globals";

const baseUrl = "http://localhost:7272";

/**
* zametov.txt will have an id of 69100f1e-2839-5b37-916d-5c87afe14094
*/
describe("r2rClient V3 Collections Integration Tests", () => {
let client: r2rClient;
let collectionId: string;
Expand Down Expand Up @@ -99,6 +102,14 @@ describe("r2rClient V3 Collections Integration Tests", () => {
expect(response.results).toBeDefined();
});

test("Delete zametov.txt", async () => {
const response = await client.documents.delete({
id: "69100f1e-2839-5b37-916d-5c87afe14094",
});

expect(response.results).toBeDefined();
});

test("Delete collection", async () => {
await expect(
client.collections.delete({ id: collectionId }),
Expand Down
4 changes: 2 additions & 2 deletions js/sdk/__tests__/r2rV2ClientIntegrationUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ describe("r2rClient Integration Tests", () => {
const files = ["examples/data/folder"];

await expect(client.ingestFiles(files)).resolves.not.toThrow();
});
}, 30000);

test("Update files", async () => {
const updated_file = [
Expand All @@ -165,7 +165,7 @@ describe("r2rClient Integration Tests", () => {
document_ids: ["0b80081e-a37a-579f-a06d-7d2032435d65"],
}),
).resolves.not.toThrow();
});
}, 30000);

test("Search documents", async () => {
await expect(client.search("test")).resolves.not.toThrow();
Expand Down
1 change: 1 addition & 0 deletions js/sdk/jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ module.exports = {
"**/__tests__/**/*.ts?(x)",
"**/__tests__/**/?(*.)+(spec|test).ts?(x)",
],
maxWorkers: 1,
};
4 changes: 3 additions & 1 deletion js/sdk/src/v3/clients/documents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ export class DocumentsClient {
);
}
if (options.collectionIds) {
formData.append("collection_ids", JSON.stringify(options.collectionIds));
options.collectionIds.forEach((id) => {
formData.append("collection_ids", id);
});
}
if (options.runWithOrchestration !== undefined) {
formData.append(
Expand Down
3 changes: 2 additions & 1 deletion py/core/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,7 +1220,8 @@ async def assign_document_to_collection_relational(
collection_id: UUID,
) -> UUID:
return await self.collection_handler.assign_document_to_collection_relational(
document_id=document_id, collection_id=collection_id
document_id=document_id,
collection_id=collection_id,
)

async def remove_document_from_collection_relational(
Expand Down
4 changes: 2 additions & 2 deletions py/core/main/api/v3/chunks_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Optional
from uuid import UUID

from fastapi import Body, Depends, Path, Query
from fastapi import Body, Depends, Path, Query, Form

from core.base import (
KGSearchSettings,
Expand Down Expand Up @@ -210,6 +210,7 @@ async def create_chunks(
chunks_by_document[chunk.document_id].append(chunk)

responses = []
# FIXME: Need to verify that the collection_id workflow is valid
for document_id, doc_chunks in chunks_by_document.items():
document_id = document_id or default_document_id
# Convert UnprocessedChunks to RawChunks for ingestion
Expand All @@ -233,7 +234,6 @@ async def create_chunks(
"user": auth_user.model_dump_json(),
}

# TODO - Modify create_chunks so that we can add chunks to existing document
# TODO - Modify create_chunks so that we can add chunks to existing document

if run_with_orchestration:
Expand Down
18 changes: 14 additions & 4 deletions py/core/main/api/v3/documents_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def create_document(
None,
description="The ID of the document. If not provided, a new ID will be generated.",
),
collection_ids: Optional[list[UUID]] = Form(
collection_ids: Optional[list[str]] = Form(
None,
description="Collection IDs to associate with the document. If none are provided, the document will be assigned to the user's default collection.",
),
Expand Down Expand Up @@ -201,10 +201,20 @@ async def create_document(
message="Either a file or content must be provided.",
)

collection_uuids = None
if collection_ids:
try:
collection_uuids = [UUID(cid) for cid in collection_ids]
except ValueError:
raise R2RException(
status_code=422,
message="Collection IDs must be valid UUIDs.",
)

workflow_input = {
"file_data": file_data,
"document_id": str(document_id),
"collection_ids": collection_ids,
"collection_ids": collection_uuids,
"metadata": metadata,
"ingestion_config": ingestion_config,
"user": auth_user.model_dump_json(),
Expand Down Expand Up @@ -389,13 +399,13 @@ async def update_document( # type: ignore
if (
metadata is not None
and "user_id" in metadata
and metadata["user_id"] != str(auth_user.id)
and metadata["user_id"] != str(auth_user.id) # type: ignore
):
raise R2RException(
status_code=403,
message="Non-superusers cannot set user_id in metadata.",
)
metadata["user_id"] = str(auth_user.id)
metadata["user_id"] = str(auth_user.id) # type: ignore

if file:
file_data = await self._process_file(file)
Expand Down
1 change: 1 addition & 0 deletions py/core/main/orchestration/hatchet/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ async def parse(self, context: Context) -> dict:
else:
for collection_id in collection_ids:
try:
# FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully
await service.providers.database.create_collection(
user_id=document_info.user_id,
name=document_info.title,
Expand Down
1 change: 1 addition & 0 deletions py/core/main/orchestration/simple/ingestion_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ async def ingest_files(input_data):
else:
for collection_id in collection_ids:
try:
# FIXME: Right now we just throw a warning if the collection already exists, but we should probably handle this more gracefully
await service.providers.database.create_collection(
name=document_info.title,
collection_id=collection_id,
Expand Down
22 changes: 12 additions & 10 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from collections import defaultdict
from typing import Any, BinaryIO, Optional, Tuple, Union
from typing import Any, BinaryIO, Optional, Tuple
from uuid import UUID

import toml
Expand Down Expand Up @@ -609,7 +609,7 @@ async def get_users_in_collection(
@telemetry_event("GetDocumentsInCollection")
async def documents_in_collection(
self, collection_id: UUID, offset: int = 0, limit: int = 100
) -> dict[str, Union[list[DocumentResponse], int]]:
) -> dict[str, list[DocumentResponse], int]:
return await self.providers.database.documents_in_collection(
collection_id, offset=offset, limit=limit
)
Expand All @@ -622,7 +622,7 @@ async def add_prompt(
await self.providers.database.add_prompt(
name, template, input_types
)
return f"Prompt '{name}' added successfully."
return f"Prompt '{name}' added successfully." # type: ignore
except ValueError as e:
raise R2RException(status_code=400, message=str(e))

Expand Down Expand Up @@ -652,8 +652,10 @@ async def get_prompt(
prompt_override: Optional[str] = None,
) -> dict:
try:
return await self.providers.database.get_prompt(
prompt_name, inputs, prompt_override
return await self.providers.database.get_prompt( # type: ignore
prompt_name=prompt_name,
inputs=inputs,
prompt_override=prompt_override,
)
except ValueError as e:
raise R2RException(status_code=404, message=str(e))
Expand All @@ -673,7 +675,7 @@ async def update_prompt(
await self.providers.database.update_prompt(
name, template, input_types
)
return f"Prompt '{name}' updated successfully."
return f"Prompt '{name}' updated successfully." # type: ignore
except ValueError as e:
raise R2RException(status_code=404, message=str(e))

Expand Down Expand Up @@ -707,7 +709,7 @@ async def verify_conversation_access(
async def create_conversation(
self, user_id: Optional[UUID] = None, auth_user=None
) -> dict:
return await self.logging_connection.create_conversation(
return await self.logging_connection.create_conversation( # type: ignore
user_id=user_id
)

Expand All @@ -719,7 +721,7 @@ async def conversations_overview(
conversation_ids: Optional[list[UUID]] = None,
user_ids: Optional[UUID | list[UUID]] = None,
auth_user=None,
) -> dict[str, Union[list[dict], int]]:
) -> dict[str, list[dict], int]:
return await self.logging_connection.get_conversations_overview(
offset=offset,
limit=limit,
Expand Down Expand Up @@ -759,7 +761,7 @@ async def update_message_metadata(
@telemetry_event("exportMessagesToCSV")
async def export_messages_to_csv(
self, chunk_size: int = 1000, return_type: str = "stream"
) -> Union[StreamingResponse, str]:
) -> StreamingResponse | str:
return await self.logging_connection.export_messages_to_csv(
chunk_size, return_type
)
Expand All @@ -772,7 +774,7 @@ async def branches_overview(
conversation_id: str,
auth_user=None,
) -> list[dict]:
return await self.logging_connection.get_branches(
return await self.logging_connection.get_branches( # type: ignore
offset=offset,
limit=limit,
conversation_id=conversation_id,
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def agent(
parent_id=str(ids[-2]) if (ids and len(ids) > 1) else None, # type: ignore
)
if message is not None:
message_id = message["id"] # type: ignore
message_id = message["id"] # type: ignore

if rag_generation_config.stream:
t1 = time.time()
Expand Down
6 changes: 3 additions & 3 deletions py/sdk/v3/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ async def create(
if ingestion_config:
data["ingestion_config"] = json.dumps(ingestion_config)
if collection_ids:
data["collection_ids"] = json.dumps(
[str(collection_id) for collection_id in collection_ids]
)
for cid in collection_ids:
data["collection_ids"] = cid

if run_with_orchestration is not None:
data["run_with_orchestration"] = str(run_with_orchestration)

Expand Down

0 comments on commit 86aaff8

Please sign in to comment.