From aa90cb3cbcb2881fbd6093088070fb2b53918824 Mon Sep 17 00:00:00 2001 From: rjambrecic <32619626+rjambrecic@users.noreply.github.com> Date: Thu, 13 Jun 2024 05:24:05 +0200 Subject: [PATCH 1/4] Use groupchat for websurfer team (#767) * wip * wip * wip * In groupchat, manager needs to send and receive the first message when starting or continuing the converstion * Cleanup --- .../captn_agents/backend/tools/_functions.py | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/captn/captn_agents/backend/tools/_functions.py b/captn/captn_agents/backend/tools/_functions.py index 18baae44..74e03d7b 100644 --- a/captn/captn_agents/backend/tools/_functions.py +++ b/captn/captn_agents/backend/tools/_functions.py @@ -765,6 +765,7 @@ def get_get_info_from_the_web_page( websurfer_navigator_llm_config: Optional[Dict[str, Any]] = None, timestamp: Optional[str] = None, max_retires_before_give_up_message: int = 7, + max_round: int = 50, ) -> Callable[[str, int, int], str]: fx = summarizer_llm_config, websurfer_llm_config, websurfer_navigator_llm_config @@ -833,6 +834,16 @@ def get_info_from_the_web_page( # is_termination_msg=_is_termination_msg, ) + groupchat = autogen.GroupChat( + agents=[web_surfer, web_surfer_navigator], + messages=[], + max_round=max_round, + speaker_selection_method="round_robin", + ) + manager = autogen.GroupChatManager( + groupchat=groupchat, + ) + initial_message = ( f"Time now is {timestamp_copy}." if timestamp_copy else "" ) @@ -844,15 +855,13 @@ def get_info_from_the_web_page( """ try: - web_surfer_navigator.initiate_chat( - web_surfer, message=initial_message - ) + manager.initiate_chat(recipient=manager, message=initial_message) except Exception as e: print(f"Exception '{type(e)}' in initiating chat: {e}") for i in range(inner_retries): print(f"Inner retry {i + 1}/{inner_retries}") - last_message = str(web_surfer_navigator.last_message()["content"]) + last_message = str(groupchat.messages[-1]["content"]) try: if "I GIVE UP" in last_message: @@ -868,9 +877,9 @@ def get_info_from_the_web_page( current_retries=i, max_retires_before_give_up_message=max_retires_before_give_up_message, ) - web_surfer.send( - retry_message, - recipient=web_surfer_navigator, + manager.send( + message=retry_message, + recipient=manager, ) continue if last_message.strip() == "": @@ -878,10 +887,9 @@ def get_info_from_the_web_page( Message to web_surfer: Please click on the link which you think is the most relevant for the task. After that, I will guide you through the next steps.""" - # In this case, web_surfer_navigator is sending the message to web_surfer - web_surfer_navigator.send( - retry_message, - recipient=web_surfer, + manager.send( + message=retry_message, + recipient=manager, ) continue @@ -899,9 +907,9 @@ def get_info_from_the_web_page( current_retries=i, max_retires_before_give_up_message=max_retires_before_give_up_message, ) - web_surfer.send( - retry_message, - recipient=web_surfer_navigator, + manager.send( + message=retry_message, + recipient=manager, ) continue last_message = _format_last_message(url=url, summary=summary) @@ -923,9 +931,9 @@ def get_info_from_the_web_page( current_retries=i, max_retires_before_give_up_message=max_retires_before_give_up_message, ) - web_surfer.send( - retry_message, - recipient=web_surfer_navigator, + manager.send( + message=retry_message, + recipient=manager, ) except Exception as e: @@ -935,7 +943,7 @@ def get_info_from_the_web_page( current_retries=i, max_retires_before_give_up_message=max_retires_before_give_up_message, ) - web_surfer.send(retry_message, recipient=web_surfer_navigator) + manager.send(message=retry_message, recipient=manager) except Exception as e: # todo: log the exception failure_message = str(e) From 1b672f6cfd99762f0d99f6c2bd8bd96999612ff7 Mon Sep 17 00:00:00 2001 From: rjambrecic <32619626+rjambrecic@users.noreply.github.com> Date: Fri, 14 Jun 2024 08:52:11 +0200 Subject: [PATCH 2/4] Implement endpoint for file upload (#768) * Initial endpoint for file uploading implemented * Add user_id and conv_id to the uploadfile endpoint * Save file to disk * Add validation of columns for the uploaded csv * Enable uploading the excel format * Update tests * Async read and write uploaded file --- .gitignore | 1 + captn/captn_agents/application.py | 58 ++++++++- pyproject.toml | 2 + .../ci/captn/captn_agents/fixtures/upload.xls | Bin 0 -> 4763 bytes .../captn/captn_agents/fixtures/upload.xlsx | Bin 0 -> 4763 bytes .../captn_agents/test_application.py} | 116 +++++++++++++++++- 6 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 tests/ci/captn/captn_agents/fixtures/upload.xls create mode 100644 tests/ci/captn/captn_agents/fixtures/upload.xlsx rename tests/ci/{test_captn_agents_application.py => captn/captn_agents/test_application.py} (80%) diff --git a/.gitignore b/.gitignore index d7e9a06b..38a00e72 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ client_secret.json .vscode/ benchmarking/working/* +uploaded_files/ diff --git a/captn/captn_agents/application.py b/captn/captn_agents/application.py index 4cc6e6eb..cf7833b2 100644 --- a/captn/captn_agents/application.py +++ b/captn/captn_agents/application.py @@ -1,11 +1,14 @@ import traceback from datetime import date -from typing import Dict, List, Literal, Optional, TypeVar +from pathlib import Path +from typing import Annotated, Dict, List, Literal, Optional, TypeVar, Union +import aiofiles import httpx import openai +import pandas as pd from autogen.io.websockets import IOWebsockets -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, File, Form, HTTPException, UploadFile from prometheus_client import Counter from pydantic import BaseModel @@ -170,3 +173,54 @@ def weekly_analysis(request: WeeklyAnalysisRequest) -> str: send_only_to_emails=request.send_only_to_emails, date=request.date ) return "Weekly analysis has been sent to the specified emails" + + +AVALIABLE_FILE_CONTENT_TYPES = [ + "text/csv", + "application/vnd.ms-excel", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", +] +MANDATORY_COLUMNS = {"from_destination", "to_destination"} + +UPLOADED_FILES_DIR = Path(__file__).resolve().parent.parent.parent / "uploaded_files" + + +@router.post("/uploadfile/") +async def create_upload_file( + file: Annotated[UploadFile, File()], + user_id: Annotated[int, Form()], + conv_id: Annotated[int, Form()], +) -> Dict[str, Union[str, None]]: + if file.content_type not in AVALIABLE_FILE_CONTENT_TYPES: + raise HTTPException( + status_code=400, + detail=f"Invalid file content type: {file.content_type}. Only {', '.join(AVALIABLE_FILE_CONTENT_TYPES)} are allowed.", + ) + if file.filename is None: + raise HTTPException(status_code=400, detail="Invalid file name") + + # Create a directory if not exists + users_conv_dir = UPLOADED_FILES_DIR / str(user_id) / str(conv_id) + users_conv_dir.mkdir(parents=True, exist_ok=True) + file_path = users_conv_dir / file.filename + + # Async read-write + async with aiofiles.open(file_path, "wb") as out_file: + content = await file.read() + await out_file.write(content) + + # Check if the file has mandatory columns + if file.content_type == "text/csv": + df = pd.read_csv(file_path, nrows=0) + else: + df = pd.read_excel(file_path, nrows=0) + if not MANDATORY_COLUMNS.issubset(df.columns): + # Remove the file + file_path.unlink() + + raise HTTPException( + status_code=400, + detail=f"Missing mandatory columns: {', '.join(MANDATORY_COLUMNS - set(df.columns))}", + ) + + return {"filename": file.filename} diff --git a/pyproject.toml b/pyproject.toml index 509cef20..5ca2c861 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,8 @@ agents = [ "opentelemetry-instrumentation-fastapi==0.46b0", "opentelemetry-instrumentation-logging==0.46b0", "opentelemetry-exporter-otlp==1.25.0", + "openpyxl==3.1.4", + "aiofiles==23.2.1", ] dev = [ diff --git a/tests/ci/captn/captn_agents/fixtures/upload.xls b/tests/ci/captn/captn_agents/fixtures/upload.xls new file mode 100644 index 0000000000000000000000000000000000000000..cf1b206d5028ac09c487c76a02813fb239eef0ea GIT binary patch literal 4763 zcmai12{_dK7RK0j#x`WjmMt{4B(he7v6LlChH1tw6e9aJmN1sA3E2w8Sh9>Y>tsnG zyX;FyqR4V*bnln1d%y0Pd1n6eJpbQ2&w2mnyyra-Z2&P70XaE2ft^5)3BieA!(Dqh zO4zztd)hg}-T!kTb=J$tF|*&oxkZwi>@)Ba?e$4JGd7?alV!iIB<-|-Ge>-EH*M~? z-&-wNJ0C(WC#>@BMvt>c)pvC$(3fweoRQ={bH%0wAMJ(D| zAeoZB`hwGIK9&XL3=g7P&9rV@RG2epAUjwI|OT!UEkR`F4A{xQMJ&{~~0cu-!ih97z-|n=JBb zRe0oz>VYDK{2}}HM5jSR_n=kh1XBkgr(UaAjoSUr@G?Qmh6I?V=Xzi@0>BxF?R+Kw zeRiTuq)WE)jHqLoOHd-?p^f12u?%oT*`hCXm@%)22}FKC;H7@c6mqAtX)LRthvsvIO3tNgl&C=9+v&4M@%Ick?`lMB zid|ZwHfn!Q5yEa9!ncM^Blj`TYKt$sc{iZ;1~GFk63j-@%;ly|A`fJ&4CBS3pkL11 zX+=uHBTiF;IIN5-hODF~C0sM00+nJhu>x6FWC_EbcGyyTxHCu0${^(pI`St?Y@G~! z4K#b~-pornk+KV*kY#tINcEr_dS<;1uOG0zz+w;!;`JD1T;V*pUH%DM%)Fl|4RXR#txym8Ta2n$c7K#SZ|v?x zdLH&o1aKo`JD(_CFW-9K)McnBwl1z8YwplrHM*ced6R=eFETvuv=I&iAdjLCeNR5;G={2X9&{-tK`xkYn}^K$PAn!J zk>Jb1@<2HAEv_sGNd8n7d_OzvL|It5!5mK($65UEm_BXF7D;;2X`olQq))<7XoU2d zW{7hh*R3o0bdFs7n{UU`V{)cj=HK-Bwpz*1fzceH^5;df=SU>WyO+CY8G9%0U;r#Y zvluj*-#VLI2lGr<+NN9XPKK{7s7${l1vN#*FnEQO&-1aP?tND3e8n&oFC*4dR~#du zqAdl4+mE`;#Ca=fiHF1=AC_56=|$4&TF{Jqoek|L+^8`Tl%h`LcRd`%w4&fF>K4cs zk=sB#2PK88I%r&mA~+uYxix@4Iq*QhoL~~ipT}L%V?5?O&O?6?bO@NiG!J8h>-5Xh zuXCGSU#VIFmu==Xe;|^zJ8WR{aq`v9xy=W&BANwLMrME8{m7dzFfJuDWg)uwBaq><{nj$$3tQE5M2Q!GqV-|O#5PJ^hBIImR&SM* zJ(hXJRs$XHXw`?@%ubr|Hm2ajkBQBrS&i+=u{qK_b&B%0L%JeEfRTBqv#1JsddUpo znD7vK;pqkbFUhC@wZyQ1FKSPc=w5!rSR67aHjEk2a!FdKX&>>7n&*&Sny0TGZe%wG zwD>FvC>6|b%-r{74+y@bA!R=1&T=^Uyo|Jvowmm1#l_+63tqufoc1*}BW4zU)onh2 z`c(@F-|FCRX7$R9tB&vy&ibhtc&is&u5PV3c=$<3{wKeHDyOL_X9o`e(rs@Z<`t#d z-dnVvnlXAAbeOP(AzGSww$>HkD{>_S5L&k&oG)KtIu z8a`3jxXM?Q`_Mp(?h6w2LRFjbyiP0T+zRrZJ<6Ql-b7-!x$V@d)O5{$X7^m&w5qtx zDl=5yu!D%H1`$dvz6xzkxOJsF8EL4U8FRofi zw3rl&ws0~?d0GjzRQ2XT&*wm`l?#YK*Y(Lrrpq0^Tjr5R5EwEh-+i{PH+az}uK+bd zGX|n?SX5D9dR#hmKpv8%Z7%MJj6xd<@7>9bLT*I^$3|(mTHQ zY9@K81AA(PHC3H~VlE^v^O*~C*V3?D5caB{Q4ZbJ@8s)C@kdP}`iJ9Zfo%DzvLn=n z&W`J{QJB?dN|AgLiU$GN8bPZMp_U8G=CK^-HBXt@DT_!ax72z0KP>%GZn!a!f#BKd zy7}oKtv z=yvfpy|J$HT{3*5!Eytw&*0`y3GJU6jp=8jxg)IIV77)HZupswALS2XRqO#$)EavI z7ueP~^vbVjm1S!13L%D?Kcq!li|L8$*g12DixF9R>`tjI`Md99(R*#(MRB2Pw+TX- zw61wnx*jTF`~#qyOb6!V;^T22mnNyvh4-`^tr<;~Diy$Y78gLP;=~GJ4xF$FRV^wqC>}9u;q> zd?MW8`FhKja6+^00zPS7C#}JpryQdYe2BalK!-K zEMgBp75u24$}U01(U>VKb$;~i8&_p#mzq~IWoBOFPi8a&U{4ZRzN9TboRB(uL(oTO zOH(2F$D>)-;0i!&?Vx6dU2o)29uX^m!PTQIr;el$`#HVdEqKvrTBWR(=~ffaA0av_ z-L}vF;f|>LbnfstKcd%$u;kt}u=AXQAG|cB1u_kg-V)-;lwCOINxZ@s6+KW#N>g6^ z?OlKwEl0ZD!YhAujt5J#35)MWl6Ql<#lO>Dwi}!rVAjfx0#YTlD;8|-7M)eStl{=% zxHLyT9O+>Jg<%IcootMQ>4H*55v{J`ed+;x@g3e4@OD2_b6A>2O#w0W!rS=@rHyr^&H2X()|n< zrI91&DfUe1>0mcmA;(k{aj~nxh@oyn<0K0QDx|YyZq&51iuC*;)+7Nr)rCb#b5jw*7rN%u zAH0IS`W15?6dG}ArrB}3YEoPQdZfwZ8oPx~*a5nIY10imv{lc|<*r@rxg+j#$c@ZPEK@F&@%N3|H8a1) z52gbjbkfH|BgeuE+DaG0tv?Nbq*7uiR|*69lR3=&?!l?iV=-jNw}Qw%Cf$9NQVN5m zh#Iwv;TYsKdsKmCvj|cN>D2Ht&;S3 zIf@t=E>AGfLT2cku^r1NqaXNiu1tVeXLN5TQ+8T`UX|4ndDOOl)Q>mu5vViY?H9 z)CKiDKx#3!ky3tJXB4-wKLD>{$)IM#6&xQmM1QJUD%>4E>H4=HRG|%$q$lae`N7*e z2Cml8*)OF89cj3pRH^Y^DzJOrvbJUrL-ini1?yEJA=#x&jTL|XmW_H5i zo%)LeRQhJ(Lgi;}wGrH{6qbZTJPng7; z1Aj`}rLq7s>n67+(N2_`-7GO`hJW==yh$sSt^}oVNRJBlYs@PpevOC;degw4fSj@G zxaESjz$(u<)hkJX71)Gc8~x=WbdtW|zjOa=stEqjgn^7hzA2 zM$gMWy)L&QGsQE{>UvGY^M+qAf1RW-UDD$zGMP9?d|nA<+KhMMx-Vh9ygdN* zjxK@KKA))8P|wcuNnn4fQ>c{yzk0;;^U1W)M{7L~!d@z@VAtV&4aclIMN~j!aEybF zWs#7GiQw0=e{y%lFZ+Lwe_H(^+J8qmx$68KMH~0OIFu9X&9ANZca)RM0)8#}6;Pa* z{2k?gE=j)wo-{@HqWTrLaa7;{f8RKM2RLak@ZRZH^pPC{{FjsZ9pU6i!W)ra@tXW6 z!e6%J_oFA36hE_m#e2%X==jeB`~C1qQNfGtuP~uLKKu{T`#a9bj>SvRuZYB@b3%}Q hKX|g}{?7efoJIW~?G4c;A;qIp;$CSuDO^3y`CnH#ro8|F literal 0 HcmV?d00001 diff --git a/tests/ci/captn/captn_agents/fixtures/upload.xlsx b/tests/ci/captn/captn_agents/fixtures/upload.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..21a1d9474e1ff791a2d61d679d6b28362885945f GIT binary patch literal 4763 zcmai12{_dK7RK0j#x`WjmMt{4B(he7v6LlChH1tw6e9aJmN1sA3E2w8Sh9>Y>tsnG zyX;FyqR4V*bnln1d%y0Pd1n6eJpbQ2&w2mnyyra-Z2&P70XaE2fhKE@3BieA!(Dqh zO4zztd)hg}-T!kTb=J$tF|*&oxkZwi>@)Ba?e$4JGd7?alV!iIB<-|-Ge>-EH*M~? z-&-wNJ0C(WC#>@BMvt>c)pvC$(3fweoRQ={bH%0wAMJ(D| zAeoZB`hwGIK9&XL3=g7P&9rV@RG2epAUjwI|OT!UEkR`F4A{xQMJ&{~~0cu-!ih97z-|n=JBb zRe0oz>VYDK{2}}HM5jSR_n=kh1XBkgr(UaAjoSUr@G?Qmh6I?V=Xzi@0>BxF?R+Kw zeRiTuq)WE)jHqLoOHd-?p^f12u?%oT*`hCXm@%)22}FKC;H7@c6mqAtX)LRthvsvIO3tNgl&C=9+v&4M@%Ick?`lMB zid|ZwHfn!Q5yEa9!ncM^Blj`TYKt$sc{iZ;1~GFk63j-@%;ly|A`fJ&4CBS3pkL11 zX+=uHBTiF;IIN5-hODF~C0sM00+nJhu>x6FWC_EbcGyyTxHCu0${^(pI`St?Y@G~! z4K#b~-pornk+KV*kY#tINcEr_dS<;1uOG0zz+w;!;`JD1T;V*pUH%DM%)Fl|4RXR#txym8Ta2n$c7K#SZ|v?x zdLH&o1aKo`JD(_CFW-9K)McnBwl1z8YwplrHM*ced6R=eFETvuv=I&iAdjLCeNR5;G={2X9&{-tK`xkYn}^K$PAn!J zk>Jb1@<2HAEv_sGNd8n7d_OzvL|It5!5mK($65UEm_BXF7D;;2X`olQq))<7XoU2d zW{7hh*R3o0bdFs7n{UU`V{)cj=HK-Bwpz*1fzceH^5;df=SU>WyO+CY8G9%0U;r#Y zvluj*-#VLI2lGr<+NN9XPKK{7s7${l1vN#*FnEQO&-1aP?tND3e8n&oFC*4dR~#du zqAdl4+mE`;#Ca=fiHF1=AC_56=|$4&TF{Jqoek|L+^8`Tl%h`LcRd`%w4&fF>K4cs zk=sB#2PK88I%r&mA~+uYxix@4Iq*QhoL~~ipT}L%V?5?O&O?6?bO@NiG!J8h>-5Xh zuXCGSU#VIFmu==Xe;|^zJ8WR{aq`v9xy=W&BANwLMrME8{m7dzFfJuDWg)uwBaq><{nj$$3tQE5M2Q!GqV-|O#5PJ^hBIImR&SM* zJ(hXJRs$XHXw`?@%ubr|Hm2ajkBQBrS&i+=u{qK_b&B%0L%JeEfRTBqv#1JsddUpo znD7vK;pqkbFUhC@wZyQ1FKSPc=w5!rSR67aHjEk2a!FdKX&>>7n&*&Sny0TGZe%wG zwD>FvC>6|b%-r{74+y@bA!R=1&T=^Uyo|Jvowmm1#l_+63tqufoc1*}BW4zU)onh2 z`c(@F-|FCRX7$R9tB&vy&ibhtc&is&u5PV3c=$<3{wKeHDyOL_X9o`e(rs@Z<`t#d z-dnVvnlXAAbeOP(AzGSww$>HkD{>_S5L&k&oG)KtIu z8a`3jxXM?Q`_Mp(?h6w2LRFjbyiP0T+zRrZJ<6Ql-b7-!x$V@d)O5{$X7^m&w5qtx zDl=5yu!D%H1`$dvz6xzkxOJsF8EL4U8FRofi zw3rl&ws0~?d0GjzRQ2XT&*wm`l?#YK*Y(Lrrpq0^Tjr5R5EwEh-+i{PH+az}uK+bd zGX|n?SX5D9dR#hmKpv8%Z7%MJj6xd<@7>9bLT*I^$3|(mTHQ zY9@K81AA(PHC3H~VlE^v^O*~C*V3?D5caB{Q4ZbJ@8s)C@kdP}`iJ9Zfo%DzvLn=n z&W`J{QJB?dN|AgLiU$GN8bPZMp_U8G=CK^-HBXt@DT_!ax72z0KP>%GZn!a!f#BKd zy7}oKtv z=yvfpy|J$HT{3*5!Eytw&*0`y3GJU6jp=8jxg)IIV77)HZupswALS2XRqO#$)EavI z7ueP~^vbVjm1S!13L%D?Kcq!li|L8$*g12DixF9R>`tjI`Md99(R*#(MRB2Pw+TX- zw61wnx*jTF`~#qyOb6!V;^T22mnNyvh4-`^tr<;~Diy$Y78gLP;=~GJ4xF$FRV^wqC>}9u;q> zd?MW8`FhKja6+^00zPS7C#}JpryQdYe2BalK!-K zEMgBp75u24$}U01(U>VKb$;~i8&_p#mzq~IWoBOFPi8a&U{4ZRzN9TboRB(uL(oTO zOH(2F$D>)-;0i!&?Vx6dU2o)29uX^m!PTQIr;el$`#HVdEqKvrTBWR(=~ffaA0av_ z-L}vF;f|>LbnfstKcd%$u;kt}u=AXQAG|cB1u_kg-V)-;lwCOINxZ@s6+KW#N>g6^ z?OlKwEl0ZD!YhAujt5J#35)MWl6Ql<#lO>Dwi}!rVAjfx0#YTlD;8|-7M)eStl{=% zxHLyT9O+>Jg<%IcootMQ>4H*55v{J`ed+;x@g3e4@OD2_b6A>2O#w0W!rS=@rHyr^&H2X()|n< zrI91&DfUe1>0mcmA;(k{aj~nxh@oyn<0K0QDx|YyZq&51iuC*;)+7Nr)rCb#b5jw*7rN%u zAH0IS`W15?6dG}ArrB}3YEoPQdZfwZ8oPx~*a5nIY10imv{lc|<*r@rxg+j#$c@ZPEK@F&@%N3|H8a1) z52gbjbkfH|BgeuE+DaG0tv?Nbq*7uiR|*69lR3=&?!l?iV=-jNw}Qw%Cf$9NQVN5m zh#Iwv;TYsKdsKmCvj|cN>D2Ht&;S3 zIf@t=E>AGfLT2cku^r1NqaXNiu1tVeXLN5TQ+8T`UX|4ndDOOl)Q>mu5vViY?H9 z)CKiDKx#3!ky3tJXB4-wKLD>{$)IM#6&xQmM1QJUD%>4E>H4=HRG|%$q$lae`N7*e z2Cml8*)OF89cj3pRH^Y^DzJOrvbJUrL-ini1?yEJA=#x&jTL|XmW_H5i zo%)LeRQhJ(Lgi;}wGrH{6qbZTJPng7; z1Aj`}rLq7s>n67+(N2_`-7GO`hJW==yh$sSt^}oVNRJBlYs@PpevOC;degw4fSj@G zxaESjz$(u<)hkJX71)Gc8~x=WbdtW|zjOa=stEqjgn^7hzA2 zM$gMWy)L&QGsQE{>UvGY^M+qAf1RW-UDD$zGMP9?d|nA<+KhMMx-Vh9ygdN* zjxK@KKA))8P|wcuNnn4fQ>c{yzk0;;^U1W)M{7L~!d@z@VAtV&4aclIMN~j!aEybF zWs#7GiQw0=e{y%lFZ+Lwe_H(^+J8qmx$68KMH~0OIFu9X&9ANZca)RM0)8#}6;Pa* z{2k?gE=j)wo-{@HqWTrLaa7;{f8RKM2RLak@ZRZH^pPC{{FjsZ9pU6i!W)ra@tXW6 z!e6%J_oFA36hE_m#e2%X==jeB`~C1qQNfGtuP~uLKKu{T`#a9bj>SvRuZYB@b3%}Q hKX|g}{?7efoJIW~?G4c;A;qIp;$CSuDO^3y`Cr6AqUitt literal 0 HcmV?d00001 diff --git a/tests/ci/test_captn_agents_application.py b/tests/ci/captn/captn_agents/test_application.py similarity index 80% rename from tests/ci/test_captn_agents_application.py rename to tests/ci/captn/captn_agents/test_application.py index 8e329ab5..a1a4dacb 100644 --- a/tests/ci/test_captn_agents_application.py +++ b/tests/ci/captn/captn_agents/test_application.py @@ -1,10 +1,15 @@ import unittest from datetime import datetime -from typing import Callable, Dict +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Callable, Dict, Optional import autogen +import pandas as pd import pytest from autogen.io.websockets import IOWebsockets +from fastapi import HTTPException +from fastapi.testclient import TestClient from websockets.sync.client import connect as ws_connect from captn.captn_agents.application import ( @@ -12,6 +17,7 @@ CaptnAgentRequest, _get_message, on_connect, + router, ) from captn.captn_agents.backend.config import Config from captn.captn_agents.backend.tools._functions import TeamResponse @@ -395,3 +401,111 @@ def test_get_message_normal_chat() -> None: actual = _get_message(request) expected = "I want to Remove 'Free' keyword because it is not performing well" assert actual == expected + + +class TestUploadFile: + @pytest.fixture(autouse=True) + def setup(self) -> None: + self.client = TestClient(router) + self.data = { + "user_id": 123, + "conv_id": 456, + } + + def test_upload_file_raises_exception_if_invalid_content_type(self): + # Create a dummy file + file_content = b"Hello, world!" + file_name = "test.txt" + files = {"file": (file_name, file_content, "text/plain")} + + # Send a POST request to the upload endpoint + with pytest.raises(HTTPException) as exc_info: + self.client.post("/uploadfile/", files=files, data=self.data) + + assert exc_info.value.status_code == 400 + assert "Invalid file content type" in exc_info.value.detail + + @pytest.mark.parametrize( + "file_name, file_content, success, content_type", + [ + ( + "test.csv", + b"from_destination,to_destination,additional_column\nvalue1,value2,value3\nvalue1,value2,value3\nvalue1,value2,value3", + True, + "text/csv", + ), + ( + "test.csv", + b"from_destination,additional_column\nvalue1,value3", + False, + "text/csv", + ), + ( + "upload.xlsx", + None, + True, + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ), + ( + "upload.xls", + None, + True, + "application/vnd.ms-excel", + ), + ], + ) + def test_upload_csv_or_xlsx_file( + self, + file_name: str, + file_content: Optional[bytes], + success: bool, + content_type: str, + ): + # Create a dummy CSV file + if file_content is None and "upload.xls" in file_name: + file_path = Path(__file__).parent / "fixtures" / file_name + with open(file_path, "rb") as f: + file_content = f.read() + else: + file_content = file_content + file_name = file_name + files = {"file": (file_name, file_content, content_type)} + + with TemporaryDirectory() as tmp_dir: + with unittest.mock.patch( + "captn.captn_agents.application.UPLOADED_FILES_DIR", + Path(tmp_dir), + ) as mock_uploaded_files_dir: + file_path = ( + mock_uploaded_files_dir + / str(self.data["user_id"]) + / str(self.data["conv_id"]) + / file_name + ) + + if success: + response = self.client.post( + "/uploadfile/", files=files, data=self.data + ) + assert response.status_code == 200 + assert response.json() == {"filename": file_name} + # Check if the file was saved + assert file_path.exists() + with open(file_path, "rb") as f: + assert f.read() == file_content + if "xls" in file_name: + df = pd.read_excel(file_path) + else: + df = pd.read_csv(file_path) + + # 3 rows in all test files + assert df.shape[0] == 3 + + else: + with pytest.raises(HTTPException) as exc_info: + self.client.post("/uploadfile/", files=files, data=self.data) + assert not file_path.exists() + assert ( + exc_info.value.detail + == "Missing mandatory columns: to_destination" + ) From 0b672e1fd8be10df98233b4e623d25a4524f8d20 Mon Sep 17 00:00:00 2001 From: rjambrecic <32619626+rjambrecic@users.noreply.github.com> Date: Mon, 17 Jun 2024 09:57:15 +0200 Subject: [PATCH 3/4] Handle exception in weekly analysis if we are not able to create conv id (#776) --- .../captn_agents/backend/teams/_weekly_analysis_team.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/captn/captn_agents/backend/teams/_weekly_analysis_team.py b/captn/captn_agents/backend/teams/_weekly_analysis_team.py index 1ef5c763..33103674 100644 --- a/captn/captn_agents/backend/teams/_weekly_analysis_team.py +++ b/captn/captn_agents/backend/teams/_weekly_analysis_team.py @@ -1078,7 +1078,14 @@ def execute_weekly_analysis( print(f"Skipping user_id: {user_id} - email {email}") continue - conv_id, conv_uuid = _get_conv_id_and_uuid(user_id=user_id, email=email) + try: + conv_id, conv_uuid = _get_conv_id_and_uuid(user_id=user_id, email=email) + except Exception as e: + print( + f"Failed to create chat for user_id: {user_id} - email {email}.\nError: {e}" + ) + WEEKLY_ANALYSIS_EXCEPTIONS_TOTAL.inc() + continue weekly_analysis_team = None try: login_url_response = get_login_url( From f428011e9c2a839aa6c65e60939ef5d0f6fc5435 Mon Sep 17 00:00:00 2001 From: rjambrecic <32619626+rjambrecic@users.noreply.github.com> Date: Mon, 17 Jun 2024 10:14:42 +0200 Subject: [PATCH 4/4] Update packages (#777) --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5ca2c861..be8792ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,16 +58,16 @@ lint = [ "mypy==1.10.0", "black==24.4.2", "isort>=5", - "ruff==0.4.8", + "ruff==0.4.9", "pyupgrade-directories", - "bandit==1.7.8", + "bandit==1.7.9", "semgrep==1.75.0", "pre-commit==3.7.1", "detect-secrets==1.5.0", ] test-core = [ - "coverage[toml]==7.5.2", + "coverage[toml]==7.5.3", "pytest==8.2.1", "pytest-asyncio>=0.23.6", "dirty-equals==0.7.1.post0", @@ -85,7 +85,7 @@ testing = [ benchmarking = [ "typer==0.12.3", - "filelock==3.14.0", + "filelock==3.15.1", "tabulate==0.9.0", ] @@ -101,7 +101,7 @@ agents = [ "pandas>=2.1", "fastcore==1.5.35", "asyncer==0.0.7", - "pydantic==2.7.2", + "pydantic==2.7.4", "markdownify==0.12.1", "tenacity==8.3.0", "prometheus-client==0.20.0",