Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aiperf/common/config/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
AudioFormat,
CommunicationBackend,
CustomDatasetType,
DatasetType,
EndpointType,
ImageFormat,
ModelSelectionStrategy,
Expand Down Expand Up @@ -48,6 +49,7 @@ class InputDefaults:
FIXED_SCHEDULE_AUTO_OFFSET = False
FIXED_SCHEDULE_START_OFFSET = None
FIXED_SCHEDULE_END_OFFSET = None
DATASET_TYPE = DatasetType.SYNTHETIC
CUSTOM_DATASET_TYPE = CustomDatasetType.MOONCAKE_TRACE
RANDOM_SEED = None
NUM_DATASET_ENTRIES = 100
Expand Down Expand Up @@ -142,6 +144,7 @@ class ServiceDefaults:
EXTRA_VERBOSE = False
LOG_PATH = None
RECORD_PROCESSOR_SERVICE_COUNT = None
DATASET_PROCESSOR_SERVICE_COUNT = 1
PROGRESS_REPORT_INTERVAL = 1.0
UI_TYPE = AIPerfUIType.DASHBOARD

Expand Down
29 changes: 28 additions & 1 deletion aiperf/common/config/input_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from aiperf.common.config.groups import Groups
from aiperf.common.config.image_config import ImageConfig
from aiperf.common.config.prompt_config import PromptConfig
from aiperf.common.enums import CustomDatasetType
from aiperf.common.enums import CustomDatasetType, DatasetType

logger = AIPerfLogger(__name__)

Expand Down Expand Up @@ -66,6 +66,24 @@ def validate_fixed_schedule_start_and_end_offset(self) -> Self:
)
return self

@model_validator(mode="after")
def validate_dataset_type(self) -> Self:
"""Validate the dataset type configuration."""
if self.dataset_type == DatasetType.SYNTHETIC and self.file is not None:
self.dataset_type = DatasetType.CUSTOM
logger.warning(
"Dataset type is set to CUSTOM because a file or custom dataset type is "
"provided for synthetic dataset"
)
if self.dataset_type == DatasetType.CUSTOM:
if self.custom_dataset_type is None:
raise ValueError(
"A custom dataset type requires a custom dataset type to be provided"
)
if self.file is None:
raise ValueError("A custom dataset type requires a file to be provided")
return self

extra: Annotated[
Any,
Field(
Expand Down Expand Up @@ -175,6 +193,15 @@ def validate_fixed_schedule_start_and_end_offset(self) -> Self:
),
] = InputDefaults.FIXED_SCHEDULE_END_OFFSET

dataset_type: Annotated[
DatasetType,
Field(description="The type of dataset to generate for the requests."),
Parameter(
name=("--dataset-type",),
group=_CLI_GROUP,
),
] = InputDefaults.DATASET_TYPE

# NEW AIPerf Option
custom_dataset_type: Annotated[
CustomDatasetType,
Expand Down
11 changes: 11 additions & 0 deletions aiperf/common/config/service_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,17 @@ def validate_comm_config(self) -> Self:
),
] = ServiceDefaults.RECORD_PROCESSOR_SERVICE_COUNT

dataset_processor_service_count: Annotated[
int,
Field(
description="Number of services to spawn for processing dataset generation.",
),
Parameter(
name=("--dataset-processor-service-count", "--dataset-processors"),
group=_CLI_GROUP,
),
] = ServiceDefaults.DATASET_PROCESSOR_SERVICE_COUNT

progress_report_interval: Annotated[
float,
Field(
Expand Down
28 changes: 28 additions & 0 deletions aiperf/common/config/zmq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ def credit_drop_address(self) -> str:
def credit_return_address(self) -> str:
"""Get the credit return address based on protocol configuration."""

@property
@abstractmethod
def dataset_job_address(self) -> str:
"""Get the dataset job address based on protocol configuration."""

@property
@abstractmethod
def dataset_result_address(self) -> str:
"""Get the dataset result address based on protocol configuration."""

def get_address(self, address_type: CommAddress) -> str:
"""Get the actual address based on the address type."""
address_map = {
Expand All @@ -65,6 +75,8 @@ def get_address(self, address_type: CommAddress) -> str:
CommAddress.DATASET_MANAGER_PROXY_BACKEND: self.dataset_manager_proxy_config.backend_address,
CommAddress.CREDIT_DROP: self.credit_drop_address,
CommAddress.CREDIT_RETURN: self.credit_return_address,
CommAddress.DATASET_JOB: self.dataset_job_address,
CommAddress.DATASET_RESULT: self.dataset_result_address,
CommAddress.RECORDS: self.records_push_pull_address,
CommAddress.RAW_INFERENCE_PROXY_FRONTEND: self.raw_inference_proxy_config.frontend_address,
CommAddress.RAW_INFERENCE_PROXY_BACKEND: self.raw_inference_proxy_config.backend_address,
Expand Down Expand Up @@ -170,6 +182,12 @@ class ZMQTCPConfig(BaseZMQCommunicationConfig):
credit_return_port: int = Field(
default=5563, description="Port for credit return operations"
)
dataset_job_port: int = Field(
default=5665, description="Port for dataset job operations"
)
dataset_result_port: int = Field(
default=5666, description="Port for dataset result operations"
)
dataset_manager_proxy_config: ZMQTCPProxyConfig = Field( # type: ignore
default=ZMQTCPProxyConfig(
frontend_port=5661,
Expand Down Expand Up @@ -239,3 +257,13 @@ def credit_drop_address(self) -> str:
def credit_return_address(self) -> str:
"""Get the credit return address based on protocol configuration."""
return f"ipc://{self.path}/credit_return.ipc"

@property
def dataset_job_address(self) -> str:
"""Get the dataset job address based on protocol configuration."""
return f"ipc://{self.path}/dataset_job.ipc"

@property
def dataset_result_address(self) -> str:
"""Get the dataset result address based on protocol configuration."""
return f"ipc://{self.path}/dataset_result.ipc"
4 changes: 2 additions & 2 deletions aiperf/common/enums/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
)
from aiperf.common.enums.dataset_enums import (
AudioFormat,
ComposerType,
CustomDatasetType,
DatasetType,
ImageFormat,
PromptSource,
)
Expand Down Expand Up @@ -109,11 +109,11 @@
"CommandResponseStatus",
"CommandType",
"CommunicationBackend",
"ComposerType",
"ConsoleExporterType",
"CreditPhase",
"CustomDatasetType",
"DataExporterType",
"DatasetType",
"EndpointType",
"EndpointTypeInfo",
"GenericMetricUnit",
Expand Down
2 changes: 2 additions & 0 deletions aiperf/common/enums/command_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class CommandType(CaseInsensitiveStrEnum):
SHUTDOWN = "shutdown"
SHUTDOWN_WORKERS = "shutdown_workers"
SPAWN_WORKERS = "spawn_workers"
SPAWN_DATASET_PROCESSORS = "spawn_dataset_processors"
SHUTDOWN_DATASET_PROCESSORS = "shutdown_dataset_processors"


class CommandResponseStatus(CaseInsensitiveStrEnum):
Expand Down
6 changes: 6 additions & 0 deletions aiperf/common/enums/communication_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ class CommAddress(CaseInsensitiveStrEnum):
RAW_INFERENCE_PROXY_BACKEND = "raw_inference_proxy_backend"
"""Backend address for the InferenceParser to receive raw inference messages from Workers."""

DATASET_JOB = "dataset_job"
"""Address for sending dataset generation jobs to the DatasetProcessor."""

DATASET_RESULT = "dataset_result"
"""Address for sending dataset generation results to the DatasetManager."""


class ZMQProxyType(CaseInsensitiveStrEnum):
DEALER_ROUTER = "dealer_router"
Expand Down
4 changes: 2 additions & 2 deletions aiperf/common/enums/dataset_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from aiperf.common.enums.base_enums import CaseInsensitiveStrEnum


class ComposerType(CaseInsensitiveStrEnum):
class DatasetType(CaseInsensitiveStrEnum):
SYNTHETIC = "synthetic"
CUSTOM = "custom"
PUBLIC_DATASET = "public_dataset"
PUBLIC = "public"


class CustomDatasetType(CaseInsensitiveStrEnum):
Expand Down
6 changes: 6 additions & 0 deletions aiperf/common/enums/message_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ class MessageType(CaseInsensitiveStrEnum):
DATASET_CONFIGURED_NOTIFICATION = "dataset_configured_notification"
DATASET_TIMING_REQUEST = "dataset_timing_request"
DATASET_TIMING_RESPONSE = "dataset_timing_response"
DATASET_RESULT = "dataset_result"
PROCESS_SYNTHETIC_DATASET = "process_synthetic_dataset"
PROCESS_MOONCAKE_TRACE_DATASET = "process_mooncake_trace_dataset"
PROCESS_MULTI_TURN_DATASET = "process_multi_turn_dataset"
PROCESS_SINGLE_TURN_DATASET = "process_single_turn_dataset"
PROCESS_RANDOM_POOL_DATASET = "process_random_pool_dataset"
ERROR = "error"
HEARTBEAT = "heartbeat"
INFERENCE_RESULTS = "inference_results"
Expand Down
1 change: 1 addition & 0 deletions aiperf/common/enums/service_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ServiceType(CaseInsensitiveStrEnum):
TIMING_MANAGER = "timing_manager"
RECORD_PROCESSOR = "record_processor"
RECORDS_MANAGER = "records_manager"
DATASET_PROCESSOR = "dataset_processor"
WORKER_MANAGER = "worker_manager"
WORKER = "worker"

Expand Down
16 changes: 0 additions & 16 deletions aiperf/common/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
AIPerfUIType,
CommClientType,
CommunicationBackend,
ComposerType,
ConsoleExporterType,
CustomDatasetType,
DataExporterType,
Expand Down Expand Up @@ -58,7 +57,6 @@
from aiperf.dataset import (
CustomDatasetLoaderProtocol,
)
from aiperf.dataset.composer.base import BaseDatasetComposer
from aiperf.exporters.exporter_config import ExporterConfig
from aiperf.zmq.zmq_proxy_base import BaseZMQProxy

Expand Down Expand Up @@ -377,20 +375,6 @@ def create_instance( # type: ignore[override]
return super().create_instance(class_type, config=config, **kwargs)


class ComposerFactory(AIPerfFactory[ComposerType, "BaseDatasetComposer"]):
"""Factory for registering and creating BaseDatasetComposer instances based on the specified composer type.
see: :class:`aiperf.common.factories.AIPerfFactory` for more details.
"""

@classmethod
def create_instance( # type: ignore[override]
cls,
class_type: ComposerType | str,
**kwargs,
) -> "BaseDatasetComposer":
return super().create_instance(class_type, **kwargs)


class ConsoleExporterFactory(
AIPerfFactory[ConsoleExporterType, "ConsoleExporterProtocol"]
):
Expand Down
14 changes: 14 additions & 0 deletions aiperf/common/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@
DatasetConfiguredNotification,
DatasetTimingRequest,
DatasetTimingResponse,
ProcessDatasetMessage,
ProcessDatasetResponseMessage,
ProcessMooncakeTraceDatasetMessage,
ProcessMultiTurnDatasetMessage,
ProcessRandomPoolDatasetMessage,
ProcessSingleTurnDatasetMessage,
ProcessSyntheticDatasetMessage,
)
from aiperf.common.messages.inference_messages import (
InferenceResultsMessage,
Expand Down Expand Up @@ -110,9 +117,16 @@
"Message",
"MetricRecordsMessage",
"ParsedInferenceResultsMessage",
"ProcessDatasetMessage",
"ProcessDatasetResponseMessage",
"ProcessMooncakeTraceDatasetMessage",
"ProcessMultiTurnDatasetMessage",
"ProcessRandomPoolDatasetMessage",
"ProcessRecordsCommand",
"ProcessRecordsResponse",
"ProcessRecordsResultMessage",
"ProcessSingleTurnDatasetMessage",
"ProcessSyntheticDatasetMessage",
"ProcessingStatsMessage",
"ProfileCancelCommand",
"ProfileConfigureCommand",
Expand Down
68 changes: 68 additions & 0 deletions aiperf/common/messages/dataset_messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from typing import Any

from pydantic import Field

from aiperf.common.enums import CreditPhase, MessageType
Expand All @@ -9,6 +11,72 @@
from aiperf.common.types import MessageTypeT


class ProcessDatasetMessage(BaseServiceMessage):
"""Message for sending dataset processing requests to processors."""

random_seed: int | None = Field(
default=None, description="Random seed for the dataset generation"
)


class ProcessSyntheticDatasetMessage(ProcessDatasetMessage):
"""Message for processing synthetic data."""

message_type: MessageTypeT = MessageType.PROCESS_SYNTHETIC_DATASET
num_conversations: int = Field(
..., description="Number of conversation to generate"
)


class ProcessMooncakeTraceDatasetMessage(ProcessDatasetMessage):
"""Message for processing mooncake trace data."""

message_type: MessageTypeT = MessageType.PROCESS_MOONCAKE_TRACE_DATASET
dataset: list[tuple[str, Any]] = Field(
..., description="The Mooncake trace dataset"
)


class ProcessMultiTurnDatasetMessage(ProcessDatasetMessage):
"""Message for processing multi-turn data."""

message_type: MessageTypeT = MessageType.PROCESS_MULTI_TURN_DATASET
dataset: list[tuple[str, Any]] = Field(..., description="The multi-turn dataset")


class ProcessSingleTurnDatasetMessage(ProcessDatasetMessage):
"""Message for processing single-turn data."""

message_type: MessageTypeT = MessageType.PROCESS_SINGLE_TURN_DATASET
dataset: list[tuple[str, Any]] = Field(..., description="The single-turn dataset")


class ProcessRandomPoolDatasetMessage(ProcessDatasetMessage):
"""Message for processing random pool data."""

message_type: MessageTypeT = MessageType.PROCESS_RANDOM_POOL_DATASET
dataset: list[tuple[str, Any]] = Field(..., description="The random pool dataset")
num_conversations: int = Field(
..., description="Number of conversations to generate"
)


class ProcessDatasetResponseMessage(ProcessDatasetMessage):
"""Message for returning dataset processing responses."""

message_type: MessageTypeT = MessageType.DATASET_RESULT

generated_data: list[Conversation] = Field(
default_factory=list, description="Generated conversations"
)
error_message: str | None = Field(
default=None, description="Error message if job failed"
)
processing_time_ms: float | None = Field(
default=None, description="Time taken to process the job in milliseconds"
)


class ConversationRequestMessage(BaseServiceMessage):
"""Message to request a full conversation by ID."""

Expand Down
5 changes: 5 additions & 0 deletions aiperf/controller/system_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def __init__(
else:
self.scale_record_processors_with_workers = True

if self.service_config.dataset_processor_service_count is not None:
self.required_services[ServiceType.DATASET_PROCESSOR] = (
self.service_config.dataset_processor_service_count
)

self.proxy_manager: ProxyManager = ProxyManager(
service_config=self.service_config
)
Expand Down
Loading
Loading