From 2731d58457c9a93de2d5861393a37853969ce842 Mon Sep 17 00:00:00 2001 From: timonmerk Date: Tue, 26 Nov 2024 20:23:43 +0100 Subject: [PATCH] Add decoding model integration and update session store for model path management --- gui_dev/src/components/DecodingGraph.jsx | 336 ++++++++++++++++++ .../components/FileBrowser/FileBrowser.jsx | 2 +- gui_dev/src/pages/Dashboard.jsx | 31 +- .../pages/SourceSelection/DecodingModel.jsx | 48 +++ .../pages/SourceSelection/SourceSelection.jsx | 2 + gui_dev/src/stores/sessionStore.js | 4 + gui_dev/src/stores/socketStore.js | 34 +- py_neuromodulation/analysis/decode.py | 28 ++ py_neuromodulation/gui/backend/app_backend.py | 1 + py_neuromodulation/gui/backend/app_pynm.py | 13 +- py_neuromodulation/stream/stream.py | 6 + pyproject.toml | 1 + 12 files changed, 493 insertions(+), 13 deletions(-) create mode 100644 gui_dev/src/components/DecodingGraph.jsx create mode 100644 gui_dev/src/pages/SourceSelection/DecodingModel.jsx diff --git a/gui_dev/src/components/DecodingGraph.jsx b/gui_dev/src/components/DecodingGraph.jsx new file mode 100644 index 00000000..7765e1c9 --- /dev/null +++ b/gui_dev/src/components/DecodingGraph.jsx @@ -0,0 +1,336 @@ +import { useEffect, useRef, useState, useMemo } from "react"; +import { useSocketStore } from "@/stores"; +import { useSessionStore } from "@/stores/sessionStore"; +import Plotly from "plotly.js-basic-dist-min"; +import { + Box, + Typography, + FormControlLabel, + Checkbox, + Radio, + RadioGroup, + FormControl, + Slider, +} from "@mui/material"; +import { CollapsibleBox } from "./CollapsibleBox"; +import { getChannelAndFeature } from "./utils"; +import { shallow } from "zustand/shallow"; + +// TODO redundant and might be candidate for refactor +const generateColors = (numColors) => { + const colors = []; + for (let i = 0; i < numColors; i++) { + const hue = (i * 360) / numColors; + colors.push(`hsl(${hue}, 100%, 50%)`); + } + return colors; +}; + +export const DecodingGraph = ({ + title = "Decoding Output", + xAxisTitle = "Nr. of Samples", + yAxisTitle = "Value", +}) => { + //const graphData = useSocketStore((state) => state.graphData); + const graphDecodingData = useSocketStore((state) => state.graphDecodingData); + + //const channels = useSessionStore((state) => state.channels, shallow); + + //const usedChannels = useMemo( + // () => channels.filter((channel) => channel.used === 1), + // [channels] + //); + + //const availableChannels = useMemo( + // () => usedChannels.map((channel) => channel.name), + // [usedChannels] + //); + + const availableDecodingOutputs = useSocketStore((state) => state.availableDecodingOutputs); + + //const [selectedChannels, setSelectedChannels] = useState([]); + const [selectedDecodingOutputs, setSelectedDecodingOutputs] = useState(availableDecodingOutputs); + + const hasInitialized = useRef(false); + //const [rawData, setRawData] = useState({}); + const [decodingData, setDecodingData] = useState({}); + const graphRef = useRef(null); + const plotlyRef = useRef(null); + const [yAxisMaxValue, setYAxisMaxValue] = useState("Auto"); + const [maxDataPointsDecoding, setMaxDataPointsDecoding] = useState(10000); + + const layoutRef = useRef({ + // title: { + // text: title, + // font: { color: "#f4f4f4" }, + // }, + autosize: true, + height: 400, + paper_bgcolor: "#333", + plot_bgcolor: "#333", + margin: { + l: 50, + r: 50, + b: 50, + t: 0, + }, + xaxis: { + title: { + text: xAxisTitle, + font: { color: "#f4f4f4" }, + }, + color: "#cccccc", + autorange: "reversed", + }, + yaxis: { + // title: { + // text: yAxisTitle, + // font: { color: "#f4f4f4" }, + // }, + // color: "#cccccc", + }, + font: { + color: "#f4f4f4", + }, + }); + + // Handling the channel selection here -> TODO: see if this is better done in the socketStore + const handleDecodingOutputToggle = (decodingOutput) => { + setSelectedDecodingOutputs((prevSelected) => { + if (prevSelected.includes(decodingOutput)) { + return prevSelected.filter((name) => name !== decodingOutput); + } else { + return [...prevSelected, decodingOutput]; + } + }); + }; + + const handleYAxisMaxValueChange = (event) => { + setYAxisMaxValue(event.target.value); + }; + + const handleMaxDataPointsChangeDecoding = (event, newValue) => { + setMaxDataPointsDecoding(newValue); + }; + + //useEffect(() => { + // if (usedChannels.length > 0 && !hasInitialized.current) { + // const availableChannelNames = usedChannels.map((channel) => channel.name); + // setSelectedChannels(availableChannelNames); + // hasInitialized.current = true; + // } + //}, [usedChannels]); + + // Process incoming graphData to extract raw data for each channel -> TODO: Check later if this fits here better than socketStore + useEffect(() => { + // if (!graphData || Object.keys(graphData).length === 0) return; + if (!graphDecodingData || Object.keys(graphDecodingData).length === 0) return; + + //const latestData = graphData; + const latestData = graphDecodingData; + + setDecodingData((prevDecodingData) => { + const updatedDecodingData = { ...prevDecodingData }; + + Object.entries(latestData).forEach(([key, value]) => { + //const { channelName = "", featureName = "" } = getChannelAndFeature( + // availableChannels, + // key + //); + + //if (!channelName) return; + + //if (featureName !== "raw") return; + + // filter here for "decoding_xyz" --> this is the channelName + // availableDecodingOutputs might change --> this should lead to + + // check if value is in availableDecodingOutputs + // if not return; + + + const decodingOutput = key; + + if (!selectedDecodingOutputs.includes(key)) return; + + if (!updatedDecodingData[decodingOutput]) { + updatedDecodingData[decodingOutput] = []; + } + + updatedDecodingData[decodingOutput].push(value); + + if (updatedDecodingData[decodingOutput].length > maxDataPointsDecoding) { + updatedDecodingData[decodingOutput] = updatedDecodingData[decodingOutput].slice( + -maxDataPointsDecoding + ); + } + }); + + return updatedDecodingData; + }); + }, [graphDecodingData, availableDecodingOutputs, maxDataPointsDecoding]); + + useEffect(() => { + if (!graphRef.current) return; + + if (selectedDecodingOutputs.length === 0) { + Plotly.purge(graphRef.current); + return; + } + + const colors = generateColors(selectedDecodingOutputs.length); + + const totalDecodingOutputs = selectedDecodingOutputs.length; + const domainHeight = 1 / totalDecodingOutputs; + + const yAxes = {}; + const maxVal = yAxisMaxValue !== "Auto" ? Number(yAxisMaxValue) : null; + + selectedDecodingOutputs.forEach((decodingOutput, idx) => { + const start = 1 - (idx + 1) * domainHeight; + const end = 1 - idx * domainHeight; + + const yAxisKey = `yaxis${idx === 0 ? "" : idx + 1}`; + + yAxes[yAxisKey] = { + domain: [start, end], + nticks: 5, + tickfont: { + size: 10, + color: "#cccccc", + }, + // Titles necessary? Legend works but what if people are color blind? Rotate not supported! Annotations are a possibility though + // title: { + // text: channelName, + // font: { color: "#f4f4f4", size: 12 }, + // standoff: 30, + // textangle: -90, + // }, + color: "#cccccc", + automargin: true, + }; + + if (maxVal !== null) { + yAxes[yAxisKey].range = [-maxVal, maxVal]; + } + }); + + const traces = selectedDecodingOutputs.map((decodingOutput, idx) => { + const yData = decodingData[decodingOutput] || []; + const y = yData.slice().reverse(); + const x = Array.from({ length: y.length }, (_, i) => i); + + return { + x, + y, + type: "scatter", + mode: "lines", + name: decodingOutput, + line: { simplify: false, color: colors[idx] }, + yaxis: idx === 0 ? "y" : `y${idx + 1}`, + }; + }); + + const layout = { + ...layoutRef.current, + xaxis: { + ...layoutRef.current.xaxis, + autorange: "reversed", + range: [0, maxDataPointsDecoding], + domain: [0, 1], + anchor: totalDecodingOutputs === 1 ? "y" : `y${totalDecodingOutputs}`, + }, + ...yAxes, + height: 350, // TODO height autoadjust to screen + }; + + Plotly.react(graphRef.current, traces, layout, { + responsive: true, + displayModeBar: false, + }) + .then((gd) => { + plotlyRef.current = gd; + }) + .catch((error) => { + console.error("Plotly error:", error); + }); + }, [decodingData, selectedDecodingOutputs, yAxisMaxValue, maxDataPointsDecoding]); + + return ( + + + + {title} + + + + + {/* TODO: Fix the typing errors -> How to solve this in jsx? */} + + {availableDecodingOutputs.map((decodingOutput, index) => ( + handleDecodingOutputToggle(decodingOutput)} + color="primary" + /> + } + label={decodingOutput} + /> + ))} + + + + + + + + } label="Auto" /> + } label="1" /> + } label="5" /> + } label="10" /> + } label="20" /> + } label="50" /> + } label="100" /> + } label="500" /> + + + + + + + + Current Value: {maxDataPointsDecoding} + + + + + + + +
+
+ ); +}; diff --git a/gui_dev/src/components/FileBrowser/FileBrowser.jsx b/gui_dev/src/components/FileBrowser/FileBrowser.jsx index 278a4f4d..9cf315c3 100644 --- a/gui_dev/src/components/FileBrowser/FileBrowser.jsx +++ b/gui_dev/src/components/FileBrowser/FileBrowser.jsx @@ -178,7 +178,7 @@ export const FileBrowser = ({ if (file.is_directory) { dispatch({ type: "SET_CURRENT_PATH", payload: file.path }); } else if ( - ALLOWED_EXTENSIONS.some((ext) => file.name.toLowerCase().endsWith(ext)) + allowedExtensions.some((ext) => file.name.toLowerCase().endsWith(ext)) ) { onSelect(file); } diff --git a/gui_dev/src/pages/Dashboard.jsx b/gui_dev/src/pages/Dashboard.jsx index 23b0f9a5..f7b07dcc 100644 --- a/gui_dev/src/pages/Dashboard.jsx +++ b/gui_dev/src/pages/Dashboard.jsx @@ -1,5 +1,6 @@ import { RawDataGraph } from '@/components/RawDataGraph'; import { PSDGraph } from '@/components/PSDGraph'; +import { DecodingGraph } from '@/components/DecodingGraph'; import { HeatmapGraph } from '@/components/HeatmapGraph'; import { BandPowerGraph } from '@/components/BandPowerGraph'; import { Box, Button } from '@mui/material'; @@ -19,7 +20,7 @@ export const Dashboard = () => { return ( <> - + {/* */} { + - {/* PSDGraph */} - + + {/* PSDGraph */} + - {/* Bottom Row - HeatmapGraph and BandPowerGraph */} { + + + + {/* BandPowerGraph */} + + + {/* DecoddingGraph */} + + + + - ); diff --git a/gui_dev/src/pages/SourceSelection/DecodingModel.jsx b/gui_dev/src/pages/SourceSelection/DecodingModel.jsx new file mode 100644 index 00000000..1167b892 --- /dev/null +++ b/gui_dev/src/pages/SourceSelection/DecodingModel.jsx @@ -0,0 +1,48 @@ +import { TitledBox } from "@/components"; +import { MyTextField } from "@/components/utils"; +import { Button } from "@mui/material"; +import { useState } from "react"; +import { FileBrowser } from "@/components"; +import { useSessionStore } from "@/stores"; + +export const DecodingModel = () => { + + const decodingModelPath = useSessionStore((state) => state.decodingModelPath); + const setDecodingModelPath = useSessionStore((state) => state.setDecodingModelPath); + + const [showFileBrowser, setShowFileBrowser] = useState(false); + + const handleFileSelect = (file) => { + setDecodingModelPath(file.path); + setShowFileBrowser(false); + }; + + return ( + +
+ setDecodingModelPath(event.target.value)} + style={{ flexGrow: 1 }} + /> + +
+ {showFileBrowser && ( + setShowFileBrowser(false)} + onSelect={handleFileSelect} + allowedExtensions={[".skops"]} + /> + )} +
+ ); +}; \ No newline at end of file diff --git a/gui_dev/src/pages/SourceSelection/SourceSelection.jsx b/gui_dev/src/pages/SourceSelection/SourceSelection.jsx index 1a19fcd9..cf882582 100644 --- a/gui_dev/src/pages/SourceSelection/SourceSelection.jsx +++ b/gui_dev/src/pages/SourceSelection/SourceSelection.jsx @@ -4,6 +4,7 @@ import { Stack, Typography } from "@mui/material"; import { useSessionStore, WorkflowStage } from "@/stores"; import { LinkButton } from "@/components/utils"; import { StreamParameters } from "./StreamParameters"; +import { DecodingModel } from "./DecodingModel"; export const SourceSelection = () => { const setSourceType = useSessionStore((state) => state.setSourceType); @@ -44,6 +45,7 @@ export const SourceSelection = () => { + ({ selectedStream: null, availableStreams: [], }, + decodingModelPath: "None", + streamParameters: { samplingRate: 1000, lineNoise: 50, @@ -55,6 +57,7 @@ export const useSessionStore = createStore("session", (set, get) => ({ }, setSourceType: (type) => set({ sourceType: type }), + setDecodingModelPath: (path) => set({ decodingModelPath: path }), updateStreamParameter: (field, value) => set((state) => { state.streamParameters[field] = value; @@ -219,6 +222,7 @@ export const useSessionStore = createStore("session", (set, get) => ({ line_noise: streamParameters.lineNoise, experiment_name: streamParameters.experimentName, out_dir: streamParameters.outputDirectory, + decoding_path: get().decodingModelPath, }), }); diff --git a/gui_dev/src/stores/socketStore.js b/gui_dev/src/stores/socketStore.js index c638ba07..4e6d7264 100644 --- a/gui_dev/src/stores/socketStore.js +++ b/gui_dev/src/stores/socketStore.js @@ -10,7 +10,9 @@ export const useSocketStore = createStore("socket", (set, get) => ({ status: "disconnected", // 'disconnected', 'connecting', 'connected' error: null, graphData: [], - graphRawData : [], + graphRawData: [], + graphDecodingData: [], + availableDecodingOutputs: [], infoMessages: [], reconnectTimer: null, intentionalDisconnect: false, @@ -70,12 +72,34 @@ export const useSocketStore = createStore("socket", (set, get) => ({ const decodedData = CBOR.decode(arrayBuffer); // console.log("Decoded message from server:", decodedData); if (Object.keys(decodedData)[0] == "raw_data") { - set({graphRawData: decodedData.raw_data}); + set({ graphRawData: decodedData.raw_data }); } else { - set({graphData: decodedData}); + // check here if there are values in decodedData that start with "decoding" + // if so, set graphDecodingData to the value of those keys + // else, set graphData to decodedData + let decodingData = {}; + let dataNonDecodingFeatures = {}; + //for (const [key, value] of Object.entries(decodedData)) { + // if (key.startsWith("decode")) { + // decodingData[key] = value; + // } else { + // dataNonDecodingFeatures[key] = value; + // } + // } + + // check if this is the same: + Object.entries(decodedData).forEach(([key, value]) => { + (key.startsWith("decode") ? decodingData : dataNonDecodingFeatures)[key] = value; + }); + + set({ availableDecodingOutputs: Object.keys(decodingData) }); + + + + set({ graphDecodingData: decodingData }); + set({ graphData: dataNonDecodingFeatures }); + } - - } catch (error) { console.error("Failed to decode CBOR message:", error); } diff --git a/py_neuromodulation/analysis/decode.py b/py_neuromodulation/analysis/decode.py index 84d10644..4971852b 100644 --- a/py_neuromodulation/analysis/decode.py +++ b/py_neuromodulation/analysis/decode.py @@ -14,6 +14,34 @@ from typing import Callable +class RealTimeDecoder: + + def __init__(self, model_path: str): + self.model_path = model_path + if model_path.endswith(".skops"): + from skops import io as skops_io + self.model = skops_io.load(model_path) + else: + return NotImplementedError("Only skops models are supported") + + def predict(self, feature_dict: dict, channel: str = None, fft_bands_only: bool = True ) -> dict: + try: + if channel is not None: + features_ch = {f: feature_dict[f] for f in feature_dict.keys() if f.startswith(channel)} + if fft_bands_only is True: + features_ch_fft = {f: features_ch[f] for f in features_ch.keys() if "fft" in f and "psd" not in f} + out = self.model.predict_proba(np.array(list(features_ch_fft.values())).reshape(1, -1)) + else: + out = self.model.predict_proba(features_ch) + else: + out = self.model.predict(feature_dict) + for decode_output_idx in range(out.shape[1]): + feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[decode_output_idx] + return feature_dict + except Exception as e: + logger.error(f"Error in decoding: {e}") + return feature_dict + class CV_res: def __init__( self, diff --git a/py_neuromodulation/gui/backend/app_backend.py b/py_neuromodulation/gui/backend/app_backend.py index 328bc34c..06b929f0 100644 --- a/py_neuromodulation/gui/backend/app_backend.py +++ b/py_neuromodulation/gui/backend/app_backend.py @@ -218,6 +218,7 @@ async def set_stream_params(data: dict): self.pynm_state.stream.sfreq = float(data["sampling_rate"]) self.pynm_state.experiment_name = data["experiment_name"] self.pynm_state.out_dir = data["out_dir"] + self.pynm_state.decoding_model_path = data["decoding_path"] return {"message": "Stream parameters updated successfully"} except ValueError: diff --git a/py_neuromodulation/gui/backend/app_pynm.py b/py_neuromodulation/gui/backend/app_pynm.py index fc12f34a..523a4f62 100644 --- a/py_neuromodulation/gui/backend/app_pynm.py +++ b/py_neuromodulation/gui/backend/app_pynm.py @@ -1,3 +1,4 @@ +import os import asyncio import logging import threading @@ -6,6 +7,7 @@ from threading import Thread import queue from py_neuromodulation.stream import Stream, NMSettings +from py_neuromodulation.analysis.decode import RealTimeDecoder from py_neuromodulation.utils import set_channels from py_neuromodulation.utils.io import read_mne_data from py_neuromodulation import logger @@ -43,6 +45,8 @@ def __init__( self.run_func_process = None self.out_dir = None # will be set by set_stream_params self.experiment_name = None # will be set by set_stream_params + self.decoding_model_path = None + self.decoder = None if default_init: self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1])) @@ -62,7 +66,11 @@ def start_run_function( self.logger.info("Starting stream_controller_process thread") - + if self.decoding_model_path is not None and self.decoding_model_path != "None": + if os.path.exists(self.decoding_model_path): + self.decoder = RealTimeDecoder(self.decoding_model_path) + else: + logger.debug("Passed decoding model path does't exist") # Stop event # .set() is called from app_backend self.stop_event_ws = threading.Event() @@ -94,7 +102,8 @@ def start_run_function( "stream_lsl_name" : stream_lsl_name, "feature_queue" : self.feature_queue, "simulate_real_time" : True, - "rawdata_queue" : self.rawdata_queue, + "rawdata_queue" : self.rawdata_queue, + "decoder" : self.decoder, }, ) diff --git a/py_neuromodulation/stream/stream.py b/py_neuromodulation/stream/stream.py index 8fdb9aec..2ee03421 100644 --- a/py_neuromodulation/stream/stream.py +++ b/py_neuromodulation/stream/stream.py @@ -13,6 +13,7 @@ from py_neuromodulation.utils.types import _PathLike, FeatureName from py_neuromodulation.utils.file_writer import MsgPackFileWriter from py_neuromodulation.stream.settings import NMSettings +from py_neuromodulation.analysis.decode import RealTimeDecoder if TYPE_CHECKING: import pandas as pd @@ -207,6 +208,7 @@ def run( save_interval: int = 10, return_df: bool = True, simulate_real_time: bool = False, + decoder: RealTimeDecoder = None, feature_queue: "queue.Queue | None" = None, rawdata_queue: "queue.Queue | None" = None, stream_handling_queue: "queue.Queue | None" = None, @@ -299,6 +301,10 @@ def run( f"{batch_length:.3f} seconds of new data processed", ) + if decoder is not None: + ch_to_decode = self.channels.query("used == 1").iloc[0]["name"] + feature_dict = decoder.predict(feature_dict, ch_to_decode, fft_bands_only=True) + feature_dict["time"] = np.ceil(this_batch_end * 1000 + 1) prev_batch_end = this_batch_end diff --git a/pyproject.toml b/pyproject.toml index 32cea5a5..5facc8ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ dependencies = [ "cbor2>=5.6.4", "msgpack>=1.1.0", "multiprocess>=0.70.17", + "skops>=0.10.0", ] [project.optional-dependencies]