Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement endpoint for file upload #768

Merged
merged 7 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ client_secret.json

.vscode/
benchmarking/working/*
uploaded_files/
58 changes: 56 additions & 2 deletions captn/captn_agents/application.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import traceback
from datetime import date
from typing import Dict, List, Literal, Optional, TypeVar
from pathlib import Path
from typing import Annotated, Dict, List, Literal, Optional, TypeVar, Union

import aiofiles
import httpx
import openai
import pandas as pd
from autogen.io.websockets import IOWebsockets
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from prometheus_client import Counter
from pydantic import BaseModel

Expand Down Expand Up @@ -170,3 +173,54 @@ def weekly_analysis(request: WeeklyAnalysisRequest) -> str:
send_only_to_emails=request.send_only_to_emails, date=request.date
)
return "Weekly analysis has been sent to the specified emails"


AVALIABLE_FILE_CONTENT_TYPES = [
"text/csv",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
]
MANDATORY_COLUMNS = {"from_destination", "to_destination"}

UPLOADED_FILES_DIR = Path(__file__).resolve().parent.parent.parent / "uploaded_files"


@router.post("/uploadfile/")
async def create_upload_file(
file: Annotated[UploadFile, File()],
user_id: Annotated[int, Form()],
conv_id: Annotated[int, Form()],
) -> Dict[str, Union[str, None]]:
if file.content_type not in AVALIABLE_FILE_CONTENT_TYPES:
raise HTTPException(
status_code=400,
detail=f"Invalid file content type: {file.content_type}. Only {', '.join(AVALIABLE_FILE_CONTENT_TYPES)} are allowed.",
)
if file.filename is None:
raise HTTPException(status_code=400, detail="Invalid file name")

# Create a directory if not exists
users_conv_dir = UPLOADED_FILES_DIR / str(user_id) / str(conv_id)
users_conv_dir.mkdir(parents=True, exist_ok=True)
file_path = users_conv_dir / file.filename

# Async read-write
async with aiofiles.open(file_path, "wb") as out_file:
content = await file.read()
await out_file.write(content)

# Check if the file has mandatory columns
if file.content_type == "text/csv":
df = pd.read_csv(file_path, nrows=0)
else:
df = pd.read_excel(file_path, nrows=0)
if not MANDATORY_COLUMNS.issubset(df.columns):
# Remove the file
file_path.unlink()

raise HTTPException(
status_code=400,
detail=f"Missing mandatory columns: {', '.join(MANDATORY_COLUMNS - set(df.columns))}",
)

return {"filename": file.filename}
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ agents = [
"opentelemetry-instrumentation-fastapi==0.46b0",
"opentelemetry-instrumentation-logging==0.46b0",
"opentelemetry-exporter-otlp==1.25.0",
"openpyxl==3.1.4",
"aiofiles==23.2.1",
]

dev = [
Expand Down
Binary file added tests/ci/captn/captn_agents/fixtures/upload.xls
Binary file not shown.
Binary file added tests/ci/captn/captn_agents/fixtures/upload.xlsx
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import unittest
from datetime import datetime
from typing import Callable, Dict
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, Dict, Optional

import autogen
import pandas as pd
import pytest
from autogen.io.websockets import IOWebsockets
from fastapi import HTTPException
from fastapi.testclient import TestClient
from websockets.sync.client import connect as ws_connect

from captn.captn_agents.application import (
ON_FAILURE_MESSAGE,
CaptnAgentRequest,
_get_message,
on_connect,
router,
)
from captn.captn_agents.backend.config import Config
from captn.captn_agents.backend.tools._functions import TeamResponse
Expand Down Expand Up @@ -395,3 +401,111 @@ def test_get_message_normal_chat() -> None:
actual = _get_message(request)
expected = "I want to Remove 'Free' keyword because it is not performing well"
assert actual == expected


class TestUploadFile:
@pytest.fixture(autouse=True)
def setup(self) -> None:
self.client = TestClient(router)
self.data = {
"user_id": 123,
"conv_id": 456,
}

def test_upload_file_raises_exception_if_invalid_content_type(self):
# Create a dummy file
file_content = b"Hello, world!"
file_name = "test.txt"
files = {"file": (file_name, file_content, "text/plain")}

# Send a POST request to the upload endpoint
with pytest.raises(HTTPException) as exc_info:
self.client.post("/uploadfile/", files=files, data=self.data)

assert exc_info.value.status_code == 400
assert "Invalid file content type" in exc_info.value.detail

@pytest.mark.parametrize(
"file_name, file_content, success, content_type",
[
(
"test.csv",
b"from_destination,to_destination,additional_column\nvalue1,value2,value3\nvalue1,value2,value3\nvalue1,value2,value3",
True,
"text/csv",
),
(
"test.csv",
b"from_destination,additional_column\nvalue1,value3",
False,
"text/csv",
),
(
"upload.xlsx",
None,
True,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
),
(
"upload.xls",
None,
True,
"application/vnd.ms-excel",
),
],
)
def test_upload_csv_or_xlsx_file(
self,
file_name: str,
file_content: Optional[bytes],
success: bool,
content_type: str,
):
# Create a dummy CSV file
if file_content is None and "upload.xls" in file_name:
file_path = Path(__file__).parent / "fixtures" / file_name
with open(file_path, "rb") as f:
file_content = f.read()
else:
file_content = file_content
file_name = file_name
files = {"file": (file_name, file_content, content_type)}

with TemporaryDirectory() as tmp_dir:
with unittest.mock.patch(
"captn.captn_agents.application.UPLOADED_FILES_DIR",
Path(tmp_dir),
) as mock_uploaded_files_dir:
file_path = (
mock_uploaded_files_dir
/ str(self.data["user_id"])
/ str(self.data["conv_id"])
/ file_name
)

if success:
response = self.client.post(
"/uploadfile/", files=files, data=self.data
)
assert response.status_code == 200
assert response.json() == {"filename": file_name}
# Check if the file was saved
assert file_path.exists()
with open(file_path, "rb") as f:
assert f.read() == file_content
if "xls" in file_name:
df = pd.read_excel(file_path)
else:
df = pd.read_csv(file_path)

# 3 rows in all test files
assert df.shape[0] == 3

else:
with pytest.raises(HTTPException) as exc_info:
self.client.post("/uploadfile/", files=files, data=self.data)
assert not file_path.exists()
assert (
exc_info.value.detail
== "Missing mandatory columns: to_destination"
)