From 47cecfc8bcbe6f3fe6540efa341f56b52aade5d0 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Wed, 28 Aug 2024 10:13:54 -0700 Subject: [PATCH] Fix tests. --- pyproject.toml | 2 +- .../image_builder/serving_image_builder.py | 6 +- truss/templates/server/common/tracing.py | 64 +++++++++++---- truss/templates/server/common/truss_server.py | 25 +++--- truss/templates/server/model_wrapper.py | 82 +++++++------------ truss/test_data/server.Dockerfile | 4 + .../model/model.py | 2 +- 7 files changed, 102 insertions(+), 83 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5f4039731..6fc7b9263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.30rc5" +version = "0.9.30rc667" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/contexts/image_builder/serving_image_builder.py b/truss/contexts/image_builder/serving_image_builder.py index e11a6d92d..4442ca8d4 100644 --- a/truss/contexts/image_builder/serving_image_builder.py +++ b/truss/contexts/image_builder/serving_image_builder.py @@ -459,11 +459,13 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): # are detected and cause a build failure. If there are no # requirements provided, we just pass an empty string, # as there's no need to install anything. - # TODO: above reasoning leads to inconsistencies. Needs revisit. + # TODO: above reasoning leads to inconsistencies. To get consistent images + # tentatively add server requirements always. This whole point needs more + # thought and potentially a re-design. user_provided_python_requirements = ( base_server_requirements + spec.requirements_txt if spec.requirements - else "" + else base_server_requirements ) if spec.requirements_file is not None: copy_into_build_dir( diff --git a/truss/templates/server/common/tracing.py b/truss/templates/server/common/tracing.py index 4ef750993..32c984e12 100644 --- a/truss/templates/server/common/tracing.py +++ b/truss/templates/server/common/tracing.py @@ -1,20 +1,26 @@ import contextlib -import functools import json import logging import os import pathlib import time -from typing import Iterator, List, Sequence +from typing import Iterator, List, Optional, Sequence -import opentelemetry.exporter.otlp.proto.grpc.trace_exporter as oltp_exporter +import opentelemetry.exporter.otlp.proto.http.trace_exporter as oltp_exporter import opentelemetry.sdk.resources as resources import opentelemetry.sdk.trace as sdk_trace import opentelemetry.sdk.trace.export as trace_export from opentelemetry import context, trace +from shared import secrets_resolver logger = logging.getLogger(__name__) +ATTR_NAME_DURATION = "duration_sec" +OTEL_EXPORTER_OTLP_ENDPOINT = "OTEL_EXPORTER_OTLP_ENDPOINT" +OTEL_TRACING_NDJSON_FILE = "OTEL_TRACING_NDJSON_FILE" +HONEYCOMB_DATASET = "HONEYCOMB_DATASET" +HONEYCOMB_API_KEY = "HONEYCOMB_API_KEY" + class JSONFileExporter(trace_export.SpanExporter): """Writes spans to newline-delimited JSON file for debugging / testing.""" @@ -36,8 +42,10 @@ def shutdown(self) -> None: self._file.close() -@functools.lru_cache(maxsize=1) -def get_truss_tracer() -> trace.Tracer: +_truss_tracer: Optional[trace.Tracer] = None + + +def get_truss_tracer(secrets: secrets_resolver.SecretsResolver) -> trace.Tracer: """Creates a cached tracer (i.e. runtime-singleton) to be used for truss internal tracing. @@ -45,17 +53,39 @@ def get_truss_tracer() -> trace.Tracer: completely from potential user-defined tracing - see also `detach_context` below. """ + global _truss_tracer + if _truss_tracer: + return _truss_tracer + span_processors: List[sdk_trace.SpanProcessor] = [] - if otlp_endpoint := os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT"): + if otlp_endpoint := os.getenv(OTEL_EXPORTER_OTLP_ENDPOINT): + logger.info(f"Exporting trace data to {OTEL_EXPORTER_OTLP_ENDPOINT}.") otlp_exporter = oltp_exporter.OTLPSpanExporter(endpoint=otlp_endpoint) otlp_processor = sdk_trace.export.BatchSpanProcessor(otlp_exporter) span_processors.append(otlp_processor) - if tracing_log_file := os.getenv("OTEL_TRACING_NDJSON_FILE"): + if tracing_log_file := os.getenv(OTEL_TRACING_NDJSON_FILE): + logger.info("Exporting trace data to `tracing_log_file`.") json_file_exporter = JSONFileExporter(pathlib.Path(tracing_log_file)) file_processor = sdk_trace.export.SimpleSpanProcessor(json_file_exporter) span_processors.append(file_processor) + if honeycomb_dataset := os.getenv(HONEYCOMB_DATASET): + if HONEYCOMB_API_KEY in secrets: + honeycomb_api_key = secrets[HONEYCOMB_API_KEY] + logger.info("Exporting trace data to honeycomb.") + honeycomb_exporter = oltp_exporter.OTLPSpanExporter( + endpoint="https://api.honeycomb.io/v1/traces", + headers={ + "x-honeycomb-team": honeycomb_api_key, + "x-honeycomb-dataset": honeycomb_dataset, + }, + ) + honeycomb_processor = sdk_trace.export.BatchSpanProcessor( + honeycomb_exporter + ) + span_processors.append(honeycomb_processor) + if span_processors: logger.info("Instantiating truss tracer.") resource = resources.Resource.create({resources.SERVICE_NAME: "TrussServer"}) @@ -67,7 +97,8 @@ def get_truss_tracer() -> trace.Tracer: logger.info("Using no-op tracing.") tracer = sdk_trace.NoOpTracer() - return tracer + _truss_tracer = tracer + return _truss_tracer @contextlib.contextmanager @@ -82,14 +113,9 @@ def detach_context() -> Iterator[None]: be wrapped in this context for isolation. """ current_context = context.get_current() - # Set the current context to an invalid span context, effectively clearing it. - # This makes sure inside the context a new root is context is created. - transient_token = context.attach( - trace.set_span_in_context( - trace.INVALID_SPAN, - trace.INVALID_SPAN_CONTEXT, # type: ignore[arg-type] - ) - ) + # Create an invalid tracing context. This forces that tracing code inside this + # context manager creates a new root tracing context. + transient_token = context.attach(trace.set_span_in_context(trace.INVALID_SPAN)) try: yield finally: @@ -105,9 +131,11 @@ def section_as_event(span: sdk_trace.Span, section_name: str) -> Iterator[None]: Note that events are much cheaper to create than dedicated spans. """ t0 = time.time() - span.add_event(f"start-{section_name}") + span.add_event(f"start: {section_name}") try: yield finally: t1 = time.time() - span.add_event(f"done-{section_name}", attributes={"duration_sec": t1 - t0}) + span.add_event( + f"done: {section_name}", attributes={ATTR_NAME_DURATION: t1 - t0} + ) diff --git a/truss/templates/server/common/truss_server.py b/truss/templates/server/common/truss_server.py index 5ae611e10..fe502a7ce 100644 --- a/truss/templates/server/common/truss_server.py +++ b/truss/templates/server/common/truss_server.py @@ -23,6 +23,7 @@ from opentelemetry import propagate as otel_propagate from opentelemetry.sdk import trace as sdk_trace from shared.logging import setup_logging +from shared.secrets_resolver import SecretsResolver from shared.serialization import ( DeepNumpyEncoder, truss_msgpack_deserialize, @@ -173,16 +174,19 @@ async def predict( response_headers = {} if self.is_binary(request): - response_headers["Content-Type"] = "application/octet-stream" - return Response( - content=truss_msgpack_serialize(response), headers=response_headers - ) + with tracing.section_as_event(span, "binary-serialize"): + response_headers["Content-Type"] = "application/octet-stream" + return Response( + content=truss_msgpack_serialize(response), + headers=response_headers, + ) else: - response_headers["Content-Type"] = "application/json" - return Response( - content=json.dumps(response, cls=DeepNumpyEncoder), - headers=response_headers, - ) + with tracing.section_as_event(span, "json-serialize"): + response_headers["Content-Type"] = "application/json" + return Response( + content=json.dumps(response, cls=DeepNumpyEncoder), + headers=response_headers, + ) async def schema(self, model_name: str) -> Dict: model: ModelWrapper = self._safe_lookup_model(model_name) @@ -223,7 +227,8 @@ def __init__( config: Dict, setup_json_logger: bool = True, ): - tracer = tracing.get_truss_tracer() + secrets = SecretsResolver.get_secrets(config) + tracer = tracing.get_truss_tracer(secrets) self.http_port = http_port self._config = config self._model = ModelWrapper(self._config, tracer) diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index 2569b6894..cd309adf2 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -20,7 +20,6 @@ Mapping, NoReturn, Optional, - Set, TypeVar, Union, ) @@ -50,38 +49,26 @@ TRT_LLM_EXTENSION_NAME = "trt_llm" -def aprint(msg: str): - task_id = str(hash(id(asyncio.current_task())))[:3] - print(f"Task[ {task_id} ]: {msg}") - - @asynccontextmanager async def deferred_semaphore_and_span( - semaphore: Semaphore, span: sdk_trace.Span + semaphore: Semaphore, span: trace.Span ) -> AsyncGenerator[Callable[[], Callable[[], None]], None]: """ Context manager that allows deferring the release of a semaphore and the ending of a trace span. - Yields a function that, when called, releases the semaphore and ends the span. If - that function is not called, the resources are cleand up when exiting the context. + Yields a function that, when called, releases the semaphore and ends the span. + If that function is not called, the resources are cleand up when exiting. """ - val_before = semaphore.value - aprint("requesting semaphore") await semaphore.acquire() - val_after = semaphore.value - aprint(f"acquired semaphore. {val_before} -> {val_after}") trace.use_span(span, end_on_exit=False) deferred = False def release_and_end() -> None: - aprint("called release.") semaphore.release() span.end() - aprint("releases semaphore.") def defer() -> Callable[[], None]: - aprint("called defer.") nonlocal deferred deferred = True return release_and_end @@ -89,12 +76,8 @@ def defer() -> Callable[[], None]: try: yield defer finally: - aprint("ending context.") if not deferred: - aprint("ending context - release.") release_and_end() - else: - aprint("ending context - keep.") class ModelWrapper: @@ -119,7 +102,6 @@ def __init__(self, config: Dict, tracer: sdk_trace.Tracer): "predict_concurrency", DEFAULT_PREDICT_CONCURRENCY ) ) - self._background_tasks: Set[asyncio.Task] = set() self.truss_schema: TrussSchema = None def load(self) -> bool: @@ -305,12 +287,13 @@ async def postprocess( ) async def write_response_to_queue( - self, queue: asyncio.Queue, generator: AsyncGenerator, span: sdk_trace.Span + self, queue: asyncio.Queue, generator: AsyncGenerator, span: trace.Span ): with tracing.section_as_event(span, "write_response_to_queue"): - aprint("start-write_response_to_queue") try: async for chunk in generator: + # TODO: consider checking `request.is_disconnected()` for + # client-side cancellations and freeing resources. await queue.put(ResponseChunk(chunk)) except Exception as e: self._logger.exception( @@ -318,12 +301,12 @@ async def write_response_to_queue( ) finally: await queue.put(None) - aprint("end-write_response_to_queue") - async def _gather_generator(self, response: Any, span: sdk_trace.Span) -> str: - # In the case of gathering, it might make more sense to apply the post-process + async def _gather_generator(self, response: Any, span: trace.Span) -> str: + # In the case of gathering, it might make more sense to apply the postprocess # to the gathered result, but that would be inconsistent with streaming. - # In general it might even be better to forbid postprocessing completely. + # In general, it might even be better to strictly forbid postprocessing + # for generators. if hasattr(self._model, "postprocess"): logging.warning( "Predict returned a streaming response, while a postprocess is defined." @@ -341,7 +324,7 @@ async def _gather_generator(self, response: Any, span: sdk_trace.Span) -> str: async def _stream_with_background_task( self, response: Any, - span: sdk_trace.Span, + span: trace.Span, release_and_end: Callable[[], None], ): # The streaming read timeout is the amount of time in between streamed chunk @@ -358,28 +341,23 @@ async def _stream_with_background_task( response_queue: asyncio.Queue = asyncio.Queue() # `write_response_to_queue` keeps running the background until completion. - task = asyncio.create_task( + gen_task = asyncio.create_task( self.write_response_to_queue(response_queue, async_generator, span) ) - # We add the task to the ModelWrapper instance to ensure it does - # not get garbage collected after the predict method completes, - # and continues running. - self._background_tasks.add(task) # Defer the release of the semaphore until the write_response_to_queue task. - task.add_done_callback(lambda _: release_and_end()) - task.add_done_callback(self._background_tasks.discard) + gen_task.add_done_callback(lambda _: release_and_end) # The gap between responses in a stream must be < streaming_read_timeout async def _response_generator(): - with tracing.section_as_event(span, "response_generator"): - aprint("start-response_generator") + # `span` is tied to the "producer" `gen_task` which might complete before + # "consume" part here finishes, therefore a dedicated span is required. + with self._tracer.start_as_current_span("response_generator"): while True: chunk = await asyncio.wait_for( response_queue.get(), timeout=streaming_read_timeout, ) if chunk is None: - aprint("done-response_generator") return yield chunk.value @@ -399,24 +377,28 @@ async def __call__( Generator: In case of streaming response String: in case of non-streamed generator (the string is the JSON result). """ - with self._tracer.start_as_current_span("predict-call-pre") as span: + with self._tracer.start_as_current_span("call-pre") as span_pre: if self.truss_schema is not None: try: - with tracing.section_as_event(span, "parse-pydantic"): + with tracing.section_as_event(span_pre, "parse-pydantic"): body = self.truss_schema.input_type(**body) except pydantic.ValidationError as e: self._logger.info("Request Validation Error") raise HTTPException( status_code=400, detail=f"Request Validation Error, {str(e)}" ) from e - with tracing.section_as_event(span, "preprocess"), tracing.detach_context(): + with tracing.section_as_event( + span_pre, "preprocess" + ), tracing.detach_context(): payload = await self.preprocess(body) - span = self._tracer.start_span("predict-call-predict") + span_predict = self._tracer.start_span("call-predict") async with deferred_semaphore_and_span( - self._predict_semaphore, span + self._predict_semaphore, span_predict ) as get_defer_fn: - with tracing.section_as_event(span, "predict"), tracing.detach_context(): + with tracing.section_as_event( + span_predict, "predict" + ), tracing.detach_context(): # To prevent span pollution, we need to make sure spans created by user # code don't inherit context from our spans (which happens even if # different tracer instances are used). @@ -430,29 +412,27 @@ async def __call__( # exactly handle that case we would need to apply `detach_context` # around each `next`-invocation that consumes the generator, which is # prohibitive. - aprint("start-predict") response = await self.predict(payload) - aprint("done-predict") if inspect.isgenerator(response) or inspect.isasyncgen(response): if headers and headers.get("accept") == "application/json": # In the case of a streaming response, consume stream # if the http accept header is set, and json is requested. - return await self._gather_generator(response, span) + return await self._gather_generator(response, span_predict) else: return await self._stream_with_background_task( - response, span, release_and_end=get_defer_fn() + response, span_predict, release_and_end=get_defer_fn() ) - with self._tracer.start_as_current_span("predict-call-post") as span: + with self._tracer.start_as_current_span("call-post") as span_post: with tracing.section_as_event( - span, "postprocess" + span_post, "postprocess" ), tracing.detach_context(): processed_response = await self.postprocess(response) if isinstance(processed_response, BaseModel): # If we return a pydantic object, convert it back to a dict - with tracing.section_as_event(span, "dump-pydantic"): + with tracing.section_as_event(span_post, "dump-pydantic"): processed_response = processed_response.dict() return processed_response diff --git a/truss/test_data/server.Dockerfile b/truss/test_data/server.Dockerfile index 7be57e609..7d97c435a 100644 --- a/truss/test_data/server.Dockerfile +++ b/truss/test_data/server.Dockerfile @@ -28,6 +28,10 @@ RUN apt update && \ COPY ./base_server_requirements.txt base_server_requirements.txt RUN pip install -r base_server_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip +COPY ./requirements.txt requirements.txt +RUN cat requirements.txt +RUN pip install -r requirements.txt --no-cache-dir && rm -rf /root/.cache/pip + ENV APP_HOME /app WORKDIR $APP_HOME diff --git a/truss/test_data/test_streaming_truss_with_tracing/model/model.py b/truss/test_data/test_streaming_truss_with_tracing/model/model.py index 5c541159b..249a0603f 100644 --- a/truss/test_data/test_streaming_truss_with_tracing/model/model.py +++ b/truss/test_data/test_streaming_truss_with_tracing/model/model.py @@ -57,7 +57,7 @@ def predict(self, model_input: Any) -> Generator[str, None, None]: with tracer.start_as_current_span("start-predict") as span: def inner(): - time.sleep(2) + time.sleep(0.02) for i in range(5): span.add_event("yield") yield str(i)