From 7ca25dc009f971dda89b1cad03fae97eb4c6fd95 Mon Sep 17 00:00:00 2001 From: "Toni M. Brotons" <10654467+toni-neurosc@users.noreply.github.com> Date: Thu, 5 Dec 2024 15:22:07 +0100 Subject: [PATCH] - Refactor handling of async calls to websocket send data functions - Remove queues from Stream and instead create new StreamBackendInterface class to handle communication between Stream and PyNMState - Improvements to WebsocketManager - A bunch of typing fixes --- .../_get_grid_whole_brain.py | 1 - .../_helper_write_connectome.py | 2 - py_neuromodulation/analysis/decode.py | 46 +++- py_neuromodulation/gui/backend/app_backend.py | 19 +- py_neuromodulation/gui/backend/app_manager.py | 2 - py_neuromodulation/gui/backend/app_pynm.py | 231 ++++++++++-------- py_neuromodulation/gui/backend/app_socket.py | 70 +++++- py_neuromodulation/gui/backend/app_utils.py | 11 +- .../stream/backend_interface.py | 47 ++++ py_neuromodulation/stream/mnelsl_player.py | 6 +- py_neuromodulation/stream/stream.py | 90 +++---- py_neuromodulation/utils/channels.py | 2 +- 12 files changed, 340 insertions(+), 187 deletions(-) create mode 100644 py_neuromodulation/stream/backend_interface.py diff --git a/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py b/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py index fa97e424..4f4409f6 100644 --- a/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +++ b/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py @@ -70,7 +70,6 @@ def select_non_zero_voxels( coord_arr = np.array(coord_) ival_non_zero = ival_arr[ival != 0] coord_non_zero = coord_arr[ival != 0] - print(coord_non_zero.shape) return coord_non_zero, ival_non_zero diff --git a/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py b/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py index 2238ff6b..89cd5b2e 100644 --- a/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +++ b/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py @@ -58,8 +58,6 @@ def write_connectome_mat( dict_connectome[ f[f.find("ROI-") + 4 : f.find("_func_seed_AvgR_Fz.nii")] ] = fp - - print(idx) # save the dictionary sio.savemat( PATH_CONNECTOME, diff --git a/py_neuromodulation/analysis/decode.py b/py_neuromodulation/analysis/decode.py index 4971852b..722f7192 100644 --- a/py_neuromodulation/analysis/decode.py +++ b/py_neuromodulation/analysis/decode.py @@ -6,42 +6,66 @@ import pandas as pd import numpy as np from copy import deepcopy -from pathlib import PurePath +from pathlib import Path import pickle from py_neuromodulation import logger +from py_neuromodulation.utils.types import _PathLike from typing import Callable class RealTimeDecoder: - - def __init__(self, model_path: str): - self.model_path = model_path - if model_path.endswith(".skops"): + def __init__(self, model_path: _PathLike): + self.model_path = Path(model_path) + if not self.model_path.exists(): + raise FileNotFoundError(f"Model file {self.model_path} not found") + if not self.model_path.is_file(): + raise IsADirectoryError(f"Model file {self.model_path} is a directory") + + if self.model_path.suffix == ".skops": from skops import io as skops_io - self.model = skops_io.load(model_path) + + self.model = skops_io.load(self.model_path) else: return NotImplementedError("Only skops models are supported") - def predict(self, feature_dict: dict, channel: str = None, fft_bands_only: bool = True ) -> dict: + def predict( + self, + feature_dict: dict, + channel: str | None = None, + fft_bands_only: bool = True, + ) -> dict: try: if channel is not None: - features_ch = {f: feature_dict[f] for f in feature_dict.keys() if f.startswith(channel)} + features_ch = { + f: feature_dict[f] + for f in feature_dict.keys() + if f.startswith(channel) + } if fft_bands_only is True: - features_ch_fft = {f: features_ch[f] for f in features_ch.keys() if "fft" in f and "psd" not in f} - out = self.model.predict_proba(np.array(list(features_ch_fft.values())).reshape(1, -1)) + features_ch_fft = { + f: features_ch[f] + for f in features_ch.keys() + if "fft" in f and "psd" not in f + } + out = self.model.predict_proba( + np.array(list(features_ch_fft.values())).reshape(1, -1) + ) else: out = self.model.predict_proba(features_ch) else: out = self.model.predict(feature_dict) for decode_output_idx in range(out.shape[1]): - feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[decode_output_idx] + feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[ + decode_output_idx + ] return feature_dict except Exception as e: logger.error(f"Error in decoding: {e}") return feature_dict + class CV_res: def __init__( self, diff --git a/py_neuromodulation/gui/backend/app_backend.py b/py_neuromodulation/gui/backend/app_backend.py index ec6f4bf2..f64e837d 100644 --- a/py_neuromodulation/gui/backend/app_backend.py +++ b/py_neuromodulation/gui/backend/app_backend.py @@ -15,7 +15,7 @@ from pydantic import ValidationError from . import app_pynm -from .app_socket import WebSocketManager +from .app_socket import WebsocketManager from .app_utils import is_hidden, get_quick_access import pandas as pd @@ -29,7 +29,6 @@ class PyNMBackend(FastAPI): def __init__( self, - pynm_state: app_pynm.PyNMState, debug=False, dev=True, dev_port: int | None = None, @@ -48,7 +47,6 @@ def __init__( cors_origins = ( ["http://localhost:" + str(dev_port)] if dev_port is not None else [] ) - print(cors_origins) # Configure CORS self.add_middleware( CORSMiddleware, @@ -69,8 +67,8 @@ def __init__( name="static", ) - self.pynm_state = pynm_state - self.websocket_manager = WebSocketManager() + self.websocket_manager = WebsocketManager() + self.pynm_state = app_pynm.PyNMState() def setup_routes(self): @self.get("/api/health") @@ -158,8 +156,7 @@ async def handle_stream_control(data: dict): if action == "stop": self.logger.info("Stopping stream") - self.pynm_state.stream_handling_queue.put("stop") - self.pynm_state.stop_event_ws.set() + self.pynm_state.stop_run_function() return {"message": f"Stream action '{action}' executed"} @@ -321,6 +318,14 @@ async def home_directory(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + # Get PYNM_DIR + @self.get("/api/pynm_dir") + async def get_pynm_dir(): + try: + return {"pynm_dir": PYNM_DIR} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + # Get list of available drives in Windows systems @self.get("/api/drives") async def list_drives(): diff --git a/py_neuromodulation/gui/backend/app_manager.py b/py_neuromodulation/gui/backend/app_manager.py index ec1587da..8bec1023 100644 --- a/py_neuromodulation/gui/backend/app_manager.py +++ b/py_neuromodulation/gui/backend/app_manager.py @@ -27,7 +27,6 @@ def create_backend(): :return: The web application instance. :rtype: PyNMBackend """ - from .app_pynm import PyNMState from .app_backend import PyNMBackend debug = os.environ.get("PYNM_DEBUG", "False").lower() == "true" @@ -35,7 +34,6 @@ def create_backend(): dev_port = os.environ.get("PYNM_DEV_PORT", str(DEV_SERVER_PORT)) return PyNMBackend( - pynm_state=PyNMState(), debug=debug, dev=dev, dev_port=int(dev_port), diff --git a/py_neuromodulation/gui/backend/app_pynm.py b/py_neuromodulation/gui/backend/app_pynm.py index d12ef89c..23641786 100644 --- a/py_neuromodulation/gui/backend/app_pynm.py +++ b/py_neuromodulation/gui/backend/app_pynm.py @@ -1,165 +1,145 @@ import os -import asyncio -import logging -import threading import numpy as np -import multiprocessing as mp from threading import Thread -import queue +import time +import asyncio +import multiprocessing as mp +from queue import Empty +from pathlib import Path from py_neuromodulation.stream import Stream, NMSettings from py_neuromodulation.analysis.decode import RealTimeDecoder from py_neuromodulation.utils import set_channels from py_neuromodulation.utils.io import read_mne_data +from py_neuromodulation.utils.types import _PathLike +from py_neuromodulation import logger +from py_neuromodulation.gui.backend.app_socket import WebsocketManager +from py_neuromodulation.stream.backend_interface import StreamBackendInterface from py_neuromodulation import logger - - -async def run_stream_controller( - feature_queue: queue.Queue, - rawdata_queue: queue.Queue, - websocket_manager: "WebSocketManager", - stop_event: threading.Event, -): - while not stop_event.wait(0.002): - if not feature_queue.empty() and websocket_manager is not None: - feature_dict = feature_queue.get() - logger.info("Sending message to Websocket") - await websocket_manager.send_cbor(feature_dict) - - if not rawdata_queue.empty() and websocket_manager is not None: - raw_data = rawdata_queue.get() - - await websocket_manager.send_cbor(raw_data) - - -def run_stream_controller_sync( - feature_queue: queue.Queue, - rawdata_queue: queue.Queue, - websocket_manager: "WebSocketManager", - stop_event: threading.Event, -): - # The run_stream_controller needs to be started as an asyncio function due to the async websocket - asyncio.run( - run_stream_controller( - feature_queue, rawdata_queue, websocket_manager, stop_event - ) - ) class PyNMState: def __init__( self, - default_init: bool = True, # has to be true for the backend settings communication ) -> None: - self.logger = logging.getLogger("uvicorn.error") + self.lsl_stream_name: str = "" + self.experiment_name: str = "PyNM_Experiment" # set by set_stream_params + self.out_dir: _PathLike = str( + Path.home() / "PyNM" / self.experiment_name + ) # set by set_stream_params + self.decoding_model_path: _PathLike | None = None + self.decoder: RealTimeDecoder | None = None + self.is_stream_lsl: bool = False + + self.backend_interface: StreamBackendInterface | None = None + self.websocket_manager: WebsocketManager | None = None - self.lsl_stream_name = None - self.stream_controller_process = None - self.run_func_process = None - self.out_dir = None # will be set by set_stream_params - self.experiment_name = None # will be set by set_stream_params - self.decoding_model_path = None - self.decoder = None + # Note: sfreq and data are required for stream init + self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1])) + self.settings: NMSettings = NMSettings(sampling_rate_features=10) - if default_init: - self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1])) - self.settings: NMSettings = NMSettings(sampling_rate_features=10) + self.feature_queue = mp.Queue() + self.rawdata_queue = mp.Queue() + self.control_queue = mp.Queue() + self.stop_event = asyncio.Event() + + self.messages_sent = 0 def start_run_function( self, - websocket_manager=None, + websocket_manager: WebsocketManager | None = None, ) -> None: - self.stream.settings = self.settings - - self.stream_handling_queue = queue.Queue() - self.feature_queue = queue.Queue() - self.rawdata_queue = queue.Queue() + # TONI: This is dangerous to do here, should be done by the setup functions + # self.stream.settings = self.settings - self.logger.info("Starting stream_controller_process thread") + self.is_stream_lsl = self.lsl_stream_name is not None + self.websocket_manager = websocket_manager + # Create decoder if self.decoding_model_path is not None and self.decoding_model_path != "None": if os.path.exists(self.decoding_model_path): self.decoder = RealTimeDecoder(self.decoding_model_path) else: logger.debug("Passed decoding model path does't exist") - # Stop event - # .set() is called from app_backend - self.stop_event_ws = threading.Event() - - self.stream_controller_thread = Thread( - target=run_stream_controller_sync, - daemon=True, - args=( - self.feature_queue, - self.rawdata_queue, - websocket_manager, - self.stop_event_ws, - ), - ) - is_stream_lsl = self.lsl_stream_name is not None - stream_lsl_name = ( - self.lsl_stream_name if self.lsl_stream_name is not None else "" - ) + # Initialize the backend interface if not already done + if not self.backend_interface: + self.backend_interface = StreamBackendInterface( + self.feature_queue, self.rawdata_queue, self.control_queue + ) # The run_func_thread is terminated through the stream_handling_queue # which initiates to break the data generator and save the features - out_dir = "" if self.out_dir == "default" else self.out_dir - self.run_func_thread = Thread( + stream_process = mp.Process( target=self.stream.run, - daemon=True, kwargs={ - "out_dir": out_dir, + "out_dir": "" if self.out_dir == "default" else self.out_dir, "experiment_name": self.experiment_name, - "stream_handling_queue": self.stream_handling_queue, - "is_stream_lsl": is_stream_lsl, - "stream_lsl_name": stream_lsl_name, - "feature_queue": self.feature_queue, + "is_stream_lsl": self.lsl_stream_name is not None, + "stream_lsl_name": self.lsl_stream_name or "", "simulate_real_time": True, - "rawdata_queue": self.rawdata_queue, "decoder": self.decoder, + "backend_interface": self.backend_interface, }, ) - self.stream_controller_thread.start() - self.run_func_thread.start() + stream_process.start() + + # Start websocket sender process + + if self.websocket_manager: + # Get the current event loop and run the queue processor + loop = asyncio.get_running_loop() + queue_task = loop.create_task(self._process_queue()) + + # Store task reference for cleanup + self._queue_task = queue_task + + # Store processes for cleanup + self.stream_process = stream_process + + def stop_run_function(self) -> None: + """Stop the stream processing""" + if self.backend_interface: + self.backend_interface.send_command("stop") + self.stop_event.set() def setup_lsl_stream( self, - lsl_stream_name: str | None = None, + lsl_stream_name: str = "", line_noise: float | None = None, sampling_rate_features: float | None = None, ): from mne_lsl.lsl import resolve_streams - self.logger.info("resolving streams") + logger.info("resolving streams") lsl_streams = resolve_streams() for stream in lsl_streams: if stream.name == lsl_stream_name: - self.logger.info(f"found stream {lsl_stream_name}") + logger.info(f"found stream {lsl_stream_name}") # setup this stream self.lsl_stream_name = lsl_stream_name ch_names = stream.get_channel_names() if ch_names is None: ch_names = ["ch" + str(i) for i in range(stream.n_channels)] - self.logger.info(f"channel names: {ch_names}") + logger.info(f"channel names: {ch_names}") ch_types = stream.get_channel_types() if ch_types is None: ch_types = ["eeg" for i in range(stream.n_channels)] - self.logger.info(f"channel types: {ch_types}") + logger.info(f"channel types: {ch_types}") info_ = stream.get_channel_info() - self.logger.info(f"channel info: {info_}") + logger.info(f"channel info: {info_}") channels = set_channels( ch_names=ch_names, ch_types=ch_types, used_types=["eeg", "ecog", "dbs", "seeg"], ) - self.logger.info(channels) + logger.info(channels) sfreq = stream.sfreq self.stream: Stream = Stream( @@ -169,20 +149,19 @@ def setup_lsl_stream( sampling_rate_features_hz=sampling_rate_features, settings=self.settings, ) - self.logger.info("stream setup") + logger.info("stream setup") # self.settings: NMSettings = NMSettings(sampling_rate_features=sfreq) - self.logger.info("settings setup") + logger.info("settings setup") break - - if channels.shape[0] == 0: - self.logger.error(f"Stream {lsl_stream_name} not found") + else: + logger.error(f"Stream {lsl_stream_name} not found") raise ValueError(f"Stream {lsl_stream_name} not found") def setup_offline_stream( self, file_path: str, - line_noise: float | None = None, - sampling_rate_features: float | None = None, + line_noise: float, + sampling_rate_features: float, ): data, sfreq, ch_names, ch_types, bads = read_mne_data(file_path) @@ -195,7 +174,9 @@ def setup_offline_stream( target_keywords=None, ) - self.logger.info(f"settings: {self.settings}") + self.settings.sampling_rate_features_hz = sampling_rate_features + + logger.info(f"settings: {self.settings}") self.stream: Stream = Stream( settings=self.settings, sfreq=sfreq, @@ -204,3 +185,53 @@ def setup_offline_stream( line_noise=line_noise, sampling_rate_features_hz=sampling_rate_features, ) + + # Async function that will continuously run in the Uvicorn async loop + # and handle sending data through the websocket manager + async def _process_queue(self): + last_queue_check = time.time() + + while not self.stop_event.is_set(): + # Use asyncio.gather to process both queues concurrently + tasks = [] + current_time = time.time() + + # Check feature queue + while not self.feature_queue.empty(): + try: + data = self.feature_queue.get_nowait() + tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore + self.messages_sent += 1 + except Empty: + break + + # Check raw data queue + while not self.rawdata_queue.empty(): + try: + data = self.rawdata_queue.get_nowait() + self.messages_sent += 1 + tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore + except Empty: + break + + if tasks: + # Wait for all send operations to complete + await asyncio.gather(*tasks, return_exceptions=True) + else: + # Only sleep if we didn't process any messages + await asyncio.sleep(0.001) + + # Log queue diagnostics every 5 seconds + if current_time - last_queue_check > 5: + logger.info( + "\nQueue diagnostics:\n" + f"\tMessages send to websocket: {self.messages_sent}.\n" + f"\tFeature queue size: ~{self.feature_queue.qsize()}\n" + f"\tRaw data queue size: ~{self.rawdata_queue.qsize()}" + ) + + last_queue_check = current_time + + # Check if stream process is still alive + if not self.stream_process.is_alive(): + break diff --git a/py_neuromodulation/gui/backend/app_socket.py b/py_neuromodulation/gui/backend/app_socket.py index a7afed67..f81cc864 100644 --- a/py_neuromodulation/gui/backend/app_socket.py +++ b/py_neuromodulation/gui/backend/app_socket.py @@ -1,9 +1,10 @@ from fastapi import WebSocket import logging -import struct -import json import cbor2 -class WebSocketManager: +import time + + +class WebsocketManager: """ Manages WebSocket connections and messages. Perhaps in the future it will handle multiple connections. @@ -12,6 +13,12 @@ class WebSocketManager: def __init__(self): self.active_connections: list[WebSocket] = [] self.logger = logging.getLogger("PyNM") + self.disconnected = [] + self._queue_task = None + self._stop_event = None + self.loop = None + self.messages_sent = 0 + self._last_diagnostic_time = time.time() async def connect(self, websocket: WebSocket): await websocket.accept() @@ -30,25 +37,66 @@ def disconnect(self, websocket: WebSocket): f"Client {client_address.port}:{client_address.port} disconnected." ) + async def _cleanup_disconnected(self): + for connection in self.disconnected: + self.active_connections.remove(connection) + await connection.close() + # Combine IP and port to create a unique client ID async def send_cbor(self, object: dict): + if not self.active_connections: + self.logger.warning("No active connection to send message.") + return + + start_time = time.time() + cbor_data = cbor2.dumps(object) + serialize_time = time.time() - start_time + + if serialize_time > 0.1: # Log slow serializations + self.logger.warning(f"CBOR serialization took {serialize_time:.3f}s") + + send_start = time.time() for connection in self.active_connections: try: - await connection.send_bytes(cbor2.dumps(object)) + await connection.send_bytes(cbor_data) except RuntimeError as e: - self.active_connections.remove(connection) + self.logger.error(f"Error sending CBOR message: {e}") + self.disconnected.append(connection) + + send_time = time.time() - send_start + if send_time > 0.1: # Log slow sends + self.logger.warning(f"WebSocket send took {send_time:.3f}s") + + self.messages_sent += 1 + + # Log diagnostics every 5 seconds + current_time = time.time() + if current_time - self._last_diagnostic_time > 5: + self.logger.info(f"Messages sent: {self.messages_sent}") + self._last_diagnostic_time = current_time + + await self._cleanup_disconnected() async def send_message(self, message: str | dict): - self.logger.info(f"Sending message within app_socket: {message.keys()}") - if self.active_connections: - for connection in self.active_connections: + if not self.active_connections: + self.logger.warning("No active connection to send message.") + return + + self.logger.info( + f"Sending message within app_socket: {message.keys() if type(message) is dict else message}" + ) + for connection in self.active_connections: + try: if type(message) is dict: await connection.send_json(message) elif type(message) is str: await connection.send_text(message) - self.logger.info(f"Message sent") - else: - self.logger.warning("No active connection to send message.") + self.logger.info(f"Message sent to {connection.client}") + except RuntimeError as e: + self.logger.error(f"Error sending message: {e}") + self.disconnected.append(connection) + + await self._cleanup_disconnected() @property def is_connected(self): diff --git a/py_neuromodulation/gui/backend/app_utils.py b/py_neuromodulation/gui/backend/app_utils.py index 7fcce8d3..c623bf1b 100644 --- a/py_neuromodulation/gui/backend/app_utils.py +++ b/py_neuromodulation/gui/backend/app_utils.py @@ -6,6 +6,7 @@ from py_neuromodulation.utils.types import _PathLike from functools import lru_cache import platform +from py_neuromodulation import logger def force_terminate_process( @@ -211,20 +212,18 @@ def get_pinned_folders_windows(): """ try: - print(powershell_command) result = subprocess.run( ["powershell", "-Command", powershell_command], capture_output=True, text=True, check=True, ) - print(result.stdout) return json.loads(result.stdout) except subprocess.CalledProcessError as e: - print(f"Error running PowerShell command: {e}") + logger.error(f"Error running PowerShell command: {e}") return [] except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}") + logger.error(f"Error decoding JSON: {e}") return [] @@ -299,9 +298,9 @@ def get_macos_favorites(): path = item["URL"].replace("file://", "") favorites.append({"Name": item["Name"], "Path": path}) except (subprocess.CalledProcessError, json.JSONDecodeError) as e: - print(f"Error processing macOS favorites: {e}") + logger.error(f"Error processing macOS favorites: {e}") except Exception as e: - print(f"Error getting macOS favorites: {e}") + logger.error(f"Error getting macOS favorites: {e}") return favorites diff --git a/py_neuromodulation/stream/backend_interface.py b/py_neuromodulation/stream/backend_interface.py new file mode 100644 index 00000000..568998b0 --- /dev/null +++ b/py_neuromodulation/stream/backend_interface.py @@ -0,0 +1,47 @@ +from typing import Any +import logging +import multiprocessing as mp + + +class StreamBackendInterface: + """Handles stream data output via queues""" + + def __init__( + self, feature_queue: mp.Queue, raw_data_queue: mp.Queue, control_queue: mp.Queue + ): + self.feature_queue = feature_queue + self.rawdata_queue = raw_data_queue + self.control_queue = control_queue + + self.logger = logging.getLogger("PyNM") + + def send_command(self, command: str) -> None: + """Send a command through the control queue""" + try: + self.control_queue.put(command) + except Exception as e: + self.logger.error(f"Error sending command: {e}") + + def send_features(self, features: dict[str, Any]) -> None: + """Send feature data through the feature queue""" + try: + self.feature_queue.put(features) + except Exception as e: + self.logger.error(f"Error sending features: {e}") + + def send_raw_data(self, data: dict[str, Any]) -> None: + """Send raw data through the rawdata queue""" + try: + self.rawdata_queue.put(data) + except Exception as e: + self.logger.error(f"Error sending raw data: {e}") + + def check_control_signals(self) -> str | None: + """Check for control signals (non-blocking)""" + try: + if not self.control_queue.empty(): + return self.control_queue.get_nowait() + return None + except Exception as e: + self.logger.error(f"Error checking control signals: {e}") + return None diff --git a/py_neuromodulation/stream/mnelsl_player.py b/py_neuromodulation/stream/mnelsl_player.py index 086e4d60..bd39dd89 100644 --- a/py_neuromodulation/stream/mnelsl_player.py +++ b/py_neuromodulation/stream/mnelsl_player.py @@ -125,7 +125,7 @@ def start_player( try: self.wait_for_completion() except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping the player...") + logger.info("\nKeyboard interrupt received. Stopping the player...") self.stop_player() def _run_player(self, chunk_size, n_repeat, stop_flag, streaming_complete): @@ -156,7 +156,7 @@ def wait_for_completion(self): if self._streaming_complete.is_set(): break except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping the player...") + logger.info("\nKeyboard interrupt received. Stopping the player...") self.stop_player() break @@ -172,7 +172,7 @@ def stop_player(self): self._player_process.kill() self._player_process = None - print(f"Player stopped: {self.stream_name}") + logger.info(f"Player stopped: {self.stream_name}") LSLOfflinePlayer._instances.discard(self) @classmethod diff --git a/py_neuromodulation/stream/stream.py b/py_neuromodulation/stream/stream.py index 97446085..1ec685f6 100644 --- a/py_neuromodulation/stream/stream.py +++ b/py_neuromodulation/stream/stream.py @@ -1,19 +1,19 @@ """Module for generic and offline data streams.""" import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from collections.abc import Iterator import numpy as np from pathlib import Path import py_neuromodulation as nm -from contextlib import suppress from py_neuromodulation.stream.data_processor import DataProcessor from py_neuromodulation.utils.types import _PathLike, FEATURE_NAME from py_neuromodulation.utils.file_writer import MsgPackFileWriter from py_neuromodulation.stream.settings import NMSettings from py_neuromodulation.analysis.decode import RealTimeDecoder +from py_neuromodulation.stream.backend_interface import StreamBackendInterface if TYPE_CHECKING: import pandas as pd @@ -68,6 +68,7 @@ def __init__( verbose : bool, optional print out stream computation time information, by default True """ + # This is calling NMSettings.validate() which is making a copy self.settings: NMSettings = NMSettings.load(settings) if channels is None and data is not None: @@ -208,14 +209,11 @@ def run( save_interval: int = 10, return_df: bool = True, simulate_real_time: bool = False, - decoder: RealTimeDecoder = None, - feature_queue: "queue.Queue | None" = None, - rawdata_queue: "queue.Queue | None" = None, - stream_handling_queue: "queue.Queue | None" = None, + decoder: RealTimeDecoder | None = None, + backend_interface: StreamBackendInterface | None = None, ): self.is_stream_lsl = is_stream_lsl self.stream_lsl_name = stream_lsl_name - self.stream_handling_queue = stream_handling_queue self.save_csv = save_csv self.save_interval = save_interval self.return_df = return_df @@ -248,7 +246,7 @@ def run( nm.logger.log_to_file(out_dir) self.generator: Iterator - if not is_stream_lsl: + if not is_stream_lsl and data is not None: from py_neuromodulation.stream.generator import RawDataGenerator self.generator = RawDataGenerator( @@ -265,7 +263,10 @@ def run( settings=self.settings, stream_name=stream_lsl_name ) - if self.sfreq != self.lsl_stream.stream.sinfo.sfreq: + if ( + self.lsl_stream.stream.sinfo is not None + and self.sfreq != self.lsl_stream.stream.sinfo.sfreq + ): error_msg = ( f"Sampling frequency of the lsl-stream ({self.lsl_stream.stream.sinfo.sfreq}) " f"does not match the settings ({self.sfreq})." @@ -279,14 +280,14 @@ def run( prev_batch_end = 0 for timestamps, data_batch in self.generator: self.is_running = True - if self.stream_handling_queue is not None: + if backend_interface: + # Only simulate real-time if connected to GUI if simulate_real_time: time.sleep(1 / self.settings.sampling_rate_features_hz) - if not self.stream_handling_queue.empty(): - signal = self.stream_handling_queue.get() - nm.logger.info(f"Got signal: {signal}") - if signal == "stop": - break + + signal = backend_interface.check_control_signals() + if signal == "stop": + break if data_batch is None: nm.logger.info("Data batch is None, stopping run function") @@ -308,7 +309,6 @@ def run( ) feature_dict["time"] = np.ceil(this_batch_end * 1000 + 1) - prev_batch_end = this_batch_end if self.verbose: @@ -316,41 +316,41 @@ def run( self._add_target(feature_dict, data_batch) - with suppress(TypeError): # Need this because some features output None - for key, value in feature_dict.items(): - feature_dict[key] = np.float64(value) - + # Push data to file writer file_writer.insert_data(feature_dict) - if feature_queue is not None: - feature_queue.put(feature_dict) - if rawdata_queue is not None: - # convert raw data into dict with new raw data in unit self.sfreq - new_time_ms = 1000 / self.settings.sampling_rate_features_hz - new_samples = int(new_time_ms * self.sfreq / 1000) - data_batch_dict = {} - data_batch_dict["raw_data"] = {} - for i, ch in enumerate(self.channels["name"]): - # needs to be list since cbor doesn't support np array - data_batch_dict["raw_data"][ch] = list(data_batch[i, -new_samples:]) - rawdata_queue.put(data_batch_dict) + # Send data to frontend + if backend_interface: + backend_interface.send_features(feature_dict) + backend_interface.send_raw_data(self._prepare_raw_data_dict(data_batch)) + + # Save features to file in intervals self.batch_count += 1 if self.batch_count % self.save_interval == 0: file_writer.save() file_writer.save() + if self.save_csv: file_writer.save_as_csv(save_all_combined=True) - if self.return_df: - feature_df = file_writer.load_all() + feature_df = file_writer.load_all() if self.return_df else {} self._save_after_stream() self.is_running = False - return ( - feature_df # Timon: We could think of returning the feature_reader instead - ) + return feature_df # Timon: We could think of returnader instead + + def _prepare_raw_data_dict(self, data_batch: np.ndarray) -> dict[str, Any]: + """Prepare raw data dictionary for sending through queue""" + new_time_ms = 1000 / self.settings.sampling_rate_features_hz + new_samples = int(new_time_ms * self.sfreq / 1000) + return { + "raw_data": { + ch: list(data_batch[i, -new_samples:]) + for i, ch in enumerate(self.channels["name"]) + } + } def plot_raw_signal( self, @@ -386,11 +386,15 @@ def plot_raw_signal( ValueError raise Exception when no data is passed """ - if self.data is None and data is None: - raise ValueError("No data passed to plot_raw_signal function.") - - if data is None and self.data is not None: - data = self.data + if data is None: + if self.data is None: + raise ValueError("No data passed to plot_raw_signal function.") + else: + data = ( + self.data.to_numpy() + if isinstance(self.data, pd.DataFrame) + else self.data + ) if sfreq is None: sfreq = self.sfreq @@ -405,7 +409,7 @@ def plot_raw_signal( from mne import create_info from mne.io import RawArray - info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) # type: ignore raw = RawArray(data, info) if picks is not None: diff --git a/py_neuromodulation/utils/channels.py b/py_neuromodulation/utils/channels.py index c2fffd00..61f5b68e 100644 --- a/py_neuromodulation/utils/channels.py +++ b/py_neuromodulation/utils/channels.py @@ -251,7 +251,7 @@ def _get_default_references( def get_default_channels_from_data( - data: np.ndarray, + data: "np.ndarray | pd.DataFrame", car_rereferencing: bool = True, ): """Return default channels dataframe with