Skip to content

Commit

Permalink
Merge branch 'main' into chemelli74-typing-4
Browse files Browse the repository at this point in the history
  • Loading branch information
rokam authored May 27, 2024
2 parents a49d047 + 959225e commit 365411a
Showing 1 changed file with 53 additions and 39 deletions.
92 changes: 53 additions & 39 deletions midealocal/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from abc import ABC
from enum import IntEnum
from typing import SupportsIndex, cast

_LOGGER = logging.getLogger(__name__)

Expand All @@ -27,57 +28,60 @@ class MessageType(IntEnum):
query_appliance = 0xA0


NONE_VALUE = 0x00


class MessageBase(ABC):
HEADER_LENGTH = 10

def __init__(self):
self._device_type = 0x00
self._message_type = 0x00
self._body_type = 0x00
self._protocol_version = 0x00
def __init__(self) -> None:
self._device_type = NONE_VALUE
self._message_type = NONE_VALUE
self._body_type = NONE_VALUE
self._protocol_version = NONE_VALUE

@staticmethod
def checksum(data):
return (~sum(data) + 1) & 0xFF
def checksum(data: bytes) -> SupportsIndex:
return cast(SupportsIndex, (~sum(data) + 1) & 0xFF)

@property
def header(self):
def header(self) -> bytearray:
raise NotImplementedError

@property
def body(self):
def body(self) -> bytearray:
raise NotImplementedError

@property
def message_type(self):
def message_type(self) -> int:
return self._message_type

@message_type.setter
def message_type(self, value):
def message_type(self, value: int) -> None:
self._message_type = value

@property
def device_type(self):
def device_type(self) -> int:
return self._device_type

@device_type.setter
def device_type(self, value):
def device_type(self, value: int) -> None:
self._device_type = value

@property
def body_type(self):
def body_type(self) -> int:
return self._body_type

@body_type.setter
def body_type(self, value):
def body_type(self, value: int) -> None:
self._body_type = value

@property
def protocol_version(self):
def protocol_version(self) -> int:
return self._protocol_version

@protocol_version.setter
def protocol_version(self, protocol_version):
def protocol_version(self, protocol_version: int) -> None:
self._protocol_version = protocol_version

def __str__(self) -> str:
Expand All @@ -94,7 +98,11 @@ def __str__(self) -> str:

class MessageRequest(MessageBase):
def __init__(
self, device_type: int, protocol_version: int, message_type: int, body_type: int
self,
device_type: int,
protocol_version: int,
message_type: int,
body_type: int,
) -> None:
super().__init__()
self.device_type = device_type
Expand Down Expand Up @@ -149,77 +157,83 @@ def serialize(self) -> bytearray:


class MessageQuestCustom(MessageRequest):
def __init__(self, device_type, protocol_version, cmd_type, cmd_body):
def __init__(
self,
device_type: int,
protocol_version: int,
cmd_type: int,
cmd_body: bytearray,
) -> None:
super().__init__(
device_type=device_type,
protocol_version=protocol_version,
message_type=cmd_type,
body_type=None,
body_type=NONE_VALUE,
)
self._cmd_body = cmd_body

@property
def _body(self):
def _body(self) -> bytearray:
return bytearray([])

@property
def body(self):
def body(self) -> bytearray:
return self._cmd_body


class MessageQueryAppliance(MessageRequest):
def __init__(self, device_type):
def __init__(self, device_type: int) -> None:
super().__init__(
device_type=device_type,
protocol_version=0,
message_type=MessageType.query_appliance,
body_type=None,
body_type=NONE_VALUE,
)

@property
def _body(self):
def _body(self) -> bytearray:
return bytearray([])

@property
def body(self):
def body(self) -> bytearray:
return bytearray([0x00] * 19)


class MessageBody:
def __init__(self, body):
def __init__(self, body: bytearray) -> None:
self._data = body

@property
def data(self):
def data(self) -> bytearray:
return self._data

@property
def body_type(self):
def body_type(self) -> int:
return self._data[0]

@staticmethod
def read_byte(body, byte, default_value=0):
def read_byte(body: bytearray, byte: int, default_value: int = 0) -> int:
return body[byte] if len(body) > byte else default_value


class NewProtocolMessageBody(MessageBody):
def __init__(self, body, bt):
def __init__(self, body: bytearray, bt: int) -> None:
super().__init__(body)
if bt == 0xB5:
self._pack_len = 4
else:
self._pack_len = 5

@staticmethod
def pack(param, value: bytearray, pack_len=4):
def pack(param: int, value: bytearray, pack_len: int = 4) -> bytearray:
length = len(value)
if pack_len == 4:
stream = bytearray([param & 0xFF, param >> 8, length]) + value
else:
stream = bytearray([param & 0xFF, param >> 8, 0x00, length]) + value
return stream

def parse(self):
def parse(self) -> dict[int, bytearray]:
result = {}
try:
pos = 2
Expand All @@ -239,7 +253,7 @@ def parse(self):


class MessageResponse(MessageBase):
def __init__(self, message):
def __init__(self, message: bytearray) -> None:
super().__init__()
if message is None or len(message) < self.HEADER_LENGTH + 1:
raise MessageLenError
Expand All @@ -252,23 +266,23 @@ def __init__(self, message):
self.body_type = self._body.body_type

@property
def header(self):
def header(self) -> bytearray:
return self._header

@property
def body(self):
def body(self) -> bytearray:
return self._body.data

def set_body(self, body: MessageBody):
def set_body(self, body: MessageBody) -> None:
self._body = body

def set_attr(self):
def set_attr(self) -> None:
for key in vars(self._body).keys():
if key != "data":
value = getattr(self._body, key, None)
setattr(self, key, value)


class MessageApplianceResponse(MessageResponse):
def __init__(self, message):
def __init__(self, message: bytearray) -> None:
super().__init__(message)

0 comments on commit 365411a

Please sign in to comment.