Skip to content

Commit

Permalink
fix: Some minor cleanup fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman committed Aug 29, 2024
1 parent 03729ca commit 16c27f8
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 79 deletions.
10 changes: 4 additions & 6 deletions neps/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

from functools import partial
from typing import Callable, Mapping
from typing import TYPE_CHECKING, Callable, Mapping

from .base_optimizer import BaseOptimizer
from .bayesian_optimization.cost_cooling import CostCooling
from .bayesian_optimization.mf_tpe import MultiFidelityPriorWeightedTreeParzenEstimator
from .bayesian_optimization.optimizer import BayesianOptimization
from .grid_search.optimizer import GridSearch
from .multi_fidelity.dyhpo import MFEIBO
Expand All @@ -26,13 +24,14 @@
from .random_search.optimizer import RandomSearch
from .regularized_evolution.optimizer import RegularizedEvolution

if TYPE_CHECKING:
from .base_optimizer import BaseOptimizer

# TODO: Rename Searcher to Optimizer...
SearcherMapping: Mapping[str, Callable[..., BaseOptimizer]] = {
"bayesian_optimization": BayesianOptimization,
"pibo": partial(BayesianOptimization, disable_priors=False),
"cost_cooling_bayesian_optimization": CostCooling,
"random_search": RandomSearch,
"cost_cooling": CostCooling,
"regularized_evolution": RegularizedEvolution,
"assisted_regularized_evolution": partial(RegularizedEvolution, assisted=True),
"grid_search": GridSearch,
Expand All @@ -41,7 +40,6 @@
"asha": AsynchronousSuccessiveHalving,
"hyperband": Hyperband,
"asha_prior": AsynchronousSuccessiveHalvingWithPriors,
"multifidelity_tpe": MultiFidelityPriorWeightedTreeParzenEstimator,
"hyperband_custom_default": HyperbandCustomDefault,
"priorband": PriorBand,
"mobster": MOBSTER,
Expand Down
2 changes: 1 addition & 1 deletion neps/distributions.py → neps/sampling/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ def log_prob(self, value):


@dataclass
class DistributionOverDomain:
class TorchDistributionWithDomain:
distribution: Distribution
domain: Domain
140 changes: 68 additions & 72 deletions neps/state/neps_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,75 +32,6 @@
Loc = TypeVar("Loc")
T = TypeVar("T")

def sample_trial(
neps_state,
optimizer: BaseOptimizer,
*,
worker_id: str,
_sample_hooks: list[Callable] | None = None,
) -> Trial:
"""Sample a new trial from the optimizer.
Args:
optimizer: The optimizer to sample the trial from.
worker_id: The worker that is sampling the trial.
_sample_hooks: A list of hooks to apply to the optimizer before sampling.
Returns:
The new trial.
"""
with neps_state._optimizer_state.acquire() as (
opt_state,
put_opt,
), neps_state._seed_state.acquire() as (seed_state, put_seed_state):
trials: dict[Trial.ID, Trial] = {}
for trial_id, shared_trial in neps_state._trials.all().items():
trial = shared_trial.synced()
trials[trial_id] = trial

seed_state.set_as_global_seed_state()

# TODO: Not sure if any existing pre_load hooks required
# it to be done after `load_results`... I hope not.
if _sample_hooks is not None:
for hook in _sample_hooks:
optimizer = hook(optimizer)

# NOTE: We don't want optimizers mutating this before serialization
budget = opt_state.budget.clone() if opt_state.budget is not None else None
sampled_config, new_opt_state = optimizer.ask(
trials=trials,
budget_info=budget,
optimizer_state=opt_state.shared_state,
)

if sampled_config.previous_config_id is not None:
previous_trial = trials.get(sampled_config.previous_config_id)
if previous_trial is None:
raise ValueError(
f"Previous trial '{sampled_config.previous_config_id}' not found."
)
previous_trial_location = previous_trial.metadata.location
else:
previous_trial_location = None

trial = Trial.new(
trial_id=sampled_config.id,
location="", # HACK: This will be set by the `TrialRepo`
config=sampled_config.config,
previous_trial=sampled_config.previous_config_id,
previous_trial_location=previous_trial_location,
time_sampled=time.time(),
worker_id=worker_id,
)
shared_trial = neps_state._trials.put_new(trial)
seed_state.recapture()
put_seed_state(seed_state)
put_opt(
OptimizationState(budget=opt_state.budget, shared_state=new_opt_state)
)

return trial

@dataclass
class NePSState(Generic[Loc]):
Expand Down Expand Up @@ -140,10 +71,75 @@ def get_trials_by_ids(self, trial_ids: list[str], /) -> dict[str, Trial | None]:
for _id, shared_trial in self._trials.get_by_ids(trial_ids).items()
}

def get_optimizer_instance(self) -> BaseOptimizer:
"""Get the optimizer instance."""
raise NotImplementedError
def sample_trial(
self,
optimizer: BaseOptimizer,
*,
worker_id: str,
_sample_hooks: list[Callable] | None = None,
) -> Trial:
"""Sample a new trial from the optimizer.
Args:
optimizer: The optimizer to sample the trial from.
worker_id: The worker that is sampling the trial.
_sample_hooks: A list of hooks to apply to the optimizer before sampling.
Returns:
The new trial.
"""
with self._optimizer_state.acquire() as (
opt_state,
put_opt,
), self._seed_state.acquire() as (seed_state, put_seed_state):
trials: dict[Trial.ID, Trial] = {}
for trial_id, shared_trial in self._trials.all().items():
trial = shared_trial.synced()
trials[trial_id] = trial

seed_state.set_as_global_seed_state()

# TODO: Not sure if any existing pre_load hooks required
# it to be done after `load_results`... I hope not.
if _sample_hooks is not None:
for hook in _sample_hooks:
optimizer = hook(optimizer)

# NOTE: We don't want optimizers mutating this before serialization
budget = opt_state.budget.clone() if opt_state.budget is not None else None
sampled_config, new_opt_state = optimizer.ask(
trials=trials,
budget_info=budget,
optimizer_state=opt_state.shared_state,
)

if sampled_config.previous_config_id is not None:
previous_trial = trials.get(sampled_config.previous_config_id)
if previous_trial is None:
raise ValueError(
f"Previous trial '{sampled_config.previous_config_id}' not found."
)
previous_trial_location = previous_trial.metadata.location
else:
previous_trial_location = None

trial = Trial.new(
trial_id=sampled_config.id,
location="", # HACK: This will be set by the `TrialRepo`
config=sampled_config.config,
previous_trial=sampled_config.previous_config_id,
previous_trial_location=previous_trial_location,
time_sampled=time.time(),
worker_id=worker_id,
)
shared_trial = self._trials.put_new(trial)
seed_state.recapture()
put_seed_state(seed_state)
put_opt(
OptimizationState(budget=opt_state.budget, shared_state=new_opt_state)
)

return trial

def report_trial_evaluation(
self,
Expand Down

0 comments on commit 16c27f8

Please sign in to comment.