diff --git a/gui_dev/src/components/StatusBar/StatusBar.jsx b/gui_dev/src/components/StatusBar/StatusBar.jsx index 36aaf1dc..f011e2b9 100644 --- a/gui_dev/src/components/StatusBar/StatusBar.jsx +++ b/gui_dev/src/components/StatusBar/StatusBar.jsx @@ -2,12 +2,14 @@ import { ResizeHandle } from "./ResizeHandle"; import { SocketStatus } from "./SocketStatus"; import { WebviewStatus } from "./WebviewStatus"; -import { useWebviewStore } from "@/stores"; - +import { useUiStore, useWebviewStore } from "@/stores"; import { Stack } from "@mui/material"; export const StatusBar = () => { - const { isWebView } = useWebviewStore((state) => state.isWebView); + const isWebView = useWebviewStore((state) => state.isWebView); + const createStatusBarContent = useUiStore((state) => state.statusBarContent); + + const StatusBarContent = createStatusBarContent(); return ( { bgcolor="background.level1" borderTop="2px solid" borderColor="background.level3" + height="2rem" > - + {StatusBarContent && } + + {/* */} {/* Current experiment */} {/* Current stream */} {/* Current activity */} diff --git a/gui_dev/src/components/TitledBox.jsx b/gui_dev/src/components/TitledBox.jsx index 56f0266b..a86b6009 100644 --- a/gui_dev/src/components/TitledBox.jsx +++ b/gui_dev/src/components/TitledBox.jsx @@ -1,4 +1,4 @@ -import { Container } from "@mui/material"; +import { Box, Container } from "@mui/material"; /** * Component that uses the Box component to render an HTML fieldset element @@ -13,14 +13,14 @@ export const TitledBox = ({ children, ...props }) => ( - {title} {children} - + ); diff --git a/gui_dev/src/pages/Settings/DragAndDropList.jsx b/gui_dev/src/pages/Settings/DragAndDropList.jsx deleted file mode 100644 index afa74b02..00000000 --- a/gui_dev/src/pages/Settings/DragAndDropList.jsx +++ /dev/null @@ -1,65 +0,0 @@ -import { useRef } from "react"; -import styles from "./DragAndDropList.module.css"; -import { useOptionsStore } from "@/stores"; - -export const DragAndDropList = () => { - const { options, setOptions, addOption, removeOption } = useOptionsStore(); - - const predefinedOptions = [ - { id: 1, name: "raw_resampling" }, - { id: 2, name: "notch_filter" }, - { id: 3, name: "re_referencing" }, - { id: 4, name: "preprocessing_filter" }, - { id: 5, name: "raw_normalization" }, - ]; - - const dragOption = useRef(0); - const draggedOverOption = useRef(0); - - function handleSort() { - const optionsClone = [...options]; - const temp = optionsClone[dragOption.current]; - optionsClone[dragOption.current] = optionsClone[draggedOverOption.current]; - optionsClone[draggedOverOption.current] = temp; - setOptions(optionsClone); - } - - return ( -
-

List

- {options.map((option, index) => ( -
(dragOption.current = index)} - onDragEnter={() => (draggedOverOption.current = index)} - onDragEnd={handleSort} - onDragOver={(e) => e.preventDefault()} - > -

- {[option.id, ". ", option.name.replace("_", " ")]} -

- -
- ))} -
-

Add Elements

- {predefinedOptions.map((option) => ( - - ))} -
-
- ); -}; diff --git a/gui_dev/src/pages/Settings/DragAndDropList.module.css b/gui_dev/src/pages/Settings/DragAndDropList.module.css deleted file mode 100644 index 76bd7861..00000000 --- a/gui_dev/src/pages/Settings/DragAndDropList.module.css +++ /dev/null @@ -1,72 +0,0 @@ -.dragDropList { - max-width: 400px; - margin: 0 auto; - padding: 20px; -} - -.title { - text-align: center; - color: #333; - margin-bottom: 20px; -} - -.item { - background-color: #f0f0f0; - border-radius: 8px; - padding: 15px; - margin-bottom: 10px; - cursor: move; - transition: background-color 0.3s ease; - display: flex; - justify-content: space-between; - align-items: center; -} - -.item:hover { - background-color: #e0e0e0; -} - -.itemText { - margin: 0; - color: #333; - font-size: 16px; -} - -.removeButton { - background-color: #ff4d4d; - border: none; - border-radius: 4px; - color: white; - padding: 5px 10px; - cursor: pointer; - transition: background-color 0.3s ease; -} - -.removeButton:hover { - background-color: #ff1a1a; -} - -.addSection { - margin-top: 20px; - text-align: center; -} - -.subtitle { - color: #333; - margin-bottom: 10px; -} - -.addButton { - background-color: #4CAF50; - border: none; - border-radius: 4px; - color: white; - padding: 10px 15px; - margin: 5px; - cursor: pointer; - transition: background-color 0.3s ease; -} - -.addButton:hover { - background-color: #45a049; -} \ No newline at end of file diff --git a/gui_dev/src/pages/Settings/Dropdown.jsx b/gui_dev/src/pages/Settings/Dropdown.jsx deleted file mode 100644 index ad662eef..00000000 --- a/gui_dev/src/pages/Settings/Dropdown.jsx +++ /dev/null @@ -1,66 +0,0 @@ -import { useState } from "react"; -import "../App.css"; - -var stringJson = - '{"sampling_rate_features_hz":10.0,"segment_length_features_ms":1000.0,"frequency_ranges_hz":{"theta":{"frequency_low_hz":4.0,"frequency_high_hz":8.0},"alpha":{"frequency_low_hz":8.0,"frequency_high_hz":12.0},"low beta":{"frequency_low_hz":13.0,"frequency_high_hz":20.0},"high beta":{"frequency_low_hz":20.0,"frequency_high_hz":35.0},"low gamma":{"frequency_low_hz":60.0,"frequency_high_hz":80.0},"high gamma":{"frequency_low_hz":90.0,"frequency_high_hz":200.0},"HFA":{"frequency_low_hz":200.0,"frequency_high_hz":400.0}},"preprocessing":["raw_resampling","notch_filter","re_referencing"],"raw_resampling_settings":{"resample_freq_hz":1000.0},"preprocessing_filter":{"bandstop_filter":true,"bandpass_filter":true,"lowpass_filter":true,"highpass_filter":true,"bandstop_filter_settings":{"frequency_low_hz":100.0,"frequency_high_hz":160.0},"bandpass_filter_settings":{"frequency_low_hz":2.0,"frequency_high_hz":200.0},"lowpass_filter_cutoff_hz":200.0,"highpass_filter_cutoff_hz":3.0},"raw_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"postprocessing":{"feature_normalization":true,"project_cortex":false,"project_subcortex":false},"feature_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"project_cortex_settings":{"max_dist_mm":20.0},"project_subcortex_settings":{"max_dist_mm":5.0},"features":{"raw_hjorth":true,"return_raw":true,"bandpass_filter":false,"stft":false,"fft":true,"welch":true,"sharpwave_analysis":true,"fooof":false,"nolds":false,"coherence":false,"bursts":true,"linelength":true,"mne_connectivity":false,"bispectrum":false},"fft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"welch_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"stft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"bandpass_filter_settings":{"segment_lengths_ms":{"theta":1000,"alpha":500,"low beta":333,"high beta":333,"low gamma":100,"high gamma":100,"HFA":100},"bandpower_features":{"activity":true,"mobility":false,"complexity":false},"log_transform":true,"kalman_filter":false},"kalman_filter_settings":{"Tp":0.1,"sigma_w":0.7,"sigma_v":1.0,"frequency_bands":["theta","alpha","low_beta","high_beta","low_gamma","high_gamma","HFA"]},"burst_settings":{"threshold":75.0,"time_duration_s":30.0,"frequency_bands":["low beta","high beta","low gamma"],"burst_features":{"duration":true,"amplitude":true,"burst_rate_per_s":true,"in_burst":true}},"sharpwave_analysis_settings":{"sharpwave_features":{"peak_left":false,"peak_right":false,"trough":false,"width":false,"prominence":true,"interval":true,"decay_time":false,"rise_time":false,"sharpness":true,"rise_steepness":false,"decay_steepness":false,"slope_ratio":false},"filter_ranges_hz":[{"frequency_low_hz":5.0,"frequency_high_hz":80.0},{"frequency_low_hz":5.0,"frequency_high_hz":30.0}],"detect_troughs":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"detect_peaks":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"estimator":{"mean":["interval"],"median":[],"max":["prominence","sharpness"],"min":[],"var":[]},"apply_estimator_between_peaks_and_troughs":true},"mne_connectivity":{"method":"plv","mode":"multitaper"},"coherence":{"features":{"mean_fband":true,"max_fband":true,"max_allfbands":true},"method":{"coh":true,"icoh":true},"channels":[],"frequency_bands":["high beta"]},"fooof":{"aperiodic":{"exponent":true,"offset":true,"knee":true},"periodic":{"center_frequency":false,"band_width":false,"height_over_ap":false},"windowlength_ms":800.0,"peak_width_limits":{"frequency_low_hz":0.5,"frequency_high_hz":12.0},"max_n_peaks":3,"min_peak_height":0.0,"peak_threshold":2.0,"freq_range_hz":{"frequency_low_hz":2.0,"frequency_high_hz":40.0},"knee":true},"nolds_features":{"raw":true,"frequency_bands":["low beta"],"features":{"sample_entropy":false,"correlation_dimension":false,"lyapunov_exponent":true,"hurst_exponent":false,"detrended_fluctuation_analysis":false}},"bispectrum":{"f1s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"f2s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"compute_features_for_whole_fband_range":true,"frequency_bands":["theta","alpha","low_beta","high_beta"],"components":{"absolute":true,"real":true,"imag":true,"phase":true},"bispectrum_features":{"mean":true,"sum":true,"var":true}}}'; -const nm_settings = JSON.parse(stringJson); - -const filterByKeys = (dict, keys) => { - const filteredDict = {}; - keys.forEach((key) => { - if (typeof key === "string") { - // Top-level key - if (dict.hasOwnProperty(key)) { - filteredDict[key] = dict[key]; - } - } else if (typeof key === "object") { - // Nested key - const topLevelKey = Object.keys(key)[0]; - const nestedKeys = key[topLevelKey]; - if ( - dict.hasOwnProperty(topLevelKey) && - typeof dict[topLevelKey] === "object" - ) { - filteredDict[topLevelKey] = filterByKeys(dict[topLevelKey], nestedKeys); - } - } - }); - return filteredDict; -}; - -const Dropdown = ({ onChange, keysToInclude }) => { - const filteredSettings = filterByKeys(nm_settings, keysToInclude); - const [selectedOption, setSelectedOption] = useState(""); - - const handleChange = (event) => { - const newValue = event.target.value; - setSelectedOption(newValue); - onChange(keysToInclude, newValue); - }; - - return ( -
- -
- ); -}; - -export default Dropdown; diff --git a/gui_dev/src/pages/Settings/FrequencyRange.jsx b/gui_dev/src/pages/Settings/FrequencyRange.jsx deleted file mode 100644 index 5def783b..00000000 --- a/gui_dev/src/pages/Settings/FrequencyRange.jsx +++ /dev/null @@ -1,88 +0,0 @@ -import { useState } from "react"; - -// const onChange = (key, newValue) => { -// settings.frequencyRanges[key] = newValue; -// }; - -// Object.entries(settings.frequencyRanges).map(([key, value]) => ( -// -// )); - -/** */ - -/** - * - * @param {String} key - * @param {Array} freqRange - * @param {Function} onChange - * @returns - */ -export const FrequencyRange = ({ settings }) => { - const [frequencyRanges, setFrequencyRanges] = useState(settings || {}); - // Handle changes in the text fields - const handleInputChange = (label, key, newValue) => { - setFrequencyRanges((prevState) => ({ - ...prevState, - [label]: { - ...prevState[label], - [key]: newValue, - }, - })); - }; - - // Add a new band - const addBand = () => { - const newLabel = `Band ${Object.keys(frequencyRanges).length + 1}`; - setFrequencyRanges((prevState) => ({ - ...prevState, - [newLabel]: { frequency_high_hz: "", frequency_low_hz: "" }, - })); - }; - - // Remove a band - const removeBand = (label) => { - const updatedRanges = { ...frequencyRanges }; - delete updatedRanges[label]; - setFrequencyRanges(updatedRanges); - }; - - return ( -
-
Frequency Bands
- {Object.keys(frequencyRanges).map((label) => ( -
- { - const newLabel = e.target.value; - const updatedRanges = { ...frequencyRanges }; - updatedRanges[newLabel] = updatedRanges[label]; - delete updatedRanges[label]; - setFrequencyRanges(updatedRanges); - }} - placeholder="Band Name" - /> - - handleInputChange(label, "frequency_high_hz", e.target.value) - } - placeholder="High Hz" - /> - - handleInputChange(label, "frequency_low_hz", e.target.value) - } - placeholder="Low Hz" - /> - -
- ))} - -
- ); -}; diff --git a/gui_dev/src/pages/Settings/FrequencySettings.module.css b/gui_dev/src/pages/Settings/FrequencySettings.module.css deleted file mode 100644 index a638ad93..00000000 --- a/gui_dev/src/pages/Settings/FrequencySettings.module.css +++ /dev/null @@ -1,67 +0,0 @@ -.container { - background-color: #f9f9f9; /* Light gray background for the container */ - padding: 20px; - border-radius: 10px; /* Rounded corners */ - box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); /* Subtle shadow */ - max-width: 600px; - margin: auto; - } - - .header { - font-size: 1.5rem; - color: #333; /* Darker text color */ - margin-bottom: 20px; - text-align: center; - } - - .bandContainer { - display: flex; - align-items: center; - margin-bottom: 15px; - padding: 10px; - border: 1px solid #ddd; /* Light border for each band */ - border-radius: 8px; /* Rounded corners for individual bands */ - background-color: #fff; /* White background for bands */ - box-shadow: 0 1px 5px rgba(0, 0, 0, 0.1); /* Light shadow for depth */ - } - - .bandNameInput, .frequencyInput { - border: 1px solid #ccc; /* Light gray border */ - border-radius: 5px; /* Slightly rounded corners */ - padding: 8px; - margin-right: 10px; - font-size: 0.875rem; - } - - .bandNameInput::placeholder, .frequencyInput::placeholder { - color: #aaa; /* Light gray placeholder text */ - } - - .removeButton, .addButton { - border: none; - border-radius: 5px; /* Rounded corners */ - padding: 8px 12px; - font-size: 0.875rem; - cursor: pointer; - transition: background-color 0.3s, color 0.3s; - } - - .removeButton { - background-color: #e57373; /* Light red color */ - color: white; - } - - .removeButton:hover { - background-color: #d32f2f; /* Darker red on hover */ - } - - .addButton { - background-color: #4caf50; /* Green color */ - color: white; - margin-top: 10px; - } - - .addButton:hover { - background-color: #388e3c; /* Darker green on hover */ - } - \ No newline at end of file diff --git a/gui_dev/src/pages/Settings/Settings.jsx b/gui_dev/src/pages/Settings/Settings.jsx index e2cf5f2a..14b57246 100644 --- a/gui_dev/src/pages/Settings/Settings.jsx +++ b/gui_dev/src/pages/Settings/Settings.jsx @@ -1,19 +1,24 @@ +import { useEffect, useState } from "react"; import { + Box, Button, + ButtonGroup, InputAdornment, + Popover, Stack, Switch, TextField, + Tooltip, Typography, } from "@mui/material"; import { Link } from "react-router-dom"; import { CollapsibleBox, TitledBox } from "@/components"; -import { FrequencyRange } from "./FrequencyRange"; -import { useSettingsStore } from "@/stores"; +import { FrequencyRangeList } from "./FrequencyRange"; +import { Dropdown } from "./Dropdown"; +import { useSettingsStore, useStatusBarContent } from "@/stores"; import { filterObjectByKeys } from "@/utils/functions"; const formatKey = (key) => { - // console.log(key); return key .split("_") .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) @@ -21,15 +26,32 @@ const formatKey = (key) => { }; // Wrapper components for each type -const BooleanField = ({ value, onChange }) => ( +const BooleanField = ({ value, onChange, error }) => ( onChange(e.target.checked)} /> ); -const StringField = ({ value, onChange, label }) => ( - +const StringField = ({ value, onChange, label, error }) => ( + onChange(e.target.value)} + label={label} + sx={{ + "& .MuiOutlinedInput-root": { + "& fieldset": { + borderColor: error ? "error.main" : "inherit", + }, + "&:hover fieldset": { + borderColor: error ? "error.main" : "primary.main", + }, + "&.Mui-focused fieldset": { + borderColor: error ? "error.main" : "primary.main", + }, + }, + }} + /> ); -const NumberField = ({ value, onChange, label }) => { +const NumberField = ({ value, onChange, label, error }) => { const handleChange = (event) => { const newValue = event.target.value; // Only allow numbers and decimal point @@ -44,13 +66,26 @@ const NumberField = ({ value, onChange, label }) => { value={value} onChange={handleChange} label={label} - InputProps={{ - endAdornment: ( - - Hz - - ), + sx={{ + "& .MuiOutlinedInput-root": { + "& fieldset": { + borderColor: error ? "error.main" : "inherit", + }, + "&:hover fieldset": { + borderColor: error ? "error.main" : "primary.main", + }, + "&.Mui-focused fieldset": { + borderColor: error ? "error.main" : "primary.main", + }, + }, }} + // InputProps={{ + // endAdornment: ( + // + // Hz + // + // ), + // }} inputProps={{ pattern: "[0-9]*", }} @@ -58,107 +93,231 @@ const NumberField = ({ value, onChange, label }) => { ); }; -const FrequencyRangeField = ({ value, onChange, label }) => ( - -); - // Map component types to their respective wrappers const componentRegistry = { boolean: BooleanField, string: StringField, number: NumberField, - FrequencyRange: FrequencyRangeField, + Array: Dropdown, }; -const SettingsField = ({ path, Component, label, value, onChange, depth }) => { +const SettingsField = ({ path, Component, label, value, onChange, error }) => { return ( - - {label} - onChange(path, newValue)} - label={label} - /> - + + + {label} + onChange(path, newValue)} + label={label} + error={error} + /> + + ); }; +// Function to get the error corresponding to this field or its children +const getFieldError = (fieldPath, errors) => { + if (!errors) return null; + + return errors.find((error) => { + const errorPath = error.loc.join("."); + const currentPath = fieldPath.join("."); + return errorPath === currentPath || errorPath.startsWith(currentPath + "."); + }); +}; + const SettingsSection = ({ settings, title = null, path = [], onChange, - depth = 0, + errors, }) => { - if (Object.keys(settings).length === 0) { - return null; - } const boxTitle = title ? title : formatKey(path[path.length - 1]); + /* + 3 possible cases: + 1. Primitive type || 2. Object with component -> Don't iterate, render directly + 3. Object without component or 4. Array -> Iterate and render recursively + */ - return ( - - {Object.entries(settings).map(([key, value]) => { - if (key === "__field_type__") return null; + const type = typeof settings; + const isObject = type === "object" && !Array.isArray(settings); + const isArray = Array.isArray(settings); - const newPath = [...path, key]; - const label = key; - const isPydanticModel = - typeof value === "object" && "__field_type__" in value; + // __field_type__ should be always present + if (isObject && !settings.__field_type__) { + console.log(settings); + throw new Error("Invalid settings object"); + } + const fieldType = isObject ? settings.__field_type__ : type; + const Component = componentRegistry[fieldType]; - const fieldType = isPydanticModel ? value.__field_type__ : typeof value; + // Case 1: Primitive type -> Don't iterate, render directly + if (!isObject && !isArray) { + if (!Component) { + console.error(`Invalid component type: ${type}`); + return null; + } - const Component = componentRegistry[fieldType]; + const error = getFieldError(path, errors); + + return ( + + ); + } + + // Case 2: Object with component -> Don't iterate, render directly + if (isObject && Component) { + return ( + + ); + } + + // Case 3: Object without component or 4. Array -> Iterate and render recursively + if ((isObject && !Component) || isArray) { + return ( + + {/* Handle recursing through both objects and arrays */} + {(isArray ? settings : Object.entries(settings)).map((item, index) => { + const [key, value] = isArray ? [index.toString(), item] : item; + if (key.startsWith("__")) return null; // Skip metadata fields + + const newPath = [...path, key]; - if (Component) { - return ( - - ); - } else { return ( ); - } - })} - + })} + + ); + } + + // Default case: return null and log an error + console.error(`Invalid settings object, returning null`); + return null; +}; + +const StatusBarSettingsInfo = () => { + const validationErrors = useSettingsStore((state) => state.validationErrors); + const [anchorEl, setAnchorEl] = useState(null); + const open = Boolean(anchorEl); + + const handleOpenErrorsPopover = (event) => { + setAnchorEl(event.currentTarget); + }; + + const handleCloseErrorsPopover = () => { + setAnchorEl(null); + }; + + return ( + <> + {validationErrors?.length > 0 && ( + <> + + {validationErrors?.length} errors found in Settings + + + + {validationErrors.map((error, index) => ( + + {index} - [{error.type}] {error.msg} + + ))} + + + + )} + ); }; -const SettingsContent = () => { +export const Settings = () => { + // Get all necessary state from the settings store const settings = useSettingsStore((state) => state.settings); - const updateSettings = useSettingsStore((state) => state.updateSettings); + const uploadSettings = useSettingsStore((state) => state.uploadSettings); + const resetSettings = useSettingsStore((state) => state.resetSettings); + const validationErrors = useSettingsStore((state) => state.validationErrors); + useStatusBarContent(StatusBarSettingsInfo); + + // This is needed so that the frequency ranges stay in order between updates + const frequencyRangeOrder = useSettingsStore( + (state) => state.frequencyRangeOrder + ); + const updateFrequencyRangeOrder = useSettingsStore( + (state) => state.updateFrequencyRangeOrder + ); + + // Here I handle the selected feature in the feature settings component + const [selectedFeature, setSelectedFeature] = useState(""); + + useEffect(() => { + uploadSettings(null, true); // validateOnly = true + }, [settings]); + // Inject validation error info into status bar + + // This has to be after all the hooks, otherwise React will complain if (!settings) { return
Loading settings...
; } - const handleChange = (path, value) => { - updateSettings((settings) => { + // This are the callbacks for the different buttons + const handleChangeSettings = async (path, value) => { + uploadSettings((settings) => { let current = settings; for (let i = 0; i < path.length - 1; i++) { current = current[path[i]]; } current[path[path.length - 1]] = value; - }); + }, true); // validateOnly = true + }; + + const handleSaveSettings = () => { + uploadSettings(() => settings); + }; + + const handleResetSettings = async () => { + await resetSettings(); }; const featureSettingsKeys = Object.keys(settings.features) @@ -181,89 +340,146 @@ const SettingsContent = () => { "project_subcortex_settings", ]; + const generalSettingsKeys = [ + "sampling_rate_features_hz", + "segment_length_features_ms", + ]; + return ( - - + {/* SETTINGS LAYOUT */} + - - + {/* GENERAL SETTINGS + FREQUENCY RANGES */} + + + {generalSettingsKeys.map((key) => ( + + ))} + + + + + + - - + {/* POSTPROCESSING + PREPROCESSING SETTINGS */} + {preprocessingSettingsKeys.map((key) => ( ))} - + - + {postprocessingSettingsKeys.map((key) => ( ))} - - + - - {Object.entries(enabledFeatures).map(([feature, featureSettings]) => ( - - - - ))} - - - ); -}; + {/* FEATURE SETTINGS */} + + + + + + + {Object.entries(enabledFeatures).map( + ([feature, featureSettings]) => ( + + + + ) + )} + + + + {/* END SETTINGS LAYOUT */} +
-export const Settings = () => { - return ( - - - + + {/* */} + + + ); }; diff --git a/gui_dev/src/pages/Settings/TextField.jsx b/gui_dev/src/pages/Settings/TextField.jsx deleted file mode 100644 index 86ab07b8..00000000 --- a/gui_dev/src/pages/Settings/TextField.jsx +++ /dev/null @@ -1,77 +0,0 @@ -import { useState, useEffect } from "react"; -import { - Box, - Grid, - TextField as MUITextField, - Typography, -} from "@mui/material"; -import { useSettingsStore } from "@/stores"; -import styles from "./TextField.module.css"; - -const flattenDictionary = (dict, parentKey = "", result = {}) => { - for (let key in dict) { - const newKey = parentKey ? `${parentKey}.${key}` : key; - if (typeof dict[key] === "object" && dict[key] !== null) { - flattenDictionary(dict[key], newKey, result); - } else { - result[newKey] = dict[key]; - } - } - return result; -}; - -const filterByKeys = (flatDict, keys) => { - const filteredDict = {}; - keys.forEach((key) => { - if (flatDict.hasOwnProperty(key)) { - filteredDict[key] = flatDict[key]; - } - }); - return filteredDict; -}; - -export const TextField = ({ keysToInclude }) => { - const settings = useSettingsStore((state) => state.settings); - const flatSettings = flattenDictionary(settings); - const filteredSettings = filterByKeys(flatSettings, keysToInclude); - const [textLabels, setTextLabels] = useState({}); - - useEffect(() => { - setTextLabels(filteredSettings); - }, [settings]); - - const handleTextFieldChange = (label, value) => { - setTextLabels((prevLabels) => ({ - ...prevLabels, - [label]: value, - })); - }; - - // Function to format the label - const formatLabel = (label) => { - const labelAfterDot = label.split(".").pop(); // Get everything after the last dot - return labelAfterDot.replace(/_/g, " "); // Replace underscores with spaces - }; - - return ( -
- {Object.keys(textLabels).map((label, index) => ( -
- - handleTextFieldChange(label, e.target.value)} - className={styles.textFieldInput} - /> -
- ))} -
- ); -}; diff --git a/gui_dev/src/pages/Settings/TextField.module.css b/gui_dev/src/pages/Settings/TextField.module.css deleted file mode 100644 index 37791c40..00000000 --- a/gui_dev/src/pages/Settings/TextField.module.css +++ /dev/null @@ -1,67 +0,0 @@ -/* TextField.module.css */ - -/* Container for the text fields */ -.textFieldContainer { - display: flex; - flex-direction: column; - margin: 1.5rem 0; /* Increased margin for better spacing */ - } - - /* Row for each text field */ - .textFieldRow { - display: flex; - flex-direction: column; /* Stack label and input vertically */ - margin-bottom: 1rem; /* Increased margin for better separation */ - } - - /* Label for each text field */ - .textFieldLabel { - margin-bottom: 0.5rem; /* Space between label and input */ - font-weight: 600; /* Increased weight for better visibility */ - color: #333; /* Dark gray for the label */ - font-size: 1.1rem; /* Increased font size for the label */ - transition: all 0.2s ease; /* Smooth transition for label */ - } - - /* Input field styles */ - .textFieldInput { - padding: 12px 14px; /* Padding for a filled look */ - border: 1px solid #ccc; /* Light gray border */ - border-radius: 4px; /* Rounded corners */ - width: 100%; /* Full width */ - font-size: 1rem; /* Font size */ - background-color: #f5f5f5; /* Light background color for filled effect */ - transition: border-color 0.2s ease, background-color 0.2s ease; /* Smooth transitions */ - box-shadow: none; /* Remove default shadow */ - height: 48px; /* Fixed height for a more square appearance */ - } - - /* Focus styles for the input */ - .textFieldInput:focus { - border-color: #1976d2; /* Blue border color on focus */ - background-color: #fff; /* Change background to white on focus */ - outline: none; /* Remove default outline */ - } - - /* Hover effect for the input */ - .textFieldInput:hover { - border-color: #1976d2; /* Change border color on hover */ - } - - /* Placeholder styles */ - .textFieldInput::placeholder { - color: #aaa; /* Light gray placeholder text */ - opacity: 1; /* Ensure placeholder is fully opaque */ - } - - /* Hide the number input spinners in webkit browsers */ - .textFieldInput::-webkit-inner-spin-button, - .textFieldInput::-webkit-outer-spin-button { - -webkit-appearance: none; /* Remove default styling */ - margin: 0; /* Remove margin */ - } - - /* Hide the number input spinners in Firefox */ - .textFieldInput[type='number'] { - -moz-appearance: textfield; /* Use textfield appearance */ - } \ No newline at end of file diff --git a/gui_dev/src/pages/Settings/index.js b/gui_dev/src/pages/Settings/index.js deleted file mode 100644 index 16355023..00000000 --- a/gui_dev/src/pages/Settings/index.js +++ /dev/null @@ -1 +0,0 @@ -export { TextField } from './TextField'; diff --git a/gui_dev/src/stores/createStore.js b/gui_dev/src/stores/createStore.js index 762d9340..e683a35c 100644 --- a/gui_dev/src/stores/createStore.js +++ b/gui_dev/src/stores/createStore.js @@ -2,16 +2,23 @@ import { create } from "zustand"; import { immer } from "zustand/middleware/immer"; import { devtools, persist as persistMiddleware } from "zustand/middleware"; -export const createStore = (name, initializer, persist = false) => { +export const createStore = ( + name, + initializer, + persist = false, + dev = false +) => { const fn = persist ? persistMiddleware(immer(initializer), name) : immer(initializer); - return create( - devtools(fn, { - name: name, - }) - ); + const dev_fn = dev + ? devtools(fn, { + name: name, + }) + : fn; + + return create(dev_fn); }; export const createPersistStore = (name, initializer) => { diff --git a/gui_dev/src/stores/settingsStore.js b/gui_dev/src/stores/settingsStore.js index 9b17f42d..e5f085fe 100644 --- a/gui_dev/src/stores/settingsStore.js +++ b/gui_dev/src/stores/settingsStore.js @@ -5,32 +5,18 @@ const INITIAL_DELAY = 3000; // wait for Flask const RETRY_DELAY = 1000; // ms const MAX_RETRIES = 100; -const uploadSettingsToServer = async (settings) => { - try { - const response = await fetch(getBackendURL("/api/settings"), { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(settings), - }); - if (!response.ok) { - throw new Error("Failed to update settings"); - } - return { success: true }; - } catch (error) { - console.error("Error updating settings:", error); - return { success: false, error }; - } -}; - export const useSettingsStore = createStore("settings", (set, get) => ({ settings: null, + lastValidSettings: null, + frequencyRangeOrder: [], isLoading: false, error: null, + validationErrors: null, retryCount: 0, - setSettings: (settings) => set({ settings }), + updateLocalSettings: (updater) => { + set((state) => updater(state.settings)); + }, fetchSettingsWithDelay: () => { set({ isLoading: true, error: null }); @@ -39,15 +25,25 @@ export const useSettingsStore = createStore("settings", (set, get) => ({ }, INITIAL_DELAY); }, - fetchSettings: async () => { + fetchSettings: async (reset = false) => { + console.log("Fetching settings..."); try { - console.log("Fetching settings..."); - const response = await fetch(getBackendURL("/api/settings")); + const response = await fetch( + `/api/settings${reset ? "?reset=true" : ""}` + ); + if (!response.ok) { throw new Error("Failed to fetch settings"); } + const data = await response.json(); - set({ settings: data, retryCount: 0 }); + + set({ + settings: data, + lastValidSettings: data, + frequencyRangeOrder: Object.keys(data.frequency_ranges_hz || {}), + retryCount: 0, + }); } catch (error) { console.log("Error fetching settings:", error); set((state) => ({ @@ -55,6 +51,8 @@ export const useSettingsStore = createStore("settings", (set, get) => ({ retryCount: state.retryCount + 1, })); + console.log(get().retryCount); + if (get().retryCount < MAX_RETRIES) { await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY)); return get().fetchSettings(); @@ -66,29 +64,62 @@ export const useSettingsStore = createStore("settings", (set, get) => ({ resetRetryCount: () => set({ retryCount: 0 }), - updateSettings: async (updater) => { - const currentSettings = get().settings; + resetSettings: async () => { + await get().fetchSettings(true); + }, - // Apply the update optimistically - set((state) => { - updater(state.settings); - }); + updateFrequencyRangeOrder: (newOrder) => { + set({ frequencyRangeOrder: newOrder }); + }, - const newSettings = get().settings; + uploadSettings: async (updater, validateOnly = false) => { + if (updater) { + set((state) => { + updater(state.settings); + }); + } + + const currentSettings = get().settings; try { - const result = await uploadSettingsToServer(newSettings); + const response = await fetch( + `/api/settings${validateOnly ? "?validate_only=true" : ""}`, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(currentSettings), + } + ); - if (!result.success) { - // Revert the local state if the server update failed - set({ settings: currentSettings }); + const data = await response.json(); + + if (!response.ok) { + throw new Error("Failed to upload settings to backend"); } - return result; + if (data.valid) { + // Settings are valid + set({ + lastValidSettings: currentSettings, + validationErrors: null, + }); + return true; + } else { + // Settings are invalid + set({ + validationErrors: data.errors, + }); + // Note: We don't revert the settings here, keeping the potentially invalid state + return false; + } } catch (error) { - // Revert the local state if there was an error - set({ settings: currentSettings }); - throw error; + console.error( + `Error ${validateOnly ? "validating" : "updating"} settings:`, + error + ); + return false; } }, })); diff --git a/gui_dev/src/stores/uiStore.js b/gui_dev/src/stores/uiStore.js index a229ec06..25b44382 100644 --- a/gui_dev/src/stores/uiStore.js +++ b/gui_dev/src/stores/uiStore.js @@ -1,4 +1,5 @@ import { createPersistStore } from "./createStore"; +import { useEffect } from "react"; export const useUiStore = createPersistStore("ui", (set, get) => ({ activeDrawer: null, @@ -28,4 +29,24 @@ export const useUiStore = createPersistStore("ui", (set, get) => ({ state.accordionStates[id] = defaultState; } }), + + // Hook to inject UI elements into the status bar + statusBarContent: () => {}, + setStatusBarContent: (content) => set({ statusBarContent: content }), + clearStatusBarContent: () => set({ statusBarContent: null }), })); + +// Use this hook from Page components to inject page-specific UI elements into the status bar +export const useStatusBarContent = (content) => { + const createStatusBarContent = () => content; + + const setStatusBarContent = useUiStore((state) => state.setStatusBarContent); + const clearStatusBarContent = useUiStore( + (state) => state.clearStatusBarContent + ); + + useEffect(() => { + setStatusBarContent(createStatusBarContent); + return () => clearStatusBarContent(); + }, [content, setStatusBarContent, clearStatusBarContent]); +}; diff --git a/gui_dev/src/theme.js b/gui_dev/src/theme.js index 5e2b5f11..cc99909c 100644 --- a/gui_dev/src/theme.js +++ b/gui_dev/src/theme.js @@ -22,6 +22,11 @@ export const theme = createTheme({ disableRipple: true, }, }, + MuiTextField: { + defaultProps: { + autoComplete: "off", + }, + }, MuiStack: { defaultProps: { alignItems: "center", diff --git a/gui_dev/src/utils/functions.js b/gui_dev/src/utils/functions.js index 04289d5c..1f8ae90a 100644 --- a/gui_dev/src/utils/functions.js +++ b/gui_dev/src/utils/functions.js @@ -17,6 +17,7 @@ export function debounce(func, wait) { timeout = setTimeout(later, wait); }; } + export const flattenDictionary = (dict, parentKey = "", result = {}) => { for (let key in dict) { const newKey = parentKey ? `${parentKey}.${key}` : key; diff --git a/py_neuromodulation/__init__.py b/py_neuromodulation/__init__.py index 688393fd..864df345 100644 --- a/py_neuromodulation/__init__.py +++ b/py_neuromodulation/__init__.py @@ -4,11 +4,12 @@ from importlib.metadata import version from py_neuromodulation.utils.logging import NMLogger + ##################################### # Globals and environment variables # ##################################### -__version__ = version("py_neuromodulation") # get version from pyproject.toml +__version__ = version("py_neuromodulation") # Check if the module is running headless (no display) for tests and doc builds PYNM_HEADLESS: bool = not os.environ.get("DISPLAY") @@ -18,6 +19,7 @@ os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file + # Set environment variable MNE_LSL_LIB (required to import Stream below) LSL_DICT = { "windows_32bit": "windows/x86/liblsl.1.16.2.dll", @@ -36,6 +38,7 @@ PLATFORM = platform.system().lower().strip() ARCH = platform.architecture()[0] + match PLATFORM: case "windows": KEY = PLATFORM + "_" + ARCH diff --git a/py_neuromodulation/default_settings.yaml b/py_neuromodulation/default_settings.yaml index 49988840..e7858599 100644 --- a/py_neuromodulation/default_settings.yaml +++ b/py_neuromodulation/default_settings.yaml @@ -1,4 +1,15 @@ ---- +# We +# should +# have +# a +# brief +# explanation +# of +# the +# settings +# format +# here + ######################## ### General settings ### ######################## @@ -51,12 +62,8 @@ preprocessing_filter: lowpass_filter: true highpass_filter: true bandpass_filter: true - bandstop_filter_settings: - frequency_low_hz: 100 - frequency_high_hz: 160 - bandpass_filter_settings: - frequency_low_hz: 3 - frequency_high_hz: 200 + bandstop_filter_settings: [100, 160] # [low_hz, high_hz] + bandpass_filter_settings: [3, 200] # [hz, _hz] lowpass_filter_cutoff_hz: 200 highpass_filter_cutoff_hz: 3 @@ -162,10 +169,8 @@ sharpwave_analysis_settings: decay_steepness: false slope_ratio: false filter_ranges_hz: - - frequency_low_hz: 5 - frequency_high_hz: 80 - - frequency_low_hz: 5 - frequency_high_hz: 30 + - [5, 80] + - [5, 30] detect_troughs: estimate: true distance_troughs_ms: 10 @@ -174,6 +179,7 @@ sharpwave_analysis_settings: estimate: true distance_troughs_ms: 5 distance_peaks_ms: 10 + # TONI: Reverse this setting? e.g. interval: [mean, var] estimator: mean: [interval] median: [] diff --git a/py_neuromodulation/features/bursts.py b/py_neuromodulation/features/bursts.py index 83d63bf8..2ff8ceac 100644 --- a/py_neuromodulation/features/bursts.py +++ b/py_neuromodulation/features/bursts.py @@ -11,7 +11,7 @@ from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature from typing import TYPE_CHECKING, Callable -from py_neuromodulation.utils.types import create_validation_error +from py_neuromodulation.utils.pydantic_extensions import create_validation_error if TYPE_CHECKING: from py_neuromodulation import NMSettings diff --git a/py_neuromodulation/gui/backend/app_backend.py b/py_neuromodulation/gui/backend/app_backend.py index e6c74985..ecfa0a97 100644 --- a/py_neuromodulation/gui/backend/app_backend.py +++ b/py_neuromodulation/gui/backend/app_backend.py @@ -13,6 +13,7 @@ ) from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware +from pydantic import ValidationError from . import app_pynm from .app_socket import WebSocketManager @@ -74,21 +75,62 @@ async def healthcheck(): #################### ##### SETTINGS ##### #################### - @self.get("/api/settings") - async def get_settings(): - return self.pynm_state.settings.process_for_frontend() + async def get_settings( + reset: bool = Query(False, description="Reset settings to default"), + ): + if reset: + settings = NMSettings.get_default() + else: + settings = self.pynm_state.settings + + return settings.serialize_with_metadata() @self.post("/api/settings") - async def update_settings(data: dict): + async def update_settings(data: dict, validate_only: bool = Query(False)): try: - self.pynm_state.settings = NMSettings.model_validate(data) - self.logger.info(self.pynm_state.settings.features) - return self.pynm_state.settings.model_dump() - except ValueError as e: + # First, validate with Pydantic + try: + validated_settings = NMSettings.model_validate(data) + except ValidationError as e: + if not validate_only: + # If validation failed but we wanted to upload, return error + self.logger.error(f"Error validating settings: {e}") + raise HTTPException( + status_code=422, + detail={ + "error": "Error validating settings", + "details": str(e), + }, + ) + # Else return list of errors + return { + "valid": False, + "errors": [err for err in e.errors()], + "details": str(e), + } + + # If validation succesful, return or update settings + if validate_only: + return { + "valid": True, + "settings": validated_settings.serialize_with_metadata(), + } + + self.pynm_state.settings = validated_settings + self.logger.info("Settings successfully updated") + + return { + "valid": True, + "settings": self.pynm_state.settings.serialize_with_metadata(), + } + + # If something else than validation went wrong, return error + except Exception as e: + self.logger.error(f"Error validating/updating settings: {e}") raise HTTPException( status_code=422, - detail={"error": "Validation failed", "details": str(e)}, + detail={"error": "Error uploading settings", "details": str(e)}, ) ######################## diff --git a/py_neuromodulation/processing/normalization.py b/py_neuromodulation/processing/normalization.py index 3345b122..66419cec 100644 --- a/py_neuromodulation/processing/normalization.py +++ b/py_neuromodulation/processing/normalization.py @@ -2,10 +2,10 @@ import numpy as np from typing import TYPE_CHECKING, Callable, Literal, get_args +from pydantic import Field from py_neuromodulation.utils.types import ( NMBaseModel, - Field, NormMethod, NMPreprocessor, ) diff --git a/py_neuromodulation/processing/resample.py b/py_neuromodulation/processing/resample.py index 08a4e115..6fa0885d 100644 --- a/py_neuromodulation/processing/resample.py +++ b/py_neuromodulation/processing/resample.py @@ -1,7 +1,8 @@ """Module for resampling.""" import numpy as np -from py_neuromodulation.utils.types import NMBaseModel, Field, NMPreprocessor +from pydantic import Field +from py_neuromodulation.utils.types import NMBaseModel, NMPreprocessor class ResamplerSettings(NMBaseModel): diff --git a/py_neuromodulation/stream/settings.py b/py_neuromodulation/stream/settings.py index 7e25ec24..39a9b6fd 100644 --- a/py_neuromodulation/stream/settings.py +++ b/py_neuromodulation/stream/settings.py @@ -4,15 +4,15 @@ from typing import ClassVar from pydantic import Field, model_validator -from py_neuromodulation import PYNM_DIR, logger, user_features +from py_neuromodulation import logger, user_features from py_neuromodulation.utils.types import ( BoolSelector, FrequencyRange, - PreprocessorName, _PathLike, NMBaseModel, NormMethod, + PreprocessorList, ) from py_neuromodulation.processing.filter_preprocessing import FilterSettings @@ -72,11 +72,14 @@ class NMSettings(NMBaseModel): } # Preproceessing settings - preprocessing: list[PreprocessorName] = [ - "raw_resampling", - "notch_filter", - "re_referencing", - ] + preprocessing: PreprocessorList = PreprocessorList( + [ + "raw_resampling", + "notch_filter", + "re_referencing", + ] + ) + raw_resampling_settings: ResamplerSettings = ResamplerSettings() preprocessing_filter: FilterSettings = FilterSettings() raw_normalization_settings: NormalizationSettings = NormalizationSettings() @@ -144,26 +147,29 @@ def validate_settings(self): if self.bandpass_filter_settings.kalman_filter: self.kalman_filter_settings.validate_fbands(self) - for k, v in self.frequency_ranges_hz.items(): - if not isinstance(v, FrequencyRange): - self.frequency_ranges_hz[k] = FrequencyRange.create_from(v) + # TONI: not needed after NMSequenceModel, remove in the future + # for k, v in self.frequency_ranges_hz.items(): + # if not isinstance(v, FrequencyRange): + # self.frequency_ranges_hz[k] = FrequencyRange.create_from(v) return self def reset(self) -> "NMSettings": self.features.disable_all() - self.preprocessing = [] + self.preprocessing = PreprocessorList() self.postprocessing.disable_all() return self def set_fast_compute(self) -> "NMSettings": self.reset() self.features.fft = True - self.preprocessing = [ - "raw_resampling", - "notch_filter", - "re_referencing", - ] + self.preprocessing = PreprocessorList( + [ + "raw_resampling", + "notch_filter", + "re_referencing", + ] + ) self.postprocessing.feature_normalization = True self.postprocessing.project_cortex = False self.postprocessing.project_subcortex = False @@ -250,7 +256,7 @@ def from_file(PATH: _PathLike) -> "NMSettings": @staticmethod def get_default() -> "NMSettings": - return NMSettings.from_file(PYNM_DIR / "default_settings.yaml") + return NMSettings() @staticmethod def list_normalization_methods() -> list[NormMethod]: diff --git a/py_neuromodulation/utils/pydantic_extensions.py b/py_neuromodulation/utils/pydantic_extensions.py new file mode 100644 index 00000000..be677243 --- /dev/null +++ b/py_neuromodulation/utils/pydantic_extensions.py @@ -0,0 +1,347 @@ +from typing import Any, get_type_hints, TypeVar, Generic, Literal, overload +from typing_extensions import Unpack, TypedDict +from pydantic import BaseModel, ConfigDict, model_validator +from pydantic_core import PydanticUndefined, ValidationError, InitErrorDetails +from pydantic.fields import FieldInfo, _FieldInfoInputs, _FromFieldInfoInputs +from pprint import pformat + + +def create_validation_error( + error_message: str, + loc: list[str | int] | None = None, + title: str = "Validation Error", + input_type: Literal["python", "json"] = "python", + hide_input: bool = False, +) -> ValidationError: + """ + Factory function to create a Pydantic v2 ValidationError instance from a single error message. + + Args: + error_message (str): The error message for the ValidationError. + loc (List[str | int], optional): The location of the error. Defaults to None. + title (str, optional): The title of the error. Defaults to "Validation Error". + input_type (Literal["python", "json"], optional): Whether the error is for a Python object or JSON. Defaults to "python". + hide_input (bool, optional): Whether to hide the input value in the error message. Defaults to False. + + Returns: + ValidationError: A Pydantic ValidationError instance. + """ + if loc is None: + loc = [] + + line_errors = [ + InitErrorDetails( + type="value_error", loc=tuple(loc), input=None, ctx={"error": error_message} + ) + ] + + return ValidationError.from_exception_data( + title=title, + line_errors=line_errors, + input_type=input_type, + hide_input=hide_input, + ) + + +class _NMExtraFieldInputs(TypedDict, total=False): + """Additional fields to add on top of the pydantic FieldInfo""" + + custom_metadata: dict[str, Any] + + +class _NMFieldInfoInputs(_FieldInfoInputs, _NMExtraFieldInputs, total=False): + """Combine pydantic FieldInfo inputs with PyNM additional inputs""" + + pass + + +class _NMFromFieldInfoInputs(_FromFieldInfoInputs, _NMExtraFieldInputs, total=False): + """Combine pydantic FieldInfo.from_field inputs with PyNM additional inputs""" + + pass + + +class NMFieldInfo(FieldInfo): + # Add default values for any other custom fields here + _default_values = {} + + def __init__(self, **kwargs: Unpack[_NMFieldInfoInputs]) -> None: + self.custom_metadata = kwargs.pop("custom_metadata", {}) + super().__init__(**kwargs) + + @staticmethod + def from_field( + default: Any = PydanticUndefined, + **kwargs: Unpack[_NMFromFieldInfoInputs], + ) -> "NMFieldInfo": + if "annotation" in kwargs: + raise TypeError('"annotation" is not permitted as a Field keyword argument') + return NMFieldInfo(default=default, **kwargs) + + def __repr_args__(self): + yield from super().__repr_args__() + extra_fields = get_type_hints(_NMExtraFieldInputs) + for field in extra_fields: + value = getattr(self, field) + yield field, value + + +def NMField( + default: Any = PydanticUndefined, + **kwargs: Unpack[_NMFromFieldInfoInputs], +) -> Any: + return NMFieldInfo.from_field(default=default, **kwargs) + + +class NMBaseModel(BaseModel): + model_config = ConfigDict(validate_assignment=False, extra="allow") + + def __init__(self, *args, **kwargs) -> None: + """Pydantic does not support positional arguments by default. + This is a workaround to support positional arguments for models like FrequencyRange. + It converts positional arguments to kwargs and then calls the base class __init__. + """ + if not args: + # Simple case - just use kwargs + super().__init__(**kwargs) + return + + field_names = list(self.model_fields.keys()) + # If we have more positional args than fields, that's an error + if len(args) > len(field_names): + raise ValueError( + f"Too many positional arguments. Expected at most {len(field_names)}, got {len(args)}" + ) + + # Convert positional args to kwargs, checking for duplicates if args: + complete_kwargs = {} + for i, arg in enumerate(args): + field_name = field_names[i] + if field_name in kwargs: + raise ValueError( + f"Got multiple values for field '{field_name}': " + f"positional argument and keyword argument" + ) + complete_kwargs[field_name] = arg + + # Add remaining kwargs + complete_kwargs.update(kwargs) + super().__init__(**complete_kwargs) + + def __str__(self): + return pformat(self.model_dump()) + + def __repr__(self): + return pformat(self.model_dump()) + + def validate(self) -> Any: # type: ignore + return self.model_validate(self.model_dump()) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value) -> None: + setattr(self, key, value) + + @property + def fields(self) -> dict[str, FieldInfo | NMFieldInfo]: + return self.model_fields # type: ignore + + def serialize_with_metadata(self): + result: dict[str, Any] = {"__field_type__": self.__class__.__name__} + + for field_name, field_info in self.fields.items(): + value = getattr(self, field_name) + if isinstance(value, NMBaseModel): + result[field_name] = value.serialize_with_metadata() + elif isinstance(value, list): + result[field_name] = [ + item.serialize_with_metadata() + if isinstance(item, NMBaseModel) + else item + for item in value + ] + elif isinstance(value, dict): + result[field_name] = { + k: v.serialize_with_metadata() if isinstance(v, NMBaseModel) else v + for k, v in value.items() + } + else: + result[field_name] = value + + # Extract unit information from Annotated type + if isinstance(field_info, NMFieldInfo): + for tag, value in field_info.custom_metadata.items(): + result[f"__{tag}__"] = value + return result + + +################################# +#### Generic Pydantic models #### +################################# + + +def create_alias_property(index: int, alias: str, classname: str): + """Creates a property that accesses the root sequence at the given index""" + + def getter(self): + return self.root[index] + + def setter(self, value): + if isinstance(self.root, tuple): + new_values = list(self.root) + new_values[index] = value + self.root = tuple(new_values) + else: + self.root[index] = value + + return property( + fget=getter, + fset=setter, + doc=f"Alias '{alias}' for position [{index}] of class '{classname}'.", + ) + + +T = TypeVar("T") +C = TypeVar("C", list, tuple) + + +class NMSequenceModel(NMBaseModel, Generic[C]): + """Base class for sequence models with a single root value""" + + root: C = NMField(default_factory=list) + + # Class variable for aliases - override in subclasses + __aliases__: dict[int, list[str]] = {} + + def __init__(self, *args, **kwargs) -> None: + # Generate properties programatically (not used currently) + # for index, aliases in self.__aliases__.items(): + # for alias in aliases: + # if not hasattr(self.__class__, alias): + # setattr( + # self.__class__, + # alias, + # create_alias_property(index, alias, self.__class__.__name__), + # ) + + if len(args) == 1 and isinstance(args[0], (list, tuple)): + kwargs["root"] = args[0] + elif len(args) == 1: + kwargs["root"] = [args[0]] + elif len(args) > 1: # Add this case + kwargs["root"] = tuple(args) + super().__init__(**kwargs) + + def __iter__(self): # type: ignore[reportIncompatibleMethodOverride] + return iter(self.root) + + def __getitem__(self, idx): + return self.root[idx] + + def __len__(self): + return len(self.root) + + def model_dump(self): # type: ignore[reportIncompatibleMethodOverride] + return self.root + + def model_dump_json(self, **kwargs): + import json + + return json.dumps(self.root, **kwargs) + + def serialize_with_metadata(self) -> dict[str, Any]: + result = {"__field_type__": self.__class__.__name__, "value": self.root} + + # Add any field metadata from the root field + field_info = self.model_fields.get("root") + if isinstance(field_info, NMFieldInfo): + for tag, value in field_info.custom_metadata.items(): + result[f"__{tag}__"] = value + + return result + + @model_validator(mode="before") + @classmethod + def validate_input(cls, value: Any) -> dict[str, Any]: + # If it's a dict, just return it + if isinstance(value, dict): + if "root" in value: + return value + + # Check for aliased fields if class has aliases defined + if hasattr(cls, "__aliases__"): + # Collect all possible alias names for each position + alias_values = [] + max_index = max(cls.__aliases__.keys()) if cls.__aliases__ else -1 + + # Try to find a value for each position using its aliases + for i in range(max_index + 1): + aliases = cls.__aliases__.get(i, []) + value_found = None + + # Try each alias for this position + for alias in aliases: + if alias in value: + value_found = value[alias] + break + + if value_found is not None: + alias_values.append(value_found) + else: + # If we're missing any position's value, don't use aliases + break + + # If we found all values through aliases, use them + if len(alias_values) == max_index + 1: + return {"root": alias_values} + + # if it's a sequence, return the value as the root + if isinstance(value, (list, tuple)): + return {"root": value} + + # Else, make it a list + return {"root": [value]} + + +class NMValueModel(NMBaseModel, Generic[T]): + """Base class for single-value models that behave like their contained type""" + + root: T + + @model_validator(mode="before") + @classmethod + def validate_input(cls, value: Any) -> dict[str, Any]: + if isinstance(value, dict): + if "root" in value: + return value + # If it's a dict without root, assume the first value is our value + if len(value) > 0: + return {"root": next(iter(value.values()))} + return {"root": None} + return {"root": value} + + def __str__(self) -> str: + return str(self.root) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({repr(self.root)})" + + def model_dump(self): # type: ignore[reportIncompatibleMethodOverride] + return self.root + + def model_dump_json(self, **kwargs): + import json + + return json.dumps(self.root, **kwargs) + + def serialize_with_metadata(self) -> dict[str, Any]: + result = {"__field_type__": self.__class__.__name__, "value": self.root} + + # Add any field metadata from the root field + field_info = self.model_fields.get("root") + if isinstance(field_info, NMFieldInfo): + for tag, value in field_info.custom_metadata.items(): + result[f"__{tag}__"] = value + + return result diff --git a/py_neuromodulation/utils/types.py b/py_neuromodulation/utils/types.py index 7886685d..a2c22db9 100644 --- a/py_neuromodulation/utils/types.py +++ b/py_neuromodulation/utils/types.py @@ -1,12 +1,13 @@ from os import PathLike from math import isnan -from typing import Any, Literal, Protocol, TYPE_CHECKING, runtime_checkable -from pydantic import ConfigDict, Field, model_validator, BaseModel -from pydantic_core import ValidationError, InitErrorDetails -from pprint import pformat +from typing import Literal, TYPE_CHECKING +from pydantic import BaseModel, ConfigDict, model_validator +from .pydantic_extensions import NMBaseModel, NMSequenceModel + from collections.abc import Sequence from datetime import datetime + if TYPE_CHECKING: import numpy as np from py_neuromodulation import NMSettings @@ -17,7 +18,6 @@ _PathLike = str | PathLike - FeatureName = Literal[ "raw_hjorth", "return_raw", @@ -54,13 +54,13 @@ "minmax", ] + ################################### ######## PROTOCOL CLASSES ######## ################################### -@runtime_checkable -class NMFeature(Protocol): +class NMFeature: def __init__( self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int | float ) -> None: ... @@ -81,131 +81,47 @@ def calc_feature(self, data: "np.ndarray") -> dict: ... -class NMPreprocessor(Protocol): - def __init__(self, sfreq: float, settings: "NMSettings") -> None: ... - +class NMPreprocessor: def process(self, data: "np.ndarray") -> "np.ndarray": ... -################################### -######## PYDANTIC CLASSES ######## -################################### - - -class NMBaseModel(BaseModel): - model_config = ConfigDict(validate_assignment=False, extra="allow") +class PreprocessorList(NMSequenceModel[list[PreprocessorName]]): + model_config = ConfigDict(arbitrary_types_allowed=True) + # Useless contructor to prevent linter from complaining def __init__(self, *args, **kwargs) -> None: - if kwargs: - super().__init__(**kwargs) - else: - field_names = list(self.model_fields.keys()) - kwargs = {} - for i in range(len(args)): - kwargs[field_names[i]] = args[i] - super().__init__(**kwargs) - - def __str__(self): - return pformat(self.model_dump()) - - def __repr__(self): - return pformat(self.model_dump()) - - def validate(self) -> Any: # type: ignore - return self.model_validate(self.model_dump()) + super().__init__(*args, **kwargs) - def __getitem__(self, key): - return getattr(self, key) - def __setitem__(self, key, value) -> None: - setattr(self, key, value) +class FrequencyRange(NMSequenceModel[tuple[float, float]]): + """Frequency range as (low, high) tuple""" - def process_for_frontend(self) -> dict[str, Any]: - """ - Process the model for frontend use, adding __field_type__ information. - """ - result = {} - for field_name, field_value in self.__dict__.items(): - if isinstance(field_value, NMBaseModel): - processed_value = field_value.process_for_frontend() - processed_value["__field_type__"] = field_value.__class__.__name__ - result[field_name] = processed_value - elif isinstance(field_value, list): - result[field_name] = [ - item.process_for_frontend() - if isinstance(item, NMBaseModel) - else item - for item in field_value - ] - elif isinstance(field_value, dict): - result[field_name] = { - k: v.process_for_frontend() if isinstance(v, NMBaseModel) else v - for k, v in field_value.items() - } - else: - result[field_name] = field_value - - return result - - -class FrequencyRange(NMBaseModel): - frequency_low_hz: float = Field(gt=0) - frequency_high_hz: float = Field(gt=0) + __aliases__ = { + 0: ["frequency_low_hz", "low_frequency_hz"], + 1: ["frequency_high_hz", "high_frequency_hz"], + } + # Useless contructor to prevent linter from complaining def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - def __getitem__(self, item: int): - match item: - case 0: - return self.frequency_low_hz - case 1: - return self.frequency_high_hz - case _: - raise IndexError(f"Index {item} out of range") - - def as_tuple(self) -> tuple[float, float]: - return (self.frequency_low_hz, self.frequency_high_hz) - - def __iter__(self): # type: ignore - return iter(self.as_tuple()) - @model_validator(mode="after") def validate_range(self): - if not (isnan(self.frequency_high_hz) or isnan(self.frequency_low_hz)): - assert ( - self.frequency_high_hz > self.frequency_low_hz - ), "Frequency high must be greater than frequency low" + low, high = self.root + if not (isnan(low) or isnan(high)): + assert high > low, "High frequency must be greater than low frequency" return self - @classmethod - def create_from(cls, input) -> "FrequencyRange": - match input: - case FrequencyRange(): - return input - case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input: - return FrequencyRange( - input["frequency_low_hz"], input["frequency_high_hz"] - ) - case Sequence() if len(input) == 2: - return FrequencyRange(input[0], input[1]) - case _: - raise ValueError("Invalid input for FrequencyRange creation.") - - @model_validator(mode="before") - @classmethod - def check_input(cls, input): - match input: - case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input: - return input - case Sequence() if len(input) == 2: - return {"frequency_low_hz": input[0], "frequency_high_hz": input[1]} - case _: - raise ValueError( - "Value for FrequencyRange must be a dictionary, " - "or a sequence of 2 numeric values, " - f"but got {input} instead." - ) + # Alias properties + @property + def frequency_low_hz(self) -> float: + """Lower frequency bound in Hz""" + return self.root[0] + + @property + def frequency_high_hz(self) -> float: + """Upper frequency bound in Hz""" + return self.root[1] class BoolSelector(NMBaseModel): @@ -238,47 +154,6 @@ def print_all(cls): for f in cls.list_all(): print(f) - @classmethod - def get_fields(cls): - return cls.model_fields - - -def create_validation_error( - error_message: str, - loc: list[str | int] = None, - title: str = "Validation Error", - input_type: Literal["python", "json"] = "python", - hide_input: bool = False, -) -> ValidationError: - """ - Factory function to create a Pydantic v2 ValidationError instance from a single error message. - - Args: - error_message (str): The error message for the ValidationError. - loc (List[str | int], optional): The location of the error. Defaults to None. - title (str, optional): The title of the error. Defaults to "Validation Error". - input_type (Literal["python", "json"], optional): Whether the error is for a Python object or JSON. Defaults to "python". - hide_input (bool, optional): Whether to hide the input value in the error message. Defaults to False. - - Returns: - ValidationError: A Pydantic ValidationError instance. - """ - if loc is None: - loc = [] - - line_errors = [ - InitErrorDetails( - type="value_error", loc=tuple(loc), input=None, ctx={"error": error_message} - ) - ] - - return ValidationError.from_exception_data( - title=title, - line_errors=line_errors, - input_type=input_type, - hide_input=hide_input, - ) - ################# ### GUI TYPES ### diff --git a/pyproject.toml b/pyproject.toml index 32cea5a5..193177c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,6 @@ dependencies = [ "nolds >=0.6.1", "numpy >= 1.21.2", "pandas >= 2.0.0", - "scikit-image", "scikit-learn >= 0.24.2", "scikit-optimize", "scipy >= 1.7.1", @@ -55,15 +54,12 @@ dependencies = [ "statsmodels", "mne-lsl>=1.2.0", "pynput", - #"pyqt5", "pydantic>=2.7.3", "llvmlite>=0.43.0", "pywebview", "fastapi", - "uvicorn>=0.30.6", - "websockets>=13.0", + "uvicorn[standard]>=0.30.6", "seaborn >= 0.11", - # exists only because of nolds? "numba>=0.60.0", "cbor2>=5.6.4", "msgpack>=1.1.0", @@ -78,7 +74,7 @@ dev = [ "pytest-cov", "pytest-sugar", "notebook", - "watchdog", + "uvicorn[standard]", ] docs = [ "py-neuromodulation[dev]",