From e951b596a25b5e5b3030c016b444cc6e1e5f5269 Mon Sep 17 00:00:00 2001 From: Brian Raf Date: Mon, 29 Jul 2024 17:09:51 +0000 Subject: [PATCH] Creating new generator class for concurrency sweeping --- .../config/generate/concurrency_sweeper.py | 77 +++++++++++++++++++ ..._concurrency_sweep_run_config_generator.py | 28 +------ ..._concurrency_sweep_run_config_generator.py | 28 +------ 3 files changed, 85 insertions(+), 48 deletions(-) create mode 100755 model_analyzer/config/generate/concurrency_sweeper.py diff --git a/model_analyzer/config/generate/concurrency_sweeper.py b/model_analyzer/config/generate/concurrency_sweeper.py new file mode 100755 index 000000000..5d58fd175 --- /dev/null +++ b/model_analyzer/config/generate/concurrency_sweeper.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from copy import deepcopy +from typing import Generator, List, Optional + +from model_analyzer.config.input.config_command_profile import ConfigCommandProfile +from model_analyzer.config.run.run_config import RunConfig +from model_analyzer.constants import LOGGER_NAME +from model_analyzer.result.parameter_search import ParameterSearch +from model_analyzer.result.result_manager import ResultManager +from model_analyzer.result.run_config_measurement import RunConfigMeasurement + +from .config_generator_interface import ConfigGeneratorInterface + +logger = logging.getLogger(LOGGER_NAME) + + +class ConcurrencySweeper: + """ + Sweeps concurrency for the top-N model configs + """ + + def __init__( + self, + config: ConfigCommandProfile, + result_manager: ResultManager, + ): + self._config = config + self._result_manager = result_manager + self._last_measurement: Optional[RunConfigMeasurement] = None + + def set_last_results( + self, measurements: List[Optional[RunConfigMeasurement]] + ) -> None: + self._last_measurement = measurements[-1] + + def get_configs(self) -> Generator[RunConfig, None, None]: + """ + A generator which creates RunConfigs based on sweeping + concurrency over the top-N models + """ + for model_name in self._result_manager.get_model_names(): + top_results = self._result_manager.top_n_results( + model_name=model_name, + n=self._config.num_configs_per_model, + include_default=True, + ) + + for result in top_results: + run_config = deepcopy(result.run_config()) + parameter_search = ParameterSearch(self._config) + for concurrency in parameter_search.search_parameters(): + run_config = self._create_run_config(run_config, concurrency) + yield run_config + parameter_search.add_run_config_measurement(self._last_measurement) + + def _create_run_config(self, run_config: RunConfig, concurrency: int) -> RunConfig: + for model_run_config in run_config.model_run_configs(): + perf_config = model_run_config.perf_config() + perf_config.update_config({"concurrency-range": concurrency}) + + return run_config diff --git a/model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py b/model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py index 35061a6ba..36dcd597a 100755 --- a/model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py +++ b/model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py @@ -18,6 +18,7 @@ from copy import deepcopy from typing import Dict, Generator, List, Optional +from model_analyzer.config.generate.concurrency_sweeper import ConcurrencySweeper from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec from model_analyzer.config.generate.model_variant_name_manager import ( ModelVariantNameManager, @@ -115,7 +116,9 @@ def get_configs(self) -> Generator[RunConfig, None, None]: "Done with Optuna mode search. Gathering concurrency sweep measurements for reports" ) logger.info("") - yield from self._sweep_concurrency_over_top_results() + yield from ConcurrencySweeper( + config=self._config, result_manager=self._result_manager + ).get_configs() logger.info("") logger.info("Done gathering concurrency sweep measurements for reports") logger.info("") @@ -136,26 +139,3 @@ def _create_optuna_run_config_generator(self) -> OptunaRunConfigGenerator: search_parameters=self._search_parameters, composing_search_parameters=self._composing_search_parameters, ) - - def _sweep_concurrency_over_top_results(self) -> Generator[RunConfig, None, None]: - for model_name in self._result_manager.get_model_names(): - top_results = self._result_manager.top_n_results( - model_name=model_name, - n=self._config.num_configs_per_model, - include_default=True, - ) - - for result in top_results: - run_config = deepcopy(result.run_config()) - parameter_search = ParameterSearch(self._config) - for concurrency in parameter_search.search_parameters(): - run_config = self._set_concurrency(run_config, concurrency) - yield run_config - parameter_search.add_run_config_measurement(self._last_measurement) - - def _set_concurrency(self, run_config: RunConfig, concurrency: int) -> RunConfig: - for model_run_config in run_config.model_run_configs(): - perf_config = model_run_config.perf_config() - perf_config.update_config({"concurrency-range": concurrency}) - - return run_config diff --git a/model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py b/model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py index 3ac34daff..e4bdf0a90 100755 --- a/model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py +++ b/model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py @@ -18,6 +18,7 @@ from copy import deepcopy from typing import Generator, List, Optional +from model_analyzer.config.generate.concurrency_sweeper import ConcurrencySweeper from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec from model_analyzer.config.generate.model_variant_name_manager import ( ModelVariantNameManager, @@ -106,7 +107,9 @@ def get_configs(self) -> Generator[RunConfig, None, None]: "Done with quick mode search. Gathering concurrency sweep measurements for reports" ) logger.info("") - yield from self._sweep_concurrency_over_top_results() + yield from ConcurrencySweeper( + config=self._config, result_manager=self._result_manager + ).get_configs() logger.info("") logger.info("Done gathering concurrency sweep measurements for reports") logger.info("") @@ -125,26 +128,3 @@ def _create_quick_run_config_generator(self) -> QuickRunConfigGenerator: composing_models=self._composing_models, model_variant_name_manager=self._model_variant_name_manager, ) - - def _sweep_concurrency_over_top_results(self) -> Generator[RunConfig, None, None]: - for model_name in self._result_manager.get_model_names(): - top_results = self._result_manager.top_n_results( - model_name=model_name, - n=self._config.num_configs_per_model, - include_default=True, - ) - - for result in top_results: - run_config = deepcopy(result.run_config()) - parameter_search = ParameterSearch(self._config) - for concurrency in parameter_search.search_parameters(): - run_config = self._set_concurrency(run_config, concurrency) - yield run_config - parameter_search.add_run_config_measurement(self._last_measurement) - - def _set_concurrency(self, run_config: RunConfig, concurrency: int) -> RunConfig: - for model_run_config in run_config.model_run_configs(): - perf_config = model_run_config.perf_config() - perf_config.update_config({"concurrency-range": concurrency}) - - return run_config