diff --git a/gui_dev/package.json b/gui_dev/package.json index e1cb7f00..5fa2d606 100644 --- a/gui_dev/package.json +++ b/gui_dev/package.json @@ -21,7 +21,7 @@ "react-dom": "next", "react-icons": "^5.3.0", "react-plotly.js": "^2.6.0", - "react-router-dom": "^6.28.0", + "react-router-dom": "^7.0.1", "zustand": "latest" }, "devDependencies": { diff --git a/gui_dev/src/App.jsx b/gui_dev/src/App.jsx index 7ebfdc55..a97f8c75 100644 --- a/gui_dev/src/App.jsx +++ b/gui_dev/src/App.jsx @@ -45,8 +45,6 @@ export const App = () => { const connectSocket = useSocketStore((state) => state.connectSocket); const disconnectSocket = useSocketStore((state) => state.disconnectSocket); - connectSocket(); - useEffect(() => { console.log("Connecting socket from App component..."); connectSocket(); diff --git a/gui_dev/src/components/StatusBar/StatusBar.jsx b/gui_dev/src/components/StatusBar/StatusBar.jsx index 36aaf1dc..f011e2b9 100644 --- a/gui_dev/src/components/StatusBar/StatusBar.jsx +++ b/gui_dev/src/components/StatusBar/StatusBar.jsx @@ -2,12 +2,14 @@ import { ResizeHandle } from "./ResizeHandle"; import { SocketStatus } from "./SocketStatus"; import { WebviewStatus } from "./WebviewStatus"; -import { useWebviewStore } from "@/stores"; - +import { useUiStore, useWebviewStore } from "@/stores"; import { Stack } from "@mui/material"; export const StatusBar = () => { - const { isWebView } = useWebviewStore((state) => state.isWebView); + const isWebView = useWebviewStore((state) => state.isWebView); + const createStatusBarContent = useUiStore((state) => state.statusBarContent); + + const StatusBarContent = createStatusBarContent(); return ( { bgcolor="background.level1" borderTop="2px solid" borderColor="background.level3" + height="2rem" > - + {StatusBarContent && } + + {/* */} {/* Current experiment */} {/* Current stream */} {/* Current activity */} diff --git a/gui_dev/src/components/TitledBox.jsx b/gui_dev/src/components/TitledBox.jsx index 56f0266b..a86b6009 100644 --- a/gui_dev/src/components/TitledBox.jsx +++ b/gui_dev/src/components/TitledBox.jsx @@ -1,4 +1,4 @@ -import { Container } from "@mui/material"; +import { Box, Container } from "@mui/material"; /** * Component that uses the Box component to render an HTML fieldset element @@ -13,14 +13,14 @@ export const TitledBox = ({ children, ...props }) => ( - {title} {children} - + ); diff --git a/gui_dev/src/pages/Settings/DragAndDropList.jsx b/gui_dev/src/pages/Settings/DragAndDropList.jsx deleted file mode 100644 index afa74b02..00000000 --- a/gui_dev/src/pages/Settings/DragAndDropList.jsx +++ /dev/null @@ -1,65 +0,0 @@ -import { useRef } from "react"; -import styles from "./DragAndDropList.module.css"; -import { useOptionsStore } from "@/stores"; - -export const DragAndDropList = () => { - const { options, setOptions, addOption, removeOption } = useOptionsStore(); - - const predefinedOptions = [ - { id: 1, name: "raw_resampling" }, - { id: 2, name: "notch_filter" }, - { id: 3, name: "re_referencing" }, - { id: 4, name: "preprocessing_filter" }, - { id: 5, name: "raw_normalization" }, - ]; - - const dragOption = useRef(0); - const draggedOverOption = useRef(0); - - function handleSort() { - const optionsClone = [...options]; - const temp = optionsClone[dragOption.current]; - optionsClone[dragOption.current] = optionsClone[draggedOverOption.current]; - optionsClone[draggedOverOption.current] = temp; - setOptions(optionsClone); - } - - return ( -
-

List

- {options.map((option, index) => ( -
(dragOption.current = index)} - onDragEnter={() => (draggedOverOption.current = index)} - onDragEnd={handleSort} - onDragOver={(e) => e.preventDefault()} - > -

- {[option.id, ". ", option.name.replace("_", " ")]} -

- -
- ))} -
-

Add Elements

- {predefinedOptions.map((option) => ( - - ))} -
-
- ); -}; diff --git a/gui_dev/src/pages/Settings/DragAndDropList.module.css b/gui_dev/src/pages/Settings/DragAndDropList.module.css deleted file mode 100644 index 76bd7861..00000000 --- a/gui_dev/src/pages/Settings/DragAndDropList.module.css +++ /dev/null @@ -1,72 +0,0 @@ -.dragDropList { - max-width: 400px; - margin: 0 auto; - padding: 20px; -} - -.title { - text-align: center; - color: #333; - margin-bottom: 20px; -} - -.item { - background-color: #f0f0f0; - border-radius: 8px; - padding: 15px; - margin-bottom: 10px; - cursor: move; - transition: background-color 0.3s ease; - display: flex; - justify-content: space-between; - align-items: center; -} - -.item:hover { - background-color: #e0e0e0; -} - -.itemText { - margin: 0; - color: #333; - font-size: 16px; -} - -.removeButton { - background-color: #ff4d4d; - border: none; - border-radius: 4px; - color: white; - padding: 5px 10px; - cursor: pointer; - transition: background-color 0.3s ease; -} - -.removeButton:hover { - background-color: #ff1a1a; -} - -.addSection { - margin-top: 20px; - text-align: center; -} - -.subtitle { - color: #333; - margin-bottom: 10px; -} - -.addButton { - background-color: #4CAF50; - border: none; - border-radius: 4px; - color: white; - padding: 10px 15px; - margin: 5px; - cursor: pointer; - transition: background-color 0.3s ease; -} - -.addButton:hover { - background-color: #45a049; -} \ No newline at end of file diff --git a/gui_dev/src/pages/Settings/Dropdown.jsx b/gui_dev/src/pages/Settings/Dropdown.jsx deleted file mode 100644 index ad662eef..00000000 --- a/gui_dev/src/pages/Settings/Dropdown.jsx +++ /dev/null @@ -1,66 +0,0 @@ -import { useState } from "react"; -import "../App.css"; - -var stringJson = - '{"sampling_rate_features_hz":10.0,"segment_length_features_ms":1000.0,"frequency_ranges_hz":{"theta":{"frequency_low_hz":4.0,"frequency_high_hz":8.0},"alpha":{"frequency_low_hz":8.0,"frequency_high_hz":12.0},"low beta":{"frequency_low_hz":13.0,"frequency_high_hz":20.0},"high beta":{"frequency_low_hz":20.0,"frequency_high_hz":35.0},"low gamma":{"frequency_low_hz":60.0,"frequency_high_hz":80.0},"high gamma":{"frequency_low_hz":90.0,"frequency_high_hz":200.0},"HFA":{"frequency_low_hz":200.0,"frequency_high_hz":400.0}},"preprocessing":["raw_resampling","notch_filter","re_referencing"],"raw_resampling_settings":{"resample_freq_hz":1000.0},"preprocessing_filter":{"bandstop_filter":true,"bandpass_filter":true,"lowpass_filter":true,"highpass_filter":true,"bandstop_filter_settings":{"frequency_low_hz":100.0,"frequency_high_hz":160.0},"bandpass_filter_settings":{"frequency_low_hz":2.0,"frequency_high_hz":200.0},"lowpass_filter_cutoff_hz":200.0,"highpass_filter_cutoff_hz":3.0},"raw_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"postprocessing":{"feature_normalization":true,"project_cortex":false,"project_subcortex":false},"feature_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"project_cortex_settings":{"max_dist_mm":20.0},"project_subcortex_settings":{"max_dist_mm":5.0},"features":{"raw_hjorth":true,"return_raw":true,"bandpass_filter":false,"stft":false,"fft":true,"welch":true,"sharpwave_analysis":true,"fooof":false,"nolds":false,"coherence":false,"bursts":true,"linelength":true,"mne_connectivity":false,"bispectrum":false},"fft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"welch_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"stft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"bandpass_filter_settings":{"segment_lengths_ms":{"theta":1000,"alpha":500,"low beta":333,"high beta":333,"low gamma":100,"high gamma":100,"HFA":100},"bandpower_features":{"activity":true,"mobility":false,"complexity":false},"log_transform":true,"kalman_filter":false},"kalman_filter_settings":{"Tp":0.1,"sigma_w":0.7,"sigma_v":1.0,"frequency_bands":["theta","alpha","low_beta","high_beta","low_gamma","high_gamma","HFA"]},"burst_settings":{"threshold":75.0,"time_duration_s":30.0,"frequency_bands":["low beta","high beta","low gamma"],"burst_features":{"duration":true,"amplitude":true,"burst_rate_per_s":true,"in_burst":true}},"sharpwave_analysis_settings":{"sharpwave_features":{"peak_left":false,"peak_right":false,"trough":false,"width":false,"prominence":true,"interval":true,"decay_time":false,"rise_time":false,"sharpness":true,"rise_steepness":false,"decay_steepness":false,"slope_ratio":false},"filter_ranges_hz":[{"frequency_low_hz":5.0,"frequency_high_hz":80.0},{"frequency_low_hz":5.0,"frequency_high_hz":30.0}],"detect_troughs":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"detect_peaks":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"estimator":{"mean":["interval"],"median":[],"max":["prominence","sharpness"],"min":[],"var":[]},"apply_estimator_between_peaks_and_troughs":true},"mne_connectivity":{"method":"plv","mode":"multitaper"},"coherence":{"features":{"mean_fband":true,"max_fband":true,"max_allfbands":true},"method":{"coh":true,"icoh":true},"channels":[],"frequency_bands":["high beta"]},"fooof":{"aperiodic":{"exponent":true,"offset":true,"knee":true},"periodic":{"center_frequency":false,"band_width":false,"height_over_ap":false},"windowlength_ms":800.0,"peak_width_limits":{"frequency_low_hz":0.5,"frequency_high_hz":12.0},"max_n_peaks":3,"min_peak_height":0.0,"peak_threshold":2.0,"freq_range_hz":{"frequency_low_hz":2.0,"frequency_high_hz":40.0},"knee":true},"nolds_features":{"raw":true,"frequency_bands":["low beta"],"features":{"sample_entropy":false,"correlation_dimension":false,"lyapunov_exponent":true,"hurst_exponent":false,"detrended_fluctuation_analysis":false}},"bispectrum":{"f1s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"f2s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"compute_features_for_whole_fband_range":true,"frequency_bands":["theta","alpha","low_beta","high_beta"],"components":{"absolute":true,"real":true,"imag":true,"phase":true},"bispectrum_features":{"mean":true,"sum":true,"var":true}}}'; -const nm_settings = JSON.parse(stringJson); - -const filterByKeys = (dict, keys) => { - const filteredDict = {}; - keys.forEach((key) => { - if (typeof key === "string") { - // Top-level key - if (dict.hasOwnProperty(key)) { - filteredDict[key] = dict[key]; - } - } else if (typeof key === "object") { - // Nested key - const topLevelKey = Object.keys(key)[0]; - const nestedKeys = key[topLevelKey]; - if ( - dict.hasOwnProperty(topLevelKey) && - typeof dict[topLevelKey] === "object" - ) { - filteredDict[topLevelKey] = filterByKeys(dict[topLevelKey], nestedKeys); - } - } - }); - return filteredDict; -}; - -const Dropdown = ({ onChange, keysToInclude }) => { - const filteredSettings = filterByKeys(nm_settings, keysToInclude); - const [selectedOption, setSelectedOption] = useState(""); - - const handleChange = (event) => { - const newValue = event.target.value; - setSelectedOption(newValue); - onChange(keysToInclude, newValue); - }; - - return ( -
- -
- ); -}; - -export default Dropdown; diff --git a/gui_dev/src/pages/Settings/FrequencyRange.jsx b/gui_dev/src/pages/Settings/FrequencyRange.jsx deleted file mode 100644 index 5def783b..00000000 --- a/gui_dev/src/pages/Settings/FrequencyRange.jsx +++ /dev/null @@ -1,88 +0,0 @@ -import { useState } from "react"; - -// const onChange = (key, newValue) => { -// settings.frequencyRanges[key] = newValue; -// }; - -// Object.entries(settings.frequencyRanges).map(([key, value]) => ( -// -// )); - -/** */ - -/** - * - * @param {String} key - * @param {Array} freqRange - * @param {Function} onChange - * @returns - */ -export const FrequencyRange = ({ settings }) => { - const [frequencyRanges, setFrequencyRanges] = useState(settings || {}); - // Handle changes in the text fields - const handleInputChange = (label, key, newValue) => { - setFrequencyRanges((prevState) => ({ - ...prevState, - [label]: { - ...prevState[label], - [key]: newValue, - }, - })); - }; - - // Add a new band - const addBand = () => { - const newLabel = `Band ${Object.keys(frequencyRanges).length + 1}`; - setFrequencyRanges((prevState) => ({ - ...prevState, - [newLabel]: { frequency_high_hz: "", frequency_low_hz: "" }, - })); - }; - - // Remove a band - const removeBand = (label) => { - const updatedRanges = { ...frequencyRanges }; - delete updatedRanges[label]; - setFrequencyRanges(updatedRanges); - }; - - return ( -
-
Frequency Bands
- {Object.keys(frequencyRanges).map((label) => ( -
- { - const newLabel = e.target.value; - const updatedRanges = { ...frequencyRanges }; - updatedRanges[newLabel] = updatedRanges[label]; - delete updatedRanges[label]; - setFrequencyRanges(updatedRanges); - }} - placeholder="Band Name" - /> - - handleInputChange(label, "frequency_high_hz", e.target.value) - } - placeholder="High Hz" - /> - - handleInputChange(label, "frequency_low_hz", e.target.value) - } - placeholder="Low Hz" - /> - -
- ))} - -
- ); -}; diff --git a/gui_dev/src/pages/Settings/FrequencySettings.module.css b/gui_dev/src/pages/Settings/FrequencySettings.module.css deleted file mode 100644 index a638ad93..00000000 --- a/gui_dev/src/pages/Settings/FrequencySettings.module.css +++ /dev/null @@ -1,67 +0,0 @@ -.container { - background-color: #f9f9f9; /* Light gray background for the container */ - padding: 20px; - border-radius: 10px; /* Rounded corners */ - box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); /* Subtle shadow */ - max-width: 600px; - margin: auto; - } - - .header { - font-size: 1.5rem; - color: #333; /* Darker text color */ - margin-bottom: 20px; - text-align: center; - } - - .bandContainer { - display: flex; - align-items: center; - margin-bottom: 15px; - padding: 10px; - border: 1px solid #ddd; /* Light border for each band */ - border-radius: 8px; /* Rounded corners for individual bands */ - background-color: #fff; /* White background for bands */ - box-shadow: 0 1px 5px rgba(0, 0, 0, 0.1); /* Light shadow for depth */ - } - - .bandNameInput, .frequencyInput { - border: 1px solid #ccc; /* Light gray border */ - border-radius: 5px; /* Slightly rounded corners */ - padding: 8px; - margin-right: 10px; - font-size: 0.875rem; - } - - .bandNameInput::placeholder, .frequencyInput::placeholder { - color: #aaa; /* Light gray placeholder text */ - } - - .removeButton, .addButton { - border: none; - border-radius: 5px; /* Rounded corners */ - padding: 8px 12px; - font-size: 0.875rem; - cursor: pointer; - transition: background-color 0.3s, color 0.3s; - } - - .removeButton { - background-color: #e57373; /* Light red color */ - color: white; - } - - .removeButton:hover { - background-color: #d32f2f; /* Darker red on hover */ - } - - .addButton { - background-color: #4caf50; /* Green color */ - color: white; - margin-top: 10px; - } - - .addButton:hover { - background-color: #388e3c; /* Darker green on hover */ - } - \ No newline at end of file diff --git a/gui_dev/src/pages/Settings/Settings.jsx b/gui_dev/src/pages/Settings/Settings.jsx index e2cf5f2a..045603be 100644 --- a/gui_dev/src/pages/Settings/Settings.jsx +++ b/gui_dev/src/pages/Settings/Settings.jsx @@ -1,19 +1,21 @@ +import { useEffect, useState } from "react"; import { + Box, Button, - InputAdornment, + Popover, Stack, Switch, TextField, + Tooltip, Typography, } from "@mui/material"; import { Link } from "react-router-dom"; import { CollapsibleBox, TitledBox } from "@/components"; -import { FrequencyRange } from "./FrequencyRange"; -import { useSettingsStore } from "@/stores"; +import { FrequencyRangeList } from "./components/FrequencyRange"; +import { useSettingsStore, useStatusBarContent } from "@/stores"; import { filterObjectByKeys } from "@/utils/functions"; const formatKey = (key) => { - // console.log(key); return key .split("_") .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) @@ -21,15 +23,37 @@ const formatKey = (key) => { }; // Wrapper components for each type -const BooleanField = ({ value, onChange }) => ( +const BooleanField = ({ value, onChange, error }) => ( onChange(e.target.checked)} /> ); -const StringField = ({ value, onChange, label }) => ( - -); +const errorStyle = { + "& .MuiOutlinedInput-root": { + "& fieldset": { borderColor: "error.main" }, + "&:hover fieldset": { + borderColor: "error.main", + }, + "&.Mui-focused fieldset": { + borderColor: "error.main", + }, + }, +}; + +const StringField = ({ value, onChange, label, error }) => { + const errorSx = error ? errorStyle : {}; + return ( + onChange(e.target.value)} + label={label} + sx={{ ...errorSx }} + /> + ); +}; + +const NumberField = ({ value, onChange, label, error }) => { + const errorSx = error ? errorStyle : {}; -const NumberField = ({ value, onChange, label }) => { const handleChange = (event) => { const newValue = event.target.value; // Only allow numbers and decimal point @@ -44,13 +68,14 @@ const NumberField = ({ value, onChange, label }) => { value={value} onChange={handleChange} label={label} - InputProps={{ - endAdornment: ( - - Hz - - ), - }} + sx={{ ...errorSx }} + // InputProps={{ + // endAdornment: ( + // + // Hz + // + // ), + // }} inputProps={{ pattern: "[0-9]*", }} @@ -58,107 +83,230 @@ const NumberField = ({ value, onChange, label }) => { ); }; -const FrequencyRangeField = ({ value, onChange, label }) => ( - -); - // Map component types to their respective wrappers const componentRegistry = { boolean: BooleanField, string: StringField, number: NumberField, - FrequencyRange: FrequencyRangeField, }; -const SettingsField = ({ path, Component, label, value, onChange, depth }) => { +const SettingsField = ({ path, Component, label, value, onChange, error }) => { return ( - - {label} - onChange(path, newValue)} - label={label} - /> - + + + {label} + onChange(path, newValue)} + label={label} + error={error} + /> + + ); }; +// Function to get the error corresponding to this field or its children +const getFieldError = (fieldPath, errors) => { + if (!errors) return null; + + return errors.find((error) => { + const errorPath = error.loc.join("."); + const currentPath = fieldPath.join("."); + return errorPath === currentPath || errorPath.startsWith(currentPath + "."); + }); +}; + const SettingsSection = ({ settings, title = null, path = [], onChange, - depth = 0, + errors, }) => { - if (Object.keys(settings).length === 0) { - return null; - } const boxTitle = title ? title : formatKey(path[path.length - 1]); + /* + 3 possible cases: + 1. Primitive type || 2. Object with component -> Don't iterate, render directly + 3. Object without component or 4. Array -> Iterate and render recursively + */ - return ( - - {Object.entries(settings).map(([key, value]) => { - if (key === "__field_type__") return null; + const type = typeof settings; + const isObject = type === "object" && !Array.isArray(settings); + const isArray = Array.isArray(settings); - const newPath = [...path, key]; - const label = key; - const isPydanticModel = - typeof value === "object" && "__field_type__" in value; + // __field_type__ should be always present + if (isObject && !settings.__field_type__) { + console.log(settings); + throw new Error("Invalid settings object"); + } + const fieldType = isObject ? settings.__field_type__ : type; + const Component = componentRegistry[fieldType]; - const fieldType = isPydanticModel ? value.__field_type__ : typeof value; + // Case 1: Primitive type -> Don't iterate, render directly + if (!isObject && !isArray) { + if (!Component) { + console.error(`Invalid component type: ${type}`); + return null; + } - const Component = componentRegistry[fieldType]; + const error = getFieldError(path, errors); + + return ( + + ); + } + + // Case 2: Object with component -> Don't iterate, render directly + if (isObject && Component) { + return ( + + ); + } + + // Case 3: Object without component or 4. Array -> Iterate and render recursively + if ((isObject && !Component) || isArray) { + return ( + + {/* Handle recursing through both objects and arrays */} + {(isArray ? settings : Object.entries(settings)).map((item, index) => { + const [key, value] = isArray ? [index.toString(), item] : item; + if (key.startsWith("__")) return null; // Skip metadata fields + + const newPath = [...path, key]; - if (Component) { - return ( - - ); - } else { return ( ); - } - })} - + })} + + ); + } + + // Default case: return null and log an error + console.error(`Invalid settings object, returning null`); + return null; +}; + +const StatusBarSettingsInfo = () => { + const validationErrors = useSettingsStore((state) => state.validationErrors); + const [anchorEl, setAnchorEl] = useState(null); + const open = Boolean(anchorEl); + + const handleOpenErrorsPopover = (event) => { + setAnchorEl(event.currentTarget); + }; + + const handleCloseErrorsPopover = () => { + setAnchorEl(null); + }; + + return ( + <> + {validationErrors?.length > 0 && ( + <> + + {validationErrors?.length} errors found in Settings + + + + {validationErrors.map((error, index) => ( + + {index} - [{error.type}] {error.msg} + + ))} + + + + )} + ); }; -const SettingsContent = () => { +export const Settings = () => { + // Get all necessary state from the settings store const settings = useSettingsStore((state) => state.settings); - const updateSettings = useSettingsStore((state) => state.updateSettings); + const uploadSettings = useSettingsStore((state) => state.uploadSettings); + const resetSettings = useSettingsStore((state) => state.resetSettings); + const validationErrors = useSettingsStore((state) => state.validationErrors); + useStatusBarContent(StatusBarSettingsInfo); + + // This is needed so that the frequency ranges stay in order between updates + const frequencyRangeOrder = useSettingsStore( + (state) => state.frequencyRangeOrder + ); + const updateFrequencyRangeOrder = useSettingsStore( + (state) => state.updateFrequencyRangeOrder + ); + // Here I handle the selected feature in the feature settings component + const [selectedFeature, setSelectedFeature] = useState(""); + + useEffect(() => { + uploadSettings(null, true); // validateOnly = true + }, [settings]); + + // Inject validation error info into status bar + + // This has to be after all the hooks, otherwise React will complain if (!settings) { return
Loading settings...
; } - const handleChange = (path, value) => { - updateSettings((settings) => { + // This are the callbacks for the different buttons + const handleChangeSettings = async (path, value) => { + uploadSettings((settings) => { let current = settings; for (let i = 0; i < path.length - 1; i++) { current = current[path[i]]; } current[path[path.length - 1]] = value; - }); + }, true); // validateOnly = true + }; + + const handleSaveSettings = () => { + uploadSettings(() => settings); + }; + + const handleResetSettings = async () => { + await resetSettings(); }; const featureSettingsKeys = Object.keys(settings.features) @@ -181,89 +329,147 @@ const SettingsContent = () => { "project_subcortex_settings", ]; + const generalSettingsKeys = [ + "sampling_rate_features_hz", + "segment_length_features_ms", + ]; + return ( - - + {/* SETTINGS LAYOUT */} + - - + {/* GENERAL SETTINGS + FREQUENCY RANGES */} + + + {generalSettingsKeys.map((key) => ( + + ))} + + + + + + - - + {/* POSTPROCESSING + PREPROCESSING SETTINGS */} + {preprocessingSettingsKeys.map((key) => ( ))} - + - + {postprocessingSettingsKeys.map((key) => ( ))} - - + - - {Object.entries(enabledFeatures).map(([feature, featureSettings]) => ( - - - - ))} - - - ); -}; + {/* FEATURE SETTINGS */} + + + + + + + {Object.entries(enabledFeatures).map( + ([feature, featureSettings]) => ( + + + + ) + )} + + + + {/* END SETTINGS LAYOUT */} +
-export const Settings = () => { - return ( - - - + + {/* */} + + + ); }; diff --git a/gui_dev/src/pages/Settings/TextField.jsx b/gui_dev/src/pages/Settings/TextField.jsx deleted file mode 100644 index 86ab07b8..00000000 --- a/gui_dev/src/pages/Settings/TextField.jsx +++ /dev/null @@ -1,77 +0,0 @@ -import { useState, useEffect } from "react"; -import { - Box, - Grid, - TextField as MUITextField, - Typography, -} from "@mui/material"; -import { useSettingsStore } from "@/stores"; -import styles from "./TextField.module.css"; - -const flattenDictionary = (dict, parentKey = "", result = {}) => { - for (let key in dict) { - const newKey = parentKey ? `${parentKey}.${key}` : key; - if (typeof dict[key] === "object" && dict[key] !== null) { - flattenDictionary(dict[key], newKey, result); - } else { - result[newKey] = dict[key]; - } - } - return result; -}; - -const filterByKeys = (flatDict, keys) => { - const filteredDict = {}; - keys.forEach((key) => { - if (flatDict.hasOwnProperty(key)) { - filteredDict[key] = flatDict[key]; - } - }); - return filteredDict; -}; - -export const TextField = ({ keysToInclude }) => { - const settings = useSettingsStore((state) => state.settings); - const flatSettings = flattenDictionary(settings); - const filteredSettings = filterByKeys(flatSettings, keysToInclude); - const [textLabels, setTextLabels] = useState({}); - - useEffect(() => { - setTextLabels(filteredSettings); - }, [settings]); - - const handleTextFieldChange = (label, value) => { - setTextLabels((prevLabels) => ({ - ...prevLabels, - [label]: value, - })); - }; - - // Function to format the label - const formatLabel = (label) => { - const labelAfterDot = label.split(".").pop(); // Get everything after the last dot - return labelAfterDot.replace(/_/g, " "); // Replace underscores with spaces - }; - - return ( -
- {Object.keys(textLabels).map((label, index) => ( -
- - handleTextFieldChange(label, e.target.value)} - className={styles.textFieldInput} - /> -
- ))} -
- ); -}; diff --git a/gui_dev/src/pages/Settings/TextField.module.css b/gui_dev/src/pages/Settings/TextField.module.css deleted file mode 100644 index 37791c40..00000000 --- a/gui_dev/src/pages/Settings/TextField.module.css +++ /dev/null @@ -1,67 +0,0 @@ -/* TextField.module.css */ - -/* Container for the text fields */ -.textFieldContainer { - display: flex; - flex-direction: column; - margin: 1.5rem 0; /* Increased margin for better spacing */ - } - - /* Row for each text field */ - .textFieldRow { - display: flex; - flex-direction: column; /* Stack label and input vertically */ - margin-bottom: 1rem; /* Increased margin for better separation */ - } - - /* Label for each text field */ - .textFieldLabel { - margin-bottom: 0.5rem; /* Space between label and input */ - font-weight: 600; /* Increased weight for better visibility */ - color: #333; /* Dark gray for the label */ - font-size: 1.1rem; /* Increased font size for the label */ - transition: all 0.2s ease; /* Smooth transition for label */ - } - - /* Input field styles */ - .textFieldInput { - padding: 12px 14px; /* Padding for a filled look */ - border: 1px solid #ccc; /* Light gray border */ - border-radius: 4px; /* Rounded corners */ - width: 100%; /* Full width */ - font-size: 1rem; /* Font size */ - background-color: #f5f5f5; /* Light background color for filled effect */ - transition: border-color 0.2s ease, background-color 0.2s ease; /* Smooth transitions */ - box-shadow: none; /* Remove default shadow */ - height: 48px; /* Fixed height for a more square appearance */ - } - - /* Focus styles for the input */ - .textFieldInput:focus { - border-color: #1976d2; /* Blue border color on focus */ - background-color: #fff; /* Change background to white on focus */ - outline: none; /* Remove default outline */ - } - - /* Hover effect for the input */ - .textFieldInput:hover { - border-color: #1976d2; /* Change border color on hover */ - } - - /* Placeholder styles */ - .textFieldInput::placeholder { - color: #aaa; /* Light gray placeholder text */ - opacity: 1; /* Ensure placeholder is fully opaque */ - } - - /* Hide the number input spinners in webkit browsers */ - .textFieldInput::-webkit-inner-spin-button, - .textFieldInput::-webkit-outer-spin-button { - -webkit-appearance: none; /* Remove default styling */ - margin: 0; /* Remove margin */ - } - - /* Hide the number input spinners in Firefox */ - .textFieldInput[type='number'] { - -moz-appearance: textfield; /* Use textfield appearance */ - } \ No newline at end of file diff --git a/gui_dev/src/pages/Settings/components/FrequencyRange.jsx b/gui_dev/src/pages/Settings/components/FrequencyRange.jsx new file mode 100644 index 00000000..16333166 --- /dev/null +++ b/gui_dev/src/pages/Settings/components/FrequencyRange.jsx @@ -0,0 +1,182 @@ +import { useState } from "react"; +import { TextField, Button, IconButton, Stack } from "@mui/material"; +import { Add, Close } from "@mui/icons-material"; +import { debounce } from "@/utils"; + +const NumberField = ({ ...props }) => ( + +); + +export const FrequencyRange = ({ + name, + range, + onChangeName, + onChangeRange, + onRemove, + error, +}) => { + const [localName, setLocalName] = useState(name); + + const debouncedChangeName = debounce((newName) => { + onChangeName(newName, name); + }, 1000); + + const handleNameChange = (e) => { + const newName = e.target.value; + setLocalName(newName); + debouncedChangeName(newName); + }; + + const handleNameBlur = () => { + onChangeName(localName, name); + }; + + const handleKeyPress = (e) => { + if (e.key === "Enter") { + console.log(e.target.value, name); + onChangeName(localName, name); + } + }; + + const handleRangeChange = (name, field, value) => { + // onChangeRange takes the name of the range as the first argument + onChangeRange(name, { ...range, [field]: value }); + }; + + return ( + + + + handleRangeChange(name, "frequency_low_hz", e.target.value) + } + label="Low Hz" + /> + + handleRangeChange(name, "frequency_high_hz", e.target.value) + } + label="High Hz" + /> + onRemove(name)} + color="primary" + disableRipple + sx={{ m: 0, p: 0 }} + > + + + + ); +}; + +export const FrequencyRangeList = ({ + ranges, + rangeOrder, + onChange, + onOrderChange, + errors, +}) => { + const handleChangeRange = (name, newRange) => { + const updatedRanges = { ...ranges }; + updatedRanges[name] = newRange; + onChange(["frequency_ranges_hz"], updatedRanges); + }; + + const handleChangeName = (newName, oldName) => { + if (oldName === newName) { + return; + } + + const updatedRanges = { ...ranges, [newName]: ranges[oldName] }; + delete updatedRanges[oldName]; + onChange(["frequency_ranges_hz"], updatedRanges); + + const updatedOrder = rangeOrder.map((name) => + name === oldName ? newName : name + ); + onOrderChange(updatedOrder); + }; + + const handleRemove = (name) => { + const updatedRanges = { ...ranges }; + delete updatedRanges[name]; + onChange(["frequency_ranges_hz"], updatedRanges); + + const updatedOrder = rangeOrder.filter((item) => item !== name); + onOrderChange(updatedOrder); + }; + + const addRange = () => { + let newName = "NewRange"; + let counter = 0; + // Find first available name + while (Object.hasOwn(ranges, newName)) { + counter++; + newName = `NewRange${counter}`; + } + + const updatedRanges = { + ...ranges, + [newName]: { + __field_type__: "FrequencyRange", + frequency_low_hz: 1, + frequency_high_hz: 500, + }, + }; + onChange(["frequency_ranges_hz"], updatedRanges); + + const updatedOrder = [...rangeOrder, newName]; + onOrderChange(updatedOrder); + }; + + return ( + + {rangeOrder.map((name, index) => ( + + ))} + + + ); +}; diff --git a/gui_dev/src/pages/Settings/index.js b/gui_dev/src/pages/Settings/index.js deleted file mode 100644 index 16355023..00000000 --- a/gui_dev/src/pages/Settings/index.js +++ /dev/null @@ -1 +0,0 @@ -export { TextField } from './TextField'; diff --git a/gui_dev/src/pages/SourceSelection/FileSelector.jsx b/gui_dev/src/pages/SourceSelection/FileSelector.jsx index ffa53bab..dd0663b1 100644 --- a/gui_dev/src/pages/SourceSelection/FileSelector.jsx +++ b/gui_dev/src/pages/SourceSelection/FileSelector.jsx @@ -25,6 +25,7 @@ export const FileSelector = () => { const [isSelecting, setIsSelecting] = useState(false); const [showFileBrowser, setShowFileBrowser] = useState(false); + const [showFolderBrowser, setShowFolderBrowser] = useState(false); useEffect(() => { setSourceType("lsl"); @@ -48,6 +49,10 @@ export const FileSelector = () => { } }; + const handleFolderSelect = (folder) => { + setShowFolderBrowser(false); + }; + return ( + {streamSetupMessage && ( { onSelect={handleFileSelect} /> )} + {showFolderBrowser && ( + setShowFolderBrowser(false)} + onSelect={handleFolderSelect} + onlyDirectories={true} + /> + )} ); }; diff --git a/gui_dev/src/stores/createStore.js b/gui_dev/src/stores/createStore.js index 762d9340..e683a35c 100644 --- a/gui_dev/src/stores/createStore.js +++ b/gui_dev/src/stores/createStore.js @@ -2,16 +2,23 @@ import { create } from "zustand"; import { immer } from "zustand/middleware/immer"; import { devtools, persist as persistMiddleware } from "zustand/middleware"; -export const createStore = (name, initializer, persist = false) => { +export const createStore = ( + name, + initializer, + persist = false, + dev = false +) => { const fn = persist ? persistMiddleware(immer(initializer), name) : immer(initializer); - return create( - devtools(fn, { - name: name, - }) - ); + const dev_fn = dev + ? devtools(fn, { + name: name, + }) + : fn; + + return create(dev_fn); }; export const createPersistStore = (name, initializer) => { diff --git a/gui_dev/src/stores/settingsStore.js b/gui_dev/src/stores/settingsStore.js index 9b17f42d..68929f1e 100644 --- a/gui_dev/src/stores/settingsStore.js +++ b/gui_dev/src/stores/settingsStore.js @@ -26,11 +26,16 @@ const uploadSettingsToServer = async (settings) => { export const useSettingsStore = createStore("settings", (set, get) => ({ settings: null, + lastValidSettings: null, + frequencyRangeOrder: [], isLoading: false, error: null, + validationErrors: null, retryCount: 0, - setSettings: (settings) => set({ settings }), + updateLocalSettings: (updater) => { + set((state) => updater(state.settings)); + }, fetchSettingsWithDelay: () => { set({ isLoading: true, error: null }); @@ -46,8 +51,15 @@ export const useSettingsStore = createStore("settings", (set, get) => ({ if (!response.ok) { throw new Error("Failed to fetch settings"); } + const data = await response.json(); - set({ settings: data, retryCount: 0 }); + + set({ + settings: data, + lastValidSettings: data, + frequencyRangeOrder: Object.keys(data.frequency_ranges_hz || {}), + retryCount: 0, + }); } catch (error) { console.log("Error fetching settings:", error); set((state) => ({ @@ -55,6 +67,8 @@ export const useSettingsStore = createStore("settings", (set, get) => ({ retryCount: state.retryCount + 1, })); + console.log(get().retryCount); + if (get().retryCount < MAX_RETRIES) { await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY)); return get().fetchSettings(); @@ -66,29 +80,62 @@ export const useSettingsStore = createStore("settings", (set, get) => ({ resetRetryCount: () => set({ retryCount: 0 }), - updateSettings: async (updater) => { - const currentSettings = get().settings; + resetSettings: async () => { + await get().fetchSettings(true); + }, - // Apply the update optimistically - set((state) => { - updater(state.settings); - }); + updateFrequencyRangeOrder: (newOrder) => { + set({ frequencyRangeOrder: newOrder }); + }, - const newSettings = get().settings; + uploadSettings: async (updater, validateOnly = false) => { + if (updater) { + set((state) => { + updater(state.settings); + }); + } + + const currentSettings = get().settings; try { - const result = await uploadSettingsToServer(newSettings); + const response = await fetch( + `/api/settings${validateOnly ? "?validate_only=true" : ""}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(currentSettings), + } + ); + + const data = await response.json(); - if (!result.success) { - // Revert the local state if the server update failed - set({ settings: currentSettings }); + if (!response.ok) { + throw new Error("Failed to upload settings to backend"); } - return result; + if (data.valid) { + // Settings are valid + set({ + lastValidSettings: currentSettings, + validationErrors: null, + }); + return true; + } else { + // Settings are invalid + set({ + validationErrors: data.errors, + }); + // Note: We don't revert the settings here, keeping the potentially invalid state + return false; + } } catch (error) { - // Revert the local state if there was an error - set({ settings: currentSettings }); - throw error; + console.error( + `Error ${validateOnly ? "validating" : "updating"} settings:`, + error + ); + return false; } }, })); diff --git a/gui_dev/src/stores/uiStore.js b/gui_dev/src/stores/uiStore.js index a229ec06..25b44382 100644 --- a/gui_dev/src/stores/uiStore.js +++ b/gui_dev/src/stores/uiStore.js @@ -1,4 +1,5 @@ import { createPersistStore } from "./createStore"; +import { useEffect } from "react"; export const useUiStore = createPersistStore("ui", (set, get) => ({ activeDrawer: null, @@ -28,4 +29,24 @@ export const useUiStore = createPersistStore("ui", (set, get) => ({ state.accordionStates[id] = defaultState; } }), + + // Hook to inject UI elements into the status bar + statusBarContent: () => {}, + setStatusBarContent: (content) => set({ statusBarContent: content }), + clearStatusBarContent: () => set({ statusBarContent: null }), })); + +// Use this hook from Page components to inject page-specific UI elements into the status bar +export const useStatusBarContent = (content) => { + const createStatusBarContent = () => content; + + const setStatusBarContent = useUiStore((state) => state.setStatusBarContent); + const clearStatusBarContent = useUiStore( + (state) => state.clearStatusBarContent + ); + + useEffect(() => { + setStatusBarContent(createStatusBarContent); + return () => clearStatusBarContent(); + }, [content, setStatusBarContent, clearStatusBarContent]); +}; diff --git a/gui_dev/src/theme.js b/gui_dev/src/theme.js index 5e2b5f11..cc99909c 100644 --- a/gui_dev/src/theme.js +++ b/gui_dev/src/theme.js @@ -22,6 +22,11 @@ export const theme = createTheme({ disableRipple: true, }, }, + MuiTextField: { + defaultProps: { + autoComplete: "off", + }, + }, MuiStack: { defaultProps: { alignItems: "center", diff --git a/gui_dev/src/utils/functions.js b/gui_dev/src/utils/functions.js index 04289d5c..1f8ae90a 100644 --- a/gui_dev/src/utils/functions.js +++ b/gui_dev/src/utils/functions.js @@ -17,6 +17,7 @@ export function debounce(func, wait) { timeout = setTimeout(later, wait); }; } + export const flattenDictionary = (dict, parentKey = "", result = {}) => { for (let key in dict) { const newKey = parentKey ? `${parentKey}.${key}` : key; diff --git a/py_neuromodulation/__init__.py b/py_neuromodulation/__init__.py index 688393fd..864df345 100644 --- a/py_neuromodulation/__init__.py +++ b/py_neuromodulation/__init__.py @@ -4,11 +4,12 @@ from importlib.metadata import version from py_neuromodulation.utils.logging import NMLogger + ##################################### # Globals and environment variables # ##################################### -__version__ = version("py_neuromodulation") # get version from pyproject.toml +__version__ = version("py_neuromodulation") # Check if the module is running headless (no display) for tests and doc builds PYNM_HEADLESS: bool = not os.environ.get("DISPLAY") @@ -18,6 +19,7 @@ os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file + # Set environment variable MNE_LSL_LIB (required to import Stream below) LSL_DICT = { "windows_32bit": "windows/x86/liblsl.1.16.2.dll", @@ -36,6 +38,7 @@ PLATFORM = platform.system().lower().strip() ARCH = platform.architecture()[0] + match PLATFORM: case "windows": KEY = PLATFORM + "_" + ARCH diff --git a/py_neuromodulation/default_settings.yaml b/py_neuromodulation/default_settings.yaml index c50b5f8e..e68194f8 100644 --- a/py_neuromodulation/default_settings.yaml +++ b/py_neuromodulation/default_settings.yaml @@ -1,4 +1,15 @@ ---- +# We +# should +# have +# a +# brief +# explanation +# of +# the +# settings +# format +# here + ######################## ### General settings ### ######################## @@ -51,12 +62,8 @@ preprocessing_filter: lowpass_filter: true highpass_filter: true bandpass_filter: true - bandstop_filter_settings: - frequency_low_hz: 100 - frequency_high_hz: 160 - bandpass_filter_settings: - frequency_low_hz: 3 - frequency_high_hz: 200 + bandstop_filter_settings: [100, 160] # [low_hz, high_hz] + bandpass_filter_settings: [3, 200] # [hz, _hz] lowpass_filter_cutoff_hz: 200 highpass_filter_cutoff_hz: 3 @@ -162,11 +169,9 @@ sharpwave_analysis_settings: rise_steepness: false decay_steepness: false slope_ratio: false - filter_ranges_hz: - - frequency_low_hz: 5 - frequency_high_hz: 80 - - frequency_low_hz: 5 - frequency_high_hz: 30 + filter_ranges_hz: # list of [low_hz, high_hz] + - [5, 80] + - [5, 30] detect_troughs: estimate: true distance_troughs_ms: 10 @@ -175,6 +180,7 @@ sharpwave_analysis_settings: estimate: true distance_troughs_ms: 5 distance_peaks_ms: 10 + # TONI: Reverse this setting? e.g. interval: [mean, var] estimator: mean: [interval] median: [] @@ -223,8 +229,8 @@ nolds_settings: frequency_bands: [low_beta] mne_connectiviy_settings: - method: plv - mode: multitaper + method: plv # One of ['coh', 'cohy', 'imcoh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli','wpli', 'wpli2_debiased', 'gc', 'gc_tr'] + mode: multitaper # One of ['multitaper', 'fourier', 'cwt_morlet'] bispectrum_settings: f1s: [5, 35] diff --git a/py_neuromodulation/features/bandpower.py b/py_neuromodulation/features/bandpower.py index dac13f4b..72dab126 100644 --- a/py_neuromodulation/features/bandpower.py +++ b/py_neuromodulation/features/bandpower.py @@ -2,8 +2,13 @@ from collections.abc import Sequence from typing import TYPE_CHECKING from pydantic import field_validator - from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature +from py_neuromodulation.utils.pydantic_extensions import ( + NMField, + NMErrorList, + create_validation_error, +) +from py_neuromodulation import logger if TYPE_CHECKING: from py_neuromodulation.stream.settings import NMSettings @@ -17,15 +22,18 @@ class BandpowerFeatures(BoolSelector): class BandPowerSettings(NMBaseModel): - segment_lengths_ms: dict[str, int] = { - "theta": 1000, - "alpha": 500, - "low beta": 333, - "high beta": 333, - "low gamma": 100, - "high gamma": 100, - "HFA": 100, - } + segment_lengths_ms: dict[str, int] = NMField( + default={ + "theta": 1000, + "alpha": 500, + "low beta": 333, + "high beta": 333, + "low gamma": 100, + "high gamma": 100, + "HFA": 100, + }, + custom_metadata={"field_type": "FrequencySegmentLength"}, + ) bandpower_features: BandpowerFeatures = BandpowerFeatures() log_transform: bool = True kalman_filter: bool = False @@ -33,24 +41,58 @@ class BandPowerSettings(NMBaseModel): @field_validator("bandpower_features") @classmethod def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures): - assert ( - len(bandpower_features.get_enabled()) > 0 - ), "Set at least one bandpower_feature to True." - + if not len(bandpower_features.get_enabled()) > 0: + raise create_validation_error( + error_message="Set at least one bandpower_feature to True.", + location=["bandpass_filter_settings", "bandpower_features"], + ) return bandpower_features - def validate_fbands(self, settings: "NMSettings") -> None: + def validate_fbands(self, settings: "NMSettings") -> NMErrorList: + """_summary_ + + :param settings: _description_ + :type settings: NMSettings + :raises create_validation_error: _description_ + :raises create_validation_error: _description_ + :raises ValueError: _description_ + """ + errors: NMErrorList = NMErrorList() + for fband_name, seg_length_fband in self.segment_lengths_ms.items(): - assert seg_length_fband <= settings.segment_length_features_ms, ( - f"segment length {seg_length_fband} needs to be smaller than " - f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}" - ) + # Check that all frequency bands are defined in settings.frequency_ranges_hz + if fband_name not in settings.frequency_ranges_hz: + logger.warning( + f"Frequency band {fband_name} in bandpass_filter_settings.segment_lengths_ms" + " is not defined in settings.frequency_ranges_hz" + ) + + # Check that all segment lengths are smaller than settings.segment_length_features_ms + if not seg_length_fband <= settings.segment_length_features_ms: + errors.add_error( + f"segment length {seg_length_fband} needs to be smaller than " + f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}", + location=[ + "bandpass_filter_settings", + "segment_lengths_ms", + fband_name, + ], + ) + # Check that all frequency bands defined in settings.frequency_ranges_hz for fband_name in settings.frequency_ranges_hz.keys(): - assert fband_name in self.segment_lengths_ms, ( - f"frequency range {fband_name} " - "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]" - ) + if fband_name not in self.segment_lengths_ms: + errors.add_error( + f"frequency range {fband_name} " + "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms", + location=[ + "bandpass_filter_settings", + "segment_lengths_ms", + fband_name, + ], + ) + + return errors class BandPower(NMFeature): diff --git a/py_neuromodulation/features/bursts.py b/py_neuromodulation/features/bursts.py index 83d63bf8..e9f5b632 100644 --- a/py_neuromodulation/features/bursts.py +++ b/py_neuromodulation/features/bursts.py @@ -7,11 +7,12 @@ from collections.abc import Sequence from itertools import product -from pydantic import Field, field_validator +from pydantic import field_validator from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature +from py_neuromodulation.utils.pydantic_extensions import NMField from typing import TYPE_CHECKING, Callable -from py_neuromodulation.utils.types import create_validation_error +from py_neuromodulation.utils.pydantic_extensions import create_validation_error if TYPE_CHECKING: from py_neuromodulation import NMSettings @@ -46,8 +47,8 @@ class BurstFeatures(BoolSelector): class BurstsSettings(NMBaseModel): - threshold: float = Field(default=75, ge=0, le=100) - time_duration_s: float = Field(default=30, ge=0) + threshold: float = NMField(default=75, ge=0) + time_duration_s: float = NMField(default=30, ge=0, custom_metadata={"unit": "s"}) frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"] burst_features: BurstFeatures = BurstFeatures() diff --git a/py_neuromodulation/features/coherence.py b/py_neuromodulation/features/coherence.py index 21ca471b..6a100710 100644 --- a/py_neuromodulation/features/coherence.py +++ b/py_neuromodulation/features/coherence.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Annotated -from pydantic import Field, field_validator +from pydantic import field_validator from py_neuromodulation.utils.types import ( NMFeature, @@ -11,6 +11,7 @@ FrequencyRange, NMBaseModel, ) +from py_neuromodulation.utils.pydantic_extensions import NMField from py_neuromodulation import logger if TYPE_CHECKING: @@ -28,14 +29,16 @@ class CoherenceFeatures(BoolSelector): max_allfbands: bool = True -ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)] +# TODO: make this into a pydantic model that only accepts names from +# the channels objects and does not accept the same string twice +ListOfTwoStr = Annotated[list[str], NMField(min_length=2, max_length=2)] class CoherenceSettings(NMBaseModel): features: CoherenceFeatures = CoherenceFeatures() method: CoherenceMethods = CoherenceMethods() channels: list[ListOfTwoStr] = [] - frequency_bands: list[str] = Field(default=["high_beta"], min_length=1) + frequency_bands: list[str] = NMField(default=["high_beta"], min_length=1) @field_validator("frequency_bands") def fbands_spaces_to_underscores(cls, frequency_bands): diff --git a/py_neuromodulation/features/feature_processor.py b/py_neuromodulation/features/feature_processor.py index 9e970e8f..dbb7e5c9 100644 --- a/py_neuromodulation/features/feature_processor.py +++ b/py_neuromodulation/features/feature_processor.py @@ -1,13 +1,13 @@ from typing import Type, TYPE_CHECKING -from py_neuromodulation.utils.types import NMFeature, FeatureName +from py_neuromodulation.utils.types import NMFeature, FEATURE_NAME if TYPE_CHECKING: import numpy as np from py_neuromodulation import NMSettings -FEATURE_DICT: dict[FeatureName | str, str] = { +FEATURE_DICT: dict[FEATURE_NAME | str, str] = { "raw_hjorth": "Hjorth", "return_raw": "Raw", "bandpass_filter": "BandPower", @@ -42,7 +42,7 @@ def __init__( from importlib import import_module # Accept 'str' for custom features - self.features: dict[FeatureName | str, NMFeature] = { + self.features: dict[FEATURE_NAME | str, NMFeature] = { feature_name: getattr( import_module("py_neuromodulation.features"), FEATURE_DICT[feature_name] )(settings, ch_names, sfreq) @@ -83,7 +83,7 @@ def estimate_features(self, data: "np.ndarray") -> dict: return feature_results - def get_feature(self, fname: FeatureName) -> NMFeature: + def get_feature(self, fname: FEATURE_NAME) -> NMFeature: return self.features[fname] diff --git a/py_neuromodulation/features/fooof.py b/py_neuromodulation/features/fooof.py index 683c5304..9f4b9544 100644 --- a/py_neuromodulation/features/fooof.py +++ b/py_neuromodulation/features/fooof.py @@ -9,6 +9,7 @@ BoolSelector, FrequencyRange, ) +from py_neuromodulation.utils.pydantic_extensions import NMField if TYPE_CHECKING: from py_neuromodulation import NMSettings @@ -29,11 +30,11 @@ class FooofPeriodicSettings(BoolSelector): class FooofSettings(NMBaseModel): aperiodic: FooofAperiodicSettings = FooofAperiodicSettings() periodic: FooofPeriodicSettings = FooofPeriodicSettings() - windowlength_ms: float = 800 + windowlength_ms: float = NMField(800, gt=0, custom_metadata={"unit": "ms"}) peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12) - max_n_peaks: int = 3 - min_peak_height: float = 0 - peak_threshold: float = 2 + max_n_peaks: int = NMField(3, ge=0) + min_peak_height: float = NMField(0, ge=0) + peak_threshold: float = NMField(2, ge=0) freq_range_hz: FrequencyRange = FrequencyRange(2, 40) knee: bool = True diff --git a/py_neuromodulation/features/mne_connectivity.py b/py_neuromodulation/features/mne_connectivity.py index 88883771..f7a186e7 100644 --- a/py_neuromodulation/features/mne_connectivity.py +++ b/py_neuromodulation/features/mne_connectivity.py @@ -1,8 +1,9 @@ from collections.abc import Iterable import numpy as np -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from py_neuromodulation.utils.types import NMFeature, NMBaseModel +from py_neuromodulation.utils.pydantic_extensions import NMField if TYPE_CHECKING: from py_neuromodulation import NMSettings @@ -10,9 +11,30 @@ from mne import Epochs +MNE_CONNECTIVITY_METHOD = Literal[ + "coh", + "cohy", + "imcoh", + "cacoh", + "mic", + "mim", + "plv", + "ciplv", + "ppc", + "pli", + "dpli", + "wpli", + "wpli2_debiased", + "gc", + "gc_tr", +] + +MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"] + + class MNEConnectivitySettings(NMBaseModel): - method: str = "plv" - mode: str = "multitaper" + method: MNE_CONNECTIVITY_METHOD = NMField(default="plv") + mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper") class MNEConnectivity(NMFeature): diff --git a/py_neuromodulation/features/oscillatory.py b/py_neuromodulation/features/oscillatory.py index ba376cc6..bb4b9e6a 100644 --- a/py_neuromodulation/features/oscillatory.py +++ b/py_neuromodulation/features/oscillatory.py @@ -3,6 +3,7 @@ from itertools import product from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature +from py_neuromodulation.utils.pydantic_extensions import NMField from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -17,7 +18,7 @@ class OscillatoryFeatures(BoolSelector): class OscillatorySettings(NMBaseModel): - windowlength_ms: int = 1000 + windowlength_ms: int = NMField(1000, gt=0, custom_metadata={"unit": "ms"}) log_transform: bool = True features: OscillatoryFeatures = OscillatoryFeatures( mean=True, median=False, std=False, max=False diff --git a/py_neuromodulation/filter/kalman_filter.py b/py_neuromodulation/filter/kalman_filter.py index b7ae1075..0ebfe195 100644 --- a/py_neuromodulation/filter/kalman_filter.py +++ b/py_neuromodulation/filter/kalman_filter.py @@ -1,7 +1,9 @@ import numpy as np from typing import TYPE_CHECKING + from py_neuromodulation.utils.types import NMBaseModel +from py_neuromodulation.utils.pydantic_extensions import NMErrorList if TYPE_CHECKING: @@ -22,13 +24,22 @@ class KalmanSettings(NMBaseModel): "HFA", ] - def validate_fbands(self, settings: "NMSettings") -> None: - assert all( + def validate_fbands(self, settings: "NMSettings") -> NMErrorList: + errors: NMErrorList = NMErrorList() + + if not all( (item in settings.frequency_ranges_hz for item in self.frequency_bands) - ), ( - "Frequency bands for Kalman filter must also be specified in " - "bandpass_filter_settings." - ) + ): + errors.add_error( + "Frequency bands for Kalman filter must also be specified in " + "frequency_ranges_hz.", + location=[ + "kalman_filter_settings", + "frequency_bands", + ], + ) + + return errors def define_KF(Tp, sigma_w, sigma_v): diff --git a/py_neuromodulation/gui/backend/app_backend.py b/py_neuromodulation/gui/backend/app_backend.py index 06b929f0..01858fe6 100644 --- a/py_neuromodulation/gui/backend/app_backend.py +++ b/py_neuromodulation/gui/backend/app_backend.py @@ -13,6 +13,7 @@ ) from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware +from pydantic import ValidationError from . import app_pynm from .app_socket import WebSocketManager @@ -74,21 +75,63 @@ async def healthcheck(): #################### ##### SETTINGS ##### #################### - @self.get("/api/settings") - async def get_settings(): - return self.pynm_state.settings.process_for_frontend() + async def get_settings( + reset: bool = Query(False, description="Reset settings to default"), + ): + if reset: + settings = NMSettings.get_default() + else: + settings = self.pynm_state.settings + + return settings.serialize_with_metadata() @self.post("/api/settings") - async def update_settings(data: dict): + async def update_settings(data: dict, validate_only: bool = Query(False)): try: - self.pynm_state.settings = NMSettings.model_validate(data) - self.logger.info(self.pynm_state.settings.features) - return self.pynm_state.settings.model_dump() - except ValueError as e: + # First, validate with Pydantic + try: + # TODO: check if this works properly or needs model_validate_strings + validated_settings = NMSettings.model_validate(data) + except ValidationError as e: + if not validate_only: + # If validation failed but we wanted to upload, return error + self.logger.error(f"Error validating settings: {e}") + raise HTTPException( + status_code=422, + detail={ + "error": "Error validating settings", + "details": str(e), + }, + ) + # Else return list of errors + return { + "valid": False, + "errors": [err for err in e.errors()], + "details": str(e), + } + + # If validation succesful, return or update settings + if validate_only: + return { + "valid": True, + "settings": validated_settings.serialize_with_metadata(), + } + + self.pynm_state.settings = validated_settings + self.logger.info("Settings successfully updated") + + return { + "valid": True, + "settings": self.pynm_state.settings.serialize_with_metadata(), + } + + # If something else than validation went wrong, return error + except Exception as e: + self.logger.error(f"Error validating/updating settings: {e}") raise HTTPException( status_code=422, - detail={"error": "Validation failed", "details": str(e)}, + detail={"error": "Error uploading settings", "details": str(e)}, ) ######################## @@ -105,10 +148,9 @@ async def handle_stream_control(data: dict): self.logger.info("Starting stream") self.pynm_state.start_run_function( - websocket_manager=self.websocket_manager, + websocket_manager=self.websocket_manager, ) - if action == "stop": self.logger.info("Stopping stream") self.pynm_state.stream_handling_queue.put("stop") diff --git a/py_neuromodulation/gui/backend/app_pynm.py b/py_neuromodulation/gui/backend/app_pynm.py index 523a4f62..d12ef89c 100644 --- a/py_neuromodulation/gui/backend/app_pynm.py +++ b/py_neuromodulation/gui/backend/app_pynm.py @@ -12,8 +12,13 @@ from py_neuromodulation.utils.io import read_mne_data from py_neuromodulation import logger -async def run_stream_controller(feature_queue: queue.Queue, rawdata_queue: queue.Queue, - websocket_manager: "WebSocketManager", stop_event: threading.Event): + +async def run_stream_controller( + feature_queue: queue.Queue, + rawdata_queue: queue.Queue, + websocket_manager: "WebSocketManager", + stop_event: threading.Event, +): while not stop_event.wait(0.002): if not feature_queue.empty() and websocket_manager is not None: feature_dict = feature_queue.get() @@ -22,16 +27,23 @@ async def run_stream_controller(feature_queue: queue.Queue, rawdata_queue: queue if not rawdata_queue.empty() and websocket_manager is not None: raw_data = rawdata_queue.get() - + await websocket_manager.send_cbor(raw_data) -def run_stream_controller_sync(feature_queue: queue.Queue, - rawdata_queue: queue.Queue, - websocket_manager: "WebSocketManager", - stop_event: threading.Event - ): + +def run_stream_controller_sync( + feature_queue: queue.Queue, + rawdata_queue: queue.Queue, + websocket_manager: "WebSocketManager", + stop_event: threading.Event, +): # The run_stream_controller needs to be started as an asyncio function due to the async websocket - asyncio.run(run_stream_controller(feature_queue, rawdata_queue, websocket_manager, stop_event)) + asyncio.run( + run_stream_controller( + feature_queue, rawdata_queue, websocket_manager, stop_event + ) + ) + class PyNMState: def __init__( @@ -52,12 +64,10 @@ def __init__( self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1])) self.settings: NMSettings = NMSettings(sampling_rate_features=10) - def start_run_function( self, websocket_manager=None, ) -> None: - self.stream.settings = self.settings self.stream_handling_queue = queue.Queue() @@ -70,7 +80,7 @@ def start_run_function( if os.path.exists(self.decoding_model_path): self.decoder = RealTimeDecoder(self.decoding_model_path) else: - logger.debug("Passed decoding model path does't exist") + logger.debug("Passed decoding model path does't exist") # Stop event # .set() is called from app_backend self.stop_event_ws = threading.Event() @@ -78,16 +88,19 @@ def start_run_function( self.stream_controller_thread = Thread( target=run_stream_controller_sync, daemon=True, - args=(self.feature_queue, - self.rawdata_queue, - websocket_manager, - self.stop_event_ws - ), + args=( + self.feature_queue, + self.rawdata_queue, + websocket_manager, + self.stop_event_ws, + ), ) is_stream_lsl = self.lsl_stream_name is not None - stream_lsl_name = self.lsl_stream_name if self.lsl_stream_name is not None else "" - + stream_lsl_name = ( + self.lsl_stream_name if self.lsl_stream_name is not None else "" + ) + # The run_func_thread is terminated through the stream_handling_queue # which initiates to break the data generator and save the features out_dir = "" if self.out_dir == "default" else self.out_dir @@ -95,15 +108,15 @@ def start_run_function( target=self.stream.run, daemon=True, kwargs={ - "out_dir" : out_dir, - "experiment_name" : self.experiment_name, - "stream_handling_queue" : self.stream_handling_queue, - "is_stream_lsl" : is_stream_lsl, - "stream_lsl_name" : stream_lsl_name, - "feature_queue" : self.feature_queue, - "simulate_real_time" : True, - "rawdata_queue" : self.rawdata_queue, - "decoder" : self.decoder, + "out_dir": out_dir, + "experiment_name": self.experiment_name, + "stream_handling_queue": self.stream_handling_queue, + "is_stream_lsl": is_stream_lsl, + "stream_lsl_name": stream_lsl_name, + "feature_queue": self.feature_queue, + "simulate_real_time": True, + "rawdata_queue": self.rawdata_queue, + "decoder": self.decoder, }, ) @@ -157,10 +170,10 @@ def setup_lsl_stream( settings=self.settings, ) self.logger.info("stream setup") - #self.settings: NMSettings = NMSettings(sampling_rate_features=sfreq) + # self.settings: NMSettings = NMSettings(sampling_rate_features=sfreq) self.logger.info("settings setup") break - + if channels.shape[0] == 0: self.logger.error(f"Stream {lsl_stream_name} not found") raise ValueError(f"Stream {lsl_stream_name} not found") diff --git a/py_neuromodulation/processing/data_preprocessor.py b/py_neuromodulation/processing/data_preprocessor.py index 319926d9..3b9a9088 100644 --- a/py_neuromodulation/processing/data_preprocessor.py +++ b/py_neuromodulation/processing/data_preprocessor.py @@ -1,12 +1,12 @@ from typing import TYPE_CHECKING, Type -from py_neuromodulation.utils.types import PreprocessorName, NMPreprocessor +from py_neuromodulation.utils.types import PREPROCESSOR_NAME, NMPreprocessor if TYPE_CHECKING: import numpy as np import pandas as pd from py_neuromodulation.stream.settings import NMSettings -PREPROCESSOR_DICT: dict[PreprocessorName, str] = { +PREPROCESSOR_DICT: dict[PREPROCESSOR_NAME, str] = { "preprocessing_filter": "PreprocessingFilter", "notch_filter": "NotchFilter", "raw_resampling": "Resampler", diff --git a/py_neuromodulation/processing/filter_preprocessing.py b/py_neuromodulation/processing/filter_preprocessing.py index e6515ae0..2c5addbc 100644 --- a/py_neuromodulation/processing/filter_preprocessing.py +++ b/py_neuromodulation/processing/filter_preprocessing.py @@ -1,22 +1,14 @@ import numpy as np -from pydantic import Field from typing import TYPE_CHECKING from py_neuromodulation.utils.types import BoolSelector, FrequencyRange, NMPreprocessor +from py_neuromodulation.utils.pydantic_extensions import NMField if TYPE_CHECKING: from py_neuromodulation import NMSettings -FILTER_SETTINGS_MAP = { - "bandstop_filter": "bandstop_filter_settings", - "bandpass_filter": "bandpass_filter_settings", - "lowpass_filter": "lowpass_filter_cutoff_hz", - "highpass_filter": "highpass_filter_cutoff_hz", -} - - class FilterSettings(BoolSelector): bandstop_filter: bool = True bandpass_filter: bool = True @@ -25,21 +17,23 @@ class FilterSettings(BoolSelector): bandstop_filter_settings: FrequencyRange = FrequencyRange(100, 160) bandpass_filter_settings: FrequencyRange = FrequencyRange(2, 200) - lowpass_filter_cutoff_hz: float = Field(default=200) - highpass_filter_cutoff_hz: float = Field(default=3) - - def get_filter_tuple(self, filter_name) -> tuple[float | None, float | None]: - filter_value = self[FILTER_SETTINGS_MAP[filter_name]] - + lowpass_filter_cutoff_hz: float = NMField( + default=200, gt=0, custom_metadata={"unit": "Hz"} + ) + highpass_filter_cutoff_hz: float = NMField( + default=3, gt=0, custom_metadata={"unit": "Hz"} + ) + + def get_filter_tuple(self, filter_name) -> FrequencyRange: match filter_name: case "bandstop_filter": - return (filter_value.frequency_high_hz, filter_value.frequency_low_hz) + return self.bandstop_filter_settings case "bandpass_filter": - return (filter_value.frequency_low_hz, filter_value.frequency_high_hz) + return self.bandpass_filter_settings case "lowpass_filter": - return (None, filter_value) + return FrequencyRange(None, self.lowpass_filter_cutoff_hz) case "highpass_filter": - return (filter_value, None) + return FrequencyRange(self.highpass_filter_cutoff_hz, None) case _: raise ValueError( "Filter name must be one of 'bandstop_filter', 'lowpass_filter', " diff --git a/py_neuromodulation/processing/normalization.py b/py_neuromodulation/processing/normalization.py index e3f25d55..713e2308 100644 --- a/py_neuromodulation/processing/normalization.py +++ b/py_neuromodulation/processing/normalization.py @@ -3,10 +3,10 @@ import numpy as np from typing import TYPE_CHECKING, Callable, Literal, get_args +from py_neuromodulation.utils.pydantic_extensions import NMField from py_neuromodulation.utils.types import ( NMBaseModel, - Field, - NormMethod, + NORM_METHOD, NMPreprocessor, ) @@ -17,13 +17,13 @@ class NormalizationSettings(NMBaseModel): - normalization_time_s: float = 30 - normalization_method: NormMethod = "zscore" - clip: float = Field(default=3, ge=0) + normalization_time_s: float = NMField(30, gt=0, custom_metadata={"unit": "s"}) + normalization_method: NORM_METHOD = NMField(default="zscore") + clip: float = NMField(default=3, ge=0, custom_metadata={"unit": "a.u."}) @staticmethod - def list_normalization_methods() -> list[NormMethod]: - return list(get_args(NormMethod)) + def list_normalization_methods() -> list[NORM_METHOD]: + return list(get_args(NORM_METHOD)) class FeatureNormalizationSettings(NormalizationSettings): normalize_psd: bool = False @@ -60,7 +60,7 @@ def __init__( if self.using_sklearn: import sklearn.preprocessing as skpp - NORM_METHODS_SKLEARN: dict[NormMethod, Callable] = { + NORM_METHODS_SKLEARN: dict[NORM_METHOD, Callable] = { "quantile": lambda: skpp.QuantileTransformer(n_quantiles=300), "robust": skpp.RobustScaler, "minmax": skpp.MinMaxScaler, diff --git a/py_neuromodulation/processing/projection.py b/py_neuromodulation/processing/projection.py index 4cdc3610..3a83a29f 100644 --- a/py_neuromodulation/processing/projection.py +++ b/py_neuromodulation/processing/projection.py @@ -1,6 +1,6 @@ import numpy as np -from pydantic import Field from py_neuromodulation.utils.types import NMBaseModel +from py_neuromodulation.utils.pydantic_extensions import NMField from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -9,7 +9,7 @@ class ProjectionSettings(NMBaseModel): - max_dist_mm: float = Field(default=20.0, gt=0.0) + max_dist_mm: float = NMField(default=20.0, gt=0.0, custom_metadata={"unit": "mm"}) class Projection: diff --git a/py_neuromodulation/processing/resample.py b/py_neuromodulation/processing/resample.py index 08a4e115..1b7107eb 100644 --- a/py_neuromodulation/processing/resample.py +++ b/py_neuromodulation/processing/resample.py @@ -1,11 +1,14 @@ """Module for resampling.""" import numpy as np -from py_neuromodulation.utils.types import NMBaseModel, Field, NMPreprocessor +from py_neuromodulation.utils.types import NMBaseModel, NMPreprocessor +from py_neuromodulation.utils.pydantic_extensions import NMField class ResamplerSettings(NMBaseModel): - resample_freq_hz: float = Field(default=1000, gt=0) + resample_freq_hz: float = NMField( + default=1000, gt=0, custom_metadata={"unit": "Hz"} + ) class Resampler(NMPreprocessor): diff --git a/py_neuromodulation/stream/settings.py b/py_neuromodulation/stream/settings.py index 669c97db..bb3ca7cd 100644 --- a/py_neuromodulation/stream/settings.py +++ b/py_neuromodulation/stream/settings.py @@ -1,19 +1,22 @@ """Module for handling settings.""" from pathlib import Path -from typing import ClassVar -from pydantic import Field, model_validator +from typing import Any, ClassVar +from pydantic import model_validator, ValidationError +from pydantic.functional_validators import ModelWrapValidatorHandler -from py_neuromodulation import PYNM_DIR, logger, user_features +from py_neuromodulation import logger, user_features +from types import SimpleNamespace from py_neuromodulation.utils.types import ( BoolSelector, FrequencyRange, - PreprocessorName, _PathLike, NMBaseModel, - NormMethod, + NORM_METHOD, + PreprocessorList, ) +from py_neuromodulation.utils.pydantic_extensions import NMErrorList, NMField from py_neuromodulation.processing.filter_preprocessing import FilterSettings from py_neuromodulation.processing.normalization import FeatureNormalizationSettings, NormalizationSettings @@ -31,7 +34,9 @@ from py_neuromodulation.features import BurstsSettings -class FeatureSelection(BoolSelector): +# TONI: this class has the proble that if a feature is absent, +# it won't default to False but to whatever is defined here as default +class FeatureSelector(BoolSelector): raw_hjorth: bool = True return_raw: bool = True bandpass_filter: bool = False @@ -59,8 +64,12 @@ class NMSettings(NMBaseModel): _instances: ClassVar[list["NMSettings"]] = [] # General settings - sampling_rate_features_hz: float = Field(default=10, gt=0) - segment_length_features_ms: float = Field(default=1000, gt=0) + sampling_rate_features_hz: float = NMField( + default=10, gt=0, custom_metadata={"unit": "Hz"} + ) + segment_length_features_ms: float = NMField( + default=1000, gt=0, custom_metadata={"unit": "ms"} + ) frequency_ranges_hz: dict[str, FrequencyRange] = { "theta": FrequencyRange(4, 8), "alpha": FrequencyRange(8, 12), @@ -72,11 +81,14 @@ class NMSettings(NMBaseModel): } # Preproceessing settings - preprocessing: list[PreprocessorName] = [ - "raw_resampling", - "notch_filter", - "re_referencing", - ] + preprocessing: PreprocessorList = PreprocessorList( + [ + "raw_resampling", + "notch_filter", + "re_referencing", + ] + ) + raw_resampling_settings: ResamplerSettings = ResamplerSettings() preprocessing_filter: FilterSettings = FilterSettings() raw_normalization_settings: NormalizationSettings = NormalizationSettings() @@ -88,7 +100,7 @@ class NMSettings(NMBaseModel): project_subcortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=5) # Feature settings - features: FeatureSelection = FeatureSelection() + features: FeatureSelector = FeatureSelector() fft_settings: OscillatorySettings = OscillatorySettings() welch_settings: OscillatorySettings = OscillatorySettings() @@ -126,10 +138,23 @@ def _remove_feature(cls, feature: str) -> None: for instance in cls._instances: delattr(instance.features, feature) - @model_validator(mode="after") - def validate_settings(self): + @model_validator(mode="wrap") # type: ignore[reportIncompatibleMethodOverride] + def validate_settings(self, handler: ModelWrapValidatorHandler) -> Any: + # Perform all necessary custom validations in the settings class and also + # all validations in the feature classes that need additional information from + # the settings class + errors: NMErrorList = NMErrorList() + + try: + # validate the model + self = handler(self) + except ValidationError as e: + self = NMSettings.unvalidated(**self) + NMSettings.model_fields_set + errors.extend(NMErrorList(e.errors())) + if len(self.features.get_enabled()) == 0: - raise ValueError("At least one feature must be selected.") + errors.add_error("At least one feature must be selected.") # Replace spaces with underscores in frequency band names self.frequency_ranges_hz = { @@ -138,32 +163,33 @@ def validate_settings(self): if self.features.bandpass_filter: # Check BandPass settings frequency bands - self.bandpass_filter_settings.validate_fbands(self) + errors.extend(self.bandpass_filter_settings.validate_fbands(self)) # Check Kalman filter frequency bands if self.bandpass_filter_settings.kalman_filter: - self.kalman_filter_settings.validate_fbands(self) + errors.extend(self.kalman_filter_settings.validate_fbands(self)) - for k, v in self.frequency_ranges_hz.items(): - if not isinstance(v, FrequencyRange): - self.frequency_ranges_hz[k] = FrequencyRange.create_from(v) + if len(errors) > 0: + raise errors.create_error() return self def reset(self) -> "NMSettings": self.features.disable_all() - self.preprocessing = [] + self.preprocessing = PreprocessorList() self.postprocessing.disable_all() return self def set_fast_compute(self) -> "NMSettings": self.reset() self.features.fft = True - self.preprocessing = [ - "raw_resampling", - "notch_filter", - "re_referencing", - ] + self.preprocessing = PreprocessorList( + [ + "raw_resampling", + "notch_filter", + "re_referencing", + ] + ) self.postprocessing.feature_normalization = True self.postprocessing.project_cortex = False self.postprocessing.project_subcortex = False @@ -250,10 +276,10 @@ def from_file(PATH: _PathLike) -> "NMSettings": @staticmethod def get_default() -> "NMSettings": - return NMSettings.from_file(PYNM_DIR / "default_settings.yaml") + return NMSettings() @staticmethod - def list_normalization_methods() -> list[NormMethod]: + def list_normalization_methods() -> list[NORM_METHOD]: return NormalizationSettings.list_normalization_methods() def save( diff --git a/py_neuromodulation/stream/stream.py b/py_neuromodulation/stream/stream.py index 2ee03421..df8338aa 100644 --- a/py_neuromodulation/stream/stream.py +++ b/py_neuromodulation/stream/stream.py @@ -85,7 +85,7 @@ def __init__( raise ValueError("Either `channels` or `data` must be passed to `Stream`.") # If features that use frequency ranges are on, test them against nyquist frequency - use_freq_ranges: list[FeatureName] = [ + use_freq_ranges: list[FEATURE_NAME] = [ "bandpass_filter", "stft", "fft", @@ -209,7 +209,7 @@ def run( return_df: bool = True, simulate_real_time: bool = False, decoder: RealTimeDecoder = None, - feature_queue: "queue.Queue | None" = None, + feature_queue: "queue.Queue | None" = None, rawdata_queue: "queue.Queue | None" = None, stream_handling_queue: "queue.Queue | None" = None, ): @@ -303,7 +303,9 @@ def run( if decoder is not None: ch_to_decode = self.channels.query("used == 1").iloc[0]["name"] - feature_dict = decoder.predict(feature_dict, ch_to_decode, fft_bands_only=True) + feature_dict = decoder.predict( + feature_dict, ch_to_decode, fft_bands_only=True + ) feature_dict["time"] = np.ceil(this_batch_end * 1000 + 1) @@ -346,8 +348,9 @@ def run( self._save_after_stream() self.is_running = False - return feature_df # Timon: We could think of returning the feature_reader instead - + return ( + feature_df # Timon: We could think of returning the feature_reader instead + ) def plot_raw_signal( self, @@ -437,7 +440,7 @@ def _save_sidecar(self) -> None: """Save sidecar incduing fs, coords, sess_right to out_path_root and subfolder 'folder_name'""" additional_args = {"sess_right": self.sess_right} - + self.data_processor.save_sidecar( self.out_dir, self.experiment_name, additional_args ) diff --git a/py_neuromodulation/utils/pydantic_extensions.py b/py_neuromodulation/utils/pydantic_extensions.py new file mode 100644 index 00000000..ee89ccea --- /dev/null +++ b/py_neuromodulation/utils/pydantic_extensions.py @@ -0,0 +1,420 @@ +import copy +from typing import ( + Any, + get_origin, + get_args, + get_type_hints, + TypeVar, + Generic, + Literal, + cast, + Sequence, +) +from typing_extensions import Unpack, TypedDict +from pydantic import BaseModel, ConfigDict, model_validator, model_serializer + +from pydantic_core import ( + ErrorDetails, + PydanticUndefined, + InitErrorDetails, + ValidationError, +) +from pydantic.fields import FieldInfo, _FieldInfoInputs, _FromFieldInfoInputs +from pprint import pformat + + +def create_validation_error( + error_message: str, + location: list[str | int] = [], + title: str = "Validation Error", + error_type="value_error", +) -> ValidationError: + """ + Factory function to create a Pydantic v2 ValidationError. + + Args: + error_message (str): The error message for the ValidationError. + loc (List[str | int], optional): The location of the error. Defaults to None. + title (str, optional): The title of the error. Defaults to "Validation Error". + + Returns: + ValidationError: A Pydantic ValidationError instance. + """ + + return ValidationError.from_exception_data( + title=title, + line_errors=[ + InitErrorDetails( + type=error_type, + loc=tuple(location), + input=None, + ctx={"error": error_message}, + ) + ], + input_type="python", + hide_input=False, + ) + + +class NMErrorList: + """Class to handle data about Pydantic errors. + Stores data in a list of InitErrorDetails. Errors can be accessed but not modified. + + :return: _description_ + :rtype: _type_ + """ + + def __init__( + self, errors: Sequence[InitErrorDetails | ErrorDetails] | None = None + ) -> None: + self.__errors: list[InitErrorDetails | ErrorDetails] = [e for e in errors or []] + + def add_error( + self, + error_message: str, + location: list[str | int] = [], + error_type="value_error", + ) -> None: + self.__errors.append( + InitErrorDetails( + type=error_type, + loc=tuple(location), + input=None, + ctx={"error": error_message}, + ) + ) + + def create_error(self, title: str = "Validation Error") -> ValidationError: + return ValidationError.from_exception_data( + title=title, line_errors=cast(list[InitErrorDetails], self.__errors) + ) + + def extend(self, errors: "NMErrorList"): + self.__errors.extend(errors.__errors) + + def __iter__(self): + return iter(self.__errors) + + def __len__(self): + return len(self.__errors) + + def __getitem__(self, idx): + # Return a copy of the error to prevent modification + return copy.deepcopy(self.__errors[idx]) + + def __repr__(self): + return repr(self.__errors) + + def __str__(self): + return str(self.__errors) + + +class _NMExtraFieldInputs(TypedDict, total=False): + """Additional fields to add on top of the pydantic FieldInfo""" + + custom_metadata: dict[str, Any] + + +class _NMFieldInfoInputs(_FieldInfoInputs, _NMExtraFieldInputs, total=False): + """Combine pydantic FieldInfo inputs with PyNM additional inputs""" + + pass + + +class _NMFromFieldInfoInputs(_FromFieldInfoInputs, _NMExtraFieldInputs, total=False): + """Combine pydantic FieldInfo.from_field inputs with PyNM additional inputs""" + + pass + + +class NMFieldInfo(FieldInfo): + # Add default values for any other custom fields here + _default_values = {} + + def __init__(self, **kwargs: Unpack[_NMFieldInfoInputs]) -> None: + self.custom_metadata: dict[str, Any] = kwargs.pop("custom_metadata", {}) + super().__init__(**kwargs) + + @staticmethod + def from_field( + default: Any = PydanticUndefined, + **kwargs: Unpack[_NMFromFieldInfoInputs], + ) -> "NMFieldInfo": + if "annotation" in kwargs: + raise TypeError('"annotation" is not permitted as a Field keyword argument') + return NMFieldInfo(default=default, **kwargs) + + def __repr_args__(self): + yield from super().__repr_args__() + extra_fields = get_type_hints(_NMExtraFieldInputs) + for field in extra_fields: + value = getattr(self, field) + yield field, value + + +def NMField( + default: Any = PydanticUndefined, + **kwargs: Unpack[_NMFromFieldInfoInputs], +) -> Any: + return NMFieldInfo.from_field(default=default, **kwargs) + + +class NMBaseModel(BaseModel): + # model_config = ConfigDict(validate_assignment=False, extra="allow") + + def __init__(self, *args, **kwargs) -> None: + """Pydantic does not support positional arguments by default. + This is a workaround to support positional arguments for models like FrequencyRange. + It converts positional arguments to kwargs and then calls the base class __init__. + """ + + if not args: + # Simple case - just use kwargs + super().__init__(*args, **kwargs) + return + + field_names = list(self.model_fields.keys()) + # If we have more positional args than fields, that's an error + if len(args) > len(field_names): + raise ValueError( + f"Too many positional arguments. Expected at most {len(field_names)}, got {len(args)}" + ) + + # Convert positional args to kwargs, checking for duplicates if args: + complete_kwargs = {} + for i, arg in enumerate(args): + field_name = field_names[i] + if field_name in kwargs: + raise ValueError( + f"Got multiple values for field '{field_name}': " + f"positional argument and keyword argument" + ) + complete_kwargs[field_name] = arg + + # Add remaining kwargs + complete_kwargs.update(kwargs) + super().__init__(**complete_kwargs) + + __init__.__pydantic_base_init__ = True # type: ignore + + def __str__(self): + return pformat(self.model_dump()) + + # def __repr__(self): + # return pformat(self.model_dump()) + + def validate(self, context: Any | None = None) -> Any: # type: ignore + return self.model_validate(self.model_dump(), context=context) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value) -> None: + setattr(self, key, value) + + @property + def fields(self) -> dict[str, FieldInfo | NMFieldInfo]: + return self.model_fields # type: ignore + + def serialize_with_metadata(self): + result: dict[str, Any] = {"__field_type__": self.__class__.__name__} + + for field_name, field_info in self.fields.items(): + value = getattr(self, field_name) + if isinstance(value, NMBaseModel): + result[field_name] = value.serialize_with_metadata() + elif isinstance(value, list): + result[field_name] = [ + item.serialize_with_metadata() + if isinstance(item, NMBaseModel) + else item + for item in value + ] + elif isinstance(value, dict): + result[field_name] = { + k: v.serialize_with_metadata() if isinstance(v, NMBaseModel) else v + for k, v in value.items() + } + else: + result[field_name] = value + + # Extract unit information from Annotated type + if isinstance(field_info, NMFieldInfo): + # Convert scalar value to dict with metadata + field_dict = { + "value": value, + # __field_type__ will be overwritte if set in custom_metadata + "__field_type__": type(value).__name__, + **{ + f"__{tag}__": value + for tag, value in field_info.custom_metadata.items() + }, + } + # Add possible values for Literal types + if get_origin(field_info.annotation) is Literal: + field_dict["__valid_values__"] = list( + get_args(field_info.annotation) + ) + + result[field_name] = field_dict + return result + + @classmethod + def unvalidated(cls, **data: Any) -> Any: + def process_value(value: Any, field_type: Any) -> Any: + if isinstance(value, dict) and hasattr( + field_type, "__pydantic_core_schema__" + ): + # Recursively handle nested Pydantic models + return field_type.unvalidated(**value) + elif isinstance(value, list): + # Handle lists of Pydantic models + if hasattr(field_type, "__args__") and hasattr( + field_type.__args__[0], "__pydantic_core_schema__" + ): + return [ + field_type.__args__[0].unvalidated(**item) + if isinstance(item, dict) + else item + for item in value + ] + return value + + processed_data = {} + for name, field in cls.model_fields.items(): + try: + value = data[name] + processed_data[name] = process_value(value, field.annotation) + except KeyError: + if not field.is_required(): + processed_data[name] = copy.deepcopy(field.default) + else: + raise TypeError(f"Missing required keyword argument {name!r}") + + self = cls.__new__(cls) + object.__setattr__(self, "__dict__", processed_data) + object.__setattr__(self, "__pydantic_private__", {"extra": None}) + object.__setattr__(self, "__pydantic_fields_set__", set(processed_data.keys())) + return self + + +################################# +#### Generic Pydantic models #### +################################# + + +T = TypeVar("T") +C = TypeVar("C", list, tuple) + + +class NMSequenceModel(NMBaseModel, Generic[C]): + """Base class for sequence models with a single root value""" + + root: C = NMField(default_factory=list) + + # Class variable for aliases - override in subclasses + __aliases__: dict[int, list[str]] = {} + + def __init__(self, *args, **kwargs) -> None: + if len(args) == 1 and isinstance(args[0], (list, tuple)): + kwargs["root"] = args[0] + elif len(args) == 1: + kwargs["root"] = [args[0]] + elif len(args) > 1: # Add this case + kwargs["root"] = tuple(args) + super().__init__(**kwargs) + + def __iter__(self): # type: ignore[reportIncompatibleMethodOverride] + return iter(self.root) + + def __getitem__(self, idx): + return self.root[idx] + + def __len__(self): + return len(self.root) + + def model_dump(self): # type: ignore[reportIncompatibleMethodOverride] + return self.root + + def model_dump_json(self, **kwargs): + import json + + return json.dumps(self.root, **kwargs) + + def serialize_with_metadata(self) -> dict[str, Any]: + result = {"__field_type__": self.__class__.__name__, "value": self.root} + + # Add any field metadata from the root field + field_info = self.model_fields.get("root") + if isinstance(field_info, NMFieldInfo): + for tag, value in field_info.custom_metadata.items(): + result[f"__{tag}__"] = value + + return result + + @model_validator(mode="before") + @classmethod + def validate_input(cls, value: Any) -> dict[str, Any]: + # If it's a dict, just return it + if isinstance(value, dict): + if "root" in value: + return value + + # Check for aliased fields if class has aliases defined + if hasattr(cls, "__aliases__"): + # Collect all possible alias names for each position + alias_values = [] + max_index = max(cls.__aliases__.keys()) if cls.__aliases__ else -1 + + # Try to find a value for each position using its aliases + for i in range(max_index + 1): + aliases = cls.__aliases__.get(i, []) + value_found = None + + # Try each alias for this position + for alias in aliases: + if alias in value: + value_found = value[alias] + break + + if value_found is not None: + alias_values.append(value_found) + else: + # If we're missing any position's value, don't use aliases + break + + # If we found all values through aliases, use them + if len(alias_values) == max_index + 1: + return {"root": alias_values} + + # if it's a sequence, return the value as the root + if isinstance(value, (list, tuple)): + return {"root": value} + + # Else, make it a list + return {"root": [value]} + + @model_serializer + def ser_model(self): + return self.root + + # Custom validator to skip the 'root' field in validation errors + @model_validator(mode="wrap") # type: ignore[reportIncompatibleMethodOverride] + def rewrite_error_locations(self, handler): + try: + return handler(self) + except ValidationError as e: + errors = [] + for err in e.errors(): + loc = list(err["loc"]) + # Find and remove 'root' from the location path + if "root" in loc: + root_idx = loc.index("root") + if root_idx < len(loc) - 1: + loc = loc[:root_idx] + loc[root_idx + 1 :] + err["loc"] = tuple(loc) + errors.append(err) + print(errors) + raise ValidationError.from_exception_data( + title="ValidationError", line_errors=errors + ) diff --git a/py_neuromodulation/utils/types.py b/py_neuromodulation/utils/types.py index 7886685d..13b2923e 100644 --- a/py_neuromodulation/utils/types.py +++ b/py_neuromodulation/utils/types.py @@ -1,12 +1,14 @@ from os import PathLike from math import isnan -from typing import Any, Literal, Protocol, TYPE_CHECKING, runtime_checkable -from pydantic import ConfigDict, Field, model_validator, BaseModel -from pydantic_core import ValidationError, InitErrorDetails -from pprint import pformat +from typing import Literal, TYPE_CHECKING, Any, TypeVar +from pydantic import BaseModel, ConfigDict, model_validator +from .pydantic_extensions import NMBaseModel, NMSequenceModel, NMField +from abc import abstractmethod + from collections.abc import Sequence from datetime import datetime + if TYPE_CHECKING: import numpy as np from py_neuromodulation import NMSettings @@ -17,8 +19,7 @@ _PathLike = str | PathLike - -FeatureName = Literal[ +FEATURE_NAME = Literal[ "raw_hjorth", "return_raw", "bandpass_filter", @@ -35,7 +36,7 @@ "bispectrum", ] -PreprocessorName = Literal[ +PREPROCESSOR_NAME = Literal[ "preprocessing_filter", "notch_filter", "raw_resampling", @@ -43,7 +44,7 @@ "raw_normalization", ] -NormMethod = Literal[ +NORM_METHOD = Literal[ "mean", "median", "zscore", @@ -54,13 +55,8 @@ "minmax", ] -################################### -######## PROTOCOL CLASSES ######## -################################### - -@runtime_checkable -class NMFeature(Protocol): +class NMFeature: def __init__( self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int | float ) -> None: ... @@ -81,131 +77,47 @@ def calc_feature(self, data: "np.ndarray") -> dict: ... -class NMPreprocessor(Protocol): - def __init__(self, sfreq: float, settings: "NMSettings") -> None: ... - +class NMPreprocessor: def process(self, data: "np.ndarray") -> "np.ndarray": ... -################################### -######## PYDANTIC CLASSES ######## -################################### - - -class NMBaseModel(BaseModel): - model_config = ConfigDict(validate_assignment=False, extra="allow") +class PreprocessorList(NMSequenceModel[list[PREPROCESSOR_NAME]]): + model_config = ConfigDict(arbitrary_types_allowed=True) + # Useless contructor to prevent linter from complaining def __init__(self, *args, **kwargs) -> None: - if kwargs: - super().__init__(**kwargs) - else: - field_names = list(self.model_fields.keys()) - kwargs = {} - for i in range(len(args)): - kwargs[field_names[i]] = args[i] - super().__init__(**kwargs) - - def __str__(self): - return pformat(self.model_dump()) - - def __repr__(self): - return pformat(self.model_dump()) - - def validate(self) -> Any: # type: ignore - return self.model_validate(self.model_dump()) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value) -> None: - setattr(self, key, value) - - def process_for_frontend(self) -> dict[str, Any]: - """ - Process the model for frontend use, adding __field_type__ information. - """ - result = {} - for field_name, field_value in self.__dict__.items(): - if isinstance(field_value, NMBaseModel): - processed_value = field_value.process_for_frontend() - processed_value["__field_type__"] = field_value.__class__.__name__ - result[field_name] = processed_value - elif isinstance(field_value, list): - result[field_name] = [ - item.process_for_frontend() - if isinstance(item, NMBaseModel) - else item - for item in field_value - ] - elif isinstance(field_value, dict): - result[field_name] = { - k: v.process_for_frontend() if isinstance(v, NMBaseModel) else v - for k, v in field_value.items() - } - else: - result[field_name] = field_value + super().__init__(*args, **kwargs) - return result +class FrequencyRange(NMSequenceModel[tuple[float, float]]): + """Frequency range as (low, high) tuple""" -class FrequencyRange(NMBaseModel): - frequency_low_hz: float = Field(gt=0) - frequency_high_hz: float = Field(gt=0) + __aliases__ = { + 0: ["frequency_low_hz", "low_frequency_hz"], + 1: ["frequency_high_hz", "high_frequency_hz"], + } + # Useless contructor to prevent linter from complaining def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def __getitem__(self, item: int): - match item: - case 0: - return self.frequency_low_hz - case 1: - return self.frequency_high_hz - case _: - raise IndexError(f"Index {item} out of range") - - def as_tuple(self) -> tuple[float, float]: - return (self.frequency_low_hz, self.frequency_high_hz) - - def __iter__(self): # type: ignore - return iter(self.as_tuple()) - @model_validator(mode="after") def validate_range(self): - if not (isnan(self.frequency_high_hz) or isnan(self.frequency_low_hz)): - assert ( - self.frequency_high_hz > self.frequency_low_hz - ), "Frequency high must be greater than frequency low" + low, high = self.root + if not (isnan(low) or isnan(high)): + assert high > low, "High frequency must be greater than low frequency" return self - @classmethod - def create_from(cls, input) -> "FrequencyRange": - match input: - case FrequencyRange(): - return input - case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input: - return FrequencyRange( - input["frequency_low_hz"], input["frequency_high_hz"] - ) - case Sequence() if len(input) == 2: - return FrequencyRange(input[0], input[1]) - case _: - raise ValueError("Invalid input for FrequencyRange creation.") - - @model_validator(mode="before") - @classmethod - def check_input(cls, input): - match input: - case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input: - return input - case Sequence() if len(input) == 2: - return {"frequency_low_hz": input[0], "frequency_high_hz": input[1]} - case _: - raise ValueError( - "Value for FrequencyRange must be a dictionary, " - "or a sequence of 2 numeric values, " - f"but got {input} instead." - ) + # Alias properties + @property + def frequency_low_hz(self) -> float: + """Lower frequency bound in Hz""" + return self.root[0] + + @property + def frequency_high_hz(self) -> float: + """Upper frequency bound in Hz""" + return self.root[1] class BoolSelector(NMBaseModel): @@ -238,46 +150,135 @@ def print_all(cls): for f in cls.list_all(): print(f) - @classmethod - def get_fields(cls): - return cls.model_fields + +################################################ +### Generic Pydantic models for the frontend ### +################################################ + + +class UniqueStringSequence(NMSequenceModel[list[str]]): + """ + A sequence of strings where: + - Values must come from a predefined set + - Each value can only appear once + - Order is preserved + """ + + @property + @abstractmethod + def valid_values(self) -> list[str]: + """Each subclass must implement this to provide its valid values""" + raise NotImplementedError + + def __init__(self, **data): + valid_values = data.pop("valid_values", []) + super().__init__(**data) + object.__setattr__(self, "valid_values", valid_values) + + @model_validator(mode="after") + def validate_sequence(self): + seen = set() + validated = [] + for item in self.root: + if item not in seen and item in self.valid_values: + seen.add(item) + validated.append(item) + self.root = validated + return self + + def serialize_with_metadata(self) -> dict[str, Any]: + result = super().serialize_with_metadata() + result["__valid_values__"] = self.valid_values + return result -def create_validation_error( - error_message: str, - loc: list[str | int] = None, - title: str = "Validation Error", - input_type: Literal["python", "json"] = "python", - hide_input: bool = False, -) -> ValidationError: +class DependentKeysList(NMSequenceModel[list[str]]): + """ + A list of strings where valid values are keys from another settings field """ - Factory function to create a Pydantic v2 ValidationError instance from a single error message. - Args: - error_message (str): The error message for the ValidationError. - loc (List[str | int], optional): The location of the error. Defaults to None. - title (str, optional): The title of the error. Defaults to "Validation Error". - input_type (Literal["python", "json"], optional): Whether the error is for a Python object or JSON. Defaults to "python". - hide_input (bool, optional): Whether to hide the input value in the error message. Defaults to False. + root: list[str] = NMField(default_factory=list) + source_dict: dict[str, Any] = NMField(default_factory=dict, exclude=True) + + def __init__(self, **data): + source_dict = data.pop("source_dict", {}) + super().__init__(**data) + object.__setattr__(self, "source_dict", source_dict) + + @model_validator(mode="after") + def validate_keys(self): + valid_keys = set(self.source_dict.keys()) + seen = set() + validated = [] + for item in self.root: + if item not in seen and item in valid_keys: + seen.add(item) + validated.append(item) + self.root = validated + return self + + def serialize_with_metadata(self) -> dict[str, Any]: + result = super().serialize_with_metadata() + result["__valid_values__"] = list(self.source_dict.keys()) + result["__dependent__"] = True # Indicates this needs dynamic updating + return result + - Returns: - ValidationError: A Pydantic ValidationError instance. +class StringPairsList(NMSequenceModel[list[tuple[str, str]]]): + """ + A list of string pairs where values must come from predetermined lists """ - if loc is None: - loc = [] - - line_errors = [ - InitErrorDetails( - type="value_error", loc=tuple(loc), input=None, ctx={"error": error_message} - ) - ] - - return ValidationError.from_exception_data( - title=title, - line_errors=line_errors, - input_type=input_type, - hide_input=hide_input, - ) + + root: list[tuple[str, str]] = NMField(default_factory=list) + valid_first: list[str] = NMField(default_factory=list, exclude=True) + valid_second: list[str] = NMField(default_factory=list, exclude=True) + + def __init__(self, **data): + valid_first = data.pop("valid_first", []) + valid_second = data.pop("valid_second", []) + super().__init__(**data) + object.__setattr__(self, "valid_first", valid_first) + object.__setattr__(self, "valid_second", valid_second) + + @model_validator(mode="after") + def validate_pairs(self): + validated = [ + (first, second) + for first, second in self.root + if first in self.valid_first and second in self.valid_second + ] + self.root = validated + return self + + def serialize_with_metadata(self) -> dict[str, Any]: + result = super().serialize_with_metadata() + result["__valid_first__"] = self.valid_first + result["__valid_second__"] = self.valid_second + return result + + +# class LiteralValue(NMValueModel[str]): +# """ +# A string field that must be one of a predefined set of literals +# """ + +# valid_values: list[str] = NMField(default_factory=list, exclude=True) + +# def __init__(self, **data): +# valid_values = data.pop("valid_values", []) +# super().__init__(**data) +# object.__setattr__(self, "valid_values", valid_values) + +# @model_validator(mode="after") +# def validate_value(self): +# if self.root not in self.valid_values: +# raise ValueError(f"Value must be one of: {self.valid_values}") +# return self + +# def serialize_with_metadata(self) -> dict[str, Any]: +# result = super().serialize_with_metadata() +# result["__valid_values__"] = self.valid_values +# return result ################# diff --git a/pyproject.toml b/pyproject.toml index 5facc8ba..29c85832 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ dependencies = [ "nolds >=0.6.1", "numpy >= 1.21.2", "pandas >= 2.0.0", - "scikit-image", "scikit-learn >= 0.24.2", "scikit-optimize", "scipy >= 1.7.1", @@ -55,13 +54,11 @@ dependencies = [ "statsmodels", "mne-lsl>=1.2.0", "pynput", - #"pyqt5", "pydantic>=2.7.3", "llvmlite>=0.43.0", "pywebview", "fastapi", - "uvicorn>=0.30.6", - "websockets>=13.0", + "uvicorn[standard]>=0.30.6", "seaborn >= 0.11", # exists only because of nolds? "numba>=0.60.0", @@ -79,7 +76,7 @@ dev = [ "pytest-cov", "pytest-sugar", "notebook", - "watchdog", + "uvicorn[standard]", ] docs = [ "py-neuromodulation[dev]", diff --git a/tests/test_osc_features.py b/tests/test_osc_features.py index 8cf84111..fd2428d1 100644 --- a/tests/test_osc_features.py +++ b/tests/test_osc_features.py @@ -3,11 +3,11 @@ import numpy as np from py_neuromodulation import NMSettings, Stream, features -from py_neuromodulation.utils.types import FeatureName +from py_neuromodulation.utils.types import FEATURE_NAME def setup_osc_settings( - osc_feature_name: FeatureName, + osc_feature_name: FEATURE_NAME, osc_feature_setting: str, windowlength_ms: int, log_transform: bool,