diff --git a/neps/optimizers/__init__.py b/neps/optimizers/__init__.py index 31cb4c4a..518952cd 100644 --- a/neps/optimizers/__init__.py +++ b/neps/optimizers/__init__.py @@ -1,11 +1,9 @@ from __future__ import annotations from functools import partial -from typing import Callable, Mapping +from typing import TYPE_CHECKING, Callable, Mapping from .base_optimizer import BaseOptimizer -from .bayesian_optimization.cost_cooling import CostCooling -from .bayesian_optimization.mf_tpe import MultiFidelityPriorWeightedTreeParzenEstimator from .bayesian_optimization.optimizer import BayesianOptimization from .grid_search.optimizer import GridSearch from .multi_fidelity.dyhpo import MFEIBO @@ -26,13 +24,14 @@ from .random_search.optimizer import RandomSearch from .regularized_evolution.optimizer import RegularizedEvolution +if TYPE_CHECKING: + from .base_optimizer import BaseOptimizer + # TODO: Rename Searcher to Optimizer... SearcherMapping: Mapping[str, Callable[..., BaseOptimizer]] = { "bayesian_optimization": BayesianOptimization, "pibo": partial(BayesianOptimization, disable_priors=False), - "cost_cooling_bayesian_optimization": CostCooling, "random_search": RandomSearch, - "cost_cooling": CostCooling, "regularized_evolution": RegularizedEvolution, "assisted_regularized_evolution": partial(RegularizedEvolution, assisted=True), "grid_search": GridSearch, @@ -41,7 +40,6 @@ "asha": AsynchronousSuccessiveHalving, "hyperband": Hyperband, "asha_prior": AsynchronousSuccessiveHalvingWithPriors, - "multifidelity_tpe": MultiFidelityPriorWeightedTreeParzenEstimator, "hyperband_custom_default": HyperbandCustomDefault, "priorband": PriorBand, "mobster": MOBSTER, diff --git a/neps/distributions.py b/neps/sampling/distributions.py similarity index 99% rename from neps/distributions.py rename to neps/sampling/distributions.py index 2361e191..fb552949 100644 --- a/neps/distributions.py +++ b/neps/sampling/distributions.py @@ -225,6 +225,6 @@ def log_prob(self, value): @dataclass -class DistributionOverDomain: +class TorchDistributionWithDomain: distribution: Distribution domain: Domain diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 163679d8..8afaee62 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -32,75 +32,6 @@ Loc = TypeVar("Loc") T = TypeVar("T") -def sample_trial( - neps_state, - optimizer: BaseOptimizer, - *, - worker_id: str, - _sample_hooks: list[Callable] | None = None, -) -> Trial: - """Sample a new trial from the optimizer. - - Args: - optimizer: The optimizer to sample the trial from. - worker_id: The worker that is sampling the trial. - _sample_hooks: A list of hooks to apply to the optimizer before sampling. - - Returns: - The new trial. - """ - with neps_state._optimizer_state.acquire() as ( - opt_state, - put_opt, - ), neps_state._seed_state.acquire() as (seed_state, put_seed_state): - trials: dict[Trial.ID, Trial] = {} - for trial_id, shared_trial in neps_state._trials.all().items(): - trial = shared_trial.synced() - trials[trial_id] = trial - - seed_state.set_as_global_seed_state() - - # TODO: Not sure if any existing pre_load hooks required - # it to be done after `load_results`... I hope not. - if _sample_hooks is not None: - for hook in _sample_hooks: - optimizer = hook(optimizer) - - # NOTE: We don't want optimizers mutating this before serialization - budget = opt_state.budget.clone() if opt_state.budget is not None else None - sampled_config, new_opt_state = optimizer.ask( - trials=trials, - budget_info=budget, - optimizer_state=opt_state.shared_state, - ) - - if sampled_config.previous_config_id is not None: - previous_trial = trials.get(sampled_config.previous_config_id) - if previous_trial is None: - raise ValueError( - f"Previous trial '{sampled_config.previous_config_id}' not found." - ) - previous_trial_location = previous_trial.metadata.location - else: - previous_trial_location = None - - trial = Trial.new( - trial_id=sampled_config.id, - location="", # HACK: This will be set by the `TrialRepo` - config=sampled_config.config, - previous_trial=sampled_config.previous_config_id, - previous_trial_location=previous_trial_location, - time_sampled=time.time(), - worker_id=worker_id, - ) - shared_trial = neps_state._trials.put_new(trial) - seed_state.recapture() - put_seed_state(seed_state) - put_opt( - OptimizationState(budget=opt_state.budget, shared_state=new_opt_state) - ) - - return trial @dataclass class NePSState(Generic[Loc]): @@ -140,10 +71,75 @@ def get_trials_by_ids(self, trial_ids: list[str], /) -> dict[str, Trial | None]: for _id, shared_trial in self._trials.get_by_ids(trial_ids).items() } - def get_optimizer_instance(self) -> BaseOptimizer: - """Get the optimizer instance.""" - raise NotImplementedError + def sample_trial( + self, + optimizer: BaseOptimizer, + *, + worker_id: str, + _sample_hooks: list[Callable] | None = None, + ) -> Trial: + """Sample a new trial from the optimizer. + + Args: + optimizer: The optimizer to sample the trial from. + worker_id: The worker that is sampling the trial. + _sample_hooks: A list of hooks to apply to the optimizer before sampling. + Returns: + The new trial. + """ + with self._optimizer_state.acquire() as ( + opt_state, + put_opt, + ), self._seed_state.acquire() as (seed_state, put_seed_state): + trials: dict[Trial.ID, Trial] = {} + for trial_id, shared_trial in self._trials.all().items(): + trial = shared_trial.synced() + trials[trial_id] = trial + + seed_state.set_as_global_seed_state() + + # TODO: Not sure if any existing pre_load hooks required + # it to be done after `load_results`... I hope not. + if _sample_hooks is not None: + for hook in _sample_hooks: + optimizer = hook(optimizer) + + # NOTE: We don't want optimizers mutating this before serialization + budget = opt_state.budget.clone() if opt_state.budget is not None else None + sampled_config, new_opt_state = optimizer.ask( + trials=trials, + budget_info=budget, + optimizer_state=opt_state.shared_state, + ) + + if sampled_config.previous_config_id is not None: + previous_trial = trials.get(sampled_config.previous_config_id) + if previous_trial is None: + raise ValueError( + f"Previous trial '{sampled_config.previous_config_id}' not found." + ) + previous_trial_location = previous_trial.metadata.location + else: + previous_trial_location = None + + trial = Trial.new( + trial_id=sampled_config.id, + location="", # HACK: This will be set by the `TrialRepo` + config=sampled_config.config, + previous_trial=sampled_config.previous_config_id, + previous_trial_location=previous_trial_location, + time_sampled=time.time(), + worker_id=worker_id, + ) + shared_trial = self._trials.put_new(trial) + seed_state.recapture() + put_seed_state(seed_state) + put_opt( + OptimizationState(budget=opt_state.budget, shared_state=new_opt_state) + ) + + return trial def report_trial_evaluation( self,