diff --git a/bpod_rig/calibration/__init__.py b/bpod_rig/calibration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bpod_rig/calibration/liquid/__init__.py b/bpod_rig/calibration/liquid/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bpod_rig/calibration/liquid/models.py b/bpod_rig/calibration/liquid/models.py new file mode 100644 index 0000000..0f3d412 --- /dev/null +++ b/bpod_rig/calibration/liquid/models.py @@ -0,0 +1,243 @@ +"""Data models for liquid calibration.""" + +import datetime +import logging + +import numpy as np +from pydantic import BaseModel, Field, ConfigDict, field_serializer + +logger = logging.getLogger(__name__) + + +class ValveData(BaseModel): + """Data model for the liquid calibration data for a single valve. + + Valves are calibrated by measuring the amount of liquid dispensed for a specific + duration. + """ + + name: str = Field(alias="ValveName", description="Name of valve (1-indexed)") + lastdatemodified: datetime.datetime | str = Field( + serialization_alias="LastDateModified", + validation_alias="LastDateModified", + default="", + description="Modification datetime in ISO format or empty string if not set.", + ) + coeffs: list[float] = Field( + alias="Coeffs", + default=[], + description="Coefficients for polynomial fitting of durations vs amounts.", + ) + durations: list[float | int] = Field( + alias="Durations", + default=[], + description="Durations in ms for each dispense, corresponding to the amounts.", + ) + amounts: list[float | int] = Field( + alias="Amounts", + default=[], + description="Amounts in uL for each dispense, corresponding to the durations.", + ) + + # MATLAB .json compatibility is maintained by using the alias to load/save + model_config = ConfigDict( + serialize_by_alias=True, + validate_by_name=True, + validate_by_alias=True, + ) + + def get_valve_time(self, amount: float) -> float: + """Get the valve time for a given amount of liquid. + + Parameters + ---------- + amount : float + The amount of liquid in mL. + + Returns + ------- + float + The duration in ms for the given amount. + """ + if len(self.coeffs) == 0: + raise ValueError("Coefficients are not set. Please add measurements first.") + duration = np.polyval(self.coeffs, amount) + logger.debug( + "%s calculated valve time: %s ms for amount: %s mL", + self.name, + duration, + amount, + ) + return float(duration) + + def add_measurement(self, duration: float, amount: float) -> None: + """Add a measurement to the valve data. + + Parameters + ---------- + duration : float + The duration of the dispense in ms. + amount : float + The amount of liquid dispensed mL. + """ + self.amounts.append(amount) + self.durations.append(duration) + logger.debug("Added measurement: amount = %s, duration = %s", amount, duration) + self.lastdatemodified = datetime.datetime.now() + self._update_coeffs() + + def remove_measurement(self, value: float, method: str = "index") -> None: + """Remove a measurement from the valve data. + + Parameters + ---------- + value : int | float + The index or duration (ms) to remove. + method : str, optional (default = 'index') + How to use value to find the measurement to remove. + - 'duration' to remove by duration value (assuming unique duration). + - 'index' to remove by index in `.durations` + """ + if method == "duration": + if value in self.durations: + # Throw error if value is in durations multiple times + if self.durations.count(value) > 1: + raise ValueError( + "Duration value is present multiple times. " + "Index must be specified." + ) + index = self.durations.index(value) + del self.amounts[index] + del self.durations[index] + logger.debug("Removed measurement by duration: %s", value) + else: + raise ValueError("Duration value not found in durations list.") + elif method == "index": + if 0 <= value < len(self.amounts): + del self.amounts[value] + del self.durations[value] + logger.debug("Removed measurement by index: %s", value) + else: + raise IndexError("Index out of range for amounts and durations lists.") + else: + raise ValueError(f"Unknown method for removing measurement: {method}") + self.lastdatemodified = datetime.datetime.now() + self._update_coeffs() + + def _update_coeffs(self) -> None: + """Update the polynomial coefficients based on current measurements.""" + if len(self.amounts) < 2: + # potential feature: 1 value assumes intercept at 0 + self.coeffs = [] + return + # If only two measurements, use linear fit, otherwise use 2 degree polynomial + order = 1 if len(self.amounts) == 2 else 2 + self.coeffs = np.polyfit(self.amounts, self.durations, order).tolist() + + @field_serializer("lastdatemodified") + def serialize_datetime(self, dt: datetime.datetime, _info): + if isinstance(dt, str): # if uncalibrated it is an empty string + return dt + return dt.isoformat(timespec="seconds") + + +class ValveManagerMetaData(BaseModel): + modification_datetime: datetime.datetime = Field( + default_factory=datetime.datetime.now, + description="Time of when the valve data was last saved to json file.", + ) + + COM: str = Field( + default="", + description="The COM port of the last state machine to modify the valves.", + ) + + id: str = Field( + default="", + description="Unique identifier of the machine that controls the valve " + "(i.e. Bpod State Machine or Port Array Module.", + ) + + @field_serializer("modification_datetime") + def serialize_datetime(self, modification_datetime: datetime.datetime, _info): + return modification_datetime.isoformat(timespec="seconds") + + +class ValveDataManager(BaseModel): + """Parent of multiple ValveData objects.""" + + metadata: ValveManagerMetaData = Field( + default_factory=ValveManagerMetaData, + description="Metadata regarding the set of valves.", + ) + valve_datas: list[ValveData] = Field( + alias="ValveDatas", default=[], description="Array of valve data objects." + ) + + model_config = ConfigDict(serialize_by_alias=True, validate_by_name=True) + + def get_valve(self, valvename: str) -> ValveData: + """Get a valve by name. + + Parameters + ---------- + valvename : str + The name of the valve to retrieve. + + Returns + ------- + ValveData + """ + # Check if the valve exists in the ValveDatas list + if valvename in self.valve_names: + # check if there are two valves with same name, raise error if so + if sum(name == valvename for name in self.valve_names) > 1: + raise KeyError(f"Multiple valves with name '{valvename}' found.") + return next(valve for valve in self.valve_datas if valve.name == valvename) + + raise KeyError(f"Valve '{valvename}' not found in valve manager.") + + def create_valve(self, valvename: str) -> None: + """Create a new valve with the given name.""" + if valvename in self.valve_names: + raise KeyError(f"Valve '{valvename}' already exists.") + self.valve_datas.append(ValveData(ValveName=valvename)) + logger.debug(f"Created new valve: {valvename}") + + @property + def n_valves(self) -> int: + """The number of valves in the manager.""" + return len(self.valve_datas) + + @property + def valve_names(self) -> list[str]: + """List of valve names.""" + return [valve.name for valve in self.valve_datas] + + def to_json(self, machineid: str | None = None) -> str: + """Convert the ValveDataManager to JSON string. + + Parameters + ---------- + machineid : str, optional + The COM port or machine ID to set in the metadata. + + Returns + ------- + str + JSON text (2 space indented) for writing to file. + """ + if machineid is not None: + if not isinstance(machineid, str): + raise ValueError("machineid must be a string.") + + # check if the COM is changing + if (self.metadata.COM != "") & (self.metadata.COM != machineid): # noqa: SIM300 + logger.warning( + "COM port of liquid calibration file is changing from %s to %s", + self.metadata.COM, + machineid, + ) + self.metadata.COM = machineid + self.metadata.modification_datetime = datetime.datetime.now() + return self.model_dump_json(indent=2) diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py new file mode 100644 index 0000000..04371ea --- /dev/null +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -0,0 +1,329 @@ +import logging + +import bpod_core.bpod +from pydantic import Field +from bpod_core.fsm import StateMachine + +from bpod_rig.calibration.liquid.models import ValveDataManager, ValveData + +logger = logging.getLogger(__name__) + + +class PendingValve(ValveData): + pending_durations: list[float] = Field( + default=[], + description="List of durations that are pending calibration measurements.", + ) + + valve_controller: str = Field( + default="unknown", + description="The controlling hardware of the valve " + "(e.g., 'StateMachine', 'PortArray').", + ) + + def __init__(self, **data): + super().__init__(**data) + if self.name.startswith("PA"): + self.valve_controller = "PortArray" + elif self.name.startswith("Valve"): + self.valve_controller = "StateMachine" + else: + raise ValueError("Valve name not recognised for controller assignment.") + + def action(self, actiontype: str) -> dict[str, str | list[str | int] | bool]: + """Define the output action for this valve + + Parameters + ---------- + actiontype : str + What kind of action the valve is performing: "open" or "close" + + Returns + ------- + dict + Action definition for valve that can be fed to + StateMachine.add_state(**, output_actions=) + """ + if actiontype not in ["open", "close"]: + raise ValueError("Action type not recognised.") + if self.valve_controller == "StateMachine": + if actiontype == "open": + # matlab code + # ValvePhysicalAddress = 2.^(0:7); + # ValveAddress = ValvePhysicalAddress(valveID) + # valve_id = int(self.name[5:]) + # valve_address = 2 ** (valve_id - 1) + # return {'ValveState': valve_address} + return {self.name: True} + + if actiontype == "close": + raise ValueError("StateMachine valves do not have a close action.") + raise ValueError("Action type not recognised.") + if self.valve_controller == "PortArray": + target_module = self.name[0:3] + portnumber = self.name[4] + if actiontype == "open": + # TODO: serial message byte parsing hasn't been checked + action_serial = ["V", portnumber, 1] + return {target_module: action_serial} + if actiontype == "close": + action_serial = ["V", portnumber, 0] + return {target_module: action_serial} + raise ValueError("Action type not recognised") + raise ValueError("Valve controller not recognised.") + + @property + def is_pending(self) -> bool: + return len(self.pending_durations) > 0 + + +class PendingMeasurementsManager: + """Manage valves and their pending values. + + Maintains list of pending valves (which store pending values), and can construct a + state machine to test the pending durations. + """ + + _valvemanager: ValveDataManager + """The real data used by Bpod.""" + valves: list[PendingValve] + """All of the pending valves that are being managed""" + n_pulses: int = 100 + """number of pulses to delivery across the test""" + pulse_interval: float = 0.2 + """time (s) between two different valves pulsing""" + pulse_set_pause: float = 0.5 + """time (s) between each set of pulses""" + + def __init__(self, valvemanager: ValveDataManager): + self._valvemanager = valvemanager + self.valves = [ + PendingValve(ValveName=valve.name) for valve in valvemanager.valve_datas + ] + + def get_valve(self, valvename: str) -> PendingValve: + """Retrieve object for a pending valve.""" + if valvename not in self._valvemanager.valve_names: + raise KeyError(f"Valvename '{valvename}' not found in manager.") + return next(valve for valve in self.valves if valve.name == valvename) + + def get_pending(self, valvename: str) -> list[float]: + """Return pending durations from a given valve.""" + return self.get_valve(valvename).pending_durations + + def add_pending(self, valvename: str, duration: float) -> None: + """Add a pending duration to the valve's list.""" + if duration in self.get_pending(valvename): + raise ValueError("Duration already pending.") + if duration in self._valvemanager.get_valve(valvename).durations: + raise ValueError("Duration already exists in valve calibration.") + + self.get_pending(valvename).append(duration) + + def remove_pending(self, valvename: str, duration: float) -> None: + """Remove a pending duration from the valve's pending record.""" + if duration not in self.get_pending(valvename): + raise ValueError("Duration for completed measurement is not pending.") + + self.get_pending(valvename).remove(duration) + + def complete_measurement( + self, valvename: str, duration: float, amount: float + ) -> None: + """Record an amount dispensed for a duration that was measured. + + Remove duration from pending list and adds completed measurement to the valve + manager. + + Parameters + ---------- + valvename : str + Name of the valve that has been tested and weighed. + duration : float + Duration (ms) of valve opening + amount : float + Quantity (mg or mL) released from the duration. + """ + if duration not in self.get_pending(valvename): + raise ValueError("Duration not found in existing pending values.") + self.remove_pending(valvename, duration) + self._valvemanager.get_valve(valvename).add_measurement(duration, amount) + logger.debug("Completed measurement for %s %s ms", valvename, duration) + + def build_test_set(self) -> dict[str, float]: + """Create dict of valves and the duration (ms) to test. + + Will take the first (oldest) duration in the valve's pending values. + """ + pending_valves = [valve for valve in self.valves if valve.is_pending] + test_set = {} + for valve in pending_valves: + test_set[valve.name] = valve.pending_durations[0] + return test_set + + def build_statemachine( + self, test_set: dict | None = None + ) -> tuple[StateMachine, dict[str, float]]: + """Build the state machine to test all pending valves. + + The state machine is constructed using object parameters: + - n_pulses: number of pulses to deliver. + - pulse_interval: Duration (s) between one valve's pulse end and next's start. + - pulse_set_pause: Duration (s) between a set of pulses. + + The amount of liquid dispensed from the state machine + For example, 100 pulses totaling 0.200g (200mg) is 0.200g / 100 = 2mg = 2mL + + Parameters + ---------- + test_set : dict | None, optional + Names of valves and durations to build state machine for. If not provided, + test all valves that have pending durations using their oldest duration. + Returns + ------- + tuple[StateMachine, dict[str, float]] + - The finite state machine. + - dict of which valves and what duration the state machine was built for. + """ + fsm = StateMachine() + if test_set is None: + test_set = self.build_test_set() + valve_names = list(test_set.keys()) + logger.debug("Building test set for %s", valve_names) + pending_valves = [valve for valve in self.valves if valve.is_pending] + n_pending = len(pending_valves) + if n_pending == 0: + raise AssertionError("No pending valves found.") + for x, valve_name in enumerate(valve_names): + pending_valve = self.get_valve(valve_name) + duration = test_set[valve_name] + if x < n_pending - 1: + next_valve = self.get_valve(valve_names[x + 1]) + else: + next_valve = self.pulse_set_pause + add_valve_states( + fsm, pending_valve, duration, next_valve, self.pulse_interval + ) + return fsm, test_set + + +def add_valve_states( + fsm: StateMachine, + pending_valve: PendingValve, + duration: float, + next_valve: PendingValve | float, + interval_delay_duration: float, +): + """Add the pending valve to a state machine. + + Parameters + ---------- + fsm : StateMachine + Existing state machine to add states for the valve to. + pending_valve : PendingValve + Valve to create states to use for calibration + duration : float + Duration (ms) of valve open. + next_valve : PendingValve | float + The next valve object in the cycle, or the float for the pause duration (s) + between pulse sets. + interval_delay_duration : float + """ + last_valve = isinstance(next_valve, float) + if last_valve: + pulse_set_pause_duration = next_valve + next_state = "PulseSetPause" + else: + if not isinstance(next_valve, PendingValve): + raise ValueError("Next valve must be a PendingValve or float.") + pulse_set_pause_duration = None + next_state = f"Pulse{next_valve.name}" + + # Add valve openings to the state machine + if pending_valve.valve_controller == "StateMachine": + # Valve is wired i.e. open while state is active + fsm.add_state( + name=f"Pulse{pending_valve.name}", + timer=duration / 1000, + transitions={"Tup": f"Delay{pending_valve.name}"}, + actions=pending_valve.action("open"), + comment="Open valve for duration of state, with closure at the end.", + ) + elif pending_valve.valve_controller == "PortArray": + # Valve is controlled by serial messaging i.e. discrete open/close messages + fsm.add_state( + name=f"Pulse{pending_valve.name}", + timer=duration / 1000, + transitions={"Tup": f"EndPulse{pending_valve.name}"}, + actions=pending_valve.action("open"), + comment="Send message to Port Array Module to open the valve.", + ) + fsm.add_state( + name=f"EndPulse{pending_valve.name}", + timer=0, + transitions={"Tup": f"Delay{pending_valve.name}"}, + actions=pending_valve.action("close"), + comment="Send message to Port Array module to close the valve.", + ) + else: + raise ValueError("Valve controller not recognised.") + + fsm.add_state( + name=f"Delay{pending_valve.name}", + timer=interval_delay_duration, + transitions={"Tup": next_state}, + actions={}, + comment="Pause before opening the next valve.", + ) + + if last_valve: + # If last valve in set, enter this state following the valve's Delay. + fsm.add_state( + name="PulseSetPause", + timer=pulse_set_pause_duration, + transitions={"Tup": "exit"}, + actions={}, + comment="Pause when the entire set of valves is completed.", + ) + + +def run_calibration( + bpodsystem: bpod_core.bpod.Bpod, + pending_manager: PendingMeasurementsManager, + verbose: bool = False, +) -> dict[str, float]: + """Run a calibration sequence using Bpod state machine. + + Parameters + ---------- + bpodsystem : bpod_core.bpod.Bpod + Active Bpod machine. + pending_manager : PendingMeasurementsManager + Object storing valve data + verbose : bool, optional + Print progress, default False. + + Returns + ------- + dict[str, float] + Dictionary of valve names and the duration (ms) of the pulse added to the + state machine. + """ + + # Build the state machine + fsm, test_set = pending_manager.build_statemachine() + if verbose: + print("Running liquid calibration:") + print(f"\t{pending_manager.n_pulses} pulses.") + print(f"\t{pending_manager.pulse_interval} between each pulse.") + print(f"\t{pending_manager.pulse_set_pause} between each pulse set.") + for valvename in test_set: + print(f"- {valvename}: duration {test_set[valvename]} ms") + + bpodsystem.send_state_machine(fsm) + logger.debug("Running calibration state machine.") + # Run the state machine + for _ in range(pending_manager.n_pulses): + bpodsystem.run_state_machine() + + return test_set diff --git a/bpod_rig/calibration/liquid/populate_examples.py b/bpod_rig/calibration/liquid/populate_examples.py new file mode 100644 index 0000000..3a5b566 --- /dev/null +++ b/bpod_rig/calibration/liquid/populate_examples.py @@ -0,0 +1,62 @@ +"""Create the example liquid calibration JSON with 8 valves and dummy measurements.""" + +import datetime +from pathlib import Path + +from bpod_rig.examples import calibration as example_folder +from bpod_rig.calibration.liquid.models import ValveDataManager +from bpod_rig.calibration.liquid.utils import create_empty_valve_data_manager + + +def add_dummy_measurements(valvemanager: ValveDataManager) -> None: + """Add dummy measurements to the valves in manager. + + Modifies the data in place. + Uses 2000-01-01 as the origin date for the data. + + Examples + -------- + >>> dummy_liquid_manager = create_empty_valve_data_manager() + >>> add_dummy_measurements(dummy_liquid_manager) + """ + origin_date = datetime.datetime(2000, 1, 1) + valve1 = valvemanager.get_valve("Valve1") + valve1.add_measurement(22, 2) + valve1.add_measurement(66, 9.5) + valve1.add_measurement(44, 5) + valve1.add_measurement(57, 7.5) + valve1.add_measurement(34, 3.5) + valve1.lastdatemodified = origin_date + + valve3 = valvemanager.get_valve("Valve3") + valve3.add_measurement(22, 1.5) + valve3.add_measurement(46, 5.5) + valve3.add_measurement(59, 7.5) + valve3.add_measurement(35, 3.5) + valve3.add_measurement(66, 8.5) + valve3.lastdatemodified = origin_date + + valvemanager.metadata.modification_datetime = origin_date + + +def create_default_json() -> str: + """Produce a default valve JSON string with 8 valves and dummy measurements.""" + dummy = create_empty_valve_data_manager() + add_dummy_measurements(dummy) + jsontext = dummy.to_json(machineid=None) + # replace modification datetime with a fixed date for consistency + return jsontext.replace( + dummy.metadata.modification_datetime.isoformat(timespec="seconds"), + "2000-01-01T00:00:00", + ) + + +def main(): + """Create the example liquid calibration JSON in the example folder.""" + example_json = create_default_json() + example_path = Path(example_folder.__path__[0]) / "LiquidCalibration.json" + example_path.write_text(example_json) + + +if __name__ == "__main__": + main() diff --git a/bpod_rig/calibration/liquid/utils.py b/bpod_rig/calibration/liquid/utils.py new file mode 100644 index 0000000..21bfe0f --- /dev/null +++ b/bpod_rig/calibration/liquid/utils.py @@ -0,0 +1,167 @@ +"""Liquid calibration data management and calibration routines.""" + +import datetime +import logging +import numpy as np + +from bpod_rig.calibration.liquid.models import ValveData, ValveDataManager + +logger = logging.getLogger(__name__) + + +def create_empty_valve_data_manager( + source: str = "statemachine", + n_valves: int = 8, +) -> ValveDataManager: + """Create an empty valve manager with 8 valves.""" + valvemanager = ValveDataManager() + if source == "statemachine": + for index in range(n_valves): + valvemanager.create_valve(f"Valve{index + 1}") + elif source == "portarray": + for index in range(n_valves): + for port_array in range(4): + valvemanager.create_valve(f"PA{port_array + 1}_{index + 1}") + else: + raise ValueError(f"Unknown source for valve names: {source}") + + return valvemanager + + +def check_valvemanager_user_updated(valvemanager: ValveDataManager) -> bool: + """Check if the user has updated the valve manager from the example. + + This is determined by checking if any valve has a modification date later than the + origin date of 2000-01-01. + + Parameters + ---------- + valvemanager : ValveDataManager + The valve manager to check. + + Returns + ------- + bool + True if any valve has measurements, False otherwise. + """ + return valvemanager.metadata.modification_datetime > datetime.datetime(2000, 1, 1) + + +def suggest_duration( + valveobject: ValveData, + range_low: float, + range_high: float, +) -> float: + """Suggest a duration for a given valve and range. + + Parameters + ---------- + valveobject : ValveData + The valve object to use for the suggestion. + range_low : float + The lower bound of microliters (uL) for the range. + range_high : float + The upper bound of microliters (uL) for the range. + + Returns + ------- + float + The suggested duration in ms. + """ + if len(valveobject.durations) >= 1: + durations = np.array(valveobject.durations) + amounts = np.array(valveobject.amounts) + else: + durations = np.array([]) + amounts = np.array([]) + + n_measurements = len(durations) + + if n_measurements >= 2: # Use curve fit to predict next measurement + amounts_vector = calculate_ranged_amounts(amounts, range_low, range_high) + suggested_amount = calculate_largest_gap_midpoint(amounts_vector) + suggested_duration_ms = valveobject.get_valve_time(suggested_amount) + elif n_measurements == 1: + # Use a linear estimate + suggested_duration_ms = linearly_suggest_duration( + amounts[0], durations[0], range_low, range_high + ) + else: + # Use an estimate of the middle of the range based on range and our experience + # with the CSHL configuration. + # (nResearch pinch valves, silastic tubing, specs in Bpod literature) + suggested_duration_ms = (range_high + range_low) * 4 + return suggested_duration_ms + + +def calculate_largest_gap_midpoint(sorted_values: list[float]) -> float: + """Calculate the midpoint of the largest gap in a sorted list of values.""" + distances = np.zeros(len(sorted_values)) + for y in range(1, len(sorted_values)): + distances[y] = abs(sorted_values[y] - sorted_values[y - 1]) + max_distance_pos = np.argmax(distances) + # Return the midpoint location + return ( + sorted_values[max_distance_pos - 1] + + (sorted_values[max_distance_pos] - sorted_values[max_distance_pos - 1]) / 2 + ) + + +def linearly_suggest_duration( + amount: float | np.typing.ArrayLike, + duration: float | np.typing.ArrayLike, + range_low: float, + range_high: float, +) -> float: + """Suggest a duration based on a single measurement and a range.""" + ul_per_ms = float(amount / duration) + if (amount < range_low) or (amount > range_high): + target_amount = (range_high - range_low) / 2 + else: + bottom_part = amount - range_low + top_part = range_high - amount + if bottom_part > top_part: + target_amount = range_low + (bottom_part / 2) + else: + target_amount = range_high - (top_part / 2) + return round(target_amount / ul_per_ms) + + +def calculate_ranged_amounts( + amounts: np.array, range_low: float, range_high: float +) -> list[float]: + """Calculate a sorted list of amounts within a range, including the range bounds.""" + amounts_vector = amounts.tolist() + if range_low not in amounts: + amounts_vector.append(range_low) + if range_high not in amounts: + amounts_vector.append(range_high) + + # take only values within range + amounts_vector = sorted(amounts_vector) + startpoint = amounts_vector.index(range_low) + endpoint = amounts_vector.index(range_high) + return amounts_vector[startpoint : endpoint + 1] + + +def check_com(valvemanager: ValveDataManager, com_port: str) -> str: + """Check if the COM port in the valve manager matches the given COM port. + + Parameters + ---------- + valvemanager : ValveDataManager + The valve manager to check. + com_port : str + The COM port to check against. + + Returns + ------- + str + 'yes' if the COM ports match, 'no' if it's different, 'unknown' otherwise. + """ + valvecom = valvemanager.metadata.COM + if valvecom == "": # e.g. uncalibrated + return "unknown" + if valvecom == com_port: + return "yes" + return "no" diff --git a/bpod_rig/examples/calibration/LiquidCalibration.json b/bpod_rig/examples/calibration/LiquidCalibration.json new file mode 100644 index 0000000..4454d0c --- /dev/null +++ b/bpod_rig/examples/calibration/LiquidCalibration.json @@ -0,0 +1,96 @@ +{ + "metadata": { + "modification_datetime": "2000-01-01T00:00:00", + "COM": "" + }, + "ValveDatas": [ + { + "ValveName": "Valve1", + "LastDateModified": "2000-01-01T00:00:00", + "Coeffs": [ + -0.30576102280847484, + 9.320110469495358, + 4.82071882423374 + ], + "Durations": [ + 22, + 66, + 44, + 57, + 34 + ], + "Amounts": [ + 2, + 9.5, + 5, + 7.5, + 3.5 + ] + }, + { + "ValveName": "Valve2", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve3", + "LastDateModified": "2000-01-01T00:00:00", + "Coeffs": [ + 0.047977795400477294, + 5.724028548770809, + 13.60021808088822 + ], + "Durations": [ + 22, + 46, + 59, + 35, + 66 + ], + "Amounts": [ + 1.5, + 5.5, + 7.5, + 3.5, + 8.5 + ] + }, + { + "ValveName": "Valve4", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve5", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve6", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve7", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve8", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + } + ] +} \ No newline at end of file diff --git a/tests/bpod_rig/calibration/__init__.py b/tests/bpod_rig/calibration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/bpod_rig/calibration/liquid/__init__.py b/tests/bpod_rig/calibration/liquid/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/bpod_rig/calibration/liquid/test_models.py b/tests/bpod_rig/calibration/liquid/test_models.py new file mode 100644 index 0000000..805ceaf --- /dev/null +++ b/tests/bpod_rig/calibration/liquid/test_models.py @@ -0,0 +1,142 @@ +import pytest +import urllib.request + +from bpod_rig.calibration.liquid import utils, populate_examples +from bpod_rig.calibration.liquid.models import ValveData, ValveDataManager + + +class TestValveDataClass: + @pytest.fixture(autouse=True) + def setup_method(self): + self.valve = ValveData(ValveName="Test Valve") + + def test_init_alias(self): + # Test that alias works correctly + # type checker may complain about this because of pydantic aliasing + assert ValveData(name="attrname").name == "attrname" # noqa: Pydantic's validation aliases raises Unexpected Argument + assert ValveData(ValveName="aliasname").name == "aliasname" + + def test_add_measurement(self): + self.valve.add_measurement(10, 1.0) + assert len(self.valve.amounts) == 1 + assert len(self.valve.durations) == 1 + + def test_get_valve_time(self): + self.valve.add_measurement(10, 1.0) + self.valve.add_measurement(20, 2.0) + duration = self.valve.get_valve_time(1.0) + assert duration == pytest.approx(10) + + def test_remove_measurement_by_index(self): + self.valve.add_measurement(10, 1.0) + self.valve.remove_measurement(0) + assert len(self.valve.amounts) == 0 + + def test_remove_measurement_by_duration(self): + self.valve.add_measurement(10, 1.0) + self.valve.remove_measurement(10, method="duration") + assert len(self.valve.amounts) == 0 + + def test_remove_measurement_invalid_index(self): + self.valve.add_measurement(10, 1.0) + self.valve.add_measurement(20, 2.0) + self.valve.remove_measurement(0) + with pytest.raises(IndexError): + self.valve.remove_measurement(5) + + def test_remove_measurement_invalid_duration(self): + self.valve.add_measurement(10, 1.0) + self.valve.add_measurement(20, 2.0) + with pytest.raises(ValueError): + self.valve.remove_measurement(15, method="duration") + + def test_remove_measurement_duplicate_duration(self): + self.valve.add_measurement(10, 1.0) + self.valve.add_measurement(10, 2.0) + with pytest.raises(ValueError): + self.valve.remove_measurement(10, method="duration") + + def test_no_coeffs_with_insufficient_data(self): + self.valve.add_measurement(10, 1.0) + assert len(self.valve.coeffs) == 0 + with pytest.raises(ValueError): + self.valve.get_valve_time(1.0) + self.valve.add_measurement(20, 2.0) + assert self.valve.coeffs is not None + + +class TestValveDataManagerClass: + @pytest.fixture(autouse=True) + def setup_method(self): + self.manager = ValveDataManager() + self.manager.create_valve("Test Valve") + dummy = utils.create_empty_valve_data_manager() + populate_examples.add_dummy_measurements(dummy) + self.dummy = dummy + + def test_create_valve(self): + assert self.manager.n_valves == 1 + assert "Test Valve" in self.manager.valve_names + + def test_get_valve(self): + valve = self.manager.get_valve("Test Valve") + assert isinstance(valve, ValveData) + + def test_measurements(self): + """Test that the default values are behaving as expected.""" + liquid_amount = 15 + + valve = "Valve1" + expected_duration = 0.0758261 * 1000 + duration = self.dummy.get_valve(valve).get_valve_time(liquid_amount) + assert duration == pytest.approx(expected_duration, abs=0.001) + + valve = "Valve3" + expected_duration = 0.1102557 * 1000 + duration = self.dummy.get_valve(valve).get_valve_time(liquid_amount) + assert duration == pytest.approx(expected_duration, abs=0.001) + + +class TestGen2CalibrationFile: + """Test loading a calibration file from the Bpod_Gen2 (MATLAB) repo.""" + + @pytest.fixture(scope="class") + def jsontext(self): + return ( + urllib.request.urlopen( + "https://raw.githubusercontent.com/sanworks/Bpod_Gen2/refs/heads/develop/Examples/Example%20Calibration%20Files/LiquidCalibration.json" + ) + .read() + .decode() + ) + + def test_load_calibration_file(self, jsontext): + loaded_valves = ValveDataManager.model_validate_json(jsontext) + assert isinstance(loaded_valves, ValveDataManager) + assert loaded_valves.n_valves >= 1 + assert "Valve1" in loaded_valves.valve_names + + def test_valve_modtime(self, jsontext): + valvemanager = ValveDataManager.model_validate_json(jsontext) + assert not utils.check_valvemanager_user_updated(valvemanager) + newvalve = ValveDataManager.model_validate_json(valvemanager.to_json()) + assert utils.check_valvemanager_user_updated(newvalve) + + +class TestValveDataManagerJSON: + @pytest.fixture(autouse=True) + def setup_method(self): + self.manager = utils.create_empty_valve_data_manager() + populate_examples.add_dummy_measurements(self.manager) + self.json_str = self.manager.to_json() + self.loaded_manager = ValveDataManager.model_validate_json(self.json_str) + + def test_json_round_trip(self): + assert self.manager.n_valves == self.loaded_manager.n_valves + assert self.manager.valve_names == self.loaded_manager.valve_names + for valve_name in self.manager.valve_names: + original_valve = self.manager.get_valve(valve_name) + loaded_valve = self.loaded_manager.get_valve(valve_name) + assert original_valve.amounts == loaded_valve.amounts + assert original_valve.durations == loaded_valve.durations + assert original_valve.coeffs == loaded_valve.coeffs diff --git a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py new file mode 100644 index 0000000..6078d88 --- /dev/null +++ b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py @@ -0,0 +1,91 @@ +import pytest + +from bpod_core.fsm import StateMachine + +from bpod_rig.calibration.liquid import pending_calibration +from bpod_rig.calibration.liquid import utils, populate_examples + + +class TestPendingValve: + @pytest.fixture(autouse=True) + def setup_method(self): + self.pending_valve = pending_calibration.PendingValve(ValveName="Valve3") + + def test_valve_controller_assignment(self): + valve1 = pending_calibration.PendingValve(ValveName="PA1") + valve2 = pending_calibration.PendingValve(ValveName="Valve1") + assert valve1.valve_controller == "PortArray" + assert valve2.valve_controller == "StateMachine" + + def test_unrecognied_controller_assignment(self): + with pytest.raises(ValueError): + pending_calibration.PendingValve(ValveName="HeartValve") + + def test_ispending(self): + valve = pending_calibration.PendingValve(ValveName="Valve1") + assert not valve.is_pending + valve = pending_calibration.PendingValve(ValveName="Valve1") + valve.pending_durations.append(20) + assert valve.is_pending + + +class TestPendingMeasurementsManager: + @pytest.fixture(autouse=True) + def setup_method(self): + self.valvemanager = utils.create_empty_valve_data_manager() + populate_examples.add_dummy_measurements(self.valvemanager) + self.manager = pending_calibration.PendingMeasurementsManager(self.valvemanager) + + def test_get_pending_initial(self): + assert self.manager.get_pending("Valve1") == [] + assert self.manager.get_pending("Valve2") == [] + + def test_get_pending_invalid_valve(self): + with pytest.raises(KeyError): + self.manager.get_pending("InvalidValve") + + def test_add_pending(self): + self.manager.add_pending("Valve1", 10.0) # valve with values + assert 10.0 in self.manager.get_pending("Valve1") + self.manager.add_pending("Valve2", 10.0) # valve without values + assert 10.0 in self.manager.get_pending("Valve2") + + def test_add_pending_duplicate(self): + self.manager.add_pending("Valve1", 10.0) + with pytest.raises(ValueError): + self.manager.add_pending("Valve1", 10.0) + + def test_remove_pending(self): + self.manager.add_pending("Valve1", 10.0) + self.manager.remove_pending("Valve1", 10.0) + assert 10.0 not in self.manager.get_pending("Valve1") + + def test_remove_pending_nonexistent(self): + with pytest.raises(ValueError): + self.manager.remove_pending("Valve1", 10.0) + + def test_complete_measurement(self): + self.manager.add_pending("Valve2", 10.0) + self.manager.complete_measurement("Valve2", 10.0, 1.5) + valve = self.valvemanager.get_valve("Valve2") + assert 10.0 in valve.durations + assert 1.5 in valve.amounts + assert 10.0 not in self.manager.get_pending("Valve2") + + def test_build_testset(self): + self.manager.add_pending("Valve1", 10) + self.manager.add_pending("Valve1", 5) + self.manager.add_pending("Valve2", 20) + assert self.manager.build_test_set() == {"Valve1": 10, "Valve2": 20} + + def test_build_statemachine(self): + self.manager.add_pending("Valve1", 10) + self.manager.add_pending("Valve2", 15) + statemachine, test_set = self.manager.build_statemachine() + assert isinstance(statemachine, StateMachine) + assert len(test_set.keys()) == 2 + # expect 2 states of open and delay, with final pulse set delay + n_states = 2 + 2 + 1 + assert len(statemachine.states) == n_states + + # TODO: test build, send, and run of state machine with emulator bpod diff --git a/tests/bpod_rig/calibration/liquid/test_utils.py b/tests/bpod_rig/calibration/liquid/test_utils.py new file mode 100644 index 0000000..afd62d2 --- /dev/null +++ b/tests/bpod_rig/calibration/liquid/test_utils.py @@ -0,0 +1,70 @@ +import datetime +import pytest + +from bpod_rig.calibration.liquid import utils +from bpod_rig.calibration.liquid.models import ValveData, ValveDataManager + + +class TestSuggestDuration: + @pytest.fixture(autouse=True) + def setup_method(self): + self.valve = ValveData(ValveName="Test Valve") + self.range_low = 2 + self.range_high = 10 + self.suggest_duration = lambda: utils.suggest_duration( + self.valve, self.range_low, self.range_high + ) + + def test_no_measures(self): + assert self.suggest_duration() == 48 + + def test_one_measure(self): + self.valve.add_measurement(57, 7.5) + assert self.suggest_duration() == 36 + + def test_two_measures(self): + self.valve.add_measurement(22, 2) + self.valve.add_measurement(66, 9.5) + assert self.suggest_duration() == pytest.approx(44.0, abs=0.001) + + +class TestCheckValveManagerUserUpdate: + @pytest.fixture(autouse=True) + def setup_method(self): + self.manager = ValveDataManager() + + def test_earlier(self): + self.manager.metadata.modification_datetime = datetime.datetime(1999, 1, 1) + assert not utils.check_valvemanager_user_updated(self.manager) + + def test_later(self): + self.manager.create_valve("Valve1") + self.manager.get_valve("Valve1").add_measurement(15, 20) + assert utils.check_valvemanager_user_updated(self.manager) + + def test_equal(self): + self.manager.metadata.modification_datetime = datetime.datetime(2000, 1, 1) + assert not utils.check_valvemanager_user_updated(self.manager) + + def test_dummy(self): + from bpod_rig.calibration.liquid.populate_examples import create_default_json + + manager = ValveDataManager.model_validate_json(create_default_json()) + assert not utils.check_valvemanager_user_updated(manager) + + +class TestCheckCOM: + @pytest.fixture(autouse=True) + def setup_method(self): + self.manager = ValveDataManager() + self.manager.metadata.COM = "COM3" + + def test_matching(self): + assert utils.check_com(self.manager, "COM3") == "yes" + + def test_no_match(self): + assert utils.check_com(self.manager, "COM4") == "no" + + def test_unknown(self): + self.manager.metadata.COM = "" + assert utils.check_com(self.manager, "COM3") == "unknown"