Skip to content

Commit

Permalink
[WIP] Propagate Client Disconnect through Truss (#1118)
Browse files Browse the repository at this point in the history
* Wire up client disconnect - not working in test

* Works for streaming with cURL. Test doesnt work and JSON doesnt work.

* Streaming test does work now.

* Make passing of is_cancelled_fn depdendent on function signature
  • Loading branch information
marius-baseten authored Sep 5, 2024
1 parent 333a0e6 commit 01192bd
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.31"
version = "0.9.30rc01+writer"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions truss/templates/server/common/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ async def predict(
with tracing.section_as_event(span, "model-call"):
response: Union[Dict, Generator] = await model(
body,
request.is_disconnected,
headers=utils.transform_keys(
request.headers, lambda key: key.lower()
),
Expand Down
35 changes: 26 additions & 9 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Coroutine,
Dict,
Expand Down Expand Up @@ -81,6 +82,7 @@ def defer() -> Callable[[], None]:

class ModelWrapper:
_tracer: sdk_trace.Tracer
_predict_cancellable: bool

class Status(Enum):
NOT_READY = 0
Expand All @@ -102,6 +104,7 @@ def __init__(self, config: Dict, tracer: sdk_trace.Tracer):
)
)
self.truss_schema: TrussSchema = None
self._predict_cancellable = False

def load(self) -> bool:
if self.ready:
Expand Down Expand Up @@ -192,6 +195,7 @@ def try_load(self):
raise RuntimeError("No module class file found")

self.set_truss_schema()
self._set_predict_cancellable()

if hasattr(self._model, "load"):
retry(
Expand All @@ -217,6 +221,13 @@ def set_truss_schema(self):

self.truss_schema = TrussSchema.from_signature(parameters, outputs_annotation)

def _set_predict_cancellable(self):
sig = inspect.signature(self._model.predict)
params = list(sig.parameters.values())
if len(params) < 2:
return False
self._predict_cancellable = params[1].name == "is_cancelled_fn"

async def preprocess(
self,
payload: Any,
Expand All @@ -232,8 +243,7 @@ async def preprocess(
)

async def predict(
self,
payload: Any,
self, payload: Any, is_cancelled_fn: Callable[[], Awaitable[bool]]
) -> Any:
# It's possible for the user's predict function to be a:
# 1. Generator function (function that returns a generator)
Expand All @@ -243,16 +253,15 @@ async def predict(
# 3. Coroutine -- in this case, await the predict function as it is async
# 4. Normal function -- in this case, offload to a separate thread to prevent
# blocking the main event loop
args = (payload, is_cancelled_fn) if self._predict_cancellable else (payload,)
if inspect.isasyncgenfunction(
self._model.predict
) or inspect.isgeneratorfunction(self._model.predict):
return self._model.predict(payload)

return self._model.predict(*args)
if inspect.iscoroutinefunction(self._model.predict):
return await _intercept_exceptions_async(self._model.predict)(payload)

return await _intercept_exceptions_async(self._model.predict)(*args)
return await to_thread.run_sync(
_intercept_exceptions_sync(self._model.predict), payload
_intercept_exceptions_sync(self._model.predict), *args
)

async def postprocess(
Expand Down Expand Up @@ -285,6 +294,8 @@ async def write_response_to_queue(
):
with tracing.section_as_event(span, "write_response_to_queue"):
try:
# Special case for writer: the triton client checks for canellations
# in each iteration.
async for chunk in generator:
# TODO: consider checking `request.is_disconnected()` for
# client-side cancellations and freeing resources.
Expand Down Expand Up @@ -368,7 +379,10 @@ async def _buffered_response_generator():
return _buffered_response_generator()

async def __call__(
self, body: Any, headers: Optional[Mapping[str, str]] = None
self,
body: Any,
is_cancelled_fn: Callable[[], Awaitable[bool]],
headers: Optional[Mapping[str, str]] = None,
) -> Union[Dict, Generator, AsyncGenerator, str]:
"""Method to call predictor or explainer with the given input.
Expand Down Expand Up @@ -416,12 +430,15 @@ async def __call__(
# exactly handle that case we would need to apply `detach_context`
# around each `next`-invocation that consumes the generator, which is
# prohibitive.
response = await self.predict(payload)
response = await self.predict(payload, is_cancelled_fn)

if inspect.isgenerator(response) or inspect.isasyncgen(response):
if headers and headers.get("accept") == "application/json":
# In the case of a streaming response, consume stream
# if the http accept header is set, and json is requested.
# TODO: cancellation does not work for this case.
# This is unexpected, because `is_cancelled_fn` should still be
# called in this code branch.
return await self._gather_generator(response, span_predict)
else:
return await self._stream_with_background_task(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model_name: Test Streaming Async Generator
python_version: py39
environment_variables:
OTEL_TRACING_NDJSON_FILE: "/tmp/otel_traces.ndjson"
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import asyncio
from typing import Any, Awaitable, Callable, Dict, List


class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._secrets = kwargs["secrets"]
self._model = None

def load(self):
# Load model here and assign to self._model.
pass

async def predict(
self, model_input: Any, is_cancelled_fn: Callable[[], Awaitable[bool]]
) -> Dict[str, List]:
# Invoke model on model_input and calculate predictions here.
await asyncio.sleep(1)
if await is_cancelled_fn():
print("Cancelled (before gen).")
return

for i in range(5):
await asyncio.sleep(1.0)
print(i)
yield str(i)
if await is_cancelled_fn():
print("Cancelled (during gen).")
return
9 changes: 7 additions & 2 deletions truss/tests/templates/server/test_model_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import importlib
import os
import sys
Expand Down Expand Up @@ -143,7 +144,9 @@ async def mock_predict(return_value):
):
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
resp = await model_wrapper.predict({})
resp = await model_wrapper.predict(
{}, is_cancelled_fn=lambda: asyncio.sleep(0, result=False)
)
mock_extension.load.assert_called()
mock_extension.model_args.assert_called()
assert mock_predict_called
Expand Down Expand Up @@ -180,7 +183,9 @@ async def mock_predict(return_value):
):
model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer())
model_wrapper.load()
resp = await model_wrapper.predict({})
resp = await model_wrapper.predict(
{}, is_cancelled_fn=lambda: asyncio.sleep(0, result=False)
)
mock_extension.load.assert_called()
mock_extension.model_override.assert_called()
assert mock_predict_called
Expand Down
25 changes: 25 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,31 @@ def test_async_streaming():
assert predict_non_stream_response.json() == "01234"


@pytest.mark.integration
def test_async_streaming_with_cancel():
with ensure_kill_all():
truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"

truss_dir = (
truss_root
/ "test_data"
/ "test_streaming_async_cancellable_generator_truss"
)

tr = TrussHandle(truss_dir)

container = tr.docker_run(
local_port=8090, detach=True, wait_for_server_ready=True
)
truss_server_addr = "http://localhost:8090"
full_url = f"{truss_server_addr}/v1/models/model:predict"

with pytest.raises(requests.ConnectionError):
requests.post(full_url, json={}, stream=False, timeout=1)
time.sleep(2) # Wait a bit to get all logs.
assert "Cancelled (during gen)." in container.logs()


@pytest.mark.integration
def test_async_streaming_timeout():
with ensure_kill_all():
Expand Down

0 comments on commit 01192bd

Please sign in to comment.