Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new health check configuration fields #1290

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
RemoteConfig,
RemoteErrorDetail,
RPCOptions,
Runtime,
)
from truss_chains.public_api import (
ChainletBase,
Expand All @@ -57,6 +58,7 @@
"RPCOptions",
"RemoteConfig",
"RemoteErrorDetail",
"Runtime",
"DeployedServiceDescriptor",
"StubBase",
"depends",
Expand Down
5 changes: 5 additions & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,10 @@ class ChainletOptions(SafeModelNonSerializable):
env_variables: Mapping[str, str] = {}


class Runtime(SafeModelNonSerializable):
health_checks: truss_config.HealthCheck = truss_config.HealthCheck()


class ChainletMetadata(SafeModelNonSerializable):
is_entrypoint: bool = False
chain_name: Optional[str] = None
Expand Down Expand Up @@ -393,6 +397,7 @@ class MyChainlet(chains.ChainletBase):
assets: Assets = Assets()
name: Optional[str] = None
options: ChainletOptions = ChainletOptions()
runtime: Runtime = Runtime()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chains doesn't map 1:1 to the structure of truss config. I wonder if it would make sense to add this field to ChainletOptions instead of creating a new class.

If you want to proceed with a new class, then there should be some differentiator, explaining which future fields go into Runtime and which ones go into ChainletOptions.


def get_compute_spec(self) -> ComputeSpec:
return self.compute.get_spec()
Expand Down
2 changes: 2 additions & 0 deletions truss-chains/truss_chains/deployment/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,9 @@ def _make_truss_config(
config.resources.memory = str(compute.memory)
config.resources.accelerator = compute.accelerator
config.resources.use_gpu = bool(compute.accelerator.count)
# Runtime
config.runtime.predict_concurrency = compute.predict_concurrency
config.runtime.health_checks = chains_config.runtime.health_checks
# Image.
_inplace_fill_base_image(chains_config.docker_image, config)
pip_requirements = _make_requirements(chains_config.docker_image)
Expand Down
68 changes: 67 additions & 1 deletion truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
DEFAULT_PREDICT_CONCURRENCY = 1
DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT = 60
DEFAULT_ENABLE_TRACING_DATA = False # This should be in sync with tracing.py.

MAX_FAILURE_THRESHOLD_SECONDS = 1800
MIN_FAILURE_THRESHOLD_SECONDS = 30
DEFAULT_CPU = "1"
DEFAULT_MEMORY = "2Gi"
DEFAULT_USE_GPU = False
Expand Down Expand Up @@ -154,12 +155,70 @@ def to_list(self, verbose=False) -> List[Dict[str, str]]:
return [model.to_dict(verbose=verbose) for model in self.models]


@dataclass
class HealthCheck:
spal1 marked this conversation as resolved.
Show resolved Hide resolved
restart_check_delay_seconds: int = 0
restart_failure_threshold_seconds: int = MAX_FAILURE_THRESHOLD_SECONDS
stop_traffic_failure_threshold_seconds: int = MAX_FAILURE_THRESHOLD_SECONDS

def __post_init__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to do these validations on the backend. That gives us more flexibility to not be stuck with these checks for old clients.

Can you look into how that might work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline--I'll use this as an opportunity to create a first version of a validate_truss_config GQL mutation that we can call from here: https://github.com/basetenlabs/truss/blob/main/truss/remote/baseten/remote.py#L124

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we write so much boilerplate code for this here when pydantic can do most of this. - more a comment on pre-existing code though...

I think we should refactor this whole file soon. It's also super error prone that you have to add key literals to to_dict methods each time you add a field.

if not (0 <= self.restart_check_delay_seconds <= MAX_FAILURE_THRESHOLD_SECONDS):
raise ValidationError(
f"restart_check_delay_seconds must be between 0 and {MAX_FAILURE_THRESHOLD_SECONDS} seconds."
)
if not (
MIN_FAILURE_THRESHOLD_SECONDS
<= self.restart_failure_threshold_seconds
<= MAX_FAILURE_THRESHOLD_SECONDS
):
raise ValidationError(
f"restart_failure_threshold_seconds must be between {MIN_FAILURE_THRESHOLD_SECONDS} and {MAX_FAILURE_THRESHOLD_SECONDS} seconds."
)
if not (
MIN_FAILURE_THRESHOLD_SECONDS
<= self.stop_traffic_failure_threshold_seconds
<= MAX_FAILURE_THRESHOLD_SECONDS
):
raise ValidationError(
f"stop_traffic_failure_threshold_seconds must be between {MIN_FAILURE_THRESHOLD_SECONDS} and {MAX_FAILURE_THRESHOLD_SECONDS} seconds."
)

if (
self.restart_check_delay_seconds + self.restart_failure_threshold_seconds
> MAX_FAILURE_THRESHOLD_SECONDS
):
raise ValidationError(
"The sum of restart_check_delay_seconds and max_failures_before_restart "
f"must not exceed {MAX_FAILURE_THRESHOLD_SECONDS} seconds."
)

@staticmethod
def from_dict(d):
return HealthCheck(
restart_check_delay_seconds=d.get("restart_check_delay_seconds", 0),
restart_failure_threshold_seconds=d.get(
"restart_failure_threshold_seconds", MAX_FAILURE_THRESHOLD_SECONDS
),
stop_traffic_failure_threshold_seconds=d.get(
"stop_traffic_failure_threshold_seconds", MAX_FAILURE_THRESHOLD_SECONDS
),
)

def to_dict(self):
return {
"restart_check_delay_seconds": self.restart_check_delay_seconds,
"restart_failure_threshold_seconds": self.restart_failure_threshold_seconds,
"stop_traffic_failure_threshold_seconds": self.stop_traffic_failure_threshold_seconds,
}


@dataclass
class Runtime:
predict_concurrency: int = DEFAULT_PREDICT_CONCURRENCY
streaming_read_timeout: int = DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT
enable_tracing_data: bool = DEFAULT_ENABLE_TRACING_DATA
enable_debug_logs: bool = False
health_checks: HealthCheck = field(default_factory=HealthCheck)

@staticmethod
def from_dict(d):
Expand All @@ -176,18 +235,21 @@ def from_dict(d):
"streaming_read_timeout", DEFAULT_STREAMING_RESPONSE_READ_TIMEOUT
)
enable_tracing_data = d.get("enable_tracing_data", DEFAULT_ENABLE_TRACING_DATA)
health_checks = HealthCheck.from_dict(d.get("health_checks", {}))

return Runtime(
predict_concurrency=predict_concurrency,
streaming_read_timeout=streaming_read_timeout,
enable_tracing_data=enable_tracing_data,
health_checks=health_checks,
)

def to_dict(self):
return {
"predict_concurrency": self.predict_concurrency,
"streaming_read_timeout": self.streaming_read_timeout,
"enable_tracing_data": self.enable_tracing_data,
"health_checks": self.health_checks.to_dict(),
}


Expand Down Expand Up @@ -819,6 +881,10 @@ def obj_to_dict(obj, verbose: bool = False):
d["docker_auth"] = transform_optional(
field_curr_value, lambda data: data.to_dict()
)
elif isinstance(field_curr_value, HealthCheck):
d[field_name] = transform_optional(
field_curr_value, lambda data: data.to_dict()
)
else:
d[field_name] = field_curr_value

Expand Down
60 changes: 60 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,66 @@ def predict(self, model_input):
)


@pytest.mark.integration
def test_health_check_configuration():
model = """
class Model:
def predict(self, model_input):
return model_input
"""

config = """runtime:
health_checks:
restart_check_delay_seconds: 100
restart_failure_threshold_seconds: 1700
"""

with ensure_kill_all(), _temp_truss(model, config) as tr:
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)

assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 100
assert (
tr.spec.config.runtime.health_checks.restart_failure_threshold_seconds
== 1700
)
assert (
tr.spec.config.runtime.health_checks.stop_traffic_failure_threshold_seconds
== 1800
)

config = """runtime:
health_checks:
restart_check_delay_seconds: 1200
restart_failure_threshold_seconds: 90
stop_traffic_failure_threshold_seconds: 50
"""

with ensure_kill_all(), _temp_truss(model, config) as tr:
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)

assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 1200
assert (
tr.spec.config.runtime.health_checks.restart_failure_threshold_seconds == 90
)
assert (
tr.spec.config.runtime.health_checks.stop_traffic_failure_threshold_seconds
== 50
)

with ensure_kill_all(), _temp_truss(model, "") as tr:
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)

assert tr.spec.config.runtime.health_checks.restart_check_delay_seconds == 0
assert (
tr.spec.config.runtime.health_checks.restart_failure_threshold_seconds
== 1800
)
assert (
tr.spec.config.runtime.health_checks.stop_traffic_failure_threshold_seconds
== 1800
)


def _patch_termination_timeout(container: Container, seconds: int, truss_container_fs):
app_path = truss_container_fs / "app"
sys.path.append(str(app_path))
Expand Down