From cd565fc3ae5a84773d1be94c4e115551355a9075 Mon Sep 17 00:00:00 2001 From: jessica-moylan Date: Fri, 9 Jan 2026 15:59:40 -0500 Subject: [PATCH 1/2] adding global stopping but has error with pre-commit --- docs/source/how-to-guides.rst | 1 + .../global-stopping-strategies.rst | 167 ++++++++++++++++++ src/blop/ax/agent.py | 16 +- src/blop/plans/plans.py | 11 ++ 4 files changed, 194 insertions(+), 1 deletion(-) create mode 100644 docs/source/how-to-guides/global-stopping-strategies.rst diff --git a/docs/source/how-to-guides.rst b/docs/source/how-to-guides.rst index e4a1780b..b179a22e 100644 --- a/docs/source/how-to-guides.rst +++ b/docs/source/how-to-guides.rst @@ -11,3 +11,4 @@ How-to Guides how-to-guides/set-outcome-constraints.rst how-to-guides/acquire-baseline.rst how-to-guides/tiled-databroker.rst + how-to-guides/global-stopping-strategies.rst diff --git a/docs/source/how-to-guides/global-stopping-strategies.rst b/docs/source/how-to-guides/global-stopping-strategies.rst new file mode 100644 index 00000000..1a1063cf --- /dev/null +++ b/docs/source/how-to-guides/global-stopping-strategies.rst @@ -0,0 +1,167 @@ +.. testsetup:: + + from typing import Any + import time + import logging + + from bluesky.protocols import NamedMovable, Readable, Status, Hints, HasHints, HasParent + from bluesky.run_engine import RunEngine + from bluesky.callbacks.tiled_writer import TiledWriter + from bluesky.callbacks.best_effort import BestEffortCallback + + from tiled.client.container import Container + from tiled.client import from_uri + from tiled.server import SimpleTiledServer + + class AlwaysSuccessfulStatus(Status): + def add_callback(self, callback) -> None: + callback(self) + + def exception(self, timeout = 0.0): + return None + + @property + def done(self) -> bool: + return True + + @property + def success(self) -> bool: + return True + + class ReadableSignal(Readable, HasHints, HasParent): + def __init__(self, name: str) -> None: + self._name = name + self._value = 0.0 + + @property + def name(self) -> str: + return self._name + + @property + def hints(self) -> Hints: + return { + "fields": [self._name], + "dimensions": [], + "gridding": "rectilinear", + } + + @property + def parent(self) -> Any | None: + return None + + def read(self): + return { + self._name: { "value": self._value, "timestamp": time.time() } + } + + def describe(self): + return { + self._name: { "source": self._name, "dtype": "number", "shape": [] } + } + + class MovableSignal(ReadableSignal, NamedMovable): + def __init__(self, name: str, initial_value: float = 0.0) -> None: + super().__init__(name) + self._value: float = initial_value + + def set(self, value: float) -> Status: + self._value = value + return AlwaysSuccessfulStatus() + + # Start a local Tiled server for data storage + tiled_server = SimpleTiledServer() + + # Set up the Bluesky RunEngine and connect it to Tiled + RE = RunEngine({}) + tiled_client = from_uri(tiled_server.uri) + tiled_writer = TiledWriter(tiled_client) + RE.subscribe(tiled_writer) + bec = BestEffortCallback() + bec.disable_plots() + RE.subscribe(bec) + + x1 = MovableSignal("x1", initial_value=0.1) + x2 = MovableSignal("x2", initial_value=0.23) + +.. testcleanup:: + + # Suppress stdout from server.close() otherwise the doctest will fail + import os + import contextlib + + with contextlib.redirect_stdout(open(os.devnull, "w")): + tiled_server.close() + +Using a global stopping strategy +================================== +This guide will show you how to use a global stopping strategy. This allows you to stop an optimization early based on certain criteria, such as lack of improvement over a series of trials. + +Define the stopping strategy +---------------------------- +You will need to define the following parameters: +1. The minimum number of trials `min_trials` before checking for improvement +2. The window size `window_size`, how many of the most recent trials to consider when checking for improvement +3. The improvement bar `improvement_bar`, the theshold for considering improvement relative to the interquartile range of values seen so far. Must be >= 0 + +.. testcode:: + + from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy + + stopping_strategy = ImprovementGlobalStoppingStrategy( + min_trials=10, + window_size=5, + improvement_bar=0.1, + ) + +Configure an agent +------------------ + +.. testcode:: + + from blop.ax import Agent, RangeDOF, Objective + + class Himmelblau2DEvaluation(): + def __init__(self, tiled_client: Container): + self.tiled_client = tiled_client + + def __call__(self, uid: str, suggestions: list[dict]) -> list[dict]: + run = self.tiled_client[uid] + outcomes = [] + x1_data = run["primary/x1"].read() + x2_data = run["primary/x2"].read() + + for suggestion in suggestions: + suggestion_id = suggestion["_id"] + x1 = x1_data[suggestion_id % len(x1_data)] + x2 = x2_data[suggestion_id % len(x2_data)] + # Himmelblau function + outcomes.append({ + "himmelblau_2d": (x1 ** 2 + x2 - 11) ** 2 + (x1 + x2 ** 2 - 7) ** 2, + "_id": suggestion_id + }) + + return outcomes + + dofs = [ + RangeDOF(actuator=x1, bounds=(-5.0, 5.0), parameter_type="float"), + RangeDOF(actuator=x2, bounds=(-5.0, 5.0), parameter_type="float"), + ] + + objectives = [ + Objective(name="himmelblau_2d", minimize=False), + ] + + agent = Agent( + sensors=[], + dofs=dofs, + objectives=objectives, + stopping_strategy=stopping_strategy, + evaluation=Himmelblau2DEvaluation(tiled_client), + ) + +Run the experiment with Bluesky +------------------------------- +The experiment will stop early only if the stopping criteria are met. Otherwise, it will continue for the full number of iterations. +.. testcode:: + + RE(agent.optimize(iterations=10000, n_points=1)) diff --git a/src/blop/ax/agent.py b/src/blop/ax/agent.py index 3e491e84..507cd8ab 100644 --- a/src/blop/ax/agent.py +++ b/src/blop/ax/agent.py @@ -5,6 +5,7 @@ from ax import Client from ax.analysis import ContourPlot from ax.analysis.analysis_card import AnalysisCardBase +from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy from bluesky.utils import MsgGenerator from ..plans import acquire_baseline, optimize @@ -42,6 +43,8 @@ class Agent: Constraints on DOFs to refine the search space. outcome_constraints : Sequence[OutcomeConstraint] | None, optional Constraints on outcomes to be satisfied during optimization. + stopping_strategy : BaseGlobalStoppingStrategy | None, optional + A global stopping strategy to determine when/if to stop optimization early. **kwargs : Any Additional keyword arguments to configure the Ax experiment. @@ -71,6 +74,7 @@ def __init__( acquisition_plan: AcquisitionPlan | None = None, dof_constraints: Sequence[DOFConstraint] | None = None, outcome_constraints: Sequence[OutcomeConstraint] | None = None, + stopping_strategy: BaseGlobalStoppingStrategy | None = None, **kwargs: Any, ): self._sensors = sensors @@ -80,6 +84,7 @@ def __init__( self._acquisition_plan = acquisition_plan self._dof_constraints = dof_constraints self._outcome_constraints = outcome_constraints + self._stopping_strategy = stopping_strategy self._optimizer = AxOptimizer( parameters=[dof.to_ax_parameter_config() for dof in dofs], objective=to_ax_objective_str(objectives), @@ -124,6 +129,10 @@ def outcome_constraints(self) -> Sequence[OutcomeConstraint] | None: def ax_client(self) -> Client: return self._optimizer.ax_client + @property + def stopping_strategy(self) -> BaseGlobalStoppingStrategy | None: + return self._stopping_strategy + def to_optimization_problem(self) -> OptimizationProblem: """ Construct an optimization problem from the agent. @@ -254,7 +263,12 @@ def optimize(self, iterations: int = 1, n_points: int = 1) -> MsgGenerator[None] suggest : Get point suggestions without running acquisition. ingest : Manually ingest evaluation results. """ - yield from optimize(self.to_optimization_problem(), iterations=iterations, n_points=n_points) + yield from optimize( + self.to_optimization_problem(), + stopping_strategy=self._stopping_strategy, + iterations=iterations, + n_points=n_points, + ) def plot_objective( self, x_dof_name: str, y_dof_name: str, objective_name: str, *args: Any, **kwargs: Any diff --git a/src/blop/plans/plans.py b/src/blop/plans/plans.py index 08f5de88..088c875e 100644 --- a/src/blop/plans/plans.py +++ b/src/blop/plans/plans.py @@ -5,6 +5,7 @@ import bluesky.plan_stubs as bps import bluesky.plans as bp +from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy from bluesky.protocols import Readable, Reading from bluesky.utils import MsgGenerator, plan @@ -126,6 +127,7 @@ def optimize( optimization_problem: OptimizationProblem, iterations: int = 1, n_points: int = 1, + stopping_strategy: BaseGlobalStoppingStrategy | None = None, *args: Any, **kwargs: Any, ) -> MsgGenerator[None]: @@ -140,10 +142,19 @@ def optimize( The number of optimization iterations to run. n_points : int, optional The number of points to suggest per iteration. + stopping_strategy : BaseGlobalStoppingStrategy | None, optional + A global stopping strategy to determine when/if to stop optimization early. """ for _ in range(iterations): yield from optimize_step(optimization_problem, n_points, *args, **kwargs) + if stopping_strategy is not None: + should_stop, message = stopping_strategy.should_stop_optimization( + experiment=optimization_problem.optimizer.ax_client._experiment + ) + if should_stop: + print(f"Global stopping strategy triggered optimization stop: {message}") + break @plan From 8188894cc3107d3244625f2b4abf809fa2a8f773 Mon Sep 17 00:00:00 2001 From: jessica-moylan Date: Mon, 12 Jan 2026 10:24:57 -0500 Subject: [PATCH 2/2] fixed pyright --- src/blop/plans/plans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blop/plans/plans.py b/src/blop/plans/plans.py index 088c875e..d5f4791c 100644 --- a/src/blop/plans/plans.py +++ b/src/blop/plans/plans.py @@ -150,7 +150,7 @@ def optimize( yield from optimize_step(optimization_problem, n_points, *args, **kwargs) if stopping_strategy is not None: should_stop, message = stopping_strategy.should_stop_optimization( - experiment=optimization_problem.optimizer.ax_client._experiment + experiment=optimization_problem.optimizer.ax_client._experiment # type: ignore[attr-defined] ) if should_stop: print(f"Global stopping strategy triggered optimization stop: {message}")