From 3b1c27f8f73a138befc7302eb973d1b48bbf26fa Mon Sep 17 00:00:00 2001 From: Eric Zawadski <58191072+ezawadski@users.noreply.github.com> Date: Fri, 17 Jan 2025 06:08:56 -0800 Subject: [PATCH] Backend: Fix issue 909 (#910) * fix(backend): revert AVAILABLE_MODEL_DEPLOYMENTS to dictionary from list --- src/backend/config/deployments.py | 17 +++++++++-------- src/backend/crud/agent.py | 2 +- src/backend/scripts/cli/prompts.py | 4 ++-- src/backend/services/deployment.py | 12 +++++++----- src/backend/services/request_validators.py | 4 +++- src/backend/tests/integration/conftest.py | 2 +- .../integration/routers/test_deployment.py | 2 +- src/backend/tests/unit/conftest.py | 2 +- .../tests/unit/services/test_deployment.py | 7 +++++-- 9 files changed, 30 insertions(+), 22 deletions(-) diff --git a/src/backend/config/deployments.py b/src/backend/config/deployments.py index 55b4b1c74b..9a9e03ff34 100644 --- a/src/backend/config/deployments.py +++ b/src/backend/config/deployments.py @@ -8,15 +8,16 @@ ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() } -def get_available_deployments() -> list[type[BaseDeployment]]: - installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values()) +def get_available_deployments() -> dict[str, type[BaseDeployment]]: + installed_deployments = ALL_MODEL_DEPLOYMENTS.copy() if Settings().get("feature_flags.use_community_features"): try: from community.config.deployments import ( AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP, ) - installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values()) + + installed_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP) except ImportError as e: logger.warning( event="[Deployments] No available community deployments have been configured", ex=e @@ -24,11 +25,11 @@ def get_available_deployments() -> list[type[BaseDeployment]]: enabled_deployment_ids = Settings().get("deployments.enabled_deployments") if enabled_deployment_ids: - return [ - deployment - for deployment in installed_deployments - if deployment.id() in enabled_deployment_ids - ] + return { + key: value + for key, value in installed_deployments.items() + if value.id() in enabled_deployment_ids + } return installed_deployments diff --git a/src/backend/crud/agent.py b/src/backend/crud/agent.py index 4eaa3c58ea..8dd34c5727 100644 --- a/src/backend/crud/agent.py +++ b/src/backend/crud/agent.py @@ -140,7 +140,7 @@ def update_agent( if agent.is_private and agent.user_id != user_id: return None - new_agent_cleaned = new_agent.dict(exclude_unset=True, exclude_none=True) + new_agent_cleaned = new_agent.model_dump(exclude_unset=True, exclude_none=True) for attr, value in new_agent_cleaned.items(): setattr(agent, attr, value) diff --git a/src/backend/scripts/cli/prompts.py b/src/backend/scripts/cli/prompts.py index 713e7fdcfe..871d2aa078 100644 --- a/src/backend/scripts/cli/prompts.py +++ b/src/backend/scripts/cli/prompts.py @@ -65,7 +65,7 @@ def core_env_var_prompt(secrets): def deployment_prompt(secrets, configs): - for secret in configs.env_vars: + for secret in configs.env_vars(): value = secrets.get(secret) if not value: @@ -149,7 +149,7 @@ def select_deployments_prompt(deployments, _): deployments = inquirer.checkbox( "Select the model deployments you want to set up", - choices=[deployment.value for deployment in deployments.keys()], + choices=[deployment for deployment in deployments.keys()], default=["Cohere Platform"], validate=lambda _, x: len(x) > 0, ) diff --git a/src/backend/services/deployment.py b/src/backend/services/deployment.py index 67ed2923d1..ac4c597af4 100644 --- a/src/backend/services/deployment.py +++ b/src/backend/services/deployment.py @@ -38,7 +38,7 @@ def create_db_deployment(session: DBSessionDep, deployment: DeploymentDefinition def get_default_deployment(**kwargs) -> BaseDeployment: try: - fallback = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.is_available) + fallback = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.is_available()) except StopIteration: raise NoAvailableDeploymentsError() @@ -47,7 +47,7 @@ def get_default_deployment(**kwargs) -> BaseDeployment: return next( ( d - for d in AVAILABLE_MODEL_DEPLOYMENTS + for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.id() == default_deployment ), fallback, @@ -63,7 +63,9 @@ def get_deployment_by_name(session: DBSessionDep, deployment_name: str, **kwargs definition = get_deployment_definition_by_name(session, deployment_name) try: - return next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.__name__ == definition.class_name)(db_id=definition.id, **definition.config, **kwargs) + return next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.__name__ == definition.class_name)( + db_id=definition.id, **definition.config, **kwargs + ) except StopIteration: raise DeploymentNotFoundError(deployment_id=deployment_name) @@ -73,7 +75,7 @@ def get_deployment_definition(session: DBSessionDep, deployment_id: str) -> Depl return DeploymentDefinition.from_db_deployment(db_deployment) try: - deployment = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.id() == deployment_id) + deployment = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.id() == deployment_id) except StopIteration: raise DeploymentNotFoundError(deployment_id=deployment_id) @@ -101,7 +103,7 @@ def get_deployment_definitions(session: DBSessionDep) -> list[DeploymentDefiniti installed_deployments = [ deployment.to_deployment_definition() - for deployment in AVAILABLE_MODEL_DEPLOYMENTS + for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values() if deployment.name() not in db_deployments ] diff --git a/src/backend/services/request_validators.py b/src/backend/services/request_validators.py index ab35ab0dc9..5bbdd65248 100644 --- a/src/backend/services/request_validators.py +++ b/src/backend/services/request_validators.py @@ -42,7 +42,9 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep detail=f"Deployment {deployment} not found or is not available in the Database.", ) - deployment_config = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.__name__ == found.class_name).to_deployment_definition() + deployment_config = next( + d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.__name__ == found.class_name + ).to_deployment_definition() deployment_model = next( ( model_db diff --git a/src/backend/tests/integration/conftest.py b/src/backend/tests/integration/conftest.py index 0b005901a7..d2207d04fc 100644 --- a/src/backend/tests/integration/conftest.py +++ b/src/backend/tests/integration/conftest.py @@ -189,7 +189,7 @@ def mock_available_model_deployments(request): MockBedrockDeployment.name(): MockBedrockDeployment, } - with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", list(MOCKED_DEPLOYMENTS.values())) as mock: + with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", MOCKED_DEPLOYMENTS) as mock: yield mock @pytest.fixture diff --git a/src/backend/tests/integration/routers/test_deployment.py b/src/backend/tests/integration/routers/test_deployment.py index 4ffd5f7af6..6df4d29a1a 100644 --- a/src/backend/tests/integration/routers/test_deployment.py +++ b/src/backend/tests/integration/routers/test_deployment.py @@ -93,7 +93,7 @@ def test_list_deployments_no_available_models_404( session_client: TestClient, session: Session ) -> None: session.query(Deployment).delete() - with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", []): + with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", {}): response = session_client.get("/v1/deployments") assert response.status_code == 404 assert response.json() == { diff --git a/src/backend/tests/unit/conftest.py b/src/backend/tests/unit/conftest.py index 156c175b95..9b180aaef5 100644 --- a/src/backend/tests/unit/conftest.py +++ b/src/backend/tests/unit/conftest.py @@ -207,5 +207,5 @@ def mock_available_model_deployments(request): MockSingleContainerDeployment.name(): MockSingleContainerDeployment, } - with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", list(MOCKED_DEPLOYMENTS.values())) as mock: + with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", MOCKED_DEPLOYMENTS) as mock: yield mock diff --git a/src/backend/tests/unit/services/test_deployment.py b/src/backend/tests/unit/services/test_deployment.py index d3b76df229..44df0d0512 100644 --- a/src/backend/tests/unit/services/test_deployment.py +++ b/src/backend/tests/unit/services/test_deployment.py @@ -41,7 +41,7 @@ def test_all_tools_have_id() -> None: assert tool.value.ID is not None def test_get_default_deployment_none_available() -> None: - with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", []): + with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", {}): with pytest.raises(NoAvailableDeploymentsError): deployment_service.get_default_deployment() @@ -106,7 +106,10 @@ def test_get_deployment_definitions_with_db_deployments(session, mock_available_ id="db-mock-cohere-platform-id", ) with patch("backend.crud.deployment.get_deployments", return_value=[mock_cohere_deployment]): - with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", [MockCohereDeployment, MockAzureDeployment]): + with patch( + "backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", + { MockCohereDeployment.name(): MockCohereDeployment, MockAzureDeployment.name(): MockAzureDeployment } + ): definitions = deployment_service.get_deployment_definitions(session) assert len(definitions) == 2