From 00fc347975df4a1795364343e67413a55452c068 Mon Sep 17 00:00:00 2001 From: Andrew Sayre <6730289+andrewsayre@users.noreply.github.com> Date: Sun, 9 Feb 2025 17:18:57 -0600 Subject: [PATCH] Fix handling of command timeouts (#100) * Improve command logic * Add test for event during command * Change visability --- pyheos/connection.py | 37 +++++++++++++++++--------------- tests/__init__.py | 51 ++++++++++++++++++++++++++++++++++++++++---- tests/test_heos.py | 40 ++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 21 deletions(-) diff --git a/pyheos/connection.py b/pyheos/connection.py index 546addc..4ac6646 100644 --- a/pyheos/connection.py +++ b/pyheos/connection.py @@ -137,11 +137,11 @@ async def _read_handler(self, reader: asyncio.StreamReader) -> None: return else: self._last_activity = datetime.now() - await self._handle_message( + self._handle_message( HeosMessage._from_raw_message(binary_result.decode()) ) - async def _handle_message(self, message: HeosMessage) -> None: + def _handle_message(self, message: HeosMessage) -> None: """Handle a message received from the HEOS device.""" if message.is_under_process: _LOGGER.debug("Command under process '%s'", message.command) @@ -152,7 +152,10 @@ async def _handle_message(self, message: HeosMessage) -> None: return # Set the message on the pending command. - self._pending_command_event.set(message) + if not self._pending_command_event.set(message): + _LOGGER.debug( + "Unexpected response received: '%s': '%s'", message.command, message + ) async def command(self, command: HeosCommand) -> HeosMessage: """Send a command to the HEOS device.""" @@ -165,19 +168,19 @@ async def _command_impl() -> HeosMessage: raise CommandError(command.command, "Not connected to device") if TYPE_CHECKING: assert self._writer is not None - assert not self._pending_command_event.is_set() + # Send the command try: self._writer.write((command.uri + SEPARATOR).encode()) await self._writer.drain() except (ConnectionError, OSError, AttributeError) as error: # Occurs when the connection is broken. Run in the background to ensure connection is reset. - self._register_task( - self._disconnect_from_error(error), "Disconnect From Error" - ) _LOGGER.debug( "Command failed '%s': %s: %s", command, type(error).__name__, error ) + self._register_task( + self._disconnect_from_error(error), "Disconnect From Error" + ) raise CommandError( command.command, f"Command failed: {error}" ) from error @@ -192,7 +195,7 @@ async def _command_impl() -> HeosMessage: # Wait for the response with a timeout try: response = await asyncio.wait_for( - self._pending_command_event.wait(), self._timeout + self._pending_command_event.wait(command.command), self._timeout ) except asyncio.TimeoutError as error: # Occurs when the command times out @@ -201,9 +204,6 @@ async def _command_impl() -> HeosMessage: finally: self._pending_command_event.clear() - # The retrieved response should match the command - assert command.command == response.command - # Check the result if not response.result: _LOGGER.debug("Command failed '%s': '%s'", command, response) @@ -340,24 +340,27 @@ def __init__(self) -> None: """Init a new instance of the CommandEvent.""" self._event: asyncio.Event = asyncio.Event() self._response: HeosMessage | None = None + self._target_command: str | None = None - async def wait(self) -> HeosMessage: + async def wait(self, target_command: str) -> HeosMessage: """Wait until the event is set.""" + self._target_command = target_command await self._event.wait() if TYPE_CHECKING: assert self._response is not None return self._response - def set(self, response: HeosMessage) -> None: + def set(self, response: HeosMessage) -> bool: """Set the response.""" + if self._target_command is None or self._target_command != response.command: + return False + self._target_command = None self._response = response self._event.set() + return True def clear(self) -> None: """Clear the event.""" self._response = None + self._target_command = None self._event.clear() - - def is_set(self) -> bool: - """Return True if the event is set.""" - return self._event.is_set() diff --git a/tests/__init__.py b/tests/__init__.py index 3b73ce3..7834719 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,8 +3,9 @@ import asyncio import functools import json -from collections.abc import Callable, Sequence +from collections.abc import Callable, Generator, Sequence from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any, cast from urllib.parse import parse_qsl, quote_plus, urlencode, urlparse @@ -273,6 +274,7 @@ def __init__(self) -> None: self._started: bool = False self.connections: list[ConnectionLog] = [] self._matchers: list[CommandMatcher] = [] + self.modifiers: list[CommandModifier] = [] async def start(self) -> None: """Start the heos server.""" @@ -354,6 +356,18 @@ def assert_command_called( f"Command was not registered: {target_command} with args {target_args}." ) + @contextmanager + def modify( + self, command: str, *, replay_response: int = 1, delay_response: float = 0.0 + ) -> Generator[None]: + """Modifies behavior of command processing.""" + modifier = CommandModifier( + command, replay_response=replay_response, delay_response=delay_response + ) + self.modifiers.append(modifier) + yield + self.modifiers.remove(modifier) + async def _handle_connection( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: @@ -387,10 +401,27 @@ async def _handle_connection( None, ) if matcher: + # Apply modifiers + modifier = next( + ( + modifier + for modifier in self.modifiers + if modifier.command == command + ), + DEFAULT_MODIFIER, + ) + + # Delay the response if set + if modifier.delay_response > 0: + await asyncio.sleep(modifier.delay_response) + responses = await matcher.get_response(query) - for response in responses: - writer.write((response + SEPARATOR).encode()) - await writer.drain() + # Write the response multiple times if set + for _ in range(modifier.replay_response): + for response in responses: + writer.write((response + SEPARATOR).encode()) + await writer.drain() + continue # Special processing for known/unknown commands @@ -512,3 +543,15 @@ async def write(self, payload: str) -> None: data = (payload + SEPARATOR).encode() self._writer.write(data) await self._writer.drain() + + +@dataclass +class CommandModifier: + """Define a command modifier.""" + + command: str + replay_response: int = field(kw_only=True, default=1) + delay_response: float = field(kw_only=True, default=0.0) + + +DEFAULT_MODIFIER = CommandModifier(c.COMMAND_GET_PLAYERS) diff --git a/tests/test_heos.py b/tests/test_heos.py index d8bf162..f268b62 100644 --- a/tests/test_heos.py +++ b/tests/test_heos.py @@ -52,6 +52,7 @@ from . import ( CallCommand, + CommandModifier, MockHeosDevice, calls_command, calls_commands, @@ -317,6 +318,45 @@ async def test_commands_fail_when_disconnected( ) +@calls_command("system.heart_beat") +async def test_command_timeout(mock_device: MockHeosDevice, heos: Heos) -> None: + """Test command times out.""" + with mock_device.modify(c.COMMAND_HEART_BEAT, delay_response=0.2): + with pytest.raises(CommandError): + await heos.heart_beat() + await asyncio.sleep(0.2) + await heos.heart_beat() + + +@calls_command("system.heart_beat") +async def test_command_duplicate_response( + mock_device: MockHeosDevice, heos: Heos, caplog: pytest.LogCaptureFixture +) -> None: + """Test a duplicate command response is discarded.""" + with mock_device.modify(c.COMMAND_HEART_BEAT, replay_response=2): + await heos.heart_beat() + while "Unexpected response received: 'system/heart_beat'" not in caplog.text: + await asyncio.sleep(0.1) + + +@calls_command("system.heart_beat") +async def test_event_received_during_command(mock_device: MockHeosDevice) -> None: + """Test event received during command execution.""" + heos = await Heos.create_and_connect("127.0.0.1", heart_beat=False) + + mock_device.modifiers.append( + CommandModifier(c.COMMAND_HEART_BEAT, delay_response=0.2) + ) + command_task = asyncio.create_task(heos.heart_beat()) + + await asyncio.sleep(0.1) + await mock_device.write_event("event.user_changed_signed_in") + + await command_task + + await heos.disconnect() + + async def test_connection_error(mock_device: MockHeosDevice, heos: Heos) -> None: """Test connection error during event results in disconnected.""" disconnect_signal = connect_handler(