diff --git a/bpod_core/bpod.py b/bpod_core/bpod/__init__.py similarity index 90% rename from bpod_core/bpod.py rename to bpod_core/bpod/__init__.py index abad485d..52761439 100644 --- a/bpod_core/bpod.py +++ b/bpod_core/bpod/__init__.py @@ -1,155 +1,57 @@ """Module for interfacing with the Bpod Finite State Machine.""" +import contextlib import logging import re import struct import threading import traceback import weakref -from abc import ABC, abstractmethod from collections.abc import Callable, Iterator from dataclasses import dataclass, field from types import TracebackType from typing import Any, NamedTuple, cast -import msgspec import numpy as np from numpy.typing import NDArray -from platformdirs import user_config_path from pydantic import validate_call from serial import SerialException from bpod_core import __version__ as bpod_core_version +from bpod_core.bpod.abc import AbstractBpod +from bpod_core.bpod.constants import ( + CHANNEL_TYPES_INPUT, + CHANNEL_TYPES_OUTPUT, + CONFIG_PATH, + DISCOVERY_TIMEOUT, + MACHINE_TYPES, + MAX_BPOD_HW_VERSION, + MIN_BPOD_FW_VERSION, + MIN_BPOD_HW_VERSION, + N_SERIAL_EVENTS_DEFAULT, + PIDS_BPOD, + VALID_OPERATORS, + VIDS_BPOD, +) +from bpod_core.bpod.structs import BpodInfo, HardwareConfiguration, VersionInfo from bpod_core.com import ( ExtendedSerial, - USBSerialDevice, + SerialDevice, find_ports, verify_serial_discovery, ) -from bpod_core.constants import STRUCT_UINT32, VID_TEENSY, PIDsTeensy +from bpod_core.constants import STRUCT_UINT32, TeensyPID from bpod_core.fsm import StateMachine from bpod_core.ipc import DualChannelClient, DualChannelHost from bpod_core.misc import ( - DocstringInheritanceMixin, SettingsDict, extend_packed, suggest_similar, ) -VIDS_BPOD = [VID_TEENSY] -"""Vendor IDs of supported Bpod devices""" - -PIDS_BPOD = [PIDsTeensy.SERIAL, PIDsTeensy.DUAL_SERIAL, PIDsTeensy.TRIPLE_SERIAL] -"""List of Product IDs of supported Bpod devices""" - -MIN_BPOD_FW_VERSION = (23, 0) -"""minimum supported firmware version (major, minor)""" - -MIN_BPOD_HW_VERSION = 3 -"""minimum supported hardware version""" - -MAX_BPOD_HW_VERSION = 4 -"""maximum supported hardware version""" - -CHANNEL_TYPES_INPUT = { - b'U': 'Serial', - b'X': 'SoftCode', - b'Z': 'SoftCodeApp', - b'F': 'Flex', - b'D': 'Digital', - b'B': 'BNC', - b'W': 'Wire', - b'P': 'Port', -} -CHANNEL_TYPES_OUTPUT = CHANNEL_TYPES_INPUT.copy() -CHANNEL_TYPES_OUTPUT.update({b'V': 'Valve', b'P': 'PWM'}) -N_SERIAL_EVENTS_DEFAULT = 15 -VALID_OPERATORS = {'exit', '>exit', '>back'} -MACHINE_TYPES = {3: 'r2.0-2.5', 4: '2+ r1.0'} -CONFIG_PATH = user_config_path('bpod-core', False) -DISCOVERY_TIMEOUT = 0.11 - logger = logging.getLogger(__name__) -class BpodSettings(msgspec.Struct): - """Settings for a specific Bpod device.""" - - serial_number: str - """Serial number of the device.""" - name: str = '' - """User-defined name of the device.""" - location: str = '' - """User-defined location of the device.""" - zmq_port_pub: int | None = None - """Port number for the ZeroMQ PUB service.""" - zmq_port_rep: int | None = None - """Port number for the ZeroMQ REP service.""" - - -class BpodInfo(msgspec.Struct): - """Information about a specific Bpod device.""" - - serial_number: str - """Serial number of the device.""" - port: str | None = None - """Port on which the device is connected.""" - name: str = '' - """User-defined name of the device.""" - location: str = '' - """User-defined location of the device.""" - zmq_pub: str | None = None - """ZeroMQ PUB service address.""" - zmq_rep: str | None = None - """ZeroMQ REP service address.""" - - -class VersionInfo(msgspec.Struct, frozen=True): - """Data structure representing various version information.""" - - firmware: tuple[int, int] - """Firmware version (major, minor)""" - machine: int - """Machine type (numerical)""" - machine_str: str - """Machine type (string)""" - pcb: int | None - """PCB revision, if applicable""" - bpod_core: str - """bpod-core version""" - - -class HardwareConfiguration(msgspec.Struct, frozen=True): - """Represents the Bpod's on-board hardware configuration.""" - - max_states: int - """Maximum number of supported states in a single state machine description.""" - cycle_period: int - """Period of the state machine's refresh cycle during a trial in microseconds.""" - max_serial_events: int - """Maximum number of behavior events allocatable among connected modules.""" - max_bytes_per_serial_message: int - """Maximum number of bytes allowed per serial message.""" - n_global_timers: int - """Number of global timers supported.""" - n_global_counters: int - """Number of global counters supported.""" - n_conditions: int - """Number of condition-events supported.""" - n_inputs: int - """Number of input channels.""" - input_description: bytes - """Array indicating the state machine's onboard input channel types.""" - n_outputs: int - """Number of channels in the state machine's output channel description array.""" - output_description: bytes - """Array indicating the state machine's onboard output channel types.""" - cycle_frequency: int - """Frequency of the state machine's refresh cycle during a trial in Hertz.""" - n_modules: int - """Number of modules supported by the state machine.""" - - class BpodError(Exception): """ Exception class for Bpod-related errors. @@ -224,9 +126,10 @@ def run(self) -> None: use_back_op = self._use_back_op event_names = self._event_names - # create buffers for repeated serial reads + # create buffers / memoryview for repeated serial reads opcode_buf = bytearray(2) # buffer for opcodes event_data_buf = bytearray(259) # max 255 events + 4 bytes for n_cycles + event_data_view = memoryview(event_data_buf) # confirm the state machine if self._confirm_fsm: @@ -236,6 +139,7 @@ def run(self) -> None: # read the start time of the state machine t0 = serial.read_uint64() + # TODO: get time.perf_counter() logger.debug('%d µs: Starting state machine #%d', t0, index) logger.debug('%d µs: State %d', t0, current_state) # TODO: handle start of state machine @@ -243,14 +147,26 @@ def run(self) -> None: # enter the reading loop while not self._stop_event.is_set(): + # TODO: thread is blocked by readinto() + # + # Options: + # a) use serial timeout, + # b) while serial.in_waiting() < 2: + # if self._stop_event.wait(timeout=0.01): + # break + # c) thread entirely controlled by bpod (no _stop_event required) + # + # readinto returns number of bytes read, so we can use it to check for + # timeout + # read the next two opcodes serial.readinto(opcode_buf) + # TODO: get time.perf_counter() opcode, param = opcode_buf if opcode == 1: # handle events # read `param` event bytes + 4 bytes for n_cycles (uInt32) - event_data_view = memoryview(event_data_buf)[: param + 4] - serial.readinto(event_data_view) + serial.readinto(event_data_view[: param + 4]) # unpack the number of cycles, calculate the event's timestamp (n_cycles,) = STRUCT_UINT32.unpack_from(event_data_view, param) @@ -302,54 +218,9 @@ def run(self) -> None: # TODO: handle end of state machine -class AbstractBpod(DocstringInheritanceMixin, ABC): - """Abstract base class for Bpod objects.""" - - _version: VersionInfo - _hardware: HardwareConfiguration - _serial_number: str - - @property - @abstractmethod - def name(self) -> str | None: - """The Bpod's user-defined name, or :obj:`None` if not set.""" - - @property - @abstractmethod - def location(self) -> str | None: - """The Bpod's user-defined location, or :obj:`None` if not set.""" - - @property - def version(self) -> VersionInfo: - """Version information of the Bpod's firmware and hardware.""" - return self._version - - @property - def serial_number(self) -> str: - """The Bpod's unique serial number.""" - return self._serial_number - - @abstractmethod - def set_status_led(self, enabled: bool) -> bool: - """ - Enable or disable the Bpod's status LED. - - Parameters - ---------- - enabled : bool - True to enable the status LED, False to disable. - - Returns - ------- - bool - True if the operation was successful, False otherwise. - """ - - -class Bpod(USBSerialDevice, AbstractBpod): +class Bpod(SerialDevice, AbstractBpod): """Class for interfacing with a Bpod Finite State Machine.""" - _device_type = 'Bpod Finite State Machine' _settings: SettingsDict _fsm_thread: FSMThread | None = None _zmq_service: DualChannelHost @@ -379,9 +250,11 @@ class Bpod(USBSerialDevice, AbstractBpod): @validate_call def __init__( - self, port: str | None = None, serial_number: str | None = None + self, + port: str | None = None, + serial_number: str | None = None, + remote: bool = False, ) -> None: - self._finalizer = weakref.finalize(self, self._finalize) logger.info('bpod_core %s', bpod_core_version) self._settings = SettingsDict(CONFIG_PATH / 'settings.json') @@ -394,7 +267,7 @@ def __init__( # identify Bpod by port or serial number, open connection bpod_port, _ = self._identify_bpod(port, serial_number) - super().__init__(bpod_port) + super().__init__(port=bpod_port, open_connection=True) self._serial_number = self._port_info.serial_number or 'unknown' # get firmware version and machine type; enforce version requirements @@ -413,7 +286,15 @@ def __init__( self.update_modules() # start ZeroMQ service - self._start_zmq() + self._start_zmq(use_zeroconf=remote) + + # register destructor + self._finalizer = weakref.finalize( + self, + Bpod._finalize, + self._serial, + self._zmq_service, + ) # log hardware information logger.info( @@ -428,10 +309,12 @@ def __init__( self.version.pcb, ) - @property - def serial0(self) -> ExtendedSerial: - """Primary serial device for communication with the Bpod.""" - return self._serial + @staticmethod + def _finalize(serial: ExtendedSerial, zmq_service: DualChannelHost) -> None: + with contextlib.suppress(SerialException): + Bpod._request_disconnect(serial) + serial.close() + zmq_service.close() def __exit__( self, @@ -440,7 +323,8 @@ def __exit__( exc_tb: TracebackType | None, ) -> None: """Exit context and close connection.""" - super().__exit__(exc_type, exc_val, exc_tb) + self._finalizer.detach() + self.close() self._stop_zmq() def open(self) -> None: @@ -458,13 +342,30 @@ def open(self) -> None: self._handshake() def close(self) -> None: - """Close the connection to the Bpod.""" + """ + Close the connection to the Bpod. + + Raises + ------ + SerialException + If the port could not be closed. + """ self.stop_state_machine() + if hasattr(self, 'serial0'): + self._request_disconnect(self.serial0) super().close() - def _finalize(self) -> None: - self.close() - self._stop_zmq() + @staticmethod + def _request_disconnect(serial: ExtendedSerial) -> None: + """Send a close request to the Bpod.""" + if getattr(serial, 'is_open', False): + logger.debug('Sending close request to Bpod') + serial.write(b'Z') + + @property + def serial0(self) -> ExtendedSerial: + """Primary serial device for communication with the Bpod.""" + return self._serial def _zmq_handler(self, message: dict[str, Any]) -> dict[str, Any]: msg_type = message.get('type', 'unknown') @@ -500,13 +401,13 @@ def _zmq_handler(self, message: dict[str, Any]) -> dict[str, Any]: } return response - def _start_zmq(self) -> None: + def _start_zmq(self, use_zeroconf: bool) -> None: port_pub = self._get_setting(['devices', self._serial_number, 'port_pub']) port_rep = self._get_setting(['devices', self._serial_number, 'port_rep']) self._zmq_service = DualChannelHost( service_name=self.name if self.name else f'bpod_{self._serial_number}', - service_type='_bpod', - txt_record={ + service_type='bpod', + properties={ 'description': f'Bpod Finite State Machine {self.version.machine_str}', 'serial': self._serial_number or '', 'name': self.name or '', @@ -515,6 +416,7 @@ def _start_zmq(self) -> None: 'core': bpod_core_version, }, event_handler=self._zmq_handler, + remote=use_zeroconf, port_pub=cast('int | None', port_pub), port_rep=cast('int | None', port_rep), ) @@ -668,7 +570,7 @@ def _detect_additional_serial_ports(self) -> None: # First, assemble a list of candidate ports candidate_ports = find_ports( vid=VIDS_BPOD, - pid=[PIDsTeensy.DUAL_SERIAL, PIDsTeensy.TRIPLE_SERIAL], + pid=[TeensyPID.DUAL_SERIAL, TeensyPID.TRIPLE_SERIAL], serial_number=self._serial_number, device=re.compile(rf'^(?!{re.escape(str(self.port))}$).*$'), ) @@ -744,6 +646,7 @@ def _set_enable_inputs(self) -> bool: return self.serial0.read(1) == b'\x01' def reset_session_clock(self) -> bool: + # TODO: Get timestamp / time.monotonic() / time.perf_counter() logger.debug('Resetting session clock') return self.serial0.verify(b'*') @@ -1616,7 +1519,7 @@ def __init__( try: self._zmq = DualChannelClient( - '_bpod._tcp.local.', + service_type='bpod', address=address, discovery_timeout=timeout, txt_properties=properties, diff --git a/bpod_core/bpod/abc.py b/bpod_core/bpod/abc.py new file mode 100644 index 00000000..e70b3c8b --- /dev/null +++ b/bpod_core/bpod/abc.py @@ -0,0 +1,50 @@ +"""Abstract base classes used by the bpod module.""" + +from abc import ABC, abstractmethod + +from bpod_core.bpod.structs import HardwareConfiguration, VersionInfo +from bpod_core.misc import DocstringInheritanceMixin + + +class AbstractBpod(DocstringInheritanceMixin, ABC): + """Abstract base class for Bpod objects.""" + + _version: VersionInfo + _hardware: HardwareConfiguration + _serial_number: str + + @property + @abstractmethod + def name(self) -> str | None: + """The Bpod's user-defined name, or :obj:`None` if not set.""" + + @property + @abstractmethod + def location(self) -> str | None: + """The Bpod's user-defined location, or :obj:`None` if not set.""" + + @property + def version(self) -> VersionInfo: + """Version information of the Bpod's firmware and hardware.""" + return self._version + + @property + def serial_number(self) -> str: + """The Bpod's unique serial number.""" + return self._serial_number + + @abstractmethod + def set_status_led(self, enabled: bool) -> bool: + """ + Enable or disable the Bpod's status LED. + + Parameters + ---------- + enabled : bool + True to enable the status LED, False to disable. + + Returns + ------- + bool + True if the operation was successful, False otherwise. + """ diff --git a/bpod_core/bpod/constants.py b/bpod_core/bpod/constants.py new file mode 100644 index 00000000..4f4c41e0 --- /dev/null +++ b/bpod_core/bpod/constants.py @@ -0,0 +1,36 @@ +"""Constants used by the bpod module.""" + +from bpod_core.constants import PLATFORMDIRS, VID_TEENSY, TeensyPID + +VIDS_BPOD = [VID_TEENSY] +"""Vendor IDs of supported Bpod devices""" + +PIDS_BPOD = [TeensyPID.SERIAL, TeensyPID.DUAL_SERIAL, TeensyPID.TRIPLE_SERIAL] +"""List of Product IDs of supported Bpod devices""" + +MIN_BPOD_FW_VERSION = (23, 0) +"""minimum supported firmware version (major, minor)""" + +MIN_BPOD_HW_VERSION = 3 +"""minimum supported hardware version""" + +MAX_BPOD_HW_VERSION = 4 +"""maximum supported hardware version""" + +CHANNEL_TYPES_INPUT = { + b'U': 'Serial', + b'X': 'SoftCode', + b'Z': 'SoftCodeApp', + b'F': 'Flex', + b'D': 'Digital', + b'B': 'BNC', + b'W': 'Wire', + b'P': 'Port', +} +CHANNEL_TYPES_OUTPUT = CHANNEL_TYPES_INPUT.copy() +CHANNEL_TYPES_OUTPUT.update({b'V': 'Valve', b'P': 'PWM'}) +N_SERIAL_EVENTS_DEFAULT = 15 +VALID_OPERATORS = {'exit', '>exit', '>back'} +MACHINE_TYPES = {3: 'r2.0-2.5', 4: '2+ r1.0'} +CONFIG_PATH = PLATFORMDIRS.user_config_path +DISCOVERY_TIMEOUT = 0.11 diff --git a/bpod_core/bpod/structs.py b/bpod_core/bpod/structs.py new file mode 100644 index 00000000..cb17af54 --- /dev/null +++ b/bpod_core/bpod/structs.py @@ -0,0 +1,81 @@ +"""Data structures used by the bpod module.""" + +import msgspec + + +class BpodSettings(msgspec.Struct): + """Settings for a specific Bpod device.""" + + serial_number: str + """Serial number of the device.""" + name: str = '' + """User-defined name of the device.""" + location: str = '' + """User-defined location of the device.""" + zmq_port_pub: int | None = None + """Port number for the ZeroMQ PUB service.""" + zmq_port_rep: int | None = None + """Port number for the ZeroMQ REP service.""" + + +class BpodInfo(msgspec.Struct): + """Information about a specific Bpod device.""" + + serial_number: str + """Serial number of the device.""" + port: str | None = None + """Port on which the device is connected.""" + name: str = '' + """User-defined name of the device.""" + location: str = '' + """User-defined location of the device.""" + zmq_pub: str | None = None + """ZeroMQ PUB service address.""" + zmq_rep: str | None = None + """ZeroMQ REP service address.""" + + +class VersionInfo(msgspec.Struct, frozen=True): + """Data structure representing various version information.""" + + firmware: tuple[int, int] + """Firmware version (major, minor)""" + machine: int + """Machine type (numerical)""" + machine_str: str + """Machine type (string)""" + pcb: int | None + """PCB revision, if applicable""" + bpod_core: str + """bpod-core version""" + + +class HardwareConfiguration(msgspec.Struct, frozen=True): + """Represents the Bpod's on-board hardware configuration.""" + + max_states: int + """Maximum number of supported states in a single state machine description.""" + cycle_period: int + """Period of the state machine's refresh cycle during a trial in microseconds.""" + max_serial_events: int + """Maximum number of behavior events allocatable among connected modules.""" + max_bytes_per_serial_message: int + """Maximum number of bytes allowed per serial message.""" + n_global_timers: int + """Number of global timers supported.""" + n_global_counters: int + """Number of global counters supported.""" + n_conditions: int + """Number of condition-events supported.""" + n_inputs: int + """Number of input channels.""" + input_description: bytes + """Array indicating the state machine's onboard input channel types.""" + n_outputs: int + """Number of channels in the state machine's output channel description array.""" + output_description: bytes + """Array indicating the state machine's onboard output channel types.""" + cycle_frequency: int + """Frequency of the state machine's refresh cycle during a trial in Hertz.""" + n_modules: int + """Number of modules supported by the state machine.""" diff --git a/bpod_core/com.py b/bpod_core/com.py index 0a034c28..559fd2fc 100644 --- a/bpod_core/com.py +++ b/bpod_core/com.py @@ -3,7 +3,9 @@ import logging import re import struct +import weakref from collections.abc import Callable, Sequence +from contextlib import AbstractContextManager from types import TracebackType from typing import Any, cast @@ -412,20 +414,38 @@ def verify_serial_discovery( return False -class USBSerialDevice: +def _close_serial_connection(serial: Serial, raise_errors: bool = False) -> None: + """Close a serial connection if open.""" + if not getattr(serial, 'is_open', False): + return + logger.debug('Closing connection to serial device on %s', serial.port) + try: + serial.close() + except Exception as e: + if not raise_errors: + return + raise SerialException( + f'Failed to close connection to serial device on {serial.port}' + ) from e + + +class SerialDevice(AbstractContextManager): """Class that interfaces with a USB serial device.""" _serial: ExtendedSerial - """The serial connection to the USB device.""" + """The serial connection to the device.""" _port_info: ListPortInfo """Information about the serial port associated with the device.""" - _device_type: str = 'serial device' - """The type of the USB device, e.g., 'Bpod'.""" - - def __init__(self, port: str, open_connection: bool = True, **kwargs: Any) -> None: - """Initialize the USB serial device. + def __init__( + self, + port: str, + open_connection: bool = True, + serial_device_name: str = 'serial_device', + **kwargs: Any, + ) -> None: + """Initialize the serial device. Parameters ---------- @@ -445,21 +465,13 @@ def __init__(self, port: str, open_connection: bool = True, **kwargs: Any) -> No self._port_info = next(p for p in comports() if p.device == port) except StopIteration as e: raise SerialException(f'Serial port not found: {port}') from e + self._serial_device_name = serial_device_name self._serial = ExtendedSerial() + weakref.finalize(self, _close_serial_connection, self._serial) self._serial.port = port if open_connection: self.open() - def __enter__(self) -> Self: - """Enter the context manager. - - Returns - ------- - Self - The device instance. - """ - return self - def __exit__( self, exc_type: type[BaseException] | None, @@ -479,7 +491,8 @@ def __exit__( exc_tb : TracebackType | None The traceback object, if any. """ - self.close() + if hasattr(self, '_serial'): + _close_serial_connection(self._serial) def open(self) -> None: """Open the serial connection. @@ -493,12 +506,12 @@ def open(self) -> None: """ if self._serial.is_open: return - logger.debug('Opening connection to %s on %s', self._device_type, self.port) + logger.debug('Opening connection to serial device on %s', self.port) try: self._serial.open() except Exception as e: raise SerialException( - f'Failed to open connection to {self._device_type} on {self.port}' + f'Failed to open connection to serial device on {self.port}' ) from e def close(self) -> None: @@ -511,15 +524,8 @@ def close(self) -> None: serial.SerialException If the connection cannot be closed. """ - if not self._serial.is_open: - return - logger.debug('Closing connection to %s on %s', self._device_type, self.port) - try: - self._serial.close() - except Exception as e: - raise SerialException( - f'Failed to close connection to {self._device_type} on {self.port}' - ) from e + if hasattr(self, '_serial'): + _close_serial_connection(self._serial, raise_errors=True) @property def port(self) -> str: diff --git a/bpod_core/constants.py b/bpod_core/constants.py index dde20cd2..599a2c21 100644 --- a/bpod_core/constants.py +++ b/bpod_core/constants.py @@ -1,7 +1,10 @@ """Constants and identifiers used throughout the package.""" +from enum import IntEnum from struct import Struct +import platformdirs + # pre-compiled structs for common data types STRUCT_BOOL = Struct('?') """Compiled struct representing a boolean value.""" @@ -30,13 +33,15 @@ VID_TEENSY: int = 0x16C0 """Vendor ID of Teensy microcontrollers.""" +PLATFORMDIRS = platformdirs.PlatformDirs(appname='bpod-core', appauthor=False) + -class PIDsTeensy: +class TeensyPID(IntEnum): """Product IDs of Teensy microcontrollers.""" - SERIAL: int = 0x0483 + SERIAL = 0x0483 """Product ID of Teensy microcontrollers with single USB serial port.""" - DUAL_SERIAL: int = 0x048B + DUAL_SERIAL = 0x048B """Product ID of Teensy microcontrollers with dual USB serial ports.""" - TRIPLE_SERIAL: int = 0x048C + TRIPLE_SERIAL = 0x048C """Product ID of Teensy microcontrollers with triple USB serial ports.""" diff --git a/bpod_core/ipc.py b/bpod_core/ipc.py index af0f054d..fa6d0579 100644 --- a/bpod_core/ipc.py +++ b/bpod_core/ipc.py @@ -1,21 +1,26 @@ """Inter-process Communication, service discovery and related.""" import contextlib +import json +import logging import os -import re import socket import sys import threading -import uuid import weakref -from abc import ABC, abstractmethod -from collections.abc import Callable +from abc import abstractmethod +from collections.abc import Callable, Iterator +from pathlib import Path from types import TracebackType from typing import Any, Literal, cast +from uuid import UUID, uuid4 import msgspec +import platformdirs import zmq -from typing_extensions import Self +from platformdirs import user_runtime_path +from psutil import pid_exists +from pydantic import UUID4, validate_call from zeroconf import ( InterfaceChoice, IPVersion, @@ -25,9 +30,15 @@ Zeroconf, ) -from bpod_core.com import logger from bpod_core.constants import IP_ANY, IP_LOOPBACK -from bpod_core.misc import convert_to_snake_case, get_local_ipv4 +from bpod_core.misc import ( + RE_NON_ALPHANUMERIC, + get_local_ipv4, + prune_empty_parent_directories, + to_snake_case, +) + +logger = logging.getLogger(__name__) class DualChannelMessage(msgspec.Struct, omit_defaults=True, array_like=True): @@ -35,78 +46,271 @@ class DualChannelMessage(msgspec.Struct, omit_defaults=True, array_like=True): data: Any | None = msgspec.field(default=None, name='D') # message data -class DualChannelBase(ABC): +class DualChannelHandshake(msgspec.Struct, kw_only=True): + ipc_pub_sub: str | None = None + ipc_req_rep: str | None = None + tcp_pub_sub: str | None = None + tcp_req_rep: str | None = None + + +class LocalServiceInfo(msgspec.Struct): + """Information about a locally advertised service.""" + + service_name: str + service_type: str + address: str + pid: int + uuid: UUID + properties: dict[str, str | None] + + +class LocalServiceAdvertisement(contextlib.AbstractContextManager): + """ + File-based local service advertisement for IPC discovery. + + Advertises a service by writing a JSON file to the user's runtime directory. + This provides a lightweight alternative to Zeroconf for discovering services + on the same machine. Stale advertisements (from dead processes) are automatically + cleaned up during discovery. + + The advertisement is automatically removed when the instance is garbage collected + or when `stop()` is called explicitly. + """ + + runtime_directory = user_runtime_path('LocalServiceAdvertisements') + """Directory where service advertisement files are stored.""" + + service_file: Path + """Path to the advertisement file.""" + + _closed = False + """Flag to prevent double-finalization.""" + + @validate_call + def __init__( + self, + service_name: str, + service_type: str, + address: str, + properties: dict[str, str | None] | None = None, + *, + pid: int | None = None, + uuid: UUID4 | None = None, + ) -> None: + """ + Create a local service advertisement. + + Parameters + ---------- + service_name + The name of the service being advertised (e.g., 'Bpod 3'). + service_type : str + The type of service being advertised (e.g., 'bpod'). + address : str + The address where the service can be reached (e.g., 'ipc:///tmp/foo.ipc'). + properties : dict, optional + Additional key-value properties to advertise with the service. + pid : int, optional + Process ID of the service. Used to detect stale advertisements. + uuid : UUID4, optional + Unique identifier for this service instance. Generated if not provided. + """ + uuid = uuid or uuid4() + info = LocalServiceInfo( + service_name=service_name, + service_type=service_type, + address=address, + uuid=uuid, + pid=pid if pid is not None else os.getpid(), + properties=properties or {}, + ) + + self.service_file = self._get_service_file(service_type, uuid) + self.service_file.parent.mkdir(parents=True, exist_ok=True) + self._finalizer = weakref.finalize(self, self._close, self.service_file) + + json_data = msgspec.to_builtins(info) + self.service_file.write_text(json.dumps(json_data, indent=2)) + logger.debug("Advertising local service at '%s'", self.service_file) + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + """Remove the service advertisement and clean up empty directories.""" + if self._closed: + return + self._closed = True + self._finalizer.detach() + self._close(self.service_file) + + @staticmethod + def _close(service_file: Path) -> None: + with contextlib.suppress(Exception): + service_file.unlink(missing_ok=True) + logger.debug("Removed local service advertisement '%s'", service_file) + with contextlib.suppress(Exception): + prune_empty_parent_directories( + service_file.parent, + LocalServiceAdvertisement.runtime_directory, + remove_root=True, + ) + + @staticmethod + def _get_service_directory(service_type: str) -> Path: + """Get the directory for a service type.""" + runtime_directory = LocalServiceAdvertisement.runtime_directory + sanitized = RE_NON_ALPHANUMERIC.sub('_', service_type) + return runtime_directory / sanitized + + @staticmethod + def _get_service_file(service_type: str, uuid: UUID) -> Path: + """Get the path to a local service file.""" + service_dir = LocalServiceAdvertisement._get_service_directory(service_type) + return service_dir / f'{uuid.hex}.json' + + @staticmethod + def discover( + service_type: str, + properties: dict[str, str | None] | None = None, + ) -> Iterator[LocalServiceInfo]: + """Discover locally advertised services. + + Parameters + ---------- + service_type : str + The service type to discover. + properties : dict, optional + Properties to match against the service's properties. + + Yields + ------ + LocalServiceInfo + Information structure describing the discovered services. + """ + service_dir = LocalServiceAdvertisement._get_service_directory(service_type) + properties = properties or {} + + if service_dir.exists(): + for service_file in service_dir.glob('*.json'): + # Load service info + try: + data = json.loads(service_file.read_text()) + info = msgspec.convert(data, LocalServiceInfo) + except ( + json.JSONDecodeError, + msgspec.ValidationError, + OSError, + ): + continue + + # Remove service file if process no longer exists + if not pid_exists(info.pid): + service_file.unlink(missing_ok=True) + continue + + # Check if properties match + if all(info.properties.get(k) == v for k, v in properties.items()): + yield info + + # Clean up empty directories + with contextlib.suppress(OSError, ValueError): + prune_empty_parent_directories( + service_dir, + LocalServiceAdvertisement.runtime_directory, + remove_root=True, + ) + + +class DualChannelBase(contextlib.AbstractContextManager): _serialization: Literal['json', 'msgpack'] = 'msgpack' _encoder: msgspec.msgpack.Encoder | msgspec.json.Encoder _decoder: msgspec.msgpack.Decoder | msgspec.json.Decoder - _event_thread: threading.Thread + _event_thread: threading.Thread | None = None _socket_req_rep: zmq.Socket _socket_pub_sub: zmq.Socket + _closed = False def __init__(self) -> None: - self._closed = False self._lock_close = threading.Lock() - self._finalizer = weakref.finalize(self, self.close) self._zmq_context = zmq.Context() self._stop_event_loop = threading.Event() - def __enter__(self) -> Self: - """Enter context manager.""" - return self + @staticmethod + def _finalize_base( + event_thread: threading.Thread | None, + stop_event: threading.Event, + socket_req_rep: zmq.Socket, + socket_pub_sub: zmq.Socket, + zmq_context: zmq.Context, + ) -> None: + # event thread + is_alive = False + with contextlib.suppress(Exception): + is_alive = event_thread is not None and event_thread.is_alive() + if event_thread is not None and is_alive: + with contextlib.suppress(Exception): + stop_event.set() + event_thread.join(timeout=1) + if event_thread.is_alive(): + with contextlib.suppress(Exception): + logger.warning('Event thread did not terminate cleanly') + + # ZMQ sockets + with contextlib.suppress(Exception): + socket_req_rep.close(linger=0) + with contextlib.suppress(Exception): + socket_pub_sub.close(linger=0) + + # ZMQ context + with contextlib.suppress(Exception): + zmq_context.term() def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, - ) -> bool | None: + ) -> None: """Exit context manager.""" self.close() - return None @abstractmethod - def _event_loop(self) -> None: ... - - def close(self) -> bool: - """ - Close the instance. - - Releases resources, stops the event loop, and terminates ZeroMQ sockets and - context. - - Returns - ------- - bool - Returns True if the instance was successfully closed, False if it was - already closed. - """ - with self._lock_close: - if self._closed: - return False - self._closed = True - - if self._event_thread and self._event_thread.is_alive(): - self._stop_event_loop.set() - self._event_thread.join() - - # close sockets and terminate context - self._socket_req_rep.close(linger=0) - self._socket_pub_sub.close(linger=0) - self._zmq_context.term() - return True + def close(self) -> None: ... class DualChannelHost(DualChannelBase): - """A ZeroMQ host providing REQ/REP and PUB/SUB sockets with Zeroconf discovery.""" + """ + A ZeroMQ host providing REQ/REP and PUB/SUB sockets with service discovery. + + Provides two communication channels: a REQ/REP channel for synchronous + request-reply messaging and a PUB/SUB channel for broadcasting events to + subscribers. Incoming requests are dispatched to a user-provided + ``event_handler`` callback. + + The service is automatically advertised for discovery by + :class:`DualChannelClient`. When ``remote=True``, the service is advertised + via Zeroconf (mDNS) for network-wide discovery. When ``remote=False``, it is + advertised locally via a file in the user's runtime directory for IPC-only use + cases. On POSIX systems, clients on the same machine transparently upgrade to IPC + for lower latency. + """ - _rep_ipc_addr: str | None = None - _pub_ipc_addr: str | None = None + _zeroconf: Zeroconf | None = None + _zeroconf_service_info: ServiceInfo | None = None + _named_pipe_rep: Path | None = None + _named_pipe_pub: Path | None = None def __init__( self, service_name: str, service_type: str, - txt_record: dict[str | bytes, str | bytes | None] | None = None, + properties: dict[str, str | None] | None = None, event_handler: Callable[[Any], Any] | None = None, remote: bool = True, port_pub: int | None = None, @@ -121,9 +325,9 @@ def __init__( service_name : str Service name to advertise. service_type : str - Zeroconf service type (e.g., 'my_service'). - txt_record : dict, optional - Additional TXT records for Zeroconf service advertisement. + Service type. + properties : dict, optional + Additional properties for service advertisement. event_handler : callable, optional Function to handle incoming requests. remote : bool, default=True @@ -138,8 +342,7 @@ def __init__( # initialize base class super().__init__() - self.name = convert_to_snake_case(service_name).strip('_') - self.uuid = uuid.uuid4() + self.uuid = uuid4() self._bind_ip = IP_ANY if remote else IP_LOOPBACK self._local_ip = get_local_ipv4() if remote else IP_LOOPBACK @@ -151,34 +354,49 @@ def __init__( self._socket_req_rep.setsockopt(zmq.SNDHWM, 1000) self._socket_req_rep.setsockopt(zmq.RCVHWM, 1000) self._socket_pub_sub.setsockopt(zmq.SNDHWM, 1000) - - # bind sockets to IPC addresses (POSIX only) - # clients on localhost can upgrade to IPC (named pipes) for improved performance - if 'win' not in sys.platform: - self._rep_ipc_addr = f'ipc:///tmp/REQ_REP_{self.uuid.hex}.ipc' - self._pub_ipc_addr = f'ipc:///tmp/PUB_SUB_{self.uuid.hex}.ipc' + self._socket_pub_sub.setsockopt(zmq.IMMEDIATE, 1) + + # Bind sockets to IPC addresses for improved performance (POSIX only) + rep_ipc_addr: str | None = None + pub_ipc_addr: str | None = None + if os.name == 'posix': + if sys.platform.startswith('linux'): + # On linux we use abstract sockets for IPC, avoiding issues with + # filesystem, cleanup, permissions and stale files + rep_ipc_addr = f'ipc://@REQ_REP_{self.uuid.hex}' + pub_ipc_addr = f'ipc://@PUB_SUB_{self.uuid.hex}' + else: + # On other POSIX platforms we use filesystem Unix domain sockets + runtime_path = platformdirs.user_runtime_path(ensure_exists=True) + self._named_pipe_rep = runtime_path / f'REQ_REP_{self.uuid.hex}.ipc' + self._named_pipe_pub = runtime_path / f'PUB_SUB_{self.uuid.hex}.ipc' + self._named_pipe_rep.unlink(missing_ok=True) # pre-unlink before bind + self._named_pipe_pub.unlink(missing_ok=True) # to avoid collisions + rep_ipc_addr = 'ipc://' + self._named_pipe_rep.as_posix() + pub_ipc_addr = 'ipc://' + self._named_pipe_pub.as_posix() try: - self._socket_req_rep.bind(self._rep_ipc_addr) - self._socket_pub_sub.bind(self._pub_ipc_addr) - logger.debug("Binding REP socket to '%s'", self._rep_ipc_addr) - logger.debug("Binding PUB socket to '%s'", self._pub_ipc_addr) + self._socket_req_rep.bind(rep_ipc_addr) + logger.debug("Binding REP socket to '%s'", rep_ipc_addr) + self._socket_pub_sub.bind(pub_ipc_addr) + logger.debug("Binding PUB socket to '%s'", pub_ipc_addr) except zmq.ZMQError: logger.warning('Failed to bind IPC sockets; continuing without IPC') - self._rep_ipc_addr = None - self._pub_ipc_addr = None + if rep_ipc_addr: + with contextlib.suppress(zmq.ZMQError): + self._socket_req_rep.unbind(rep_ipc_addr) + rep_ipc_addr = None + pub_ipc_addr = None def bind_tcp(zmq_socket: zmq.Socket, tcp_port: int | None) -> tuple[str, int]: - """Bind socket to TCP address with preferred port.""" - bind_address = f'tcp://{self._bind_ip}' - service_address = f'tcp://{self._local_ip}' + """Helper function binding socket to TCP address with preferred port.""" if tcp_port is not None: try: - zmq_socket.bind(f'{bind_address}:{tcp_port}') + zmq_socket.bind(f'tcp://{self._bind_ip}:{tcp_port}') except zmq.ZMQError: tcp_port = None if tcp_port is None: - tcp_port = zmq_socket.bind_to_random_port(bind_address) - return f'{service_address}:{tcp_port}', tcp_port + tcp_port = zmq_socket.bind_to_random_port(f'tcp://{self._bind_ip}') + return f'tcp://{self._local_ip}:{tcp_port}', tcp_port # bind sockets to TCP addresses self.rep_tcp_addr, self.rep_tcp_port = bind_tcp(self._socket_req_rep, port_rep) @@ -187,7 +405,7 @@ def bind_tcp(zmq_socket: zmq.Socket, tcp_port: int | None) -> tuple[str, int]: logger.debug("Binding PUB socket to '%s'", self.pub_tcp_addr) # select serialization protocol / initialize encoders + decoders - self._serialization_protocol = serialization + self._serialization = serialization if serialization == 'msgpack': self._encoder = msgspec.msgpack.Encoder() self._decoder = msgspec.msgpack.Decoder(type=DualChannelMessage) @@ -198,36 +416,137 @@ def bind_tcp(zmq_socket: zmq.Socket, tcp_port: int | None) -> tuple[str, int]: raise ValueError(f'Unsupported serialization protocol: {serialization}') # start event loop for request handling - self._user_event_handler = event_handler or self._empty_event_handler - self._event_handler_lock = threading.RLock() - self._event_thread = threading.Thread(target=self._event_loop, daemon=True) + self._event_handler_lock = threading.Lock() + handshake_data = DualChannelHandshake( + ipc_pub_sub=pub_ipc_addr, + ipc_req_rep=rep_ipc_addr, + tcp_pub_sub=self.pub_tcp_addr, + tcp_req_rep=self.rep_tcp_addr, + ) + self._event_thread = threading.Thread( + target=DualChannelHost._event_loop, + args=( + self._stop_event_loop, + self._socket_req_rep, + self._decoder, + self._encoder, + self._serialization, + event_handler or self._empty_event_handler, + self._event_handler_lock, + handshake_data, + ), + daemon=True, + ) self._event_thread.start() - # advertise service via Zeroconf - self._service_type = f'_{service_type.strip("_")}._tcp.local.' - self._service_name = f'{self.name}.{self._service_type}' - self._service_info = ServiceInfo( - type_=self._service_type, - name=self._service_name, - port=self.rep_tcp_port, - addresses=[socket.inet_aton(self._local_ip)], - properties=txt_record or {}, - server=f'{socket.gethostname()}.local.', + # advertise service locally + self._local_advertisement = LocalServiceAdvertisement( + service_name=service_name, + service_type=service_type, + address=rep_ipc_addr or self.rep_tcp_addr, + pid=os.getpid(), + uuid=self.uuid, + properties=properties, ) - self._zeroconf = Zeroconf( - interfaces=InterfaceChoice.Default if remote else IP_LOOPBACK, - ip_version=IPVersion.V4Only, + + # advertise service via Zeroconf for remote discovery + if remote: + zeroconf_type = f'_{to_snake_case(service_type)}._tcp.local.' + zeroconf_name = f'{service_name}.{zeroconf_type}' + self._zeroconf_service_info = ServiceInfo( + type_=zeroconf_type, + name=zeroconf_name, + port=self.rep_tcp_port, + addresses=[socket.inet_aton(self._local_ip)], + properties=properties or {}, + server=f'{socket.gethostname()}.local.', + ) + self._zeroconf = Zeroconf( + interfaces=InterfaceChoice.Default, + ip_version=IPVersion.V4Only, + ) + self._zeroconf.register_service( + self._zeroconf_service_info, allow_name_change=True + ) + self._zeroconf_service_name = self._zeroconf_service_info.name + logger.debug( + "Registering Zeroconf service '%s'", self._zeroconf_service_name + ) + + # register finalizer to clean up resources on exit + self._finalizer = weakref.finalize( + self, + DualChannelHost._finalize, + self._event_thread, + self._stop_event_loop, + self._socket_req_rep, + self._socket_pub_sub, + self._zmq_context, + self._local_advertisement, + self._zeroconf, + self._zeroconf_service_info, + self._named_pipe_rep, + self._named_pipe_pub, + ) + + @staticmethod + def _finalize( + event_thread: threading.Thread | None, + stop_event: threading.Event, + socket_req_rep: zmq.Socket, + socket_pub_sub: zmq.Socket, + zmq_context: zmq.Context, + local_advertisement: LocalServiceAdvertisement, + zeroconf: Zeroconf | None, + service: ServiceInfo | None, + named_pipe_rep: Path | None, + named_pipe_pub: Path | None, + ) -> None: + """Finalize the host by unregistering the service.""" + # local advertisement + with contextlib.suppress(Exception): + local_advertisement.close() + + # zeroconf + if zeroconf is not None: + if service is not None: + with contextlib.suppress(Exception): + logger.debug("Unregistering Zeroconf service '%s'", service.name) + zeroconf.unregister_service(service) + with contextlib.suppress(Exception): + zeroconf.close() + + # call base class finalizer + DualChannelBase._finalize_base( + event_thread, + stop_event, + socket_req_rep, + socket_pub_sub, + zmq_context, ) - self._zeroconf.register_service(self._service_info, allow_name_change=True) - self._service_name = self._service_info.name - logger.debug("Registering Zeroconf service '%s'", self._service_name) + + # named pipes + for pipe in (named_pipe_rep, named_pipe_pub): + if pipe is not None: + with contextlib.suppress(Exception): + pipe.unlink(missing_ok=True) @staticmethod def _empty_event_handler(_: Any) -> dict: """Default event handler that returns an empty dict.""" return {} - def _event_loop(self) -> None: + @staticmethod + def _event_loop( # noqa: PLR0913 + stop_event: threading.Event, + req_rep_socket: zmq.Socket, + decoder: msgspec.msgpack.Decoder | msgspec.json.Decoder, + encoder: msgspec.msgpack.Encoder | msgspec.json.Encoder, + serialization_protocol: str, + event_handler: Callable[[Any], dict], + event_handler_lock: threading.Lock, + handshake_data: DualChannelHandshake, + ) -> None: """ Handle incoming REQ messages. @@ -238,87 +557,83 @@ def _event_loop(self) -> None: - Responds with handshake data for type 'H'. - Sends an error for unknown types. """ - while not self._stop_event_loop.is_set(): + format_error = DualChannelHost._format_error + + # avoid overhead of attribute lookups + send = req_rep_socket.send + recv = req_rep_socket.recv + decode = decoder.decode + encode = encoder.encode + + while not stop_event.is_set(): # wait for incoming requests (short poll so we can check stop_event) - if not self._socket_req_rep.poll(100): + if not req_rep_socket.poll(100): continue # guard the actual recv with try/except so shutdown can't hang us try: - request_frame = self._socket_req_rep.recv(copy=False) + request_frame = recv(copy=False) except zmq.ZMQError as e: logger.exception('Error receiving request from client', exc_info=e) continue # try to decode the request try: - request: DualChannelMessage = self._decoder.decode(request_frame.bytes) + req: DualChannelMessage = decode(request_frame.buffer) except msgspec.DecodeError as e1: # try the other serialization as a fallback try: - if self._serialization_protocol == 'msgpack': - request = msgspec.json.decode( - request_frame.bytes, type=DualChannelMessage + if serialization_protocol == 'msgpack': + req = msgspec.json.decode( + request_frame.buffer, type=DualChannelMessage ) else: - request = msgspec.msgpack.decode( - request_frame.bytes, type=DualChannelMessage + req = msgspec.msgpack.decode( + request_frame.buffer, type=DualChannelMessage ) except msgspec.DecodeError: - logger.exception( - 'Error decoding request from client: %s', - request_frame.bytes, - exc_info=e1, - ) - reply = self._format_error(type(e1).__name__, e1.args[0]) + logger.exception('Error decoding request from client', exc_info=e1) + reply = format_error(type(e1).__name__, str(e1)) try: - reply_bytes = self._encoder.encode(reply) - self._socket_req_rep.send(reply_bytes, copy=False) + reply_bytes = encode(reply) + send(reply_bytes, copy=False) except (msgspec.EncodeError, zmq.ZMQError) as e2: logger.exception('Error sending reply to client', exc_info=e2) continue # handle request depending on request type - match request.type: + match req.type: case 'R': # general request try: - with self._event_handler_lock: - reply_data = self._user_event_handler(request.data) + with event_handler_lock: + reply_data = event_handler(req.data) except Exception as e: logger.exception( 'Event handler raised an exception', exc_info=e ) - reply = self._format_error(type(e).__name__, e.args[0]) + reply = format_error(type(e).__name__, str(e)) else: reply = DualChannelMessage('R', reply_data) case 'H': # handshake for communicating TCP and IPC addresses - reply = DualChannelMessage( - type='H', - data={ - 'ipc_pub_sub': self._pub_ipc_addr, - 'ipc_req_rep': self._rep_ipc_addr, - 'tcp_pub_sub': self.pub_tcp_addr, - 'tcp_req_rep': self.rep_tcp_addr, - }, - ) + reply = DualChannelMessage(type='H', data=handshake_data) case _: # unknown request type - message = f'Unknown request type: {request.type}' + message = f'Unknown request type: {req.type}' logger.error(message) - reply = self._format_error('RequestError', message) + reply = format_error('RequestError', message) # encode reply try: - reply_bytes = self._encoder.encode(reply) + reply_bytes = encode(reply) except msgspec.EncodeError as e: logger.exception('Error encoding reply to client', exc_info=e) - reply = self._format_error(type(e).__name__, e.args[0]) - reply_bytes = self._encoder.encode(reply) + reply = format_error(type(e).__name__, str(e)) + reply_bytes = encode(reply) # send reply try: - self._socket_req_rep.send(reply_bytes, copy=False) + send(reply_bytes, copy=False) except zmq.ZMQError as e: logger.exception('Error sending reply to client', exc_info=e) @@ -341,31 +656,25 @@ def _format_error(name: str, message: str) -> DualChannelMessage: """ return DualChannelMessage(type='E', data={'name': name, 'message': message}) - def close(self) -> bool: - """ - Close the host and clean up resources. - - Returns - ------- - bool - True if the host closed successfully, False otherwise. - """ - if not super().close(): - return False - - # Unregister Zeroconf service - logger.debug("Unregistering Zeroconf service '%s'", self._service_name) - self._zeroconf.unregister_service(self._service_info) - self._zeroconf.close() - - # remove IPC files - with contextlib.suppress(FileNotFoundError): - if self._rep_ipc_addr is not None: - os.remove(self._rep_ipc_addr[6:]) - if self._pub_ipc_addr is not None: - os.remove(self._pub_ipc_addr[6:]) - - return True + def close(self) -> None: + """Close the host and clean up resources.""" + with self._lock_close: + if self._closed: + return + self._closed = True + self._finalizer.detach() + self._finalize( + self._event_thread, + self._stop_event_loop, + self._socket_req_rep, + self._socket_pub_sub, + self._zmq_context, + self._local_advertisement, + self._zeroconf, + self._zeroconf_service_info, + self._named_pipe_rep, + self._named_pipe_pub, + ) class DualChannelClient(DualChannelBase): @@ -378,6 +687,7 @@ def __init__( event_handler: Callable[[dict], Any] | None = None, discovery_timeout: float = 10.0, txt_properties: dict | None = None, + remote: bool = True, ) -> None: """ Initialize a DualChannelClient instance. @@ -385,7 +695,7 @@ def __init__( Parameters ---------- service_type : str - The mDNS service type to discover or connect to. + The service type to discover or connect to. address : str, optional The direct connection address for the REQ channel, by default None. event_handler : callable, optional @@ -394,6 +704,8 @@ def __init__( Timeout in seconds for service discovery, by default 10.0. txt_properties : dict, optional Properties for service filtering during discovery, by default None. + remote : bool, optional + Whether to use Zeroconf for discovering remote services, by default True. """ # initialize base class super().__init__() @@ -412,12 +724,12 @@ def __init__( self._address_req = address else: self._address_req, _ = discover( - service_type, txt_properties, discovery_timeout + service_type, txt_properties, remote, discovery_timeout ) self._socket_req_rep.connect(self._address_req) self._lock_req = threading.Lock() - logger.debug("Binding REQ socket to '%s'", self._address_req) - self.is_local = '127.0.0.1' in self._address_req + logger.debug("Connecting REQ socket to '%s'", self._address_req) + self.is_local = any(x in self._address_req for x in ('127.0.0.1', 'ipc://')) # perform handshake self._handshake() @@ -426,25 +738,50 @@ def __init__( if event_handler is not None: self._socket_pub_sub.connect(self._address_sub) self._socket_pub_sub.setsockopt_string(zmq.SUBSCRIBE, '') - logger.debug("Binding SUB socket to '%s'", self._address_sub) + logger.debug("Connecting SUB socket to '%s'", self._address_sub) else: - logger.debug('Not binding SUB socket for lack of event handler') + logger.debug('Not connecting SUB socket for lack of event handler') # start event loop for subscription handling - self._event_handler = event_handler - self._event_thread = threading.Thread(target=self._event_loop, daemon=True) + self._event_thread = threading.Thread( + target=DualChannelClient._event_loop, + args=( + self._stop_event_loop, + self._socket_pub_sub, + self._decoder, + event_handler, + ), + daemon=True, + ) if event_handler is not None: self._event_thread.start() - def _event_loop(self) -> None: + # register finalizer to clean up resources on exit + self._finalizer = weakref.finalize( + self, + DualChannelBase._finalize_base, + self._event_thread, + self._stop_event_loop, + self._socket_req_rep, + self._socket_pub_sub, + self._zmq_context, + ) + + @staticmethod + def _event_loop( + stop_event: threading.Event, + socket_sub: zmq.Socket, + decoder: msgspec.msgpack.Decoder | msgspec.json.Decoder, + event_handler: Callable, + ) -> None: """Process incoming PUB messages.""" - while not self._stop_event_loop.is_set(): - if not self._socket_pub_sub.poll(100): + while not stop_event.is_set(): + if not socket_sub.poll(100): continue - msg = self._socket_pub_sub.recv() - msg = self._decoder.decode(msg) + msg = socket_sub.recv() + msg = decoder.decode(msg) try: - cast('Callable', self._event_handler)(msg) + event_handler(msg) except Exception as e: logger.exception('Subscription handler raised an exception', exc_info=e) @@ -453,14 +790,15 @@ def _handshake(self) -> None: reply_type, reply_data = self._req('H') if ( self.is_local - and sys.platform in ('darwin', 'linux') + and os.name == 'posix' and reply_data.get('ipc_req_rep') is not None and reply_data.get('ipc_pub_sub') is not None ): - self._socket_req_rep.disconnect(self._address_req) - self._address_req = reply_data.get('ipc_req_rep') - logger.debug("Rebinding REQ socket to '%s'", self._address_req) - self._socket_req_rep.connect(self._address_req) + if self._address_req.startswith('tcp://'): + self._socket_req_rep.disconnect(self._address_req) + self._address_req = reply_data.get('ipc_req_rep') + logger.debug("Reconnecting REQ socket to '%s'", self._address_req) + self._socket_req_rep.connect(self._address_req) self._address_sub = reply_data.get('ipc_pub_sub') else: self._address_sub = reply_data.get('tcp_pub_sub') @@ -488,7 +826,7 @@ def _req(self, request_type: str, data: Any | None = None) -> tuple[str, Any]: # receive and decode reply try: - reply: DualChannelMessage = self._decoder.decode(reply_frame.bytes) + reply: DualChannelMessage = self._decoder.decode(reply_frame.buffer) # switch serialization format except msgspec.DecodeError as e: @@ -496,7 +834,7 @@ def _req(self, request_type: str, data: Any | None = None) -> tuple[str, Any]: new_serialization_module = getattr(msgspec, new_format) new_decoder = new_serialization_module.Decoder(type=DualChannelMessage) try: - reply = new_decoder.decode(reply_frame.bytes) + reply = new_decoder.decode(reply_frame.buffer) except msgspec.DecodeError: raise ValueError('Error decoding reply from host') from e logger.debug('Switching to %s serialization', new_format) @@ -536,21 +874,39 @@ def request(self, **kwargs: Any) -> Any: logger.error("Received unknown reply type: '%s'", reply_type) return {} + def close(self) -> None: + """Close the client and clean up resources.""" + with self._lock_close: + if self._closed: + return + self._closed = True + self._finalizer.detach() + self._finalize_base( + self._event_thread, + self._stop_event_loop, + self._socket_req_rep, + self._socket_pub_sub, + self._zmq_context, + ) + def discover( service_type: str, properties: dict[str, str | None] | None = None, + remote: bool = True, timeout: float = 10, -) -> tuple[str, dict[bytes, bytes | None]]: +) -> tuple[str, dict[str, str | None]]: """ - Discover a Zeroconf device/service on the local network matching given properties. + Discover a device/service on the local network matching given properties. Parameters ---------- service_type : str - The Zeroconf service type to discover, e.g., '_zmq._tcp.local.' + The service type to discover, e.g., 'bpod' properties : dict, optional Dictionary of expected service properties to match. + remote : bool, optional + Whether to search for a matching service on the network, by default True. timeout : float, optional How many seconds to wait for a matching service before timing out. Default is 10. @@ -560,7 +916,7 @@ def discover( str The Zeroconf service address, e.g., 'tcp://192.168.1.10:1234'. dict - The TXT record of the service + A dictionary of service properties. Raises ------ @@ -569,36 +925,41 @@ def discover( """ properties = properties or {} address = '' - protocol = (m := re.search(r'_(tcp|udp)\.', service_type)) and m.group(1) event = threading.Event() txt_record = {} + zeroconf_service_type = f'_{to_snake_case(service_type)}._tcp.local.' + + for local_info in LocalServiceAdvertisement.discover(service_type, properties): + return local_info.address, local_info.properties + if not remote: + raise RuntimeError('No matching service found locally') def on_state_change( *, name: str, state_change: ServiceStateChange, **_: Any ) -> None: - nonlocal address, protocol, txt_record, event + nonlocal address, txt_record + if event.is_set(): + return if state_change is ServiceStateChange.Added: - info = zeroconf.get_service_info(service_type, name) - if not info or not info.addresses: + remote_info = zeroconf.get_service_info(zeroconf_service_type, name) + if remote_info is None or len(remote_info.addresses) == 0: return for k, v in properties.items(): - key = k.encode('utf-8') if isinstance(k, str) else k - value = v.encode('utf-8') if isinstance(v, str) else v - if info.properties.get(key) != value: + if remote_info.decoded_properties.get(k) != v: return - port = info.port - ip = socket.inet_ntoa(info.addresses[0]) + port = remote_info.port + ip = socket.inet_ntoa(remote_info.addresses[0]) ip = '127.0.0.1' if ip == get_local_ipv4() else ip - address = f'{protocol}://{ip}:{port}' - txt_record = info.properties + address = f'tcp://{ip}:{port}' + txt_record = remote_info.decoded_properties event.set() zeroconf = Zeroconf() try: - ServiceBrowser(zeroconf, service_type, handlers=[on_state_change]) + ServiceBrowser(zeroconf, zeroconf_service_type, handlers=[on_state_change]) found = event.wait(timeout) finally: zeroconf.close() if not found: - raise TimeoutError('No matching device found via Zeroconf') + raise TimeoutError('No matching service found') return address, txt_record diff --git a/bpod_core/misc.py b/bpod_core/misc.py index 68c0bedb..8c05f3fb 100644 --- a/bpod_core/misc.py +++ b/bpod_core/misc.py @@ -21,9 +21,14 @@ K = TypeVar('K') V = TypeVar('V') -RE_SANITIZE = re.compile(r'[^a-zA-Z0-9_]') -RE_SNAKE_CASE = re.compile(r'(?<=[a-z])(?=[A-Z])|(?<=\D)(?=\d)|(?<=\d)(?=\D)') -RE_UNDERSCORES = re.compile(r'_{2,}') +RE_NON_ALPHANUMERIC = re.compile(r'[^a-zA-Z0-9_]') +"""Match non-alphanumeric characters except underscores.""" +RE_ACRONYM = re.compile(r'([A-Z]+)([A-Z][a-z])') +"""Match acronym boundaries.""" +RE_CASE_TRANSITION = re.compile(r'(?<=[a-z])(?=[A-Z])|(?<=\D)(?=\d)|(?<=\d)(?=\D)') +"""Match case and digit transitions.""" +RE_MULTIPLE_UNDERSCORES = re.compile(r'_{2,}') +"""Match multiple consecutive underscores.""" class DocstringInheritanceMixin: @@ -46,35 +51,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: break -def sanitize_string(string: str, substitute: str = '_') -> str: - """ - Replace non-alphanumeric characters in a string with a given substitute. - - Parameters - ---------- - string : str - The input string to be sanitized. - substitute : str, optional - The character(s) to replace non-alphanumeric characters with. - Defaults to '_'. - - Returns - ------- - str - A sanitized string where all non-alphanumeric characters have been replaced with - the specified substitute. - - Raises - ------ - TypeError - If either `string` or `substitute` is not an instance of ``str``. - """ - if not (isinstance(string, str) and isinstance(substitute, str)): - raise TypeError('Both `string` and `substitute` must be strings.') - return re.sub(RE_SANITIZE, substitute, string) - - -def convert_to_snake_case(string: str) -> str: +def to_snake_case(string: str) -> str: """ Convert a given string to snake_case. @@ -88,9 +65,10 @@ def convert_to_snake_case(string: str) -> str: str The converted snake_case string. """ - string = sanitize_string(string) - string = RE_SNAKE_CASE.sub('_', string) - string = RE_UNDERSCORES.sub('_', string) + string = RE_NON_ALPHANUMERIC.sub('_', string) + string = RE_ACRONYM.sub(r'\1_\2', string) + string = RE_CASE_TRANSITION.sub('_', string) + string = RE_MULTIPLE_UNDERSCORES.sub('_', string) string = string.strip('_') return string.lower() @@ -411,3 +389,70 @@ def extend_packed( """ if values: byte_array.extend(struct.pack(f'<{len(values)}{fmt}', *values)) + + +def prune_empty_parent_directories( + target_directory: PathLike | str, + root_directory: PathLike | str, + remove_root: bool = False, +) -> None: + """Remove empty parent directories recursively up to root directory. + + Recursively removes the given directory if empty, then checks and removes parent + directories up to (and optionally including) the root directory. Stops at the first + non-empty directory encountered. + + Parameters + ---------- + target_directory : PathLike or str + Directory to check and remove if empty. + root_directory : PathLike or str + Root directory to stop at. Must be a parent directory of target_directory. + remove_root : bool, optional + If True, also remove root_directory if it becomes empty. + + Raises + ------ + ValueError + If target_directory is not a subpath of root_directory. + FileNotFoundError + If target_directory or root_directory does not exist. + NotADirectoryError + If target_directory or root_directory is not a directory. + """ + target_directory = Path(target_directory).absolute() + root_directory = Path(root_directory).absolute() + + for path in (target_directory, root_directory): + if not path.exists(): + raise FileNotFoundError(f"'{path}' does not exist") + if not path.is_dir(): + raise NotADirectoryError(f"'{path}' is not a directory") + + if not target_directory.is_relative_to(root_directory): + raise ValueError( + f"'{target_directory}' is not a sub-directory of '{root_directory}'" + ) + + if target_directory == root_directory: + if remove_root: + try: + root_directory.rmdir() + logger.debug('Removed empty root directory: %s', root_directory) + except OSError: + return + return + + try: + target_directory.rmdir() + logger.debug('Removed empty directory: %s', target_directory) + except OSError: + return + + parent = target_directory.parent + if parent.is_dir(): # Guard against race condition + prune_empty_parent_directories( + target_directory=parent, + root_directory=root_directory, + remove_root=remove_root, + ) diff --git a/pyproject.toml b/pyproject.toml index a2944e37..02370a95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "pyyaml>=6.0.0", "platformdirs>=4.4.0", "filelock>=3.20.0", + "psutil>=7.0.0", ] license = { text = "MIT" } readme = "README.md" @@ -70,9 +71,10 @@ typing = [ "appdirs-stubs>=0.1.0", "docutils-stubs>=0.0.22", "mypy>=1.15.0", - "types-docutils>=0.22.0.20250822", - "types-pyserial>=3.5.0.20250326", - "types-pyyaml>=6.0.12.20250822", + "types-docutils>=0.22.0", + "types-psutil>=7.0.0", + "types-pyserial>=3.5.0", + "types-pyyaml>=6.0.0", ] ci = [ "coveralls>=4.0.1", diff --git a/tests/conftest.py b/tests/conftest.py index 5b2028da..8f2824c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,10 @@ import pytest +from bpod_core import ipc from bpod_core.bpod import Bpod from bpod_core.com import ExtendedSerial -from bpod_core.constants import VID_TEENSY, PIDsTeensy +from bpod_core.constants import VID_TEENSY, TeensyPID fixture_bpod_all = { b'6': b'5', @@ -57,7 +58,7 @@ def mock_comports(mocker, mock_serial_discovery): mock_port_info.device = 'COM3' mock_port_info.serial_number = '12345' mock_port_info.vid = VID_TEENSY - mock_port_info.pid = PIDsTeensy.SERIAL + mock_port_info.pid = TeensyPID.SERIAL mock_comports = mocker.patch('bpod_core.com.comports') mock_comports.return_value = [mock_port_info] return mock_comports @@ -99,6 +100,25 @@ def in_waiting() -> int: return extended_serial +@pytest.fixture +def mock_zeroconf(mocker): + """Mock Zeroconf class.""" + return mocker.patch('bpod_core.ipc.Zeroconf', spec=ipc.Zeroconf) + + +@pytest.fixture +def mock_local_discovery_dir(tmp_path, mocker): + """Mock runtime directory for local advertisements.""" + mocker.patch.object(ipc.LocalServiceAdvertisement, 'runtime_directory', tmp_path) + return tmp_path + + +@pytest.fixture +def mock_advertisement(mock_zeroconf, mock_local_discovery_dir): + """Mock, both, zeroconf and local advertisement.""" + yield {'zeroconf': mock_zeroconf, 'runtime_dir': mock_local_discovery_dir} + + @pytest.fixture def mock_bpod(mocker, mock_ext_serial, mock_settings): mock_bpod = mocker.MagicMock(spec=Bpod) @@ -118,7 +138,9 @@ def mock_settings(mocker): @pytest.fixture -def mock_bpod_20(mock_comports, mock_ext_serial, mock_settings, mocker): +def mock_bpod_20( + mock_comports, mock_ext_serial, mock_settings, mocker, mock_advertisement +): mock_ext_serial.mock_responses.update(fixture_bpod_20) mocker.patch('bpod_core.com.ExtendedSerial', return_value=mock_ext_serial) mocker.patch('bpod_core.bpod.Bpod._detect_additional_serial_ports') @@ -127,7 +149,9 @@ def mock_bpod_20(mock_comports, mock_ext_serial, mock_settings, mocker): @pytest.fixture -def mock_bpod_25(mock_comports, mock_ext_serial, mock_settings, mocker): +def mock_bpod_25( + mock_comports, mock_ext_serial, mock_settings, mocker, mock_advertisement +): mock_ext_serial.mock_responses.update(fixture_bpod_25) mocker.patch('bpod_core.com.ExtendedSerial', return_value=mock_ext_serial) mocker.patch('bpod_core.bpod.Bpod._detect_additional_serial_ports') @@ -136,7 +160,9 @@ def mock_bpod_25(mock_comports, mock_ext_serial, mock_settings, mocker): @pytest.fixture -def mock_bpod_2p(mock_comports, mock_ext_serial, mock_settings, mocker): +def mock_bpod_2p( + mock_comports, mock_ext_serial, mock_settings, mocker, mock_advertisement +): mock_ext_serial.mock_responses.update(fixture_bpod_2p) mocker.patch('bpod_core.com.ExtendedSerial', return_value=mock_ext_serial) mocker.patch('bpod_core.bpod.Bpod._detect_additional_serial_ports') diff --git a/tests/test_com.py b/tests/test_com.py index a5ceb322..88657cc5 100644 --- a/tests/test_com.py +++ b/tests/test_com.py @@ -1,3 +1,4 @@ +import gc import logging import re from unittest.mock import MagicMock, call @@ -343,8 +344,8 @@ def test_serial_exception(self, mock_serial): assert result is False -class TestUSBSerialDevice: - """Tests for USBSerialDevice class.""" +class TestSerialDevice: + """Tests for SerialDevice class.""" @pytest.fixture def mock_port_info(self, mocker): @@ -375,13 +376,13 @@ def test_init_opens_connection_by_default( self, mock_comports, mock_extended_serial ): """Device opens serial connection by default on init.""" - device = com.USBSerialDevice('/dev/ttyACM0') + device = com.SerialDevice('/dev/ttyACM0') mock_extended_serial.open.assert_called_once() assert device.port == '/dev/ttyACM0' def test_init_without_opening_connection(self, mock_comports, mock_extended_serial): """Device can be initialized without opening connection.""" - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) mock_extended_serial.open.assert_not_called() assert device.port == '/dev/ttyACM0' @@ -389,32 +390,32 @@ def test_init_port_not_found(self, mock_comports): """Raises SerialException when port does not exist.""" mock_comports.return_value = [] with pytest.raises(SerialException, match='Serial port not found'): - com.USBSerialDevice('/dev/ttyACM0') + com.SerialDevice('/dev/ttyACM0') def test_context_manager_enter(self, mock_comports, mock_extended_serial): """Context manager __enter__ returns the device instance.""" - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) result = device.__enter__() assert result is device def test_context_manager_exit(self, mock_comports, mock_extended_serial): """Context manager __exit__ closes the connection.""" mock_extended_serial.is_open = True - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) device.__exit__(None, None, None) mock_extended_serial.close.assert_called_once() def test_context_manager_with_statement(self, mock_comports, mock_extended_serial): """Device works correctly with 'with' statement.""" mock_extended_serial.is_open = True - with com.USBSerialDevice('/dev/ttyACM0', open_connection=False) as device: + with com.SerialDevice('/dev/ttyACM0', open_connection=False) as device: assert device.port == '/dev/ttyACM0' mock_extended_serial.close.assert_called_once() def test_open_when_closed(self, mock_comports, mock_extended_serial, caplog): """open() opens connection when not already open.""" mock_extended_serial.is_open = False - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) with caplog.at_level(logging.DEBUG): device.open() mock_extended_serial.open.assert_called_once() @@ -423,7 +424,7 @@ def test_open_when_closed(self, mock_comports, mock_extended_serial, caplog): def test_open_when_already_open(self, mock_comports, mock_extended_serial): """open() does nothing when connection is already open.""" mock_extended_serial.is_open = True - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) device.open() mock_extended_serial.open.assert_not_called() @@ -431,14 +432,14 @@ def test_open_raises_on_failure(self, mock_comports, mock_extended_serial): """open() raises SerialException when connection fails.""" mock_extended_serial.is_open = False mock_extended_serial.open.side_effect = Exception('Connection failed') - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) with pytest.raises(SerialException, match='Failed to open connection'): device.open() def test_close_when_open(self, mock_comports, mock_extended_serial, caplog): """close() closes connection when open.""" mock_extended_serial.is_open = True - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) with caplog.at_level(logging.DEBUG): device.close() mock_extended_serial.close.assert_called_once() @@ -447,7 +448,7 @@ def test_close_when_open(self, mock_comports, mock_extended_serial, caplog): def test_close_when_already_closed(self, mock_comports, mock_extended_serial): """close() does nothing when connection is already closed.""" mock_extended_serial.is_open = False - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) device.close() mock_extended_serial.close.assert_not_called() @@ -455,18 +456,53 @@ def test_close_raises_on_failure(self, mock_comports, mock_extended_serial): """close() raises SerialException when closing fails.""" mock_extended_serial.is_open = True mock_extended_serial.close.side_effect = Exception('Close failed') - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) with pytest.raises(SerialException, match='Failed to close connection'): device.close() def test_port_property(self, mock_comports, mock_extended_serial): """port property returns the device path.""" - device = com.USBSerialDevice('/dev/ttyACM0', open_connection=False) + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) assert device.port == '/dev/ttyACM0' def test_accepts_kwargs(self, mock_comports, mock_extended_serial): """Extra kwargs are accepted for subclass compatibility.""" - device = com.USBSerialDevice( + device = com.SerialDevice( '/dev/ttyACM0', open_connection=False, custom_arg='value' ) assert device.port == '/dev/ttyACM0' + + def test_finalizer_closes_on_gc(self, mock_comports, mock_extended_serial): + """Finalizer closes the serial connection during garbage collection.""" + mock_extended_serial.is_open = True + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) + del device + gc.collect() + mock_extended_serial.close.assert_called_once() + + def test_finalizer_swallows_errors(self, mock_comports, mock_extended_serial): + """Finalizer does not raise when serial.close() fails.""" + mock_extended_serial.is_open = True + mock_extended_serial.close.side_effect = Exception('Close failed') + device = com.SerialDevice('/dev/ttyACM0', open_connection=False) + del device + gc.collect() # should not raise + + def test_exit_swallows_close_errors(self, mock_comports, mock_extended_serial): + """__exit__ does not raise when closing fails.""" + mock_extended_serial.is_open = True + mock_extended_serial.close.side_effect = Exception('Close failed') + with com.SerialDevice('/dev/ttyACM0', open_connection=False): + pass # should not raise on exit + + def test_exit_preserves_original_exception( + self, mock_comports, mock_extended_serial + ): + """__exit__ does not mask the original exception with a close error.""" + mock_extended_serial.is_open = True + mock_extended_serial.close.side_effect = Exception('Close failed') + with ( + pytest.raises(ValueError, match='original error'), + com.SerialDevice('/dev/ttyACM0', open_connection=False), + ): + raise ValueError('original error') diff --git a/tests/test_ipc.py b/tests/test_ipc.py index 60580b2e..ad8e003b 100644 --- a/tests/test_ipc.py +++ b/tests/test_ipc.py @@ -1,11 +1,150 @@ +import json +import os +from uuid import uuid4 + import pytest from bpod_core import ipc -@pytest.fixture -def mock_zeroconf(mocker): - return mocker.patch('bpod_core.ipc.Zeroconf') +class TestLocalServiceAdvertisement: + """Tests for LocalServiceAdvertisement class.""" + + def test_context_manager(self, mock_local_discovery_dir): + """Advertisement can be used as a context manager.""" + with ipc.LocalServiceAdvertisement( + service_name='service_name', + service_type='service_type', + address='tcp://127.0.0.1:5555', + pid=os.getpid(), + ) as ad: + assert ad.service_file.is_relative_to(mock_local_discovery_dir) + assert ad.service_file.exists() + assert ad.service_file.suffix == '.json' + assert not ad.service_file.exists() + assert not mock_local_discovery_dir.exists() + + def test_close_removes_file(self, mock_local_discovery_dir): + """close() removes the service file and removes empty parent directories.""" + ad = ipc.LocalServiceAdvertisement( + service_name='service_name', + service_type='service_type', + address='tcp://127.0.0.1:5555', + pid=os.getpid(), + ) + assert ad.service_file.exists() + ad.close() + assert not ad.service_file.exists() + assert not mock_local_discovery_dir.exists() + + def test_file_contains_correct_data(self, mock_local_discovery_dir): + """Service file contains correct JSON data.""" + uuid = uuid4() + with ipc.LocalServiceAdvertisement( + service_name='service_name', + service_type='service_type', + address='tcp://127.0.0.1:5555', + pid=12345, + uuid=uuid, + properties={'key': 'value'}, + ) as ad: + data = json.loads(ad.service_file.read_text()) + assert data['service_name'] == 'service_name' + assert data['service_type'] == 'service_type' + assert data['address'] == 'tcp://127.0.0.1:5555' + assert data['pid'] == 12345 + assert data['uuid'] == str(uuid) + assert data['properties'] == {'key': 'value'} + + def test_discover_finds_service(self, mock_local_discovery_dir): + """discover() yields advertised services.""" + with ipc.LocalServiceAdvertisement( + service_name='service_name', + service_type='service_type', + address='tcp://127.0.0.1:5555', + pid=os.getpid(), + properties={'name': 'test'}, + ): + results = list(ipc.LocalServiceAdvertisement.discover('service_type')) + assert len(results) == 1 + assert results[0].service_name == 'service_name' + assert results[0].address == 'tcp://127.0.0.1:5555' + assert results[0].properties == {'name': 'test'} + + def test_discover_filters_by_properties(self, mock_local_discovery_dir): + """discover() filters services by matching properties.""" + with ( + ipc.LocalServiceAdvertisement( + service_name='service_name_1', + service_type='service_type', + address='tcp://127.0.0.1:5555', + pid=os.getpid(), + properties={'name': 'first'}, + ), + ipc.LocalServiceAdvertisement( + service_name='service_name_2', + service_type='service_type', + address='tcp://127.0.0.1:5556', + pid=os.getpid(), + properties={'name': 'second'}, + ), + ): + results = list(ipc.LocalServiceAdvertisement.discover('service_type')) + assert len(results) == 2 + results = list( + ipc.LocalServiceAdvertisement.discover( + 'service_type', properties={'name': 'second'} + ) + ) + assert len(results) == 1 + assert results[0].address == 'tcp://127.0.0.1:5556' + + def test_discover_removes_stale_advertisements( + self, mock_local_discovery_dir, mocker + ): + """discover() removes files for dead processes.""" + with ipc.LocalServiceAdvertisement( + service_name='service_name', + service_type='service_type', + address='tcp://127.0.0.1:5555', + pid=99999999, # non-existent PID + ) as ad: + assert ad.service_file.exists() + mocker.patch('bpod_core.ipc.pid_exists', return_value=False) + results = list(ipc.LocalServiceAdvertisement.discover('service_type')) + assert len(results) == 0 + assert not ad.service_file.exists() + assert not mock_local_discovery_dir.exists() + + def test_discover_nonexistent_service(self, mock_local_discovery_dir): + """discover() returns empty iterator for unknown service types.""" + assert list(ipc.LocalServiceAdvertisement.discover('nonexistent')) == [] + + def test_service_type_sanitization(self, mock_local_discovery_dir): + """Service types with special characters are sanitized.""" + with ipc.LocalServiceAdvertisement( + service_name='service_name', + service_type='service_type.local', + address='tcp://127.0.0.1:5555', + pid=os.getpid(), + ) as ad: + assert '.' not in ad.service_file.parent.name + + def test_handles_corrupted_file(self, mock_local_discovery_dir): + """discover() skips corrupted JSON files.""" + # Create a valid advertisement first to get the directory + with ipc.LocalServiceAdvertisement( + service_name='service_name', + service_type='service_type', + address='tcp://127.0.0.1:5555', + pid=os.getpid(), + ) as ad: + service_dir = ad.service_file.parent + corrupted_file = service_dir / 'corrupted.json' + corrupted_file.write_text('not valid json') + results = list(ipc.LocalServiceAdvertisement.discover('service_type')) + assert len(results) == 1 + assert results[0].address == 'tcp://127.0.0.1:5555' @pytest.fixture @@ -17,7 +156,7 @@ class TestClient: """Tests for the DualChannelClient class.""" @pytest.fixture - def host(self, mock_zeroconf): + def host(self, mock_advertisement): with ipc.DualChannelHost( service_name='TestService', service_type='dualtest', @@ -38,8 +177,7 @@ def test_handshake(self, client): """Verify handshake exchanges addresses and negotiates serialization.""" assert client._address_req.startswith(('tcp://', 'ipc://')) assert client._address_sub.startswith(('tcp://', 'ipc://')) - # client should downgrade serialization to host's json - assert client._serialization == 'json' + assert client._serialization == 'json', 'Serialization should be JSON' def test_request_response(self, client): """Round-trip a request to the host and validate payload.""" @@ -51,51 +189,130 @@ def test_unknown_request_type(self, client, caplog): with caplog.at_level('ERROR'): client._req(request_type='invalid') - def test_error_response(self, host, client, caplog): + def test_error_response(self, mock_advertisement, mock_service_browser, caplog): """Verify server exceptions are logged and client gets empty dict.""" def bad_handler(_): raise RuntimeError('boom') - host._user_event_handler = bad_handler - with caplog.at_level('ERROR'): - reply = client.request(foo='bar') - assert reply == {} - error_logs = [rec for rec in caplog.records if rec.levelname == 'ERROR'] - assert any( - 'RuntimeError' in rec.message and 'boom' in rec.message - for rec in error_logs - ) + with ( + ipc.DualChannelHost('Test', 'service', event_handler=bad_handler) as host, + ipc.DualChannelClient('service', host.rep_tcp_addr) as client, + ): + with caplog.at_level('ERROR'): + reply = client.request(foo='bar') + assert reply == {} + error_logs = [rec for rec in caplog.records if rec.levelname == 'ERROR'] + assert any( + 'RuntimeError' in rec.message and 'boom' in rec.message + for rec in error_logs + ) class TestHost: """Tests for the DualChannelHost class.""" - @pytest.fixture - def mock_service(self, mock_zeroconf): - with ipc.DualChannelHost('test', 'testservice') as service: - yield service - - def test_basic_init_and_properties(self, mock_service): + def test_basic_init_and_properties(self, mock_advertisement): """Check ports, addresses, and Zeroconf objects are initialized.""" - assert mock_service.rep_tcp_port > 0 - assert mock_service.rep_tcp_addr.startswith('tcp://') - assert mock_service._zeroconf is not None - assert mock_service._service_info is not None + with ipc.DualChannelHost('test', 'test_service') as host: + assert host.rep_tcp_addr.startswith('tcp://') + assert host._zeroconf is not None + assert host._zeroconf_service_info is not None - @pytest.mark.parametrize('remote', [True, False]) - def test_bind_address_matches_local_flag(self, mock_zeroconf, remote): + @pytest.mark.parametrize('remote', [True, False], ids=['remote', 'local']) + def test_bind_address(self, mock_advertisement, remote): """Validate bind address switches between 0.0.0.0 and 127.0.0.1.""" - service = ipc.DualChannelHost('test', 'testservice', remote=remote) - ip = service._bind_ip - expected_ip = '0.0.0.0' if remote else '127.0.0.1' - assert ip == expected_ip + with ipc.DualChannelHost('test', 'service_type', remote=remote) as host: + expected_ip = '0.0.0.0' if remote else '127.0.0.1' + assert host._bind_ip == expected_ip, f'Bind IP should be {expected_ip}' + + def test_remote_true_creates_zeroconf_and_local(self, mock_advertisement): + """remote=True creates both zeroconf and local advertisement.""" + with ipc.DualChannelHost('test', 'test_service', remote=True) as host: + assert host._zeroconf is not None + assert host._zeroconf_service_info is not None + host._zeroconf.register_service.assert_called_once() + host._zeroconf.close.assert_not_called() + assert host._local_advertisement is not None + assert host._local_advertisement.service_file.exists() + host._zeroconf.close.assert_called_once() + assert not host._local_advertisement.service_file.exists() + assert not host._local_advertisement.runtime_directory.exists() + + def test_remote_false_creates_only_local(self, mock_advertisement): + """remote=False creates only local advertisement, no zeroconf.""" + with ipc.DualChannelHost('test', 'test_service', remote=False) as host: + assert host._zeroconf is None + assert host._zeroconf_service_info is None + assert host._local_advertisement is not None + assert host._local_advertisement.service_file.exists() + assert not host._local_advertisement.service_file.exists() + assert not host._local_advertisement.runtime_directory.exists() + + def test_close_removes_local_advertisement(self, mock_advertisement): + """close() removes the local advertisement file.""" + host = ipc.DualChannelHost('test', 'test_service', remote=False) + service_file = host._local_advertisement.service_file + assert service_file.exists() + host.close() + assert not service_file.exists() + + +class TestLocalDiscovery: + """Tests for local vs remote discovery behavior.""" + + def test_client_discovers_host_locally(self, mock_advertisement): + """Client discovers host via local advertisement without zeroconf.""" + with ( + ipc.DualChannelHost('test', 'service', event_handler=lambda d: {'req': d}), + ipc.DualChannelClient(service_type='service', remote=False) as client, + ): + assert client._address_req.startswith(('tcp://', 'ipc://')) + reply = client.request(test='value') + assert reply == {'req': {'test': 'value'}} + + def test_discover_prefers_local_over_zeroconf( + self, mock_advertisement, mock_service_browser + ): + """discover() returns local service without invoking zeroconf.""" + with ipc.LocalServiceAdvertisement( + service_name='service_name', + service_type='service_type', + address='tcp://127.0.0.1:9999', + pid=os.getpid(), + ): + address, properties = ipc.discover('service_type', remote=True, timeout=0) + assert address == 'tcp://127.0.0.1:9999' + mock_advertisement['zeroconf'].assert_not_called() + + def test_discover_remote_false_raises_if_no_local(self, mock_advertisement): + """discover() with remote=False raises if no local service found.""" + with pytest.raises(RuntimeError, match='No matching service found locally'): + ipc.discover('nonexistent', remote=False, timeout=0) + + def test_discover_falls_back_to_zeroconf( + self, mocker, mock_advertisement, mock_service_browser + ): + """discover() falls back to zeroconf when no local service exists.""" + mocker.patch('threading.Event.wait', return_value=False) + with pytest.raises(TimeoutError): + ipc.discover('test_service', remote=True, timeout=0) + mock_advertisement['zeroconf'].assert_called_once() + def test_client_remote_false_uses_only_local(self, mock_advertisement): + """Client with remote=False only uses local discovery.""" + with ( + ipc.DualChannelHost('test', 'localonly', remote=False), + ipc.DualChannelClient(service_type='localonly', remote=False), + ): + mock_advertisement['zeroconf'].assert_not_called() -class TestDiscover: - """Tests for the discover function.""" + def test_client_remote_false_raises_if_no_local(self, mock_advertisement): + """Client with remote=False raises if no local service found.""" + with pytest.raises(RuntimeError, match='No matching service found locally'): + ipc.DualChannelClient(service_type='nonexistent', remote=False) - def test_discover_timeout(self, mocker, mock_zeroconf, mock_service_browser): + def test_discover_timeout(self, mocker, mock_advertisement, mock_service_browser): """Timeout when no matching service is discovered within deadline.""" mocker.patch('threading.Event.wait', return_value=False) with pytest.raises(TimeoutError): diff --git a/tests/test_misc.py b/tests/test_misc.py index cf1a864d..ea40078e 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -10,25 +10,6 @@ from bpod_core.misc import ValidatedDict -class TestSanitizeString: - """Tests for sanitize_string utility.""" - - def test_basic_substitution(self): - """Replaces spaces and hyphens with underscores.""" - assert misc.sanitize_string(' foo bar-123 ') == '_foo_bar_123_' - - def test_custom_substitute(self): - """Uses a custom substitute character.""" - assert misc.sanitize_string('foo bar!', substitute='-') == 'foo-bar-' - - def test_invalid_types(self): - """Raises TypeError on invalid argument types.""" - with pytest.raises(TypeError): - misc.sanitize_string('foo', substitute=1) # type: ignore - with pytest.raises(TypeError): - misc.sanitize_string(1) # type: ignore - - @pytest.mark.parametrize( ('text', 'expected'), [ @@ -40,11 +21,12 @@ def test_invalid_types(self): ('_Foo_Bar_', 'foo_bar'), ('123Bar', '123_bar'), ('Foo123', 'foo_123'), + ('HTTPSConnection', 'https_connection'), ], ) -def test_convert_to_snake_case(text, expected): +def test_to_snakecase(text, expected): """Converts various input styles to snake_case.""" - assert misc.convert_to_snake_case(text) == expected + assert misc.to_snake_case(text) == expected class TestSuggestSimilar: @@ -490,6 +472,187 @@ def test_missing_key(self, validated_dict): del validated_dict['missing'] +class TestPruneEmptyParentDirectories: + """Tests for prune_empty_parent_directories utility.""" + + def test_removes_single_empty_directory(self, tmp_path): + """Removes a single empty target directory.""" + root = tmp_path / 'root' + target = root / 'empty' + root.mkdir() + target.mkdir() + + misc.prune_empty_parent_directories(target, root) + + assert not target.exists() + assert root.exists() + + def test_removes_nested_empty_directories(self, tmp_path): + """Recursively removes empty parent directories up to root.""" + root = tmp_path / 'root' + nested = root / 'a' / 'b' / 'c' + nested.mkdir(parents=True) + + misc.prune_empty_parent_directories(nested, root) + + assert not (root / 'a').exists() + assert root.exists() + + def test_stops_at_non_empty_parent(self, tmp_path): + """Stops removing when a parent directory contains other files.""" + root = tmp_path / 'root' + branch_a = root / 'parent' / 'empty' + branch_b = root / 'parent' / 'sibling.txt' + branch_a.mkdir(parents=True) + branch_b.write_text('content') + + misc.prune_empty_parent_directories(branch_a, root) + + assert not branch_a.exists() + assert (root / 'parent').exists() # parent not removed (has sibling) + assert branch_b.exists() + + def test_does_not_remove_non_empty_target(self, tmp_path): + """Target directory is not removed if it contains files.""" + root = tmp_path / 'root' + target = root / 'not_empty' + target.mkdir(parents=True) + (target / 'file.txt').write_text('content') + + misc.prune_empty_parent_directories(target, root) + + assert target.exists() + assert (target / 'file.txt').exists() + + def test_raises_for_missing_target(self, tmp_path): + """Raises FileNotFoundError if target directory does not exist.""" + root = tmp_path / 'root' + root.mkdir() + missing = root / 'missing' + + with pytest.raises(FileNotFoundError, match='does not exist'): + misc.prune_empty_parent_directories(missing, root) + + def test_raises_for_missing_root(self, tmp_path): + """Raises FileNotFoundError if root directory does not exist.""" + target = tmp_path / 'target' + target.mkdir() + missing_root = tmp_path / 'missing_root' + + with pytest.raises(FileNotFoundError, match='does not exist'): + misc.prune_empty_parent_directories(target, missing_root) + + def test_raises_for_file_target(self, tmp_path): + """Raises NotADirectoryError if target is a file.""" + root = tmp_path / 'root' + root.mkdir() + file_target = root / 'file.txt' + file_target.write_text('content') + + with pytest.raises(NotADirectoryError, match='is not a directory'): + misc.prune_empty_parent_directories(file_target, root) + + def test_raises_for_file_root(self, tmp_path): + """Raises NotADirectoryError if root is a file.""" + file_root = tmp_path / 'file.txt' + file_root.write_text('content') + target = tmp_path / 'target' + target.mkdir() + + with pytest.raises(NotADirectoryError, match='is not a directory'): + misc.prune_empty_parent_directories(target, file_root) + + def test_raises_for_non_subpath(self, tmp_path): + """Raises ValueError if target is not a subpath of root.""" + root = tmp_path / 'root' + other = tmp_path / 'other' + root.mkdir() + other.mkdir() + + with pytest.raises(ValueError, match='is not a sub-directory'): + misc.prune_empty_parent_directories(other, root) + + def test_accepts_string_paths(self, tmp_path): + """Works with string paths, not just Path objects.""" + root = tmp_path / 'root' + target = root / 'empty' + root.mkdir() + target.mkdir() + + misc.prune_empty_parent_directories(str(target), str(root)) + + assert not target.exists() + assert root.exists() + + @pytest.mark.parametrize('error_class', (PermissionError, FileNotFoundError)) + def test_handles_errors_gracefully(self, tmp_path, mocker, error_class): + """Returns silently when rmdir raises OSError.""" + root = tmp_path / 'root' + target = root / 'problematic' + root.mkdir() + target.mkdir() + + mocker.patch.object(target.__class__, 'rmdir', side_effect=error_class()) + + misc.prune_empty_parent_directories(target, root) + assert target.exists() + + def test_remove_root_true_removes_empty_root(self, tmp_path): + """With remove_root=True, empty root directory is removed.""" + root = tmp_path / 'root' + nested = root / 'a' / 'b' / 'c' + nested.mkdir(parents=True) + + misc.prune_empty_parent_directories(nested, root, remove_root=True) + + assert not root.exists() + + def test_remove_root_true_preserves_non_empty_root(self, tmp_path): + """With remove_root=True, non-empty root is still preserved.""" + root = tmp_path / 'root' + target = root / 'empty' + sibling = root / 'sibling.txt' + target.mkdir(parents=True) + sibling.write_text('content') + + misc.prune_empty_parent_directories(target, root, remove_root=True) + + assert not target.exists() + assert root.exists() + assert sibling.exists() + + @pytest.mark.parametrize('error_class', (PermissionError, FileNotFoundError)) + def test_remove_root_handles_errors_gracefully(self, tmp_path, mocker, error_class): + """With remove_root=True, OSError on root removal is handled gracefully.""" + root = tmp_path / 'root' + root.mkdir() + + mocker.patch.object(root.__class__, 'rmdir', side_effect=error_class()) + + misc.prune_empty_parent_directories(root, root, remove_root=True) + assert root.exists() + + def test_handles_race_condition(self, tmp_path, mocker): + """Returns silently if parent is deleted between rmdir and recursion.""" + root = tmp_path / 'root' + parent = root / 'parent' + target = parent / 'child' + target.mkdir(parents=True) + + original_is_dir = type(target).is_dir + + def simulate_race(path): + # Simulate race: parent deleted by another process after target removed + if path == parent and not target.exists(): + return False + return original_is_dir(path) + + mocker.patch.object(type(target), 'is_dir', simulate_race) + + misc.prune_empty_parent_directories(target, root) + assert not target.exists() + + class TestDocstringInheritanceMixin: """Tests for DocstringInheritanceMixin.""" diff --git a/uv.lock b/uv.lock index c665224d..ff8cd1cc 100644 --- a/uv.lock +++ b/uv.lock @@ -137,6 +137,7 @@ dependencies = [ { name = "numpy", version = "2.4.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "packaging" }, { name = "platformdirs" }, + { name = "psutil" }, { name = "pydantic" }, { name = "pyserial" }, { name = "pyyaml" }, @@ -174,6 +175,7 @@ dev = [ { name = "tox" }, { name = "tox-uv" }, { name = "types-docutils" }, + { name = "types-psutil" }, { name = "types-pyserial" }, { name = "types-pyyaml" }, ] @@ -207,6 +209,7 @@ test = [ { name = "tox" }, { name = "tox-uv" }, { name = "types-docutils" }, + { name = "types-psutil" }, { name = "types-pyserial" }, { name = "types-pyyaml" }, ] @@ -215,6 +218,7 @@ typing = [ { name = "docutils-stubs" }, { name = "mypy" }, { name = "types-docutils" }, + { name = "types-psutil" }, { name = "types-pyserial" }, { name = "types-pyyaml" }, ] @@ -227,6 +231,7 @@ requires-dist = [ { name = "numpy" }, { name = "packaging", specifier = ">=25.0" }, { name = "platformdirs", specifier = ">=4.4.0" }, + { name = "psutil", specifier = ">=7.0.0" }, { name = "pydantic", specifier = ">=2.0" }, { name = "pyserial", specifier = ">=3.5" }, { name = "pyyaml", specifier = ">=6.0.0" }, @@ -261,9 +266,10 @@ dev = [ { name = "sphinx-toolbox", specifier = ">=4.0.0" }, { name = "tox", specifier = ">=4.25.0" }, { name = "tox-uv", specifier = ">=0.7.2" }, - { name = "types-docutils", specifier = ">=0.22.0.20250822" }, - { name = "types-pyserial", specifier = ">=3.5.0.20250326" }, - { name = "types-pyyaml", specifier = ">=6.0.12.20250822" }, + { name = "types-docutils", specifier = ">=0.22.0" }, + { name = "types-psutil", specifier = ">=7.0.0" }, + { name = "types-pyserial", specifier = ">=3.5.0" }, + { name = "types-pyyaml", specifier = ">=6.0.0" }, ] doc = [ { name = "myst-parser", specifier = ">=4.0.1" }, @@ -290,17 +296,19 @@ test = [ { name = "ruff", specifier = ">=0.11.5" }, { name = "tox", specifier = ">=4.25.0" }, { name = "tox-uv", specifier = ">=0.7.2" }, - { name = "types-docutils", specifier = ">=0.22.0.20250822" }, - { name = "types-pyserial", specifier = ">=3.5.0.20250326" }, - { name = "types-pyyaml", specifier = ">=6.0.12.20250822" }, + { name = "types-docutils", specifier = ">=0.22.0" }, + { name = "types-psutil", specifier = ">=7.0.0" }, + { name = "types-pyserial", specifier = ">=3.5.0" }, + { name = "types-pyyaml", specifier = ">=6.0.0" }, ] typing = [ { name = "appdirs-stubs", specifier = ">=0.1.0" }, { name = "docutils-stubs", specifier = ">=0.0.22" }, { name = "mypy", specifier = ">=1.15.0" }, - { name = "types-docutils", specifier = ">=0.22.0.20250822" }, - { name = "types-pyserial", specifier = ">=3.5.0.20250326" }, - { name = "types-pyyaml", specifier = ">=6.0.12.20250822" }, + { name = "types-docutils", specifier = ">=0.22.0" }, + { name = "types-psutil", specifier = ">=7.0.0" }, + { name = "types-pyserial", specifier = ">=3.5.0" }, + { name = "types-pyyaml", specifier = ">=6.0.0" }, ] [[package]] @@ -1470,6 +1478,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "psutil" +version = "7.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/08/510cbdb69c25a96f4ae523f733cdc963ae654904e8db864c07585ef99875/psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b", size = 130595, upload-time = "2026-01-28T18:14:57.293Z" }, + { url = "https://files.pythonhosted.org/packages/d6/f5/97baea3fe7a5a9af7436301f85490905379b1c6f2dd51fe3ecf24b4c5fbf/psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea", size = 131082, upload-time = "2026-01-28T18:14:59.732Z" }, + { url = "https://files.pythonhosted.org/packages/37/d6/246513fbf9fa174af531f28412297dd05241d97a75911ac8febefa1a53c6/psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63", size = 181476, upload-time = "2026-01-28T18:15:01.884Z" }, + { url = "https://files.pythonhosted.org/packages/b8/b5/9182c9af3836cca61696dabe4fd1304e17bc56cb62f17439e1154f225dd3/psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312", size = 184062, upload-time = "2026-01-28T18:15:04.436Z" }, + { url = "https://files.pythonhosted.org/packages/16/ba/0756dca669f5a9300d0cbcbfae9a4c30e446dfc7440ffe43ded5724bfd93/psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b", size = 139893, upload-time = "2026-01-28T18:15:06.378Z" }, + { url = "https://files.pythonhosted.org/packages/1c/61/8fa0e26f33623b49949346de05ec1ddaad02ed8ba64af45f40a147dbfa97/psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9", size = 135589, upload-time = "2026-01-28T18:15:08.03Z" }, + { url = "https://files.pythonhosted.org/packages/81/69/ef179ab5ca24f32acc1dac0c247fd6a13b501fd5534dbae0e05a1c48b66d/psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00", size = 130664, upload-time = "2026-01-28T18:15:09.469Z" }, + { url = "https://files.pythonhosted.org/packages/7b/64/665248b557a236d3fa9efc378d60d95ef56dd0a490c2cd37dafc7660d4a9/psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9", size = 131087, upload-time = "2026-01-28T18:15:11.724Z" }, + { url = "https://files.pythonhosted.org/packages/d5/2e/e6782744700d6759ebce3043dcfa661fb61e2fb752b91cdeae9af12c2178/psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a", size = 182383, upload-time = "2026-01-28T18:15:13.445Z" }, + { url = "https://files.pythonhosted.org/packages/57/49/0a41cefd10cb7505cdc04dab3eacf24c0c2cb158a998b8c7b1d27ee2c1f5/psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf", size = 185210, upload-time = "2026-01-28T18:15:16.002Z" }, + { url = "https://files.pythonhosted.org/packages/dd/2c/ff9bfb544f283ba5f83ba725a3c5fec6d6b10b8f27ac1dc641c473dc390d/psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1", size = 141228, upload-time = "2026-01-28T18:15:18.385Z" }, + { url = "https://files.pythonhosted.org/packages/f2/fc/f8d9c31db14fcec13748d373e668bc3bed94d9077dbc17fb0eebc073233c/psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841", size = 136284, upload-time = "2026-01-28T18:15:19.912Z" }, + { url = "https://files.pythonhosted.org/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486", size = 129090, upload-time = "2026-01-28T18:15:22.168Z" }, + { url = "https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979", size = 129859, upload-time = "2026-01-28T18:15:23.795Z" }, + { url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" }, + { url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" }, + { url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" }, + { url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" }, + { url = "https://files.pythonhosted.org/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988", size = 137737, upload-time = "2026-01-28T18:15:33.849Z" }, + { url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" }, +] + [[package]] name = "pycparser" version = "3.0" @@ -2446,6 +2482,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/01/61ac9eb38f1f978b47443dc6fd2e0a3b0f647c2da741ddad30771f1b2b6f/types_docutils-0.22.3.20251115-py3-none-any.whl", hash = "sha256:c6e53715b65395d00a75a3a8a74e352c669bc63959e65a207dffaa22f4a2ad6e", size = 91951, upload-time = "2025-11-15T02:59:56.413Z" }, ] +[[package]] +name = "types-psutil" +version = "7.2.2.20260130" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/69/14/fc5fb0a6ddfadf68c27e254a02ececd4d5c7fdb0efcb7e7e917a183497fb/types_psutil-7.2.2.20260130.tar.gz", hash = "sha256:15b0ab69c52841cf9ce3c383e8480c620a4d13d6a8e22b16978ebddac5590950", size = 26535, upload-time = "2026-01-30T03:58:14.116Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/d7/60974b7e31545d3768d1770c5fe6e093182c3bfd819429b33133ba6b3e89/types_psutil-7.2.2.20260130-py3-none-any.whl", hash = "sha256:15523a3caa7b3ff03ac7f9b78a6470a59f88f48df1d74a39e70e06d2a99107da", size = 32876, upload-time = "2026-01-30T03:58:13.172Z" }, +] + [[package]] name = "types-pyserial" version = "3.5.0.20251001"