Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new health check configuration fields #1290

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RemoteConfig,
RemoteErrorDetail,
RPCOptions,
Runtime,
)
from truss_chains.public_api import (
ChainletBase,
Expand All @@ -57,6 +58,7 @@
"RPCOptions",
"RemoteConfig",
"RemoteErrorDetail",
"Runtime",
"DeployedServiceDescriptor",
"StubBase",
"depends",
Expand Down
5 changes: 5 additions & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ class ChainletOptions(SafeModelNonSerializable):
env_variables: Mapping[str, str] = {}


class Runtime(SafeModelNonSerializable):
health_checks: truss_config.HealthChecks = truss_config.HealthChecks()


class ChainletMetadata(SafeModelNonSerializable):
is_entrypoint: bool = False
chain_name: Optional[str] = None
Expand Down Expand Up @@ -393,6 +397,7 @@ class MyChainlet(chains.ChainletBase):
assets: Assets = Assets()
name: Optional[str] = None
options: ChainletOptions = ChainletOptions()
runtime: Runtime = Runtime()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chains doesn't map 1:1 to the structure of truss config. I wonder if it would make sense to add this field to ChainletOptions instead of creating a new class.

If you want to proceed with a new class, then there should be some differentiator, explaining which future fields go into Runtime and which ones go into ChainletOptions.


def get_compute_spec(self) -> ComputeSpec:
return self.compute.get_spec()
Expand Down
2 changes: 2 additions & 0 deletions truss-chains/truss_chains/deployment/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,9 @@ def _make_truss_config(
config.resources.memory = str(compute.memory)
config.resources.accelerator = compute.accelerator
config.resources.use_gpu = bool(compute.accelerator.count)
# Runtime
config.runtime.predict_concurrency = compute.predict_concurrency
config.runtime.health_checks = chains_config.runtime.health_checks
# Image.
_inplace_fill_base_image(chains_config.docker_image, config)
pip_requirements = _make_requirements(chains_config.docker_image)
Expand Down
37 changes: 36 additions & 1 deletion truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
DEFAULT_PREDICT_CONCURRENCY = 1
DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT = 60
DEFAULT_ENABLE_TRACING_DATA = False # This should be in sync with tracing.py.

MAX_FAILURE_THRESHOLD_SECONDS = 1800
MIN_FAILURE_THRESHOLD_SECONDS = 30
DEFAULT_CPU = "1"
DEFAULT_MEMORY = "2Gi"
DEFAULT_USE_GPU = False
Expand Down Expand Up @@ -154,12 +155,39 @@ def to_list(self, verbose=False) -> List[Dict[str, str]]:
return [model.to_dict(verbose=verbose) for model in self.models]


@dataclass
class HealthChecks:
restart_check_delay_seconds: int = 0
restart_failure_threshold_seconds: int = MAX_FAILURE_THRESHOLD_SECONDS
stop_traffic_failure_threshold_seconds: int = MAX_FAILURE_THRESHOLD_SECONDS

@staticmethod
def from_dict(d):
return HealthChecks(
restart_check_delay_seconds=d.get("restart_check_delay_seconds", 0),
restart_failure_threshold_seconds=d.get(
"restart_failure_threshold_seconds", MAX_FAILURE_THRESHOLD_SECONDS
),
stop_traffic_failure_threshold_seconds=d.get(
"stop_traffic_failure_threshold_seconds", MAX_FAILURE_THRESHOLD_SECONDS
),
)

def to_dict(self):
return {
"restart_check_delay_seconds": self.restart_check_delay_seconds,
"restart_failure_threshold_seconds": self.restart_failure_threshold_seconds,
"stop_traffic_failure_threshold_seconds": self.stop_traffic_failure_threshold_seconds,
}


@dataclass
class Runtime:
predict_concurrency: int = DEFAULT_PREDICT_CONCURRENCY
streaming_read_timeout: int = DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT
enable_tracing_data: bool = DEFAULT_ENABLE_TRACING_DATA
enable_debug_logs: bool = False
health_checks: HealthChecks = field(default_factory=HealthChecks)

@staticmethod
def from_dict(d):
Expand All @@ -176,18 +204,21 @@ def from_dict(d):
"streaming_read_timeout", DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT
)
enable_tracing_data = d.get("enable_tracing_data", DEFAULT_ENABLE_TRACING_DATA)
health_checks = HealthChecks.from_dict(d.get("health_checks", {}))

return Runtime(
predict_concurrency=predict_concurrency,
streaming_read_timeout=streaming_read_timeout,
enable_tracing_data=enable_tracing_data,
health_checks=health_checks,
)

def to_dict(self):
return {
"predict_concurrency": self.predict_concurrency,
"streaming_read_timeout": self.streaming_read_timeout,
"enable_tracing_data": self.enable_tracing_data,
"health_checks": self.health_checks.to_dict(),
}


Expand Down Expand Up @@ -819,6 +850,10 @@ def obj_to_dict(obj, verbose: bool = False):
d["docker_auth"] = transform_optional(
field_curr_value, lambda data: data.to_dict()
)
elif isinstance(field_curr_value, HealthChecks):
d[field_name] = transform_optional(
field_curr_value, lambda data: data.to_dict()
)
else:
d[field_name] = field_curr_value

Expand Down
60 changes: 60 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,66 @@ def predict(self, model_input):
)


@pytest.mark.integration
def test_health_check_configuration():
model = """
class Model:
def predict(self, model_input):
return model_input
"""

config = """runtime:
health_checks:
restart_check_delay_seconds: 100
restart_failure_threshold_seconds: 1700
"""

with ensure_kill_all(), _temp_truss(model, config) as tr:
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)

assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 100
assert (
tr.spec.config.runtime.health_checks.restart_failure_threshold_seconds
== 1700
)
assert (
tr.spec.config.runtime.health_checks.stop_traffic_failure_threshold_seconds
== 1800
)

config = """runtime:
health_checks:
restart_check_delay_seconds: 1200
restart_failure_threshold_seconds: 90
stop_traffic_failure_threshold_seconds: 50
"""

with ensure_kill_all(), _temp_truss(model, config) as tr:
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)

assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 1200
assert (
tr.spec.config.runtime.health_checks.restart_failure_threshold_seconds == 90
)
assert (
tr.spec.config.runtime.health_checks.stop_traffic_failure_threshold_seconds
== 50
)

with ensure_kill_all(), _temp_truss(model, "") as tr:
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)

assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 0
assert (
tr.spec.config.runtime.health_checks.restart_failure_threshold_seconds
== 1800
)
assert (
tr.spec.config.runtime.health_checks.stop_traffic_failure_threshold_seconds
== 1800
)


def _patch_termination_timeout(container: Container, seconds: int, truss_container_fs):
app_path = truss_container_fs / "app"
sys.path.append(str(app_path))
Expand Down