Skip to content

Commit

Permalink
Merge branch 'main' into feat/mock_integration_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ezawadski authored Jan 14, 2025
2 parents 4bc40e1 + 802c232 commit 74ed897
Show file tree
Hide file tree
Showing 101 changed files with 6,261 additions and 3,723 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from alembic import op

from backend.database_models.seeders.deplyments_models_seed import (
from backend.database_models.seeders.deployments_models_seed import (
delete_default_models,
deployments_models_seed,
)
Expand Down
29 changes: 9 additions & 20 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any

from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS,
get_default_deployment,
)
from backend.database_models.database import get_session
from backend.exceptions import DeploymentNotFoundError
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services import deployment as deployment_service


def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
Expand All @@ -16,22 +15,12 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
Returns:
BaseDeployment: Deployment implementation instance based on the deployment name.
Raises:
ValueError: If the deployment is not supported.
"""
kwargs["ctx"] = ctx
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(name)

# Check provided deployment against config const
if deployment is not None:
return deployment.deployment_class(**kwargs, **deployment.kwargs)

# Fallback to first available deployment
default = get_default_deployment(**kwargs)
if default is not None:
return default
try:
session = next(get_session())
deployment = deployment_service.get_deployment_by_name(session, name, **kwargs)
except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment(**kwargs)

raise ValueError(
f"Deployment {name} is not supported, and no available deployments were found."
)
return deployment
4 changes: 2 additions & 2 deletions src/backend/config/default_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime

from backend.config.deployments import ModelDeploymentName
from backend.config.tools import Tool
from backend.model_deployments.cohere_platform import CohereDeployment
from backend.schemas.agent import AgentPublic

DEFAULT_AGENT_ID = "default"
DEFAULT_DEPLOYMENT = ModelDeploymentName.CoherePlatform
DEFAULT_DEPLOYMENT = CohereDeployment.name()
DEFAULT_MODEL = "command-r-plus"

def get_default_agent() -> AgentPublic:
Expand Down
135 changes: 15 additions & 120 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,35 @@
from enum import StrEnum

from backend.config.settings import Settings
from backend.model_deployments import (
AzureDeployment,
BedrockDeployment,
CohereDeployment,
SageMakerDeployment,
SingleContainerDeployment,
)
from backend.model_deployments.azure import AZURE_ENV_VARS
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS
from backend.model_deployments.single_container import SC_ENV_VARS
from backend.schemas.deployment import Deployment
from backend.services.logger.utils import LoggerFactory

logger = LoggerFactory().get_logger()


class ModelDeploymentName(StrEnum):
CoherePlatform = "Cohere Platform"
SageMaker = "SageMaker"
Azure = "Azure"
Bedrock = "Bedrock"
SingleContainer = "Single Container"


use_community_features = Settings().get('feature_flags.use_community_features')
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() }

# TODO names in the map below should not be the display names but ids
ALL_MODEL_DEPLOYMENTS = {
ModelDeploymentName.CoherePlatform: Deployment(
id="cohere_platform",
name=ModelDeploymentName.CoherePlatform,
deployment_class=CohereDeployment,
models=CohereDeployment.list_models(),
is_available=CohereDeployment.is_available(),
env_vars=COHERE_ENV_VARS,
),
ModelDeploymentName.SingleContainer: Deployment(
id="single_container",
name=ModelDeploymentName.SingleContainer,
deployment_class=SingleContainerDeployment,
models=SingleContainerDeployment.list_models(),
is_available=SingleContainerDeployment.is_available(),
env_vars=SC_ENV_VARS,
),
ModelDeploymentName.SageMaker: Deployment(
id="sagemaker",
name=ModelDeploymentName.SageMaker,
deployment_class=SageMakerDeployment,
models=SageMakerDeployment.list_models(),
is_available=SageMakerDeployment.is_available(),
env_vars=SAGE_MAKER_ENV_VARS,
),
ModelDeploymentName.Azure: Deployment(
id="azure",
name=ModelDeploymentName.Azure,
deployment_class=AzureDeployment,
models=AzureDeployment.list_models(),
is_available=AzureDeployment.is_available(),
env_vars=AZURE_ENV_VARS,
),
ModelDeploymentName.Bedrock: Deployment(
id="bedrock",
name=ModelDeploymentName.Bedrock,
deployment_class=BedrockDeployment,
models=BedrockDeployment.list_models(),
is_available=BedrockDeployment.is_available(),
env_vars=BEDROCK_ENV_VARS,
),
}

def get_available_deployments() -> list[type[BaseDeployment]]:
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values())

def get_available_deployments() -> dict[ModelDeploymentName, Deployment]:
if use_community_features:
if Settings().get("feature_flags.use_community_features"):
try:
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)

model_deployments = ALL_MODEL_DEPLOYMENTS.copy()
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
return model_deployments
except ImportError:
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values())
except ImportError as e:
logger.warning(
event="[Deployments] No available community deployments have been configured"
event="[Deployments] No available community deployments have been configured", ex=e
)

deployments = Settings().get('deployments.enabled_deployments')
if deployments is not None and len(deployments) > 0:
return {
key: value
for key, value in ALL_MODEL_DEPLOYMENTS.items()
if value.id in Settings().get('deployments.enabled_deployments')
}

return ALL_MODEL_DEPLOYMENTS


def get_default_deployment(**kwargs) -> BaseDeployment:
# Fallback to the first available deployment
fallback = None
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.is_available:
fallback = deployment.deployment_class(**kwargs)
break

default = Settings().get('deployments.default_deployment')
if default:
return next(
(
v.deployment_class(**kwargs)
for k, v in AVAILABLE_MODEL_DEPLOYMENTS.items()
if v.id == default
),
fallback,
)
else:
return fallback


def find_config_by_deployment_id(deployment_id: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.id == deployment_id:
return deployment
return None


def find_config_by_deployment_name(deployment_name: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.name == deployment_name:
return deployment
return None
enabled_deployment_ids = Settings().get("deployments.enabled_deployments")
if enabled_deployment_ids:
return [
deployment
for deployment in installed_deployments
if deployment.id() in enabled_deployment_ids
]

return installed_deployments

AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments()
1 change: 0 additions & 1 deletion src/backend/config/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class RouterName(StrEnum):
TOOL = "tool"
USER = "user"
AGENT = "agent"
DEFAULT_AGENT = "default_agent"
SNAPSHOT = "snapshot"
MODEL = "model"
SCIM = "scim"
Expand Down
26 changes: 11 additions & 15 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import os

from sqlalchemy.orm import Session

from backend.database_models import Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
from backend.services.transaction import validate_transaction
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS,
from backend.schemas.deployment import (
DeploymentCreate,
DeploymentDefinition,
DeploymentUpdate,
)
from backend.services.transaction import validate_transaction


@validate_transaction
Expand All @@ -19,7 +18,7 @@ def create_deployment(db: Session, deployment: DeploymentCreate) -> Deployment:
Args:
db (Session): Database session.
deployment (DeploymentSchema): Deployment data to be created.
deployment (DeploymentDefinition): Deployment data to be created.
Returns:
Deployment: Created deployment.
Expand Down Expand Up @@ -132,27 +131,24 @@ def delete_deployment(db: Session, deployment_id: str) -> None:


@validate_transaction
def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment:
def create_deployment_by_config(db: Session, deployment_config: DeploymentDefinition) -> Deployment:
"""
Create a new deployment by config.
Args:
db (Session): Database session.
deployment (str): Deployment data to be created.
deployment_config (DeploymentSchema): Deployment config.
deployment_config (DeploymentDefinition): Deployment config.
Returns:
Deployment: Created deployment.
"""
deployment = Deployment(
name=deployment_config.name,
description="",
default_deployment_config= {
env_var: os.environ.get(env_var, "")
for env_var in deployment_config.env_vars
},
deployment_class_name=deployment_config.deployment_class.__name__,
is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS
default_deployment_config=deployment_config.config,
deployment_class_name=deployment_config.class_name,
is_community=deployment_config.is_community,
)
db.add(deployment)
db.commit()
Expand Down
22 changes: 12 additions & 10 deletions src/backend/crud/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from sqlalchemy.orm import Session

from backend.database_models import Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentDefinition
from backend.schemas.model import ModelCreate, ModelUpdate
from backend.services.logger.utils import LoggerFactory
from backend.services.transaction import validate_transaction

logger = LoggerFactory().get_logger()


@validate_transaction
def create_model(db: Session, model: ModelCreate) -> Model:
Expand Down Expand Up @@ -127,29 +129,29 @@ def delete_model(db: Session, model_id: str) -> None:
db.commit()


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
def create_model_by_config(db: Session, deployment_config: DeploymentDefinition, deployment_id: str, model: str | None) -> Model:
"""
Create a new model by config if present
Args:
db (Session): Database session.
deployment (Deployment): Deployment data.
deployment_config (DeploymentSchema): Deployment config data.
model (str): Model data.
deployment_config (DeploymentDefinition): A deployment definition for any kind of deployment.
deployment_id (DeploymentDefinition): Deployment ID for a deployment from the DB.
model (str): Optional model name that should have its data returned from this call.
Returns:
Model: Created model.
"""
deployment_config_models = deployment_config.models
deployment_db_models = get_models_by_deployment_id(db, deployment.id)
logger.debug(event="create_model_by_config", deployment_models=deployment_config.models, deployment_id=deployment_id, model=model)
deployment_db_models = get_models_by_deployment_id(db, deployment_id)
model_to_return = None
for deployment_config_model in deployment_config_models:
for deployment_config_model in deployment_config.models:
model_in_db = any(record.name == deployment_config_model for record in deployment_db_models)
if not model_in_db:
new_model = Model(
name=deployment_config_model,
cohere_name=deployment_config_model,
deployment_id=deployment.id,
deployment_id=deployment_id,
)
db.add(new_model)
db.commit()
Expand Down
25 changes: 25 additions & 0 deletions src/backend/database_models/seeders/deployments_models_seed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from sqlalchemy.orm import Session

from backend.database_models import Deployment, Model, Organization


def deployments_models_seed(op):
"""
Seed default deployments, models, organization, user and agent.
"""
# Previously we would seed the default deployments and models here. We've changed this
# behaviour during a refactor of the deployments module so that deployments and models
# are inserted when they're first used. This solves an issue where seed data would
# sometimes be inserted with invalid config data.
pass


def delete_default_models(op):
"""
Delete deployments and models.
"""
session = Session(op.get_bind())
session.query(Deployment).delete()
session.query(Model).delete()
session.query(Organization).filter_by(id="default").delete()
session.commit()
Loading

0 comments on commit 74ed897

Please sign in to comment.