diff --git a/.github/workflows/autodeps.yml b/.github/workflows/autodeps.yml index 712ceca..cc127ad 100644 --- a/.github/workflows/autodeps.yml +++ b/.github/workflows/autodeps.yml @@ -38,7 +38,7 @@ jobs: # apply newer versions' formatting - name: Black - run: black src/trio + run: black src/checkers - name: uv run: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e8ee528..7315954 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.0 hooks: - id: ruff types: [file] diff --git a/computer_players/MiniMax_AI.py b/computer_players/MiniMax_AI.py index 6a20ca5..a7baaa4 100755 --- a/computer_players/MiniMax_AI.py +++ b/computer_players/MiniMax_AI.py @@ -25,8 +25,6 @@ T = TypeVar("T") -PORT = 31613 - # Player: # 0 = False = Person = MIN = 0, 2 # 1 = True = AI (Us) = MAX = 1, 3 diff --git a/computer_players/machine_client.py b/computer_players/machine_client.py index e9681da..72bdd37 100644 --- a/computer_players/machine_client.py +++ b/computer_players/machine_client.py @@ -8,6 +8,8 @@ import sys from abc import ABCMeta, abstractmethod +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING import trio @@ -20,6 +22,9 @@ ) from checkers.state import Action, Pos, State +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup @@ -140,12 +145,14 @@ def __init__(self, remote_state_class: type[RemoteState]) -> None: self.running = True - self.add_components( - ( - remote_state_class(), - GameClient("game_client"), - ), - ) + self.add_component(remote_state_class()) + + @asynccontextmanager + async def client_with_block(self) -> AsyncGenerator[GameClient, None]: + """Add client temporarily with `with` block, ensuring closure.""" + async with GameClient("game_client") as client: + with self.temporary_component(client): + yield client def bind_handlers(self) -> None: """Register client event handlers.""" @@ -182,14 +189,21 @@ async def run_client( ) client = MachineClient(remote_state_class) with event_manager.temporary_component(client): - await event_manager.raise_event( - Event("client_connect", (host, port)), - ) - print(f"Connected to server {host}:{port}") - while client.running: # noqa: ASYNC110 - # Wait so backlog things happen - await trio.sleep(1) - print(f"Disconnected from server {host}:{port}") + async with client.client_with_block(): + await event_manager.raise_event( + Event("client_connect", (host, port)), + ) + print(f"Connected to server {host}:{port}") + try: + while client.running: # noqa: ASYNC110 + # Wait so backlog things happen + await trio.sleep(1) + except KeyboardInterrupt: + print("Shutting down client from keyboard interrupt.") + await event_manager.raise_event( + Event("network_stop", None), + ) + print(f"Disconnected from server {host}:{port}") client.unbind_components() connected.remove((host, port)) @@ -225,13 +239,11 @@ async def run_clients_in_local_servers( ) await trio.sleep(1) except BaseExceptionGroup as exc: - caught = False for ex in exc.exceptions: if isinstance(ex, KeyboardInterrupt): print("Shutting down from keyboard interrupt.") - caught = True break - if not caught: + else: raise diff --git a/pyproject.toml b/pyproject.toml index 55d238a..0f3ff0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "pygame~=2.6.0", "typing_extensions>=4.12.2", "mypy_extensions>=1.0.0", - "trio~=0.26.2", + "trio~=0.27.0", "cryptography>=43.0.0", "exceptiongroup; python_version < '3.11'", ] @@ -97,7 +97,6 @@ disable_all_dunder_policy = true [tool.black] line-length = 79 -target-version = ['py312'] [tool.ruff] line-length = 79 @@ -117,12 +116,16 @@ select = [ "EXE", # flake8-executable "F", # pyflakes "FA", # flake8-future-annotations + "FLY", # flynt + "FURB", # refurb "I", # isort + "ICN", # flake8-import-conventions "N", # pep8-naming "PIE", # flake8-pie "PT", # flake8-pytest-style "PYI", # flake8-pyi "Q", # flake8-quotes + "R", # Refactor "RET", # flake8-return "RUF", # Ruff-specific rules "S", # flake8-bandit diff --git a/src/checkers/base_io.py b/src/checkers/base_io.py index 62bed1f..1f8e7b5 100644 --- a/src/checkers/base_io.py +++ b/src/checkers/base_io.py @@ -174,6 +174,8 @@ async def write_varint(self, value: int, /) -> None: """Write a 32-bit signed integer in a variable length format. For more information about variable length format check :meth:`._write_varuint`. + + Raises ValueError if value is outside of the range of a 32-bit signed integer. """ val = to_twos_complement(value, bits=32) await self._write_varuint(val, max_bits=32) @@ -182,12 +184,17 @@ async def write_varlong(self, value: int, /) -> None: """Write a 64-bit signed integer in a variable length format. For more information about variable length format check :meth:`._write_varuint`. + + Raises ValueError if value is outside of the range of a 64-bit signed integer. """ val = to_twos_complement(value, bits=64) await self._write_varuint(val, max_bits=64) async def write_bytearray(self, data: bytes, /) -> None: - """Write an arbitrary sequence of bytes, prefixed with a varint of it's size.""" + """Write an arbitrary sequence of bytes, prefixed with a varint of it's size. + + Raises ValueError if length is is outside of the range of a 32-bit signed integer. + """ await self.write_varint(len(data)) await self.write(data) @@ -321,6 +328,8 @@ def write_varint(self, value: int, /) -> None: """Write a 32-bit signed integer in a variable length format. For more information about variable length format check :meth:`._write_varuint`. + + Raises ValueError if length is is outside of the range of a 32-bit signed integer. """ val = to_twos_complement(value, bits=32) self._write_varuint(val, max_bits=32) @@ -329,12 +338,17 @@ def write_varlong(self, value: int, /) -> None: """Write a 64-bit signed integer in a variable length format. For more information about variable length format check :meth:`._write_varuint` docstring. + + Raises ValueError if length is is outside of the range of a 64-bit signed integer. """ val = to_twos_complement(value, bits=64) self._write_varuint(val, max_bits=64) def write_bytearray(self, data: bytes, /) -> None: - """Write an arbitrary sequence of bytes, prefixed with a varint of it's size.""" + """Write an arbitrary sequence of bytes, prefixed with a varint of it's size. + + Raises ValueError if length is is outside of the range of a 32-bit signed integer. + """ self.write_varint(len(data)) self.write(data) @@ -429,7 +443,7 @@ async def _read_varuint(self, *, max_bits: int | None = None) -> int: This is a standard way of transmitting ints, and it allows smaller numbers to take less bytes. Reading will be limited up to integer values of ``max_bits`` bits, and trying to read bigger values will rase - an :exc:`IOError`. Note that setting ``max_bits`` to for example 32 bits doesn't mean that at most 4 bytes + an :exc:`OSError`. Note that setting ``max_bits`` to for example 32 bits doesn't mean that at most 4 bytes will be read, in this case we would actually read at most 5 bytes, due to the variable encoding overhead. Varints send bytes where 7 least significant bits are value bits, and the most significant bit is continuation diff --git a/src/checkers/client.py b/src/checkers/client.py index c5712d1..0f8e2ad 100644 --- a/src/checkers/client.py +++ b/src/checkers/client.py @@ -25,11 +25,13 @@ __version__ = "0.0.0" import struct +import time import traceback from typing import TYPE_CHECKING import trio +from checkers import network from checkers.base_io import StructFormat from checkers.buffer import Buffer from checkers.component import Event @@ -39,10 +41,6 @@ encrypt_token_and_secret, generate_shared_secret, ) -from checkers.network import ( - NetworkStreamNotConnectedError, - NetworkTimeoutError, -) from checkers.network_shared import ( ADVERTISEMENT_IP, ADVERTISEMENT_PORT, @@ -108,7 +106,7 @@ async def read_advertisements( trio.socket.IP_ADD_MEMBERSHIP, mreq, ) - else: + else: # IPv6 mreq = group_bin + struct.pack("@I", 0) udp_socket.setsockopt( trio.socket.IPPROTO_IPV6, @@ -181,7 +179,7 @@ def __init__(self, name: str) -> None: cbe = ClientBoundEvents self.register_read_network_events( { - cbe.callback_ping: "callback_ping->client", + cbe.callback_ping: "server->callback_ping", cbe.create_piece: "server->create_piece", cbe.select_piece: "server->select_piece", cbe.create_tile: "server->create_tile", @@ -206,7 +204,7 @@ def bind_handlers(self) -> None: super().bind_handlers() self.register_handlers( { - # "callback_ping->client": self.print_callback_ping, + "server->callback_ping": self.read_callback_ping, "gameboard_piece_clicked": self.write_piece_click, "gameboard_tile_clicked": self.write_tile_click, "server->create_piece": self.read_create_piece, @@ -249,7 +247,22 @@ async def raise_disconnect(self, message: str) -> None: assert self.not_connected async def handle_read_event(self) -> None: - """Raise events from server.""" + """Raise events from server. + + Can raise following exceptions: + RuntimeError - Unhandled packet id + network.NetworkStreamNotConnectedError - Network stream is not connected + OSError - Stopped responding + trio.BrokenResourceError - Something is wrong and stream is broken + + Shouldn't happen with write lock but still: + trio.BusyResourceError - Another task is already writing data + + Handled exceptions: + trio.ClosedResourceError - Stream is closed or another task closes stream + network.NetworkTimeoutError - Timeout + network.NetworkEOFError - Server closed connection + """ ##print(f"{self.__class__.__name__}[{self.name}]: handle_read_event") if not self.manager_exists: return @@ -260,34 +273,43 @@ async def handle_read_event(self) -> None: # print("handle_read_event start") event = await self.read_event() except trio.ClosedResourceError: + self.running = False await self.close() - print("Client side socket closed from another task.") + print(f"[{self.name}] Socket closed from another task.") return - except NetworkTimeoutError as exc: + except network.NetworkTimeoutError as exc: if self.running: + self.running = False + print(f"[{self.name}] NetworkTimeoutError") await self.close() traceback.print_exception(exc) await self.raise_disconnect( "Failed to read event from server.", ) return - except NetworkStreamNotConnectedError as exc: + except network.NetworkStreamNotConnectedError as exc: + self.running = False + print(f"[{self.name}] NetworkStreamNotConnectedError") traceback.print_exception(exc) await self.close() assert self.not_connected raise - else: - await self.raise_event(event) - ## await self.raise_event( - ## Event(f"client[{self.name}]_read_event", None), - ## ) + except network.NetworkEOFError: + self.running = False + print(f"[{self.name}] NetworkEOFError") + await self.close() + await self.raise_disconnect( + "Server closed connection.", + ) + return + + await self.raise_event(event) async def handle_client_connect( self, event: Event[tuple[str, int]], ) -> None: """Have client connect to address specified in event.""" - print("handle_client_connect event fired") if self.connect_event_lock.locked(): raise RuntimeError("2nd client connect fired!") async with self.connect_event_lock: @@ -310,10 +332,25 @@ async def handle_client_connect( await self.raise_event( Event("client_connection_closed", None), ) - + else: + print( + "manager does not exist, cannot send client connection closed event.", + ) return await self.raise_disconnect("Error connecting to server.") + async def read_callback_ping(self, event: Event[bytearray]) -> None: + """Read callback_ping event from server.""" + ns = int.from_bytes(event.data) + now = int(time.time() * 1e9) + difference = now - ns + + # print(f'{difference / 1e9 = } seconds') + + await self.raise_event( + Event("callback_ping", difference), + ) + async def read_create_piece(self, event: Event[bytearray]) -> None: """Read create_piece event from server.""" buffer = Buffer(event.data) @@ -360,7 +397,7 @@ async def write_piece_click(self, event: Event[tuple[Pos, int]]) -> None: buffer = Buffer() write_position(buffer, piece_position) - buffer.write_value(StructFormat.UINT, piece_type) + # buffer.write_value(StructFormat.UINT, piece_type) await self.write_event(Event("select_piece->server", buffer)) @@ -510,6 +547,7 @@ async def handle_network_stop(self, event: Event[None]) -> None: """Send EOF if connected and close socket.""" if self.not_connected: return + self.running = False try: await self.send_eof() finally: diff --git a/src/checkers/component.py b/src/checkers/component.py index 14c5708..c4b7e6a 100644 --- a/src/checkers/component.py +++ b/src/checkers/component.py @@ -70,7 +70,7 @@ class Component: __slots__ = ("name", "__manager") - def __init__(self, name: str) -> None: + def __init__(self, name: object) -> None: """Initialise with name.""" self.name = name self.__manager: ref[ComponentManager] | None = None @@ -102,7 +102,10 @@ def register_handler( event_name: str, handler_coro: Callable[[Event[Any]], Awaitable[Any]], ) -> None: - """Register handler with bound component manager.""" + """Register handler with bound component manager. + + Raises AttributeError if this component is not bound. + """ self.manager.register_component_handler( event_name, handler_coro, @@ -113,7 +116,10 @@ def register_handlers( self, handlers: dict[str, Callable[[Event[Any]], Awaitable[Any]]], ) -> None: - """Register multiple handler Coroutines.""" + """Register multiple handler Coroutines. + + Raises AttributeError if this component is not bound. + """ for name, coro in handlers.items(): self.register_handler(name, coro) @@ -121,7 +127,10 @@ def bind_handlers(self) -> None: """Add handlers in subclass.""" def bind(self, manager: ComponentManager) -> None: - """Bind self to manager.""" + """Bind self to manager. + + Raises RuntimeError if component is already bound to a manager. + """ if self.manager_exists: raise RuntimeError( f"{self.name} component is already bound to {self.manager}", @@ -130,30 +139,48 @@ def bind(self, manager: ComponentManager) -> None: self.bind_handlers() def has_handler(self, event_name: str) -> bool: - """Return if manager has event handlers registered for a given event.""" + """Return if manager has event handlers registered for a given event. + + Raises AttributeError if this component is not bound. + """ return self.manager.has_handler(event_name) async def raise_event(self, event: Event[Any]) -> None: - """Raise event for bound manager.""" + """Raise event for bound manager. + + Raises AttributeError if this component is not bound. + """ await self.manager.raise_event(event) def component_exists(self, component_name: str) -> bool: - """Return if component exists in manager.""" + """Return if component exists in manager. + + Raises AttributeError if this component is not bound. + """ return self.manager.component_exists(component_name) def components_exist(self, component_names: Iterable[str]) -> bool: - """Return if all component names given exist in manager.""" + """Return if all component names given exist in manager. + + Raises AttributeError if this component is not bound. + """ return self.manager.components_exist(component_names) def get_component(self, component_name: str) -> Any: - """Get Component from manager.""" + """Get Component from manager. + + Raises AttributeError if this component is not bound. + """ return self.manager.get_component(component_name) def get_components( self, component_names: Iterable[str], ) -> list[Component]: - """Return Components from manager.""" + """Return Components from manager. + + Raises AttributeError if this component is not bound. + """ return self.manager.get_components(component_names) @@ -165,14 +192,14 @@ class ComponentManager(Component): __slots__ = ("__event_handlers", "__components", "__weakref__") - def __init__(self, name: str, own_name: str | None = None) -> None: + def __init__(self, name: object, own_name: object | None = None) -> None: """If own_name is set, add self to list of components as specified name.""" super().__init__(name) self.__event_handlers: dict[ str, - set[tuple[Callable[[Event[Any]], Awaitable[Any]], str]], + set[tuple[Callable[[Event[Any]], Awaitable[Any]], object]], ] = {} - self.__components: dict[str, Component] = {} + self.__components: dict[object, Component] = {} if own_name is not None: self.__add_self_as_component(own_name) @@ -182,8 +209,11 @@ def __repr__(self) -> str: """Return representation of self.""" return f"<{self.__class__.__name__} Components: {self.__components}>" - def __add_self_as_component(self, name: str) -> None: - """Add this manager as component to self without binding.""" + def __add_self_as_component(self, name: object) -> None: + """Add this manager as component to self without binding. + + Raises ValueError if a component with given name already exists. + """ if self.component_exists(name): # pragma: nocover raise ValueError(f'Component named "{name}" already exists!') self.__components[name] = self @@ -200,9 +230,12 @@ def register_component_handler( self, event_name: str, handler_coro: Callable[[Event[Any]], Awaitable[None]], - component_name: str, + component_name: object, ) -> None: - """Register handler_func as handler for event_name.""" + """Register handler_func as handler for event_name. + + Raises ValueError if no component with given name is registered. + """ if ( component_name != self.name and component_name not in self.__components @@ -223,7 +256,10 @@ async def raise_event_in_nursery( event: Event[Any], nursery: trio.Nursery, ) -> None: - """Raise event in a particular trio nursery.""" + """Raise event in a particular trio nursery. + + Could raise RuntimeError if given nursery is no longer open. + """ await trio.lowlevel.checkpoint() # Forward leveled events up; They'll come back to us soon enough. @@ -257,7 +293,11 @@ async def raise_event(self, event: Event[Any]) -> None: await self.raise_event_in_nursery(event, nursery) def add_component(self, component: Component) -> None: - """Add component to this manager.""" + """Add component to this manager. + + Raises ValueError if component already exists with component name. + `component` must be an instance of Component. + """ assert isinstance(component, Component), "Must be component instance" if self.component_exists(component.name): raise ValueError( @@ -267,12 +307,19 @@ def add_component(self, component: Component) -> None: component.bind(self) def add_components(self, components: Iterable[Component]) -> None: - """Add multiple components to this manager.""" + """Add multiple components to this manager. + + Raises ValueError if any component already exists with component name. + `component`s must be instances of Component. + """ for component in components: self.add_component(component) - def remove_component(self, component_name: str) -> None: - """Remove a component.""" + def remove_component(self, component_name: object) -> None: + """Remove a component. + + Raises ValueError if component name does not exist. + """ if not self.component_exists(component_name): raise ValueError(f"Component {component_name!r} does not exist!") # Remove component from registered components @@ -294,7 +341,7 @@ def remove_component(self, component_name: str) -> None: for name in empty: self.__event_handlers.pop(name) - def component_exists(self, component_name: str) -> bool: + def component_exists(self, component_name: object) -> bool: """Return if component exists in this manager.""" return component_name in self.__components @@ -312,26 +359,26 @@ def temporary_component( if self.component_exists(name): self.remove_component(name) - def components_exist(self, component_names: Iterable[str]) -> bool: + def components_exist(self, component_names: Iterable[object]) -> bool: """Return if all component names given exist in this manager.""" return all(self.component_exists(name) for name in component_names) - def get_component(self, component_name: str) -> Any: - """Return Component or raise ValueError.""" + def get_component(self, component_name: object) -> Any: + """Return Component or raise ValueError because it doesn't exist.""" if not self.component_exists(component_name): raise ValueError(f'"{component_name}" component does not exist') return self.__components[component_name] - def get_components(self, component_names: Iterable[str]) -> list[Any]: + def get_components(self, component_names: Iterable[object]) -> list[Any]: """Return iterable of components asked for or raise ValueError.""" return [self.get_component(name) for name in component_names] - def list_components(self) -> tuple[str, ...]: - """Return list of components bound to this manager.""" + def list_components(self) -> tuple[object, ...]: + """Return tuple of the names of components bound to this manager.""" return tuple(self.__components) def get_all_components(self) -> tuple[Component, ...]: - """Return all bound components.""" + """Return tuple of all components bound to this manager.""" return tuple(self.__components.values()) def unbind_components(self) -> None: @@ -358,16 +405,19 @@ class ExternalRaiseManager(ComponentManager): def __init__( self, - name: str, + name: object, nursery: trio.Nursery, - own_name: str | None = None, + own_name: object | None = None, ) -> None: """Initialize with name, own component name, and nursery.""" super().__init__(name, own_name) self.nursery = nursery async def raise_event(self, event: Event[Any]) -> None: - """Raise event in nursery.""" + """Raise event in nursery. + + Could raise RuntimeError if self.nursery is no longer open. + """ await self.raise_event_in_nursery(event, self.nursery) async def raise_event_internal(self, event: Event[Any]) -> None: diff --git a/src/checkers/element_list.py b/src/checkers/element_list.py new file mode 100644 index 0000000..3c1ff2d --- /dev/null +++ b/src/checkers/element_list.py @@ -0,0 +1,139 @@ +"""Element List - List of element sprites.""" + +# Programmed by CoolCat467 + +from __future__ import annotations + +# Element List - List of element sprites. +# Copyright (C) 2024 CoolCat467 +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +__title__ = "Element List" +__author__ = "CoolCat467" +__version__ = "0.0.0" +__license__ = "GNU General Public License Version 3" + + +from typing import TYPE_CHECKING + +from checkers import sprite +from checkers.vector import Vector2 + +if TYPE_CHECKING: + from collections.abc import Generator + + +class Element(sprite.Sprite): + """Element sprite.""" + + __slots__ = () + + def self_destruct(self) -> None: + """Remove this element.""" + self.kill() + if self.manager_exists: + self.manager.remove_component(self.name) + + def __del__(self) -> None: + """Clean up this element for garbage collecting.""" + self.self_destruct() + super().__del__() + + +class ElementList(sprite.Sprite): + """Element List sprite.""" + + __slots__ = ("_order",) + + def __init__(self, name: object) -> None: + """Initialize connection list.""" + super().__init__(name) + + self._order: list[object] = [] + + def add_element(self, element: Element) -> None: + """Add element to this list.""" + group = self.groups()[-1] + group.add(element) # type: ignore[arg-type] + self.add_component(element) + self._order.append(element.name) + + def delete_element(self, element_name: object) -> None: + """Delete an element (only from component).""" + element = self.get_component(element_name) + index = self._order.index(element_name) + if element.visible: + assert element.image is not None + height = element.image.get_height() + self.offset_elements_after(index, (0, -height)) + self._order.pop(index) + assert isinstance(element, Element) + element.self_destruct() + + def yield_elements(self) -> Generator[Element, None, None]: + """Yield bound Element components in order.""" + for component_name in self._order: + if not self.component_exists(component_name): + self._order.remove(component_name) + continue + component = self.get_component(component_name) + assert isinstance(component, Element) + yield component + + def get_last_rendered_element(self) -> Element | None: + """Return last bound Element sprite or None.""" + for component_name in reversed(self._order): + if not self.component_exists(component_name): + self._order.remove(component_name) + continue + component = self.get_component(component_name) + assert isinstance(component, Element) + if component.visible: + assert component.image is not None + return component + return None + + def get_new_connection_position(self) -> Vector2: + """Return location for new connection.""" + last_element = self.get_last_rendered_element() + if last_element is None: + return Vector2.from_iter(self.rect.topleft) + location = Vector2.from_iter(last_element.rect.topleft) + assert last_element.image is not None + location += (0, last_element.image.get_height()) + return location + + def offset_elements(self, diff: tuple[int, int]) -> None: + """Offset all element locations by given difference.""" + for element in self.yield_elements(): + element.location += diff + + def offset_elements_after(self, index: int, diff: tuple[int, int]) -> None: + """Offset elements after index by given difference.""" + for idx, element in enumerate(self.yield_elements()): + if idx <= index: + continue + element.location += diff + + def _set_location(self, value: tuple[int, int]) -> None: + """Set rect center from tuple of integers.""" + current = self.location + super()._set_location(value) + diff = Vector2.from_iter(value) - current + self.offset_elements(diff) + + +if __name__ == "__main__": + print(f"{__title__} v{__version__}\nProgrammed by {__author__}.\n") diff --git a/src/checkers/encrypted_event.py b/src/checkers/encrypted_event.py index 638d8d9..e689567 100644 --- a/src/checkers/encrypted_event.py +++ b/src/checkers/encrypted_event.py @@ -84,18 +84,45 @@ def enable_encryption( self.decryptor = self.cipher.decryptor() async def write(self, data: bytes) -> None: - """Write encrypted data to stream.""" + """Send the given data, encrypted through the stream, blocking if necessary. + + Args: + data (bytes, bytearray, or memoryview): The data to send. + + Raises: + trio.BusyResourceError: if another task is already executing a + :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or + :meth:`HalfCloseableStream.send_eof` on this stream. + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`send_all` is running. + + Most low-level operations in Trio provide a guarantee: if they raise + :exc:`trio.Cancelled`, this means that they had no effect, so the + system remains in a known state. This is **not true** for + :meth:`send_all`. If this operation raises :exc:`trio.Cancelled` (or + any other exception for that matter), then it may have sent some, all, + or none of the requested data, and there is no way to know which. + + Copied from Trio docs. + + """ if self.encryption_enabled: data = self.encryptor.update(data) return await super().write(data) async def read(self, length: int) -> bytearray: - """Read `length` encrypted bytes from stream. + """Read `length` bytes from stream. Can raise following exceptions: NetworkStreamNotConnectedError NetworkTimeoutError - Timeout or no data - OSError - Stopped responding + OSError - Stopped responding + trio.BusyResourceError - Another task is already writing data + trio.BrokenResourceError - Something is wrong and stream is broken + trio.ClosedResourceError - Stream is closed or another task closes stream """ data = await super().read(length) if self.encryption_enabled: diff --git a/src/checkers/encryption.py b/src/checkers/encryption.py index e6c5d29..e2ad0d2 100644 --- a/src/checkers/encryption.py +++ b/src/checkers/encryption.py @@ -65,6 +65,17 @@ def generate_rsa_key() -> RSAPrivateKey: # pragma: no cover ) +def encrypt_with_rsa( + public_key: RSAPublicKey, + data: bytes, +) -> bytes: + """Encrypt given data with given RSA public key.""" + return public_key.encrypt( + bytes(data), + OAEP(MGF1(SHA256()), SHA256(), None), + ) + + def encrypt_token_and_secret( public_key: RSAPublicKey, verification_token: bytes, @@ -77,15 +88,20 @@ def encrypt_token_and_secret( :param shared_secret: The generated shared secret :return: A tuple containing (encrypted token, encrypted secret) """ - encrypted_token = public_key.encrypt( - bytes(verification_token), - OAEP(MGF1(SHA256()), SHA256(), None), - ) - encrypted_secret = public_key.encrypt( - bytes(shared_secret), + encrypted_token = encrypt_with_rsa(public_key, verification_token) + encrypted_secret = encrypt_with_rsa(public_key, shared_secret) + return encrypted_token, encrypted_secret + + +def decrypt_with_rsa( + private_key: RSAPrivateKey, + data: bytes, +) -> bytes: + """Decrypt given data with given RSA private key.""" + return private_key.decrypt( + bytes(data), OAEP(MGF1(SHA256()), SHA256(), None), ) - return encrypted_token, encrypted_secret def decrypt_token_and_secret( @@ -100,14 +116,8 @@ def decrypt_token_and_secret( :param shared_secret: The shared secret encrypted and sent by the client :return: A tuple containing (decrypted token, decrypted secret) """ - decrypted_token = private_key.decrypt( - bytes(verification_token), - OAEP(MGF1(SHA256()), SHA256(), None), - ) - decrypted_secret = private_key.decrypt( - bytes(shared_secret), - OAEP(MGF1(SHA256()), SHA256(), None), - ) + decrypted_token = decrypt_with_rsa(private_key, verification_token) + decrypted_secret = decrypt_with_rsa(private_key, shared_secret) return decrypted_token, decrypted_secret diff --git a/src/checkers/game.py b/src/checkers/game.py index 9d92076..5a66d45 100644 --- a/src/checkers/game.py +++ b/src/checkers/game.py @@ -43,7 +43,7 @@ from pygame.locals import K_ESCAPE, KEYUP, QUIT, WINDOWRESIZED from pygame.rect import Rect -from checkers import base2d, objects, sprite +from checkers import base2d, element_list, objects, sprite from checkers.async_clock import Clock from checkers.client import GameClient, read_advertisements from checkers.component import ( @@ -281,7 +281,7 @@ async def handle_update_event(self, event: Event[int]) -> None: class Tile(sprite.Sprite): """Outlined tile sprite - Only exists for selecting destination.""" - __slots__ = ("color", "board_position", "position_name") + __slots__ = ("color", "board_position") def __init__( self, @@ -295,7 +295,6 @@ def __init__( self.color = color self.board_position = position - self.position_name = position_name self.location = location self.update_location_on_resize = True @@ -313,7 +312,7 @@ def bind_handlers(self) -> None: self.register_handlers( { "click": self.handle_click_event, - f"self_destruct_tile_{self.position_name}": self.handle_self_destruct_event, + f"self_destruct_{self.name}": self.handle_self_destruct_event, }, ) @@ -487,10 +486,8 @@ async def handle_create_piece_event( event: Event[tuple[Pos, int]], ) -> None: """Handle create_piece event.""" - if not self.visible: - # If not visible, re-raise until board is set up right - await self.raise_event(event) - return + while not self.visible: + raise RuntimeError("handle_create_piece_event not visible yet.") piece_pos, piece_type = event.data self.add_piece(piece_type, piece_pos) @@ -737,6 +734,7 @@ def add_piece( group.add(piece) # type: ignore[arg-type] self.pieces[position] = piece_type + assert isinstance(piece.name, str) return piece.name def add_tile(self, position: Pos) -> str: @@ -751,6 +749,7 @@ def add_tile(self, position: Pos) -> str: self.add_component(tile) group.add(tile) # type: ignore[arg-type] + assert isinstance(tile.name, str) return tile.name def generate_board_image(self) -> Surface: @@ -860,10 +859,13 @@ async def mouse_down( assert isinstance(event.data["pos"], tuple) target.destination = Vector2.from_iter(event.data["pos"]) - async def move_towards_dest(self, event: Event[dict[str, float]]) -> None: + async def move_towards_dest( + self, + event: Event[sprite.TickEventData], + ) -> None: """Move closer to destination.""" target: sprite.TargetingComponent = self.get_component("targeting") - await target.move_destination_time(event.data["time_passed"]) + await target.move_destination_time(event.data.time_passed) class MrFloppy(sprite.Sprite): @@ -955,12 +957,13 @@ def __init__(self) -> None: ) super().__init__("fps", font) - self.location = Vector2.from_iter(self.image.get_size()) / 2 + (5, 5) + self.location = (20, 20) - async def on_tick(self, event: Event[dict[str, float]]) -> None: + async def on_tick(self, event: Event[sprite.TickEventData]) -> None: """Update text.""" - # self.text = f'FPS: {event.data["fps"]:.2f}' - self.text = f'FPS: {event.data["fps"]:.0f}' + # self.text = f'FPS: {event.data.fps:.2f}' + self.text = f"FPS: {event.data.fps:.0f}" + self.visible = True def bind_handlers(self) -> None: """Register tick event handler.""" @@ -1179,6 +1182,21 @@ async def entry_actions(self) -> None: ) self.group_add(internal_button) + quit_button = KwargButton( + "quit_button", + button_font, + visible=True, + color=Color(0, 0, 0), + text="Quit", + location=join_button.location + + Vector2( + 0, + join_button.rect.h + 10, + ), + handle_click=self.change_state("Halt"), + ) + self.group_add(quit_button) + await self.machine.raise_event(Event("init", None)) @@ -1236,44 +1254,72 @@ class PlayInternalHostingState(PlayHostingState): internal_server = True -class JoinButton(Button): - """Join Button.""" +class ReturnElement(element_list.Element, objects.Button): + """Connection list return to title element sprite.""" __slots__ = () - def __init__(self, id_: int, font: pygame.font.Font, motd: str) -> None: - """Initialize Join Button.""" - super().__init__(f"join_button_{id_}", font) - self.text = motd - self.location = [x // 2 for x in SCREEN_SIZE] + def __init__(self, name: str, font: pygame.font.Font) -> None: + """Initialize return element.""" + super().__init__(name, font) + + self.update_location_on_resize = False + self.border_width = 4 + self.outline = RED + self.text = "Return to Title" + self.visible = True + self.location = (SCREEN_SIZE[0] // 2, self.location.y + 10) + + async def handle_click( + self, + _: Event[sprite.PygameMouseButtonEventData], + ) -> None: + """Handle Click Event.""" + await self.raise_event( + Event("return_to_title", None, 2), + ) + + +class ConnectionElement(element_list.Element, objects.Button): + """Connection list element sprite.""" + + __slots__ = () + + def __init__( + self, + name: tuple[str, int], + font: pygame.font.Font, + motd: str, + ) -> None: + """Initialize connection element.""" + super().__init__(name, font) + + self.text = f"[{name[0]}:{name[1]}]\n{motd}" + self.visible = True async def handle_click( self, _: Event[sprite.PygameMouseButtonEventData], ) -> None: """Handle Click Event.""" - print(f"{self!r} handle_click") + details = self.name + await self.raise_event( + Event("join_server", details, 2), + ) class PlayJoiningState(GameState): """Start running client.""" - __slots__ = ( - "font", - "buttons", - "next_button", - ) + __slots__ = ("font",) def __init__(self) -> None: """Initialize Joining State.""" super().__init__("play_joining") - self.next_button = 0 - self.buttons: dict[tuple[str, int], int] = {} - self.font = pygame.font.Font( DATA_FOLDER / "VeraSerif.ttf", - 28, + 12, ) async def entry_actions(self) -> None: @@ -1283,14 +1329,28 @@ async def entry_actions(self) -> None: self.id = self.machine.new_group("join") client = GameClient("network") + # Add network to higher level manager self.machine.manager.add_component(client) - self.buttons.clear() - self.next_button = 0 + connections = element_list.ElementList("connection_list") + self.manager.add_component(connections) + group = self.machine.get_group(self.id) + assert group is not None + group.add(connections) - self.manager.register_handler( - "update_listing", - self.handle_update_listing, + return_font = pygame.font.Font( + DATA_FOLDER / "VeraSerif.ttf", + 30, + ) + return_button = ReturnElement("return_button", return_font) + connections.add_element(return_button) + + self.manager.register_handlers( + { + "update_listing": self.handle_update_listing, + "return_to_title": self.handle_return_to_title, + "join_server": self.handle_join_server, + }, ) await self.manager.raise_event(Event("update_listing", None)) @@ -1298,25 +1358,53 @@ async def entry_actions(self) -> None: async def handle_update_listing(self, _: Event[None]) -> None: """Update server listing.""" assert self.machine is not None - for advertisement in await read_advertisements(): - motd, details = advertisement - if details not in self.buttons: - self.buttons[details] = self.next_button + connections = self.manager.get_component("connection_list") + + old: list[tuple[str, int]] = [] + current: list[tuple[str, int]] = [] + + ## print(f'{self.machine.active_state = }') + ## print(f'{self.name = }') + while ( + self.machine.active_state is not None + and self.machine.active_state is self + ): + ## print("handle_update_listing click") + + for motd, details in await read_advertisements(): + current.append(details) + if connections.component_exists(details): + continue + element = ConnectionElement(details, self.font, motd) + element.rect.topleft = ( + connections.get_new_connection_position() + ) + element.rect.topleft = (10, element.location.y + 3) + connections.add_element(element) + for details in old: + if details in current: + continue + connections.delete_element(details) + old, current = current, [] + + async def handle_join_server(self, event: Event[tuple[str, int]]) -> None: + """Handle join server event.""" + details = event.data + await self.machine.raise_event( + Event("client_connect", details), + ) + await self.machine.set_state("play") - print(f"handle_update_listing {motd = } {details = }") - ##button = JoinButton(self.next_button, self.font, motd, details) - ##self.group_add(button) + async def handle_return_to_title(self, _: Event[None]) -> None: + """Handle return to title event.""" + # Fire server stop event so server shuts down if it exists + await self.machine.raise_event_internal(Event("network_stop", None)) - self.next_button += 1 - #### - await self.machine.raise_event( - Event("client_connect", details), - ) - await self.machine.set_state("play") - return - print("handle_update_listing click") - await self.manager.raise_event(Event("update_listing", None)) + if self.machine.manager.component_exists("network"): + self.machine.manager.remove_component("network") + + await self.machine.set_state("title") ## async def check_conditions(self) -> str | None: @@ -1375,7 +1463,8 @@ async def check_conditions(self) -> str | None: async def exit_actions(self) -> None: """Raise network stop event and remove components.""" # Fire server stop event so server shuts down if it exists - await self.machine.raise_event(Event("network_stop", None)) + # await self.machine.raise_event(Event("network_stop", None)) + await self.machine.raise_event_internal(Event("network_stop", None)) if self.machine.manager.component_exists("network"): self.machine.manager.remove_component("network") @@ -1394,7 +1483,7 @@ async def handle_game_over(self, event: Event[int]) -> None: winner = event.data self.exit_data = (0, f"{PLAYERS[winner]} Won", False) - await self.machine.raise_event(Event("network_stop", None)) + await self.machine.raise_event_internal(Event("network_stop", None)) async def handle_client_disconnected(self, event: Event[str]) -> None: """Handle client disconnected error.""" @@ -1437,6 +1526,9 @@ async def do_actions(self) -> None: handle_click=self.change_state("title"), ) self.group_add(continue_button) + group = continue_button.groups()[0] + # LayeredDirty, not just AbstractGroup + group.move_to_front(continue_button) # type: ignore[attr-defined] else: continue_button = self.manager.get_component("continue_button") @@ -1463,7 +1555,7 @@ class CheckersClient(sprite.GroupProcessor): __slots__ = ("manager",) - def __init__(self, manager: ComponentManager) -> None: + def __init__(self, manager: ExternalRaiseManager) -> None: """Initialize Checkers Client.""" super().__init__() self.manager = manager @@ -1484,6 +1576,10 @@ async def raise_event(self, event: Event[Any]) -> None: """Raise component event in all groups.""" await self.manager.raise_event(event) + async def raise_event_internal(self, event: Event[Any]) -> None: + """Raise component event in all groups.""" + await self.manager.raise_event_internal(event) + async def async_run() -> None: """Handle main event loop.""" diff --git a/src/checkers/network.py b/src/checkers/network.py index 172009d..72b4280 100644 --- a/src/checkers/network.py +++ b/src/checkers/network.py @@ -29,7 +29,6 @@ from typing import ( TYPE_CHECKING, Any, - AnyStr, Literal, NoReturn, ) @@ -60,6 +59,12 @@ class NetworkTimeoutError(Exception): __slots__ = () +class NetworkEOFError(Exception): + """Network End of File Error.""" + + __slots__ = () + + class NetworkStreamNotConnectedError(Exception): """Network Stream Not Connected Error.""" @@ -105,7 +110,13 @@ def from_stream( return self async def connect(self, host: str, port: int) -> None: - """Connect to host:port on TCP.""" + """Connect to host:port on TCP. + + Raises: + OSError: if the connection fails. + RuntimeError: if stream is already connected + + """ if not self.not_connected: raise RuntimeError("Already connected!") try: # pragma: nocover @@ -118,8 +129,9 @@ async def read(self, length: int) -> bytearray: """Read `length` bytes from stream. Can raise following exceptions: - NetworkStreamNotConnectedError - NetworkTimeoutError - Timeout or no data + NetworkStreamNotConnectedError - Network stream is not connected + NetworkTimeoutError - Timeout + NetworkEOFError - End of File OSError - Stopped responding trio.BusyResourceError - Another task is already writing data trio.BrokenResourceError - Something is wrong and stream is broken @@ -129,17 +141,19 @@ async def read(self, length: int) -> bytearray: while max_read_count := length - len(content): received = b"" ## try: - with trio.move_on_after(self.timeout): + with trio.move_on_after(self.timeout) as cancel_scope: received = await self.stream.receive_some(max_read_count) + cancel_called = cancel_scope.cancel_called ## except (trio.BrokenResourceError, trio.ClosedResourceError): ## await self.close() ## raise if len(received) == 0: # No information at all if len(content) == 0: - raise NetworkTimeoutError( - "Server did not respond with any information. " - "This may be from a connection timeout.", + if cancel_called: + raise NetworkTimeoutError("Read timed out.") + raise NetworkEOFError( + "Server did not respond with any information.", ) # Only sent a few bytes, but we requested more raise OSError( @@ -150,8 +164,32 @@ async def read(self, length: int) -> bytearray: content.extend(received) return content - async def write(self, data: bytes) -> None: - """Write data to stream.""" + async def write(self, data: bytes | bytearray | memoryview) -> None: + """Send the given data through the stream, blocking if necessary. + + Args: + data (bytes, bytearray, or memoryview): The data to send. + + Raises: + trio.BusyResourceError: if another task is already executing a + :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or + :meth:`HalfCloseableStream.send_eof` on this stream. + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`send_all` is running. + + Most low-level operations in Trio provide a guarantee: if they raise + :exc:`trio.Cancelled`, this means that they had no effect, so the + system remains in a known state. This is **not true** for + :meth:`send_all`. If this operation raises :exc:`trio.Cancelled` (or + any other exception for that matter), then it may have sent some, all, + or none of the requested data, and there is no way to know which. + + Copied from Trio docs. + + """ await self.stream.send_all(data) ## try: @@ -161,7 +199,7 @@ async def write(self, data: bytes) -> None: ## raise async def close(self) -> None: - """Close the stream.""" + """Close the stream, possibly blocking.""" if self._stream is None: await trio.lowlevel.checkpoint() return @@ -169,12 +207,88 @@ async def close(self) -> None: self._stream = None async def send_eof(self) -> None: - """Close the sending half of the stream.""" + """Close the sending half of the stream. + + This corresponds to ``shutdown(..., SHUT_WR)`` (`man + page `__). + + If an EOF has already been sent, then this method should silently + succeed. + + Raises: + trio.BusyResourceError: if another task is already executing a + :meth:`~SendStream.send_all`, + :meth:`~SendStream.wait_send_all_might_not_block`, or + :meth:`send_eof` on this stream. + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + + Suppresses: + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`send_eof` is running. + + Copied from trio docs. + + """ with contextlib.suppress(trio.ClosedResourceError): await self.stream.send_eof() async def wait_write_might_not_block(self) -> None: - """stream.wait_send_all_might_not_block.""" + """Block until it's possible that :meth:`write` might not block. + + This method may return early: it's possible that after it returns, + :meth:`send_all` will still block. (In the worst case, if no better + implementation is available, then it might always return immediately + without blocking. It's nice to do better than that when possible, + though.) + + This method **must not** return *late*: if it's possible for + :meth:`send_all` to complete without blocking, then it must + return. When implementing it, err on the side of returning early. + + Raises: + trio.BusyResourceError: if another task is already executing a + :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or + :meth:`HalfCloseableStream.send_eof` on this stream. + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`wait_send_all_might_not_block` is running. + + Note: + This method is intended to aid in implementing protocols that want + to delay choosing which data to send until the last moment. E.g., + suppose you're working on an implementation of a remote display server + like `VNC + `__, and + the network connection is currently backed up so that if you call + :meth:`send_all` now then it will sit for 0.5 seconds before actually + sending anything. In this case it doesn't make sense to take a + screenshot, then wait 0.5 seconds, and then send it, because the + screen will keep changing while you wait; it's better to wait 0.5 + seconds, then take the screenshot, and then send it, because this + way the data you deliver will be more + up-to-date. Using :meth:`wait_send_all_might_not_block` makes it + possible to implement the better strategy. + + If you use this method, you might also want to read up on + ``TCP_NOTSENT_LOWAT``. + + Further reading: + + * `Prioritization Only Works When There's Pending Data to Prioritize + `__ + + * WWDC 2015: Your App and Next Generation Networks: `slides + `__, + `video and transcript + `__ + + Copied from Trio docs. + + """ return await self.stream.wait_send_all_might_not_block() async def __aenter__(self) -> Self: @@ -187,7 +301,7 @@ async def __aexit__( exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: - """Async context manager exit.""" + """Async context manager exit. Close connection.""" await self.close() @@ -233,7 +347,12 @@ def register_network_write_event( event_name: str, packet_id: int, ) -> None: - """Map event name to serverbound packet id.""" + """Map event name to serverbound packet id. + + Raises: + ValueError: Event name already registered or infinite network loop. + + """ if event_name in self._write_event_name_to_packet_id: raise ValueError(f"{event_name!r} event already registered!") if self._read_packet_id_to_event_name.get(packet_id) == event_name: @@ -252,7 +371,20 @@ def register_network_write_events(self, event_map: dict[str, int]) -> None: self.register_network_write_event(event_name, packet_id) async def write_event(self, event: Event[bytearray]) -> None: - """Send event to network.""" + """Send event to network. + + Raises: + RuntimeError: if unregistered packet id received from network + trio.BusyResourceError: if another task is already executing a + :meth:`send_all`, :meth:`wait_send_all_might_not_block`, or + :meth:`HalfCloseableStream.send_eof` on this stream. + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if you previously closed this stream + object, or if another task closes this stream object while + :meth:`send_all` is running. + + """ packet_id = self._write_event_name_to_packet_id.get(event.name) if packet_id is None: raise RuntimeError(f"Unhandled network event name {event.name!r}") @@ -263,7 +395,19 @@ async def write_event(self, event: Event[bytearray]) -> None: await self.write(buffer) async def read_event(self) -> Event[bytearray]: - """Receive event from network.""" + """Receive event from network. + + Can raise following exceptions: + RuntimeError - Unhandled packet id + NetworkStreamNotConnectedError - Network stream is not connected + NetworkTimeoutError - Timeout or no data + OSError - Stopped responding + trio.BrokenResourceError - Something is wrong and stream is broken + trio.ClosedResourceError - Stream is closed or another task closes stream + + Shouldn't happen with write lock but still: + trio.BusyResourceError - Another task is already writing data + """ async with self.read_lock: packet_id = await self.read_value(self.packet_id_format) event_data = await self.read_bytearray() @@ -297,12 +441,12 @@ def register_read_network_events(self, packet_map: dict[int, str]) -> None: class Server(ComponentManager): """Asynchronous TCP Server.""" - __slots__ = ("cancel_scope",) + __slots__ = ("serve_cancel_scope",) def __init__(self, name: str, own_name: str | None = None) -> None: """Initialize Server.""" super().__init__(name, own_name) - self.cancel_scope: trio.CancelScope | None = None + self.serve_cancel_scope: trio.CancelScope | None = None def stop_serving(self) -> None: """Cancel serve scope immediately. @@ -310,19 +454,19 @@ def stop_serving(self) -> None: This method is idempotent, i.e., if the scope was already cancelled then this method silently does nothing. """ - if self.cancel_scope is None: + if self.serve_cancel_scope is None: return - self.cancel_scope.cancel() + self.serve_cancel_scope.cancel() # "Implicit return in function which does not return" async def serve( # type: ignore[misc] # pragma: nocover self, port: int, - host: AnyStr | None = None, + host: str | bytes | None = None, backlog: int | None = None, ) -> NoReturn: """Serve over TCP. See trio.open_tcp_listeners for argument details.""" - self.cancel_scope = trio.CancelScope() + self.serve_cancel_scope = trio.CancelScope() async with trio.open_nursery() as nursery: listeners = await trio.open_tcp_listeners( port, @@ -333,9 +477,9 @@ async def serve( # type: ignore[misc] # pragma: nocover async def handle_serve( task_status: trio.TaskStatus[Any] = trio.TASK_STATUS_IGNORED, ) -> None: - assert self.cancel_scope is not None + assert self.serve_cancel_scope is not None try: - with self.cancel_scope: + with self.serve_cancel_scope: await trio.serve_listeners( self.handler, listeners, diff --git a/src/checkers/objects.py b/src/checkers/objects.py index 4a30216..886de80 100644 --- a/src/checkers/objects.py +++ b/src/checkers/objects.py @@ -53,7 +53,7 @@ class Text(sprite.Sprite): __slots__ = ("__text", "font") - def __init__(self, name: str, font: Font) -> None: + def __init__(self, name: object, font: Font) -> None: """Initialize with font. Defaults to white text.""" super().__init__(name) @@ -135,7 +135,7 @@ class OutlinedText(Text): border_width = 3 border_radius = 8 - def __init__(self, name: str, font: Font) -> None: + def __init__(self, name: object, font: Font) -> None: """Initialize with name and font.""" super().__init__(name, font) @@ -196,7 +196,7 @@ class Button(OutlinedText): __slots__ = () - def __init__(self, name: str, font: Font) -> None: + def __init__(self, name: object, font: Font) -> None: """Initialize with name and font.""" super().__init__(name, font) diff --git a/src/checkers/server.py b/src/checkers/server.py index ec07838..42ae222 100755 --- a/src/checkers/server.py +++ b/src/checkers/server.py @@ -27,14 +27,15 @@ __license__ = "GNU General Public License Version 3" __version__ = "0.0.0" +import time import traceback from collections import deque from functools import partial -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, NoReturn, cast import trio -from checkers.async_clock import Clock +from checkers import network from checkers.base_io import StructFormat from checkers.buffer import Buffer from checkers.component import ComponentManager, Event, ExternalRaiseManager @@ -46,7 +47,6 @@ generate_verify_token, serialize_public_key, ) -from checkers.network import NetworkEventComponent, NetworkTimeoutError, Server from checkers.network_shared import ( ADVERTISEMENT_IP, ADVERTISEMENT_PORT, @@ -54,12 +54,11 @@ ClientBoundEvents, Pos, ServerBoundEvents, - TickEventData, find_ip, read_position, write_position, ) -from checkers.state import ActionSet, State, generate_pieces +from checkers.state import State, generate_pieces if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Iterable @@ -133,6 +132,7 @@ def bind_handlers(self) -> None: "action_complete->network": self.handle_action_complete, "initial_config->network": self.handle_initial_config, f"playing_as->network[{self.client_id}]": self.handle_playing_as, + f"callback_ping->network[{self.client_id}]": self.handle_callback_ping, }, ) @@ -305,6 +305,34 @@ async def handle_playing_as( buffer.write_value(StructFormat.UBYTE, playing_as) await self.write_event(Event("server[write]->playing_as", buffer)) + async def write_callback_ping(self) -> None: + """Write callback_ping packet to client. + + Could raise the following exceptions: + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if stream was previously closed + + Listed as possible but probably not because of write lock: + trio.BusyResourceError: if another task is using :meth:`write` + """ + buffer = Buffer() + + # Try to be as accurate with time as possible + await self.wait_write_might_not_block() + ns = int(time.time() * 1e9) + # Use as many bits as time needs, write_buffer handles size for us. + buffer.write(ns.to_bytes(-(-ns.bit_length() // 8), "big")) + + await self.write_event(Event("server[write]->callback_ping", buffer)) + + async def handle_callback_ping( + self, + _: Event[None], + ) -> None: + """Reraise as server[write]->callback_ping.""" + await self.write_callback_ping() + async def start_encryption_request(self) -> None: """Start encryption request and raise as server[write]->encryption_request.""" if self.encryption_enabled: @@ -372,11 +400,9 @@ def __init__( size: Pos, pieces: dict[Pos, int], turn: bool = True, - /, - pre_calculated_actions: dict[Pos, ActionSet] | None = None, ) -> None: """Initialize Checkers State.""" - super().__init__(size, pieces, turn, pre_calculated_actions) + super().__init__(size, pieces, turn) self.action_queue: deque[tuple[str, Iterable[Pos | int]]] = deque() def piece_kinged(self, piece_pos: Pos, new_type: int) -> None: @@ -407,7 +433,7 @@ def get_action_queue(self) -> deque[tuple[str, Iterable[Pos | int]]]: return self.action_queue -class GameServer(Server): +class GameServer(network.Server): """Checkers server. Handles accepting incoming connections from clients and handles @@ -463,7 +489,7 @@ async def stop_server(self, event: Event[None] | None = None) -> None: close_methods: deque[Callable[[], Awaitable[object]]] = deque() for component in self.get_all_components(): - if isinstance(component, NetworkEventComponent): + if isinstance(component, network.NetworkEventComponent): close_methods.append(component.close) print(f"stop_server {component.name = }") self.remove_component(component.name) @@ -580,10 +606,11 @@ def new_game_init(self) -> None: self.players_can_interact = True - async def start_server( + # "Implicit return in function which does not return" + async def start_server( # type: ignore[misc] self, event: Event[tuple[str | None, int]], - ) -> None: + ) -> NoReturn: """Serve clients.""" print(f"{self.__class__.__name__}: Closing old server clients") await self.stop_server() @@ -597,6 +624,7 @@ async def start_server( # Do not post advertisements when using internal singleplayer mode if not self.internal_singleplayer_mode: nursery.start_soon(self.post_advertisements, port) + # Serve runs forever until canceled nursery.start_soon(partial(self.serve, port, host, backlog=0)) async def transmit_playing_as(self) -> None: @@ -641,20 +669,39 @@ async def handle_server_start_new_game(self, event: Event[None]) -> None: ) async def client_network_loop(self, client: ServerClient) -> None: - """Network loop for given ServerClient.""" + """Network loop for given ServerClient. + + Could raise the following exceptions: + trio.BrokenResourceError: if something has gone wrong, and the stream + is broken. + trio.ClosedResourceError: if stream was previously closed + + Probably couldn't raise because of write lock but still: + trio.BusyResourceError: More than one task is trying to write + to socket at once. + """ while not self.can_start() and not client.not_connected: - await client.write_event( - Event("server[write]->callback_ping", bytearray()), - ) + try: + await client.write_callback_ping() + except ( + trio.BrokenResourceError, + trio.ClosedResourceError, + network.NetworkStreamNotConnectedError, + ): + print(f"{client.name} Disconnected in lobby.") + return while not client.not_connected: - print(f"{client.name} client_network_loop tick") + event: Event[bytearray] | None = None try: - await client.write_event( - Event("server[write]->callback_ping", bytearray()), - ) - event = await client.read_event() - except NetworkTimeoutError: - continue + await client.write_callback_ping() + with trio.move_on_after(2): + event = await client.read_event() + except network.NetworkTimeoutError: + print(f"{client.name} Timeout") + break + except network.NetworkEOFError: + print(f"{client.name} EOF") + break except ( trio.BrokenResourceError, trio.ClosedResourceError, @@ -664,7 +711,9 @@ async def client_network_loop(self, client: ServerClient) -> None: except Exception as exc: traceback.print_exception(exc) break - else: + if event is not None: + # print(f"{client.name} client_network_loop tick") + # print(f"{client.name} {event = }") await client.raise_event(event) def can_start(self) -> bool: @@ -713,7 +762,11 @@ async def send_spectator_join_packets( ) async def handler(self, stream: trio.SocketStream) -> None: - """Accept clients.""" + """Accept clients. Called by network.Server.serve.""" + if self.client_count == 0 and self.game_active(): + # Old game was running but everyone left, restart + self.state.pieces.clear() + # self.state = CheckersState(self.board_size, {}) new_client_id = self.client_count print( f"{self.__class__.__name__}: client connected [client_id {new_client_id}]", @@ -755,6 +808,7 @@ async def handler(self, stream: trio.SocketStream) -> None: f"{self.__class__.__name__}: client disconnected [client_id {new_client_id}]", ) self.client_count -= 1 + # ServerClient's `with` block handles closing stream. async def handle_network_select_piece( self, @@ -799,14 +853,14 @@ async def player_select_piece( if piece_pos is not None: # Calculate actions if required - new_action_set = self.state.get_actions_set(piece_pos) + new_action_set = self.state.calculate_actions(piece_pos) ignore = new_action_set.ends ignored: set[Pos] = set() # Remove outlined tiles from previous selection if existed if prev_selection := self.player_selections.get(player): - action_set = self.state.get_actions_set(prev_selection) + action_set = self.state.calculate_actions(prev_selection) ignored = action_set.ends & ignore remove = action_set.ends - ignore async with trio.open_nursery() as nursery: @@ -921,7 +975,7 @@ async def handle_network_select_tile( ) return - if tile_pos not in self.state.get_actions_set(piece_pos).ends: + if tile_pos not in self.state.calculate_actions(piece_pos).ends: print( f"{player = } cannot select tile {piece_pos!r} because not valid move", ) @@ -980,7 +1034,6 @@ async def run_server( event_manager = ExternalRaiseManager( "checkers", main_nursery, - "client", ) server = server_class() event_manager.add_component(server) @@ -990,45 +1043,23 @@ async def run_server( print("Server starting...") await trio.sleep(1) - print("Server running") - - clock = Clock() + print("\nServer running.") try: - while server.running: - await clock.tick() - await event_manager.raise_event( - Event( - "tick", - TickEventData( - time_passed=clock.get_time() - / 1e9, # nanoseconds -> seconds - fps=clock.get_fps(), - ), - ), - ) + while server.running: # noqa: ASYNC110 # sleep in while loop + # Process background tasks in the main nursery await trio.sleep(0.01) - finally: - server.unbind_components() - - -def run_server_sync( - server_class: type[GameServer], - host: str, - port: int, -) -> None: - """Run server given server class and address to host at.""" - trio.run(run_server, server_class, host, port) + except KeyboardInterrupt: + print("\nClosing from keyboard interrupt.") + await server.stop_server() + server.unbind_components() async def cli_run_async() -> None: """Run game server.""" host = await find_ip() port = DEFAULT_PORT - try: - await run_server(GameServer, host, port) - except KeyboardInterrupt: - print("Closing from keyboard interrupt") + await run_server(GameServer, host, port) def cli_run() -> None: diff --git a/src/checkers/sprite.py b/src/checkers/sprite.py index 5ce8c05..622590d 100644 --- a/src/checkers/sprite.py +++ b/src/checkers/sprite.py @@ -75,7 +75,7 @@ class Sprite(ComponentManager, WeakDirtySprite): __slots__ = ("rect", "__image", "mask", "update_location_on_resize") - def __init__(self, name: str) -> None: + def __init__(self, name: object) -> None: """Initialize with name.""" ComponentManager.__init__(self, name, "sprite") WeakDirtySprite.__init__(self) @@ -95,10 +95,14 @@ def __get_location(self) -> Vector2: """Return rect center as new Vector2.""" return Vector2.from_iter(self.rect.center) - def __set_location(self, value: tuple[int, int]) -> None: + def _set_location(self, value: tuple[int, int]) -> None: """Set rect center from tuple of integers.""" self.rect.center = value + def __set_location(self, value: tuple[int, int]) -> None: + """Set rect center from tuple of integers.""" + self._set_location(value) + location = property( __get_location, __set_location, diff --git a/src/checkers/state.py b/src/checkers/state.py index b75b851..09da7c9 100644 --- a/src/checkers/state.py +++ b/src/checkers/state.py @@ -26,9 +26,9 @@ import copy import math +from dataclasses import dataclass from typing import ( TYPE_CHECKING, - Any, NamedTuple, TypeAlias, TypeVar, @@ -106,31 +106,13 @@ def pawn_modify(moves: tuple[T, ...], piece_type: u8) -> tuple[T, ...]: return moves +@dataclass(slots=True) class State: """Represents state of checkers game.""" - __slots__ = ("size", "turn", "pieces", "pre_calculated_actions") - - def __init__( - self, - size: tuple[int, int], - pieces: dict[Pos, int], - turn: bool = True, # Black moves first - /, - pre_calculated_actions: dict[Pos, ActionSet] | None = None, - ) -> None: - """Initialize state.""" - self.size = size - self.turn = turn - self.pieces = pieces - - if pre_calculated_actions is None: - pre_calculated_actions = {} - self.pre_calculated_actions = pre_calculated_actions - - def __repr__(self) -> str: - """Return representation of self.""" - return f"{self.__class__.__name__}({self.size}, {self.turn}, {self.pieces})" + size: tuple[int, int] + pieces: dict[Pos, int] + turn: bool = True # Black moves first def __str__(self) -> str: """Return text representation of game board state.""" @@ -150,28 +132,6 @@ def __str__(self) -> str: ## lines.append("--+-"*(w-1)+"-") return "\n".join(lines) - @classmethod - def from_game_board(cls, board_data: dict[str, Any]) -> Self: - """Return new instance from board data.""" - size = board_data.get("boardsize", (8, 8)) - turn = True - pieces = cls.get_pieces_from_tiles(board_data.get("tiles", {})) - return cls(size, pieces, turn) - - @staticmethod - def get_pieces_from_tiles( - tiles: dict[str, dict[str, Any]], - ) -> dict[Pos, int]: - """Convert board data from game to internal representation.""" - pieces: dict[Pos, int] = {} - for _tile_name, tile_data in tiles.items(): - piece_type = tile_data["piece"] - if piece_type in {None, "None"}: - continue - x, y = tile_data["xy"] - pieces[(x, y)] = int(piece_type) - return pieces - def calculate_actions(self, position: Pos) -> ActionSet: """Return actions the piece at given position can make.""" if MANDATORY_CAPTURE: @@ -189,39 +149,16 @@ def calculate_actions(self, position: Pos) -> ActionSet: ends.update(moves) return ActionSet(jumps, moves, ends) - def get_actions_set(self, piece_position: Pos) -> ActionSet: - """Return potentially cached actions.""" - if piece_position in self.pre_calculated_actions: - new_action_set = self.pre_calculated_actions[piece_position] - else: - new_action_set = self.calculate_actions(piece_position) - self.pre_calculated_actions[piece_position] = new_action_set - return new_action_set - - def invalidate_location(self, position: Pos) -> None: - """Delete pre-calculated actions for a given position if calculated.""" - if position in self.pre_calculated_actions: - del self.pre_calculated_actions[position] - - def invalidate_all_locations(self) -> None: - """Clear all pre-calculated actions.""" - self.pre_calculated_actions.clear() - - ## print(position) - def piece_kinged(self, piece_pos: Pos, new_type: int) -> None: """Piece kinged.""" - ## print(f'piece_kinged {piece = }') - self.invalidate_location(piece_pos) + # print(f'piece_kinged {piece = }') def piece_moved(self, start_pos: Pos, end_pos: Pos) -> None: """Piece moved from start_pos to end_pos.""" - self.invalidate_location(start_pos) def piece_jumped(self, jumped_piece_pos: Pos) -> None: """Piece has been jumped.""" - ## print(f'piece_jumped {position = }') - self.invalidate_all_locations() + # print(f'piece_jumped {position = }') def preform_action(self, action: Action) -> Self: """Return new state after performing action on self.""" @@ -266,32 +203,22 @@ def preform_action(self, action: Action) -> Self: # Move piece to it's end position pieces_copy[to_pos] = piece_type - self.invalidate_location(from_pos) - - self.invalidate_all_locations() # Swap turn return self.__class__( self.size, pieces_copy, not self.turn, - pre_calculated_actions=self.pre_calculated_actions, ) def get_tile_name(self, x: int, y: int) -> str: """Return name of a given tile.""" return chr(65 + x) + str(self.size[1] - y) - def get_tile_pos(self, name: str) -> Pos: - """Return tile position from it's name.""" - x = ord(name[0]) - 65 - y = self.size[1] - int(name[1:]) - return (x, y) - @staticmethod def action_from_points(start: Pos, end: Pos) -> Action: """Return action from given start and end coordinates.""" - ## return Action(self.get_tile_name(*start), self.get_tile_name(*end)) + # return Action(self.get_tile_name(*start), self.get_tile_name(*end)) return Action(start, end) def get_turn(self) -> int: @@ -441,21 +368,21 @@ def get_moves(self, position: Pos) -> tuple[Pos, ...]: ], ) + @classmethod def wrap_actions( - self, + cls, position: Pos, calculate_ends: Callable[[Pos], Iterable[Pos]], ) -> Generator[Action, None, None]: """Yield end calculation function results as Actions.""" for end in calculate_ends(position): - yield self.action_from_points(position, end) + yield cls.action_from_points(position, end) def get_actions(self, position: Pos) -> Generator[Action, None, None]: """Yield all moves and jumps the piece at position can make.""" ends = set(self.get_jumps(position)) if not (ends and MANDATORY_CAPTURE): ends.update(self.get_moves(position)) - ## ends = self.get_actions_set(position).ends for end in ends: yield self.action_from_points(position, end) diff --git a/test-requirements.in b/test-requirements.in index 878673a..4d33f34 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -19,6 +19,6 @@ typing-extensions pygame~=2.6.0 typing_extensions>=4.12.2 mypy_extensions>=1.0.0 -trio~=0.26.2 +trio~=0.27.0 cryptography>=43.0.0 exceptiongroup; python_version < '3.11' diff --git a/test-requirements.txt b/test-requirements.txt index 5f1d280..e452c72 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -18,7 +18,7 @@ colorama==0.4.6 ; (implementation_name != 'cpython' and sys_platform == 'win32') # via # click # pytest -coverage==7.6.2 +coverage==7.6.3 # via # -r test-requirements.in # pytest-cov @@ -33,7 +33,7 @@ idna==3.10 # via trio iniconfig==2.0.0 # via pytest -mypy==1.11.2 +mypy==1.12.0 # via -r test-requirements.in mypy-extensions==1.0.0 # via @@ -79,7 +79,7 @@ tomli==2.0.2 ; python_full_version <= '3.11' # coverage # mypy # pytest -trio==0.26.2 +trio==0.27.0 # via # -r test-requirements.in # pytest-trio diff --git a/tests/test_network.py b/tests/test_network.py index 9c85c8c..4a12208 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -174,5 +174,5 @@ async def test_event_transmission() -> None: def test_server() -> None: server = Server("name") server.stop_serving() - server.cancel_scope = trio.CancelScope() + server.serve_cancel_scope = trio.CancelScope() server.stop_serving()