diff --git a/gui_dev/data_processor/src/lib.rs b/gui_dev/data_processor/src/lib.rs index d4ac3287..2681d64b 100644 --- a/gui_dev/data_processor/src/lib.rs +++ b/gui_dev/data_processor/src/lib.rs @@ -1,81 +1,137 @@ use wasm_bindgen::prelude::*; use serde_cbor::Value; -use serde_wasm_bindgen::to_value; +use serde_wasm_bindgen::{from_value, Serializer}; use serde::Serialize; -use std::collections::{BTreeMap, BTreeSet}; +use std::collections::{BTreeMap}; +use web_sys::console; #[wasm_bindgen] -pub fn process_cbor_data(data: &[u8]) -> JsValue { - match serde_cbor::from_slice::>(data) { - Ok(decoded_data) => { - let mut data_by_channel: BTreeMap = BTreeMap::new(); - let mut all_features_set: BTreeSet = BTreeSet::new(); +pub fn process_cbor_data(data: &[u8], channels_js: JsValue) -> JsValue { + // Deserialize channels_js into Vec + let channels: Vec = match from_value(channels_js) { + Ok(c) => c, + Err(err) => { + console::error_1(&format!("Failed to parse channels: {:?}", err).into()); + return JsValue::NULL; + } + }; - for (key, value) in decoded_data { - let (channel_name, feature_name) = get_channel_and_feature(&key); + match serde_cbor::from_slice::(data) { + Ok(decoded_value) => { + console::log_1(&format!("Decoded value: {:?}", decoded_value).into()); + if let Value::Map(decoded_map) = decoded_value { + // create output data structures for each graph + let mut psd_data_by_channel: BTreeMap = BTreeMap::new(); + let mut raw_data_by_channel: BTreeMap = BTreeMap::new(); + let mut bandwidth_data_by_channel: BTreeMap> = BTreeMap::new(); + let mut all_data: BTreeMap = BTreeMap::new(); - if channel_name.is_empty() { - continue; - } + let bandwidth_features = vec![ + "fft_theta_mean", + "fft_alpha_mean", + "fft_low_beta_mean", + "fft_high_beta_mean", + "fft_low_gamma_mean", + "fft_high_gamma_mean", + ]; - if !feature_name.starts_with("fft_psd_") { - continue; - } + for (key_value, value) in decoded_map { + let key_str = match key_value { + Value::Text(s) => s, + _ => continue, + }; - let feature_number = &feature_name["fft_psd_".len()..]; - let feature_index = match feature_number.parse::() { - Ok(n) => n, - Err(_) => continue, - }; + // Insert into all_data + all_data.insert(key_str.clone(), value.clone()); - all_features_set.insert(feature_index); + let (channel_name, feature_name) = + get_channel_and_feature(&key_str, &channels); - let channel_data = data_by_channel - .entry(channel_name.clone()) - .or_insert_with(|| ChannelData { - channel_name: channel_name.clone(), - feature_map: BTreeMap::new(), - }); + if channel_name.is_empty() { + continue; + } - channel_data.feature_map.insert(feature_index, value); - } + if feature_name == "raw" { + raw_data_by_channel.insert(channel_name.clone(), value.clone()); + } else if feature_name.starts_with("fft_psd_") { + let feature_number = &feature_name["fft_psd_".len()..]; + let feature_index = match feature_number.parse::() { + Ok(n) => n, + Err(_) => continue, + }; + + let feature_index_str = feature_index.to_string(); - let all_features: Vec = all_features_set.into_iter().collect(); + let channel_data = psd_data_by_channel + .entry(channel_name.clone()) + .or_insert_with(|| ChannelData { + channel_name: channel_name.clone(), + feature_map: BTreeMap::new(), + }); - let result = ProcessedData { - data_by_channel, - all_features, - }; + channel_data + .feature_map + .insert(feature_index_str, value.clone()); + } else if bandwidth_features.contains(&feature_name.as_str()) { - to_value(&result).unwrap_or(JsValue::NULL) + let channel_bandwidth_data = bandwidth_data_by_channel + .entry(channel_name.clone()) + .or_insert_with(BTreeMap::new); + + channel_bandwidth_data.insert(feature_name.clone(), value.clone()); + } + } + + let result = ProcessedData { + psd_data_by_channel, + raw_data_by_channel, + bandwidth_data_by_channel, + all_data, + }; + + // Serialize maps as plain JavaScript objects + let serializer = Serializer::new().serialize_maps_as_objects(true); + match result.serialize(&serializer) { + Ok(js_value) => js_value, + Err(err) => { + console::error_1(&format!("Serialization error: {:?}", err).into()); + JsValue::NULL + } + } + } else { + console::error_1(&"Decoded CBOR data is not a map.".into()); + JsValue::NULL + } } - Err(e) => { - // Optionally log the error for debugging + Err(err) => { + console::error_1(&format!("Failed to decode CBOR data: {:?}", err).into()); JsValue::NULL - }, + } } } -fn get_channel_and_feature(key: &str) -> (String, String) { - // Adjusted to split at the "_fft_psd_" pattern - let pattern = "_fft_psd_"; - if let Some(pos) = key.find(pattern) { - let channel_name = &key[..pos]; - let feature_name = &key[pos + 1..]; // Skip the underscore - (channel_name.to_string(), feature_name.to_string()) - } else { - ("".to_string(), key.to_string()) +fn get_channel_and_feature(key: &str, channels: &[String]) -> (String, String) { + // Iterate over channels to find if the key starts with any channel name + for channel in channels { + if key.starts_with(channel) { + let feature_name = key[channel.len()..].trim_start_matches('_'); + return (channel.clone(), feature_name.to_string()); + } } + // No matching channel found + ("".to_string(), key.to_string()) } #[derive(Serialize)] struct ChannelData { channel_name: String, - feature_map: BTreeMap, + feature_map: BTreeMap, } #[derive(Serialize)] struct ProcessedData { - data_by_channel: BTreeMap, - all_features: Vec, + psd_data_by_channel: BTreeMap, + raw_data_by_channel: BTreeMap, + bandwidth_data_by_channel: BTreeMap>, + all_data: BTreeMap, } diff --git a/gui_dev/src/components/BandPowerGraph.jsx b/gui_dev/src/components/BandPowerGraph.jsx index 7622389c..0663f8b7 100644 --- a/gui_dev/src/components/BandPowerGraph.jsx +++ b/gui_dev/src/components/BandPowerGraph.jsx @@ -10,7 +10,6 @@ import { FormControlLabel, } from "@mui/material"; import { CollapsibleBox } from "./CollapsibleBox"; -import { getChannelAndFeature } from "./utils"; import { shallow } from "zustand/shallow"; const generateColors = (numColors) => { @@ -47,45 +46,27 @@ export const BandPowerGraph = () => { const [selectedChannel, setSelectedChannel] = useState(""); const hasInitialized = useRef(false); - const socketData = useSocketStore((state) => state.graphData); + const processedData = useSocketStore((state) => state.processedData); const data = useMemo(() => { - if (!socketData || !selectedChannel) return null; - const dataByChannel = {}; + if (!processedData || !selectedChannel) return null; - Object.entries(socketData).forEach(([key, value]) => { - const { channelName = "", featureName = "" } = getChannelAndFeature( - availableChannels, - key - ); - if (!channelName) return; - - if (!fftFeatures.includes(featureName)) return; - - if (!dataByChannel[channelName]) { - dataByChannel[channelName] = { - channelName, - features: [], - values: [], - }; - } + const bandwidthDataByChannel = processedData.bandwidth_data_by_channel; - dataByChannel[channelName].features.push(featureName); - dataByChannel[channelName].values.push(value); - }); - - const channelData = dataByChannel[selectedChannel]; + const channelData = bandwidthDataByChannel[selectedChannel]; if (channelData) { - const sortedValues = fftFeatures.map((feature) => { - const index = channelData.features.indexOf(feature); - return index !== -1 ? channelData.values[index] : null; + const features = fftFeatures.map((f) => + f.replace("_mean", "").replace("fft_", "") + ); + const values = fftFeatures.map((feature) => { + const value = channelData[feature]; + return value !== undefined ? value : null; }); + return { channelName: selectedChannel, - features: fftFeatures.map((f) => - f.replace("_mean", "").replace("fft_", "") - ), - values: sortedValues, + features, + values, }; } else { return { @@ -96,7 +77,7 @@ export const BandPowerGraph = () => { values: fftFeatures.map(() => null), }; } - }, [socketData, selectedChannel, availableChannels]); + }, [processedData, selectedChannel]); const graphRef = useRef(null); const plotlyRef = useRef(null); @@ -169,7 +150,7 @@ export const BandPowerGraph = () => { Band Power - + { @@ -37,7 +37,7 @@ export const HeatmapGraph = () => { const [isDataStale, setIsDataStale] = useState(false); const [lastDataTime, setLastDataTime] = useState(null); - const graphData = useSocketStore((state) => state.graphData); + const processedData = useSocketStore((state) => state.processedData); const [maxDataPoints, setMaxDataPoints] = useState(100); @@ -83,9 +83,10 @@ export const HeatmapGraph = () => { }, [usedChannels, selectedChannel]); useEffect(() => { - if (!graphData || !selectedChannel) return; + if (!processedData || !selectedChannel) return; - const dataKeys = Object.keys(graphData); + const allData = processedData.all_data || {}; + const dataKeys = Object.keys(allData); const channelPrefix = `${selectedChannel}_`; const featureKeys = dataKeys.filter( (key) => key.startsWith(channelPrefix) && key !== 'time' @@ -113,17 +114,18 @@ export const HeatmapGraph = () => { setIsDataStale(false); setLastDataTime(null); } - }, [graphData, selectedChannel, features]); + }, [processedData, selectedChannel]); useEffect(() => { if ( - !graphData || + !processedData || !selectedChannel || features.length === 0 || selectedFeatures.length === 0 ) return; + const allData = processedData.all_data || {}; setLastDataTime(Date.now()); setIsDataStale(false); @@ -137,7 +139,7 @@ export const HeatmapGraph = () => { selectedFeatures.forEach((featureName, idx) => { const key = `${selectedChannel}_${featureName}`; - const value = graphData[key]; + const value = allData[key]; const numericValue = typeof value === 'number' && !isNaN(value) ? value : 0; // Shift existing data to the left if necessary @@ -155,7 +157,7 @@ export const HeatmapGraph = () => { setHeatmapData({ x, z }); }, [ - graphData, + processedData, selectedChannel, features, selectedFeatures, @@ -183,13 +185,12 @@ export const HeatmapGraph = () => { plot_bgcolor: '#333', autosize: true, xaxis: { - title: { text: 'Nr. of Samples', font: { color: '#f4f4f4' } }, + title: { text: 'Number of Samples', font: { color: '#f4f4f4' } }, color: '#cccccc', tickfont: { color: '#cccccc', }, automargin: false, - // autorange: 'reversed' }, yaxis: { title: { text: 'Features', font: { color: '#f4f4f4' } }, @@ -300,8 +301,8 @@ export const HeatmapGraph = () => { ]} layout={layout} useResizeHandler={true} - style={{ width: '100%', height: '100%'}} - config={{ responsive: true, displayModeBar: false}} + style={{ width: '100%', height: '100%' }} + config={{ responsive: true, displayModeBar: false }} /> )} diff --git a/gui_dev/src/components/PSDGraph.jsx b/gui_dev/src/components/PSDGraph.jsx index 97ab5dbf..836f7c7e 100644 --- a/gui_dev/src/components/PSDGraph.jsx +++ b/gui_dev/src/components/PSDGraph.jsx @@ -1,5 +1,5 @@ import { useEffect, useRef, useState, useMemo } from "react"; -import { useSocketStore } from "@/stores/socketStore"; +import { useSocketStore } from "@/stores/socketStore"; import { useSessionStore } from "@/stores/sessionStore"; import Plotly from "plotly.js-basic-dist-min"; import { @@ -8,20 +8,20 @@ import { FormControlLabel, Checkbox, } from "@mui/material"; -import { CollapsibleBox } from "./CollapsibleBox"; -import { shallow } from 'zustand/shallow'; +import { CollapsibleBox } from "./CollapsibleBox"; +import { shallow } from 'zustand/shallow'; const generateColors = (numColors) => { const colors = []; for (let i = 0; i < numColors; i++) { - const hue = (i * 360) / numColors; + const hue = (i * 360) / numColors; colors.push(`hsl(${hue}, 100%, 50%)`); } return colors; }; export const PSDGraph = () => { - const channels = useSessionStore((state) => state.channels, shallow); + const channels = useSessionStore((state) => state.channels, shallow); const usedChannels = useMemo( () => channels.filter((channel) => channel.used === 1), @@ -31,43 +31,49 @@ export const PSDGraph = () => { const availableChannels = useMemo( () => usedChannels.map((channel) => channel.name), [usedChannels] - ); + ); const [selectedChannels, setSelectedChannels] = useState([]); const hasInitialized = useRef(false); - - const psdProcessedData = useSocketStore((state) => state.psdProcessedData); - console.log(psdProcessedData); - - const psdData = useMemo(() => { - if (!psdProcessedData) return []; - const dataByChannel = psdProcessedData.data_by_channel || new Map(); - const allFeatures = psdProcessedData.all_features || []; + const processedData = useSocketStore((state) => state.processedData); + + const psdData = useMemo(() => { + if (!processedData || !selectedChannels.length) return []; + + const dataByChannel = processedData.psd_data_by_channel || {}; const selectedData = selectedChannels.map((channelName) => { - const channelData = dataByChannel.get(channelName); + const channelData = dataByChannel[channelName]; if (channelData) { - const values = allFeatures.map((featureIndex) => { - const value = channelData.feature_map.get(featureIndex); + const featureMap = channelData.feature_map || {}; + + // Extract feature indices and sort them numerically + const sortedFeatures = Object.keys(featureMap) + .map(Number) + .sort((a, b) => a - b); + + const values = sortedFeatures.map((featureIndex) => { + const value = featureMap[featureIndex.toString()]; return value !== undefined ? value : null; }); + return { channelName, - features: allFeatures, + features: sortedFeatures, values, }; } else { return { channelName, - features: allFeatures, - values: allFeatures.map(() => null), + features: [], + values: [], }; } }); return selectedData; - }, [psdProcessedData, selectedChannels]); + }, [processedData, selectedChannels]); const graphRef = useRef(null); const plotlyRef = useRef(null); @@ -84,7 +90,7 @@ export const PSDGraph = () => { useEffect(() => { if (usedChannels.length > 0 && !hasInitialized.current) { - const availableChannelNames = usedChannels.map((channel) => channel.name); + const availableChannelNames = usedChannels.map((channel) => channel.name); setSelectedChannels(availableChannelNames); hasInitialized.current = true; } @@ -104,7 +110,7 @@ export const PSDGraph = () => { paper_bgcolor: "#333", plot_bgcolor: "#333", xaxis: { - title: { text: "Feature Index", font: { color: "#f4f4f4" } }, + title: { text: "Frequency (Hz)", font: { color: "#f4f4f4" } }, color: "#cccccc", type: 'linear', }, @@ -138,7 +144,7 @@ export const PSDGraph = () => { .catch((error) => { console.error("Plotly error:", error); }); - }, [psdData, selectedChannels.length]); + }, [psdData]); return ( @@ -152,7 +158,7 @@ export const PSDGraph = () => { PSD Trace - + {usedChannels.map((channel, index) => ( @@ -172,10 +178,10 @@ export const PSDGraph = () => { - +
); diff --git a/gui_dev/src/components/RawDataGraph.jsx b/gui_dev/src/components/RawDataGraph.jsx index 23b4c336..d0f9f6d6 100644 --- a/gui_dev/src/components/RawDataGraph.jsx +++ b/gui_dev/src/components/RawDataGraph.jsx @@ -1,5 +1,5 @@ import { useEffect, useRef, useState, useMemo } from "react"; -import { useSocketStore } from "@/stores"; +import { useSocketStore } from "@/stores/socketStore"; import { useSessionStore } from "@/stores/sessionStore"; import Plotly from "plotly.js-basic-dist-min"; import { @@ -13,10 +13,8 @@ import { Slider, } from "@mui/material"; import { CollapsibleBox } from "./CollapsibleBox"; -import { getChannelAndFeature } from "./utils"; import { shallow } from "zustand/shallow"; -// TODO redundant and might be candidate for refactor const generateColors = (numColors) => { const colors = []; for (let i = 0; i < numColors; i++) { @@ -28,10 +26,10 @@ const generateColors = (numColors) => { export const RawDataGraph = ({ title = "Raw Data", - xAxisTitle = "Nr. of Samples", + xAxisTitle = "Number of Samples", yAxisTitle = "Value", }) => { - const graphData = useSocketStore((state) => state.graphData); + const processedData = useSocketStore((state) => state.processedData); const channels = useSessionStore((state) => state.channels, shallow); @@ -47,17 +45,13 @@ export const RawDataGraph = ({ const [selectedChannels, setSelectedChannels] = useState([]); const hasInitialized = useRef(false); - const [rawData, setRawData] = useState({}); + const [rawDataBuffer, setRawDataBuffer] = useState({}); const graphRef = useRef(null); const plotlyRef = useRef(null); const [yAxisMaxValue, setYAxisMaxValue] = useState("Auto"); const [maxDataPoints, setMaxDataPoints] = useState(100); const layoutRef = useRef({ - // title: { - // text: title, - // font: { color: "#f4f4f4" }, - // }, autosize: true, height: 400, paper_bgcolor: "#333", @@ -74,21 +68,12 @@ export const RawDataGraph = ({ font: { color: "#f4f4f4" }, }, color: "#cccccc", - // autorange: "reversed", - }, - yaxis: { - // title: { - // text: yAxisTitle, - // font: { color: "#f4f4f4" }, - // }, - // color: "#cccccc", }, font: { color: "#f4f4f4", }, }); - // Handling the channel selection here -> TODO: see if this is better done in the socketStore const handleChannelToggle = (channelName) => { setSelectedChannels((prevSelected) => { if (prevSelected.includes(channelName)) { @@ -115,25 +100,16 @@ export const RawDataGraph = ({ } }, [usedChannels]); - // Process incoming graphData to extract raw data for each channel -> TODO: Check later if this fits here better than socketStore useEffect(() => { - const startPSDGraph = performance.now(); - if (!graphData || Object.keys(graphData).length === 0) return; + if (!processedData || !processedData.raw_data_by_channel) return; - const latestData = graphData; + const latestRawData = processedData.raw_data_by_channel; - setRawData((prevRawData) => { + setRawDataBuffer((prevRawData) => { const updatedRawData = { ...prevRawData }; - Object.entries(latestData).forEach(([key, value]) => { - const { channelName = "", featureName = "" } = getChannelAndFeature( - availableChannels, - key - ); - - if (!channelName) return; - - if (featureName !== "raw") return; + Object.entries(latestRawData).forEach(([channelName, value]) => { + if (!availableChannels.includes(channelName)) return; if (!updatedRawData[channelName]) { updatedRawData[channelName] = []; @@ -150,9 +126,7 @@ export const RawDataGraph = ({ return updatedRawData; }); - const endPSDGraph = performance.now(); - console.log("Time taken to process data: ", endPSDGraph - startPSDGraph); - }, [graphData, availableChannels, maxDataPoints]); + }, [processedData, availableChannels, maxDataPoints]); useEffect(() => { if (!graphRef.current) return; @@ -200,7 +174,7 @@ export const RawDataGraph = ({ }); const traces = selectedChannels.map((channelName, idx) => { - const yData = rawData[channelName] || []; + const yData = rawDataBuffer[channelName] || []; const y = yData.slice().reverse(); const x = Array.from({ length: y.length }, (_, i) => i); @@ -219,13 +193,13 @@ export const RawDataGraph = ({ ...layoutRef.current, xaxis: { ...layoutRef.current.xaxis, - autorange: "reversed", + autorange: "reversed", range: [0, maxDataPoints], domain: [0, 1], anchor: totalChannels === 1 ? "y" : `y${totalChannels}`, }, ...yAxes, - height: 350, // TODO height autoadjust to screen + height: 350, }; Plotly.react(graphRef.current, traces, layout, { @@ -238,7 +212,7 @@ export const RawDataGraph = ({ .catch((error) => { console.error("Plotly error:", error); }); - }, [rawData, selectedChannels, yAxisMaxValue, maxDataPoints]); + }, [rawDataBuffer, selectedChannels, yAxisMaxValue, maxDataPoints]); return ( @@ -255,7 +229,6 @@ export const RawDataGraph = ({ - {/* TODO: Fix the typing errors -> How to solve this in jsx? */} {usedChannels.map((channel, index) => ( ({ socket: null, status: "disconnected", // 'disconnected', 'connecting', 'connected' error: null, - psdProcessedData: null, + processedData: null, infoMessages: [], reconnectTimer: null, intentionalDisconnect: false, @@ -72,15 +73,18 @@ export const useSocketStore = createStore("socket", (set, get) => ({ const arrayBuffer = event.data; const uint8Array = new Uint8Array(arrayBuffer); - // Ensure the WASM module is initialized await initWasm(); + const channels = useSessionStore.getState().channels.map( + (channel) => channel.name + ); + // Process CBOR data using Rust module - const processedData = process_cbor_data(uint8Array); + const processedData = process_cbor_data(uint8Array, channels); + console.log("Processed data:", processedData); // Set processed data in store - set({ psdProcessedData: processedData }); - console.log("PSD processed data:", processedData); + set({ processedData }); } catch (error) { console.error("Failed to process CBOR message:", error); }