Skip to content

Commit

Permalink
Merge pull request #14 from JPHutchins/fix/#13/smp-message-inheritance
Browse files Browse the repository at this point in the history
fix #13: allow users to define custom groups and commands
  • Loading branch information
JPHutchins authored Apr 24, 2024
2 parents 3f5e330 + c134b6a commit bbb6b47
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 39 deletions.
46 changes: 35 additions & 11 deletions smp/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import struct
from dataclasses import dataclass
from enum import IntEnum, IntFlag, auto, unique
from typing import TypeVar
from typing import ClassVar, Dict, Type, TypeAlias


class CommandId:
Expand Down Expand Up @@ -37,7 +37,7 @@ class Intercreate(IntEnum):
UPLOAD = 1


AnyCommandId = TypeVar("AnyCommandId", bound=IntEnum)
AnyCommandId: TypeAlias = IntEnum | int


class GroupId(IntEnum):
Expand All @@ -56,6 +56,9 @@ class GroupId(IntEnum):
INTERCREATE = 64


AnyGroupId: TypeAlias = IntEnum | int


@unique
class OP(IntEnum):
READ = 0
Expand Down Expand Up @@ -91,20 +94,24 @@ class Header:
version: Version
flags: Flag
length: int
group_id: GroupId
group_id: AnyGroupId | GroupId
sequence: int
command_id: (
CommandId.OSManagement | CommandId.ImageManagement | CommandId.ShellManagement | IntEnum
AnyCommandId
| CommandId.OSManagement
| CommandId.ImageManagement
| CommandId.ShellManagement
| CommandId.Intercreate
)

_MAP_GROUP_ID_TO_COMMAND_ID_ENUM = {
_MAP_GROUP_ID_TO_COMMAND_ID_ENUM: ClassVar[Dict[int, Type[IntEnum]]] = {
GroupId.OS_MANAGEMENT: CommandId.OSManagement,
GroupId.IMAGE_MANAGEMENT: CommandId.ImageManagement,
GroupId.SHELL_MANAGEMENT: CommandId.ShellManagement,
GroupId.INTERCREATE: CommandId.Intercreate,
}
_STRUCT = struct.Struct("!BBHHBB")
SIZE = _STRUCT.size
_STRUCT: ClassVar = struct.Struct("!BBHHBB")
SIZE: ClassVar = _STRUCT.size

@staticmethod
def _pack_op(op: OP) -> int:
Expand All @@ -131,17 +138,32 @@ def _pack_op_and_version(op: OP, version: Version) -> int:
"""The op and version packed into one byte."""
return Header._pack_op(op) | Header._pack_version(version)

@staticmethod
def _validate_command_id(group_id: int, command_id: int) -> None:
"""Validate the command_id if the GroupId is known."""

if command_id_t := Header._MAP_GROUP_ID_TO_COMMAND_ID_ENUM.get(group_id):
try:
command_id_t(command_id)
except ValueError:
raise ValueError(
f"Command ID {command_id} is not valid for Group ID {group_id}"
f" ({GroupId(group_id).name})"
)

def __post_init__(self) -> None:
Header._validate_command_id(self.group_id, self.command_id)

object.__setattr__(
self,
'_bytes',
self._STRUCT.pack(
self._pack_op_and_version(self.op, self.version),
Flag(self.flags),
self.length,
GroupId(self.group_id),
self.group_id,
self.sequence,
Header._MAP_GROUP_ID_TO_COMMAND_ID_ENUM[GroupId(self.group_id)](self.command_id),
self.command_id,
),
)

Expand All @@ -161,12 +183,14 @@ def loads(header: bytes) -> 'Header':
header
)

Header._validate_command_id(group_id, command_id)

return Header(
Header._unpack_op(res_ver_op_byte),
Header._unpack_version(res_ver_op_byte),
Flag(flags),
length,
GroupId(group_id),
group_id,
sequence,
Header._MAP_GROUP_ID_TO_COMMAND_ID_ENUM[GroupId(group_id)](command_id),
command_id,
)
4 changes: 2 additions & 2 deletions smp/image_management.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""The Simple Management Protocol (SMP) Image Management group."""

from enum import IntEnum, auto, unique
from typing import Generator, List
from typing import ClassVar, Generator, List

from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator

from smp import error, header, message


class _ImageManagementGroup:
_GROUP_ID = header.GroupId.IMAGE_MANAGEMENT
_GROUP_ID: ClassVar = header.GroupId.IMAGE_MANAGEMENT


class HashBytes(bytes): # pragma: no cover
Expand Down
17 changes: 7 additions & 10 deletions smp/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ class _MessageBase(ABC, BaseModel):
_OP: ClassVar[smpheader.OP]
_VERSION: ClassVar[smpheader.Version] = smpheader.Version.V0
_FLAGS: ClassVar[smpheader.Flag] = smpheader.Flag(0)
_GROUP_ID: ClassVar[smpheader.GroupId]
_GROUP_ID: ClassVar[smpheader.GroupId | smpheader.AnyGroupId]
_COMMAND_ID: ClassVar[
smpheader.CommandId.ImageManagement
smpheader.AnyCommandId
| smpheader.CommandId.ImageManagement
| smpheader.CommandId.OSManagement
| smpheader.CommandId.ShellManagement
| smpheader.CommandId.Intercreate
Expand Down Expand Up @@ -79,11 +80,9 @@ def model_post_init(self, _: None) -> None:
version=self._VERSION,
flags=smpheader.Flag(self._FLAGS),
length=len(data_bytes),
group_id=smpheader.GroupId(self._GROUP_ID),
group_id=self._GROUP_ID,
sequence=next(_counter) % 0xFF,
command_id=smpheader.Header._MAP_GROUP_ID_TO_COMMAND_ID_ENUM[
smpheader.GroupId(self._GROUP_ID)
](self._COMMAND_ID),
command_id=self._COMMAND_ID,
),
)
elif self.header.length != len(data_bytes):
Expand Down Expand Up @@ -120,11 +119,9 @@ def model_post_init(self, _: None) -> None:
version=self._VERSION,
flags=smpheader.Flag(self._FLAGS),
length=len(data_bytes),
group_id=smpheader.GroupId(self._GROUP_ID),
group_id=self._GROUP_ID,
sequence=self.sequence,
command_id=smpheader.Header._MAP_GROUP_ID_TO_COMMAND_ID_ENUM[
smpheader.GroupId(self._GROUP_ID)
](self._COMMAND_ID),
command_id=self._COMMAND_ID,
),
)
self._bytes = cast(smpheader.Header, self.header).BYTES + data_bytes
Expand Down
4 changes: 2 additions & 2 deletions smp/os_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@


from enum import IntEnum, auto, unique
from typing import Any, Dict
from typing import Any, ClassVar, Dict

from pydantic import BaseModel, ConfigDict, Field

from smp import error, header, message


class _OSManagementGroup:
_GROUP_ID = header.GroupId.OS_MANAGEMENT
_GROUP_ID: ClassVar = header.GroupId.OS_MANAGEMENT


class EchoWriteRequest(_OSManagementGroup, message.WriteRequest):
Expand Down
4 changes: 2 additions & 2 deletions smp/shell_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@


from enum import IntEnum, auto, unique
from typing import List
from typing import ClassVar, List

from smp import error, header, message


class _ShellManagementGroup:
_GROUP_ID = header.GroupId.SHELL_MANAGEMENT
_GROUP_ID: ClassVar = header.GroupId.SHELL_MANAGEMENT


class ExecuteRequest(_ShellManagementGroup, message.WriteRequest):
Expand Down
3 changes: 2 additions & 1 deletion smp/user/intercreate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""The Simple Management Protocol (SMP) Intercreate Management group."""

from enum import IntEnum, auto, unique
from typing import ClassVar

from smp import error, header, message


class _IntercreateManagementGroup:
_GROUP_ID = header.GroupId.INTERCREATE
_GROUP_ID: ClassVar = header.GroupId.INTERCREATE


class ImageUploadWriteRequest(_IntercreateManagementGroup, message.WriteRequest):
Expand Down
17 changes: 6 additions & 11 deletions tests/test_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ def test_header_serialization(
command_id: int,
) -> None:
# the validators will raise exceptions
if group_id > max(GroupId):
with pytest.raises((KeyError, ValueError)):
h = Header(op, version, flags, length, group_id, sequence, command_id) # type: ignore
return
elif command_id > max(Header._MAP_GROUP_ID_TO_COMMAND_ID_ENUM[group_id]):
if group_id <= max(GroupId) and command_id > max(
Header._MAP_GROUP_ID_TO_COMMAND_ID_ENUM[group_id]
):
with pytest.raises((KeyError, ValueError)):
h = Header(op, version, flags, length, group_id, sequence, command_id) # type: ignore
return
Expand Down Expand Up @@ -76,12 +74,9 @@ def test_header_deserialization(
command_id: int,
) -> None:
# the validators will raise exceptions
if group_id > max(GroupId):
with pytest.raises((KeyError, ValueError)):
_h = Header(op, version, flags, length, group_id, sequence, command_id) # type: ignore
h = Header.loads(_h.BYTES)
return
elif command_id > max(Header._MAP_GROUP_ID_TO_COMMAND_ID_ENUM[group_id]):
if group_id <= max(GroupId) and command_id > max(
Header._MAP_GROUP_ID_TO_COMMAND_ID_ENUM[group_id]
):
with pytest.raises((KeyError, ValueError)):
_h = Header(op, version, flags, length, group_id, sequence, command_id) # type: ignore
h = Header.loads(_h.BYTES)
Expand Down
91 changes: 91 additions & 0 deletions tests/test_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Tests for user-defined inheritance of classes."""


import struct
from typing import Final, Type

import pytest

from smp import header as smphdr
from smp import message as smpmsg

GROUP_ID_THAT_IS_NOT_IN_GROUP_ID_ENUM: Final = 65
assert GROUP_ID_THAT_IS_NOT_IN_GROUP_ID_ENUM not in list(smphdr.GroupId)


def test_custom_ReadRequest() -> None:
"""Test ReadRequest inheritance."""

class CustomInts(smpmsg.ReadRequest):
_GROUP_ID = GROUP_ID_THAT_IS_NOT_IN_GROUP_ID_ENUM
_COMMAND_ID = 0

m = CustomInts()
assert m._GROUP_ID == GROUP_ID_THAT_IS_NOT_IN_GROUP_ID_ENUM
assert m._COMMAND_ID == 0


@pytest.mark.parametrize(
"cls",
[
smpmsg.ReadRequest,
smpmsg.WriteRequest,
smpmsg.ReadResponse,
smpmsg.WriteResponse,
smpmsg.Request,
smpmsg.Response,
],
)
@pytest.mark.parametrize("group_id", [GROUP_ID_THAT_IS_NOT_IN_GROUP_ID_ENUM, 0xFFFF])
@pytest.mark.parametrize("command_id", [0, 1, 0xFF])
def test_custom_message(cls: Type[smpmsg._MessageBase], group_id: int, command_id: int) -> None:
"""Test ReadRequest inheritance."""

class CustomInts(cls): # type: ignore
_OP = getattr(cls, "_OP", 0)
_GROUP_ID = group_id
_COMMAND_ID = command_id

m = CustomInts()
assert m._GROUP_ID == group_id
assert m._COMMAND_ID == command_id


def test_invalid_group_id() -> None:
"""Test invalid group_id."""

with pytest.raises(struct.error):

class A(smpmsg.ReadRequest):
_GROUP_ID = 0x10000
_COMMAND_ID = 0

A()

with pytest.raises(struct.error):

class B(smpmsg.ReadRequest):
_GROUP_ID = -1
_COMMAND_ID = 0

B()


def test_invalid_command_id() -> None:
"""Test invalid command_id."""

with pytest.raises(struct.error):

class A(smpmsg.ReadRequest):
_GROUP_ID = GROUP_ID_THAT_IS_NOT_IN_GROUP_ID_ENUM
_COMMAND_ID = 0x100

A()

with pytest.raises(struct.error):

class B(smpmsg.ReadRequest):
_GROUP_ID = GROUP_ID_THAT_IS_NOT_IN_GROUP_ID_ENUM
_COMMAND_ID = -1

B()

0 comments on commit bbb6b47

Please sign in to comment.