From 64d0d9c2bb8ba24d40befa668e23d86ae3c41394 Mon Sep 17 00:00:00 2001 From: "J.P. Hutchins" Date: Thu, 14 Mar 2024 13:24:11 -0700 Subject: [PATCH] fix: assert that header length matches data length when user provides header --- smp/exceptions.py | 4 + smp/message.py | 10 ++- tests/test_injected_header.py | 135 ++++++++++++++++++++++++++++++++++ 3 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 tests/test_injected_header.py diff --git a/smp/exceptions.py b/smp/exceptions.py index 1b91f8f..7e3722b 100644 --- a/smp/exceptions.py +++ b/smp/exceptions.py @@ -27,3 +27,7 @@ class SMPDeserializationError(SMPException): class SMPMismatchedGroupId(SMPDeserializationError): ... + + +class SMPMalformed(SMPException): + ... diff --git a/smp/message.py b/smp/message.py index ea9d04c..86d64e4 100644 --- a/smp/message.py +++ b/smp/message.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict from smp import header as smpheader -from smp.exceptions import SMPMismatchedGroupId +from smp.exceptions import SMPMalformed, SMPMismatchedGroupId T = TypeVar("T", bound='_MessageBase') @@ -86,6 +86,10 @@ def model_post_init(self, _: None) -> None: ](self._COMMAND_ID), ), ) + elif self.header.length != len(data_bytes): + raise SMPMalformed( + f"header.length {self.header.length} != len(data_bytes) {len(data_bytes)}" + ) self._bytes = cast(smpheader.Header, self.header).BYTES + data_bytes @@ -123,6 +127,10 @@ def model_post_init(self, _: None) -> None: ](self._COMMAND_ID), ), ) + elif self.header.length != len(data_bytes): + raise SMPMalformed( + f"header.length {self.header.length} != len(data_bytes) {len(data_bytes)}" + ) self._bytes = cast(smpheader.Header, self.header).BYTES + data_bytes diff --git a/tests/test_injected_header.py b/tests/test_injected_header.py new file mode 100644 index 0000000..b168714 --- /dev/null +++ b/tests/test_injected_header.py @@ -0,0 +1,135 @@ +"""Test the case where the user forms the header separately.""" + +from typing import cast + +import pytest + +from smp import header as smphdr +from smp import image_management as smpimg +from smp.exceptions import SMPMalformed + + +def test_ImageUploadWriteRequest_injected_header() -> None: + h = smphdr.Header( + op=smphdr.OP.WRITE, + version=smphdr.Version.V0, + flags=smphdr.Flag(0), + length=0, + group_id=smphdr.GroupId.IMAGE_MANAGEMENT, + sequence=0, + command_id=smphdr.CommandId.ImageManagement.UPLOAD, + ) + + data = bytes([0x00] * 50) + + r = smpimg.ImageUploadWriteRequest( + header=smphdr.Header( + op=h.op, + version=h.version, + flags=h.flags, + length=76, + group_id=h.group_id, + sequence=h.sequence, + command_id=h.command_id, + ), + off=0, + data=data, + image=1, + len=50, + ) + + assert cast(smphdr.Header, r.header).length == 76 + assert len(r.BYTES) == 76 + smphdr.Header.SIZE + + with pytest.raises(SMPMalformed): + r = smpimg.ImageUploadWriteRequest( + header=smphdr.Header( + op=h.op, + version=h.version, + flags=h.flags, + length=84, + group_id=h.group_id, + sequence=h.sequence, + command_id=h.command_id, + ), + off=0, + data=data, + image=1, + len=50, + ) + + with pytest.raises(SMPMalformed): + r = smpimg.ImageUploadWriteRequest( + header=smphdr.Header( + op=h.op, + version=h.version, + flags=h.flags, + length=0, + group_id=h.group_id, + sequence=h.sequence, + command_id=h.command_id, + ), + off=0, + data=data, + image=1, + len=50, + ) + + +def test_ImageUploadWriteResponse_injected_header() -> None: + h = smphdr.Header( + op=smphdr.OP.WRITE_RSP, + version=smphdr.Version.V0, + flags=smphdr.Flag(0), + length=0, + group_id=smphdr.GroupId.IMAGE_MANAGEMENT, + sequence=0, + command_id=smphdr.CommandId.ImageManagement.UPLOAD, + ) + + r = smpimg.ImageUploadProgressWriteResponse( + header=smphdr.Header( + op=h.op, + version=h.version, + flags=h.flags, + length=10, + group_id=h.group_id, + sequence=h.sequence, + command_id=h.command_id, + ), + rc=0, + off=0, + ) + + assert cast(smphdr.Header, r.header).length == 10 + assert len(r.BYTES) == 10 + smphdr.Header.SIZE + + with pytest.raises(SMPMalformed): + r = smpimg.ImageUploadProgressWriteResponse( + header=smphdr.Header( + op=h.op, + version=h.version, + flags=h.flags, + length=2, + group_id=h.group_id, + sequence=h.sequence, + command_id=h.command_id, + ), + rc=0, + off=0, + ) + + with pytest.raises(SMPMalformed): + r = smpimg.ImageUploadProgressWriteResponse( + header=smphdr.Header( + op=h.op, + version=h.version, + flags=h.flags, + length=0, + group_id=h.group_id, + sequence=h.sequence, + command_id=h.command_id, + ), + rc=0, + off=0, + )