Skip to content

Commit

Permalink
Fix handling of command timeouts (#100)
Browse files Browse the repository at this point in the history
* Improve command logic

* Add test for event during command

* Change visability
  • Loading branch information
andrewsayre authored Feb 9, 2025
1 parent e04cd54 commit 00fc347
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 21 deletions.
37 changes: 20 additions & 17 deletions pyheos/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ async def _read_handler(self, reader: asyncio.StreamReader) -> None:
return
else:
self._last_activity = datetime.now()
await self._handle_message(
self._handle_message(
HeosMessage._from_raw_message(binary_result.decode())
)

async def _handle_message(self, message: HeosMessage) -> None:
def _handle_message(self, message: HeosMessage) -> None:
"""Handle a message received from the HEOS device."""
if message.is_under_process:
_LOGGER.debug("Command under process '%s'", message.command)
Expand All @@ -152,7 +152,10 @@ async def _handle_message(self, message: HeosMessage) -> None:
return

# Set the message on the pending command.
self._pending_command_event.set(message)
if not self._pending_command_event.set(message):
_LOGGER.debug(
"Unexpected response received: '%s': '%s'", message.command, message
)

async def command(self, command: HeosCommand) -> HeosMessage:
"""Send a command to the HEOS device."""
Expand All @@ -165,19 +168,19 @@ async def _command_impl() -> HeosMessage:
raise CommandError(command.command, "Not connected to device")
if TYPE_CHECKING:
assert self._writer is not None
assert not self._pending_command_event.is_set()

# Send the command
try:
self._writer.write((command.uri + SEPARATOR).encode())
await self._writer.drain()
except (ConnectionError, OSError, AttributeError) as error:
# Occurs when the connection is broken. Run in the background to ensure connection is reset.
self._register_task(
self._disconnect_from_error(error), "Disconnect From Error"
)
_LOGGER.debug(
"Command failed '%s': %s: %s", command, type(error).__name__, error
)
self._register_task(
self._disconnect_from_error(error), "Disconnect From Error"
)
raise CommandError(
command.command, f"Command failed: {error}"
) from error
Expand All @@ -192,7 +195,7 @@ async def _command_impl() -> HeosMessage:
# Wait for the response with a timeout
try:
response = await asyncio.wait_for(
self._pending_command_event.wait(), self._timeout
self._pending_command_event.wait(command.command), self._timeout
)
except asyncio.TimeoutError as error:
# Occurs when the command times out
Expand All @@ -201,9 +204,6 @@ async def _command_impl() -> HeosMessage:
finally:
self._pending_command_event.clear()

# The retrieved response should match the command
assert command.command == response.command

# Check the result
if not response.result:
_LOGGER.debug("Command failed '%s': '%s'", command, response)
Expand Down Expand Up @@ -340,24 +340,27 @@ def __init__(self) -> None:
"""Init a new instance of the CommandEvent."""
self._event: asyncio.Event = asyncio.Event()
self._response: HeosMessage | None = None
self._target_command: str | None = None

async def wait(self) -> HeosMessage:
async def wait(self, target_command: str) -> HeosMessage:
"""Wait until the event is set."""
self._target_command = target_command
await self._event.wait()
if TYPE_CHECKING:
assert self._response is not None
return self._response

def set(self, response: HeosMessage) -> None:
def set(self, response: HeosMessage) -> bool:
"""Set the response."""
if self._target_command is None or self._target_command != response.command:
return False
self._target_command = None
self._response = response
self._event.set()
return True

def clear(self) -> None:
"""Clear the event."""
self._response = None
self._target_command = None
self._event.clear()

def is_set(self) -> bool:
"""Return True if the event is set."""
return self._event.is_set()
51 changes: 47 additions & 4 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import asyncio
import functools
import json
from collections.abc import Callable, Sequence
from collections.abc import Callable, Generator, Sequence
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, cast
from urllib.parse import parse_qsl, quote_plus, urlencode, urlparse
Expand Down Expand Up @@ -273,6 +274,7 @@ def __init__(self) -> None:
self._started: bool = False
self.connections: list[ConnectionLog] = []
self._matchers: list[CommandMatcher] = []
self.modifiers: list[CommandModifier] = []

async def start(self) -> None:
"""Start the heos server."""
Expand Down Expand Up @@ -354,6 +356,18 @@ def assert_command_called(
f"Command was not registered: {target_command} with args {target_args}."
)

@contextmanager
def modify(
self, command: str, *, replay_response: int = 1, delay_response: float = 0.0
) -> Generator[None]:
"""Modifies behavior of command processing."""
modifier = CommandModifier(
command, replay_response=replay_response, delay_response=delay_response
)
self.modifiers.append(modifier)
yield
self.modifiers.remove(modifier)

async def _handle_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
Expand Down Expand Up @@ -387,10 +401,27 @@ async def _handle_connection(
None,
)
if matcher:
# Apply modifiers
modifier = next(
(
modifier
for modifier in self.modifiers
if modifier.command == command
),
DEFAULT_MODIFIER,
)

# Delay the response if set
if modifier.delay_response > 0:
await asyncio.sleep(modifier.delay_response)

responses = await matcher.get_response(query)
for response in responses:
writer.write((response + SEPARATOR).encode())
await writer.drain()
# Write the response multiple times if set
for _ in range(modifier.replay_response):
for response in responses:
writer.write((response + SEPARATOR).encode())
await writer.drain()

continue

# Special processing for known/unknown commands
Expand Down Expand Up @@ -512,3 +543,15 @@ async def write(self, payload: str) -> None:
data = (payload + SEPARATOR).encode()
self._writer.write(data)
await self._writer.drain()


@dataclass
class CommandModifier:
"""Define a command modifier."""

command: str
replay_response: int = field(kw_only=True, default=1)
delay_response: float = field(kw_only=True, default=0.0)


DEFAULT_MODIFIER = CommandModifier(c.COMMAND_GET_PLAYERS)
40 changes: 40 additions & 0 deletions tests/test_heos.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

from . import (
CallCommand,
CommandModifier,
MockHeosDevice,
calls_command,
calls_commands,
Expand Down Expand Up @@ -317,6 +318,45 @@ async def test_commands_fail_when_disconnected(
)


@calls_command("system.heart_beat")
async def test_command_timeout(mock_device: MockHeosDevice, heos: Heos) -> None:
"""Test command times out."""
with mock_device.modify(c.COMMAND_HEART_BEAT, delay_response=0.2):
with pytest.raises(CommandError):
await heos.heart_beat()
await asyncio.sleep(0.2)
await heos.heart_beat()


@calls_command("system.heart_beat")
async def test_command_duplicate_response(
mock_device: MockHeosDevice, heos: Heos, caplog: pytest.LogCaptureFixture
) -> None:
"""Test a duplicate command response is discarded."""
with mock_device.modify(c.COMMAND_HEART_BEAT, replay_response=2):
await heos.heart_beat()
while "Unexpected response received: 'system/heart_beat'" not in caplog.text:
await asyncio.sleep(0.1)


@calls_command("system.heart_beat")
async def test_event_received_during_command(mock_device: MockHeosDevice) -> None:
"""Test event received during command execution."""
heos = await Heos.create_and_connect("127.0.0.1", heart_beat=False)

mock_device.modifiers.append(
CommandModifier(c.COMMAND_HEART_BEAT, delay_response=0.2)
)
command_task = asyncio.create_task(heos.heart_beat())

await asyncio.sleep(0.1)
await mock_device.write_event("event.user_changed_signed_in")

await command_task

await heos.disconnect()


async def test_connection_error(mock_device: MockHeosDevice, heos: Heos) -> None:
"""Test connection error during event results in disconnected."""
disconnect_signal = connect_handler(
Expand Down

0 comments on commit 00fc347

Please sign in to comment.