Skip to content

Commit

Permalink
Move to STOMP 1.2 & improve parsing (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev authored Jun 9, 2024
1 parent e610165 commit 37868ba
Show file tree
Hide file tree
Showing 10 changed files with 374 additions and 314 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ async with asyncio.TaskGroup() as task_group:
)
case stompman.HeartbeatEvent():
task_group.create_task(update_healthcheck_status())
case stompman.UnknownEvent():
logger.error("Received unknown event from server", event=event)
```

### Cleaning Up
Expand All @@ -120,7 +118,7 @@ stompman takes care of cleaning up resources automatically. When you leave the c
### ...and caveats

- stompman only runs on Python 3.11 and newer.
- It only implements [STOMP 1.1](https://stomp.github.io/stomp-specification-1.1.html). I'm open to implementing [STOMP 1.2](https://stomp.github.io/stomp-specification-1.1.html).
- It implements [STOMP 1.2](https://stomp.github.io/stomp-specification-1.2.html) — the latest version of the protocol.
- The client-individual ack mode is used, which means that server requires `ack` or `nack`. In contrast, with `client` ack mode server assumes you don't care about messages that occured before you connected. And, with `auto` ack mode server assumes client successfully received the message.
- Heartbeats are required, and sent automatically on `listen_to_events()` (defaults to 1 second).

Expand Down
7 changes: 1 addition & 6 deletions stompman/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
AbortFrame,
AckFrame,
AnyFrame,
BaseFrame,
BeginFrame,
ClientFrame,
CommitFrame,
Expand All @@ -27,10 +26,9 @@
SendFrame,
ServerFrame,
SubscribeFrame,
UnknownFrame,
UnsubscribeFrame,
)
from stompman.listen_events import AnyListeningEvent, ErrorEvent, HeartbeatEvent, MessageEvent, UnknownEvent
from stompman.listening_events import AnyListeningEvent, ErrorEvent, HeartbeatEvent, MessageEvent

__all__ = [
"AbortFrame",
Expand All @@ -39,7 +37,6 @@
"FailedAllConnectAttemptsError",
"AnyFrame",
"AnyListeningEvent",
"BaseFrame",
"BaseListenEvent",
"BeginFrame",
"ClientFrame",
Expand All @@ -66,8 +63,6 @@
"Connection",
"Error",
"SubscribeFrame",
"UnknownEvent",
"UnknownFrame",
"UnsubscribeFrame",
"UnsupportedProtocolVersionError",
]
13 changes: 8 additions & 5 deletions stompman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
MessageFrame,
ReceiptFrame,
SendFrame,
SendHeaders,
SubscribeFrame,
UnsubscribeFrame,
)
from stompman.listen_events import AnyListeningEvent, ErrorEvent, HeartbeatEvent, MessageEvent, UnknownEvent
from stompman.listening_events import AnyListeningEvent, ErrorEvent, HeartbeatEvent, MessageEvent
from stompman.protocol import PROTOCOL_VERSION


Expand Down Expand Up @@ -117,6 +118,7 @@ async def _connection_lifespan(self) -> AsyncGenerator[None, None]:
headers={
"accept-version": PROTOCOL_VERSION,
"heart-beat": self.heartbeat.to_header(),
"host": self._connection.connection_parameters.host,
"login": self._connection.connection_parameters.login,
"passcode": self._connection.connection_parameters.passcode,
},
Expand Down Expand Up @@ -176,8 +178,6 @@ async def listen_to_events(self) -> AsyncIterator[AnyListeningEvent]:
yield HeartbeatEvent(_client=self, _frame=frame)
case ConnectedFrame() | ReceiptFrame():
raise AssertionError("Should be unreachable! Report the issue.", frame)
case _:
yield UnknownEvent(_client=self, _frame=frame)

@asynccontextmanager
async def enter_transaction(self) -> AsyncGenerator[str, None]:
Expand All @@ -192,16 +192,19 @@ async def enter_transaction(self) -> AsyncGenerator[str, None]:
else:
await self._connection.write_frame(CommitFrame(headers={"transaction": transaction_id}))

async def send(
async def send( # noqa: PLR0913
self,
body: bytes,
destination: str,
transaction: str | None = None,
content_type: str | None = None,
headers: dict[str, str] | None = None,
) -> None:
full_headers = headers or {}
full_headers: SendHeaders = headers or {} # type: ignore[assignment]
full_headers["destination"] = destination
full_headers["content-length"] = str(len(body))
if content_type is not None:
full_headers["content-type"] = content_type
if transaction is not None:
full_headers["transaction"] = transaction
await self._connection.write_frame(SendFrame(headers=full_headers, body=body))
21 changes: 11 additions & 10 deletions stompman/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Protocol, TypeVar, cast

from stompman.errors import ConnectError, ReadTimeoutError
from stompman.frames import ClientFrame, ServerFrame, UnknownFrame
from stompman.protocol import HEARTBEAT_MARKER, dump_frame, load_frames, separate_complete_and_incomplete_packet_parts
from stompman.frames import AnyRealFrame, ClientFrame, ServerFrame
from stompman.protocol import NEWLINE, Parser, dump_frame, separate_complete_and_incomplete_packet_parts


@dataclass
Expand All @@ -16,7 +16,7 @@ class ConnectionParameters:
passcode: str = field(repr=False)


ServerFrameT = TypeVar("ServerFrameT", bound=ServerFrame | UnknownFrame)
FrameT = TypeVar("FrameT", bound=AnyRealFrame)


@dataclass
Expand All @@ -29,10 +29,10 @@ class AbstractConnection(Protocol):
async def connect(self) -> None: ...
async def close(self) -> None: ...
def write_heartbeat(self) -> None: ...
async def write_frame(self, frame: ClientFrame | UnknownFrame) -> None: ...
def read_frames(self) -> AsyncGenerator[ServerFrame | UnknownFrame, None]: ...
async def write_frame(self, frame: ClientFrame) -> None: ...
def read_frames(self) -> AsyncGenerator[ServerFrame, None]: ...

async def read_frame_of_type(self, type_: type[ServerFrameT]) -> ServerFrameT:
async def read_frame_of_type(self, type_: type[FrameT]) -> FrameT:
while True:
async for frame in self.read_frames():
if isinstance(frame, type_):
Expand Down Expand Up @@ -62,9 +62,9 @@ async def close(self) -> None:
await self.writer.wait_closed()

def write_heartbeat(self) -> None:
return self.writer.write(HEARTBEAT_MARKER)
return self.writer.write(NEWLINE)

async def write_frame(self, frame: ClientFrame | UnknownFrame) -> None:
async def write_frame(self, frame: ClientFrame) -> None:
self.writer.write(dump_frame(frame))
await self.writer.drain()

Expand All @@ -75,8 +75,9 @@ async def _read_non_empty_bytes(self) -> bytes:
await asyncio.sleep(0)
return chunk

async def read_frames(self) -> AsyncGenerator[ServerFrame | UnknownFrame, None]:
async def read_frames(self) -> AsyncGenerator[ServerFrame, None]:
incomplete_bytes = b""
parser = Parser()

while True:
try:
Expand All @@ -87,5 +88,5 @@ async def read_frames(self) -> AsyncGenerator[ServerFrame | UnknownFrame, None]:
complete_bytes, incomplete_bytes = separate_complete_and_incomplete_packet_parts(
incomplete_bytes + received_bytes
)
for frame in cast(Iterable[ServerFrame | UnknownFrame], load_frames(complete_bytes)):
for frame in cast(Iterable[ServerFrame], parser.load_frames(received_bytes)):
yield frame
Loading

0 comments on commit 37868ba

Please sign in to comment.