Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/how-to-guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ How-to Guides
how-to-guides/set-outcome-constraints.rst
how-to-guides/acquire-baseline.rst
how-to-guides/tiled-databroker.rst
how-to-guides/global-stopping-strategies.rst
167 changes: 167 additions & 0 deletions docs/source/how-to-guides/global-stopping-strategies.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
.. testsetup::

from typing import Any
import time
import logging

from bluesky.protocols import NamedMovable, Readable, Status, Hints, HasHints, HasParent
from bluesky.run_engine import RunEngine
from bluesky.callbacks.tiled_writer import TiledWriter
from bluesky.callbacks.best_effort import BestEffortCallback

from tiled.client.container import Container
from tiled.client import from_uri
from tiled.server import SimpleTiledServer

class AlwaysSuccessfulStatus(Status):
def add_callback(self, callback) -> None:
callback(self)

def exception(self, timeout = 0.0):
return None

@property
def done(self) -> bool:
return True

@property
def success(self) -> bool:
return True

class ReadableSignal(Readable, HasHints, HasParent):
def __init__(self, name: str) -> None:
self._name = name
self._value = 0.0

@property
def name(self) -> str:
return self._name

@property
def hints(self) -> Hints:
return {
"fields": [self._name],
"dimensions": [],
"gridding": "rectilinear",
}

@property
def parent(self) -> Any | None:
return None

def read(self):
return {
self._name: { "value": self._value, "timestamp": time.time() }
}

def describe(self):
return {
self._name: { "source": self._name, "dtype": "number", "shape": [] }
}

class MovableSignal(ReadableSignal, NamedMovable):
def __init__(self, name: str, initial_value: float = 0.0) -> None:
super().__init__(name)
self._value: float = initial_value

def set(self, value: float) -> Status:
self._value = value
return AlwaysSuccessfulStatus()

# Start a local Tiled server for data storage
tiled_server = SimpleTiledServer()

# Set up the Bluesky RunEngine and connect it to Tiled
RE = RunEngine({})
tiled_client = from_uri(tiled_server.uri)
tiled_writer = TiledWriter(tiled_client)
RE.subscribe(tiled_writer)
bec = BestEffortCallback()
bec.disable_plots()
RE.subscribe(bec)

x1 = MovableSignal("x1", initial_value=0.1)
x2 = MovableSignal("x2", initial_value=0.23)

.. testcleanup::

# Suppress stdout from server.close() otherwise the doctest will fail
import os
import contextlib

with contextlib.redirect_stdout(open(os.devnull, "w")):
tiled_server.close()

Using a global stopping strategy
==================================
This guide will show you how to use a global stopping strategy. This allows you to stop an optimization early based on certain criteria, such as lack of improvement over a series of trials.

Define the stopping strategy
----------------------------
You will need to define the following parameters:
1. The minimum number of trials `min_trials` before checking for improvement
2. The window size `window_size`, how many of the most recent trials to consider when checking for improvement
3. The improvement bar `improvement_bar`, the theshold for considering improvement relative to the interquartile range of values seen so far. Must be >= 0

.. testcode::

from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy

stopping_strategy = ImprovementGlobalStoppingStrategy(
min_trials=10,
window_size=5,
improvement_bar=0.1,
)

Configure an agent
------------------

.. testcode::

from blop.ax import Agent, RangeDOF, Objective

class Himmelblau2DEvaluation():
def __init__(self, tiled_client: Container):
self.tiled_client = tiled_client

def __call__(self, uid: str, suggestions: list[dict]) -> list[dict]:
run = self.tiled_client[uid]
outcomes = []
x1_data = run["primary/x1"].read()
x2_data = run["primary/x2"].read()

for suggestion in suggestions:
suggestion_id = suggestion["_id"]
x1 = x1_data[suggestion_id % len(x1_data)]
x2 = x2_data[suggestion_id % len(x2_data)]
# Himmelblau function
outcomes.append({
"himmelblau_2d": (x1 ** 2 + x2 - 11) ** 2 + (x1 + x2 ** 2 - 7) ** 2,
"_id": suggestion_id
})

return outcomes

dofs = [
RangeDOF(actuator=x1, bounds=(-5.0, 5.0), parameter_type="float"),
RangeDOF(actuator=x2, bounds=(-5.0, 5.0), parameter_type="float"),
]

objectives = [
Objective(name="himmelblau_2d", minimize=False),
]

agent = Agent(
sensors=[],
dofs=dofs,
objectives=objectives,
stopping_strategy=stopping_strategy,
evaluation=Himmelblau2DEvaluation(tiled_client),
)

Run the experiment with Bluesky
-------------------------------
The experiment will stop early only if the stopping criteria are met. Otherwise, it will continue for the full number of iterations.
.. testcode::

RE(agent.optimize(iterations=10000, n_points=1))
16 changes: 15 additions & 1 deletion src/blop/ax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ax import Client
from ax.analysis import ContourPlot
from ax.analysis.analysis_card import AnalysisCardBase
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
from bluesky.utils import MsgGenerator

from ..plans import acquire_baseline, optimize
Expand Down Expand Up @@ -42,6 +43,8 @@ class Agent:
Constraints on DOFs to refine the search space.
outcome_constraints : Sequence[OutcomeConstraint] | None, optional
Constraints on outcomes to be satisfied during optimization.
stopping_strategy : BaseGlobalStoppingStrategy | None, optional
A global stopping strategy to determine when/if to stop optimization early.
**kwargs : Any
Additional keyword arguments to configure the Ax experiment.

Expand Down Expand Up @@ -71,6 +74,7 @@ def __init__(
acquisition_plan: AcquisitionPlan | None = None,
dof_constraints: Sequence[DOFConstraint] | None = None,
outcome_constraints: Sequence[OutcomeConstraint] | None = None,
stopping_strategy: BaseGlobalStoppingStrategy | None = None,
**kwargs: Any,
):
self._sensors = sensors
Expand All @@ -80,6 +84,7 @@ def __init__(
self._acquisition_plan = acquisition_plan
self._dof_constraints = dof_constraints
self._outcome_constraints = outcome_constraints
self._stopping_strategy = stopping_strategy
self._optimizer = AxOptimizer(
parameters=[dof.to_ax_parameter_config() for dof in dofs],
objective=to_ax_objective_str(objectives),
Expand Down Expand Up @@ -124,6 +129,10 @@ def outcome_constraints(self) -> Sequence[OutcomeConstraint] | None:
def ax_client(self) -> Client:
return self._optimizer.ax_client

@property
def stopping_strategy(self) -> BaseGlobalStoppingStrategy | None:
return self._stopping_strategy

def to_optimization_problem(self) -> OptimizationProblem:
"""
Construct an optimization problem from the agent.
Expand Down Expand Up @@ -254,7 +263,12 @@ def optimize(self, iterations: int = 1, n_points: int = 1) -> MsgGenerator[None]
suggest : Get point suggestions without running acquisition.
ingest : Manually ingest evaluation results.
"""
yield from optimize(self.to_optimization_problem(), iterations=iterations, n_points=n_points)
yield from optimize(
self.to_optimization_problem(),
stopping_strategy=self._stopping_strategy,
iterations=iterations,
n_points=n_points,
)

def plot_objective(
self, x_dof_name: str, y_dof_name: str, objective_name: str, *args: Any, **kwargs: Any
Expand Down
11 changes: 11 additions & 0 deletions src/blop/plans/plans.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import bluesky.plan_stubs as bps
import bluesky.plans as bp
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
from bluesky.protocols import Readable, Reading
from bluesky.utils import MsgGenerator, plan

Expand Down Expand Up @@ -126,6 +127,7 @@ def optimize(
optimization_problem: OptimizationProblem,
iterations: int = 1,
n_points: int = 1,
stopping_strategy: BaseGlobalStoppingStrategy | None = None,
*args: Any,
**kwargs: Any,
) -> MsgGenerator[None]:
Expand All @@ -140,10 +142,19 @@ def optimize(
The number of optimization iterations to run.
n_points : int, optional
The number of points to suggest per iteration.
stopping_strategy : BaseGlobalStoppingStrategy | None, optional
A global stopping strategy to determine when/if to stop optimization early.
"""

for _ in range(iterations):
yield from optimize_step(optimization_problem, n_points, *args, **kwargs)
if stopping_strategy is not None:
should_stop, message = stopping_strategy.should_stop_optimization(
experiment=optimization_problem.optimizer.ax_client._experiment # type: ignore[attr-defined]
)
if should_stop:
print(f"Global stopping strategy triggered optimization stop: {message}")
break


@plan
Expand Down