Skip to content

Commit

Permalink
Merge pull request #1250 from basetenlabs/bump-version-0.9.51
Browse files Browse the repository at this point in the history
Release 0.9.51
  • Loading branch information
tyranitar authored Nov 18, 2024
2 parents 51afc15 + bae1882 commit 8aebf26
Show file tree
Hide file tree
Showing 19 changed files with 789 additions and 551 deletions.
2 changes: 1 addition & 1 deletion docs/chains/doc_gen/generate_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
"truss_chains.remote.ChainService",
"truss_chains.make_abs_path_here",
"truss_chains.run_local",
"truss_chains.ServiceDescriptor",
"truss_chains.DeployedServiceDescriptor",
"truss_chains.StubBase",
"truss_chains.RemoteErrorDetail",
# "truss_chains.ChainsRuntimeError",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.50"
version = "0.9.51"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
40 changes: 12 additions & 28 deletions truss-chains/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import copy
import json

import pytest

from truss_chains import definitions
from truss_chains.utils import override_chainlet_to_service_metadata
from truss_chains.utils import populate_chainlet_service_predict_urls

DYNAMIC_CHAINLET_CONFIG_VALUE = {
"HelloWorld": {
Expand All @@ -22,23 +21,22 @@ def dynamic_config_mount_dir(tmp_path, monkeypatch: pytest.MonkeyPatch):
yield


def test_override_chainlet_to_service_metadata(tmp_path, dynamic_config_mount_dir):
def test_populate_chainlet_service_predict_urls(tmp_path, dynamic_config_mount_dir):
with (tmp_path / definitions.DYNAMIC_CHAINLET_CONFIG_KEY).open("w") as f:
f.write(json.dumps(DYNAMIC_CHAINLET_CONFIG_VALUE))

chainlet_to_service = {
"HelloWorld": definitions.ServiceDescriptor(
name="HelloWorld",
predict_url="https://model-model_id.api.baseten.co/deployments/deployment_id/predict",
options=definitions.RPCOptions(),
)
}
original_chainlet_to_service = copy.deepcopy(chainlet_to_service)
override_chainlet_to_service_metadata(chainlet_to_service)
new_chainlet_to_service = populate_chainlet_service_predict_urls(
chainlet_to_service
)

assert chainlet_to_service != original_chainlet_to_service
assert (
chainlet_to_service["HelloWorld"].predict_url
new_chainlet_to_service["HelloWorld"].predict_url
== DYNAMIC_CHAINLET_CONFIG_VALUE["HelloWorld"]["predict_url"]
)

Expand All @@ -47,34 +45,20 @@ def test_override_chainlet_to_service_metadata(tmp_path, dynamic_config_mount_di
"config",
[DYNAMIC_CHAINLET_CONFIG_VALUE, {}, ""],
)
def test_no_override_chainlet_to_service_metadata(
def test_no_populate_chainlet_service_predict_urls(
config, tmp_path, dynamic_config_mount_dir
):
with (tmp_path / definitions.DYNAMIC_CHAINLET_CONFIG_KEY).open("w") as f:
f.write(json.dumps(config))

chainlet_to_service = {
"RandInt": definitions.ServiceDescriptor(
name="HelloWorld",
predict_url="https://model-model_id.api.baseten.co/deployments/deployment_id/predict",
options=definitions.RPCOptions(),
)
}
original_chainlet_to_service = copy.deepcopy(chainlet_to_service)
override_chainlet_to_service_metadata(chainlet_to_service)

assert chainlet_to_service == original_chainlet_to_service


def test_no_config_override_chainlet_to_service_metadata(dynamic_config_mount_dir):
chainlet_to_service = {
"HelloWorld": definitions.ServiceDescriptor(
name="HelloWorld",
predict_url="https://model-model_id.api.baseten.co/deployments/deployment_id/predict",
name="RandInt",
options=definitions.RPCOptions(),
)
}
original_chainlet_to_service = copy.deepcopy(chainlet_to_service)
override_chainlet_to_service_metadata(chainlet_to_service)

assert chainlet_to_service == original_chainlet_to_service
with pytest.raises(
definitions.MissingDependencyError, match="Chainlet 'RandInt' not found"
):
populate_chainlet_service_predict_urls(chainlet_to_service)
4 changes: 2 additions & 2 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
ChainletOptions,
Compute,
CustomImage,
DeployedServiceDescriptor,
DeploymentContext,
DockerImage,
RemoteConfig,
RemoteErrorDetail,
RPCOptions,
ServiceDescriptor,
)
from truss_chains.public_api import (
ChainletBase,
Expand All @@ -55,7 +55,7 @@
"RPCOptions",
"RemoteConfig",
"RemoteErrorDetail",
"ServiceDescriptor",
"DeployedServiceDescriptor",
"StubBase",
"depends",
"depends_context",
Expand Down
2 changes: 0 additions & 2 deletions truss-chains/truss_chains/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,14 +623,12 @@ def gen_truss_chainlet(
chain_name: str,
chainlet_descriptor: definitions.ChainletAPIDescriptor,
model_name: str,
chainlet_display_name_to_url: Mapping[str, str],
) -> pathlib.Path:
# Filter needed services and customize options.
dep_services = {}
for dep in chainlet_descriptor.dependencies.values():
dep_services[dep.name] = definitions.ServiceDescriptor(
name=dep.name,
predict_url=chainlet_display_name_to_url[dep.display_name],
options=dep.options,
)

Expand Down
11 changes: 6 additions & 5 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class SafeModel(pydantic.BaseModel):
arbitrary_types_allowed=False,
strict=True,
validate_assignment=True,
extra="forbid",
)


Expand Down Expand Up @@ -397,10 +396,13 @@ class ServiceDescriptor(SafeModel):
specifically with ``StubBase``."""

name: str
predict_url: str
options: RPCOptions


class DeployedServiceDescriptor(ServiceDescriptor):
predict_url: str


class Environment(SafeModel):
"""The environment the chainlet is deployed in.
Expand All @@ -422,7 +424,6 @@ class DeploymentContext(SafeModelNonSerializable):
Args:
data_dir: The directory where the chainlet can store and access data,
e.g. for downloading model weights.
user_config: User-defined configuration for the chainlet.
chainlet_to_service: A mapping from chainlet names to service descriptors.
This is used create RPCs sessions to dependency chainlets. It contains only
the chainlet services that are dependencies of the current chainlet.
Expand All @@ -434,11 +435,11 @@ class DeploymentContext(SafeModelNonSerializable):
"""

data_dir: Optional[pathlib.Path] = None
chainlet_to_service: Mapping[str, ServiceDescriptor]
chainlet_to_service: Mapping[str, DeployedServiceDescriptor]
secrets: MappingNoIter[str, str]
environment: Optional[Environment] = None

def get_service_descriptor(self, chainlet_name: str) -> ServiceDescriptor:
def get_service_descriptor(self, chainlet_name: str) -> DeployedServiceDescriptor:
if chainlet_name not in self.chainlet_to_service:
raise MissingDependencyError(f"{chainlet_name}")
return self.chainlet_to_service[chainlet_name]
Expand Down
4 changes: 2 additions & 2 deletions truss-chains/truss_chains/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ def _create_modified_init_for_local(
],
secrets: Mapping[str, str],
data_dir: Optional[pathlib.Path],
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
chainlet_to_service: Mapping[str, definitions.DeployedServiceDescriptor],
):
"""Replaces the default argument values with local Chainlet instantiations.
Expand Down Expand Up @@ -1011,7 +1011,7 @@ def __init_local__(self: definitions.ABCChainlet, **kwargs) -> None:
def run_local(
secrets: Mapping[str, str],
data_dir: Optional[pathlib.Path],
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
chainlet_to_service: Mapping[str, definitions.DeployedServiceDescriptor],
) -> Any:
"""Context to run Chainlets with dependency injection from local instances."""
# TODO: support retries in local mode.
Expand Down
8 changes: 5 additions & 3 deletions truss-chains/truss_chains/model_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from truss.templates.shared import secrets_resolver

from truss_chains import definitions
from truss_chains.utils import override_chainlet_to_service_metadata
from truss_chains.utils import populate_chainlet_service_predict_urls


class TrussChainletModel:
Expand All @@ -28,10 +28,12 @@ def __init__(
deployment_environment: Optional[definitions.Environment] = (
definitions.Environment.model_validate(environment) if environment else None
)
override_chainlet_to_service_metadata(truss_metadata.chainlet_to_service)
chainlet_to_deployed_service = populate_chainlet_service_predict_urls(
truss_metadata.chainlet_to_service
)

self._context = definitions.DeploymentContext(
chainlet_to_service=truss_metadata.chainlet_to_service,
chainlet_to_service=chainlet_to_deployed_service,
secrets=secrets,
data_dir=data_dir,
environment=deployment_environment,
Expand Down
6 changes: 4 additions & 2 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def push(
def run_local(
secrets: Optional[Mapping[str, str]] = None,
data_dir: Optional[Union[pathlib.Path, str]] = None,
chainlet_to_service: Optional[Mapping[str, definitions.ServiceDescriptor]] = None,
chainlet_to_service: Optional[
Mapping[str, definitions.DeployedServiceDescriptor]
] = None,
) -> ContextManager[None]:
"""Context manager local debug execution of a chain.
Expand All @@ -188,7 +190,7 @@ class HelloWorld(chains.ChainletBase):
with chains.run_local(
secrets={"some_token": os.environ["SOME_TOKEN"]},
chainlet_to_service={
"SomeChainlet": chains.ServiceDescriptor(
"SomeChainlet": chains.DeployedServiceDescriptor(
name="SomeChainlet",
predict_url="https://...",
options=chains.RPCOptions(),
Expand Down
Loading

0 comments on commit 8aebf26

Please sign in to comment.