diff --git a/docs/source/how-to-guides/acquire-baseline.rst b/docs/source/how-to-guides/acquire-baseline.rst index 6660ccae..bdb6e451 100644 --- a/docs/source/how-to-guides/acquire-baseline.rst +++ b/docs/source/how-to-guides/acquire-baseline.rst @@ -130,7 +130,7 @@ Here we configure an agent with three DOFs and two objectives. The second object sensors=[readable1, readable2], dofs=dofs, objectives=objectives, - evaluation=evaluation_function, + evaluation_function=evaluation_function, outcome_constraints=outcome_constraints, ) diff --git a/docs/source/how-to-guides/attach-data-to-experiments.rst b/docs/source/how-to-guides/attach-data-to-experiments.rst index d92ff85e..69203e79 100644 --- a/docs/source/how-to-guides/attach-data-to-experiments.rst +++ b/docs/source/how-to-guides/attach-data-to-experiments.rst @@ -133,7 +133,7 @@ The ``DOF`` and ``Objective`` names must match the keys in the data dictionaries sensors=[readable1, readable2], dofs=dofs, objectives=objectives, - evaluation=evaluation_function, + evaluation_function=evaluation_function, ) Ingest your data diff --git a/docs/source/how-to-guides/custom-generation-strategies.rst b/docs/source/how-to-guides/custom-generation-strategies.rst index 548789ff..854171ad 100644 --- a/docs/source/how-to-guides/custom-generation-strategies.rst +++ b/docs/source/how-to-guides/custom-generation-strategies.rst @@ -123,7 +123,7 @@ Configure an agent sensors=[readable1, readable2], dofs=dofs, objectives=objectives, - evaluation=evaluation_function, + evaluation_function=evaluation_function, ) Configure a generation strategy diff --git a/docs/source/how-to-guides/set-dof-constraints.rst b/docs/source/how-to-guides/set-dof-constraints.rst index c8e1cc2b..c793cb17 100644 --- a/docs/source/how-to-guides/set-dof-constraints.rst +++ b/docs/source/how-to-guides/set-dof-constraints.rst @@ -124,6 +124,6 @@ Configure an agent with DOF constraints sensors=[], dofs=[dof1, dof2, dof3], objectives=[objective], - evaluation=evaluation_function, + evaluation_function=evaluation_function, dof_constraints=[constraint], ) diff --git a/docs/source/how-to-guides/set-outcome-constraints.rst b/docs/source/how-to-guides/set-outcome-constraints.rst index 12b7617a..7784b159 100644 --- a/docs/source/how-to-guides/set-outcome-constraints.rst +++ b/docs/source/how-to-guides/set-outcome-constraints.rst @@ -138,6 +138,6 @@ Configure an agent with outcome constraints sensors=[], dofs=dofs, objectives=objectives, - evaluation=evaluation_function, + evaluation_function=evaluation_function, outcome_constraints=[constraint], ) diff --git a/docs/source/how-to-guides/tiled-databroker.rst b/docs/source/how-to-guides/tiled-databroker.rst index a0561031..674fd873 100644 --- a/docs/source/how-to-guides/tiled-databroker.rst +++ b/docs/source/how-to-guides/tiled-databroker.rst @@ -189,7 +189,7 @@ Configure an agent sensors=[motor_x], dofs=[dof1], objectives=[objective], - evaluation=TiledEvaluation(tiled_client=tiled_client), + evaluation_function=TiledEvaluation(tiled_client=tiled_client), ) RE(agent.optimize()) server.close() @@ -205,6 +205,6 @@ or for Databroker: sensors=[motor_x], dofs=[dof1], objectives=[objective], - evaluation=DatabrokerEvaluation(db=db), + evaluation_function=DatabrokerEvaluation(db=db), ) RE(agent_db.optimize()) diff --git a/docs/source/how-to-guides/use-ophyd-devices.rst b/docs/source/how-to-guides/use-ophyd-devices.rst index 379e1450..8c05ba08 100644 --- a/docs/source/how-to-guides/use-ophyd-devices.rst +++ b/docs/source/how-to-guides/use-ophyd-devices.rst @@ -23,7 +23,7 @@ The ``name`` attribute of the signal will be used as the name of the :class:`blo sensors=[some_readable_signal], dofs=[dof], objectives=[Objective(name="result", minimize=False)], - evaluation=lambda uid, suggestions: [{"result": 0.1}], + evaluation_function=lambda uid, suggestions: [{"result": 0.1}], ) Ophyd-async devices @@ -48,7 +48,7 @@ Once again, the ``name`` attribute of the signal will be used as the name of the sensors=[some_readable_signal], dofs=[dof], objectives=[Objective(name="result", minimize=False)], - evaluation=lambda uid, suggestions: [{"result": 0.1}], + evaluation_function=lambda uid, suggestions: [{"result": 0.1}], ) Using your devices in custom acquisition plans @@ -83,7 +83,7 @@ If you use a custom acquisition plan by implementing the :class:`blop.protocols. dofs=[dof], acquisition_plan=custom_acquire, objectives=[Objective(name="result", minimize=False)], - evaluation=lambda uid, suggestions: [{"result": 0.1, "_id": 0}], + evaluation_function=lambda uid, suggestions: [{"result": 0.1, "_id": 0}], ) RE(agent.optimize()) diff --git a/docs/source/tutorials/simple-experiment.md b/docs/source/tutorials/simple-experiment.md index 4f4df7ac..368e4e93 100644 --- a/docs/source/tutorials/simple-experiment.md +++ b/docs/source/tutorials/simple-experiment.md @@ -151,7 +151,7 @@ agent = Agent( sensors=sensors, dofs=dofs, objectives=objectives, - evaluation=Himmelblau2DEvaluation(tiled_client=tiled_client), + evaluation_function=Himmelblau2DEvaluation(tiled_client=tiled_client), name="simple-experiment", description="A simple experiment optimizing the Himmelblau function", ) diff --git a/docs/source/tutorials/xrt-kb-mirrors.md b/docs/source/tutorials/xrt-kb-mirrors.md index 871693ad..61378e87 100644 --- a/docs/source/tutorials/xrt-kb-mirrors.md +++ b/docs/source/tutorials/xrt-kb-mirrors.md @@ -170,7 +170,7 @@ agent = Agent( sensors=[beamline.det], dofs=dofs, objectives=objectives, - evaluation=DetectorEvaluation(tiled_client), + evaluation_function=DetectorEvaluation(tiled_client), name="xrt-blop-demo", description="A demo of the Blop agent with XRT simulated beamline", experiment_type="demo", diff --git a/docs/wip/qserver-experiment.md b/docs/wip/qserver-experiment.md index 6d62416c..80f0a230 100644 --- a/docs/wip/qserver-experiment.md +++ b/docs/wip/qserver-experiment.md @@ -303,7 +303,7 @@ agent = QueueserverAgent( sensors=sensors, # The list of sensors to read from dofs=dofs, # The list of DOFs to search over objectives=objectives, # The list of objectives to be optimized - evaluation= DetectorEvaluation(tiled_client), # The function to create objective function values + evaluation_function= DetectorEvaluation(tiled_client), # The function to create objective function values acquisition_plan= "acquire", # The name of the plan in the Queueserver environment Queueserver_control_addr="tcp://localhost:60615", Queueserver_info_addr="tcp://localhost:60625", diff --git a/src/blop/ax/agent.py b/src/blop/ax/agent.py index 3e491e84..687e382e 100644 --- a/src/blop/ax/agent.py +++ b/src/blop/ax/agent.py @@ -8,7 +8,7 @@ from bluesky.utils import MsgGenerator from ..plans import acquire_baseline, optimize -from ..protocols import AcquisitionPlan, EvaluationFunction, OptimizationProblem, Sensor +from ..protocols import AcquisitionPlan, Actuator, EvaluationFunction, OptimizationProblem, Sensor from .dof import DOF, DOFConstraint from .objective import Objective, OutcomeConstraint, to_ax_objective_str from .optimizer import AxOptimizer @@ -33,7 +33,7 @@ class Agent: The degrees of freedom that the agent can control, which determine the search space. objectives : Sequence[Objective] The objectives which the agent will try to optimize. - evaluation : EvaluationFunction + evaluation_function : EvaluationFunction The function to evaluate acquired data and produce outcomes. acquisition_plan : AcquisitionPlan | None, optional The acquisition plan to use for acquiring data from the beamline. If not provided, @@ -42,6 +42,8 @@ class Agent: Constraints on DOFs to refine the search space. outcome_constraints : Sequence[OutcomeConstraint] | None, optional Constraints on outcomes to be satisfied during optimization. + checkpoint_path : str | None, optional + The path to the checkpoint file to save the optimizer's state to. **kwargs : Any Additional keyword arguments to configure the Ax experiment. @@ -67,42 +69,75 @@ def __init__( sensors: Sequence[Sensor], dofs: Sequence[DOF], objectives: Sequence[Objective], - evaluation: EvaluationFunction, + evaluation_function: EvaluationFunction, acquisition_plan: AcquisitionPlan | None = None, dof_constraints: Sequence[DOFConstraint] | None = None, outcome_constraints: Sequence[OutcomeConstraint] | None = None, + checkpoint_path: str | None = None, **kwargs: Any, ): self._sensors = sensors - self._dofs = {dof.parameter_name: dof for dof in dofs} - self._objectives = {obj.name: obj for obj in objectives} - self._evaluation_function = evaluation + self._actuators = [dof.actuator for dof in dofs if dof.actuator is not None] + self._evaluation_function = evaluation_function self._acquisition_plan = acquisition_plan - self._dof_constraints = dof_constraints - self._outcome_constraints = outcome_constraints self._optimizer = AxOptimizer( parameters=[dof.to_ax_parameter_config() for dof in dofs], objective=to_ax_objective_str(objectives), - parameter_constraints=[constraint.ax_constraint for constraint in self._dof_constraints] - if self._dof_constraints - else None, - outcome_constraints=[constraint.ax_constraint for constraint in self._outcome_constraints] - if self._outcome_constraints + parameter_constraints=[constraint.ax_constraint for constraint in dof_constraints] if dof_constraints else None, + outcome_constraints=[constraint.ax_constraint for constraint in outcome_constraints] + if outcome_constraints else None, + checkpoint_path=checkpoint_path, **kwargs, ) + @classmethod + def from_checkpoint( + cls, + checkpoint_path: str, + actuators: Sequence[Actuator], + sensors: Sequence[Sensor], + evaluation_function: EvaluationFunction, + acquisition_plan: AcquisitionPlan | None = None, + ) -> "Agent": + """ + Load an agent from the optimizer's checkpoint file. + + .. note:: + + Only the optimizer state is saved during a checkpoint, so we cannot reliably validate + the remaining state against the optimizer configuration. + + Parameters + ---------- + checkpoint_path : str + The checkpoint path to load the agent from. + actuators: Sequence[Actuator] + Objects that can be moved to control the beamline using the Bluesky RunEngine. + A subset of the actuators' names must match the names of suggested parameterizations. + sensors: Sequence[Sensor] + Objects that can produce data to acquire data from the beamline using the Bluesky RunEngine. + evaluation_function: EvaluationFunction + A callable to evaluate data from a Bluesky run and produce outcomes. + acquisition_plan: AcquisitionPlan, optional + A Bluesky plan to acquire data from the beamline. If not provided, a default plan will be used. + """ + instance = object.__new__(cls) + instance._optimizer = AxOptimizer.from_checkpoint(checkpoint_path) + instance._actuators = actuators + instance._sensors = sensors + instance._evaluation_function = evaluation_function + instance._acquisition_plan = acquisition_plan + + return instance + @property def sensors(self) -> Sequence[Sensor]: return self._sensors @property - def dofs(self) -> Sequence[DOF]: - return list(self._dofs.values()) - - @property - def objectives(self) -> Sequence[Objective]: - return list(self._objectives.values()) + def actuators(self) -> Sequence[Actuator]: + return self._actuators @property def evaluation_function(self) -> EvaluationFunction: @@ -112,18 +147,14 @@ def evaluation_function(self) -> EvaluationFunction: def acquisition_plan(self) -> AcquisitionPlan | None: return self._acquisition_plan - @property - def dof_constraints(self) -> Sequence[DOFConstraint] | None: - return self._dof_constraints - - @property - def outcome_constraints(self) -> Sequence[OutcomeConstraint] | None: - return self._outcome_constraints - @property def ax_client(self) -> Client: return self._optimizer.ax_client + @property + def checkpoint_path(self) -> str | None: + return self._optimizer.checkpoint_path + def to_optimization_problem(self) -> OptimizationProblem: """ Construct an optimization problem from the agent. @@ -144,7 +175,7 @@ def to_optimization_problem(self) -> OptimizationProblem: """ return OptimizationProblem( optimizer=self._optimizer, - actuators=[dof.actuator for dof in self.dofs if dof.actuator is not None], + actuators=self.actuators, sensors=self.sensors, evaluation_function=self.evaluation_function, acquisition_plan=self.acquisition_plan, @@ -299,3 +330,9 @@ def plot_objective( *args, **kwargs, ) + + def checkpoint(self) -> None: + """ + Save the agent's state to a JSON file. + """ + self._optimizer.checkpoint() diff --git a/src/blop/ax/optimizer.py b/src/blop/ax/optimizer.py index f44ff8eb..9b473246 100644 --- a/src/blop/ax/optimizer.py +++ b/src/blop/ax/optimizer.py @@ -3,10 +3,10 @@ from ax import ChoiceParameterConfig, Client, RangeParameterConfig -from ..protocols import ID_KEY, Optimizer +from ..protocols import ID_KEY, Checkpointable, Optimizer -class AxOptimizer(Optimizer): +class AxOptimizer(Optimizer, Checkpointable): """ An optimizer that uses Ax as the backend for optimization and experiment tracking. @@ -22,6 +22,8 @@ class AxOptimizer(Optimizer): The parameter constraints to apply to the optimization. outcome_constraints : Sequence[str] | None, optional The outcome constraints to apply to the optimization. + checkpoint_path : str | None, optional + The path to the checkpoint file to save the optimizer's state to. client_kwargs : dict[str, Any] | None, optional Additional keyword arguments to configure the Ax client. **kwargs : Any @@ -39,10 +41,12 @@ def __init__( objective: str, parameter_constraints: Sequence[str] | None = None, outcome_constraints: Sequence[str] | None = None, + checkpoint_path: str | None = None, client_kwargs: dict[str, Any] | None = None, **kwargs: Any, ): self._parameter_names = [parameter.name for parameter in parameters] + self._checkpoint_path = checkpoint_path self._client = Client(**(client_kwargs or {})) self._client.configure_experiment( parameters=parameters, @@ -54,6 +58,33 @@ def __init__( outcome_constraints=outcome_constraints, ) + @classmethod + def from_checkpoint(cls, checkpoint_path: str) -> "AxOptimizer": + """ + Load an optimizer from a checkpoint file. + + Parameters + ---------- + checkpoint_path : str + The path to the checkpoint file to load the optimizer from. + + Returns + ------- + AxOptimizer + An instance of the optimizer class, initialized from the checkpoint. + """ + client = Client.load_from_json_file(checkpoint_path) + instance = object.__new__(cls) + instance._parameter_names = list(client._experiment.parameters.keys()) + instance._checkpoint_path = checkpoint_path + instance._client = client + + return instance + + @property + def checkpoint_path(self) -> str | None: + return self._checkpoint_path + @property def ax_client(self) -> Client: return self._client @@ -126,3 +157,11 @@ def ingest(self, points: list[dict]) -> None: elif trial_idx == "baseline": trial_idx = self._client.attach_baseline(parameters=parameters) self._client.complete_trial(trial_index=trial_idx, raw_data=outcomes) + + def checkpoint(self) -> None: + """ + Save the optimizer's state to JSON file. + """ + if not self.checkpoint_path: + raise ValueError("Checkpoint path is not set. Please set a checkpoint path when initializing the optimizer.") + self._client.save_to_json_file(self.checkpoint_path) diff --git a/src/blop/ax/qserver_agent.py b/src/blop/ax/qserver_agent.py index 1714ee04..694b036a 100644 --- a/src/blop/ax/qserver_agent.py +++ b/src/blop/ax/qserver_agent.py @@ -87,7 +87,7 @@ class BlopQserverAgent(BlopAxAgent): The degrees of freedom that the agent can control, which determine the search space. objectives : Sequence[Objective] The objectives which the agent will try to optimize. - evaluation : EvaluationFunction + evaluation_function : EvaluationFunction The function to evaluate acquired data and produce outcomes. acquisition_plan : str, optional The name of the plan on the queueserver @@ -121,7 +121,7 @@ def __init__( sensors: Sequence[Sensor], dofs: Sequence[DOF], objectives: Sequence[Objective], - evaluation: EvaluationFunction = None, + evaluation_function: EvaluationFunction = None, acquisition_plan: str = "acquire", dof_constraints: Sequence[DOFConstraint] = None, qserver_control_addr: str = "tcp://localhost:60615", @@ -134,7 +134,7 @@ def __init__( sensors=sensors, dofs=dofs, objectives=objectives, - evaluation=evaluation, + evaluation_function=evaluation_function, acquisition_plan=acquisition_plan, dof_constraints=dof_constraints, **kwargs, diff --git a/src/blop/plans/plans.py b/src/blop/plans/plans.py index 08f5de88..a1efa157 100644 --- a/src/blop/plans/plans.py +++ b/src/blop/plans/plans.py @@ -8,7 +8,7 @@ from bluesky.protocols import Readable, Reading from bluesky.utils import MsgGenerator, plan -from ..protocols import ID_KEY, Actuator, OptimizationProblem, Sensor +from ..protocols import ID_KEY, Actuator, Checkpointable, OptimizationProblem, Sensor from .utils import route_suggestions logger = logging.getLogger(__name__) @@ -126,6 +126,7 @@ def optimize( optimization_problem: OptimizationProblem, iterations: int = 1, n_points: int = 1, + checkpoint_interval: int | None = None, *args: Any, **kwargs: Any, ) -> MsgGenerator[None]: @@ -140,10 +141,30 @@ def optimize( The number of optimization iterations to run. n_points : int, optional The number of points to suggest per iteration. + checkpoint_interval : int | None, optional + The number of iterations between optimizer checkpoints. If None, checkpoints + will not be saved. Optimizer must implement the + :class:`blop.protocols.Checkpointable` protocol. + *args : Any + Additional positional arguments to pass to the :func:`optimize_step` plan. + **kwargs : Any + Additional keyword arguments to pass to the :func:`optimize_step` plan. + + See Also + -------- + blop.protocols.OptimizationProblem : The problem to solve. + blop.protocols.Checkpointable : The protocol for checkpointable objects. + optimize_step : The plan to execute a single step of the optimization. """ - for _ in range(iterations): + for i in range(iterations): yield from optimize_step(optimization_problem, n_points, *args, **kwargs) + if checkpoint_interval and (i + 1) % checkpoint_interval == 0: + if not isinstance(optimization_problem.optimizer, Checkpointable): + raise ValueError( + "The optimizer is not checkpointable. Please review your optimizer configuration or implementation." + ) + optimization_problem.optimizer.checkpoint() @plan diff --git a/src/blop/protocols.py b/src/blop/protocols.py index 3646a1a6..66938492 100644 --- a/src/blop/protocols.py +++ b/src/blop/protocols.py @@ -10,6 +10,22 @@ Sensor = Readable | EventCollectable | EventPageCollectable +@runtime_checkable +class Checkpointable(Protocol): + """ + A protocol for objects that can can write state to persistent storage. + + Implementers configure storage at construction time (e.g., a file path, databse URI). + The checkpoint method then saves or updates to that pre-configured location. + """ + + def checkpoint(self) -> None: + """ + Write the object's state to persistent storage. + """ + ... + + @runtime_checkable class Optimizer(Protocol): """ diff --git a/src/blop/tests/integration/test_ax_agent.py b/src/blop/tests/integration/test_ax_agent.py index 9d790831..d0db6345 100644 --- a/src/blop/tests/integration/test_ax_agent.py +++ b/src/blop/tests/integration/test_ax_agent.py @@ -46,6 +46,6 @@ def evaluation_function(uid: str, suggestions: list[dict]) -> list[dict]: sensors=[beamline.det], dofs=dofs, objectives=objectives, - evaluation=evaluation_function, + evaluation_function=evaluation_function, ) RE(agent.optimize(iterations=12, n_points=1)) diff --git a/src/blop/tests/unit/ax/test_agent.py b/src/blop/tests/unit/ax/test_agent.py index c7945e3a..0580ca30 100644 --- a/src/blop/tests/unit/ax/test_agent.py +++ b/src/blop/tests/unit/ax/test_agent.py @@ -36,20 +36,47 @@ def test_agent_init(mock_evaluation_function, mock_acquisition_plan): sensors=[readable], dofs=[dof1, dof2], objectives=[objective], - evaluation=mock_evaluation_function, + evaluation_function=mock_evaluation_function, dof_constraints=[constraint], acquisition_plan=mock_acquisition_plan, name="test_experiment", ) assert agent.sensors == [readable] - assert agent.dofs == [dof1, dof2] - assert agent.objectives == [objective] + assert agent.actuators == [dof1.actuator, dof2.actuator] assert agent.evaluation_function == mock_evaluation_function - assert agent.dof_constraints == [constraint] assert agent.acquisition_plan == mock_acquisition_plan assert isinstance(agent.ax_client, Client) +def test_agent_checkpoint(mock_evaluation_function, mock_acquisition_plan, tmp_path): + checkpoint_path = tmp_path / "checkpoint.json" + readable = ReadableSignal(name="test_readable") + agent = Agent( + sensors=[ReadableSignal(name="test_readable")], + dofs=[RangeDOF(name="x1", bounds=(0, 10), parameter_type="float")], + objectives=[Objective(name="test_objective", minimize=False)], + evaluation_function=mock_evaluation_function, + acquisition_plan=mock_acquisition_plan, + checkpoint_path=str(checkpoint_path), + ) + + assert agent.checkpoint_path == str(checkpoint_path) + assert not checkpoint_path.exists() + agent.ingest([{"x1": 0.1, "test_objective": 0.2}]) + agent.ax_client.configure_generation_strategy() + agent.checkpoint() + assert checkpoint_path.exists() + + agent = Agent.from_checkpoint( + str(checkpoint_path), + sensors=[readable], + actuators=[], + evaluation_function=mock_evaluation_function, + acquisition_plan=mock_acquisition_plan, + ) + assert len(agent.ax_client.summarize()) == 1 + + def test_agent_to_optimization_problem(mock_evaluation_function): """Test that the agent can be converted to an optimization problem.""" movable1 = MovableSignal(name="test_movable1") @@ -62,7 +89,7 @@ def test_agent_to_optimization_problem(mock_evaluation_function): sensors=[], dofs=[dof1, dof2], objectives=[objective], - evaluation=mock_evaluation_function, + evaluation_function=mock_evaluation_function, dof_constraints=[constraint], ) optimization_problem = agent.to_optimization_problem() @@ -79,7 +106,7 @@ def test_agent_suggest(mock_evaluation_function): dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float") objective = Objective(name="test_objective", minimize=False) - agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation=mock_evaluation_function) + agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation_function=mock_evaluation_function) parameterizations = agent.suggest(1) assert len(parameterizations) == 1 @@ -96,7 +123,7 @@ def test_agent_suggest_multiple(mock_evaluation_function): dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float") objective = Objective(name="test_objective", minimize=False) - agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation=mock_evaluation_function) + agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation_function=mock_evaluation_function) parameterizations = agent.suggest(5) assert len(parameterizations) == 5 @@ -114,7 +141,7 @@ def test_agent_ingest(mock_evaluation_function): dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float") objective = Objective(name="test_objective", minimize=False) - agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation=mock_evaluation_function) + agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation_function=mock_evaluation_function) agent.ingest([{"test_movable1": 0.1, "test_movable2": 0.2, "test_objective": 0.3}]) @@ -132,7 +159,7 @@ def test_agent_ingest_multiple(mock_evaluation_function): dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float") objective = Objective(name="test_objective", minimize=False) - agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation=mock_evaluation_function) + agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation_function=mock_evaluation_function) agent.ingest( [ @@ -154,7 +181,7 @@ def test_ingest_baseline(mock_evaluation_function): dof1 = RangeDOF(actuator=movable1, bounds=(0, 10), parameter_type="float") dof2 = RangeDOF(actuator=movable2, bounds=(0, 10), parameter_type="float") objective = Objective(name="test_objective", minimize=False) - agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation=mock_evaluation_function) + agent = Agent(sensors=[], dofs=[dof1, dof2], objectives=[objective], evaluation_function=mock_evaluation_function) agent.ingest([{"test_movable1": 0.1, "test_movable2": 0.2, "test_objective": 0.3, "_id": "baseline"}]) diff --git a/src/blop/tests/unit/ax/test_optimizer.py b/src/blop/tests/unit/ax/test_optimizer.py index 729d1ca7..8ec67772 100644 --- a/src/blop/tests/unit/ax/test_optimizer.py +++ b/src/blop/tests/unit/ax/test_optimizer.py @@ -117,3 +117,46 @@ def test_ax_optimizer_suggest_ingest(): assert len(summary_df) == 2 assert np.all(summary_df["y1"].values == [1.0, 3.0]) assert np.all(summary_df["y2"].values == [2.0, 4.0]) + + +def test_ax_optimizer_checkpoint(tmp_path): + checkpoint_path = tmp_path / "checkpoint.json" + + # Save to checkpoint + optimizer = AxOptimizer( + parameters=[ + RangeParameterConfig(name="x1", bounds=(-5.0, 5.0), parameter_type="float"), + ], + objective="y1", + checkpoint_path=str(checkpoint_path), + ) + suggestions = optimizer.suggest(num_points=2) + outcomes = [ + {"_id": suggestions[0]["_id"], "y1": 1.0}, + {"_id": suggestions[1]["_id"], "y1": 3.0}, + ] + optimizer.ingest(outcomes) + + assert not checkpoint_path.exists() + optimizer.checkpoint() + assert checkpoint_path.exists() + + # Load from checkpoint + optimizer = AxOptimizer.from_checkpoint(str(checkpoint_path)) + summary_df = optimizer.ax_client.summarize() + assert "x1" in summary_df.columns + assert "y1" in summary_df.columns + assert len(summary_df) == 2 + assert optimizer.checkpoint_path == str(checkpoint_path) + + +def test_ax_optimizer_checkpoint_no_path(): + optimizer = AxOptimizer( + parameters=[ + RangeParameterConfig(name="x1", bounds=(-5.0, 5.0), parameter_type="float"), + ], + objective="y1", + ) + + with pytest.raises(ValueError): + optimizer.checkpoint() diff --git a/src/blop/tests/unit/test_plans.py b/src/blop/tests/unit/test_plans.py index 8df55427..9d30b23b 100644 --- a/src/blop/tests/unit/test_plans.py +++ b/src/blop/tests/unit/test_plans.py @@ -5,11 +5,14 @@ from bluesky.run_engine import RunEngine from blop.plans import acquire_baseline, acquire_with_background, default_acquire, optimize, optimize_step -from blop.protocols import AcquisitionPlan, EvaluationFunction, OptimizationProblem, Optimizer +from blop.protocols import AcquisitionPlan, Checkpointable, EvaluationFunction, OptimizationProblem, Optimizer from .conftest import MovableSignal, ReadableSignal +class CheckpointableOptimizer(Optimizer, Checkpointable): ... + + @pytest.fixture(scope="function") def RE(): return RunEngine({}) @@ -105,6 +108,40 @@ def test_optimize_complex_case(RE): assert evaluation_function.call_count == 2 +@pytest.mark.parametrize("checkpoint_interval", [0, 1, 2, 3]) +def test_optimize_with_checkpoint_every_iteration(RE, checkpoint_interval): + optimizer = MagicMock(spec=CheckpointableOptimizer) + optimizer.suggest.return_value = [{"x1": 0.0, "_id": 0}] + evaluation_function = MagicMock(spec=EvaluationFunction, return_value={"objective": 0.0}) + optimization_problem = OptimizationProblem( + optimizer=optimizer, + actuators=[MovableSignal("x1", initial_value=-1.0)], + sensors=[ReadableSignal("objective")], + evaluation_function=evaluation_function, + ) + + with patch.object(optimizer, "checkpoint", wraps=optimizer.checkpoint) as mock_checkpoint: + RE(optimize(optimization_problem, iterations=5, n_points=2, checkpoint_interval=checkpoint_interval)) + if checkpoint_interval == 0: + assert mock_checkpoint.call_count == 0 + else: + assert mock_checkpoint.call_count == 5 // checkpoint_interval + + +def test_optimize_with_non_checkpointable_optimizer(RE): + optimizer = MagicMock(spec=Optimizer) + optimizer.suggest.return_value = [{"x1": 0.0, "_id": 0}] + evaluation_function = MagicMock(spec=EvaluationFunction, return_value={"objective": 0.0}) + optimization_problem = OptimizationProblem( + optimizer=optimizer, + actuators=[MovableSignal("x1", initial_value=-1.0)], + sensors=[ReadableSignal("objective")], + evaluation_function=evaluation_function, + ) + with pytest.raises(ValueError): + RE(optimize(optimization_problem, iterations=5, n_points=2, checkpoint_interval=1)) + + def test_optimize_step_default(RE): optimizer = MagicMock(spec=Optimizer) optimizer.suggest.return_value = [{"x1": 0.0, "_id": 0}]