From 6b70245eb984f1c49201a0dbdfb33811077a799c Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Thu, 13 Jun 2024 09:43:57 +0200 Subject: [PATCH 1/7] Initial endpoint for file uploading implemented --- captn/captn_agents/application.py | 15 ++++++-- .../captn_agents/test_application.py} | 34 +++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) rename tests/ci/{test_captn_agents_application.py => captn/captn_agents/test_application.py} (93%) diff --git a/captn/captn_agents/application.py b/captn/captn_agents/application.py index 4cc6e6eb..4bd1695f 100644 --- a/captn/captn_agents/application.py +++ b/captn/captn_agents/application.py @@ -1,11 +1,11 @@ import traceback from datetime import date -from typing import Dict, List, Literal, Optional, TypeVar +from typing import Dict, List, Literal, Optional, TypeVar, Union import httpx import openai from autogen.io.websockets import IOWebsockets -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, UploadFile from prometheus_client import Counter from pydantic import BaseModel @@ -170,3 +170,14 @@ def weekly_analysis(request: WeeklyAnalysisRequest) -> str: send_only_to_emails=request.send_only_to_emails, date=request.date ) return "Weekly analysis has been sent to the specified emails" + + +AVALIABLE_CONTENT_TYPES = ["text/csv"] + + +@router.post("/uploadfile/") +async def create_upload_file(file: UploadFile) -> Dict[str, Union[str, None]]: + if file.content_type not in AVALIABLE_CONTENT_TYPES: + raise HTTPException(status_code=400, detail="Invalid file content type") + + return {"filename": file.filename} diff --git a/tests/ci/test_captn_agents_application.py b/tests/ci/captn/captn_agents/test_application.py similarity index 93% rename from tests/ci/test_captn_agents_application.py rename to tests/ci/captn/captn_agents/test_application.py index 8e329ab5..c7b15237 100644 --- a/tests/ci/test_captn_agents_application.py +++ b/tests/ci/captn/captn_agents/test_application.py @@ -5,6 +5,8 @@ import autogen import pytest from autogen.io.websockets import IOWebsockets +from fastapi import HTTPException +from fastapi.testclient import TestClient from websockets.sync.client import connect as ws_connect from captn.captn_agents.application import ( @@ -12,6 +14,7 @@ CaptnAgentRequest, _get_message, on_connect, + router, ) from captn.captn_agents.backend.config import Config from captn.captn_agents.backend.tools._functions import TeamResponse @@ -395,3 +398,34 @@ def test_get_message_normal_chat() -> None: actual = _get_message(request) expected = "I want to Remove 'Free' keyword because it is not performing well" assert actual == expected + + +class TestUploadFile: + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.client = TestClient(router) + + def test_upload_file_raises_exception_if_invalid_content_type(self): + # Create a dummy file + file_content = b"Hello, world!" + file_name = "test.txt" + files = {"file": (file_name, file_content, "text/plain")} + + # Send a POST request to the upload endpoint + with pytest.raises(HTTPException) as exc_info: + self.client.post("/uploadfile/", files=files) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Invalid file content type" + + def test_upload_csv_file(self): + # Create a dummy CSV file + file_content = b"column1,column2\nvalue1,value2" + file_name = "test.csv" + files = {"file": (file_name, file_content, "text/csv")} + + # Send a POST request to the upload endpoint + response = self.client.post("/uploadfile/", files=files) + + assert response.status_code == 200 + assert response.json() == {"filename": file_name} From 275a3df6a13be0f021635182775d948b0d67c934 Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Thu, 13 Jun 2024 10:03:31 +0200 Subject: [PATCH 2/7] Add user_id and conv_id to the uploadfile endpoint --- captn/captn_agents/application.py | 14 +++++++++++--- tests/ci/captn/captn_agents/test_application.py | 8 ++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/captn/captn_agents/application.py b/captn/captn_agents/application.py index 4bd1695f..8ce29942 100644 --- a/captn/captn_agents/application.py +++ b/captn/captn_agents/application.py @@ -1,11 +1,11 @@ import traceback from datetime import date -from typing import Dict, List, Literal, Optional, TypeVar, Union +from typing import Annotated, Dict, List, Literal, Optional, TypeVar, Union import httpx import openai from autogen.io.websockets import IOWebsockets -from fastapi import APIRouter, HTTPException, UploadFile +from fastapi import APIRouter, File, Form, HTTPException, UploadFile from prometheus_client import Counter from pydantic import BaseModel @@ -176,8 +176,16 @@ def weekly_analysis(request: WeeklyAnalysisRequest) -> str: @router.post("/uploadfile/") -async def create_upload_file(file: UploadFile) -> Dict[str, Union[str, None]]: +async def create_upload_file( + file: Annotated[UploadFile, File()], + user_id: Annotated[int, Form()], + conv_id: Annotated[int, Form()], +) -> Dict[str, Union[str, None]]: if file.content_type not in AVALIABLE_CONTENT_TYPES: raise HTTPException(status_code=400, detail="Invalid file content type") + print( + f"Received file: {file.filename} with user_id: {user_id} and conv_id: {conv_id}" + ) + return {"filename": file.filename} diff --git a/tests/ci/captn/captn_agents/test_application.py b/tests/ci/captn/captn_agents/test_application.py index c7b15237..34b8f3ae 100644 --- a/tests/ci/captn/captn_agents/test_application.py +++ b/tests/ci/captn/captn_agents/test_application.py @@ -404,6 +404,10 @@ class TestUploadFile: @pytest.fixture(autouse=True) def setup(self) -> None: self.client = TestClient(router) + self.data = { + "user_id": 123, + "conv_id": 456, + } def test_upload_file_raises_exception_if_invalid_content_type(self): # Create a dummy file @@ -413,7 +417,7 @@ def test_upload_file_raises_exception_if_invalid_content_type(self): # Send a POST request to the upload endpoint with pytest.raises(HTTPException) as exc_info: - self.client.post("/uploadfile/", files=files) + self.client.post("/uploadfile/", files=files, data=self.data) assert exc_info.value.status_code == 400 assert exc_info.value.detail == "Invalid file content type" @@ -425,7 +429,7 @@ def test_upload_csv_file(self): files = {"file": (file_name, file_content, "text/csv")} # Send a POST request to the upload endpoint - response = self.client.post("/uploadfile/", files=files) + response = self.client.post("/uploadfile/", files=files, data=self.data) assert response.status_code == 200 assert response.json() == {"filename": file_name} From 569b6bd1cccd1fdcfd358bea4f840b83473cf3bd Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Thu, 13 Jun 2024 10:46:18 +0200 Subject: [PATCH 3/7] Save file to disk --- .gitignore | 1 + captn/captn_agents/application.py | 12 +++++++- .../ci/captn/captn_agents/test_application.py | 28 +++++++++++++++---- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index d7e9a06b..38a00e72 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ client_secret.json .vscode/ benchmarking/working/* +uploaded_files/ diff --git a/captn/captn_agents/application.py b/captn/captn_agents/application.py index 8ce29942..fd923b15 100644 --- a/captn/captn_agents/application.py +++ b/captn/captn_agents/application.py @@ -1,5 +1,6 @@ import traceback from datetime import date +from pathlib import Path from typing import Annotated, Dict, List, Literal, Optional, TypeVar, Union import httpx @@ -174,6 +175,8 @@ def weekly_analysis(request: WeeklyAnalysisRequest) -> str: AVALIABLE_CONTENT_TYPES = ["text/csv"] +UPLOADED_FILES_DIR = Path(__file__).resolve().parent.parent.parent / "uploaded_files" + @router.post("/uploadfile/") async def create_upload_file( @@ -183,9 +186,16 @@ async def create_upload_file( ) -> Dict[str, Union[str, None]]: if file.content_type not in AVALIABLE_CONTENT_TYPES: raise HTTPException(status_code=400, detail="Invalid file content type") - + if file.filename is None: + raise HTTPException(status_code=400, detail="Invalid file name") print( f"Received file: {file.filename} with user_id: {user_id} and conv_id: {conv_id}" ) + # Create a directory if not exists + users_conv_dir = UPLOADED_FILES_DIR / str(user_id) / str(conv_id) + users_conv_dir.mkdir(parents=True, exist_ok=True) + with open(users_conv_dir / file.filename, "wb") as f: + f.write(file.file.read()) + print(f"File saved to {users_conv_dir}") return {"filename": file.filename} diff --git a/tests/ci/captn/captn_agents/test_application.py b/tests/ci/captn/captn_agents/test_application.py index 34b8f3ae..723bf44c 100644 --- a/tests/ci/captn/captn_agents/test_application.py +++ b/tests/ci/captn/captn_agents/test_application.py @@ -1,5 +1,7 @@ import unittest from datetime import datetime +from pathlib import Path +from tempfile import TemporaryDirectory from typing import Callable, Dict import autogen @@ -428,8 +430,24 @@ def test_upload_csv_file(self): file_name = "test.csv" files = {"file": (file_name, file_content, "text/csv")} - # Send a POST request to the upload endpoint - response = self.client.post("/uploadfile/", files=files, data=self.data) - - assert response.status_code == 200 - assert response.json() == {"filename": file_name} + with TemporaryDirectory() as tmp_dir: + with unittest.mock.patch( + "captn.captn_agents.application.UPLOADED_FILES_DIR", + Path(tmp_dir), + ) as mock_uploaded_files_dir: + # Send a POST request to the upload endpoint + response = self.client.post("/uploadfile/", files=files, data=self.data) + + assert response.status_code == 200 + assert response.json() == {"filename": file_name} + + # Check if the file was saved + file_path = ( + mock_uploaded_files_dir + / str(self.data["user_id"]) + / str(self.data["conv_id"]) + / file_name + ) + assert file_path.exists() + with open(file_path, "rb") as f: + assert f.read() == file_content From ffe4a0ec133eb1bb5a74ce8dfb14127fae891916 Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Thu, 13 Jun 2024 11:45:28 +0200 Subject: [PATCH 4/7] Add validation of columns for the uploaded csv --- captn/captn_agents/application.py | 21 ++++++--- .../ci/captn/captn_agents/test_application.py | 43 +++++++++++++------ 2 files changed, 47 insertions(+), 17 deletions(-) diff --git a/captn/captn_agents/application.py b/captn/captn_agents/application.py index fd923b15..2de48ab5 100644 --- a/captn/captn_agents/application.py +++ b/captn/captn_agents/application.py @@ -5,6 +5,7 @@ import httpx import openai +import pandas as pd from autogen.io.websockets import IOWebsockets from fastapi import APIRouter, File, Form, HTTPException, UploadFile from prometheus_client import Counter @@ -174,6 +175,7 @@ def weekly_analysis(request: WeeklyAnalysisRequest) -> str: AVALIABLE_CONTENT_TYPES = ["text/csv"] +MANDATORY_COLUMNS = {"from_destination", "to_destination"} UPLOADED_FILES_DIR = Path(__file__).resolve().parent.parent.parent / "uploaded_files" @@ -188,14 +190,23 @@ async def create_upload_file( raise HTTPException(status_code=400, detail="Invalid file content type") if file.filename is None: raise HTTPException(status_code=400, detail="Invalid file name") - print( - f"Received file: {file.filename} with user_id: {user_id} and conv_id: {conv_id}" - ) + # Create a directory if not exists users_conv_dir = UPLOADED_FILES_DIR / str(user_id) / str(conv_id) users_conv_dir.mkdir(parents=True, exist_ok=True) - with open(users_conv_dir / file.filename, "wb") as f: + file_path = users_conv_dir / file.filename + with open(file_path, "wb") as f: f.write(file.file.read()) - print(f"File saved to {users_conv_dir}") + + # Check if the file has mandatory columns + df = pd.read_csv(file_path, nrows=0) + if not MANDATORY_COLUMNS.issubset(df.columns): + # Remove the file + file_path.unlink() + + raise HTTPException( + status_code=400, + detail=f"Missing mandatory columns: {', '.join(MANDATORY_COLUMNS - set(df.columns))}", + ) return {"filename": file.filename} diff --git a/tests/ci/captn/captn_agents/test_application.py b/tests/ci/captn/captn_agents/test_application.py index 723bf44c..3cda509e 100644 --- a/tests/ci/captn/captn_agents/test_application.py +++ b/tests/ci/captn/captn_agents/test_application.py @@ -424,9 +424,19 @@ def test_upload_file_raises_exception_if_invalid_content_type(self): assert exc_info.value.status_code == 400 assert exc_info.value.detail == "Invalid file content type" - def test_upload_csv_file(self): + @pytest.mark.parametrize( + "file_content, success", + [ + ( + b"from_destination,to_destination,additional_column\nvalue1,value2,value3", + True, + ), + (b"from_destination,additional_column\nvalue1,value3", False), + ], + ) + def test_upload_csv_file(self, file_content: bytes, success: bool): # Create a dummy CSV file - file_content = b"column1,column2\nvalue1,value2" + file_content = file_content file_name = "test.csv" files = {"file": (file_name, file_content, "text/csv")} @@ -435,19 +445,28 @@ def test_upload_csv_file(self): "captn.captn_agents.application.UPLOADED_FILES_DIR", Path(tmp_dir), ) as mock_uploaded_files_dir: - # Send a POST request to the upload endpoint - response = self.client.post("/uploadfile/", files=files, data=self.data) - - assert response.status_code == 200 - assert response.json() == {"filename": file_name} - - # Check if the file was saved file_path = ( mock_uploaded_files_dir / str(self.data["user_id"]) / str(self.data["conv_id"]) / file_name ) - assert file_path.exists() - with open(file_path, "rb") as f: - assert f.read() == file_content + + if success: + response = self.client.post( + "/uploadfile/", files=files, data=self.data + ) + assert response.status_code == 200 + assert response.json() == {"filename": file_name} + # Check if the file was saved + assert file_path.exists() + with open(file_path, "rb") as f: + assert f.read() == file_content + else: + with pytest.raises(HTTPException) as exc_info: + self.client.post("/uploadfile/", files=files, data=self.data) + assert not file_path.exists() + assert ( + exc_info.value.detail + == "Missing mandatory columns: to_destination" + ) From 49e00a5da92ca2156d437e8e1af500d369e219b9 Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Thu, 13 Jun 2024 12:09:47 +0200 Subject: [PATCH 5/7] Enable uploading the excel format --- captn/captn_agents/application.py | 18 ++++++++++++++---- pyproject.toml | 1 + 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/captn/captn_agents/application.py b/captn/captn_agents/application.py index 2de48ab5..0a14b09d 100644 --- a/captn/captn_agents/application.py +++ b/captn/captn_agents/application.py @@ -174,7 +174,11 @@ def weekly_analysis(request: WeeklyAnalysisRequest) -> str: return "Weekly analysis has been sent to the specified emails" -AVALIABLE_CONTENT_TYPES = ["text/csv"] +AVALIABLE_FILE_CONTENT_TYPES = [ + "text/csv", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", +] MANDATORY_COLUMNS = {"from_destination", "to_destination"} UPLOADED_FILES_DIR = Path(__file__).resolve().parent.parent.parent / "uploaded_files" @@ -186,8 +190,11 @@ async def create_upload_file( user_id: Annotated[int, Form()], conv_id: Annotated[int, Form()], ) -> Dict[str, Union[str, None]]: - if file.content_type not in AVALIABLE_CONTENT_TYPES: - raise HTTPException(status_code=400, detail="Invalid file content type") + if file.content_type not in AVALIABLE_FILE_CONTENT_TYPES: + raise HTTPException( + status_code=400, + detail=f"Invalid file content type: {file.content_type}. Only {', '.join(AVALIABLE_FILE_CONTENT_TYPES)} are allowed.", + ) if file.filename is None: raise HTTPException(status_code=400, detail="Invalid file name") @@ -199,7 +206,10 @@ async def create_upload_file( f.write(file.file.read()) # Check if the file has mandatory columns - df = pd.read_csv(file_path, nrows=0) + if file.content_type == "text/csv": + df = pd.read_csv(file_path, nrows=0) + else: + df = pd.read_excel(file_path, nrows=0) if not MANDATORY_COLUMNS.issubset(df.columns): # Remove the file file_path.unlink() diff --git a/pyproject.toml b/pyproject.toml index 509cef20..19cc9400 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ agents = [ "opentelemetry-instrumentation-fastapi==0.46b0", "opentelemetry-instrumentation-logging==0.46b0", "opentelemetry-exporter-otlp==1.25.0", + "openpyxl==3.1.4", ] dev = [ From 777f198afe239c229ec55603be3e6b17eb84af98 Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Thu, 13 Jun 2024 14:25:48 +0200 Subject: [PATCH 6/7] Update tests --- .../ci/captn/captn_agents/fixtures/upload.xls | Bin 0 -> 4763 bytes .../captn/captn_agents/fixtures/upload.xlsx | Bin 0 -> 4763 bytes .../ci/captn/captn_agents/test_application.py | 57 +++++++++++++++--- 3 files changed, 48 insertions(+), 9 deletions(-) create mode 100644 tests/ci/captn/captn_agents/fixtures/upload.xls create mode 100644 tests/ci/captn/captn_agents/fixtures/upload.xlsx diff --git a/tests/ci/captn/captn_agents/fixtures/upload.xls b/tests/ci/captn/captn_agents/fixtures/upload.xls new file mode 100644 index 0000000000000000000000000000000000000000..cf1b206d5028ac09c487c76a02813fb239eef0ea GIT binary patch literal 4763 zcmai12{_dK7RK0j#x`WjmMt{4B(he7v6LlChH1tw6e9aJmN1sA3E2w8Sh9>Y>tsnG zyX;FyqR4V*bnln1d%y0Pd1n6eJpbQ2&w2mnyyra-Z2&P70XaE2ft^5)3BieA!(Dqh zO4zztd)hg}-T!kTb=J$tF|*&oxkZwi>@)Ba?e$4JGd7?alV!iIB<-|-Ge>-EH*M~? z-&-wNJ0C(WC#>@BMvt>c)pvC$(3fweoRQ={bH%0wAMJ(D| zAeoZB`hwGIK9&XL3=g7P&9rV@RG2epAUjwI|OT!UEkR`F4A{xQMJ&{~~0cu-!ih97z-|n=JBb zRe0oz>VYDK{2}}HM5jSR_n=kh1XBkgr(UaAjoSUr@G?Qmh6I?V=Xzi@0>BxF?R+Kw zeRiTuq)WE)jHqLoOHd-?p^f12u?%oT*`hCXm@%)22}FKC;H7@c6mqAtX)LRthvsvIO3tNgl&C=9+v&4M@%Ick?`lMB zid|ZwHfn!Q5yEa9!ncM^Blj`TYKt$sc{iZ;1~GFk63j-@%;ly|A`fJ&4CBS3pkL11 zX+=uHBTiF;IIN5-hODF~C0sM00+nJhu>x6FWC_EbcGyyTxHCu0${^(pI`St?Y@G~! z4K#b~-pornk+KV*kY#tINcEr_dS<;1uOG0zz+w;!;`JD1T;V*pUH%DM%)Fl|4RXR#txym8Ta2n$c7K#SZ|v?x zdLH&o1aKo`JD(_CFW-9K)McnBwl1z8YwplrHM*ced6R=eFETvuv=I&iAdjLCeNR5;G={2X9&{-tK`xkYn}^K$PAn!J zk>Jb1@<2HAEv_sGNd8n7d_OzvL|It5!5mK($65UEm_BXF7D;;2X`olQq))<7XoU2d zW{7hh*R3o0bdFs7n{UU`V{)cj=HK-Bwpz*1fzceH^5;df=SU>WyO+CY8G9%0U;r#Y zvluj*-#VLI2lGr<+NN9XPKK{7s7${l1vN#*FnEQO&-1aP?tND3e8n&oFC*4dR~#du zqAdl4+mE`;#Ca=fiHF1=AC_56=|$4&TF{Jqoek|L+^8`Tl%h`LcRd`%w4&fF>K4cs zk=sB#2PK88I%r&mA~+uYxix@4Iq*QhoL~~ipT}L%V?5?O&O?6?bO@NiG!J8h>-5Xh zuXCGSU#VIFmu==Xe;|^zJ8WR{aq`v9xy=W&BANwLMrME8{m7dzFfJuDWg)uwBaq><{nj$$3tQE5M2Q!GqV-|O#5PJ^hBIImR&SM* zJ(hXJRs$XHXw`?@%ubr|Hm2ajkBQBrS&i+=u{qK_b&B%0L%JeEfRTBqv#1JsddUpo znD7vK;pqkbFUhC@wZyQ1FKSPc=w5!rSR67aHjEk2a!FdKX&>>7n&*&Sny0TGZe%wG zwD>FvC>6|b%-r{74+y@bA!R=1&T=^Uyo|Jvowmm1#l_+63tqufoc1*}BW4zU)onh2 z`c(@F-|FCRX7$R9tB&vy&ibhtc&is&u5PV3c=$<3{wKeHDyOL_X9o`e(rs@Z<`t#d z-dnVvnlXAAbeOP(AzGSww$>HkD{>_S5L&k&oG)KtIu z8a`3jxXM?Q`_Mp(?h6w2LRFjbyiP0T+zRrZJ<6Ql-b7-!x$V@d)O5{$X7^m&w5qtx zDl=5yu!D%H1`$dvz6xzkxOJsF8EL4U8FRofi zw3rl&ws0~?d0GjzRQ2XT&*wm`l?#YK*Y(Lrrpq0^Tjr5R5EwEh-+i{PH+az}uK+bd zGX|n?SX5D9dR#hmKpv8%Z7%MJj6xd<@7>9bLT*I^$3|(mTHQ zY9@K81AA(PHC3H~VlE^v^O*~C*V3?D5caB{Q4ZbJ@8s)C@kdP}`iJ9Zfo%DzvLn=n z&W`J{QJB?dN|AgLiU$GN8bPZMp_U8G=CK^-HBXt@DT_!ax72z0KP>%GZn!a!f#BKd zy7}oKtv z=yvfpy|J$HT{3*5!Eytw&*0`y3GJU6jp=8jxg)IIV77)HZupswALS2XRqO#$)EavI z7ueP~^vbVjm1S!13L%D?Kcq!li|L8$*g12DixF9R>`tjI`Md99(R*#(MRB2Pw+TX- zw61wnx*jTF`~#qyOb6!V;^T22mnNyvh4-`^tr<;~Diy$Y78gLP;=~GJ4xF$FRV^wqC>}9u;q> zd?MW8`FhKja6+^00zPS7C#}JpryQdYe2BalK!-K zEMgBp75u24$}U01(U>VKb$;~i8&_p#mzq~IWoBOFPi8a&U{4ZRzN9TboRB(uL(oTO zOH(2F$D>)-;0i!&?Vx6dU2o)29uX^m!PTQIr;el$`#HVdEqKvrTBWR(=~ffaA0av_ z-L}vF;f|>LbnfstKcd%$u;kt}u=AXQAG|cB1u_kg-V)-;lwCOINxZ@s6+KW#N>g6^ z?OlKwEl0ZD!YhAujt5J#35)MWl6Ql<#lO>Dwi}!rVAjfx0#YTlD;8|-7M)eStl{=% zxHLyT9O+>Jg<%IcootMQ>4H*55v{J`ed+;x@g3e4@OD2_b6A>2O#w0W!rS=@rHyr^&H2X()|n< zrI91&DfUe1>0mcmA;(k{aj~nxh@oyn<0K0QDx|YyZq&51iuC*;)+7Nr)rCb#b5jw*7rN%u zAH0IS`W15?6dG}ArrB}3YEoPQdZfwZ8oPx~*a5nIY10imv{lc|<*r@rxg+j#$c@ZPEK@F&@%N3|H8a1) z52gbjbkfH|BgeuE+DaG0tv?Nbq*7uiR|*69lR3=&?!l?iV=-jNw}Qw%Cf$9NQVN5m zh#Iwv;TYsKdsKmCvj|cN>D2Ht&;S3 zIf@t=E>AGfLT2cku^r1NqaXNiu1tVeXLN5TQ+8T`UX|4ndDOOl)Q>mu5vViY?H9 z)CKiDKx#3!ky3tJXB4-wKLD>{$)IM#6&xQmM1QJUD%>4E>H4=HRG|%$q$lae`N7*e z2Cml8*)OF89cj3pRH^Y^DzJOrvbJUrL-ini1?yEJA=#x&jTL|XmW_H5i zo%)LeRQhJ(Lgi;}wGrH{6qbZTJPng7; z1Aj`}rLq7s>n67+(N2_`-7GO`hJW==yh$sSt^}oVNRJBlYs@PpevOC;degw4fSj@G zxaESjz$(u<)hkJX71)Gc8~x=WbdtW|zjOa=stEqjgn^7hzA2 zM$gMWy)L&QGsQE{>UvGY^M+qAf1RW-UDD$zGMP9?d|nA<+KhMMx-Vh9ygdN* zjxK@KKA))8P|wcuNnn4fQ>c{yzk0;;^U1W)M{7L~!d@z@VAtV&4aclIMN~j!aEybF zWs#7GiQw0=e{y%lFZ+Lwe_H(^+J8qmx$68KMH~0OIFu9X&9ANZca)RM0)8#}6;Pa* z{2k?gE=j)wo-{@HqWTrLaa7;{f8RKM2RLak@ZRZH^pPC{{FjsZ9pU6i!W)ra@tXW6 z!e6%J_oFA36hE_m#e2%X==jeB`~C1qQNfGtuP~uLKKu{T`#a9bj>SvRuZYB@b3%}Q hKX|g}{?7efoJIW~?G4c;A;qIp;$CSuDO^3y`CnH#ro8|F literal 0 HcmV?d00001 diff --git a/tests/ci/captn/captn_agents/fixtures/upload.xlsx b/tests/ci/captn/captn_agents/fixtures/upload.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..21a1d9474e1ff791a2d61d679d6b28362885945f GIT binary patch literal 4763 zcmai12{_dK7RK0j#x`WjmMt{4B(he7v6LlChH1tw6e9aJmN1sA3E2w8Sh9>Y>tsnG zyX;FyqR4V*bnln1d%y0Pd1n6eJpbQ2&w2mnyyra-Z2&P70XaE2fhKE@3BieA!(Dqh zO4zztd)hg}-T!kTb=J$tF|*&oxkZwi>@)Ba?e$4JGd7?alV!iIB<-|-Ge>-EH*M~? z-&-wNJ0C(WC#>@BMvt>c)pvC$(3fweoRQ={bH%0wAMJ(D| zAeoZB`hwGIK9&XL3=g7P&9rV@RG2epAUjwI|OT!UEkR`F4A{xQMJ&{~~0cu-!ih97z-|n=JBb zRe0oz>VYDK{2}}HM5jSR_n=kh1XBkgr(UaAjoSUr@G?Qmh6I?V=Xzi@0>BxF?R+Kw zeRiTuq)WE)jHqLoOHd-?p^f12u?%oT*`hCXm@%)22}FKC;H7@c6mqAtX)LRthvsvIO3tNgl&C=9+v&4M@%Ick?`lMB zid|ZwHfn!Q5yEa9!ncM^Blj`TYKt$sc{iZ;1~GFk63j-@%;ly|A`fJ&4CBS3pkL11 zX+=uHBTiF;IIN5-hODF~C0sM00+nJhu>x6FWC_EbcGyyTxHCu0${^(pI`St?Y@G~! z4K#b~-pornk+KV*kY#tINcEr_dS<;1uOG0zz+w;!;`JD1T;V*pUH%DM%)Fl|4RXR#txym8Ta2n$c7K#SZ|v?x zdLH&o1aKo`JD(_CFW-9K)McnBwl1z8YwplrHM*ced6R=eFETvuv=I&iAdjLCeNR5;G={2X9&{-tK`xkYn}^K$PAn!J zk>Jb1@<2HAEv_sGNd8n7d_OzvL|It5!5mK($65UEm_BXF7D;;2X`olQq))<7XoU2d zW{7hh*R3o0bdFs7n{UU`V{)cj=HK-Bwpz*1fzceH^5;df=SU>WyO+CY8G9%0U;r#Y zvluj*-#VLI2lGr<+NN9XPKK{7s7${l1vN#*FnEQO&-1aP?tND3e8n&oFC*4dR~#du zqAdl4+mE`;#Ca=fiHF1=AC_56=|$4&TF{Jqoek|L+^8`Tl%h`LcRd`%w4&fF>K4cs zk=sB#2PK88I%r&mA~+uYxix@4Iq*QhoL~~ipT}L%V?5?O&O?6?bO@NiG!J8h>-5Xh zuXCGSU#VIFmu==Xe;|^zJ8WR{aq`v9xy=W&BANwLMrME8{m7dzFfJuDWg)uwBaq><{nj$$3tQE5M2Q!GqV-|O#5PJ^hBIImR&SM* zJ(hXJRs$XHXw`?@%ubr|Hm2ajkBQBrS&i+=u{qK_b&B%0L%JeEfRTBqv#1JsddUpo znD7vK;pqkbFUhC@wZyQ1FKSPc=w5!rSR67aHjEk2a!FdKX&>>7n&*&Sny0TGZe%wG zwD>FvC>6|b%-r{74+y@bA!R=1&T=^Uyo|Jvowmm1#l_+63tqufoc1*}BW4zU)onh2 z`c(@F-|FCRX7$R9tB&vy&ibhtc&is&u5PV3c=$<3{wKeHDyOL_X9o`e(rs@Z<`t#d z-dnVvnlXAAbeOP(AzGSww$>HkD{>_S5L&k&oG)KtIu z8a`3jxXM?Q`_Mp(?h6w2LRFjbyiP0T+zRrZJ<6Ql-b7-!x$V@d)O5{$X7^m&w5qtx zDl=5yu!D%H1`$dvz6xzkxOJsF8EL4U8FRofi zw3rl&ws0~?d0GjzRQ2XT&*wm`l?#YK*Y(Lrrpq0^Tjr5R5EwEh-+i{PH+az}uK+bd zGX|n?SX5D9dR#hmKpv8%Z7%MJj6xd<@7>9bLT*I^$3|(mTHQ zY9@K81AA(PHC3H~VlE^v^O*~C*V3?D5caB{Q4ZbJ@8s)C@kdP}`iJ9Zfo%DzvLn=n z&W`J{QJB?dN|AgLiU$GN8bPZMp_U8G=CK^-HBXt@DT_!ax72z0KP>%GZn!a!f#BKd zy7}oKtv z=yvfpy|J$HT{3*5!Eytw&*0`y3GJU6jp=8jxg)IIV77)HZupswALS2XRqO#$)EavI z7ueP~^vbVjm1S!13L%D?Kcq!li|L8$*g12DixF9R>`tjI`Md99(R*#(MRB2Pw+TX- zw61wnx*jTF`~#qyOb6!V;^T22mnNyvh4-`^tr<;~Diy$Y78gLP;=~GJ4xF$FRV^wqC>}9u;q> zd?MW8`FhKja6+^00zPS7C#}JpryQdYe2BalK!-K zEMgBp75u24$}U01(U>VKb$;~i8&_p#mzq~IWoBOFPi8a&U{4ZRzN9TboRB(uL(oTO zOH(2F$D>)-;0i!&?Vx6dU2o)29uX^m!PTQIr;el$`#HVdEqKvrTBWR(=~ffaA0av_ z-L}vF;f|>LbnfstKcd%$u;kt}u=AXQAG|cB1u_kg-V)-;lwCOINxZ@s6+KW#N>g6^ z?OlKwEl0ZD!YhAujt5J#35)MWl6Ql<#lO>Dwi}!rVAjfx0#YTlD;8|-7M)eStl{=% zxHLyT9O+>Jg<%IcootMQ>4H*55v{J`ed+;x@g3e4@OD2_b6A>2O#w0W!rS=@rHyr^&H2X()|n< zrI91&DfUe1>0mcmA;(k{aj~nxh@oyn<0K0QDx|YyZq&51iuC*;)+7Nr)rCb#b5jw*7rN%u zAH0IS`W15?6dG}ArrB}3YEoPQdZfwZ8oPx~*a5nIY10imv{lc|<*r@rxg+j#$c@ZPEK@F&@%N3|H8a1) z52gbjbkfH|BgeuE+DaG0tv?Nbq*7uiR|*69lR3=&?!l?iV=-jNw}Qw%Cf$9NQVN5m zh#Iwv;TYsKdsKmCvj|cN>D2Ht&;S3 zIf@t=E>AGfLT2cku^r1NqaXNiu1tVeXLN5TQ+8T`UX|4ndDOOl)Q>mu5vViY?H9 z)CKiDKx#3!ky3tJXB4-wKLD>{$)IM#6&xQmM1QJUD%>4E>H4=HRG|%$q$lae`N7*e z2Cml8*)OF89cj3pRH^Y^DzJOrvbJUrL-ini1?yEJA=#x&jTL|XmW_H5i zo%)LeRQhJ(Lgi;}wGrH{6qbZTJPng7; z1Aj`}rLq7s>n67+(N2_`-7GO`hJW==yh$sSt^}oVNRJBlYs@PpevOC;degw4fSj@G zxaESjz$(u<)hkJX71)Gc8~x=WbdtW|zjOa=stEqjgn^7hzA2 zM$gMWy)L&QGsQE{>UvGY^M+qAf1RW-UDD$zGMP9?d|nA<+KhMMx-Vh9ygdN* zjxK@KKA))8P|wcuNnn4fQ>c{yzk0;;^U1W)M{7L~!d@z@VAtV&4aclIMN~j!aEybF zWs#7GiQw0=e{y%lFZ+Lwe_H(^+J8qmx$68KMH~0OIFu9X&9ANZca)RM0)8#}6;Pa* z{2k?gE=j)wo-{@HqWTrLaa7;{f8RKM2RLak@ZRZH^pPC{{FjsZ9pU6i!W)ra@tXW6 z!e6%J_oFA36hE_m#e2%X==jeB`~C1qQNfGtuP~uLKKu{T`#a9bj>SvRuZYB@b3%}Q hKX|g}{?7efoJIW~?G4c;A;qIp;$CSuDO^3y`Cr6AqUitt literal 0 HcmV?d00001 diff --git a/tests/ci/captn/captn_agents/test_application.py b/tests/ci/captn/captn_agents/test_application.py index 3cda509e..a1a4dacb 100644 --- a/tests/ci/captn/captn_agents/test_application.py +++ b/tests/ci/captn/captn_agents/test_application.py @@ -2,9 +2,10 @@ from datetime import datetime from pathlib import Path from tempfile import TemporaryDirectory -from typing import Callable, Dict +from typing import Callable, Dict, Optional import autogen +import pandas as pd import pytest from autogen.io.websockets import IOWebsockets from fastapi import HTTPException @@ -422,23 +423,53 @@ def test_upload_file_raises_exception_if_invalid_content_type(self): self.client.post("/uploadfile/", files=files, data=self.data) assert exc_info.value.status_code == 400 - assert exc_info.value.detail == "Invalid file content type" + assert "Invalid file content type" in exc_info.value.detail @pytest.mark.parametrize( - "file_content, success", + "file_name, file_content, success, content_type", [ ( - b"from_destination,to_destination,additional_column\nvalue1,value2,value3", + "test.csv", + b"from_destination,to_destination,additional_column\nvalue1,value2,value3\nvalue1,value2,value3\nvalue1,value2,value3", True, + "text/csv", + ), + ( + "test.csv", + b"from_destination,additional_column\nvalue1,value3", + False, + "text/csv", + ), + ( + "upload.xlsx", + None, + True, + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ), + ( + "upload.xls", + None, + True, + "application/vnd.ms-excel", ), - (b"from_destination,additional_column\nvalue1,value3", False), ], ) - def test_upload_csv_file(self, file_content: bytes, success: bool): + def test_upload_csv_or_xlsx_file( + self, + file_name: str, + file_content: Optional[bytes], + success: bool, + content_type: str, + ): # Create a dummy CSV file - file_content = file_content - file_name = "test.csv" - files = {"file": (file_name, file_content, "text/csv")} + if file_content is None and "upload.xls" in file_name: + file_path = Path(__file__).parent / "fixtures" / file_name + with open(file_path, "rb") as f: + file_content = f.read() + else: + file_content = file_content + file_name = file_name + files = {"file": (file_name, file_content, content_type)} with TemporaryDirectory() as tmp_dir: with unittest.mock.patch( @@ -462,6 +493,14 @@ def test_upload_csv_file(self, file_content: bytes, success: bool): assert file_path.exists() with open(file_path, "rb") as f: assert f.read() == file_content + if "xls" in file_name: + df = pd.read_excel(file_path) + else: + df = pd.read_csv(file_path) + + # 3 rows in all test files + assert df.shape[0] == 3 + else: with pytest.raises(HTTPException) as exc_info: self.client.post("/uploadfile/", files=files, data=self.data) From 5f00d5f7451562a6f402ee8c434b6e7d58cda626 Mon Sep 17 00:00:00 2001 From: Robert Jambrecic Date: Thu, 13 Jun 2024 15:00:12 +0200 Subject: [PATCH 7/7] Async read and write uploaded file --- captn/captn_agents/application.py | 8 ++++++-- pyproject.toml | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/captn/captn_agents/application.py b/captn/captn_agents/application.py index 0a14b09d..cf7833b2 100644 --- a/captn/captn_agents/application.py +++ b/captn/captn_agents/application.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Annotated, Dict, List, Literal, Optional, TypeVar, Union +import aiofiles import httpx import openai import pandas as pd @@ -202,8 +203,11 @@ async def create_upload_file( users_conv_dir = UPLOADED_FILES_DIR / str(user_id) / str(conv_id) users_conv_dir.mkdir(parents=True, exist_ok=True) file_path = users_conv_dir / file.filename - with open(file_path, "wb") as f: - f.write(file.file.read()) + + # Async read-write + async with aiofiles.open(file_path, "wb") as out_file: + content = await file.read() + await out_file.write(content) # Check if the file has mandatory columns if file.content_type == "text/csv": diff --git a/pyproject.toml b/pyproject.toml index 19cc9400..5ca2c861 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,6 +110,7 @@ agents = [ "opentelemetry-instrumentation-logging==0.46b0", "opentelemetry-exporter-otlp==1.25.0", "openpyxl==3.1.4", + "aiofiles==23.2.1", ] dev = [