diff --git a/gui_dev/src/components/StatusBar/StatusBar.jsx b/gui_dev/src/components/StatusBar/StatusBar.jsx
index 36aaf1dc..f011e2b9 100644
--- a/gui_dev/src/components/StatusBar/StatusBar.jsx
+++ b/gui_dev/src/components/StatusBar/StatusBar.jsx
@@ -2,12 +2,14 @@ import { ResizeHandle } from "./ResizeHandle";
import { SocketStatus } from "./SocketStatus";
import { WebviewStatus } from "./WebviewStatus";
-import { useWebviewStore } from "@/stores";
-
+import { useUiStore, useWebviewStore } from "@/stores";
import { Stack } from "@mui/material";
export const StatusBar = () => {
- const { isWebView } = useWebviewStore((state) => state.isWebView);
+ const isWebView = useWebviewStore((state) => state.isWebView);
+ const createStatusBarContent = useUiStore((state) => state.statusBarContent);
+
+ const StatusBarContent = createStatusBarContent();
return (
{
bgcolor="background.level1"
borderTop="2px solid"
borderColor="background.level3"
+ height="2rem"
>
-
+ {StatusBarContent && }
+
+ {/* */}
{/* Current experiment */}
{/* Current stream */}
{/* Current activity */}
diff --git a/gui_dev/src/components/TitledBox.jsx b/gui_dev/src/components/TitledBox.jsx
index 56f0266b..a86b6009 100644
--- a/gui_dev/src/components/TitledBox.jsx
+++ b/gui_dev/src/components/TitledBox.jsx
@@ -1,4 +1,4 @@
-import { Container } from "@mui/material";
+import { Box, Container } from "@mui/material";
/**
* Component that uses the Box component to render an HTML fieldset element
@@ -13,14 +13,14 @@ export const TitledBox = ({
children,
...props
}) => (
-
{children}
-
+
);
diff --git a/gui_dev/src/pages/Settings/DragAndDropList.jsx b/gui_dev/src/pages/Settings/DragAndDropList.jsx
deleted file mode 100644
index afa74b02..00000000
--- a/gui_dev/src/pages/Settings/DragAndDropList.jsx
+++ /dev/null
@@ -1,65 +0,0 @@
-import { useRef } from "react";
-import styles from "./DragAndDropList.module.css";
-import { useOptionsStore } from "@/stores";
-
-export const DragAndDropList = () => {
- const { options, setOptions, addOption, removeOption } = useOptionsStore();
-
- const predefinedOptions = [
- { id: 1, name: "raw_resampling" },
- { id: 2, name: "notch_filter" },
- { id: 3, name: "re_referencing" },
- { id: 4, name: "preprocessing_filter" },
- { id: 5, name: "raw_normalization" },
- ];
-
- const dragOption = useRef(0);
- const draggedOverOption = useRef(0);
-
- function handleSort() {
- const optionsClone = [...options];
- const temp = optionsClone[dragOption.current];
- optionsClone[dragOption.current] = optionsClone[draggedOverOption.current];
- optionsClone[draggedOverOption.current] = temp;
- setOptions(optionsClone);
- }
-
- return (
-
- List
- {options.map((option, index) => (
- (dragOption.current = index)}
- onDragEnter={() => (draggedOverOption.current = index)}
- onDragEnd={handleSort}
- onDragOver={(e) => e.preventDefault()}
- >
-
- {[option.id, ". ", option.name.replace("_", " ")]}
-
-
-
- ))}
-
-
Add Elements
- {predefinedOptions.map((option) => (
-
- ))}
-
-
- );
-};
diff --git a/gui_dev/src/pages/Settings/DragAndDropList.module.css b/gui_dev/src/pages/Settings/DragAndDropList.module.css
deleted file mode 100644
index 76bd7861..00000000
--- a/gui_dev/src/pages/Settings/DragAndDropList.module.css
+++ /dev/null
@@ -1,72 +0,0 @@
-.dragDropList {
- max-width: 400px;
- margin: 0 auto;
- padding: 20px;
-}
-
-.title {
- text-align: center;
- color: #333;
- margin-bottom: 20px;
-}
-
-.item {
- background-color: #f0f0f0;
- border-radius: 8px;
- padding: 15px;
- margin-bottom: 10px;
- cursor: move;
- transition: background-color 0.3s ease;
- display: flex;
- justify-content: space-between;
- align-items: center;
-}
-
-.item:hover {
- background-color: #e0e0e0;
-}
-
-.itemText {
- margin: 0;
- color: #333;
- font-size: 16px;
-}
-
-.removeButton {
- background-color: #ff4d4d;
- border: none;
- border-radius: 4px;
- color: white;
- padding: 5px 10px;
- cursor: pointer;
- transition: background-color 0.3s ease;
-}
-
-.removeButton:hover {
- background-color: #ff1a1a;
-}
-
-.addSection {
- margin-top: 20px;
- text-align: center;
-}
-
-.subtitle {
- color: #333;
- margin-bottom: 10px;
-}
-
-.addButton {
- background-color: #4CAF50;
- border: none;
- border-radius: 4px;
- color: white;
- padding: 10px 15px;
- margin: 5px;
- cursor: pointer;
- transition: background-color 0.3s ease;
-}
-
-.addButton:hover {
- background-color: #45a049;
-}
\ No newline at end of file
diff --git a/gui_dev/src/pages/Settings/Dropdown.jsx b/gui_dev/src/pages/Settings/Dropdown.jsx
deleted file mode 100644
index ad662eef..00000000
--- a/gui_dev/src/pages/Settings/Dropdown.jsx
+++ /dev/null
@@ -1,66 +0,0 @@
-import { useState } from "react";
-import "../App.css";
-
-var stringJson =
- '{"sampling_rate_features_hz":10.0,"segment_length_features_ms":1000.0,"frequency_ranges_hz":{"theta":{"frequency_low_hz":4.0,"frequency_high_hz":8.0},"alpha":{"frequency_low_hz":8.0,"frequency_high_hz":12.0},"low beta":{"frequency_low_hz":13.0,"frequency_high_hz":20.0},"high beta":{"frequency_low_hz":20.0,"frequency_high_hz":35.0},"low gamma":{"frequency_low_hz":60.0,"frequency_high_hz":80.0},"high gamma":{"frequency_low_hz":90.0,"frequency_high_hz":200.0},"HFA":{"frequency_low_hz":200.0,"frequency_high_hz":400.0}},"preprocessing":["raw_resampling","notch_filter","re_referencing"],"raw_resampling_settings":{"resample_freq_hz":1000.0},"preprocessing_filter":{"bandstop_filter":true,"bandpass_filter":true,"lowpass_filter":true,"highpass_filter":true,"bandstop_filter_settings":{"frequency_low_hz":100.0,"frequency_high_hz":160.0},"bandpass_filter_settings":{"frequency_low_hz":2.0,"frequency_high_hz":200.0},"lowpass_filter_cutoff_hz":200.0,"highpass_filter_cutoff_hz":3.0},"raw_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"postprocessing":{"feature_normalization":true,"project_cortex":false,"project_subcortex":false},"feature_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"project_cortex_settings":{"max_dist_mm":20.0},"project_subcortex_settings":{"max_dist_mm":5.0},"features":{"raw_hjorth":true,"return_raw":true,"bandpass_filter":false,"stft":false,"fft":true,"welch":true,"sharpwave_analysis":true,"fooof":false,"nolds":false,"coherence":false,"bursts":true,"linelength":true,"mne_connectivity":false,"bispectrum":false},"fft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"welch_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"stft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"bandpass_filter_settings":{"segment_lengths_ms":{"theta":1000,"alpha":500,"low beta":333,"high beta":333,"low gamma":100,"high gamma":100,"HFA":100},"bandpower_features":{"activity":true,"mobility":false,"complexity":false},"log_transform":true,"kalman_filter":false},"kalman_filter_settings":{"Tp":0.1,"sigma_w":0.7,"sigma_v":1.0,"frequency_bands":["theta","alpha","low_beta","high_beta","low_gamma","high_gamma","HFA"]},"burst_settings":{"threshold":75.0,"time_duration_s":30.0,"frequency_bands":["low beta","high beta","low gamma"],"burst_features":{"duration":true,"amplitude":true,"burst_rate_per_s":true,"in_burst":true}},"sharpwave_analysis_settings":{"sharpwave_features":{"peak_left":false,"peak_right":false,"trough":false,"width":false,"prominence":true,"interval":true,"decay_time":false,"rise_time":false,"sharpness":true,"rise_steepness":false,"decay_steepness":false,"slope_ratio":false},"filter_ranges_hz":[{"frequency_low_hz":5.0,"frequency_high_hz":80.0},{"frequency_low_hz":5.0,"frequency_high_hz":30.0}],"detect_troughs":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"detect_peaks":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"estimator":{"mean":["interval"],"median":[],"max":["prominence","sharpness"],"min":[],"var":[]},"apply_estimator_between_peaks_and_troughs":true},"mne_connectivity":{"method":"plv","mode":"multitaper"},"coherence":{"features":{"mean_fband":true,"max_fband":true,"max_allfbands":true},"method":{"coh":true,"icoh":true},"channels":[],"frequency_bands":["high beta"]},"fooof":{"aperiodic":{"exponent":true,"offset":true,"knee":true},"periodic":{"center_frequency":false,"band_width":false,"height_over_ap":false},"windowlength_ms":800.0,"peak_width_limits":{"frequency_low_hz":0.5,"frequency_high_hz":12.0},"max_n_peaks":3,"min_peak_height":0.0,"peak_threshold":2.0,"freq_range_hz":{"frequency_low_hz":2.0,"frequency_high_hz":40.0},"knee":true},"nolds_features":{"raw":true,"frequency_bands":["low beta"],"features":{"sample_entropy":false,"correlation_dimension":false,"lyapunov_exponent":true,"hurst_exponent":false,"detrended_fluctuation_analysis":false}},"bispectrum":{"f1s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"f2s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"compute_features_for_whole_fband_range":true,"frequency_bands":["theta","alpha","low_beta","high_beta"],"components":{"absolute":true,"real":true,"imag":true,"phase":true},"bispectrum_features":{"mean":true,"sum":true,"var":true}}}';
-const nm_settings = JSON.parse(stringJson);
-
-const filterByKeys = (dict, keys) => {
- const filteredDict = {};
- keys.forEach((key) => {
- if (typeof key === "string") {
- // Top-level key
- if (dict.hasOwnProperty(key)) {
- filteredDict[key] = dict[key];
- }
- } else if (typeof key === "object") {
- // Nested key
- const topLevelKey = Object.keys(key)[0];
- const nestedKeys = key[topLevelKey];
- if (
- dict.hasOwnProperty(topLevelKey) &&
- typeof dict[topLevelKey] === "object"
- ) {
- filteredDict[topLevelKey] = filterByKeys(dict[topLevelKey], nestedKeys);
- }
- }
- });
- return filteredDict;
-};
-
-const Dropdown = ({ onChange, keysToInclude }) => {
- const filteredSettings = filterByKeys(nm_settings, keysToInclude);
- const [selectedOption, setSelectedOption] = useState("");
-
- const handleChange = (event) => {
- const newValue = event.target.value;
- setSelectedOption(newValue);
- onChange(keysToInclude, newValue);
- };
-
- return (
-
-
-
- );
-};
-
-export default Dropdown;
diff --git a/gui_dev/src/pages/Settings/FrequencyRange.jsx b/gui_dev/src/pages/Settings/FrequencyRange.jsx
deleted file mode 100644
index 5def783b..00000000
--- a/gui_dev/src/pages/Settings/FrequencyRange.jsx
+++ /dev/null
@@ -1,88 +0,0 @@
-import { useState } from "react";
-
-// const onChange = (key, newValue) => {
-// settings.frequencyRanges[key] = newValue;
-// };
-
-// Object.entries(settings.frequencyRanges).map(([key, value]) => (
-//
-// ));
-
-/** */
-
-/**
- *
- * @param {String} key
- * @param {Array} freqRange
- * @param {Function} onChange
- * @returns
- */
-export const FrequencyRange = ({ settings }) => {
- const [frequencyRanges, setFrequencyRanges] = useState(settings || {});
- // Handle changes in the text fields
- const handleInputChange = (label, key, newValue) => {
- setFrequencyRanges((prevState) => ({
- ...prevState,
- [label]: {
- ...prevState[label],
- [key]: newValue,
- },
- }));
- };
-
- // Add a new band
- const addBand = () => {
- const newLabel = `Band ${Object.keys(frequencyRanges).length + 1}`;
- setFrequencyRanges((prevState) => ({
- ...prevState,
- [newLabel]: { frequency_high_hz: "", frequency_low_hz: "" },
- }));
- };
-
- // Remove a band
- const removeBand = (label) => {
- const updatedRanges = { ...frequencyRanges };
- delete updatedRanges[label];
- setFrequencyRanges(updatedRanges);
- };
-
- return (
-
- );
-};
diff --git a/gui_dev/src/pages/Settings/FrequencySettings.module.css b/gui_dev/src/pages/Settings/FrequencySettings.module.css
deleted file mode 100644
index a638ad93..00000000
--- a/gui_dev/src/pages/Settings/FrequencySettings.module.css
+++ /dev/null
@@ -1,67 +0,0 @@
-.container {
- background-color: #f9f9f9; /* Light gray background for the container */
- padding: 20px;
- border-radius: 10px; /* Rounded corners */
- box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); /* Subtle shadow */
- max-width: 600px;
- margin: auto;
- }
-
- .header {
- font-size: 1.5rem;
- color: #333; /* Darker text color */
- margin-bottom: 20px;
- text-align: center;
- }
-
- .bandContainer {
- display: flex;
- align-items: center;
- margin-bottom: 15px;
- padding: 10px;
- border: 1px solid #ddd; /* Light border for each band */
- border-radius: 8px; /* Rounded corners for individual bands */
- background-color: #fff; /* White background for bands */
- box-shadow: 0 1px 5px rgba(0, 0, 0, 0.1); /* Light shadow for depth */
- }
-
- .bandNameInput, .frequencyInput {
- border: 1px solid #ccc; /* Light gray border */
- border-radius: 5px; /* Slightly rounded corners */
- padding: 8px;
- margin-right: 10px;
- font-size: 0.875rem;
- }
-
- .bandNameInput::placeholder, .frequencyInput::placeholder {
- color: #aaa; /* Light gray placeholder text */
- }
-
- .removeButton, .addButton {
- border: none;
- border-radius: 5px; /* Rounded corners */
- padding: 8px 12px;
- font-size: 0.875rem;
- cursor: pointer;
- transition: background-color 0.3s, color 0.3s;
- }
-
- .removeButton {
- background-color: #e57373; /* Light red color */
- color: white;
- }
-
- .removeButton:hover {
- background-color: #d32f2f; /* Darker red on hover */
- }
-
- .addButton {
- background-color: #4caf50; /* Green color */
- color: white;
- margin-top: 10px;
- }
-
- .addButton:hover {
- background-color: #388e3c; /* Darker green on hover */
- }
-
\ No newline at end of file
diff --git a/gui_dev/src/pages/Settings/Settings.jsx b/gui_dev/src/pages/Settings/Settings.jsx
index e2cf5f2a..14b57246 100644
--- a/gui_dev/src/pages/Settings/Settings.jsx
+++ b/gui_dev/src/pages/Settings/Settings.jsx
@@ -1,19 +1,24 @@
+import { useEffect, useState } from "react";
import {
+ Box,
Button,
+ ButtonGroup,
InputAdornment,
+ Popover,
Stack,
Switch,
TextField,
+ Tooltip,
Typography,
} from "@mui/material";
import { Link } from "react-router-dom";
import { CollapsibleBox, TitledBox } from "@/components";
-import { FrequencyRange } from "./FrequencyRange";
-import { useSettingsStore } from "@/stores";
+import { FrequencyRangeList } from "./FrequencyRange";
+import { Dropdown } from "./Dropdown";
+import { useSettingsStore, useStatusBarContent } from "@/stores";
import { filterObjectByKeys } from "@/utils/functions";
const formatKey = (key) => {
- // console.log(key);
return key
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
@@ -21,15 +26,32 @@ const formatKey = (key) => {
};
// Wrapper components for each type
-const BooleanField = ({ value, onChange }) => (
+const BooleanField = ({ value, onChange, error }) => (
onChange(e.target.checked)} />
);
-const StringField = ({ value, onChange, label }) => (
-
+const StringField = ({ value, onChange, label, error }) => (
+ onChange(e.target.value)}
+ label={label}
+ sx={{
+ "& .MuiOutlinedInput-root": {
+ "& fieldset": {
+ borderColor: error ? "error.main" : "inherit",
+ },
+ "&:hover fieldset": {
+ borderColor: error ? "error.main" : "primary.main",
+ },
+ "&.Mui-focused fieldset": {
+ borderColor: error ? "error.main" : "primary.main",
+ },
+ },
+ }}
+ />
);
-const NumberField = ({ value, onChange, label }) => {
+const NumberField = ({ value, onChange, label, error }) => {
const handleChange = (event) => {
const newValue = event.target.value;
// Only allow numbers and decimal point
@@ -44,13 +66,26 @@ const NumberField = ({ value, onChange, label }) => {
value={value}
onChange={handleChange}
label={label}
- InputProps={{
- endAdornment: (
-
- Hz
-
- ),
+ sx={{
+ "& .MuiOutlinedInput-root": {
+ "& fieldset": {
+ borderColor: error ? "error.main" : "inherit",
+ },
+ "&:hover fieldset": {
+ borderColor: error ? "error.main" : "primary.main",
+ },
+ "&.Mui-focused fieldset": {
+ borderColor: error ? "error.main" : "primary.main",
+ },
+ },
}}
+ // InputProps={{
+ // endAdornment: (
+ //
+ // Hz
+ //
+ // ),
+ // }}
inputProps={{
pattern: "[0-9]*",
}}
@@ -58,107 +93,231 @@ const NumberField = ({ value, onChange, label }) => {
);
};
-const FrequencyRangeField = ({ value, onChange, label }) => (
-
-);
-
// Map component types to their respective wrappers
const componentRegistry = {
boolean: BooleanField,
string: StringField,
number: NumberField,
- FrequencyRange: FrequencyRangeField,
+ Array: Dropdown,
};
-const SettingsField = ({ path, Component, label, value, onChange, depth }) => {
+const SettingsField = ({ path, Component, label, value, onChange, error }) => {
return (
-
- {label}
- onChange(path, newValue)}
- label={label}
- />
-
+
+
+ {label}
+ onChange(path, newValue)}
+ label={label}
+ error={error}
+ />
+
+
);
};
+// Function to get the error corresponding to this field or its children
+const getFieldError = (fieldPath, errors) => {
+ if (!errors) return null;
+
+ return errors.find((error) => {
+ const errorPath = error.loc.join(".");
+ const currentPath = fieldPath.join(".");
+ return errorPath === currentPath || errorPath.startsWith(currentPath + ".");
+ });
+};
+
const SettingsSection = ({
settings,
title = null,
path = [],
onChange,
- depth = 0,
+ errors,
}) => {
- if (Object.keys(settings).length === 0) {
- return null;
- }
const boxTitle = title ? title : formatKey(path[path.length - 1]);
+ /*
+ 3 possible cases:
+ 1. Primitive type || 2. Object with component -> Don't iterate, render directly
+ 3. Object without component or 4. Array -> Iterate and render recursively
+ */
- return (
-
- {Object.entries(settings).map(([key, value]) => {
- if (key === "__field_type__") return null;
+ const type = typeof settings;
+ const isObject = type === "object" && !Array.isArray(settings);
+ const isArray = Array.isArray(settings);
- const newPath = [...path, key];
- const label = key;
- const isPydanticModel =
- typeof value === "object" && "__field_type__" in value;
+ // __field_type__ should be always present
+ if (isObject && !settings.__field_type__) {
+ console.log(settings);
+ throw new Error("Invalid settings object");
+ }
+ const fieldType = isObject ? settings.__field_type__ : type;
+ const Component = componentRegistry[fieldType];
- const fieldType = isPydanticModel ? value.__field_type__ : typeof value;
+ // Case 1: Primitive type -> Don't iterate, render directly
+ if (!isObject && !isArray) {
+ if (!Component) {
+ console.error(`Invalid component type: ${type}`);
+ return null;
+ }
- const Component = componentRegistry[fieldType];
+ const error = getFieldError(path, errors);
+
+ return (
+
+ );
+ }
+
+ // Case 2: Object with component -> Don't iterate, render directly
+ if (isObject && Component) {
+ return (
+
+ );
+ }
+
+ // Case 3: Object without component or 4. Array -> Iterate and render recursively
+ if ((isObject && !Component) || isArray) {
+ return (
+
+ {/* Handle recursing through both objects and arrays */}
+ {(isArray ? settings : Object.entries(settings)).map((item, index) => {
+ const [key, value] = isArray ? [index.toString(), item] : item;
+ if (key.startsWith("__")) return null; // Skip metadata fields
+
+ const newPath = [...path, key];
- if (Component) {
- return (
-
- );
- } else {
return (
);
- }
- })}
-
+ })}
+
+ );
+ }
+
+ // Default case: return null and log an error
+ console.error(`Invalid settings object, returning null`);
+ return null;
+};
+
+const StatusBarSettingsInfo = () => {
+ const validationErrors = useSettingsStore((state) => state.validationErrors);
+ const [anchorEl, setAnchorEl] = useState(null);
+ const open = Boolean(anchorEl);
+
+ const handleOpenErrorsPopover = (event) => {
+ setAnchorEl(event.currentTarget);
+ };
+
+ const handleCloseErrorsPopover = () => {
+ setAnchorEl(null);
+ };
+
+ return (
+ <>
+ {validationErrors?.length > 0 && (
+ <>
+
+ {validationErrors?.length} errors found in Settings
+
+
+
+ {validationErrors.map((error, index) => (
+
+ {index} - [{error.type}] {error.msg}
+
+ ))}
+
+
+ >
+ )}
+ >
);
};
-const SettingsContent = () => {
+export const Settings = () => {
+ // Get all necessary state from the settings store
const settings = useSettingsStore((state) => state.settings);
- const updateSettings = useSettingsStore((state) => state.updateSettings);
+ const uploadSettings = useSettingsStore((state) => state.uploadSettings);
+ const resetSettings = useSettingsStore((state) => state.resetSettings);
+ const validationErrors = useSettingsStore((state) => state.validationErrors);
+ useStatusBarContent(StatusBarSettingsInfo);
+
+ // This is needed so that the frequency ranges stay in order between updates
+ const frequencyRangeOrder = useSettingsStore(
+ (state) => state.frequencyRangeOrder
+ );
+ const updateFrequencyRangeOrder = useSettingsStore(
+ (state) => state.updateFrequencyRangeOrder
+ );
+
+ // Here I handle the selected feature in the feature settings component
+ const [selectedFeature, setSelectedFeature] = useState("");
+
+ useEffect(() => {
+ uploadSettings(null, true); // validateOnly = true
+ }, [settings]);
+ // Inject validation error info into status bar
+
+ // This has to be after all the hooks, otherwise React will complain
if (!settings) {
return Loading settings...
;
}
- const handleChange = (path, value) => {
- updateSettings((settings) => {
+ // This are the callbacks for the different buttons
+ const handleChangeSettings = async (path, value) => {
+ uploadSettings((settings) => {
let current = settings;
for (let i = 0; i < path.length - 1; i++) {
current = current[path[i]];
}
current[path[path.length - 1]] = value;
- });
+ }, true); // validateOnly = true
+ };
+
+ const handleSaveSettings = () => {
+ uploadSettings(() => settings);
+ };
+
+ const handleResetSettings = async () => {
+ await resetSettings();
};
const featureSettingsKeys = Object.keys(settings.features)
@@ -181,89 +340,146 @@ const SettingsContent = () => {
"project_subcortex_settings",
];
+ const generalSettingsKeys = [
+ "sampling_rate_features_hz",
+ "segment_length_features_ms",
+ ];
+
return (
-
-
+ {/* SETTINGS LAYOUT */}
+
-
-
+ {/* GENERAL SETTINGS + FREQUENCY RANGES */}
+
+
+ {generalSettingsKeys.map((key) => (
+
+ ))}
+
+
+
+
+
+
-
-
+ {/* POSTPROCESSING + PREPROCESSING SETTINGS */}
+
{preprocessingSettingsKeys.map((key) => (
))}
-
+
-
+
{postprocessingSettingsKeys.map((key) => (
))}
-
-
+
-
- {Object.entries(enabledFeatures).map(([feature, featureSettings]) => (
-
-
-
- ))}
-
-
- );
-};
+ {/* FEATURE SETTINGS */}
+
+
+
+
+
+
+ {Object.entries(enabledFeatures).map(
+ ([feature, featureSettings]) => (
+
+
+
+ )
+ )}
+
+
+
+ {/* END SETTINGS LAYOUT */}
+
-export const Settings = () => {
- return (
-
-
-
+
+ {/* */}
+
+
+
);
};
diff --git a/gui_dev/src/pages/Settings/TextField.jsx b/gui_dev/src/pages/Settings/TextField.jsx
deleted file mode 100644
index 86ab07b8..00000000
--- a/gui_dev/src/pages/Settings/TextField.jsx
+++ /dev/null
@@ -1,77 +0,0 @@
-import { useState, useEffect } from "react";
-import {
- Box,
- Grid,
- TextField as MUITextField,
- Typography,
-} from "@mui/material";
-import { useSettingsStore } from "@/stores";
-import styles from "./TextField.module.css";
-
-const flattenDictionary = (dict, parentKey = "", result = {}) => {
- for (let key in dict) {
- const newKey = parentKey ? `${parentKey}.${key}` : key;
- if (typeof dict[key] === "object" && dict[key] !== null) {
- flattenDictionary(dict[key], newKey, result);
- } else {
- result[newKey] = dict[key];
- }
- }
- return result;
-};
-
-const filterByKeys = (flatDict, keys) => {
- const filteredDict = {};
- keys.forEach((key) => {
- if (flatDict.hasOwnProperty(key)) {
- filteredDict[key] = flatDict[key];
- }
- });
- return filteredDict;
-};
-
-export const TextField = ({ keysToInclude }) => {
- const settings = useSettingsStore((state) => state.settings);
- const flatSettings = flattenDictionary(settings);
- const filteredSettings = filterByKeys(flatSettings, keysToInclude);
- const [textLabels, setTextLabels] = useState({});
-
- useEffect(() => {
- setTextLabels(filteredSettings);
- }, [settings]);
-
- const handleTextFieldChange = (label, value) => {
- setTextLabels((prevLabels) => ({
- ...prevLabels,
- [label]: value,
- }));
- };
-
- // Function to format the label
- const formatLabel = (label) => {
- const labelAfterDot = label.split(".").pop(); // Get everything after the last dot
- return labelAfterDot.replace(/_/g, " "); // Replace underscores with spaces
- };
-
- return (
-
- {Object.keys(textLabels).map((label, index) => (
-
-
- handleTextFieldChange(label, e.target.value)}
- className={styles.textFieldInput}
- />
-
- ))}
-
- );
-};
diff --git a/gui_dev/src/pages/Settings/TextField.module.css b/gui_dev/src/pages/Settings/TextField.module.css
deleted file mode 100644
index 37791c40..00000000
--- a/gui_dev/src/pages/Settings/TextField.module.css
+++ /dev/null
@@ -1,67 +0,0 @@
-/* TextField.module.css */
-
-/* Container for the text fields */
-.textFieldContainer {
- display: flex;
- flex-direction: column;
- margin: 1.5rem 0; /* Increased margin for better spacing */
- }
-
- /* Row for each text field */
- .textFieldRow {
- display: flex;
- flex-direction: column; /* Stack label and input vertically */
- margin-bottom: 1rem; /* Increased margin for better separation */
- }
-
- /* Label for each text field */
- .textFieldLabel {
- margin-bottom: 0.5rem; /* Space between label and input */
- font-weight: 600; /* Increased weight for better visibility */
- color: #333; /* Dark gray for the label */
- font-size: 1.1rem; /* Increased font size for the label */
- transition: all 0.2s ease; /* Smooth transition for label */
- }
-
- /* Input field styles */
- .textFieldInput {
- padding: 12px 14px; /* Padding for a filled look */
- border: 1px solid #ccc; /* Light gray border */
- border-radius: 4px; /* Rounded corners */
- width: 100%; /* Full width */
- font-size: 1rem; /* Font size */
- background-color: #f5f5f5; /* Light background color for filled effect */
- transition: border-color 0.2s ease, background-color 0.2s ease; /* Smooth transitions */
- box-shadow: none; /* Remove default shadow */
- height: 48px; /* Fixed height for a more square appearance */
- }
-
- /* Focus styles for the input */
- .textFieldInput:focus {
- border-color: #1976d2; /* Blue border color on focus */
- background-color: #fff; /* Change background to white on focus */
- outline: none; /* Remove default outline */
- }
-
- /* Hover effect for the input */
- .textFieldInput:hover {
- border-color: #1976d2; /* Change border color on hover */
- }
-
- /* Placeholder styles */
- .textFieldInput::placeholder {
- color: #aaa; /* Light gray placeholder text */
- opacity: 1; /* Ensure placeholder is fully opaque */
- }
-
- /* Hide the number input spinners in webkit browsers */
- .textFieldInput::-webkit-inner-spin-button,
- .textFieldInput::-webkit-outer-spin-button {
- -webkit-appearance: none; /* Remove default styling */
- margin: 0; /* Remove margin */
- }
-
- /* Hide the number input spinners in Firefox */
- .textFieldInput[type='number'] {
- -moz-appearance: textfield; /* Use textfield appearance */
- }
\ No newline at end of file
diff --git a/gui_dev/src/pages/Settings/index.js b/gui_dev/src/pages/Settings/index.js
deleted file mode 100644
index 16355023..00000000
--- a/gui_dev/src/pages/Settings/index.js
+++ /dev/null
@@ -1 +0,0 @@
-export { TextField } from './TextField';
diff --git a/gui_dev/src/stores/createStore.js b/gui_dev/src/stores/createStore.js
index 762d9340..e683a35c 100644
--- a/gui_dev/src/stores/createStore.js
+++ b/gui_dev/src/stores/createStore.js
@@ -2,16 +2,23 @@ import { create } from "zustand";
import { immer } from "zustand/middleware/immer";
import { devtools, persist as persistMiddleware } from "zustand/middleware";
-export const createStore = (name, initializer, persist = false) => {
+export const createStore = (
+ name,
+ initializer,
+ persist = false,
+ dev = false
+) => {
const fn = persist
? persistMiddleware(immer(initializer), name)
: immer(initializer);
- return create(
- devtools(fn, {
- name: name,
- })
- );
+ const dev_fn = dev
+ ? devtools(fn, {
+ name: name,
+ })
+ : fn;
+
+ return create(dev_fn);
};
export const createPersistStore = (name, initializer) => {
diff --git a/gui_dev/src/stores/settingsStore.js b/gui_dev/src/stores/settingsStore.js
index 9b17f42d..e5f085fe 100644
--- a/gui_dev/src/stores/settingsStore.js
+++ b/gui_dev/src/stores/settingsStore.js
@@ -5,32 +5,18 @@ const INITIAL_DELAY = 3000; // wait for Flask
const RETRY_DELAY = 1000; // ms
const MAX_RETRIES = 100;
-const uploadSettingsToServer = async (settings) => {
- try {
- const response = await fetch(getBackendURL("/api/settings"), {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- },
- body: JSON.stringify(settings),
- });
- if (!response.ok) {
- throw new Error("Failed to update settings");
- }
- return { success: true };
- } catch (error) {
- console.error("Error updating settings:", error);
- return { success: false, error };
- }
-};
-
export const useSettingsStore = createStore("settings", (set, get) => ({
settings: null,
+ lastValidSettings: null,
+ frequencyRangeOrder: [],
isLoading: false,
error: null,
+ validationErrors: null,
retryCount: 0,
- setSettings: (settings) => set({ settings }),
+ updateLocalSettings: (updater) => {
+ set((state) => updater(state.settings));
+ },
fetchSettingsWithDelay: () => {
set({ isLoading: true, error: null });
@@ -39,15 +25,25 @@ export const useSettingsStore = createStore("settings", (set, get) => ({
}, INITIAL_DELAY);
},
- fetchSettings: async () => {
+ fetchSettings: async (reset = false) => {
+ console.log("Fetching settings...");
try {
- console.log("Fetching settings...");
- const response = await fetch(getBackendURL("/api/settings"));
+ const response = await fetch(
+ `/api/settings${reset ? "?reset=true" : ""}`
+ );
+
if (!response.ok) {
throw new Error("Failed to fetch settings");
}
+
const data = await response.json();
- set({ settings: data, retryCount: 0 });
+
+ set({
+ settings: data,
+ lastValidSettings: data,
+ frequencyRangeOrder: Object.keys(data.frequency_ranges_hz || {}),
+ retryCount: 0,
+ });
} catch (error) {
console.log("Error fetching settings:", error);
set((state) => ({
@@ -55,6 +51,8 @@ export const useSettingsStore = createStore("settings", (set, get) => ({
retryCount: state.retryCount + 1,
}));
+ console.log(get().retryCount);
+
if (get().retryCount < MAX_RETRIES) {
await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY));
return get().fetchSettings();
@@ -66,29 +64,62 @@ export const useSettingsStore = createStore("settings", (set, get) => ({
resetRetryCount: () => set({ retryCount: 0 }),
- updateSettings: async (updater) => {
- const currentSettings = get().settings;
+ resetSettings: async () => {
+ await get().fetchSettings(true);
+ },
- // Apply the update optimistically
- set((state) => {
- updater(state.settings);
- });
+ updateFrequencyRangeOrder: (newOrder) => {
+ set({ frequencyRangeOrder: newOrder });
+ },
- const newSettings = get().settings;
+ uploadSettings: async (updater, validateOnly = false) => {
+ if (updater) {
+ set((state) => {
+ updater(state.settings);
+ });
+ }
+
+ const currentSettings = get().settings;
try {
- const result = await uploadSettingsToServer(newSettings);
+ const response = await fetch(
+ `/api/settings${validateOnly ? "?validate_only=true" : ""}`,
+ {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ },
+ body: JSON.stringify(currentSettings),
+ }
+ );
- if (!result.success) {
- // Revert the local state if the server update failed
- set({ settings: currentSettings });
+ const data = await response.json();
+
+ if (!response.ok) {
+ throw new Error("Failed to upload settings to backend");
}
- return result;
+ if (data.valid) {
+ // Settings are valid
+ set({
+ lastValidSettings: currentSettings,
+ validationErrors: null,
+ });
+ return true;
+ } else {
+ // Settings are invalid
+ set({
+ validationErrors: data.errors,
+ });
+ // Note: We don't revert the settings here, keeping the potentially invalid state
+ return false;
+ }
} catch (error) {
- // Revert the local state if there was an error
- set({ settings: currentSettings });
- throw error;
+ console.error(
+ `Error ${validateOnly ? "validating" : "updating"} settings:`,
+ error
+ );
+ return false;
}
},
}));
diff --git a/gui_dev/src/stores/uiStore.js b/gui_dev/src/stores/uiStore.js
index a229ec06..25b44382 100644
--- a/gui_dev/src/stores/uiStore.js
+++ b/gui_dev/src/stores/uiStore.js
@@ -1,4 +1,5 @@
import { createPersistStore } from "./createStore";
+import { useEffect } from "react";
export const useUiStore = createPersistStore("ui", (set, get) => ({
activeDrawer: null,
@@ -28,4 +29,24 @@ export const useUiStore = createPersistStore("ui", (set, get) => ({
state.accordionStates[id] = defaultState;
}
}),
+
+ // Hook to inject UI elements into the status bar
+ statusBarContent: () => {},
+ setStatusBarContent: (content) => set({ statusBarContent: content }),
+ clearStatusBarContent: () => set({ statusBarContent: null }),
}));
+
+// Use this hook from Page components to inject page-specific UI elements into the status bar
+export const useStatusBarContent = (content) => {
+ const createStatusBarContent = () => content;
+
+ const setStatusBarContent = useUiStore((state) => state.setStatusBarContent);
+ const clearStatusBarContent = useUiStore(
+ (state) => state.clearStatusBarContent
+ );
+
+ useEffect(() => {
+ setStatusBarContent(createStatusBarContent);
+ return () => clearStatusBarContent();
+ }, [content, setStatusBarContent, clearStatusBarContent]);
+};
diff --git a/gui_dev/src/theme.js b/gui_dev/src/theme.js
index 5e2b5f11..cc99909c 100644
--- a/gui_dev/src/theme.js
+++ b/gui_dev/src/theme.js
@@ -22,6 +22,11 @@ export const theme = createTheme({
disableRipple: true,
},
},
+ MuiTextField: {
+ defaultProps: {
+ autoComplete: "off",
+ },
+ },
MuiStack: {
defaultProps: {
alignItems: "center",
diff --git a/gui_dev/src/utils/functions.js b/gui_dev/src/utils/functions.js
index 04289d5c..1f8ae90a 100644
--- a/gui_dev/src/utils/functions.js
+++ b/gui_dev/src/utils/functions.js
@@ -17,6 +17,7 @@ export function debounce(func, wait) {
timeout = setTimeout(later, wait);
};
}
+
export const flattenDictionary = (dict, parentKey = "", result = {}) => {
for (let key in dict) {
const newKey = parentKey ? `${parentKey}.${key}` : key;
diff --git a/py_neuromodulation/__init__.py b/py_neuromodulation/__init__.py
index 688393fd..864df345 100644
--- a/py_neuromodulation/__init__.py
+++ b/py_neuromodulation/__init__.py
@@ -4,11 +4,12 @@
from importlib.metadata import version
from py_neuromodulation.utils.logging import NMLogger
+
#####################################
# Globals and environment variables #
#####################################
-__version__ = version("py_neuromodulation") # get version from pyproject.toml
+__version__ = version("py_neuromodulation")
# Check if the module is running headless (no display) for tests and doc builds
PYNM_HEADLESS: bool = not os.environ.get("DISPLAY")
@@ -18,6 +19,7 @@
os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file
+
# Set environment variable MNE_LSL_LIB (required to import Stream below)
LSL_DICT = {
"windows_32bit": "windows/x86/liblsl.1.16.2.dll",
@@ -36,6 +38,7 @@
PLATFORM = platform.system().lower().strip()
ARCH = platform.architecture()[0]
+
match PLATFORM:
case "windows":
KEY = PLATFORM + "_" + ARCH
diff --git a/py_neuromodulation/default_settings.yaml b/py_neuromodulation/default_settings.yaml
index 49988840..e7858599 100644
--- a/py_neuromodulation/default_settings.yaml
+++ b/py_neuromodulation/default_settings.yaml
@@ -1,4 +1,15 @@
----
+# We
+# should
+# have
+# a
+# brief
+# explanation
+# of
+# the
+# settings
+# format
+# here
+
########################
### General settings ###
########################
@@ -51,12 +62,8 @@ preprocessing_filter:
lowpass_filter: true
highpass_filter: true
bandpass_filter: true
- bandstop_filter_settings:
- frequency_low_hz: 100
- frequency_high_hz: 160
- bandpass_filter_settings:
- frequency_low_hz: 3
- frequency_high_hz: 200
+ bandstop_filter_settings: [100, 160] # [low_hz, high_hz]
+ bandpass_filter_settings: [3, 200] # [hz, _hz]
lowpass_filter_cutoff_hz: 200
highpass_filter_cutoff_hz: 3
@@ -162,10 +169,8 @@ sharpwave_analysis_settings:
decay_steepness: false
slope_ratio: false
filter_ranges_hz:
- - frequency_low_hz: 5
- frequency_high_hz: 80
- - frequency_low_hz: 5
- frequency_high_hz: 30
+ - [5, 80]
+ - [5, 30]
detect_troughs:
estimate: true
distance_troughs_ms: 10
@@ -174,6 +179,7 @@ sharpwave_analysis_settings:
estimate: true
distance_troughs_ms: 5
distance_peaks_ms: 10
+ # TONI: Reverse this setting? e.g. interval: [mean, var]
estimator:
mean: [interval]
median: []
diff --git a/py_neuromodulation/features/bursts.py b/py_neuromodulation/features/bursts.py
index 83d63bf8..2ff8ceac 100644
--- a/py_neuromodulation/features/bursts.py
+++ b/py_neuromodulation/features/bursts.py
@@ -11,7 +11,7 @@
from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
from typing import TYPE_CHECKING, Callable
-from py_neuromodulation.utils.types import create_validation_error
+from py_neuromodulation.utils.pydantic_extensions import create_validation_error
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
diff --git a/py_neuromodulation/gui/backend/app_backend.py b/py_neuromodulation/gui/backend/app_backend.py
index e6c74985..ecfa0a97 100644
--- a/py_neuromodulation/gui/backend/app_backend.py
+++ b/py_neuromodulation/gui/backend/app_backend.py
@@ -13,6 +13,7 @@
)
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
+from pydantic import ValidationError
from . import app_pynm
from .app_socket import WebSocketManager
@@ -74,21 +75,62 @@ async def healthcheck():
####################
##### SETTINGS #####
####################
-
@self.get("/api/settings")
- async def get_settings():
- return self.pynm_state.settings.process_for_frontend()
+ async def get_settings(
+ reset: bool = Query(False, description="Reset settings to default"),
+ ):
+ if reset:
+ settings = NMSettings.get_default()
+ else:
+ settings = self.pynm_state.settings
+
+ return settings.serialize_with_metadata()
@self.post("/api/settings")
- async def update_settings(data: dict):
+ async def update_settings(data: dict, validate_only: bool = Query(False)):
try:
- self.pynm_state.settings = NMSettings.model_validate(data)
- self.logger.info(self.pynm_state.settings.features)
- return self.pynm_state.settings.model_dump()
- except ValueError as e:
+ # First, validate with Pydantic
+ try:
+ validated_settings = NMSettings.model_validate(data)
+ except ValidationError as e:
+ if not validate_only:
+ # If validation failed but we wanted to upload, return error
+ self.logger.error(f"Error validating settings: {e}")
+ raise HTTPException(
+ status_code=422,
+ detail={
+ "error": "Error validating settings",
+ "details": str(e),
+ },
+ )
+ # Else return list of errors
+ return {
+ "valid": False,
+ "errors": [err for err in e.errors()],
+ "details": str(e),
+ }
+
+ # If validation succesful, return or update settings
+ if validate_only:
+ return {
+ "valid": True,
+ "settings": validated_settings.serialize_with_metadata(),
+ }
+
+ self.pynm_state.settings = validated_settings
+ self.logger.info("Settings successfully updated")
+
+ return {
+ "valid": True,
+ "settings": self.pynm_state.settings.serialize_with_metadata(),
+ }
+
+ # If something else than validation went wrong, return error
+ except Exception as e:
+ self.logger.error(f"Error validating/updating settings: {e}")
raise HTTPException(
status_code=422,
- detail={"error": "Validation failed", "details": str(e)},
+ detail={"error": "Error uploading settings", "details": str(e)},
)
########################
diff --git a/py_neuromodulation/processing/normalization.py b/py_neuromodulation/processing/normalization.py
index 3345b122..66419cec 100644
--- a/py_neuromodulation/processing/normalization.py
+++ b/py_neuromodulation/processing/normalization.py
@@ -2,10 +2,10 @@
import numpy as np
from typing import TYPE_CHECKING, Callable, Literal, get_args
+from pydantic import Field
from py_neuromodulation.utils.types import (
NMBaseModel,
- Field,
NormMethod,
NMPreprocessor,
)
diff --git a/py_neuromodulation/processing/resample.py b/py_neuromodulation/processing/resample.py
index 08a4e115..6fa0885d 100644
--- a/py_neuromodulation/processing/resample.py
+++ b/py_neuromodulation/processing/resample.py
@@ -1,7 +1,8 @@
"""Module for resampling."""
import numpy as np
-from py_neuromodulation.utils.types import NMBaseModel, Field, NMPreprocessor
+from pydantic import Field
+from py_neuromodulation.utils.types import NMBaseModel, NMPreprocessor
class ResamplerSettings(NMBaseModel):
diff --git a/py_neuromodulation/stream/settings.py b/py_neuromodulation/stream/settings.py
index 7e25ec24..39a9b6fd 100644
--- a/py_neuromodulation/stream/settings.py
+++ b/py_neuromodulation/stream/settings.py
@@ -4,15 +4,15 @@
from typing import ClassVar
from pydantic import Field, model_validator
-from py_neuromodulation import PYNM_DIR, logger, user_features
+from py_neuromodulation import logger, user_features
from py_neuromodulation.utils.types import (
BoolSelector,
FrequencyRange,
- PreprocessorName,
_PathLike,
NMBaseModel,
NormMethod,
+ PreprocessorList,
)
from py_neuromodulation.processing.filter_preprocessing import FilterSettings
@@ -72,11 +72,14 @@ class NMSettings(NMBaseModel):
}
# Preproceessing settings
- preprocessing: list[PreprocessorName] = [
- "raw_resampling",
- "notch_filter",
- "re_referencing",
- ]
+ preprocessing: PreprocessorList = PreprocessorList(
+ [
+ "raw_resampling",
+ "notch_filter",
+ "re_referencing",
+ ]
+ )
+
raw_resampling_settings: ResamplerSettings = ResamplerSettings()
preprocessing_filter: FilterSettings = FilterSettings()
raw_normalization_settings: NormalizationSettings = NormalizationSettings()
@@ -144,26 +147,29 @@ def validate_settings(self):
if self.bandpass_filter_settings.kalman_filter:
self.kalman_filter_settings.validate_fbands(self)
- for k, v in self.frequency_ranges_hz.items():
- if not isinstance(v, FrequencyRange):
- self.frequency_ranges_hz[k] = FrequencyRange.create_from(v)
+ # TONI: not needed after NMSequenceModel, remove in the future
+ # for k, v in self.frequency_ranges_hz.items():
+ # if not isinstance(v, FrequencyRange):
+ # self.frequency_ranges_hz[k] = FrequencyRange.create_from(v)
return self
def reset(self) -> "NMSettings":
self.features.disable_all()
- self.preprocessing = []
+ self.preprocessing = PreprocessorList()
self.postprocessing.disable_all()
return self
def set_fast_compute(self) -> "NMSettings":
self.reset()
self.features.fft = True
- self.preprocessing = [
- "raw_resampling",
- "notch_filter",
- "re_referencing",
- ]
+ self.preprocessing = PreprocessorList(
+ [
+ "raw_resampling",
+ "notch_filter",
+ "re_referencing",
+ ]
+ )
self.postprocessing.feature_normalization = True
self.postprocessing.project_cortex = False
self.postprocessing.project_subcortex = False
@@ -250,7 +256,7 @@ def from_file(PATH: _PathLike) -> "NMSettings":
@staticmethod
def get_default() -> "NMSettings":
- return NMSettings.from_file(PYNM_DIR / "default_settings.yaml")
+ return NMSettings()
@staticmethod
def list_normalization_methods() -> list[NormMethod]:
diff --git a/py_neuromodulation/utils/pydantic_extensions.py b/py_neuromodulation/utils/pydantic_extensions.py
new file mode 100644
index 00000000..be677243
--- /dev/null
+++ b/py_neuromodulation/utils/pydantic_extensions.py
@@ -0,0 +1,347 @@
+from typing import Any, get_type_hints, TypeVar, Generic, Literal, overload
+from typing_extensions import Unpack, TypedDict
+from pydantic import BaseModel, ConfigDict, model_validator
+from pydantic_core import PydanticUndefined, ValidationError, InitErrorDetails
+from pydantic.fields import FieldInfo, _FieldInfoInputs, _FromFieldInfoInputs
+from pprint import pformat
+
+
+def create_validation_error(
+ error_message: str,
+ loc: list[str | int] | None = None,
+ title: str = "Validation Error",
+ input_type: Literal["python", "json"] = "python",
+ hide_input: bool = False,
+) -> ValidationError:
+ """
+ Factory function to create a Pydantic v2 ValidationError instance from a single error message.
+
+ Args:
+ error_message (str): The error message for the ValidationError.
+ loc (List[str | int], optional): The location of the error. Defaults to None.
+ title (str, optional): The title of the error. Defaults to "Validation Error".
+ input_type (Literal["python", "json"], optional): Whether the error is for a Python object or JSON. Defaults to "python".
+ hide_input (bool, optional): Whether to hide the input value in the error message. Defaults to False.
+
+ Returns:
+ ValidationError: A Pydantic ValidationError instance.
+ """
+ if loc is None:
+ loc = []
+
+ line_errors = [
+ InitErrorDetails(
+ type="value_error", loc=tuple(loc), input=None, ctx={"error": error_message}
+ )
+ ]
+
+ return ValidationError.from_exception_data(
+ title=title,
+ line_errors=line_errors,
+ input_type=input_type,
+ hide_input=hide_input,
+ )
+
+
+class _NMExtraFieldInputs(TypedDict, total=False):
+ """Additional fields to add on top of the pydantic FieldInfo"""
+
+ custom_metadata: dict[str, Any]
+
+
+class _NMFieldInfoInputs(_FieldInfoInputs, _NMExtraFieldInputs, total=False):
+ """Combine pydantic FieldInfo inputs with PyNM additional inputs"""
+
+ pass
+
+
+class _NMFromFieldInfoInputs(_FromFieldInfoInputs, _NMExtraFieldInputs, total=False):
+ """Combine pydantic FieldInfo.from_field inputs with PyNM additional inputs"""
+
+ pass
+
+
+class NMFieldInfo(FieldInfo):
+ # Add default values for any other custom fields here
+ _default_values = {}
+
+ def __init__(self, **kwargs: Unpack[_NMFieldInfoInputs]) -> None:
+ self.custom_metadata = kwargs.pop("custom_metadata", {})
+ super().__init__(**kwargs)
+
+ @staticmethod
+ def from_field(
+ default: Any = PydanticUndefined,
+ **kwargs: Unpack[_NMFromFieldInfoInputs],
+ ) -> "NMFieldInfo":
+ if "annotation" in kwargs:
+ raise TypeError('"annotation" is not permitted as a Field keyword argument')
+ return NMFieldInfo(default=default, **kwargs)
+
+ def __repr_args__(self):
+ yield from super().__repr_args__()
+ extra_fields = get_type_hints(_NMExtraFieldInputs)
+ for field in extra_fields:
+ value = getattr(self, field)
+ yield field, value
+
+
+def NMField(
+ default: Any = PydanticUndefined,
+ **kwargs: Unpack[_NMFromFieldInfoInputs],
+) -> Any:
+ return NMFieldInfo.from_field(default=default, **kwargs)
+
+
+class NMBaseModel(BaseModel):
+ model_config = ConfigDict(validate_assignment=False, extra="allow")
+
+ def __init__(self, *args, **kwargs) -> None:
+ """Pydantic does not support positional arguments by default.
+ This is a workaround to support positional arguments for models like FrequencyRange.
+ It converts positional arguments to kwargs and then calls the base class __init__.
+ """
+ if not args:
+ # Simple case - just use kwargs
+ super().__init__(**kwargs)
+ return
+
+ field_names = list(self.model_fields.keys())
+ # If we have more positional args than fields, that's an error
+ if len(args) > len(field_names):
+ raise ValueError(
+ f"Too many positional arguments. Expected at most {len(field_names)}, got {len(args)}"
+ )
+
+ # Convert positional args to kwargs, checking for duplicates if args:
+ complete_kwargs = {}
+ for i, arg in enumerate(args):
+ field_name = field_names[i]
+ if field_name in kwargs:
+ raise ValueError(
+ f"Got multiple values for field '{field_name}': "
+ f"positional argument and keyword argument"
+ )
+ complete_kwargs[field_name] = arg
+
+ # Add remaining kwargs
+ complete_kwargs.update(kwargs)
+ super().__init__(**complete_kwargs)
+
+ def __str__(self):
+ return pformat(self.model_dump())
+
+ def __repr__(self):
+ return pformat(self.model_dump())
+
+ def validate(self) -> Any: # type: ignore
+ return self.model_validate(self.model_dump())
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __setitem__(self, key, value) -> None:
+ setattr(self, key, value)
+
+ @property
+ def fields(self) -> dict[str, FieldInfo | NMFieldInfo]:
+ return self.model_fields # type: ignore
+
+ def serialize_with_metadata(self):
+ result: dict[str, Any] = {"__field_type__": self.__class__.__name__}
+
+ for field_name, field_info in self.fields.items():
+ value = getattr(self, field_name)
+ if isinstance(value, NMBaseModel):
+ result[field_name] = value.serialize_with_metadata()
+ elif isinstance(value, list):
+ result[field_name] = [
+ item.serialize_with_metadata()
+ if isinstance(item, NMBaseModel)
+ else item
+ for item in value
+ ]
+ elif isinstance(value, dict):
+ result[field_name] = {
+ k: v.serialize_with_metadata() if isinstance(v, NMBaseModel) else v
+ for k, v in value.items()
+ }
+ else:
+ result[field_name] = value
+
+ # Extract unit information from Annotated type
+ if isinstance(field_info, NMFieldInfo):
+ for tag, value in field_info.custom_metadata.items():
+ result[f"__{tag}__"] = value
+ return result
+
+
+#################################
+#### Generic Pydantic models ####
+#################################
+
+
+def create_alias_property(index: int, alias: str, classname: str):
+ """Creates a property that accesses the root sequence at the given index"""
+
+ def getter(self):
+ return self.root[index]
+
+ def setter(self, value):
+ if isinstance(self.root, tuple):
+ new_values = list(self.root)
+ new_values[index] = value
+ self.root = tuple(new_values)
+ else:
+ self.root[index] = value
+
+ return property(
+ fget=getter,
+ fset=setter,
+ doc=f"Alias '{alias}' for position [{index}] of class '{classname}'.",
+ )
+
+
+T = TypeVar("T")
+C = TypeVar("C", list, tuple)
+
+
+class NMSequenceModel(NMBaseModel, Generic[C]):
+ """Base class for sequence models with a single root value"""
+
+ root: C = NMField(default_factory=list)
+
+ # Class variable for aliases - override in subclasses
+ __aliases__: dict[int, list[str]] = {}
+
+ def __init__(self, *args, **kwargs) -> None:
+ # Generate properties programatically (not used currently)
+ # for index, aliases in self.__aliases__.items():
+ # for alias in aliases:
+ # if not hasattr(self.__class__, alias):
+ # setattr(
+ # self.__class__,
+ # alias,
+ # create_alias_property(index, alias, self.__class__.__name__),
+ # )
+
+ if len(args) == 1 and isinstance(args[0], (list, tuple)):
+ kwargs["root"] = args[0]
+ elif len(args) == 1:
+ kwargs["root"] = [args[0]]
+ elif len(args) > 1: # Add this case
+ kwargs["root"] = tuple(args)
+ super().__init__(**kwargs)
+
+ def __iter__(self): # type: ignore[reportIncompatibleMethodOverride]
+ return iter(self.root)
+
+ def __getitem__(self, idx):
+ return self.root[idx]
+
+ def __len__(self):
+ return len(self.root)
+
+ def model_dump(self): # type: ignore[reportIncompatibleMethodOverride]
+ return self.root
+
+ def model_dump_json(self, **kwargs):
+ import json
+
+ return json.dumps(self.root, **kwargs)
+
+ def serialize_with_metadata(self) -> dict[str, Any]:
+ result = {"__field_type__": self.__class__.__name__, "value": self.root}
+
+ # Add any field metadata from the root field
+ field_info = self.model_fields.get("root")
+ if isinstance(field_info, NMFieldInfo):
+ for tag, value in field_info.custom_metadata.items():
+ result[f"__{tag}__"] = value
+
+ return result
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_input(cls, value: Any) -> dict[str, Any]:
+ # If it's a dict, just return it
+ if isinstance(value, dict):
+ if "root" in value:
+ return value
+
+ # Check for aliased fields if class has aliases defined
+ if hasattr(cls, "__aliases__"):
+ # Collect all possible alias names for each position
+ alias_values = []
+ max_index = max(cls.__aliases__.keys()) if cls.__aliases__ else -1
+
+ # Try to find a value for each position using its aliases
+ for i in range(max_index + 1):
+ aliases = cls.__aliases__.get(i, [])
+ value_found = None
+
+ # Try each alias for this position
+ for alias in aliases:
+ if alias in value:
+ value_found = value[alias]
+ break
+
+ if value_found is not None:
+ alias_values.append(value_found)
+ else:
+ # If we're missing any position's value, don't use aliases
+ break
+
+ # If we found all values through aliases, use them
+ if len(alias_values) == max_index + 1:
+ return {"root": alias_values}
+
+ # if it's a sequence, return the value as the root
+ if isinstance(value, (list, tuple)):
+ return {"root": value}
+
+ # Else, make it a list
+ return {"root": [value]}
+
+
+class NMValueModel(NMBaseModel, Generic[T]):
+ """Base class for single-value models that behave like their contained type"""
+
+ root: T
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_input(cls, value: Any) -> dict[str, Any]:
+ if isinstance(value, dict):
+ if "root" in value:
+ return value
+ # If it's a dict without root, assume the first value is our value
+ if len(value) > 0:
+ return {"root": next(iter(value.values()))}
+ return {"root": None}
+ return {"root": value}
+
+ def __str__(self) -> str:
+ return str(self.root)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({repr(self.root)})"
+
+ def model_dump(self): # type: ignore[reportIncompatibleMethodOverride]
+ return self.root
+
+ def model_dump_json(self, **kwargs):
+ import json
+
+ return json.dumps(self.root, **kwargs)
+
+ def serialize_with_metadata(self) -> dict[str, Any]:
+ result = {"__field_type__": self.__class__.__name__, "value": self.root}
+
+ # Add any field metadata from the root field
+ field_info = self.model_fields.get("root")
+ if isinstance(field_info, NMFieldInfo):
+ for tag, value in field_info.custom_metadata.items():
+ result[f"__{tag}__"] = value
+
+ return result
diff --git a/py_neuromodulation/utils/types.py b/py_neuromodulation/utils/types.py
index 7886685d..a2c22db9 100644
--- a/py_neuromodulation/utils/types.py
+++ b/py_neuromodulation/utils/types.py
@@ -1,12 +1,13 @@
from os import PathLike
from math import isnan
-from typing import Any, Literal, Protocol, TYPE_CHECKING, runtime_checkable
-from pydantic import ConfigDict, Field, model_validator, BaseModel
-from pydantic_core import ValidationError, InitErrorDetails
-from pprint import pformat
+from typing import Literal, TYPE_CHECKING
+from pydantic import BaseModel, ConfigDict, model_validator
+from .pydantic_extensions import NMBaseModel, NMSequenceModel
+
from collections.abc import Sequence
from datetime import datetime
+
if TYPE_CHECKING:
import numpy as np
from py_neuromodulation import NMSettings
@@ -17,7 +18,6 @@
_PathLike = str | PathLike
-
FeatureName = Literal[
"raw_hjorth",
"return_raw",
@@ -54,13 +54,13 @@
"minmax",
]
+
###################################
######## PROTOCOL CLASSES ########
###################################
-@runtime_checkable
-class NMFeature(Protocol):
+class NMFeature:
def __init__(
self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int | float
) -> None: ...
@@ -81,131 +81,47 @@ def calc_feature(self, data: "np.ndarray") -> dict:
...
-class NMPreprocessor(Protocol):
- def __init__(self, sfreq: float, settings: "NMSettings") -> None: ...
-
+class NMPreprocessor:
def process(self, data: "np.ndarray") -> "np.ndarray": ...
-###################################
-######## PYDANTIC CLASSES ########
-###################################
-
-
-class NMBaseModel(BaseModel):
- model_config = ConfigDict(validate_assignment=False, extra="allow")
+class PreprocessorList(NMSequenceModel[list[PreprocessorName]]):
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+ # Useless contructor to prevent linter from complaining
def __init__(self, *args, **kwargs) -> None:
- if kwargs:
- super().__init__(**kwargs)
- else:
- field_names = list(self.model_fields.keys())
- kwargs = {}
- for i in range(len(args)):
- kwargs[field_names[i]] = args[i]
- super().__init__(**kwargs)
-
- def __str__(self):
- return pformat(self.model_dump())
-
- def __repr__(self):
- return pformat(self.model_dump())
-
- def validate(self) -> Any: # type: ignore
- return self.model_validate(self.model_dump())
+ super().__init__(*args, **kwargs)
- def __getitem__(self, key):
- return getattr(self, key)
- def __setitem__(self, key, value) -> None:
- setattr(self, key, value)
+class FrequencyRange(NMSequenceModel[tuple[float, float]]):
+ """Frequency range as (low, high) tuple"""
- def process_for_frontend(self) -> dict[str, Any]:
- """
- Process the model for frontend use, adding __field_type__ information.
- """
- result = {}
- for field_name, field_value in self.__dict__.items():
- if isinstance(field_value, NMBaseModel):
- processed_value = field_value.process_for_frontend()
- processed_value["__field_type__"] = field_value.__class__.__name__
- result[field_name] = processed_value
- elif isinstance(field_value, list):
- result[field_name] = [
- item.process_for_frontend()
- if isinstance(item, NMBaseModel)
- else item
- for item in field_value
- ]
- elif isinstance(field_value, dict):
- result[field_name] = {
- k: v.process_for_frontend() if isinstance(v, NMBaseModel) else v
- for k, v in field_value.items()
- }
- else:
- result[field_name] = field_value
-
- return result
-
-
-class FrequencyRange(NMBaseModel):
- frequency_low_hz: float = Field(gt=0)
- frequency_high_hz: float = Field(gt=0)
+ __aliases__ = {
+ 0: ["frequency_low_hz", "low_frequency_hz"],
+ 1: ["frequency_high_hz", "high_frequency_hz"],
+ }
+ # Useless contructor to prevent linter from complaining
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
- def __getitem__(self, item: int):
- match item:
- case 0:
- return self.frequency_low_hz
- case 1:
- return self.frequency_high_hz
- case _:
- raise IndexError(f"Index {item} out of range")
-
- def as_tuple(self) -> tuple[float, float]:
- return (self.frequency_low_hz, self.frequency_high_hz)
-
- def __iter__(self): # type: ignore
- return iter(self.as_tuple())
-
@model_validator(mode="after")
def validate_range(self):
- if not (isnan(self.frequency_high_hz) or isnan(self.frequency_low_hz)):
- assert (
- self.frequency_high_hz > self.frequency_low_hz
- ), "Frequency high must be greater than frequency low"
+ low, high = self.root
+ if not (isnan(low) or isnan(high)):
+ assert high > low, "High frequency must be greater than low frequency"
return self
- @classmethod
- def create_from(cls, input) -> "FrequencyRange":
- match input:
- case FrequencyRange():
- return input
- case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input:
- return FrequencyRange(
- input["frequency_low_hz"], input["frequency_high_hz"]
- )
- case Sequence() if len(input) == 2:
- return FrequencyRange(input[0], input[1])
- case _:
- raise ValueError("Invalid input for FrequencyRange creation.")
-
- @model_validator(mode="before")
- @classmethod
- def check_input(cls, input):
- match input:
- case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input:
- return input
- case Sequence() if len(input) == 2:
- return {"frequency_low_hz": input[0], "frequency_high_hz": input[1]}
- case _:
- raise ValueError(
- "Value for FrequencyRange must be a dictionary, "
- "or a sequence of 2 numeric values, "
- f"but got {input} instead."
- )
+ # Alias properties
+ @property
+ def frequency_low_hz(self) -> float:
+ """Lower frequency bound in Hz"""
+ return self.root[0]
+
+ @property
+ def frequency_high_hz(self) -> float:
+ """Upper frequency bound in Hz"""
+ return self.root[1]
class BoolSelector(NMBaseModel):
@@ -238,47 +154,6 @@ def print_all(cls):
for f in cls.list_all():
print(f)
- @classmethod
- def get_fields(cls):
- return cls.model_fields
-
-
-def create_validation_error(
- error_message: str,
- loc: list[str | int] = None,
- title: str = "Validation Error",
- input_type: Literal["python", "json"] = "python",
- hide_input: bool = False,
-) -> ValidationError:
- """
- Factory function to create a Pydantic v2 ValidationError instance from a single error message.
-
- Args:
- error_message (str): The error message for the ValidationError.
- loc (List[str | int], optional): The location of the error. Defaults to None.
- title (str, optional): The title of the error. Defaults to "Validation Error".
- input_type (Literal["python", "json"], optional): Whether the error is for a Python object or JSON. Defaults to "python".
- hide_input (bool, optional): Whether to hide the input value in the error message. Defaults to False.
-
- Returns:
- ValidationError: A Pydantic ValidationError instance.
- """
- if loc is None:
- loc = []
-
- line_errors = [
- InitErrorDetails(
- type="value_error", loc=tuple(loc), input=None, ctx={"error": error_message}
- )
- ]
-
- return ValidationError.from_exception_data(
- title=title,
- line_errors=line_errors,
- input_type=input_type,
- hide_input=hide_input,
- )
-
#################
### GUI TYPES ###
diff --git a/pyproject.toml b/pyproject.toml
index 32cea5a5..193177c5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,7 +42,6 @@ dependencies = [
"nolds >=0.6.1",
"numpy >= 1.21.2",
"pandas >= 2.0.0",
- "scikit-image",
"scikit-learn >= 0.24.2",
"scikit-optimize",
"scipy >= 1.7.1",
@@ -55,15 +54,12 @@ dependencies = [
"statsmodels",
"mne-lsl>=1.2.0",
"pynput",
- #"pyqt5",
"pydantic>=2.7.3",
"llvmlite>=0.43.0",
"pywebview",
"fastapi",
- "uvicorn>=0.30.6",
- "websockets>=13.0",
+ "uvicorn[standard]>=0.30.6",
"seaborn >= 0.11",
- # exists only because of nolds?
"numba>=0.60.0",
"cbor2>=5.6.4",
"msgpack>=1.1.0",
@@ -78,7 +74,7 @@ dev = [
"pytest-cov",
"pytest-sugar",
"notebook",
- "watchdog",
+ "uvicorn[standard]",
]
docs = [
"py-neuromodulation[dev]",