Skip to content

Commit

Permalink
Creating new generator class for concurrency sweeping
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-braf committed Jul 29, 2024
1 parent 2dabb62 commit e951b59
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 48 deletions.
77 changes: 77 additions & 0 deletions model_analyzer/config/generate/concurrency_sweeper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from copy import deepcopy
from typing import Generator, List, Optional

from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
from model_analyzer.config.run.run_config import RunConfig
from model_analyzer.constants import LOGGER_NAME
from model_analyzer.result.parameter_search import ParameterSearch
from model_analyzer.result.result_manager import ResultManager
from model_analyzer.result.run_config_measurement import RunConfigMeasurement

from .config_generator_interface import ConfigGeneratorInterface

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'ConfigGeneratorInterface' is not used.

logger = logging.getLogger(LOGGER_NAME)


class ConcurrencySweeper:
"""
Sweeps concurrency for the top-N model configs
"""

def __init__(
self,
config: ConfigCommandProfile,
result_manager: ResultManager,
):
self._config = config
self._result_manager = result_manager
self._last_measurement: Optional[RunConfigMeasurement] = None

def set_last_results(
self, measurements: List[Optional[RunConfigMeasurement]]
) -> None:
self._last_measurement = measurements[-1]

def get_configs(self) -> Generator[RunConfig, None, None]:
"""
A generator which creates RunConfigs based on sweeping
concurrency over the top-N models
"""
for model_name in self._result_manager.get_model_names():
top_results = self._result_manager.top_n_results(
model_name=model_name,
n=self._config.num_configs_per_model,
include_default=True,
)

for result in top_results:
run_config = deepcopy(result.run_config())
parameter_search = ParameterSearch(self._config)
for concurrency in parameter_search.search_parameters():
run_config = self._create_run_config(run_config, concurrency)
yield run_config
parameter_search.add_run_config_measurement(self._last_measurement)

def _create_run_config(self, run_config: RunConfig, concurrency: int) -> RunConfig:
for model_run_config in run_config.model_run_configs():
perf_config = model_run_config.perf_config()
perf_config.update_config({"concurrency-range": concurrency})

return run_config
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from copy import deepcopy
from typing import Dict, Generator, List, Optional

from model_analyzer.config.generate.concurrency_sweeper import ConcurrencySweeper
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
from model_analyzer.config.generate.model_variant_name_manager import (
ModelVariantNameManager,
Expand Down Expand Up @@ -115,7 +116,9 @@ def get_configs(self) -> Generator[RunConfig, None, None]:
"Done with Optuna mode search. Gathering concurrency sweep measurements for reports"
)
logger.info("")
yield from self._sweep_concurrency_over_top_results()
yield from ConcurrencySweeper(
config=self._config, result_manager=self._result_manager
).get_configs()
logger.info("")
logger.info("Done gathering concurrency sweep measurements for reports")
logger.info("")
Expand All @@ -136,26 +139,3 @@ def _create_optuna_run_config_generator(self) -> OptunaRunConfigGenerator:
search_parameters=self._search_parameters,
composing_search_parameters=self._composing_search_parameters,
)

def _sweep_concurrency_over_top_results(self) -> Generator[RunConfig, None, None]:
for model_name in self._result_manager.get_model_names():
top_results = self._result_manager.top_n_results(
model_name=model_name,
n=self._config.num_configs_per_model,
include_default=True,
)

for result in top_results:
run_config = deepcopy(result.run_config())
parameter_search = ParameterSearch(self._config)
for concurrency in parameter_search.search_parameters():
run_config = self._set_concurrency(run_config, concurrency)
yield run_config
parameter_search.add_run_config_measurement(self._last_measurement)

def _set_concurrency(self, run_config: RunConfig, concurrency: int) -> RunConfig:
for model_run_config in run_config.model_run_configs():
perf_config = model_run_config.perf_config()
perf_config.update_config({"concurrency-range": concurrency})

return run_config
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from copy import deepcopy
from typing import Generator, List, Optional

from model_analyzer.config.generate.concurrency_sweeper import ConcurrencySweeper
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
from model_analyzer.config.generate.model_variant_name_manager import (
ModelVariantNameManager,
Expand Down Expand Up @@ -106,7 +107,9 @@ def get_configs(self) -> Generator[RunConfig, None, None]:
"Done with quick mode search. Gathering concurrency sweep measurements for reports"
)
logger.info("")
yield from self._sweep_concurrency_over_top_results()
yield from ConcurrencySweeper(
config=self._config, result_manager=self._result_manager
).get_configs()
logger.info("")
logger.info("Done gathering concurrency sweep measurements for reports")
logger.info("")
Expand All @@ -125,26 +128,3 @@ def _create_quick_run_config_generator(self) -> QuickRunConfigGenerator:
composing_models=self._composing_models,
model_variant_name_manager=self._model_variant_name_manager,
)

def _sweep_concurrency_over_top_results(self) -> Generator[RunConfig, None, None]:
for model_name in self._result_manager.get_model_names():
top_results = self._result_manager.top_n_results(
model_name=model_name,
n=self._config.num_configs_per_model,
include_default=True,
)

for result in top_results:
run_config = deepcopy(result.run_config())
parameter_search = ParameterSearch(self._config)
for concurrency in parameter_search.search_parameters():
run_config = self._set_concurrency(run_config, concurrency)
yield run_config
parameter_search.add_run_config_measurement(self._last_measurement)

def _set_concurrency(self, run_config: RunConfig, concurrency: int) -> RunConfig:
for model_run_config in run_config.model_run_configs():
perf_config = model_run_config.perf_config()
perf_config.update_config({"concurrency-range": concurrency})

return run_config

0 comments on commit e951b59

Please sign in to comment.