Skip to content

Commit

Permalink
Backend: Fix issue 909 (#910)
Browse files Browse the repository at this point in the history
* fix(backend): revert AVAILABLE_MODEL_DEPLOYMENTS to dictionary from list
  • Loading branch information
ezawadski authored Jan 17, 2025
1 parent 9ce3c5c commit 3b1c27f
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 22 deletions.
17 changes: 9 additions & 8 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,28 @@
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() }


def get_available_deployments() -> list[type[BaseDeployment]]:
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values())
def get_available_deployments() -> dict[str, type[BaseDeployment]]:
installed_deployments = ALL_MODEL_DEPLOYMENTS.copy()

if Settings().get("feature_flags.use_community_features"):
try:
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values())

installed_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
except ImportError as e:
logger.warning(
event="[Deployments] No available community deployments have been configured", ex=e
)

enabled_deployment_ids = Settings().get("deployments.enabled_deployments")
if enabled_deployment_ids:
return [
deployment
for deployment in installed_deployments
if deployment.id() in enabled_deployment_ids
]
return {
key: value
for key, value in installed_deployments.items()
if value.id() in enabled_deployment_ids
}

return installed_deployments

Expand Down
2 changes: 1 addition & 1 deletion src/backend/crud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def update_agent(
if agent.is_private and agent.user_id != user_id:
return None

new_agent_cleaned = new_agent.dict(exclude_unset=True, exclude_none=True)
new_agent_cleaned = new_agent.model_dump(exclude_unset=True, exclude_none=True)

for attr, value in new_agent_cleaned.items():
setattr(agent, attr, value)
Expand Down
4 changes: 2 additions & 2 deletions src/backend/scripts/cli/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def core_env_var_prompt(secrets):


def deployment_prompt(secrets, configs):
for secret in configs.env_vars:
for secret in configs.env_vars():
value = secrets.get(secret)

if not value:
Expand Down Expand Up @@ -149,7 +149,7 @@ def select_deployments_prompt(deployments, _):

deployments = inquirer.checkbox(
"Select the model deployments you want to set up",
choices=[deployment.value for deployment in deployments.keys()],
choices=[deployment for deployment in deployments.keys()],
default=["Cohere Platform"],
validate=lambda _, x: len(x) > 0,
)
Expand Down
12 changes: 7 additions & 5 deletions src/backend/services/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_db_deployment(session: DBSessionDep, deployment: DeploymentDefinition

def get_default_deployment(**kwargs) -> BaseDeployment:
try:
fallback = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.is_available)
fallback = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.is_available())
except StopIteration:
raise NoAvailableDeploymentsError()

Expand All @@ -47,7 +47,7 @@ def get_default_deployment(**kwargs) -> BaseDeployment:
return next(
(
d
for d in AVAILABLE_MODEL_DEPLOYMENTS
for d in AVAILABLE_MODEL_DEPLOYMENTS.values()
if d.id() == default_deployment
),
fallback,
Expand All @@ -63,7 +63,9 @@ def get_deployment_by_name(session: DBSessionDep, deployment_name: str, **kwargs
definition = get_deployment_definition_by_name(session, deployment_name)

try:
return next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.__name__ == definition.class_name)(db_id=definition.id, **definition.config, **kwargs)
return next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.__name__ == definition.class_name)(
db_id=definition.id, **definition.config, **kwargs
)
except StopIteration:
raise DeploymentNotFoundError(deployment_id=deployment_name)

Expand All @@ -73,7 +75,7 @@ def get_deployment_definition(session: DBSessionDep, deployment_id: str) -> Depl
return DeploymentDefinition.from_db_deployment(db_deployment)

try:
deployment = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.id() == deployment_id)
deployment = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.id() == deployment_id)
except StopIteration:
raise DeploymentNotFoundError(deployment_id=deployment_id)

Expand Down Expand Up @@ -101,7 +103,7 @@ def get_deployment_definitions(session: DBSessionDep) -> list[DeploymentDefiniti

installed_deployments = [
deployment.to_deployment_definition()
for deployment in AVAILABLE_MODEL_DEPLOYMENTS
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values()
if deployment.name() not in db_deployments
]

Expand Down
4 changes: 3 additions & 1 deletion src/backend/services/request_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep
detail=f"Deployment {deployment} not found or is not available in the Database.",
)

deployment_config = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS if d.__name__ == found.class_name).to_deployment_definition()
deployment_config = next(
d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.__name__ == found.class_name
).to_deployment_definition()
deployment_model = next(
(
model_db
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def mock_available_model_deployments(request):
MockBedrockDeployment.name(): MockBedrockDeployment,
}

with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", list(MOCKED_DEPLOYMENTS.values())) as mock:
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", MOCKED_DEPLOYMENTS) as mock:
yield mock

@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/integration/routers/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_list_deployments_no_available_models_404(
session_client: TestClient, session: Session
) -> None:
session.query(Deployment).delete()
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", []):
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", {}):
response = session_client.get("/v1/deployments")
assert response.status_code == 404
assert response.json() == {
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,5 @@ def mock_available_model_deployments(request):
MockSingleContainerDeployment.name(): MockSingleContainerDeployment,
}

with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", list(MOCKED_DEPLOYMENTS.values())) as mock:
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", MOCKED_DEPLOYMENTS) as mock:
yield mock
7 changes: 5 additions & 2 deletions src/backend/tests/unit/services/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_all_tools_have_id() -> None:
assert tool.value.ID is not None

def test_get_default_deployment_none_available() -> None:
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", []):
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", {}):
with pytest.raises(NoAvailableDeploymentsError):
deployment_service.get_default_deployment()

Expand Down Expand Up @@ -106,7 +106,10 @@ def test_get_deployment_definitions_with_db_deployments(session, mock_available_
id="db-mock-cohere-platform-id",
)
with patch("backend.crud.deployment.get_deployments", return_value=[mock_cohere_deployment]):
with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", [MockCohereDeployment, MockAzureDeployment]):
with patch(
"backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS",
{ MockCohereDeployment.name(): MockCohereDeployment, MockAzureDeployment.name(): MockAzureDeployment }
):
definitions = deployment_service.get_deployment_definitions(session)

assert len(definitions) == 2
Expand Down

0 comments on commit 3b1c27f

Please sign in to comment.