diff --git a/gui_dev/package.json b/gui_dev/package.json
index e1cb7f00..5fa2d606 100644
--- a/gui_dev/package.json
+++ b/gui_dev/package.json
@@ -21,7 +21,7 @@
"react-dom": "next",
"react-icons": "^5.3.0",
"react-plotly.js": "^2.6.0",
- "react-router-dom": "^6.28.0",
+ "react-router-dom": "^7.0.1",
"zustand": "latest"
},
"devDependencies": {
diff --git a/gui_dev/src/App.jsx b/gui_dev/src/App.jsx
index 7ebfdc55..a97f8c75 100644
--- a/gui_dev/src/App.jsx
+++ b/gui_dev/src/App.jsx
@@ -45,8 +45,6 @@ export const App = () => {
const connectSocket = useSocketStore((state) => state.connectSocket);
const disconnectSocket = useSocketStore((state) => state.disconnectSocket);
- connectSocket();
-
useEffect(() => {
console.log("Connecting socket from App component...");
connectSocket();
diff --git a/gui_dev/src/components/StatusBar/StatusBar.jsx b/gui_dev/src/components/StatusBar/StatusBar.jsx
index 36aaf1dc..f011e2b9 100644
--- a/gui_dev/src/components/StatusBar/StatusBar.jsx
+++ b/gui_dev/src/components/StatusBar/StatusBar.jsx
@@ -2,12 +2,14 @@ import { ResizeHandle } from "./ResizeHandle";
import { SocketStatus } from "./SocketStatus";
import { WebviewStatus } from "./WebviewStatus";
-import { useWebviewStore } from "@/stores";
-
+import { useUiStore, useWebviewStore } from "@/stores";
import { Stack } from "@mui/material";
export const StatusBar = () => {
- const { isWebView } = useWebviewStore((state) => state.isWebView);
+ const isWebView = useWebviewStore((state) => state.isWebView);
+ const createStatusBarContent = useUiStore((state) => state.statusBarContent);
+
+ const StatusBarContent = createStatusBarContent();
return (
{
bgcolor="background.level1"
borderTop="2px solid"
borderColor="background.level3"
+ height="2rem"
>
-
+ {StatusBarContent && }
+
+ {/* */}
{/* Current experiment */}
{/* Current stream */}
{/* Current activity */}
diff --git a/gui_dev/src/components/TitledBox.jsx b/gui_dev/src/components/TitledBox.jsx
index 56f0266b..a86b6009 100644
--- a/gui_dev/src/components/TitledBox.jsx
+++ b/gui_dev/src/components/TitledBox.jsx
@@ -1,4 +1,4 @@
-import { Container } from "@mui/material";
+import { Box, Container } from "@mui/material";
/**
* Component that uses the Box component to render an HTML fieldset element
@@ -13,14 +13,14 @@ export const TitledBox = ({
children,
...props
}) => (
-
{children}
-
+
);
diff --git a/gui_dev/src/pages/Settings/DragAndDropList.jsx b/gui_dev/src/pages/Settings/DragAndDropList.jsx
deleted file mode 100644
index afa74b02..00000000
--- a/gui_dev/src/pages/Settings/DragAndDropList.jsx
+++ /dev/null
@@ -1,65 +0,0 @@
-import { useRef } from "react";
-import styles from "./DragAndDropList.module.css";
-import { useOptionsStore } from "@/stores";
-
-export const DragAndDropList = () => {
- const { options, setOptions, addOption, removeOption } = useOptionsStore();
-
- const predefinedOptions = [
- { id: 1, name: "raw_resampling" },
- { id: 2, name: "notch_filter" },
- { id: 3, name: "re_referencing" },
- { id: 4, name: "preprocessing_filter" },
- { id: 5, name: "raw_normalization" },
- ];
-
- const dragOption = useRef(0);
- const draggedOverOption = useRef(0);
-
- function handleSort() {
- const optionsClone = [...options];
- const temp = optionsClone[dragOption.current];
- optionsClone[dragOption.current] = optionsClone[draggedOverOption.current];
- optionsClone[draggedOverOption.current] = temp;
- setOptions(optionsClone);
- }
-
- return (
-
- List
- {options.map((option, index) => (
- (dragOption.current = index)}
- onDragEnter={() => (draggedOverOption.current = index)}
- onDragEnd={handleSort}
- onDragOver={(e) => e.preventDefault()}
- >
-
- {[option.id, ". ", option.name.replace("_", " ")]}
-
-
-
- ))}
-
-
Add Elements
- {predefinedOptions.map((option) => (
-
- ))}
-
-
- );
-};
diff --git a/gui_dev/src/pages/Settings/DragAndDropList.module.css b/gui_dev/src/pages/Settings/DragAndDropList.module.css
deleted file mode 100644
index 76bd7861..00000000
--- a/gui_dev/src/pages/Settings/DragAndDropList.module.css
+++ /dev/null
@@ -1,72 +0,0 @@
-.dragDropList {
- max-width: 400px;
- margin: 0 auto;
- padding: 20px;
-}
-
-.title {
- text-align: center;
- color: #333;
- margin-bottom: 20px;
-}
-
-.item {
- background-color: #f0f0f0;
- border-radius: 8px;
- padding: 15px;
- margin-bottom: 10px;
- cursor: move;
- transition: background-color 0.3s ease;
- display: flex;
- justify-content: space-between;
- align-items: center;
-}
-
-.item:hover {
- background-color: #e0e0e0;
-}
-
-.itemText {
- margin: 0;
- color: #333;
- font-size: 16px;
-}
-
-.removeButton {
- background-color: #ff4d4d;
- border: none;
- border-radius: 4px;
- color: white;
- padding: 5px 10px;
- cursor: pointer;
- transition: background-color 0.3s ease;
-}
-
-.removeButton:hover {
- background-color: #ff1a1a;
-}
-
-.addSection {
- margin-top: 20px;
- text-align: center;
-}
-
-.subtitle {
- color: #333;
- margin-bottom: 10px;
-}
-
-.addButton {
- background-color: #4CAF50;
- border: none;
- border-radius: 4px;
- color: white;
- padding: 10px 15px;
- margin: 5px;
- cursor: pointer;
- transition: background-color 0.3s ease;
-}
-
-.addButton:hover {
- background-color: #45a049;
-}
\ No newline at end of file
diff --git a/gui_dev/src/pages/Settings/Dropdown.jsx b/gui_dev/src/pages/Settings/Dropdown.jsx
deleted file mode 100644
index ad662eef..00000000
--- a/gui_dev/src/pages/Settings/Dropdown.jsx
+++ /dev/null
@@ -1,66 +0,0 @@
-import { useState } from "react";
-import "../App.css";
-
-var stringJson =
- '{"sampling_rate_features_hz":10.0,"segment_length_features_ms":1000.0,"frequency_ranges_hz":{"theta":{"frequency_low_hz":4.0,"frequency_high_hz":8.0},"alpha":{"frequency_low_hz":8.0,"frequency_high_hz":12.0},"low beta":{"frequency_low_hz":13.0,"frequency_high_hz":20.0},"high beta":{"frequency_low_hz":20.0,"frequency_high_hz":35.0},"low gamma":{"frequency_low_hz":60.0,"frequency_high_hz":80.0},"high gamma":{"frequency_low_hz":90.0,"frequency_high_hz":200.0},"HFA":{"frequency_low_hz":200.0,"frequency_high_hz":400.0}},"preprocessing":["raw_resampling","notch_filter","re_referencing"],"raw_resampling_settings":{"resample_freq_hz":1000.0},"preprocessing_filter":{"bandstop_filter":true,"bandpass_filter":true,"lowpass_filter":true,"highpass_filter":true,"bandstop_filter_settings":{"frequency_low_hz":100.0,"frequency_high_hz":160.0},"bandpass_filter_settings":{"frequency_low_hz":2.0,"frequency_high_hz":200.0},"lowpass_filter_cutoff_hz":200.0,"highpass_filter_cutoff_hz":3.0},"raw_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"postprocessing":{"feature_normalization":true,"project_cortex":false,"project_subcortex":false},"feature_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"project_cortex_settings":{"max_dist_mm":20.0},"project_subcortex_settings":{"max_dist_mm":5.0},"features":{"raw_hjorth":true,"return_raw":true,"bandpass_filter":false,"stft":false,"fft":true,"welch":true,"sharpwave_analysis":true,"fooof":false,"nolds":false,"coherence":false,"bursts":true,"linelength":true,"mne_connectivity":false,"bispectrum":false},"fft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"welch_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"stft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"bandpass_filter_settings":{"segment_lengths_ms":{"theta":1000,"alpha":500,"low beta":333,"high beta":333,"low gamma":100,"high gamma":100,"HFA":100},"bandpower_features":{"activity":true,"mobility":false,"complexity":false},"log_transform":true,"kalman_filter":false},"kalman_filter_settings":{"Tp":0.1,"sigma_w":0.7,"sigma_v":1.0,"frequency_bands":["theta","alpha","low_beta","high_beta","low_gamma","high_gamma","HFA"]},"burst_settings":{"threshold":75.0,"time_duration_s":30.0,"frequency_bands":["low beta","high beta","low gamma"],"burst_features":{"duration":true,"amplitude":true,"burst_rate_per_s":true,"in_burst":true}},"sharpwave_analysis_settings":{"sharpwave_features":{"peak_left":false,"peak_right":false,"trough":false,"width":false,"prominence":true,"interval":true,"decay_time":false,"rise_time":false,"sharpness":true,"rise_steepness":false,"decay_steepness":false,"slope_ratio":false},"filter_ranges_hz":[{"frequency_low_hz":5.0,"frequency_high_hz":80.0},{"frequency_low_hz":5.0,"frequency_high_hz":30.0}],"detect_troughs":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"detect_peaks":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"estimator":{"mean":["interval"],"median":[],"max":["prominence","sharpness"],"min":[],"var":[]},"apply_estimator_between_peaks_and_troughs":true},"mne_connectivity":{"method":"plv","mode":"multitaper"},"coherence":{"features":{"mean_fband":true,"max_fband":true,"max_allfbands":true},"method":{"coh":true,"icoh":true},"channels":[],"frequency_bands":["high beta"]},"fooof":{"aperiodic":{"exponent":true,"offset":true,"knee":true},"periodic":{"center_frequency":false,"band_width":false,"height_over_ap":false},"windowlength_ms":800.0,"peak_width_limits":{"frequency_low_hz":0.5,"frequency_high_hz":12.0},"max_n_peaks":3,"min_peak_height":0.0,"peak_threshold":2.0,"freq_range_hz":{"frequency_low_hz":2.0,"frequency_high_hz":40.0},"knee":true},"nolds_features":{"raw":true,"frequency_bands":["low beta"],"features":{"sample_entropy":false,"correlation_dimension":false,"lyapunov_exponent":true,"hurst_exponent":false,"detrended_fluctuation_analysis":false}},"bispectrum":{"f1s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"f2s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"compute_features_for_whole_fband_range":true,"frequency_bands":["theta","alpha","low_beta","high_beta"],"components":{"absolute":true,"real":true,"imag":true,"phase":true},"bispectrum_features":{"mean":true,"sum":true,"var":true}}}';
-const nm_settings = JSON.parse(stringJson);
-
-const filterByKeys = (dict, keys) => {
- const filteredDict = {};
- keys.forEach((key) => {
- if (typeof key === "string") {
- // Top-level key
- if (dict.hasOwnProperty(key)) {
- filteredDict[key] = dict[key];
- }
- } else if (typeof key === "object") {
- // Nested key
- const topLevelKey = Object.keys(key)[0];
- const nestedKeys = key[topLevelKey];
- if (
- dict.hasOwnProperty(topLevelKey) &&
- typeof dict[topLevelKey] === "object"
- ) {
- filteredDict[topLevelKey] = filterByKeys(dict[topLevelKey], nestedKeys);
- }
- }
- });
- return filteredDict;
-};
-
-const Dropdown = ({ onChange, keysToInclude }) => {
- const filteredSettings = filterByKeys(nm_settings, keysToInclude);
- const [selectedOption, setSelectedOption] = useState("");
-
- const handleChange = (event) => {
- const newValue = event.target.value;
- setSelectedOption(newValue);
- onChange(keysToInclude, newValue);
- };
-
- return (
-
-
-
- );
-};
-
-export default Dropdown;
diff --git a/gui_dev/src/pages/Settings/FrequencyRange.jsx b/gui_dev/src/pages/Settings/FrequencyRange.jsx
deleted file mode 100644
index 5def783b..00000000
--- a/gui_dev/src/pages/Settings/FrequencyRange.jsx
+++ /dev/null
@@ -1,88 +0,0 @@
-import { useState } from "react";
-
-// const onChange = (key, newValue) => {
-// settings.frequencyRanges[key] = newValue;
-// };
-
-// Object.entries(settings.frequencyRanges).map(([key, value]) => (
-//
-// ));
-
-/** */
-
-/**
- *
- * @param {String} key
- * @param {Array} freqRange
- * @param {Function} onChange
- * @returns
- */
-export const FrequencyRange = ({ settings }) => {
- const [frequencyRanges, setFrequencyRanges] = useState(settings || {});
- // Handle changes in the text fields
- const handleInputChange = (label, key, newValue) => {
- setFrequencyRanges((prevState) => ({
- ...prevState,
- [label]: {
- ...prevState[label],
- [key]: newValue,
- },
- }));
- };
-
- // Add a new band
- const addBand = () => {
- const newLabel = `Band ${Object.keys(frequencyRanges).length + 1}`;
- setFrequencyRanges((prevState) => ({
- ...prevState,
- [newLabel]: { frequency_high_hz: "", frequency_low_hz: "" },
- }));
- };
-
- // Remove a band
- const removeBand = (label) => {
- const updatedRanges = { ...frequencyRanges };
- delete updatedRanges[label];
- setFrequencyRanges(updatedRanges);
- };
-
- return (
-
- );
-};
diff --git a/gui_dev/src/pages/Settings/FrequencySettings.module.css b/gui_dev/src/pages/Settings/FrequencySettings.module.css
deleted file mode 100644
index a638ad93..00000000
--- a/gui_dev/src/pages/Settings/FrequencySettings.module.css
+++ /dev/null
@@ -1,67 +0,0 @@
-.container {
- background-color: #f9f9f9; /* Light gray background for the container */
- padding: 20px;
- border-radius: 10px; /* Rounded corners */
- box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); /* Subtle shadow */
- max-width: 600px;
- margin: auto;
- }
-
- .header {
- font-size: 1.5rem;
- color: #333; /* Darker text color */
- margin-bottom: 20px;
- text-align: center;
- }
-
- .bandContainer {
- display: flex;
- align-items: center;
- margin-bottom: 15px;
- padding: 10px;
- border: 1px solid #ddd; /* Light border for each band */
- border-radius: 8px; /* Rounded corners for individual bands */
- background-color: #fff; /* White background for bands */
- box-shadow: 0 1px 5px rgba(0, 0, 0, 0.1); /* Light shadow for depth */
- }
-
- .bandNameInput, .frequencyInput {
- border: 1px solid #ccc; /* Light gray border */
- border-radius: 5px; /* Slightly rounded corners */
- padding: 8px;
- margin-right: 10px;
- font-size: 0.875rem;
- }
-
- .bandNameInput::placeholder, .frequencyInput::placeholder {
- color: #aaa; /* Light gray placeholder text */
- }
-
- .removeButton, .addButton {
- border: none;
- border-radius: 5px; /* Rounded corners */
- padding: 8px 12px;
- font-size: 0.875rem;
- cursor: pointer;
- transition: background-color 0.3s, color 0.3s;
- }
-
- .removeButton {
- background-color: #e57373; /* Light red color */
- color: white;
- }
-
- .removeButton:hover {
- background-color: #d32f2f; /* Darker red on hover */
- }
-
- .addButton {
- background-color: #4caf50; /* Green color */
- color: white;
- margin-top: 10px;
- }
-
- .addButton:hover {
- background-color: #388e3c; /* Darker green on hover */
- }
-
\ No newline at end of file
diff --git a/gui_dev/src/pages/Settings/Settings.jsx b/gui_dev/src/pages/Settings/Settings.jsx
index e2cf5f2a..045603be 100644
--- a/gui_dev/src/pages/Settings/Settings.jsx
+++ b/gui_dev/src/pages/Settings/Settings.jsx
@@ -1,19 +1,21 @@
+import { useEffect, useState } from "react";
import {
+ Box,
Button,
- InputAdornment,
+ Popover,
Stack,
Switch,
TextField,
+ Tooltip,
Typography,
} from "@mui/material";
import { Link } from "react-router-dom";
import { CollapsibleBox, TitledBox } from "@/components";
-import { FrequencyRange } from "./FrequencyRange";
-import { useSettingsStore } from "@/stores";
+import { FrequencyRangeList } from "./components/FrequencyRange";
+import { useSettingsStore, useStatusBarContent } from "@/stores";
import { filterObjectByKeys } from "@/utils/functions";
const formatKey = (key) => {
- // console.log(key);
return key
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
@@ -21,15 +23,37 @@ const formatKey = (key) => {
};
// Wrapper components for each type
-const BooleanField = ({ value, onChange }) => (
+const BooleanField = ({ value, onChange, error }) => (
onChange(e.target.checked)} />
);
-const StringField = ({ value, onChange, label }) => (
-
-);
+const errorStyle = {
+ "& .MuiOutlinedInput-root": {
+ "& fieldset": { borderColor: "error.main" },
+ "&:hover fieldset": {
+ borderColor: "error.main",
+ },
+ "&.Mui-focused fieldset": {
+ borderColor: "error.main",
+ },
+ },
+};
+
+const StringField = ({ value, onChange, label, error }) => {
+ const errorSx = error ? errorStyle : {};
+ return (
+ onChange(e.target.value)}
+ label={label}
+ sx={{ ...errorSx }}
+ />
+ );
+};
+
+const NumberField = ({ value, onChange, label, error }) => {
+ const errorSx = error ? errorStyle : {};
-const NumberField = ({ value, onChange, label }) => {
const handleChange = (event) => {
const newValue = event.target.value;
// Only allow numbers and decimal point
@@ -44,13 +68,14 @@ const NumberField = ({ value, onChange, label }) => {
value={value}
onChange={handleChange}
label={label}
- InputProps={{
- endAdornment: (
-
- Hz
-
- ),
- }}
+ sx={{ ...errorSx }}
+ // InputProps={{
+ // endAdornment: (
+ //
+ // Hz
+ //
+ // ),
+ // }}
inputProps={{
pattern: "[0-9]*",
}}
@@ -58,107 +83,230 @@ const NumberField = ({ value, onChange, label }) => {
);
};
-const FrequencyRangeField = ({ value, onChange, label }) => (
-
-);
-
// Map component types to their respective wrappers
const componentRegistry = {
boolean: BooleanField,
string: StringField,
number: NumberField,
- FrequencyRange: FrequencyRangeField,
};
-const SettingsField = ({ path, Component, label, value, onChange, depth }) => {
+const SettingsField = ({ path, Component, label, value, onChange, error }) => {
return (
-
- {label}
- onChange(path, newValue)}
- label={label}
- />
-
+
+
+ {label}
+ onChange(path, newValue)}
+ label={label}
+ error={error}
+ />
+
+
);
};
+// Function to get the error corresponding to this field or its children
+const getFieldError = (fieldPath, errors) => {
+ if (!errors) return null;
+
+ return errors.find((error) => {
+ const errorPath = error.loc.join(".");
+ const currentPath = fieldPath.join(".");
+ return errorPath === currentPath || errorPath.startsWith(currentPath + ".");
+ });
+};
+
const SettingsSection = ({
settings,
title = null,
path = [],
onChange,
- depth = 0,
+ errors,
}) => {
- if (Object.keys(settings).length === 0) {
- return null;
- }
const boxTitle = title ? title : formatKey(path[path.length - 1]);
+ /*
+ 3 possible cases:
+ 1. Primitive type || 2. Object with component -> Don't iterate, render directly
+ 3. Object without component or 4. Array -> Iterate and render recursively
+ */
- return (
-
- {Object.entries(settings).map(([key, value]) => {
- if (key === "__field_type__") return null;
+ const type = typeof settings;
+ const isObject = type === "object" && !Array.isArray(settings);
+ const isArray = Array.isArray(settings);
- const newPath = [...path, key];
- const label = key;
- const isPydanticModel =
- typeof value === "object" && "__field_type__" in value;
+ // __field_type__ should be always present
+ if (isObject && !settings.__field_type__) {
+ console.log(settings);
+ throw new Error("Invalid settings object");
+ }
+ const fieldType = isObject ? settings.__field_type__ : type;
+ const Component = componentRegistry[fieldType];
- const fieldType = isPydanticModel ? value.__field_type__ : typeof value;
+ // Case 1: Primitive type -> Don't iterate, render directly
+ if (!isObject && !isArray) {
+ if (!Component) {
+ console.error(`Invalid component type: ${type}`);
+ return null;
+ }
- const Component = componentRegistry[fieldType];
+ const error = getFieldError(path, errors);
+
+ return (
+
+ );
+ }
+
+ // Case 2: Object with component -> Don't iterate, render directly
+ if (isObject && Component) {
+ return (
+
+ );
+ }
+
+ // Case 3: Object without component or 4. Array -> Iterate and render recursively
+ if ((isObject && !Component) || isArray) {
+ return (
+
+ {/* Handle recursing through both objects and arrays */}
+ {(isArray ? settings : Object.entries(settings)).map((item, index) => {
+ const [key, value] = isArray ? [index.toString(), item] : item;
+ if (key.startsWith("__")) return null; // Skip metadata fields
+
+ const newPath = [...path, key];
- if (Component) {
- return (
-
- );
- } else {
return (
);
- }
- })}
-
+ })}
+
+ );
+ }
+
+ // Default case: return null and log an error
+ console.error(`Invalid settings object, returning null`);
+ return null;
+};
+
+const StatusBarSettingsInfo = () => {
+ const validationErrors = useSettingsStore((state) => state.validationErrors);
+ const [anchorEl, setAnchorEl] = useState(null);
+ const open = Boolean(anchorEl);
+
+ const handleOpenErrorsPopover = (event) => {
+ setAnchorEl(event.currentTarget);
+ };
+
+ const handleCloseErrorsPopover = () => {
+ setAnchorEl(null);
+ };
+
+ return (
+ <>
+ {validationErrors?.length > 0 && (
+ <>
+
+ {validationErrors?.length} errors found in Settings
+
+
+
+ {validationErrors.map((error, index) => (
+
+ {index} - [{error.type}] {error.msg}
+
+ ))}
+
+
+ >
+ )}
+ >
);
};
-const SettingsContent = () => {
+export const Settings = () => {
+ // Get all necessary state from the settings store
const settings = useSettingsStore((state) => state.settings);
- const updateSettings = useSettingsStore((state) => state.updateSettings);
+ const uploadSettings = useSettingsStore((state) => state.uploadSettings);
+ const resetSettings = useSettingsStore((state) => state.resetSettings);
+ const validationErrors = useSettingsStore((state) => state.validationErrors);
+ useStatusBarContent(StatusBarSettingsInfo);
+
+ // This is needed so that the frequency ranges stay in order between updates
+ const frequencyRangeOrder = useSettingsStore(
+ (state) => state.frequencyRangeOrder
+ );
+ const updateFrequencyRangeOrder = useSettingsStore(
+ (state) => state.updateFrequencyRangeOrder
+ );
+ // Here I handle the selected feature in the feature settings component
+ const [selectedFeature, setSelectedFeature] = useState("");
+
+ useEffect(() => {
+ uploadSettings(null, true); // validateOnly = true
+ }, [settings]);
+
+ // Inject validation error info into status bar
+
+ // This has to be after all the hooks, otherwise React will complain
if (!settings) {
return Loading settings...
;
}
- const handleChange = (path, value) => {
- updateSettings((settings) => {
+ // This are the callbacks for the different buttons
+ const handleChangeSettings = async (path, value) => {
+ uploadSettings((settings) => {
let current = settings;
for (let i = 0; i < path.length - 1; i++) {
current = current[path[i]];
}
current[path[path.length - 1]] = value;
- });
+ }, true); // validateOnly = true
+ };
+
+ const handleSaveSettings = () => {
+ uploadSettings(() => settings);
+ };
+
+ const handleResetSettings = async () => {
+ await resetSettings();
};
const featureSettingsKeys = Object.keys(settings.features)
@@ -181,89 +329,147 @@ const SettingsContent = () => {
"project_subcortex_settings",
];
+ const generalSettingsKeys = [
+ "sampling_rate_features_hz",
+ "segment_length_features_ms",
+ ];
+
return (
-
-
+ {/* SETTINGS LAYOUT */}
+
-
-
+ {/* GENERAL SETTINGS + FREQUENCY RANGES */}
+
+
+ {generalSettingsKeys.map((key) => (
+
+ ))}
+
+
+
+
+
+
-
-
+ {/* POSTPROCESSING + PREPROCESSING SETTINGS */}
+
{preprocessingSettingsKeys.map((key) => (
))}
-
+
-
+
{postprocessingSettingsKeys.map((key) => (
))}
-
-
+
-
- {Object.entries(enabledFeatures).map(([feature, featureSettings]) => (
-
-
-
- ))}
-
-
- );
-};
+ {/* FEATURE SETTINGS */}
+
+
+
+
+
+
+ {Object.entries(enabledFeatures).map(
+ ([feature, featureSettings]) => (
+
+
+
+ )
+ )}
+
+
+
+ {/* END SETTINGS LAYOUT */}
+
-export const Settings = () => {
- return (
-
-
-
+
+ {/* */}
+
+
+
);
};
diff --git a/gui_dev/src/pages/Settings/TextField.jsx b/gui_dev/src/pages/Settings/TextField.jsx
deleted file mode 100644
index 86ab07b8..00000000
--- a/gui_dev/src/pages/Settings/TextField.jsx
+++ /dev/null
@@ -1,77 +0,0 @@
-import { useState, useEffect } from "react";
-import {
- Box,
- Grid,
- TextField as MUITextField,
- Typography,
-} from "@mui/material";
-import { useSettingsStore } from "@/stores";
-import styles from "./TextField.module.css";
-
-const flattenDictionary = (dict, parentKey = "", result = {}) => {
- for (let key in dict) {
- const newKey = parentKey ? `${parentKey}.${key}` : key;
- if (typeof dict[key] === "object" && dict[key] !== null) {
- flattenDictionary(dict[key], newKey, result);
- } else {
- result[newKey] = dict[key];
- }
- }
- return result;
-};
-
-const filterByKeys = (flatDict, keys) => {
- const filteredDict = {};
- keys.forEach((key) => {
- if (flatDict.hasOwnProperty(key)) {
- filteredDict[key] = flatDict[key];
- }
- });
- return filteredDict;
-};
-
-export const TextField = ({ keysToInclude }) => {
- const settings = useSettingsStore((state) => state.settings);
- const flatSettings = flattenDictionary(settings);
- const filteredSettings = filterByKeys(flatSettings, keysToInclude);
- const [textLabels, setTextLabels] = useState({});
-
- useEffect(() => {
- setTextLabels(filteredSettings);
- }, [settings]);
-
- const handleTextFieldChange = (label, value) => {
- setTextLabels((prevLabels) => ({
- ...prevLabels,
- [label]: value,
- }));
- };
-
- // Function to format the label
- const formatLabel = (label) => {
- const labelAfterDot = label.split(".").pop(); // Get everything after the last dot
- return labelAfterDot.replace(/_/g, " "); // Replace underscores with spaces
- };
-
- return (
-
- {Object.keys(textLabels).map((label, index) => (
-
-
- handleTextFieldChange(label, e.target.value)}
- className={styles.textFieldInput}
- />
-
- ))}
-
- );
-};
diff --git a/gui_dev/src/pages/Settings/TextField.module.css b/gui_dev/src/pages/Settings/TextField.module.css
deleted file mode 100644
index 37791c40..00000000
--- a/gui_dev/src/pages/Settings/TextField.module.css
+++ /dev/null
@@ -1,67 +0,0 @@
-/* TextField.module.css */
-
-/* Container for the text fields */
-.textFieldContainer {
- display: flex;
- flex-direction: column;
- margin: 1.5rem 0; /* Increased margin for better spacing */
- }
-
- /* Row for each text field */
- .textFieldRow {
- display: flex;
- flex-direction: column; /* Stack label and input vertically */
- margin-bottom: 1rem; /* Increased margin for better separation */
- }
-
- /* Label for each text field */
- .textFieldLabel {
- margin-bottom: 0.5rem; /* Space between label and input */
- font-weight: 600; /* Increased weight for better visibility */
- color: #333; /* Dark gray for the label */
- font-size: 1.1rem; /* Increased font size for the label */
- transition: all 0.2s ease; /* Smooth transition for label */
- }
-
- /* Input field styles */
- .textFieldInput {
- padding: 12px 14px; /* Padding for a filled look */
- border: 1px solid #ccc; /* Light gray border */
- border-radius: 4px; /* Rounded corners */
- width: 100%; /* Full width */
- font-size: 1rem; /* Font size */
- background-color: #f5f5f5; /* Light background color for filled effect */
- transition: border-color 0.2s ease, background-color 0.2s ease; /* Smooth transitions */
- box-shadow: none; /* Remove default shadow */
- height: 48px; /* Fixed height for a more square appearance */
- }
-
- /* Focus styles for the input */
- .textFieldInput:focus {
- border-color: #1976d2; /* Blue border color on focus */
- background-color: #fff; /* Change background to white on focus */
- outline: none; /* Remove default outline */
- }
-
- /* Hover effect for the input */
- .textFieldInput:hover {
- border-color: #1976d2; /* Change border color on hover */
- }
-
- /* Placeholder styles */
- .textFieldInput::placeholder {
- color: #aaa; /* Light gray placeholder text */
- opacity: 1; /* Ensure placeholder is fully opaque */
- }
-
- /* Hide the number input spinners in webkit browsers */
- .textFieldInput::-webkit-inner-spin-button,
- .textFieldInput::-webkit-outer-spin-button {
- -webkit-appearance: none; /* Remove default styling */
- margin: 0; /* Remove margin */
- }
-
- /* Hide the number input spinners in Firefox */
- .textFieldInput[type='number'] {
- -moz-appearance: textfield; /* Use textfield appearance */
- }
\ No newline at end of file
diff --git a/gui_dev/src/pages/Settings/components/FrequencyRange.jsx b/gui_dev/src/pages/Settings/components/FrequencyRange.jsx
new file mode 100644
index 00000000..16333166
--- /dev/null
+++ b/gui_dev/src/pages/Settings/components/FrequencyRange.jsx
@@ -0,0 +1,182 @@
+import { useState } from "react";
+import { TextField, Button, IconButton, Stack } from "@mui/material";
+import { Add, Close } from "@mui/icons-material";
+import { debounce } from "@/utils";
+
+const NumberField = ({ ...props }) => (
+
+);
+
+export const FrequencyRange = ({
+ name,
+ range,
+ onChangeName,
+ onChangeRange,
+ onRemove,
+ error,
+}) => {
+ const [localName, setLocalName] = useState(name);
+
+ const debouncedChangeName = debounce((newName) => {
+ onChangeName(newName, name);
+ }, 1000);
+
+ const handleNameChange = (e) => {
+ const newName = e.target.value;
+ setLocalName(newName);
+ debouncedChangeName(newName);
+ };
+
+ const handleNameBlur = () => {
+ onChangeName(localName, name);
+ };
+
+ const handleKeyPress = (e) => {
+ if (e.key === "Enter") {
+ console.log(e.target.value, name);
+ onChangeName(localName, name);
+ }
+ };
+
+ const handleRangeChange = (name, field, value) => {
+ // onChangeRange takes the name of the range as the first argument
+ onChangeRange(name, { ...range, [field]: value });
+ };
+
+ return (
+
+
+
+ handleRangeChange(name, "frequency_low_hz", e.target.value)
+ }
+ label="Low Hz"
+ />
+
+ handleRangeChange(name, "frequency_high_hz", e.target.value)
+ }
+ label="High Hz"
+ />
+ onRemove(name)}
+ color="primary"
+ disableRipple
+ sx={{ m: 0, p: 0 }}
+ >
+
+
+
+ );
+};
+
+export const FrequencyRangeList = ({
+ ranges,
+ rangeOrder,
+ onChange,
+ onOrderChange,
+ errors,
+}) => {
+ const handleChangeRange = (name, newRange) => {
+ const updatedRanges = { ...ranges };
+ updatedRanges[name] = newRange;
+ onChange(["frequency_ranges_hz"], updatedRanges);
+ };
+
+ const handleChangeName = (newName, oldName) => {
+ if (oldName === newName) {
+ return;
+ }
+
+ const updatedRanges = { ...ranges, [newName]: ranges[oldName] };
+ delete updatedRanges[oldName];
+ onChange(["frequency_ranges_hz"], updatedRanges);
+
+ const updatedOrder = rangeOrder.map((name) =>
+ name === oldName ? newName : name
+ );
+ onOrderChange(updatedOrder);
+ };
+
+ const handleRemove = (name) => {
+ const updatedRanges = { ...ranges };
+ delete updatedRanges[name];
+ onChange(["frequency_ranges_hz"], updatedRanges);
+
+ const updatedOrder = rangeOrder.filter((item) => item !== name);
+ onOrderChange(updatedOrder);
+ };
+
+ const addRange = () => {
+ let newName = "NewRange";
+ let counter = 0;
+ // Find first available name
+ while (Object.hasOwn(ranges, newName)) {
+ counter++;
+ newName = `NewRange${counter}`;
+ }
+
+ const updatedRanges = {
+ ...ranges,
+ [newName]: {
+ __field_type__: "FrequencyRange",
+ frequency_low_hz: 1,
+ frequency_high_hz: 500,
+ },
+ };
+ onChange(["frequency_ranges_hz"], updatedRanges);
+
+ const updatedOrder = [...rangeOrder, newName];
+ onOrderChange(updatedOrder);
+ };
+
+ return (
+
+ {rangeOrder.map((name, index) => (
+
+ ))}
+ }
+ onClick={addRange}
+ sx={{ mt: 1 }}
+ >
+ Add Range
+
+
+ );
+};
diff --git a/gui_dev/src/pages/Settings/index.js b/gui_dev/src/pages/Settings/index.js
deleted file mode 100644
index 16355023..00000000
--- a/gui_dev/src/pages/Settings/index.js
+++ /dev/null
@@ -1 +0,0 @@
-export { TextField } from './TextField';
diff --git a/gui_dev/src/pages/SourceSelection/FileSelector.jsx b/gui_dev/src/pages/SourceSelection/FileSelector.jsx
index ffa53bab..dd0663b1 100644
--- a/gui_dev/src/pages/SourceSelection/FileSelector.jsx
+++ b/gui_dev/src/pages/SourceSelection/FileSelector.jsx
@@ -25,6 +25,7 @@ export const FileSelector = () => {
const [isSelecting, setIsSelecting] = useState(false);
const [showFileBrowser, setShowFileBrowser] = useState(false);
+ const [showFolderBrowser, setShowFolderBrowser] = useState(false);
useEffect(() => {
setSourceType("lsl");
@@ -48,6 +49,10 @@ export const FileSelector = () => {
}
};
+ const handleFolderSelect = (folder) => {
+ setShowFolderBrowser(false);
+ };
+
return (
+
{streamSetupMessage && (
{
onSelect={handleFileSelect}
/>
)}
+ {showFolderBrowser && (
+ setShowFolderBrowser(false)}
+ onSelect={handleFolderSelect}
+ onlyDirectories={true}
+ />
+ )}
);
};
diff --git a/gui_dev/src/stores/createStore.js b/gui_dev/src/stores/createStore.js
index 762d9340..e683a35c 100644
--- a/gui_dev/src/stores/createStore.js
+++ b/gui_dev/src/stores/createStore.js
@@ -2,16 +2,23 @@ import { create } from "zustand";
import { immer } from "zustand/middleware/immer";
import { devtools, persist as persistMiddleware } from "zustand/middleware";
-export const createStore = (name, initializer, persist = false) => {
+export const createStore = (
+ name,
+ initializer,
+ persist = false,
+ dev = false
+) => {
const fn = persist
? persistMiddleware(immer(initializer), name)
: immer(initializer);
- return create(
- devtools(fn, {
- name: name,
- })
- );
+ const dev_fn = dev
+ ? devtools(fn, {
+ name: name,
+ })
+ : fn;
+
+ return create(dev_fn);
};
export const createPersistStore = (name, initializer) => {
diff --git a/gui_dev/src/stores/settingsStore.js b/gui_dev/src/stores/settingsStore.js
index 9b17f42d..68929f1e 100644
--- a/gui_dev/src/stores/settingsStore.js
+++ b/gui_dev/src/stores/settingsStore.js
@@ -26,11 +26,16 @@ const uploadSettingsToServer = async (settings) => {
export const useSettingsStore = createStore("settings", (set, get) => ({
settings: null,
+ lastValidSettings: null,
+ frequencyRangeOrder: [],
isLoading: false,
error: null,
+ validationErrors: null,
retryCount: 0,
- setSettings: (settings) => set({ settings }),
+ updateLocalSettings: (updater) => {
+ set((state) => updater(state.settings));
+ },
fetchSettingsWithDelay: () => {
set({ isLoading: true, error: null });
@@ -46,8 +51,15 @@ export const useSettingsStore = createStore("settings", (set, get) => ({
if (!response.ok) {
throw new Error("Failed to fetch settings");
}
+
const data = await response.json();
- set({ settings: data, retryCount: 0 });
+
+ set({
+ settings: data,
+ lastValidSettings: data,
+ frequencyRangeOrder: Object.keys(data.frequency_ranges_hz || {}),
+ retryCount: 0,
+ });
} catch (error) {
console.log("Error fetching settings:", error);
set((state) => ({
@@ -55,6 +67,8 @@ export const useSettingsStore = createStore("settings", (set, get) => ({
retryCount: state.retryCount + 1,
}));
+ console.log(get().retryCount);
+
if (get().retryCount < MAX_RETRIES) {
await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY));
return get().fetchSettings();
@@ -66,29 +80,62 @@ export const useSettingsStore = createStore("settings", (set, get) => ({
resetRetryCount: () => set({ retryCount: 0 }),
- updateSettings: async (updater) => {
- const currentSettings = get().settings;
+ resetSettings: async () => {
+ await get().fetchSettings(true);
+ },
- // Apply the update optimistically
- set((state) => {
- updater(state.settings);
- });
+ updateFrequencyRangeOrder: (newOrder) => {
+ set({ frequencyRangeOrder: newOrder });
+ },
- const newSettings = get().settings;
+ uploadSettings: async (updater, validateOnly = false) => {
+ if (updater) {
+ set((state) => {
+ updater(state.settings);
+ });
+ }
+
+ const currentSettings = get().settings;
try {
- const result = await uploadSettingsToServer(newSettings);
+ const response = await fetch(
+ `/api/settings${validateOnly ? "?validate_only=true" : ""}`,
+ {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ },
+ body: JSON.stringify(currentSettings),
+ }
+ );
+
+ const data = await response.json();
- if (!result.success) {
- // Revert the local state if the server update failed
- set({ settings: currentSettings });
+ if (!response.ok) {
+ throw new Error("Failed to upload settings to backend");
}
- return result;
+ if (data.valid) {
+ // Settings are valid
+ set({
+ lastValidSettings: currentSettings,
+ validationErrors: null,
+ });
+ return true;
+ } else {
+ // Settings are invalid
+ set({
+ validationErrors: data.errors,
+ });
+ // Note: We don't revert the settings here, keeping the potentially invalid state
+ return false;
+ }
} catch (error) {
- // Revert the local state if there was an error
- set({ settings: currentSettings });
- throw error;
+ console.error(
+ `Error ${validateOnly ? "validating" : "updating"} settings:`,
+ error
+ );
+ return false;
}
},
}));
diff --git a/gui_dev/src/stores/uiStore.js b/gui_dev/src/stores/uiStore.js
index a229ec06..25b44382 100644
--- a/gui_dev/src/stores/uiStore.js
+++ b/gui_dev/src/stores/uiStore.js
@@ -1,4 +1,5 @@
import { createPersistStore } from "./createStore";
+import { useEffect } from "react";
export const useUiStore = createPersistStore("ui", (set, get) => ({
activeDrawer: null,
@@ -28,4 +29,24 @@ export const useUiStore = createPersistStore("ui", (set, get) => ({
state.accordionStates[id] = defaultState;
}
}),
+
+ // Hook to inject UI elements into the status bar
+ statusBarContent: () => {},
+ setStatusBarContent: (content) => set({ statusBarContent: content }),
+ clearStatusBarContent: () => set({ statusBarContent: null }),
}));
+
+// Use this hook from Page components to inject page-specific UI elements into the status bar
+export const useStatusBarContent = (content) => {
+ const createStatusBarContent = () => content;
+
+ const setStatusBarContent = useUiStore((state) => state.setStatusBarContent);
+ const clearStatusBarContent = useUiStore(
+ (state) => state.clearStatusBarContent
+ );
+
+ useEffect(() => {
+ setStatusBarContent(createStatusBarContent);
+ return () => clearStatusBarContent();
+ }, [content, setStatusBarContent, clearStatusBarContent]);
+};
diff --git a/gui_dev/src/theme.js b/gui_dev/src/theme.js
index 5e2b5f11..cc99909c 100644
--- a/gui_dev/src/theme.js
+++ b/gui_dev/src/theme.js
@@ -22,6 +22,11 @@ export const theme = createTheme({
disableRipple: true,
},
},
+ MuiTextField: {
+ defaultProps: {
+ autoComplete: "off",
+ },
+ },
MuiStack: {
defaultProps: {
alignItems: "center",
diff --git a/gui_dev/src/utils/functions.js b/gui_dev/src/utils/functions.js
index 04289d5c..1f8ae90a 100644
--- a/gui_dev/src/utils/functions.js
+++ b/gui_dev/src/utils/functions.js
@@ -17,6 +17,7 @@ export function debounce(func, wait) {
timeout = setTimeout(later, wait);
};
}
+
export const flattenDictionary = (dict, parentKey = "", result = {}) => {
for (let key in dict) {
const newKey = parentKey ? `${parentKey}.${key}` : key;
diff --git a/py_neuromodulation/__init__.py b/py_neuromodulation/__init__.py
index 688393fd..864df345 100644
--- a/py_neuromodulation/__init__.py
+++ b/py_neuromodulation/__init__.py
@@ -4,11 +4,12 @@
from importlib.metadata import version
from py_neuromodulation.utils.logging import NMLogger
+
#####################################
# Globals and environment variables #
#####################################
-__version__ = version("py_neuromodulation") # get version from pyproject.toml
+__version__ = version("py_neuromodulation")
# Check if the module is running headless (no display) for tests and doc builds
PYNM_HEADLESS: bool = not os.environ.get("DISPLAY")
@@ -18,6 +19,7 @@
os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file
+
# Set environment variable MNE_LSL_LIB (required to import Stream below)
LSL_DICT = {
"windows_32bit": "windows/x86/liblsl.1.16.2.dll",
@@ -36,6 +38,7 @@
PLATFORM = platform.system().lower().strip()
ARCH = platform.architecture()[0]
+
match PLATFORM:
case "windows":
KEY = PLATFORM + "_" + ARCH
diff --git a/py_neuromodulation/default_settings.yaml b/py_neuromodulation/default_settings.yaml
index c50b5f8e..e68194f8 100644
--- a/py_neuromodulation/default_settings.yaml
+++ b/py_neuromodulation/default_settings.yaml
@@ -1,4 +1,15 @@
----
+# We
+# should
+# have
+# a
+# brief
+# explanation
+# of
+# the
+# settings
+# format
+# here
+
########################
### General settings ###
########################
@@ -51,12 +62,8 @@ preprocessing_filter:
lowpass_filter: true
highpass_filter: true
bandpass_filter: true
- bandstop_filter_settings:
- frequency_low_hz: 100
- frequency_high_hz: 160
- bandpass_filter_settings:
- frequency_low_hz: 3
- frequency_high_hz: 200
+ bandstop_filter_settings: [100, 160] # [low_hz, high_hz]
+ bandpass_filter_settings: [3, 200] # [hz, _hz]
lowpass_filter_cutoff_hz: 200
highpass_filter_cutoff_hz: 3
@@ -162,11 +169,9 @@ sharpwave_analysis_settings:
rise_steepness: false
decay_steepness: false
slope_ratio: false
- filter_ranges_hz:
- - frequency_low_hz: 5
- frequency_high_hz: 80
- - frequency_low_hz: 5
- frequency_high_hz: 30
+ filter_ranges_hz: # list of [low_hz, high_hz]
+ - [5, 80]
+ - [5, 30]
detect_troughs:
estimate: true
distance_troughs_ms: 10
@@ -175,6 +180,7 @@ sharpwave_analysis_settings:
estimate: true
distance_troughs_ms: 5
distance_peaks_ms: 10
+ # TONI: Reverse this setting? e.g. interval: [mean, var]
estimator:
mean: [interval]
median: []
@@ -223,8 +229,8 @@ nolds_settings:
frequency_bands: [low_beta]
mne_connectiviy_settings:
- method: plv
- mode: multitaper
+ method: plv # One of ['coh', 'cohy', 'imcoh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli','wpli', 'wpli2_debiased', 'gc', 'gc_tr']
+ mode: multitaper # One of ['multitaper', 'fourier', 'cwt_morlet']
bispectrum_settings:
f1s: [5, 35]
diff --git a/py_neuromodulation/features/bandpower.py b/py_neuromodulation/features/bandpower.py
index dac13f4b..72dab126 100644
--- a/py_neuromodulation/features/bandpower.py
+++ b/py_neuromodulation/features/bandpower.py
@@ -2,8 +2,13 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING
from pydantic import field_validator
-
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
+from py_neuromodulation.utils.pydantic_extensions import (
+ NMField,
+ NMErrorList,
+ create_validation_error,
+)
+from py_neuromodulation import logger
if TYPE_CHECKING:
from py_neuromodulation.stream.settings import NMSettings
@@ -17,15 +22,18 @@ class BandpowerFeatures(BoolSelector):
class BandPowerSettings(NMBaseModel):
- segment_lengths_ms: dict[str, int] = {
- "theta": 1000,
- "alpha": 500,
- "low beta": 333,
- "high beta": 333,
- "low gamma": 100,
- "high gamma": 100,
- "HFA": 100,
- }
+ segment_lengths_ms: dict[str, int] = NMField(
+ default={
+ "theta": 1000,
+ "alpha": 500,
+ "low beta": 333,
+ "high beta": 333,
+ "low gamma": 100,
+ "high gamma": 100,
+ "HFA": 100,
+ },
+ custom_metadata={"field_type": "FrequencySegmentLength"},
+ )
bandpower_features: BandpowerFeatures = BandpowerFeatures()
log_transform: bool = True
kalman_filter: bool = False
@@ -33,24 +41,58 @@ class BandPowerSettings(NMBaseModel):
@field_validator("bandpower_features")
@classmethod
def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
- assert (
- len(bandpower_features.get_enabled()) > 0
- ), "Set at least one bandpower_feature to True."
-
+ if not len(bandpower_features.get_enabled()) > 0:
+ raise create_validation_error(
+ error_message="Set at least one bandpower_feature to True.",
+ location=["bandpass_filter_settings", "bandpower_features"],
+ )
return bandpower_features
- def validate_fbands(self, settings: "NMSettings") -> None:
+ def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
+ """_summary_
+
+ :param settings: _description_
+ :type settings: NMSettings
+ :raises create_validation_error: _description_
+ :raises create_validation_error: _description_
+ :raises ValueError: _description_
+ """
+ errors: NMErrorList = NMErrorList()
+
for fband_name, seg_length_fband in self.segment_lengths_ms.items():
- assert seg_length_fband <= settings.segment_length_features_ms, (
- f"segment length {seg_length_fband} needs to be smaller than "
- f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}"
- )
+ # Check that all frequency bands are defined in settings.frequency_ranges_hz
+ if fband_name not in settings.frequency_ranges_hz:
+ logger.warning(
+ f"Frequency band {fband_name} in bandpass_filter_settings.segment_lengths_ms"
+ " is not defined in settings.frequency_ranges_hz"
+ )
+
+ # Check that all segment lengths are smaller than settings.segment_length_features_ms
+ if not seg_length_fband <= settings.segment_length_features_ms:
+ errors.add_error(
+ f"segment length {seg_length_fband} needs to be smaller than "
+ f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
+ location=[
+ "bandpass_filter_settings",
+ "segment_lengths_ms",
+ fband_name,
+ ],
+ )
+ # Check that all frequency bands defined in settings.frequency_ranges_hz
for fband_name in settings.frequency_ranges_hz.keys():
- assert fband_name in self.segment_lengths_ms, (
- f"frequency range {fband_name} "
- "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]"
- )
+ if fband_name not in self.segment_lengths_ms:
+ errors.add_error(
+ f"frequency range {fband_name} "
+ "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms",
+ location=[
+ "bandpass_filter_settings",
+ "segment_lengths_ms",
+ fband_name,
+ ],
+ )
+
+ return errors
class BandPower(NMFeature):
diff --git a/py_neuromodulation/features/bursts.py b/py_neuromodulation/features/bursts.py
index 83d63bf8..e9f5b632 100644
--- a/py_neuromodulation/features/bursts.py
+++ b/py_neuromodulation/features/bursts.py
@@ -7,11 +7,12 @@
from collections.abc import Sequence
from itertools import product
-from pydantic import Field, field_validator
+from pydantic import field_validator
from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
+from py_neuromodulation.utils.pydantic_extensions import NMField
from typing import TYPE_CHECKING, Callable
-from py_neuromodulation.utils.types import create_validation_error
+from py_neuromodulation.utils.pydantic_extensions import create_validation_error
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
@@ -46,8 +47,8 @@ class BurstFeatures(BoolSelector):
class BurstsSettings(NMBaseModel):
- threshold: float = Field(default=75, ge=0, le=100)
- time_duration_s: float = Field(default=30, ge=0)
+ threshold: float = NMField(default=75, ge=0)
+ time_duration_s: float = NMField(default=30, ge=0, custom_metadata={"unit": "s"})
frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
burst_features: BurstFeatures = BurstFeatures()
diff --git a/py_neuromodulation/features/coherence.py b/py_neuromodulation/features/coherence.py
index 21ca471b..6a100710 100644
--- a/py_neuromodulation/features/coherence.py
+++ b/py_neuromodulation/features/coherence.py
@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Annotated
-from pydantic import Field, field_validator
+from pydantic import field_validator
from py_neuromodulation.utils.types import (
NMFeature,
@@ -11,6 +11,7 @@
FrequencyRange,
NMBaseModel,
)
+from py_neuromodulation.utils.pydantic_extensions import NMField
from py_neuromodulation import logger
if TYPE_CHECKING:
@@ -28,14 +29,16 @@ class CoherenceFeatures(BoolSelector):
max_allfbands: bool = True
-ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]
+# TODO: make this into a pydantic model that only accepts names from
+# the channels objects and does not accept the same string twice
+ListOfTwoStr = Annotated[list[str], NMField(min_length=2, max_length=2)]
class CoherenceSettings(NMBaseModel):
features: CoherenceFeatures = CoherenceFeatures()
method: CoherenceMethods = CoherenceMethods()
channels: list[ListOfTwoStr] = []
- frequency_bands: list[str] = Field(default=["high_beta"], min_length=1)
+ frequency_bands: list[str] = NMField(default=["high_beta"], min_length=1)
@field_validator("frequency_bands")
def fbands_spaces_to_underscores(cls, frequency_bands):
diff --git a/py_neuromodulation/features/feature_processor.py b/py_neuromodulation/features/feature_processor.py
index 9e970e8f..dbb7e5c9 100644
--- a/py_neuromodulation/features/feature_processor.py
+++ b/py_neuromodulation/features/feature_processor.py
@@ -1,13 +1,13 @@
from typing import Type, TYPE_CHECKING
-from py_neuromodulation.utils.types import NMFeature, FeatureName
+from py_neuromodulation.utils.types import NMFeature, FEATURE_NAME
if TYPE_CHECKING:
import numpy as np
from py_neuromodulation import NMSettings
-FEATURE_DICT: dict[FeatureName | str, str] = {
+FEATURE_DICT: dict[FEATURE_NAME | str, str] = {
"raw_hjorth": "Hjorth",
"return_raw": "Raw",
"bandpass_filter": "BandPower",
@@ -42,7 +42,7 @@ def __init__(
from importlib import import_module
# Accept 'str' for custom features
- self.features: dict[FeatureName | str, NMFeature] = {
+ self.features: dict[FEATURE_NAME | str, NMFeature] = {
feature_name: getattr(
import_module("py_neuromodulation.features"), FEATURE_DICT[feature_name]
)(settings, ch_names, sfreq)
@@ -83,7 +83,7 @@ def estimate_features(self, data: "np.ndarray") -> dict:
return feature_results
- def get_feature(self, fname: FeatureName) -> NMFeature:
+ def get_feature(self, fname: FEATURE_NAME) -> NMFeature:
return self.features[fname]
diff --git a/py_neuromodulation/features/fooof.py b/py_neuromodulation/features/fooof.py
index 683c5304..9f4b9544 100644
--- a/py_neuromodulation/features/fooof.py
+++ b/py_neuromodulation/features/fooof.py
@@ -9,6 +9,7 @@
BoolSelector,
FrequencyRange,
)
+from py_neuromodulation.utils.pydantic_extensions import NMField
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
@@ -29,11 +30,11 @@ class FooofPeriodicSettings(BoolSelector):
class FooofSettings(NMBaseModel):
aperiodic: FooofAperiodicSettings = FooofAperiodicSettings()
periodic: FooofPeriodicSettings = FooofPeriodicSettings()
- windowlength_ms: float = 800
+ windowlength_ms: float = NMField(800, gt=0, custom_metadata={"unit": "ms"})
peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12)
- max_n_peaks: int = 3
- min_peak_height: float = 0
- peak_threshold: float = 2
+ max_n_peaks: int = NMField(3, ge=0)
+ min_peak_height: float = NMField(0, ge=0)
+ peak_threshold: float = NMField(2, ge=0)
freq_range_hz: FrequencyRange = FrequencyRange(2, 40)
knee: bool = True
diff --git a/py_neuromodulation/features/mne_connectivity.py b/py_neuromodulation/features/mne_connectivity.py
index 88883771..f7a186e7 100644
--- a/py_neuromodulation/features/mne_connectivity.py
+++ b/py_neuromodulation/features/mne_connectivity.py
@@ -1,8 +1,9 @@
from collections.abc import Iterable
import numpy as np
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
from py_neuromodulation.utils.types import NMFeature, NMBaseModel
+from py_neuromodulation.utils.pydantic_extensions import NMField
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
@@ -10,9 +11,30 @@
from mne import Epochs
+MNE_CONNECTIVITY_METHOD = Literal[
+ "coh",
+ "cohy",
+ "imcoh",
+ "cacoh",
+ "mic",
+ "mim",
+ "plv",
+ "ciplv",
+ "ppc",
+ "pli",
+ "dpli",
+ "wpli",
+ "wpli2_debiased",
+ "gc",
+ "gc_tr",
+]
+
+MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"]
+
+
class MNEConnectivitySettings(NMBaseModel):
- method: str = "plv"
- mode: str = "multitaper"
+ method: MNE_CONNECTIVITY_METHOD = NMField(default="plv")
+ mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper")
class MNEConnectivity(NMFeature):
diff --git a/py_neuromodulation/features/oscillatory.py b/py_neuromodulation/features/oscillatory.py
index ba376cc6..bb4b9e6a 100644
--- a/py_neuromodulation/features/oscillatory.py
+++ b/py_neuromodulation/features/oscillatory.py
@@ -3,6 +3,7 @@
from itertools import product
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
+from py_neuromodulation.utils.pydantic_extensions import NMField
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -17,7 +18,7 @@ class OscillatoryFeatures(BoolSelector):
class OscillatorySettings(NMBaseModel):
- windowlength_ms: int = 1000
+ windowlength_ms: int = NMField(1000, gt=0, custom_metadata={"unit": "ms"})
log_transform: bool = True
features: OscillatoryFeatures = OscillatoryFeatures(
mean=True, median=False, std=False, max=False
diff --git a/py_neuromodulation/filter/kalman_filter.py b/py_neuromodulation/filter/kalman_filter.py
index b7ae1075..0ebfe195 100644
--- a/py_neuromodulation/filter/kalman_filter.py
+++ b/py_neuromodulation/filter/kalman_filter.py
@@ -1,7 +1,9 @@
import numpy as np
from typing import TYPE_CHECKING
+
from py_neuromodulation.utils.types import NMBaseModel
+from py_neuromodulation.utils.pydantic_extensions import NMErrorList
if TYPE_CHECKING:
@@ -22,13 +24,22 @@ class KalmanSettings(NMBaseModel):
"HFA",
]
- def validate_fbands(self, settings: "NMSettings") -> None:
- assert all(
+ def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
+ errors: NMErrorList = NMErrorList()
+
+ if not all(
(item in settings.frequency_ranges_hz for item in self.frequency_bands)
- ), (
- "Frequency bands for Kalman filter must also be specified in "
- "bandpass_filter_settings."
- )
+ ):
+ errors.add_error(
+ "Frequency bands for Kalman filter must also be specified in "
+ "frequency_ranges_hz.",
+ location=[
+ "kalman_filter_settings",
+ "frequency_bands",
+ ],
+ )
+
+ return errors
def define_KF(Tp, sigma_w, sigma_v):
diff --git a/py_neuromodulation/gui/backend/app_backend.py b/py_neuromodulation/gui/backend/app_backend.py
index 06b929f0..01858fe6 100644
--- a/py_neuromodulation/gui/backend/app_backend.py
+++ b/py_neuromodulation/gui/backend/app_backend.py
@@ -13,6 +13,7 @@
)
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
+from pydantic import ValidationError
from . import app_pynm
from .app_socket import WebSocketManager
@@ -74,21 +75,63 @@ async def healthcheck():
####################
##### SETTINGS #####
####################
-
@self.get("/api/settings")
- async def get_settings():
- return self.pynm_state.settings.process_for_frontend()
+ async def get_settings(
+ reset: bool = Query(False, description="Reset settings to default"),
+ ):
+ if reset:
+ settings = NMSettings.get_default()
+ else:
+ settings = self.pynm_state.settings
+
+ return settings.serialize_with_metadata()
@self.post("/api/settings")
- async def update_settings(data: dict):
+ async def update_settings(data: dict, validate_only: bool = Query(False)):
try:
- self.pynm_state.settings = NMSettings.model_validate(data)
- self.logger.info(self.pynm_state.settings.features)
- return self.pynm_state.settings.model_dump()
- except ValueError as e:
+ # First, validate with Pydantic
+ try:
+ # TODO: check if this works properly or needs model_validate_strings
+ validated_settings = NMSettings.model_validate(data)
+ except ValidationError as e:
+ if not validate_only:
+ # If validation failed but we wanted to upload, return error
+ self.logger.error(f"Error validating settings: {e}")
+ raise HTTPException(
+ status_code=422,
+ detail={
+ "error": "Error validating settings",
+ "details": str(e),
+ },
+ )
+ # Else return list of errors
+ return {
+ "valid": False,
+ "errors": [err for err in e.errors()],
+ "details": str(e),
+ }
+
+ # If validation succesful, return or update settings
+ if validate_only:
+ return {
+ "valid": True,
+ "settings": validated_settings.serialize_with_metadata(),
+ }
+
+ self.pynm_state.settings = validated_settings
+ self.logger.info("Settings successfully updated")
+
+ return {
+ "valid": True,
+ "settings": self.pynm_state.settings.serialize_with_metadata(),
+ }
+
+ # If something else than validation went wrong, return error
+ except Exception as e:
+ self.logger.error(f"Error validating/updating settings: {e}")
raise HTTPException(
status_code=422,
- detail={"error": "Validation failed", "details": str(e)},
+ detail={"error": "Error uploading settings", "details": str(e)},
)
########################
@@ -105,10 +148,9 @@ async def handle_stream_control(data: dict):
self.logger.info("Starting stream")
self.pynm_state.start_run_function(
- websocket_manager=self.websocket_manager,
+ websocket_manager=self.websocket_manager,
)
-
if action == "stop":
self.logger.info("Stopping stream")
self.pynm_state.stream_handling_queue.put("stop")
diff --git a/py_neuromodulation/gui/backend/app_pynm.py b/py_neuromodulation/gui/backend/app_pynm.py
index 523a4f62..d12ef89c 100644
--- a/py_neuromodulation/gui/backend/app_pynm.py
+++ b/py_neuromodulation/gui/backend/app_pynm.py
@@ -12,8 +12,13 @@
from py_neuromodulation.utils.io import read_mne_data
from py_neuromodulation import logger
-async def run_stream_controller(feature_queue: queue.Queue, rawdata_queue: queue.Queue,
- websocket_manager: "WebSocketManager", stop_event: threading.Event):
+
+async def run_stream_controller(
+ feature_queue: queue.Queue,
+ rawdata_queue: queue.Queue,
+ websocket_manager: "WebSocketManager",
+ stop_event: threading.Event,
+):
while not stop_event.wait(0.002):
if not feature_queue.empty() and websocket_manager is not None:
feature_dict = feature_queue.get()
@@ -22,16 +27,23 @@ async def run_stream_controller(feature_queue: queue.Queue, rawdata_queue: queue
if not rawdata_queue.empty() and websocket_manager is not None:
raw_data = rawdata_queue.get()
-
+
await websocket_manager.send_cbor(raw_data)
-def run_stream_controller_sync(feature_queue: queue.Queue,
- rawdata_queue: queue.Queue,
- websocket_manager: "WebSocketManager",
- stop_event: threading.Event
- ):
+
+def run_stream_controller_sync(
+ feature_queue: queue.Queue,
+ rawdata_queue: queue.Queue,
+ websocket_manager: "WebSocketManager",
+ stop_event: threading.Event,
+):
# The run_stream_controller needs to be started as an asyncio function due to the async websocket
- asyncio.run(run_stream_controller(feature_queue, rawdata_queue, websocket_manager, stop_event))
+ asyncio.run(
+ run_stream_controller(
+ feature_queue, rawdata_queue, websocket_manager, stop_event
+ )
+ )
+
class PyNMState:
def __init__(
@@ -52,12 +64,10 @@ def __init__(
self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1]))
self.settings: NMSettings = NMSettings(sampling_rate_features=10)
-
def start_run_function(
self,
websocket_manager=None,
) -> None:
-
self.stream.settings = self.settings
self.stream_handling_queue = queue.Queue()
@@ -70,7 +80,7 @@ def start_run_function(
if os.path.exists(self.decoding_model_path):
self.decoder = RealTimeDecoder(self.decoding_model_path)
else:
- logger.debug("Passed decoding model path does't exist")
+ logger.debug("Passed decoding model path does't exist")
# Stop event
# .set() is called from app_backend
self.stop_event_ws = threading.Event()
@@ -78,16 +88,19 @@ def start_run_function(
self.stream_controller_thread = Thread(
target=run_stream_controller_sync,
daemon=True,
- args=(self.feature_queue,
- self.rawdata_queue,
- websocket_manager,
- self.stop_event_ws
- ),
+ args=(
+ self.feature_queue,
+ self.rawdata_queue,
+ websocket_manager,
+ self.stop_event_ws,
+ ),
)
is_stream_lsl = self.lsl_stream_name is not None
- stream_lsl_name = self.lsl_stream_name if self.lsl_stream_name is not None else ""
-
+ stream_lsl_name = (
+ self.lsl_stream_name if self.lsl_stream_name is not None else ""
+ )
+
# The run_func_thread is terminated through the stream_handling_queue
# which initiates to break the data generator and save the features
out_dir = "" if self.out_dir == "default" else self.out_dir
@@ -95,15 +108,15 @@ def start_run_function(
target=self.stream.run,
daemon=True,
kwargs={
- "out_dir" : out_dir,
- "experiment_name" : self.experiment_name,
- "stream_handling_queue" : self.stream_handling_queue,
- "is_stream_lsl" : is_stream_lsl,
- "stream_lsl_name" : stream_lsl_name,
- "feature_queue" : self.feature_queue,
- "simulate_real_time" : True,
- "rawdata_queue" : self.rawdata_queue,
- "decoder" : self.decoder,
+ "out_dir": out_dir,
+ "experiment_name": self.experiment_name,
+ "stream_handling_queue": self.stream_handling_queue,
+ "is_stream_lsl": is_stream_lsl,
+ "stream_lsl_name": stream_lsl_name,
+ "feature_queue": self.feature_queue,
+ "simulate_real_time": True,
+ "rawdata_queue": self.rawdata_queue,
+ "decoder": self.decoder,
},
)
@@ -157,10 +170,10 @@ def setup_lsl_stream(
settings=self.settings,
)
self.logger.info("stream setup")
- #self.settings: NMSettings = NMSettings(sampling_rate_features=sfreq)
+ # self.settings: NMSettings = NMSettings(sampling_rate_features=sfreq)
self.logger.info("settings setup")
break
-
+
if channels.shape[0] == 0:
self.logger.error(f"Stream {lsl_stream_name} not found")
raise ValueError(f"Stream {lsl_stream_name} not found")
diff --git a/py_neuromodulation/processing/data_preprocessor.py b/py_neuromodulation/processing/data_preprocessor.py
index 319926d9..3b9a9088 100644
--- a/py_neuromodulation/processing/data_preprocessor.py
+++ b/py_neuromodulation/processing/data_preprocessor.py
@@ -1,12 +1,12 @@
from typing import TYPE_CHECKING, Type
-from py_neuromodulation.utils.types import PreprocessorName, NMPreprocessor
+from py_neuromodulation.utils.types import PREPROCESSOR_NAME, NMPreprocessor
if TYPE_CHECKING:
import numpy as np
import pandas as pd
from py_neuromodulation.stream.settings import NMSettings
-PREPROCESSOR_DICT: dict[PreprocessorName, str] = {
+PREPROCESSOR_DICT: dict[PREPROCESSOR_NAME, str] = {
"preprocessing_filter": "PreprocessingFilter",
"notch_filter": "NotchFilter",
"raw_resampling": "Resampler",
diff --git a/py_neuromodulation/processing/filter_preprocessing.py b/py_neuromodulation/processing/filter_preprocessing.py
index e6515ae0..2c5addbc 100644
--- a/py_neuromodulation/processing/filter_preprocessing.py
+++ b/py_neuromodulation/processing/filter_preprocessing.py
@@ -1,22 +1,14 @@
import numpy as np
-from pydantic import Field
from typing import TYPE_CHECKING
from py_neuromodulation.utils.types import BoolSelector, FrequencyRange, NMPreprocessor
+from py_neuromodulation.utils.pydantic_extensions import NMField
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
-FILTER_SETTINGS_MAP = {
- "bandstop_filter": "bandstop_filter_settings",
- "bandpass_filter": "bandpass_filter_settings",
- "lowpass_filter": "lowpass_filter_cutoff_hz",
- "highpass_filter": "highpass_filter_cutoff_hz",
-}
-
-
class FilterSettings(BoolSelector):
bandstop_filter: bool = True
bandpass_filter: bool = True
@@ -25,21 +17,23 @@ class FilterSettings(BoolSelector):
bandstop_filter_settings: FrequencyRange = FrequencyRange(100, 160)
bandpass_filter_settings: FrequencyRange = FrequencyRange(2, 200)
- lowpass_filter_cutoff_hz: float = Field(default=200)
- highpass_filter_cutoff_hz: float = Field(default=3)
-
- def get_filter_tuple(self, filter_name) -> tuple[float | None, float | None]:
- filter_value = self[FILTER_SETTINGS_MAP[filter_name]]
-
+ lowpass_filter_cutoff_hz: float = NMField(
+ default=200, gt=0, custom_metadata={"unit": "Hz"}
+ )
+ highpass_filter_cutoff_hz: float = NMField(
+ default=3, gt=0, custom_metadata={"unit": "Hz"}
+ )
+
+ def get_filter_tuple(self, filter_name) -> FrequencyRange:
match filter_name:
case "bandstop_filter":
- return (filter_value.frequency_high_hz, filter_value.frequency_low_hz)
+ return self.bandstop_filter_settings
case "bandpass_filter":
- return (filter_value.frequency_low_hz, filter_value.frequency_high_hz)
+ return self.bandpass_filter_settings
case "lowpass_filter":
- return (None, filter_value)
+ return FrequencyRange(None, self.lowpass_filter_cutoff_hz)
case "highpass_filter":
- return (filter_value, None)
+ return FrequencyRange(self.highpass_filter_cutoff_hz, None)
case _:
raise ValueError(
"Filter name must be one of 'bandstop_filter', 'lowpass_filter', "
diff --git a/py_neuromodulation/processing/normalization.py b/py_neuromodulation/processing/normalization.py
index e3f25d55..713e2308 100644
--- a/py_neuromodulation/processing/normalization.py
+++ b/py_neuromodulation/processing/normalization.py
@@ -3,10 +3,10 @@
import numpy as np
from typing import TYPE_CHECKING, Callable, Literal, get_args
+from py_neuromodulation.utils.pydantic_extensions import NMField
from py_neuromodulation.utils.types import (
NMBaseModel,
- Field,
- NormMethod,
+ NORM_METHOD,
NMPreprocessor,
)
@@ -17,13 +17,13 @@
class NormalizationSettings(NMBaseModel):
- normalization_time_s: float = 30
- normalization_method: NormMethod = "zscore"
- clip: float = Field(default=3, ge=0)
+ normalization_time_s: float = NMField(30, gt=0, custom_metadata={"unit": "s"})
+ normalization_method: NORM_METHOD = NMField(default="zscore")
+ clip: float = NMField(default=3, ge=0, custom_metadata={"unit": "a.u."})
@staticmethod
- def list_normalization_methods() -> list[NormMethod]:
- return list(get_args(NormMethod))
+ def list_normalization_methods() -> list[NORM_METHOD]:
+ return list(get_args(NORM_METHOD))
class FeatureNormalizationSettings(NormalizationSettings): normalize_psd: bool = False
@@ -60,7 +60,7 @@ def __init__(
if self.using_sklearn:
import sklearn.preprocessing as skpp
- NORM_METHODS_SKLEARN: dict[NormMethod, Callable] = {
+ NORM_METHODS_SKLEARN: dict[NORM_METHOD, Callable] = {
"quantile": lambda: skpp.QuantileTransformer(n_quantiles=300),
"robust": skpp.RobustScaler,
"minmax": skpp.MinMaxScaler,
diff --git a/py_neuromodulation/processing/projection.py b/py_neuromodulation/processing/projection.py
index 4cdc3610..3a83a29f 100644
--- a/py_neuromodulation/processing/projection.py
+++ b/py_neuromodulation/processing/projection.py
@@ -1,6 +1,6 @@
import numpy as np
-from pydantic import Field
from py_neuromodulation.utils.types import NMBaseModel
+from py_neuromodulation.utils.pydantic_extensions import NMField
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -9,7 +9,7 @@
class ProjectionSettings(NMBaseModel):
- max_dist_mm: float = Field(default=20.0, gt=0.0)
+ max_dist_mm: float = NMField(default=20.0, gt=0.0, custom_metadata={"unit": "mm"})
class Projection:
diff --git a/py_neuromodulation/processing/resample.py b/py_neuromodulation/processing/resample.py
index 08a4e115..1b7107eb 100644
--- a/py_neuromodulation/processing/resample.py
+++ b/py_neuromodulation/processing/resample.py
@@ -1,11 +1,14 @@
"""Module for resampling."""
import numpy as np
-from py_neuromodulation.utils.types import NMBaseModel, Field, NMPreprocessor
+from py_neuromodulation.utils.types import NMBaseModel, NMPreprocessor
+from py_neuromodulation.utils.pydantic_extensions import NMField
class ResamplerSettings(NMBaseModel):
- resample_freq_hz: float = Field(default=1000, gt=0)
+ resample_freq_hz: float = NMField(
+ default=1000, gt=0, custom_metadata={"unit": "Hz"}
+ )
class Resampler(NMPreprocessor):
diff --git a/py_neuromodulation/stream/settings.py b/py_neuromodulation/stream/settings.py
index 669c97db..bb3ca7cd 100644
--- a/py_neuromodulation/stream/settings.py
+++ b/py_neuromodulation/stream/settings.py
@@ -1,19 +1,22 @@
"""Module for handling settings."""
from pathlib import Path
-from typing import ClassVar
-from pydantic import Field, model_validator
+from typing import Any, ClassVar
+from pydantic import model_validator, ValidationError
+from pydantic.functional_validators import ModelWrapValidatorHandler
-from py_neuromodulation import PYNM_DIR, logger, user_features
+from py_neuromodulation import logger, user_features
+from types import SimpleNamespace
from py_neuromodulation.utils.types import (
BoolSelector,
FrequencyRange,
- PreprocessorName,
_PathLike,
NMBaseModel,
- NormMethod,
+ NORM_METHOD,
+ PreprocessorList,
)
+from py_neuromodulation.utils.pydantic_extensions import NMErrorList, NMField
from py_neuromodulation.processing.filter_preprocessing import FilterSettings
from py_neuromodulation.processing.normalization import FeatureNormalizationSettings, NormalizationSettings
@@ -31,7 +34,9 @@
from py_neuromodulation.features import BurstsSettings
-class FeatureSelection(BoolSelector):
+# TONI: this class has the proble that if a feature is absent,
+# it won't default to False but to whatever is defined here as default
+class FeatureSelector(BoolSelector):
raw_hjorth: bool = True
return_raw: bool = True
bandpass_filter: bool = False
@@ -59,8 +64,12 @@ class NMSettings(NMBaseModel):
_instances: ClassVar[list["NMSettings"]] = []
# General settings
- sampling_rate_features_hz: float = Field(default=10, gt=0)
- segment_length_features_ms: float = Field(default=1000, gt=0)
+ sampling_rate_features_hz: float = NMField(
+ default=10, gt=0, custom_metadata={"unit": "Hz"}
+ )
+ segment_length_features_ms: float = NMField(
+ default=1000, gt=0, custom_metadata={"unit": "ms"}
+ )
frequency_ranges_hz: dict[str, FrequencyRange] = {
"theta": FrequencyRange(4, 8),
"alpha": FrequencyRange(8, 12),
@@ -72,11 +81,14 @@ class NMSettings(NMBaseModel):
}
# Preproceessing settings
- preprocessing: list[PreprocessorName] = [
- "raw_resampling",
- "notch_filter",
- "re_referencing",
- ]
+ preprocessing: PreprocessorList = PreprocessorList(
+ [
+ "raw_resampling",
+ "notch_filter",
+ "re_referencing",
+ ]
+ )
+
raw_resampling_settings: ResamplerSettings = ResamplerSettings()
preprocessing_filter: FilterSettings = FilterSettings()
raw_normalization_settings: NormalizationSettings = NormalizationSettings()
@@ -88,7 +100,7 @@ class NMSettings(NMBaseModel):
project_subcortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=5)
# Feature settings
- features: FeatureSelection = FeatureSelection()
+ features: FeatureSelector = FeatureSelector()
fft_settings: OscillatorySettings = OscillatorySettings()
welch_settings: OscillatorySettings = OscillatorySettings()
@@ -126,10 +138,23 @@ def _remove_feature(cls, feature: str) -> None:
for instance in cls._instances:
delattr(instance.features, feature)
- @model_validator(mode="after")
- def validate_settings(self):
+ @model_validator(mode="wrap") # type: ignore[reportIncompatibleMethodOverride]
+ def validate_settings(self, handler: ModelWrapValidatorHandler) -> Any:
+ # Perform all necessary custom validations in the settings class and also
+ # all validations in the feature classes that need additional information from
+ # the settings class
+ errors: NMErrorList = NMErrorList()
+
+ try:
+ # validate the model
+ self = handler(self)
+ except ValidationError as e:
+ self = NMSettings.unvalidated(**self)
+ NMSettings.model_fields_set
+ errors.extend(NMErrorList(e.errors()))
+
if len(self.features.get_enabled()) == 0:
- raise ValueError("At least one feature must be selected.")
+ errors.add_error("At least one feature must be selected.")
# Replace spaces with underscores in frequency band names
self.frequency_ranges_hz = {
@@ -138,32 +163,33 @@ def validate_settings(self):
if self.features.bandpass_filter:
# Check BandPass settings frequency bands
- self.bandpass_filter_settings.validate_fbands(self)
+ errors.extend(self.bandpass_filter_settings.validate_fbands(self))
# Check Kalman filter frequency bands
if self.bandpass_filter_settings.kalman_filter:
- self.kalman_filter_settings.validate_fbands(self)
+ errors.extend(self.kalman_filter_settings.validate_fbands(self))
- for k, v in self.frequency_ranges_hz.items():
- if not isinstance(v, FrequencyRange):
- self.frequency_ranges_hz[k] = FrequencyRange.create_from(v)
+ if len(errors) > 0:
+ raise errors.create_error()
return self
def reset(self) -> "NMSettings":
self.features.disable_all()
- self.preprocessing = []
+ self.preprocessing = PreprocessorList()
self.postprocessing.disable_all()
return self
def set_fast_compute(self) -> "NMSettings":
self.reset()
self.features.fft = True
- self.preprocessing = [
- "raw_resampling",
- "notch_filter",
- "re_referencing",
- ]
+ self.preprocessing = PreprocessorList(
+ [
+ "raw_resampling",
+ "notch_filter",
+ "re_referencing",
+ ]
+ )
self.postprocessing.feature_normalization = True
self.postprocessing.project_cortex = False
self.postprocessing.project_subcortex = False
@@ -250,10 +276,10 @@ def from_file(PATH: _PathLike) -> "NMSettings":
@staticmethod
def get_default() -> "NMSettings":
- return NMSettings.from_file(PYNM_DIR / "default_settings.yaml")
+ return NMSettings()
@staticmethod
- def list_normalization_methods() -> list[NormMethod]:
+ def list_normalization_methods() -> list[NORM_METHOD]:
return NormalizationSettings.list_normalization_methods()
def save(
diff --git a/py_neuromodulation/stream/stream.py b/py_neuromodulation/stream/stream.py
index 2ee03421..df8338aa 100644
--- a/py_neuromodulation/stream/stream.py
+++ b/py_neuromodulation/stream/stream.py
@@ -85,7 +85,7 @@ def __init__(
raise ValueError("Either `channels` or `data` must be passed to `Stream`.")
# If features that use frequency ranges are on, test them against nyquist frequency
- use_freq_ranges: list[FeatureName] = [
+ use_freq_ranges: list[FEATURE_NAME] = [
"bandpass_filter",
"stft",
"fft",
@@ -209,7 +209,7 @@ def run(
return_df: bool = True,
simulate_real_time: bool = False,
decoder: RealTimeDecoder = None,
- feature_queue: "queue.Queue | None" = None,
+ feature_queue: "queue.Queue | None" = None,
rawdata_queue: "queue.Queue | None" = None,
stream_handling_queue: "queue.Queue | None" = None,
):
@@ -303,7 +303,9 @@ def run(
if decoder is not None:
ch_to_decode = self.channels.query("used == 1").iloc[0]["name"]
- feature_dict = decoder.predict(feature_dict, ch_to_decode, fft_bands_only=True)
+ feature_dict = decoder.predict(
+ feature_dict, ch_to_decode, fft_bands_only=True
+ )
feature_dict["time"] = np.ceil(this_batch_end * 1000 + 1)
@@ -346,8 +348,9 @@ def run(
self._save_after_stream()
self.is_running = False
- return feature_df # Timon: We could think of returning the feature_reader instead
-
+ return (
+ feature_df # Timon: We could think of returning the feature_reader instead
+ )
def plot_raw_signal(
self,
@@ -437,7 +440,7 @@ def _save_sidecar(self) -> None:
"""Save sidecar incduing fs, coords, sess_right to
out_path_root and subfolder 'folder_name'"""
additional_args = {"sess_right": self.sess_right}
-
+
self.data_processor.save_sidecar(
self.out_dir, self.experiment_name, additional_args
)
diff --git a/py_neuromodulation/utils/pydantic_extensions.py b/py_neuromodulation/utils/pydantic_extensions.py
new file mode 100644
index 00000000..ee89ccea
--- /dev/null
+++ b/py_neuromodulation/utils/pydantic_extensions.py
@@ -0,0 +1,420 @@
+import copy
+from typing import (
+ Any,
+ get_origin,
+ get_args,
+ get_type_hints,
+ TypeVar,
+ Generic,
+ Literal,
+ cast,
+ Sequence,
+)
+from typing_extensions import Unpack, TypedDict
+from pydantic import BaseModel, ConfigDict, model_validator, model_serializer
+
+from pydantic_core import (
+ ErrorDetails,
+ PydanticUndefined,
+ InitErrorDetails,
+ ValidationError,
+)
+from pydantic.fields import FieldInfo, _FieldInfoInputs, _FromFieldInfoInputs
+from pprint import pformat
+
+
+def create_validation_error(
+ error_message: str,
+ location: list[str | int] = [],
+ title: str = "Validation Error",
+ error_type="value_error",
+) -> ValidationError:
+ """
+ Factory function to create a Pydantic v2 ValidationError.
+
+ Args:
+ error_message (str): The error message for the ValidationError.
+ loc (List[str | int], optional): The location of the error. Defaults to None.
+ title (str, optional): The title of the error. Defaults to "Validation Error".
+
+ Returns:
+ ValidationError: A Pydantic ValidationError instance.
+ """
+
+ return ValidationError.from_exception_data(
+ title=title,
+ line_errors=[
+ InitErrorDetails(
+ type=error_type,
+ loc=tuple(location),
+ input=None,
+ ctx={"error": error_message},
+ )
+ ],
+ input_type="python",
+ hide_input=False,
+ )
+
+
+class NMErrorList:
+ """Class to handle data about Pydantic errors.
+ Stores data in a list of InitErrorDetails. Errors can be accessed but not modified.
+
+ :return: _description_
+ :rtype: _type_
+ """
+
+ def __init__(
+ self, errors: Sequence[InitErrorDetails | ErrorDetails] | None = None
+ ) -> None:
+ self.__errors: list[InitErrorDetails | ErrorDetails] = [e for e in errors or []]
+
+ def add_error(
+ self,
+ error_message: str,
+ location: list[str | int] = [],
+ error_type="value_error",
+ ) -> None:
+ self.__errors.append(
+ InitErrorDetails(
+ type=error_type,
+ loc=tuple(location),
+ input=None,
+ ctx={"error": error_message},
+ )
+ )
+
+ def create_error(self, title: str = "Validation Error") -> ValidationError:
+ return ValidationError.from_exception_data(
+ title=title, line_errors=cast(list[InitErrorDetails], self.__errors)
+ )
+
+ def extend(self, errors: "NMErrorList"):
+ self.__errors.extend(errors.__errors)
+
+ def __iter__(self):
+ return iter(self.__errors)
+
+ def __len__(self):
+ return len(self.__errors)
+
+ def __getitem__(self, idx):
+ # Return a copy of the error to prevent modification
+ return copy.deepcopy(self.__errors[idx])
+
+ def __repr__(self):
+ return repr(self.__errors)
+
+ def __str__(self):
+ return str(self.__errors)
+
+
+class _NMExtraFieldInputs(TypedDict, total=False):
+ """Additional fields to add on top of the pydantic FieldInfo"""
+
+ custom_metadata: dict[str, Any]
+
+
+class _NMFieldInfoInputs(_FieldInfoInputs, _NMExtraFieldInputs, total=False):
+ """Combine pydantic FieldInfo inputs with PyNM additional inputs"""
+
+ pass
+
+
+class _NMFromFieldInfoInputs(_FromFieldInfoInputs, _NMExtraFieldInputs, total=False):
+ """Combine pydantic FieldInfo.from_field inputs with PyNM additional inputs"""
+
+ pass
+
+
+class NMFieldInfo(FieldInfo):
+ # Add default values for any other custom fields here
+ _default_values = {}
+
+ def __init__(self, **kwargs: Unpack[_NMFieldInfoInputs]) -> None:
+ self.custom_metadata: dict[str, Any] = kwargs.pop("custom_metadata", {})
+ super().__init__(**kwargs)
+
+ @staticmethod
+ def from_field(
+ default: Any = PydanticUndefined,
+ **kwargs: Unpack[_NMFromFieldInfoInputs],
+ ) -> "NMFieldInfo":
+ if "annotation" in kwargs:
+ raise TypeError('"annotation" is not permitted as a Field keyword argument')
+ return NMFieldInfo(default=default, **kwargs)
+
+ def __repr_args__(self):
+ yield from super().__repr_args__()
+ extra_fields = get_type_hints(_NMExtraFieldInputs)
+ for field in extra_fields:
+ value = getattr(self, field)
+ yield field, value
+
+
+def NMField(
+ default: Any = PydanticUndefined,
+ **kwargs: Unpack[_NMFromFieldInfoInputs],
+) -> Any:
+ return NMFieldInfo.from_field(default=default, **kwargs)
+
+
+class NMBaseModel(BaseModel):
+ # model_config = ConfigDict(validate_assignment=False, extra="allow")
+
+ def __init__(self, *args, **kwargs) -> None:
+ """Pydantic does not support positional arguments by default.
+ This is a workaround to support positional arguments for models like FrequencyRange.
+ It converts positional arguments to kwargs and then calls the base class __init__.
+ """
+
+ if not args:
+ # Simple case - just use kwargs
+ super().__init__(*args, **kwargs)
+ return
+
+ field_names = list(self.model_fields.keys())
+ # If we have more positional args than fields, that's an error
+ if len(args) > len(field_names):
+ raise ValueError(
+ f"Too many positional arguments. Expected at most {len(field_names)}, got {len(args)}"
+ )
+
+ # Convert positional args to kwargs, checking for duplicates if args:
+ complete_kwargs = {}
+ for i, arg in enumerate(args):
+ field_name = field_names[i]
+ if field_name in kwargs:
+ raise ValueError(
+ f"Got multiple values for field '{field_name}': "
+ f"positional argument and keyword argument"
+ )
+ complete_kwargs[field_name] = arg
+
+ # Add remaining kwargs
+ complete_kwargs.update(kwargs)
+ super().__init__(**complete_kwargs)
+
+ __init__.__pydantic_base_init__ = True # type: ignore
+
+ def __str__(self):
+ return pformat(self.model_dump())
+
+ # def __repr__(self):
+ # return pformat(self.model_dump())
+
+ def validate(self, context: Any | None = None) -> Any: # type: ignore
+ return self.model_validate(self.model_dump(), context=context)
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __setitem__(self, key, value) -> None:
+ setattr(self, key, value)
+
+ @property
+ def fields(self) -> dict[str, FieldInfo | NMFieldInfo]:
+ return self.model_fields # type: ignore
+
+ def serialize_with_metadata(self):
+ result: dict[str, Any] = {"__field_type__": self.__class__.__name__}
+
+ for field_name, field_info in self.fields.items():
+ value = getattr(self, field_name)
+ if isinstance(value, NMBaseModel):
+ result[field_name] = value.serialize_with_metadata()
+ elif isinstance(value, list):
+ result[field_name] = [
+ item.serialize_with_metadata()
+ if isinstance(item, NMBaseModel)
+ else item
+ for item in value
+ ]
+ elif isinstance(value, dict):
+ result[field_name] = {
+ k: v.serialize_with_metadata() if isinstance(v, NMBaseModel) else v
+ for k, v in value.items()
+ }
+ else:
+ result[field_name] = value
+
+ # Extract unit information from Annotated type
+ if isinstance(field_info, NMFieldInfo):
+ # Convert scalar value to dict with metadata
+ field_dict = {
+ "value": value,
+ # __field_type__ will be overwritte if set in custom_metadata
+ "__field_type__": type(value).__name__,
+ **{
+ f"__{tag}__": value
+ for tag, value in field_info.custom_metadata.items()
+ },
+ }
+ # Add possible values for Literal types
+ if get_origin(field_info.annotation) is Literal:
+ field_dict["__valid_values__"] = list(
+ get_args(field_info.annotation)
+ )
+
+ result[field_name] = field_dict
+ return result
+
+ @classmethod
+ def unvalidated(cls, **data: Any) -> Any:
+ def process_value(value: Any, field_type: Any) -> Any:
+ if isinstance(value, dict) and hasattr(
+ field_type, "__pydantic_core_schema__"
+ ):
+ # Recursively handle nested Pydantic models
+ return field_type.unvalidated(**value)
+ elif isinstance(value, list):
+ # Handle lists of Pydantic models
+ if hasattr(field_type, "__args__") and hasattr(
+ field_type.__args__[0], "__pydantic_core_schema__"
+ ):
+ return [
+ field_type.__args__[0].unvalidated(**item)
+ if isinstance(item, dict)
+ else item
+ for item in value
+ ]
+ return value
+
+ processed_data = {}
+ for name, field in cls.model_fields.items():
+ try:
+ value = data[name]
+ processed_data[name] = process_value(value, field.annotation)
+ except KeyError:
+ if not field.is_required():
+ processed_data[name] = copy.deepcopy(field.default)
+ else:
+ raise TypeError(f"Missing required keyword argument {name!r}")
+
+ self = cls.__new__(cls)
+ object.__setattr__(self, "__dict__", processed_data)
+ object.__setattr__(self, "__pydantic_private__", {"extra": None})
+ object.__setattr__(self, "__pydantic_fields_set__", set(processed_data.keys()))
+ return self
+
+
+#################################
+#### Generic Pydantic models ####
+#################################
+
+
+T = TypeVar("T")
+C = TypeVar("C", list, tuple)
+
+
+class NMSequenceModel(NMBaseModel, Generic[C]):
+ """Base class for sequence models with a single root value"""
+
+ root: C = NMField(default_factory=list)
+
+ # Class variable for aliases - override in subclasses
+ __aliases__: dict[int, list[str]] = {}
+
+ def __init__(self, *args, **kwargs) -> None:
+ if len(args) == 1 and isinstance(args[0], (list, tuple)):
+ kwargs["root"] = args[0]
+ elif len(args) == 1:
+ kwargs["root"] = [args[0]]
+ elif len(args) > 1: # Add this case
+ kwargs["root"] = tuple(args)
+ super().__init__(**kwargs)
+
+ def __iter__(self): # type: ignore[reportIncompatibleMethodOverride]
+ return iter(self.root)
+
+ def __getitem__(self, idx):
+ return self.root[idx]
+
+ def __len__(self):
+ return len(self.root)
+
+ def model_dump(self): # type: ignore[reportIncompatibleMethodOverride]
+ return self.root
+
+ def model_dump_json(self, **kwargs):
+ import json
+
+ return json.dumps(self.root, **kwargs)
+
+ def serialize_with_metadata(self) -> dict[str, Any]:
+ result = {"__field_type__": self.__class__.__name__, "value": self.root}
+
+ # Add any field metadata from the root field
+ field_info = self.model_fields.get("root")
+ if isinstance(field_info, NMFieldInfo):
+ for tag, value in field_info.custom_metadata.items():
+ result[f"__{tag}__"] = value
+
+ return result
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_input(cls, value: Any) -> dict[str, Any]:
+ # If it's a dict, just return it
+ if isinstance(value, dict):
+ if "root" in value:
+ return value
+
+ # Check for aliased fields if class has aliases defined
+ if hasattr(cls, "__aliases__"):
+ # Collect all possible alias names for each position
+ alias_values = []
+ max_index = max(cls.__aliases__.keys()) if cls.__aliases__ else -1
+
+ # Try to find a value for each position using its aliases
+ for i in range(max_index + 1):
+ aliases = cls.__aliases__.get(i, [])
+ value_found = None
+
+ # Try each alias for this position
+ for alias in aliases:
+ if alias in value:
+ value_found = value[alias]
+ break
+
+ if value_found is not None:
+ alias_values.append(value_found)
+ else:
+ # If we're missing any position's value, don't use aliases
+ break
+
+ # If we found all values through aliases, use them
+ if len(alias_values) == max_index + 1:
+ return {"root": alias_values}
+
+ # if it's a sequence, return the value as the root
+ if isinstance(value, (list, tuple)):
+ return {"root": value}
+
+ # Else, make it a list
+ return {"root": [value]}
+
+ @model_serializer
+ def ser_model(self):
+ return self.root
+
+ # Custom validator to skip the 'root' field in validation errors
+ @model_validator(mode="wrap") # type: ignore[reportIncompatibleMethodOverride]
+ def rewrite_error_locations(self, handler):
+ try:
+ return handler(self)
+ except ValidationError as e:
+ errors = []
+ for err in e.errors():
+ loc = list(err["loc"])
+ # Find and remove 'root' from the location path
+ if "root" in loc:
+ root_idx = loc.index("root")
+ if root_idx < len(loc) - 1:
+ loc = loc[:root_idx] + loc[root_idx + 1 :]
+ err["loc"] = tuple(loc)
+ errors.append(err)
+ print(errors)
+ raise ValidationError.from_exception_data(
+ title="ValidationError", line_errors=errors
+ )
diff --git a/py_neuromodulation/utils/types.py b/py_neuromodulation/utils/types.py
index 7886685d..13b2923e 100644
--- a/py_neuromodulation/utils/types.py
+++ b/py_neuromodulation/utils/types.py
@@ -1,12 +1,14 @@
from os import PathLike
from math import isnan
-from typing import Any, Literal, Protocol, TYPE_CHECKING, runtime_checkable
-from pydantic import ConfigDict, Field, model_validator, BaseModel
-from pydantic_core import ValidationError, InitErrorDetails
-from pprint import pformat
+from typing import Literal, TYPE_CHECKING, Any, TypeVar
+from pydantic import BaseModel, ConfigDict, model_validator
+from .pydantic_extensions import NMBaseModel, NMSequenceModel, NMField
+from abc import abstractmethod
+
from collections.abc import Sequence
from datetime import datetime
+
if TYPE_CHECKING:
import numpy as np
from py_neuromodulation import NMSettings
@@ -17,8 +19,7 @@
_PathLike = str | PathLike
-
-FeatureName = Literal[
+FEATURE_NAME = Literal[
"raw_hjorth",
"return_raw",
"bandpass_filter",
@@ -35,7 +36,7 @@
"bispectrum",
]
-PreprocessorName = Literal[
+PREPROCESSOR_NAME = Literal[
"preprocessing_filter",
"notch_filter",
"raw_resampling",
@@ -43,7 +44,7 @@
"raw_normalization",
]
-NormMethod = Literal[
+NORM_METHOD = Literal[
"mean",
"median",
"zscore",
@@ -54,13 +55,8 @@
"minmax",
]
-###################################
-######## PROTOCOL CLASSES ########
-###################################
-
-@runtime_checkable
-class NMFeature(Protocol):
+class NMFeature:
def __init__(
self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int | float
) -> None: ...
@@ -81,131 +77,47 @@ def calc_feature(self, data: "np.ndarray") -> dict:
...
-class NMPreprocessor(Protocol):
- def __init__(self, sfreq: float, settings: "NMSettings") -> None: ...
-
+class NMPreprocessor:
def process(self, data: "np.ndarray") -> "np.ndarray": ...
-###################################
-######## PYDANTIC CLASSES ########
-###################################
-
-
-class NMBaseModel(BaseModel):
- model_config = ConfigDict(validate_assignment=False, extra="allow")
+class PreprocessorList(NMSequenceModel[list[PREPROCESSOR_NAME]]):
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+ # Useless contructor to prevent linter from complaining
def __init__(self, *args, **kwargs) -> None:
- if kwargs:
- super().__init__(**kwargs)
- else:
- field_names = list(self.model_fields.keys())
- kwargs = {}
- for i in range(len(args)):
- kwargs[field_names[i]] = args[i]
- super().__init__(**kwargs)
-
- def __str__(self):
- return pformat(self.model_dump())
-
- def __repr__(self):
- return pformat(self.model_dump())
-
- def validate(self) -> Any: # type: ignore
- return self.model_validate(self.model_dump())
-
- def __getitem__(self, key):
- return getattr(self, key)
-
- def __setitem__(self, key, value) -> None:
- setattr(self, key, value)
-
- def process_for_frontend(self) -> dict[str, Any]:
- """
- Process the model for frontend use, adding __field_type__ information.
- """
- result = {}
- for field_name, field_value in self.__dict__.items():
- if isinstance(field_value, NMBaseModel):
- processed_value = field_value.process_for_frontend()
- processed_value["__field_type__"] = field_value.__class__.__name__
- result[field_name] = processed_value
- elif isinstance(field_value, list):
- result[field_name] = [
- item.process_for_frontend()
- if isinstance(item, NMBaseModel)
- else item
- for item in field_value
- ]
- elif isinstance(field_value, dict):
- result[field_name] = {
- k: v.process_for_frontend() if isinstance(v, NMBaseModel) else v
- for k, v in field_value.items()
- }
- else:
- result[field_name] = field_value
+ super().__init__(*args, **kwargs)
- return result
+class FrequencyRange(NMSequenceModel[tuple[float, float]]):
+ """Frequency range as (low, high) tuple"""
-class FrequencyRange(NMBaseModel):
- frequency_low_hz: float = Field(gt=0)
- frequency_high_hz: float = Field(gt=0)
+ __aliases__ = {
+ 0: ["frequency_low_hz", "low_frequency_hz"],
+ 1: ["frequency_high_hz", "high_frequency_hz"],
+ }
+ # Useless contructor to prevent linter from complaining
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
- def __getitem__(self, item: int):
- match item:
- case 0:
- return self.frequency_low_hz
- case 1:
- return self.frequency_high_hz
- case _:
- raise IndexError(f"Index {item} out of range")
-
- def as_tuple(self) -> tuple[float, float]:
- return (self.frequency_low_hz, self.frequency_high_hz)
-
- def __iter__(self): # type: ignore
- return iter(self.as_tuple())
-
@model_validator(mode="after")
def validate_range(self):
- if not (isnan(self.frequency_high_hz) or isnan(self.frequency_low_hz)):
- assert (
- self.frequency_high_hz > self.frequency_low_hz
- ), "Frequency high must be greater than frequency low"
+ low, high = self.root
+ if not (isnan(low) or isnan(high)):
+ assert high > low, "High frequency must be greater than low frequency"
return self
- @classmethod
- def create_from(cls, input) -> "FrequencyRange":
- match input:
- case FrequencyRange():
- return input
- case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input:
- return FrequencyRange(
- input["frequency_low_hz"], input["frequency_high_hz"]
- )
- case Sequence() if len(input) == 2:
- return FrequencyRange(input[0], input[1])
- case _:
- raise ValueError("Invalid input for FrequencyRange creation.")
-
- @model_validator(mode="before")
- @classmethod
- def check_input(cls, input):
- match input:
- case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input:
- return input
- case Sequence() if len(input) == 2:
- return {"frequency_low_hz": input[0], "frequency_high_hz": input[1]}
- case _:
- raise ValueError(
- "Value for FrequencyRange must be a dictionary, "
- "or a sequence of 2 numeric values, "
- f"but got {input} instead."
- )
+ # Alias properties
+ @property
+ def frequency_low_hz(self) -> float:
+ """Lower frequency bound in Hz"""
+ return self.root[0]
+
+ @property
+ def frequency_high_hz(self) -> float:
+ """Upper frequency bound in Hz"""
+ return self.root[1]
class BoolSelector(NMBaseModel):
@@ -238,46 +150,135 @@ def print_all(cls):
for f in cls.list_all():
print(f)
- @classmethod
- def get_fields(cls):
- return cls.model_fields
+
+################################################
+### Generic Pydantic models for the frontend ###
+################################################
+
+
+class UniqueStringSequence(NMSequenceModel[list[str]]):
+ """
+ A sequence of strings where:
+ - Values must come from a predefined set
+ - Each value can only appear once
+ - Order is preserved
+ """
+
+ @property
+ @abstractmethod
+ def valid_values(self) -> list[str]:
+ """Each subclass must implement this to provide its valid values"""
+ raise NotImplementedError
+
+ def __init__(self, **data):
+ valid_values = data.pop("valid_values", [])
+ super().__init__(**data)
+ object.__setattr__(self, "valid_values", valid_values)
+
+ @model_validator(mode="after")
+ def validate_sequence(self):
+ seen = set()
+ validated = []
+ for item in self.root:
+ if item not in seen and item in self.valid_values:
+ seen.add(item)
+ validated.append(item)
+ self.root = validated
+ return self
+
+ def serialize_with_metadata(self) -> dict[str, Any]:
+ result = super().serialize_with_metadata()
+ result["__valid_values__"] = self.valid_values
+ return result
-def create_validation_error(
- error_message: str,
- loc: list[str | int] = None,
- title: str = "Validation Error",
- input_type: Literal["python", "json"] = "python",
- hide_input: bool = False,
-) -> ValidationError:
+class DependentKeysList(NMSequenceModel[list[str]]):
+ """
+ A list of strings where valid values are keys from another settings field
"""
- Factory function to create a Pydantic v2 ValidationError instance from a single error message.
- Args:
- error_message (str): The error message for the ValidationError.
- loc (List[str | int], optional): The location of the error. Defaults to None.
- title (str, optional): The title of the error. Defaults to "Validation Error".
- input_type (Literal["python", "json"], optional): Whether the error is for a Python object or JSON. Defaults to "python".
- hide_input (bool, optional): Whether to hide the input value in the error message. Defaults to False.
+ root: list[str] = NMField(default_factory=list)
+ source_dict: dict[str, Any] = NMField(default_factory=dict, exclude=True)
+
+ def __init__(self, **data):
+ source_dict = data.pop("source_dict", {})
+ super().__init__(**data)
+ object.__setattr__(self, "source_dict", source_dict)
+
+ @model_validator(mode="after")
+ def validate_keys(self):
+ valid_keys = set(self.source_dict.keys())
+ seen = set()
+ validated = []
+ for item in self.root:
+ if item not in seen and item in valid_keys:
+ seen.add(item)
+ validated.append(item)
+ self.root = validated
+ return self
+
+ def serialize_with_metadata(self) -> dict[str, Any]:
+ result = super().serialize_with_metadata()
+ result["__valid_values__"] = list(self.source_dict.keys())
+ result["__dependent__"] = True # Indicates this needs dynamic updating
+ return result
+
- Returns:
- ValidationError: A Pydantic ValidationError instance.
+class StringPairsList(NMSequenceModel[list[tuple[str, str]]]):
+ """
+ A list of string pairs where values must come from predetermined lists
"""
- if loc is None:
- loc = []
-
- line_errors = [
- InitErrorDetails(
- type="value_error", loc=tuple(loc), input=None, ctx={"error": error_message}
- )
- ]
-
- return ValidationError.from_exception_data(
- title=title,
- line_errors=line_errors,
- input_type=input_type,
- hide_input=hide_input,
- )
+
+ root: list[tuple[str, str]] = NMField(default_factory=list)
+ valid_first: list[str] = NMField(default_factory=list, exclude=True)
+ valid_second: list[str] = NMField(default_factory=list, exclude=True)
+
+ def __init__(self, **data):
+ valid_first = data.pop("valid_first", [])
+ valid_second = data.pop("valid_second", [])
+ super().__init__(**data)
+ object.__setattr__(self, "valid_first", valid_first)
+ object.__setattr__(self, "valid_second", valid_second)
+
+ @model_validator(mode="after")
+ def validate_pairs(self):
+ validated = [
+ (first, second)
+ for first, second in self.root
+ if first in self.valid_first and second in self.valid_second
+ ]
+ self.root = validated
+ return self
+
+ def serialize_with_metadata(self) -> dict[str, Any]:
+ result = super().serialize_with_metadata()
+ result["__valid_first__"] = self.valid_first
+ result["__valid_second__"] = self.valid_second
+ return result
+
+
+# class LiteralValue(NMValueModel[str]):
+# """
+# A string field that must be one of a predefined set of literals
+# """
+
+# valid_values: list[str] = NMField(default_factory=list, exclude=True)
+
+# def __init__(self, **data):
+# valid_values = data.pop("valid_values", [])
+# super().__init__(**data)
+# object.__setattr__(self, "valid_values", valid_values)
+
+# @model_validator(mode="after")
+# def validate_value(self):
+# if self.root not in self.valid_values:
+# raise ValueError(f"Value must be one of: {self.valid_values}")
+# return self
+
+# def serialize_with_metadata(self) -> dict[str, Any]:
+# result = super().serialize_with_metadata()
+# result["__valid_values__"] = self.valid_values
+# return result
#################
diff --git a/pyproject.toml b/pyproject.toml
index 5facc8ba..29c85832 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,7 +42,6 @@ dependencies = [
"nolds >=0.6.1",
"numpy >= 1.21.2",
"pandas >= 2.0.0",
- "scikit-image",
"scikit-learn >= 0.24.2",
"scikit-optimize",
"scipy >= 1.7.1",
@@ -55,13 +54,11 @@ dependencies = [
"statsmodels",
"mne-lsl>=1.2.0",
"pynput",
- #"pyqt5",
"pydantic>=2.7.3",
"llvmlite>=0.43.0",
"pywebview",
"fastapi",
- "uvicorn>=0.30.6",
- "websockets>=13.0",
+ "uvicorn[standard]>=0.30.6",
"seaborn >= 0.11",
# exists only because of nolds?
"numba>=0.60.0",
@@ -79,7 +76,7 @@ dev = [
"pytest-cov",
"pytest-sugar",
"notebook",
- "watchdog",
+ "uvicorn[standard]",
]
docs = [
"py-neuromodulation[dev]",
diff --git a/tests/test_osc_features.py b/tests/test_osc_features.py
index 8cf84111..fd2428d1 100644
--- a/tests/test_osc_features.py
+++ b/tests/test_osc_features.py
@@ -3,11 +3,11 @@
import numpy as np
from py_neuromodulation import NMSettings, Stream, features
-from py_neuromodulation.utils.types import FeatureName
+from py_neuromodulation.utils.types import FEATURE_NAME
def setup_osc_settings(
- osc_feature_name: FeatureName,
+ osc_feature_name: FEATURE_NAME,
osc_feature_setting: str,
windowlength_ms: int,
log_transform: bool,