Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shell management #4

Merged
merged 4 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion smp/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class ImageManagement(IntEnum):
CORELOAD = 4
ERASE = 5

@unique
class ShellManagement(IntEnum):
EXECUTE = 0


AnyCommandId = TypeVar("AnyCommandId", bound=IntEnum)

Expand Down Expand Up @@ -85,11 +89,14 @@ class Header:
length: int
group_id: GroupId
sequence: int
command_id: CommandId.OSManagement | CommandId.ImageManagement | IntEnum
command_id: (
CommandId.OSManagement | CommandId.ImageManagement | CommandId.ShellManagement | IntEnum
)

_MAP_GROUP_ID_TO_COMMAND_ID_ENUM = {
GroupId.OS_MANAGEMENT: CommandId.OSManagement,
GroupId.IMAGE_MANAGEMENT: CommandId.ImageManagement,
GroupId.SHELL_MANAGEMENT: CommandId.ShellManagement,
}
_STRUCT = struct.Struct("!BBHHBB")
SIZE = _STRUCT.size
Expand Down
7 changes: 6 additions & 1 deletion smp/message.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The Simple Management Protocol (SMP) Message base class."""

import itertools
from abc import ABC
from enum import IntEnum, unique
Expand All @@ -23,7 +24,11 @@ class _MessageBase(ABC, BaseModel):
_VERSION: ClassVar[smpheader.Version] = smpheader.Version.V0
_FLAGS: ClassVar[smpheader.Flag] = smpheader.Flag(0)
_GROUP_ID: ClassVar[smpheader.GroupId]
_COMMAND_ID: ClassVar[smpheader.CommandId.ImageManagement | smpheader.CommandId.OSManagement]
_COMMAND_ID: ClassVar[
smpheader.CommandId.ImageManagement
| smpheader.CommandId.OSManagement
| smpheader.CommandId.ShellManagement
]

header: smpheader.Header | None = None

Expand Down
44 changes: 44 additions & 0 deletions smp/shell_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""The Simple Management Protocol (SMP) Shell Management group."""


from enum import IntEnum, auto, unique
from typing import List

from smp import error, header, message


class _ShellManagementGroup:
_GROUP_ID = header.GroupId.SHELL_MANAGEMENT


class ExecuteRequest(_ShellManagementGroup, message.WriteRequest):
_COMMAND_ID = header.CommandId.ShellManagement.EXECUTE

argv: List[str]


class ExecuteResponse(_ShellManagementGroup, message.WriteResponse):
_COMMAND_ID = header.CommandId.ShellManagement.EXECUTE

o: str
ret: int


@unique
class SHELL_MGMT_RET_RC(IntEnum):
OK = 0
"""No error, this is implied if there is no ret value in the response."""

UNKNOWN = auto()
"""Unknown error occurred."""

INVALID_FORMAT = auto()
"""The provided format value is not valid."""


class ShellManagementErrorV0(error.ErrorV0[SHELL_MGMT_RET_RC]):
_GROUP_ID = header.GroupId.SHELL_MANAGEMENT


class ShellManagementErrorV1(error.ErrorV1[SHELL_MGMT_RET_RC]):
_GROUP_ID = header.GroupId.SHELL_MANAGEMENT
67 changes: 67 additions & 0 deletions tests/test_shell_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Test the SMP Shell Management group."""

from typing import Any, Dict, Type, TypeVar

import cbor2
from pydantic import BaseModel

from smp import header as smphdr
from smp import message as smpmsg
from smp import shell_management as smpshell
from tests.helpers import make_assert_header

shellcmd = smphdr.CommandId.ShellManagement


T = TypeVar("T", bound=smpmsg._MessageBase)


def _do_test(
msg: Type[T],
op: smphdr.OP,
command_id: smphdr.CommandId.ShellManagement,
data: Dict[str, Any],
nested_model: Type[BaseModel] | None = None,
) -> T:
cbor = cbor2.dumps(data)
assert_header = make_assert_header(smphdr.GroupId.SHELL_MANAGEMENT, op, command_id, len(cbor))

def _assert_common(r: smpmsg._MessageBase) -> None:
assert_header(r)
for k, v in data.items():
if type(v) is dict and nested_model is not None:
for k2, v2 in v.items():
one_deep = getattr(r, k)
assert isinstance(one_deep[k2], nested_model)
assert v2 == one_deep[k2].model_dump()
else:
assert v == getattr(r, k)
assert cbor == r.BYTES[8:]

r = msg(**data)

_assert_common(r) # serialize
_assert_common(msg.loads(r.BYTES)) # deserialize

return r


def test_ExecuteRequest() -> None:
r = _do_test(
smpshell.ExecuteRequest,
smphdr.OP.WRITE,
shellcmd.EXECUTE,
{"argv": ["echo", "Hello"]},
)
assert r.argv == ["echo", "Hello"]


def test_ExecuteResponse() -> None:
r = _do_test(
smpshell.ExecuteResponse,
smphdr.OP.WRITE_RSP,
shellcmd.EXECUTE,
{"o": "Hello", "ret": 0},
)
assert r.o == "Hello"
assert r.ret == 0
Loading