diff --git a/.gitignore b/.gitignore index d7e9a06b..38a00e72 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ client_secret.json .vscode/ benchmarking/working/* +uploaded_files/ diff --git a/captn/captn_agents/application.py b/captn/captn_agents/application.py index 4cc6e6eb..cf7833b2 100644 --- a/captn/captn_agents/application.py +++ b/captn/captn_agents/application.py @@ -1,11 +1,14 @@ import traceback from datetime import date -from typing import Dict, List, Literal, Optional, TypeVar +from pathlib import Path +from typing import Annotated, Dict, List, Literal, Optional, TypeVar, Union +import aiofiles import httpx import openai +import pandas as pd from autogen.io.websockets import IOWebsockets -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, File, Form, HTTPException, UploadFile from prometheus_client import Counter from pydantic import BaseModel @@ -170,3 +173,54 @@ def weekly_analysis(request: WeeklyAnalysisRequest) -> str: send_only_to_emails=request.send_only_to_emails, date=request.date ) return "Weekly analysis has been sent to the specified emails" + + +AVALIABLE_FILE_CONTENT_TYPES = [ + "text/csv", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", +] +MANDATORY_COLUMNS = {"from_destination", "to_destination"} + +UPLOADED_FILES_DIR = Path(__file__).resolve().parent.parent.parent / "uploaded_files" + + +@router.post("/uploadfile/") +async def create_upload_file( + file: Annotated[UploadFile, File()], + user_id: Annotated[int, Form()], + conv_id: Annotated[int, Form()], +) -> Dict[str, Union[str, None]]: + if file.content_type not in AVALIABLE_FILE_CONTENT_TYPES: + raise HTTPException( + status_code=400, + detail=f"Invalid file content type: {file.content_type}. Only {', '.join(AVALIABLE_FILE_CONTENT_TYPES)} are allowed.", + ) + if file.filename is None: + raise HTTPException(status_code=400, detail="Invalid file name") + + # Create a directory if not exists + users_conv_dir = UPLOADED_FILES_DIR / str(user_id) / str(conv_id) + users_conv_dir.mkdir(parents=True, exist_ok=True) + file_path = users_conv_dir / file.filename + + # Async read-write + async with aiofiles.open(file_path, "wb") as out_file: + content = await file.read() + await out_file.write(content) + + # Check if the file has mandatory columns + if file.content_type == "text/csv": + df = pd.read_csv(file_path, nrows=0) + else: + df = pd.read_excel(file_path, nrows=0) + if not MANDATORY_COLUMNS.issubset(df.columns): + # Remove the file + file_path.unlink() + + raise HTTPException( + status_code=400, + detail=f"Missing mandatory columns: {', '.join(MANDATORY_COLUMNS - set(df.columns))}", + ) + + return {"filename": file.filename} diff --git a/pyproject.toml b/pyproject.toml index 509cef20..5ca2c861 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,8 @@ agents = [ "opentelemetry-instrumentation-fastapi==0.46b0", "opentelemetry-instrumentation-logging==0.46b0", "opentelemetry-exporter-otlp==1.25.0", + "openpyxl==3.1.4", + "aiofiles==23.2.1", ] dev = [ diff --git a/tests/ci/captn/captn_agents/fixtures/upload.xls b/tests/ci/captn/captn_agents/fixtures/upload.xls new file mode 100644 index 00000000..cf1b206d Binary files /dev/null and b/tests/ci/captn/captn_agents/fixtures/upload.xls differ diff --git a/tests/ci/captn/captn_agents/fixtures/upload.xlsx b/tests/ci/captn/captn_agents/fixtures/upload.xlsx new file mode 100644 index 00000000..21a1d947 Binary files /dev/null and b/tests/ci/captn/captn_agents/fixtures/upload.xlsx differ diff --git a/tests/ci/test_captn_agents_application.py b/tests/ci/captn/captn_agents/test_application.py similarity index 80% rename from tests/ci/test_captn_agents_application.py rename to tests/ci/captn/captn_agents/test_application.py index 8e329ab5..a1a4dacb 100644 --- a/tests/ci/test_captn_agents_application.py +++ b/tests/ci/captn/captn_agents/test_application.py @@ -1,10 +1,15 @@ import unittest from datetime import datetime -from typing import Callable, Dict +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Callable, Dict, Optional import autogen +import pandas as pd import pytest from autogen.io.websockets import IOWebsockets +from fastapi import HTTPException +from fastapi.testclient import TestClient from websockets.sync.client import connect as ws_connect from captn.captn_agents.application import ( @@ -12,6 +17,7 @@ CaptnAgentRequest, _get_message, on_connect, + router, ) from captn.captn_agents.backend.config import Config from captn.captn_agents.backend.tools._functions import TeamResponse @@ -395,3 +401,111 @@ def test_get_message_normal_chat() -> None: actual = _get_message(request) expected = "I want to Remove 'Free' keyword because it is not performing well" assert actual == expected + + +class TestUploadFile: + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.client = TestClient(router) + self.data = { + "user_id": 123, + "conv_id": 456, + } + + def test_upload_file_raises_exception_if_invalid_content_type(self): + # Create a dummy file + file_content = b"Hello, world!" + file_name = "test.txt" + files = {"file": (file_name, file_content, "text/plain")} + + # Send a POST request to the upload endpoint + with pytest.raises(HTTPException) as exc_info: + self.client.post("/uploadfile/", files=files, data=self.data) + + assert exc_info.value.status_code == 400 + assert "Invalid file content type" in exc_info.value.detail + + @pytest.mark.parametrize( + "file_name, file_content, success, content_type", + [ + ( + "test.csv", + b"from_destination,to_destination,additional_column\nvalue1,value2,value3\nvalue1,value2,value3\nvalue1,value2,value3", + True, + "text/csv", + ), + ( + "test.csv", + b"from_destination,additional_column\nvalue1,value3", + False, + "text/csv", + ), + ( + "upload.xlsx", + None, + True, + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ), + ( + "upload.xls", + None, + True, + "application/vnd.ms-excel", + ), + ], + ) + def test_upload_csv_or_xlsx_file( + self, + file_name: str, + file_content: Optional[bytes], + success: bool, + content_type: str, + ): + # Create a dummy CSV file + if file_content is None and "upload.xls" in file_name: + file_path = Path(__file__).parent / "fixtures" / file_name + with open(file_path, "rb") as f: + file_content = f.read() + else: + file_content = file_content + file_name = file_name + files = {"file": (file_name, file_content, content_type)} + + with TemporaryDirectory() as tmp_dir: + with unittest.mock.patch( + "captn.captn_agents.application.UPLOADED_FILES_DIR", + Path(tmp_dir), + ) as mock_uploaded_files_dir: + file_path = ( + mock_uploaded_files_dir + / str(self.data["user_id"]) + / str(self.data["conv_id"]) + / file_name + ) + + if success: + response = self.client.post( + "/uploadfile/", files=files, data=self.data + ) + assert response.status_code == 200 + assert response.json() == {"filename": file_name} + # Check if the file was saved + assert file_path.exists() + with open(file_path, "rb") as f: + assert f.read() == file_content + if "xls" in file_name: + df = pd.read_excel(file_path) + else: + df = pd.read_csv(file_path) + + # 3 rows in all test files + assert df.shape[0] == 3 + + else: + with pytest.raises(HTTPException) as exc_info: + self.client.post("/uploadfile/", files=files, data=self.data) + assert not file_path.exists() + assert ( + exc_info.value.detail + == "Missing mandatory columns: to_destination" + )