Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ repos:
name: 🌟 Starring code with pylint
language: system
types: [python]
entry: scripts/run-in-env.sh pylint
entry: scripts/run-in-env.sh pylint matter_server/ tests/
require_serial: true
pass_filenames: false
- id: trailing-whitespace
name: ✄ Trim Trailing Whitespace
language: system
Expand Down
46 changes: 31 additions & 15 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datetime import datetime
from functools import cached_property, lru_cache
import logging
import re
import secrets
import time
from typing import TYPE_CHECKING, Any, cast
Expand Down Expand Up @@ -128,6 +129,10 @@
0, Clusters.IcdManagement.Attributes.AttributeList
)

RE_MDNS_SERVICE_NAME = re.compile(
rf"^([0-9A-Fa-f]{{16}})-([0-9A-Fa-f]{{16}})\.{re.escape(MDNS_TYPE_OPERATIONAL_NODE)}$"
)


# pylint: disable=too-many-lines,too-many-instance-attributes,too-many-public-methods

Expand All @@ -152,7 +157,6 @@ def __init__(
# we keep the last events in memory so we can include them in the diagnostics dump
self.event_history: deque[Attribute.EventReadResult] = deque(maxlen=25)
self._compressed_fabric_id: int | None = None
self._fabric_id_hex: str | None = None
self._wifi_credentials_set: bool = False
self._thread_credentials_set: bool = False
self._setup_node_tasks = dict[int, asyncio.Task]()
Expand All @@ -179,7 +183,6 @@ async def initialize(self) -> None:
self._compressed_fabric_id = (
await self._chip_device_controller.get_compressed_fabric_id()
)
self._fabric_id_hex = hex(self._compressed_fabric_id)[2:]
await load_local_updates(self._ota_provider_dir)

async def start(self) -> None:
Expand Down Expand Up @@ -245,8 +248,10 @@ async def stop(self) -> None:
LOGGER.debug("Stopped.")

@property
def compressed_fabric_id(self) -> int | None:
def compressed_fabric_id(self) -> int:
"""Return the compressed fabric id."""
if self._compressed_fabric_id is None:
raise RuntimeError("Compressed Fabric ID not set")
return self._compressed_fabric_id

@property
Expand Down Expand Up @@ -1524,25 +1529,36 @@ def _on_mdns_service_state_change(
)
return

if service_type == MDNS_TYPE_OPERATIONAL_NODE:
if self._fabric_id_hex is None or self._fabric_id_hex not in name.lower():
# filter out messages that are not for our fabric
return
# process the event with a debounce timer
if service_type != MDNS_TYPE_OPERATIONAL_NODE:
return

if not (match := RE_MDNS_SERVICE_NAME.match(name)):
LOGGER.getChild("mdns").warning(
"Service name doesn't match expected operational node pattern: %s", name
)
return

fabric_id_hex, node_id_hex = match.groups()

# Filter messages of other fabrics
if int(fabric_id_hex, 16) != self.compressed_fabric_id:
return

# Process the event with a debounce timer
self._mdns_event_timer[name] = self._loop.call_later(
0.5, self._on_mdns_operational_node_state, name, state_change
0.5,
self._on_mdns_operational_node_state,
name,
int(node_id_hex, 16),
state_change,
)

def _on_mdns_operational_node_state(
self, name: str, state_change: ServiceStateChange
self, name: str, node_id: int, state_change: ServiceStateChange
) -> None:
"""Handle a (operational) Matter node MDNS state change."""
self._mdns_event_timer.pop(name, None)
logger = LOGGER.getChild("mdns")
# the mdns name is constructed as [fabricid]-[nodeid]._matter._tcp.local.
# extract the node id from the name
node_id = int(name.split("-")[1].split(".")[0], 16)
node_logger = self.get_node_logger(logger, node_id)
node_logger = self.get_node_logger(LOGGER.getChild("mdns"), node_id)

if not (node := self._nodes.get(node_id)):
return # this should not happen, but guard just in case
Expand Down
23 changes: 15 additions & 8 deletions matter_server/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import asyncio
from functools import cached_property, partial
import inspect
import ipaddress
import logging
import os
Expand Down Expand Up @@ -315,18 +316,24 @@ def register_api_command(

def _register_api_commands(self) -> None:
"""Register all methods decorated as api_command."""
for cls in (self, self._device_controller, self.vendor_info):
for attr_name in dir(cls):
for obj in (self, self._device_controller, self.vendor_info):
cls = obj.__class__
for attr_name, attr in inspect.getmembers(cls):
if attr_name.startswith("_"):
continue
val = getattr(cls, attr_name)
if not hasattr(val, "api_cmd"):

if isinstance(attr, property):
continue # skip properties

# attr is the (unbound) function, we can check for the decorator
if not hasattr(attr, "api_cmd"):
continue
if hasattr(val, "mock_calls"):
# filter out mocks
if hasattr(attr, "mock_calls"):
continue
# method is decorated with our api decorator
self.register_api_command(val.api_cmd, val)

# Get the instance method before registering
bound_method = getattr(obj, attr_name)
self.register_api_command(attr.api_cmd, bound_method)

async def _handle_info(self, request: web.Request) -> web.Response:
"""Handle info endpoint to serve basic server (version) info."""
Expand Down
39 changes: 39 additions & 0 deletions tests/test_device_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Device controller tests."""

import pytest

from matter_server.server.device_controller import RE_MDNS_SERVICE_NAME


@pytest.mark.parametrize(
("name", "expected"),
[
(
"D22DC25523A78A89-0000000000000125._matter._tcp.local.",
("D22DC25523A78A89", "0000000000000125"),
),
(
"d22dc25523a78a89-0000000000000125._matter._tcp.local.",
("d22dc25523a78a89", "0000000000000125"),
),
],
)
def test_valid_mdns_service_names(name, expected):
"""Test valid mDNS service names."""
match = RE_MDNS_SERVICE_NAME.match(name)
assert match is not None
assert match.groups() == expected


@pytest.mark.parametrize(
"name",
[
"D22DC25523A78A89-0000000000000125 (2)._matter._tcp.local.",
"D22DC25523A78A89-0000000000000125.._matter._tcp.local.",
"G22DC25523A78A89-0000000000000125._matter._tcp.local.", # invalid hex
"D22DC25523A78A89-0000000000000125._matterc._udp.local.",
],
)
def test_invalid_mdns_service_names(name):
"""Test invalid mDNS service names."""
assert RE_MDNS_SERVICE_NAME.match(name) is None