diff --git a/truss/templates/control/control/application.py b/truss/templates/control/control/application.py index dac4804d2..4b121538e 100644 --- a/truss/templates/control/control/application.py +++ b/truss/templates/control/control/application.py @@ -1,5 +1,4 @@ import asyncio -import contextlib import logging import re from pathlib import Path @@ -46,20 +45,6 @@ async def handle_model_load_failed(_, error): return JSONResponse({"error": str(error)}, 503) -@contextlib.asynccontextmanager -async def lifespan_context(app: FastAPI): - # Before start. - yield # Run. - # Shutdown. - # FastApi handles the term signal to start the shutdown flow. Here we - # make sure that the inference server is stopeed when control server - # shuts down. Inference server has logic to wait until all requests are - # finished before exiting. By waiting on that, we inherit the same - # behavior for control server. - app.state.logger.info("Term signal received, shutting down.") - app.state.inference_server_process_controller.terminate_with_wait() - - def create_app(base_config: Dict): app_state = State() setup_logging() @@ -114,11 +99,20 @@ async def start_background_inference_startup(): ModelLoadFailed: handle_model_load_failed, Exception: generic_error_handler, }, - lifespan=lifespan_context, ) app.state = app_state app.include_router(control_app) + @app.on_event("shutdown") + def on_shutdown(): + # FastApi handles the term signal to start the shutdown flow. Here we + # make sure that the inference server is stopeed when control server + # shuts down. Inference server has logic to wait until all requests are + # finished before exiting. By waiting on that, we inherit the same + # behavior for control server. + app.state.logger.info("Term signal received, shutting down.") + app.state.inference_server_process_controller.terminate_with_wait() + return app diff --git a/truss/templates/control/control/endpoints.py b/truss/templates/control/control/endpoints.py index 200bd97cd..98a6c57ef 100644 --- a/truss/templates/control/control/endpoints.py +++ b/truss/templates/control/control/endpoints.py @@ -57,7 +57,6 @@ async def proxy(request: Request): ), stop=stop_after_attempt(INFERENCE_SERVER_START_WAIT_SECS), wait=wait_fixed(1), - reraise=True, ): with attempt: try: diff --git a/truss/templates/control/control/helpers/inference_server_starter.py b/truss/templates/control/control/helpers/inference_server_starter.py index c8f1d23e9..a9d5157c8 100644 --- a/truss/templates/control/control/helpers/inference_server_starter.py +++ b/truss/templates/control/control/helpers/inference_server_starter.py @@ -41,7 +41,6 @@ def inference_server_startup_flow( for attempt in Retrying( stop=stop_after_attempt(15), wait=wait_exponential(multiplier=2, min=1, max=4), - reraise=True, ): with attempt: try: diff --git a/truss/templates/shared/util.py b/truss/templates/shared/util.py index 2a35ee5e4..f4d7f45c0 100644 --- a/truss/templates/shared/util.py +++ b/truss/templates/shared/util.py @@ -3,7 +3,7 @@ import shutil import sys from pathlib import Path -from typing import List, TypeVar +from typing import List import psutil import requests @@ -80,11 +80,6 @@ def kill_child_processes(parent_pid: int): process.kill() -X = TypeVar("X") -Y = TypeVar("Y") -Z = TypeVar("Z") - - def download_from_url_using_requests(URL: str, download_to: Path): # Streaming download to keep memory usage low resp = requests.get( diff --git a/truss/test_data/test_streaming_async_generator_truss/model/model.py b/truss/test_data/test_streaming_async_generator_truss/model/model.py index d5bdecf30..d120d2c87 100644 --- a/truss/test_data/test_streaming_async_generator_truss/model/model.py +++ b/truss/test_data/test_streaming_async_generator_truss/model/model.py @@ -3,5 +3,5 @@ class Model: async def predict(self, model_input: Any) -> Dict[str, List]: - for i in range(100): + for i in range(5): yield str(i) diff --git a/truss/util/data_structures.py b/truss/util/data_structures.py index 4f3eb2b24..0834dbfe6 100644 --- a/truss/util/data_structures.py +++ b/truss/util/data_structures.py @@ -2,7 +2,6 @@ X = TypeVar("X") Y = TypeVar("Y") -Z = TypeVar("Z") def transform_optional(x: Optional[X], fn: Callable[[X], Optional[Y]]) -> Optional[Y]: