Skip to content

Commit

Permalink
Sharpen public interface (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev authored Jun 11, 2024
1 parent 2852786 commit cb03acb
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 97 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
- uses: extractions/setup-just@v2
- run: just install check-types

lint-format:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -45,9 +45,9 @@ jobs:
path: |
~/.cache/pip
~/.cache/pypoetry
key: lint-format-${{ hashFiles('pyproject.toml') }}
key: lint-${{ hashFiles('pyproject.toml') }}
- uses: extractions/setup-just@v2
- run: just install lint-format
- run: just install lint

test:
runs-on: ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions Justfile
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
default: install lint-format check-types test
default: install lint check-types test

install:
poetry install --sync

test *args:
poetry run pytest {{args}}

lint-format:
lint:
poetry run ruff check .
poetry run ruff format .

Expand Down
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,18 @@ stompman takes care of cleaning up resources automatically. When you leave the c

### Handling Connectivity Issues

- If multiple servers are provided, stompman will attempt to connect to each one simultaneously and use the first that succeeds.
- If all servers fail to connect, an `stompman.FailedAllConnectAttemptsError` will be raised. There're no need to handle it, if you should tune retry and timeout parameters to your needs.
- If a connection is lost, a `stompman.ReadTimeoutError` will be raised. You'll need to implement reconnect logic manually. Implementing reconnect logic in the library would be too complex, since there're no global state and clean-ups are automatic (e.g. it won't be possible to re-subscribe to destination because client doesn't keep track of subscriptions).
- If multiple servers were provided, stompman will attempt to connect to each one simultaneously and will use the first that succeeds.

- If all servers fail to connect, an `stompman.FailedAllConnectAttemptsError` will be raised. In normal situation it doesn't need to be handled: tune retry and timeout parameters in `stompman.Client()` to your needs.

- If a connection is lost, a `stompman.ConnectionLostError` will be raised. You should implement reconnect logic manually, for example, with stamina:

```python
for attempt in stamina.retry_context(on=stompman.ConnectionLostError):
with attempt:
async with stompman.Client(...) as client:
...
```

### ...and caveats

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ target-version = "py311"
fix = true
unsafe-fixes = true
line-length = 120

[tool.ruff.lint]
preview = true
select = ["ALL"]
ignore = [
"EM",
Expand All @@ -58,6 +60,7 @@ ignore = [
"ISC001",
"S101",
"SLF001",
"CPY001",
]

[tool.pytest.ini_options]
Expand Down
27 changes: 11 additions & 16 deletions stompman/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from stompman.client import Client, Heartbeat
from stompman.connection import AbstractConnection, Connection, ConnectionParameters
from stompman.errors import (
ConnectError,
ConnectionConfirmationTimeoutError,
ConnectionLostError,
Error,
FailedAllConnectAttemptsError,
ReadTimeoutError,
UnsupportedProtocolVersionError,
)
from stompman.frames import (
AbortFrame,
AckFrame,
AnyFrame,
AnyClientFrame,
AnyServerFrame,
BeginFrame,
ClientFrame,
CommitFrame,
ConnectedFrame,
ConnectFrame,
Expand All @@ -24,7 +23,6 @@
NackFrame,
ReceiptFrame,
SendFrame,
ServerFrame,
SubscribeFrame,
UnsubscribeFrame,
)
Expand All @@ -34,34 +32,31 @@
"AbortFrame",
"AbstractConnection",
"AckFrame",
"FailedAllConnectAttemptsError",
"AnyFrame",
"AnyClientFrame",
"AnyListeningEvent",
"BaseListenEvent",
"AnyServerFrame",
"BeginFrame",
"ClientFrame",
"Client",
"CommitFrame",
"ConnectedFrame",
"ConnectError",
"ConnectFrame",
"ConnectedFrame",
"Connection",
"ConnectionConfirmationTimeoutError",
"ConnectionLostError",
"ConnectionParameters",
"DisconnectFrame",
"Error",
"ErrorEvent",
"ErrorFrame",
"FailedAllConnectAttemptsError",
"Heartbeat",
"HeartbeatEvent",
"HeartbeatFrame",
"MessageEvent",
"MessageFrame",
"NackFrame",
"ReadTimeoutError",
"ReceiptFrame",
"SendFrame",
"ServerFrame",
"Client",
"Connection",
"Error",
"SubscribeFrame",
"UnsubscribeFrame",
"UnsupportedProtocolVersionError",
Expand Down
8 changes: 2 additions & 6 deletions stompman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from stompman.connection import AbstractConnection, Connection, ConnectionParameters
from stompman.errors import (
ConnectError,
ConnectionConfirmationTimeoutError,
FailedAllConnectAttemptsError,
UnsupportedProtocolVersionError,
Expand Down Expand Up @@ -80,12 +79,9 @@ async def _connect_to_one_server(self, server: ConnectionParameters) -> Abstract
read_timeout=self.read_timeout,
read_max_chunk_size=self.read_max_chunk_size,
)
try:
await connection.connect()
except ConnectError:
await asyncio.sleep(self.connect_retry_interval * (attempt + 1))
else:
if await connection.connect():
return connection
await asyncio.sleep(self.connect_retry_interval * (attempt + 1))
return None

async def _connect_to_any_server(self) -> AbstractConnection:
Expand Down
27 changes: 14 additions & 13 deletions stompman/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from dataclasses import dataclass, field
from typing import Protocol, TypeVar, cast

from stompman.errors import ConnectError, ReadTimeoutError
from stompman.frames import AnyRealFrame, ClientFrame, ServerFrame
from stompman.errors import ConnectionLostError
from stompman.frames import AnyClientFrame, AnyServerFrame
from stompman.protocol import NEWLINE, Parser, dump_frame


Expand All @@ -16,7 +16,7 @@ class ConnectionParameters:
passcode: str = field(repr=False)


FrameT = TypeVar("FrameT", bound=AnyRealFrame)
FrameT = TypeVar("FrameT", bound=AnyClientFrame | AnyServerFrame)


@dataclass
Expand All @@ -26,11 +26,11 @@ class AbstractConnection(Protocol):
read_timeout: int
read_max_chunk_size: int

async def connect(self) -> None: ...
async def connect(self) -> bool: ...
async def close(self) -> None: ...
def write_heartbeat(self) -> None: ...
async def write_frame(self, frame: ClientFrame) -> None: ...
def read_frames(self) -> AsyncGenerator[ServerFrame, None]: ...
async def write_frame(self, frame: AnyClientFrame) -> None: ...
def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]: ...

async def read_frame_of_type(self, type_: type[FrameT]) -> FrameT:
while True:
Expand All @@ -48,14 +48,15 @@ class Connection(AbstractConnection):
reader: asyncio.StreamReader = field(init=False)
writer: asyncio.StreamWriter = field(init=False)

async def connect(self) -> None:
async def connect(self) -> bool:
try:
self.reader, self.writer = await asyncio.wait_for(
asyncio.open_connection(self.connection_parameters.host, self.connection_parameters.port),
timeout=self.connect_timeout,
)
except (TimeoutError, ConnectionError) as exception:
raise ConnectError(self.connection_parameters) from exception
except (TimeoutError, ConnectionError):
return False
return True

async def close(self) -> None:
self.writer.close()
Expand All @@ -64,7 +65,7 @@ async def close(self) -> None:
def write_heartbeat(self) -> None:
return self.writer.write(NEWLINE)

async def write_frame(self, frame: ClientFrame) -> None:
async def write_frame(self, frame: AnyClientFrame) -> None:
self.writer.write(dump_frame(frame))
await self.writer.drain()

Expand All @@ -75,14 +76,14 @@ async def _read_non_empty_bytes(self) -> bytes:
await asyncio.sleep(0)
return chunk

async def read_frames(self) -> AsyncGenerator[ServerFrame, None]:
async def read_frames(self) -> AsyncGenerator[AnyServerFrame, None]:
parser = Parser()

while True:
try:
raw_frames = await asyncio.wait_for(self._read_non_empty_bytes(), timeout=self.read_timeout)
except TimeoutError as exception:
raise ReadTimeoutError(self.read_timeout) from exception
raise ConnectionLostError(self.read_timeout) from exception

for frame in cast(Iterator[ServerFrame], parser.load_frames(raw_frames)):
for frame in cast(Iterator[AnyServerFrame], parser.load_frames(raw_frames)):
yield frame
7 changes: 1 addition & 6 deletions stompman/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ class UnsupportedProtocolVersionError(Error):
supported_version: str


@dataclass
class ConnectError(Error):
connection_parameters: "ConnectionParameters"


@dataclass
class FailedAllConnectAttemptsError(Error):
servers: list["ConnectionParameters"]
Expand All @@ -36,5 +31,5 @@ class FailedAllConnectAttemptsError(Error):


@dataclass
class ReadTimeoutError(Error):
class ConnectionLostError(Error):
timeout: int
6 changes: 2 additions & 4 deletions stompman/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class HeartbeatFrame: ...
}
FRAMES_TO_COMMANDS = {value: key for key, value in COMMANDS_TO_FRAMES.items()}

ClientFrame = (
AnyClientFrame = (
SendFrame
| SubscribeFrame
| UnsubscribeFrame
Expand All @@ -257,6 +257,4 @@ class HeartbeatFrame: ...
| ConnectFrame
| StompFrame
)
ServerFrame = ConnectedFrame | MessageFrame | ReceiptFrame | ErrorFrame
AnyRealFrame = ClientFrame | ServerFrame
AnyFrame = AnyRealFrame | HeartbeatFrame
AnyServerFrame = ConnectedFrame | MessageFrame | ReceiptFrame | ErrorFrame
10 changes: 5 additions & 5 deletions stompman/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from stompman.frames import (
COMMANDS_TO_FRAMES,
FRAMES_TO_COMMANDS,
AnyFrame,
AnyRealFrame,
AnyClientFrame,
AnyServerFrame,
HeartbeatFrame,
)

Expand Down Expand Up @@ -37,7 +37,7 @@ def dump_header(key: str, value: str) -> bytes:
return f"{escaped_key}:{escaped_value}\n".encode()


def dump_frame(frame: AnyRealFrame) -> bytes:
def dump_frame(frame: AnyClientFrame | AnyServerFrame) -> bytes:
lines = (
FRAMES_TO_COMMANDS[type(frame)],
NEWLINE,
Expand Down Expand Up @@ -79,7 +79,7 @@ def parse_headers(buffer: list[bytes]) -> tuple[str, str] | None:
return (b"".join(key_buffer).decode(), b"".join(value_buffer).decode()) if key_parsed else None


def parse_lines_into_frame(lines: deque[list[bytes]]) -> AnyFrame | None:
def parse_lines_into_frame(lines: deque[list[bytes]]) -> AnyClientFrame | AnyServerFrame | None:
command = b"".join(lines.popleft())
headers = {}

Expand All @@ -101,7 +101,7 @@ class Parser:
_previous_byte: bytes = field(default=b"", init=False)
_headers_processed: bool = field(default=False, init=False)

def load_frames(self, raw_frames: bytes) -> Iterator[AnyFrame]:
def load_frames(self, raw_frames: bytes) -> Iterator[AnyClientFrame | AnyServerFrame | HeartbeatFrame]:
buffer = deque(struct.unpack(f"{len(raw_frames)!s}c", raw_frames))
while buffer:
byte = buffer.popleft()
Expand Down
Loading

0 comments on commit cb03acb

Please sign in to comment.