Skip to content

Commit

Permalink
Merge pull request #778 from airtai/dev
Browse files Browse the repository at this point in the history
3 PRs
  • Loading branch information
kumaranvpl committed Jun 17, 2024
2 parents a43f15f + f428011 commit 453d7bb
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ client_secret.json

.vscode/
benchmarking/working/*
uploaded_files/
58 changes: 56 additions & 2 deletions captn/captn_agents/application.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import traceback
from datetime import date
from typing import Dict, List, Literal, Optional, TypeVar
from pathlib import Path
from typing import Annotated, Dict, List, Literal, Optional, TypeVar, Union

import aiofiles
import httpx
import openai
import pandas as pd
from autogen.io.websockets import IOWebsockets
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, File, Form, HTTPException, UploadFile
from prometheus_client import Counter
from pydantic import BaseModel

Expand Down Expand Up @@ -170,3 +173,54 @@ def weekly_analysis(request: WeeklyAnalysisRequest) -> str:
send_only_to_emails=request.send_only_to_emails, date=request.date
)
return "Weekly analysis has been sent to the specified emails"


AVALIABLE_FILE_CONTENT_TYPES = [
"text/csv",
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
]
MANDATORY_COLUMNS = {"from_destination", "to_destination"}

UPLOADED_FILES_DIR = Path(__file__).resolve().parent.parent.parent / "uploaded_files"


@router.post("/uploadfile/")
async def create_upload_file(
file: Annotated[UploadFile, File()],
user_id: Annotated[int, Form()],
conv_id: Annotated[int, Form()],
) -> Dict[str, Union[str, None]]:
if file.content_type not in AVALIABLE_FILE_CONTENT_TYPES:
raise HTTPException(
status_code=400,
detail=f"Invalid file content type: {file.content_type}. Only {', '.join(AVALIABLE_FILE_CONTENT_TYPES)} are allowed.",
)
if file.filename is None:
raise HTTPException(status_code=400, detail="Invalid file name")

# Create a directory if not exists
users_conv_dir = UPLOADED_FILES_DIR / str(user_id) / str(conv_id)
users_conv_dir.mkdir(parents=True, exist_ok=True)
file_path = users_conv_dir / file.filename

# Async read-write
async with aiofiles.open(file_path, "wb") as out_file:
content = await file.read()
await out_file.write(content)

# Check if the file has mandatory columns
if file.content_type == "text/csv":
df = pd.read_csv(file_path, nrows=0)
else:
df = pd.read_excel(file_path, nrows=0)
if not MANDATORY_COLUMNS.issubset(df.columns):
# Remove the file
file_path.unlink()

raise HTTPException(
status_code=400,
detail=f"Missing mandatory columns: {', '.join(MANDATORY_COLUMNS - set(df.columns))}",
)

return {"filename": file.filename}
9 changes: 8 additions & 1 deletion captn/captn_agents/backend/teams/_weekly_analysis_team.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,14 @@ def execute_weekly_analysis(
print(f"Skipping user_id: {user_id} - email {email}")
continue

conv_id, conv_uuid = _get_conv_id_and_uuid(user_id=user_id, email=email)
try:
conv_id, conv_uuid = _get_conv_id_and_uuid(user_id=user_id, email=email)
except Exception as e:
print(
f"Failed to create chat for user_id: {user_id} - email {email}.\nError: {e}"
)
WEEKLY_ANALYSIS_EXCEPTIONS_TOTAL.inc()
continue
weekly_analysis_team = None
try:
login_url_response = get_login_url(
Expand Down
44 changes: 26 additions & 18 deletions captn/captn_agents/backend/tools/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ def get_get_info_from_the_web_page(
websurfer_navigator_llm_config: Optional[Dict[str, Any]] = None,
timestamp: Optional[str] = None,
max_retires_before_give_up_message: int = 7,
max_round: int = 50,
) -> Callable[[str, int, int], str]:
fx = summarizer_llm_config, websurfer_llm_config, websurfer_navigator_llm_config

Expand Down Expand Up @@ -833,6 +834,16 @@ def get_info_from_the_web_page(
# is_termination_msg=_is_termination_msg,
)

groupchat = autogen.GroupChat(
agents=[web_surfer, web_surfer_navigator],
messages=[],
max_round=max_round,
speaker_selection_method="round_robin",
)
manager = autogen.GroupChatManager(
groupchat=groupchat,
)

initial_message = (
f"Time now is {timestamp_copy}." if timestamp_copy else ""
)
Expand All @@ -844,15 +855,13 @@ def get_info_from_the_web_page(
"""

try:
web_surfer_navigator.initiate_chat(
web_surfer, message=initial_message
)
manager.initiate_chat(recipient=manager, message=initial_message)
except Exception as e:
print(f"Exception '{type(e)}' in initiating chat: {e}")

for i in range(inner_retries):
print(f"Inner retry {i + 1}/{inner_retries}")
last_message = str(web_surfer_navigator.last_message()["content"])
last_message = str(groupchat.messages[-1]["content"])

try:
if "I GIVE UP" in last_message:
Expand All @@ -868,20 +877,19 @@ def get_info_from_the_web_page(
current_retries=i,
max_retires_before_give_up_message=max_retires_before_give_up_message,
)
web_surfer.send(
retry_message,
recipient=web_surfer_navigator,
manager.send(
message=retry_message,
recipient=manager,
)
continue
if last_message.strip() == "":
retry_message = """Reminder to myself: we do not have any bad attempts, we are just trying to get the information from the web page.
Message to web_surfer: Please click on the link which you think is the most relevant for the task.
After that, I will guide you through the next steps."""
# In this case, web_surfer_navigator is sending the message to web_surfer
web_surfer_navigator.send(
retry_message,
recipient=web_surfer,
manager.send(
message=retry_message,
recipient=manager,
)
continue

Expand All @@ -899,9 +907,9 @@ def get_info_from_the_web_page(
current_retries=i,
max_retires_before_give_up_message=max_retires_before_give_up_message,
)
web_surfer.send(
retry_message,
recipient=web_surfer_navigator,
manager.send(
message=retry_message,
recipient=manager,
)
continue
last_message = _format_last_message(url=url, summary=summary)
Expand All @@ -923,9 +931,9 @@ def get_info_from_the_web_page(
current_retries=i,
max_retires_before_give_up_message=max_retires_before_give_up_message,
)
web_surfer.send(
retry_message,
recipient=web_surfer_navigator,
manager.send(
message=retry_message,
recipient=manager,
)

except Exception as e:
Expand All @@ -935,7 +943,7 @@ def get_info_from_the_web_page(
current_retries=i,
max_retires_before_give_up_message=max_retires_before_give_up_message,
)
web_surfer.send(retry_message, recipient=web_surfer_navigator)
manager.send(message=retry_message, recipient=manager)
except Exception as e:
# todo: log the exception
failure_message = str(e)
Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ lint = [
"mypy==1.10.0",
"black==24.4.2",
"isort>=5",
"ruff==0.4.8",
"ruff==0.4.9",
"pyupgrade-directories",
"bandit==1.7.8",
"bandit==1.7.9",
"semgrep==1.75.0",
"pre-commit==3.7.1",
"detect-secrets==1.5.0",
]

test-core = [
"coverage[toml]==7.5.2",
"coverage[toml]==7.5.3",
"pytest==8.2.1",
"pytest-asyncio>=0.23.6",
"dirty-equals==0.7.1.post0",
Expand All @@ -85,7 +85,7 @@ testing = [

benchmarking = [
"typer==0.12.3",
"filelock==3.14.0",
"filelock==3.15.1",
"tabulate==0.9.0",
]

Expand All @@ -101,14 +101,16 @@ agents = [
"pandas>=2.1",
"fastcore==1.5.35",
"asyncer==0.0.7",
"pydantic==2.7.2",
"pydantic==2.7.4",
"markdownify==0.12.1",
"tenacity==8.3.0",
"prometheus-client==0.20.0",
"opentelemetry-distro==0.46b0",
"opentelemetry-instrumentation-fastapi==0.46b0",
"opentelemetry-instrumentation-logging==0.46b0",
"opentelemetry-exporter-otlp==1.25.0",
"openpyxl==3.1.4",
"aiofiles==23.2.1",
]

dev = [
Expand Down
Binary file added tests/ci/captn/captn_agents/fixtures/upload.xls
Binary file not shown.
Binary file added tests/ci/captn/captn_agents/fixtures/upload.xlsx
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import unittest
from datetime import datetime
from typing import Callable, Dict
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, Dict, Optional

import autogen
import pandas as pd
import pytest
from autogen.io.websockets import IOWebsockets
from fastapi import HTTPException
from fastapi.testclient import TestClient
from websockets.sync.client import connect as ws_connect

from captn.captn_agents.application import (
ON_FAILURE_MESSAGE,
CaptnAgentRequest,
_get_message,
on_connect,
router,
)
from captn.captn_agents.backend.config import Config
from captn.captn_agents.backend.tools._functions import TeamResponse
Expand Down Expand Up @@ -395,3 +401,111 @@ def test_get_message_normal_chat() -> None:
actual = _get_message(request)
expected = "I want to Remove 'Free' keyword because it is not performing well"
assert actual == expected


class TestUploadFile:
@pytest.fixture(autouse=True)
def setup(self) -> None:
self.client = TestClient(router)
self.data = {
"user_id": 123,
"conv_id": 456,
}

def test_upload_file_raises_exception_if_invalid_content_type(self):
# Create a dummy file
file_content = b"Hello, world!"
file_name = "test.txt"
files = {"file": (file_name, file_content, "text/plain")}

# Send a POST request to the upload endpoint
with pytest.raises(HTTPException) as exc_info:
self.client.post("/uploadfile/", files=files, data=self.data)

assert exc_info.value.status_code == 400
assert "Invalid file content type" in exc_info.value.detail

@pytest.mark.parametrize(
"file_name, file_content, success, content_type",
[
(
"test.csv",
b"from_destination,to_destination,additional_column\nvalue1,value2,value3\nvalue1,value2,value3\nvalue1,value2,value3",
True,
"text/csv",
),
(
"test.csv",
b"from_destination,additional_column\nvalue1,value3",
False,
"text/csv",
),
(
"upload.xlsx",
None,
True,
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
),
(
"upload.xls",
None,
True,
"application/vnd.ms-excel",
),
],
)
def test_upload_csv_or_xlsx_file(
self,
file_name: str,
file_content: Optional[bytes],
success: bool,
content_type: str,
):
# Create a dummy CSV file
if file_content is None and "upload.xls" in file_name:
file_path = Path(__file__).parent / "fixtures" / file_name
with open(file_path, "rb") as f:
file_content = f.read()
else:
file_content = file_content
file_name = file_name
files = {"file": (file_name, file_content, content_type)}

with TemporaryDirectory() as tmp_dir:
with unittest.mock.patch(
"captn.captn_agents.application.UPLOADED_FILES_DIR",
Path(tmp_dir),
) as mock_uploaded_files_dir:
file_path = (
mock_uploaded_files_dir
/ str(self.data["user_id"])
/ str(self.data["conv_id"])
/ file_name
)

if success:
response = self.client.post(
"/uploadfile/", files=files, data=self.data
)
assert response.status_code == 200
assert response.json() == {"filename": file_name}
# Check if the file was saved
assert file_path.exists()
with open(file_path, "rb") as f:
assert f.read() == file_content
if "xls" in file_name:
df = pd.read_excel(file_path)
else:
df = pd.read_csv(file_path)

# 3 rows in all test files
assert df.shape[0] == 3

else:
with pytest.raises(HTTPException) as exc_info:
self.client.post("/uploadfile/", files=files, data=self.data)
assert not file_path.exists()
assert (
exc_info.value.detail
== "Missing mandatory columns: to_destination"
)

0 comments on commit 453d7bb

Please sign in to comment.