diff --git a/pyproject.toml b/pyproject.toml index 5f343c225..79647e8ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.31" +version = "0.9.30rc01+writer" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/templates/server/common/truss_server.py b/truss/templates/server/common/truss_server.py index 7db57a99c..8c141148e 100644 --- a/truss/templates/server/common/truss_server.py +++ b/truss/templates/server/common/truss_server.py @@ -158,6 +158,7 @@ async def predict( with tracing.section_as_event(span, "model-call"): response: Union[Dict, Generator] = await model( body, + request.is_disconnected, headers=utils.transform_keys( request.headers, lambda key: key.lower() ), diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index b807f2350..f1739313a 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -14,6 +14,7 @@ from typing import ( Any, AsyncGenerator, + Awaitable, Callable, Coroutine, Dict, @@ -81,6 +82,7 @@ def defer() -> Callable[[], None]: class ModelWrapper: _tracer: sdk_trace.Tracer + _predict_cancellable: bool class Status(Enum): NOT_READY = 0 @@ -102,6 +104,7 @@ def __init__(self, config: Dict, tracer: sdk_trace.Tracer): ) ) self.truss_schema: TrussSchema = None + self._predict_cancellable = False def load(self) -> bool: if self.ready: @@ -192,6 +195,7 @@ def try_load(self): raise RuntimeError("No module class file found") self.set_truss_schema() + self._set_predict_cancellable() if hasattr(self._model, "load"): retry( @@ -217,6 +221,13 @@ def set_truss_schema(self): self.truss_schema = TrussSchema.from_signature(parameters, outputs_annotation) + def _set_predict_cancellable(self): + sig = inspect.signature(self._model.predict) + params = list(sig.parameters.values()) + if len(params) < 2: + return False + self._predict_cancellable = params[1].name == "is_cancelled_fn" + async def preprocess( self, payload: Any, @@ -232,8 +243,7 @@ async def preprocess( ) async def predict( - self, - payload: Any, + self, payload: Any, is_cancelled_fn: Callable[[], Awaitable[bool]] ) -> Any: # It's possible for the user's predict function to be a: # 1. Generator function (function that returns a generator) @@ -243,16 +253,15 @@ async def predict( # 3. Coroutine -- in this case, await the predict function as it is async # 4. Normal function -- in this case, offload to a separate thread to prevent # blocking the main event loop + args = (payload, is_cancelled_fn) if self._predict_cancellable else (payload,) if inspect.isasyncgenfunction( self._model.predict ) or inspect.isgeneratorfunction(self._model.predict): - return self._model.predict(payload) - + return self._model.predict(*args) if inspect.iscoroutinefunction(self._model.predict): - return await _intercept_exceptions_async(self._model.predict)(payload) - + return await _intercept_exceptions_async(self._model.predict)(*args) return await to_thread.run_sync( - _intercept_exceptions_sync(self._model.predict), payload + _intercept_exceptions_sync(self._model.predict), *args ) async def postprocess( @@ -285,6 +294,8 @@ async def write_response_to_queue( ): with tracing.section_as_event(span, "write_response_to_queue"): try: + # Special case for writer: the triton client checks for canellations + # in each iteration. async for chunk in generator: # TODO: consider checking `request.is_disconnected()` for # client-side cancellations and freeing resources. @@ -368,7 +379,10 @@ async def _buffered_response_generator(): return _buffered_response_generator() async def __call__( - self, body: Any, headers: Optional[Mapping[str, str]] = None + self, + body: Any, + is_cancelled_fn: Callable[[], Awaitable[bool]], + headers: Optional[Mapping[str, str]] = None, ) -> Union[Dict, Generator, AsyncGenerator, str]: """Method to call predictor or explainer with the given input. @@ -416,12 +430,15 @@ async def __call__( # exactly handle that case we would need to apply `detach_context` # around each `next`-invocation that consumes the generator, which is # prohibitive. - response = await self.predict(payload) + response = await self.predict(payload, is_cancelled_fn) if inspect.isgenerator(response) or inspect.isasyncgen(response): if headers and headers.get("accept") == "application/json": # In the case of a streaming response, consume stream # if the http accept header is set, and json is requested. + # TODO: cancellation does not work for this case. + # This is unexpected, because `is_cancelled_fn` should still be + # called in this code branch. return await self._gather_generator(response, span_predict) else: return await self._stream_with_background_task( diff --git a/truss/test_data/test_streaming_async_cancellable_generator_truss/config.yaml b/truss/test_data/test_streaming_async_cancellable_generator_truss/config.yaml new file mode 100644 index 000000000..0b3e897ce --- /dev/null +++ b/truss/test_data/test_streaming_async_cancellable_generator_truss/config.yaml @@ -0,0 +1,4 @@ +model_name: Test Streaming Async Generator +python_version: py39 +environment_variables: + OTEL_TRACING_NDJSON_FILE: "/tmp/otel_traces.ndjson" diff --git a/truss/test_data/test_streaming_async_cancellable_generator_truss/model/model.py b/truss/test_data/test_streaming_async_cancellable_generator_truss/model/model.py new file mode 100644 index 000000000..694d0d66f --- /dev/null +++ b/truss/test_data/test_streaming_async_cancellable_generator_truss/model/model.py @@ -0,0 +1,31 @@ +import asyncio +from typing import Any, Awaitable, Callable, Dict, List + + +class Model: + def __init__(self, **kwargs) -> None: + self._data_dir = kwargs["data_dir"] + self._config = kwargs["config"] + self._secrets = kwargs["secrets"] + self._model = None + + def load(self): + # Load model here and assign to self._model. + pass + + async def predict( + self, model_input: Any, is_cancelled_fn: Callable[[], Awaitable[bool]] + ) -> Dict[str, List]: + # Invoke model on model_input and calculate predictions here. + await asyncio.sleep(1) + if await is_cancelled_fn(): + print("Cancelled (before gen).") + return + + for i in range(5): + await asyncio.sleep(1.0) + print(i) + yield str(i) + if await is_cancelled_fn(): + print("Cancelled (during gen).") + return diff --git a/truss/tests/templates/server/test_model_wrapper.py b/truss/tests/templates/server/test_model_wrapper.py index e8c1ad829..bd17727d1 100644 --- a/truss/tests/templates/server/test_model_wrapper.py +++ b/truss/tests/templates/server/test_model_wrapper.py @@ -1,3 +1,4 @@ +import asyncio import importlib import os import sys @@ -143,7 +144,9 @@ async def mock_predict(return_value): ): model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer()) model_wrapper.load() - resp = await model_wrapper.predict({}) + resp = await model_wrapper.predict( + {}, is_cancelled_fn=lambda: asyncio.sleep(0, result=False) + ) mock_extension.load.assert_called() mock_extension.model_args.assert_called() assert mock_predict_called @@ -180,7 +183,9 @@ async def mock_predict(return_value): ): model_wrapper = model_wrapper_class(config, sdk_trace.NoOpTracer()) model_wrapper.load() - resp = await model_wrapper.predict({}) + resp = await model_wrapper.predict( + {}, is_cancelled_fn=lambda: asyncio.sleep(0, result=False) + ) mock_extension.load.assert_called() mock_extension.model_override.assert_called() assert mock_predict_called diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index c3283627d..4e2729f3f 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -261,6 +261,31 @@ def test_async_streaming(): assert predict_non_stream_response.json() == "01234" +@pytest.mark.integration +def test_async_streaming_with_cancel(): + with ensure_kill_all(): + truss_root = Path(__file__).parent.parent.parent.resolve() / "truss" + + truss_dir = ( + truss_root + / "test_data" + / "test_streaming_async_cancellable_generator_truss" + ) + + tr = TrussHandle(truss_dir) + + container = tr.docker_run( + local_port=8090, detach=True, wait_for_server_ready=True + ) + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + + with pytest.raises(requests.ConnectionError): + requests.post(full_url, json={}, stream=False, timeout=1) + time.sleep(2) # Wait a bit to get all logs. + assert "Cancelled (during gen)." in container.logs() + + @pytest.mark.integration def test_async_streaming_timeout(): with ensure_kill_all():