Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensorboard打开没有任何内容 #6

Closed
YYrgb opened this issue Dec 21, 2024 · 0 comments
Closed

tensorboard打开没有任何内容 #6

YYrgb opened this issue Dec 21, 2024 · 0 comments

Comments

@YYrgb
Copy link

YYrgb commented Dec 21, 2024

import mne
import torch
import random
import numpy as np
from torch.utils.data import Dataset
import warnings
import os
import re

warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

random_seed = 777
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)

class TorchDataset(Dataset):
def init(self, paths, sfreq, rfreq, scaler: bool = False):
super().init()
self.x, self.y = self.get_data(paths, sfreq, rfreq, scaler)
self.x, self.y = torch.tensor(self.x, dtype=torch.float32), torch.tensor(self.y, dtype=torch.long)

@staticmethod
def get_data(paths, sfreq, rfreq, scaler_flag):
    info = mne.create_info(sfreq=sfreq, ch_types='eeg', ch_names=['Fp1'])
    scaler = mne.decoding.Scaler(info=info, scalings='median')
    total_x, total_y = [], []

    for path in paths:
        # 获取信号文件名的前缀部分(例如 SC4161E0)
        base_name = os.path.basename(path).split('-')[0]  # 提取 "SC4161E0"

        # 构造对应的注释文件路径,替换 E0 为 EC 和 PSG 为 Hypnogram
        annotation_path = path.replace(base_name, base_name.replace('E0', 'EC')).replace('PSG', 'Hypnogram')

        # 检查注释文件是否存在
        if not os.path.exists(annotation_path):
            raise FileNotFoundError(f"Annotation file does not exist: {annotation_path}")

        # 读取信号文件和注释文件
        raw = mne.io.read_raw_edf(path, preload=True)
        annotations = mne.read_annotations(annotation_path)

        # 将注释添加到信号文件
        raw.set_annotations(annotations)

        # 提取信号数据
        x = raw.get_data().T  # 转置为 [samples, channels]
        y = raw.annotations.description  # 注释标签

        # 扩展维度以符合网络输入的要求
        x = np.expand_dims(x, axis=1)

        if scaler_flag:
            x = scaler.fit_transform(x)

        # 转换为 EpochsArray 进行重采样
        x = mne.EpochsArray(x, info=info)
        x = x.resample(rfreq)
        x = x.get_data().squeeze()  # 获取数据并移除多余的维度

        total_x.append(x)
        total_y.append(y)

    # 合并所有数据
    total_x, total_y = np.concatenate(total_x), np.concatenate(total_y)
    return total_x, total_y

def __len__(self):
    return len(self.y)

def __getitem__(self, item):
    x = torch.tensor(self.x[item])
    y = torch.tensor(self.y[item])
    return x, y
   我的数据集是sleep-edfx 2013 ,数据是将psg.edf 和hypnogram.psg放在一块,我根据代码报错将data_loader修改成这样,于是接着运行train文件,这个文件一直在运行,只是打印出来一些内容,而且我想打开tensorboard看看,但是打开后,没有任何东西。而且我的日志文件只有1kb,我不知道怎么办了,如果能获得帮助,十分感谢
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants