diff --git a/README.md b/README.md index 087b7763c..ac2d750ab 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,9 @@ Features | **[Fixed Schedule](docs/tutorials/fixed-schedule.md)** | Precise timestamp-based request execution | Traffic replay, temporal analysis, burst testing | | **[Time-based Benchmarking](docs/tutorials/time-based-benchmarking.md)** | Duration-based testing with grace period control | Stability testing, sustained performance | +### Working with Benchmark Data +- **[Profile Exports](docs/tutorials/working-with-profile-exports.md)** - Parse and analyze `profile_export.jsonl` with Pydantic models, custom metrics, and async processing + ### Quick Navigation ```bash # Basic profiling diff --git a/aiperf/common/config/__init__.py b/aiperf/common/config/__init__.py index 35a439bc0..2f11104f4 100644 --- a/aiperf/common/config/__init__.py +++ b/aiperf/common/config/__init__.py @@ -171,11 +171,11 @@ "load_user_config", "parse_file", "parse_service_types", + "parse_str_as_numeric_dict", "parse_str_or_csv_list", "parse_str_or_dict_as_tuple_list", "parse_str_or_list", "parse_str_or_list_of_positive_values", - "parse_str_as_numeric_dict", "print_developer_mode_warning", "print_str_or_list", ] diff --git a/aiperf/common/config/config_defaults.py b/aiperf/common/config/config_defaults.py index a0e8666bf..1b2e041bc 100644 --- a/aiperf/common/config/config_defaults.py +++ b/aiperf/common/config/config_defaults.py @@ -11,6 +11,7 @@ AudioFormat, CommunicationBackend, EndpointType, + ExportLevel, ImageFormat, ModelSelectionStrategy, RequestRateMode, @@ -114,12 +115,13 @@ class TurnDelayDefaults: @dataclass(frozen=True) class OutputDefaults: ARTIFACT_DIRECTORY = Path("./artifacts") - PROFILE_EXPORT_FILE = Path("profile_export.json") + PROFILE_EXPORT_FILE = Path("profile_export.jsonl") LOG_FOLDER = Path("logs") LOG_FILE = Path("aiperf.log") INPUTS_JSON_FILE = Path("inputs.json") PROFILE_EXPORT_AIPERF_CSV_FILE = Path("profile_export_aiperf.csv") PROFILE_EXPORT_AIPERF_JSON_FILE = Path("profile_export_aiperf.json") + EXPORT_LEVEL = ExportLevel.RECORDS @dataclass(frozen=True) diff --git a/aiperf/common/config/output_config.py b/aiperf/common/config/output_config.py index b38488e12..c346625e5 100644 --- a/aiperf/common/config/output_config.py +++ b/aiperf/common/config/output_config.py @@ -10,6 +10,7 @@ from aiperf.common.config.cli_parameter import CLIParameter from aiperf.common.config.config_defaults import OutputDefaults from aiperf.common.config.groups import Groups +from aiperf.common.enums import ExportLevel class OutputConfig(BaseConfig): @@ -32,3 +33,18 @@ class OutputConfig(BaseConfig): group=_CLI_GROUP, ), ] = OutputDefaults.ARTIFACT_DIRECTORY + + profile_export_file: Annotated[ + Path, + Field( + description="The file to store the profile export in JSONL format.", + ), + CLIParameter( + name=("--profile-export-file",), + group=_CLI_GROUP, + ), + ] = OutputDefaults.PROFILE_EXPORT_FILE + + @property + def export_level(self) -> ExportLevel: + return ExportLevel.RECORDS diff --git a/aiperf/common/constants.py b/aiperf/common/constants.py index 9f6a16bf1..aadf59a48 100644 --- a/aiperf/common/constants.py +++ b/aiperf/common/constants.py @@ -108,3 +108,6 @@ GOOD_REQUEST_COUNT_TAG = "good_request_count" """GoodRequestCount metric tag""" + +DEFAULT_RECORD_EXPORT_BATCH_SIZE = 100 +"""Default batch size for record export results processor.""" diff --git a/aiperf/common/enums/__init__.py b/aiperf/common/enums/__init__.py index 90aea3e32..ccfb2b9e7 100644 --- a/aiperf/common/enums/__init__.py +++ b/aiperf/common/enums/__init__.py @@ -26,6 +26,7 @@ from aiperf.common.enums.data_exporter_enums import ( ConsoleExporterType, DataExporterType, + ExportLevel, ) from aiperf.common.enums.dataset_enums import ( AudioFormat, @@ -53,7 +54,6 @@ BaseMetricUnit, BaseMetricUnitInfo, GenericMetricUnit, - MetricDateTimeUnit, MetricFlags, MetricOverTimeUnit, MetricOverTimeUnitInfo, @@ -122,12 +122,12 @@ "EndpointServiceKind", "EndpointType", "EndpointTypeInfo", + "ExportLevel", "GenericMetricUnit", "ImageFormat", "LifecycleState", "MediaType", "MessageType", - "MetricDateTimeUnit", "MetricFlags", "MetricOverTimeUnit", "MetricOverTimeUnitInfo", diff --git a/aiperf/common/enums/data_exporter_enums.py b/aiperf/common/enums/data_exporter_enums.py index ab9af638b..07548cac5 100644 --- a/aiperf/common/enums/data_exporter_enums.py +++ b/aiperf/common/enums/data_exporter_enums.py @@ -14,3 +14,16 @@ class ConsoleExporterType(CaseInsensitiveStrEnum): class DataExporterType(CaseInsensitiveStrEnum): JSON = "json" CSV = "csv" + + +class ExportLevel(CaseInsensitiveStrEnum): + """Export level for benchmark data.""" + + SUMMARY = "summary" + """Export only aggregated/summarized metrics (default, most compact)""" + + RECORDS = "records" + """Export per-record metrics after aggregation with display unit conversion""" + + RAW = "raw" + """Export raw parsed records with full request/response data (most detailed)""" diff --git a/aiperf/common/enums/metric_enums.py b/aiperf/common/enums/metric_enums.py index ccac7b14e..349021831 100644 --- a/aiperf/common/enums/metric_enums.py +++ b/aiperf/common/enums/metric_enums.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Callable -from datetime import datetime from enum import Flag from functools import cached_property from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar @@ -169,16 +168,9 @@ def long_name(self) -> str: def convert_to(self, other_unit: "MetricUnitT", value: int | float) -> float: """Convert a value from this unit to another unit.""" - if not isinstance( - other_unit, MetricTimeUnit | MetricTimeUnitInfo | MetricDateTimeUnit - ): + if not isinstance(other_unit, MetricTimeUnit | MetricTimeUnitInfo): return super().convert_to(other_unit, value) - if isinstance(other_unit, MetricDateTimeUnit): - return datetime.fromtimestamp( - self.convert_to(MetricTimeUnit.SECONDS, value) - ) - return value * (other_unit.per_second / self.per_second) @@ -197,12 +189,6 @@ class GenericMetricUnit(BaseMetricUnit): USER = _unit("user") -class MetricDateTimeUnit(BaseMetricUnit): - """Defines the various date time units that can be used for metrics.""" - - DATE_TIME = _unit("datetime") - - class MetricOverTimeUnitInfo(BaseMetricUnitInfo): """Information about a metric over time unit.""" @@ -445,6 +431,10 @@ class MetricFlags(Flag): GOODPUT = 1 << 10 """Metrics that are only applicable when goodput feature is enabled""" + NO_INDIVIDUAL_RECORDS = 1 << 11 + """Metrics that should not be exported for individual records. These are typically aggregate metrics. + This is used to filter out metrics such as request count or min/max timestamps that are not relevant to individual records.""" + def has_flags(self, flags: "MetricFlags") -> bool: """Return True if the metric has ALL of the given flag(s) (regardless of other flags).""" # Bitwise AND will return the input flags only if all of the given flags are present. diff --git a/aiperf/common/enums/post_processor_enums.py b/aiperf/common/enums/post_processor_enums.py index 96115ec8f..266d84db9 100644 --- a/aiperf/common/enums/post_processor_enums.py +++ b/aiperf/common/enums/post_processor_enums.py @@ -5,7 +5,11 @@ class RecordProcessorType(CaseInsensitiveStrEnum): - """Type of streaming record processor.""" + """Type of streaming record processor. + + Record processors are responsible for streaming records and computing metrics from MetricType.RECORD and MetricType.AGGREGATE. + This is the first stage of the processing pipeline, and is done is a distributed manner across multiple service instances. + """ METRIC_RECORD = "metric_record" """Streamer that streams records and computes metrics from MetricType.RECORD and MetricType.AGGREGATE. @@ -13,8 +17,17 @@ class RecordProcessorType(CaseInsensitiveStrEnum): class ResultsProcessorType(CaseInsensitiveStrEnum): - """Type of streaming results processor.""" + """Type of streaming results processor. + + Results processors are responsible for processing results from RecordProcessors and computing metrics from MetricType.DERIVED. + as well as aggregating the results. + This is the last stage of the processing pipeline, and is done from the single instance of the RecordsManager. + """ METRIC_RESULTS = "metric_results" """Processor that processes the metric results from METRIC_RECORD and computes metrics from MetricType.DERIVED. as well as aggregates the results. This is the last stage of the metrics processing pipeline, and is done from the RecordsManager after all the service instances have completed their processing.""" + + RECORD_EXPORT = "record_export" + """Processor that exports per-record metrics to JSONL files with display unit conversion and filtering. + Only enabled when export_level is set to RECORDS.""" diff --git a/aiperf/common/exceptions.py b/aiperf/common/exceptions.py index 912e8266a..96f0e7765 100644 --- a/aiperf/common/exceptions.py +++ b/aiperf/common/exceptions.py @@ -143,6 +143,10 @@ class NoMetricValue(AIPerfError): """Raised when a metric value is not available.""" +class PostProcessorDisabled(AIPerfError): + """Raised when initializing a post processor to indicate to the caller that it is disabled and should not be used.""" + + class ProxyError(AIPerfError): """Exception raised when a proxy encounters an error.""" diff --git a/aiperf/common/messages/__init__.py b/aiperf/common/messages/__init__.py index 7349b00be..83745b23d 100644 --- a/aiperf/common/messages/__init__.py +++ b/aiperf/common/messages/__init__.py @@ -53,6 +53,7 @@ ) from aiperf.common.messages.inference_messages import ( InferenceResultsMessage, + MetricRecordsData, MetricRecordsMessage, RealtimeMetricsMessage, ) @@ -107,6 +108,7 @@ "HeartbeatMessage", "InferenceResultsMessage", "Message", + "MetricRecordsData", "MetricRecordsMessage", "ProcessRecordsCommand", "ProcessRecordsResponse", diff --git a/aiperf/common/messages/credit_messages.py b/aiperf/common/messages/credit_messages.py index f19edb889..4ddd62cbe 100644 --- a/aiperf/common/messages/credit_messages.py +++ b/aiperf/common/messages/credit_messages.py @@ -22,7 +22,15 @@ class CreditDropMessage(BaseServiceMessage): default_factory=lambda: str(uuid.uuid4()), description="The ID of the credit drop, that will be used as the X-Correlation-ID header.", ) - phase: CreditPhase = Field(..., description="The type of credit phase") + phase: CreditPhase = Field( + ..., description="The type of credit phase, such as warmup or profiling." + ) + credit_num: int = Field( + ..., + ge=0, + description="The sequential number of the credit in the credit phase. This is used to track the progress of the credit phase," + " as well as the order that requests are sent in.", + ) conversation_id: str | None = Field( default=None, description="The ID of the conversation, if applicable." ) diff --git a/aiperf/common/messages/inference_messages.py b/aiperf/common/messages/inference_messages.py index 6337d57ac..8a8bfd46f 100644 --- a/aiperf/common/messages/inference_messages.py +++ b/aiperf/common/messages/inference_messages.py @@ -1,21 +1,19 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from pydantic import ( - Field, - SerializeAsAny, -) - -from aiperf.common.enums import ( - CreditPhase, - MessageType, -) +from pydantic import Field, SerializeAsAny + +from aiperf.common.aiperf_logger import AIPerfLogger +from aiperf.common.enums import MessageType from aiperf.common.enums.metric_enums import MetricValueTypeT from aiperf.common.messages.service_messages import BaseServiceMessage from aiperf.common.models import ErrorDetails, RequestRecord -from aiperf.common.models.record_models import MetricResult +from aiperf.common.models.base_models import AIPerfBaseModel +from aiperf.common.models.record_models import MetricRecordMetadata, MetricResult from aiperf.common.types import MessageTypeT, MetricTagT +_logger = AIPerfLogger(__name__) + class InferenceResultsMessage(BaseServiceMessage): """Message for a inference results.""" @@ -27,29 +25,36 @@ class InferenceResultsMessage(BaseServiceMessage): ) +class MetricRecordsData(AIPerfBaseModel): + """Incoming data from the record processor service to combine metric records for the profile run.""" + + metadata: MetricRecordMetadata = Field( + ..., description="The metadata of the request record." + ) + metrics: dict[MetricTagT, MetricValueTypeT] = Field( + ..., description="The combined metric records for this inference request." + ) + error: ErrorDetails | None = Field( + default=None, description="The error details if the request failed." + ) + + @property + def valid(self) -> bool: + """Whether the request was valid.""" + return self.error is None + + class MetricRecordsMessage(BaseServiceMessage): """Message from the result parser to the records manager to notify it of the metric records for a single request.""" message_type: MessageTypeT = MessageType.METRIC_RECORDS - timestamp_ns: int = Field( - ..., description="The wall clock timestamp of the request in nanoseconds." - ) - x_request_id: str | None = Field( - default=None, description="The X-Request-ID header of the request." - ) - x_correlation_id: str | None = Field( - default=None, description="The X-Correlation-ID header of the request." - ) - worker_id: str = Field( - ..., description="The ID of the worker that processed the request." - ) - credit_phase: CreditPhase = Field( - ..., description="The credit phase of the request." + metadata: MetricRecordMetadata = Field( + ..., description="The metadata of the request record." ) results: list[dict[MetricTagT, MetricValueTypeT]] = Field( - ..., description="The record processor results" + ..., description="The record processor metric results" ) error: ErrorDetails | None = Field( default=None, description="The error details if the request failed." @@ -60,6 +65,24 @@ def valid(self) -> bool: """Whether the request was valid.""" return self.error is None + def to_data(self) -> MetricRecordsData: + """Convert the metric records message to a MetricRecordsData for processing by the records manager.""" + metrics = {} + for result in self.results: + for tag, value in result.items(): + if tag in metrics: + _logger.warning( + f"Duplicate metric tag '{tag}' found in results. " + f"Overwriting previous value {metrics[tag]} with {value}." + ) + metrics[tag] = value + + return MetricRecordsData( + metadata=self.metadata, + metrics=metrics, + error=self.error, + ) + class RealtimeMetricsMessage(BaseServiceMessage): """Message from the records manager to show real-time metrics for the profile run.""" diff --git a/aiperf/common/models/__init__.py b/aiperf/common/models/__init__.py index f73c55320..fc1c2ea7a 100644 --- a/aiperf/common/models/__init__.py +++ b/aiperf/common/models/__init__.py @@ -50,7 +50,10 @@ BaseResponseData, EmbeddingResponseData, InferenceServerResponse, + MetricRecordInfo, + MetricRecordMetadata, MetricResult, + MetricValue, ParsedResponse, ParsedResponseRecord, ProcessRecordsResult, @@ -90,7 +93,10 @@ "InferenceServerResponse", "InputsFile", "Media", + "MetricRecordInfo", + "MetricRecordMetadata", "MetricResult", + "MetricValue", "ParsedResponse", "ParsedResponseRecord", "ProcessHealth", diff --git a/aiperf/common/models/base_models.py b/aiperf/common/models/base_models.py index dbb500150..e91808b65 100644 --- a/aiperf/common/models/base_models.py +++ b/aiperf/common/models/base_models.py @@ -37,7 +37,8 @@ class AIPerfBaseModel(BaseModel): are None. This is set by the @exclude_if_none decorator. """ - model_config = ConfigDict(arbitrary_types_allowed=True) + # Allow extras by default to be more flexible for end users + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") @model_serializer def _serialize_model(self) -> dict[str, Any]: diff --git a/aiperf/common/models/record_models.py b/aiperf/common/models/record_models.py index 26a5777cb..38ccf3434 100644 --- a/aiperf/common/models/record_models.py +++ b/aiperf/common/models/record_models.py @@ -14,6 +14,7 @@ from aiperf.common.constants import NANOS_PER_SECOND from aiperf.common.enums import CreditPhase, SSEFieldType +from aiperf.common.enums.metric_enums import MetricValueTypeT from aiperf.common.models.base_models import AIPerfBaseModel from aiperf.common.models.dataset_models import Turn from aiperf.common.models.error_models import ErrorDetails, ErrorDetailsCount @@ -26,7 +27,7 @@ class MetricResult(AIPerfBaseModel): tag: MetricTagT = Field(description="The unique identifier of the metric") # NOTE: We do not use a MetricUnitT here, as that is harder to de-serialize from JSON strings with pydantic. # If we need an instance of a MetricUnitT, lookup the unit based on the tag in the MetricRegistry. - unit: str = Field(description="The unit of the metric, e.g. 'ms'") + unit: str = Field(description="The unit of the metric, e.g. 'ms' or 'requests/sec'") header: str = Field( description="The user friendly name of the metric (e.g. 'Inter Token Latency')" ) @@ -55,6 +56,90 @@ def to_display_unit(self) -> "MetricResult": return to_display_unit(self, MetricRegistry) +class MetricValue(AIPerfBaseModel): + """The value of a metric converted to display units for export.""" + + value: MetricValueTypeT + unit: str + + +class MetricRecordMetadata(AIPerfBaseModel): + """The metadata of a metric record for export.""" + + session_num: int = Field( + ..., + description="The sequential number of the session in the benchmark. For single-turn datasets, this will be the" + " request index. For multi-turn datasets, this will be the session index.", + ) + x_request_id: str | None = Field( + default=None, + description="The X-Request-ID header of the request. This is a unique ID for the request.", + ) + x_correlation_id: str | None = Field( + default=None, + description="The X-Correlation-ID header of the request. This is a shared ID for each user session/conversation in multi-turn.", + ) + conversation_id: str | None = Field( + default=None, + description="The ID of the conversation (if applicable). This can be used to lookup the original request data from the inputs.json file.", + ) + turn_index: int | None = Field( + default=None, + description="The index of the turn in the conversation (if applicable). This can be used to lookup the original request data from the inputs.json file.", + ) + request_start_ns: int = Field( + ..., + description="The wall clock timestamp of the request start time measured as time.time_ns().", + ) + request_ack_ns: int | None = Field( + default=None, + description="The wall clock timestamp of the request acknowledgement from the server, measured as time.time_ns(), if applicable. " + "This is only applicable to streaming requests, and servers that send 200 OK back immediately after the request is received.", + ) + request_end_ns: int = Field( + ..., + description="The wall clock timestamp of the request end time measured as time.time_ns(). If the request failed, " + "this will be the time of the error.", + ) + worker_id: str = Field( + ..., description="The ID of the AIPerf worker that processed the request." + ) + record_processor_id: str = Field( + ..., + description="The ID of the AIPerf record processor that processed the record.", + ) + benchmark_phase: CreditPhase = Field( + ..., + description="The benchmark phase of the record, either warmup or profiling.", + ) + was_cancelled: bool = Field( + default=False, + description="Whether the request was cancelled during execution.", + ) + cancellation_time_ns: int | None = Field( + default=None, + description="The wall clock timestamp of the request cancellation time measured as time.time_ns(), if applicable. " + "This is only applicable to requests that were cancelled.", + ) + + +class MetricRecordInfo(AIPerfBaseModel): + """The full info of a metric record including the metadata, metrics, and error for export.""" + + metadata: MetricRecordMetadata = Field( + ..., + description="The metadata of the record. Should match the metadata in the MetricRecordsMessage.", + ) + metrics: dict[str, MetricValue] = Field( + ..., + description="A dictionary containing all metric values along with their units.", + ) + error: ErrorDetails | None = Field( + default=None, + description="The error details if the request failed.", + ) + + class ProfileResults(AIPerfBaseModel): records: list[MetricResult] | None = Field( ..., description="The records of the profile results" @@ -175,6 +260,12 @@ class RequestRecord(AIPerfBaseModel): default=None, description="The turn of the request, if applicable.", ) + credit_num: int | None = Field( + default=None, + ge=0, + description="The sequential number of the credit in the credit phase. This is used to track the progress of the credit phase," + " as well as the order that requests are sent in.", + ) conversation_id: str | None = Field( default=None, description="The ID of the conversation (if applicable).", diff --git a/aiperf/common/protocols.py b/aiperf/common/protocols.py index a0c5df65c..6abf97417 100644 --- a/aiperf/common/protocols.py +++ b/aiperf/common/protocols.py @@ -30,7 +30,6 @@ MessageOutputT, MessageT, MessageTypeT, - MetricTagT, ModelEndpointInfoT, RequestInputT, RequestOutputT, @@ -41,7 +40,7 @@ from rich.console import Console from aiperf.common.config import ServiceConfig, UserConfig - from aiperf.common.enums.metric_enums import MetricValueTypeT + from aiperf.common.messages.inference_messages import MetricRecordsData from aiperf.common.models.record_models import MetricResult from aiperf.exporters.exporter_config import ExporterConfig, FileExportInfo from aiperf.metrics.metric_dicts import MetricRecordDict @@ -503,9 +502,7 @@ class ResultsProcessorProtocol(Protocol): """Protocol for a results processor that processes the results of multiple record processors, and provides the ability to summarize the results.""" - async def process_result( - self, result: dict[MetricTagT, "MetricValueTypeT"] - ) -> None: ... + async def process_result(self, record_data: "MetricRecordsData") -> None: ... async def summarize(self) -> list["MetricResult"]: ... diff --git a/aiperf/common/utils.py b/aiperf/common/utils.py index f07bd4337..8d67c84e3 100644 --- a/aiperf/common/utils.py +++ b/aiperf/common/utils.py @@ -107,6 +107,27 @@ async def yield_to_event_loop() -> None: await asyncio.sleep(0) +def compute_time_ns( + start_time_ns: int, start_perf_ns: int, perf_ns: int | None +) -> int | None: + """Convert a perf_ns timestamp to a wall clock time_ns timestamp by + computing the absolute duration in perf_ns (perf_ns - start_perf_ns) and adding it to the start_time_ns. + + Args: + start_time_ns: The wall clock start time in nanoseconds (time.time_ns). + start_perf_ns: The start perf time in nanoseconds (perf_counter_ns). + perf_ns: The perf time in nanoseconds to convert to time_ns (perf_counter_ns). + + Returns: + The perf_ns converted to time_ns, or None if the perf_ns is None. + """ + if perf_ns is None: + return None + if perf_ns < start_perf_ns: + raise ValueError(f"perf_ns {perf_ns} is before start_perf_ns {start_perf_ns}") + return start_time_ns + (perf_ns - start_perf_ns) + + # This is used to identify the source file of the call_all_functions function # in the AIPerfLogger class to skip it when determining the caller. # NOTE: Using similar logic to logging._srcfile diff --git a/aiperf/metrics/metric_dicts.py b/aiperf/metrics/metric_dicts.py index 152fdfa53..f8ed38aa9 100644 --- a/aiperf/metrics/metric_dicts.py +++ b/aiperf/metrics/metric_dicts.py @@ -4,6 +4,7 @@ import numpy as np +from aiperf.common.aiperf_logger import AIPerfLogger from aiperf.common.enums import MetricType from aiperf.common.enums.metric_enums import ( MetricDictValueTypeT, @@ -11,18 +12,21 @@ MetricValueTypeT, MetricValueTypeVarT, ) -from aiperf.common.exceptions import NoMetricValue -from aiperf.common.models.record_models import MetricResult +from aiperf.common.exceptions import MetricTypeError, MetricUnitError, NoMetricValue +from aiperf.common.models.record_models import MetricResult, MetricValue from aiperf.common.types import MetricTagT if TYPE_CHECKING: from aiperf.metrics.base_metric import BaseMetric + from aiperf.metrics.metric_registry import MetricRegistry MetricDictValueTypeVarT = TypeVar( "MetricDictValueTypeVarT", bound="MetricValueTypeT | MetricDictValueTypeT" ) +_logger = AIPerfLogger(__name__) + class BaseMetricDict( Generic[MetricDictValueTypeVarT], dict[MetricTagT, MetricDictValueTypeVarT] @@ -54,7 +58,57 @@ class MetricRecordDict(BaseMetricDict[MetricValueTypeT]): - No `BaseDerivedMetric`s will be included. """ - pass # Everything is handled by the BaseMetricDict class. + def to_display_dict( + self, registry: "type[MetricRegistry]", show_internal: bool = False + ) -> dict[str, MetricValue]: + """Convert to display units with filtering applied. + NOTE: This will not include metrics with the `NO_INDIVIDUAL_RECORDS` flag. + + Args: + registry: MetricRegistry class for looking up metric definitions + show_internal: If True, include experimental/internal metrics + + Returns: + Dictionary of {tag: MetricValue} for export + """ + from aiperf.common.enums import MetricFlags + + result = {} + for tag, value in self.items(): + try: + metric_class = registry.get_class(tag) + except MetricTypeError: + _logger.warning(f"Metric {tag} not found in registry") + continue + + if not show_internal and not metric_class.missing_flags( + MetricFlags.EXPERIMENTAL | MetricFlags.INTERNAL + ): + continue + + if metric_class.has_flags(MetricFlags.NO_INDIVIDUAL_RECORDS): + continue + + display_unit = metric_class.display_unit or metric_class.unit + if display_unit != metric_class.unit: + try: + if isinstance(value, list): + value = [ + metric_class.unit.convert_to(display_unit, v) for v in value + ] + else: + value = metric_class.unit.convert_to(display_unit, value) + except MetricUnitError as e: + _logger.warning( + f"Error converting {tag} from {metric_class.unit} to {display_unit}: {e}" + ) + + result[tag] = MetricValue( + value=value, + unit=str(display_unit), + ) + + return result class MetricResultsDict(BaseMetricDict[MetricDictValueTypeT]): diff --git a/aiperf/metrics/metric_registry.py b/aiperf/metrics/metric_registry.py index fa54dcd47..e683f274c 100644 --- a/aiperf/metrics/metric_registry.py +++ b/aiperf/metrics/metric_registry.py @@ -151,8 +151,8 @@ def tags_applicable_to( applicable to non-streaming endpoints, etc. Arguments: - required_flags: The flags that the metric must have. - disallowed_flags: The flags that the metric must not have. + required_flags: The flags that the metric must have ALL of. If MetricFlags.NONE, no flags are required. + disallowed_flags: The flags that the metric must not have ANY of. types: The types of metrics to include. If not provided, all types will be included. Returns: @@ -161,7 +161,10 @@ def tags_applicable_to( return [ tag for tag, metric_class in cls._metrics_map.items() - if metric_class.has_flags(required_flags) + if ( + required_flags == MetricFlags.NONE + or metric_class.has_flags(required_flags) + ) and metric_class.missing_flags(disallowed_flags) and (not types or metric_class.type in types) ] diff --git a/aiperf/metrics/types/error_request_count.py b/aiperf/metrics/types/error_request_count.py index e71fb2dd1..2de4d1277 100644 --- a/aiperf/metrics/types/error_request_count.py +++ b/aiperf/metrics/types/error_request_count.py @@ -21,5 +21,5 @@ class ErrorRequestCountMetric(BaseAggregateCounterMetric[int]): short_header = "Error Count" short_header_hide_unit = True unit = GenericMetricUnit.REQUESTS - flags = MetricFlags.ERROR_ONLY + flags = MetricFlags.ERROR_ONLY | MetricFlags.NO_INDIVIDUAL_RECORDS required_metrics = None diff --git a/aiperf/metrics/types/input_sequence_length_metric.py b/aiperf/metrics/types/input_sequence_length_metric.py index 384fdfa2d..65bd01ba5 100644 --- a/aiperf/metrics/types/input_sequence_length_metric.py +++ b/aiperf/metrics/types/input_sequence_length_metric.py @@ -11,7 +11,7 @@ class InputSequenceLengthMetric(BaseRecordMetric[int]): """ - Post-processor for calculating Input Sequence Length (ISL) metrics from records. + Post-processor for calculating Input Sequence Length (ISL) metrics from valid records. Formula: Input Sequence Length = Sum of Input Token Counts @@ -44,7 +44,7 @@ def _parse_record( class TotalInputSequenceLengthMetric(DerivedSumMetric[int, InputSequenceLengthMetric]): """ - This is the total number of input tokens processed by the benchmark. + This is the total number of input tokens processed by the benchmark for valid records. Formula: ``` @@ -61,3 +61,44 @@ class TotalInputSequenceLengthMetric(DerivedSumMetric[int, InputSequenceLengthMe | MetricFlags.LARGER_IS_BETTER | MetricFlags.NO_CONSOLE ) + + +class ErrorInputSequenceLengthMetric(InputSequenceLengthMetric): + """ + Post-processor for calculating Input Sequence Length (ISL) metrics from error records. + """ + + tag = "error_isl" + header = "Error Input Sequence Length" + short_header = "Error ISL" + unit = GenericMetricUnit.TOKENS + flags = ( + MetricFlags.PRODUCES_TOKENS_ONLY + | MetricFlags.LARGER_IS_BETTER + | MetricFlags.NO_CONSOLE + | MetricFlags.ERROR_ONLY + ) + + +class TotalErrorInputSequenceLengthMetric( + DerivedSumMetric[int, ErrorInputSequenceLengthMetric] +): + """ + This is the total number of input tokens processed in the benchmark for error records. + + Formula: + ``` + Total Error Input Sequence Length = Sum(Error Input Sequence Lengths) + ``` + """ + + tag = "total_error_isl" + header = "Total Error Input Sequence Length" + short_header = "Total Error ISL" + short_header_hide_unit = True + flags = ( + MetricFlags.PRODUCES_TOKENS_ONLY + | MetricFlags.LARGER_IS_BETTER + | MetricFlags.NO_CONSOLE + | MetricFlags.ERROR_ONLY + ) diff --git a/aiperf/metrics/types/max_response_metric.py b/aiperf/metrics/types/max_response_metric.py index 7bf34397b..3eac6e387 100644 --- a/aiperf/metrics/types/max_response_metric.py +++ b/aiperf/metrics/types/max_response_metric.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from aiperf.common.enums import MetricDateTimeUnit, MetricFlags, MetricTimeUnit +from aiperf.common.enums import MetricFlags, MetricTimeUnit from aiperf.common.models import ParsedResponseRecord from aiperf.metrics import BaseAggregateMetric from aiperf.metrics.metric_dicts import MetricRecordDict @@ -21,8 +21,7 @@ class MaxResponseTimestampMetric(BaseAggregateMetric[int]): short_header = "Max Resp" short_header_hide_unit = True unit = MetricTimeUnit.NANOSECONDS - display_unit = MetricDateTimeUnit.DATE_TIME - flags = MetricFlags.NO_CONSOLE + flags = MetricFlags.NO_CONSOLE | MetricFlags.NO_INDIVIDUAL_RECORDS required_metrics = { RequestLatencyMetric.tag, } diff --git a/aiperf/metrics/types/min_request_metric.py b/aiperf/metrics/types/min_request_metric.py index e64ce86b6..e0d289fd3 100644 --- a/aiperf/metrics/types/min_request_metric.py +++ b/aiperf/metrics/types/min_request_metric.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import sys -from aiperf.common.enums import MetricDateTimeUnit, MetricFlags, MetricTimeUnit +from aiperf.common.enums import MetricFlags, MetricTimeUnit from aiperf.common.models import ParsedResponseRecord from aiperf.metrics import BaseAggregateMetric from aiperf.metrics.metric_dicts import MetricRecordDict @@ -21,8 +21,7 @@ class MinRequestTimestampMetric(BaseAggregateMetric[int]): short_header = "Min Req" short_header_hide_unit = True unit = MetricTimeUnit.NANOSECONDS - display_unit = MetricDateTimeUnit.DATE_TIME - flags = MetricFlags.NO_CONSOLE + flags = MetricFlags.NO_CONSOLE | MetricFlags.NO_INDIVIDUAL_RECORDS required_metrics = None def __init__(self) -> None: diff --git a/aiperf/metrics/types/request_count_metric.py b/aiperf/metrics/types/request_count_metric.py index 0f7ba89f0..75507969e 100644 --- a/aiperf/metrics/types/request_count_metric.py +++ b/aiperf/metrics/types/request_count_metric.py @@ -21,6 +21,6 @@ class RequestCountMetric(BaseAggregateCounterMetric[int]): short_header = "Requests" short_header_hide_unit = True unit = GenericMetricUnit.REQUESTS - display_order = 1000 - flags = MetricFlags.LARGER_IS_BETTER + display_order = 1100 + flags = MetricFlags.LARGER_IS_BETTER | MetricFlags.NO_INDIVIDUAL_RECORDS required_metrics = None diff --git a/aiperf/parsers/inference_result_parser.py b/aiperf/parsers/inference_result_parser.py index e3f225532..3c3e47f54 100644 --- a/aiperf/parsers/inference_result_parser.py +++ b/aiperf/parsers/inference_result_parser.py @@ -90,7 +90,7 @@ async def configure(self) -> None: self.info(f"Initialized tokenizers: {tokenizer_info} in {duration:.2f} seconds") async def get_tokenizer(self, model: str) -> Tokenizer: - """Get the tokenizer for a given model.""" + """Get the tokenizer for a given model or create it if it doesn't exist.""" async with self.tokenizer_lock: if model not in self.tokenizers: self.tokenizers[model] = Tokenizer.from_pretrained( @@ -110,9 +110,17 @@ async def parse_request_record( ) if request_record.has_error: + # Even for error records, compute input token count if possible + try: + input_token_count = await self.compute_input_token_count(request_record) + except Exception as e: + self.warning(f"Error computing input token count for error record: {e}") + input_token_count = None + return ParsedResponseRecord( request=request_record, responses=[], + input_token_count=input_token_count, ) elif request_record.valid: @@ -127,9 +135,18 @@ async def parse_request_record( # TODO: We should add an ErrorDetails to the response record and not the request record. self.exception(f"Error processing valid record: {e}") request_record.error = ErrorDetails.from_exception(e) + + try: + input_token_count = await self.compute_input_token_count( + request_record + ) + except Exception: + input_token_count = None + return ParsedResponseRecord( request=request_record, responses=[], + input_token_count=input_token_count, ) else: self.warning(f"Received invalid inference results: {request_record}") @@ -139,9 +156,16 @@ async def parse_request_record( message="Invalid inference results", type="InvalidInferenceResults", ) + + try: + input_token_count = await self.compute_input_token_count(request_record) + except Exception: + input_token_count = None + return ParsedResponseRecord( request=request_record, responses=[], + input_token_count=input_token_count, ) async def process_valid_record( @@ -159,11 +183,8 @@ async def process_valid_record( output_token_count=None, ) - tokenizer = await self.get_tokenizer(request_record.model_name) resp = await self.extractor.extract_response_data(request_record) - input_token_count = await self.compute_input_token_count( - request_record, tokenizer - ) + input_token_count = await self.compute_input_token_count(request_record) output_texts: list[str] = [] reasoning_texts: list[str] = [] @@ -176,6 +197,7 @@ async def process_valid_record( else: output_texts.append(response.data.get_text()) + tokenizer = await self.get_tokenizer(request_record.model_name) output_token_count = ( len(tokenizer.encode("".join(output_texts))) if output_texts else None ) @@ -218,13 +240,14 @@ async def get_turn(self, request_record: RequestRecord) -> Turn | None: return turn_response.turn async def compute_input_token_count( - self, request_record: RequestRecord, tokenizer: Tokenizer + self, request_record: RequestRecord ) -> int | None: """Compute the number of tokens in the input for a given request record.""" turn = await self.get_turn(request_record) if turn is None: return None + tokenizer = await self.get_tokenizer(request_record.model_name) input_token_count = 0 for text in turn.texts: input_token_count += len(tokenizer.encode("".join(text.contents))) diff --git a/aiperf/post_processors/__init__.py b/aiperf/post_processors/__init__.py index a70ca91c1..1cd73ab5d 100644 --- a/aiperf/post_processors/__init__.py +++ b/aiperf/post_processors/__init__.py @@ -17,5 +17,13 @@ from aiperf.post_processors.metric_results_processor import ( MetricResultsProcessor, ) +from aiperf.post_processors.record_export_results_processor import ( + RecordExportResultsProcessor, +) -__all__ = ["BaseMetricsProcessor", "MetricRecordProcessor", "MetricResultsProcessor"] +__all__ = [ + "BaseMetricsProcessor", + "MetricRecordProcessor", + "MetricResultsProcessor", + "RecordExportResultsProcessor", +] diff --git a/aiperf/post_processors/base_metrics_processor.py b/aiperf/post_processors/base_metrics_processor.py index be6b0221b..40318039d 100644 --- a/aiperf/post_processors/base_metrics_processor.py +++ b/aiperf/post_processors/base_metrics_processor.py @@ -82,15 +82,15 @@ def _setup_metrics( disallowed_flags |= MetricFlags.GOODPUT metrics: list[BaseMetric] = [] - supported_tags = MetricRegistry.tags_applicable_to( + applicable_tags = MetricRegistry.tags_applicable_to( required_flags, disallowed_flags, *metric_types, ) - self._configure_goodput(supported_tags) + self._configure_goodput(applicable_tags) ordered_tags = MetricRegistry.create_dependency_order_for( - supported_tags, + applicable_tags, ) for metric_tag in ordered_tags: metric = MetricRegistry.get_instance(metric_tag) diff --git a/aiperf/post_processors/metric_results_processor.py b/aiperf/post_processors/metric_results_processor.py index bb276f54d..dcf98d3fb 100644 --- a/aiperf/post_processors/metric_results_processor.py +++ b/aiperf/post_processors/metric_results_processor.py @@ -9,12 +9,13 @@ from aiperf.common.enums.metric_enums import MetricDictValueTypeT, MetricValueTypeT from aiperf.common.exceptions import NoMetricValue from aiperf.common.factories import ResultsProcessorFactory +from aiperf.common.messages.inference_messages import MetricRecordsData from aiperf.common.models import MetricResult from aiperf.common.protocols import ResultsProcessorProtocol from aiperf.common.types import MetricTagT from aiperf.metrics import BaseAggregateMetric from aiperf.metrics.base_metric import BaseMetric -from aiperf.metrics.metric_dicts import MetricArray, MetricRecordDict, MetricResultsDict +from aiperf.metrics.metric_dicts import MetricArray, MetricResultsDict from aiperf.metrics.metric_registry import MetricRegistry from aiperf.post_processors.base_metrics_processor import BaseMetricsProcessor @@ -64,12 +65,12 @@ def __init__(self, user_config: UserConfig, **kwargs: Any): if metric.type == MetricType.AGGREGATE } - async def process_result(self, incoming_metrics: MetricRecordDict) -> None: + async def process_result(self, record_data: MetricRecordsData) -> None: """Process a result from the metric record processor.""" if self.is_trace_enabled: - self.trace(f"Processing incoming metrics: {incoming_metrics}") + self.trace(f"Processing incoming metrics: {record_data.metrics}") - for tag, value in incoming_metrics.items(): + for tag, value in record_data.metrics.items(): try: metric_type = self._tags_to_types[tag] if metric_type == MetricType.RECORD: diff --git a/aiperf/post_processors/record_export_results_processor.py b/aiperf/post_processors/record_export_results_processor.py new file mode 100644 index 000000000..b7e7b4b05 --- /dev/null +++ b/aiperf/post_processors/record_export_results_processor.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import asyncio + +import aiofiles + +from aiperf.common.config import ServiceConfig, UserConfig +from aiperf.common.constants import AIPERF_DEV_MODE, DEFAULT_RECORD_EXPORT_BATCH_SIZE +from aiperf.common.decorators import implements_protocol +from aiperf.common.enums import ExportLevel, ResultsProcessorType +from aiperf.common.exceptions import PostProcessorDisabled +from aiperf.common.factories import ResultsProcessorFactory +from aiperf.common.hooks import on_init, on_stop +from aiperf.common.messages.inference_messages import MetricRecordsData +from aiperf.common.models.record_models import MetricRecordInfo, MetricResult +from aiperf.common.protocols import ResultsProcessorProtocol +from aiperf.metrics.metric_dicts import MetricRecordDict +from aiperf.metrics.metric_registry import MetricRegistry +from aiperf.post_processors.base_metrics_processor import BaseMetricsProcessor + + +@implements_protocol(ResultsProcessorProtocol) +@ResultsProcessorFactory.register(ResultsProcessorType.RECORD_EXPORT) +class RecordExportResultsProcessor(BaseMetricsProcessor): + """Exports per-record metrics to JSONL with display unit conversion and filtering.""" + + def __init__( + self, + service_id: str, + service_config: ServiceConfig, + user_config: UserConfig, + **kwargs, + ): + super().__init__(user_config=user_config, **kwargs) + export_level = user_config.output.export_level + export_file_path = user_config.output.profile_export_file + if export_level not in (ExportLevel.RECORDS, ExportLevel.RAW): + raise PostProcessorDisabled( + f"Record export results processor is disabled for export level {export_level}" + ) + + self.output_file = user_config.output.artifact_directory / export_file_path + self.output_file.parent.mkdir(parents=True, exist_ok=True) + self.record_count = 0 + self.show_internal = ( + AIPERF_DEV_MODE and service_config.developer.show_internal_metrics + ) + self.info(f"Record metrics export enabled: {self.output_file}") + self.output_file.unlink(missing_ok=True) + + # File handle for persistent writes with batching + self._file_handle = None + self._buffer: list[str] = [] + self._batch_size = DEFAULT_RECORD_EXPORT_BATCH_SIZE + self._buffer_lock = asyncio.Lock() + + @on_init + async def _open_file(self) -> None: + """Open a persistent file handle for writing.""" + self._file_handle = await aiofiles.open( + self.output_file, mode="w", encoding="utf-8" + ) + + async def process_result(self, record_data: MetricRecordsData) -> None: + try: + metric_dict = MetricRecordDict(record_data.metrics) + display_metrics = metric_dict.to_display_dict( + MetricRegistry, self.show_internal + ) + if not display_metrics: + return + + record_info = MetricRecordInfo( + metadata=record_data.metadata, + metrics=display_metrics, + error=record_data.error, + ) + json_str = record_info.model_dump_json() + + buffer_to_flush = None + async with self._buffer_lock: + self._buffer.append(json_str) + self.record_count += 1 + + if len(self._buffer) >= self._batch_size: + buffer_to_flush = self._buffer + self._buffer = [] + + if buffer_to_flush: + await self._flush_buffer(buffer_to_flush) + + except Exception as e: + self.error(f"Failed to write record metrics: {e}") + + async def summarize(self) -> list[MetricResult]: + """Summarize the results. For this processor, we don't need to summarize anything.""" + return [] + + async def _flush_buffer(self, buffer_to_flush: list[str]) -> None: + """Write buffered records to disk.""" + if not buffer_to_flush: + return + + try: + self.debug(lambda: f"Flushing {len(buffer_to_flush)} records to file") + await self._file_handle.write("\n".join(buffer_to_flush)) + await self._file_handle.write("\n") + await self._file_handle.flush() + except Exception as e: + self.error(f"Failed to flush buffer: {e}") + raise + + @on_stop + async def _shutdown(self) -> None: + async with self._buffer_lock: + buffer_to_flush = self._buffer + self._buffer = [] + + try: + await self._flush_buffer(buffer_to_flush) + except Exception as e: + self.error(f"Failed to flush remaining buffer during shutdown: {e}") + + if self._file_handle is not None: + try: + await self._file_handle.close() + except Exception as e: + self.error(f"Failed to close file handle during shutdown: {e}") + finally: + self._file_handle = None + + self.info( + f"RecordExportResultsProcessor: {self.record_count} records written to {self.output_file}" + ) diff --git a/aiperf/records/record_processor_service.py b/aiperf/records/record_processor_service.py index 5ffb092f7..3be8a8035 100644 --- a/aiperf/records/record_processor_service.py +++ b/aiperf/records/record_processor_service.py @@ -6,7 +6,13 @@ from aiperf.common.base_component_service import BaseComponentService from aiperf.common.config import ServiceConfig, UserConfig from aiperf.common.constants import DEFAULT_PULL_CLIENT_MAX_CONCURRENCY -from aiperf.common.enums import CommAddress, CommandType, MessageType, ServiceType +from aiperf.common.enums import ( + CommAddress, + CommandType, + MessageType, + ServiceType, +) +from aiperf.common.exceptions import PostProcessorDisabled from aiperf.common.factories import ( RecordProcessorFactory, ServiceFactory, @@ -18,13 +24,14 @@ ProfileConfigureCommand, ) from aiperf.common.mixins import PullClientMixin -from aiperf.common.models import ParsedResponseRecord +from aiperf.common.models import MetricRecordMetadata, ParsedResponseRecord from aiperf.common.protocols import ( PushClientProtocol, RecordProcessorProtocol, RequestClientProtocol, ) from aiperf.common.tokenizer import Tokenizer +from aiperf.common.utils import compute_time_ns from aiperf.metrics.metric_dicts import MetricRecordDict from aiperf.parsers.inference_result_parser import InferenceResultParser @@ -75,15 +82,20 @@ async def _initialize(self) -> None: """Initialize record processor-specific components.""" self.debug("Initializing record processor") - # Initialize all the records streamers + # Initialize all the records streamers that are enabled for processor_type in RecordProcessorFactory.get_all_class_types(): - self.records_processors.append( - RecordProcessorFactory.create_instance( - processor_type, - service_config=self.service_config, - user_config=self.user_config, + try: + self.records_processors.append( + RecordProcessorFactory.create_instance( + processor_type, + service_config=self.service_config, + user_config=self.user_config, + ) + ) + except PostProcessorDisabled: + self.debug( + f"Record processor {processor_type} is disabled and will not be used" ) - ) @on_command(CommandType.PROFILE_CONFIGURE) async def _profile_configure_command( @@ -103,6 +115,43 @@ async def get_tokenizer(self, model: str) -> Tokenizer: ) return self.tokenizers[model] + def _create_metric_record_metadata( + self, record: ParsedResponseRecord, worker_id: str + ) -> MetricRecordMetadata: + """Create a metric record metadata based on a parsed response record.""" + + start_time_ns = record.timestamp_ns + start_perf_ns = record.start_perf_ns + + # Convert all timestamps from perf_ns to time_ns for the user + request_end_ns = compute_time_ns( + start_time_ns, + start_perf_ns, + record.responses[-1].perf_ns if record.responses else record.end_perf_ns, + ) + request_ack_ns = compute_time_ns( + start_time_ns, start_perf_ns, record.recv_start_perf_ns + ) + cancellation_time_ns = compute_time_ns( + start_time_ns, start_perf_ns, record.cancellation_perf_ns + ) + + return MetricRecordMetadata( + request_start_ns=start_time_ns, + request_ack_ns=request_ack_ns, + request_end_ns=request_end_ns, + conversation_id=record.conversation_id, + turn_index=record.turn_index, + record_processor_id=self.service_id, + benchmark_phase=record.credit_phase, + x_request_id=record.x_request_id, + x_correlation_id=record.x_correlation_id, + session_num=record.credit_num, + worker_id=worker_id, + was_cancelled=record.was_cancelled, + cancellation_time_ns=cancellation_time_ns, + ) + @on_pull_message(MessageType.INFERENCE_RESULTS) async def _on_inference_results(self, message: InferenceResultsMessage) -> None: """Handle an inference results message.""" @@ -116,16 +165,15 @@ async def _on_inference_results(self, message: InferenceResultsMessage) -> None: self.warning(f"Error processing record: {result}") else: results.append(result) + await self.records_push_client.push( MetricRecordsMessage( service_id=self.service_id, - timestamp_ns=message.record.timestamp_ns, - x_request_id=message.record.x_request_id, - x_correlation_id=message.record.x_correlation_id, - credit_phase=message.record.credit_phase, + metadata=self._create_metric_record_metadata( + message.record, message.service_id + ), results=results, error=message.record.error, - worker_id=message.service_id, ) ) diff --git a/aiperf/records/records_manager.py b/aiperf/records/records_manager.py index 76b4080be..1036e2dc4 100644 --- a/aiperf/records/records_manager.py +++ b/aiperf/records/records_manager.py @@ -20,8 +20,8 @@ MessageType, ServiceType, ) -from aiperf.common.enums.metric_enums import MetricValueTypeT from aiperf.common.enums.ui_enums import AIPerfUIType +from aiperf.common.exceptions import PostProcessorDisabled from aiperf.common.factories import ( ResultsProcessorFactory, ServiceFactory, @@ -40,6 +40,7 @@ ) from aiperf.common.messages.command_messages import RealtimeMetricsCommand from aiperf.common.messages.credit_messages import CreditPhaseSendingCompleteMessage +from aiperf.common.messages.inference_messages import MetricRecordsData from aiperf.common.mixins import PullClientMixin from aiperf.common.models import ( ErrorDetails, @@ -50,9 +51,6 @@ ) from aiperf.common.models.record_models import MetricResult from aiperf.common.protocols import ResultsProcessorProtocol, ServiceProtocol -from aiperf.common.types import MetricTagT -from aiperf.metrics.types.min_request_metric import MinRequestTimestampMetric -from aiperf.metrics.types.request_latency_metric import RequestLatencyMetric from aiperf.records.phase_completion import ( PhaseCompletionChecker, ) @@ -106,16 +104,21 @@ def __init__( self._results_processors: list[ResultsProcessorProtocol] = [] for results_processor_type in ResultsProcessorFactory.get_all_class_types(): - results_processor = ResultsProcessorFactory.create_instance( - class_type=results_processor_type, - service_id=self.service_id, - service_config=self.service_config, - user_config=self.user_config, - ) - self.debug( - f"Created results processor: {results_processor_type}: {results_processor.__class__.__name__}" - ) - self._results_processors.append(results_processor) + try: + results_processor = ResultsProcessorFactory.create_instance( + class_type=results_processor_type, + service_id=self.service_id, + service_config=self.service_config, + user_config=self.user_config, + ) + self.debug( + f"Created results processor: {results_processor_type}: {results_processor.__class__.__name__}" + ) + self._results_processors.append(results_processor) + except PostProcessorDisabled: + self.debug( + f"Results processor {results_processor_type} is disabled and will not be used" + ) @on_pull_message(MessageType.METRIC_RECORDS) async def _on_metric_records(self, message: MetricRecordsMessage) -> None: @@ -123,20 +126,22 @@ async def _on_metric_records(self, message: MetricRecordsMessage) -> None: if self.is_trace_enabled: self.trace(f"Received metric records: {message}") - if message.credit_phase != CreditPhase.PROFILING: - self.debug(lambda: f"Skipping non-profiling record: {message.credit_phase}") + if message.metadata.benchmark_phase != CreditPhase.PROFILING: + self.debug( + lambda: f"Skipping non-profiling record: {message.metadata.benchmark_phase}" + ) return - should_include_request = self._should_include_request_by_duration( - message.results - ) + record_data = message.to_data() + + should_include_request = self._should_include_request_by_duration(record_data) if should_include_request: - await self._send_results_to_results_processors(message.results) + await self._send_results_to_results_processors(record_data) - worker_id = message.worker_id + worker_id = message.metadata.worker_id - if message.valid and should_include_request: + if record_data.valid and should_include_request: # Valid record async with self.worker_stats_lock: worker_stats = self.worker_stats.setdefault( @@ -145,7 +150,7 @@ async def _on_metric_records(self, message: MetricRecordsMessage) -> None: worker_stats.processed += 1 async with self.processing_status_lock: self.processing_stats.processed += 1 - elif message.valid and not should_include_request: + elif record_data.valid and not should_include_request: # Timed out record self.debug( f"Filtered out record from worker {worker_id} - response received after duration" @@ -159,21 +164,21 @@ async def _on_metric_records(self, message: MetricRecordsMessage) -> None: worker_stats.errors += 1 async with self.processing_status_lock: self.processing_stats.errors += 1 - if message.error: + if record_data.error: async with self.error_summary_lock: - self.error_summary[message.error] = ( - self.error_summary.get(message.error, 0) + 1 + self.error_summary[record_data.error] = ( + self.error_summary.get(record_data.error, 0) + 1 ) await self._check_if_all_records_received() def _should_include_request_by_duration( - self, results: list[dict[MetricTagT, MetricValueTypeT]] + self, record_data: MetricRecordsData ) -> bool: """Determine if the request should be included based on benchmark duration. Args: - results: List of metric results for a single request + record_data: MetricRecordsData for a single request Returns: True if the request should be included, else False @@ -188,19 +193,12 @@ def _should_include_request_by_duration( # Check if any response in this request was received after the duration # If so, filter out the entire request (all-or-nothing approach) - for result_dict in results: - request_timestamp = result_dict.get(MinRequestTimestampMetric.tag) - request_latency = result_dict.get(RequestLatencyMetric.tag) - - if request_timestamp is not None and request_latency is not None: - final_response_timestamp = request_timestamp + request_latency - - if final_response_timestamp > duration_end_ns: - self.debug( - f"Filtering out timed-out request - response received " - f"{final_response_timestamp - duration_end_ns} ns after timeout" - ) - return False + if record_data.metadata.request_end_ns > duration_end_ns: + self.debug( + f"Filtering out timed-out request - response received " + f"{record_data.metadata.request_end_ns - duration_end_ns} ns after timeout" + ) + return False return True @@ -255,14 +253,13 @@ async def _check_if_all_records_received(self) -> None: await self._process_results(cancelled=cancelled) async def _send_results_to_results_processors( - self, results: list[dict[MetricTagT, MetricValueTypeT]] + self, record_data: MetricRecordsData ) -> None: """Send the results to each of the results processors.""" await asyncio.gather( *[ - results_processor.process_result(result) + results_processor.process_result(record_data) for results_processor in self._results_processors - for result in results ] ) diff --git a/aiperf/timing/credit_manager.py b/aiperf/timing/credit_manager.py index 1e0de551b..713aa1b12 100644 --- a/aiperf/timing/credit_manager.py +++ b/aiperf/timing/credit_manager.py @@ -26,6 +26,7 @@ class CreditManagerProtocol(PubClientProtocol, Protocol): async def drop_credit( self, credit_phase: CreditPhase, + credit_num: int, conversation_id: str | None = None, credit_drop_ns: int | None = None, *, diff --git a/aiperf/timing/fixed_schedule_strategy.py b/aiperf/timing/fixed_schedule_strategy.py index 98eb9e4bd..775c5d5be 100644 --- a/aiperf/timing/fixed_schedule_strategy.py +++ b/aiperf/timing/fixed_schedule_strategy.py @@ -106,12 +106,14 @@ async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None: await self.credit_manager.drop_credit( credit_phase=CreditPhase.PROFILING, + credit_num=phase_stats.sent, conversation_id=conversation_id, # We already waited, so it can be sent ASAP credit_drop_ns=None, should_cancel=should_cancel, cancel_after_ns=cancel_after_ns, ) + # NOTE: This is incremented here, as the credit_num is used up above, and needs the current value. phase_stats.sent += 1 duration_sec = (self._perf_counter_ms() - start_time_ms) / MILLIS_PER_SECOND diff --git a/aiperf/timing/request_rate_strategy.py b/aiperf/timing/request_rate_strategy.py index 8da6f4ca6..cca7258eb 100644 --- a/aiperf/timing/request_rate_strategy.py +++ b/aiperf/timing/request_rate_strategy.py @@ -69,9 +69,11 @@ async def _execute_single_phase(self, phase_stats: CreditPhaseStats) -> None: await self.credit_manager.drop_credit( credit_phase=phase_stats.type, + credit_num=phase_stats.sent, should_cancel=should_cancel, cancel_after_ns=cancel_after_ns, ) + # NOTE: This is incremented here, as the credit_num is used up above, and needs the current value. phase_stats.sent += 1 # Check if we should break out of the loop before we sleep for the next interval. # This is to ensure we don't sleep for any unnecessary time, which could cause race conditions. diff --git a/aiperf/timing/timing_manager.py b/aiperf/timing/timing_manager.py index 48b1d0ae7..8a7f3a7d7 100644 --- a/aiperf/timing/timing_manager.py +++ b/aiperf/timing/timing_manager.py @@ -175,6 +175,8 @@ async def _on_credit_return(self, message: CreditReturnMessage) -> None: async def drop_credit( self, credit_phase: CreditPhase, + credit_num: int, + *, conversation_id: str | None = None, credit_drop_ns: int | None = None, should_cancel: bool = False, @@ -186,6 +188,7 @@ async def drop_credit( message=CreditDropMessage( service_id=self.service_id, phase=credit_phase, + credit_num=credit_num, credit_drop_ns=credit_drop_ns, conversation_id=conversation_id, should_cancel=should_cancel, diff --git a/aiperf/workers/worker.py b/aiperf/workers/worker.py index 2231bf36a..e23154a43 100644 --- a/aiperf/workers/worker.py +++ b/aiperf/workers/worker.py @@ -221,11 +221,11 @@ async def _execute_single_credit_internal(self, message: CreditDropMessage) -> N turn_list.append(turn) # TODO: how do we handle errors in the middle of a conversation? record = await self._build_response_record( - conversation.session_id, - message, - turn, - turn_index, - drop_perf_ns, + conversation_id=conversation.session_id, + message=message, + turn=turn, + turn_index=turn_index, + drop_perf_ns=drop_perf_ns, ) await self._send_inference_result_message(record) resp_turn = await self._process_response(record) @@ -234,6 +234,7 @@ async def _execute_single_credit_internal(self, message: CreditDropMessage) -> N async def _retrieve_conversation_response( self, + *, service_id: str, conversation_id: str | None, phase: CreditPhase, @@ -275,6 +276,7 @@ async def _retrieve_conversation_response( async def _build_response_record( self, + *, conversation_id: str, message: CreditDropMessage, turn: Turn, @@ -291,6 +293,7 @@ async def _build_response_record( record.cancel_after_ns = message.cancel_after_ns record.x_request_id = x_request_id record.x_correlation_id = message.request_id + record.credit_num = message.credit_num # If this is the first turn, calculate the credit drop latency if turn_index == 0: record.credit_drop_latency = record.start_perf_ns - drop_perf_ns diff --git a/docs/tutorials/working-with-profile-exports.md b/docs/tutorials/working-with-profile-exports.md new file mode 100644 index 000000000..19bb1f454 --- /dev/null +++ b/docs/tutorials/working-with-profile-exports.md @@ -0,0 +1,277 @@ + + +# Working with Profile Export Files + +This guide demonstrates how to programmatically work with AIPerf benchmark output files using the native Pydantic data models. + +## Overview + +AIPerf generates multiple output formats after each benchmark run, each optimized for different analysis workflows: + +- [**`inputs.json`**](#input-dataset-json) - Complete input dataset with formatted payloads for each request +- [**`profile_export.jsonl`**](#per-request-records-jsonl) - Per-request metric records in JSON Lines format with one record per line +- [**`profile_export_aiperf.json`**](#aggregated-statistics-json) - Aggregated statistics and user configuration as a single JSON object +- [**`profile_export_aiperf.csv`**](#aggregated-statistics-csv) - Aggregated statistics in CSV format + +## Data Models + +AIPerf uses Pydantic models for type-safe parsing and validation of all benchmark output files. These models ensure data integrity and provide IDE autocompletion support. + +### Core Models + +```python +from aiperf.common.models import ( + MetricRecordInfo, + MetricRecordMetadata, + MetricValue, + ErrorDetails, + InputsFile, + SessionPayloads, +) +``` + +| Model | Description | Source | +|-------|-------------|--------| +| `MetricRecordInfo` | Complete per-request record including metadata, metrics, and error information | [record_models.py](../../aiperf/common/models/record_models.py) | +| `MetricRecordMetadata` | Request metadata: timestamps, IDs, worker identifiers, and phase information | [record_models.py](../../aiperf/common/models/record_models.py) | +| `MetricValue` | Individual metric value with associated unit of measurement | [record_models.py](../../aiperf/common/models/record_models.py) | +| `ErrorDetails` | Error information including HTTP code, error type, and descriptive message | [error_models.py](../../aiperf/common/models/error_models.py) | +| `InputsFile` | Container for all input dataset sessions with formatted payloads for each turn | [dataset_models.py](../../aiperf/common/models/dataset_models.py) | +| `SessionPayloads` | Single conversation session with session ID and list of formatted request payloads | [dataset_models.py](../../aiperf/common/models/dataset_models.py) | + +## Output File Formats + +### Input Dataset (JSON) + +**File:** `artifacts/my-run/inputs.json` + +A structured representation of all input datasets converted to the payload format used by the endpoint. + +**Structure:** +```json +{ + "data": [ + { + "session_id": "a5cdb1fe-19a3-4ed0-9e54-ed5ed6dc5578", + "payloads": [ + { ... } // formatted payload based on the endpoint type. + ] + } + ] +} +``` + +**Key fields:** +- `session_id`: Unique identifier for the conversation. This can be used to correlate inputs with results. +- `payloads`: Array of formatted request payloads (one per turn in multi-turn conversations) + +### Per-Request Records (JSONL) + +**File:** `artifacts/my-run/profile_export.jsonl` + +The JSONL output contains one record per line, for each request sent during the benchmark. Each record includes request metadata, computed metrics, and error information if the request failed. + +#### Successful Request Record + +```json +{ + "metadata": { + "session_num": 45, + "x_request_id": "7609a2e7-aa53-4ab1-98f4-f35ecafefd25", + "x_correlation_id": "32ee4f33-cfca-4cfc-988f-79b45408b909", + "conversation_id": "77aa5b0e-b305-423f-88d5-c00da1892599", + "turn_index": 0, + "request_start_ns": 1759813207532900363, + "request_ack_ns": 1759813207650730976, + "request_end_ns": 1759813207838764604, + "worker_id": "worker_359d423a", + "record_processor_id": "record_processor_1fa47cd7", + "benchmark_phase": "profiling", + "was_cancelled": false, + "cancellation_time_ns": null + }, + "metrics": { + "input_sequence_length": {"value": 550, "unit": "tokens"}, + "ttft": {"value": 255.88656799999998, "unit": "ms"}, + "request_latency": {"value": 297.52522799999997, "unit": "ms"}, + "output_token_count": {"value": 9, "unit": "tokens"}, + "ttst": {"value": 4.8984369999999995, "unit": "ms"}, + "inter_chunk_latency": {"value": [4.898437, 5.316006, 4.801489, 5.674918, 4.811467, 5.097998, 5.504797, 5.533548], "unit": "ms"}, + "output_sequence_length": {"value": 9, "unit": "tokens"}, + "inter_token_latency": {"value": 5.2048325, "unit": "ms"}, + "output_token_throughput_per_user": {"value": 192.1291415237666, "unit": "tokens/sec/user"} + }, + "error": null +} +``` + +**Metadata Fields:** +- `session_num`: Sequential request number across the entire benchmark (0-indexed). + - For single-turn conversations, this will be the request index across all requests in the benchmark. + - For multi-turn conversations, this will be the index of the user session across all sessions in the benchmark. +- `x_request_id`: Unique identifier for this specific request. This is sent to the endpoint as the X-Request-ID header. +- `x_correlation_id`: Unique identifier for the user session. This is the same for all requests in the same user session for multi-turn conversations. This is sent to the endpoint as the X-Correlation-ID header. +- `conversation_id`: ID of the input dataset conversation. This can be used to correlate inputs with results. +- `turn_index`: Position within a multi-turn conversation (0-indexed), or 0 for single-turn conversations. +- `request_start_ns`: Epoch time in nanoseconds when request was initiated by AIPerf. +- `request_ack_ns`: Epoch time in nanoseconds when server acknowledged the request. This is only applicable to streaming requests. +- `request_end_ns`: Epoch time in nanoseconds when the last response was received from the endpoint. +- `worker_id`: ID of the AIPerf worker that executed the request against the endpoint. +- `record_processor_id`: ID of the AIPerf record processor that processed the results from the server. +- `benchmark_phase`: Phase of the benchmark. Currently only `profiling` is supported. +- `was_cancelled`: Whether the request was cancelled during execution (such as when `--request-cancellation-rate` is enabled). +- `cancellation_time_ns`: Epoch time in nanoseconds when the request was cancelled (if applicable). + +**Metrics:** +See the [Complete Metrics Reference](../metrics_reference.md) page for a list of all metrics and their descriptions. Will always be null for failed requests. + +#### Failed Request Record + +```json +{ + "metadata": { + "session_num": 80, + "x_request_id": "c35e4b1b-6775-4750-b875-94cd68e5ec15", + "x_correlation_id": "77ecf78d-b848-4efc-9579-cd695c6e89c4", + "conversation_id": "9526b41d-5dbc-41a5-a353-99ae06a53bc5", + "turn_index": 0, + "request_start_ns": 1759879161119147826, + "request_ack_ns": null, + "request_end_ns": 1759879161119772754, + "worker_id": "worker_6006099d", + "record_processor_id": "record_processor_fdeeec8f", + "benchmark_phase": "profiling", + "was_cancelled": true, + "cancellation_time_ns": 1759879161119772754 + }, + "metrics": { + "error_isl": {"value": 550, "unit": "tokens"} + }, + "error": { + "code": 499, + "type": "RequestCancellationError", + "message": "Request was cancelled after 0.000 seconds" + } +} +``` + +**Error Fields:** +- `code`: HTTP status code or custom error code +- `type`: Classification of the error (e.g., timeout, cancellation, server error). Typically the python exception class name. +- `message`: Human-readable error description + + +### Aggregated Statistics (JSON) + +**File:** `artifacts/my-run/profile_export_aiperf.json` + +A single JSON object containing statistical summaries (min, max, mean, percentiles) for all metrics across the entire benchmark run, as well as the user configuration used for the benchmark. + +### Aggregated Statistics (CSV) + +**File:** `artifacts/my-run/profile_export_aiperf.csv` + +Contains the same aggregated statistics as the JSON format, but in a spreadsheet-friendly structure with one metric per row. + +## Working with Output Data + +AIPerf output files can be parsed using the native Pydantic models for type-safe data handling and analysis. + +### Synchronous Loading +```python +from aiperf.common.models import MetricRecordInfo + +def load_records(file_path: Path) -> list[MetricRecordInfo]: + """Load artifacts/my-run/profile_export.jsonl file into structured Pydantic models in sync mode.""" + records = [] + with open(file_path, encoding="utf-8") as f: + for line in f: + if line.strip(): + record = MetricRecordInfo.model_validate_json(line) + records.append(record) + return records +``` + +### Asynchronous Loading + +For large benchmark runs with thousands of requests, use async file I/O for better performance: + +```python +import aiofiles +from aiperf.common.models import MetricRecordInfo + +async def process_streaming_records_async(file_path: Path) -> None: + """Load artifacts/my-run/profile_export.jsonl file into structured Pydantic models in async mode and process the streaming records.""" + async with aiofiles.open(file_path, encoding="utf-8") as f: + async for line in f: + if line.strip(): + record = MetricRecordInfo.model_validate_json(line) + # ... Process the streaming records here ... +``` + +### Working with Input Datasets + +Load and analyze the `inputs.json` file to understand what data was sent during the benchmark: + +```python +from pathlib import Path +from aiperf.common.models import InputsFile + +def load_inputs_file(file_path: Path) -> InputsFile: + """Load inputs.json file into structured Pydantic model.""" + with open(file_path, encoding="utf-8") as f: + return InputsFile.model_validate_json(f.read()) + +inputs = load_inputs_file(Path("artifacts/my-run/inputs.json")) +``` + +### Correlating Inputs with Results + +Combine `artifacts/my-run/inputs.json` with `artifacts/my-run/profile_export.jsonl` for deeper analysis: + +```python +from pathlib import Path +from aiperf.common.models import InputsFile, MetricRecordInfo + +def correlate_inputs_and_results(inputs_path: Path, results_path: Path): + """Correlate input prompts with performance metrics.""" + # Load inputs + with open(inputs_path, encoding="utf-8") as f: + inputs = InputsFile.model_validate_json(f.read()) + + # Create session lookup + session_inputs = {session.session_id: session for session in inputs.data} + + # Process results and correlate + with open(results_path, encoding="utf-8") as f: + for line in f: + if not line.strip(): + continue + + record = MetricRecordInfo.model_validate_json(line) + + # Find corresponding input + conv_id = record.metadata.conversation_id + if conv_id not in session_inputs: + raise ValueError(f"Conversation ID {conv_id} not found in inputs") + + session = session_inputs[conv_id] + turn_idx = record.metadata.turn_index + + if turn_idx >= len(session.payloads): + raise ValueError(f"Turn index {turn_idx} is out of range for session {conv_id}") + + # Assign the raw request payload to the record, and print it out + # You can do this because AIPerf models allow extra fields to be added to the model. + payload = session.payloads[turn_idx] + record.payload = payload + print(record.model_dump_json(indent=2)) + +correlate_inputs_and_results( + Path("artifacts/my-run/inputs.json"), + Path("artifacts/my-run/profile_export.jsonl") +) +``` diff --git a/tests/conftest.py b/tests/conftest.py index 8baaaaeb5..51b664136 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,15 @@ from aiperf.common.config import EndpointConfig, ServiceConfig, UserConfig from aiperf.common.enums import CommunicationBackend, ServiceRunType from aiperf.common.messages import Message -from aiperf.common.models import Conversation, Text, Turn +from aiperf.common.models import ( + Conversation, + ParsedResponse, + ParsedResponseRecord, + RequestRecord, + Text, + TextResponseData, + Turn, +) from aiperf.common.tokenizer import Tokenizer from aiperf.common.types import MessageTypeT from tests.comms.mock_zmq import ( @@ -34,6 +42,13 @@ logging.basicConfig(level=_TRACE) +# Shared test constants for request/response records +DEFAULT_START_TIME_NS = 1_000_000 +DEFAULT_FIRST_RESPONSE_NS = 1_050_000 +DEFAULT_LAST_RESPONSE_NS = 1_100_000 +DEFAULT_INPUT_TOKENS = 5 +DEFAULT_OUTPUT_TOKENS = 2 + def pytest_addoption(parser): """Add custom command line options for pytest.""" @@ -315,3 +330,39 @@ def sample_conversations() -> dict[str, Conversation]: ), } return conversations + + +@pytest.fixture +def sample_request_record() -> RequestRecord: + """Create a sample RequestRecord for testing.""" + return RequestRecord( + conversation_id="test-conversation", + turn_index=0, + model_name="test-model", + start_perf_ns=DEFAULT_START_TIME_NS, + timestamp_ns=DEFAULT_START_TIME_NS, + end_perf_ns=DEFAULT_LAST_RESPONSE_NS, + error=None, + ) + + +@pytest.fixture +def sample_parsed_record(sample_request_record: RequestRecord) -> ParsedResponseRecord: + """Create a valid ParsedResponseRecord for testing.""" + responses = [ + ParsedResponse( + perf_ns=DEFAULT_FIRST_RESPONSE_NS, + data=TextResponseData(text="Hello"), + ), + ParsedResponse( + perf_ns=DEFAULT_LAST_RESPONSE_NS, + data=TextResponseData(text=" world"), + ), + ] + + return ParsedResponseRecord( + request=sample_request_record, + responses=responses, + input_token_count=DEFAULT_INPUT_TOKENS, + output_token_count=DEFAULT_OUTPUT_TOKENS, + ) diff --git a/tests/metrics/test_input_sequence_length_metric.py b/tests/metrics/test_input_sequence_length_metric.py index a6039bdce..29613441b 100644 --- a/tests/metrics/test_input_sequence_length_metric.py +++ b/tests/metrics/test_input_sequence_length_metric.py @@ -7,7 +7,9 @@ from aiperf.common.exceptions import NoMetricValue from aiperf.metrics.metric_dicts import MetricRecordDict, MetricResultsDict from aiperf.metrics.types.input_sequence_length_metric import ( + ErrorInputSequenceLengthMetric, InputSequenceLengthMetric, + TotalErrorInputSequenceLengthMetric, TotalInputSequenceLengthMetric, ) from tests.metrics.conftest import ( @@ -83,3 +85,62 @@ def test_metric_metadata(self): assert TotalInputSequenceLengthMetric.has_flags(MetricFlags.LARGER_IS_BETTER) assert TotalInputSequenceLengthMetric.has_flags(MetricFlags.NO_CONSOLE) assert TotalInputSequenceLengthMetric.missing_flags(MetricFlags.INTERNAL) + + +class TestErrorInputSequenceLengthMetric: + def test_error_isl_basic(self): + """Test basic error input sequence length extraction""" + from aiperf.common.models import ErrorDetails + + record = create_record( + input_tokens=15, + error=ErrorDetails(code=500, message="Error", type="ServerError"), + ) + + metric = ErrorInputSequenceLengthMetric() + result = metric.parse_record(record, MetricRecordDict()) + assert result == 15 + + def test_error_isl_none_raises(self): + """Test handling of None input tokens raises error""" + from aiperf.common.models import ErrorDetails + + record = create_record( + input_tokens=None, + error=ErrorDetails(code=500, message="Error", type="ServerError"), + ) + + metric = ErrorInputSequenceLengthMetric() + with pytest.raises(NoMetricValue): + metric.parse_record(record, MetricRecordDict()) + + def test_error_isl_metadata(self): + """Test that ErrorInputSequenceLengthMetric has correct flags""" + assert ErrorInputSequenceLengthMetric.tag == "error_isl" + assert ErrorInputSequenceLengthMetric.has_flags(MetricFlags.ERROR_ONLY) + assert ErrorInputSequenceLengthMetric.has_flags(MetricFlags.NO_CONSOLE) + + +class TestTotalErrorInputSequenceLengthMetric: + @pytest.mark.parametrize( + "values, expected_sum", + [ + ([10, 20, 30], 60), + ([100], 100), + ([], 0), + ], + ) + def test_sum_calculation(self, values, expected_sum): + """Test that TotalErrorInputSequenceLengthMetric correctly sums error input tokens""" + metric = TotalErrorInputSequenceLengthMetric() + metric_results = MetricResultsDict() + metric_results[ErrorInputSequenceLengthMetric.tag] = create_metric_array(values) + + result = metric.derive_value(metric_results) + assert result == expected_sum + + def test_metric_metadata(self): + """Test that TotalErrorInputSequenceLengthMetric has correct metadata""" + assert TotalErrorInputSequenceLengthMetric.tag == "total_error_isl" + assert TotalErrorInputSequenceLengthMetric.has_flags(MetricFlags.ERROR_ONLY) + assert TotalErrorInputSequenceLengthMetric.has_flags(MetricFlags.NO_CONSOLE) diff --git a/tests/parsers/test_inference_result_parser.py b/tests/parsers/test_inference_result_parser.py index 6e3ec21fd..3b4b66e21 100644 --- a/tests/parsers/test_inference_result_parser.py +++ b/tests/parsers/test_inference_result_parser.py @@ -7,13 +7,14 @@ from aiperf.common.config import EndpointConfig, InputConfig, ServiceConfig, UserConfig from aiperf.common.messages import ConversationTurnResponseMessage -from aiperf.common.models import RequestRecord, Text, Turn +from aiperf.common.models import ErrorDetails, RequestRecord, Text, Turn from aiperf.common.tokenizer import Tokenizer from aiperf.parsers.inference_result_parser import InferenceResultParser @pytest.fixture def mock_tokenizer(): + """Mock tokenizer that returns token count based on word count.""" tokenizer = MagicMock(spec=Tokenizer) tokenizer.encode.side_effect = lambda x: list(range(len(x.split()))) return tokenizer @@ -21,6 +22,7 @@ def mock_tokenizer(): @pytest.fixture def sample_turn(): + """Sample turn with 4 text strings (8 words total).""" return Turn( role="user", texts=[ @@ -39,13 +41,34 @@ def mock_turn_response(sample_turn): @pytest.fixture -def sample_request_record(): - return RequestRecord(conversation_id="cid", turn_index=0) +def parser(mock_turn_response): + """Create a parser with mocked communications layer.""" + mock_client = MagicMock() + mock_client.request = AsyncMock(return_value=mock_turn_response) + mock_comms = MagicMock() + mock_comms.create_request_client.return_value = mock_client -@pytest.fixture -def parser(mock_turn_response): - with patch.object(InferenceResultParser, "__init__", lambda self, **kwargs: None): + def mock_communication_init(self, **_kwargs): + self.comms = mock_comms + # Add logger methods + for method in [ + "trace_or_debug", + "debug", + "info", + "warning", + "error", + "exception", + ]: + setattr(self, method, MagicMock()) + + with ( + patch( + "aiperf.common.mixins.CommunicationMixin.__init__", mock_communication_init + ), + patch("aiperf.clients.model_endpoint_info.ModelEndpointInfo.from_user_config"), + patch("aiperf.common.factories.ResponseExtractorFactory.create_instance"), + ): parser = InferenceResultParser( service_config=ServiceConfig(), user_config=UserConfig( @@ -53,18 +76,56 @@ def parser(mock_turn_response): input=InputConfig(), ), ) - parser.id = "test-parser" - parser.conversation_request_client = MagicMock() - parser.conversation_request_client.request = AsyncMock( - return_value=mock_turn_response - ) return parser +def create_request_record(has_error=False, is_invalid=False, model_name="test-model"): + """Helper to create request records with various states.""" + record = RequestRecord(conversation_id="cid", turn_index=0, model_name=model_name) + + if has_error: + record.error = ErrorDetails( + code=500, message="Server error", type="ServerError" + ) + if is_invalid: + record._valid = False + + return record + + +def setup_parser_for_error_tests(parser, mock_tokenizer, sample_turn): + """Common setup for error record tests.""" + parser.get_tokenizer = AsyncMock(return_value=mock_tokenizer) + parser.get_turn = AsyncMock(return_value=sample_turn) + parser.extractor = MagicMock() + + @pytest.mark.asyncio -async def test_compute_input_token_count(parser, sample_request_record, mock_tokenizer): - result = await parser.compute_input_token_count( - sample_request_record, mock_tokenizer - ) - assert result == 8 # 4 strings × 2 words each - assert mock_tokenizer.encode.call_count == 2 +@pytest.mark.parametrize( + "record_type", + ["error", "invalid", "processing_exception"], +) +async def test_error_records_compute_input_tokens( + parser, mock_tokenizer, sample_turn, record_type +): + """Test that input_token_count is computed for all error scenarios.""" + if record_type == "error": + record = create_request_record(has_error=True) + elif record_type == "invalid": + record = create_request_record(is_invalid=True) + else: # processing_exception + record = create_request_record() + + setup_parser_for_error_tests(parser, mock_tokenizer, sample_turn) + + if record_type == "processing_exception": + parser.extractor.extract_response_data = AsyncMock( + side_effect=ValueError("Processing failed") + ) + + result = await parser.parse_request_record(record) + + assert result.request == record + assert result.input_token_count == 8 + assert result.responses == [] + assert record.error is not None diff --git a/tests/post_processors/conftest.py b/tests/post_processors/conftest.py index ac5a1e8e4..0c5816e15 100644 --- a/tests/post_processors/conftest.py +++ b/tests/post_processors/conftest.py @@ -7,23 +7,27 @@ import pytest from aiperf.common.config import EndpointConfig, UserConfig -from aiperf.common.enums import EndpointType +from aiperf.common.enums import CreditPhase, EndpointType, MessageType +from aiperf.common.enums.metric_enums import MetricValueTypeT +from aiperf.common.messages import MetricRecordsMessage from aiperf.common.models import ( ErrorDetails, - ParsedResponse, ParsedResponseRecord, RequestRecord, - TextResponseData, ) +from aiperf.common.models.record_models import MetricRecordMetadata +from aiperf.common.types import MetricTagT from aiperf.metrics.base_metric import BaseMetric from aiperf.post_processors.metric_results_processor import MetricResultsProcessor - -# Constants for test data -DEFAULT_START_TIME_NS = 1_000_000 -DEFAULT_FIRST_RESPONSE_NS = 1_050_000 -DEFAULT_LAST_RESPONSE_NS = 1_100_000 -DEFAULT_INPUT_TOKENS = 5 -DEFAULT_OUTPUT_TOKENS = 2 +from tests.conftest import ( # noqa: F401 + DEFAULT_FIRST_RESPONSE_NS, + DEFAULT_INPUT_TOKENS, + DEFAULT_LAST_RESPONSE_NS, + DEFAULT_OUTPUT_TOKENS, + DEFAULT_START_TIME_NS, + sample_parsed_record, + sample_request_record, +) @pytest.fixture @@ -37,42 +41,6 @@ def mock_user_config() -> UserConfig: ) -@pytest.fixture -def sample_request_record() -> RequestRecord: - """Create a sample RequestRecord for testing.""" - return RequestRecord( - conversation_id="test-conversation", - turn_index=0, - model_name="test-model", - start_perf_ns=DEFAULT_START_TIME_NS, - timestamp_ns=DEFAULT_START_TIME_NS, - end_perf_ns=DEFAULT_LAST_RESPONSE_NS, - error=None, - ) - - -@pytest.fixture -def sample_parsed_record(sample_request_record: RequestRecord) -> ParsedResponseRecord: - """Create a valid ParsedResponseRecord for testing.""" - responses = [ - ParsedResponse( - perf_ns=DEFAULT_FIRST_RESPONSE_NS, - data=TextResponseData(text="Hello"), - ), - ParsedResponse( - perf_ns=DEFAULT_LAST_RESPONSE_NS, - data=TextResponseData(text=" world"), - ), - ] - - return ParsedResponseRecord( - request=sample_request_record, - responses=responses, - input_token_count=DEFAULT_INPUT_TOKENS, - output_token_count=DEFAULT_OUTPUT_TOKENS, - ) - - @pytest.fixture def error_parsed_record() -> ParsedResponseRecord: """Create an error ParsedResponseRecord for testing.""" @@ -191,3 +159,90 @@ def mock_metric_registry(monkeypatch): ) return mock_registry + + +def create_metric_metadata( + session_num: int = 0, + conversation_id: str | None = None, + turn_index: int = 0, + request_start_ns: int = 1_000_000_000, + request_ack_ns: int | None = None, + request_end_ns: int = 1_100_000_000, + worker_id: str = "worker-1", + record_processor_id: str = "processor-1", + benchmark_phase: CreditPhase = CreditPhase.PROFILING, + x_request_id: str | None = None, + x_correlation_id: str | None = None, +) -> MetricRecordMetadata: + """ + Create a MetricRecordMetadata object with sensible defaults. + + Args: + session_num: Sequential session number in the benchmark + conversation_id: Conversation ID (optional) + turn_index: Turn index in conversation + request_start_ns: Request start timestamp in nanoseconds + request_ack_ns: Request acknowledgement timestamp in nanoseconds (optional) + request_end_ns: Request end timestamp in nanoseconds (optional) + worker_id: Worker ID + record_processor_id: Record processor ID + benchmark_phase: Benchmark phase (warmup or profiling) + x_request_id: X-Request-ID header value (optional) + x_correlation_id: X-Correlation-ID header value (optional) + + Returns: + MetricRecordMetadata object + """ + return MetricRecordMetadata( + session_num=session_num, + conversation_id=conversation_id, + turn_index=turn_index, + request_start_ns=request_start_ns, + request_ack_ns=request_ack_ns, + request_end_ns=request_end_ns, + worker_id=worker_id, + record_processor_id=record_processor_id, + benchmark_phase=benchmark_phase, + x_request_id=x_request_id, + x_correlation_id=x_correlation_id, + ) + + +def create_metric_records_message( + service_id: str = "test-processor", + results: list[dict[MetricTagT, MetricValueTypeT]] | None = None, + error: ErrorDetails | None = None, + metadata: MetricRecordMetadata | None = None, + x_request_id: str | None = None, + **metadata_kwargs, +) -> MetricRecordsMessage: + """ + Create a MetricRecordsMessage with sensible defaults. + + Args: + service_id: Service ID + results: List of metric result dictionaries + error: Error details if any + metadata: Pre-built metadata, or None to build from kwargs + x_request_id: Record ID (will be set as x_request_id in metadata if provided) + **metadata_kwargs: Arguments to pass to create_metric_metadata if metadata is None + + Returns: + MetricRecordsMessage object + """ + if results is None: + results = [] + + if metadata is None: + # If x_request_id is provided, use it as x_request_id + if x_request_id is not None and "x_request_id" not in metadata_kwargs: + metadata_kwargs["x_request_id"] = x_request_id + metadata = create_metric_metadata(**metadata_kwargs) + + return MetricRecordsMessage( + message_type=MessageType.METRIC_RECORDS, + service_id=service_id, + metadata=metadata, + results=results, + error=error, + ) diff --git a/tests/post_processors/test_metric_results_processor.py b/tests/post_processors/test_metric_results_processor.py index f98b82dc2..21e2a5d07 100644 --- a/tests/post_processors/test_metric_results_processor.py +++ b/tests/post_processors/test_metric_results_processor.py @@ -9,11 +9,12 @@ from aiperf.common.enums import MetricType from aiperf.common.exceptions import NoMetricValue from aiperf.common.models import MetricResult -from aiperf.metrics.metric_dicts import MetricArray, MetricRecordDict, MetricResultsDict +from aiperf.metrics.metric_dicts import MetricArray, MetricResultsDict from aiperf.metrics.types.request_count_metric import RequestCountMetric from aiperf.metrics.types.request_latency_metric import RequestLatencyMetric from aiperf.metrics.types.request_throughput_metric import RequestThroughputMetric from aiperf.post_processors.metric_results_processor import MetricResultsProcessor +from tests.post_processors.conftest import create_metric_records_message class TestMetricResultsProcessor: @@ -39,14 +40,23 @@ async def test_process_result_record_metric( processor = MetricResultsProcessor(mock_user_config) processor._tags_to_types = {"test_record": MetricType.RECORD} - await processor.process_result(MetricRecordDict({"test_record": 42.0})) + message = create_metric_records_message( + x_request_id="test-1", + results=[{"test_record": 42.0}], + ) + await processor.process_result(message.to_data()) assert "test_record" in processor._results assert isinstance(processor._results["test_record"], MetricArray) assert list(processor._results["test_record"].data) == [42.0] # New data should expand the array - await processor.process_result(MetricRecordDict({"test_record": 84.0})) + message2 = create_metric_records_message( + x_request_id="test-2", + request_start_ns=1_000_000_001, + results=[{"test_record": 84.0}], + ) + await processor.process_result(message2.to_data()) assert list(processor._results["test_record"].data) == [42.0, 84.0] @pytest.mark.asyncio @@ -58,9 +68,11 @@ async def test_process_result_record_metric_list_values( processor._tags_to_types = {"test_record": MetricType.RECORD} # Process list of values - await processor.process_result( - MetricRecordDict({"test_record": [10.0, 20.0, 30.0]}) + message = create_metric_records_message( + x_request_id="test-1", + results=[{"test_record": [10.0, 20.0, 30.0]}], ) + await processor.process_result(message.to_data()) assert "test_record" in processor._results assert isinstance(processor._results["test_record"], MetricArray) @@ -76,10 +88,19 @@ async def test_process_result_aggregate_metric( processor._instances_map = {RequestCountMetric.tag: RequestCountMetric()} # Process two values and ensure they are accumulated - await processor.process_result(MetricRecordDict({RequestCountMetric.tag: 5})) + message1 = create_metric_records_message( + x_request_id="test-1", + results=[{RequestCountMetric.tag: 5}], + ) + await processor.process_result(message1.to_data()) assert processor._results[RequestCountMetric.tag] == 5 - await processor.process_result(MetricRecordDict({RequestCountMetric.tag: 3})) + message2 = create_metric_records_message( + x_request_id="test-2", + request_start_ns=1_000_000_001, + results=[{RequestCountMetric.tag: 3}], + ) + await processor.process_result(message2.to_data()) assert processor._results[RequestCountMetric.tag] == 8 @pytest.mark.asyncio diff --git a/tests/post_processors/test_post_processor_integration.py b/tests/post_processors/test_post_processor_integration.py index 8c1d7e9e4..70265c514 100644 --- a/tests/post_processors/test_post_processor_integration.py +++ b/tests/post_processors/test_post_processor_integration.py @@ -9,7 +9,7 @@ from aiperf.common.config import UserConfig from aiperf.common.constants import NANOS_PER_SECOND from aiperf.common.models import ParsedResponseRecord -from aiperf.metrics.metric_dicts import MetricArray, MetricRecordDict +from aiperf.metrics.metric_dicts import MetricArray from aiperf.metrics.types.benchmark_duration_metric import BenchmarkDurationMetric from aiperf.metrics.types.error_request_count import ErrorRequestCountMetric from aiperf.metrics.types.request_count_metric import RequestCountMetric @@ -18,6 +18,7 @@ from aiperf.post_processors.metric_record_processor import MetricRecordProcessor from aiperf.post_processors.metric_results_processor import MetricResultsProcessor from tests.post_processors.conftest import ( + create_metric_records_message, create_results_processor_with_metrics, setup_mock_registry_sequences, ) @@ -41,11 +42,12 @@ async def test_record_to_results_data_flow( results_processor = create_results_processor_with_metrics( mock_user_config, RequestLatencyMetric, RequestCountMetric ) - test_record_dict = MetricRecordDict( - {RequestLatencyMetric.tag: 100.0, RequestCountMetric.tag: 1} + message = create_metric_records_message( + x_request_id="test-1", + results=[{RequestLatencyMetric.tag: 100.0, RequestCountMetric.tag: 1}], ) - await results_processor.process_result(test_record_dict) + await results_processor.process_result(message.to_data()) assert RequestLatencyMetric.tag in results_processor._results assert isinstance( @@ -67,12 +69,14 @@ async def test_multiple_batches_accumulation( mock_user_config, RequestLatencyMetric ) - batches = [ - MetricRecordDict({RequestLatencyMetric.tag: value}) - for value in TEST_LATENCY_VALUES - ] - for batch in batches: - await results_processor.process_result(batch) + for idx, value in enumerate(TEST_LATENCY_VALUES): + message = create_metric_records_message( + x_request_id=f"test-{idx}", + request_start_ns=1_000_000_000 + idx, + x_correlation_id=f"test-correlation-{idx}", + results=[{RequestLatencyMetric.tag: value}], + ) + await results_processor.process_result(message.to_data()) assert RequestLatencyMetric.tag in results_processor._results accumulated_data = list( diff --git a/tests/post_processors/test_record_export_results_processor.py b/tests/post_processors/test_record_export_results_processor.py new file mode 100644 index 000000000..872954491 --- /dev/null +++ b/tests/post_processors/test_record_export_results_processor.py @@ -0,0 +1,573 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path +from unittest.mock import Mock, patch + +import orjson +import pytest + +from aiperf.common.config import ( + EndpointConfig, + OutputConfig, + ServiceConfig, + UserConfig, +) +from aiperf.common.enums import CreditPhase, EndpointType +from aiperf.common.enums.data_exporter_enums import ExportLevel +from aiperf.common.exceptions import PostProcessorDisabled +from aiperf.common.messages import MetricRecordsMessage +from aiperf.common.models.record_models import ( + MetricRecordInfo, + MetricRecordMetadata, + MetricValue, +) +from aiperf.metrics.metric_dicts import MetricRecordDict +from aiperf.post_processors.record_export_results_processor import ( + RecordExportResultsProcessor, +) +from tests.post_processors.conftest import create_metric_records_message + + +@pytest.fixture +def tmp_artifact_dir(tmp_path: Path) -> Path: + """Create a temporary artifact directory for testing.""" + artifact_dir = tmp_path / "artifacts" + artifact_dir.mkdir(parents=True, exist_ok=True) + return artifact_dir + + +@pytest.fixture +def user_config_records_export(tmp_artifact_dir: Path) -> UserConfig: + """Create a UserConfig with RECORDS export level.""" + return UserConfig( + endpoint=EndpointConfig( + model_names=["test-model"], + type=EndpointType.CHAT, + ), + output=OutputConfig( + artifact_directory=tmp_artifact_dir, + ), + ) + + +@pytest.fixture +def service_config() -> ServiceConfig: + """Create a ServiceConfig for testing.""" + return ServiceConfig() + + +@pytest.fixture +def sample_metric_records_message(): + """Create a sample MetricRecordsMessage for testing.""" + return create_metric_records_message( + service_id="processor-1", + x_request_id="test-record-123", + conversation_id="conv-456", + x_correlation_id="test-correlation-123", + results=[ + {"request_latency_ns": 1_000_000, "output_token_count": 10}, + {"ttft_ns": 500_000}, + ], + ) + + +class TestRecordExportResultsProcessorInitialization: + """Test RecordExportResultsProcessor initialization.""" + + @pytest.mark.parametrize( + "export_level, raise_exception", + [ + (ExportLevel.SUMMARY, True), + (ExportLevel.RECORDS, False), + (ExportLevel.RAW, False), + ], + ) + def test_init_with_export_level( + self, + monkeypatch, + export_level: ExportLevel, + raise_exception: bool, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + ): + """Test initialization with various export levels enable or disable the processor.""" + monkeypatch.setattr( + type(user_config_records_export.output), + "export_level", + property(lambda self: export_level), + ) + if raise_exception: + with pytest.raises(PostProcessorDisabled): + _ = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + else: + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + assert processor.record_count == 0 + assert processor.output_file.name == "profile_export.jsonl" + assert processor.output_file.parent.exists() + + def test_init_with_raw_export_level( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + ): + """Test initialization with RAW export level enables the processor.""" + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + assert processor.record_count == 0 + assert processor.output_file.name == "profile_export.jsonl" + assert processor.output_file.parent.exists() + + def test_init_creates_output_directory( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that initialization creates the output directory.""" + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + assert processor.output_file.parent.exists() + assert processor.output_file.parent.is_dir() + + def test_init_clears_existing_file( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that initialization clears existing output file.""" + # Create a file with existing content + output_file = ( + user_config_records_export.output.artifact_directory + / "profile_export.jsonl" + ) + output_file.parent.mkdir(parents=True, exist_ok=True) + output_file.write_text("existing content\n") + + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + # File should be cleared or not exist + if processor.output_file.exists(): + content = processor.output_file.read_text() + assert content == "" + else: + assert not processor.output_file.exists() + + def test_init_sets_show_internal_in_dev_mode( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that show_internal is set based on dev mode.""" + with patch( + "aiperf.post_processors.record_export_results_processor.AIPERF_DEV_MODE", + True, + ): + service_config.developer.show_internal_metrics = True + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + assert processor.show_internal is True + + +class TestRecordExportResultsProcessorProcessResult: + """Test RecordExportResultsProcessor process_result method.""" + + @pytest.mark.asyncio + async def test_process_result_writes_valid_data( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + sample_metric_records_message: MetricRecordsMessage, + mock_metric_registry: Mock, + ): + """Test that process_result writes valid data to file.""" + mock_display_dict = { + "request_latency": MetricValue(value=1.0, unit="ms"), + "output_token_count": MetricValue(value=10, unit="tokens"), + } + + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + await processor._open_file() + + with patch.object( + MetricRecordDict, + "to_display_dict", + return_value=mock_display_dict, + ): + await processor.process_result(sample_metric_records_message.to_data()) + + await processor._shutdown() + + assert processor.record_count == 1 + assert processor.output_file.exists() + + with open(processor.output_file, "rb") as f: + lines = f.readlines() + + assert len(lines) == 1 + record_dict = orjson.loads(lines[0]) + record = MetricRecordInfo.model_validate(record_dict) + assert record.metadata.x_request_id == "test-record-123" + assert record.metadata.conversation_id == "conv-456" + assert record.metadata.turn_index == 0 + assert record.metadata.worker_id == "worker-1" + assert record.metadata.record_processor_id == "processor-1" + assert record.metadata.benchmark_phase == CreditPhase.PROFILING + assert record.metadata.request_start_ns == 1_000_000_000 + assert record.error is None + assert "request_latency" in record.metrics + assert "output_token_count" in record.metrics + + @pytest.mark.asyncio + async def test_process_result_with_empty_display_metrics( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + sample_metric_records_message: MetricRecordsMessage, + mock_metric_registry: Mock, + ): + """Test that process_result skips records with empty display metrics.""" + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + # Mock to_display_dict to return empty dict + with patch.object(MetricRecordDict, "to_display_dict", return_value={}): + await processor.process_result(sample_metric_records_message.to_data()) + + # Should not write anything since display_metrics is empty + assert processor.record_count == 0 + if processor.output_file.exists(): + content = processor.output_file.read_text() + assert content == "" + + @pytest.mark.asyncio + async def test_process_result_handles_errors_gracefully( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + sample_metric_records_message: MetricRecordsMessage, + mock_metric_registry: Mock, + ): + """Test that errors during processing don't raise exceptions.""" + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + # Mock to_display_dict to raise an exception + with ( + patch.object( + MetricRecordDict, "to_display_dict", side_effect=Exception("Test error") + ), + patch.object(processor, "error") as mock_error, + ): + # Should not raise + await processor.process_result(sample_metric_records_message.to_data()) + + # Should log the error + assert mock_error.call_count >= 1 + + # Record count should not increment + assert processor.record_count == 0 + + @pytest.mark.asyncio + async def test_process_result_multiple_messages( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + sample_metric_records_message: MetricRecordsMessage, + mock_metric_registry: Mock, + ): + """Test processing multiple messages accumulates records.""" + mock_display_dict = { + "request_latency": MetricValue(value=1.0, unit="ms"), + } + + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + await processor._open_file() + + with patch.object( + MetricRecordDict, "to_display_dict", return_value=mock_display_dict + ): + for i in range(5): + message = create_metric_records_message( + x_request_id=f"record-{i}", + conversation_id=f"conv-{i}", + turn_index=i, + request_start_ns=1_000_000_000 + i, + results=[{"metric1": 100}, {"metric2": 200}], + ) + await processor.process_result(message.to_data()) + + await processor._shutdown() + + assert processor.record_count == 5 + assert processor.output_file.exists() + + with open(processor.output_file, "rb") as f: + lines = f.readlines() + + assert len(lines) == 5 + + for line in lines: + record_dict = orjson.loads(line) + record = MetricRecordInfo.model_validate(record_dict) + assert isinstance(record, MetricRecordInfo) + assert record.metadata.x_request_id.startswith("record-") # type: ignore[union-attr] + assert "request_latency" in record.metrics + + +class TestRecordExportResultsProcessorFileFormat: + """Test RecordExportResultsProcessor file format.""" + + @pytest.mark.asyncio + async def test_output_is_valid_jsonl( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + sample_metric_records_message: MetricRecordsMessage, + mock_metric_registry: Mock, + ): + """Test that output file is valid JSONL format.""" + mock_display_dict = {"test_metric": MetricValue(value=42, unit="ms")} + + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + await processor._open_file() + + with patch.object( + MetricRecordDict, "to_display_dict", return_value=mock_display_dict + ): + await processor.process_result(sample_metric_records_message.to_data()) + + await processor._shutdown() + + with open(processor.output_file, "rb") as f: + lines = f.readlines() + + for line in lines: + if line.strip(): + record_dict = orjson.loads(line) + assert isinstance(record_dict, dict) + record = MetricRecordInfo.model_validate(record_dict) + assert isinstance(record, MetricRecordInfo) + assert line.endswith(b"\n") + + @pytest.mark.asyncio + async def test_record_structure_is_complete( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + sample_metric_records_message: MetricRecordsMessage, + mock_metric_registry: Mock, + ): + """Test that each record has the expected structure.""" + mock_display_dict = {"test_metric": MetricValue(value=42, unit="ms")} + + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + await processor._open_file() + + with patch.object( + MetricRecordDict, "to_display_dict", return_value=mock_display_dict + ): + await processor.process_result(sample_metric_records_message.to_data()) + + await processor._shutdown() + + with open(processor.output_file, "rb") as f: + record_dict = orjson.loads(f.readline()) + + record = MetricRecordInfo.model_validate(record_dict) + + assert isinstance(record.metadata, MetricRecordMetadata) + assert isinstance(record.metrics, dict) + + assert record.metadata.conversation_id is not None + assert isinstance(record.metadata.turn_index, int) + assert isinstance(record.metadata.request_start_ns, int) + assert isinstance(record.metadata.worker_id, str) + assert isinstance(record.metadata.record_processor_id, str) + assert isinstance(record.metadata.benchmark_phase, CreditPhase) + + assert "test_metric" in record.metrics + assert isinstance(record.metrics["test_metric"], MetricValue) + assert record.metrics["test_metric"].value == 42 + assert record.metrics["test_metric"].unit == "ms" + + +class TestRecordExportResultsProcessorLogging: + """Test RecordExportResultsProcessor logging behavior.""" + + @pytest.mark.asyncio + async def test_periodic_debug_logging( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + mock_metric_registry: Mock, + ): + """Test that debug logging occurs when buffer is flushed.""" + mock_display_dict = {"test_metric": MetricValue(value=42, unit="ms")} + + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + await processor._open_file() + + with ( + patch.object( + MetricRecordDict, "to_display_dict", return_value=mock_display_dict + ), + patch.object(processor, "debug") as mock_debug, + ): + for i in range(processor._batch_size): + message = create_metric_records_message( + x_request_id=f"record-{i}", + conversation_id=f"conv-{i}", + turn_index=i, + request_start_ns=1_000_000_000 + i, + results=[{"metric1": 100}, {"metric2": 200}], + ) + await processor.process_result(message.to_data()) + + assert mock_debug.call_count >= 1 + + @pytest.mark.asyncio + async def test_error_logging_on_write_failure( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + sample_metric_records_message: MetricRecordsMessage, + mock_metric_registry: Mock, + ): + """Test that errors are logged when write fails.""" + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + with ( + patch.object( + MetricRecordDict, "to_display_dict", side_effect=OSError("Disk full") + ), + patch.object(processor, "error") as mock_error, + ): + await processor.process_result(sample_metric_records_message.to_data()) + + assert mock_error.call_count >= 1 + call_args = str(mock_error.call_args_list[0]) + assert "Failed to write record metrics" in call_args + + +class TestRecordExportResultsProcessorShutdown: + """Test RecordExportResultsProcessor shutdown behavior.""" + + @pytest.mark.asyncio + async def test_shutdown_logs_statistics( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + sample_metric_records_message: MetricRecordsMessage, + mock_metric_registry: Mock, + ): + """Test that shutdown logs final statistics.""" + mock_display_dict = {"test_metric": MetricValue(value=42, unit="ms")} + + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + with patch.object( + MetricRecordDict, "to_display_dict", return_value=mock_display_dict + ): + # Process some records + for i in range(3): + message = create_metric_records_message( + x_request_id=f"record-{i}", + conversation_id=f"conv-{i}", + turn_index=i, + request_start_ns=1_000_000_000 + i, + results=[{"metric1": 100}], + ) + await processor.process_result(message.to_data()) + + with patch.object(processor, "info") as mock_info: + await processor._shutdown() + + mock_info.assert_called_once() + call_args = str(mock_info.call_args) + assert "3 records written" in call_args or "3" in call_args + + +class TestRecordExportResultsProcessorSummarize: + """Test RecordExportResultsProcessor summarize method.""" + + @pytest.mark.asyncio + async def test_summarize_returns_empty_list( + self, + user_config_records_export: UserConfig, + service_config: ServiceConfig, + ): + """Test that summarize returns an empty list (no aggregation needed).""" + processor = RecordExportResultsProcessor( + service_id="records-manager", + service_config=service_config, + user_config=user_config_records_export, + ) + + result = await processor.summarize() + + assert result == [] + assert isinstance(result, list) diff --git a/tests/records/conftest.py b/tests/records/conftest.py new file mode 100644 index 000000000..8b4389c7f --- /dev/null +++ b/tests/records/conftest.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Shared fixtures for records tests.""" + +# Import shared constants and fixtures from root conftest +from tests.conftest import ( # noqa: F401 + DEFAULT_FIRST_RESPONSE_NS, + DEFAULT_INPUT_TOKENS, + DEFAULT_LAST_RESPONSE_NS, + DEFAULT_OUTPUT_TOKENS, + DEFAULT_START_TIME_NS, + sample_parsed_record, + sample_request_record, +) diff --git a/tests/records/test_records_filtering.py b/tests/records/test_records_filtering.py index 257ae7522..dec511037 100644 --- a/tests/records/test_records_filtering.py +++ b/tests/records/test_records_filtering.py @@ -6,72 +6,121 @@ import pytest from aiperf.common.constants import NANOS_PER_SECOND +from aiperf.common.enums import CreditPhase +from aiperf.common.messages.inference_messages import MetricRecordsData +from aiperf.common.models.record_models import MetricRecordMetadata +from aiperf.common.types import MetricTagT from aiperf.metrics.types.min_request_metric import MinRequestTimestampMetric from aiperf.metrics.types.request_latency_metric import RequestLatencyMetric +# Constants +START_TIME = 1000000000 + + +# Helper functions +def create_mock_records_manager( + start_time_ns: int, + expected_duration_sec: float | None, + grace_period_sec: float = 0.0, +) -> MagicMock: + """Create a mock RecordsManager instance for testing filtering logic.""" + instance = MagicMock() + instance.expected_duration_sec = expected_duration_sec + instance.start_time_ns = start_time_ns + instance.user_config.loadgen.benchmark_grace_period = grace_period_sec + instance.debug = MagicMock() + return instance + + +def create_metric_record_data( + request_start_ns: int, + request_end_ns: int, + metrics: dict[MetricTagT, int | float] | None = None, +) -> MetricRecordsData: + """Create a MetricRecordsData object with sensible defaults for testing.""" + return MetricRecordsData( + metadata=MetricRecordMetadata( + session_num=0, + conversation_id="test", + turn_index=0, + request_start_ns=request_start_ns, + request_end_ns=request_end_ns, + worker_id="worker-1", + record_processor_id="processor-1", + benchmark_phase=CreditPhase.PROFILING, + ), + metrics=metrics or {}, + ) + class TestRecordsManagerFiltering: - """Test the records manager's filtering logic .""" + """Test the records manager's filtering logic.""" def test_should_include_request_by_duration_no_duration_benchmark(self): """Test that request-count benchmarks always include all requests.""" from aiperf.records.records_manager import RecordsManager - instance = MagicMock() - instance.expected_duration_sec = None + instance = create_mock_records_manager( + start_time_ns=0, + expected_duration_sec=None, + ) - results = [ - { + record_data = create_metric_record_data( + request_start_ns=999999999999999, + request_end_ns=999999999999999, + metrics={ MinRequestTimestampMetric.tag: 999999999999999, RequestLatencyMetric.tag: 999999999999999, - } - ] + }, + ) - result = RecordsManager._should_include_request_by_duration(instance, results) + result = RecordsManager._should_include_request_by_duration( + instance, record_data + ) assert result is True def test_should_include_request_within_duration_no_grace_period(self): """Test filtering with zero grace period - only duration matters.""" from aiperf.records.records_manager import RecordsManager - start_time = 1000000000 - duration_sec = 2.0 - grace_period_sec = 0.0 - - instance = MagicMock() - instance.expected_duration_sec = duration_sec - instance.start_time_ns = start_time - instance.user_config.loadgen.benchmark_grace_period = grace_period_sec - instance.debug = MagicMock() # Mock debug method + instance = create_mock_records_manager( + start_time_ns=START_TIME, + expected_duration_sec=2.0, + grace_period_sec=0.0, + ) # Request that completes exactly at duration end should be included - results_at_duration = [ - { - MinRequestTimestampMetric.tag: start_time + int(1.5 * NANOS_PER_SECOND), - RequestLatencyMetric.tag: int( - 0.5 * NANOS_PER_SECOND - ), # Completes exactly at 2.0s - } - ] + request_start = START_TIME + int(1.5 * NANOS_PER_SECOND) + request_latency = int(0.5 * NANOS_PER_SECOND) + record_at_duration = create_metric_record_data( + request_start_ns=request_start, + request_end_ns=request_start + request_latency, # Completes exactly at 2.0s + metrics={ + MinRequestTimestampMetric.tag: request_start, + RequestLatencyMetric.tag: request_latency, + }, + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_at_duration + instance, record_at_duration ) is True ) # Request that completes after duration should be excluded - results_after_duration = [ - { - MinRequestTimestampMetric.tag: start_time + int(1.5 * NANOS_PER_SECOND), - RequestLatencyMetric.tag: int( - 0.6 * NANOS_PER_SECOND - ), # Completes at 2.1s - } - ] + request_start2 = START_TIME + int(1.5 * NANOS_PER_SECOND) + request_latency2 = int(0.6 * NANOS_PER_SECOND) + record_after_duration = create_metric_record_data( + request_start_ns=request_start2, + request_end_ns=request_start2 + request_latency2, # Completes at 2.1s + metrics={ + MinRequestTimestampMetric.tag: request_start2, + RequestLatencyMetric.tag: request_latency2, + }, + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_after_duration + instance, record_after_duration ) is False ) @@ -80,45 +129,46 @@ def test_should_include_request_within_grace_period(self): """Test filtering with grace period - responses within grace period are included.""" from aiperf.records.records_manager import RecordsManager - start_time = 1000000000 - duration_sec = 2.0 - grace_period_sec = 1.0 # 1 second grace period - - instance = MagicMock() - instance.expected_duration_sec = duration_sec - instance.start_time_ns = start_time - instance.user_config.loadgen.benchmark_grace_period = grace_period_sec - instance.debug = MagicMock() - - results_within_grace = [ - { - MinRequestTimestampMetric.tag: start_time - + int(1.5 * NANOS_PER_SECOND), # 1.5s after start - RequestLatencyMetric.tag: int( - 1.4 * NANOS_PER_SECOND - ), # Completes at 2.9s (within 3s total) - } - ] + instance = create_mock_records_manager( + start_time_ns=START_TIME, + expected_duration_sec=2.0, + grace_period_sec=1.0, + ) + + # Request that completes within grace period should be included + request_start_within = START_TIME + int(1.5 * NANOS_PER_SECOND) + request_latency_within = int(1.4 * NANOS_PER_SECOND) + record_within_grace = create_metric_record_data( + request_start_ns=request_start_within, + request_end_ns=request_start_within + + request_latency_within, # Completes at 2.9s + metrics={ + MinRequestTimestampMetric.tag: request_start_within, + RequestLatencyMetric.tag: request_latency_within, + }, + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_within_grace + instance, record_within_grace ) is True ) # Request that completes after grace period should be excluded - results_after_grace = [ - { - MinRequestTimestampMetric.tag: start_time - + int(1.5 * NANOS_PER_SECOND), # 1.5s after start - RequestLatencyMetric.tag: int( - 1.6 * NANOS_PER_SECOND - ), # Completes at 3.1s (after 3s total) - } - ] + request_start_after = START_TIME + int(1.5 * NANOS_PER_SECOND) + request_latency_after = int(1.6 * NANOS_PER_SECOND) + record_after_grace = create_metric_record_data( + request_start_ns=request_start_after, + request_end_ns=request_start_after + + request_latency_after, # Completes at 3.1s + metrics={ + MinRequestTimestampMetric.tag: request_start_after, + RequestLatencyMetric.tag: request_latency_after, + }, + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_after_grace + instance, record_after_grace ) is False ) @@ -127,48 +177,51 @@ def test_should_include_request_missing_metrics(self): """Test filtering behavior when required metrics are missing.""" from aiperf.records.records_manager import RecordsManager - start_time = 1000000000 - duration_sec = 2.0 - grace_period_sec = 1.0 - - instance = MagicMock() - instance.expected_duration_sec = duration_sec - instance.start_time_ns = start_time - instance.user_config.loadgen.benchmark_grace_period = grace_period_sec - instance.debug = MagicMock() + instance = create_mock_records_manager( + start_time_ns=START_TIME, + expected_duration_sec=2.0, + grace_period_sec=1.0, + ) - # Request with missing timestamp should be included (cannot filter) - results_missing_timestamp = [ - { + # Request that ends after grace period should be excluded + record_missing_timestamp = create_metric_record_data( + request_start_ns=START_TIME, + request_end_ns=START_TIME + int(5.0 * NANOS_PER_SECOND), # After grace + metrics={ RequestLatencyMetric.tag: int(5.0 * NANOS_PER_SECOND) # Only latency - } - ] + }, + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_missing_timestamp + instance, record_missing_timestamp ) - is True + is False ) - # Request with missing latency should be included (cannot filter) - results_missing_latency = [ - { - MinRequestTimestampMetric.tag: start_time - + int(1.0 * NANOS_PER_SECOND) # Only timestamp - } - ] + # Request that ends within grace period should be included + record_missing_latency = create_metric_record_data( + request_start_ns=START_TIME + int(1.0 * NANOS_PER_SECOND), + request_end_ns=START_TIME + int(2.0 * NANOS_PER_SECOND), # Within grace + metrics={ + MinRequestTimestampMetric.tag: START_TIME + int(1.0 * NANOS_PER_SECOND) + }, + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_missing_latency + instance, record_missing_latency ) is True ) - # Request with no metrics should be included - results_no_metrics = [{}] + # Request with no metrics should be included if it ends within grace period + record_no_metrics = create_metric_record_data( + request_start_ns=START_TIME, + request_end_ns=START_TIME + int(1.0 * NANOS_PER_SECOND), # Within grace + metrics={}, + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_no_metrics + instance, record_no_metrics ) is True ) @@ -178,41 +231,44 @@ def test_should_include_request_various_grace_periods(self, grace_period: float) """Test filtering logic with various grace period values.""" from aiperf.records.records_manager import RecordsManager - start_time = 1000000000 - duration_sec = 2.0 - - instance = MagicMock() - instance.expected_duration_sec = duration_sec - instance.start_time_ns = start_time - instance.user_config.loadgen.benchmark_grace_period = grace_period - instance.debug = MagicMock() + instance = create_mock_records_manager( + start_time_ns=START_TIME, + expected_duration_sec=2.0, + grace_period_sec=grace_period, + ) # Request that completes exactly at duration + grace_period boundary - results_at_boundary = [ - { - MinRequestTimestampMetric.tag: start_time + int(1.0 * NANOS_PER_SECOND), - RequestLatencyMetric.tag: int((1.0 + grace_period) * NANOS_PER_SECOND), - } - ] + request_start_at = START_TIME + int(1.0 * NANOS_PER_SECOND) + request_latency_at = int((1.0 + grace_period) * NANOS_PER_SECOND) + record_at_boundary = create_metric_record_data( + request_start_ns=request_start_at, + request_end_ns=request_start_at + request_latency_at, + metrics={ + MinRequestTimestampMetric.tag: request_start_at, + RequestLatencyMetric.tag: request_latency_at, + }, + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_at_boundary + instance, record_at_boundary ) is True ) # Request that completes just after duration + grace_period should be excluded - results_after_boundary = [ - { - MinRequestTimestampMetric.tag: start_time + int(1.0 * NANOS_PER_SECOND), - RequestLatencyMetric.tag: int( - (1.0 + grace_period + 0.1) * NANOS_PER_SECOND - ), - } - ] + request_start_after = START_TIME + int(1.0 * NANOS_PER_SECOND) + request_latency_after = int((1.0 + grace_period + 0.1) * NANOS_PER_SECOND) + record_after_boundary = create_metric_record_data( + request_start_ns=request_start_after, + request_end_ns=request_start_after + request_latency_after, + metrics={ + MinRequestTimestampMetric.tag: request_start_after, + RequestLatencyMetric.tag: request_latency_after, + }, + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_after_boundary + instance, record_after_boundary ) is False ) @@ -221,56 +277,44 @@ def test_should_include_request_multiple_results_in_request(self): """Test filtering with multiple result dictionaries for a single request (all-or-nothing).""" from aiperf.records.records_manager import RecordsManager - start_time = 1000000000 - duration_sec = 2.0 - grace_period_sec = 1.0 - - instance = MagicMock() - instance.expected_duration_sec = duration_sec - instance.start_time_ns = start_time - instance.user_config.loadgen.benchmark_grace_period = grace_period_sec - instance.debug = MagicMock() - - # Multiple results where all complete within grace period - should include - results_all_within = [ - { - MinRequestTimestampMetric.tag: start_time + int(1.5 * NANOS_PER_SECOND), - RequestLatencyMetric.tag: int( - 1.0 * NANOS_PER_SECOND - ), # Completes at 2.5s - }, - { - MinRequestTimestampMetric.tag: start_time + int(1.8 * NANOS_PER_SECOND), - RequestLatencyMetric.tag: int( - 1.1 * NANOS_PER_SECOND - ), # Completes at 2.9s + instance = create_mock_records_manager( + start_time_ns=START_TIME, + expected_duration_sec=2.0, + grace_period_sec=1.0, + ) + + # Request where the latest response completes within grace period - should include + request_start_within = START_TIME + int(1.5 * NANOS_PER_SECOND) + record_all_within = create_metric_record_data( + request_start_ns=request_start_within, + request_end_ns=START_TIME + + int(2.9 * NANOS_PER_SECOND), # Latest completion time + metrics={ + MinRequestTimestampMetric.tag: request_start_within, + RequestLatencyMetric.tag: int(1.0 * NANOS_PER_SECOND), }, - ] + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_all_within + instance, record_all_within ) is True ) - # Multiple results where one completes after grace period - should exclude entire request - results_one_after = [ - { - MinRequestTimestampMetric.tag: start_time + int(1.0 * NANOS_PER_SECOND), - RequestLatencyMetric.tag: int( - 1.0 * NANOS_PER_SECOND - ), # Completes at 2.0s (within) + # Request where one response completes after grace period - should exclude entire request + request_start_after = START_TIME + int(1.0 * NANOS_PER_SECOND) + record_one_after = create_metric_record_data( + request_start_ns=request_start_after, + request_end_ns=START_TIME + + int(3.5 * NANOS_PER_SECOND), # Latest completion time (after grace) + metrics={ + MinRequestTimestampMetric.tag: request_start_after, + RequestLatencyMetric.tag: int(1.0 * NANOS_PER_SECOND), }, - { - MinRequestTimestampMetric.tag: start_time + int(1.5 * NANOS_PER_SECOND), - RequestLatencyMetric.tag: int( - 2.0 * NANOS_PER_SECOND - ), # Completes at 3.5s (after grace) - }, - ] + ) assert ( RecordsManager._should_include_request_by_duration( - instance, results_one_after + instance, record_one_after ) is False ) diff --git a/tests/timing_manager/conftest.py b/tests/timing_manager/conftest.py index aaae65394..1dbc5dfc5 100644 --- a/tests/timing_manager/conftest.py +++ b/tests/timing_manager/conftest.py @@ -109,6 +109,7 @@ async def publish(self, message: Message) -> None: async def drop_credit( self, credit_phase: CreditPhase, + credit_num: int, conversation_id: str | None = None, credit_drop_ns: int | None = None, should_cancel: bool = False, @@ -121,6 +122,7 @@ async def drop_credit( CreditDropMessage( service_id="test-service", phase=credit_phase, + credit_num=credit_num, conversation_id=conversation_id, credit_drop_ns=credit_drop_ns, should_cancel=should_cancel, diff --git a/tests/workers/test_worker.py b/tests/workers/test_worker.py index c987ea25b..c3d69c86c 100644 --- a/tests/workers/test_worker.py +++ b/tests/workers/test_worker.py @@ -229,6 +229,7 @@ async def test_build_response_record( credit_drop_ns=None, should_cancel=False, cancel_after_ns=123456789, + credit_num=1, ) dummy_record = RequestRecord() @@ -278,6 +279,7 @@ async def test_build_response_record_credit_drop_latency_only_first_turn( credit_drop_ns=None, should_cancel=False, cancel_after_ns=123456789, + credit_num=1, ) dummy_record = RequestRecord() @@ -324,6 +326,7 @@ async def test_x_request_id_and_x_correlation_id_passed_to_client(self, worker): message = CreditDropMessage( service_id="test-service", phase=CreditPhase.PROFILING, + credit_num=1, ) turn = Turn(texts=[Text(contents=["test"])], model="test-model") x_request_id = str(uuid.uuid4())