diff --git a/gui_dev/package.json b/gui_dev/package.json
index 99fb9b0c..6766b874 100644
--- a/gui_dev/package.json
+++ b/gui_dev/package.json
@@ -20,7 +20,7 @@
"react-dom": "next",
"react-icons": "^5.3.0",
"react-plotly.js": "^2.6.0",
- "react-router-dom": "^6.28.0",
+ "react-router-dom": "^7.0.1",
"zustand": "latest"
},
"devDependencies": {
diff --git a/gui_dev/src/App.jsx b/gui_dev/src/App.jsx
index 7ebfdc55..a97f8c75 100644
--- a/gui_dev/src/App.jsx
+++ b/gui_dev/src/App.jsx
@@ -45,8 +45,6 @@ export const App = () => {
const connectSocket = useSocketStore((state) => state.connectSocket);
const disconnectSocket = useSocketStore((state) => state.disconnectSocket);
- connectSocket();
-
useEffect(() => {
console.log("Connecting socket from App component...");
connectSocket();
diff --git a/gui_dev/src/components/FileBrowser/FileBrowser.jsx b/gui_dev/src/components/FileBrowser/FileBrowser.jsx
index 9cf315c3..84b5a91b 100644
--- a/gui_dev/src/components/FileBrowser/FileBrowser.jsx
+++ b/gui_dev/src/components/FileBrowser/FileBrowser.jsx
@@ -1,5 +1,5 @@
import { useReducer, useEffect } from "react";
-import { getBackendURL } from "@/utils/getBackendURL";
+import { getBackendURL } from "@/utils";
import {
Box,
Button,
@@ -36,7 +36,7 @@ import {
import { QuickAccessSidebar } from "./QuickAccess";
import { FileManager } from "@/utils/FileManager";
-const fileManager = new FileManager("");
+const fileManager = new FileManager(getBackendURL("/api/files"));
const ALLOWED_EXTENSIONS = [".npy", ".vhdr", ".fif", ".edf", ".bdf"];
diff --git a/gui_dev/src/components/FileBrowser/QuickAccess.jsx b/gui_dev/src/components/FileBrowser/QuickAccess.jsx
index 84b70064..6e1253bb 100644
--- a/gui_dev/src/components/FileBrowser/QuickAccess.jsx
+++ b/gui_dev/src/components/FileBrowser/QuickAccess.jsx
@@ -1,5 +1,5 @@
import React, { useState, useEffect } from "react";
-import { getBackendURL } from "@/utils/getBackendURL";
+import { getBackendURL } from "@/utils";
import {
Paper,
Typography,
diff --git a/gui_dev/src/components/StatusBar/StatusBar.jsx b/gui_dev/src/components/StatusBar/StatusBar.jsx
index 36aaf1dc..81f1b32b 100644
--- a/gui_dev/src/components/StatusBar/StatusBar.jsx
+++ b/gui_dev/src/components/StatusBar/StatusBar.jsx
@@ -2,12 +2,14 @@ import { ResizeHandle } from "./ResizeHandle";
import { SocketStatus } from "./SocketStatus";
import { WebviewStatus } from "./WebviewStatus";
-import { useWebviewStore } from "@/stores";
-
+import { useUiStore, useWebviewStore } from "@/stores";
import { Stack } from "@mui/material";
export const StatusBar = () => {
- const { isWebView } = useWebviewStore((state) => state.isWebView);
+ const isWebView = useWebviewStore((state) => state.isWebView);
+ const getStatusBarContent = useUiStore((state) => state.getStatusBarContent);
+
+ const StatusBarContent = getStatusBarContent();
return (
{
bgcolor="background.level1"
borderTop="2px solid"
borderColor="background.level3"
+ height="2rem"
>
-
+ {StatusBarContent && }
+
+ {/* */}
{/* Current experiment */}
{/* Current stream */}
{/* Current activity */}
diff --git a/gui_dev/src/components/TitledBox.jsx b/gui_dev/src/components/TitledBox.jsx
index 56f0266b..a86b6009 100644
--- a/gui_dev/src/components/TitledBox.jsx
+++ b/gui_dev/src/components/TitledBox.jsx
@@ -1,4 +1,4 @@
-import { Container } from "@mui/material";
+import { Box, Container } from "@mui/material";
/**
* Component that uses the Box component to render an HTML fieldset element
@@ -13,14 +13,14 @@ export const TitledBox = ({
children,
...props
}) => (
-
{children}
-
+
);
diff --git a/gui_dev/src/main.jsx b/gui_dev/src/main.jsx
index f4524069..c1cabb62 100644
--- a/gui_dev/src/main.jsx
+++ b/gui_dev/src/main.jsx
@@ -1,7 +1,15 @@
+// import { scan } from "react-scan";
+// scan({
+// enabled: true,
+// log: true, // logs render info to console
+// });
+
import { StrictMode } from "react";
import ReactDOM from "react-dom/client";
import { App } from "./App.jsx";
+// Set up react-scan
+
// Ignore React 19 warning about accessing element.ref
const originalConsoleError = console.error;
console.error = (message, ...messageArgs) => {
diff --git a/gui_dev/src/pages/Settings/DragAndDropList.jsx b/gui_dev/src/pages/Settings/DragAndDropList.jsx
deleted file mode 100644
index afa74b02..00000000
--- a/gui_dev/src/pages/Settings/DragAndDropList.jsx
+++ /dev/null
@@ -1,65 +0,0 @@
-import { useRef } from "react";
-import styles from "./DragAndDropList.module.css";
-import { useOptionsStore } from "@/stores";
-
-export const DragAndDropList = () => {
- const { options, setOptions, addOption, removeOption } = useOptionsStore();
-
- const predefinedOptions = [
- { id: 1, name: "raw_resampling" },
- { id: 2, name: "notch_filter" },
- { id: 3, name: "re_referencing" },
- { id: 4, name: "preprocessing_filter" },
- { id: 5, name: "raw_normalization" },
- ];
-
- const dragOption = useRef(0);
- const draggedOverOption = useRef(0);
-
- function handleSort() {
- const optionsClone = [...options];
- const temp = optionsClone[dragOption.current];
- optionsClone[dragOption.current] = optionsClone[draggedOverOption.current];
- optionsClone[draggedOverOption.current] = temp;
- setOptions(optionsClone);
- }
-
- return (
-
- List
- {options.map((option, index) => (
- (dragOption.current = index)}
- onDragEnter={() => (draggedOverOption.current = index)}
- onDragEnd={handleSort}
- onDragOver={(e) => e.preventDefault()}
- >
-
- {[option.id, ". ", option.name.replace("_", " ")]}
-
-
-
- ))}
-
-
Add Elements
- {predefinedOptions.map((option) => (
-
- ))}
-
-
- );
-};
diff --git a/gui_dev/src/pages/Settings/DragAndDropList.module.css b/gui_dev/src/pages/Settings/DragAndDropList.module.css
deleted file mode 100644
index 76bd7861..00000000
--- a/gui_dev/src/pages/Settings/DragAndDropList.module.css
+++ /dev/null
@@ -1,72 +0,0 @@
-.dragDropList {
- max-width: 400px;
- margin: 0 auto;
- padding: 20px;
-}
-
-.title {
- text-align: center;
- color: #333;
- margin-bottom: 20px;
-}
-
-.item {
- background-color: #f0f0f0;
- border-radius: 8px;
- padding: 15px;
- margin-bottom: 10px;
- cursor: move;
- transition: background-color 0.3s ease;
- display: flex;
- justify-content: space-between;
- align-items: center;
-}
-
-.item:hover {
- background-color: #e0e0e0;
-}
-
-.itemText {
- margin: 0;
- color: #333;
- font-size: 16px;
-}
-
-.removeButton {
- background-color: #ff4d4d;
- border: none;
- border-radius: 4px;
- color: white;
- padding: 5px 10px;
- cursor: pointer;
- transition: background-color 0.3s ease;
-}
-
-.removeButton:hover {
- background-color: #ff1a1a;
-}
-
-.addSection {
- margin-top: 20px;
- text-align: center;
-}
-
-.subtitle {
- color: #333;
- margin-bottom: 10px;
-}
-
-.addButton {
- background-color: #4CAF50;
- border: none;
- border-radius: 4px;
- color: white;
- padding: 10px 15px;
- margin: 5px;
- cursor: pointer;
- transition: background-color 0.3s ease;
-}
-
-.addButton:hover {
- background-color: #45a049;
-}
\ No newline at end of file
diff --git a/gui_dev/src/pages/Settings/Dropdown.jsx b/gui_dev/src/pages/Settings/Dropdown.jsx
deleted file mode 100644
index ad662eef..00000000
--- a/gui_dev/src/pages/Settings/Dropdown.jsx
+++ /dev/null
@@ -1,66 +0,0 @@
-import { useState } from "react";
-import "../App.css";
-
-var stringJson =
- '{"sampling_rate_features_hz":10.0,"segment_length_features_ms":1000.0,"frequency_ranges_hz":{"theta":{"frequency_low_hz":4.0,"frequency_high_hz":8.0},"alpha":{"frequency_low_hz":8.0,"frequency_high_hz":12.0},"low beta":{"frequency_low_hz":13.0,"frequency_high_hz":20.0},"high beta":{"frequency_low_hz":20.0,"frequency_high_hz":35.0},"low gamma":{"frequency_low_hz":60.0,"frequency_high_hz":80.0},"high gamma":{"frequency_low_hz":90.0,"frequency_high_hz":200.0},"HFA":{"frequency_low_hz":200.0,"frequency_high_hz":400.0}},"preprocessing":["raw_resampling","notch_filter","re_referencing"],"raw_resampling_settings":{"resample_freq_hz":1000.0},"preprocessing_filter":{"bandstop_filter":true,"bandpass_filter":true,"lowpass_filter":true,"highpass_filter":true,"bandstop_filter_settings":{"frequency_low_hz":100.0,"frequency_high_hz":160.0},"bandpass_filter_settings":{"frequency_low_hz":2.0,"frequency_high_hz":200.0},"lowpass_filter_cutoff_hz":200.0,"highpass_filter_cutoff_hz":3.0},"raw_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"postprocessing":{"feature_normalization":true,"project_cortex":false,"project_subcortex":false},"feature_normalization_settings":{"normalization_time_s":30.0,"normalization_method":"zscore","clip":3.0},"project_cortex_settings":{"max_dist_mm":20.0},"project_subcortex_settings":{"max_dist_mm":5.0},"features":{"raw_hjorth":true,"return_raw":true,"bandpass_filter":false,"stft":false,"fft":true,"welch":true,"sharpwave_analysis":true,"fooof":false,"nolds":false,"coherence":false,"bursts":true,"linelength":true,"mne_connectivity":false,"bispectrum":false},"fft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"welch_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"stft_settings":{"windowlength_ms":1000,"log_transform":true,"features":{"mean":true,"median":false,"std":false,"max":false},"return_spectrum":false},"bandpass_filter_settings":{"segment_lengths_ms":{"theta":1000,"alpha":500,"low beta":333,"high beta":333,"low gamma":100,"high gamma":100,"HFA":100},"bandpower_features":{"activity":true,"mobility":false,"complexity":false},"log_transform":true,"kalman_filter":false},"kalman_filter_settings":{"Tp":0.1,"sigma_w":0.7,"sigma_v":1.0,"frequency_bands":["theta","alpha","low_beta","high_beta","low_gamma","high_gamma","HFA"]},"burst_settings":{"threshold":75.0,"time_duration_s":30.0,"frequency_bands":["low beta","high beta","low gamma"],"burst_features":{"duration":true,"amplitude":true,"burst_rate_per_s":true,"in_burst":true}},"sharpwave_analysis_settings":{"sharpwave_features":{"peak_left":false,"peak_right":false,"trough":false,"width":false,"prominence":true,"interval":true,"decay_time":false,"rise_time":false,"sharpness":true,"rise_steepness":false,"decay_steepness":false,"slope_ratio":false},"filter_ranges_hz":[{"frequency_low_hz":5.0,"frequency_high_hz":80.0},{"frequency_low_hz":5.0,"frequency_high_hz":30.0}],"detect_troughs":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"detect_peaks":{"estimate":true,"distance_troughs_ms":10.0,"distance_peaks_ms":5.0},"estimator":{"mean":["interval"],"median":[],"max":["prominence","sharpness"],"min":[],"var":[]},"apply_estimator_between_peaks_and_troughs":true},"mne_connectivity":{"method":"plv","mode":"multitaper"},"coherence":{"features":{"mean_fband":true,"max_fband":true,"max_allfbands":true},"method":{"coh":true,"icoh":true},"channels":[],"frequency_bands":["high beta"]},"fooof":{"aperiodic":{"exponent":true,"offset":true,"knee":true},"periodic":{"center_frequency":false,"band_width":false,"height_over_ap":false},"windowlength_ms":800.0,"peak_width_limits":{"frequency_low_hz":0.5,"frequency_high_hz":12.0},"max_n_peaks":3,"min_peak_height":0.0,"peak_threshold":2.0,"freq_range_hz":{"frequency_low_hz":2.0,"frequency_high_hz":40.0},"knee":true},"nolds_features":{"raw":true,"frequency_bands":["low beta"],"features":{"sample_entropy":false,"correlation_dimension":false,"lyapunov_exponent":true,"hurst_exponent":false,"detrended_fluctuation_analysis":false}},"bispectrum":{"f1s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"f2s":{"frequency_low_hz":5.0,"frequency_high_hz":35.0},"compute_features_for_whole_fband_range":true,"frequency_bands":["theta","alpha","low_beta","high_beta"],"components":{"absolute":true,"real":true,"imag":true,"phase":true},"bispectrum_features":{"mean":true,"sum":true,"var":true}}}';
-const nm_settings = JSON.parse(stringJson);
-
-const filterByKeys = (dict, keys) => {
- const filteredDict = {};
- keys.forEach((key) => {
- if (typeof key === "string") {
- // Top-level key
- if (dict.hasOwnProperty(key)) {
- filteredDict[key] = dict[key];
- }
- } else if (typeof key === "object") {
- // Nested key
- const topLevelKey = Object.keys(key)[0];
- const nestedKeys = key[topLevelKey];
- if (
- dict.hasOwnProperty(topLevelKey) &&
- typeof dict[topLevelKey] === "object"
- ) {
- filteredDict[topLevelKey] = filterByKeys(dict[topLevelKey], nestedKeys);
- }
- }
- });
- return filteredDict;
-};
-
-const Dropdown = ({ onChange, keysToInclude }) => {
- const filteredSettings = filterByKeys(nm_settings, keysToInclude);
- const [selectedOption, setSelectedOption] = useState("");
-
- const handleChange = (event) => {
- const newValue = event.target.value;
- setSelectedOption(newValue);
- onChange(keysToInclude, newValue);
- };
-
- return (
-
-
-
- );
-};
-
-export default Dropdown;
diff --git a/gui_dev/src/pages/Settings/FrequencyRange.jsx b/gui_dev/src/pages/Settings/FrequencyRange.jsx
deleted file mode 100644
index 5def783b..00000000
--- a/gui_dev/src/pages/Settings/FrequencyRange.jsx
+++ /dev/null
@@ -1,88 +0,0 @@
-import { useState } from "react";
-
-// const onChange = (key, newValue) => {
-// settings.frequencyRanges[key] = newValue;
-// };
-
-// Object.entries(settings.frequencyRanges).map(([key, value]) => (
-//
-// ));
-
-/** */
-
-/**
- *
- * @param {String} key
- * @param {Array} freqRange
- * @param {Function} onChange
- * @returns
- */
-export const FrequencyRange = ({ settings }) => {
- const [frequencyRanges, setFrequencyRanges] = useState(settings || {});
- // Handle changes in the text fields
- const handleInputChange = (label, key, newValue) => {
- setFrequencyRanges((prevState) => ({
- ...prevState,
- [label]: {
- ...prevState[label],
- [key]: newValue,
- },
- }));
- };
-
- // Add a new band
- const addBand = () => {
- const newLabel = `Band ${Object.keys(frequencyRanges).length + 1}`;
- setFrequencyRanges((prevState) => ({
- ...prevState,
- [newLabel]: { frequency_high_hz: "", frequency_low_hz: "" },
- }));
- };
-
- // Remove a band
- const removeBand = (label) => {
- const updatedRanges = { ...frequencyRanges };
- delete updatedRanges[label];
- setFrequencyRanges(updatedRanges);
- };
-
- return (
-
- );
-};
diff --git a/gui_dev/src/pages/Settings/FrequencySettings.module.css b/gui_dev/src/pages/Settings/FrequencySettings.module.css
deleted file mode 100644
index a638ad93..00000000
--- a/gui_dev/src/pages/Settings/FrequencySettings.module.css
+++ /dev/null
@@ -1,67 +0,0 @@
-.container {
- background-color: #f9f9f9; /* Light gray background for the container */
- padding: 20px;
- border-radius: 10px; /* Rounded corners */
- box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); /* Subtle shadow */
- max-width: 600px;
- margin: auto;
- }
-
- .header {
- font-size: 1.5rem;
- color: #333; /* Darker text color */
- margin-bottom: 20px;
- text-align: center;
- }
-
- .bandContainer {
- display: flex;
- align-items: center;
- margin-bottom: 15px;
- padding: 10px;
- border: 1px solid #ddd; /* Light border for each band */
- border-radius: 8px; /* Rounded corners for individual bands */
- background-color: #fff; /* White background for bands */
- box-shadow: 0 1px 5px rgba(0, 0, 0, 0.1); /* Light shadow for depth */
- }
-
- .bandNameInput, .frequencyInput {
- border: 1px solid #ccc; /* Light gray border */
- border-radius: 5px; /* Slightly rounded corners */
- padding: 8px;
- margin-right: 10px;
- font-size: 0.875rem;
- }
-
- .bandNameInput::placeholder, .frequencyInput::placeholder {
- color: #aaa; /* Light gray placeholder text */
- }
-
- .removeButton, .addButton {
- border: none;
- border-radius: 5px; /* Rounded corners */
- padding: 8px 12px;
- font-size: 0.875rem;
- cursor: pointer;
- transition: background-color 0.3s, color 0.3s;
- }
-
- .removeButton {
- background-color: #e57373; /* Light red color */
- color: white;
- }
-
- .removeButton:hover {
- background-color: #d32f2f; /* Darker red on hover */
- }
-
- .addButton {
- background-color: #4caf50; /* Green color */
- color: white;
- margin-top: 10px;
- }
-
- .addButton:hover {
- background-color: #388e3c; /* Darker green on hover */
- }
-
\ No newline at end of file
diff --git a/gui_dev/src/pages/Settings/Settings.jsx b/gui_dev/src/pages/Settings/Settings.jsx
index e2cf5f2a..69458cbc 100644
--- a/gui_dev/src/pages/Settings/Settings.jsx
+++ b/gui_dev/src/pages/Settings/Settings.jsx
@@ -1,19 +1,25 @@
+import { useEffect, useState } from "react";
import {
+ Box,
Button,
InputAdornment,
+ Popover,
Stack,
Switch,
TextField,
+ Tooltip,
Typography,
} from "@mui/material";
import { Link } from "react-router-dom";
import { CollapsibleBox, TitledBox } from "@/components";
-import { FrequencyRange } from "./FrequencyRange";
-import { useSettingsStore } from "@/stores";
+import {
+ FrequencyRangeList,
+ FrequencyRange,
+} from "./components/FrequencyRange";
+import { useSettingsStore, useStatusBar } from "@/stores";
import { filterObjectByKeys } from "@/utils/functions";
const formatKey = (key) => {
- // console.log(key);
return key
.split("_")
.map((word) => word.charAt(0).toUpperCase() + word.slice(1))
@@ -21,15 +27,42 @@ const formatKey = (key) => {
};
// Wrapper components for each type
-const BooleanField = ({ value, onChange }) => (
- onChange(e.target.checked)} />
+const BooleanField = ({ label, value, onChange, error }) => (
+
+ {label}
+ onChange(e.target.checked)} />
+
);
+const errorStyle = {
+ "& .MuiOutlinedInput-root": {
+ "& fieldset": { borderColor: "error.main" },
+ "&:hover fieldset": {
+ borderColor: "error.main",
+ },
+ "&.Mui-focused fieldset": {
+ borderColor: "error.main",
+ },
+ },
+};
-const StringField = ({ value, onChange, label }) => (
-
-);
+const StringField = ({ label, value, onChange, error }) => {
+ const errorSx = error ? errorStyle : {};
+ return (
+
+ {label}
+ onChange(e.target.value)}
+ label={label}
+ sx={{ ...errorSx }}
+ />
+
+ );
+};
+
+const NumberField = ({ label, value, onChange, error, unit }) => {
+ const errorSx = error ? errorStyle : {};
-const NumberField = ({ value, onChange, label }) => {
const handleChange = (event) => {
const newValue = event.target.value;
// Only allow numbers and decimal point
@@ -39,126 +72,237 @@ const NumberField = ({ value, onChange, label }) => {
};
return (
-
- Hz
-
- ),
- }}
- inputProps={{
- pattern: "[0-9]*",
- }}
- />
+
+ {label}
+
+ {unit},
+ }}
+ inputProps={{
+ pattern: "[0-9]*",
+ }}
+ />
+
);
};
-const FrequencyRangeField = ({ value, onChange, label }) => (
-
-);
+const FrequencyRangeField = ({ label, value, onChange, error }) => {
+ console.log(label, value);
+ return ;
+};
// Map component types to their respective wrappers
const componentRegistry = {
boolean: BooleanField,
string: StringField,
+ int: NumberField,
+ float: NumberField,
number: NumberField,
FrequencyRange: FrequencyRangeField,
};
-const SettingsField = ({ path, Component, label, value, onChange, depth }) => {
+const SettingsField = ({
+ path,
+ Component,
+ label,
+ value,
+ onChange,
+ error,
+ unit,
+}) => {
return (
-
- {label}
+
onChange(path, newValue)}
label={label}
+ error={error}
+ unit={unit}
/>
-
+
);
};
+// Function to get the error corresponding to this field or its children
+const getFieldError = (fieldPath, errors) => {
+ if (!errors) return null;
+
+ return errors.find((error) => {
+ const errorPath = error.loc.join(".");
+ const currentPath = fieldPath.join(".");
+ return errorPath === currentPath || errorPath.startsWith(currentPath + ".");
+ });
+};
+
const SettingsSection = ({
settings,
title = null,
path = [],
onChange,
- depth = 0,
+ errors,
}) => {
- if (Object.keys(settings).length === 0) {
- return null;
- }
const boxTitle = title ? title : formatKey(path[path.length - 1]);
+ const type = typeof settings;
+ const isObject = type === "object" && !Array.isArray(settings);
+ const isArray = Array.isArray(settings);
- return (
-
- {Object.entries(settings).map(([key, value]) => {
- if (key === "__field_type__") return null;
+ // __field_type__ should be always present
+ if (isObject && !settings.__field_type__) {
+ throw new Error("Invalid settings object");
+ }
+ const fieldType = isObject ? settings.__field_type__ : type;
+ const Component = componentRegistry[fieldType];
- const newPath = [...path, key];
- const label = key;
- const isPydanticModel =
- typeof value === "object" && "__field_type__" in value;
+ // Case 1: Object or primitive with component -> Don't iterate, render directly
+ if (Component) {
+ const value =
+ isObject && "__value__" in settings ? settings.__value__ : settings;
+ const unit = isObject && "__unit__" in settings ? settings.__unit__ : null;
- const fieldType = isPydanticModel ? value.__field_type__ : typeof value;
+ return (
+
+ );
+ }
- const Component = componentRegistry[fieldType];
+ // Case 2: Object without component or Array -> Iterate and render recursively
+ else {
+ return (
+
+ {/* Handle recursing through both objects and arrays */}
+ {(isArray ? settings : Object.entries(settings)).map((item, index) => {
+ const [key, value] = isArray ? [index.toString(), item] : item;
+ if (key.startsWith("__")) return null; // Skip metadata fields
+
+ const newPath = [...path, key];
- if (Component) {
- return (
-
- );
- } else {
return (
);
- }
- })}
-
+ })}
+
+ );
+ }
+};
+
+const StatusBarSettingsInfo = () => {
+ const validationErrors = useSettingsStore((state) => state.validationErrors);
+ const [anchorEl, setAnchorEl] = useState(null);
+ const open = Boolean(anchorEl);
+
+ const handleOpenErrorsPopover = (event) => {
+ setAnchorEl(event.currentTarget);
+ };
+
+ const handleCloseErrorsPopover = () => {
+ setAnchorEl(null);
+ };
+
+ return (
+ <>
+ {validationErrors?.length > 0 && (
+ <>
+
+ {validationErrors?.length} errors found in Settings
+
+
+
+ {validationErrors.map((error, index) => (
+
+ {index} - [{error.type}] {error.msg}
+
+ ))}
+
+
+ >
+ )}
+ >
);
};
-const SettingsContent = () => {
+export const Settings = () => {
+ // Get all necessary state from the settings store
const settings = useSettingsStore((state) => state.settings);
- const updateSettings = useSettingsStore((state) => state.updateSettings);
+ const uploadSettings = useSettingsStore((state) => state.uploadSettings);
+ const resetSettings = useSettingsStore((state) => state.resetSettings);
+ const validationErrors = useSettingsStore((state) => state.validationErrors);
+ useStatusBar(StatusBarSettingsInfo);
+
+ // This is needed so that the frequency ranges stay in order between updates
+ const frequencyRangeOrder = useSettingsStore(
+ (state) => state.frequencyRangeOrder
+ );
+ const updateFrequencyRangeOrder = useSettingsStore(
+ (state) => state.updateFrequencyRangeOrder
+ );
+ // Here I handle the selected feature in the feature settings component
+ const [selectedFeature, setSelectedFeature] = useState("");
+
+ useEffect(() => {
+ uploadSettings(null, true); // validateOnly = true
+ }, [settings]);
+
+ // Inject validation error info into status bar
+
+ // This has to be after all the hooks, otherwise React will complain
if (!settings) {
return Loading settings...
;
}
- const handleChange = (path, value) => {
- updateSettings((settings) => {
+ // This are the callbacks for the different buttons
+ const handleChangeSettings = async (path, value) => {
+ uploadSettings((settings) => {
let current = settings;
for (let i = 0; i < path.length - 1; i++) {
current = current[path[i]];
}
current[path[path.length - 1]] = value;
- });
+ }, true); // validateOnly = true
+ };
+
+ const handleSaveSettings = () => {
+ uploadSettings(() => settings);
+ };
+
+ const handleResetSettings = async () => {
+ await resetSettings();
};
const featureSettingsKeys = Object.keys(settings.features)
@@ -181,89 +325,147 @@ const SettingsContent = () => {
"project_subcortex_settings",
];
+ const generalSettingsKeys = [
+ "sampling_rate_features_hz",
+ "segment_length_features_ms",
+ ];
+
return (
-
-
+ {/* SETTINGS LAYOUT */}
+
-
-
-
-
-
+ {/* GENERAL SETTINGS + FREQUENCY RANGES */}
+
+
+ {generalSettingsKeys.map((key) => (
+
+ ))}
+
+
+
+
+
+
+
+ {/* POSTPROCESSING + PREPROCESSING SETTINGS */}
+
{preprocessingSettingsKeys.map((key) => (
))}
-
+
-
+
{postprocessingSettingsKeys.map((key) => (
))}
-
-
+
-
- {Object.entries(enabledFeatures).map(([feature, featureSettings]) => (
-
-
-
- ))}
-
-
- );
-};
+ {/* FEATURE SETTINGS */}
+
+
+
+
+
+
+ {Object.entries(enabledFeatures).map(
+ ([feature, featureSettings]) => (
+
+
+
+ )
+ )}
+
+
+
+ {/* END SETTINGS LAYOUT */}
+
-export const Settings = () => {
- return (
-
-
-
+
+ {/* */}
+
+
+
);
};
diff --git a/gui_dev/src/pages/Settings/TextField.jsx b/gui_dev/src/pages/Settings/TextField.jsx
deleted file mode 100644
index 86ab07b8..00000000
--- a/gui_dev/src/pages/Settings/TextField.jsx
+++ /dev/null
@@ -1,77 +0,0 @@
-import { useState, useEffect } from "react";
-import {
- Box,
- Grid,
- TextField as MUITextField,
- Typography,
-} from "@mui/material";
-import { useSettingsStore } from "@/stores";
-import styles from "./TextField.module.css";
-
-const flattenDictionary = (dict, parentKey = "", result = {}) => {
- for (let key in dict) {
- const newKey = parentKey ? `${parentKey}.${key}` : key;
- if (typeof dict[key] === "object" && dict[key] !== null) {
- flattenDictionary(dict[key], newKey, result);
- } else {
- result[newKey] = dict[key];
- }
- }
- return result;
-};
-
-const filterByKeys = (flatDict, keys) => {
- const filteredDict = {};
- keys.forEach((key) => {
- if (flatDict.hasOwnProperty(key)) {
- filteredDict[key] = flatDict[key];
- }
- });
- return filteredDict;
-};
-
-export const TextField = ({ keysToInclude }) => {
- const settings = useSettingsStore((state) => state.settings);
- const flatSettings = flattenDictionary(settings);
- const filteredSettings = filterByKeys(flatSettings, keysToInclude);
- const [textLabels, setTextLabels] = useState({});
-
- useEffect(() => {
- setTextLabels(filteredSettings);
- }, [settings]);
-
- const handleTextFieldChange = (label, value) => {
- setTextLabels((prevLabels) => ({
- ...prevLabels,
- [label]: value,
- }));
- };
-
- // Function to format the label
- const formatLabel = (label) => {
- const labelAfterDot = label.split(".").pop(); // Get everything after the last dot
- return labelAfterDot.replace(/_/g, " "); // Replace underscores with spaces
- };
-
- return (
-
- {Object.keys(textLabels).map((label, index) => (
-
-
- handleTextFieldChange(label, e.target.value)}
- className={styles.textFieldInput}
- />
-
- ))}
-
- );
-};
diff --git a/gui_dev/src/pages/Settings/TextField.module.css b/gui_dev/src/pages/Settings/TextField.module.css
deleted file mode 100644
index 37791c40..00000000
--- a/gui_dev/src/pages/Settings/TextField.module.css
+++ /dev/null
@@ -1,67 +0,0 @@
-/* TextField.module.css */
-
-/* Container for the text fields */
-.textFieldContainer {
- display: flex;
- flex-direction: column;
- margin: 1.5rem 0; /* Increased margin for better spacing */
- }
-
- /* Row for each text field */
- .textFieldRow {
- display: flex;
- flex-direction: column; /* Stack label and input vertically */
- margin-bottom: 1rem; /* Increased margin for better separation */
- }
-
- /* Label for each text field */
- .textFieldLabel {
- margin-bottom: 0.5rem; /* Space between label and input */
- font-weight: 600; /* Increased weight for better visibility */
- color: #333; /* Dark gray for the label */
- font-size: 1.1rem; /* Increased font size for the label */
- transition: all 0.2s ease; /* Smooth transition for label */
- }
-
- /* Input field styles */
- .textFieldInput {
- padding: 12px 14px; /* Padding for a filled look */
- border: 1px solid #ccc; /* Light gray border */
- border-radius: 4px; /* Rounded corners */
- width: 100%; /* Full width */
- font-size: 1rem; /* Font size */
- background-color: #f5f5f5; /* Light background color for filled effect */
- transition: border-color 0.2s ease, background-color 0.2s ease; /* Smooth transitions */
- box-shadow: none; /* Remove default shadow */
- height: 48px; /* Fixed height for a more square appearance */
- }
-
- /* Focus styles for the input */
- .textFieldInput:focus {
- border-color: #1976d2; /* Blue border color on focus */
- background-color: #fff; /* Change background to white on focus */
- outline: none; /* Remove default outline */
- }
-
- /* Hover effect for the input */
- .textFieldInput:hover {
- border-color: #1976d2; /* Change border color on hover */
- }
-
- /* Placeholder styles */
- .textFieldInput::placeholder {
- color: #aaa; /* Light gray placeholder text */
- opacity: 1; /* Ensure placeholder is fully opaque */
- }
-
- /* Hide the number input spinners in webkit browsers */
- .textFieldInput::-webkit-inner-spin-button,
- .textFieldInput::-webkit-outer-spin-button {
- -webkit-appearance: none; /* Remove default styling */
- margin: 0; /* Remove margin */
- }
-
- /* Hide the number input spinners in Firefox */
- .textFieldInput[type='number'] {
- -moz-appearance: textfield; /* Use textfield appearance */
- }
\ No newline at end of file
diff --git a/gui_dev/src/pages/Settings/components/FrequencyRange.jsx b/gui_dev/src/pages/Settings/components/FrequencyRange.jsx
new file mode 100644
index 00000000..7fb50a23
--- /dev/null
+++ b/gui_dev/src/pages/Settings/components/FrequencyRange.jsx
@@ -0,0 +1,198 @@
+import { useState } from "react";
+import {
+ TextField,
+ Button,
+ IconButton,
+ Stack,
+ Typography,
+} from "@mui/material";
+import { Add, Close } from "@mui/icons-material";
+import { debounce } from "@/utils";
+
+const NumberField = ({ ...props }) => (
+
+);
+
+export const FrequencyRange = ({
+ name,
+ range,
+ onChangeName,
+ onChangeRange,
+ error,
+ nameEditable = false,
+}) => {
+ console.log(range);
+ const [localName, setLocalName] = useState(name);
+
+ const debouncedChangeName = debounce((newName) => {
+ onChangeName(newName, name);
+ }, 1000);
+
+ const handleNameChange = (e) => {
+ if (!nameEditable) return;
+ const newName = e.target.value;
+ setLocalName(newName);
+ debouncedChangeName(newName);
+ };
+
+ const handleNameBlur = () => {
+ if (!nameEditable) return;
+ onChangeName(localName, name);
+ };
+
+ const handleKeyPress = (e) => {
+ if (!nameEditable) return;
+ if (e.key === "Enter") {
+ console.log(e.target.value, name);
+ onChangeName(localName, name);
+ }
+ };
+
+ const handleRangeChange = (name, field, value) => {
+ // onChangeRange takes the name of the range as the first argument
+ onChangeRange(name, { ...range, [field]: value });
+ };
+
+ return (
+
+ {nameEditable ? (
+
+ ) : (
+ {name}
+ )}
+
+ handleRangeChange(name, "frequency_low_hz", e.target.value)
+ }
+ label="Low Hz"
+ />
+
+ handleRangeChange(name, "frequency_high_hz", e.target.value)
+ }
+ label="High Hz"
+ />
+
+ );
+};
+
+export const FrequencyRangeList = ({
+ ranges,
+ rangeOrder,
+ onChange,
+ onOrderChange,
+ errors,
+}) => {
+ const handleChangeRange = (name, newRange) => {
+ const updatedRanges = { ...ranges };
+ updatedRanges[name] = newRange;
+ onChange(["frequency_ranges_hz"], updatedRanges);
+ };
+
+ const handleChangeName = (newName, oldName) => {
+ if (oldName === newName) {
+ return;
+ }
+
+ const updatedRanges = { ...ranges, [newName]: ranges[oldName] };
+ delete updatedRanges[oldName];
+ onChange(["frequency_ranges_hz"], updatedRanges);
+
+ const updatedOrder = rangeOrder.map((name) =>
+ name === oldName ? newName : name
+ );
+ onOrderChange(updatedOrder);
+ };
+
+ const handleRemove = (name) => {
+ const updatedRanges = { ...ranges };
+ delete updatedRanges[name];
+ onChange(["frequency_ranges_hz"], updatedRanges);
+
+ const updatedOrder = rangeOrder.filter((item) => item !== name);
+ onOrderChange(updatedOrder);
+ };
+
+ const addRange = () => {
+ let newName = "NewRange";
+ let counter = 0;
+ // Find first available name
+ while (Object.hasOwn(ranges, newName)) {
+ counter++;
+ newName = `NewRange${counter}`;
+ }
+
+ const updatedRanges = {
+ ...ranges,
+ [newName]: {
+ __field_type__: "FrequencyRange",
+ frequency_low_hz: 1,
+ frequency_high_hz: 500,
+ },
+ };
+ onChange(["frequency_ranges_hz"], updatedRanges);
+
+ const updatedOrder = [...rangeOrder, newName];
+ onOrderChange(updatedOrder);
+ };
+
+ return (
+
+ {rangeOrder.map((name, index) => (
+
+
+ handleRemove(name)}
+ color="primary"
+ disableRipple
+ sx={{ m: 0, p: 0 }}
+ >
+
+
+
+ ))}
+ }
+ onClick={addRange}
+ sx={{ mt: 1 }}
+ >
+ Add Range
+
+
+ );
+};
diff --git a/gui_dev/src/pages/Settings/index.js b/gui_dev/src/pages/Settings/index.js
deleted file mode 100644
index 16355023..00000000
--- a/gui_dev/src/pages/Settings/index.js
+++ /dev/null
@@ -1 +0,0 @@
-export { TextField } from './TextField';
diff --git a/gui_dev/src/pages/SourceSelection/FileSelector.jsx b/gui_dev/src/pages/SourceSelection/FileSelector.jsx
index ffa53bab..e119265a 100644
--- a/gui_dev/src/pages/SourceSelection/FileSelector.jsx
+++ b/gui_dev/src/pages/SourceSelection/FileSelector.jsx
@@ -5,6 +5,8 @@ import { useSessionStore } from "@/stores";
import { FileBrowser, TitledBox } from "@/components";
+import { getPyNMDirectory } from "@/utils";
+
export const FileSelector = () => {
const fileSource = useSessionStore((state) => state.fileSource);
const setFileSource = useSessionStore((state) => state.setFileSource);
@@ -19,16 +21,22 @@ export const FileSelector = () => {
);
const setSourceType = useSessionStore((state) => state.setSourceType);
- const fileBrowserDirRef = useRef(
- "C:\\code\\py_neuromodulation\\py_neuromodulation\\data\\sub-testsub\\ses-EphysMedOff\\ieeg\\sub-testsub_ses-EphysMedOff_task-gripforce_run-0_ieeg.vhdr"
- );
+ const fileBrowserDirRef = useRef("");
const [isSelecting, setIsSelecting] = useState(false);
const [showFileBrowser, setShowFileBrowser] = useState(false);
+ const [showFolderBrowser, setShowFolderBrowser] = useState(false);
useEffect(() => {
setSourceType("lsl");
- }, []);
+
+ const fetchPyNMDirectory = async () => {
+ const pynmDir = await getPyNMDirectory();
+ fileBrowserDirRef.current =
+ pynmDir + "\\data\\sub-testsub\\ses-EphysMedOff\\ieeg\\";
+ };
+ fetchPyNMDirectory();
+ }, [setSourceType]);
const handleFileSelect = (file) => {
setIsSelecting(true);
@@ -48,6 +56,10 @@ export const FileSelector = () => {
}
};
+ const handleFolderSelect = (folder) => {
+ setShowFolderBrowser(false);
+ };
+
return (
+
{streamSetupMessage && (
{
onSelect={handleFileSelect}
/>
)}
+ {showFolderBrowser && (
+ setShowFolderBrowser(false)}
+ onSelect={handleFolderSelect}
+ onlyDirectories={true}
+ />
+ )}
);
};
diff --git a/gui_dev/src/stores/appInfoStore.js b/gui_dev/src/stores/appInfoStore.js
index c7887fe4..773aa936 100644
--- a/gui_dev/src/stores/appInfoStore.js
+++ b/gui_dev/src/stores/appInfoStore.js
@@ -1,5 +1,5 @@
import { createStore } from "./createStore";
-import { getBackendURL } from "@/utils/getBackendURL";
+import { getBackendURL } from "@/utils";
export const useAppInfoStore = createStore("appInfo", (set) => ({
version: "",
diff --git a/gui_dev/src/stores/createStore.js b/gui_dev/src/stores/createStore.js
index 762d9340..e683a35c 100644
--- a/gui_dev/src/stores/createStore.js
+++ b/gui_dev/src/stores/createStore.js
@@ -2,16 +2,23 @@ import { create } from "zustand";
import { immer } from "zustand/middleware/immer";
import { devtools, persist as persistMiddleware } from "zustand/middleware";
-export const createStore = (name, initializer, persist = false) => {
+export const createStore = (
+ name,
+ initializer,
+ persist = false,
+ dev = false
+) => {
const fn = persist
? persistMiddleware(immer(initializer), name)
: immer(initializer);
- return create(
- devtools(fn, {
- name: name,
- })
- );
+ const dev_fn = dev
+ ? devtools(fn, {
+ name: name,
+ })
+ : fn;
+
+ return create(dev_fn);
};
export const createPersistStore = (name, initializer) => {
diff --git a/gui_dev/src/stores/sessionStore.js b/gui_dev/src/stores/sessionStore.js
index 959d7fcf..971e6314 100644
--- a/gui_dev/src/stores/sessionStore.js
+++ b/gui_dev/src/stores/sessionStore.js
@@ -3,7 +3,7 @@
// the data source, stream paramerters, the output files paths, etc
import { createStore } from "@/stores/createStore";
-import { getBackendURL } from "@/utils/getBackendURL";
+import { getBackendURL } from "@/utils";
// Workflow stages enum-like object
export const WorkflowStage = Object.freeze({
@@ -110,7 +110,6 @@ export const useSessionStore = createStore("session", (set, get) => ({
// Check that all stream parameters are valid
checkStreamParameters: () => {
- // const { samplingRate, lineNoise, samplingRateFeatures } = get();
set({
areParametersValid:
get().streamParameters.samplingRate &&
diff --git a/gui_dev/src/stores/settingsStore.js b/gui_dev/src/stores/settingsStore.js
index 9b17f42d..73574815 100644
--- a/gui_dev/src/stores/settingsStore.js
+++ b/gui_dev/src/stores/settingsStore.js
@@ -1,36 +1,22 @@
-import { getBackendURL } from "@/utils/getBackendURL";
+import { getBackendURL } from "@/utils";
import { createStore } from "./createStore";
const INITIAL_DELAY = 3000; // wait for Flask
const RETRY_DELAY = 1000; // ms
const MAX_RETRIES = 100;
-const uploadSettingsToServer = async (settings) => {
- try {
- const response = await fetch(getBackendURL("/api/settings"), {
- method: "POST",
- headers: {
- "Content-Type": "application/json",
- },
- body: JSON.stringify(settings),
- });
- if (!response.ok) {
- throw new Error("Failed to update settings");
- }
- return { success: true };
- } catch (error) {
- console.error("Error updating settings:", error);
- return { success: false, error };
- }
-};
-
export const useSettingsStore = createStore("settings", (set, get) => ({
settings: null,
+ lastValidSettings: null,
+ frequencyRangeOrder: [],
isLoading: false,
error: null,
+ validationErrors: null,
retryCount: 0,
- setSettings: (settings) => set({ settings }),
+ updateLocalSettings: (updater) => {
+ set((state) => updater(state.settings));
+ },
fetchSettingsWithDelay: () => {
set({ isLoading: true, error: null });
@@ -46,8 +32,15 @@ export const useSettingsStore = createStore("settings", (set, get) => ({
if (!response.ok) {
throw new Error("Failed to fetch settings");
}
+
const data = await response.json();
- set({ settings: data, retryCount: 0 });
+
+ set({
+ settings: data,
+ lastValidSettings: data,
+ frequencyRangeOrder: Object.keys(data.frequency_ranges_hz || {}),
+ retryCount: 0,
+ });
} catch (error) {
console.log("Error fetching settings:", error);
set((state) => ({
@@ -55,6 +48,8 @@ export const useSettingsStore = createStore("settings", (set, get) => ({
retryCount: state.retryCount + 1,
}));
+ console.log(get().retryCount);
+
if (get().retryCount < MAX_RETRIES) {
await new Promise((resolve) => setTimeout(resolve, RETRY_DELAY));
return get().fetchSettings();
@@ -66,29 +61,64 @@ export const useSettingsStore = createStore("settings", (set, get) => ({
resetRetryCount: () => set({ retryCount: 0 }),
- updateSettings: async (updater) => {
- const currentSettings = get().settings;
+ resetSettings: async () => {
+ await get().fetchSettings(true);
+ },
- // Apply the update optimistically
- set((state) => {
- updater(state.settings);
- });
+ updateFrequencyRangeOrder: (newOrder) => {
+ set({ frequencyRangeOrder: newOrder });
+ },
- const newSettings = get().settings;
+ uploadSettings: async (updater, validateOnly = false) => {
+ if (updater) {
+ set((state) => {
+ updater(state.settings);
+ });
+ }
+
+ const currentSettings = get().settings;
try {
- const result = await uploadSettingsToServer(newSettings);
+ const response = await fetch(
+ getBackendURL(
+ `/api/settings${validateOnly ? "?validate_only=true" : ""}`
+ ),
+ {
+ method: "POST",
+ headers: {
+ "Content-Type": "application/json",
+ },
+ body: JSON.stringify(currentSettings),
+ }
+ );
+
+ const data = await response.json();
- if (!result.success) {
- // Revert the local state if the server update failed
- set({ settings: currentSettings });
+ if (!response.ok) {
+ throw new Error("Failed to upload settings to backend");
}
- return result;
+ if (data.valid) {
+ // Settings are valid
+ set({
+ lastValidSettings: currentSettings,
+ validationErrors: null,
+ });
+ return true;
+ } else {
+ // Settings are invalid
+ set({
+ validationErrors: data.errors,
+ });
+ // Note: We don't revert the settings here, keeping the potentially invalid state
+ return false;
+ }
} catch (error) {
- // Revert the local state if there was an error
- set({ settings: currentSettings });
- throw error;
+ console.error(
+ `Error ${validateOnly ? "validating" : "updating"} settings:`,
+ error
+ );
+ return false;
}
},
}));
diff --git a/gui_dev/src/stores/socketStore.js b/gui_dev/src/stores/socketStore.js
index 4e6d7264..8e091260 100644
--- a/gui_dev/src/stores/socketStore.js
+++ b/gui_dev/src/stores/socketStore.js
@@ -1,5 +1,5 @@
import { createStore } from "./createStore";
-import { getBackendURL } from "@/utils/getBackendURL";
+import { getBackendURL } from "@/utils";
import CBOR from "cbor-js";
const WEBSOCKET_URL = getBackendURL("/ws");
@@ -89,16 +89,15 @@ export const useSocketStore = createStore("socket", (set, get) => ({
// check if this is the same:
Object.entries(decodedData).forEach(([key, value]) => {
- (key.startsWith("decode") ? decodingData : dataNonDecodingFeatures)[key] = value;
+ (key.startsWith("decode") ? decodingData : dataNonDecodingFeatures)[
+ key
+ ] = value;
});
set({ availableDecodingOutputs: Object.keys(decodingData) });
-
-
set({ graphDecodingData: decodingData });
set({ graphData: dataNonDecodingFeatures });
-
}
} catch (error) {
console.error("Failed to decode CBOR message:", error);
diff --git a/gui_dev/src/stores/uiStore.js b/gui_dev/src/stores/uiStore.js
index a229ec06..0462871f 100644
--- a/gui_dev/src/stores/uiStore.js
+++ b/gui_dev/src/stores/uiStore.js
@@ -1,6 +1,7 @@
-import { createPersistStore } from "./createStore";
+import { createStore } from "./createStore";
+import { useEffect } from "react";
-export const useUiStore = createPersistStore("ui", (set, get) => ({
+export const useUiStore = createStore("ui", (set, get) => ({
activeDrawer: null,
toggleDrawer: (drawerName) =>
set((state) => {
@@ -28,4 +29,25 @@ export const useUiStore = createPersistStore("ui", (set, get) => ({
state.accordionStates[id] = defaultState;
}
}),
+
+ // Hook to inject UI elements into the status bar
+ getStatusBarContent: () => null,
+ setStatusBarContent: (contentGetter) =>
+ set({ getStatusBarContent: contentGetter }),
+ clearStatusBarContent: () => set({ getStatusBarContent: () => null }),
}));
+
+// Use this hook from Page components to inject page-specific UI elements into the status bar
+export const useStatusBar = (content) => {
+ const createStatusBarContent = () => content;
+
+ const setStatusBarContent = useUiStore((state) => state.setStatusBarContent);
+ const clearStatusBarContent = useUiStore(
+ (state) => state.clearStatusBarContent
+ );
+
+ useEffect(() => {
+ setStatusBarContent(createStatusBarContent);
+ return () => clearStatusBarContent();
+ }, [content, setStatusBarContent, clearStatusBarContent]);
+};
diff --git a/gui_dev/src/theme.js b/gui_dev/src/theme.js
index 5e2b5f11..cc99909c 100644
--- a/gui_dev/src/theme.js
+++ b/gui_dev/src/theme.js
@@ -22,6 +22,11 @@ export const theme = createTheme({
disableRipple: true,
},
},
+ MuiTextField: {
+ defaultProps: {
+ autoComplete: "off",
+ },
+ },
MuiStack: {
defaultProps: {
alignItems: "center",
diff --git a/gui_dev/src/utils/FileInfo.js b/gui_dev/src/utils/FileInfo.js
new file mode 100644
index 00000000..74ae8347
--- /dev/null
+++ b/gui_dev/src/utils/FileInfo.js
@@ -0,0 +1,137 @@
+/**
+ * Represents information about a file or directory in the system
+ */
+export class FileInfo {
+ /**
+ * Creates a new FileInfo instance
+ * @param {Object} params - The parameters to initialize the FileInfo object
+ * @param {string} [params.name=''] - The name of the file or directory
+ * @param {string} [params.path=''] - The full path of the file or directory
+ * @param {string} [params.dir=''] - The directory containing the file or directory
+ * @param {boolean} [params.is_directory=false] - Whether the entry is a directory
+ * @param {number} [params.size=0] - The size of the file in bytes (0 for directories)
+ * @param {string} [params.created_at=''] - The creation timestamp of the file
+ * @param {string} [params.modified_at=''] - The last modification timestamp of the file
+ */
+ constructor({
+ name = "",
+ path = "",
+ dir = "",
+ is_directory = false,
+ size = 0,
+ created_at = "",
+ modified_at = "",
+ } = {}) {
+ this.name = name;
+ this.path = path;
+ this.dir = dir;
+ this.is_directory = is_directory;
+ this.size = size;
+ this.created_at = created_at;
+ this.modified_at = modified_at;
+ }
+
+ /**
+ * Creates a FileInfo instance from a plain object
+ * @param {Object} obj - The object containing file information
+ * @returns {FileInfo} A new FileInfo instance
+ */
+ static fromObject(obj) {
+ return new FileInfo(obj);
+ }
+
+ /**
+ * Resets all properties to their default values
+ */
+ reset() {
+ Object.assign(this, new FileInfo());
+ }
+
+ /**
+ * Updates the FileInfo instance with new values
+ * @param {Partial} updates - The properties to update
+ */
+ update(updates) {
+ Object.assign(this, updates);
+ }
+
+ /**
+ * Gets the file extension
+ * @returns {string} The file extension (empty string for directories)
+ */
+ getExtension() {
+ if (this.is_directory) return "";
+ const ext = this.name.split(".").pop();
+ return ext === this.name ? "" : ext;
+ }
+
+ /**
+ * Gets the base name without extension
+ * @returns {string} The base name
+ */
+ getBaseName() {
+ if (this.is_directory) return this.name;
+ const lastDotIndex = this.name.lastIndexOf(".");
+ return lastDotIndex === -1 ? this.name : this.name.slice(0, lastDotIndex);
+ }
+
+ /**
+ * Formats the file size in a human-readable format
+ * @returns {string} The formatted file size
+ */
+ getFormattedSize() {
+ if (this.is_directory) return "-";
+ const units = ["B", "KB", "MB", "GB", "TB"];
+ let size = this.size;
+ let unitIndex = 0;
+
+ while (size >= 1024 && unitIndex < units.length - 1) {
+ size /= 1024;
+ unitIndex++;
+ }
+
+ return `${Math.round(size * 100) / 100} ${units[unitIndex]}`;
+ }
+
+ /**
+ * Checks if the file/directory is hidden
+ * @returns {boolean} Whether the file/directory is hidden
+ */
+ isHidden() {
+ return this.name.startsWith(".");
+ }
+
+ /**
+ * Creates a plain object representation of the FileInfo instance
+ * @returns {Object} A plain object containing the file information
+ */
+ toObject() {
+ return {
+ name: this.name,
+ path: this.path,
+ dir: this.dir,
+ is_directory: this.is_directory,
+ size: this.size,
+ created_at: this.created_at,
+ modified_at: this.modified_at,
+ };
+ }
+
+ /**
+ * Creates a clone of the FileInfo instance
+ * @returns {FileInfo} A new FileInfo instance with the same values
+ */
+ clone() {
+ return new FileInfo(this.toObject());
+ }
+
+ /**
+ * Compares this FileInfo instance with another
+ * @param {FileInfo} other - The other FileInfo instance to compare with
+ * @returns {boolean} Whether the two instances have the same values
+ */
+ equals(other) {
+ if (!(other instanceof FileInfo)) return false;
+ return JSON.stringify(this.toObject()) === JSON.stringify(other.toObject());
+ }
+}
diff --git a/gui_dev/src/utils/FileManager.js b/gui_dev/src/utils/FileManager.js
index 9065a3ad..aa625b5c 100644
--- a/gui_dev/src/utils/FileManager.js
+++ b/gui_dev/src/utils/FileManager.js
@@ -1,14 +1,4 @@
-/**
- * @typedef {Object} FileInfo
- * @property {string} name - The name of the file or directory
- * @property {string} path - The full path of the file or directory
- * @property {string} dir - The directory containing the file or directory
- * @property {boolean} is_directory - Whether the entry is a directory
- * @property {number} size - The size of the file in bytes (0 for directories)
- * @property {string} created_at - The creation timestamp of the file
- * @property {string} modified_at - The last modification timestamp of the file
- */
-import { getBackendURL } from "@/utils/getBackendURL";
+import { FileInfo } from "./FileInfo";
/**
* Manages file operations and interactions with the file API
@@ -41,13 +31,14 @@ export class FileManager {
show_hidden: showHidden,
});
- const response = await fetch(getBackendURL(`${this.apiBaseUrl}/api/files?${queryParams}`));
+ const response = await fetch(`${this.apiBaseUrl}?${queryParams}`);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
- return await response.json();
+ const filesData = await response.json();
+ return filesData.map((fileData) => FileInfo.fromObject(fileData));
}
/**
diff --git a/gui_dev/src/utils/debounced_sync.js b/gui_dev/src/utils/debounced_sync.js
index bdc2054f..0ad34469 100644
--- a/gui_dev/src/utils/debounced_sync.js
+++ b/gui_dev/src/utils/debounced_sync.js
@@ -1,5 +1,5 @@
import { debounce } from "@/utils";
-import { getBackendURL } from "@/utils/getBackendURL";
+import { getBackendURL } from "@/utils";
const DEBOUNCE_MS = 500; // Adjust as needed
@@ -22,28 +22,28 @@ const syncWithBackend = async (state) => {
const debouncedSync = debounce(syncWithBackend, DEBOUNCE_MS);
- /*****************************/
- /******** BACKEND SYNC *******/
- /*****************************/
-
- // Wrap state updates with sync logic
- setState: async (newState) => {
- set((state) => ({ ...state, ...newState, syncStatus: "syncing" }));
- try {
- await debouncedSync(get());
- set({ syncStatus: "synced", syncError: null });
- } catch (error) {
- set({ syncStatus: "error", syncError: error.message });
- }
- },
-
- // // Use this for actions that need immediate sync
- // setStateAndSync: async (newState) => {
- // set((state) => ({ ...state, ...newState, syncStatus: "syncing" }));
- // try {
- // const syncedState = await syncWithBackend(get());
- // set({ ...syncedState, syncStatus: "synced", syncError: null });
- // } catch (error) {
- // set({ syncStatus: "error", syncError: error.message });
- // }
- // },
\ No newline at end of file
+/*****************************/
+/******** BACKEND SYNC *******/
+/*****************************/
+
+// Wrap state updates with sync logic
+setState: async (newState) => {
+ set((state) => ({ ...state, ...newState, syncStatus: "syncing" }));
+ try {
+ await debouncedSync(get());
+ set({ syncStatus: "synced", syncError: null });
+ } catch (error) {
+ set({ syncStatus: "error", syncError: error.message });
+ }
+},
+
+// // Use this for actions that need immediate sync
+// setStateAndSync: async (newState) => {
+// set((state) => ({ ...state, ...newState, syncStatus: "syncing" }));
+// try {
+// const syncedState = await syncWithBackend(get());
+// set({ ...syncedState, syncStatus: "synced", syncError: null });
+// } catch (error) {
+// set({ syncStatus: "error", syncError: error.message });
+// }
+// }
\ No newline at end of file
diff --git a/gui_dev/src/utils/getBackendURL.js b/gui_dev/src/utils/getBackendURL.js
deleted file mode 100644
index b35c21de..00000000
--- a/gui_dev/src/utils/getBackendURL.js
+++ /dev/null
@@ -1,3 +0,0 @@
-export const getBackendURL = (route) => {
- return "http://localhost:50001" + route;
-}
diff --git a/gui_dev/src/utils/index.js b/gui_dev/src/utils/index.js
index f9810c70..f257e714 100644
--- a/gui_dev/src/utils/index.js
+++ b/gui_dev/src/utils/index.js
@@ -1 +1 @@
-export * from "./functions";
+export * from "./utils";
diff --git a/gui_dev/src/utils/functions.js b/gui_dev/src/utils/utils.js
similarity index 69%
rename from gui_dev/src/utils/functions.js
rename to gui_dev/src/utils/utils.js
index 04289d5c..c81709d9 100644
--- a/gui_dev/src/utils/functions.js
+++ b/gui_dev/src/utils/utils.js
@@ -17,6 +17,7 @@ export function debounce(func, wait) {
timeout = setTimeout(later, wait);
};
}
+
export const flattenDictionary = (dict, parentKey = "", result = {}) => {
for (let key in dict) {
const newKey = parentKey ? `${parentKey}.${key}` : key;
@@ -32,7 +33,7 @@ export const flattenDictionary = (dict, parentKey = "", result = {}) => {
export const filterObjectByKeys = (flatDict, keys) => {
const filteredDict = {};
keys.forEach((key) => {
- if (flatDict.hasOwnProperty(key)) {
+ if (Object.hasOwn(flatDict, key)) {
filteredDict[key] = flatDict[key];
}
});
@@ -48,3 +49,22 @@ export const filterObjectByKeyPrefix = (obj, prefix = "") => {
}
return result;
};
+
+export const getBackendURL = (route) => {
+ return "http://localhost:" + import.meta.env.VITE_BACKEND_PORT + route;
+};
+
+/**
+ * Fetches PyNeuromodulation directory from the backend
+ * @returns {string} PyNeuromodulation directory
+ */
+export const getPyNMDirectory = async () => {
+ const response = await fetch(getBackendURL("/api/pynm_dir"));
+ if (!response.ok) {
+ throw new Error("Failed to fetch settings");
+ }
+
+ const data = await response.json();
+
+ return data.pynm_dir;
+};
diff --git a/gui_dev/vite.config.js b/gui_dev/vite.config.js
index f0abbec9..792593d7 100644
--- a/gui_dev/vite.config.js
+++ b/gui_dev/vite.config.js
@@ -48,8 +48,5 @@ export default defineConfig(() => {
},
},
},
- server: {
- port: 54321,
- },
};
});
diff --git a/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py b/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py
index fa97e424..4f4409f6 100644
--- a/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py
+++ b/py_neuromodulation/ConnectivityDecoding/_get_grid_whole_brain.py
@@ -70,7 +70,6 @@ def select_non_zero_voxels(
coord_arr = np.array(coord_)
ival_non_zero = ival_arr[ival != 0]
coord_non_zero = coord_arr[ival != 0]
- print(coord_non_zero.shape)
return coord_non_zero, ival_non_zero
diff --git a/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py b/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py
index 2238ff6b..89cd5b2e 100644
--- a/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py
+++ b/py_neuromodulation/ConnectivityDecoding/_helper_write_connectome.py
@@ -58,8 +58,6 @@ def write_connectome_mat(
dict_connectome[
f[f.find("ROI-") + 4 : f.find("_func_seed_AvgR_Fz.nii")]
] = fp
-
- print(idx)
# save the dictionary
sio.savemat(
PATH_CONNECTOME,
diff --git a/py_neuromodulation/__init__.py b/py_neuromodulation/__init__.py
index 688393fd..864df345 100644
--- a/py_neuromodulation/__init__.py
+++ b/py_neuromodulation/__init__.py
@@ -4,11 +4,12 @@
from importlib.metadata import version
from py_neuromodulation.utils.logging import NMLogger
+
#####################################
# Globals and environment variables #
#####################################
-__version__ = version("py_neuromodulation") # get version from pyproject.toml
+__version__ = version("py_neuromodulation")
# Check if the module is running headless (no display) for tests and doc builds
PYNM_HEADLESS: bool = not os.environ.get("DISPLAY")
@@ -18,6 +19,7 @@
os.environ["LSLAPICFG"] = str(PYNM_DIR / "lsl_api.cfg") # LSL config file
+
# Set environment variable MNE_LSL_LIB (required to import Stream below)
LSL_DICT = {
"windows_32bit": "windows/x86/liblsl.1.16.2.dll",
@@ -36,6 +38,7 @@
PLATFORM = platform.system().lower().strip()
ARCH = platform.architecture()[0]
+
match PLATFORM:
case "windows":
KEY = PLATFORM + "_" + ARCH
diff --git a/py_neuromodulation/analysis/decode.py b/py_neuromodulation/analysis/decode.py
index 4971852b..722f7192 100644
--- a/py_neuromodulation/analysis/decode.py
+++ b/py_neuromodulation/analysis/decode.py
@@ -6,42 +6,66 @@
import pandas as pd
import numpy as np
from copy import deepcopy
-from pathlib import PurePath
+from pathlib import Path
import pickle
from py_neuromodulation import logger
+from py_neuromodulation.utils.types import _PathLike
from typing import Callable
class RealTimeDecoder:
-
- def __init__(self, model_path: str):
- self.model_path = model_path
- if model_path.endswith(".skops"):
+ def __init__(self, model_path: _PathLike):
+ self.model_path = Path(model_path)
+ if not self.model_path.exists():
+ raise FileNotFoundError(f"Model file {self.model_path} not found")
+ if not self.model_path.is_file():
+ raise IsADirectoryError(f"Model file {self.model_path} is a directory")
+
+ if self.model_path.suffix == ".skops":
from skops import io as skops_io
- self.model = skops_io.load(model_path)
+
+ self.model = skops_io.load(self.model_path)
else:
return NotImplementedError("Only skops models are supported")
- def predict(self, feature_dict: dict, channel: str = None, fft_bands_only: bool = True ) -> dict:
+ def predict(
+ self,
+ feature_dict: dict,
+ channel: str | None = None,
+ fft_bands_only: bool = True,
+ ) -> dict:
try:
if channel is not None:
- features_ch = {f: feature_dict[f] for f in feature_dict.keys() if f.startswith(channel)}
+ features_ch = {
+ f: feature_dict[f]
+ for f in feature_dict.keys()
+ if f.startswith(channel)
+ }
if fft_bands_only is True:
- features_ch_fft = {f: features_ch[f] for f in features_ch.keys() if "fft" in f and "psd" not in f}
- out = self.model.predict_proba(np.array(list(features_ch_fft.values())).reshape(1, -1))
+ features_ch_fft = {
+ f: features_ch[f]
+ for f in features_ch.keys()
+ if "fft" in f and "psd" not in f
+ }
+ out = self.model.predict_proba(
+ np.array(list(features_ch_fft.values())).reshape(1, -1)
+ )
else:
out = self.model.predict_proba(features_ch)
else:
out = self.model.predict(feature_dict)
for decode_output_idx in range(out.shape[1]):
- feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[decode_output_idx]
+ feature_dict[f"decode_{decode_output_idx}"] = np.squeeze(out)[
+ decode_output_idx
+ ]
return feature_dict
except Exception as e:
logger.error(f"Error in decoding: {e}")
return feature_dict
+
class CV_res:
def __init__(
self,
diff --git a/py_neuromodulation/default_settings.yaml b/py_neuromodulation/default_settings.yaml
index c50b5f8e..e68194f8 100644
--- a/py_neuromodulation/default_settings.yaml
+++ b/py_neuromodulation/default_settings.yaml
@@ -1,4 +1,15 @@
----
+# We
+# should
+# have
+# a
+# brief
+# explanation
+# of
+# the
+# settings
+# format
+# here
+
########################
### General settings ###
########################
@@ -51,12 +62,8 @@ preprocessing_filter:
lowpass_filter: true
highpass_filter: true
bandpass_filter: true
- bandstop_filter_settings:
- frequency_low_hz: 100
- frequency_high_hz: 160
- bandpass_filter_settings:
- frequency_low_hz: 3
- frequency_high_hz: 200
+ bandstop_filter_settings: [100, 160] # [low_hz, high_hz]
+ bandpass_filter_settings: [3, 200] # [hz, _hz]
lowpass_filter_cutoff_hz: 200
highpass_filter_cutoff_hz: 3
@@ -162,11 +169,9 @@ sharpwave_analysis_settings:
rise_steepness: false
decay_steepness: false
slope_ratio: false
- filter_ranges_hz:
- - frequency_low_hz: 5
- frequency_high_hz: 80
- - frequency_low_hz: 5
- frequency_high_hz: 30
+ filter_ranges_hz: # list of [low_hz, high_hz]
+ - [5, 80]
+ - [5, 30]
detect_troughs:
estimate: true
distance_troughs_ms: 10
@@ -175,6 +180,7 @@ sharpwave_analysis_settings:
estimate: true
distance_troughs_ms: 5
distance_peaks_ms: 10
+ # TONI: Reverse this setting? e.g. interval: [mean, var]
estimator:
mean: [interval]
median: []
@@ -223,8 +229,8 @@ nolds_settings:
frequency_bands: [low_beta]
mne_connectiviy_settings:
- method: plv
- mode: multitaper
+ method: plv # One of ['coh', 'cohy', 'imcoh', 'cacoh', 'mic', 'mim', 'plv', 'ciplv', 'ppc', 'pli', 'dpli','wpli', 'wpli2_debiased', 'gc', 'gc_tr']
+ mode: multitaper # One of ['multitaper', 'fourier', 'cwt_morlet']
bispectrum_settings:
f1s: [5, 35]
diff --git a/py_neuromodulation/features/bandpower.py b/py_neuromodulation/features/bandpower.py
index dac13f4b..72dab126 100644
--- a/py_neuromodulation/features/bandpower.py
+++ b/py_neuromodulation/features/bandpower.py
@@ -2,8 +2,13 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING
from pydantic import field_validator
-
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
+from py_neuromodulation.utils.pydantic_extensions import (
+ NMField,
+ NMErrorList,
+ create_validation_error,
+)
+from py_neuromodulation import logger
if TYPE_CHECKING:
from py_neuromodulation.stream.settings import NMSettings
@@ -17,15 +22,18 @@ class BandpowerFeatures(BoolSelector):
class BandPowerSettings(NMBaseModel):
- segment_lengths_ms: dict[str, int] = {
- "theta": 1000,
- "alpha": 500,
- "low beta": 333,
- "high beta": 333,
- "low gamma": 100,
- "high gamma": 100,
- "HFA": 100,
- }
+ segment_lengths_ms: dict[str, int] = NMField(
+ default={
+ "theta": 1000,
+ "alpha": 500,
+ "low beta": 333,
+ "high beta": 333,
+ "low gamma": 100,
+ "high gamma": 100,
+ "HFA": 100,
+ },
+ custom_metadata={"field_type": "FrequencySegmentLength"},
+ )
bandpower_features: BandpowerFeatures = BandpowerFeatures()
log_transform: bool = True
kalman_filter: bool = False
@@ -33,24 +41,58 @@ class BandPowerSettings(NMBaseModel):
@field_validator("bandpower_features")
@classmethod
def bandpower_features_validator(cls, bandpower_features: BandpowerFeatures):
- assert (
- len(bandpower_features.get_enabled()) > 0
- ), "Set at least one bandpower_feature to True."
-
+ if not len(bandpower_features.get_enabled()) > 0:
+ raise create_validation_error(
+ error_message="Set at least one bandpower_feature to True.",
+ location=["bandpass_filter_settings", "bandpower_features"],
+ )
return bandpower_features
- def validate_fbands(self, settings: "NMSettings") -> None:
+ def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
+ """_summary_
+
+ :param settings: _description_
+ :type settings: NMSettings
+ :raises create_validation_error: _description_
+ :raises create_validation_error: _description_
+ :raises ValueError: _description_
+ """
+ errors: NMErrorList = NMErrorList()
+
for fband_name, seg_length_fband in self.segment_lengths_ms.items():
- assert seg_length_fband <= settings.segment_length_features_ms, (
- f"segment length {seg_length_fband} needs to be smaller than "
- f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}"
- )
+ # Check that all frequency bands are defined in settings.frequency_ranges_hz
+ if fband_name not in settings.frequency_ranges_hz:
+ logger.warning(
+ f"Frequency band {fband_name} in bandpass_filter_settings.segment_lengths_ms"
+ " is not defined in settings.frequency_ranges_hz"
+ )
+
+ # Check that all segment lengths are smaller than settings.segment_length_features_ms
+ if not seg_length_fband <= settings.segment_length_features_ms:
+ errors.add_error(
+ f"segment length {seg_length_fband} needs to be smaller than "
+ f" settings['segment_length_features_ms'] = {settings.segment_length_features_ms}",
+ location=[
+ "bandpass_filter_settings",
+ "segment_lengths_ms",
+ fband_name,
+ ],
+ )
+ # Check that all frequency bands defined in settings.frequency_ranges_hz
for fband_name in settings.frequency_ranges_hz.keys():
- assert fband_name in self.segment_lengths_ms, (
- f"frequency range {fband_name} "
- "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms]"
- )
+ if fband_name not in self.segment_lengths_ms:
+ errors.add_error(
+ f"frequency range {fband_name} "
+ "needs to be defined in settings.bandpass_filter_settings.segment_lengths_ms",
+ location=[
+ "bandpass_filter_settings",
+ "segment_lengths_ms",
+ fband_name,
+ ],
+ )
+
+ return errors
class BandPower(NMFeature):
diff --git a/py_neuromodulation/features/bursts.py b/py_neuromodulation/features/bursts.py
index 83d63bf8..e9f5b632 100644
--- a/py_neuromodulation/features/bursts.py
+++ b/py_neuromodulation/features/bursts.py
@@ -7,11 +7,12 @@
from collections.abc import Sequence
from itertools import product
-from pydantic import Field, field_validator
+from pydantic import field_validator
from py_neuromodulation.utils.types import BoolSelector, NMBaseModel, NMFeature
+from py_neuromodulation.utils.pydantic_extensions import NMField
from typing import TYPE_CHECKING, Callable
-from py_neuromodulation.utils.types import create_validation_error
+from py_neuromodulation.utils.pydantic_extensions import create_validation_error
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
@@ -46,8 +47,8 @@ class BurstFeatures(BoolSelector):
class BurstsSettings(NMBaseModel):
- threshold: float = Field(default=75, ge=0, le=100)
- time_duration_s: float = Field(default=30, ge=0)
+ threshold: float = NMField(default=75, ge=0)
+ time_duration_s: float = NMField(default=30, ge=0, custom_metadata={"unit": "s"})
frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
burst_features: BurstFeatures = BurstFeatures()
diff --git a/py_neuromodulation/features/coherence.py b/py_neuromodulation/features/coherence.py
index 21ca471b..6a100710 100644
--- a/py_neuromodulation/features/coherence.py
+++ b/py_neuromodulation/features/coherence.py
@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Annotated
-from pydantic import Field, field_validator
+from pydantic import field_validator
from py_neuromodulation.utils.types import (
NMFeature,
@@ -11,6 +11,7 @@
FrequencyRange,
NMBaseModel,
)
+from py_neuromodulation.utils.pydantic_extensions import NMField
from py_neuromodulation import logger
if TYPE_CHECKING:
@@ -28,14 +29,16 @@ class CoherenceFeatures(BoolSelector):
max_allfbands: bool = True
-ListOfTwoStr = Annotated[list[str], Field(min_length=2, max_length=2)]
+# TODO: make this into a pydantic model that only accepts names from
+# the channels objects and does not accept the same string twice
+ListOfTwoStr = Annotated[list[str], NMField(min_length=2, max_length=2)]
class CoherenceSettings(NMBaseModel):
features: CoherenceFeatures = CoherenceFeatures()
method: CoherenceMethods = CoherenceMethods()
channels: list[ListOfTwoStr] = []
- frequency_bands: list[str] = Field(default=["high_beta"], min_length=1)
+ frequency_bands: list[str] = NMField(default=["high_beta"], min_length=1)
@field_validator("frequency_bands")
def fbands_spaces_to_underscores(cls, frequency_bands):
diff --git a/py_neuromodulation/features/feature_processor.py b/py_neuromodulation/features/feature_processor.py
index 9e970e8f..dbb7e5c9 100644
--- a/py_neuromodulation/features/feature_processor.py
+++ b/py_neuromodulation/features/feature_processor.py
@@ -1,13 +1,13 @@
from typing import Type, TYPE_CHECKING
-from py_neuromodulation.utils.types import NMFeature, FeatureName
+from py_neuromodulation.utils.types import NMFeature, FEATURE_NAME
if TYPE_CHECKING:
import numpy as np
from py_neuromodulation import NMSettings
-FEATURE_DICT: dict[FeatureName | str, str] = {
+FEATURE_DICT: dict[FEATURE_NAME | str, str] = {
"raw_hjorth": "Hjorth",
"return_raw": "Raw",
"bandpass_filter": "BandPower",
@@ -42,7 +42,7 @@ def __init__(
from importlib import import_module
# Accept 'str' for custom features
- self.features: dict[FeatureName | str, NMFeature] = {
+ self.features: dict[FEATURE_NAME | str, NMFeature] = {
feature_name: getattr(
import_module("py_neuromodulation.features"), FEATURE_DICT[feature_name]
)(settings, ch_names, sfreq)
@@ -83,7 +83,7 @@ def estimate_features(self, data: "np.ndarray") -> dict:
return feature_results
- def get_feature(self, fname: FeatureName) -> NMFeature:
+ def get_feature(self, fname: FEATURE_NAME) -> NMFeature:
return self.features[fname]
diff --git a/py_neuromodulation/features/fooof.py b/py_neuromodulation/features/fooof.py
index 683c5304..9f4b9544 100644
--- a/py_neuromodulation/features/fooof.py
+++ b/py_neuromodulation/features/fooof.py
@@ -9,6 +9,7 @@
BoolSelector,
FrequencyRange,
)
+from py_neuromodulation.utils.pydantic_extensions import NMField
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
@@ -29,11 +30,11 @@ class FooofPeriodicSettings(BoolSelector):
class FooofSettings(NMBaseModel):
aperiodic: FooofAperiodicSettings = FooofAperiodicSettings()
periodic: FooofPeriodicSettings = FooofPeriodicSettings()
- windowlength_ms: float = 800
+ windowlength_ms: float = NMField(800, gt=0, custom_metadata={"unit": "ms"})
peak_width_limits: FrequencyRange = FrequencyRange(0.5, 12)
- max_n_peaks: int = 3
- min_peak_height: float = 0
- peak_threshold: float = 2
+ max_n_peaks: int = NMField(3, ge=0)
+ min_peak_height: float = NMField(0, ge=0)
+ peak_threshold: float = NMField(2, ge=0)
freq_range_hz: FrequencyRange = FrequencyRange(2, 40)
knee: bool = True
diff --git a/py_neuromodulation/features/mne_connectivity.py b/py_neuromodulation/features/mne_connectivity.py
index 88883771..f7a186e7 100644
--- a/py_neuromodulation/features/mne_connectivity.py
+++ b/py_neuromodulation/features/mne_connectivity.py
@@ -1,8 +1,9 @@
from collections.abc import Iterable
import numpy as np
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Literal
from py_neuromodulation.utils.types import NMFeature, NMBaseModel
+from py_neuromodulation.utils.pydantic_extensions import NMField
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
@@ -10,9 +11,30 @@
from mne import Epochs
+MNE_CONNECTIVITY_METHOD = Literal[
+ "coh",
+ "cohy",
+ "imcoh",
+ "cacoh",
+ "mic",
+ "mim",
+ "plv",
+ "ciplv",
+ "ppc",
+ "pli",
+ "dpli",
+ "wpli",
+ "wpli2_debiased",
+ "gc",
+ "gc_tr",
+]
+
+MNE_CONNECTIVITY_MODE = Literal["multitaper", "fourier", "cwt_morlet"]
+
+
class MNEConnectivitySettings(NMBaseModel):
- method: str = "plv"
- mode: str = "multitaper"
+ method: MNE_CONNECTIVITY_METHOD = NMField(default="plv")
+ mode: MNE_CONNECTIVITY_MODE = NMField(default="multitaper")
class MNEConnectivity(NMFeature):
diff --git a/py_neuromodulation/features/oscillatory.py b/py_neuromodulation/features/oscillatory.py
index ba376cc6..bb4b9e6a 100644
--- a/py_neuromodulation/features/oscillatory.py
+++ b/py_neuromodulation/features/oscillatory.py
@@ -3,6 +3,7 @@
from itertools import product
from py_neuromodulation.utils.types import NMBaseModel, BoolSelector, NMFeature
+from py_neuromodulation.utils.pydantic_extensions import NMField
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -17,7 +18,7 @@ class OscillatoryFeatures(BoolSelector):
class OscillatorySettings(NMBaseModel):
- windowlength_ms: int = 1000
+ windowlength_ms: int = NMField(1000, gt=0, custom_metadata={"unit": "ms"})
log_transform: bool = True
features: OscillatoryFeatures = OscillatoryFeatures(
mean=True, median=False, std=False, max=False
diff --git a/py_neuromodulation/filter/kalman_filter.py b/py_neuromodulation/filter/kalman_filter.py
index b7ae1075..0ebfe195 100644
--- a/py_neuromodulation/filter/kalman_filter.py
+++ b/py_neuromodulation/filter/kalman_filter.py
@@ -1,7 +1,9 @@
import numpy as np
from typing import TYPE_CHECKING
+
from py_neuromodulation.utils.types import NMBaseModel
+from py_neuromodulation.utils.pydantic_extensions import NMErrorList
if TYPE_CHECKING:
@@ -22,13 +24,22 @@ class KalmanSettings(NMBaseModel):
"HFA",
]
- def validate_fbands(self, settings: "NMSettings") -> None:
- assert all(
+ def validate_fbands(self, settings: "NMSettings") -> NMErrorList:
+ errors: NMErrorList = NMErrorList()
+
+ if not all(
(item in settings.frequency_ranges_hz for item in self.frequency_bands)
- ), (
- "Frequency bands for Kalman filter must also be specified in "
- "bandpass_filter_settings."
- )
+ ):
+ errors.add_error(
+ "Frequency bands for Kalman filter must also be specified in "
+ "frequency_ranges_hz.",
+ location=[
+ "kalman_filter_settings",
+ "frequency_bands",
+ ],
+ )
+
+ return errors
def define_KF(Tp, sigma_w, sigma_v):
diff --git a/py_neuromodulation/gui/backend/app_backend.py b/py_neuromodulation/gui/backend/app_backend.py
index 06b929f0..f64e837d 100644
--- a/py_neuromodulation/gui/backend/app_backend.py
+++ b/py_neuromodulation/gui/backend/app_backend.py
@@ -1,5 +1,4 @@
import logging
-import asyncio
import importlib.metadata
from datetime import datetime
from pathlib import Path
@@ -13,9 +12,10 @@
)
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
+from pydantic import ValidationError
from . import app_pynm
-from .app_socket import WebSocketManager
+from .app_socket import WebsocketManager
from .app_utils import is_hidden, get_quick_access
import pandas as pd
@@ -29,9 +29,9 @@
class PyNMBackend(FastAPI):
def __init__(
self,
- pynm_state: app_pynm.PyNMState,
debug=False,
dev=True,
+ dev_port: int | None = None,
fastapi_kwargs: dict = {},
) -> None:
super().__init__(debug=debug, **fastapi_kwargs)
@@ -43,14 +43,18 @@ def __init__(
self.logger = logging.getLogger("uvicorn.error")
self.logger.warning(PYNM_DIR)
- # Configure CORS
- self.add_middleware(
- CORSMiddleware,
- allow_origins=["http://localhost:54321"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
- )
+ if dev:
+ cors_origins = (
+ ["http://localhost:" + str(dev_port)] if dev_port is not None else []
+ )
+ # Configure CORS
+ self.add_middleware(
+ CORSMiddleware,
+ allow_origins=cors_origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
# Has to be before mounting static files
self.setup_routes()
@@ -63,8 +67,8 @@ def __init__(
name="static",
)
- self.pynm_state = pynm_state
- self.websocket_manager = WebSocketManager()
+ self.websocket_manager = WebsocketManager()
+ self.pynm_state = app_pynm.PyNMState()
def setup_routes(self):
@self.get("/api/health")
@@ -74,21 +78,63 @@ async def healthcheck():
####################
##### SETTINGS #####
####################
-
@self.get("/api/settings")
- async def get_settings():
- return self.pynm_state.settings.process_for_frontend()
+ async def get_settings(
+ reset: bool = Query(False, description="Reset settings to default"),
+ ):
+ if reset:
+ settings = NMSettings.get_default()
+ else:
+ settings = self.pynm_state.settings
+
+ return settings.serialize_with_metadata()
@self.post("/api/settings")
- async def update_settings(data: dict):
+ async def update_settings(data: dict, validate_only: bool = Query(False)):
try:
- self.pynm_state.settings = NMSettings.model_validate(data)
- self.logger.info(self.pynm_state.settings.features)
- return self.pynm_state.settings.model_dump()
- except ValueError as e:
+ # First, validate with Pydantic
+ try:
+ # TODO: check if this works properly or needs model_validate_strings
+ validated_settings = NMSettings.model_validate(data)
+ except ValidationError as e:
+ self.logger.error(f"Error validating settings: {e}")
+ if not validate_only:
+ # If validation failed but we wanted to upload, return error
+ raise HTTPException(
+ status_code=422,
+ detail={
+ "error": "Error validating settings",
+ "details": str(e),
+ },
+ )
+ # Else return list of errors
+ return {
+ "valid": False,
+ "errors": [err for err in e.errors()],
+ "details": str(e),
+ }
+
+ # If validation succesful, return or update settings
+ if validate_only:
+ return {
+ "valid": True,
+ "settings": validated_settings.serialize_with_metadata(),
+ }
+
+ self.pynm_state.settings = validated_settings
+ self.logger.info("Settings successfully updated")
+
+ return {
+ "valid": True,
+ "settings": self.pynm_state.settings.serialize_with_metadata(),
+ }
+
+ # If something else than validation went wrong, return error
+ except Exception as e:
+ self.logger.error(f"Error validating/updating settings: {e}")
raise HTTPException(
status_code=422,
- detail={"error": "Validation failed", "details": str(e)},
+ detail={"error": "Error uploading settings", "details": str(e)},
)
########################
@@ -105,14 +151,12 @@ async def handle_stream_control(data: dict):
self.logger.info("Starting stream")
self.pynm_state.start_run_function(
- websocket_manager=self.websocket_manager,
+ websocket_manager=self.websocket_manager,
)
-
if action == "stop":
self.logger.info("Stopping stream")
- self.pynm_state.stream_handling_queue.put("stop")
- self.pynm_state.stop_event_ws.set()
+ self.pynm_state.stop_run_function()
return {"message": f"Stream action '{action}' executed"}
@@ -274,6 +318,14 @@ async def home_directory():
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
+ # Get PYNM_DIR
+ @self.get("/api/pynm_dir")
+ async def get_pynm_dir():
+ try:
+ return {"pynm_dir": PYNM_DIR}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
# Get list of available drives in Windows systems
@self.get("/api/drives")
async def list_drives():
diff --git a/py_neuromodulation/gui/backend/app_manager.py b/py_neuromodulation/gui/backend/app_manager.py
index b1cb6971..8bec1023 100644
--- a/py_neuromodulation/gui/backend/app_manager.py
+++ b/py_neuromodulation/gui/backend/app_manager.py
@@ -12,24 +12,40 @@
if TYPE_CHECKING:
from multiprocessing.synchronize import Event
- from .app_backend import PyNMBackend
# Shared memory configuration
ARRAY_SIZE = 1000 # Adjust based on your needs
+SERVER_PORT = 50001
+DEV_SERVER_PORT = 54321
+
-def create_backend() -> "PyNMBackend":
- from .app_pynm import PyNMState
+def create_backend():
+ """Factory function passed to Uvicorn to create the web application instance.
+
+ :return: The web application instance.
+ :rtype: PyNMBackend
+ """
from .app_backend import PyNMBackend
debug = os.environ.get("PYNM_DEBUG", "False").lower() == "true"
dev = os.environ.get("PYNM_DEV", "True").lower() == "true"
+ dev_port = os.environ.get("PYNM_DEV_PORT", str(DEV_SERVER_PORT))
- return PyNMBackend(pynm_state=PyNMState(), debug=debug, dev=dev)
+ return PyNMBackend(
+ debug=debug,
+ dev=dev,
+ dev_port=int(dev_port),
+ )
-def run_vite(shutdown_event: "Event", debug: bool = False) -> None:
+def run_vite(
+ shutdown_event: "Event",
+ debug: bool = False,
+ dev_port: int = DEV_SERVER_PORT,
+ backend_port: int = SERVER_PORT,
+) -> None:
"""Run Vite in a separate shell"""
import subprocess
@@ -41,6 +57,8 @@ def run_vite(shutdown_event: "Event", debug: bool = False) -> None:
logging.DEBUG if debug else logging.INFO,
)
+ os.environ["VITE_BACKEND_PORT"] = str(backend_port)
+
def output_reader(shutdown_event: "Event", process: subprocess.Popen):
logger.debug("Initialized output stream")
color = ansi_color(color="magenta", bright=True, styles=["BOLD"])
@@ -73,7 +91,7 @@ def read_stream(stream, stream_name):
subprocess_flags = subprocess.CREATE_NEW_PROCESS_GROUP if os.name == "nt" else 0
process = subprocess.Popen(
- "bun run dev",
+ "bun run dev --port " + str(dev_port),
cwd="gui_dev",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
@@ -106,7 +124,9 @@ def read_stream(stream, stream_name):
logger.info("Development server stopped")
-def run_uvicorn(debug: bool = False, reload=False) -> None:
+def run_uvicorn(
+ debug: bool = False, reload=False, server_port: int = SERVER_PORT
+) -> None:
from uvicorn.server import Server
from uvicorn.config import LOGGING_CONFIG, Config
@@ -131,7 +151,7 @@ def run_uvicorn(debug: bool = False, reload=False) -> None:
host="localhost",
reload=reload,
factory=True,
- port=50001,
+ port=server_port,
log_level="debug" if debug else "info",
log_config=log_config,
)
@@ -160,17 +180,23 @@ def restart(self) -> None:
def run_backend(
- shutdown_event: "Event", debug: bool = False, reload: bool = True, dev: bool = True
+ shutdown_event: "Event",
+ dev: bool = True,
+ debug: bool = False,
+ reload: bool = True,
+ server_port: int = SERVER_PORT,
+ dev_port: int = DEV_SERVER_PORT,
) -> None:
signal.signal(signal.SIGINT, signal.SIG_IGN)
# Pass create_backend parameters through environment variables
os.environ["PYNM_DEBUG"] = str(debug)
os.environ["PYNM_DEV"] = str(dev)
+ os.environ["PYNM_DEV_PORT"] = str(dev_port)
server_process = mp.Process(
target=run_uvicorn,
- kwargs={"debug": debug, "reload": reload},
+ kwargs={"debug": debug, "reload": reload, "server_port": server_port},
name="Server",
)
server_process.start()
@@ -182,7 +208,12 @@ class AppManager:
LAUNCH_FLAG = "PYNM_RUNNING"
def __init__(
- self, debug: bool = False, dev: bool = True, run_in_webview=False
+ self,
+ debug: bool = False,
+ dev: bool = True,
+ run_in_webview=False,
+ server_port=SERVER_PORT,
+ dev_port=DEV_SERVER_PORT,
) -> None:
"""_summary_
@@ -197,6 +228,9 @@ def __init__(
self.debug = debug
self.dev = dev
self.run_in_webview = run_in_webview
+ self.server_port = server_port
+ self.dev_port = dev_port
+
self._reset()
# Prevent launching multiple instances of the app due to multiprocessing
# This allows the absence of a main guard in the main script
@@ -270,7 +304,12 @@ def launch(self) -> None:
self.logger.info("Starting Vite server...")
self.tasks["vite"] = mp.Process(
target=run_vite,
- kwargs={"shutdown_event": self.shutdown_event, "debug": self.debug},
+ kwargs={
+ "shutdown_event": self.shutdown_event,
+ "debug": self.debug,
+ "dev_port": self.dev_port,
+ "backend_port": self.server_port,
+ },
name="Vite",
)
@@ -282,6 +321,8 @@ def launch(self) -> None:
"debug": self.debug,
"reload": self.dev,
"dev": self.dev,
+ "server_port": self.server_port,
+ "dev_port": self.dev_port,
},
name="Backend",
)
diff --git a/py_neuromodulation/gui/backend/app_pynm.py b/py_neuromodulation/gui/backend/app_pynm.py
index 523a4f62..23641786 100644
--- a/py_neuromodulation/gui/backend/app_pynm.py
+++ b/py_neuromodulation/gui/backend/app_pynm.py
@@ -1,152 +1,145 @@
import os
-import asyncio
-import logging
-import threading
import numpy as np
-import multiprocessing as mp
from threading import Thread
-import queue
+import time
+import asyncio
+import multiprocessing as mp
+from queue import Empty
+from pathlib import Path
from py_neuromodulation.stream import Stream, NMSettings
from py_neuromodulation.analysis.decode import RealTimeDecoder
from py_neuromodulation.utils import set_channels
from py_neuromodulation.utils.io import read_mne_data
+from py_neuromodulation.utils.types import _PathLike
+from py_neuromodulation import logger
+from py_neuromodulation.gui.backend.app_socket import WebsocketManager
+from py_neuromodulation.stream.backend_interface import StreamBackendInterface
from py_neuromodulation import logger
-async def run_stream_controller(feature_queue: queue.Queue, rawdata_queue: queue.Queue,
- websocket_manager: "WebSocketManager", stop_event: threading.Event):
- while not stop_event.wait(0.002):
- if not feature_queue.empty() and websocket_manager is not None:
- feature_dict = feature_queue.get()
- logger.info("Sending message to Websocket")
- await websocket_manager.send_cbor(feature_dict)
-
- if not rawdata_queue.empty() and websocket_manager is not None:
- raw_data = rawdata_queue.get()
-
- await websocket_manager.send_cbor(raw_data)
-
-def run_stream_controller_sync(feature_queue: queue.Queue,
- rawdata_queue: queue.Queue,
- websocket_manager: "WebSocketManager",
- stop_event: threading.Event
- ):
- # The run_stream_controller needs to be started as an asyncio function due to the async websocket
- asyncio.run(run_stream_controller(feature_queue, rawdata_queue, websocket_manager, stop_event))
class PyNMState:
def __init__(
self,
- default_init: bool = True, # has to be true for the backend settings communication
) -> None:
- self.logger = logging.getLogger("uvicorn.error")
+ self.lsl_stream_name: str = ""
+ self.experiment_name: str = "PyNM_Experiment" # set by set_stream_params
+ self.out_dir: _PathLike = str(
+ Path.home() / "PyNM" / self.experiment_name
+ ) # set by set_stream_params
+ self.decoding_model_path: _PathLike | None = None
+ self.decoder: RealTimeDecoder | None = None
+ self.is_stream_lsl: bool = False
- self.lsl_stream_name = None
- self.stream_controller_process = None
- self.run_func_process = None
- self.out_dir = None # will be set by set_stream_params
- self.experiment_name = None # will be set by set_stream_params
- self.decoding_model_path = None
- self.decoder = None
+ self.backend_interface: StreamBackendInterface | None = None
+ self.websocket_manager: WebsocketManager | None = None
- if default_init:
- self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1]))
- self.settings: NMSettings = NMSettings(sampling_rate_features=10)
+ # Note: sfreq and data are required for stream init
+ self.stream: Stream = Stream(sfreq=1500, data=np.random.random([1, 1]))
+ self.settings: NMSettings = NMSettings(sampling_rate_features=10)
+ self.feature_queue = mp.Queue()
+ self.rawdata_queue = mp.Queue()
+ self.control_queue = mp.Queue()
+ self.stop_event = asyncio.Event()
+
+ self.messages_sent = 0
def start_run_function(
self,
- websocket_manager=None,
+ websocket_manager: WebsocketManager | None = None,
) -> None:
-
- self.stream.settings = self.settings
-
- self.stream_handling_queue = queue.Queue()
- self.feature_queue = queue.Queue()
- self.rawdata_queue = queue.Queue()
+ # TONI: This is dangerous to do here, should be done by the setup functions
+ # self.stream.settings = self.settings
- self.logger.info("Starting stream_controller_process thread")
+ self.is_stream_lsl = self.lsl_stream_name is not None
+ self.websocket_manager = websocket_manager
+ # Create decoder
if self.decoding_model_path is not None and self.decoding_model_path != "None":
if os.path.exists(self.decoding_model_path):
self.decoder = RealTimeDecoder(self.decoding_model_path)
else:
- logger.debug("Passed decoding model path does't exist")
- # Stop event
- # .set() is called from app_backend
- self.stop_event_ws = threading.Event()
-
- self.stream_controller_thread = Thread(
- target=run_stream_controller_sync,
- daemon=True,
- args=(self.feature_queue,
- self.rawdata_queue,
- websocket_manager,
- self.stop_event_ws
- ),
- )
+ logger.debug("Passed decoding model path does't exist")
+
+ # Initialize the backend interface if not already done
+ if not self.backend_interface:
+ self.backend_interface = StreamBackendInterface(
+ self.feature_queue, self.rawdata_queue, self.control_queue
+ )
- is_stream_lsl = self.lsl_stream_name is not None
- stream_lsl_name = self.lsl_stream_name if self.lsl_stream_name is not None else ""
-
# The run_func_thread is terminated through the stream_handling_queue
# which initiates to break the data generator and save the features
- out_dir = "" if self.out_dir == "default" else self.out_dir
- self.run_func_thread = Thread(
+ stream_process = mp.Process(
target=self.stream.run,
- daemon=True,
kwargs={
- "out_dir" : out_dir,
- "experiment_name" : self.experiment_name,
- "stream_handling_queue" : self.stream_handling_queue,
- "is_stream_lsl" : is_stream_lsl,
- "stream_lsl_name" : stream_lsl_name,
- "feature_queue" : self.feature_queue,
- "simulate_real_time" : True,
- "rawdata_queue" : self.rawdata_queue,
- "decoder" : self.decoder,
+ "out_dir": "" if self.out_dir == "default" else self.out_dir,
+ "experiment_name": self.experiment_name,
+ "is_stream_lsl": self.lsl_stream_name is not None,
+ "stream_lsl_name": self.lsl_stream_name or "",
+ "simulate_real_time": True,
+ "decoder": self.decoder,
+ "backend_interface": self.backend_interface,
},
)
- self.stream_controller_thread.start()
- self.run_func_thread.start()
+ stream_process.start()
+
+ # Start websocket sender process
+
+ if self.websocket_manager:
+ # Get the current event loop and run the queue processor
+ loop = asyncio.get_running_loop()
+ queue_task = loop.create_task(self._process_queue())
+
+ # Store task reference for cleanup
+ self._queue_task = queue_task
+
+ # Store processes for cleanup
+ self.stream_process = stream_process
+
+ def stop_run_function(self) -> None:
+ """Stop the stream processing"""
+ if self.backend_interface:
+ self.backend_interface.send_command("stop")
+ self.stop_event.set()
def setup_lsl_stream(
self,
- lsl_stream_name: str | None = None,
+ lsl_stream_name: str = "",
line_noise: float | None = None,
sampling_rate_features: float | None = None,
):
from mne_lsl.lsl import resolve_streams
- self.logger.info("resolving streams")
+ logger.info("resolving streams")
lsl_streams = resolve_streams()
for stream in lsl_streams:
if stream.name == lsl_stream_name:
- self.logger.info(f"found stream {lsl_stream_name}")
+ logger.info(f"found stream {lsl_stream_name}")
# setup this stream
self.lsl_stream_name = lsl_stream_name
ch_names = stream.get_channel_names()
if ch_names is None:
ch_names = ["ch" + str(i) for i in range(stream.n_channels)]
- self.logger.info(f"channel names: {ch_names}")
+ logger.info(f"channel names: {ch_names}")
ch_types = stream.get_channel_types()
if ch_types is None:
ch_types = ["eeg" for i in range(stream.n_channels)]
- self.logger.info(f"channel types: {ch_types}")
+ logger.info(f"channel types: {ch_types}")
info_ = stream.get_channel_info()
- self.logger.info(f"channel info: {info_}")
+ logger.info(f"channel info: {info_}")
channels = set_channels(
ch_names=ch_names,
ch_types=ch_types,
used_types=["eeg", "ecog", "dbs", "seeg"],
)
- self.logger.info(channels)
+ logger.info(channels)
sfreq = stream.sfreq
self.stream: Stream = Stream(
@@ -156,20 +149,19 @@ def setup_lsl_stream(
sampling_rate_features_hz=sampling_rate_features,
settings=self.settings,
)
- self.logger.info("stream setup")
- #self.settings: NMSettings = NMSettings(sampling_rate_features=sfreq)
- self.logger.info("settings setup")
+ logger.info("stream setup")
+ # self.settings: NMSettings = NMSettings(sampling_rate_features=sfreq)
+ logger.info("settings setup")
break
-
- if channels.shape[0] == 0:
- self.logger.error(f"Stream {lsl_stream_name} not found")
+ else:
+ logger.error(f"Stream {lsl_stream_name} not found")
raise ValueError(f"Stream {lsl_stream_name} not found")
def setup_offline_stream(
self,
file_path: str,
- line_noise: float | None = None,
- sampling_rate_features: float | None = None,
+ line_noise: float,
+ sampling_rate_features: float,
):
data, sfreq, ch_names, ch_types, bads = read_mne_data(file_path)
@@ -182,7 +174,9 @@ def setup_offline_stream(
target_keywords=None,
)
- self.logger.info(f"settings: {self.settings}")
+ self.settings.sampling_rate_features_hz = sampling_rate_features
+
+ logger.info(f"settings: {self.settings}")
self.stream: Stream = Stream(
settings=self.settings,
sfreq=sfreq,
@@ -191,3 +185,53 @@ def setup_offline_stream(
line_noise=line_noise,
sampling_rate_features_hz=sampling_rate_features,
)
+
+ # Async function that will continuously run in the Uvicorn async loop
+ # and handle sending data through the websocket manager
+ async def _process_queue(self):
+ last_queue_check = time.time()
+
+ while not self.stop_event.is_set():
+ # Use asyncio.gather to process both queues concurrently
+ tasks = []
+ current_time = time.time()
+
+ # Check feature queue
+ while not self.feature_queue.empty():
+ try:
+ data = self.feature_queue.get_nowait()
+ tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore
+ self.messages_sent += 1
+ except Empty:
+ break
+
+ # Check raw data queue
+ while not self.rawdata_queue.empty():
+ try:
+ data = self.rawdata_queue.get_nowait()
+ self.messages_sent += 1
+ tasks.append(self.websocket_manager.send_cbor(data)) # type: ignore
+ except Empty:
+ break
+
+ if tasks:
+ # Wait for all send operations to complete
+ await asyncio.gather(*tasks, return_exceptions=True)
+ else:
+ # Only sleep if we didn't process any messages
+ await asyncio.sleep(0.001)
+
+ # Log queue diagnostics every 5 seconds
+ if current_time - last_queue_check > 5:
+ logger.info(
+ "\nQueue diagnostics:\n"
+ f"\tMessages send to websocket: {self.messages_sent}.\n"
+ f"\tFeature queue size: ~{self.feature_queue.qsize()}\n"
+ f"\tRaw data queue size: ~{self.rawdata_queue.qsize()}"
+ )
+
+ last_queue_check = current_time
+
+ # Check if stream process is still alive
+ if not self.stream_process.is_alive():
+ break
diff --git a/py_neuromodulation/gui/backend/app_socket.py b/py_neuromodulation/gui/backend/app_socket.py
index a7afed67..f81cc864 100644
--- a/py_neuromodulation/gui/backend/app_socket.py
+++ b/py_neuromodulation/gui/backend/app_socket.py
@@ -1,9 +1,10 @@
from fastapi import WebSocket
import logging
-import struct
-import json
import cbor2
-class WebSocketManager:
+import time
+
+
+class WebsocketManager:
"""
Manages WebSocket connections and messages.
Perhaps in the future it will handle multiple connections.
@@ -12,6 +13,12 @@ class WebSocketManager:
def __init__(self):
self.active_connections: list[WebSocket] = []
self.logger = logging.getLogger("PyNM")
+ self.disconnected = []
+ self._queue_task = None
+ self._stop_event = None
+ self.loop = None
+ self.messages_sent = 0
+ self._last_diagnostic_time = time.time()
async def connect(self, websocket: WebSocket):
await websocket.accept()
@@ -30,25 +37,66 @@ def disconnect(self, websocket: WebSocket):
f"Client {client_address.port}:{client_address.port} disconnected."
)
+ async def _cleanup_disconnected(self):
+ for connection in self.disconnected:
+ self.active_connections.remove(connection)
+ await connection.close()
+
# Combine IP and port to create a unique client ID
async def send_cbor(self, object: dict):
+ if not self.active_connections:
+ self.logger.warning("No active connection to send message.")
+ return
+
+ start_time = time.time()
+ cbor_data = cbor2.dumps(object)
+ serialize_time = time.time() - start_time
+
+ if serialize_time > 0.1: # Log slow serializations
+ self.logger.warning(f"CBOR serialization took {serialize_time:.3f}s")
+
+ send_start = time.time()
for connection in self.active_connections:
try:
- await connection.send_bytes(cbor2.dumps(object))
+ await connection.send_bytes(cbor_data)
except RuntimeError as e:
- self.active_connections.remove(connection)
+ self.logger.error(f"Error sending CBOR message: {e}")
+ self.disconnected.append(connection)
+
+ send_time = time.time() - send_start
+ if send_time > 0.1: # Log slow sends
+ self.logger.warning(f"WebSocket send took {send_time:.3f}s")
+
+ self.messages_sent += 1
+
+ # Log diagnostics every 5 seconds
+ current_time = time.time()
+ if current_time - self._last_diagnostic_time > 5:
+ self.logger.info(f"Messages sent: {self.messages_sent}")
+ self._last_diagnostic_time = current_time
+
+ await self._cleanup_disconnected()
async def send_message(self, message: str | dict):
- self.logger.info(f"Sending message within app_socket: {message.keys()}")
- if self.active_connections:
- for connection in self.active_connections:
+ if not self.active_connections:
+ self.logger.warning("No active connection to send message.")
+ return
+
+ self.logger.info(
+ f"Sending message within app_socket: {message.keys() if type(message) is dict else message}"
+ )
+ for connection in self.active_connections:
+ try:
if type(message) is dict:
await connection.send_json(message)
elif type(message) is str:
await connection.send_text(message)
- self.logger.info(f"Message sent")
- else:
- self.logger.warning("No active connection to send message.")
+ self.logger.info(f"Message sent to {connection.client}")
+ except RuntimeError as e:
+ self.logger.error(f"Error sending message: {e}")
+ self.disconnected.append(connection)
+
+ await self._cleanup_disconnected()
@property
def is_connected(self):
diff --git a/py_neuromodulation/gui/backend/app_utils.py b/py_neuromodulation/gui/backend/app_utils.py
index 7fcce8d3..c623bf1b 100644
--- a/py_neuromodulation/gui/backend/app_utils.py
+++ b/py_neuromodulation/gui/backend/app_utils.py
@@ -6,6 +6,7 @@
from py_neuromodulation.utils.types import _PathLike
from functools import lru_cache
import platform
+from py_neuromodulation import logger
def force_terminate_process(
@@ -211,20 +212,18 @@ def get_pinned_folders_windows():
"""
try:
- print(powershell_command)
result = subprocess.run(
["powershell", "-Command", powershell_command],
capture_output=True,
text=True,
check=True,
)
- print(result.stdout)
return json.loads(result.stdout)
except subprocess.CalledProcessError as e:
- print(f"Error running PowerShell command: {e}")
+ logger.error(f"Error running PowerShell command: {e}")
return []
except json.JSONDecodeError as e:
- print(f"Error decoding JSON: {e}")
+ logger.error(f"Error decoding JSON: {e}")
return []
@@ -299,9 +298,9 @@ def get_macos_favorites():
path = item["URL"].replace("file://", "")
favorites.append({"Name": item["Name"], "Path": path})
except (subprocess.CalledProcessError, json.JSONDecodeError) as e:
- print(f"Error processing macOS favorites: {e}")
+ logger.error(f"Error processing macOS favorites: {e}")
except Exception as e:
- print(f"Error getting macOS favorites: {e}")
+ logger.error(f"Error getting macOS favorites: {e}")
return favorites
diff --git a/py_neuromodulation/processing/data_preprocessor.py b/py_neuromodulation/processing/data_preprocessor.py
index 319926d9..3b9a9088 100644
--- a/py_neuromodulation/processing/data_preprocessor.py
+++ b/py_neuromodulation/processing/data_preprocessor.py
@@ -1,12 +1,12 @@
from typing import TYPE_CHECKING, Type
-from py_neuromodulation.utils.types import PreprocessorName, NMPreprocessor
+from py_neuromodulation.utils.types import PREPROCESSOR_NAME, NMPreprocessor
if TYPE_CHECKING:
import numpy as np
import pandas as pd
from py_neuromodulation.stream.settings import NMSettings
-PREPROCESSOR_DICT: dict[PreprocessorName, str] = {
+PREPROCESSOR_DICT: dict[PREPROCESSOR_NAME, str] = {
"preprocessing_filter": "PreprocessingFilter",
"notch_filter": "NotchFilter",
"raw_resampling": "Resampler",
diff --git a/py_neuromodulation/processing/filter_preprocessing.py b/py_neuromodulation/processing/filter_preprocessing.py
index e6515ae0..2c5addbc 100644
--- a/py_neuromodulation/processing/filter_preprocessing.py
+++ b/py_neuromodulation/processing/filter_preprocessing.py
@@ -1,22 +1,14 @@
import numpy as np
-from pydantic import Field
from typing import TYPE_CHECKING
from py_neuromodulation.utils.types import BoolSelector, FrequencyRange, NMPreprocessor
+from py_neuromodulation.utils.pydantic_extensions import NMField
if TYPE_CHECKING:
from py_neuromodulation import NMSettings
-FILTER_SETTINGS_MAP = {
- "bandstop_filter": "bandstop_filter_settings",
- "bandpass_filter": "bandpass_filter_settings",
- "lowpass_filter": "lowpass_filter_cutoff_hz",
- "highpass_filter": "highpass_filter_cutoff_hz",
-}
-
-
class FilterSettings(BoolSelector):
bandstop_filter: bool = True
bandpass_filter: bool = True
@@ -25,21 +17,23 @@ class FilterSettings(BoolSelector):
bandstop_filter_settings: FrequencyRange = FrequencyRange(100, 160)
bandpass_filter_settings: FrequencyRange = FrequencyRange(2, 200)
- lowpass_filter_cutoff_hz: float = Field(default=200)
- highpass_filter_cutoff_hz: float = Field(default=3)
-
- def get_filter_tuple(self, filter_name) -> tuple[float | None, float | None]:
- filter_value = self[FILTER_SETTINGS_MAP[filter_name]]
-
+ lowpass_filter_cutoff_hz: float = NMField(
+ default=200, gt=0, custom_metadata={"unit": "Hz"}
+ )
+ highpass_filter_cutoff_hz: float = NMField(
+ default=3, gt=0, custom_metadata={"unit": "Hz"}
+ )
+
+ def get_filter_tuple(self, filter_name) -> FrequencyRange:
match filter_name:
case "bandstop_filter":
- return (filter_value.frequency_high_hz, filter_value.frequency_low_hz)
+ return self.bandstop_filter_settings
case "bandpass_filter":
- return (filter_value.frequency_low_hz, filter_value.frequency_high_hz)
+ return self.bandpass_filter_settings
case "lowpass_filter":
- return (None, filter_value)
+ return FrequencyRange(None, self.lowpass_filter_cutoff_hz)
case "highpass_filter":
- return (filter_value, None)
+ return FrequencyRange(self.highpass_filter_cutoff_hz, None)
case _:
raise ValueError(
"Filter name must be one of 'bandstop_filter', 'lowpass_filter', "
diff --git a/py_neuromodulation/processing/normalization.py b/py_neuromodulation/processing/normalization.py
index e3f25d55..713e2308 100644
--- a/py_neuromodulation/processing/normalization.py
+++ b/py_neuromodulation/processing/normalization.py
@@ -3,10 +3,10 @@
import numpy as np
from typing import TYPE_CHECKING, Callable, Literal, get_args
+from py_neuromodulation.utils.pydantic_extensions import NMField
from py_neuromodulation.utils.types import (
NMBaseModel,
- Field,
- NormMethod,
+ NORM_METHOD,
NMPreprocessor,
)
@@ -17,13 +17,13 @@
class NormalizationSettings(NMBaseModel):
- normalization_time_s: float = 30
- normalization_method: NormMethod = "zscore"
- clip: float = Field(default=3, ge=0)
+ normalization_time_s: float = NMField(30, gt=0, custom_metadata={"unit": "s"})
+ normalization_method: NORM_METHOD = NMField(default="zscore")
+ clip: float = NMField(default=3, ge=0, custom_metadata={"unit": "a.u."})
@staticmethod
- def list_normalization_methods() -> list[NormMethod]:
- return list(get_args(NormMethod))
+ def list_normalization_methods() -> list[NORM_METHOD]:
+ return list(get_args(NORM_METHOD))
class FeatureNormalizationSettings(NormalizationSettings): normalize_psd: bool = False
@@ -60,7 +60,7 @@ def __init__(
if self.using_sklearn:
import sklearn.preprocessing as skpp
- NORM_METHODS_SKLEARN: dict[NormMethod, Callable] = {
+ NORM_METHODS_SKLEARN: dict[NORM_METHOD, Callable] = {
"quantile": lambda: skpp.QuantileTransformer(n_quantiles=300),
"robust": skpp.RobustScaler,
"minmax": skpp.MinMaxScaler,
diff --git a/py_neuromodulation/processing/projection.py b/py_neuromodulation/processing/projection.py
index 4cdc3610..3a83a29f 100644
--- a/py_neuromodulation/processing/projection.py
+++ b/py_neuromodulation/processing/projection.py
@@ -1,6 +1,6 @@
import numpy as np
-from pydantic import Field
from py_neuromodulation.utils.types import NMBaseModel
+from py_neuromodulation.utils.pydantic_extensions import NMField
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -9,7 +9,7 @@
class ProjectionSettings(NMBaseModel):
- max_dist_mm: float = Field(default=20.0, gt=0.0)
+ max_dist_mm: float = NMField(default=20.0, gt=0.0, custom_metadata={"unit": "mm"})
class Projection:
diff --git a/py_neuromodulation/processing/resample.py b/py_neuromodulation/processing/resample.py
index 08a4e115..1b7107eb 100644
--- a/py_neuromodulation/processing/resample.py
+++ b/py_neuromodulation/processing/resample.py
@@ -1,11 +1,14 @@
"""Module for resampling."""
import numpy as np
-from py_neuromodulation.utils.types import NMBaseModel, Field, NMPreprocessor
+from py_neuromodulation.utils.types import NMBaseModel, NMPreprocessor
+from py_neuromodulation.utils.pydantic_extensions import NMField
class ResamplerSettings(NMBaseModel):
- resample_freq_hz: float = Field(default=1000, gt=0)
+ resample_freq_hz: float = NMField(
+ default=1000, gt=0, custom_metadata={"unit": "Hz"}
+ )
class Resampler(NMPreprocessor):
diff --git a/py_neuromodulation/stream/backend_interface.py b/py_neuromodulation/stream/backend_interface.py
new file mode 100644
index 00000000..568998b0
--- /dev/null
+++ b/py_neuromodulation/stream/backend_interface.py
@@ -0,0 +1,47 @@
+from typing import Any
+import logging
+import multiprocessing as mp
+
+
+class StreamBackendInterface:
+ """Handles stream data output via queues"""
+
+ def __init__(
+ self, feature_queue: mp.Queue, raw_data_queue: mp.Queue, control_queue: mp.Queue
+ ):
+ self.feature_queue = feature_queue
+ self.rawdata_queue = raw_data_queue
+ self.control_queue = control_queue
+
+ self.logger = logging.getLogger("PyNM")
+
+ def send_command(self, command: str) -> None:
+ """Send a command through the control queue"""
+ try:
+ self.control_queue.put(command)
+ except Exception as e:
+ self.logger.error(f"Error sending command: {e}")
+
+ def send_features(self, features: dict[str, Any]) -> None:
+ """Send feature data through the feature queue"""
+ try:
+ self.feature_queue.put(features)
+ except Exception as e:
+ self.logger.error(f"Error sending features: {e}")
+
+ def send_raw_data(self, data: dict[str, Any]) -> None:
+ """Send raw data through the rawdata queue"""
+ try:
+ self.rawdata_queue.put(data)
+ except Exception as e:
+ self.logger.error(f"Error sending raw data: {e}")
+
+ def check_control_signals(self) -> str | None:
+ """Check for control signals (non-blocking)"""
+ try:
+ if not self.control_queue.empty():
+ return self.control_queue.get_nowait()
+ return None
+ except Exception as e:
+ self.logger.error(f"Error checking control signals: {e}")
+ return None
diff --git a/py_neuromodulation/stream/mnelsl_player.py b/py_neuromodulation/stream/mnelsl_player.py
index 086e4d60..bd39dd89 100644
--- a/py_neuromodulation/stream/mnelsl_player.py
+++ b/py_neuromodulation/stream/mnelsl_player.py
@@ -125,7 +125,7 @@ def start_player(
try:
self.wait_for_completion()
except KeyboardInterrupt:
- print("\nKeyboard interrupt received. Stopping the player...")
+ logger.info("\nKeyboard interrupt received. Stopping the player...")
self.stop_player()
def _run_player(self, chunk_size, n_repeat, stop_flag, streaming_complete):
@@ -156,7 +156,7 @@ def wait_for_completion(self):
if self._streaming_complete.is_set():
break
except KeyboardInterrupt:
- print("\nKeyboard interrupt received. Stopping the player...")
+ logger.info("\nKeyboard interrupt received. Stopping the player...")
self.stop_player()
break
@@ -172,7 +172,7 @@ def stop_player(self):
self._player_process.kill()
self._player_process = None
- print(f"Player stopped: {self.stream_name}")
+ logger.info(f"Player stopped: {self.stream_name}")
LSLOfflinePlayer._instances.discard(self)
@classmethod
diff --git a/py_neuromodulation/stream/settings.py b/py_neuromodulation/stream/settings.py
index 669c97db..c2b45e4c 100644
--- a/py_neuromodulation/stream/settings.py
+++ b/py_neuromodulation/stream/settings.py
@@ -1,22 +1,27 @@
"""Module for handling settings."""
from pathlib import Path
-from typing import ClassVar
-from pydantic import Field, model_validator
+from typing import Any, ClassVar, get_args
+from pydantic import model_validator, ValidationError
+from pydantic.functional_validators import ModelWrapValidatorHandler
-from py_neuromodulation import PYNM_DIR, logger, user_features
+from py_neuromodulation import logger, user_features
from py_neuromodulation.utils.types import (
BoolSelector,
FrequencyRange,
- PreprocessorName,
_PathLike,
NMBaseModel,
- NormMethod,
+ NORM_METHOD,
+ PREPROCESSOR_NAME,
)
+from py_neuromodulation.utils.pydantic_extensions import NMErrorList, NMField
from py_neuromodulation.processing.filter_preprocessing import FilterSettings
-from py_neuromodulation.processing.normalization import FeatureNormalizationSettings, NormalizationSettings
+from py_neuromodulation.processing.normalization import (
+ FeatureNormalizationSettings,
+ NormalizationSettings,
+)
from py_neuromodulation.processing.resample import ResamplerSettings
from py_neuromodulation.processing.projection import ProjectionSettings
@@ -31,7 +36,9 @@
from py_neuromodulation.features import BurstsSettings
-class FeatureSelection(BoolSelector):
+# TONI: this class has the proble that if a feature is absent,
+# it won't default to False but to whatever is defined here as default
+class FeatureSelector(BoolSelector):
raw_hjorth: bool = True
return_raw: bool = True
bandpass_filter: bool = False
@@ -54,13 +61,24 @@ class PostprocessingSettings(BoolSelector):
project_subcortex: bool = False
+DEFAULT_PREPROCESSORS: list[PREPROCESSOR_NAME] = [
+ "raw_resampling",
+ "notch_filter",
+ "re_referencing",
+]
+
+
class NMSettings(NMBaseModel):
# Class variable to store instances
_instances: ClassVar[list["NMSettings"]] = []
# General settings
- sampling_rate_features_hz: float = Field(default=10, gt=0)
- segment_length_features_ms: float = Field(default=1000, gt=0)
+ sampling_rate_features_hz: float = NMField(
+ default=10, gt=0, custom_metadata={"unit": "Hz"}
+ )
+ segment_length_features_ms: float = NMField(
+ default=1000, gt=0, custom_metadata={"unit": "ms"}
+ )
frequency_ranges_hz: dict[str, FrequencyRange] = {
"theta": FrequencyRange(4, 8),
"alpha": FrequencyRange(8, 12),
@@ -72,23 +90,28 @@ class NMSettings(NMBaseModel):
}
# Preproceessing settings
- preprocessing: list[PreprocessorName] = [
- "raw_resampling",
- "notch_filter",
- "re_referencing",
- ]
+ preprocessing: list[PREPROCESSOR_NAME] = NMField(
+ default=DEFAULT_PREPROCESSORS,
+ custom_metadata={
+ "field_type": "PreprocessorList",
+ "valid_values": list(get_args(PREPROCESSOR_NAME)),
+ },
+ )
+
raw_resampling_settings: ResamplerSettings = ResamplerSettings()
preprocessing_filter: FilterSettings = FilterSettings()
raw_normalization_settings: NormalizationSettings = NormalizationSettings()
# Postprocessing settings
postprocessing: PostprocessingSettings = PostprocessingSettings()
- feature_normalization_settings: FeatureNormalizationSettings = FeatureNormalizationSettings()
+ feature_normalization_settings: FeatureNormalizationSettings = (
+ FeatureNormalizationSettings()
+ )
project_cortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=20)
project_subcortex_settings: ProjectionSettings = ProjectionSettings(max_dist_mm=5)
# Feature settings
- features: FeatureSelection = FeatureSelection()
+ features: FeatureSelector = FeatureSelector()
fft_settings: OscillatorySettings = OscillatorySettings()
welch_settings: OscillatorySettings = OscillatorySettings()
@@ -126,10 +149,38 @@ def _remove_feature(cls, feature: str) -> None:
for instance in cls._instances:
delattr(instance.features, feature)
- @model_validator(mode="after")
- def validate_settings(self):
+ @model_validator(mode="wrap") # type: ignore[reportIncompatibleMethodOverride]
+ def validate_settings(self, handler: ModelWrapValidatorHandler) -> Any:
+ # Perform all necessary custom validations in the settings class and also
+ # all validations in the feature classes that need additional information from
+ # the settings class
+ errors: NMErrorList = NMErrorList()
+
+ def remove_private_keys(data):
+ if isinstance(data, dict):
+ if "__value__" in data:
+ return data["__value__"]
+ else:
+ return {
+ key: remove_private_keys(value)
+ for key, value in data.items()
+ if not key.startswith("__")
+ }
+ elif isinstance(data, (list, tuple, set)):
+ return type(data)(remove_private_keys(item) for item in data)
+ else:
+ return data
+
+ self = remove_private_keys(self)
+
+ try:
+ self = handler(self) # validate the model
+ except ValidationError as e:
+ self = NMSettings.unvalidated(**self) # type: ignore
+ errors.extend(NMErrorList(e.errors()))
+
if len(self.features.get_enabled()) == 0:
- raise ValueError("At least one feature must be selected.")
+ errors.add_error("At least one feature must be selected.")
# Replace spaces with underscores in frequency band names
self.frequency_ranges_hz = {
@@ -138,32 +189,27 @@ def validate_settings(self):
if self.features.bandpass_filter:
# Check BandPass settings frequency bands
- self.bandpass_filter_settings.validate_fbands(self)
+ errors.extend(self.bandpass_filter_settings.validate_fbands(self))
# Check Kalman filter frequency bands
if self.bandpass_filter_settings.kalman_filter:
- self.kalman_filter_settings.validate_fbands(self)
+ errors.extend(self.kalman_filter_settings.validate_fbands(self))
- for k, v in self.frequency_ranges_hz.items():
- if not isinstance(v, FrequencyRange):
- self.frequency_ranges_hz[k] = FrequencyRange.create_from(v)
+ if len(errors) > 0:
+ raise errors.create_error()
return self
def reset(self) -> "NMSettings":
self.features.disable_all()
- self.preprocessing = []
+ self.preprocessing = DEFAULT_PREPROCESSORS
self.postprocessing.disable_all()
return self
def set_fast_compute(self) -> "NMSettings":
self.reset()
self.features.fft = True
- self.preprocessing = [
- "raw_resampling",
- "notch_filter",
- "re_referencing",
- ]
+ self.preprocessing = DEFAULT_PREPROCESSORS
self.postprocessing.feature_normalization = True
self.postprocessing.project_cortex = False
self.postprocessing.project_subcortex = False
@@ -250,10 +296,10 @@ def from_file(PATH: _PathLike) -> "NMSettings":
@staticmethod
def get_default() -> "NMSettings":
- return NMSettings.from_file(PYNM_DIR / "default_settings.yaml")
+ return NMSettings()
@staticmethod
- def list_normalization_methods() -> list[NormMethod]:
+ def list_normalization_methods() -> list[NORM_METHOD]:
return NormalizationSettings.list_normalization_methods()
def save(
diff --git a/py_neuromodulation/stream/stream.py b/py_neuromodulation/stream/stream.py
index 2ee03421..1ec685f6 100644
--- a/py_neuromodulation/stream/stream.py
+++ b/py_neuromodulation/stream/stream.py
@@ -1,19 +1,19 @@
"""Module for generic and offline data streams."""
import time
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from collections.abc import Iterator
import numpy as np
from pathlib import Path
import py_neuromodulation as nm
-from contextlib import suppress
from py_neuromodulation.stream.data_processor import DataProcessor
-from py_neuromodulation.utils.types import _PathLike, FeatureName
+from py_neuromodulation.utils.types import _PathLike, FEATURE_NAME
from py_neuromodulation.utils.file_writer import MsgPackFileWriter
from py_neuromodulation.stream.settings import NMSettings
from py_neuromodulation.analysis.decode import RealTimeDecoder
+from py_neuromodulation.stream.backend_interface import StreamBackendInterface
if TYPE_CHECKING:
import pandas as pd
@@ -68,6 +68,7 @@ def __init__(
verbose : bool, optional
print out stream computation time information, by default True
"""
+ # This is calling NMSettings.validate() which is making a copy
self.settings: NMSettings = NMSettings.load(settings)
if channels is None and data is not None:
@@ -85,7 +86,7 @@ def __init__(
raise ValueError("Either `channels` or `data` must be passed to `Stream`.")
# If features that use frequency ranges are on, test them against nyquist frequency
- use_freq_ranges: list[FeatureName] = [
+ use_freq_ranges: list[FEATURE_NAME] = [
"bandpass_filter",
"stft",
"fft",
@@ -208,14 +209,11 @@ def run(
save_interval: int = 10,
return_df: bool = True,
simulate_real_time: bool = False,
- decoder: RealTimeDecoder = None,
- feature_queue: "queue.Queue | None" = None,
- rawdata_queue: "queue.Queue | None" = None,
- stream_handling_queue: "queue.Queue | None" = None,
+ decoder: RealTimeDecoder | None = None,
+ backend_interface: StreamBackendInterface | None = None,
):
self.is_stream_lsl = is_stream_lsl
self.stream_lsl_name = stream_lsl_name
- self.stream_handling_queue = stream_handling_queue
self.save_csv = save_csv
self.save_interval = save_interval
self.return_df = return_df
@@ -248,7 +246,7 @@ def run(
nm.logger.log_to_file(out_dir)
self.generator: Iterator
- if not is_stream_lsl:
+ if not is_stream_lsl and data is not None:
from py_neuromodulation.stream.generator import RawDataGenerator
self.generator = RawDataGenerator(
@@ -265,7 +263,10 @@ def run(
settings=self.settings, stream_name=stream_lsl_name
)
- if self.sfreq != self.lsl_stream.stream.sinfo.sfreq:
+ if (
+ self.lsl_stream.stream.sinfo is not None
+ and self.sfreq != self.lsl_stream.stream.sinfo.sfreq
+ ):
error_msg = (
f"Sampling frequency of the lsl-stream ({self.lsl_stream.stream.sinfo.sfreq}) "
f"does not match the settings ({self.sfreq})."
@@ -279,14 +280,14 @@ def run(
prev_batch_end = 0
for timestamps, data_batch in self.generator:
self.is_running = True
- if self.stream_handling_queue is not None:
+ if backend_interface:
+ # Only simulate real-time if connected to GUI
if simulate_real_time:
time.sleep(1 / self.settings.sampling_rate_features_hz)
- if not self.stream_handling_queue.empty():
- signal = self.stream_handling_queue.get()
- nm.logger.info(f"Got signal: {signal}")
- if signal == "stop":
- break
+
+ signal = backend_interface.check_control_signals()
+ if signal == "stop":
+ break
if data_batch is None:
nm.logger.info("Data batch is None, stopping run function")
@@ -303,10 +304,11 @@ def run(
if decoder is not None:
ch_to_decode = self.channels.query("used == 1").iloc[0]["name"]
- feature_dict = decoder.predict(feature_dict, ch_to_decode, fft_bands_only=True)
+ feature_dict = decoder.predict(
+ feature_dict, ch_to_decode, fft_bands_only=True
+ )
feature_dict["time"] = np.ceil(this_batch_end * 1000 + 1)
-
prev_batch_end = this_batch_end
if self.verbose:
@@ -314,40 +316,41 @@ def run(
self._add_target(feature_dict, data_batch)
- with suppress(TypeError): # Need this because some features output None
- for key, value in feature_dict.items():
- feature_dict[key] = np.float64(value)
-
+ # Push data to file writer
file_writer.insert_data(feature_dict)
- if feature_queue is not None:
- feature_queue.put(feature_dict)
- if rawdata_queue is not None:
- # convert raw data into dict with new raw data in unit self.sfreq
- new_time_ms = 1000 / self.settings.sampling_rate_features_hz
- new_samples = int(new_time_ms * self.sfreq / 1000)
- data_batch_dict = {}
- data_batch_dict["raw_data"] = {}
- for i, ch in enumerate(self.channels["name"]):
- # needs to be list since cbor doesn't support np array
- data_batch_dict["raw_data"][ch] = list(data_batch[i, -new_samples:])
- rawdata_queue.put(data_batch_dict)
+ # Send data to frontend
+ if backend_interface:
+ backend_interface.send_features(feature_dict)
+ backend_interface.send_raw_data(self._prepare_raw_data_dict(data_batch))
+
+ # Save features to file in intervals
self.batch_count += 1
if self.batch_count % self.save_interval == 0:
file_writer.save()
file_writer.save()
+
if self.save_csv:
file_writer.save_as_csv(save_all_combined=True)
- if self.return_df:
- feature_df = file_writer.load_all()
+ feature_df = file_writer.load_all() if self.return_df else {}
self._save_after_stream()
self.is_running = False
- return feature_df # Timon: We could think of returning the feature_reader instead
-
+ return feature_df # Timon: We could think of returnader instead
+
+ def _prepare_raw_data_dict(self, data_batch: np.ndarray) -> dict[str, Any]:
+ """Prepare raw data dictionary for sending through queue"""
+ new_time_ms = 1000 / self.settings.sampling_rate_features_hz
+ new_samples = int(new_time_ms * self.sfreq / 1000)
+ return {
+ "raw_data": {
+ ch: list(data_batch[i, -new_samples:])
+ for i, ch in enumerate(self.channels["name"])
+ }
+ }
def plot_raw_signal(
self,
@@ -383,11 +386,15 @@ def plot_raw_signal(
ValueError
raise Exception when no data is passed
"""
- if self.data is None and data is None:
- raise ValueError("No data passed to plot_raw_signal function.")
-
- if data is None and self.data is not None:
- data = self.data
+ if data is None:
+ if self.data is None:
+ raise ValueError("No data passed to plot_raw_signal function.")
+ else:
+ data = (
+ self.data.to_numpy()
+ if isinstance(self.data, pd.DataFrame)
+ else self.data
+ )
if sfreq is None:
sfreq = self.sfreq
@@ -402,7 +409,7 @@ def plot_raw_signal(
from mne import create_info
from mne.io import RawArray
- info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types)
+ info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) # type: ignore
raw = RawArray(data, info)
if picks is not None:
@@ -437,7 +444,7 @@ def _save_sidecar(self) -> None:
"""Save sidecar incduing fs, coords, sess_right to
out_path_root and subfolder 'folder_name'"""
additional_args = {"sess_right": self.sess_right}
-
+
self.data_processor.save_sidecar(
self.out_dir, self.experiment_name, additional_args
)
diff --git a/py_neuromodulation/utils/channels.py b/py_neuromodulation/utils/channels.py
index c2fffd00..61f5b68e 100644
--- a/py_neuromodulation/utils/channels.py
+++ b/py_neuromodulation/utils/channels.py
@@ -251,7 +251,7 @@ def _get_default_references(
def get_default_channels_from_data(
- data: np.ndarray,
+ data: "np.ndarray | pd.DataFrame",
car_rereferencing: bool = True,
):
"""Return default channels dataframe with
diff --git a/py_neuromodulation/utils/pydantic_extensions.py b/py_neuromodulation/utils/pydantic_extensions.py
new file mode 100644
index 00000000..e489712d
--- /dev/null
+++ b/py_neuromodulation/utils/pydantic_extensions.py
@@ -0,0 +1,320 @@
+import copy
+from typing import (
+ Any,
+ get_origin,
+ get_args,
+ get_type_hints,
+ Literal,
+ cast,
+ Sequence,
+)
+from typing_extensions import Unpack, TypedDict
+from pydantic import BaseModel, GetCoreSchemaHandler
+
+from pydantic_core import (
+ ErrorDetails,
+ PydanticUndefined,
+ InitErrorDetails,
+ ValidationError,
+ CoreSchema,
+ core_schema,
+)
+from pydantic.fields import FieldInfo, _FieldInfoInputs, _FromFieldInfoInputs
+from pprint import pformat
+
+
+def create_validation_error(
+ error_message: str,
+ location: list[str | int] = [],
+ title: str = "Validation Error",
+ error_type="value_error",
+) -> ValidationError:
+ """
+ Factory function to create a Pydantic v2 ValidationError.
+
+ Args:
+ error_message (str): The error message for the ValidationError.
+ loc (List[str | int], optional): The location of the error. Defaults to None.
+ title (str, optional): The title of the error. Defaults to "Validation Error".
+
+ Returns:
+ ValidationError: A Pydantic ValidationError instance.
+ """
+
+ return ValidationError.from_exception_data(
+ title=title,
+ line_errors=[
+ InitErrorDetails(
+ type=error_type,
+ loc=tuple(location),
+ input=None,
+ ctx={"error": error_message},
+ )
+ ],
+ input_type="python",
+ hide_input=False,
+ )
+
+
+class NMErrorList:
+ """Class to handle data about Pydantic errors.
+ Stores data in a list of InitErrorDetails. Errors can be accessed but not modified.
+
+ :return: _description_
+ :rtype: _type_
+ """
+
+ def __init__(
+ self, errors: Sequence[InitErrorDetails | ErrorDetails] | None = None
+ ) -> None:
+ if errors is None:
+ self.__errors: list[InitErrorDetails | ErrorDetails] = []
+ else:
+ self.__errors: list[InitErrorDetails | ErrorDetails] = [e for e in errors]
+
+ def add_error(
+ self,
+ error_message: str,
+ location: list[str | int] = [],
+ error_type="value_error",
+ ) -> None:
+ self.__errors.append(
+ InitErrorDetails(
+ type=error_type,
+ loc=tuple(location),
+ input=None,
+ ctx={"error": error_message},
+ )
+ )
+
+ def create_error(self, title: str = "Validation Error") -> ValidationError:
+ return ValidationError.from_exception_data(
+ title=title, line_errors=cast(list[InitErrorDetails], self.__errors)
+ )
+
+ def extend(self, errors: "NMErrorList"):
+ self.__errors.extend(errors.__errors)
+
+ def __iter__(self):
+ return iter(self.__errors)
+
+ def __len__(self):
+ return len(self.__errors)
+
+ def __getitem__(self, idx):
+ # Return a copy of the error to prevent modification
+ return copy.deepcopy(self.__errors[idx])
+
+ def __repr__(self):
+ return repr(self.__errors)
+
+ def __str__(self):
+ return str(self.__errors)
+
+
+class _NMExtraFieldInputs(TypedDict, total=False):
+ """Additional fields to add on top of the pydantic FieldInfo"""
+
+ custom_metadata: dict[str, Any]
+
+
+class _NMFieldInfoInputs(_FieldInfoInputs, _NMExtraFieldInputs, total=False):
+ """Combine pydantic FieldInfo inputs with PyNM additional inputs"""
+
+ pass
+
+
+class _NMFromFieldInfoInputs(_FromFieldInfoInputs, _NMExtraFieldInputs, total=False):
+ """Combine pydantic FieldInfo.from_field inputs with PyNM additional inputs"""
+
+ pass
+
+
+class NMFieldInfo(FieldInfo):
+ # Add default values for any other custom fields here
+ _default_values = {}
+
+ def __init__(self, **kwargs: Unpack[_NMFieldInfoInputs]) -> None:
+ self.sequence: bool = kwargs.pop("sequence", False) # type: ignore
+ self.custom_metadata: dict[str, Any] = kwargs.pop("custom_metadata", {})
+ super().__init__(**kwargs)
+
+ def __get_pydantic_core_schema__(
+ self, source_type: Any, handler: GetCoreSchemaHandler
+ ) -> CoreSchema:
+ schema = handler(source_type)
+
+ if self.sequence:
+
+ def sequence_validator(v: Any) -> Any:
+ if isinstance(v, (list, tuple)):
+ return v
+ if isinstance(v, dict) and "root" in v:
+ return v["root"]
+ return [v]
+
+ return core_schema.no_info_before_validator_function(
+ sequence_validator, schema
+ )
+
+ return schema
+
+ @staticmethod
+ def from_field(
+ default: Any = PydanticUndefined,
+ **kwargs: Unpack[_NMFromFieldInfoInputs],
+ ) -> "NMFieldInfo":
+ if "annotation" in kwargs:
+ raise TypeError('"annotation" is not permitted as a Field keyword argument')
+ return NMFieldInfo(default=default, **kwargs)
+
+ def __repr_args__(self):
+ yield from super().__repr_args__()
+ extra_fields = get_type_hints(_NMExtraFieldInputs)
+ for field in extra_fields:
+ value = getattr(self, field)
+ yield field, value
+
+
+def NMField(
+ default: Any = PydanticUndefined,
+ **kwargs: Unpack[_NMFromFieldInfoInputs],
+) -> Any:
+ return NMFieldInfo.from_field(default=default, **kwargs)
+
+
+class NMBaseModel(BaseModel):
+ def __init__(self, *args, **kwargs) -> None:
+ """Pydantic does not support positional arguments by default.
+ This is a workaround to support positional arguments for models like FrequencyRange.
+ It converts positional arguments to kwargs and then calls the base class __init__.
+ """
+
+ if not args:
+ # Simple case - just use kwargs
+ super().__init__(*args, **kwargs)
+ return
+
+ field_names = list(self.model_fields.keys())
+ # If we have more positional args than fields, that's an error
+ if len(args) > len(field_names):
+ raise ValueError(
+ f"Too many positional arguments. Expected at most {len(field_names)}, got {len(args)}"
+ )
+
+ # Convert positional args to kwargs, checking for duplicates if args:
+ complete_kwargs = {}
+ for i, arg in enumerate(args):
+ field_name = field_names[i]
+ if field_name in kwargs:
+ raise ValueError(
+ f"Got multiple values for field '{field_name}': "
+ f"positional argument and keyword argument"
+ )
+ complete_kwargs[field_name] = arg
+
+ # Add remaining kwargs
+ complete_kwargs.update(kwargs)
+ super().__init__(**complete_kwargs)
+
+ __init__.__pydantic_base_init__ = True # type: ignore
+
+ def __str__(self):
+ return pformat(self.model_dump())
+
+ # def __repr__(self):
+ # return pformat(self.model_dump())
+
+ def validate(self, context: Any | None = None) -> Any: # type: ignore
+ return self.model_validate(self.model_dump(), context=context)
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __setitem__(self, key, value) -> None:
+ setattr(self, key, value)
+
+ @property
+ def fields(self) -> dict[str, FieldInfo | NMFieldInfo]:
+ return self.model_fields # type: ignore
+
+ def serialize_with_metadata(self):
+ result: dict[str, Any] = {"__field_type__": self.__class__.__name__}
+
+ for field_name, field_info in self.fields.items():
+ value = getattr(self, field_name)
+ if isinstance(value, NMBaseModel):
+ result[field_name] = value.serialize_with_metadata()
+ elif isinstance(value, list):
+ result[field_name] = [
+ item.serialize_with_metadata()
+ if isinstance(item, NMBaseModel)
+ else item
+ for item in value
+ ]
+ elif isinstance(value, dict):
+ result[field_name] = {
+ k: v.serialize_with_metadata() if isinstance(v, NMBaseModel) else v
+ for k, v in value.items()
+ }
+ else:
+ result[field_name] = value
+
+ # Extract unit information from Annotated type
+ if isinstance(field_info, NMFieldInfo):
+ # Convert scalar value to dict with metadata
+ field_dict = {
+ "__value__": value,
+ # __field_type__ will be overwritte if set in custom_metadata
+ "__field_type__": type(value).__name__,
+ **{
+ f"__{tag}__": value
+ for tag, value in field_info.custom_metadata.items()
+ },
+ }
+ # Add possible values for Literal types
+ if get_origin(field_info.annotation) is Literal:
+ field_dict["__valid_values__"] = list(
+ get_args(field_info.annotation)
+ )
+
+ result[field_name] = field_dict
+ return result
+
+ @classmethod
+ def unvalidated(cls, **data: Any) -> Any:
+ def process_value(value: Any, field_type: Any) -> Any:
+ if isinstance(value, dict) and hasattr(
+ field_type, "__pydantic_core_schema__"
+ ):
+ # Recursively handle nested Pydantic models
+ return field_type.unvalidated(**value)
+ elif isinstance(value, list):
+ # Handle lists of Pydantic models
+ if hasattr(field_type, "__args__") and hasattr(
+ field_type.__args__[0], "__pydantic_core_schema__"
+ ):
+ return [
+ field_type.__args__[0].unvalidated(**item)
+ if isinstance(item, dict)
+ else item
+ for item in value
+ ]
+ return value
+
+ processed_data = {}
+ for name, field in cls.model_fields.items():
+ try:
+ value = data[name]
+ processed_data[name] = process_value(value, field.annotation)
+ except KeyError:
+ if not field.is_required():
+ processed_data[name] = copy.deepcopy(field.default)
+ else:
+ raise TypeError(f"Missing required keyword argument {name!r}")
+
+ self = cls.__new__(cls)
+ object.__setattr__(self, "__dict__", processed_data)
+ object.__setattr__(self, "__pydantic_private__", {"extra": None})
+ object.__setattr__(self, "__pydantic_fields_set__", set(processed_data.keys()))
+ return self
diff --git a/py_neuromodulation/utils/types.py b/py_neuromodulation/utils/types.py
index 7886685d..6f7ab636 100644
--- a/py_neuromodulation/utils/types.py
+++ b/py_neuromodulation/utils/types.py
@@ -1,12 +1,14 @@
from os import PathLike
from math import isnan
-from typing import Any, Literal, Protocol, TYPE_CHECKING, runtime_checkable
-from pydantic import ConfigDict, Field, model_validator, BaseModel
-from pydantic_core import ValidationError, InitErrorDetails
-from pprint import pformat
+from typing import Literal, TYPE_CHECKING, Any
+from pydantic import BaseModel, model_validator, Field
+from .pydantic_extensions import NMBaseModel, NMField
+from abc import abstractmethod
+
from collections.abc import Sequence
from datetime import datetime
+
if TYPE_CHECKING:
import numpy as np
from py_neuromodulation import NMSettings
@@ -17,8 +19,7 @@
_PathLike = str | PathLike
-
-FeatureName = Literal[
+FEATURE_NAME = Literal[
"raw_hjorth",
"return_raw",
"bandpass_filter",
@@ -35,7 +36,7 @@
"bispectrum",
]
-PreprocessorName = Literal[
+PREPROCESSOR_NAME = Literal[
"preprocessing_filter",
"notch_filter",
"raw_resampling",
@@ -43,7 +44,7 @@
"raw_normalization",
]
-NormMethod = Literal[
+NORM_METHOD = Literal[
"mean",
"median",
"zscore",
@@ -54,13 +55,8 @@
"minmax",
]
-###################################
-######## PROTOCOL CLASSES ########
-###################################
-
-@runtime_checkable
-class NMFeature(Protocol):
+class NMFeature:
def __init__(
self, settings: "NMSettings", ch_names: Sequence[str], sfreq: int | float
) -> None: ...
@@ -81,73 +77,10 @@ def calc_feature(self, data: "np.ndarray") -> dict:
...
-class NMPreprocessor(Protocol):
- def __init__(self, sfreq: float, settings: "NMSettings") -> None: ...
-
+class NMPreprocessor:
def process(self, data: "np.ndarray") -> "np.ndarray": ...
-###################################
-######## PYDANTIC CLASSES ########
-###################################
-
-
-class NMBaseModel(BaseModel):
- model_config = ConfigDict(validate_assignment=False, extra="allow")
-
- def __init__(self, *args, **kwargs) -> None:
- if kwargs:
- super().__init__(**kwargs)
- else:
- field_names = list(self.model_fields.keys())
- kwargs = {}
- for i in range(len(args)):
- kwargs[field_names[i]] = args[i]
- super().__init__(**kwargs)
-
- def __str__(self):
- return pformat(self.model_dump())
-
- def __repr__(self):
- return pformat(self.model_dump())
-
- def validate(self) -> Any: # type: ignore
- return self.model_validate(self.model_dump())
-
- def __getitem__(self, key):
- return getattr(self, key)
-
- def __setitem__(self, key, value) -> None:
- setattr(self, key, value)
-
- def process_for_frontend(self) -> dict[str, Any]:
- """
- Process the model for frontend use, adding __field_type__ information.
- """
- result = {}
- for field_name, field_value in self.__dict__.items():
- if isinstance(field_value, NMBaseModel):
- processed_value = field_value.process_for_frontend()
- processed_value["__field_type__"] = field_value.__class__.__name__
- result[field_name] = processed_value
- elif isinstance(field_value, list):
- result[field_name] = [
- item.process_for_frontend()
- if isinstance(item, NMBaseModel)
- else item
- for item in field_value
- ]
- elif isinstance(field_value, dict):
- result[field_name] = {
- k: v.process_for_frontend() if isinstance(v, NMBaseModel) else v
- for k, v in field_value.items()
- }
- else:
- result[field_name] = field_value
-
- return result
-
-
class FrequencyRange(NMBaseModel):
frequency_low_hz: float = Field(gt=0)
frequency_high_hz: float = Field(gt=0)
@@ -178,35 +111,6 @@ def validate_range(self):
), "Frequency high must be greater than frequency low"
return self
- @classmethod
- def create_from(cls, input) -> "FrequencyRange":
- match input:
- case FrequencyRange():
- return input
- case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input:
- return FrequencyRange(
- input["frequency_low_hz"], input["frequency_high_hz"]
- )
- case Sequence() if len(input) == 2:
- return FrequencyRange(input[0], input[1])
- case _:
- raise ValueError("Invalid input for FrequencyRange creation.")
-
- @model_validator(mode="before")
- @classmethod
- def check_input(cls, input):
- match input:
- case dict() if "frequency_low_hz" in input and "frequency_high_hz" in input:
- return input
- case Sequence() if len(input) == 2:
- return {"frequency_low_hz": input[0], "frequency_high_hz": input[1]}
- case _:
- raise ValueError(
- "Value for FrequencyRange must be a dictionary, "
- "or a sequence of 2 numeric values, "
- f"but got {input} instead."
- )
-
class BoolSelector(NMBaseModel):
def get_enabled(self):
@@ -238,47 +142,6 @@ def print_all(cls):
for f in cls.list_all():
print(f)
- @classmethod
- def get_fields(cls):
- return cls.model_fields
-
-
-def create_validation_error(
- error_message: str,
- loc: list[str | int] = None,
- title: str = "Validation Error",
- input_type: Literal["python", "json"] = "python",
- hide_input: bool = False,
-) -> ValidationError:
- """
- Factory function to create a Pydantic v2 ValidationError instance from a single error message.
-
- Args:
- error_message (str): The error message for the ValidationError.
- loc (List[str | int], optional): The location of the error. Defaults to None.
- title (str, optional): The title of the error. Defaults to "Validation Error".
- input_type (Literal["python", "json"], optional): Whether the error is for a Python object or JSON. Defaults to "python".
- hide_input (bool, optional): Whether to hide the input value in the error message. Defaults to False.
-
- Returns:
- ValidationError: A Pydantic ValidationError instance.
- """
- if loc is None:
- loc = []
-
- line_errors = [
- InitErrorDetails(
- type="value_error", loc=tuple(loc), input=None, ctx={"error": error_message}
- )
- ]
-
- return ValidationError.from_exception_data(
- title=title,
- line_errors=line_errors,
- input_type=input_type,
- hide_input=hide_input,
- )
-
#################
### GUI TYPES ###
diff --git a/pyproject.toml b/pyproject.toml
index 5facc8ba..7ad64fae 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,7 +42,6 @@ dependencies = [
"nolds >=0.6.1",
"numpy >= 1.21.2",
"pandas >= 2.0.0",
- "scikit-image",
"scikit-learn >= 0.24.2",
"scikit-optimize",
"scipy >= 1.7.1",
@@ -55,13 +54,11 @@ dependencies = [
"statsmodels",
"mne-lsl>=1.2.0",
"pynput",
- #"pyqt5",
"pydantic>=2.7.3",
"llvmlite>=0.43.0",
"pywebview",
"fastapi",
- "uvicorn>=0.30.6",
- "websockets>=13.0",
+ "uvicorn[standard]>=0.30.6",
"seaborn >= 0.11",
# exists only because of nolds?
"numba>=0.60.0",
@@ -73,14 +70,7 @@ dependencies = [
[project.optional-dependencies]
test = ["pytest>=8.0.2", "pytest-xdist"]
-dev = [
- "ruff",
- "pytest>=8.0.2",
- "pytest-cov",
- "pytest-sugar",
- "notebook",
- "watchdog",
-]
+dev = ["ruff", "pytest>=8.0.2", "pytest-cov", "pytest-sugar", "notebook"]
docs = [
"py-neuromodulation[dev]",
"sphinx",
diff --git a/tests/test_osc_features.py b/tests/test_osc_features.py
index 8cf84111..fd2428d1 100644
--- a/tests/test_osc_features.py
+++ b/tests/test_osc_features.py
@@ -3,11 +3,11 @@
import numpy as np
from py_neuromodulation import NMSettings, Stream, features
-from py_neuromodulation.utils.types import FeatureName
+from py_neuromodulation.utils.types import FEATURE_NAME
def setup_osc_settings(
- osc_feature_name: FeatureName,
+ osc_feature_name: FEATURE_NAME,
osc_feature_setting: str,
windowlength_ms: int,
log_transform: bool,