Skip to content

Commit

Permalink
- Refactor handling of async calls to websocket send data functions
Browse files Browse the repository at this point in the history
- Remove queues from Stream and instead create new StreamBackendInterface class to handle communication between Stream and PyNMState
- Improvements to WebsocketManager
- A bunch of typing fixes
  • Loading branch information
toni-neurosc committed Dec 5, 2024
1 parent 802db51 commit 7ca25dc
Show file tree
Hide file tree
Showing 12 changed files with 340 additions and 187 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def select_non_zero_voxels(
coord_arr = np.array(coord_)
ival_non_zero = ival_arr[ival != 0]
coord_non_zero = coord_arr[ival != 0]
print(coord_non_zero.shape)

return coord_non_zero, ival_non_zero

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ def write_connectome_mat(
dict_connectome[
f[f.find("ROI-") + 4 : f.find("_func_seed_AvgR_Fz.nii")]
] = fp

print(idx)
# save the dictionary
sio.savemat(
PATH_CONNECTOME,
Expand Down
46 changes: 35 additions & 11 deletions py_neuromodulation/analysis/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,66 @@
import pandas as pd
import numpy as np
from copy import deepcopy
from pathlib import PurePath
from pathlib import Path
import pickle

from py_neuromodulation import logger
from py_neuromodulation.utils.types import _PathLike

from typing import Callable


class RealTimeDecoder:

def __init__(self, model_path: str):
self.model_path = model_path
if model_path.endswith(".skops"):
def __init__(self, model_path: _PathLike):
self.model_path = Path(model_path)
if not self.model_path.exists():
raise FileNotFoundError(f"Model file {self.model_path} not found")
if not self.model_path.is_file():
raise IsADirectoryError(f"Model file {self.model_path} is a directory")

if self.model_path.suffix == ".skops":
from skops import io as skops_io
self.model = skops_io.load(model_path)

self.model = skops_io.load(self.model_path)
else:
return NotImplementedError("Only skops models are supported")

def predict(self, feature_dict: dict, channel: str = None, fft_bands_only: bool = True ) -> dict:
def predict(
self,
feature_dict: dict,
channel: str | None = None,
fft_bands_only: bool = True,
) -> dict:
try:
if channel is not None:
features_ch = {f: feature_dict[f] for f in feature_dict.keys() if f.startswith(channel)}
features_ch = {
f: feature_dict[f]
for f in feature_dict.keys()
if f.startswith(channel)
}
if fft_bands_only is True:
features_ch_fft = {f: features_ch[f] for f in features_ch.keys() if "fft" in f and "psd" not in f}
out = self.model.predict_proba(np.array(list(features_ch_fft.values())).reshape(1, -1))
features_ch_fft = {
f: features_ch[f]
for f in features_ch.keys()
if "fft" in f and "psd" not in f
}
out = self.model.predict_proba(
np.array(list(features_ch_fft.values())).reshape(1, -1)
)
else:
out = self.model.predict_proba(features_ch)
else:
out = self.model.predict(feature_dict)
for decode_output_idx in range(out.shape[1]):
feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[decode_output_idx]
feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[
decode_output_idx
]
return feature_dict
except Exception as e:
logger.error(f"Error in decoding: {e}")
return feature_dict


class CV_res:
def __init__(
self,
Expand Down
19 changes: 12 additions & 7 deletions py_neuromodulation/gui/backend/app_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic import ValidationError

from . import app_pynm
from .app_socket import WebSocketManager
from .app_socket import WebsocketManager
from .app_utils import is_hidden, get_quick_access
import pandas as pd

Expand All @@ -29,7 +29,6 @@
class PyNMBackend(FastAPI):
def __init__(
self,
pynm_state: app_pynm.PyNMState,
debug=False,
dev=True,
dev_port: int | None = None,
Expand All @@ -48,7 +47,6 @@ def __init__(
cors_origins = (
["http://localhost:" + str(dev_port)] if dev_port is not None else []
)
print(cors_origins)
# Configure CORS
self.add_middleware(
CORSMiddleware,
Expand All @@ -69,8 +67,8 @@ def __init__(
name="static",
)

self.pynm_state = pynm_state
self.websocket_manager = WebSocketManager()
self.websocket_manager = WebsocketManager()
self.pynm_state = app_pynm.PyNMState()

def setup_routes(self):
@self.get("/api/health")
Expand Down Expand Up @@ -158,8 +156,7 @@ async def handle_stream_control(data: dict):

if action == "stop":
self.logger.info("Stopping stream")
self.pynm_state.stream_handling_queue.put("stop")
self.pynm_state.stop_event_ws.set()
self.pynm_state.stop_run_function()

return {"message": f"Stream action '{action}' executed"}

Expand Down Expand Up @@ -321,6 +318,14 @@ async def home_directory():
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

# Get PYNM_DIR
@self.get("/api/pynm_dir")
async def get_pynm_dir():
try:
return {"pynm_dir": PYNM_DIR}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

# Get list of available drives in Windows systems
@self.get("/api/drives")
async def list_drives():
Expand Down
2 changes: 0 additions & 2 deletions py_neuromodulation/gui/backend/app_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ def create_backend():
:return: The web application instance.
:rtype: PyNMBackend
"""
from .app_pynm import PyNMState
from .app_backend import PyNMBackend

debug = os.environ.get("PYNM_DEBUG", "False").lower() == "true"
dev = os.environ.get("PYNM_DEV", "True").lower() == "true"
dev_port = os.environ.get("PYNM_DEV_PORT", str(DEV_SERVER_PORT))

return PyNMBackend(
pynm_state=PyNMState(),
debug=debug,
dev=dev,
dev_port=int(dev_port),
Expand Down
Loading

0 comments on commit 7ca25dc

Please sign in to comment.