Skip to content

Commit

Permalink
Merge pull request #381 from neuromodulation/refactor_storage
Browse files Browse the repository at this point in the history
Refactor storage
  • Loading branch information
timonmerk authored Nov 20, 2024
2 parents 5e67151 + 557dacf commit a50d357
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 110 deletions.
2 changes: 1 addition & 1 deletion examples/plot_0_first_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def generate_random_walk(NUM_CHANNELS, TIME_DATA_SAMPLES):
# We will therefore use the :class:`~nm_analysis` class to showcase some functions. For multi-run -or subject analysis we will pass here the feature_file "sub" as default directory:

analyzer = nm.FeatureReader(
feature_dir=stream.out_dir_root, feature_file=stream.experiment_name
feature_dir=stream.out_dir, feature_file=stream.experiment_name
)

# %%
Expand Down
8 changes: 4 additions & 4 deletions py_neuromodulation/gui/backend/app_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,17 @@ async def handle_stream_control(data: dict):
self.logger.info(self.websocket_manager)
self.logger.info("Starting stream")

asyncio.create_task(
self.pynm_state.start_run_function(
self.pynm_state.start_run_function(
# out_dir=data["out_dir"],
# experiment_name=data["experiment_name"],
websocket_manager_features=self.websocket_manager,
)
)


if action == "stop":
self.logger.info("Stopping stream")
asyncio.create_task(self.pynm_state.stream_handling_queue.put("stop"))
self.pynm_state.stream_handling_queue.put("stop")
self.pynm_state.stop_event_ws.set()

return {"message": f"Stream action '{action}' executed"}

Expand Down
6 changes: 0 additions & 6 deletions py_neuromodulation/gui/backend/app_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,6 @@ def __init__(
self.is_child_process = os.environ.get(self.LAUNCH_FLAG) == "TRUE"
os.environ[self.LAUNCH_FLAG] = "TRUE"

# PyNM state
# TODO: need to find a way to pass the state to the backend
# self.pynm_state = PyNMState()
# self.app = PyNMBackend(pynm_state=self.pynm_state)

# Logging
self.logger = create_logger(
"PyNM",
"yellow",
Expand Down
97 changes: 68 additions & 29 deletions py_neuromodulation/gui/backend/app_pynm.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,100 @@
import asyncio
import logging
import threading
import numpy as np
from multiprocessing import Process

import multiprocessing as mp
from threading import Thread
import queue
from py_neuromodulation.stream import Stream, NMSettings
from py_neuromodulation.utils import set_channels
from py_neuromodulation.utils.io import read_mne_data

from py_neuromodulation import logger

async def run_stream_controller(feature_queue: queue.Queue, rawdata_queue: queue.Queue,
websocket_manager_features: "WebSocketManager", stop_event: threading.Event):
while not stop_event.wait(0.002):
if not feature_queue.empty() and websocket_manager_features is not None:
feature_dict = feature_queue.get()
logger.info("Sending message to Websocket")
await websocket_manager_features.send_cbor(feature_dict)
# here the rawdata queue could also be used to send raw data, potentiall through different websocket?

def run_stream_controller_sync(feature_queue: queue.Queue,
rawdata_queue: queue.Queue,
websocket_manager_features: "WebSocketManager",
stop_event: threading.Event
):
# The run_stream_controller needs to be started as an asyncio function due to the async websocket
asyncio.run(run_stream_controller(feature_queue, rawdata_queue, websocket_manager_features, stop_event))

class PyNMState:
def __init__(
self,
default_init: bool = True,
default_init: bool = True, # has to be true for the backend settings communication
) -> None:
self.logger = logging.getLogger("uvicorn.error")

self.lsl_stream_name = None
self.stream_controller_process = None
self.run_func_process = None

if default_init:
self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1]))
# TODO: we currently can pass the sampling_rate_features to both the stream and the settings?
self.settings: NMSettings = NMSettings(sampling_rate_features=17)
self.settings: NMSettings = NMSettings(sampling_rate_features=10)

async def start_run_function(

def start_run_function(
self,
out_dir: str = "",
experiment_name: str = "sub",
websocket_manager_features=None,
) -> None:
# TODO: we should add a way to pass the output path and the foldername
# Initialize the stream with as process with a queue that is passed to the stream
# The stream will then put the results in the queue
# there should be another websocket in which the results are sent to the frontend

self.stream.settings = self.settings

self.stream_handling_queue = asyncio.Queue()
self.stream_handling_queue = queue.Queue()
self.feature_queue = queue.Queue()
self.rawdata_queue = queue.Queue()

self.logger.info("setup stream Process")
self.logger.info("Starting stream_controller_process thread")

self.stream.settings = self.settings

asyncio.create_task(self.stream.run(
out_dir=out_dir,
experiment_name=experiment_name,
stream_handling_queue=self.stream_handling_queue,
is_stream_lsl=self.lsl_stream_name is not None,
stream_lsl_name=self.lsl_stream_name
if self.lsl_stream_name is not None
else "",
websocket_featues=websocket_manager_features,
)
# Stop even that is set in the app_backend
self.stop_event_ws = threading.Event()

self.stream_controller_thread = Thread(
target=run_stream_controller_sync,
daemon=True,
args=(self.feature_queue,
self.rawdata_queue,
websocket_manager_features,
self.stop_event_ws
),
)

is_stream_lsl = self.lsl_stream_name is not None
stream_lsl_name = self.lsl_stream_name if self.lsl_stream_name is not None else ""

# The run_func_thread is terminated through the stream_handling_queue
# which initiates to break the data generator and save the features
self.run_func_thread = Thread(
target=self.stream.run,
daemon=True,
kwargs={
"out_dir" : out_dir,
"experiment_name" : experiment_name,
"stream_handling_queue" : self.stream_handling_queue,
"is_stream_lsl" : is_stream_lsl,
"stream_lsl_name" : stream_lsl_name,
"feature_queue" : self.feature_queue,
"simulate_real_time" : True,
#"rawdata_queue" : self.rawdata_queue,
},
)

self.stream_controller_thread.start()
self.run_func_thread.start()

def setup_lsl_stream(
self,
lsl_stream_name: str | None = None,
Expand Down Expand Up @@ -123,11 +167,6 @@ def setup_offline_stream(
target_keywords=None,
)

# self.settings: NMSettings = NMSettings(
# sampling_rate_features=sampling_rate_features
# )

# self.settings.preprocessing = []
self.logger.info(f"settings: {self.settings}")
self.stream: Stream = Stream(
settings=self.settings,
Expand Down
96 changes: 32 additions & 64 deletions py_neuromodulation/stream/stream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Module for generic and offline data streams."""

import asyncio
import time
from typing import TYPE_CHECKING
from collections.abc import Iterator
import numpy as np
Expand All @@ -11,6 +11,7 @@

from py_neuromodulation.stream.data_processor import DataProcessor
from py_neuromodulation.utils.types import _PathLike, FeatureName
from py_neuromodulation.utils.file_writer import MsgPackFileWriter
from py_neuromodulation.stream.settings import NMSettings

if TYPE_CHECKING:
Expand Down Expand Up @@ -197,27 +198,28 @@ def _handle_data(self, data: "np.ndarray | pd.DataFrame") -> np.ndarray:
)
return data.to_numpy().transpose()

async def run(
def run(
self,
data: "np.ndarray | pd.DataFrame | None" = None,
out_dir: _PathLike = "",
experiment_name: str = "sub",
is_stream_lsl: bool = False,
stream_lsl_name: str | None = None,
save_csv: bool = False,
save_csv: bool = True,
save_interval: int = 10,
return_df: bool = True,
# feature_queue: "multiprocessing.Queue | None" = None,
stream_handling_queue: "multiprocessing.Queue | None" = None,
websocket_featues: "WebSocketManager | None" = None,
simulate_real_time: bool = False,
feature_queue: "queue.Queue | None" = None,
stream_handling_queue: "queue.Queue | None" = None,
):
self.is_stream_lsl = is_stream_lsl
self.stream_lsl_name = stream_lsl_name
self.stream_handling_queue = stream_handling_queue
# self.feature_queue = feature_queue
self.save_csv = save_csv
self.save_interval = save_interval
self.return_df = return_df
self.out_dir = Path.cwd() / experiment_name if not out_dir else Path(out_dir)
self.experiment_name = experiment_name

# Validate input data
if data is not None:
Expand All @@ -227,24 +229,10 @@ async def run(
elif self.data is None and data is None and self.is_stream_lsl is False:
raise ValueError("No data passed to run function.")

# Generate output dirs
self.out_dir_root = Path.cwd() if not out_dir else Path(out_dir)
self.out_dir = self.out_dir_root / experiment_name
# TONI: Need better default experiment name
self.experiment_name = experiment_name if experiment_name else "sub"

self.out_dir.mkdir(parents=True, exist_ok=True)

# Open database connection
# TONI: we should give the user control over the save format
from py_neuromodulation.utils.database import NMDatabase

self.db = NMDatabase(experiment_name, out_dir) # Create output database
file_writer = MsgPackFileWriter(name=experiment_name, out_dir=out_dir)

self.batch_count: int = 0 # Keep track of the number of batches processed

# Reinitialize the data processor in case the nm_channels or nm_settings changed between runs of the same Stream
# TONI: then I think we can just not initialize the data processor in the init function
self.data_processor = DataProcessor(
sfreq=self.sfreq,
settings=self.settings,
Expand All @@ -258,14 +246,6 @@ async def run(

nm.logger.log_to_file(out_dir)

# Initialize mp.Pool for multiprocessing
#self.pool = mp.Pool(processes=self.settings.n_jobs)
# Set up shared memory for multiprocessing
#self.shared_memory = mp.Array(ctypes.c_double, self.settings.n_jobs * self.settings.n_jobs)
# Set up multiprocessing semaphores
#self.semaphore = mp.Semaphore(self.settings.n_jobs)

# Initialize generator
self.generator: Iterator
if not is_stream_lsl:
from py_neuromodulation.stream.generator import RawDataGenerator
Expand Down Expand Up @@ -299,16 +279,19 @@ async def run(
for timestamps, data_batch in self.generator:
self.is_running = True
if self.stream_handling_queue is not None:
nm.logger.info("Checking for stop signal")
#await asyncio.sleep(0.001)
await asyncio.sleep(1 / self.settings.sampling_rate_features_hz)
if simulate_real_time:
time.sleep(1 / self.settings.sampling_rate_features_hz)
if not self.stream_handling_queue.empty():
stop_signal = await asyncio.wait_for(self.stream_handling_queue.get(), timeout=0.01)
if stop_signal == "stop":
signal = self.stream_handling_queue.get()
nm.logger.info(f"Got signal: {signal}")
if signal == "stop":
break

if data_batch is None:
nm.logger.info("Data batch is None, stopping run function")
break

nm.logger.info("Processing new data batch")
feature_dict = self.data_processor.process(data_batch)

this_batch_end = timestamps[-1]
Expand All @@ -318,11 +301,6 @@ async def run(
)

feature_dict["time"] = np.ceil(this_batch_end * 1000 + 1)
#(
# np.ceil(batch_length)
# if self.is_stream_lsl
# else
#)

prev_batch_end = this_batch_end

Expand All @@ -331,38 +309,30 @@ async def run(

self._add_target(feature_dict, data_batch)

# We should ensure that feature output is always either float64 or None and remove this
with suppress(TypeError): # Need this because some features output None
for key, value in feature_dict.items():
feature_dict[key] = np.float64(value)

self.db.insert_data(feature_dict)
file_writer.insert_data(feature_dict)
if feature_queue is not None:
feature_queue.put(feature_dict)

# if self.feature_queue is not None:
# self.feature_queue.put(feature_dict)

if websocket_featues is not None:
nm.logger.info("Sending message to Websocket")
#nm.logger.info(feature_dict)
await websocket_featues.send_cbor(feature_dict)
#await websocket_featues.send_message(feature_dict)
self.batch_count += 1
if self.batch_count % self.save_interval == 0:
self.db.commit()

self.db.commit() # Save last batches
file_writer.save()

# If save_csv is False, still save the first row to get the column names
feature_df: "pd.DataFrame" = (
self.db.fetch_all() if (self.save_csv or self.return_df) else self.db.head()
)
file_writer.save()
if self.save_csv:
file_writer.save_as_csv(save_all_combined=True)

self.db.close() # Close the database connection
if self.return_df:
feature_df = file_writer.load_all()

self._save_after_stream(feature_arr=feature_df)
self._save_after_stream()
self.is_running = False

return feature_df # TONI: Not sure if this makes sense anymore
return feature_df # Timon: We could think of returning the feature_reader instead


def plot_raw_signal(
self,
Expand Down Expand Up @@ -430,12 +400,9 @@ def plot_raw_signal(

def _save_after_stream(
self,
feature_arr: "pd.DataFrame | None" = None,
) -> None:
"""Save features, settings, nm_channels and sidecar after run"""
"""Save settings, nm_channels and sidecar after run"""
self._save_sidecar()
if feature_arr is not None:
self._save_features(feature_arr)
self._save_settings()
self._save_channels()

Expand All @@ -455,6 +422,7 @@ def _save_sidecar(self) -> None:
"""Save sidecar incduing fs, coords, sess_right to
out_path_root and subfolder 'folder_name'"""
additional_args = {"sess_right": self.sess_right}

self.data_processor.save_sidecar(
self.out_dir, self.experiment_name, additional_args
)
Loading

0 comments on commit a50d357

Please sign in to comment.