diff --git a/.gitignore b/.gitignore index 6043f058..ef90c25c 100644 --- a/.gitignore +++ b/.gitignore @@ -132,6 +132,10 @@ reports/ # IDE .idea/ -# Optional config file for library_test.py script +# Optional config file for library script library_test.json midea-local.json + +# LUA protocol implementation +*.lua +cov.xml diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 8a5d7d48..71c108a8 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -2,6 +2,7 @@ "recommendations": [ "charliermarsh.ruff", "esbenp.prettier-vscode", - "ms-python.python" + "ms-python.python", + "ryanluker.vscode-coverage-gutters" ] } diff --git a/.vscode/launch.json b/.vscode/launch.json index 779f725b..6cfbfc23 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -22,6 +22,7 @@ "module": "midealocal.cli", "args": [ "save", + "-d", "--cloud-name", "${input:cloud-name}", "--username", @@ -29,6 +30,27 @@ "--password", "${input:password}" ] + }, + { + "name": "Python Debugger: Download from host", + "type": "debugpy", + "request": "launch", + "module": "midealocal.cli", + "args": ["download", "-d", "--host", "${input:host}"] + }, + { + "name": "Python Debugger: Download from sn and type", + "type": "debugpy", + "request": "launch", + "module": "midealocal.cli", + "args": [ + "download", + "-d", + "--device-type", + "${input:device-type}", + "--device-sn", + "${input:device-sn}" + ] } ], "inputs": [ @@ -38,6 +60,18 @@ "description": "Enter the host IP address (leave empty to broadcast on the network)", "default": "" }, + { + "id": "device-type", + "type": "promptString", + "description": "Enter the device type ex. AC", + "default": "" + }, + { + "id": "device-sn", + "type": "promptString", + "description": "Enter the device serial number.", + "default": "" + }, { "id": "message", "type": "promptString", diff --git a/.vscode/settings.json b/.vscode/settings.json index 5a30d876..3d249c17 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,11 @@ { - "python.testing.pytestArgs": ["tests", "--no-cov"], + "python.testing.pytestArgs": [ + "tests", + "--cov-report", + "xml:cov.xml", + "--cov", + "midealocal" + ], "python.testing.pytestEnabled": true, "pylint.importStrategy": "fromEnvironment", "python.testing.unittestEnabled": false diff --git a/midealocal/cli.py b/midealocal/cli.py index be036346..dfcbcfdf 100644 --- a/midealocal/cli.py +++ b/midealocal/cli.py @@ -14,7 +14,7 @@ import platformdirs from colorlog import ColoredFormatter -from midealocal.cloud import SUPPORTED_CLOUDS, get_midea_cloud +from midealocal.cloud import SUPPORTED_CLOUDS, MideaCloud, get_midea_cloud from midealocal.device import ProtocolVersion from midealocal.devices import device_selector from midealocal.discover import discover @@ -23,110 +23,197 @@ _LOGGER = logging.getLogger("cli") +LOG_FORMAT = ( + "%(asctime)s.%(msecs)03d %(levelname)s (%(threadName)s) [%(name)s] %(message)s" +) -def get_config_file_path(relative: bool = False) -> Path: - """Get the config file path.""" - local_path = Path("midea-local.json") - if relative or local_path.exists(): - return local_path - return platformdirs.user_config_path(appname="midea-local").joinpath( - "midea-local.json", - ) +class MideaCLI: + """Midea CLI.""" -async def _get_keys(args: Namespace, device_id: int) -> dict[int, dict[str, Any]]: - if not args.cloud_name or not args.username or not args.password: - raise ElementMissing("Missing required parameters for cloud request.") - cloud_keys = {} - async with aiohttp.ClientSession() as session: - cloud = get_midea_cloud( - cloud_name=args.cloud_name, - session=session, - account=args.username, - password=args.password, - ) + session: aiohttp.ClientSession + namespace: Namespace - cloud_keys = await cloud.get_cloud_keys(device_id) - default_keys = await cloud.get_default_keys() - return {**cloud_keys, **default_keys} + async def _get_cloud(self) -> MideaCloud: + """Get cloud instance.""" + if ( + not self.namespace.cloud_name + or not self.namespace.username + or not self.namespace.password + ): + raise ElementMissing("Missing required parameters for cloud request.") + if not hasattr(self, "session"): + self.session = aiohttp.ClientSession() -async def _discover(args: Namespace) -> None: - """Discover device information.""" - devices = discover(ip_address=args.host) + return get_midea_cloud( + cloud_name=self.namespace.cloud_name, + session=self.session, + account=self.namespace.username, + password=self.namespace.password, + ) - if len(devices) == 0: - _LOGGER.error("No devices found.") - return + async def _get_keys(self, device_id: int) -> dict[int, dict[str, Any]]: + cloud = await self._get_cloud() + cloud_keys = await cloud.get_cloud_keys(device_id) + default_keys = await cloud.get_default_keys() + return {**cloud_keys, **default_keys} + + async def discover(self) -> None: + """Discover device information.""" + devices = discover(ip_address=self.namespace.host) + + if len(devices) == 0: + _LOGGER.error("No devices found.") + return + + # Dump only basic device info from the base class + _LOGGER.info("Found %d devices.", len(devices)) + for device in devices.values(): + keys = ( + {0: {"token": "", "key": ""}} + if device["protocol"] != ProtocolVersion.V3 + else await self._get_keys(device["device_id"]) + ) - # Dump only basic device info from the base class - _LOGGER.info("Found %d devices.", len(devices)) - for device in devices.values(): - keys = ( - {0: {"token": "", "key": ""}} - if device["protocol"] != ProtocolVersion.V3 - else await _get_keys(args, device["device_id"]) + for key in keys.values(): + dev = device_selector( + name=device["device_id"], + device_id=device["device_id"], + device_type=device["type"], + ip_address=device["ip_address"], + port=device["port"], + token=key["token"], + key=key["key"], + protocol=device["protocol"], + model=device["model"], + subtype=0, + customize="", + ) + _LOGGER.debug("Trying to connect with key: %s", key) + if dev.connect(): + _LOGGER.info("Found device:\n%s", dev.attributes) + break + + _LOGGER.debug("Unable to connect with key: %s", key) + + def message(self) -> None: + """Load message into device.""" + device_type = int(self.namespace.message[2]) + + device = device_selector( + device_id=0, + name="", + device_type=device_type, + ip_address="192.168.192.168", + port=6664, + protocol=ProtocolVersion.V2, + model="0000", + token="", + key="", + subtype=0, + customize="", ) - for key in keys.values(): - dev = device_selector( - name=device["device_id"], - device_id=device["device_id"], - device_type=device["type"], - ip_address=device["ip_address"], - port=device["port"], - token=key["token"], - key=key["key"], - protocol=device["protocol"], - model=device["model"], - subtype=0, - customize="", - ) - _LOGGER.debug("Trying to connect with key: %s", key) - if dev.connect(): - _LOGGER.info("Found device:\n%s", dev.attributes) - break - - _LOGGER.debug("Unable to connect with key: %s", key) - - -def _message(args: Namespace) -> None: - """Load message into device.""" - device_type = int(args.message[2]) - - device = device_selector( - device_id=0, - name="", - device_type=device_type, - ip_address="192.168.192.168", - port=6664, - protocol=ProtocolVersion.V2, - model="0000", - token="", - key="", - subtype=0, - customize="", - ) + result = device.process_message(self.namespace.message) + + _LOGGER.info("Parsed message: %s", result) + + def save(self) -> None: + """Save credentials to config file.""" + data = { + "username": self.namespace.username, + "password": self.namespace.password, + "cloud_name": self.namespace.cloud_name, + } + json_data = json.dumps(data) + file = get_config_file_path(not self.namespace.user) + with file.open(mode="w+", encoding="utf-8") as f: + f.write(json_data) + + async def download(self) -> None: + """Download lua from cloud.""" + device_type = int.from_bytes(self.namespace.device_type or bytearray()) + device_sn = str(self.namespace.device_sn) + model: str | None = None + + if self.namespace.host: + devices = discover(ip_address=self.namespace.host) + + if len(devices) == 0: + _LOGGER.error("No devices found.") + return + + _, device = devices.popitem() + device_type = device["type"] + device_sn = device["sn"] + model = device["model"] + + cloud = await self._get_cloud() + _LOGGER.debug("Try to authenticate to the cloud.") + if not await cloud.login(): + _LOGGER.error("Failed to authenticate to the cloud.") + return + + _LOGGER.debug("Download lua file for %s [%s]", device_sn, hex(device_type)) + lua = await cloud.download_lua(str(Path()), device_type, device_sn, model) + _LOGGER.info("Downloaded lua file: %s", lua) + + def run(self, namespace: Namespace) -> None: + """Do setup logging, validate args and execute the desired function.""" + self.namespace = namespace + # Configure logging + if self.namespace.debug: + logging.basicConfig(level=logging.DEBUG) + # Keep httpx as info level + logging.getLogger("asyncio").setLevel(logging.INFO) + logging.getLogger("charset_normalizer").setLevel(logging.INFO) + else: + logging.basicConfig(level=logging.INFO) + # Set httpx to warning level + logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger("charset_normalizer").setLevel(logging.WARNING) + + fmt = LOG_FORMAT + colorfmt = f"%(log_color)s{fmt}%(reset)s" + logging.getLogger().handlers[0].setFormatter( + ColoredFormatter( + colorfmt, + datefmt="%Y-%m-%d %H:%M:%S", + reset=True, + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red", + }, + ), + ) - result = device.process_message(args.message) + with contextlib.suppress(KeyboardInterrupt): + if inspect.iscoroutinefunction(self.namespace.func): + asyncio.run(self.namespace.func()) + else: + self.namespace.func() - _LOGGER.info("Parsed message: %s", result) + if hasattr(self, "session") and self.session: + asyncio.run(self.session.close()) -def _save(args: Namespace) -> None: - data = { - "username": args.username, - "password": args.password, - "cloud_name": args.cloud_name, - } - json_data = json.dumps(data) - file = get_config_file_path(not args.user) - with file.open(mode="w+", encoding="utf-8") as f: - f.write(json_data) +def get_config_file_path(relative: bool = False) -> Path: + """Get the config file path.""" + local_path = Path("midea-local.json") + if relative or local_path.exists(): + return local_path + return platformdirs.user_config_path(appname="midea-local").joinpath( + "midea-local.json", + ) def main() -> NoReturn: """Launch main entry.""" + cli = MideaCLI() # Define the main parser to select subcommands parser = ArgumentParser(description="Command line utility for midea-local.") parser.add_argument( @@ -176,7 +263,7 @@ def main() -> NoReturn: help="Hostname or IP address of a single device to discover.", default=None, ) - discover_parser.set_defaults(func=_discover) + discover_parser.set_defaults(func=cli.discover) decode_msg_parser = subparsers.add_parser( "decode", @@ -188,7 +275,7 @@ def main() -> NoReturn: help="Received message", type=bytes.fromhex, ) - decode_msg_parser.set_defaults(func=_message) + decode_msg_parser.set_defaults(func=cli.message) save_parser = subparsers.add_parser( "save", @@ -200,57 +287,36 @@ def main() -> NoReturn: help="Save config file in your user config folder.", action="store_true", ) - save_parser.set_defaults(func=_save) - - config = get_config_file_path() - namespace = Namespace() - if config.exists(): - with config.open("r", encoding="utf-8") as f: - namespace = Namespace(**json.load(f)) + save_parser.set_defaults(func=cli.save) - # Run with args - _run(parser.parse_args(namespace=namespace)) - - -def _run(args: Namespace) -> NoReturn: - """Do setup logging, validate args and execute the desired function.""" - # Configure logging - if args.debug: - logging.basicConfig(level=logging.DEBUG) - # Keep httpx as info level - logging.getLogger("asyncio").setLevel(logging.INFO) - logging.getLogger("charset_normalizer").setLevel(logging.INFO) - else: - logging.basicConfig(level=logging.INFO) - # Set httpx to warning level - logging.getLogger("asyncio").setLevel(logging.WARNING) - logging.getLogger("charset_normalizer").setLevel(logging.WARNING) - - fmt = ( - "%(asctime)s.%(msecs)03d %(levelname)s (%(threadName)s) [%(name)s] %(message)s" + download_parser = subparsers.add_parser( + "download", + description="Download lua scripts from cloud.", + parents=[common_parser], ) - colorfmt = f"%(log_color)s{fmt}%(reset)s" - logging.getLogger().handlers[0].setFormatter( - ColoredFormatter( - colorfmt, - datefmt="%Y-%m-%d %H:%M:%S", - reset=True, - log_colors={ - "DEBUG": "cyan", - "INFO": "green", - "WARNING": "yellow", - "ERROR": "red", - "CRITICAL": "red", - }, - ), + download_parser.add_argument( + "--device-type", + help="Device Type", + type=bytes.fromhex, + ) + download_parser.add_argument("--device-sn", help="Device SN") + download_parser.add_argument( + "--host", + help="IP Address of the device.", ) + download_parser.set_defaults(func=cli.download) - with contextlib.suppress(KeyboardInterrupt): - if inspect.iscoroutinefunction(args.func): - asyncio.run(args.func(args)) - else: - args.func(args) + config = get_config_file_path() + namespace = parser.parse_args() + if config.exists(): + with config.open(encoding="utf-8") as f: + config_data = json.load(f) + for key, value in config_data.items(): + if not getattr(namespace, key): + setattr(namespace, key, value) + # Run with args + cli.run(namespace) sys.exit(0) diff --git a/midealocal/cloud.py b/midealocal/cloud.py index 8a5787ce..8913314d 100644 --- a/midealocal/cloud.py +++ b/midealocal/cloud.py @@ -272,7 +272,7 @@ async def download_lua( path: str, device_type: int, sn: str, - model_number: str | None, + model_number: str | None = None, manufacturer_code: str = "0000", ) -> str | None: """Download lua integration.""" @@ -442,7 +442,7 @@ async def download_lua( path: str, device_type: int, sn: str, - model_number: str | None, # noqa: ARG002 + model_number: str | None = None, # noqa: ARG002 manufacturer_code: str = "0000", ) -> str | None: """Download lua integration.""" @@ -628,7 +628,7 @@ async def download_lua( path: str, device_type: int, sn: str, - model_number: str | None, + model_number: str | None = None, manufacturer_code: str = "0000", ) -> str | None: """Download lua integration.""" @@ -639,14 +639,15 @@ async def download_lua( "deviceId": self._device_id, "iotAppId": self._app_id, "applianceMFCode": manufacturer_code, - "applianceType": f".{f'x{device_type:02x}'}", - "modelNumber": model_number, + "applianceType": hex(device_type), "applianceSn": self._security.aes_encrypt_with_fixed_key( sn.encode("ascii"), ).hex(), "version": "0", "encryptedType ": "2", } + if model_number is not None: + data["modelNumber"] = model_number fnm = None if response := await self._api_request( endpoint="/v2/luaEncryption/luaGet", diff --git a/tests/cli_test.py b/tests/cli_test.py new file mode 100644 index 00000000..79e786f6 --- /dev/null +++ b/tests/cli_test.py @@ -0,0 +1,263 @@ +"""Midea Local CLI tests.""" + +import json +import logging +import subprocess # noqa: S404 +import sys +from argparse import Namespace +from pathlib import Path +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from midealocal.cli import ( + ElementMissing, + MideaCLI, + get_config_file_path, +) +from midealocal.cloud import MSmartHomeCloud +from midealocal.device import ProtocolVersion + + +class TestMideaCLI(IsolatedAsyncioTestCase): + """Test Midea CLI.""" + + def setUp(self) -> None: + """Create namespace for testing.""" + self.cli = MideaCLI() + self.namespace = Namespace( + cloud_name="MSmartHome", + username="user", + password="pass", + host="192.168.0.1", + message=bytes.fromhex("00a1a2ac0f2a0000"), + device_type=bytearray(), + device_sn="", + user=False, + debug=True, + func=MagicMock(), + ) + self.cli.namespace = self.namespace + + async def test_get_cloud(self) -> None: + """Test get cloud.""" + mock_session_instance = AsyncMock() + with ( + patch("aiohttp.ClientSession", return_value=mock_session_instance), + ): + cloud = await self.cli._get_cloud() + + assert isinstance(cloud, MSmartHomeCloud) + assert cloud._account == self.namespace.username + assert cloud._password == self.namespace.password + assert cloud._session == mock_session_instance + + self.namespace.cloud_name = None + with pytest.raises(ElementMissing): + await self.cli._get_cloud() + + async def test_discover(self) -> None: + """Test discover.""" + mock_device = { + "device_id": 1, + "protocol": ProtocolVersion.V3, + "type": "AC", + "ip_address": "192.168.0.2", + "port": 6444, + "model": "AC123", + "sn": "AC123", + } + mock_cloud_instance = AsyncMock() + mock_device_instance = MagicMock() + mock_device_instance.connect.return_value = True + with ( + patch( + "midealocal.cli.discover", + ) as mock_discover, + patch.object( + self.cli, + "_get_cloud", + return_value=mock_cloud_instance, + ), + patch( + "midealocal.cli.device_selector", + return_value=mock_device_instance, + ), + ): + mock_discover.return_value = {1: mock_device} + mock_cloud_instance.get_cloud_keys.return_value = { + 0: {"token": "token", "key": "key"}, + } + mock_cloud_instance.get_default_keys.return_value = { + 99: {"token": "token", "key": "key"}, + } + + await self.cli.discover() # V3 device + + mock_device["protocol"] = ProtocolVersion.V2 + await self.cli.discover() # V2 device + + mock_device_instance.connect.return_value = False + await self.cli.discover() # connect failed + + mock_discover.return_value = {} + + await self.cli.discover() # No devices + + def test_message(self) -> None: + """Test message.""" + mock_device_instance = MagicMock() + with patch( + "midealocal.cli.device_selector", + return_value=mock_device_instance, + ) as mock_device_selector: + mock_device_selector.return_value = mock_device_instance + + self.cli.message() + + mock_device_selector.assert_called_once_with( + device_id=0, + name="", + device_type=int(self.namespace.message[2]), + ip_address="192.168.192.168", + port=6664, + protocol=ProtocolVersion.V2, + model="0000", + token="", + key="", + subtype=0, + customize="", + ) + mock_device_instance.process_message.assert_called_once_with( + self.namespace.message, + ) + + def test_save(self) -> None: + """Test save.""" + mock_path_instance = MagicMock() + with patch("midealocal.cli.get_config_file_path") as mock_get_config_file_path: + mock_get_config_file_path.return_value = mock_path_instance + + self.cli.save() + + mock_get_config_file_path.assert_called_once_with(not self.namespace.user) + mock_path_instance.open.assert_called_once_with(mode="w+", encoding="utf-8") + handle = mock_path_instance.open.return_value.__enter__.return_value + handle.write.assert_called_once_with( + json.dumps( + { + "username": self.namespace.username, + "password": self.namespace.password, + "cloud_name": self.namespace.cloud_name, + }, + ), + ) + + async def test_download(self) -> None: + """Test download.""" + mock_device = { + "device_id": 1, + "protocol": ProtocolVersion.V3, + "type": 0xAC, + "ip_address": "192.168.0.2", + "port": 6444, + "model": "AC123", + "sn": "AC123", + } + mock_cloud_instance = AsyncMock() + with ( + patch( + "midealocal.cli.discover", + side_effect=[{}, {1: mock_device}, {1: mock_device}], + ) as mock_discover, + patch.object( + self.cli, + "_get_cloud", + return_value=mock_cloud_instance, + ), + ): + await self.cli.download() # No device found + mock_discover.assert_called_once_with(ip_address=self.namespace.host) + mock_discover.reset_mock() + + mock_cloud_instance.login.side_effect = [False, True] + await self.cli.download() # Cloud login failed + mock_discover.assert_called_once_with(ip_address=self.namespace.host) + mock_discover.reset_mock() + mock_cloud_instance.login.assert_called_once() + mock_cloud_instance.login.reset_mock() + + await self.cli.download() + mock_discover.assert_called_once_with(ip_address=self.namespace.host) + mock_cloud_instance.login.assert_called_once() + mock_cloud_instance.download_lua.assert_called_once_with( + str(Path()), + mock_device["type"], + mock_device["sn"], + mock_device["model"], + ) + + def test_run(self) -> None: + """Test run.""" + mock_logger = MagicMock() + with ( + patch("logging.basicConfig") as mock_basic_config, + patch("logging.getLogger", return_value=mock_logger), + patch.object(mock_logger, "setLevel") as mock_set_level, + ): + self.cli.session = AsyncMock() + self.cli.run(self.namespace) + mock_basic_config.assert_called_once_with(level=logging.DEBUG) + mock_basic_config.reset_mock() + mock_set_level.assert_called_with(logging.INFO) + mock_set_level.reset_mock() + self.namespace.func.assert_called_once() + + # Test coroutine function + self.namespace.func = AsyncMock() + self.namespace.debug = False + self.cli.run(self.namespace) + mock_basic_config.assert_called_once_with(level=logging.INFO) + mock_set_level.assert_called_with(logging.WARNING) + self.namespace.func.assert_called_once() + + def test_main_call(self) -> None: + """Test main call.""" + # Command to run the script + cmd = [ + sys.executable, + "-m", + "midealocal.cli", + ] + clear_config = False + if not get_config_file_path().exists(): + clear_config = True + subprocess.run([*cmd, "save"], capture_output=True, text=True, check=False) # noqa: S603 + + # Run the command and capture the output + result = subprocess.run(cmd, capture_output=True, text=True, check=False) # noqa: S603 + + # Check if the script executed without errors + assert result.returncode == 2 + + result = subprocess.run( # noqa: S603 + [*cmd, "save"], + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode == 0 + + if clear_config: + get_config_file_path().unlink() + + def test_get_config_file_path(self) -> None: + """Test get config file path.""" + mock_path = MagicMock() + with ( + patch("midealocal.cli.Path", return_value=mock_path), + patch.object(mock_path, "exists", return_value=False), + ): + get_config_file_path()