Skip to content

Commit

Permalink
Merge pull request #4 from petrovichkr/shell-group
Browse files Browse the repository at this point in the history
Add shell management
  • Loading branch information
JPHutchins authored Feb 13, 2024
2 parents a786681 + ce4b620 commit 3c6c675
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 2 deletions.
9 changes: 8 additions & 1 deletion smp/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class ImageManagement(IntEnum):
CORELOAD = 4
ERASE = 5

@unique
class ShellManagement(IntEnum):
EXECUTE = 0


AnyCommandId = TypeVar("AnyCommandId", bound=IntEnum)

Expand Down Expand Up @@ -85,11 +89,14 @@ class Header:
length: int
group_id: GroupId
sequence: int
command_id: CommandId.OSManagement | CommandId.ImageManagement | IntEnum
command_id: (
CommandId.OSManagement | CommandId.ImageManagement | CommandId.ShellManagement | IntEnum
)

_MAP_GROUP_ID_TO_COMMAND_ID_ENUM = {
GroupId.OS_MANAGEMENT: CommandId.OSManagement,
GroupId.IMAGE_MANAGEMENT: CommandId.ImageManagement,
GroupId.SHELL_MANAGEMENT: CommandId.ShellManagement,
}
_STRUCT = struct.Struct("!BBHHBB")
SIZE = _STRUCT.size
Expand Down
7 changes: 6 additions & 1 deletion smp/message.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The Simple Management Protocol (SMP) Message base class."""

import itertools
from abc import ABC
from enum import IntEnum, unique
Expand All @@ -23,7 +24,11 @@ class _MessageBase(ABC, BaseModel):
_VERSION: ClassVar[smpheader.Version] = smpheader.Version.V0
_FLAGS: ClassVar[smpheader.Flag] = smpheader.Flag(0)
_GROUP_ID: ClassVar[smpheader.GroupId]
_COMMAND_ID: ClassVar[smpheader.CommandId.ImageManagement | smpheader.CommandId.OSManagement]
_COMMAND_ID: ClassVar[
smpheader.CommandId.ImageManagement
| smpheader.CommandId.OSManagement
| smpheader.CommandId.ShellManagement
]

header: smpheader.Header | None = None

Expand Down
44 changes: 44 additions & 0 deletions smp/shell_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""The Simple Management Protocol (SMP) Shell Management group."""


from enum import IntEnum, auto, unique
from typing import List

from smp import error, header, message


class _ShellManagementGroup:
_GROUP_ID = header.GroupId.SHELL_MANAGEMENT


class ExecuteRequest(_ShellManagementGroup, message.WriteRequest):
_COMMAND_ID = header.CommandId.ShellManagement.EXECUTE

argv: List[str]


class ExecuteResponse(_ShellManagementGroup, message.WriteResponse):
_COMMAND_ID = header.CommandId.ShellManagement.EXECUTE

o: str
ret: int


@unique
class SHELL_MGMT_RET_RC(IntEnum):
OK = 0
"""No error, this is implied if there is no ret value in the response."""

UNKNOWN = auto()
"""Unknown error occurred."""

INVALID_FORMAT = auto()
"""The provided format value is not valid."""


class ShellManagementErrorV0(error.ErrorV0[SHELL_MGMT_RET_RC]):
_GROUP_ID = header.GroupId.SHELL_MANAGEMENT


class ShellManagementErrorV1(error.ErrorV1[SHELL_MGMT_RET_RC]):
_GROUP_ID = header.GroupId.SHELL_MANAGEMENT
67 changes: 67 additions & 0 deletions tests/test_shell_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Test the SMP Shell Management group."""

from typing import Any, Dict, Type, TypeVar

import cbor2
from pydantic import BaseModel

from smp import header as smphdr
from smp import message as smpmsg
from smp import shell_management as smpshell
from tests.helpers import make_assert_header

shellcmd = smphdr.CommandId.ShellManagement


T = TypeVar("T", bound=smpmsg._MessageBase)


def _do_test(
msg: Type[T],
op: smphdr.OP,
command_id: smphdr.CommandId.ShellManagement,
data: Dict[str, Any],
nested_model: Type[BaseModel] | None = None,
) -> T:
cbor = cbor2.dumps(data)
assert_header = make_assert_header(smphdr.GroupId.SHELL_MANAGEMENT, op, command_id, len(cbor))

def _assert_common(r: smpmsg._MessageBase) -> None:
assert_header(r)
for k, v in data.items():
if type(v) is dict and nested_model is not None:
for k2, v2 in v.items():
one_deep = getattr(r, k)
assert isinstance(one_deep[k2], nested_model)
assert v2 == one_deep[k2].model_dump()
else:
assert v == getattr(r, k)
assert cbor == r.BYTES[8:]

r = msg(**data)

_assert_common(r) # serialize
_assert_common(msg.loads(r.BYTES)) # deserialize

return r


def test_ExecuteRequest() -> None:
r = _do_test(
smpshell.ExecuteRequest,
smphdr.OP.WRITE,
shellcmd.EXECUTE,
{"argv": ["echo", "Hello"]},
)
assert r.argv == ["echo", "Hello"]


def test_ExecuteResponse() -> None:
r = _do_test(
smpshell.ExecuteResponse,
smphdr.OP.WRITE_RSP,
shellcmd.EXECUTE,
{"o": "Hello", "ret": 0},
)
assert r.o == "Hello"
assert r.ret == 0

0 comments on commit 3c6c675

Please sign in to comment.