diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b870673d..be2d4979 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -85,7 +85,9 @@ repos: name: 🌟 Starring code with pylint language: system types: [python] - entry: scripts/run-in-env.sh pylint + entry: scripts/run-in-env.sh pylint matter_server/ tests/ + require_serial: true + pass_filenames: false - id: trailing-whitespace name: ✄ Trim Trailing Whitespace language: system diff --git a/matter_server/server/device_controller.py b/matter_server/server/device_controller.py index cfb9b133..0e344c84 100644 --- a/matter_server/server/device_controller.py +++ b/matter_server/server/device_controller.py @@ -13,6 +13,7 @@ from datetime import datetime from functools import cached_property, lru_cache import logging +import re import secrets import time from typing import TYPE_CHECKING, Any, cast @@ -128,6 +129,10 @@ 0, Clusters.IcdManagement.Attributes.AttributeList ) +RE_MDNS_SERVICE_NAME = re.compile( + rf"^([0-9A-Fa-f]{{16}})-([0-9A-Fa-f]{{16}})\.{re.escape(MDNS_TYPE_OPERATIONAL_NODE)}$" +) + # pylint: disable=too-many-lines,too-many-instance-attributes,too-many-public-methods @@ -152,7 +157,6 @@ def __init__( # we keep the last events in memory so we can include them in the diagnostics dump self.event_history: deque[Attribute.EventReadResult] = deque(maxlen=25) self._compressed_fabric_id: int | None = None - self._fabric_id_hex: str | None = None self._wifi_credentials_set: bool = False self._thread_credentials_set: bool = False self._setup_node_tasks = dict[int, asyncio.Task]() @@ -179,7 +183,6 @@ async def initialize(self) -> None: self._compressed_fabric_id = ( await self._chip_device_controller.get_compressed_fabric_id() ) - self._fabric_id_hex = hex(self._compressed_fabric_id)[2:] await load_local_updates(self._ota_provider_dir) async def start(self) -> None: @@ -245,8 +248,10 @@ async def stop(self) -> None: LOGGER.debug("Stopped.") @property - def compressed_fabric_id(self) -> int | None: + def compressed_fabric_id(self) -> int: """Return the compressed fabric id.""" + if self._compressed_fabric_id is None: + raise RuntimeError("Compressed Fabric ID not set") return self._compressed_fabric_id @property @@ -1524,25 +1529,36 @@ def _on_mdns_service_state_change( ) return - if service_type == MDNS_TYPE_OPERATIONAL_NODE: - if self._fabric_id_hex is None or self._fabric_id_hex not in name.lower(): - # filter out messages that are not for our fabric - return - # process the event with a debounce timer + if service_type != MDNS_TYPE_OPERATIONAL_NODE: + return + + if not (match := RE_MDNS_SERVICE_NAME.match(name)): + LOGGER.getChild("mdns").warning( + "Service name doesn't match expected operational node pattern: %s", name + ) + return + + fabric_id_hex, node_id_hex = match.groups() + + # Filter messages of other fabrics + if int(fabric_id_hex, 16) != self.compressed_fabric_id: + return + + # Process the event with a debounce timer self._mdns_event_timer[name] = self._loop.call_later( - 0.5, self._on_mdns_operational_node_state, name, state_change + 0.5, + self._on_mdns_operational_node_state, + name, + int(node_id_hex, 16), + state_change, ) def _on_mdns_operational_node_state( - self, name: str, state_change: ServiceStateChange + self, name: str, node_id: int, state_change: ServiceStateChange ) -> None: """Handle a (operational) Matter node MDNS state change.""" self._mdns_event_timer.pop(name, None) - logger = LOGGER.getChild("mdns") - # the mdns name is constructed as [fabricid]-[nodeid]._matter._tcp.local. - # extract the node id from the name - node_id = int(name.split("-")[1].split(".")[0], 16) - node_logger = self.get_node_logger(logger, node_id) + node_logger = self.get_node_logger(LOGGER.getChild("mdns"), node_id) if not (node := self._nodes.get(node_id)): return # this should not happen, but guard just in case diff --git a/matter_server/server/server.py b/matter_server/server/server.py index 7fd96f22..7e04ff84 100644 --- a/matter_server/server/server.py +++ b/matter_server/server/server.py @@ -4,6 +4,7 @@ import asyncio from functools import cached_property, partial +import inspect import ipaddress import logging import os @@ -315,18 +316,24 @@ def register_api_command( def _register_api_commands(self) -> None: """Register all methods decorated as api_command.""" - for cls in (self, self._device_controller, self.vendor_info): - for attr_name in dir(cls): + for obj in (self, self._device_controller, self.vendor_info): + cls = obj.__class__ + for attr_name, attr in inspect.getmembers(cls): if attr_name.startswith("_"): continue - val = getattr(cls, attr_name) - if not hasattr(val, "api_cmd"): + + if isinstance(attr, property): + continue # skip properties + + # attr is the (unbound) function, we can check for the decorator + if not hasattr(attr, "api_cmd"): continue - if hasattr(val, "mock_calls"): - # filter out mocks + if hasattr(attr, "mock_calls"): continue - # method is decorated with our api decorator - self.register_api_command(val.api_cmd, val) + + # Get the instance method before registering + bound_method = getattr(obj, attr_name) + self.register_api_command(attr.api_cmd, bound_method) async def _handle_info(self, request: web.Request) -> web.Response: """Handle info endpoint to serve basic server (version) info.""" diff --git a/tests/test_device_controller.py b/tests/test_device_controller.py new file mode 100644 index 00000000..eb142595 --- /dev/null +++ b/tests/test_device_controller.py @@ -0,0 +1,39 @@ +"""Device controller tests.""" + +import pytest + +from matter_server.server.device_controller import RE_MDNS_SERVICE_NAME + + +@pytest.mark.parametrize( + ("name", "expected"), + [ + ( + "D22DC25523A78A89-0000000000000125._matter._tcp.local.", + ("D22DC25523A78A89", "0000000000000125"), + ), + ( + "d22dc25523a78a89-0000000000000125._matter._tcp.local.", + ("d22dc25523a78a89", "0000000000000125"), + ), + ], +) +def test_valid_mdns_service_names(name, expected): + """Test valid mDNS service names.""" + match = RE_MDNS_SERVICE_NAME.match(name) + assert match is not None + assert match.groups() == expected + + +@pytest.mark.parametrize( + "name", + [ + "D22DC25523A78A89-0000000000000125 (2)._matter._tcp.local.", + "D22DC25523A78A89-0000000000000125.._matter._tcp.local.", + "G22DC25523A78A89-0000000000000125._matter._tcp.local.", # invalid hex + "D22DC25523A78A89-0000000000000125._matterc._udp.local.", + ], +) +def test_invalid_mdns_service_names(name): + """Test invalid mDNS service names.""" + assert RE_MDNS_SERVICE_NAME.match(name) is None