From b67af2e1df656ff5bddf6fcb04c1402ed17f1cd3 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Mon, 13 Jan 2025 11:27:09 -0800 Subject: [PATCH 1/7] chore(backend): add type hints to integration tests --- .../tests/integration/crud/test_deployment.py | 30 +++++++++--------- .../tests/integration/crud/test_model.py | 31 ++++++++++--------- .../tests/integration/routers/test_agent.py | 10 +++--- .../tests/integration/routers/test_model.py | 30 +++++++++++++----- 4 files changed, 60 insertions(+), 41 deletions(-) diff --git a/src/backend/tests/integration/crud/test_deployment.py b/src/backend/tests/integration/crud/test_deployment.py index 48c2c7bc74..908d57e844 100644 --- a/src/backend/tests/integration/crud/test_deployment.py +++ b/src/backend/tests/integration/crud/test_deployment.py @@ -1,12 +1,14 @@ import pytest +from sqlalchemy.orm import Session from backend.crud import deployment as deployment_crud from backend.database_models.deployment import Deployment from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate +from backend.schemas.user import User from backend.tests.unit.factories import get_factory -def test_create_deployment(session, deployment): +def test_create_deployment(session: Session, deployment: Deployment) -> None: deployment_data = DeploymentCreate( name="Test Deployment", deployment_class_name="CohereDeployment", @@ -29,7 +31,7 @@ def test_create_deployment(session, deployment): assert deployment.name == deployment_data.name -def test_create_deployment_invalid_class_name(session): +def test_create_deployment_invalid_class_name(session: Session) -> None: with pytest.raises(ValueError) as e: deployment_data = DeploymentCreate( name="Test Deployment", @@ -43,19 +45,19 @@ def test_create_deployment_invalid_class_name(session): assert "Deployment class not found" in str(e.value) -def test_get_deployment(session): +def test_get_deployment(session: Session) -> None: deployment = get_factory("Deployment", session).create(name="Test Deployment") retrieved_deployment = deployment_crud.get_deployment(session, deployment.id) assert retrieved_deployment.id == deployment.id assert retrieved_deployment.name == deployment.name -def test_fail_get_nonexistent_deployment(session): +def test_fail_get_nonexistent_deployment(session: Session) -> None: deployment = deployment_crud.get_deployment(session, "123") assert deployment is None -def test_list_deployments(session): +def test_list_deployments(session: Session) -> None: # Delete default deployments session.query(Deployment).delete() _ = get_factory("Deployment", session).create(name="Test Deployment") @@ -65,14 +67,14 @@ def test_list_deployments(session): assert deployments[0].name == "Test Deployment" -def test_list_deployments_empty(session): +def test_list_deployments_empty(session: Session) -> None: # Delete default deployments session.query(Deployment).delete() deployments = deployment_crud.get_deployments(session) assert len(deployments) == 0 -def test_list_deployments_with_pagination(session): +def test_list_deployments_with_pagination(session: Session) -> None: # Delete default deployments session.query(Deployment).delete() for i in range(10): @@ -82,7 +84,7 @@ def test_list_deployments_with_pagination(session): assert len(deployments) == 5 -def test_get_available_deployments(session, user): +def test_get_available_deployments(session: Session, user: User) -> None: session.query(Deployment).delete() deployment = get_factory("Deployment", session).create() _ = get_factory("Deployment", session).create( @@ -95,14 +97,14 @@ def test_get_available_deployments(session, user): assert deployments[0].id == deployment.id -def test_get_available_deployments_empty(session, user): +def test_get_available_deployments_empty(session: Session, user: User) -> None: session.query(Deployment).delete() deployments = deployment_crud.get_available_deployments(session) assert len(deployments) == 0 -def test_update_deployment(session, deployment): +def test_update_deployment(session: Session, deployment: Deployment) -> None: new_deployment_data = DeploymentUpdate( name="NewName", description="New Description", @@ -122,7 +124,7 @@ def test_update_deployment(session, deployment): assert updated_deployment.id == deployment.id -def test_update_deployment_partial(session, deployment): +def test_update_deployment_partial(session: Session, deployment: Deployment) -> None: new_deployment_data = DeploymentUpdate(name="Cohere") updated_deployment = deployment_crud.update_deployment( @@ -133,7 +135,7 @@ def test_update_deployment_partial(session, deployment): assert updated_deployment.id == deployment.id -def test_do_not_update_deployment(session, deployment): +def test_do_not_update_deployment(session: Session, deployment: Deployment) -> None: new_deployment_data = DeploymentUpdate(name="Test Deployment") updated_deployment = deployment_crud.update_deployment( @@ -142,7 +144,7 @@ def test_do_not_update_deployment(session, deployment): assert updated_deployment.name == deployment.name -def test_delete_deployment(session): +def test_delete_deployment(session: Session) -> None: deployment = get_factory("Deployment", session).create() deployment_crud.delete_deployment(session, deployment.id) @@ -151,7 +153,7 @@ def test_delete_deployment(session): assert deployment is None -def test_delete_nonexistent_deployment(session): +def test_delete_nonexistent_deployment(session: Session) -> None: deployment_crud.delete_deployment(session, "123") # no error deployment = deployment_crud.get_deployment(session, "123") assert deployment is None diff --git a/src/backend/tests/integration/crud/test_model.py b/src/backend/tests/integration/crud/test_model.py index 5fbd3bacf0..95389177bc 100644 --- a/src/backend/tests/integration/crud/test_model.py +++ b/src/backend/tests/integration/crud/test_model.py @@ -1,10 +1,13 @@ +from sqlalchemy.orm import Session + from backend.crud import model as model_crud +from backend.database_models.deployment import Deployment from backend.database_models.model import Model from backend.schemas.model import ModelCreate, ModelUpdate from backend.tests.unit.factories import get_factory -def test_create_model(session, deployment): +def test_create_model(session: Session, deployment: Deployment) -> None: model_data = ModelCreate( name="Test Model", cohere_name="Test Cohere Model", @@ -21,7 +24,7 @@ def test_create_model(session, deployment): assert model.name == model_data.name -def test_get_model(session, deployment): +def test_get_model(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create( name="Test Model", deployment=deployment ) @@ -30,12 +33,12 @@ def test_get_model(session, deployment): assert retrieved_model.name == model.name -def test_fail_get_nonexistent_model(session): +def test_fail_get_nonexistent_model(session: Session) -> None: model = model_crud.get_model(session, "123") assert model is None -def test_list_models(session, deployment): +def test_list_models(session: Session, deployment: Deployment) -> None: # Delete default models session.query(Model).delete() _ = get_factory("Model", session).create(name="Test Model", deployment=deployment) @@ -45,14 +48,14 @@ def test_list_models(session, deployment): assert models[0].name == "Test Model" -def test_list_models_empty(session): +def test_list_models_empty(session: Session) -> None: # Delete default models session.query(Model).delete() models = model_crud.get_models(session) assert len(models) == 0 -def test_list_models_with_pagination(session, deployment): +def test_list_models_with_pagination(session: Session, deployment: Deployment) -> None: # Delete default models session.query(Model).delete() for i in range(10): @@ -67,7 +70,7 @@ def test_list_models_with_pagination(session, deployment): assert model.name == f"Test Model {i + 5}" -def test_get_models_by_deployment_id(session, deployment): +def test_get_models_by_deployment_id(session: Session, deployment: Deployment) -> None: for i in range(10): model = get_factory("Model", session).create( name=f"Test Model {i}", deployment=deployment @@ -80,12 +83,12 @@ def test_get_models_by_deployment_id(session, deployment): assert model.name == f"Test Model {i}" -def test_get_models_by_deployment_id_empty(session, deployment): +def test_get_models_by_deployment_id_empty(session: Session, deployment: Deployment) -> None: models = model_crud.get_models_by_deployment_id(session, deployment.id) assert len(models) == 0 -def test_get_models_by_deployment_id_with_pagination(session, deployment): +def test_get_models_by_deployment_id_with_pagination(session: Session, deployment: Deployment) -> None: for i in range(10): model = get_factory("Model", session).create( name=f"Test Model {i}", deployment=deployment @@ -100,7 +103,7 @@ def test_get_models_by_deployment_id_with_pagination(session, deployment): assert model.name == f"Test Model {i + 5}" -def test_update_model(session, deployment): +def test_update_model(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create( name="Sagemaker model", deployment=deployment ) @@ -127,7 +130,7 @@ def test_update_model(session, deployment): assert model.deployment_id == new_model_data.deployment_id -def test_update_model_partial(session, deployment): +def test_update_model_partial(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create( name="Test Model U", deployment=deployment ) @@ -148,7 +151,7 @@ def test_update_model_partial(session, deployment): assert model.deployment_id == model.deployment_id -def test_do_not_update_model(session, deployment): +def test_do_not_update_model(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create( name="Test Model", deployment=deployment ) @@ -159,7 +162,7 @@ def test_do_not_update_model(session, deployment): assert updated_model.name == model.name -def test_delete_model(session, deployment): +def test_delete_model(session: Session, deployment: Deployment) -> None: model = get_factory("Model", session).create(deployment=deployment) model_crud.delete_model(session, model.id) @@ -168,7 +171,7 @@ def test_delete_model(session, deployment): assert model is None -def test_delete_nonexistent_model(session): +def test_delete_nonexistent_model(session: Session) -> None: model_crud.delete_model(session, "123") # no error model = model_crud.get_model(session, "123") assert model is None diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index 9ba0be0649..f39749c588 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -1,4 +1,3 @@ - from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -6,10 +5,11 @@ from backend.config.tools import Tool from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata +from backend.schemas.user import User from backend.tests.unit.factories import get_factory -def test_create_agent(session_client: TestClient, session: Session, user) -> None: +def test_create_agent(session_client: TestClient, session: Session, user: User) -> None: request_json = { "name": "test agent", "version": 1, @@ -49,7 +49,7 @@ def test_create_agent(session_client: TestClient, session: Session, user) -> Non def test_create_agent_with_tool_metadata( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: request_json = { "name": "test agent", @@ -107,7 +107,7 @@ def test_create_agent_with_tool_metadata( def test_create_agent_missing_non_required_fields( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: request_json = { "name": "test agent", @@ -138,7 +138,7 @@ def test_create_agent_missing_non_required_fields( assert agent.model == request_json["model"] -def test_update_agent(session_client: TestClient, session: Session, user) -> None: +def test_update_agent(session_client: TestClient, session: Session, user: User) -> None: agent = get_factory("Agent", session).create( name="test agent", version=1, diff --git a/src/backend/tests/integration/routers/test_model.py b/src/backend/tests/integration/routers/test_model.py index 133a684859..5f56bf929c 100644 --- a/src/backend/tests/integration/routers/test_model.py +++ b/src/backend/tests/integration/routers/test_model.py @@ -2,10 +2,11 @@ from sqlalchemy.orm import Session from backend.database_models import Model +from backend.database_models.deployment import Deployment from backend.tests.unit.factories import get_factory -def test_create_model(session_client: TestClient, session: Session, deployment) -> None: +def test_create_model(session_client: TestClient, session: Session, deployment: Deployment) -> None: request_json = { "name": "sagemaker-command-created", "cohere_name": "command", @@ -29,7 +30,7 @@ def test_create_model(session_client: TestClient, session: Session, deployment) def test_create_model_non_existing_deployment( - session_client: TestClient, session: Session + session_client: TestClient, session: Session, ) -> None: request_json = { "name": "sagemaker-command-created", @@ -50,7 +51,7 @@ def test_create_model_non_existing_deployment( ) -def test_update_model(session_client: TestClient, session: Session, deployment) -> None: +def test_update_model(session_client: TestClient, session: Session, deployment: Deployment) -> None: request_json = { "name": "sagemaker-command-updated", "cohere_name": "command", @@ -69,7 +70,7 @@ def test_update_model(session_client: TestClient, session: Session, deployment) assert model.deployment_id == response_json["deployment_id"] -def test_get_model(session_client: TestClient, session: Session, deployment) -> None: +def test_get_model(session_client: TestClient, session: Session, deployment: Deployment) -> None: # Delete all models session.query(Model).delete() model = get_factory("Model", session).create(deployment=deployment) @@ -89,7 +90,11 @@ def test_get_model_non_existing(session_client: TestClient, session: Session) -> assert "Model not found" in response_json["detail"] -def test_list_models(session_client: TestClient, session: Session, deployment) -> None: +def test_list_models( + session_client: TestClient, + session: Session, + deployment: Deployment, +) -> None: # Delete all models session.query(Model).delete() for _ in range(5): @@ -101,7 +106,10 @@ def test_list_models(session_client: TestClient, session: Session, deployment) - assert len(models) == 5 -def test_list_models_empty(session_client: TestClient, session: Session) -> None: +def test_list_models_empty( + session_client: TestClient, + session: Session, +) -> None: session.query(Model).delete() response = session_client.get("/v1/models") assert response.status_code == 200 @@ -110,7 +118,9 @@ def test_list_models_empty(session_client: TestClient, session: Session) -> None def test_list_models_with_pagination( - session_client: TestClient, session: Session, deployment + session_client: TestClient, + session: Session, + deployment: Deployment, ) -> None: # Delete all models session.query(Model).delete() @@ -128,7 +138,11 @@ def test_list_models_with_pagination( assert model["name"] == f"Test Model {i + 5}" -def test_delete_model(session_client: TestClient, session: Session, deployment) -> None: +def test_delete_model( + session_client: TestClient, + session: Session, + deployment: Deployment, +) -> None: model = get_factory("Model", session).create(deployment=deployment) response = session_client.delete(f"/v1/models/{model.id}") assert response.status_code == 200 From c4bf297fb4847e31ab2fa6afdda94ed774984bb7 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Tue, 14 Jan 2025 06:07:38 -0800 Subject: [PATCH 2/7] chore(backend): remove deprecated method --- src/backend/routers/organization.py | 2 +- src/backend/tests/unit/routers/test_organization.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/backend/routers/organization.py b/src/backend/routers/organization.py index 6c252f7c6c..e88fed066b 100644 --- a/src/backend/routers/organization.py +++ b/src/backend/routers/organization.py @@ -41,7 +41,7 @@ def create_organization( Returns: Organization: Created organization. """ - organization_data = OrganizationModel(**organization.dict()) + organization_data = OrganizationModel(**organization.model_dump()) return organization_crud.create_organization(session, organization_data) diff --git a/src/backend/tests/unit/routers/test_organization.py b/src/backend/tests/unit/routers/test_organization.py index 658321c992..67459a10fb 100644 --- a/src/backend/tests/unit/routers/test_organization.py +++ b/src/backend/tests/unit/routers/test_organization.py @@ -8,7 +8,7 @@ def test_create_organization(session_client: TestClient, session: Session) -> None: organization = CreateOrganization(name="test organization") - response = session_client.post("/v1/organizations", json=organization.dict()) + response = session_client.post("/v1/organizations", json=organization.model_dump()) assert response.status_code == 200 assert response.json()["name"] == organization.name @@ -18,7 +18,7 @@ def test_create_organization_with_existing_name( ) -> None: get_factory("Organization", session).create(name="test organization") new_organization = CreateOrganization(name="test organization") - response = session_client.post("/v1/organizations", json=new_organization.dict()) + response = session_client.post("/v1/organizations", json=new_organization.model_dump()) assert response.status_code == 400 assert response.json() == { "detail": "Organization with name: test organization already exists." @@ -29,7 +29,7 @@ def test_update_organization(session_client: TestClient, session: Session) -> No organization = get_factory("Organization", session).create(name="test organization") new_organization = UpdateOrganization(name="new organization") response = session_client.put( - f"/v1/organizations/{organization.id}", json=new_organization.dict() + f"/v1/organizations/{organization.id}", json=new_organization.model_dump() ) assert response.status_code == 200 assert response.json()["name"] == new_organization.name @@ -39,7 +39,7 @@ def test_update_not_existing_organization( session_client: TestClient, session: Session ) -> None: new_organization = UpdateOrganization(name="new organization") - response = session_client.put("/v1/organizations/123", json=new_organization.dict()) + response = session_client.put("/v1/organizations/123", json=new_organization.model_dump()) assert response.status_code == 404 assert response.json() == {"detail": "Organization with ID: 123 not found."} From 3c4d8c8dde08dd9d51d557162417d584ed96ea38 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Tue, 14 Jan 2025 12:28:46 -0800 Subject: [PATCH 3/7] chore(ci): improve readability --- .github/workflows/backend_integration_tests.yml | 9 ++++++++- .github/workflows/backend_unit_tests.yml | 7 +++++++ src/backend/pytest_integration.ini | 2 ++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/.github/workflows/backend_integration_tests.yml b/.github/workflows/backend_integration_tests.yml index 8537b85670..9d398d51b5 100644 --- a/.github/workflows/backend_integration_tests.yml +++ b/.github/workflows/backend_integration_tests.yml @@ -8,16 +8,18 @@ on: jobs: pytest: permissions: write-all - environment: development + # environment: development runs-on: ubuntu-latest steps: - name: Checkout repo uses: actions/checkout@v3 + - uses: actions/setup-python@v5 with: python-version: '3.11' cache: 'pip' + - name: Install poetry uses: snok/install-poetry@v1 with: @@ -25,23 +27,28 @@ jobs: virtualenvs-in-project: true virtualenvs-path: .venv installer-parallel: true + - name: Load cached venv id: cached-poetry-dependencies uses: actions/cache@v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' run: poetry install --with dev --no-interaction --no-root + - name: Setup test DB container run: make test-db + - name: Test with pytest if: github.actor != 'dependabot[bot]' run: | make run-integration-tests env: PYTHONPATH: src + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 with: diff --git a/.github/workflows/backend_unit_tests.yml b/.github/workflows/backend_unit_tests.yml index 8664706c75..a175c3a79e 100644 --- a/.github/workflows/backend_unit_tests.yml +++ b/.github/workflows/backend_unit_tests.yml @@ -14,10 +14,12 @@ jobs: steps: - name: Checkout repo uses: actions/checkout@v3 + - uses: actions/setup-python@v5 with: python-version: '3.11' cache: 'pip' + - name: Install poetry uses: snok/install-poetry@v1 with: @@ -25,23 +27,28 @@ jobs: virtualenvs-in-project: true virtualenvs-path: .venv installer-parallel: true + - name: Load cached venv id: cached-poetry-dependencies uses: actions/cache@v4 with: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} + - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' run: poetry install --with dev --no-interaction --no-root + - name: Setup test DB container run: make test-db + - name: Test with pytest if: github.actor != 'dependabot[bot]' run: | make run-unit-tests-debug env: PYTHONPATH: src + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 with: diff --git a/src/backend/pytest_integration.ini b/src/backend/pytest_integration.ini index bc9ea9572c..c686703e0c 100644 --- a/src/backend/pytest_integration.ini +++ b/src/backend/pytest_integration.ini @@ -1,3 +1,5 @@ [pytest] env = DATABASE_URL=postgresql://postgres:postgres@db:5432/postgres +filterwarnings = + ignore::UserWarning:pydantic.* From 0a42a0d50f3527e096e8c2dab9a9bbc263e1705c Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Tue, 14 Jan 2025 12:30:55 -0800 Subject: [PATCH 4/7] chore(backend): fix type hints --- src/backend/model_deployments/azure.py | 6 ++--- src/backend/model_deployments/base.py | 8 +++--- src/backend/model_deployments/bedrock.py | 6 ++--- .../model_deployments/cohere_platform.py | 6 ++--- src/backend/model_deployments/sagemaker.py | 6 ++--- .../model_deployments/single_container.py | 6 ++--- src/backend/schemas/context.py | 26 +++++++++---------- src/backend/services/conversation.py | 16 ++++++------ src/backend/tests/integration/conftest.py | 2 +- .../mock_deployments/mock_azure.py | 6 ++--- .../mock_deployments/mock_bedrock.py | 4 +-- .../mock_deployments/mock_sagemaker.py | 4 +-- .../mock_deployments/mock_single_container.py | 4 +-- 13 files changed, 50 insertions(+), 50 deletions(-) diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index e7849f0371..3829f8819e 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator import cohere @@ -49,7 +49,7 @@ def rerank_enabled(self) -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not cls.is_available(): return [] @@ -79,6 +79,6 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context + self, query: str, documents: list[str], ctx: Context ) -> Any: return None diff --git a/src/backend/model_deployments/base.py b/src/backend/model_deployments/base.py index 6436421e5a..c235cca76a 100644 --- a/src/backend/model_deployments/base.py +++ b/src/backend/model_deployments/base.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context @@ -19,8 +19,8 @@ class BaseDeployment: @abstractmethod def rerank_enabled(self) -> bool: ... - @staticmethod - def list_models() -> List[str]: ... + @classmethod + def list_models(cls) -> list[str]: ... @staticmethod def is_available() -> bool: ... @@ -37,5 +37,5 @@ async def invoke_chat_stream( @abstractmethod async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: ... diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index 094ed243a3..f5dd726266 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator import cohere @@ -53,7 +53,7 @@ def rerank_enabled(self) -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not cls.is_available(): return [] @@ -94,6 +94,6 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context + self, query: str, documents: list[str], ctx: Context ) -> Any: return None diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index f8da71693d..d1d7cd98c5 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any import cohere import requests @@ -34,7 +34,7 @@ def rerank_enabled(self) -> bool: return True @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: logger = LoggerFactory().get_logger() if not CohereDeployment.is_available(): return [] @@ -91,7 +91,7 @@ async def invoke_chat_stream( yield event_dict async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: response = self.client.rerank( query=query, documents=documents, model=DEFAULT_RERANK_MODEL diff --git a/src/backend/model_deployments/sagemaker.py b/src/backend/model_deployments/sagemaker.py index 56d2a96555..5777c59c0a 100644 --- a/src/backend/model_deployments/sagemaker.py +++ b/src/backend/model_deployments/sagemaker.py @@ -1,6 +1,6 @@ import io import json -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator import boto3 @@ -77,7 +77,7 @@ def rerank_enabled(self) -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not SageMakerDeployment.is_available(): return [] @@ -114,7 +114,7 @@ async def invoke_chat_stream( yield stream_event async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context + self, query: str, documents: list[str], ctx: Context ) -> Any: return None diff --git a/src/backend/model_deployments/single_container.py b/src/backend/model_deployments/single_container.py index 9c727a2186..64639f2bc3 100644 --- a/src/backend/model_deployments/single_container.py +++ b/src/backend/model_deployments/single_container.py @@ -1,4 +1,4 @@ -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator import cohere @@ -39,7 +39,7 @@ def rerank_enabled(self) -> bool: return SingleContainerDeployment.default_model.startswith("rerank") @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not SingleContainerDeployment.is_available(): return [] @@ -73,7 +73,7 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context + self, query: str, documents: list[str], ctx: Context ) -> Any: return self.client.rerank( query=query, documents=documents, model=DEFAULT_RERANK_MODEL diff --git a/src/backend/schemas/context.py b/src/backend/schemas/context.py index 19365e62a8..29a288fcd1 100644 --- a/src/backend/schemas/context.py +++ b/src/backend/schemas/context.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Self from pydantic import BaseModel @@ -65,7 +65,7 @@ def with_deployment_name(self, deployment_name: str): def with_user( self, session: DBSessionDep | None = None, user: User | None = None - ) -> "Context": + ) -> Self: if not user and not session: return self @@ -78,42 +78,42 @@ def with_user( return self - def with_agent(self, agent: Agent | None) -> "Context": + def with_agent(self, agent: Agent | None) -> Self: self.agent = agent return self def with_agent_tool_metadata( self, agent_tool_metadata: AgentToolMetadata - ) -> "Context": + ) -> Self: self.agent_tool_metadata = agent_tool_metadata return self - def with_model(self, model: str) -> "Context": + def with_model(self, model: str) -> Self: self.model = model return self - def with_deployment_config(self, deployment_config=None) -> "Context": + def with_deployment_config(self, deployment_config=None) -> Self: if deployment_config: self.deployment_config = deployment_config else: self.deployment_config = get_deployment_config(self.request) return self - def with_conversation_id(self, conversation_id: str) -> "Context": + def with_conversation_id(self, conversation_id: str) -> Self: self.conversation_id = conversation_id return self - def with_stream_start_ms(self, now_ms: float) -> "Context": + def with_stream_start_ms(self, now_ms: float) -> Self: self.stream_start_ms = now_ms - def with_agent_id(self, agent_id: str) -> "Context": + def with_agent_id(self, agent_id: str) -> Self: if not agent_id: return self self.agent_id = agent_id return self - def with_organization_id(self, organization_id: str) -> "Context": + def with_organization_id(self, organization_id: str) -> Self: self.organization_id = organization_id return self @@ -121,7 +121,7 @@ def with_organization( self, session: DBSessionDep | None = None, organization: Organization | None = None, - ) -> "Context": + ) -> Self: if not organization and not session: return self @@ -138,11 +138,11 @@ def with_organization( return self - def with_global_filtering(self) -> "Context": + def with_global_filtering(self) -> Self: self.use_global_filtering = True return self - def without_global_filtering(self) -> "Context": + def without_global_filtering(self) -> Self: self.use_global_filtering = False return self diff --git a/src/backend/services/conversation.py b/src/backend/services/conversation.py index 8c47412b5f..5fc038c236 100644 --- a/src/backend/services/conversation.py +++ b/src/backend/services/conversation.py @@ -1,5 +1,5 @@ import uuid -from typing import List, Optional +from typing import Optional from fastapi import HTTPException @@ -141,7 +141,7 @@ def get_messages_with_files( return messages_with_file -def get_documents_to_rerank(conversations: List[Conversation]) -> List[str]: +def get_documents_to_rerank(conversations: list[Conversation]) -> list[str]: """Get documents (strings) to rerank from a list of conversations Args: @@ -165,22 +165,22 @@ def get_documents_to_rerank(conversations: List[Conversation]) -> List[str]: async def filter_conversations( query: str, - conversations: List[Conversation], - rerank_documents: List[str], + conversations: list[Conversation], + rerank_documents: list[str], model_deployment, ctx: Context, -) -> List[Conversation]: +) -> list[Conversation]: """Filter conversations based on the rerank score Args: query (str): The query to filter conversations - conversations (List[Conversation]): List of conversations - rerank_documents (List[str]): List of documents to rerank + conversations (list[Conversation]): List of conversations + rerank_documents (list[str]): List of documents to rerank model_deployment: Model deployment object ctx (Context): Context object Returns: - List[Conversation]: List of filtered conversations + list[Conversation]: List of filtered conversations """ # if rerank is not enabled, filter out conversations that don't contain the query if not model_deployment.rerank_enabled: diff --git a/src/backend/tests/integration/conftest.py b/src/backend/tests/integration/conftest.py index 5932ab18e4..3c7aa6ec38 100644 --- a/src/backend/tests/integration/conftest.py +++ b/src/backend/tests/integration/conftest.py @@ -24,7 +24,7 @@ @pytest.fixture -def client(): +def client() -> Generator[TestClient, None, None]: yield TestClient(app) diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py index 7104e5c603..f98eeaa1f6 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py @@ -24,8 +24,8 @@ def list_models(cls) -> List[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True def invoke_chat( @@ -84,6 +84,6 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py index 798d235070..528faa58eb 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py @@ -21,8 +21,8 @@ def rerank_enabled(self) -> bool: def list_models(cls) -> List[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True def invoke_chat( diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py index b68e312518..91c3f6c0d6 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py @@ -21,8 +21,8 @@ def rerank_enabled(self) -> bool: def list_models(cls) -> List[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True def invoke_chat_stream( diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py index c64f7f5f94..1045ec578c 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py @@ -21,8 +21,8 @@ def rerank_enabled(self) -> bool: def list_models(cls) -> List[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True def invoke_chat( From 4bc40e1852d34e104d9c6512ffa5f25290cfa8f9 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Tue, 14 Jan 2025 12:43:08 -0800 Subject: [PATCH 5/7] feat(backend): remove dependency on Cohere API key --- .../integration/routers/test_conversation.py | 32 +++++---------- .../mock_deployments/mock_cohere_platform.py | 41 ++++++++++++++----- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/src/backend/tests/integration/routers/test_conversation.py b/src/backend/tests/integration/routers/test_conversation.py index 80a25d245b..c34a2afa37 100644 --- a/src/backend/tests/integration/routers/test_conversation.py +++ b/src/backend/tests/integration/routers/test_conversation.py @@ -1,5 +1,3 @@ -import os - import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -10,11 +8,14 @@ from backend.schemas.user import User from backend.tests.unit.factories import get_factory +_IS_GOOGLE_CLOUD_API_KEY_SET = bool(Settings().get('google_cloud.api_key')) + def test_search_conversations( session_client: TestClient, session: Session, user: User, + mock_available_model_deployments, ) -> None: conversation = get_factory("Conversation", session).create( title="test title", user_id=user.id @@ -24,8 +25,6 @@ def test_search_conversations( headers={"User-Id": user.id}, params={"query": "test"}, ) - print("here") - print(response.json) results = response.json() assert response.status_code == 200 @@ -33,14 +32,11 @@ def test_search_conversations( assert results[0]["id"] == conversation.id -@pytest.mark.skipif( - os.environ.get("COHERE_API_KEY") is None, - reason="Cohere API key not set, skipping test", -) def test_search_conversations_with_reranking( session_client: TestClient, session: Session, user: User, + mock_available_model_deployments, ) -> None: _ = get_factory("Conversation", session).create( title="Hello, how are you?", text_messages=[], user_id=user.id @@ -83,19 +79,16 @@ def test_search_conversations_no_conversations( assert response.json() == [] -# MISC - - -@pytest.mark.skip(reason="Restore this test when we get access to run models on Huggingface") def test_generate_title( session_client: TestClient, session: Session, user: User, + mock_available_model_deployments, ) -> None: - conversation = get_factory("Conversation", session).create(user_id=user.id) + conversation_initial = get_factory("Conversation", session).create(user_id=user.id) response = session_client.post( - f"/v1/conversations/{conversation.id}/generate-title", - headers={"User-Id": conversation.user_id}, + f"/v1/conversations/{conversation_initial.id}/generate-title", + headers={"User-Id": conversation_initial.user_id}, ) response_json = response.json() @@ -105,7 +98,7 @@ def test_generate_title( # Check if the conversation was updated conversation = ( session.query(Conversation) - .filter_by(id=conversation.id, user_id=conversation.user_id) + .filter_by(id=conversation_initial.id, user_id=conversation_initial.user_id) .first() ) assert conversation is not None @@ -165,10 +158,7 @@ def test_generate_title_error_invalid_model( # SYNTHESIZE -is_google_cloud_api_key_set = bool(Settings().get('google_cloud.api_key')) - - -@pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test") +@pytest.mark.skipif(not _IS_GOOGLE_CLOUD_API_KEY_SET, reason="Google Cloud API key not set, skipping test") def test_synthesize_english_message( session_client: TestClient, session: Session, @@ -186,7 +176,7 @@ def test_synthesize_english_message( assert response.headers["Content-Type"] == "audio/mp3" -@pytest.mark.skipif(not is_google_cloud_api_key_set, reason="Google Cloud API key not set, skipping test") +@pytest.mark.skipif(not _IS_GOOGLE_CLOUD_API_KEY_SET, reason="Google Cloud API key not set, skipping test") def test_synthesize_non_english_message( session_client: TestClient, session: Session, diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py index 3fe818d497..eb0101163e 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Generator, List +import random +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -6,6 +7,7 @@ from backend.model_deployments.base import BaseDeployment from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context +from backend.services.conversation import SEARCH_RELEVANCE_THRESHOLD class MockCohereDeployment(BaseDeployment): @@ -13,20 +15,23 @@ class MockCohereDeployment(BaseDeployment): DEFAULT_MODELS = ["command", "command-r"] + def __init__(self, **kwargs: Any): + pass + @property def rerank_enabled(self) -> bool: return True @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return True - def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + async def invoke_chat( + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: event = { "text": "Hi! Hello there! How's it going?", @@ -51,7 +56,7 @@ def invoke_chat( } yield event - def invoke_chat_stream( + async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: events = [ @@ -79,8 +84,22 @@ def invoke_chat_stream( for event in events: yield event - def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + async def invoke_rerank( + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: - # TODO: Add - pass + results = [] + for idx, doc in enumerate(documents): + if query in doc: + results.append({ + "index": idx, + "relevance_score": random.uniform(SEARCH_RELEVANCE_THRESHOLD, 1), + }) + event = { + "id": "eae2b023-bf49-4139-bf15-9825022762f4", + "results": results, + "meta": { + "api_version":{"version":"1"}, + "billed_units":{"search_units":1} + } + } + return event From d341c24088028841ee0947d7bdca7f1a3bf2872a Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Wed, 15 Jan 2025 06:12:59 -0800 Subject: [PATCH 6/7] chore(backend): fix issues from rebase --- src/backend/model_deployments/azure.py | 20 +++++------ src/backend/model_deployments/base.py | 23 ++++++------- src/backend/model_deployments/bedrock.py | 20 +++++------ .../model_deployments/cohere_platform.py | 18 +++++----- src/backend/model_deployments/sagemaker.py | 18 +++++----- .../model_deployments/single_container.py | 20 +++++------ .../tests/integration/routers/test_agent.py | 33 ++++++++++++------- .../mock_deployments/mock_azure.py | 18 +++++----- .../mock_deployments/mock_bedrock.py | 20 +++++------ .../mock_deployments/mock_cohere_platform.py | 15 +++++---- .../mock_deployments/mock_sagemaker.py | 20 +++++------ .../mock_deployments/mock_single_container.py | 20 +++++------ 12 files changed, 129 insertions(+), 116 deletions(-) diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index 2463f610d5..dce6660516 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -43,16 +43,16 @@ def __init__(self, **kwargs: Any): base_url=self.chat_endpoint_url, api_key=self.api_key ) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Azure" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [AZURE_API_KEY_ENV_VAR, AZURE_CHAT_URL_ENV_VAR] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod @@ -62,14 +62,14 @@ def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( AzureDeployment.default_api_key is not None and AzureDeployment.default_chat_endpoint_url is not None ) - async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any: response = self.client.chat( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), ) @@ -86,6 +86,6 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: list[str], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs ) -> Any: return None diff --git a/src/backend/model_deployments/base.py b/src/backend/model_deployments/base.py index f578d15f80..f6bec64471 100644 --- a/src/backend/model_deployments/base.py +++ b/src/backend/model_deployments/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator +from typing import Any from backend.config.settings import Settings from backend.schemas.cohere_chat import CohereChatRequest @@ -25,31 +25,32 @@ def __init__(self, db_id=None, **kwargs: Any): def id(cls) -> str: return cls.db_id if cls.db_id else cls.name().replace(" ", "_").lower() - @classmethod + @staticmethod @abstractmethod - def name(cls) -> str: ... + def name() -> str: ... - @classmethod + @staticmethod @abstractmethod - def env_vars(cls) -> List[str]: ... + def env_vars() -> list[str]: ... - @classmethod + @staticmethod @abstractmethod - def rerank_enabled(cls) -> bool: ... + def rerank_enabled() -> bool: ... @classmethod @abstractmethod def list_models(cls) -> list[str]: ... - @classmethod + @staticmethod @abstractmethod - def is_available(cls) -> bool: ... + def is_available() -> bool: ... @classmethod def is_community(cls) -> bool: return False - def config(cls) -> Dict[str, Any]: + @classmethod + def config(cls) -> dict[str, Any]: config = Settings().get(f"deployments.{cls.id()}") config_dict = {} if not config else dict(config) for key, value in config_dict.items(): @@ -78,7 +79,7 @@ async def invoke_chat( @abstractmethod async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any - ) -> AsyncGenerator[Any, Any]: ... + ) -> Any: ... @abstractmethod async def invoke_rerank( diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index aba4e5a091..9403deed47 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -42,12 +42,12 @@ def __init__(self, **kwargs: Any): ), ) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Bedrock" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [ BEDROCK_ACCESS_KEY_ENV_VAR, BEDROCK_SECRET_KEY_ENV_VAR, @@ -55,8 +55,8 @@ def env_vars(cls) -> List[str]: BEDROCK_REGION_NAME_ENV_VAR, ] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod @@ -66,8 +66,8 @@ def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( BedrockDeployment.access_key is not None and BedrockDeployment.secret_access_key is not None @@ -75,7 +75,7 @@ def is_available(cls) -> bool: and BedrockDeployment.region_name is not None ) - async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: # bedrock accepts a subset of the chat request fields bedrock_chat_req = chat_request.model_dump( exclude={"tools", "conversation_id", "model", "stream"}, exclude_none=True @@ -101,6 +101,6 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: list[str], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index dac13d1347..a718d0b68e 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -29,16 +29,16 @@ def __init__(self, **kwargs: Any): ) self.client = cohere.Client(api_key, client_name=self.client_name) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Cohere Platform" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [COHERE_API_KEY_ENV_VAR] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return True @classmethod @@ -64,12 +64,12 @@ def list_models(cls) -> list[str]: models = response.json()["models"] return [model["name"] for model in models if model.get("endpoints") and "chat" in model["endpoints"]] - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return CohereDeployment.api_key is not None async def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Any: response = self.client.chat( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), diff --git a/src/backend/model_deployments/sagemaker.py b/src/backend/model_deployments/sagemaker.py index 317ec70a1a..6a686a378c 100644 --- a/src/backend/model_deployments/sagemaker.py +++ b/src/backend/model_deployments/sagemaker.py @@ -65,12 +65,12 @@ def __init__(self, **kwargs: Any): "ContentType": "application/json", } - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "SageMaker" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [ SAGE_MAKER_ACCESS_KEY_ENV_VAR, SAGE_MAKER_SECRET_KEY_ENV_VAR, @@ -79,8 +79,8 @@ def env_vars(cls) -> List[str]: SAGE_MAKER_ENDPOINT_NAME_ENV_VAR, ] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod @@ -90,8 +90,8 @@ def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( SageMakerDeployment.region_name is not None and SageMakerDeployment.aws_access_key_id is not None @@ -121,7 +121,7 @@ async def invoke_chat_stream( yield stream_event async def invoke_rerank( - self, query: str, documents: list[str], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs ) -> Any: return None diff --git a/src/backend/model_deployments/single_container.py b/src/backend/model_deployments/single_container.py index 78c7bf0a0a..4ddcb2e174 100644 --- a/src/backend/model_deployments/single_container.py +++ b/src/backend/model_deployments/single_container.py @@ -33,16 +33,16 @@ def __init__(self, **kwargs: Any): base_url=self.url, client_name=self.client_name, api_key="none" ) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Single Container" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [SC_URL_ENV_VAR, SC_MODEL_ENV_VAR] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return SingleContainerDeployment.default_model.startswith("rerank") @classmethod @@ -52,14 +52,14 @@ def list_models(cls) -> list[str]: return [SingleContainerDeployment.default_model] - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( SingleContainerDeployment.default_model is not None and SingleContainerDeployment.default_url is not None ) - async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any: response = self.client.chat( **chat_request.model_dump( exclude={"stream", "file_ids", "model", "agent_id"} @@ -80,7 +80,7 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: list[str], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs ) -> Any: return self.client.rerank( query=query, documents=documents, model=DEFAULT_RERANK_MODEL diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index cb8f863386..af276bbfda 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -21,7 +21,12 @@ ) -def test_create_agent(session_client: TestClient, session: Session, user: User, mock_cohere_list_models) -> None: +def test_create_agent( + session_client: TestClient, + session: Session, + user: User, + mock_cohere_list_models, +) -> None: request_json = { "name": "test agent", "version": 1, @@ -297,7 +302,7 @@ def test_create_agent_deployment_not_in_db( def test_create_agent_invalid_tool( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: request_json = { "name": "test agent", @@ -314,7 +319,7 @@ def test_create_agent_invalid_tool( def test_create_existing_agent( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: agent = get_factory("Agent", session).create(name="test agent") request_json = { @@ -336,7 +341,9 @@ def test_list_agents_empty_returns_default_agent(session_client: TestClient, ses assert len(response_agents) == 1 -def test_list_agents(session_client: TestClient, session: Session, user) -> None: +def test_list_agents( + session_client: TestClient, session: Session, user: User, +) -> None: num_agents = 3 for _ in range(num_agents): _ = get_factory("Agent", session).create(user=user) @@ -350,7 +357,7 @@ def test_list_agents(session_client: TestClient, session: Session, user) -> None def test_list_organization_agents( session_client: TestClient, session: Session, - user, + user: User, ) -> None: num_agents = 3 organization = get_factory("Organization", session).create() @@ -379,7 +386,7 @@ def test_list_organization_agents( def test_list_organization_agents_query_param( session_client: TestClient, session: Session, - user, + user: User, ) -> None: num_agents = 3 organization = get_factory("Organization", session).create() @@ -408,7 +415,7 @@ def test_list_organization_agents_query_param( def test_list_organization_agents_nonexistent_organization( session_client: TestClient, session: Session, - user, + user: User, ) -> None: response = session_client.get( "/v1/agents", headers={"User-Id": user.id, "Organization-Id": "123"} @@ -418,7 +425,7 @@ def test_list_organization_agents_nonexistent_organization( def test_list_private_agents( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: for _ in range(3): _ = get_factory("Agent", session).create(user=user, is_private=True) @@ -438,7 +445,9 @@ def test_list_private_agents( assert len(response_agents) == 3 -def test_list_public_agents(session_client: TestClient, session: Session, user) -> None: +def test_list_public_agents( + session_client: TestClient, session: Session, user: User, +) -> None: for _ in range(3): _ = get_factory("Agent", session).create(user=user, is_private=True) @@ -451,6 +460,7 @@ def test_list_public_agents(session_client: TestClient, session: Session, user) ) assert response.status_code == 200 + breakpoint() response_agents = filter_default_agent(response.json()) # Only the agents created by user should be returned @@ -458,7 +468,7 @@ def test_list_public_agents(session_client: TestClient, session: Session, user) def list_public_and_private_agents( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: for _ in range(3): _ = get_factory("Agent", session).create(user=user, is_private=True) @@ -479,7 +489,7 @@ def list_public_and_private_agents( def test_list_agents_with_pagination( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: for _ in range(5): _ = get_factory("Agent", session).create(user=user) @@ -495,6 +505,7 @@ def test_list_agents_with_pagination( "/v1/agents?limit=2&offset=4", headers={"User-Id": user.id} ) assert response.status_code == 200 + breakpoint() response_agents = filter_default_agent(response.json()) assert len(response_agents) == 1 diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py index 2d4129cf9b..12279f1c23 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,20 +18,20 @@ class MockAzureDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Azure" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return ["AZURE_API_KEY", "AZURE_CHAT_ENDPOINT_URL"] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not cls.is_available(): return [] @@ -42,7 +42,7 @@ def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Any: event = { "text": "Hi! Hello there! How's it going?", diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py index f2fd3eb9a2..53fa171faa 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,20 +18,20 @@ class MockBedrockDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Bedrock" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS @staticmethod @@ -39,7 +39,7 @@ def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: event = { "text": "Hi! Hello there! How's it going?", @@ -93,6 +93,6 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py index 02120cbe97..dd2e95b03a 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py @@ -6,6 +6,7 @@ from backend.chat.enums import StreamEvent from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context +from backend.services.conversation import SEARCH_RELEVANCE_THRESHOLD from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( MockDeployment, ) @@ -19,16 +20,16 @@ class MockCohereDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Cohere Platform" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return ["COHERE_API_KEY"] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return True @classmethod @@ -40,7 +41,7 @@ def is_available() -> bool: return True @classmethod - def config(cls) -> Dict[str, Any]: + def config(cls) -> dict[str, Any]: return {"COHERE_API_KEY": "fake-api-key"} def invoke_chat( diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py index 0b7353ff53..2f0f577562 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,20 +18,20 @@ class MockSageMakerDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "SageMaker" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS @staticmethod @@ -39,7 +39,7 @@ def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: pass @@ -72,6 +72,6 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py index 7451af8541..2ed0464ff4 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,20 +18,20 @@ class MockSingleContainerDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Single Container" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS @staticmethod @@ -39,7 +39,7 @@ def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: event = { "text": "Hi! Hello there! How's it going?", @@ -93,7 +93,7 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: # TODO: Add pass From 7fc86a9ea3c77ca0f03e17eba99248dbb1956683 Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Wed, 15 Jan 2025 06:16:45 -0800 Subject: [PATCH 7/7] chore(backend): remove breakpoints --- src/backend/tests/integration/routers/test_agent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index af276bbfda..30c2e6b8ce 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -460,7 +460,6 @@ def test_list_public_agents( ) assert response.status_code == 200 - breakpoint() response_agents = filter_default_agent(response.json()) # Only the agents created by user should be returned @@ -505,7 +504,6 @@ def test_list_agents_with_pagination( "/v1/agents?limit=2&offset=4", headers={"User-Id": user.id} ) assert response.status_code == 200 - breakpoint() response_agents = filter_default_agent(response.json()) assert len(response_agents) == 1