Skip to content

Commit ee3c07d

Browse files
committed
refactor ei
1 parent 9543012 commit ee3c07d

File tree

11 files changed

+193
-59
lines changed

11 files changed

+193
-59
lines changed

src/arduino/app_bricks/audio_classification/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import wave
77
from typing import Callable
88

9-
from arduino.app_internal.core.audio import AudioDetector, NO_MIC as NO_MIC
9+
from arduino.app_internal.core.audio import AudioDetector
1010
from arduino.app_peripherals.microphone import Microphone
1111
from arduino.app_utils import brick, Logger
12+
from arduino.app_internal.core import EdgeImpulseRunnerFacade
1213

1314
logger = Logger("AudioClassification")
1415

@@ -28,8 +29,6 @@ def __init__(self, mic: Microphone = None, confidence: float = 0.8):
2829
2930
Args:
3031
mic (Microphone, optional): Microphone instance used as the audio source. If None, a default Microphone will be initialized.
31-
If NO_MIC is passed, no microphone will be initialized, and only file-based classification
32-
will be available.
3332
confidence (float, optional): Minimum confidence threshold (0.0–1.0) required
3433
for a detection to be considered valid. Defaults to 0.8 (80%).
3534
@@ -68,7 +67,8 @@ def stop(self):
6867
"""
6968
super().stop()
7069

71-
def classify_from_file(self, audio_path: str, confidence: int = None) -> dict | None:
70+
@staticmethod
71+
def classify_from_file(audio_path: str, confidence: int) -> dict | None:
7272
"""Classify audio content from a WAV file.
7373
7474
Supported sample widths:
@@ -91,10 +91,10 @@ def classify_from_file(self, audio_path: str, confidence: int = None) -> dict |
9191
9292
Raises:
9393
AudioClassificationException: If the file cannot be found, read, or processed.
94-
ValueError: If the file uses an unsupported sample width.
94+
ValueError: If the file uses an unsupported sample width or if confidence is not specified.
9595
"""
9696
if confidence is None:
97-
confidence = self.confidence
97+
raise ValueError("Confidence level must be specified.")
9898

9999
try:
100100
with wave.open(audio_path, "rb") as wf:
@@ -127,8 +127,8 @@ def classify_from_file(self, audio_path: str, confidence: int = None) -> dict |
127127
else:
128128
raise ValueError(f"Unsupported sample width: {samp_width} bytes. Cannot process this WAV file.")
129129

130-
classification = super().infer_from_features(features[: int(self.model_info.input_features_count)])
131-
best_match = super().get_best_match(classification, confidence)
130+
classification = EdgeImpulseRunnerFacade.infer_from_features(features)
131+
best_match = AudioDetector.get_best_match(classification, confidence)
132132
if not best_match:
133133
return None
134134
keyword, confidence = best_match

src/arduino/app_bricks/audio_classification/examples/2_glass_breaking_from_file.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44

55
# EXAMPLE_NAME = "Detect the glass breaking sound from audio file"
66
# EXAMPLE_REQUIRES = "Requires an audio file with the glass breaking sound."
7-
from arduino.app_bricks.audio_classification import AudioClassification, NO_MIC
7+
from arduino.app_bricks.audio_classification import AudioClassification
88

9-
classifier = AudioClassification(mic=NO_MIC)
10-
11-
classification = classifier.classify_from_file("glass_breaking.wav")
9+
classification = AudioClassification.classify_from_file("glass_breaking.wav", confidence=0.8)
1210
print("Result:", classification)

src/arduino/app_bricks/motion_detection/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, confidence: float = 0.4):
2626
"""
2727
self._confidence = confidence
2828
super().__init__()
29-
model_info = self.get_model_info()
29+
model_info = EdgeImpulseRunnerFacade.get_model_info()
3030
if not model_info:
3131
raise ValueError("Failed to retrieve model information. Ensure the EI model runner service is running.")
3232
if model_info.frequency <= 0 or model_info.input_features_count <= 0:
@@ -133,7 +133,7 @@ def _detection_loop(self):
133133
return
134134

135135
try:
136-
ret = super().infer_from_features(features[: int(self._model_info.input_features_count)].flatten().tolist())
136+
ret = EdgeImpulseRunnerFacade.infer_from_features(features.tolist())
137137
spotted_movement = self._movement_spotted(ret)
138138
if spotted_movement is not None:
139139
keyword, confidence, complete_detection = spotted_movement

src/arduino/app_bricks/object_detection/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, confidence: float = 0.3):
3131
"""
3232
self.confidence = confidence
3333
super().__init__()
34-
self._model_info = self.get_model_info()
34+
self._model_info = EdgeImpulseRunnerFacade.get_model_info()
3535
if not self._model_info:
3636
raise ValueError("Failed to retrieve model information. Ensure the Edge Impulse service is running.")
3737

src/arduino/app_bricks/vibration_anomaly_detection/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, anomaly_detection_threshold: float = 1.0):
4848
"""
4949
self._anomaly_detection_threshold = anomaly_detection_threshold
5050
super().__init__()
51-
model_info = self.get_model_info()
51+
model_info = EdgeImpulseRunnerFacade.get_model_info()
5252
if not model_info:
5353
raise ValueError("Failed to retrieve model information. Ensure the EI model runner service is running.")
5454
if model_info.frequency <= 0 or model_info.input_features_count <= 0:
@@ -133,7 +133,7 @@ def loop(self):
133133
if features is None or len(features) == 0:
134134
return
135135

136-
ret = super().infer_from_features(features[: int(self._model_info.input_features_count)].flatten().tolist())
136+
ret = EdgeImpulseRunnerFacade.infer_from_features(features.tolist())
137137
logger.debug(f"Inference result: {ret}")
138138
spotted_anomaly = self._extract_anomaly_score(ret)
139139
if spotted_anomaly is not None:
@@ -184,6 +184,14 @@ def stop(self):
184184
"""
185185
self._clear()
186186

187+
def get_model_info(self):
188+
"""Get the Edge Impulse model information used by this detector.
189+
190+
Returns:
191+
EdgeImpulseModelInfo: The model information object.
192+
"""
193+
return self._model_info
194+
187195
def _clear(self):
188196
"""Internal helper: flush the sensor data buffer and log the action.
189197

src/arduino/app_internal/core/audio.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
logger = Logger(__name__)
1515

16-
NO_MIC = object() # Sentinel value for no microphone
17-
1816

1917
class AudioDetector(EdgeImpulseRunnerFacade):
2018
"""AudioDetector module for detecting sounds and classifying audio using a specified model."""
@@ -37,14 +35,14 @@ def __init__(self, mic: Microphone = None, confidence: float = 0.8, debounce_sec
3735
self._debounce_sec = debounce_sec
3836
self._last_detected = {}
3937

40-
model_info = self.get_model_info()
38+
model_info = EdgeImpulseRunnerFacade.get_model_info()
4139
if not model_info:
4240
raise ValueError("Failed to retrieve model information. Ensure the Edge Impulse service is running.")
4341
if model_info.frequency <= 0 or model_info.input_features_count <= 0:
4442
raise ValueError("Model parameters are missing or incomplete in the retrieved model information.")
4543
self.model_info = model_info
4644

47-
self._mic = None if mic is NO_MIC else (mic or Microphone(sample_rate=model_info.frequency, channels=model_info.axis_count))
45+
self._mic = mic if mic else Microphone(sample_rate=model_info.frequency, channels=model_info.axis_count)
4846
self._mic_lock = threading.Lock()
4947

5048
self._window_size = int(model_info.input_features_count / model_info.axis_count)
@@ -87,17 +85,24 @@ def stop(self):
8785
self._mic.stop()
8886
self._buffer.flush()
8987

90-
def get_best_match(self, item: dict, confidence: int = None) -> tuple[str, float] | None:
88+
@staticmethod
89+
def get_best_match(item: dict, confidence: float) -> tuple[str, float] | None:
9190
"""Extract the best matched keyword from the classification results.
9291
9392
Args:
9493
item (dict): The classification result from the inference.
95-
confidence (int): The confidence threshold for classification. If None, uses the instance's confidence level.
94+
confidence (float): The confidence threshold for classification.
9695
9796
Returns:
9897
tuple[str, float] | None: The best matched keyword and its confidence, or None if no match is found.
98+
99+
Raises:
100+
ValueError: If confidence level is not provided.
99101
"""
100-
classification = _extract_classification(item, confidence or self.confidence)
102+
if confidence is None:
103+
raise ValueError("Confidence level must be provided.")
104+
105+
classification = _extract_classification(item, confidence)
101106
if not classification:
102107
return None
103108

@@ -143,8 +148,8 @@ def _inference_loop(self):
143148

144149
logger.debug(f"Processing sensor data with {len(features)} features.")
145150
try:
146-
ret = super().infer_from_features(features[: int(self.model_info.input_features_count)].tolist())
147-
spotted_keyword = self.get_best_match(ret)
151+
ret = EdgeImpulseRunnerFacade.infer_from_features(features.tolist())
152+
spotted_keyword = AudioDetector.get_best_match(ret, self.confidence)
148153
if spotted_keyword:
149154
keyword, confidence = spotted_keyword
150155
keyword = keyword.lower()

src/arduino/app_internal/core/ei.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,8 @@ class EdgeImpulseRunnerFacade:
4343

4444
def __init__(self):
4545
"""Initialize the EdgeImpulseRunnerFacade with the API path."""
46-
infra = load_brick_compose_file(self.__class__)
47-
for k, v in infra["services"].items():
48-
self.host = k
49-
self.infra = v
50-
break # Only one service is expected
51-
52-
self.host = resolve_address(self.host)
53-
54-
self.port = 1337 # Default EI HTTP port
55-
self.url = f"http://{self.host}:{self.port}"
56-
logger.warning(f"[{self.__class__.__name__}] Host: {self.host} - Ports: {self.port} - URL: {self.url}")
46+
self.url = _get_ei_url(self.__class__)
47+
logger.warning(f"[{self.__class__.__name__}] URL: {self.url}")
5748

5849
def infer_from_file(self, image_path: str) -> dict | None:
5950
if not image_path or image_path == "":
@@ -124,47 +115,56 @@ def process(self, item):
124115
logger.error(f"[{self.__class__}] Error processing file {item}: {e}")
125116
return None
126117

127-
def infer_from_features(self, features: list) -> dict | None:
128-
"""Infer from features using the Edge Impulse API.
118+
@classmethod
119+
def infer_from_features(cls, features: list) -> dict | None:
120+
"""
121+
Infer from features using the Edge Impulse API.
129122
130123
Args:
124+
cls: The class method caller.
131125
features (list): A list of features to send to the Edge Impulse API.
132126
133127
Returns:
134128
dict | None: The response from the Edge Impulse API as a dictionary, or None if an error occurs.
135129
"""
136130
try:
137-
response = requests.post(f"{self.url}/api/features", json={"features": features})
131+
url = _get_ei_url(cls)
132+
model_info = EdgeImpulseRunnerFacade.get_model_info()
133+
features = features[: int(model_info.input_features_count)]
134+
135+
response = requests.post(f"{url}/api/features", json={"features": features})
138136
if response.status_code == 200:
139137
return response.json()
140138
else:
141-
logger.warning(f"[{self.__class__}] error: {response.status_code}. Message: {response.text}")
139+
logger.warning(f"[{cls.__name__}] error: {response.status_code}. Message: {response.text}")
142140
return None
143141
except Exception as e:
144-
logger.error(f"[{self.__class__.__name__}] Error: {e}")
142+
logger.error(f"[{cls.__name__}] Error: {e}")
145143
return None
146144

147-
def get_model_info(self) -> EdgeImpulseModelInfo | None:
145+
@classmethod
146+
def get_model_info(cls) -> EdgeImpulseModelInfo | None:
148147
"""Get model information from the Edge Impulse API.
149148
149+
Args:
150+
cls: The class method caller.
151+
150152
Returns:
151153
model_info (EdgeImpulseModelInfo | None): An instance of EdgeImpulseModelInfo containing model details, None if an error occurs.
152154
"""
153-
if not self.host or not self.port:
154-
logger.error(f"[{self.__class__}] Host or port not set. Cannot fetch model info.")
155-
return None
155+
url = _get_ei_url(cls)
156156

157157
http_client = HttpClient(total_retries=6) # Initialize the HTTP client with retry logic
158158
try:
159-
response = http_client.request_with_retry(f"{self.url}/api/info")
159+
response = http_client.request_with_retry(f"{url}/api/info")
160160
if response.status_code == 200:
161-
logger.debug(f"[{self.__class__.__name__}] Fetching model info from {self.url}/api/info -> {response.status_code} {response.json}")
161+
logger.debug(f"[{cls.__name__}] Fetching model info from {url}/api/info -> {response.status_code} {response.json}")
162162
return EdgeImpulseModelInfo(response.json())
163163
else:
164-
logger.warning(f"[{self.__class__}] Error fetching model info: {response.status_code}. Message: {response.text}")
164+
logger.warning(f"[{cls}] Error fetching model info: {response.status_code}. Message: {response.text}")
165165
return None
166166
except Exception as e:
167-
logger.error(f"[{self.__class__}] Error fetching model info: {e}")
167+
logger.error(f"[{cls}] Error fetching model info: {e}")
168168
return None
169169
finally:
170170
http_client.close() # Close the HTTP client session
@@ -237,3 +237,14 @@ def _extract_anomaly_score(self, item: dict):
237237
return class_results["anomaly"]
238238

239239
return None
240+
241+
242+
def _get_ei_url(cls):
243+
infra = load_brick_compose_file(cls)
244+
for k, v in infra["services"].items():
245+
host = k
246+
break
247+
host = resolve_address(host)
248+
port = 1337
249+
url = f"http://{host}:{port}"
250+
return url

tests/arduino/app_bricks/motion_detection/test_motion_detection.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ def app_instance(monkeypatch):
2424
@pytest.fixture(autouse=True)
2525
def mock_dependencies(monkeypatch: pytest.MonkeyPatch):
2626
"""Mock out docker-compose lookups and image helpers."""
27-
fake_compose = {"services": {"models-runner": {"ports": ["${BIND_ADDRESS:-127.0.0.1}:${BIND_PORT:-8100}:8100"]}}}
28-
monkeypatch.setattr("arduino.app_internal.core.load_brick_compose_file", lambda cls: fake_compose)
27+
fake_compose = {"services": {"ei-inference": {"ports": ["${BIND_ADDRESS:-127.0.0.1}:${BIND_PORT:-1337}:1337"]}}}
28+
monkeypatch.setattr("arduino.app_internal.core.ei.load_brick_compose_file", lambda cls: fake_compose)
2929
monkeypatch.setattr("arduino.app_internal.core.resolve_address", lambda host: "127.0.0.1")
30-
monkeypatch.setattr("arduino.app_internal.core.parse_docker_compose_variable", lambda x: [(None, None), (None, "8200")])
30+
monkeypatch.setattr("arduino.app_internal.core.parse_docker_compose_variable", lambda x: [(None, None), (None, "1337")])
3131

3232
class FakeResp:
3333
status_code = 200

tests/arduino/app_bricks/objectdetection/test_objectdetection.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import io
88
from PIL import Image
99
from arduino.app_bricks.object_detection import ObjectDetection
10+
from arduino.app_utils import HttpClient
1011

1112

1213
class ModelInfo:
@@ -20,12 +21,65 @@ def mock_dependencies(monkeypatch: pytest.MonkeyPatch):
2021
2122
This is needed to avoid network calls and other side effects.
2223
"""
23-
fake_compose = {"services": {"models-runner": {"ports": ["${BIND_ADDRESS:-127.0.0.1}:${BIND_PORT:-8100}:8100"]}}}
24-
monkeypatch.setattr("arduino.app_internal.core.load_brick_compose_file", lambda cls: fake_compose)
24+
fake_compose = {"services": {"ei-inference": {"ports": ["${BIND_ADDRESS:-127.0.0.1}:${BIND_PORT:-1337}:1337"]}}}
25+
monkeypatch.setattr("arduino.app_internal.core.ei.load_brick_compose_file", lambda cls: fake_compose)
2526
monkeypatch.setattr("arduino.app_internal.core.resolve_address", lambda host: "127.0.0.1")
26-
monkeypatch.setattr("arduino.app_internal.core.parse_docker_compose_variable", lambda x: [(None, None), (None, "8100")])
27+
monkeypatch.setattr("arduino.app_internal.core.parse_docker_compose_variable", lambda x: [(None, None), (None, "1337")])
2728
monkeypatch.setattr("arduino.app_bricks.object_detection.ObjectDetection.get_model_info", lambda self: ModelInfo("object-detection"))
2829

30+
class FakeResp:
31+
status_code = 200
32+
33+
def json(self):
34+
return {
35+
"project": {
36+
"deploy_version": 11,
37+
"id": 774707,
38+
"impulse_id": 1,
39+
"impulse_name": "Time series data, Spectral Analysis, Classification (Keras), Anomaly Detection (K-means)",
40+
"name": "Fan Monitoring - Advanced Anomaly Detection",
41+
"owner": "Arduino",
42+
},
43+
"modelParameters": {
44+
"has_visual_anomaly_detection": False,
45+
"axis_count": 3,
46+
"frequency": 100,
47+
"has_anomaly": 1,
48+
"has_object_tracking": False,
49+
"has_performance_calibration": False,
50+
"image_channel_count": 0,
51+
"image_input_frames": 0,
52+
"image_input_height": 0,
53+
"image_input_width": 0,
54+
"image_resize_mode": "none",
55+
"inferencing_engine": 4,
56+
"input_features_count": 600,
57+
"interval_ms": 10,
58+
"label_count": 2,
59+
"labels": ["nominal", "off"],
60+
"model_type": "classification",
61+
"sensor": 2,
62+
"slice_size": 50,
63+
"thresholds": [],
64+
"use_continuous_mode": False,
65+
"sensorType": "accelerometer",
66+
},
67+
}
68+
69+
def fake_get(
70+
self,
71+
url: str,
72+
method: str = "GET",
73+
data: dict | str = None,
74+
json: dict = None,
75+
headers: dict = None,
76+
timeout: int = 5,
77+
):
78+
return FakeResp()
79+
80+
# Mock the requests.get method to return a fake response
81+
monkeypatch.setattr(HttpClient, "request_with_retry", fake_get)
82+
2983

3084
@pytest.fixture
3185
def detector():

0 commit comments

Comments
 (0)