diff --git a/docs/chains/doc_gen/generate_reference.py b/docs/chains/doc_gen/generate_reference.py index e8b4323d2..9c92c5112 100644 --- a/docs/chains/doc_gen/generate_reference.py +++ b/docs/chains/doc_gen/generate_reference.py @@ -72,7 +72,7 @@ "truss_chains.remote.ChainService", "truss_chains.make_abs_path_here", "truss_chains.run_local", - "truss_chains.ServiceDescriptor", + "truss_chains.DeployedServiceDescriptor", "truss_chains.StubBase", "truss_chains.RemoteErrorDetail", # "truss_chains.ChainsRuntimeError", diff --git a/pyproject.toml b/pyproject.toml index 6e8a6d56a..db7ca82fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.50" +version = "0.9.51" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss-chains/tests/test_utils.py b/truss-chains/tests/test_utils.py index a438e044c..03fde8f60 100644 --- a/truss-chains/tests/test_utils.py +++ b/truss-chains/tests/test_utils.py @@ -1,10 +1,9 @@ -import copy import json import pytest from truss_chains import definitions -from truss_chains.utils import override_chainlet_to_service_metadata +from truss_chains.utils import populate_chainlet_service_predict_urls DYNAMIC_CHAINLET_CONFIG_VALUE = { "HelloWorld": { @@ -22,23 +21,22 @@ def dynamic_config_mount_dir(tmp_path, monkeypatch: pytest.MonkeyPatch): yield -def test_override_chainlet_to_service_metadata(tmp_path, dynamic_config_mount_dir): +def test_populate_chainlet_service_predict_urls(tmp_path, dynamic_config_mount_dir): with (tmp_path / definitions.DYNAMIC_CHAINLET_CONFIG_KEY).open("w") as f: f.write(json.dumps(DYNAMIC_CHAINLET_CONFIG_VALUE)) chainlet_to_service = { "HelloWorld": definitions.ServiceDescriptor( name="HelloWorld", - predict_url="https://model-model_id.api.baseten.co/deployments/deployment_id/predict", options=definitions.RPCOptions(), ) } - original_chainlet_to_service = copy.deepcopy(chainlet_to_service) - override_chainlet_to_service_metadata(chainlet_to_service) + new_chainlet_to_service = populate_chainlet_service_predict_urls( + chainlet_to_service + ) - assert chainlet_to_service != original_chainlet_to_service assert ( - chainlet_to_service["HelloWorld"].predict_url + new_chainlet_to_service["HelloWorld"].predict_url == DYNAMIC_CHAINLET_CONFIG_VALUE["HelloWorld"]["predict_url"] ) @@ -47,7 +45,7 @@ def test_override_chainlet_to_service_metadata(tmp_path, dynamic_config_mount_di "config", [DYNAMIC_CHAINLET_CONFIG_VALUE, {}, ""], ) -def test_no_override_chainlet_to_service_metadata( +def test_no_populate_chainlet_service_predict_urls( config, tmp_path, dynamic_config_mount_dir ): with (tmp_path / definitions.DYNAMIC_CHAINLET_CONFIG_KEY).open("w") as f: @@ -55,26 +53,12 @@ def test_no_override_chainlet_to_service_metadata( chainlet_to_service = { "RandInt": definitions.ServiceDescriptor( - name="HelloWorld", - predict_url="https://model-model_id.api.baseten.co/deployments/deployment_id/predict", - options=definitions.RPCOptions(), - ) - } - original_chainlet_to_service = copy.deepcopy(chainlet_to_service) - override_chainlet_to_service_metadata(chainlet_to_service) - - assert chainlet_to_service == original_chainlet_to_service - - -def test_no_config_override_chainlet_to_service_metadata(dynamic_config_mount_dir): - chainlet_to_service = { - "HelloWorld": definitions.ServiceDescriptor( - name="HelloWorld", - predict_url="https://model-model_id.api.baseten.co/deployments/deployment_id/predict", + name="RandInt", options=definitions.RPCOptions(), ) } - original_chainlet_to_service = copy.deepcopy(chainlet_to_service) - override_chainlet_to_service_metadata(chainlet_to_service) - assert chainlet_to_service == original_chainlet_to_service + with pytest.raises( + definitions.MissingDependencyError, match="Chainlet 'RandInt' not found" + ): + populate_chainlet_service_predict_urls(chainlet_to_service) diff --git a/truss-chains/truss_chains/__init__.py b/truss-chains/truss_chains/__init__.py index a9dd61551..ea7b07e71 100644 --- a/truss-chains/truss_chains/__init__.py +++ b/truss-chains/truss_chains/__init__.py @@ -25,12 +25,12 @@ ChainletOptions, Compute, CustomImage, + DeployedServiceDescriptor, DeploymentContext, DockerImage, RemoteConfig, RemoteErrorDetail, RPCOptions, - ServiceDescriptor, ) from truss_chains.public_api import ( ChainletBase, @@ -55,7 +55,7 @@ "RPCOptions", "RemoteConfig", "RemoteErrorDetail", - "ServiceDescriptor", + "DeployedServiceDescriptor", "StubBase", "depends", "depends_context", diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index 8107a2573..0aed2f181 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -623,14 +623,12 @@ def gen_truss_chainlet( chain_name: str, chainlet_descriptor: definitions.ChainletAPIDescriptor, model_name: str, - chainlet_display_name_to_url: Mapping[str, str], ) -> pathlib.Path: # Filter needed services and customize options. dep_services = {} for dep in chainlet_descriptor.dependencies.values(): dep_services[dep.name] = definitions.ServiceDescriptor( name=dep.name, - predict_url=chainlet_display_name_to_url[dep.display_name], options=dep.options, ) diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 04d08ed86..d5ba85da5 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -75,7 +75,6 @@ class SafeModel(pydantic.BaseModel): arbitrary_types_allowed=False, strict=True, validate_assignment=True, - extra="forbid", ) @@ -397,10 +396,13 @@ class ServiceDescriptor(SafeModel): specifically with ``StubBase``.""" name: str - predict_url: str options: RPCOptions +class DeployedServiceDescriptor(ServiceDescriptor): + predict_url: str + + class Environment(SafeModel): """The environment the chainlet is deployed in. @@ -422,7 +424,6 @@ class DeploymentContext(SafeModelNonSerializable): Args: data_dir: The directory where the chainlet can store and access data, e.g. for downloading model weights. - user_config: User-defined configuration for the chainlet. chainlet_to_service: A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. @@ -434,11 +435,11 @@ class DeploymentContext(SafeModelNonSerializable): """ data_dir: Optional[pathlib.Path] = None - chainlet_to_service: Mapping[str, ServiceDescriptor] + chainlet_to_service: Mapping[str, DeployedServiceDescriptor] secrets: MappingNoIter[str, str] environment: Optional[Environment] = None - def get_service_descriptor(self, chainlet_name: str) -> ServiceDescriptor: + def get_service_descriptor(self, chainlet_name: str) -> DeployedServiceDescriptor: if chainlet_name not in self.chainlet_to_service: raise MissingDependencyError(f"{chainlet_name}") return self.chainlet_to_service[chainlet_name] diff --git a/truss-chains/truss_chains/framework.py b/truss-chains/truss_chains/framework.py index 9e24366ae..771ee73d0 100644 --- a/truss-chains/truss_chains/framework.py +++ b/truss-chains/truss_chains/framework.py @@ -840,7 +840,7 @@ def _create_modified_init_for_local( ], secrets: Mapping[str, str], data_dir: Optional[pathlib.Path], - chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], + chainlet_to_service: Mapping[str, definitions.DeployedServiceDescriptor], ): """Replaces the default argument values with local Chainlet instantiations. @@ -1011,7 +1011,7 @@ def __init_local__(self: definitions.ABCChainlet, **kwargs) -> None: def run_local( secrets: Mapping[str, str], data_dir: Optional[pathlib.Path], - chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], + chainlet_to_service: Mapping[str, definitions.DeployedServiceDescriptor], ) -> Any: """Context to run Chainlets with dependency injection from local instances.""" # TODO: support retries in local mode. diff --git a/truss-chains/truss_chains/model_skeleton.py b/truss-chains/truss_chains/model_skeleton.py index 2173b5fed..6f637e8d9 100644 --- a/truss-chains/truss_chains/model_skeleton.py +++ b/truss-chains/truss_chains/model_skeleton.py @@ -4,7 +4,7 @@ from truss.templates.shared import secrets_resolver from truss_chains import definitions -from truss_chains.utils import override_chainlet_to_service_metadata +from truss_chains.utils import populate_chainlet_service_predict_urls class TrussChainletModel: @@ -28,10 +28,12 @@ def __init__( deployment_environment: Optional[definitions.Environment] = ( definitions.Environment.model_validate(environment) if environment else None ) - override_chainlet_to_service_metadata(truss_metadata.chainlet_to_service) + chainlet_to_deployed_service = populate_chainlet_service_predict_urls( + truss_metadata.chainlet_to_service + ) self._context = definitions.DeploymentContext( - chainlet_to_service=truss_metadata.chainlet_to_service, + chainlet_to_service=chainlet_to_deployed_service, secrets=secrets, data_dir=data_dir, environment=deployment_environment, diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index aab38f798..f3ed5109d 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -162,7 +162,9 @@ def push( def run_local( secrets: Optional[Mapping[str, str]] = None, data_dir: Optional[Union[pathlib.Path, str]] = None, - chainlet_to_service: Optional[Mapping[str, definitions.ServiceDescriptor]] = None, + chainlet_to_service: Optional[ + Mapping[str, definitions.DeployedServiceDescriptor] + ] = None, ) -> ContextManager[None]: """Context manager local debug execution of a chain. @@ -188,7 +190,7 @@ class HelloWorld(chains.ChainletBase): with chains.run_local( secrets={"some_token": os.environ["SOME_TOKEN"]}, chainlet_to_service={ - "SomeChainlet": chains.ServiceDescriptor( + "SomeChainlet": chains.DeployedServiceDescriptor( name="SomeChainlet", predict_url="https://...", options=chains.RPCOptions(), diff --git a/truss-chains/truss_chains/remote.py b/truss-chains/truss_chains/remote.py index 564376f4e..0d29b043d 100644 --- a/truss-chains/truss_chains/remote.py +++ b/truss-chains/truss_chains/remote.py @@ -1,9 +1,9 @@ import abc import concurrent.futures import inspect +import json import logging import pathlib -import re import tempfile import textwrap import traceback @@ -16,8 +16,6 @@ Iterable, Iterator, Mapping, - MutableMapping, - NamedTuple, Optional, Type, cast, @@ -28,6 +26,7 @@ if TYPE_CHECKING: from rich import console as rich_console +from truss.local import local_config_handler from truss.remote import remote_cli, remote_factory from truss.remote.baseten import core as b10_core from truss.remote.baseten import custom_types as b10_types @@ -39,38 +38,15 @@ from truss_chains import code_gen, definitions, framework, utils -_MODEL_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]+-[0-9a-f]{8}$") - - -def _push_to_baseten( - truss_dir: pathlib.Path, options: definitions.PushOptionsBaseten, chainlet_name: str -) -> b10_service.BasetenService: - truss_handle = truss_build.load(str(truss_dir)) - model_name = truss_handle.spec.config.model_name - assert model_name is not None - assert bool(_MODEL_NAME_RE.match(model_name)) - logging.info( - f"Pushing chainlet `{model_name}` as a truss model on " - f"Baseten (publish={options.publish})" - ) - # Models must be trusted to use the API KEY secret. - service = options.remote_provider.push( - truss_handle, - model_name=model_name, - trusted=True, - publish=options.publish, - origin=b10_types.ModelOrigin.CHAINS, - chain_environment=options.environment, - chainlet_name=chainlet_name, - chain_name=options.chain_name, - ) - return cast(b10_service.BasetenService, service) - class DockerTrussService(b10_service.TrussService): """This service is for Chainlets (not for Chains).""" - def __init__(self, remote_url: str, is_draft: bool, **kwargs): + def __init__(self, port: int, is_draft: bool, **kwargs): + # http://localhost:{port} seems to only work *sometimes* with docker. + remote_url = f"http://host.docker.internal:{port}" + self._port = port + super().__init__(remote_url, is_draft, **kwargs) def authenticate(self) -> Dict[str, str]: @@ -92,6 +68,10 @@ def is_ready(self) -> bool: def logs_url(self) -> str: raise NotImplementedError() + @property + def port(self) -> int: + return self._port + @property def predict_url(self) -> str: return f"{self._service_url}/v1/models/model:predict" @@ -100,44 +80,27 @@ def poll_deployment_status(self, sleep_secs: int = 1) -> Iterator[str]: raise NotImplementedError() -def _push_service( +def _push_service_docker( truss_dir: pathlib.Path, - chainlet_descriptor: definitions.ChainletAPIDescriptor, - options: definitions.PushOptions, -) -> b10_service.TrussService: - service: b10_service.TrussService - if isinstance(options, definitions.PushOptionsLocalDocker): - logging.info( - f"Running in docker container `{chainlet_descriptor.display_name}` " - ) - port = utils.get_free_port() - truss_handle = truss_build.load(str(truss_dir)) - truss_handle.add_secret( - definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key - ) - truss_handle.docker_run( - local_port=port, - detach=True, - wait_for_server_ready=True, - network="host", - container_name_prefix=chainlet_descriptor.display_name, - ) - # http://localhost:{port} seems to only work *sometimes* with docker. - service = DockerTrussService( - f"http://host.docker.internal:{port}", is_draft=True - ) - elif isinstance(options, definitions.PushOptionsBaseten): - with utils.log_level(logging.INFO): - # We send the display_name of the chainlet in subsequent steps. - service = _push_to_baseten( - truss_dir, options, chainlet_descriptor.display_name - ) - else: - raise NotImplementedError(options) + chainlet_display_name: str, + options: definitions.PushOptionsLocalDocker, + port: int, +) -> None: + logging.info(f"Running in docker container `{chainlet_display_name}` ") - logging.info(f"Pushed `{chainlet_descriptor.display_name}`") - logging.debug(f"Internal model endpoint: `{service.predict_url}`") - return service + truss_handle = truss_build.load(str(truss_dir)) + + truss_handle.add_secret( + definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key + ) + + truss_handle.docker_run( + local_port=port, + detach=True, + wait_for_server_ready=True, + network="host", + container_name_prefix=chainlet_display_name, + ) def _get_ordered_dependencies( @@ -225,14 +188,14 @@ def entrypoint_fake_json_data(self, fake_data: Any) -> None: class BasetenChainService(ChainService): - _chain_deployment_handle: b10_core.ChainDeploymentHandle + _chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic _remote: b10_remote.BasetenRemote def __init__( self, name: str, entrypoint_service: b10_service.BasetenService, - chain_deployment_handle: b10_core.ChainDeploymentHandle, + chain_deployment_handle: b10_core.ChainDeploymentHandleAtomic, remote: b10_remote.BasetenRemote, ) -> None: super().__init__(name, entrypoint_service) @@ -309,25 +272,22 @@ def _get_chain_root( def _create_baseten_chain( baseten_options: definitions.PushOptionsBaseten, - chainlet_services: list["_Pusher.ChainEntry"], - entrypoint_service: b10_service.BasetenService, + entrypoint_artifact: b10_types.ChainletArtifact, + dependency_artifacts: list[b10_types.ChainletArtifact], ): - chainlet_data = [] - for chain_entry in chainlet_services: - assert isinstance(chain_entry.service, b10_service.BasetenService) - chainlet_data.append( - b10_types.ChainletData( - name=chain_entry.chainlet_display_name, - oracle_version_id=chain_entry.service.model_version_id, - is_entrypoint=chain_entry.is_entrypoint, - ) + chain_deployment_handle, entrypoint_service = ( + baseten_options.remote_provider.push_chain_atomic( + chain_name=baseten_options.chain_name, + entrypoint_artifact=entrypoint_artifact, + dependency_artifacts=dependency_artifacts, + publish=baseten_options.publish, + environment=baseten_options.environment, ) - chain_deployment_handle = baseten_options.remote_provider.create_chain( - chain_name=baseten_options.chain_name, - chainlets=chainlet_data, - publish=baseten_options.publish, - environment=baseten_options.environment, ) + + logging.info(f"Pushed Chain '{baseten_options.chain_name}'.") + logging.debug(f"Internal model endpoint: '{entrypoint_service.predict_url}'.") + return BasetenChainService( baseten_options.chain_name, entrypoint_service, @@ -351,12 +311,7 @@ def _create_chains_secret_if_missing(remote_provider: b10_remote.BasetenRemote) ) -class _Pusher: - class ChainEntry(NamedTuple): - service: b10_service.TrussService - chainlet_display_name: str - is_entrypoint: bool - +class _ChainSourceGenerator: def __init__( self, options: definitions.PushOptions, @@ -364,69 +319,52 @@ def __init__( ) -> None: self._options = options self._gen_root = gen_root or pathlib.Path(tempfile.gettempdir()) - if isinstance(self._options, definitions.PushOptionsBaseten): - _create_chains_secret_if_missing(self._options.remote_provider) - def push( + def generate_chainlet_artifacts( self, entrypoint: Type[definitions.ABCChainlet], non_entrypoint_root_dir: Optional[str] = None, - ) -> Optional[ChainService]: + ) -> tuple[b10_types.ChainletArtifact, list[b10_types.ChainletArtifact]]: chain_root = _get_chain_root(entrypoint, non_entrypoint_root_dir) - chainlet_display_name_to_url: MutableMapping[str, str] = {} - chainlet_services: list[_Pusher.ChainEntry] = [] - entrypoint_service = None + entrypoint_artifact: Optional[b10_types.ChainletArtifact] = None + dependency_artifacts: list[b10_types.ChainletArtifact] = [] + for chainlet_descriptor in _get_ordered_dependencies([entrypoint]): model_base_name = chainlet_descriptor.display_name # Since we are creating a distinct model for each deployment of the chain, # we add a random suffix. model_suffix = str(uuid.uuid4()).split("-")[0] model_name = f"{model_base_name}-{model_suffix}" + logging.info( - f"Generating truss chainlet model for `{chainlet_descriptor.name}`." + f"Generating Truss Chainlet model for '{chainlet_descriptor.name}'." ) + chainlet_dir = code_gen.gen_truss_chainlet( chain_root, self._gen_root, self._options.chain_name, chainlet_descriptor, model_name, - chainlet_display_name_to_url, ) - if self._options.only_generate_trusses: - chainlet_display_name_to_url[chainlet_descriptor.display_name] = ( - "http://dummy" - ) - continue + artifact = b10_types.ChainletArtifact( + truss_dir=chainlet_dir, + name=chainlet_descriptor.name, + display_name=chainlet_descriptor.display_name, + ) is_entrypoint = chainlet_descriptor.chainlet_cls == entrypoint - service = _push_service(chainlet_dir, chainlet_descriptor, self._options) - chainlet_display_name_to_url[chainlet_descriptor.display_name] = ( - service.predict_url - ) - chainlet_services.append( - _Pusher.ChainEntry( - service, chainlet_descriptor.display_name, is_entrypoint - ) - ) + if is_entrypoint: - assert entrypoint_service is None - entrypoint_service = service + assert entrypoint_artifact is None - if self._options.only_generate_trusses: - return None - assert entrypoint_service is not None + entrypoint_artifact = artifact + else: + dependency_artifacts.append(artifact) - if isinstance(self._options, definitions.PushOptionsBaseten): - assert isinstance(entrypoint_service, b10_service.BasetenService) - return _create_baseten_chain( - self._options, chainlet_services, entrypoint_service - ) - elif isinstance(self._options, definitions.PushOptionsLocalDocker): - assert isinstance(entrypoint_service, DockerTrussService) - return DockerChainService(self._options.chain_name, entrypoint_service) - else: - raise NotImplementedError(self._options) + assert entrypoint_artifact is not None + + return entrypoint_artifact, dependency_artifacts @framework.raise_validation_errors_before @@ -436,7 +374,68 @@ def push( non_entrypoint_root_dir: Optional[str] = None, gen_root: pathlib.Path = pathlib.Path(tempfile.gettempdir()), ) -> Optional[ChainService]: - return _Pusher(options, gen_root).push(entrypoint, non_entrypoint_root_dir) + entrypoint_artifact, dependency_artifacts = _ChainSourceGenerator( + options, gen_root + ).generate_chainlet_artifacts( + entrypoint, + non_entrypoint_root_dir, + ) + + if options.only_generate_trusses: + return None + + if isinstance(options, definitions.PushOptionsBaseten): + _create_chains_secret_if_missing(options.remote_provider) + return _create_baseten_chain(options, entrypoint_artifact, dependency_artifacts) + elif isinstance(options, definitions.PushOptionsLocalDocker): + chainlet_artifacts = [entrypoint_artifact, *dependency_artifacts] + chainlet_to_predict_url: Dict[str, Dict[str, str]] = {} + chainlet_to_service: Dict[str, DockerTrussService] = {} + + for chainlet_artifact in chainlet_artifacts: + port = utils.get_free_port() + + service = DockerTrussService( + is_draft=True, + port=port, + ) + chainlet_to_predict_url[chainlet_artifact.name] = { + "predict_url": service.predict_url, + } + chainlet_to_service[chainlet_artifact.name] = service + + local_config_handler.LocalConfigHandler.set_dynamic_config( + definitions.DYNAMIC_CHAINLET_CONFIG_KEY, + json.dumps(chainlet_to_predict_url), + ) + + # TODO(Tyron): We run the Docker containers in a + # separate for-loop to make sure that the dynamic + # config is populated (the same one gets mounted + # on all the containers). We should look into + # consolidating the logic into a single for-loop. + # One approach might be to use separate config + # paths for each container under the `/tmp` dir. + for chainlet_artifact in chainlet_artifacts: + truss_dir = chainlet_artifact.truss_dir + + _push_service_docker( + truss_dir, + chainlet_artifact.display_name, + options, + chainlet_to_service[chainlet_artifact.name].port, + ) + + logging.info(f"Pushed `{chainlet_artifact.display_name}`") + logging.debug( + f"Internal model endpoint: `{chainlet_to_predict_url[chainlet_artifact.name]}`" + ) + + return DockerChainService( + options.chain_name, chainlet_to_service[entrypoint_artifact.name] + ) + else: + raise NotImplementedError(options) # Watch / Live Patching ################################################################ @@ -525,10 +524,6 @@ def watch_filter(_: watchfiles.Change, path: str) -> bool: def _original_chainlet_names(self) -> set[str]: return set(self._chainlet_data.keys()) - @property - def _chainlet_display_name_to_url(self) -> Mapping[str, str]: - return {k: v.oracle_predict_url for k, v in self._chainlet_data.items()} - def _assert_chainlet_names_same(self, new_names: set[str]) -> None: missing = self._original_chainlet_names - new_names added = new_names - self._original_chainlet_names @@ -556,7 +551,6 @@ def _code_gen_and_patch_thread( self._deployed_chain_name, descr, self._chainlet_data[descr.display_name].oracle_name, - self._chainlet_display_name_to_url, ) patch_result = self._remote_provider.patch_for_chainlet( chainlet_dir, self._ignore_patterns diff --git a/truss-chains/truss_chains/stub.py b/truss-chains/truss_chains/stub.py index 4870dae48..3d7679fd9 100644 --- a/truss-chains/truss_chains/stub.py +++ b/truss-chains/truss_chains/stub.py @@ -43,13 +43,13 @@ class BasetenSession: max_keepalive_connections=DEFAULT_MAX_KEEPALIVE_CONNECTIONS, ) _auth_header: Mapping[str, str] - _service_descriptor: definitions.ServiceDescriptor + _service_descriptor: definitions.DeployedServiceDescriptor _cached_sync_client: Optional[tuple[httpx.Client, int]] _cached_async_client: Optional[tuple[aiohttp.ClientSession, int]] def __init__( self, - service_descriptor: definitions.ServiceDescriptor, + service_descriptor: definitions.DeployedServiceDescriptor, api_key: str, ) -> None: logging.info( @@ -220,7 +220,9 @@ def __init__(self, ..., context=chains.depends_context()): @final def __init__( - self, service_descriptor: definitions.ServiceDescriptor, api_key: str + self, + service_descriptor: definitions.DeployedServiceDescriptor, + api_key: str, ) -> None: """ Args: @@ -245,7 +247,7 @@ def from_url( """ options = options or definitions.RPCOptions() return cls( - definitions.ServiceDescriptor( + service_descriptor=definitions.DeployedServiceDescriptor( name=cls.__name__, predict_url=predict_url, options=options ), api_key=context.get_baseten_api_key(), diff --git a/truss-chains/truss_chains/utils.py b/truss-chains/truss_chains/utils.py index 66a89efb2..773a0be76 100644 --- a/truss-chains/truss_chains/utils.py +++ b/truss-chains/truss_chains/utils.py @@ -12,7 +12,17 @@ import textwrap import threading import traceback -from typing import Any, Iterable, Iterator, Mapping, NoReturn, Type, TypeVar, Union +from typing import ( + Any, + Dict, + Iterable, + Iterator, + Mapping, + NoReturn, + Type, + TypeVar, + Union, +) import aiohttp import fastapi @@ -131,33 +141,57 @@ def get_free_port() -> int: return port -def override_chainlet_to_service_metadata( +def populate_chainlet_service_predict_urls( chainlet_to_service: Mapping[str, definitions.ServiceDescriptor], -): - # Override predict_urls in chainlet_to_service ServiceDescriptors if dynamic_chainlet_config exists +) -> Mapping[str, definitions.DeployedServiceDescriptor]: + chainlet_to_deployed_service: Dict[str, definitions.DeployedServiceDescriptor] = {} + dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value_sync( definitions.DYNAMIC_CHAINLET_CONFIG_KEY ) - if dynamic_chainlet_config_str: - dynamic_chainlet_config = json.loads(dynamic_chainlet_config_str) - for ( - chainlet_name, - service_descriptor, - ) in chainlet_to_service.items(): - if chainlet_name in dynamic_chainlet_config: - # We update the predict_url to be the one pulled from the dynamic_chainlet_config - service_descriptor.predict_url = dynamic_chainlet_config[chainlet_name][ - "predict_url" - ] - else: - logging.debug( - f"Skipped override for chainlet '{chainlet_name}': not found in {definitions.DYNAMIC_CHAINLET_CONFIG_KEY}." - ) - else: - logging.debug( - f"No {definitions.DYNAMIC_CHAINLET_CONFIG_KEY} found, skipping overrides." + + if not dynamic_chainlet_config_str: + raise definitions.MissingDependencyError( + f"No '{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}' found. Cannot override Chainlet configs." + ) + + dynamic_chainlet_config = json.loads(dynamic_chainlet_config_str) + + for ( + chainlet_name, + service_descriptor, + ) in chainlet_to_service.items(): + if chainlet_name not in dynamic_chainlet_config: + raise definitions.MissingDependencyError( + f"Chainlet '{chainlet_name}' not found in '{definitions.DYNAMIC_CHAINLET_CONFIG_KEY}'." + ) + + chainlet_to_deployed_service[chainlet_name] = ( + definitions.DeployedServiceDescriptor( + name=service_descriptor.name, + options=service_descriptor.options, + predict_url=dynamic_chainlet_config[chainlet_name]["predict_url"], + ) ) + return chainlet_to_deployed_service + + +# NOTE: This needs to be available in the Context Builder +# so that older Truss CLI versions that generate code that +# expects this function to be available continue to work. +def override_chainlet_to_service_metadata( + chainlet_to_service: Dict[ + str, Union[definitions.ServiceDescriptor, definitions.DeployedServiceDescriptor] + ], +) -> None: + chainlet_to_deployed_service = populate_chainlet_service_predict_urls( + chainlet_to_service + ) + + for chainlet_name in chainlet_to_service.keys(): + chainlet_to_service[chainlet_name] = chainlet_to_deployed_service[chainlet_name] + # Error Propagation Utils. ############################################################# diff --git a/truss/base/validation.py b/truss/base/validation.py index f071ab8d1..b4b468abc 100644 --- a/truss/base/validation.py +++ b/truss/base/validation.py @@ -24,6 +24,12 @@ "Ei": 1024**6, } +_MODEL_NAME_RE = re.compile(r"^[a-zA-Z0-9_-]+-[0-9a-f]{8}$") + + +def is_valid_model_name(model_name: str) -> bool: + return bool(_MODEL_NAME_RE.match(model_name)) + def validate_secret_to_path_mapping(secret_to_path_mapping: Dict[str, str]) -> None: if not isinstance(secret_to_path_mapping, dict): diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index 85d903b63..af9814f16 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -3,6 +3,7 @@ from typing import Any, List, Optional import requests +import truss from truss.remote.baseten import custom_types as b10_types from truss.remote.baseten.auth import ApiKey, AuthService from truss.remote.baseten.error import ApiError @@ -23,14 +24,42 @@ DEFAULT_API_DOMAIN = "https://api.baseten.co" -def _chainlet_data_to_graphql_mutation(chainlet: b10_types.ChainletData): - return f""" - {{ - name: "{chainlet.name}", - oracle_version_id: "{chainlet.oracle_version_id}", - is_entrypoint: {'true' if chainlet.is_entrypoint else 'false'} - }} - """ +def _oracle_data_to_graphql_mutation(oracle: b10_types.OracleData) -> str: + args = [ + f'model_name: "{oracle.model_name}"', + f's3_key: "{oracle.s3_key}"', + f'encoded_config_str: "{oracle.encoded_config_str}"', + f"is_trusted: {str(oracle.is_trusted).lower()}", + ] + + if oracle.semver_bump: + args.append(f'semver_bump: "{oracle.semver_bump}"') + + if oracle.version_name: + args.append(f'version_name: "{oracle.version_name}"') + + args_str = ",\n".join(args) + + return f"""{{ + {args_str} + }}""" + + +def _chainlet_data_atomic_to_graphql_mutation( + chainlet: b10_types.ChainletDataAtomic, +) -> str: + oracle_data_string = _oracle_data_to_graphql_mutation(chainlet.oracle) + + args = [ + f'name: "{chainlet.name}"', + f"oracle: {oracle_data_string}", + ] + + args_str = ",\n".join(args) + + return f"""{{ + {args_str} + }}""" class BasetenApi: @@ -110,9 +139,6 @@ def create_model_from_truss( allow_truss_download: bool = True, deployment_name: Optional[str] = None, origin: Optional[b10_types.ModelOrigin] = None, - chain_environment: Optional[str] = None, - chainlet_name: Optional[str] = None, - chain_name: Optional[str] = None, ): query_string = f""" mutation {{ @@ -126,9 +152,6 @@ def create_model_from_truss( allow_truss_download: {'true' if allow_truss_download else 'false'}, {f'version_name: "{deployment_name}"' if deployment_name else ""} {f'model_origin: {origin.value}' if origin else ""} - {f'chain_environment: "{chain_environment}"' if chain_environment else ""} - {f'chainlet_name: "{chainlet_name}"' if chainlet_name else ""} - {f'chain_name: "{chain_name}"' if chain_name else ""} ) {{ id, name, @@ -202,72 +225,46 @@ def create_development_model_from_truss( resp = self._post_graphql_query(query_string) return resp["data"]["deploy_draft_truss"] - def deploy_chain(self, name: str, chainlet_data: List[b10_types.ChainletData]): - chainlet_data_strings = [ - _chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data - ] - - chainlets_string = ", ".join(chainlet_data_strings) - query_string = f""" - mutation {{ - deploy_chain( - name: "{name}", - chainlets: [{chainlets_string}] - ) {{ - id - chain_id - chain_deployment_id - }} - }} - """ - resp = self._post_graphql_query(query_string) - return resp["data"]["deploy_chain"] - - def deploy_draft_chain( - self, name: str, chainlet_data: List[b10_types.ChainletData] - ): - chainlet_data_strings = [ - _chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data - ] - chainlets_string = ", ".join(chainlet_data_strings) - query_string = f""" - mutation {{ - deploy_draft_chain( - name: "{name}", - chainlets: [{chainlets_string}] - ) {{ - chain_id - chain_deployment_id - }} - }} - """ - resp = self._post_graphql_query(query_string) - return resp["data"]["deploy_draft_chain"] - - def deploy_chain_deployment( + def deploy_chain_atomic( self, - chain_id: str, - chainlet_data: List[b10_types.ChainletData], + entrypoint: b10_types.ChainletDataAtomic, + dependencies: List[b10_types.ChainletDataAtomic], + chain_id: Optional[str] = None, + chain_name: Optional[str] = None, environment: Optional[str] = None, + is_draft: bool = False, ): - chainlet_data_strings = [ - _chainlet_data_to_graphql_mutation(chainlet) for chainlet in chainlet_data - ] - chainlets_string = ", ".join(chainlet_data_strings) + entrypoint_str = _chainlet_data_atomic_to_graphql_mutation(entrypoint) + + dependencies_str = ", ".join( + [ + _chainlet_data_atomic_to_graphql_mutation(dependency) + for dependency in dependencies + ] + ) + query_string = f""" - mutation {{ - deploy_chain_deployment( - chain_id: "{chain_id}", - chainlets: [{chainlets_string}], - {f'environment_name: "{environment}"' if environment else ""} - ) {{ - chain_id - chain_deployment_id + mutation {{ + deploy_chain_atomic( + {f'chain_id: "{chain_id}"' if chain_id else ""} + {f'chain_name: "{chain_name}"' if chain_name else ""} + {f'environment: "{environment}"' if environment else ""} + is_draft: {str(is_draft).lower()} + entrypoint: {entrypoint_str} + dependencies: [{dependencies_str}] + client_version: "truss=={truss.version()}" + ) {{ + chain_id + chain_deployment_id + entrypoint_model_id + entrypoint_model_version_id + }} }} - }} """ + resp = self._post_graphql_query(query_string) - return resp["data"]["deploy_chain_deployment"] + + return resp["data"]["deploy_chain_atomic"] def get_chains(self): query_string = """ diff --git a/truss/remote/baseten/core.py b/truss/remote/baseten/core.py index d3b603504..eda2597ba 100644 --- a/truss/remote/baseten/core.py +++ b/truss/remote/baseten/core.py @@ -42,12 +42,6 @@ def __init__(self, model_version_id: str): self.value = model_version_id -class ChainDeploymentHandle(typing.NamedTuple): - chain_id: str - chain_deployment_id: str - is_draft: bool - - class PatchState(typing.NamedTuple): current_hash: str current_signature: str @@ -63,6 +57,14 @@ class TrussWatchState(typing.NamedTuple): patches: Optional[TrussPatches] +class ChainDeploymentHandleAtomic(typing.NamedTuple): + chain_id: str + chain_deployment_id: str + is_draft: bool + entrypoint_model_id: str + entrypoint_model_version_id: str + + def get_chain_id_by_name(api: BasetenApi, chain_name: str) -> Optional[str]: """ Check if a chain with the given name exists in the Baseten remote. @@ -93,23 +95,45 @@ def get_dev_chain_deployment(api: BasetenApi, chain_id: str): return newest_draft_deployment -def create_chain( +def create_chain_atomic( api: BasetenApi, - chain_id: Optional[str], chain_name: str, - chainlets: List[b10_types.ChainletData], + entrypoint: b10_types.ChainletDataAtomic, + dependencies: List[b10_types.ChainletDataAtomic], is_draft: bool, environment: Optional[str], -) -> ChainDeploymentHandle: +) -> ChainDeploymentHandleAtomic: + if environment and is_draft: + logging.info( + f"Automatically publishing Chain '{chain_name}' based on environment setting." + ) + is_draft = False + + chain_id = get_chain_id_by_name(api, chain_name) + + # TODO(Tyron): Refactor for better readability: + # 1. Prepare all arguments for `deploy_chain_atomic`. + # 2. Validate argument combinations. + # 3. Make a single invocation to `deploy_chain_atomic`. if is_draft: - response = api.deploy_draft_chain(chain_name, chainlets) + res = api.deploy_chain_atomic( + chain_name=chain_name, + is_draft=True, + entrypoint=entrypoint, + dependencies=dependencies, + ) elif chain_id: # This is the only case where promote has relevance, since # if there is no chain already, the first deployment will # already be production, and only published deployments can # be promoted. try: - response = api.deploy_chain_deployment(chain_id, chainlets, environment) + res = api.deploy_chain_atomic( + chain_id=chain_id, + environment=environment, + entrypoint=entrypoint, + dependencies=dependencies, + ) except ApiError as e: if ( e.graphql_error_code @@ -118,15 +142,22 @@ def create_chain( raise ValueError( f'Environment "{environment}" does not exist. You can create environments in the Chains UI.' ) from e + raise e + elif environment and environment != PRODUCTION_ENVIRONMENT_NAME: + raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING) else: - if environment and environment != PRODUCTION_ENVIRONMENT_NAME: - raise ValueError(NO_ENVIRONMENTS_EXIST_ERROR_MESSAGING) - response = api.deploy_chain(chain_name, chainlets) + res = api.deploy_chain_atomic( + chain_name=chain_name, + entrypoint=entrypoint, + dependencies=dependencies, + ) - return ChainDeploymentHandle( - chain_id=response["chain_id"], - chain_deployment_id=response["chain_deployment_id"], + return ChainDeploymentHandleAtomic( + chain_id=res["chain_id"], + chain_deployment_id=res["chain_deployment_id"], + entrypoint_model_id=res["entrypoint_model_id"], + entrypoint_model_version_id=res["entrypoint_model_version_id"], is_draft=is_draft, ) @@ -293,9 +324,6 @@ def create_truss_service( deployment_name: Optional[str] = None, origin: Optional[b10_types.ModelOrigin] = None, environment: Optional[str] = None, - chain_environment: Optional[str] = None, - chainlet_name: Optional[str] = None, - chain_name: Optional[str] = None, ) -> Tuple[str, str]: """ Create a model in the Baseten remote. @@ -342,9 +370,6 @@ def create_truss_service( allow_truss_download=allow_truss_download, deployment_name=deployment_name, origin=origin, - chain_environment=chain_environment, - chainlet_name=chainlet_name, - chain_name=chain_name, ) return model_version_json["id"], model_version_json["version_id"] diff --git a/truss/remote/baseten/custom_types.py b/truss/remote/baseten/custom_types.py index 8fbc52833..0c34d59f8 100644 --- a/truss/remote/baseten/custom_types.py +++ b/truss/remote/baseten/custom_types.py @@ -1,4 +1,6 @@ +import pathlib from enum import Enum +from typing import Optional import pydantic @@ -13,12 +15,27 @@ class DeployedChainlet(pydantic.BaseModel): oracle_name: str -class ChainletData(pydantic.BaseModel): +class ChainletArtifact(pydantic.BaseModel): + truss_dir: pathlib.Path + display_name: str name: str - oracle_version_id: str - is_entrypoint: bool class ModelOrigin(Enum): BASETEN = "BASETEN" CHAINS = "CHAINS" + + +class OracleData(pydantic.BaseModel): + model_name: str + s3_key: str + encoded_config_str: str + semver_bump: Optional[str] = "MINOR" + is_trusted: bool + version_name: Optional[str] = None + + +# This corresponds to `ChainletInputAtomicGraphene` in the backend. +class ChainletDataAtomic(pydantic.BaseModel): + name: str + oracle: OracleData diff --git a/truss/remote/baseten/remote.py b/truss/remote/baseten/remote.py index 2b5596760..276c122bd 100644 --- a/truss/remote/baseten/remote.py +++ b/truss/remote/baseten/remote.py @@ -10,22 +10,22 @@ if TYPE_CHECKING: from rich import console as rich_console +from truss.base import validation from truss.base.truss_config import ModelServer from truss.local.local_config_handler import LocalConfigHandler from truss.remote.baseten import custom_types from truss.remote.baseten.api import BasetenApi from truss.remote.baseten.auth import AuthService from truss.remote.baseten.core import ( - ChainDeploymentHandle, + ChainDeploymentHandleAtomic, ModelId, ModelIdentifier, ModelName, ModelVersionId, archive_truss, - create_chain, + create_chain_atomic, create_truss_service, exists_model, - get_chain_id_by_name, get_dev_version, get_dev_version_from_versions, get_model_versions, @@ -37,6 +37,7 @@ from truss.remote.baseten.service import BasetenService, URLConfig from truss.remote.baseten.utils.transfer import base64_encoded_json_str from truss.remote.truss_remote import RemoteUser, TrussRemote +from truss.truss_handle import build as truss_build from truss.truss_handle.truss_handle import TrussHandle from truss.util.path import is_ignored, load_trussignore_patterns_from_truss_dir from watchfiles import watch @@ -53,6 +54,15 @@ class PatchResult(NamedTuple): message: str +class FinalPushData(custom_types.OracleData): + is_draft: bool + model_id: Optional[str] + preserve_previous_prod_deployment: bool + origin: Optional[custom_types.ModelOrigin] = None + environment: Optional[str] = None + allow_truss_download: bool + + class BasetenRemote(TrussRemote): def __init__(self, remote_url: str, api_key: str, **kwargs): super().__init__(remote_url, **kwargs) @@ -63,28 +73,6 @@ def __init__(self, remote_url: str, api_key: str, **kwargs): def api(self) -> BasetenApi: return self._api - def create_chain( - self, - chain_name: str, - chainlets: List[custom_types.ChainletData], - publish: bool = False, - environment: Optional[str] = None, - ) -> ChainDeploymentHandle: - if environment: - # If we are promoting a model to an environment after deploy, it must be published. - # Draft models cannot be promoted. - publish = True - # Returns tuple of (chain_id, chain_deployment_id) - chain_id = get_chain_id_by_name(self._api, chain_name) - return create_chain( - self._api, - chain_id=chain_id, - chain_name=chain_name, - chainlets=chainlets, - is_draft=not publish, - environment=environment, - ) - def get_chainlets( self, chain_deployment_id: str ) -> List[custom_types.DeployedChainlet]: @@ -127,7 +115,9 @@ def whoami(self) -> RemoteUser: user_email, ) - def push( # type: ignore + # Validate and finalize options. + # Upload Truss files to S3 and return S3 key. + def _prepare_push( self, truss_handle: TrussHandle, model_name: str, @@ -139,25 +129,24 @@ def push( # type: ignore deployment_name: Optional[str] = None, origin: Optional[custom_types.ModelOrigin] = None, environment: Optional[str] = None, - chain_environment: Optional[str] = None, - chainlet_name: Optional[str] = None, - chain_name: Optional[str] = None, - ) -> BasetenService: + ) -> FinalPushData: if model_name.isspace(): raise ValueError("Model name cannot be empty") - model_id = exists_model(self._api, model_name) - gathered_truss = TrussHandle(truss_handle.gather()) + if gathered_truss.spec.model_server != ModelServer.TrussServer: publish = True if promote: environment = PRODUCTION_ENVIRONMENT_NAME - if environment: - # If there is a target environment, it must be published. - # Draft models cannot be promoted. + # If there is a target environment, it must be published. + # Draft models cannot be promoted. + if environment and not publish: + logging.info( + f"Automatically publishing model '{model_name}' based on environment setting." + ) publish = True if not publish and deployment_name: @@ -176,44 +165,161 @@ def push( # type: ignore "Deployment name must only contain alphanumeric, -, _ and . characters" ) + model_id = exists_model(self._api, model_name) + if model_id is not None and disable_truss_download: raise ValueError("disable-truss-download can only be used for new models") + temp_file = archive_truss(gathered_truss) + s3_key = upload_truss(self._api, temp_file) encoded_config_str = base64_encoded_json_str( gathered_truss._spec._config.to_dict() ) - temp_file = archive_truss(gathered_truss) - s3_key = upload_truss(self._api, temp_file) - - model_id, model_version_id = create_truss_service( - api=self._api, + return FinalPushData( model_name=model_name, s3_key=s3_key, - config=encoded_config_str, + encoded_config_str=encoded_config_str, is_draft=not publish, model_id=model_id, is_trusted=trusted, preserve_previous_prod_deployment=preserve_previous_prod_deployment, - deployment_name=deployment_name, + version_name=deployment_name, origin=origin, environment=environment, - chain_environment=chain_environment, - chainlet_name=chainlet_name, - chain_name=chain_name, allow_truss_download=not disable_truss_download, ) + def push( # type: ignore + self, + truss_handle: TrussHandle, + model_name: str, + publish: bool = True, + trusted: bool = False, + promote: bool = False, + preserve_previous_prod_deployment: bool = False, + disable_truss_download: bool = False, + deployment_name: Optional[str] = None, + origin: Optional[custom_types.ModelOrigin] = None, + environment: Optional[str] = None, + ) -> BasetenService: + push_data = self._prepare_push( + truss_handle=truss_handle, + model_name=model_name, + publish=publish, + trusted=trusted, + promote=promote, + preserve_previous_prod_deployment=preserve_previous_prod_deployment, + disable_truss_download=disable_truss_download, + deployment_name=deployment_name, + origin=origin, + environment=environment, + ) + + # TODO(Tyron): This set of args is duplicated across + # many functions. We should consolidate them into a + # data class with standardized default values so + # we're not drilling these arguments everywhere. + model_id, model_version_id = create_truss_service( + api=self._api, + model_name=push_data.model_name, + s3_key=push_data.s3_key, + config=push_data.encoded_config_str, + is_draft=push_data.is_draft, + model_id=push_data.model_id, + is_trusted=push_data.is_trusted, + preserve_previous_prod_deployment=push_data.preserve_previous_prod_deployment, + allow_truss_download=push_data.allow_truss_download, + deployment_name=push_data.version_name, + origin=push_data.origin, + environment=push_data.environment, + ) + return BasetenService( model_id=model_id, model_version_id=model_version_id, - is_draft=not publish, + is_draft=push_data.is_draft, api_key=self._auth_service.authenticate().value, service_url=f"{self._remote_url}/model_versions/{model_version_id}", truss_handle=truss_handle, api=self._api, ) + def push_chain_atomic( + self, + chain_name: str, + entrypoint_artifact: custom_types.ChainletArtifact, + dependency_artifacts: List[custom_types.ChainletArtifact], + publish: bool = False, + environment: Optional[str] = None, + ) -> Tuple[ChainDeploymentHandleAtomic, BasetenService]: + # If we are promoting a model to an environment after deploy, it must be published. + # Draft models cannot be promoted. + if environment and not publish: + publish = True + + chainlet_data: List[custom_types.ChainletDataAtomic] = [] + + for artifact in [entrypoint_artifact, *dependency_artifacts]: + truss_handle = truss_build.load(str(artifact.truss_dir)) + model_name = truss_handle.spec.config.model_name + + assert model_name and validation.is_valid_model_name(model_name) + + push_data = self._prepare_push( + truss_handle=truss_handle, + model_name=model_name, + # Models must be trusted to use the API KEY secret. + trusted=True, + publish=publish, + origin=custom_types.ModelOrigin.CHAINS, + ) + oracle_data = custom_types.OracleData( + model_name=push_data.model_name, + s3_key=push_data.s3_key, + encoded_config_str=push_data.encoded_config_str, + is_draft=push_data.is_draft, + model_id=push_data.model_id, + is_trusted=push_data.is_trusted, + version_name=push_data.version_name, + ) + chainlet_data.append( + custom_types.ChainletDataAtomic( + name=artifact.display_name, + oracle=oracle_data, + ) + ) + logging.info( + f"Pushing Chainlet '{model_name}' as a Truss model on Baseten (publish={publish})." + ) + + chain_deployment_handle = create_chain_atomic( + api=self._api, + chain_name=chain_name, + entrypoint=chainlet_data[0], + dependencies=chainlet_data[1:], + is_draft=not publish, + environment=environment, + ) + + model_id = chain_deployment_handle.entrypoint_model_id + model_version_id = chain_deployment_handle.entrypoint_model_version_id + + entrypoint_service = BasetenService( + model_id=model_id, + model_version_id=model_version_id, + is_draft=not publish, + api_key=self._auth_service.authenticate().value, + service_url=f"{self._remote_url}/model_versions/{model_version_id}", + truss_handle=truss_build.load(str(entrypoint_artifact.truss_dir)), + api=self._api, + ) + + return ( + chain_deployment_handle, + entrypoint_service, + ) + @staticmethod def _get_matching_version(model_versions: List[dict], published: bool) -> dict: if not published: diff --git a/truss/tests/remote/baseten/test_api.py b/truss/tests/remote/baseten/test_api.py index 7993f2ef9..b1ebbbf4b 100644 --- a/truss/tests/remote/baseten/test_api.py +++ b/truss/tests/remote/baseten/test_api.py @@ -4,7 +4,7 @@ import requests from requests import Response from truss.remote.baseten.api import BasetenApi -from truss.remote.baseten.custom_types import ChainletData +from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData from truss.remote.baseten.error import ApiError @@ -69,9 +69,11 @@ def mock_deploy_chain_deployment_response(): response.json = mock.Mock( return_value={ "data": { - "deploy_chain_deployment": { + "deploy_chain_atomic": { "chain_id": "12345", "chain_deployment_id": "54321", + "entrypoint_model_id": "67890", + "entrypoint_model_version_id": "09876", } } } @@ -208,34 +210,6 @@ def test_create_model_from_truss(mock_post, baseten_api): assert 'version_name: "deployment_name"' in gql_mutation -@mock.patch("requests.post", return_value=mock_create_model_response()) -def test_create_model_from_truss_forwards_chainlet_data(mock_post, baseten_api): - baseten_api.create_model_from_truss( - "model_name", - "s3key", - "config_str", - "semver_bump", - "client_version", - is_trusted=False, - deployment_name="deployment_name", - chain_environment="chainstaging", - chain_name="chainchain", - chainlet_name="chainlet-1", - ) - - gql_mutation = mock_post.call_args[1]["data"]["query"] - assert 'name: "model_name"' in gql_mutation - assert 's3_key: "s3key"' in gql_mutation - assert 'config: "config_str"' in gql_mutation - assert 'semver_bump: "semver_bump"' in gql_mutation - assert 'client_version: "client_version"' in gql_mutation - assert "is_trusted: false" in gql_mutation - assert 'version_name: "deployment_name"' in gql_mutation - assert 'chain_environment: "chainstaging"' in gql_mutation - assert 'chain_name: "chainchain"' in gql_mutation - assert 'chainlet_name: "chainlet-1"' in gql_mutation - - @mock.patch("requests.post", return_value=mock_create_model_response()) def test_create_model_from_truss_does_not_send_deployment_name_if_not_specified( mock_post, baseten_api @@ -306,38 +280,48 @@ def test_create_development_model_from_truss_with_allow_truss_download( @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response()) def test_deploy_chain_deployment(mock_post, baseten_api): - baseten_api.deploy_chain_deployment( - "chain_id", - [ - ChainletData( - name="chainlet-1", - oracle_version_id="some-ov-id", - is_entrypoint=True, - ) - ], - "production", + baseten_api.deploy_chain_atomic( + environment="production", + chain_id="chain_id", + dependencies=[], + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + is_trusted=True, + ), + ), ) gql_mutation = mock_post.call_args[1]["data"]["query"] + + assert 'environment: "production"' in gql_mutation assert 'chain_id: "chain_id"' in gql_mutation - assert "chainlets:" in gql_mutation - assert 'environment_name: "production"' in gql_mutation + assert "dependencies:" in gql_mutation + assert "entrypoint:" in gql_mutation @mock.patch("requests.post", return_value=mock_deploy_chain_deployment_response()) def test_deploy_chain_deployment_no_environment(mock_post, baseten_api): - baseten_api.deploy_chain_deployment( - "chain_id", - [ - ChainletData( - name="chainlet-1", - oracle_version_id="some-ov-id", - is_entrypoint=True, - ) - ], + baseten_api.deploy_chain_atomic( + chain_id="chain_id", + dependencies=[], + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + is_trusted=True, + ), + ), ) gql_mutation = mock_post.call_args[1]["data"]["query"] + assert 'chain_id: "chain_id"' in gql_mutation - assert "chainlets:" in gql_mutation - assert "environment_name" not in gql_mutation + assert "environment" not in gql_mutation + assert "dependencies:" in gql_mutation + assert "entrypoint:" in gql_mutation diff --git a/truss/tests/remote/baseten/test_remote.py b/truss/tests/remote/baseten/test_remote.py index 6d13d47b0..281dcbac4 100644 --- a/truss/tests/remote/baseten/test_remote.py +++ b/truss/tests/remote/baseten/test_remote.py @@ -2,8 +2,14 @@ import pytest import requests_mock -from truss.remote.baseten.core import ModelId, ModelName, ModelVersionId -from truss.remote.baseten.custom_types import ChainletData +import truss +from truss.remote.baseten.core import ( + ModelId, + ModelName, + ModelVersionId, + create_chain_atomic, +) +from truss.remote.baseten.custom_types import ChainletDataAtomic, OracleData from truss.remote.baseten.error import RemoteError from truss.remote.baseten.remote import BasetenRemote from truss.truss_handle.truss_handle import TrussHandle @@ -12,13 +18,15 @@ _TEST_REMOTE_GRAPHQL_PATH = "http://test_remote.com/graphql/" -def match_graphql_request(request, expected_query): +def request_matches_expected_query(request, expected_query): unescaped_content = parse.unquote_plus(request.text) # Remove 'query=' prefix and any leading/trailing whitespace - graphql_query = unescaped_content.replace("query=", "").strip() + actual_query = unescaped_content.replace("query=", "").strip() - assert graphql_query == expected_query + return tuple( + line.strip() for line in actual_query.split("\n") if line.strip() + ) == tuple(line.strip() for line in expected_query.split("\n") if line.strip()) def test_get_service_by_version_id(): @@ -317,9 +325,11 @@ def test_create_chain_with_no_publish(): { "json": { "data": { - "deploy_draft_chain": { + "deploy_chain_atomic": { "chain_id": "new-chain-id", "chain_deployment_id": "new-chain-deployment-id", + "entrypoint_model_id": "new-entrypoint-model-id", + "entrypoint_model_version_id": "new-entrypoint-model-version-id", } } } @@ -327,52 +337,72 @@ def test_create_chain_with_no_publish(): ], ) - deployment_handle = remote.create_chain( - "draft_chain", - [ - ChainletData( - name="chainlet-1", - oracle_version_id="some-ov-id", - is_entrypoint=True, - ) - ], - publish=False, + deployment_handle = create_chain_atomic( + api=remote.api, + chain_name="draft_chain", + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + is_trusted=True, + ), + ), + dependencies=[], + is_draft=True, + environment=None, ) get_chains_graphql_request = m.request_history[0] create_chain_graphql_request = m.request_history[1] expected_get_chains_query = """ - { - chains { - id - name + { + chains { + id + name + } + } + """.strip() + + assert request_matches_expected_query( + get_chains_graphql_request, expected_get_chains_query + ) + + chainlets_string = """ + { + name: "chainlet-1", + oracle: { + model_name: "model-1", + s3_key: "s3-key-1", + encoded_config_str: "encoded-config-str-1", + is_trusted: true, + semver_bump: "MINOR" + } } - } """.strip() - match_graphql_request(get_chains_graphql_request, expected_get_chains_query) # Note that if publish=False and promote=True, we set publish to True and create # a non-draft deployment - expected_create_chain_mutation = """ - mutation { - deploy_draft_chain( - name: "draft_chain", - chainlets: [ - { - name: "chainlet-1", - oracle_version_id: "some-ov-id", - is_entrypoint: true - } - ] - ) { - chain_id - chain_deployment_id - } - } + expected_create_chain_mutation = f""" + mutation {{ + deploy_chain_atomic( + chain_name: "draft_chain" + is_draft: true + entrypoint: {chainlets_string} + dependencies: [] + client_version: "truss=={truss.version()}" + ) {{ + chain_id + chain_deployment_id + entrypoint_model_id + entrypoint_model_version_id + }} + }} """.strip() - match_graphql_request( + assert request_matches_expected_query( create_chain_graphql_request, expected_create_chain_mutation ) @@ -391,10 +421,11 @@ def test_create_chain_no_existing_chain(): { "json": { "data": { - "deploy_chain": { - "id": "new-chain-id", + "deploy_chain_atomic": { "chain_id": "new-chain-id", "chain_deployment_id": "new-chain-deployment-id", + "entrypoint_model_id": "new-entrypoint-model-id", + "entrypoint_model_version_id": "new-entrypoint-model-version-id", } } } @@ -402,52 +433,70 @@ def test_create_chain_no_existing_chain(): ], ) - deployment_handle = remote.create_chain( - "new_chain", - [ - ChainletData( - name="chainlet-1", - oracle_version_id="some-ov-id", - is_entrypoint=True, - ) - ], - publish=True, + deployment_handle = create_chain_atomic( + api=remote.api, + chain_name="new_chain", + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + is_trusted=True, + ), + ), + dependencies=[], + is_draft=False, + environment=None, ) get_chains_graphql_request = m.request_history[0] create_chain_graphql_request = m.request_history[1] expected_get_chains_query = """ - { - chains { - id - name + { + chains { + id + name + } } - } """.strip() - match_graphql_request(get_chains_graphql_request, expected_get_chains_query) - - expected_create_chain_mutation = """ - mutation { - deploy_chain( - name: "new_chain", - chainlets: [ - { - name: "chainlet-1", - oracle_version_id: "some-ov-id", - is_entrypoint: true - } - ] - ) { - id - chain_id - chain_deployment_id - } - } + assert request_matches_expected_query( + get_chains_graphql_request, expected_get_chains_query + ) + + chainlets_string = """ + { + name: "chainlet-1", + oracle: { + model_name: "model-1", + s3_key: "s3-key-1", + encoded_config_str: "encoded-config-str-1", + is_trusted: true, + semver_bump: "MINOR" + } + } + """.strip() + + expected_create_chain_mutation = f""" + mutation {{ + deploy_chain_atomic( + chain_name: "new_chain" + is_draft: false + entrypoint: {chainlets_string} + dependencies: [] + client_version: "truss=={truss.version()}" + ) {{ + chain_id + chain_deployment_id + entrypoint_model_id + entrypoint_model_version_id + }} + }} """.strip() - match_graphql_request( + assert request_matches_expected_query( create_chain_graphql_request, expected_create_chain_mutation ) @@ -472,10 +521,11 @@ def test_create_chain_with_existing_chain_promote_to_environment_publish_false() { "json": { "data": { - "deploy_chain_deployment": { - "id": "new-chain-id", + "deploy_chain_atomic": { "chain_id": "new-chain-id", "chain_deployment_id": "new-chain-deployment-id", + "entrypoint_model_id": "new-entrypoint-model-id", + "entrypoint_model_version_id": "new-entrypoint-model-version-id", } } } @@ -483,16 +533,20 @@ def test_create_chain_with_existing_chain_promote_to_environment_publish_false() ], ) - deployment_handle = remote.create_chain( - "old_chain", - [ - ChainletData( - name="chainlet-1", - oracle_version_id="some-ov-id", - is_entrypoint=True, - ) - ], - publish=False, + deployment_handle = create_chain_atomic( + api=remote.api, + chain_name="old_chain", + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + is_trusted=True, + ), + ), + dependencies=[], + is_draft=True, environment="production", ) @@ -500,38 +554,52 @@ def test_create_chain_with_existing_chain_promote_to_environment_publish_false() create_chain_graphql_request = m.request_history[1] expected_get_chains_query = """ - { - chains { - id - name + { + chains { + id + name + } } - } """.strip() - match_graphql_request(get_chains_graphql_request, expected_get_chains_query) + assert request_matches_expected_query( + get_chains_graphql_request, expected_get_chains_query + ) + # Note that if publish=False and environment!=None, we set publish to True and create # a non-draft deployment chainlets_string = """ - { - name: "chainlet-1", - oracle_version_id: "some-ov-id", - is_entrypoint: true - } - """ + { + name: "chainlet-1", + oracle: { + model_name: "model-1", + s3_key: "s3-key-1", + encoded_config_str: "encoded-config-str-1", + is_trusted: true, + semver_bump: "MINOR" + } + } + """.strip() + expected_create_chain_mutation = f""" - mutation {{ - deploy_chain_deployment( - chain_id: "old-chain-id", - chainlets: [{chainlets_string}], - environment_name: "production" - ) {{ - chain_id - chain_deployment_id + mutation {{ + deploy_chain_atomic( + chain_id: "old-chain-id" + environment: "production" + is_draft: false + entrypoint: {chainlets_string} + dependencies: [] + client_version: "truss=={truss.version()}" + ) {{ + chain_id + chain_deployment_id + entrypoint_model_id + entrypoint_model_version_id + }} }} - }} """.strip() - match_graphql_request( + assert request_matches_expected_query( create_chain_graphql_request, expected_create_chain_mutation ) @@ -556,10 +624,11 @@ def test_create_chain_existing_chain_publish_true_no_promotion(): { "json": { "data": { - "deploy_chain_deployment": { - "id": "new-chain-id", + "deploy_chain_atomic": { "chain_id": "new-chain-id", "chain_deployment_id": "new-chain-deployment-id", + "entrypoint_model_id": "new-entrypoint-model-id", + "entrypoint_model_version_id": "new-entrypoint-model-version-id", } } } @@ -567,53 +636,70 @@ def test_create_chain_existing_chain_publish_true_no_promotion(): ], ) - deployment_handle = remote.create_chain( - "old_chain", - [ - ChainletData( - name="chainlet-1", - oracle_version_id="some-ov-id", - is_entrypoint=True, - ) - ], - publish=True, + deployment_handle = create_chain_atomic( + api=remote.api, + chain_name="old_chain", + entrypoint=ChainletDataAtomic( + name="chainlet-1", + oracle=OracleData( + model_name="model-1", + s3_key="s3-key-1", + encoded_config_str="encoded-config-str-1", + is_trusted=True, + ), + ), + dependencies=[], + is_draft=False, + environment=None, ) get_chains_graphql_request = m.request_history[0] create_chain_graphql_request = m.request_history[1] expected_get_chains_query = """ - { - chains { - id - name + { + chains { + id + name + } } - } """.strip() - match_graphql_request(get_chains_graphql_request, expected_get_chains_query) + assert request_matches_expected_query( + get_chains_graphql_request, expected_get_chains_query + ) + chainlets_string = """ - { - name: "chainlet-1", - oracle_version_id: "some-ov-id", - is_entrypoint: true - } - """ - environment = None + { + name: "chainlet-1", + oracle: { + model_name: "model-1", + s3_key: "s3-key-1", + encoded_config_str: "encoded-config-str-1", + is_trusted: true, + semver_bump: "MINOR" + } + } + """.strip() + expected_create_chain_mutation = f""" - mutation {{ - deploy_chain_deployment( - chain_id: "old-chain-id", - chainlets: [{chainlets_string}], - {f'environment_name: "{environment}"' if environment else ""} - ) {{ - chain_id - chain_deployment_id + mutation {{ + deploy_chain_atomic( + chain_id: "old-chain-id" + is_draft: false + entrypoint: {chainlets_string} + dependencies: [] + client_version: "truss=={truss.version()}" + ) {{ + chain_id + chain_deployment_id + entrypoint_model_id + entrypoint_model_version_id + }} }} - }} """.strip() - match_graphql_request( + assert request_matches_expected_query( create_chain_graphql_request, expected_create_chain_mutation )