Skip to content

Commit

Permalink
Rename Scheduler --> Orchestrator (#3479)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3479

As titled; now is the time to do this before we write related docs

Differential Revision: D70788646
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Mar 7, 2025
1 parent 5857ba9 commit 5de62c3
Show file tree
Hide file tree
Showing 35 changed files with 911 additions and 928 deletions.
52 changes: 26 additions & 26 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
from ax.core.trial_status import TrialStatus
from ax.core.types import TParamValue
from ax.core.utils import get_model_times
from ax.service.scheduler import Scheduler
from ax.service.orchestrator import Orchestrator
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.scheduler_options import SchedulerOptions, TrialType
from ax.service.utils.orchestrator_options import OrchestratorOptions, TrialType
from ax.utils.common.logger import DEFAULT_LOG_LEVEL, get_logger
from ax.utils.common.random import with_rng_seed
from pyre_extensions import assert_is_instance
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_benchmark_runner(
(used to generate data) and ``step_runtime_function`` (used to
determine timing for the simulator).
max_concurrency: The maximum number of trials that can be run concurrently.
Typically, ``max_pending_trials`` from ``SchedulerOptions``, which are
Typically, ``max_pending_trials`` from ``OrchestratorOptions``, which are
stored on the ``BenchmarkMethod``.
"""

Expand Down Expand Up @@ -173,26 +173,26 @@ def get_oracle_experiment_from_params(
return experiment


def get_benchmark_scheduler_options(
def get_benchmark_orchestrator_options(
method: BenchmarkMethod,
include_sq: bool = False,
logging_level: int = DEFAULT_LOG_LEVEL,
) -> SchedulerOptions:
) -> OrchestratorOptions:
"""
Get the ``SchedulerOptions`` for the given ``BenchmarkMethod``.
Get the ``OrchestratorOptions`` for the given ``BenchmarkMethod``.
Args:
method: The ``BenchmarkMethod``.
include_sq: Whether to include the status quo in each trial.
Returns:
``SchedulerOptions``
``OrchestratorOptions``
"""
if method.batch_size is None or method.batch_size > 1 or include_sq:
trial_type = TrialType.BATCH_TRIAL
else:
trial_type = TrialType.TRIAL
return SchedulerOptions(
return OrchestratorOptions(
# No new candidates can be generated while any are pending.
# If batched, an entire batch must finish before the next can be
# generated.
Expand Down Expand Up @@ -372,7 +372,7 @@ def benchmark_replication(
method: BenchmarkMethod,
seed: int,
strip_runner_before_saving: bool = True,
scheduler_logging_level: int = DEFAULT_LOG_LEVEL,
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
) -> BenchmarkResult:
"""
Run one benchmarking replication (equivalent to one optimization loop).
Expand All @@ -389,7 +389,7 @@ def benchmark_replication(
seed: The seed to use for this replication.
strip_runner_before_saving: Whether to strip the runner from the
experiment before saving it. This enables serialization.
scheduler_logging_level: If >INFO, logs will only appear when unexpected
orchestrator_logging_level: If >INFO, logs will only appear when unexpected
things happen. If INFO, logs will update when a trial is completed
and when an early stopping strategy, if present, decides whether or
not to continue a trial. If DEBUG, logs additionaly include
Expand All @@ -403,13 +403,13 @@ def benchmark_replication(
if problem.status_quo_params is None
else Arm(name="status_quo", parameters=problem.status_quo_params)
)
scheduler_options = get_benchmark_scheduler_options(
orchestrator_options = get_benchmark_orchestrator_options(
method=method,
include_sq=sq_arm is not None,
logging_level=scheduler_logging_level,
logging_level=orchestrator_logging_level,
)
runner = get_benchmark_runner(
problem=problem, max_concurrency=scheduler_options.max_pending_trials
problem=problem, max_concurrency=orchestrator_options.max_pending_trials
)
experiment = Experiment(
name=f"{problem.name}|{method.name}_{int(time())}",
Expand All @@ -420,10 +420,10 @@ def benchmark_replication(
auxiliary_experiments_by_purpose=problem.auxiliary_experiments_by_purpose,
)

scheduler = Scheduler(
orchestrator = Orchestrator(
experiment=experiment,
generation_strategy=method.generation_strategy.clone_reset(),
options=scheduler_options,
options=orchestrator_options,
)

# Each of these lists is added to when a trial completes or stops early.
Expand All @@ -447,11 +447,11 @@ def benchmark_replication(
)
start = monotonic()
# These next several lines do the same thing as
# `scheduler.run_n_trials`, but
# `orchestrator.run_n_trials`, but
# decrement the timeout with each step, so that the timeout refers to
# the total time spent in the optimization loop, not time per trial.
scheduler.poll_and_process_results()
for _ in scheduler.run_trials_and_yield_results(
orchestrator.poll_and_process_results()
for _ in orchestrator.run_trials_and_yield_results(
max_trials=problem.num_trials,
timeout_hours=remaining_hours,
):
Expand All @@ -472,7 +472,7 @@ def benchmark_replication(
logger.warning("The optimization loop timed out.")
break

scheduler.summarize_final_result()
orchestrator.summarize_final_result()

inference_trace = _get_inference_trace_from_params(
best_params_list=best_params_list,
Expand Down Expand Up @@ -500,9 +500,9 @@ def benchmark_replication(
experiment.runner = None

return BenchmarkResult(
name=scheduler.experiment.name,
name=orchestrator.experiment.name,
seed=seed,
experiment=scheduler.experiment,
experiment=orchestrator.experiment,
oracle_trace=oracle_trace,
inference_trace=inference_trace,
optimization_trace=optimization_trace,
Expand Down Expand Up @@ -565,7 +565,7 @@ def compute_baseline_value_from_sobol(
problem=dummy_problem,
method=method,
seed=i,
scheduler_logging_level=WARNING,
orchestrator_logging_level=WARNING,
)
values[i] = result.optimization_trace[-1]

Expand All @@ -576,15 +576,15 @@ def benchmark_one_method_problem(
problem: BenchmarkProblem,
method: BenchmarkMethod,
seeds: Iterable[int],
scheduler_logging_level: int = DEFAULT_LOG_LEVEL,
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
) -> AggregatedBenchmarkResult:
return AggregatedBenchmarkResult.from_benchmark_results(
results=[
benchmark_replication(
problem=problem,
method=method,
seed=seed,
scheduler_logging_level=scheduler_logging_level,
orchestrator_logging_level=orchestrator_logging_level,
)
for seed in seeds
]
Expand All @@ -595,7 +595,7 @@ def benchmark_multiple_problems_methods(
problems: Iterable[BenchmarkProblem],
methods: Iterable[BenchmarkMethod],
seeds: Iterable[int],
scheduler_logging_level: int = DEFAULT_LOG_LEVEL,
orchestrator_logging_level: int = DEFAULT_LOG_LEVEL,
) -> list[AggregatedBenchmarkResult]:
"""
For each `problem` and `method` in the Cartesian product of `problems` and
Expand All @@ -608,7 +608,7 @@ def benchmark_multiple_problems_methods(
problem=p,
method=m,
seeds=seeds,
scheduler_logging_level=scheduler_logging_level,
orchestrator_logging_level=orchestrator_logging_level,
)
for p, m in product(problems, methods)
]
10 changes: 5 additions & 5 deletions ax/benchmark/benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
@dataclass(kw_only=True)
class BenchmarkMethod(Base):
"""Benchmark method, represented in terms of Ax generation strategy (which tells us
which models to use when) and scheduler options (which tell us extra execution
which models to use when) and Orchestrator options (which tell us extra execution
information like maximum parallelism, early stopping configuration, etc.).
Args:
Expand All @@ -44,9 +44,9 @@ class BenchmarkMethod(Base):
``NotImplementedError``.
batch_size: Number of arms per trial. If greater than 1, trials are
``BatchTrial``s; otherwise, they are ``Trial``s. Defaults to 1. This
and the following arguments are passed to ``SchedulerOptions``.
run_trials_in_batches: Passed to ``SchedulerOptions``.
max_pending_trials: Passed to ``SchedulerOptions``.
and the following arguments are passed to ``OrchestratorOptions``.
run_trials_in_batches: Passed to ``OrchestratorOptions``.
max_pending_trials: Passed to ``OrchestratorOptions``.
"""

name: str = "DEFAULT"
Expand Down Expand Up @@ -111,7 +111,7 @@ def _get_first_parameterization_from_last_trial() -> TParameterization:
return experiment.trials[max(experiment.trials)].arms[0].parameters

# SOO, n=1 case.
# Note: This has the same effect as Scheduler.get_best_parameters
# Note: This has the same effect as orchestrator.get_best_parameters
if len(experiment.trials_by_status[TrialStatus.COMPLETED]) == 0:
return [_get_first_parameterization_from_last_trial()]

Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult
)
# The BackendSimulator distinguishes between queued and running
# trials "for testing particular initialization cases", but these
# are all "running" to Scheduler.
# are all "running" to orchestrator.
start_time = none_throws(sim_trial.sim_start_time)

if sim_trial.sim_completed_time is None: # Still running
Expand Down
2 changes: 1 addition & 1 deletion ax/benchmark/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class BenchmarkRunner(Runner):
(in ``TParameterization`` format) and returns the runtime of a step.
max_concurrency: The maximum number of trials that can be running at a
given time. Typically, this is ``max_pending_trials`` from the
``scheduler_options`` on the ``BenchmarkMethod``.
``orchestrator_options`` on the ``BenchmarkMethod``.
"""

test_function: BenchmarkTestFunction
Expand Down
11 changes: 2 additions & 9 deletions ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@ def get_sobol_mbm_generation_strategy(
model_cls: BoTorch model class, e.g. SingleTaskGP
acquisition_cls: Acquisition function class, e.g.
`qLogNoisyExpectedImprovement`.
scheduler_options: Passed as-is to scheduler. Default:
`get_benchmark_scheduler_options()`.
name: Name that will be attached to the `GenerationStrategy`.
num_sobol_trials: Number of Sobol trials; if the scheduler_options
specify to use `BatchTrial`s, then this refers to the number of
num_sobol_trials: Number of Sobol trials; can refer to the number of
`BatchTrial`s.
model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
to the BoTorch `Model`.
Expand All @@ -64,7 +61,6 @@ def get_sobol_mbm_generation_strategy(
>>> from ax.benchmark.methods.sobol_botorch_modular import (
... get_sobol_mbm_generation_strategy
... )
>>> from ax.benchmark.benchmark_method import get_benchmark_scheduler_options
>>> gs = get_sobol_mbm_generation_strategy(
... model_cls=SingleTaskGP,
... acquisition_cls=qLogNoisyExpectedImprovement,
Expand Down Expand Up @@ -122,10 +118,8 @@ def get_sobol_botorch_modular_acquisition(
acquisition_cls: Acquisition function class, e.g.
`qLogNoisyExpectedImprovement`.
distribute_replications: Whether to use multiple machines
scheduler_options: Passed as-is to scheduler. Default:
`get_benchmark_scheduler_options()`.
name: Name that will be attached to the `GenerationStrategy`.
num_sobol_trials: Number of Sobol trials; if the scheduler_options
num_sobol_trials: Number of Sobol trials; if the orchestrator_options
specify to use `BatchTrial`s, then this refers to the number of
`BatchTrial`s.
model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
Expand All @@ -138,7 +132,6 @@ def get_sobol_botorch_modular_acquisition(
>>> from ax.benchmark.methods.sobol_botorch_modular import (
... get_sobol_botorch_modular_acquisition
... )
>>> from ax.benchmark.benchmark_method import get_benchmark_scheduler_options
>>>
>>> method = get_sobol_botorch_modular_acquisition(
... model_cls=SingleTaskGP,
Expand Down
6 changes: 4 additions & 2 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def train_and_evaluate(
weight_decay=weight_decay,
)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
lr_scheduler = optim.lr_scheduler.StepLR(
optimizer, step_size=step_size, gamma=gamma
)

for inputs, labels in train_loader:
inputs = inputs.to(device=device)
Expand All @@ -92,7 +94,7 @@ def train_and_evaluate(
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
scheduler.step()
lr_scheduler.step()

# Evaluate
net.eval()
Expand Down
10 changes: 5 additions & 5 deletions ax/benchmark/tests/methods/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
import numpy as np
from ax.benchmark.benchmark import (
benchmark_replication,
get_benchmark_orchestrator_options,
get_benchmark_runner,
get_benchmark_scheduler_options,
)
from ax.benchmark.methods.modular_botorch import get_sobol_botorch_modular_acquisition
from ax.benchmark.methods.sobol import get_sobol_benchmark_method
from ax.benchmark.problems.registry import get_problem
from ax.core.experiment import Experiment
from ax.modelbridge.registry import Generators
from ax.service.scheduler import Scheduler
from ax.service.orchestrator import Orchestrator
from ax.service.utils.best_point import (
get_best_by_raw_objective_with_trial_index,
get_best_parameters_from_model_predictions_with_trial_index,
Expand Down Expand Up @@ -137,14 +137,14 @@ def _test_get_best_parameters(self, use_model_predictions: bool) -> None:
runner=get_benchmark_runner(problem=problem),
)

scheduler = Scheduler(
orchestrator = Orchestrator(
experiment=experiment,
generation_strategy=method.generation_strategy.clone_reset(),
options=get_benchmark_scheduler_options(method=method),
options=get_benchmark_orchestrator_options(method=method),
)

with with_rng_seed(seed=0):
scheduler.run_n_trials(max_trials=problem.num_trials)
orchestrator.run_n_trials(max_trials=problem.num_trials)

# because the second trial is a BoTorch trial, the model should be used
best_point_mixin_path = "ax.service.utils.best_point_mixin.best_point_utils."
Expand Down
Loading

0 comments on commit 5de62c3

Please sign in to comment.