Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

toolkit: add file content viewer #825

Merged
merged 27 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
02d0703
Add icons
danylo-boiko Oct 19, 2024
f75d0e6
Run format-web for MessageRow
danylo-boiko Oct 19, 2024
6f98847
Add handlers
danylo-boiko Oct 19, 2024
264ffbc
Fetch conv files
danylo-boiko Oct 19, 2024
ceeaa94
Merge branch 'main' into add-file-content-viewer
danylo-boiko Oct 24, 2024
8541c69
Merge branch 'main' into add-file-content-viewer
danylo-boiko Oct 25, 2024
4bf8949
Fetch agent files
danylo-boiko Oct 25, 2024
fecef10
Add unit tests
danylo-boiko Oct 25, 2024
2c64fcb
Generate web client
danylo-boiko Oct 25, 2024
835742c
Add API calls
danylo-boiko Oct 30, 2024
3f984d7
Merge branch 'main' into add-file-content-viewer
danylo-boiko Oct 30, 2024
bc34a3d
Add content for FileViewer
danylo-boiko Oct 30, 2024
94e128d
Add padding settings for modals
danylo-boiko Oct 31, 2024
d0dcfc8
Refactor styles
danylo-boiko Oct 31, 2024
78c8407
Minor clean up
danylo-boiko Oct 31, 2024
e60a1ea
Run format-web
danylo-boiko Oct 31, 2024
e556faf
Merge branch 'main' into add-file-content-viewer
danylo-boiko Oct 31, 2024
faf33ff
Merge branch 'main' into add-file-content-viewer
danylo-boiko Nov 1, 2024
2175913
Merge branch 'main' into add-file-content-viewer
danylo-boiko Nov 5, 2024
29fdc0b
Merge branch 'main' into add-file-content-viewer
danylo-boiko Nov 8, 2024
c6e8f5f
Merge remote-tracking branch 'upstream/main' into add-file-content-vi…
danylo-boiko Nov 8, 2024
0a44b77
Merge branch 'main' into add-file-content-viewer
danylo-boiko Nov 8, 2024
46b89df
Merge branch 'main' into add-file-content-viewer
danylo-boiko Nov 11, 2024
38ccf0c
Merge remote-tracking branch 'upstream/main' into add-file-content-vi…
danylo-boiko Nov 12, 2024
e42dc3f
Merge branch 'main' into add-file-content-viewer
danylo-boiko Nov 12, 2024
48c6df4
Merge AgentFileFull and ConversationFileFull
danylo-boiko Nov 12, 2024
b2274cd
Add error message
danylo-boiko Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/backend/crud/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def batch_create_files(db: Session, files: list[File]) -> list[File]:


@validate_transaction
def get_file(db: Session, file_id: str, user_id: str) -> File:
def get_file(db: Session, file_id: str, user_id: str | None = None) -> File:
"""
Get a file by ID.

Expand All @@ -47,7 +47,12 @@ def get_file(db: Session, file_id: str, user_id: str) -> File:
Returns:
File: File with the given ID.
"""
return db.query(File).filter(File.id == file_id, File.user_id == user_id).first()
filters = [File.id == file_id]

if user_id:
filters.append(File.user_id == user_id)

return db.query(File).filter(*filters).first()


@validate_transaction
Expand Down
57 changes: 55 additions & 2 deletions src/backend/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from backend.config.routers import RouterName
from backend.crud import agent as agent_crud
from backend.crud import agent_tool_metadata as agent_tool_metadata_crud
from backend.crud import file as file_crud
from backend.crud import snapshot as snapshot_crud
from backend.database_models.agent import Agent as AgentModel
from backend.database_models.agent_tool_metadata import (
Expand All @@ -34,7 +35,11 @@
)
from backend.schemas.context import Context
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.file import DeleteAgentFileResponse, UploadAgentFileResponse
from backend.schemas.file import (
AgentFileFull,
DeleteAgentFileResponse,
UploadAgentFileResponse,
)
from backend.services.agent import (
raise_db_error,
validate_agent_exists,
Expand Down Expand Up @@ -583,6 +588,54 @@ async def batch_upload_file(
return uploaded_files


@router.get("/{agent_id}/files/{file_id}", response_model=AgentFileFull)
async def get_agent_file(
agent_id: str,
file_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> AgentFileFull:
"""
Get an agent file by ID.

Args:
agent_id (str): Agent ID.
file_id (str): File ID.
session (DBSessionDep): Database session.
ctx (Context): Context object.

Returns:
AgentFileFull: File with the given ID.

Raises:
HTTPException: If the agent or file with the given ID is not found, or if the file does not belong to the agent.
"""
user_id = ctx.get_user_id()

if file_id not in get_file_service().get_file_ids_by_agent_id(session, user_id, agent_id, ctx):
raise HTTPException(
status_code=404,
detail=f"File with ID: {file_id} does not belong to the agent with ID: {agent_id}."
)

file = file_crud.get_file(session, file_id)

if not file:
raise HTTPException(
status_code=404,
detail=f"File with ID: {file_id} not found.",
)

return AgentFileFull(
id=file.id,
file_name=file.file_name,
file_content=file.file_content,
file_size=file.file_size,
created_at=file.created_at,
updated_at=file.updated_at,
)


@router.delete("/{agent_id}/files/{file_id}")
async def delete_agent_file(
agent_id: str,
Expand All @@ -605,7 +658,7 @@ async def delete_agent_file(
HTTPException: If the agent with the given ID is not found.
"""
user_id = ctx.get_user_id()
_ = validate_agent_exists(session, agent_id)
_ = validate_agent_exists(session, agent_id, user_id)
validate_file(session, file_id, user_id)

# Delete the File DB object
Expand Down
47 changes: 45 additions & 2 deletions src/backend/routers/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
UpdateConversationRequest,
)
from backend.schemas.file import (
ConversationFileFull,
DeleteConversationFileResponse,
ListConversationFile,
UploadConversationFileResponse,
Expand Down Expand Up @@ -461,6 +462,49 @@ async def list_files(
return files_with_conversation_id


@router.get("/{conversation_id}/files/{file_id}", response_model=ConversationFileFull)
async def get_file(
conversation_id: str, file_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
) -> ConversationFileFull:
"""
Get a conversation file by ID.

Args:
conversation_id (str): Conversation ID.
file_id (str): File ID.
session (DBSessionDep): Database session.
ctx (Context): Context object.

Returns:
ConversationFileFull: File with the given ID.

Raises:
HTTPException: If the conversation or file with the given ID is not found, or if the file does not belong to the conversation.
"""
user_id = ctx.get_user_id()

conversation = validate_conversation(session, conversation_id, user_id)

if file_id not in conversation.file_ids:
raise HTTPException(
status_code=404,
detail=f"File with ID: {file_id} does not belong to the conversation with ID: {conversation.id}."
)

file = validate_file(session, file_id, user_id)

return ConversationFileFull(
id=file.id,
conversation_id=conversation.id,
file_name=file.file_name,
file_content=file.file_content,
file_size=file.file_size,
user_id=file.user_id,
created_at=file.created_at,
updated_at=file.updated_at,
)


@router.delete("/{conversation_id}/files/{file_id}")
async def delete_file(
conversation_id: str,
Expand All @@ -484,8 +528,7 @@ async def delete_file(
"""
user_id = ctx.get_user_id()
_ = validate_conversation(session, conversation_id, user_id)
validate_file(session, file_id, user_id )

validate_file(session, file_id, user_id)
# Delete the File DB object
get_file_service().delete_conversation_file_by_id(
session, conversation_id, file_id, user_id, ctx
Expand Down
8 changes: 8 additions & 0 deletions src/backend/schemas/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ class ConversationFilePublic(BaseModel):
file_size: int = Field(default=0, ge=0)


class ConversationFileFull(ConversationFilePublic):
file_content: str


class AgentFilePublic(BaseModel):
id: str
Expand All @@ -39,6 +42,11 @@ class AgentFilePublic(BaseModel):
file_name: str
file_size: int = Field(default=0, ge=0)


class AgentFileFull(AgentFilePublic):
file_content: str


class ListConversationFile(ConversationFilePublic):
pass

Expand Down
69 changes: 44 additions & 25 deletions src/backend/services/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,49 +119,66 @@ async def create_agent_files(

return uploaded_files

def get_files_by_agent_id(
def get_file_ids_by_agent_id(
self, session: DBSessionDep, user_id: str, agent_id: str, ctx: Context
) -> list[File]:
) -> list[str]:
"""
Get files by agent ID
Get file IDs associated with a specific agent ID

Args:
session (DBSessionDep): The database session
user_id (str): The user ID
agent_id (str): The agent ID
ctx (Context): Context object

Returns:
list[File]: The files that were created
list[str]: IDs of files that were created
"""
from backend.config.tools import Tool
from backend.tools.files import FileToolsArtifactTypes

agent = validate_agent_exists(session, agent_id, user_id)

files = []
agent_tool_metadata = agent.tools_metadata
if agent_tool_metadata is not None and len(agent_tool_metadata) > 0:
artifacts = next(
(
tool_metadata.artifacts
for tool_metadata in agent_tool_metadata
if tool_metadata.tool_name == Tool.Read_File.value.ID
or tool_metadata.tool_name == Tool.Search_File.value.ID
),
[], # Default value if the generator is empty
)
if not agent.tools_metadata:
malexw marked this conversation as resolved.
Show resolved Hide resolved
return []

artifacts = next(
(
tool_metadata.artifacts
for tool_metadata in agent.tools_metadata
if tool_metadata.tool_name == Tool.Read_File.value.ID
or tool_metadata.tool_name == Tool.Search_File.value.ID
),
[], # Default value if the generator is empty
)

file_ids = list(
{
artifact.get("id")
for artifact in artifacts
if artifact.get("type") == FileToolsArtifactTypes.local_file
}
)
return [
artifact.get("id")
for artifact in artifacts
if artifact.get("type") == FileToolsArtifactTypes.local_file
]

files = file_crud.get_files_by_ids(session, file_ids, user_id)
def get_files_by_agent_id(
self, session: DBSessionDep, user_id: str, agent_id: str, ctx: Context
) -> list[File]:
"""
Get files by agent ID

return files
Args:
session (DBSessionDep): The database session
user_id (str): The user ID
agent_id (str): The agent ID
ctx (Context): Context object

Returns:
list[File]: The files that were created
"""
file_ids = self.get_file_ids_by_agent_id(session, user_id, agent_id, ctx)

if not file_ids:
return []

return file_crud.get_files_by_ids(session, file_ids, user_id)

def get_files_by_conversation_id(
self, session: DBSessionDep, user_id: str, conversation_id: str, ctx: Context
Expand Down Expand Up @@ -312,6 +329,8 @@ def validate_file(
detail=f"File with ID: {file_id} not found.",
)

return file


async def insert_files_in_db(
session: DBSessionDep,
Expand Down
51 changes: 51 additions & 0 deletions src/backend/tests/unit/routers/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,57 @@ def test_list_files_missing_user_id(
assert response.json() == {"detail": "User-Id required in request headers."}


def test_get_file(
session_client: TestClient, session: Session, user: User
) -> None:
conversation = get_factory("Conversation", session).create(user_id=user.id)
response = session_client.post(
"/v1/conversations/batch_upload_file",
headers={"User-Id": conversation.user_id},
files=[
("files", ("Mariana_Trench.pdf", open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb")))
],
data={"conversation_id": conversation.id},
)
assert response.status_code == 200
uploaded_file = response.json()[0]

response = session_client.get(
f"/v1/conversations/{conversation.id}/files/{uploaded_file['id']}",
headers={"User-Id": conversation.user_id},
)

assert response.status_code == 200
response_file = response.json()
assert response_file["id"] == uploaded_file["id"]
assert response_file["file_name"] == uploaded_file["file_name"]


def test_fail_get_file_nonexistent_conversation(
session_client: TestClient, session: Session, user: User
) -> None:
response = session_client.get(
"/v1/conversations/123/files/456",
headers={"User-Id": user.id},
)

assert response.status_code == 404
assert response.json() == {"detail": "Conversation with ID: 123 not found."}


def test_fail_get_file_nonbelong_file(
session_client: TestClient, session: Session, user: User
) -> None:
conversation = get_factory("Conversation", session).create(user_id=user.id)
response = session_client.get(
f"/v1/conversations/{conversation.id}/files/123",
headers={"User-Id": conversation.user_id},
)

assert response.status_code == 404
assert response.json() == {"detail": f"File with ID: 123 does not belong to the conversation with ID: {conversation.id}."}


def test_batch_upload_file_existing_conversation(
session_client: TestClient, session: Session, user
) -> None:
Expand Down
20 changes: 20 additions & 0 deletions src/interfaces/assistants_web/src/cohere-client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ export class CohereClient {
});
}

public getConversationFile({
conversationId,
fileId,
}: {
conversationId: string;
fileId: string;
}) {
return this.cohereService.default.getFileV1ConversationsConversationIdFilesFileIdGet({
conversationId,
fileId,
});
}

public batchUploadConversationFile(
formData: Body_batch_upload_file_v1_conversations_batch_upload_file_post
) {
Expand All @@ -61,6 +74,13 @@ export class CohereClient {
});
}

public getAgentFile({ agentId, fileId }: { agentId: string; fileId: string }) {
return this.cohereService.default.getAgentFileV1AgentsAgentIdFilesFileIdGet({
agentId,
fileId,
});
}

public batchUploadAgentFile(formData: Body_batch_upload_file_v1_agents_batch_upload_file_post) {
return this.cohereService.default.batchUploadFileV1AgentsBatchUploadFilePost({
formData,
Expand Down
Loading