From c66838c6fc7f497ecfdba434bb0129829dd3d8be Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 22:51:03 +1000 Subject: [PATCH 01/30] Create liquid calibration module --- bpod_rig/calibration/__init__.py | 0 bpod_rig/calibration/liquid.py | 486 +++++++++++++++++++++++++++++++ 2 files changed, 486 insertions(+) create mode 100644 bpod_rig/calibration/__init__.py create mode 100644 bpod_rig/calibration/liquid.py diff --git a/bpod_rig/calibration/__init__.py b/bpod_rig/calibration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bpod_rig/calibration/liquid.py b/bpod_rig/calibration/liquid.py new file mode 100644 index 0000000..17638eb --- /dev/null +++ b/bpod_rig/calibration/liquid.py @@ -0,0 +1,486 @@ +"""Liquid calibration data management and calibratino routines.""" +import datetime +import logging +from pydantic import BaseModel, Field, ConfigDict, field_serializer +import numpy as np + +logger = logging.getLogger(__name__) + + +class ValveDataClass(BaseModel): + name: str = Field(alias="ValveName") + lastdatemodified: datetime.datetime | str = Field( + serialization_alias="LastDateModified", + validation_alias="LastDateModified", + default="", + description="Modification datetime in ISO format or empty string if not set.", + ) + coeffs: list[float] = Field( + alias="Coeffs", + default=[], + description="Coefficients for polynomial fitting of durations vs amounts.", + ) + durations: list[float | int] = Field( + alias="Durations", + default=[], + description="Durations in ms for each dispense, corresponding to the amounts.", + ) + amounts: list[float | int] = Field( + alias="Amounts", + default=[], + description="Amounts in uL for each dispense, corresponding to the durations.", + ) + + model_config = ConfigDict( + serialize_by_alias=True, validate_by_name=True, validate_by_alias=True, + ) + + def get_valve_time(self, amount: float) -> float: + """Get the valve time for a given amount of liquid. + + Parameters + ---------- + amount : float + The amount of liquid in mL. + + Returns + ------- + float + The duration in ms for the given amount. + """ + if len(self.coeffs) == 0: + raise ValueError("Coefficients are not set. Please add measurements first.") + duration = np.polyval(self.coeffs, amount) + logger.debug( + "%s calculated valve time: %s ms for amount: %s mL", + self.name, + duration, + amount, + ) + return float(duration) + + def add_measurement(self, duration: float, amount: float) -> None: + """Add a measurement to the valve data. + + Parameters + ---------- + duration : float + The duration of the dispense. + amount : float + The amount of liquid dispensed. + """ + self.amounts.append(amount) + self.durations.append(duration) + logger.debug("Added measurement: amount = %s, duration = %s", amount, duration) + self.lastdatemodified = datetime.datetime.now() + self._update_coeffs() + + def remove_measurement(self, value: int | float, method: str = "index") -> None: + """Remove a measurement from the valve data. + + Parameters + ---------- + value : int | float + The amount or duration to remove. + method : str, optional (default = 'index') + How to use value to find the measurement to remove. 'index' to remove by index, + 'duration' to remove by duration value. + """ + if method == "duration": + if value in self.durations: + # Throw error if value is in durations multiple times + if self.durations.count(value) > 1: + raise ValueError( + "Duration value is present multiple times. Index must be specified." + ) + index = self.durations.index(value) + del self.amounts[index] + del self.durations[index] + logger.debug("Removed measurement by duration: %s", value) + else: + raise ValueError("Duration value not found in durations list.") + elif method == "index": + if 0 <= value < len(self.amounts): + del self.amounts[value] + del self.durations[value] + logger.debug("Removed measurement by index: %s", value) + else: + raise IndexError("Index out of range for amounts and durations lists.") + else: + raise ValueError(f"Unknown method for removing measurement: {method}") + self.lastdatemodified = datetime.datetime.now() + self._update_coeffs() + + def _update_coeffs(self) -> None: + if len(self.amounts) < 2: + self.coeffs = [] + return + # TODO: 1 value assumes intercept at 0? + elif len(self.amounts) == 2: + # If only two measurements, use linear fit + order = 1 + else: + order = 2 + # Example: Fit a polynomial of degree 2 + self.coeffs = np.polyfit(self.amounts, self.durations, order).tolist() + + @field_serializer("lastdatemodified") + def serialize_datetime(self, dt: datetime.datetime, _info): + if isinstance(dt, str): # if uncalibrated it is an empty string + return dt + return dt.isoformat(timespec="seconds") + + +class ValveManagerMetaData(BaseModel): + modification_datetime: datetime.datetime = Field( + default_factory=datetime.datetime.now, + description="Time of when the valve data was last saved to json file.", + ) + COM: str = "" + + @field_serializer("modification_datetime") + def serialize_datetime(self, modification_datetime: datetime.datetime, _info): + return modification_datetime.isoformat(timespec="seconds") + + + +class ValveDataManagerClass(BaseModel): + metadata: ValveManagerMetaData = Field(default_factory=ValveManagerMetaData) + valve_datas: list[ValveDataClass] = Field( + alias="ValveDatas", default=[], description="Array of valve data objects." + ) + + model_config = ConfigDict(serialize_by_alias=True, validate_by_name=True) + + def get_valve(self, valvename: str) -> ValveDataClass: + """Get a valve by name. + + Parameters + ---------- + valvename : str + The name of the valve to retrieve. + + Returns + ------- + ValveDataClass + """ + # Check if the valve exists in the ValveDatas list + if valvename in self.valve_names: + # check if there are two valves with same name, raise error if so + if sum(name == valvename for name in self.valve_names) > 1: + raise KeyError(f"Multiple valves with name '{valvename}' found.") + return next(valve for valve in self.valve_datas if valve.name == valvename) + else: + raise KeyError(f"Valve '{valvename}' not found in valve manager.") + + def create_valve(self, valvename: str) -> None: + """Create a new valve with the given name.""" + if valvename in self.valve_names: + raise KeyError(f"Valve '{valvename}' already exists.") + self.valve_datas.append(ValveDataClass(name=valvename)) # noqa: aliasing with pydantic can cause type check issues + logger.debug(f"Created new valve: {valvename}") + + @property + def n_valves(self) -> int: + """Get the number of valves in the manager.""" + return len(self.valve_datas) + + @property + def valve_names(self) -> list[str]: + """Get a list of valve names.""" + return list(valve.name for valve in self.valve_datas) + + def to_json(self, machineid: str = None) -> str: + """Convert the ValveDataManagerPydantic to JSON string. + + Parameters + ---------- + machineid : str, optional + The COM port or machine ID to set in the metadata. + """ + if machineid is not None: + if not isinstance(machineid, str): + raise ValueError("machineid must be a string.") + + # check if the COM is changing + if (self.metadata.COM != "") & (self.metadata.COM != machineid): + logger.warning( + "COM port of liquid calibration file is changing from %s to %s", self.metadata.COM, machineid + ) + self.metadata.COM = machineid + self.metadata.modification_datetime = datetime.datetime.now() + return self.model_dump_json(indent=2) + + +def create_empty_valve_data_manager( + source: str = 'statemachine', + n_valves: int = 8, +) -> ValveDataManagerClass: + """Create an empty valve manager with 8 valves.""" + valvemanager = ValveDataManagerClass() + if source == 'statemachine': + for index in range(n_valves): + valvemanager.create_valve(f"Valve{index + 1}") + elif source == 'portarray': + for index in range(n_valves): + for port_array in range(4): + valvemanager.create_valve(f"PA{port_array + 1 }_{index + 1}") + else: + raise ValueError(f"Unknown source for valve names: {source}") + + return valvemanager + + +def add_dummy_measurements(valvemanager: ValveDataManagerClass) -> None: + """Add dummy measurements to the valves in manager.""" + origin_date = datetime.datetime(2000, 1, 1) + valve1 = valvemanager.get_valve("Valve1") + valve1.add_measurement(22, 2) + valve1.add_measurement(66, 9.5) + valve1.add_measurement(44, 5) + valve1.add_measurement(57, 7.5) + valve1.add_measurement(34, 3.5) + valve1.lastdatemodified = origin_date + + valve3 = valvemanager.get_valve("Valve3") + valve3.add_measurement(22, 1.5) + valve3.add_measurement(46, 5.5) + valve3.add_measurement(59, 7.5) + valve3.add_measurement(35, 3.5) + valve3.add_measurement(66, 8.5) + valve3.lastdatemodified = origin_date + + valvemanager.metadata.modification_datetime = origin_date + +def create_default_json() -> str: + """Produce a default valve JSON string with 8 valves and dummy measurements.""" + dummy = create_empty_valve_data_manager() + add_dummy_measurements(dummy) + jsontext = dummy.to_json(machineid=None) + # replace modification datetime with a fixed date for consistency + jsontext = jsontext.replace( + dummy.metadata.modification_datetime.isoformat(timespec="seconds"), + "2000-01-01T00:00:00", + ) + return jsontext + +def check_valvemanager_user_updated(valvemanager: ValveDataManagerClass) -> bool: + """Check if the user has updated the valve manager has been updated from the example. + + This is determined by checking if any valve has a modification date later than the + origin date of 2000-01-01. + + Parameters + ---------- + valvemanager : ValveDataManagerClass + The valve manager to check. + + Returns + ------- + bool + True if any valve has measurements, False otherwise. + """ + return valvemanager.metadata.modification_datetime > datetime.datetime(2000, 1, 1) + + +def suggest_duration( + valveobject: ValveDataClass, + range_low: float, + range_high: float, +) -> float: + """Suggest a duration for a given valve and range. + + Parameters + ---------- + valveobject : ValveDataClass + The valve object to use for the suggestion. + range_low : float + The lower bound of microliters (uL) for the range. + range_high : float + The upper bound of microliters (uL) for the range. + + Returns + ------- + float + The suggested duration in ms. + """ + if len(valveobject.durations) >= 1: + durations = np.array(valveobject.durations) + amounts = np.array(valveobject.amounts) + else: + durations = np.array([]) + amounts = np.array([]) + + n_measurements = len(durations) + + if n_measurements >= 2: # Use curve fit to predict next measurement + amounts_vector = calculate_ranged_amounts(amounts, range_low, range_high) + suggested_amount = calculate_largest_gap_midpoint(amounts_vector) + suggested_duration_ms = valveobject.get_valve_time(suggested_amount) + elif n_measurements == 1: + # Use a linear estimate + suggested_duration_ms = linearly_suggest_duration( + amounts[0], durations[0], range_low, range_high + ) + else: + # Use an estimate of the middle of the range based on range and our experience with + # the CSHL configuration (nResearch pinch valves, silastic tubing, specs in Bpod literature) + suggested_duration_ms = (range_high + range_low) * 4 + return suggested_duration_ms + + +def calculate_largest_gap_midpoint(sorted_values: list[float]) -> float: + """Calculate the midpoint of the largest gap in a sorted list of values.""" + distances = np.zeros(len(sorted_values)) + for y in range(1, len(sorted_values)): + distances[y] = abs(sorted_values[y] - sorted_values[y - 1]) + max_distance_pos = np.argmax(distances) + midpoint_value = sorted_values[max_distance_pos - 1] + ( + sorted_values[max_distance_pos] - sorted_values[max_distance_pos - 1] + ) / 2 + return midpoint_value + + +def linearly_suggest_duration( + amount: float | np.typing.ArrayLike, + duration: float | np.typing.ArrayLike, + range_low: float, + range_high: float, +) -> float: + """Suggest a duration based on a single measurement and a range.""" + ul_per_ms = float(amount / duration) + if (amount < range_low) or (amount > range_high): + target_amount = (range_high - range_low) / 2 + else: + bottom_part = amount - range_low + top_part = range_high - amount + if bottom_part > top_part: + target_amount = range_low + (bottom_part / 2) + else: + target_amount = range_high - (top_part / 2) + suggested_duration = round(target_amount / ul_per_ms) + return suggested_duration + + +def calculate_ranged_amounts(amounts: np.array, range_low: float, range_high: float) -> list[float]: + """Calculate a sorted list of amounts within a given range, including the range bounds.""" + amounts_vector = amounts.tolist() + if range_low not in amounts: + amounts_vector.append(range_low) + if range_high not in amounts: + amounts_vector.append(range_high) + + # take only values within range + amounts_vector = sorted(amounts_vector) + startpoint = amounts_vector.index(range_low) + endpoint = amounts_vector.index(range_high) + amounts_vector = amounts_vector[startpoint: endpoint + 1] + return amounts_vector + +def check_COM(valvemanager: ValveDataManagerClass, com_port: str) -> str: + """Check if the COM port in the valve manager matches the given COM port. + + Parameters + ---------- + valvemanager : ValveDataManagerClass + The valve manager to check. + com_port : str + The COM port to check against. + + Returns + ------- + str + 'yes' if the COM ports match, 'no' if it's different, 'unknown' otherwise. + """ + valvecom = valvemanager.metadata.COM + if valvecom == "": # e.g. uncalibrated + return 'unknown' + if valvecom == com_port: + return 'yes' + else: + return 'no' + +class PendingValve(ValveDataClass): + pending_durations: list[float] = Field( + default=[], + description="List of durations that are pending calibration measurements.", + ) + + valve_controller: str = Field( + default="unknown", + description="The controlling hardware of the valve (e.g., 'StateMachine', 'PortArray').", + ) + + def __init__(self, **data): + super().__init__(**data) + if self.name.startswith("PA"): + self.valve_controller = "PortArray" + else: + self.valve_controller = "StateMachine" + + def get_calibration_run_info(self) -> float: + pass + + +class PendingMeasurementsManager: + _valvemanager: ValveDataManagerClass + valves: list[PendingValve] + + def __init__(self, valvemanager: ValveDataManagerClass): + self._valvemanager = valvemanager + self.valves = [PendingValve(ValveName=valve.name) for valve in valvemanager.valve_datas] + + def get_pending(self, valvename: str) -> list[float]: + if valvename not in self._valvemanager.valve_names: + raise KeyError(f"Valvename '{valvename}' not found in manager.") + + return next(valve for valve in self.valves if valve.name == valvename).pending_durations + + def add_pending(self, valvename: str, duration: float) -> None: + if duration in self.get_pending(valvename): + raise ValueError("Duration already pending.") + if duration in self._valvemanager.get_valve(valvename).durations: + raise ValueError("Duration already exists in valve calibration.") + + self.get_pending(valvename).append(duration) + + def remove_pending(self, valvename: str, duration: float) -> None: + if duration not in self.get_pending(valvename): + raise ValueError("Duration for completed measurement is not pending.") + + self.get_pending(valvename).remove(duration) + + def complete_measurement(self, valvename: str, duration: float, amount: float) -> None: + self.remove_pending(valvename, duration) + self._valvemanager.get_valve(valvename).add_mesurement(duration, amount) + +def run_calibration( + pending_manager: PendingMeasurementsManager, + n_pulses: int, + pulse_interval: float = 0.2, + pulse_set_pause: float = 0.5, + verbose: bool = True, +) -> None: + """Run a calibration sequence using Bpod state machine. + + Parameters + ---------- + + """ + + from bpod_core.fsm import StateMachine + + fsm = StateMachine() + + fsm.add_state( + name='Port1Light', + timer=1, + state_change_conditions={'BNC1_High': 'Port2Light'}, + output_actions={'PWM1': 255}, + ) + fsm.add_state( + name='Port2Light', + timer=1, + state_change_conditions={'Tup': '>exit'}, + output_actions={'PWM2': 255}, + ) From 89528266a0d6e785b2fd26d1cea60d959e5662e9 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 22:51:18 +1000 Subject: [PATCH 02/30] Create liquid calibration tests --- tests/bpod_rig/calibration/__init__.py | 0 tests/bpod_rig/calibration/test_liquid.py | 155 ++++++++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 tests/bpod_rig/calibration/__init__.py create mode 100644 tests/bpod_rig/calibration/test_liquid.py diff --git a/tests/bpod_rig/calibration/__init__.py b/tests/bpod_rig/calibration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/bpod_rig/calibration/test_liquid.py b/tests/bpod_rig/calibration/test_liquid.py new file mode 100644 index 0000000..ad6472d --- /dev/null +++ b/tests/bpod_rig/calibration/test_liquid.py @@ -0,0 +1,155 @@ +import unittest +import urllib.request + +from bpod_rig.calibration import liquid + +class TestValveDataClass(unittest.TestCase): + + def setUp(self): + self.valve = liquid.ValveDataClass(ValveName="Test Valve") + + def test_init_alias(self): + # Test that alias works correctly + # type checker may complain about this because of pydantic aliasing + self.assertEqual(liquid.ValveDataClass(name='attrname').name, 'attrname') + self.assertEqual(liquid.ValveDataClass(ValveName='aliasname').name, 'aliasname') + + def test_add_measurement(self): + self.valve.add_measurement(10, 1.0) + self.assertEqual(len(self.valve.amounts), 1) + self.assertEqual(len(self.valve.durations), 1) + + def test_get_valve_time(self): + self.valve.add_measurement(10, 1.0) + self.valve.add_measurement(20, 2.0) + duration = self.valve.get_valve_time(1.0) + self.assertAlmostEqual(duration, 10) + + def test_remove_measurement_by_index(self): + self.valve.add_measurement(10, 1.0) + self.valve.remove_measurement(0) + self.assertEqual(len(self.valve.amounts), 0) + + def test_remove_measurement_by_duration(self): + self.valve.add_measurement(10, 1.0) + self.valve.remove_measurement(10, method="duration") + self.assertEqual(len(self.valve.amounts), 0) + + def test_remove_measurement_invalid_index(self): + self.valve.add_measurement(10, 1.0) + self.valve.add_measurement(20, 2.0) + self.valve.remove_measurement(0) + with self.assertRaises(IndexError): + self.valve.remove_measurement(5) + + def test_remove_measurement_invalid_duration(self): + self.valve.add_measurement(10, 1.0) + self.valve.add_measurement(20, 2.0) + with self.assertRaises(ValueError): + self.valve.remove_measurement(15, method="duration") + + def test_remove_measurement_duplicate_duration(self): + self.valve.add_measurement(10, 1.0) + self.valve.add_measurement(10, 2.0) + with self.assertRaises(ValueError): + self.valve.remove_measurement(10, method="duration") + + def test_no_coeffs_with_insufficient_data(self): + self.valve.add_measurement(10, 1.0) + self.assertEqual(len(self.valve.coeffs), 0) + with self.assertRaises(ValueError): + self.valve.get_valve_time(1.0) + self.valve.add_measurement(20, 2.0) + self.assertIsNotNone(self.valve.coeffs) + +class TestValveDataManagerClass(unittest.TestCase): + + def setUp(self): + self.manager = liquid.ValveDataManagerClass() + self.manager.create_valve("Test Valve") + dummy = liquid.create_empty_valve_data_manager() + liquid.add_dummy_measurements(dummy) + self.dummy = dummy + + def test_create_valve(self): + self.assertEqual(self.manager.n_valves, 1) + self.assertIn("Test Valve", self.manager.valve_names) + + def test_get_valve(self): + valve = self.manager.get_valve("Test Valve") + self.assertIsInstance(valve, liquid.ValveDataClass) + + def test_measurements(self): + """Test that the default values are behaving as expeted.""" + liquid_amount = 15 + + valve = "Valve1" + expected_duration = 0.0758261 * 1000 + duration = self.dummy.get_valve(valve).get_valve_time(liquid_amount) + self.assertAlmostEqual(duration, expected_duration, places=3) + + valve = "Valve3" + expected_duration = 0.1102557 * 1000 + duration = self.dummy.get_valve(valve).get_valve_time(liquid_amount) + self.assertAlmostEqual(duration, expected_duration, places=3) + +class TestGen2CalibrationFile(unittest.TestCase): + """Test loading a calibration file from the Bpod_Gen2 (MATLAB) repo.""" + + @classmethod + def setUpClass(cls): + super(TestGen2CalibrationFile, cls).setUpClass() + url = "https://raw.githubusercontent.com/sanworks/Bpod_Gen2/refs/heads/develop/Examples/Example%20Calibration%20Files/LiquidCalibration.json" + cls.jsontext = urllib.request.urlopen(url).read().decode() + + def test_load_calibration_file(self): + loaded_valves = liquid.ValveDataManagerClass.model_validate_json(self.jsontext) + self.assertIsInstance(loaded_valves, liquid.ValveDataManagerClass) + self.assertGreaterEqual(loaded_valves.n_valves, 1) + self.assertIn("Valve1", loaded_valves.valve_names) + + def test_valve_modtime(self): + valvemanager = liquid.ValveDataManagerClass.model_validate_json(self.jsontext) + self.assertFalse(liquid.check_valvemanager_user_updated(valvemanager)) + newvalve = liquid.ValveDataManagerClass.model_validate_json(valvemanager.to_json()) + self.assertTrue(liquid.check_valvemanager_user_updated(newvalve)) + + +class TestValveDataManagerJSON(unittest.TestCase): + + def setUp(self): + self.manager = liquid.create_empty_valve_data_manager() + liquid.add_dummy_measurements(self.manager) + self.json_str = self.manager.to_json() + self.loaded_manager = liquid.ValveDataManagerClass.model_validate_json(self.json_str) + + def test_json_round_trip(self): + self.assertEqual(self.manager.n_valves, self.loaded_manager.n_valves) + self.assertEqual(self.manager.valve_names, self.loaded_manager.valve_names) + for valve_name in self.manager.valve_names: + original_valve = self.manager.get_valve(valve_name) + loaded_valve = self.loaded_manager.get_valve(valve_name) + self.assertEqual(original_valve.amounts, loaded_valve.amounts) + self.assertEqual(original_valve.durations, loaded_valve.durations) + self.assertEqual(original_valve.coeffs, loaded_valve.coeffs) + + +class TestSuggestDuration(unittest.TestCase): + + def setUp(self): + self.valve = liquid.ValveDataClass(ValveName="Test Valve") + self.range_low = 2 + self.range_high = 10 + self.suggest_duration = lambda: liquid.suggest_duration(self.valve, self.range_low, self.range_high) + + def test_no_measures(self): + self.assertEqual(self.suggest_duration(), 48) + + def test_one_measure(self): + self.valve.add_measurement(57, 7.5) + self.assertEqual(self.suggest_duration(), 36) + + def test_two_measures(self): + self.valve.add_measurement(22, 2) + self.valve.add_measurement(66, 9.5) + self.assertAlmostEqual(self.suggest_duration(), 44.0, places=3) From 987fd473e3f4900ddaf36ae576732522e4eca643 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:15:17 +1000 Subject: [PATCH 03/30] Refactor into calibration.liquid module --- bpod_rig/calibration/liquid/__init__.py | 0 bpod_rig/calibration/{ => liquid}/liquid.py | 209 +------------------ bpod_rig/calibration/liquid/models.py | 211 ++++++++++++++++++++ tests/bpod_rig/calibration/test_liquid.py | 28 +-- 4 files changed, 230 insertions(+), 218 deletions(-) create mode 100644 bpod_rig/calibration/liquid/__init__.py rename bpod_rig/calibration/{ => liquid}/liquid.py (55%) create mode 100644 bpod_rig/calibration/liquid/models.py diff --git a/bpod_rig/calibration/liquid/__init__.py b/bpod_rig/calibration/liquid/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bpod_rig/calibration/liquid.py b/bpod_rig/calibration/liquid/liquid.py similarity index 55% rename from bpod_rig/calibration/liquid.py rename to bpod_rig/calibration/liquid/liquid.py index 17638eb..a84988f 100644 --- a/bpod_rig/calibration/liquid.py +++ b/bpod_rig/calibration/liquid/liquid.py @@ -1,215 +1,12 @@ """Liquid calibration data management and calibratino routines.""" import datetime import logging -from pydantic import BaseModel, Field, ConfigDict, field_serializer +from pydantic import Field import numpy as np -logger = logging.getLogger(__name__) - - -class ValveDataClass(BaseModel): - name: str = Field(alias="ValveName") - lastdatemodified: datetime.datetime | str = Field( - serialization_alias="LastDateModified", - validation_alias="LastDateModified", - default="", - description="Modification datetime in ISO format or empty string if not set.", - ) - coeffs: list[float] = Field( - alias="Coeffs", - default=[], - description="Coefficients for polynomial fitting of durations vs amounts.", - ) - durations: list[float | int] = Field( - alias="Durations", - default=[], - description="Durations in ms for each dispense, corresponding to the amounts.", - ) - amounts: list[float | int] = Field( - alias="Amounts", - default=[], - description="Amounts in uL for each dispense, corresponding to the durations.", - ) - - model_config = ConfigDict( - serialize_by_alias=True, validate_by_name=True, validate_by_alias=True, - ) - - def get_valve_time(self, amount: float) -> float: - """Get the valve time for a given amount of liquid. - - Parameters - ---------- - amount : float - The amount of liquid in mL. - - Returns - ------- - float - The duration in ms for the given amount. - """ - if len(self.coeffs) == 0: - raise ValueError("Coefficients are not set. Please add measurements first.") - duration = np.polyval(self.coeffs, amount) - logger.debug( - "%s calculated valve time: %s ms for amount: %s mL", - self.name, - duration, - amount, - ) - return float(duration) - - def add_measurement(self, duration: float, amount: float) -> None: - """Add a measurement to the valve data. - - Parameters - ---------- - duration : float - The duration of the dispense. - amount : float - The amount of liquid dispensed. - """ - self.amounts.append(amount) - self.durations.append(duration) - logger.debug("Added measurement: amount = %s, duration = %s", amount, duration) - self.lastdatemodified = datetime.datetime.now() - self._update_coeffs() - - def remove_measurement(self, value: int | float, method: str = "index") -> None: - """Remove a measurement from the valve data. - - Parameters - ---------- - value : int | float - The amount or duration to remove. - method : str, optional (default = 'index') - How to use value to find the measurement to remove. 'index' to remove by index, - 'duration' to remove by duration value. - """ - if method == "duration": - if value in self.durations: - # Throw error if value is in durations multiple times - if self.durations.count(value) > 1: - raise ValueError( - "Duration value is present multiple times. Index must be specified." - ) - index = self.durations.index(value) - del self.amounts[index] - del self.durations[index] - logger.debug("Removed measurement by duration: %s", value) - else: - raise ValueError("Duration value not found in durations list.") - elif method == "index": - if 0 <= value < len(self.amounts): - del self.amounts[value] - del self.durations[value] - logger.debug("Removed measurement by index: %s", value) - else: - raise IndexError("Index out of range for amounts and durations lists.") - else: - raise ValueError(f"Unknown method for removing measurement: {method}") - self.lastdatemodified = datetime.datetime.now() - self._update_coeffs() - - def _update_coeffs(self) -> None: - if len(self.amounts) < 2: - self.coeffs = [] - return - # TODO: 1 value assumes intercept at 0? - elif len(self.amounts) == 2: - # If only two measurements, use linear fit - order = 1 - else: - order = 2 - # Example: Fit a polynomial of degree 2 - self.coeffs = np.polyfit(self.amounts, self.durations, order).tolist() - - @field_serializer("lastdatemodified") - def serialize_datetime(self, dt: datetime.datetime, _info): - if isinstance(dt, str): # if uncalibrated it is an empty string - return dt - return dt.isoformat(timespec="seconds") - - -class ValveManagerMetaData(BaseModel): - modification_datetime: datetime.datetime = Field( - default_factory=datetime.datetime.now, - description="Time of when the valve data was last saved to json file.", - ) - COM: str = "" - - @field_serializer("modification_datetime") - def serialize_datetime(self, modification_datetime: datetime.datetime, _info): - return modification_datetime.isoformat(timespec="seconds") +from .models import ValveDataClass, ValveDataManagerClass - - -class ValveDataManagerClass(BaseModel): - metadata: ValveManagerMetaData = Field(default_factory=ValveManagerMetaData) - valve_datas: list[ValveDataClass] = Field( - alias="ValveDatas", default=[], description="Array of valve data objects." - ) - - model_config = ConfigDict(serialize_by_alias=True, validate_by_name=True) - - def get_valve(self, valvename: str) -> ValveDataClass: - """Get a valve by name. - - Parameters - ---------- - valvename : str - The name of the valve to retrieve. - - Returns - ------- - ValveDataClass - """ - # Check if the valve exists in the ValveDatas list - if valvename in self.valve_names: - # check if there are two valves with same name, raise error if so - if sum(name == valvename for name in self.valve_names) > 1: - raise KeyError(f"Multiple valves with name '{valvename}' found.") - return next(valve for valve in self.valve_datas if valve.name == valvename) - else: - raise KeyError(f"Valve '{valvename}' not found in valve manager.") - - def create_valve(self, valvename: str) -> None: - """Create a new valve with the given name.""" - if valvename in self.valve_names: - raise KeyError(f"Valve '{valvename}' already exists.") - self.valve_datas.append(ValveDataClass(name=valvename)) # noqa: aliasing with pydantic can cause type check issues - logger.debug(f"Created new valve: {valvename}") - - @property - def n_valves(self) -> int: - """Get the number of valves in the manager.""" - return len(self.valve_datas) - - @property - def valve_names(self) -> list[str]: - """Get a list of valve names.""" - return list(valve.name for valve in self.valve_datas) - - def to_json(self, machineid: str = None) -> str: - """Convert the ValveDataManagerPydantic to JSON string. - - Parameters - ---------- - machineid : str, optional - The COM port or machine ID to set in the metadata. - """ - if machineid is not None: - if not isinstance(machineid, str): - raise ValueError("machineid must be a string.") - - # check if the COM is changing - if (self.metadata.COM != "") & (self.metadata.COM != machineid): - logger.warning( - "COM port of liquid calibration file is changing from %s to %s", self.metadata.COM, machineid - ) - self.metadata.COM = machineid - self.metadata.modification_datetime = datetime.datetime.now() - return self.model_dump_json(indent=2) +logger = logging.getLogger(__name__) def create_empty_valve_data_manager( diff --git a/bpod_rig/calibration/liquid/models.py b/bpod_rig/calibration/liquid/models.py new file mode 100644 index 0000000..4b96986 --- /dev/null +++ b/bpod_rig/calibration/liquid/models.py @@ -0,0 +1,211 @@ +import datetime +import logging + +import numpy as np +from pydantic import BaseModel, Field, ConfigDict, field_serializer + +logger = logging.getLogger(__name__) + +class ValveDataClass(BaseModel): + name: str = Field(alias="ValveName") + lastdatemodified: datetime.datetime | str = Field( + serialization_alias="LastDateModified", + validation_alias="LastDateModified", + default="", + description="Modification datetime in ISO format or empty string if not set.", + ) + coeffs: list[float] = Field( + alias="Coeffs", + default=[], + description="Coefficients for polynomial fitting of durations vs amounts.", + ) + durations: list[float | int] = Field( + alias="Durations", + default=[], + description="Durations in ms for each dispense, corresponding to the amounts.", + ) + amounts: list[float | int] = Field( + alias="Amounts", + default=[], + description="Amounts in uL for each dispense, corresponding to the durations.", + ) + + model_config = ConfigDict( + serialize_by_alias=True, validate_by_name=True, validate_by_alias=True, + ) + + def get_valve_time(self, amount: float) -> float: + """Get the valve time for a given amount of liquid. + + Parameters + ---------- + amount : float + The amount of liquid in mL. + + Returns + ------- + float + The duration in ms for the given amount. + """ + if len(self.coeffs) == 0: + raise ValueError("Coefficients are not set. Please add measurements first.") + duration = np.polyval(self.coeffs, amount) + logger.debug( + "%s calculated valve time: %s ms for amount: %s mL", + self.name, + duration, + amount, + ) + return float(duration) + + def add_measurement(self, duration: float, amount: float) -> None: + """Add a measurement to the valve data. + + Parameters + ---------- + duration : float + The duration of the dispense in ms. + amount : float + The amount of liquid dispensed mL. + """ + self.amounts.append(amount) + self.durations.append(duration) + logger.debug("Added measurement: amount = %s, duration = %s", amount, duration) + self.lastdatemodified = datetime.datetime.now() + self._update_coeffs() + + def remove_measurement(self, value: int | float, method: str = "index") -> None: + """Remove a measurement from the valve data. + + Parameters + ---------- + value : int | float + The index or duration to remove. + method : str, optional (default = 'index') + How to use value to find the measurement to remove. 'index' to remove by index, + 'duration' to remove by duration value (assuming unique duration). + """ + if method == "duration": + if value in self.durations: + # Throw error if value is in durations multiple times + if self.durations.count(value) > 1: + raise ValueError( + "Duration value is present multiple times. Index must be specified." + ) + index = self.durations.index(value) + del self.amounts[index] + del self.durations[index] + logger.debug("Removed measurement by duration: %s", value) + else: + raise ValueError("Duration value not found in durations list.") + elif method == "index": + if 0 <= value < len(self.amounts): + del self.amounts[value] + del self.durations[value] + logger.debug("Removed measurement by index: %s", value) + else: + raise IndexError("Index out of range for amounts and durations lists.") + else: + raise ValueError(f"Unknown method for removing measurement: {method}") + self.lastdatemodified = datetime.datetime.now() + self._update_coeffs() + + def _update_coeffs(self) -> None: + """Update the polynomial coefficients based on current measurements.""" + if len(self.amounts) < 2: + self.coeffs = [] + return + # TODO: 1 value assumes intercept at 0? + elif len(self.amounts) == 2: + # If only two measurements, use linear fit + order = 1 + else: + # Fit a polynomial of degree 2 + order = 2 + self.coeffs = np.polyfit(self.amounts, self.durations, order).tolist() + + @field_serializer("lastdatemodified") + def serialize_datetime(self, dt: datetime.datetime, _info): + if isinstance(dt, str): # if uncalibrated it is an empty string + return dt + return dt.isoformat(timespec="seconds") + + +class ValveManagerMetaData(BaseModel): + modification_datetime: datetime.datetime = Field( + default_factory=datetime.datetime.now, + description="Time of when the valve data was last saved to json file.", + ) + COM: str = "" + + @field_serializer("modification_datetime") + def serialize_datetime(self, modification_datetime: datetime.datetime, _info): + return modification_datetime.isoformat(timespec="seconds") + + +class ValveDataManagerClass(BaseModel): + metadata: ValveManagerMetaData = Field(default_factory=ValveManagerMetaData) + valve_datas: list[ValveDataClass] = Field( + alias="ValveDatas", default=[], description="Array of valve data objects." + ) + + model_config = ConfigDict(serialize_by_alias=True, validate_by_name=True) + + def get_valve(self, valvename: str) -> ValveDataClass: + """Get a valve by name. + + Parameters + ---------- + valvename : str + The name of the valve to retrieve. + + Returns + ------- + ValveDataClass + """ + # Check if the valve exists in the ValveDatas list + if valvename in self.valve_names: + # check if there are two valves with same name, raise error if so + if sum(name == valvename for name in self.valve_names) > 1: + raise KeyError(f"Multiple valves with name '{valvename}' found.") + return next(valve for valve in self.valve_datas if valve.name == valvename) + else: + raise KeyError(f"Valve '{valvename}' not found in valve manager.") + + def create_valve(self, valvename: str) -> None: + """Create a new valve with the given name.""" + if valvename in self.valve_names: + raise KeyError(f"Valve '{valvename}' already exists.") + self.valve_datas.append(ValveDataClass(name=valvename)) # noqa: aliasing with pydantic can cause type check issues + logger.debug(f"Created new valve: {valvename}") + + @property + def n_valves(self) -> int: + """Get the number of valves in the manager.""" + return len(self.valve_datas) + + @property + def valve_names(self) -> list[str]: + """Get a list of valve names.""" + return list(valve.name for valve in self.valve_datas) + + def to_json(self, machineid: str = None) -> str: + """Convert the ValveDataManagerPydantic to JSON string. + + Parameters + ---------- + machineid : str, optional + The COM port or machine ID to set in the metadata. + """ + if machineid is not None: + if not isinstance(machineid, str): + raise ValueError("machineid must be a string.") + + # check if the COM is changing + if (self.metadata.COM != "") & (self.metadata.COM != machineid): + logger.warning( + "COM port of liquid calibration file is changing from %s to %s", self.metadata.COM, machineid + ) + self.metadata.COM = machineid + self.metadata.modification_datetime = datetime.datetime.now() + return self.model_dump_json(indent=2) diff --git a/tests/bpod_rig/calibration/test_liquid.py b/tests/bpod_rig/calibration/test_liquid.py index ad6472d..b1bb2c3 100644 --- a/tests/bpod_rig/calibration/test_liquid.py +++ b/tests/bpod_rig/calibration/test_liquid.py @@ -1,18 +1,20 @@ import unittest import urllib.request -from bpod_rig.calibration import liquid +from bpod_rig.calibration.liquid import liquid +from bpod_rig.calibration.liquid.models import ValveDataClass, ValveDataManagerClass class TestValveDataClass(unittest.TestCase): def setUp(self): - self.valve = liquid.ValveDataClass(ValveName="Test Valve") + self.valve = ValveDataClass(ValveName="Test Valve") def test_init_alias(self): # Test that alias works correctly # type checker may complain about this because of pydantic aliasing - self.assertEqual(liquid.ValveDataClass(name='attrname').name, 'attrname') - self.assertEqual(liquid.ValveDataClass(ValveName='aliasname').name, 'aliasname') + self.assertEqual(ValveDataClass(name='attrname').name, 'attrname') + self.assertEqual( + ValveDataClass(ValveName='aliasname').name, 'aliasname') def test_add_measurement(self): self.valve.add_measurement(10, 1.0) @@ -65,7 +67,7 @@ def test_no_coeffs_with_insufficient_data(self): class TestValveDataManagerClass(unittest.TestCase): def setUp(self): - self.manager = liquid.ValveDataManagerClass() + self.manager = ValveDataManagerClass() self.manager.create_valve("Test Valve") dummy = liquid.create_empty_valve_data_manager() liquid.add_dummy_measurements(dummy) @@ -77,7 +79,7 @@ def test_create_valve(self): def test_get_valve(self): valve = self.manager.get_valve("Test Valve") - self.assertIsInstance(valve, liquid.ValveDataClass) + self.assertIsInstance(valve, ValveDataClass) def test_measurements(self): """Test that the default values are behaving as expeted.""" @@ -103,15 +105,17 @@ def setUpClass(cls): cls.jsontext = urllib.request.urlopen(url).read().decode() def test_load_calibration_file(self): - loaded_valves = liquid.ValveDataManagerClass.model_validate_json(self.jsontext) - self.assertIsInstance(loaded_valves, liquid.ValveDataManagerClass) + loaded_valves = ValveDataManagerClass.model_validate_json(self.jsontext) + self.assertIsInstance(loaded_valves, + ValveDataManagerClass + ) self.assertGreaterEqual(loaded_valves.n_valves, 1) self.assertIn("Valve1", loaded_valves.valve_names) def test_valve_modtime(self): - valvemanager = liquid.ValveDataManagerClass.model_validate_json(self.jsontext) + valvemanager = ValveDataManagerClass.model_validate_json(self.jsontext) self.assertFalse(liquid.check_valvemanager_user_updated(valvemanager)) - newvalve = liquid.ValveDataManagerClass.model_validate_json(valvemanager.to_json()) + newvalve = ValveDataManagerClass.model_validate_json(valvemanager.to_json()) self.assertTrue(liquid.check_valvemanager_user_updated(newvalve)) @@ -121,7 +125,7 @@ def setUp(self): self.manager = liquid.create_empty_valve_data_manager() liquid.add_dummy_measurements(self.manager) self.json_str = self.manager.to_json() - self.loaded_manager = liquid.ValveDataManagerClass.model_validate_json(self.json_str) + self.loaded_manager = ValveDataManagerClass.model_validate_json(self.json_str) def test_json_round_trip(self): self.assertEqual(self.manager.n_valves, self.loaded_manager.n_valves) @@ -137,7 +141,7 @@ def test_json_round_trip(self): class TestSuggestDuration(unittest.TestCase): def setUp(self): - self.valve = liquid.ValveDataClass(ValveName="Test Valve") + self.valve = ValveDataClass(ValveName="Test Valve") self.range_low = 2 self.range_high = 10 self.suggest_duration = lambda: liquid.suggest_duration(self.valve, self.range_low, self.range_high) From 6df048d3b50b3e91628d81f9a0c2bcfcb1e33e57 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:17:43 +1000 Subject: [PATCH 04/30] Ruff format --- bpod_rig/calibration/liquid/liquid.py | 87 +++++++++++++---------- bpod_rig/calibration/liquid/models.py | 11 ++- tests/bpod_rig/calibration/test_liquid.py | 20 +++--- 3 files changed, 68 insertions(+), 50 deletions(-) diff --git a/bpod_rig/calibration/liquid/liquid.py b/bpod_rig/calibration/liquid/liquid.py index a84988f..36bd590 100644 --- a/bpod_rig/calibration/liquid/liquid.py +++ b/bpod_rig/calibration/liquid/liquid.py @@ -1,4 +1,5 @@ """Liquid calibration data management and calibratino routines.""" + import datetime import logging from pydantic import Field @@ -10,18 +11,18 @@ def create_empty_valve_data_manager( - source: str = 'statemachine', - n_valves: int = 8, + source: str = "statemachine", + n_valves: int = 8, ) -> ValveDataManagerClass: """Create an empty valve manager with 8 valves.""" valvemanager = ValveDataManagerClass() - if source == 'statemachine': + if source == "statemachine": for index in range(n_valves): valvemanager.create_valve(f"Valve{index + 1}") - elif source == 'portarray': + elif source == "portarray": for index in range(n_valves): for port_array in range(4): - valvemanager.create_valve(f"PA{port_array + 1 }_{index + 1}") + valvemanager.create_valve(f"PA{port_array + 1}_{index + 1}") else: raise ValueError(f"Unknown source for valve names: {source}") @@ -49,6 +50,7 @@ def add_dummy_measurements(valvemanager: ValveDataManagerClass) -> None: valvemanager.metadata.modification_datetime = origin_date + def create_default_json() -> str: """Produce a default valve JSON string with 8 valves and dummy measurements.""" dummy = create_empty_valve_data_manager() @@ -61,6 +63,7 @@ def create_default_json() -> str: ) return jsontext + def check_valvemanager_user_updated(valvemanager: ValveDataManagerClass) -> bool: """Check if the user has updated the valve manager has been updated from the example. @@ -81,9 +84,9 @@ def check_valvemanager_user_updated(valvemanager: ValveDataManagerClass) -> bool def suggest_duration( - valveobject: ValveDataClass, - range_low: float, - range_high: float, + valveobject: ValveDataClass, + range_low: float, + range_high: float, ) -> float: """Suggest a duration for a given valve and range. @@ -118,7 +121,7 @@ def suggest_duration( # Use a linear estimate suggested_duration_ms = linearly_suggest_duration( amounts[0], durations[0], range_low, range_high - ) + ) else: # Use an estimate of the middle of the range based on range and our experience with # the CSHL configuration (nResearch pinch valves, silastic tubing, specs in Bpod literature) @@ -132,17 +135,18 @@ def calculate_largest_gap_midpoint(sorted_values: list[float]) -> float: for y in range(1, len(sorted_values)): distances[y] = abs(sorted_values[y] - sorted_values[y - 1]) max_distance_pos = np.argmax(distances) - midpoint_value = sorted_values[max_distance_pos - 1] + ( - sorted_values[max_distance_pos] - sorted_values[max_distance_pos - 1] - ) / 2 + midpoint_value = ( + sorted_values[max_distance_pos - 1] + + (sorted_values[max_distance_pos] - sorted_values[max_distance_pos - 1]) / 2 + ) return midpoint_value def linearly_suggest_duration( - amount: float | np.typing.ArrayLike, - duration: float | np.typing.ArrayLike, - range_low: float, - range_high: float, + amount: float | np.typing.ArrayLike, + duration: float | np.typing.ArrayLike, + range_low: float, + range_high: float, ) -> float: """Suggest a duration based on a single measurement and a range.""" ul_per_ms = float(amount / duration) @@ -159,7 +163,9 @@ def linearly_suggest_duration( return suggested_duration -def calculate_ranged_amounts(amounts: np.array, range_low: float, range_high: float) -> list[float]: +def calculate_ranged_amounts( + amounts: np.array, range_low: float, range_high: float +) -> list[float]: """Calculate a sorted list of amounts within a given range, including the range bounds.""" amounts_vector = amounts.tolist() if range_low not in amounts: @@ -171,9 +177,10 @@ def calculate_ranged_amounts(amounts: np.array, range_low: float, range_high: fl amounts_vector = sorted(amounts_vector) startpoint = amounts_vector.index(range_low) endpoint = amounts_vector.index(range_high) - amounts_vector = amounts_vector[startpoint: endpoint + 1] + amounts_vector = amounts_vector[startpoint : endpoint + 1] return amounts_vector + def check_COM(valvemanager: ValveDataManagerClass, com_port: str) -> str: """Check if the COM port in the valve manager matches the given COM port. @@ -190,12 +197,13 @@ def check_COM(valvemanager: ValveDataManagerClass, com_port: str) -> str: 'yes' if the COM ports match, 'no' if it's different, 'unknown' otherwise. """ valvecom = valvemanager.metadata.COM - if valvecom == "": # e.g. uncalibrated - return 'unknown' + if valvecom == "": # e.g. uncalibrated + return "unknown" if valvecom == com_port: - return 'yes' + return "yes" else: - return 'no' + return "no" + class PendingValve(ValveDataClass): pending_durations: list[float] = Field( @@ -225,13 +233,17 @@ class PendingMeasurementsManager: def __init__(self, valvemanager: ValveDataManagerClass): self._valvemanager = valvemanager - self.valves = [PendingValve(ValveName=valve.name) for valve in valvemanager.valve_datas] + self.valves = [ + PendingValve(ValveName=valve.name) for valve in valvemanager.valve_datas + ] def get_pending(self, valvename: str) -> list[float]: if valvename not in self._valvemanager.valve_names: raise KeyError(f"Valvename '{valvename}' not found in manager.") - return next(valve for valve in self.valves if valve.name == valvename).pending_durations + return next( + valve for valve in self.valves if valve.name == valvename + ).pending_durations def add_pending(self, valvename: str, duration: float) -> None: if duration in self.get_pending(valvename): @@ -247,16 +259,19 @@ def remove_pending(self, valvename: str, duration: float) -> None: self.get_pending(valvename).remove(duration) - def complete_measurement(self, valvename: str, duration: float, amount: float) -> None: + def complete_measurement( + self, valvename: str, duration: float, amount: float + ) -> None: self.remove_pending(valvename, duration) self._valvemanager.get_valve(valvename).add_mesurement(duration, amount) + def run_calibration( - pending_manager: PendingMeasurementsManager, - n_pulses: int, - pulse_interval: float = 0.2, - pulse_set_pause: float = 0.5, - verbose: bool = True, + pending_manager: PendingMeasurementsManager, + n_pulses: int, + pulse_interval: float = 0.2, + pulse_set_pause: float = 0.5, + verbose: bool = True, ) -> None: """Run a calibration sequence using Bpod state machine. @@ -270,14 +285,14 @@ def run_calibration( fsm = StateMachine() fsm.add_state( - name='Port1Light', + name="Port1Light", timer=1, - state_change_conditions={'BNC1_High': 'Port2Light'}, - output_actions={'PWM1': 255}, + state_change_conditions={"BNC1_High": "Port2Light"}, + output_actions={"PWM1": 255}, ) fsm.add_state( - name='Port2Light', + name="Port2Light", timer=1, - state_change_conditions={'Tup': '>exit'}, - output_actions={'PWM2': 255}, + state_change_conditions={"Tup": ">exit"}, + output_actions={"PWM2": 255}, ) diff --git a/bpod_rig/calibration/liquid/models.py b/bpod_rig/calibration/liquid/models.py index 4b96986..16a5bca 100644 --- a/bpod_rig/calibration/liquid/models.py +++ b/bpod_rig/calibration/liquid/models.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class ValveDataClass(BaseModel): name: str = Field(alias="ValveName") lastdatemodified: datetime.datetime | str = Field( @@ -31,7 +32,9 @@ class ValveDataClass(BaseModel): ) model_config = ConfigDict( - serialize_by_alias=True, validate_by_name=True, validate_by_alias=True, + serialize_by_alias=True, + validate_by_name=True, + validate_by_alias=True, ) def get_valve_time(self, amount: float) -> float: @@ -126,7 +129,7 @@ def _update_coeffs(self) -> None: @field_serializer("lastdatemodified") def serialize_datetime(self, dt: datetime.datetime, _info): - if isinstance(dt, str): # if uncalibrated it is an empty string + if isinstance(dt, str): # if uncalibrated it is an empty string return dt return dt.isoformat(timespec="seconds") @@ -204,7 +207,9 @@ def to_json(self, machineid: str = None) -> str: # check if the COM is changing if (self.metadata.COM != "") & (self.metadata.COM != machineid): logger.warning( - "COM port of liquid calibration file is changing from %s to %s", self.metadata.COM, machineid + "COM port of liquid calibration file is changing from %s to %s", + self.metadata.COM, + machineid, ) self.metadata.COM = machineid self.metadata.modification_datetime = datetime.datetime.now() diff --git a/tests/bpod_rig/calibration/test_liquid.py b/tests/bpod_rig/calibration/test_liquid.py index b1bb2c3..b841451 100644 --- a/tests/bpod_rig/calibration/test_liquid.py +++ b/tests/bpod_rig/calibration/test_liquid.py @@ -4,17 +4,16 @@ from bpod_rig.calibration.liquid import liquid from bpod_rig.calibration.liquid.models import ValveDataClass, ValveDataManagerClass -class TestValveDataClass(unittest.TestCase): +class TestValveDataClass(unittest.TestCase): def setUp(self): self.valve = ValveDataClass(ValveName="Test Valve") def test_init_alias(self): # Test that alias works correctly # type checker may complain about this because of pydantic aliasing - self.assertEqual(ValveDataClass(name='attrname').name, 'attrname') - self.assertEqual( - ValveDataClass(ValveName='aliasname').name, 'aliasname') + self.assertEqual(ValveDataClass(name="attrname").name, "attrname") + self.assertEqual(ValveDataClass(ValveName="aliasname").name, "aliasname") def test_add_measurement(self): self.valve.add_measurement(10, 1.0) @@ -64,8 +63,8 @@ def test_no_coeffs_with_insufficient_data(self): self.valve.add_measurement(20, 2.0) self.assertIsNotNone(self.valve.coeffs) -class TestValveDataManagerClass(unittest.TestCase): +class TestValveDataManagerClass(unittest.TestCase): def setUp(self): self.manager = ValveDataManagerClass() self.manager.create_valve("Test Valve") @@ -95,6 +94,7 @@ def test_measurements(self): duration = self.dummy.get_valve(valve).get_valve_time(liquid_amount) self.assertAlmostEqual(duration, expected_duration, places=3) + class TestGen2CalibrationFile(unittest.TestCase): """Test loading a calibration file from the Bpod_Gen2 (MATLAB) repo.""" @@ -106,9 +106,7 @@ def setUpClass(cls): def test_load_calibration_file(self): loaded_valves = ValveDataManagerClass.model_validate_json(self.jsontext) - self.assertIsInstance(loaded_valves, - ValveDataManagerClass - ) + self.assertIsInstance(loaded_valves, ValveDataManagerClass) self.assertGreaterEqual(loaded_valves.n_valves, 1) self.assertIn("Valve1", loaded_valves.valve_names) @@ -120,7 +118,6 @@ def test_valve_modtime(self): class TestValveDataManagerJSON(unittest.TestCase): - def setUp(self): self.manager = liquid.create_empty_valve_data_manager() liquid.add_dummy_measurements(self.manager) @@ -139,12 +136,13 @@ def test_json_round_trip(self): class TestSuggestDuration(unittest.TestCase): - def setUp(self): self.valve = ValveDataClass(ValveName="Test Valve") self.range_low = 2 self.range_high = 10 - self.suggest_duration = lambda: liquid.suggest_duration(self.valve, self.range_low, self.range_high) + self.suggest_duration = lambda: liquid.suggest_duration( + self.valve, self.range_low, self.range_high + ) def test_no_measures(self): self.assertEqual(self.suggest_duration(), 48) From c42f5fee473bc04bb8480c55020ca1d6cbdf6557 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:18:46 +1000 Subject: [PATCH 05/30] Rename to utils --- bpod_rig/calibration/liquid/{liquid.py => utils.py} | 0 tests/bpod_rig/calibration/test_liquid.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename bpod_rig/calibration/liquid/{liquid.py => utils.py} (100%) diff --git a/bpod_rig/calibration/liquid/liquid.py b/bpod_rig/calibration/liquid/utils.py similarity index 100% rename from bpod_rig/calibration/liquid/liquid.py rename to bpod_rig/calibration/liquid/utils.py diff --git a/tests/bpod_rig/calibration/test_liquid.py b/tests/bpod_rig/calibration/test_liquid.py index b841451..d4e1af3 100644 --- a/tests/bpod_rig/calibration/test_liquid.py +++ b/tests/bpod_rig/calibration/test_liquid.py @@ -1,7 +1,7 @@ import unittest import urllib.request -from bpod_rig.calibration.liquid import liquid +from bpod_rig.calibration.liquid import utils from bpod_rig.calibration.liquid.models import ValveDataClass, ValveDataManagerClass From 330cf5cd0836f805abbe15fc87d794ce01690dd5 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:23:02 +1000 Subject: [PATCH 06/30] Refactor into pending_calibration.py --- .../calibration/liquid/pending_calibration.py | 94 +++++++++++++++++++ bpod_rig/calibration/liquid/utils.py | 91 ------------------ 2 files changed, 94 insertions(+), 91 deletions(-) create mode 100644 bpod_rig/calibration/liquid/pending_calibration.py diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py new file mode 100644 index 0000000..a006dbc --- /dev/null +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -0,0 +1,94 @@ +from calibration.liquid.models import ValveDataManagerClass, ValveDataClass + + +class PendingMeasurementsManager: + _valvemanager: ValveDataManagerClass + valves: list[PendingValve] + + def __init__(self, valvemanager: ValveDataManagerClass): + self._valvemanager = valvemanager + self.valves = [ + PendingValve(ValveName=valve.name) for valve in valvemanager.valve_datas + ] + + def get_pending(self, valvename: str) -> list[float]: + if valvename not in self._valvemanager.valve_names: + raise KeyError(f"Valvename '{valvename}' not found in manager.") + + return next( + valve for valve in self.valves if valve.name == valvename + ).pending_durations + + def add_pending(self, valvename: str, duration: float) -> None: + if duration in self.get_pending(valvename): + raise ValueError("Duration already pending.") + if duration in self._valvemanager.get_valve(valvename).durations: + raise ValueError("Duration already exists in valve calibration.") + + self.get_pending(valvename).append(duration) + + def remove_pending(self, valvename: str, duration: float) -> None: + if duration not in self.get_pending(valvename): + raise ValueError("Duration for completed measurement is not pending.") + + self.get_pending(valvename).remove(duration) + + def complete_measurement( + self, valvename: str, duration: float, amount: float + ) -> None: + self.remove_pending(valvename, duration) + self._valvemanager.get_valve(valvename).add_mesurement(duration, amount) + + +def run_calibration( + pending_manager: PendingMeasurementsManager, + n_pulses: int, + pulse_interval: float = 0.2, + pulse_set_pause: float = 0.5, + verbose: bool = True, +) -> None: + """Run a calibration sequence using Bpod state machine. + + Parameters + ---------- + + """ + + from bpod_core.fsm import StateMachine + + fsm = StateMachine() + + fsm.add_state( + name="Port1Light", + timer=1, + state_change_conditions={"BNC1_High": "Port2Light"}, + output_actions={"PWM1": 255}, + ) + fsm.add_state( + name="Port2Light", + timer=1, + state_change_conditions={"Tup": ">exit"}, + output_actions={"PWM2": 255}, + ) + + +class PendingValve(ValveDataClass): + pending_durations: list[float] = Field( + default=[], + description="List of durations that are pending calibration measurements.", + ) + + valve_controller: str = Field( + default="unknown", + description="The controlling hardware of the valve (e.g., 'StateMachine', 'PortArray').", + ) + + def __init__(self, **data): + super().__init__(**data) + if self.name.startswith("PA"): + self.valve_controller = "PortArray" + else: + self.valve_controller = "StateMachine" + + def get_calibration_run_info(self) -> float: + pass diff --git a/bpod_rig/calibration/liquid/utils.py b/bpod_rig/calibration/liquid/utils.py index 36bd590..c46b2bd 100644 --- a/bpod_rig/calibration/liquid/utils.py +++ b/bpod_rig/calibration/liquid/utils.py @@ -205,94 +205,3 @@ def check_COM(valvemanager: ValveDataManagerClass, com_port: str) -> str: return "no" -class PendingValve(ValveDataClass): - pending_durations: list[float] = Field( - default=[], - description="List of durations that are pending calibration measurements.", - ) - - valve_controller: str = Field( - default="unknown", - description="The controlling hardware of the valve (e.g., 'StateMachine', 'PortArray').", - ) - - def __init__(self, **data): - super().__init__(**data) - if self.name.startswith("PA"): - self.valve_controller = "PortArray" - else: - self.valve_controller = "StateMachine" - - def get_calibration_run_info(self) -> float: - pass - - -class PendingMeasurementsManager: - _valvemanager: ValveDataManagerClass - valves: list[PendingValve] - - def __init__(self, valvemanager: ValveDataManagerClass): - self._valvemanager = valvemanager - self.valves = [ - PendingValve(ValveName=valve.name) for valve in valvemanager.valve_datas - ] - - def get_pending(self, valvename: str) -> list[float]: - if valvename not in self._valvemanager.valve_names: - raise KeyError(f"Valvename '{valvename}' not found in manager.") - - return next( - valve for valve in self.valves if valve.name == valvename - ).pending_durations - - def add_pending(self, valvename: str, duration: float) -> None: - if duration in self.get_pending(valvename): - raise ValueError("Duration already pending.") - if duration in self._valvemanager.get_valve(valvename).durations: - raise ValueError("Duration already exists in valve calibration.") - - self.get_pending(valvename).append(duration) - - def remove_pending(self, valvename: str, duration: float) -> None: - if duration not in self.get_pending(valvename): - raise ValueError("Duration for completed measurement is not pending.") - - self.get_pending(valvename).remove(duration) - - def complete_measurement( - self, valvename: str, duration: float, amount: float - ) -> None: - self.remove_pending(valvename, duration) - self._valvemanager.get_valve(valvename).add_mesurement(duration, amount) - - -def run_calibration( - pending_manager: PendingMeasurementsManager, - n_pulses: int, - pulse_interval: float = 0.2, - pulse_set_pause: float = 0.5, - verbose: bool = True, -) -> None: - """Run a calibration sequence using Bpod state machine. - - Parameters - ---------- - - """ - - from bpod_core.fsm import StateMachine - - fsm = StateMachine() - - fsm.add_state( - name="Port1Light", - timer=1, - state_change_conditions={"BNC1_High": "Port2Light"}, - output_actions={"PWM1": 255}, - ) - fsm.add_state( - name="Port2Light", - timer=1, - state_change_conditions={"Tup": ">exit"}, - output_actions={"PWM2": 255}, - ) From 843f4840f47201d7c5da0d2c47ea89ceaeed7541 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:23:11 +1000 Subject: [PATCH 07/30] Update liquid to utils --- tests/bpod_rig/calibration/test_liquid.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/bpod_rig/calibration/test_liquid.py b/tests/bpod_rig/calibration/test_liquid.py index d4e1af3..42ebe28 100644 --- a/tests/bpod_rig/calibration/test_liquid.py +++ b/tests/bpod_rig/calibration/test_liquid.py @@ -68,8 +68,8 @@ class TestValveDataManagerClass(unittest.TestCase): def setUp(self): self.manager = ValveDataManagerClass() self.manager.create_valve("Test Valve") - dummy = liquid.create_empty_valve_data_manager() - liquid.add_dummy_measurements(dummy) + dummy = utils.create_empty_valve_data_manager() + utils.add_dummy_measurements(dummy) self.dummy = dummy def test_create_valve(self): @@ -112,15 +112,15 @@ def test_load_calibration_file(self): def test_valve_modtime(self): valvemanager = ValveDataManagerClass.model_validate_json(self.jsontext) - self.assertFalse(liquid.check_valvemanager_user_updated(valvemanager)) + self.assertFalse(utils.check_valvemanager_user_updated(valvemanager)) newvalve = ValveDataManagerClass.model_validate_json(valvemanager.to_json()) - self.assertTrue(liquid.check_valvemanager_user_updated(newvalve)) + self.assertTrue(utils.check_valvemanager_user_updated(newvalve)) class TestValveDataManagerJSON(unittest.TestCase): def setUp(self): - self.manager = liquid.create_empty_valve_data_manager() - liquid.add_dummy_measurements(self.manager) + self.manager = utils.create_empty_valve_data_manager() + utils.add_dummy_measurements(self.manager) self.json_str = self.manager.to_json() self.loaded_manager = ValveDataManagerClass.model_validate_json(self.json_str) @@ -140,7 +140,7 @@ def setUp(self): self.valve = ValveDataClass(ValveName="Test Valve") self.range_low = 2 self.range_high = 10 - self.suggest_duration = lambda: liquid.suggest_duration( + self.suggest_duration = lambda: utils.suggest_duration( self.valve, self.range_low, self.range_high ) From 59f9cdb5f415b8cd7eb4c87a7df7e2212f904542 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:25:19 +1000 Subject: [PATCH 08/30] Refactor models.py tests --- tests/bpod_rig/calibration/liquid/__init__.py | 0 .../calibration/liquid/test_liquid.py | 26 +++++++++++++++++++ .../{test_liquid.py => liquid/test_models.py} | 22 ---------------- 3 files changed, 26 insertions(+), 22 deletions(-) create mode 100644 tests/bpod_rig/calibration/liquid/__init__.py create mode 100644 tests/bpod_rig/calibration/liquid/test_liquid.py rename tests/bpod_rig/calibration/{test_liquid.py => liquid/test_models.py} (88%) diff --git a/tests/bpod_rig/calibration/liquid/__init__.py b/tests/bpod_rig/calibration/liquid/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/bpod_rig/calibration/liquid/test_liquid.py b/tests/bpod_rig/calibration/liquid/test_liquid.py new file mode 100644 index 0000000..a5051dc --- /dev/null +++ b/tests/bpod_rig/calibration/liquid/test_liquid.py @@ -0,0 +1,26 @@ +import unittest + +from bpod_rig.calibration.liquid import utils +from bpod_rig.calibration.liquid.models import ValveDataClass + + +class TestSuggestDuration(unittest.TestCase): + def setUp(self): + self.valve = ValveDataClass(ValveName="Test Valve") + self.range_low = 2 + self.range_high = 10 + self.suggest_duration = lambda: utils.suggest_duration( + self.valve, self.range_low, self.range_high + ) + + def test_no_measures(self): + self.assertEqual(self.suggest_duration(), 48) + + def test_one_measure(self): + self.valve.add_measurement(57, 7.5) + self.assertEqual(self.suggest_duration(), 36) + + def test_two_measures(self): + self.valve.add_measurement(22, 2) + self.valve.add_measurement(66, 9.5) + self.assertAlmostEqual(self.suggest_duration(), 44.0, places=3) diff --git a/tests/bpod_rig/calibration/test_liquid.py b/tests/bpod_rig/calibration/liquid/test_models.py similarity index 88% rename from tests/bpod_rig/calibration/test_liquid.py rename to tests/bpod_rig/calibration/liquid/test_models.py index 42ebe28..20e817f 100644 --- a/tests/bpod_rig/calibration/test_liquid.py +++ b/tests/bpod_rig/calibration/liquid/test_models.py @@ -133,25 +133,3 @@ def test_json_round_trip(self): self.assertEqual(original_valve.amounts, loaded_valve.amounts) self.assertEqual(original_valve.durations, loaded_valve.durations) self.assertEqual(original_valve.coeffs, loaded_valve.coeffs) - - -class TestSuggestDuration(unittest.TestCase): - def setUp(self): - self.valve = ValveDataClass(ValveName="Test Valve") - self.range_low = 2 - self.range_high = 10 - self.suggest_duration = lambda: utils.suggest_duration( - self.valve, self.range_low, self.range_high - ) - - def test_no_measures(self): - self.assertEqual(self.suggest_duration(), 48) - - def test_one_measure(self): - self.valve.add_measurement(57, 7.5) - self.assertEqual(self.suggest_duration(), 36) - - def test_two_measures(self): - self.valve.add_measurement(22, 2) - self.valve.add_measurement(66, 9.5) - self.assertAlmostEqual(self.suggest_duration(), 44.0, places=3) From da711251068275435572b37ecc619a6793726511 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:25:39 +1000 Subject: [PATCH 09/30] Rename utils.py --- .../bpod_rig/calibration/liquid/{test_liquid.py => test_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/bpod_rig/calibration/liquid/{test_liquid.py => test_utils.py} (100%) diff --git a/tests/bpod_rig/calibration/liquid/test_liquid.py b/tests/bpod_rig/calibration/liquid/test_utils.py similarity index 100% rename from tests/bpod_rig/calibration/liquid/test_liquid.py rename to tests/bpod_rig/calibration/liquid/test_utils.py From c0423be03a77322f809984b0e2d0f8ead31d1574 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 5 Sep 2025 23:35:34 +1000 Subject: [PATCH 10/30] Create example file --- .../calibration/liquid/populate_examples.py | 14 +++ .../calibration/LiquidCalibration.json | 96 +++++++++++++++++++ 2 files changed, 110 insertions(+) create mode 100644 bpod_rig/calibration/liquid/populate_examples.py create mode 100644 bpod_rig/examples/calibration/LiquidCalibration.json diff --git a/bpod_rig/calibration/liquid/populate_examples.py b/bpod_rig/calibration/liquid/populate_examples.py new file mode 100644 index 0000000..dccffc5 --- /dev/null +++ b/bpod_rig/calibration/liquid/populate_examples.py @@ -0,0 +1,14 @@ +"""Create the example liquid calibration JSON with 8 vales and dummy measurements.""" +from pathlib import Path + +from bpod_rig.examples import calibration as example_folder +from .utils import create_default_json + +def main(): + """Create the example liquid calibration JSON with 8 vales and dummy measurements.""" + example_json = create_default_json() + example_path = Path(example_folder.__path__[0]) / "LiquidCalibration.json" + example_path.write_text(example_json) + +if __name__ == "__main__": + main() diff --git a/bpod_rig/examples/calibration/LiquidCalibration.json b/bpod_rig/examples/calibration/LiquidCalibration.json new file mode 100644 index 0000000..4454d0c --- /dev/null +++ b/bpod_rig/examples/calibration/LiquidCalibration.json @@ -0,0 +1,96 @@ +{ + "metadata": { + "modification_datetime": "2000-01-01T00:00:00", + "COM": "" + }, + "ValveDatas": [ + { + "ValveName": "Valve1", + "LastDateModified": "2000-01-01T00:00:00", + "Coeffs": [ + -0.30576102280847484, + 9.320110469495358, + 4.82071882423374 + ], + "Durations": [ + 22, + 66, + 44, + 57, + 34 + ], + "Amounts": [ + 2, + 9.5, + 5, + 7.5, + 3.5 + ] + }, + { + "ValveName": "Valve2", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve3", + "LastDateModified": "2000-01-01T00:00:00", + "Coeffs": [ + 0.047977795400477294, + 5.724028548770809, + 13.60021808088822 + ], + "Durations": [ + 22, + 46, + 59, + 35, + 66 + ], + "Amounts": [ + 1.5, + 5.5, + 7.5, + 3.5, + 8.5 + ] + }, + { + "ValveName": "Valve4", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve5", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve6", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve7", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + }, + { + "ValveName": "Valve8", + "LastDateModified": "", + "Coeffs": [], + "Durations": [], + "Amounts": [] + } + ] +} \ No newline at end of file From 042e8a662fc8527615470636982aae3da24e41f3 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Sun, 7 Sep 2025 22:15:41 +1000 Subject: [PATCH 11/30] Refactor example population --- .../calibration/liquid/populate_examples.py | 43 ++++++++++++++++++- bpod_rig/calibration/liquid/utils.py | 36 ---------------- .../calibration/liquid/test_models.py | 6 +-- 3 files changed, 44 insertions(+), 41 deletions(-) diff --git a/bpod_rig/calibration/liquid/populate_examples.py b/bpod_rig/calibration/liquid/populate_examples.py index dccffc5..9309e85 100644 --- a/bpod_rig/calibration/liquid/populate_examples.py +++ b/bpod_rig/calibration/liquid/populate_examples.py @@ -1,14 +1,53 @@ """Create the example liquid calibration JSON with 8 vales and dummy measurements.""" +import datetime from pathlib import Path from bpod_rig.examples import calibration as example_folder -from .utils import create_default_json +from bpod_rig.calibration.liquid.models import ValveDataManagerClass +from bpod_rig.calibration.liquid.utils import create_empty_valve_data_manager + + +def add_dummy_measurements(valvemanager: ValveDataManagerClass) -> None: + """Add dummy measurements to the valves in manager.""" + origin_date = datetime.datetime(2000, 1, 1) + valve1 = valvemanager.get_valve("Valve1") + valve1.add_measurement(22, 2) + valve1.add_measurement(66, 9.5) + valve1.add_measurement(44, 5) + valve1.add_measurement(57, 7.5) + valve1.add_measurement(34, 3.5) + valve1.lastdatemodified = origin_date + + valve3 = valvemanager.get_valve("Valve3") + valve3.add_measurement(22, 1.5) + valve3.add_measurement(46, 5.5) + valve3.add_measurement(59, 7.5) + valve3.add_measurement(35, 3.5) + valve3.add_measurement(66, 8.5) + valve3.lastdatemodified = origin_date + + valvemanager.metadata.modification_datetime = origin_date + + +def create_default_json() -> str: + """Produce a default valve JSON string with 8 valves and dummy measurements.""" + dummy = create_empty_valve_data_manager() + add_dummy_measurements(dummy) + jsontext = dummy.to_json(machineid=None) + # replace modification datetime with a fixed date for consistency + jsontext = jsontext.replace( + dummy.metadata.modification_datetime.isoformat(timespec="seconds"), + "2000-01-01T00:00:00", + ) + return jsontext + def main(): - """Create the example liquid calibration JSON with 8 vales and dummy measurements.""" + """Create the example liquid calibration JSON with 8 valves and dummy measurements.""" example_json = create_default_json() example_path = Path(example_folder.__path__[0]) / "LiquidCalibration.json" example_path.write_text(example_json) + if __name__ == "__main__": main() diff --git a/bpod_rig/calibration/liquid/utils.py b/bpod_rig/calibration/liquid/utils.py index c46b2bd..1844bdc 100644 --- a/bpod_rig/calibration/liquid/utils.py +++ b/bpod_rig/calibration/liquid/utils.py @@ -2,7 +2,6 @@ import datetime import logging -from pydantic import Field import numpy as np from .models import ValveDataClass, ValveDataManagerClass @@ -29,41 +28,6 @@ def create_empty_valve_data_manager( return valvemanager -def add_dummy_measurements(valvemanager: ValveDataManagerClass) -> None: - """Add dummy measurements to the valves in manager.""" - origin_date = datetime.datetime(2000, 1, 1) - valve1 = valvemanager.get_valve("Valve1") - valve1.add_measurement(22, 2) - valve1.add_measurement(66, 9.5) - valve1.add_measurement(44, 5) - valve1.add_measurement(57, 7.5) - valve1.add_measurement(34, 3.5) - valve1.lastdatemodified = origin_date - - valve3 = valvemanager.get_valve("Valve3") - valve3.add_measurement(22, 1.5) - valve3.add_measurement(46, 5.5) - valve3.add_measurement(59, 7.5) - valve3.add_measurement(35, 3.5) - valve3.add_measurement(66, 8.5) - valve3.lastdatemodified = origin_date - - valvemanager.metadata.modification_datetime = origin_date - - -def create_default_json() -> str: - """Produce a default valve JSON string with 8 valves and dummy measurements.""" - dummy = create_empty_valve_data_manager() - add_dummy_measurements(dummy) - jsontext = dummy.to_json(machineid=None) - # replace modification datetime with a fixed date for consistency - jsontext = jsontext.replace( - dummy.metadata.modification_datetime.isoformat(timespec="seconds"), - "2000-01-01T00:00:00", - ) - return jsontext - - def check_valvemanager_user_updated(valvemanager: ValveDataManagerClass) -> bool: """Check if the user has updated the valve manager has been updated from the example. diff --git a/tests/bpod_rig/calibration/liquid/test_models.py b/tests/bpod_rig/calibration/liquid/test_models.py index 20e817f..3d04d10 100644 --- a/tests/bpod_rig/calibration/liquid/test_models.py +++ b/tests/bpod_rig/calibration/liquid/test_models.py @@ -1,7 +1,7 @@ import unittest import urllib.request -from bpod_rig.calibration.liquid import utils +from bpod_rig.calibration.liquid import utils, populate_examples from bpod_rig.calibration.liquid.models import ValveDataClass, ValveDataManagerClass @@ -69,7 +69,7 @@ def setUp(self): self.manager = ValveDataManagerClass() self.manager.create_valve("Test Valve") dummy = utils.create_empty_valve_data_manager() - utils.add_dummy_measurements(dummy) + populate_examples.add_dummy_measurements(dummy) self.dummy = dummy def test_create_valve(self): @@ -120,7 +120,7 @@ def test_valve_modtime(self): class TestValveDataManagerJSON(unittest.TestCase): def setUp(self): self.manager = utils.create_empty_valve_data_manager() - utils.add_dummy_measurements(self.manager) + populate_examples.add_dummy_measurements(self.manager) self.json_str = self.manager.to_json() self.loaded_manager = ValveDataManagerClass.model_validate_json(self.json_str) From 2ea8718ca85d16cdcb66a103213f529550ee2d0e Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Sun, 7 Sep 2025 23:07:02 +1000 Subject: [PATCH 12/30] Improve pending valve logic - Add functionality to PendingValve - Create add_valve_states - Add tests --- .../calibration/liquid/pending_calibration.py | 209 ++++++++++++++---- .../liquid/test_pending_calibration.py | 67 ++++++ 2 files changed, 237 insertions(+), 39 deletions(-) create mode 100644 tests/bpod_rig/calibration/liquid/test_pending_calibration.py diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py index a006dbc..eed2e59 100644 --- a/bpod_rig/calibration/liquid/pending_calibration.py +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -1,4 +1,61 @@ -from calibration.liquid.models import ValveDataManagerClass, ValveDataClass +from pydantic import Field +from bpod_core.fsm import StateMachine + +from bpod_rig.calibration.liquid.models import ValveDataManagerClass, ValveDataClass + + +class PendingValve(ValveDataClass): + pending_durations: list[float] = Field( + default=[], + description="List of durations that are pending calibration measurements.", + ) + + valve_controller: str = Field( + default="unknown", + description="The controlling hardware of the valve (e.g., 'StateMachine', 'PortArray').", + ) + + def __init__(self, **data): + super().__init__(**data) + if self.name.startswith("PA"): + self.valve_controller = "PortArray" + elif self.name.startswith("Valve"): + self.valve_controller = "StateMachine" + else: + raise ValueError("Valve name not recognised for controller assignment.") + + def action(self, actiontype: str) -> dict[str, str | list[str | int]]: + assert actiontype in ["open", "close"] + if self.valve_controller == "StateMachine": + valve_id = int(self.name[5:]) + if actiontype == "open": + # matlab code + # ValvePhysicalAddress = 2.^(0:7); + # ValveAddress = ValvePhysicalAddress(valveID) + valve_address = 2 ** (valve_id - 1) + return {'ValveState': valve_address} + elif actiontype == "close": + raise ValueError("StateMachine valves do not have a close action.") + else: + raise ValueError("Action type not recognised.") + elif self.valve_controller == "PortArray": + target_module = self.name[0:3] + portnumber = self.name[4] + if actiontype == "open": + # TODO: serial message byte parsing hasn't been checked + action_serial = ['V', portnumber, 1] + return {target_module: action_serial} + elif actiontype == "close": + action_serial = ['V', portnumber, 0] + return {target_module: action_serial} + else: + raise ValueError("Action type not recognised") + else: + raise ValueError("Valve controller not recognised.") + + @property + def is_pending(self) -> bool: + return len(self.pending_durations) > 0 class PendingMeasurementsManager: @@ -36,8 +93,84 @@ def remove_pending(self, valvename: str, duration: float) -> None: def complete_measurement( self, valvename: str, duration: float, amount: float ) -> None: + assert duration in self.get_pending(valvename) self.remove_pending(valvename, duration) - self._valvemanager.get_valve(valvename).add_mesurement(duration, amount) + self._valvemanager.get_valve(valvename).add_measurement(duration, amount) + + +def add_valve_states( + fsm: StateMachine, pending_valve: PendingValve, next_valve: PendingValve | float, + delay_duration: float +) -> tuple[str, float]: + """Add the pending valve to a state machine. + + Parameters + ---------- + fsm : StateMachine + Existing state machine to add states for the valve to. + pending_valve : PendingValve + Valve to create states to use for calibration + next_valve : PendingValve | float + The next valve object in the cycle, or the float for the pause duration (s) between pulse sets. + delay_duration : float + + Returns + ------- + tuple[str, float] + Valve name and the duration (ms) of the pulse added to the state machine. + """ + assert len(pending_valve.durations) > 0 + duration = pending_valve.pending_durations[0] + last_valve = isinstance(next_valve, float) + if last_valve: + pulse_set_pause_duration = next_valve + next_state = 'PulseSetPause' + else: + assert (isinstance(next_valve, PendingValve)) + pulse_set_pause_duration = None + next_state = f"Pulse{next_valve.name}" + + # Add valve openings to the state machine + if pending_valve.valve_controller == 'StateMachine': + # Valve is wired i.e. open while state is active + fsm.add_state( + name=f"Pulse{pending_valve.name}", + timer=duration / 1000, + state_change_conditions={'Tup': f"Delay{pending_valve.name}"}, + output_actions={'Valve1'} + ) + elif pending_valve.valve_controller == 'PortArray': + # Valve is controlled by serial messaging i.e. discrete open/close messages + fsm.add_state( + name=f"Pulse{pending_valve.name}", + timer=duration / 1000, + state_change_conditions={'Tup': f"EndPulse{pending_valve.name}"}, + output_actins={} + ) + fsm.add_state( + name=f"EndPulse{pending_valve.name}", + timer=0, + state_change_conditions={'Tup': f"Delay{pending_valve.name}"}, + output_actions={} + ) + else: + raise ValueError("Valve controller not recognised.") + + fsm.add_state( + name=f"Delay{pending_valve.name}", + timer=delay_duration, + state_change_conditions={'Tup', next_state}, + output_actions={} + ) + + if last_valve: + fsm.add_state( + name="PulseSetPause", + timer=pulse_set_pause_duration, + state_change_conditions={'Tup': 'exit'}, + output_actions={} + ) + return pending_valve.name, duration def run_calibration( @@ -46,49 +179,47 @@ def run_calibration( pulse_interval: float = 0.2, pulse_set_pause: float = 0.5, verbose: bool = True, -) -> None: +) -> dict[str, float]: """Run a calibration sequence using Bpod state machine. Parameters ---------- - + pending_manager : PendingMeasurementsManager + Object storing valve data + n_pulses : int + Number of pulses to deliver for all valves. + pulse_interval : float, optional + Time (s) between each pulse of each valve. + pulse_set_pause : float, optional + Time (s) between a round (after all valves pulsed) + verbose : bool, optional + Print progress. + + Returns + ------- + dict[str, float] + Dictionary of valve names and the duration (ms) of the pulse added to the state machine. """ - from bpod_core.fsm import StateMachine - + # Build the state machine fsm = StateMachine() - - fsm.add_state( - name="Port1Light", - timer=1, - state_change_conditions={"BNC1_High": "Port2Light"}, - output_actions={"PWM1": 255}, - ) - fsm.add_state( - name="Port2Light", - timer=1, - state_change_conditions={"Tup": ">exit"}, - output_actions={"PWM2": 255}, - ) - - -class PendingValve(ValveDataClass): - pending_durations: list[float] = Field( - default=[], - description="List of durations that are pending calibration measurements.", - ) - - valve_controller: str = Field( - default="unknown", - description="The controlling hardware of the valve (e.g., 'StateMachine', 'PortArray').", - ) - - def __init__(self, **data): - super().__init__(**data) - if self.name.startswith("PA"): - self.valve_controller = "PortArray" + pending_valves = [valve for valve in pending_manager.valves if valve.is_pending] + n_pending = len(pending_valves) + if n_pending == 0: + raise AssertionError("No pending valves found.") + test_set = {} + for x, pending_valve in enumerate(pending_valves): + if x < n_pending - 1: + next_valve = pending_valves[x + 1] else: - self.valve_controller = "StateMachine" - - def get_calibration_run_info(self) -> float: + next_valve = pulse_set_pause + name, duration = add_valve_states( + fsm, pending_valve, next_valve, pulse_interval + ) + test_set[name] = duration + + # Run the state machine + for trial in range(n_pulses): pass + + return test_set diff --git a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py new file mode 100644 index 0000000..06f71f0 --- /dev/null +++ b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py @@ -0,0 +1,67 @@ +import unittest + +from bpod_rig.calibration.liquid import pending_calibration +from bpod_rig.calibration.liquid import utils, populate_examples + + +class test_PendingValve(unittest.TestCase): + def setUp(self): + self.pending_valve = pending_calibration.PendingValve(ValveName="Valve3") + + def test_valve_controller_assignment(self): + valve1 = pending_calibration.PendingValve(ValveName="PA1") + valve2 = pending_calibration.PendingValve(ValveName="Valve1") + self.assertEqual(valve1.valve_controller, "PortArray") + self.assertEqual(valve2.valve_controller, "StateMachine") + + def test_unrecognied_controller_assignment(self): + with self.assertRaises(ValueError): + valve1 = pending_calibration.PendingValve(ValveName="HeartValve") + + def test_ispending(self): + valve = pending_calibration.PendingValve(ValveName="Valve1") + self.assertFalse(valve.is_pending) + valve = pending_calibration.PendingValve(ValveName="Valve1") + valve.pending_durations.append(20) + self.assertTrue(valve.is_pending) + + +class test_PendingMeasurementsManager(unittest.TestCase): + def setUp(self): + self.valvemanager = utils.create_empty_valve_data_manager() + populate_examples.add_dummy_measurements(self.valvemanager) + self.manager = pending_calibration.PendingMeasurementsManager(self.valvemanager) + + def test_get_pending_initial(self): + self.assertEqual(self.manager.get_pending("Valve1"), []) + self.assertEqual(self.manager.get_pending("Valve2"), []) + + def test_get_pending_invalid_valve(self): + with self.assertRaises(KeyError): + self.manager.get_pending("InvalidValve") + + def test_add_pending(self): + self.manager.add_pending("Valve1", 10.0) + self.assertIn(10.0, self.manager.get_pending("Valve1")) + + def test_add_pending_duplicate(self): + self.manager.add_pending("Valve1", 10.0) + with self.assertRaises(ValueError): + self.manager.add_pending("Valve1", 10.0) + + def test_remove_pending(self): + self.manager.add_pending("Valve1", 10.0) + self.manager.remove_pending("Valve1", 10.0) + self.assertNotIn(10.0, self.manager.get_pending("Valve1")) + + def test_remove_pending_nonexistent(self): + with self.assertRaises(ValueError): + self.manager.remove_pending("Valve1", 10.0) + + def test_complete_measurement(self): + self.manager.add_pending("Valve2", 10.0) + self.manager.complete_measurement("Valve2", 10.0, 1.5) + valve = self.valvemanager.get_valve("Valve2") + self.assertIn(10.0, valve.durations) + self.assertIn(1.5, valve.amounts) + self.assertNotIn(10.0, self.manager.get_pending("Valve2")) From 2e34367c77ef5a6e2d438e38b4b0eb36117e51bd Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:06:24 +1000 Subject: [PATCH 13/30] Refactor pending manager - Functional state machine factory (but not hardware tested) - Easier parameter modification for state machine construction - Cleaner output action construction --- .../calibration/liquid/pending_calibration.py | 143 ++++++++++++------ .../liquid/test_pending_calibration.py | 18 ++- 2 files changed, 117 insertions(+), 44 deletions(-) diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py index eed2e59..b4a98e0 100644 --- a/bpod_rig/calibration/liquid/pending_calibration.py +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -59,8 +59,15 @@ def is_pending(self) -> bool: class PendingMeasurementsManager: + """Manage valves and their pending values. + + Object would handle all + """ _valvemanager: ValveDataManagerClass valves: list[PendingValve] + n_pulses: int = 100 + pulse_interval: float = 0.2 + pulse_set_pause: float = 0.5 def __init__(self, valvemanager: ValveDataManagerClass): self._valvemanager = valvemanager @@ -68,15 +75,21 @@ def __init__(self, valvemanager: ValveDataManagerClass): PendingValve(ValveName=valve.name) for valve in valvemanager.valve_datas ] - def get_pending(self, valvename: str) -> list[float]: + + def get_valve(self, valvename: str) -> PendingValve: + """Retrieve object for a pending valve.""" if valvename not in self._valvemanager.valve_names: raise KeyError(f"Valvename '{valvename}' not found in manager.") + return next(valve for valve in self.valves if valve.name == valvename) + + + def get_pending(self, valvename: str) -> list[float]: + """Return pending durations from a given valve.""" + return self.get_valve(valvename).pending_durations - return next( - valve for valve in self.valves if valve.name == valvename - ).pending_durations def add_pending(self, valvename: str, duration: float) -> None: + """Add a pending duration to the valve's list.""" if duration in self.get_pending(valvename): raise ValueError("Duration already pending.") if duration in self._valvemanager.get_valve(valvename).durations: @@ -84,24 +97,94 @@ def add_pending(self, valvename: str, duration: float) -> None: self.get_pending(valvename).append(duration) + def remove_pending(self, valvename: str, duration: float) -> None: + """Remove a pending duration from the valve's pending record.""" if duration not in self.get_pending(valvename): raise ValueError("Duration for completed measurement is not pending.") self.get_pending(valvename).remove(duration) + def complete_measurement( self, valvename: str, duration: float, amount: float ) -> None: + """Record an amount dispensed for a duration that was measured. + + Remove duration from pending list and adds completed measurement to the valve + manager. + + Parameters + ---------- + valvename : str + Name of the valve that has been tested and weighed. + duration : float + Duration (ms) of valve opening + amount : float + Quantity (mg or mL) released from the duration. + """ assert duration in self.get_pending(valvename) self.remove_pending(valvename, duration) self._valvemanager.get_valve(valvename).add_measurement(duration, amount) + def build_test_set(self) -> dict[str, float]: + """Create dict of valves and the duration to test. + + Will take the first (oldest) duration in the valve's pending values. + """ + pending_valves = [valve for valve in self.valves if valve.is_pending] + test_set = {} + for valve in pending_valves: + test_set[valve.name] = valve.pending_durations[0] + return test_set + + def build_statemachine(self, test_set: dict | None = None) -> tuple[StateMachine, dict[str, float]]: + """Build the state machine to test all pending valves. + + The state machine is constructed using object parameters: + - n_pulses: number of pulses to deliver + - pulse_interval: Duration (ms) + - pulse_set_pause: Duration (ms) + + The amount of liquid dispensed from the state machine + For example, 100 pulses totaling 0.200g (200mg) is 0.200g / 100 = 2mg = 2mL + + Parameters + ---------- + test_set : dict | None, optional + Names of valves and durations to build state machine for. If not provided, test all valves that have pending durations using their oldest duration. + Returns + ------- + tuple[StateMachine, dict[str, float]] + - The finite state machine. + - dict of which valves and what duration the state machine was built for. + """ + fsm = StateMachine() + if test_set is None: + test_set = self.build_test_set() + valve_names = list(test_set.keys()) + pending_valves = [valve for valve in self.valves if valve.is_pending] + n_pending = len(pending_valves) + if n_pending == 0: + raise AssertionError("No pending valves found.") + for x, valve_name in enumerate(valve_names): + pending_valve = self.get_valve(valve_name) + duration = test_set[valve_name] + if x < n_pending - 1: + next_valve = self.get_valve(valve_names[x + 1]) + else: + next_valve = self.pulse_set_pause + add_valve_states( + fsm, pending_valve, duration, next_valve, self.pulse_interval + ) + return fsm, test_set + def add_valve_states( - fsm: StateMachine, pending_valve: PendingValve, next_valve: PendingValve | float, + fsm: StateMachine, pending_valve: PendingValve, duration: float, + next_valve: PendingValve | float, delay_duration: float -) -> tuple[str, float]: +): """Add the pending valve to a state machine. Parameters @@ -110,17 +193,12 @@ def add_valve_states( Existing state machine to add states for the valve to. pending_valve : PendingValve Valve to create states to use for calibration + duration : float + Duration (ms) of valve open. next_valve : PendingValve | float The next valve object in the cycle, or the float for the pause duration (s) between pulse sets. delay_duration : float - - Returns - ------- - tuple[str, float] - Valve name and the duration (ms) of the pulse added to the state machine. """ - assert len(pending_valve.durations) > 0 - duration = pending_valve.pending_durations[0] last_valve = isinstance(next_valve, float) if last_valve: pulse_set_pause_duration = next_valve @@ -137,7 +215,7 @@ def add_valve_states( name=f"Pulse{pending_valve.name}", timer=duration / 1000, state_change_conditions={'Tup': f"Delay{pending_valve.name}"}, - output_actions={'Valve1'} + output_actions=pending_valve.action('open') ) elif pending_valve.valve_controller == 'PortArray': # Valve is controlled by serial messaging i.e. discrete open/close messages @@ -145,13 +223,13 @@ def add_valve_states( name=f"Pulse{pending_valve.name}", timer=duration / 1000, state_change_conditions={'Tup': f"EndPulse{pending_valve.name}"}, - output_actins={} + output_actins=pending_valve.action('open') ) fsm.add_state( name=f"EndPulse{pending_valve.name}", timer=0, state_change_conditions={'Tup': f"Delay{pending_valve.name}"}, - output_actions={} + output_actions=pending_valve.action('close') ) else: raise ValueError("Valve controller not recognised.") @@ -159,8 +237,8 @@ def add_valve_states( fsm.add_state( name=f"Delay{pending_valve.name}", timer=delay_duration, - state_change_conditions={'Tup', next_state}, - output_actions={} + state_change_conditions={'Tup': next_state}, + # output_actions={} ) if last_valve: @@ -168,18 +246,13 @@ def add_valve_states( name="PulseSetPause", timer=pulse_set_pause_duration, state_change_conditions={'Tup': 'exit'}, - output_actions={} + # output_actions={} ) - return pending_valve.name, duration def run_calibration( pending_manager: PendingMeasurementsManager, - n_pulses: int, - pulse_interval: float = 0.2, - pulse_set_pause: float = 0.5, - verbose: bool = True, -) -> dict[str, float]: +): """Run a calibration sequence using Bpod state machine. Parameters @@ -202,24 +275,8 @@ def run_calibration( """ # Build the state machine - fsm = StateMachine() - pending_valves = [valve for valve in pending_manager.valves if valve.is_pending] - n_pending = len(pending_valves) - if n_pending == 0: - raise AssertionError("No pending valves found.") - test_set = {} - for x, pending_valve in enumerate(pending_valves): - if x < n_pending - 1: - next_valve = pending_valves[x + 1] - else: - next_valve = pulse_set_pause - name, duration = add_valve_states( - fsm, pending_valve, next_valve, pulse_interval - ) - test_set[name] = duration + fsm, test_set = pending_manager.build_statemachine() # Run the state machine - for trial in range(n_pulses): + for trial in range(pending_manager.n_pulses): pass - - return test_set diff --git a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py index 06f71f0..edb61ca 100644 --- a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py +++ b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py @@ -1,5 +1,7 @@ import unittest +from bpod_core.fsm import StateMachine + from bpod_rig.calibration.liquid import pending_calibration from bpod_rig.calibration.liquid import utils, populate_examples @@ -41,8 +43,10 @@ def test_get_pending_invalid_valve(self): self.manager.get_pending("InvalidValve") def test_add_pending(self): - self.manager.add_pending("Valve1", 10.0) + self.manager.add_pending("Valve1", 10.0) # valve with values self.assertIn(10.0, self.manager.get_pending("Valve1")) + self.manager.add_pending("Valve2", 10.0) # valve without values + self.assertIn(10.0, self.manager.get_pending("Valve2")) def test_add_pending_duplicate(self): self.manager.add_pending("Valve1", 10.0) @@ -65,3 +69,15 @@ def test_complete_measurement(self): self.assertIn(10.0, valve.durations) self.assertIn(1.5, valve.amounts) self.assertNotIn(10.0, self.manager.get_pending("Valve2")) + + def test_build_testset(self): + self.manager.add_pending("Valve1", 10) + self.manager.add_pending("Valve1", 5) + self.manager.add_pending("Valve2", 20) + self.assertEqual(self.manager.build_test_set(), {"Valve1": 10, "Valve2": 20}) + + def test_build_statemachine(self): + self.manager.add_pending("Valve1", 10) + self.manager.add_pending("Valve2", 15) + statemachine, test_set = self.manager.build_statemachine() + self.assertIsInstance(statemachine, StateMachine) From ca489211b44f2d90d727b85062ddeb4c644707aa Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:16:27 +1000 Subject: [PATCH 14/30] Rename classes Remove redundant Class from end of name. --- bpod_rig/calibration/liquid/models.py | 19 ++++++++++------ .../calibration/liquid/pending_calibration.py | 8 +++---- .../calibration/liquid/populate_examples.py | 4 ++-- bpod_rig/calibration/liquid/utils.py | 18 +++++++-------- .../calibration/liquid/test_models.py | 22 +++++++++---------- .../bpod_rig/calibration/liquid/test_utils.py | 4 ++-- 6 files changed, 40 insertions(+), 35 deletions(-) diff --git a/bpod_rig/calibration/liquid/models.py b/bpod_rig/calibration/liquid/models.py index 16a5bca..53b4584 100644 --- a/bpod_rig/calibration/liquid/models.py +++ b/bpod_rig/calibration/liquid/models.py @@ -7,8 +7,11 @@ logger = logging.getLogger(__name__) -class ValveDataClass(BaseModel): - name: str = Field(alias="ValveName") +class ValveData(BaseModel): + name: str = Field( + alias="ValveName", + description="Name of valve (1-indexed)" + ) lastdatemodified: datetime.datetime | str = Field( serialization_alias="LastDateModified", validation_alias="LastDateModified", @@ -146,15 +149,15 @@ def serialize_datetime(self, modification_datetime: datetime.datetime, _info): return modification_datetime.isoformat(timespec="seconds") -class ValveDataManagerClass(BaseModel): +class ValveDataManager(BaseModel): metadata: ValveManagerMetaData = Field(default_factory=ValveManagerMetaData) - valve_datas: list[ValveDataClass] = Field( + valve_datas: list[ValveData] = Field( alias="ValveDatas", default=[], description="Array of valve data objects." ) model_config = ConfigDict(serialize_by_alias=True, validate_by_name=True) - def get_valve(self, valvename: str) -> ValveDataClass: + def get_valve(self, valvename: str) -> ValveData: """Get a valve by name. Parameters @@ -164,7 +167,7 @@ def get_valve(self, valvename: str) -> ValveDataClass: Returns ------- - ValveDataClass + ValveData """ # Check if the valve exists in the ValveDatas list if valvename in self.valve_names: @@ -179,7 +182,9 @@ def create_valve(self, valvename: str) -> None: """Create a new valve with the given name.""" if valvename in self.valve_names: raise KeyError(f"Valve '{valvename}' already exists.") - self.valve_datas.append(ValveDataClass(name=valvename)) # noqa: aliasing with pydantic can cause type check issues + self.valve_datas.append( + ValveData(name=valvename) + ) # noqa: aliasing with pydantic can cause type check issues logger.debug(f"Created new valve: {valvename}") @property diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py index b4a98e0..8a09407 100644 --- a/bpod_rig/calibration/liquid/pending_calibration.py +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -1,10 +1,10 @@ from pydantic import Field from bpod_core.fsm import StateMachine -from bpod_rig.calibration.liquid.models import ValveDataManagerClass, ValveDataClass +from bpod_rig.calibration.liquid.models import ValveDataManager, ValveData -class PendingValve(ValveDataClass): +class PendingValve(ValveData): pending_durations: list[float] = Field( default=[], description="List of durations that are pending calibration measurements.", @@ -63,13 +63,13 @@ class PendingMeasurementsManager: Object would handle all """ - _valvemanager: ValveDataManagerClass + _valvemanager: ValveDataManager valves: list[PendingValve] n_pulses: int = 100 pulse_interval: float = 0.2 pulse_set_pause: float = 0.5 - def __init__(self, valvemanager: ValveDataManagerClass): + def __init__(self, valvemanager: ValveDataManager): self._valvemanager = valvemanager self.valves = [ PendingValve(ValveName=valve.name) for valve in valvemanager.valve_datas diff --git a/bpod_rig/calibration/liquid/populate_examples.py b/bpod_rig/calibration/liquid/populate_examples.py index 9309e85..4950e05 100644 --- a/bpod_rig/calibration/liquid/populate_examples.py +++ b/bpod_rig/calibration/liquid/populate_examples.py @@ -3,11 +3,11 @@ from pathlib import Path from bpod_rig.examples import calibration as example_folder -from bpod_rig.calibration.liquid.models import ValveDataManagerClass +from bpod_rig.calibration.liquid.models import ValveDataManager from bpod_rig.calibration.liquid.utils import create_empty_valve_data_manager -def add_dummy_measurements(valvemanager: ValveDataManagerClass) -> None: +def add_dummy_measurements(valvemanager: ValveDataManager) -> None: """Add dummy measurements to the valves in manager.""" origin_date = datetime.datetime(2000, 1, 1) valve1 = valvemanager.get_valve("Valve1") diff --git a/bpod_rig/calibration/liquid/utils.py b/bpod_rig/calibration/liquid/utils.py index 1844bdc..319edbb 100644 --- a/bpod_rig/calibration/liquid/utils.py +++ b/bpod_rig/calibration/liquid/utils.py @@ -4,7 +4,7 @@ import logging import numpy as np -from .models import ValveDataClass, ValveDataManagerClass +from .models import ValveData, ValveDataManager logger = logging.getLogger(__name__) @@ -12,9 +12,9 @@ def create_empty_valve_data_manager( source: str = "statemachine", n_valves: int = 8, -) -> ValveDataManagerClass: +) -> ValveDataManager: """Create an empty valve manager with 8 valves.""" - valvemanager = ValveDataManagerClass() + valvemanager = ValveDataManager() if source == "statemachine": for index in range(n_valves): valvemanager.create_valve(f"Valve{index + 1}") @@ -28,7 +28,7 @@ def create_empty_valve_data_manager( return valvemanager -def check_valvemanager_user_updated(valvemanager: ValveDataManagerClass) -> bool: +def check_valvemanager_user_updated(valvemanager: ValveDataManager) -> bool: """Check if the user has updated the valve manager has been updated from the example. This is determined by checking if any valve has a modification date later than the @@ -36,7 +36,7 @@ def check_valvemanager_user_updated(valvemanager: ValveDataManagerClass) -> bool Parameters ---------- - valvemanager : ValveDataManagerClass + valvemanager : ValveDataManager The valve manager to check. Returns @@ -48,7 +48,7 @@ def check_valvemanager_user_updated(valvemanager: ValveDataManagerClass) -> bool def suggest_duration( - valveobject: ValveDataClass, + valveobject: ValveData, range_low: float, range_high: float, ) -> float: @@ -56,7 +56,7 @@ def suggest_duration( Parameters ---------- - valveobject : ValveDataClass + valveobject : ValveData The valve object to use for the suggestion. range_low : float The lower bound of microliters (uL) for the range. @@ -145,12 +145,12 @@ def calculate_ranged_amounts( return amounts_vector -def check_COM(valvemanager: ValveDataManagerClass, com_port: str) -> str: +def check_COM(valvemanager: ValveDataManager, com_port: str) -> str: """Check if the COM port in the valve manager matches the given COM port. Parameters ---------- - valvemanager : ValveDataManagerClass + valvemanager : ValveDataManager The valve manager to check. com_port : str The COM port to check against. diff --git a/tests/bpod_rig/calibration/liquid/test_models.py b/tests/bpod_rig/calibration/liquid/test_models.py index 3d04d10..0b6e62b 100644 --- a/tests/bpod_rig/calibration/liquid/test_models.py +++ b/tests/bpod_rig/calibration/liquid/test_models.py @@ -2,18 +2,18 @@ import urllib.request from bpod_rig.calibration.liquid import utils, populate_examples -from bpod_rig.calibration.liquid.models import ValveDataClass, ValveDataManagerClass +from bpod_rig.calibration.liquid.models import ValveData, ValveDataManager class TestValveDataClass(unittest.TestCase): def setUp(self): - self.valve = ValveDataClass(ValveName="Test Valve") + self.valve = ValveData(ValveName="Test Valve") def test_init_alias(self): # Test that alias works correctly # type checker may complain about this because of pydantic aliasing - self.assertEqual(ValveDataClass(name="attrname").name, "attrname") - self.assertEqual(ValveDataClass(ValveName="aliasname").name, "aliasname") + self.assertEqual(ValveData(name="attrname").name, "attrname") + self.assertEqual(ValveData(ValveName="aliasname").name, "aliasname") def test_add_measurement(self): self.valve.add_measurement(10, 1.0) @@ -66,7 +66,7 @@ def test_no_coeffs_with_insufficient_data(self): class TestValveDataManagerClass(unittest.TestCase): def setUp(self): - self.manager = ValveDataManagerClass() + self.manager = ValveDataManager() self.manager.create_valve("Test Valve") dummy = utils.create_empty_valve_data_manager() populate_examples.add_dummy_measurements(dummy) @@ -78,7 +78,7 @@ def test_create_valve(self): def test_get_valve(self): valve = self.manager.get_valve("Test Valve") - self.assertIsInstance(valve, ValveDataClass) + self.assertIsInstance(valve, ValveData) def test_measurements(self): """Test that the default values are behaving as expeted.""" @@ -105,15 +105,15 @@ def setUpClass(cls): cls.jsontext = urllib.request.urlopen(url).read().decode() def test_load_calibration_file(self): - loaded_valves = ValveDataManagerClass.model_validate_json(self.jsontext) - self.assertIsInstance(loaded_valves, ValveDataManagerClass) + loaded_valves = ValveDataManager.model_validate_json(self.jsontext) + self.assertIsInstance(loaded_valves, ValveDataManager) self.assertGreaterEqual(loaded_valves.n_valves, 1) self.assertIn("Valve1", loaded_valves.valve_names) def test_valve_modtime(self): - valvemanager = ValveDataManagerClass.model_validate_json(self.jsontext) + valvemanager = ValveDataManager.model_validate_json(self.jsontext) self.assertFalse(utils.check_valvemanager_user_updated(valvemanager)) - newvalve = ValveDataManagerClass.model_validate_json(valvemanager.to_json()) + newvalve = ValveDataManager.model_validate_json(valvemanager.to_json()) self.assertTrue(utils.check_valvemanager_user_updated(newvalve)) @@ -122,7 +122,7 @@ def setUp(self): self.manager = utils.create_empty_valve_data_manager() populate_examples.add_dummy_measurements(self.manager) self.json_str = self.manager.to_json() - self.loaded_manager = ValveDataManagerClass.model_validate_json(self.json_str) + self.loaded_manager = ValveDataManager.model_validate_json(self.json_str) def test_json_round_trip(self): self.assertEqual(self.manager.n_valves, self.loaded_manager.n_valves) diff --git a/tests/bpod_rig/calibration/liquid/test_utils.py b/tests/bpod_rig/calibration/liquid/test_utils.py index a5051dc..f26cc29 100644 --- a/tests/bpod_rig/calibration/liquid/test_utils.py +++ b/tests/bpod_rig/calibration/liquid/test_utils.py @@ -1,12 +1,12 @@ import unittest from bpod_rig.calibration.liquid import utils -from bpod_rig.calibration.liquid.models import ValveDataClass +from bpod_rig.calibration.liquid.models import ValveData class TestSuggestDuration(unittest.TestCase): def setUp(self): - self.valve = ValveDataClass(ValveName="Test Valve") + self.valve = ValveData(ValveName="Test Valve") self.range_low = 2 self.range_high = 10 self.suggest_duration = lambda: utils.suggest_duration( From 010b5339fc50f5fe26f0a3b13a282eeb95488434 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Sun, 14 Sep 2025 00:08:50 +1000 Subject: [PATCH 15/30] Modify valve action to function with real state machine --- .../calibration/liquid/pending_calibration.py | 93 +++++++++++++------ .../liquid/test_pending_calibration.py | 6 ++ 2 files changed, 69 insertions(+), 30 deletions(-) diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py index 8a09407..ff866ad 100644 --- a/bpod_rig/calibration/liquid/pending_calibration.py +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -1,3 +1,4 @@ +import bpod_core.bpod from pydantic import Field from bpod_core.fsm import StateMachine @@ -24,7 +25,20 @@ def __init__(self, **data): else: raise ValueError("Valve name not recognised for controller assignment.") - def action(self, actiontype: str) -> dict[str, str | list[str | int]]: + def action(self, actiontype: str) -> dict[str, str | list[str | int] | bool]: + """Define the output action for this valve + + Parameters + ---------- + actiontype : str + What kind of action the valve is performing: "open" or "close" + + Returns + ------- + dict + Action definition for valve that can be fed to + StateMachine.add_state(**, output_actions=) + """ assert actiontype in ["open", "close"] if self.valve_controller == "StateMachine": valve_id = int(self.name[5:]) @@ -33,7 +47,8 @@ def action(self, actiontype: str) -> dict[str, str | list[str | int]]: # ValvePhysicalAddress = 2.^(0:7); # ValveAddress = ValvePhysicalAddress(valveID) valve_address = 2 ** (valve_id - 1) - return {'ValveState': valve_address} + # return {'ValveState': valve_address} + return {self.name: True} elif actiontype == "close": raise ValueError("StateMachine valves do not have a close action.") else: @@ -61,13 +76,14 @@ def is_pending(self) -> bool: class PendingMeasurementsManager: """Manage valves and their pending values. - Object would handle all + Maintains lists of valve pending values, and can construct a state machine to test + the pending durations. """ _valvemanager: ValveDataManager valves: list[PendingValve] - n_pulses: int = 100 - pulse_interval: float = 0.2 - pulse_set_pause: float = 0.5 + n_pulses: int = 100 # number of pulses to delivery across the test + pulse_interval: float = 0.2 # time (s) between two different valves pulsing + pulse_set_pause: float = 0.5 # time (s) between each set of pulses def __init__(self, valvemanager: ValveDataManager): self._valvemanager = valvemanager @@ -128,7 +144,7 @@ def complete_measurement( self._valvemanager.get_valve(valvename).add_measurement(duration, amount) def build_test_set(self) -> dict[str, float]: - """Create dict of valves and the duration to test. + """Create dict of valves and the duration (ms) to test. Will take the first (oldest) duration in the valve's pending values. """ @@ -142,9 +158,9 @@ def build_statemachine(self, test_set: dict | None = None) -> tuple[StateMachine """Build the state machine to test all pending valves. The state machine is constructed using object parameters: - - n_pulses: number of pulses to deliver - - pulse_interval: Duration (ms) - - pulse_set_pause: Duration (ms) + - n_pulses: number of pulses to deliver. + - pulse_interval: Duration (s) between one valve's end of pulse and another's start. + - pulse_set_pause: Duration (s) between a set of pulses. The amount of liquid dispensed from the state machine For example, 100 pulses totaling 0.200g (200mg) is 0.200g / 100 = 2mg = 2mL @@ -183,7 +199,7 @@ def build_statemachine(self, test_set: dict | None = None) -> tuple[StateMachine def add_valve_states( fsm: StateMachine, pending_valve: PendingValve, duration: float, next_valve: PendingValve | float, - delay_duration: float + interval_delay_duration: float ): """Add the pending valve to a state machine. @@ -197,7 +213,7 @@ def add_valve_states( Duration (ms) of valve open. next_valve : PendingValve | float The next valve object in the cycle, or the float for the pause duration (s) between pulse sets. - delay_duration : float + interval_delay_duration : float """ last_valve = isinstance(next_valve, float) if last_valve: @@ -214,8 +230,11 @@ def add_valve_states( fsm.add_state( name=f"Pulse{pending_valve.name}", timer=duration / 1000, - state_change_conditions={'Tup': f"Delay{pending_valve.name}"}, - output_actions=pending_valve.action('open') + transitions={'Tup': f"Delay{pending_valve.name}"}, + # state_change_conditions={'Tup': f"Delay{pending_valve.name}"}, + actions=pending_valve.action('open'), + # output_actions=pending_valve.action('open'), + comment="Open valve for duration of state, with closure at the end." ) elif pending_valve.valve_controller == 'PortArray': # Valve is controlled by serial messaging i.e. discrete open/close messages @@ -223,50 +242,53 @@ def add_valve_states( name=f"Pulse{pending_valve.name}", timer=duration / 1000, state_change_conditions={'Tup': f"EndPulse{pending_valve.name}"}, - output_actins=pending_valve.action('open') + output_actions=pending_valve.action('open'), + comment="Send message to Port Array Module to open the valve." ) fsm.add_state( name=f"EndPulse{pending_valve.name}", timer=0, state_change_conditions={'Tup': f"Delay{pending_valve.name}"}, - output_actions=pending_valve.action('close') + output_actions=pending_valve.action('close'), + comment="Send message to Port Array module to close the valve." ) else: raise ValueError("Valve controller not recognised.") fsm.add_state( name=f"Delay{pending_valve.name}", - timer=delay_duration, - state_change_conditions={'Tup': next_state}, - # output_actions={} + timer=interval_delay_duration, + transitions={'Tup': next_state}, + actions={}, + comment="Pause before opening the next valve." ) if last_valve: + # If last valve in set, following the Delay this stated is entered into. fsm.add_state( name="PulseSetPause", timer=pulse_set_pause_duration, - state_change_conditions={'Tup': 'exit'}, - # output_actions={} + transitions={'Tup': 'exit'}, + actions={}, + comment="Pause when the entire set of valves is completed." ) def run_calibration( + bpodsystem: bpod_core.bpod.Bpod, pending_manager: PendingMeasurementsManager, -): + verbose: bool=False, +) -> dict[str, float]: """Run a calibration sequence using Bpod state machine. Parameters ---------- + bpodsystem : bpod_core.bpod.Bpod + Active Bpod machine. pending_manager : PendingMeasurementsManager Object storing valve data - n_pulses : int - Number of pulses to deliver for all valves. - pulse_interval : float, optional - Time (s) between each pulse of each valve. - pulse_set_pause : float, optional - Time (s) between a round (after all valves pulsed) verbose : bool, optional - Print progress. + Print progress, default False. Returns ------- @@ -276,7 +298,18 @@ def run_calibration( # Build the state machine fsm, test_set = pending_manager.build_statemachine() + if verbose: + print('Running liquid calibration:') + print(f'\t{pending_manager.n_pulses} pulses.') + print(f'\t{pending_manager.pulse_interval} between each pulse.') + print(f'\t{pending_manager.pulse_set_pause} between each pulse set.') + for valvename in test_set: + print(f'- {valvename}: duration {test_set[valvename]} ms') + + bpodsystem.send_state_machine(fsm) # Run the state machine for trial in range(pending_manager.n_pulses): - pass + bpodsystem.run_state_machine() + + return test_set diff --git a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py index edb61ca..83d4f15 100644 --- a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py +++ b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py @@ -81,3 +81,9 @@ def test_build_statemachine(self): self.manager.add_pending("Valve2", 15) statemachine, test_set = self.manager.build_statemachine() self.assertIsInstance(statemachine, StateMachine) + self.assertEqual(len(test_set.keys()), 2) + # expect 2 states of open and delay, with final pulse set delay + n_states = 2 + 2 + 1 + self.assertEqual(len(statemachine.states), n_states) + + # TODO: test build, send, and run of state machine with emulator bpod From 309bac89ebe1a69cd879275051ad6eb05bd22a9f Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Sun, 14 Sep 2025 00:19:18 +1000 Subject: [PATCH 16/30] Fix port array state machine definition Output actions haven't been tested because modules haven't been implemented in bpod-core --- bpod_rig/calibration/liquid/pending_calibration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py index ff866ad..1b16737 100644 --- a/bpod_rig/calibration/liquid/pending_calibration.py +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -241,15 +241,15 @@ def add_valve_states( fsm.add_state( name=f"Pulse{pending_valve.name}", timer=duration / 1000, - state_change_conditions={'Tup': f"EndPulse{pending_valve.name}"}, - output_actions=pending_valve.action('open'), + transitions={'Tup': f"EndPulse{pending_valve.name}"}, + actions=pending_valve.action('open'), comment="Send message to Port Array Module to open the valve." ) fsm.add_state( name=f"EndPulse{pending_valve.name}", timer=0, - state_change_conditions={'Tup': f"Delay{pending_valve.name}"}, - output_actions=pending_valve.action('close'), + transitions={'Tup': f"Delay{pending_valve.name}"}, + actions=pending_valve.action('close'), comment="Send message to Port Array module to close the valve." ) else: From cc45b42750e3b713d598cd74f525fe71498db9f7 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Sun, 14 Sep 2025 01:19:41 +1000 Subject: [PATCH 17/30] Improve docs --- bpod_rig/calibration/liquid/models.py | 36 +++++++++++++++---- .../calibration/liquid/pending_calibration.py | 20 +++++++---- .../calibration/liquid/populate_examples.py | 11 +++++- 3 files changed, 53 insertions(+), 14 deletions(-) diff --git a/bpod_rig/calibration/liquid/models.py b/bpod_rig/calibration/liquid/models.py index 53b4584..49bf491 100644 --- a/bpod_rig/calibration/liquid/models.py +++ b/bpod_rig/calibration/liquid/models.py @@ -1,3 +1,4 @@ +"""Data models for liquid calibration.""" import datetime import logging @@ -8,6 +9,11 @@ class ValveData(BaseModel): + """Data model for the liquid calibration data for a single valve. + + Valves are calibrated by measuring the amount of liquid dispensed for a specific + duration. + """ name: str = Field( alias="ValveName", description="Name of valve (1-indexed)" @@ -34,6 +40,7 @@ class ValveData(BaseModel): description="Amounts in uL for each dispense, corresponding to the durations.", ) + # MATLAB .json compatibility is maintained by using the alias to load/save model_config = ConfigDict( serialize_by_alias=True, validate_by_name=True, @@ -142,7 +149,12 @@ class ValveManagerMetaData(BaseModel): default_factory=datetime.datetime.now, description="Time of when the valve data was last saved to json file.", ) - COM: str = "" + COM: str = Field( + default="", + description="The COM port of the last state machine to modify the valves." + ) + # TODO : if the serial number approach to identifying Bpods works out then + # the serial number should be used instead of COM (which can change) @field_serializer("modification_datetime") def serialize_datetime(self, modification_datetime: datetime.datetime, _info): @@ -150,7 +162,12 @@ def serialize_datetime(self, modification_datetime: datetime.datetime, _info): class ValveDataManager(BaseModel): - metadata: ValveManagerMetaData = Field(default_factory=ValveManagerMetaData) + """Parent of multiple ValveData objects. + """ + metadata: ValveManagerMetaData = Field( + default_factory=ValveManagerMetaData, + description="Metadata regarding the set of valves." + ) valve_datas: list[ValveData] = Field( alias="ValveDatas", default=[], description="Array of valve data objects." ) @@ -183,27 +200,32 @@ def create_valve(self, valvename: str) -> None: if valvename in self.valve_names: raise KeyError(f"Valve '{valvename}' already exists.") self.valve_datas.append( - ValveData(name=valvename) + ValveData(ValveName=valvename) ) # noqa: aliasing with pydantic can cause type check issues logger.debug(f"Created new valve: {valvename}") @property def n_valves(self) -> int: - """Get the number of valves in the manager.""" + """The number of valves in the manager.""" return len(self.valve_datas) @property def valve_names(self) -> list[str]: - """Get a list of valve names.""" + """List of valve names.""" return list(valve.name for valve in self.valve_datas) - def to_json(self, machineid: str = None) -> str: - """Convert the ValveDataManagerPydantic to JSON string. + def to_json(self, machineid: str | None = None) -> str: + """Convert the ValveDataManager to JSON string. Parameters ---------- machineid : str, optional The COM port or machine ID to set in the metadata. + + Returns + ------- + str + JSON text (2 space indented) for writing to file. """ if machineid is not None: if not isinstance(machineid, str): diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py index 1b16737..ccf5592 100644 --- a/bpod_rig/calibration/liquid/pending_calibration.py +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -1,9 +1,12 @@ +import logging + import bpod_core.bpod from pydantic import Field from bpod_core.fsm import StateMachine from bpod_rig.calibration.liquid.models import ValveDataManager, ValveData +logger = logging.getLogger(__name__) class PendingValve(ValveData): pending_durations: list[float] = Field( @@ -80,10 +83,15 @@ class PendingMeasurementsManager: the pending durations. """ _valvemanager: ValveDataManager + """The real data used by Bpod.""" valves: list[PendingValve] - n_pulses: int = 100 # number of pulses to delivery across the test - pulse_interval: float = 0.2 # time (s) between two different valves pulsing - pulse_set_pause: float = 0.5 # time (s) between each set of pulses + """All of the pending valves that are being managed""" + n_pulses: int = 100 + """number of pulses to delivery across the test""" + pulse_interval: float = 0.2 + """time (s) between two different valves pulsing""" + pulse_set_pause: float = 0.5 + """time (s) between each set of pulses""" def __init__(self, valvemanager: ValveDataManager): self._valvemanager = valvemanager @@ -142,6 +150,7 @@ def complete_measurement( assert duration in self.get_pending(valvename) self.remove_pending(valvename, duration) self._valvemanager.get_valve(valvename).add_measurement(duration, amount) + logger.debug("Completed measurement for %s %s ms", valvename, duration) def build_test_set(self) -> dict[str, float]: """Create dict of valves and the duration (ms) to test. @@ -179,6 +188,7 @@ def build_statemachine(self, test_set: dict | None = None) -> tuple[StateMachine if test_set is None: test_set = self.build_test_set() valve_names = list(test_set.keys()) + logger.debug("Building test set for %s", valve_names) pending_valves = [valve for valve in self.valves if valve.is_pending] n_pending = len(pending_valves) if n_pending == 0: @@ -231,9 +241,7 @@ def add_valve_states( name=f"Pulse{pending_valve.name}", timer=duration / 1000, transitions={'Tup': f"Delay{pending_valve.name}"}, - # state_change_conditions={'Tup': f"Delay{pending_valve.name}"}, actions=pending_valve.action('open'), - # output_actions=pending_valve.action('open'), comment="Open valve for duration of state, with closure at the end." ) elif pending_valve.valve_controller == 'PortArray': @@ -307,7 +315,7 @@ def run_calibration( print(f'- {valvename}: duration {test_set[valvename]} ms') bpodsystem.send_state_machine(fsm) - + logger.debug("Running calibration state machine.") # Run the state machine for trial in range(pending_manager.n_pulses): bpodsystem.run_state_machine() diff --git a/bpod_rig/calibration/liquid/populate_examples.py b/bpod_rig/calibration/liquid/populate_examples.py index 4950e05..3d2ba52 100644 --- a/bpod_rig/calibration/liquid/populate_examples.py +++ b/bpod_rig/calibration/liquid/populate_examples.py @@ -8,7 +8,16 @@ def add_dummy_measurements(valvemanager: ValveDataManager) -> None: - """Add dummy measurements to the valves in manager.""" + """Add dummy measurements to the valves in manager. + + Modifies the data in place. + Uses 2000-01-01 as the origin date for the data. + + Examples + -------- + >>> dummy_liquid_manager = create_empty_valve_data_manager() + >>> add_dummy_measurements(dummy_liquid_manager) + """ origin_date = datetime.datetime(2000, 1, 1) valve1 = valvemanager.get_valve("Valve1") valve1.add_measurement(22, 2) From 87e9fa79159bcb8b6cdbdf4bd1a0cf9374b2a8ef Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Sun, 14 Sep 2025 01:19:51 +1000 Subject: [PATCH 18/30] Add more tests --- .../bpod_rig/calibration/liquid/test_utils.py | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/bpod_rig/calibration/liquid/test_utils.py b/tests/bpod_rig/calibration/liquid/test_utils.py index f26cc29..616bb8c 100644 --- a/tests/bpod_rig/calibration/liquid/test_utils.py +++ b/tests/bpod_rig/calibration/liquid/test_utils.py @@ -1,7 +1,8 @@ +import datetime import unittest from bpod_rig.calibration.liquid import utils -from bpod_rig.calibration.liquid.models import ValveData +from bpod_rig.calibration.liquid.models import ValveData, ValveDataManager class TestSuggestDuration(unittest.TestCase): @@ -24,3 +25,41 @@ def test_two_measures(self): self.valve.add_measurement(22, 2) self.valve.add_measurement(66, 9.5) self.assertAlmostEqual(self.suggest_duration(), 44.0, places=3) + +class TestCheckValveManagerUserUpdate(unittest.TestCase): + def setUp(self): + self.manager = ValveDataManager() + + def test_earlier(self): + self.manager.metadata.modification_datetime = datetime.datetime(1999, 1, 1) + self.assertFalse(utils.check_valvemanager_user_updated(self.manager)) + + def test_later(self): + self.manager.create_valve("Valve1") + self.manager.get_valve("Valve1").add_measurement(15, 20) + self.assertTrue(utils.check_valvemanager_user_updated(self.manager)) + + def test_equal(self): + self.manager.metadata.modification_datetime = datetime.datetime(2000, 1, 1) + self.assertFalse(utils.check_valvemanager_user_updated(self.manager)) + + def test_dummy(self): + from bpod_rig.calibration.liquid.populate_examples import create_default_json + + manager = ValveDataManager.model_validate_json(create_default_json()) + self.assertFalse(utils.check_valvemanager_user_updated(manager)) + +class TestCheckCOM(unittest.TestCase): + def setUp(self): + self.manager = ValveDataManager() + self.manager.metadata.COM = 'COM3' + + def test_matching(self): + self.assertEqual(utils.check_COM(self.manager, "COM3"), "yes") + + def test_no_match(self): + self.assertEqual(utils.check_COM(self.manager, "COM4"), "no") + + def test_unknown(self): + self.manager.metadata.COM = "" + self.assertEqual(utils.check_COM(self.manager, "COM3"), "unknown") From 138d3a627294f863c39f26f6183b676c029e3f9a Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Sun, 14 Sep 2025 01:23:15 +1000 Subject: [PATCH 19/30] Run ruff --- bpod_rig/calibration/liquid/models.py | 19 +++--- .../calibration/liquid/pending_calibration.py | 67 ++++++++++--------- .../calibration/liquid/populate_examples.py | 1 + bpod_rig/calibration/liquid/utils.py | 2 - .../liquid/test_pending_calibration.py | 4 +- .../bpod_rig/calibration/liquid/test_utils.py | 4 +- 6 files changed, 48 insertions(+), 49 deletions(-) diff --git a/bpod_rig/calibration/liquid/models.py b/bpod_rig/calibration/liquid/models.py index 49bf491..7b1bf1a 100644 --- a/bpod_rig/calibration/liquid/models.py +++ b/bpod_rig/calibration/liquid/models.py @@ -1,4 +1,5 @@ """Data models for liquid calibration.""" + import datetime import logging @@ -14,10 +15,8 @@ class ValveData(BaseModel): Valves are calibrated by measuring the amount of liquid dispensed for a specific duration. """ - name: str = Field( - alias="ValveName", - description="Name of valve (1-indexed)" - ) + + name: str = Field(alias="ValveName", description="Name of valve (1-indexed)") lastdatemodified: datetime.datetime | str = Field( serialization_alias="LastDateModified", validation_alias="LastDateModified", @@ -151,7 +150,7 @@ class ValveManagerMetaData(BaseModel): ) COM: str = Field( default="", - description="The COM port of the last state machine to modify the valves." + description="The COM port of the last state machine to modify the valves.", ) # TODO : if the serial number approach to identifying Bpods works out then # the serial number should be used instead of COM (which can change) @@ -162,11 +161,11 @@ def serialize_datetime(self, modification_datetime: datetime.datetime, _info): class ValveDataManager(BaseModel): - """Parent of multiple ValveData objects. - """ + """Parent of multiple ValveData objects.""" + metadata: ValveManagerMetaData = Field( default_factory=ValveManagerMetaData, - description="Metadata regarding the set of valves." + description="Metadata regarding the set of valves.", ) valve_datas: list[ValveData] = Field( alias="ValveDatas", default=[], description="Array of valve data objects." @@ -199,9 +198,7 @@ def create_valve(self, valvename: str) -> None: """Create a new valve with the given name.""" if valvename in self.valve_names: raise KeyError(f"Valve '{valvename}' already exists.") - self.valve_datas.append( - ValveData(ValveName=valvename) - ) # noqa: aliasing with pydantic can cause type check issues + self.valve_datas.append(ValveData(ValveName=valvename)) # noqa: aliasing with pydantic can cause type check issues logger.debug(f"Created new valve: {valvename}") @property diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py index ccf5592..7709c82 100644 --- a/bpod_rig/calibration/liquid/pending_calibration.py +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -8,6 +8,7 @@ logger = logging.getLogger(__name__) + class PendingValve(ValveData): pending_durations: list[float] = Field( default=[], @@ -61,10 +62,10 @@ def action(self, actiontype: str) -> dict[str, str | list[str | int] | bool]: portnumber = self.name[4] if actiontype == "open": # TODO: serial message byte parsing hasn't been checked - action_serial = ['V', portnumber, 1] + action_serial = ["V", portnumber, 1] return {target_module: action_serial} elif actiontype == "close": - action_serial = ['V', portnumber, 0] + action_serial = ["V", portnumber, 0] return {target_module: action_serial} else: raise ValueError("Action type not recognised") @@ -82,6 +83,7 @@ class PendingMeasurementsManager: Maintains lists of valve pending values, and can construct a state machine to test the pending durations. """ + _valvemanager: ValveDataManager """The real data used by Bpod.""" valves: list[PendingValve] @@ -99,19 +101,16 @@ def __init__(self, valvemanager: ValveDataManager): PendingValve(ValveName=valve.name) for valve in valvemanager.valve_datas ] - def get_valve(self, valvename: str) -> PendingValve: """Retrieve object for a pending valve.""" if valvename not in self._valvemanager.valve_names: raise KeyError(f"Valvename '{valvename}' not found in manager.") return next(valve for valve in self.valves if valve.name == valvename) - def get_pending(self, valvename: str) -> list[float]: """Return pending durations from a given valve.""" return self.get_valve(valvename).pending_durations - def add_pending(self, valvename: str, duration: float) -> None: """Add a pending duration to the valve's list.""" if duration in self.get_pending(valvename): @@ -121,7 +120,6 @@ def add_pending(self, valvename: str, duration: float) -> None: self.get_pending(valvename).append(duration) - def remove_pending(self, valvename: str, duration: float) -> None: """Remove a pending duration from the valve's pending record.""" if duration not in self.get_pending(valvename): @@ -129,7 +127,6 @@ def remove_pending(self, valvename: str, duration: float) -> None: self.get_pending(valvename).remove(duration) - def complete_measurement( self, valvename: str, duration: float, amount: float ) -> None: @@ -163,7 +160,9 @@ def build_test_set(self) -> dict[str, float]: test_set[valve.name] = valve.pending_durations[0] return test_set - def build_statemachine(self, test_set: dict | None = None) -> tuple[StateMachine, dict[str, float]]: + def build_statemachine( + self, test_set: dict | None = None + ) -> tuple[StateMachine, dict[str, float]]: """Build the state machine to test all pending valves. The state machine is constructed using object parameters: @@ -207,9 +206,11 @@ def build_statemachine(self, test_set: dict | None = None) -> tuple[StateMachine def add_valve_states( - fsm: StateMachine, pending_valve: PendingValve, duration: float, + fsm: StateMachine, + pending_valve: PendingValve, + duration: float, next_valve: PendingValve | float, - interval_delay_duration: float + interval_delay_duration: float, ): """Add the pending valve to a state machine. @@ -228,37 +229,37 @@ def add_valve_states( last_valve = isinstance(next_valve, float) if last_valve: pulse_set_pause_duration = next_valve - next_state = 'PulseSetPause' + next_state = "PulseSetPause" else: - assert (isinstance(next_valve, PendingValve)) + assert isinstance(next_valve, PendingValve) pulse_set_pause_duration = None next_state = f"Pulse{next_valve.name}" # Add valve openings to the state machine - if pending_valve.valve_controller == 'StateMachine': + if pending_valve.valve_controller == "StateMachine": # Valve is wired i.e. open while state is active fsm.add_state( name=f"Pulse{pending_valve.name}", timer=duration / 1000, - transitions={'Tup': f"Delay{pending_valve.name}"}, - actions=pending_valve.action('open'), - comment="Open valve for duration of state, with closure at the end." + transitions={"Tup": f"Delay{pending_valve.name}"}, + actions=pending_valve.action("open"), + comment="Open valve for duration of state, with closure at the end.", ) - elif pending_valve.valve_controller == 'PortArray': + elif pending_valve.valve_controller == "PortArray": # Valve is controlled by serial messaging i.e. discrete open/close messages fsm.add_state( name=f"Pulse{pending_valve.name}", timer=duration / 1000, - transitions={'Tup': f"EndPulse{pending_valve.name}"}, - actions=pending_valve.action('open'), - comment="Send message to Port Array Module to open the valve." + transitions={"Tup": f"EndPulse{pending_valve.name}"}, + actions=pending_valve.action("open"), + comment="Send message to Port Array Module to open the valve.", ) fsm.add_state( name=f"EndPulse{pending_valve.name}", timer=0, - transitions={'Tup': f"Delay{pending_valve.name}"}, - actions=pending_valve.action('close'), - comment="Send message to Port Array module to close the valve." + transitions={"Tup": f"Delay{pending_valve.name}"}, + actions=pending_valve.action("close"), + comment="Send message to Port Array module to close the valve.", ) else: raise ValueError("Valve controller not recognised.") @@ -266,9 +267,9 @@ def add_valve_states( fsm.add_state( name=f"Delay{pending_valve.name}", timer=interval_delay_duration, - transitions={'Tup': next_state}, + transitions={"Tup": next_state}, actions={}, - comment="Pause before opening the next valve." + comment="Pause before opening the next valve.", ) if last_valve: @@ -276,16 +277,16 @@ def add_valve_states( fsm.add_state( name="PulseSetPause", timer=pulse_set_pause_duration, - transitions={'Tup': 'exit'}, + transitions={"Tup": "exit"}, actions={}, - comment="Pause when the entire set of valves is completed." + comment="Pause when the entire set of valves is completed.", ) def run_calibration( bpodsystem: bpod_core.bpod.Bpod, pending_manager: PendingMeasurementsManager, - verbose: bool=False, + verbose: bool = False, ) -> dict[str, float]: """Run a calibration sequence using Bpod state machine. @@ -307,12 +308,12 @@ def run_calibration( # Build the state machine fsm, test_set = pending_manager.build_statemachine() if verbose: - print('Running liquid calibration:') - print(f'\t{pending_manager.n_pulses} pulses.') - print(f'\t{pending_manager.pulse_interval} between each pulse.') - print(f'\t{pending_manager.pulse_set_pause} between each pulse set.') + print("Running liquid calibration:") + print(f"\t{pending_manager.n_pulses} pulses.") + print(f"\t{pending_manager.pulse_interval} between each pulse.") + print(f"\t{pending_manager.pulse_set_pause} between each pulse set.") for valvename in test_set: - print(f'- {valvename}: duration {test_set[valvename]} ms') + print(f"- {valvename}: duration {test_set[valvename]} ms") bpodsystem.send_state_machine(fsm) logger.debug("Running calibration state machine.") diff --git a/bpod_rig/calibration/liquid/populate_examples.py b/bpod_rig/calibration/liquid/populate_examples.py index 3d2ba52..de03bf6 100644 --- a/bpod_rig/calibration/liquid/populate_examples.py +++ b/bpod_rig/calibration/liquid/populate_examples.py @@ -1,4 +1,5 @@ """Create the example liquid calibration JSON with 8 vales and dummy measurements.""" + import datetime from pathlib import Path diff --git a/bpod_rig/calibration/liquid/utils.py b/bpod_rig/calibration/liquid/utils.py index 319edbb..c464719 100644 --- a/bpod_rig/calibration/liquid/utils.py +++ b/bpod_rig/calibration/liquid/utils.py @@ -167,5 +167,3 @@ def check_COM(valvemanager: ValveDataManager, com_port: str) -> str: return "yes" else: return "no" - - diff --git a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py index 83d4f15..6a22b25 100644 --- a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py +++ b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py @@ -43,9 +43,9 @@ def test_get_pending_invalid_valve(self): self.manager.get_pending("InvalidValve") def test_add_pending(self): - self.manager.add_pending("Valve1", 10.0) # valve with values + self.manager.add_pending("Valve1", 10.0) # valve with values self.assertIn(10.0, self.manager.get_pending("Valve1")) - self.manager.add_pending("Valve2", 10.0) # valve without values + self.manager.add_pending("Valve2", 10.0) # valve without values self.assertIn(10.0, self.manager.get_pending("Valve2")) def test_add_pending_duplicate(self): diff --git a/tests/bpod_rig/calibration/liquid/test_utils.py b/tests/bpod_rig/calibration/liquid/test_utils.py index 616bb8c..0f3a9b3 100644 --- a/tests/bpod_rig/calibration/liquid/test_utils.py +++ b/tests/bpod_rig/calibration/liquid/test_utils.py @@ -26,6 +26,7 @@ def test_two_measures(self): self.valve.add_measurement(66, 9.5) self.assertAlmostEqual(self.suggest_duration(), 44.0, places=3) + class TestCheckValveManagerUserUpdate(unittest.TestCase): def setUp(self): self.manager = ValveDataManager() @@ -49,10 +50,11 @@ def test_dummy(self): manager = ValveDataManager.model_validate_json(create_default_json()) self.assertFalse(utils.check_valvemanager_user_updated(manager)) + class TestCheckCOM(unittest.TestCase): def setUp(self): self.manager = ValveDataManager() - self.manager.metadata.COM = 'COM3' + self.manager.metadata.COM = "COM3" def test_matching(self): self.assertEqual(utils.check_COM(self.manager, "COM3"), "yes") From 4bdc19d29592ecd628da5e6073a45779425ffa97 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Sun, 14 Sep 2025 01:32:29 +1000 Subject: [PATCH 20/30] Fix CamelCase --- tests/bpod_rig/calibration/liquid/test_pending_calibration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py index 6a22b25..e100926 100644 --- a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py +++ b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py @@ -6,7 +6,7 @@ from bpod_rig.calibration.liquid import utils, populate_examples -class test_PendingValve(unittest.TestCase): +class TestPendingValve(unittest.TestCase): def setUp(self): self.pending_valve = pending_calibration.PendingValve(ValveName="Valve3") @@ -28,7 +28,7 @@ def test_ispending(self): self.assertTrue(valve.is_pending) -class test_PendingMeasurementsManager(unittest.TestCase): +class TestPendingMeasurementsManager(unittest.TestCase): def setUp(self): self.valvemanager = utils.create_empty_valve_data_manager() populate_examples.add_dummy_measurements(self.valvemanager) From 8d1e91a9c637dae09e1e603988db5056a69bc857 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Wed, 17 Sep 2025 15:50:55 +1000 Subject: [PATCH 21/30] Fix typos and formatting --- bpod_rig/calibration/liquid/pending_calibration.py | 6 +++--- bpod_rig/calibration/liquid/populate_examples.py | 2 +- bpod_rig/calibration/liquid/utils.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py index 7709c82..5b40164 100644 --- a/bpod_rig/calibration/liquid/pending_calibration.py +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -80,8 +80,8 @@ def is_pending(self) -> bool: class PendingMeasurementsManager: """Manage valves and their pending values. - Maintains lists of valve pending values, and can construct a state machine to test - the pending durations. + Maintains list of pending valves (which store pending values), and can construct a + state machine to test the pending durations. """ _valvemanager: ValveDataManager @@ -273,7 +273,7 @@ def add_valve_states( ) if last_valve: - # If last valve in set, following the Delay this stated is entered into. + # If last valve in set, enter this state following the valve's Delay. fsm.add_state( name="PulseSetPause", timer=pulse_set_pause_duration, diff --git a/bpod_rig/calibration/liquid/populate_examples.py b/bpod_rig/calibration/liquid/populate_examples.py index de03bf6..fc80238 100644 --- a/bpod_rig/calibration/liquid/populate_examples.py +++ b/bpod_rig/calibration/liquid/populate_examples.py @@ -1,4 +1,4 @@ -"""Create the example liquid calibration JSON with 8 vales and dummy measurements.""" +"""Create the example liquid calibration JSON with 8 valves and dummy measurements.""" import datetime from pathlib import Path diff --git a/bpod_rig/calibration/liquid/utils.py b/bpod_rig/calibration/liquid/utils.py index c464719..c117687 100644 --- a/bpod_rig/calibration/liquid/utils.py +++ b/bpod_rig/calibration/liquid/utils.py @@ -1,10 +1,10 @@ -"""Liquid calibration data management and calibratino routines.""" +"""Liquid calibration data management and calibration routines.""" import datetime import logging import numpy as np -from .models import ValveData, ValveDataManager +from bpod_rig.calibration.liquid.models import ValveData, ValveDataManager logger = logging.getLogger(__name__) From 8b8284cf5d186166a71e0fc29f33aa2bd00d26e9 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:18:12 +1000 Subject: [PATCH 22/30] Add id field The exact nature of the unique id is TBD, but the information will be stored. --- bpod_rig/calibration/liquid/models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bpod_rig/calibration/liquid/models.py b/bpod_rig/calibration/liquid/models.py index 7b1bf1a..9f6d845 100644 --- a/bpod_rig/calibration/liquid/models.py +++ b/bpod_rig/calibration/liquid/models.py @@ -152,8 +152,12 @@ class ValveManagerMetaData(BaseModel): default="", description="The COM port of the last state machine to modify the valves.", ) - # TODO : if the serial number approach to identifying Bpods works out then - # the serial number should be used instead of COM (which can change) + + id: str = Field( + default="", + description="Unique identifier of the machine that controls the valve " + "(i.e. Bpod State Machine or Port Array Module.", + ) @field_serializer("modification_datetime") def serialize_datetime(self, modification_datetime: datetime.datetime, _info): From 2efc2151a52137ce5a5b0e7d0c364ce981699ed5 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:22:33 +1000 Subject: [PATCH 23/30] Rename check_COM to check_com --- bpod_rig/calibration/liquid/utils.py | 2 +- tests/bpod_rig/calibration/liquid/test_utils.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bpod_rig/calibration/liquid/utils.py b/bpod_rig/calibration/liquid/utils.py index c117687..304960c 100644 --- a/bpod_rig/calibration/liquid/utils.py +++ b/bpod_rig/calibration/liquid/utils.py @@ -145,7 +145,7 @@ def calculate_ranged_amounts( return amounts_vector -def check_COM(valvemanager: ValveDataManager, com_port: str) -> str: +def check_com(valvemanager: ValveDataManager, com_port: str) -> str: """Check if the COM port in the valve manager matches the given COM port. Parameters diff --git a/tests/bpod_rig/calibration/liquid/test_utils.py b/tests/bpod_rig/calibration/liquid/test_utils.py index 0f3a9b3..e02d8fc 100644 --- a/tests/bpod_rig/calibration/liquid/test_utils.py +++ b/tests/bpod_rig/calibration/liquid/test_utils.py @@ -57,11 +57,11 @@ def setUp(self): self.manager.metadata.COM = "COM3" def test_matching(self): - self.assertEqual(utils.check_COM(self.manager, "COM3"), "yes") + self.assertEqual(utils.check_com(self.manager, "COM3"), "yes") def test_no_match(self): - self.assertEqual(utils.check_COM(self.manager, "COM4"), "no") + self.assertEqual(utils.check_com(self.manager, "COM4"), "no") def test_unknown(self): self.manager.metadata.COM = "" - self.assertEqual(utils.check_COM(self.manager, "COM3"), "unknown") + self.assertEqual(utils.check_com(self.manager, "COM3"), "unknown") From 88e92a8628e378b38774f16e67ec898345a82001 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:22:42 +1000 Subject: [PATCH 24/30] Resolve ruff warnings --- bpod_rig/calibration/liquid/models.py | 33 +++++++++++++-------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/bpod_rig/calibration/liquid/models.py b/bpod_rig/calibration/liquid/models.py index 9f6d845..0f3d412 100644 --- a/bpod_rig/calibration/liquid/models.py +++ b/bpod_rig/calibration/liquid/models.py @@ -86,23 +86,25 @@ def add_measurement(self, duration: float, amount: float) -> None: self.lastdatemodified = datetime.datetime.now() self._update_coeffs() - def remove_measurement(self, value: int | float, method: str = "index") -> None: + def remove_measurement(self, value: float, method: str = "index") -> None: """Remove a measurement from the valve data. Parameters ---------- value : int | float - The index or duration to remove. + The index or duration (ms) to remove. method : str, optional (default = 'index') - How to use value to find the measurement to remove. 'index' to remove by index, - 'duration' to remove by duration value (assuming unique duration). + How to use value to find the measurement to remove. + - 'duration' to remove by duration value (assuming unique duration). + - 'index' to remove by index in `.durations` """ if method == "duration": if value in self.durations: # Throw error if value is in durations multiple times if self.durations.count(value) > 1: raise ValueError( - "Duration value is present multiple times. Index must be specified." + "Duration value is present multiple times. " + "Index must be specified." ) index = self.durations.index(value) del self.amounts[index] @@ -125,15 +127,11 @@ def remove_measurement(self, value: int | float, method: str = "index") -> None: def _update_coeffs(self) -> None: """Update the polynomial coefficients based on current measurements.""" if len(self.amounts) < 2: + # potential feature: 1 value assumes intercept at 0 self.coeffs = [] return - # TODO: 1 value assumes intercept at 0? - elif len(self.amounts) == 2: - # If only two measurements, use linear fit - order = 1 - else: - # Fit a polynomial of degree 2 - order = 2 + # If only two measurements, use linear fit, otherwise use 2 degree polynomial + order = 1 if len(self.amounts) == 2 else 2 self.coeffs = np.polyfit(self.amounts, self.durations, order).tolist() @field_serializer("lastdatemodified") @@ -148,6 +146,7 @@ class ValveManagerMetaData(BaseModel): default_factory=datetime.datetime.now, description="Time of when the valve data was last saved to json file.", ) + COM: str = Field( default="", description="The COM port of the last state machine to modify the valves.", @@ -195,14 +194,14 @@ def get_valve(self, valvename: str) -> ValveData: if sum(name == valvename for name in self.valve_names) > 1: raise KeyError(f"Multiple valves with name '{valvename}' found.") return next(valve for valve in self.valve_datas if valve.name == valvename) - else: - raise KeyError(f"Valve '{valvename}' not found in valve manager.") + + raise KeyError(f"Valve '{valvename}' not found in valve manager.") def create_valve(self, valvename: str) -> None: """Create a new valve with the given name.""" if valvename in self.valve_names: raise KeyError(f"Valve '{valvename}' already exists.") - self.valve_datas.append(ValveData(ValveName=valvename)) # noqa: aliasing with pydantic can cause type check issues + self.valve_datas.append(ValveData(ValveName=valvename)) logger.debug(f"Created new valve: {valvename}") @property @@ -213,7 +212,7 @@ def n_valves(self) -> int: @property def valve_names(self) -> list[str]: """List of valve names.""" - return list(valve.name for valve in self.valve_datas) + return [valve.name for valve in self.valve_datas] def to_json(self, machineid: str | None = None) -> str: """Convert the ValveDataManager to JSON string. @@ -233,7 +232,7 @@ def to_json(self, machineid: str | None = None) -> str: raise ValueError("machineid must be a string.") # check if the COM is changing - if (self.metadata.COM != "") & (self.metadata.COM != machineid): + if (self.metadata.COM != "") & (self.metadata.COM != machineid): # noqa: SIM300 logger.warning( "COM port of liquid calibration file is changing from %s to %s", self.metadata.COM, From 6f9d11613fcfaccfe517a6a5b12786c2f517e762 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:22:50 +1000 Subject: [PATCH 25/30] Resolve ruff warnings --- bpod_rig/calibration/liquid/utils.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/bpod_rig/calibration/liquid/utils.py b/bpod_rig/calibration/liquid/utils.py index 304960c..21bfe0f 100644 --- a/bpod_rig/calibration/liquid/utils.py +++ b/bpod_rig/calibration/liquid/utils.py @@ -29,7 +29,7 @@ def create_empty_valve_data_manager( def check_valvemanager_user_updated(valvemanager: ValveDataManager) -> bool: - """Check if the user has updated the valve manager has been updated from the example. + """Check if the user has updated the valve manager from the example. This is determined by checking if any valve has a modification date later than the origin date of 2000-01-01. @@ -87,8 +87,9 @@ def suggest_duration( amounts[0], durations[0], range_low, range_high ) else: - # Use an estimate of the middle of the range based on range and our experience with - # the CSHL configuration (nResearch pinch valves, silastic tubing, specs in Bpod literature) + # Use an estimate of the middle of the range based on range and our experience + # with the CSHL configuration. + # (nResearch pinch valves, silastic tubing, specs in Bpod literature) suggested_duration_ms = (range_high + range_low) * 4 return suggested_duration_ms @@ -99,11 +100,11 @@ def calculate_largest_gap_midpoint(sorted_values: list[float]) -> float: for y in range(1, len(sorted_values)): distances[y] = abs(sorted_values[y] - sorted_values[y - 1]) max_distance_pos = np.argmax(distances) - midpoint_value = ( + # Return the midpoint location + return ( sorted_values[max_distance_pos - 1] + (sorted_values[max_distance_pos] - sorted_values[max_distance_pos - 1]) / 2 ) - return midpoint_value def linearly_suggest_duration( @@ -123,14 +124,13 @@ def linearly_suggest_duration( target_amount = range_low + (bottom_part / 2) else: target_amount = range_high - (top_part / 2) - suggested_duration = round(target_amount / ul_per_ms) - return suggested_duration + return round(target_amount / ul_per_ms) def calculate_ranged_amounts( amounts: np.array, range_low: float, range_high: float ) -> list[float]: - """Calculate a sorted list of amounts within a given range, including the range bounds.""" + """Calculate a sorted list of amounts within a range, including the range bounds.""" amounts_vector = amounts.tolist() if range_low not in amounts: amounts_vector.append(range_low) @@ -141,8 +141,7 @@ def calculate_ranged_amounts( amounts_vector = sorted(amounts_vector) startpoint = amounts_vector.index(range_low) endpoint = amounts_vector.index(range_high) - amounts_vector = amounts_vector[startpoint : endpoint + 1] - return amounts_vector + return amounts_vector[startpoint : endpoint + 1] def check_com(valvemanager: ValveDataManager, com_port: str) -> str: @@ -165,5 +164,4 @@ def check_com(valvemanager: ValveDataManager, com_port: str) -> str: return "unknown" if valvecom == com_port: return "yes" - else: - return "no" + return "no" From b30464440c3f6e7d6d2078035fcc55a72c57a4ed Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:24:59 +1000 Subject: [PATCH 26/30] Resolve ruff warnings --- bpod_rig/calibration/liquid/populate_examples.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bpod_rig/calibration/liquid/populate_examples.py b/bpod_rig/calibration/liquid/populate_examples.py index fc80238..3a5b566 100644 --- a/bpod_rig/calibration/liquid/populate_examples.py +++ b/bpod_rig/calibration/liquid/populate_examples.py @@ -45,15 +45,14 @@ def create_default_json() -> str: add_dummy_measurements(dummy) jsontext = dummy.to_json(machineid=None) # replace modification datetime with a fixed date for consistency - jsontext = jsontext.replace( + return jsontext.replace( dummy.metadata.modification_datetime.isoformat(timespec="seconds"), "2000-01-01T00:00:00", ) - return jsontext def main(): - """Create the example liquid calibration JSON with 8 valves and dummy measurements.""" + """Create the example liquid calibration JSON in the example folder.""" example_json = create_default_json() example_path = Path(example_folder.__path__[0]) / "LiquidCalibration.json" example_path.write_text(example_json) From 8ed1e9aab1e17e86f47fce5a22ca3147a2989b60 Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:28:02 +1000 Subject: [PATCH 27/30] Resolve ruff warnings --- tests/bpod_rig/calibration/liquid/test_models.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/bpod_rig/calibration/liquid/test_models.py b/tests/bpod_rig/calibration/liquid/test_models.py index 0b6e62b..5bb1a42 100644 --- a/tests/bpod_rig/calibration/liquid/test_models.py +++ b/tests/bpod_rig/calibration/liquid/test_models.py @@ -12,7 +12,7 @@ def setUp(self): def test_init_alias(self): # Test that alias works correctly # type checker may complain about this because of pydantic aliasing - self.assertEqual(ValveData(name="attrname").name, "attrname") + self.assertEqual(ValveData(name="attrname").name, "attrname") # noqa: Pydantic's validation aliases raises Unexpected Argument self.assertEqual(ValveData(ValveName="aliasname").name, "aliasname") def test_add_measurement(self): @@ -101,8 +101,13 @@ class TestGen2CalibrationFile(unittest.TestCase): @classmethod def setUpClass(cls): super(TestGen2CalibrationFile, cls).setUpClass() - url = "https://raw.githubusercontent.com/sanworks/Bpod_Gen2/refs/heads/develop/Examples/Example%20Calibration%20Files/LiquidCalibration.json" - cls.jsontext = urllib.request.urlopen(url).read().decode() + cls.jsontext = ( + urllib.request.urlopen( + "https://raw.githubusercontent.com/sanworks/Bpod_Gen2/refs/heads/develop/Examples/Example%20Calibration%20Files/LiquidCalibration.json" + ) + .read() + .decode() + ) def test_load_calibration_file(self): loaded_valves = ValveDataManager.model_validate_json(self.jsontext) From ce91201645563406d9a052541e861d4adbca3fff Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:35:19 +1000 Subject: [PATCH 28/30] Resolve ruff warnings --- .../calibration/liquid/pending_calibration.py | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/bpod_rig/calibration/liquid/pending_calibration.py b/bpod_rig/calibration/liquid/pending_calibration.py index 5b40164..04371ea 100644 --- a/bpod_rig/calibration/liquid/pending_calibration.py +++ b/bpod_rig/calibration/liquid/pending_calibration.py @@ -17,7 +17,8 @@ class PendingValve(ValveData): valve_controller: str = Field( default="unknown", - description="The controlling hardware of the valve (e.g., 'StateMachine', 'PortArray').", + description="The controlling hardware of the valve " + "(e.g., 'StateMachine', 'PortArray').", ) def __init__(self, **data): @@ -43,34 +44,33 @@ def action(self, actiontype: str) -> dict[str, str | list[str | int] | bool]: Action definition for valve that can be fed to StateMachine.add_state(**, output_actions=) """ - assert actiontype in ["open", "close"] + if actiontype not in ["open", "close"]: + raise ValueError("Action type not recognised.") if self.valve_controller == "StateMachine": - valve_id = int(self.name[5:]) if actiontype == "open": # matlab code # ValvePhysicalAddress = 2.^(0:7); # ValveAddress = ValvePhysicalAddress(valveID) - valve_address = 2 ** (valve_id - 1) + # valve_id = int(self.name[5:]) + # valve_address = 2 ** (valve_id - 1) # return {'ValveState': valve_address} return {self.name: True} - elif actiontype == "close": + + if actiontype == "close": raise ValueError("StateMachine valves do not have a close action.") - else: - raise ValueError("Action type not recognised.") - elif self.valve_controller == "PortArray": + raise ValueError("Action type not recognised.") + if self.valve_controller == "PortArray": target_module = self.name[0:3] portnumber = self.name[4] if actiontype == "open": # TODO: serial message byte parsing hasn't been checked action_serial = ["V", portnumber, 1] return {target_module: action_serial} - elif actiontype == "close": + if actiontype == "close": action_serial = ["V", portnumber, 0] return {target_module: action_serial} - else: - raise ValueError("Action type not recognised") - else: - raise ValueError("Valve controller not recognised.") + raise ValueError("Action type not recognised") + raise ValueError("Valve controller not recognised.") @property def is_pending(self) -> bool: @@ -144,7 +144,8 @@ def complete_measurement( amount : float Quantity (mg or mL) released from the duration. """ - assert duration in self.get_pending(valvename) + if duration not in self.get_pending(valvename): + raise ValueError("Duration not found in existing pending values.") self.remove_pending(valvename, duration) self._valvemanager.get_valve(valvename).add_measurement(duration, amount) logger.debug("Completed measurement for %s %s ms", valvename, duration) @@ -167,7 +168,7 @@ def build_statemachine( The state machine is constructed using object parameters: - n_pulses: number of pulses to deliver. - - pulse_interval: Duration (s) between one valve's end of pulse and another's start. + - pulse_interval: Duration (s) between one valve's pulse end and next's start. - pulse_set_pause: Duration (s) between a set of pulses. The amount of liquid dispensed from the state machine @@ -176,7 +177,8 @@ def build_statemachine( Parameters ---------- test_set : dict | None, optional - Names of valves and durations to build state machine for. If not provided, test all valves that have pending durations using their oldest duration. + Names of valves and durations to build state machine for. If not provided, + test all valves that have pending durations using their oldest duration. Returns ------- tuple[StateMachine, dict[str, float]] @@ -223,7 +225,8 @@ def add_valve_states( duration : float Duration (ms) of valve open. next_valve : PendingValve | float - The next valve object in the cycle, or the float for the pause duration (s) between pulse sets. + The next valve object in the cycle, or the float for the pause duration (s) + between pulse sets. interval_delay_duration : float """ last_valve = isinstance(next_valve, float) @@ -231,7 +234,8 @@ def add_valve_states( pulse_set_pause_duration = next_valve next_state = "PulseSetPause" else: - assert isinstance(next_valve, PendingValve) + if not isinstance(next_valve, PendingValve): + raise ValueError("Next valve must be a PendingValve or float.") pulse_set_pause_duration = None next_state = f"Pulse{next_valve.name}" @@ -302,7 +306,8 @@ def run_calibration( Returns ------- dict[str, float] - Dictionary of valve names and the duration (ms) of the pulse added to the state machine. + Dictionary of valve names and the duration (ms) of the pulse added to the + state machine. """ # Build the state machine @@ -318,7 +323,7 @@ def run_calibration( bpodsystem.send_state_machine(fsm) logger.debug("Running calibration state machine.") # Run the state machine - for trial in range(pending_manager.n_pulses): + for _ in range(pending_manager.n_pulses): bpodsystem.run_state_machine() return test_set From f0dac4d2bdeefb0ae496551ab37c2908d396345e Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 19 Sep 2025 10:37:40 +1000 Subject: [PATCH 29/30] Resolve ruff warnings --- tests/bpod_rig/calibration/liquid/test_pending_calibration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py index e100926..04dd15b 100644 --- a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py +++ b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py @@ -18,7 +18,7 @@ def test_valve_controller_assignment(self): def test_unrecognied_controller_assignment(self): with self.assertRaises(ValueError): - valve1 = pending_calibration.PendingValve(ValveName="HeartValve") + pending_calibration.PendingValve(ValveName="HeartValve") def test_ispending(self): valve = pending_calibration.PendingValve(ValveName="Valve1") From 8181af6bb53dfe30c40d30f0b3633a3000c8b2cf Mon Sep 17 00:00:00 2001 From: George Stuyt <15712782+ogeesan@users.noreply.github.com> Date: Fri, 3 Oct 2025 01:55:13 +1000 Subject: [PATCH 30/30] Convert unittest to pytest In anticipation of change to dev branch to use pytest --- .../calibration/liquid/test_models.py | 92 ++++++++++--------- .../liquid/test_pending_calibration.py | 52 ++++++----- .../bpod_rig/calibration/liquid/test_utils.py | 37 ++++---- 3 files changed, 94 insertions(+), 87 deletions(-) diff --git a/tests/bpod_rig/calibration/liquid/test_models.py b/tests/bpod_rig/calibration/liquid/test_models.py index 5bb1a42..805ceaf 100644 --- a/tests/bpod_rig/calibration/liquid/test_models.py +++ b/tests/bpod_rig/calibration/liquid/test_models.py @@ -1,71 +1,73 @@ -import unittest +import pytest import urllib.request from bpod_rig.calibration.liquid import utils, populate_examples from bpod_rig.calibration.liquid.models import ValveData, ValveDataManager -class TestValveDataClass(unittest.TestCase): - def setUp(self): +class TestValveDataClass: + @pytest.fixture(autouse=True) + def setup_method(self): self.valve = ValveData(ValveName="Test Valve") def test_init_alias(self): # Test that alias works correctly # type checker may complain about this because of pydantic aliasing - self.assertEqual(ValveData(name="attrname").name, "attrname") # noqa: Pydantic's validation aliases raises Unexpected Argument - self.assertEqual(ValveData(ValveName="aliasname").name, "aliasname") + assert ValveData(name="attrname").name == "attrname" # noqa: Pydantic's validation aliases raises Unexpected Argument + assert ValveData(ValveName="aliasname").name == "aliasname" def test_add_measurement(self): self.valve.add_measurement(10, 1.0) - self.assertEqual(len(self.valve.amounts), 1) - self.assertEqual(len(self.valve.durations), 1) + assert len(self.valve.amounts) == 1 + assert len(self.valve.durations) == 1 def test_get_valve_time(self): self.valve.add_measurement(10, 1.0) self.valve.add_measurement(20, 2.0) duration = self.valve.get_valve_time(1.0) - self.assertAlmostEqual(duration, 10) + assert duration == pytest.approx(10) def test_remove_measurement_by_index(self): self.valve.add_measurement(10, 1.0) self.valve.remove_measurement(0) - self.assertEqual(len(self.valve.amounts), 0) + assert len(self.valve.amounts) == 0 def test_remove_measurement_by_duration(self): self.valve.add_measurement(10, 1.0) self.valve.remove_measurement(10, method="duration") - self.assertEqual(len(self.valve.amounts), 0) + assert len(self.valve.amounts) == 0 def test_remove_measurement_invalid_index(self): self.valve.add_measurement(10, 1.0) self.valve.add_measurement(20, 2.0) self.valve.remove_measurement(0) - with self.assertRaises(IndexError): + with pytest.raises(IndexError): self.valve.remove_measurement(5) def test_remove_measurement_invalid_duration(self): self.valve.add_measurement(10, 1.0) self.valve.add_measurement(20, 2.0) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.valve.remove_measurement(15, method="duration") def test_remove_measurement_duplicate_duration(self): self.valve.add_measurement(10, 1.0) self.valve.add_measurement(10, 2.0) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.valve.remove_measurement(10, method="duration") def test_no_coeffs_with_insufficient_data(self): self.valve.add_measurement(10, 1.0) - self.assertEqual(len(self.valve.coeffs), 0) - with self.assertRaises(ValueError): + assert len(self.valve.coeffs) == 0 + with pytest.raises(ValueError): self.valve.get_valve_time(1.0) self.valve.add_measurement(20, 2.0) - self.assertIsNotNone(self.valve.coeffs) + assert self.valve.coeffs is not None -class TestValveDataManagerClass(unittest.TestCase): - def setUp(self): +class TestValveDataManagerClass: + @pytest.fixture(autouse=True) + def setup_method(self): self.manager = ValveDataManager() self.manager.create_valve("Test Valve") dummy = utils.create_empty_valve_data_manager() @@ -73,35 +75,34 @@ def setUp(self): self.dummy = dummy def test_create_valve(self): - self.assertEqual(self.manager.n_valves, 1) - self.assertIn("Test Valve", self.manager.valve_names) + assert self.manager.n_valves == 1 + assert "Test Valve" in self.manager.valve_names def test_get_valve(self): valve = self.manager.get_valve("Test Valve") - self.assertIsInstance(valve, ValveData) + assert isinstance(valve, ValveData) def test_measurements(self): - """Test that the default values are behaving as expeted.""" + """Test that the default values are behaving as expected.""" liquid_amount = 15 valve = "Valve1" expected_duration = 0.0758261 * 1000 duration = self.dummy.get_valve(valve).get_valve_time(liquid_amount) - self.assertAlmostEqual(duration, expected_duration, places=3) + assert duration == pytest.approx(expected_duration, abs=0.001) valve = "Valve3" expected_duration = 0.1102557 * 1000 duration = self.dummy.get_valve(valve).get_valve_time(liquid_amount) - self.assertAlmostEqual(duration, expected_duration, places=3) + assert duration == pytest.approx(expected_duration, abs=0.001) -class TestGen2CalibrationFile(unittest.TestCase): +class TestGen2CalibrationFile: """Test loading a calibration file from the Bpod_Gen2 (MATLAB) repo.""" - @classmethod - def setUpClass(cls): - super(TestGen2CalibrationFile, cls).setUpClass() - cls.jsontext = ( + @pytest.fixture(scope="class") + def jsontext(self): + return ( urllib.request.urlopen( "https://raw.githubusercontent.com/sanworks/Bpod_Gen2/refs/heads/develop/Examples/Example%20Calibration%20Files/LiquidCalibration.json" ) @@ -109,32 +110,33 @@ def setUpClass(cls): .decode() ) - def test_load_calibration_file(self): - loaded_valves = ValveDataManager.model_validate_json(self.jsontext) - self.assertIsInstance(loaded_valves, ValveDataManager) - self.assertGreaterEqual(loaded_valves.n_valves, 1) - self.assertIn("Valve1", loaded_valves.valve_names) + def test_load_calibration_file(self, jsontext): + loaded_valves = ValveDataManager.model_validate_json(jsontext) + assert isinstance(loaded_valves, ValveDataManager) + assert loaded_valves.n_valves >= 1 + assert "Valve1" in loaded_valves.valve_names - def test_valve_modtime(self): - valvemanager = ValveDataManager.model_validate_json(self.jsontext) - self.assertFalse(utils.check_valvemanager_user_updated(valvemanager)) + def test_valve_modtime(self, jsontext): + valvemanager = ValveDataManager.model_validate_json(jsontext) + assert not utils.check_valvemanager_user_updated(valvemanager) newvalve = ValveDataManager.model_validate_json(valvemanager.to_json()) - self.assertTrue(utils.check_valvemanager_user_updated(newvalve)) + assert utils.check_valvemanager_user_updated(newvalve) -class TestValveDataManagerJSON(unittest.TestCase): - def setUp(self): +class TestValveDataManagerJSON: + @pytest.fixture(autouse=True) + def setup_method(self): self.manager = utils.create_empty_valve_data_manager() populate_examples.add_dummy_measurements(self.manager) self.json_str = self.manager.to_json() self.loaded_manager = ValveDataManager.model_validate_json(self.json_str) def test_json_round_trip(self): - self.assertEqual(self.manager.n_valves, self.loaded_manager.n_valves) - self.assertEqual(self.manager.valve_names, self.loaded_manager.valve_names) + assert self.manager.n_valves == self.loaded_manager.n_valves + assert self.manager.valve_names == self.loaded_manager.valve_names for valve_name in self.manager.valve_names: original_valve = self.manager.get_valve(valve_name) loaded_valve = self.loaded_manager.get_valve(valve_name) - self.assertEqual(original_valve.amounts, loaded_valve.amounts) - self.assertEqual(original_valve.durations, loaded_valve.durations) - self.assertEqual(original_valve.coeffs, loaded_valve.coeffs) + assert original_valve.amounts == loaded_valve.amounts + assert original_valve.durations == loaded_valve.durations + assert original_valve.coeffs == loaded_valve.coeffs diff --git a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py index 04dd15b..6078d88 100644 --- a/tests/bpod_rig/calibration/liquid/test_pending_calibration.py +++ b/tests/bpod_rig/calibration/liquid/test_pending_calibration.py @@ -1,4 +1,4 @@ -import unittest +import pytest from bpod_core.fsm import StateMachine @@ -6,84 +6,86 @@ from bpod_rig.calibration.liquid import utils, populate_examples -class TestPendingValve(unittest.TestCase): - def setUp(self): +class TestPendingValve: + @pytest.fixture(autouse=True) + def setup_method(self): self.pending_valve = pending_calibration.PendingValve(ValveName="Valve3") def test_valve_controller_assignment(self): valve1 = pending_calibration.PendingValve(ValveName="PA1") valve2 = pending_calibration.PendingValve(ValveName="Valve1") - self.assertEqual(valve1.valve_controller, "PortArray") - self.assertEqual(valve2.valve_controller, "StateMachine") + assert valve1.valve_controller == "PortArray" + assert valve2.valve_controller == "StateMachine" def test_unrecognied_controller_assignment(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): pending_calibration.PendingValve(ValveName="HeartValve") def test_ispending(self): valve = pending_calibration.PendingValve(ValveName="Valve1") - self.assertFalse(valve.is_pending) + assert not valve.is_pending valve = pending_calibration.PendingValve(ValveName="Valve1") valve.pending_durations.append(20) - self.assertTrue(valve.is_pending) + assert valve.is_pending -class TestPendingMeasurementsManager(unittest.TestCase): - def setUp(self): +class TestPendingMeasurementsManager: + @pytest.fixture(autouse=True) + def setup_method(self): self.valvemanager = utils.create_empty_valve_data_manager() populate_examples.add_dummy_measurements(self.valvemanager) self.manager = pending_calibration.PendingMeasurementsManager(self.valvemanager) def test_get_pending_initial(self): - self.assertEqual(self.manager.get_pending("Valve1"), []) - self.assertEqual(self.manager.get_pending("Valve2"), []) + assert self.manager.get_pending("Valve1") == [] + assert self.manager.get_pending("Valve2") == [] def test_get_pending_invalid_valve(self): - with self.assertRaises(KeyError): + with pytest.raises(KeyError): self.manager.get_pending("InvalidValve") def test_add_pending(self): self.manager.add_pending("Valve1", 10.0) # valve with values - self.assertIn(10.0, self.manager.get_pending("Valve1")) + assert 10.0 in self.manager.get_pending("Valve1") self.manager.add_pending("Valve2", 10.0) # valve without values - self.assertIn(10.0, self.manager.get_pending("Valve2")) + assert 10.0 in self.manager.get_pending("Valve2") def test_add_pending_duplicate(self): self.manager.add_pending("Valve1", 10.0) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.manager.add_pending("Valve1", 10.0) def test_remove_pending(self): self.manager.add_pending("Valve1", 10.0) self.manager.remove_pending("Valve1", 10.0) - self.assertNotIn(10.0, self.manager.get_pending("Valve1")) + assert 10.0 not in self.manager.get_pending("Valve1") def test_remove_pending_nonexistent(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.manager.remove_pending("Valve1", 10.0) def test_complete_measurement(self): self.manager.add_pending("Valve2", 10.0) self.manager.complete_measurement("Valve2", 10.0, 1.5) valve = self.valvemanager.get_valve("Valve2") - self.assertIn(10.0, valve.durations) - self.assertIn(1.5, valve.amounts) - self.assertNotIn(10.0, self.manager.get_pending("Valve2")) + assert 10.0 in valve.durations + assert 1.5 in valve.amounts + assert 10.0 not in self.manager.get_pending("Valve2") def test_build_testset(self): self.manager.add_pending("Valve1", 10) self.manager.add_pending("Valve1", 5) self.manager.add_pending("Valve2", 20) - self.assertEqual(self.manager.build_test_set(), {"Valve1": 10, "Valve2": 20}) + assert self.manager.build_test_set() == {"Valve1": 10, "Valve2": 20} def test_build_statemachine(self): self.manager.add_pending("Valve1", 10) self.manager.add_pending("Valve2", 15) statemachine, test_set = self.manager.build_statemachine() - self.assertIsInstance(statemachine, StateMachine) - self.assertEqual(len(test_set.keys()), 2) + assert isinstance(statemachine, StateMachine) + assert len(test_set.keys()) == 2 # expect 2 states of open and delay, with final pulse set delay n_states = 2 + 2 + 1 - self.assertEqual(len(statemachine.states), n_states) + assert len(statemachine.states) == n_states # TODO: test build, send, and run of state machine with emulator bpod diff --git a/tests/bpod_rig/calibration/liquid/test_utils.py b/tests/bpod_rig/calibration/liquid/test_utils.py index e02d8fc..afd62d2 100644 --- a/tests/bpod_rig/calibration/liquid/test_utils.py +++ b/tests/bpod_rig/calibration/liquid/test_utils.py @@ -1,12 +1,13 @@ import datetime -import unittest +import pytest from bpod_rig.calibration.liquid import utils from bpod_rig.calibration.liquid.models import ValveData, ValveDataManager -class TestSuggestDuration(unittest.TestCase): - def setUp(self): +class TestSuggestDuration: + @pytest.fixture(autouse=True) + def setup_method(self): self.valve = ValveData(ValveName="Test Valve") self.range_low = 2 self.range_high = 10 @@ -15,53 +16,55 @@ def setUp(self): ) def test_no_measures(self): - self.assertEqual(self.suggest_duration(), 48) + assert self.suggest_duration() == 48 def test_one_measure(self): self.valve.add_measurement(57, 7.5) - self.assertEqual(self.suggest_duration(), 36) + assert self.suggest_duration() == 36 def test_two_measures(self): self.valve.add_measurement(22, 2) self.valve.add_measurement(66, 9.5) - self.assertAlmostEqual(self.suggest_duration(), 44.0, places=3) + assert self.suggest_duration() == pytest.approx(44.0, abs=0.001) -class TestCheckValveManagerUserUpdate(unittest.TestCase): - def setUp(self): +class TestCheckValveManagerUserUpdate: + @pytest.fixture(autouse=True) + def setup_method(self): self.manager = ValveDataManager() def test_earlier(self): self.manager.metadata.modification_datetime = datetime.datetime(1999, 1, 1) - self.assertFalse(utils.check_valvemanager_user_updated(self.manager)) + assert not utils.check_valvemanager_user_updated(self.manager) def test_later(self): self.manager.create_valve("Valve1") self.manager.get_valve("Valve1").add_measurement(15, 20) - self.assertTrue(utils.check_valvemanager_user_updated(self.manager)) + assert utils.check_valvemanager_user_updated(self.manager) def test_equal(self): self.manager.metadata.modification_datetime = datetime.datetime(2000, 1, 1) - self.assertFalse(utils.check_valvemanager_user_updated(self.manager)) + assert not utils.check_valvemanager_user_updated(self.manager) def test_dummy(self): from bpod_rig.calibration.liquid.populate_examples import create_default_json manager = ValveDataManager.model_validate_json(create_default_json()) - self.assertFalse(utils.check_valvemanager_user_updated(manager)) + assert not utils.check_valvemanager_user_updated(manager) -class TestCheckCOM(unittest.TestCase): - def setUp(self): +class TestCheckCOM: + @pytest.fixture(autouse=True) + def setup_method(self): self.manager = ValveDataManager() self.manager.metadata.COM = "COM3" def test_matching(self): - self.assertEqual(utils.check_com(self.manager, "COM3"), "yes") + assert utils.check_com(self.manager, "COM3") == "yes" def test_no_match(self): - self.assertEqual(utils.check_com(self.manager, "COM4"), "no") + assert utils.check_com(self.manager, "COM4") == "no" def test_unknown(self): self.manager.metadata.COM = "" - self.assertEqual(utils.check_com(self.manager, "COM3"), "unknown") + assert utils.check_com(self.manager, "COM3") == "unknown"