We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import mne import torch import random import numpy as np from torch.utils.data import Dataset import warnings import os import re
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
random_seed = 777 np.random.seed(random_seed) torch.manual_seed(random_seed) random.seed(random_seed)
class TorchDataset(Dataset): def init(self, paths, sfreq, rfreq, scaler: bool = False): super().init() self.x, self.y = self.get_data(paths, sfreq, rfreq, scaler) self.x, self.y = torch.tensor(self.x, dtype=torch.float32), torch.tensor(self.y, dtype=torch.long)
@staticmethod def get_data(paths, sfreq, rfreq, scaler_flag): info = mne.create_info(sfreq=sfreq, ch_types='eeg', ch_names=['Fp1']) scaler = mne.decoding.Scaler(info=info, scalings='median') total_x, total_y = [], [] for path in paths: # 获取信号文件名的前缀部分(例如 SC4161E0) base_name = os.path.basename(path).split('-')[0] # 提取 "SC4161E0" # 构造对应的注释文件路径,替换 E0 为 EC 和 PSG 为 Hypnogram annotation_path = path.replace(base_name, base_name.replace('E0', 'EC')).replace('PSG', 'Hypnogram') # 检查注释文件是否存在 if not os.path.exists(annotation_path): raise FileNotFoundError(f"Annotation file does not exist: {annotation_path}") # 读取信号文件和注释文件 raw = mne.io.read_raw_edf(path, preload=True) annotations = mne.read_annotations(annotation_path) # 将注释添加到信号文件 raw.set_annotations(annotations) # 提取信号数据 x = raw.get_data().T # 转置为 [samples, channels] y = raw.annotations.description # 注释标签 # 扩展维度以符合网络输入的要求 x = np.expand_dims(x, axis=1) if scaler_flag: x = scaler.fit_transform(x) # 转换为 EpochsArray 进行重采样 x = mne.EpochsArray(x, info=info) x = x.resample(rfreq) x = x.get_data().squeeze() # 获取数据并移除多余的维度 total_x.append(x) total_y.append(y) # 合并所有数据 total_x, total_y = np.concatenate(total_x), np.concatenate(total_y) return total_x, total_y def __len__(self): return len(self.y) def __getitem__(self, item): x = torch.tensor(self.x[item]) y = torch.tensor(self.y[item]) return x, y 我的数据集是sleep-edfx 2013 ,数据是将psg.edf 和hypnogram.psg放在一块,我根据代码报错将data_loader修改成这样,于是接着运行train文件,这个文件一直在运行,只是打印出来一些内容,而且我想打开tensorboard看看,但是打开后,没有任何东西。而且我的日志文件只有1kb,我不知道怎么办了,如果能获得帮助,十分感谢
The text was updated successfully, but these errors were encountered:
No branches or pull requests
import mne
import torch
import random
import numpy as np
from torch.utils.data import Dataset
import warnings
import os
import re
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
random_seed = 777
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
class TorchDataset(Dataset):
def init(self, paths, sfreq, rfreq, scaler: bool = False):
super().init()
self.x, self.y = self.get_data(paths, sfreq, rfreq, scaler)
self.x, self.y = torch.tensor(self.x, dtype=torch.float32), torch.tensor(self.y, dtype=torch.long)
The text was updated successfully, but these errors were encountered: