diff --git a/gui_dev/package.json b/gui_dev/package.json index 99fb9b0c..6766b874 100644 --- a/gui_dev/package.json +++ b/gui_dev/package.json @@ -20,7 +20,7 @@ "react-dom": "next", "react-icons": "^5.3.0", "react-plotly.js": "^2.6.0", - "react-router-dom": "^6.28.0", + "react-router-dom": "^7.0.1", "zustand": "latest" }, "devDependencies": { diff --git a/gui_dev/src/App.jsx b/gui_dev/src/App.jsx index 7ebfdc55..a97f8c75 100644 --- a/gui_dev/src/App.jsx +++ b/gui_dev/src/App.jsx @@ -45,8 +45,6 @@ export const App = () => { const connectSocket = useSocketStore((state) => state.connectSocket); const disconnectSocket = useSocketStore((state) => state.disconnectSocket); - connectSocket(); - useEffect(() => { console.log("Connecting socket from App component..."); connectSocket(); diff --git a/gui_dev/src/components/FileBrowser/FileBrowser.jsx b/gui_dev/src/components/FileBrowser/FileBrowser.jsx index 9cf315c3..84b5a91b 100644 --- a/gui_dev/src/components/FileBrowser/FileBrowser.jsx +++ b/gui_dev/src/components/FileBrowser/FileBrowser.jsx @@ -1,5 +1,5 @@ import { useReducer, useEffect } from "react"; -import { getBackendURL } from "@/utils/getBackendURL"; +import { getBackendURL } from "@/utils"; import { Box, Button, @@ -36,7 +36,7 @@ import { import { QuickAccessSidebar } from "./QuickAccess"; import { FileManager } from "@/utils/FileManager"; -const fileManager = new FileManager(""); +const fileManager = new FileManager(getBackendURL("/api/files")); const ALLOWED_EXTENSIONS = [".npy", ".vhdr", ".fif", ".edf", ".bdf"]; diff --git a/gui_dev/src/components/FileBrowser/QuickAccess.jsx b/gui_dev/src/components/FileBrowser/QuickAccess.jsx index 84b70064..6e1253bb 100644 --- a/gui_dev/src/components/FileBrowser/QuickAccess.jsx +++ b/gui_dev/src/components/FileBrowser/QuickAccess.jsx @@ -1,5 +1,5 @@ import React, { useState, useEffect } from "react"; -import { getBackendURL } from "@/utils/getBackendURL"; +import { getBackendURL } from "@/utils"; import { Paper, Typography, diff --git a/gui_dev/src/components/StatusBar/StatusBar.jsx b/gui_dev/src/components/StatusBar/StatusBar.jsx index 36aaf1dc..81f1b32b 100644 --- a/gui_dev/src/components/StatusBar/StatusBar.jsx +++ b/gui_dev/src/components/StatusBar/StatusBar.jsx @@ -2,12 +2,14 @@ import { ResizeHandle } from "./ResizeHandle"; import { SocketStatus } from "./SocketStatus"; import { WebviewStatus } from "./WebviewStatus"; -import { useWebviewStore } from "@/stores"; - +import { useUiStore, useWebviewStore } from "@/stores"; import { Stack } from "@mui/material"; export const StatusBar = () => { - const { isWebView } = useWebviewStore((state) => state.isWebView); + const isWebView = useWebviewStore((state) => state.isWebView); + const getStatusBarContent = useUiStore((state) => state.getStatusBarContent); + + const StatusBarContent = getStatusBarContent(); return ( { bgcolor="background.level1" borderTop="2px solid" borderColor="background.level3" + height="2rem" > - + {StatusBarContent && } + + {/* */} {/* Current experiment */} {/* Current stream */} {/* Current activity */} diff --git a/gui_dev/src/components/TitledBox.jsx b/gui_dev/src/components/TitledBox.jsx index 56f0266b..a86b6009 100644 --- a/gui_dev/src/components/TitledBox.jsx +++ b/gui_dev/src/components/TitledBox.jsx @@ -1,4 +1,4 @@ -import { Container } from "@mui/material"; +import { Box, Container } from "@mui/material"; /** * Component that uses the Box component to render an HTML fieldset element @@ -13,14 +13,14 @@ export const TitledBox = ({ children, ...props }) => ( - {title} {children} - + ); diff --git a/gui_dev/src/main.jsx b/gui_dev/src/main.jsx index f4524069..c1cabb62 100644 --- a/gui_dev/src/main.jsx +++ b/gui_dev/src/main.jsx @@ -1,7 +1,15 @@ +// import { scan } from "react-scan"; +// scan({ +// enabled: true, +// log: true, // logs render info to console +// }); + import { StrictMode } from "react"; import ReactDOM from "react-dom/client"; import { App } from "./App.jsx"; +// Set up react-scan + // Ignore React 19 warning about accessing element.ref const originalConsoleError = console.error; console.error = (message, ...messageArgs) => { diff --git a/gui_dev/src/pages/Settings/DragAndDropList.jsx b/gui_dev/src/pages/Settings/DragAndDropList.jsx deleted file mode 100644 index afa74b02..00000000 --- a/gui_dev/src/pages/Settings/DragAndDropList.jsx +++ /dev/null @@ -1,65 +0,0 @@ -import { useRef } from "react"; -import styles from "./DragAndDropList.module.css"; -import { useOptionsStore } from "@/stores"; - -export const DragAndDropList = () => { - const { options, setOptions, addOption, removeOption } = useOptionsStore(); - - const predefinedOptions = [ - { id: 1, name: "raw_resampling" }, - { id: 2, name: "notch_filter" }, - { id: 3, name: "re_referencing" }, - { id: 4, name: "preprocessing_filter" }, - { id: 5, name: "raw_normalization" }, - ]; - - const dragOption = useRef(0); - const draggedOverOption = useRef(0); - - function handleSort() { - const optionsClone = [...options]; - const temp = optionsClone[dragOption.current]; - optionsClone[dragOption.current] = optionsClone[draggedOverOption.current]; - optionsClone[draggedOverOption.current] = temp; - setOptions(optionsClone); - } - - return ( -
-

List

- {options.map((option, index) => ( -
(dragOption.current = index)} - onDragEnter={() => (draggedOverOption.current = index)} - onDragEnd={handleSort} - onDragOver={(e) => e.preventDefault()} - > -

- {[option.id, ". ", option.name.replace("_", " ")]} -

- -
- ))} -
-

Add Elements

- {predefinedOptions.map((option) => ( - - ))} -
-
- ); -}; diff --git a/gui_dev/src/pages/Settings/DragAndDropList.module.css b/gui_dev/src/pages/Settings/DragAndDropList.module.css deleted file mode 100644 index 76bd7861..00000000 --- a/gui_dev/src/pages/Settings/DragAndDropList.module.css +++ /dev/null @@ -1,72 +0,0 @@ -.dragDropList { - max-width: 400px; - margin: 0 auto; - padding: 20px; -} - -.title { - text-align: center; - color: #333; - margin-bottom: 20px; -} - -.item { - background-color: #f0f0f0; - border-radius: 8px; - padding: 15px; - margin-bottom: 10px; - cursor: move; - transition: background-color 0.3s ease; - display: flex; - justify-content: space-between; - align-items: center; -} - -.item:hover { - background-color: #e0e0e0; -} - -.itemText { - margin: 0; - color: #333; - font-size: 16px; -} - -.removeButton { - background-color: #ff4d4d; - border: none; - border-radius: 4px; - color: white; - padding: 5px 10px; - cursor: pointer; - transition: background-color 0.3s ease; -} - -.removeButton:hover { - background-color: #ff1a1a; -} - -.addSection { - margin-top: 20px; - text-align: center; -} - -.subtitle { - color: #333; - margin-bottom: 10px; -} - -.addButton { - background-color: #4CAF50; - border: none; - border-radius: 4px; - color: white; - padding: 10px 15px; - margin: 5px; - cursor: pointer; - transition: background-color 0.3s ease; -} - -.addButton:hover { - background-color: #45a049; -} \ No newline at end of file diff --git a/gui_dev/src/pages/Settings/Dropdown.jsx b/gui_dev/src/pages/Settings/Dropdown.jsx deleted file mode 100644 index ad662eef..00000000 --- a/gui_dev/src/pages/Settings/Dropdown.jsx +++ /dev/null @@ -1,66 +0,0 @@ -import { useState } from "react"; -import "../App.css"; - -var stringJson = - '{"sampling_rate_features_hz":10.0,"segment_length_features_ms":1000.0,"frequency_ranges_hz":{"theta":{"frequency_low_hz":4.0,"frequency_high_hz":8.0},"alpha":{"frequency_low_hz":8.0,"frequency_high_hz":12.0},"low beta":{"frequency_low_hz":13.0,"frequency_high_hz":20.0},"high beta":{"frequency_low_hz":20.0,"frequency_high_hz":35.0},"low gamma":{"frequency_low_hz":60.0,"frequency_high_hz":80.0},"high gamma":{"frequency_low_hz":90.0,"frequency_high_hz":200.0},"HFA":{"frequency_low_hz":200.0,"frequency_high_hz":400.0}},"preprocessing":["raw_resampling","notch_filter","re_referencing"],"raw_resampling_settings":{"resample_freq_hz":1000.0},"preprocessing_filter":{"bandstop_filter":true,"bandpass_filter":true,"lowpass_filter":true,"highpass_filter":true,"bandstop_filter_settings":{"frequency_low_hz":100.0,"frequency_high_hz":160.0},"bandpass_filter_settings":{"frequency_low_hz":2.0,"frequency_high_hz":200.0},"lowpass_filter_cutoff_hz":200.0,"highpass_filter_cutoff_hz":3.0},"raw_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"postprocessing":{"feature_normalization":true,"project_cortex":false,"project_subcortex":false},"feature_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"project_cortex_settings":{"max_dist_mm":20.0},"project_subcortex_settings":{"max_dist_mm":5.0},"features":{"raw_hjorth":true,"return_raw":true,"bandpass_filter":false,"stft":false,"fft":true,"welch":true,"sharpwave_analysis":true,"fooof":false,"nolds":false,"coherence":false,"bursts":true,"linelength":true,"mne_connectivity":false,"bispectrum":false},"fft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"welch_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"stft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"bandpass_filter_settings":{"segment_lengths_ms":{"theta":1000,"alpha":500,"low beta":333,"high beta":333,"low gamma":100,"high gamma":100,"HFA":100},"bandpower_features":{"activity":true,"mobility":false,"complexity":false},"log_transform":true,"kalman_filter":false},"kalman_filter_settings":{"Tp":0.1,"sigma_w":0.7,"sigma_v":1.0,"frequency_bands":["theta","alpha","low_beta","high_beta","low_gamma","high_gamma","HFA"]},"burst_settings":{"threshold":75.0,"time_duration_s":30.0,"frequency_bands":["low beta","high beta","low gamma"],"burst_features":{"duration":true,"amplitude":true,"burst_rate_per_s":true,"in_burst":true}},"sharpwave_analysis_settings":{"sharpwave_features":{"peak_left":false,"peak_right":false,"trough":false,"width":false,"prominence":true,"interval":true,"decay_time":false,"rise_time":false,"sharpness":true,"rise_steepness":false,"decay_steepness":false,"slope_ratio":false},"filter_ranges_hz":[{"frequency_low_hz":5.0,"frequency_high_hz":80.0},{"frequency_low_hz":5.0,"frequency_high_hz":30.0}],"detect_troughs":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"detect_peaks":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"estimator":{"mean":["interval"],"median":[],"max":["prominence","sharpness"],"min":[],"var":[]},"apply_estimator_between_peaks_and_troughs":true},"mne_connectivity":{"method":"plv","mode":"multitaper"},"coherence":{"features":{"mean_fband":true,"max_fband":true,"max_allfbands":true},"method":{"coh":true,"icoh":true},"channels":[],"frequency_bands":["high beta"]},"fooof":{"aperiodic":{"exponent":true,"offset":true,"knee":true},"periodic":{"center_frequency":false,"band_width":false,"height_over_ap":false},"windowlength_ms":800.0,"peak_width_limits":{"frequency_low_hz":0.5,"frequency_high_hz":12.0},"max_n_peaks":3,"min_peak_height":0.0,"peak_threshold":2.0,"freq_range_hz":{"frequency_low_hz":2.0,"frequency_high_hz":40.0},"knee":true},"nolds_features":{"raw":true,"frequency_bands":["low beta"],"features":{"sample_entropy":false,"correlation_dimension":false,"lyapunov_exponent":true,"hurst_exponent":false,"detrended_fluctuation_analysis":false}},"bispectrum":{"f1s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"f2s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"compute_features_for_whole_fband_range":true,"frequency_bands":["theta","alpha","low_beta","high_beta"],"components":{"absolute":true,"real":true,"imag":true,"phase":true},"bispectrum_features":{"mean":true,"sum":true,"var":true}}}'; -const nm_settings = JSON.parse(stringJson); - -const filterByKeys = (dict, keys) => { - const filteredDict = {}; - keys.forEach((key) => { - if (typeof key === "string") { - // Top-level key - if (dict.hasOwnProperty(key)) { - filteredDict[key] = dict[key]; - } - } else if (typeof key === "object") { - // Nested key - const topLevelKey = Object.keys(key)[0]; - const nestedKeys = key[topLevelKey]; - if ( - dict.hasOwnProperty(topLevelKey) && - typeof dict[topLevelKey] === "object" - ) { - filteredDict[topLevelKey] = filterByKeys(dict[topLevelKey], nestedKeys); - } - } - }); - return filteredDict; -}; - -const Dropdown = ({ onChange, keysToInclude }) => { - const filteredSettings = filterByKeys(nm_settings, keysToInclude); - const [selectedOption, setSelectedOption] = useState(""); - - const handleChange = (event) => { - const newValue = event.target.value; - setSelectedOption(newValue); - onChange(keysToInclude, newValue); - }; - - return ( -
- -
- ); -}; - -export default Dropdown; diff --git a/gui_dev/src/pages/Settings/FrequencyRange.jsx b/gui_dev/src/pages/Settings/FrequencyRange.jsx deleted file mode 100644 index 5def783b..00000000 --- a/gui_dev/src/pages/Settings/FrequencyRange.jsx +++ /dev/null @@ -1,88 +0,0 @@ -import { useState } from "react"; - -// const onChange = (key, newValue) => { -// settings.frequencyRanges[key] = newValue; -// }; - -// Object.entries(settings.frequencyRanges).map(([key, value]) => ( -// -// )); - -/** */ - -/** - * - * @param {String} key - * @param {Array} freqRange - * @param {Function} onChange - * @returns - */ -export const FrequencyRange = ({ settings }) => { - const [frequencyRanges, setFrequencyRanges] = useState(settings || {}); - // Handle changes in the text fields - const handleInputChange = (label, key, newValue) => { - setFrequencyRanges((prevState) => ({ - ...prevState, - [label]: { - ...prevState[label], - [key]: newValue, - }, - })); - }; - - // Add a new band - const addBand = () => { - const newLabel = `Band ${Object.keys(frequencyRanges).length + 1}`; - setFrequencyRanges((prevState) => ({ - ...prevState, - [newLabel]: { frequency_high_hz: "", frequency_low_hz: "" }, - })); - }; - - // Remove a band - const removeBand = (label) => { - const updatedRanges = { ...frequencyRanges }; - delete updatedRanges[label]; - setFrequencyRanges(updatedRanges); - }; - - return ( -
-
Frequency Bands
- {Object.keys(frequencyRanges).map((label) => ( -
- { - const newLabel = e.target.value; - const updatedRanges = { ...frequencyRanges }; - updatedRanges[newLabel] = updatedRanges[label]; - delete updatedRanges[label]; - setFrequencyRanges(updatedRanges); - }} - placeholder="Band Name" - /> - - handleInputChange(label, "frequency_high_hz", e.target.value) - } - placeholder="High Hz" - /> - - handleInputChange(label, "frequency_low_hz", e.target.value) - } - placeholder="Low Hz" - /> - -
- ))} - -
- ); -}; diff --git a/gui_dev/src/pages/Settings/FrequencySettings.module.css b/gui_dev/src/pages/Settings/FrequencySettings.module.css deleted file mode 100644 index a638ad93..00000000 --- a/gui_dev/src/pages/Settings/FrequencySettings.module.css +++ /dev/null @@ -1,67 +0,0 @@ -.container { - background-color: #f9f9f9; /* Light gray background for the container */ - padding: 20px; - border-radius: 10px; /* Rounded corners */ - box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); /* Subtle shadow */ - max-width: 600px; - margin: auto; - } - - .header { - font-size: 1.5rem; - color: #333; /* Darker text color */ - margin-bottom: 20px; - text-align: center; - } - - .bandContainer { - display: flex; - align-items: center; - margin-bottom: 15px; - padding: 10px; - border: 1px solid #ddd; /* Light border for each band */ - border-radius: 8px; /* Rounded corners for individual bands */ - background-color: #fff; /* White background for bands */ - box-shadow: 0 1px 5px rgba(0, 0, 0, 0.1); /* Light shadow for depth */ - } - - .bandNameInput, .frequencyInput { - border: 1px solid #ccc; /* Light gray border */ - border-radius: 5px; /* Slightly rounded corners */ - padding: 8px; - margin-right: 10px; - font-size: 0.875rem; - } - - .bandNameInput::placeholder, .frequencyInput::placeholder { - color: #aaa; /* Light gray placeholder text */ - } - - .removeButton, .addButton { - border: none; - border-radius: 5px; /* Rounded corners */ - padding: 8px 12px; - font-size: 0.875rem; - cursor: pointer; - transition: background-color 0.3s, color 0.3s; - } - - .removeButton { - background-color: #e57373; /* Light red color */ - color: white; - } - - .removeButton:hover { - background-color: #d32f2f; /* Darker red on hover */ - } - - .addButton { - background-color: #4caf50; /* Green color */ - color: white; - margin-top: 10px; - } - - .addButton:hover { - background-color: #388e3c; /* Darker green on hover */ - } - \ No newline at end of file diff --git a/gui_dev/src/pages/Settings/Settings.jsx b/gui_dev/src/pages/Settings/Settings.jsx index e2cf5f2a..69458cbc 100644 --- a/gui_dev/src/pages/Settings/Settings.jsx +++ b/gui_dev/src/pages/Settings/Settings.jsx @@ -1,19 +1,25 @@ +import { useEffect, useState } from "react"; import { + Box, Button, InputAdornment, + Popover, Stack, Switch, TextField, + Tooltip, Typography, } from "@mui/material"; import { Link } from "react-router-dom"; import { CollapsibleBox, TitledBox } from "@/components"; -import { FrequencyRange } from "./FrequencyRange"; -import { useSettingsStore } from "@/stores"; +import { + FrequencyRangeList, + FrequencyRange, +} from "./components/FrequencyRange"; +import { useSettingsStore, useStatusBar } from "@/stores"; import { filterObjectByKeys } from "@/utils/functions"; const formatKey = (key) => { - // console.log(key); return key .split("_") .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) @@ -21,15 +27,42 @@ const formatKey = (key) => { }; // Wrapper components for each type -const BooleanField = ({ value, onChange }) => ( - onChange(e.target.checked)} /> +const BooleanField = ({ label, value, onChange, error }) => ( + + {label} + onChange(e.target.checked)} /> + ); +const errorStyle = { + "& .MuiOutlinedInput-root": { + "& fieldset": { borderColor: "error.main" }, + "&:hover fieldset": { + borderColor: "error.main", + }, + "&.Mui-focused fieldset": { + borderColor: "error.main", + }, + }, +}; -const StringField = ({ value, onChange, label }) => ( - -); +const StringField = ({ label, value, onChange, error }) => { + const errorSx = error ? errorStyle : {}; + return ( + + {label} + onChange(e.target.value)} + label={label} + sx={{ ...errorSx }} + /> + + ); +}; + +const NumberField = ({ label, value, onChange, error, unit }) => { + const errorSx = error ? errorStyle : {}; -const NumberField = ({ value, onChange, label }) => { const handleChange = (event) => { const newValue = event.target.value; // Only allow numbers and decimal point @@ -39,126 +72,237 @@ const NumberField = ({ value, onChange, label }) => { }; return ( - - Hz - - ), - }} - inputProps={{ - pattern: "[0-9]*", - }} - /> + + {label} + + {unit}, + }} + inputProps={{ + pattern: "[0-9]*", + }} + /> + ); }; -const FrequencyRangeField = ({ value, onChange, label }) => ( - -); +const FrequencyRangeField = ({ label, value, onChange, error }) => { + console.log(label, value); + return ; +}; // Map component types to their respective wrappers const componentRegistry = { boolean: BooleanField, string: StringField, + int: NumberField, + float: NumberField, number: NumberField, FrequencyRange: FrequencyRangeField, }; -const SettingsField = ({ path, Component, label, value, onChange, depth }) => { +const SettingsField = ({ + path, + Component, + label, + value, + onChange, + error, + unit, +}) => { return ( - - {label} + onChange(path, newValue)} label={label} + error={error} + unit={unit} /> - + ); }; +// Function to get the error corresponding to this field or its children +const getFieldError = (fieldPath, errors) => { + if (!errors) return null; + + return errors.find((error) => { + const errorPath = error.loc.join("."); + const currentPath = fieldPath.join("."); + return errorPath === currentPath || errorPath.startsWith(currentPath + "."); + }); +}; + const SettingsSection = ({ settings, title = null, path = [], onChange, - depth = 0, + errors, }) => { - if (Object.keys(settings).length === 0) { - return null; - } const boxTitle = title ? title : formatKey(path[path.length - 1]); + const type = typeof settings; + const isObject = type === "object" && !Array.isArray(settings); + const isArray = Array.isArray(settings); - return ( - - {Object.entries(settings).map(([key, value]) => { - if (key === "__field_type__") return null; + // __field_type__ should be always present + if (isObject && !settings.__field_type__) { + throw new Error("Invalid settings object"); + } + const fieldType = isObject ? settings.__field_type__ : type; + const Component = componentRegistry[fieldType]; - const newPath = [...path, key]; - const label = key; - const isPydanticModel = - typeof value === "object" && "__field_type__" in value; + // Case 1: Object or primitive with component -> Don't iterate, render directly + if (Component) { + const value = + isObject && "__value__" in settings ? settings.__value__ : settings; + const unit = isObject && "__unit__" in settings ? settings.__unit__ : null; - const fieldType = isPydanticModel ? value.__field_type__ : typeof value; + return ( + + ); + } - const Component = componentRegistry[fieldType]; + // Case 2: Object without component or Array -> Iterate and render recursively + else { + return ( + + {/* Handle recursing through both objects and arrays */} + {(isArray ? settings : Object.entries(settings)).map((item, index) => { + const [key, value] = isArray ? [index.toString(), item] : item; + if (key.startsWith("__")) return null; // Skip metadata fields + + const newPath = [...path, key]; - if (Component) { - return ( - - ); - } else { return ( ); - } - })} - + })} + + ); + } +}; + +const StatusBarSettingsInfo = () => { + const validationErrors = useSettingsStore((state) => state.validationErrors); + const [anchorEl, setAnchorEl] = useState(null); + const open = Boolean(anchorEl); + + const handleOpenErrorsPopover = (event) => { + setAnchorEl(event.currentTarget); + }; + + const handleCloseErrorsPopover = () => { + setAnchorEl(null); + }; + + return ( + <> + {validationErrors?.length > 0 && ( + <> + + {validationErrors?.length} errors found in Settings + + + + {validationErrors.map((error, index) => ( + + {index} - [{error.type}] {error.msg} + + ))} + + + + )} + ); }; -const SettingsContent = () => { +export const Settings = () => { + // Get all necessary state from the settings store const settings = useSettingsStore((state) => state.settings); - const updateSettings = useSettingsStore((state) => state.updateSettings); + const uploadSettings = useSettingsStore((state) => state.uploadSettings); + const resetSettings = useSettingsStore((state) => state.resetSettings); + const validationErrors = useSettingsStore((state) => state.validationErrors); + useStatusBar(StatusBarSettingsInfo); + + // This is needed so that the frequency ranges stay in order between updates + const frequencyRangeOrder = useSettingsStore( + (state) => state.frequencyRangeOrder + ); + const updateFrequencyRangeOrder = useSettingsStore( + (state) => state.updateFrequencyRangeOrder + ); + // Here I handle the selected feature in the feature settings component + const [selectedFeature, setSelectedFeature] = useState(""); + + useEffect(() => { + uploadSettings(null, true); // validateOnly = true + }, [settings]); + + // Inject validation error info into status bar + + // This has to be after all the hooks, otherwise React will complain if (!settings) { return
Loading settings...
; } - const handleChange = (path, value) => { - updateSettings((settings) => { + // This are the callbacks for the different buttons + const handleChangeSettings = async (path, value) => { + uploadSettings((settings) => { let current = settings; for (let i = 0; i < path.length - 1; i++) { current = current[path[i]]; } current[path[path.length - 1]] = value; - }); + }, true); // validateOnly = true + }; + + const handleSaveSettings = () => { + uploadSettings(() => settings); + }; + + const handleResetSettings = async () => { + await resetSettings(); }; const featureSettingsKeys = Object.keys(settings.features) @@ -181,89 +325,147 @@ const SettingsContent = () => { "project_subcortex_settings", ]; + const generalSettingsKeys = [ + "sampling_rate_features_hz", + "segment_length_features_ms", + ]; + return ( - - + {/* SETTINGS LAYOUT */} + - - - - - + {/* GENERAL SETTINGS + FREQUENCY RANGES */} + + + {generalSettingsKeys.map((key) => ( + + ))} + + + + + + + + {/* POSTPROCESSING + PREPROCESSING SETTINGS */} + {preprocessingSettingsKeys.map((key) => ( ))} - + - + {postprocessingSettingsKeys.map((key) => ( ))} - - + - - {Object.entries(enabledFeatures).map(([feature, featureSettings]) => ( - - - - ))} - - - ); -}; + {/* FEATURE SETTINGS */} + + + + + + + {Object.entries(enabledFeatures).map( + ([feature, featureSettings]) => ( + + + + ) + )} + + + + {/* END SETTINGS LAYOUT */} +
-export const Settings = () => { - return ( - - - + + {/* */} + + + ); }; diff --git a/gui_dev/src/pages/Settings/TextField.jsx b/gui_dev/src/pages/Settings/TextField.jsx deleted file mode 100644 index 86ab07b8..00000000 --- a/gui_dev/src/pages/Settings/TextField.jsx +++ /dev/null @@ -1,77 +0,0 @@ -import { useState, useEffect } from "react"; -import { - Box, - Grid, - TextField as MUITextField, - Typography, -} from "@mui/material"; -import { useSettingsStore } from "@/stores"; -import styles from "./TextField.module.css"; - -const flattenDictionary = (dict, parentKey = "", result = {}) => { - for (let key in dict) { - const newKey = parentKey ? `${parentKey}.${key}` : key; - if (typeof dict[key] === "object" && dict[key] !== null) { - flattenDictionary(dict[key], newKey, result); - } else { - result[newKey] = dict[key]; - } - } - return result; -}; - -const filterByKeys = (flatDict, keys) => { - const filteredDict = {}; - keys.forEach((key) => { - if (flatDict.hasOwnProperty(key)) { - filteredDict[key] = flatDict[key]; - } - }); - return filteredDict; -}; - -export const TextField = ({ keysToInclude }) => { - const settings = useSettingsStore((state) => state.settings); - const flatSettings = flattenDictionary(settings); - const filteredSettings = filterByKeys(flatSettings, keysToInclude); - const [textLabels, setTextLabels] = useState({}); - - useEffect(() => { - setTextLabels(filteredSettings); - }, [settings]); - - const handleTextFieldChange = (label, value) => { - setTextLabels((prevLabels) => ({ - ...prevLabels, - [label]: value, - })); - }; - - // Function to format the label - const formatLabel = (label) => { - const labelAfterDot = label.split(".").pop(); // Get everything after the last dot - return labelAfterDot.replace(/_/g, " "); // Replace underscores with spaces - }; - - return ( -
- {Object.keys(textLabels).map((label, index) => ( -
- - handleTextFieldChange(label, e.target.value)} - className={styles.textFieldInput} - /> -
- ))} -
- ); -}; diff --git a/gui_dev/src/pages/Settings/TextField.module.css b/gui_dev/src/pages/Settings/TextField.module.css deleted file mode 100644 index 37791c40..00000000 --- a/gui_dev/src/pages/Settings/TextField.module.css +++ /dev/null @@ -1,67 +0,0 @@ -/* TextField.module.css */ - -/* Container for the text fields */ -.textFieldContainer { - display: flex; - flex-direction: column; - margin: 1.5rem 0; /* Increased margin for better spacing */ - } - - /* Row for each text field */ - .textFieldRow { - display: flex; - flex-direction: column; /* Stack label and input vertically */ - margin-bottom: 1rem; /* Increased margin for better separation */ - } - - /* Label for each text field */ - .textFieldLabel { - margin-bottom: 0.5rem; /* Space between label and input */ - font-weight: 600; /* Increased weight for better visibility */ - color: #333; /* Dark gray for the label */ - font-size: 1.1rem; /* Increased font size for the label */ - transition: all 0.2s ease; /* Smooth transition for label */ - } - - /* Input field styles */ - .textFieldInput { - padding: 12px 14px; /* Padding for a filled look */ - border: 1px solid #ccc; /* Light gray border */ - border-radius: 4px; /* Rounded corners */ - width: 100%; /* Full width */ - font-size: 1rem; /* Font size */ - background-color: #f5f5f5; /* Light background color for filled effect */ - transition: border-color 0.2s ease, background-color 0.2s ease; /* Smooth transitions */ - box-shadow: none; /* Remove default shadow */ - height: 48px; /* Fixed height for a more square appearance */ - } - - /* Focus styles for the input */ - .textFieldInput:focus { - border-color: #1976d2; /* Blue border color on focus */ - background-color: #fff; /* Change background to white on focus */ - outline: none; /* Remove default outline */ - } - - /* Hover effect for the input */ - .textFieldInput:hover { - border-color: #1976d2; /* Change border color on hover */ - } - - /* Placeholder styles */ - .textFieldInput::placeholder { - color: #aaa; /* Light gray placeholder text */ - opacity: 1; /* Ensure placeholder is fully opaque */ - } - - /* Hide the number input spinners in webkit browsers */ - .textFieldInput::-webkit-inner-spin-button, - .textFieldInput::-webkit-outer-spin-button { - -webkit-appearance: none; /* Remove default styling */ - margin: 0; /* Remove margin */ - } - - /* Hide the number input spinners in Firefox */ - .textFieldInput[type='number'] { - -moz-appearance: textfield; /* Use textfield appearance */ - } \ No newline at end of file diff --git a/gui_dev/src/pages/Settings/components/FrequencyRange.jsx b/gui_dev/src/pages/Settings/components/FrequencyRange.jsx new file mode 100644 index 00000000..7fb50a23 --- /dev/null +++ b/gui_dev/src/pages/Settings/components/FrequencyRange.jsx @@ -0,0 +1,198 @@ +import { useState } from "react"; +import { + TextField, + Button, + IconButton, + Stack, + Typography, +} from "@mui/material"; +import { Add, Close } from "@mui/icons-material"; +import { debounce } from "@/utils"; + +const NumberField = ({ ...props }) => ( + +); + +export const FrequencyRange = ({ + name, + range, + onChangeName, + onChangeRange, + error, + nameEditable = false, +}) => { + console.log(range); + const [localName, setLocalName] = useState(name); + + const debouncedChangeName = debounce((newName) => { + onChangeName(newName, name); + }, 1000); + + const handleNameChange = (e) => { + if (!nameEditable) return; + const newName = e.target.value; + setLocalName(newName); + debouncedChangeName(newName); + }; + + const handleNameBlur = () => { + if (!nameEditable) return; + onChangeName(localName, name); + }; + + const handleKeyPress = (e) => { + if (!nameEditable) return; + if (e.key === "Enter") { + console.log(e.target.value, name); + onChangeName(localName, name); + } + }; + + const handleRangeChange = (name, field, value) => { + // onChangeRange takes the name of the range as the first argument + onChangeRange(name, { ...range, [field]: value }); + }; + + return ( + + {nameEditable ? ( + + ) : ( + {name} + )} + + handleRangeChange(name, "frequency_low_hz", e.target.value) + } + label="Low Hz" + /> + + handleRangeChange(name, "frequency_high_hz", e.target.value) + } + label="High Hz" + /> + + ); +}; + +export const FrequencyRangeList = ({ + ranges, + rangeOrder, + onChange, + onOrderChange, + errors, +}) => { + const handleChangeRange = (name, newRange) => { + const updatedRanges = { ...ranges }; + updatedRanges[name] = newRange; + onChange(["frequency_ranges_hz"], updatedRanges); + }; + + const handleChangeName = (newName, oldName) => { + if (oldName === newName) { + return; + } + + const updatedRanges = { ...ranges, [newName]: ranges[oldName] }; + delete updatedRanges[oldName]; + onChange(["frequency_ranges_hz"], updatedRanges); + + const updatedOrder = rangeOrder.map((name) => + name === oldName ? newName : name + ); + onOrderChange(updatedOrder); + }; + + const handleRemove = (name) => { + const updatedRanges = { ...ranges }; + delete updatedRanges[name]; + onChange(["frequency_ranges_hz"], updatedRanges); + + const updatedOrder = rangeOrder.filter((item) => item !== name); + onOrderChange(updatedOrder); + }; + + const addRange = () => { + let newName = "NewRange"; + let counter = 0; + // Find first available name + while (Object.hasOwn(ranges, newName)) { + counter++; + newName = `NewRange${counter}`; + } + + const updatedRanges = { + ...ranges, + [newName]: { + __field_type__: "FrequencyRange", + frequency_low_hz: 1, + frequency_high_hz: 500, + }, + }; + onChange(["frequency_ranges_hz"], updatedRanges); + + const updatedOrder = [...rangeOrder, newName]; + onOrderChange(updatedOrder); + }; + + return ( + + {rangeOrder.map((name, index) => ( + + + handleRemove(name)} + color="primary" + disableRipple + sx={{ m: 0, p: 0 }} + > + + + + ))} + + + ); +}; diff --git a/gui_dev/src/pages/Settings/index.js b/gui_dev/src/pages/Settings/index.js deleted file mode 100644 index 16355023..00000000 --- a/gui_dev/src/pages/Settings/index.js +++ /dev/null @@ -1 +0,0 @@ -export { TextField } from './TextField'; diff --git a/gui_dev/src/pages/SourceSelection/FileSelector.jsx b/gui_dev/src/pages/SourceSelection/FileSelector.jsx index ffa53bab..e119265a 100644 --- a/gui_dev/src/pages/SourceSelection/FileSelector.jsx +++ b/gui_dev/src/pages/SourceSelection/FileSelector.jsx @@ -5,6 +5,8 @@ import { useSessionStore } from "@/stores"; import { FileBrowser, TitledBox } from "@/components"; +import { getPyNMDirectory } from "@/utils"; + export const FileSelector = () => { const fileSource = useSessionStore((state) => state.fileSource); const setFileSource = useSessionStore((state) => state.setFileSource); @@ -19,16 +21,22 @@ export const FileSelector = () => { ); const setSourceType = useSessionStore((state) => state.setSourceType); - const fileBrowserDirRef = useRef( - "C:\\code\\py_neuromodulation\\py_neuromodulation\\data\\sub-testsub\\ses-EphysMedOff\\ieeg\\sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr" - ); + const fileBrowserDirRef = useRef(""); const [isSelecting, setIsSelecting] = useState(false); const [showFileBrowser, setShowFileBrowser] = useState(false); + const [showFolderBrowser, setShowFolderBrowser] = useState(false); useEffect(() => { setSourceType("lsl"); - }, []); + + const fetchPyNMDirectory = async () => { + const pynmDir = await getPyNMDirectory(); + fileBrowserDirRef.current = + pynmDir + "\\data\\sub-testsub\\ses-EphysMedOff\\ieeg\\"; + }; + fetchPyNMDirectory(); + }, [setSourceType]); const handleFileSelect = (file) => { setIsSelecting(true); @@ -48,6 +56,10 @@ export const FileSelector = () => { } }; + const handleFolderSelect = (folder) => { + setShowFolderBrowser(false); + }; + return ( + {streamSetupMessage && ( { onSelect={handleFileSelect} /> )} + {showFolderBrowser && ( + setShowFolderBrowser(false)} + onSelect={handleFolderSelect} + onlyDirectories={true} + /> + )} ); }; diff --git a/gui_dev/src/stores/appInfoStore.js b/gui_dev/src/stores/appInfoStore.js index c7887fe4..773aa936 100644 --- a/gui_dev/src/stores/appInfoStore.js +++ b/gui_dev/src/stores/appInfoStore.js @@ -1,5 +1,5 @@ import { createStore } from "./createStore"; -import { getBackendURL } from "@/utils/getBackendURL"; +import { getBackendURL } from "@/utils"; export const useAppInfoStore = createStore("appInfo", (set) => ({ version: "", diff --git a/gui_dev/src/stores/createStore.js b/gui_dev/src/stores/createStore.js index 762d9340..e683a35c 100644 --- a/gui_dev/src/stores/createStore.js +++ b/gui_dev/src/stores/createStore.js @@ -2,16 +2,23 @@ import { create } from "zustand"; import { immer } from "zustand/middleware/immer"; import { devtools, persist as persistMiddleware } from "zustand/middleware"; -export const createStore = (name, initializer, persist = false) => { +export const createStore = ( + name, + initializer, + persist = false, + dev = false +) => { const fn = persist ? persistMiddleware(immer(initializer), name) : immer(initializer); - return create( - devtools(fn, { - name: name, - }) - ); + const dev_fn = dev + ? devtools(fn, { + name: name, + }) + : fn; + + return create(dev_fn); }; export const createPersistStore = (name, initializer) => { diff --git a/gui_dev/src/stores/sessionStore.js b/gui_dev/src/stores/sessionStore.js index 959d7fcf..971e6314 100644 --- a/gui_dev/src/stores/sessionStore.js +++ b/gui_dev/src/stores/sessionStore.js @@ -3,7 +3,7 @@ // the data source, stream paramerters, the output files paths, etc import { createStore } from "@/stores/createStore"; -import { getBackendURL } from "@/utils/getBackendURL"; +import { getBackendURL } from "@/utils"; // Workflow stages enum-like object export const WorkflowStage = Object.freeze({ @@ -110,7 +110,6 @@ export const useSessionStore = createStore("session", (set, get) => ({ // Check that all stream parameters are valid checkStreamParameters: () => { - // const { samplingRate, lineNoise, samplingRateFeatures } = get(); set({ areParametersValid: get().streamParameters.samplingRate && diff --git a/gui_dev/src/stores/settingsStore.js b/gui_dev/src/stores/settingsStore.js index 9b17f42d..73574815 100644 --- a/gui_dev/src/stores/settingsStore.js +++ b/gui_dev/src/stores/settingsStore.js @@ -1,36 +1,22 @@ -import { getBackendURL } from "@/utils/getBackendURL"; +import { getBackendURL } from "@/utils"; import { createStore } from "./createStore"; const INITIAL_DELAY = 3000; // wait for Flask const RETRY_DELAY = 1000; // ms const MAX_RETRIES = 100; -const uploadSettingsToServer = async (settings) => { - try { - const response = await fetch(getBackendURL("/api/settings"), { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(settings), - }); - if (!response.ok) { - throw new Error("Failed to update settings"); - } - return { success: true }; - } catch (error) { - console.error("Error updating settings:", error); - return { success: false, error }; - } -}; - export const useSettingsStore = createStore("settings", (set, get) => ({ settings: null, + lastValidSettings: null, + frequencyRangeOrder: [], isLoading: false, error: null, + validationErrors: null, retryCount: 0, - setSettings: (settings) => set({ settings }), + updateLocalSettings: (updater) => { + set((state) => updater(state.settings)); + }, fetchSettingsWithDelay: () => { set({ isLoading: true, error: null }); @@ -46,8 +32,15 @@ export const useSettingsStore = createStore("settings", (set, get) => ({ if (!response.ok) { throw new Error("Failed to fetch settings"); } + const data = await response.json(); - set({ settings: data, retryCount: 0 }); + + set({ + settings: data, + lastValidSettings: data, + frequencyRangeOrder: Object.keys(data.frequency_ranges_hz || {}), + retryCount: 0, + }); } catch (error) { console.log("Error fetching settings:", error); set((state) => ({ @@ -55,6 +48,8 @@ export const useSettingsStore = createStore("settings", (set, get) => ({ retryCount: state.retryCount + 1, })); + console.log(get().retryCount); + if (get().retryCount < MAX_RETRIES) { await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY)); return get().fetchSettings(); @@ -66,29 +61,64 @@ export const useSettingsStore = createStore("settings", (set, get) => ({ resetRetryCount: () => set({ retryCount: 0 }), - updateSettings: async (updater) => { - const currentSettings = get().settings; + resetSettings: async () => { + await get().fetchSettings(true); + }, - // Apply the update optimistically - set((state) => { - updater(state.settings); - }); + updateFrequencyRangeOrder: (newOrder) => { + set({ frequencyRangeOrder: newOrder }); + }, - const newSettings = get().settings; + uploadSettings: async (updater, validateOnly = false) => { + if (updater) { + set((state) => { + updater(state.settings); + }); + } + + const currentSettings = get().settings; try { - const result = await uploadSettingsToServer(newSettings); + const response = await fetch( + getBackendURL( + `/api/settings${validateOnly ? "?validate_only=true" : ""}` + ), + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(currentSettings), + } + ); + + const data = await response.json(); - if (!result.success) { - // Revert the local state if the server update failed - set({ settings: currentSettings }); + if (!response.ok) { + throw new Error("Failed to upload settings to backend"); } - return result; + if (data.valid) { + // Settings are valid + set({ + lastValidSettings: currentSettings, + validationErrors: null, + }); + return true; + } else { + // Settings are invalid + set({ + validationErrors: data.errors, + }); + // Note: We don't revert the settings here, keeping the potentially invalid state + return false; + } } catch (error) { - // Revert the local state if there was an error - set({ settings: currentSettings }); - throw error; + console.error( + `Error ${validateOnly ? "validating" : "updating"} settings:`, + error + ); + return false; } }, })); diff --git a/gui_dev/src/stores/socketStore.js b/gui_dev/src/stores/socketStore.js index 4e6d7264..8e091260 100644 --- a/gui_dev/src/stores/socketStore.js +++ b/gui_dev/src/stores/socketStore.js @@ -1,5 +1,5 @@ import { createStore } from "./createStore"; -import { getBackendURL } from "@/utils/getBackendURL"; +import { getBackendURL } from "@/utils"; import CBOR from "cbor-js"; const WEBSOCKET_URL = getBackendURL("/ws"); @@ -89,16 +89,15 @@ export const useSocketStore = createStore("socket", (set, get) => ({ // check if this is the same: Object.entries(decodedData).forEach(([key, value]) => { - (key.startsWith("decode") ? decodingData : dataNonDecodingFeatures)[key] = value; + (key.startsWith("decode") ? decodingData : dataNonDecodingFeatures)[ + key + ] = value; }); set({ availableDecodingOutputs: Object.keys(decodingData) }); - - set({ graphDecodingData: decodingData }); set({ graphData: dataNonDecodingFeatures }); - } } catch (error) { console.error("Failed to decode CBOR message:", error); diff --git a/gui_dev/src/stores/uiStore.js b/gui_dev/src/stores/uiStore.js index a229ec06..0462871f 100644 --- a/gui_dev/src/stores/uiStore.js +++ b/gui_dev/src/stores/uiStore.js @@ -1,6 +1,7 @@ -import { createPersistStore } from "./createStore"; +import { createStore } from "./createStore"; +import { useEffect } from "react"; -export const useUiStore = createPersistStore("ui", (set, get) => ({ +export const useUiStore = createStore("ui", (set, get) => ({ activeDrawer: null, toggleDrawer: (drawerName) => set((state) => { @@ -28,4 +29,25 @@ export const useUiStore = createPersistStore("ui", (set, get) => ({ state.accordionStates[id] = defaultState; } }), + + // Hook to inject UI elements into the status bar + getStatusBarContent: () => null, + setStatusBarContent: (contentGetter) => + set({ getStatusBarContent: contentGetter }), + clearStatusBarContent: () => set({ getStatusBarContent: () => null }), })); + +// Use this hook from Page components to inject page-specific UI elements into the status bar +export const useStatusBar = (content) => { + const createStatusBarContent = () => content; + + const setStatusBarContent = useUiStore((state) => state.setStatusBarContent); + const clearStatusBarContent = useUiStore( + (state) => state.clearStatusBarContent + ); + + useEffect(() => { + setStatusBarContent(createStatusBarContent); + return () => clearStatusBarContent(); + }, [content, setStatusBarContent, clearStatusBarContent]); +}; diff --git a/gui_dev/src/theme.js b/gui_dev/src/theme.js index 5e2b5f11..cc99909c 100644 --- a/gui_dev/src/theme.js +++ b/gui_dev/src/theme.js @@ -22,6 +22,11 @@ export const theme = createTheme({ disableRipple: true, }, }, + MuiTextField: { + defaultProps: { + autoComplete: "off", + }, + }, MuiStack: { defaultProps: { alignItems: "center", diff --git a/gui_dev/src/utils/FileInfo.js b/gui_dev/src/utils/FileInfo.js new file mode 100644 index 00000000..74ae8347 --- /dev/null +++ b/gui_dev/src/utils/FileInfo.js @@ -0,0 +1,137 @@ +/** + * Represents information about a file or directory in the system + */ +export class FileInfo { + /** + * Creates a new FileInfo instance + * @param {Object} params - The parameters to initialize the FileInfo object + * @param {string} [params.name=''] - The name of the file or directory + * @param {string} [params.path=''] - The full path of the file or directory + * @param {string} [params.dir=''] - The directory containing the file or directory + * @param {boolean} [params.is_directory=false] - Whether the entry is a directory + * @param {number} [params.size=0] - The size of the file in bytes (0 for directories) + * @param {string} [params.created_at=''] - The creation timestamp of the file + * @param {string} [params.modified_at=''] - The last modification timestamp of the file + */ + constructor({ + name = "", + path = "", + dir = "", + is_directory = false, + size = 0, + created_at = "", + modified_at = "", + } = {}) { + this.name = name; + this.path = path; + this.dir = dir; + this.is_directory = is_directory; + this.size = size; + this.created_at = created_at; + this.modified_at = modified_at; + } + + /** + * Creates a FileInfo instance from a plain object + * @param {Object} obj - The object containing file information + * @returns {FileInfo} A new FileInfo instance + */ + static fromObject(obj) { + return new FileInfo(obj); + } + + /** + * Resets all properties to their default values + */ + reset() { + Object.assign(this, new FileInfo()); + } + + /** + * Updates the FileInfo instance with new values + * @param {Partial} updates - The properties to update + */ + update(updates) { + Object.assign(this, updates); + } + + /** + * Gets the file extension + * @returns {string} The file extension (empty string for directories) + */ + getExtension() { + if (this.is_directory) return ""; + const ext = this.name.split(".").pop(); + return ext === this.name ? "" : ext; + } + + /** + * Gets the base name without extension + * @returns {string} The base name + */ + getBaseName() { + if (this.is_directory) return this.name; + const lastDotIndex = this.name.lastIndexOf("."); + return lastDotIndex === -1 ? this.name : this.name.slice(0, lastDotIndex); + } + + /** + * Formats the file size in a human-readable format + * @returns {string} The formatted file size + */ + getFormattedSize() { + if (this.is_directory) return "-"; + const units = ["B", "KB", "MB", "GB", "TB"]; + let size = this.size; + let unitIndex = 0; + + while (size >= 1024 && unitIndex < units.length - 1) { + size /= 1024; + unitIndex++; + } + + return `${Math.round(size * 100) / 100} ${units[unitIndex]}`; + } + + /** + * Checks if the file/directory is hidden + * @returns {boolean} Whether the file/directory is hidden + */ + isHidden() { + return this.name.startsWith("."); + } + + /** + * Creates a plain object representation of the FileInfo instance + * @returns {Object} A plain object containing the file information + */ + toObject() { + return { + name: this.name, + path: this.path, + dir: this.dir, + is_directory: this.is_directory, + size: this.size, + created_at: this.created_at, + modified_at: this.modified_at, + }; + } + + /** + * Creates a clone of the FileInfo instance + * @returns {FileInfo} A new FileInfo instance with the same values + */ + clone() { + return new FileInfo(this.toObject()); + } + + /** + * Compares this FileInfo instance with another + * @param {FileInfo} other - The other FileInfo instance to compare with + * @returns {boolean} Whether the two instances have the same values + */ + equals(other) { + if (!(other instanceof FileInfo)) return false; + return JSON.stringify(this.toObject()) === JSON.stringify(other.toObject()); + } +} diff --git a/gui_dev/src/utils/FileManager.js b/gui_dev/src/utils/FileManager.js index 9065a3ad..aa625b5c 100644 --- a/gui_dev/src/utils/FileManager.js +++ b/gui_dev/src/utils/FileManager.js @@ -1,14 +1,4 @@ -/** - * @typedef {Object} FileInfo - * @property {string} name - The name of the file or directory - * @property {string} path - The full path of the file or directory - * @property {string} dir - The directory containing the file or directory - * @property {boolean} is_directory - Whether the entry is a directory - * @property {number} size - The size of the file in bytes (0 for directories) - * @property {string} created_at - The creation timestamp of the file - * @property {string} modified_at - The last modification timestamp of the file - */ -import { getBackendURL } from "@/utils/getBackendURL"; +import { FileInfo } from "./FileInfo"; /** * Manages file operations and interactions with the file API @@ -41,13 +31,14 @@ export class FileManager { show_hidden: showHidden, }); - const response = await fetch(getBackendURL(`${this.apiBaseUrl}/api/files?${queryParams}`)); + const response = await fetch(`${this.apiBaseUrl}?${queryParams}`); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } - return await response.json(); + const filesData = await response.json(); + return filesData.map((fileData) => FileInfo.fromObject(fileData)); } /** diff --git a/gui_dev/src/utils/debounced_sync.js b/gui_dev/src/utils/debounced_sync.js index bdc2054f..0ad34469 100644 --- a/gui_dev/src/utils/debounced_sync.js +++ b/gui_dev/src/utils/debounced_sync.js @@ -1,5 +1,5 @@ import { debounce } from "@/utils"; -import { getBackendURL } from "@/utils/getBackendURL"; +import { getBackendURL } from "@/utils"; const DEBOUNCE_MS = 500; // Adjust as needed @@ -22,28 +22,28 @@ const syncWithBackend = async (state) => { const debouncedSync = debounce(syncWithBackend, DEBOUNCE_MS); - /*****************************/ - /******** BACKEND SYNC *******/ - /*****************************/ - - // Wrap state updates with sync logic - setState: async (newState) => { - set((state) => ({ ...state, ...newState, syncStatus: "syncing" })); - try { - await debouncedSync(get()); - set({ syncStatus: "synced", syncError: null }); - } catch (error) { - set({ syncStatus: "error", syncError: error.message }); - } - }, - - // // Use this for actions that need immediate sync - // setStateAndSync: async (newState) => { - // set((state) => ({ ...state, ...newState, syncStatus: "syncing" })); - // try { - // const syncedState = await syncWithBackend(get()); - // set({ ...syncedState, syncStatus: "synced", syncError: null }); - // } catch (error) { - // set({ syncStatus: "error", syncError: error.message }); - // } - // }, \ No newline at end of file +/*****************************/ +/******** BACKEND SYNC *******/ +/*****************************/ + +// Wrap state updates with sync logic +setState: async (newState) => { + set((state) => ({ ...state, ...newState, syncStatus: "syncing" })); + try { + await debouncedSync(get()); + set({ syncStatus: "synced", syncError: null }); + } catch (error) { + set({ syncStatus: "error", syncError: error.message }); + } +}, + +// // Use this for actions that need immediate sync +// setStateAndSync: async (newState) => { +// set((state) => ({ ...state, ...newState, syncStatus: "syncing" })); +// try { +// const syncedState = await syncWithBackend(get()); +// set({ ...syncedState, syncStatus: "synced", syncError: null }); +// } catch (error) { +// set({ syncStatus: "error", syncError: error.message }); +// } +// } \ No newline at end of file diff --git a/gui_dev/src/utils/getBackendURL.js b/gui_dev/src/utils/getBackendURL.js deleted file mode 100644 index b35c21de..00000000 --- a/gui_dev/src/utils/getBackendURL.js +++ /dev/null @@ -1,3 +0,0 @@ -export const getBackendURL = (route) => { - return "http://localhost:50001" + route; -} diff --git a/gui_dev/src/utils/index.js b/gui_dev/src/utils/index.js index f9810c70..f257e714 100644 --- a/gui_dev/src/utils/index.js +++ b/gui_dev/src/utils/index.js @@ -1 +1 @@ -export * from "./functions"; +export * from "./utils"; diff --git a/gui_dev/src/utils/functions.js b/gui_dev/src/utils/utils.js similarity index 69% rename from gui_dev/src/utils/functions.js rename to gui_dev/src/utils/utils.js index 04289d5c..c81709d9 100644 --- a/gui_dev/src/utils/functions.js +++ b/gui_dev/src/utils/utils.js @@ -17,6 +17,7 @@ export function debounce(func, wait) { timeout = setTimeout(later, wait); }; } + export const flattenDictionary = (dict, parentKey = "", result = {}) => { for (let key in dict) { const newKey = parentKey ? `${parentKey}.${key}` : key; @@ -32,7 +33,7 @@ export const flattenDictionary = (dict, parentKey = "", result = {}) => { export const filterObjectByKeys = (flatDict, keys) => { const filteredDict = {}; keys.forEach((key) => { - if (flatDict.hasOwnProperty(key)) { + if (Object.hasOwn(flatDict, key)) { filteredDict[key] = flatDict[key]; } }); @@ -48,3 +49,22 @@ export const filterObjectByKeyPrefix = (obj, prefix = "") => { } return result; }; + +export const getBackendURL = (route) => { + return "http://localhost:" + import.meta.env.VITE_BACKEND_PORT + route; +}; + +/** + * Fetches PyNeuromodulation directory from the backend + * @returns {string} PyNeuromodulation directory + */ +export const getPyNMDirectory = async () => { + const response = await fetch(getBackendURL("/api/pynm_dir")); + if (!response.ok) { + throw new Error("Failed to fetch settings"); + } + + const data = await response.json(); + + return data.pynm_dir; +}; diff --git a/gui_dev/vite.config.js b/gui_dev/vite.config.js index f0abbec9..792593d7 100644 --- a/gui_dev/vite.config.js +++ b/gui_dev/vite.config.js @@ -48,8 +48,5 @@ export default defineConfig(() => { }, }, }, - server: { - port: 54321, - }, }; }); diff --git a/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py b/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py index fa97e424..4f4409f6 100644 --- a/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py +++ b/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py @@ -70,7 +70,6 @@ def select_non_zero_voxels( coord_arr = np.array(coord_) ival_non_zero = ival_arr[ival != 0] coord_non_zero = coord_arr[ival != 0] - print(coord_non_zero.shape) return coord_non_zero, ival_non_zero diff --git a/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py b/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py index 2238ff6b..89cd5b2e 100644 --- a/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py +++ b/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py @@ -58,8 +58,6 @@ def write_connectome_mat( dict_connectome[ f[f.find("ROI-") + 4 : f.find("_func_seed_AvgR_Fz.nii")] ] = fp - - print(idx) # save the dictionary sio.savemat( PATH_CONNECTOME, diff --git a/py_neuromodulation/__init__.py b/py_neuromodulation/__init__.py index 688393fd..864df345 100644 --- a/py_neuromodulation/__init__.py +++ b/py_neuromodulation/__init__.py @@ -4,11 +4,12 @@ from importlib.metadata import version from py_neuromodulation.utils.logging import NMLogger + ##################################### # Globals and environment variables # ##################################### -__version__ = version("py_neuromodulation") # get version from pyproject.toml +__version__ = version("py_neuromodulation") # Check if the module is running headless (no display) for tests and doc builds PYNM_HEADLESS: bool = not os.environ.get("DISPLAY") @@ -18,6 +19,7 @@ os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file + # Set environment variable MNE_LSL_LIB (required to import Stream below) LSL_DICT = { "windows_32bit": "windows/x86/liblsl.1.16.2.dll", @@ -36,6 +38,7 @@ PLATFORM = platform.system().lower().strip() ARCH = platform.architecture()[0] + match PLATFORM: case "windows": KEY = PLATFORM + "_" + ARCH diff --git a/py_neuromodulation/analysis/decode.py b/py_neuromodulation/analysis/decode.py index 4971852b..722f7192 100644 --- a/py_neuromodulation/analysis/decode.py +++ b/py_neuromodulation/analysis/decode.py @@ -6,42 +6,66 @@ import pandas as pd import numpy as np from copy import deepcopy -from pathlib import PurePath +from pathlib import Path import pickle from py_neuromodulation import logger +from py_neuromodulation.utils.types import _PathLike from typing import Callable class RealTimeDecoder: - - def __init__(self, model_path: str): - self.model_path = model_path - if model_path.endswith(".skops"): + def __init__(self, model_path: _PathLike): + self.model_path = Path(model_path) + if not self.model_path.exists(): + raise FileNotFoundError(f"Model file {self.model_path} not found") + if not self.model_path.is_file(): + raise IsADirectoryError(f"Model file {self.model_path} is a directory") + + if self.model_path.suffix == ".skops": from skops import io as skops_io - self.model = skops_io.load(model_path) + + self.model = skops_io.load(self.model_path) else: return NotImplementedError("Only skops models are supported") - def predict(self, feature_dict: dict, channel: str = None, fft_bands_only: bool = True ) -> dict: + def predict( + self, + feature_dict: dict, + channel: str | None = None, + fft_bands_only: bool = True, + ) -> dict: try: if channel is not None: - features_ch = {f: feature_dict[f] for f in feature_dict.keys() if f.startswith(channel)} + features_ch = { + f: feature_dict[f] + for f in feature_dict.keys() + if f.startswith(channel) + } if fft_bands_only is True: - features_ch_fft = {f: features_ch[f] for f in features_ch.keys() if "fft" in f and "psd" not in f} - out = self.model.predict_proba(np.array(list(features_ch_fft.values())).reshape(1, -1)) + features_ch_fft = { + f: features_ch[f] + for f in features_ch.keys() + if "fft" in f and "psd" not in f + } + out = self.model.predict_proba( + np.array(list(features_ch_fft.values())).reshape(1, -1) + ) else: out = self.model.predict_proba(features_ch) else: out = self.model.predict(feature_dict) for decode_output_idx in range(out.shape[1]): - feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[decode_output_idx] + feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[ + decode_output_idx + ] return feature_dict except Exception as e: logger.error(f"Error in decoding: {e}") return feature_dict + class CV_res: def __init__( self, diff --git a/py_neuromodulation/default_settings.yaml b/py_neuromodulation/default_settings.yaml index c50b5f8e..e68194f8 100644 --- a/py_neuromodulation/default_settings.yaml +++ b/py_neuromodulation/default_settings.yaml @@ -1,4 +1,15 @@ ---- +# We +# should +# have +# a +# brief +# explanation +# of +# the +# settings +# format +# here + ######################## ### General settings ### ######################## @@ -51,12 +62,8 @@ preprocessing_filter: lowpass_filter: true highpass_filter: true bandpass_filter: true - bandstop_filter_settings: - frequency_low_hz: 100 - frequency_high_hz: 160 - bandpass_filter_settings: - frequency_low_hz: 3 - frequency_high_hz: 200 + bandstop_filter_settings: [100, 160] # [low_hz, high_hz] + bandpass_filter_settings: [3, 200] # [hz, _hz] lowpass_filter_cutoff_hz: 200 highpass_filter_cutoff_hz: 3 @@ -162,11 +169,9 @@ sharpwave_analysis_settings: rise_steepness: false decay_steepness: false slope_ratio: false - filter_ranges_hz: - - frequency_low_hz: 5 - frequency_high_hz: 80 - - frequency_low_hz: 5 - frequency_high_hz: 30 + filter_ranges_hz: # list of [low_hz, high_hz] + - [5, 80] + - [5, 30] detect_troughs: estimate: true distance_troughs_ms: 10 @@ -175,6 +180,7 @@ sharpwave_analysis_settings: estimate: true distance_troughs_ms: 5 distance_peaks_ms: 10 + # TONI: Reverse this setting? e.g. interval: [mean, var] estimator: mean: [interval] median: [] @@ -223,8 +229,8 @@ nolds_settings: frequency_bands: [low_beta] mne_connectiviy_settings: - method: plv - mode: multitaper + method: plv # One of ['coh', 'cohy', 'imcoh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli','wpli', 'wpli2_debiased', 'gc', 'gc_tr'] + mode: multitaper # One of ['multitaper', 'fourier', 'cwt_morlet'] bispectrum_settings: f1s: [5, 35] diff --git a/py_neuromodulation/features/bandpower.py b/py_neuromodulation/features/bandpower.py index dac13f4b..72dab126 100644 --- a/py_neuromodulation/features/bandpower.py +++ b/py_neuromodulation/features/bandpower.py @@ -2,8 +2,13 @@ from collections.abc import Sequence from typing import TYPE_CHECKING from pydantic import field_validator - from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature +from py_neuromodulation.utils.pydantic_extensions import ( + NMField, + NMErrorList, + create_validation_error, +) +from py_neuromodulation import logger if TYPE_CHECKING: from py_neuromodulation.stream.settings import NMSettings @@ -17,15 +22,18 @@ class BandpowerFeatures(BoolSelector): class BandPowerSettings(NMBaseModel): - segment_lengths_ms: dict[str, int] = { - "theta": 1000, - "alpha": 500, - "low beta": 333, - "high beta": 333, - "low gamma": 100, - "high gamma": 100, - "HFA": 100, - } + segment_lengths_ms: dict[str, int] = NMField( + default={ + "theta": 1000, + "alpha": 500, + "low beta": 333, + "high beta": 333, + "low gamma": 100, + "high gamma": 100, + "HFA": 100, + }, + custom_metadata={"field_type": "FrequencySegmentLength"}, + ) bandpower_features: BandpowerFeatures = BandpowerFeatures() log_transform: bool = True kalman_filter: bool = False @@ -33,24 +41,58 @@ class BandPowerSettings(NMBaseModel): @field_validator("bandpower_features") @classmethod def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures): - assert ( - len(bandpower_features.get_enabled()) > 0 - ), "Set at least one bandpower_feature to True." - + if not len(bandpower_features.get_enabled()) > 0: + raise create_validation_error( + error_message="Set at least one bandpower_feature to True.", + location=["bandpass_filter_settings", "bandpower_features"], + ) return bandpower_features - def validate_fbands(self, settings: "NMSettings") -> None: + def validate_fbands(self, settings: "NMSettings") -> NMErrorList: + """_summary_ + + :param settings: _description_ + :type settings: NMSettings + :raises create_validation_error: _description_ + :raises create_validation_error: _description_ + :raises ValueError: _description_ + """ + errors: NMErrorList = NMErrorList() + for fband_name, seg_length_fband in self.segment_lengths_ms.items(): - assert seg_length_fband <= settings.segment_length_features_ms, ( - f"segment length {seg_length_fband} needs to be smaller than " - f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}" - ) + # Check that all frequency bands are defined in settings.frequency_ranges_hz + if fband_name not in settings.frequency_ranges_hz: + logger.warning( + f"Frequency band {fband_name} in bandpass_filter_settings.segment_lengths_ms" + " is not defined in settings.frequency_ranges_hz" + ) + + # Check that all segment lengths are smaller than settings.segment_length_features_ms + if not seg_length_fband <= settings.segment_length_features_ms: + errors.add_error( + f"segment length {seg_length_fband} needs to be smaller than " + f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}", + location=[ + "bandpass_filter_settings", + "segment_lengths_ms", + fband_name, + ], + ) + # Check that all frequency bands defined in settings.frequency_ranges_hz for fband_name in settings.frequency_ranges_hz.keys(): - assert fband_name in self.segment_lengths_ms, ( - f"frequency range {fband_name} " - "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]" - ) + if fband_name not in self.segment_lengths_ms: + errors.add_error( + f"frequency range {fband_name} " + "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms", + location=[ + "bandpass_filter_settings", + "segment_lengths_ms", + fband_name, + ], + ) + + return errors class BandPower(NMFeature): diff --git a/py_neuromodulation/features/bursts.py b/py_neuromodulation/features/bursts.py index 83d63bf8..e9f5b632 100644 --- a/py_neuromodulation/features/bursts.py +++ b/py_neuromodulation/features/bursts.py @@ -7,11 +7,12 @@ from collections.abc import Sequence from itertools import product -from pydantic import Field, field_validator +from pydantic import field_validator from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature +from py_neuromodulation.utils.pydantic_extensions import NMField from typing import TYPE_CHECKING, Callable -from py_neuromodulation.utils.types import create_validation_error +from py_neuromodulation.utils.pydantic_extensions import create_validation_error if TYPE_CHECKING: from py_neuromodulation import NMSettings @@ -46,8 +47,8 @@ class BurstFeatures(BoolSelector): class BurstsSettings(NMBaseModel): - threshold: float = Field(default=75, ge=0, le=100) - time_duration_s: float = Field(default=30, ge=0) + threshold: float = NMField(default=75, ge=0) + time_duration_s: float = NMField(default=30, ge=0, custom_metadata={"unit": "s"}) frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"] burst_features: BurstFeatures = BurstFeatures() diff --git a/py_neuromodulation/features/coherence.py b/py_neuromodulation/features/coherence.py index 21ca471b..6a100710 100644 --- a/py_neuromodulation/features/coherence.py +++ b/py_neuromodulation/features/coherence.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Annotated -from pydantic import Field, field_validator +from pydantic import field_validator from py_neuromodulation.utils.types import ( NMFeature, @@ -11,6 +11,7 @@ FrequencyRange, NMBaseModel, ) +from py_neuromodulation.utils.pydantic_extensions import NMField from py_neuromodulation import logger if TYPE_CHECKING: @@ -28,14 +29,16 @@ class CoherenceFeatures(BoolSelector): max_allfbands: bool = True -ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)] +# TODO: make this into a pydantic model that only accepts names from +# the channels objects and does not accept the same string twice +ListOfTwoStr = Annotated[list[str], NMField(min_length=2, max_length=2)] class CoherenceSettings(NMBaseModel): features: CoherenceFeatures = CoherenceFeatures() method: CoherenceMethods = CoherenceMethods() channels: list[ListOfTwoStr] = [] - frequency_bands: list[str] = Field(default=["high_beta"], min_length=1) + frequency_bands: list[str] = NMField(default=["high_beta"], min_length=1) @field_validator("frequency_bands") def fbands_spaces_to_underscores(cls, frequency_bands): diff --git a/py_neuromodulation/features/feature_processor.py b/py_neuromodulation/features/feature_processor.py index 9e970e8f..dbb7e5c9 100644 --- a/py_neuromodulation/features/feature_processor.py +++ b/py_neuromodulation/features/feature_processor.py @@ -1,13 +1,13 @@ from typing import Type, TYPE_CHECKING -from py_neuromodulation.utils.types import NMFeature, FeatureName +from py_neuromodulation.utils.types import NMFeature, FEATURE_NAME if TYPE_CHECKING: import numpy as np from py_neuromodulation import NMSettings -FEATURE_DICT: dict[FeatureName | str, str] = { +FEATURE_DICT: dict[FEATURE_NAME | str, str] = { "raw_hjorth": "Hjorth", "return_raw": "Raw", "bandpass_filter": "BandPower", @@ -42,7 +42,7 @@ def __init__( from importlib import import_module # Accept 'str' for custom features - self.features: dict[FeatureName | str, NMFeature] = { + self.features: dict[FEATURE_NAME | str, NMFeature] = { feature_name: getattr( import_module("py_neuromodulation.features"), FEATURE_DICT[feature_name] )(settings, ch_names, sfreq) @@ -83,7 +83,7 @@ def estimate_features(self, data: "np.ndarray") -> dict: return feature_results - def get_feature(self, fname: FeatureName) -> NMFeature: + def get_feature(self, fname: FEATURE_NAME) -> NMFeature: return self.features[fname] diff --git a/py_neuromodulation/features/fooof.py b/py_neuromodulation/features/fooof.py index 683c5304..9f4b9544 100644 --- a/py_neuromodulation/features/fooof.py +++ b/py_neuromodulation/features/fooof.py @@ -9,6 +9,7 @@ BoolSelector, FrequencyRange, ) +from py_neuromodulation.utils.pydantic_extensions import NMField if TYPE_CHECKING: from py_neuromodulation import NMSettings @@ -29,11 +30,11 @@ class FooofPeriodicSettings(BoolSelector): class FooofSettings(NMBaseModel): aperiodic: FooofAperiodicSettings = FooofAperiodicSettings() periodic: FooofPeriodicSettings = FooofPeriodicSettings() - windowlength_ms: float = 800 + windowlength_ms: float = NMField(800, gt=0, custom_metadata={"unit": "ms"}) peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12) - max_n_peaks: int = 3 - min_peak_height: float = 0 - peak_threshold: float = 2 + max_n_peaks: int = NMField(3, ge=0) + min_peak_height: float = NMField(0, ge=0) + peak_threshold: float = NMField(2, ge=0) freq_range_hz: FrequencyRange = FrequencyRange(2, 40) knee: bool = True diff --git a/py_neuromodulation/features/mne_connectivity.py b/py_neuromodulation/features/mne_connectivity.py index 88883771..f7a186e7 100644 --- a/py_neuromodulation/features/mne_connectivity.py +++ b/py_neuromodulation/features/mne_connectivity.py @@ -1,8 +1,9 @@ from collections.abc import Iterable import numpy as np -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from py_neuromodulation.utils.types import NMFeature, NMBaseModel +from py_neuromodulation.utils.pydantic_extensions import NMField if TYPE_CHECKING: from py_neuromodulation import NMSettings @@ -10,9 +11,30 @@ from mne import Epochs +MNE_CONNECTIVITY_METHOD = Literal[ + "coh", + "cohy", + "imcoh", + "cacoh", + "mic", + "mim", + "plv", + "ciplv", + "ppc", + "pli", + "dpli", + "wpli", + "wpli2_debiased", + "gc", + "gc_tr", +] + +MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"] + + class MNEConnectivitySettings(NMBaseModel): - method: str = "plv" - mode: str = "multitaper" + method: MNE_CONNECTIVITY_METHOD = NMField(default="plv") + mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper") class MNEConnectivity(NMFeature): diff --git a/py_neuromodulation/features/oscillatory.py b/py_neuromodulation/features/oscillatory.py index ba376cc6..bb4b9e6a 100644 --- a/py_neuromodulation/features/oscillatory.py +++ b/py_neuromodulation/features/oscillatory.py @@ -3,6 +3,7 @@ from itertools import product from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature +from py_neuromodulation.utils.pydantic_extensions import NMField from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -17,7 +18,7 @@ class OscillatoryFeatures(BoolSelector): class OscillatorySettings(NMBaseModel): - windowlength_ms: int = 1000 + windowlength_ms: int = NMField(1000, gt=0, custom_metadata={"unit": "ms"}) log_transform: bool = True features: OscillatoryFeatures = OscillatoryFeatures( mean=True, median=False, std=False, max=False diff --git a/py_neuromodulation/filter/kalman_filter.py b/py_neuromodulation/filter/kalman_filter.py index b7ae1075..0ebfe195 100644 --- a/py_neuromodulation/filter/kalman_filter.py +++ b/py_neuromodulation/filter/kalman_filter.py @@ -1,7 +1,9 @@ import numpy as np from typing import TYPE_CHECKING + from py_neuromodulation.utils.types import NMBaseModel +from py_neuromodulation.utils.pydantic_extensions import NMErrorList if TYPE_CHECKING: @@ -22,13 +24,22 @@ class KalmanSettings(NMBaseModel): "HFA", ] - def validate_fbands(self, settings: "NMSettings") -> None: - assert all( + def validate_fbands(self, settings: "NMSettings") -> NMErrorList: + errors: NMErrorList = NMErrorList() + + if not all( (item in settings.frequency_ranges_hz for item in self.frequency_bands) - ), ( - "Frequency bands for Kalman filter must also be specified in " - "bandpass_filter_settings." - ) + ): + errors.add_error( + "Frequency bands for Kalman filter must also be specified in " + "frequency_ranges_hz.", + location=[ + "kalman_filter_settings", + "frequency_bands", + ], + ) + + return errors def define_KF(Tp, sigma_w, sigma_v): diff --git a/py_neuromodulation/gui/backend/app_backend.py b/py_neuromodulation/gui/backend/app_backend.py index 06b929f0..f64e837d 100644 --- a/py_neuromodulation/gui/backend/app_backend.py +++ b/py_neuromodulation/gui/backend/app_backend.py @@ -1,5 +1,4 @@ import logging -import asyncio import importlib.metadata from datetime import datetime from pathlib import Path @@ -13,9 +12,10 @@ ) from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware +from pydantic import ValidationError from . import app_pynm -from .app_socket import WebSocketManager +from .app_socket import WebsocketManager from .app_utils import is_hidden, get_quick_access import pandas as pd @@ -29,9 +29,9 @@ class PyNMBackend(FastAPI): def __init__( self, - pynm_state: app_pynm.PyNMState, debug=False, dev=True, + dev_port: int | None = None, fastapi_kwargs: dict = {}, ) -> None: super().__init__(debug=debug, **fastapi_kwargs) @@ -43,14 +43,18 @@ def __init__( self.logger = logging.getLogger("uvicorn.error") self.logger.warning(PYNM_DIR) - # Configure CORS - self.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost:54321"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) + if dev: + cors_origins = ( + ["http://localhost:" + str(dev_port)] if dev_port is not None else [] + ) + # Configure CORS + self.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) # Has to be before mounting static files self.setup_routes() @@ -63,8 +67,8 @@ def __init__( name="static", ) - self.pynm_state = pynm_state - self.websocket_manager = WebSocketManager() + self.websocket_manager = WebsocketManager() + self.pynm_state = app_pynm.PyNMState() def setup_routes(self): @self.get("/api/health") @@ -74,21 +78,63 @@ async def healthcheck(): #################### ##### SETTINGS ##### #################### - @self.get("/api/settings") - async def get_settings(): - return self.pynm_state.settings.process_for_frontend() + async def get_settings( + reset: bool = Query(False, description="Reset settings to default"), + ): + if reset: + settings = NMSettings.get_default() + else: + settings = self.pynm_state.settings + + return settings.serialize_with_metadata() @self.post("/api/settings") - async def update_settings(data: dict): + async def update_settings(data: dict, validate_only: bool = Query(False)): try: - self.pynm_state.settings = NMSettings.model_validate(data) - self.logger.info(self.pynm_state.settings.features) - return self.pynm_state.settings.model_dump() - except ValueError as e: + # First, validate with Pydantic + try: + # TODO: check if this works properly or needs model_validate_strings + validated_settings = NMSettings.model_validate(data) + except ValidationError as e: + self.logger.error(f"Error validating settings: {e}") + if not validate_only: + # If validation failed but we wanted to upload, return error + raise HTTPException( + status_code=422, + detail={ + "error": "Error validating settings", + "details": str(e), + }, + ) + # Else return list of errors + return { + "valid": False, + "errors": [err for err in e.errors()], + "details": str(e), + } + + # If validation succesful, return or update settings + if validate_only: + return { + "valid": True, + "settings": validated_settings.serialize_with_metadata(), + } + + self.pynm_state.settings = validated_settings + self.logger.info("Settings successfully updated") + + return { + "valid": True, + "settings": self.pynm_state.settings.serialize_with_metadata(), + } + + # If something else than validation went wrong, return error + except Exception as e: + self.logger.error(f"Error validating/updating settings: {e}") raise HTTPException( status_code=422, - detail={"error": "Validation failed", "details": str(e)}, + detail={"error": "Error uploading settings", "details": str(e)}, ) ######################## @@ -105,14 +151,12 @@ async def handle_stream_control(data: dict): self.logger.info("Starting stream") self.pynm_state.start_run_function( - websocket_manager=self.websocket_manager, + websocket_manager=self.websocket_manager, ) - if action == "stop": self.logger.info("Stopping stream") - self.pynm_state.stream_handling_queue.put("stop") - self.pynm_state.stop_event_ws.set() + self.pynm_state.stop_run_function() return {"message": f"Stream action '{action}' executed"} @@ -274,6 +318,14 @@ async def home_directory(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + # Get PYNM_DIR + @self.get("/api/pynm_dir") + async def get_pynm_dir(): + try: + return {"pynm_dir": PYNM_DIR} + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + # Get list of available drives in Windows systems @self.get("/api/drives") async def list_drives(): diff --git a/py_neuromodulation/gui/backend/app_manager.py b/py_neuromodulation/gui/backend/app_manager.py index b1cb6971..8bec1023 100644 --- a/py_neuromodulation/gui/backend/app_manager.py +++ b/py_neuromodulation/gui/backend/app_manager.py @@ -12,24 +12,40 @@ if TYPE_CHECKING: from multiprocessing.synchronize import Event - from .app_backend import PyNMBackend # Shared memory configuration ARRAY_SIZE = 1000 # Adjust based on your needs +SERVER_PORT = 50001 +DEV_SERVER_PORT = 54321 + -def create_backend() -> "PyNMBackend": - from .app_pynm import PyNMState +def create_backend(): + """Factory function passed to Uvicorn to create the web application instance. + + :return: The web application instance. + :rtype: PyNMBackend + """ from .app_backend import PyNMBackend debug = os.environ.get("PYNM_DEBUG", "False").lower() == "true" dev = os.environ.get("PYNM_DEV", "True").lower() == "true" + dev_port = os.environ.get("PYNM_DEV_PORT", str(DEV_SERVER_PORT)) - return PyNMBackend(pynm_state=PyNMState(), debug=debug, dev=dev) + return PyNMBackend( + debug=debug, + dev=dev, + dev_port=int(dev_port), + ) -def run_vite(shutdown_event: "Event", debug: bool = False) -> None: +def run_vite( + shutdown_event: "Event", + debug: bool = False, + dev_port: int = DEV_SERVER_PORT, + backend_port: int = SERVER_PORT, +) -> None: """Run Vite in a separate shell""" import subprocess @@ -41,6 +57,8 @@ def run_vite(shutdown_event: "Event", debug: bool = False) -> None: logging.DEBUG if debug else logging.INFO, ) + os.environ["VITE_BACKEND_PORT"] = str(backend_port) + def output_reader(shutdown_event: "Event", process: subprocess.Popen): logger.debug("Initialized output stream") color = ansi_color(color="magenta", bright=True, styles=["BOLD"]) @@ -73,7 +91,7 @@ def read_stream(stream, stream_name): subprocess_flags = subprocess.CREATE_NEW_PROCESS_GROUP if os.name == "nt" else 0 process = subprocess.Popen( - "bun run dev", + "bun run dev --port " + str(dev_port), cwd="gui_dev", stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -106,7 +124,9 @@ def read_stream(stream, stream_name): logger.info("Development server stopped") -def run_uvicorn(debug: bool = False, reload=False) -> None: +def run_uvicorn( + debug: bool = False, reload=False, server_port: int = SERVER_PORT +) -> None: from uvicorn.server import Server from uvicorn.config import LOGGING_CONFIG, Config @@ -131,7 +151,7 @@ def run_uvicorn(debug: bool = False, reload=False) -> None: host="localhost", reload=reload, factory=True, - port=50001, + port=server_port, log_level="debug" if debug else "info", log_config=log_config, ) @@ -160,17 +180,23 @@ def restart(self) -> None: def run_backend( - shutdown_event: "Event", debug: bool = False, reload: bool = True, dev: bool = True + shutdown_event: "Event", + dev: bool = True, + debug: bool = False, + reload: bool = True, + server_port: int = SERVER_PORT, + dev_port: int = DEV_SERVER_PORT, ) -> None: signal.signal(signal.SIGINT, signal.SIG_IGN) # Pass create_backend parameters through environment variables os.environ["PYNM_DEBUG"] = str(debug) os.environ["PYNM_DEV"] = str(dev) + os.environ["PYNM_DEV_PORT"] = str(dev_port) server_process = mp.Process( target=run_uvicorn, - kwargs={"debug": debug, "reload": reload}, + kwargs={"debug": debug, "reload": reload, "server_port": server_port}, name="Server", ) server_process.start() @@ -182,7 +208,12 @@ class AppManager: LAUNCH_FLAG = "PYNM_RUNNING" def __init__( - self, debug: bool = False, dev: bool = True, run_in_webview=False + self, + debug: bool = False, + dev: bool = True, + run_in_webview=False, + server_port=SERVER_PORT, + dev_port=DEV_SERVER_PORT, ) -> None: """_summary_ @@ -197,6 +228,9 @@ def __init__( self.debug = debug self.dev = dev self.run_in_webview = run_in_webview + self.server_port = server_port + self.dev_port = dev_port + self._reset() # Prevent launching multiple instances of the app due to multiprocessing # This allows the absence of a main guard in the main script @@ -270,7 +304,12 @@ def launch(self) -> None: self.logger.info("Starting Vite server...") self.tasks["vite"] = mp.Process( target=run_vite, - kwargs={"shutdown_event": self.shutdown_event, "debug": self.debug}, + kwargs={ + "shutdown_event": self.shutdown_event, + "debug": self.debug, + "dev_port": self.dev_port, + "backend_port": self.server_port, + }, name="Vite", ) @@ -282,6 +321,8 @@ def launch(self) -> None: "debug": self.debug, "reload": self.dev, "dev": self.dev, + "server_port": self.server_port, + "dev_port": self.dev_port, }, name="Backend", ) diff --git a/py_neuromodulation/gui/backend/app_pynm.py b/py_neuromodulation/gui/backend/app_pynm.py index 523a4f62..23641786 100644 --- a/py_neuromodulation/gui/backend/app_pynm.py +++ b/py_neuromodulation/gui/backend/app_pynm.py @@ -1,152 +1,145 @@ import os -import asyncio -import logging -import threading import numpy as np -import multiprocessing as mp from threading import Thread -import queue +import time +import asyncio +import multiprocessing as mp +from queue import Empty +from pathlib import Path from py_neuromodulation.stream import Stream, NMSettings from py_neuromodulation.analysis.decode import RealTimeDecoder from py_neuromodulation.utils import set_channels from py_neuromodulation.utils.io import read_mne_data +from py_neuromodulation.utils.types import _PathLike +from py_neuromodulation import logger +from py_neuromodulation.gui.backend.app_socket import WebsocketManager +from py_neuromodulation.stream.backend_interface import StreamBackendInterface from py_neuromodulation import logger -async def run_stream_controller(feature_queue: queue.Queue, rawdata_queue: queue.Queue, - websocket_manager: "WebSocketManager", stop_event: threading.Event): - while not stop_event.wait(0.002): - if not feature_queue.empty() and websocket_manager is not None: - feature_dict = feature_queue.get() - logger.info("Sending message to Websocket") - await websocket_manager.send_cbor(feature_dict) - - if not rawdata_queue.empty() and websocket_manager is not None: - raw_data = rawdata_queue.get() - - await websocket_manager.send_cbor(raw_data) - -def run_stream_controller_sync(feature_queue: queue.Queue, - rawdata_queue: queue.Queue, - websocket_manager: "WebSocketManager", - stop_event: threading.Event - ): - # The run_stream_controller needs to be started as an asyncio function due to the async websocket - asyncio.run(run_stream_controller(feature_queue, rawdata_queue, websocket_manager, stop_event)) class PyNMState: def __init__( self, - default_init: bool = True, # has to be true for the backend settings communication ) -> None: - self.logger = logging.getLogger("uvicorn.error") + self.lsl_stream_name: str = "" + self.experiment_name: str = "PyNM_Experiment" # set by set_stream_params + self.out_dir: _PathLike = str( + Path.home() / "PyNM" / self.experiment_name + ) # set by set_stream_params + self.decoding_model_path: _PathLike | None = None + self.decoder: RealTimeDecoder | None = None + self.is_stream_lsl: bool = False - self.lsl_stream_name = None - self.stream_controller_process = None - self.run_func_process = None - self.out_dir = None # will be set by set_stream_params - self.experiment_name = None # will be set by set_stream_params - self.decoding_model_path = None - self.decoder = None + self.backend_interface: StreamBackendInterface | None = None + self.websocket_manager: WebsocketManager | None = None - if default_init: - self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1])) - self.settings: NMSettings = NMSettings(sampling_rate_features=10) + # Note: sfreq and data are required for stream init + self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1])) + self.settings: NMSettings = NMSettings(sampling_rate_features=10) + self.feature_queue = mp.Queue() + self.rawdata_queue = mp.Queue() + self.control_queue = mp.Queue() + self.stop_event = asyncio.Event() + + self.messages_sent = 0 def start_run_function( self, - websocket_manager=None, + websocket_manager: WebsocketManager | None = None, ) -> None: - - self.stream.settings = self.settings - - self.stream_handling_queue = queue.Queue() - self.feature_queue = queue.Queue() - self.rawdata_queue = queue.Queue() + # TONI: This is dangerous to do here, should be done by the setup functions + # self.stream.settings = self.settings - self.logger.info("Starting stream_controller_process thread") + self.is_stream_lsl = self.lsl_stream_name is not None + self.websocket_manager = websocket_manager + # Create decoder if self.decoding_model_path is not None and self.decoding_model_path != "None": if os.path.exists(self.decoding_model_path): self.decoder = RealTimeDecoder(self.decoding_model_path) else: - logger.debug("Passed decoding model path does't exist") - # Stop event - # .set() is called from app_backend - self.stop_event_ws = threading.Event() - - self.stream_controller_thread = Thread( - target=run_stream_controller_sync, - daemon=True, - args=(self.feature_queue, - self.rawdata_queue, - websocket_manager, - self.stop_event_ws - ), - ) + logger.debug("Passed decoding model path does't exist") + + # Initialize the backend interface if not already done + if not self.backend_interface: + self.backend_interface = StreamBackendInterface( + self.feature_queue, self.rawdata_queue, self.control_queue + ) - is_stream_lsl = self.lsl_stream_name is not None - stream_lsl_name = self.lsl_stream_name if self.lsl_stream_name is not None else "" - # The run_func_thread is terminated through the stream_handling_queue # which initiates to break the data generator and save the features - out_dir = "" if self.out_dir == "default" else self.out_dir - self.run_func_thread = Thread( + stream_process = mp.Process( target=self.stream.run, - daemon=True, kwargs={ - "out_dir" : out_dir, - "experiment_name" : self.experiment_name, - "stream_handling_queue" : self.stream_handling_queue, - "is_stream_lsl" : is_stream_lsl, - "stream_lsl_name" : stream_lsl_name, - "feature_queue" : self.feature_queue, - "simulate_real_time" : True, - "rawdata_queue" : self.rawdata_queue, - "decoder" : self.decoder, + "out_dir": "" if self.out_dir == "default" else self.out_dir, + "experiment_name": self.experiment_name, + "is_stream_lsl": self.lsl_stream_name is not None, + "stream_lsl_name": self.lsl_stream_name or "", + "simulate_real_time": True, + "decoder": self.decoder, + "backend_interface": self.backend_interface, }, ) - self.stream_controller_thread.start() - self.run_func_thread.start() + stream_process.start() + + # Start websocket sender process + + if self.websocket_manager: + # Get the current event loop and run the queue processor + loop = asyncio.get_running_loop() + queue_task = loop.create_task(self._process_queue()) + + # Store task reference for cleanup + self._queue_task = queue_task + + # Store processes for cleanup + self.stream_process = stream_process + + def stop_run_function(self) -> None: + """Stop the stream processing""" + if self.backend_interface: + self.backend_interface.send_command("stop") + self.stop_event.set() def setup_lsl_stream( self, - lsl_stream_name: str | None = None, + lsl_stream_name: str = "", line_noise: float | None = None, sampling_rate_features: float | None = None, ): from mne_lsl.lsl import resolve_streams - self.logger.info("resolving streams") + logger.info("resolving streams") lsl_streams = resolve_streams() for stream in lsl_streams: if stream.name == lsl_stream_name: - self.logger.info(f"found stream {lsl_stream_name}") + logger.info(f"found stream {lsl_stream_name}") # setup this stream self.lsl_stream_name = lsl_stream_name ch_names = stream.get_channel_names() if ch_names is None: ch_names = ["ch" + str(i) for i in range(stream.n_channels)] - self.logger.info(f"channel names: {ch_names}") + logger.info(f"channel names: {ch_names}") ch_types = stream.get_channel_types() if ch_types is None: ch_types = ["eeg" for i in range(stream.n_channels)] - self.logger.info(f"channel types: {ch_types}") + logger.info(f"channel types: {ch_types}") info_ = stream.get_channel_info() - self.logger.info(f"channel info: {info_}") + logger.info(f"channel info: {info_}") channels = set_channels( ch_names=ch_names, ch_types=ch_types, used_types=["eeg", "ecog", "dbs", "seeg"], ) - self.logger.info(channels) + logger.info(channels) sfreq = stream.sfreq self.stream: Stream = Stream( @@ -156,20 +149,19 @@ def setup_lsl_stream( sampling_rate_features_hz=sampling_rate_features, settings=self.settings, ) - self.logger.info("stream setup") - #self.settings: NMSettings = NMSettings(sampling_rate_features=sfreq) - self.logger.info("settings setup") + logger.info("stream setup") + # self.settings: NMSettings = NMSettings(sampling_rate_features=sfreq) + logger.info("settings setup") break - - if channels.shape[0] == 0: - self.logger.error(f"Stream {lsl_stream_name} not found") + else: + logger.error(f"Stream {lsl_stream_name} not found") raise ValueError(f"Stream {lsl_stream_name} not found") def setup_offline_stream( self, file_path: str, - line_noise: float | None = None, - sampling_rate_features: float | None = None, + line_noise: float, + sampling_rate_features: float, ): data, sfreq, ch_names, ch_types, bads = read_mne_data(file_path) @@ -182,7 +174,9 @@ def setup_offline_stream( target_keywords=None, ) - self.logger.info(f"settings: {self.settings}") + self.settings.sampling_rate_features_hz = sampling_rate_features + + logger.info(f"settings: {self.settings}") self.stream: Stream = Stream( settings=self.settings, sfreq=sfreq, @@ -191,3 +185,53 @@ def setup_offline_stream( line_noise=line_noise, sampling_rate_features_hz=sampling_rate_features, ) + + # Async function that will continuously run in the Uvicorn async loop + # and handle sending data through the websocket manager + async def _process_queue(self): + last_queue_check = time.time() + + while not self.stop_event.is_set(): + # Use asyncio.gather to process both queues concurrently + tasks = [] + current_time = time.time() + + # Check feature queue + while not self.feature_queue.empty(): + try: + data = self.feature_queue.get_nowait() + tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore + self.messages_sent += 1 + except Empty: + break + + # Check raw data queue + while not self.rawdata_queue.empty(): + try: + data = self.rawdata_queue.get_nowait() + self.messages_sent += 1 + tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore + except Empty: + break + + if tasks: + # Wait for all send operations to complete + await asyncio.gather(*tasks, return_exceptions=True) + else: + # Only sleep if we didn't process any messages + await asyncio.sleep(0.001) + + # Log queue diagnostics every 5 seconds + if current_time - last_queue_check > 5: + logger.info( + "\nQueue diagnostics:\n" + f"\tMessages send to websocket: {self.messages_sent}.\n" + f"\tFeature queue size: ~{self.feature_queue.qsize()}\n" + f"\tRaw data queue size: ~{self.rawdata_queue.qsize()}" + ) + + last_queue_check = current_time + + # Check if stream process is still alive + if not self.stream_process.is_alive(): + break diff --git a/py_neuromodulation/gui/backend/app_socket.py b/py_neuromodulation/gui/backend/app_socket.py index a7afed67..f81cc864 100644 --- a/py_neuromodulation/gui/backend/app_socket.py +++ b/py_neuromodulation/gui/backend/app_socket.py @@ -1,9 +1,10 @@ from fastapi import WebSocket import logging -import struct -import json import cbor2 -class WebSocketManager: +import time + + +class WebsocketManager: """ Manages WebSocket connections and messages. Perhaps in the future it will handle multiple connections. @@ -12,6 +13,12 @@ class WebSocketManager: def __init__(self): self.active_connections: list[WebSocket] = [] self.logger = logging.getLogger("PyNM") + self.disconnected = [] + self._queue_task = None + self._stop_event = None + self.loop = None + self.messages_sent = 0 + self._last_diagnostic_time = time.time() async def connect(self, websocket: WebSocket): await websocket.accept() @@ -30,25 +37,66 @@ def disconnect(self, websocket: WebSocket): f"Client {client_address.port}:{client_address.port} disconnected." ) + async def _cleanup_disconnected(self): + for connection in self.disconnected: + self.active_connections.remove(connection) + await connection.close() + # Combine IP and port to create a unique client ID async def send_cbor(self, object: dict): + if not self.active_connections: + self.logger.warning("No active connection to send message.") + return + + start_time = time.time() + cbor_data = cbor2.dumps(object) + serialize_time = time.time() - start_time + + if serialize_time > 0.1: # Log slow serializations + self.logger.warning(f"CBOR serialization took {serialize_time:.3f}s") + + send_start = time.time() for connection in self.active_connections: try: - await connection.send_bytes(cbor2.dumps(object)) + await connection.send_bytes(cbor_data) except RuntimeError as e: - self.active_connections.remove(connection) + self.logger.error(f"Error sending CBOR message: {e}") + self.disconnected.append(connection) + + send_time = time.time() - send_start + if send_time > 0.1: # Log slow sends + self.logger.warning(f"WebSocket send took {send_time:.3f}s") + + self.messages_sent += 1 + + # Log diagnostics every 5 seconds + current_time = time.time() + if current_time - self._last_diagnostic_time > 5: + self.logger.info(f"Messages sent: {self.messages_sent}") + self._last_diagnostic_time = current_time + + await self._cleanup_disconnected() async def send_message(self, message: str | dict): - self.logger.info(f"Sending message within app_socket: {message.keys()}") - if self.active_connections: - for connection in self.active_connections: + if not self.active_connections: + self.logger.warning("No active connection to send message.") + return + + self.logger.info( + f"Sending message within app_socket: {message.keys() if type(message) is dict else message}" + ) + for connection in self.active_connections: + try: if type(message) is dict: await connection.send_json(message) elif type(message) is str: await connection.send_text(message) - self.logger.info(f"Message sent") - else: - self.logger.warning("No active connection to send message.") + self.logger.info(f"Message sent to {connection.client}") + except RuntimeError as e: + self.logger.error(f"Error sending message: {e}") + self.disconnected.append(connection) + + await self._cleanup_disconnected() @property def is_connected(self): diff --git a/py_neuromodulation/gui/backend/app_utils.py b/py_neuromodulation/gui/backend/app_utils.py index 7fcce8d3..c623bf1b 100644 --- a/py_neuromodulation/gui/backend/app_utils.py +++ b/py_neuromodulation/gui/backend/app_utils.py @@ -6,6 +6,7 @@ from py_neuromodulation.utils.types import _PathLike from functools import lru_cache import platform +from py_neuromodulation import logger def force_terminate_process( @@ -211,20 +212,18 @@ def get_pinned_folders_windows(): """ try: - print(powershell_command) result = subprocess.run( ["powershell", "-Command", powershell_command], capture_output=True, text=True, check=True, ) - print(result.stdout) return json.loads(result.stdout) except subprocess.CalledProcessError as e: - print(f"Error running PowerShell command: {e}") + logger.error(f"Error running PowerShell command: {e}") return [] except json.JSONDecodeError as e: - print(f"Error decoding JSON: {e}") + logger.error(f"Error decoding JSON: {e}") return [] @@ -299,9 +298,9 @@ def get_macos_favorites(): path = item["URL"].replace("file://", "") favorites.append({"Name": item["Name"], "Path": path}) except (subprocess.CalledProcessError, json.JSONDecodeError) as e: - print(f"Error processing macOS favorites: {e}") + logger.error(f"Error processing macOS favorites: {e}") except Exception as e: - print(f"Error getting macOS favorites: {e}") + logger.error(f"Error getting macOS favorites: {e}") return favorites diff --git a/py_neuromodulation/processing/data_preprocessor.py b/py_neuromodulation/processing/data_preprocessor.py index 319926d9..3b9a9088 100644 --- a/py_neuromodulation/processing/data_preprocessor.py +++ b/py_neuromodulation/processing/data_preprocessor.py @@ -1,12 +1,12 @@ from typing import TYPE_CHECKING, Type -from py_neuromodulation.utils.types import PreprocessorName, NMPreprocessor +from py_neuromodulation.utils.types import PREPROCESSOR_NAME, NMPreprocessor if TYPE_CHECKING: import numpy as np import pandas as pd from py_neuromodulation.stream.settings import NMSettings -PREPROCESSOR_DICT: dict[PreprocessorName, str] = { +PREPROCESSOR_DICT: dict[PREPROCESSOR_NAME, str] = { "preprocessing_filter": "PreprocessingFilter", "notch_filter": "NotchFilter", "raw_resampling": "Resampler", diff --git a/py_neuromodulation/processing/filter_preprocessing.py b/py_neuromodulation/processing/filter_preprocessing.py index e6515ae0..2c5addbc 100644 --- a/py_neuromodulation/processing/filter_preprocessing.py +++ b/py_neuromodulation/processing/filter_preprocessing.py @@ -1,22 +1,14 @@ import numpy as np -from pydantic import Field from typing import TYPE_CHECKING from py_neuromodulation.utils.types import BoolSelector, FrequencyRange, NMPreprocessor +from py_neuromodulation.utils.pydantic_extensions import NMField if TYPE_CHECKING: from py_neuromodulation import NMSettings -FILTER_SETTINGS_MAP = { - "bandstop_filter": "bandstop_filter_settings", - "bandpass_filter": "bandpass_filter_settings", - "lowpass_filter": "lowpass_filter_cutoff_hz", - "highpass_filter": "highpass_filter_cutoff_hz", -} - - class FilterSettings(BoolSelector): bandstop_filter: bool = True bandpass_filter: bool = True @@ -25,21 +17,23 @@ class FilterSettings(BoolSelector): bandstop_filter_settings: FrequencyRange = FrequencyRange(100, 160) bandpass_filter_settings: FrequencyRange = FrequencyRange(2, 200) - lowpass_filter_cutoff_hz: float = Field(default=200) - highpass_filter_cutoff_hz: float = Field(default=3) - - def get_filter_tuple(self, filter_name) -> tuple[float | None, float | None]: - filter_value = self[FILTER_SETTINGS_MAP[filter_name]] - + lowpass_filter_cutoff_hz: float = NMField( + default=200, gt=0, custom_metadata={"unit": "Hz"} + ) + highpass_filter_cutoff_hz: float = NMField( + default=3, gt=0, custom_metadata={"unit": "Hz"} + ) + + def get_filter_tuple(self, filter_name) -> FrequencyRange: match filter_name: case "bandstop_filter": - return (filter_value.frequency_high_hz, filter_value.frequency_low_hz) + return self.bandstop_filter_settings case "bandpass_filter": - return (filter_value.frequency_low_hz, filter_value.frequency_high_hz) + return self.bandpass_filter_settings case "lowpass_filter": - return (None, filter_value) + return FrequencyRange(None, self.lowpass_filter_cutoff_hz) case "highpass_filter": - return (filter_value, None) + return FrequencyRange(self.highpass_filter_cutoff_hz, None) case _: raise ValueError( "Filter name must be one of 'bandstop_filter', 'lowpass_filter', " diff --git a/py_neuromodulation/processing/normalization.py b/py_neuromodulation/processing/normalization.py index e3f25d55..713e2308 100644 --- a/py_neuromodulation/processing/normalization.py +++ b/py_neuromodulation/processing/normalization.py @@ -3,10 +3,10 @@ import numpy as np from typing import TYPE_CHECKING, Callable, Literal, get_args +from py_neuromodulation.utils.pydantic_extensions import NMField from py_neuromodulation.utils.types import ( NMBaseModel, - Field, - NormMethod, + NORM_METHOD, NMPreprocessor, ) @@ -17,13 +17,13 @@ class NormalizationSettings(NMBaseModel): - normalization_time_s: float = 30 - normalization_method: NormMethod = "zscore" - clip: float = Field(default=3, ge=0) + normalization_time_s: float = NMField(30, gt=0, custom_metadata={"unit": "s"}) + normalization_method: NORM_METHOD = NMField(default="zscore") + clip: float = NMField(default=3, ge=0, custom_metadata={"unit": "a.u."}) @staticmethod - def list_normalization_methods() -> list[NormMethod]: - return list(get_args(NormMethod)) + def list_normalization_methods() -> list[NORM_METHOD]: + return list(get_args(NORM_METHOD)) class FeatureNormalizationSettings(NormalizationSettings): normalize_psd: bool = False @@ -60,7 +60,7 @@ def __init__( if self.using_sklearn: import sklearn.preprocessing as skpp - NORM_METHODS_SKLEARN: dict[NormMethod, Callable] = { + NORM_METHODS_SKLEARN: dict[NORM_METHOD, Callable] = { "quantile": lambda: skpp.QuantileTransformer(n_quantiles=300), "robust": skpp.RobustScaler, "minmax": skpp.MinMaxScaler, diff --git a/py_neuromodulation/processing/projection.py b/py_neuromodulation/processing/projection.py index 4cdc3610..3a83a29f 100644 --- a/py_neuromodulation/processing/projection.py +++ b/py_neuromodulation/processing/projection.py @@ -1,6 +1,6 @@ import numpy as np -from pydantic import Field from py_neuromodulation.utils.types import NMBaseModel +from py_neuromodulation.utils.pydantic_extensions import NMField from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -9,7 +9,7 @@ class ProjectionSettings(NMBaseModel): - max_dist_mm: float = Field(default=20.0, gt=0.0) + max_dist_mm: float = NMField(default=20.0, gt=0.0, custom_metadata={"unit": "mm"}) class Projection: diff --git a/py_neuromodulation/processing/resample.py b/py_neuromodulation/processing/resample.py index 08a4e115..1b7107eb 100644 --- a/py_neuromodulation/processing/resample.py +++ b/py_neuromodulation/processing/resample.py @@ -1,11 +1,14 @@ """Module for resampling.""" import numpy as np -from py_neuromodulation.utils.types import NMBaseModel, Field, NMPreprocessor +from py_neuromodulation.utils.types import NMBaseModel, NMPreprocessor +from py_neuromodulation.utils.pydantic_extensions import NMField class ResamplerSettings(NMBaseModel): - resample_freq_hz: float = Field(default=1000, gt=0) + resample_freq_hz: float = NMField( + default=1000, gt=0, custom_metadata={"unit": "Hz"} + ) class Resampler(NMPreprocessor): diff --git a/py_neuromodulation/stream/backend_interface.py b/py_neuromodulation/stream/backend_interface.py new file mode 100644 index 00000000..568998b0 --- /dev/null +++ b/py_neuromodulation/stream/backend_interface.py @@ -0,0 +1,47 @@ +from typing import Any +import logging +import multiprocessing as mp + + +class StreamBackendInterface: + """Handles stream data output via queues""" + + def __init__( + self, feature_queue: mp.Queue, raw_data_queue: mp.Queue, control_queue: mp.Queue + ): + self.feature_queue = feature_queue + self.rawdata_queue = raw_data_queue + self.control_queue = control_queue + + self.logger = logging.getLogger("PyNM") + + def send_command(self, command: str) -> None: + """Send a command through the control queue""" + try: + self.control_queue.put(command) + except Exception as e: + self.logger.error(f"Error sending command: {e}") + + def send_features(self, features: dict[str, Any]) -> None: + """Send feature data through the feature queue""" + try: + self.feature_queue.put(features) + except Exception as e: + self.logger.error(f"Error sending features: {e}") + + def send_raw_data(self, data: dict[str, Any]) -> None: + """Send raw data through the rawdata queue""" + try: + self.rawdata_queue.put(data) + except Exception as e: + self.logger.error(f"Error sending raw data: {e}") + + def check_control_signals(self) -> str | None: + """Check for control signals (non-blocking)""" + try: + if not self.control_queue.empty(): + return self.control_queue.get_nowait() + return None + except Exception as e: + self.logger.error(f"Error checking control signals: {e}") + return None diff --git a/py_neuromodulation/stream/mnelsl_player.py b/py_neuromodulation/stream/mnelsl_player.py index 086e4d60..bd39dd89 100644 --- a/py_neuromodulation/stream/mnelsl_player.py +++ b/py_neuromodulation/stream/mnelsl_player.py @@ -125,7 +125,7 @@ def start_player( try: self.wait_for_completion() except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping the player...") + logger.info("\nKeyboard interrupt received. Stopping the player...") self.stop_player() def _run_player(self, chunk_size, n_repeat, stop_flag, streaming_complete): @@ -156,7 +156,7 @@ def wait_for_completion(self): if self._streaming_complete.is_set(): break except KeyboardInterrupt: - print("\nKeyboard interrupt received. Stopping the player...") + logger.info("\nKeyboard interrupt received. Stopping the player...") self.stop_player() break @@ -172,7 +172,7 @@ def stop_player(self): self._player_process.kill() self._player_process = None - print(f"Player stopped: {self.stream_name}") + logger.info(f"Player stopped: {self.stream_name}") LSLOfflinePlayer._instances.discard(self) @classmethod diff --git a/py_neuromodulation/stream/settings.py b/py_neuromodulation/stream/settings.py index 669c97db..c2b45e4c 100644 --- a/py_neuromodulation/stream/settings.py +++ b/py_neuromodulation/stream/settings.py @@ -1,22 +1,27 @@ """Module for handling settings.""" from pathlib import Path -from typing import ClassVar -from pydantic import Field, model_validator +from typing import Any, ClassVar, get_args +from pydantic import model_validator, ValidationError +from pydantic.functional_validators import ModelWrapValidatorHandler -from py_neuromodulation import PYNM_DIR, logger, user_features +from py_neuromodulation import logger, user_features from py_neuromodulation.utils.types import ( BoolSelector, FrequencyRange, - PreprocessorName, _PathLike, NMBaseModel, - NormMethod, + NORM_METHOD, + PREPROCESSOR_NAME, ) +from py_neuromodulation.utils.pydantic_extensions import NMErrorList, NMField from py_neuromodulation.processing.filter_preprocessing import FilterSettings -from py_neuromodulation.processing.normalization import FeatureNormalizationSettings, NormalizationSettings +from py_neuromodulation.processing.normalization import ( + FeatureNormalizationSettings, + NormalizationSettings, +) from py_neuromodulation.processing.resample import ResamplerSettings from py_neuromodulation.processing.projection import ProjectionSettings @@ -31,7 +36,9 @@ from py_neuromodulation.features import BurstsSettings -class FeatureSelection(BoolSelector): +# TONI: this class has the proble that if a feature is absent, +# it won't default to False but to whatever is defined here as default +class FeatureSelector(BoolSelector): raw_hjorth: bool = True return_raw: bool = True bandpass_filter: bool = False @@ -54,13 +61,24 @@ class PostprocessingSettings(BoolSelector): project_subcortex: bool = False +DEFAULT_PREPROCESSORS: list[PREPROCESSOR_NAME] = [ + "raw_resampling", + "notch_filter", + "re_referencing", +] + + class NMSettings(NMBaseModel): # Class variable to store instances _instances: ClassVar[list["NMSettings"]] = [] # General settings - sampling_rate_features_hz: float = Field(default=10, gt=0) - segment_length_features_ms: float = Field(default=1000, gt=0) + sampling_rate_features_hz: float = NMField( + default=10, gt=0, custom_metadata={"unit": "Hz"} + ) + segment_length_features_ms: float = NMField( + default=1000, gt=0, custom_metadata={"unit": "ms"} + ) frequency_ranges_hz: dict[str, FrequencyRange] = { "theta": FrequencyRange(4, 8), "alpha": FrequencyRange(8, 12), @@ -72,23 +90,28 @@ class NMSettings(NMBaseModel): } # Preproceessing settings - preprocessing: list[PreprocessorName] = [ - "raw_resampling", - "notch_filter", - "re_referencing", - ] + preprocessing: list[PREPROCESSOR_NAME] = NMField( + default=DEFAULT_PREPROCESSORS, + custom_metadata={ + "field_type": "PreprocessorList", + "valid_values": list(get_args(PREPROCESSOR_NAME)), + }, + ) + raw_resampling_settings: ResamplerSettings = ResamplerSettings() preprocessing_filter: FilterSettings = FilterSettings() raw_normalization_settings: NormalizationSettings = NormalizationSettings() # Postprocessing settings postprocessing: PostprocessingSettings = PostprocessingSettings() - feature_normalization_settings: FeatureNormalizationSettings = FeatureNormalizationSettings() + feature_normalization_settings: FeatureNormalizationSettings = ( + FeatureNormalizationSettings() + ) project_cortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=20) project_subcortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=5) # Feature settings - features: FeatureSelection = FeatureSelection() + features: FeatureSelector = FeatureSelector() fft_settings: OscillatorySettings = OscillatorySettings() welch_settings: OscillatorySettings = OscillatorySettings() @@ -126,10 +149,38 @@ def _remove_feature(cls, feature: str) -> None: for instance in cls._instances: delattr(instance.features, feature) - @model_validator(mode="after") - def validate_settings(self): + @model_validator(mode="wrap") # type: ignore[reportIncompatibleMethodOverride] + def validate_settings(self, handler: ModelWrapValidatorHandler) -> Any: + # Perform all necessary custom validations in the settings class and also + # all validations in the feature classes that need additional information from + # the settings class + errors: NMErrorList = NMErrorList() + + def remove_private_keys(data): + if isinstance(data, dict): + if "__value__" in data: + return data["__value__"] + else: + return { + key: remove_private_keys(value) + for key, value in data.items() + if not key.startswith("__") + } + elif isinstance(data, (list, tuple, set)): + return type(data)(remove_private_keys(item) for item in data) + else: + return data + + self = remove_private_keys(self) + + try: + self = handler(self) # validate the model + except ValidationError as e: + self = NMSettings.unvalidated(**self) # type: ignore + errors.extend(NMErrorList(e.errors())) + if len(self.features.get_enabled()) == 0: - raise ValueError("At least one feature must be selected.") + errors.add_error("At least one feature must be selected.") # Replace spaces with underscores in frequency band names self.frequency_ranges_hz = { @@ -138,32 +189,27 @@ def validate_settings(self): if self.features.bandpass_filter: # Check BandPass settings frequency bands - self.bandpass_filter_settings.validate_fbands(self) + errors.extend(self.bandpass_filter_settings.validate_fbands(self)) # Check Kalman filter frequency bands if self.bandpass_filter_settings.kalman_filter: - self.kalman_filter_settings.validate_fbands(self) + errors.extend(self.kalman_filter_settings.validate_fbands(self)) - for k, v in self.frequency_ranges_hz.items(): - if not isinstance(v, FrequencyRange): - self.frequency_ranges_hz[k] = FrequencyRange.create_from(v) + if len(errors) > 0: + raise errors.create_error() return self def reset(self) -> "NMSettings": self.features.disable_all() - self.preprocessing = [] + self.preprocessing = DEFAULT_PREPROCESSORS self.postprocessing.disable_all() return self def set_fast_compute(self) -> "NMSettings": self.reset() self.features.fft = True - self.preprocessing = [ - "raw_resampling", - "notch_filter", - "re_referencing", - ] + self.preprocessing = DEFAULT_PREPROCESSORS self.postprocessing.feature_normalization = True self.postprocessing.project_cortex = False self.postprocessing.project_subcortex = False @@ -250,10 +296,10 @@ def from_file(PATH: _PathLike) -> "NMSettings": @staticmethod def get_default() -> "NMSettings": - return NMSettings.from_file(PYNM_DIR / "default_settings.yaml") + return NMSettings() @staticmethod - def list_normalization_methods() -> list[NormMethod]: + def list_normalization_methods() -> list[NORM_METHOD]: return NormalizationSettings.list_normalization_methods() def save( diff --git a/py_neuromodulation/stream/stream.py b/py_neuromodulation/stream/stream.py index 2ee03421..1ec685f6 100644 --- a/py_neuromodulation/stream/stream.py +++ b/py_neuromodulation/stream/stream.py @@ -1,19 +1,19 @@ """Module for generic and offline data streams.""" import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from collections.abc import Iterator import numpy as np from pathlib import Path import py_neuromodulation as nm -from contextlib import suppress from py_neuromodulation.stream.data_processor import DataProcessor -from py_neuromodulation.utils.types import _PathLike, FeatureName +from py_neuromodulation.utils.types import _PathLike, FEATURE_NAME from py_neuromodulation.utils.file_writer import MsgPackFileWriter from py_neuromodulation.stream.settings import NMSettings from py_neuromodulation.analysis.decode import RealTimeDecoder +from py_neuromodulation.stream.backend_interface import StreamBackendInterface if TYPE_CHECKING: import pandas as pd @@ -68,6 +68,7 @@ def __init__( verbose : bool, optional print out stream computation time information, by default True """ + # This is calling NMSettings.validate() which is making a copy self.settings: NMSettings = NMSettings.load(settings) if channels is None and data is not None: @@ -85,7 +86,7 @@ def __init__( raise ValueError("Either `channels` or `data` must be passed to `Stream`.") # If features that use frequency ranges are on, test them against nyquist frequency - use_freq_ranges: list[FeatureName] = [ + use_freq_ranges: list[FEATURE_NAME] = [ "bandpass_filter", "stft", "fft", @@ -208,14 +209,11 @@ def run( save_interval: int = 10, return_df: bool = True, simulate_real_time: bool = False, - decoder: RealTimeDecoder = None, - feature_queue: "queue.Queue | None" = None, - rawdata_queue: "queue.Queue | None" = None, - stream_handling_queue: "queue.Queue | None" = None, + decoder: RealTimeDecoder | None = None, + backend_interface: StreamBackendInterface | None = None, ): self.is_stream_lsl = is_stream_lsl self.stream_lsl_name = stream_lsl_name - self.stream_handling_queue = stream_handling_queue self.save_csv = save_csv self.save_interval = save_interval self.return_df = return_df @@ -248,7 +246,7 @@ def run( nm.logger.log_to_file(out_dir) self.generator: Iterator - if not is_stream_lsl: + if not is_stream_lsl and data is not None: from py_neuromodulation.stream.generator import RawDataGenerator self.generator = RawDataGenerator( @@ -265,7 +263,10 @@ def run( settings=self.settings, stream_name=stream_lsl_name ) - if self.sfreq != self.lsl_stream.stream.sinfo.sfreq: + if ( + self.lsl_stream.stream.sinfo is not None + and self.sfreq != self.lsl_stream.stream.sinfo.sfreq + ): error_msg = ( f"Sampling frequency of the lsl-stream ({self.lsl_stream.stream.sinfo.sfreq}) " f"does not match the settings ({self.sfreq})." @@ -279,14 +280,14 @@ def run( prev_batch_end = 0 for timestamps, data_batch in self.generator: self.is_running = True - if self.stream_handling_queue is not None: + if backend_interface: + # Only simulate real-time if connected to GUI if simulate_real_time: time.sleep(1 / self.settings.sampling_rate_features_hz) - if not self.stream_handling_queue.empty(): - signal = self.stream_handling_queue.get() - nm.logger.info(f"Got signal: {signal}") - if signal == "stop": - break + + signal = backend_interface.check_control_signals() + if signal == "stop": + break if data_batch is None: nm.logger.info("Data batch is None, stopping run function") @@ -303,10 +304,11 @@ def run( if decoder is not None: ch_to_decode = self.channels.query("used == 1").iloc[0]["name"] - feature_dict = decoder.predict(feature_dict, ch_to_decode, fft_bands_only=True) + feature_dict = decoder.predict( + feature_dict, ch_to_decode, fft_bands_only=True + ) feature_dict["time"] = np.ceil(this_batch_end * 1000 + 1) - prev_batch_end = this_batch_end if self.verbose: @@ -314,40 +316,41 @@ def run( self._add_target(feature_dict, data_batch) - with suppress(TypeError): # Need this because some features output None - for key, value in feature_dict.items(): - feature_dict[key] = np.float64(value) - + # Push data to file writer file_writer.insert_data(feature_dict) - if feature_queue is not None: - feature_queue.put(feature_dict) - if rawdata_queue is not None: - # convert raw data into dict with new raw data in unit self.sfreq - new_time_ms = 1000 / self.settings.sampling_rate_features_hz - new_samples = int(new_time_ms * self.sfreq / 1000) - data_batch_dict = {} - data_batch_dict["raw_data"] = {} - for i, ch in enumerate(self.channels["name"]): - # needs to be list since cbor doesn't support np array - data_batch_dict["raw_data"][ch] = list(data_batch[i, -new_samples:]) - rawdata_queue.put(data_batch_dict) + # Send data to frontend + if backend_interface: + backend_interface.send_features(feature_dict) + backend_interface.send_raw_data(self._prepare_raw_data_dict(data_batch)) + + # Save features to file in intervals self.batch_count += 1 if self.batch_count % self.save_interval == 0: file_writer.save() file_writer.save() + if self.save_csv: file_writer.save_as_csv(save_all_combined=True) - if self.return_df: - feature_df = file_writer.load_all() + feature_df = file_writer.load_all() if self.return_df else {} self._save_after_stream() self.is_running = False - return feature_df # Timon: We could think of returning the feature_reader instead - + return feature_df # Timon: We could think of returnader instead + + def _prepare_raw_data_dict(self, data_batch: np.ndarray) -> dict[str, Any]: + """Prepare raw data dictionary for sending through queue""" + new_time_ms = 1000 / self.settings.sampling_rate_features_hz + new_samples = int(new_time_ms * self.sfreq / 1000) + return { + "raw_data": { + ch: list(data_batch[i, -new_samples:]) + for i, ch in enumerate(self.channels["name"]) + } + } def plot_raw_signal( self, @@ -383,11 +386,15 @@ def plot_raw_signal( ValueError raise Exception when no data is passed """ - if self.data is None and data is None: - raise ValueError("No data passed to plot_raw_signal function.") - - if data is None and self.data is not None: - data = self.data + if data is None: + if self.data is None: + raise ValueError("No data passed to plot_raw_signal function.") + else: + data = ( + self.data.to_numpy() + if isinstance(self.data, pd.DataFrame) + else self.data + ) if sfreq is None: sfreq = self.sfreq @@ -402,7 +409,7 @@ def plot_raw_signal( from mne import create_info from mne.io import RawArray - info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) # type: ignore raw = RawArray(data, info) if picks is not None: @@ -437,7 +444,7 @@ def _save_sidecar(self) -> None: """Save sidecar incduing fs, coords, sess_right to out_path_root and subfolder 'folder_name'""" additional_args = {"sess_right": self.sess_right} - + self.data_processor.save_sidecar( self.out_dir, self.experiment_name, additional_args ) diff --git a/py_neuromodulation/utils/channels.py b/py_neuromodulation/utils/channels.py index c2fffd00..61f5b68e 100644 --- a/py_neuromodulation/utils/channels.py +++ b/py_neuromodulation/utils/channels.py @@ -251,7 +251,7 @@ def _get_default_references( def get_default_channels_from_data( - data: np.ndarray, + data: "np.ndarray | pd.DataFrame", car_rereferencing: bool = True, ): """Return default channels dataframe with diff --git a/py_neuromodulation/utils/pydantic_extensions.py b/py_neuromodulation/utils/pydantic_extensions.py new file mode 100644 index 00000000..e489712d --- /dev/null +++ b/py_neuromodulation/utils/pydantic_extensions.py @@ -0,0 +1,320 @@ +import copy +from typing import ( + Any, + get_origin, + get_args, + get_type_hints, + Literal, + cast, + Sequence, +) +from typing_extensions import Unpack, TypedDict +from pydantic import BaseModel, GetCoreSchemaHandler + +from pydantic_core import ( + ErrorDetails, + PydanticUndefined, + InitErrorDetails, + ValidationError, + CoreSchema, + core_schema, +) +from pydantic.fields import FieldInfo, _FieldInfoInputs, _FromFieldInfoInputs +from pprint import pformat + + +def create_validation_error( + error_message: str, + location: list[str | int] = [], + title: str = "Validation Error", + error_type="value_error", +) -> ValidationError: + """ + Factory function to create a Pydantic v2 ValidationError. + + Args: + error_message (str): The error message for the ValidationError. + loc (List[str | int], optional): The location of the error. Defaults to None. + title (str, optional): The title of the error. Defaults to "Validation Error". + + Returns: + ValidationError: A Pydantic ValidationError instance. + """ + + return ValidationError.from_exception_data( + title=title, + line_errors=[ + InitErrorDetails( + type=error_type, + loc=tuple(location), + input=None, + ctx={"error": error_message}, + ) + ], + input_type="python", + hide_input=False, + ) + + +class NMErrorList: + """Class to handle data about Pydantic errors. + Stores data in a list of InitErrorDetails. Errors can be accessed but not modified. + + :return: _description_ + :rtype: _type_ + """ + + def __init__( + self, errors: Sequence[InitErrorDetails | ErrorDetails] | None = None + ) -> None: + if errors is None: + self.__errors: list[InitErrorDetails | ErrorDetails] = [] + else: + self.__errors: list[InitErrorDetails | ErrorDetails] = [e for e in errors] + + def add_error( + self, + error_message: str, + location: list[str | int] = [], + error_type="value_error", + ) -> None: + self.__errors.append( + InitErrorDetails( + type=error_type, + loc=tuple(location), + input=None, + ctx={"error": error_message}, + ) + ) + + def create_error(self, title: str = "Validation Error") -> ValidationError: + return ValidationError.from_exception_data( + title=title, line_errors=cast(list[InitErrorDetails], self.__errors) + ) + + def extend(self, errors: "NMErrorList"): + self.__errors.extend(errors.__errors) + + def __iter__(self): + return iter(self.__errors) + + def __len__(self): + return len(self.__errors) + + def __getitem__(self, idx): + # Return a copy of the error to prevent modification + return copy.deepcopy(self.__errors[idx]) + + def __repr__(self): + return repr(self.__errors) + + def __str__(self): + return str(self.__errors) + + +class _NMExtraFieldInputs(TypedDict, total=False): + """Additional fields to add on top of the pydantic FieldInfo""" + + custom_metadata: dict[str, Any] + + +class _NMFieldInfoInputs(_FieldInfoInputs, _NMExtraFieldInputs, total=False): + """Combine pydantic FieldInfo inputs with PyNM additional inputs""" + + pass + + +class _NMFromFieldInfoInputs(_FromFieldInfoInputs, _NMExtraFieldInputs, total=False): + """Combine pydantic FieldInfo.from_field inputs with PyNM additional inputs""" + + pass + + +class NMFieldInfo(FieldInfo): + # Add default values for any other custom fields here + _default_values = {} + + def __init__(self, **kwargs: Unpack[_NMFieldInfoInputs]) -> None: + self.sequence: bool = kwargs.pop("sequence", False) # type: ignore + self.custom_metadata: dict[str, Any] = kwargs.pop("custom_metadata", {}) + super().__init__(**kwargs) + + def __get_pydantic_core_schema__( + self, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + schema = handler(source_type) + + if self.sequence: + + def sequence_validator(v: Any) -> Any: + if isinstance(v, (list, tuple)): + return v + if isinstance(v, dict) and "root" in v: + return v["root"] + return [v] + + return core_schema.no_info_before_validator_function( + sequence_validator, schema + ) + + return schema + + @staticmethod + def from_field( + default: Any = PydanticUndefined, + **kwargs: Unpack[_NMFromFieldInfoInputs], + ) -> "NMFieldInfo": + if "annotation" in kwargs: + raise TypeError('"annotation" is not permitted as a Field keyword argument') + return NMFieldInfo(default=default, **kwargs) + + def __repr_args__(self): + yield from super().__repr_args__() + extra_fields = get_type_hints(_NMExtraFieldInputs) + for field in extra_fields: + value = getattr(self, field) + yield field, value + + +def NMField( + default: Any = PydanticUndefined, + **kwargs: Unpack[_NMFromFieldInfoInputs], +) -> Any: + return NMFieldInfo.from_field(default=default, **kwargs) + + +class NMBaseModel(BaseModel): + def __init__(self, *args, **kwargs) -> None: + """Pydantic does not support positional arguments by default. + This is a workaround to support positional arguments for models like FrequencyRange. + It converts positional arguments to kwargs and then calls the base class __init__. + """ + + if not args: + # Simple case - just use kwargs + super().__init__(*args, **kwargs) + return + + field_names = list(self.model_fields.keys()) + # If we have more positional args than fields, that's an error + if len(args) > len(field_names): + raise ValueError( + f"Too many positional arguments. Expected at most {len(field_names)}, got {len(args)}" + ) + + # Convert positional args to kwargs, checking for duplicates if args: + complete_kwargs = {} + for i, arg in enumerate(args): + field_name = field_names[i] + if field_name in kwargs: + raise ValueError( + f"Got multiple values for field '{field_name}': " + f"positional argument and keyword argument" + ) + complete_kwargs[field_name] = arg + + # Add remaining kwargs + complete_kwargs.update(kwargs) + super().__init__(**complete_kwargs) + + __init__.__pydantic_base_init__ = True # type: ignore + + def __str__(self): + return pformat(self.model_dump()) + + # def __repr__(self): + # return pformat(self.model_dump()) + + def validate(self, context: Any | None = None) -> Any: # type: ignore + return self.model_validate(self.model_dump(), context=context) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value) -> None: + setattr(self, key, value) + + @property + def fields(self) -> dict[str, FieldInfo | NMFieldInfo]: + return self.model_fields # type: ignore + + def serialize_with_metadata(self): + result: dict[str, Any] = {"__field_type__": self.__class__.__name__} + + for field_name, field_info in self.fields.items(): + value = getattr(self, field_name) + if isinstance(value, NMBaseModel): + result[field_name] = value.serialize_with_metadata() + elif isinstance(value, list): + result[field_name] = [ + item.serialize_with_metadata() + if isinstance(item, NMBaseModel) + else item + for item in value + ] + elif isinstance(value, dict): + result[field_name] = { + k: v.serialize_with_metadata() if isinstance(v, NMBaseModel) else v + for k, v in value.items() + } + else: + result[field_name] = value + + # Extract unit information from Annotated type + if isinstance(field_info, NMFieldInfo): + # Convert scalar value to dict with metadata + field_dict = { + "__value__": value, + # __field_type__ will be overwritte if set in custom_metadata + "__field_type__": type(value).__name__, + **{ + f"__{tag}__": value + for tag, value in field_info.custom_metadata.items() + }, + } + # Add possible values for Literal types + if get_origin(field_info.annotation) is Literal: + field_dict["__valid_values__"] = list( + get_args(field_info.annotation) + ) + + result[field_name] = field_dict + return result + + @classmethod + def unvalidated(cls, **data: Any) -> Any: + def process_value(value: Any, field_type: Any) -> Any: + if isinstance(value, dict) and hasattr( + field_type, "__pydantic_core_schema__" + ): + # Recursively handle nested Pydantic models + return field_type.unvalidated(**value) + elif isinstance(value, list): + # Handle lists of Pydantic models + if hasattr(field_type, "__args__") and hasattr( + field_type.__args__[0], "__pydantic_core_schema__" + ): + return [ + field_type.__args__[0].unvalidated(**item) + if isinstance(item, dict) + else item + for item in value + ] + return value + + processed_data = {} + for name, field in cls.model_fields.items(): + try: + value = data[name] + processed_data[name] = process_value(value, field.annotation) + except KeyError: + if not field.is_required(): + processed_data[name] = copy.deepcopy(field.default) + else: + raise TypeError(f"Missing required keyword argument {name!r}") + + self = cls.__new__(cls) + object.__setattr__(self, "__dict__", processed_data) + object.__setattr__(self, "__pydantic_private__", {"extra": None}) + object.__setattr__(self, "__pydantic_fields_set__", set(processed_data.keys())) + return self diff --git a/py_neuromodulation/utils/types.py b/py_neuromodulation/utils/types.py index 7886685d..6f7ab636 100644 --- a/py_neuromodulation/utils/types.py +++ b/py_neuromodulation/utils/types.py @@ -1,12 +1,14 @@ from os import PathLike from math import isnan -from typing import Any, Literal, Protocol, TYPE_CHECKING, runtime_checkable -from pydantic import ConfigDict, Field, model_validator, BaseModel -from pydantic_core import ValidationError, InitErrorDetails -from pprint import pformat +from typing import Literal, TYPE_CHECKING, Any +from pydantic import BaseModel, model_validator, Field +from .pydantic_extensions import NMBaseModel, NMField +from abc import abstractmethod + from collections.abc import Sequence from datetime import datetime + if TYPE_CHECKING: import numpy as np from py_neuromodulation import NMSettings @@ -17,8 +19,7 @@ _PathLike = str | PathLike - -FeatureName = Literal[ +FEATURE_NAME = Literal[ "raw_hjorth", "return_raw", "bandpass_filter", @@ -35,7 +36,7 @@ "bispectrum", ] -PreprocessorName = Literal[ +PREPROCESSOR_NAME = Literal[ "preprocessing_filter", "notch_filter", "raw_resampling", @@ -43,7 +44,7 @@ "raw_normalization", ] -NormMethod = Literal[ +NORM_METHOD = Literal[ "mean", "median", "zscore", @@ -54,13 +55,8 @@ "minmax", ] -################################### -######## PROTOCOL CLASSES ######## -################################### - -@runtime_checkable -class NMFeature(Protocol): +class NMFeature: def __init__( self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int | float ) -> None: ... @@ -81,73 +77,10 @@ def calc_feature(self, data: "np.ndarray") -> dict: ... -class NMPreprocessor(Protocol): - def __init__(self, sfreq: float, settings: "NMSettings") -> None: ... - +class NMPreprocessor: def process(self, data: "np.ndarray") -> "np.ndarray": ... -################################### -######## PYDANTIC CLASSES ######## -################################### - - -class NMBaseModel(BaseModel): - model_config = ConfigDict(validate_assignment=False, extra="allow") - - def __init__(self, *args, **kwargs) -> None: - if kwargs: - super().__init__(**kwargs) - else: - field_names = list(self.model_fields.keys()) - kwargs = {} - for i in range(len(args)): - kwargs[field_names[i]] = args[i] - super().__init__(**kwargs) - - def __str__(self): - return pformat(self.model_dump()) - - def __repr__(self): - return pformat(self.model_dump()) - - def validate(self) -> Any: # type: ignore - return self.model_validate(self.model_dump()) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value) -> None: - setattr(self, key, value) - - def process_for_frontend(self) -> dict[str, Any]: - """ - Process the model for frontend use, adding __field_type__ information. - """ - result = {} - for field_name, field_value in self.__dict__.items(): - if isinstance(field_value, NMBaseModel): - processed_value = field_value.process_for_frontend() - processed_value["__field_type__"] = field_value.__class__.__name__ - result[field_name] = processed_value - elif isinstance(field_value, list): - result[field_name] = [ - item.process_for_frontend() - if isinstance(item, NMBaseModel) - else item - for item in field_value - ] - elif isinstance(field_value, dict): - result[field_name] = { - k: v.process_for_frontend() if isinstance(v, NMBaseModel) else v - for k, v in field_value.items() - } - else: - result[field_name] = field_value - - return result - - class FrequencyRange(NMBaseModel): frequency_low_hz: float = Field(gt=0) frequency_high_hz: float = Field(gt=0) @@ -178,35 +111,6 @@ def validate_range(self): ), "Frequency high must be greater than frequency low" return self - @classmethod - def create_from(cls, input) -> "FrequencyRange": - match input: - case FrequencyRange(): - return input - case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input: - return FrequencyRange( - input["frequency_low_hz"], input["frequency_high_hz"] - ) - case Sequence() if len(input) == 2: - return FrequencyRange(input[0], input[1]) - case _: - raise ValueError("Invalid input for FrequencyRange creation.") - - @model_validator(mode="before") - @classmethod - def check_input(cls, input): - match input: - case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input: - return input - case Sequence() if len(input) == 2: - return {"frequency_low_hz": input[0], "frequency_high_hz": input[1]} - case _: - raise ValueError( - "Value for FrequencyRange must be a dictionary, " - "or a sequence of 2 numeric values, " - f"but got {input} instead." - ) - class BoolSelector(NMBaseModel): def get_enabled(self): @@ -238,47 +142,6 @@ def print_all(cls): for f in cls.list_all(): print(f) - @classmethod - def get_fields(cls): - return cls.model_fields - - -def create_validation_error( - error_message: str, - loc: list[str | int] = None, - title: str = "Validation Error", - input_type: Literal["python", "json"] = "python", - hide_input: bool = False, -) -> ValidationError: - """ - Factory function to create a Pydantic v2 ValidationError instance from a single error message. - - Args: - error_message (str): The error message for the ValidationError. - loc (List[str | int], optional): The location of the error. Defaults to None. - title (str, optional): The title of the error. Defaults to "Validation Error". - input_type (Literal["python", "json"], optional): Whether the error is for a Python object or JSON. Defaults to "python". - hide_input (bool, optional): Whether to hide the input value in the error message. Defaults to False. - - Returns: - ValidationError: A Pydantic ValidationError instance. - """ - if loc is None: - loc = [] - - line_errors = [ - InitErrorDetails( - type="value_error", loc=tuple(loc), input=None, ctx={"error": error_message} - ) - ] - - return ValidationError.from_exception_data( - title=title, - line_errors=line_errors, - input_type=input_type, - hide_input=hide_input, - ) - ################# ### GUI TYPES ### diff --git a/pyproject.toml b/pyproject.toml index 5facc8ba..7ad64fae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ dependencies = [ "nolds >=0.6.1", "numpy >= 1.21.2", "pandas >= 2.0.0", - "scikit-image", "scikit-learn >= 0.24.2", "scikit-optimize", "scipy >= 1.7.1", @@ -55,13 +54,11 @@ dependencies = [ "statsmodels", "mne-lsl>=1.2.0", "pynput", - #"pyqt5", "pydantic>=2.7.3", "llvmlite>=0.43.0", "pywebview", "fastapi", - "uvicorn>=0.30.6", - "websockets>=13.0", + "uvicorn[standard]>=0.30.6", "seaborn >= 0.11", # exists only because of nolds? "numba>=0.60.0", @@ -73,14 +70,7 @@ dependencies = [ [project.optional-dependencies] test = ["pytest>=8.0.2", "pytest-xdist"] -dev = [ - "ruff", - "pytest>=8.0.2", - "pytest-cov", - "pytest-sugar", - "notebook", - "watchdog", -] +dev = ["ruff", "pytest>=8.0.2", "pytest-cov", "pytest-sugar", "notebook"] docs = [ "py-neuromodulation[dev]", "sphinx", diff --git a/tests/test_osc_features.py b/tests/test_osc_features.py index 8cf84111..fd2428d1 100644 --- a/tests/test_osc_features.py +++ b/tests/test_osc_features.py @@ -3,11 +3,11 @@ import numpy as np from py_neuromodulation import NMSettings, Stream, features -from py_neuromodulation.utils.types import FeatureName +from py_neuromodulation.utils.types import FEATURE_NAME def setup_osc_settings( - osc_feature_name: FeatureName, + osc_feature_name: FEATURE_NAME, osc_feature_setting: str, windowlength_ms: int, log_transform: bool,