From d7bc9bdb08913bb8ca9acabb9988f68b6a8d4cdc Mon Sep 17 00:00:00 2001 From: Andrew Sayre <6730289+andrewsayre@users.noreply.github.com> Date: Sun, 26 Jan 2025 05:13:05 +0000 Subject: [PATCH] Improvements to connection --- pyheos/connection.py | 52 ++++++++++++++++++-------------------------- tests/test_heos.py | 8 ++++++- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/pyheos/connection.py b/pyheos/connection.py index 7a62fef..546addc 100644 --- a/pyheos/connection.py +++ b/pyheos/connection.py @@ -2,10 +2,10 @@ import asyncio import logging -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Coroutine from contextlib import suppress from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Final +from typing import TYPE_CHECKING, Any, Final from pyheos.command import COMMAND_HEART_BEAT, COMMAND_REBOOT from pyheos.message import HeosCommand, HeosMessage @@ -89,9 +89,11 @@ async def _on_command_error(self, error: CommandFailedError) -> None: for callback in self._on_command_error_callbacks: await callback(error) - def _register_task(self, future: Awaitable[None]) -> None: + def _register_task( + self, future: Coroutine[Any, Any, None], name: str | None = None + ) -> None: """Register a task that is running in the background, so it can be canceled and reset later.""" - task = asyncio.ensure_future(future) + task: asyncio.Task[None] = asyncio.create_task(future, name=name) self._running_tasks.add(task) task.add_done_callback(self._running_tasks.discard) @@ -101,19 +103,14 @@ async def _reset(self) -> None: while self._running_tasks: task = self._running_tasks.pop() if task.cancel(): - try: + with suppress(asyncio.CancelledError): await task - except asyncio.CancelledError: - pass # Close the writer if self._writer: self._writer.close() - try: + with suppress(OSError, asyncio.CancelledError): await self._writer.wait_closed() - except (ConnectionError, OSError, asyncio.CancelledError): - pass - finally: - self._writer = None + self._writer = None # Reset other parameters self._pending_command_event.clear() self._last_activity = datetime.now() @@ -135,12 +132,7 @@ async def _read_handler(self, reader: asyncio.StreamReader) -> None: while True: try: binary_result = await reader.readuntil(SEPARATOR_BYTES) - except ( - ConnectionError, - asyncio.IncompleteReadError, - RuntimeError, - OSError, - ) as error: + except (asyncio.IncompleteReadError, RuntimeError, OSError) as error: await self._disconnect_from_error(error) return else: @@ -156,7 +148,7 @@ async def _handle_message(self, message: HeosMessage) -> None: return if message.is_event: _LOGGER.debug("Event received: '%s': '%s'", message.command, message) - self._register_task(self._on_event(message)) + self._register_task(self._on_event(message), "Event Handler") return # Set the message on the pending command. @@ -180,7 +172,9 @@ async def _command_impl() -> HeosMessage: await self._writer.drain() except (ConnectionError, OSError, AttributeError) as error: # Occurs when the connection is broken. Run in the background to ensure connection is reset. - self._register_task(self._disconnect_from_error(error)) + self._register_task( + self._disconnect_from_error(error), "Disconnect From Error" + ) _LOGGER.debug( "Command failed '%s': %s: %s", command, type(error).__name__, error ) @@ -252,7 +246,7 @@ async def connect(self) -> None: ) from err # Start read handler - self._register_task(self._read_handler(reader)) + self._register_task(self._read_handler(reader), "Read Handler") self._last_activity = datetime.now() self._state = ConnectionState.CONNECTED _LOGGER.debug("Connected to %s", self._host) @@ -314,32 +308,28 @@ async def _attempt_reconnect(self) -> None: unlimited_attempts = self._reconnect_max_attempts == 0 delay = min(self._reconnect_delay, MAX_RECONNECT_DELAY) while (attempts < self._reconnect_max_attempts) or unlimited_attempts: + _LOGGER.debug("Waiting %s seconds before attempting to reconnect", delay) + await asyncio.sleep(delay) + _LOGGER.debug("Attempting reconnect #%s to %s", (attempts + 1), self._host) try: - _LOGGER.debug( - "Waiting %s seconds before attempting to reconnect", delay - ) - await asyncio.sleep(delay) - _LOGGER.debug( - "Attempting reconnect #%s to %s", (attempts + 1), self._host - ) await self.connect() except HeosError: attempts += 1 delay = min(delay * 2, MAX_RECONNECT_DELAY) else: - return # This never actually hits as the task is cancelled when the connection is established, but it's here for completeness. + return async def _on_connected(self) -> None: """Handle when the connection is established.""" # Start heart beat when enabled if self._heart_beat: - self._register_task(self._heart_beat_handler()) + self._register_task(self._heart_beat_handler(), "Heart Beat") await super()._on_connected() async def _on_disconnected(self, due_to_error: bool = False) -> None: """Handle when the connection is lost. Invoked after the connection has been reset.""" if due_to_error and self._reconnect: - self._register_task(self._attempt_reconnect()) + self._register_task(self._attempt_reconnect(), "Reconnect") await super()._on_disconnected(due_to_error) diff --git a/tests/test_heos.py b/tests/test_heos.py index 388f8dd..d8bf162 100644 --- a/tests/test_heos.py +++ b/tests/test_heos.py @@ -380,11 +380,17 @@ async def test_reconnect_during_event(mock_device: MockHeosDevice) -> None: # Assert reconnects once server is back up and fires connected # Force reconnect timeout - await asyncio.sleep(0.5) # type: ignore[unreachable] + reconnect_task = next( # type: ignore[unreachable] + task + for task in heos._connection._running_tasks + if task.get_name() == "Reconnect" + ) + await asyncio.sleep(0.5) await mock_device.start() await connect_signal.wait() assert heos.connection_state == ConnectionState.CONNECTED + await reconnect_task # Ensures task completes, otherwise disconnect cancels it await heos.disconnect()